Skip to content

Commit 1e73be9

Browse files
committed
use protocol
1 parent 9cec350 commit 1e73be9

5 files changed

Lines changed: 82 additions & 53 deletions

File tree

src/zarr/storage/_fsspec.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -255,13 +255,19 @@ def from_url(
255255

256256
return cls(fs=fs, path=path, read_only=read_only, allowed_exceptions=allowed_exceptions)
257257

258+
def get_semaphore(self) -> asyncio.Semaphore | None:
259+
return self._semaphore
260+
258261
def with_read_only(self, read_only: bool = False) -> FsspecStore:
259262
# docstring inherited
263+
sem = self.get_semaphore()
264+
concurrency_limit = sem._value if sem else None
260265
return type(self)(
261266
fs=self.fs,
262267
path=self.path,
263268
allowed_exceptions=self.allowed_exceptions,
264269
read_only=read_only,
270+
concurrency_limit=concurrency_limit,
265271
)
266272

267273
async def clear(self) -> None:
@@ -353,15 +359,16 @@ async def _set_many(self, values: Iterable[tuple[str, Buffer]]) -> None:
353359
if not self._is_open:
354360
await self._open()
355361
self._check_writable()
362+
semaphore = self.get_semaphore()
356363

357364
async def _set_with_limit(key: str, value: Buffer) -> None:
358365
if not isinstance(value, Buffer):
359366
raise TypeError(
360367
f"FsspecStore.set(): `value` must be a Buffer instance. Got an instance of {type(value)} instead."
361368
)
362369
path = _dereference_path(self.path, key)
363-
if self._semaphore:
364-
async with self._semaphore:
370+
if semaphore:
371+
async with semaphore:
365372
await self.fs._pipe_file(path, value.to_bytes())
366373
else:
367374
await self.fs._pipe_file(path, value.to_bytes())

src/zarr/storage/_local.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -134,9 +134,13 @@ def __init__(
134134
asyncio.Semaphore(concurrency_limit) if concurrency_limit is not None else None
135135
)
136136

137+
def get_semaphore(self) -> asyncio.Semaphore | None:
138+
return self._semaphore
139+
137140
def with_read_only(self, read_only: bool = False) -> Self:
138141
# docstring inherited
139-
concurrency_limit = self._semaphore._value if self._semaphore else None
142+
sem = self.get_semaphore()
143+
concurrency_limit = sem._value if sem else None
140144
return type(self)(
141145
root=self.root,
142146
read_only=read_only,
@@ -232,11 +236,13 @@ async def get_partial_values(
232236
# Note: We directly call the I/O functions here, wrapped with semaphore
233237
# to avoid deadlock from calling the decorated get() method
234238

239+
semaphore = self.get_semaphore()
240+
235241
async def _get_with_limit(key: str, byte_range: ByteRequest | None) -> Buffer | None:
236242
path = self.root / key
237243
try:
238-
if self._semaphore:
239-
async with self._semaphore:
244+
if semaphore:
245+
async with semaphore:
240246
return await asyncio.to_thread(_get, path, prototype, byte_range)
241247
else:
242248
return await asyncio.to_thread(_get, path, prototype, byte_range)

src/zarr/storage/_obstore.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -86,9 +86,13 @@ def __init__(
8686
asyncio.Semaphore(concurrency_limit) if concurrency_limit is not None else None
8787
)
8888

89+
def get_semaphore(self) -> asyncio.Semaphore | None:
90+
return self._semaphore
91+
8992
def with_read_only(self, read_only: bool = False) -> Self:
9093
# docstring inherited
91-
concurrency_limit = self._semaphore._value if self._semaphore else None
94+
sem = self.get_semaphore()
95+
concurrency_limit = sem._value if sem else None
9296
return type(self)(
9397
store=self.store,
9498
read_only=read_only,
@@ -134,6 +138,7 @@ async def get_partial_values(
134138
import obstore as obs
135139

136140
key_ranges = list(key_ranges)
141+
semaphore = self.get_semaphore()
137142
# Group bounded range requests by path for batched fetching
138143
per_file_bounded: dict[str, list[tuple[int, RangeByteRequest]]] = defaultdict(list)
139144
other_requests: list[tuple[int, str, ByteRequest | None]] = []
@@ -150,8 +155,8 @@ async def _fetch_ranges(path: str, requests: list[tuple[int, RangeByteRequest]])
150155
"""Batch multiple range requests for the same file using get_ranges_async."""
151156
starts = [r.start for _, r in requests]
152157
ends = [r.end for _, r in requests]
153-
if self._semaphore:
154-
async with self._semaphore:
158+
if semaphore:
159+
async with semaphore:
155160
responses = await obs.get_ranges_async(
156161
self.store, path=path, starts=starts, ends=ends
157162
)
@@ -165,8 +170,8 @@ async def _fetch_ranges(path: str, requests: list[tuple[int, RangeByteRequest]])
165170
async def _fetch_one(idx: int, path: str, byte_range: ByteRequest | None) -> None:
166171
"""Fetch a single non-range request with semaphore limiting."""
167172
try:
168-
if self._semaphore:
169-
async with self._semaphore:
173+
if semaphore:
174+
async with semaphore:
170175
buffers[idx] = await self._get_impl(path, prototype, byte_range, obs)
171176
else:
172177
buffers[idx] = await self._get_impl(path, prototype, byte_range, obs)
@@ -250,11 +255,12 @@ async def _set_many(self, values: Iterable[tuple[str, Buffer]]) -> None:
250255
import obstore as obs
251256

252257
self._check_writable()
258+
semaphore = self.get_semaphore()
253259

254260
async def _set_with_limit(key: str, value: Buffer) -> None:
255261
buf = value.as_buffer_like()
256-
if self._semaphore:
257-
async with self._semaphore:
262+
if semaphore:
263+
async with semaphore:
258264
await obs.put_async(self.store, key, buf)
259265
else:
260266
await obs.put_async(self.store, key, buf)
@@ -268,8 +274,9 @@ async def set_if_not_exists(self, key: str, value: Buffer) -> None:
268274

269275
self._check_writable()
270276
buf = value.as_buffer_like()
271-
if self._semaphore:
272-
async with self._semaphore:
277+
semaphore = self.get_semaphore()
278+
if semaphore:
279+
async with semaphore:
273280
with contextlib.suppress(obs.exceptions.AlreadyExistsError):
274281
await obs.put_async(self.store, key, buf, mode="create")
275282
else:
@@ -304,11 +311,12 @@ async def delete_dir(self, prefix: str) -> None:
304311
prefix += "/"
305312

306313
metas = await obs.list(self.store, prefix).collect_async()
314+
semaphore = self.get_semaphore()
307315

308316
# Delete with semaphore limiting to avoid deadlock
309317
async def _delete_with_limit(path: str) -> None:
310-
if self._semaphore:
311-
async with self._semaphore:
318+
if semaphore:
319+
async with semaphore:
312320
with contextlib.suppress(FileNotFoundError):
313321
await obs.delete_async(self.store, path)
314322
else:

src/zarr/storage/_utils.py

Lines changed: 20 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import functools
44
import re
55
from pathlib import Path
6-
from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar
6+
from typing import TYPE_CHECKING, Any, ParamSpec, Protocol, TypeVar, runtime_checkable
77

88
from zarr.abc.store import OffsetByteRequest, RangeByteRequest, SuffixByteRequest
99

@@ -18,20 +18,24 @@
1818
T_co = TypeVar("T_co", covariant=True)
1919

2020

21-
def with_concurrency_limit(
22-
semaphore_attr: str = "_semaphore",
23-
) -> Callable[[Callable[P, Coroutine[Any, Any, T_co]]], Callable[P, Coroutine[Any, Any, T_co]]]:
21+
@runtime_checkable
22+
class HasConcurrencyLimit(Protocol):
23+
"""Protocol for stores that support concurrency limiting via a semaphore."""
24+
25+
def get_semaphore(self) -> asyncio.Semaphore | None:
26+
"""Return the semaphore used for concurrency limiting, or None for unlimited."""
27+
...
28+
29+
30+
def with_concurrency_limit() -> Callable[
31+
[Callable[P, Coroutine[Any, Any, T_co]]], Callable[P, Coroutine[Any, Any, T_co]]
32+
]:
2433
"""
2534
Decorator that applies a semaphore-based concurrency limit to an async method.
2635
27-
This decorator is designed for Store methods that need to limit concurrent operations.
28-
The store instance should have a `_semaphore` attribute (or custom attribute name)
29-
that is either an asyncio.Semaphore or None (for unlimited concurrency).
30-
31-
Parameters
32-
----------
33-
semaphore_attr : str, optional
34-
Name of the semaphore attribute on the class instance. Default is "_semaphore".
36+
This decorator is designed for methods on classes that implement the
37+
``HasConcurrencyLimit`` protocol. The class must define a ``get_semaphore()``
38+
method returning either an ``asyncio.Semaphore`` or ``None``.
3539
3640
Returns
3741
-------
@@ -45,6 +49,9 @@ class MyStore(Store):
4549
def __init__(self, concurrency_limit: int = 100):
4650
self._semaphore = asyncio.Semaphore(concurrency_limit) if concurrency_limit else None
4751
52+
def get_semaphore(self) -> asyncio.Semaphore | None:
53+
return self._semaphore
54+
4855
@with_concurrency_limit()
4956
async def get(self, key: str) -> Buffer | None:
5057
# This will only run when semaphore permits
@@ -55,22 +62,14 @@ async def get(self, key: str) -> Buffer | None:
5562
def decorator(
5663
func: Callable[P, Coroutine[Any, Any, T_co]],
5764
) -> Callable[P, Coroutine[Any, Any, T_co]]:
58-
"""
59-
This decorator wraps the invocation of `func` in an `async with semaphore` context manager.
60-
The semaphore object is resolved by getting the `semaphor_attr` attribute from the first
61-
argument to func. When this decorator is used on a method of a class, that first argument
62-
is a reference to the class instance (`self`).
63-
"""
64-
6565
@functools.wraps(func)
6666
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T_co:
6767
# First arg should be 'self'
6868
if not args:
6969
raise TypeError(f"{func.__name__} requires at least one argument (self)")
7070

7171
self = args[0]
72-
73-
semaphore: asyncio.Semaphore | None = getattr(self, semaphore_attr)
72+
semaphore: asyncio.Semaphore | None = self.get_semaphore() # type: ignore[attr-defined]
7473

7574
if semaphore is not None:
7675
async with semaphore:

src/zarr/testing/store_concurrency.py

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -40,16 +40,25 @@ async def store(self, store_kwargs: dict[str, Any]) -> S:
4040
"""Create and open a store instance."""
4141
return await self.store_cls.open(**store_kwargs)
4242

43+
@staticmethod
44+
def _get_semaphore(store: Store) -> asyncio.Semaphore | None:
45+
"""Get the semaphore from a store, or None if the store doesn't support concurrency limiting."""
46+
get_semaphore = getattr(store, "get_semaphore", None)
47+
if get_semaphore is not None:
48+
return get_semaphore() # type: ignore[no-any-return]
49+
return None
50+
4351
def test_concurrency_limit_default(self, store: S) -> None:
4452
"""Test that store has the expected default concurrency limit."""
45-
if hasattr(store, "_semaphore"):
46-
if self.expected_concurrency_limit is None:
47-
assert store._semaphore is None, "Expected no concurrency limit"
48-
else:
49-
assert store._semaphore is not None, "Expected concurrency limit to be set"
50-
assert store._semaphore._value == self.expected_concurrency_limit, (
51-
f"Expected limit {self.expected_concurrency_limit}, got {store._semaphore._value}"
52-
)
53+
semaphore = self._get_semaphore(store)
54+
if semaphore is None and self.expected_concurrency_limit is not None:
55+
pytest.fail("Expected concurrency limit to be set")
56+
if semaphore is not None and self.expected_concurrency_limit is None:
57+
pytest.fail("Expected no concurrency limit")
58+
if semaphore is not None and self.expected_concurrency_limit is not None:
59+
assert semaphore._value == self.expected_concurrency_limit, (
60+
f"Expected limit {self.expected_concurrency_limit}, got {semaphore._value}"
61+
)
5362

5463
def test_concurrency_limit_custom(self, store_kwargs: dict[str, Any]) -> None:
5564
"""Test that custom concurrency limits can be set."""
@@ -58,25 +67,25 @@ def test_concurrency_limit_custom(self, store_kwargs: dict[str, Any]) -> None:
5867

5968
# Test with custom limit
6069
store = self.store_cls(**{**store_kwargs, "concurrency_limit": 42})
61-
if hasattr(store, "_semaphore"):
62-
assert store._semaphore is not None
63-
assert store._semaphore._value == 42
70+
semaphore = self._get_semaphore(store)
71+
assert semaphore is not None
72+
assert semaphore._value == 42
6473

6574
# Test with None (unlimited)
6675
store = self.store_cls(**{**store_kwargs, "concurrency_limit": None})
67-
if hasattr(store, "_semaphore"):
68-
assert store._semaphore is None
76+
assert self._get_semaphore(store) is None
6977

7078
async def test_concurrency_limit_enforced(self, store: S) -> None:
7179
"""Test that the concurrency limit is actually enforced during execution.
7280
7381
This test verifies that when many operations are submitted concurrently,
7482
only up to the concurrency limit are actually executing at once.
7583
"""
76-
if not hasattr(store, "_semaphore") or store._semaphore is None:
84+
semaphore = self._get_semaphore(store)
85+
if semaphore is None:
7786
pytest.skip("Store has no concurrency limit")
7887

79-
limit = store._semaphore._value
88+
limit = semaphore._value
8089

8190
# We'll monitor the semaphore's available count
8291
# When it reaches 0, that means `limit` operations are running
@@ -86,7 +95,7 @@ async def monitored_operation(key: str, value: B) -> None:
8695
nonlocal min_available
8796
# Check semaphore state right after we're scheduled
8897
await asyncio.sleep(0) # Yield to ensure we're in the queue
89-
available = store._semaphore._value
98+
available = semaphore._value
9099
min_available = min(min_available, available)
91100

92101
# Now do the actual operation (which will acquire the semaphore)

0 commit comments

Comments
 (0)