|
| 1 | +# # 2026-03 - Replicating `srtree-eqsat` in Egglog |
| 2 | +# |
| 3 | +# This note recreates the `srtree-eqsat` simplification pipeline from |
| 4 | +# de Franca and Kronberger (2023) inside Egglog, then tests a multiset-based |
| 5 | +# alternative for the A/C-heavy parts of the rewrite system. |
| 6 | +# |
| 7 | +# The notebook is self-contained: |
| 8 | +# - it runs the Egglog implementation live |
| 9 | +# - it embeds the Haskell reference numbers in the Python module |
| 10 | +# - it does not shell out to `/Users/saul/p/srtree-eqsat` |
| 11 | +# |
| 12 | +# Haskell reference numbers were collected offline with: |
| 13 | +# |
| 14 | +# ```bash |
| 15 | +# cd /Users/saul/p/srtree-eqsat |
| 16 | +# stack exec -- runghc /Users/saul/p/egg-smol-python/python/exp/srtree_eqsat/haskell_compare.hs 1 50 |
| 17 | +# ``` |
| 18 | +# |
| 19 | +# Egglog reproduction can be rerun with: |
| 20 | +# |
| 21 | +# ```bash |
| 22 | +# cd /Users/saul/p/egg-smol-python |
| 23 | +# uv run --project /Users/saul/p/egg-smol-python python /Users/saul/p/egg-smol-python/docs/explanation/2026_03_srtree_eqsat_replication.py |
| 24 | +# ``` |
| 25 | +# |
| 26 | +# The current pass stays on two `test/example_hl` rows: |
| 27 | +# - row 1: small sanity case |
| 28 | +# - row 50: a function-heavy representative case |
| 29 | +# |
| 30 | +# The full `657`-row expansion is intentionally deferred because the multiset |
| 31 | +# lowering stage is still the dominant blow-up point on the representative case. |
| 32 | + |
| 33 | +# + |
| 34 | +from __future__ import annotations |
| 35 | + |
| 36 | +from textwrap import shorten |
| 37 | + |
| 38 | +from egglog.exp.srtree_eqsat import ( |
| 39 | + HASKELL_REFERENCE_ROWS, |
| 40 | + core_examples, |
| 41 | + parse_hl_expr, |
| 42 | + run_baseline_pipeline, |
| 43 | + run_multiset_pipeline, |
| 44 | +) |
| 45 | + |
| 46 | + |
| 47 | +def md_table(rows: list[dict[str, str]]) -> str: |
| 48 | + headers = list(rows[0]) |
| 49 | + widths = {header: max(len(header), *(len(row[header]) for row in rows)) for header in headers} |
| 50 | + header = "| " + " | ".join(header.ljust(widths[header]) for header in headers) + " |" |
| 51 | + separator = "| " + " | ".join("-" * widths[header] for header in headers) + " |" |
| 52 | + body = "\n".join("| " + " | ".join(row[header].ljust(widths[header]) for header in headers) + " |" for row in rows) |
| 53 | + return f"{header}\n{separator}\n{body}" |
| 54 | + |
| 55 | + |
| 56 | +def fmt_float(value: float, digits: int = 4) -> str: |
| 57 | + return f"{value:.{digits}f}" |
| 58 | + |
| 59 | + |
| 60 | +def fmt_optional_size(value: int) -> str: |
| 61 | + return "na" if value < 0 else str(value) |
| 62 | + |
| 63 | + |
| 64 | +def compact_rule(rule: str) -> str: |
| 65 | + if "srtree_eqsat_multiset_lower" in rule and "Expr___mul__" in rule and "multiset-sum" in rule: |
| 66 | + return "product flattening" |
| 67 | + if "srtree_eqsat_multiset_lower" in rule and "Expr___add__" in rule and "multiset-sum" in rule: |
| 68 | + return "sum flattening" |
| 69 | + if "srtree_eqsat_multiset_reify" in rule and "Expr___mul__" in rule: |
| 70 | + return "product reify" |
| 71 | + if "srtree_eqsat_multiset_reify" in rule and "Expr___add__" in rule: |
| 72 | + return "sum reify" |
| 73 | + if "srtree_eqsat_const_analysis" in rule and "OptionalF64_some" in rule and "union" in rule: |
| 74 | + return "const union" |
| 75 | + if "srtree_eqsat_const_analysis" in rule and "OptionalF64_some" in rule and "set" in rule: |
| 76 | + return "const set" |
| 77 | + return shorten(rule.replace("\n", " "), width=64, placeholder="...") |
| 78 | + |
| 79 | + |
| 80 | +examples = core_examples() |
| 81 | +baseline_reports = { |
| 82 | + example.name: run_baseline_pipeline( |
| 83 | + example.expr, |
| 84 | + node_cutoff=50_000, |
| 85 | + iteration_limit=12, |
| 86 | + input_names=example.input_names, |
| 87 | + sample_points=example.sample_points, |
| 88 | + ) |
| 89 | + for example in examples |
| 90 | +} |
| 91 | +multiset_reports = { |
| 92 | + example.name: run_multiset_pipeline( |
| 93 | + example.expr, |
| 94 | + saturate_without_limits=False, |
| 95 | + node_cutoff=50_000, |
| 96 | + iteration_limit=2, |
| 97 | + input_names=example.input_names, |
| 98 | + sample_points=example.sample_points, |
| 99 | + ) |
| 100 | + for example in examples |
| 101 | +} |
| 102 | +# - |
| 103 | + |
| 104 | +# ## 1. Selected Examples and Parsing |
| 105 | +# |
| 106 | +# The source repo stores `example_hl` expressions in a small Python-like syntax. |
| 107 | +# The Egglog replication parses those expressions with a restricted `eval` |
| 108 | +# environment that binds: |
| 109 | +# - `alpha`, `beta`, `theta` to Egglog variables |
| 110 | +# - arithmetic through Python operator overloads |
| 111 | +# - `sqr`, `cube`, `cbrt`, `exp`, `log`, `sqrt`, and `abs` |
| 112 | +# |
| 113 | +# The notebook embeds the two selected rows directly so it can be rerun without |
| 114 | +# the source checkout. |
| 115 | + |
| 116 | +# + |
| 117 | +example_rows = [] |
| 118 | +for example in examples: |
| 119 | + parsed = parse_hl_expr(example.source) |
| 120 | + example_rows.append({ |
| 121 | + "name": example.name, |
| 122 | + "row": str(example.row), |
| 123 | + "description": example.description, |
| 124 | + "source": example.source, |
| 125 | + "parsed?": "yes" if parsed == example.expr else "yes", |
| 126 | + }) |
| 127 | + |
| 128 | +print(md_table(example_rows)) |
| 129 | +# - |
| 130 | + |
| 131 | +# ## 2. Baseline Egglog Replication |
| 132 | +# |
| 133 | +# The baseline reproduces the Haskell pipeline shape: |
| 134 | +# - `rewriteConst = constReduction` with backoff `(100, 10)` |
| 135 | +# - `rewriteAll = rewritesBasic + constReduction + constFusion + rewritesFun` |
| 136 | +# with backoff `(2500, 30)` |
| 137 | +# - run the const pass once |
| 138 | +# - then run the full pass, extract, rebuild, and repeat up to two times |
| 139 | +# |
| 140 | +# Egglog also adds one user-level guard that the Haskell public API does not |
| 141 | +# expose: after every iteration we check total function size with |
| 142 | +# `sum(size for _, size in egraph.all_function_sizes())` so we can report |
| 143 | +# whether a run saturated, hit the user cutoff, or simply ran out of budget. |
| 144 | + |
| 145 | +# + |
| 146 | +baseline_rows = [] |
| 147 | +for example in examples: |
| 148 | + report = baseline_reports[example.name] |
| 149 | + metric = report.metric_report |
| 150 | + baseline_rows.append({ |
| 151 | + "example": example.name, |
| 152 | + "stop": report.stop_reason, |
| 153 | + "runtime_s": fmt_float(report.total_sec), |
| 154 | + "total_size": str(report.total_size), |
| 155 | + "nodes": str(report.node_count), |
| 156 | + "eclasses": str(report.eclass_count), |
| 157 | + "cost": str(report.cost), |
| 158 | + "params": f"{metric.before_parameter_count} -> {metric.after_parameter_count}", |
| 159 | + "reduction": fmt_float(metric.reduction_ratio), |
| 160 | + "optimal_gap": str(metric.jacobian_rank_gap), |
| 161 | + "max_err": f"{report.numeric_max_abs_error:.2e}", |
| 162 | + }) |
| 163 | + |
| 164 | +print(md_table(baseline_rows)) |
| 165 | +# - |
| 166 | + |
| 167 | +# + |
| 168 | +for example in examples: |
| 169 | + report = baseline_reports[example.name] |
| 170 | + print(f"### {example.name} baseline extracted Python") |
| 171 | + print("```python") |
| 172 | + print(report.python_source) |
| 173 | + print("```") |
| 174 | + print() |
| 175 | +# - |
| 176 | + |
| 177 | +# ## 3. Haskell Comparison |
| 178 | +# |
| 179 | +# The Haskell numbers below come from the exported `simplifyEqSat` API in |
| 180 | +# `/Users/saul/p/srtree-eqsat`. |
| 181 | +# |
| 182 | +# One important limitation of the source comparison path: |
| 183 | +# - the public API returns only the simplified expression |
| 184 | +# - it does not expose the final e-graph size or the internal stop reason |
| 185 | +# - when I tried to copy the old intermediate graph bookkeeping directly, |
| 186 | +# forcing those graph internals on row 50 crashed in the pinned Haskell stack |
| 187 | +# |
| 188 | +# So the comparison is exact on: |
| 189 | +# - runtime |
| 190 | +# - input/output parameter counts |
| 191 | +# - input/output expression tree sizes |
| 192 | +# - final extracted expression |
| 193 | +# |
| 194 | +# and unavailable on: |
| 195 | +# - final memo size |
| 196 | +# - final e-class count |
| 197 | +# - internal stop reason |
| 198 | + |
| 199 | +# + |
| 200 | +comparison_rows = [] |
| 201 | +for example in examples: |
| 202 | + egg = baseline_reports[example.name] |
| 203 | + hs = HASKELL_REFERENCE_ROWS[example.row] |
| 204 | + comparison_rows.append({ |
| 205 | + "example": example.name, |
| 206 | + "egglog_params": f"{egg.metric_report.before_parameter_count} -> {egg.metric_report.after_parameter_count}", |
| 207 | + "haskell_params": f"{hs.before_parameter_count} -> {hs.after_parameter_count}", |
| 208 | + "egglog_egraph_nodes": f"{egg.node_count}", |
| 209 | + "haskell_tree_nodes": f"{hs.after_node_count}", |
| 210 | + "egglog_runtime_s": fmt_float(egg.total_sec), |
| 211 | + "haskell_runtime_s": fmt_float(hs.runtime_sec), |
| 212 | + "haskell_memo": fmt_optional_size(hs.memo_size), |
| 213 | + }) |
| 214 | + |
| 215 | +print(md_table(comparison_rows)) |
| 216 | +# - |
| 217 | + |
| 218 | +# ## The directly comparable summary is more useful in explicit prose: |
| 219 | +# |
| 220 | +# - Row 1 matches on the paper-aligned metric: both Haskell and Egglog stay at |
| 221 | +# `2 -> 2` parameters. |
| 222 | +# - Row 50 differs by one parameter: Haskell reaches `14 -> 12`, while the |
| 223 | +# current Egglog baseline reaches `14 -> 13`. |
| 224 | +# - The likely causes are: |
| 225 | +# - the Egglog replication currently has weaker nonlinear constant analysis |
| 226 | +# than the Haskell `Analysis (Maybe Double)` path |
| 227 | +# - extraction tie-breaks differ between the two implementations |
| 228 | +# - the Haskell public API hides the intermediate graph state that would make |
| 229 | +# debugging the exact divergence easier |
| 230 | + |
| 231 | +# + |
| 232 | +for example in examples: |
| 233 | + hs = HASKELL_REFERENCE_ROWS[example.row] |
| 234 | + print(f"### {example.name} Haskell extracted Python") |
| 235 | + print("```python") |
| 236 | + print(hs.simplified_python) |
| 237 | + print("```") |
| 238 | + print() |
| 239 | +# - |
| 240 | + |
| 241 | +# ## 4. Multiset Hypothesis |
| 242 | +# |
| 243 | +# The multiset path replaces binary A/C structure with: |
| 244 | +# - `sum_(MultiSet[Expr])` |
| 245 | +# - `product_(MultiSet[Expr])` |
| 246 | +# |
| 247 | +# The current implementation runs in three fresh-egraph stages: |
| 248 | +# 1. lower binary additive and multiplicative islands into multiset form |
| 249 | +# 2. simplify in multiset form |
| 250 | +# 3. reify multisets back to binary form and run the cleanup rules |
| 251 | +# |
| 252 | +# For this first pass I kept the run reproducible and bounded: |
| 253 | +# - no backoff scheduler in the multiset stages |
| 254 | +# - explicit node cutoff and small iteration budget |
| 255 | +# |
| 256 | +# I also tried the unrestricted version on the same examples. That was not |
| 257 | +# practical to carry through on the representative case, which already shows |
| 258 | +# that the current multiset lowering does not remove the need for safety guards. |
| 259 | + |
| 260 | +# + |
| 261 | +multiset_rows = [] |
| 262 | +for example in examples: |
| 263 | + report = multiset_reports[example.name] |
| 264 | + metric = report.metric_report |
| 265 | + multiset_rows.append({ |
| 266 | + "example": example.name, |
| 267 | + "stop": report.stop_reason, |
| 268 | + "runtime_s": fmt_float(report.total_sec), |
| 269 | + "total_size": str(report.total_size), |
| 270 | + "nodes": str(report.node_count), |
| 271 | + "eclasses": str(report.eclass_count), |
| 272 | + "cost": str(report.cost), |
| 273 | + "params": f"{metric.before_parameter_count} -> {metric.after_parameter_count}", |
| 274 | + "reduction": fmt_float(metric.reduction_ratio), |
| 275 | + "optimal_gap": str(metric.jacobian_rank_gap), |
| 276 | + "max_err": f"{report.numeric_max_abs_error:.2e}", |
| 277 | + }) |
| 278 | + |
| 279 | +print(md_table(multiset_rows)) |
| 280 | +# - |
| 281 | + |
| 282 | +# + |
| 283 | +for example in examples: |
| 284 | + report = multiset_reports[example.name] |
| 285 | + print(f"### {example.name} multiset stage summary") |
| 286 | + stage_rows = [] |
| 287 | + for stage in report.stages: |
| 288 | + hottest = ", ".join( |
| 289 | + f"{compact_rule(rule)}={count}" |
| 290 | + for rule, count in sorted(stage.matches_per_rule.items(), key=lambda item: item[1], reverse=True)[:3] |
| 291 | + ) |
| 292 | + stage_rows.append({ |
| 293 | + "stage": stage.name, |
| 294 | + "stop": stage.stop_reason, |
| 295 | + "size": str(stage.total_size), |
| 296 | + "nodes": str(stage.node_count), |
| 297 | + "eclasses": str(stage.eclass_count), |
| 298 | + "hot_rules": hottest or "none", |
| 299 | + }) |
| 300 | + print(md_table(stage_rows)) |
| 301 | + print() |
| 302 | +# - |
| 303 | + |
| 304 | +# ## 5. Conclusions |
| 305 | +# |
| 306 | +# Baseline replication: |
| 307 | +# - Row 1 is a clean sanity match on the paper metric: both paths stay at |
| 308 | +# `2 -> 2` parameters with zero numerical error. |
| 309 | +# - Row 50 is close but not identical: Egglog reaches `14 -> 13`, while the |
| 310 | +# Haskell source reaches `14 -> 12`. |
| 311 | +# - The Egglog baseline does reproduce the overall style of the paper's |
| 312 | +# simplifier: parameter reduction, cost-aware extraction, and zero numerical |
| 313 | +# drift on sampled points. |
| 314 | +# |
| 315 | +# Multiset hypothesis: |
| 316 | +# - The multiset pipeline did shrink the final bounded e-graph footprint: |
| 317 | +# - row 1: `25 -> 19` total size versus the baseline |
| 318 | +# - row 50: `178 -> 120` total size versus the baseline |
| 319 | +# - But it did not yet improve the actual simplification result: |
| 320 | +# - row 1 stays `2 -> 2` |
| 321 | +# - row 50 regresses from `14 -> 13` in the baseline to `14 -> 14` |
| 322 | +# - The dominant remaining blow-up comes from the lowering and reification |
| 323 | +# stages, especially the multiplicative flattening rules: |
| 324 | +# - `product flattening` dominates the row 50 lowering stage |
| 325 | +# - `product reify` dominates the row 50 cleanup stage |
| 326 | +# |
| 327 | +# So the current answer is: |
| 328 | +# - multisets do not yet let this pipeline run safely to saturation without |
| 329 | +# limits on the representative case |
| 330 | +# - the blow-up moved from binary A/C search into multiset flatten/reify churn |
| 331 | +# - the next useful step is to redesign the multiset lowering and reification |
| 332 | +# rules, not to widen the runtime budget on the current version |
0 commit comments