77import warnings
88from collections .abc import Sequence
99from types import ModuleType
10- from typing import TYPE_CHECKING , cast
10+ from typing import cast
1111
1212from ._at import at
1313from ._utils import _compat , _helpers
1414from ._utils ._compat import array_namespace , is_jax_array
15- from ._utils ._helpers import asarrays , ndindex
15+ from ._utils ._helpers import asarrays , eager_shape , ndindex
1616from ._utils ._typing import Array
1717
1818__all__ = [
@@ -211,11 +211,13 @@ def cov(m: Array, /, *, xp: ModuleType | None = None) -> Array:
211211 m = xp .astype (m , dtype )
212212
213213 avg = _helpers .mean (m , axis = 1 , xp = xp )
214- fact = m .shape [1 ] - 1
214+
215+ m_shape = eager_shape (m )
216+ fact = m_shape [1 ] - 1
215217
216218 if fact <= 0 :
217219 warnings .warn ("Degrees of freedom <= 0 for slice" , RuntimeWarning , stacklevel = 2 )
218- fact = 0.0
220+ fact = 0
219221
220222 m -= avg [:, None ]
221223 m_transpose = m .T
@@ -274,8 +276,10 @@ def create_diagonal(
274276 if x .ndim == 0 :
275277 err_msg = "`x` must be at least 1-dimensional."
276278 raise ValueError (err_msg )
277- batch_dims = x .shape [:- 1 ]
278- n = x .shape [- 1 ] + abs (offset )
279+
280+ x_shape = eager_shape (x )
281+ batch_dims = x_shape [:- 1 ]
282+ n = x_shape [- 1 ] + abs (offset )
279283 diag = xp .zeros ((* batch_dims , n ** 2 ), dtype = x .dtype , device = _compat .device (x ))
280284
281285 target_slice = slice (
@@ -385,10 +389,6 @@ def isclose(
385389) -> Array : # numpydoc ignore=PR01,RT01
386390 """See docstring in array_api_extra._delegation."""
387391 a , b = asarrays (a , b , xp = xp )
388- # FIXME https://github.com/microsoft/pyright/issues/10085
389- if TYPE_CHECKING : # pragma: nocover
390- assert _compat .is_array_api_obj (a )
391- assert _compat .is_array_api_obj (b )
392392
393393 a_inexact = xp .isdtype (a .dtype , ("real floating" , "complex floating" ))
394394 b_inexact = xp .isdtype (b .dtype , ("real floating" , "complex floating" ))
@@ -505,24 +505,17 @@ def kron(
505505 if xp is None :
506506 xp = array_namespace (a , b )
507507 a , b = asarrays (a , b , xp = xp )
508- # FIXME https://github.com/microsoft/pyright/issues/10085
509- if TYPE_CHECKING : # pragma: nocover
510- assert _compat .is_array_api_obj (a )
511- assert _compat .is_array_api_obj (b )
512508
513509 singletons = (1 ,) * (b .ndim - a .ndim )
514- a = xp .broadcast_to (a , singletons + a .shape )
515- # FIXME https://github.com/microsoft/pyright/issues/10085
516- if TYPE_CHECKING : # pragma: nocover
517- assert _compat .is_array_api_obj (a )
510+ a = cast (Array , xp .broadcast_to (a , singletons + a .shape ))
518511
519512 nd_b , nd_a = b .ndim , a .ndim
520513 nd_max = max (nd_b , nd_a )
521514 if nd_a == 0 or nd_b == 0 :
522515 return xp .multiply (a , b )
523516
524- a_shape = a . shape
525- b_shape = b . shape
517+ a_shape = eager_shape ( a )
518+ b_shape = eager_shape ( b )
526519
527520 # Equalise the shapes by prepending smaller one with 1s
528521 a_shape = (1 ,) * max (0 , nd_b - nd_a ) + a_shape
@@ -587,16 +580,14 @@ def pad(
587580) -> Array : # numpydoc ignore=PR01,RT01
588581 """See docstring in `array_api_extra._delegation.py`."""
589582 # make pad_width a list of length-2 tuples of ints
590- x_ndim = cast (int , x .ndim )
591-
592583 if isinstance (pad_width , int ):
593- pad_width_seq = [(pad_width , pad_width )] * x_ndim
584+ pad_width_seq = [(pad_width , pad_width )] * x . ndim
594585 elif (
595586 isinstance (pad_width , tuple )
596587 and len (pad_width ) == 2
597588 and all (isinstance (i , int ) for i in pad_width )
598589 ):
599- pad_width_seq = [cast (tuple [int , int ], pad_width )] * x_ndim
590+ pad_width_seq = [cast (tuple [int , int ], pad_width )] * x . ndim
600591 else :
601592 pad_width_seq = cast (list [tuple [int , int ]], list (pad_width ))
602593
@@ -608,7 +599,8 @@ def pad(
608599 msg = f"expect a 2-tuple (before, after), got { w_tpl } ."
609600 raise ValueError (msg )
610601
611- sh = x .shape [ax ]
602+ sh = eager_shape (x )[ax ]
603+
612604 if w_tpl [0 ] == 0 and w_tpl [1 ] == 0 :
613605 sl = slice (None , None , None )
614606 else :
@@ -674,20 +666,17 @@ def setdiff1d(
674666 """
675667 if xp is None :
676668 xp = array_namespace (x1 , x2 )
677- x1 , x2 = asarrays (x1 , x2 , xp = xp )
669+ # https://github.com/microsoft/pyright/issues/10103
670+ x1_ , x2_ = asarrays (x1 , x2 , xp = xp )
678671
679672 if assume_unique :
680- x1 = xp .reshape (x1 , (- 1 ,))
681- x2 = xp .reshape (x2 , (- 1 ,))
673+ x1_ = xp .reshape (x1_ , (- 1 ,))
674+ x2_ = xp .reshape (x2_ , (- 1 ,))
682675 else :
683- x1 = xp .unique_values (x1 )
684- x2 = xp .unique_values (x2 )
685-
686- # FIXME https://github.com/microsoft/pyright/issues/10085
687- if TYPE_CHECKING : # pragma: nocover
688- assert _compat .is_array_api_obj (x1 )
676+ x1_ = xp .unique_values (x1_ )
677+ x2_ = xp .unique_values (x2_ )
689678
690- return x1 [_helpers .in1d (x1 , x2 , assume_unique = True , invert = True , xp = xp )]
679+ return x1_ [_helpers .in1d (x1_ , x2_ , assume_unique = True , invert = True , xp = xp )]
691680
692681
693682def sinc (x : Array , / , * , xp : ModuleType | None = None ) -> Array :
0 commit comments