@@ -137,24 +137,46 @@ def merge(a: dict, b: dict, path=[]):
137137 return func_instr_opts
138138
139139
140- def get_model_tracker_instr_opts (invariants : list [Invariant ]) -> str | None :
140+ def get_model_tracker_instr_opts (
141+ invariants : list [Invariant ], config_tracker_style : str
142+ ) -> str | None :
141143 """
142144 Get model tracker instrumentation options
143145 """
144146
145- tracker_type = None
147+ logger = logging .getLogger (__name__ )
148+ need_immediate_var_tracking = False
149+ need_var_tracking = False
146150 for inv in invariants :
147151 if inv .relation == APIContainRelation :
148152 for param in inv .params :
149153 if isinstance (param , (VarNameParam , VarTypeParam )):
150- tracker_type = "proxy"
154+ need_var_tracking = True
155+ need_immediate_var_tracking = True
151156 break
152- if tracker_type is None and inv .relation == ConsistencyRelation :
153- tracker_type = "sampler"
157+ if not need_var_tracking and inv .relation == ConsistencyRelation :
158+ need_immediate_var_tracking = False
159+ need_var_tracking = True
154160
155- if tracker_type == "proxy" :
161+ if need_var_tracking and need_immediate_var_tracking :
156162 break
157- return tracker_type
163+
164+ if need_immediate_var_tracking :
165+ if config_tracker_style in ["proxy" , "subclass" ]:
166+ return config_tracker_style
167+ else :
168+ logger .warning (
169+ f"Model tracker style { config_tracker_style } is not suitable for immediate variable tracking, using 'subclass' by default instead."
170+ )
171+ return "subclass"
172+ elif need_var_tracking :
173+ if not config_tracker_style == "sampler" :
174+ logger .warning (
175+ f"Model tracker style { config_tracker_style } is not suitable for non-immediate variable tracking, using 'sampler' by default instead."
176+ )
177+ return "sampler"
178+
179+ return None
158180
159181
160182def dump_env (args , output_dir : str ):
@@ -435,9 +457,12 @@ def main():
435457 if args .invariants :
436458 # selective instrumentation if invariants are provided, only funcs_to_instr will be instrumented with trace collection
437459 invariants = read_inv_file (args .invariants )
460+
438461 instr_opts = InstrOpt (
439462 func_instr_opts = get_per_func_instr_opts (invariants ),
440- model_tracker_style = get_model_tracker_instr_opts (invariants ),
463+ model_tracker_style = get_model_tracker_instr_opts (
464+ invariants , args .model_tracker_style
465+ ),
441466 disable_proxy_dumping = True ,
442467 )
443468 models_to_track = (
0 commit comments