Skip to content

Commit bfb0dfc

Browse files
committed
Add dtype conversion to SimpleProxy
1 parent 88fb47c commit bfb0dfc

3 files changed

Lines changed: 17 additions & 5 deletions

File tree

src/blosc2/lazyexpr.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2530,7 +2530,7 @@ def where(self, value1=None, value2=None):
25302530
# This just acts as a 'decorator' for the existing expression
25312531
if value1 is not None and value2 is not None:
25322532
# Guess the outcome dtype for value1 and value2
2533-
dtype = np.result_type(value1, value2)
2533+
dtype = blosc2.result_type(value1, value2)
25342534
args = {"_where_x": value1, "_where_y": value2}
25352535
elif value1 is not None:
25362536
if hasattr(value1, "dtype"):

src/blosc2/linalg.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def matmul(x1: blosc2.Array, x2: blosc2.NDArray, **kwargs: Any) -> blosc2.NDArra
100100
n, k = x1.shape[-2:]
101101
m = x2.shape[-1]
102102
result_shape = np.broadcast_shapes(x1.shape[:-2], x2.shape[:-2]) + (n, m)
103-
result = blosc2.zeros(result_shape, dtype=np.result_type(x1, x2), **kwargs)
103+
result = blosc2.zeros(result_shape, dtype=blosc2.result_type(x1, x2), **kwargs)
104104

105105
if 0 not in result.shape + x1.shape + x2.shape: # if any array is empty, return array of 0s
106106
p, q = result.chunks[-2:]
@@ -227,7 +227,7 @@ def tensordot(
227227
raise ValueError("x1 and x2 must have same shapes along reduction dimensions")
228228

229229
result_shape = tuple(x1shape[a_keep]) + tuple(x2shape[b_keep])
230-
result = blosc2.zeros(result_shape, dtype=np.result_type(x1, x2), **kwargs)
230+
result = blosc2.zeros(result_shape, dtype=blosc2.result_type(x1, x2), **kwargs)
231231

232232
op_chunks = [
233233
slice_to_chunktuple(slice(0, s, 1), c) for s, c in zip(x1shape[a_axes], a_chunks_red, strict=True)
@@ -363,7 +363,7 @@ def vecdot(x1: blosc2.NDArray, x2: blosc2.NDArray, axis: int = -1, **kwargs) ->
363363
raise ValueError("x1 and x2 must have same shapes along reduction dimensions")
364364

365365
result_shape = np.broadcast_shapes(x1shape[a_keep], x2shape[b_keep])
366-
result = blosc2.zeros(result_shape, dtype=np.result_type(x1, x2), **kwargs)
366+
result = blosc2.zeros(result_shape, dtype=blosc2.result_type(x1, x2), **kwargs)
367367

368368
res_chunks = [
369369
slice_to_chunktuple(s, c)

src/blosc2/proxy.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -569,6 +569,18 @@ def __getitem__(self, item: slice | list[slice]) -> np.ndarray:
569569
return nparr[self.field]
570570

571571

572+
def _convert_dtype(dt: str | np.typing.DTypeLike):
573+
"""
574+
Attempts to convert to blosc2.dtype (i.e. numpy dtype)
575+
"""
576+
try:
577+
return np.dtype(dt)
578+
except TypeError: # likely passed e.g. a torch.float64
579+
return np.dtype(str(dt).split(".")[1])
580+
except Exception as e:
581+
raise TypeError("Could not parse dtype arg {dt}.") from e
582+
583+
572584
class SimpleProxy(blosc2.Operand):
573585
"""
574586
Simple proxy for any data container to be used with the compute engine.
@@ -597,7 +609,7 @@ def __init__(self, src, chunks: tuple | None = None, blocks: tuple | None = None
597609
if not hasattr(src, "__getitem__"):
598610
raise TypeError("The source must have a __getitem__ method")
599611
self._src = src
600-
self._dtype = src.dtype
612+
self._dtype = _convert_dtype(src.dtype)
601613
self._shape = src.shape
602614
# Compute reasonable values for chunks and blocks
603615
cparams = blosc2.CParams(clevel=0)

0 commit comments

Comments
 (0)