Skip to content

Commit 998d29f

Browse files
authored
Add GF2MulMBUC Bloq (#1718)
Add Bloq for measurement based uncomputation of GF2 Multiplication circuits. The `Parity` bloq added to compute the parity of input bits should be generalized into a more general `CtrlSpec` that takes a bunch of control registers and uses their parity to control the operation. I'll open a new issue for it.
1 parent e283ae1 commit 998d29f

5 files changed

Lines changed: 165 additions & 55 deletions

File tree

qualtran/bloqs/basic_gates/diag_gates.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -192,8 +192,8 @@
192192
"Two-qubit controlled-Z gate.\n",
193193
"\n",
194194
"#### Registers\n",
195-
" - `ctrl`: One-bit control register.\n",
196-
" - `target`: One-bit target register.\n"
195+
" - `q1`: One-bit control register.\n",
196+
" - `q2`: One-bit target register.\n"
197197
]
198198
},
199199
{

qualtran/bloqs/basic_gates/z_basis.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -320,8 +320,8 @@ class CZ(Bloq):
320320
"""Two-qubit controlled-Z gate.
321321
322322
Registers:
323-
ctrl: One-bit control register.
324-
target: One-bit target register.
323+
q1: One-bit control register.
324+
q2: One-bit target register.
325325
"""
326326

327327
@cached_property

qualtran/bloqs/gf_arithmetic/gf2_multiplication.ipynb

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -655,8 +655,7 @@
655655
"The toffoli complexity is $n^{\\log_2{3}}$\n",
656656
"\n",
657657
"#### Parameters\n",
658-
" - `m_x`: The irreducible polynomial that defines the galois field.\n",
659-
" - `uncompute`: Whether to compute or uncompute the product. \n",
658+
" - `m_x`: The irreducible polynomial that defines the galois field. \n",
660659
"\n",
661660
"#### Registers\n",
662661
" - `x`: A TRHU register representing the first number (or polynomial).\n",

qualtran/bloqs/gf_arithmetic/gf2_multiplication.py

Lines changed: 115 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from functools import cached_property
15-
from typing import Dict, Optional, Sequence, Set, TYPE_CHECKING, Union
15+
from typing import Dict, Mapping, Optional, Sequence, Set, TYPE_CHECKING, Union
1616

1717
import attrs
1818
import galois
@@ -24,20 +24,22 @@
2424
Bloq,
2525
bloq_example,
2626
BloqDocSpec,
27+
CBit,
28+
CtrlSpec,
2729
DecomposeTypeError,
2830
QBit,
2931
QGF,
3032
Register,
3133
Side,
3234
Signature,
3335
)
34-
from qualtran.bloqs.basic_gates import CNOT, Toffoli
36+
from qualtran.bloqs.basic_gates import CNOT, CZ, Discard, MeasureX, Toffoli
3537
from qualtran.symbolics import ceil, is_symbolic, log2, Shaped, SymbolicInt
3638

3739
if TYPE_CHECKING:
3840
from qualtran import BloqBuilder, Soquet, SoquetT
3941
from qualtran.resource_counting import BloqCountDictT, BloqCountT, SympySymbolAllocator
40-
from qualtran.simulation.classical_sim import ClassicalValT
42+
from qualtran.simulation.classical_sim import ClassicalValRetT, ClassicalValT
4143

4244

4345
def _data_or_shape_to_tuple(data_or_shape: Union[np.ndarray, Shaped]) -> tuple:
@@ -172,21 +174,6 @@ def signature(self) -> 'Signature':
172174
def bitsize(self) -> SymbolicInt:
173175
return self.qgf.bitsize
174176

175-
@cached_property
176-
def reduction_matrix_q(self) -> np.ndarray:
177-
m = int(self.bitsize)
178-
f = self.qgf.gf_type.irreducible_poly
179-
M = np.zeros((m, m))
180-
alpha = [1] + [0] * m
181-
for i in range(m - 1):
182-
# x ** (m + i) % f
183-
coeffs = (Poly(alpha, GF(2)) % f).coeffs.tolist()[::-1]
184-
coeffs = coeffs + [0] * (m - len(coeffs))
185-
M[i] = coeffs
186-
alpha += [0]
187-
M[m - 1][m - 1] = 1
188-
return np.transpose(M)
189-
190177
def build_composite_bloq(self, bb: 'BloqBuilder', **soqs: 'Soquet') -> Dict[str, 'Soquet']:
191178
if is_symbolic(self.bitsize):
192179
raise DecomposeTypeError(f"Cannot decompose symbolic {self}")
@@ -261,6 +248,110 @@ def _gf2_multiplication_symbolic() -> GF2Multiplication:
261248
)
262249

