Skip to content

Commit c7698ad

Browse files
committed
Reverse arrows for case, precision for _
1 parent 28a66b5 commit c7698ad

8 files changed

Lines changed: 110 additions & 41 deletions

File tree

lib/elixir/lib/module/types/descr.ex

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1810,7 +1810,7 @@ defmodule Module.Types.Descr do
18101810
cache = Map.put(cache, cache_key, false)
18111811
{false, cache}
18121812
else
1813-
{_index, result2, cache} =
1813+
{_index, result, cache} =
18141814
Enum.reduce_while(arguments, {0, true, cache}, fn
18151815
type, {index, acc_result, acc_cache} ->
18161816
{new_result, new_cache} =
@@ -1825,7 +1825,6 @@ defmodule Module.Types.Descr do
18251825
end
18261826
end)
18271827

1828-
result = result1 and result2
18291828
cache = Map.put(cache, cache_key, result)
18301829
{result, cache}
18311830
end

lib/elixir/lib/module/types/expr.ex

Lines changed: 80 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,8 @@ defmodule Module.Types.Expr do
137137
{@stacktrace, context}
138138
end
139139

140+
@dynamic_or_term_list [dynamic(), term()]
141+
140142
# left = right
141143
def of_expr({:=, _, [left_expr, right_expr]} = match, expected, expr, stack, context) do
142144
{left_expr, right_expr} = repack_match(left_expr, right_expr)
@@ -147,12 +149,22 @@ defmodule Module.Types.Expr do
147149
of_expr(right_expr, expected, expr, stack, context)
148150

149151
_ ->
150-
type_fun = fn pattern_type, context ->
151-
# See if we can use the expected type to further refine the pattern type,
152-
# if we cannot, use the pattern type as that will fail later on.
153-
{_ok_or_error, type} = compatible_intersection(dynamic(pattern_type), expected)
154-
of_expr(right_expr, type, expr, stack, context)
155-
end
152+
type_fun =
153+
fn pattern_type, context ->
154+
if expected in @dynamic_or_term_list do
155+
of_expr(right_expr, pattern_type, expr, stack, context)
156+
else
157+
# See if we can use the expected type to further refine the pattern type,
158+
# if we cannot, use the pattern type as that will fail later on.
159+
{_ok_or_error, type} = compatible_intersection(dynamic(pattern_type), expected)
160+
{result, context} = of_expr(right_expr, type, expr, stack, context)
161+
162+
# The function may still return a too broad type, so we refine once again
163+
# to assign the most appropriate one for reverse arrows.
164+
{_ok_or_error, result} = compatible_intersection(result, expected)
165+
{result, context}
166+
end
167+
end
156168

157169
Pattern.of_match(left_expr, type_fun, match, stack, context)
158170
end
@@ -311,9 +323,20 @@ defmodule Module.Types.Expr do
311323
|> dynamic_unless_static(stack)
312324
end
313325

314-
def of_expr({:case, meta, [case_expr, [{:do, clauses}]]}, expected, _expr, stack, context) do
315-
_ = Keyword.fetch!(meta, :version)
316-
{case_type, context} = of_expr(case_expr, @pending, case_expr, stack, context)
326+
def of_expr({:case, meta, [_case_expr, [{:do, _clauses}]]}, _expected, _expr, stack, context)
327+
when stack.reverse_arrow == :use do
328+
version = Keyword.fetch!(meta, :version)
329+
clauses = Map.fetch!(context.reverse_arrows, version)
330+
result = Enum.reduce(clauses, none(), &union(elem(&1, 1), &2))
331+
dynamic_unless_static({result, context}, stack)
332+
end
333+
334+
def of_expr({:case, meta, [case_expr, [{:do, clauses}]]}, expected, _expr, stack, base_context) do
335+
version = Keyword.fetch!(meta, :version)
336+
337+
{case_type, context} =
338+
of_expr(case_expr, @pending, case_expr, %{stack | reverse_arrow: :cache}, base_context)
339+
317340
info = {:case, meta, case_expr, case_type}
318341

