@@ -510,13 +510,15 @@ def check_smaller_shape(value_shape, shape, slice_shape, slice_):
510510 This follows the NumPy broadcasting rules.
511511 """
512512 # slice_shape must be as long as shape
513- if len (slice_shape ) != len (shape ):
514- raise ValueError ("slice_shape must be as long as shape" )
513+ if len (slice_shape ) != len (slice_ ):
514+ raise ValueError ("slice_shape must be as long as slice_" )
515+ no_nones_shape = tuple (sh for sh , s in zip (slice_shape , slice_ , strict = True ) if s is not None )
516+ no_nones_slice = tuple (s for sh , s in zip (slice_shape , slice_ , strict = True ) if s is not None )
515517 is_smaller_shape = any (
516- s > (1 if i >= len (value_shape ) else value_shape [i ]) for i , s in enumerate (slice_shape )
518+ s > (1 if i >= len (value_shape ) else value_shape [i ]) for i , s in enumerate (no_nones_shape )
517519 )
518520 slice_past_bounds = any (
519- s .stop > (1 if i >= len (value_shape ) else value_shape [i ]) for i , s in enumerate (slice_ )
521+ s .stop > (1 if i >= len (value_shape ) else value_shape [i ]) for i , s in enumerate (no_nones_slice )
520522 )
521523 return len (value_shape ) < len (shape ) or is_smaller_shape or slice_past_bounds
522524
@@ -547,10 +549,31 @@ def compute_smaller_slice(larger_shape, smaller_shape, larger_slice):
547549 """
548550 Returns the slice of the smaller array that corresponds to the slice of the larger array.
549551 """
550- diff_dims = len (larger_shape ) - len (smaller_shape )
552+ j_small = len (smaller_shape ) - 1
553+ j_large = len (larger_shape ) - 1
554+ smaller_shape_nones = []
555+ larger_shape_nones = []
556+ for s in reversed (larger_slice ):
557+ if s is None :
558+ smaller_shape_nones .append (1 )
559+ larger_shape_nones .append (1 )
560+ else :
561+ if j_small >= 0 :
562+ smaller_shape_nones .append (smaller_shape [j_small ])
563+ j_small -= 1
564+ if j_large >= 0 :
565+ larger_shape_nones .append (larger_shape [j_large ])
566+ j_large -= 1
567+ smaller_shape_nones .reverse ()
568+ larger_shape_nones .reverse ()
569+ diff_dims = len (larger_shape_nones ) - len (smaller_shape_nones )
551570 return tuple (
552- larger_slice [i ] if smaller_shape [i - diff_dims ] != 1 else slice (0 , larger_shape [i ])
553- for i in range (diff_dims , len (larger_shape ))
571+ None
572+ if larger_slice [i ] is None
573+ else (
574+ larger_slice [i ] if smaller_shape_nones [i - diff_dims ] != 1 else slice (0 , larger_shape_nones [i ])
575+ )
576+ for i in range (diff_dims , len (larger_shape_nones ))
554577 )
555578
556579
@@ -1694,7 +1717,6 @@ def slices_eval_getitem(
16941717 _slice_bcast = tuple (slice (i , i + 1 ) if isinstance (i , int ) else i for i in _slice .raw )
16951718 slice_shape = ndindex .ndindex (_slice_bcast ).newshape (shape ) # includes dummy dimensions
16961719 _slice = _slice .raw
1697- offset = tuple (s .start for s in _slice_bcast ) # offset for the udf
16981720
16991721 # Get the slice of each operand
17001722 slice_operands = {}
@@ -1715,6 +1737,7 @@ def slices_eval_getitem(
17151737
17161738 # Evaluate the expression using slices of operands
17171739 if callable (expression ):
1740+ offset = tuple (0 if s is None else s .start for s in _slice_bcast ) # offset for the udf
17181741 result = np .empty (slice_shape , dtype = dtype )
17191742 expression (tuple (slice_operands .values ()), result , offset = offset )
17201743 else :
@@ -2161,7 +2184,7 @@ def chunked_eval( # noqa: C901
21612184 """
21622185 try :
21632186 # standardise slice to be ndindex.Tuple
2164- item = () if item in ( None , slice (None , None , None ) ) else item
2187+ item = () if item == slice (None , None , None ) else item
21652188 item = item if isinstance (item , tuple ) else (item ,)
21662189 item = tuple (
21672190 slice (s .start , s .stop , 1 if s .step is None else s .step ) if isinstance (s , slice ) else s
0 commit comments