Skip to content

Commit 2794fb0

Browse files
keller-markpre-commit-ci[bot]LucaMarconato
authored
Geometry encoding parameter for shapes (#951)
* Geometry encoding parameter for shapes * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update spatialdata.py * Update io_shapes.py * Update io_shapes.py * add setting for geometry_encoding; add tests --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Luca Marconato <m.lucalmer@gmail.com>
1 parent 0731edd commit 2794fb0

7 files changed

Lines changed: 148 additions & 16 deletions

File tree

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,3 +52,4 @@ node_modules/
5252

5353
.mypy_cache
5454
.ruff_cache
55+
uv.lock

src/spatialdata/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
"deepcopy",
4141
"sanitize_table",
4242
"sanitize_name",
43+
"settings",
4344
]
4445

4546
from spatialdata import dataloader, datasets, models, transformations
@@ -70,3 +71,4 @@
7071
from spatialdata._io.format import SpatialDataFormatType
7172
from spatialdata._io.io_zarr import read_zarr
7273
from spatialdata._utils import get_pyramid_levels, unpad_raster
74+
from spatialdata.config import settings

src/spatialdata/_core/spatialdata.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1110,6 +1110,7 @@ def write(
11101110
consolidate_metadata: bool = True,
11111111
update_sdata_path: bool = True,
11121112
sdata_formats: SpatialDataFormatType | list[SpatialDataFormatType] | None = None,
1113+
shapes_geometry_encoding: Literal["WKB", "geoarrow"] | None = None,
11131114
) -> None:
11141115
"""
11151116
Write the `SpatialData` object to a Zarr store.
@@ -1154,6 +1155,9 @@ def write(
11541155
unspecified, the element formats will be set to the latest element format compatible with the specified
11551156
SpatialData container format. All the formats and relationships between them are defined in
11561157
`spatialdata._io.format.py`.
1158+
shapes_geometry_encoding
1159+
Whether to use the WKB or geoarrow encoding for GeoParquet. See :meth:`geopandas.GeoDataFrame.to_parquet`
1160+
for details. If None, uses the value from :attr:`spatialdata.settings.shapes_geometry_encoding`.
11571161
"""
11581162
from spatialdata._io._utils import _resolve_zarr_store
11591163
from spatialdata._io.format import _parse_formats
@@ -1179,6 +1183,7 @@ def write(
11791183
element_name=element_name,
11801184
overwrite=False,
11811185
parsed_formats=parsed,
1186+
shapes_geometry_encoding=shapes_geometry_encoding,
11821187
)
11831188

11841189
if self.path != file_path and update_sdata_path:
@@ -1195,6 +1200,7 @@ def _write_element(
11951200
element_name: str,
11961201
overwrite: bool,
11971202
parsed_formats: dict[str, SpatialDataFormatType] | None = None,
1203+
shapes_geometry_encoding: Literal["WKB", "geoarrow"] | None = None,
11981204
) -> None:
11991205
from spatialdata._io.io_zarr import _get_groups_for_element
12001206

