From aaa0e111c5e9d8fb49bfa992f2272a121a95f741 Mon Sep 17 00:00:00 2001 From: Hana Joo Date: Thu, 2 Jul 2026 02:10:44 -0700 Subject: [PATCH] No public description PiperOrigin-RevId: 941561111 --- seqio/beam_utils.py | 4 ++-- seqio/dataset_providers.py | 38 ++++++++++++++++---------------- seqio/dataset_providers_test.py | 26 +++++++++++----------- seqio/evaluation.py | 4 ++-- seqio/evaluation_test.py | 26 +++++++++++----------- seqio/experimental.py | 18 +++++++-------- seqio/feature_converters.py | 10 ++++----- seqio/feature_converters_test.py | 16 +++++++------- seqio/helpers.py | 4 ++-- seqio/loggers.py | 2 +- seqio/loggers_test.py | 2 +- seqio/metrics.py | 16 +++++++------- seqio/preprocessors_test.py | 2 +- seqio/test_utils.py | 26 +++++++++++----------- seqio/test_utils_test.py | 4 ++-- seqio/utils.py | 18 +++++++-------- seqio/vocabularies.py | 7 +++--- 17 files changed, 112 insertions(+), 111 deletions(-) diff --git a/seqio/beam_utils.py b/seqio/beam_utils.py index 91a21bd1..92ce0c59 100644 --- a/seqio/beam_utils.py +++ b/seqio/beam_utils.py @@ -351,14 +351,14 @@ def _info_dict(self, ex: List[Dict[str, Any]]): if not ex: return {} assert len(ex) == 1 - ex = ex[0] + ex = ex[0] # pyrefly: ignore[bad-assignment] info = { "num_shards": self._num_shards, "features": {}, "seqio_version": seqio.__version__, } feature_dict = info["features"] - for k, v in ex.items(): + for k, v in ex.items(): # pyrefly: ignore[missing-attribute] if self._exclude_provenance and k.startswith(PROVENANCE_PREFIX): continue if isinstance(v, tf.RaggedTensor): diff --git a/seqio/dataset_providers.py b/seqio/dataset_providers.py index ab683138..014a0a83 100644 --- a/seqio/dataset_providers.py +++ b/seqio/dataset_providers.py @@ -130,7 +130,7 @@ def splits(self) -> Sequence[str]: def get_dataset( self, sequence_length: Optional[Mapping[str, int]] = None, - split: str = tfds.Split.TRAIN, + split: str = tfds.Split.TRAIN, # pyrefly: ignore[missing-attribute] use_cached: bool = False, shuffle: bool = True, seed: Optional[int] = None, @@ -334,9 +334,9 @@ def list_shards(self, split: str) -> Sequence[str]: raise NotImplementedError @abc.abstractmethod - def get_dataset( + def get_dataset( # pyrefly: ignore[bad-override] self, # pytype: disable=signature-mismatch # overriding-default-value-checks - split: str = tfds.Split.TRAIN, + split: str = tfds.Split.TRAIN, # pyrefly: ignore[missing-attribute] shuffle: bool = True, seed: Optional[int] = None, shard_info: Optional[ShardInfo] = None, @@ -443,7 +443,7 @@ def __repr__(self): def get_dataset( self, - split: str = tfds.Split.TRAIN, + split: str = tfds.Split.TRAIN, # pyrefly: ignore[missing-attribute] shuffle: bool = True, seed: Optional[int] = None, shard_info: Optional[ShardInfo] = None, @@ -569,7 +569,7 @@ def get_dataset( num_epochs: Optional[int] = 1, # Unused ) -> tf.data.Dataset: if split is None: - split = tfds.Split.TRAIN + split = tfds.Split.TRAIN # pyrefly: ignore[missing-attribute] return self.tfds_dataset.load( split, shuffle_files=shuffle, seed=seed, shard_info=shard_info ) @@ -657,7 +657,7 @@ def __repr__(self): def get_dataset( self, - split: str = tfds.Split.TRAIN, + split: str = tfds.Split.TRAIN, # pyrefly: ignore[missing-attribute] shuffle: bool = True, seed: Optional[int] = None, shard_info: Optional[ShardInfo] = None, @@ -1455,8 +1455,8 @@ def preprocess_postcache( if self.supports_caching: # Skip a sufficient number of seeds to avoid duplicating any from # pre-cache preprocessing. - seed = None if seed is None else seed + 42 * self._cache_step_idx - start_idx = self._cache_step_idx + 1 + seed = None if seed is None else seed + 42 * self._cache_step_idx # pyrefly: ignore[unsupported-operation] + start_idx = self._cache_step_idx + 1 # pyrefly: ignore[unsupported-operation] with utils.map_seed_manager(seed): dataset = self._preprocess_dataset( dataset, @@ -1530,7 +1530,7 @@ def assert_cached(self) -> None: ), f"'{self.name}' does not exist in any of the task cache directories." def get_cached_stats( - self, split: str = tfds.Split.TRAIN + self, split: str = tfds.Split.TRAIN # pyrefly: ignore[missing-attribute] ) -> Mapping[str, Union[int, float]]: """Returns basic statistics for cached dataset.""" self.assert_cached() @@ -1547,7 +1547,7 @@ def get_cached_stats( def get_dataset( self, # pytype: disable=signature-mismatch # overriding-default-value-checks sequence_length: Optional[Mapping[str, int]] = None, - split: str = tfds.Split.TRAIN, + split: str = tfds.Split.TRAIN, # pyrefly: ignore[missing-attribute] use_cached: bool = False, shuffle: bool = True, shuffle_buffer_size: Optional[int] = None, # Unique to Task @@ -1632,7 +1632,7 @@ def get_dataset( ) else: ds = source.get_dataset(split=split, shuffle=shuffle, seed=seed) - ds = ds.shard(shard_info.num_shards, shard_info.index) + ds = ds.shard(shard_info.num_shards, shard_info.index) # pyrefly: ignore[missing-attribute] num_shards = shard_info.num_shards if shard_info else 1 if try_in_mem_cache and ( @@ -1643,7 +1643,7 @@ def get_dataset( ) or ( source.num_input_examples(split) - and source.num_input_examples(split) + and source.num_input_examples(split) # pyrefly: ignore[unsupported-operation] < _MAX_EXAMPLES_TO_MEM_CACHE * num_shards ) ): @@ -1692,10 +1692,10 @@ def _get_cached_source( self.assert_cached() file_shuffle_buffer_size = ( file_shuffle_buffer_size - or self._cache_dataset_placerholder.file_shuffle_buffer_size + or self._cache_dataset_placerholder.file_shuffle_buffer_size # pyrefly: ignore[missing-attribute] ) return _CachedDataSource( - cache_dir=self.cache_dir, + cache_dir=self.cache_dir, # pyrefly: ignore[bad-argument-type] split=split, file_shuffle_buffer_size=file_shuffle_buffer_size, ) @@ -1720,7 +1720,7 @@ class TaskRegistry(DatasetProviderRegistry): # pylint: disable=arguments-renamed @classmethod - def add( + def add( # pyrefly: ignore[bad-override] cls, name: str, source: DataSourceInterface, @@ -1910,7 +1910,7 @@ def _get_submixture_rate(self, mix: "Mixture") -> float: return float(rate) def num_input_examples(self, split: str) -> int: - return sum( + return sum( # pyrefly: ignore[no-matching-overload] t.num_input_examples(split) for t in self.tasks if split in t.splits ) @@ -1953,7 +1953,7 @@ def get_task_dataset( task: Task, output_feature_keys: Set[str], sequence_length: Optional[Mapping[str, int]] = None, - split: str = tfds.Split.TRAIN, + split: str = tfds.Split.TRAIN, # pyrefly: ignore[missing-attribute] use_cached: bool = False, shuffle: bool = True, seed: Optional[int] = None, @@ -1985,7 +1985,7 @@ def _get_all_mixing_rates(self, tasks): def get_dataset( # pytype: disable=signature-mismatch # overriding-parameter-type-checks self, sequence_length: Optional[Mapping[str, int]] = None, - split: str = tfds.Split.TRAIN, + split: str = tfds.Split.TRAIN, # pyrefly: ignore[missing-attribute] use_cached: bool = False, shuffle: bool = True, seed: Optional[int] = None, @@ -2284,7 +2284,7 @@ class MixtureRegistry(DatasetProviderRegistry): # pylint: disable=arguments-renamed @classmethod - def add( + def add( # pyrefly: ignore[bad-override] cls, name, tasks, diff --git a/seqio/dataset_providers_test.py b/seqio/dataset_providers_test.py index 7fbfb162..a8ba39d1 100644 --- a/seqio/dataset_providers_test.py +++ b/seqio/dataset_providers_test.py @@ -199,7 +199,7 @@ def predict_metric_fn_with_types( valid_task_with_types = TaskRegistry.add( "valid_metrics_with_types", - source=self.function_source, + source=self.function_source, # pyrefly: ignore[bad-argument-type] output_features={ "inputs": dataset_providers.Feature( test_utils.sentencepiece_vocab() @@ -666,7 +666,7 @@ def test_value_errors(self): _ = dataset_providers.Task( "multiple_cache_placeholders", source=dataset_providers.FunctionDataSource( - dataset_fn=dataset_fn, splits=["train", "validation"] + dataset_fn=dataset_fn, splits=["train", "validation"] # pyrefly: ignore[bad-argument-type] ), preprocessors=[ test_utils.test_text_preprocessor, @@ -691,7 +691,7 @@ def test_value_errors(self): task = dataset_providers.Task( "sequence_length_pre_cache", dataset_providers.FunctionDataSource( - dataset_fn=dataset_fn, + dataset_fn=dataset_fn, # pyrefly: ignore[bad-argument-type] splits=["train"], ), preprocessors=[ @@ -1739,8 +1739,8 @@ def test_multidimension_sequence_length( dataset_fn = lambda split, shuffle_files: ds dataset_providers.TaskRegistry.add( task_name, - source=dataset_providers.FunctionDataSource( - dataset_fn=dataset_fn, splits=["train", "validation"] + source=dataset_providers.FunctionDataSource( # pyrefly: ignore[bad-argument-type] + dataset_fn=dataset_fn, splits=["train", "validation"] # pyrefly: ignore[bad-argument-type] ), preprocessors=[ dataset_providers.CacheDatasetPlaceholder(), @@ -1953,8 +1953,8 @@ def register_dummy_task( """Register a dummy task for GetDatasetTest.""" dataset_providers.TaskRegistry.add( task_name, - source=dataset_providers.FunctionDataSource( - dataset_fn=dataset_fn, splits=["train", "validation"] + source=dataset_providers.FunctionDataSource( # pyrefly: ignore[bad-argument-type] + dataset_fn=dataset_fn, splits=["train", "validation"] # pyrefly: ignore[bad-argument-type] ), preprocessors=[ dataset_providers.CacheDatasetPlaceholder(), @@ -2022,13 +2022,13 @@ def good_fn(split, shuffle_files): del split del shuffle_files - dataset_providers.FunctionDataSource(good_fn, splits=("train",)) + dataset_providers.FunctionDataSource(good_fn, splits=("train",)) # pyrefly: ignore[bad-argument-type] def default_good_fn(split, shuffle_files=False): del split del shuffle_files - dataset_providers.FunctionDataSource(default_good_fn, splits=("train",)) + dataset_providers.FunctionDataSource(default_good_fn, splits=("train",)) # pyrefly: ignore[bad-argument-type] def seed_fn(split, shuffle_files=True, seed=0): del split @@ -2041,7 +2041,7 @@ def extra_kwarg_good_fn(split, shuffle_files, unused_kwarg=True): del split del shuffle_files - dataset_providers.FunctionDataSource(extra_kwarg_good_fn, splits=("train",)) + dataset_providers.FunctionDataSource(extra_kwarg_good_fn, splits=("train",)) # pyrefly: ignore[bad-argument-type] class GoodProtocol(dataset_providers.DatasetFnCallable): @@ -2062,7 +2062,7 @@ def __call__(self, split, shuffle_files, seed=None): def missing_shuff(split): del split - dataset_providers.FunctionDataSource(missing_shuff, splits=("train",)) + dataset_providers.FunctionDataSource(missing_shuff, splits=("train",)) # pyrefly: ignore[bad-argument-type] with self.assertRaisesWithLiteralMatch( ValueError, @@ -2075,7 +2075,7 @@ def missing_shuff(split): def missing_split(shuffle_files): del shuffle_files - dataset_providers.FunctionDataSource(missing_split, splits=("train",)) + dataset_providers.FunctionDataSource(missing_split, splits=("train",)) # pyrefly: ignore[bad-argument-type] with self.assertRaisesWithLiteralMatch( ValueError, @@ -2089,7 +2089,7 @@ def extra_pos_arg(split, shuffle_files, unused_arg): del split del shuffle_files - dataset_providers.FunctionDataSource(extra_pos_arg, splits=("train",)) + dataset_providers.FunctionDataSource(extra_pos_arg, splits=("train",)) # pyrefly: ignore[bad-argument-type] diff --git a/seqio/evaluation.py b/seqio/evaluation.py index e4e6557c..c58bba98 100644 --- a/seqio/evaluation.py +++ b/seqio/evaluation.py @@ -515,7 +515,7 @@ def dataset_fn(task: Task) -> tf.data.Dataset: self._cached_task_datasets = cached_task_datasets self._model_feature_shapes = { k: tuple(spec.shape) - for k, spec in eval_ds.element_spec.items() + for k, spec in eval_ds.element_spec.items() # pyrefly: ignore[unbound-name] if spec.shape.rank > 0 } @@ -810,7 +810,7 @@ def model_feature_shapes(self) -> Mapping[str, Tuple[int, ...]]: @property def loggers(self) -> Tuple[loggers_lib.Logger]: - return tuple(self._loggers) + return tuple(self._loggers) # pyrefly: ignore[bad-return] class MetricManager: diff --git a/seqio/evaluation_test.py b/seqio/evaluation_test.py index 8ca2ba0b..3d193e6b 100644 --- a/seqio/evaluation_test.py +++ b/seqio/evaluation_test.py @@ -95,8 +95,8 @@ def register_dummy_task( """Register a dummy task for GetDatasetTest.""" return dataset_providers.TaskRegistry.add( task_name, - source=dataset_providers.FunctionDataSource( - dataset_fn=dataset_fn, splits=["train", "validation"] + source=dataset_providers.FunctionDataSource( # pyrefly: ignore[bad-argument-type] + dataset_fn=dataset_fn, splits=["train", "validation"] # pyrefly: ignore[bad-argument-type] ), preprocessors=[preprocessor], postprocess_fn=postprocess_fn, @@ -143,7 +143,7 @@ def _task_from_tensor_slices(name, tensor_slices, label_classes): return dataset_providers.Task( name, dataset_providers.FunctionDataSource( - lambda split, shuffle_files: tf.data.Dataset.from_tensor_slices( + lambda split, shuffle_files: tf.data.Dataset.from_tensor_slices( # pyrefly: ignore[bad-argument-type] tensor_slices ), splits="validation", @@ -334,7 +334,7 @@ def _task_from_tensor_slices_rank2(name, tensor_slices, label_classes): return dataset_providers.Task( name, dataset_providers.FunctionDataSource( - lambda split, shuffle_files: tf.data.Dataset.from_tensor_slices( + lambda split, shuffle_files: tf.data.Dataset.from_tensor_slices( # pyrefly: ignore[bad-argument-type] tensor_slices ), splits="validation", @@ -496,9 +496,9 @@ def score_fn( all_metrics, _ = evaluator.evaluate( compute_metrics=True, - predict_fn=predict_fn, - score_fn=score_fn, - predict_with_aux_fn=predict_with_aux_fn, + predict_fn=predict_fn, # pyrefly: ignore[bad-argument-type] + score_fn=score_fn, # pyrefly: ignore[bad-argument-type] + predict_with_aux_fn=predict_with_aux_fn, # pyrefly: ignore[bad-argument-type] step=42, ) return all_metrics.result(), evaluator @@ -595,7 +595,7 @@ def predict_with_aux_fn( compute_metrics=True, predict_fn=self.uncalled_fn, score_fn=self.uncalled_fn, - predict_with_aux_fn=predict_with_aux_fn, + predict_with_aux_fn=predict_with_aux_fn, # pyrefly: ignore[bad-argument-type] step=42, ) @@ -732,7 +732,7 @@ def predict_fn( return [(0, [5, 6]), (1, [6, 8])] all_metrics, _ = evaluator.evaluate( - compute_metrics=True, predict_fn=predict_fn, score_fn=self.uncalled_fn + compute_metrics=True, predict_fn=predict_fn, score_fn=self.uncalled_fn # pyrefly: ignore[bad-argument-type] ) # expected = {"accuracy": 2.0 / 3 * 100} expected = {"sequence_accuracy": 50} @@ -777,7 +777,7 @@ def predict_fn( return [(0, [5]), (1, [6]), (2, [7])] all_metrics, _ = evaluator.evaluate( - compute_metrics=True, predict_fn=predict_fn, score_fn=self.uncalled_fn + compute_metrics=True, predict_fn=predict_fn, score_fn=self.uncalled_fn # pyrefly: ignore[bad-argument-type] ) expected = {"accuracy": 100} self.assertDictClose(expected, all_metrics.result()[task.name]) @@ -846,7 +846,7 @@ def score_fn( evaluator = Evaluator() # pytype: disable=missing-parameter all_metrics, _ = evaluator.evaluate( - compute_metrics=True, predict_fn=predict_fn, score_fn=score_fn + compute_metrics=True, predict_fn=predict_fn, score_fn=score_fn # pyrefly: ignore[bad-argument-type] ) expected = { task1.name: {"sequence_accuracy": 50.0, "total_score": 651}, @@ -1160,7 +1160,7 @@ def mixing_order_predict_fn( all_metrics, all_outputs = evaluator.evaluate( compute_metrics=True, - predict_fn=mixing_order_predict_fn, + predict_fn=mixing_order_predict_fn, # pyrefly: ignore[bad-argument-type] score_fn=self.uncalled_fn, ) expected_metric = {"sequence_accuracy": 100} @@ -1286,7 +1286,7 @@ def mock_init(self): compute_metrics=True, predict_fn=self.uncalled_fn, predict_with_aux_fn=self.uncalled_fn, - score_fn=score_fn_with_intermediates, + score_fn=score_fn_with_intermediates, # pyrefly: ignore[bad-argument-type] step=42, ) results = all_metrics.result() diff --git a/seqio/experimental.py b/seqio/experimental.py index 16cc4bb9..945c713e 100644 --- a/seqio/experimental.py +++ b/seqio/experimental.py @@ -172,7 +172,7 @@ def validate_sequence_length(ds, sequence_length): return TaskRegistry.add( new_name, - source=task.source, + source=task.source, # pyrefly: ignore[bad-argument-type] preprocessors=new_preprocessors, output_features=task.output_features, metric_fns=task.metric_fns, @@ -298,7 +298,7 @@ def list_shards(self, split: str) -> Sequence[str]: def get_dataset( self, - split: str = tfds.Split.TRAIN, + split: str = tfds.Split.TRAIN, # pyrefly: ignore[missing-attribute] shuffle: bool = True, seed: Optional[int] = None, shard_info: Optional[ShardInfo] = None, @@ -372,7 +372,7 @@ def _get_maybe_sharded_dataset( train_ds = _get_maybe_sharded_dataset( split_=self._train_split, shuffle_=True, - seed_=train_seed if shuffle else 0, + seed_=train_seed if shuffle else 0, # pyrefly: ignore[bad-argument-type] ) train_ds = _apply_preprocessors(train_ds, self._train_preprocessors) train_ds = train_ds.map( @@ -385,12 +385,12 @@ def _get_maybe_sharded_dataset( datasets['train'] = train_ds eval_ds = _get_maybe_sharded_dataset( - split_=split, shuffle_=shuffle, seed_=eval_seed + split_=split, shuffle_=shuffle, seed_=eval_seed # pyrefly: ignore[bad-argument-type] ) eval_ds = _apply_preprocessors(eval_ds, self._eval_preprocessors) datasets['eval'] = eval_ds - return tf.data.Dataset.zip(datasets) + return tf.data.Dataset.zip(datasets) # pyrefly: ignore[bad-argument-type] @@ -587,10 +587,10 @@ def _add_sentinels(dataset, sequence_length, output_features): @utils.map_over_dataset def _my_fn(x): sentinels_input = [ - _sentinel_id(input_vocab, idx) for idx in range(num_sentinels) + _sentinel_id(input_vocab, idx) for idx in range(num_sentinels) # pyrefly: ignore[bad-argument-type] ] sentinels_output = [ - _sentinel_id(target_vocab, idx) for idx in range(num_sentinels) + _sentinel_id(target_vocab, idx) for idx in range(num_sentinels) # pyrefly: ignore[bad-argument-type] ] x['inputs'] = tf.concat([x['inputs'], sentinels_input], 0) x['targets'] = tf.concat([sentinels_output, x['targets']], 0) @@ -604,7 +604,7 @@ def _postprocess_fn_remove_sentinel(string_label, *args, **kwargs): del kwargs vocab = task.output_features['targets'].vocabulary sentinel_str = vocab.decode( - [_sentinel_id(vocab, idx) for idx in range(num_sentinels)] + [_sentinel_id(vocab, idx) for idx in range(num_sentinels)] # pyrefly: ignore[bad-argument-type] ) if string_label.startswith(sentinel_str): string_label = string_label[len(sentinel_str) :].strip() @@ -647,7 +647,7 @@ def new_fn(string_label, *args, **kwargs): TaskRegistry.add( sentinel_task_name, - source=task.source, + source=task.source, # pyrefly: ignore[bad-argument-type] preprocessors=new_preprocessors, output_features=task.output_features, postprocess_fn=new_postprocess_fn, diff --git a/seqio/feature_converters.py b/seqio/feature_converters.py index ced4a830..b849dc16 100644 --- a/seqio/feature_converters.py +++ b/seqio/feature_converters.py @@ -192,18 +192,18 @@ def _check_exact_match( actual_feature_source: str, ) -> None: """Check whether expected and actual features match one-to-one.""" - expected_features = set(expected_features) - actual_features = set(actual_features) + expected_features = set(expected_features) # pyrefly: ignore[bad-assignment] + actual_features = set(actual_features) # pyrefly: ignore[bad-assignment] if expected_features != actual_features: - if actual_features - expected_features: - extra_features = actual_features - expected_features + if actual_features - expected_features: # pyrefly: ignore[unsupported-operation] + extra_features = actual_features - expected_features # pyrefly: ignore[unsupported-operation] raise ValueError( f"The {actual_feature_source} contains extra features not specified " f"in the {expected_feature_source}: {extra_features}" ) else: - missing_features = expected_features - actual_features + missing_features = expected_features - actual_features # pyrefly: ignore[unsupported-operation] raise ValueError( f"The {actual_feature_source} is missing features specified " f"in the {expected_feature_source}: {missing_features}" diff --git a/seqio/feature_converters_test.py b/seqio/feature_converters_test.py index f7ccece4..fc7d3265 100644 --- a/seqio/feature_converters_test.py +++ b/seqio/feature_converters_test.py @@ -271,7 +271,7 @@ def test_pass_through_and_packing(self): ): expected_msg = "Packing is incompatible with pass-through features." with self.assertRaisesRegex(ValueError, expected_msg): - feature_converters.FeatureConverter( + feature_converters.FeatureConverter( # pyrefly: ignore[bad-instantiation] pack=True, passthrough_features={ "pass_through": feature_converters.FeatureConverter.FeatureSpec( @@ -284,7 +284,7 @@ def test_pass_through(self): with mock.patch.object( feature_converters.FeatureConverter, "__abstractmethods__", set() ): - converter = feature_converters.FeatureConverter( + converter = feature_converters.FeatureConverter( # pyrefly: ignore[bad-instantiation] pack=False, passthrough_features={ "pass_through": feature_converters.FeatureConverter.FeatureSpec( @@ -527,11 +527,11 @@ class EncDecFeatureConverterTest(tf.test.TestCase): def tearDown(self): if "passthrough" in feature_converters.EncDecFeatureConverter.TASK_FEATURES: - del feature_converters.EncDecFeatureConverter.TASK_FEATURES["passthrough"] - del feature_converters.EncDecFeatureConverter.MODEL_FEATURES[ + del feature_converters.EncDecFeatureConverter.TASK_FEATURES["passthrough"] # pyrefly: ignore[unsupported-operation] + del feature_converters.EncDecFeatureConverter.MODEL_FEATURES[ # pyrefly: ignore[unsupported-operation] "passthrough" ] - del feature_converters.EncDecFeatureConverter.PACKING_FEATURE_DTYPES[ + del feature_converters.EncDecFeatureConverter.PACKING_FEATURE_DTYPES[ # pyrefly: ignore[unsupported-operation] "passthrough" ] super().tearDown() @@ -827,13 +827,13 @@ def tearDown(self): "passthrough" in feature_converters.PrefixLMFeatureConverter.TASK_FEATURES ): - del feature_converters.PrefixLMFeatureConverter.TASK_FEATURES[ + del feature_converters.PrefixLMFeatureConverter.TASK_FEATURES[ # pyrefly: ignore[unsupported-operation] "passthrough" ] - del feature_converters.PrefixLMFeatureConverter.MODEL_FEATURES[ + del feature_converters.PrefixLMFeatureConverter.MODEL_FEATURES[ # pyrefly: ignore[unsupported-operation] "passthrough" ] - del feature_converters.PrefixLMFeatureConverter.PACKING_FEATURE_DTYPES[ + del feature_converters.PrefixLMFeatureConverter.PACKING_FEATURE_DTYPES[ # pyrefly: ignore[unsupported-operation] "passthrough" ] super().tearDown() diff --git a/seqio/helpers.py b/seqio/helpers.py index 99963182..ebb90ced 100644 --- a/seqio/helpers.py +++ b/seqio/helpers.py @@ -133,7 +133,7 @@ def _validate_output_features(og_output_features, new_output_features): new_task = dp.Task( new_mixture_or_task_name, source=mixture_or_task.source, - output_features=new_output_features, + output_features=new_output_features, # pyrefly: ignore[bad-argument-type] preprocessors=preprocessors, postprocess_fn=mixture_or_task.postprocessor, metric_fns=mixture_or_task.metric_fns, @@ -212,7 +212,7 @@ def list_shards(self, split: str) -> Sequence[str]: def get_dataset( self, - split: str = tfds.Split.TRAIN, + split: str = tfds.Split.TRAIN, # pyrefly: ignore[missing-attribute] shuffle: bool = True, seed: Optional[int] = None, shard_info: Optional[dp.ShardInfo] = None, diff --git a/seqio/loggers.py b/seqio/loggers.py index 21680030..34a82730 100644 --- a/seqio/loggers.py +++ b/seqio/loggers.py @@ -145,7 +145,7 @@ def _write_metric( """Log a metric value to tensorboard, dispatched on value type.""" if isinstance(value, metrics_lib.Scalar): value: metrics_lib.Scalar = value - value = float(np.array(value.value)) + value = float(np.array(value.value)) # pyrefly: ignore[bad-assignment] with writer.as_default(): tf.summary.scalar(name=tag, data=value, step=step) elif isinstance(value, metrics_lib.Image): diff --git a/seqio/loggers_test.py b/seqio/loggers_test.py index 76fcc8d4..a71b66f9 100644 --- a/seqio/loggers_test.py +++ b/seqio/loggers_test.py @@ -372,7 +372,7 @@ def test_predictions_and_aux_values(self): step=42, metrics={"accuracy": metrics_lib.Scalar(100)}, dataset=task_dataset, - inferences=inferences, + inferences=inferences, # pyrefly: ignore[bad-argument-type] targets=targets, ) diff --git a/seqio/metrics.py b/seqio/metrics.py index 730f7a2d..dee1dfce 100644 --- a/seqio/metrics.py +++ b/seqio/metrics.py @@ -173,17 +173,17 @@ def from_model_output( # pylint:disable=missing-function-docstring ) if mask is None: - mask = jnp.ones((num_examples,), jnp.int32) + mask = jnp.ones((num_examples,), jnp.int32) # pyrefly: ignore[bad-assignment] if indices_2d is None: - indices_2d = jnp.transpose( + indices_2d = jnp.transpose( # pyrefly: ignore[bad-assignment] jnp.stack([ jnp.zeros((num_examples,), jnp.int32), jnp.arange(num_examples, dtype=jnp.int32), ]) ) return cls( - values={ + values={ # pyrefly: ignore[bad-argument-type] "model_output": model_output, "indices_2d": indices_2d, "mask": mask, @@ -227,7 +227,7 @@ class LegacyMetric(Metric): targets_and_inferences: Dict[str, Any] @classmethod - def empty(cls, metric_fn, postprocess_fn) -> "LegacyMetric": + def empty(cls, metric_fn, postprocess_fn) -> "LegacyMetric": # pyrefly: ignore[bad-override] pos_args = tuple( key for key, param in inspect.signature(metric_fn).parameters.items() @@ -264,7 +264,7 @@ def postprocess_fn( return self._postprocess_fn(targets_or_predictions, **postprocess_kwargs) return targets_or_predictions - def from_model_output( # pylint:disable=arguments-renamed + def from_model_output( # pylint:disable=arguments-renamed # pyrefly: ignore[bad-override] self, inputs: Sequence[Mapping[str, Any]], model_output: Union[np.ndarray, Tuple[np.ndarray, np.ndarray]], @@ -321,7 +321,7 @@ def from_model_output( # pylint:disable=arguments-renamed # Postprocesses the predictions here. postprocessed_predictions = [ self.postprocess_fn(p, example=ex, is_target=False) - for ex, p in zip(inputs, predictions) + for ex, p in zip(inputs, predictions) # pyrefly: ignore[unbound-name] ] self.metric_fn_kwargs["predictions"] = postprocessed_predictions @@ -430,7 +430,7 @@ def _get_model_output_type() -> ModelOutputType: return model_output_type @flax.struct.dataclass - class FromMetricFun(cls): + class FromMetricFun(cls): # pyrefly: ignore[invalid-inheritance] """Wrapper PassthroughLegacyMetric class that runs metric_fn.""" model_output_type: ModelOutputType = _get_model_output_type() @@ -535,7 +535,7 @@ def actual_compute( # If neither 2d or 3d, assume that model_output is already # decoded. predictions = model_output - targets_and_inferences["output"] = predictions + targets_and_inferences["output"] = predictions # pyrefly: ignore[unbound-name] # Postprocesses the predictions here. postprocessed_predictions = [ diff --git a/seqio/preprocessors_test.py b/seqio/preprocessors_test.py index 2f53eebc..2e85bc4e 100644 --- a/seqio/preprocessors_test.py +++ b/seqio/preprocessors_test.py @@ -249,7 +249,7 @@ def test_append_eos(self): ) # Trim to sequence lengths (but with targets=None). - sequence_length['targets'] = None + sequence_length['targets'] = None # pyrefly: ignore[bad-assignment] assert_dataset( preprocessors.append_eos_after_trim( og_dataset, diff --git a/seqio/test_utils.py b/seqio/test_utils.py index 3fcea554..effe35ae 100644 --- a/seqio/test_utils.py +++ b/seqio/test_utils.py @@ -577,7 +577,7 @@ def get_fake_dataset( # Keep only defined features. examples = list( - map(lambda ex: {k: ex[k] for k in output_signature}, _FAKE_DATASET[split]) + map(lambda ex: {k: ex[k] for k in output_signature}, _FAKE_DATASET[split]) # pyrefly: ignore[unsupported-operation] ) ds = tf.data.Dataset.from_generator( @@ -679,7 +679,7 @@ def assert_dataset( """ if not isinstance(expected, list): - expected = [expected] + expected = [expected] # pyrefly: ignore[bad-assignment] actual = list(tfds.as_numpy(dataset)) _pyunit_proxy.assertEqual(len(actual), len(expected)) @@ -769,7 +769,7 @@ def _assert_compare_to_fake_dataset( dataset if not ragged_features else "token_preprocessed_ragged_features" ) _make_fake_datasets() - fake_examples = copy.deepcopy(_FAKE_DATASETS[dataset][split]) + fake_examples = copy.deepcopy(_FAKE_DATASETS[dataset][split]) # pyrefly: ignore[unsupported-operation] for key, feat in features.items(): for n, ex in enumerate(fake_examples): @@ -862,7 +862,7 @@ def create_default_dataset( if output_types is None: output_types = {feature_name: tf.int32 for feature_name in feature_names} if output_shapes is None: - output_shapes = {feature_name: [None] for feature_name in feature_names} + output_shapes = {feature_name: [None] for feature_name in feature_names} # pyrefly: ignore[bad-assignment] ds = tf.data.Dataset.from_generator( lambda: x, output_types=output_types, output_shapes=output_shapes @@ -934,7 +934,7 @@ def random_token_preprocessor(ex, seed, sequence_length): [], maxval=n_tokens, dtype=tf.int32, seed=seed ) res[feat] = tf.roll(tokens, shift=random_shift, axis=0) - return res + return res # pyrefly: ignore[unbound-name] def token_preprocessor_no_sequence_length(dataset, output_features): @@ -1218,7 +1218,7 @@ def __call__( return evaluator.evaluate( compute_metrics=True, predict_fn=PredictCallable(), - score_fn=ScoreCallable(), + score_fn=ScoreCallable(), # pyrefly: ignore[bad-argument-type] )[0].result()[task_name] @@ -1229,7 +1229,7 @@ def __init__(self, encode_dict, vocab_size=None): self._encode_dict = encode_dict self._vocab_size = vocab_size - def unk_id(self) -> Optional[int]: + def unk_id(self) -> Optional[int]: # pyrefly: ignore[bad-override] raise NotImplementedError def encode(self, s): @@ -1256,7 +1256,7 @@ def _decode(self, ids): def _decode_tf(self, ids: tf.Tensor) -> tf.Tensor: raise NotImplementedError - def _base_vocab_size(self) -> int: + def _base_vocab_size(self) -> int: # pyrefly: ignore[bad-override] raise NotImplementedError @property @@ -1465,7 +1465,7 @@ def _load_shard(shard_instruction, shuffle_files, seed): # Prepare TextLineSource. _dump_fake_dataset( os.path.join(self.test_data_dir, "train.tsv"), - _FAKE_DATASET["train"], + _FAKE_DATASET["train"], # pyrefly: ignore[unsupported-operation] [2, 1], _dump_examples_to_tsv, ) @@ -1484,7 +1484,7 @@ def _load_shard(shard_instruction, shuffle_files, seed): # Prepare TFExampleSource. _dump_fake_dataset( os.path.join(self.test_data_dir, "train.tfrecord"), - _FAKE_DATASET["train"], + _FAKE_DATASET["train"], # pyrefly: ignore[unsupported-operation] [2, 1], _dump_examples_to_tfrecord, ) @@ -1564,13 +1564,13 @@ def decode_tf_example_fn(example): self.cached_task_dir = os.path.join(self.test_data_dir, "cached_task") _dump_fake_dataset( os.path.join(self.cached_task_dir, "train.tfrecord"), - _FAKE_TOKENIZED_DATASET["train"], + _FAKE_TOKENIZED_DATASET["train"], # pyrefly: ignore[unsupported-operation] [2, 1], _dump_examples_to_tfrecord, ) _dump_fake_dataset( os.path.join(self.cached_task_dir, "validation.tfrecord"), - _FAKE_TOKENIZED_DATASET["validation"], + _FAKE_TOKENIZED_DATASET["validation"], # pyrefly: ignore[unsupported-operation] [2], _dump_examples_to_tfrecord, ) @@ -1581,7 +1581,7 @@ def decode_tf_example_fn(example): os.path.join( self.test_data_dir, "cached_plaintext_task", "train.tfrecord" ), - _FAKE_PLAINTEXT_TOKENIZED_DATASET["train"], + _FAKE_PLAINTEXT_TOKENIZED_DATASET["train"], # pyrefly: ignore[unsupported-operation] [2, 1], _dump_examples_to_tfrecord, ) diff --git a/seqio/test_utils_test.py b/seqio/test_utils_test.py index 1287ea0d..c462eff7 100644 --- a/seqio/test_utils_test.py +++ b/seqio/test_utils_test.py @@ -118,12 +118,12 @@ def ds_fn(split, shuffle_files): return ds source = dataset_providers.FunctionDataSource( - dataset_fn=ds_fn, splits=['train'] + dataset_fn=ds_fn, splits=['train'] # pyrefly: ignore[bad-argument-type] ) dataset_providers.TaskRegistry.add( 'test_data_injection_task', - source=source, + source=source, # pyrefly: ignore[bad-argument-type] preprocessors=[], output_features={}, metric_fns=[], diff --git a/seqio/utils.py b/seqio/utils.py index 29ec37ab..e516dcd6 100644 --- a/seqio/utils.py +++ b/seqio/utils.py @@ -158,7 +158,7 @@ def __init__( `tfds.builder()`. read_only: whether `get_dataset` can trigger the generation of a dataset. """ - _validate_tfds_name(name) + _validate_tfds_name(name) # pyrefly: ignore[bad-argument-type] self._name = name self._data_dir = data_dir self._data_dir_override = None @@ -188,7 +188,7 @@ def set_decoders(self, decoders) -> None: @property def tfds_splits(self) -> Optional[Mapping[str, TfdsSplit]]: - return self._split_map if self._is_custom_split_map else None + return self._split_map if self._is_custom_split_map else None # pyrefly: ignore[bad-return] def resolved_tfds_name(self, split: Optional[str] = None) -> Optional[str]: """Returns the resolved TFDS dataset name. @@ -232,7 +232,7 @@ def get_split_params( ) -> Tuple[Optional[str], Optional[str]]: """Returns a tuple of (dataset, data_dir) for the given canonical split.""" if self._is_custom_split_map: - if mapped_split := self._split_map.get(split): + if mapped_split := self._split_map.get(split): # pyrefly: ignore[missing-attribute] dataset = mapped_split.dataset data_dir = mapped_split.data_dir else: @@ -337,7 +337,7 @@ def _get_builder(self, split: Optional[str] = None): "`builder_kwargs` should be empty when `dataset` value is not" " present." ) - builder = tfds.builder_from_directory(data_dir) + builder = tfds.builder_from_directory(data_dir) # pyrefly: ignore[bad-argument-type] LazyTfdsLoader._MEMOIZED_BUILDERS[builder_key] = builder return LazyTfdsLoader._MEMOIZED_BUILDERS[builder_key] @@ -349,10 +349,10 @@ def _map_split(self, split: str) -> Optional[str]: """Maps the given split to a dataset split.""" if self._is_custom_split_map: self._split_map: Mapping[str, TfdsSplit] - return self._split_map[split].split + return self._split_map[split].split # pyrefly: ignore[bad-return, unsupported-operation] elif self._split_map: self._split_map: Mapping[str, str] - return self._split_map[split] + return self._split_map[split] # pyrefly: ignore[bad-return] else: return split @@ -385,7 +385,7 @@ def load( shard_info=None, ): """Returns a tf.data.Dataset for the given split.""" - dataset_split = self._map_split(split) + dataset_split = self._map_split(split) # pyrefly: ignore[bad-argument-type] dataset, data_dir = self.get_split_params(split) read_config = self.read_config read_config.input_context = ( @@ -441,7 +441,7 @@ def size(self, split: str) -> Optional[int]: dataset_size = ds_splits[dataset_split].num_examples # Very large datasets have num_examples = 0; default instead to np.inf dataset_size = dataset_size if dataset_size > 0 else np.inf - return dataset_size + return dataset_size # pyrefly: ignore[bad-return] # ============================== TFExamples ==================================== @@ -1564,7 +1564,7 @@ def wrapped_fn(ds, *args, **kwargs): def map_with_seeds(fn): @functools.wraps(fn) def wrapped_fn(ds, *args, **kwargs): - return _GrainRandomMapFn(fn, num_seeds, num_parallel_calls)( + return _GrainRandomMapFn(fn, num_seeds, num_parallel_calls)( # pyrefly: ignore[bad-argument-type] ds, *args, **kwargs ) diff --git a/seqio/vocabularies.py b/seqio/vocabularies.py index 8e80db42..a8e5df2b 100644 --- a/seqio/vocabularies.py +++ b/seqio/vocabularies.py @@ -93,7 +93,7 @@ def _encode(self, s: str) -> Sequence[int]: def encode(self, s: Union[Sequence[int], str]) -> Sequence[int]: """Tokenizes string to an int sequence, without adding EOS.""" - return self._encode(s) + return self._encode(s) # pyrefly: ignore[bad-argument-type] @abc.abstractmethod def _decode(self, ids): @@ -292,7 +292,7 @@ class _ModelContext: sp_model: bytes -_load_model_lock: ClassVar[threading.Lock] = threading.Lock() +_load_model_lock: ClassVar[threading.Lock] = threading.Lock() # pyrefly: ignore[invalid-annotation] def _load_model( @@ -563,6 +563,7 @@ def __eq__(self, other): def __str__(self) -> str: return ( + # pyrefly: ignore[bad-argument-type] f"SentencePieceVocabulary(file={self.sentencepiece_model_file}, " f"extra_ids={self._extra_ids}, " f"spm_md5={hashlib.md5(self.sp_model).hexdigest()})" @@ -1035,7 +1036,7 @@ def unk_id(self) -> Optional[int]: return self._unk_id @property - def pad_id(self) -> Optional[int]: + def pad_id(self) -> Optional[int]: # pyrefly: ignore[bad-override] return self._pad_id @property