3131import numpy as np
3232from numpy .typing import NDArray
3333
34+ from ..symbolics import is_symbolic , prod , Shaped , SymbolicInt
3435from .bloq import Bloq , DecomposeNotImplementedError , DecomposeTypeError
3536from .data_types import QBit , QDType
3637from .gate_with_registers import GateWithRegisters
@@ -55,18 +56,21 @@ def _cvs_convert(
5556 int ,
5657 np .integer ,
5758 NDArray [np .integer ],
59+ Shaped ,
5860 Sequence [Union [int , np .integer ]],
5961 Sequence [Sequence [Union [int , np .integer ]]],
60- Sequence [NDArray [np .integer ]],
62+ Sequence [Union [ NDArray [np .integer ], Shaped ]],
6163 ]
62- ) -> Tuple [NDArray [np .integer ], ...]:
64+ ) -> Tuple [Union [NDArray [np .integer ], Shaped ], ...]:
65+ if isinstance (cvs , Shaped ):
66+ return (cvs ,)
6367 if isinstance (cvs , (int , np .integer )):
6468 return (np .array (cvs ),)
6569 if isinstance (cvs , np .ndarray ):
6670 return (cvs ,)
6771 if all (isinstance (cv , (int , np .integer )) for cv in cvs ):
6872 return (np .asarray (cvs ),)
69- return tuple (np .asarray (cv ) for cv in cvs )
73+ return tuple (cv if isinstance ( cv , Shaped ) else np .asarray (cv ) for cv in cvs )
7074
7175
7276@attrs .frozen (eq = False )
@@ -115,7 +119,9 @@ class CtrlSpec:
115119 qdtypes : Tuple [QDType , ...] = attrs .field (
116120 default = QBit (), converter = lambda qt : (qt ,) if isinstance (qt , QDType ) else tuple (qt )
117121 )
118- cvs : Tuple [NDArray [np .integer ], ...] = attrs .field (default = 1 , converter = _cvs_convert )
122+ cvs : Tuple [Union [NDArray [np .integer ], Shaped ], ...] = attrs .field (
123+ default = 1 , converter = _cvs_convert
124+ )
119125
120126 def __attrs_post_init__ (self ):
121127 assert len (self .qdtypes ) == len (self .cvs )
@@ -125,19 +131,29 @@ def num_ctrl_reg(self) -> int:
125131 return len (self .qdtypes )
126132
127133 @cached_property
128- def shapes (self ) -> Tuple [Tuple [int , ...], ...]:
134+ def shapes (self ) -> Tuple [Tuple [SymbolicInt , ...], ...]:
129135 """Tuple of shapes of control registers represented by this CtrlSpec."""
130136 return tuple (cv .shape for cv in self .cvs )
131137
132138 @cached_property
133- def num_qubits (self ) -> int :
139+ def concrete_shapes (self ) -> tuple [tuple [int , ...], ...]:
140+ """Tuple of shapes of control registers represented by this CtrlSpec."""
141+ shapes = self .shapes
142+ if is_symbolic (* shapes ):
143+ raise ValueError (f"cannot get concrete shapes: found symbolic { self .shapes } " )
144+ return shapes # type: ignore
145+
146+ @cached_property
147+ def num_qubits (self ) -> SymbolicInt :
134148 """Total number of qubits required for control registers represented by this CtrlSpec."""
135149 return sum (
136- dtype .num_qubits * int (np .prod (shape ))
137- for dtype , shape in zip (self .qdtypes , self .shapes )
150+ dtype .num_qubits * prod (shape ) for dtype , shape in zip (self .qdtypes , self .shapes )
138151 )
139152
140- def activation_function_dtypes (self ) -> Sequence [Tuple [QDType , Tuple [int , ...]]]:
153+ def is_symbolic (self ):
154+ return is_symbolic (* self .qdtypes ) or is_symbolic (* self .cvs )
155+
156+ def activation_function_dtypes (self ) -> Sequence [Tuple [QDType , Tuple [SymbolicInt , ...]]]:
141157 """The data types that serve as input to the 'activation function'.
142158
143159 The activation function takes in (quantum) inputs of these types and shapes and determines
@@ -165,6 +181,8 @@ def is_active(self, *vals: 'ClassicalValT') -> bool:
165181 Returns:
166182 True if the specific input values evaluate to `True` for this CtrlSpec.
167183 """
184+ if self .is_symbolic ():
185+ raise ValueError (f"Cannot compute activation for symbolic { self } " )
168186 if len (vals ) != self .num_ctrl_reg :
169187 raise ValueError (f"Incorrect number of inputs for { self } : { len (vals )} ." )
170188
@@ -180,19 +198,31 @@ def is_active(self, *vals: 'ClassicalValT') -> bool:
180198 return True
181199
182200 def wire_symbol (self , i : int , reg : Register , idx : Tuple [int , ...] = tuple ()) -> 'WireSymbol' :
183- # Return a circle for bits; a box otherwise.
184201 from qualtran .drawing import Circle , TextBox
185202
203+ cvs = self .cvs [i ]
204+
205+ if is_symbolic (cvs ):
206+ # control value is not given
207+ return TextBox ('ctrl' )
208+
209+ # Return a circle for bits; a box otherwise.
210+ cv = cvs [idx ]
186211 if reg .bitsize == 1 :
187- cv = self .cvs [i ][idx ]
188212 return Circle (filled = (cv == 1 ))
189-
190- cv = self .cvs [i ][idx ]
191- return TextBox (f'{ cv } ' )
213+ else :
214+ return TextBox (f'{ cv } ' )
192215
193216 @cached_property
194- def _cvs_tuple (self ) -> Tuple [int , ...]:
195- return tuple (cv for cvs in self .cvs for cv in tuple (cvs .reshape (- 1 )))
217+ def __cvs_tuple (self ) -> Tuple [Union [tuple [int , ...], Shaped ], ...]:
218+ """Serialize the control values for hashing and equality checking."""
219+
220+ def _serialize (cvs ) -> Union [tuple [int , ...], Shaped ]:
221+ if isinstance (cvs , Shaped ):
222+ return cvs
223+ return tuple (cvs .reshape (- 1 ))
224+
225+ return tuple (_serialize (cvs ) for cvs in self .cvs )
196226
197227 def __eq__ (self , other : Any ) -> bool :
198228 if not isinstance (other , CtrlSpec ):
@@ -201,18 +231,22 @@ def __eq__(self, other: Any) -> bool:
201231 return (
202232 other .qdtypes == self .qdtypes
203233 and other .shapes == self .shapes
204- and other ._cvs_tuple == self ._cvs_tuple
234+ and other .__cvs_tuple == self .__cvs_tuple
205235 )
206236
207237 def __hash__ (self ):
208- return hash ((self .qdtypes , self .shapes , self ._cvs_tuple ))
238+ return hash ((self .qdtypes , self .shapes , self .__cvs_tuple ))
209239
210240 def to_cirq_cv (self ) -> 'cirq.SumOfProducts' :
211241 """Convert CtrlSpec to cirq.SumOfProducts representation of control values."""
212242 import cirq
213243
244+ if self .is_symbolic ():
245+ raise ValueError (f"Cannot convert symbolic { self } to cirq control values." )
246+
214247 cirq_cv = []
215248 for qdtype , cv in zip (self .qdtypes , self .cvs ):
249+ assert isinstance (cv , np .ndarray )
216250 for idx in Register ('' , qdtype , cv .shape ).all_idxs ():
217251 cirq_cv += [* qdtype .to_bits (cv [idx ])]
218252 return cirq .SumOfProducts ([tuple (cirq_cv )])
@@ -256,11 +290,14 @@ def from_cirq_cv(
256290
257291 def get_single_ctrl_bit (self ) -> ControlBit :
258292 """If controlled by a single qubit, return the control bit, otherwise raise"""
293+ if self .is_symbolic ():
294+ raise ValueError (f"cannot get ctrl bit for symbolic { self } " )
259295 if self .num_qubits != 1 :
260296 raise ValueError (f"expected a single qubit control, got { self .num_qubits } " )
261297
262298 (qdtype ,) = self .qdtypes
263299 (cv ,) = self .cvs
300+ assert isinstance (cv , np .ndarray )
264301 (idx ,) = Register ('' , qdtype , cv .shape ).all_idxs ()
265302 (control_bit ,) = qdtype .to_bits (cv [idx ])
266303
0 commit comments