Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions crates/ppvm-python-native/src/interface_tableau.rs
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,10 @@ macro_rules! create_interface {
self.cz(targets)
}

pub fn cz_block(&mut self, control_base: usize, target_base: usize, count: usize) {
self.inner.cz_block(control_base, target_base, count);
}

// rot1
pub fn rx(&mut self, targets: Vec<usize>, theta: f64) {
self.inner.rx_many(targets.as_slice(), theta);
Expand Down
38 changes: 17 additions & 21 deletions crates/ppvm-tableau/benches/tableau-msd-fused.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,36 +35,32 @@ fn msd_func_fused<const MEASURE: bool>() -> (String, Tab) {
tab.sqrt_x_many(ql[1]);
tab.sqrt_x_many(ql[4]);

// ql[0] x ql[1]: pairs (0,17)...(16,33) — all in word 0
tab.cz_block_pairs(0, 17, 17);
// Cross-block CZ layers entangle two contiguous registers with a constant
// offset. `cz_block` takes plain qubit indices (control_base, target_base,
// count) and splits each run at u64-word boundaries internally, so it emits
// the same within-word / cross-word kernels the hand-written calls did.
let block_len = qubits_per_code_block;

// ql[2] x ql[3]: pairs (34,51)...(50,67)
tab.cz_block_pairs(34, 17, 13); // (34,51)...(46,63) in word 0
// (47,64)...(50,67): controls word 0 bits 47-50, targets word 1 bits 0-3
tab.cz_block_pairs_cross_word(0, 47, 1, 0, 4);
// ql[0] x ql[1]
tab.cz_block(ql[0][0], ql[1][0], block_len);
// ql[2] x ql[3]
tab.cz_block(ql[2][0], ql[3][0], block_len);

// sqrt_y on ql[0] and ql[3]
tab.sqrt_y_many(ql[0]);
tab.sqrt_y_many(ql[3]);

// ql[0] x ql[2]: pairs (0,34)...(16,50) — all in word 0
tab.cz_block_pairs(0, 34, 17);

// ql[3] x ql[4]: (51,68)...(67,84)
// (51,68)...(63,80): controls word 0 bits 51-63, targets word 1 bits 4-16
tab.cz_block_pairs_cross_word(0, 51, 1, 4, 13);
tab.cz_block_pairs(64, 17, 4); // (64,81)...(67,84) both in word 1
// ql[0] x ql[2]
tab.cz_block(ql[0][0], ql[2][0], block_len);
// ql[3] x ql[4]
tab.cz_block(ql[3][0], ql[4][0], block_len);

tab.sqrt_x_dag_many(ql[0]);

// ql[0] x ql[4]: (0,68)...(16,84)
// controls word 0 bits 0-16, targets word 1 bits 4-20
tab.cz_block_pairs_cross_word(0, 0, 1, 4, 17);

// ql[1] x ql[3]: (17,51)...(33,67)
tab.cz_block_pairs(17, 34, 13); // (17,51)...(29,63) in word 0
// (30,64)...(33,67): controls word 0 bits 30-33, targets word 1 bits 0-3
tab.cz_block_pairs_cross_word(0, 30, 1, 0, 4);
// ql[0] x ql[4]
tab.cz_block(ql[0][0], ql[4][0], block_len);
// ql[1] x ql[3]
tab.cz_block(ql[1][0], ql[3][0], block_len);

// sqrt_x_dag on all blocks
for block in ql.iter().take(5) {
Expand Down
79 changes: 79 additions & 0 deletions crates/ppvm-tableau/src/data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -771,6 +771,47 @@ where
}
}

/// Apply CZ to `count` pairs with a constant offset, given in qubit-index
/// terms: `(control_base + i, target_base + i)` for `i in 0..count`.
///
/// This is the high-level entry point for a fused block of CZs: it splits
/// the run at storage-word boundaries internally and dispatches each
/// segment to [`Self::cz_block_pairs`] (control and target in the same
/// word) or [`Self::cz_block_pairs_cross_word`] (straddling two words), so
/// callers never need to reason about the `u64` packing. CZ is symmetric,
/// so the two bases may be passed in either order.
pub fn cz_block(&mut self, control_base: usize, target_base: usize, count: usize)
where
<<T::Storage as BitView>::Store as TryFrom<usize>>::Error: Debug,
<T::Storage as BitView>::Store: PrimInt + TryFrom<usize>,
{
if count == 0 {
return;
}
// cz_block_pairs needs a non-negative offset; CZ is symmetric, so order
// the two bases.
let (lo, hi) = if control_base <= target_base {
(control_base, target_base)
} else {
(target_base, control_base)
};
let bits_per_word = std::mem::size_of::<<T::Storage as BitView>::Store>() * 8;
let mut i = 0;
while i < count {
let (c, t) = (lo + i, hi + i);
let (wc, bc) = (c / bits_per_word, c % bits_per_word);
let (wt, bt) = (t / bits_per_word, t % bits_per_word);
// Longest run before either index crosses into the next word.
let run = (bits_per_word - bc).min(bits_per_word - bt).min(count - i);
if wc == wt {
self.cz_block_pairs(c, t - c, run);
} else {
self.cz_block_pairs_cross_word(wc, bc, wt, bt, run);
}
i += run;
}
}

// helper functions

/// Compute the decomposition of a pauli into stabilizer destabilizer products
Expand Down Expand Up @@ -1415,4 +1456,42 @@ mod tests {
snapshot_tableau(&tab2.tableau)
);
}

