From f4054bc884df766a85a675c3b2384cebb0980e61 Mon Sep 17 00:00:00 2001 From: Claude Date: Fri, 19 Jun 2026 04:07:26 +0000 Subject: [PATCH 1/3] ADR-228: implement SpectrogramStage, TokenStage, SetManifestSplitter Add two deterministic featurisation stages and a set splitter: - SpectrogramStage: writes {id}.npy of shape (n_mels, time_steps); truncates long spectrograms and zero-pads short ones - TokenStage: writes {id}.json with phoneme strings and token indices padded to input_token_length; padding value is 0 (index 0) - SetManifestSplitter: shuffles with seed=42 and writes train/val/test manifests Also adds three DVC entry-point scripts (speech_07, 08, 09), extends PipelineParams with ComputeSpectrogramsParams, ComputeTokensParams, CreateSetManifestsParams, adds matching sections to params.yaml, and updates _doc_speech.md. --- ml/params.yaml | 9 + ml/pipeline/speech/_doc_speech.md | 5 +- ml/pipeline/speech/set_splitter.py | 74 ++++++ ml/pipeline/speech/spectrogram_stage.py | 103 ++++++++ ml/pipeline/speech/token_stage.py | 95 +++++++ ml/pipeline/stages/params.py | 36 +++ .../stages/speech_07_compute_tokens.py | 62 +++++ .../stages/speech_08_compute_spectrograms.py | 53 ++++ .../stages/speech_09_create_set_manifests.py | 44 ++++ ml/test/pipeline/speech/test_set_splitter.py | 165 ++++++++++++ .../pipeline/speech/test_spectrogram_stage.py | 211 +++++++++++++++ ml/test/pipeline/speech/test_token_stage.py | 248 ++++++++++++++++++ ml/test/pipeline/stages/test_params.py | 110 ++++++++ 13 files changed, 1213 insertions(+), 2 deletions(-) create mode 100644 ml/pipeline/speech/set_splitter.py create mode 100644 ml/pipeline/speech/spectrogram_stage.py create mode 100644 ml/pipeline/speech/token_stage.py create mode 100644 ml/pipeline/stages/speech_07_compute_tokens.py create mode 100644 ml/pipeline/stages/speech_08_compute_spectrograms.py create mode 100644 ml/pipeline/stages/speech_09_create_set_manifests.py create mode 100644 ml/test/pipeline/speech/test_set_splitter.py create mode 100644 ml/test/pipeline/speech/test_spectrogram_stage.py create mode 100644 ml/test/pipeline/speech/test_token_stage.py diff --git a/ml/params.yaml b/ml/params.yaml index e57d7eaa..b7d7c26c 100644 --- a/ml/params.yaml +++ b/ml/params.yaml @@ -29,3 +29,12 @@ stages: vary_probability: 0.5 amplitude_min: 0.001 amplitude_max: 0.05 + compute_spectrograms: + n_mels: 80 + time_steps: 400 + compute_tokens: + input_token_length: 50 + create_set_manifests: + train_pct: 80 + val_pct: 10 + test_pct: 10 diff --git a/ml/pipeline/speech/_doc_speech.md b/ml/pipeline/speech/_doc_speech.md index e548c093..02560661 100644 --- a/ml/pipeline/speech/_doc_speech.md +++ b/ml/pipeline/speech/_doc_speech.md @@ -12,8 +12,9 @@ produces a derived `AudioSample` manifest with per-sample variation applied. | [`delay_stage.py`](delay_stage.py) | `DelayAugmentor` | `AudioSample → AudioSample` (silence padding) | | [`background_noise_stage.py`](background_noise_stage.py) | `BackgroundNoiseAugmentor` | `AudioSample → AudioSample` (environmental noise mix) | | [`mic_noise_stage.py`](mic_noise_stage.py) | `MicrophoneNoiseAugmentor` | `AudioSample → AudioSample` (Gaussian mic noise) | - -Planned stages (not yet implemented): `token_stage.py`, `spectrogram_stage.py`. +| [`token_stage.py`](token_stage.py) | `TokenStage` | `AudioSample → SampleTokens` (phoneme token sequences) | +| [`spectrogram_stage.py`](spectrogram_stage.py) | `SpectrogramStage` | `AudioSample → SampleSpectrogram` (mel spectrogram NPY) | +| [`set_splitter.py`](set_splitter.py) | `SetManifestSplitter` | Splits augmented manifest into train/val/test | ## Key design decisions diff --git a/ml/pipeline/speech/set_splitter.py b/ml/pipeline/speech/set_splitter.py new file mode 100644 index 00000000..36204153 --- /dev/null +++ b/ml/pipeline/speech/set_splitter.py @@ -0,0 +1,74 @@ +"""SetManifestSplitter: split an augmented manifest into train/val/test sets. + +Shuffles with seed=42 for reproducibility. Percentages must sum to 100. +Writes train_manifest.json, val_manifest.json, test_manifest.json via +conventions.split_manifest_path. +""" + +from __future__ import annotations + +import random +from pathlib import Path +from typing import Any + +from pipeline.core.manifest import Manifest, ManifestStore +from pipeline.core.sample import AudioSample +from pipeline.stages import conventions + + +class SetManifestSplitter: + """Split a fully-augmented manifest into train/val/test sets. + + Shuffles the samples with seed=42 then writes three manifest files. + train_pct + val_pct + test_pct must equal 100. + """ + + def __init__( + self, + output_dir: Path, + manifest_store: ManifestStore, + train_pct: int, + val_pct: int, + test_pct: int, + ) -> None: + if train_pct + val_pct + test_pct != 100: + raise ValueError( + f"train_pct + val_pct + test_pct must equal 100, " + f"got {train_pct + val_pct + test_pct}" + ) + self._output_dir = output_dir + self._manifest_store = manifest_store + self._train_pct = train_pct + self._val_pct = val_pct + self._test_pct = test_pct + + def split(self, manifest: Manifest[AudioSample]) -> None: + """Shuffle and write train/val/test manifest files.""" + samples = list(manifest.samples) + rng = random.Random(42) + rng.shuffle(samples) + + n = len(samples) + n_train = int(n * self._train_pct / 100) + n_val = int(n * self._val_pct / 100) + # test gets the remainder to ensure total == n + n_test = n - n_train - n_val + + train_samples = samples[:n_train] + val_samples = samples[n_train : n_train + n_val] + test_samples = samples[n_train + n_val :] + + self._output_dir.mkdir(parents=True, exist_ok=True) + + self._manifest_store.write( + Manifest(train_samples), + conventions.split_manifest_path(self._output_dir, "train"), + ) + self._manifest_store.write( + Manifest(val_samples), + conventions.split_manifest_path(self._output_dir, "val"), + ) + self._manifest_store.write( + Manifest(test_samples), + conventions.split_manifest_path(self._output_dir, "test"), + ) diff --git a/ml/pipeline/speech/spectrogram_stage.py b/ml/pipeline/speech/spectrogram_stage.py new file mode 100644 index 00000000..75f250dc --- /dev/null +++ b/ml/pipeline/speech/spectrogram_stage.py @@ -0,0 +1,103 @@ +"""SpectrogramStage: compute mel spectrograms from WAV audio samples. + +Deterministic stage (seed=0). Writes {id}.npy of shape (n_mels, time_steps). +Long spectrograms are truncated from the end; short ones are zero-padded. +""" + +from __future__ import annotations + +import asyncio +from pathlib import Path +from typing import Any + +import numpy as np + +from pipeline.core.manifest import ManifestStore +from pipeline.core.modifier_stage import ModifierStage +from pipeline.core.randomization import VariationGenerator +from pipeline.core.sample import AudioSample, SampleSpectrogram +from pipeline.io.audio_io import AudioReader +from pipeline.stages import conventions + + +class SpectrogramStage(ModifierStage[AudioSample, SampleSpectrogram]): + """Compute mel spectrograms from WAV audio samples. + + Writes {id}.npy of shape (n_mels, time_steps). + Truncates long spectrograms from the end; zero-pads short ones. + """ + + _is_deterministic: bool = True + + def __init__( + self, + output_dir: Path, + manifest_store: ManifestStore, + audio_reader: AudioReader, + input_dir: Path, + sample_rate: int, + n_mels: int, + time_steps: int, + ) -> None: + super().__init__(output_dir, manifest_store) + self._audio_reader = audio_reader + self._input_dir = input_dir + self._sample_rate = sample_rate + self._n_mels = n_mels + self._time_steps = time_steps + + def _get_applied_values( + self, sample: AudioSample, generator: VariationGenerator + ) -> dict[str, Any]: + return {} + + def _derive_id(self, input_sample: AudioSample, applied_values: dict[str, Any]) -> str: + return input_sample.id + + async def _generate_output( + self, + input_sample: AudioSample, + output_id: str, + output_seed: int, + applied_values: dict[str, Any], + parent_content_hash: str, + ) -> SampleSpectrogram: + import librosa # deferred import to allow unit tests without librosa + + input_path = self._input_dir / input_sample.path + audio = await self._audio_reader.read(input_path) + + loop = asyncio.get_running_loop() + mel = await loop.run_in_executor( + None, + lambda: librosa.feature.melspectrogram( + y=audio.samples, + sr=self._sample_rate, + n_mels=self._n_mels, + ), + ) + + # Truncate or zero-pad to exactly time_steps columns + if mel.shape[1] >= self._time_steps: + mel = mel[:, : self._time_steps] + else: + pad_width = self._time_steps - mel.shape[1] + mel = np.pad(mel, ((0, 0), (0, pad_width)), mode="constant") + + self._output_dir.mkdir(parents=True, exist_ok=True) + output_path = conventions.sample_file_path(self._output_dir, output_id, "npy") + np.save(str(output_path), mel) + + content_hash = self._compute_content_hash( + parent_content_hash, output_seed, applied_values + ) + + return SampleSpectrogram( + id=output_id, + seed=output_seed, + content_hash=content_hash, + path=Path(f"{output_id}.npy"), + parent_content_hash=parent_content_hash, + transcript=input_sample.transcript, + parent_id=input_sample.id, + ) diff --git a/ml/pipeline/speech/token_stage.py b/ml/pipeline/speech/token_stage.py new file mode 100644 index 00000000..27e49f8f --- /dev/null +++ b/ml/pipeline/speech/token_stage.py @@ -0,0 +1,95 @@ +"""TokenStage: compute phoneme tokens from AudioSample transcripts. + +Deterministic stage (seed=0). Writes {id}.json with phonemes and tokens +padded to input_token_length. Token index 0 is used for padding. +""" + +from __future__ import annotations + +import json +from pathlib import Path +from typing import Any + +from pipeline.core.manifest import ManifestStore +from pipeline.core.modifier_stage import ModifierStage +from pipeline.core.randomization import VariationGenerator +from pipeline.core.sample import AudioSample, SampleTokens +from pipeline.intent.vocab_computer import VocabResult +from pipeline.stages import conventions + + +class TokenStage(ModifierStage[AudioSample, SampleTokens]): + """Compute phoneme tokens from AudioSample transcripts. + + Writes {id}.json with: + phonemes: list[str] — phoneme strings padded with "" to input_token_length + tokens: list[int] — phoneme indices padded with 0 to input_token_length + + Token index 0 is used for padding (index into phoneme_list). + """ + + _is_deterministic: bool = True + + def __init__( + self, + output_dir: Path, + manifest_store: ManifestStore, + vocab: VocabResult, + input_token_length: int, + ) -> None: + super().__init__(output_dir, manifest_store) + self._vocab = vocab + self._input_token_length = input_token_length + + def _get_applied_values( + self, sample: AudioSample, generator: VariationGenerator + ) -> dict[str, Any]: + return {} + + def _derive_id(self, input_sample: AudioSample, applied_values: dict[str, Any]) -> str: + return input_sample.id + + async def _generate_output( + self, + input_sample: AudioSample, + output_id: str, + output_seed: int, + applied_values: dict[str, Any], + parent_content_hash: str, + ) -> SampleTokens: + # Collect phonemes for each word in the transcript + phonemes: list[str] = [] + for word in input_sample.transcript.split(): + word_phonemes = self._vocab.words_to_phonemes.get(word, []) + phonemes.extend(word_phonemes) + + # Compute token indices (1-based lookup into phoneme_list) + tokens = [self._vocab.phoneme_list.index(p) for p in phonemes] + + # Pad to input_token_length + pad_len = max(0, self._input_token_length - len(phonemes)) + padded_phonemes = phonemes + [""] * pad_len + padded_tokens = tokens + [0] * pad_len + + # Truncate if longer than input_token_length + padded_phonemes = padded_phonemes[: self._input_token_length] + padded_tokens = padded_tokens[: self._input_token_length] + + self._output_dir.mkdir(parents=True, exist_ok=True) + output_path = conventions.sample_file_path(self._output_dir, output_id, "json") + with open(output_path, "w", encoding="utf-8") as f: + json.dump({"phonemes": padded_phonemes, "tokens": padded_tokens}, f) + + content_hash = self._compute_content_hash( + parent_content_hash, output_seed, applied_values + ) + + return SampleTokens( + id=output_id, + seed=output_seed, + content_hash=content_hash, + path=Path(f"{output_id}.json"), + parent_content_hash=parent_content_hash, + transcript=input_sample.transcript, + parent_id=input_sample.id, + ) diff --git a/ml/pipeline/stages/params.py b/ml/pipeline/stages/params.py index 8153dc77..8e9208cb 100644 --- a/ml/pipeline/stages/params.py +++ b/ml/pipeline/stages/params.py @@ -51,6 +51,24 @@ class AddMicNoiseParams: amplitude_max: float +@dataclass +class ComputeSpectrogramsParams: + n_mels: int + time_steps: int + + +@dataclass +class ComputeTokensParams: + input_token_length: int + + +@dataclass +class CreateSetManifestsParams: + train_pct: int + val_pct: int + test_pct: int + + @dataclass class PipelineParams: variations_per_phrase: int @@ -61,6 +79,9 @@ class PipelineParams: add_delays: AddDelaysParams add_background_noise: AddBackgroundNoiseParams add_mic_noise: AddMicNoiseParams + compute_spectrograms: ComputeSpectrogramsParams + compute_tokens: ComputeTokensParams + create_set_manifests: CreateSetManifestsParams @classmethod def load(cls, path: Path) -> "PipelineParams": @@ -73,6 +94,9 @@ def load(cls, path: Path) -> "PipelineParams": stage_delays = raw["stages"]["add_delays"] stage_bg_noise = raw["stages"]["add_background_noise"] stage_mic_noise = raw["stages"]["add_mic_noise"] + stage_spectrograms = raw["stages"]["compute_spectrograms"] + stage_tokens = raw["stages"]["compute_tokens"] + stage_set_manifests = raw["stages"]["create_set_manifests"] return cls( variations_per_phrase=int(pipeline["variations_per_phrase"]), @@ -107,4 +131,16 @@ def load(cls, path: Path) -> "PipelineParams": amplitude_min=float(stage_mic_noise["amplitude_min"]), amplitude_max=float(stage_mic_noise["amplitude_max"]), ), + compute_spectrograms=ComputeSpectrogramsParams( + n_mels=int(stage_spectrograms["n_mels"]), + time_steps=int(stage_spectrograms["time_steps"]), + ), + compute_tokens=ComputeTokensParams( + input_token_length=int(stage_tokens["input_token_length"]), + ), + create_set_manifests=CreateSetManifestsParams( + train_pct=int(stage_set_manifests["train_pct"]), + val_pct=int(stage_set_manifests["val_pct"]), + test_pct=int(stage_set_manifests["test_pct"]), + ), ) diff --git a/ml/pipeline/stages/speech_07_compute_tokens.py b/ml/pipeline/stages/speech_07_compute_tokens.py new file mode 100644 index 00000000..96bf37d4 --- /dev/null +++ b/ml/pipeline/stages/speech_07_compute_tokens.py @@ -0,0 +1,62 @@ +from __future__ import annotations + +import argparse +import asyncio +import json +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent.parent)) + +from pipeline.core.manifest import ManifestStore +from pipeline.intent.vocab_computer import VocabResult +from pipeline.speech.token_stage import TokenStage +from pipeline.stages import conventions +from pipeline.stages.params import PipelineParams + +_PROJECT_ROOT = Path(__file__).parents[2] + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Compute phoneme tokens from WAV audio samples" + ) + parser.add_argument("--input-manifest-dir", required=True, type=Path) + parser.add_argument("--output-dir", required=True, type=Path) + parser.add_argument("--vocab-dir", required=True, type=Path) + args = parser.parse_args() + + params = PipelineParams.load(conventions.params_path(_PROJECT_ROOT)) + + # Reconstruct VocabResult from files written by the vocab stage + phoneme_list = (args.vocab_dir / "phoneme_list.txt").read_text(encoding="utf-8").splitlines() + with open(args.vocab_dir / "words_to_phonemes.json", encoding="utf-8") as f: + words_to_phonemes = json.load(f) + vocab = VocabResult( + phoneme_list=phoneme_list, + words_to_phonemes=words_to_phonemes, + ctc_blank_idx=len(phoneme_list), + ) + + store = ManifestStore() + input_manifest = store.read(conventions.manifest_path(args.input_manifest_dir)) + + args.output_dir.mkdir(parents=True, exist_ok=True) + + stage = TokenStage( + output_dir=args.output_dir, + manifest_store=store, + vocab=vocab, + input_token_length=params.compute_tokens.input_token_length, + ) + + asyncio.run( + stage.transform( + input_manifest, + conventions.manifest_path(args.output_dir), + ) + ) + + +if __name__ == "__main__": + main() diff --git a/ml/pipeline/stages/speech_08_compute_spectrograms.py b/ml/pipeline/stages/speech_08_compute_spectrograms.py new file mode 100644 index 00000000..9d7d7918 --- /dev/null +++ b/ml/pipeline/stages/speech_08_compute_spectrograms.py @@ -0,0 +1,53 @@ +from __future__ import annotations + +import argparse +import asyncio +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent.parent)) + +from pipeline.core.manifest import ManifestStore +from pipeline.io.audio_io import LibrosaAudioReader +from pipeline.speech.spectrogram_stage import SpectrogramStage +from pipeline.stages import conventions +from pipeline.stages.params import PipelineParams + +_PROJECT_ROOT = Path(__file__).parents[2] + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Compute mel spectrograms from WAV audio samples" + ) + parser.add_argument("--input-manifest-dir", required=True, type=Path) + parser.add_argument("--output-dir", required=True, type=Path) + args = parser.parse_args() + + params = PipelineParams.load(conventions.params_path(_PROJECT_ROOT)) + + store = ManifestStore() + input_manifest = store.read(conventions.manifest_path(args.input_manifest_dir)) + + args.output_dir.mkdir(parents=True, exist_ok=True) + + stage = SpectrogramStage( + output_dir=args.output_dir, + manifest_store=store, + audio_reader=LibrosaAudioReader(), + input_dir=args.input_manifest_dir, + sample_rate=params.sample_rate, + n_mels=params.compute_spectrograms.n_mels, + time_steps=params.compute_spectrograms.time_steps, + ) + + asyncio.run( + stage.transform( + input_manifest, + conventions.manifest_path(args.output_dir), + ) + ) + + +if __name__ == "__main__": + main() diff --git a/ml/pipeline/stages/speech_09_create_set_manifests.py b/ml/pipeline/stages/speech_09_create_set_manifests.py new file mode 100644 index 00000000..5bbf875b --- /dev/null +++ b/ml/pipeline/stages/speech_09_create_set_manifests.py @@ -0,0 +1,44 @@ +from __future__ import annotations + +import argparse +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent.parent)) + +from pipeline.core.manifest import ManifestStore +from pipeline.speech.set_splitter import SetManifestSplitter +from pipeline.stages import conventions +from pipeline.stages.params import PipelineParams + +_PROJECT_ROOT = Path(__file__).parents[2] + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Split augmented manifest into train/val/test sets" + ) + parser.add_argument("--input-manifest-dir", required=True, type=Path) + parser.add_argument("--output-dir", required=True, type=Path) + args = parser.parse_args() + + params = PipelineParams.load(conventions.params_path(_PROJECT_ROOT)) + + store = ManifestStore() + input_manifest = store.read(conventions.manifest_path(args.input_manifest_dir)) + + args.output_dir.mkdir(parents=True, exist_ok=True) + + splitter = SetManifestSplitter( + output_dir=args.output_dir, + manifest_store=store, + train_pct=params.create_set_manifests.train_pct, + val_pct=params.create_set_manifests.val_pct, + test_pct=params.create_set_manifests.test_pct, + ) + + splitter.split(input_manifest) + + +if __name__ == "__main__": + main() diff --git a/ml/test/pipeline/speech/test_set_splitter.py b/ml/test/pipeline/speech/test_set_splitter.py new file mode 100644 index 00000000..5840d80f --- /dev/null +++ b/ml/test/pipeline/speech/test_set_splitter.py @@ -0,0 +1,165 @@ +"""Unit tests for SetManifestSplitter.""" + +from __future__ import annotations + +import hashlib +import json +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent)) + +from pipeline.core.manifest import Manifest, ManifestStore +from pipeline.core.sample import AudioSample +from pipeline.speech.set_splitter import SetManifestSplitter +from pipeline.stages import conventions + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_audio_sample(sample_id: str, transcript: str = "TV_ON") -> AudioSample: + content = f"{sample_id}:audio" + content_hash = hashlib.sha256(content.encode("utf-8")).hexdigest() + return AudioSample( + id=sample_id, + seed=0, + content_hash=content_hash, + path=Path(f"{sample_id}.wav"), + parent_content_hash="parent_hash", + transcript=transcript, + applied_values={}, + ) + + +def _make_manifest(n: int) -> Manifest[AudioSample]: + return Manifest([_make_audio_sample(f"sample_{i:04d}") for i in range(n)]) + + +def _make_splitter( + output_dir: Path, + *, + train_pct: int = 80, + val_pct: int = 10, + test_pct: int = 10, +) -> SetManifestSplitter: + return SetManifestSplitter( + output_dir=output_dir, + manifest_store=ManifestStore(), + train_pct=train_pct, + val_pct=val_pct, + test_pct=test_pct, + ) + + +# --------------------------------------------------------------------------- +# TestSplit +# --------------------------------------------------------------------------- + + +class TestSplit: + def test_writes_train_manifest(self, tmp_path: Path) -> None: + splitter = _make_splitter(tmp_path) + manifest = _make_manifest(10) + splitter.split(manifest) + assert conventions.split_manifest_path(tmp_path, "train").exists() + + def test_writes_val_manifest(self, tmp_path: Path) -> None: + splitter = _make_splitter(tmp_path) + manifest = _make_manifest(10) + splitter.split(manifest) + assert conventions.split_manifest_path(tmp_path, "val").exists() + + def test_writes_test_manifest(self, tmp_path: Path) -> None: + splitter = _make_splitter(tmp_path) + manifest = _make_manifest(10) + splitter.split(manifest) + assert conventions.split_manifest_path(tmp_path, "test").exists() + + def test_total_samples_equals_input(self, tmp_path: Path) -> None: + splitter = _make_splitter(tmp_path, train_pct=70, val_pct=20, test_pct=10) + manifest = _make_manifest(100) + splitter.split(manifest) + store = ManifestStore() + train = store.read(conventions.split_manifest_path(tmp_path, "train")) + val = store.read(conventions.split_manifest_path(tmp_path, "val")) + test = store.read(conventions.split_manifest_path(tmp_path, "test")) + total = len(train.samples) + len(val.samples) + len(test.samples) + assert total == 100 + + def test_sample_ids_preserved(self, tmp_path: Path) -> None: + splitter = _make_splitter(tmp_path) + manifest = _make_manifest(10) + original_ids = {s.id for s in manifest.samples} + splitter.split(manifest) + store = ManifestStore() + all_ids: set[str] = set() + for split in ["train", "val", "test"]: + m = store.read(conventions.split_manifest_path(tmp_path, split)) + for s in m.samples: + all_ids.add(s.id) + assert all_ids == original_ids + + def test_no_duplicate_samples_across_splits(self, tmp_path: Path) -> None: + splitter = _make_splitter(tmp_path) + manifest = _make_manifest(20) + splitter.split(manifest) + store = ManifestStore() + all_ids: list[str] = [] + for split in ["train", "val", "test"]: + m = store.read(conventions.split_manifest_path(tmp_path, split)) + all_ids.extend(s.id for s in m.samples) + assert len(all_ids) == len(set(all_ids)) + + def test_train_set_largest(self, tmp_path: Path) -> None: + splitter = _make_splitter(tmp_path, train_pct=80, val_pct=10, test_pct=10) + manifest = _make_manifest(100) + splitter.split(manifest) + store = ManifestStore() + train = store.read(conventions.split_manifest_path(tmp_path, "train")) + val = store.read(conventions.split_manifest_path(tmp_path, "val")) + test = store.read(conventions.split_manifest_path(tmp_path, "test")) + assert len(train.samples) > len(val.samples) + assert len(train.samples) > len(test.samples) + + def test_percentages_must_sum_to_100(self, tmp_path: Path) -> None: + import pytest + with pytest.raises(ValueError, match="100"): + SetManifestSplitter( + output_dir=tmp_path, + manifest_store=ManifestStore(), + train_pct=70, + val_pct=20, + test_pct=5, + ) + + def test_shuffles_with_seed_42(self, tmp_path: Path) -> None: + """Two separate runs produce the same split (seed=42 is fixed).""" + manifest = _make_manifest(30) + store = ManifestStore() + + out1 = tmp_path / "run1" + out1.mkdir() + _make_splitter(out1).split(manifest) + train1 = [s.id for s in store.read(conventions.split_manifest_path(out1, "train")).samples] + + out2 = tmp_path / "run2" + out2.mkdir() + _make_splitter(out2).split(manifest) + train2 = [s.id for s in store.read(conventions.split_manifest_path(out2, "train")).samples] + + assert train1 == train2 + + def test_split_sizes_proportional_to_percentages(self, tmp_path: Path) -> None: + splitter = _make_splitter(tmp_path, train_pct=60, val_pct=20, test_pct=20) + manifest = _make_manifest(100) + splitter.split(manifest) + store = ManifestStore() + train = store.read(conventions.split_manifest_path(tmp_path, "train")) + val = store.read(conventions.split_manifest_path(tmp_path, "val")) + test = store.read(conventions.split_manifest_path(tmp_path, "test")) + assert len(train.samples) == 60 + assert len(val.samples) == 20 + assert len(test.samples) == 20 diff --git a/ml/test/pipeline/speech/test_spectrogram_stage.py b/ml/test/pipeline/speech/test_spectrogram_stage.py new file mode 100644 index 00000000..77ea3aa5 --- /dev/null +++ b/ml/test/pipeline/speech/test_spectrogram_stage.py @@ -0,0 +1,211 @@ +"""Unit tests for SpectrogramStage.""" + +from __future__ import annotations + +import asyncio +import hashlib +import json +import sys +from pathlib import Path +from typing import Any + +import numpy as np + +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent)) + +from pipeline.core.manifest import Manifest, ManifestStore +from pipeline.core.randomization import VariationGenerator +from pipeline.core.sample import AudioSample, SampleSpectrogram +from pipeline.io.audio_io import AudioData +from pipeline.speech.spectrogram_stage import SpectrogramStage + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_audio_sample( + sample_id: str = "TV_ON_Jenny_r100", + transcript: str = "TV_ON", + sample_rate: int = 16000, + duration_s: float = 0.5, +) -> AudioSample: + content = f"{sample_id}:audio" + content_hash = hashlib.sha256(content.encode("utf-8")).hexdigest() + return AudioSample( + id=sample_id, + seed=0, + content_hash=content_hash, + path=Path(f"{sample_id}.wav"), + parent_content_hash="parent_hash", + transcript=transcript, + applied_values={}, + ) + + +class _RecordingAudioReader: + """Stub AudioReader that returns silence at a fixed sample rate.""" + + def __init__(self, sample_rate: int = 16000, duration_s: float = 0.5) -> None: + self.calls: list[Path] = [] + self._sample_rate = sample_rate + self._duration_s = duration_s + + async def read(self, path: Path) -> AudioData: + self.calls.append(path) + n = int(self._sample_rate * self._duration_s) + return AudioData(samples=np.zeros(n, dtype=np.float32), sample_rate=self._sample_rate) + + +def _make_stage( + output_dir: Path, + *, + audio_reader: _RecordingAudioReader | None = None, + n_mels: int = 80, + time_steps: int = 200, + sample_rate: int = 16000, + duration_s: float = 0.5, + input_dir: Path | None = None, +) -> tuple[SpectrogramStage, _RecordingAudioReader]: + if audio_reader is None: + audio_reader = _RecordingAudioReader(sample_rate=sample_rate, duration_s=duration_s) + if input_dir is None: + input_dir = output_dir + stage = SpectrogramStage( + output_dir=output_dir, + manifest_store=ManifestStore(), + audio_reader=audio_reader, + input_dir=input_dir, + sample_rate=sample_rate, + n_mels=n_mels, + time_steps=time_steps, + ) + return stage, audio_reader + + +# --------------------------------------------------------------------------- +# TestIsDeterministic +# --------------------------------------------------------------------------- + + +class TestIsDeterministic: + def test_is_deterministic_true(self, tmp_path: Path) -> None: + stage, _ = _make_stage(tmp_path) + assert stage._is_deterministic is True + + +# --------------------------------------------------------------------------- +# TestGetAppliedValues +# --------------------------------------------------------------------------- + + +class TestGetAppliedValues: + def test_get_applied_values_returns_empty_dict(self, tmp_path: Path) -> None: + stage, _ = _make_stage(tmp_path) + sample = _make_audio_sample() + result = stage._get_applied_values(sample, VariationGenerator(0)) + assert result == {} + + +# --------------------------------------------------------------------------- +# TestDeriveId +# --------------------------------------------------------------------------- + + +class TestDeriveId: + def test_derive_id_returns_input_sample_id(self, tmp_path: Path) -> None: + stage, _ = _make_stage(tmp_path) + sample = _make_audio_sample(sample_id="TV_ON_Jenny_r100") + result = stage._derive_id(sample, {}) + assert result == "TV_ON_Jenny_r100" + + def test_derive_id_returns_unchanged_id_for_different_samples(self, tmp_path: Path) -> None: + stage, _ = _make_stage(tmp_path) + sample = _make_audio_sample(sample_id="VOLUME_UP_Aria_r110") + result = stage._derive_id(sample, {}) + assert result == "VOLUME_UP_Aria_r110" + + +# --------------------------------------------------------------------------- +# TestGenerateOutput +# --------------------------------------------------------------------------- + + +class TestGenerateOutput: + def test_output_npy_file_written(self, tmp_path: Path) -> None: + stage, _ = _make_stage(tmp_path, n_mels=40, time_steps=100) + sample = _make_audio_sample() + asyncio.run(stage.transform(Manifest([sample]), tmp_path / "manifest.json")) + assert (tmp_path / f"{sample.id}.npy").exists() + + def test_output_shape_is_n_mels_by_time_steps(self, tmp_path: Path) -> None: + n_mels = 40 + time_steps = 100 + stage, _ = _make_stage(tmp_path, n_mels=n_mels, time_steps=time_steps) + sample = _make_audio_sample() + asyncio.run(stage.transform(Manifest([sample]), tmp_path / "manifest.json")) + arr = np.load(tmp_path / f"{sample.id}.npy") + assert arr.shape == (n_mels, time_steps) + + def test_output_short_audio_zero_padded_to_time_steps(self, tmp_path: Path) -> None: + """Short audio (fewer frames than time_steps) is zero-padded.""" + stage, _ = _make_stage(tmp_path, n_mels=40, time_steps=500, duration_s=0.1) + sample = _make_audio_sample() + asyncio.run(stage.transform(Manifest([sample]), tmp_path / "manifest.json")) + arr = np.load(tmp_path / f"{sample.id}.npy") + assert arr.shape[1] == 500 + + def test_output_long_audio_truncated_to_time_steps(self, tmp_path: Path) -> None: + """Long audio (more frames than time_steps) is truncated from the end.""" + stage, _ = _make_stage(tmp_path, n_mels=40, time_steps=10, duration_s=2.0) + sample = _make_audio_sample() + asyncio.run(stage.transform(Manifest([sample]), tmp_path / "manifest.json")) + arr = np.load(tmp_path / f"{sample.id}.npy") + assert arr.shape[1] == 10 + + def test_output_sample_parent_id_equals_input_sample_id(self, tmp_path: Path) -> None: + stage, _ = _make_stage(tmp_path) + sample = _make_audio_sample(sample_id="MY_SAMPLE") + result = asyncio.run(stage.transform(Manifest([sample]), tmp_path / "manifest.json")) + assert result.samples[0].parent_id == "MY_SAMPLE" + + def test_output_sample_id_equals_input_sample_id(self, tmp_path: Path) -> None: + stage, _ = _make_stage(tmp_path) + sample = _make_audio_sample(sample_id="MY_SAMPLE") + result = asyncio.run(stage.transform(Manifest([sample]), tmp_path / "manifest.json")) + assert result.samples[0].id == "MY_SAMPLE" + + def test_output_sample_transcript_preserved(self, tmp_path: Path) -> None: + stage, _ = _make_stage(tmp_path) + sample = _make_audio_sample(transcript="VOLUME_UP") + result = asyncio.run(stage.transform(Manifest([sample]), tmp_path / "manifest.json")) + assert result.samples[0].transcript == "VOLUME_UP" + + def test_output_sample_is_sample_spectrogram(self, tmp_path: Path) -> None: + stage, _ = _make_stage(tmp_path) + sample = _make_audio_sample() + result = asyncio.run(stage.transform(Manifest([sample]), tmp_path / "manifest.json")) + assert isinstance(result.samples[0], SampleSpectrogram) + + def test_audio_reader_called_with_input_path(self, tmp_path: Path) -> None: + reader = _RecordingAudioReader() + input_dir = tmp_path / "input" + input_dir.mkdir() + stage, _ = _make_stage(tmp_path, audio_reader=reader, input_dir=input_dir) + sample = _make_audio_sample(sample_id="TV_ON_Jenny_r100") + asyncio.run(stage.transform(Manifest([sample]), tmp_path / "manifest.json")) + assert reader.calls[0] == input_dir / "TV_ON_Jenny_r100.wav" + + def test_second_run_skips_file_write(self, tmp_path: Path) -> None: + """Deterministic stage: second run with same input reads manifest and skips.""" + reader = _RecordingAudioReader() + stage, _ = _make_stage(tmp_path, audio_reader=reader) + sample = _make_audio_sample() + + asyncio.run(stage.transform(Manifest([sample]), tmp_path / "manifest.json")) + calls_after_first = len(reader.calls) + + asyncio.run(stage.transform(Manifest([sample]), tmp_path / "manifest.json")) + # Second run should not call reader again (skip path) + assert len(reader.calls) == calls_after_first diff --git a/ml/test/pipeline/speech/test_token_stage.py b/ml/test/pipeline/speech/test_token_stage.py new file mode 100644 index 00000000..028704f8 --- /dev/null +++ b/ml/test/pipeline/speech/test_token_stage.py @@ -0,0 +1,248 @@ +"""Unit tests for TokenStage.""" + +from __future__ import annotations + +import asyncio +import hashlib +import json +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent)) + +from pipeline.core.manifest import Manifest, ManifestStore +from pipeline.core.randomization import VariationGenerator +from pipeline.core.sample import AudioSample, SampleTokens +from pipeline.intent.vocab_computer import VocabResult +from pipeline.speech.token_stage import TokenStage + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_audio_sample( + sample_id: str = "TV_ON_Jenny_r100", + transcript: str = "TV_ON", +) -> AudioSample: + content = f"{sample_id}:audio" + content_hash = hashlib.sha256(content.encode("utf-8")).hexdigest() + return AudioSample( + id=sample_id, + seed=0, + content_hash=content_hash, + path=Path(f"{sample_id}.wav"), + parent_content_hash="parent_hash", + transcript=transcript, + applied_values={}, + ) + + +def _make_vocab( + phoneme_list: list[str] | None = None, + words_to_phonemes: dict[str, list[str]] | None = None, +) -> VocabResult: + if phoneme_list is None: + phoneme_list = ["AH", "N", "T", "V"] + if words_to_phonemes is None: + words_to_phonemes = {"TV_ON": ["T", "V", "AH", "N"]} + return VocabResult( + phoneme_list=phoneme_list, + words_to_phonemes=words_to_phonemes, + ctc_blank_idx=len(phoneme_list), + ) + + +def _make_stage( + output_dir: Path, + *, + vocab: VocabResult | None = None, + input_token_length: int = 20, +) -> TokenStage: + if vocab is None: + vocab = _make_vocab() + return TokenStage( + output_dir=output_dir, + manifest_store=ManifestStore(), + vocab=vocab, + input_token_length=input_token_length, + ) + + +# --------------------------------------------------------------------------- +# TestIsDeterministic +# --------------------------------------------------------------------------- + + +class TestIsDeterministic: + def test_is_deterministic_true(self, tmp_path: Path) -> None: + stage = _make_stage(tmp_path) + assert stage._is_deterministic is True + + +# --------------------------------------------------------------------------- +# TestGetAppliedValues +# --------------------------------------------------------------------------- + + +class TestGetAppliedValues: + def test_get_applied_values_returns_empty_dict(self, tmp_path: Path) -> None: + stage = _make_stage(tmp_path) + sample = _make_audio_sample() + result = stage._get_applied_values(sample, VariationGenerator(0)) + assert result == {} + + +# --------------------------------------------------------------------------- +# TestDeriveId +# --------------------------------------------------------------------------- + + +class TestDeriveId: + def test_derive_id_returns_input_sample_id(self, tmp_path: Path) -> None: + stage = _make_stage(tmp_path) + sample = _make_audio_sample(sample_id="TV_ON_Jenny_r100") + result = stage._derive_id(sample, {}) + assert result == "TV_ON_Jenny_r100" + + def test_derive_id_returns_unchanged_id_for_different_samples(self, tmp_path: Path) -> None: + stage = _make_stage(tmp_path) + sample = _make_audio_sample(sample_id="VOLUME_UP_Aria_r110") + result = stage._derive_id(sample, {}) + assert result == "VOLUME_UP_Aria_r110" + + +# --------------------------------------------------------------------------- +# TestGenerateOutput +# --------------------------------------------------------------------------- + + +class TestGenerateOutput: + def test_output_json_file_written(self, tmp_path: Path) -> None: + stage = _make_stage(tmp_path) + sample = _make_audio_sample() + asyncio.run(stage.transform(Manifest([sample]), tmp_path / "manifest.json")) + assert (tmp_path / f"{sample.id}.json").exists() + + def test_output_json_has_phonemes_key(self, tmp_path: Path) -> None: + stage = _make_stage(tmp_path) + sample = _make_audio_sample() + asyncio.run(stage.transform(Manifest([sample]), tmp_path / "manifest.json")) + data = json.loads((tmp_path / f"{sample.id}.json").read_text()) + assert "phonemes" in data + + def test_output_json_has_tokens_key(self, tmp_path: Path) -> None: + stage = _make_stage(tmp_path) + sample = _make_audio_sample() + asyncio.run(stage.transform(Manifest([sample]), tmp_path / "manifest.json")) + data = json.loads((tmp_path / f"{sample.id}.json").read_text()) + assert "tokens" in data + + def test_tokens_padded_to_input_token_length(self, tmp_path: Path) -> None: + stage = _make_stage(tmp_path, input_token_length=20) + sample = _make_audio_sample(transcript="TV_ON") + asyncio.run(stage.transform(Manifest([sample]), tmp_path / "manifest.json")) + data = json.loads((tmp_path / f"{sample.id}.json").read_text()) + assert len(data["tokens"]) == 20 + + def test_phonemes_padded_to_input_token_length(self, tmp_path: Path) -> None: + stage = _make_stage(tmp_path, input_token_length=20) + sample = _make_audio_sample(transcript="TV_ON") + asyncio.run(stage.transform(Manifest([sample]), tmp_path / "manifest.json")) + data = json.loads((tmp_path / f"{sample.id}.json").read_text()) + assert len(data["phonemes"]) == 20 + + def test_tokens_padding_value_is_zero(self, tmp_path: Path) -> None: + """Padding uses index 0.""" + vocab = _make_vocab( + phoneme_list=["AH", "N", "T", "V"], + words_to_phonemes={"TV_ON": ["T", "V"]}, + ) + stage = _make_stage(tmp_path, vocab=vocab, input_token_length=10) + sample = _make_audio_sample(transcript="TV_ON") + asyncio.run(stage.transform(Manifest([sample]), tmp_path / "manifest.json")) + data = json.loads((tmp_path / f"{sample.id}.json").read_text()) + # Only 2 phonemes for "TV_ON", remaining 8 should be 0 + assert data["tokens"][2:] == [0] * 8 + + def test_tokens_are_phoneme_indices(self, tmp_path: Path) -> None: + """Token values are indices into phoneme_list.""" + phoneme_list = ["AH", "N", "T", "V"] + vocab = _make_vocab( + phoneme_list=phoneme_list, + words_to_phonemes={"TV_ON": ["T", "V", "AH", "N"]}, + ) + stage = _make_stage(tmp_path, vocab=vocab, input_token_length=10) + sample = _make_audio_sample(transcript="TV_ON") + asyncio.run(stage.transform(Manifest([sample]), tmp_path / "manifest.json")) + data = json.loads((tmp_path / f"{sample.id}.json").read_text()) + expected = [phoneme_list.index("T"), phoneme_list.index("V"), + phoneme_list.index("AH"), phoneme_list.index("N")] + assert data["tokens"][:4] == expected + + def test_output_sample_parent_id_equals_input_sample_id(self, tmp_path: Path) -> None: + stage = _make_stage(tmp_path) + sample = _make_audio_sample(sample_id="MY_SAMPLE") + result = asyncio.run(stage.transform(Manifest([sample]), tmp_path / "manifest.json")) + assert result.samples[0].parent_id == "MY_SAMPLE" + + def test_output_sample_id_equals_input_sample_id(self, tmp_path: Path) -> None: + stage = _make_stage(tmp_path) + sample = _make_audio_sample(sample_id="MY_SAMPLE") + result = asyncio.run(stage.transform(Manifest([sample]), tmp_path / "manifest.json")) + assert result.samples[0].id == "MY_SAMPLE" + + def test_output_sample_transcript_preserved(self, tmp_path: Path) -> None: + stage = _make_stage(tmp_path) + sample = _make_audio_sample(transcript="TV_ON") + result = asyncio.run(stage.transform(Manifest([sample]), tmp_path / "manifest.json")) + assert result.samples[0].transcript == "TV_ON" + + def test_output_sample_is_sample_tokens(self, tmp_path: Path) -> None: + stage = _make_stage(tmp_path) + sample = _make_audio_sample() + result = asyncio.run(stage.transform(Manifest([sample]), tmp_path / "manifest.json")) + assert isinstance(result.samples[0], SampleTokens) + + def test_phonemes_list_contains_phoneme_strings(self, tmp_path: Path) -> None: + phoneme_list = ["AH", "N", "T", "V"] + vocab = _make_vocab( + phoneme_list=phoneme_list, + words_to_phonemes={"TV_ON": ["T", "V"]}, + ) + stage = _make_stage(tmp_path, vocab=vocab, input_token_length=5) + sample = _make_audio_sample(transcript="TV_ON") + asyncio.run(stage.transform(Manifest([sample]), tmp_path / "manifest.json")) + data = json.loads((tmp_path / f"{sample.id}.json").read_text()) + assert data["phonemes"][0] == "T" + assert data["phonemes"][1] == "V" + + def test_phonemes_padding_is_empty_string(self, tmp_path: Path) -> None: + """Phoneme padding uses empty string for slots after the actual phonemes.""" + vocab = _make_vocab( + phoneme_list=["AH", "N", "T", "V"], + words_to_phonemes={"TV_ON": ["T", "V"]}, + ) + stage = _make_stage(tmp_path, vocab=vocab, input_token_length=5) + sample = _make_audio_sample(transcript="TV_ON") + asyncio.run(stage.transform(Manifest([sample]), tmp_path / "manifest.json")) + data = json.loads((tmp_path / f"{sample.id}.json").read_text()) + # slots 2,3,4 are padding + assert data["phonemes"][2:] == ["", "", ""] + + def test_multi_word_transcript_concatenates_phonemes(self, tmp_path: Path) -> None: + phoneme_list = ["AH", "N", "T", "V", "AH0", "P"] + vocab = _make_vocab( + phoneme_list=phoneme_list, + words_to_phonemes={ + "TV_ON": ["T", "V"], + "VOLUME_UP": ["V", "AH0", "P"], + }, + ) + stage = _make_stage(tmp_path, vocab=vocab, input_token_length=10) + sample = _make_audio_sample(transcript="TV_ON VOLUME_UP") + asyncio.run(stage.transform(Manifest([sample]), tmp_path / "manifest.json")) + data = json.loads((tmp_path / f"{sample.id}.json").read_text()) + # TV_ON=[T,V] + VOLUME_UP=[V,AH0,P] = 5 phonemes + assert data["phonemes"][:5] == ["T", "V", "V", "AH0", "P"] diff --git a/ml/test/pipeline/stages/test_params.py b/ml/test/pipeline/stages/test_params.py index 39a09a8d..300f9971 100644 --- a/ml/test/pipeline/stages/test_params.py +++ b/ml/test/pipeline/stages/test_params.py @@ -14,6 +14,9 @@ AddBackgroundNoiseParams, AddDelaysParams, AddMicNoiseParams, + ComputeSpectrogramsParams, + ComputeTokensParams, + CreateSetManifestsParams, GeneratePhraseParams, GenerateSamplesParams, PipelineParams, @@ -62,6 +65,18 @@ def _write_params(path: Path, data: dict) -> Path: "amplitude_min": 0.001, "amplitude_max": 0.05, }, + "compute_spectrograms": { + "n_mels": 80, + "time_steps": 400, + }, + "compute_tokens": { + "input_token_length": 50, + }, + "create_set_manifests": { + "train_pct": 80, + "val_pct": 10, + "test_pct": 10, + }, }, } @@ -253,3 +268,98 @@ def test_missing_add_mic_noise_stage_raises(self, tmp_path: Path) -> None: params_file = _write_params(tmp_path, data) with pytest.raises(KeyError): PipelineParams.load(params_file) + + +class TestComputeSpectrogramsParamsLoad: + def test_loads_compute_spectrograms_fields(self, tmp_path: Path) -> None: + params_file = _write_params(tmp_path, _VALID_DATA) + params = PipelineParams.load(params_file) + + cs = params.compute_spectrograms + assert cs.n_mels == 80 + assert cs.time_steps == 400 + + def test_compute_spectrograms_is_correct_type(self, tmp_path: Path) -> None: + params_file = _write_params(tmp_path, _VALID_DATA) + params = PipelineParams.load(params_file) + + assert isinstance(params.compute_spectrograms, ComputeSpectrogramsParams) + + def test_compute_spectrograms_fields_are_ints(self, tmp_path: Path) -> None: + params_file = _write_params(tmp_path, _VALID_DATA) + params = PipelineParams.load(params_file) + + cs = params.compute_spectrograms + assert isinstance(cs.n_mels, int) + assert isinstance(cs.time_steps, int) + + def test_missing_compute_spectrograms_stage_raises(self, tmp_path: Path) -> None: + import copy + data = copy.deepcopy(_VALID_DATA) + del data["stages"]["compute_spectrograms"] + params_file = _write_params(tmp_path, data) + with pytest.raises(KeyError): + PipelineParams.load(params_file) + + +class TestComputeTokensParamsLoad: + def test_loads_compute_tokens_fields(self, tmp_path: Path) -> None: + params_file = _write_params(tmp_path, _VALID_DATA) + params = PipelineParams.load(params_file) + + ct = params.compute_tokens + assert ct.input_token_length == 50 + + def test_compute_tokens_is_correct_type(self, tmp_path: Path) -> None: + params_file = _write_params(tmp_path, _VALID_DATA) + params = PipelineParams.load(params_file) + + assert isinstance(params.compute_tokens, ComputeTokensParams) + + def test_compute_tokens_field_is_int(self, tmp_path: Path) -> None: + params_file = _write_params(tmp_path, _VALID_DATA) + params = PipelineParams.load(params_file) + + assert isinstance(params.compute_tokens.input_token_length, int) + + def test_missing_compute_tokens_stage_raises(self, tmp_path: Path) -> None: + import copy + data = copy.deepcopy(_VALID_DATA) + del data["stages"]["compute_tokens"] + params_file = _write_params(tmp_path, data) + with pytest.raises(KeyError): + PipelineParams.load(params_file) + + +class TestCreateSetManifestsParamsLoad: + def test_loads_create_set_manifests_fields(self, tmp_path: Path) -> None: + params_file = _write_params(tmp_path, _VALID_DATA) + params = PipelineParams.load(params_file) + + csm = params.create_set_manifests + assert csm.train_pct == 80 + assert csm.val_pct == 10 + assert csm.test_pct == 10 + + def test_create_set_manifests_is_correct_type(self, tmp_path: Path) -> None: + params_file = _write_params(tmp_path, _VALID_DATA) + params = PipelineParams.load(params_file) + + assert isinstance(params.create_set_manifests, CreateSetManifestsParams) + + def test_create_set_manifests_fields_are_ints(self, tmp_path: Path) -> None: + params_file = _write_params(tmp_path, _VALID_DATA) + params = PipelineParams.load(params_file) + + csm = params.create_set_manifests + assert isinstance(csm.train_pct, int) + assert isinstance(csm.val_pct, int) + assert isinstance(csm.test_pct, int) + + def test_missing_create_set_manifests_stage_raises(self, tmp_path: Path) -> None: + import copy + data = copy.deepcopy(_VALID_DATA) + del data["stages"]["create_set_manifests"] + params_file = _write_params(tmp_path, data) + with pytest.raises(KeyError): + PipelineParams.load(params_file) From deeea1ec00ab3e77f13ef8aab5f9d9a95c30b35c Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 19 Jun 2026 04:28:45 +0000 Subject: [PATCH 2/3] ADR-228: address PR review comments - padding, unknown words, sample_rate, unused imports --- ml/pipeline/speech/set_splitter.py | 1 - ml/pipeline/speech/spectrogram_stage.py | 2 +- ml/pipeline/speech/token_stage.py | 35 +++++++++++++------- ml/test/pipeline/speech/test_set_splitter.py | 1 - ml/test/pipeline/speech/test_token_stage.py | 24 +++++++++++--- 5 files changed, 43 insertions(+), 20 deletions(-) diff --git a/ml/pipeline/speech/set_splitter.py b/ml/pipeline/speech/set_splitter.py index 36204153..b31b444d 100644 --- a/ml/pipeline/speech/set_splitter.py +++ b/ml/pipeline/speech/set_splitter.py @@ -9,7 +9,6 @@ import random from pathlib import Path -from typing import Any from pipeline.core.manifest import Manifest, ManifestStore from pipeline.core.sample import AudioSample diff --git a/ml/pipeline/speech/spectrogram_stage.py b/ml/pipeline/speech/spectrogram_stage.py index 75f250dc..912a8a54 100644 --- a/ml/pipeline/speech/spectrogram_stage.py +++ b/ml/pipeline/speech/spectrogram_stage.py @@ -72,7 +72,7 @@ async def _generate_output( None, lambda: librosa.feature.melspectrogram( y=audio.samples, - sr=self._sample_rate, + sr=audio.sample_rate, n_mels=self._n_mels, ), ) diff --git a/ml/pipeline/speech/token_stage.py b/ml/pipeline/speech/token_stage.py index 27e49f8f..76127699 100644 --- a/ml/pipeline/speech/token_stage.py +++ b/ml/pipeline/speech/token_stage.py @@ -1,7 +1,8 @@ """TokenStage: compute phoneme tokens from AudioSample transcripts. Deterministic stage (seed=0). Writes {id}.json with phonemes and tokens -padded to input_token_length. Token index 0 is used for padding. +padded to input_token_length. Padding uses vocab.ctc_blank_idx (== len(phoneme_list)), +which cannot collide with any real phoneme index. """ from __future__ import annotations @@ -23,9 +24,11 @@ class TokenStage(ModifierStage[AudioSample, SampleTokens]): Writes {id}.json with: phonemes: list[str] — phoneme strings padded with "" to input_token_length - tokens: list[int] — phoneme indices padded with 0 to input_token_length + tokens: list[int] — phoneme indices padded with vocab.ctc_blank_idx to + input_token_length. ctc_blank_idx == len(phoneme_list), + so it cannot collide with any real phoneme index. - Token index 0 is used for padding (index into phoneme_list). + Raises KeyError if a transcript word is missing from vocab.words_to_phonemes. """ _is_deterministic: bool = True @@ -57,19 +60,27 @@ async def _generate_output( applied_values: dict[str, Any], parent_content_hash: str, ) -> SampleTokens: - # Collect phonemes for each word in the transcript + # Build index lookup once to avoid repeated list.index() calls + phoneme_to_idx = {p: i for i, p in enumerate(self._vocab.phoneme_list)} + + # Collect phonemes for each word in the transcript; fail fast on unknown words phonemes: list[str] = [] for word in input_sample.transcript.split(): - word_phonemes = self._vocab.words_to_phonemes.get(word, []) - phonemes.extend(word_phonemes) - - # Compute token indices (1-based lookup into phoneme_list) - tokens = [self._vocab.phoneme_list.index(p) for p in phonemes] - - # Pad to input_token_length + if word not in self._vocab.words_to_phonemes: + raise KeyError( + f"Word '{word}' in transcript is not in vocab.words_to_phonemes" + ) + phonemes.extend(self._vocab.words_to_phonemes[word]) + + # Compute token indices + tokens = [phoneme_to_idx[p] for p in phonemes] + + # Pad to input_token_length using ctc_blank_idx (== len(phoneme_list)), + # which cannot collide with any real phoneme index + pad_idx = self._vocab.ctc_blank_idx pad_len = max(0, self._input_token_length - len(phonemes)) padded_phonemes = phonemes + [""] * pad_len - padded_tokens = tokens + [0] * pad_len + padded_tokens = tokens + [pad_idx] * pad_len # Truncate if longer than input_token_length padded_phonemes = padded_phonemes[: self._input_token_length] diff --git a/ml/test/pipeline/speech/test_set_splitter.py b/ml/test/pipeline/speech/test_set_splitter.py index 5840d80f..a5e7362b 100644 --- a/ml/test/pipeline/speech/test_set_splitter.py +++ b/ml/test/pipeline/speech/test_set_splitter.py @@ -3,7 +3,6 @@ from __future__ import annotations import hashlib -import json import sys from pathlib import Path diff --git a/ml/test/pipeline/speech/test_token_stage.py b/ml/test/pipeline/speech/test_token_stage.py index 028704f8..724085e2 100644 --- a/ml/test/pipeline/speech/test_token_stage.py +++ b/ml/test/pipeline/speech/test_token_stage.py @@ -153,18 +153,32 @@ def test_phonemes_padded_to_input_token_length(self, tmp_path: Path) -> None: data = json.loads((tmp_path / f"{sample.id}.json").read_text()) assert len(data["phonemes"]) == 20 - def test_tokens_padding_value_is_zero(self, tmp_path: Path) -> None: - """Padding uses index 0.""" + def test_tokens_padding_value_is_ctc_blank_idx(self, tmp_path: Path) -> None: + """Padding uses vocab.ctc_blank_idx (== len(phoneme_list)), not 0.""" + phoneme_list = ["AH", "N", "T", "V"] vocab = _make_vocab( - phoneme_list=["AH", "N", "T", "V"], + phoneme_list=phoneme_list, words_to_phonemes={"TV_ON": ["T", "V"]}, ) stage = _make_stage(tmp_path, vocab=vocab, input_token_length=10) sample = _make_audio_sample(transcript="TV_ON") asyncio.run(stage.transform(Manifest([sample]), tmp_path / "manifest.json")) data = json.loads((tmp_path / f"{sample.id}.json").read_text()) - # Only 2 phonemes for "TV_ON", remaining 8 should be 0 - assert data["tokens"][2:] == [0] * 8 + # Only 2 phonemes for "TV_ON"; remaining 8 should equal ctc_blank_idx == 4 + pad_idx = vocab.ctc_blank_idx + assert data["tokens"][2:] == [pad_idx] * 8 + + def test_unknown_transcript_word_raises_key_error(self, tmp_path: Path) -> None: + """A word not in vocab.words_to_phonemes raises KeyError immediately.""" + import pytest + vocab = _make_vocab( + phoneme_list=["AH", "N", "T", "V"], + words_to_phonemes={"TV_ON": ["T", "V"]}, + ) + stage = _make_stage(tmp_path, vocab=vocab, input_token_length=10) + sample = _make_audio_sample(transcript="UNKNOWN_WORD") + with pytest.raises(KeyError, match="UNKNOWN_WORD"): + asyncio.run(stage.transform(Manifest([sample]), tmp_path / "manifest.json")) def test_tokens_are_phoneme_indices(self, tmp_path: Path) -> None: """Token values are indices into phoneme_list.""" From 38c474e7b3a700e914111d4d5e7ef7be97e35ccf Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 19 Jun 2026 04:29:47 +0000 Subject: [PATCH 3/3] Move pytest import to module level in test_token_stage.py --- ml/test/pipeline/speech/test_token_stage.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ml/test/pipeline/speech/test_token_stage.py b/ml/test/pipeline/speech/test_token_stage.py index 724085e2..84c8f157 100644 --- a/ml/test/pipeline/speech/test_token_stage.py +++ b/ml/test/pipeline/speech/test_token_stage.py @@ -8,6 +8,8 @@ import sys from pathlib import Path +import pytest + sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent)) from pipeline.core.manifest import Manifest, ManifestStore @@ -170,7 +172,6 @@ def test_tokens_padding_value_is_ctc_blank_idx(self, tmp_path: Path) -> None: def test_unknown_transcript_word_raises_key_error(self, tmp_path: Path) -> None: """A word not in vocab.words_to_phonemes raises KeyError immediately.""" - import pytest vocab = _make_vocab( phoneme_list=["AH", "N", "T", "V"], words_to_phonemes={"TV_ON": ["T", "V"]},