Skip to content

Commit bfa5bc1

Browse files
authored
Use symbolics for computing register bitsizes (#1353)
* use symbolics for computing register bitsizes * symbolic `n` in `Partition`
1 parent 862e7e7 commit bfa5bc1

2 files changed

Lines changed: 10 additions & 7 deletions

File tree

qualtran/_infra/registers.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,10 @@
1919
from typing import cast, Dict, Iterable, Iterator, List, overload, Tuple, Union
2020

2121
import attrs
22-
import numpy as np
2322
import sympy
2423
from attrs import field, frozen
2524

26-
from qualtran.symbolics import is_symbolic, smax, SymbolicInt
25+
from qualtran.symbolics import is_symbolic, prod, smax, ssum, SymbolicInt
2726

2827
from .data_types import QAny, QBit, QDType
2928

@@ -99,7 +98,7 @@ def total_bits(self) -> int:
9998
10099
This is the product of bitsize and each of the dimensions in `shape`.
101100
"""
102-
return self.bitsize * int(np.prod(self.shape))
101+
return self.bitsize * prod(self.shape_symbolic)
103102

104103
def adjoint(self) -> 'Register':
105104
"""Return the 'adjoint' of this register by switching RIGHT and LEFT registers."""
@@ -202,8 +201,8 @@ def n_qubits(self) -> int:
202201
is taken to be the greater of the number of left or right qubits. A bloq with this
203202
signature uses at least this many qubits.
204203
"""
205-
left_size = sum(reg.total_bits() for reg in self.lefts())
206-
right_size = sum(reg.total_bits() for reg in self.rights())
204+
left_size = ssum(reg.total_bits() for reg in self.lefts())
205+
right_size = ssum(reg.total_bits() for reg in self.rights())
207206
return smax(left_size, right_size)
208207

209208
def __repr__(self):

qualtran/bloqs/bookkeeping/partition.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
)
3333
from qualtran.bloqs.bookkeeping._bookkeeping_bloq import _BookkeepingBloq
3434
from qualtran.drawing import directional_text_box, Text, WireSymbol
35+
from qualtran.symbolics import is_symbolic, ssum, SymbolicInt
3536

3637
if TYPE_CHECKING:
3738
import quimb.tensor as qtn
@@ -54,14 +55,14 @@ class Partition(_BookkeepingBloq):
5455
[user spec]: The registers provided by the `regs` argument. RIGHT by default.
5556
"""
5657

57-
n: int
58+
n: SymbolicInt
5859
regs: Tuple[Register, ...] = field(
5960
converter=lambda x: x if isinstance(x, tuple) else tuple(x), validator=validators.min_len(1)
6061
)
6162
partition: bool = True
6263

6364
def __attrs_post_init__(self):
64-
if self.n != sum(r.total_bits() for r in self.regs):
65+
if self.n != ssum(r.total_bits() for r in self.regs):
6566
raise ValueError("Total bitsize not equal to sum of registers to partition into")
6667
if len(set(r.name for r in self.regs)) != len(self.regs):
6768
raise ValueError("Duplicate register names")
@@ -104,6 +105,9 @@ def my_tensors(
104105
) -> List['qtn.Tensor']:
105106
import quimb.tensor as qtn
106107

108+
if is_symbolic(self.n):
109+
raise DecomposeTypeError(f"cannot compute tensors for symbolic {self}")
110+
107111
grouped = incoming['x'] if self.partition else outgoing['x']
108112
partitioned = outgoing if self.partition else incoming
109113

0 commit comments

Comments
 (0)