263250

251+
@attrs.frozen
252+
class Parity(Bloq):
253+
n: int
254+
255+
@cached_property
256+
def signature(self) -> 'Signature':
257+
return Signature(
258+
[
259+
Register('x', dtype=CBit(), shape=(self.n,)),
260+
Register('parity', dtype=CBit(), side=Side.RIGHT),
261+
]
262+
)
263+
264+
def on_classical_vals(
265+
self, *, x: Union['sympy.Symbol', 'ClassicalValT']
266+
) -> Mapping[str, 'ClassicalValRetT']:
267+
assert isinstance(x, np.ndarray)
268+
return {'x': x, 'parity': np.sum(x, dtype=int) & 1}
269+
270+
271+
@attrs.frozen
272+
class GF2MulMBUC(Bloq):
273+
r"""Measurement based uncomputation of out of place multiplication over GF($2^m$).
274+
275+
Args:
276+
bitsize: The degree $m$ of the galois field $GF(2^m)$. Also corresponds to the number of
277+
qubits in each of the two input registers $a$ and $b$ that should be multiplied.
278+
279+
Registers:
280+
x: Input THRU register of size $m$ that stores elements from $GF(2^m)$.
281+
y: Input THRU register of size $m$ that stores elements from $GF(2^m)$.
282+
result: Register of size $m$ that stores the product $x * y$ in $GF(2^m)$.
283+
"""
284+
285+
qgf: QGF = attrs.field(converter=_qgf_converter)
286+
287+
@cached_property
288+
def signature(self) -> 'Signature':
289+
return Signature(
290+
[
291+
Register('x', dtype=self.qgf),
292+
Register('y', dtype=self.qgf),
293+
Register('result', dtype=self.qgf, side=Side.LEFT),
294+
]
295+
)
296+
297+
@cached_property
298+
def bitsize(self) -> SymbolicInt:
299+
return self.qgf.bitsize
300+
301+
@cached_property
302+
def reduction_matrix_q(self) -> np.ndarray:
303+
m = int(self.bitsize)
304+
f = self.qgf.gf_type.irreducible_poly
305+
M = np.zeros((m, m), dtype=int)
306+
alpha = [1] + [0] * m
307+
for i in range(m):
308+
# x ** (m + i) % f
309+
coeffs = (Poly(alpha, GF(2)) % f).coeffs.tolist()[::-1]
310+
coeffs = coeffs + [0] * (m - len(coeffs))
311+
M[i] = coeffs
312+
alpha += [0]
313+
return np.transpose(M)
314+
315+
def build_composite_bloq(self, bb: 'BloqBuilder', **soqs: 'Soquet') -> Dict[str, 'Soquet']:
316+
if is_symbolic(self.bitsize):
317+
raise DecomposeTypeError(f"Cannot decompose symbolic {self}")
318+
x, y, result = soqs['x'], soqs['y'], soqs['result']
319+
x, y, result = bb.split(x)[::-1], bb.split(y)[::-1], bb.split(result)[::-1]
320+
result = np.array([bb.add(MeasureX(), q=q) for q in result])
321+
m = int(self.bitsize)
322+
323+
# Inverse of Step-3: Multiply Monomials
324+
ctrl_cz = CZ().controlled(CtrlSpec(qdtypes=[CBit()]))
325+
for i in range(m):
326+
for j in range(i + 1):
327+
result[i], x[j], y[i - j] = bb.add(ctrl_cz, ctrl=result[i], q1=x[j], q2=y[i - j])
328+
329+
# Inverse of Step-1 & 2: Multiply Monomials.
330+
for i in range(m):
331+
inp_vec = GF(2).Zeros(m)
332+
inp_vec[i] = 1
333+
out_vec = GF(2)(self.reduction_matrix_q) @ inp_vec
334+
indices = [k for k in range(m) if out_vec[k]]
335+
result[indices], parity = bb.add(Parity(len(indices)), x=result[indices])
336+
for j in range(i + 1, m):
337+
parity, x[m - j + i], y[j] = bb.add(ctrl_cz, ctrl=parity, q1=x[m - j + i], q2=y[j])
338+
bb.add(Discard(), c=parity)
339+
340+
# Done :)
341+
for c in result:
342+
bb.add(Discard(), c=c)
343+
return {'x': bb.join(x[::-1], dtype=self.qgf), 'y': bb.join(y[::-1], dtype=self.qgf)}
344+
345+
def build_call_graph(
346+
self, ssa: 'SympySymbolAllocator'
347+
) -> Union['BloqCountDictT', Set['BloqCountT']]:
348+
m = self.bitsize
349+
return {CZ(): m**2}
350+
351+
def adjoint(self) -> 'Bloq':
352+
return GF2MulViaKaratsuba(self.qgf)
353+
354+
264355
@attrs.frozen
265356
class GF2MulK(Bloq):
266357
r"""Multiply by constant $f(x)$ modulo $m(x)$. Both $f(x)$ and $m(x)$ are constants.
@@ -953,7 +1044,6 @@ class GF2MulViaKaratsuba(Bloq):
9531044
9541045
Args:
9551046
m_x: The irreducible polynomial that defines the galois field.
956-
uncompute: Whether to compute or uncompute the product.
9571047
9581048
Registers:
9591049
x: A TRHU register representing the first number (or polynomial).
@@ -966,7 +1056,6 @@ class GF2MulViaKaratsuba(Bloq):
9661056
"""
9671057

