Skip to content

Commit 8b4629e

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 6d156db commit 8b4629e

9 files changed

Lines changed: 51 additions & 46 deletions

File tree

docs/explanation/2023_11_17_pytensor.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@
167167
")\n",
168168
"converter(int, IntTuple, lambda i: IntTuple(Int(i64(i))))\n",
169169
"converter(i64, IntTuple, lambda i: IntTuple(Int(i)))\n",
170-
"converter(Int, IntTuple, lambda i: IntTuple(i))\n",
170+
"converter(Int, IntTuple, IntTuple)\n",
171171
"\n",
172172
"\n",
173173
"@egraph.register\n",

python/egglog/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from . import config, ipython_magic # noqa: F401
66
from .bindings import EggSmolError, StageInfo, TimeOnly, WithPlan # noqa: F401
7-
from .builtins import * # noqa: UP029
7+
from .builtins import *
88
from .conversion import *
99
from .deconstruct import *
1010
from .egraph import *

python/egglog/builtins.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -801,7 +801,7 @@ def bool_le(self, other: BigIntLike) -> Bool: ...
801801
def bool_ge(self, other: BigIntLike) -> Bool: ...
802802

803803

804-
converter(i64, BigInt, lambda i: BigInt(i))
804+
converter(i64, BigInt, BigInt)
805805

806806
BigIntLike: TypeAlias = BigInt | i64Like
807807

python/egglog/exp/array_api.py

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,7 @@ def _int(i: i64, j: i64, r: Boolean, o: Int, b: Int):
297297
yield rule(eq(Int.NEVER).to(Int(i))).then(panic("Int.NEVER cannot be equal to any real int"))
298298

299299

300-
converter(i64, Int, lambda x: Int(x))
300+
converter(i64, Int, Int)
301301

302302
IntLike: TypeAlias = Int | i64Like
303303

@@ -377,8 +377,8 @@ def __gt__(self, other: FloatLike) -> Boolean: ...
377377
def __ge__(self, other: FloatLike) -> Boolean: ...
378378

379379

380-
converter(float, Float, lambda x: Float(x))
381-
converter(Int, Float, lambda x: Float.from_int(x))
380+
converter(float, Float, Float)
381+
converter(Int, Float, Float.from_int)
382382

383383

384384
FloatLike: TypeAlias = Float | float | IntLike
@@ -521,7 +521,7 @@ def deselect(self, indices: TupleIntLike) -> TupleInt:
521521
return TupleInt.range(self.length()).filter(lambda i: ~indices.contains(i)).map(lambda i: self[i])
522522

523523

524-
converter(Vec[Int], TupleInt, lambda x: TupleInt.from_vec(x))
524+
converter(Vec[Int], TupleInt, TupleInt.from_vec)
525525

526526
TupleIntLike: TypeAlias = TupleInt | VecLike[Int, IntLike]
527527

@@ -649,7 +649,7 @@ def product(self) -> TupleTupleInt:
649649
)
650650

651651

652-
converter(Vec[TupleInt], TupleTupleInt, lambda x: TupleTupleInt.from_vec(x))
652+
converter(Vec[TupleInt], TupleTupleInt, TupleTupleInt.from_vec)
653653

654654
TupleTupleIntLike: TypeAlias = TupleTupleInt | VecLike[TupleInt, TupleIntLike]
655655

@@ -755,8 +755,8 @@ def __or__(self, other: IsDtypeKind) -> IsDtypeKind: ...
755755
def isdtype(dtype: DType, kind: IsDtypeKind) -> Boolean: ...
756756

757757

758-
converter(DType, IsDtypeKind, lambda x: IsDtypeKind.dtype(x))
759-
converter(str, IsDtypeKind, lambda x: IsDtypeKind.string(x))
758+
converter(DType, IsDtypeKind, IsDtypeKind.dtype)
759+
converter(str, IsDtypeKind, IsDtypeKind.string)
760760
converter(
761761
tuple, IsDtypeKind, lambda x: convert(x[0], IsDtypeKind) | convert(x[1:], IsDtypeKind) if x else IsDtypeKind.NULL
762762
)
@@ -922,8 +922,8 @@ def from_tuple_int(cls, ti: TupleIntLike) -> TupleValue:
922922
return TupleValue(ti.length(), lambda i: Value.int(ti[i]))
923923

924924

