Skip to content

Commit 05d191a

Browse files
committed
add store concurrency tests
1 parent 229e3b3 commit 05d191a

4 files changed

Lines changed: 284 additions & 7 deletions

File tree

src/zarr/storage/_utils.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -55,22 +55,26 @@ async def get(self, key: str) -> Buffer | None:
5555
def decorator(
5656
func: Callable[P, Coroutine[Any, Any, T_co]],
5757
) -> Callable[P, Coroutine[Any, Any, T_co]]:
58+
"""
59+
This decorator wraps the invocation of `func` in an `async with semaphore` context manager.
60+
The semaphore object is resolved by getting the `semaphor_attr` attribute from the first
61+
argument to func. When this decorator is used on a method of a class, that first argument
62+
is a reference to the class instance (`self`).
63+
"""
64+
5865
@functools.wraps(func)
5966
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T_co:
6067
# First arg should be 'self'
6168
if not args:
6269
raise TypeError(f"{func.__name__} requires at least one argument (self)")
6370

6471
self = args[0]
65-
semaphore: asyncio.Semaphore | None = getattr(self, semaphore_attr, None)
6672

67-
if semaphore is None:
68-
# No concurrency limit - run directly
73+
semaphore: asyncio.Semaphore = getattr(self, semaphore_attr)
74+
75+
# Apply concurrency limit
76+
async with semaphore:
6977
return await func(*args, **kwargs)
70-
else:
71-
# Apply concurrency limit
72-
async with semaphore:
73-
return await func(*args, **kwargs)
7478

7579
return wrapper
7680