#[test]
fn test_cz_block_matches_individual_across_word_boundary() {
// cz_block must split a run that straddles the u64 boundary into the
// right within-word + cross-word segments. control_base=34,
// target_base=51, count=17 reproduces the MSD ql[2]xql[3] sweep:
// (34,51)..(46,63) in word 0, then (47,64)..(50,67) cross-word.
use ppvm_pauli_sum::config::fx64hash::Byte8F64;
type GTab = GeneralizedTableau<Byte8F64<2>>;
let n = 85;
let mut tab1: GTab = GeneralizedTableau::new(n, 1e-12);
for i in 0..n {
Clifford::h(&mut tab1.tableau, i);
}
let mut tab2 = tab1.clone();

let (control_base, target_base, count) = (34, 51, 17);
for i in 0..count {
Clifford::cz(&mut tab1, control_base + i, target_base + i);
}
tab2.cz_block(control_base, target_base, count);

assert_eq!(
snapshot_tableau(&tab1.tableau),
snapshot_tableau(&tab2.tableau)
);

// Reversed bases (CZ is symmetric) must give the same result.
let mut tab3 = GeneralizedTableau::<Byte8F64<2>>::new(n, 1e-12);
for i in 0..n {
Clifford::h(&mut tab3.tableau, i);
}
tab3.cz_block(target_base, control_base, count);
assert_eq!(
snapshot_tableau(&tab1.tableau),
snapshot_tableau(&tab3.tableau)
);
}
}
1 change: 1 addition & 0 deletions ppvm-python/src/ppvm/_core.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ class _GeneralizedTableauBase:
def zcy(self, targets: Sequence[int]) -> None: ...
def cz(self, targets: Sequence[int]) -> None: ...
def zcz(self, targets: Sequence[int]) -> None: ...
def cz_block(self, control_base: int, target_base: int, count: int) -> None: ...
def rx(self, targets: Sequence[int], theta: float) -> None: ...
def ry(self, targets: Sequence[int], theta: float) -> None: ...
def rz(self, targets: Sequence[int], theta: float) -> None: ...
Expand Down
39 changes: 32 additions & 7 deletions ppvm-python/src/ppvm/generalized_tableau.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,13 @@ class MeasurementResult(enum.IntEnum):
LOST = 2


# Indexed by integer outcome value (0/1/2) to reuse the singleton enum members.
# This is much faster than calling ``MeasurementResult(i)`` per element: the
# IntEnum constructor dominates large readouts, while a tuple index just bumps a
# refcount. Shared with ``GeneralizedTableauSum``.
_BY_VALUE = (MeasurementResult.ZERO, MeasurementResult.ONE, MeasurementResult.LOST)


