|
| 1 | +"""Base test class for store concurrency limiting behavior.""" |
| 2 | + |
| 3 | +from __future__ import annotations |
| 4 | + |
| 5 | +import asyncio |
| 6 | +from typing import TYPE_CHECKING, Generic, TypeVar |
| 7 | + |
| 8 | +import pytest |
| 9 | + |
| 10 | +from zarr.core.buffer import Buffer, default_buffer_prototype |
| 11 | + |
| 12 | +if TYPE_CHECKING: |
| 13 | + from zarr.abc.store import Store |
| 14 | + |
| 15 | +__all__ = ["StoreConcurrencyTests"] |
| 16 | + |
| 17 | + |
| 18 | +S = TypeVar("S", bound="Store") |
| 19 | +B = TypeVar("B", bound="Buffer") |
| 20 | + |
| 21 | + |
| 22 | +class StoreConcurrencyTests(Generic[S, B]): |
| 23 | + """Base class for testing store concurrency limiting behavior. |
| 24 | +
|
| 25 | + This mixin provides tests for verifying that stores correctly implement |
| 26 | + concurrency limiting. |
| 27 | +
|
| 28 | + Subclasses should set: |
| 29 | + - store_cls: The store class being tested |
| 30 | + - buffer_cls: The buffer class to use (e.g., cpu.Buffer) |
| 31 | + - expected_concurrency_limit: Expected default concurrency limit (or None for unlimited) |
| 32 | + """ |
| 33 | + |
| 34 | + store_cls: type[S] |
| 35 | + buffer_cls: type[B] |
| 36 | + expected_concurrency_limit: int | None |
| 37 | + |
| 38 | + @pytest.fixture |
| 39 | + async def store(self, store_kwargs: dict) -> S: |
| 40 | + """Create and open a store instance.""" |
| 41 | + return await self.store_cls.open(**store_kwargs) |
| 42 | + |
| 43 | + def test_concurrency_limit_default(self, store: S) -> None: |
| 44 | + """Test that store has the expected default concurrency limit.""" |
| 45 | + if hasattr(store, "_semaphore"): |
| 46 | + if self.expected_concurrency_limit is None: |
| 47 | + assert store._semaphore is None, "Expected no concurrency limit" |
| 48 | + else: |
| 49 | + assert store._semaphore is not None, "Expected concurrency limit to be set" |
| 50 | + assert store._semaphore._value == self.expected_concurrency_limit, ( |
| 51 | + f"Expected limit {self.expected_concurrency_limit}, got {store._semaphore._value}" |
| 52 | + ) |
| 53 | + |
| 54 | + def test_concurrency_limit_custom(self, store_kwargs: dict) -> None: |
| 55 | + """Test that custom concurrency limits can be set.""" |
| 56 | + if "concurrency_limit" not in self.store_cls.__init__.__code__.co_varnames: |
| 57 | + pytest.skip("Store does not support custom concurrency limits") |
| 58 | + |
| 59 | + # Test with custom limit |
| 60 | + store = self.store_cls(**store_kwargs, concurrency_limit=42) |
| 61 | + if hasattr(store, "_semaphore"): |
| 62 | + assert store._semaphore is not None |
| 63 | + assert store._semaphore._value == 42 |
| 64 | + |
| 65 | + # Test with None (unlimited) |
| 66 | + store = self.store_cls(**store_kwargs, concurrency_limit=None) |
| 67 | + if hasattr(store, "_semaphore"): |
| 68 | + assert store._semaphore is None |
| 69 | + |
| 70 | + async def test_concurrency_limit_enforced(self, store: S) -> None: |
| 71 | + """Test that the concurrency limit is actually enforced during execution. |
| 72 | +
|
| 73 | + This test verifies that when many operations are submitted concurrently, |
| 74 | + only up to the concurrency limit are actually executing at once. |
| 75 | + """ |
| 76 | + if not hasattr(store, "_semaphore") or store._semaphore is None: |
| 77 | + pytest.skip("Store has no concurrency limit") |
| 78 | + |
| 79 | + limit = store._semaphore._value |
| 80 | + |
| 81 | + # We'll monitor the semaphore's available count |
| 82 | + # When it reaches 0, that means `limit` operations are running |
| 83 | + min_available = limit |
| 84 | + |
| 85 | + async def monitored_operation(key: str, value: B) -> None: |
| 86 | + nonlocal min_available |
| 87 | + # Check semaphore state right after we're scheduled |
| 88 | + await asyncio.sleep(0) # Yield to ensure we're in the queue |
| 89 | + available = store._semaphore._value |
| 90 | + min_available = min(min_available, available) |
| 91 | + |
| 92 | + # Now do the actual operation (which will acquire the semaphore) |
| 93 | + await store.set(key, value) |
| 94 | + |
| 95 | + # Launch more operations than the limit to ensure contention |
| 96 | + num_ops = limit * 2 |
| 97 | + items = [ |
| 98 | + (f"limit_test_key_{i}", self.buffer_cls.from_bytes(f"value_{i}".encode())) |
| 99 | + for i in range(num_ops) |
| 100 | + ] |
| 101 | + |
| 102 | + await asyncio.gather(*[monitored_operation(k, v) for k, v in items]) |
| 103 | + |
| 104 | + # The semaphore should have been fully utilized (reached 0 or close to it) |
| 105 | + # This indicates that `limit` operations were running concurrently |
| 106 | + assert min_available < limit, ( |
| 107 | + f"Semaphore was never fully utilized. " |
| 108 | + f"Min available: {min_available}, Limit: {limit}. " |
| 109 | + f"This suggests operations aren't running concurrently." |
| 110 | + ) |
| 111 | + |
| 112 | + # Ideally it should reach 0, but allow some slack for timing |
| 113 | + assert min_available <= 5, ( |
| 114 | + f"Semaphore only reached {min_available} available slots. " |
| 115 | + f"Expected close to 0 with limit {limit}." |
| 116 | + ) |
| 117 | + |
| 118 | + async def test_batch_write_no_deadlock(self, store: S) -> None: |
| 119 | + """Test that batch writes don't deadlock when exceeding concurrency limit.""" |
| 120 | + # Create more items than any reasonable concurrency limit |
| 121 | + num_items = 200 |
| 122 | + items = [ |
| 123 | + (f"test_key_{i}", self.buffer_cls.from_bytes(f"test_value_{i}".encode())) |
| 124 | + for i in range(num_items) |
| 125 | + ] |
| 126 | + |
| 127 | + # This should complete without deadlock, even if num_items > concurrency_limit |
| 128 | + await asyncio.wait_for(store._set_many(items), timeout=30.0) |
| 129 | + |
| 130 | + # Verify all items were written correctly |
| 131 | + for key, expected_value in items: |
| 132 | + result = await store.get(key, default_buffer_prototype()) |
| 133 | + assert result is not None |
| 134 | + assert result.to_bytes() == expected_value.to_bytes() |
| 135 | + |
| 136 | + async def test_batch_read_no_deadlock(self, store: S) -> None: |
| 137 | + """Test that batch reads don't deadlock when exceeding concurrency limit.""" |
| 138 | + # Write test data |
| 139 | + num_items = 200 |
| 140 | + test_data = { |
| 141 | + f"test_key_{i}": self.buffer_cls.from_bytes(f"test_value_{i}".encode()) |
| 142 | + for i in range(num_items) |
| 143 | + } |
| 144 | + |
| 145 | + for key, value in test_data.items(): |
| 146 | + await store.set(key, value) |
| 147 | + |
| 148 | + # Read all items concurrently - should not deadlock |
| 149 | + keys_and_ranges = [(key, None) for key in test_data] |
| 150 | + results = await asyncio.wait_for( |
| 151 | + store.get_partial_values(default_buffer_prototype(), keys_and_ranges), |
| 152 | + timeout=30.0, |
| 153 | + ) |
| 154 | + |
| 155 | + # Verify results |
| 156 | + assert len(results) == num_items |
| 157 | + for result, (key, expected_value) in zip(results, test_data.items()): |
| 158 | + assert result is not None |
| 159 | + assert result.to_bytes() == expected_value.to_bytes() |
| 160 | + |
| 161 | + async def test_batch_delete_no_deadlock(self, store: S) -> None: |
| 162 | + """Test that batch deletes don't deadlock when exceeding concurrency limit.""" |
| 163 | + if not store.supports_deletes: |
| 164 | + pytest.skip("Store does not support deletes") |
| 165 | + |
| 166 | + # Write test data |
| 167 | + num_items = 200 |
| 168 | + keys = [f"test_key_{i}" for i in range(num_items)] |
| 169 | + for key in keys: |
| 170 | + await store.set(key, self.buffer_cls.from_bytes(b"test_value")) |
| 171 | + |
| 172 | + # Delete all items concurrently - should not deadlock |
| 173 | + await asyncio.wait_for(asyncio.gather(*[store.delete(key) for key in keys]), timeout=30.0) |
| 174 | + |
| 175 | + # Verify all items were deleted |
| 176 | + for key in keys: |
| 177 | + result = await store.get(key, default_buffer_prototype()) |
| 178 | + assert result is None |
| 179 | + |
| 180 | + async def test_concurrent_operations_correctness(self, store: S) -> None: |
| 181 | + """Test that concurrent operations produce correct results.""" |
| 182 | + num_operations = 100 |
| 183 | + |
| 184 | + # Mix of reads and writes |
| 185 | + write_keys = [f"write_key_{i}" for i in range(num_operations)] |
| 186 | + write_values = [ |
| 187 | + self.buffer_cls.from_bytes(f"value_{i}".encode()) for i in range(num_operations) |
| 188 | + ] |
| 189 | + |
| 190 | + # Write all concurrently |
| 191 | + await asyncio.gather(*[store.set(k, v) for k, v in zip(write_keys, write_values)]) |
| 192 | + |
| 193 | + # Read all concurrently |
| 194 | + results = await asyncio.gather( |
| 195 | + *[store.get(k, default_buffer_prototype()) for k in write_keys] |
| 196 | + ) |
| 197 | + |
| 198 | + # Verify correctness |
| 199 | + for result, expected in zip(results, write_values): |
| 200 | + assert result is not None |
| 201 | + assert result.to_bytes() == expected.to_bytes() |
| 202 | + |
| 203 | + @pytest.mark.parametrize("batch_size", [1, 10, 50, 100]) |
| 204 | + async def test_various_batch_sizes(self, store: S, batch_size: int) -> None: |
| 205 | + """Test that various batch sizes work correctly.""" |
| 206 | + items = [ |
| 207 | + (f"batch_key_{i}", self.buffer_cls.from_bytes(f"batch_value_{i}".encode())) |
| 208 | + for i in range(batch_size) |
| 209 | + ] |
| 210 | + |
| 211 | + # Should complete without issues for any batch size |
| 212 | + await asyncio.wait_for(store._set_many(items), timeout=10.0) |
| 213 | + |
| 214 | + # Verify |
| 215 | + for key, expected_value in items: |
| 216 | + result = await store.get(key, default_buffer_prototype()) |
| 217 | + assert result is not None |
| 218 | + assert result.to_bytes() == expected_value.to_bytes() |
| 219 | + |
| 220 | + async def test_empty_batch_operations(self, store: S) -> None: |
| 221 | + """Test that empty batch operations don't cause issues.""" |
| 222 | + # Empty batch should not raise |
| 223 | + await store._set_many([]) |
| 224 | + |
| 225 | + # Empty read batch |
| 226 | + results = await store.get_partial_values(default_buffer_prototype(), []) |
| 227 | + assert results == [] |
| 228 | + |
| 229 | + async def test_mixed_success_failure_batch(self, store: S) -> None: |
| 230 | + """Test batch operations with mix of successful and failing items.""" |
| 231 | + # Write some initial data |
| 232 | + await store.set("existing_key", self.buffer_cls.from_bytes(b"existing_value")) |
| 233 | + |
| 234 | + # Try to read mix of existing and non-existing keys |
| 235 | + key_ranges = [ |
| 236 | + ("existing_key", None), |
| 237 | + ("non_existing_key_1", None), |
| 238 | + ("non_existing_key_2", None), |
| 239 | + ] |
| 240 | + |
| 241 | + results = await store.get_partial_values(default_buffer_prototype(), key_ranges) |
| 242 | + |
| 243 | + # First should exist, others should be None |
| 244 | + assert results[0] is not None |
| 245 | + assert results[0].to_bytes() == b"existing_value" |
| 246 | + assert results[1] is None |
| 247 | + assert results[2] is None |
0 commit comments