Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions ml/params.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,12 @@ stages:
vary_probability: 0.5
amplitude_min: 0.001
amplitude_max: 0.05
compute_spectrograms:
n_mels: 80
time_steps: 400
compute_tokens:
input_token_length: 50
create_set_manifests:
train_pct: 80
val_pct: 10
test_pct: 10
5 changes: 3 additions & 2 deletions ml/pipeline/speech/_doc_speech.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@ produces a derived `AudioSample` manifest with per-sample variation applied.
| [`delay_stage.py`](delay_stage.py) | `DelayAugmentor` | `AudioSample → AudioSample` (silence padding) |
| [`background_noise_stage.py`](background_noise_stage.py) | `BackgroundNoiseAugmentor` | `AudioSample → AudioSample` (environmental noise mix) |
| [`mic_noise_stage.py`](mic_noise_stage.py) | `MicrophoneNoiseAugmentor` | `AudioSample → AudioSample` (Gaussian mic noise) |

Planned stages (not yet implemented): `token_stage.py`, `spectrogram_stage.py`.
| [`token_stage.py`](token_stage.py) | `TokenStage` | `AudioSample → SampleTokens` (phoneme token sequences) |
| [`spectrogram_stage.py`](spectrogram_stage.py) | `SpectrogramStage` | `AudioSample → SampleSpectrogram` (mel spectrogram NPY) |
| [`set_splitter.py`](set_splitter.py) | `SetManifestSplitter` | Splits augmented manifest into train/val/test |

## Key design decisions

Expand Down
73 changes: 73 additions & 0 deletions ml/pipeline/speech/set_splitter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
"""SetManifestSplitter: split an augmented manifest into train/val/test sets.

Shuffles with seed=42 for reproducibility. Percentages must sum to 100.
Writes train_manifest.json, val_manifest.json, test_manifest.json via
conventions.split_manifest_path.
"""

from __future__ import annotations

import random
from pathlib import Path

from pipeline.core.manifest import Manifest, ManifestStore
Comment thread
Copilot marked this conversation as resolved.
from pipeline.core.sample import AudioSample
from pipeline.stages import conventions


class SetManifestSplitter:
"""Split a fully-augmented manifest into train/val/test sets.

Shuffles the samples with seed=42 then writes three manifest files.
train_pct + val_pct + test_pct must equal 100.
"""

def __init__(
self,
output_dir: Path,
manifest_store: ManifestStore,
train_pct: int,
val_pct: int,
test_pct: int,
) -> None:
if train_pct + val_pct + test_pct != 100:
raise ValueError(
f"train_pct + val_pct + test_pct must equal 100, "
f"got {train_pct + val_pct + test_pct}"
)
self._output_dir = output_dir
self._manifest_store = manifest_store
self._train_pct = train_pct
self._val_pct = val_pct
self._test_pct = test_pct

def split(self, manifest: Manifest[AudioSample]) -> None:
"""Shuffle and write train/val/test manifest files."""
samples = list(manifest.samples)
rng = random.Random(42)
rng.shuffle(samples)

n = len(samples)
n_train = int(n * self._train_pct / 100)
n_val = int(n * self._val_pct / 100)
# test gets the remainder to ensure total == n
n_test = n - n_train - n_val

train_samples = samples[:n_train]
val_samples = samples[n_train : n_train + n_val]
test_samples = samples[n_train + n_val :]

self._output_dir.mkdir(parents=True, exist_ok=True)

self._manifest_store.write(
Manifest(train_samples),
conventions.split_manifest_path(self._output_dir, "train"),
)
self._manifest_store.write(
Manifest(val_samples),
conventions.split_manifest_path(self._output_dir, "val"),
)
self._manifest_store.write(
Manifest(test_samples),
conventions.split_manifest_path(self._output_dir, "test"),
)
103 changes: 103 additions & 0 deletions ml/pipeline/speech/spectrogram_stage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
"""SpectrogramStage: compute mel spectrograms from WAV audio samples.

Deterministic stage (seed=0). Writes {id}.npy of shape (n_mels, time_steps).
Long spectrograms are truncated from the end; short ones are zero-padded.
"""

from __future__ import annotations

import asyncio
from pathlib import Path
from typing import Any

import numpy as np