925-
converter(Vec[Value], TupleValue, lambda x: TupleValue.from_vec(x))
926-
converter(TupleInt, TupleValue, lambda x: TupleValue.from_tuple_int(x))
925+
converter(Vec[Value], TupleValue, TupleValue.from_vec)
926+
converter(TupleInt, TupleValue, TupleValue.from_tuple_int)
927927

928928
TupleValueLike: TypeAlias = TupleValue | VecLike[Value, ValueLike] | TupleIntLike
929929

@@ -1073,9 +1073,9 @@ def ndarray(cls, key: NDArray) -> IndexKey:
10731073

10741074

10751075
converter(type(...), IndexKey, lambda _: IndexKey.ELLIPSIS)
1076-
converter(Int, IndexKey, lambda i: IndexKey.int(i))
1077-
converter(Slice, IndexKey, lambda s: IndexKey.slice(s))
1078-
converter(MultiAxisIndexKey, IndexKey, lambda m: IndexKey.multi_axis(m))
1076+
converter(Int, IndexKey, IndexKey.int)
1077+
converter(Slice, IndexKey, IndexKey.slice)
1078+
converter(MultiAxisIndexKey, IndexKey, IndexKey.multi_axis)
10791079

10801080

10811081
class Device(Expr, ruleset=array_api_ruleset): ...
@@ -1232,13 +1232,13 @@ def if_(cls, b: BooleanLike, i: NDArrayLike, j: NDArrayLike) -> NDArray: ...
12321232

12331233
NDArrayLike: TypeAlias = NDArray | ValueLike | TupleValueLike
12341234

1235-
converter(NDArray, IndexKey, lambda v: IndexKey.ndarray(v))
1236-
converter(Value, NDArray, lambda v: NDArray.scalar(v))
1235+
converter(NDArray, IndexKey, IndexKey.ndarray)
1236+
converter(Value, NDArray, NDArray.scalar)
12371237
# Need this if we want to use ints in slices of arrays coming from 1d arrays, but make it more expensive
12381238
# to prefer upcasting in the other direction when we can, which is safer at runtime
12391239
converter(NDArray, Value, lambda n: n.to_value(), 100)
1240-
converter(TupleValue, NDArray, lambda v: NDArray.vector(v))
1241-
converter(TupleInt, TupleValue, lambda v: TupleValue.from_tuple_int(v))
1240+
converter(TupleValue, NDArray, NDArray.vector)
1241+
converter(TupleInt, TupleValue, TupleValue.from_tuple_int)
12421242

12431243

12441244
@array_api_ruleset.register
@@ -1322,7 +1322,7 @@ def eval(self) -> tuple[NDArray, ...]:
13221322
return try_evaling(_get_current_egraph(), array_api_schedule, self, self.to_vec)
13231323

13241324

1325-
converter(Vec[NDArray], TupleNDArray, lambda x: TupleNDArray.from_vec(x))
1325+
converter(Vec[NDArray], TupleNDArray, TupleNDArray.from_vec)
13261326

13271327
TupleNDArrayLike: TypeAlias = TupleNDArray | VecLike[NDArray, NDArrayLike]
13281328

@@ -1371,7 +1371,7 @@ def some(cls, value: Boolean) -> OptionalBool: ...
13711371

13721372

13731373
converter(type(None), OptionalBool, lambda _: OptionalBool.none)
1374-
converter(Boolean, OptionalBool, lambda x: OptionalBool.some(x))
1374+
converter(Boolean, OptionalBool, OptionalBool.some)
13751375

13761376

13771377
class OptionalDType(Expr, ruleset=array_api_ruleset):
@@ -1382,7 +1382,7 @@ def some(cls, value: DType) -> OptionalDType: ...
13821382

13831383

13841384
converter(type(None), OptionalDType, lambda _: OptionalDType.none)
1385-
converter(DType, OptionalDType, lambda x: OptionalDType.some(x))
1385+
converter(DType, OptionalDType, OptionalDType.some)
13861386

13871387

13881388
class OptionalDevice(Expr, ruleset=array_api_ruleset):
@@ -1393,7 +1393,7 @@ def some(cls, value: Device) -> OptionalDevice: ...
13931393

13941394

13951395
converter(type(None), OptionalDevice, lambda _: OptionalDevice.none)
1396-
converter(Device, OptionalDevice, lambda x: OptionalDevice.some(x))
1396+
converter(Device, OptionalDevice, OptionalDevice.some)
13971397

