Skip to content

Commit dccd673

Browse files
anurudhpmpharrigan
andauthored
Cleanup redundant code for computing t complexity (#1317)
* move _HasEps * reduce dependence on `_get_all_rotation_types` * remove `_get_all_rotation_types` * mypy * reorganize tests for `t_counts_from_sigma` * mypy * fix error * replace `t_counts_from_sigma(bloq)` with `get_cost_value(bloq, QECGatesCost())` * revert test * link issue --------- Co-authored-by: Matthew Harrigan <mpharrigan@google.com>
1 parent 35db018 commit dccd673

5 files changed

Lines changed: 23 additions & 65 deletions

File tree

qualtran/bloqs/basic_gates/rotation.py

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from functools import cached_property
15-
from typing import Optional, Protocol, runtime_checkable, Tuple, Union
15+
from typing import Optional, Tuple, Union
1616

1717
import attrs
1818
import cirq
@@ -27,13 +27,6 @@
2727
from qualtran.symbolics import SymbolicFloat
2828

2929

30-
@runtime_checkable
31-
class _HasEps(Protocol):
32-
"""Protocol for typing `RotationBloq` base class mixin that has accuracy specified as eps."""
33-
34-
eps: float
35-
36-
3730
@frozen
3831
class ZPowGate(CirqGateAsBloqBase):
3932
r"""A gate that rotates around the Z axis of the Bloch sphere.
@@ -115,7 +108,7 @@ def _z_pow() -> ZPowGate:
115108
class CZPowGate(CirqGateAsBloqBase):
116109
exponent: float = 1.0
117110
global_shift: float = 0.0
118-
eps: float = 1e-11
111+
eps: SymbolicFloat = 1e-11
119112

120113
def decompose_bloq(self) -> 'CompositeBloq':
121114
raise DecomposeTypeError(f"{self} is atomic")
@@ -183,7 +176,7 @@ class XPowGate(CirqGateAsBloqBase):
183176
"""
184177
exponent: Union[sympy.Expr, float] = 1.0
185178
global_shift: float = 0.0
186-
eps: float = 1e-11
179+
eps: SymbolicFloat = 1e-11
187180

188181
def decompose_bloq(self) -> 'CompositeBloq':
189182
raise DecomposeTypeError(f"{self} is atomic")
@@ -253,7 +246,7 @@ class YPowGate(CirqGateAsBloqBase):
253246
"""
254247
exponent: Union[sympy.Expr, float] = 1.0
255248
global_shift: float = 0.0
256-
eps: float = 1e-11
249+
eps: SymbolicFloat = 1e-11
257250

258251
def decompose_bloq(self) -> 'CompositeBloq':
259252
raise DecomposeTypeError(f"{self} is atomic")
@@ -321,7 +314,7 @@ def wire_symbol(self, reg: Optional[Register], idx: Tuple[int, ...] = tuple()) -
321314
@frozen
322315
class Rx(CirqGateAsBloqBase):
323316
angle: Union[sympy.Expr, float]
324-
eps: float = 1e-11
317+
eps: SymbolicFloat = 1e-11
325318

326319
def decompose_bloq(self) -> 'CompositeBloq':
327320
raise DecomposeTypeError(f"{self} is atomic")
@@ -342,7 +335,7 @@ def wire_symbol(self, reg: Optional[Register], idx: Tuple[int, ...] = tuple()) -
342335
@frozen
343336
class Ry(CirqGateAsBloqBase):
344337
angle: Union[sympy.Expr, float]
345-
eps: float = 1e-11
338+
eps: SymbolicFloat = 1e-11
346339

347340
def decompose_bloq(self) -> 'CompositeBloq':
348341
raise DecomposeTypeError(f"{self} is atomic")

qualtran/bloqs/chemistry/trotter/hubbard/qpe_cost_optimization.ipynb

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@
111111
"import numpy as np\n",
112112
"import sympy\n",
113113
"\n",
114-
"from qualtran.resource_counting.t_counts_from_sigma import _get_all_rotation_types\n",
114+
"from qualtran.resource_counting.classify_bloqs import bloq_is_rotation\n",
115115
"from qualtran.resource_counting.generalizers import PHI\n",
116116
"from qualtran.cirq_interop.t_complexity_protocol import TComplexity\n",
117117
"from qualtran import Bloq\n",
@@ -130,11 +130,10 @@
130130
"\n",
131131
"\n",
132132
"def t_and_rot_counts_from_sigma(sigma: Dict['Bloq', Union[int, 'sympy.Expr']]) -> Tuple[int, int]:\n",
133-
" rotation_types = _get_all_rotation_types()\n",
134133
" ret = sigma.get(TGate(), 0)\n",
135134
" n_rot = 0\n",
136135
" for bloq, counts in sigma.items():\n",
137-
" if isinstance(bloq, rotation_types):\n",
136+
" if bloq_is_rotation(bloq):\n",
138137
" n_rot += counts\n",
139138
" return ret, n_rot\n",
140139
"\n",

qualtran/bloqs/data_loading/select_swap_qrom_test.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
)
2828
from qualtran.cirq_interop.t_complexity_protocol import t_complexity, TComplexity
2929
from qualtran.cirq_interop.testing import assert_circuit_inp_out_cirqsim
30-
from qualtran.resource_counting.t_counts_from_sigma import t_counts_from_sigma
30+
from qualtran.resource_counting import GateCounts, get_cost_value, QECGatesCost
3131
from qualtran.testing import assert_valid_bloq_decomposition
3232

3333

@@ -187,8 +187,9 @@ def test_qroam_t_complexity():
187187
qroam = SelectSwapQROM.build_from_data(
188188
[1, 2, 3, 4, 5, 6, 7, 8], target_bitsizes=(4,), log_block_sizes=(2,)
189189
)
190-
_, sigma = qroam.call_graph()
191-
assert t_counts_from_sigma(sigma) == qroam.t_complexity().t == 192
190+
gate_counts = get_cost_value(qroam, QECGatesCost())
191+
assert gate_counts == GateCounts(t=192, clifford=1082)
192+
assert qroam.t_complexity() == TComplexity(t=192, clifford=1082)
192193

193194

194195
def test_qroam_many_registers():

qualtran/resource_counting/t_counts_from_sigma.py

Lines changed: 10 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -11,45 +11,27 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
import inspect
15-
import sys
16-
from typing import cast, Mapping, Optional, Tuple, Type, TYPE_CHECKING
14+
from typing import Mapping
1715

1816
import cirq
1917

18+
from qualtran import Bloq, Controlled
2019
from qualtran.symbolics import ceil, SymbolicInt
2120

22-
if TYPE_CHECKING:
23-
from qualtran import Bloq
24-
from qualtran.bloqs.basic_gates.rotation import _HasEps
2521

26-
27-
def _get_all_rotation_types() -> Tuple[Type['_HasEps'], ...]:
28-
"""Returns all classes defined in bloqs.basic_gates which have an attribute `eps`."""
29-
from qualtran.bloqs.basic_gates import GlobalPhase
30-
from qualtran.bloqs.basic_gates.rotation import _HasEps
31-
32-
bloqs_to_exclude = [GlobalPhase]
33-
34-
return tuple(
35-
cast(Type['_HasEps'], v) # Can't use `issubclass` with protocols with attributes.
36-
for (_, v) in inspect.getmembers(sys.modules['qualtran.bloqs.basic_gates'], inspect.isclass)
37-
if isinstance(v, _HasEps) and v not in bloqs_to_exclude
38-
)
39-
40-
41-
def t_counts_from_sigma(
42-
sigma: Mapping['Bloq', SymbolicInt],
43-
rotation_types: Optional[Tuple[Type['_HasEps'], ...]] = None,
44-
) -> SymbolicInt:
22+
def t_counts_from_sigma(sigma: Mapping['Bloq', SymbolicInt]) -> SymbolicInt:
4523
"""Aggregates T-counts from a sigma dictionary by summing T-costs for all rotation bloqs."""
4624
from qualtran.bloqs.basic_gates import TGate
4725
from qualtran.cirq_interop.t_complexity_protocol import TComplexity
26+
from qualtran.resource_counting.classify_bloqs import bloq_is_rotation
4827

49-
if rotation_types is None:
50-
rotation_types = _get_all_rotation_types()
5128
ret = sigma.get(TGate(), 0) + sigma.get(TGate().adjoint(), 0)
5229
for bloq, counts in sigma.items():
53-
if isinstance(bloq, rotation_types) and not cirq.has_stabilizer_effect(bloq):
30+
if bloq_is_rotation(bloq) and not cirq.has_stabilizer_effect(bloq):
31+
if isinstance(bloq, Controlled):
32+
# TODO native controlled rotation bloqs missing (CRz, CRy etc.)
33+
# https://github.com/quantumlib/Qualtran/issues/878
34+
bloq = bloq.subbloq
35+
assert hasattr(bloq, 'eps')
5436
ret += ceil(TComplexity.rotation_cost(bloq.eps)) * counts
5537
return ret

qualtran/resource_counting/t_counts_from_sigma_test.py

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -19,31 +19,14 @@
1919
Rx,
2020
Ry,
2121
Rz,
22-
SU2RotationGate,
2322
TGate,
2423
Toffoli,
2524
XPowGate,
2625
YPowGate,
2726
ZPowGate,
2827
)
2928
from qualtran.cirq_interop.t_complexity_protocol import TComplexity
30-
from qualtran.resource_counting.t_counts_from_sigma import (
31-
_get_all_rotation_types,
32-
t_counts_from_sigma,
33-
)
34-
35-
36-
def test_all_rotation_types():
37-
assert set(_get_all_rotation_types()) == {
38-
CZPowGate,
39-
Rx,
40-
Ry,
41-
Rz,
42-
XPowGate,
43-
YPowGate,
44-
ZPowGate,
45-
SU2RotationGate,
46-
}
29+
from qualtran.resource_counting.t_counts_from_sigma import t_counts_from_sigma
4730

4831

4932
def test_t_counts_from_sigma():

0 commit comments

Comments
 (0)