File tree Expand file tree Collapse file tree
src/maxdiffusion/input_pipeline Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -113,7 +113,11 @@ def _make_tfrecord_iterator(
113113 "clip_embeddings" : tf .io .FixedLenFeature ([], tf .string ),
114114 }
115115
116- used_feature_description = feature_description_fn if make_cached_tfrecord_iterator else feature_description
116+ used_feature_description = (
117+ feature_description_fn
118+ if (make_cached_tfrecord_iterator or config .dataset_type == "tfrecord" )
119+ else feature_description
120+ )
117121
118122 def _parse_tfrecord_fn (example ):
119123 return tf .io .parse_single_example (example , used_feature_description )
@@ -141,7 +145,11 @@ def prepare_sample(features):
141145 ds = ds .concatenate (padding_ds )
142146 max_logging .log (f"Padded evaluation dataset with { num_to_pad } samples." )
143147
144- used_prepare_sample = prepare_sample_fn if make_cached_tfrecord_iterator else prepare_sample
148+ used_prepare_sample = (
149+ prepare_sample_fn
150+ if (make_cached_tfrecord_iterator or config .dataset_type == "tfrecord" )
151+ else prepare_sample
152+ )
145153 ds = (
146154 ds .shard (num_shards = dataloading_host_count , index = dataloading_host_index )
147155 .map (_parse_tfrecord_fn , num_parallel_calls = AUTOTUNE )
You can’t perform that action at this time.
0 commit comments