Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions seqio/beam_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,14 +351,14 @@ def _info_dict(self, ex: List[Dict[str, Any]]):
if not ex:
return {}
assert len(ex) == 1
ex = ex[0]
ex = ex[0] # pyrefly: ignore[bad-assignment]
info = {
"num_shards": self._num_shards,
"features": {},
"seqio_version": seqio.__version__,
}
feature_dict = info["features"]
for k, v in ex.items():
for k, v in ex.items(): # pyrefly: ignore[missing-attribute]
if self._exclude_provenance and k.startswith(PROVENANCE_PREFIX):
continue
if isinstance(v, tf.RaggedTensor):
Expand Down
38 changes: 19 additions & 19 deletions seqio/dataset_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def splits(self) -> Sequence[str]:
def get_dataset(
self,
sequence_length: Optional[Mapping[str, int]] = None,
split: str = tfds.Split.TRAIN,
split: str = tfds.Split.TRAIN, # pyrefly: ignore[missing-attribute]
use_cached: bool = False,
shuffle: bool = True,
seed: Optional[int] = None,
Expand Down Expand Up @@ -334,9 +334,9 @@ def list_shards(self, split: str) -> Sequence[str]:
raise NotImplementedError

@abc.abstractmethod
def get_dataset(
def get_dataset( # pyrefly: ignore[bad-override]
self, # pytype: disable=signature-mismatch # overriding-default-value-checks
split: str = tfds.Split.TRAIN,
split: str = tfds.Split.TRAIN, # pyrefly: ignore[missing-attribute]
shuffle: bool = True,
seed: Optional[int] = None,
shard_info: Optional[ShardInfo] = None,
Expand Down Expand Up @@ -443,7 +443,7 @@ def __repr__(self):

def get_dataset(
self,
split: str = tfds.Split.TRAIN,
split: str = tfds.Split.TRAIN, # pyrefly: ignore[missing-attribute]
shuffle: bool = True,
seed: Optional[int] = None,
shard_info: Optional[ShardInfo] = None,
Expand Down Expand Up @@ -569,7 +569,7 @@ def get_dataset(
num_epochs: Optional[int] = 1, # Unused
) -> tf.data.Dataset:
if split is None:
split = tfds.Split.TRAIN
split = tfds.Split.TRAIN # pyrefly: ignore[missing-attribute]
return self.tfds_dataset.load(
split, shuffle_files=shuffle, seed=seed, shard_info=shard_info
)
Expand Down Expand Up @@ -657,7 +657,7 @@ def __repr__(self):

def get_dataset(
self,
split: str = tfds.Split.TRAIN,
split: str = tfds.Split.TRAIN, # pyrefly: ignore[missing-attribute]
shuffle: bool = True,
seed: Optional[int] = None,
shard_info: Optional[ShardInfo] = None,
Expand Down Expand Up @@ -1455,8 +1455,8 @@ def preprocess_postcache(
if self.supports_caching:
# Skip a sufficient number of seeds to avoid duplicating any from
# pre-cache preprocessing.
seed = None if seed is None else seed + 42 * self._cache_step_idx
start_idx = self._cache_step_idx + 1
seed = None if seed is None else seed + 42 * self._cache_step_idx # pyrefly: ignore[unsupported-operation]
start_idx = self._cache_step_idx + 1 # pyrefly: ignore[unsupported-operation]
with utils.map_seed_manager(seed):
dataset = self._preprocess_dataset(
dataset,
Expand Down Expand Up @@ -1530,7 +1530,7 @@ def assert_cached(self) -> None:
), f"'{self.name}' does not exist in any of the task cache directories."

def get_cached_stats(
self, split: str = tfds.Split.TRAIN
self, split: str = tfds.Split.TRAIN # pyrefly: ignore[missing-attribute]
) -> Mapping[str, Union[int, float]]:
"""Returns basic statistics for cached dataset."""
self.assert_cached()
Expand All @@ -1547,7 +1547,7 @@ def get_cached_stats(
def get_dataset(
self, # pytype: disable=signature-mismatch # overriding-default-value-checks
sequence_length: Optional[Mapping[str, int]] = None,
split: str = tfds.Split.TRAIN,
split: str = tfds.Split.TRAIN, # pyrefly: ignore[missing-attribute]
use_cached: bool = False,
shuffle: bool = True,
shuffle_buffer_size: Optional[int] = None, # Unique to Task
Expand Down Expand Up @@ -1632,7 +1632,7 @@ def get_dataset(
)
else:
ds = source.get_dataset(split=split, shuffle=shuffle, seed=seed)
ds = ds.shard(shard_info.num_shards, shard_info.index)
ds = ds.shard(shard_info.num_shards, shard_info.index) # pyrefly: ignore[missing-attribute]

num_shards = shard_info.num_shards if shard_info else 1
if try_in_mem_cache and (
Expand All @@ -1643,7 +1643,7 @@ def get_dataset(
)
or (
source.num_input_examples(split)
and source.num_input_examples(split)
and source.num_input_examples(split) # pyrefly: ignore[unsupported-operation]
< _MAX_EXAMPLES_TO_MEM_CACHE * num_shards
)
):
Expand Down Expand Up @@ -1692,10 +1692,10 @@ def _get_cached_source(
self.assert_cached()
file_shuffle_buffer_size = (
file_shuffle_buffer_size
or self._cache_dataset_placerholder.file_shuffle_buffer_size
or self._cache_dataset_placerholder.file_shuffle_buffer_size # pyrefly: ignore[missing-attribute]
)
return _CachedDataSource(
cache_dir=self.cache_dir,
cache_dir=self.cache_dir, # pyrefly: ignore[bad-argument-type]
split=split,
file_shuffle_buffer_size=file_shuffle_buffer_size,
)
Expand All @@ -1720,7 +1720,7 @@ class TaskRegistry(DatasetProviderRegistry):

# pylint: disable=arguments-renamed
@classmethod
def add(
def add( # pyrefly: ignore[bad-override]
cls,
name: str,
source: DataSourceInterface,
Expand Down Expand Up @@ -1910,7 +1910,7 @@ def _get_submixture_rate(self, mix: "Mixture") -> float:
return float(rate)

def num_input_examples(self, split: str) -> int:
return sum(
return sum( # pyrefly: ignore[no-matching-overload]
t.num_input_examples(split) for t in self.tasks if split in t.splits
)

Expand Down Expand Up @@ -1953,7 +1953,7 @@ def get_task_dataset(
task: Task,
output_feature_keys: Set[str],
sequence_length: Optional[Mapping[str, int]] = None,
split: str = tfds.Split.TRAIN,
split: str = tfds.Split.TRAIN, # pyrefly: ignore[missing-attribute]
use_cached: bool = False,
shuffle: bool = True,
seed: Optional[int] = None,
Expand Down Expand Up @@ -1985,7 +1985,7 @@ def _get_all_mixing_rates(self, tasks):
def get_dataset( # pytype: disable=signature-mismatch # overriding-parameter-type-checks
self,
sequence_length: Optional[Mapping[str, int]] = None,
split: str = tfds.Split.TRAIN,
split: str = tfds.Split.TRAIN, # pyrefly: ignore[missing-attribute]
use_cached: bool = False,
shuffle: bool = True,
seed: Optional[int] = None,
Expand Down Expand Up @@ -2284,7 +2284,7 @@ class MixtureRegistry(DatasetProviderRegistry):

# pylint: disable=arguments-renamed
@classmethod
def add(
def add( # pyrefly: ignore[bad-override]
cls,
name,
tasks,
Expand Down
26 changes: 13 additions & 13 deletions seqio/dataset_providers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def predict_metric_fn_with_types(

valid_task_with_types = TaskRegistry.add(
"valid_metrics_with_types",
source=self.function_source,
source=self.function_source, # pyrefly: ignore[bad-argument-type]
output_features={
"inputs": dataset_providers.Feature(
test_utils.sentencepiece_vocab()
Expand Down Expand Up @@ -666,7 +666,7 @@ def test_value_errors(self):
_ = dataset_providers.Task(
"multiple_cache_placeholders",
source=dataset_providers.FunctionDataSource(
dataset_fn=dataset_fn, splits=["train", "validation"]
dataset_fn=dataset_fn, splits=["train", "validation"] # pyrefly: ignore[bad-argument-type]
),
preprocessors=[
test_utils.test_text_preprocessor,
Expand All @@ -691,7 +691,7 @@ def test_value_errors(self):
task = dataset_providers.Task(
"sequence_length_pre_cache",
dataset_providers.FunctionDataSource(
dataset_fn=dataset_fn,
dataset_fn=dataset_fn, # pyrefly: ignore[bad-argument-type]
splits=["train"],
),
preprocessors=[
Expand Down Expand Up @@ -1739,8 +1739,8 @@ def test_multidimension_sequence_length(
dataset_fn = lambda split, shuffle_files: ds
dataset_providers.TaskRegistry.add(
task_name,
source=dataset_providers.FunctionDataSource(
dataset_fn=dataset_fn, splits=["train", "validation"]
source=dataset_providers.FunctionDataSource( # pyrefly: ignore[bad-argument-type]
dataset_fn=dataset_fn, splits=["train", "validation"] # pyrefly: ignore[bad-argument-type]
),
preprocessors=[
dataset_providers.CacheDatasetPlaceholder(),
Expand Down Expand Up @@ -1953,8 +1953,8 @@ def register_dummy_task(
"""Register a dummy task for GetDatasetTest."""
dataset_providers.TaskRegistry.add(
task_name,
source=dataset_providers.FunctionDataSource(
dataset_fn=dataset_fn, splits=["train", "validation"]
source=dataset_providers.FunctionDataSource( # pyrefly: ignore[bad-argument-type]
dataset_fn=dataset_fn, splits=["train", "validation"] # pyrefly: ignore[bad-argument-type]
),
preprocessors=[
dataset_providers.CacheDatasetPlaceholder(),
Expand Down Expand Up @@ -2022,13 +2022,13 @@ def good_fn(split, shuffle_files):
del split
del shuffle_files

dataset_providers.FunctionDataSource(good_fn, splits=("train",))
dataset_providers.FunctionDataSource(good_fn, splits=("train",)) # pyrefly: ignore[bad-argument-type]

def default_good_fn(split, shuffle_files=False):
del split
del shuffle_files

dataset_providers.FunctionDataSource(default_good_fn, splits=("train",))
dataset_providers.FunctionDataSource(default_good_fn, splits=("train",)) # pyrefly: ignore[bad-argument-type]

def seed_fn(split, shuffle_files=True, seed=0):
del split
Expand All @@ -2041,7 +2041,7 @@ def extra_kwarg_good_fn(split, shuffle_files, unused_kwarg=True):
del split
del shuffle_files

dataset_providers.FunctionDataSource(extra_kwarg_good_fn, splits=("train",))
dataset_providers.FunctionDataSource(extra_kwarg_good_fn, splits=("train",)) # pyrefly: ignore[bad-argument-type]

class GoodProtocol(dataset_providers.DatasetFnCallable):

Expand All @@ -2062,7 +2062,7 @@ def __call__(self, split, shuffle_files, seed=None):
def missing_shuff(split):
del split

dataset_providers.FunctionDataSource(missing_shuff, splits=("train",))
dataset_providers.FunctionDataSource(missing_shuff, splits=("train",)) # pyrefly: ignore[bad-argument-type]

with self.assertRaisesWithLiteralMatch(
ValueError,
Expand All @@ -2075,7 +2075,7 @@ def missing_shuff(split):
def missing_split(shuffle_files):
del shuffle_files

dataset_providers.FunctionDataSource(missing_split, splits=("train",))
dataset_providers.FunctionDataSource(missing_split, splits=("train",)) # pyrefly: ignore[bad-argument-type]

with self.assertRaisesWithLiteralMatch(
ValueError,
Expand All @@ -2089,7 +2089,7 @@ def extra_pos_arg(split, shuffle_files, unused_arg):
del split
del shuffle_files

dataset_providers.FunctionDataSource(extra_pos_arg, splits=("train",))
dataset_providers.FunctionDataSource(extra_pos_arg, splits=("train",)) # pyrefly: ignore[bad-argument-type]



Expand Down
4 changes: 2 additions & 2 deletions seqio/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,7 +515,7 @@ def dataset_fn(task: Task) -> tf.data.Dataset:
self._cached_task_datasets = cached_task_datasets
self._model_feature_shapes = {
k: tuple(spec.shape)
for k, spec in eval_ds.element_spec.items()
for k, spec in eval_ds.element_spec.items() # pyrefly: ignore[unbound-name]
if spec.shape.rank > 0
}

Expand Down Expand Up @@ -810,7 +810,7 @@ def model_feature_shapes(self) -> Mapping[str, Tuple[int, ...]]:

@property
def loggers(self) -> Tuple[loggers_lib.Logger]:
return tuple(self._loggers)
return tuple(self._loggers) # pyrefly: ignore[bad-return]


class MetricManager:
Expand Down
26 changes: 13 additions & 13 deletions seqio/evaluation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,8 @@ def register_dummy_task(
"""Register a dummy task for GetDatasetTest."""
return dataset_providers.TaskRegistry.add(
task_name,
source=dataset_providers.FunctionDataSource(
dataset_fn=dataset_fn, splits=["train", "validation"]
source=dataset_providers.FunctionDataSource( # pyrefly: ignore[bad-argument-type]
dataset_fn=dataset_fn, splits=["train", "validation"] # pyrefly: ignore[bad-argument-type]
),
preprocessors=[preprocessor],
postprocess_fn=postprocess_fn,
Expand Down Expand Up @@ -143,7 +143,7 @@ def _task_from_tensor_slices(name, tensor_slices, label_classes):
return dataset_providers.Task(
name,
dataset_providers.FunctionDataSource(
lambda split, shuffle_files: tf.data.Dataset.from_tensor_slices(
lambda split, shuffle_files: tf.data.Dataset.from_tensor_slices( # pyrefly: ignore[bad-argument-type]
tensor_slices
),
splits="validation",
Expand Down Expand Up @@ -334,7 +334,7 @@ def _task_from_tensor_slices_rank2(name, tensor_slices, label_classes):
return dataset_providers.Task(
name,
dataset_providers.FunctionDataSource(
lambda split, shuffle_files: tf.data.Dataset.from_tensor_slices(
lambda split, shuffle_files: tf.data.Dataset.from_tensor_slices( # pyrefly: ignore[bad-argument-type]
tensor_slices
),
splits="validation",
Expand Down Expand Up @@ -496,9 +496,9 @@ def score_fn(

all_metrics, _ = evaluator.evaluate(
compute_metrics=True,
predict_fn=predict_fn,
score_fn=score_fn,
predict_with_aux_fn=predict_with_aux_fn,
predict_fn=predict_fn, # pyrefly: ignore[bad-argument-type]
score_fn=score_fn, # pyrefly: ignore[bad-argument-type]
predict_with_aux_fn=predict_with_aux_fn, # pyrefly: ignore[bad-argument-type]
step=42,
)
return all_metrics.result(), evaluator
Expand Down Expand Up @@ -595,7 +595,7 @@ def predict_with_aux_fn(
compute_metrics=True,
predict_fn=self.uncalled_fn,
score_fn=self.uncalled_fn,
predict_with_aux_fn=predict_with_aux_fn,
predict_with_aux_fn=predict_with_aux_fn, # pyrefly: ignore[bad-argument-type]
step=42,
)

Expand Down Expand Up @@ -732,7 +732,7 @@ def predict_fn(
return [(0, [5, 6]), (1, [6, 8])]

all_metrics, _ = evaluator.evaluate(
compute_metrics=True, predict_fn=predict_fn, score_fn=self.uncalled_fn
compute_metrics=True, predict_fn=predict_fn, score_fn=self.uncalled_fn # pyrefly: ignore[bad-argument-type]
)
# expected = {"accuracy": 2.0 / 3 * 100}
expected = {"sequence_accuracy": 50}
Expand Down Expand Up @@ -777,7 +777,7 @@ def predict_fn(
return [(0, [5]), (1, [6]), (2, [7])]

all_metrics, _ = evaluator.evaluate(
compute_metrics=True, predict_fn=predict_fn, score_fn=self.uncalled_fn
compute_metrics=True, predict_fn=predict_fn, score_fn=self.uncalled_fn # pyrefly: ignore[bad-argument-type]
)
expected = {"accuracy": 100}
self.assertDictClose(expected, all_metrics.result()[task.name])
Expand Down Expand Up @@ -846,7 +846,7 @@ def score_fn(

evaluator = Evaluator() # pytype: disable=missing-parameter
all_metrics, _ = evaluator.evaluate(
compute_metrics=True, predict_fn=predict_fn, score_fn=score_fn
compute_metrics=True, predict_fn=predict_fn, score_fn=score_fn # pyrefly: ignore[bad-argument-type]
)
expected = {
task1.name: {"sequence_accuracy": 50.0, "total_score": 651},
Expand Down Expand Up @@ -1160,7 +1160,7 @@ def mixing_order_predict_fn(

all_metrics, all_outputs = evaluator.evaluate(
compute_metrics=True,
predict_fn=mixing_order_predict_fn,
predict_fn=mixing_order_predict_fn, # pyrefly: ignore[bad-argument-type]
score_fn=self.uncalled_fn,
)
expected_metric = {"sequence_accuracy": 100}
Expand Down Expand Up @@ -1286,7 +1286,7 @@ def mock_init(self):
compute_metrics=True,
predict_fn=self.uncalled_fn,
predict_with_aux_fn=self.uncalled_fn,
score_fn=score_fn_with_intermediates,
score_fn=score_fn_with_intermediates, # pyrefly: ignore[bad-argument-type]
step=42,
)
results = all_metrics.result()
Expand Down
Loading
Loading