Skip to content

Commit 5fb1424

Browse files
committed
[WIP] refactor: move proxy_wrapper to within the instrumentor
1 parent 43b7984 commit 5fb1424

22 files changed

Lines changed: 35 additions & 50 deletions

.github/workflows/eval-overhead-e2e.yml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,11 @@ on:
66
paths:
77
- '.github/workflows/**'
88
- 'traincheck/instrumentor/**'
9-
- 'traincheck/proxy_wrapper/**'
109
- 'traincheck/collect_trace.py'
1110
pull_request:
1211
paths:
1312
- '.github/workflows/**'
1413
- 'traincheck/instrumentor/**'
15-
- 'traincheck/proxy_wrapper/**'
1614
- 'traincheck/collect_trace.py'
1715

1816

traincheck/collect_trace.py

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
import traincheck.config.config as config
99
import traincheck.instrumentor as instrumentor
10-
import traincheck.proxy_wrapper.proxy_config as proxy_config
10+
import traincheck.instrumentor.proxy_wrapper.proxy_config as proxy_config
1111
import traincheck.runner as runner
1212
from traincheck.config.config import InstrOpt
1313
from traincheck.invariant.base_cls import (
@@ -157,19 +157,6 @@ def get_model_tracker_instr_opts(invariants: list[Invariant]) -> str | None:
157157
return tracker_type
158158

159159

160-
def get_disable_proxy_dumping(invariants: list[Invariant]) -> bool:
161-
"""
162-
Get disable proxy dumping options for checking
163-
164-
Always return True if an APIContain invariant requested proxy tracking
165-
166-
We cannot disable automatic variable dumping if only consistency relations but no APIContain
167-
require variable states, as then no APIs will trigger state dumps.
168-
However, the var tracker should be sampler if there's no APIContain anyway
169-
"""
170-
return True
171-
172-
173160
def dump_env(args, output_dir: str):
174161
with open(os.path.join(output_dir, "env_dump.txt"), "w") as f:
175162
f.write("Arguments:\n")

traincheck/instrumentor/dumper.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@
1717
)
1818

1919
# if torch.cuda.is_available():
20-
from traincheck.proxy_wrapper.hash import tensor_hash
21-
from traincheck.proxy_wrapper.proxy_basics import is_fake_tensor
22-
from traincheck.proxy_wrapper.proxy_config import (
20+
from traincheck.instrumentor.proxy_wrapper.hash import tensor_hash
21+
from traincheck.instrumentor.proxy_wrapper.proxy_basics import is_fake_tensor
22+
from traincheck.instrumentor.proxy_wrapper.proxy_config import (
2323
attribute_black_list,
2424
primitive_types,
2525
proxy_attribute,

traincheck/proxy_wrapper/Changelog.md renamed to traincheck/instrumentor/proxy_wrapper/Changelog.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ and this project adheres to [Semantic Versioning].
1313

1414
### Added
1515

16-
- Maintain global registry to proxied objects (to access the vars, use `from traincheck.proxy_wrapper.proxy import get_registered_object`)
16+
- Maintain global registry to proxied objects (to access the vars, use `from traincheck.instrumentor.proxy_wrapper.proxy import get_registered_object`)
1717
- Bypass tensor stats/hash computation if it has already been calculated
1818

1919
### Fixed
File renamed without changes.
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
# This import is necessary to make the observer utility inside torch_proxy.py executed before the instrumented code. This would ensure the observer function is successfully registred before the instrumented code is executed.
22

3-
import traincheck.proxy_wrapper.proxy_config # noqa
4-
import traincheck.proxy_wrapper.torch_proxy # noqa
3+
import traincheck.instrumentor.proxy_wrapper.proxy_config # noqa
4+
import traincheck.instrumentor.proxy_wrapper.torch_proxy # noqa

traincheck/proxy_wrapper/dumper.py renamed to traincheck/instrumentor/proxy_wrapper/dumper.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22
from typing import Dict
33

44
from traincheck.instrumentor.dumper import convert_var_to_dict
5+
from traincheck.instrumentor.proxy_wrapper.proxy_basics import is_proxied
6+
from traincheck.instrumentor.proxy_wrapper.proxy_config import primitive_types
57
from traincheck.instrumentor.tracer import TraceLineType
68
from traincheck.instrumentor.tracer import get_meta_vars as tracer_get_meta_vars
7-
from traincheck.proxy_wrapper.proxy_basics import is_proxied
8-
from traincheck.proxy_wrapper.proxy_config import primitive_types
99

1010

1111
class Singleton(type):
File renamed without changes.

traincheck/proxy_wrapper/proxy.py renamed to traincheck/instrumentor/proxy_wrapper/proxy.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@
99
import torch
1010

1111
import traincheck.config.config as general_config
12-
import traincheck.proxy_wrapper.proxy_config as proxy_config # HACK: cannot directly import config variables as then they would be local variables
13-
import traincheck.proxy_wrapper.proxy_methods as proxy_methods
14-
from traincheck.proxy_wrapper.dumper import dump_attributes, get_meta_vars
12+
import traincheck.instrumentor.proxy_wrapper.proxy_config as proxy_config # HACK: cannot directly import config variables as then they would be local variables
13+
import traincheck.instrumentor.proxy_wrapper.proxy_methods as proxy_methods
14+
from traincheck.instrumentor.proxy_wrapper.dumper import dump_attributes, get_meta_vars
1515
from traincheck.utils import get_timestamp_ns, typename
1616

1717
from .dumper import json_dumper as dumper

traincheck/proxy_wrapper/proxy_basics.py renamed to traincheck/instrumentor/proxy_wrapper/proxy_basics.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def visit_FunctionDef(self, node):
123123
self.generic_visit(node)
124124
# Inject code right after the def statement
125125
inject_code = """
126-
from traincheck.proxy_wrapper.proxy_basics import type_handle_traincheck_proxy
126+
from traincheck.instrumentor.proxy_wrapper.proxy_basics import type_handle_traincheck_proxy
127127
"""
128128
inject_node = ast.parse(inject_code).body
129129
node.body = inject_node + node.body

0 commit comments

Comments
 (0)