Skip to content

Commit debcafa

Browse files
Merge pull request #10 from rhoadesScholar:main
PiperOrigin-RevId: 693804633
2 parents 1e4cd91 + fa7ea05 commit debcafa

3 files changed

Lines changed: 76 additions & 7 deletions

File tree

conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,4 @@
1717
try:
1818
app.run(lambda argv: None)
1919
except SystemExit:
20-
pass
20+
pass

xarray_tensorstore.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,18 @@ def __getitem__(self, key: indexing.ExplicitIndexer) -> _TensorStoreAdapter:
9999
translated = indexed[tensorstore.d[:].translate_to[0]]
100100
return type(self)(translated)
101101

102+
def __setitem__(self, key: indexing.ExplicitIndexer, value) -> None:
103+
index_tuple = tuple(map(_numpy_to_tensorstore_index, key.tuple, self.shape))
104+
if isinstance(key, indexing.OuterIndexer):
105+
self.array.oindex[index_tuple] = value
106+
elif isinstance(key, indexing.VectorizedIndexer):
107+
self.array.vindex[index_tuple] = value
108+
else:
109+
assert isinstance(key, indexing.BasicIndexer)
110+
self.array[index_tuple] = value
111+
# Invalidate the future so that the next read will pick up the new value
112+
object.__setattr__(self, 'future', None)
113+
102114
# xarray>2024.02.0 uses oindex and vindex properties, which are expected to
103115
# return objects whose __getitem__ method supports the appropriate form of
104116
# indexing.
@@ -200,6 +212,7 @@ def open_zarr(
200212
*,
201213
context: tensorstore.Context | None = None,
202214
mask_and_scale: bool = True,
215+
write: bool = False,
203216
) -> xarray.Dataset:
204217
"""Open an xarray.Dataset from Zarr using TensorStore.
205218
@@ -228,6 +241,7 @@ def open_zarr(
228241
mask_and_scale: if True (default), attempt to apply masking and scaling like
229242
xarray.open_zarr(). This is only supported for coordinate variables and
230243
otherwise will raise an error.
244+
write: Allow write access. Defaults to False.
231245
232246
Returns:
233247
Dataset with all data variables opened via TensorStore.
@@ -259,7 +273,7 @@ def open_zarr(
259273

260274
specs = {k: _zarr_spec_from_path(os.path.join(path, k)) for k in ds}
261275
array_futures = {
262-
k: tensorstore.open(spec, read=True, write=False, context=context)
276+
k: tensorstore.open(spec, read=True, write=write, context=context)
263277
for k, spec in specs.items()
264278
}
265279
arrays = {k: v.result() for k, v in array_futures.items()}

xarray_tensorstore_test.py

Lines changed: 60 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,24 @@
11
# Copyright 2023 Google LLC
22
#
3-
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# Licensed under the Apache License, Version 2.0 (the 'License');
44
# you may not use this file except in compliance with the License.
55
# You may obtain a copy of the License at
66
#
77
# https://www.apache.org/licenses/LICENSE-2.0
88
#
99
# Unless required by applicable law or agreed to in writing, software
10-
# distributed under the License is distributed on an "AS IS" BASIS,
10+
# distributed under the License is distributed on an 'AS IS' BASIS,
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from absl.testing import absltest
1515
from absl.testing import parameterized
1616
import numpy as np
1717
import pandas as pd
18+
import pytest
1819
import tensorstore
1920
import xarray
21+
from xarray.core import indexing
2022
import xarray_tensorstore
2123

2224

@@ -136,9 +138,7 @@ def test_compute(self):
136138
self.assertNotIsInstance(computed_data, tensorstore.TensorStore)
137139

138140
def test_open_zarr_from_uri(self):
139-
source = xarray.Dataset(
140-
{'baz': (('x', 'y', 'z'), np.arange(24).reshape(2, 3, 4))}
141-
)
141+
source = xarray.Dataset({'baz': (('x', 'y', 'z'), np.arange(24).reshape(2, 3, 4))})
142142
path = self.create_tempdir().full_path
143143
source.chunk().to_zarr(path)
144144

@@ -221,6 +221,61 @@ def test_mask_and_scale(self):
221221
xarray.testing.assert_identical(actual, source)
222222
self.assertEqual(actual.coords['x'].encoding['add_offset'], -1)
223223

224+
@parameterized.named_parameters(
225+
{
226+
'testcase_name': 'basic_indexing',
227+
'key': (slice(1, None), slice(None), slice(None)),
228+
'value': np.full((1, 2, 3), -1),
229+
},
230+
{
231+
'testcase_name': 'outer_indexing',
232+
'key': (np.array([0]), np.array([1]), slice(None)),
233+
'value': np.full((1, 1, 3), -2),
234+
},
235+
{
236+
'testcase_name': 'vectorized_indexing',
237+
'key': (np.array([0]), np.array([0, 1]), slice(None)),
238+
'value': np.full((2, 3), -3),
239+
},
240+
)
241+
def test_setitem(self, key, value):
242+
source_data = np.array([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]])
243+
source = xarray.DataArray(
244+
source_data,
245+
dims=('x', 'y', 'z'),
246+
name='baz',
247+
)
248+
path = self.create_tempdir().full_path
249+
source.to_dataset().chunk().to_zarr(path)
250+
251+
opened = xarray_tensorstore.open_zarr(path, write=True)['baz']
252+
253+
opened[key] = value
254+
read = xarray_tensorstore.read(opened)
255+
256+
expected_data = source_data.copy()
257+
expected_data[key] = value
258+
expected = xarray.DataArray(
259+
expected_data,
260+
dims=('x', 'y', 'z'),
261+
name='baz',
262+
)
263+
264+
xarray.testing.assert_equal(read, expected)
265+
266+
def test_setitem_readonly(self):
267+
source = xarray.DataArray(
268+
np.array([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]),
269+
dims=('x', 'y', 'z'),
270+
name='baz',
271+
)
272+
path = self.create_tempdir().full_path
273+
source.to_dataset().chunk().to_zarr(path)
274+
275+
opened = xarray_tensorstore.open_zarr(path)['baz']
276+
with pytest.raises(ValueError):
277+
opened[1:, ...] = np.full((1, 2, 3), -1)
278+
224279

225280
if __name__ == '__main__':
226281
absltest.main()

0 commit comments

Comments
 (0)