Skip to content

Commit 8f6def9

Browse files
shoyercopybara-github
authored andcommitted
[xarray-tensorstore] add support for consolidated metadata
This should speed-up opening Zarrs with many groups considerably. Because TensorStore doesn't have any notion of groups, we need to fetch this information from Zarr-Python. PiperOrigin-RevId: 825739590
1 parent 34a52e6 commit 8f6def9

4 files changed

Lines changed: 161 additions & 97 deletions

File tree

.github/workflows/tests.yml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,13 @@ on:
1111

1212
jobs:
1313
tests:
14-
name: "python ${{ matrix.python-version }} tests"
14+
name: "python=${{ matrix.python-version }} zarr=${{ matrix.zarr-version }} tests"
1515
runs-on: ubuntu-latest
1616
strategy:
1717
fail-fast: false
1818
matrix:
19-
python-version: ["3.10", "3.11", "3.12", "3.13"]
19+
python-version: ["3.11", "3.12", "3.13"]
20+
zarr-version: [">=2,<3", ">=3"]
2021
steps:
2122
- name: Cancel previous
2223
uses: styfle/cancel-workflow-action@0.7.0
@@ -40,7 +41,7 @@ jobs:
4041
key: ${{ runner.os }}-pip-${{ hashFiles('**/setup.py') }}
4142
- name: Install Xarray-Tensorstore
4243
run: |
43-
pip install -e .[tests]
44+
pip install -e .[tests] "zarr${{ matrix.zarr-version }}"
4445
- name: Run unit tests
4546
run: |
4647
pytest .

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
setuptools.setup(
2020
name='xarray-tensorstore',
21-
version='0.2.0', # keep in sync with xarray_tensorstore.py
21+
version='0.3.0', # keep in sync with xarray_tensorstore.py
2222
license='Apache-2.0',
2323
author='Google LLC',
2424
author_email='noreply@google.com',

xarray_tensorstore.py

Lines changed: 80 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
import zarr
2929

3030

31-
__version__ = '0.2.0' # keep in sync with setup.py
31+
__version__ = '0.3.0' # keep in sync with setup.py
3232

3333

3434
Index = TypeVar('Index', int, slice, np.ndarray, None)
@@ -217,12 +217,49 @@ def _get_zarr_format(path: str) -> int:
217217
return 2
218218

219219

220+
def _open_tensorstore_arrays(
221+
path: str,
222+
names: list[str],
223+
group: zarr.Group | None,
224+
zarr_format: int,
225+
write: bool,
226+
context: tensorstore.Context | None = None,
227+
) -> dict[str, tensorstore.Future]:
228+
"""Open all arrays in a Zarr group using TensorStore."""
229+
specs = {
230+
k: _zarr_spec_from_path(os.path.join(path, k), zarr_format) for k in names
231+
}
232+
233+
assume_metadata = False
234+
if packaging.version.parse(zarr.__version__).major >= 3 and group is not None:
235+
consolidated_metadata = group.metadata.consolidated_metadata
236+
if consolidated_metadata is not None:
237+
assume_metadata = True
238+
for name in names:
239+
metadata = consolidated_metadata.metadata[name].to_dict()
240+
metadata.pop('attributes', None) # not supported by TensorStore
241+
specs[name]['metadata'] = metadata
242+
243+
array_futures = {}
244+
for k, spec in specs.items():
245+
array_futures[k] = tensorstore.open(
246+
spec,
247+
read=True,
248+
write=write,
249+
open=True,
250+
context=context,
251+
assume_metadata=assume_metadata,
252+
)
253+
return array_futures
254+
255+
220256
def open_zarr(
221257
path: str,
222258
*,
223259
context: tensorstore.Context | None = None,
224260
mask_and_scale: bool = True,
225261
write: bool = False,
262+
consolidated: bool | None = None,
226263
) -> xarray.Dataset:
227264
"""Open an xarray.Dataset from Zarr using TensorStore.
228265
@@ -252,6 +289,9 @@ def open_zarr(
252289
xarray.open_zarr(). This is only supported for coordinate variables and
253290
otherwise will raise an error.
254291
write: Allow write access. Defaults to False.
292+
consolidated: If True, read consolidated metadata. By default, an attempt to
293+
use consolidated metadata is made with a fallback to non-consolidated
294+
metadata, like in Xarray.
255295
256296
Returns:
257297
Dataset with all data variables opened via TensorStore.
@@ -272,8 +312,19 @@ def open_zarr(
272312
if context is None:
273313
context = tensorstore.Context()
274314

275-
# chunks=None means avoid using dask
276-
ds = xarray.open_zarr(path, chunks=None, mask_and_scale=mask_and_scale)
315+
# Open Xarray's backends.ZarrStore directly so we can get access to the
316+
# underlying Zarr group's consolidated metadata.
317+
store = xarray.backends.ZarrStore.open_group(
318+
path, consolidated=consolidated
319+
)
320+
group = store.zarr_group
321+
ds = xarray.open_dataset(
322+
filename_or_obj='', # ignored in favor of store=
323+
chunks=None, # avoid using dask
324+
mask_and_scale=mask_and_scale,
325+
store=store,
326+
engine='zarr',
327+
)
277328

278329
if mask_and_scale:
279330
# Data variables get replaced below with _TensorStoreAdapter arrays, which
@@ -282,13 +333,9 @@ def open_zarr(
282333
_raise_if_mask_and_scale_used_for_data_vars(ds)
283334

284335
zarr_format = _get_zarr_format(path)
285-
specs = {
286-
k: _zarr_spec_from_path(os.path.join(path, k), zarr_format) for k in ds
287-
}
288-
array_futures = {
289-
k: tensorstore.open(spec, read=True, write=write, context=context)
290-
for k, spec in specs.items()
291-
}
336+
array_futures = _open_tensorstore_arrays(
337+
path, list(ds), group, zarr_format, write=write, context=context
338+
)
292339
arrays = {k: v.result() for k, v in array_futures.items()}
293340
new_data = {k: _TensorStoreAdapter(v) for k, v in arrays.items()}
294341

@@ -304,20 +351,26 @@ def _tensorstore_open_concatenated_zarrs(
304351
"""Open multiple zarrs with TensorStore.
305352
306353
Args:
307-
paths: List of paths to zarr stores.
308-
data_vars: List of data variable names to open.
309-
concat_axes: List of axes along which to concatenate the data variables.
310-
context: TensorStore context.
354+
paths: List of paths to zarr stores.
355+
data_vars: List of data variable names to open.
356+
concat_axes: List of axes along which to concatenate the data variables.
357+
context: TensorStore context.
358+
359+
Returns:
360+
Dictionary of data variable names to concatenated TensorStore arrays.
311361
"""
312362
# Open all arrays in all datasets using tensorstore
313363
arrays_list = []
314364
for path in paths:
315365
zarr_format = _get_zarr_format(path)
316-
specs = {k: _zarr_spec_from_path(os.path.join(path, k), zarr_format) for k in data_vars}
317-
array_futures = {
318-
k: tensorstore.open(spec, read=True, write=False, context=context)
319-
for k, spec in specs.items()
320-
}
366+
# TODO(shoyer): Figure out how to support opening concatenated Zarrs with
367+
# consolidated metadata. xarray.open_mfdataset() doesn't support opening
368+
# from an existing store, so we'd have to replicate that functionality for
369+
# figuring out the structure of the concatenated dataset.
370+
group = None
371+
array_futures = _open_tensorstore_arrays(
372+
path, data_vars, group, zarr_format, write=False, context=context
373+
)
321374
arrays_list.append(array_futures)
322375

323376
# Concatenate the tensorstore arrays
@@ -354,11 +407,11 @@ def open_concatenated_zarrs(
354407
context = tensorstore.Context()
355408

356409
ds = xarray.open_mfdataset(
357-
paths,
358-
concat_dim=concat_dim,
359-
combine="nested",
360-
mask_and_scale=mask_and_scale,
361-
engine="zarr"
410+
paths,
411+
concat_dim=concat_dim,
412+
combine='nested',
413+
mask_and_scale=mask_and_scale,
414+
engine='zarr',
362415
)
363416

364417
if mask_and_scale:
@@ -369,7 +422,9 @@ def open_concatenated_zarrs(
369422

370423
data_vars = list(ds.data_vars)
371424
concat_axes = [ds[v].dims.index(concat_dim) for v in data_vars]
372-
arrays = _tensorstore_open_concatenated_zarrs(paths, data_vars, concat_axes, context)
425+
arrays = _tensorstore_open_concatenated_zarrs(
426+
paths, data_vars, concat_axes, context
427+
)
373428
new_data = {k: _TensorStoreAdapter(v) for k, v in arrays.items()}
374429

375430
return ds.copy(data=new_data)

xarray_tensorstore_test.py

Lines changed: 76 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -26,63 +26,63 @@
2626
_USING_ZARR_PYTHON_3 = packaging.version.parse(zarr.__version__).major >= 3
2727

2828
test_cases = [
29-
{
30-
'testcase_name': 'base',
31-
'transform': lambda ds: ds,
32-
},
33-
{
34-
'testcase_name': 'transposed',
35-
'transform': lambda ds: ds.transpose('z', 'x', 'y'),
36-
},
37-
{
38-
'testcase_name': 'basic_int',
39-
'transform': lambda ds: ds.isel(y=1),
40-
},
41-
{
42-
'testcase_name': 'negative_int',
43-
'transform': lambda ds: ds.isel(y=-1),
44-
},
45-
{
46-
'testcase_name': 'basic_slice',
47-
'transform': lambda ds: ds.isel(z=slice(2)),
48-
},
49-
{
50-
'testcase_name': 'full_slice',
51-
'transform': lambda ds: ds.isel(z=slice(0, 4)),
52-
},
53-
{
54-
'testcase_name': 'out_of_bounds_slice',
55-
'transform': lambda ds: ds.isel(z=slice(0, 10)),
56-
},
57-
{
58-
'testcase_name': 'strided_slice',
59-
'transform': lambda ds: ds.isel(z=slice(0, None, 2)),
60-
},
61-
{
62-
'testcase_name': 'negative_stride_slice',
63-
'transform': lambda ds: ds.isel(z=slice(None, None, -1)),
64-
},
65-
{
66-
'testcase_name': 'repeated_indexing',
67-
'transform': lambda ds: ds.isel(z=slice(1, None)).isel(z=0),
68-
},
69-
{
70-
'testcase_name': 'oindex',
71-
# includes repeated, negative and out of order indices
72-
'transform': lambda ds: ds.isel(x=[0], y=[1, 1], z=[1, -1, 0]),
73-
},
74-
{
75-
'testcase_name': 'vindex',
76-
'transform': lambda ds: ds.isel(x=('w', [0, 1]), y=('w', [1, 2])),
77-
},
78-
{
79-
'testcase_name': 'mixed_indexing_types',
80-
'transform': lambda ds: ds.isel(x=0, y=slice(2), z=[-1]),
81-
},
82-
{
83-
'testcase_name': 'select_a_variable',
84-
'transform': lambda ds: ds['foo'],
85-
},
29+
{
30+
'testcase_name': 'base',
31+
'transform': lambda ds: ds,
32+
},
33+
{
34+
'testcase_name': 'transposed',
35+
'transform': lambda ds: ds.transpose('z', 'x', 'y'),
36+
},
37+
{
38+
'testcase_name': 'basic_int',
39+
'transform': lambda ds: ds.isel(y=1),
40+
},
41+
{
42+
'testcase_name': 'negative_int',
43+
'transform': lambda ds: ds.isel(y=-1),
44+
},
45+
{
46+
'testcase_name': 'basic_slice',
47+
'transform': lambda ds: ds.isel(z=slice(2)),
48+
},
49+
{
50+
'testcase_name': 'full_slice',
51+
'transform': lambda ds: ds.isel(z=slice(0, 4)),
52+
},
53+
{
54+
'testcase_name': 'out_of_bounds_slice',
55+
'transform': lambda ds: ds.isel(z=slice(0, 10)),
56+
},
57+
{
58+
'testcase_name': 'strided_slice',
59+
'transform': lambda ds: ds.isel(z=slice(0, None, 2)),
60+
},
61+
{
62+
'testcase_name': 'negative_stride_slice',
63+
'transform': lambda ds: ds.isel(z=slice(None, None, -1)),
64+
},
65+
{
66+
'testcase_name': 'repeated_indexing',
67+
'transform': lambda ds: ds.isel(z=slice(1, None)).isel(z=0),
68+
},
69+
{
70+
'testcase_name': 'oindex',
71+
# includes repeated, negative and out of order indices
72+
'transform': lambda ds: ds.isel(x=[0], y=[1, 1], z=[1, -1, 0]),
73+
},
74+
{
75+
'testcase_name': 'vindex',
76+
'transform': lambda ds: ds.isel(x=('w', [0, 1]), y=('w', [1, 2])),
77+
},
78+
{
79+
'testcase_name': 'mixed_indexing_types',
80+
'transform': lambda ds: ds.isel(x=0, y=slice(2), z=[-1]),
81+
},
82+
{
83+
'testcase_name': 'select_a_variable',
84+
'transform': lambda ds: ds['foo'],
85+
},
8686
]
8787

8888

@@ -128,16 +128,18 @@ def test_open_concatenated_zarrs(self, transform):
128128
},
129129
attrs={'global': 'global metadata'},
130130
)
131-
for x in [range(0,2), range(3, 5)]
131+
for x in [range(0, 2), range(3, 5)]
132132
]
133133

134134
zarr_dir = self.create_tempdir().full_path
135-
paths = [f"{zarr_dir}/{i}" for i in range(len(sources))]
135+
paths = [f'{zarr_dir}/{i}' for i in range(len(sources))]
136136
for source, path in zip(sources, paths, strict=True):
137137
source.chunk().to_zarr(path)
138138

139-
expected = transform(xarray.concat(sources, dim="x"))
140-
actual = transform(xarray_tensorstore.open_concatenated_zarrs(paths, concat_dim="x")).compute()
139+
expected = transform(xarray.concat(sources, dim='x'))
140+
actual = transform(
141+
xarray_tensorstore.open_concatenated_zarrs(paths, concat_dim='x')
142+
).compute()
141143
xarray.testing.assert_identical(actual, expected)
142144

143145
@parameterized.parameters(
@@ -172,26 +174,32 @@ def test_compute(self):
172174
self.assertNotIsInstance(computed_data, tensorstore.TensorStore)
173175

174176
def test_open_zarr_from_uri(self):
175-
source = xarray.Dataset({'baz': (('x', 'y', 'z'), np.arange(24).reshape(2, 3, 4))})
177+
source = xarray.Dataset(
178+
{'baz': (('x', 'y', 'z'), np.arange(24).reshape(2, 3, 4))}
179+
)
176180
path = self.create_tempdir().full_path
177181
source.chunk().to_zarr(path)
178182

179183
opened = xarray_tensorstore.open_zarr('file://' + path)
180184
xarray.testing.assert_identical(source, opened)
181185

182186
@parameterized.parameters(
183-
{'zarr_format': 2},
184-
{'zarr_format': 3},
187+
{'zarr_format': 2, 'consolidated': True},
188+
{'zarr_format': 3, 'consolidated': True},
189+
{'zarr_format': 2, 'consolidated': False},
190+
{'zarr_format': 3, 'consolidated': False},
185191
)
186-
def test_read_dataset(self, zarr_format):
192+
def test_read_dataset(self, zarr_format: int, consolidated: bool):
187193
if not _USING_ZARR_PYTHON_3 and zarr_format == 3:
188194
self.skipTest('zarr format 3 is not supported in zarr < 3.0.0')
189195
source = xarray.Dataset(
190196
{'baz': (('x', 'y', 'z'), np.arange(24).reshape(2, 3, 4))},
191197
coords={'x': np.arange(2)},
192198
)
193199
path = self.create_tempdir().full_path
194-
source.chunk().to_zarr(path, zarr_format=zarr_format)
200+
source.chunk().to_zarr(
201+
path, zarr_format=zarr_format, consolidated=consolidated
202+
)
195203

196204
opened = xarray_tensorstore.open_zarr(path)
197205
read = xarray_tensorstore.read(opened)
@@ -204,8 +212,8 @@ def test_read_dataset(self, zarr_format):
204212
{'zarr_format': 2},
205213
{'zarr_format': 3},
206214
)
207-
def test_read_dataarray(self, zarr_format):
208-
if not _USING_ZARR_PYTHON_3 and zarr_format == 3:
215+
def test_read_dataarray(self, zarr_format: int):
216+
if not _USING_ZARR_PYTHON_3 and zarr_format == 3:
209217
self.skipTest('zarr format 3 is not supported in zarr < 3.0.0')
210218
source = xarray.DataArray(
211219
np.arange(24).reshape(2, 3, 4),

0 commit comments

Comments
 (0)