Skip to content

Commit abce696

Browse files
committed
Implemented A*
1 parent 22b1081 commit abce696

6 files changed

Lines changed: 211 additions & 0 deletions

File tree

examples/pbe/solve.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
hs_enumerate_bucket_prob_grammar,
2525
hs_enumerate_bucket_prob_u_grammar,
2626
cd_enumerate_prob_grammar,
27+
as_enumerate_prob_grammar,
2728
ProgramEnumerator,
2829
Type,
2930
CFG,
@@ -62,6 +63,7 @@
6263
lambda x: hs_enumerate_bucket_prob_u_grammar(x, 3),
6364
),
6465
"bee_search": (bs_enumerate_prob_grammar, None),
66+
"a_star": (as_enumerate_prob_grammar, None),
6567
}
6668

6769
PRUNING = {"dfta", "obs-eq"}

synth/syntax/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,5 +50,6 @@
5050
hs_enumerate_bucket_prob_grammar,
5151
hs_enumerate_bucket_prob_u_grammar,
5252
cd_enumerate_prob_grammar,
53+
as_enumerate_prob_grammar,
5354
split,
5455
)

synth/syntax/grammars/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
hs_enumerate_bucket_prob_grammar,
1313
hs_enumerate_bucket_prob_u_grammar,
1414
cd_enumerate_prob_grammar,
15+
as_enumerate_prob_grammar,
1516
split,
1617
)
1718
from synth.syntax.grammars.u_grammar import UGrammar

synth/syntax/grammars/enumeration/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,5 +18,8 @@
1818
from synth.syntax.grammars.enumeration.constant_delay import (
1919
enumerate_prob_grammar as cd_enumerate_prob_grammar,
2020
)
21+
from synth.syntax.grammars.enumeration.a_star import (
22+
enumerate_prob_grammar as as_enumerate_prob_grammar,
23+
)
2124