Lines changed: 247 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,247 @@
1+
"""Base test class for store concurrency limiting behavior."""
2+
3+
from __future__ import annotations
4+
5+
import asyncio
6+
from typing import TYPE_CHECKING, Generic, TypeVar
7+
8+
import pytest
9+
10+
from zarr.core.buffer import Buffer, default_buffer_prototype
11+
12+
if TYPE_CHECKING:
13+
from zarr.abc.store import Store
14+
15+
__all__ = ["StoreConcurrencyTests"]
16+
17+
18+
S = TypeVar("S", bound="Store")
19+
B = TypeVar("B", bound="Buffer")
20+
21+
22+
class StoreConcurrencyTests(Generic[S, B]):
23+
"""Base class for testing store concurrency limiting behavior.
24+
25+
This mixin provides tests for verifying that stores correctly implement
26+
concurrency limiting.
27+
28+
Subclasses should set:
29+
- store_cls: The store class being tested
30+
- buffer_cls: The buffer class to use (e.g., cpu.Buffer)
31+
- expected_concurrency_limit: Expected default concurrency limit (or None for unlimited)
32+
"""
33+
34+
store_cls: type[S]
35+
buffer_cls: type[B]
36+
expected_concurrency_limit: int | None
37+
38+
@pytest.fixture
39+
async def store(self, store_kwargs: dict) -> S:
40+
"""Create and open a store instance."""
41+
return await self.store_cls.open(**store_kwargs)
42+
43+
def test_concurrency_limit_default(self, store: S) -> None:
44+
"""Test that store has the expected default concurrency limit."""
45+
if hasattr(store, "_semaphore"):
46+
if self.expected_concurrency_limit is None:
47+
assert store._semaphore is None, "Expected no concurrency limit"
48+
else:
49+
assert store._semaphore is not None, "Expected concurrency limit to be set"
50+
assert store._semaphore._value == self.expected_concurrency_limit, (
51+
f"Expected limit {self.expected_concurrency_limit}, got {store._semaphore._value}"
52+
)
53+
54+
def test_concurrency_limit_custom(self, store_kwargs: dict) -> None:
55+
"""Test that custom concurrency limits can be set."""
56+
if "concurrency_limit" not in self.store_cls.__init__.__code__.co_varnames:
57+
pytest.skip("Store does not support custom concurrency limits")
58+
59+
# Test with custom limit
60+
store = self.store_cls(**store_kwargs, concurrency_limit=42)
61+
if hasattr(store, "_semaphore"):
62+
assert store._semaphore is not None
63+
assert store._semaphore._value == 42
64+
65+
# Test with None (unlimited)
66+
store = self.store_cls(**store_kwargs, concurrency_limit=None)
67+
if hasattr(store, "_semaphore"):
68+
assert store._semaphore is None
69+
70+
async def test_concurrency_limit_enforced(self, store: S) -> None:
71+
"""Test that the concurrency limit is actually enforced during execution.
72+
73+
This test verifies that when many operations are submitted concurrently,
74+
only up to the concurrency limit are actually executing at once.
75+
"""
76+
if not hasattr(store, "_semaphore") or store._semaphore is None:
77+
pytest.skip("Store has no concurrency limit")
78+
79+
limit = store._semaphore._value
80+
81+
# We'll monitor the semaphore's available count
82+
# When it reaches 0, that means `limit` operations are running
83+
min_available = limit
84+
85+
async def monitored_operation(key: str, value: B) -> None:
86+
nonlocal min_available
87+
# Check semaphore state right after we're scheduled
88+
await asyncio.sleep(0) # Yield to ensure we're in the queue
89+
available = store._semaphore._value
90+
min_available = min(min_available, available)
91+
92+
# Now do the actual operation (which will acquire the semaphore)
93+
await store.set(key, value)
94+
95+
# Launch more operations than the limit to ensure contention
96+
num_ops = limit * 2
97+
items = [
98+
(f"limit_test_key_{i}", self.buffer_cls.from_bytes(f"value_{i}".encode()))
99+
for i in range(num_ops)
100+
]
101+
102+
await asyncio.gather(*[monitored_operation(k, v) for k, v in items])
103+
104+
# The semaphore should have been fully utilized (reached 0 or close to it)
105+
# This indicates that `limit` operations were running concurrently
106+
assert min_available < limit, (
107+
f"Semaphore was never fully utilized. "
108+
f"Min available: {min_available}, Limit: {limit}. "
109+
f"This suggests operations aren't running concurrently."
110+
)
111+
112+
# Ideally it should reach 0, but allow some slack for timing
113+
assert min_available <= 5, (
114+
f"Semaphore only reached {min_available} available slots. "
115+
f"Expected close to 0 with limit {limit}."
116+
)
117+
118+
async def test_batch_write_no_deadlock(self, store: S) -> None:
119+
"""Test that batch writes don't deadlock when exceeding concurrency limit."""
120+
# Create more items than any reasonable concurrency limit
121+
num_items = 200
122+
items = [
123+
(f"test_key_{i}", self.buffer_cls.from_bytes(f"test_value_{i}".encode()))
124+
for i in range(num_items)
125+
]
126+
127+
# This should complete without deadlock, even if num_items > concurrency_limit
128+
await asyncio.wait_for(store._set_many(items), timeout=30.0)
129+
130+
# Verify all items were written correctly
131+
for key, expected_value in items:
132+
result = await store.get(key, default_buffer_prototype())
133+
assert result is not None
134+
assert result.to_bytes() == expected_value.to_bytes()
135+
136+
async def test_batch_read_no_deadlock(self, store: S) -> None:
137+
"""Test that batch reads don't deadlock when exceeding concurrency limit."""
138+
# Write test data
139+
num_items = 200
140+
test_data = {
141+
f"test_key_{i}": self.buffer_cls.from_bytes(f"test_value_{i}".encode())
142+
for i in range(num_items)
143+
}
144+
145+
for key, value in test_data.items():
146+
await store.set(key, value)
147+
148+
# Read all items concurrently - should not deadlock
149+
keys_and_ranges = [(key, None) for key in test_data]
150+
results = await asyncio.wait_for(
151+
store.get_partial_values(default_buffer_prototype(), keys_and_ranges),
152+
timeout=30.0,
153+
)
154+
155+
# Verify results
156+
assert len(results) == num_items
157+
for result, (key, expected_value) in zip(results, test_data.items()):
158+
assert result is not None
159+
assert result.to_bytes() == expected_value.to_bytes()
160+
161+
async def test_batch_delete_no_deadlock(self, store: S) -> None:
162+
"""Test that batch deletes don't deadlock when exceeding concurrency limit."""
163+
if not store.supports_deletes:
164+
pytest.skip("Store does not support deletes")
165+
166+
# Write test data
167+
num_items = 200
168+
keys = [f"test_key_{i}" for i in range(num_items)]
169+
for key in keys:
170+
await store.set(key, self.buffer_cls.from_bytes(b"test_value"))
171+
172+
# Delete all items concurrently - should not deadlock
173+
await asyncio.wait_for(asyncio.gather(*[store.delete(key) for key in keys]), timeout=30.0)
174+
175+
# Verify all items were deleted
176+
for key in keys:
177+
result = await store.get(key, default_buffer_prototype())
178+
assert result is None
179+
180+
async def test_concurrent_operations_correctness(self, store: S) -> None:
181+
"""Test that concurrent operations produce correct results."""
182+
num_operations = 100
183+
184+
# Mix of reads and writes
185+
write_keys = [f"write_key_{i}" for i in range(num_operations)]
186+
write_values = [
187+
self.buffer_cls.from_bytes(f"value_{i}".encode()) for i in range(num_operations)
188+
]
189+
190+
# Write all concurrently
191+
await asyncio.gather(*[store.set(k, v) for k, v in zip(write_keys, write_values)])
192+
193+
# Read all concurrently
194+
results = await asyncio.gather(
195+
*[store.get(k, default_buffer_prototype()) for k in write_keys]
196+
)
197+
198+
# Verify correctness
199+
for result, expected in zip(results, write_values):
200+
assert result is not None
201+
assert result.to_bytes() == expected.to_bytes()
202+
203+
@pytest.mark.parametrize("batch_size", [1, 10, 50, 100])
204+
async def test_various_batch_sizes(self, store: S, batch_size: int) -> None:
205+
"""Test that various batch sizes work correctly."""
206+
items = [
207+
(f"batch_key_{i}", self.buffer_cls.from_bytes(f"batch_value_{i}".encode()))
208+
for i in range(batch_size)
209+
]
210+
211+
# Should complete without issues for any batch size
212+
await asyncio.wait_for(store._set_many(items), timeout=10.0)
213+
214+
# Verify
215+
for key, expected_value in items:
216+
result = await store.get(key, default_buffer_prototype())
217+
assert result is not None
218+
assert result.to_bytes() == expected_value.to_bytes()
219+
220+
async def test_empty_batch_operations(self, store: S) -> None:
221+
"""Test that empty batch operations don't cause issues."""
222+
# Empty batch should not raise
223+
await store._set_many([])
224+
225+
# Empty read batch
226+
results = await store.get_partial_values(default_buffer_prototype(), [])
227+
assert results == []
228+
229+
async def test_mixed_success_failure_batch(self, store: S) -> None:
230+
"""Test batch operations with mix of successful and failing items."""
231+
# Write some initial data
232+
await store.set("existing_key", self.buffer_cls.from_bytes(b"existing_value"))
233+
234+
# Try to read mix of existing and non-existing keys
235+
key_ranges = [
236+
("existing_key", None),
237+
("non_existing_key_1", None),
238+
("non_existing_key_2", None),
239+
]
240+
241+
results = await store.get_partial_values(default_buffer_prototype(), key_ranges)
242+
243+
# First should exist, others should be None
244+
assert results[0] is not None
245+
assert results[0].to_bytes() == b"existing_value"
246+
assert results[1] is None
247+
assert results[2] is None

