|
14 | 14 |
|
15 | 15 | from collections import defaultdict |
16 | 16 | from functools import cached_property |
17 | | -from typing import ( |
18 | | - Dict, |
19 | | - Iterable, |
20 | | - Iterator, |
21 | | - List, |
22 | | - Optional, |
23 | | - Sequence, |
24 | | - Set, |
25 | | - Tuple, |
26 | | - TYPE_CHECKING, |
27 | | - Union, |
28 | | -) |
| 17 | +from typing import Dict, Iterable, Iterator, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union |
29 | 18 |
|
30 | 19 | import attrs |
31 | 20 | import cirq |
|
65 | 54 |
|
66 | 55 | if TYPE_CHECKING: |
67 | 56 | from qualtran import BloqBuilder |
68 | | - from qualtran.resource_counting import BloqCountT, SympySymbolAllocator |
| 57 | + from qualtran.resource_counting import ( |
| 58 | + BloqCountDictT, |
| 59 | + MutableBloqCountDictT, |
| 60 | + SympySymbolAllocator, |
| 61 | + ) |
69 | 62 | from qualtran.simulation.classical_sim import ClassicalValT |
70 | 63 |
|
71 | 64 |
|
@@ -183,22 +176,22 @@ def decompose_from_registers( |
183 | 176 | def _has_unitary_(self): |
184 | 177 | return True |
185 | 178 |
|
186 | | - def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']: |
| 179 | + def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT': |
187 | 180 | if ( |
188 | 181 | not is_symbolic(self.less_than_val, self.bitsize) |
189 | 182 | and self.less_than_val >= 2**self.bitsize |
190 | 183 | ): |
191 | | - return {(XGate(), 1)} |
| 184 | + return {XGate(): 1} |
192 | 185 | num_set_bits = ( |
193 | 186 | int(self.less_than_val).bit_count() |
194 | 187 | if not is_symbolic(self.less_than_val) |
195 | 188 | else self.bitsize |
196 | 189 | ) |
197 | 190 | return { |
198 | | - (And(), self.bitsize), |
199 | | - (And().adjoint(), self.bitsize), |
200 | | - (CNOT(), num_set_bits + 2 * self.bitsize), |
201 | | - (XGate(), 2 * (1 + num_set_bits)), |
| 191 | + And(): self.bitsize, |
| 192 | + And().adjoint(): self.bitsize, |
| 193 | + CNOT(): num_set_bits + 2 * self.bitsize, |
| 194 | + XGate(): 2 * (1 + num_set_bits), |
202 | 195 | } |
203 | 196 |
|
204 | 197 |
|
@@ -307,8 +300,8 @@ def __pow__(self, power: int) -> 'BiQubitsMixer': |
307 | 300 | return self.adjoint() |
308 | 301 | return NotImplemented # pragma: no cover |
309 | 302 |
|
310 | | - def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']: |
311 | | - return {(XGate(), 1), (CNOT(), 9), (And(uncompute=self.is_adjoint), 2)} |
| 303 | + def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT': |
| 304 | + return {XGate(): 1, CNOT(): 9, And(uncompute=self.is_adjoint): 2} |
312 | 305 |
|
313 | 306 | def _has_unitary_(self): |
314 | 307 | return not self.is_adjoint |
@@ -380,8 +373,8 @@ def __pow__(self, power: int) -> Union['SingleQubitCompare', cirq.Gate]: |
380 | 373 | return self.adjoint() |
381 | 374 | return self |
382 | 375 |
|
383 | | - def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']: |
384 | | - return {(XGate(), 1), (CNOT(), 4), (And(uncompute=self.is_adjoint), 1)} |
| 376 | + def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT': |
| 377 | + return {XGate(): 1, CNOT(): 4, And(uncompute=self.is_adjoint): 1} |
385 | 378 |
|
386 | 379 |
|
387 | 380 | @bloq_example |
@@ -575,13 +568,13 @@ def decompose_from_registers( |
575 | 568 | all_ancilla = set([q for op in adjoint for q in op.qubits if q not in input_qubits]) |
576 | 569 | context.qubit_manager.qfree(all_ancilla) |
577 | 570 |
|
578 | | - def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']: |
| 571 | + def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT': |
579 | 572 | if is_symbolic(self.x_bitsize, self.y_bitsize): |
580 | 573 | return { |
581 | | - (BiQubitsMixer(), self.x_bitsize), |
582 | | - (BiQubitsMixer().adjoint(), self.x_bitsize), |
583 | | - (SingleQubitCompare(), 1), |
584 | | - (SingleQubitCompare().adjoint(), 1), |
| 574 | + BiQubitsMixer(): self.x_bitsize, |
| 575 | + BiQubitsMixer().adjoint(): self.x_bitsize, |
| 576 | + SingleQubitCompare(): 1, |
| 577 | + SingleQubitCompare().adjoint(): 1, |
585 | 578 | } |
586 | 579 |
|
587 | 580 | n = min(self.x_bitsize, self.y_bitsize) |
@@ -613,7 +606,7 @@ def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']: |
613 | 606 | ret[And(1, 0).adjoint()] += 1 |
614 | 607 | ret[CNOT()] += 1 |
615 | 608 |
|
616 | | - return set(ret.items()) |
| 609 | + return ret |
617 | 610 |
|
618 | 611 | def _has_unitary_(self): |
619 | 612 | return True |
@@ -691,8 +684,8 @@ def build_composite_bloq( |
691 | 684 | target = bb.add(XGate(), q=target) |
692 | 685 | return {'a': a, 'b': b, 'target': target} |
693 | 686 |
|
694 | | - def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']: |
695 | | - return {(LessThanEqual(self.a_bitsize, self.b_bitsize), 1), (XGate(), 1)} |
| 687 | + def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT': |
| 688 | + return {LessThanEqual(self.a_bitsize, self.b_bitsize): 1, XGate(): 1} |
696 | 689 |
|
697 | 690 |
|
698 | 691 | @bloq_example |
@@ -885,23 +878,23 @@ def wire_symbol( |
885 | 878 | return TextBox('t⨁(a>b)') |
886 | 879 | raise ValueError(f'Unknown register name {reg.name}') |
887 | 880 |
|
888 | | - def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']: |
| 881 | + def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT': |
889 | 882 | if self.bitsize == 1: |
890 | | - return {(MultiControlX(cvs=(1, 0)), 1)} |
| 883 | + return {MultiControlX(cvs=(1, 0)): 1} |
891 | 884 |
|
892 | 885 | if self.signed: |
893 | 886 | return { |
894 | | - (CNOT(), 6 * self.bitsize - 7), |
895 | | - (XGate(), 2 * self.bitsize + 2), |
896 | | - (And(), self.bitsize - 1), |
897 | | - (And(uncompute=True), self.bitsize - 1), |
| 887 | + CNOT(): 6 * self.bitsize - 7, |
| 888 | + XGate(): 2 * self.bitsize + 2, |
| 889 | + And(): self.bitsize - 1, |
| 890 | + And(uncompute=True): self.bitsize - 1, |
898 | 891 | } |
899 | 892 |
|
900 | 893 | return { |
901 | | - (CNOT(), 6 * self.bitsize - 1), |
902 | | - (XGate(), 2 * self.bitsize + 4), |
903 | | - (And(), self.bitsize), |
904 | | - (And(uncompute=True), self.bitsize), |
| 894 | + CNOT(): 6 * self.bitsize - 1, |
| 895 | + XGate(): 2 * self.bitsize + 4, |
| 896 | + And(): self.bitsize, |
| 897 | + And(uncompute=True): self.bitsize, |
905 | 898 | } |
906 | 899 |
|
907 | 900 |
|
@@ -941,8 +934,8 @@ def wire_symbol(self, reg: Optional[Register], idx: Tuple[int, ...] = tuple()) - |
941 | 934 | return TextBox(f"⨁(x > {self.val})") |
942 | 935 | raise ValueError(f'Unknown register symbol {reg.name}') |
943 | 936 |
|
944 | | - def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']: |
945 | | - return {(LessThanConstant(self.bitsize, less_than_val=self.val), 1)} |
| 937 | + def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT': |
| 938 | + return {LessThanConstant(self.bitsize, less_than_val=self.val): 1} |
946 | 939 |
|
947 | 940 |
|
948 | 941 | @bloq_example |
@@ -1007,8 +1000,8 @@ def build_composite_bloq( |
1007 | 1000 | x = bb.join(xs) |
1008 | 1001 | return {'x': x, 'target': target} |
1009 | 1002 |
|
1010 | | - def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']: |
1011 | | - return {(MultiControlX(self.bits_k), 1)} |
| 1003 | + def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT': |
| 1004 | + return {MultiControlX(self.bits_k): 1} |
1012 | 1005 |
|
1013 | 1006 |
|
1014 | 1007 | def _make_equals_a_constant(): |
@@ -1134,21 +1127,22 @@ def on_classical_vals( |
1134 | 1127 | return {'ctrl': ctrl, 'a': a, 'b': b, 'target': target ^ (a > b)} |
1135 | 1128 | return {'ctrl': ctrl, 'a': a, 'b': b, 'target': target} |
1136 | 1129 |
|
1137 | | - def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']: |
1138 | | - signed_ops = [] |
| 1130 | + def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT': |
| 1131 | + signed_ops: 'MutableBloqCountDictT' = {} |
1139 | 1132 | if isinstance(self.dtype, QInt): |
1140 | | - signed_ops = [ |
1141 | | - (SignExtend(self.dtype, QInt(self.dtype.bitsize + 1)), 2), |
1142 | | - (SignExtend(self.dtype, QInt(self.dtype.bitsize + 1)).adjoint(), 2), |
1143 | | - ] |
| 1133 | + signed_ops = { |
| 1134 | + SignExtend(self.dtype, QInt(self.dtype.bitsize + 1)): 2, |
| 1135 | + SignExtend(self.dtype, QInt(self.dtype.bitsize + 1)).adjoint(): 2, |
| 1136 | + } |
1144 | 1137 | dtype = attrs.evolve(self.dtype, bitsize=self.dtype.bitsize + 1) |
1145 | 1138 | return { |
1146 | | - (BitwiseNot(dtype), 2), |
1147 | | - (BitwiseNot(QUInt(dtype.bitsize + 1)), 2), |
1148 | | - (OutOfPlaceAdder(self.dtype.bitsize + 1).adjoint(), 1), |
1149 | | - (OutOfPlaceAdder(self.dtype.bitsize + 1), 1), |
1150 | | - (MultiControlX((self.cv, 1)), 1), |
1151 | | - }.union(signed_ops) |
| 1139 | + BitwiseNot(dtype): 2, |
| 1140 | + BitwiseNot(QUInt(dtype.bitsize + 1)): 2, |
| 1141 | + OutOfPlaceAdder(self.dtype.bitsize + 1).adjoint(): 1, |
| 1142 | + OutOfPlaceAdder(self.dtype.bitsize + 1): 1, |
| 1143 | + MultiControlX((self.cv, 1)): 1, |
| 1144 | + **signed_ops, |
| 1145 | + } |
1152 | 1146 |
|
1153 | 1147 |
|
1154 | 1148 | @bloq_example(generalizer=ignore_split_join) |
|
0 commit comments