Skip to content

Commit e10560c

Browse files
Fix bug in KaliskiStep3 and add tests for all steps (#1496)
* Fix bug in KaliskiStep3 and add tests for all steps * cost
1 parent 07c98b6 commit e10560c

4 files changed

Lines changed: 99 additions & 27 deletions

File tree

qualtran/bloqs/arithmetic/comparison.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1462,20 +1462,17 @@ def on_classical_vals(
14621462
c: Optional['ClassicalValT'] = None,
14631463
target: Optional['ClassicalValT'] = None,
14641464
) -> Dict[str, 'ClassicalValT']:
1465+
if self._op_symbol in ('>', '<='):
1466+
c_val = add_ints(-int(a), int(b), num_bits=self.dtype.bitsize + 1, is_signed=False)
1467+
else:
1468+
c_val = add_ints(int(a), -int(b), num_bits=self.dtype.bitsize + 1, is_signed=False)
14651469
if self.uncompute:
1466-
assert c == add_ints(
1467-
int(a),
1468-
int(b),
1469-
num_bits=int(self.dtype.bitsize),
1470-
is_signed=isinstance(self.dtype, QInt),
1471-
)
1470+
assert c == c_val
14721471
assert target == self._classical_comparison(a, b)
14731472
return {'a': a, 'b': b}
1474-
if self._op_symbol in ('>', '<='):
1475-
c = add_ints(-int(a), int(b), num_bits=self.dtype.bitsize + 1, is_signed=False)
1476-
else:
1477-
c = add_ints(int(a), -int(b), num_bits=self.dtype.bitsize + 1, is_signed=False)
1478-
return {'a': a, 'b': b, 'c': c, 'target': int(self._classical_comparison(a, b))}
1473+
assert c is None
1474+
assert target is None
1475+
return {'a': a, 'b': b, 'c': c_val, 'target': int(self._classical_comparison(a, b))}
14791476

14801477
def _compute(self, bb: 'BloqBuilder', a: 'Soquet', b: 'Soquet') -> Dict[str, 'SoquetT']:
14811478
if self._op_symbol in ('>', '<='):

qualtran/bloqs/factoring/ecc/ec_add_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -418,7 +418,7 @@ def test_ec_add_symbolic_cost():
418418
# toffoli cost for Kaliski Mod Inverse, n extra toffolis in ModNeg, 2n extra toffolis to do n
419419
# 3-controlled toffolis in step 2. The expression is written with rationals because sympy
420420
# comparison fails with floats.
421-
assert total_toff == sympy.Rational(253, 2) * n**2 + sympy.Rational(391, 2) * n - 31
421+
assert total_toff == sympy.Rational(253, 2) * n**2 + sympy.Rational(407, 2) * n - 31
422422

423423

424424
def test_ec_add(bloq_autotester):

qualtran/bloqs/mod_arithmetic/mod_division.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,6 @@ def signature(self) -> 'Signature':
7272
def on_classical_vals(
7373
self, v: int, m: int, f: int, is_terminal: int
7474
) -> Dict[str, 'ClassicalValT']:
75-
print('here')
76-
assert False
7775
m ^= f & (v == 0)
7876
assert is_terminal == 0
7977
is_terminal ^= m
@@ -101,10 +99,10 @@ def build_composite_bloq(
10199

102100
def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT':
103101
if is_symbolic(self.bitsize):
104-
cvs: Union[HasLength, List[int]] = HasLength(self.bitsize)
102+
cvs: Union[HasLength, List[int]] = HasLength(self.bitsize + 1)
105103
else:
106-
cvs = [0] * int(self.bitsize)
107-
return {MultiAnd(cvs=cvs): 1, MultiAnd(cvs=cvs).adjoint(): 1, CNOT(): 2}
104+
cvs = [0] * int(self.bitsize) + [1]
105+
return {MultiAnd(cvs=cvs): 1, MultiAnd(cvs=cvs).adjoint(): 1, CNOT(): 3}
108106

109107

110108
@frozen
@@ -197,25 +195,25 @@ def on_classical_vals(
197195
def build_composite_bloq(
198196
self, bb: 'BloqBuilder', u: Soquet, v: Soquet, b: Soquet, a: Soquet, m: Soquet, f: Soquet
199197
) -> Dict[str, 'SoquetT']:
200-
u, v, junk, greater_than = bb.add(
198+
u, v, junk_c, greater_than = bb.add(
201199
LinearDepthHalfGreaterThan(QMontgomeryUInt(self.bitsize)), a=u, b=v
202200
)
203201

204-
(greater_than, f, b), junk, ctrl = bb.add(
202+
(greater_than, f, b), junk_m, ctrl = bb.add(
205203
MultiAnd(cvs=(1, 1, 0)), ctrl=(greater_than, f, b)
206204
)
207205

208206
ctrl, a = bb.add(CNOT(), ctrl=ctrl, target=a)
209207
ctrl, m = bb.add(CNOT(), ctrl=ctrl, target=m)
210208

211209
greater_than, f, b = bb.add(
212-
MultiAnd(cvs=(1, 1, 0)).adjoint(), ctrl=(greater_than, f, b), junk=junk, target=ctrl
210+
MultiAnd(cvs=(1, 1, 0)).adjoint(), ctrl=(greater_than, f, b), junk=junk_m, target=ctrl
213211
)
214212
u, v = bb.add(
215213
LinearDepthHalfGreaterThan(QMontgomeryUInt(self.bitsize)).adjoint(),
216214
a=u,
217215
b=v,
218-
c=junk,
216+
c=junk_c,
219217
target=greater_than,
220218
)
221219
return {'u': u, 'v': v, 'b': b, 'a': a, 'm': m, 'f': f}
@@ -391,7 +389,7 @@ def build_composite_bloq(
391389

392390
def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT':
393391
return {
394-
CNOT(): 4,
392+
CNOT(): 3,
395393
XGate(): 2,
396394
ModDbl(QMontgomeryUInt(self.bitsize), self.mod): 1,
397395
CSwapApprox(self.bitsize): 2,
@@ -475,7 +473,7 @@ def on_classical_vals(
475473
of `f` and `m`.
476474
"""
477475
assert m == 0
478-
is_terminal = f == 1 and v == 0
476+
is_terminal = int(f == 1 and v == 0)
479477
if f == 0:
480478
# When `f = 0` this means that the algorithm is nearly over and that we just need to
481479
# double the value of `r`.
@@ -489,7 +487,8 @@ def on_classical_vals(
489487
f = 0
490488
r = (r << 1) % self.mod
491489
else:
492-
m = (u % 2 == 1) & (v % 2 == 0)
490+
m = ((u % 2 == 1) & (v % 2 == 0)) or (u % 2 == 1 and v % 2 == 1 and u > v)
491+
m = int(m)
493492
# Kaliski iteration as described in Fig7 of https://arxiv.org/pdf/2001.09580.
494493
swap = (u % 2 == 0 and v % 2 == 1) or (u % 2 == 1 and v % 2 == 1 and u > v)
495494
if swap:

qualtran/bloqs/mod_arithmetic/mod_division_test.py

Lines changed: 79 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import qualtran.testing as qlt_testing
2121
from qualtran import QMontgomeryUInt
22+
from qualtran.bloqs.mod_arithmetic import mod_division
2223
from qualtran.bloqs.mod_arithmetic.mod_division import _kaliskimodinverse_example, KaliskiModInverse
2324
from qualtran.resource_counting import get_cost_value, QECGatesCost
2425
from qualtran.resource_counting.generalizers import ignore_alloc_free, ignore_split_join
@@ -36,7 +37,7 @@ def test_kaliski_mod_inverse_classical_action(bitsize, mod):
3637
continue
3738
x_montgomery = dtype.uint_to_montgomery(x, mod)
3839
res = blq.call_classically(x=x_montgomery)
39-
print(x, x_montgomery)
40+
4041
assert res == cblq.call_classically(x=x_montgomery)
4142
assert len(res) == 2
4243
assert res[0] == dtype.montgomery_inverse(x_montgomery, mod)
@@ -85,11 +86,11 @@ def test_kaliski_symbolic_cost():
8586
# construction this is just $n-1$ (BitwiseNot -> Add(p+1)).
8687
# - The cost of an iteration in Litinski $13n$ since they ignore constants.
8788
# Our construction is exactly the same but we also count the constants
88-
# which amout to $3$. for a total cost of $13n + 3$.
89+
# which amout to $3$. for a total cost of $13n + 4$.
8990
# For example the cost of ModDbl is 2n+1. In their figure 8, they report
9091
# it as just $2n$. ModDbl gets executed within the 2n loop so its contribution
9192
# to the overal cost should be 4n^2 + 2n instead of just 4n^2.
92-
assert total_toff == 26 * n**2 + 7 * n - 1
93+
assert total_toff == 26 * n**2 + 9 * n - 1
9394

9495

9596
def test_kaliskimodinverse_example(bloq_autotester):
@@ -99,3 +100,78 @@ def test_kaliskimodinverse_example(bloq_autotester):
99100
@pytest.mark.notebook
100101
def test_notebook():
101102
qlt_testing.execute_notebook('mod_division')
103+
104+
105+
def test_kaliski_iteration_decomposition():
106+
mod = 7
107+
bitsize = 5
108+
b = mod_division._KaliskiIteration(bitsize, mod)
109+
cb = b.decompose_bloq()
110+
for x in range(mod):
111+
u = mod
112+
v = x
113+
r = 0
114+
s = 1
115+
f = 1
116+
117+
for _ in range(2 * bitsize):
118+
inputs = {'u': u, 'v': v, 'r': r, 's': s, 'm': 0, 'f': f, 'is_terminal': 0}
119+
res = b.call_classically(**inputs)
120+
assert res == cb.call_classically(**inputs), f'{inputs=}'
121+
u, v, r, s, _, f, _ = res # type: ignore
122+
123+
qlt_testing.assert_valid_bloq_decomposition(b)
124+
qlt_testing.assert_equivalent_bloq_counts(b, generalizer=(ignore_alloc_free, ignore_split_join))
125+
126+
127+
def test_kaliski_steps():
128+
bitsize = 5
129+
mod = 7
130+
steps = [
131+
mod_division._KaliskiIterationStep1(bitsize),
132+
mod_division._KaliskiIterationStep2(bitsize),
133+
mod_division._KaliskiIterationStep3(bitsize),
134+
mod_division._KaliskiIterationStep4(bitsize),
135+
mod_division._KaliskiIterationStep5(bitsize),
136+
mod_division._KaliskiIterationStep6(bitsize, mod),
137+
]
138+
csteps = [b.decompose_bloq() for b in steps]
139+
140+
# check decomposition is valid.
141+
for step in steps:
142+
qlt_testing.assert_valid_bloq_decomposition(step)
143+
qlt_testing.assert_equivalent_bloq_counts(
144+
step, generalizer=(ignore_alloc_free, ignore_split_join)
145+
)
146+
147+
# check that for all inputs all 2n iteration work when excuted directly on the 6 steps
148+
# and their decompositions.
149+
for x in range(mod):
150+
u, v, r, s, f = mod, x, 0, 1, 1
151+
152+
for _ in range(2 * bitsize):
153+
a = b = m = is_terminal = 0
154+
155+
res = steps[0].call_classically(v=v, m=m, f=f, is_terminal=is_terminal)
156+
assert res == csteps[0].call_classically(v=v, m=m, f=f, is_terminal=is_terminal)
157+
v, m, f, is_terminal = res # type: ignore
158+
159+
res = steps[1].call_classically(u=u, v=v, b=b, a=a, m=m, f=f)
160+
assert res == csteps[1].call_classically(u=u, v=v, b=b, a=a, m=m, f=f)
161+
u, v, b, a, m, f = res # type: ignore
162+
163+
res = steps[2].call_classically(u=u, v=v, b=b, a=a, m=m, f=f)
164+
assert res == csteps[2].call_classically(u=u, v=v, b=b, a=a, m=m, f=f)
165+
u, v, b, a, m, f = res # type: ignore
166+
167+
res = steps[3].call_classically(u=u, v=v, r=r, s=s, a=a)
168+
assert res == csteps[3].call_classically(u=u, v=v, r=r, s=s, a=a)
169+
u, v, r, s, a = res # type: ignore
170+
171+
res = steps[4].call_classically(u=u, v=v, r=r, s=s, b=b, f=f)
172+
assert res == csteps[4].call_classically(u=u, v=v, r=r, s=s, b=b, f=f)
173+
u, v, r, s, b, f = res # type: ignore
174+
175+
res = steps[5].call_classically(u=u, v=v, r=r, s=s, b=b, a=a, m=m, f=f)
176+
assert res == csteps[5].call_classically(u=u, v=v, r=r, s=s, b=b, a=a, m=m, f=f)
177+
u, v, r, s, b, a, m, f = res # type: ignore

0 commit comments

Comments
 (0)