Skip to content

Commit ee3c8c5

Browse files
sylmakcopybara-github
authored andcommitted
Check zarr format before opening.
PiperOrigin-RevId: 791314170
1 parent ea6dc16 commit ee3c8c5

2 files changed

Lines changed: 34 additions & 8 deletions

File tree

xarray_tensorstore.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,11 @@
2121
from typing import Optional, TypeVar
2222

2323
import numpy as np
24+
import packaging
2425
import tensorstore
2526
import xarray
2627
from xarray.core import indexing
28+
import zarr
2729

2830

2931
__version__ = '0.1.5' # keep in sync with setup.py
@@ -176,12 +178,12 @@ def read(xarraydata: XarrayData, /) -> XarrayData:
176178
_DEFAULT_STORAGE_DRIVER = 'file'
177179

178180

179-
def _zarr_spec_from_path(path: str) -> ...:
181+
def _zarr_spec_from_path(path: str, zarr_format: int) -> ...:
180182
if re.match(r'\w+\://', path): # path is a URI
181183
kv_store = path
182184
else:
183185
kv_store = {'driver': _DEFAULT_STORAGE_DRIVER, 'path': path}
184-
return {'driver': 'zarr', 'kvstore': kv_store}
186+
return {'driver': f'zarr{zarr_format}', 'kvstore': kv_store}
185187

186188

187189
def _raise_if_mask_and_scale_used_for_data_vars(ds: xarray.Dataset):
@@ -207,6 +209,14 @@ def _raise_if_mask_and_scale_used_for_data_vars(ds: xarray.Dataset):
207209
)
208210

209211

212+
def _get_zarr_format(path: str) -> int:
213+
"""Returns the Zarr format of the given path."""
214+
if packaging.version.parse(zarr.__version__).major >= 3:
215+
return zarr.open_group(path, mode='r').metadata.zarr_format
216+
else:
217+
return 2
218+
219+
210220
def open_zarr(
211221
path: str,
212222
*,
@@ -271,7 +281,10 @@ def open_zarr(
271281
# incorrect data values.
272282
_raise_if_mask_and_scale_used_for_data_vars(ds)
273283

274-
specs = {k: _zarr_spec_from_path(os.path.join(path, k)) for k in ds}
284+
zarr_format = _get_zarr_format(path)
285+
specs = {
286+
k: _zarr_spec_from_path(os.path.join(path, k), zarr_format) for k in ds
287+
}
275288
array_futures = {
276289
k: tensorstore.open(spec, read=True, write=write, context=context)
277290
for k, spec in specs.items()

xarray_tensorstore_test.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,13 @@
1414
from absl.testing import absltest
1515
from absl.testing import parameterized
1616
import numpy as np
17+
import packaging
1718
import pandas as pd
1819
import pytest
1920
import tensorstore
2021
import xarray
21-
from xarray.core import indexing
2222
import xarray_tensorstore
23+
import zarr
2324

2425

2526
class XarrayTensorstoreTest(parameterized.TestCase):
@@ -145,13 +146,19 @@ def test_open_zarr_from_uri(self):
145146
opened = xarray_tensorstore.open_zarr('file://' + path)
146147
xarray.testing.assert_identical(source, opened)
147148

148-
def test_read_dataset(self):
149+
@parameterized.parameters(
150+
{'zarr_format': 2},
151+
{'zarr_format': 3},
152+
)
153+
def test_read_dataset(self, zarr_format):
154+
if packaging.version.parse(zarr.__version__).major < 3 and zarr_format == 3:
155+
self.skipTest('zarr format 3 is not supported in zarr < 3.0.0')
149156
source = xarray.Dataset(
150157
{'baz': (('x', 'y', 'z'), np.arange(24).reshape(2, 3, 4))},
151158
coords={'x': np.arange(2)},
152159
)
153160
path = self.create_tempdir().full_path
154-
source.chunk().to_zarr(path)
161+
source.chunk().to_zarr(path, zarr_format=zarr_format)
155162

156163
opened = xarray_tensorstore.open_zarr(path)
157164
read = xarray_tensorstore.read(opened)
@@ -160,15 +167,21 @@ def test_read_dataset(self):
160167
self.assertIsNotNone(read.variables['baz']._data.future)
161168
xarray.testing.assert_identical(read, source)
162169

163-
def test_read_dataarray(self):
170+
@parameterized.parameters(
171+
{'zarr_format': 2},
172+
{'zarr_format': 3},
173+
)
174+
def test_read_dataarray(self, zarr_format):
175+
if packaging.version.parse(zarr.__version__).major < 3 and zarr_format == 3:
176+
self.skipTest('zarr format 3 is not supported in zarr < 3.0.0')
164177
source = xarray.DataArray(
165178
np.arange(24).reshape(2, 3, 4),
166179
dims=('x', 'y', 'z'),
167180
name='baz',
168181
coords={'x': np.arange(2)},
169182
)
170183
path = self.create_tempdir().full_path
171-
source.to_dataset().chunk().to_zarr(path)
184+
source.to_dataset().chunk().to_zarr(path, zarr_format=zarr_format)
172185

173186
opened = xarray_tensorstore.open_zarr(path)['baz']
174187
read = xarray_tensorstore.read(opened)

0 commit comments

Comments
 (0)