Skip to content

Commit 76fdc9d

Browse files
authored
Change return type of build_call_graph to set | dict (#1356)
* Change return type of build_call_graph to set | dict - Changes build_call_graph to return a Union of Dict[Bloq, Union[int, Expr]] and Set[BloqCountT] - Adds checks in main call sites to check both cases.
1 parent dde12e2 commit 76fdc9d

11 files changed

Lines changed: 85 additions & 27 deletions

File tree

qualtran/_infra/adjoint.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,9 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from collections import Counter
1516
from functools import cached_property
16-
from typing import Dict, List, Optional, Set, Tuple, TYPE_CHECKING
17+
from typing import Dict, List, Optional, Tuple, TYPE_CHECKING
1718

1819
import cirq
1920
from attrs import frozen
@@ -26,7 +27,7 @@
2627
if TYPE_CHECKING:
2728
from qualtran import Bloq, CompositeBloq, Register, Signature, SoquetT
2829
from qualtran.drawing import WireSymbol
29-
from qualtran.resource_counting import BloqCountT, SympySymbolAllocator
30+
from qualtran.resource_counting import BloqCountDictT, SympySymbolAllocator
3031

3132

3233
def _adjoint_final_soqs(cbloq: 'CompositeBloq', new_signature: Signature) -> Dict[str, 'SoquetT']:
@@ -158,9 +159,17 @@ def adjoint(self) -> 'Bloq':
158159
"""The 'double adjoint' brings you back to the original bloq."""
159160
return self.subbloq
160161

161-
def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
162+
def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT':
162163
"""The call graph takes the adjoint of each of the bloqs in `subbloq`'s call graph."""
163-
return {(bloq.adjoint(), n) for bloq, n in self.subbloq.build_call_graph(ssa=ssa)}
164+
sub_cg = self.subbloq.build_call_graph(ssa=ssa)
165+
counts = Counter['Bloq']()
166+
if isinstance(sub_cg, set):
167+
for bloq, n in sub_cg:
168+
counts[bloq.adjoint()] += n
169+
else:
170+
for bloq, n in sub_cg.items():
171+
counts[bloq.adjoint()] += n
172+
return counts
164173

165174
def pretty_name(self) -> str:
166175
"""The subbloq's pretty_name with a dagger."""

qualtran/_infra/bloq.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,13 @@
3939
from qualtran.cirq_interop import CirqQuregT
4040
from qualtran.cirq_interop.t_complexity_protocol import TComplexity
4141
from qualtran.drawing import WireSymbol
42-
from qualtran.resource_counting import BloqCountT, CostKey, GeneralizerT, SympySymbolAllocator
42+
from qualtran.resource_counting import (
43+
BloqCountDictT,
44+
BloqCountT,
45+
CostKey,
46+
GeneralizerT,
47+
SympySymbolAllocator,
48+
)
4349
from qualtran.simulation.classical_sim import ClassicalValT
4450

4551

@@ -279,7 +285,9 @@ def my_tensors(
279285
"""
280286
raise NotImplementedError(f"{self} does not support tensor simulation.")
281287

282-
def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
288+
def build_call_graph(
289+
self, ssa: 'SympySymbolAllocator'
290+
) -> Union['BloqCountDictT', Set['BloqCountT']]:
283291
"""Override this method to build the bloq call graph.
284292
285293
This method must return a set of `(bloq, n)` tuples where `bloq` is called `n` times in

qualtran/_infra/controlled.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
from collections import Counter
1415
from functools import cached_property
1516
from typing import (
1617
Any,
@@ -20,7 +21,6 @@
2021
Optional,
2122
Protocol,
2223
Sequence,
23-
Set,
2424
Tuple,
2525
TYPE_CHECKING,
2626
Union,
@@ -39,10 +39,10 @@
3939
if TYPE_CHECKING:
4040
import quimb.tensor as qtn
4141

42-
from qualtran import BloqBuilder, CompositeBloq, ConnectionT, SoquetT
42+
from qualtran import Bloq, BloqBuilder, CompositeBloq, ConnectionT, SoquetT
4343
from qualtran.cirq_interop import CirqQuregT
4444
from qualtran.drawing import WireSymbol
45-
from qualtran.resource_counting import BloqCountT, SympySymbolAllocator
45+
from qualtran.resource_counting import BloqCountDictT, SympySymbolAllocator
4646
from qualtran.simulation.classical_sim import ClassicalValT
4747

4848

@@ -386,7 +386,7 @@ def build_composite_bloq(
386386
fsoqs |= dict(zip(self.ctrl_reg_names, ctrl_soqs))
387387
return fsoqs
388388

389-
def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
389+
def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT':
390390
try:
391391
sub_cg = self.subbloq.build_call_graph(ssa=ssa)
392392
except DecomposeTypeError as e1:
@@ -396,7 +396,14 @@ def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
396396
f"Could not build call graph for {self}: {e2}"
397397
) from e2
398398

399-
return {(bloq.controlled(self.ctrl_spec), n) for bloq, n in sub_cg}
399+
counts = Counter['Bloq']()
400+
if isinstance(sub_cg, set):
401+
for bloq, n in sub_cg:
402+
counts[bloq.controlled(self.ctrl_spec)] += n
403+
else:
404+
for bloq, n in sub_cg.items():
405+
counts[bloq.controlled(self.ctrl_spec)] += n
406+
return counts
400407

401408
def on_classical_vals(self, **vals: 'ClassicalValT') -> Dict[str, 'ClassicalValT']:
402409
ctrl_vals = [vals[reg_name] for reg_name in self.ctrl_reg_names]

qualtran/bloqs/arithmetic/addition.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@
6262

6363
if TYPE_CHECKING:
6464
from qualtran.drawing import WireSymbol
65-
from qualtran.resource_counting import BloqCountT, SympySymbolAllocator
65+
from qualtran.resource_counting import BloqCountDictT, BloqCountT, SympySymbolAllocator
6666
from qualtran.simulation.classical_sim import ClassicalValT
6767
from qualtran.symbolics import SymbolicInt
6868

@@ -500,7 +500,9 @@ def build_composite_bloq(
500500
else:
501501
return {'x': x}
502502

503-
def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
503+
def build_call_graph(
504+
self, ssa: 'SympySymbolAllocator'
505+
) -> Union['BloqCountDictT', Set['BloqCountT']]:
504506
loading_cost: Tuple[Bloq, SymbolicInt]
505507
if len(self.cvs) == 0:
506508
loading_cost = (XGate(), self.bitsize) # upper bound; depends on the data.

qualtran/bloqs/arithmetic/permutation.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@
5252
import sympy
5353

5454
from qualtran import BloqBuilder, SoquetT
55-
from qualtran.resource_counting import BloqCountT, SympySymbolAllocator
55+
from qualtran.resource_counting import BloqCountDictT, BloqCountT, SympySymbolAllocator
5656

5757
SymbolicCycleT: TypeAlias = Union[CycleT, Shaped]
5858

@@ -122,7 +122,9 @@ def build_composite_bloq(self, bb: 'BloqBuilder', x: 'SoquetT') -> dict[str, 'So
122122

123123
return {'x': x}
124124

125-
def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
125+
def build_call_graph(
126+
self, ssa: 'SympySymbolAllocator'
127+
) -> Union['BloqCountDictT', Set['BloqCountT']]:
126128
if is_symbolic(self.cycle):
127129
x = ssa.new_symbol('x')
128130
cycle_len = slen(self.cycle)
@@ -267,7 +269,9 @@ def build_composite_bloq(self, bb: 'BloqBuilder', x: 'Soquet') -> dict[str, 'Soq
267269

268270
return {'x': x}
269271

270-
def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
272+
def build_call_graph(
273+
self, ssa: 'SympySymbolAllocator'
274+
) -> Union['BloqCountDictT', Set['BloqCountT']]:
271275
if is_symbolic(self.cycles):
272276
# worst case cost: single cycle of length N
273277
cycle = Shaped((self.N,))

qualtran/bloqs/data_loading/qrom.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from qualtran.symbolics import prod, SymbolicInt
3434

3535
if TYPE_CHECKING:
36-
from qualtran.resource_counting import BloqCountT, SympySymbolAllocator
36+
from qualtran.resource_counting import BloqCountDictT, BloqCountT, SympySymbolAllocator
3737

3838

3939
def _to_tuple(x: Iterable[NDArray]) -> Sequence[NDArray]:
@@ -208,7 +208,9 @@ def nth_operation_callgraph(self, **kwargs: int) -> Set['BloqCountT']:
208208
ret += data_to_load.bit_count()
209209
return {(CNOT(), ret)}
210210

211-
def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
211+
def build_call_graph(
212+
self, ssa: 'SympySymbolAllocator'
213+
) -> Union['BloqCountDictT', Set['BloqCountT']]:
212214
if self.has_data():
213215
return super().build_call_graph(ssa=ssa)
214216
n_and = prod(self.data_shape) - 2 + self.num_controls

qualtran/bloqs/multiplexers/unary_iteration_bloq.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
import sympy
3131

3232
from qualtran import Bloq
33-
from qualtran.resource_counting import BloqCountT, SympySymbolAllocator
33+
from qualtran.resource_counting import BloqCountDictT, BloqCountT, SympySymbolAllocator
3434
from qualtran.symbolics import SymbolicInt
3535

3636

@@ -591,7 +591,9 @@ def nth_operation_callgraph(self, **selection_regs_name_to_val) -> Set['BloqCoun
591591
f"Derived class {type(self)} does not implement `nth_operation_callgraph`."
592592
)
593593

594-
def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
594+
def build_call_graph(
595+
self, ssa: 'SympySymbolAllocator'
596+
) -> Union['BloqCountDictT', Set['BloqCountT']]:
595597
if total_bits(self.selection_registers) == 0 or self._break_early(
596598
(), 0, self.selection_registers[0].dtype.iteration_length_or_zero()
597599
):

qualtran/bloqs/state_preparation/prepare_uniform_superposition.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
)
3838

3939
if TYPE_CHECKING:
40-
from qualtran.resource_counting import BloqCountT, SympySymbolAllocator
40+
from qualtran.resource_counting import BloqCountDictT, BloqCountT, SympySymbolAllocator
4141

4242

4343
@attrs.frozen
@@ -148,7 +148,9 @@ def decompose_from_registers(
148148
yield cirq.H.on_each(*logL_qubits)
149149
context.qubit_manager.qfree([*and_target, *and_ancilla])
150150

151-
def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
151+
def build_call_graph(
152+
self, ssa: 'SympySymbolAllocator'
153+
) -> Union['BloqCountDictT', Set['BloqCountT']]:
152154
if not is_symbolic(self.n, self.cvs):
153155
# build from decomposition
154156
return super().build_call_graph(ssa)

qualtran/cirq_interop/t_complexity_protocol.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import cirq
2020

2121
from qualtran import Bloq, DecomposeNotImplementedError, DecomposeTypeError
22-
from qualtran.resource_counting import SympySymbolAllocator
22+
from qualtran.resource_counting import BloqCountT, SympySymbolAllocator
2323
from qualtran.symbolics import ceil, log2, SymbolicFloat, SymbolicInt
2424

2525
from .decompose_protocol import _decompose_once_considering_known_decomposition
@@ -166,7 +166,11 @@ def _from_bloq_build_call_graph(bloq: Bloq) -> Optional[TComplexity]:
166166
return None
167167

168168
ret = TComplexity()
169-
for callee, n in callee_counts:
169+
if isinstance(callee_counts, set):
170+
callee_iterator: Iterable[BloqCountT] = callee_counts
171+
else:
172+
callee_iterator = callee_counts.items()
173+
for callee, n in callee_iterator:
170174
r = t_complexity(callee)
171175
if r is None:
172176
return None

qualtran/resource_counting/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from ._generalization import GeneralizerT
2121

2222
from ._call_graph import (
23+
BloqCountDictT,
2324
BloqCountT,
2425
big_O,
2526
SympySymbolAllocator,

0 commit comments

Comments
 (0)