Skip to content

Commit 90565a8

Browse files
committed
Add new fp_accuracy param for LazyArray.compute()
1 parent d1fd8f6 commit 90565a8

6 files changed

Lines changed: 64 additions & 12 deletions

File tree

doc/reference/classes.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,3 +54,4 @@ Other Classes
5454
Storage
5555
Tuner
5656
URLPath
57+
FPAccuracy

doc/reference/misc.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ This page documents the miscellaneous members of the ``blosc2`` module that do n
5757
SpecialValue,
5858
SplitMode,
5959
Tuner,
60+
FPAccuracy,
6061
compute_chunks_blocks,
6162
get_slice_nchunks,
6263
remove_urlpath,

src/blosc2/__init__.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,21 @@ class Tuner(Enum):
113113
BTUNE = 32
114114

115115

116+
class FPAccuracy(Enum):
117+
"""
118+
Floating point accuracy modes for Blosc2 computing with lazy expressions.
119+
120+
This is only relevant when using floating point dtypes with miniexpr.
121+
"""
122+
123+
#: Use 1.0 ULPs (Units in the Last Place) for floating point functions
124+
HIGH = 1
125+
#: Use 3.5 ULPs (Units in the Last Place) for floating point functions
126+
LOW = 2
127+
#: Use default accuracy. This is LOW, which is enough for most applications.
128+
DEFAULT = LOW
129+
130+
116131
from .blosc2_ext import (
117132
DEFINED_CODECS_STOP,
118133
EXTENDED_HEADER_LENGTH,

src/blosc2/blosc2_ext.pyx

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -578,7 +578,7 @@ cdef extern from "miniexpr.h":
578578
const int64_t *shape, const int32_t *chunkshape,
579579
const int32_t *blockshape, int *error, me_expr **out)
580580

581-
cdef enum me_compile_status:
581+
ctypedef enum me_compile_status:
582582
ME_COMPILE_SUCCESS
583583
ME_COMPILE_ERR_OOM
584584
ME_COMPILE_ERR_PARSE
@@ -590,7 +590,7 @@ cdef extern from "miniexpr.h":
590590
ME_COMPILE_ERR_INVALID_ARG_TYPE
591591
ME_COMPILE_ERR_MIXED_TYPE_NESTED
592592

593-
cdef enum me_simd_ulp_mode:
593+
ctypedef enum me_simd_ulp_mode:
594594
ME_SIMD_ULP_DEFAULT
595595
ME_SIMD_ULP_1
596596
ME_SIMD_ULP_3_5
@@ -647,7 +647,8 @@ ctypedef struct udf_udata:
647647
ctypedef struct me_udata:
648648
b2nd_array_t** inputs
649649
int ninputs
650-
b2nd_array_t *array
650+
me_eval_params* eval_params
651+
b2nd_array_t* array
651652
void* aux_reduc_ptr
652653
int64_t chunks_in_array[B2ND_MAX_DIM]
653654
int64_t blocks_in_chunk[B2ND_MAX_DIM]
@@ -1819,6 +1820,8 @@ cdef class SChunk:
18191820
free(me_data.inputs)
18201821
if me_data.miniexpr_handle != NULL: # XXX do we really need the conditional?
18211822
me_free(me_data.miniexpr_handle)
1823+
if me_data.eval_params != NULL:
1824+
free(me_data.eval_params)
18221825
free(me_data)
18231826
elif self.schunk.storage.cparams.prefilter != NULL:
18241827
# From Python the preparams->udata with always have the field py_func
@@ -2015,7 +2018,7 @@ cdef int aux_miniexpr(me_udata *udata, int64_t nchunk, int32_t nblock,
20152018
# NOTE: miniexpr handles scalar outputs in me_eval_nd without touching tail bytes.
20162019
aux_reduc_ptr = <void *> (<uintptr_t> udata.aux_reduc_ptr + offset_bytes)
20172020
rc = me_eval_nd(miniexpr_handle, <const void**> input_buffers, udata.ninputs,
2018-
aux_reduc_ptr, blocknitems, nchunk, nblock, NULL)
2021+
aux_reduc_ptr, blocknitems, nchunk, nblock, udata.eval_params)
20192022
if rc != 0:
20202023
raise RuntimeError(f"miniexpr: issues during evaluation; error code: {rc}")
20212024

