|
5 | 5 |
|
6 | 6 | import math |
7 | 7 | import warnings |
| 8 | +from collections.abc import Sequence |
8 | 9 | from types import ModuleType |
9 | 10 | from typing import cast |
10 | 11 |
|
@@ -448,23 +449,30 @@ def nunique(x: Array, /, *, xp: ModuleType | None = None) -> Array: |
448 | 449 |
|
449 | 450 | def pad( |
450 | 451 | x: Array, |
451 | | - pad_width: int | tuple[int, int] | list[tuple[int, int]], |
| 452 | + pad_width: int | tuple[int, int] | Sequence[tuple[int, int]], |
452 | 453 | *, |
453 | 454 | constant_values: bool | int | float | complex = 0, |
454 | 455 | xp: ModuleType, |
455 | 456 | ) -> Array: # numpydoc ignore=PR01,RT01 |
456 | 457 | """See docstring in `array_api_extra._delegation.py`.""" |
457 | 458 | # make pad_width a list of length-2 tuples of ints |
458 | 459 | x_ndim = cast(int, x.ndim) |
| 460 | + |
459 | 461 | if isinstance(pad_width, int): |
460 | | - pad_width = [(pad_width, pad_width)] * x_ndim |
461 | | - if isinstance(pad_width, tuple): |
462 | | - pad_width = [pad_width] * x_ndim |
| 462 | + pad_width_seq = [(pad_width, pad_width)] * x_ndim |
| 463 | + elif ( |
| 464 | + isinstance(pad_width, tuple) |
| 465 | + and len(pad_width) == 2 |
| 466 | + and all(isinstance(i, int) for i in pad_width) |
| 467 | + ): |
| 468 | + pad_width_seq = [cast(tuple[int, int], pad_width)] * x_ndim |
| 469 | + else: |
| 470 | + pad_width_seq = cast(list[tuple[int, int]], list(pad_width)) |
463 | 471 |
|
464 | 472 | # https://github.com/python/typeshed/issues/13376 |
465 | 473 | slices: list[slice] = [] # type: ignore[no-any-explicit] |
466 | 474 | newshape: list[int] = [] |
467 | | - for ax, w_tpl in enumerate(pad_width): |
| 475 | + for ax, w_tpl in enumerate(pad_width_seq): |
468 | 476 | if len(w_tpl) != 2: |
469 | 477 | msg = f"expect a 2-tuple (before, after), got {w_tpl}." |
470 | 478 | raise ValueError(msg) |
|
0 commit comments