319342
added_meta =
@@ -326,13 +349,35 @@ defmodule Module.Types.Expr do
326349
# If the expression is generated or the construct is a literal,
327350
# it is most likely a macro code. However, if no clause is matched,
328351
# we should still check for that.
329-
if added_meta != [] do
330-
for {:->, meta, args} <- clauses, do: {:->, [generated: true] ++ meta, args}
331-
else
332-
clauses
352+
clauses =
353+
if added_meta != [] do
354+
for {:->, meta, args} <- clauses, do: {:->, [generated: true] ++ meta, args}
355+
else
356+
clauses
357+
end
358+
359+
of_body = fn trees, body, context ->
360+
[arg_type] = Pattern.of_domain(trees, stack, context)
361+
362+
{_, context} =
363+
of_expr(case_expr, arg_type, case_expr, %{stack | reverse_arrow: :use}, context)
364+
365+
of_expr(body, expected, body, stack, context)
333366
end
334-
|> of_clauses([case_type], expected, info, stack, context, none())
335-
|> dynamic_unless_static(stack)
367+
368+
result_context =
369+
cache_arrows(version, stack, fn ->
370+
of_clauses_fun(clauses, [case_type], info, stack, context, of_body, [], fn
371+
trees, body_type, context, acc ->
372+
[arg_type] = Pattern.of_domain(trees, stack, context)
373+
[{arg_type, body_type} | acc]
374+
end)
375+
end) ||
376+
of_clauses_fun(clauses, [case_type], info, stack, context, of_body, none(), fn
377+
_trees, body_type, _context, acc -> union(acc, body_type)
378+
end)
379+
380+
dynamic_unless_static(result_context, stack)
336381
end
337382

338383
# fn pat -> expr end
@@ -341,11 +386,13 @@ defmodule Module.Types.Expr do
341386
{patterns, _guards} = extract_head(head)
342387
domain = Enum.map(patterns, fn _ -> dynamic() end)
343388