2225
enumerate_prob_grammar = cd_enumerate_prob_grammar
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
from heapq import heappush, heappop
2+
from typing import (
3+
Generator,
4+
Generic,
5+
List,
6+
Optional,
7+
Tuple,
8+
TypeVar,
9+
Union,
10+
)
11+
from dataclasses import dataclass, field
12+
13+
import numpy as np
14+
15+
from synth.filter.filter import Filter
16+
from synth.syntax.grammars.enumeration.program_enumerator import ProgramEnumerator
17+
from synth.syntax.grammars.tagged_u_grammar import ProbUGrammar
18+
from synth.syntax.program import Function, Program
19+
from synth.syntax.grammars.tagged_det_grammar import ProbDetGrammar, DerivableProgram
20+
from synth.syntax.type_system import Type
21+
22+
U = TypeVar("U")
23+
V = TypeVar("V")
24+
W = TypeVar("W")
25+
26+
27+
def _build_(
28+
elems: List[Tuple[DerivableProgram, Tuple[Type, U]]], G: ProbDetGrammar[U, V, W]
29+
) -> Program:
30+
P, S = elems.pop(0)
31+
nargs = G.arguments_length_for(S, P)
32+
if nargs == 0:
33+
return P
34+
else:
35+
args = []
36+
while nargs > 0:
37+
args.append(_build_(elems, G))
38+
nargs -= 1
39+
return Function(P, args)
40+
41+
42+
@dataclass(order=True, frozen=True)
43+
class HeapElement(Generic[U]):
44+
priority: float
45+
to_expand: List[Tuple[Type, U]] = field(compare=False)
46+
parts: List[Tuple[DerivableProgram, Tuple[Type, U]]] = field(compare=False)
47+
48+
def __repr__(self) -> str:
49+
return f"({self.priority}, {self.parts})"
50+
51+
def make_program(self, g: ProbDetGrammar[U, V, W]) -> Program:
52+
return _build_(self.parts, g)
53+
54+
55+
class AStar(
56+
ProgramEnumerator[None],
57+
Generic[U, V, W],
58+
):
59+
def __init__(
60+
self,
61+
G: ProbDetGrammar[U, V, W],
62+
filter: Optional[Filter[Program]] = None,
63+
) -> None:
64+
super().__init__(filter)
65+
self.current: Optional[Program] = None
66+
67+
self.G = G
68+
self.start = G.start
69+
self.rules = G.rules
70+
71+
self.frontier: List[HeapElement[U]] = []
72+
73+
def probability(self, program: Program) -> float:
74+
return self.G.probability(program)
75+
76+
@classmethod
77+
def name(cls) -> str:
78+
return "a-star"
79+
80+
def generator(self) -> Generator[Program, None, None]:
81+
"""
82+
A generator which outputs the next most probable program
83+
"""
84+
first = (self.G.start[0], self.G.start[1][0]) # type: ignore
85+
heappush(self.frontier, HeapElement(0, [first], [])) # type: ignore
86+
87+
while self.frontier:
88+
elem = heappop(self.frontier)
89+
if len(elem.to_expand) == 0:
90+
p = elem.make_program(self.G)
91+
if self._should_keep_subprogram(p):
92+
yield p
93+
else:
94+
partS = elem.to_expand.pop()
95+
S = (partS[0], (partS[1], None))
96+
for P in self.G.rules[S]: # type: ignore
97+
args = self.G.rules[S][P][0] # type: ignore
98+
p = self.G.probabilities[S][P] # type: ignore
99+
new_el = HeapElement(
100+
elem.priority + p,
101+
elem.to_expand + list(args),
102+
elem.parts + [(P, S)], # type: ignore
103+
)
104+
heappush(self.frontier, new_el)
105+
106+
def merge_program(self, representative: Program, other: Program) -> None:
107+
"""
108+
Merge other into representative.
109+
In other words, other will no longer be generated through heap search
110+
"""
111+
pass
112+
113+
def programs_in_banks(self) -> int:
114+
return 0
115+
116+
def programs_in_queues(self) -> int:
117+
return len(self.frontier)
118+
119+
def clone(self, G: Union[ProbDetGrammar, ProbUGrammar]) -> "AStar[U, V, W]":
120+
assert isinstance(G, ProbDetGrammar)
121+
enum = self.__class__(G)
122+
return enum
123+
124+
125+
def enumerate_prob_grammar(G: ProbDetGrammar[U, V, W]) -> AStar[U, V, W]:
126+
Gp: ProbDetGrammar = ProbDetGrammar(
127+
G.grammar,
128+
{
129+
S: {P: -np.log(p) for P, p in val.items() if p > 0}
130+
for S, val in G.probabilities.items()
131+
},
132+
)
133+
return AStar(Gp)
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
from synth.syntax.grammars.enumeration.a_star import (
2+
enumerate_prob_grammar,
3+
)
4+
from synth.syntax.grammars.cfg import CFG
5+
from synth.syntax.grammars.ttcfg import TTCFG
6+
from synth.syntax.grammars.tagged_det_grammar import ProbDetGrammar
7+
from synth.syntax.dsl import DSL
8+
from synth.syntax.type_system import (
9+
INT,
10+
STRING,
11+
List,
12+
PolymorphicType,
13+
PrimitiveType,
14+
)
15+
from synth.syntax.type_helper import FunctionType, auto_type
16+
17+
import pytest
18+
19+
20+
syntax = {
21+
"+": FunctionType(INT, INT, INT),
22+
"head": FunctionType(List(PolymorphicType("a")), PolymorphicType("a")),
23+
"non_reachable": PrimitiveType("non_reachable"),
24+
"1": INT,
25+
"2": INT,
26+
"non_productive": FunctionType(INT, STRING),
27+
}
28+
dsl = DSL(syntax)
29+
dsl.instantiate_polymorphic_types()
30+
testdata = [
31+
CFG.depth_constraint(dsl, FunctionType(INT, INT), 3),
32+
CFG.depth_constraint(dsl, FunctionType(INT, INT), 4),
33+
]
34+
35+
36+
@pytest.mark.parametrize("cfg", testdata)
37+
def test_unicity_a_star(cfg: TTCFG) -> None:
38+
pcfg = ProbDetGrammar.uniform(cfg)
39+
seen = set()
40+
print(cfg)
41+
for program in enumerate_prob_grammar(pcfg):
42+
assert program not in seen
43+
seen.add(program)
44+
# print(pcfg.grammar)
45+
assert len(seen) == cfg.programs()
46+
47+
48+
@pytest.mark.parametrize("cfg", testdata)
49+
def test_order_a_star(cfg: TTCFG) -> None:
50+
pcfg = ProbDetGrammar.uniform(cfg)
51+
last = 1.0
52+
for program in enumerate_prob_grammar(pcfg):
53+
p = pcfg.probability(program)
54+
assert p <= last
55+
last = p
56+
57+
58+
def test_infinite() -> None:
59+
pcfg = ProbDetGrammar.random(
60+
CFG.infinite(dsl, testdata[0].type_request, n_gram=1), 1
61+
)
62+
count = 10000
63+
last = 1.0
64+
for program in enumerate_prob_grammar(pcfg):
65+
count -= 1
66+
p = pcfg.probability(program)
67+
assert -1e-12 <= last - p, f"failed at program n°{count}:{program}"
68+
last = p
69+
if count < 0:
70+
break
71+
assert count == -1

0 commit comments

Comments
 (0)