@@ -568,7 +568,7 @@ def compute_broadcast_shape(arrays):
568568 return np .broadcast_shapes (* shapes ) if shapes else None
569569
570570
571- def check_smaller_shape (value_shape , shape , slice_shape ):
571+ def check_smaller_shape (value_shape , shape , slice_shape , slice_ ):
572572 """Check whether the shape of the value is smaller than the shape of the array.
573573
574574 This follows the NumPy broadcasting rules.
@@ -579,7 +579,10 @@ def check_smaller_shape(value_shape, shape, slice_shape):
579579 is_smaller_shape = any (
580580 s > (1 if i >= len (value_shape ) else value_shape [i ]) for i , s in enumerate (slice_shape )
581581 )
582- return len (value_shape ) < len (shape ) or is_smaller_shape
582+ slice_past_bounds = any (
583+ s .stop > (1 if i >= len (value_shape ) else value_shape [i ]) for i , s in enumerate (slice_ )
584+ )
585+ return len (value_shape ) < len (shape ) or is_smaller_shape or slice_past_bounds
583586
584587
585588def _compute_smaller_slice (larger_shape , smaller_shape , larger_slice ):
@@ -1396,29 +1399,35 @@ def slices_eval( # noqa: C901
13961399
13971400 dtype = kwargs .pop ("dtype" , None )
13981401 shape_slice = None
1399- need_orig_slice = False
1402+ need_final_slice = False
14001403
1401- # keep orig_slice for final getitem
1404+ # keep orig_slice
14021405 _slice = _slice .raw
14031406 orig_slice = _slice
1407+ full_slice = () # by default the full_slice is the whole array
1408+ final_slice = () # by default the final_slice is the whole array
1409+
1410+ # Compute the shape and chunks of the output array, including broadcasting
1411+ shape = compute_broadcast_shape (operands .values ())
14041412 if out is None :
1405- # Compute the shape and chunks of the output array, including broadcasting
1406- shape = compute_broadcast_shape (operands .values ())
14071413 if _slice != ():
14081414 # Check whether _slice contains an integer, or any step that are not None or 1
14091415 if any (
14101416 (isinstance (s , int )) or (isinstance (s , slice ) and s .step not in (None , 1 )) for s in _slice
14111417 ):
1412- need_orig_slice = True
1418+ need_final_slice = True
14131419 _slice = tuple (slice (i , i + 1 ) if isinstance (i , int ) else i for i in _slice )
14141420 full_slice = tuple (
1415- slice (s .start or 0 , s .stop or shape [i ], None ) for i , s in enumerate (_slice )
1421+ slice (s .start or 0 , s .stop or shape [i ], 1 ) for i , s in enumerate (_slice )
14161422 ) # get rid of non-unit steps
14171423 # shape_slice in general not equal to final shape:
1418- # dummy dims (due to ints) or non-unit steps will be dealt with by taking orig_slice
1424+ # dummy dims (due to ints) or non-unit steps will be dealt with by taking final_slice
14191425 shape_slice = ndindex .ndindex (full_slice ).newshape (shape )
1426+ final_slice = ndindex .ndindex (orig_slice ).as_subindex (full_slice ).raw
14201427 else :
1421- # TODO: check that this is fine since __out__ could have shape_slice and __shape__ should refer to full array
1428+ # # out should always have shape of full array
1429+ # if shape is not None and shape != out.shape:
1430+ # raise ValueError("Provided output shape does not match the operands' shape.")
14221431 shape = out .shape
14231432
14241433 if chunks is None : # Guess chunk shape
@@ -1457,9 +1466,8 @@ def slices_eval( # noqa: C901
14571466 intersecting_chunks = chunk_size .as_subchunks (_slice , shape ) # if _slice is (), returns all chunks
14581467
14591468 for nchunk , chunk_slice in enumerate (intersecting_chunks ):
1460- # Check whether current chunk_slice intersects with _slice
1469+ # get intersection of chunk and target
14611470 if _slice != ():
1462- # get intersection of chunk and target
14631471 cslice = tuple (
14641472 slice (max (s1 .start , s2 .start ), min (s1 .stop , s2 .stop ))
14651473 for s1 , s2 in zip (chunk_slice .raw , _slice , strict = True )
@@ -1469,6 +1477,10 @@ def slices_eval( # noqa: C901
14691477
14701478 cslice_shape = tuple (s .stop - s .start for s in cslice )
14711479 len_chunk = math .prod (cslice_shape )
1480+ # get local index of part of out that is to be updated
1481+ cslice_subidx = (
1482+ ndindex .ndindex (cslice ).as_subindex (full_slice ).raw
1483+ ) # in the case full_slice=(), just gives cslice
14721484
14731485 # Get the starts and stops for the slice
14741486 starts = [s .start if s .start is not None else 0 for s in cslice ]
@@ -1482,7 +1494,7 @@ def slices_eval( # noqa: C901
14821494 if value .shape == ():
14831495 chunk_operands [key ] = value [()]
14841496 continue
1485- if check_smaller_shape (value .shape , shape , cslice_shape ):
1497+ if check_smaller_shape (value .shape , shape , cslice_shape , cslice ):
14861498 # We need to fetch the part of the value that broadcasts with the operand
14871499 smaller_slice = compute_smaller_slice (shape , value .shape , cslice )
14881500 chunk_operands [key ] = value [smaller_slice ]
@@ -1502,6 +1514,7 @@ def slices_eval( # noqa: C901
15021514
15031515 if callable (expression ):
15041516 result = np .empty (cslice_shape , dtype = out .dtype ) # raises error if out is None
1517+ # cslice should be equal to cslice_subidx
15051518 # Call the udf directly and use result as the output array
15061519 offset = tuple (s .start for s in cslice )
15071520 expression (tuple (chunk_operands .values ()), result , offset = offset )
@@ -1573,12 +1586,12 @@ def slices_eval( # noqa: C901
15731586 out .schunk .update_data (nchunk , result , copy = False )
15741587 else :
15751588 try :
1576- out [cslice ] = result
1589+ out [cslice_subidx ] = result
15771590 except ComplexWarning :
15781591 # The result is a complex number, so we need to convert it to real.
15791592 # This is a workaround for rigidness of numpy with type casting.
15801593 result = result .real .astype (out .dtype )
1581- out [cslice ] = result
1594+ out [cslice_subidx ] = result
15821595 elif len (where ) == 1 :
15831596 lenres = len (result )
15841597 out [lenout : lenout + lenres ] = result
@@ -1588,7 +1601,7 @@ def slices_eval( # noqa: C901
15881601 else :
15891602 raise ValueError ("The where condition must be a tuple with one or two elements" )
15901603
1591- if where is not None and len (where ) < 2 : # Don't need to take orig_slice since filled up from 0 index
1604+ if where is not None and len (where ) < 2 : # Don't need to take final_slice since filled up from 0 index
15921605 if _order is not None :
15931606 # argsort the result following _order
15941607 new_order = np .argsort (out [:lenout ])
@@ -1600,16 +1613,16 @@ def slices_eval( # noqa: C901
16001613 else :
16011614 out .resize ((lenout ,))
16021615
1603- else : # Need to take orig_slice since filled up array according to slice_ for each chunk
1604- if orig_slice != ():
1616+ else : # Need to take final_slice since filled up array according to slice_ for each chunk
1617+ if final_slice != ():
16051618 if isinstance (out , np .ndarray ):
1606- if need_orig_slice :
1607- out = out [orig_slice ]
1619+ if need_final_slice : # only called if out was None
1620+ out = out [final_slice ]
16081621 elif isinstance (out , blosc2 .NDArray ):
16091622 # It *seems* better to choose an automatic chunks and blocks for the output array
16101623 # out = out.slice(_slice, chunks=out.chunks, blocks=out.blocks)
1611- if need_orig_slice :
1612- out = out .slice (orig_slice )
1624+ if need_final_slice : # only called if out was None
1625+ out = out .slice (final_slice )
16131626 else :
16141627 raise ValueError ("The output array is not a NumPy array or a NDArray" )
16151628
@@ -1679,7 +1692,7 @@ def slices_eval_getitem(
16791692 if value .shape == ():
16801693 slice_operands [key ] = value [()]
16811694 continue
1682- if check_smaller_shape (value .shape , shape , slice_shape ):
1695+ if check_smaller_shape (value .shape , shape , slice_shape , _slice_bcast ):
16831696 # We need to fetch the part of the value that broadcasts with the operand
16841697 smaller_slice = compute_smaller_slice (shape , value .shape , _slice )
16851698 slice_operands [key ] = value [smaller_slice ]
@@ -1703,7 +1716,8 @@ def slices_eval_getitem(
17031716 # This is a workaround for rigidness of numpy with type casting.
17041717 return result .real .astype (dtype , copy = False )
17051718 else :
1706- out [()] = result
1719+ # out should always have maximal shape
1720+ out [_slice ] = result
17071721 return out
17081722
17091723
@@ -1721,6 +1735,19 @@ def infer_reduction_dtype(dtype, operation):
17211735 raise ValueError (f"Unsupported operation: { operation } " )
17221736
17231737
1738+ def step_handler (s1start , s2start , s1stop , s2stop , s2step ):
1739+ # assume s1step = 1
1740+ newstart = max (s1start , s2start )
1741+ newstop = min (s1stop , s2stop )
1742+ rem = (newstart - s2start ) % s2step
1743+ if rem != 0 : # only pass through here if s2step is not 1
1744+ newstart += s2step - rem
1745+ # true_stop = start + n*step + 1 -> stop = start + n * step + 1 + residual
1746+ # so n = (stop - start - 1) // step
1747+ newstop = newstart + (newstop - newstart - 1 ) // s2step * s2step + 1
1748+ return slice (newstart , newstop , s2step )
1749+
1750+
17241751def reduce_slices ( # noqa: C901
17251752 expression : str | Callable [[tuple , np .ndarray , tuple [int ]], None ],
17261753 operands : dict ,
@@ -1770,14 +1797,26 @@ def reduce_slices( # noqa: C901
17701797 # Compute the shape and chunks of the output array, including broadcasting
17711798 shape = compute_broadcast_shape (operands .values ())
17721799
1800+ _slice = _slice .raw
1801+ shape_slice = shape
1802+ full_slice = () # by default the full_slice is the whole array
1803+ if out is None and _slice != ():
1804+ shape_slice = ndindex .ndindex (_slice ).newshape (shape )
1805+ full_slice = _slice
1806+
1807+ # after slicing, we reduce to calculate shape of output
17731808 if axis is None :
1774- axis = tuple (range (len (shape )))
1809+ axis = tuple (range (len (shape_slice )))
17751810 elif not isinstance (axis , tuple ):
17761811 axis = (axis ,)
1812+ axis = tuple (a if a >= 0 else a + len (shape_slice ) for a in axis )
17771813 if keepdims :
1778- reduced_shape = tuple (1 if i in axis else s for i , s in enumerate (shape ))
1814+ reduced_shape = tuple (1 if i in axis else s for i , s in enumerate (shape_slice ))
17791815 else :
1780- reduced_shape = tuple (s for i , s in enumerate (shape ) if i not in axis )
1816+ reduced_shape = tuple (s for i , s in enumerate (shape_slice ) if i not in axis )
1817+
1818+ if out is not None and reduced_shape != out .shape :
1819+ raise ValueError ("Provided output shape does not match the reduced shape." )
17811820
17821821 if is_inside_new_expr ():
17831822 # We already have the dtype and reduced_shape, so return immediately
@@ -1820,23 +1859,17 @@ def reduce_slices( # noqa: C901
18201859 chunk_operands = {}
18211860 # Check which chunks intersect with _slice
18221861 chunk_size = ndindex .ChunkSize (chunks )
1823- _slice = _slice .raw
18241862 intersecting_chunks = chunk_size .as_subchunks (_slice , shape ) # if _slice is (), returns all chunks
18251863 out_init = False
18261864
18271865 for nchunk , chunk_slice in enumerate (intersecting_chunks ):
18281866 cslice = chunk_slice .raw
1829- if keepdims :
1830- reduced_slice = tuple (slice (None ) if i in axis else sl for i , sl in enumerate (cslice ))
1831- else :
1832- reduced_slice = tuple (sl for i , sl in enumerate (cslice ) if i not in axis )
18331867 offset = tuple (s .start for s in cslice ) # offset for the udf
1834- # TODO: Is this necessary, shouldn't slice always be None for a reduction?
18351868 # Check whether current cslice intersects with _slice
18361869 if cslice != () and _slice != ():
18371870 # get intersection of chunk and target
18381871 cslice = tuple (
1839- slice ( max ( s1 .start , s2 .start ), min ( s1 .stop , s2 .stop ) )
1872+ step_handler ( s1 .start , s2 .start , s1 .stop , s2 .stop , s2 . step )
18401873 for s1 , s2 in zip (cslice , _slice , strict = True )
18411874 )
18421875 chunks_ = tuple (s .stop - s .start for s in cslice )
@@ -1859,7 +1892,7 @@ def reduce_slices( # noqa: C901
18591892 if value .shape == ():
18601893 chunk_operands [key ] = value [()]
18611894 continue
1862- if check_smaller_shape (value .shape , shape , chunks_ ):
1895+ if check_smaller_shape (value .shape , shape , chunks_ , cslice ):
18631896 # We need to fetch the part of the value that broadcasts with the operand
18641897 smaller_slice = compute_smaller_slice (operand .shape , value .shape , cslice )
18651898 chunk_operands [key ] = value [smaller_slice ]
@@ -1874,6 +1907,15 @@ def reduce_slices( # noqa: C901
18741907 continue
18751908 chunk_operands [key ] = value [cslice ]
18761909
1910+ # get local index of part of out that is to be updated
1911+ cslice_subidx = (
1912+ ndindex .ndindex (cslice ).as_subindex (full_slice ).raw
1913+ ) # if full_slice is (), just gives cslice
1914+ if keepdims :
1915+ reduced_slice = tuple (slice (None ) if i in axis else sl for i , sl in enumerate (cslice_subidx ))
1916+ else :
1917+ reduced_slice = tuple (sl for i , sl in enumerate (cslice_subidx ) if i not in axis )
1918+
18771919 # Evaluate and reduce the expression using chunks of operands
18781920
18791921 if callable (expression ):
@@ -2622,23 +2664,16 @@ def get_num_elements(self, axis, item):
26222664 num_elements = self .sum (axis = axis , dtype = np .int64 , item = item )
26232665 self ._where_args = orig_where_args
26242666 return num_elements
2625- if np .isscalar (axis ):
2626- axis = (axis ,)
26272667 # Compute the number of elements in the array
26282668 shape = self .shape
2669+ if np .isscalar (axis ):
2670+ axis = (axis ,)
26292671 if item is not None :
26302672 # Compute the shape of the slice
2631- if not isinstance (item , tuple ):
2632- item = (item ,)
2633- # Ensure that the limits in item slices are not None
2634- item = tuple (slice (s .start or 0 , s .stop or self .shape [i ], s .step ) for i , s in enumerate (item ))
2635- # Compute the intersection of the slice with the shape
2636- item = tuple (slice (s1 .start , min (s1 .stop , s2 )) for s1 , s2 in zip (item , shape , strict = True ))
2637- if axis is None :
2638- shape = [s .stop - s .start for s in item ]
2639- else :
2640- shape = [s .stop - s .start for i , s in enumerate (item ) if i in axis ]
2641- return math .prod (shape ) if axis is None else math .prod ([shape [i ] for i in axis ])
2673+ shape = ndindex .ndindex (item ).newshape (shape )
2674+ axis = tuple (range (len (shape ))) if axis is None else axis
2675+ axis = tuple (a if a >= 0 else a + len (shape ) for a in axis ) # handle negative indexing
2676+ return math .prod ([shape [i ] for i in axis ])
26422677
26432678 def mean (self , axis = None , dtype = None , keepdims = False , ** kwargs ):
26442679 item = kwargs .pop ("item" , None )
@@ -2657,9 +2692,15 @@ def mean(self, axis=None, dtype=None, keepdims=False, **kwargs):
26572692
26582693 def std (self , axis = None , dtype = None , keepdims = False , ddof = 0 , ** kwargs ):
26592694 item = kwargs .pop ("item" , None )
2660- mean_value = self .mean (axis = axis , dtype = dtype , keepdims = True , item = item )
2661- expr = (self - mean_value ) ** 2
2662- out = expr .mean (axis = axis , dtype = dtype , keepdims = keepdims , item = item )
2695+ if item is None : # fast path
2696+ mean_value = self .mean (axis = axis , dtype = dtype , keepdims = True )
2697+ expr = (self - mean_value ) ** 2
2698+ else :
2699+ mean_value = self .mean (axis = axis , dtype = dtype , keepdims = True , item = item )
2700+ # TODO: Not optimal because we load the whole slice in memory. Would have to write
2701+ # a bespoke std function that executed within slice_eval to avoid this probably.
2702+ expr = (self .slice (item ) - mean_value ) ** 2
2703+ out = expr .mean (axis = axis , dtype = dtype , keepdims = keepdims )
26632704 if ddof != 0 :
26642705 num_elements = self .get_num_elements (axis , item )
26652706 out = np .sqrt (out * num_elements / (num_elements - ddof ))
@@ -2675,14 +2716,18 @@ def std(self, axis=None, dtype=None, keepdims=False, ddof=0, **kwargs):
26752716
26762717 def var (self , axis = None , dtype = None , keepdims = False , ddof = 0 , ** kwargs ):
26772718 item = kwargs .pop ("item" , None )
2678- mean_value = self .mean (axis = axis , dtype = dtype , keepdims = True , item = item )
2679- expr = (self - mean_value ) ** 2
2719+ if item is None : # fast path
2720+ mean_value = self .mean (axis = axis , dtype = dtype , keepdims = True )
2721+ expr = (self - mean_value ) ** 2
2722+ else :
2723+ mean_value = self .mean (axis = axis , dtype = dtype , keepdims = True , item = item )
2724+ # TODO: Not optimal because we load the whole slice in memory. Would have to write
2725+ # a bespoke var function that executed within slice_eval to avoid this probably.
2726+ expr = (self .slice (item ) - mean_value ) ** 2
2727+ out = expr .mean (axis = axis , dtype = dtype , keepdims = keepdims )
26802728 if ddof != 0 :
2681- out = expr .mean (axis = axis , dtype = dtype , keepdims = keepdims , item = item )
26822729 num_elements = self .get_num_elements (axis , item )
26832730 out = out * num_elements / (num_elements - ddof )
2684- else :
2685- out = expr .mean (axis = axis , dtype = dtype , keepdims = keepdims , item = item )
26862731 out2 = kwargs .pop ("out" , None )
26872732 if out2 is not None :
26882733 out2 [:] = out
@@ -3241,6 +3286,7 @@ def __getitem__(self, item):
32413286 # TODO: as this creates a big array, this can potentially consume a lot of memory
32423287 output = np .empty (self .shape , self .dtype )
32433288 # It is important to pass kwargs here, because chunks can be used internally
3289+ # fills numpy array with desired slice
32443290 chunked_eval (self .func , self .inputs_dict , item , _getitem = True , _output = output , ** self .kwargs )
32453291 return output [item ]
32463292 return self .res_getitem [item ]
0 commit comments