Skip to content

Commit 2034dfd

Browse files
committed
fix: respect configured tracker type during selective instrumentation
1 parent b14e0d4 commit 2034dfd

2 files changed

Lines changed: 40 additions & 17 deletions

File tree

traincheck/collect_trace.py

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -137,24 +137,46 @@ def merge(a: dict, b: dict, path=[]):
137137
return func_instr_opts
138138

139139

140-
def get_model_tracker_instr_opts(invariants: list[Invariant]) -> str | None:
140+
def get_model_tracker_instr_opts(
141+
invariants: list[Invariant], config_tracker_style: str
142+
) -> str | None:
141143
"""
142144
Get model tracker instrumentation options
143145
"""
144146

145-
tracker_type = None
147+
logger = logging.getLogger(__name__)
148+
need_immediate_var_tracking = False
149+
need_var_tracking = False
146150
for inv in invariants:
147151
if inv.relation == APIContainRelation:
148152
for param in inv.params:
149153
if isinstance(param, (VarNameParam, VarTypeParam)):
150-
tracker_type = "proxy"
154+
need_var_tracking = True
155+
need_immediate_var_tracking = True
151156
break
152-
if tracker_type is None and inv.relation == ConsistencyRelation:
153-
tracker_type = "sampler"
157+
if not need_var_tracking and inv.relation == ConsistencyRelation:
158+
need_immediate_var_tracking = False
159+
need_var_tracking = True
154160

155-
if tracker_type == "proxy":
161+
if need_var_tracking and need_immediate_var_tracking:
156162
break
157-
return tracker_type
163+
164+
if need_immediate_var_tracking:
165+
if config_tracker_style in ["proxy", "subclass"]:
166+
return config_tracker_style
167+
else:
168+
logger.warning(
169+
f"Model tracker style {config_tracker_style} is not suitable for immediate variable tracking, using 'subclass' by default instead."
170+
)
171+
return "subclass"
172+
elif need_var_tracking:
173+
if not config_tracker_style == "sampler":
174+
logger.warning(
175+
f"Model tracker style {config_tracker_style} is not suitable for non-immediate variable tracking, using 'sampler' by default instead."
176+
)
177+
return "sampler"
178+
179+
return None
158180

159181

160182
def dump_env(args, output_dir: str):
@@ -435,9 +457,12 @@ def main():
435457
if args.invariants:
436458
# selective instrumentation if invariants are provided, only funcs_to_instr will be instrumented with trace collection
437459
invariants = read_inv_file(args.invariants)
460+
438461
instr_opts = InstrOpt(
439462
func_instr_opts=get_per_func_instr_opts(invariants),
440-
model_tracker_style=get_model_tracker_instr_opts(invariants),
463+
model_tracker_style=get_model_tracker_instr_opts(
464+
invariants, args.model_tracker_style
465+
),
441466
disable_proxy_dumping=True,
442467
)
443468
models_to_track = (

traincheck/trace/types.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -154,17 +154,15 @@ def __init__(self, func_name: str, pre_record: dict, post_record: dict):
154154
)
155155

156156
# TODO: use the Arguments class to replace self.args and self.kwargs
157-
self.args: dict[str, dict[str, dict[str, object]]] = pre_record[
158-
"args"
159-
] # lists of [type -> attr_name -> value]
160-
self.kwargs: dict[str, dict[str, object]] = pre_record[
161-
"kwargs"
162-
] # key --> attr_name -> value
157+
self.args: dict[str, dict[str, dict[str, object]]] = pre_record.get(
158+
"args", {}
159+
) # lists of [type -> attr_name -> value]
160+
self.kwargs: dict[str, dict[str, object]] = pre_record.get("kwargs", {})
163161
self.return_values: (
164162
dict[str, dict[str, object]] | list[dict[str, dict[str, object]]]
165-
) = post_record[
166-
"return_values"
167-
] # key --> attr_name -> value
163+
) = post_record.get(
164+
"return_values", {}
165+
) # key --> attr_name -> value
168166

169167
def __str__(self):
170168
return f"FuncCallEvent: {self.func_name}"

0 commit comments

Comments
 (0)