Skip to content

Commit ebb7c73

Browse files
committed
Improve inference across for and try
1 parent 898f1a2 commit ebb7c73

3 files changed

Lines changed: 44 additions & 20 deletions

File tree

lib/elixir/lib/module/types/expr.ex

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -411,7 +411,7 @@ defmodule Module.Types.Expr do
411411

412412
{type, context} =
413413
if else_block do
414-
{type, context} = of_expr(body, @pending, body, stack, original)
414+
{type, context} = of_expr(body, term(), body, stack, original)
415415
info = {:try_else, meta, body, type}
416416
of_clauses(else_block, [type], expected, info, stack, context, none())
417417
else
@@ -425,13 +425,13 @@ defmodule Module.Types.Expr do
425425
Enum.reduce(clauses, acc_context, fn
426426
{:->, _, [[{:in, meta, [var, exceptions]} = expr], body]}, {acc, context} ->
427427
{type, context} =
428-
of_rescue(var, exceptions, body, expr, :rescue, meta, stack, context)
428+
of_rescue(var, exceptions, body, expr, expected, :rescue, meta, stack, context)
429429

430430
{union(type, acc), context}
431431

432432
{:->, meta, [[var], body]}, {acc, context} ->
433433
{type, context} =
434-
of_rescue(var, [], body, var, :anonymous_rescue, meta, stack, context)
434+
of_rescue(var, [], body, var, expected, :anonymous_rescue, meta, stack, context)
435435

436436
{union(type, acc), context}
437437
end)
@@ -494,28 +494,37 @@ defmodule Module.Types.Expr do
494494
of_clauses(block, args, expected, :for_reduce, stack, context, reduce_type)
495495
else
496496
# TODO: Use the collectable protocol for the output
497-
# TODO: Use the expected type for the block output
498497
into = Keyword.get(opts, :into, [])
499498
{into_type, into_kind, context} = for_into(into, meta, stack, context)
500-
{block_type, context} = of_expr(block, @pending, block, stack, context)
501499

502500
case into_kind do
503501
:bitstring ->
502+
{block_type, context} = of_expr(block, bitstring(), block, stack, context)
503+
504504
case compatible_intersection(block_type, bitstring()) do
505505
{:ok, intersection} ->
506-
{return_union(into_type, intersection, stack), context}
506+
{union(into_type, intersection), context}
507507

508508
{:error, _} ->
509509
error = {:badbitbody, block_type, block, context}
510510
{error_type(), error(__MODULE__, error, meta, stack, context)}
511511
end
512512

513513
:non_empty_list ->
514-
{return_union(into_type, non_empty_list(block_type), stack), context}
514+
expected =
515+
case list_hd(expected) do
516+
{:ok, head} -> head
517+
_ -> term()
518+
end
519+
520+
{block_type, context} = of_expr(block, expected, block, stack, context)
521+
{union(into_type, non_empty_list(block_type)), context}
515522

516523
:none ->
524+
{_, context} = of_expr(block, term(), block, stack, context)
517525
{into_type, context}
518526
end
527+
|> dynamic_unless_static(stack)
519528
end
520529
end)
521530
end
@@ -535,7 +544,7 @@ defmodule Module.Types.Expr do
535544

536545
# TODO: Perform inference based on the strong domain of a function
537546
{args_types, context} =
538-
Enum.map_reduce(args, context, &of_expr(&1, @pending, &1, stack, &2))
547+
Enum.map_reduce(args, context, &of_expr(&1, term(), &1, stack, &2))
539548

540549
Apply.fun(fun_type, args_types, call, stack, context)
541550
end
@@ -619,7 +628,7 @@ defmodule Module.Types.Expr do
619628

620629
## Try
621630

622-
defp of_rescue(var, exceptions, body, expr, info, meta, stack, original) do
631+
defp of_rescue(var, exceptions, body, expr, expected, info, meta, stack, original) do
623632
args = [__exception__: term()]
624633

