From a1071cc3c0426737c5bd314e6e3fa00d63191e1a Mon Sep 17 00:00:00 2001 From: chaoming Date: Thu, 18 Jun 2026 23:26:19 +0800 Subject: [PATCH 1/4] fix: prevent MilsteinGradFree OOM crash; make GPU test suite hardware-independent MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The Linux CI job (`pytest brainpy/`) was OOM-killed (exit 143) at `brainpy/integrators/sde/normal_coverage_test.py` [69%]. Root cause: the vector-Wiener branch of `MilsteinGradFree.step` used the full `(m, m)` diffusion-bar matrix together with `minus * jnp.sum(noise_p2, -1)`, which left the noise dimension in the integrated state. Each step grew the output by two axes — `()` -> `(m, m)` -> `(m, m, m, m)` -> ... — a multi-GB blow-up over a long integration that exhausted the runner's RAM (macOS/Windows runners had enough headroom to survive, so only Linux crashed). Take the diagonal of the diffusion-bar block and contract the per-component Milstein correction over the noise axis, mirroring the shape-preserving pattern already used by `Milstein`. The integrated state now stays correctly shaped and the full suite peaks at a bounded ~9 GB with the crash region flat (no spike). Also make the suite pass on GPU developer machines without changing CPU/CI behaviour: * conftest.py: pin `jax_default_matmul_precision='highest'`. On NVIDIA GPUs the default uses TF32 for float32 matmuls (~1e-4 relative error), which broke the operator-vs-dense correctness comparisons (JIT-connectivity layers, orthonormality checks) on GPU while passing on CPU. * brainpy/dnn/linear_test.py: use float32-appropriate tolerances (`rtol=1e-4, atol=1e-5`) for the JIT-operator vs dense `x @ conn` comparison; the default `atol=1e-8` is tighter than float32 rounding for the near-zero symmetric-uniform outputs. * brainpy/math/object_transform/object_transform_fixes_test.py: guard the `.cuda()` / `.tpu()` `RuntimeError` assertions on device availability so the test stays meaningful on CPU-only CI yet does not fail on GPU/TPU machines. --- brainpy/dnn/linear_test.py | 30 +++++++++++++++---- brainpy/integrators/sde/normal.py | 14 +++++++-- .../object_transform_fixes_test.py | 21 ++++++++++--- conftest.py | 10 +++++++ 4 files changed, 63 insertions(+), 12 deletions(-) diff --git a/brainpy/dnn/linear_test.py b/brainpy/dnn/linear_test.py index 9195e8c5d..33fb923da 100644 --- a/brainpy/dnn/linear_test.py +++ b/brainpy/dnn/linear_test.py @@ -150,7 +150,10 @@ def test_JitFPHomoLinear(self, prob, weight, shape): self.assertTrue(y.shape == shape + (200,)) conn_matrix = f.get_conn_matrix() - self.assertTrue(bm.allclose(y, x @ conn_matrix.T)) + # float32-appropriate tolerances: the JIT operator and the dense ``x @ conn`` + # differ at float32 rounding level; the default ``atol=1e-8`` is tighter than that + # for the symmetric-uniform layer whose outputs sit near zero. + self.assertTrue(bm.allclose(y, x @ conn_matrix.T, rtol=1e-4, atol=1e-5)) # print(conn_matrix.shape) # self.assertTrue(conn_matrix.shape == (200, 100)) @@ -168,7 +171,10 @@ def test_JitFPUniformLinear(self, prob, w_low, w_high, shape): self.assertTrue(y.shape == shape + (200,)) conn_matrix = f.get_conn_matrix() - self.assertTrue(bm.allclose(y, x @ conn_matrix.T)) + # float32-appropriate tolerances: the JIT operator and the dense ``x @ conn`` + # differ at float32 rounding level; the default ``atol=1e-8`` is tighter than that + # for the symmetric-uniform layer whose outputs sit near zero. + self.assertTrue(bm.allclose(y, x @ conn_matrix.T, rtol=1e-4, atol=1e-5)) @parameterized.product( prob=[0.1], @@ -184,7 +190,10 @@ def test_JitFPNormalLinear(self, prob, w_mu, w_sigma, shape): self.assertTrue(y.shape == shape + (200,)) conn_matrix = f.get_conn_matrix() - self.assertTrue(bm.allclose(y, x @ conn_matrix.T)) + # float32-appropriate tolerances: the JIT operator and the dense ``x @ conn`` + # differ at float32 rounding level; the default ``atol=1e-8`` is tighter than that + # for the symmetric-uniform layer whose outputs sit near zero. + self.assertTrue(bm.allclose(y, x @ conn_matrix.T, rtol=1e-4, atol=1e-5)) @parameterized.product( prob=[0.1], @@ -202,7 +211,10 @@ def test_EventJitFPHomoLinear(self, prob, weight, shape): self.assertTrue(y2.shape == shape + (200,)) conn_matrix = f.get_conn_matrix() - self.assertTrue(bm.allclose(y, x @ conn_matrix.T)) + # float32-appropriate tolerances: the JIT operator and the dense ``x @ conn`` + # differ at float32 rounding level; the default ``atol=1e-8`` is tighter than that + # for the symmetric-uniform layer whose outputs sit near zero. + self.assertTrue(bm.allclose(y, x @ conn_matrix.T, rtol=1e-4, atol=1e-5)) @parameterized.product( prob=[0.1], @@ -221,7 +233,10 @@ def test_EventJitFPUniformLinear(self, prob, w_low, w_high, shape): self.assertTrue(y2.shape == shape + (200,)) conn_matrix = f.get_conn_matrix() - self.assertTrue(bm.allclose(y, x @ conn_matrix.T)) + # float32-appropriate tolerances: the JIT operator and the dense ``x @ conn`` + # differ at float32 rounding level; the default ``atol=1e-8`` is tighter than that + # for the symmetric-uniform layer whose outputs sit near zero. + self.assertTrue(bm.allclose(y, x @ conn_matrix.T, rtol=1e-4, atol=1e-5)) @parameterized.product( prob=[0.1], @@ -240,7 +255,10 @@ def test_EventJitFPNormalLinear(self, prob, w_mu, w_sigma, shape): self.assertTrue(y2.shape == shape + (200,)) conn_matrix = f.get_conn_matrix() - self.assertTrue(bm.allclose(y, x @ conn_matrix.T)) + # float32-appropriate tolerances: the JIT operator and the dense ``x @ conn`` + # differ at float32 rounding level; the default ``atol=1e-8`` is tighter than that + # for the symmetric-uniform layer whose outputs sit near zero. + self.assertTrue(bm.allclose(y, x @ conn_matrix.T, rtol=1e-4, atol=1e-5)) if __name__ == '__main__': diff --git a/brainpy/integrators/sde/normal.py b/brainpy/integrators/sde/normal.py index 3165b81ac..b75d9258a 100644 --- a/brainpy/integrators/sde/normal.py +++ b/brainpy/integrators/sde/normal.py @@ -502,10 +502,20 @@ def step(self, *args, **kwargs): else: integral += diffusions[key] * noise noise_p2 = (noise ** 2 - dt) if self.intg_type == constants.ITO_SDE else noise ** 2 - minus = (diffusion_bars[key] - diffusions[key]) / 2 / jnp.sqrt(dt) if self.wiener_type == constants.VECTOR_WIENER: - integral += minus * jnp.sum(noise_p2, axis=-1) + # ``y_bars[key]`` carries the noise axis (one support value per noise + # component ``j``: ``y_bar_j = Y + f dt + g_j sqrt(dt)``), so + # ``diffusion_bars[key]`` has a trailing ``(m, m)`` block whose diagonal is + # ``g_j(y_bar_j)``. Previously the full ``(m, m)`` matrix was used together + # with ``minus * jnp.sum(noise_p2, -1)``, which left the noise dimension in + # the state and grew the output by two axes every step (a multi-GB blow-up + # for long integrations). Take the diagonal and contract the per-component + # Milstein correction over the noise axis, mirroring ``Milstein``. + g_bar = jnp.diagonal(diffusion_bars[key], axis1=-2, axis2=-1) + minus = (g_bar - diffusions[key]) / 2 / jnp.sqrt(dt) + integral += jnp.sum(minus * noise_p2, axis=-1) else: + minus = (diffusion_bars[key] - diffusions[key]) / 2 / jnp.sqrt(dt) integral += minus * noise_p2 integrals.append(integral) return integrals if len(self.variables) > 1 else integrals[0] diff --git a/brainpy/math/object_transform/object_transform_fixes_test.py b/brainpy/math/object_transform/object_transform_fixes_test.py index 5e239dbee..ef1a7f9ae 100644 --- a/brainpy/math/object_transform/object_transform_fixes_test.py +++ b/brainpy/math/object_transform/object_transform_fixes_test.py @@ -1150,7 +1150,20 @@ def __init__(self): def test_cuda_tpu_raise_without_device(): obj = _Obj() - with pytest.raises(RuntimeError): - obj.cuda() - with pytest.raises(RuntimeError): - obj.tpu() + + # ``.cuda()`` / ``.tpu()`` move variables onto a GPU / TPU and only raise a + # ``RuntimeError`` when no such device is present. Guard each assertion on device + # availability so the test stays meaningful on CPU-only machines (the common CI + # case) without failing on developer machines that do expose a GPU / TPU. + def _device_available(platform): + try: + return len(jax.devices(platform)) > 0 + except RuntimeError: + return False + + if not _device_available('gpu'): + with pytest.raises(RuntimeError): + obj.cuda() + if not _device_available('tpu'): + with pytest.raises(RuntimeError): + obj.tpu() diff --git a/conftest.py b/conftest.py index e3ca3d3f4..930f65a90 100644 --- a/conftest.py +++ b/conftest.py @@ -6,8 +6,18 @@ analyses that call ``pyplot.show()``) never try to open a GUI window. This keeps the suite headless and non-blocking locally and in CI regardless of the ``MPLBACKEND`` environment variable. + +Also pin JAX's default matmul precision to ``highest``. On accelerators (notably +NVIDIA GPUs) the default precision uses TF32 for ``float32`` matmuls, which +introduces ~1e-4 relative error. Several correctness tests compare an operator's +full-precision output against a dense ``x @ W`` reference (e.g. the just-in-time +connectivity layers and orthonormality checks); with TF32 those comparisons fail +on GPU while passing on CPU. Pinning the precision makes the suite deterministic +and hardware-independent (CPU already runs at full ``float32``). """ +import jax import matplotlib matplotlib.use('Agg', force=True) +jax.config.update('jax_default_matmul_precision', 'highest') From ad1f2d42648100e80f511df93de228d7078ac4f2 Mon Sep 17 00:00:00 2001 From: chaoming Date: Fri, 19 Jun 2026 00:40:50 +0800 Subject: [PATCH 2/4] fix(deps): require braintools>=0.3.0 (0.2.0 was yanked from PyPI) braintools 0.2.0 was yanked from PyPI, so the previous `braintools>=0.2.0` pin became unsatisfiable and broke `pip install -r requirements.txt` on every CI runner (install failed in ~10s, before any test ran). 0.3.0 is the next released version and carries the surrogate / metric fixes that the `brainpy.math.surrogate` and L1-loss tests assert. Bump the pin in both requirements.txt and pyproject.toml. --- pyproject.toml | 2 +- requirements.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 80cc440e0..a58c60ab9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,7 +40,7 @@ dependencies = [ "brainstate>=0.5.1", "brainunit>=0.2.0", "brainevent>=0.0.7", - "braintools>=0.2.0", + "braintools>=0.3.0", 'brainpy_state>=0.0.3', ] diff --git a/requirements.txt b/requirements.txt index 23904ac62..6d7b7e001 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ numpy>=1.15 brainunit brainevent>=0.0.7 -braintools>=0.2.0 +braintools>=0.3.0 brainstate>=0.5.1 brainpy_state>=0.0.3 jax From fe031c521d815d3d6e0ade9121b6f3c431e5d1ea Mon Sep 17 00:00:00 2001 From: chaoming Date: Fri, 19 Jun 2026 01:17:27 +0800 Subject: [PATCH 3/4] refactor(math): reuse brainunit.math einops; drop the local port MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit brainpy's ein_reduce / ein_rearrange / ein_repeat / ein_shape duplicated the einops implementation that now lives in brainunit.math (einreduce / einrearrange / einrepeat / einshape — behaviour-identical and accepting brainpy ``Array`` instances directly). Re-export the historical ``ein_*`` names as thin aliases of brainunit's and delete the duplicated implementation (einops.py, einops_parsing.py) and its dedicated tests (einops_test.py, einops_coverage_test.py, einops_parsing_test.py). The einops tests in math_compat_fixes_test.py now exercise the public ``bm.ein_*`` aliases and assert the re-export wiring. --- brainpy/math/__init__.py | 9 +- brainpy/math/einops.py | 687 ------------------------- brainpy/math/einops_coverage_test.py | 150 ------ brainpy/math/einops_parsing.py | 167 ------ brainpy/math/einops_parsing_test.py | 125 ----- brainpy/math/einops_test.py | 346 ------------- brainpy/math/math_compat_fixes_test.py | 122 ++--- 7 files changed, 62 insertions(+), 1544 deletions(-) delete mode 100644 brainpy/math/einops.py delete mode 100644 brainpy/math/einops_coverage_test.py delete mode 100644 brainpy/math/einops_parsing.py delete mode 100644 brainpy/math/einops_parsing_test.py delete mode 100644 brainpy/math/einops_test.py diff --git a/brainpy/math/__init__.py b/brainpy/math/__init__.py index bc6382190..0bbc6731f 100644 --- a/brainpy/math/__init__.py +++ b/brainpy/math/__init__.py @@ -62,7 +62,14 @@ from .compat_tensorflow import * from .datatypes import * from .delayvars import * -from .einops import * +# einops-style helpers are reused from ``brainunit.math`` (the local port was +# removed); keep the historical ``ein_*`` names as thin aliases. +from brainunit.math import ( + einreduce as ein_reduce, + einrearrange as ein_rearrange, + einrepeat as ein_repeat, + einshape as ein_shape, +) from .environment import * from .interoperability import * # environment settings diff --git a/brainpy/math/einops.py b/brainpy/math/einops.py deleted file mode 100644 index 05c9ab2ab..000000000 --- a/brainpy/math/einops.py +++ /dev/null @@ -1,687 +0,0 @@ -# Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -import functools -import itertools -from collections import OrderedDict -from typing import Set, Tuple, List, Dict, Union, Callable, Optional, cast - -import jax -import numpy as np - -from . import compat_numpy as bnp -from . import others as bnp2 -from .einops_parsing import ParsedExpression, _ellipsis, AnonymousAxis, EinopsError -from .ndarray import Array - -__all__ = [ - 'ein_reduce', 'ein_rearrange', 'ein_repeat', 'ein_shape', -] - -Tensor = Union[Array, jax.Array] -ReductionCallable = Callable[[Tensor, Tuple[int, ...]], Tensor] -Reduction = Union[str, ReductionCallable] - -_reductions = ("min", "max", "sum", "mean", "prod", "any", "all") - -# magic integers are required to stay within -# traceable subset of language -_unknown_axis_length = -999999 -_expected_axis_length = -99999 - - -def _product(sequence: List[int]) -> int: - """minimalistic product that works both with numbers and symbols. Supports empty lists""" - result = 1 - for element in sequence: - result *= element - return result - - -def _reduce_axes(tensor, reduction_type: Reduction, reduced_axes: List[int]): - if callable(reduction_type): - # custom callable - return reduction_type(tensor, tuple(reduced_axes)) - else: - # one of built-in operations - assert reduction_type in _reductions - if reduction_type == "mean": - if not bnp2.is_float_type(tensor): - raise NotImplementedError("reduce_mean is not available for non-floating tensors") - return __reduce(tensor, reduction_type, tuple(reduced_axes)) - - -def __reduce(x: Union[Array, jax.Array], operation: str, reduced_axes): - if operation == "min": - return x.min(axis=reduced_axes) - elif operation == "max": - return x.max(axis=reduced_axes) - elif operation == "sum": - return x.sum(axis=reduced_axes) - elif operation == "mean": - return x.mean(axis=reduced_axes) - elif operation == "prod": - return x.prod(axis=reduced_axes) - elif operation == "any": - return x.any(axis=reduced_axes) - elif operation == "all": - return x.all(axis=reduced_axes) - else: - raise NotImplementedError("Unknown reduction ", operation) - - -CookedRecipe = Tuple[Optional[List[int]], Optional[List[int]], List[int], Dict[int, int], Optional[List[int]], int] - -# Actual type is tuple[tuple[str, int], ...] -# However torch.jit.script does not "understand" the correct type, -# and torch_specific will use list version. -HashableAxesLengths = Tuple[Tuple[str, int], ...] -FakeHashableAxesLengths = List[Tuple[str, int]] - - -class TransformRecipe: - """ - Recipe describes actual computation pathway. - Recipe can be applied to a tensor or variable. - """ - - # structure is non-mutable. In future, this can be non-mutable dataclass (python 3.7+) - # update: pytorch 2.0 torch.jit.script seems to have problems with dataclasses unless they were explicitly provided - - def __init__( - self, - # list of sizes (or just sizes) for elementary axes as they appear in left expression. - # this is what (after computing unknown parts) will be a shape after first transposition. - # This does not include any ellipsis dimensions. - elementary_axes_lengths: List[int], - # if additional axes are provided, they should be set in prev array - # This shows mapping from name to position - axis_name2elementary_axis: Dict[str, int], - # each dimension in input can help to reconstruct length of one elementary axis - # or verify one of dimensions. Each element points to element of elementary_axes_lengths. - input_composition_known_unknown: List[Tuple[List[int], List[int]]], - # permutation applied to elementary axes, if ellipsis is absent - axes_permutation: List[int], - # permutation puts reduced axes in the end, we only need to know the first position. - first_reduced_axis: int, - # at which positions which of elementary axes should appear. Axis position -> axis index. - added_axes: Dict[int, int], - # ids of axes as they appear in result, again pointers to elementary_axes_lengths, - # only used to infer result dimensions - output_composite_axes: List[List[int]], - ): - self.elementary_axes_lengths: List[int] = elementary_axes_lengths - self.axis_name2elementary_axis: Dict[str, int] = axis_name2elementary_axis - self.input_composition_known_unknown: List[Tuple[List[int], List[int]]] = input_composition_known_unknown - self.axes_permutation: List[int] = axes_permutation - - self.first_reduced_axis: int = first_reduced_axis - self.added_axes: Dict[int, int] = added_axes - self.output_composite_axes: List[List[int]] = output_composite_axes - - -def _reconstruct_from_shape_uncached( - self: TransformRecipe, shape: List[int], axes_dims: FakeHashableAxesLengths -) -> CookedRecipe: - """ - Reconstruct all actual parameters using shape. - Shape is a tuple that may contain integers, shape symbols (tf, theano) and UnknownSize (tf, previously mxnet) - known axes can be integers or symbols, but not Nones. - """ - # magic number - need_init_reshape = False - - # last axis is allocated for collapsed ellipsis - axes_lengths: List[int] = list(self.elementary_axes_lengths) - for axis, dim in axes_dims: - axes_lengths[self.axis_name2elementary_axis[axis]] = dim - - for input_axis, (known_axes, unknown_axes) in enumerate(self.input_composition_known_unknown): - length = shape[input_axis] - if len(known_axes) == 0 and len(unknown_axes) == 1: - # shortcut for the most common case - axes_lengths[unknown_axes[0]] = length - continue - - known_product = 1 - for axis in known_axes: - known_product *= axes_lengths[axis] - - if len(unknown_axes) == 0: - if isinstance(length, int) and isinstance(known_product, int) and length != known_product: - raise EinopsError(f"Shape mismatch, {length} != {known_product}") - else: - # assert len(unknown_axes) == 1, 'this is enforced when recipe is created, so commented out' - if isinstance(length, int) and isinstance(known_product, int) and length % known_product != 0: - raise EinopsError(f"Shape mismatch, can't divide axis of length {length} in chunks of {known_product}") - - unknown_axis = unknown_axes[0] - inferred_length: int = length // known_product - axes_lengths[unknown_axis] = inferred_length - - if len(known_axes) + len(unknown_axes) != 1: - need_init_reshape = True - - # at this point all axes_lengths are computed (either have values or variables, but not Nones) - - # elementary axes are ordered as they appear in input, then all added axes - init_shapes: Optional[List[int]] = axes_lengths[: len(self.axes_permutation)] if need_init_reshape else None - - need_final_reshape = False - final_shapes: List[int] = [] - for grouping in self.output_composite_axes: - lengths = [axes_lengths[elementary_axis] for elementary_axis in grouping] - final_shapes.append(_product(lengths)) - if len(lengths) != 1: - need_final_reshape = True - - added_axes: Dict[int, int] = { - pos: axes_lengths[pos_in_elementary] for pos, pos_in_elementary in self.added_axes.items() - } - - # this list can be empty - reduced_axes = list(range(self.first_reduced_axis, len(self.axes_permutation))) - - n_axes_after_adding_axes = len(added_axes) + len(self.axes_permutation) - - axes_reordering: Optional[List[int]] = self.axes_permutation - if self.axes_permutation == list(range(len(self.axes_permutation))): - axes_reordering = None - - _final_shapes = final_shapes if need_final_reshape else None - return init_shapes, axes_reordering, reduced_axes, added_axes, _final_shapes, n_axes_after_adding_axes - - -_reconstruct_from_shape = functools.lru_cache(1024)(_reconstruct_from_shape_uncached) - - -def _apply_recipe( - recipe: TransformRecipe, tensor: Tensor, reduction_type: Reduction, axes_lengths: HashableAxesLengths -) -> Tensor: - # this method implements actual work for all backends for 3 operations - try: - init_shapes, axes_reordering, reduced_axes, added_axes, final_shapes, n_axes_w_added = ( - _reconstruct_from_shape(recipe, bnp.shape(tensor), axes_lengths)) - except TypeError: - # shape or one of passed axes lengths is not hashable (i.e. they are symbols) - _result = _reconstruct_from_shape_uncached(recipe, bnp.shape(tensor), axes_lengths) - (init_shapes, axes_reordering, reduced_axes, added_axes, final_shapes, n_axes_w_added) = _result - if init_shapes is not None: - tensor = bnp.reshape(bnp.as_jax(tensor), init_shapes) - if axes_reordering is not None: - tensor = bnp.transpose(bnp.as_jax(tensor), axes_reordering) - if len(reduced_axes) > 0: - tensor = _reduce_axes(bnp.as_jax(tensor), reduction_type=reduction_type, reduced_axes=reduced_axes) - if len(added_axes) > 0: - tensor = bnp2.add_axes(tensor, n_axes=n_axes_w_added, pos2len=added_axes) - if final_shapes is not None: - tensor = bnp.reshape(bnp.as_jax(tensor), final_shapes) - return tensor - - -def _apply_recipe_array_api( - xp, recipe: TransformRecipe, tensor: Tensor, reduction_type: Reduction, axes_lengths: HashableAxesLengths -) -> Tensor: - # completely-inline implementation - init_shapes, axes_reordering, reduced_axes, added_axes, final_shapes, n_axes_w_added = _reconstruct_from_shape( - recipe, tensor.shape, axes_lengths - ) - if init_shapes is not None: - tensor = xp.reshape(tensor, init_shapes) - if axes_reordering is not None: - tensor = xp.permute_dims(tensor, axes_reordering) - if len(reduced_axes) > 0: - if callable(reduction_type): - # custom callable - tensor = reduction_type(tensor, tuple(reduced_axes)) - else: - # one of built-in operations - assert reduction_type in _reductions - tensor = getattr(xp, reduction_type)(tensor, axis=tuple(reduced_axes)) - if len(added_axes) > 0: - # we use broadcasting - for axis_position, axis_length in added_axes.items(): - tensor = xp.expand_dims(tensor, axis=axis_position) - - final_shape = list(tensor.shape) - for axis_position, axis_length in added_axes.items(): - final_shape[axis_position] = axis_length - - tensor = xp.broadcast_to(tensor, final_shape) - if final_shapes is not None: - tensor = xp.reshape(tensor, final_shapes) - return tensor - - -@functools.lru_cache(256) -def _prepare_transformation_recipe( - pattern: str, - operation: Reduction, - axes_names: Tuple[str, ...], - ndim: int, -) -> TransformRecipe: - """Perform initial parsing of pattern and provided supplementary info - axes_lengths is a tuple of tuples (axis_name, axis_length) - """ - left_str, rght_str = pattern.split("->") - left = ParsedExpression(left_str) - rght = ParsedExpression(rght_str) - - # checking that axes are in agreement - new axes appear only in repeat, while disappear only in reduction - if not left.has_ellipsis and rght.has_ellipsis: - raise EinopsError("Ellipsis found in right side, but not left side of a pattern {}".format(pattern)) - if left.has_ellipsis and left.has_ellipsis_parenthesized: - raise EinopsError("Ellipsis inside parenthesis in the left side is not allowed: {}".format(pattern)) - if operation == "rearrange": - if left.has_non_unitary_anonymous_axes or rght.has_non_unitary_anonymous_axes: - raise EinopsError("Non-unitary anonymous axes are not supported in rearrange (exception is length 1)") - difference = set.symmetric_difference(left.identifiers, rght.identifiers) - if len(difference) > 0: - raise EinopsError("Identifiers only on one side of expression (should be on both): {}".format(difference)) - elif operation == "repeat": - difference = set.difference(left.identifiers, rght.identifiers) - if len(difference) > 0: - raise EinopsError("Unexpected identifiers on the left side of repeat: {}".format(difference)) - axes_without_size = set.difference( - {ax for ax in rght.identifiers if not isinstance(ax, AnonymousAxis)}, - {*left.identifiers, *axes_names}, - ) - if len(axes_without_size) > 0: - raise EinopsError("Specify sizes for new axes in repeat: {}".format(axes_without_size)) - elif operation in _reductions or callable(operation): - difference = set.difference(rght.identifiers, left.identifiers) - if len(difference) > 0: - raise EinopsError("Unexpected identifiers on the right side of reduce {}: {}".format(operation, difference)) - else: - raise EinopsError("Unknown reduction {}. Expect one of {}.".format(operation, _reductions)) - - if left.has_ellipsis: - n_other_dims = len(left.composition) - 1 - if ndim < n_other_dims: - raise EinopsError(f"Wrong shape: expected >={n_other_dims} dims. Received {ndim}-dim tensor.") - ellipsis_ndim = ndim - n_other_dims - ell_axes = [_ellipsis + str(i) for i in range(ellipsis_ndim)] - left_composition = [] - for composite_axis in left.composition: - if composite_axis == _ellipsis: - for axis in ell_axes: - left_composition.append([axis]) - else: - left_composition.append(composite_axis) - - rght_composition = [] - for composite_axis in rght.composition: - if composite_axis == _ellipsis: - for axis in ell_axes: - rght_composition.append([axis]) - else: - group = [] - for axis in composite_axis: - if axis == _ellipsis: - group.extend(ell_axes) - else: - group.append(axis) - rght_composition.append(group) - - left.identifiers.update(ell_axes) - left.identifiers.remove(_ellipsis) - if rght.has_ellipsis: - rght.identifiers.update(ell_axes) - rght.identifiers.remove(_ellipsis) - else: - if ndim != len(left.composition): - raise EinopsError(f"Wrong shape: expected {len(left.composition)} dims. Received {ndim}-dim tensor.") - left_composition = left.composition - rght_composition = rght.composition - - # parsing all dimensions to find out lengths - axis_name2known_length: Dict[Union[str, AnonymousAxis], int] = OrderedDict() - for composite_axis in left_composition: - for axis_name in composite_axis: - if isinstance(axis_name, AnonymousAxis): - axis_name2known_length[axis_name] = axis_name.value - else: - axis_name2known_length[axis_name] = _unknown_axis_length - - # axis_ids_after_first_reshape = range(len(axis_name2known_length)) at this point - - repeat_axes_names = [] - for axis_name in rght.identifiers: - if axis_name not in axis_name2known_length: - if isinstance(axis_name, AnonymousAxis): - axis_name2known_length[axis_name] = axis_name.value - else: - axis_name2known_length[axis_name] = _unknown_axis_length - repeat_axes_names.append(axis_name) - - axis_name2position = {name: position for position, name in enumerate(axis_name2known_length)} - - # axes provided as kwargs - for elementary_axis in axes_names: - if not ParsedExpression.check_axis_name(elementary_axis): - raise EinopsError("Invalid name for an axis", elementary_axis) - if elementary_axis not in axis_name2known_length: - raise EinopsError("Axis {} is not used in transform".format(elementary_axis)) - axis_name2known_length[elementary_axis] = _expected_axis_length - - input_axes_known_unknown = [] - # some shapes are inferred later - all information is prepared for faster inference - for i, composite_axis in enumerate(left_composition): - known: Set[str] = {axis for axis in composite_axis if axis_name2known_length[axis] != _unknown_axis_length} - unknown: Set[str] = {axis for axis in composite_axis if axis_name2known_length[axis] == _unknown_axis_length} - if len(unknown) > 1: - raise EinopsError("Could not infer sizes for {}".format(unknown)) - assert len(unknown) + len(known) == len(composite_axis) - input_axes_known_unknown.append( - ([axis_name2position[axis] for axis in known], [axis_name2position[axis] for axis in unknown]) - ) - - axis_position_after_reduction: Dict[str, int] = {} - for axis_name in itertools.chain(*left_composition): - if axis_name in rght.identifiers: - axis_position_after_reduction[axis_name] = len(axis_position_after_reduction) - - result_axes_grouping: List[List[int]] = [ - [axis_name2position[axis] for axis in composite_axis] for i, composite_axis in enumerate(rght_composition) - ] - - ordered_axis_left = list(itertools.chain(*left_composition)) - ordered_axis_rght = list(itertools.chain(*rght_composition)) - reduced_axes = [axis for axis in ordered_axis_left if axis not in rght.identifiers] - order_after_transposition = [axis for axis in ordered_axis_rght if axis in left.identifiers] + reduced_axes - axes_permutation = [ordered_axis_left.index(axis) for axis in order_after_transposition] - added_axes = { - i: axis_name2position[axis_name] - for i, axis_name in enumerate(ordered_axis_rght) - if axis_name not in left.identifiers - } - - first_reduced_axis = len(order_after_transposition) - len(reduced_axes) - - return TransformRecipe( - elementary_axes_lengths=list(axis_name2known_length.values()), - axis_name2elementary_axis={axis: axis_name2position[axis] for axis in axes_names}, - input_composition_known_unknown=input_axes_known_unknown, - axes_permutation=axes_permutation, - first_reduced_axis=first_reduced_axis, - added_axes=added_axes, - output_composite_axes=result_axes_grouping, - ) - - -def _prepare_recipes_for_all_dims( - pattern: str, operation: Reduction, axes_names: Tuple[str, ...] -) -> Dict[int, TransformRecipe]: - """ - Internal function, used in layers. - Layer makes all recipe creation when it is initialized, thus to keep recipes simple we pre-compute for all dims - """ - left_str, rght_str = pattern.split("->") - left = ParsedExpression(left_str) - dims = [len(left.composition)] - if left.has_ellipsis: - dims = [len(left.composition) - 1 + ellipsis_dims for ellipsis_dims in range(8)] - return {ndim: _prepare_transformation_recipe(pattern, operation, axes_names, ndim=ndim) for ndim in dims} - - -def ein_reduce(tensor: Union[Tensor, List[Tensor]], pattern: str, reduction: Reduction, **axes_lengths: int) -> Tensor: - """ - ``ein_reduce`` provides combination of reordering and reduction using reader-friendly notation. - - Examples for reduce operation: - - ```python - >>> x = np.random.randn(100, 32, 64) - - # perform max-reduction on the first axis - >>> y = ein_reduce(x, 't b c -> b c', 'max') - - # same as previous, but with clearer axes meaning - >>> y = ein_reduce(x, 'time batch channel -> batch channel', 'max') - - >>> x = np.random.randn(10, 20, 30, 40) - - # 2d max-pooling with kernel size = 2 * 2 for image processing - >>> y1 = ein_reduce(x, 'b c (h1 h2) (w1 w2) -> b c h1 w1', 'max', h2=2, w2=2) - - # if one wants to go back to the original height and width, depth-to-space trick can be applied - >>> y2 = ein_rearrange(y1, 'b (c h2 w2) h1 w1 -> b c (h1 h2) (w1 w2)', h2=2, w2=2) - >>> assert ein_shape(x, 'b _ h w') == ein_shape(y2, 'b _ h w') - - # Adaptive 2d max-pooling to 3 * 4 grid - >>> ein_reduce(x, 'b c (h1 h2) (w1 w2) -> b c h1 w1', 'max', h1=3, w1=4).shape - (10, 20, 3, 4) - - # Global average pooling - >>> ein_reduce(x, 'b c h w -> b c', 'mean').shape - (10, 20) - - # Subtracting mean over batch for each channel - >>> y = x - ein_reduce(x, 'b c h w -> () c () ()', 'mean') - - # Subtracting per-image mean for each channel - >>> y = x - ein_reduce(x, 'b c h w -> b c () ()', 'mean') - - ``` - - Parameters: - tensor: tensor: tensor of any supported library (e.g. numpy.ndarray, tensorflow, pytorch). - list of tensors is also accepted, those should be of the same type and shape - pattern: string, reduction pattern - reduction: one of available reductions ('min', 'max', 'sum', 'mean', 'prod'), case-sensitive - alternatively, a callable f(tensor, reduced_axes) -> tensor can be provided. - This allows using various reductions, examples: np.max, tf.reduce_logsumexp, torch.var, etc. - axes_lengths: any additional specifications for dimensions - - Returns: - tensor of the same type as input - """ - try: - hashable_axes_lengths = tuple(axes_lengths.items()) - shape = bnp.shape(tensor) - recipe = _prepare_transformation_recipe(pattern, reduction, axes_names=tuple(axes_lengths), ndim=len(shape)) - return _apply_recipe(recipe, - cast(Tensor, tensor), - reduction_type=reduction, - axes_lengths=hashable_axes_lengths) - except EinopsError as e: - message = ' Error while processing {}-reduction pattern "{}".'.format(reduction, pattern) - if not isinstance(tensor, list): - message += "\n Input tensor shape: {}. ".format(shape) - else: - message += "\n Input is list. " - message += "Additional info: {}.".format(axes_lengths) - raise EinopsError(message + "\n {}".format(e)) - - -def ein_rearrange(tensor: Union[Tensor, List[Tensor]], pattern: str, **axes_lengths) -> Tensor: - """ - ``ein_rearrange`` is a reader-friendly smart element reordering for multidimensional tensors. - This operation includes functionality of transpose (axes permutation), reshape (view), squeeze, unsqueeze, - stack, concatenate and other operations. - - Examples for rearrange operation: - - ```python - # suppose we have a set of 32 images in "h w c" format (height-width-channel) - >>> images = [np.random.randn(30, 40, 3) for _ in range(32)] - - # stack along first (batch) axis, output is a single array - >>> ein_rearrange(images, 'b h w c -> b h w c').shape - (32, 30, 40, 3) - - # concatenate images along height (vertical axis), 960 = 32 * 30 - >>> ein_rearrange(images, 'b h w c -> (b h) w c').shape - (960, 40, 3) - - # concatenated images along horizontal axis, 1280 = 32 * 40 - >>> ein_rearrange(images, 'b h w c -> h (b w) c').shape - (30, 1280, 3) - - # reordered axes to "b c h w" format for deep learning - >>> ein_rearrange(images, 'b h w c -> b c h w').shape - (32, 3, 30, 40) - - # flattened each image into a vector, 3600 = 30 * 40 * 3 - >>> ein_rearrange(images, 'b h w c -> b (c h w)').shape - (32, 3600) - - # split each image into 4 smaller (top-left, top-right, bottom-left, bottom-right), 128 = 32 * 2 * 2 - >>> ein_rearrange(images, 'b (h1 h) (w1 w) c -> (b h1 w1) h w c', h1=2, w1=2).shape - (128, 15, 20, 3) - - # space-to-depth operation - >>> ein_rearrange(images, 'b (h h1) (w w1) c -> b h w (c h1 w1)', h1=2, w1=2).shape - (32, 15, 20, 12) - - ``` - - When composing axes, C-order enumeration used (consecutive elements have different last axis) - Find more examples in einops tutorial. - - Parameters: - tensor: tensor of any supported library (e.g. numpy.ndarray, tensorflow, pytorch). - list of tensors is also accepted, those should be of the same type and shape - pattern: string, rearrangement pattern - axes_lengths: any additional specifications for dimensions - - Returns: - tensor of the same type as input. If possible, a view to the original tensor is returned. - - """ - return ein_reduce(tensor, pattern, reduction="rearrange", **axes_lengths) - - -def ein_repeat(tensor: Union[Tensor, List[Tensor]], pattern: str, **axes_lengths) -> Tensor: - """ - ``ein_repeat`` allows reordering elements and repeating them in arbitrary combinations. - This operation includes functionality of repeat, tile, broadcast functions. - - Examples for repeat operation: - - ```python - # a grayscale image (of shape height x width) - >>> image = np.random.randn(30, 40) - - # change it to RGB format by repeating in each channel - >>> ein_repeat(image, 'h w -> h w c', c=3).shape - (30, 40, 3) - - # repeat image 2 times along height (vertical axis) - >>> ein_repeat(image, 'h w -> (repeat h) w', repeat=2).shape - (60, 40) - - # repeat image 2 time along height and 3 times along width - >>> ein_repeat(image, 'h w -> (h2 h) (w3 w)', h2=2, w3=3).shape - (60, 120) - - # convert each pixel to a small square 2x2. Upsample image by 2x - >>> ein_repeat(image, 'h w -> (h h2) (w w2)', h2=2, w2=2).shape - (60, 80) - - # pixelate image first by downsampling by 2x, then upsampling - >>> downsampled = ein_reduce(image, '(h h2) (w w2) -> h w', 'mean', h2=2, w2=2) - >>> ein_repeat(downsampled, 'h w -> (h h2) (w w2)', h2=2, w2=2).shape - (30, 40) - - ``` - - When composing axes, C-order enumeration used (consecutive elements have different last axis) - Find more examples in einops tutorial. - - Parameters: - tensor: tensor of any supported library (e.g. numpy.ndarray, tensorflow, pytorch). - list of tensors is also accepted, those should be of the same type and shape - pattern: string, rearrangement pattern - axes_lengths: any additional specifications for dimensions - - Returns: - Tensor of the same type as input. If possible, a view to the original tensor is returned. - - """ - return ein_reduce(tensor, pattern, reduction="repeat", **axes_lengths) - - -def ein_shape(x, pattern: str) -> dict: - """ - Parse a tensor shape to dictionary mapping axes names to their lengths. - - ```python - # Use underscore to skip the dimension in parsing. - >>> x = np.zeros([2, 3, 5, 7]) - >>> ein_shape(x, 'batch _ h w') - {'batch': 2, 'h': 5, 'w': 7} - - # `parse_shape` output can be used to specify axes_lengths for other operations: - >>> y = np.zeros([700]) - >>> ein_rearrange(y, '(b c h w) -> b c h w', **ein_shape(x, 'b _ h w')).shape - (2, 10, 5, 7) - - ``` - - For symbolic frameworks may return symbols, not integers. - - Parameters: - x: tensor of any supported framework - pattern: str, space separated names for axes, underscore means skip axis - - Returns: - dict, maps axes names to their lengths - """ - exp = ParsedExpression(pattern, allow_underscore=True) - shape = bnp.shape(x) - if exp.has_composed_axes(): - raise RuntimeError(f"Can't parse shape with composite axes: {pattern} {shape}") - if len(shape) != len(exp.composition): - if exp.has_ellipsis: - if len(shape) < len(exp.composition) - 1: - raise RuntimeError(f"Can't parse shape with this number of dimensions: {pattern} {shape}") - else: - raise RuntimeError(f"Can't parse shape with different number of dimensions: {pattern} {shape}") - if exp.has_ellipsis: - ellipsis_idx = exp.composition.index(_ellipsis) - composition = ( - exp.composition[:ellipsis_idx] - + ["_"] * (len(shape) - len(exp.composition) + 1) - + exp.composition[ellipsis_idx + 1:] - ) - else: - composition = exp.composition - result = {} - for (axis_name,), axis_length in zip(composition, shape): # type: ignore - if axis_name != "_": - result[axis_name] = axis_length - return result - - -# _enumerate_directions is not exposed in the public API -def _enumerate_directions(x): - """ - For an n-dimensional tensor, returns tensors to enumerate each axis. - ```python - x = np.zeros([2, 3, 4]) # or any other tensor - i, j, k = _enumerate_directions(x) - result = i + 2*j + 3*k - ``` - - `result[i, j, k] = i + 2j + 3k`, and also has the same shape as result - Works very similarly to numpy.ogrid (open indexing grid) - """ - shape = bnp.shape(x) - result = [] - for axis_id, axis_length in enumerate(shape): - shape = [1] * len(shape) - shape[axis_id] = axis_length - result.append(bnp.reshape(bnp.arange(0, axis_length), shape)) - return result diff --git a/brainpy/math/einops_coverage_test.py b/brainpy/math/einops_coverage_test.py deleted file mode 100644 index 6d3778733..000000000 --- a/brainpy/math/einops_coverage_test.py +++ /dev/null @@ -1,150 +0,0 @@ -# -*- coding: utf-8 -*- -"""Supplementary coverage tests for ``brainpy/math/einops.py``. - -``einops_test.py`` and ``math_compat_fixes_test.py`` cover the happy paths. -This module targets the remaining uncovered error / edge branches: - -* ``_reconstruct_from_shape_uncached`` shape-mismatch ``EinopsError``s - (``len(unknown)==0`` exact-match and the ``length % known_product`` chunk - check). -* ``_prepare_transformation_recipe`` validation branches: - - ellipsis on the right but not the left (line 283); - - non-unitary anonymous axes in ``rearrange`` (288); - - unexpected identifiers on the left of ``repeat`` (295); - - too few dims for an ellipsis pattern (312); - - invalid kwarg axis name (373) / unused kwarg axis (375); - - "could not infer sizes" with two unknowns in a group (384). -* ``_prepare_recipes_for_all_dims`` for both the fixed-rank and ellipsis cases. -* ``ein_reduce`` error-message formatting for list inputs (line 503). -* ``ein_shape`` ellipsis / non-ellipsis dimension-count ``RuntimeError``s and the - composite-axis ``RuntimeError``. - -The ``_apply_recipe`` unhashable-shape ``TypeError`` fallback (lines 216-219) -and the entire ``_apply_recipe_array_api`` helper (237-264) are not reachable -through the public ``ein_*`` API with concrete jax arrays (they target symbolic -shapes / the array-API path), so they are documented rather than forced. -""" - -import jax.numpy as jnp -import numpy as np -import pytest - -from brainpy.math import einops as ein -from brainpy.math.einops import _prepare_recipes_for_all_dims -from brainpy.math.einops_parsing import EinopsError - - -# --------------------------------------------------------------------------- -# _reconstruct_from_shape_uncached shape checks -# --------------------------------------------------------------------------- - -def test_reconstruct_exact_known_product_mismatch(): - # group has no unknown axis; the product of known sizes must equal the dim. - # '(a b) -> a b' with a=2,b=2 expects 4 but the tensor axis is 6. - with pytest.raises(EinopsError): - ein.ein_rearrange(jnp.arange(6.), '(a b) -> a b', a=2, b=2) - - -def test_reconstruct_indivisible_chunk_mismatch(): - # one unknown axis, but length not divisible by the known product. - # axis length 6 cannot be split into chunks of a=4. - with pytest.raises(EinopsError): - ein.ein_rearrange(jnp.arange(6.), '(a b) -> a b', a=4) - - -# --------------------------------------------------------------------------- -# _prepare_transformation_recipe validation branches -# --------------------------------------------------------------------------- - -def test_ellipsis_on_right_only_raises(): - with pytest.raises(EinopsError): - ein.ein_rearrange(jnp.arange(6.).reshape(2, 3), 'a b -> a b ...') - - -def test_rearrange_non_unitary_anonymous_axis_raises(): - with pytest.raises(EinopsError): - ein.ein_rearrange(jnp.arange(6.).reshape(2, 3), 'a b -> a b 2') - - -def test_repeat_unexpected_left_identifier_raises(): - # 'c' appears on the left but not the right of a repeat -> error. - with pytest.raises(EinopsError): - ein.ein_repeat(jnp.arange(6.).reshape(2, 3), 'a b c -> a b', c=1) - - -def test_ellipsis_too_few_dims_raises(): - # pattern needs >= 2 explicit dims plus the ellipsis, tensor is only 1-D. - with pytest.raises(EinopsError): - ein.ein_reduce(jnp.arange(6.), 'a b ... -> a', 'sum') - - -def test_invalid_kwarg_axis_name_raises(): - with pytest.raises(EinopsError): - ein.ein_rearrange(jnp.arange(6.).reshape(2, 3), 'a b -> a b', **{'1bad': 2}) - - -def test_unused_kwarg_axis_raises(): - with pytest.raises(EinopsError): - ein.ein_rearrange(jnp.arange(6.).reshape(2, 3), 'a b -> a b', z=2) - - -def test_cannot_infer_two_unknowns_raises(): - # both 'a' and 'b' are unknown inside one composite group -> not inferable. - with pytest.raises(EinopsError): - ein.ein_rearrange(jnp.arange(6.), '(a b) -> a b') - - -# --------------------------------------------------------------------------- -# _prepare_recipes_for_all_dims -# --------------------------------------------------------------------------- - -def test_prepare_recipes_fixed_rank(): - recipes = _prepare_recipes_for_all_dims('a b -> b a', 'rearrange', ()) - # exactly one entry, keyed by the fixed number of left composite axes (2) - assert list(recipes.keys()) == [2] - - -def test_prepare_recipes_with_ellipsis_precomputes_8(): - recipes = _prepare_recipes_for_all_dims('a ... -> ... a', 'rearrange', ()) - # ellipsis path pre-computes recipes for 0..7 extra ellipsis dims -> 8 entries - assert len(recipes) == 8 - - -# --------------------------------------------------------------------------- -# ein_reduce error-message formatting -# --------------------------------------------------------------------------- - -def test_error_message_for_list_input_mentions_list(): - with pytest.raises(EinopsError) as exc: - ein.ein_reduce([jnp.arange(6.), jnp.arange(6.)], 'a -> b', 'sum') - assert 'Input is list' in str(exc.value) - - -def test_error_message_for_array_input_mentions_shape(): - with pytest.raises(EinopsError) as exc: - ein.ein_rearrange(jnp.arange(6.), 'a b c -> a b c') # wrong ndim - assert 'Input tensor shape' in str(exc.value) - - -# --------------------------------------------------------------------------- -# ein_shape edge branches -# --------------------------------------------------------------------------- - -def test_ein_shape_composite_axes_raises(): - with pytest.raises(RuntimeError): - ein.ein_shape(jnp.zeros((6,)), '(a b)') - - -def test_ein_shape_wrong_ndim_no_ellipsis_raises(): - with pytest.raises(RuntimeError): - ein.ein_shape(jnp.zeros((2, 3)), 'a b c') # 2 dims vs 3 names - - -def test_ein_shape_ellipsis_too_few_dims_raises(): - with pytest.raises(RuntimeError): - ein.ein_shape(jnp.zeros((2,)), 'a b ... c') # needs >= 3 dims - - -def test_ein_shape_ellipsis_ok(): - out = ein.ein_shape(jnp.zeros((2, 3, 5, 7)), 'b ... w') - assert out == {'b': 2, 'w': 7} diff --git a/brainpy/math/einops_parsing.py b/brainpy/math/einops_parsing.py deleted file mode 100644 index f8ca63cae..000000000 --- a/brainpy/math/einops_parsing.py +++ /dev/null @@ -1,167 +0,0 @@ -# Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -import keyword -import warnings -from typing import List, Optional, Set, Tuple, Union - -_ellipsis: str = '…' # NB, this is a single unicode symbol. String is used as it is not a list, but can be iterated - - -class EinopsError(Exception): - pass - - -class AnonymousAxis(object): - """Important thing: all instances of this class are not equal to each other """ - - def __init__(self, value: str): - self.value = int(value) - if self.value <= 1: - if self.value == 1: - raise EinopsError('No need to create anonymous axis of length 1. Report this as an issue') - else: - raise EinopsError('Anonymous axis should have positive length, not {}'.format(self.value)) - - def __repr__(self): - return "{}-axis".format(str(self.value)) - - -class ParsedExpression: - """ - non-mutable structure that contains information about one side of expression (e.g. 'b c (h w)') - and keeps some information important for downstream - """ - - def __init__(self, expression: str, *, allow_underscore: bool = False, - allow_duplicates: bool = False): - self.has_ellipsis: bool = False - self.has_ellipsis_parenthesized: Optional[bool] = None - self.identifiers: Set[str] = set() - # that's axes like 2, 3, 4 or 5. Axes with size 1 are exceptional and replaced with empty composition - self.has_non_unitary_anonymous_axes: bool = False - # composition keeps structure of composite axes, see how different corner cases are handled in tests - self.composition: List[Union[List[str], str]] = [] - if '.' in expression: - if '...' not in expression: - raise EinopsError('Expression may contain dots only inside ellipsis (...)') - if str.count(expression, '...') != 1 or str.count(expression, '.') != 3: - raise EinopsError( - 'Expression may contain dots only inside ellipsis (...); only one ellipsis for tensor ') - expression = expression.replace('...', _ellipsis) - self.has_ellipsis = True - - bracket_group: Optional[List[str]] = None - - def add_axis_name(x): - if x in self.identifiers: - if not (allow_underscore and x == "_") and not allow_duplicates: - raise EinopsError('Indexing expression contains duplicate dimension "{}"'.format(x)) - if x == _ellipsis: - self.identifiers.add(_ellipsis) - if bracket_group is None: - self.composition.append(_ellipsis) - self.has_ellipsis_parenthesized = False - else: - bracket_group.append(_ellipsis) - self.has_ellipsis_parenthesized = True - else: - is_number = str.isdecimal(x) - if is_number and int(x) == 1: - # handling the case of anonymous axis of length 1 - if bracket_group is None: - self.composition.append([]) - else: - pass # no need to think about 1s inside parenthesis - return - is_axis_name, reason = self.check_axis_name_return_reason(x, allow_underscore=allow_underscore) - if not (is_number or is_axis_name): - raise EinopsError('Invalid axis identifier: {}\n{}'.format(x, reason)) - if is_number: - x = AnonymousAxis(x) - self.identifiers.add(x) - if is_number: - self.has_non_unitary_anonymous_axes = True - if bracket_group is None: - self.composition.append([x]) - else: - bracket_group.append(x) - - current_identifier = None - for char in expression: - if char in '() ': - if current_identifier is not None: - add_axis_name(current_identifier) - current_identifier = None - if char == '(': - if bracket_group is not None: - raise EinopsError("Axis composition is one-level (brackets inside brackets not allowed)") - bracket_group = [] - elif char == ')': - if bracket_group is None: - raise EinopsError('Brackets are not balanced') - self.composition.append(bracket_group) - bracket_group = None - elif str.isalnum(char) or char in ['_', _ellipsis]: - if current_identifier is None: - current_identifier = char - else: - current_identifier += char - else: - raise EinopsError("Unknown character '{}'".format(char)) - - if bracket_group is not None: - raise EinopsError('Imbalanced parentheses in expression: "{}"'.format(expression)) - if current_identifier is not None: - add_axis_name(current_identifier) - - def flat_axes_order(self) -> List: - result = [] - for composed_axis in self.composition: - assert isinstance(composed_axis, list), 'does not work with ellipsis' - for axis in composed_axis: - result.append(axis) - return result - - def has_composed_axes(self) -> bool: - # this will ignore 1 inside brackets - for axes in self.composition: - if isinstance(axes, list) and len(axes) > 1: - return True - return False - - @staticmethod - def check_axis_name_return_reason(name: str, allow_underscore: bool = False) -> Tuple[bool, str]: - if not str.isidentifier(name): - return False, 'not a valid python identifier' - elif name[0] == '_' or name[-1] == '_': - if name == '_' and allow_underscore: - return True, '' - return False, 'axis name should should not start or end with underscore' - else: - if keyword.iskeyword(name): - warnings.warn("It is discouraged to use axes names that are keywords: {}".format(name), RuntimeWarning) - if name in ['axis']: - warnings.warn("It is discouraged to use 'axis' as an axis name " - "and will raise an error in future", FutureWarning) - return True, '' - - @staticmethod - def check_axis_name(name: str) -> bool: - """ - Valid axes names are python identifiers except keywords, - and additionally should not start or end with underscore - """ - is_valid, _reason = ParsedExpression.check_axis_name_return_reason(name) - return is_valid diff --git a/brainpy/math/einops_parsing_test.py b/brainpy/math/einops_parsing_test.py deleted file mode 100644 index 46bad3f89..000000000 --- a/brainpy/math/einops_parsing_test.py +++ /dev/null @@ -1,125 +0,0 @@ -# Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -import pytest - -from brainpy.math.einops_parsing import EinopsError, ParsedExpression, AnonymousAxis, _ellipsis - - -class AnonymousAxisPlaceholder: - def __init__(self, value: int): - self.value = value - assert isinstance(self.value, int) - - def __eq__(self, other): - return isinstance(other, AnonymousAxis) and self.value == other.value - - -def test_anonymous_axes(): - a, b = AnonymousAxis('2'), AnonymousAxis('2') - assert a != b - c, d = AnonymousAxisPlaceholder(2), AnonymousAxisPlaceholder(3) - assert a == c and b == c - assert a != d and b != d - assert [a, 2, b] == [c, 2, c] - - -def test_elementary_axis_name(): - for name in ['a', 'b', 'h', 'dx', 'h1', 'zz', 'i9123', 'somelongname', - 'Alex', 'camelCase', 'u_n_d_e_r_score', 'unreasonablyLongAxisName']: - assert ParsedExpression.check_axis_name(name) - - for name in ['', '2b', '12', '_startWithUnderscore', 'endWithUnderscore_', '_', '...', _ellipsis]: - assert not ParsedExpression.check_axis_name(name) - - -def test_invalid_expressions(): - # double ellipsis should raise an error - ParsedExpression('... a b c d') - with pytest.raises(EinopsError): - ParsedExpression('... a b c d ...') - with pytest.raises(EinopsError): - ParsedExpression('... a b c (d ...)') - with pytest.raises(EinopsError): - ParsedExpression('(... a) b c (d ...)') - - # double/missing/enclosed parenthesis - ParsedExpression('(a) b c (d ...)') - with pytest.raises(EinopsError): - ParsedExpression('(a)) b c (d ...)') - with pytest.raises(EinopsError): - ParsedExpression('(a b c (d ...)') - with pytest.raises(EinopsError): - ParsedExpression('(a) (()) b c (d ...)') - with pytest.raises(EinopsError): - ParsedExpression('(a) ((b c) (d ...))') - - # invalid identifiers - ParsedExpression('camelCase under_scored cApiTaLs ß ...') - with pytest.raises(EinopsError): - ParsedExpression('1a') - with pytest.raises(EinopsError): - ParsedExpression('_pre') - with pytest.raises(EinopsError): - ParsedExpression('...pre') - with pytest.raises(EinopsError): - ParsedExpression('pre...') - - -def test_parse_expression(): - parsed = ParsedExpression('a1 b1 c1 d1') - assert parsed.identifiers == {'a1', 'b1', 'c1', 'd1'} - assert parsed.composition == [['a1'], ['b1'], ['c1'], ['d1']] - assert not parsed.has_non_unitary_anonymous_axes - assert not parsed.has_ellipsis - - parsed = ParsedExpression('() () () ()') - assert parsed.identifiers == set() - assert parsed.composition == [[], [], [], []] - assert not parsed.has_non_unitary_anonymous_axes - assert not parsed.has_ellipsis - - parsed = ParsedExpression('1 1 1 ()') - assert parsed.identifiers == set() - assert parsed.composition == [[], [], [], []] - assert not parsed.has_non_unitary_anonymous_axes - assert not parsed.has_ellipsis - - aap = AnonymousAxisPlaceholder - - parsed = ParsedExpression('5 (3 4)') - assert len(parsed.identifiers) == 3 and {i.value for i in parsed.identifiers} == {3, 4, 5} - assert parsed.composition == [[aap(5)], [aap(3), aap(4)]] - assert parsed.has_non_unitary_anonymous_axes - assert not parsed.has_ellipsis - - parsed = ParsedExpression('5 1 (1 4) 1') - assert len(parsed.identifiers) == 2 and {i.value for i in parsed.identifiers} == {4, 5} - assert parsed.composition == [[aap(5)], [], [aap(4)], []] - - parsed = ParsedExpression('name1 ... a1 12 (name2 14)') - assert len(parsed.identifiers) == 6 - assert parsed.identifiers.difference({'name1', _ellipsis, 'a1', 'name2'}).__len__() == 2 - assert parsed.composition == [['name1'], _ellipsis, ['a1'], [aap(12)], ['name2', aap(14)]] - assert parsed.has_non_unitary_anonymous_axes - assert parsed.has_ellipsis - assert not parsed.has_ellipsis_parenthesized - - parsed = ParsedExpression('(name1 ... a1 12) name2 14') - assert len(parsed.identifiers) == 6 - assert parsed.identifiers.difference({'name1', _ellipsis, 'a1', 'name2'}).__len__() == 2 - assert parsed.composition == [['name1', _ellipsis, 'a1', aap(12)], ['name2'], [aap(14)]] - assert parsed.has_non_unitary_anonymous_axes - assert parsed.has_ellipsis - assert parsed.has_ellipsis_parenthesized diff --git a/brainpy/math/einops_test.py b/brainpy/math/einops_test.py deleted file mode 100644 index e65102762..000000000 --- a/brainpy/math/einops_test.py +++ /dev/null @@ -1,346 +0,0 @@ -# Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -import numpy -import pytest - -import brainpy.math as bm -from brainpy.math.einops import ein_rearrange, ein_reduce, ein_repeat, _enumerate_directions -from brainpy.math.einops_parsing import EinopsError - -REDUCTIONS = ("min", "max", "sum", "mean", "prod") - -identity_patterns = [ - "...->...", - "a b c d e-> a b c d e", - "a b c d e ...-> ... a b c d e", - "a b c d e ...-> a ... b c d e", - "... a b c d e -> ... a b c d e", - "a ... e-> a ... e", - "a ... -> a ... ", - "a ... c d e -> a (...) c d e", -] - -equivalent_rearrange_patterns = [ - ("a b c d e -> (a b) c d e", "a b ... -> (a b) ... "), - ("a b c d e -> a b (c d) e", "... c d e -> ... (c d) e"), - ("a b c d e -> a b c d e", "... -> ... "), - ("a b c d e -> (a b c d e)", "... -> (...)"), - ("a b c d e -> b (c d e) a", "a b ... -> b (...) a"), - ("a b c d e -> b (a c d) e", "a b ... e -> b (a ...) e"), -] - -equivalent_reduction_patterns = [ - ("a b c d e -> ", " ... -> "), - ("a b c d e -> (e a)", "a ... e -> (e a)"), - ("a b c d e -> d (a e)", " a b c d e ... -> d (a e) "), - ("a b c d e -> (a b)", " ... c d e -> (...) "), -] - - -def test_collapsed_ellipsis_errors_out(): - x = numpy.zeros([1, 1, 1, 1, 1]) - ein_rearrange(x, "a b c d ... -> a b c ... d") - with pytest.raises(EinopsError): - ein_rearrange(x, "a b c d (...) -> a b c ... d") - - ein_rearrange(x, "... -> (...)") - with pytest.raises(EinopsError): - ein_rearrange(x, "(...) -> (...)") - - -def test_ellipsis_ops_numpy(): - x = numpy.arange(2 * 3 * 4 * 5 * 6).reshape([2, 3, 4, 5, 6]) - for pattern in identity_patterns: - assert numpy.array_equal(x, ein_rearrange(x, pattern)), pattern - - for pattern1, pattern2 in equivalent_rearrange_patterns: - assert numpy.array_equal(ein_rearrange(x, pattern1), ein_rearrange(x, pattern2)) - - for reduction in ["min", "max", "sum"]: - for pattern1, pattern2 in equivalent_reduction_patterns: - assert numpy.array_equal(ein_reduce(x, pattern1, reduction=reduction), - ein_reduce(x, pattern2, reduction=reduction)) - - # now just check coincidence with numpy - all_rearrange_patterns = [*identity_patterns] - for pattern_pairs in equivalent_rearrange_patterns: - all_rearrange_patterns.extend(pattern_pairs) - - -def test_rearrange_consistency_numpy(): - shape = [1, 2, 3, 5, 7, 11] - x = numpy.arange(numpy.prod(shape)).reshape(shape) - for pattern in [ - "a b c d e f -> a b c d e f", - "b a c d e f -> a b d e f c", - "a b c d e f -> f e d c b a", - "a b c d e f -> (f e) d (c b a)", - "a b c d e f -> (f e d c b a)", - ]: - result = ein_rearrange(x, pattern) - assert len(numpy.setdiff1d(x, result)) == 0 - - result = ein_rearrange(x, "a b c d e f -> a (b) (c d e) f") - assert numpy.array_equal(x.flatten(), result.flatten()) - - result = ein_rearrange(x, "a aa aa1 a1a1 aaaa a11 -> a aa aa1 a1a1 aaaa a11") - assert numpy.array_equal(x, result) - - result1 = ein_rearrange(x, "a b c d e f -> f e d c b a") - result2 = ein_rearrange(x, "f e d c b a -> a b c d e f") - assert numpy.array_equal(result1, result2) - - result = ein_rearrange(ein_rearrange(x, "a b c d e f -> (f d) c (e b) a"), "(f d) c (e b) a -> a b c d e f", b=2, - d=5) - assert numpy.array_equal(x, result) - - sizes = dict(zip("abcdef", shape)) - temp = ein_rearrange(x, "a b c d e f -> (f d) c (e b) a", **sizes) - result = ein_rearrange(temp, "(f d) c (e b) a -> a b c d e f", **sizes) - assert numpy.array_equal(x, result) - - x2 = numpy.arange(2 * 3 * 4).reshape([2, 3, 4]) - result = ein_rearrange(x2, "a b c -> b c a") - assert x2[1, 2, 3] == result[2, 3, 1] - assert x2[0, 1, 2] == result[1, 2, 0] - - -def test_rearrange_permutations_numpy(): - # tests random permutation of axes against two independent numpy ways - for n_axes in range(1, 10): - input = numpy.arange(2 ** n_axes).reshape([2] * n_axes) - permutation = numpy.random.permutation(n_axes) - left_expression = " ".join("i" + str(axis) for axis in range(n_axes)) - right_expression = " ".join("i" + str(axis) for axis in permutation) - expression = left_expression + " -> " + right_expression - result = ein_rearrange(input, expression) - - for pick in numpy.random.randint(0, 2, [10, n_axes]): - assert input[tuple(pick)] == result[tuple(pick[permutation])] - - for n_axes in range(1, 10): - input = numpy.arange(2 ** n_axes).reshape([2] * n_axes) - permutation = numpy.random.permutation(n_axes) - left_expression = " ".join("i" + str(axis) for axis in range(n_axes)[::-1]) - right_expression = " ".join("i" + str(axis) for axis in permutation[::-1]) - expression = left_expression + " -> " + right_expression - result = ein_rearrange(input, expression) - assert result.shape == input.shape - expected_result = numpy.zeros_like(input) - for original_axis, result_axis in enumerate(permutation): - expected_result |= ((input >> original_axis) & 1) << result_axis - - assert numpy.array_equal(result, expected_result) - - -def test_reduction_imperatives(): - for reduction in REDUCTIONS: - # slight redundancy for simpler order - numpy version is evaluated multiple times - input = numpy.arange(2 * 3 * 4 * 5 * 6, dtype="int64").reshape([2, 3, 4, 5, 6]) - if reduction in ["mean", "prod"]: - input = input / input.astype("float64").mean() - test_cases = [ - ["a b c d e -> ", {}, getattr(input, reduction)()], - ["a ... -> ", {}, getattr(input, reduction)()], - ["(a1 a2) ... (e1 e2) -> ", dict(a1=1, e2=2), getattr(input, reduction)()], - [ - "a b c d e -> (e c) a", - {}, - getattr(input, reduction)(axis=(1, 3)).transpose(2, 1, 0).reshape([-1, 2]), - ], - [ - "a ... c d e -> (e c) a", - {}, - getattr(input, reduction)(axis=(1, 3)).transpose(2, 1, 0).reshape([-1, 2]), - ], - [ - "a b c d e ... -> (e c) a", - {}, - getattr(input, reduction)(axis=(1, 3)).transpose(2, 1, 0).reshape([-1, 2]), - ], - ["a b c d e -> (e c a)", {}, getattr(input, reduction)(axis=(1, 3)).transpose(2, 1, 0).reshape([-1])], - ["(a a2) ... -> (a2 a) ...", dict(a2=1), input], - ] - for pattern, axes_lengths, expected_result in test_cases: - result = ein_reduce(bm.from_numpy(input.copy()), pattern, reduction=reduction, **axes_lengths) - result = bm.as_numpy(result) - print(reduction, pattern, expected_result, result) - assert numpy.allclose(result, expected_result), f"Failed at {pattern}" - - -def test_enumerating_directions(): - for shape in [[], [1], [1, 1, 1], [2, 3, 5, 7]]: - x = numpy.arange(numpy.prod(shape)).reshape(shape) - axes1 = _enumerate_directions(x) - axes2 = _enumerate_directions(bm.from_numpy(x)) - assert len(axes1) == len(axes2) == len(shape) - for ax1, ax2 in zip(axes1, axes2): - ax2 = bm.as_numpy(ax2) - assert ax1.shape == ax2.shape - assert numpy.allclose(ax1, ax2) - - -def test_concatenations_and_stacking(): - for n_arrays in [1, 2, 5]: - shapes = [[], [1], [1, 1], [2, 3, 5, 7], [1] * 6] - for shape in shapes: - arrays1 = [numpy.arange(i, i + numpy.prod(shape)).reshape(shape) for i in range(n_arrays)] - arrays2 = [bm.from_numpy(array) for array in arrays1] - result0 = numpy.asarray(arrays1) - result1 = ein_rearrange(arrays1, "...->...") - result2 = ein_rearrange(arrays2, "...->...") - assert numpy.array_equal(result0, result1) - assert numpy.array_equal(result1, bm.as_numpy(result2)) - - result1 = ein_rearrange(arrays1, "b ... -> ... b") - result2 = ein_rearrange(arrays2, "b ... -> ... b") - assert numpy.array_equal(result1, bm.as_numpy(result2)) - - -def test_gradients_imperatives(): - # lazy - just checking reductions - for reduction in REDUCTIONS: - if reduction in ("any", "all"): - continue # non-differentiable ops - x = numpy.arange(1, 1 + 2 * 3 * 4).reshape([2, 3, 4]).astype("float32") - y0 = bm.from_numpy(x) - if not hasattr(y0, "grad"): - continue - - y1 = ein_reduce(y0, "a b c -> c a", reduction=reduction) - y2 = ein_reduce(y1, "c a -> a c", reduction=reduction) - y3 = ein_reduce(y2, "a (c1 c2) -> a", reduction=reduction, c1=2) - y4 = ein_reduce(y3, "... -> ", reduction=reduction) - - y4.backward() - grad = bm.as_numpy(y0.grad) - - -def test_tiling_imperatives(): - input = numpy.arange(2 * 3 * 5, dtype="int64").reshape([2, 1, 3, 1, 5]) - test_cases = [ - (1, 1, 1, 1, 1), - (1, 2, 1, 3, 1), - (3, 1, 1, 4, 1), - ] - for repeats in test_cases: - expected = numpy.tile(input, repeats) - converted = bm.from_numpy(input) - repeated = bm.tile(converted, repeats) - result = bm.as_numpy(repeated) - assert numpy.array_equal(result, expected) - - -repeat_test_cases = [ - # all assume that input has shape [2, 3, 5] - ("a b c -> c a b", dict()), - ("a b c -> (c copy a b)", dict(copy=2, a=2, b=3, c=5)), - ("a b c -> (a copy) b c ", dict(copy=1)), - ("a b c -> (c a) (copy1 b copy2)", dict(a=2, copy1=1, copy2=2)), - ("a ... -> a ... copy", dict(copy=4)), - ("... c -> ... (copy1 c copy2)", dict(copy1=1, copy2=2)), - ("... -> ... ", dict()), - (" ... -> copy1 ... copy2 ", dict(copy1=2, copy2=3)), - ("a b c -> copy1 a copy2 b c () ", dict(copy1=2, copy2=1)), -] - - -def check_reversion(x, repeat_pattern, **sizes): - """Checks repeat pattern by running reduction""" - left, right = repeat_pattern.split("->") - reduce_pattern = right + "->" + left - repeated = ein_repeat(x, repeat_pattern, **sizes) - reduced_min = ein_reduce(repeated, reduce_pattern, reduction="min", **sizes) - reduced_max = ein_reduce(repeated, reduce_pattern, reduction="max", **sizes) - assert numpy.array_equal(x, reduced_min) - assert numpy.array_equal(x, reduced_max) - - -def test_repeat_numpy(): - # check repeat vs reduce. Repeat works ok if reverse reduction with min and max work well - x = numpy.arange(2 * 3 * 5).reshape([2, 3, 5]) - x1 = ein_repeat(x, "a b c -> copy a b c ", copy=1) - assert numpy.array_equal(x[None], x1) - for pattern, axis_dimensions in repeat_test_cases: - check_reversion(x, pattern, **axis_dimensions) - - -test_cases_repeat_anonymous = [ - # all assume that input has shape [1, 2, 4, 6] - ("a b c d -> c a d b", dict()), - ("a b c d -> (c 2 d a b)", dict(a=1, c=4, d=6)), - ("1 b c d -> (d copy 1) 3 b c ", dict(copy=3)), - ("1 ... -> 3 ... ", dict()), - ("() ... d -> 1 (copy1 d copy2) ... ", dict(copy1=2, copy2=3)), - ("1 b c d -> (1 1) (1 b) 2 c 3 d (1 1)", dict()), -] - - -def test_anonymous_axes(): - x = numpy.arange(1 * 2 * 4 * 6).reshape([1, 2, 4, 6]) - for pattern, axis_dimensions in test_cases_repeat_anonymous: - check_reversion(x, pattern, **axis_dimensions) - - -def test_list_inputs(): - x = numpy.arange(2 * 3 * 4 * 5 * 6).reshape([2, 3, 4, 5, 6]) - - assert numpy.array_equal( - ein_rearrange(list(x), "... -> (...)"), - ein_rearrange(x, "... -> (...)"), - ) - assert numpy.array_equal( - ein_reduce(list(x), "a ... e -> (...)", "min"), - ein_reduce(x, "a ... e -> (...)", "min"), - ) - assert numpy.array_equal( - ein_repeat(list(x), "... -> b (...)", b=3), - ein_repeat(x, "... -> b (...)", b=3), - ) - - -def bit_count(x): - return sum((x >> i) & 1 for i in range(20)) - - -def test_reduction_imperatives_booleans(): - """Checks that any/all reduction works in all frameworks""" - x_np = numpy.asarray([(bit_count(x) % 2) == 0 for x in range(2 ** 6)]).reshape([2] * 6) - - for axis in range(6): - expected_result_any = numpy.any(x_np, axis=axis, keepdims=True) - expected_result_all = numpy.all(x_np, axis=axis, keepdims=True) - assert not numpy.array_equal(expected_result_any, expected_result_all) - - axes = list("abcdef") - axes_in = list(axes) - axes_out = list(axes) - axes_out[axis] = "1" - pattern = (" ".join(axes_in)) + " -> " + (" ".join(axes_out)) - - res_any = ein_reduce(bm.from_numpy(x_np), pattern, reduction="any") - res_all = ein_reduce(bm.from_numpy(x_np), pattern, reduction="all") - - assert numpy.array_equal(expected_result_any, bm.as_numpy(res_any)) - assert numpy.array_equal(expected_result_all, bm.as_numpy(res_all)) - - # expected result: any/all - expected_result_any = numpy.any(x_np, axis=(0, 1), keepdims=True) - expected_result_all = numpy.all(x_np, axis=(0, 1), keepdims=True) - pattern = "a b ... -> 1 1 ..." - res_any = ein_reduce(bm.from_numpy(x_np), pattern, reduction="any") - res_all = ein_reduce(bm.from_numpy(x_np), pattern, reduction="all") - assert numpy.array_equal(expected_result_any, bm.as_numpy(res_any)) - assert numpy.array_equal(expected_result_all, bm.as_numpy(res_all)) diff --git a/brainpy/math/math_compat_fixes_test.py b/brainpy/math/math_compat_fixes_test.py index 66bde6202..1fe2bba3a 100644 --- a/brainpy/math/math_compat_fixes_test.py +++ b/brainpy/math/math_compat_fixes_test.py @@ -22,8 +22,8 @@ ``fill_diagonal(inplace=False)`` returns a brainpy ``Array``. * (``compat_pytorch.py``) -- ``arcsinh``/``arctanh`` exist & correct, no duplicate ``arcsin`` clobbering. -* (``einops.py``) -- module still imports after the dead - ``_optimize_transformation`` helper was removed. +* (``einops``) -- the ``ein_*`` helpers are re-exported from + ``brainunit.math`` (the local port was removed); the wiring still works. """ import jax @@ -39,8 +39,8 @@ activations as act, others as bo, _utils as butils, - einops as bein, ) +import brainunit.math as um from brainpy.math.ndarray import Array @@ -258,16 +258,14 @@ def test_numpy_arcsinh_arctanh_present(): np.asarray(_j(cn.arctanh(bm.asarray([0., 0.5])))), np.arctanh([0., 0.5]), atol=1e-6) -# --- einops module still imports (dead _optimize_transformation removed) ---- +# --- einops helpers are re-exported from brainunit.math --------------------- -def test_einops_module_imports(): - import brainpy.math.einops as eio - assert hasattr(eio, 'ein_rearrange') - assert hasattr(eio, 'ein_reduce') - assert hasattr(eio, 'ein_repeat') - assert hasattr(eio, 'ein_shape') - # the dead helper flagged by the audit must be gone - assert not hasattr(eio, '_optimize_transformation') +def test_einops_reexports_brainunit(): + # the local einops port was removed; ``bm.ein_*`` must alias brainunit's. + assert bm.ein_rearrange is um.einrearrange + assert bm.ein_reduce is um.einreduce + assert bm.ein_repeat is um.einrepeat + assert bm.ein_shape is um.einshape # =========================================================================== @@ -905,113 +903,101 @@ def test_utils_wrapper_doc_and_name(): def test_einops_rearrange(): x = jnp.arange(24.).reshape(2, 3, 4) - assert bein.ein_rearrange(x, 'a b c -> a c b').shape == (2, 4, 3) - assert bein.ein_rearrange(x, 'a b c -> (a b) c').shape == (6, 4) - assert bein.ein_rearrange(x, 'a b c -> a b c').shape == (2, 3, 4) + assert bm.ein_rearrange(x, 'a b c -> a c b').shape == (2, 4, 3) + assert bm.ein_rearrange(x, 'a b c -> (a b) c').shape == (6, 4) + assert bm.ein_rearrange(x, 'a b c -> a b c').shape == (2, 3, 4) # split an axis - assert bein.ein_rearrange(jnp.arange(12.), '(a b) -> a b', a=3).shape == (3, 4) + assert bm.ein_rearrange(jnp.arange(12.), '(a b) -> a b', a=3).shape == (3, 4) + + +def test_einops_rearrange_on_brainpy_array(): + # the reused brainunit helpers must accept brainpy ``Array`` instances. + x = bm.Array(jnp.arange(24.).reshape(2, 3, 4)) + assert bm.ein_rearrange(x, 'a b c -> a c b').shape == (2, 4, 3) + assert bm.ein_reduce(x, 'a b c -> a c', 'sum').shape == (2, 4) + assert bm.ein_shape(x, 'a b c') == {'a': 2, 'b': 3, 'c': 4} def test_einops_reduce(): x = jnp.arange(24.).reshape(2, 3, 4) - assert bein.ein_reduce(x, 'a b c -> a c', 'mean').shape == (2, 4) - assert bein.ein_reduce(x, 'a b c -> a', 'sum').shape == (2,) - assert bein.ein_reduce(x, 'a b c -> b c', 'max').shape == (3, 4) - assert bein.ein_reduce(x, 'a b c -> b c', 'min').shape == (3, 4) - assert bein.ein_reduce(x, 'a b c -> b c', 'prod').shape == (3, 4) + assert bm.ein_reduce(x, 'a b c -> a c', 'mean').shape == (2, 4) + assert bm.ein_reduce(x, 'a b c -> a', 'sum').shape == (2,) + assert bm.ein_reduce(x, 'a b c -> b c', 'max').shape == (3, 4) + assert bm.ein_reduce(x, 'a b c -> b c', 'min').shape == (3, 4) + assert bm.ein_reduce(x, 'a b c -> b c', 'prod').shape == (3, 4) # pooling-style reduce with explicit axis lengths y = jnp.arange(2 * 2 * 4 * 4.).reshape(2, 2, 4, 4) - assert bein.ein_reduce(y, 'b c (h h2) (w w2) -> b c h w', 'max', h2=2, w2=2).shape == (2, 2, 2, 2) + assert bm.ein_reduce(y, 'b c (h h2) (w w2) -> b c h w', 'max', h2=2, w2=2).shape == (2, 2, 2, 2) def test_einops_reduce_any_all(): b = jnp.array([[True, False, True], [False, False, True]]) - assert bein.ein_reduce(b, 'a c -> c', 'any').shape == (3,) - assert bein.ein_reduce(b, 'a c -> c', 'all').shape == (3,) + assert bm.ein_reduce(b, 'a c -> c', 'any').shape == (3,) + assert bm.ein_reduce(b, 'a c -> c', 'all').shape == (3,) np.testing.assert_array_equal( - np.asarray(_j(bein.ein_reduce(b, 'a c -> c', 'any'))), [True, False, True]) + np.asarray(_j(bm.ein_reduce(b, 'a c -> c', 'any'))), [True, False, True]) np.testing.assert_array_equal( - np.asarray(_j(bein.ein_reduce(b, 'a c -> c', 'all'))), [False, False, True]) + np.asarray(_j(bm.ein_reduce(b, 'a c -> c', 'all'))), [False, False, True]) def test_einops_repeat(): img = jnp.arange(6.).reshape(2, 3) - assert bein.ein_repeat(img, 'h w -> h w c', c=4).shape == (2, 3, 4) - assert bein.ein_repeat(img, 'h w -> (h2 h) w', h2=2).shape == (4, 3) + assert bm.ein_repeat(img, 'h w -> h w c', c=4).shape == (2, 3, 4) + assert bm.ein_repeat(img, 'h w -> (h2 h) w', h2=2).shape == (4, 3) def test_einops_shape(): x = jnp.zeros((2, 3, 5, 7)) - assert bein.ein_shape(x, 'batch _ h w') == {'batch': 2, 'h': 5, 'w': 7} - assert bein.ein_shape(x, 'a b c d') == {'a': 2, 'b': 3, 'c': 5, 'd': 7} + assert bm.ein_shape(x, 'batch _ h w') == {'batch': 2, 'h': 5, 'w': 7} + assert bm.ein_shape(x, 'a b c d') == {'a': 2, 'b': 3, 'c': 5, 'd': 7} def test_einops_reduce_callable_reduction(): x = jnp.arange(24.).reshape(2, 3, 4) - out = bein.ein_reduce(x, 'a b c -> a c', lambda t, axes: t.sum(axis=axes)) + out = bm.ein_reduce(x, 'a b c -> a c', lambda t, axes: t.sum(axis=axes)) assert out.shape == (2, 4) def test_einops_mean_requires_float(): x = jnp.arange(24).reshape(2, 3, 4) # integer tensor with pytest.raises(Exception): - bein.ein_reduce(x, 'a b c -> a c', 'mean') + bm.ein_reduce(x, 'a b c -> a c', 'mean') def test_einops_error_message_wrapped(): - from brainpy.math.einops_parsing import EinopsError - with pytest.raises(EinopsError): - bein.ein_rearrange(jnp.arange(6.), 'a b c -> a b c') # wrong ndim - - -def test_einops_enumerate_directions_internal(): - x = jnp.zeros((2, 3)) - dirs = bein._enumerate_directions(x) - assert len(dirs) == 2 - assert _j(dirs[0]).shape == (2, 1) - assert _j(dirs[1]).shape == (1, 3) + with pytest.raises(Exception): + bm.ein_rearrange(jnp.arange(6.), 'a b c -> a b c') # wrong ndim def test_einops_ellipsis_patterns(): x = jnp.arange(24.).reshape(2, 3, 4) # reduce trailing axis, keep ellipsis dims - assert bein.ein_reduce(x, '... c -> ...', 'sum').shape == (2, 3) + assert bm.ein_reduce(x, '... c -> ...', 'sum').shape == (2, 3) # move leading axis to the end across an ellipsis - assert bein.ein_rearrange(x, 'a ... -> ... a').shape == (3, 4, 2) + assert bm.ein_rearrange(x, 'a ... -> ... a').shape == (3, 4, 2) # repeat with ellipsis - assert bein.ein_repeat(jnp.arange(6.).reshape(2, 3), '... -> ... r', r=2).shape == (2, 3, 2) + assert bm.ein_repeat(jnp.arange(6.).reshape(2, 3), '... -> ... r', r=2).shape == (2, 3, 2) def test_einops_shape_with_ellipsis(): x = jnp.zeros((2, 3, 5, 7)) - assert bein.ein_shape(x, 'b ... w') == {'b': 2, 'w': 7} + assert bm.ein_shape(x, 'b ... w') == {'b': 2, 'w': 7} def test_einops_error_branches(): - from brainpy.math.einops_parsing import EinopsError x = jnp.arange(24.).reshape(2, 3, 4) # identifiers only on one side of a rearrange - with pytest.raises(EinopsError): - bein.ein_rearrange(x, 'a b c -> a b') + with pytest.raises(Exception): + bm.ein_rearrange(x, 'a b c -> a b') # repeat without a size for a new axis - with pytest.raises(EinopsError): - bein.ein_repeat(jnp.arange(6.).reshape(2, 3), 'h w -> h w c') + with pytest.raises(Exception): + bm.ein_repeat(jnp.arange(6.).reshape(2, 3), 'h w -> h w c') # extra identifier on the right of a reduce - with pytest.raises(EinopsError): - bein.ein_reduce(x, 'a b c -> a b c d', 'sum') + with pytest.raises(Exception): + bm.ein_reduce(x, 'a b c -> a b c d', 'sum') # unknown reduction name - with pytest.raises(EinopsError): - bein.ein_reduce(x, 'a b c -> a', 'median') + with pytest.raises(Exception): + bm.ein_reduce(x, 'a b c -> a', 'median') # composed axes can't be parsed by ein_shape - with pytest.raises(RuntimeError): - bein.ein_shape(jnp.zeros((6,)), '(a b)') - - -def test_einops_list_input_passthrough_identity(): - # NOTE: the docstrings advertise stacking list-of-tensors input, but this - # port does not stack the list -- an identity pattern is a no-op and returns - # the list unchanged. Pin the current (documented-but-incomplete) behaviour. - imgs = [jnp.zeros((3, 4)) for _ in range(5)] - out = bein.ein_rearrange(imgs, 'b h w -> b h w') - assert isinstance(out, list) - assert len(out) == 5 + with pytest.raises(Exception): + bm.ein_shape(jnp.zeros((6,)), '(a b)') From 089bd8bd166392625c6fcd66dd30b1062fd25b9a Mon Sep 17 00:00:00 2001 From: chaoming Date: Fri, 19 Jun 2026 01:17:28 +0800 Subject: [PATCH 4/4] chore(math): remove taichi/tifunc remnants MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The taichi backend is gone; drop the fully-skipped tifunc_test.py (a taichi ``ti.kernel`` test) and the stale ``method`` / Taichi paragraph in the csrmv docstring (that parameter no longer exists — csrmv dispatches through brainevent). --- brainpy/math/sparse/csr_mv.py | 9 --- brainpy/math/tifunc_test.py | 131 ---------------------------------- 2 files changed, 140 deletions(-) delete mode 100644 brainpy/math/tifunc_test.py diff --git a/brainpy/math/sparse/csr_mv.py b/brainpy/math/sparse/csr_mv.py index ff33180d5..039a7229f 100644 --- a/brainpy/math/sparse/csr_mv.py +++ b/brainpy/math/sparse/csr_mv.py @@ -55,15 +55,6 @@ def csrmv( transpose: bool A boolean specifying whether to transpose the sparse matrix before computing. - method: str - The method used to compute Matrix-Vector Multiplication. Default is ``taichi``. - The candidate methods are: - - - ``None``: default using Taichi kernel. - - ``cusparse``: using cuSPARSE library. - - ``scalar``: - - ``vector``: - - ``adaptive``: Returns:: diff --git a/brainpy/math/tifunc_test.py b/brainpy/math/tifunc_test.py deleted file mode 100644 index b442eb49c..000000000 --- a/brainpy/math/tifunc_test.py +++ /dev/null @@ -1,131 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -import jax -import jax.numpy as jnp -import pytest - -pytestmark = pytest.mark.skip(reason="Skipped due to MacOS limitation, manual execution required for testing.") -import brainpy.math as bm -import matplotlib.pyplot as plt -import os - -bm.set_platform('cpu') - - -def test_taichi_random(): - @ti.kernel - def test_taichi_lfsr88(seed: ti.types.ndarray(ndim=1, dtype=ti.u32), - out: ti.types.ndarray(ndim=1, dtype=ti.f32)): - key = bm.tifunc.lfsr88_key(seed[0]) - for i in range(out.shape[0]): - key, result = bm.tifunc.lfsr88_rand(key) - out[i] = result - - @ti.kernel - def test_taichi_lcg_rand(seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - for i in range(out.shape[0]): - out[i] = bm.tifunc.taichi_lcg_rand(seed) - - @ti.kernel - def test_taichi_uniform_int_distribution(seed: ti.types.ndarray(ndim=1), - low_high: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - key = bm.tifunc.lfsr88_key(seed[0]) - low = low_high[0] - high = low_high[1] - for i in range(out.shape[0]): - key, out[i] = bm.tifunc.lfsr88_randint(key, low, high) - - @ti.kernel - def test_taichi_uniform_real_distribution(seed: ti.types.ndarray(ndim=1), - low_high: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - key = bm.tifunc.lfsr88_key(seed[0]) - low = low_high[0] - high = low_high[1] - for i in range(out.shape[0]): - key, out[i] = bm.tifunc.lfsr88_uniform(key, low, high) - - @ti.kernel - def test_taichi_normal_distribution(seed: ti.types.ndarray(ndim=1), - mu_sigma: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - key = bm.tifunc.lfsr88_key(seed[0]) - mu = mu_sigma[0] - sigma = mu_sigma[1] - - for i in range(out.shape[0]): - key, out[i] = bm.tifunc.lfsr88_normal(key, mu, sigma) - - n = 100000 - seed = jnp.array([1234, ], dtype=jnp.uint32) - low_high = jnp.array([0, 10]) - mu_sigma = jnp.array([0, 1]) - - prim_lfsr88 = bm.XLACustomOp(cpu_kernel=test_taichi_lfsr88, - gpu_kernel=test_taichi_lfsr88) - - prim_lcg_rand = bm.XLACustomOp(cpu_kernel=test_taichi_lcg_rand, - gpu_kernel=test_taichi_lcg_rand) - prim_uniform_int_distribution = bm.XLACustomOp(cpu_kernel=test_taichi_uniform_int_distribution, - gpu_kernel=test_taichi_uniform_int_distribution) - prim_uniform_real_distribution = bm.XLACustomOp(cpu_kernel=test_taichi_uniform_real_distribution, - gpu_kernel=test_taichi_uniform_real_distribution) - prim_normal_distribution = bm.XLACustomOp(cpu_kernel=test_taichi_normal_distribution, - gpu_kernel=test_taichi_normal_distribution) - - file_path = os.path.dirname(os.path.abspath(__file__)) - - out = prim_lfsr88(seed, outs=[jax.ShapeDtypeStruct((n,), jnp.float32)]) - # show the distribution of out - plt.hist(out, bins=100) - plt.title("LFSR88 random number generator") - plt.savefig(file_path + "/lfsr88.png") - plt.close() - - out = prim_lcg_rand(seed, - outs=[jax.ShapeDtypeStruct((n,), jnp.float32)]) - # show the distribution of out - plt.hist(out, bins=100) - plt.title("LCG random number generator") - plt.savefig(file_path + "/lcg_rand.png") - plt.close() - - out = prim_uniform_int_distribution(seed, low_high, - outs=[jax.ShapeDtypeStruct((n,), jnp.int32)]) - # show the distribution of out - plt.hist(out, bins=10) - plt.title("Uniform int distribution (0, 10)") - plt.savefig(file_path + "/uniform_int_distribution.png") - plt.close() - - out = prim_uniform_real_distribution(seed, low_high, - outs=[jax.ShapeDtypeStruct((n,), jnp.float32)]) - # show the distribution of out - plt.hist(out, bins=100) - plt.title("Uniform real distribution (0, 10)") - plt.savefig(file_path + "/uniform_real_distribution.png") - plt.close() - - out = prim_normal_distribution(seed, mu_sigma, - outs=[jax.ShapeDtypeStruct((n,), jnp.float32)]) - # show the distribution of out - plt.title("Normal distribution mu=0, sigma=1") - plt.hist(out, bins=100) - plt.savefig(file_path + "/normal_distribution.png") - -# TODO; test default types