@dataclass(frozen=True)
class GeneralizedTableau(
CliffordMixin,
Expand Down Expand Up @@ -148,6 +155,26 @@ def t_dag(self, *targets: int | Iterable[int]) -> None:
"""
self._interface.t_dag(_normalize_targets(targets))

def cz_block(self, control_base: int, target_base: int, count: int) -> None:
"""Apply a fused block of CZ gates over constant-offset qubit pairs.

Applies CZ to ``(control_base + i, target_base + i)`` for ``i`` in
``range(count)`` -- i.e. the gates ``zip(range(control_base, ...),
range(target_base, ...))`` would produce. This uses a word-level kernel
Comment on lines +161 to +163
that is much faster than the equivalent `cz` call when the pairs form a
contiguous, constant-offset block (e.g. entangling two adjacent qubit
registers). For scattered pairs, use `cz`.

CZ is symmetric, so ``control_base`` and ``target_base`` may be given in
either order.

Args:
control_base: First qubit of the control run.
target_base: First qubit of the target run.
count: Number of CZ pairs.
"""
self._interface.cz_block(control_base, target_base, count)

def measure(self, addr0: int) -> MeasurementResult:
"""Measure the specified qubit in the Z basis.

Expand All @@ -158,7 +185,7 @@ def measure(self, addr0: int) -> MeasurementResult:
The measurement outcome as a ``MeasurementResult``, which is
``LOST`` if the qubit has been lost, ``ZERO`` or ``ONE`` otherwise.
"""
return MeasurementResult(self._interface.measure(addr0))
return _BY_VALUE[self._interface.measure(addr0)]

def measure_many(self, *targets: int | Iterable[int]) -> list[MeasurementResult]:
"""Measure several qubits in the Z basis.
Expand All @@ -169,17 +196,15 @@ def measure_many(self, *targets: int | Iterable[int]) -> list[MeasurementResult]
Returns:
A list of ``MeasurementResult`` outcomes, one per target.
"""
return [
MeasurementResult(v) for v in self._interface.measure_many(_normalize_targets(targets))
]
return [_BY_VALUE[v] for v in self._interface.measure_many(_normalize_targets(targets))]

def current_measurement_record(self) -> list[MeasurementResult]:
"""Return all measurement outcomes recorded so far.

Returns:
A list of ``MeasurementResult`` outcomes in measurement order.
"""
return [MeasurementResult(v) for v in self._interface.current_measurement_record()]
return [_BY_VALUE[v] for v in self._interface.current_measurement_record()]

def coefficients(self) -> dict[int, complex]:
"""Return a snapshot of the sparse coefficient vector.
Expand Down Expand Up @@ -316,7 +341,7 @@ def run(self, prog: StimProgram) -> list[MeasurementResult]:
fresh tableau per shot).
"""
raw = self._interface.run(prog)
return [MeasurementResult(x) for x in raw]
return [_BY_VALUE[x] for x in raw]

# stim familiarity alias
do = run
Expand Down Expand Up @@ -345,7 +370,7 @@ def sample(
"""
native_cls = _native_tableau_cls(n_qubits)
raw = native_cls.sample(prog, n_qubits, min_abs_coeff, num_shots, seed)
return [[MeasurementResult(x) for x in shot] for shot in raw]
return [[_BY_VALUE[x] for x in shot] for shot in raw]