13981398

13991399
class OptionalTupleInt(Expr, ruleset=array_api_ruleset):
@@ -1404,7 +1404,7 @@ def some(cls, value: TupleIntLike) -> OptionalTupleInt: ...
14041404

14051405

14061406
converter(type(None), OptionalTupleInt, lambda _: OptionalTupleInt.none)
1407-
converter(TupleInt, OptionalTupleInt, lambda x: OptionalTupleInt.some(x))
1407+
converter(TupleInt, OptionalTupleInt, OptionalTupleInt.some)
14081408

14091409

14101410
class IntOrTuple(Expr, ruleset=array_api_ruleset):
@@ -1417,8 +1417,8 @@ def int(cls, value: Int) -> IntOrTuple: ...
14171417
def tuple(cls, value: TupleIntLike) -> IntOrTuple: ...
14181418

14191419

1420-
converter(Int, IntOrTuple, lambda v: IntOrTuple.int(v))
1421-
converter(TupleInt, IntOrTuple, lambda v: IntOrTuple.tuple(v))
1420+
converter(Int, IntOrTuple, IntOrTuple.int)
1421+
converter(TupleInt, IntOrTuple, IntOrTuple.tuple)
14221422

14231423

14241424
class OptionalIntOrTuple(Expr, ruleset=array_api_ruleset):
@@ -1429,7 +1429,7 @@ def some(cls, value: IntOrTuple) -> OptionalIntOrTuple: ...
14291429

14301430

14311431
converter(type(None), OptionalIntOrTuple, lambda _: OptionalIntOrTuple.none)
1432-
converter(IntOrTuple, OptionalIntOrTuple, lambda v: OptionalIntOrTuple.some(v))
1432+
converter(IntOrTuple, OptionalIntOrTuple, OptionalIntOrTuple.some)
14331433

14341434

14351435
@function

python/egglog/exp/array_api_loopnest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def shape_api_ruleset(dims: TupleInt, axis: TupleInt):
3131
ShapeAPI(TupleInt.range(dims.length()).filter(lambda i: ~axis.contains(i)).map(lambda i: dims[i]))
3232
)
3333
yield rewrite(s.select(axis), subsume=True).to(
34-
ShapeAPI(TupleInt.range(dims.length()).filter(lambda i: axis.contains(i)).map(lambda i: dims[i]))
34+
ShapeAPI(TupleInt.range(dims.length()).filter(axis.contains).map(lambda i: dims[i]))
3535
)
3636
yield rewrite(s.to_tuple(), subsume=True).to(dims)
3737

python/tests/test_array_api.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -210,11 +210,14 @@ def linalg_norm(X: NDArray, axis: TupleIntLike) -> NDArray:
210210
return NDArray(
211211
outshape,
212212
X.dtype,
213-
lambda k: LoopNestAPI.from_tuple(reduce_axis)
214-
.unwrap()
215-
.indices()
216-
.foldl_value(lambda carry, i: carry + ((x := X.index(i + k)).conj() * x).real(), init=0.0)
217-
.sqrt(),
213+
lambda k: (
214+
LoopNestAPI
215+
.from_tuple(reduce_axis)
216+
.unwrap()
217+
.indices()
218+
.foldl_value(lambda carry, i: carry + ((x := X.index(i + k)).conj() * x).real(), init=0.0)
219+
.sqrt()
220+
),
218221
)
219222

220223

@@ -224,9 +227,11 @@ def linalg_norm_v2(X: NDArrayLike, axis: TupleIntLike) -> NDArray:
224227
return NDArray(
225228
X.shape.deselect(axis),
226229
X.dtype,
227-
lambda k: ndindex(X.shape.select(axis))
228-
.foldl_value(lambda carry, i: carry + ((x := X.index(i + k)).conj() * x).real(), init=0.0)
229-
.sqrt(),
230+
lambda k: (
231+
ndindex(X.shape.select(axis))
232+
.foldl_value(lambda carry, i: carry + ((x := X.index(i + k)).conj() * x).real(), init=0.0)
233+
.sqrt()
234+
),
230235
)
231236

232237

python/tests/test_convert.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def test_convert_to_generic():
9595
class G(BuiltinExpr, Generic[T]):
9696
def __init__(self, x: T) -> None: ...
9797

