Skip to content

Commit 179396d

Browse files
authored
Support symbolic CtrlSpec (#1491)
* support symbolic `CtrlSpec` * fix bug in `CtrlSpecAnd` decomposition, add more bloq examples
1 parent 188b663 commit 179396d

13 files changed

Lines changed: 309 additions & 114 deletions

File tree

qualtran/_infra/controlled.py

Lines changed: 55 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import numpy as np
3232
from numpy.typing import NDArray
3333

34+
from ..symbolics import is_symbolic, prod, Shaped, SymbolicInt
3435
from .bloq import Bloq, DecomposeNotImplementedError, DecomposeTypeError
3536
from .data_types import QBit, QDType
3637
from .gate_with_registers import GateWithRegisters
@@ -55,18 +56,21 @@ def _cvs_convert(
5556
int,
5657
np.integer,
5758
NDArray[np.integer],
59+
Shaped,
5860
Sequence[Union[int, np.integer]],
5961
Sequence[Sequence[Union[int, np.integer]]],
60-
Sequence[NDArray[np.integer]],
62+
Sequence[Union[NDArray[np.integer], Shaped]],
6163
]
62-
) -> Tuple[NDArray[np.integer], ...]:
64+
) -> Tuple[Union[NDArray[np.integer], Shaped], ...]:
65+
if isinstance(cvs, Shaped):
66+
return (cvs,)
6367
if isinstance(cvs, (int, np.integer)):
6468
return (np.array(cvs),)
6569
if isinstance(cvs, np.ndarray):
6670
return (cvs,)
6771
if all(isinstance(cv, (int, np.integer)) for cv in cvs):
6872
return (np.asarray(cvs),)
69-
return tuple(np.asarray(cv) for cv in cvs)
73+
return tuple(cv if isinstance(cv, Shaped) else np.asarray(cv) for cv in cvs)
7074

7175

7276
@attrs.frozen(eq=False)
@@ -115,7 +119,9 @@ class CtrlSpec:
115119
qdtypes: Tuple[QDType, ...] = attrs.field(
116120
default=QBit(), converter=lambda qt: (qt,) if isinstance(qt, QDType) else tuple(qt)
117121
)
118-
cvs: Tuple[NDArray[np.integer], ...] = attrs.field(default=1, converter=_cvs_convert)
122+
cvs: Tuple[Union[NDArray[np.integer], Shaped], ...] = attrs.field(
123+
default=1, converter=_cvs_convert
124+
)
119125

120126
def __attrs_post_init__(self):
121127
assert len(self.qdtypes) == len(self.cvs)
@@ -125,19 +131,29 @@ def num_ctrl_reg(self) -> int:
125131
return len(self.qdtypes)
126132

127133
@cached_property
128-
def shapes(self) -> Tuple[Tuple[int, ...], ...]:
134+
def shapes(self) -> Tuple[Tuple[SymbolicInt, ...], ...]:
129135
"""Tuple of shapes of control registers represented by this CtrlSpec."""
130136
return tuple(cv.shape for cv in self.cvs)
131137

132138
@cached_property
133-
def num_qubits(self) -> int:
139+
def concrete_shapes(self) -> tuple[tuple[int, ...], ...]:
140+
"""Tuple of shapes of control registers represented by this CtrlSpec."""
141+
shapes = self.shapes
142+
if is_symbolic(*shapes):
143+
raise ValueError(f"cannot get concrete shapes: found symbolic {self.shapes}")
144+
return shapes # type: ignore
145+
146+
@cached_property
147+
def num_qubits(self) -> SymbolicInt:
134148
"""Total number of qubits required for control registers represented by this CtrlSpec."""
135149
return sum(
136-
dtype.num_qubits * int(np.prod(shape))
137-
for dtype, shape in zip(self.qdtypes, self.shapes)
150+
dtype.num_qubits * prod(shape) for dtype, shape in zip(self.qdtypes, self.shapes)
138151
)
139152

140-
def activation_function_dtypes(self) -> Sequence[Tuple[QDType, Tuple[int, ...]]]:
153+
def is_symbolic(self):
154+
return is_symbolic(*self.qdtypes) or is_symbolic(*self.cvs)
155+
156+
def activation_function_dtypes(self) -> Sequence[Tuple[QDType, Tuple[SymbolicInt, ...]]]:
141157
"""The data types that serve as input to the 'activation function'.
142158
143159
The activation function takes in (quantum) inputs of these types and shapes and determines
@@ -165,6 +181,8 @@ def is_active(self, *vals: 'ClassicalValT') -> bool:
165181
Returns:
166182
True if the specific input values evaluate to `True` for this CtrlSpec.
167183
"""
184+
if self.is_symbolic():
185+
raise ValueError(f"Cannot compute activation for symbolic {self}")
168186
if len(vals) != self.num_ctrl_reg:
169187
raise ValueError(f"Incorrect number of inputs for {self}: {len(vals)}.")
170188

