Skip to content

Commit b9c3cd3

Browse files
mameclaude
andcommitted
Isolate variable types across case/when branches
Variables modified in one when branch no longer leak into other branches. Each branch starts from the original variable state, and all branches are joined after the case statement, similar to how if/unless handles variable branching. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 720f290 commit b9c3cd3

2 files changed

Lines changed: 76 additions & 26 deletions

File tree

lib/typeprof/core/ast/control.rb

Lines changed: 56 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -322,44 +322,74 @@ def install0(genv)
322322
ret = Vertex.new(self)
323323
@pivot&.install(genv)
324324

325-
# case文での型絞り込みを実装
326-
if @pivot && @pivot.is_a?(LocalVariableReadNode)
327-
var = @pivot.var
328-
original_vtx = @lenv.get_var(var)
325+
# Collect modified variables across all branches
326+
vars = []
327+
@when_nodes.each {|wn| wn.body.modified_vars(@lenv.locals.keys, vars) }
328+
@else_clause.modified_vars(@lenv.locals.keys, vars) if @else_clause
329+
vars.uniq!
329330

330-
# ダミー変数に元の型情報を設定
331-
@lenv.set_var(:"*pivot", original_vtx)
331+
# Save original variable vertices
332+
saved_vtxs = {}
333+
vars.each do |var|
334+
saved_vtxs[var] = @lenv.get_var(var)
335+
end
332336

333-
# 各when節を実行
334-
@when_nodes.each do |when_node|
335-
clause_result = when_node.install(genv)
336-
@changes.add_edge(genv, clause_result, ret)
337-
# 元の型に戻す
338-
@lenv.set_var(var, original_vtx)
339-
end
337+
# Prepare per-branch result vertices
338+
branch_vtxs = []
339+
340+
# Setup pivot narrowing if applicable
341+
pivot_var = @pivot.is_a?(LocalVariableReadNode) ? @pivot.var : nil
342+
if pivot_var
343+
original_pivot = @lenv.get_var(pivot_var)
344+
@lenv.set_var(:"*pivot", original_pivot)
345+
end
346+
347+
# Install each when branch
348+
@when_nodes.each do |when_node|
349+
# Reset variables to original for each branch
350+
saved_vtxs.each {|var, vtx| @lenv.set_var(var, vtx.new_vertex(genv, self)) }
351+
@lenv.set_var(pivot_var, original_pivot) if pivot_var
352+
353+
clause_val = when_node.install(genv)
354+
@changes.add_edge(genv, clause_val, ret)
340355

341-
# else節(他のwhen節で除外された後の型)
342-
filtered_else_vtx = original_vtx.new_vertex(genv, self)
356+
modified = {}
357+
vars.each {|var| modified[var] = @lenv.get_var(var) }
358+
branch_vtxs << [clause_val, modified]
359+
end
360+
361+
# Install else branch
362+
saved_vtxs.each {|var, vtx| @lenv.set_var(var, vtx.new_vertex(genv, self)) }
363+
if pivot_var
364+
# Apply exclusion filters for else
365+
filtered_else_vtx = original_pivot.new_vertex(genv, self)
343366
@when_nodes.each do |when_node|
344367
when_node.get_exclusion_conditions.each do |static_ret|
345-
# 各when節の型を除外(negation)
346368
filtered_else_vtx = IsAFilter.new(genv, self, filtered_else_vtx, true, static_ret).next_vtx
347369
end
348370
end
349-
@lenv.set_var(var, filtered_else_vtx)
350-
@changes.add_edge(genv, @else_clause.install(genv), ret)
351-
@lenv.set_var(var, original_vtx)
371+
@lenv.set_var(pivot_var, filtered_else_vtx)
372+
end
373+
else_val = @else_clause.install(genv)
374+
@changes.add_edge(genv, else_val, ret)
352375

353-
# ダミー変数をクリア
354-
@lenv.locals.delete(:"*pivot")
355-
else
356-
# pivotが変数でない場合は従来通り
357-
@when_nodes.each do |when_node|
358-
@changes.add_edge(genv, when_node.install(genv), ret)
376+
else_modified = {}
377+
vars.each {|var| else_modified[var] = @lenv.get_var(var) }
378+
branch_vtxs << [else_val, else_modified]
379+
380+
# Join all branches
381+
vars.each do |var|
382+
joined = Vertex.new(self)
383+
branch_vtxs.each do |branch_val, modified|
384+
vtx = BotFilter.new(genv, self, modified[var], branch_val).next_vtx
385+
@changes.add_edge(genv, vtx, joined)
359386
end
360-
@changes.add_edge(genv, @else_clause.install(genv), ret)
387+
@lenv.set_var(var, joined)
361388
end
362389

390+
# Cleanup
391+
@lenv.locals.delete(:"*pivot") if pivot_var
392+
363393
ret
364394
end
365395
end

scenario/flow/case_variable.rb

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
## update
2+
def test(type, val)
3+
case type
4+
when :int
5+
val = val.to_i
6+
when :sym
7+
val = val.to_sym
8+
else
9+
val = val
10+
end
11+
val
12+
end
13+
14+
test(:int, "42")
15+
test(:sym, "hello")
16+
17+
## assert
18+
class Object
19+
def test: (:int | :sym, String) -> (Integer | String | Symbol)
20+
end

0 commit comments

Comments
 (0)