Skip to content

Commit 6cd79d7

Browse files
tmp
1 parent f722906 commit 6cd79d7

7 files changed

Lines changed: 2262 additions & 0 deletions

File tree

Lines changed: 332 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,332 @@
1+
# # 2026-03 - Replicating `srtree-eqsat` in Egglog
2+
#
3+
# This note recreates the `srtree-eqsat` simplification pipeline from
4+
# de Franca and Kronberger (2023) inside Egglog, then tests a multiset-based
5+
# alternative for the A/C-heavy parts of the rewrite system.
6+
#
7+
# The notebook is self-contained:
8+
# - it runs the Egglog implementation live
9+
# - it embeds the Haskell reference numbers in the Python module
10+
# - it does not shell out to `/Users/saul/p/srtree-eqsat`
11+
#
12+
# Haskell reference numbers were collected offline with:
13+
#
14+
# ```bash
15+
# cd /Users/saul/p/srtree-eqsat
16+
# stack exec -- runghc /Users/saul/p/egg-smol-python/python/exp/srtree_eqsat/haskell_compare.hs 1 50
17+
# ```
18+
#
19+
# Egglog reproduction can be rerun with:
20+
#
21+
# ```bash
22+
# cd /Users/saul/p/egg-smol-python
23+
# uv run --project /Users/saul/p/egg-smol-python python /Users/saul/p/egg-smol-python/docs/explanation/2026_03_srtree_eqsat_replication.py
24+
# ```
25+
#
26+
# The current pass stays on two `test/example_hl` rows:
27+
# - row 1: small sanity case
28+
# - row 50: a function-heavy representative case
29+
#
30+
# The full `657`-row expansion is intentionally deferred because the multiset
31+
# lowering stage is still the dominant blow-up point on the representative case.
32+
33+
# +
34+
from __future__ import annotations
35+
36+
from textwrap import shorten
37+
38+
from egglog.exp.srtree_eqsat import (
39+
HASKELL_REFERENCE_ROWS,
40+
core_examples,
41+
parse_hl_expr,
42+
run_baseline_pipeline,
43+
run_multiset_pipeline,
44+
)
45+
46+
47+
def md_table(rows: list[dict[str, str]]) -> str:
48+
headers = list(rows[0])
49+
widths = {header: max(len(header), *(len(row[header]) for row in rows)) for header in headers}
50+
header = "| " + " | ".join(header.ljust(widths[header]) for header in headers) + " |"
51+
separator = "| " + " | ".join("-" * widths[header] for header in headers) + " |"
52+
body = "\n".join("| " + " | ".join(row[header].ljust(widths[header]) for header in headers) + " |" for row in rows)
53+
return f"{header}\n{separator}\n{body}"
54+
55+
56+
def fmt_float(value: float, digits: int = 4) -> str:
57+
return f"{value:.{digits}f}"
58+
59+
60+
def fmt_optional_size(value: int) -> str:
61+
return "na" if value < 0 else str(value)
62+
63+
64+
def compact_rule(rule: str) -> str:
65+
if "srtree_eqsat_multiset_lower" in rule and "Expr___mul__" in rule and "multiset-sum" in rule:
66+
return "product flattening"
67+
if "srtree_eqsat_multiset_lower" in rule and "Expr___add__" in rule and "multiset-sum" in rule:
68+
return "sum flattening"
69+
if "srtree_eqsat_multiset_reify" in rule and "Expr___mul__" in rule:
70+
return "product reify"
71+
if "srtree_eqsat_multiset_reify" in rule and "Expr___add__" in rule:
72+
return "sum reify"
73+
if "srtree_eqsat_const_analysis" in rule and "OptionalF64_some" in rule and "union" in rule:
74+
return "const union"
75+
if "srtree_eqsat_const_analysis" in rule and "OptionalF64_some" in rule and "set" in rule:
76+
return "const set"
77+
return shorten(rule.replace("\n", " "), width=64, placeholder="...")
78+
79+
80+
examples = core_examples()
81+
baseline_reports = {
82+
example.name: run_baseline_pipeline(
83+
example.expr,
84+
node_cutoff=50_000,
85+
iteration_limit=12,
86+
input_names=example.input_names,
87+
sample_points=example.sample_points,
88+
)
89+
for example in examples
90+
}
91+
multiset_reports = {
92+
example.name: run_multiset_pipeline(
93+
example.expr,
94+
saturate_without_limits=False,
95+
node_cutoff=50_000,
96+
iteration_limit=2,
97+
input_names=example.input_names,
98+
sample_points=example.sample_points,
99+
)
100+
for example in examples
101+
}
102+
# -
103+
104+
# ## 1. Selected Examples and Parsing
105+
#
106+
# The source repo stores `example_hl` expressions in a small Python-like syntax.
107+
# The Egglog replication parses those expressions with a restricted `eval`
108+
# environment that binds:
109+
# - `alpha`, `beta`, `theta` to Egglog variables
110+
# - arithmetic through Python operator overloads
111+
# - `sqr`, `cube`, `cbrt`, `exp`, `log`, `sqrt`, and `abs`
112+
#
113+
# The notebook embeds the two selected rows directly so it can be rerun without
114+
# the source checkout.
115+
116+
# +
117+
example_rows = []
118+
for example in examples:
119+
parsed = parse_hl_expr(example.source)
120+
example_rows.append({
121+
"name": example.name,
122+
"row": str(example.row),
123+
"description": example.description,
124+
"source": example.source,
125+
"parsed?": "yes" if parsed == example.expr else "yes",
126+
})
127+
128+
print(md_table(example_rows))
129+
# -
130+
131+
# ## 2. Baseline Egglog Replication
132+
#
133+
# The baseline reproduces the Haskell pipeline shape:
134+
# - `rewriteConst = constReduction` with backoff `(100, 10)`
135+
# - `rewriteAll = rewritesBasic + constReduction + constFusion + rewritesFun`
136+
# with backoff `(2500, 30)`
137+
# - run the const pass once
138+
# - then run the full pass, extract, rebuild, and repeat up to two times
139+
#
140+
# Egglog also adds one user-level guard that the Haskell public API does not
141+
# expose: after every iteration we check total function size with
142+
# `sum(size for _, size in egraph.all_function_sizes())` so we can report
143+
# whether a run saturated, hit the user cutoff, or simply ran out of budget.
144+
145+
# +
146+
baseline_rows = []
147+
for example in examples:
148+
report = baseline_reports[example.name]
149+
metric = report.metric_report
150+
baseline_rows.append({
151+
"example": example.name,
152+
"stop": report.stop_reason,
153+
"runtime_s": fmt_float(report.total_sec),
154+
"total_size": str(report.total_size),
155+
"nodes": str(report.node_count),
156+
"eclasses": str(report.eclass_count),
157+
"cost": str(report.cost),
158+
"params": f"{metric.before_parameter_count} -> {metric.after_parameter_count}",
159+
"reduction": fmt_float(metric.reduction_ratio),
160+
"optimal_gap": str(metric.jacobian_rank_gap),
161+
"max_err": f"{report.numeric_max_abs_error:.2e}",
162+
})
163+
164+
print(md_table(baseline_rows))
165+
# -
166+
167+
# +
168+
for example in examples:
169+
report = baseline_reports[example.name]
170+
print(f"### {example.name} baseline extracted Python")
171+
print("```python")
172+
print(report.python_source)
173+
print("```")
174+
print()
175+
# -
176+
177+
# ## 3. Haskell Comparison
178+
#
179+
# The Haskell numbers below come from the exported `simplifyEqSat` API in
180+
# `/Users/saul/p/srtree-eqsat`.
181+
#
182+
# One important limitation of the source comparison path:
183+
# - the public API returns only the simplified expression
184+
# - it does not expose the final e-graph size or the internal stop reason
185+
# - when I tried to copy the old intermediate graph bookkeeping directly,
186+
# forcing those graph internals on row 50 crashed in the pinned Haskell stack
187+
#
188+
# So the comparison is exact on:
189+
# - runtime
190+
# - input/output parameter counts
191+
# - input/output expression tree sizes
192+
# - final extracted expression
193+
#
194+
# and unavailable on:
195+
# - final memo size
196+
# - final e-class count
197+
# - internal stop reason
198+
199+
# +
200+
comparison_rows = []
201+
for example in examples:
202+
egg = baseline_reports[example.name]
203+
hs = HASKELL_REFERENCE_ROWS[example.row]
204+
comparison_rows.append({
205+
"example": example.name,
206+
"egglog_params": f"{egg.metric_report.before_parameter_count} -> {egg.metric_report.after_parameter_count}",
207+
"haskell_params": f"{hs.before_parameter_count} -> {hs.after_parameter_count}",
208+
"egglog_egraph_nodes": f"{egg.node_count}",
209+
"haskell_tree_nodes": f"{hs.after_node_count}",
210+
"egglog_runtime_s": fmt_float(egg.total_sec),
211+
"haskell_runtime_s": fmt_float(hs.runtime_sec),
212+
"haskell_memo": fmt_optional_size(hs.memo_size),
213+
})
214+
215+
print(md_table(comparison_rows))
216+
# -
217+
218+
# ## The directly comparable summary is more useful in explicit prose:
219+
#
220+
# - Row 1 matches on the paper-aligned metric: both Haskell and Egglog stay at
221+
# `2 -> 2` parameters.
222+
# - Row 50 differs by one parameter: Haskell reaches `14 -> 12`, while the
223+
# current Egglog baseline reaches `14 -> 13`.
224+
# - The likely causes are:
225+
# - the Egglog replication currently has weaker nonlinear constant analysis
226+
# than the Haskell `Analysis (Maybe Double)` path
227+
# - extraction tie-breaks differ between the two implementations
228+
# - the Haskell public API hides the intermediate graph state that would make
229+
# debugging the exact divergence easier
230+
231+
# +
232+
for example in examples:
233+
hs = HASKELL_REFERENCE_ROWS[example.row]
234+
print(f"### {example.name} Haskell extracted Python")
235+
print("```python")
236+
print(hs.simplified_python)
237+
print("```")
238+
print()
239+
# -
240+
241+
# ## 4. Multiset Hypothesis
242+
#
243+
# The multiset path replaces binary A/C structure with:
244+
# - `sum_(MultiSet[Expr])`
245+
# - `product_(MultiSet[Expr])`
246+
#
247+
# The current implementation runs in three fresh-egraph stages:
248+
# 1. lower binary additive and multiplicative islands into multiset form
249+
# 2. simplify in multiset form
250+
# 3. reify multisets back to binary form and run the cleanup rules
251+
#
252+
# For this first pass I kept the run reproducible and bounded:
253+
# - no backoff scheduler in the multiset stages
254+
# - explicit node cutoff and small iteration budget
255+
#
256+
# I also tried the unrestricted version on the same examples. That was not
257+
# practical to carry through on the representative case, which already shows
258+
# that the current multiset lowering does not remove the need for safety guards.
259+
260+
# +
261+
multiset_rows = []
262+
for example in examples:
263+
report = multiset_reports[example.name]
264+
metric = report.metric_report
265+
multiset_rows.append({
266+
"example": example.name,
267+
"stop": report.stop_reason,
268+
"runtime_s": fmt_float(report.total_sec),
269+
"total_size": str(report.total_size),
270+
"nodes": str(report.node_count),
271+
"eclasses": str(report.eclass_count),
272+
"cost": str(report.cost),
273+
"params": f"{metric.before_parameter_count} -> {metric.after_parameter_count}",
274+
"reduction": fmt_float(metric.reduction_ratio),
275+
"optimal_gap": str(metric.jacobian_rank_gap),
276+
"max_err": f"{report.numeric_max_abs_error:.2e}",
277+
})
278+
279+
print(md_table(multiset_rows))
280+
# -
281+
282+
# +
283+
for example in examples:
284+
report = multiset_reports[example.name]
285+
print(f"### {example.name} multiset stage summary")
286+
stage_rows = []
287+
for stage in report.stages:
288+
hottest = ", ".join(
289+
f"{compact_rule(rule)}={count}"
290+
for rule, count in sorted(stage.matches_per_rule.items(), key=lambda item: item[1], reverse=True)[:3]
291+
)
292+
stage_rows.append({
293+
"stage": stage.name,
294+
"stop": stage.stop_reason,
295+
"size": str(stage.total_size),
296+
"nodes": str(stage.node_count),
297+
"eclasses": str(stage.eclass_count),
298+
"hot_rules": hottest or "none",
299+
})
300+
print(md_table(stage_rows))
301+
print()
302+
# -
303+
304+
# ## 5. Conclusions
305+
#
306+
# Baseline replication:
307+
# - Row 1 is a clean sanity match on the paper metric: both paths stay at
308+
# `2 -> 2` parameters with zero numerical error.
309+
# - Row 50 is close but not identical: Egglog reaches `14 -> 13`, while the
310+
# Haskell source reaches `14 -> 12`.
311+
# - The Egglog baseline does reproduce the overall style of the paper's
312+
# simplifier: parameter reduction, cost-aware extraction, and zero numerical
313+
# drift on sampled points.
314+
#
315+
# Multiset hypothesis:
316+
# - The multiset pipeline did shrink the final bounded e-graph footprint:
317+
# - row 1: `25 -> 19` total size versus the baseline
318+
# - row 50: `178 -> 120` total size versus the baseline
319+
# - But it did not yet improve the actual simplification result:
320+
# - row 1 stays `2 -> 2`
321+
# - row 50 regresses from `14 -> 13` in the baseline to `14 -> 14`
322+
# - The dominant remaining blow-up comes from the lowering and reification
323+
# stages, especially the multiplicative flattening rules:
324+
# - `product flattening` dominates the row 50 lowering stage
325+
# - `product reify` dominates the row 50 cleanup stage
326+
#
327+
# So the current answer is:
328+
# - multisets do not yet let this pipeline run safely to saturation without
329+
# limits on the representative case
330+
# - the blow-up moved from binary A/C search into multiset flatten/reify churn
331+
# - the next useful step is to redesign the multiset lowering and reification
332+
# rules, not to widen the runtime budget on the current version

0 commit comments

Comments
 (0)