@@ -137,6 +137,8 @@ defmodule Module.Types.Expr do
137137 { @ stacktrace , context }
138138 end
139139
140+ @ dynamic_or_term_list [ dynamic ( ) , term ( ) ]
141+
140142 # left = right
141143 def of_expr ( { := , _ , [ left_expr , right_expr ] } = match , expected , expr , stack , context ) do
142144 { left_expr , right_expr } = repack_match ( left_expr , right_expr )
@@ -147,12 +149,22 @@ defmodule Module.Types.Expr do
147149 of_expr ( right_expr , expected , expr , stack , context )
148150
149151 _ ->
150- type_fun = fn pattern_type , context ->
151- # See if we can use the expected type to further refine the pattern type,
152- # if we cannot, use the pattern type as that will fail later on.
153- { _ok_or_error , type } = compatible_intersection ( dynamic ( pattern_type ) , expected )
154- of_expr ( right_expr , type , expr , stack , context )
155- end
152+ type_fun =
153+ fn pattern_type , context ->
154+ if expected in @ dynamic_or_term_list do
155+ of_expr ( right_expr , pattern_type , expr , stack , context )
156+ else
157+ # See if we can use the expected type to further refine the pattern type,
158+ # if we cannot, use the pattern type as that will fail later on.
159+ { _ok_or_error , type } = compatible_intersection ( dynamic ( pattern_type ) , expected )
160+ { result , context } = of_expr ( right_expr , type , expr , stack , context )
161+
162+ # The function may still return a too broad type, so we refine once again
163+ # to assign the most appropriate one for reverse arrows.
164+ { _ok_or_error , result } = compatible_intersection ( result , expected )
165+ { result , context }
166+ end
167+ end
156168
157169 Pattern . of_match ( left_expr , type_fun , match , stack , context )
158170 end
@@ -311,9 +323,20 @@ defmodule Module.Types.Expr do
311323 |> dynamic_unless_static ( stack )
312324 end
313325
314- def of_expr ( { :case , meta , [ case_expr , [ { :do , clauses } ] ] } , expected , _expr , stack , context ) do
315- _ = Keyword . fetch! ( meta , :version )
316- { case_type , context } = of_expr ( case_expr , @ pending , case_expr , stack , context )
326+ def of_expr ( { :case , meta , [ _case_expr , [ { :do , _clauses } ] ] } , _expected , _expr , stack , context )
327+ when stack . reverse_arrow == :use do
328+ version = Keyword . fetch! ( meta , :version )
329+ clauses = Map . fetch! ( context . reverse_arrows , version )
330+ result = Enum . reduce ( clauses , none ( ) , & union ( elem ( & 1 , 1 ) , & 2 ) )
331+ dynamic_unless_static ( { result , context } , stack )
332+ end
333+
334+ def of_expr ( { :case , meta , [ case_expr , [ { :do , clauses } ] ] } , expected , _expr , stack , base_context ) do
335+ version = Keyword . fetch! ( meta , :version )
336+
337+ { case_type , context } =
338+ of_expr ( case_expr , @ pending , case_expr , % { stack | reverse_arrow: :cache } , base_context )
339+
317340 info = { :case , meta , case_expr , case_type }
318341
319342 added_meta =
@@ -326,13 +349,35 @@ defmodule Module.Types.Expr do
326349 # If the expression is generated or the construct is a literal,
327350 # it is most likely a macro code. However, if no clause is matched,
328351 # we should still check for that.
329- if added_meta != [ ] do
330- for { :-> , meta , args } <- clauses , do: { :-> , [ generated: true ] ++ meta , args }
331- else
332- clauses
352+ clauses =
353+ if added_meta != [ ] do
354+ for { :-> , meta , args } <- clauses , do: { :-> , [ generated: true ] ++ meta , args }
355+ else
356+ clauses
357+ end
358+
359+ of_body = fn trees , body , context ->
360+ [ arg_type ] = Pattern . of_domain ( trees , stack , context )
361+
362+ { _ , context } =
363+ of_expr ( case_expr , arg_type , case_expr , % { stack | reverse_arrow: :use } , context )
364+
365+ of_expr ( body , expected , body , stack , context )
333366 end
334- |> of_clauses ( [ case_type ] , expected , info , stack , context , none ( ) )
335- |> dynamic_unless_static ( stack )
367+
368+ result_context =
369+ cache_arrows ( version , stack , fn ->
370+ of_clauses_fun ( clauses , [ case_type ] , info , stack , context , of_body , [ ] , fn
371+ trees , body_type , context , acc ->
372+ [ arg_type ] = Pattern . of_domain ( trees , stack , context )
373+ [ { arg_type , body_type } | acc ]
374+ end )
375+ end ) ||
376+ of_clauses_fun ( clauses , [ case_type ] , info , stack , context , of_body , none ( ) , fn
377+ _trees , body_type , _context , acc -> union ( acc , body_type )
378+ end )
379+
380+ dynamic_unless_static ( result_context , stack )
336381 end
337382
338383 # fn pat -> expr end
@@ -341,11 +386,13 @@ defmodule Module.Types.Expr do
341386 { patterns , _guards } = extract_head ( head )
342387 domain = Enum . map ( patterns , fn _ -> dynamic ( ) end )
343388
389+ of_body = fn _args_types , body , context -> of_expr ( body , @ pending , body , stack , context ) end
390+
344391 { acc , context } =
345- of_clauses_fun ( clauses , domain , @ pending , :fn , stack , context , [ ] , fn
346- trees , body , context , acc ->
392+ of_clauses_fun ( clauses , domain , :fn , stack , context , of_body , [ ] , fn
393+ trees , body_type , context , acc ->
347394 args_types = Pattern . of_domain ( trees , stack , context )
348- add_inferred ( acc , args_types , body )
395+ add_inferred ( acc , args_types , body_type )
349396 end )
350397
351398 { fun_from_inferred_clauses ( acc ) , context }
@@ -725,12 +772,22 @@ defmodule Module.Types.Expr do
725772 defp dynamic_unless_static ( { _ , _ } = output , % { mode: :static } ) , do: output
726773 defp dynamic_unless_static ( { type , context } , % { mode: _ } ) , do: { dynamic ( type ) , context }
727774
775+ defp cache_arrows ( _version , % { reverse_arrow: nil } , _fun ) , do: nil
776+
777+ defp cache_arrows ( version , % { reverse_arrow: :cache } , fun ) do
778+ { clauses , context } = fun . ( )
779+ context = put_in ( context . reverse_arrows [ version ] , clauses )
780+ result = Enum . reduce ( clauses , none ( ) , & union ( elem ( & 1 , 1 ) , & 2 ) )
781+ { result , context }
782+ end
783+
728784 defp of_clauses ( clauses , domain , expected , base_info , stack , context , acc ) do
729- fun = fn _args_types , result , _context , acc -> union ( result , acc ) end
730- of_clauses_fun ( clauses , domain , expected , base_info , stack , context , acc , fun )
785+ of_body = fn _args_types , body , context -> of_expr ( body , expected , body , stack , context ) end
786+ of_acc = fn _args_types , body_type , _context , acc -> union ( acc , body_type ) end
787+ of_clauses_fun ( clauses , domain , base_info , stack , context , of_body , acc , of_acc )
731788 end
732789
733- defp of_clauses_fun ( clauses , domain , expected , base_info , stack , original , acc , fun ) do
790+ defp of_clauses_fun ( clauses , domain , base_info , stack , original , of_body , acc , of_acc ) do
734791 % { failed: failed? } = original
735792
736793 { result , _previous , context } =
@@ -743,9 +800,9 @@ defmodule Module.Types.Expr do
743800 { trees , previous , context } =
744801 Pattern . of_head ( patterns , guards , domain , previous , info , meta , stack , context )
745802
746- { result , context } = of_expr ( body , expected , body , stack , context )
803+ { result , context } = of_body . ( trees , body , context )
747804
748- { fun . ( trees , result , context , acc ) , previous ,
805+ { of_acc . ( trees , result , context , acc ) , previous ,
749806 context |> set_failed ( failed? ) |> Of . reset_vars ( original ) }
750807 end )
751808
0 commit comments