@@ -180,19 +198,31 @@ def is_active(self, *vals: 'ClassicalValT') -> bool:
180198
return True
181199

182200
def wire_symbol(self, i: int, reg: Register, idx: Tuple[int, ...] = tuple()) -> 'WireSymbol':
183-
# Return a circle for bits; a box otherwise.
184201
from qualtran.drawing import Circle, TextBox
185202

203+
cvs = self.cvs[i]
204+
205+
if is_symbolic(cvs):
206+
# control value is not given
207+
return TextBox('ctrl')
208+
209+
# Return a circle for bits; a box otherwise.
210+
cv = cvs[idx]
186211
if reg.bitsize == 1:
187-
cv = self.cvs[i][idx]
188212
return Circle(filled=(cv == 1))
189-
190-
cv = self.cvs[i][idx]
191-
return TextBox(f'{cv}')
213+
else:
214+
return TextBox(f'{cv}')
192215

193216
@cached_property
194-
def _cvs_tuple(self) -> Tuple[int, ...]:
195-
return tuple(cv for cvs in self.cvs for cv in tuple(cvs.reshape(-1)))
217+
def __cvs_tuple(self) -> Tuple[Union[tuple[int, ...], Shaped], ...]:
218+
"""Serialize the control values for hashing and equality checking."""
219+
220+
def _serialize(cvs) -> Union[tuple[int, ...], Shaped]:
221+
if isinstance(cvs, Shaped):
222+
return cvs
223+
return tuple(cvs.reshape(-1))
224+
225+
return tuple(_serialize(cvs) for cvs in self.cvs)
196226

197227
def __eq__(self, other: Any) -> bool:
198228
if not isinstance(other, CtrlSpec):
@@ -201,18 +231,22 @@ def __eq__(self, other: Any) -> bool:
201231
return (
202232
other.qdtypes == self.qdtypes
203233
and other.shapes == self.shapes
204-
and other._cvs_tuple == self._cvs_tuple
234+
and other.__cvs_tuple == self.__cvs_tuple
205235
)
206236

207237
def __hash__(self):
208-
return hash((self.qdtypes, self.shapes, self._cvs_tuple))
238+
return hash((self.qdtypes, self.shapes, self.__cvs_tuple))
209239

210240
def to_cirq_cv(self) -> 'cirq.SumOfProducts':
211241
"""Convert CtrlSpec to cirq.SumOfProducts representation of control values."""
212242
import cirq
213243