9681058
dtype: QGF = attrs.field(converter=_qgf_converter)
969-
uncompute: bool = False
9701059

9711060
@cached_property
9721061
def m_x(self):
@@ -988,21 +1077,16 @@ def gf(self):
9881077
def qgf(self):
9891078
return self.dtype
9901079

991-
def adjoint(self) -> 'GF2MulViaKaratsuba':
992-
return attrs.evolve(self, uncompute=not self.uncompute)
993-
9941080
def __str__(self):
995-
return f'{self.__class__.__name__}†' if self.uncompute else f'{self.__class__.__name__}'
1081+
return f'{self.__class__.__name__}'
9961082

9971083
@cached_property
9981084
def signature(self) -> 'Signature':
999-
# C is directional
1000-
side = Side.LEFT if self.uncompute else Side.RIGHT
10011085
return Signature(
10021086
[
10031087
Register('x', dtype=self.qgf),
10041088
Register('y', dtype=self.qgf),
1005-
Register('result', dtype=self.qgf, side=side),
1089+
Register('result', dtype=self.qgf, side=Side.RIGHT),
10061090
]
10071091
)
10081092

@@ -1016,8 +1100,6 @@ def k(self):
10161100
@cached_property
10171101
def _GF2MulViaKaratsubamod_impl(self) -> Bloq:
10181102
impl = _GF2MulViaKaratsubaImpl(self.m_x)
1019-
if self.uncompute:
1020-
return impl.adjoint()
10211103
return impl
10221104

