1111# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212# See the License for the specific language governing permissions and
1313# limitations under the License.
14+ import warnings
1415from functools import cached_property
15- from typing import Dict , Set , TYPE_CHECKING , Union
16+ from typing import cast , Dict , Set , TYPE_CHECKING , Union
1617
1718import attrs
1819import numpy as np
1920
21+ import qualtran .bloqs .gf_arithmetic .gf_utils as gf_utils
2022from qualtran import (
2123 Bloq ,
2224 bloq_example ,
2830 Signature ,
2931)
3032from qualtran .bloqs .gf_arithmetic .gf2_addition import GF2Addition
31- from qualtran .bloqs .gf_arithmetic .gf2_multiplication import GF2MulViaKaratsuba
33+ from qualtran .bloqs .gf_arithmetic .gf2_multiplication import GF2MulViaKaratsuba , SynthesizeLRCircuit
3234from qualtran .bloqs .gf_arithmetic .gf2_square import GF2Square
3335from qualtran .resource_counting .generalizers import ignore_alloc_free , ignore_split_join
34- from qualtran .symbolics import bit_length , ceil , is_symbolic , log2 , SymbolicInt
36+ from qualtran .symbolics import bit_length , is_symbolic , Shaped , SymbolicInt
3537
3638if TYPE_CHECKING :
3739 from qualtran import BloqBuilder , Soquet , SoquetT
@@ -69,8 +71,7 @@ class GF2Inverse(Bloq):
6971 where $B_1 = x$ and $B_{i+j} = B_i B_j^{2^i}$.
7072
7173 Args:
72- bitsize: The degree $m$ of the galois field $GF(2^m)$. Also corresponds to the number of
73- qubits in the input register whose inverse should be calculated.
74+ qgf: QGF type of the registers.
7475
7576 Registers:
7677 x: Input THRU register of size $m$ that stores elements from $GF(2^m)$.
@@ -84,9 +85,26 @@ class GF2Inverse(Bloq):
8485
8586 [Structure of parallel multipliers for a class of fields GF(2^m)](https://doi.org/10.1016/0890-5401(89)90045-X).
8687 Itoh and Tsujii. 1989.
88+
89+ [Concrete quantum cryptanalysisof binary elliptic curves](https://tches.iacr.org/index.php/TCHES/article/view/8741/8341)
90+ Algorithm 2.
8791 """
8892
89- bitsize : SymbolicInt
93+ qgf : QGF = attrs .field (converter = gf_utils .qgf_converter )
94+
95+ def __init__ (self , qgf = None , bitsize = None ):
96+ if not ((qgf is None ) ^ (bitsize is None )):
97+ raise TypeError ("Exactly one of `qgf` or `bitsize` should be specified." )
98+ if qgf is not None :
99+ qgf = gf_utils .qgf_converter (qgf )
100+ if bitsize is not None :
101+ warnings .warn (
102+ "The `bitsize` attribute is deprecated. Use `qgf` instead." ,
103+ DeprecationWarning ,
104+ stacklevel = 2 ,
105+ )
106+ qgf = gf_utils .qgf_converter (bitsize )
107+ object .__setattr__ (self , 'qgf' , qgf )
90108
91109 @cached_property
92110 def signature (self ) -> 'Signature' :
@@ -104,12 +122,14 @@ def signature(self) -> 'Signature':
104122 )
105123
106124 @cached_property
107- def qgf (self ) -> QGF :
108- return QGF ( characteristic = 2 , degree = self .bitsize )
125+ def bitsize (self ) -> SymbolicInt :
126+ return self .qgf . degree
109127
110128 @cached_property
111129 def n_junk_regs (self ) -> SymbolicInt :
112- return 2 * bit_length (self .bitsize - 1 ) + self .bitsize_hamming_weight - 3
130+ if is_symbolic (self .bitsize ):
131+ return 2 * bit_length (self .bitsize - 1 ) - 2
132+ return bit_length (self .bitsize - 1 ) - 2 + max (self .bitsize_hamming_weight - 1 , 1 )
113133
114134 @cached_property
115135 def bitsize_hamming_weight (self ) -> SymbolicInt :
@@ -129,90 +149,110 @@ def my_static_costs(self, cost_key: 'CostKey'):
129149
130150 return NotImplemented
131151
152+ @cached_property
153+ def _bits (self ) -> list [int ]:
154+ k1 = bit_length (self .bitsize - 1 ) - 1
155+ return [- 1 ] + [k1 - i for i , b in enumerate (np .binary_repr (self .bitsize - 1 )) if b == '1' ]
156+
157+ @cached_property
158+ def gf2_multiplier (self ) -> Bloq :
159+ return GF2MulViaKaratsuba (self .qgf )
160+
132161 def build_composite_bloq (self , bb : 'BloqBuilder' , * , x : 'Soquet' ) -> Dict [str , 'SoquetT' ]:
133162 if is_symbolic (self .bitsize ):
134163 raise DecomposeTypeError (f"Cannot decompose symbolic { self } " )
135164
136- result = bb .allocate (dtype = self .qgf )
137165 if self .bitsize == 1 :
138- x , result = bb .add (GF2Addition (self .bitsize ), x = x , y = result )
166+ result = bb .allocate (dtype = self .qgf )
167+ x , result = bb .add (GF2Addition (self .qgf ), x = x , y = result )
139168 return {'x' : x , 'result' : result }
140169
141- junk = []
142- beta = x
143- is_first = True
144- bitsize_minus_one = int (self .bitsize - 1 )
145- n_iters = bitsize_minus_one .bit_length ()
146- for i in range (n_iters ):
147- if (1 << i ) & bitsize_minus_one :
148- if is_first :
149- beta , result = bb .add (GF2Addition (self .bitsize ), x = beta , y = result )
150- is_first = False
151- else :
152- for j in range (2 ** i ):
153- result = bb .add (GF2Square (self .bitsize ), x = result )
154- beta , result , new_result = bb .add (
155- GF2MulViaKaratsuba (self .bitsize ), x = beta , y = result
156- )
157- junk .append (result )
158- result = new_result
159- if i != n_iters - 1 :
160- beta_squared = bb .allocate (dtype = self .qgf )
161- beta , beta_squared = bb .add (GF2Addition (self .bitsize ), x = beta , y = beta_squared )
162- for j in range (2 ** i ):
163- beta_squared = bb .add (GF2Square (self .bitsize ), x = beta_squared )
164- beta , beta_squared , beta_new = bb .add (
165- GF2MulViaKaratsuba (self .bitsize ), x = beta , y = beta_squared
166- )
167- junk .extend ([beta , beta_squared ])
168- beta = beta_new
169- junk .append (beta )
170- result = bb .add (GF2Square (self .bitsize ), x = result )
171- x = junk .pop (0 )
172- assert len (junk ) == self .n_junk_regs , f'{ len (junk )= } , { self .n_junk_regs = } '
173- return {'x' : x , 'result' : result , 'junk' : np .array (junk )}
170+ t = (self .bitsize - 1 ).bit_count ()
171+ k1 = bit_length (self .bitsize - 1 ) - 1
172+ k = max (k1 + t - 1 , k1 + 1 )
173+ f = [x ] + [None ] * k
174+ f [k ] = bb .allocate (self .bitsize , self .qgf )
175+ f = cast (list ['Soquet' ], f )
176+ for i in range (1 , k1 + 1 ):
177+ f [i - 1 ], f [k ] = bb .add (GF2Addition (self .qgf ), x = f [i - 1 ], y = f [k ])
178+ f [k ] = bb .add (GF2Square (self .qgf , 2 ** (i - 1 )), x = f [k ])
179+ f [i - 1 ], f [k ], f [i ] = bb .add (self .gf2_multiplier , x = f [i - 1 ], y = f [k ])
180+ f [k ] = bb .add (GF2Square (self .qgf , 2 ** (i - 1 )).adjoint (), x = f [k ])
181+ f [i - 1 ], f [k ] = bb .add (GF2Addition (self .qgf ), x = f [i - 1 ], y = f [k ])
182+ bits = self ._bits
183+ if k1 + t - 1 == k :
184+ bb .free (f [k ])
185+ for s in range (1 , t ):
186+ f [k1 + s - 1 ] = bb .add (GF2Square (self .qgf , 2 ** bits [s + 1 ]), x = f [k1 + s - 1 ])
187+ f [k1 + s - 1 ], f [bits [s + 1 ]], f [k1 + s ] = bb .add (
188+ self .gf2_multiplier , x = f [k1 + s - 1 ], y = f [bits [s + 1 ]]
189+ )
190+
191+ if t == 1 :
192+ if k1 == 0 :
193+ assert self .bitsize == 2
194+ f [0 ], f [k ] = bb .add (GF2Addition (self .qgf ), x = f [0 ], y = f [k ])
195+ f [k1 ], f [k ] = f [k ], f [k1 ]
196+
197+ f [k ] = bb .add (GF2Square (qgf = self .qgf ), x = f [k ])
198+
199+ return {'x' : f [0 ], 'result' : f [k ], 'junk' : np .array (f [1 :k ])}
174200
175201 def build_call_graph (
176202 self , ssa : 'SympySymbolAllocator'
177203 ) -> Union ['BloqCountDictT' , Set ['BloqCountT' ]]:
178204 if not is_symbolic (self .bitsize ) and self .bitsize == 1 :
179- return {GF2Addition (self .bitsize ): 1 }
180- square_count = self .bitsize + 2 ** ceil (log2 (self .bitsize )) - 1
181- if not is_symbolic (self .bitsize ):
182- n = self .bitsize - 1
183- square_count -= n & (- n )
184- square_count -= 1 << (n .bit_length () - 1 )
185- mul_count = ceil (log2 (self .bitsize )) + self .bitsize_hamming_weight - 2
186- return {
187- GF2Addition (self .bitsize ): ceil (log2 (self .bitsize )),
188- GF2Square (self .bitsize ): square_count ,
189- } | ({GF2MulViaKaratsuba (self .bitsize ): mul_count } if mul_count else {})
205+ return {GF2Addition (self .qgf ): 1 }
206+ k1 = bit_length (self .bitsize - 1 ) - 1
207+ if is_symbolic (self .bitsize ):
208+ t = bit_length (self .bitsize - 1 )
209+ return {
210+ GF2Addition (self .qgf ): 2 * k1 ,
211+ self .gf2_multiplier : k1 + t - 1 ,
212+ SynthesizeLRCircuit (Shaped ((self .bitsize , self .bitsize ))): 2 * k1 + t ,
213+ }
214+
215+ t = (self .bitsize - 1 ).bit_count ()
216+ bloq_counts : dict [Bloq , int ] = (
217+ {GF2Square (self .qgf , 2 ** (i - 1 )): 1 for i in range (2 , k1 + 1 )}
218+ | {GF2Square (self .qgf , 2 ** (i - 1 )).adjoint (): 1 for i in range (1 , k1 + 1 )}
219+ | {GF2Square (self .qgf ): 1 + (k1 > 0 )}
220+ )
221+
222+ for i in self ._bits [2 :]:
223+ s = GF2Square (self .qgf , 2 ** i )
224+ bloq_counts [s ] = bloq_counts .get (s , 0 ) + 1
225+ mul_count = k1 + t - 1
226+ if mul_count :
227+ bloq_counts [self .gf2_multiplier ] = mul_count
228+ add_count = 2 * k1 + (self .bitsize == 2 )
229+ if add_count :
230+ bloq_counts [GF2Addition (self .qgf )] = add_count
231+ return bloq_counts
190232
191233 def on_classical_vals (self , * , x ) -> Dict [str , 'ClassicalValT' ]:
192234 assert isinstance (x , self .qgf .gf_type )
193- junk = []
194- bitsize_minus_one = int (self .bitsize - 1 )
195- beta = x
196- result = self .qgf .gf_type (0 )
197- is_first = True
198- for i in range (bitsize_minus_one .bit_length ()):
199- if (1 << i ) & bitsize_minus_one :
200- if is_first :
201- is_first = False
202- result = beta
203- else :
204- for j in range (2 ** i ):
205- result = result ** 2
206- junk .append (result )
207- result = result * beta
208- if i != bitsize_minus_one .bit_length () - 1 :
209- beta_squared = beta ** (2 ** (2 ** i ))
210- junk .extend ([beta , beta_squared ])
211- beta = beta * beta_squared
212- junk .append (beta )
213- assert x == junk [0 ]
214- junk = junk [1 :]
215- return {'x' : x , 'result' : x ** (- 1 ) if x else self .qgf .gf_type (0 ), 'junk' : np .array (junk )}
235+ t = (self .bitsize - 1 ).bit_count ()
236+ k1 = bit_length (self .bitsize - 1 ) - 1
237+ k = max (k1 + t - 1 , k1 + 1 )
238+ f = [x ] + [0 ] * k
239+ for i in range (1 , k1 + 1 ):
240+ f [i ] = f [i - 1 ] ** (2 ** (2 ** (i - 1 )) + 1 )
241+ bits = self ._bits
242+ for s in range (1 , t ):
243+ f [k1 + s - 1 ] = f [k1 + s - 1 ] ** (2 ** (2 ** bits [s + 1 ]))
244+ f [k1 + s ] = f [k1 + s - 1 ] * f [bits [s + 1 ]]
245+
246+ if t == 1 :
247+ f [k1 ], f [k ] = f [k ], f [k1 ]
248+
249+ f [k ] = f [k ] ** 2
250+
251+ return {
252+ 'x' : x ,
253+ 'result' : x ** (- 1 ) if x else self .qgf .gf_type (0 ),
254+ 'junk' : np .array (f [1 :- 1 ]),
255+ }
216256
217257
218258@bloq_example (generalizer = [ignore_split_join , ignore_alloc_free ])
0 commit comments