389+
of_body = fn _args_types, body, context -> of_expr(body, @pending, body, stack, context) end
390+
344391
{acc, context} =
345-
of_clauses_fun(clauses, domain, @pending, :fn, stack, context, [], fn
346-
trees, body, context, acc ->
392+
of_clauses_fun(clauses, domain, :fn, stack, context, of_body, [], fn
393+
trees, body_type, context, acc ->
347394
args_types = Pattern.of_domain(trees, stack, context)
348-
add_inferred(acc, args_types, body)
395+
add_inferred(acc, args_types, body_type)
349396
end)
350397

351398
{fun_from_inferred_clauses(acc), context}
@@ -725,12 +772,22 @@ defmodule Module.Types.Expr do
725772
defp dynamic_unless_static({_, _} = output, %{mode: :static}), do: output
726773
defp dynamic_unless_static({type, context}, %{mode: _}), do: {dynamic(type), context}
727774

775+
defp cache_arrows(_version, %{reverse_arrow: nil}, _fun), do: nil
776+
777+
defp cache_arrows(version, %{reverse_arrow: :cache}, fun) do
778+
{clauses, context} = fun.()
779+
context = put_in(context.reverse_arrows[version], clauses)
780+
result = Enum.reduce(clauses, none(), &union(elem(&1, 1), &2))
781+
{result, context}
782+
end
783+
728784
defp of_clauses(clauses, domain, expected, base_info, stack, context, acc) do
729-
fun = fn _args_types, result, _context, acc -> union(result, acc) end
730-
of_clauses_fun(clauses, domain, expected, base_info, stack, context, acc, fun)
785+
of_body = fn _args_types, body, context -> of_expr(body, expected, body, stack, context) end
786+
of_acc = fn _args_types, body_type, _context, acc -> union(acc, body_type) end
787+
of_clauses_fun(clauses, domain, base_info, stack, context, of_body, acc, of_acc)
731788
end
732789

733-
defp of_clauses_fun(clauses, domain, expected, base_info, stack, original, acc, fun) do
790+
defp of_clauses_fun(clauses, domain, base_info, stack, original, of_body, acc, of_acc) do
734791
%{failed: failed?} = original
735792

736793
{result, _previous, context} =
@@ -743,9 +800,9 @@ defmodule Module.Types.Expr do
743800
{trees, previous, context} =
744801
Pattern.of_head(patterns, guards, domain, previous, info, meta, stack, context)
745802

746-
{result, context} = of_expr(body, expected, body, stack, context)
803+
{result, context} = of_body.(trees, body, context)
747804

748-
{fun.(trees, result, context, acc), previous,
805+
{of_acc.(trees, result, context, acc), previous,
749806
context |> set_failed(failed?) |> Of.reset_vars(original)}
750807
end)
751808

lib/elixir/lib/module/types/pattern.ex

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -805,12 +805,7 @@ defmodule Module.Types.Pattern do
805805
end
806806
end
807807

808-
# _
809-
defp of_pattern({:_, _meta, _var_context}, _path, _stack, context) do
810-
{term(), true, context}
811-
end
812-
813-
# var
808+
# var (includes underscores)
814809
defp of_pattern({name, meta, ctx} = var, path, _stack, context)
815810
when is_atom(name) and is_atom(ctx) do
816811
version = Keyword.fetch!(meta, :version)

lib/elixir/lib/protocol.ex

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -786,7 +786,7 @@ defmodule Protocol do
786786
end
787787

788788
defp fallback_clause_for(value, _protocol, meta) do
789-
{meta, [quote(do: _)], [], value}
789+
{meta, [{:_, [version: -1], __MODULE__}], [], value}
790790
end
791791

792792
# Finally compile the module and emit its bytecode.

lib/elixir/src/elixir_expand.erl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -385,13 +385,15 @@ expand({'^', Meta, [Arg]}, S, E) ->
385385
function_error(Meta, E, ?MODULE, {pin_outside_of_match, Arg}),
386386
{{'^', Meta, [Arg]}, S#elixir_ex{tainted_function=true}, E};
387387

388-
expand({'_', Meta, Kind} = Var, S, #{context := Context} = E) when is_atom(Kind) ->
388+
expand({'_', Meta, Kind}, #elixir_ex{version=Counter} = S, #{context := Context} = E) when is_atom(Kind) ->
389+
NewVar = {'_', [{version, Counter} | Meta], Kind},
390+
389391
case Context of
390392
match ->
391-
{Var, S, E};
393+
{NewVar, S#elixir_ex{version=Counter+1}, E};
392394
_ ->
393395
function_error(Meta, E, ?MODULE, unbound_underscore),
394-
{Var, S#elixir_ex{tainted_function=true}, E}
396+
{NewVar, S#elixir_ex{tainted_function=true, version=Counter+1}, E}
395397
end;
396398

397399
expand({Name, Meta, Kind}, S, #{context := match} = E) when is_atom(Name), is_atom(Kind) ->
@@ -798,7 +800,7 @@ expand_case(Meta, Expr, Opts, S, E) ->
798800
end,
799801

800802
{EOpts, #elixir_ex{version=Counter} = SO, EO} = elixir_clauses:'case'(Meta, ROpts, SE, EE),
801-
{{'case', [{version, Counter} | Meta], [EExpr, EOpts]}, SO#elixir_ex{version = Counter + 1}, EO}.
803+
{{'case', [{version, Counter} | Meta], [EExpr, EOpts]}, SO#elixir_ex{version=Counter+1}, EO}.
802804

803805
rewrite_case_clauses([{do, [
804806
{'->', FalseMeta, [

lib/elixir/test/elixir/kernel/expansion_test.exs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -387,8 +387,8 @@ defmodule Kernel.ExpansionTest do
387387
[
388388
{:=, _, [var_ver(:x, 0), 0]},
389389
{:=, _, [_, var_ver(:x, 0)]},
390-
{:=, _, [var_ver(:x, 1), 1]},
391-
{:=, _, [_, var_ver(:x, 1)]}
390+
{:=, _, [var_ver(:x, 2), 1]},
391+
{:=, _, [_, var_ver(:x, 2)]}
392392
]} =
393393
expand_with_version(
394394
quote do
@@ -404,7 +404,7 @@ defmodule Kernel.ExpansionTest do
404404
{:=, _, [var_ver(:x, 0), 0]},
405405
{:fn, _, [{:->, _, [[var_ver(:x, 1)], {:=, _, [var_ver(:x, 2), 2]}]}]},
406406
{:=, _, [_, var_ver(:x, 0)]},
407-
{:=, _, [var_ver(:x, 3), 3]}
407+
{:=, _, [var_ver(:x, 4), 3]}
408408
]} =
409409
expand_with_version(
410410
quote do
@@ -420,7 +420,7 @@ defmodule Kernel.ExpansionTest do
420420
{:=, _, [var_ver(:x, 0), 0]},
421421
{:case, _, [:foo, [do: [{:->, _, [[var_ver(:x, 1)], var_ver(:x, 1)]}]]]},
422422
{:=, _, [_, var_ver(:x, 0)]},
423-
{:=, _, [var_ver(:x, 3), 2]}
423+
{:=, _, [var_ver(:x, 4), 2]}
424424
]} =
425425
expand_with_version(
426426
quote do

lib/elixir/test/elixir/module/types/expr_test.exs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1915,6 +1915,22 @@ defmodule Module.Types.ExprTest do
19151915
) == atom([:ok, nil])
19161916
end
19171917

1918+
test "refines expression type" do
1919+
assert typecheck!(
1920+
if x = System.get_env("HELLO") do
1921+
{:ok, x}
1922+
else
1923+
{:error, x}
1924+
end
1925+
) ==
1926+
dynamic(
1927+
union(
1928+
tuple([atom([:ok]), binary()]),
1929+
tuple([atom([:error]), atom([nil])])
1930+
)
1931+
)
1932+
end
1933+
19181934
test "and/or does not report on literals" do
19191935
assert typecheck!(false and true) == boolean()
19201936
assert typecheck!(false or true) == atom([true])

lib/elixir/test/elixir/module/types/integration_test.exs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,7 @@ defmodule Module.Types.IntegrationTest do
338338
339339
previous clauses have already matched on the following types:
340340
341-
term(), integer()
341+
not integer(), integer()
342342
integer(), term()
343343
344344

0 commit comments

Comments
 (0)