|
1 | 1 | import ast |
| 2 | +import io |
2 | 3 | import logging |
3 | 4 | import re |
| 5 | +import tokenize |
| 6 | +from collections import deque |
| 7 | +from typing import Dict, Set |
4 | 8 |
|
5 | 9 | from traincheck.config.config import INSTR_MODULES_TO_INSTR |
6 | 10 |
|
@@ -502,6 +506,284 @@ def instrument_model_tracker_sampler( |
502 | 506 | return source |
503 | 507 |
|
504 | 508 |
|
| 509 | +def annotate_stage( |
| 510 | + source: str, |
| 511 | +) -> str: |
| 512 | + """DEBT: Refactor the source tree exploration part with a AST-based approach""" |
| 513 | + |
| 514 | + def _ctx(msg: str) -> str: |
| 515 | + return f"[annotate_stage] {msg}" |
| 516 | + |
| 517 | + def has_stage(src: str, name: str) -> bool: |
| 518 | + return re.search(rf'annotate_stage\(\s*[\'"]{name}[\'"]\s*\)', src) is not None |
| 519 | + |
| 520 | + orig_has = { |
| 521 | + "init": has_stage(source, "init"), |
| 522 | + "training": has_stage(source, "training"), |
| 523 | + "testing": has_stage(source, "testing"), |
| 524 | + "checkpointing": has_stage(source, "checkpointing"), |
| 525 | + } |
| 526 | + orig_has_any = any(orig_has.values()) or ("annotate_stage(" in source) |
| 527 | + |
| 528 | + for stage_name, present in orig_has.items(): |
| 529 | + if present: |
| 530 | + logger.info( |
| 531 | + _ctx( |
| 532 | + f"Stage '{stage_name}' already present in source; skip adding this stage." |
| 533 | + ) |
| 534 | + ) |
| 535 | + |
| 536 | + training_lines: Set[int] = set() |
| 537 | + testing_lines: Set[int] = set() |
| 538 | + checkpointing_lines: Set[int] = set() |
| 539 | + |
| 540 | + q: deque = deque(maxlen=3) |
| 541 | + for tok in tokenize.generate_tokens(io.StringIO(source).readline): |
| 542 | + q.append(tok) |
| 543 | + if len(q) < 2: |
| 544 | + continue |
| 545 | + a = q[-3] if len(q) >= 3 else None |
| 546 | + b = q[-2] |
| 547 | + c = q[-1] |
| 548 | + |
| 549 | + def at_attr(name: str) -> bool: |
| 550 | + return ( |
| 551 | + a is not None |
| 552 | + and a.type == tokenize.OP |
| 553 | + and a.string == "." |
| 554 | + and b.type == tokenize.NAME |
| 555 | + and b.string == name |
| 556 | + and c.type == tokenize.OP |
| 557 | + and c.string == "(" |
| 558 | + ) |
| 559 | + |
| 560 | + if (at_attr("train") or at_attr("step")) and not orig_has["training"]: |
| 561 | + training_lines.add(b.start[0]) |
| 562 | + |
| 563 | + if (at_attr("eval") or at_attr("no_grad")) and not orig_has["testing"]: |
| 564 | + testing_lines.add(b.start[0]) |
| 565 | + |
| 566 | + if at_attr("save") and not orig_has["checkpointing"]: |
| 567 | + checkpointing_lines.add(b.start[0]) |
| 568 | + |
| 569 | + TRAINING_PRIORITY = 3 |
| 570 | + TESTING_PRIORITY = 2 |
| 571 | + CHECKPOINTING_PRIORITY = 1 |
| 572 | + priority = { |
| 573 | + "training": TRAINING_PRIORITY, |
| 574 | + "testing": TESTING_PRIORITY, |
| 575 | + "checkpointing": CHECKPOINTING_PRIORITY, |
| 576 | + } |
| 577 | + line_to_stage: Dict[int, str] = {} |
| 578 | + for ln in checkpointing_lines: |
| 579 | + line_to_stage[ln] = "checkpointing" |
| 580 | + for ln in training_lines: |
| 581 | + if priority["training"] > priority.get(line_to_stage.get(ln, ""), 0): |
| 582 | + line_to_stage[ln] = "training" |
| 583 | + for ln in testing_lines: |
| 584 | + if priority["testing"] > priority.get(line_to_stage.get(ln, ""), 0): |
| 585 | + line_to_stage[ln] = "testing" |
| 586 | + |
| 587 | + lines = source.splitlines(keepends=True) |
| 588 | + new_lines: list[str] = [] |
| 589 | + inserted_count = { |
| 590 | + "training": 0, |
| 591 | + "testing": 0, |
| 592 | + "checkpointing": 0, |
| 593 | + "init": 0, |
| 594 | + "import": 0, |
| 595 | + } |
| 596 | + for i, line in enumerate(lines): |
| 597 | + lineno = i + 1 |
| 598 | + stage = line_to_stage.get(lineno) |
| 599 | + if stage: |
| 600 | + k = len(new_lines) - 1 |
| 601 | + while k >= 0 and new_lines[k].strip() == "": |
| 602 | + k -= 1 |
| 603 | + prev = new_lines[k] if k >= 0 else "" |
| 604 | + if not ( |
| 605 | + ("annotate_stage" in prev) |
| 606 | + and (f'"{stage}"' in prev or f"'{stage}'" in prev) |
| 607 | + ): |
| 608 | + if (m := re.match(r"\s*", line)) is None: |
| 609 | + raise ValueError("pattern not found") |
| 610 | + indent = m.group(0) |
| 611 | + new_lines.append(f'{indent}annotate_stage("{stage}")\n') |
| 612 | + inserted_count[stage] += 1 |
| 613 | + logger.info( |
| 614 | + _ctx( |
| 615 | + f"Inserted stage '{stage}' before line {lineno}: {line.strip()}" |
| 616 | + ) |
| 617 | + ) |
| 618 | + else: |
| 619 | + logger.info( |
| 620 | + _ctx( |
| 621 | + f"Skip inserting '{stage}' at line {lineno} (previous non-empty line already has it)." |
| 622 | + ) |
| 623 | + ) |
| 624 | + new_lines.append(line) |
| 625 | + |
| 626 | + new_src = "".join(new_lines) |
| 627 | + |
| 628 | + def _find_annotate_import_idx(lines): |
| 629 | + for idx, line in enumerate(lines): |
| 630 | + if re.match(r"^\s*from\s+traincheck\s+import\s+annotate_stage\s*$", line): |
| 631 | + return idx |
| 632 | + return -1 |
| 633 | + |
| 634 | + lines_list = new_src.splitlines(keepends=True) |
| 635 | + annot_import_idx = _find_annotate_import_idx(lines_list) |
| 636 | + |
| 637 | + if annot_import_idx == -1: |
| 638 | + insert_idx = 0 |
| 639 | + while insert_idx < len(lines_list): |
| 640 | + s = lines_list[insert_idx].strip() |
| 641 | + if ( |
| 642 | + lines_list[insert_idx].startswith("#!") |
| 643 | + or (s.startswith("#") and "coding" in s) |
| 644 | + or s.startswith("from __future__ import") |
| 645 | + ): |
| 646 | + insert_idx += 1 |
| 647 | + else: |
| 648 | + break |
| 649 | + lines_list.insert(insert_idx, "from traincheck import annotate_stage\n") |
| 650 | + annot_import_idx = insert_idx |
| 651 | + inserted_count["import"] += 1 |
| 652 | + logger.info( |
| 653 | + _ctx( |
| 654 | + f"Inserted import 'from traincheck import annotate_stage' at line {annot_import_idx + 1}." |
| 655 | + ) |
| 656 | + ) |
| 657 | + |
| 658 | + new_src = "".join(lines_list) |
| 659 | + |
| 660 | + if not orig_has["init"]: |
| 661 | + has_guard = ( |
| 662 | + re.search( |
| 663 | + r'^\s*if\s+__name__\s*==\s*[\'"]__main__[\'"]\s*:\s*$', new_src, re.M |
| 664 | + ) |
| 665 | + is not None |
| 666 | + ) |
| 667 | + main_def = re.search( |
| 668 | + r"^([ \t]*)def\s+main\s*\(.*?\)\s*:\s*(?:#.*)?$", new_src, re.M |
| 669 | + ) |
| 670 | + |
| 671 | + if has_guard and main_def: |
| 672 | + def_line_start = main_def.start() |
| 673 | + before_def = new_src[:def_line_start] |
| 674 | + def_line_idx = before_def.count("\n") |
| 675 | + indent = main_def.group(1) |
| 676 | + step = "\t" if ("\t" in indent and " " not in indent) else " " |
| 677 | + body_indent = indent + step |
| 678 | + |
| 679 | + nl = new_src.splitlines(keepends=True) |
| 680 | + insert_at = def_line_idx + 1 |
| 681 | + while insert_at < len(nl) and nl[insert_at].strip() == "": |
| 682 | + insert_at += 1 |
| 683 | + |
| 684 | + def _is_triple_quote(s: str) -> bool: |
| 685 | + t = s.lstrip() |
| 686 | + return t.startswith('"""') or t.startswith("'''") |
| 687 | + |
| 688 | + def is_single_line_triple_quoted_string(line: str, quote: str) -> bool: |
| 689 | + """Return True if the line is a single-line triple-quoted string using the given quote.""" |
| 690 | + return line.count(quote) >= 2 and line.lstrip().startswith(quote) |
| 691 | + |
| 692 | + if insert_at < len(nl) and _is_triple_quote(nl[insert_at]): |
| 693 | + quote = '"""' if nl[insert_at].lstrip().startswith('"""') else "'''" |
| 694 | + if is_single_line_triple_quoted_string(nl[insert_at], quote): |
| 695 | + insert_at += 1 |
| 696 | + else: |
| 697 | + insert_at += 1 |
| 698 | + while insert_at < len(nl): |
| 699 | + if quote in nl[insert_at]: |
| 700 | + insert_at += 1 |
| 701 | + break |
| 702 | + insert_at += 1 |
| 703 | + |
| 704 | + k = insert_at - 1 |
| 705 | + while k >= 0 and nl[k].strip() == "": |
| 706 | + k -= 1 |
| 707 | + prev = nl[k] if k >= 0 else "" |
| 708 | + if not (("annotate_stage" in prev) and ("init" in prev)): |
| 709 | + nl.insert(insert_at, f'{body_indent}annotate_stage("init")\n') |
| 710 | + inserted_count["init"] += 1 |
| 711 | + logger.info( |
| 712 | + _ctx( |
| 713 | + f"Inserted stage 'init' at start of main() body (line {insert_at + 1})." |
| 714 | + ) |
| 715 | + ) |
| 716 | + else: |
| 717 | + logger.info( |
| 718 | + _ctx( |
| 719 | + "Skip inserting 'init' inside main(): previous non-empty line already has it." |
| 720 | + ) |
| 721 | + ) |
| 722 | + new_src = "".join(nl) |
| 723 | + else: |
| 724 | + lines2 = new_src.splitlines(keepends=True) |
| 725 | + annot_import_idx = _find_annotate_import_idx(lines2) |
| 726 | + if annot_import_idx == -1: |
| 727 | + i = 0 |
| 728 | + while i < len(lines2): |
| 729 | + s = lines2[i].strip() |
| 730 | + if ( |
| 731 | + lines2[i].startswith("#!") |
| 732 | + or (s.startswith("#") and "coding" in s) |
| 733 | + or s.startswith("from __future__ import") |
| 734 | + ): |
| 735 | + i += 1 |
| 736 | + else: |
| 737 | + break |
| 738 | + while i < len(lines2): |
| 739 | + s = lines2[i].strip() |
| 740 | + if ( |
| 741 | + s.startswith("import ") |
| 742 | + or s.startswith("from ") |
| 743 | + or s == "" |
| 744 | + or s.startswith("#") |
| 745 | + ): |
| 746 | + i += 1 |
| 747 | + else: |
| 748 | + break |
| 749 | + insert_at = i |
| 750 | + else: |
| 751 | + insert_at = annot_import_idx + 1 |
| 752 | + |
| 753 | + k = insert_at |
| 754 | + while k < len(lines2) and lines2[k].strip() == "": |
| 755 | + k += 1 |
| 756 | + next_line = lines2[k] if k < len(lines2) else "" |
| 757 | + if not (("annotate_stage" in next_line) and ("init" in next_line)): |
| 758 | + lines2.insert(insert_at, 'annotate_stage("init")\n') |
| 759 | + inserted_count["init"] += 1 |
| 760 | + logger.info( |
| 761 | + _ctx( |
| 762 | + f"Inserted stage 'init' right after annotate_stage import at line {insert_at + 1}." |
| 763 | + ) |
| 764 | + ) |
| 765 | + else: |
| 766 | + logger.info( |
| 767 | + _ctx( |
| 768 | + "Skip inserting 'init': next non-empty line after annotate_stage import is already init." |
| 769 | + ) |
| 770 | + ) |
| 771 | + |
| 772 | + new_src = "".join(lines2) |
| 773 | + |
| 774 | + if "annotate_stage(" not in new_src and not orig_has_any: |
| 775 | + logger.error( |
| 776 | + _ctx( |
| 777 | + "Automatic insertion failed: no annotate_stage(...) found or added. Manual insertion required." |
| 778 | + ) |
| 779 | + ) |
| 780 | + raise RuntimeError( |
| 781 | + _ctx("annotate_stage insertion failed; see logs for details.") |
| 782 | + ) |
| 783 | + |
| 784 | + return new_src |
| 785 | + |
| 786 | + |
505 | 787 | def instrument_file( |
506 | 788 | path: str, |
507 | 789 | modules_to_instr: list[str], |
@@ -532,7 +814,8 @@ def instrument_file( |
532 | 814 | funcs_to_instr, |
533 | 815 | API_dump_stack_trace, |
534 | 816 | ) |
535 | | - |
| 817 | + # annotate stages |
| 818 | + instrumented_source = annotate_stage(instrumented_source) |
536 | 819 | # logging configs |
537 | 820 | logging_start_code = f""" |
538 | 821 | import os |
|
0 commit comments