@@ -132,9 +132,8 @@ def function_wrapper(
132132 dump_args_config ,
133133 dump_ret : bool ,
134134 dump_ret_config ,
135- handle_proxy : bool ,
136- trigger_proxy_state_dump : bool ,
137- proxy_state_dump_config : dict ,
135+ trigger_var_dump : bool ,
136+ var_dump_config : dict ,
138137 need_unproxy_args_kwargs : bool ,
139138 * args ,
140139 ** kwargs ,
@@ -176,7 +175,6 @@ def function_wrapper(
176175 pre_record ["stack_trace" ] = traceback .format_stack ()
177176
178177 if scan_proxy_in_args :
179- # TODO: can be optimized: use static or dynamic analysis to determine which args/kwargs to scan
180178 proxy_in_args = []
181179
182180 def find_proxy_in_args (args ):
@@ -214,10 +212,10 @@ def find_proxy_in_args(args):
214212 pre_record ["kwargs" ] = dict_args_kwargs ["kwargs" ]
215213 dump_trace_API (pre_record )
216214
217- if trigger_proxy_state_dump :
215+ if trigger_var_dump :
218216 """Mimicking the behavior the observer wrapper: pre-observe"""
219217 get_global_registry ().dump_modified (
220- dump_loc = original_function_name , dump_config = proxy_state_dump_config
218+ dump_loc = original_function_name , dump_config = var_dump_config
221219 )
222220
223221 if need_unproxy_args_kwargs :
@@ -233,10 +231,10 @@ def find_proxy_in_args(args):
233231 if COLLECT_OVERHEAD_METRICS :
234232 ORIG_EXIT_PERF_TIME = time .perf_counter ()
235233
236- if handle_proxy and trigger_proxy_state_dump :
234+ if trigger_var_dump :
237235 """Mimicking the behavior the observer wrapper: post-observe"""
238236 get_global_registry ().dump_modified (
239- dump_loc = original_function_name , dump_config = proxy_state_dump_config
237+ dump_loc = original_function_name , dump_config = var_dump_config
240238 )
241239
242240 dump_trace_API (
@@ -261,9 +259,9 @@ def find_proxy_in_args(args):
261259 )
262260 raise e
263261
264- if handle_proxy and trigger_proxy_state_dump :
262+ if trigger_var_dump :
265263 get_global_registry ().dump_modified (
266- dump_loc = original_function_name , dump_config = proxy_state_dump_config
264+ dump_loc = original_function_name , dump_config = var_dump_config
267265 )
268266
269267 post_record = {
@@ -351,7 +349,7 @@ def find_proxy_in_args(args):
351349 return result
352350
353351
354- def core_wrapper_proxy (original_function , is_builtin , handle_proxy , * args , ** kwargs ):
352+ def core_wrapper_proxy (original_function , * args , ** kwargs ):
355353 """Core wrapper that only handles unproxying for built-in functions."""
356354 global DISABLE_WRAPPER
357355 if DISABLE_WRAPPER :
@@ -372,8 +370,8 @@ def wrapper(
372370 dump_ret = True ,
373371 dump_ret_config = None ,
374372 handle_proxy = True ,
375- trigger_proxy_state_dump = False ,
376- proxy_state_dump_config = None ,
373+ trigger_var_dump = False ,
374+ var_dump_config = None ,
377375):
378376 is_builtin = is_c_level_function (original_function )
379377 need_unproxy_args_kwargs = handle_proxy and (
@@ -404,25 +402,22 @@ def wrapped(*args, **kwargs):
404402 dump_args_config = dump_args_config ,
405403 dump_ret = dump_ret ,
406404 dump_ret_config = dump_ret_config ,
407- handle_proxy = handle_proxy ,
408- trigger_proxy_state_dump = trigger_proxy_state_dump ,
409- proxy_state_dump_config = proxy_state_dump_config ,
405+ trigger_var_dump = trigger_var_dump ,
406+ var_dump_config = var_dump_config ,
410407 need_unproxy_args_kwargs = need_unproxy_args_kwargs ,
411408 * args ,
412409 ** kwargs ,
413410 )
414411
415412 else :
416413 METRIC_INSTRUMENTED_FUNC_LIST ["no_dump" ].append (original_function_name )
417- if handle_proxy :
414+ if need_unproxy_args_kwargs :
418415
419416 @functools .wraps (original_function )
420417 def wrapped (* args , ** kwargs ):
421418 if increment_step :
422419 META_VARS ["step" ] += 1
423- return core_wrapper_proxy (
424- original_function , is_builtin , handle_proxy , * args , ** kwargs
425- )
420+ return core_wrapper_proxy (original_function , * args , ** kwargs )
426421
427422 else :
428423 if increment_step :
@@ -817,7 +812,7 @@ def get_wrapped_function(self, func_obj: Callable) -> Callable:
817812 if self .instr_opts is not None
818813 else config .MODEL_TRACKER_STYLE
819814 )
820- used_proxy = tracker_style == "proxy"
815+ used_proxy = tracker_style == "proxy" # TODO: refactor this:
821816 if self .instr_opts is None :
822817 # inference stage instrumentation
823818 return wrapper (
@@ -861,9 +856,9 @@ def get_wrapped_function(self, func_obj: Callable) -> Callable:
861856 else None
862857 ),
863858 handle_proxy = used_proxy ,
864- trigger_proxy_state_dump = self .instr_opts .disable_proxy_dumping
859+ trigger_var_dump = self .instr_opts .disable_proxy_dumping
865860 and len (func_instr_opt ["var_types_to_track" ]) > 0 ,
866- proxy_state_dump_config = func_instr_opt ["var_types_to_track" ],
861+ var_dump_config = func_instr_opt ["var_types_to_track" ],
867862 )
868863
869864 def _instrument_module (
0 commit comments