@@ -2916,7 +2919,7 @@ cdef class NDArray:
29162919

29172920
return udata
29182921

2919-
cdef me_udata *_fill_me_udata(self, inputs, aux_reduc):
2922+
cdef me_udata *_fill_me_udata(self, inputs, fp_accuracy, aux_reduc):
29202923
cdef me_udata *udata = <me_udata *> malloc(sizeof(me_udata))
29212924
operands = list(inputs.values())
29222925
ninputs = len(operands)
@@ -2927,6 +2930,10 @@ cdef class NDArray:
29272930
inputs_[i].chunk_cache.data = NULL
29282931
udata.inputs = inputs_
29292932
udata.ninputs = ninputs
2933+
cdef me_eval_params* eval_params = <me_eval_params*> malloc(sizeof(me_eval_params))
2934+
eval_params.disable_simd = False
2935+
eval_params.simd_ulp_mode = ME_SIMD_ULP_3_5 if fp_accuracy == blosc2.FPAccuracy.LOW else ME_SIMD_ULP_1
2936+
udata.eval_params = eval_params
29302937
udata.array = self.array
29312938
cdef void* aux_reduc_ptr = NULL
29322939
if aux_reduc is not None:
@@ -2941,12 +2948,12 @@ cdef class NDArray:
29412948

29422949
return udata
29432950

2944-
def _set_pref_expr(self, expression, inputs, aux_reduc=None):
2951+
def _set_pref_expr(self, expression, inputs, fp_accuracy, aux_reduc=None):
29452952
# Set prefilter for miniexpr
29462953
cdef blosc2_cparams* cparams = self.array.sc.storage.cparams
29472954
cparams.prefilter = <blosc2_prefilter_fn> miniexpr_prefilter
29482955

2949-
cdef me_udata* udata = self._fill_me_udata(inputs, aux_reduc)
2956+
cdef me_udata* udata = self._fill_me_udata(inputs, fp_accuracy, aux_reduc)
29502957

29512958
# Get the compiled expression handle for multi-threading
29522959
cdef Py_ssize_t n = len(inputs)

src/blosc2/lazyexpr.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,12 @@ def sort(self, order: str | list[str] | None = None) -> blosc2.LazyArray:
302302
pass
303303

