Skip to content

Commit 80768af

Browse files
pvcresincodex
authored andcommitted
Support direct argument forwarding
Co-authored-by: OpenAI Codex <codex@openai.com>
1 parent 9adbaa7 commit 80768af

7 files changed

Lines changed: 225 additions & 39 deletions

File tree

lib/typeprof/core/ast/call.rb

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
module TypeProf::Core
22
class AST
33
class CallBaseNode < Node
4-
def initialize(raw_node, recv, mid, mid_code_range, raw_args, last_arg, raw_block, lenv)
4+
def initialize(raw_node, recv, mid, mid_code_range, raw_args, last_arg, raw_block, lenv, forwarding_arguments: false)
55
super(raw_node, lenv)
66

77
@recv = recv
@@ -20,6 +20,7 @@ def initialize(raw_node, recv, mid, mid_code_range, raw_args, last_arg, raw_bloc
2020
@block_body = nil
2121
@safe_navigation = raw_node.respond_to?(:safe_navigation?) && raw_node.safe_navigation?
2222
@anonymous_block_forwarding = false
23+
@forwarding_arguments = forwarding_arguments
2324

2425
if raw_args
2526
args = []
@@ -30,7 +31,7 @@ def initialize(raw_node, recv, mid, mid_code_range, raw_args, last_arg, raw_bloc
3031
args << raw_arg.expression
3132
@splat_flags << true
3233
when Prism::ForwardingArgumentsNode
33-
# TODO: Support forwarding arguments
34+
@forwarding_arguments = true
3435
else
3536
args << raw_arg
3637
@splat_flags << false
@@ -98,10 +99,10 @@ def initialize(raw_node, recv, mid, mid_code_range, raw_args, last_arg, raw_bloc
9899
attr_reader :positional_args, :splat_flags, :keyword_args
99100
attr_reader :block_tbl, :block_f_args, :block_opt_positional_defaults, :block_body, :block_pass, :anonymous_block_forwarding
100101
attr_reader :block_multi_targets
101-
attr_reader :safe_navigation
102+
attr_reader :safe_navigation, :forwarding_arguments
102103

103104
def subnodes = { recv:, positional_args:, keyword_args:, block_opt_positional_defaults:, block_body:, block_pass: }
104-
def attrs = { mid:, splat_flags:, block_tbl:, block_f_args:, yield:, safe_navigation:, anonymous_block_forwarding: }
105+
def attrs = { mid:, splat_flags:, block_tbl:, block_f_args:, yield:, safe_navigation:, anonymous_block_forwarding:, forwarding_arguments: }
105106

106107
def install0(genv)
107108
recv = @recv ? @recv.install(genv) : @yield ? @lenv.get_var(:"*given_block") : @lenv.get_var(:"*self")
@@ -111,22 +112,29 @@ def install0(genv)
111112
recv = NilFilter.new(genv, self, recv, false).next_vtx
112113
end
113114

114-
positional_args = @positional_args.map do |arg|
115-
if arg.is_a?(DummyNilNode)
116-
@lenv.get_var(:"*anonymous_rest")
117-
else
118-
arg.install(genv)
115+
if @forwarding_arguments
116+
forward_a_args = (@lenv.forward_args || raise).to_actual_arguments(genv, @changes, self)
117+
positional_args = forward_a_args.positionals
118+
splat_flags = forward_a_args.splat_flags
119+
keyword_args = forward_a_args.keywords
120+
else
121+
positional_args = @positional_args.map do |arg|
122+
if arg.is_a?(DummyNilNode)
123+
@lenv.get_var(:"*anonymous_rest")
124+
else
125+
arg.install(genv)
126+
end
119127
end
128+
splat_flags = @splat_flags
129+
keyword_args = @keyword_args ? @keyword_args.install(genv) : nil
120130
end
121131

122-
keyword_args = @keyword_args ? @keyword_args.install(genv) : nil
123-
124132
if @block_body
125133
block_body = @block_body # kinda type annotationty
126134
block_tbl = @block_tbl || raise
127135
@lenv.locals.each {|var, vtx| block_body.lenv.locals[var] = vtx }
128136
block_tbl.each {|var| block_body.lenv.locals[var] = Source.new(genv.nil_type) }
129-
@block_body.lenv.locals[:"*self"] = @block_body.lenv.cref.get_self(genv)
137+
block_body.lenv.locals[:"*self"] = block_body.lenv.cref.get_self(genv)
130138

131139
blk_f_args = []
132140
if @block_f_args
@@ -156,7 +164,7 @@ def install0(genv)
156164
block_body.lenv.set_var(var, vtx)
157165
end
158166
vars = []
159-
@block_body.modified_vars(@lenv.locals.keys - block_tbl, vars)
167+
block_body.modified_vars(@lenv.locals.keys - block_tbl, vars)
160168
vars.uniq!
161169
vars.each do |var|
162170
vtx = @lenv.get_var(var)
@@ -165,9 +173,9 @@ def install0(genv)
165173
block_body.lenv.set_var(var, nvtx)
166174
end
167175

168-
@block_body.lenv.locals[:"*expected_block_ret"] = Vertex.new(self)
169-
@block_body.install(genv)
170-
@block_body.lenv.add_next_box(@changes.add_escape_box(genv, @block_body.ret))
176+
block_body.lenv.locals[:"*expected_block_ret"] = Vertex.new(self)
177+
block_body.install(genv)
178+
block_body.lenv.add_next_box(@changes.add_escape_box(genv, block_body.ret))
171179

172180
vars.each do |var|
173181
@changes.add_edge(genv, block_body.lenv.get_var(var), @lenv.get_var(var))
@@ -179,15 +187,17 @@ def install0(genv)
179187
elem_vtx = @changes.add_splat_box(genv, blk_f_ary_arg, i).ret
180188
@changes.add_edge(genv, elem_vtx, f_arg)
181189
end
182-
block = Block.new(self, blk_f_ary_arg, blk_f_args, @block_body.lenv.next_boxes)
190+
block = Block.new(self, blk_f_ary_arg, blk_f_args, block_body.lenv.next_boxes)
183191
blk_ty = Source.new(Type::Proc.new(genv, block))
184192
elsif @block_pass
185193
blk_ty = @block_pass.install(genv)
186194
elsif @anonymous_block_forwarding
187195
blk_ty = @lenv.get_var(:"*anonymous_block")
196+
elsif @forwarding_arguments
197+
blk_ty = forward_a_args.block
188198
end
189199

190-
a_args = ActualArguments.new(positional_args, @splat_flags, keyword_args, blk_ty)
200+
a_args = ActualArguments.new(positional_args, splat_flags, keyword_args, blk_ty)
191201
box = @changes.add_method_call_box(genv, recv, @mid, a_args, !@recv)
192202

193203
block_body = @block_body

lib/typeprof/core/ast/method.rb

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,23 @@ def install0(genv)
271271
block = @body.lenv.new_var(:"*given_block", self)
272272
end
273273

274+
forward_opt_positionals = @opt_positionals.map do
275+
elem_vtx = Vertex.new(self)
276+
[Source.new(genv.gen_ary_type(elem_vtx)), elem_vtx]
277+
end
278+
forward_opt_keywords = @opt_keywords.map {|_name| Vertex.new(self) }
279+
@body.lenv.forward_args = ForwardingArguments.new(
280+
req_positionals,
281+
forward_opt_positionals.map(&:first),
282+
forward_opt_positionals.map(&:last),
283+
rest_positionals,
284+
post_positionals,
285+
@req_keywords.zip(req_keywords),
286+
@opt_keywords.zip(forward_opt_keywords),
287+
rest_keywords,
288+
block,
289+
)
290+
274291
if @body
275292
@body.lenv.locals[:"*expected_method_ret"] = Vertex.new(self)
276293
@body.install(genv)

lib/typeprof/core/env.rb

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -322,7 +322,7 @@ def code_units_cache
322322
end
323323

324324
class LocalEnv
325-
def initialize(file_context, cref, locals, return_boxes)
325+
def initialize(file_context, cref, locals, return_boxes, forward_args = nil)
326326
@file_context = file_context
327327
@cref = cref
328328
@locals = locals
@@ -331,10 +331,11 @@ def initialize(file_context, cref, locals, return_boxes)
331331
@next_boxes = []
332332
@ivar_narrowings = {}
333333
@strict_const_scope = false
334+
@forward_args = forward_args
334335
end
335336

336337
attr_reader :file_context, :cref, :locals, :return_boxes, :break_vtx, :next_boxes, :strict_const_scope
337-
attr_accessor :module_function
338+
attr_accessor :module_function, :forward_args
338339

339340
def path = @file_context&.path
340341
def code_range_from_node(node)
@@ -369,7 +370,6 @@ def get_break_vtx
369370
@break_vtx ||= Vertex.new(:break_vtx)
370371
end
371372

372-
373373
def push_ivar_narrowing(name, narrowing)
374374
raise unless narrowing.is_a?(Narrowing::Constraint)
375375
(@ivar_narrowings[name] ||= []) << narrowing

lib/typeprof/core/env/method.rb

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,162 @@ def get_keyword_arg(genv, changes, name)
9595
end
9696
end
9797

98+
class ForwardingArguments
99+
def initialize(req_positionals, opt_positionals, opt_positional_elems, rest_positionals, post_positionals, req_keyword_pairs, opt_keyword_pairs, rest_keywords, block)
100+
@req_positionals = req_positionals
101+
@opt_positionals = opt_positionals
102+
@opt_positional_elems = opt_positional_elems
103+
@rest_positionals = rest_positionals
104+
@post_positionals = post_positionals
105+
@req_keyword_pairs = req_keyword_pairs
106+
@opt_keyword_pairs = opt_keyword_pairs
107+
@rest_keywords = rest_keywords
108+
@block = block
109+
end
110+
111+
attr_reader :block
112+
113+
def to_actual_arguments(genv, changes, node)
114+
positionals = @req_positionals.dup
115+
splat_flags = ::Array.new(positionals.size, false)
116+
117+
@opt_positionals.each do |arg|
118+
positionals << arg
119+
splat_flags << true
120+
end
121+
122+
if @rest_positionals
123+
positionals << @rest_positionals
124+
splat_flags << true
125+
end
126+
127+
@post_positionals.each do |arg|
128+
positionals << arg
129+
splat_flags << false
130+
end
131+
132+
keywords = build_keyword_args(genv, changes, node)
133+
ActualArguments.new(positionals, splat_flags, keywords, @block)
134+
end
135+
136+
def accept_actual_arguments(genv, changes, a_args)
137+
if a_args.splat_flags.any?
138+
start_rest = [a_args.splat_flags.index(true), @req_positionals.size + @opt_positionals.size].min
139+
end_rest = [a_args.splat_flags.rindex(true) + 1, a_args.positionals.size - @post_positionals.size].max
140+
rest_vtxs = a_args.get_rest_args(genv, changes, start_rest, end_rest)
141+
142+
@req_positionals.each_with_index do |f_vtx, i|
143+
if i < start_rest
144+
changes.add_edge(genv, a_args.positionals[i], f_vtx)
145+
else
146+
rest_vtxs.each do |vtx|
147+
changes.add_edge(genv, vtx, f_vtx)
148+
end
149+
end
150+
end
151+
152+
@opt_positional_elems.each_with_index do |elem_vtx, i|
153+
i += @req_positionals.size
154+
if i < start_rest
155+
changes.add_edge(genv, a_args.positionals[i], elem_vtx)
156+
else
157+
rest_vtxs.each do |vtx|
158+
changes.add_edge(genv, vtx, elem_vtx)
159+
end
160+
end
161+
end
162+
163+
@post_positionals.each_with_index do |f_vtx, i|
164+
i += a_args.positionals.size - @post_positionals.size
165+
if end_rest <= i
166+
changes.add_edge(genv, a_args.positionals[i], f_vtx)
167+
else
168+
rest_vtxs.each do |vtx|
169+
changes.add_edge(genv, vtx, f_vtx)
170+
end
171+
end
172+
end
173+
174+
else
175+
@req_positionals.each_with_index do |f_vtx, i|
176+
changes.add_edge(genv, a_args.positionals[i], f_vtx)
177+
end
178+
179+
@post_positionals.each_with_index do |f_vtx, i|
180+
i -= @post_positionals.size
181+
changes.add_edge(genv, a_args.positionals[i], f_vtx)
182+
end
183+
184+
start_rest = @req_positionals.size
185+
end_rest = a_args.positionals.size - @post_positionals.size
186+
i = 0
187+
while i < @opt_positional_elems.size && start_rest < end_rest
188+
changes.add_edge(genv, a_args.positionals[start_rest], @opt_positional_elems[i])
189+
i += 1
190+
start_rest += 1
191+
end
192+
end
193+
194+
changes.add_edge(genv, a_args.block, @block) if @block && a_args.block
195+
196+
return unless a_args.keywords
197+
198+
@req_keyword_pairs.each do |name, f_vtx|
199+
changes.add_edge(genv, a_args.get_keyword_arg(genv, changes, name), f_vtx)
200+
end
201+
202+
@opt_keyword_pairs.each do |name, f_vtx|
203+
changes.add_edge(genv, a_args.get_keyword_arg(genv, changes, name), f_vtx)
204+
end
205+
206+
if @rest_keywords
207+
named_keys = @req_keyword_pairs.map(&:first) + @opt_keyword_pairs.map(&:first)
208+
a_args.keywords.each_type do |kw_ty|
209+
case kw_ty
210+
when Type::Record
211+
rest_fields = kw_ty.fields.reject {|key, _| named_keys.include?(key) }
212+
base = kw_ty.base_type(genv)
213+
rest_record = Type::Record.new(genv, rest_fields, base)
214+
changes.add_edge(genv, Source.new(rest_record), @rest_keywords)
215+
when Type::Hash, Type::Instance
216+
changes.add_edge(genv, Source.new(kw_ty), @rest_keywords)
217+
end
218+
end
219+
end
220+
end
221+
222+
private
223+
224+
def build_keyword_args(genv, changes, node)
225+
return nil if @req_keyword_pairs.empty? && @opt_keyword_pairs.empty? && !@rest_keywords
226+
return @rest_keywords if @req_keyword_pairs.empty? && @opt_keyword_pairs.empty?
227+
228+
unified_key = Vertex.new(node)
229+
unified_val = Vertex.new(node)
230+
literal_pairs = {}
231+
232+
@req_keyword_pairs.each do |name, vtx|
233+
changes.add_edge(genv, Source.new(Type::Symbol.new(genv, name)), unified_key)
234+
changes.add_edge(genv, vtx, unified_val)
235+
literal_pairs[name] = vtx
236+
end
237+
238+
@opt_keyword_pairs.each do |name, vtx|
239+
changes.add_edge(genv, Source.new(Type::Symbol.new(genv, name)), unified_key)
240+
changes.add_edge(genv, vtx, unified_val)
241+
end
242+
243+
base_hash_type = genv.gen_hash_type(unified_key, unified_val)
244+
changes.add_hash_splat_box(genv, @rest_keywords, unified_key, unified_val) if @rest_keywords
245+
246+
if literal_pairs.empty?
247+
Source.new(base_hash_type)
248+
else
249+
Source.new(Type::Record.new(genv, literal_pairs, base_hash_type))
250+
end
251+
end
252+
end
253+
98254
class Block
99255
#: (AST::CallBaseNode, Vertex, Array[Vertex], Array[EscapeBox]) -> void
100256
def initialize(node, f_ary_arg, f_args, next_boxes)

lib/typeprof/core/graph/box.rb

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -768,8 +768,6 @@ def run0(genv, changes)
768768
end
769769

770770
def pass_arguments(changes, genv, a_args)
771-
a_args = normalize_keyword_hash_argument_for_def(a_args)
772-
773771
if a_args.splat_flags.any?
774772
# there is at least one splat actual argument
775773

@@ -906,7 +904,11 @@ def normalize_keyword_hash_argument_for_def(a_args)
906904
end
907905

908906
def call(changes, genv, a_args, ret)
907+
a_args = normalize_keyword_hash_argument_for_def(a_args)
909908
if pass_arguments(changes, genv, a_args)
909+
if @node.is_a?(AST::DefNode)
910+
@node.body.lenv.forward_args&.accept_actual_arguments(genv, changes, a_args)
911+
end
910912
changes.add_edge(genv, a_args.block, @f_args.block) if @f_args.block && a_args.block
911913

912914
changes.add_edge(genv, @ret, ret)

scenario/args/forwarding_arguments.rb

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,20 @@ def foo(a, ...)
1919
class Object
2020
def foo: (Integer, *untyped, **untyped) -> Integer
2121
end
22+
23+
## update
24+
def foo(...)
25+
bar(...)
26+
end
27+
28+
def bar(*a, **b)
29+
[a, b]
30+
end
31+
32+
foo(1, x: 4, y: 5)
33+
34+
## assert
35+
class Object
36+
def foo: (*Integer, **Integer) -> [Array[Integer], { x: Integer, y: Integer }]
37+
def bar: (*Integer, **Integer) -> [Array[Integer], { x: Integer, y: Integer }]
38+
end

scenario/known-issues/forwarding-arguments.rb

Lines changed: 0 additions & 16 deletions
This file was deleted.

0 commit comments

Comments
 (0)