Skip to content

Commit b5ffbe5

Browse files
authored
Updates to getitem (#196)
1 parent fb228b6 commit b5ffbe5

2 files changed

Lines changed: 83 additions & 12 deletions

File tree

cf_xarray/accessor.py

Lines changed: 42 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
List,
1313
Mapping,
1414
MutableMapping,
15-
Optional,
1615
Set,
1716
Tuple,
1817
TypeVar,
@@ -605,24 +604,41 @@ def _getitem(
605604
if skip is None:
606605
skip = []
607606

608-
def check_results(names, k):
607+
def drop_bounds(names):
608+
# sometimes bounds variables have the same standard_name as the
609+
# actual variable. It seems practical to ignore them when indexing
610+
# with a scalar key. Hopefully these will soon get decoded to IntervalIndex
611+
# and we can move on...
612+
if scalar_key:
613+
bounds = set([obj[k].attrs.get("bounds", None) for k in names])
614+
names = set(names) - bounds
615+
return names
616+
617+
def check_results(names, key):
609618
if scalar_key and len(names) > 1:
610619
raise KeyError(
611-
f"Receive multiple variables for key {k!r}: {names}. "
612-
f"Expected only one. Please pass a list [{k!r}] "
613-
f"instead to get all variables matching {k!r}."
620+
f"Receive multiple variables for key {key!r}: {names}. "
621+
f"Expected only one. Please pass a list [{key!r}] "
622+
f"instead to get all variables matching {key!r}."
614623
)
615624

625+
try:
626+
measures = accessor._get_all_cell_measures()
627+
except ValueError:
628+
measures = []
629+
warnings.warn("Ignoring bad cell_measures attribute.", UserWarning)
630+
616631
varnames: List[Hashable] = []
617632
coords: List[Hashable] = []
618633
successful = dict.fromkeys(key, False)
619634
for k in key:
620635
if "coords" not in skip and k in _AXIS_NAMES + _COORD_NAMES:
621636
names = _get_all(obj, k)
637+
names = drop_bounds(names)
622638
check_results(names, k)
623639
successful[k] = bool(names)
624640
coords.extend(names)
625-
elif "measures" not in skip and k in accessor._get_all_cell_measures():
641+
elif "measures" not in skip and k in measures:
626642
measure = _get_all(obj, k)
627643
check_results(measure, k)
628644
successful[k] = bool(measure)
@@ -631,6 +647,7 @@ def check_results(names, k):
631647
else:
632648
stdnames = set(_get_with_standard_name(obj, k))
633649
objcoords = set(obj.coords)
650+
stdnames = drop_bounds(stdnames)
634651
if "coords" in skip:
635652
stdnames -= objcoords
636653
check_results(stdnames, k)
@@ -646,7 +663,7 @@ def check_results(names, k):
646663
try:
647664
for name in allnames:
648665
extravars = accessor.get_associated_variable_names(
649-
name, skip_bounds=scalar_key
666+
name, skip_bounds=scalar_key, error=False
650667
)
651668
coords.extend(itertools.chain(*extravars.values()))
652669

@@ -1238,7 +1255,7 @@ def standard_names(self) -> Dict[str, List[str]]:
12381255
return {k: sorted(v) for k, v in vardict.items()}
12391256

12401257
def get_associated_variable_names(
1241-
self, name: Hashable, skip_bounds: Optional[bool] = None
1258+
self, name: Hashable, skip_bounds: bool = False, error: bool = True
12421259
) -> Dict[str, List[str]]:
12431260
"""
12441261
Returns a dict mapping
@@ -1252,6 +1269,9 @@ def get_associated_variable_names(
12521269
----------
12531270
name : Hashable
12541271
skip_bounds : bool, optional
1272+
error: bool, optional
1273+
Raise or ignore errors.
1274+
12551275
Returns
12561276
------
12571277
Dict with keys "ancillary_variables", "cell_measures", "coordinates", "bounds"
@@ -1264,9 +1284,20 @@ def get_associated_variable_names(
12641284
coords["coordinates"] = attrs_or_encoding["coordinates"].split(" ")
12651285

12661286
if "cell_measures" in attrs_or_encoding:
1267-
coords["cell_measures"] = list(
1268-
parse_cell_methods_attr(attrs_or_encoding["cell_measures"]).values()
1269-
)
1287+
try:
1288+
coords["cell_measures"] = list(
1289+
parse_cell_methods_attr(attrs_or_encoding["cell_measures"]).values()
1290+
)
1291+
except ValueError as e:
1292+
if error:
1293+
msg = e.args[0] + " Ignore this error by passing 'error=False'"
1294+
raise ValueError(msg)
1295+
else:
1296+
warnings.warn(
1297+
f"Ignoring bad cell_measures attribute: {attrs_or_encoding['cell_measures']}",
1298+
UserWarning,
1299+
)
1300+
coords["cell_measures"] = []
12701301

12711302
if (
12721303
isinstance(self._obj, Dataset)

cf_xarray/tests/test_accessor.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -502,11 +502,51 @@ def test_getitem_errors(obj):
502502
obj2.cf["X"]
503503

504504

505-
def test_getitem_regression():
505+
def test_getitem_ignores_bad_measure_attribute():
506+
air2 = airds.copy(deep=True)
507+
air2.air.attrs["cell_measures"] = "asd"
508+
with pytest.warns(UserWarning):
509+
assert_identical(air2.air.drop_vars("cell_area"), air2.cf["air"])
510+
511+
with pytest.raises(ValueError):
512+
air2.cf.cell_measures
513+
with pytest.raises(ValueError):
514+
air2.air.cf.cell_measures
515+
with pytest.raises(ValueError):
516+
air2.cf.get_associated_variable_names("air", error=True)
517+
with pytest.warns(UserWarning):
518+
air2.cf.get_associated_variable_names("air", error=False)
519+
520+
521+
def test_getitem_clash_standard_name():
506522
ds = xr.Dataset()
507523
ds.coords["area"] = xr.DataArray(np.ones(10), attrs={"standard_name": "cell_area"})
508524
assert_identical(ds.cf["cell_area"], ds["area"].reset_coords(drop=True))
509525

526+
ds = xr.Dataset()
527+
ds["time"] = (
528+
"time",
529+
np.arange(10),
530+
{"standard_name": "time", "bounds": "time_bounds"},
531+
)
532+
ds["time_bounds"] = (
533+
("time", "bounds"),
534+
np.ones((10, 2)),
535+
{"standard_name": "time"},
536+
)
537+
538+
ds["lat"] = (
539+
"lat",
540+
np.arange(10),
541+
{"units": "degrees_north", "bounds": "lat_bounds"},
542+
)
543+
ds["lat_bounds"] = (
544+
("lat", "bounds"),
545+
np.ones((10, 2)),
546+
{"units": "degrees_north"},
547+
)
548+
assert_identical(ds["lat"], ds.cf["latitude"])
549+
510550

511551
def test_getitem_uses_coordinates():
512552
# POP-like dataset

0 commit comments

Comments
 (0)