3434import numpy as np
3535import sympy
3636from numpy .typing import NDArray
37+ from typing_extensions import Self
3738
3839from qualtran import (
3940 Bloq ,
4041 BloqInstance ,
4142 DanglingT ,
43+ DecomposeNotImplementedError ,
44+ DecomposeTypeError ,
4245 LeftDangle ,
4346 Register ,
4447 RightDangle ,
@@ -301,31 +304,49 @@ def _update_assign_from_vals(
301304 soq = Soquet (binst , reg )
302305 self .soq_assign [soq ] = val
303306
304- def _binst_on_classical_vals (self , binst , in_vals ) -> None :
305- """Call `on_classical_vals` on a given bloq instance."""
306- bloq = binst .bloq
307+ def _recurse_impl (self , cbloq : 'CompositeBloq' , in_vals ):
308+ """Overridable function to recursively simulate a composite bloq."""
309+ out_vals , _ = call_cbloq_classically (cbloq .signature , in_vals , cbloq ._binst_graph )
310+ phase = None
307311
308- out_vals = bloq .on_classical_vals (** in_vals )
309- if not isinstance (out_vals , dict ):
310- raise TypeError (
311- f"{ bloq .__class__ .__name__ } .on_classical_vals should return a dictionary."
312- )
313- self ._update_assign_from_vals (bloq .signature .rights (), binst , out_vals )
312+ return out_vals , phase
314313
315- def _binst_basis_state_phase (self , binst , in_vals ) -> None :
316- """Call `basis_state_phase` on a given bloq instance .
314+ def _recurse (self , binst : 'BloqInstance' , in_vals ) -> Self :
315+ """Recursively simulate a composite bloq.
317316
318- This base simulation class will raise an error if the bloq reports any phasing.
319- This method is overwritten in `PhasedClassicalSimState` to support phasing.
317+ This handles decomposing the bloq and using the results of the sub-simulation
318+ to update the simulator state. Developers should override `_recurse_impl` to
319+ customize the sub-simulation.
320320 """
321321 bloq = binst .bloq
322- bloq_phase = bloq .basis_state_phase (** in_vals )
322+ try :
323+ cbloq = bloq .decompose_bloq ()
324+ out_vals , bloq_phase = self ._recurse_impl (cbloq , in_vals )
325+
326+ except DecomposeTypeError as e :
327+ raise NotImplementedError (f"{ bloq } is not classically simulable." ) from e
328+ except DecomposeNotImplementedError as e :
329+ raise NotImplementedError (
330+ f"{ bloq } has no decomposition and does not "
331+ f"support classical simulation directly"
332+ ) from e
333+ except NotImplementedError as e :
334+ raise NotImplementedError (f"{ bloq } does not support classical simulation: { e } " ) from e
335+
336+ self ._update (binst , out_vals , bloq_phase )
337+ return self
338+
339+ def _update (self , binst : 'BloqInstance' , out_vals , bloq_phase : Union [complex , None ]) -> None :
340+ """Overridable method to update the current simulator state."""
341+ self ._update_assign_from_vals (binst .bloq .signature .rights (), binst , out_vals )
342+
323343 if bloq_phase is not None :
324344 raise ValueError (
325- f"{ bloq } imparts a phase, and can't be simulated purely classically. Consider using `do_phased_classical_simulation`."
345+ f"{ binst .bloq } imparts a phase of { bloq_phase } , and can't be simulated purely classically. "
346+ f"Consider using `do_phased_classical_simulation`."
326347 )
327348
328- def step (self ) -> 'ClassicalSimState' :
349+ def step (self ) -> Self :
329350 """Advance the simulation by one bloq instance.
330351
331352 After calling this method, `self.last_binst` will contain the bloq instance that
@@ -349,11 +370,32 @@ def _in_vals(reg: Register):
349370 return _get_in_vals (binst , reg , soq_assign = self .soq_assign )
350371
351372 bloq = binst .bloq
373+ bcls_name = bloq .__class__ .__name__
352374 in_vals = {reg .name : _in_vals (reg ) for reg in bloq .signature .lefts ()}
375+ out_vals = bloq .on_classical_vals (** in_vals )
376+ bloq_phase = bloq .basis_state_phase (** in_vals )
377+
378+ # +--+ basis_state_phase
379+ # +- on_classical_vals
380+ # | |
381+ # dict None Use classical values.
382+ # dict number Use classical values and phase only if doing phased sim
383+ # NotImplemented None decompose and use the correct simulator type
384+ # NotImplemented number error
385+
386+ if out_vals is NotImplemented :
387+ if bloq_phase is not None :
388+ raise ValueError (
389+ f"`basis_state_phase` defined on { bcls_name } , but not `on_classical_vals`"
390+ )
391+
392+ return self ._recurse (binst , in_vals )
393+
394+ if not isinstance (out_vals , dict ):
395+ raise TypeError (f"`{ bcls_name } .on_classical_vals` should return a dictionary." )
396+
397+ self ._update (binst , out_vals , bloq_phase )
353398
354- # Apply methods
355- self ._binst_on_classical_vals (binst , in_vals )
356- self ._binst_basis_state_phase (binst , in_vals )
357399 return self
358400
359401 def finalize (self ) -> Dict [str , 'ClassicalValT' ]:
@@ -466,14 +508,19 @@ def from_cbloq(
466508 random_handler = rnd_handler ,
467509 )
468510
469- def _binst_basis_state_phase (self , binst , in_vals ):
470- """Call `basis_state_phase` on a given bloq instance.
511+ def _recurse_impl (self , cbloq , in_vals ):
512+ """Use phased classical simulation when recursing."""
513+ sim = PhasedClassicalSimState (
514+ cbloq .signature , cbloq ._binst_graph , in_vals , random_handler = self ._random_handler
515+ )
516+ final_vals = sim .simulate ()
517+ phase = sim .phase
518+ return final_vals , phase
519+
520+ def _update (self , binst : 'BloqInstance' , out_vals , bloq_phase : Union [complex , None ]) -> None :
521+ """Update the current simulator state, including phase tracking."""
522+ self ._update_assign_from_vals (binst .bloq .signature .rights (), binst , out_vals )
471523
472- If this method returns a value, the current phase will be updated. Otherwise, we
473- leave the phase as-is.
474- """
475- bloq = binst .bloq
476- bloq_phase = bloq .basis_state_phase (** in_vals )
477524 if isinstance (bloq_phase , MeasurementPhase ):
478525 # In this special case, there is a coupling between the classical result and the
479526 # phase result (because the classical result is stochastic). We look up the measurement
0 commit comments