244+
if self.is_symbolic():
245+
raise ValueError(f"Cannot convert symbolic {self} to cirq control values.")
246+
214247
cirq_cv = []
215248
for qdtype, cv in zip(self.qdtypes, self.cvs):
249+
assert isinstance(cv, np.ndarray)
216250
for idx in Register('', qdtype, cv.shape).all_idxs():
217251
cirq_cv += [*qdtype.to_bits(cv[idx])]
218252
return cirq.SumOfProducts([tuple(cirq_cv)])
@@ -256,11 +290,14 @@ def from_cirq_cv(
256290

257291
def get_single_ctrl_bit(self) -> ControlBit:
258292
"""If controlled by a single qubit, return the control bit, otherwise raise"""
293+
if self.is_symbolic():
294+
raise ValueError(f"cannot get ctrl bit for symbolic {self}")
259295
if self.num_qubits != 1:
260296
raise ValueError(f"expected a single qubit control, got {self.num_qubits}")
261297

262298
(qdtype,) = self.qdtypes
263299
(cv,) = self.cvs
300+
assert isinstance(cv, np.ndarray)
264301
(idx,) = Register('', qdtype, cv.shape).all_idxs()
265302
(control_bit,) = qdtype.to_bits(cv[idx])
266303

qualtran/_infra/controlled_test.py

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import attrs
1717
import numpy as np
1818
import pytest
19+
import sympy
1920

2021
import qualtran.testing as qlt_testing
2122
from qualtran import (
@@ -24,6 +25,7 @@
2425
CompositeBloq,
2526
Controlled,
2627
CtrlSpec,
28+
DecomposeTypeError,
2729
QBit,
2830
QInt,
2931
QUInt,
@@ -52,6 +54,7 @@
5254
from qualtran.drawing import get_musical_score_data
5355
from qualtran.drawing.musical_score import Circle, SoqData, TextBox
5456
from qualtran.simulation.tensor import cbloq_to_quimb, get_right_and_left_inds
57+
from qualtran.symbolics import Shaped
5558

5659
if TYPE_CHECKING:
5760
import cirq
@@ -73,8 +76,10 @@ def test_ctrl_spec():
7376
cspec3 = CtrlSpec(QInt(64), cvs=np.int64(234234))
7477
assert cspec3 != cspec1
7578
assert cspec3.qdtypes[0].num_qubits == 64
76-
assert cspec3.cvs[0] == 234234
77-
assert cspec3.cvs[0][tuple()] == 234234
79+
(cvs,) = cspec3.cvs
80+
assert isinstance(cvs, np.ndarray)
81+
assert cvs == 234234
82+
assert cvs[tuple()] == 234234
7883

7984

8085
def test_ctrl_spec_shape():
@@ -97,7 +102,9 @@ def test_ctrl_spec_to_cirq_cv_roundtrip():
97102

98103
for ctrl_spec in ctrl_specs:
99104
assert ctrl_spec.to_cirq_cv() == cirq_cv.expand()
100-
assert CtrlSpec.from_cirq_cv(cirq_cv, qdtypes=ctrl_spec.qdtypes, shapes=ctrl_spec.shapes)
105+
assert CtrlSpec.from_cirq_cv(
106+
cirq_cv, qdtypes=ctrl_spec.qdtypes, shapes=ctrl_spec.concrete_shapes
107+
)
101108

102109

103110
@pytest.mark.parametrize(
@@ -120,6 +127,32 @@ def test_ctrl_spec_single_bit_raises(ctrl_spec: CtrlSpec):
120127
ctrl_spec.get_single_ctrl_bit()
121128

122129

130+
@pytest.mark.parametrize("shape", [(1,), (10,), (10, 10)])
131+
def test_ctrl_spec_symbolic_cvs(shape: tuple[int, ...]):
132+
ctrl_spec = CtrlSpec(cvs=Shaped(shape))
133+
assert ctrl_spec.is_symbolic()
134+
assert ctrl_spec.num_qubits == np.prod(shape)
135+
assert ctrl_spec.shapes == (shape,)
136+
137+
138+
@pytest.mark.parametrize("shape", [(1,), (10,), (10, 10)])
139+
def test_ctrl_spec_symbolic_dtype(shape: tuple[int, ...]):
140+
n = sympy.Symbol("n")
141+
dtype = QUInt(n)
142+
143+
ctrl_spec = CtrlSpec(qdtypes=dtype, cvs=Shaped(shape))
144+
145+
assert ctrl_spec.is_symbolic()
146+
assert ctrl_spec.num_qubits == n * np.prod(shape)
147+
assert ctrl_spec.shapes == (shape,)
148+
149+
150+
def test_ctrl_spec_symbolic_wire_symbol():
151+
ctrl_spec = CtrlSpec(cvs=Shaped((10,)))
152+
reg = Register('q', QBit())
153+
assert ctrl_spec.wire_symbol(0, reg) == TextBox('ctrl')
154+
155+
123156
def _test_cirq_equivalence(bloq: Bloq, gate: 'cirq.Gate'):
124157
import cirq
125158

@@ -431,11 +464,15 @@ def signature(self) -> 'Signature':
431464
return Signature([Register('x', QBit(), shape=(3,), side=Side.RIGHT)])
432465

433466
def build_composite_bloq(self, bb: 'BloqBuilder') -> Dict[str, 'SoquetT']:
467+
if self.ctrl_spec.is_symbolic():
468+
raise DecomposeTypeError(f"cannot decompose {self} with symbolic {self.ctrl_spec=}")
469+
434470
one_or_zero = [ZeroState(), OneState()]
435471
ctrl_bloq = Controlled(And(*self.and_ctrl), ctrl_spec=self.ctrl_spec)
436472

437473
ctrl_soqs = {}
438474
for reg, cvs in zip(ctrl_bloq.ctrl_regs, self.ctrl_spec.cvs):
475+
assert isinstance(cvs, np.ndarray)
439476
soqs = np.empty(shape=reg.shape, dtype=object)
440477
for idx in reg.all_idxs():
441478
soqs[idx] = bb.add(IntState(val=cvs[idx], bitsize=reg.dtype.num_qubits))
@@ -447,6 +484,7 @@ def build_composite_bloq(self, bb: 'BloqBuilder') -> Dict[str, 'SoquetT']:
447484
out_soqs = np.asarray([*ctrl_soqs.pop('ctrl'), ctrl_soqs.pop('target')]) # type: ignore[misc]
448485

449486
for reg, cvs in zip(ctrl_bloq.ctrl_regs, self.ctrl_spec.cvs):
487+
assert isinstance(cvs, np.ndarray)
450488
for idx in reg.all_idxs():
451489
ctrl_soq = np.asarray(ctrl_soqs[reg.name])[idx]
452490
bb.add(IntEffect(val=cvs[idx], bitsize=reg.dtype.num_qubits), val=ctrl_soq)

0 commit comments

Comments
 (0)