@@ -1247,6 +1253,7 @@ def _write_element(
12471253
shapes=element,
12481254
group=element_group,
12491255
element_format=parsed_formats["shapes"],
1256+
geometry_encoding=shapes_geometry_encoding,
12501257
)
12511258
elif element_type == "tables":
12521259
write_table(
@@ -1263,6 +1270,7 @@ def write_element(
12631270
element_name: str | list[str],
12641271
overwrite: bool = False,
12651272
sdata_formats: SpatialDataFormatType | list[SpatialDataFormatType] | None = None,
1273+
shapes_geometry_encoding: Literal["WKB", "geoarrow"] | None = None,
12661274
) -> None:
12671275
"""
12681276
Write a single element, or a list of elements, to the Zarr store used for backing.
@@ -1278,6 +1286,9 @@ def write_element(
12781286
sdata_formats
12791287
It is recommended to leave this parameter equal to `None`. See more details in the documentation of
12801288
`SpatialData.write()`.
1289+
shapes_geometry_encoding
1290+
Whether to use the WKB or geoarrow encoding for GeoParquet. See :meth:`geopandas.GeoDataFrame.to_parquet`
1291+
for details. If None, uses the value from :attr:`spatialdata.settings.shapes_geometry_encoding`.
12811292
12821293
Notes
12831294
-----
@@ -1291,7 +1302,12 @@ def write_element(
12911302
if isinstance(element_name, list):
12921303
for name in element_name:
12931304
assert isinstance(name, str)
1294-
self.write_element(name, overwrite=overwrite, sdata_formats=sdata_formats)
1305+
self.write_element(
1306+
name,
1307+
overwrite=overwrite,
1308+
sdata_formats=sdata_formats,
1309+
shapes_geometry_encoding=shapes_geometry_encoding,
1310+
)
12951311
return
12961312

12971313
check_valid_name(element_name)
@@ -1325,6 +1341,7 @@ def write_element(
13251341
element_name=element_name,
13261342
overwrite=overwrite,
13271343
parsed_formats=parsed_formats,
1344+
shapes_geometry_encoding=shapes_geometry_encoding,
13281345
)
13291346
# After every write, metadata should be consolidated, otherwise this can lead to IO problems like when deleting.
13301347
if self.has_consolidated_metadata():

src/spatialdata/_io/io_shapes.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from pathlib import Path
2-
from typing import Any
2+
from typing import Any, Literal
33

44
import numpy as np
55
import zarr
@@ -70,6 +70,7 @@ def write_shapes(
7070
group: zarr.Group,
7171
group_type: str = "ngff:shapes",
7272
element_format: Format = CurrentShapesFormat(),
73+
geometry_encoding: Literal["WKB", "geoarrow"] | None = None,
7374
) -> None:
7475
"""Write shapes to spatialdata zarr store.
7576
@@ -86,15 +87,23 @@ def write_shapes(
8687
The type of the element.
8788
element_format
8889
The format of the shapes element used to store it.
90+
geometry_encoding
91+
Whether to use the WKB or geoarrow encoding for GeoParquet. See :meth:`geopandas.GeoDataFrame.to_parquet` for
92+
details. If None, uses the value from :attr:`spatialdata.settings.shapes_geometry_encoding`.
8993
"""
94+
from spatialdata.config import settings
95+
96+
if geometry_encoding is None:
97+
geometry_encoding = settings.shapes_geometry_encoding
98+
9099
axes = get_axes_names(shapes)
91100
transformations = _get_transformations(shapes)
92101
if transformations is None:
93102
raise ValueError(f"{group.basename} does not have any transformations and can therefore not be written.")
94103
if isinstance(element_format, ShapesFormatV01):
95104
attrs = _write_shapes_v01(shapes, group, element_format)
96105
elif isinstance(element_format, ShapesFormatV02 | ShapesFormatV03):
97-
attrs = _write_shapes_v02_v03(shapes, group, element_format)
106+
attrs = _write_shapes_v02_v03(shapes, group, element_format, geometry_encoding=geometry_encoding)
98107
else:
99108
raise ValueError(f"Unsupported format version {element_format.version}. Please update the spatialdata library.")
100109

@@ -139,7 +148,9 @@ def _write_shapes_v01(shapes: GeoDataFrame, group: zarr.Group, element_format: F
139148
return attrs
140149

141150

142-
def _write_shapes_v02_v03(shapes: GeoDataFrame, group: zarr.Group, element_format: Format) -> Any:
151+
def _write_shapes_v02_v03(
152+
shapes: GeoDataFrame, group: zarr.Group, element_format: Format, geometry_encoding: Literal["WKB", "geoarrow"]
153+
) -> Any:
143154
"""Write shapes to spatialdata zarr store using format ShapesFormatV02 or ShapesFormatV03.
144155
145156
Parameters
@@ -150,6 +161,9 @@ def _write_shapes_v02_v03(shapes: GeoDataFrame, group: zarr.Group, element_forma
150161
The zarr group in the 'shapes' zarr group to write the shapes element to.
151162
element_format
152163
The format of the shapes element used to store it.
164+
geometry_encoding
165+
Whether to use the WKB or geoarrow encoding for GeoParquet. See :meth:`geopandas.GeoDataFrame.to_parquet` for
166+
details.
153167
"""
154168
from spatialdata.models._utils import TRANSFORM_KEY
155169

@@ -159,7 +173,7 @@ def _write_shapes_v02_v03(shapes: GeoDataFrame, group: zarr.Group, element_forma
159173
# Temporarily remove transformations from attrs to avoid serialization issues
160174
transforms = shapes.attrs[TRANSFORM_KEY]
161175
del shapes.attrs[TRANSFORM_KEY]
162-
shapes.to_parquet(path)
176+
shapes.to_parquet(path, geometry_encoding=geometry_encoding)
163177
shapes.attrs[TRANSFORM_KEY] = transforms
164178

165179
attrs = element_format.attrs_to_dict(shapes.attrs)

src/spatialdata/config.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,28 @@
1-
# chunk sizes bigger than this value (bytes) can trigger a compression error
2-
# https://github.com/scverse/spatialdata/issues/812#issuecomment-2559380276
3-
# so if we detect this during parsing/validation we raise a warning
4-
LARGE_CHUNK_THRESHOLD_BYTES = 2147483647
1+
from dataclasses import dataclass
2+
from typing import Literal
3+
4+
5+
@dataclass
6+
class Settings:
7+
"""Global settings for spatialdata.
8+
9+
Attributes
10+
----------
11+
shapes_geometry_encoding
12+
Default geometry encoding for GeoParquet files when writing shapes.
13+
Can be "WKB" (Well-Known Binary) or "geoarrow".
14+
See :meth:`geopandas.GeoDataFrame.to_parquet` for details.
15+
large_chunk_threshold_bytes
16+
Chunk sizes bigger than this value (bytes) can trigger a compression error.
17+
See https://github.com/scverse/spatialdata/issues/812#issuecomment-2559380276
18+
If detected during parsing/validation, a warning is raised.
19+
"""
20+
21+
shapes_geometry_encoding: Literal["WKB", "geoarrow"] = "WKB"
22+
large_chunk_threshold_bytes: int = 2147483647
23+
24+
25+
settings = Settings()
26+
27+
# Backwards compatibility alias
28+
LARGE_CHUNK_THRESHOLD_BYTES = settings.large_chunk_threshold_bytes

src/spatialdata/models/models.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
from spatialdata._logging import logger
3636
from spatialdata._types import ArrayLike
3737
from spatialdata._utils import _check_match_length_channels_c_dim
38-
from spatialdata.config import LARGE_CHUNK_THRESHOLD_BYTES
38+
from spatialdata.config import settings
3939
from spatialdata.models import C, X, Y, Z, get_axes_names
4040
from spatialdata.models._utils import (
4141
DEFAULT_COORDINATE_SYSTEM,
@@ -315,9 +315,9 @@ def _check_chunk_size_not_too_large(self, data: DataArray | DataTree) -> None:
315315
return
316316
n_elems = np.array(list(max_per_dimension.values())).prod().item()
317317
usage = n_elems * data.dtype.itemsize
318-
if usage > LARGE_CHUNK_THRESHOLD_BYTES:
318+
if usage > settings.large_chunk_threshold_bytes:
319319
warnings.warn(
320-
f"Detected chunks larger than: {usage} > {LARGE_CHUNK_THRESHOLD_BYTES} bytes. "
320+
f"Detected chunks larger than: {usage} > {settings.large_chunk_threshold_bytes} bytes. "
321321
"This can lead to low "
322322
"performance and memory issues downstream, and sometimes cause compression errors when writing "
323323
"(https://github.com/scverse/spatialdata/issues/812#issuecomment-2575983527). Please consider using"
@@ -327,7 +327,7 @@ def _check_chunk_size_not_too_large(self, data: DataArray | DataTree) -> None:
327327
"2) Multiscale representations can be achieved by using the `scale_factors` argument in the "
328328
"`parse()` function.\n"
329329
"You can suppress this warning by increasing the value of "
330-
"`spatialdata.config.LARGE_CHUNK_THRESHOLD_BYTES`.",
330+
"`spatialdata.settings.large_chunk_threshold_bytes`.",
331331
UserWarning,
332332
stacklevel=2,
333333
)

tests/io/test_readwrite.py

Lines changed: 77 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,21 @@
33
import tempfile
44
from collections.abc import Callable
55
from pathlib import Path
6-
from typing import Any
6+
from typing import Any, Literal
77

88
import dask.dataframe as dd
99
import numpy as np
10+
import pandas as pd
11+
import pyarrow.parquet as pq
1012
import pytest
1113
import zarr
1214
from anndata import AnnData
1315
from numpy.random import default_rng
16+
from shapely import MultiPolygon, Polygon
1417
from upath import UPath
1518
from zarr.errors import GroupNotFoundError
1619

20+
import spatialdata.config
1721
from spatialdata import SpatialData, deepcopy, read_zarr
1822
from spatialdata._core.validation import ValidationError
1923
from spatialdata._io._utils import _are_directories_identical, get_dask_backing_files
@@ -74,20 +78,90 @@ def test_labels(
7478
sdata = SpatialData.read(tmpdir)
7579
assert_spatial_data_objects_are_identical(labels, sdata)
7680

81+
@pytest.mark.parametrize("geometry_encoding", ["WKB", "geoarrow"])
7782
def test_shapes(
7883
self,
7984
tmp_path: str,
8085
shapes: SpatialData,
8186
sdata_container_format: SpatialDataContainerFormatType,
87+
geometry_encoding: Literal["WKB", "geoarrow"],
8288
) -> None:
8389
tmpdir = Path(tmp_path) / "tmp.zarr"
8490

8591
# check the index is correctly written and then read
8692
shapes["circles"].index = np.arange(1, len(shapes["circles"]) + 1)
8793

88-
shapes.write(tmpdir, sdata_formats=sdata_container_format)
94+
# add a mixed Polygon + MultiPolygon element
95+
shapes["mixed"] = pd.concat([shapes["poly"], shapes["multipoly"]])
96+
97+
shapes.write(tmpdir, sdata_formats=sdata_container_format, shapes_geometry_encoding=geometry_encoding)
8998
sdata = SpatialData.read(tmpdir)
90-
assert_spatial_data_objects_are_identical(shapes, sdata)
99+
100+
if geometry_encoding == "WKB":
101+
assert_spatial_data_objects_are_identical(shapes, sdata)
102+
else:
103+
# convert each Polygon to a MultiPolygon
104+
mixed_multipolygon = shapes["mixed"].assign(
105+
geometry=lambda df: df.geometry.apply(lambda g: MultiPolygon([g]) if isinstance(g, Polygon) else g)
106+
)
107+
assert sdata["mixed"].equals(mixed_multipolygon)
108+
assert not sdata["mixed"].equals(shapes["mixed"])
109+
110+
del shapes["mixed"]
111+
del sdata["mixed"]
112+
assert_spatial_data_objects_are_identical(shapes, sdata)
113+
114+
@pytest.mark.parametrize("geometry_encoding", ["WKB", "geoarrow"])
115+
def test_shapes_geometry_encoding_write_element(
116+
self,
117+
tmp_path: str,
118+
shapes: SpatialData,
119+
sdata_container_format: SpatialDataContainerFormatType,
120+
geometry_encoding: Literal["WKB", "geoarrow"],
121+
) -> None:
122+
"""Test shapes geometry encoding with write_element() and global settings."""
123+
tmpdir = Path(tmp_path) / "tmp.zarr"
124+
125+
# First write an empty SpatialData to create the zarr store
126+
empty_sdata = SpatialData()
127+
empty_sdata.write(tmpdir, sdata_formats=sdata_container_format)
128+
129+
shapes["mixed"] = pd.concat([shapes["poly"], shapes["multipoly"]])
130+
131+
# Add shapes to the empty sdata
132+
for shape_name in shapes.shapes:
133+
empty_sdata[shape_name] = shapes[shape_name]
134+
135+
# Store original setting and set global encoding
136+
original_encoding = spatialdata.config.settings.shapes_geometry_encoding
137+
try:
138+
spatialdata.config.settings.shapes_geometry_encoding = geometry_encoding
139+
140+
# Write each shape element - should use global setting
141+
for shape_name in shapes.shapes:
142+
empty_sdata.write_element(shape_name, sdata_formats=sdata_container_format)
143+
144+
# Verify the encoding metadata in the parquet file
145+
parquet_file = tmpdir / "shapes" / shape_name / "shapes.parquet"
146+
with pq.ParquetFile(parquet_file) as pf:
147+
md = pf.metadata
148+
d = json.loads(md.metadata[b"geo"].decode("utf-8"))
149+
found_encoding = d["columns"]["geometry"]["encoding"]
150+
if geometry_encoding == "WKB":
151+
expected_encoding = "WKB"
152+
elif shape_name == "circles":
153+
expected_encoding = "point"
154+
elif shape_name == "poly":
155+
expected_encoding = "polygon"
156+
elif shape_name in ["multipoly", "mixed"]:
157+
expected_encoding = "multipolygon"
158+
else:
159+
raise ValueError(
160+
f"Uncovered case for shape_name: {shape_name}, found encoding: {found_encoding}."
161+
)
162+
assert found_encoding == expected_encoding
163+
finally:
164+
spatialdata.config.settings.shapes_geometry_encoding = original_encoding
91165

92166
def test_points(
93167
self,

0 commit comments

Comments
 (0)