|
10 | 10 |
|
11 | 11 | import numpy as np |
12 | 12 | import pytest |
| 13 | +import torch |
13 | 14 |
|
14 | 15 | import blosc2 |
15 | 16 | from blosc2.lazyexpr import ne_evaluate |
@@ -1697,13 +1698,13 @@ def test_lazylinalg(): |
1697 | 1698 | np.testing.assert_array_almost_equal(out[()], npres) |
1698 | 1699 |
|
1699 | 1700 | # --- squeeze --- |
1700 | | - out = blosc2.lazyexpr("squeeze(D)") |
1701 | | - npres = np.squeeze(npD) |
| 1701 | + out = blosc2.lazyexpr("squeeze(D, axis=-1)") |
| 1702 | + npres = np.squeeze(npD, -1) |
1702 | 1703 | assert out.shape == npres.shape |
1703 | 1704 | np.testing.assert_array_almost_equal(out[()], npres) |
1704 | 1705 |
|
1705 | | - out = blosc2.lazyexpr("D.squeeze()") |
1706 | | - npres = np.squeeze(npD) |
| 1706 | + out = blosc2.lazyexpr("D.squeeze(axis=-1)") |
| 1707 | + npres = np.squeeze(npD, -1) |
1707 | 1708 | assert out.shape == npres.shape |
1708 | 1709 | np.testing.assert_array_almost_equal(out[()], npres) |
1709 | 1710 |
|
@@ -1772,3 +1773,52 @@ def test_lazyexpr_2args(): |
1772 | 1773 | newexpr = blosc2.hypot(lexpr, 3) |
1773 | 1774 | assert newexpr.expression == "hypot((sin(o0)), 3)" |
1774 | 1775 | assert newexpr.operands["o0"] is a |
| 1776 | + |
| 1777 | + |
| 1778 | +@pytest.mark.parametrize( |
| 1779 | + "xp", |
| 1780 | + [torch, np], |
| 1781 | +) |
| 1782 | +@pytest.mark.parametrize( |
| 1783 | + "dtype", |
| 1784 | + ["bool", "int32", "int64", "float32", "float64", "complex128"], |
| 1785 | +) |
| 1786 | +def test_simpleproxy(xp, dtype): |
| 1787 | + dtype_ = getattr(xp, dtype) if hasattr(xp, dtype) else np.dtype(dtype) |
| 1788 | + if dtype == "bool": |
| 1789 | + blosc_matrix = blosc2.asarray([True, False, False], dtype=np.dtype(dtype), chunks=(2,)) |
| 1790 | + foreign_matrix = xp.zeros((3,), dtype=dtype_) |
| 1791 | + # Create a lazy expression object |
| 1792 | + lexpr = blosc2.lazyexpr( |
| 1793 | + "(b & a) | (~b)", operands={"a": blosc_matrix, "b": foreign_matrix} |
| 1794 | + ) # this does not |
| 1795 | + # Compare with numpy computation result |
| 1796 | + npb = np.asarray(foreign_matrix) |
| 1797 | + npa = blosc_matrix[()] |
| 1798 | + res = (npb & npa) | np.logical_not(npb) |
| 1799 | + else: |
| 1800 | + N = 10 |
| 1801 | + shape_a = (N, N, N) |
| 1802 | + blosc_matrix = blosc2.full(shape=shape_a, fill_value=3, dtype=np.dtype(dtype), chunks=(N // 3,) * 3) |
| 1803 | + foreign_matrix = xp.ones(shape_a, dtype=dtype_) |
| 1804 | + if dtype == "complex128": |
| 1805 | + foreign_matrix += 0.5j |
| 1806 | + blosc_matrix = blosc2.full( |
| 1807 | + shape=shape_a, fill_value=3 + 2j, dtype=np.dtype(dtype), chunks=(N // 3,) * 3 |
| 1808 | + ) |
| 1809 | + |
| 1810 | + # Create a lazy expression object |
| 1811 | + lexpr = blosc2.lazyexpr( |
| 1812 | + "b + sin(a) + sum(b) - tensordot(a, b, axes=1)", |
| 1813 | + operands={"a": blosc_matrix, "b": foreign_matrix}, |
| 1814 | + ) # this does not |
| 1815 | + # Compare with numpy computation result |
| 1816 | + npb = np.asarray(foreign_matrix) |
| 1817 | + npa = blosc_matrix[()] |
| 1818 | + res = npb + np.sin(npa) + np.sum(npb) - np.tensordot(npa, npb, axes=1) |
| 1819 | + |
| 1820 | + # Test object metadata and result |
| 1821 | + assert isinstance(lexpr, blosc2.LazyExpr) |
| 1822 | + assert lexpr.dtype == res.dtype |
| 1823 | + assert lexpr.shape == res.shape |
| 1824 | + np.testing.assert_array_equal(lexpr[()], res) |
0 commit comments