diff --git a/changelog.d/policyengine-bundle-v2-import.added.md b/changelog.d/policyengine-bundle-v2-import.added.md new file mode 100644 index 00000000..8abcd0a1 --- /dev/null +++ b/changelog.d/policyengine-bundle-v2-import.added.md @@ -0,0 +1 @@ +Add schema-v2 `policyengine-bundles` archive import support for future bundle-driven release updates. diff --git a/scripts/import_policyengine_bundle.py b/scripts/import_policyengine_bundle.py index 7f02d823..083f6841 100644 --- a/scripts/import_policyengine_bundle.py +++ b/scripts/import_policyengine_bundle.py @@ -1,666 +1,13 @@ from __future__ import annotations -import argparse -import hashlib -import json -import shutil +import os import sys -import tarfile -import tempfile -import urllib.request from pathlib import Path -from typing import Any, Optional -DEFAULT_RELEASE_BASE_URL = ( - "https://github.com/PolicyEngine/policyengine-bundles/releases/download" -) -REPO_ROOT = Path(__file__).resolve().parents[1] -DEFAULT_BUNDLE_DIR = REPO_ROOT / "src" / "policyengine" / "data" / "bundle" -DEFAULT_RELEASE_MANIFEST_DIR = ( - REPO_ROOT / "src" / "policyengine" / "data" / "release_manifests" -) -DEFAULT_PYPROJECT = REPO_ROOT / "pyproject.toml" -DEFAULT_CHANGELOG_DIR = REPO_ROOT / "changelog.d" -COUNTRY_OPTIONAL_DEPENDENCIES = { - "uk": "policyengine-uk", - "us": "policyengine-us", -} - - -class BundleImportError(RuntimeError): - """Raised when a PolicyEngine bundle cannot be imported into policyengine.py.""" - - -def main() -> int: - parser = argparse.ArgumentParser( - description=( - "Import one policyengine-bundles release into policyengine.py. " - "The script verifies release assets, vendors the exploded bundle, " - "regenerates country release manifests, and updates country extras." - ) - ) - parser.add_argument("version", help="Bundle version to import, e.g. 4.14.0.") - parser.add_argument( - "--dist-dir", - type=Path, - help="Use local release assets instead of downloading from GitHub.", - ) - parser.add_argument( - "--base-url", - default=DEFAULT_RELEASE_BASE_URL, - help="GitHub release base URL used when --dist-dir is not provided.", - ) - parser.add_argument("--bundle-dir", type=Path, default=DEFAULT_BUNDLE_DIR) - parser.add_argument( - "--release-manifest-dir", - type=Path, - default=DEFAULT_RELEASE_MANIFEST_DIR, - ) - parser.add_argument("--pyproject", type=Path, default=DEFAULT_PYPROJECT) - parser.add_argument("--changelog-dir", type=Path, default=DEFAULT_CHANGELOG_DIR) - parser.add_argument( - "--no-changelog", - action="store_true", - help="Do not write a towncrier changelog fragment.", - ) - args = parser.parse_args() - - try: - imported = import_policyengine_bundle( - version=args.version, - dist_dir=args.dist_dir, - base_url=args.base_url, - bundle_dir=args.bundle_dir, - release_manifest_dir=args.release_manifest_dir, - pyproject_path=args.pyproject, - changelog_dir=None if args.no_changelog else args.changelog_dir, - ) - except BundleImportError as exc: - print(f"error: {exc}", file=sys.stderr) - return 1 - - print(f"imported bundle: {imported.bundle_dir}") - for manifest_path in imported.release_manifest_paths: - print(f"release manifest: {manifest_path}") - print(f"updated pyproject: {imported.pyproject_path}") - if imported.changelog_path is not None: - print(f"changelog: {imported.changelog_path}") - return 0 - - -class ImportResult: - def __init__( - self, - *, - bundle_dir: Path, - release_manifest_paths: list[Path], - pyproject_path: Path, - changelog_path: Optional[Path], - ) -> None: - self.bundle_dir = bundle_dir - self.release_manifest_paths = release_manifest_paths - self.pyproject_path = pyproject_path - self.changelog_path = changelog_path - - -def import_policyengine_bundle( - *, - version: str, - dist_dir: Optional[Path], - base_url: str, - bundle_dir: Path, - release_manifest_dir: Path, - pyproject_path: Path, - changelog_dir: Optional[Path], -) -> ImportResult: - with tempfile.TemporaryDirectory() as temp_dir: - temp_path = Path(temp_dir) - assets_dir = dist_dir or temp_path / "dist" - if dist_dir is None: - _download_release_assets( - version=version, - output_dir=assets_dir, - base_url=base_url, - ) - - archive_path, summary = _verify_release_assets( - version=version, - dist_dir=assets_dir, - ) - unpacked_bundle_dir = _extract_bundle_archive( - archive_path=archive_path, - output_dir=temp_path / "unpacked", - version=version, - ) - _verify_bundle_digest(unpacked_bundle_dir, summary) - - if bundle_dir.exists(): - shutil.rmtree(bundle_dir) - shutil.copytree(unpacked_bundle_dir, bundle_dir) - - bundle = _load_json(bundle_dir / "bundle.json") - country_manifest_paths = _write_country_release_manifests( - bundle_dir=bundle_dir, - bundle=bundle, - release_manifest_dir=release_manifest_dir, - ) - _update_optional_dependency_pins( - pyproject_path=pyproject_path, - bundle=bundle, - ) - changelog_path = None - if changelog_dir is not None: - changelog_path = _write_changelog_fragment( - changelog_dir=changelog_dir, - version=version, - bundle=bundle, - ) - - return ImportResult( - bundle_dir=bundle_dir, - release_manifest_paths=country_manifest_paths, - pyproject_path=pyproject_path, - changelog_path=changelog_path, - ) - - -def _download_release_assets( - *, - version: str, - output_dir: Path, - base_url: str, -) -> None: - output_dir.mkdir(parents=True, exist_ok=True) - for asset_name in _release_asset_names(version): - url = f"{base_url.rstrip('/')}/v{version}/{asset_name}" - output_path = output_dir / asset_name - try: - urllib.request.urlretrieve(url, output_path) - except OSError as exc: - raise BundleImportError(f"Could not download {url}: {exc}") from exc - - -def _verify_release_assets(*, version: str, dist_dir: Path) -> tuple[Path, dict]: - archive_name, checksum_name, summary_name = _release_asset_names(version) - archive_path = dist_dir / archive_name - checksum_path = dist_dir / checksum_name - summary_path = dist_dir / summary_name - missing = [ - path.name - for path in (archive_path, checksum_path, summary_path) - if not path.exists() - ] - if missing: - raise BundleImportError(f"Missing bundle release assets: {', '.join(missing)}.") - - summary = _load_json(summary_path) - if summary.get("bundle_version") != version: - raise BundleImportError( - "Release summary bundle_version does not match requested version: " - f"expected {version}, got {summary.get('bundle_version')}." - ) - if summary.get("archive") != archive_name: - raise BundleImportError( - "Release summary archive name does not match expected asset: " - f"expected {archive_name}, got {summary.get('archive')}." - ) - - checksum = _read_checksum_file(checksum_path, archive_name) - if summary.get("archive_sha256") != checksum: - raise BundleImportError( - "Release summary archive_sha256 does not match checksum file: " - f"expected {summary.get('archive_sha256')}, got {checksum}." - ) - actual_checksum = _sha256_file(archive_path) - if actual_checksum != checksum: - raise BundleImportError( - "Archive sha256 does not match checksum file: " - f"expected {checksum}, got {actual_checksum}." - ) - return archive_path, summary - - -def _extract_bundle_archive( - *, - archive_path: Path, - output_dir: Path, - version: str, -) -> Path: - expected_root = f"policyengine-bundle-{version}" - output_dir.mkdir(parents=True, exist_ok=True) - try: - with tarfile.open(archive_path) as archive: - _validate_archive_members(archive, expected_root) - if sys.version_info >= (3, 12): - archive.extractall(output_dir, filter="data") - else: - archive.extractall(output_dir) - except (tarfile.TarError, OSError) as exc: - raise BundleImportError(f"Could not extract {archive_path}: {exc}") from exc - - bundle_dir = output_dir / expected_root - if not bundle_dir.is_dir(): - raise BundleImportError(f"Archive did not contain {expected_root}/.") - return bundle_dir - - -def _validate_archive_members(archive: tarfile.TarFile, expected_root: str) -> None: - root = Path(expected_root) - for member in archive.getmembers(): - member_path = Path(member.name) - if member_path.is_absolute() or ".." in member_path.parts: - raise BundleImportError(f"Unsafe archive member path: {member.name}") - if member_path.parts[:1] != root.parts: - raise BundleImportError( - f"Archive member is outside {expected_root}/: {member.name}" - ) - if member.issym() or member.islnk(): - raise BundleImportError( - f"Archive link members are not allowed: {member.name}" - ) - - -def _verify_bundle_digest(bundle_dir: Path, summary: dict) -> None: - expected = summary.get("bundle_digest") - if not isinstance(expected, str) or not expected.startswith("sha256:"): - raise BundleImportError("Release summary does not include bundle_digest.") - actual = f"sha256:{_bundle_directory_digest(bundle_dir)}" - if actual != expected: - raise BundleImportError( - "Release summary bundle_digest does not match unpacked bundle: " - f"expected {expected}, got {actual}." - ) - - -def _bundle_directory_digest(bundle_dir: Path) -> str: - hasher = hashlib.sha256() - for relative_path in _bundle_files(bundle_dir): - content = _normalized_file_content(bundle_dir, relative_path) - hasher.update(relative_path.as_posix().encode("utf-8")) - hasher.update(b"\0") - hasher.update(content.encode("utf-8")) - hasher.update(b"\0") - return hasher.hexdigest() - - -def _bundle_files(bundle_dir: Path) -> list[Path]: - return sorted( - path.relative_to(bundle_dir) - for path in bundle_dir.rglob("*") - if path.is_file() and path.name != ".DS_Store" - ) - - -def _normalized_file_content(bundle_dir: Path, relative_path: Path) -> str: - path = bundle_dir / relative_path - if relative_path.suffix == ".json": - payload = _load_json(path) - if relative_path.as_posix() == "bundle.json": - payload.pop("created_at", None) - payload.pop("bundle_digest", None) - elif relative_path.as_posix() == "validation-report.json": - payload.pop("generated_at", None) - checks = [] - for check in payload.get("checks", []): - if not isinstance(check, dict): - checks.append(check) - continue - check_payload = dict(check) - check_payload.pop("command", None) - check_payload.pop("started_at", None) - check_payload.pop("ended_at", None) - details = check_payload.get("details") - if isinstance(details, dict): - details_payload = dict(details) - details_payload.pop("validated_on_platform", None) - check_payload["details"] = details_payload - checks.append(check_payload) - payload["checks"] = checks - return json.dumps(payload, indent=2, sort_keys=True) + "\n" - text = path.read_text() - if path.name in {"constraints.txt", "pylock.toml"}: - return _strip_comment_lines(text) - return text - - -def _strip_comment_lines(text: str) -> str: - lines = [line for line in text.splitlines() if not line.lstrip().startswith("#")] - return "\n".join(lines) + ("\n" if text.endswith("\n") else "") - - -def _write_country_release_manifests( - *, - bundle_dir: Path, - bundle: dict, - release_manifest_dir: Path, -) -> list[Path]: - country_paths = bundle.get("countries") - if not isinstance(country_paths, dict) or not country_paths: - raise BundleImportError("Bundle manifest does not include countries.") - - release_manifest_dir.mkdir(parents=True, exist_ok=True) - written_paths = [] - for country_id, relative_path in sorted(country_paths.items()): - if not isinstance(country_id, str) or not isinstance(relative_path, str): - raise BundleImportError("Bundle countries must map ids to paths.") - country_bundle = _load_json(bundle_dir / relative_path) - release_manifest = _country_release_manifest(country_bundle) - output_path = release_manifest_dir / f"{country_id}.json" - _write_json(output_path, release_manifest) - written_paths.append(output_path) - return written_paths - - -def _country_release_manifest(country_bundle: dict) -> dict: - country_id = _required_string(country_bundle, "country_id") - bundle_version = _required_string(country_bundle, "bundle_version") - data_package = _required_dict(country_bundle, "data_package") - certification = _required_dict(country_bundle, "certification") - datasets = _required_dict(country_bundle, "datasets") - default_dataset = _required_string(country_bundle, "default_dataset") - default_artifact = _required_dict(datasets, default_dataset) - - data_package_payload = { - "name": _required_string(data_package, "name"), - "version": _required_string(data_package, "version"), - "repo_id": _required_string(data_package, "repo_id"), - "repo_type": data_package.get("repo_type", "model"), - "release_manifest_path": data_package.get( - "release_manifest_path", "release_manifest.json" - ), - } - release_manifest_revision = data_package.get("release_manifest_revision") - if release_manifest_revision: - data_package_payload["release_manifest_revision"] = release_manifest_revision - - return { - "schema_version": 1, - "bundle_id": f"{country_id}-{bundle_version}", - "country_id": country_id, - "policyengine_version": bundle_version, - "model_package": _package_version(country_bundle["model_package"]), - "data_package": data_package_payload, - "default_dataset": default_dataset, - "datasets": _dataset_path_references(datasets), - "region_datasets": _region_dataset_templates( - country_bundle.get("region_datasets", {}) - ), - "certified_data_artifact": { - "data_package": { - "name": data_package_payload["name"], - "version": data_package_payload["version"], - }, - "dataset": default_dataset, - "uri": _artifact_uri(default_artifact), - "sha256": default_artifact.get("sha256"), - "build_id": certification.get("data_build_id"), - }, - "certification": { - "compatibility_basis": _required_string( - certification, "compatibility_basis" - ), - "data_build_id": certification.get("data_build_id"), - "built_with_model_version": _package_pin_version( - certification.get("built_with_model_package") - ), - "built_with_model_git_sha": _package_pin_git_sha( - certification.get("built_with_model_package") - ), - "certified_for_model_version": _package_pin_version( - certification.get("certified_for_model_package") - ), - "data_build_fingerprint": certification.get("data_build_fingerprint"), - "certified_by": certification.get("certified_by"), - }, - } - - -def _package_version(package: dict) -> dict: - payload = { - "name": _required_string(package, "name"), - "version": _required_string(package, "version"), - } - if package.get("sha256"): - payload["sha256"] = package["sha256"] - if package.get("wheel_url"): - payload["wheel_url"] = package["wheel_url"] - return payload - - -def _dataset_path_references(datasets: dict) -> dict: - path_references = {} - for dataset, artifact in sorted(datasets.items()): - if not isinstance(dataset, str) or not isinstance(artifact, dict): - raise BundleImportError( - "Country bundle datasets must map names to objects." - ) - payload = {"path": _required_string(artifact, "path")} - if artifact.get("revision"): - payload["revision"] = artifact["revision"] - if artifact.get("sha256"): - payload["sha256"] = artifact["sha256"] - if artifact.get("metadata_sha256"): - payload["metadata_sha256"] = artifact["metadata_sha256"] - path_references[dataset] = payload - return path_references - - -def _region_dataset_templates(region_datasets: dict) -> dict: - templates = {} - if not isinstance(region_datasets, dict): - raise BundleImportError("Country bundle region_datasets must be an object.") - for region, template in sorted(region_datasets.items()): - if not isinstance(region, str) or not isinstance(template, dict): - raise BundleImportError( - "Country bundle region_datasets must map names to objects." - ) - if "path_template" in template: - templates[region] = {"path_template": template["path_template"]} - return templates - - -def _artifact_uri(artifact: dict) -> str: - uri = artifact.get("uri") - if isinstance(uri, str) and uri: - return uri - repo_id = _required_string(artifact, "repo_id") - path = _required_string(artifact, "path") - revision = _required_string(artifact, "revision") - return f"hf://{repo_id}/{path}@{revision}" - - -def _package_pin_version(package: Any) -> Optional[str]: - if isinstance(package, dict): - version = package.get("version") - if isinstance(version, str): - return version - return None - - -def _package_pin_git_sha(package: Any) -> Optional[str]: - if isinstance(package, dict): - git_sha = package.get("git_sha") - if isinstance(git_sha, str): - return git_sha - return None - - -def _update_optional_dependency_pins(*, pyproject_path: Path, bundle: dict) -> None: - packages = _required_dict(bundle, "packages") - core_version = _required_string( - _required_dict(packages, "policyengine-core"), "version" - ) - replacements = {"policyengine_core": core_version} - for package_name in COUNTRY_OPTIONAL_DEPENDENCIES.values(): - package = _required_dict(packages, package_name) - replacements[package_name] = _required_string(package, "version") - - text = pyproject_path.read_text() - text = _replace_optional_dependency_section( - text, - "uk", - [ - f"policyengine_core=={core_version}", - f"policyengine-uk=={replacements['policyengine-uk']}", - ], - ) - text = _replace_optional_dependency_section( - text, - "us", - [ - f"policyengine_core=={core_version}", - f"policyengine-us=={replacements['policyengine-us']}", - ], - ) - text = _replace_dependency_in_section( - text, "dev", "policyengine_core", core_version - ) - text = _replace_dependency_in_section( - text, - "dev", - "policyengine-uk", - replacements["policyengine-uk"], - ) - text = _replace_dependency_in_section( - text, - "dev", - "policyengine-us", - replacements["policyengine-us"], - ) - pyproject_path.write_text(text) - - -def _replace_optional_dependency_section( - text: str, - section_name: str, - dependencies: list[str], -) -> str: - section_start = text.find(f"{section_name} = [") - if section_start == -1: - raise BundleImportError( - f"pyproject optional dependency missing: {section_name}" - ) - content_start = text.find("\n", section_start) - if content_start == -1: - raise BundleImportError(f"Malformed pyproject section: {section_name}") - content_end = text.find("\n]", content_start) - if content_end == -1: - raise BundleImportError(f"Malformed pyproject section: {section_name}") - replacement = "\n".join(f' "{dependency}",' for dependency in dependencies) - return f"{text[: content_start + 1]}{replacement}{text[content_end:]}" - - -def _replace_dependency_in_section( - text: str, - section_name: str, - package_name: str, - version: str, -) -> str: - section_start = text.find(f"{section_name} = [") - if section_start == -1: - raise BundleImportError( - f"pyproject optional dependency missing: {section_name}" - ) - content_start = text.find("\n", section_start) - content_end = text.find("\n]", content_start) - if content_start == -1 or content_end == -1: - raise BundleImportError(f"Malformed pyproject section: {section_name}") - - lines = text[content_start + 1 : content_end].splitlines() - updated_lines = [] - replaced = False - for line in lines: - stripped = line.strip() - if stripped.startswith(f'"{package_name}==') or stripped.startswith( - f'"{package_name}>=' - ): - updated_lines.append(f' "{package_name}=={version}",') - replaced = True - else: - updated_lines.append(line) - if not replaced: - raise BundleImportError( - f"pyproject optional dependency {section_name} is missing {package_name}." - ) - replacement = "\n".join(updated_lines) - return f"{text[: content_start + 1]}{replacement}{text[content_end:]}" - - -def _write_changelog_fragment( - *, - changelog_dir: Path, - version: str, - bundle: dict, -) -> Path: - packages = _required_dict(bundle, "packages") - core_version = _required_string( - _required_dict(packages, "policyengine-core"), "version" - ) - uk_version = _required_string( - _required_dict(packages, "policyengine-uk"), "version" - ) - us_version = _required_string( - _required_dict(packages, "policyengine-us"), "version" - ) - changelog_dir.mkdir(parents=True, exist_ok=True) - path = changelog_dir / f"policyengine-bundle-{version}.changed.md" - path.write_text( - f"Vend PolicyEngine bundle {version} with policyengine-core " - f"{core_version}, policyengine-uk {uk_version}, and policyengine-us " - f"{us_version}.\n" - ) - return path - - -def _release_asset_names(version: str) -> tuple[str, str, str]: - archive_name = f"policyengine-bundle-{version}.tar.gz" - return archive_name, f"{archive_name}.sha256", f"policyengine-bundle-{version}.json" - - -def _read_checksum_file(path: Path, archive_name: str) -> str: - parts = path.read_text().strip().split() - if len(parts) != 2 or parts[1] != archive_name: - raise BundleImportError(f"Malformed checksum file: {path}") - return parts[0] - - -def _sha256_file(path: Path) -> str: - hasher = hashlib.sha256() - with path.open("rb") as file: - for chunk in iter(lambda: file.read(1024 * 1024), b""): - hasher.update(chunk) - return hasher.hexdigest() - - -def _load_json(path: Path) -> dict: - try: - with path.open() as file: - payload = json.load(file) - except (OSError, ValueError) as exc: - raise BundleImportError(f"Could not load JSON from {path}: {exc}") from exc - if not isinstance(payload, dict): - raise BundleImportError(f"Expected JSON object in {path}.") - return payload - - -def _write_json(path: Path, payload: dict) -> None: - path.parent.mkdir(parents=True, exist_ok=True) - path.write_text(json.dumps(payload, indent=2, sort_keys=False) + "\n") - - -def _required_dict(payload: dict, key: str) -> dict: - value = payload.get(key) - if not isinstance(value, dict): - raise BundleImportError(f"Expected object at {key}.") - return value - - -def _required_string(payload: dict, key: str) -> str: - value = payload.get(key) - if not isinstance(value, str) or not value: - raise BundleImportError(f"Expected non-empty string at {key}.") - return value +os.environ.setdefault("POLICYENGINE_SKIP_COUNTRY_IMPORTS", "1") +sys.path.insert(0, str(Path(__file__).resolve().parent.parent / "src")) +from policyengine.provenance.bundle_import import main if __name__ == "__main__": raise SystemExit(main()) diff --git a/src/policyengine/provenance/__init__.py b/src/policyengine/provenance/__init__.py index a2b37ed1..07afe685 100644 --- a/src/policyengine/provenance/__init__.py +++ b/src/policyengine/provenance/__init__.py @@ -21,6 +21,15 @@ from .bundle import ( sync_release_manifest_policyengine_version as sync_release_manifest_policyengine_version, ) +from .bundle_import import ( + BundleImportError as BundleImportError, +) +from .bundle_import import ( + BundleImportResult as BundleImportResult, +) +from .bundle_import import ( + import_policyengine_bundle as import_policyengine_bundle, +) from .manifest import ( CertifiedDataArtifact as CertifiedDataArtifact, ) diff --git a/src/policyengine/provenance/bundle_import/__init__.py b/src/policyengine/provenance/bundle_import/__init__.py new file mode 100644 index 00000000..22a3179b --- /dev/null +++ b/src/policyengine/provenance/bundle_import/__init__.py @@ -0,0 +1,15 @@ +from __future__ import annotations + +from .api import import_policyengine_bundle as import_policyengine_bundle +from .cli import main as main +from .digest import _bundle_directory_digest as _bundle_directory_digest +from .types import BundleImportError as BundleImportError +from .types import BundleImportResult as BundleImportResult + +__all__ = [ + "BundleImportError", + "BundleImportResult", + "_bundle_directory_digest", + "import_policyengine_bundle", + "main", +] diff --git a/src/policyengine/provenance/bundle_import/api.py b/src/policyengine/provenance/bundle_import/api.py new file mode 100644 index 00000000..ef792fd6 --- /dev/null +++ b/src/policyengine/provenance/bundle_import/api.py @@ -0,0 +1,141 @@ +from __future__ import annotations + +import shutil +import tempfile +from pathlib import Path +from typing import Optional + +from .archive import ( + extract_bundle_archive, + load_country_bundles, + validate_bundle_manifest, +) +from .constants import ( + DEFAULT_BUNDLE_DIR, + DEFAULT_PYPROJECT, + DEFAULT_RELEASE_MANIFEST_DIR, +) +from .country_manifest import write_country_release_manifests +from .digest import verify_bundle_digest +from .io import load_json, required_dict, required_string +from .pyproject import update_optional_dependency_pins +from .types import BundleImportResult, TroRegenerator + + +def import_policyengine_bundle( + archive_path: Path, + *, + manifest_dir: Path = DEFAULT_RELEASE_MANIFEST_DIR, + pyproject_path: Path = DEFAULT_PYPROJECT, + update_pyproject: bool = True, + regenerate_tros: bool = True, + bundle_dir: Optional[Path] = DEFAULT_BUNDLE_DIR, + changelog_dir: Optional[Path] = None, + tro_regenerator: Optional[TroRegenerator] = None, +) -> BundleImportResult: + """Import a schema-v2 ``policyengine-bundles`` archive. + + The runtime contract in policyengine.py remains the existing per-country + ``CountryReleaseManifest`` schema. This function accepts the newer + registry-only bundle archive and translates each country bundle into that + stable runtime shape. + """ + + archive_path = Path(archive_path) + with tempfile.TemporaryDirectory() as temp_dir: + unpacked_bundle_dir = extract_bundle_archive( + archive_path=archive_path, + output_dir=Path(temp_dir) / "unpacked", + ) + bundle = load_json(unpacked_bundle_dir / "bundle.json") + validate_bundle_manifest(bundle, unpacked_bundle_dir) + verify_bundle_digest(unpacked_bundle_dir, bundle) + + copied_bundle_dir = None + if bundle_dir is not None: + copied_bundle_dir = Path(bundle_dir) + if copied_bundle_dir.exists(): + shutil.rmtree(copied_bundle_dir) + shutil.copytree(unpacked_bundle_dir, copied_bundle_dir) + source_bundle_dir = copied_bundle_dir + else: + source_bundle_dir = unpacked_bundle_dir + + country_bundles = load_country_bundles( + bundle_dir=source_bundle_dir, + bundle=bundle, + ) + release_manifest_paths = write_country_release_manifests( + country_bundles=country_bundles, + manifest_dir=manifest_dir, + ) + + updated_pyproject = None + if update_pyproject: + update_optional_dependency_pins( + pyproject_path=pyproject_path, + country_bundles=country_bundles, + ) + updated_pyproject = pyproject_path + + trace_tro_paths: list[Path] = [] + if regenerate_tros: + trace_tro_paths = regenerate_trace_tros( + countries=sorted(country_bundles), + manifest_dir=manifest_dir, + tro_regenerator=tro_regenerator, + ) + + changelog_path = None + if changelog_dir is not None: + changelog_path = write_changelog_fragment( + changelog_dir=changelog_dir, + bundle=bundle, + country_bundles=country_bundles, + ) + + return BundleImportResult( + bundle_version=required_string(bundle, "bundle_version"), + countries=sorted(country_bundles), + bundle_dir=copied_bundle_dir, + release_manifest_paths=release_manifest_paths, + pyproject_path=updated_pyproject, + trace_tro_paths=trace_tro_paths, + changelog_path=changelog_path, + ) + + +def regenerate_trace_tros( + *, + countries: list[str], + manifest_dir: Path, + tro_regenerator: Optional[TroRegenerator], +) -> list[Path]: + if tro_regenerator is None: + from policyengine.provenance.bundle import regenerate_trace_tro + + tro_regenerator = regenerate_trace_tro + return [tro_regenerator(country, manifest_dir) for country in countries] + + +def write_changelog_fragment( + *, + changelog_dir: Path, + bundle: dict, + country_bundles: dict[str, dict], +) -> Path: + bundle_version = required_string(bundle, "bundle_version") + changelog_dir.mkdir(parents=True, exist_ok=True) + path = changelog_dir / f"policyengine-bundle-{bundle_version}.changed.md" + package_fragments = [] + for country_id, country_bundle in sorted(country_bundles.items()): + model_package = required_dict(country_bundle, "model_package") + package_fragments.append( + f"{country_id}: {required_string(model_package, 'name')} " + f"{required_string(model_package, 'version')}" + ) + path.write_text( + f"Import PolicyEngine bundle {bundle_version} " + f"({'; '.join(package_fragments)}).\n" + ) + return path diff --git a/src/policyengine/provenance/bundle_import/archive.py b/src/policyengine/provenance/bundle_import/archive.py new file mode 100644 index 00000000..526ac9a7 --- /dev/null +++ b/src/policyengine/provenance/bundle_import/archive.py @@ -0,0 +1,109 @@ +from __future__ import annotations + +import sys +import tarfile +from pathlib import Path + +from .io import load_json, required_dict, required_string +from .types import BundleImportError + + +def extract_bundle_archive(*, archive_path: Path, output_dir: Path) -> Path: + output_dir.mkdir(parents=True, exist_ok=True) + try: + with tarfile.open(archive_path) as archive: + root_name = validate_archive_members(archive) + if sys.version_info >= (3, 12): + archive.extractall(output_dir, filter="data") + else: + archive.extractall(output_dir) + except (tarfile.TarError, OSError) as exc: + raise BundleImportError(f"Could not extract {archive_path}: {exc}") from exc + + bundle_dir = output_dir / root_name + if not bundle_dir.is_dir(): + raise BundleImportError(f"Archive did not contain {root_name}/.") + return bundle_dir + + +def validate_archive_members(archive: tarfile.TarFile) -> str: + members = archive.getmembers() + if not members: + raise BundleImportError("Bundle archive is empty.") + + roots: set[str] = set() + for member in members: + member_path = Path(member.name) + if not member_path.parts: + raise BundleImportError("Archive contains an empty member path.") + if member_path.is_absolute() or ".." in member_path.parts: + raise BundleImportError(f"Unsafe archive member path: {member.name}") + if member.issym() or member.islnk(): + raise BundleImportError( + f"Archive link members are not allowed: {member.name}" + ) + if not member.isfile() and not member.isdir(): + raise BundleImportError( + f"Archive special members are not allowed: {member.name}" + ) + roots.add(member_path.parts[0]) + + if len(roots) != 1: + raise BundleImportError( + "Bundle archive must contain exactly one root directory." + ) + root_name = next(iter(roots)) + if not root_name.startswith("policyengine-bundle-"): + raise BundleImportError( + "Bundle archive root must be named policyengine-bundle-." + ) + if root_name == "policyengine-bundle-": + raise BundleImportError("Bundle archive root is missing a version.") + return root_name + + +def validate_bundle_manifest(bundle: dict, bundle_dir: Path) -> None: + schema_version = bundle.get("schema_version") + if schema_version != 2: + raise BundleImportError( + "Only schema v2 policyengine-bundles archives can be imported; " + f"got schema_version={schema_version!r}." + ) + bundle_version = required_string(bundle, "bundle_version") + expected_root = f"policyengine-bundle-{bundle_version}" + if bundle_dir.name != expected_root: + raise BundleImportError( + "Bundle archive root does not match bundle_version: " + f"expected {expected_root}, got {bundle_dir.name}." + ) + required_dict(bundle, "countries") + required_dict(bundle, "packages") + validation_report = required_string(bundle, "validation_report") + report = load_json(bundle_dir / validation_report) + if report.get("schema_version") != 2: + raise BundleImportError("Bundle validation report must use schema v2.") + + +def load_country_bundles(*, bundle_dir: Path, bundle: dict) -> dict[str, dict]: + country_paths = required_dict(bundle, "countries") + country_bundles: dict[str, dict] = {} + for country_id, relative_path in sorted(country_paths.items()): + if not isinstance(country_id, str) or not isinstance(relative_path, str): + raise BundleImportError("Bundle countries must map ids to paths.") + country_bundle = load_json(bundle_dir / relative_path) + if country_bundle.get("schema_version") != 2: + raise BundleImportError( + f"Country bundle {country_id} must use schema_version=2." + ) + if country_bundle.get("country_id") != country_id: + raise BundleImportError( + f"Country bundle path for {country_id} contains country_id " + f"{country_bundle.get('country_id')!r}." + ) + if country_bundle.get("bundle_version") != bundle.get("bundle_version"): + raise BundleImportError( + f"Country bundle {country_id} bundle_version does not match " + "bundle.json." + ) + country_bundles[country_id] = country_bundle + return country_bundles diff --git a/src/policyengine/provenance/bundle_import/cli.py b/src/policyengine/provenance/bundle_import/cli.py new file mode 100644 index 00000000..5f0c4348 --- /dev/null +++ b/src/policyengine/provenance/bundle_import/cli.py @@ -0,0 +1,84 @@ +from __future__ import annotations + +import argparse +import sys +from pathlib import Path +from typing import Optional + +from .api import import_policyengine_bundle +from .constants import ( + DEFAULT_BUNDLE_DIR, + DEFAULT_CHANGELOG_DIR, + DEFAULT_PYPROJECT, + DEFAULT_RELEASE_MANIFEST_DIR, +) +from .types import BundleImportError + + +def main(argv: Optional[list[str]] = None) -> int: + parser = argparse.ArgumentParser( + description=( + "Import a schema-v2 policyengine-bundles archive into policyengine.py. " + "The importer verifies the bundle digest, vendors the exploded bundle, " + "writes .py release manifests, updates country extras, and regenerates " + "TRACE TRO sidecars." + ) + ) + parser.add_argument( + "--archive", + required=True, + type=Path, + help="Path to policyengine-bundle-.tar.gz.", + ) + parser.add_argument("--bundle-dir", type=Path, default=DEFAULT_BUNDLE_DIR) + parser.add_argument( + "--release-manifest-dir", + type=Path, + default=DEFAULT_RELEASE_MANIFEST_DIR, + ) + parser.add_argument("--pyproject", type=Path, default=DEFAULT_PYPROJECT) + parser.add_argument("--changelog-dir", type=Path, default=DEFAULT_CHANGELOG_DIR) + parser.add_argument( + "--no-pyproject", + action="store_true", + help="Do not update pyproject optional dependency pins.", + ) + parser.add_argument( + "--no-tro", + action="store_true", + help="Do not regenerate TRACE TRO sidecar files.", + ) + parser.add_argument( + "--no-changelog", + action="store_true", + help="Do not write a towncrier changelog fragment.", + ) + args = parser.parse_args(argv) + + try: + imported = import_policyengine_bundle( + args.archive, + bundle_dir=args.bundle_dir, + manifest_dir=args.release_manifest_dir, + pyproject_path=args.pyproject, + update_pyproject=not args.no_pyproject, + regenerate_tros=not args.no_tro, + changelog_dir=None if args.no_changelog else args.changelog_dir, + ) + except BundleImportError as exc: + print(f"error: {exc}", file=sys.stderr) + return 1 + + print(f"imported bundle: {imported.bundle_version}") + print(f"countries: {', '.join(imported.countries)}") + if imported.bundle_dir is not None: + print(f"vendored bundle: {imported.bundle_dir}") + for manifest_path in imported.release_manifest_paths: + print(f"release manifest: {manifest_path}") + if imported.pyproject_path is not None: + print(f"updated pyproject: {imported.pyproject_path}") + for tro_path in imported.trace_tro_paths: + print(f"trace tro: {tro_path}") + if imported.changelog_path is not None: + print(f"changelog: {imported.changelog_path}") + return 0 diff --git a/src/policyengine/provenance/bundle_import/constants.py b/src/policyengine/provenance/bundle_import/constants.py new file mode 100644 index 00000000..f227b9e7 --- /dev/null +++ b/src/policyengine/provenance/bundle_import/constants.py @@ -0,0 +1,11 @@ +from __future__ import annotations + +from pathlib import Path + +REPO_ROOT = Path(__file__).resolve().parents[4] +DEFAULT_BUNDLE_DIR = REPO_ROOT / "src" / "policyengine" / "data" / "bundle" +DEFAULT_RELEASE_MANIFEST_DIR = ( + REPO_ROOT / "src" / "policyengine" / "data" / "release_manifests" +) +DEFAULT_PYPROJECT = REPO_ROOT / "pyproject.toml" +DEFAULT_CHANGELOG_DIR = REPO_ROOT / "changelog.d" diff --git a/src/policyengine/provenance/bundle_import/country_manifest.py b/src/policyengine/provenance/bundle_import/country_manifest.py new file mode 100644 index 00000000..4ccc1351 --- /dev/null +++ b/src/policyengine/provenance/bundle_import/country_manifest.py @@ -0,0 +1,218 @@ +from __future__ import annotations + +from pathlib import Path + +from pydantic import ValidationError + +from policyengine.provenance.manifest import CountryReleaseManifest + +from .hf import parse_hf_reference_if_present +from .io import required_dict, required_string, write_json +from .types import BundleImportError + + +def write_country_release_manifests( + *, + country_bundles: dict[str, dict], + manifest_dir: Path, +) -> list[Path]: + manifest_dir.mkdir(parents=True, exist_ok=True) + written_paths = [] + for country_id, country_bundle in sorted(country_bundles.items()): + release_manifest = country_release_manifest(country_bundle) + try: + CountryReleaseManifest.model_validate(release_manifest) + except ValidationError as exc: + raise BundleImportError( + f"Generated release manifest for {country_id} is invalid: {exc}" + ) from exc + output_path = manifest_dir / f"{country_id}.json" + write_json(output_path, release_manifest) + written_paths.append(output_path) + return written_paths + + +def country_release_manifest(country_bundle: dict) -> dict: + country_id = required_string(country_bundle, "country_id") + bundle_version = required_string(country_bundle, "bundle_version") + model_package = required_dict(country_bundle, "model_package") + data_package = required_dict(country_bundle, "data_package") + compatibility = required_dict(country_bundle, "compatibility") + compatibility_metadata = compatibility.get("metadata") + if not isinstance(compatibility_metadata, dict): + compatibility_metadata = {} + datasets = required_dict(country_bundle, "datasets") + default_dataset = required_string(country_bundle, "default_dataset") + default_artifact = required_dict(datasets, default_dataset) + + data_package_payload = data_package_version(data_package) + certified_artifact = certified_data_artifact( + default_dataset=default_dataset, + default_artifact=default_artifact, + data_package=data_package_payload, + compatibility_metadata=compatibility_metadata, + ) + certification = data_certification( + model_package=model_package, + compatibility=compatibility, + compatibility_metadata=compatibility_metadata, + ) + + return { + "schema_version": 1, + "bundle_id": f"{country_id}-{bundle_version}", + "country_id": country_id, + "policyengine_version": bundle_version, + "model_package": package_version(model_package), + "data_package": data_package_payload, + "default_dataset": default_dataset, + "datasets": dataset_path_references(datasets), + "region_datasets": region_dataset_templates( + country_bundle.get("region_datasets", {}) + ), + "certified_data_artifact": certified_artifact, + "certification": certification, + } + + +def data_package_version(data_package: dict) -> dict: + payload = { + "name": required_string(data_package, "name"), + "version": required_string(data_package, "version"), + "repo_id": required_string(data_package, "repo_id"), + "repo_type": data_package.get("repo_type", "model"), + "release_manifest_path": data_package.get( + "release_manifest_path", + "release_manifest.json", + ), + } + release_manifest_revision = data_package.get("release_manifest_revision") + if isinstance(release_manifest_revision, str) and release_manifest_revision: + payload["release_manifest_revision"] = release_manifest_revision + return payload + + +def certified_data_artifact( + *, + default_dataset: str, + default_artifact: dict, + data_package: dict, + compatibility_metadata: dict, +) -> dict: + payload = { + "data_package": { + "name": data_package["name"], + "version": data_package["version"], + }, + "dataset": default_dataset, + "uri": artifact_uri(default_artifact), + } + if default_artifact.get("sha256"): + payload["sha256"] = default_artifact["sha256"] + build_id = compatibility_metadata.get("data_build_id") + if isinstance(build_id, str) and build_id: + payload["build_id"] = build_id + return payload + + +def data_certification( + *, + model_package: dict, + compatibility: dict, + compatibility_metadata: dict, +) -> dict: + payload = { + "compatibility_basis": compatibility.get("basis", "bundle_candidate"), + "certified_for_model_version": required_string( + model_package, + "version", + ), + "certified_by": compatibility.get("asserted_by", "policyengine-bundles"), + } + optional_fields = { + "data_build_id": "data_build_id", + "built_with_model_version": "built_with_model_version", + "built_with_model_git_sha": "built_with_model_git_sha", + "data_build_fingerprint": "data_build_fingerprint", + } + for output_key, metadata_key in optional_fields.items(): + value = compatibility_metadata.get(metadata_key) + if isinstance(value, str) and value: + payload[output_key] = value + return payload + + +def package_version(package: dict) -> dict: + payload = { + "name": required_string(package, "name"), + "version": required_string(package, "version"), + } + if package.get("sha256"): + payload["sha256"] = package["sha256"] + if package.get("wheel_url"): + payload["wheel_url"] = package["wheel_url"] + return payload + + +def dataset_path_references(datasets: dict) -> dict: + path_references = {} + for dataset, artifact in sorted(datasets.items()): + if not isinstance(dataset, str) or not isinstance(artifact, dict): + raise BundleImportError( + "Country bundle datasets must map names to objects." + ) + parsed_uri = parse_hf_reference_if_present(artifact.get("uri")) + path = artifact.get("path") or (parsed_uri.path if parsed_uri else None) + if not isinstance(path, str) or not path: + raise BundleImportError( + f"Dataset {dataset} does not include a path and its uri cannot " + "be translated into a path reference." + ) + payload = {"path": path} + revision = artifact.get("revision") or ( + parsed_uri.revision if parsed_uri else None + ) + if isinstance(revision, str) and revision: + payload["revision"] = revision + if artifact.get("sha256"): + payload["sha256"] = artifact["sha256"] + if artifact.get("metadata_sha256"): + payload["metadata_sha256"] = artifact["metadata_sha256"] + path_references[dataset] = payload + return path_references + + +def region_dataset_templates(region_datasets: dict) -> dict: + templates = {} + if not isinstance(region_datasets, dict): + raise BundleImportError("Country bundle region_datasets must be an object.") + for region, template in sorted(region_datasets.items()): + if not isinstance(region, str) or not isinstance(template, dict): + raise BundleImportError( + "Country bundle region_datasets must map names to objects." + ) + path_template = template.get("path_template") + if isinstance(path_template, str) and path_template: + templates[region] = {"path_template": path_template} + return templates + + +def artifact_uri(artifact: dict) -> str: + parsed_uri = parse_hf_reference_if_present(artifact.get("uri")) + repo_id = artifact.get("repo_id") or (parsed_uri.repo_id if parsed_uri else None) + path = artifact.get("path") or (parsed_uri.path if parsed_uri else None) + revision = artifact.get("revision") or (parsed_uri.revision if parsed_uri else None) + if ( + isinstance(repo_id, str) + and repo_id + and isinstance(path, str) + and path + and isinstance(revision, str) + and revision + ): + return f"hf://{repo_id}/{path}@{revision}" + + uri = artifact.get("uri") + if isinstance(uri, str) and uri: + return uri + raise BundleImportError("Artifact does not include a resolvable uri.") diff --git a/src/policyengine/provenance/bundle_import/digest.py b/src/policyengine/provenance/bundle_import/digest.py new file mode 100644 index 00000000..b45a1172 --- /dev/null +++ b/src/policyengine/provenance/bundle_import/digest.py @@ -0,0 +1,73 @@ +from __future__ import annotations + +import hashlib +import json +from pathlib import Path + +from .io import load_json +from .types import BundleImportError + + +def verify_bundle_digest(bundle_dir: Path, bundle: dict) -> None: + expected = bundle.get("bundle_digest") + if not isinstance(expected, str) or not expected.startswith("sha256:"): + raise BundleImportError("bundle.json does not include bundle_digest.") + actual = f"sha256:{bundle_directory_digest(bundle_dir)}" + if actual != expected: + raise BundleImportError( + "bundle.json bundle_digest does not match bundle contents: " + f"expected {expected}, got {actual}." + ) + + +def bundle_directory_digest(bundle_dir: Path) -> str: + hasher = hashlib.sha256() + for relative_path in bundle_files(bundle_dir): + content = normalized_file_content(bundle_dir, relative_path) + hasher.update(relative_path.as_posix().encode("utf-8")) + hasher.update(b"\0") + hasher.update(content.encode("utf-8")) + hasher.update(b"\0") + return hasher.hexdigest() + + +def bundle_files(bundle_dir: Path) -> list[Path]: + return sorted( + path.relative_to(bundle_dir) + for path in bundle_dir.rglob("*") + if path.is_file() and path.name != ".DS_Store" + ) + + +def normalized_file_content(bundle_dir: Path, relative_path: Path) -> str: + path = bundle_dir / relative_path + if relative_path.suffix == ".json": + payload = load_json(path) + if relative_path.as_posix() == "bundle.json": + payload.pop("created_at", None) + payload.pop("bundle_digest", None) + elif relative_path.as_posix() == "validation-report.json": + payload.pop("generated_at", None) + checks = [] + for check in payload.get("checks", []): + if not isinstance(check, dict): + checks.append(check) + continue + check_payload = dict(check) + check_payload.pop("command", None) + check_payload.pop("started_at", None) + check_payload.pop("ended_at", None) + details = check_payload.get("details") + if isinstance(details, dict): + details_payload = dict(details) + details_payload.pop("validated_on_platform", None) + details_payload.pop("bundle_dir", None) + check_payload["details"] = details_payload + checks.append(check_payload) + payload["checks"] = checks + return json.dumps(payload, indent=2, sort_keys=True) + "\n" + return path.read_text() + + +# Backward-compatible private name used by importer tests. +_bundle_directory_digest = bundle_directory_digest diff --git a/src/policyengine/provenance/bundle_import/hf.py b/src/policyengine/provenance/bundle_import/hf.py new file mode 100644 index 00000000..6a554e98 --- /dev/null +++ b/src/policyengine/provenance/bundle_import/hf.py @@ -0,0 +1,63 @@ +from __future__ import annotations + +import urllib.parse +from typing import Any, Optional + +from .types import BundleImportError, HuggingFaceReference + + +def parse_hf_reference_if_present(value: Any) -> Optional[HuggingFaceReference]: + if not isinstance(value, str): + return None + try: + return parse_hf_reference(value) + except BundleImportError: + return None + + +def parse_hf_reference(uri: str) -> HuggingFaceReference: + parsed = urllib.parse.urlparse(uri) + if parsed.scheme != "hf": + raise BundleImportError(f"Expected hf:// URI, got {uri!r}.") + repo_type, rest = hf_repo_type_and_reference(parsed) + repo_id, revision, path = parse_hf_reference_parts(rest) + return HuggingFaceReference( + repo_type=repo_type, + repo_id=repo_id, + revision=revision, + path=path, + ) + + +def hf_repo_type_and_reference( + parsed: urllib.parse.ParseResult, +) -> tuple[str, str]: + if parsed.netloc in {"model", "dataset", "space"}: + return parsed.netloc, parsed.path.lstrip("/") + return "model", f"{parsed.netloc}{parsed.path}" + + +def parse_hf_reference_parts(rest: str) -> tuple[str, str, str]: + if "@" not in rest: + raise BundleImportError( + "HF URIs must include an immutable revision, for example " + "hf://model/org/repo@version/path." + ) + + repo_id, revision_and_path = rest.split("@", 1) + if "/" in revision_and_path: + revision, path = revision_and_path.split("/", 1) + if repo_id and revision and path: + return repo_id, revision, path + + repo_and_path, revision = rest.rsplit("@", 1) + parts = repo_and_path.split("/") + if len(parts) < 3: + raise BundleImportError( + "Legacy HF URIs must use hf://org/repo/path@revision form." + ) + repo_id = "/".join(parts[:2]) + path = "/".join(parts[2:]) + if not repo_id or not revision or not path: + raise BundleImportError(f"Incomplete HF URI reference: {rest!r}.") + return repo_id, revision, path diff --git a/src/policyengine/provenance/bundle_import/io.py b/src/policyengine/provenance/bundle_import/io.py new file mode 100644 index 00000000..4eefeec6 --- /dev/null +++ b/src/policyengine/provenance/bundle_import/io.py @@ -0,0 +1,36 @@ +from __future__ import annotations + +import json +from pathlib import Path + +from .types import BundleImportError + + +def load_json(path: Path) -> dict: + try: + with path.open() as file: + payload = json.load(file) + except (OSError, ValueError) as exc: + raise BundleImportError(f"Could not load JSON from {path}: {exc}") from exc + if not isinstance(payload, dict): + raise BundleImportError(f"Expected JSON object in {path}.") + return payload + + +def write_json(path: Path, payload: dict) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(json.dumps(payload, indent=2, sort_keys=False) + "\n") + + +def required_dict(payload: dict, key: str) -> dict: + value = payload.get(key) + if not isinstance(value, dict): + raise BundleImportError(f"Expected object at {key}.") + return value + + +def required_string(payload: dict, key: str) -> str: + value = payload.get(key) + if not isinstance(value, str) or not value: + raise BundleImportError(f"Expected non-empty string at {key}.") + return value diff --git a/src/policyengine/provenance/bundle_import/pyproject.py b/src/policyengine/provenance/bundle_import/pyproject.py new file mode 100644 index 00000000..ae3d9a49 --- /dev/null +++ b/src/policyengine/provenance/bundle_import/pyproject.py @@ -0,0 +1,113 @@ +from __future__ import annotations + +import re +from pathlib import Path + +from .io import required_dict, required_string +from .types import BundleImportError + +COUNTRY_OPTIONAL_DEPENDENCIES = { + "uk": "policyengine-uk", + "us": "policyengine-us", +} + + +def update_optional_dependency_pins( + *, + pyproject_path: Path, + country_bundles: dict[str, dict], +) -> None: + text = pyproject_path.read_text() + core_versions = set() + for country_id, country_bundle in sorted(country_bundles.items()): + if country_id not in COUNTRY_OPTIONAL_DEPENDENCIES: + raise BundleImportError( + f"Cannot update pyproject pins for unknown country {country_id!r}." + ) + model_package = required_dict(country_bundle, "model_package") + core_package = required_dict(country_bundle, "core_package") + expected_package = COUNTRY_OPTIONAL_DEPENDENCIES[country_id] + package_name = required_string(model_package, "name") + if package_name != expected_package: + raise BundleImportError( + f"Country {country_id} expected model package {expected_package}, " + f"got {package_name}." + ) + package_version = required_string(model_package, "version") + core_version = required_string(core_package, "version") + core_versions.add(core_version) + + text = replace_dependency_in_section( + text, + section_name=country_id, + package_name="policyengine_core", + requirement=f"policyengine_core>={core_version}", + ) + text = replace_dependency_in_section( + text, + section_name=country_id, + package_name=package_name, + requirement=f"{package_name}=={package_version}", + ) + text = replace_dependency_in_section( + text, + section_name="dev", + package_name=package_name, + requirement=f"{package_name}=={package_version}", + ) + + if len(core_versions) != 1: + raise BundleImportError( + "Imported countries must use one policyengine-core version so the " + "dev extra can be updated unambiguously." + ) + core_version = next(iter(core_versions)) + text = replace_dependency_in_section( + text, + section_name="dev", + package_name="policyengine_core", + requirement=f"policyengine_core>={core_version}", + ) + pyproject_path.write_text(text) + + +def replace_dependency_in_section( + text: str, + *, + section_name: str, + package_name: str, + requirement: str, +) -> str: + section_start = text.find(f"{section_name} = [") + if section_start == -1: + raise BundleImportError( + f"pyproject optional dependency missing: {section_name}" + ) + content_start = text.find("\n", section_start) + content_end = text.find("\n]", content_start) + if content_start == -1 or content_end == -1: + raise BundleImportError(f"Malformed pyproject section: {section_name}") + + lines = text[content_start + 1 : content_end].splitlines() + updated_lines = [] + replaced = False + for line in lines: + stripped = line.strip() + if dependency_line_matches(stripped, package_name): + updated_lines.append(f' "{requirement}",') + replaced = True + else: + updated_lines.append(line) + if not replaced: + raise BundleImportError( + f"pyproject optional dependency {section_name} is missing {package_name}." + ) + replacement = "\n".join(updated_lines) + return f"{text[: content_start + 1]}{replacement}{text[content_end:]}" + + +def dependency_line_matches(line: str, package_name: str) -> bool: + return ( + re.match(rf'"{re.escape(package_name)}\s*(==|>=|<=|~=|!=|>|<)', line) + is not None + ) diff --git a/src/policyengine/provenance/bundle_import/types.py b/src/policyengine/provenance/bundle_import/types.py new file mode 100644 index 00000000..ed24f593 --- /dev/null +++ b/src/policyengine/provenance/bundle_import/types.py @@ -0,0 +1,31 @@ +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path +from typing import Callable, Optional + + +class BundleImportError(RuntimeError): + """Raised when a PolicyEngine bundle cannot be imported into policyengine.py.""" + + +@dataclass(frozen=True) +class BundleImportResult: + bundle_version: str + countries: list[str] + bundle_dir: Optional[Path] + release_manifest_paths: list[Path] + pyproject_path: Optional[Path] + trace_tro_paths: list[Path] + changelog_path: Optional[Path] + + +@dataclass(frozen=True) +class HuggingFaceReference: + repo_type: str + repo_id: str + revision: str + path: str + + +TroRegenerator = Callable[[str, Path], Path] diff --git a/tests/test_import_policyengine_bundle.py b/tests/test_import_policyengine_bundle.py index 79569ac4..3f4b3665 100644 --- a/tests/test_import_policyengine_bundle.py +++ b/tests/test_import_policyengine_bundle.py @@ -1,171 +1,302 @@ from __future__ import annotations -import hashlib -import importlib.util import json import tarfile from pathlib import Path +from typing import Optional import pytest -SCRIPT_PATH = ( - Path(__file__).resolve().parents[1] / "scripts" / "import_policyengine_bundle.py" -) -SPEC = importlib.util.spec_from_file_location("import_policyengine_bundle", SCRIPT_PATH) -assert SPEC is not None -import_policyengine_bundle = importlib.util.module_from_spec(SPEC) -assert SPEC.loader is not None -SPEC.loader.exec_module(import_policyengine_bundle) +from policyengine.provenance import bundle_import -def test_import_policyengine_bundle_verifies_and_vendors_release( +def test_import_policyengine_bundle_imports_schema_v2_archive( tmp_path: Path, ) -> None: - dist_dir = _write_release_assets(tmp_path, version="4.14.0") + archive_path = _write_bundle_archive(tmp_path, version="4.15.0") bundle_dir = tmp_path / "vendored-bundle" - release_manifest_dir = tmp_path / "release_manifests" + manifest_dir = tmp_path / "release_manifests" pyproject_path = _write_pyproject(tmp_path) changelog_dir = tmp_path / "changelog.d" - result = import_policyengine_bundle.import_policyengine_bundle( - version="4.14.0", - dist_dir=dist_dir, - base_url="unused", + def fake_tro_regenerator(country: str, output_dir: Path) -> Path: + path = output_dir / f"{country}.trace.tro.jsonld" + path.write_text(f"{country}\n") + return path + + result = bundle_import.import_policyengine_bundle( + archive_path, bundle_dir=bundle_dir, - release_manifest_dir=release_manifest_dir, + manifest_dir=manifest_dir, pyproject_path=pyproject_path, changelog_dir=changelog_dir, + tro_regenerator=fake_tro_regenerator, ) + assert result.bundle_version == "4.15.0" + assert result.countries == ["uk", "us"] assert (bundle_dir / "bundle.json").exists() assert result.bundle_dir == bundle_dir assert {path.name for path in result.release_manifest_paths} == { "uk.json", "us.json", } + assert {path.name for path in result.trace_tro_paths} == { + "uk.trace.tro.jsonld", + "us.trace.tro.jsonld", + } - us_manifest = json.loads((release_manifest_dir / "us.json").read_text()) - assert us_manifest["bundle_id"] == "us-4.14.0" - assert us_manifest["policyengine_version"] == "4.14.0" + us_manifest = json.loads((manifest_dir / "us.json").read_text()) + assert us_manifest["schema_version"] == 1 + assert us_manifest["bundle_id"] == "us-4.15.0" + assert us_manifest["policyengine_version"] == "4.15.0" assert us_manifest["model_package"]["version"] == "1.722.4" + assert us_manifest["datasets"]["enhanced_cps_2024"] == { + "path": "enhanced_cps_2024.h5", + "revision": "data-sha", + "sha256": "b" * 64, + "metadata_sha256": "c" * 64, + } assert ( us_manifest["certified_data_artifact"]["uri"] == "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5@data-sha" ) + assert ( + us_manifest["certified_data_artifact"]["build_id"] + == "policyengine-us-data-1.0.0" + ) + assert us_manifest["certification"]["compatibility_basis"] == "bundle_candidate" + assert us_manifest["certification"]["certified_by"] == "policyengine-bundles" assert us_manifest["certification"]["certified_for_model_version"] == "1.722.4" + assert us_manifest["certification"]["data_build_fingerprint"] == ( + "sha256:" + "d" * 64 + ) pyproject = pyproject_path.read_text() - assert '"policyengine_core==3.26.1"' in pyproject + assert '"policyengine_core>=3.26.1"' in pyproject assert '"policyengine-us==1.722.4"' in pyproject assert '"policyengine-uk==3.0.0"' in pyproject - assert (changelog_dir / "policyengine-bundle-4.14.0.changed.md").exists() + assert (changelog_dir / "policyengine-bundle-4.15.0.changed.md").exists() + + +def test_import_policyengine_bundle_rejects_schema_v1_archive( + tmp_path: Path, +) -> None: + archive_path = _write_bundle_archive( + tmp_path, + version="4.15.0", + schema_version=1, + ) + + with pytest.raises(bundle_import.BundleImportError, match="schema v2"): + bundle_import.import_policyengine_bundle( + archive_path, + bundle_dir=tmp_path / "vendored-bundle", + manifest_dir=tmp_path / "release_manifests", + pyproject_path=_write_pyproject(tmp_path), + regenerate_tros=False, + ) + + +def test_import_policyengine_bundle_rejects_digest_mismatch( + tmp_path: Path, +) -> None: + archive_path = _write_bundle_archive( + tmp_path, + version="4.15.0", + bundle_digest="sha256:" + "0" * 64, + ) + + with pytest.raises(bundle_import.BundleImportError, match="bundle_digest"): + bundle_import.import_policyengine_bundle( + archive_path, + bundle_dir=tmp_path / "vendored-bundle", + manifest_dir=tmp_path / "release_manifests", + pyproject_path=_write_pyproject(tmp_path), + regenerate_tros=False, + ) + + +def test_import_policyengine_bundle_rejects_missing_default_dataset( + tmp_path: Path, +) -> None: + archive_path = _write_bundle_archive( + tmp_path, + version="4.15.0", + missing_default_dataset=True, + ) + + with pytest.raises(bundle_import.BundleImportError, match="missing_dataset"): + bundle_import.import_policyengine_bundle( + archive_path, + bundle_dir=tmp_path / "vendored-bundle", + manifest_dir=tmp_path / "release_manifests", + pyproject_path=_write_pyproject(tmp_path), + regenerate_tros=False, + ) -def test_import_policyengine_bundle_rejects_checksum_mismatch( +def test_import_policyengine_bundle_wraps_invalid_generated_manifest( tmp_path: Path, ) -> None: - dist_dir = _write_release_assets(tmp_path, version="4.14.0") - (dist_dir / "policyengine-bundle-4.14.0.tar.gz.sha256").write_text( - f"{'0' * 64} policyengine-bundle-4.14.0.tar.gz\n" + archive_path = _write_bundle_archive( + tmp_path, + version="4.15.0", + invalid_data_package_repo_type=True, ) - with pytest.raises( - import_policyengine_bundle.BundleImportError, - match="archive_sha256 does not match checksum file", - ): - import_policyengine_bundle.import_policyengine_bundle( - version="4.14.0", - dist_dir=dist_dir, - base_url="unused", + with pytest.raises(bundle_import.BundleImportError, match="us is invalid"): + bundle_import.import_policyengine_bundle( + archive_path, bundle_dir=tmp_path / "vendored-bundle", - release_manifest_dir=tmp_path / "release_manifests", + manifest_dir=tmp_path / "release_manifests", pyproject_path=_write_pyproject(tmp_path), - changelog_dir=None, + regenerate_tros=False, ) -def _write_release_assets(tmp_path: Path, *, version: str) -> Path: +def test_import_policyengine_bundle_can_skip_pyproject_and_tros( + tmp_path: Path, +) -> None: + archive_path = _write_bundle_archive(tmp_path, version="4.15.0") + pyproject_path = _write_pyproject(tmp_path) + original_pyproject = pyproject_path.read_text() + + result = bundle_import.import_policyengine_bundle( + archive_path, + bundle_dir=None, + manifest_dir=tmp_path / "release_manifests", + pyproject_path=pyproject_path, + update_pyproject=False, + regenerate_tros=False, + ) + + assert result.bundle_dir is None + assert result.pyproject_path is None + assert result.trace_tro_paths == [] + assert pyproject_path.read_text() == original_pyproject + + +def test_import_policyengine_bundle_updates_only_countries_in_archive( + tmp_path: Path, +) -> None: + archive_path = _write_bundle_archive( + tmp_path, + version="4.15.0", + countries=("us",), + ) + manifest_dir = tmp_path / "release_manifests" + pyproject_path = _write_pyproject(tmp_path) + + result = bundle_import.import_policyengine_bundle( + archive_path, + bundle_dir=None, + manifest_dir=manifest_dir, + pyproject_path=pyproject_path, + regenerate_tros=False, + ) + + assert result.countries == ["us"] + assert {path.name for path in result.release_manifest_paths} == {"us.json"} + assert (manifest_dir / "us.json").exists() + assert not (manifest_dir / "uk.json").exists() + pyproject = pyproject_path.read_text() + assert '"policyengine-us==1.722.4"' in pyproject + assert '"policyengine-uk==2.88.20"' in pyproject + + +def test_import_policyengine_bundle_cli_smoke(tmp_path: Path, capsys) -> None: + archive_path = _write_bundle_archive(tmp_path, version="4.15.0") + exit_code = bundle_import.main( + [ + "--archive", + str(archive_path), + "--bundle-dir", + str(tmp_path / "vendored-bundle"), + "--release-manifest-dir", + str(tmp_path / "release_manifests"), + "--pyproject", + str(_write_pyproject(tmp_path)), + "--changelog-dir", + str(tmp_path / "changelog.d"), + "--no-tro", + ] + ) + + assert exit_code == 0 + assert "imported bundle: 4.15.0" in capsys.readouterr().out + + +def _write_bundle_archive( + tmp_path: Path, + *, + version: str, + schema_version: int = 2, + bundle_digest: Optional[str] = None, + missing_default_dataset: bool = False, + invalid_data_package_repo_type: bool = False, + countries: tuple[str, ...] = ("us", "uk"), +) -> Path: bundle_root = tmp_path / f"policyengine-bundle-{version}" - _write_json(bundle_root / "countries" / "us.json", _country_bundle("us", version)) - _write_json(bundle_root / "countries" / "uk.json", _country_bundle("uk", version)) + for country in countries: + _write_json( + bundle_root / "countries" / f"{country}.json", + _country_bundle( + country, + version, + missing_default_dataset=(missing_default_dataset and country == "us"), + invalid_data_package_repo_type=( + invalid_data_package_repo_type and country == "us" + ), + ), + ) _write_json(bundle_root / "validation-report.json", _validation_report(version)) - bundle = _bundle_manifest(version) + bundle = _bundle_manifest( + version, + schema_version=schema_version, + countries=countries, + ) _write_json(bundle_root / "bundle.json", bundle) - bundle["bundle_digest"] = ( - f"sha256:{import_policyengine_bundle._bundle_directory_digest(bundle_root)}" + bundle["bundle_digest"] = bundle_digest or ( + f"sha256:{bundle_import._bundle_directory_digest(bundle_root)}" ) _write_json(bundle_root / "bundle.json", bundle) - dist_dir = tmp_path / "dist" - dist_dir.mkdir() - archive_path = dist_dir / f"policyengine-bundle-{version}.tar.gz" + archive_path = tmp_path / f"policyengine-bundle-{version}.tar.gz" with tarfile.open(archive_path, "w:gz") as archive: archive.add(bundle_root, arcname=bundle_root.name) - archive_sha256 = _sha256_file(archive_path) - (dist_dir / f"{archive_path.name}.sha256").write_text( - f"{archive_sha256} {archive_path.name}\n" - ) - _write_json( - dist_dir / f"policyengine-bundle-{version}.json", - { - "bundle_version": version, - "bundle_digest": bundle["bundle_digest"], - "archive": archive_path.name, - "archive_sha256": archive_sha256, - }, - ) - return dist_dir + return archive_path -def _bundle_manifest(version: str) -> dict: +def _bundle_manifest( + version: str, + *, + schema_version: int, + countries: tuple[str, ...], +) -> dict: return { - "schema_version": 1, + "schema_version": schema_version, "bundle_version": version, "created_at": "2026-06-03T00:00:00Z", - "policyengine": { - "name": "policyengine", - "version": version, - "resolution_status": "pinned", - }, + "policyengine": _package_pin("policyengine", version), "packages": { - "policyengine": { - "name": "policyengine", - "version": version, - "resolution_status": "pinned", - }, - "policyengine-core": { - "name": "policyengine-core", - "version": "3.26.1", - "resolution_status": "pinned", - }, - "policyengine-us": { - "name": "policyengine-us", - "version": "1.722.4", - "resolution_status": "pinned", - }, - "policyengine-uk": { - "name": "policyengine-uk", - "version": "3.0.0", - "resolution_status": "pinned", - }, - }, - "profiles": { - "us": {"packages": ["policyengine-us"], "countries": ["us"]}, - "uk": {"packages": ["policyengine-uk"], "countries": ["uk"]}, - "all": { - "packages": ["policyengine-us", "policyengine-uk"], - "countries": ["us", "uk"], - }, - }, - "countries": { - "us": "countries/us.json", - "uk": "countries/uk.json", + "policyengine": _package_pin("policyengine", version), + "policyengine-core": _package_pin("policyengine-core", "3.26.1"), + "policyengine-us": _package_pin("policyengine-us", "1.722.4"), + "policyengine-uk": _package_pin("policyengine-uk", "3.0.0"), }, + "countries": {country: f"countries/{country}.json" for country in countries}, "validation_report": "validation-report.json", } -def _country_bundle(country_id: str, version: str) -> dict: +def _country_bundle( + country_id: str, + version: str, + *, + missing_default_dataset: bool = False, + invalid_data_package_repo_type: bool = False, +) -> dict: model_package = "policyengine-us" if country_id == "us" else "policyengine-uk" model_version = "1.722.4" if country_id == "us" else "3.0.0" data_package = ( @@ -178,27 +309,18 @@ def _country_bundle(country_id: str, version: str) -> dict: ) dataset = "enhanced_cps_2024" if country_id == "us" else "enhanced_frs_2023_24" path = f"{dataset}.h5" + default_dataset = "missing_dataset" if missing_default_dataset else dataset return { - "schema_version": 1, + "schema_version": 2, "bundle_version": version, "country_id": country_id, - "model_package": { - "name": model_package, - "version": model_version, - "resolution_status": "pinned", - "sha256": "a" * 64, - "wheel_url": f"https://example.test/{model_package}.whl", - }, - "core_package": { - "name": "policyengine-core", - "version": "3.26.1", - "resolution_status": "pinned", - }, + "model_package": _package_pin(model_package, model_version), + "core_package": _package_pin("policyengine-core", "3.26.1"), "data_package": { "name": data_package, "version": "1.0.0", "repo_id": repo_id, - "repo_type": "model", + "repo_type": 123 if invalid_data_package_repo_type else "model", "release_manifest_path": "release_manifest.json", "release_manifest_revision": "data-sha", }, @@ -207,62 +329,79 @@ def _country_bundle(country_id: str, version: str) -> dict: "version": "data-sha", "repo_type": "model", "release_manifest_uri": f"hf://model/{repo_id}@data-sha/release_manifest.json", + "release_manifest_sha256": "a" * 64, }, - "default_dataset": dataset, + "default_dataset": default_dataset, "datasets": { dataset: { "kind": "microdata", + "uri": f"hf://model/{repo_id}@data-sha/{path}", "path": path, "repo_id": repo_id, "revision": "data-sha", "sha256": "b" * 64, + "metadata_sha256": "c" * 64, "status": "certified", } }, - "region_datasets": {"national": {"path_template": path}}, - "certification": { - "compatibility_basis": "manual_runtime_certification", - "built_with_model_package": { - "name": model_package, - "version": model_version, - "resolution_status": "pinned", - }, - "built_with_core_package": { - "name": "policyengine-core", - "version": "3.26.1", - "resolution_status": "pinned", - }, - "certified_for_model_package": { - "name": model_package, - "version": model_version, - "resolution_status": "pinned", - }, - "certified_for_core_package": { - "name": "policyengine-core", - "version": "3.26.1", - "resolution_status": "pinned", + "region_datasets": { + "national": { + "path_template": path, + "uri_template": f"hf://model/{repo_id}@data-sha/{path}", + } + }, + "compatibility": { + "basis": "bundle_candidate", + "model_package": _package_pin(model_package, model_version), + "core_package": _package_pin("policyengine-core", "3.26.1"), + "data_package": {"name": data_package, "version": "1.0.0"}, + "release_manifest_uri": f"hf://model/{repo_id}@data-sha/release_manifest.json", + "release_manifest_sha256": "a" * 64, + "asserted_by": "policyengine-bundles", + "metadata": { + "candidate_model_package": model_package, + "candidate_data_release_manifest_uri": ( + f"hf://model/{repo_id}@data-sha/release_manifest.json" + ), + "data_build_id": f"{data_package}-1.0.0", + "built_with_model_version": model_version, + "built_with_model_git_sha": "git-sha", + "data_build_fingerprint": "sha256:" + "d" * 64, }, - "certified_by": "test", - "data_build_id": f"{data_package}-1.0.0", - "data_build_fingerprint": "sha256:fingerprint", }, } +def _package_pin(name: str, version: str) -> dict: + return { + "name": name, + "version": version, + "resolution_status": "pinned", + "sha256": "a" * 64, + "wheel_url": f"https://example.test/{name}-{version}.whl", + } + + def _validation_report(version: str) -> dict: return { - "schema_version": 1, + "schema_version": 2, "bundle_version": version, "generated_at": "2026-06-03T00:00:00Z", "status": "passed", "checks": [ { - "name": "runtime", + "name": "registry_validation", "status": "passed", - "details": {"validated_on_platform": "test"}, + "command": "test", + "started_at": "2026-06-03T00:00:00Z", + "ended_at": "2026-06-03T00:00:01Z", + "details": { + "validated_on_platform": "test", + "bundle_dir": "/tmp/bundle", + }, } ], - "metadata": {"validation_scope": "full"}, + "metadata": {"validation_kind": "registry"}, } @@ -291,11 +430,3 @@ def _write_pyproject(tmp_path: Path) -> Path: def _write_json(path: Path, payload: dict) -> None: path.parent.mkdir(parents=True, exist_ok=True) path.write_text(json.dumps(payload, indent=2, sort_keys=False) + "\n") - - -def _sha256_file(path: Path) -> str: - hasher = hashlib.sha256() - with path.open("rb") as file: - for chunk in iter(lambda: file.read(1024 * 1024), b""): - hasher.update(chunk) - return hasher.hexdigest()