|
| 1 | +from typing import Optional, Union |
| 2 | +import numpy as np |
| 3 | + |
| 4 | + |
| 5 | +class PythonSampler: |
| 6 | + def __init__(self, weights: np.ndarray, seed: Optional[int] = None) -> None: |
| 7 | + self.rng = np.random.default_rng(seed or 1) |
| 8 | + n = len(weights) |
| 9 | + alias = np.zeros(n, dtype=int) |
| 10 | + proba = np.zeros(n, dtype=float) |
| 11 | + # Compute the average probability and cache it for later use. |
| 12 | + avg = 1.0 / n |
| 13 | + # Create two stacks to act as worklists as we populate the tables. |
| 14 | + small = [] |
| 15 | + large = [] |
| 16 | + # Populate the stacks with the input probabilities. |
| 17 | + for i in range(n): |
| 18 | + # If the probability is below the average probability, then we add it to the small |
| 19 | + # list; otherwise we add it to the large list. |
| 20 | + if weights[i] >= avg: |
| 21 | + large.append(i) |
| 22 | + else: |
| 23 | + small.append(i) |
| 24 | + # As a note: in the mathematical specification of the algorithm, we will always exhaust the |
| 25 | + # small list before the big list. However, due to floating point inaccuracies, this is not |
| 26 | + # necessarily true. Consequently, this inner loop (which tries to pair small and large |
| 27 | + # elements) will have to check that both lists aren't empty. |
| 28 | + while len(small) > 0 and len(large) > 0: |
| 29 | + # Get the index of the small and the large probabilities. |
| 30 | + less = small.pop(0) |
| 31 | + more = large.pop(0) |
| 32 | + # These probabilities have not yet been scaled up to be such that 1 / n is given weight |
| 33 | + # 1.0. We do this here instead. |
| 34 | + proba[less] = weights[less] * n |
| 35 | + alias[less] = more |
| 36 | + # Decrease the probability of the larger one by the appropriate amount. |
| 37 | + weights[more] = weights[more] + weights[less] - avg |
| 38 | + # If the new probability is less than the average, add it into the small list; |
| 39 | + # otherwise add it to the large list. |
| 40 | + if weights[more] >= avg: |
| 41 | + large.append(more) |
| 42 | + else: |
| 43 | + small.append(more) |
| 44 | + # At this point, everything is in one list, which means that the remaining probabilities |
| 45 | + # should all be 1 / n. Based on this, set them appropriately. Due to numerical issues, we |
| 46 | + # can't be sure which stack will hold the entries, so we empty both. |
| 47 | + while len(small) > 0: |
| 48 | + less = small.pop(0) |
| 49 | + proba[less] = 1.0 |
| 50 | + while len(large) > 0: |
| 51 | + more = large.pop(0) |
| 52 | + proba[more] = 1.0 |
| 53 | + self.n = n |
| 54 | + self.alias = alias |
| 55 | + self.proba = proba |
| 56 | + |
| 57 | + def sample_1(self) -> int: |
| 58 | + # Generate a fair die roll to determine which column to inspect. |
| 59 | + col = int(self.rng.uniform(0, self.n)) |
| 60 | + # Generate a biased coin toss to determine which option to pick. |
| 61 | + heads = self.rng.uniform() < 0.5 |
| 62 | + |
| 63 | + # Based on the outcome, return either the column or its alias. |
| 64 | + if heads: |
| 65 | + return col |
| 66 | + return self.alias[col] # type: ignore |
| 67 | + |
| 68 | + def sample( |
| 69 | + self, k: int = 1, values: Optional[np.ndarray] = None |
| 70 | + ) -> Union[int, np.ndarray]: |
| 71 | + """Sample a random integer or a value from a given array. |
| 72 | +
|
| 73 | + Parameters: |
| 74 | + k: The number of integers to sample. If `k = 1`, then a single int (or float if values is not None) is returned. In any |
| 75 | + other case, a numpy array is returned. |
| 76 | + values: The numpy array of values from which to sample from. |
| 77 | +
|
| 78 | + """ |
| 79 | + if values is None: |
| 80 | + if k == 1: |
| 81 | + return self.sample_1() |
| 82 | + return np.asarray([self.sample_1() for _ in range(k)]) |
| 83 | + else: |
| 84 | + if k == 1: |
| 85 | + return values[self.sample_1()] # type: ignore |
| 86 | + return np.asarray([values[self.sample_1()] for _ in range(k)]) |
| 87 | + |
| 88 | + |
| 89 | +try: |
| 90 | + import vose |
| 91 | + |
| 92 | + Sampler = vose.Sampler |
| 93 | +except ImportError: |
| 94 | + Sampler = PythonSampler |
0 commit comments