Skip to content

Commit e2deb1f

Browse files
Fix schedules
1 parent 67d15ba commit e2deb1f

3 files changed

Lines changed: 29 additions & 3 deletions

File tree

python/egglog/egraph.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1565,7 +1565,7 @@ def __add__(self, other: Schedule) -> Schedule:
15651565
"""
15661566
Run two schedules in sequence.
15671567
"""
1568-
return Schedule(Thunk.fn(Declarations.create, self, other), SequenceDecl((self.schedule, other.schedule)))
1568+
return Schedule(partial(Declarations.create, self, other), SequenceDecl((self.schedule, other.schedule)))
15691569

15701570

15711571
@dataclass
@@ -2046,7 +2046,7 @@ def run(ruleset: Ruleset | None = None, *until: FactLike, scheduler: BackOff | N
20462046
"""
20472047
facts = _fact_likes(until)
20482048
return Schedule(
2049-
Thunk.fn(Declarations.create, ruleset, *facts),
2049+
partial(Declarations.create, ruleset, *facts),
20502050
RunDecl(
20512051
ruleset.__egg_ident__ if ruleset else Ident(""),
20522052
tuple(f.fact for f in facts),
@@ -2089,7 +2089,7 @@ def seq(*schedules: Schedule) -> Schedule:
20892089
"""
20902090
Run a sequence of schedules.
20912091
"""
2092-
return Schedule(Thunk.fn(Declarations.create, *schedules), SequenceDecl(tuple(s.schedule for s in schedules)))
2092+
return Schedule(partial(Declarations.create, *schedules), SequenceDecl(tuple(s.schedule for s in schedules)))
20932093

20942094

20952095
def _action_likes(action_likes: Iterable[ActionLike]) -> tuple[Action, ...]:

python/tests/test_array_api.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,18 @@ def test_normalize_reshape_shape(self):
172172
array_api_schedule,
173173
)
174174

175+
def test_reshape_after_schedule_decls_access(self):
176+
_ = array_api_schedule.__egg_decls__
177+
178+
x = NDArray.var("x")
179+
assume_shape(x, TupleInt((Int(5),)))
180+
res = reshape(x, TupleInt((-1,)))
181+
egraph = EGraph()
182+
egraph.register(res)
183+
egraph.run(array_api_schedule)
184+
185+
egraph.check(eq(res).to(x))
186+
175187
def test_reshape_vec_noop(self):
176188
x = NDArray.var("x")
177189
assume_shape(x, TupleInt((Int(5),)))

python/tests/test_high_level.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1117,6 +1117,20 @@ def __sub__(self, other: E) -> E: ...
11171117

11181118

11191119
class TestScheduler:
1120+
def test_seq_schedule_decls_track_ruleset_updates(self):
1121+
egraph = EGraph()
1122+
1123+
rel = relation("rel_live", i64)
1124+
live_rules = ruleset(name="live-rules")
1125+
schedule = seq(live_rules, run()).saturate()
1126+
_ = schedule.__egg_decls__
1127+
1128+
live_rules.register(rule(rel(i64(0))).then(rel(i64(1))))
1129+
1130+
egraph.register(rel(i64(0)))
1131+
egraph.run(schedule)
1132+
egraph.check(rel(i64(1)))
1133+
11201134
def test_sequence_repeat_saturate(self):
11211135
"""
11221136
Mirrors the scheduling example: alternate step-right and step-left,

0 commit comments

Comments
 (0)