Skip to content

Commit 8022f5c

Browse files
fix categories import (#1010)
* fix categories import * ensure reproducibility * add test for missing categories bug in PointsModel * remove comment --------- Co-authored-by: Luca Marconato <m.lucalmer@gmail.com>
1 parent 84db22b commit 8022f5c

2 files changed

Lines changed: 29 additions & 1 deletion

File tree

src/spatialdata/models/models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -891,7 +891,7 @@ def _add_metadata_and_validate(
891891
# It also just changes the state of the series, so it is not a big deal.
892892
if isinstance(data[c].dtype, CategoricalDtype) and not data[c].cat.known:
893893
try:
894-
data[c] = data[c].cat.set_categories(data[c].head(1).cat.categories)
894+
data[c] = data[c].cat.set_categories(data[c].compute().cat.categories)
895895
except ValueError:
896896
logger.info(f"Column `{c}` contains unknown categories. Consider casting it.")
897897

tests/models/test_models.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -830,3 +830,31 @@ def test_warning_on_large_chunks():
830830
assert len(w) == 1, "Warning should be raised for large chunk size"
831831
assert issubclass(w[-1].category, UserWarning)
832832
assert "Detected chunks larger than:" in str(w[-1].message)
833+
834+
835+
def test_categories_on_partitioned_dataframe(sdata_blobs: SpatialData):
836+
df = sdata_blobs["blobs_points"].compute()
837+
df["genes"] = RNG.choice([f"gene_{i}" for i in range(200)], len(df))
838+
N_PARTITIONS = 200
839+
ddf = dd.from_pandas(df, npartitions=N_PARTITIONS)
840+
ddf["genes"] = ddf["genes"].astype("category")
841+
842+
df["genes"] = df["genes"].astype("category")
843+
df_parsed = PointsModel.parse(df, npartitions=N_PARTITIONS)
844+
ddf_parsed = PointsModel.parse(ddf, npartitions=N_PARTITIONS)
845+
846+
assert df["genes"].equals(df_parsed["genes"].compute())
847+
assert df["genes"].cat.categories.equals(df_parsed["genes"].compute().cat.categories)
848+
849+
assert np.array_equal(df["genes"].to_numpy(), ddf_parsed["genes"].compute().to_numpy())
850+
assert set(df["genes"].cat.categories.tolist()) == set(ddf_parsed["genes"].compute().cat.categories.tolist())
851+
852+
# two behavior to investigate later/report to dask (they originate in dask)
853+
# TODO: df['genes'].cat.categories has dtype 'object', while ddf_parsed['genes'].compute().cat.categories has dtype
854+
# 'string'
855+
# this problem should disappear after pandas 3.0 is released
856+
assert df["genes"].cat.categories.dtype == "object"
857+
assert ddf_parsed["genes"].compute().cat.categories.dtype == "string"
858+
859+
# TODO: the list of categories are not preserving the order
860+
assert df["genes"].cat.categories.tolist() != ddf_parsed["genes"].compute().cat.categories.tolist()

0 commit comments

Comments
 (0)