@@ -3945,20 +3945,12 @@ def get_fselection_numpy(self, key: list | np.ndarray) -> np.ndarray:
39453945 # Default when there are booleans
39463946 # TODO: for boolean indexing could be optimised by avoiding
39473947 # calculating out_shape prior to loop and keeping track on-the-fly (like in LazyExpr machinery)
3948- return self ._get_set_findex_default (_slice , out_shape )
3949-
3950- def _get_set_findex_default (self , _slice , out_shape = None , updater = None ):
3951- _get = False
3952- if not ((out_shape is None ) or (updater is None )):
3953- raise ValueError ("Cannot provide both out_shape and updater." )
3954- # we have a getitem
3955- if out_shape is not None :
3956- _get = True
3957- out = np .empty (out_shape , dtype = self .dtype )
3958- elif updater is None :
3959- raise ValueError ("Must provide one of out_shape or updater." )
3960- else :
3961- out = self # default return for no intersecting chunks
3948+ out = np .empty (out_shape , dtype = self .dtype )
3949+ return self ._get_set_findex_default (_slice , out )
3950+
3951+ def _get_set_findex_default (self , _slice , out = None , value = None ):
3952+ _get = out is not None
3953+ out = self if out is None else out # default return for setitem with no intersecting chunks
39623954 if 0 in self .shape :
39633955 return out
39643956 chunk_size = ndindex .ChunkSize (self .chunks ) # only works with nonzero chunks
@@ -3973,10 +3965,10 @@ def _get_set_findex_default(self, _slice, out_shape=None, updater=None):
39733965 chunk = np .empty (tuple (sp - st for st , sp in zip (start , stop , strict = True )), dtype = self .dtype )
39743966 super ().get_slice_numpy (chunk , (start , stop ))
39753967 if _get :
3976- new_shape = sel_idx .newshape (out_shape )
3968+ new_shape = sel_idx .newshape (out . shape )
39773969 out [sel_idx .raw ] = chunk [sub_idx ].reshape (new_shape )
39783970 else :
3979- chunk [sub_idx ] = updater ( sel_idx . raw )
3971+ chunk [sub_idx ] = value if np . isscalar ( value ) else value [ sel_idx ]
39803972 out = super ().set_slice ((start , stop ), chunk )
39813973 return out
39823974
@@ -3998,7 +3990,42 @@ def set_oselection_numpy(self, key: list | np.ndarray, arr: NDArray) -> np.ndarr
39983990 """
39993991 return super ().set_oindex_numpy (key , arr )
40003992
4001- def __getitem__ ( # noqa: C901
3993+ def _get_set_nonunit_steps (self , _slice , out = None , value = None ):
3994+ start , stop , step , mask = _slice
3995+ _get = out is not None
3996+ out = self if out is None else out # default return for setitem with no intersecting chunks
3997+ if 0 in self .shape :
3998+ return out
3999+
4000+ chunks = self .chunks
4001+ _slice = tuple (slice (s , st , stp ) for s , st , stp in zip (start , stop , step , strict = True ))
4002+ intersecting_chunks = [
4003+ slice_to_chunktuple (s , c ) for s , c in zip (_slice , chunks , strict = True )
4004+ ] # internally handles negative steps
4005+ for c in product (* intersecting_chunks ):
4006+ sel_idx , glob_selection , sub_idx = _get_selection (c , _slice , chunks )
4007+ sel_idx = tuple (s for s , m in zip (sel_idx , mask , strict = True ) if not m )
4008+ sub_idx = tuple (s if not m else s .start for s , m in zip (sub_idx , mask , strict = True ))
4009+ locstart , locstop = _get_local_slice (
4010+ glob_selection ,
4011+ (),
4012+ ((), ()), # switches start and stop for negative steps
4013+ )
4014+ chunk = np .empty (
4015+ tuple (sp - st for st , sp in zip (locstart , locstop , strict = True )), dtype = self .dtype
4016+ )
4017+ # basically load whole chunk, except for slice part at beginning and end
4018+ super ().get_slice_numpy (chunk , (locstart , locstop )) # copy relevant slice of chunk
4019+ if _get :
4020+ out [sel_idx ] = chunk [sub_idx ] # update relevant parts of chunk
4021+ else :
4022+ chunk [sub_idx ] = (
4023+ value if np .isscalar (value ) else value [sel_idx ]
4024+ ) # update relevant parts of chunk
4025+ out = super ().set_slice ((locstart , locstop ), chunk ) # load updated partial chunk into array
4026+ return out
4027+
4028+ def __getitem__ (
40024029 self ,
40034030 key : None
40044031 | int
@@ -4080,7 +4107,8 @@ def __getitem__( # noqa: C901
40804107 if key :
40814108 _slice = ndindex .ndindex (()).expand (self .shape ) # just get whole array
40824109 out_shape = _slice .newshape (self .shape )
4083- return np .expand_dims (self ._get_set_findex_default (_slice , out_shape = out_shape ), 0 )
4110+ out = np .empty (out_shape , dtype = self .dtype )
4111+ return np .expand_dims (self ._get_set_findex_default (_slice , out = out ), 0 )
40844112 else : # do nothing
40854113 return np .empty ((0 ,) + self .shape , dtype = self .dtype )
40864114 elif (
@@ -4096,12 +4124,9 @@ def __getitem__( # noqa: C901
40964124 return self .get_fselection_numpy (key ) # fancy index default, can be quite slow
40974125
40984126 start , stop , step , none_mask = get_ndarray_start_stop (self .ndim , key_ , self .shape )
4099- for i , s in enumerate (step ): # (start, stop, -1) => stop < start
4100- if s < 0 :
4101- temp = start [i ]
4102- start [i ] = stop [i ] + 1 # don't want to include stop
4103- stop [i ] = temp + 1 # want to include start
4104- shape = np .array ([sp - st for st , sp in zip (start , stop , strict = True )])
4127+ shape = np .array (
4128+ [(sp - st - np .sign (stp )) // stp + 1 for st , sp , stp in zip (start , stop , step , strict = True )]
4129+ )
41054130 if mask is not None : # there are some dummy dims from ints
41064131 # only get mask for not Nones in key to have nm_ same length as shape
41074132 nm_ = [not m for m , n in zip (mask , none_mask , strict = True ) if not n ]
@@ -4110,12 +4135,11 @@ def __getitem__( # noqa: C901
41104135 shape = tuple (shape [nm_ ])
41114136
41124137 # Create the array to store the result
4113- arr = np .empty (shape , dtype = self .dtype )
4114- nparr = super ().get_slice_numpy (arr , (start , stop ))
4115- if step != (1 ,) * self .ndim : # TODO: optimise to work like __setitem__ for non-unit steps
4116- # have to make step refer to sliced dims (which will be less if ints present)
4117- slice_ = tuple (slice (None , None , st ) for st , m in zip (step , nm_ , strict = True ) if m )
4118- nparr = nparr [slice_ ]
4138+ nparr = np .empty (shape , dtype = self .dtype )
4139+ if step != (1 ,) * self .ndim :
4140+ nparr = self ._get_set_nonunit_steps ((start , stop , step , [not i for i in nm_ ]), out = nparr )
4141+ else :
4142+ nparr = super ().get_slice_numpy (nparr , (start , stop ))
41194143
41204144 if np .any (none_mask ):
41214145 nparr = np .expand_dims (nparr , axis = [i for i , n in enumerate (none_mask ) if n ])
@@ -4127,7 +4151,7 @@ def __getitem__( # noqa: C901
41274151
41284152 return nparr
41294153
4130- def __setitem__ ( # noqa : C901
4154+ def __setitem__ (
41314155 self ,
41324156 key : None | int | slice | Sequence [slice | int | np .bool_ | np .ndarray [int | np .bool_ ] | None ],
41334157 value : object ,
@@ -4171,14 +4195,6 @@ def __setitem__( # noqa : C901
41714195 if hasattr (value , "shape" ) and value .shape == ():
41724196 value = value .item ()
41734197
4174- def updater (sel_idx ):
4175- return value [sel_idx ]
4176-
4177- if np .isscalar (value ): # overwrite updater function for simple cases (faster)
4178-
4179- def updater (sel_idx ):
4180- return value
4181-
41824198 if builtins .any (isinstance (k , (list , np .ndarray )) for k in key_ ): # fancy indexing
41834199 _slice = ndindex .ndindex (key_ ).expand (
41844200 self .shape
@@ -4191,36 +4207,14 @@ def updater(sel_idx):
41914207 _slice = ndindex .ndindex (()).expand (self .shape ) # just get whole array
41924208 else : # do nothing
41934209 return self
4194- return self ._get_set_findex_default (_slice , updater = updater )
4210+ return self ._get_set_findex_default (_slice , value = value )
41954211
41964212 start , stop , step , none_mask = get_ndarray_start_stop (self .ndim , key_ , self .shape )
41974213
41984214 if step != (1 ,) * self .ndim : # handle non-unit or negative steps
41994215 if np .any (none_mask ):
42004216 raise ValueError ("Cannot mix non-unit steps and None indexing for __setitem__." )
4201- chunks = self .chunks
4202- shape = self .shape
4203- _slice = tuple (slice (s , st , stp ) for s , st , stp in zip (start , stop , step , strict = True ))
4204- intersecting_chunks = [
4205- slice_to_chunktuple (s , c ) for s , c in zip (_slice , chunks , strict = True )
4206- ] # internally handles negative steps
4207- out = self # for when shape has 0 (i.e. arr is empty, as then skip loop)
4208- for c in product (* intersecting_chunks ):
4209- sel_idx , glob_selection , sub_idx = _get_selection (c , _slice , chunks )
4210- sel_idx = tuple (s for s , m in zip (sel_idx , mask , strict = True ) if not m )
4211- sub_idx = tuple (s if not m else s .start for s , m in zip (sub_idx , mask , strict = True ))
4212- locstart , locstop = _get_local_slice (
4213- glob_selection ,
4214- (),
4215- ((), ()), # switches start and stop for negative steps
4216- )
4217- chunk = np .empty (
4218- tuple (sp - st for st , sp in zip (locstart , locstop , strict = True )), dtype = self .dtype
4219- )
4220- super ().get_slice_numpy (chunk , (locstart , locstop )) # copy relevant slice of chunk
4221- chunk [sub_idx ] = updater (sel_idx ) # update relevant parts of chunk
4222- out = super ().set_slice ((locstart , locstop ), chunk ) # load updated partial chunk into array
4223- return out
4217+ return self ._get_set_nonunit_steps ((start , stop , step , mask ), value = value )
42244218
42254219 shape = [sp - st for sp , st in zip (stop , start , strict = False )]
42264220 if isinstance (value , NDArray ):
@@ -6313,7 +6307,7 @@ def _get_selection(ctuple, ptuple, chunks):
63136307 out_pselection = ()
63146308 i = 0
63156309 for ps , pt in zip (pselection , ptuple , strict = True ):
6316- sign_ = pt . step // builtins . abs (pt .step )
6310+ sign_ = np . sign (pt .step )
63176311 n = (ps .start - pt .start - sign_ ) // pt .step
63186312 out_start = n + 1
63196313 # ps.stop always positive except for case where get full array (it is then -1 since desire 0th element)
0 commit comments