Skip to content

Commit 2528f19

Browse files
liurt1218Essoz
andauthored
add: annotate stage logic (#133)
* add automatic stage annotation logic --------- Co-authored-by: Yuxuan Jiang <jyuxuan@umich.edu>
1 parent 5b1eb28 commit 2528f19

2 files changed

Lines changed: 291 additions & 1 deletion

File tree

traincheck/instrumentor/source_file.py

Lines changed: 284 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
import ast
2+
import io
23
import logging
34
import re
5+
import tokenize
6+
from collections import deque
7+
from typing import Dict, Set
48

59
from traincheck.config.config import INSTR_MODULES_TO_INSTR
610

@@ -502,6 +506,284 @@ def instrument_model_tracker_sampler(
502506
return source
503507

504508

509+
def annotate_stage(
510+
source: str,
511+
) -> str:
512+
"""DEBT: Refactor the source tree exploration part with a AST-based approach"""
513+
514+
def _ctx(msg: str) -> str:
515+
return f"[annotate_stage] {msg}"
516+
517+
def has_stage(src: str, name: str) -> bool:
518+
return re.search(rf'annotate_stage\(\s*[\'"]{name}[\'"]\s*\)', src) is not None
519+
520+
orig_has = {
521+
"init": has_stage(source, "init"),
522+
"training": has_stage(source, "training"),
523+
"testing": has_stage(source, "testing"),
524+
"checkpointing": has_stage(source, "checkpointing"),
525+
}
526+
orig_has_any = any(orig_has.values()) or ("annotate_stage(" in source)
527+
528+
for stage_name, present in orig_has.items():
529+
if present:
530+
logger.info(
531+
_ctx(
532+
f"Stage '{stage_name}' already present in source; skip adding this stage."
533+
)
534+
)
535+
536+
training_lines: Set[int] = set()
537+
testing_lines: Set[int] = set()
538+
checkpointing_lines: Set[int] = set()
539+
540+
q: deque = deque(maxlen=3)
541+
for tok in tokenize.generate_tokens(io.StringIO(source).readline):
542+
q.append(tok)
543+
if len(q) < 2:
544+
continue
545+
a = q[-3] if len(q) >= 3 else None
546+
b = q[-2]
547+
c = q[-1]
548+
549+
def at_attr(name: str) -> bool:
550+
return (
551+
a is not None
552+
and a.type == tokenize.OP
553+
and a.string == "."
554+
and b.type == tokenize.NAME
555+
and b.string == name
556+
and c.type == tokenize.OP
557+
and c.string == "("
558+
)
559+
560+
if (at_attr("train") or at_attr("step")) and not orig_has["training"]:
561+
training_lines.add(b.start[0])
562+
563+
if (at_attr("eval") or at_attr("no_grad")) and not orig_has["testing"]:
564+
testing_lines.add(b.start[0])
565+
566+
if at_attr("save") and not orig_has["checkpointing"]:
567+
checkpointing_lines.add(b.start[0])
568+
569+
TRAINING_PRIORITY = 3
570+
TESTING_PRIORITY = 2
571+
CHECKPOINTING_PRIORITY = 1
572+
priority = {
573+
"training": TRAINING_PRIORITY,
574+
"testing": TESTING_PRIORITY,
575+
"checkpointing": CHECKPOINTING_PRIORITY,
576+
}
577+
line_to_stage: Dict[int, str] = {}
578+
for ln in checkpointing_lines:
579+
line_to_stage[ln] = "checkpointing"
580+
for ln in training_lines:
581+
if priority["training"] > priority.get(line_to_stage.get(ln, ""), 0):
582+
line_to_stage[ln] = "training"
583+
for ln in testing_lines:
584+
if priority["testing"] > priority.get(line_to_stage.get(ln, ""), 0):
585+
line_to_stage[ln] = "testing"
586+
587+
lines = source.splitlines(keepends=True)
588+
new_lines: list[str] = []
589+
inserted_count = {
590+
"training": 0,
591+
"testing": 0,
592+
"checkpointing": 0,
593+
"init": 0,
594+
"import": 0,
595+
}
596+
for i, line in enumerate(lines):
597+
lineno = i + 1
598+
stage = line_to_stage.get(lineno)
599+
if stage:
600+
k = len(new_lines) - 1
601+
while k >= 0 and new_lines[k].strip() == "":
602+
k -= 1
603+
prev = new_lines[k] if k >= 0 else ""
604+
if not (
605+
("annotate_stage" in prev)
606+
and (f'"{stage}"' in prev or f"'{stage}'" in prev)
607+
):
608+
if (m := re.match(r"\s*", line)) is None:
609+
raise ValueError("pattern not found")
610+
indent = m.group(0)
611+
new_lines.append(f'{indent}annotate_stage("{stage}")\n')
612+
inserted_count[stage] += 1
613+
logger.info(
614+
_ctx(
615+
f"Inserted stage '{stage}' before line {lineno}: {line.strip()}"
616+
)
617+
)
618+
else:
619+
logger.info(
620+
_ctx(
621+
f"Skip inserting '{stage}' at line {lineno} (previous non-empty line already has it)."
622+
)
623+
)
624+
new_lines.append(line)
625+
626+
new_src = "".join(new_lines)
627+
628+
def _find_annotate_import_idx(lines):
629+
for idx, line in enumerate(lines):
630+
if re.match(r"^\s*from\s+traincheck\s+import\s+annotate_stage\s*$", line):
631+
return idx
632+
return -1
633+
634+
lines_list = new_src.splitlines(keepends=True)
635+
annot_import_idx = _find_annotate_import_idx(lines_list)
636+
637+
if annot_import_idx == -1:
638+
insert_idx = 0
639+
while insert_idx < len(lines_list):
640+
s = lines_list[insert_idx].strip()
641+
if (
642+
lines_list[insert_idx].startswith("#!")
643+
or (s.startswith("#") and "coding" in s)
644+
or s.startswith("from __future__ import")
645+
):
646+
insert_idx += 1
647+
else:
648+
break
649+
lines_list.insert(insert_idx, "from traincheck import annotate_stage\n")
650+
annot_import_idx = insert_idx
651+
inserted_count["import"] += 1
652+
logger.info(
653+
_ctx(
654+
f"Inserted import 'from traincheck import annotate_stage' at line {annot_import_idx + 1}."
655+
)
656+
)
657+
658+
new_src = "".join(lines_list)
659+
660+
if not orig_has["init"]:
661+
has_guard = (
662+
re.search(
663+
r'^\s*if\s+__name__\s*==\s*[\'"]__main__[\'"]\s*:\s*$', new_src, re.M
664+
)
665+
is not None
666+
)
667+
main_def = re.search(
668+
r"^([ \t]*)def\s+main\s*\(.*?\)\s*:\s*(?:#.*)?$", new_src, re.M
669+
)
670+
671+
if has_guard and main_def:
672+
def_line_start = main_def.start()
673+
before_def = new_src[:def_line_start]
674+
def_line_idx = before_def.count("\n")
675+
indent = main_def.group(1)
676+
step = "\t" if ("\t" in indent and " " not in indent) else " "
677+
body_indent = indent + step
678+
679+
nl = new_src.splitlines(keepends=True)
680+
insert_at = def_line_idx + 1
681+
while insert_at < len(nl) and nl[insert_at].strip() == "":
682+
insert_at += 1
683+
684+
def _is_triple_quote(s: str) -> bool:
685+
t = s.lstrip()
686+
return t.startswith('"""') or t.startswith("'''")
687+
688+
def is_single_line_triple_quoted_string(line: str, quote: str) -> bool:
689+
"""Return True if the line is a single-line triple-quoted string using the given quote."""
690+
return line.count(quote) >= 2 and line.lstrip().startswith(quote)
691+
692+
if insert_at < len(nl) and _is_triple_quote(nl[insert_at]):
693+
quote = '"""' if nl[insert_at].lstrip().startswith('"""') else "'''"
694+
if is_single_line_triple_quoted_string(nl[insert_at], quote):
695+
insert_at += 1
696+
else:
697+
insert_at += 1
698+
while insert_at < len(nl):
699+
if quote in nl[insert_at]:
700+
insert_at += 1
701+
break
702+
insert_at += 1
703+
704+
k = insert_at - 1
705+
while k >= 0 and nl[k].strip() == "":
706+
k -= 1
707+
prev = nl[k] if k >= 0 else ""
708+
if not (("annotate_stage" in prev) and ("init" in prev)):
709+
nl.insert(insert_at, f'{body_indent}annotate_stage("init")\n')
710+
inserted_count["init"] += 1
711+
logger.info(
712+
_ctx(
713+
f"Inserted stage 'init' at start of main() body (line {insert_at + 1})."
714+
)
715+
)
716+
else:
717+
logger.info(
718+
_ctx(
719+
"Skip inserting 'init' inside main(): previous non-empty line already has it."
720+
)
721+
)
722+
new_src = "".join(nl)
723+
else:
724+
lines2 = new_src.splitlines(keepends=True)
725+
annot_import_idx = _find_annotate_import_idx(lines2)
726+
if annot_import_idx == -1:
727+
i = 0
728+
while i < len(lines2):
729+
s = lines2[i].strip()
730+
if (
731+
lines2[i].startswith("#!")
732+
or (s.startswith("#") and "coding" in s)
733+
or s.startswith("from __future__ import")
734+
):
735+
i += 1
736+
else:
737+
break
738+
while i < len(lines2):
739+
s = lines2[i].strip()
740+
if (
741+
s.startswith("import ")
742+
or s.startswith("from ")
743+
or s == ""
744+
or s.startswith("#")
745+
):
746+
i += 1
747+
else:
748+
break
749+
insert_at = i
750+
else:
751+
insert_at = annot_import_idx + 1
752+
753+
k = insert_at
754+
while k < len(lines2) and lines2[k].strip() == "":
755+
k += 1
756+
next_line = lines2[k] if k < len(lines2) else ""
757+
if not (("annotate_stage" in next_line) and ("init" in next_line)):
758+
lines2.insert(insert_at, 'annotate_stage("init")\n')
759+
inserted_count["init"] += 1
760+
logger.info(
761+
_ctx(
762+
f"Inserted stage 'init' right after annotate_stage import at line {insert_at + 1}."
763+
)
764+
)
765+
else:
766+
logger.info(
767+
_ctx(
768+
"Skip inserting 'init': next non-empty line after annotate_stage import is already init."
769+
)
770+
)
771+
772+
new_src = "".join(lines2)
773+
774+
if "annotate_stage(" not in new_src and not orig_has_any:
775+
logger.error(
776+
_ctx(
777+
"Automatic insertion failed: no annotate_stage(...) found or added. Manual insertion required."
778+
)
779+
)
780+
raise RuntimeError(
781+
_ctx("annotate_stage insertion failed; see logs for details.")
782+
)
783+
784+
return new_src
785+
786+
505787
def instrument_file(
506788
path: str,
507789
modules_to_instr: list[str],
@@ -532,7 +814,8 @@ def instrument_file(
532814
funcs_to_instr,
533815
API_dump_stack_trace,
534816
)
535-
817+
# annotate stages
818+
instrumented_source = annotate_stage(instrumented_source)
536819
# logging configs
537820
logging_start_code = f"""
538821
import os

traincheck/trace/trace_pandas.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,13 @@ def _rm_incomplete_trailing_func_calls(self):
183183
self.events.groupby("func_call_id").size().reset_index(name="count")
184184
)
185185

186+
multiple_func_call_ids = func_call_groups[func_call_groups["count"] > 2][
187+
"func_call_id"
188+
]
189+
assert (
190+
len(multiple_func_call_ids) == 0
191+
), "more than 2 events for one func call id"
192+
186193
incomplete_func_call_ids = func_call_groups[func_call_groups["count"] == 1][
187194
"func_call_id"
188195
]

0 commit comments

Comments
 (0)