|
1 | | -from pydantic import BaseModel, Field, model_validator, field_validator |
| 1 | +from pydantic import BaseModel, Field, model_validator, field_validator, field_serializer |
2 | 2 | from typing import List, Dict, Union, Optional, Literal, Tuple |
3 | 3 | from itertools import chain, combinations |
4 | 4 | import numpy as np |
@@ -77,7 +77,7 @@ def get_full_spike_indices(self, sorting: BaseSorting): |
77 | 77 |
|
78 | 78 | class CurationModel(BaseModel): |
79 | 79 | supported_versions: Tuple[Literal["1"], Literal["2"]] = Field( |
80 | | - default=["1", "2"], description="Supported versions of the curation format" |
| 80 | + default=("1", "2"), description="Supported versions of the curation format" |
81 | 81 | ) |
82 | 82 | format_version: str = Field(description="Version of the curation format") |
83 | 83 | unit_ids: List[Union[int, str]] = Field(description="List of unit IDs") |
@@ -238,11 +238,11 @@ def check_splits(cls, values): |
238 | 238 | for i, split in enumerate(splits): |
239 | 239 | if isinstance(split, dict): |
240 | 240 | split = dict(split) |
241 | | - if "indices" in split: |
| 241 | + if "indices" in split and split["indices"] is not None: |
242 | 242 | split["indices"] = [list(indices) for indices in split["indices"]] |
243 | | - if "labels" in split: |
| 243 | + if "labels" in split and split["labels"] is not None: |
244 | 244 | split["labels"] = list(split["labels"]) |
245 | | - if "new_unit_ids" in split: |
| 245 | + if "new_unit_ids" in split and split["new_unit_ids"] is not None: |
246 | 246 | split["new_unit_ids"] = list(split["new_unit_ids"]) |
247 | 247 | splits[i] = Split(**split) |
248 | 248 |
|
|
0 commit comments