Skip to content

Commit 658ed00

Browse files
Fix let bindings and array api tests
1 parent 32c60c6 commit 658ed00

3 files changed

Lines changed: 22 additions & 8 deletions

File tree

python/egglog/egraph_state.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,10 @@ def span(frame_index: int = 0) -> bindings.RustSpan:
3939
return bindings.RustSpan("", 0, 0)
4040

4141

42+
def _normalize_global_let_name(name: str) -> str:
43+
return name if name.startswith("$") else f"${name}"
44+
45+
4246
@dataclass
4347
class EGraphState:
4448
"""
@@ -574,7 +578,7 @@ def _expr_to_egg(self, expr_decl: ExprDecl) -> bindings._Expr: # noqa: PLR0912,
574578
res: bindings._Expr
575579
match expr_decl:
576580
case LetRefDecl(name):
577-
res = bindings.Var(span(), f"{name}")
581+
res = bindings.Var(span(), _normalize_global_let_name(name))
578582
case UnboundVarDecl(name, egg_name):
579583
res = bindings.Var(span(), egg_name or f"_{name}")
580584
case LitDecl(value):

python/tests/test_array_api.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -443,14 +443,13 @@ def test_polynomial_factoring(input: Value, expected: Value):
443443
egraph = EGraph()
444444
x = egraph.let("x", input)
445445
egraph.run(polynomial_schedule)
446-
# egraph.run(to_polynomial_ruleset.saturate())
447-
# egraph.display()
448-
# egraph.run(factor_ruleset.saturate())
449-
# egraph.display()
450-
# egraph.run(from_polynomial_ruleset.saturate())
451-
# egraph.display()
452446
equiv_expr = egraph.extract(x)
453-
assert eq(equiv_expr).to(expected), f"Expected {expected}, got {equiv_expr}"
447+
# Normalized them both so that we don't have to worry about term order.
448+
normalized = EGraph()
449+
extracted_ref = normalized.let("extracted", equiv_expr)
450+
expected_ref = normalized.let("expected", expected)
451+
normalized.run(to_polynomial_ruleset.saturate())
452+
normalized.check(eq(extracted_ref).to(expected_ref))
454453

455454

456455
# if calling as script, print out egglog source for test

python/tests/test_high_level.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,17 @@ def __mul__(self, other: Math) -> Math: ...
6060
egraph.check(eq(expr1).to(expr2))
6161

6262

63+
def test_let_auto_prefixes_global_names(capfd: pytest.CaptureFixture[str]):
64+
egraph = EGraph(save_egglog_string=True)
65+
66+
x = egraph.let("x", i64(1))
67+
egraph.check(eq(x).to(i64(1)))
68+
69+
captured = capfd.readouterr()
70+
assert "should start with `$`" not in captured.err
71+
assert "(let $x " in egraph.as_egglog_string
72+
73+
6374
def test_fib():
6475
egraph = EGraph()
6576

0 commit comments

Comments
 (0)