Skip to content

Commit ac5202f

Browse files
committed
fix: unify registry implementation for proxy and subclass
1 parent 2034dfd commit ac5202f

2 files changed

Lines changed: 25 additions & 18 deletions

File tree

traincheck/instrumentor/proxy_wrapper/proxy_registry.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,13 @@
11
import threading
2-
import typing
3-
4-
from traincheck.utils import typename
5-
6-
if typing.TYPE_CHECKING:
7-
from .proxy import Proxy
82

93

104
class RegistryEntry:
11-
"""A class to store the proxy object and its associated metadata"""
5+
"""A class to store the tracked object and its associated metadata"""
126

13-
def __init__(self, proxy: "Proxy", stale: bool):
14-
self.proxy = proxy
7+
def __init__(self, obj, var_name, var_type, stale):
8+
self.var = obj
9+
self.var_name = var_name
10+
self.var_type = var_type
1511
self.stale = stale
1612

1713

@@ -30,14 +26,18 @@ def __init__(self):
3026
self.registry: dict[str, RegistryEntry] = {}
3127
self.registry_lock = threading.Lock()
3228

33-
def add_var(self, var: "Proxy", var_name: str):
29+
def add_var(self, var, var_name: str, var_type: str):
3430
"""Add a new proxy variable to the registry"""
3531
with self.registry_lock:
3632
if var_name in self.registry:
37-
self.registry[var_name].proxy = var
33+
self.registry[var_name].var = var
34+
self.registry[var_name].var_name = var_name
35+
self.registry[var_name].var_type = var_type
3836
self.registry[var_name].stale = False
3937
else:
40-
self.registry[var_name] = RegistryEntry(proxy=var, stale=False)
38+
self.registry[var_name] = RegistryEntry(
39+
var, var_name, var_type, stale=False
40+
)
4141

4242
def dump_sample(self, dump_loc=None):
4343
"""A complete dump of all present proxy objects
@@ -48,7 +48,7 @@ def dump_sample(self, dump_loc=None):
4848
with self.registry_lock:
4949
for _, entry in self.registry.items():
5050
entry.stale = True
51-
entry.proxy.dump_trace(phase="sample", dump_loc=dump_loc)
51+
entry.var.dump_trace(phase="sample", dump_loc=dump_loc)
5252

5353
def dump_modified(self, dump_loc=None, dump_config=None):
5454
"""Dump only the proxy variables that might be modified since last dump
@@ -73,16 +73,16 @@ def dump_modified(self, dump_loc=None, dump_config=None):
7373
"""
7474
to_dump_types = set(dump_config.keys())
7575
with self.registry_lock:
76-
for var_name, entry in self.registry.items():
77-
var_type = typename(entry.proxy._obj, is_runtime=True)
76+
for _, entry in self.registry.items():
77+
var_type = entry.var_type
7878
if var_type not in to_dump_types:
7979
continue
8080

8181
if entry.stale:
8282
continue
8383

8484
entry.stale = True
85-
entry.proxy.dump_trace(phase="selective-sample", dump_loc=dump_loc)
85+
entry.var.dump_trace(phase="selective-sample", dump_loc=dump_loc)
8686
if not dump_config[var_type]["dump_unchanged"]:
8787
# remove the var from to_dump_types so that we don't dump the same type twice
8888
to_dump_types.remove(var_type)

traincheck/instrumentor/proxy_wrapper/subclass.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from traincheck.instrumentor.dumper import dump_trace_VAR
1010
from traincheck.instrumentor.proxy_wrapper.dumper import dump_attributes, get_meta_vars
1111
from traincheck.instrumentor.tracer import TraceLineType
12-
from traincheck.utils import get_timestamp_ns
12+
from traincheck.utils import get_timestamp_ns, typename
1313

1414
from .proxy_basics import is_fake_tensor
1515
from .proxy_registry import get_global_registry
@@ -37,6 +37,7 @@ def __new__(
3737
# TODO
3838
# recurse=False,
3939
var_name="",
40+
var_type="",
4041
should_dump_trace=True,
4142
from_call=False,
4243
from_iter=False,
@@ -97,6 +98,7 @@ def __init__(
9798
# TODO
9899
# recurse=False,
99100
var_name="",
101+
var_type="",
100102
should_dump_trace=True,
101103
from_call=False,
102104
from_iter=False,
@@ -116,6 +118,7 @@ def __init__(
116118
# TODO
117119
# self.__dict__["recurse"] = recurse
118120
self.__dict__["var_name"] = var_name
121+
self.__dict__["var_type"] = var_type
119122
# TODO
120123
# self.__dict__["old_value"] = None
121124
# self.__dict__["old_meta_vars"] = None
@@ -165,7 +168,9 @@ def update_timestamp(self):
165168
# Proxy.var_dict[self.__dict__["var_name"]].last_update_timestamp = current_time
166169

167170
def register_object(self):
168-
get_global_registry().add_var(self, self.__dict__["var_name"])
171+
get_global_registry().add_var(
172+
self, self.__dict__["var_name"], self.__dict__["var_type"]
173+
)
169174

170175
def dump_trace(self, phase, dump_loc):
171176
# print(f"parameter: {self.var_name}, phase = {phase}, dump_loc = {dump_loc}")
@@ -216,11 +221,13 @@ def proxy_parameter(
216221
if in_dynamo():
217222
return
218223
for name, t in list(module.named_parameters(recurse=False)):
224+
var_type = typename(t, is_runtime=True)
219225
module._parameters[name] = ProxyParameter(
220226
t,
221227
logdir,
222228
log_level,
223229
parent_name + "." + name,
230+
var_type,
224231
should_dump_trace,
225232
from_call,
226233
from_iter,

0 commit comments

Comments
 (0)