1414
1515"""Quantum read-only memory."""
1616import numbers
17- from typing import (
18- Callable ,
19- cast ,
20- Iterable ,
21- Iterator ,
22- Optional ,
23- Sequence ,
24- Set ,
25- Tuple ,
26- TYPE_CHECKING ,
27- Union ,
28- )
17+ from typing import cast , Iterable , Iterator , Optional , Sequence , Set , Tuple , TYPE_CHECKING , Union
2918
3019import attrs
3120import cirq
3221import numpy as np
3322import sympy
3423from numpy .typing import ArrayLike , NDArray
3524
36- from qualtran import bloq_example , BloqDocSpec , Register
25+ from qualtran import bloq_example , BloqDocSpec , QUInt , Register
3726from qualtran ._infra .gate_with_registers import merge_qubits
27+ from qualtran .bloqs .arithmetic import XorK
3828from qualtran .bloqs .basic_gates import CNOT
3929from qualtran .bloqs .data_loading .qrom_base import QROMBase
4030from qualtran .bloqs .mcmt .and_bloq import And , MultiAnd
@@ -127,7 +117,7 @@ def build_from_bitsize(
127117 def _load_nth_data (
128118 self ,
129119 selection_idx : Tuple [int , ...],
130- gate : Callable [[ cirq .Qid ], cirq . Operation ] ,
120+ ctrl_qubits : Tuple [ cirq .Qid , ...] = () ,
131121 ** target_regs : NDArray [cirq .Qid ], # type: ignore[type-var]
132122 ) -> Iterator [cirq .OP_TREE ]:
133123 for i , d in enumerate (self .data ):
@@ -136,20 +126,18 @@ def _load_nth_data(
136126 assert all (isinstance (x , (int , numbers .Integral )) for x in target_shape )
137127 for idx in np .ndindex (cast (Tuple [int , ...], target_shape )):
138128 data_to_load = int (d [selection_idx + idx ])
139- for q , bit in zip ( target [idx ], f' { data_to_load :0{ target_bitsize }b } ' ):
140- if int ( bit ):
141- yield gate ( q )
129+ yield XorK ( QUInt ( target_bitsize ), data_to_load ). on ( * target [idx ]). controlled_by (
130+ * ctrl_qubits
131+ )
142132
143133 def decompose_zero_selection (
144134 self , context : cirq .DecompositionContext , ** quregs : NDArray [cirq .Qid ]
145135 ) -> Iterator [cirq .OP_TREE ]:
146- controls = merge_qubits (self .control_registers , ** quregs )
136+ controls = tuple ( merge_qubits (self .control_registers , ** quregs ) )
147137 target_regs = {reg .name : quregs [reg .name ] for reg in self .target_registers }
148138 zero_indx = (0 ,) * len (self .data_shape )
149- if self .num_controls == 0 :
150- yield self ._load_nth_data (zero_indx , cirq .X , ** target_regs )
151- elif self .num_controls == 1 :
152- yield self ._load_nth_data (zero_indx , lambda q : CNOT ().on (controls [0 ], q ), ** target_regs )
139+ if self .num_controls <= 1 :
140+ yield self ._load_nth_data (zero_indx , ctrl_qubits = controls , ** target_regs )
153141 else :
154142 ctrl = np .array (controls )[:, np .newaxis ]
155143 junk = np .array (context .qubit_manager .qalloc (len (controls ) - 2 ))[:, np .newaxis ]
@@ -161,7 +149,7 @@ def decompose_zero_selection(
161149 ctrl = ctrl , junk = junk , target = and_target
162150 )
163151 yield multi_controlled_and
164- yield self ._load_nth_data (zero_indx , lambda q : CNOT (). on ( and_target , q ), ** target_regs )
152+ yield self ._load_nth_data (zero_indx , ctrl_qubits = ( and_target ,), ** target_regs )
165153 yield cirq .inverse (multi_controlled_and )
166154 context .qubit_manager .qfree (list (junk .flatten ()) + [and_target ])
167155
@@ -182,7 +170,7 @@ def nth_operation(
182170 ) -> Iterator [cirq .OP_TREE ]:
183171 selection_idx = tuple (kwargs [reg .name ] for reg in self .selection_registers )
184172 target_regs = {reg .name : kwargs [reg .name ] for reg in self .target_registers }
185- yield self ._load_nth_data (selection_idx , lambda q : CNOT (). on ( control , q ), ** target_regs )
173+ yield self ._load_nth_data (selection_idx , ctrl_qubits = ( control ,), ** target_regs )
186174
187175 def _circuit_diagram_info_ (self , args ) -> cirq .CircuitDiagramInfo :
188176 from qualtran .cirq_interop ._bloq_to_cirq import _wire_symbol_to_cirq_diagram_info
0 commit comments