|
1 | 1 | # Copyright 2023 Google LLC |
2 | 2 | # |
3 | | -# Licensed under the Apache License, Version 2.0 (the "License"); |
| 3 | +# Licensed under the Apache License, Version 2.0 (the 'License'); |
4 | 4 | # you may not use this file except in compliance with the License. |
5 | 5 | # You may obtain a copy of the License at |
6 | 6 | # |
7 | 7 | # https://www.apache.org/licenses/LICENSE-2.0 |
8 | 8 | # |
9 | 9 | # Unless required by applicable law or agreed to in writing, software |
10 | | -# distributed under the License is distributed on an "AS IS" BASIS, |
| 10 | +# distributed under the License is distributed on an 'AS IS' BASIS, |
11 | 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 | from absl.testing import absltest |
15 | 15 | from absl.testing import parameterized |
16 | 16 | import numpy as np |
17 | 17 | import pandas as pd |
| 18 | +import pytest |
18 | 19 | import tensorstore |
19 | 20 | import xarray |
| 21 | +from xarray.core import indexing |
20 | 22 | import xarray_tensorstore |
21 | 23 |
|
22 | 24 |
|
@@ -136,9 +138,7 @@ def test_compute(self): |
136 | 138 | self.assertNotIsInstance(computed_data, tensorstore.TensorStore) |
137 | 139 |
|
138 | 140 | def test_open_zarr_from_uri(self): |
139 | | - source = xarray.Dataset( |
140 | | - {'baz': (('x', 'y', 'z'), np.arange(24).reshape(2, 3, 4))} |
141 | | - ) |
| 141 | + source = xarray.Dataset({'baz': (('x', 'y', 'z'), np.arange(24).reshape(2, 3, 4))}) |
142 | 142 | path = self.create_tempdir().full_path |
143 | 143 | source.chunk().to_zarr(path) |
144 | 144 |
|
@@ -221,6 +221,61 @@ def test_mask_and_scale(self): |
221 | 221 | xarray.testing.assert_identical(actual, source) |
222 | 222 | self.assertEqual(actual.coords['x'].encoding['add_offset'], -1) |
223 | 223 |
|
| 224 | + @parameterized.named_parameters( |
| 225 | + { |
| 226 | + 'testcase_name': 'basic_indexing', |
| 227 | + 'key': (slice(1, None), slice(None), slice(None)), |
| 228 | + 'value': np.full((1, 2, 3), -1), |
| 229 | + }, |
| 230 | + { |
| 231 | + 'testcase_name': 'outer_indexing', |
| 232 | + 'key': (np.array([0]), np.array([1]), slice(None)), |
| 233 | + 'value': np.full((1, 1, 3), -2), |
| 234 | + }, |
| 235 | + { |
| 236 | + 'testcase_name': 'vectorized_indexing', |
| 237 | + 'key': (np.array([0]), np.array([0, 1]), slice(None)), |
| 238 | + 'value': np.full((2, 3), -3), |
| 239 | + }, |
| 240 | + ) |
| 241 | + def test_setitem(self, key, value): |
| 242 | + source_data = np.array([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]) |
| 243 | + source = xarray.DataArray( |
| 244 | + source_data, |
| 245 | + dims=('x', 'y', 'z'), |
| 246 | + name='baz', |
| 247 | + ) |
| 248 | + path = self.create_tempdir().full_path |
| 249 | + source.to_dataset().chunk().to_zarr(path) |
| 250 | + |
| 251 | + opened = xarray_tensorstore.open_zarr(path, write=True)['baz'] |
| 252 | + |
| 253 | + opened[key] = value |
| 254 | + read = xarray_tensorstore.read(opened) |
| 255 | + |
| 256 | + expected_data = source_data.copy() |
| 257 | + expected_data[key] = value |
| 258 | + expected = xarray.DataArray( |
| 259 | + expected_data, |
| 260 | + dims=('x', 'y', 'z'), |
| 261 | + name='baz', |
| 262 | + ) |
| 263 | + |
| 264 | + xarray.testing.assert_equal(read, expected) |
| 265 | + |
| 266 | + def test_setitem_readonly(self): |
| 267 | + source = xarray.DataArray( |
| 268 | + np.array([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]), |
| 269 | + dims=('x', 'y', 'z'), |
| 270 | + name='baz', |
| 271 | + ) |
| 272 | + path = self.create_tempdir().full_path |
| 273 | + source.to_dataset().chunk().to_zarr(path) |
| 274 | + |
| 275 | + opened = xarray_tensorstore.open_zarr(path)['baz'] |
| 276 | + with pytest.raises(ValueError): |
| 277 | + opened[1:, ...] = np.full((1, 2, 3), -1) |
| 278 | + |
224 | 279 |
|
225 | 280 | if __name__ == '__main__': |
226 | 281 | absltest.main() |
0 commit comments