@@ -83,6 +83,23 @@ def get_meta_vars() -> dict:
8383 return META_VARS
8484
8585
86+ def get_owner_class (func ):
87+ # Works for unbound functions defined on a class.
88+ qualname = getattr (func , "__qualname__" , "" )
89+ if "." not in qualname :
90+ return None # not a class method
91+ owner_path = qualname .rsplit ("." , 1 )[0 ] # e.g., "Optimizer"
92+ mod = inspect .getmodule (func )
93+ if mod is None :
94+ mod = importlib .import_module (func .__module__ )
95+ owner = mod
96+ for part in owner_path .split ("." ):
97+ owner = getattr (owner , part , None )
98+ if owner is None :
99+ return None
100+ return owner
101+
102+
86103def to_dict_args_kwargs (args , kwargs , dump_args_config = None ) -> dict :
87104 global DISABLE_WRAPPER
88105 DISABLE_WRAPPER = True
@@ -379,11 +396,10 @@ def wrapper(
379396 )
380397 original_function_name = typename (original_function )
381398 increment_step = False
382- if original_function_name .endswith (".step" ) and isinstance (
383- original_function .__self__ , torch .optim .Optimizer
384- ):
385- increment_step = True
386-
399+ if original_function_name .endswith (".step" ):
400+ owner = get_owner_class (original_function )
401+ if isinstance (owner , torch .optim .Optimizer ):
402+ increment_step = True
387403 # determine statically whether to dump the trace
388404 if not disable_dump :
389405 METRIC_INSTRUMENTED_FUNC_LIST ["dump" ].append (original_function_name )
0 commit comments