@@ -3948,20 +3948,12 @@ def get_fselection_numpy(self, key: list | np.ndarray) -> np.ndarray:
39483948 # Default when there are booleans
39493949 # TODO: for boolean indexing could be optimised by avoiding
39503950 # calculating out_shape prior to loop and keeping track on-the-fly (like in LazyExpr machinery)
3951- return self ._get_set_findex_default (_slice , out_shape )
3952-
3953- def _get_set_findex_default (self , _slice , out_shape = None , updater = None ):
3954- _get = False
3955- if not ((out_shape is None ) or (updater is None )):
3956- raise ValueError ("Cannot provide both out_shape and updater." )
3957- # we have a getitem
3958- if out_shape is not None :
3959- _get = True
3960- out = np .empty (out_shape , dtype = self .dtype )
3961- elif updater is None :
3962- raise ValueError ("Must provide one of out_shape or updater." )
3963- else :
3964- out = self # default return for no intersecting chunks
3951+ out = np .empty (out_shape , dtype = self .dtype )
3952+ return self ._get_set_findex_default (_slice , out )
3953+
3954+ def _get_set_findex_default (self , _slice , out = None , value = None ):
3955+ _get = out is not None
3956+ out = self if out is None else out # default return for setitem with no intersecting chunks
39653957 if 0 in self .shape :
39663958 return out
39673959 chunk_size = ndindex .ChunkSize (self .chunks ) # only works with nonzero chunks
@@ -3976,10 +3968,10 @@ def _get_set_findex_default(self, _slice, out_shape=None, updater=None):
39763968 chunk = np .empty (tuple (sp - st for st , sp in zip (start , stop , strict = True )), dtype = self .dtype )
39773969 super ().get_slice_numpy (chunk , (start , stop ))
39783970 if _get :
3979- new_shape = sel_idx .newshape (out_shape )
3971+ new_shape = sel_idx .newshape (out . shape )
39803972 out [sel_idx .raw ] = chunk [sub_idx ].reshape (new_shape )
39813973 else :
3982- chunk [sub_idx ] = updater ( sel_idx . raw )
3974+ chunk [sub_idx ] = value if np . isscalar ( value ) else value [ sel_idx ]
39833975 out = super ().set_slice ((start , stop ), chunk )
39843976 return out
39853977
@@ -4001,7 +3993,42 @@ def set_oselection_numpy(self, key: list | np.ndarray, arr: NDArray) -> np.ndarr
40013993 """
40023994 return super ().set_oindex_numpy (key , arr )
40033995
4004- def __getitem__ ( # noqa: C901
3996+ def _get_set_nonunit_steps (self , _slice , out = None , value = None ):
3997+ start , stop , step , mask = _slice
3998+ _get = out is not None
3999+ out = self if out is None else out # default return for setitem with no intersecting chunks
4000+ if 0 in self .shape :
4001+ return out
4002+
4003+ chunks = self .chunks
4004+ _slice = tuple (slice (s , st , stp ) for s , st , stp in zip (start , stop , step , strict = True ))
4005+ intersecting_chunks = [
4006+ slice_to_chunktuple (s , c ) for s , c in zip (_slice , chunks , strict = True )
4007+ ] # internally handles negative steps
4008+ for c in product (* intersecting_chunks ):
4009+ sel_idx , glob_selection , sub_idx = _get_selection (c , _slice , chunks )
4010+ sel_idx = tuple (s for s , m in zip (sel_idx , mask , strict = True ) if not m )
4011+ sub_idx = tuple (s if not m else s .start for s , m in zip (sub_idx , mask , strict = True ))
4012+ locstart , locstop = _get_local_slice (
4013+ glob_selection ,
4014+ (),
4015+ ((), ()), # switches start and stop for negative steps
4016+ )
4017+ chunk = np .empty (
4018+ tuple (sp - st for st , sp in zip (locstart , locstop , strict = True )), dtype = self .dtype
4019+ )
4020+ # basically load whole chunk, except for slice part at beginning and end
4021+ super ().get_slice_numpy (chunk , (locstart , locstop )) # copy relevant slice of chunk
4022+ if _get :
4023+ out [sel_idx ] = chunk [sub_idx ] # update relevant parts of chunk
4024+ else :
4025+ chunk [sub_idx ] = (
4026+ value if np .isscalar (value ) else value [sel_idx ]
4027+ ) # update relevant parts of chunk
4028+ out = super ().set_slice ((locstart , locstop ), chunk ) # load updated partial chunk into array
4029+ return out
4030+
4031+ def __getitem__ (
40054032 self ,
40064033 key : None
40074034 | int
@@ -4083,7 +4110,8 @@ def __getitem__( # noqa: C901
40834110 if key :
40844111 _slice = ndindex .ndindex (()).expand (self .shape ) # just get whole array
40854112 out_shape = _slice .newshape (self .shape )
4086- return np .expand_dims (self ._get_set_findex_default (_slice , out_shape = out_shape ), 0 )
4113+ out = np .empty (out_shape , dtype = self .dtype )
4114+ return np .expand_dims (self ._get_set_findex_default (_slice , out = out ), 0 )
40874115 else : # do nothing
40884116 return np .empty ((0 ,) + self .shape , dtype = self .dtype )
40894117 elif (
@@ -4099,12 +4127,9 @@ def __getitem__( # noqa: C901
40994127 return self .get_fselection_numpy (key ) # fancy index default, can be quite slow
41004128
41014129 start , stop , step , none_mask = get_ndarray_start_stop (self .ndim , key_ , self .shape )
4102- for i , s in enumerate (step ): # (start, stop, -1) => stop < start
4103- if s < 0 :
4104- temp = start [i ]
4105- start [i ] = stop [i ] + 1 # don't want to include stop
4106- stop [i ] = temp + 1 # want to include start
4107- shape = np .array ([sp - st for st , sp in zip (start , stop , strict = True )])
4130+ shape = np .array (
4131+ [(sp - st - np .sign (stp )) // stp + 1 for st , sp , stp in zip (start , stop , step , strict = True )]
4132+ )
41084133 if mask is not None : # there are some dummy dims from ints
41094134 # only get mask for not Nones in key to have nm_ same length as shape
41104135 nm_ = [not m for m , n in zip (mask , none_mask , strict = True ) if not n ]
@@ -4113,12 +4138,11 @@ def __getitem__( # noqa: C901
41134138 shape = tuple (shape [nm_ ])
41144139
41154140 # Create the array to store the result
4116- arr = np .empty (shape , dtype = self .dtype )
4117- nparr = super ().get_slice_numpy (arr , (start , stop ))
4118- if step != (1 ,) * self .ndim : # TODO: optimise to work like __setitem__ for non-unit steps
4119- # have to make step refer to sliced dims (which will be less if ints present)
4120- slice_ = tuple (slice (None , None , st ) for st , m in zip (step , nm_ , strict = True ) if m )
4121- nparr = nparr [slice_ ]
4141+ nparr = np .empty (shape , dtype = self .dtype )
4142+ if step != (1 ,) * self .ndim :
4143+ nparr = self ._get_set_nonunit_steps ((start , stop , step , [not i for i in nm_ ]), out = nparr )
4144+ else :
4145+ nparr = super ().get_slice_numpy (nparr , (start , stop ))
41224146
41234147 if np .any (none_mask ):
41244148 nparr = np .expand_dims (nparr , axis = [i for i , n in enumerate (none_mask ) if n ])
@@ -4130,7 +4154,7 @@ def __getitem__( # noqa: C901
41304154
41314155 return nparr
41324156
4133- def __setitem__ ( # noqa : C901
4157+ def __setitem__ (
41344158 self ,
41354159 key : None | int | slice | Sequence [slice | int | np .bool_ | np .ndarray [int | np .bool_ ] | None ],
41364160 value : object ,
@@ -4174,14 +4198,6 @@ def __setitem__( # noqa : C901
41744198 if hasattr (value , "shape" ) and value .shape == ():
41754199 value = value .item ()
41764200
4177- def updater (sel_idx ):
4178- return value [sel_idx ]
4179-
4180- if np .isscalar (value ): # overwrite updater function for simple cases (faster)
4181-
4182- def updater (sel_idx ):
4183- return value
4184-
41854201 if builtins .any (isinstance (k , (list , np .ndarray )) for k in key_ ): # fancy indexing
41864202 _slice = ndindex .ndindex (key_ ).expand (
41874203 self .shape
@@ -4194,36 +4210,14 @@ def updater(sel_idx):
41944210 _slice = ndindex .ndindex (()).expand (self .shape ) # just get whole array
41954211 else : # do nothing
41964212 return self
4197- return self ._get_set_findex_default (_slice , updater = updater )
4213+ return self ._get_set_findex_default (_slice , value = value )
41984214
41994215 start , stop , step , none_mask = get_ndarray_start_stop (self .ndim , key_ , self .shape )
42004216
42014217 if step != (1 ,) * self .ndim : # handle non-unit or negative steps
42024218 if np .any (none_mask ):
42034219 raise ValueError ("Cannot mix non-unit steps and None indexing for __setitem__." )
4204- chunks = self .chunks
4205- shape = self .shape
4206- _slice = tuple (slice (s , st , stp ) for s , st , stp in zip (start , stop , step , strict = True ))
4207- intersecting_chunks = [
4208- slice_to_chunktuple (s , c ) for s , c in zip (_slice , chunks , strict = True )
4209- ] # internally handles negative steps
4210- out = self # for when shape has 0 (i.e. arr is empty, as then skip loop)
4211- for c in product (* intersecting_chunks ):
4212- sel_idx , glob_selection , sub_idx = _get_selection (c , _slice , chunks )
4213- sel_idx = tuple (s for s , m in zip (sel_idx , mask , strict = True ) if not m )
4214- sub_idx = tuple (s if not m else s .start for s , m in zip (sub_idx , mask , strict = True ))
4215- locstart , locstop = _get_local_slice (
4216- glob_selection ,
4217- (),
4218- ((), ()), # switches start and stop for negative steps
4219- )
4220- chunk = np .empty (
4221- tuple (sp - st for st , sp in zip (locstart , locstop , strict = True )), dtype = self .dtype
4222- )
4223- super ().get_slice_numpy (chunk , (locstart , locstop )) # copy relevant slice of chunk
4224- chunk [sub_idx ] = updater (sel_idx ) # update relevant parts of chunk
4225- out = super ().set_slice ((locstart , locstop ), chunk ) # load updated partial chunk into array
4226- return out
4220+ return self ._get_set_nonunit_steps ((start , stop , step , mask ), value = value )
42274221
42284222 shape = [sp - st for sp , st in zip (stop , start , strict = False )]
42294223 if isinstance (value , NDArray ):
@@ -6320,7 +6314,7 @@ def _get_selection(ctuple, ptuple, chunks):
63206314 out_pselection = ()
63216315 i = 0
63226316 for ps , pt in zip (pselection , ptuple , strict = True ):
6323- sign_ = pt . step // builtins . abs (pt .step )
6317+ sign_ = np . sign (pt .step )
63246318 n = (ps .start - pt .start - sign_ ) // pt .step
63256319 out_start = n + 1
63266320 # ps.stop always positive except for case where get full array (it is then -1 since desire 0th element)
0 commit comments