|
| 1 | +# Copyright 2026 Google LLC |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# https://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +"""Verify that every committed .ipynb has a corresponding pytest notebook test. |
| 16 | +
|
| 17 | +For each committed notebook, checks that: |
| 18 | +1. A *_test.py file exists containing an execute_notebook('name') call |
| 19 | +2. That test function is decorated with @pytest.mark.notebook |
| 20 | +
|
| 21 | +Usage: |
| 22 | + python dev_tools/check-notebook-tests.py |
| 23 | +""" |
| 24 | + |
| 25 | +import ast |
| 26 | +import subprocess |
| 27 | +import sys |
| 28 | +from pathlib import Path |
| 29 | +from typing import Dict, Tuple |
| 30 | + |
| 31 | +from qualtran_dev_tools.git_tools import get_git_root |
| 32 | + |
| 33 | +_EXCLUDED_DIRS = {'dev_tools'} |
| 34 | + |
| 35 | + |
| 36 | +def get_committed_notebooks(reporoot: Path) -> Dict[str, Path]: |
| 37 | + """Return {stem: relative_path} for all committed .ipynb files under reporoot. |
| 38 | +
|
| 39 | + Excludes notebooks in dev_tools/ since those are developer utilities, |
| 40 | + not user-facing documentation. |
| 41 | + """ |
| 42 | + result = subprocess.run( |
| 43 | + ['git', 'ls-files', '*.ipynb'], capture_output=True, text=True, check=True, cwd=reporoot |
| 44 | + ) |
| 45 | + return { |
| 46 | + Path(f).stem: Path(f) |
| 47 | + for f in result.stdout.strip().split('\n') |
| 48 | + if f and not any(Path(f).parts[0] == d for d in _EXCLUDED_DIRS) |
| 49 | + } |
| 50 | + |
| 51 | + |
| 52 | +def _is_notebook_marker(decorator: ast.expr) -> bool: |
| 53 | + """Check if a decorator is @pytest.mark.notebook.""" |
| 54 | + # Handle pytest.mark.notebook (attr chain) |
| 55 | + if isinstance(decorator, ast.Attribute) and decorator.attr == 'notebook': |
| 56 | + return True |
| 57 | + return False |
| 58 | + |
| 59 | + |
| 60 | +def find_notebook_tests(reporoot: Path) -> Dict[str, Tuple[Path, bool]]: |
| 61 | + """Find all execute_notebook() calls in test files. |
| 62 | +
|
| 63 | + Searches all *_test.py files under the repo root (including qualtran/ |
| 64 | + and tutorials/). |
| 65 | +
|
| 66 | + Returns {notebook_name: (test_file_path_relative, has_notebook_marker)}. |
| 67 | + """ |
| 68 | + results: Dict[str, Tuple[Path, bool]] = {} |
| 69 | + for test_file in reporoot.rglob('*_test.py'): |
| 70 | + try: |
| 71 | + tree = ast.parse(test_file.read_text()) |
| 72 | + except SyntaxError: |
| 73 | + continue |
| 74 | + |
| 75 | + for node in ast.walk(tree): |
| 76 | + if not isinstance(node, ast.FunctionDef): |
| 77 | + continue |
| 78 | + # Check if function body contains execute_notebook('xxx') |
| 79 | + for child in ast.walk(node): |
| 80 | + if ( |
| 81 | + isinstance(child, ast.Call) |
| 82 | + and _is_execute_notebook_call(child) |
| 83 | + and child.args |
| 84 | + and isinstance(child.args[0], ast.Constant) |
| 85 | + ): |
| 86 | + nb_name = child.args[0].value |
| 87 | + # Check for @pytest.mark.notebook decorator |
| 88 | + has_marker = any(_is_notebook_marker(dec) for dec in node.decorator_list) |
| 89 | + results[nb_name] = (test_file.relative_to(reporoot), has_marker) |
| 90 | + return results |
| 91 | + |
| 92 | + |
| 93 | +def _is_execute_notebook_call(node: ast.Call) -> bool: |
| 94 | + """Check if a Call node is a call to execute_notebook (with or without module prefix).""" |
| 95 | + if isinstance(node.func, ast.Attribute) and node.func.attr == 'execute_notebook': |
| 96 | + return True |
| 97 | + if isinstance(node.func, ast.Name) and node.func.id == 'execute_notebook': |
| 98 | + return True |
| 99 | + return False |
| 100 | + |
| 101 | + |
| 102 | +def main(): |
| 103 | + reporoot = get_git_root() |
| 104 | + |
| 105 | + committed = get_committed_notebooks(reporoot) |
| 106 | + tested = find_notebook_tests(reporoot) |
| 107 | + |
| 108 | + errors = [] |
| 109 | + |
| 110 | + for stem, nb_rel_path in sorted(committed.items()): |
| 111 | + if stem not in tested: |
| 112 | + # Suggest the likely test file location |
| 113 | + nb_dir = nb_rel_path.parent |
| 114 | + test_file = nb_dir / f'{stem}_test.py' |
| 115 | + errors.append( |
| 116 | + f" MISSING TEST: {nb_rel_path}\n" |
| 117 | + f" Add to {test_file}:\n" |
| 118 | + f"\n" |
| 119 | + f" @pytest.mark.notebook\n" |
| 120 | + f" def test_{stem}_notebook():\n" |
| 121 | + f" qlt_testing.execute_notebook('{stem}')\n" |
| 122 | + ) |
| 123 | + else: |
| 124 | + test_file, has_marker = tested[stem] |
| 125 | + if not has_marker: |
| 126 | + errors.append( |
| 127 | + f" MISSING MARKER: {nb_rel_path}\n" |
| 128 | + f" The test in {test_file} calls execute_notebook('{stem}')\n" |
| 129 | + f" but is not decorated with @pytest.mark.notebook.\n" |
| 130 | + f" Add the decorator so this test runs in the notebooks CI job:\n" |
| 131 | + f"\n" |
| 132 | + f" @pytest.mark.notebook\n" |
| 133 | + f" def test_...():\n" |
| 134 | + ) |
| 135 | + |
| 136 | + if errors: |
| 137 | + print(f"ERROR: {len(errors)} notebook(s) have issues:\n") |
| 138 | + print("\n".join(errors)) |
| 139 | + sys.exit(1) |
| 140 | + |
| 141 | + print(f"OK: All {len(committed)} notebooks have properly marked tests.") |
| 142 | + |
| 143 | + |
| 144 | +if __name__ == '__main__': |
| 145 | + main() |
0 commit comments