Skip to content

Commit 0c0ebf8

Browse files
committed
WIP: strengthen instrumentation logic
1 parent 5fb1424 commit 0c0ebf8

11 files changed

Lines changed: 87 additions & 394 deletions

File tree

docs/assets/code/mnist.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@
88
from torchvision import datasets, transforms
99

1010
from traincheck import annotate_stage
11-
from traincheck.instrumentor import meta_vars
11+
from traincheck.instrumentor import META_VARS
1212

13-
meta_vars["step"] = -1
13+
META_VARS["step"] = -1
1414

1515

1616
class Net(nn.Module):
@@ -43,7 +43,7 @@ def train(args, model, device, train_loader, optimizer, epoch):
4343
annotate_stage("training") # ML_DAIKON: stage annotation
4444
model.train()
4545
for batch_idx, (data, target) in enumerate(train_loader):
46-
meta_vars["step"] += 1
46+
META_VARS["step"] += 1
4747
data, target = data.to(device), target.to(device)
4848
optimizer.zero_grad()
4949
output = model(data)

docs/assets/examples/traincheck-collect/mnist-config/mnist.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@
88
from torchvision import datasets, transforms
99

1010
from traincheck import annotate_stage
11-
from traincheck.instrumentor import meta_vars
11+
from traincheck.instrumentor import META_VARS
1212

13-
meta_vars["step"] = -1
13+
META_VARS["step"] = -1
1414

1515

1616
class Net(nn.Module):
@@ -43,7 +43,7 @@ def train(args, model, device, train_loader, optimizer, epoch):
4343
annotate_stage("training") # ML_DAIKON: stage annotation
4444
model.train()
4545
for batch_idx, (data, target) in enumerate(train_loader):
46-
meta_vars["step"] += 1
46+
META_VARS["step"] += 1
4747
data, target = data.to(device), target.to(device)
4848
optimizer.zero_grad()
4949
output = model(data)

traincheck/developer/annotations.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import traincheck.instrumentor.tracer as tracer
22
from traincheck.config.config import ALL_STAGE_NAMES
3-
from traincheck.instrumentor import meta_vars
3+
from traincheck.instrumentor import META_VARS
44

55

66
def annotate_stage(stage_name: str):
@@ -16,7 +16,7 @@ def annotate_stage(stage_name: str):
1616
stage_name in ALL_STAGE_NAMES
1717
), f"Invalid stage name: {stage_name}, valid ones are {ALL_STAGE_NAMES}"
1818

19-
meta_vars["stage"] = stage_name
19+
META_VARS["stage"] = stage_name
2020

2121