tests/test_store/test_local.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from zarr.storage import LocalStore
1313
from zarr.storage._local import _atomic_write
1414
from zarr.testing.store import StoreTests
15+
from zarr.testing.store_concurrency import StoreConcurrencyTests
1516
from zarr.testing.utils import assert_bytes_equal
1617

1718

@@ -150,3 +151,15 @@ def test_atomic_write_exclusive_preexisting(tmp_path: pathlib.Path) -> None:
150151
f.write(b"abc")
151152
assert path.read_bytes() == b"xyz"
152153
assert list(path.parent.iterdir()) == [path] # no temp files
154+
155+
156+
class TestLocalStoreConcurrency(StoreConcurrencyTests[LocalStore, cpu.Buffer]):
157+
"""Test LocalStore concurrency limiting behavior."""
158+
159+
store_cls = LocalStore
160+
buffer_cls = cpu.Buffer
161+
expected_concurrency_limit = 100 # LocalStore default
162+
163+
@pytest.fixture
164+
def store_kwargs(self, tmpdir: str) -> dict[str, str]:
165+
return {"root": str(tmpdir)}

tests/test_store/test_memory.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from zarr.errors import ZarrUserWarning
1313
from zarr.storage import GpuMemoryStore, MemoryStore
1414
from zarr.testing.store import StoreTests
15+
from zarr.testing.store_concurrency import StoreConcurrencyTests
1516
from zarr.testing.utils import gpu_test
1617

1718
if TYPE_CHECKING:
@@ -130,3 +131,15 @@ def test_from_dict(self) -> None:
130131
result = GpuMemoryStore.from_dict(d)
131132
for v in result._store_dict.values():
132133
assert type(v) is gpu.Buffer
134+
135+
136+
class TestMemoryStoreConcurrency(StoreConcurrencyTests[MemoryStore, cpu.Buffer]):
137+
"""Test MemoryStore concurrency limiting behavior."""
138+
139+
store_cls = MemoryStore
140+
buffer_cls = cpu.Buffer
141+
expected_concurrency_limit = None # MemoryStore has no limit (fast in-memory ops)
142+
143+
@pytest.fixture
144+
def store_kwargs(self) -> dict[str, Any]:
145+
return {"store_dict": None}

0 commit comments

Comments
 (0)