625634
{structs, context} =
@@ -648,7 +657,7 @@ defmodule Module.Types.Expr do
648657
context
649658
end
650659

651-
{type, context} = of_expr(body, @pending, body, stack, context)
660+
{type, context} = of_expr(body, expected, body, stack, context)
652661
{type, Of.reset_vars(context, original)}
653662
end
654663

@@ -658,16 +667,16 @@ defmodule Module.Types.Expr do
658667
expr = {:<-, [type_check: :generator] ++ meta, [left, right]}
659668
{pattern, guards} = extract_head([left])
660669

670+
# TODO: Extract the type from enumerable protocol
661671
{_type, context} =
662-
Apply.remote(Enumerable, :count, [right], dynamic(), expr, stack, context, &of_expr/5)
672+
Apply.remote(Enumerable, :count, [right], term(), expr, stack, context, &of_expr/5)
663673

664674
Pattern.of_generator(pattern, guards, dynamic(), :for, expr, stack, context)
665675
end
666676

667677
defp for_clause({:<<>>, _, [{:<-, meta, [left, right]}]} = expr, stack, context) do
668678
{right_type, context} = of_expr(right, bitstring(), expr, stack, context)
669-
info = {:for, expr, dynamic()}
670-
context = Pattern.of_generator(left, [], bitstring(), info, expr, stack, context)
679+
context = Pattern.of_generator(left, [], bitstring(), :for, expr, stack, context)
671680

672681
if compatible?(right_type, bitstring()) do
673682
context
@@ -728,17 +737,12 @@ defmodule Module.Types.Expr do
728737
end
729738
end
730739

731-
defp return_union(left, right, stack) do
732-
Apply.return(union(left, right), [left, right], stack)
733-
end
734-
735740
## With
736741

737742
defp with_clause({:<-, _meta, [left, right]} = expr, stack, context) do
738743
{pattern, guards} = extract_head([left])
739744
{_type, context} = of_expr(right, @pending, right, stack, context)
740-
info = {:with, expr, dynamic()}
741-
Pattern.of_generator(pattern, guards, dynamic(), info, expr, stack, context)
745+
Pattern.of_generator(pattern, guards, dynamic(), :with, expr, stack, context)
742746
end
743747

744748
defp with_clause(expr, stack, context) do

lib/elixir/lib/module/types/pattern.ex

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -323,7 +323,9 @@ defmodule Module.Types.Pattern do
323323
@doc """
324324
Handles matches in generators.
325325
"""
326-
def of_generator(pattern, guards, expected, tag, expr, stack, %{vars: vars} = context) do
326+
def of_generator(pattern, guards, expected, op, expr, stack, %{vars: vars} = context)
327+
when is_atom(op) do
328+
tag = {op, expr, expected}
327329
context = init_pattern_info(context, [])
328330

329331
{tree, _precise?, context} =

lib/elixir/test/elixir/module/types/expr_test.exs

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2618,6 +2618,24 @@ defmodule Module.Types.ExprTest do
26182618
) == union(bitstring(), list(term()))
26192619
end
26202620

2621+
test ":into inference" do
2622+
assert typecheck!(
2623+
[x, y],
2624+
(
2625+
List.to_integer([_ | _] = for(_ <- x, do: y))
2626+
y
2627+
)
2628+
) == dynamic(integer())
2629+
2630+
assert typecheck!(
2631+
[x, y],
2632+
(
2633+
for(<<_ <- x>>, do: y, into: "")
2634+
y
2635+
)
2636+
) == dynamic(bitstring())
2637+
end
2638+
26212639
test ":into incompatibility" do
26222640
assert typeerror!([binary], for(<<x <- binary>>, do: x, into: "")) =~ ~l"""
26232641
expected the body of a for-comprehension with into: binary() (or bitstring()) to be a binary (or bitstring):

0 commit comments

Comments
 (0)