Skip to content

Commit 0f283ec

Browse files
Merge pull request #592 from Blosc/dsl-constructors
DSL constructors
2 parents 7d45c9e + fc664ad commit 0f283ec

3 files changed

Lines changed: 139 additions & 17 deletions

File tree

src/blosc2/dsl_kernel.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,11 @@ def _remove_scalar_params_preserving_source(text: str, scalar_replacements: dict
105105
updated = f"{text[:pstart]}{', '.join(kept)}{text[pend:]}"
106106
body_start = 0
107107
if colon is not None:
108-
body_start = _to_abs(_line_starts(updated), colon.end[0], colon.end[1])
108+
# Signature shrink can move ':' to an earlier column, so recompute
109+
# on the rewritten text to avoid skipping first-line body tokens.
110+
_, _, updated_colon = _find_def_signature_span(updated)
111+
if updated_colon is not None:
112+
body_start = _to_abs(_line_starts(updated), updated_colon.end[0], updated_colon.end[1])
109113
return updated, body_start
110114

111115

@@ -162,7 +166,7 @@ def _replace_scalar_names_preserving_source(
162166
return out
163167

164168

165-
def _fold_numeric_cast_calls_preserving_source(text: str, body_start: int):
169+
def _fold_numeric_cast_calls_preserving_source(text: str, body_start: int): # noqa: C901
166170
"""Fold float(<number>) and int(<number>) calls into literals.
167171
168172
miniexpr parses DSL function calls in a restricted way, and scalar specialization can
@@ -176,6 +180,20 @@ def _fold_numeric_cast_calls_preserving_source(text: str, body_start: int):
176180

177181
line_starts = _line_starts(text)
178182
edits = []
183+
184+
def _numeric_literal_value(node):
185+
if isinstance(node, ast.Constant) and isinstance(node.value, int | float | bool):
186+
return node.value
187+
if (
188+
isinstance(node, ast.UnaryOp)
189+
and isinstance(node.op, ast.UAdd | ast.USub)
190+
and isinstance(node.operand, ast.Constant)
191+
and isinstance(node.operand.value, int | float | bool)
192+
):
193+
value = node.operand.value
194+
return value if isinstance(node.op, ast.UAdd) else -value
195+
return None
196+
179197
for node in ast.walk(tree):
180198
if not isinstance(node, ast.Call):
181199
continue
@@ -185,7 +203,8 @@ def _fold_numeric_cast_calls_preserving_source(text: str, body_start: int):
185203
continue
186204

187205
arg = node.args[0]
188-
if not isinstance(arg, ast.Constant) or not isinstance(arg.value, int | float | bool):
206+
value = _numeric_literal_value(arg)
207+
if value is None:
189208
continue
190209

191210
start_abs = _to_abs(line_starts, node.lineno, node.col_offset)
@@ -194,9 +213,9 @@ def _fold_numeric_cast_calls_preserving_source(text: str, body_start: int):
194213
end_abs = _to_abs(line_starts, node.end_lineno, node.end_col_offset)
195214

196215
if node.func.id == "float":
197-
repl = repr(float(arg.value))
216+
repl = repr(float(value))
198217
else:
199-
repl = repr(int(arg.value))
218+
repl = repr(int(value))
200219
edits.append((start_abs, end_abs, repl))
201220

202221
if not edits:

src/blosc2/ndarray.py

Lines changed: 33 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5454,6 +5454,10 @@ def arange_fill(inputs, output, offset):
54545454
else: # use linspace to have finer control over exclusion of endpoint for float types
54555455
output[:] = np.linspace(start, stop, lout, endpoint=False, dtype=output.dtype)
54565456

5457+
@blosc2.dsl_kernel
5458+
def ramp_arange(start, step):
5459+
return start + _flat_idx * step # noqa: F821 # DSL index/shape symbols resolved by miniexpr
5460+
54575461
if step is None: # not array-api compliant but for backwards compatibility
54585462
step = 1
54595463
if stop is None:
@@ -5478,14 +5482,19 @@ def arange_fill(inputs, output, offset):
54785482
# We already have the dtype and shape, so return immediately
54795483
return blosc2.zeros(shape, dtype=dtype, **kwargs)
54805484

5481-
lshape = (math.prod(shape),)
5482-
lazyarr = blosc2.lazyudf(arange_fill, (start, stop, step), dtype=dtype, shape=lshape)
5485+
# Windows and wasm32 does not support complex numbers in DSL
5486+
if blosc2.isdtype(dtype, "complex floating"):
5487+
lshape = (math.prod(shape),)
5488+
lazyarr = blosc2.lazyudf(arange_fill, (start, stop, step), dtype=dtype, shape=lshape)
54835489

5484-
if len(shape) == 1:
5485-
# C order is guaranteed, and no reshape is needed
5486-
return lazyarr.compute(**kwargs)
5490+
if len(shape) == 1:
5491+
# C order is guaranteed, and no reshape is needed
5492+
return lazyarr.compute(**kwargs)
54875493

5488-
return reshape(lazyarr, shape, c_order=c_order, **kwargs)
5494+
return reshape(lazyarr, shape, c_order=c_order, **kwargs)
5495+
else:
5496+
lazyarr = blosc2.lazyudf(ramp_arange, (start, step), dtype=dtype, shape=shape)
5497+
return lazyarr.compute(**kwargs)
54895498

54905499

54915500
# Define a numpy linspace-like function
@@ -5550,6 +5559,10 @@ def linspace_fill(inputs, output, offset):
55505559
else:
55515560
output[:] = np.linspace(start_, stop_, lout, endpoint=False, dtype=output.dtype)
55525561

5562+
@blosc2.dsl_kernel
5563+
def ramp_linspace(start, step):
5564+
return float(start) + _flat_idx * float(step) # noqa: F821 # DSL index/shape symbols resolved by miniexpr
5565+
55535566
if shape is None:
55545567
if num is None:
55555568
raise ValueError("Either `shape` or `num` must be specified.")
@@ -5579,13 +5592,21 @@ def linspace_fill(inputs, output, offset):
55795592
# We already have the dtype and shape, so return immediately
55805593
return blosc2.zeros(shape, dtype=dtype, **kwargs) # will return empty array for num == 0
55815594

5582-
inputs = (start, stop, num, endpoint)
5583-
lazyarr = blosc2.lazyudf(linspace_fill, inputs, dtype=dtype, shape=(num,))
5584-
if len(shape) == 1:
5585-
# C order is guaranteed, and no reshape is needed
5586-
return lazyarr.compute(**kwargs)
5595+
# Windows and wasm32 does not support complex numbers in DSL
5596+
if blosc2.isdtype(dtype, "complex floating"):
5597+
inputs = (start, stop, num, endpoint)
5598+
lazyarr = blosc2.lazyudf(linspace_fill, inputs, dtype=dtype, shape=(num,))
5599+
if len(shape) == 1:
5600+
# C order is guaranteed, and no reshape is needed
5601+
return lazyarr.compute(**kwargs)
55875602

5588-
return reshape(lazyarr, shape, c_order=c_order, **kwargs)
5603+
return reshape(lazyarr, shape, c_order=c_order, **kwargs)
5604+
else:
5605+
nitems = num - 1 if endpoint else num
5606+
step = (float(stop) - float(start)) / float(nitems) if nitems > 0 else 0.0
5607+
inputs = (start, step)
5608+
lazyarr = blosc2.lazyudf(ramp_linspace, inputs, dtype=dtype, shape=shape)
5609+
return lazyarr.compute(**kwargs)
55895610

55905611

55915612
def eye(N, M=None, k=0, dtype=np.float64, **kwargs: Any) -> NDArray:

tests/ndarray/test_dsl_kernels.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@
88

99
import subprocess
1010
import sys
11+
import tempfile
12+
import textwrap
13+
from pathlib import Path
1114

1215
import numpy as np
1316
import pytest
@@ -116,6 +119,18 @@ def kernel_scalar_start_step(start, step):
116119
return start + step * (_i0 * _n1 + _i1) # noqa: F821 # DSL index/shape symbols resolved by miniexpr
117120

118121

122+
@blosc2.dsl_kernel
123+
def kernel_scalar_start_stop_nitems(start, stop, nitems):
124+
step = (stop - start) / nitems
125+
return start + _flat_idx * step # noqa: F821 # DSL index/shape symbols resolved by miniexpr
126+
127+
128+
@blosc2.dsl_kernel
129+
def kernel_scalar_start_stop_nitems_float_cast(start, stop, nitems):
130+
step = (float(stop) - float(start)) / float(nitems)
131+
return float(start) + _flat_idx * step # noqa: F821 # DSL index/shape symbols resolved by miniexpr
132+
133+
119134
@blosc2.dsl_kernel
120135
def kernel_fallback_kw_call(x, y):
121136
return clip(x + y, a_min=0.5, a_max=2.5)
@@ -545,6 +560,73 @@ def test_dsl_kernel_two_scalar_params_start_step_linear_ramp():
545560
np.testing.assert_allclose(res[...], expected, rtol=0.0, atol=0.0)
546561

547562

563+
def test_dsl_kernel_three_scalar_params_start_stop_nitems_ramp():
564+
shape = (20, 25)
565+
start = np.float64(1.0)
566+
stop = np.float64(2.0)
567+
nitems = np.int64(np.prod(shape))
568+
569+
expr = blosc2.lazyudf(
570+
kernel_scalar_start_stop_nitems, (start, stop, nitems), dtype=np.float64, shape=shape
571+
)
572+
res = expr.compute()
573+
574+
step = (stop - start) / nitems
575+
expected = (start + step * np.arange(np.prod(shape), dtype=np.float64)).reshape(shape)
576+
np.testing.assert_allclose(res[...], expected, rtol=0.0, atol=0.0)
577+
578+
579+
def test_dsl_kernel_float_cast_with_negative_scalar_param():
580+
shape = (10, 100)
581+
start = -10
582+
stop = 10
583+
nitems = np.int64(np.prod(shape) - 1)
584+
585+
expr = blosc2.lazyudf(
586+
kernel_scalar_start_stop_nitems_float_cast, (start, stop, nitems), dtype=np.float32, shape=shape
587+
)
588+
res = expr.compute()
589+
590+
expected = np.linspace(start, stop, np.prod(shape), dtype=np.float32).reshape(shape)
591+
np.testing.assert_allclose(res[...], expected, rtol=1e-6, atol=1e-6)
592+
593+
594+
def test_dsl_kernel_float_cast_with_flat_idx_no_segfault_subprocess():
595+
if blosc2.IS_WASM:
596+
pytest.skip("subprocess is not supported on emscripten/wasm32")
597+
598+
code = textwrap.dedent(
599+
"""
600+
import numpy as np
601+
import blosc2
602+
603+
@blosc2.dsl_kernel
604+
def kernel(start, stop, nitems):
605+
step = (float(stop) - float(start)) / float(nitems)
606+
return float(start) + _flat_idx * step # noqa: F821
607+
608+
shape = (10, 100)
609+
arr = blosc2.lazyudf(kernel, (-10, 10, 999), dtype=np.float32, shape=shape).compute()
610+
exp = np.linspace(-10, 10, np.prod(shape), dtype=np.float32).reshape(shape)
611+
np.testing.assert_allclose(arr, exp, rtol=1e-6, atol=1e-6)
612+
print("ok")
613+
"""
614+
)
615+
616+
# Run from a real .py file so inspect.getsource() can recover the DSL source.
617+
with tempfile.TemporaryDirectory() as tmpdir:
618+
script = Path(tmpdir) / "dsl_kernel_subprocess.py"
619+
script.write_text(code, encoding="utf-8")
620+
result = subprocess.run([sys.executable, str(script)], capture_output=True, text=True, check=False)
621+
622+
assert result.returncode == 0, (
623+
"subprocess failed (possible segfault/regression in DSL float-cast path):\n"
624+
f"stdout:\n{result.stdout}\n"
625+
f"stderr:\n{result.stderr}"
626+
)
627+
assert "ok" in result.stdout
628+
629+
548630
def test_dsl_kernel_scalar_constant_subexpr_runtime_no_segfault(tmp_path):
549631
if blosc2.IS_WASM:
550632
pytest.skip("subprocess is not supported on emscripten/wasm32")

0 commit comments

Comments
 (0)