98-
converter(i64, G[i64], lambda x: G(x))
98+
converter(i64, G[i64], G)
9999
assert expr_parts(convert(10, G[i64])) == expr_parts(G(i64(10)))
100100

101101
with pytest.raises(ConvertError):
@@ -114,7 +114,7 @@ def test_convert_to_unbound_generic():
114114
class G(BuiltinExpr, Generic[T]):
115115
def __init__(self, x: i64) -> None: ...
116116

117-
converter(i64, G, lambda x: G[get_type_args()[0]](x)) # type: ignore[misc, operator]
117+
converter(i64, G, G[get_type_args()[0]]) # type: ignore[misc, operator]
118118
assert expr_parts(convert(10, G[String])) == expr_parts(G[String](i64(10)))
119119

120120

python/tests/test_high_level.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -535,14 +535,14 @@ def _global_make_tuple(x):
535535

536536

537537
def test_eval_fn_globals():
538-
assert EGraph().extract(PyObject(lambda x: _global_make_tuple(x))(PyObject.from_int(1))).value == (1,)
538+
assert EGraph().extract(PyObject(_global_make_tuple)(PyObject.from_int(1))).value == (1,)
539539

540540

541541
def test_eval_fn_locals():
542542
def _locals_make_tuple(x):
543543
return (x,)
544544

545-
assert EGraph().extract(PyObject(lambda x: _locals_make_tuple(x))(PyObject.from_int(1))).value == (1,)
545+
assert EGraph().extract(PyObject(_locals_make_tuple)(PyObject.from_int(1))).value == (1,)
546546

547547

548548
def test_lazy_types():
@@ -1459,9 +1459,9 @@ def __contains__(self, item: int) -> bool:
14591459
pytest.param(lambda: int(m), 1000, id="int"),
14601460
pytest.param(lambda: float(m), 100.0, id="float"),
14611461
pytest.param(lambda: complex(m), 1 + 0j, id="complex"),
1462-
pytest.param(lambda: m.__index__(), 20, id="index"),
1462+
pytest.param(m.__index__, 20, id="index"),
14631463
pytest.param(lambda: len(m), 10, id="len"),
1464-
pytest.param(lambda: m.__length_hint__(), 5, id="length_hint"),
1464+
pytest.param(m.__length_hint__, 5, id="length_hint"),
14651465
pytest.param(lambda: list(m), [1], id="iter"),
14661466
pytest.param(lambda: list(reversed(m)), [10], id="reversed"),
14671467
pytest.param(lambda: 1 in m, True, id="contains"),

python/tests/test_unstable_fn.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,7 @@ def apply_f(f: Callable[[A], A], x: A) -> A:
253253

254254
@r.register
255255
def _rewrite(a: A):
256-
yield rewrite(transform_a(a)).to(apply_f(lambda x: my_transform_a(x), a))
256+
yield rewrite(transform_a(a)).to(apply_f(my_transform_a, a))
257257

258258
assert check_eq(transform_a(A()), my_transform_a(A()), r * 10)
259259

@@ -276,7 +276,7 @@ def apply_f(f: Callable[[A], A], x: A) -> A:
276276

277277
@ruleset
278278
def my_ruleset(a: A):
279-
yield rewrite(transform_a(a)).to(apply_f(lambda x: my_transform_a(x), a))
279+
yield rewrite(transform_a(a)).to(apply_f(my_transform_a, a))
280280

281281
assert check_eq(transform_a(A()), my_transform_a(A()), (my_ruleset | apply_ruleset) * 10)
282282

@@ -296,7 +296,7 @@ def apply_f(f: Callable[[A], A], x: A) -> A:
296296

297297
@function(ruleset=r)
298298
def transform_a(a: A) -> A:
299-
return apply_f(lambda x: my_transform_a(x), a)
299+
return apply_f(my_transform_a, a)
300300

301301
assert check_eq(transform_a(A()), my_transform_a(A()), r * 10)
302302

@@ -325,7 +325,7 @@ def higher_order(f: Callable[[A], A]) -> A: ...
325325
@function
326326
def transform_a(a: A) -> A: ...
327327

328-
v = higher_order(lambda a: transform_a(a))
328+
v = higher_order(transform_a)
329329
assert str(v) == "higher_order(lambda a: transform_a(a))"
330330

331331
def test_multiple_same(self):

0 commit comments

Comments
 (0)