Skip to content

Commit d240ed2

Browse files
Add a test util for classical action and refactor factoring/mod_mul (#1339)
1 parent f2eabb2 commit d240ed2

21 files changed

Lines changed: 697 additions & 509 deletions

dev_tools/autogenerate-bloqs-notebooks-v2.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -511,9 +511,11 @@
511511
),
512512
NotebookSpecV2(
513513
title='Modular Multiplication',
514-
module=qualtran.bloqs.factoring.mod_mul,
515-
bloq_specs=[qualtran.bloqs.factoring.mod_mul._MODMUL_DOC],
516-
directory=f'{SOURCE_DIR}/bloqs/factoring',
514+
module=qualtran.bloqs.mod_arithmetic.mod_multiplication,
515+
bloq_specs=[
516+
qualtran.bloqs.mod_arithmetic.mod_multiplication._MOD_DBL_DOC,
517+
qualtran.bloqs.mod_arithmetic.mod_multiplication._C_MOD_MUL_K_DOC,
518+
],
517519
),
518520
NotebookSpecV2(
519521
title='Modular Exponentiation',

docs/bloqs/index.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ Bloqs Library
8181

8282
mod_arithmetic/mod_addition.ipynb
8383
mod_arithmetic/mod_subtraction.ipynb
84-
factoring/mod_mul.ipynb
84+
mod_arithmetic/mod_multiplication.ipynb
8585
factoring/mod_exp.ipynb
8686
factoring/ecc/ec_add.ipynb
8787
factoring/ecc/ecc.ipynb

qualtran/bloqs/arithmetic/bitwise.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from functools import cached_property
15-
from typing import Optional, Sequence, TYPE_CHECKING
15+
from typing import Dict, Optional, Sequence, TYPE_CHECKING
1616

1717
import numpy as np
1818
import sympy
@@ -26,6 +26,7 @@
2626
DecomposeTypeError,
2727
QAny,
2828
QDType,
29+
QMontgomeryUInt,
2930
QUInt,
3031
Register,
3132
Signature,
@@ -221,6 +222,12 @@ def wire_symbol(
221222

222223
return TextBox("~x")
223224

225+
def on_classical_vals(self, x: 'ClassicalValT') -> Dict[str, 'ClassicalValT']:
226+
x = -x - 1
227+
if isinstance(self.dtype, (QUInt, QMontgomeryUInt)):
228+
x %= 2**self.dtype.bitsize
229+
return {'x': x}
230+
224231

225232
@bloq_example
226233
def _bitwise_not() -> BitwiseNot:

qualtran/bloqs/arithmetic/bitwise_test.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
import numpy as np
1616
import pytest
1717

18-
from qualtran import BloqBuilder, QAny, QUInt
18+
import qualtran.testing as qlt_testing
19+
from qualtran import BloqBuilder, QAny, QInt, QMontgomeryUInt, QUInt
1920
from qualtran.bloqs.arithmetic.bitwise import (
2021
_bitwise_not,
2122
_bitwise_not_symb,
@@ -172,3 +173,14 @@ def test_bitwise_not_diagram():
172173
x3: ───~x───
173174
''',
174175
)
176+
177+
178+
@pytest.mark.parametrize('dtype', [QUInt, QMontgomeryUInt, QInt])
179+
@pytest.mark.parametrize('bitsize', range(2, 6))
180+
def test_bitwisenot_classical_action(dtype, bitsize):
181+
b = BitwiseNot(dtype(bitsize))
182+
if dtype is QInt:
183+
valid_range = range(-(2 ** (bitsize - 1)), 2 ** (bitsize - 1))
184+
else:
185+
valid_range = range(2**bitsize)
186+
qlt_testing.assert_consistent_classical_action(b, x=valid_range)

qualtran/bloqs/factoring/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,3 @@
1313
# limitations under the License.
1414

1515
from .mod_exp import ModExp
16-
from .mod_mul import CtrlModMul

qualtran/bloqs/factoring/ecc/ec_add.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
7474
(CModSub(QUInt(self.n), mod=self.mod), 4),
7575
(ModNeg(QUInt(self.n), mod=self.mod), 2),
7676
(CModNeg(QUInt(self.n), mod=self.mod), 1),
77-
(ModDbl(n=self.n, mod=self.mod), 2),
77+
(ModDbl(QUInt(self.n), mod=self.mod), 2),
7878
(ModMul(n=self.n, mod=self.mod), 10),
7979
(ModInv(n=self.n, mod=self.mod), 4),
8080
}

qualtran/bloqs/factoring/mod_exp.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import math
1415
from functools import cached_property
1516
from typing import Dict, Optional, Set, Tuple, Union
1617

@@ -33,7 +34,7 @@
3334
SoquetT,
3435
)
3536
from qualtran.bloqs.basic_gates import IntState
36-
from qualtran.bloqs.factoring.mod_mul import CtrlModMul
37+
from qualtran.bloqs.mod_arithmetic import CModMulK
3738
from qualtran.drawing import Text, WireSymbol
3839
from qualtran.resource_counting import BloqCountT, SympySymbolAllocator
3940
from qualtran.resource_counting.generalizers import ignore_split_join
@@ -69,6 +70,10 @@ class ModExp(Bloq):
6970
exp_bitsize: Union[int, sympy.Expr]
7071
x_bitsize: Union[int, sympy.Expr]
7172

73+
def __post_init__(self):
74+
if isinstance(self.base, int) and isinstance(self.mod, int):
75+
assert math.gcd(self.base, self.mod) == 1
76+
7277
@cached_property
7378
def signature(self) -> 'Signature':
7479
return Signature(
@@ -96,8 +101,8 @@ def make_for_shor(cls, big_n: int, g: Optional[int] = None):
96101
return cls(base=g, mod=big_n, exp_bitsize=2 * little_n, x_bitsize=little_n)
97102

98103
def _CtrlModMul(self, k: Union[int, sympy.Expr]):
99-
"""Helper method to return a `CtrlModMul` with attributes forwarded."""
100-
return CtrlModMul(k=k, bitsize=self.x_bitsize, mod=self.mod)
104+
"""Helper method to return a `CModMulK` with attributes forwarded."""
105+
return CModMulK(QUInt(self.x_bitsize), k=k, mod=self.mod)
101106

102107
def build_composite_bloq(self, bb: 'BloqBuilder', exponent: 'Soquet') -> Dict[str, 'SoquetT']:
103108
if isinstance(self.exp_bitsize, sympy.Expr):
@@ -135,7 +140,7 @@ def wire_symbol(
135140

136141

137142
def _generalize_k(b: Bloq) -> Optional[Bloq]:
138-
if isinstance(b, CtrlModMul):
143+
if isinstance(b, CModMulK):
139144
return attrs.evolve(b, k=_K)
140145

141146
return b

qualtran/bloqs/factoring/mod_exp_test.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,26 +22,25 @@
2222
from qualtran import Bloq
2323
from qualtran.bloqs.bookkeeping import Join, Split
2424
from qualtran.bloqs.factoring.mod_exp import _modexp, _modexp_symb, ModExp
25-
from qualtran.bloqs.factoring.mod_mul import CtrlModMul
25+
from qualtran.bloqs.mod_arithmetic import CModMulK
2626
from qualtran.drawing import Text
2727
from qualtran.resource_counting import SympySymbolAllocator
2828
from qualtran.testing import execute_notebook
2929

3030

31+
# TODO: Fix ModExp and improve this test
3132
def test_mod_exp_consistent_classical():
3233
rs = np.random.RandomState(52)
3334

3435
# 100 random attribute choices.
3536
for _ in range(100):
3637
# Sample moduli in a range. Set x_bitsize=n big enough to fit.
37-
mod = rs.randint(4, 123)
38+
mod = 7 * 13
3839
n = int(np.ceil(np.log2(mod)))
39-
n = rs.randint(n, n + 10)
4040

4141
# Choose an exponent in a range. Set exp_bitsize=ne bit enough to fit.
42-
exponent = rs.randint(1, 20)
43-
ne = int(np.ceil(np.log2(exponent)))
44-
ne = rs.randint(ne, ne + 10)
42+
exponent = rs.randint(1, 2**n)
43+
ne = 2 * n
4544

4645
# Choose a base smaller than mod.
4746
base = rs.randint(1, mod)
@@ -59,30 +58,30 @@ def test_modexp_symb_manual():
5958
counts = modexp.bloq_counts()
6059
counts_by_bloq = {bloq.pretty_name(): n for bloq, n in counts.items()}
6160
assert counts_by_bloq['|1>'] == 1
62-
assert counts_by_bloq['CtrlModMul'] == n_e
61+
assert counts_by_bloq['CModMulK'] == n_e
6362

6463
b, x = modexp.call_classically(exponent=sympy.Symbol('b'))
6564
assert str(x) == 'Mod(g**b, N)'
6665

6766

6867
def test_mod_exp_consistent_counts():
6968
bloq = ModExp(base=8, exp_bitsize=3, x_bitsize=10, mod=50)
69+
7070
counts1 = bloq.bloq_counts()
7171

7272
ssa = SympySymbolAllocator()
7373
my_k = ssa.new_symbol('k')
7474

7575
def generalize(b: Bloq) -> Optional[Bloq]:
76-
if isinstance(b, CtrlModMul):
77-
# Symbolic k in `CtrlModMul`.
76+
if isinstance(b, CModMulK):
77+
# Symbolic k in `CModMulK`.
7878
return attrs.evolve(b, k=my_k)
7979
if isinstance(b, (Split, Join)):
8080
# Ignore these
8181
return None
8282
return b
8383

8484
counts2 = bloq.decompose_bloq().bloq_counts(generalizer=generalize)
85-
8685
assert counts1 == counts2
8786

8887

qualtran/bloqs/factoring/mod_mul.ipynb

Lines changed: 0 additions & 189 deletions
This file was deleted.

0 commit comments

Comments
 (0)