@@ -40,16 +40,25 @@ async def store(self, store_kwargs: dict[str, Any]) -> S:
4040 """Create and open a store instance."""
4141 return await self .store_cls .open (** store_kwargs )
4242
43+ @staticmethod
44+ def _get_semaphore (store : Store ) -> asyncio .Semaphore | None :
45+ """Get the semaphore from a store, or None if the store doesn't support concurrency limiting."""
46+ get_semaphore = getattr (store , "get_semaphore" , None )
47+ if get_semaphore is not None :
48+ return get_semaphore () # type: ignore[no-any-return]
49+ return None
50+
4351 def test_concurrency_limit_default (self , store : S ) -> None :
4452 """Test that store has the expected default concurrency limit."""
45- if hasattr (store , "_semaphore" ):
46- if self .expected_concurrency_limit is None :
47- assert store ._semaphore is None , "Expected no concurrency limit"
48- else :
49- assert store ._semaphore is not None , "Expected concurrency limit to be set"
50- assert store ._semaphore ._value == self .expected_concurrency_limit , (
51- f"Expected limit { self .expected_concurrency_limit } , got { store ._semaphore ._value } "
52- )
53+ semaphore = self ._get_semaphore (store )
54+ if semaphore is None and self .expected_concurrency_limit is not None :
55+ pytest .fail ("Expected concurrency limit to be set" )
56+ if semaphore is not None and self .expected_concurrency_limit is None :
57+ pytest .fail ("Expected no concurrency limit" )
58+ if semaphore is not None and self .expected_concurrency_limit is not None :
59+ assert semaphore ._value == self .expected_concurrency_limit , (
60+ f"Expected limit { self .expected_concurrency_limit } , got { semaphore ._value } "
61+ )
5362
5463 def test_concurrency_limit_custom (self , store_kwargs : dict [str , Any ]) -> None :
5564 """Test that custom concurrency limits can be set."""
@@ -58,25 +67,25 @@ def test_concurrency_limit_custom(self, store_kwargs: dict[str, Any]) -> None:
5867
5968 # Test with custom limit
6069 store = self .store_cls (** {** store_kwargs , "concurrency_limit" : 42 })
61- if hasattr (store , "_semaphore" ):
62- assert store . _semaphore is not None
63- assert store . _semaphore ._value == 42
70+ semaphore = self . _get_semaphore (store )
71+ assert semaphore is not None
72+ assert semaphore ._value == 42
6473
6574 # Test with None (unlimited)
6675 store = self .store_cls (** {** store_kwargs , "concurrency_limit" : None })
67- if hasattr (store , "_semaphore" ):
68- assert store ._semaphore is None
76+ assert self ._get_semaphore (store ) is None
6977
7078 async def test_concurrency_limit_enforced (self , store : S ) -> None :
7179 """Test that the concurrency limit is actually enforced during execution.
7280
7381 This test verifies that when many operations are submitted concurrently,
7482 only up to the concurrency limit are actually executing at once.
7583 """
76- if not hasattr (store , "_semaphore" ) or store ._semaphore is None :
84+ semaphore = self ._get_semaphore (store )
85+ if semaphore is None :
7786 pytest .skip ("Store has no concurrency limit" )
7887
79- limit = store . _semaphore ._value
88+ limit = semaphore ._value
8089
8190 # We'll monitor the semaphore's available count
8291 # When it reaches 0, that means `limit` operations are running
@@ -86,7 +95,7 @@ async def monitored_operation(key: str, value: B) -> None:
8695 nonlocal min_available
8796 # Check semaphore state right after we're scheduled
8897 await asyncio .sleep (0 ) # Yield to ensure we're in the queue
89- available = store . _semaphore ._value
98+ available = semaphore ._value
9099 min_available = min (min_available , available )
91100
92101 # Now do the actual operation (which will acquire the semaphore)
0 commit comments