Skip to content

ADR-229: Implement ModelTrainer, KerasBackend, and speech_10_train_model entry-point#235

Merged
jodavis merged 7 commits into
feature/ADR-191-oop-ml-pipelinefrom
dev/claude/ADR-229
Jun 28, 2026
Merged

ADR-229: Implement ModelTrainer, KerasBackend, and speech_10_train_model entry-point#235
jodavis merged 7 commits into
feature/ADR-191-oop-ml-pipelinefrom
dev/claude/ADR-229

Conversation

@jodavis-claude

Copy link
Copy Markdown
Collaborator

Work item

ADR-229: Implement the ModelTrainer class, the KerasBackend protocol with its default TF/Keras implementation, and the speech_10_train_model.py entry-point — the stage that trains a CTC speech-to-text model from the featurised manifests produced by earlier pipeline stages.

Changes

  • ml/pipeline/speech/model_trainer.py — NEW. KerasBackend Protocol, DefaultKerasBackend (TF/Keras with deferred imports), and ModelTrainer class that filters manifests by parent_id, builds a tf.data.Dataset, and calls the backend to train and save.
  • ml/pipeline/stages/speech_10_train_model.py — NEW. CLI entry-point following the speech_07 pattern; accepts --train-manifest-dir, --spectrogram-manifest-dir, --token-manifest-dir, --vocab-dir, --spectrogram-dir, --token-dir, --output-dir.
  • ml/pipeline/stages/params.py — MODIFIED. Added TrainModelParams(epochs, batch_size) dataclass and train_model: TrainModelParams field to PipelineParams.
  • ml/params.yaml — MODIFIED. Added train_model: epochs: 10, batch_size: 32 under stages:.
  • ml/test/pipeline/speech/test_model_trainer.py — NEW. 15 pytest-style unit tests for ModelTrainer (filtering, backend call args, save/return path).
  • ml/test/pipeline/stages/test_params.py — MODIFIED. Added train_model block to _VALID_DATA and 4 new TestTrainModelParamsLoad tests.
  • ml/test/pipeline/speech/test_token_stage.py — MODIFIED. Fixed flaky mtime-based assertions in TestSkipPath — replaced with file content comparison to eliminate timer-resolution races on Windows NTFS.

Design decisions

  • KerasBackend is a Protocol — the default implementation defers import tensorflow as tf inside each method body (same deferred-import pattern as librosa in SpectrogramStage), keeping the module importable without TF installed.
  • Filtering is ModelTrainer's responsibility — the full combined spectrogram and token manifests are passed in; train() filters internally to entries whose parent_id is in the train split.
  • ModelTrainer.train() is synchronous — the entry-point calls it directly without asyncio.run().
  • num_classes = vocab.ctc_blank_idx + 1 — the +1 is the CTC blank token, consistent with VocabResult.ctc_blank_idx = len(phoneme_list).
  • DefaultKerasBackend.build_ctc_model uses model.compile(loss='ctc') — Keras 3 / TF 2.x string alias for CTC loss; documented with a comment in the source.
  • Empty filtered set returns an empty list from _build_dataset — avoids a TF import and works naturally with stub backend tests.

@jodavis-claude jodavis-claude left a comment

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ADR-229 Review — ModelTrainer, KerasBackend, speech_10_train_model

Exit criteria check:

  • KerasBackend protocol + DefaultKerasBackend with all five methods (build_ctc_model, train, predict, save, load): implemented correctly.
  • ModelTrainer.train(): filters by parent_id, constructs tf.data.Dataset with batching/prefetching, calls KerasBackend.train then save, returns model path: implemented correctly.
  • speech_10_train_model.py entry-point with all required CLI args: implemented correctly.
  • TrainModelParams dataclass + stages.train_model in params.py and params.yaml: implemented correctly.
  • Unit tests (15 tests in test_model_trainer.py + 4 in test_params.py): all required test classes present and correctly structured.
  • Build/tests pass: confirmed by context file validate log.

Issues found:

