1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414from functools import cached_property
15- from typing import Dict , Optional , Sequence , Set , TYPE_CHECKING , Union
15+ from typing import Dict , Mapping , Optional , Sequence , Set , TYPE_CHECKING , Union
1616
1717import attrs
1818import galois
2424 Bloq ,
2525 bloq_example ,
2626 BloqDocSpec ,
27+ CBit ,
28+ CtrlSpec ,
2729 DecomposeTypeError ,
2830 QBit ,
2931 QGF ,
3032 Register ,
3133 Side ,
3234 Signature ,
3335)
34- from qualtran .bloqs .basic_gates import CNOT , Toffoli
36+ from qualtran .bloqs .basic_gates import CNOT , CZ , Discard , MeasureX , Toffoli
3537from qualtran .symbolics import ceil , is_symbolic , log2 , Shaped , SymbolicInt
3638
3739if TYPE_CHECKING :
3840 from qualtran import BloqBuilder , Soquet , SoquetT
3941 from qualtran .resource_counting import BloqCountDictT , BloqCountT , SympySymbolAllocator
40- from qualtran .simulation .classical_sim import ClassicalValT
42+ from qualtran .simulation .classical_sim import ClassicalValRetT , ClassicalValT
4143
4244
4345def _data_or_shape_to_tuple (data_or_shape : Union [np .ndarray , Shaped ]) -> tuple :
@@ -172,21 +174,6 @@ def signature(self) -> 'Signature':
172174 def bitsize (self ) -> SymbolicInt :
173175 return self .qgf .bitsize
174176
175- @cached_property
176- def reduction_matrix_q (self ) -> np .ndarray :
177- m = int (self .bitsize )
178- f = self .qgf .gf_type .irreducible_poly
179- M = np .zeros ((m , m ))
180- alpha = [1 ] + [0 ] * m
181- for i in range (m - 1 ):
182- # x ** (m + i) % f
183- coeffs = (Poly (alpha , GF (2 )) % f ).coeffs .tolist ()[::- 1 ]
184- coeffs = coeffs + [0 ] * (m - len (coeffs ))
185- M [i ] = coeffs
186- alpha += [0 ]
187- M [m - 1 ][m - 1 ] = 1
188- return np .transpose (M )
189-
190177 def build_composite_bloq (self , bb : 'BloqBuilder' , ** soqs : 'Soquet' ) -> Dict [str , 'Soquet' ]:
191178 if is_symbolic (self .bitsize ):
192179 raise DecomposeTypeError (f"Cannot decompose symbolic { self } " )
@@ -261,6 +248,110 @@ def _gf2_multiplication_symbolic() -> GF2Multiplication:
261248)
262249
263250
251+ @attrs .frozen
252+ class Parity (Bloq ):
253+ n : int
254+
255+ @cached_property
256+ def signature (self ) -> 'Signature' :
257+ return Signature (
258+ [
259+ Register ('x' , dtype = CBit (), shape = (self .n ,)),
260+ Register ('parity' , dtype = CBit (), side = Side .RIGHT ),
261+ ]
262+ )
263+
264+ def on_classical_vals (
265+ self , * , x : Union ['sympy.Symbol' , 'ClassicalValT' ]
266+ ) -> Mapping [str , 'ClassicalValRetT' ]:
267+ assert isinstance (x , np .ndarray )
268+ return {'x' : x , 'parity' : np .sum (x , dtype = int ) & 1 }
269+
270+
271+ @attrs .frozen
272+ class GF2MulMBUC (Bloq ):
273+ r"""Measurement based uncomputation of out of place multiplication over GF($2^m$).
274+
275+ Args:
276+ bitsize: The degree $m$ of the galois field $GF(2^m)$. Also corresponds to the number of
277+ qubits in each of the two input registers $a$ and $b$ that should be multiplied.
278+
279+ Registers:
280+ x: Input THRU register of size $m$ that stores elements from $GF(2^m)$.
281+ y: Input THRU register of size $m$ that stores elements from $GF(2^m)$.
282+ result: Register of size $m$ that stores the product $x * y$ in $GF(2^m)$.
283+ """
284+
285+ qgf : QGF = attrs .field (converter = _qgf_converter )
286+
287+ @cached_property
288+ def signature (self ) -> 'Signature' :
289+ return Signature (
290+ [
291+ Register ('x' , dtype = self .qgf ),
292+ Register ('y' , dtype = self .qgf ),
293+ Register ('result' , dtype = self .qgf , side = Side .LEFT ),
294+ ]
295+ )
296+
297+ @cached_property
298+ def bitsize (self ) -> SymbolicInt :
299+ return self .qgf .bitsize
300+
301+ @cached_property
302+ def reduction_matrix_q (self ) -> np .ndarray :
303+ m = int (self .bitsize )
304+ f = self .qgf .gf_type .irreducible_poly
305+ M = np .zeros ((m , m ), dtype = int )
306+ alpha = [1 ] + [0 ] * m
307+ for i in range (m ):
308+ # x ** (m + i) % f
309+ coeffs = (Poly (alpha , GF (2 )) % f ).coeffs .tolist ()[::- 1 ]
310+ coeffs = coeffs + [0 ] * (m - len (coeffs ))
311+ M [i ] = coeffs
312+ alpha += [0 ]
313+ return np .transpose (M )
314+
315+ def build_composite_bloq (self , bb : 'BloqBuilder' , ** soqs : 'Soquet' ) -> Dict [str , 'Soquet' ]:
316+ if is_symbolic (self .bitsize ):
317+ raise DecomposeTypeError (f"Cannot decompose symbolic { self } " )
318+ x , y , result = soqs ['x' ], soqs ['y' ], soqs ['result' ]
319+ x , y , result = bb .split (x )[::- 1 ], bb .split (y )[::- 1 ], bb .split (result )[::- 1 ]
320+ result = np .array ([bb .add (MeasureX (), q = q ) for q in result ])
321+ m = int (self .bitsize )
322+
323+ # Inverse of Step-3: Multiply Monomials
324+ ctrl_cz = CZ ().controlled (CtrlSpec (qdtypes = [CBit ()]))
325+ for i in range (m ):
326+ for j in range (i + 1 ):
327+ result [i ], x [j ], y [i - j ] = bb .add (ctrl_cz , ctrl = result [i ], q1 = x [j ], q2 = y [i - j ])
328+
329+ # Inverse of Step-1 & 2: Multiply Monomials.
330+ for i in range (m ):
331+ inp_vec = GF (2 ).Zeros (m )
332+ inp_vec [i ] = 1
333+ out_vec = GF (2 )(self .reduction_matrix_q ) @ inp_vec
334+ indices = [k for k in range (m ) if out_vec [k ]]
335+ result [indices ], parity = bb .add (Parity (len (indices )), x = result [indices ])
336+ for j in range (i + 1 , m ):
337+ parity , x [m - j + i ], y [j ] = bb .add (ctrl_cz , ctrl = parity , q1 = x [m - j + i ], q2 = y [j ])
338+ bb .add (Discard (), c = parity )
339+
340+ # Done :)
341+ for c in result :
342+ bb .add (Discard (), c = c )
343+ return {'x' : bb .join (x [::- 1 ], dtype = self .qgf ), 'y' : bb .join (y [::- 1 ], dtype = self .qgf )}
344+
345+ def build_call_graph (
346+ self , ssa : 'SympySymbolAllocator'
347+ ) -> Union ['BloqCountDictT' , Set ['BloqCountT' ]]:
348+ m = self .bitsize
349+ return {CZ (): m ** 2 }
350+
351+ def adjoint (self ) -> 'Bloq' :
352+ return GF2MulViaKaratsuba (self .qgf )
353+
354+
264355@attrs .frozen
265356class GF2MulK (Bloq ):
266357 r"""Multiply by constant $f(x)$ modulo $m(x)$. Both $f(x)$ and $m(x)$ are constants.
@@ -953,7 +1044,6 @@ class GF2MulViaKaratsuba(Bloq):
9531044
9541045 Args:
9551046 m_x: The irreducible polynomial that defines the galois field.
956- uncompute: Whether to compute or uncompute the product.
9571047
9581048 Registers:
9591049 x: A TRHU register representing the first number (or polynomial).
@@ -966,7 +1056,6 @@ class GF2MulViaKaratsuba(Bloq):
9661056 """
9671057
9681058 dtype : QGF = attrs .field (converter = _qgf_converter )
969- uncompute : bool = False
9701059
9711060 @cached_property
9721061 def m_x (self ):
@@ -988,21 +1077,16 @@ def gf(self):
9881077 def qgf (self ):
9891078 return self .dtype
9901079
991- def adjoint (self ) -> 'GF2MulViaKaratsuba' :
992- return attrs .evolve (self , uncompute = not self .uncompute )
993-
9941080 def __str__ (self ):
995- return f'{ self .__class__ .__name__ } †' if self . uncompute else f' { self . __class__ . __name__ } '
1081+ return f'{ self .__class__ .__name__ } '
9961082
9971083 @cached_property
9981084 def signature (self ) -> 'Signature' :
999- # C is directional
1000- side = Side .LEFT if self .uncompute else Side .RIGHT
10011085 return Signature (
10021086 [
10031087 Register ('x' , dtype = self .qgf ),
10041088 Register ('y' , dtype = self .qgf ),
1005- Register ('result' , dtype = self .qgf , side = side ),
1089+ Register ('result' , dtype = self .qgf , side = Side . RIGHT ),
10061090 ]
10071091 )
10081092
@@ -1016,8 +1100,6 @@ def k(self):
10161100 @cached_property
10171101 def _GF2MulViaKaratsubamod_impl (self ) -> Bloq :
10181102 impl = _GF2MulViaKaratsubaImpl (self .m_x )
1019- if self .uncompute :
1020- return impl .adjoint ()
10211103 return impl
10221104
10231105 def build_composite_bloq (
@@ -1026,17 +1108,10 @@ def build_composite_bloq(
10261108 if is_symbolic (self .k , self .n ):
10271109 raise DecomposeTypeError (f"Symbolic Decomposition is not supported for { self } " )
10281110
1029- if self .uncompute :
1030- result = soqs ['result' ]
1031- else :
1032- result = bb .allocate (self .n , self .qgf )
1111+ result = bb .allocate (self .n , self .qgf )
10331112
10341113 x , y , result = bb .add_from (self ._GF2MulViaKaratsubamod_impl , f = x , g = y , h = result )
10351114
1036- if self .uncompute :
1037- bb .free (result ) # type: ignore[arg-type]
1038- return {'x' : x , 'y' : y }
1039-
10401115 return {'x' : x , 'y' : y , 'result' : result }
10411116
10421117 def build_call_graph (
@@ -1069,11 +1144,11 @@ def on_classical_vals(
10691144 ) -> Dict [str , 'ClassicalValT' ]:
10701145 assert isinstance (x , self .gf )
10711146 assert isinstance (y , self .gf )
1072- if self .uncompute :
1073- assert x * y == result
1074- return {'x' : x , 'y' : y }
10751147 return {'x' : x , 'y' : y , 'result' : x * y }
10761148
1149+ def adjoint (self ) -> 'Bloq' :
1150+ return GF2MulMBUC (self .qgf )
1151+
10771152
10781153@bloq_example
10791154def _gf2mulviakaratsuba () -> GF2MulViaKaratsuba :
0 commit comments