Skip to content

Commit 730463d

Browse files
committed
subclass selective instrumentation impl
1 parent 18b6a8b commit 730463d

2 files changed

Lines changed: 18 additions & 25 deletions

File tree

traincheck/instrumentor/proxy_wrapper/subclass.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -166,8 +166,6 @@ def update_timestamp(self):
166166

167167
def register_object(self):
168168
get_global_registry().add_var(self, self.__dict__["var_name"])
169-
# TODO: implement the registry, we will need to make sure the registerred timestamp is updated and is consistent with the timestamp in the object
170-
# pass
171169

172170
def dump_trace(self, phase, dump_loc):
173171
# print(f"parameter: {self.var_name}, phase = {phase}, dump_loc = {dump_loc}")

traincheck/instrumentor/tracer.py

Lines changed: 18 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -132,9 +132,8 @@ def function_wrapper(
132132
dump_args_config,
133133
dump_ret: bool,
134134
dump_ret_config,
135-
handle_proxy: bool,
136-
trigger_proxy_state_dump: bool,
137-
proxy_state_dump_config: dict,
135+
trigger_var_dump: bool,
136+
var_dump_config: dict,
138137
need_unproxy_args_kwargs: bool,
139138
*args,
140139
**kwargs,
@@ -176,7 +175,6 @@ def function_wrapper(
176175
pre_record["stack_trace"] = traceback.format_stack()
177176

178177
if scan_proxy_in_args:
179-
# TODO: can be optimized: use static or dynamic analysis to determine which args/kwargs to scan
180178
proxy_in_args = []
181179

182180
def find_proxy_in_args(args):
@@ -214,10 +212,10 @@ def find_proxy_in_args(args):
214212
pre_record["kwargs"] = dict_args_kwargs["kwargs"]
215213
dump_trace_API(pre_record)
216214

217-
if trigger_proxy_state_dump:
215+
if trigger_var_dump:
218216
"""Mimicking the behavior the observer wrapper: pre-observe"""
219217
get_global_registry().dump_modified(
220-
dump_loc=original_function_name, dump_config=proxy_state_dump_config
218+
dump_loc=original_function_name, dump_config=var_dump_config
221219
)
222220

223221
if need_unproxy_args_kwargs:
@@ -233,10 +231,10 @@ def find_proxy_in_args(args):
233231
if COLLECT_OVERHEAD_METRICS:
234232
ORIG_EXIT_PERF_TIME = time.perf_counter()
235233

236-
if handle_proxy and trigger_proxy_state_dump:
234+
if trigger_var_dump:
237235
"""Mimicking the behavior the observer wrapper: post-observe"""
238236
get_global_registry().dump_modified(
239-
dump_loc=original_function_name, dump_config=proxy_state_dump_config
237+
dump_loc=original_function_name, dump_config=var_dump_config
240238
)
241239

242240
dump_trace_API(
@@ -261,9 +259,9 @@ def find_proxy_in_args(args):
261259
)
262260
raise e
263261

264-
if handle_proxy and trigger_proxy_state_dump:
262+
if trigger_var_dump:
265263
get_global_registry().dump_modified(
266-
dump_loc=original_function_name, dump_config=proxy_state_dump_config
264+
dump_loc=original_function_name, dump_config=var_dump_config
267265
)
268266

269267
post_record = {
@@ -351,7 +349,7 @@ def find_proxy_in_args(args):
351349
return result
352350

353351

354-
def core_wrapper_proxy(original_function, is_builtin, handle_proxy, *args, **kwargs):
352+
def core_wrapper_proxy(original_function, *args, **kwargs):
355353
"""Core wrapper that only handles unproxying for built-in functions."""
356354
global DISABLE_WRAPPER
357355
if DISABLE_WRAPPER:
@@ -372,8 +370,8 @@ def wrapper(
372370
dump_ret=True,
373371
dump_ret_config=None,
374372
handle_proxy=True,
375-
trigger_proxy_state_dump=False,
376-
proxy_state_dump_config=None,
373+
trigger_var_dump=False,
374+
var_dump_config=None,
377375
):
378376
is_builtin = is_c_level_function(original_function)
379377
need_unproxy_args_kwargs = handle_proxy and (
@@ -404,25 +402,22 @@ def wrapped(*args, **kwargs):
404402
dump_args_config=dump_args_config,
405403
dump_ret=dump_ret,
406404
dump_ret_config=dump_ret_config,
407-
handle_proxy=handle_proxy,
408-
trigger_proxy_state_dump=trigger_proxy_state_dump,
409-
proxy_state_dump_config=proxy_state_dump_config,
405+
trigger_var_dump=trigger_var_dump,
406+
var_dump_config=var_dump_config,
410407
need_unproxy_args_kwargs=need_unproxy_args_kwargs,
411408
*args,
412409
**kwargs,
413410
)
414411

415412
else:
416413
METRIC_INSTRUMENTED_FUNC_LIST["no_dump"].append(original_function_name)
417-
if handle_proxy:
414+
if need_unproxy_args_kwargs:
418415

419416
@functools.wraps(original_function)
420417
def wrapped(*args, **kwargs):
421418
if increment_step:
422419
META_VARS["step"] += 1
423-
return core_wrapper_proxy(
424-
original_function, is_builtin, handle_proxy, *args, **kwargs
425-
)
420+
return core_wrapper_proxy(original_function, *args, **kwargs)
426421

427422
else:
428423
if increment_step:
@@ -817,7 +812,7 @@ def get_wrapped_function(self, func_obj: Callable) -> Callable:
817812
if self.instr_opts is not None
818813
else config.MODEL_TRACKER_STYLE
819814
)
820-
used_proxy = tracker_style == "proxy"
815+
used_proxy = tracker_style == "proxy" # TODO: refactor this:
821816
if self.instr_opts is None:
822817
# inference stage instrumentation
823818
return wrapper(
@@ -861,9 +856,9 @@ def get_wrapped_function(self, func_obj: Callable) -> Callable:
861856
else None
862857
),
863858
handle_proxy=used_proxy,
864-
trigger_proxy_state_dump=self.instr_opts.disable_proxy_dumping
859+
trigger_var_dump=self.instr_opts.disable_proxy_dumping
865860
and len(func_instr_opt["var_types_to_track"]) > 0,
866-
proxy_state_dump_config=func_instr_opt["var_types_to_track"],
861+
var_dump_config=func_instr_opt["var_types_to_track"],
867862
)
868863

869864
def _instrument_module(

0 commit comments

Comments
 (0)