from pipeline.core.manifest import ManifestStore
from pipeline.core.modifier_stage import ModifierStage
from pipeline.core.randomization import VariationGenerator
from pipeline.core.sample import AudioSample, SampleSpectrogram
from pipeline.io.audio_io import AudioReader
from pipeline.stages import conventions


class SpectrogramStage(ModifierStage[AudioSample, SampleSpectrogram]):
"""Compute mel spectrograms from WAV audio samples.

Writes {id}.npy of shape (n_mels, time_steps).
Truncates long spectrograms from the end; zero-pads short ones.
"""

_is_deterministic: bool = True

def __init__(
self,
output_dir: Path,
manifest_store: ManifestStore,
audio_reader: AudioReader,
input_dir: Path,
sample_rate: int,
n_mels: int,
time_steps: int,
) -> None:
super().__init__(output_dir, manifest_store)
self._audio_reader = audio_reader
self._input_dir = input_dir
self._sample_rate = sample_rate
self._n_mels = n_mels
self._time_steps = time_steps

def _get_applied_values(
self, sample: AudioSample, generator: VariationGenerator
) -> dict[str, Any]:
return {}

def _derive_id(self, input_sample: AudioSample, applied_values: dict[str, Any]) -> str:
return input_sample.id

async def _generate_output(
self,
input_sample: AudioSample,
output_id: str,
output_seed: int,
applied_values: dict[str, Any],
parent_content_hash: str,
) -> SampleSpectrogram:
import librosa # deferred import to allow unit tests without librosa

input_path = self._input_dir / input_sample.path
audio = await self._audio_reader.read(input_path)

loop = asyncio.get_running_loop()
mel = await loop.run_in_executor(
None,
lambda: librosa.feature.melspectrogram(
y=audio.samples,
sr=audio.sample_rate,
n_mels=self._n_mels,
),
)

# Truncate or zero-pad to exactly time_steps columns
if mel.shape[1] >= self._time_steps:
mel = mel[:, : self._time_steps]
else:
pad_width = self._time_steps - mel.shape[1]
mel = np.pad(mel, ((0, 0), (0, pad_width)), mode="constant")

self._output_dir.mkdir(parents=True, exist_ok=True)
output_path = conventions.sample_file_path(self._output_dir, output_id, "npy")
np.save(str(output_path), mel)

content_hash = self._compute_content_hash(
parent_content_hash, output_seed, applied_values
)

return SampleSpectrogram(
id=output_id,
seed=output_seed,
content_hash=content_hash,
path=Path(f"{output_id}.npy"),
parent_content_hash=parent_content_hash,
transcript=input_sample.transcript,
parent_id=input_sample.id,
)
106 changes: 106 additions & 0 deletions ml/pipeline/speech/token_stage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
"""TokenStage: compute phoneme tokens from AudioSample transcripts.

Deterministic stage (seed=0). Writes {id}.json with phonemes and tokens
padded to input_token_length. Padding uses vocab.ctc_blank_idx (== len(phoneme_list)),
which cannot collide with any real phoneme index.
"""

from __future__ import annotations

import json
from pathlib import Path
from typing import Any

from pipeline.core.manifest import ManifestStore
from pipeline.core.modifier_stage import ModifierStage
from pipeline.core.randomization import VariationGenerator
from pipeline.core.sample import AudioSample, SampleTokens
from pipeline.intent.vocab_computer import VocabResult
from pipeline.stages import conventions


class TokenStage(ModifierStage[AudioSample, SampleTokens]):
"""Compute phoneme tokens from AudioSample transcripts.

Writes {id}.json with:
phonemes: list[str] — phoneme strings padded with "" to input_token_length
tokens: list[int] — phoneme indices padded with vocab.ctc_blank_idx to
input_token_length. ctc_blank_idx == len(phoneme_list),
so it cannot collide with any real phoneme index.

Raises KeyError if a transcript word is missing from vocab.words_to_phonemes.
"""

_is_deterministic: bool = True

def __init__(
self,
output_dir: Path,
manifest_store: ManifestStore,
vocab: VocabResult,
input_token_length: int,
) -> None:
super().__init__(output_dir, manifest_store)
self._vocab = vocab
self._input_token_length = input_token_length

def _get_applied_values(
self, sample: AudioSample, generator: VariationGenerator
) -> dict[str, Any]:
return {}