304304
@abstractmethod
305-
def compute(self, item: slice | list[slice] | None = None, **kwargs: Any) -> blosc2.NDArray:
305+
def compute(
306+
self,
307+
item: slice | list[slice] | None = None,
308+
fp_accuracy: blosc2.FPAccuracy = blosc2.FPAccuracy.DEFAULT,
309+
**kwargs: Any,
310+
) -> blosc2.NDArray:
306311
"""
307312
Return a :ref:`NDArray` containing the evaluation of the :ref:`LazyArray`.
308313
@@ -313,9 +318,14 @@ def compute(self, item: slice | list[slice] | None = None, **kwargs: Any) -> blo
313318
the evaluated result. This difference between slicing operands and slicing the final expression
314319
is important when reductions or a where clause are used in the expression.
315320
321+
fp_accuracy: :ref:`blosc2.FPAccuracy`, optional
322+
Specifies the floating-point accuracy to be used during computation.
323+
By default, :ref:`blosc2.FPAccuracy.DEFAULT` is used.
324+
316325
kwargs: Any, optional
317326
Keyword arguments that are supported by the :func:`empty` constructor.
318327
These arguments will be set in the resulting :ref:`NDArray`.
328+
Additionally, the following special kwargs are supported:
319329
320330
Returns
321331
-------
@@ -1296,10 +1306,11 @@ def fast_eval( # noqa: C901
12961306
if use_miniexpr:
12971307
cparams = kwargs.pop("cparams", blosc2.CParams())
12981308
# All values will be overwritten, so we can use an uninitialized array
1309+
fp_accuracy = kwargs.pop("fp_accuracy", blosc2.FPAccuracy.DEFAULT)
12991310
res_eval = blosc2.uninit(shape, dtype, chunks=chunks, blocks=blocks, cparams=cparams, **kwargs)
13001311
try:
13011312
print("expr->miniexpr:", expression)
1302-
res_eval._set_pref_expr(expression, operands)
1313+
res_eval._set_pref_expr(expression, operands, fp_accuracy=fp_accuracy)
13031314
# Data to compress is fetched from operands, so it can be uninitialized here
13041315
data = np.empty(res_eval.schunk.chunksize, dtype=np.uint8)
13051316
# Exercise prefilter for each chunk
@@ -2001,6 +2012,7 @@ def reduce_slices( # noqa: C901
20012012
if use_miniexpr:
20022013
# Experiments say that not splitting is best (at least on Apple Silicon M4 Pro)
20032014
cparams = kwargs.pop("cparams", blosc2.CParams(splitmode=blosc2.SplitMode.NEVER_SPLIT))
2015+
fp_accuracy = kwargs.pop("fp_accuracy", blosc2.FPAccuracy.DEFAULT)
20042016
# Create a fake NDArray just to drive the miniexpr evaluation (values won't be used)
20052017
res_eval = blosc2.uninit(shape, dtype, chunks=chunks, blocks=blocks, cparams=cparams, **kwargs)
20062018
# Compute the number of blocks in the result
@@ -2027,7 +2039,7 @@ def reduce_slices( # noqa: C901
20272039
try:
20282040
print("expr->miniexpr:", expression, reduce_op)
20292041
expression = f"{reduce_op_str}({expression})"
2030-
res_eval._set_pref_expr(expression, operands, aux_reduc)
2042+
res_eval._set_pref_expr(expression, operands, fp_accuracy, aux_reduc)
20312043
# Data won't even try to be compressed, so buffers can be unitialized and reused
20322044
data = np.empty(res_eval.schunk.chunksize, dtype=np.uint8)
20332045
chunk_data = np.empty(res_eval.schunk.chunksize + blosc2.MAX_OVERHEAD, dtype=np.uint8)
@@ -3142,7 +3154,9 @@ def sort(self, order: str | list[str] | None = None) -> blosc2.LazyArray:
31423154
lazy_expr._order = order
31433155
return lazy_expr
31443156

3145-
def compute(self, item=(), **kwargs) -> blosc2.NDArray:
3157+
def compute(
3158+
self, item=(), fp_accuracy: blosc2.FPAccuracy = blosc2.FPAccuracy.DEFAULT, **kwargs
3159+
) -> blosc2.NDArray:
31463160
# When NumPy ufuncs are called, the user may add an `out` parameter to kwargs
31473161
if "out" in kwargs: # use provided out preferentially
31483162
kwargs["_output"] = kwargs.pop("out")
@@ -3452,7 +3466,7 @@ def sort(self, order: str | list[str] | None = None) -> blosc2.LazyArray:
34523466
lazy_expr._order = order
34533467
return lazy_expr
34543468

3455-
def compute(self, item=(), **kwargs):
3469+
def compute(self, item=(), fp_accuracy: blosc2.FPAccuracy = blosc2.FPAccuracy.DEFAULT, **kwargs):
34563470
# Get kwargs
34573471
if kwargs is None:
34583472
kwargs = {}

tests/ndarray/test_lazyexpr.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,20 @@ def test_expression_with_constants(array_fixture):
278278
np.testing.assert_allclose(res[:], nres)
279279

280280

281+
@pytest.mark.parametrize("accuracy", [blosc2.FPAccuracy.LOW, blosc2.FPAccuracy.HIGH])
282+
def test_fp_precision(array_fixture, accuracy):
283+
a1, a2, a3, a4, na1, na2, na3, na4 = array_fixture
284+
# Test with operands with same chunks and blocks
285+
expr = blosc2.sin(a1) ** 2 - blosc2.cos(a2) ** 2 + blosc2.sqrt(a3)
286+
# All precisions in miniexpr should be quite good for this expression
287+
res = expr.compute(fp_accuracy=accuracy)
288+
nres = ne_evaluate("sin(na1) ** 2 - cos(na2) ** 2 + sqrt(na3)")
289+
if na1.dtype == np.float32:
290+
np.testing.assert_allclose(res[:], nres, rtol=1e-6, atol=1e-6)
291+
else:
292+
np.testing.assert_allclose(res[:], nres)
293+
294+
281295
@pytest.mark.parametrize("compare_expressions", [True, False])
282296
@pytest.mark.parametrize("comparison_operator", ["==", "!=", ">=", ">", "<=", "<"])
283297
def test_comparison_operators(dtype_fixture, compare_expressions, comparison_operator):

0 commit comments

Comments
 (0)