1919
2020import qualtran .testing as qlt_testing
2121from qualtran import QMontgomeryUInt
22+ from qualtran .bloqs .mod_arithmetic import mod_division
2223from qualtran .bloqs .mod_arithmetic .mod_division import _kaliskimodinverse_example , KaliskiModInverse
2324from qualtran .resource_counting import get_cost_value , QECGatesCost
2425from qualtran .resource_counting .generalizers import ignore_alloc_free , ignore_split_join
@@ -36,7 +37,7 @@ def test_kaliski_mod_inverse_classical_action(bitsize, mod):
3637 continue
3738 x_montgomery = dtype .uint_to_montgomery (x , mod )
3839 res = blq .call_classically (x = x_montgomery )
39- print ( x , x_montgomery )
40+
4041 assert res == cblq .call_classically (x = x_montgomery )
4142 assert len (res ) == 2
4243 assert res [0 ] == dtype .montgomery_inverse (x_montgomery , mod )
@@ -85,11 +86,11 @@ def test_kaliski_symbolic_cost():
8586 # construction this is just $n-1$ (BitwiseNot -> Add(p+1)).
8687 # - The cost of an iteration in Litinski $13n$ since they ignore constants.
8788 # Our construction is exactly the same but we also count the constants
88- # which amout to $3$. for a total cost of $13n + 3 $.
89+ # which amout to $3$. for a total cost of $13n + 4 $.
8990 # For example the cost of ModDbl is 2n+1. In their figure 8, they report
9091 # it as just $2n$. ModDbl gets executed within the 2n loop so its contribution
9192 # to the overal cost should be 4n^2 + 2n instead of just 4n^2.
92- assert total_toff == 26 * n ** 2 + 7 * n - 1
93+ assert total_toff == 26 * n ** 2 + 9 * n - 1
9394
9495
9596def test_kaliskimodinverse_example (bloq_autotester ):
@@ -99,3 +100,78 @@ def test_kaliskimodinverse_example(bloq_autotester):
99100@pytest .mark .notebook
100101def test_notebook ():
101102 qlt_testing .execute_notebook ('mod_division' )
103+
104+
105+ def test_kaliski_iteration_decomposition ():
106+ mod = 7
107+ bitsize = 5
108+ b = mod_division ._KaliskiIteration (bitsize , mod )
109+ cb = b .decompose_bloq ()
110+ for x in range (mod ):
111+ u = mod
112+ v = x
113+ r = 0
114+ s = 1
115+ f = 1
116+
117+ for _ in range (2 * bitsize ):
118+ inputs = {'u' : u , 'v' : v , 'r' : r , 's' : s , 'm' : 0 , 'f' : f , 'is_terminal' : 0 }
119+ res = b .call_classically (** inputs )
120+ assert res == cb .call_classically (** inputs ), f'{ inputs = } '
121+ u , v , r , s , _ , f , _ = res # type: ignore
122+
123+ qlt_testing .assert_valid_bloq_decomposition (b )
124+ qlt_testing .assert_equivalent_bloq_counts (b , generalizer = (ignore_alloc_free , ignore_split_join ))
125+
126+
127+ def test_kaliski_steps ():
128+ bitsize = 5
129+ mod = 7
130+ steps = [
131+ mod_division ._KaliskiIterationStep1 (bitsize ),
132+ mod_division ._KaliskiIterationStep2 (bitsize ),
133+ mod_division ._KaliskiIterationStep3 (bitsize ),
134+ mod_division ._KaliskiIterationStep4 (bitsize ),
135+ mod_division ._KaliskiIterationStep5 (bitsize ),
136+ mod_division ._KaliskiIterationStep6 (bitsize , mod ),
137+ ]
138+ csteps = [b .decompose_bloq () for b in steps ]
139+
140+ # check decomposition is valid.
141+ for step in steps :
142+ qlt_testing .assert_valid_bloq_decomposition (step )
143+ qlt_testing .assert_equivalent_bloq_counts (
144+ step , generalizer = (ignore_alloc_free , ignore_split_join )
145+ )
146+
147+ # check that for all inputs all 2n iteration work when excuted directly on the 6 steps
148+ # and their decompositions.
149+ for x in range (mod ):
150+ u , v , r , s , f = mod , x , 0 , 1 , 1
151+
152+ for _ in range (2 * bitsize ):
153+ a = b = m = is_terminal = 0
154+
155+ res = steps [0 ].call_classically (v = v , m = m , f = f , is_terminal = is_terminal )
156+ assert res == csteps [0 ].call_classically (v = v , m = m , f = f , is_terminal = is_terminal )
157+ v , m , f , is_terminal = res # type: ignore
158+
159+ res = steps [1 ].call_classically (u = u , v = v , b = b , a = a , m = m , f = f )
160+ assert res == csteps [1 ].call_classically (u = u , v = v , b = b , a = a , m = m , f = f )
161+ u , v , b , a , m , f = res # type: ignore
162+
163+ res = steps [2 ].call_classically (u = u , v = v , b = b , a = a , m = m , f = f )
164+ assert res == csteps [2 ].call_classically (u = u , v = v , b = b , a = a , m = m , f = f )
165+ u , v , b , a , m , f = res # type: ignore
166+
167+ res = steps [3 ].call_classically (u = u , v = v , r = r , s = s , a = a )
168+ assert res == csteps [3 ].call_classically (u = u , v = v , r = r , s = s , a = a )
169+ u , v , r , s , a = res # type: ignore
170+
171+ res = steps [4 ].call_classically (u = u , v = v , r = r , s = s , b = b , f = f )
172+ assert res == csteps [4 ].call_classically (u = u , v = v , r = r , s = s , b = b , f = f )
173+ u , v , r , s , b , f = res # type: ignore
174+
175+ res = steps [5 ].call_classically (u = u , v = v , r = r , s = s , b = b , a = a , m = m , f = f )
176+ assert res == csteps [5 ].call_classically (u = u , v = v , r = r , s = s , b = b , a = a , m = m , f = f )
177+ u , v , r , s , b , a , m , f = res # type: ignore
0 commit comments