Skip to content

Commit 2f29aa5

Browse files
authored
Fix recursive phased classical simulation (#1731)
The classical simulators should recurse into bloq decompositions. The phased classical simulator would use the ordinary classical simulator when recursing since that was "baked into" the `Bloq` method -- and it would throw an error even if it should have been able to simulate the (sub-)circuit. This changes the responsibility of dealing with the possible recursion to the simulator classes themselves.
1 parent fed39dc commit 2f29aa5

5 files changed

Lines changed: 130 additions & 45 deletions

File tree

qualtran/_infra/bloq.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -272,19 +272,10 @@ def on_classical_vals(
272272
registers with `shape` will be an ndarray of values of the expected type.
273273
274274
Returns:
275-
A dictionary mapping right (or thru) register name to output classical values.
275+
A dictionary mapping right (or thru) register name to output classical values or
276+
NotImplemented if this is not classical-reversible logic.
276277
"""
277-
try:
278-
return self.decompose_bloq().on_classical_vals(**vals)
279-
except DecomposeTypeError as e:
280-
raise NotImplementedError(f"{self} is not classically simulable.") from e
281-
except DecomposeNotImplementedError as e:
282-
raise NotImplementedError(
283-
f"{self} has no decomposition and does not "
284-
f"support classical simulation directly"
285-
) from e
286-
except NotImplementedError as e:
287-
raise NotImplementedError(f"{self} does not support classical simulation: {e}") from e
278+
return NotImplemented
288279

289280
def basis_state_phase(
290281
self, **vals: 'ClassicalValT'

qualtran/_infra/controlled.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -440,12 +440,12 @@ def on_classical_vals(self, **vals: 'ClassicalValT') -> Mapping[str, 'ClassicalV
440440
ctrl_vals = [vals[reg_name] for reg_name in self.ctrl_reg_names]
441441
other_vals = {reg.name: vals[reg.name] for reg in self.subbloq.signature}
442442
if self.ctrl_spec.is_active(*ctrl_vals):
443-
rets = {
444-
**self.subbloq.on_classical_vals(**other_vals),
445-
**{
446-
reg_name: ctrl_val for reg_name, ctrl_val in zip(self.ctrl_reg_names, ctrl_vals)
447-
},
448-
}
443+
rets = self.subbloq.on_classical_vals(**other_vals)
444+
if rets is NotImplemented:
445+
return NotImplemented
446+
rets |= {
447+
reg_name: ctrl_val for reg_name, ctrl_val in zip(self.ctrl_reg_names, ctrl_vals)
448+
} # type: ignore[operator]
449449
return rets
450450

451451
return vals

qualtran/bloqs/basic_gates/z_basis.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,15 @@ def my_tensors(
268268
)
269269
]
270270

271+
def on_classical_vals(self, **vals: 'ClassicalValT') -> Dict[str, 'ClassicalValT']:
272+
# Diagonal, but causes phases: see `basis_state_phase`
273+
return vals
274+
275+
def basis_state_phase(self, q: int) -> Optional[complex]:
276+
if q == 1:
277+
return -1
278+
return 1
279+
271280
def get_ctrl_system(self, ctrl_spec: 'CtrlSpec') -> Tuple['Bloq', 'AddControlledT']:
272281
if ctrl_spec != CtrlSpec():
273282
# Delegate to the general superclass behavior

qualtran/simulation/classical_sim.py

Lines changed: 73 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,14 @@
3434
import numpy as np
3535
import sympy
3636
from numpy.typing import NDArray
37+
from typing_extensions import Self
3738

3839
from qualtran import (
3940
Bloq,
4041
BloqInstance,
4142
DanglingT,
43+
DecomposeNotImplementedError,
44+
DecomposeTypeError,
4245
LeftDangle,
4346
Register,
4447
RightDangle,
@@ -301,31 +304,49 @@ def _update_assign_from_vals(
301304
soq = Soquet(binst, reg)
302305
self.soq_assign[soq] = val
303306

304-
def _binst_on_classical_vals(self, binst, in_vals) -> None:
305-
"""Call `on_classical_vals` on a given bloq instance."""
306-
bloq = binst.bloq
307+
def _recurse_impl(self, cbloq: 'CompositeBloq', in_vals):
308+
"""Overridable function to recursively simulate a composite bloq."""
309+
out_vals, _ = call_cbloq_classically(cbloq.signature, in_vals, cbloq._binst_graph)
310+
phase = None
307311

308-
out_vals = bloq.on_classical_vals(**in_vals)
309-
if not isinstance(out_vals, dict):
310-
raise TypeError(
311-
f"{bloq.__class__.__name__}.on_classical_vals should return a dictionary."
312-
)
313-
self._update_assign_from_vals(bloq.signature.rights(), binst, out_vals)
312+
return out_vals, phase
314313

315-
def _binst_basis_state_phase(self, binst, in_vals) -> None:
316-
"""Call `basis_state_phase` on a given bloq instance.
314+
def _recurse(self, binst: 'BloqInstance', in_vals) -> Self:
315+
"""Recursively simulate a composite bloq.
317316
318-
This base simulation class will raise an error if the bloq reports any phasing.
319-
This method is overwritten in `PhasedClassicalSimState` to support phasing.
317+
This handles decomposing the bloq and using the results of the sub-simulation
318+
to update the simulator state. Developers should override `_recurse_impl` to
319+
customize the sub-simulation.
320320
"""
321321
bloq = binst.bloq
322-
bloq_phase = bloq.basis_state_phase(**in_vals)
322+
try:
323+
cbloq = bloq.decompose_bloq()
324+
out_vals, bloq_phase = self._recurse_impl(cbloq, in_vals)
325+
326+
except DecomposeTypeError as e:
327+
raise NotImplementedError(f"{bloq} is not classically simulable.") from e
328+
except DecomposeNotImplementedError as e:
329+
raise NotImplementedError(
330+
f"{bloq} has no decomposition and does not "
331+
f"support classical simulation directly"
332+
) from e
333+
except NotImplementedError as e:
334+
raise NotImplementedError(f"{bloq} does not support classical simulation: {e}") from e
335+
336+
self._update(binst, out_vals, bloq_phase)
337+
return self
338+
339+
def _update(self, binst: 'BloqInstance', out_vals, bloq_phase: Union[complex, None]) -> None:
340+
"""Overridable method to update the current simulator state."""
341+
self._update_assign_from_vals(binst.bloq.signature.rights(), binst, out_vals)
342+
323343
if bloq_phase is not None:
324344
raise ValueError(
325-
f"{bloq} imparts a phase, and can't be simulated purely classically. Consider using `do_phased_classical_simulation`."
345+
f"{binst.bloq} imparts a phase of {bloq_phase}, and can't be simulated purely classically. "
346+
f"Consider using `do_phased_classical_simulation`."
326347
)
327348

328-
def step(self) -> 'ClassicalSimState':
349+
def step(self) -> Self:
329350
"""Advance the simulation by one bloq instance.
330351
331352
After calling this method, `self.last_binst` will contain the bloq instance that
@@ -349,11 +370,32 @@ def _in_vals(reg: Register):
349370
return _get_in_vals(binst, reg, soq_assign=self.soq_assign)
350371

351372
bloq = binst.bloq
373+
bcls_name = bloq.__class__.__name__
352374
in_vals = {reg.name: _in_vals(reg) for reg in bloq.signature.lefts()}
375+
out_vals = bloq.on_classical_vals(**in_vals)
376+
bloq_phase = bloq.basis_state_phase(**in_vals)
377+
378+
# +--+ basis_state_phase
379+
# +- on_classical_vals
380+
# | |
381+
# dict None Use classical values.
382+
# dict number Use classical values and phase only if doing phased sim
383+
# NotImplemented None decompose and use the correct simulator type
384+
# NotImplemented number error
385+
386+
if out_vals is NotImplemented:
387+
if bloq_phase is not None:
388+
raise ValueError(
389+
f"`basis_state_phase` defined on {bcls_name}, but not `on_classical_vals`"
390+
)
391+
392+
return self._recurse(binst, in_vals)
393+
394+
if not isinstance(out_vals, dict):
395+
raise TypeError(f"`{bcls_name}.on_classical_vals` should return a dictionary.")
396+
397+
self._update(binst, out_vals, bloq_phase)
353398

354-
# Apply methods
355-
self._binst_on_classical_vals(binst, in_vals)
356-
self._binst_basis_state_phase(binst, in_vals)
357399
return self
358400

359401
def finalize(self) -> Dict[str, 'ClassicalValT']:
@@ -466,14 +508,19 @@ def from_cbloq(
466508
random_handler=rnd_handler,
467509
)
468510

469-
def _binst_basis_state_phase(self, binst, in_vals):
470-
"""Call `basis_state_phase` on a given bloq instance.
511+
def _recurse_impl(self, cbloq, in_vals):
512+
"""Use phased classical simulation when recursing."""
513+
sim = PhasedClassicalSimState(
514+
cbloq.signature, cbloq._binst_graph, in_vals, random_handler=self._random_handler
515+
)
516+
final_vals = sim.simulate()
517+
phase = sim.phase
518+
return final_vals, phase
519+
520+
def _update(self, binst: 'BloqInstance', out_vals, bloq_phase: Union[complex, None]) -> None:
521+
"""Update the current simulator state, including phase tracking."""
522+
self._update_assign_from_vals(binst.bloq.signature.rights(), binst, out_vals)
471523

472-
If this method returns a value, the current phase will be updated. Otherwise, we
473-
leave the phase as-is.
474-
"""
475-
bloq = binst.bloq
476-
bloq_phase = bloq.basis_state_phase(**in_vals)
477524
if isinstance(bloq_phase, MeasurementPhase):
478525
# In this special case, there is a coupling between the classical result and the
479526
# phase result (because the classical result is stochastic). We look up the measurement

qualtran/simulation/classical_sim_test.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import itertools
1616
from typing import Dict, Union
1717

18+
import attrs
1819
import networkx as nx
1920
import numpy as np
2021
import pytest
@@ -38,8 +39,8 @@
3839
Register,
3940
Side,
4041
Signature,
42+
SoquetT,
4143
)
42-
from qualtran.bloqs.basic_gates import CNOT
4344
from qualtran.simulation.classical_sim import (
4445
_BannedClassicalValHandler,
4546
_FixedClassicalValHandler,
@@ -153,6 +154,8 @@ def test_normal_classical_on_phased():
153154

154155

155156
def test_cnot_assign_dict():
157+
from qualtran.bloqs.basic_gates import CNOT
158+
156159
cbloq = CNOT().as_composite_bloq()
157160
binst_graph = cbloq._binst_graph # pylint: disable=protected-access
158161
vals = dict(ctrl=1, target=0)
@@ -338,3 +341,38 @@ def test_phased_classical_distribution():
338341
)
339342
assert final_values['c'] == 1
340343
assert phase == 1
344+
345+
346+
@attrs.frozen
347+
class ComposedPhasing(Bloq):
348+
n: int = 0
349+
350+
@property
351+
def signature(self) -> 'Signature':
352+
return Signature([Register('x', QBit(), side=Side.RIGHT)])
353+
354+
def build_composite_bloq(self, bb: 'BloqBuilder') -> Dict[str, 'SoquetT']:
355+
from qualtran.bloqs.basic_gates import OneState, ZGate
356+
357+
x = bb.add(OneState())
358+
for _ in range(self.n):
359+
x = bb.add(ZGate(), q=x)
360+
return {'x': x}
361+
362+
363+
def test_derive_phasing_from_composed_bloq():
364+
vals, phase = do_phased_classical_simulation(ComposedPhasing(0), {})
365+
assert vals == {'x': 1}
366+
assert phase == +1.0
367+
368+
vals, phase = do_phased_classical_simulation(ComposedPhasing(1), {})
369+
assert vals == {'x': 1}
370+
assert phase == -1.0
371+
372+
vals, phase = do_phased_classical_simulation(ComposedPhasing(2), {})
373+
assert vals == {'x': 1}
374+
assert phase == +1.0
375+
376+
vals, phase = do_phased_classical_simulation(ComposedPhasing(3), {})
377+
assert vals == {'x': 1}
378+
assert phase == -1.0

0 commit comments

Comments
 (0)