Skip to content

Commit 72d1411

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent ff224c3 commit 72d1411

3 files changed

Lines changed: 18 additions & 11 deletions

File tree

python/egglog/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from . import config, ipython_magic # noqa: F401
66
from .bindings import EggSmolError, StageInfo, TimeOnly, WithPlan # noqa: F401
7-
from .builtins import * # noqa: UP029
7+
from .builtins import *
88
from .conversion import *
99
from .deconstruct import *
1010
from .egraph import *

python/egglog/exp/array_api.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2379,7 +2379,8 @@ def vecdot(x1: NDArrayLike, x2: NDArrayLike) -> NDArray:
23792379
x1.shape.drop_last(),
23802380
x1.dtype,
23812381
lambda idx: (
2382-
TupleInt.range(x1.shape.last())
2382+
TupleInt
2383+
.range(x1.shape.last())
23832384
.map_value(lambda i: x1.index(idx.append(i)) * x2.index((i,)))
23842385
.foldl_value(Value.__add__, Value.from_float(0))
23852386
),
@@ -2740,7 +2741,8 @@ def unravel_index(flat_index: IntLike, shape: TupleIntLike) -> TupleInt:
27402741
shape = cast("TupleInt", shape)
27412742

27422743
return (
2743-
shape.reverse()
2744+
shape
2745+
.reverse()
27442746
.foldl_tuple_int(
27452747
# Store remainder as last item in accumulator
27462748
lambda acc, dim: acc.drop_last().append((r := acc.last()) % dim).append(r // dim),

python/tests/test_array_api.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -263,11 +263,14 @@ def linalg_norm(X: NDArray, axis: TupleIntLike) -> NDArray:
263263
return NDArray.fn(
264264
outshape,
265265
X.dtype,
266-
lambda k: LoopNestAPI.from_tuple(reduce_axis)
267-
.unwrap()
268-
.indices()
269-
.foldl_value(lambda carry, i: carry + ((x := X.index(i + k)).conj() * x).real(), init=0.0)
270-
.sqrt(),
266+
lambda k: (
267+
LoopNestAPI
268+
.from_tuple(reduce_axis)
269+
.unwrap()
270+
.indices()
271+
.foldl_value(lambda carry, i: carry + ((x := X.index(i + k)).conj() * x).real(), init=0.0)
272+
.sqrt()
273+
),
271274
)
272275

273276

@@ -277,9 +280,11 @@ def linalg_norm_v2(X: NDArrayLike, axis: TupleIntLike) -> NDArray:
277280
return NDArray.fn(
278281
X.shape.deselect(axis),
279282
X.dtype,
280-
lambda k: ndindex(X.shape.select(axis))
281-
.foldl_value(lambda carry, i: carry + ((x := X.index(i + k)).conj() * x).real(), init=0.0)
282-
.sqrt(),
283+
lambda k: (
284+
ndindex(X.shape.select(axis))
285+
.foldl_value(lambda carry, i: carry + ((x := X.index(i + k)).conj() * x).real(), init=0.0)
286+
.sqrt()
287+
),
283288
)
284289

285290

0 commit comments

Comments
 (0)