2222
def annotate_answer_start_token_ids(

traincheck/instrumentor/VFProxy.py

Lines changed: 0 additions & 60 deletions
This file was deleted.
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from .caches import meta_vars # noqa: F401
1+
from .caches import META_VARS # noqa: F401
22
from .source_file import * # noqa: F403
33
from .tracer import * # noqa: F403
44
from .tracer import VarSampler # noqa: F401

traincheck/instrumentor/caches.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,3 @@
1-
from collections import defaultdict
2-
3-
from traincheck.instrumentor.types import PTID
4-
5-
cache_meta_vars: dict[PTID, dict[str, dict]] = defaultdict(lambda: defaultdict(dict))
6-
meta_vars: dict[str, object] = {
1+
META_VARS: dict[str, object] = {
72
"step": 0,
83
}

traincheck/instrumentor/proxy_wrapper/proxy_basics.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -93,24 +93,23 @@ def unproxy_arg(arg, inspect_torch_module=False):
9393
return arg
9494

9595

96+
def unproxy_args_kwargs(args, kwargs, inspect_torch_module=False):
97+
args = [unproxy_arg(arg, inspect_torch_module) for arg in args]
98+
kwargs = {k: unproxy_arg(v) for k, v in kwargs.items()}
99+
return args, kwargs
100+
101+
96102
def unproxy_func(func, inspect_torch_module=False):
97103
original_func = func
98104

99105
@functools.wraps(original_func)
100106
def wrapper(*args, **kwargs):
101-
args = [unproxy_arg(arg, inspect_torch_module) for arg in args]
102-
kwargs = {k: unproxy_arg(v) for k, v in kwargs.items()}
107+
args, kwargs = unproxy_args_kwargs(args, kwargs, inspect_torch_module)
103108
return original_func(*args, **kwargs)
104109

105110
return wrapper
106111

107112

108-
def unproxy_args_kwargs(args, kwargs, inspect_torch_module=False):
109-
args = [unproxy_arg(arg, inspect_torch_module) for arg in args]
110-
kwargs = {k: unproxy_arg(v) for k, v in kwargs.items()}
111-
return args, kwargs
112-
113-
114113
def type_handle_traincheck_proxy(x):
115114
if hasattr(x, "is_traincheck_proxied_obj"):
116115
return type(x._obj)

traincheck/instrumentor/proxy_wrapper/proxy_observer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def wrapper(*args, **kwargs):
6161
result = processed_function(*args, **kwargs)
6262

6363
# post observe
64-
for i, var in enumerate(proxied_vars):
64+
for var in proxied_vars:
6565
observe_proxy_var(
6666
var,
6767
"post_observe",

traincheck/instrumentor/proxy_wrapper/proxy_registry.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,15 @@ def __init__(self, proxy: "Proxy", stale: bool):
1515
self.stale = stale
1616

1717

18-
class ProxyRegistry:
19-
"""A helper class managing all proxy variables being tracked and allow for controlled dumps of
18+
class VarRegistry:
19+
"""A helper class managing all variables being tracked and allow for controlled dumps of
2020
the variable states.
2121
2222
A variable is uniquely identified by its "name"
23+
When a variable is added to the registry, it is marked as "not stale".
24+
When a variable is dumped through `dump_sample` or `dump_modified`, it is marked as "stale".
25+
A variable is only dumped through `dump_modified` if it is not stale.
26+
2327
"""
2428

2529
def __init__(self):
@@ -29,20 +33,24 @@ def __init__(self):
2933
def add_var(self, var: "Proxy", var_name: str):
3034
"""Add a new proxy variable to the registry"""
3135
with self.registry_lock:
32-
self.registry[var_name] = RegistryEntry(proxy=var, stale=False)
36+
if var_name in self.registry:
37+
self.registry[var_name].proxy = var
38+
self.registry[var_name].stale = False
39+
else:
40+
self.registry[var_name] = RegistryEntry(proxy=var, stale=False)
3341

3442
def dump_sample(self, dump_loc=None):
3543
"""A complete dump of all present proxy objects
3644
3745
Calling this API mark all proxy objects as stale which
38-
will affect the `dump_only_modified` API.
46+
will affect the `dump_modified` API.
3947
"""
4048
with self.registry_lock:
41-
for var_name, entry in self.registry.items():
49+
for _, entry in self.registry.items():
4250
entry.stale = True
4351
entry.proxy.dump_trace(phase="sample", dump_loc=dump_loc)
4452

45-
def dump_only_modified(self, dump_loc=None, dump_config=None):
53+
def dump_modified(self, dump_loc=None, dump_config=None):
4654
"""Dump only the proxy variables that might be modified since last dump
4755
4856
args:
@@ -81,7 +89,7 @@ def dump_only_modified(self, dump_loc=None, dump_config=None):
8189

8290

8391
# Global dictionary to store registered objects
84-
global_registry = ProxyRegistry()
92+
global_registry = VarRegistry()
8593

8694

8795
def get_global_registry():

traincheck/instrumentor/proxy_wrapper/subclass.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
from traincheck.utils import get_timestamp_ns
1313

1414
from .proxy_basics import is_fake_tensor
15+
from .proxy_registry import get_global_registry
1516

16-
# from .proxy_registry import get_global_registry
1717
# from .utils import print_debug
1818

1919

@@ -165,9 +165,9 @@ def update_timestamp(self):
165165
# Proxy.var_dict[self.__dict__["var_name"]].last_update_timestamp = current_time
166166

167167
def register_object(self):
168-
# get_global_registry().add_var(self, self.__dict__["var_name"])
168+
get_global_registry().add_var(self, self.__dict__["var_name"])
169169
# TODO: implement the registry, we will need to make sure the registerred timestamp is updated and is consistent with the timestamp in the object
170-
pass
170+
# pass
171171

172172
def dump_trace(self, phase, dump_loc):
173173
# print(f"parameter: {self.var_name}, phase = {phase}, dump_loc = {dump_loc}")

0 commit comments

Comments
 (0)