Skip to content

Commit 32c60c6

Browse files
Fix notebook bug
1 parent c7711b5 commit 32c60c6

3 files changed

Lines changed: 23 additions & 0 deletions

File tree

python/egglog/egraph.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -809,6 +809,7 @@ def _add_default_rewrite(
809809
resolved_value = resolve_literal(type_ref, default_rewrite, Thunk.value(decls))
810810
rewrite_decl = DefaultRewriteDecl(ref, resolved_value.__egg_typed_expr__.expr, subsume)
811811
ruleset_decls = _add_default_rewrite_inner(decls, rewrite_decl, ruleset)
812+
ruleset_decls |= decls
812813
ruleset_decls |= resolved_value
813814

814815

python/tests/test_array_api.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,15 @@ def lda(X: NDArray, y: NDArray):
373373
return run_lda(X, y)
374374

375375

376+
def test_lda_symbolic_build_cold_schedule():
377+
X_arr = NDArray.var("X")
378+
y_arr = NDArray.var("y")
379+
egraph = EGraph()
380+
with set_array_api_egraph(egraph):
381+
res = lda(X_arr, y_arr)
382+
assert isinstance(res, NDArray)
383+
384+
376385
@pytest.mark.parametrize(
377386
"program",
378387
[

python/tests/test_high_level.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -658,6 +658,19 @@ def f() -> A:
658658

659659
check_eq(f(), A(), r)
660660

661+
def test_function_ruleset_can_run_after_materialization_without_registration(self):
662+
r = ruleset()
663+
664+
@function(ruleset=r)
665+
def f() -> A:
666+
return A()
667+
668+
# Materialize the function once so its default rewrite is added to the ruleset,
669+
# but do not register any expression that would separately add `f` to the egraph.
670+
f()
671+
egraph = EGraph()
672+
assert not egraph.run(r).updated
673+
661674
def test_constant(self):
662675
a = constant("a", A, A())
663676
check_eq(a, A(), run())

0 commit comments

Comments
 (0)