|
6 | 6 | # LICENSE file in the root directory of this source tree) |
7 | 7 | ####################################################################### |
8 | 8 |
|
| 9 | +import dask.array as da |
| 10 | +import jax.numpy as jnp |
9 | 11 | import numpy as np |
10 | 12 | import pytest |
| 13 | +import tensorflow as tf |
| 14 | +import torch |
| 15 | + |
| 16 | +# TODO: import cupy as cp |
| 17 | +import zarr |
11 | 18 |
|
12 | 19 | import blosc2 |
13 | 20 | from blosc2.ndarray import get_chunks_idx |
@@ -172,3 +179,52 @@ def get_chunk(self, nchunk): |
172 | 179 | proxy = blosc2.Proxy(source) |
173 | 180 | result = proxy[...] |
174 | 181 | np.testing.assert_array_equal(result, data) |
| 182 | + |
| 183 | + |
| 184 | +@pytest.mark.parametrize( |
| 185 | + "xp", |
| 186 | + [torch, tf, np, jnp, da, zarr], |
| 187 | +) |
| 188 | +@pytest.mark.parametrize( |
| 189 | + "dtype", |
| 190 | + ["bool", "int32", "int64", "float32", "float64", "complex128"], |
| 191 | +) |
| 192 | +def test_simpleproxy(xp, dtype): |
| 193 | + dtype_ = getattr(xp, dtype) if hasattr(xp, dtype) else np.dtype(dtype) |
| 194 | + if dtype == "bool": |
| 195 | + blosc_matrix = blosc2.asarray([True, False, False], dtype=np.dtype(dtype), chunks=(2,)) |
| 196 | + foreign_matrix = xp.zeros((3,), dtype=dtype_) |
| 197 | + # Create a lazy expression object |
| 198 | + lexpr = blosc2.lazyexpr( |
| 199 | + "(b & a) | (~b)", operands={"a": blosc_matrix, "b": foreign_matrix} |
| 200 | + ) # this does not |
| 201 | + # Compare with numpy computation result |
| 202 | + npb = np.asarray(foreign_matrix) |
| 203 | + npa = blosc_matrix[()] |
| 204 | + res = (npb & npa) | np.logical_not(npb) |
| 205 | + else: |
| 206 | + N = 10 |
| 207 | + shape_a = (N, N, N) |
| 208 | + blosc_matrix = blosc2.full(shape=shape_a, fill_value=3, dtype=np.dtype(dtype), chunks=(N // 3,) * 3) |
| 209 | + foreign_matrix = xp.ones(shape_a, dtype=dtype_) |
| 210 | + if dtype == "complex128": |
| 211 | + foreign_matrix += 0.5j |
| 212 | + blosc_matrix = blosc2.full( |
| 213 | + shape=shape_a, fill_value=3 + 2j, dtype=np.dtype(dtype), chunks=(N // 3,) * 3 |
| 214 | + ) |
| 215 | + |
| 216 | + # Create a lazy expression object |
| 217 | + lexpr = blosc2.lazyexpr( |
| 218 | + "b + sin(a) + sum(b) - tensordot(a, b, axes=1)", |
| 219 | + operands={"a": blosc_matrix, "b": foreign_matrix}, |
| 220 | + ) # this does not |
| 221 | + # Compare with numpy computation result |
| 222 | + npb = np.asarray(foreign_matrix) |
| 223 | + npa = blosc_matrix[()] |
| 224 | + res = npb + np.sin(npa) + np.sum(npb) - np.tensordot(npa, npb, axes=1) |
| 225 | + |
| 226 | + # Test object metadata and result |
| 227 | + assert isinstance(lexpr, blosc2.LazyExpr) |
| 228 | + assert lexpr.dtype == res.dtype |
| 229 | + assert lexpr.shape == res.shape |
| 230 | + np.testing.assert_array_equal(lexpr[()], res) |
0 commit comments