Skip to content

Commit 229e3b3

Browse files
committed
move concurrency limiting logic to stores
1 parent 6aeb834 commit 229e3b3

10 files changed

Lines changed: 304 additions & 170 deletions

File tree

src/zarr/abc/codec.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import asyncio
34
from abc import abstractmethod
45
from collections.abc import Mapping
56
from typing import TYPE_CHECKING, Generic, TypeGuard, TypeVar
@@ -8,7 +9,7 @@
89

910
from zarr.abc.metadata import Metadata
1011
from zarr.core.buffer import Buffer, NDBuffer
11-
from zarr.core.common import NamedConfig, concurrent_map
12+
from zarr.core.common import NamedConfig
1213

1314
if TYPE_CHECKING:
1415
from collections.abc import Awaitable, Callable, Iterable
@@ -224,10 +225,8 @@ async def decode_partial(
224225
-------
225226
Iterable[NDBuffer | None]
226227
"""
227-
return await concurrent_map(
228-
list(batch_info),
229-
self._decode_partial_single,
230-
)
228+
# Store handles concurrency limiting internally
229+
return await asyncio.gather(*[self._decode_partial_single(*info) for info in batch_info])
231230

232231

233232
class ArrayBytesCodecPartialEncodeMixin:
@@ -260,10 +259,8 @@ async def encode_partial(
260259
The ByteSetter is used to write the necessary bytes and fetch bytes for existing chunk data.
261260
The chunk spec contains information about the chunk.
262261
"""
263-
await concurrent_map(
264-
list(batch_info),
265-
self._encode_partial_single,
266-
)
262+
# Store handles concurrency limiting internally
263+
await asyncio.gather(*[self._encode_partial_single(*info) for info in batch_info])
267264

268265

269266
class CodecPipeline:
@@ -461,10 +458,8 @@ async def _batching_helper(
461458
func: Callable[[CodecInput, ArraySpec], Awaitable[CodecOutput | None]],
462459
batch_info: Iterable[tuple[CodecInput | None, ArraySpec]],
463460
) -> list[CodecOutput | None]:
464-
return await concurrent_map(
465-
list(batch_info),
466-
_noop_for_none(func),
467-
)
461+
# Store handles concurrency limiting internally
462+
return await asyncio.gather(*[_noop_for_none(func)(chunk, spec) for chunk, spec in batch_info])
468463

469464

470465
def _noop_for_none(

src/zarr/abc/store.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import asyncio
34
from abc import ABC, abstractmethod
45
from asyncio import gather
56
from dataclasses import dataclass
@@ -462,11 +463,8 @@ async def getsize_prefix(self, prefix: str) -> int:
462463
# improve tail latency and might reduce memory pressure (since not all keys
463464
# would be in memory at once).
464465

465-
# avoid circular import
466-
from zarr.core.common import concurrent_map
467-
468-
keys = [(x,) async for x in self.list_prefix(prefix)]
469-
sizes = await concurrent_map(keys, self.getsize)
466+
keys = [x async for x in self.list_prefix(prefix)]
467+
sizes = await asyncio.gather(*[self.getsize(key) for key in keys])
470468
return sum(sizes)
471469

472470

src/zarr/core/array.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import asyncio
34
import json
45
import warnings
56
from asyncio import gather
@@ -59,7 +60,6 @@
5960
_default_zarr_format,
6061
_warn_order_kwarg,
6162
ceildiv,
62-
concurrent_map,
6363
parse_shapelike,
6464
product,
6565
)
@@ -1847,12 +1847,12 @@ async def resize(self, new_shape: ShapeLike, delete_outside_chunks: bool = True)
18471847
async def _delete_key(key: str) -> None:
18481848
await (self.store_path / key).delete()
18491849

1850-
await concurrent_map(
1851-
[
1852-
(self.metadata.encode_chunk_key(chunk_coords),)
1850+
# Store handles concurrency limiting internally
1851+
await asyncio.gather(
1852+
*[
1853+
_delete_key(self.metadata.encode_chunk_key(chunk_coords))
18531854
for chunk_coords in old_chunk_coords.difference(new_chunk_coords)
1854-
],
1855-
_delete_key,
1855+
]
18561856
)
18571857

18581858
# Write new metadata
@@ -4533,19 +4533,19 @@ async def _copy_array_region(
45334533
await result.setitem(chunk_coords, arr)
45344534

45354535
# Stream data from the source array to the new array
4536-
await concurrent_map(
4537-
[(region, data) for region in result._iter_shard_regions()],
4538-
_copy_array_region,
4536+
# Store handles concurrency limiting internally
4537+
await asyncio.gather(
4538+
*[_copy_array_region(region, data) for region in result._iter_shard_regions()]
45394539
)
45404540
else:
45414541

45424542
async def _copy_arraylike_region(chunk_coords: slice, _data: NDArrayLike) -> None:
45434543
await result.setitem(chunk_coords, _data[chunk_coords])
45444544

45454545
# Stream data from the source array to the new array
4546-
await concurrent_map(
4547-
[(region, data) for region in result._iter_shard_regions()],
4548-
_copy_arraylike_region,
4546+
# Store handles concurrency limiting internally
4547+
await asyncio.gather(
4548+
*[_copy_arraylike_region(region, data) for region in result._iter_shard_regions()]
45494549
)
45504550
return result
45514551

src/zarr/core/codec_pipeline.py

Lines changed: 27 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import asyncio
34
from dataclasses import dataclass
45
from itertools import islice, pairwise
56
from typing import TYPE_CHECKING, Any, TypeVar
@@ -14,7 +15,6 @@
1415
Codec,
1516
CodecPipeline,
1617
)
17-
from zarr.core.common import concurrent_map
1818
from zarr.core.config import config
1919
from zarr.core.indexing import SelectorTuple, is_scalar
2020
from zarr.errors import ZarrUserWarning
@@ -267,9 +267,12 @@ async def read_batch(
267267
else:
268268
out[out_selection] = fill_value_or_default(chunk_spec)
269269
else:
270-
chunk_bytes_batch = await concurrent_map(
271-
[(byte_getter, array_spec.prototype) for byte_getter, array_spec, *_ in batch_info],
272-
lambda byte_getter, prototype: byte_getter.get(prototype),
270+
# Store handles concurrency limiting internally
271+
chunk_bytes_batch = await asyncio.gather(
272+
*[
273+
byte_getter.get(array_spec.prototype)
274+
for byte_getter, array_spec, *_ in batch_info
275+
]
273276
)
274277
chunk_array_batch = await self.decode_batch(
275278
[
@@ -367,15 +370,15 @@ async def _read_key(
367370
return await byte_setter.get(prototype=prototype)
368371

369372
chunk_bytes_batch: Iterable[Buffer | None]
370-
chunk_bytes_batch = await concurrent_map(
371-
[
372-
(
373+
# Store handles concurrency limiting internally
374+
chunk_bytes_batch = await asyncio.gather(
375+
*[
376+
_read_key(
373377
None if is_complete_chunk else byte_setter,
374378
chunk_spec.prototype,
375379
)
376380
for byte_setter, chunk_spec, chunk_selection, _, is_complete_chunk in batch_info
377-
],
378-
_read_key,
381+
]
379382
)
380383
chunk_array_decoded = await self.decode_batch(
381384
[
@@ -433,14 +436,14 @@ async def _write_key(byte_setter: ByteSetter, chunk_bytes: Buffer | None) -> Non
433436
else:
434437
await byte_setter.set(chunk_bytes)
435438

436-
await concurrent_map(
437-
[
438-
(byte_setter, chunk_bytes)
439+
# Store handles concurrency limiting internally
440+
await asyncio.gather(
441+
*[
442+
_write_key(byte_setter, chunk_bytes)
439443
for chunk_bytes, (byte_setter, *_) in zip(
440444
chunk_bytes_batch, batch_info, strict=False
441445
)
442-
],
443-
_write_key,
446+
]
444447
)
445448

446449
async def decode(
@@ -467,12 +470,12 @@ async def read(
467470
out: NDBuffer,
468471
drop_axes: tuple[int, ...] = (),
469472
) -> None:
470-
await concurrent_map(
471-
[
472-
(single_batch_info, out, drop_axes)
473+
# Process mini-batches concurrently - stores handle I/O concurrency internally
474+
await asyncio.gather(
475+
*[
476+
self.read_batch(single_batch_info, out, drop_axes)
473477
for single_batch_info in batched(batch_info, self.batch_size)
474-
],
475-
self.read_batch,
478+
]
476479
)
477480

478481
async def write(
@@ -481,12 +484,12 @@ async def write(
481484
value: NDBuffer,
482485
drop_axes: tuple[int, ...] = (),
483486
) -> None:
484-
await concurrent_map(
485-
[
486-
(single_batch_info, value, drop_axes)
487+
# Process mini-batches concurrently - stores handle I/O concurrency internally
488+
await asyncio.gather(
489+
*[
490+
self.write_batch(single_batch_info, value, drop_axes)
487491
for single_batch_info in batched(batch_info, self.batch_size)
488-
],
489-
self.write_batch,
492+
]
490493
)
491494

492495

src/zarr/storage/_fsspec.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import asyncio
34
import json
45
import warnings
56
from contextlib import suppress
@@ -17,6 +18,7 @@
1718
from zarr.core.buffer import Buffer
1819
from zarr.errors import ZarrUserWarning
1920
from zarr.storage._common import _dereference_path
21+
from zarr.storage._utils import with_concurrency_limit
2022

2123
if TYPE_CHECKING:
2224
from collections.abc import AsyncIterator, Iterable
@@ -82,6 +84,9 @@ class FsspecStore(Store):
8284
filesystem scheme.
8385
allowed_exceptions : tuple[type[Exception], ...]
8486
When fetching data, these cases will be deemed to correspond to missing keys.
87+
concurrency_limit : int, optional
88+
Maximum number of concurrent I/O operations. Default is 50.
89+
Set to None for unlimited concurrency.
8590
8691
Attributes
8792
----------
@@ -117,18 +122,24 @@ class FsspecStore(Store):
117122
fs: AsyncFileSystem
118123
allowed_exceptions: tuple[type[Exception], ...]
119124
path: str
125+
_semaphore: asyncio.Semaphore | None
120126

121127
def __init__(
122128
self,
123129
fs: AsyncFileSystem,
130+
*,
124131
read_only: bool = False,
125132
path: str = "/",
126133
allowed_exceptions: tuple[type[Exception], ...] = ALLOWED_EXCEPTIONS,
134+
concurrency_limit: int | None = 50,
127135
) -> None:
128136
super().__init__(read_only=read_only)
129137
self.fs = fs
130138
self.path = path
131139
self.allowed_exceptions = allowed_exceptions
140+
self._semaphore = (
141+
asyncio.Semaphore(concurrency_limit) if concurrency_limit is not None else None
142+
)
132143

133144
if not self.fs.async_impl:
134145
raise TypeError("Filesystem needs to support async operations.")
@@ -273,6 +284,7 @@ def __eq__(self, other: object) -> bool:
273284
and self.fs == other.fs
274285
)
275286

287+
@with_concurrency_limit()
276288
async def get(
277289
self,
278290
key: str,
@@ -315,6 +327,7 @@ async def get(
315327
else:
316328
return value
317329

330+
@with_concurrency_limit()
318331
async def set(
319332
self,
320333
key: str,
@@ -335,6 +348,27 @@ async def set(
335348
raise NotImplementedError
336349
await self.fs._pipe_file(path, value.to_bytes())
337350

351+
async def _set_many(self, values: Iterable[tuple[str, Buffer]]) -> None:
352+
# Override to avoid deadlock from calling decorated set() method
353+
if not self._is_open:
354+
await self._open()
355+
self._check_writable()
356+
357+
async def _set_with_limit(key: str, value: Buffer) -> None:
358+
if not isinstance(value, Buffer):
359+
raise TypeError(
360+
f"FsspecStore.set(): `value` must be a Buffer instance. Got an instance of {type(value)} instead."
361+
)
362+
path = _dereference_path(self.path, key)
363+
if self._semaphore:
364+
async with self._semaphore:
365+
await self.fs._pipe_file(path, value.to_bytes())
366+
else:
367+
await self.fs._pipe_file(path, value.to_bytes())
368+
369+
await asyncio.gather(*[_set_with_limit(key, value) for key, value in values])
370+
371+
@with_concurrency_limit()
338372
async def delete(self, key: str) -> None:
339373
# docstring inherited
340374
self._check_writable()

0 commit comments

Comments
 (0)