Skip to content

Commit b632260

Browse files
committed
Implemented a new DSL-powered blosc2.linspace
1 parent e0541dd commit b632260

3 files changed

Lines changed: 114 additions & 11 deletions

File tree

src/blosc2/dsl_kernel.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,11 @@ def _remove_scalar_params_preserving_source(text: str, scalar_replacements: dict
105105
updated = f"{text[:pstart]}{', '.join(kept)}{text[pend:]}"
106106
body_start = 0
107107
if colon is not None:
108-
body_start = _to_abs(_line_starts(updated), colon.end[0], colon.end[1])
108+
# Signature shrink can move ':' to an earlier column, so recompute
109+
# on the rewritten text to avoid skipping first-line body tokens.
110+
_, _, updated_colon = _find_def_signature_span(updated)
111+
if updated_colon is not None:
112+
body_start = _to_abs(_line_starts(updated), updated_colon.end[0], updated_colon.end[1])
109113
return updated, body_start
110114

111115

@@ -162,7 +166,7 @@ def _replace_scalar_names_preserving_source(
162166
return out
163167

164168

165-
def _fold_numeric_cast_calls_preserving_source(text: str, body_start: int):
169+
def _fold_numeric_cast_calls_preserving_source(text: str, body_start: int): # noqa: C901
166170
"""Fold float(<number>) and int(<number>) calls into literals.
167171
168172
miniexpr parses DSL function calls in a restricted way, and scalar specialization can
@@ -176,6 +180,20 @@ def _fold_numeric_cast_calls_preserving_source(text: str, body_start: int):
176180

177181
line_starts = _line_starts(text)
178182
edits = []
183+
184+
def _numeric_literal_value(node):
185+
if isinstance(node, ast.Constant) and isinstance(node.value, int | float | bool):
186+
return node.value
187+
if (
188+
isinstance(node, ast.UnaryOp)
189+
and isinstance(node.op, ast.UAdd | ast.USub)
190+
and isinstance(node.operand, ast.Constant)
191+
and isinstance(node.operand.value, int | float | bool)
192+
):
193+
value = node.operand.value
194+
return +value if isinstance(node.op, ast.UAdd) else -value
195+
return None
196+
179197
for node in ast.walk(tree):
180198
if not isinstance(node, ast.Call):
181199
continue
@@ -185,7 +203,8 @@ def _fold_numeric_cast_calls_preserving_source(text: str, body_start: int):
185203
continue
186204

187205
arg = node.args[0]
188-
if not isinstance(arg, ast.Constant) or not isinstance(arg.value, int | float | bool):
206+
value = _numeric_literal_value(arg)
207+
if value is None:
189208
continue
190209

191210
start_abs = _to_abs(line_starts, node.lineno, node.col_offset)
@@ -194,9 +213,9 @@ def _fold_numeric_cast_calls_preserving_source(text: str, body_start: int):
194213
end_abs = _to_abs(line_starts, node.end_lineno, node.end_col_offset)
195214

196215
if node.func.id == "float":
197-
repl = repr(float(arg.value))
216+
repl = repr(float(value))
198217
else:
199-
repl = repr(int(arg.value))
218+
repl = repr(int(value))
200219
edits.append((start_abs, end_abs, repl))
201220

202221
if not edits:

src/blosc2/ndarray.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5511,6 +5511,10 @@ def linspace_fill(inputs, output, offset):
55115511
else:
55125512
output[:] = np.linspace(start_, stop_, lout, endpoint=False, dtype=output.dtype)
55135513

5514+
@blosc2.dsl_kernel
5515+
def kernel_ramp(start, step):
5516+
return float(start) + _global_linear_idx * float(step) # noqa: F821 # DSL index/shape symbols resolved by miniexpr
5517+
55145518
if shape is None:
55155519
if num is None:
55165520
raise ValueError("Either `shape` or `num` must be specified.")
@@ -5540,13 +5544,21 @@ def linspace_fill(inputs, output, offset):
55405544
# We already have the dtype and shape, so return immediately
55415545
return blosc2.zeros(shape, dtype=dtype, **kwargs) # will return empty array for num == 0
55425546

5543-
inputs = (start, stop, num, endpoint)
5544-
lazyarr = blosc2.lazyudf(linspace_fill, inputs, dtype=dtype, shape=(num,))
5545-
if len(shape) == 1:
5546-
# C order is guaranteed, and no reshape is needed
5547-
return lazyarr.compute(**kwargs)
5547+
# Windows and wasm32 does not support complex numbers in DSL
5548+
if False or blosc2.isdtype(dtype, "complex floating"):
5549+
inputs = (start, stop, num, endpoint)
5550+
lazyarr = blosc2.lazyudf(linspace_fill, inputs, dtype=dtype, shape=(num,))
5551+
if len(shape) == 1:
5552+
# C order is guaranteed, and no reshape is needed
5553+
return lazyarr.compute(**kwargs)
55485554

5549-
return reshape(lazyarr, shape, c_order=c_order, **kwargs)
5555+
return reshape(lazyarr, shape, c_order=c_order, **kwargs)
5556+
else:
5557+
nitems = num - 1 if endpoint else num
5558+
step = (float(stop) - float(start)) / float(nitems) if nitems > 0 else 0.0
5559+
inputs = (start, step)
5560+
lazyarr = blosc2.lazyudf(kernel_ramp, inputs, dtype=dtype, shape=shape)
5561+
return lazyarr.compute(**kwargs)
55505562

55515563

55525564
def eye(N, M=None, k=0, dtype=np.float64, **kwargs: Any) -> NDArray:

tests/ndarray/test_dsl_kernels.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66
#######################################################################
77

88

9+
import subprocess
10+
import sys
11+
912
import numpy as np
1013
import pytest
1114

@@ -113,6 +116,18 @@ def kernel_scalar_start_step(start, step):
113116
return start + step * (_i0 * _n1 + _i1) # noqa: F821 # DSL index/shape symbols resolved by miniexpr
114117

115118

119+
@blosc2.dsl_kernel
120+
def kernel_scalar_start_stop_nitems(start, stop, nitems):
121+
step = (stop - start) / nitems
122+
return start + _global_linear_idx * step # noqa: F821 # DSL index/shape symbols resolved by miniexpr
123+
124+
125+
@blosc2.dsl_kernel
126+
def kernel_scalar_start_stop_nitems_float_cast(start, stop, nitems):
127+
step = (float(stop) - float(start)) / float(nitems)
128+
return float(start) + _global_linear_idx * step # noqa: F821 # DSL index/shape symbols resolved by miniexpr
129+
130+
116131
@blosc2.dsl_kernel
117132
def kernel_fallback_kw_call(x, y):
118133
return clip(x + y, a_min=0.5, a_max=2.5)
@@ -542,6 +557,63 @@ def test_dsl_kernel_two_scalar_params_start_step_linear_ramp():
542557
np.testing.assert_allclose(res[...], expected, rtol=0.0, atol=0.0)
543558

544559

560+
def test_dsl_kernel_three_scalar_params_start_stop_nitems_ramp():
561+
shape = (20, 25)
562+
start = np.float64(1.0)
563+
stop = np.float64(2.0)
564+
nitems = np.int64(np.prod(shape))
565+
566+
expr = blosc2.lazyudf(
567+
kernel_scalar_start_stop_nitems, (start, stop, nitems), dtype=np.float64, shape=shape
568+
)
569+
res = expr.compute()
570+
571+
step = (stop - start) / nitems
572+
expected = (start + step * np.arange(np.prod(shape), dtype=np.float64)).reshape(shape)
573+
np.testing.assert_allclose(res[...], expected, rtol=0.0, atol=0.0)
574+
575+
576+
def test_dsl_kernel_float_cast_with_negative_scalar_param():
577+
shape = (10, 100)
578+
start = -10
579+
stop = 10
580+
nitems = np.int64(np.prod(shape) - 1)
581+
582+
expr = blosc2.lazyudf(
583+
kernel_scalar_start_stop_nitems_float_cast, (start, stop, nitems), dtype=np.float32, shape=shape
584+
)
585+
res = expr.compute()
586+
587+
expected = np.linspace(start, stop, np.prod(shape), dtype=np.float32).reshape(shape)
588+
np.testing.assert_allclose(res[...], expected, rtol=1e-6, atol=1e-6)
589+
590+
591+
def test_dsl_kernel_float_cast_with_global_linear_idx_no_segfault_subprocess():
592+
if blosc2.IS_WASM:
593+
pytest.skip("subprocess is not supported on emscripten/wasm32")
594+
595+
code = (
596+
"import numpy as np\n"
597+
"import blosc2\n"
598+
"@blosc2.dsl_kernel\n"
599+
"def kernel(start, stop, nitems):\n"
600+
" step = (float(stop) - float(start)) / float(nitems)\n"
601+
" return float(start) + _global_linear_idx * step # noqa: F821\n"
602+
"shape = (10, 100)\n"
603+
"arr = blosc2.lazyudf(kernel, (-10, 10, 999), dtype=np.float32, shape=shape).compute()\n"
604+
"exp = np.linspace(-10, 10, np.prod(shape), dtype=np.float32).reshape(shape)\n"
605+
"np.testing.assert_allclose(arr, exp, rtol=1e-6, atol=1e-6)\n"
606+
"print('ok')\n"
607+
)
608+
result = subprocess.run([sys.executable, "-c", code], capture_output=True, text=True, check=False)
609+
assert result.returncode == 0, (
610+
"subprocess failed (possible segfault/regression in DSL float-cast path):\n"
611+
f"stdout:\n{result.stdout}\n"
612+
f"stderr:\n{result.stderr}"
613+
)
614+
assert "ok" in result.stdout
615+
616+
545617
def test_dsl_kernel_miniexpr_failure_raises_even_with_strict_disabled(monkeypatch):
546618
import importlib
547619

0 commit comments

Comments
 (0)