Priority 1 — Two correctness issues:

  1. Silent zero-sample training (model_trainer.py line 228): When the filter produces zero (spectrogram, token) pairs, _build_dataset returns [] and backend.train(model, [], epochs) is called with no warning. Keras will silently fit on zero steps. A _logger.warning(...) should be emitted before build_ctc_model / train are called in the empty case.

  2. Unused import tensorflow as tf in DefaultKerasBackend.train (line 137), predict (line 145), and save (line 151): tf is imported but never referenced in any of these three method bodies. The deferred-import convention is meaningful only when tf.* is actually used. Remove these three dead imports.

Priority 4 — Documentation gap:

  1. _doc_speech.md omits ModelTrainer (line 26): model_trainer.py is a new module in ml/pipeline/speech/ but has no entry in the Stages table and no design section. The intro does not mention training as a third cohort. The key decisions (KerasBackend as TF seam, synchronous training, filtering responsibility, constructor-injected params) need to be documented. Per CONTRIBUTING.md: new subsystems require a documentation file, and designs must be updated to match.

Style notes (not blocking):

  • DefaultKerasBackend does not explicitly declare (KerasBackend) as a base — structural typing makes this valid, but explicit inheritance is used by EdgeTtsProvider and similar classes elsewhere for clarity.
  • dvc.yaml only defines stages through speech_01 — downstream stages including speech_10 are not yet wired. Not blocking since other stages are in-flight, but note that the entry-point cannot be exercised via dvc repro yet.

Comment thread ml/pipeline/speech/model_trainer.py Outdated
Comment thread ml/pipeline/speech/model_trainer.py Outdated
Comment thread ml/pipeline/speech/_doc_speech.md
@github-actions

github-actions Bot commented Jun 23, 2026

Copy link
Copy Markdown

Test Results

401 tests  ±0   401 ✅ ±0   2m 20s ⏱️ -25s
  5 suites ±0     0 💤 ±0 
  5 files   ±0     0 ❌ ±0 

Results for commit febd849. ± Comparison against base commit ba36fd2.

♻️ This comment has been updated with latest results.

@jodavis-claude jodavis-claude force-pushed the dev/claude/ADR-229 branch 3 times, most recently from 622bf38 to 040a373 Compare June 26, 2026 23:16
…airs; remove dead TF imports

Issue 1: Added _logger.warning() before build_ctc_model/train when the parent_id
filter produces an empty pairs list. Without this guard, a misconfigured pipeline
(wrong manifest dirs, split-manifest mismatch) fails silently — Keras fits on zero
steps and produces a trained model from nothing. Added TestZeroSampleWarning to verify.

Issue 2: Removed unused 'import tensorflow as tf' from DefaultKerasBackend.train(),
predict(), and save(). The deferred-import guard is only meaningful when tf.* is
actually referenced in the method body. build_ctc_model and _build_dataset (which do
use tf.*) retain their deferred imports.
Added ModelTrainer to the Stages table and intro cohorts list. Added a
ModelTrainer design section documenting: KerasBackend as the TF seam
(Protocol + DefaultKerasBackend with deferred imports), synchronous train,
filter-by-parent_id responsibility, constructor-injected params, dataset
construction ownership, and the empty-pair guard.
@jodavis jodavis changed the base branch from main to feature/ADR-191-oop-ml-pipeline June 26, 2026 23:31
Comment thread ml/pipeline/speech/_doc_speech.md Outdated
Comment thread ml/pipeline/speech/_doc_speech.md Outdated
Comment thread ml/pipeline/speech/model_trainer.py Outdated
Comment thread ml/pipeline/speech/model_trainer.py Outdated
Comment thread ml/pipeline/speech/model_trainer.py Outdated
Comment thread ml/pipeline/speech/model_trainer.py Outdated
ElwoodMoves and others added 2 commits June 27, 2026 20:36
…y; create TensorflowModelBuilder in tensorflow_backend.py; add manifest_filter.py; update ModelTrainer and tests
@jodavis jodavis marked this pull request as ready for review June 28, 2026 13:21
@jodavis jodavis enabled auto-merge (squash) June 28, 2026 13:21
@jodavis jodavis merged commit 9ab9062 into feature/ADR-191-oop-ml-pipeline Jun 28, 2026
5 checks passed
@jodavis jodavis deleted the dev/claude/ADR-229 branch June 28, 2026 13:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants