diff --git a/.claude/settings.json b/.claude/settings.json index 8a49fc40..fd7af7a2 100644 --- a/.claude/settings.json +++ b/.claude/settings.json @@ -1,4 +1,13 @@ { + "extraKnownMarketplaces": { + "dev-team-agents": { + "source": { + "source": "github", + "repo": "jodavis/agent-plugins", + "ref": "feature/ADR-246-cloud-dev-team" + } + } + }, "enabledPlugins": { "dotnet@dotnet-agent-skills": false, "dotnet-diag@dotnet-agent-skills": true, @@ -10,13 +19,6 @@ "microsoft-learn": { "type": "sse", "url": "https://learn.microsoft.com/api/mcp" - }, - "github": { - "type": "http", - "url": "https://api.githubcopilot.com/mcp", - "headers": { - "Authorization": "Bearer $GITHUB_PAT" - } } } } diff --git a/.claude/settings.local.json b/.claude/settings.local.json deleted file mode 100644 index 4a760032..00000000 --- a/.claude/settings.local.json +++ /dev/null @@ -1,14 +0,0 @@ -{ - "permissions": { - "allow": [ - "mcp__jira__getAccessibleAtlassianResources", - "mcp__jira__searchJiraIssuesUsingJql", - "mcp__jira__getJiraIssue", - "Bash(xargs:*)", - "mcp__jira__createJiraIssue", - "mcp__jira__editJiraIssue", - "mcp__jira__getJiraProjectIssueTypesMetadata", - "Monitor" - ] - } -} diff --git a/.dvc/.gitignore b/.dvc/.gitignore new file mode 100644 index 00000000..528f30c7 --- /dev/null +++ b/.dvc/.gitignore @@ -0,0 +1,3 @@ +/config.local +/tmp +/cache diff --git a/.dvc/config b/.dvc/config new file mode 100644 index 00000000..90ef99af --- /dev/null +++ b/.dvc/config @@ -0,0 +1,4 @@ +[core] + remote = adr-ml-training-data +[remote "adr-ml-training-data"] + url = s3://adr-ml-training-data/dvc diff --git a/.github/workflows/build-and-test.yml b/.github/workflows/build-and-test.yml index 4c2eb5ea..1efa25b3 100644 --- a/.github/workflows/build-and-test.yml +++ b/.github/workflows/build-and-test.yml @@ -60,3 +60,25 @@ jobs: with: files: | TestResults/**/*.trx + + python-tests: + runs-on: ubuntu-latest + permissions: + contents: read + + steps: + - name: Checkout code + uses: actions/checkout@v6 + + - name: Setup Python + uses: actions/setup-python@v5 + with: + python-version: '3.12' + + - name: Install pytest + run: pip install pytest + + - name: Run Python unit tests + working-directory: ml + run: python -m pytest --verbosity=1 + diff --git a/.gitignore b/.gitignore index 5f555a0a..00473d2c 100644 --- a/.gitignore +++ b/.gitignore @@ -373,4 +373,4 @@ FodyWeavers.xsd dotnet/ # Local configuration files -*.local.* \ No newline at end of file +*.local.* diff --git a/ml/_doc_ml.md b/ml/_doc_ml.md deleted file mode 100644 index 325c2640..00000000 --- a/ml/_doc_ml.md +++ /dev/null @@ -1,10 +0,0 @@ -Speech Recognition Models - -========================= - -## Folders -/ml/scripts -> Python scripts for implementing speech recognition training and evaluation. -/ml/data -> Contains datasets used for training and evaluating speech recognition models. -/ml/notebooks -> Jupyter notebooks for experimenting with different speech recognition techniques. -/ml/models -> Pre-trained speech recognition models and scripts for training new models. - diff --git a/ml/_spec_OopPipeline.md b/ml/_spec_OopPipeline.md new file mode 100644 index 00000000..7255ab7a --- /dev/null +++ b/ml/_spec_OopPipeline.md @@ -0,0 +1,1090 @@ +# OOP ML Pipeline + +> **Status:** Implementation-ready +> **Will become:** `_doc_MLPipeline.md` once implementation is complete + +## Overview + +Refactors the ML pipeline from numbered procedural scripts to a proper object-oriented design. Every pipeline stage becomes a Python class with injectable dependencies, enabling full unit testability with mocked collaborators — the same discipline used for the C# application code. A unified `Manifest[S]` / `Sample` abstraction replaces the existing multi-format CSV files and carries typed sample objects — with their applied values, seeds, and content hashes — through the entire pipeline, from intent phrase generation to model evaluation. The core innovation is a seed-based randomisation algorithm that stabilises sample generation across experiment changes: widening a noise range regenerates only the samples whose applied values would change; all others reuse their existing files. DVC continues to orchestrate stage execution; thin CLI entry points bridge DVC to the OOP classes. + +## Responsibilities & Boundaries + +- **Owns:** All pipeline stage logic (`ml/pipeline/`); `Manifest`/`Sample` data model and JSON serialisation; seed-based randomisation algorithm; `PassFilter` implementations; per-stage injectable protocol abstractions; DVC entry-point scripts; `dvc.yaml`; `params.yaml` +- **Does not own:** Model architecture decisions; TensorFlow/Keras internals; the edge_tts service; the CMU Pronouncing Dictionary source data +- **Integrates with:** DVC (`params.yaml` for experiment parameters, `dvc.yaml` for stage wiring); edge_tts for TTS generation; TensorFlow/Keras for model training and evaluation; librosa for spectrogram computation; CMU Pronouncing Dictionary for phoneme lookup + +## Key Design Decisions + +### All stages get OOP classes + +_Context:_ Procedural scripts with module-level side effects, hardcoded paths, and `if __name__ == "__main__"` guards are difficult to unit test. The Jira requires the same testability standard as the C# application code. + +_Decision:_ Every pipeline stage is a class with dependencies injected via its constructor. This is idiomatic Python: constructor injection is standard, `Protocol` classes are the Python-native way to define injectable interfaces (structural typing — no inheritance required), and `dataclass` is the natural container type. `ABC` is used only where there is real shared *implementation* to inherit — specifically `ModifierStage`, which encapsulates the skip-unchanged and GC logic. Stages with no shared implementation are standalone classes; no artificial common base is added just for uniformity. + +_Consequences:_ Every class is independently testable with mocked collaborators. DVC entry points become thin wrappers. + +--- + +### Unified end-to-end Manifest format + +_Context:_ The existing pipeline uses two distinct CSV formats — a variation CSV for augmentation stages and set-manifest CSVs for training/evaluation — making sample lineage opaque and parameter tracking fragile. + +_Decision:_ A single `Manifest[S]` class (generic over a `Sample` subtype) replaces both formats. The manifest serialises to JSON and carries `TextSample`, `AudioSample`, `SampleSpectrogram`, or `SampleTokens` objects with their applied values, seeds, and content hashes. Train/val/test splits are three separate files of the same format. All pipeline stages — including featurisation, training, and evaluation — consume the manifest type directly. + +_Consequences:_ All consumers share one JSON schema. The schema must be stable across pipeline runs to preserve stored seeds. Existing DVC stage I/O paths change. + +--- + +### ModifierStage for all per-sample file transformations + +_Context:_ The skip-unchanged and GC logic is valuable for any stage that transforms files one-by-one — not just data augmentation. Spectrogram computation has historically been expensive; recomputing all spectrograms when only a few samples changed wastes significant time. + +_Decision:_ `ModifierStage[T_in, T_out]` is used for every stage that transforms an input manifest into an output manifest on a per-sample basis: augmentation stages (`TtsSampleGenerator`, `DelayAugmentor`, `BackgroundNoiseAugmentor`, `MicrophoneNoiseAugmentor`) and featurisation stages (`SpectrogramStage`, `TokenStage`). Stages with no randomisation (featurisation) set `_is_deterministic = True`; `transform()` uses `output_seed = 0` for those stages. Intent phrase generation is handled by `PhraseVariator` (called from the entry-point) and is not a `ModifierStage` subclass — see the "ml/pipeline/ package" design decision below. + +_Consequences:_ `SampleSpectrogram` and `SampleTokens` carry a `seed` field set to 0. This is a minor inconsistency accepted in exchange for full code reuse of the skip-unchanged, GC, and manifest-management logic. + +--- + +### Content hash determines sample identity + +_Context:_ The goal is to skip regeneration for samples whose upstream source did not change, while regenerating those that did. DVC will skip a stage entirely if no deps change, but will re-run a stage if even one dep file changes — it cannot skip individual samples within a stage. + +_Decision:_ All non-text samples use a **unified content hash formula**: + +``` +content_hash = sha256(parent_content_hash + ":" + str(seed) + ":" + canonical(applied_values)) +``` + +Where `canonical(applied_values) = json.dumps(applied_values, sort_keys=True, separators=(',',':'), ensure_ascii=True)` and `str(seed)` is the decimal representation of the integer seed. All numeric values in `applied_values` are stored as raw int/float, never as formatted strings, to ensure hash stability across code changes. + +For `TextSample`: `content_hash = sha256(content.encode('utf-8'))` (no parent). + +For deterministic stages (`SpectrogramStage`, `TokenStage`): `seed = 0` and `applied_values = {}`, so: +``` +content_hash = sha256(parent_content_hash + ":0:{}") +``` + +Every `SampleWithPath` output type stores `parent_content_hash: str` (the content hash of the sample it was derived from). This field is the lookup key for skip-unchanged detection across all stage types, including chained `AudioSample → AudioSample` stages. + +All `ModifierStage` output dirs are configured with `persist: true` in `dvc.yaml` so DVC does not delete them between runs; `ModifierStage` handles GC of unreferenced files itself. GC algorithm: after building the output manifest, collect `{sample.path.name for sample in output_samples}`; glob `output_dir` flat (non-recursive); delete any file not in that set and not named `manifest.json`. DVC tracks directory outputs by hashing all files and automatically picks up GC changes on the next `dvc repro` — no manual `dvc commit` needed. + +**Terminology:** +- **Variation constraints** — inputs from `params.yaml` controlling randomisation shape (min/max, frequency, distribution). Not stored on the sample. +- **Applied values** — the specific values `VariationGenerator` selected, stored in `AudioSample.applied_values` and included in `content_hash`. + +_Consequences:_ A sample is regenerated only when its source content or applied values change. Applied values are stable across runs for the same source and seed. Changing constraints regenerates only samples whose rejection-sampling chain selects a different value. + +--- + +### Seed-based randomisation with pass filters + +_Context:_ Experiments frequently adjust variation constraints. Without a stable algorithm, all samples regenerate on every constraint change. + +_Decision:_ Each new output sample gets a seed via `int.from_bytes(os.urandom(8), 'big')`, stored in the manifest. If an input sample was seen before (matched by `parent_content_hash`), the stored seed is reused. Per-variable sub-seeds are derived as `int.from_bytes(sha256(f"{seed}:{variable_name}").digest()[:8], 'big')`. The frequency check uses a `:vary` sub-key: `sha256(f"{seed}:{variable_name}:vary")`. Rejection sampling draws candidates uniformly from `pass_filter.sample_domain()` and accepts them with probability `pass_filter.density(candidate)`. This makes each variable's value independent of ordering — reordering or adding variables in `_get_applied_values` does not affect existing variables' values. + +_Consequences:_ The randomisation logic is non-trivial and must be fully deterministic and independently reproducible. + +--- + +### Previous output manifest as the seed store + +_Context:_ Seeds must persist across DVC reruns. DVC invalidates outputs when params change, but `persist: true` preserves output files. + +_Decision:_ Every `ModifierStage` writes its output manifest to a path supplied by the entry-point. On each run, `transform` reads the previous manifest (if present) via `ManifestStore` to recover stored seeds. Seed recovery uses a **three-case algorithm**: +- **Skip**: previous output found for this input (`parent_content_hash` matches) AND recomputing the content hash with the stored seed and new constraints gives the same result → keep output unchanged. +- **Regenerate with stored seed**: previous output found BUT constraints changed, giving different applied values → re-run `_generate_output` with the same id and seed; new content_hash reflects new applied values. +- **New sample**: no previous output found → assign new id and seed via `os.urandom`. + +_Consequences:_ Deleting the output manifest resets all seeds. DVC stages using `ModifierStage` must have `persist: true`. + +--- + +### Python generics for typed stage input/output + +_Context:_ `ModifierStage[T_in, T_out]` and `Manifest[S]` have meaningful type parameters. + +_Decision:_ Use `typing.Generic[T_in, T_out]` with `TypeVar`. Python generics are **erased at runtime**; type parameters are for mypy only. `from __future__ import annotations` required for forward references. + +--- + +### Directories as stage I/O boundaries + +_Context:_ The existing pipeline duplicates paths in both `cmd:` and `deps:` of `dvc.yaml`. + +_Decision:_ Every stage exchanges whole directories. Entry-point scripts resolve file paths via `conventions.py`. `ManifestStore` and `ModifierStage` accept explicit `Path` arguments. DVC `deps` lists input directories; `outs` lists the output directory. + +_Consequences:_ `dvc.yaml` entries are short. File convention changes require updating only `conventions.py`. + +--- + +### ml/pipeline/ package with thin DVC entry points + +_Context:_ The existing `ml/scripts/` tree will be deleted before implementation begins. + +_Decision:_ All OOP code lives in `ml/pipeline/`. Each DVC stage is a minimal entry-point script in `ml/pipeline/stages/` that parses CLI args, resolves paths via `conventions.py`, constructs the stage with injected dependencies, and calls it. Stage filenames carry a two-digit prefix for sort order. + +The initial `Manifest[TextSample]` is bootstrapped in `intent_01_generate_phrases.py`. The entry-point constructs `PhraseVariator(random.Random(42))` and calls it to generate surface-form variants from the input CSV (reading the `phrase` and `command` columns). Each valid variant becomes a `TextSample` with `content = surface_form` and `label = command` (speech_to_detect). The entry-point applies `subsample_rate` filter and writes `Manifest[TextSample]`. + +The CMU Pronouncing Dictionary download is retained as a plain non-OOP DVC stage. The existing `ml/scripts/` directory and old `dvc.yaml` are deleted as the first implementation step; the new `dvc.yaml` is written from scratch. + +**Required params.yaml keys:** + +| Key | Default | Description | +|-----|---------|-------------| +| `pipeline.input_phrases_path` | `scripts/intent_prediction/01_input_phrases.csv` | Relative to DVC root (`ml/`); overridden in CI | +| `pipeline.subsample_rate` | `1` | 1 in N phrase variants; 1 = all | +| `pipeline.variations_per_phrase` | (impl) | Variants attempted per base phrase in PhraseVariator | +| `pipeline.n_mels` | (impl) | Spectrogram mel bands | +| `pipeline.time_steps` | (impl) | Spectrogram time dimension | +| `pipeline.input_token_length` | (impl) | Padded token length | +| `pipeline.epochs` | (impl) | Overridden in CI | +| `pipeline.batch_size` | (impl) | Training batch size | + +Per-stage variation constraints are under `stages.:` sections. Required keys: + +**`stages.create_set_manifests`** — split percentages (must sum to 100): +- `train_pct`, `val_pct`, `test_pct` (int) + +**`stages.add_delays`** — delay augmentation: +- `prefix_vary_probability`, `suffix_vary_probability` (float, 0.0–1.0) +- `prefix_min_s`, `prefix_max_s`, `suffix_min_s`, `suffix_max_s` (float, seconds) + +**`stages.add_background_noise`** — background noise augmentation: +- `vary_probability` (float) +- `volume_min`, `volume_max` (float, multiplier applied to noise signal) + +**`stages.add_mic_noise`** — microphone noise augmentation: +- `vary_probability` (float) +- `amplitude_min`, `amplitude_max` (float) + +**`stages.generate_speech_samples`** — TTS speech rate: +- `speech_rate_min`, `speech_rate_max` (int, percent; e.g. -10 to +20) + +--- + +### Live TTS in CI with a small fixed phrase set + +_Context:_ The Jira requires an E2E CI test that runs the full pipeline on a small sample. + +_Decision:_ CI uses `dvc repro --set-param` flags to override expensive parameters: + +```bash +dvc repro \ + --set-param pipeline.input_phrases_path=test/fixtures/ci_phrases.csv \ + --set-param pipeline.epochs=1 \ + --set-param pipeline.subsample_rate=100 +``` + +Edge_tts is called live. The `TtsProvider` protocol is the replacement seam if TTS strategy changes. + +## Planned Implementation + +### Directory Layout + +``` +ml/ + pipeline/ + __init__.py + core/ + sample.py # Sample, SampleWithPath, TextSample, AudioSample, + # SampleSpectrogram, SampleTokens + manifest.py # Manifest[S], ManifestStore + modifier_stage.py # ModifierStage[T_in, T_out] + randomization.py # VariationGenerator, PassFilter, MinMaxFilter, NormalFilter + io/ + audio_io.py # AudioReader, AudioWriter protocols + defaults + intent/ + phrase_variator.py # PhraseVariator (rng injectable; ports existing logic + sanity_check) + vocab_computer.py # VocabComputer + speech/ + tts_stage.py # TtsSampleGenerator(ModifierStage[TextSample, AudioSample]) + delay_stage.py # DelayAugmentor(ModifierStage[AudioSample, AudioSample]) + background_noise_stage.py # BackgroundNoiseAugmentor(ModifierStage[AudioSample, AudioSample]) + mic_noise_stage.py # MicrophoneNoiseAugmentor(ModifierStage[AudioSample, AudioSample]) + set_splitter.py # SetManifestSplitter + token_stage.py # TokenStage(ModifierStage[AudioSample, SampleTokens]) + spectrogram_stage.py # SpectrogramStage(ModifierStage[AudioSample, SampleSpectrogram]) + model_trainer.py # ModelTrainer + model_evaluator.py # ModelEvaluator + stages/ + conventions.py + intent_01_generate_phrases.py # PhraseVariator; bootstraps Manifest[TextSample] + intent_02_compute_vocab.py + speech_03_generate_samples.py + speech_04_add_delays.py + speech_05_add_background_noise.py + speech_06_add_mic_noise.py + speech_07_compute_tokens.py + speech_08_compute_spectrograms.py + speech_09_create_set_manifests.py + speech_10_train_model.py + speech_11_evaluate_model.py # val set: writes metrics.json + speech_12_package_test_samples.py # test set: writes test_samples.zip for app E2E tests + test/ + pipeline/ + core/ + intent/ + speech/ + e2e_pipeline_test.py # pytest; invokes dvc repro via subprocess + fixtures/ + ci_phrases.csv # 10 canonical phrases + download_phoneme_dictionary.py + dvc.yaml # written from scratch + params.yaml +``` + +### Interfaces + +#### conventions.py + +```python +def manifest_path(output_dir: Path) -> Path: + return output_dir / "manifest.json" + +def split_manifest_path(output_dir: Path, split: str) -> Path: + """split: 'train', 'val', or 'test'""" + return output_dir / f"{split}.json" + +def sample_file_path(output_dir: Path, sample_id: str, ext: str) -> Path: + """ext has no leading dot, e.g. 'wav', 'npy', 'json'""" + return output_dir / f"{sample_id}.{ext}" + +def model_path(output_dir: Path) -> Path: + return output_dir / "speech_to_text_model.keras" + +def evaluation_predictions_path(output_dir: Path) -> Path: + return output_dir / "evaluation_predictions.txt" + +def evaluation_metrics_path(output_dir: Path) -> Path: + """JSON file written by ModelEvaluator: {"wer": }""" + return output_dir / "metrics.json" + +def test_samples_path(output_dir: Path) -> Path: + """Zip written by ModelEvaluator.package_test_samples(): known-good audio fixtures.""" + return output_dir / "test_samples.zip" +``` + +--- + +#### Sample and Manifest + +```python +@dataclass +class Sample(ABC): + id: str # human-readable stable identifier derived from input values + applied values + # TextSample: uuid4() — has no file; not user-visible + # SampleWithPath types: derived by each stage's _derive_id(); used as filename stem + # e.g. "TV_ON_Jenny_r77_pre40_suf0" after TTS + delay stages + seed: int # stable as long as parent_content_hash is unchanged; 0 for deterministic + content_hash: str # sha256(parent_content_hash + ":" + str(seed) + ":" + canonical(applied)) + # exception: TextSample uses sha256(content.encode('utf-8')) + +@dataclass +class SampleWithPath(Sample, ABC): + """All ModifierStage T_out types. GC uses sample.path.name.""" + path: Path # relative filename derived from id (e.g. 'TV_ON_Jenny_r77.wav') + parent_content_hash: str # content_hash of the sample this was derived from; + # used as the skip-unchanged lookup key in transform() + +@dataclass +class TextSample(Sample): + content: str # phrase to speak (surface form variation) + label: str # speech_to_detect — what the model outputs + # seed = 0; no parent_content_hash (bootstrapped, not from a ModifierStage) + +@dataclass +class AudioSample(SampleWithPath): + transcript: str # speech_to_detect (= TextSample.label) + applied_values: dict[str, Any] # raw int/float values; see stage specs below + +@dataclass +class SampleSpectrogram(SampleWithPath): + transcript: str + parent_id: str # id of the AudioSample; used by ModelTrainer/Evaluator for lookup + +@dataclass +class SampleTokens(SampleWithPath): + transcript: str + parent_id: str # id of the AudioSample; used by ModelTrainer/Evaluator for lookup + +class Manifest(Generic[S]): + def __init__(self, samples: Sequence[S]): ... + @property + def samples(self) -> tuple[S, ...]: ... + def by_content_hash(self, h: str) -> S | None: ... + def by_id(self, id: str) -> S | None: ... + +class ManifestStore: + def read(self, path: Path) -> Manifest[Any]: + """Deserialise using 'sample_type' field: + 'text' → TextSample, 'audio' → AudioSample, + 'spectrogram' → SampleSpectrogram, 'tokens' → SampleTokens. + Path fields are relative filenames; callers prepend output_dir.""" + ... + def write(self, manifest: Manifest, path: Path) -> None: ... +``` + +JSON schema (version 1): +```json +{ + "version": 1, + "sample_type": "audio", + "samples": [ + { + "id": "uuid", + "path": "uuid.wav", + "transcript": "TV_ON", + "seed": 67890, + "content_hash": "sha256hex", + "parent_content_hash": "sha256hex", + "applied_values": { "voice": "en-US-JennyNeural", "speech_rate": 5 } + } + ] +} +``` + +`"sample_type"` declared once per manifest. Valid values: `"text"`, `"audio"`, `"spectrogram"`, `"tokens"`. `"path"` is a filename with no directory component. Numeric values in `applied_values` stored as raw int/float. + +`TextSample` JSON (no `path`, `parent_content_hash`, or `applied_values`; `seed` is always 0 and is serialised): +```json +{ + "version": 1, + "sample_type": "text", + "samples": [ + { "id": "uuid", "seed": 0, "content_hash": "sha256hex", + "content": "okay turn on the tv", "label": "TV_ON" } + ] +} +``` + +`SampleSpectrogram` / `SampleTokens` JSON (include `parent_id`; no `applied_values` since always `{}`): +```json +{ + "version": 1, + "sample_type": "spectrogram", + "samples": [ + { "id": "uuid", "path": "uuid.npy", "seed": 0, "content_hash": "sha256hex", + "parent_content_hash": "sha256hex", "transcript": "TV_ON", "parent_id": "audio-uuid" } + ] +} +``` + +--- + +#### PhraseVariator + +```python +class PhraseVariator: + def __init__(self, rng: random.Random): ... + def generate( + self, + base_phrases: Sequence[tuple[str, str]], # (phrase, command) from 01_input_phrases.csv + variations_per_phrase: int, + ) -> list[TextSample]: + """Ports _create_variation() and sanity_check() from the existing VariationGenerator + class in 01_generate_phrases.py (pleasantries, hesitations, case transforms, spelling + variants, repeats). NOT the full generate_variations() / incremental-deduplication loop + — that logic is replaced by a simple loop: for each base phrase, attempt to generate + `variations_per_phrase` valid variants; each attempt calls _create_variation() once + and passes it through sanity_check(). Does NOT port `target_samples` or + `load_existing_variations()` — no incremental logic. + This is NOT the new VariationGenerator in randomization.py (the seed-based numeric + randomiser). + "Port" means: every `random.*` module-level call in the original is replaced by + `self.rng.*` — same logic, no restructuring. Output must be identical for a fixed seed. + TextSample.content = surface form with all transformations (what TTS will speak); + TextSample.label = command (speech_to_detect, the canonical form the model should output). + Each output: id=uuid4(), seed=0 (PhraseVariator explicitly sets seed=0 at construction), + content_hash=sha256(content.encode('utf-8')). + Entry-point uses rng=random.Random(42) and reads variations_per_phrase from params.yaml.""" + ... +``` + +`intent_01_generate_phrases.py`: +1. Reads `pipeline.input_phrases_path` CSV, columns `phrase` (surface_form) and `command` (speech_to_detect) +2. Constructs `PhraseVariator(random.Random(42))` +3. Calls `generate(base_phrases)` → `list[TextSample]` +4. Applies: `[s for i, s in enumerate(variants) if i % subsample_rate == 0]` +5. Writes `Manifest[TextSample]` + +--- + +#### PassFilter and VariationGenerator + +```python +class PassFilter(ABC): + @abstractmethod + def density(self, value: float) -> float: + """Normalised density; max == 1.0. Acceptance probability in rejection sampling.""" + ... + + @abstractmethod + def sample_domain(self) -> tuple[float, float]: + """(low, high): range for uniform candidate generation. + MinMaxFilter: (min_val, max_val). + NormalFilter: (mean - 5*std_dev, mean + 5*std_dev).""" + ... + +class MinMaxFilter(PassFilter): + """Uniform over [min_val, max_val]. density() == 1.0 in range, 0.0 outside.""" + def __init__(self, min_val: float, max_val: float): ... + +class NormalFilter(PassFilter): + """Gaussian. density(x) = gaussian_pdf(x)/gaussian_pdf(mean); peak == 1.0. + Raises ValueError if std_dev <= 0.""" + def __init__(self, mean: float, std_dev: float): ... + +class VariationGenerator: + def __init__(self, sample_seed: int): ... + + def should_vary(self, variable_name: str, frequency: float) -> bool: + """True with probability frequency. + int.from_bytes(sha256(f"{seed}:{variable_name}:vary").digest()[:8], 'big') / 2^64 < frequency""" + ... + + def generate(self, variable_name: str, pass_filter: PassFilter) -> float: + """Rejection-sample deterministically. For attempt n = 0, 1, ...: + domain_low, domain_high = pass_filter.sample_domain() + raw = int.from_bytes(sha256(f"{seed}:{variable_name}:{n}").digest()[:8], 'big') + candidate = domain_low + (raw / 2^64) * (domain_high - domain_low) + accept_raw = int.from_bytes(sha256(f"{seed}:{variable_name}:{n}:accept").digest()[:8], 'big') + if (accept_raw / 2^64) < pass_filter.density(candidate): return candidate + Raises ValueError after 1000 iterations.""" + ... + + def generate_int(self, variable_name: str, pass_filter: MinMaxFilter) -> int: + """Integer in [int(min_val), int(max_val)] inclusive, uniform distribution. + Uses attempt-indexed hashes exactly like generate() — for attempt n = 0, 1, ...: + raw_int = int.from_bytes(sha256(f"{seed}:{variable_name}:{n}").digest()[:8], 'big') + Then bitmask rejection: range = int(max_val) - int(min_val); + mask = 2^ceil(log2(range+1))-1. + When range == 0: mask = 0; returns int(min_val) immediately (n=0, no loop). + candidate = int(min_val) + (raw_int & mask); accepted if candidate <= int(max_val). + Raises ValueError after 1000 iterations (same guard as generate()).""" + ... + + def choose(self, variable_name: str, options: Sequence[T]) -> T: + """Direct selection, no rejection loop: + idx = int.from_bytes(sha256(f"{seed}:{variable_name}:0").digest()[:8], 'big') % len(options) + return options[idx]""" + ... +``` + +--- + +#### ModifierStage + +```python +T_out = TypeVar('T_out', bound=SampleWithPath) + +class ModifierStage(ABC, Generic[T_in, T_out]): + _is_deterministic: ClassVar[bool] = False + # SpectrogramStage and TokenStage set this to True. + # transform() uses output_seed = 0 when True; os.urandom(8) otherwise. + + def __init__(self, output_dir: Path, manifest_store: ManifestStore): ... + + async def transform( + self, + input_manifest: Manifest[T_in], + manifest_path: Path, + ) -> Manifest[T_out]: + """ + Entry-point scripts call: asyncio.run(stage.transform(manifest, path)) + + Steps: + 1. Read previous output manifest from manifest_path (if present). + Build index: prev_by_parent = {out.parent_content_hash: out for out in prev.samples} + + 2. For each input_sample in input_manifest.samples: + + a. prev_out = prev_by_parent.get(input_sample.content_hash) + + b. If prev_out is not None: + - Compute new_applied = _get_applied_values(input_sample, + VariationGenerator(prev_out.seed)) + - Compute expected_hash = sha256(input_sample.content_hash + ":" + + str(prev_out.seed) + ":" + canonical(new_applied)) + - If expected_hash == prev_out.content_hash: + → KEEP prev_out unchanged (step is skipped; file already exists) + - Else (constraints changed → different applied_values): + → new_id = _derive_id(input_sample, new_applied) + → await _generate_output(input_sample, + output_id=new_id, output_seed=prev_out.seed, + applied_values=new_applied, + parent_content_hash=input_sample.content_hash) + (seed preserved; id, content_hash, and file updated; old file GC'd in step 3) + + c. If prev_out is None (new sample): + - output_seed = 0 if self._is_deterministic else int.from_bytes(os.urandom(8), 'big') + - generator = VariationGenerator(output_seed) + - new_applied = _get_applied_values(input_sample, generator) + - output_id = _derive_id(input_sample, new_applied) + - await _generate_output(input_sample, output_id, output_seed, + new_applied, input_sample.content_hash) + + 3. GC: collect {sample.path.name for sample in output_samples}; + flat-glob output_dir; delete files not in that set and not 'manifest.json'. + + 4. Write output manifest to manifest_path. + """ + ... + + @abstractmethod + def _get_applied_values( + self, sample: T_in, generator: VariationGenerator + ) -> dict[str, Any]: + """Return applied values dict. Return {} for deterministic stages.""" + ... + + @abstractmethod + async def _generate_output( + self, + input_sample: T_in, + output_id: str, + output_seed: int, + applied_values: dict[str, Any], + parent_content_hash: str, + ) -> T_out: + """Generate output file; return complete output Sample with output_id, output_seed, + parent_content_hash, and content_hash all set. + MUST compute content_hash via _compute_content_hash — do not reimplement the formula.""" + ... + + @abstractmethod + def _derive_id(self, input_sample: T_in, applied_values: dict[str, Any]) -> str: + """Return the id (= filename stem) for the output sample. + Called for both new samples AND regens with changed constraints — the id must be + deterministically derivable from input_sample.id and applied_values alone. + Each stage composes: f"{input_sample.id}_{stage_suffix(applied_values)}". + SpectrogramStage and TokenStage return input_sample.id unchanged (same stem, + different extension). + Uniqueness within the output directory is guaranteed because input_sample.id is + already unique and all stages only write to their own output directory.""" + ... + + @staticmethod + def _compute_content_hash( + parent_content_hash: str, output_seed: int, applied_values: dict[str, Any] + ) -> str: + """Compute the canonical content_hash for a non-text output sample. + All _generate_output implementations MUST call this — it is the single source of + truth for the hash formula; reimplementing it risks silent skip-detection breakage. + content_hash = sha256(parent_content_hash + ":" + str(output_seed) + ":" + + json.dumps(applied_values, sort_keys=True, + separators=(',',':'), ensure_ascii=True))""" + ... +``` + +--- + +#### Stage Constructor Signatures and Applied Values + +```python +class TtsSampleGenerator(ModifierStage[TextSample, AudioSample]): + def __init__( + self, + output_dir: Path, + manifest_store: ManifestStore, + tts_provider: TtsProvider, # retries are the TtsProvider's responsibility + voices: list[str], # en-US female ShortNames; fetched by entry-point via + # asyncio.run(edge_tts.list_voices()), filtered: + # Gender=='Female', Locale=='en-US', + # ':' not in ShortName, 'DragonHD'/'Turbo' not in ShortName + ): ... + # applied_values: {"voice": str, "speech_rate": int} + # voice: VariationGenerator.choose("voice", voices) + # speech_rate: int (raw, e.g. 5 for +5%), from generate_int("speech_rate", MinMaxFilter(...)) + # edge_tts rate string (e.g. "+5%") is formatted in _generate_output, NOT stored + # + # Synthesis: tts_provider.synthesize(text=input_sample.content, ...) + # TextSample.content is the full spoken form (surface form with pleasantries/hesitations) + # AudioSample.transcript is set to input_sample.label (speech_to_detect, no hesitations) + # i.e. TTS speaks "um, turn on the TV" but the model is trained to output "TV_ON" + # + # _derive_id: f"{input_sample.label}_{voice_short}_r{speech_rate + 100}" + # voice_short = voice.split('-')[-1].replace('Neural', '') # "en-US-JennyNeural" → "Jenny" + # e.g. "TV_ON_Jenny_r77" for label="TV_ON", voice=JennyNeural, rate=-23% + +class DelayAugmentor(ModifierStage[AudioSample, AudioSample]): + # applied_values: {"prefix_delay_s": float, "suffix_delay_s": float} + # Each drawn via generate("prefix_delay_s", MinMaxFilter(min_s, max_s)) if should_vary(...) else 0.0 + # 0.0 stored when not applied — both keys always present in applied_values for hash stability + # + # _derive_id: f"{input_sample.id}_pre{int(prefix_delay_s*1000)}_suf{int(suffix_delay_s*1000)}" + # e.g. "TV_ON_Jenny_r77_pre40_suf0" for 40ms prefix, no suffix + # params.yaml keys (new float semantics; replaces old integer 1-in-N frequency keys): + # stages.add_delays.prefix_vary_probability: float # e.g. 0.333 = 1-in-3 chance + # stages.add_delays.prefix_min_s: float + # stages.add_delays.prefix_max_s: float + # stages.add_delays.suffix_vary_probability: float + # stages.add_delays.suffix_min_s: float + # stages.add_delays.suffix_max_s: float + +class BackgroundNoiseAugmentor(ModifierStage[AudioSample, AudioSample]): + def __init__( + self, + output_dir: Path, + manifest_store: ManifestStore, + noise_provider: NoiseProvider, + audio_reader: AudioReader, + audio_writer: AudioWriter, + ): ... + # applied_values: {"noise_file": str, "noise_start_s": float, "noise_volume": float} + # noise_file: filename only (no path), from: + # VariationGenerator.choose("noise_file", sorted([p.name for p in provider.list_files()])) + # sorted() ensures OS-independent determinism + # choose() is ALWAYS called (regardless of should_vary), so the filename is always stored + # noise_start_s and noise_volume are 0.0 if should_vary returns False + # → all three keys always present in applied_values for hash stability + # + # noise_start_s bounds: derived at runtime from file durations. + # max_start_s = noise_file_duration_s - audio_sample_duration_s + # Filter: MinMaxFilter(0.0, max(0.0, max_start_s)) + # Both durations read via AudioReader; max(0.0, ...) handles edge case where noise is + # shorter than the audio sample (start forced to 0.0). + # + # _derive_id: f"{input_sample.id}_{noise_filestem}_v{int(noise_volume*100)}" + # noise_filestem = Path(noise_file).stem (noise_file is always stored per choose() contract) + # e.g. "TV_ON_Jenny_r77_pre40_suf0_BabyCry_v0" when volume=0 (noise not applied) + # "TV_ON_Jenny_r77_pre40_suf0_CafeFar_v45" when volume=0.45 + +class MicrophoneNoiseAugmentor(ModifierStage[AudioSample, AudioSample]): + # applied_values: {"mic_noise_amplitude": float} + # amplitude: 0.0 if should_vary returns False; drawn from MinMaxFilter otherwise + # + # _derive_id: f"{input_sample.id}_mic{int(mic_noise_amplitude*1000)}" + # e.g. "TV_ON_Jenny_r77_pre40_suf0_CafeFar_v45_mic0" (not applied) + # "TV_ON_Jenny_r77_pre40_suf0_CafeFar_v45_mic12" (amplitude=0.012) + +class SpectrogramStage(ModifierStage[AudioSample, SampleSpectrogram]): + _is_deterministic = True + def __init__( + self, + output_dir: Path, + manifest_store: ManifestStore, + n_mels: int, + time_steps: int, + audio_reader: AudioReader, + ): ... + # _get_applied_values returns {}; output is .npy file of shape (n_mels, time_steps) + # _derive_id: return input_sample.id (same stem, .npy extension) + +class TokenStage(ModifierStage[AudioSample, SampleTokens]): + _is_deterministic = True + def __init__( + self, + output_dir: Path, + manifest_store: ManifestStore, + vocab: VocabResult, + input_token_length: int, + ): ... + # _get_applied_values returns {}; tokens derived from AudioSample.transcript + # output is .json file named {id}.json (conventions.sample_file_path(output_dir, id, 'json')) + # content: {"phonemes": [...], "tokens": [...]} padded to input_token_length + # _derive_id: return input_sample.id (same stem as spectrogram and audio; .json extension) +``` + +--- + +#### IO Protocols + +```python +class AudioReader(Protocol): + async def read(self, path: Path) -> tuple[np.ndarray, int]: ... + # Returns (samples, sample_rate). Array is always 1-D mono float32. + # If the source file is stereo, the implementation converts to mono before returning. + # Consumers never need to handle channel reduction. + +class AudioWriter(Protocol): + async def write(self, path: Path, data: np.ndarray, sample_rate: int) -> None: ... + +class TtsProvider(Protocol): + async def synthesize(self, text: str, voice: str, rate: str, output_path: Path) -> None: ... + # Retries are the implementation's responsibility (not TtsSampleGenerator's) + +class NoiseProvider(Protocol): + def list_files(self) -> list[Path]: ... + +class PhonemeProvider(Protocol): + def lookup(self, word: str) -> list[str]: ... + # Raises PhonemeNotFoundError if word not in dictionary +``` + +--- + +#### Non-ModifierStage Classes + +```python +class PhraseVariator: + # See PhraseVariator section above + +class VocabComputer: + def __init__(self, phoneme_provider: PhonemeProvider): ... + def compute(self, manifest: Manifest[TextSample], output_dir: Path) -> VocabResult: + """Extract phoneme vocabulary from TextSample.label values (speech_to_detect). + INTENTIONAL CHANGE from existing pipeline (which used surface_form): + label is what the model outputs; surface forms are irrelevant for vocabulary coverage. + Labels are canonical command names (TV_ON, VOLUME_UP, etc.) — no digits. + Digit-to-word substitution (e.g. '1'→'ONE') from the old pipeline is NOT needed + and is NOT ported. + Writes to output_dir: phoneme_list.txt and words_to_phonemes.json. + The phoneme_trie.json from the old pipeline is dropped (not in VocabResult).""" + ... + +@dataclass +class VocabResult: + phoneme_list: list[str] + words_to_phonemes: dict[str, list[str]] + ctc_blank_idx: int # = len(phoneme_list); blank token appended at end — matches existing io_utils convention + +class SetManifestSplitter: + def __init__(self, seed: int = 42): ... + def split( + self, manifest: Manifest[AudioSample], + train_pct: int, val_pct: int, test_pct: int, + output_dir: Path, + ) -> tuple[Manifest[AudioSample], Manifest[AudioSample], Manifest[AudioSample]]: + """Shuffle and split by individual AudioSample (not stratified by transcript). + Writes train.json, val.json, test.json to output_dir. Percentages must sum to 100. + Input is the fully-augmented manifest from MicrophoneNoiseAugmentor. + INTENTIONAL CHANGE from existing pipeline (which split clean audio only). + Does NOT reassign ids — the AudioSample objects in the split manifests are the same + objects from the input manifest with unchanged id values. ModelTrainer/ModelEvaluator + filter SampleSpectrogram/SampleTokens via {parent_id in {s.id for s in split_manifest}}.""" + ... + +class KerasBackend(Protocol): + def build_ctc_model(self, num_classes: int, n_mels: int, time_steps: int) -> Any: ... + def train( + self, model: Any, + dataset: Any, # tf.data.Dataset yielding (spectrogram, tokens) tuples; + # spectrogram shape (n_mels, time_steps), tokens shape (input_token_length,) + # ModelTrainer is responsible for batching and prefetching + epochs: int, + ) -> list[float]: ... # per-epoch loss values (logged but not written to disk) + def predict(self, model: Any, dataset: Any) -> np.ndarray: ... + def save(self, model: Any, path: Path) -> None: ... + def load(self, path: Path) -> Any: ... + +class ModelTrainer: + def __init__(self, keras_backend: KerasBackend): ... + def train( + self, + train_manifest: Manifest[AudioSample], + vocab: VocabResult, + spectrogram_manifest: Manifest[SampleSpectrogram], + token_manifest: Manifest[SampleTokens], + spectrogram_dir: Path, + token_dir: Path, + output_dir: Path, + ) -> Path: + """spectrogram_manifest and token_manifest are the FULL combined manifests from + SpectrogramStage/TokenStage — they cover all splits (train + val + test). + train_manifest is the split subset. Filter: keep only spectrogram/token entries + where parent_id ∈ {s.id for s in train_manifest.samples}. + Build {parent_id: sample} lookup dicts from the filtered sets. + Construct tf.data.Dataset with batching/prefetching. + Call KerasBackend.train, then KerasBackend.save to conventions.model_path(output_dir). + Return the saved model path.""" + ... + +class ModelEvaluator: + def __init__(self, keras_backend: KerasBackend): ... + def evaluate( + self, + manifest: Manifest[AudioSample], + model_path: Path, + vocab: VocabResult, + spectrogram_manifest: Manifest[SampleSpectrogram], + token_manifest: Manifest[SampleTokens], + spectrogram_dir: Path, + token_dir: Path, + output_dir: Path, + ) -> EvaluationResult: + """Same parent_id lookup as ModelTrainer for the provided manifest (val split). + Writes to output_dir: + evaluation_predictions.txt — tab-separated lines: '{reference}\\t{hypothesis}' + metrics.json — {"wer": }""" + ... + + def package_test_samples( + self, + manifest: Manifest[AudioSample], + model_path: Path, + vocab: VocabResult, + spectrogram_manifest: Manifest[SampleSpectrogram], + token_manifest: Manifest[SampleTokens], + spectrogram_dir: Path, + token_dir: Path, + audio_dir: Path, # MicrophoneNoiseAugmentor output dir; used to locate WAV files for zip + output_dir: Path, + ) -> Path: + """Runs the same prediction loop as evaluate() (implemented via shared private logic), + then writes test_samples.zip to output_dir containing audio files for samples the + model predicted correctly (hypothesis == reference). These become known-good fixtures + for app unit/E2E tests. Implemented as a separate public method; both evaluate() and + package_test_samples() call private _run_predictions(...) to avoid duplication. + Returns the zip path (conventions.test_samples_path(output_dir)).""" + ... + +@dataclass +class EvaluationResult: + wer: float + predictions: list[tuple[str, str]] +``` + +--- + +### Data Flow + +``` +01_input_phrases.csv (phrase, command columns) + │ + │ PhraseVariator(rng=Random(42)) → TextSample variants (with sanity_check) + │ subsample_rate filter applied + ▼ +Manifest[TextSample] + │ + ├──► VocabComputer → VocabResult + │ (labels extracted from TextSample.label; phoneme_list.txt + + │ words_to_phonemes.json written; intentional change from surface_form) + │ + ▼ +TtsSampleGenerator(ModifierStage[TextSample, AudioSample]) + Manifest[AudioSample] (voice from sorted en-US list; speech_rate as int) + │ + ▼ +DelayAugmentor(ModifierStage[AudioSample, AudioSample]) + Manifest[AudioSample] (prefix_delay_s, suffix_delay_s) + │ + ▼ +BackgroundNoiseAugmentor(ModifierStage[AudioSample, AudioSample]) + Manifest[AudioSample] (noise_file, noise_start_s, noise_volume) + │ + ▼ +MicrophoneNoiseAugmentor(ModifierStage[AudioSample, AudioSample]) + Manifest[AudioSample] (mic_noise_amplitude; fully augmented) + │ + ├──► SpectrogramStage (_is_deterministic=True) + │ Manifest[SampleSpectrogram] (parent_id + parent_content_hash → AudioSample) + │ + ├──► TokenStage (_is_deterministic=True) ◄── VocabResult + │ Manifest[SampleTokens] (parent_id + parent_content_hash → AudioSample) + │ + └──► SetManifestSplitter (splits augmented audio by individual sample) + train.json / val.json / test.json (Manifest[AudioSample]) + (intentional change: existing pipeline split clean audio only) + +All three outputs are DVC-parallel (no interdependence) + │ + ├──► ModelTrainer ◄── train manifest + spectrogram/token manifests + VocabResult + │ speech_to_text_model.keras (KerasBackend.save called by ModelTrainer) + │ + ├──► ModelEvaluator.evaluate() ◄── val manifest + spectrogram/token manifests + VocabResult + │ evaluation_predictions.txt, metrics.json ({"wer": }) [stage 11] + │ + └──► ModelEvaluator.package_test_samples() ◄── test manifest + same inputs + test_samples.zip (known-good audio fixtures for app E2E tests) [stage 12] +``` + +**Skip-unchanged detection across chained AudioSample→AudioSample stages:** each `AudioSample` stores `parent_content_hash` (the content_hash of the audio it was derived from). `transform()` builds a `{output.parent_content_hash: output}` index. For each input audio sample, it looks up `input.content_hash` in that index — a match means this stage previously processed this exact input. It then re-derives applied_values using the stored seed and checks whether the content_hash would change; only regenerates if constraints changed. + +--- + +### Testing Approach + +**Unit tests** (pytest, `ml/test/pipeline/`): + +- `PhraseVariator`: determinism with fixed seed; variation types produced; sanity_check filters malformed variants. +- `VariationGenerator`: same seed → same value; stability across range widening; `ValueError` after 1000 iters; `choose` is direct (no loop); range=0 → returns min_val immediately. +- `ModifierStage`: unchanged samples preserved intact (step 2b skip path); constraint change → regenerate with stored seed (step 2b regen path); new samples get fresh seed (step 2c); GC removes orphaned files; `_is_deterministic=True` → output_seed=0. +- `TtsSampleGenerator`: applied_values has `voice` (str) and `speech_rate` (int); rate string formatted in `_generate_output`. +- `BackgroundNoiseAugmentor`: noise_file always stored even when noise not applied; noise_start_s and noise_volume are 0.0 when not applied. +- Async mocks use `asyncio.Event` and `asyncio.Future`. + +**E2E CI test** (`ml/test/e2e_pipeline_test.py`, pytest): + +```python +ml_root = Path(__file__).parent.parent # ml/ +train_output_dir = ml_root / "data" / "speech_10_train_model" +eval_output_dir = ml_root / "data" / "speech_11_evaluate_model" + +def test_full_pipeline_ci(): + subprocess.run([ + "dvc", "repro", + "--set-param", "pipeline.input_phrases_path=test/fixtures/ci_phrases.csv", + "--set-param", "pipeline.epochs=1", + "--set-param", "pipeline.subsample_rate=100", + ], check=True, cwd=ml_root) + + assert conventions.model_path(train_output_dir).exists() + metrics = json.loads(conventions.evaluation_metrics_path(eval_output_dir).read_text()) + assert math.isfinite(metrics["wer"]) # any finite WER ok; guards against NaN/inf only +``` + +Edge_tts is called live; CI runners must have internet access. The test is marked +`@pytest.mark.e2e` and excluded from the default `pytest` run via `pyproject.toml` +(`addopts = "-m 'not e2e'"`); CI invokes it explicitly with `pytest -m e2e`. + +## Open Questions + +_(None — all questions resolved during spec review.)_ + +## Tasks + +### [ADR-221](https://jodasoft.atlassian.net/browse/ADR-221) Task 1: Core data model + +Implement the `Sample` hierarchy and `Manifest[S]` / `ManifestStore` in `ml/pipeline/core/`. + +- [ ] `sample.py`: `Sample`, `SampleWithPath`, `TextSample`, `AudioSample`, `SampleSpectrogram`, `SampleTokens` dataclasses with all fields from the spec; `TextSample.id = uuid4()`; `SampleWithPath.id` derived per `_derive_id()` contract +- [ ] `manifest.py`: `Manifest[S]` with `samples`, `by_content_hash()`, `by_id()`; `ManifestStore.read()` (type registry) and `write()`; JSON schema version 1 +- [ ] Unit tests: round-trip serialisation for all four sample types; `ManifestStore.read()` selects the correct class from `sample_type`; `TextSample` serialises `seed: 0`; `SampleSpectrogram` serialises `parent_id` +- [ ] `validate-build` and `validate-tests` pass + +--- + +### [ADR-222](https://jodasoft.atlassian.net/browse/ADR-222) Task 2a: Randomisation engine — PassFilter and VariationGenerator + +Implement `ml/pipeline/core/randomization.py`. + +- [ ] `PassFilter` ABC with `density()` and `sample_domain()` +- [ ] `MinMaxFilter`: uniform over `[min_val, max_val]`; `density()` = 1.0 in range, 0.0 outside +- [ ] `NormalFilter`: Gaussian; `density(x) = gaussian_pdf(x)/gaussian_pdf(mean)`; raises `ValueError` for `std_dev <= 0` +- [ ] `VariationGenerator`: `should_vary`, `generate` (rejection-sample with attempt indexing), `generate_int` (bitmask + attempt indexing; range=0 → return `min_val` immediately), `choose` (direct, no loop) — all hash formulas from spec +- [ ] Unit tests: same seed → same value for all methods; stability across range widening; `generate` raises `ValueError` after 1000 iterations; `choose` is direct (no rejection loop); `generate_int` range=0; `NormalFilter` rejects `std_dev <= 0`; `should_vary` probability converges over many seeds; change constraints (make max higher and lower) with a value that changes and a value that doesn't change (find a seed that exhibits each behavior, one that gets higher or lower when max changes, another that stays in the lower range when max changes) +- [ ] `validate-build` and `validate-tests` pass + +--- + +### [ADR-223](https://jodasoft.atlassian.net/browse/ADR-223) Task 2b: Randomisation engine — ModifierStage + +Implement `ml/pipeline/core/modifier_stage.py`. Depends on Task 2a. + +- [ ] `ModifierStage[T_in, T_out]`: `transform()` three-case algorithm (skip / regen-with-stored-seed / new sample); `_derive_id()`, `_get_applied_values()`, `_generate_output()` abstract; `_compute_content_hash()` static; `_is_deterministic` class var +- [ ] `transform()` step 2b: re-derives applied values with stored seed; computes expected hash; skips if unchanged; calls `_derive_id(input, new_applied)` for regen (old file GC'd) +- [ ] `transform()` step 2c: calls `_derive_id(input, new_applied)` for new samples (no `uuid4()`) +- [ ] GC: deletes files in `output_dir` not in `{sample.path.name for sample in output_samples}` and not named `manifest.json` +- [ ] Unit tests: skip path preserves output file and id unchanged; constraint change → new id, same seed, updated content_hash, old file GC'd; new sample → `_derive_id` called, fresh seed; `_is_deterministic=True` → `output_seed=0`; GC removes orphaned files; find seeds where the same constraint change causes a change in one sample but not another +- [ ] `validate-build` and `validate-tests` pass + +--- + +### [ADR-224](https://jodasoft.atlassian.net/browse/ADR-224) Task 3: Intent stages + +Implement `ml/pipeline/intent/` and the first two DVC entry-points. + +**⚠️ Must be implemented before Task 9 runs** — `01_generate_phrases.py` is the reference for `_create_variation()` and `sanity_check()`. Read it in full before the scripts directory is deleted. + +- [ ] `phrase_variator.py`: `PhraseVariator` — port `_create_variation()` and `sanity_check()` from `ml/scripts/intent_prediction/01_generate_phrases.py`; replace every `random.*` call with `self.rng.*`; `generate(base_phrases, variations_per_phrase)` signature +- [ ] `vocab_computer.py`: `VocabComputer`, `VocabResult`; extracts from `TextSample.label`; writes `phoneme_list.txt` and `words_to_phonemes.json`; no digit substitution; `ctc_blank_idx = len(phoneme_list)` +- [ ] `stages/conventions.py`: all functions from the spec's Interfaces section +- [ ] `stages/intent_01_generate_phrases.py` and `intent_02_compute_vocab.py` entry-points +- [ ] Unit tests: `PhraseVariator` determinism with fixed seed; output identical to original `VariationGenerator` for same inputs; `sanity_check` filters malformed variants; `VocabComputer` produces correct phoneme list from label words +- [ ] `validate-build` and `validate-tests` pass + +--- + +### [ADR-225](https://jodasoft.atlassian.net/browse/ADR-225) Task 4: TTS stage + +Implement `ml/pipeline/speech/tts_stage.py` and `stages/speech_03_generate_samples.py`. + +- [ ] `TtsProvider` protocol; edge_tts implementation (retries internal to implementation) +- [ ] `TtsSampleGenerator(ModifierStage[TextSample, AudioSample])`: `_derive_id` = `f"{input.label}_{voice_short}_r{rate+100}"`; synthesizes `input_sample.content`; stores `transcript = input_sample.label`; `applied_values = {"voice": str, "speech_rate": int}`; rate string formatted in `_generate_output`, not stored +- [ ] Voice list fetched in entry-point; sorted for determinism; filtered per spec +- [ ] Unit tests: `applied_values` keys and types; skip path; rate string formatted correctly; `TtsProvider` called with `input_sample.content`; `AudioSample.transcript` = `input_sample.label`; derived id format +- [ ] `validate-build` and `validate-tests` pass + +--- + +### [ADR-226](https://jodasoft.atlassian.net/browse/ADR-226) Task 5a: Audio I/O and delay augmentation + +Implement `ml/pipeline/io/` and `delay_stage.py`. + +- [ ] `io/audio_io.py`: `AudioReader` protocol (returns mono float32 array; converts stereo internally), `AudioWriter` protocol; default implementations using librosa/soundfile +- [ ] `delay_stage.py`: `DelayAugmentor`; `_derive_id` = `f"{input.id}_pre{int(prefix*1000)}_suf{int(suffix*1000)}"`; both keys always in `applied_values` (0.0 when not applied); independent `should_vary` checks; float `params.yaml` keys per spec +- [ ] `stages/speech_04_add_delays.py` entry-point +- [ ] Unit tests: both keys always present; 0.0 stored correctly; `AudioReader` converts stereo input to mono; derived id format +- [ ] `validate-build` and `validate-tests` pass + +--- + +### [ADR-227](https://jodasoft.atlassian.net/browse/ADR-227) Task 5b: Background noise and microphone noise augmentation + +Implement `background_noise_stage.py` and `mic_noise_stage.py`. Depends on Task 5a (AudioReader). + +- [ ] `background_noise_stage.py`: `BackgroundNoiseAugmentor`; `NoiseProvider` protocol; `choose()` always called (noise_file always stored); `noise_start_s` bounds derived at runtime from durations; `noise_start_s` and `noise_volume` = 0.0 when not applied; all three keys always present; `_derive_id` = `f"{input.id}_{noise_filestem}_v{int(volume*100)}"` +- [ ] `mic_noise_stage.py`: `MicrophoneNoiseAugmentor`; `_derive_id` = `f"{input.id}_mic{int(amplitude*1000)}"` +- [ ] `stages/speech_05_add_background_noise.py` and `speech_06_add_mic_noise.py` +- [ ] Unit tests: `BackgroundNoiseAugmentor` stores noise_file even when not applied; noise_start_s clamped to 0.0 when noise shorter than audio; all keys always present; derived id formats +- [ ] `validate-build` and `validate-tests` pass + +--- + +### [ADR-228](https://jodasoft.atlassian.net/browse/ADR-228) Task 6: Featurisation and set splitting + +Implement `SpectrogramStage`, `TokenStage`, `SetManifestSplitter` and their entry-points. + +- [ ] `spectrogram_stage.py`: `SpectrogramStage(_is_deterministic=True)`; `_derive_id` returns `input_sample.id`; writes `{id}.npy` of shape `(n_mels, time_steps)`; `SampleSpectrogram.parent_id = input_sample.id` +- [ ] `token_stage.py`: `TokenStage(_is_deterministic=True)`; `_derive_id` returns `input_sample.id`; writes `{id}.json` with `{"phonemes": [...], "tokens": [...]}` padded to `input_token_length`; `SampleTokens.parent_id = input_sample.id` +- [ ] `set_splitter.py`: `SetManifestSplitter`; shuffles full augmented manifest; writes `train.json`, `val.json`, `test.json`; preserves `AudioSample.id` values unchanged +- [ ] `stages/speech_07_compute_tokens.py`, `speech_08_compute_spectrograms.py`, `speech_09_create_set_manifests.py` +- [ ] Unit tests: `SpectrogramStage` and `TokenStage` skip-unchanged paths; `parent_id` set correctly; `SetManifestSplitter` percentages sum correctly; ids unchanged in split outputs +- [ ] `validate-build` and `validate-tests` pass + +--- + +### [ADR-229](https://jodasoft.atlassian.net/browse/ADR-229) Task 7: Model training + +Implement `ModelTrainer` and `stages/speech_10_train_model.py`. + +- [ ] `KerasBackend` protocol + default implementation: `build_ctc_model`, `train` (using `model.fit()` with CTC loss), `predict`, `save`, `load` +- [ ] `ModelTrainer.train()`: filters spectrogram/token manifests by `parent_id ∈ {s.id for s in train_manifest}`; constructs `tf.data.Dataset` with batching/prefetching; calls `KerasBackend.train` then `save`; returns model path +- [ ] `speech_10_train_model.py` entry-point +- [ ] Unit tests: filters correctly by `parent_id`; `KerasBackend.train` called with correct dataset; `KerasBackend.save` called after training +- [ ] `validate-build` and `validate-tests` pass + +--- + +### [ADR-230](https://jodasoft.atlassian.net/browse/ADR-230) Task 8: Model evaluation and test packaging + +Implement `ModelEvaluator` and the final two entry-points. + +- [ ] `ModelEvaluator._run_predictions()`: shared private logic; same `parent_id` filter as `ModelTrainer` +- [ ] `ModelEvaluator.evaluate()`: calls `_run_predictions()`; writes `evaluation_predictions.txt` (tab-separated reference/hypothesis) and `metrics.json` +- [ ] `ModelEvaluator.package_test_samples()`: calls `_run_predictions()`; filters to hypothesis == reference; zips matching WAV files from `audio_dir`; writes `test_samples.zip` +- [ ] `stages/speech_11_evaluate_model.py` and `speech_12_package_test_samples.py` +- [ ] Unit tests: `evaluate()` writes correct files; `package_test_samples()` includes only correctly-predicted samples; both methods call `_run_predictions()` (shared, not duplicated) +- [ ] `validate-build` and `validate-tests` pass + +--- + +### [ADR-231](https://jodasoft.atlassian.net/browse/ADR-231) Task 9: DVC wiring and old script cleanup + +Wire all stages in `dvc.yaml`; migrate `params.yaml`; delete old scripts. + +**⚠️ Script deletion must happen after Task 3 is merged** — the old `ml/scripts/` tree is the reference for `PhraseVariator`. Delete it only once Task 3's implementation is reviewed and merged. + +- [ ] Write `ml/dvc.yaml` from scratch: all 12 stages wired with correct `cmd`, `deps`, `outs`, `params`; `persist: true` on all `ModifierStage` output dirs +- [ ] Write `ml/params.yaml`: all `pipeline.*` keys (including `input_phrases_path`, `variations_per_phrase`, `subsample_rate`, `n_mels`, `time_steps`, `input_token_length`, `epochs`, `batch_size`) and all `stages.*` variation-constraint keys +- [ ] Retain `download_phoneme_dictionary.py` as a non-OOP stage +- [ ] Delete `ml/scripts/` tree (after Task 3 is merged) +- [ ] `dvc repro` runs end-to-end without errors on the dev machine +- [ ] `validate-build` and `validate-tests` pass + +--- + +### [ADR-232](https://jodasoft.atlassian.net/browse/ADR-232) Task 10: E2E CI test + +Add the full-pipeline integration test and pytest configuration. The test is plain Python pytest — no BDD framework. The Given/When/Then structure is written as a docstring in the test function for clarity, not as a Gherkin framework construct. + +- [ ] `ml/test/e2e_pipeline_test.py` with `@pytest.mark.e2e`; test function docstring documents the Given/When/Then scenario +- [ ] `ml/test/fixtures/ci_phrases.csv` with 10 canonical phrases +- [ ] `pyproject.toml` configured: `addopts = "-m 'not e2e'"` excludes from default runs; CI calls `pytest -m e2e` explicitly +- [ ] Assertions: `conventions.model_path(train_output_dir).exists()`; `metrics["wer"]` is finite +- [ ] `validate-build` and `validate-tests` pass (unit tests only; E2E requires live internet) + +## Related Docs + +- [`ml/_doc_ml.md`](./ml/_doc_ml.md) — current pipeline architecture and stage descriptions +- [`src/_doc_Projects.md`](./src/_doc_Projects.md) — project boundaries diff --git a/ml/pipeline/__init__.py b/ml/pipeline/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/ml/pipeline/core/__init__.py b/ml/pipeline/core/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/ml/pipeline/core/manifest.py b/ml/pipeline/core/manifest.py new file mode 100644 index 00000000..05a899c8 --- /dev/null +++ b/ml/pipeline/core/manifest.py @@ -0,0 +1,186 @@ +from __future__ import annotations + +import json +from pathlib import Path +from typing import Any, Generic, Sequence, TypeVar + +from pipeline.core.sample import ( + AudioSample, + Sample, + SampleSpectrogram, + SampleTokens, + TextSample, +) + +S = TypeVar("S", bound=Sample) + + +class Manifest(Generic[S]): + """Typed, immutable collection of samples with O(1) lookup by id or content_hash.""" + + def __init__(self, samples: Sequence[S]) -> None: + self._samples: tuple[S, ...] = tuple(samples) + # Keep first occurrence for content_hash — duplicate hashes are possible + # (deterministic stage, same parent + seed=0 + empty applied_values). + self._by_content_hash: dict[str, S] = {} + for s in self._samples: + self._by_content_hash.setdefault(s.content_hash, s) + self._by_id: dict[str, S] = {s.id: s for s in self._samples} + if len(self._by_id) != len(self._samples): + counts: dict[str, int] = {} + for s in self._samples: + counts[s.id] = counts.get(s.id, 0) + 1 + dupes = [sid for sid, n in counts.items() if n > 1] + raise ValueError(f"Manifest contains duplicate sample ids: {dupes}") + + @property + def samples(self) -> tuple[S, ...]: + return self._samples + + def by_content_hash(self, h: str) -> S | None: + return self._by_content_hash.get(h) + + def by_id(self, id: str) -> S | None: + return self._by_id.get(id) + + +_SAMPLE_TYPE_KEY: dict[type, str] = { + TextSample: "text", + AudioSample: "audio", + SampleSpectrogram: "spectrogram", + SampleTokens: "tokens", +} + +_SAMPLE_TYPE_CLASS: dict[str, type] = {v: k for k, v in _SAMPLE_TYPE_KEY.items()} + + +class ManifestStore: + """Reads and writes Manifest JSON files (schema version 1). + + JSON format: + {"version": 1, "sample_type": "", "samples": [...]} + + Path fields in JSON are bare filenames (no directory component). + Callers prepend output_dir when resolving full paths. + Numeric applied_values are stored as raw int/float — never as strings. + """ + + def read(self, path: Path) -> Manifest[Any]: + with open(path, "r", encoding="utf-8") as f: + data = json.load(f) + if data.get("version") != 1: + raise ValueError(f"Unsupported manifest version: {data.get('version')!r}") + sample_type = data.get("sample_type") + if sample_type is None: + raise ValueError("Missing 'sample_type' in manifest") + cls = _SAMPLE_TYPE_CLASS.get(sample_type) + if cls is None: + raise ValueError(f"Unknown sample_type: {sample_type!r}") + samples = [_deserialise(cls, entry) for entry in data["samples"]] + return Manifest(samples) + + def write(self, manifest: Manifest[Any], path: Path) -> None: + samples = manifest.samples + if not samples: + raise ValueError("Cannot write empty manifest") + first_type = type(samples[0]) + if any(type(s) is not first_type for s in samples): + raise ValueError( + f"Manifest contains mixed sample types: " + f"{set(type(s).__name__ for s in samples)}" + ) + sample_type = _SAMPLE_TYPE_KEY[first_type] + payload: dict[str, Any] = { + "version": 1, + "sample_type": sample_type, + "samples": [_serialise(s) for s in samples], + } + with open(path, "w", encoding="utf-8") as f: + json.dump(payload, f) + + +# --------------------------------------------------------------------------- +# Serialisation helpers +# --------------------------------------------------------------------------- + +def _serialise(s: Sample) -> dict[str, Any]: + if isinstance(s, TextSample): + return { + "id": s.id, + "seed": s.seed, + "content_hash": s.content_hash, + "content": s.content, + "label": s.label, + } + if isinstance(s, AudioSample): + return { + "id": s.id, + "seed": s.seed, + "content_hash": s.content_hash, + "path": s.path.name, + "parent_content_hash": s.parent_content_hash, + "transcript": s.transcript, + "applied_values": s.applied_values, + } + if isinstance(s, SampleSpectrogram): + return { + "id": s.id, + "seed": s.seed, + "content_hash": s.content_hash, + "path": s.path.name, + "parent_content_hash": s.parent_content_hash, + "transcript": s.transcript, + "parent_id": s.parent_id, + } + if isinstance(s, SampleTokens): + return { + "id": s.id, + "seed": s.seed, + "content_hash": s.content_hash, + "path": s.path.name, + "parent_content_hash": s.parent_content_hash, + "transcript": s.transcript, + "parent_id": s.parent_id, + } + raise ValueError(f"Unrecognised sample type: {type(s)}") + + +def _deserialise(cls: type, e: dict[str, Any]) -> Sample: + if cls is TextSample: + return TextSample( + seed=e["seed"], + content_hash=e["content_hash"], + content=e["content"], + label=e["label"], + ) + if cls is AudioSample: + return AudioSample( + id=e["id"], + seed=e["seed"], + content_hash=e["content_hash"], + path=Path(e["path"]), + parent_content_hash=e["parent_content_hash"], + transcript=e["transcript"], + applied_values=e["applied_values"], + ) + if cls is SampleSpectrogram: + return SampleSpectrogram( + id=e["id"], + seed=e["seed"], + content_hash=e["content_hash"], + path=Path(e["path"]), + parent_content_hash=e["parent_content_hash"], + transcript=e["transcript"], + parent_id=e["parent_id"], + ) + if cls is SampleTokens: + return SampleTokens( + id=e["id"], + seed=e["seed"], + content_hash=e["content_hash"], + path=Path(e["path"]), + parent_content_hash=e["parent_content_hash"], + transcript=e["transcript"], + parent_id=e["parent_id"], + ) + raise ValueError(f"Unrecognised class: {cls}") diff --git a/ml/pipeline/core/modifier_stage.py b/ml/pipeline/core/modifier_stage.py new file mode 100644 index 00000000..f7b67b14 --- /dev/null +++ b/ml/pipeline/core/modifier_stage.py @@ -0,0 +1,146 @@ +from __future__ import annotations + +import hashlib +import json +import os +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Any, ClassVar, Generic, TypeVar + +from pipeline.core.manifest import Manifest, ManifestStore +from pipeline.core.randomization import VariationGenerator +from pipeline.core.sample import Sample, SampleWithPath + +T_in = TypeVar("T_in", bound=Sample) +T_out = TypeVar("T_out", bound=SampleWithPath) + + +class ModifierStage(ABC, Generic[T_in, T_out]): + """Abstract base for all per-sample file-transformation stages. + + Subclasses implement _get_applied_values, _generate_output, and _derive_id. + transform() drives the three-case skip/regen/new algorithm, GC, and manifest write. + """ + + _is_deterministic: ClassVar[bool] = False + + def __init__(self, output_dir: Path, manifest_store: ManifestStore) -> None: + self._output_dir = output_dir + self._manifest_store = manifest_store + + async def transform( + self, + input_manifest: Manifest[T_in], + manifest_path: Path, + ) -> Manifest[T_out]: + """Run the three-case algorithm over input_manifest; write output manifest. + + Step 1: Read previous manifest (if present) and build parent_content_hash index. + Step 2: For each input sample, skip / regen / generate-new. + Step 3: GC — delete output_dir files not in the new output set and not manifest.json. + Step 4: Write output manifest to manifest_path. + """ + # Step 1: build previous-output index keyed on parent_content_hash + prev_by_parent: dict[str, T_out] = {} + if manifest_path.exists(): + prev = self._manifest_store.read(manifest_path) + prev_by_parent = { + out.parent_content_hash: out for out in prev.samples + } + + # Step 2: process each input sample + output_samples: list[T_out] = [] + for input_sample in input_manifest.samples: + prev_out = prev_by_parent.get(input_sample.content_hash) + + if prev_out is not None: + # 2b: previous output exists for this input — check if constraints changed + new_applied = self._get_applied_values( + input_sample, VariationGenerator(prev_out.seed) + ) + expected_hash = self._compute_content_hash( + input_sample.content_hash, prev_out.seed, new_applied + ) + if expected_hash == prev_out.content_hash: + # Skip: file and id unchanged + output_samples.append(prev_out) + else: + # Regen: constraints changed; preserve seed; derive new id + new_id = self._derive_id(input_sample, new_applied) + result = await self._generate_output( + input_sample, + output_id=new_id, + output_seed=prev_out.seed, + applied_values=new_applied, + parent_content_hash=input_sample.content_hash, + ) + output_samples.append(result) + else: + # 2c: new sample — assign fresh seed and derive id + if self._is_deterministic: + output_seed = 0 + else: + output_seed = int.from_bytes(os.urandom(8), "big") + generator = VariationGenerator(output_seed) + new_applied = self._get_applied_values(input_sample, generator) + output_id = self._derive_id(input_sample, new_applied) + result = await self._generate_output( + input_sample, + output_id=output_id, + output_seed=output_seed, + applied_values=new_applied, + parent_content_hash=input_sample.content_hash, + ) + output_samples.append(result) + + # Step 3: GC — flat glob; delete files not in the new output set + kept_names = {sample.path.name for sample in output_samples} + if self._output_dir.exists(): + for file in self._output_dir.glob("*"): + if file.is_file() and file.name != "manifest.json" and file.name not in kept_names: + file.unlink() + + # Step 4: write output manifest + output_manifest: Manifest[T_out] = Manifest(output_samples) + self._manifest_store.write(output_manifest, manifest_path) + + return output_manifest + + @abstractmethod + def _get_applied_values( + self, sample: T_in, generator: VariationGenerator + ) -> dict[str, Any]: + """Return applied values dict. Return {} for deterministic stages.""" + ... + + @abstractmethod + async def _generate_output( + self, + input_sample: T_in, + output_id: str, + output_seed: int, + applied_values: dict[str, Any], + parent_content_hash: str, + ) -> T_out: + """Generate output file; return complete output Sample with all fields set. + MUST compute content_hash via _compute_content_hash.""" + ... + + @abstractmethod + def _derive_id(self, input_sample: T_in, applied_values: dict[str, Any]) -> str: + """Return the output sample id (= filename stem). + Called for both new samples and regens with changed constraints.""" + ... + + @staticmethod + def _compute_content_hash( + parent_content_hash: str, output_seed: int, applied_values: dict[str, Any] + ) -> str: + """Single source of truth for the content_hash formula. + All _generate_output implementations MUST call this method. + """ + canonical = json.dumps( + applied_values, sort_keys=True, separators=(",", ":"), ensure_ascii=True + ) + raw = f"{parent_content_hash}:{output_seed}:{canonical}" + return hashlib.sha256(raw.encode("utf-8")).hexdigest() diff --git a/ml/pipeline/core/randomization.py b/ml/pipeline/core/randomization.py new file mode 100644 index 00000000..2b98f05b --- /dev/null +++ b/ml/pipeline/core/randomization.py @@ -0,0 +1,134 @@ +from __future__ import annotations + +import hashlib +import math +from abc import ABC, abstractmethod +from typing import Sequence, TypeVar + +T = TypeVar("T") + +_TWO_TO_64 = 2**64 +_MAX_ATTEMPTS = 1000 + + +def _hash_int(key: str) -> int: + return int.from_bytes(hashlib.sha256(key.encode()).digest()[:8], "big") + + +class PassFilter(ABC): + def __init__(self, precision: int = 0) -> None: + domain_low, domain_high = self.sample_domain() + scale = 10 ** precision + low_s = round(domain_low * scale) + high_s = round(domain_high * scale) + bias_s = max(0, -low_s) + shifted_high = high_s + bias_s + self._precision = precision + self._scale = scale + self._low_s = low_s + self._high_s = high_s + self._bias_s = bias_s + self._pow2_range = 1 << math.ceil(math.log2(shifted_high + 1)) if shifted_high > 0 else 1 + + @abstractmethod + def density(self, value: float) -> float: + """Normalised density; max == 1.0. Acceptance probability in rejection sampling.""" + ... + + @abstractmethod + def sample_domain(self) -> tuple[float, float]: + """(low, high): range for uniform candidate generation.""" + ... + + +class MinMaxFilter(PassFilter): + """Uniform over [min_val, max_val]. density() == 1.0 in range, 0.0 outside.""" + + def __init__(self, min_val: float, max_val: float, *, precision: int = 0) -> None: + if min_val > max_val: + raise ValueError( + f"min_val must be <= max_val, got min_val={min_val}, max_val={max_val}" + ) + self._min_val = min_val + self._max_val = max_val + super().__init__(precision) + + def density(self, value: float) -> float: + return 1.0 if self._min_val <= value <= self._max_val else 0.0 + + def sample_domain(self) -> tuple[float, float]: + return (self._min_val, self._max_val) + + +class NormalFilter(PassFilter): + """Gaussian. density(x) = gaussian_pdf(x)/gaussian_pdf(mean); peak == 1.0.""" + + def __init__(self, mean: float, std_dev: float, *, precision: int = 0) -> None: + if std_dev <= 0: + raise ValueError(f"std_dev must be positive, got {std_dev}") + self._mean = mean + self._std_dev = std_dev + super().__init__(precision) + + def density(self, value: float) -> float: + z = (value - self._mean) / self._std_dev + return math.exp(-0.5 * z * z) + + def sample_domain(self) -> tuple[float, float]: + return (self._mean - 5 * self._std_dev, self._mean + 5 * self._std_dev) + + +class VariationGenerator: + def __init__(self, sample_seed: int) -> None: + self._seed = sample_seed + + def should_vary(self, variable_name: str, frequency: float) -> bool: + """True with probability frequency using sha256-based deterministic hash.""" + if not 0.0 <= frequency <= 1.0: + raise ValueError(f"frequency must be in [0.0, 1.0], got {frequency}") + raw = _hash_int(f"{self._seed}:{variable_name}:vary") + return (raw / _TWO_TO_64) < frequency + + def generate(self, variable_name: str, pass_filter: PassFilter) -> float: + """Rejection-sample using power-of-2 modulo for stability across domain changes.""" + if pass_filter._high_s == pass_filter._low_s: + return pass_filter._low_s / pass_filter._scale + pow2_range = pass_filter._pow2_range + bias_s = pass_filter._bias_s + scale = pass_filter._scale + for n in range(_MAX_ATTEMPTS): + raw = _hash_int(f"{self._seed}:{variable_name}:{n}") + candidate = (raw % pow2_range - bias_s) / scale + accept_raw = _hash_int(f"{self._seed}:{variable_name}:{n}:accept") + if (accept_raw / _TWO_TO_64) < pass_filter.density(candidate): + return candidate + raise ValueError( + f"generate() failed to find an accepted candidate after {_MAX_ATTEMPTS} " + f"attempts for variable '{variable_name}'" + ) + + def generate_int(self, variable_name: str, pass_filter: MinMaxFilter) -> int: + """Integer in [int(min_val), int(max_val)] inclusive using bitmask rejection.""" + domain_low, domain_high = pass_filter.sample_domain() + min_val = int(domain_low) + max_val = int(domain_high) + range_ = max_val - min_val + if range_ == 0: + return min_val + mask = (1 << math.ceil(math.log2(range_ + 1))) - 1 + for n in range(_MAX_ATTEMPTS): + raw = _hash_int(f"{self._seed}:{variable_name}:{n}") + candidate = min_val + (raw & mask) + if candidate <= max_val: + return candidate + raise ValueError( + f"generate_int() failed to find an accepted candidate after {_MAX_ATTEMPTS} " + f"attempts for variable '{variable_name}'" + ) + + def choose(self, variable_name: str, options: Sequence[T]) -> T: + """Direct selection via sha256 hash modulo len(options); no rejection loop.""" + if not options: + raise ValueError("options must be non-empty") + raw = _hash_int(f"{self._seed}:{variable_name}:0") + return options[raw % len(options)] diff --git a/ml/pipeline/core/sample.py b/ml/pipeline/core/sample.py new file mode 100644 index 00000000..3ff44014 --- /dev/null +++ b/ml/pipeline/core/sample.py @@ -0,0 +1,77 @@ +from __future__ import annotations + +from abc import ABC +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + + +@dataclass +class Sample(ABC): + """Base class for all pipeline sample types. + + id: stable identifier. + TextSample: equals content_hash — content-addressable; no file, not user-visible. + SampleWithPath subtypes: derived by each stage's _derive_id(); used as filename stem. + seed: 0 for TextSample and deterministic stages; os.urandom(8) for stochastic stages. + content_hash: sha256 hex digest. + TextSample: sha256(content.encode('utf-8')). + SampleWithPath subtypes: sha256(parent_content_hash + ":" + str(seed) + ":" + canonical(applied_values)). + """ + + id: str + seed: int + content_hash: str + + +@dataclass +class SampleWithPath(Sample, ABC): + """Base for all ModifierStage output types that produce a file. + + path: relative filename (bare name only, no directory component). + parent_content_hash: content_hash of the upstream sample; skip-unchanged lookup key. + """ + + path: Path + parent_content_hash: str + + +@dataclass +class TextSample(Sample): + """Bootstrapped phrase variant; no output file. + + id is derived from content_hash (content-addressable; callers do not pass it). + seed=0 for all TextSamples. + content_hash = sha256(content.encode('utf-8')). + """ + + content: str + label: str + id: str = field(init=False) + + def __post_init__(self) -> None: + self.id = self.content_hash + + +@dataclass +class AudioSample(SampleWithPath): + """WAV file produced by TtsSampleGenerator or an augmentation stage.""" + + transcript: str + applied_values: dict[str, Any] + + +@dataclass +class SampleSpectrogram(SampleWithPath): + """NPY file produced by SpectrogramStage (seed=0, deterministic).""" + + transcript: str + parent_id: str + + +@dataclass +class SampleTokens(SampleWithPath): + """JSON file produced by TokenStage (seed=0, deterministic).""" + + transcript: str + parent_id: str diff --git a/ml/pyproject.toml b/ml/pyproject.toml new file mode 100644 index 00000000..fc3d8fed --- /dev/null +++ b/ml/pyproject.toml @@ -0,0 +1,5 @@ +[tool.pytest.ini_options] +testpaths = ["test"] +pythonpath = ["."] +addopts = "-m 'not e2e'" +markers = ["e2e: End-to-end tests that invoke dvc repro; excluded from default run"] diff --git a/ml/scripts/intent_prediction/01_generate_phrases.py b/ml/scripts/intent_prediction/01_generate_phrases.py new file mode 100644 index 00000000..e8af56a5 --- /dev/null +++ b/ml/scripts/intent_prediction/01_generate_phrases.py @@ -0,0 +1,404 @@ +#!/usr/bin/env python3 +""" +Training Data Generator for LLM Intent Classification + +This script generates training data variations from the 01_input_phrases.csv file +for fine-tuning an LLM to handle remote control commands as a fallback. +""" + +import csv +import random +import re +from pathlib import Path +from typing import List, Dict, Tuple + + +# ===================== +# CONFIGURATION SETTINGS +# ===================== +SCRIPT_DIR = Path(__file__).parent +DATA_FOLDER = SCRIPT_DIR / "../../data/intent_prediction/01_generate_phrases/" +INPUT_FILE = SCRIPT_DIR / "01_input_phrases.csv" +OUTPUT_FILE = DATA_FOLDER / "training_data.csv" + +# Number of total samples to generate +TARGET_SAMPLES = 10000 + +# Probability settings for variations +REPEAT_MODIFIER_CHANCE = 0.25 +PLEASANTRY_CHANCE = 0.3 +HESITATION_CHANCE = 0.3 +SPELLING_VARIANT_CHANCE = 0.3 +CASE_VARIANT_CHANCE = 0.3 + +# ===================== + +# Variation components +REPEAT_MODIFIERS = { + 1: ["", "once", "one", "one time", "one more time", "another one", "another time", "again"], + 2: ["twice", "two", "two times", "two more times", "another two", "another two times"], + 3: ["three", "three times", "three more times", "another three", "another three times"], + 4: ["four", "four times", "four more times", "another four", "another four times"], + 5: ["five", "five times", "five more times", "another five", "another five times"], + 6: ["six", "six times", "six more times", "another six", "another six times"], + 7: ["seven", "seven times", "seven more times", "another seven", "another seven times"], + 8: ["eight", "eight times", "eight more times", "another eight", "another eight times"], + 9: ["nine", "nine times", "nine more times", "another nine", "another nine times"], +} + +PLEASANTRIES = [ + "", + "please ", + ", please ", + "please, ", + ", please, ", + "could you ", + "can you ", + "would you ", + ", thank you ", + ", thanks " +] + +HESITATIONS = [ + "", + ", um, ", + ", uh, ", + ", umm, ", + ", err, ", + ", hmm, ", + " ... ", +] + +# Homophone/spelling variations +SPELLING_VARIANTS = { + "one": ["one", "won", "1"], + "to": ["to", "too", "two", "2"], + "two": ["two", "to", "too", "2"], + "three": ["three", "3"], + "for": ["for", "four", "4"], + "four": ["four", "for", "4"], + "five": ["five", "5"], + "six": ["six", "6"], + "seven": ["seven", "7"], + "eight": ["eight", "ate", "8"], + "nine": ["nine", "9"], + "right": ["right", "rite", "write", "wright"], + "OK": ["OK", "okay", "ok"], + "pause": ["pause", "paws"] +} + +class VariationGenerator: + """Generates variations of command phrases.""" + + def __init__(self, target_samples: int): + self.target_samples = target_samples + (self.generated, self.existing_variations) = self.load_existing_variations(OUTPUT_FILE) + + def load_existing_variations(self, filepath: Path) -> Tuple[set, List[Dict[str, str]]]: + """Load existing variations from a CSV file to avoid duplicates.""" + variations = [] + generated_keys = set() + if not filepath.exists(): + return generated_keys, variations + with open(filepath, newline='', encoding='utf-8') as csvfile: + reader = csv.DictReader(csvfile) + for row in reader: + variations.append(row) + key = self._create_key(row) + generated_keys.add(key) + return (generated_keys, variations) + + def generate_variations(self, commands: List[Dict[str, str]]) -> List[Dict[str, str]]: + """Generate variations for all commands.""" + variations = self.existing_variations + samples_per_phrase = max(1, self.target_samples // len(commands)) + + for cmd in commands: + existing_for_phrase = len([v for v in variations if v['base_phrase'] == cmd['phrase']]) + cmd_variations = self._generate_for_command( + cmd['phrase'], + cmd['canonical'], + samples_per_phrase - existing_for_phrase + ) + variations.extend(cmd_variations) + + # If we haven't reached target, add more variations + while len(variations) < self.target_samples and len(commands) > 0: + cmd = random.choice(commands) + extra = self._generate_for_command( + cmd['phrase'], + cmd['canonical'], + 1 + ) + variations.extend(extra) + + return variations[:self.target_samples] + + def _generate_for_command(self, phrase: str, canonical: str, count: int) -> List[Dict[str, str]]: + """Generate variations for a single command.""" + variations = [] + attempts = 0 + max_attempts = count * 10 # Prevent infinite loops + + while len(variations) < count and attempts < max_attempts: + attempts += 1 + + # Generate a variation + variation = self._create_variation(phrase) + + # Check for duplicates + variant_key = self._create_key(variation) + if variant_key not in self.generated: + self.generated.add(variant_key) + variation['canonical_label'] = canonical + variations.append(variation) + + return variations + + def _create_key(self, row: Dict[str, str]) -> str: + """Create a unique key for a variation based on base phrase and canonical label.""" + return row['base_phrase'].lower() + "|" + row['transformations'] + + def _create_variation(self, phrase: str) -> Dict[str, str]: + """Create a single variation of a phrase.""" + transformations = [] + result = phrase + repeat_count_used = 1 + + # Add repeat modifier (for data variety only) + repeat_modifier = "" + if random.random() < REPEAT_MODIFIER_CHANCE: + repeat_count_used = random.choice(list(REPEAT_MODIFIERS.keys())) + modifiers = REPEAT_MODIFIERS.get(repeat_count_used, []) + if modifiers: + repeat_modifier = random.choice(modifiers) + else: + repeat_modifier = "" + repeat_count_used = 1 + if repeat_modifier: + result = f"{result} {repeat_modifier}" + transformations.append(f"repeat_modifier:{repeat_count_used}") + + # Add pleasantry + if random.random() < PLEASANTRY_CHANCE: + pleasantry = random.choice(PLEASANTRIES) + if pleasantry: + if pleasantry.startswith("could you") or pleasantry.startswith("can you") or pleasantry.startswith("would you"): + result = f"{pleasantry} {result}" + transformations.append(f"prefix_pleasantry:{pleasantry}") + elif pleasantry.startswith(","): + result = f"{result}{pleasantry}" + transformations.append(f"suffix_pleasantry:{pleasantry}") + else: + if random.random() > 0.5: + result = f"{result}, {pleasantry}" + else: + result = f"{pleasantry}, {result}" + transformations.append(f"pleasantry:{pleasantry}") + + # Save the speech to detect at this point. This is what the STT system should recognize. + speech_to_detect = self._normalize_commas_and_whitespace(result) + + # Add hesitation + if random.random() < HESITATION_CHANCE: + hesitation = random.choice([h for h in HESITATIONS if h]) + if hesitation: + # Insert hesitation at random position + words = result.split() + if len(words) > 1: + pos = random.randint(0, len(words)) + words.insert(pos, hesitation) + result = " ".join(words) + transformations.append(f"hesitation:{hesitation}") + + # Apply spelling variations + if random.random() < SPELLING_VARIANT_CHANCE: + for word, variants in SPELLING_VARIANTS.items(): + lower_result = result.lower() + if word in lower_result and len(variants) > 1: + variant = random.choice([v for v in variants if v != word]) + result_lower = lower_result.replace(word, variant, 1) + # Match original casing roughly + if result.isupper(): + result = result_lower.upper() + elif result.istitle(): + result = result_lower.title() + else: + result = result_lower + transformations.append(f"spelling_variant:{word}->{variant}") + + # Random case variations + if random.random() < CASE_VARIANT_CHANCE: + case_transform = random.choice(["lower", "upper", "title", "original"]) + if case_transform == "lower": + result = result.lower() + transformations.append("lowercase") + elif case_transform == "upper": + result = result.upper() + transformations.append("uppercase") + elif case_transform == "title": + result = result.title() + transformations.append("titlecase") + + result = self._normalize_commas_and_whitespace(result) + + return { + 'base_phrase': phrase, + 'surface_form': result, + 'speech_to_detect': speech_to_detect, + 'transformations': "|".join(transformations) if transformations else "none", + 'repeat_count': repeat_count_used + } + + def _normalize_commas_and_whitespace(self, text: str) -> str: + """Normalize commas and whitespace in the given text.""" + # Collapse consecutive commas into a single comma + text = re.sub(r'(?:\s*,\s*){2,}', ',', text) + # Remove leading/trailing commas and normalize whitespace + text = re.sub(r'^\s*,\s*|\s*,\s*$', '', text).strip() + text = re.sub(r'\s+', ' ', text).strip() + text = re.sub(r'\s,', ',', text).strip() + return text + + def _tokenize(self, text: str) -> List[str]: + """Tokenize text into alphanumeric lowercase tokens.""" + return re.findall(r"[a-z0-9']+", text.lower()) + + def sanity_check(self, variations: List[Dict[str, str]]) -> Tuple[List[Dict[str, str]], List[str]]: + """Perform sanity checks on generated variations.""" + valid = [] + issues = [] + + for var in variations: + surface = var['surface_form'] + canonical = var['canonical_label'] + + # Check 1: Not empty + if not surface or not surface.strip(): + issues.append(f"Empty surface form for {canonical}") + continue + + # Check 2: Reasonable length (5-150 characters) + if len(surface) < 2 or len(surface) > 150: + issues.append(f"Unusual length ({len(surface)}) for: {surface}") + continue + + # Check 3: Contains at least one letter + if not any(c.isalpha() for c in surface): + issues.append(f"No letters in: {surface}") + continue + + # Check 4: Base phrase recognizable via token overlap (with punctuation stripped) + base_tokens_list = self._tokenize(var['base_phrase']) + surface_token_set = set(self._tokenize(surface)) + + # Fallback to canonical label tokens if base tokens are empty + if not base_tokens_list: + base_tokens_list = self._tokenize(canonical) + + augmented_base_tokens = set(base_tokens_list) + + transformations_str = var['transformations'] + if transformations_str and transformations_str != "none": + for t in transformations_str.split('|'): + if t.startswith("spelling_variant:"): + try: + mapping = t.split(":", 1)[1] + if "->" in mapping: + base_word, variant_word = mapping.split("->", 1) + augmented_base_tokens.update(self._tokenize(base_word)) + augmented_base_tokens.update(self._tokenize(variant_word)) + except ValueError: + # Ignore malformed transformation metadata + pass + + if augmented_base_tokens and surface_token_set: + if augmented_base_tokens.isdisjoint(surface_token_set): + issues.append(f"Base phrase '{var['base_phrase']}' not recognizable in '{surface}'") + continue + + valid.append(var) + + return valid, issues + + +def main(): + """Main execution function.""" + print("Training Data Generator for LLM Intent Classification") + print("=" * 60) + + # Step 1: Extract commands from CSV + print("\n1. Extracting commands from " + INPUT_FILE.as_posix() + "...") + csv_path = INPUT_FILE + commands = [] + with open(csv_path, newline='', encoding='utf-8') as csvfile: + reader = csv.DictReader(csvfile) + for row in reader: + commands.append({ + 'phrase': row['phrase'], + 'canonical': row['command'] + }) + print(f" Found {len(commands)} command phrases") + + # Step 2: Generate variations + print(f"\n2. Generating {TARGET_SAMPLES} variations...") + generator = VariationGenerator(TARGET_SAMPLES) + variations = generator.generate_variations(commands) + print(f" Generated {len(variations)} initial variations") + + # Step 3: Sanity check + print("\n3. Running sanity checks...") + valid_variations, issues = generator.sanity_check(variations) + print(f" Valid: {len(valid_variations)}") + print(f" Issues: {len(issues)}") + if issues: + print("\n Sample issues:") + for issue in issues[:5]: + print(f" - {issue}") + + # Step 4: Write to CSV + print(f"\n4. Writing to {OUTPUT_FILE}...") + output_path = Path(OUTPUT_FILE) + output_path.parent.mkdir(parents=True, exist_ok=True) + + with open(output_path, 'w', newline='', encoding='utf-8') as f: + fieldnames = [ + 'surface_form', + 'base_phrase', + 'speech_to_detect', + 'canonical_label', + 'repeat_count', + 'transformations' + ] + writer = csv.DictWriter(f, fieldnames=fieldnames) + writer.writeheader() + for row in valid_variations: + writer.writerow({ + 'surface_form': row['surface_form'], + 'base_phrase': row['base_phrase'], + 'speech_to_detect': row['speech_to_detect'], + 'canonical_label': row['canonical_label'], + 'repeat_count': row.get('repeat_count', ''), + 'transformations': row['transformations'] + }) + + print(f"\n✓ Successfully generated {len(valid_variations)} training samples") + print(f"✓ Output saved to: {OUTPUT_FILE}") + + # Step 5: Show sample outputs + print("\n5. Sample outputs:") + print("-" * 60) + samples = random.sample(valid_variations, min(10, len(valid_variations))) + for i, sample in enumerate(samples, 1): + print(f"\n{i}. Surface form: {sample['surface_form']}") + print(f" Canonical: {sample['canonical_label']}") + print(f" Base: {sample['base_phrase']}") + print(f" Repeat count: {sample.get('repeat_count', '')}") + print(f" Transforms: {sample['transformations']}") + + print("\n" + "=" * 60) + print("Done!") + + +if __name__ == "__main__": + main() diff --git a/ml/scripts/intent_prediction/01_input_phrases.csv b/ml/scripts/intent_prediction/01_input_phrases.csv new file mode 100644 index 00000000..9394dfd3 --- /dev/null +++ b/ml/scripts/intent_prediction/01_input_phrases.csv @@ -0,0 +1,52 @@ +phrase,command +Back,Back +Go back,Back +Page down,ChannelDown +Channel down,ChannelDown +Page up,ChannelUp +Channel up,ChannelUp +Go down,Down +Down,Down +Quit,Exit +Exit,Exit +Guide,Guide +Go to Guide,Guide +Left,Left +Go left,Left +Mute,Mute +Netflix,Netflix +Go to Netflix,Netflix +Pause,Pause +Play,Play +Turn Off,PowerOff +Power Off,PowerOff +Off,PowerOff +Turn Off the TV,PowerOff +On,PowerOn +Power On,PowerOn +Turn On,PowerOn +Turn On the TV,PowerOn +Record,Record +Replay,Replay +Skip back,Replay +Right,Right +Go right,Right +Select,Select +OK,Select +Advance,Skip +Skip forward,Skip +Skip,Skip +Go to TiVo,TiVo +TiVo,TiVo +Up,Up +Go up,Up +Turn it down,VolumeDown +Quieter,VolumeDown +Softer,VolumeDown +Volume down,VolumeDown +Turn down the volume,VolumeDown +Louder,VolumeUp +Volume up,VolumeUp +Turn it up,VolumeUp +Crank it up,VolumeUp +Turn up the volume,VolumeUp diff --git a/ml/scripts/intent_prediction/02_compute_vocab.py b/ml/scripts/intent_prediction/02_compute_vocab.py new file mode 100644 index 00000000..95a1ed2c --- /dev/null +++ b/ml/scripts/intent_prediction/02_compute_vocab.py @@ -0,0 +1,99 @@ +import string +import argparse +from pathlib import Path +import os +from tqdm import tqdm +import pandas as pd +import json +import sys +sys.path.insert(0, str(Path(__file__).parent.parent)) # adds ml/scripts/ to sys.path + +from shared.phoneme_utils import load_phoneme_dict, build_trie + +# Settings + +# Parse command-line arguments +parser = argparse.ArgumentParser(description="Download phoneme dictionary files.") +parser.add_argument('--phoneme-dictionary-dir', type=Path, required=True, help='Directory containing phoneme dictionary files') +parser.add_argument('--training-data-file', type=Path, required=True, help='Path to training data CSV file') +parser.add_argument('--output-dir', type=Path, required=True, help='Directory for output of vocabulary files') + +word_translation_table = str.maketrans({ + **{ + '1': 'ONE', + '2': 'TWO', + '3': 'THREE', + '4': 'FOR', + '5': 'FIVE', + '6': 'SIX', + '7': 'SEVEN', + '8': 'EIGHT', + '9': 'NINE', + }, + **{p: ' ' for p in string.punctuation} +}) + + +def extract_words_from_csv(training_data_file): + """Extract all distinct words from the 'surface_form' column of the training data CSV file, stripping punctuation.""" + words = set() + df = pd.read_csv(training_data_file, encoding="utf-8") + if 'surface_form' not in df.columns: + raise ValueError("'surface_form' column not found in training data file.") + # Create translation table: punctuation -> space + for cell in df['surface_form'].dropna(): + # Replace punctuation with spaces + cell_clean = str(cell).translate(word_translation_table) + for word in cell_clean.split(): + if word: + words.add(word.upper()) + return sorted(words) + + +if __name__ == "__main__": + paths = parser.parse_args() + + os.makedirs(paths.output_dir, exist_ok=True) + + # 1. Load phoneme dictionary + phoneme_dict = load_phoneme_dict(paths.phoneme_dictionary_dir) + + # 2. Extract words from training data + words = extract_words_from_csv(paths.training_data_file) + + # 3. Map words to phoneme sequences + word_to_phonemes = {} + missing_words = [] + for word in words: + if word in phoneme_dict: + word_to_phonemes[word] = phoneme_dict[word] + else: + missing_words.append(word) + + if missing_words: + print(f"Warning: {len(missing_words)} words not found in phoneme dictionary. They will be skipped.") + for entry in missing_words[:10]: # Print first 10 missing words + print(f" - {entry}") + + # 4. Collect all phonemes used + phoneme_set = set() + for phonemes in word_to_phonemes.values(): + phoneme_set.update(phonemes) + phoneme_list = sorted(phoneme_set) + + # 5. Output JSON file: words and their phoneme sequences + words_json_path = Path(paths.output_dir) / "words_to_phonemes.json" + with open(words_json_path, "w", encoding="utf-8") as f: + json.dump({w: word_to_phonemes[w] for w in sorted(word_to_phonemes)}, f, indent=2, ensure_ascii=False) + + # 6. Output phoneme list file + phoneme_list_path = Path(paths.output_dir) / "phoneme_list.txt" + with open(phoneme_list_path, "w", encoding="utf-8") as f: + for phoneme in phoneme_list: + f.write(phoneme + "\n") + + # 7. Output trie JSON file + trie = build_trie(word_to_phonemes) + trie_json_path = Path(paths.output_dir) / "phoneme_trie.json" + with open(trie_json_path, "w", encoding="utf-8") as f: + json.dump(trie, f, indent=2, ensure_ascii=False) diff --git a/ml/scripts/intent_prediction/02a_download_phoneme_dictionary.py b/ml/scripts/intent_prediction/02a_download_phoneme_dictionary.py new file mode 100644 index 00000000..7110b4be --- /dev/null +++ b/ml/scripts/intent_prediction/02a_download_phoneme_dictionary.py @@ -0,0 +1,38 @@ +import argparse +from pathlib import Path +import os +import asyncio +from tqdm import tqdm +import aiohttp + +# Settings +download_urls = [ + "https://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-0.7b", + "https://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-0.7b.phones" +] + +# Parse command-line arguments +parser = argparse.ArgumentParser(description="Download phoneme dictionary files.") +parser.add_argument('--output-dir', type=Path, required=True, help='Directory for output of downloaded files') +paths = parser.parse_args() + +os.makedirs(paths.output_dir, exist_ok=True) + +async def download_file(session, url, output_dir): + filename = url.split("/")[-1] + output_path = Path(output_dir) / filename + async with session.get(url) as resp: + resp.raise_for_status() + with open(output_path, "wb") as f: + async for chunk in resp.content.iter_chunked(1024): + f.write(chunk) + +async def main(): + async with aiohttp.ClientSession() as session: + tasks = [download_file(session, url, paths.output_dir) for url in download_urls] + for f in tqdm(asyncio.as_completed(tasks), total=len(tasks), desc="Downloading"): + await f + +if __name__ == "__main__": + asyncio.run(main()) + diff --git a/ml/scripts/requirements.txt b/ml/scripts/requirements.txt new file mode 100644 index 00000000..9018cabf --- /dev/null +++ b/ml/scripts/requirements.txt @@ -0,0 +1,13 @@ +edge-tts==6.1.9 +pytest>=8.0 +pytest-mock>=3.0 +pyyaml>=6.0 +pandas==2.2.0 +tqdm==4.66.1 +pydub==0.25.1 +soundfile==0.12.1 +librosa==0.10.1 +tensorflow==2.15.0 +onnx==1.17.0 +tf2onnx==1.16.1 +jiwer==3.0.3 \ No newline at end of file diff --git a/ml/scripts/shared/__init__.py b/ml/scripts/shared/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/ml/scripts/shared/constants.py b/ml/scripts/shared/constants.py new file mode 100644 index 00000000..c6f076bc --- /dev/null +++ b/ml/scripts/shared/constants.py @@ -0,0 +1,52 @@ +""" +Centralised pipeline dimension constants. All scripts import from here +instead of hardcoding values independently. + +Values are read from ml/params.yaml so they can be overridden via +``dvc exp run --set-param pipeline.n_mels=40`` without touching source code. +""" +from pathlib import Path + +import yaml + +_params_path = Path(__file__).parent.parent.parent / "params.yaml" +with open(_params_path, encoding='utf-8') as _f: + _y = yaml.safe_load(_f) + _p = _y["pipeline"] + _d = _y["add_delays"] + _gssv = _y["generate_speech_sample_variations"] + _gs = _y["generate_speech_samples"] + _abn = _y["add_background_noise"] + _amn = _y["add_microphone_noise"] + _csm = _y["create_set_manifests"] + +INPUT_TOKEN_LENGTH: int = _p["input_token_length"] +N_MELS: int = _p["n_mels"] +TIME_STEPS: int = _p["time_steps"] +BATCH_SIZE: int = _p["batch_size"] +EPOCHS: int = _p["epochs"] +SUBSAMPLE_RATE: int = _p["subsample_rate"] + +PREFIX_DELAY_FREQUENCY: int = _d["prefix_frequency"] +SUFFIX_DELAY_FREQUENCY: int = _d["suffix_frequency"] +MAX_DELAY_DURATION: float = _d["max_duration"] +MIN_DELAY_DURATION: float = _d["min_duration"] + +MIN_SPEECH_RATE: int = _gssv["min_speech_rate"] +MAX_SPEECH_RATE: int = _gssv["max_speech_rate"] + +MAX_RETRIES: int = _gs["max_retries"] +RETRY_DELAY_SEC: int = _gs["retry_delay_sec"] + +BACKGROUND_NOISE_FREQUENCY: int = _abn["noise_frequency"] +BACKGROUND_NOISE_VOLUME_MIN: float = _abn["volume_min"] +BACKGROUND_NOISE_VOLUME_MAX: float = _abn["volume_max"] + +MICROPHONE_NOISE_FREQUENCY: int = _amn["noise_frequency"] +MICROPHONE_NOISE_VOLUME_MIN: float = _amn["volume_min"] +MICROPHONE_NOISE_VOLUME_MAX: float = _amn["volume_max"] +MICROPHONE_NOISE_TYPE: str = _amn["noise_type"] + +TRAINING_SET_PERCENTAGE: int = _csm["training_set_percentage"] +VALIDATION_SET_PERCENTAGE: int = _csm["validation_set_percentage"] +TEST_SET_PERCENTAGE: int = _csm["test_set_percentage"] \ No newline at end of file diff --git a/ml/scripts/shared/ctc_utils.py b/ml/scripts/shared/ctc_utils.py new file mode 100644 index 00000000..9829d1bf --- /dev/null +++ b/ml/scripts/shared/ctc_utils.py @@ -0,0 +1,43 @@ +""" +CTC decoding utilities shared between 09_evaluate_model.py and +10_evaluate_test_samples.py. Both scripts previously contained identical +copies of these functions. +""" +import numpy as np + + +def ctc_greedy_decode(pred, blank: int) -> list[list[int]]: + """Greedy CTC decode. + + Args: + pred: array of shape (batch, time, classes) + blank: index of the CTC blank token + + Returns: + List of decoded index lists (one per batch item). + """ + pred_ids = np.argmax(pred, axis=-1) + decoded = [] + for seq in pred_ids: + prev = blank + out: list[int] = [] + for idx in seq: + if idx != prev and idx != blank: + out.append(int(idx)) + prev = idx + decoded.append(out) + return decoded + + +def indices_to_words(indices: list[int], vocab_list: list[str]) -> list[str]: + return [vocab_list[i] if i < len(vocab_list) else '[UNK]' for i in indices] + + +def trim_at_blank(seq: np.ndarray, vocab_list: list[str]) -> np.ndarray: + """Trim sequence at first index that is out-of-vocab (>= len(vocab_list)).""" + idxs = np.where(seq >= len(vocab_list))[0] + return seq[:idxs[0]] if len(idxs) > 0 else seq + + +def tokens_to_text(tokens: list[int], vocab_list: list[str]) -> str: + return ''.join([vocab_list[idx] for idx in tokens if idx < len(vocab_list)]) diff --git a/ml/scripts/shared/io_utils.py b/ml/scripts/shared/io_utils.py new file mode 100644 index 00000000..4bde9886 --- /dev/null +++ b/ml/scripts/shared/io_utils.py @@ -0,0 +1,42 @@ +""" +Shared I/O helpers used by scripts 06, 07, 08, 09, 10. + +Centralising these patterns means a format change (e.g. npy → JSON for +spectrograms) only requires one code change instead of four. +""" +import json +from pathlib import Path + +import numpy as np +import pandas as pd + + +def write_token_list(output_dir: Path, wav_stem: str, phonemes: list, tokens: list) -> None: + out_path = output_dir / f"{wav_stem}_tokens.json" + with open(out_path, 'w', encoding='utf-8') as f: + json.dump({'phonemes': phonemes, 'tokens': tokens}, f) + + +def read_token_list(token_list_dir: Path, wav_stem: str) -> np.ndarray: + tokens_file = token_list_dir / f"{wav_stem}_tokens.json" + with open(tokens_file, 'r', encoding='utf-8') as f: + return np.array(json.load(f)['tokens'], dtype=np.int32) + + +def write_spectrogram(output_dir: Path, wav_stem: str, data: np.ndarray) -> None: + np.save(output_dir / f"{wav_stem}.npy", data) + + +def read_spectrogram(spectrogram_dir: Path, wav_stem: str) -> np.ndarray: + return np.load(spectrogram_dir / f"{wav_stem}.npy") + + +def read_phoneme_list(phoneme_list_path: Path) -> tuple[list[str], int]: + """Return (vocab_list, ctc_blank_idx). ctc_blank_idx = len(vocab_list).""" + with open(phoneme_list_path, 'r', encoding='utf-8') as f: + vocab_list = [line.strip() for line in f if line.strip()] + return vocab_list, len(vocab_list) + + +def read_manifest(manifest_path: Path) -> pd.DataFrame: + return pd.read_csv(manifest_path, encoding='utf-8') diff --git a/ml/scripts/shared/phoneme_utils.py b/ml/scripts/shared/phoneme_utils.py new file mode 100644 index 00000000..079c0268 --- /dev/null +++ b/ml/scripts/shared/phoneme_utils.py @@ -0,0 +1,38 @@ +""" +Phoneme dictionary helpers shared between 02_compute_vocab.py and tests. +""" +from pathlib import Path + + +def load_phoneme_dict(phoneme_dict_dir: Path) -> dict[str, list[str]]: + """Load CMU phoneme dictionary from *phoneme_dict_dir*/cmudict-0.7b. + + Stress numerals are stripped from each phoneme (e.g. ``AH0`` → ``AH``). + First entry wins on duplicate headwords. A custom entry for ``NETFLIX`` + is always added. + """ + phoneme_dict: dict[str, list[str]] = {} + dict_path = phoneme_dict_dir / "cmudict-0.7b" + with open(dict_path, encoding='latin-1') as f: + for line in f: + line = line.strip() + if not line or line.startswith("#") or line.startswith(";;;"): + continue + parts = line.split() + word = parts[0] + phonemes = [''.join(c for c in p if c.isalpha()) for p in parts[1:]] + if word not in phoneme_dict: + phoneme_dict[word] = phonemes + phoneme_dict['NETFLIX'] = ['N', 'EH', 'T', 'F', 'L', 'IH', 'K', 'S'] + return phoneme_dict + + +def build_trie(word_to_phonemes: dict[str, list[str]]) -> dict: + """Build a trie mapping phoneme sequences to words.""" + root: dict = {} + for word, phonemes in word_to_phonemes.items(): + node = root + for phoneme in phonemes: + node = node.setdefault(phoneme, {}) + node.setdefault("$", []).append(word) + return root diff --git a/ml/scripts/speech_to_text/01_generate_speech_samples.py b/ml/scripts/speech_to_text/01_generate_speech_samples.py new file mode 100644 index 00000000..38b07094 --- /dev/null +++ b/ml/scripts/speech_to_text/01_generate_speech_samples.py @@ -0,0 +1,74 @@ +import argparse +from pathlib import Path +import pandas as pd +import os +import asyncio +import edge_tts +from tqdm import tqdm + +import sys +sys.path.insert(0, str(Path(__file__).parent.parent)) # adds ml/scripts/ to sys.path + +from shared.constants import MAX_RETRIES, RETRY_DELAY_SEC + +# Parse command-line arguments +parser = argparse.ArgumentParser(description="Generate speech samples from variations CSV.") +parser.add_argument('--input-file', type=Path, required=True, help='Path to the input CSV file (variations)') +parser.add_argument('--output-dir', type=Path, required=True, help='Directory for output speech samples') + + +def is_valid_audio_file(path: Path) -> bool: + """Return True if the file exists and has non-zero size.""" + return path.exists() and path.stat().st_size > 0 + + +async def generate_sample(idx, phrase, voice, speech_rate_str, output_path): + for attempt in range(MAX_RETRIES): + try: + communicate = edge_tts.Communicate( + text=phrase, + voice=voice, + rate=speech_rate_str, + ) + await communicate.save(str(output_path)) + if is_valid_audio_file(output_path): + return True + # File missing or empty — delete corrupt file before retry + if output_path.exists(): + output_path.unlink() + except Exception as e: + print(f"Attempt {attempt + 1}/{MAX_RETRIES} failed for index {idx}: {e}") + if attempt < MAX_RETRIES - 1: + await asyncio.sleep(RETRY_DELAY_SEC * (2 ** attempt)) + print(f"Failed after {MAX_RETRIES} attempts for index {idx}, skipping.") + return False + + +async def generate_samples(paths, phrases_df): + count = 0 + for idx in tqdm(range(0, len(phrases_df)), desc="Generating speech samples"): + output_path = paths.output_dir / phrases_df.iloc[idx]['sample_file_name'] + if is_valid_audio_file(output_path): + continue # Skip if a valid file already exists + phrase = phrases_df.iloc[idx]['phrase_to_speak'] + voice = phrases_df.iloc[idx]['voice'] + speech_rate_str = phrases_df.iloc[idx]['speech_rate'] + if await generate_sample(idx, phrase, voice, speech_rate_str, output_path): + count += 1 + return count + + +async def main(): + paths = parser.parse_args() + os.makedirs(paths.output_dir, exist_ok=True) + + # Load the phrases from CSV + phrases_df = pd.read_csv(paths.input_file, encoding='utf-8') + + print("Starting sample generation...") + total_generated = await generate_samples(paths, phrases_df) + print(f"Sample generation completed. Total samples generated: {total_generated}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/ml/scripts/speech_to_text/01a_generate_speech_sample_variations.py b/ml/scripts/speech_to_text/01a_generate_speech_sample_variations.py new file mode 100644 index 00000000..e7b3a45a --- /dev/null +++ b/ml/scripts/speech_to_text/01a_generate_speech_sample_variations.py @@ -0,0 +1,114 @@ +import argparse +from collections import defaultdict +from pathlib import Path +import pandas as pd +import random +import os +import asyncio +import edge_tts +from tqdm import tqdm +import sys +sys.path.insert(0, str(Path(__file__).parent.parent)) # adds ml/scripts/ to sys.path + +from shared.constants import MIN_SPEECH_RATE, MAX_SPEECH_RATE, SUBSAMPLE_RATE + +# Read file and directory paths from command line arguments +parser = argparse.ArgumentParser(description="Generate speech sample variations.") +parser.add_argument('--input-file', type=Path, required=True, help='Path to the input CSV file') +parser.add_argument('--output-file', type=Path, required=True, help='Path to the output CSV file') +parser.add_argument('--samples-dir', type=Path, required=True, help='Directory for speech samples') + +# Functions +async def get_voices(): + voices = await edge_tts.list_voices() + # Filter for female voices and exclude problematic ones + female_voices = [ + v for v in voices + if v['Gender'] == 'Female' + and v['Locale'] == 'en-US' + and ':' not in v['ShortName'] + and 'DragonHD' not in v['ShortName'] + and 'Turbo' not in v['ShortName'] + ] + print(f"Sample female voices: {[v['ShortName'] for v in female_voices[:5]]}") + print(f"Total female voices found: {len(female_voices)}") + return female_voices + + +async def generate_variation_records(paths, phrases, labels, speech_to_detect, existing_records, voices: list): + """Generate variation records, reusing existing records where the phrase already exists. + + Bug fix: the previous implementation used list.index() + list.pop() to match existing + records by phrase. When a phrase appeared multiple times (different voice/rate variations), + only the first row was ever matched; subsequent rows were silently lost, and their .wav + files were then deleted as "obsolete". + + Fix: build a dict[phrase -> list[record]] so every existing row for a phrase is preserved + in FIFO order, and we pop the front entry for each occurrence of that phrase. + """ + records = [] + if not voices: + print("No voices available.") + return records + + # Build a lookup: phrase_to_speak -> [record, record, ...] (preserves all rows) + existing_by_phrase: dict[str, list[dict]] = defaultdict(list) + for rec in existing_records: + existing_by_phrase[rec['phrase_to_speak']].append(rec) + + for idx, (phrase, label, speech) in enumerate(tqdm(zip(phrases, labels, speech_to_detect), desc="Generating variations", total=len(phrases))): + if idx % SUBSAMPLE_RATE != 0: + continue + if existing_by_phrase.get(phrase): + records.append(existing_by_phrase[phrase].pop(0)) + continue + voice = random.choice(voices)['ShortName'] + speech_rate = random.randint(MIN_SPEECH_RATE, MAX_SPEECH_RATE) + speech_rate_str = f"+{speech_rate}%" if speech_rate >= 0 else f"{speech_rate}%" + records.append({ + 'phrase_to_speak': phrase, + 'phrase_to_detect': speech, + 'voice': voice, + 'speech_rate': speech_rate_str, + 'sample_file_name': f"{label}_{idx}_{voice}_r{speech_rate + 100}.wav", + }) + return records + + +async def main(): + paths = parser.parse_args() + + os.makedirs(paths.output_file.parent, exist_ok=True) + + phrases_df = pd.read_csv(paths.input_file, encoding='utf-8') + phrases = phrases_df['surface_form'].tolist() + labels = phrases_df['canonical_label'].tolist() + speech_to_detect = phrases_df['speech_to_detect'].tolist() + + # Load the existing records if the file exists + try: + existing_df = pd.read_csv(paths.output_file, encoding='utf-8') + existing_records = existing_df.to_dict(orient='records') + print(f"Loaded {len(existing_records)} existing variation records from {paths.output_file}...") + except FileNotFoundError: + print(f"Did not find existing {paths.output_file}.") + existing_records = [] + + print("Fetching available voices...") + voices = await get_voices() + print("Generating variation records...") + variation_records = await generate_variation_records(paths, phrases, labels, speech_to_detect, existing_records, voices) + print(f"Saving variation records to {paths.output_file}...") + variations_df = pd.DataFrame(variation_records) + variations_df.to_csv(paths.output_file, index=False, encoding='utf-8') + print("Deleting obsolete speech samples.") + # Remove files from samples_dir that are not in the new variations_df + if paths.samples_dir.exists(): + for file in paths.samples_dir.iterdir(): + if not variations_df["sample_file_name"].eq(file.name).any(): + print(f"Deleting obsolete sample file: {file.name}") + file.unlink() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/ml/scripts/speech_to_text/02_add_delays.py b/ml/scripts/speech_to_text/02_add_delays.py new file mode 100644 index 00000000..0ca1bb9a --- /dev/null +++ b/ml/scripts/speech_to_text/02_add_delays.py @@ -0,0 +1,46 @@ +import argparse +import os +from pathlib import Path +import pandas as pd +import soundfile as sf +import numpy as np +from tqdm import tqdm + +# Settings +# Parse command-line arguments +parser = argparse.ArgumentParser(description="Add random delays to audio samples.") +parser.add_argument('--input-file', type=Path, required=True, help='Input wav file list (CSV) with delay variations') +parser.add_argument('--output-dir', type=Path, required=True, help='Directory for output wav files') +paths = parser.parse_args() + +os.makedirs(paths.output_dir, exist_ok=True) + +# Read the CSV file with delay variations +df = pd.read_csv(paths.input_file) + +for _, row in tqdm(df.iterrows(), total=len(df), desc="Processing audio files", unit="file"): + file_path = Path(row['input_file_path']) + + data, samplerate = sf.read(file_path) + new_data = data + prefix_delay = 0.0 + suffix_delay = 0.0 + + # Add prefix silence + if row['prefix_delay_seconds'] > 0.0: + prefix_delay = row['prefix_delay_seconds'] + num_prefix_samples = int(prefix_delay * samplerate) + silence_prefix = np.zeros((num_prefix_samples, data.shape[1]) if data.ndim > 1 else num_prefix_samples, dtype=data.dtype) + new_data = np.concatenate([silence_prefix, new_data], axis=0) + + # Add suffix silence + if row['suffix_delay_seconds'] > 0.0: + suffix_delay = row['suffix_delay_seconds'] + num_suffix_samples = int(suffix_delay * samplerate) + silence_suffix = np.zeros((num_suffix_samples, data.shape[1]) if data.ndim > 1 else num_suffix_samples, dtype=data.dtype) + new_data = np.concatenate([new_data, silence_suffix], axis=0) + + out_path = paths.output_dir / row['new_file_name'] + + sf.write(out_path, new_data, samplerate) + diff --git a/ml/scripts/speech_to_text/02a_randomize_delay_variations.py b/ml/scripts/speech_to_text/02a_randomize_delay_variations.py new file mode 100644 index 00000000..842f93a4 --- /dev/null +++ b/ml/scripts/speech_to_text/02a_randomize_delay_variations.py @@ -0,0 +1,45 @@ +import argparse +from pathlib import Path +import pandas as pd +import random +import os + +import sys +sys.path.insert(0, str(Path(__file__).parent.parent)) # adds ml/scripts/ to sys.path +from shared.constants import PREFIX_DELAY_FREQUENCY, SUFFIX_DELAY_FREQUENCY, MAX_DELAY_DURATION, MIN_DELAY_DURATION + +# Read file and directory paths from command line arguments +parser = argparse.ArgumentParser(description="Generate speech sample variations.") +parser.add_argument('--input-dir', type=Path, required=True, help='Path to the input directory containing speech samples') +parser.add_argument('--output-file', type=Path, required=True, help='Path to the output CSV file containing random delay values') +paths = parser.parse_args() + +os.makedirs(paths.output_file.parent, exist_ok=True) + +records = [] +for file_path in paths.input_dir.glob('*.wav'): + stem = file_path.stem + + if random.randint(1, PREFIX_DELAY_FREQUENCY) == 1: + prefix_delay = random.uniform(MIN_DELAY_DURATION, MAX_DELAY_DURATION) + stem = f"{stem}_pre{int(prefix_delay * 1000):04d}" + else: + prefix_delay = 0.0 + + if random.randint(1, SUFFIX_DELAY_FREQUENCY) == 1: + suffix_delay = random.uniform(MIN_DELAY_DURATION, MAX_DELAY_DURATION) + stem = f"{stem}_suf{int(suffix_delay * 1000):04d}" + else: + suffix_delay = 0.0 + + records.append({ + 'input_file_path': str(file_path), + 'prefix_delay_seconds': prefix_delay, + 'suffix_delay_seconds': suffix_delay, + 'new_file_name': f"{stem}{file_path.suffix}" + }) + +# Save to CSV +df = pd.DataFrame.from_records(records) +df.to_csv(paths.output_file, index=False, encoding='utf-8') +print(f"Delay variations saved to {paths.output_file}") \ No newline at end of file diff --git a/ml/scripts/speech_to_text/03_add_background_noise.py b/ml/scripts/speech_to_text/03_add_background_noise.py new file mode 100644 index 00000000..bf0f0e85 --- /dev/null +++ b/ml/scripts/speech_to_text/03_add_background_noise.py @@ -0,0 +1,74 @@ +import argparse +import os +from pathlib import Path +import random +import numpy as np +import soundfile as sf +from tqdm import tqdm +import sys +sys.path.insert(0, str(Path(__file__).parent.parent)) # adds ml/scripts/ to sys.path + +from shared.constants import BACKGROUND_NOISE_FREQUENCY, BACKGROUND_NOISE_VOLUME_MIN, BACKGROUND_NOISE_VOLUME_MAX + +# Parse command-line arguments +parser = argparse.ArgumentParser(description="Add background noise to audio samples.") +parser.add_argument('--input-dir', type=Path, required=True, help='Directory containing input wav files') +parser.add_argument('--noise-dir', type=Path, required=True, help='Directory containing noise wav files') +parser.add_argument('--output-dir', type=Path, required=True, help='Directory for output wav files') + + +def get_random_noise(noise_files, length, sr): + noise_file = random.choice(noise_files) + noise, noise_sr = sf.read(noise_file) + if len(noise.shape) > 1: + noise = noise[:,0] # Use first channel if stereo + if noise_sr != sr: + # Resample noise to match target sample rate + num_samples = int(len(noise) * sr / noise_sr) + indices = np.linspace(0, len(noise) - 1, num_samples).astype(int) + noise = noise[indices] + if len(noise) < length: + # Loop noise if too short + repeats = int(np.ceil(length / len(noise))) + noise = np.tile(noise, repeats) + max_start = max(0, len(noise) - length) + start = random.randint(0, max_start) + return noise[start:start+length] + + +def add_noise_to_audio(audio, noise, volume): + return audio + noise * volume + + +def main(paths): + os.makedirs(paths.output_dir, exist_ok=True) + noise_files = list(paths.noise_dir.glob("*.wav")) + if not noise_files: + print(f"No noise samples found in {paths.noise_dir}") + return + input_files = list(paths.input_dir.glob("*.wav")) + for input_file in tqdm(input_files, desc="Processing audio files", unit="file", total=len(input_files)): + stem = input_file.stem + audio, sr = sf.read(input_file) + if len(audio.shape) > 1: + audio = audio[:,0] # Use first channel if stereo + add_noise = (random.randint(1, BACKGROUND_NOISE_FREQUENCY) == 1) + if add_noise: + noise = get_random_noise(noise_files, len(audio), sr) + volume = random.uniform(BACKGROUND_NOISE_VOLUME_MIN, BACKGROUND_NOISE_VOLUME_MAX) + audio_noisy = add_noise_to_audio(audio, noise, volume) + # Clip to [-1,1] to avoid overflow + audio_noisy = np.clip(audio_noisy, -1.0, 1.0) + out_audio = audio_noisy + # Modify filename to include _bg{volume} without leading '0.' + volume_str = f"{int(volume * 1000):03d}" + out_filename = f"{stem}_bg{volume_str}.wav" + else: + out_audio = audio + out_filename = input_file.name + out_path = paths.output_dir / out_filename + sf.write(out_path, out_audio, sr) + + +if __name__ == "__main__": + main(parser.parse_args()) diff --git a/ml/scripts/speech_to_text/03a_download_background_noise.py b/ml/scripts/speech_to_text/03a_download_background_noise.py new file mode 100644 index 00000000..215cbcc4 --- /dev/null +++ b/ml/scripts/speech_to_text/03a_download_background_noise.py @@ -0,0 +1,36 @@ +import argparse +import os +from pathlib import Path +import requests + +# Settings +# Noise samples +noise_samples = { + "creative-background-short-ver.wav": "https://cdn.freesound.org/sounds/721/721949-a0b57121-2a03-4dac-97c0-ee15fc5db207?filename=721949__audiocoffee__creative-background-short-ver.wav", + "trailer.wav": "https://cdn.freesound.org/sounds/785/785516-53995c18-2299-49bc-b042-357c8cb919fd?filename=785516__litesaturation__trailer.wav", + "tv-chatter.wav": "https://cdn.freesound.org/sounds/765/765157-8a98bb7d-6d3d-4869-af6c-4ba18aaddf27?filename=765157__mieckevanhoek__tv-chatter.wav", + "tv-news-loop.wav": "https://cdn.freesound.org/sounds/468/468539-e433c8eb-7f21-467d-9910-a37f4738c868?filename=468539__sergequadrado__tv-news-loop.wav", + "tv-recording-of-a-handball-match-3.wav": "https://cdn.freesound.org/sounds/786/786263-6ef16c1d-183a-4143-beca-6b9528e9cdb5?filename=786263__king_anna__tv-recording-of-a-handball-match-3.wav", +} + +# Parse command-line arguments +parser = argparse.ArgumentParser(description="Download background noise samples.") +parser.add_argument('--output-dir', type=Path, required=True, help='Directory for downloaded noise wav files') +paths = parser.parse_args() + +os.makedirs(paths.output_dir, exist_ok=True) + +# Download noise samples +for filename, url in noise_samples.items(): + output_path = paths.output_dir / filename + if output_path.exists(): + print(f"File {output_path} already exists. Skipping download.") + continue + print(f"Downloading {filename} from {url}...") + response = requests.get(url, timeout=30, stream=True) + response.raise_for_status() + with open(output_path, "wb") as f: + for chunk in response.iter_content(chunk_size=8192): + f.write(chunk) + print(f"Saved to {output_path}") + diff --git a/ml/scripts/speech_to_text/04_add_microphone_noise.py b/ml/scripts/speech_to_text/04_add_microphone_noise.py new file mode 100644 index 00000000..521a8380 --- /dev/null +++ b/ml/scripts/speech_to_text/04_add_microphone_noise.py @@ -0,0 +1,49 @@ +import argparse +import os +from pathlib import Path +import random +import numpy as np +import soundfile as sf +from tqdm import tqdm +import sys +sys.path.insert(0, str(Path(__file__).parent.parent)) # adds ml/scripts/ to sys.path + +from shared.constants import ( + MICROPHONE_NOISE_FREQUENCY, MICROPHONE_NOISE_VOLUME_MIN, MICROPHONE_NOISE_VOLUME_MAX, + MICROPHONE_NOISE_TYPE, +) + +# Parse command-line arguments +parser = argparse.ArgumentParser(description="Add microphone noise to audio samples.") +parser.add_argument('--input-dir', type=Path, required=True, help='Directory containing input wav files') +parser.add_argument('--output-dir', type=Path, required=True, help='Directory for output wav files') +paths = parser.parse_args() + +os.makedirs(paths.output_dir, exist_ok=True) + +# Process each audio file in the input directory +input_files = list(paths.input_dir.glob('*.wav')) +for file_path in tqdm(input_files, desc="Processing audio files", unit="file", total=len(input_files)): + stem = file_path.stem + # Decide randomly whether to add noise + add_noise = random.randint(1, MICROPHONE_NOISE_FREQUENCY) == 1 + data, samplerate = sf.read(file_path) + noise_volume = 0.0 + if add_noise: + # Random noise volume + noise_volume = random.uniform(MICROPHONE_NOISE_VOLUME_MIN, MICROPHONE_NOISE_VOLUME_MAX) + # Generate white noise + noise = np.random.normal(0, 1, data.shape) * noise_volume + data_noisy = data + noise + # Clip to valid range + data_noisy = np.clip(data_noisy, -1.0, 1.0) + # Modify filename to indicate noise + noise_str = f"_mic{int(noise_volume * 1000):03d}" + out_name = file_path.stem + noise_str + file_path.suffix + out_path = paths.output_dir / out_name + sf.write(out_path, data_noisy, samplerate) + else: + # Save original file without noise + out_path = paths.output_dir / file_path.name + sf.write(out_path, data, samplerate) + diff --git a/ml/scripts/speech_to_text/05_create_set_manifests.py b/ml/scripts/speech_to_text/05_create_set_manifests.py new file mode 100644 index 00000000..5c8cbe72 --- /dev/null +++ b/ml/scripts/speech_to_text/05_create_set_manifests.py @@ -0,0 +1,71 @@ + +import argparse +import os +from pathlib import Path +import random +import csv +import re +import sys +sys.path.insert(0, str(Path(__file__).parent.parent)) # adds ml/scripts/ to sys.path + +from shared.constants import TRAINING_SET_PERCENTAGE, VALIDATION_SET_PERCENTAGE, TEST_SET_PERCENTAGE + +# Parse command-line arguments +parser = argparse.ArgumentParser(description="Create set manifests for training, validation, and test sets.") +parser.add_argument('--input-manifest', type=Path, required=True, help='Path to input manifest CSV (training_data.csv)') +parser.add_argument('--clean-dir', type=Path, required=True, help='Directory with clean speech samples') +parser.add_argument('--noisy-dir', type=Path, required=True, help='Directory with noisy speech samples') +parser.add_argument('--output-dir', type=Path, required=True, help='Directory for output manifest files') +paths = parser.parse_args() + +os.makedirs(paths.output_dir, exist_ok=True) + +# Read surface_form from input manifest +surface_forms = [] +with open(paths.input_manifest, newline='', encoding='utf-8') as csvfile: + reader = csv.DictReader(csvfile) + for row in reader: + surface_forms.append(row['speech_to_detect']) + +# Collect all files from clean and noisy dirs +all_files = [] +# for input_dir in [paths.clean_dir, paths.noisy_dir]: +for input_dir in [paths.clean_dir]: + for root, _, files in os.walk(input_dir): + for file in files: + file_path = Path(root) / file + all_files.append(str(file_path.resolve())) + +# Shuffle the list (seeded for reproducibility) +random.seed(42) +random.shuffle(all_files) + +total_files = len(all_files) +train_count = int(total_files * TRAINING_SET_PERCENTAGE / 100) +val_count = int(total_files * VALIDATION_SET_PERCENTAGE / 100) +test_count = total_files - train_count - val_count + +train_files = all_files[:train_count] +val_files = all_files[train_count:train_count+val_count] +test_files = all_files[train_count+val_count:] + + +# Helper to extract number between underscores +def extract_number_from_filename(filename): + match = re.search(r'_(\d+)_', filename) + if match: + return int(match.group(1)) + return None + +def write_manifest(file_list, manifest_path): + with open(manifest_path, 'w', newline='', encoding='utf-8') as csvfile: + writer = csv.writer(csvfile) + writer.writerow(['filepath', 'speech_to_detect']) + for f in file_list: + num = extract_number_from_filename(os.path.basename(f)) + surface_form = surface_forms[num] if num is not None and num < len(surface_forms) else '' + writer.writerow([f, surface_form]) + +write_manifest(train_files, paths.output_dir / 'train_manifest.csv') +write_manifest(val_files, paths.output_dir / 'val_manifest.csv') +write_manifest(test_files, paths.output_dir / 'test_manifest.csv') diff --git a/ml/scripts/speech_to_text/06_compute_token_lists.py b/ml/scripts/speech_to_text/06_compute_token_lists.py new file mode 100644 index 00000000..802b7a57 --- /dev/null +++ b/ml/scripts/speech_to_text/06_compute_token_lists.py @@ -0,0 +1,69 @@ +import argparse +import json +from pathlib import Path +import os +import pandas as pd +from tqdm import tqdm +import sys +sys.path.insert(0, str(Path(__file__).parent.parent)) # adds ml/scripts/ to sys.path + +from shared.constants import INPUT_TOKEN_LENGTH +from shared.io_utils import write_token_list, read_phoneme_list, read_manifest + +# Parse command-line arguments +parser = argparse.ArgumentParser(description="Compute token lists for audio files.") + +parser.add_argument('--train-manifest', type=Path, required=True, help='Path to train_manifest.csv') +parser.add_argument('--eval-manifest', type=Path, required=True, help='Path to eval_manifest.csv') +parser.add_argument('--test-manifest', type=Path, required=True, help='Path to test_manifest.csv') +parser.add_argument('--phoneme-list', type=Path, required=True, help='Path to phoneme_list.txt') +parser.add_argument('--words-to-phonemes', type=Path, required=True, help='Path to words_to_phonemes.json') +parser.add_argument('--output-dir', type=Path, required=True, help='Directory for output token JSON files') + + +def compute_tokens(phoneme_list, words_to_phonemes, transcription): + # Convert transcription to uppercase and split into words + words = transcription.upper().replace(',', ' ').split() + phonemes = [] + for word in words: + if word in words_to_phonemes: + phonemes.extend(words_to_phonemes[word]) + # If word not in mapping, skip it (could log warning if desired) + # Convert phonemes to indices + tokens = [phoneme_list.index(ph) for ph in phonemes if ph in phoneme_list] + tokens = tokens + [len(phoneme_list)] * (INPUT_TOKEN_LENGTH - len(tokens)) if len(tokens) < INPUT_TOKEN_LENGTH else tokens[:INPUT_TOKEN_LENGTH] + return phonemes, tokens + + +def load_rows_from_manifest(manifest_path): + df = read_manifest(manifest_path) + print(f"Loaded {len(df)} samples from {manifest_path}.") + return [(row['filepath'], row['speech_to_detect']) for _, row in df.iterrows()] + + +if __name__ == "__main__": + paths = parser.parse_args() + + os.makedirs(paths.output_dir, exist_ok=True) + + # Read the phoneme list from file + phoneme_list, pad_value = read_phoneme_list(paths.phoneme_list) + print(f"Loaded phoneme list with {len(phoneme_list)} entries.") + + # Read the words-to-phonemes mapping from JSON + with open(paths.words_to_phonemes, 'r', encoding='utf-8') as w2p_file: + words_to_phonemes = json.load(w2p_file) + + manifest_rows = ( + load_rows_from_manifest(paths.train_manifest) + + load_rows_from_manifest(paths.eval_manifest) + + load_rows_from_manifest(paths.test_manifest) + ) + + for wav_path, transcription in tqdm(manifest_rows, desc="Computing token lists from manifests", total=len(manifest_rows)): + wav_path = Path(wav_path) + try: + phonemes, tokens = compute_tokens(phoneme_list, words_to_phonemes, transcription) + write_token_list(paths.output_dir, wav_path.stem, phonemes, tokens) + except Exception as e: + print(f'Error processing {wav_path}: {e}') diff --git a/ml/scripts/speech_to_text/07_compute_spectrograms.py b/ml/scripts/speech_to_text/07_compute_spectrograms.py new file mode 100644 index 00000000..21146bde --- /dev/null +++ b/ml/scripts/speech_to_text/07_compute_spectrograms.py @@ -0,0 +1,61 @@ +import argparse +from pathlib import Path +import os +import numpy as np +import librosa +import soundfile as sf +from tqdm import tqdm +import sys +sys.path.insert(0, str(Path(__file__).parent.parent)) # adds ml/scripts/ to sys.path + +from shared.constants import TIME_STEPS, N_MELS +from shared.io_utils import write_spectrogram, read_manifest + +# Parse command-line arguments +parser = argparse.ArgumentParser(description="Compute log-mel spectrograms for audio files.") +parser.add_argument('--train-manifest', type=Path, required=True, help='Path to train_manifest.csv') +parser.add_argument('--eval-manifest', type=Path, required=True, help='Path to eval_manifest.csv') +parser.add_argument('--test-manifest', type=Path, required=True, help='Path to test_manifest.csv') +parser.add_argument('--output-dir', type=Path, required=True, help='Directory for output spectrogram npy files') + + +def compute_melspectrogram(wav_path, time_steps=TIME_STEPS, n_mels=N_MELS): + y, sr = sf.read(str(wav_path)) + # If stereo, convert to mono (average channels) + if y.ndim > 1: + y = np.mean(y, axis=1) + S = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=n_mels) + log_S = librosa.power_to_db(S, ref=np.max) + if log_S.shape[1] < time_steps: + pad_width = time_steps - log_S.shape[1] + log_S = np.pad(log_S, ((0, 0), (0, pad_width)), mode='constant') + else: + print(f'Warning: Truncating spectrogram for {wav_path}, has {log_S.shape[1]}>{time_steps} time steps.') + log_S = log_S[:, :time_steps] + return log_S + + +def load_rows_from_manifest(manifest_path): + df = read_manifest(manifest_path) + print(f"Loaded {len(df)} samples from {manifest_path}.") + return [(row['filepath'], row['speech_to_detect']) for _, row in df.iterrows()] + + +if __name__ == "__main__": + paths = parser.parse_args() + + os.makedirs(paths.output_dir, exist_ok=True) + + manifest_rows = ( + load_rows_from_manifest(paths.train_manifest) + + load_rows_from_manifest(paths.eval_manifest) + + load_rows_from_manifest(paths.test_manifest) + ) + + for wav_path, transcription in tqdm(manifest_rows, desc="Computing spectrograms from manifests", total=len(manifest_rows)): + wav_path = Path(wav_path) + try: + log_S = compute_melspectrogram(wav_path) + write_spectrogram(paths.output_dir, wav_path.stem, log_S) + except Exception as e: + print(f'Error processing {wav_path}: {e}') diff --git a/ml/scripts/speech_to_text/08_train_model.py b/ml/scripts/speech_to_text/08_train_model.py new file mode 100644 index 00000000..df3582d4 --- /dev/null +++ b/ml/scripts/speech_to_text/08_train_model.py @@ -0,0 +1,94 @@ +import argparse +from pathlib import Path +import os +import numpy as np +from tqdm import tqdm +import sys +sys.path.insert(0, str(Path(__file__).parent.parent)) # adds ml/scripts/ to sys.path + +from shared.constants import INPUT_TOKEN_LENGTH, N_MELS, TIME_STEPS, EPOCHS, BATCH_SIZE +from shared.io_utils import read_spectrogram, read_token_list, read_phoneme_list, read_manifest + +print("Initializing TensorFlow...") +import tensorflow as tf +from tensorflow.keras import layers, Model, Input + +# Parse command-line arguments +parser = argparse.ArgumentParser(description="Train speech-to-text model.") +parser.add_argument('--manifest', type=Path, required=True, help='Path to train_manifest.csv') +parser.add_argument('--vocab', type=Path, required=True, help='Path to vocab_list.txt') +parser.add_argument('--spectrogram-dir', type=Path, required=True, help='Directory with spectrogram npy files') +parser.add_argument('--token-list-dir', type=Path, required=True, help='Directory with token list JSON files') +parser.add_argument('--output-dir', type=Path, required=True, help='Directory for output model') + +if __name__ == "__main__": + paths = parser.parse_args() + + os.makedirs(paths.output_dir, exist_ok=True) + output_model_file = paths.output_dir / "speech_to_text_model.keras" + + # Read the sample file names from manifest + training_set = read_manifest(paths.manifest) + print(f"Loaded {len(training_set)} training samples from manifest.") + + # Build the model + vocab_list, ctc_blank_idx = read_phoneme_list(paths.vocab) + num_classes = len(vocab_list) + 1 # +1 for CTC blank token + print(f"Vocabulary size: {len(vocab_list)}, Number of classes (with CTC blank): {num_classes}, CTC blank index: {ctc_blank_idx}") + + input_layer = Input(shape=(N_MELS, TIME_STEPS), name='input') + x = layers.Reshape((N_MELS, TIME_STEPS, 1))(input_layer) + x = layers.Conv2D(32, (3, 3), activation='relu', padding='same')(x) + x = layers.BatchNormalization()(x) + x = layers.MaxPooling2D(pool_size=(2, 2))(x) + x = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(x) + x = layers.BatchNormalization()(x) + x = layers.MaxPooling2D(pool_size=(2, 2))(x) + new_time_steps = x.shape[1] + new_features = x.shape[2] * x.shape[3] + x = layers.Reshape((new_time_steps, new_features))(x) + x = layers.Bidirectional(layers.LSTM(128, return_sequences=True))(x) + output_layer = layers.Dense(num_classes, activation='softmax', name='output')(x) + model = Model(inputs=input_layer, outputs=output_layer) + model.summary() + + # Compile model with dummy loss (real loss in Lambda layer) + model.compile(optimizer='adam', loss='categorical_crossentropy') # placeholder + print("Model compiled.") + + x_train = [] + y_train = [] + + # Prepare input/output pairs for training + for _, row in tqdm(training_set.iterrows(), total=len(training_set), desc="Loading training data"): + wav_filename = Path(row['filepath']).stem + x_train.append(read_spectrogram(paths.spectrogram_dir, wav_filename)) + y_train.append(read_token_list(paths.token_list_dir, wav_filename)) + + print(f"Prepared {len(x_train)} input-output pairs for training.") + + train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))\ + .batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE) + + history = [] + for epoch in range(EPOCHS): + epoch_loss = [] + for batch in tqdm(train_dataset, desc=f"Epoch {epoch+1}/{EPOCHS}"): + x_batch, y_batch = batch + with tf.GradientTape() as tape: + y_pred = model(x_batch, training=True) + # Compute prediction lengths (time steps of y_pred) + pred_len = tf.fill([tf.shape(y_pred)[0], 1], tf.shape(y_pred)[1]) + # Compute true label lengths by counting non-padding tokens (assumes 0 is padding) + lbl_len = tf.math.count_nonzero(y_batch, axis=1, dtype=tf.int32) + lbl_len_reshaped = tf.expand_dims(lbl_len, axis=1) + loss = tf.keras.backend.ctc_batch_cost(y_batch, y_pred, pred_len, lbl_len_reshaped) + grads = tape.gradient(loss, model.trainable_variables) + model.optimizer.apply_gradients(zip(grads, model.trainable_variables)) + epoch_loss.append(tf.reduce_mean(loss).numpy()) + print(f'Epoch {epoch+1}/{EPOCHS} - Loss: {np.mean(epoch_loss):.4f}') + history.append(np.mean(epoch_loss)) + + # Save the trained model in Keras format + model.save(output_model_file) + print(f"Model saved to {output_model_file}") diff --git a/ml/scripts/speech_to_text/09_evaluate_model.py b/ml/scripts/speech_to_text/09_evaluate_model.py new file mode 100644 index 00000000..d113acfe --- /dev/null +++ b/ml/scripts/speech_to_text/09_evaluate_model.py @@ -0,0 +1,84 @@ +import argparse +from pathlib import Path +import os +import numpy as np +from tqdm import tqdm +from jiwer import wer +import sys +sys.path.insert(0, str(Path(__file__).parent.parent)) # adds ml/scripts/ to sys.path + +from shared.constants import BATCH_SIZE +from shared.ctc_utils import ctc_greedy_decode, indices_to_words, trim_at_blank +from shared.io_utils import read_spectrogram, read_token_list, read_phoneme_list, read_manifest + +print("Initializing TensorFlow...") +import tensorflow as tf + +# Parse command-line arguments +parser = argparse.ArgumentParser(description="Evaluate speech-to-text model.") +parser.add_argument('--manifest', type=Path, required=True, help='Path to val_manifest.csv') +parser.add_argument('--model', type=Path, required=True, help='Path to model file (speech_to_text_model.keras)') +parser.add_argument('--vocab', type=Path, required=True, help='Path to vocab_list.txt') +parser.add_argument('--spectrogram-dir', type=Path, required=True, help='Directory with spectrogram npy files') +parser.add_argument('--token-list-dir', type=Path, required=True, help='Directory with token list JSON files') +parser.add_argument('--output-dir', type=Path, required=True, help='Directory for evaluation results (predictions and metrics)') + +if __name__ == "__main__": + paths = parser.parse_args() + + os.makedirs(paths.output_dir, exist_ok=True) + + # Read the sample file names from manifest + eval_set = read_manifest(paths.manifest) + print(f"Loaded {len(eval_set)} evaluation samples from manifest.") + + # Prepare input/output pairs for evaluation + x_eval = [] + y_eval = [] + for _, row in tqdm(eval_set.iterrows(), total=len(eval_set), desc="Loading evaluation data"): + wav_filename = Path(row['filepath']).stem + x_eval.append(read_spectrogram(paths.spectrogram_dir, wav_filename)) + y_eval.append(read_token_list(paths.token_list_dir, wav_filename)) + + # Load the trained model + print("Loading speech-to-text model...") + model = tf.keras.models.load_model(paths.model) + print(f"Loaded {paths.model}") + + # Load the vocabulary list from vocab file + vocab_list, ctc_blank_idx = read_phoneme_list(paths.vocab) + print( + f"Vocabulary size: {len(vocab_list)}, " + f"Number of classes (with CTC blank): {len(vocab_list) + 1}, " + f"CTC blank index: {ctc_blank_idx}" + ) + + eval_dataset = tf.data.Dataset.from_tensor_slices((x_eval, y_eval))\ + .batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE) + + # Evaluate on eval set + all_preds = [] + for batch, _ in eval_dataset: + pred = model.predict(batch) + all_preds.extend(ctc_greedy_decode(pred, blank=ctc_blank_idx)) + + # Compute WER + refs = [' '.join(indices_to_words(trim_at_blank(seq, vocab_list), vocab_list)) for seq in y_eval] + hyps = [' '.join(indices_to_words(seq, vocab_list)) for seq in all_preds] + wer_score = wer(refs, hyps) + print(f'WER: {wer_score:.3f}') + + # Show a few example predictions + for i in range(min(5, refs.__len__())): + print('REF:', refs[i]) + print('HYP:', hyps[i]) + print() + + # Save all predictions to a file in output dir + output_predictions_file = paths.output_dir / "evaluation_predictions.txt" + with open(output_predictions_file, 'w', encoding='utf-8') as f: + f.write(f'WER: {wer_score:.3f}\n\n') + for ref, hyp in zip(refs, hyps): + f.write(f'REF: {ref}\n') + f.write(f'HYP: {hyp}\n\n') + print(f"Saved evaluation predictions to {output_predictions_file}") diff --git a/ml/scripts/speech_to_text/10_evaluate_test_samples.py b/ml/scripts/speech_to_text/10_evaluate_test_samples.py new file mode 100644 index 00000000..2f1e54f3 --- /dev/null +++ b/ml/scripts/speech_to_text/10_evaluate_test_samples.py @@ -0,0 +1,75 @@ +import argparse +from pathlib import Path +import os +import numpy as np +from tqdm import tqdm +from zipfile import ZipFile +import sys +sys.path.insert(0, str(Path(__file__).parent.parent)) # adds ml/scripts/ to sys.path + +from shared.constants import BATCH_SIZE +from shared.ctc_utils import ctc_greedy_decode, tokens_to_text +from shared.io_utils import read_spectrogram, read_token_list, read_phoneme_list, read_manifest + +print("Initializing TensorFlow...") +import tensorflow as tf + +# Parse command-line arguments +parser = argparse.ArgumentParser(description="Evaluate test samples and create ZIP of successfully recognized files.") +parser.add_argument('--manifest', type=Path, required=True, help='Path to test_manifest.csv') +parser.add_argument('--model', type=Path, required=True, help='Path to model file (speech_to_text_model.keras)') +parser.add_argument('--vocab', type=Path, required=True, help='Path to vocab_list.txt') +parser.add_argument('--spectrogram-dir', type=Path, required=True, help='Directory with spectrogram npy files') +parser.add_argument('--token-list-dir', type=Path, required=True, help='Directory with token list JSON files') +parser.add_argument('--output-zip', type=Path, required=True, help='Path for output zip file') + +if __name__ == "__main__": + paths = parser.parse_args() + + os.makedirs(paths.output_zip.parent, exist_ok=True) + + # Read the sample file names from manifest + eval_set = read_manifest(paths.manifest) + print(f"Loaded {len(eval_set)} evaluation samples from manifest.") + + # Prepare input/output pairs for evaluation + x_eval = [] + y_eval = [] + for _, row in tqdm(eval_set.iterrows(), total=len(eval_set), desc="Loading evaluation data"): + wav_filename = Path(row['filepath']).stem + x_eval.append(read_spectrogram(paths.spectrogram_dir, wav_filename)) + y_eval.append(read_token_list(paths.token_list_dir, wav_filename)) + + # Load the trained model + print("Loading speech-to-text model...") + model = tf.keras.models.load_model(paths.model) + print(f"Loaded {paths.model}") + + # Load the vocabulary list from vocab file + vocab_list, ctc_blank_idx = read_phoneme_list(paths.vocab) + print(f"Vocabulary size: {len(vocab_list)}, Number of classes (with CTC blank): {len(vocab_list) + 1}") + + eval_dataset = tf.data.Dataset.from_tensor_slices((x_eval, y_eval))\ + .batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE) + + # Evaluate on eval set + all_preds = [] + for batch, _ in eval_dataset: + pred = model.predict(batch) + all_preds.extend(ctc_greedy_decode(pred, blank=ctc_blank_idx)) + + # Collect files where predicted text matches reference + success_files = [] + for i, (pred_tokens, true_tokens) in enumerate(zip(all_preds, y_eval)): + pred_text = tokens_to_text(pred_tokens, vocab_list) + true_text = tokens_to_text(true_tokens, vocab_list) + if pred_text == true_text: + wav_path = eval_set.iloc[i]['filepath'] + success_files.append(wav_path) + + # Add successfully matched files to ZIP + with ZipFile(paths.output_zip, 'w') as zipf: + for file_path in success_files: + zipf.write(file_path, arcname=Path(file_path).name) + + print(f"Successfully matched and added {len(success_files)} files to {paths.output_zip}") diff --git a/ml/test/__init__.py b/ml/test/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/ml/test/pipeline/__init__.py b/ml/test/pipeline/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/ml/test/pipeline/core/__init__.py b/ml/test/pipeline/core/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/ml/test/pipeline/core/test_manifest.py b/ml/test/pipeline/core/test_manifest.py new file mode 100644 index 00000000..4339db18 --- /dev/null +++ b/ml/test/pipeline/core/test_manifest.py @@ -0,0 +1,459 @@ +from __future__ import annotations + +import json +from pathlib import Path + +import pytest + +from pipeline.core.sample import ( + AudioSample, + SampleSpectrogram, + SampleTokens, + TextSample, +) +from pipeline.core.manifest import Manifest, ManifestStore + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _text_sample() -> TextSample: + return TextSample( + seed=0, + content_hash="deadbeef", + content="okay turn on the tv", + label="TV_ON", + ) + + +def _audio_sample() -> AudioSample: + return AudioSample( + id="TV_ON_Jenny_r77", + seed=67890, + content_hash="audiohash", + path=Path("TV_ON_Jenny_r77.wav"), + parent_content_hash="texthash", + transcript="TV_ON", + applied_values={"voice": "en-US-JennyNeural", "speech_rate": 5}, + ) + + +def _spectrogram_sample() -> SampleSpectrogram: + return SampleSpectrogram( + id="TV_ON_Jenny_r77", + seed=0, + content_hash="spechash", + path=Path("TV_ON_Jenny_r77.npy"), + parent_content_hash="audiohash", + transcript="TV_ON", + parent_id="audio-uuid-123", + ) + + +def _tokens_sample() -> SampleTokens: + return SampleTokens( + id="TV_ON_Jenny_r77", + seed=0, + content_hash="tokenshash", + path=Path("TV_ON_Jenny_r77.json"), + parent_content_hash="audiohash", + transcript="TV_ON", + parent_id="audio-uuid-123", + ) + + +# --------------------------------------------------------------------------- +# TextSample round-trip +# --------------------------------------------------------------------------- + +class TestTextSampleRoundTrip: + def test_round_trip_preserves_all_fields(self, tmp_path: Path) -> None: + sample = _text_sample() + store = ManifestStore() + manifest_path = tmp_path / "manifest.json" + store.write(Manifest([sample]), manifest_path) + + result = store.read(manifest_path) + + assert len(result.samples) == 1 + s = result.samples[0] + assert isinstance(s, TextSample) + assert s.id == "deadbeef" + assert s.seed == 0 + assert s.content_hash == "deadbeef" + assert s.content == "okay turn on the tv" + assert s.label == "TV_ON" + + def test_seed_serialised_as_zero(self, tmp_path: Path) -> None: + store = ManifestStore() + manifest_path = tmp_path / "manifest.json" + store.write(Manifest([_text_sample()]), manifest_path) + + raw = json.loads(manifest_path.read_text()) + + assert raw["samples"][0]["seed"] == 0 + + def test_json_omits_path_parent_content_hash_applied_values(self, tmp_path: Path) -> None: + store = ManifestStore() + manifest_path = tmp_path / "manifest.json" + store.write(Manifest([_text_sample()]), manifest_path) + + raw = json.loads(manifest_path.read_text()) + entry = raw["samples"][0] + + assert "path" not in entry + assert "parent_content_hash" not in entry + assert "applied_values" not in entry + + def test_json_schema_version_is_one(self, tmp_path: Path) -> None: + store = ManifestStore() + manifest_path = tmp_path / "manifest.json" + store.write(Manifest([_text_sample()]), manifest_path) + + raw = json.loads(manifest_path.read_text()) + + assert raw["version"] == 1 + assert raw["sample_type"] == "text" + + +# --------------------------------------------------------------------------- +# AudioSample round-trip +# --------------------------------------------------------------------------- + +class TestAudioSampleRoundTrip: + def test_round_trip_preserves_all_fields(self, tmp_path: Path) -> None: + sample = _audio_sample() + store = ManifestStore() + manifest_path = tmp_path / "manifest.json" + store.write(Manifest([sample]), manifest_path) + + result = store.read(manifest_path) + + assert len(result.samples) == 1 + s = result.samples[0] + assert isinstance(s, AudioSample) + assert s.id == "TV_ON_Jenny_r77" + assert s.seed == 67890 + assert s.content_hash == "audiohash" + assert s.path == Path("TV_ON_Jenny_r77.wav") + assert s.parent_content_hash == "texthash" + assert s.transcript == "TV_ON" + assert s.applied_values == {"voice": "en-US-JennyNeural", "speech_rate": 5} + + def test_applied_values_int_type_preserved(self, tmp_path: Path) -> None: + sample = AudioSample( + id="id", + seed=1, + content_hash="h", + path=Path("id.wav"), + parent_content_hash="ph", + transcript="TV_ON", + applied_values={"speech_rate": 5}, + ) + store = ManifestStore() + manifest_path = tmp_path / "manifest.json" + store.write(Manifest([sample]), manifest_path) + + result = store.read(manifest_path) + s = result.samples[0] + + assert isinstance(s.applied_values["speech_rate"], int) + + def test_applied_values_float_type_preserved(self, tmp_path: Path) -> None: + sample = AudioSample( + id="id", + seed=1, + content_hash="h", + path=Path("id.wav"), + parent_content_hash="ph", + transcript="TV_ON", + applied_values={"noise_volume": 0.45}, + ) + store = ManifestStore() + manifest_path = tmp_path / "manifest.json" + store.write(Manifest([sample]), manifest_path) + + result = store.read(manifest_path) + s = result.samples[0] + + assert isinstance(s.applied_values["noise_volume"], float) + + def test_path_stored_as_bare_filename(self, tmp_path: Path) -> None: + store = ManifestStore() + manifest_path = tmp_path / "manifest.json" + store.write(Manifest([_audio_sample()]), manifest_path) + + raw = json.loads(manifest_path.read_text()) + # Must be filename only, no directory component + assert raw["samples"][0]["path"] == "TV_ON_Jenny_r77.wav" + + +# --------------------------------------------------------------------------- +# SampleSpectrogram round-trip +# --------------------------------------------------------------------------- + +class TestSampleSpectrogramRoundTrip: + def test_round_trip_preserves_all_fields(self, tmp_path: Path) -> None: + sample = _spectrogram_sample() + store = ManifestStore() + manifest_path = tmp_path / "manifest.json" + store.write(Manifest([sample]), manifest_path) + + result = store.read(manifest_path) + + assert len(result.samples) == 1 + s = result.samples[0] + assert isinstance(s, SampleSpectrogram) + assert s.id == "TV_ON_Jenny_r77" + assert s.seed == 0 + assert s.content_hash == "spechash" + assert s.path == Path("TV_ON_Jenny_r77.npy") + assert s.parent_content_hash == "audiohash" + assert s.transcript == "TV_ON" + assert s.parent_id == "audio-uuid-123" + + def test_parent_id_serialised_in_json(self, tmp_path: Path) -> None: + store = ManifestStore() + manifest_path = tmp_path / "manifest.json" + store.write(Manifest([_spectrogram_sample()]), manifest_path) + + raw = json.loads(manifest_path.read_text()) + + assert raw["samples"][0]["parent_id"] == "audio-uuid-123" + + def test_applied_values_omitted_from_json(self, tmp_path: Path) -> None: + store = ManifestStore() + manifest_path = tmp_path / "manifest.json" + store.write(Manifest([_spectrogram_sample()]), manifest_path) + + raw = json.loads(manifest_path.read_text()) + + assert "applied_values" not in raw["samples"][0] + + def test_sample_type_is_spectrogram(self, tmp_path: Path) -> None: + store = ManifestStore() + manifest_path = tmp_path / "manifest.json" + store.write(Manifest([_spectrogram_sample()]), manifest_path) + + raw = json.loads(manifest_path.read_text()) + + assert raw["sample_type"] == "spectrogram" + + +# --------------------------------------------------------------------------- +# SampleTokens round-trip +# --------------------------------------------------------------------------- + +class TestSampleTokensRoundTrip: + def test_round_trip_preserves_all_fields(self, tmp_path: Path) -> None: + sample = _tokens_sample() + store = ManifestStore() + manifest_path = tmp_path / "manifest.json" + store.write(Manifest([sample]), manifest_path) + + result = store.read(manifest_path) + + assert len(result.samples) == 1 + s = result.samples[0] + assert isinstance(s, SampleTokens) + assert s.id == "TV_ON_Jenny_r77" + assert s.seed == 0 + assert s.content_hash == "tokenshash" + assert s.path == Path("TV_ON_Jenny_r77.json") + assert s.parent_content_hash == "audiohash" + assert s.transcript == "TV_ON" + assert s.parent_id == "audio-uuid-123" + + def test_sample_type_is_tokens(self, tmp_path: Path) -> None: + store = ManifestStore() + manifest_path = tmp_path / "manifest.json" + store.write(Manifest([_tokens_sample()]), manifest_path) + + raw = json.loads(manifest_path.read_text()) + + assert raw["sample_type"] == "tokens" + + def test_applied_values_omitted_from_json(self, tmp_path: Path) -> None: + store = ManifestStore() + manifest_path = tmp_path / "manifest.json" + store.write(Manifest([_tokens_sample()]), manifest_path) + + raw = json.loads(manifest_path.read_text()) + + assert "applied_values" not in raw["samples"][0] + + +# --------------------------------------------------------------------------- +# ManifestStore.read() dispatch by sample_type +# --------------------------------------------------------------------------- + +class TestManifestStoreReadDispatch: + def test_reads_text_sample_type(self, tmp_path: Path) -> None: + store = ManifestStore() + manifest_path = tmp_path / "manifest.json" + store.write(Manifest([_text_sample()]), manifest_path) + + result = store.read(manifest_path) + + assert all(isinstance(s, TextSample) for s in result.samples) + + def test_reads_audio_sample_type(self, tmp_path: Path) -> None: + store = ManifestStore() + manifest_path = tmp_path / "manifest.json" + store.write(Manifest([_audio_sample()]), manifest_path) + + result = store.read(manifest_path) + + assert all(isinstance(s, AudioSample) for s in result.samples) + + def test_reads_spectrogram_sample_type(self, tmp_path: Path) -> None: + store = ManifestStore() + manifest_path = tmp_path / "manifest.json" + store.write(Manifest([_spectrogram_sample()]), manifest_path) + + result = store.read(manifest_path) + + assert all(isinstance(s, SampleSpectrogram) for s in result.samples) + + def test_reads_tokens_sample_type(self, tmp_path: Path) -> None: + store = ManifestStore() + manifest_path = tmp_path / "manifest.json" + store.write(Manifest([_tokens_sample()]), manifest_path) + + result = store.read(manifest_path) + + assert all(isinstance(s, SampleTokens) for s in result.samples) + + def test_reads_multiple_samples(self, tmp_path: Path) -> None: + s1 = _audio_sample() + s2 = AudioSample( + id="TV_OFF_Jenny_r80", + seed=11111, + content_hash="hash2", + path=Path("TV_OFF_Jenny_r80.wav"), + parent_content_hash="texthash2", + transcript="TV_OFF", + applied_values={"voice": "en-US-JennyNeural", "speech_rate": -20}, + ) + store = ManifestStore() + manifest_path = tmp_path / "manifest.json" + store.write(Manifest([s1, s2]), manifest_path) + + result = store.read(manifest_path) + + assert len(result.samples) == 2 + + +# --------------------------------------------------------------------------- +# Manifest lookup methods +# --------------------------------------------------------------------------- + +class TestManifestLookup: + def test_by_content_hash_returns_matching_sample(self) -> None: + sample = _text_sample() + manifest = Manifest([sample]) + + result = manifest.by_content_hash("deadbeef") + + assert result is sample + + def test_by_content_hash_returns_none_when_not_found(self) -> None: + manifest = Manifest([_text_sample()]) + + result = manifest.by_content_hash("notfound") + + assert result is None + + def test_by_id_returns_matching_sample(self) -> None: + sample = _text_sample() + manifest = Manifest([sample]) + + result = manifest.by_id("deadbeef") + + assert result is sample + + def test_by_id_returns_none_when_not_found(self) -> None: + manifest = Manifest([_text_sample()]) + + result = manifest.by_id("notfound") + + assert result is None + + def test_samples_property_returns_tuple(self) -> None: + manifest = Manifest([_text_sample()]) + + assert isinstance(manifest.samples, tuple) + + def test_empty_manifest(self) -> None: + manifest = Manifest([]) + + assert manifest.samples == () + assert manifest.by_id("x") is None + assert manifest.by_content_hash("x") is None + + def test_write_empty_manifest_raises(self, tmp_path: Path) -> None: + with pytest.raises(ValueError, match="Cannot write empty manifest"): + ManifestStore().write(Manifest([]), tmp_path / "manifest.json") + + def test_duplicate_id_raises(self) -> None: + s1 = AudioSample( + id="same-id", + seed=1, + content_hash="hash1", + path=Path("f1.wav"), + parent_content_hash="ph", + transcript="T", + applied_values={}, + ) + s2 = AudioSample( + id="same-id", + seed=2, + content_hash="hash2", + path=Path("f2.wav"), + parent_content_hash="ph", + transcript="T", + applied_values={}, + ) + with pytest.raises(ValueError, match="duplicate sample ids"): + Manifest([s1, s2]) + + def test_duplicate_content_hash_keeps_first(self) -> None: + s1 = AudioSample( + id="id-first", + seed=0, + content_hash="shared-hash", + path=Path("f1.wav"), + parent_content_hash="ph", + transcript="T", + applied_values={}, + ) + s2 = AudioSample( + id="id-second", + seed=0, + content_hash="shared-hash", + path=Path("f2.wav"), + parent_content_hash="ph", + transcript="T", + applied_values={}, + ) + manifest = Manifest([s1, s2]) + + assert manifest.by_content_hash("shared-hash") is s1 + + def test_read_missing_sample_type_raises(self, tmp_path: Path) -> None: + manifest_path = tmp_path / "manifest.json" + manifest_path.write_text(json.dumps({"version": 1, "samples": []})) + with pytest.raises(ValueError, match="Missing 'sample_type'"): + ManifestStore().read(manifest_path) + + def test_write_mixed_sample_types_raises_correctly(self, tmp_path: Path) -> None: + text = _text_sample() + audio = _audio_sample() + # Bypass Manifest type-safety by injecting samples directly + m: Manifest = object.__new__(Manifest) + m._samples = (text, audio) # type: ignore[attr-defined] + m._by_content_hash = {text.content_hash: text, audio.content_hash: audio} # type: ignore[attr-defined] + m._by_id = {text.id: text, audio.id: audio} # type: ignore[attr-defined] + with pytest.raises(ValueError, match="mixed sample types"): + ManifestStore().write(m, tmp_path / "manifest.json") diff --git a/ml/test/pipeline/core/test_modifier_stage.py b/ml/test/pipeline/core/test_modifier_stage.py new file mode 100644 index 00000000..3b968070 --- /dev/null +++ b/ml/test/pipeline/core/test_modifier_stage.py @@ -0,0 +1,492 @@ +from __future__ import annotations + +import asyncio +import hashlib +import json +from pathlib import Path +from typing import Any, ClassVar + +import pytest + +from pipeline.core.manifest import Manifest, ManifestStore +from pipeline.core.modifier_stage import ModifierStage +from pipeline.core.randomization import VariationGenerator +from pipeline.core.sample import AudioSample, TextSample + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _content_hash(parent_content_hash: str, seed: int, applied_values: dict[str, Any]) -> str: + return ModifierStage._compute_content_hash(parent_content_hash, seed, applied_values) + + +def _text_sample(content: str = "turn on the tv", label: str = "TV_ON") -> TextSample: + h = hashlib.sha256(content.encode("utf-8")).hexdigest() + return TextSample(seed=0, content_hash=h, content=content, label=label) + + +def _audio_sample( + *, + id: str = "TV_ON_fake", + seed: int = 42, + parent_content_hash: str, + applied_values: dict[str, Any] | None = None, +) -> AudioSample: + av = applied_values or {} + ch = _content_hash(parent_content_hash, seed, av) + return AudioSample( + id=id, + seed=seed, + content_hash=ch, + path=Path(f"{id}.wav"), + parent_content_hash=parent_content_hash, + transcript="TV_ON", + applied_values=av, + ) + + +def _write_prev_manifest(path: Path, samples: list[AudioSample]) -> None: + ManifestStore().write(Manifest(samples), path) + + +class _FakeStage(ModifierStage[TextSample, AudioSample]): + """Minimal concrete subclass for testing ModifierStage logic.""" + + _is_deterministic: ClassVar[bool] = False + + def __init__( + self, + output_dir: Path, + manifest_store: ManifestStore, + *, + av_by_content_hash: dict[str, dict[str, Any]] | None = None, + ) -> None: + super().__init__(output_dir, manifest_store) + # per-sample applied_values; keyed on input_sample.content_hash + self._av_map: dict[str, dict[str, Any]] = av_by_content_hash or {} + self.generate_output_calls: list[tuple[str, int, dict[str, Any]]] = [] + self.derive_id_calls: list[tuple[str, dict[str, Any]]] = [] + + def _get_applied_values( + self, sample: TextSample, generator: VariationGenerator + ) -> dict[str, Any]: + return dict(self._av_map.get(sample.content_hash, {})) + + async def _generate_output( + self, + input_sample: TextSample, + output_id: str, + output_seed: int, + applied_values: dict[str, Any], + parent_content_hash: str, + ) -> AudioSample: + self.generate_output_calls.append((output_id, output_seed, applied_values)) + (self._output_dir / f"{output_id}.wav").write_bytes(b"audio") + return AudioSample( + id=output_id, + seed=output_seed, + content_hash=_content_hash(parent_content_hash, output_seed, applied_values), + path=Path(f"{output_id}.wav"), + parent_content_hash=parent_content_hash, + transcript=input_sample.label, + applied_values=applied_values, + ) + + def _derive_id( + self, input_sample: TextSample, applied_values: dict[str, Any] + ) -> str: + self.derive_id_calls.append((input_sample.content_hash, applied_values)) + return f"{input_sample.id}_fake" + + +class _DeterministicFakeStage(_FakeStage): + _is_deterministic: ClassVar[bool] = True + + +# --------------------------------------------------------------------------- +# TestComputeContentHash +# --------------------------------------------------------------------------- + +class TestComputeContentHash: + def test_known_value(self) -> None: + # Verify exact sha256 formula: sha256(parent + ":" + str(seed) + ":" + canonical(av)) + parent = "abc" + seed = 0 + av: dict[str, Any] = {} + canonical = json.dumps(av, sort_keys=True, separators=(",", ":"), ensure_ascii=True) + raw = f"{parent}:{seed}:{canonical}" + expected = hashlib.sha256(raw.encode("utf-8")).hexdigest() + + result = ModifierStage._compute_content_hash(parent, seed, av) + + assert result == expected + + def test_sort_keys_ensures_stability(self) -> None: + h1 = ModifierStage._compute_content_hash("p", 1, {"b": 2, "a": 1}) + h2 = ModifierStage._compute_content_hash("p", 1, {"a": 1, "b": 2}) + assert h1 == h2 + + def test_different_seeds_give_different_hashes(self) -> None: + h1 = ModifierStage._compute_content_hash("p", 1, {}) + h2 = ModifierStage._compute_content_hash("p", 2, {}) + assert h1 != h2 + + def test_different_parents_give_different_hashes(self) -> None: + h1 = ModifierStage._compute_content_hash("pA", 0, {}) + h2 = ModifierStage._compute_content_hash("pB", 0, {}) + assert h1 != h2 + + def test_different_applied_values_give_different_hashes(self) -> None: + h1 = ModifierStage._compute_content_hash("p", 0, {"x": 1}) + h2 = ModifierStage._compute_content_hash("p", 0, {"x": 2}) + assert h1 != h2 + + +# --------------------------------------------------------------------------- +# TestSkipPath +# --------------------------------------------------------------------------- + +class TestSkipPath: + def test_unchanged_sample_is_kept_verbatim(self, tmp_path: Path) -> None: + inp = _text_sample() + prev = _audio_sample(seed=77, parent_content_hash=inp.content_hash, applied_values={}) + _write_prev_manifest(tmp_path / "manifest.json", [prev]) + + stage = _FakeStage(tmp_path, ManifestStore()) + result = asyncio.run(stage.transform(Manifest([inp]), tmp_path / "manifest.json")) + + assert len(result.samples) == 1 + out = result.samples[0] + assert out.id == prev.id + assert out.content_hash == prev.content_hash + assert out.seed == prev.seed + + def test_generate_output_not_called_on_skip(self, tmp_path: Path) -> None: + inp = _text_sample() + prev = _audio_sample(seed=77, parent_content_hash=inp.content_hash, applied_values={}) + _write_prev_manifest(tmp_path / "manifest.json", [prev]) + + stage = _FakeStage(tmp_path, ManifestStore()) + asyncio.run(stage.transform(Manifest([inp]), tmp_path / "manifest.json")) + + assert stage.generate_output_calls == [] + + def test_existing_output_file_not_deleted_on_skip(self, tmp_path: Path) -> None: + inp = _text_sample() + prev = _audio_sample(seed=77, parent_content_hash=inp.content_hash, applied_values={}) + _write_prev_manifest(tmp_path / "manifest.json", [prev]) + existing_file = tmp_path / prev.path.name + existing_file.write_bytes(b"audio") + + stage = _FakeStage(tmp_path, ManifestStore()) + asyncio.run(stage.transform(Manifest([inp]), tmp_path / "manifest.json")) + + assert existing_file.exists() + + def test_manifest_written_with_preserved_sample(self, tmp_path: Path) -> None: + inp = _text_sample() + prev = _audio_sample(seed=77, parent_content_hash=inp.content_hash, applied_values={}) + _write_prev_manifest(tmp_path / "manifest.json", [prev]) + + stage = _FakeStage(tmp_path, ManifestStore()) + asyncio.run(stage.transform(Manifest([inp]), tmp_path / "manifest.json")) + + written = ManifestStore().read(tmp_path / "manifest.json") + assert len(written.samples) == 1 + assert written.samples[0].id == prev.id + + +# --------------------------------------------------------------------------- +# TestRegenPath +# --------------------------------------------------------------------------- + +class TestRegenPath: + def _setup_regen( + self, tmp_path: Path, old_av: dict[str, Any], new_av: dict[str, Any] + ) -> tuple[TextSample, AudioSample, _FakeStage]: + """Creates prev manifest with old_av; stage returns new_av for the sample.""" + inp = _text_sample() + stored_seed = 99 + # content_hash was computed with old_av, so it won't match new_av → regen + prev = AudioSample( + id="TV_ON_old", + seed=stored_seed, + content_hash=_content_hash(inp.content_hash, stored_seed, old_av), + path=Path("TV_ON_old.wav"), + parent_content_hash=inp.content_hash, + transcript="TV_ON", + applied_values=old_av, + ) + _write_prev_manifest(tmp_path / "manifest.json", [prev]) + stage = _FakeStage( + tmp_path, + ManifestStore(), + av_by_content_hash={inp.content_hash: new_av}, + ) + return inp, prev, stage + + def test_regen_produces_new_id_via_derive_id(self, tmp_path: Path) -> None: + inp, prev, stage = self._setup_regen(tmp_path, {"x": 1}, {"x": 2}) + result = asyncio.run(stage.transform(Manifest([inp]), tmp_path / "manifest.json")) + + assert len(result.samples) == 1 + assert result.samples[0].id != prev.id + assert len(stage.derive_id_calls) == 1 + + def test_regen_preserves_stored_seed(self, tmp_path: Path) -> None: + inp, prev, stage = self._setup_regen(tmp_path, {"x": 1}, {"x": 2}) + result = asyncio.run(stage.transform(Manifest([inp]), tmp_path / "manifest.json")) + + assert result.samples[0].seed == prev.seed + + def test_regen_updates_content_hash(self, tmp_path: Path) -> None: + inp, prev, stage = self._setup_regen(tmp_path, {"x": 1}, {"x": 2}) + result = asyncio.run(stage.transform(Manifest([inp]), tmp_path / "manifest.json")) + + assert result.samples[0].content_hash != prev.content_hash + expected = _content_hash(inp.content_hash, prev.seed, {"x": 2}) + assert result.samples[0].content_hash == expected + + def test_regen_old_file_gc_deleted(self, tmp_path: Path) -> None: + inp, prev, stage = self._setup_regen(tmp_path, {"x": 1}, {"x": 2}) + old_file = tmp_path / prev.path.name + old_file.write_bytes(b"old audio") + + asyncio.run(stage.transform(Manifest([inp]), tmp_path / "manifest.json")) + + assert not old_file.exists() + + def test_regen_generate_output_called_with_stored_seed(self, tmp_path: Path) -> None: + inp, prev, stage = self._setup_regen(tmp_path, {"x": 1}, {"x": 2}) + asyncio.run(stage.transform(Manifest([inp]), tmp_path / "manifest.json")) + + assert len(stage.generate_output_calls) == 1 + _id, seed, av = stage.generate_output_calls[0] + assert seed == prev.seed + assert av == {"x": 2} + + +# --------------------------------------------------------------------------- +# TestNewSamplePath +# --------------------------------------------------------------------------- + +class TestNewSamplePath: + def test_new_sample_calls_derive_id(self, tmp_path: Path) -> None: + inp = _text_sample() + stage = _FakeStage(tmp_path, ManifestStore()) + + asyncio.run(stage.transform(Manifest([inp]), tmp_path / "manifest.json")) + + assert len(stage.derive_id_calls) == 1 + + def test_new_sample_calls_generate_output(self, tmp_path: Path) -> None: + inp = _text_sample() + stage = _FakeStage(tmp_path, ManifestStore()) + + asyncio.run(stage.transform(Manifest([inp]), tmp_path / "manifest.json")) + + assert len(stage.generate_output_calls) == 1 + + def test_new_sample_gets_nonzero_seed(self, tmp_path: Path) -> None: + # Stochastic stage → seed from os.urandom; not 0 (astronomically unlikely to be 0) + inp = _text_sample() + stage = _FakeStage(tmp_path, ManifestStore()) + + result = asyncio.run(stage.transform(Manifest([inp]), tmp_path / "manifest.json")) + + # os.urandom(8) producing 0 is 1 in 2^64; treat as impossible for test purposes + assert result.samples[0].seed != 0 + + def test_no_previous_manifest_means_all_samples_are_new(self, tmp_path: Path) -> None: + samples = [_text_sample(content=f"phrase {i}") for i in range(3)] + stage = _FakeStage(tmp_path, ManifestStore()) + + asyncio.run(stage.transform(Manifest(samples), tmp_path / "manifest.json")) + + assert len(stage.generate_output_calls) == 3 + + def test_new_sample_id_derived_not_uuid(self, tmp_path: Path) -> None: + # _derive_id is called (not uuid4); our stub returns "{id}_fake" + inp = _text_sample() + stage = _FakeStage(tmp_path, ManifestStore()) + + result = asyncio.run(stage.transform(Manifest([inp]), tmp_path / "manifest.json")) + + assert result.samples[0].id.endswith("_fake") + + def test_manifest_written_after_new_sample(self, tmp_path: Path) -> None: + inp = _text_sample() + stage = _FakeStage(tmp_path, ManifestStore()) + + asyncio.run(stage.transform(Manifest([inp]), tmp_path / "manifest.json")) + + written = ManifestStore().read(tmp_path / "manifest.json") + assert len(written.samples) == 1 + + +# --------------------------------------------------------------------------- +# TestDeterministicStage +# --------------------------------------------------------------------------- + +class TestDeterministicStage: + def test_new_sample_gets_seed_zero(self, tmp_path: Path) -> None: + inp = _text_sample() + stage = _DeterministicFakeStage(tmp_path, ManifestStore()) + + result = asyncio.run(stage.transform(Manifest([inp]), tmp_path / "manifest.json")) + + assert result.samples[0].seed == 0 + + def test_generate_output_called_with_seed_zero(self, tmp_path: Path) -> None: + inp = _text_sample() + stage = _DeterministicFakeStage(tmp_path, ManifestStore()) + + asyncio.run(stage.transform(Manifest([inp]), tmp_path / "manifest.json")) + + _id, seed, _av = stage.generate_output_calls[0] + assert seed == 0 + + def test_multiple_new_samples_all_get_seed_zero(self, tmp_path: Path) -> None: + samples = [_text_sample(content=f"phrase {i}") for i in range(3)] + stage = _DeterministicFakeStage(tmp_path, ManifestStore()) + + result = asyncio.run(stage.transform(Manifest(samples), tmp_path / "manifest.json")) + + assert all(s.seed == 0 for s in result.samples) + + +# --------------------------------------------------------------------------- +# TestGarbageCollection +# --------------------------------------------------------------------------- + +class TestGarbageCollection: + def test_gc_removes_orphaned_file(self, tmp_path: Path) -> None: + orphan = tmp_path / "orphan.wav" + orphan.write_bytes(b"stale") + inp = _text_sample() + stage = _FakeStage(tmp_path, ManifestStore()) + + asyncio.run(stage.transform(Manifest([inp]), tmp_path / "manifest.json")) + + assert not orphan.exists() + + def test_gc_removes_multiple_orphaned_files(self, tmp_path: Path) -> None: + orphans = [tmp_path / f"orphan{i}.wav" for i in range(3)] + for f in orphans: + f.write_bytes(b"stale") + inp = _text_sample() + stage = _FakeStage(tmp_path, ManifestStore()) + + asyncio.run(stage.transform(Manifest([inp]), tmp_path / "manifest.json")) + + assert all(not f.exists() for f in orphans) + + def test_gc_does_not_delete_manifest_json(self, tmp_path: Path) -> None: + inp = _text_sample() + stage = _FakeStage(tmp_path, ManifestStore()) + + asyncio.run(stage.transform(Manifest([inp]), tmp_path / "manifest.json")) + + assert (tmp_path / "manifest.json").exists() + + def test_gc_does_not_delete_current_output_file(self, tmp_path: Path) -> None: + inp = _text_sample() + stage = _FakeStage(tmp_path, ManifestStore()) + + result = asyncio.run(stage.transform(Manifest([inp]), tmp_path / "manifest.json")) + + output_path = tmp_path / result.samples[0].path.name + assert output_path.exists() + + def test_gc_does_not_delete_skipped_file(self, tmp_path: Path) -> None: + inp = _text_sample() + prev = _audio_sample(seed=77, parent_content_hash=inp.content_hash, applied_values={}) + _write_prev_manifest(tmp_path / "manifest.json", [prev]) + existing_file = tmp_path / prev.path.name + existing_file.write_bytes(b"audio") + orphan = tmp_path / "orphan.wav" + orphan.write_bytes(b"stale") + + stage = _FakeStage(tmp_path, ManifestStore()) + asyncio.run(stage.transform(Manifest([inp]), tmp_path / "manifest.json")) + + assert existing_file.exists() + assert not orphan.exists() + + def test_gc_empty_output_dir_does_not_raise(self, tmp_path: Path) -> None: + inp = _text_sample() + stage = _FakeStage(tmp_path, ManifestStore()) + + # Should not raise even when output_dir is empty (no files to GC) + asyncio.run(stage.transform(Manifest([inp]), tmp_path / "manifest.json")) + + +# --------------------------------------------------------------------------- +# TestSplitBehavior +# --------------------------------------------------------------------------- + +class TestSplitBehavior: + """Two samples; same 'constraint change' causes a skip for one but regen for the other. + + Demonstrates that skip/regen is determined per-sample by whether the recomputed + content_hash matches the stored content_hash — not by the seed alone. + """ + + def test_split_behavior_skip_and_regen_in_one_transform(self, tmp_path: Path) -> None: + inp_a = _text_sample(content="phrase a", label="A") + inp_b = _text_sample(content="phrase b", label="B") + + seed_a = 10 + seed_b = 20 + # Both samples previously had {"x": 1} + old_av = {"x": 1} + prev_a = AudioSample( + id="A_fake", + seed=seed_a, + content_hash=_content_hash(inp_a.content_hash, seed_a, old_av), + path=Path("A_fake.wav"), + parent_content_hash=inp_a.content_hash, + transcript="A", + applied_values=old_av, + ) + prev_b = AudioSample( + id="B_fake", + seed=seed_b, + content_hash=_content_hash(inp_b.content_hash, seed_b, old_av), + path=Path("B_fake.wav"), + parent_content_hash=inp_b.content_hash, + transcript="B", + applied_values=old_av, + ) + _write_prev_manifest(tmp_path / "manifest.json", [prev_a, prev_b]) + + # After constraint change: sample A still returns {"x": 1} (no change → skip) + # but sample B now returns {"x": 2} (changed → regen) + stage = _FakeStage( + tmp_path, + ManifestStore(), + av_by_content_hash={ + inp_a.content_hash: {"x": 1}, + inp_b.content_hash: {"x": 2}, + }, + ) + result = asyncio.run( + stage.transform(Manifest([inp_a, inp_b]), tmp_path / "manifest.json") + ) + + assert len(result.samples) == 2 + out_a = next(s for s in result.samples if s.parent_content_hash == inp_a.content_hash) + out_b = next(s for s in result.samples if s.parent_content_hash == inp_b.content_hash) + + # Sample A: skipped — same id, same content_hash, no regen call + assert out_a.id == prev_a.id + assert out_a.content_hash == prev_a.content_hash + + # Sample B: regenerated — new id, same seed, different content_hash + assert out_b.id != prev_b.id + assert out_b.seed == prev_b.seed + assert out_b.content_hash != prev_b.content_hash + + assert len(stage.generate_output_calls) == 1 diff --git a/ml/test/pipeline/core/test_randomization.py b/ml/test/pipeline/core/test_randomization.py new file mode 100644 index 00000000..8cdd800a --- /dev/null +++ b/ml/test/pipeline/core/test_randomization.py @@ -0,0 +1,418 @@ +from __future__ import annotations + +import math +import pytest + +from pipeline.core.randomization import ( + MinMaxFilter, + NormalFilter, + PassFilter, + VariationGenerator, +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +class _NeverAcceptFilter(PassFilter): + """PassFilter whose density() always returns 0.0, so every candidate is rejected.""" + + def density(self, value: float) -> float: + return 0.0 + + def sample_domain(self) -> tuple[float, float]: + return (0.0, 1.0) + + +# --------------------------------------------------------------------------- +# MinMaxFilter — density +# --------------------------------------------------------------------------- + +class TestMinMaxFilterDensity: + def test_density_at_min_returns_1(self) -> None: + f = MinMaxFilter(2.0, 8.0) + assert f.density(2.0) == 1.0 + + def test_density_at_max_returns_1(self) -> None: + f = MinMaxFilter(2.0, 8.0) + assert f.density(8.0) == 1.0 + + def test_density_in_range_returns_1(self) -> None: + f = MinMaxFilter(0.0, 10.0) + assert f.density(5.0) == 1.0 + + def test_density_below_min_returns_0(self) -> None: + f = MinMaxFilter(2.0, 8.0) + assert f.density(1.9) == 0.0 + + def test_density_above_max_returns_0(self) -> None: + f = MinMaxFilter(2.0, 8.0) + assert f.density(8.1) == 0.0 + + def test_sample_domain_returns_min_max(self) -> None: + f = MinMaxFilter(3.0, 7.5) + assert f.sample_domain() == (3.0, 7.5) + + def test_min_greater_than_max_raises_value_error(self) -> None: + with pytest.raises(ValueError): + MinMaxFilter(8.0, 2.0) + + def test_equal_min_max_constructs_successfully(self) -> None: + f = MinMaxFilter(5.0, 5.0) + assert f.density(5.0) == 1.0 + + +# --------------------------------------------------------------------------- +# NormalFilter — density and domain +# --------------------------------------------------------------------------- + +class TestNormalFilterDensity: + def test_density_at_mean_is_1(self) -> None: + f = NormalFilter(5.0, 2.0) + assert f.density(5.0) == 1.0 + + def test_density_at_one_std_dev_from_mean(self) -> None: + f = NormalFilter(5.0, 2.0) + # exp(-0.5 * 1^2) = exp(-0.5) + assert f.density(7.0) == pytest.approx(math.exp(-0.5)) + assert f.density(3.0) == pytest.approx(math.exp(-0.5)) + + def test_density_decreases_away_from_mean(self) -> None: + f = NormalFilter(5.0, 2.0) + assert f.density(5.0) > f.density(6.0) > f.density(8.0) + + def test_density_is_symmetric_around_mean(self) -> None: + f = NormalFilter(5.0, 2.0) + assert f.density(3.0) == pytest.approx(f.density(7.0)) + + def test_sample_domain_is_5_std_devs(self) -> None: + f = NormalFilter(5.0, 2.0) + low, high = f.sample_domain() + assert low == pytest.approx(5.0 - 5 * 2.0) + assert high == pytest.approx(5.0 + 5 * 2.0) + + +# --------------------------------------------------------------------------- +# NormalFilter — validation +# --------------------------------------------------------------------------- + +class TestNormalFilterValidation: + def test_std_dev_zero_raises_value_error(self) -> None: + with pytest.raises(ValueError): + NormalFilter(5.0, 0.0) + + def test_std_dev_negative_raises_value_error(self) -> None: + with pytest.raises(ValueError): + NormalFilter(5.0, -1.0) + + def test_positive_std_dev_constructs_successfully(self) -> None: + f = NormalFilter(0.0, 0.001) + assert f.density(0.0) == 1.0 + + +# --------------------------------------------------------------------------- +# VariationGenerator — should_vary +# --------------------------------------------------------------------------- + +class TestVariationGeneratorShouldVary: + def test_same_seed_and_name_returns_same_result(self) -> None: + result_a = VariationGenerator(42).should_vary("prefix_delay", 0.5) + result_b = VariationGenerator(42).should_vary("prefix_delay", 0.5) + assert result_a == result_b + + def test_frequency_zero_always_returns_false(self) -> None: + for seed in [0, 1, 42, 99, 12345]: + assert VariationGenerator(seed).should_vary("x", 0.0) is False + + def test_frequency_one_always_returns_true(self) -> None: + for seed in [0, 1, 42, 99, 12345]: + assert VariationGenerator(seed).should_vary("x", 1.0) is True + + def test_different_seeds_can_give_different_results(self) -> None: + # seed=0 returns True, seed=42 returns False for freq=0.5 + assert VariationGenerator(0).should_vary("prefix_delay", 0.5) is True + assert VariationGenerator(42).should_vary("prefix_delay", 0.5) is False + + def test_probability_converges_over_many_seeds(self) -> None: + frequency = 0.7 + n_seeds = 1000 + true_count = sum( + 1 for s in range(n_seeds) + if VariationGenerator(s).should_vary("v", frequency) + ) + ratio = true_count / n_seeds + # Expect within 8% of the target frequency for 1000 samples + assert abs(ratio - frequency) < 0.08 + + def test_different_variable_names_are_independent(self) -> None: + g = VariationGenerator(3) + a = g.should_vary("prefix_delay_s", 0.5) + b = g.should_vary("suffix_delay_s", 0.5) + # Same seed, different variable names -> independent hashes -> different values + assert a != b + + def test_frequency_below_zero_raises_value_error(self) -> None: + with pytest.raises(ValueError): + VariationGenerator(0).should_vary("x", -0.1) + + def test_frequency_above_one_raises_value_error(self) -> None: + with pytest.raises(ValueError): + VariationGenerator(0).should_vary("x", 1.1) + + +# --------------------------------------------------------------------------- +# VariationGenerator — generate (float) +# --------------------------------------------------------------------------- + +class TestVariationGeneratorGenerate: + def test_same_seed_and_name_returns_same_value(self) -> None: + v1 = VariationGenerator(0).generate("speed", MinMaxFilter(0.0, 1.0)) + v2 = VariationGenerator(0).generate("speed", MinMaxFilter(0.0, 1.0)) + assert v1 == v2 + + def test_different_seeds_produce_different_values(self) -> None: + v0 = VariationGenerator(0).generate("speed", MinMaxFilter(0.0, 10.0)) + v42 = VariationGenerator(42).generate("speed", MinMaxFilter(0.0, 10.0)) + assert v0 != v42 + + def test_minmax_filter_value_in_range(self) -> None: + f = MinMaxFilter(2.0, 8.0) + val = VariationGenerator(0).generate("speed", f) + assert 2.0 <= val <= 8.0 + + def test_normal_filter_value_in_domain(self) -> None: + f = NormalFilter(5.0, 2.0) + low, high = f.sample_domain() + val = VariationGenerator(0).generate("noise_vol", f) + assert low <= val <= high + + def test_normal_filter_exact_value_is_deterministic(self) -> None: + # seed=0, "noise_vol", NormalFilter(5,2), precision=0 -> 7.0 + val = VariationGenerator(0).generate("noise_vol", NormalFilter(5.0, 2.0)) + assert val == 7.0 + + def test_minmax_filter_exact_value_is_deterministic(self) -> None: + # seed=0, "speed", MinMaxFilter(0,10), precision=0 -> 0.0 + val = VariationGenerator(0).generate("speed", MinMaxFilter(0.0, 10.0)) + assert val == 0.0 + + def test_precision_zero_returns_whole_number(self) -> None: + val = VariationGenerator(3).generate("x", MinMaxFilter(0.0, 100.0)) + assert val == math.floor(val) + + def test_precision_two_returns_at_most_two_decimal_places(self) -> None: + val = VariationGenerator(3).generate("x", MinMaxFilter(0.0, 10.0, precision=2)) + assert round(val, 2) == val + + def test_different_variable_names_produce_independent_values(self) -> None: + g = VariationGenerator(99) + va = g.generate("prefix_delay_s", MinMaxFilter(0.0, 1.0)) + vb = g.generate("suffix_delay_s", MinMaxFilter(0.0, 1.0)) + # Same seed, different variable names -> different hash keys -> independent + assert va != vb + + def test_raises_value_error_after_1000_failed_attempts(self) -> None: + with pytest.raises(ValueError): + VariationGenerator(0).generate("x", _NeverAcceptFilter()) + + +# --------------------------------------------------------------------------- +# VariationGenerator — generate_int +# --------------------------------------------------------------------------- + +class TestVariationGeneratorGenerateInt: + def test_same_seed_and_name_returns_same_value(self) -> None: + v1 = VariationGenerator(0).generate_int("speed", MinMaxFilter(0, 10)) + v2 = VariationGenerator(0).generate_int("speed", MinMaxFilter(0, 10)) + assert v1 == v2 + + def test_different_seeds_can_produce_different_values(self) -> None: + v0 = VariationGenerator(0).generate_int("speed", MinMaxFilter(0, 10)) + v42 = VariationGenerator(42).generate_int("speed", MinMaxFilter(0, 10)) + assert v0 != v42 + + def test_value_in_range(self) -> None: + for seed in range(20): + val = VariationGenerator(seed).generate_int("x", MinMaxFilter(0, 10)) + assert 0 <= val <= 10 + + def test_returns_int_type(self) -> None: + val = VariationGenerator(0).generate_int("speed", MinMaxFilter(0, 10)) + assert isinstance(val, int) + + def test_range_zero_returns_min_val_immediately(self) -> None: + # Special case: min_val == max_val -> return min_val without looping + assert VariationGenerator(0).generate_int("x", MinMaxFilter(5, 5)) == 5 + assert VariationGenerator(42).generate_int("y", MinMaxFilter(0, 0)) == 0 + + def test_exact_value_seed_0(self) -> None: + # seed=0, "speed", MinMaxFilter(0,10) -> 0 + assert VariationGenerator(0).generate_int("speed", MinMaxFilter(0, 10)) == 0 + + def test_exact_value_seed_42(self) -> None: + # seed=42, "speed", MinMaxFilter(0,10) -> 2 + assert VariationGenerator(42).generate_int("speed", MinMaxFilter(0, 10)) == 2 + + def test_different_variable_names_are_independent(self) -> None: + g = VariationGenerator(0) + v1 = g.generate_int("speech_rate", MinMaxFilter(-20, 20)) + v2 = g.generate_int("pitch", MinMaxFilter(-20, 20)) + assert v1 != v2 + + +# --------------------------------------------------------------------------- +# VariationGenerator — generate_int stability across range changes +# +# Seeds computed offline to exhibit specific bitmask-rejection behaviors: +# seed=1: generate_int("x", MinMaxFilter(0,10)) == 2 +# generate_int("x", MinMaxFilter(0,20)) == 2 (stable: same value) +# generate_int("x", MinMaxFilter(0,5)) == 2 (stable: value is in lower range) +# seed=2: generate_int("x", MinMaxFilter(0,10)) == 3 +# generate_int("x", MinMaxFilter(0,20)) == 11 (changes: bitmask bit-4 set, 3+16=19>10) +# seed=5: generate_int("x", MinMaxFilter(0,10)) == 7 +# generate_int("x", MinMaxFilter(0,5)) == 1 (changes: 7>5 rejected, resampled) +# --------------------------------------------------------------------------- + +class TestVariationGeneratorGenerateIntStability: + def test_stable_seed_same_value_when_max_widened(self) -> None: + # seed=1: value 2 is below 10 and its raw bits have bit-4 == 0, so mask=31 gives + # the same 2 as mask=15. + v_narrow = VariationGenerator(1).generate_int("x", MinMaxFilter(0, 10)) + v_wide = VariationGenerator(1).generate_int("x", MinMaxFilter(0, 20)) + assert v_narrow == 2 + assert v_wide == 2 + + def test_changing_seed_gets_higher_when_max_widened(self) -> None: + # seed=2: raw_0 & 31 == 11 (accepted at n=0 for max=20), but for max=10 the value + # 11 > 10 is rejected; the narrow case loops to n=2 where raw_2 & 15 == 3 (<= 10). + # The two ranges draw from different attempt indices, so the values differ. + v_narrow = VariationGenerator(2).generate_int("x", MinMaxFilter(0, 10)) + v_wide = VariationGenerator(2).generate_int("x", MinMaxFilter(0, 20)) + assert v_narrow == 3 + assert v_wide == 11 + + def test_stable_seed_same_value_when_max_narrowed(self) -> None: + # seed=1: value 2 is within the narrowed range [0,5], so it stays the same. + v_original = VariationGenerator(1).generate_int("x", MinMaxFilter(0, 10)) + v_narrow = VariationGenerator(1).generate_int("x", MinMaxFilter(0, 5)) + assert v_original == 2 + assert v_narrow == 2 + + def test_changing_seed_when_max_narrowed(self) -> None: + # seed=5: raw_0 & 15 == 7 (accepted for max=10), but 7 > 5 so rejected for max=5; + # resampling produces a different value. + v_original = VariationGenerator(5).generate_int("x", MinMaxFilter(0, 10)) + v_narrow = VariationGenerator(5).generate_int("x", MinMaxFilter(0, 5)) + assert v_original == 7 + assert v_narrow == 1 + + +# --------------------------------------------------------------------------- +# VariationGenerator — generate stability across range changes +# +# pow2_range for MinMaxFilter(0, N) is the smallest power-of-2 > N: +# N=5 -> pow2_range=8 (candidates 0..7, reject 6..7) +# N=6 -> pow2_range=8 (candidates 0..7, reject 7) +# N=7 -> pow2_range=8 (candidates 0..7, none rejected) +# N=10 -> pow2_range=16 (candidates 0..15, reject 11..15) +# N=15 -> pow2_range=16 (candidates 0..15, none rejected) +# +# seed=0, "x": n=0 candidate = raw%8 = 2 (accepted for all ranges above) +# seed=0, "x": n=0 candidate = raw%16 = 10 (accepted for N>=10; for N=7, raw%8=2 so stable) +# seed=5, "x": n=0 candidate = raw%8 = 7 (rejected for N<7, accepted for N>=7) +# --------------------------------------------------------------------------- + +class TestVariationGeneratorGenerateStability: + def test_stable_value_when_max_widens_within_same_pow2_range(self) -> None: + # seed=0: n=0 candidate=2; both MinMaxFilter(0,5) and MinMaxFilter(0,6) share + # pow2_range=8 and both accept 2 -> same value. + v_narrow = VariationGenerator(0).generate("x", MinMaxFilter(0, 5)) + v_wide = VariationGenerator(0).generate("x", MinMaxFilter(0, 6)) + assert v_narrow == 2.0 + assert v_wide == 2.0 + + def test_value_changes_when_pow2_range_expands(self) -> None: + # seed=0: MinMaxFilter(0,7) pow2_range=8 -> candidate=2; + # MinMaxFilter(0,15) pow2_range=16 -> candidate=10 (bit-3 of raw is set, so + # raw%16 != raw%8). + v_narrow = VariationGenerator(0).generate("x", MinMaxFilter(0, 7)) + v_wide = VariationGenerator(0).generate("x", MinMaxFilter(0, 15)) + assert v_narrow == 2.0 + assert v_wide == 10.0 + + def test_value_changes_when_rejected_candidate_becomes_accepted(self) -> None: + # seed=5: n=0 candidate=7; rejected for MinMaxFilter(0,5) (7>5), so first + # accepted hit is at a later attempt -> 1.0. With MinMaxFilter(0,7) the same + # n=0 candidate=7 is now within range and accepted first -> 7.0. + v_strict = VariationGenerator(5).generate("x", MinMaxFilter(0, 5)) + v_relaxed = VariationGenerator(5).generate("x", MinMaxFilter(0, 7)) + assert v_strict == 1.0 + assert v_relaxed == 7.0 + + def test_value_changes_when_max_narrowed_below_candidate(self) -> None: + # seed=0: [0,10] pow2_range=16, first accepted candidate=10 (within [0,10]). + # [0,5] pow2_range=8, candidate=2 (10 is outside [0,5], so earlier attempt wins). + v_wide = VariationGenerator(0).generate("x", MinMaxFilter(0, 10)) + v_narrow = VariationGenerator(0).generate("x", MinMaxFilter(0, 5)) + assert v_wide == 10.0 + assert v_narrow == 2.0 + + def test_stable_value_when_max_narrowed_but_candidate_still_in_range(self) -> None: + # seed=1: first accepted candidate=2, which is within both [0,10] and [0,5]. + v_wide = VariationGenerator(1).generate("x", MinMaxFilter(0, 10)) + v_narrow = VariationGenerator(1).generate("x", MinMaxFilter(0, 5)) + assert v_wide == 2.0 + assert v_narrow == 2.0 + + +# --------------------------------------------------------------------------- +# VariationGenerator — choose +# --------------------------------------------------------------------------- + +class TestVariationGeneratorChoose: + def test_same_seed_and_name_returns_same_item(self) -> None: + options = ["a", "b", "c", "d"] + c1 = VariationGenerator(0).choose("voice", options) + c2 = VariationGenerator(0).choose("voice", options) + assert c1 == c2 + + def test_result_is_one_of_the_options(self) -> None: + options = ["en-US-JennyNeural", "en-US-GuyNeural", "en-GB-LibbyNeural"] + result = VariationGenerator(0).choose("voice", options) + assert result in options + + def test_uses_index_zero_hash_key(self) -> None: + # choose must use sha256("{seed}:{variable_name}:0") with index 0, no loop. + # Verified by exact match against the expected hash formula. + import hashlib + options = ["a", "b", "c", "d"] + seed = 0 + raw = int.from_bytes( + hashlib.sha256(f"{seed}:voice:0".encode()).digest()[:8], "big" + ) + expected = options[raw % len(options)] + assert VariationGenerator(seed).choose("voice", options) == expected + + def test_different_seeds_can_choose_different_items(self) -> None: + options = ["en-US-JennyNeural", "en-US-GuyNeural", "en-GB-LibbyNeural"] + # seed=0 -> JennyNeural, seed=1 -> GuyNeural + assert VariationGenerator(0).choose("voice", options) == "en-US-JennyNeural" + assert VariationGenerator(1).choose("voice", options) == "en-US-GuyNeural" + + def test_different_variable_names_select_independently(self) -> None: + options = ["x", "y", "z"] + g = VariationGenerator(0) + # Different keys -> independent hash values -> different selections + a = g.choose("noise_file", options) + b = g.choose("voice", options) + assert a != b + + def test_single_option_list_always_returns_it(self) -> None: + only = ["only-choice"] + for seed in [0, 1, 42, 999]: + assert VariationGenerator(seed).choose("x", only) == "only-choice" + + def test_empty_options_raises_value_error(self) -> None: + with pytest.raises(ValueError): + VariationGenerator(0).choose("x", []) diff --git a/scripts/validate-tests.cmd b/scripts/validate-tests.cmd index aa738179..eb615092 100644 --- a/scripts/validate-tests.cmd +++ b/scripts/validate-tests.cmd @@ -4,4 +4,8 @@ dotnet test --no-build "%~dp0validate-unit-tests.proj" if %ERRORLEVEL% neq 0 ( popd & exit /b %ERRORLEVEL% ) dotnet test --no-build "%~dp0validate-e2e-tests.proj" if %ERRORLEVEL% neq 0 ( popd & exit /b %ERRORLEVEL% ) +cd ml +if %ERRORLEVEL% neq 0 ( popd & exit /b %ERRORLEVEL% ) +python -m pytest +if %ERRORLEVEL% neq 0 ( popd & exit /b %ERRORLEVEL% ) popd diff --git a/scripts/validate-tests.sh b/scripts/validate-tests.sh index 8894f734..bdbf2167 100755 --- a/scripts/validate-tests.sh +++ b/scripts/validate-tests.sh @@ -6,3 +6,6 @@ echo 'Testing unit test projects...' dotnet test --no-build "$SCRIPT_DIR/validate-unit-tests.proj" echo 'Testing E2E test projects...' dotnet test --no-build "$SCRIPT_DIR/validate-e2e-tests.proj" +echo 'Testing Python unit tests...' +cd ml || exit 1 +python3 -m pytest