|
3 | 3 | import tempfile |
4 | 4 | from collections.abc import Callable |
5 | 5 | from pathlib import Path |
6 | | -from typing import Any |
| 6 | +from typing import Any, Literal |
7 | 7 |
|
8 | 8 | import dask.dataframe as dd |
9 | 9 | import numpy as np |
| 10 | +import pandas as pd |
| 11 | +import pyarrow.parquet as pq |
10 | 12 | import pytest |
11 | 13 | import zarr |
12 | 14 | from anndata import AnnData |
13 | 15 | from numpy.random import default_rng |
| 16 | +from shapely import MultiPolygon, Polygon |
14 | 17 | from upath import UPath |
15 | 18 | from zarr.errors import GroupNotFoundError |
16 | 19 |
|
| 20 | +import spatialdata.config |
17 | 21 | from spatialdata import SpatialData, deepcopy, read_zarr |
18 | 22 | from spatialdata._core.validation import ValidationError |
19 | 23 | from spatialdata._io._utils import _are_directories_identical, get_dask_backing_files |
@@ -74,20 +78,90 @@ def test_labels( |
74 | 78 | sdata = SpatialData.read(tmpdir) |
75 | 79 | assert_spatial_data_objects_are_identical(labels, sdata) |
76 | 80 |
|
| 81 | + @pytest.mark.parametrize("geometry_encoding", ["WKB", "geoarrow"]) |
77 | 82 | def test_shapes( |
78 | 83 | self, |
79 | 84 | tmp_path: str, |
80 | 85 | shapes: SpatialData, |
81 | 86 | sdata_container_format: SpatialDataContainerFormatType, |
| 87 | + geometry_encoding: Literal["WKB", "geoarrow"], |
82 | 88 | ) -> None: |
83 | 89 | tmpdir = Path(tmp_path) / "tmp.zarr" |
84 | 90 |
|
85 | 91 | # check the index is correctly written and then read |
86 | 92 | shapes["circles"].index = np.arange(1, len(shapes["circles"]) + 1) |
87 | 93 |
|
88 | | - shapes.write(tmpdir, sdata_formats=sdata_container_format) |
| 94 | + # add a mixed Polygon + MultiPolygon element |
| 95 | + shapes["mixed"] = pd.concat([shapes["poly"], shapes["multipoly"]]) |
| 96 | + |
| 97 | + shapes.write(tmpdir, sdata_formats=sdata_container_format, shapes_geometry_encoding=geometry_encoding) |
89 | 98 | sdata = SpatialData.read(tmpdir) |
90 | | - assert_spatial_data_objects_are_identical(shapes, sdata) |
| 99 | + |
| 100 | + if geometry_encoding == "WKB": |
| 101 | + assert_spatial_data_objects_are_identical(shapes, sdata) |
| 102 | + else: |
| 103 | + # convert each Polygon to a MultiPolygon |
| 104 | + mixed_multipolygon = shapes["mixed"].assign( |
| 105 | + geometry=lambda df: df.geometry.apply(lambda g: MultiPolygon([g]) if isinstance(g, Polygon) else g) |
| 106 | + ) |
| 107 | + assert sdata["mixed"].equals(mixed_multipolygon) |
| 108 | + assert not sdata["mixed"].equals(shapes["mixed"]) |
| 109 | + |
| 110 | + del shapes["mixed"] |
| 111 | + del sdata["mixed"] |
| 112 | + assert_spatial_data_objects_are_identical(shapes, sdata) |
| 113 | + |
| 114 | + @pytest.mark.parametrize("geometry_encoding", ["WKB", "geoarrow"]) |
| 115 | + def test_shapes_geometry_encoding_write_element( |
| 116 | + self, |
| 117 | + tmp_path: str, |
| 118 | + shapes: SpatialData, |
| 119 | + sdata_container_format: SpatialDataContainerFormatType, |
| 120 | + geometry_encoding: Literal["WKB", "geoarrow"], |
| 121 | + ) -> None: |
| 122 | + """Test shapes geometry encoding with write_element() and global settings.""" |
| 123 | + tmpdir = Path(tmp_path) / "tmp.zarr" |
| 124 | + |
| 125 | + # First write an empty SpatialData to create the zarr store |
| 126 | + empty_sdata = SpatialData() |
| 127 | + empty_sdata.write(tmpdir, sdata_formats=sdata_container_format) |
| 128 | + |
| 129 | + shapes["mixed"] = pd.concat([shapes["poly"], shapes["multipoly"]]) |
| 130 | + |
| 131 | + # Add shapes to the empty sdata |
| 132 | + for shape_name in shapes.shapes: |
| 133 | + empty_sdata[shape_name] = shapes[shape_name] |
| 134 | + |
| 135 | + # Store original setting and set global encoding |
| 136 | + original_encoding = spatialdata.config.settings.shapes_geometry_encoding |
| 137 | + try: |
| 138 | + spatialdata.config.settings.shapes_geometry_encoding = geometry_encoding |
| 139 | + |
| 140 | + # Write each shape element - should use global setting |
| 141 | + for shape_name in shapes.shapes: |
| 142 | + empty_sdata.write_element(shape_name, sdata_formats=sdata_container_format) |
| 143 | + |
| 144 | + # Verify the encoding metadata in the parquet file |
| 145 | + parquet_file = tmpdir / "shapes" / shape_name / "shapes.parquet" |
| 146 | + with pq.ParquetFile(parquet_file) as pf: |
| 147 | + md = pf.metadata |
| 148 | + d = json.loads(md.metadata[b"geo"].decode("utf-8")) |
| 149 | + found_encoding = d["columns"]["geometry"]["encoding"] |
| 150 | + if geometry_encoding == "WKB": |
| 151 | + expected_encoding = "WKB" |
| 152 | + elif shape_name == "circles": |
| 153 | + expected_encoding = "point" |
| 154 | + elif shape_name == "poly": |
| 155 | + expected_encoding = "polygon" |
| 156 | + elif shape_name in ["multipoly", "mixed"]: |
| 157 | + expected_encoding = "multipolygon" |
| 158 | + else: |
| 159 | + raise ValueError( |
| 160 | + f"Uncovered case for shape_name: {shape_name}, found encoding: {found_encoding}." |
| 161 | + ) |
| 162 | + assert found_encoding == expected_encoding |
| 163 | + finally: |
| 164 | + spatialdata.config.settings.shapes_geometry_encoding = original_encoding |
91 | 165 |
|
92 | 166 | def test_points( |
93 | 167 | self, |
|
0 commit comments