Skip to content

Commit 5c6f65f

Browse files
ninatumartinarroyo
andcommitted
Fix: Ensure prepare_sample_fn is used for 'tfrecord' dataset type
Co-authored-by: martinarroyo <martinarroyo@google.com>
1 parent 1fe4ce0 commit 5c6f65f

1 file changed

Lines changed: 10 additions & 2 deletions

File tree

src/maxdiffusion/input_pipeline/_tfds_data_processing.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,11 @@ def _make_tfrecord_iterator(
113113
"clip_embeddings": tf.io.FixedLenFeature([], tf.string),
114114
}
115115

116-
used_feature_description = feature_description_fn if make_cached_tfrecord_iterator else feature_description
116+
used_feature_description = (
117+
feature_description_fn
118+
if (make_cached_tfrecord_iterator or config.dataset_type == "tfrecord")
119+
else feature_description
120+
)
117121

118122
def _parse_tfrecord_fn(example):
119123
return tf.io.parse_single_example(example, used_feature_description)
@@ -141,7 +145,11 @@ def prepare_sample(features):
141145
ds = ds.concatenate(padding_ds)
142146
max_logging.log(f"Padded evaluation dataset with {num_to_pad} samples.")
143147

144-
used_prepare_sample = prepare_sample_fn if make_cached_tfrecord_iterator else prepare_sample
148+
used_prepare_sample = (
149+
prepare_sample_fn
150+
if (make_cached_tfrecord_iterator or config.dataset_type == "tfrecord")
151+
else prepare_sample
152+
)
145153
ds = (
146154
ds.shard(num_shards=dataloading_host_count, index=dataloading_host_index)
147155
.map(_parse_tfrecord_fn, num_parallel_calls=AUTOTUNE)

0 commit comments

Comments
 (0)