Skip to content

Commit d811be0

Browse files
authored
Make BloqBuilder.join accept array-like of soquets (#1509)
1 parent 47a36c8 commit d811be0

2 files changed

Lines changed: 9 additions & 2 deletions

File tree

qualtran/_infra/composite_bloq.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1234,12 +1234,13 @@ def split(self, soq: Soquet) -> NDArray[Soquet]: # type: ignore[type-var]
12341234

12351235
return self.add(Split(dtype=soq.reg.dtype), reg=soq)
12361236

1237-
def join(self, soqs: NDArray[Soquet], dtype: Optional[QDType] = None) -> Soquet: # type: ignore[type-var]
1237+
def join(self, soqs: SoquetInT, dtype: Optional[QDType] = None) -> Soquet:
12381238
from qualtran.bloqs.bookkeeping import Join
12391239

12401240
try:
1241+
soqs = np.asarray(soqs)
12411242
(n,) = soqs.shape
1242-
except AttributeError:
1243+
except (AttributeError, ValueError):
12431244
raise ValueError("`join` expects a 1-d array of input soquets to join.") from None
12441245

12451246
if not all(soq.reg.bitsize == 1 for soq in soqs):

qualtran/_infra/composite_bloq_test.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -389,6 +389,12 @@ def test_util_convenience_methods():
389389
assert len(cbloq.connections) == 1 + 10 + 1
390390

391391

392+
def test_join_list():
393+
bb = BloqBuilder()
394+
qs = [bb.allocate() for _ in range(10)]
395+
_ = bb.join(qs)
396+
397+
392398
def test_util_convenience_methods_errors():
393399
bb = BloqBuilder()
394400

0 commit comments

Comments
 (0)