10231105
def build_composite_bloq(
@@ -1026,17 +1108,10 @@ def build_composite_bloq(
10261108
if is_symbolic(self.k, self.n):
10271109
raise DecomposeTypeError(f"Symbolic Decomposition is not supported for {self}")
10281110

1029-
if self.uncompute:
1030-
result = soqs['result']
1031-
else:
1032-
result = bb.allocate(self.n, self.qgf)
1111+
result = bb.allocate(self.n, self.qgf)
10331112

10341113
x, y, result = bb.add_from(self._GF2MulViaKaratsubamod_impl, f=x, g=y, h=result)
10351114

1036-
if self.uncompute:
1037-
bb.free(result) # type: ignore[arg-type]
1038-
return {'x': x, 'y': y}
1039-
10401115
return {'x': x, 'y': y, 'result': result}
10411116

10421117
def build_call_graph(
@@ -1069,11 +1144,11 @@ def on_classical_vals(
10691144
) -> Dict[str, 'ClassicalValT']:
10701145
assert isinstance(x, self.gf)
10711146
assert isinstance(y, self.gf)
1072-
if self.uncompute:
1073-
assert x * y == result
1074-
return {'x': x, 'y': y}
10751147
return {'x': x, 'y': y, 'result': x * y}
10761148

1149+
def adjoint(self) -> 'Bloq':
1150+
return GF2MulMBUC(self.qgf)
1151+
10771152

10781153
@bloq_example
10791154
def _gf2mulviakaratsuba() -> GF2MulViaKaratsuba:

qualtran/bloqs/gf_arithmetic/gf2_multiplication_test.py

Lines changed: 45 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,11 @@
2222
from qualtran import QGF
2323
from qualtran.bloqs.gf_arithmetic.gf2_multiplication import (
2424
_gf2_multiplication_symbolic,
25+
_GF2MulViaKaratsubaImpl,
2526
_gf16_multiplication,
2627
BinaryPolynomialMultiplication,
2728
GF2MulK,
29+
GF2MulMBUC,
2830
GF2Multiplication,
2931
GF2MulViaKaratsuba,
3032
GF2ShiftLeft,
@@ -34,6 +36,7 @@
3436
)
3537
from qualtran.resource_counting import get_cost_value, QECGatesCost
3638
from qualtran.resource_counting.generalizers import ignore_alloc_free, ignore_split_join
39+
from qualtran.simulation.classical_sim import do_phased_classical_simulation
3740
from qualtran.testing import assert_consistent_classical_action
3841

3942

@@ -47,7 +50,7 @@ def test_gf2_multiplication_symbolic(bloq_autotester):
4750

4851
@pytest.mark.parametrize('m', [2, 4, 6, 8])
4952
def test_synthesize_lr_circuit(m: int):
50-
matrix = GF2Multiplication(m).reduction_matrix_q
53+
matrix = GF2MulMBUC(m).reduction_matrix_q
5154
bloq = SynthesizeLRCircuit(matrix)
5255
bloq_adj = bloq.adjoint()
5356
QGFM, GFM = QGF(2, m), GF(2**m)
@@ -61,7 +64,7 @@ def test_synthesize_lr_circuit(m: int):
6164
@pytest.mark.slow
6265
@pytest.mark.parametrize('m', [3, 4, 5])
6366
def test_synthesize_lr_circuit_slow(m):
64-
matrix = GF2Multiplication(m).reduction_matrix_q
67+
matrix = GF2MulMBUC(m).reduction_matrix_q
6568
bloq = SynthesizeLRCircuit(matrix)
6669
bloq_adj = bloq.adjoint()
6770
QGFM, GFM = QGF(2, m), GF(2**m)
@@ -349,15 +352,16 @@ def test_gf2mulmod_classical_action_slow():
349352

350353
@pytest.mark.parametrize('m_x', [[2, 1, 0], [3, 1, 0], [5, 2, 0]])
351354
def test_gf2mulmod_classical_action_adjoint(m_x):
352-
blq = GF2MulViaKaratsuba(m_x)
355+
blq = _GF2MulViaKaratsubaImpl(m_x)
353356
adjoint = blq.adjoint()
354357
rs = np.random.default_rng(42)
355-
for i, j in rs.integers(0, len(blq.gf.elements) - 1, (10, 2)):
356-
f = blq.gf.elements[i]
357-
g = blq.gf.elements[j]
358-
a, b, c = blq.call_classically(x=f, y=g)
359-
a, b = adjoint.call_classically(x=a, y=b, result=c)
360-
assert a == f and b == g
358+
zero = blq.qgf.gf_type(0)
359+
for i, j in rs.integers(0, len(blq.qgf.gf_type.elements) - 1, (10, 2)):
360+
f = blq.qgf.gf_type.elements[i]
361+
g = blq.qgf.gf_type.elements[j]
362+
a, b, c = blq.call_classically(f=f, g=g, h=zero)
363+
a, b, c = adjoint.call_classically(f=a, g=b, h=c)
364+
assert a == f and b == g and c == zero
361365

362366

363367
@pytest.mark.parametrize('m_x', [[2, 1, 0], [8, 4, 3, 1, 0], [16, 5, 3, 1, 0]])
@@ -374,3 +378,35 @@ def test_gf2mulmod_classical_complexity(m_x):
374378
def test_gf2mul_invalid_input_raises():
375379
with pytest.raises(ValueError):
376380
_ = GF2MulViaKaratsuba([0, 1]) # type: ignore[arg-type]
381+
382+
383+
def test_gf2_mul_mbuc_quick():
384+
m = 3
385+
bloq_mbuc = GF2MulMBUC(m)
386+
rng = np.random.default_rng(seed=123)
387+
for x in bloq_mbuc.qgf.gf_type.elements:
388+
for y in bloq_mbuc.qgf.gf_type.elements:
389+
in_vals = {'x': x, 'y': y, 'result': x * y}
390+
out_vals, phase = do_phased_classical_simulation(
391+
bloq_mbuc.decompose_bloq(), in_vals, rng=rng
392+
)
393+
assert out_vals['x'] == x
394+
assert out_vals['y'] == y
395+
assert 'result' not in out_vals
396+
assert phase == 1
397+
398+
399+
@pytest.mark.parametrize('m', [4, 5])
400+
def test_gf2_mul_mbuc(m: int):
401+
bloq_mbuc = GF2MulMBUC(m)
402+
rng = np.random.default_rng(seed=123)
403+
for x in bloq_mbuc.qgf.gf_type.elements:
404+
for y in bloq_mbuc.qgf.gf_type.elements:
405+
in_vals = {'x': x, 'y': y, 'result': x * y}
406+
out_vals, phase = do_phased_classical_simulation(
407+
bloq_mbuc.decompose_bloq(), in_vals, rng=rng
408+
)
409+
assert out_vals['x'] == x
410+
assert out_vals['y'] == y
411+
assert 'result' not in out_vals
412+
assert phase == 1

0 commit comments

Comments
 (0)