Skip to content

Commit 060872d

Browse files
committed
fix instrumentation logic to get the parent class of a method definition
1 parent 730463d commit 060872d

1 file changed

Lines changed: 21 additions & 5 deletions

File tree

traincheck/instrumentor/tracer.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,23 @@ def get_meta_vars() -> dict:
8383
return META_VARS
8484

8585

86+
def get_owner_class(func):
87+
# Works for unbound functions defined on a class.
88+
qualname = getattr(func, "__qualname__", "")
89+
if "." not in qualname:
90+
return None # not a class method
91+
owner_path = qualname.rsplit(".", 1)[0] # e.g., "Optimizer"
92+
mod = inspect.getmodule(func)
93+
if mod is None:
94+
mod = importlib.import_module(func.__module__)
95+
owner = mod
96+
for part in owner_path.split("."):
97+
owner = getattr(owner, part, None)
98+
if owner is None:
99+
return None
100+
return owner
101+
102+
86103
def to_dict_args_kwargs(args, kwargs, dump_args_config=None) -> dict:
87104
global DISABLE_WRAPPER
88105
DISABLE_WRAPPER = True
@@ -379,11 +396,10 @@ def wrapper(
379396
)
380397
original_function_name = typename(original_function)
381398
increment_step = False
382-
if original_function_name.endswith(".step") and isinstance(
383-
original_function.__self__, torch.optim.Optimizer
384-
):
385-
increment_step = True
386-
399+
if original_function_name.endswith(".step"):
400+
owner = get_owner_class(original_function)
401+
if isinstance(owner, torch.optim.Optimizer):
402+
increment_step = True
387403
# determine statically whether to dump the trace
388404
if not disable_dump:
389405
METRIC_INSTRUMENTED_FUNC_LIST["dump"].append(original_function_name)

0 commit comments

Comments
 (0)