def _derive_id(self, input_sample: AudioSample, applied_values: dict[str, Any]) -> str:
return input_sample.id

async def _generate_output(
self,
input_sample: AudioSample,
output_id: str,
output_seed: int,
applied_values: dict[str, Any],
parent_content_hash: str,
) -> SampleTokens:
# Build index lookup once to avoid repeated list.index() calls
phoneme_to_idx = {p: i for i, p in enumerate(self._vocab.phoneme_list)}

# Collect phonemes for each word in the transcript; fail fast on unknown words
phonemes: list[str] = []
for word in input_sample.transcript.split():
if word not in self._vocab.words_to_phonemes:
raise KeyError(
f"Word '{word}' in transcript is not in vocab.words_to_phonemes"
)
phonemes.extend(self._vocab.words_to_phonemes[word])

# Compute token indices
tokens = [phoneme_to_idx[p] for p in phonemes]

# Pad to input_token_length using ctc_blank_idx (== len(phoneme_list)),
# which cannot collide with any real phoneme index
pad_idx = self._vocab.ctc_blank_idx
pad_len = max(0, self._input_token_length - len(phonemes))
padded_phonemes = phonemes + [""] * pad_len
padded_tokens = tokens + [pad_idx] * pad_len

# Truncate if longer than input_token_length
padded_phonemes = padded_phonemes[: self._input_token_length]
padded_tokens = padded_tokens[: self._input_token_length]

self._output_dir.mkdir(parents=True, exist_ok=True)
output_path = conventions.sample_file_path(self._output_dir, output_id, "json")
with open(output_path, "w", encoding="utf-8") as f:
json.dump({"phonemes": padded_phonemes, "tokens": padded_tokens}, f)

content_hash = self._compute_content_hash(
parent_content_hash, output_seed, applied_values
)

return SampleTokens(
id=output_id,
seed=output_seed,
content_hash=content_hash,
path=Path(f"{output_id}.json"),
parent_content_hash=parent_content_hash,
transcript=input_sample.transcript,
parent_id=input_sample.id,
)
36 changes: 36 additions & 0 deletions ml/pipeline/stages/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,24 @@ class AddMicNoiseParams:
amplitude_max: float


@dataclass
class ComputeSpectrogramsParams:
n_mels: int
time_steps: int


@dataclass
class ComputeTokensParams:
input_token_length: int


@dataclass
class CreateSetManifestsParams:
train_pct: int
val_pct: int
test_pct: int


@dataclass
class PipelineParams:
variations_per_phrase: int
Expand All @@ -61,6 +79,9 @@ class PipelineParams:
add_delays: AddDelaysParams
add_background_noise: AddBackgroundNoiseParams
add_mic_noise: AddMicNoiseParams
compute_spectrograms: ComputeSpectrogramsParams
compute_tokens: ComputeTokensParams
create_set_manifests: CreateSetManifestsParams

@classmethod
def load(cls, path: Path) -> "PipelineParams":
Expand All @@ -73,6 +94,9 @@ def load(cls, path: Path) -> "PipelineParams":
stage_delays = raw["stages"]["add_delays"]
stage_bg_noise = raw["stages"]["add_background_noise"]
stage_mic_noise = raw["stages"]["add_mic_noise"]
stage_spectrograms = raw["stages"]["compute_spectrograms"]
stage_tokens = raw["stages"]["compute_tokens"]
stage_set_manifests = raw["stages"]["create_set_manifests"]

return cls(
variations_per_phrase=int(pipeline["variations_per_phrase"]),
Expand Down Expand Up @@ -107,4 +131,16 @@ def load(cls, path: Path) -> "PipelineParams":
amplitude_min=float(stage_mic_noise["amplitude_min"]),
amplitude_max=float(stage_mic_noise["amplitude_max"]),
),
compute_spectrograms=ComputeSpectrogramsParams(
n_mels=int(stage_spectrograms["n_mels"]),
time_steps=int(stage_spectrograms["time_steps"]),
),
compute_tokens=ComputeTokensParams(
input_token_length=int(stage_tokens["input_token_length"]),
),
create_set_manifests=CreateSetManifestsParams(
train_pct=int(stage_set_manifests["train_pct"]),
val_pct=int(stage_set_manifests["val_pct"]),
test_pct=int(stage_set_manifests["test_pct"]),
),
)
Loading