Skip to content

Commit 9178d4d

Browse files
committed
Added vose polyfill making it optional
1 parent ef20aa3 commit 9178d4d

4 files changed

Lines changed: 103 additions & 9 deletions

File tree

synth/generation/sampler.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import copy
1515

1616
import numpy as np
17-
import vose
17+
from synth.utils.vose_polyfill import Sampler as VoseSampler
1818

1919
from synth.syntax.type_system import List, Type
2020

@@ -53,7 +53,7 @@ def __init__(
5353
filled_probabilities = probabilites
5454
else:
5555
filled_probabilities = [1 / len(self.lexicon) for _ in lexicon]
56-
self.sampler = vose.Sampler(np.asarray(filled_probabilities), seed=seed)
56+
self.sampler = VoseSampler(np.asarray(filled_probabilities), seed=seed)
5757

5858
def sample(self, **kwargs: Any) -> U:
5959
index: int = self.sampler.sample()
@@ -104,7 +104,7 @@ def __init__(
104104
if not isinstance(probabilities[0], tuple):
105105
correct_prob = [(i + 1, p) for i, p in enumerate(probabilities)] # type: ignore
106106
self._length_mapping = [n for n, _ in correct_prob]
107-
self.sampler = vose.Sampler(
107+
self.sampler = VoseSampler(
108108
np.array([p for _, p in correct_prob]), seed=seed
109109
)
110110

synth/syntax/grammars/tagged_det_grammar.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
)
1515

1616
import numpy as np
17-
import vose
17+
from synth.utils.vose_polyfill import Sampler as VoseSampler
1818

1919
if TYPE_CHECKING:
2020
from synth.syntax.grammars.cfg import CFG
@@ -165,7 +165,7 @@ def init_sampling(self, seed: Optional[int] = None) -> None:
165165

166166
for i, S in enumerate(self.tags):
167167
P_list = list(self.tags[S].keys())
168-
self.vose_samplers[S] = vose.Sampler(
168+
self.vose_samplers[S] = VoseSampler(
169169
np.array(
170170
[self.tags[S][P] for P in P_list],
171171
dtype=float,

synth/syntax/grammars/tagged_u_grammar.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
)
1212

1313
import numpy as np
14-
import vose
14+
from synth.utils.vose_polyfill import Sampler as VoseSampler
1515

1616
from synth.syntax.grammars.det_grammar import DerivableProgram
1717
from synth.syntax.grammars.u_grammar import UGrammar
@@ -183,7 +183,7 @@ def init_sampling(self, seed: Optional[int] = None) -> None:
183183

184184
for i, S in enumerate(self.tags):
185185
P_list = list(self.tags[S].keys())
186-
self.vose_samplers[S] = vose.Sampler(
186+
self.vose_samplers[S] = VoseSampler(
187187
np.array(
188188
[sum(p for p in self.tags[S][P].values()) for P in P_list],
189189
dtype=float,
@@ -192,7 +192,7 @@ def init_sampling(self, seed: Optional[int] = None) -> None:
192192
)
193193
self._vose_samplers_2[S] = {}
194194
for P in P_list:
195-
self._vose_samplers_2[S][P] = vose.Sampler(
195+
self._vose_samplers_2[S][P] = VoseSampler(
196196
np.array(
197197
[p for p in self.tags[S][P].values()],
198198
dtype=float,
@@ -202,7 +202,7 @@ def init_sampling(self, seed: Optional[int] = None) -> None:
202202
)
203203
self.sampling_map[S] = P_list
204204
self._int2start = list(self.starts)
205-
self._start_sampler = vose.Sampler(
205+
self._start_sampler = VoseSampler(
206206
np.array(
207207
[v for v in self.start_tags.values()],
208208
dtype=float,

synth/utils/vose_polyfill.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
from typing import Optional, Union
2+
import numpy as np
3+
4+
5+
class PythonSampler:
6+
def __init__(self, weights: np.ndarray, seed: Optional[int] = None) -> None:
7+
self.rng = np.random.default_rng(seed or 1)
8+
n = len(weights)
9+
alias = np.zeros(n, dtype=int)
10+
proba = np.zeros(n, dtype=float)
11+
# Compute the average probability and cache it for later use.
12+
avg = 1.0 / n
13+
# Create two stacks to act as worklists as we populate the tables.
14+
small = []
15+
large = []
16+
# Populate the stacks with the input probabilities.
17+
for i in range(n):
18+
# If the probability is below the average probability, then we add it to the small
19+
# list; otherwise we add it to the large list.
20+
if weights[i] >= avg:
21+
large.append(i)
22+
else:
23+
small.append(i)
24+
# As a note: in the mathematical specification of the algorithm, we will always exhaust the
25+
# small list before the big list. However, due to floating point inaccuracies, this is not
26+
# necessarily true. Consequently, this inner loop (which tries to pair small and large
27+
# elements) will have to check that both lists aren't empty.
28+
while len(small) > 0 and len(large) > 0:
29+
# Get the index of the small and the large probabilities.
30+
less = small.pop(0)
31+
more = large.pop(0)
32+
# These probabilities have not yet been scaled up to be such that 1 / n is given weight
33+
# 1.0. We do this here instead.
34+
proba[less] = weights[less] * n
35+
alias[less] = more
36+
# Decrease the probability of the larger one by the appropriate amount.
37+
weights[more] = weights[more] + weights[less] - avg
38+
# If the new probability is less than the average, add it into the small list;
39+
# otherwise add it to the large list.
40+
if weights[more] >= avg:
41+
large.append(more)
42+
else:
43+
small.append(more)
44+
# At this point, everything is in one list, which means that the remaining probabilities
45+
# should all be 1 / n. Based on this, set them appropriately. Due to numerical issues, we
46+
# can't be sure which stack will hold the entries, so we empty both.
47+
while len(small) > 0:
48+
less = small.pop(0)
49+
proba[less] = 1.0
50+
while len(large) > 0:
51+
more = large.pop(0)
52+
proba[more] = 1.0
53+
self.n = n
54+
self.alias = alias
55+
self.proba = proba
56+
57+
def sample_1(self) -> int:
58+
# Generate a fair die roll to determine which column to inspect.
59+
col = int(self.rng.uniform(0, self.n))
60+
# Generate a biased coin toss to determine which option to pick.
61+
heads = self.rng.uniform() < 0.5
62+
63+
# Based on the outcome, return either the column or its alias.
64+
if heads:
65+
return col
66+
return self.alias[col] # type: ignore
67+
68+
def sample(
69+
self, k: int = 1, values: Optional[np.ndarray] = None
70+
) -> Union[int, np.ndarray]:
71+
"""Sample a random integer or a value from a given array.
72+
73+
Parameters:
74+
k: The number of integers to sample. If `k = 1`, then a single int (or float if values is not None) is returned. In any
75+
other case, a numpy array is returned.
76+
values: The numpy array of values from which to sample from.
77+
78+
"""
79+
if values is None:
80+
if k == 1:
81+
return self.sample_1()
82+
return np.asarray([self.sample_1() for _ in range(k)])
83+
else:
84+
if k == 1:
85+
return values[self.sample_1()] # type: ignore
86+
return np.asarray([values[self.sample_1()] for _ in range(k)])
87+
88+
89+
try:
90+
import vose
91+
92+
Sampler = vose.Sampler
93+
except ImportError:
94+
Sampler = PythonSampler

0 commit comments

Comments
 (0)