def sample_stim(
Expand Down
8 changes: 1 addition & 7 deletions ppvm-python/src/ppvm/generalized_tableau_sum.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import cast

from . import _core
from .generalized_tableau import MeasurementResult
from .generalized_tableau import _BY_VALUE, MeasurementResult
from .mixins import (
CliffordExtensionMixin,
CliffordMixin,
Expand All @@ -17,12 +17,6 @@
)
from .types import GeneralizedTableauSumInterface, TableauSumSamplerInterface

# Indexed by integer outcome value (0/1/2) to reuse the singleton enum members.
# This is much faster than calling ``MeasurementResult(i)`` per element: the
# IntEnum constructor dominates large shot batches, while a tuple index just
# bumps a refcount.
_BY_VALUE = (MeasurementResult.ZERO, MeasurementResult.ONE, MeasurementResult.LOST)


@dataclass(frozen=True)
class GeneralizedTableauSum(
Expand Down
54 changes: 37 additions & 17 deletions ppvm-python/src/ppvm/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,32 +43,49 @@ def _is_sequence(obj: Any) -> bool:
A bare ``int`` — including a numpy integer scalar, which is not iterable —
is not a sequence, so it falls through to the variadic path.
"""
return isinstance(obj, Iterable) and not isinstance(obj, (str, bytes))


def _normalize_targets(args: tuple[Any, ...]) -> list[int]:
# Concrete-type fast paths first: ``list``/``tuple`` (the overwhelmingly
# common splatted form) and bare ``int`` short-circuit before the slow ABC
# ``isinstance(obj, Iterable)`` dispatch, which is run only for the rare
# range / ndarray / generator cases. ``str``/``bytes`` are iterable but are
# never targets, so they report False.
if isinstance(obj, (list, tuple)):
return True
if isinstance(obj, (int, str, bytes)):
return False
return isinstance(obj, Iterable)


def _normalize_targets(args: tuple[Any, ...]) -> Sequence[int]:
"""Resolve gate targets passed either as variadic indices (``x(0, 1, 2)``)
or as a single sequence (``x([0, 1, 2])``, ``x(np.array([0, 1, 2]))``)."""
or as a single sequence (``x([0, 1, 2])``, ``x(np.array([0, 1, 2]))``).

Returns the targets as-is — the single sequence, or the variadic ``args``
tuple. The native layer extracts a ``Vec<usize>`` directly (PyO3 handles
Python ints, numpy integer scalars, ranges and ndarrays), so there is no
need to rebuild the list with a per-element ``int()`` on the hot path."""
if len(args) == 1 and _is_sequence(args[0]):
return [int(t) for t in args[0]]
return [int(t) for t in args]
return args[0]
return args


def _split_targets_parameter(
args: tuple[Any, ...],
value: Any | None,
name: str,
) -> tuple[list[int], Any]:
) -> tuple[Sequence[int], Any]:
Comment on lines 71 to +75
"""Split ``(*targets, value)`` accepting ``value=...`` and a single leading
sequence of targets (``([0, 1, 2], theta)`` as well as ``(0, 1, 2, theta)``)."""
sequence of targets (``([0, 1, 2], theta)`` as well as ``(0, 1, 2, theta)``).

Targets are returned as-is (sequence or tuple slice); the native layer does
the ``Vec<usize>`` extraction, so no per-element ``int()`` rebuild is needed."""
if args and _is_sequence(args[0]):
targets, rest = [int(t) for t in args[0]], args[1:]
targets, rest = args[0], args[1:]
elif value is None:
if not args:
raise TypeError(f"missing required argument: {name!r}")
targets, rest = [int(t) for t in args[:-1]], args[-1:]
targets, rest = args[:-1], args[-1:]
else:
targets, rest = [int(t) for t in args], ()
targets, rest = args, ()
if value is None:
if not rest:
raise TypeError(f"missing required argument: {name!r}")
Expand All @@ -81,11 +98,14 @@ def _split_targets_parameter_truncate(
value: Any | None,
name: str,
truncate: bool,
) -> tuple[list[int], Any, bool]:
) -> tuple[Sequence[int], Any, bool]:
"""Split ``(*targets, value[, truncate])`` for PauliSum methods, also
accepting a single leading sequence of targets."""
accepting a single leading sequence of targets.

Targets are returned as-is (sequence or tuple/list slice); the native layer
does the ``Vec<usize>`` extraction, so no per-element ``int()`` is needed."""
if args and _is_sequence(args[0]):
targets = [int(t) for t in args[0]]
targets = args[0]
rest = list(args[1:])
if rest and isinstance(rest[-1], bool):
truncate = rest.pop()
Expand All @@ -100,8 +120,8 @@ def _split_targets_parameter_truncate(
args_list = list(args)
if len(args_list) >= 2 and isinstance(args_list[-1], bool):
truncate = args_list.pop()
return [int(t) for t in args_list[:-1]], args_list[-1], truncate
return [int(t) for t in args], value, truncate
return args_list[:-1], args_list[-1], truncate
return args, value, truncate


class CliffordMixin:
Expand Down
Loading
Loading