ADR-229: Implement ModelTrainer, KerasBackend, and speech_10_train_model entry-point#235
Conversation
jodavis-claude
left a comment
There was a problem hiding this comment.
ADR-229 Review — ModelTrainer, KerasBackend, speech_10_train_model
Exit criteria check:
KerasBackendprotocol +DefaultKerasBackendwith all five methods (build_ctc_model,train,predict,save,load): implemented correctly.ModelTrainer.train(): filters byparent_id, constructstf.data.Datasetwith batching/prefetching, callsKerasBackend.trainthensave, returns model path: implemented correctly.speech_10_train_model.pyentry-point with all required CLI args: implemented correctly.TrainModelParamsdataclass +stages.train_modelinparams.pyandparams.yaml: implemented correctly.- Unit tests (15 tests in
test_model_trainer.py+ 4 intest_params.py): all required test classes present and correctly structured. - Build/tests pass: confirmed by context file validate log.
Issues found:
Priority 1 — Two correctness issues:
-
Silent zero-sample training (
model_trainer.pyline 228): When the filter produces zero (spectrogram, token) pairs,_build_datasetreturns[]andbackend.train(model, [], epochs)is called with no warning. Keras will silently fit on zero steps. A_logger.warning(...)should be emitted beforebuild_ctc_model/trainare called in the empty case. -
Unused
import tensorflow as tfinDefaultKerasBackend.train(line 137),predict(line 145), andsave(line 151):tfis imported but never referenced in any of these three method bodies. The deferred-import convention is meaningful only whentf.*is actually used. Remove these three dead imports.
Priority 4 — Documentation gap:
_doc_speech.mdomitsModelTrainer(line 26):model_trainer.pyis a new module inml/pipeline/speech/but has no entry in the Stages table and no design section. The intro does not mention training as a third cohort. The key decisions (KerasBackend as TF seam, synchronous training, filtering responsibility, constructor-injected params) need to be documented. PerCONTRIBUTING.md: new subsystems require a documentation file, and designs must be updated to match.
Style notes (not blocking):
DefaultKerasBackenddoes not explicitly declare(KerasBackend)as a base — structural typing makes this valid, but explicit inheritance is used byEdgeTtsProviderand similar classes elsewhere for clarity.dvc.yamlonly defines stages throughspeech_01— downstream stages includingspeech_10are not yet wired. Not blocking since other stages are in-flight, but note that the entry-point cannot be exercised viadvc reproyet.
622bf38 to
040a373
Compare
…ech_10_train_model entry-point
…t comparison instead
…airs; remove dead TF imports Issue 1: Added _logger.warning() before build_ctc_model/train when the parent_id filter produces an empty pairs list. Without this guard, a misconfigured pipeline (wrong manifest dirs, split-manifest mismatch) fails silently — Keras fits on zero steps and produces a trained model from nothing. Added TestZeroSampleWarning to verify. Issue 2: Removed unused 'import tensorflow as tf' from DefaultKerasBackend.train(), predict(), and save(). The deferred-import guard is only meaningful when tf.* is actually referenced in the method body. build_ctc_model and _build_dataset (which do use tf.*) retain their deferred imports.
Added ModelTrainer to the Stages table and intro cohorts list. Added a ModelTrainer design section documenting: KerasBackend as the TF seam (Protocol + DefaultKerasBackend with deferred imports), synchronous train, filter-by-parent_id responsibility, constructor-injected params, dataset construction ownership, and the empty-pair guard.
040a373 to
ccd15c5
Compare
…y; create TensorflowModelBuilder in tensorflow_backend.py; add manifest_filter.py; update ModelTrainer and tests
Work item
ADR-229: Implement the
ModelTrainerclass, theKerasBackendprotocol with its default TF/Keras implementation, and thespeech_10_train_model.pyentry-point — the stage that trains a CTC speech-to-text model from the featurised manifests produced by earlier pipeline stages.Changes
ml/pipeline/speech/model_trainer.py— NEW.KerasBackendProtocol,DefaultKerasBackend(TF/Keras with deferred imports), andModelTrainerclass that filters manifests byparent_id, builds atf.data.Dataset, and calls the backend to train and save.ml/pipeline/stages/speech_10_train_model.py— NEW. CLI entry-point following thespeech_07pattern; accepts--train-manifest-dir,--spectrogram-manifest-dir,--token-manifest-dir,--vocab-dir,--spectrogram-dir,--token-dir,--output-dir.ml/pipeline/stages/params.py— MODIFIED. AddedTrainModelParams(epochs, batch_size)dataclass andtrain_model: TrainModelParamsfield toPipelineParams.ml/params.yaml— MODIFIED. Addedtrain_model: epochs: 10, batch_size: 32understages:.ml/test/pipeline/speech/test_model_trainer.py— NEW. 15 pytest-style unit tests forModelTrainer(filtering, backend call args, save/return path).ml/test/pipeline/stages/test_params.py— MODIFIED. Addedtrain_modelblock to_VALID_DATAand 4 newTestTrainModelParamsLoadtests.ml/test/pipeline/speech/test_token_stage.py— MODIFIED. Fixed flaky mtime-based assertions inTestSkipPath— replaced with file content comparison to eliminate timer-resolution races on Windows NTFS.Design decisions
KerasBackendis aProtocol— the default implementation defersimport tensorflow as tfinside each method body (same deferred-import pattern aslibrosainSpectrogramStage), keeping the module importable without TF installed.ModelTrainer's responsibility — the full combined spectrogram and token manifests are passed in;train()filters internally to entries whoseparent_idis in the train split.ModelTrainer.train()is synchronous — the entry-point calls it directly withoutasyncio.run().num_classes = vocab.ctc_blank_idx + 1— the+1is the CTC blank token, consistent withVocabResult.ctc_blank_idx = len(phoneme_list).DefaultKerasBackend.build_ctc_modelusesmodel.compile(loss='ctc')— Keras 3 / TF 2.x string alias for CTC loss; documented with a comment in the source._build_dataset— avoids a TF import and works naturally with stub backend tests.