Skip to content

Commit 4094c7e

Browse files
committed
✨ Include _options, _index and _score in RedcapImportFormat
1 parent 0c7b569 commit 4094c7e

1 file changed

Lines changed: 143 additions & 51 deletions

File tree

src/mindlogger_data_export/outputs.py

Lines changed: 143 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,15 @@
22

33
from __future__ import annotations
44

5+
import inspect
56
import logging
67
import re
78
from abc import ABC
89
from collections.abc import Callable, Generator
910
from dataclasses import dataclass
11+
from functools import partial
1012
from pathlib import Path
13+
from typing import Protocol, Self
1114

1215
import polars as pl
1316
import polars.selectors as cs
@@ -19,6 +22,20 @@
1922
LOG = logging.getLogger(__name__)
2023

2124

25+
class PivotFunction(Protocol):
26+
"""Pivot Functions with keyword-only `include_options`."""
27+
28+
def __call__(
29+
self,
30+
df: pl.DataFrame,
31+
option_scores: pl.DataFrame,
32+
*,
33+
include_options: bool = False,
34+
) -> Self:
35+
"""Signature for Pivot Functions with keyword-only `include_options`."""
36+
...
37+
38+
2239
@dataclass
2340
class NamedOutput:
2441
"""Represents named output data to be written."""
@@ -72,21 +89,35 @@ class WideFormat(Output):
7289

7390
NAME = "wide"
7491

92+
def __init__(self, *args, include_options: bool = False, **kwargs) -> None:
93+
"""Initialize wide data format without options by default."""
94+
super().__init__(*args, **kwargs)
95+
self._include_options = include_options
96+
7597
@staticmethod
7698
def _pivot_multiselect(
77-
df: pl.DataFrame, option_scores: pl.DataFrame
99+
df: pl.DataFrame, option_scores: pl.DataFrame, *, include_options: bool = False
78100
) -> pl.DataFrame:
79101
del option_scores
80-
return (
81-
df.with_columns(item_option=pl.col("item").struct.field("response_options"))
82-
.explode("item_option")
102+
103+
# Extract response_options before dropping item
104+
if include_options:
105+
df = df.with_columns(
106+
item_option=pl.col("item").struct.field("response_options"),
107+
response_options=pl.col("item").struct.field("response_options"),
108+
)
109+
else:
110+
df = df.with_columns(
111+
item_option=pl.col("item").struct.field("response_options")
112+
)
113+
114+
df = (
115+
df.explode("item_option")
83116
# Generate value column indicating presence of response.
84117
.with_columns(
85118
response_present=pl.col("item_option")
86119
.struct.field("value")
87120
.is_in(pl.col("response_value").struct.field("value")),
88-
# response_index=pl.col("item_option").struct.field("value"),
89-
# response_name=pl.col("item_option").struct.field("name"),
90121
)
91122
.drop("response_value")
92123
# Generate pivot column.
@@ -98,14 +129,21 @@ def _pivot_multiselect(
98129
)
99130
)
100131
.drop("item_option", "item")
101-
.pivot(
102-
on=["item_option_pivot"], values="response_present", sort_columns=True
103-
)
132+
)
133+
134+
pivot_values = ["response_present"]
135+
if include_options:
136+
pivot_values.append("response_options")
137+
138+
return df.pivot(
139+
on=["item_option_pivot"], values=pivot_values, sort_columns=True
104140
)
105141

106142
@staticmethod
107143
def _map_response_column_names(cname: str) -> str:
108144
parts = cname.split("__", 1)
145+
if parts[0] == "response_options":
146+
return f"{parts[1]}_options"
109147
return "_".join([parts[1], parts[0].removeprefix("response")])
110148

111149
@staticmethod
@@ -115,7 +153,7 @@ def _fill_item_response(*null_score_columns: str) -> Generator[pl.Expr, None, No
115153

116154
@staticmethod
117155
def _pivot_singleselect(
118-
df: pl.DataFrame, option_scores: pl.DataFrame
156+
df: pl.DataFrame, option_scores: pl.DataFrame, *, include_options: bool = False
119157
) -> pl.DataFrame:
120158
# Rename columns in scores table.
121159
response_options = option_scores.rename(
@@ -145,12 +183,30 @@ def _pivot_singleselect(
145183
how="left",
146184
validate="m:1",
147185
)
148-
# Extract item name for pivot.
149-
.with_columns(item_name=pl.col("item").struct.field("name"))
150-
.drop("item")
151-
# Pivot on item_name producing 3 columns for each item.
152-
.pivot(on="item_name", values=cs.starts_with("response"), separator="__")
153-
# Rename pivoted columns to
186+
)
187+
188+
# Optionally extract response_options before dropping item
189+
if include_options:
190+
df = df.with_columns(
191+
response_options=pl.col("item").struct.field("response_options")
192+
)
193+
194+
# Extract item name for pivot.
195+
df = df.with_columns(item_name=pl.col("item").struct.field("name")).drop("item")
196+
197+
# Determine which columns to pivot
198+
pivot_values = cs.starts_with("response")
199+
if include_options:
200+
pivot_values = [
201+
"response_index",
202+
"response_score",
203+
"response_response",
204+
"response_options",
205+
]
206+
207+
df = (
208+
df.pivot(on="item_name", values=pivot_values, separator="__")
209+
# Rename pivoted columns
154210
.with_columns(
155211
cs.starts_with("response").name.map(
156212
WideFormat._map_response_column_names
@@ -173,8 +229,14 @@ def _pivot_singleselect(
173229
return df.with_columns(WideFormat._fill_item_response(*null_score_columns))
174230

175231
@staticmethod
176-
def _pivot_text(df: pl.DataFrame, option_scores: pl.DataFrame) -> pl.DataFrame:
232+
def _pivot_text(
233+
df: pl.DataFrame, option_scores: pl.DataFrame, *, include_options=NotImplemented
234+
) -> pl.DataFrame:
235+
current_frame = inspect.currentframe()
236+
_qualname = current_frame.f_code.co_qualname if current_frame else __name__
237+
LOG.debug("`include_options` parameter %s for %s", include_options, _qualname)
177238
del option_scores
239+
178240
return (
179241
df.with_columns(
180242
response_value=pl.col("response_value").struct.field("text"),
@@ -197,30 +259,41 @@ def _pivot_text(df: pl.DataFrame, option_scores: pl.DataFrame) -> pl.DataFrame:
197259
)
198260

199261
@staticmethod
200-
def _pivot_subscale(df: pl.DataFrame, option_scores: pl.DataFrame) -> pl.DataFrame:
262+
def _pivot_subscale(
263+
df: pl.DataFrame, option_scores: pl.DataFrame, *, include_options: bool = False
264+
) -> pl.DataFrame:
201265
del option_scores
202-
return (
203-
df.with_columns(
204-
response_value=pl.col("response_value").struct.field("subscale"),
205-
item_name=pl.col("item").struct.field("name"),
206-
)
207-
.with_columns(response_response=pl.col("response_value"))
208-
.drop("item")
209-
.pivot(
210-
on="item_name",
211-
values=["response_value", "response_response"],
212-
separator="__",
213-
)
214-
.rename(
215-
lambda s: s.removesuffix("__response_response")
216-
if s.endswith("__response_response")
217-
else s.removesuffix("_value")
218-
if s.endswith("__response_value")
219-
else s
266+
267+
df = df.with_columns(
268+
response_value=pl.col("response_value").struct.field("subscale"),
269+
item_name=pl.col("item").struct.field("name"),
270+
).with_columns(response_response=pl.col("response_value"))
271+
272+
# Optionally extract response_options before dropping item
273+
if include_options:
274+
df = df.with_columns(
275+
response_options=pl.col("item").struct.field("response_options")
220276
)
277+
278+
df = df.drop("item")
279+
280+
pivot_values = ["response_value", "response_response"]
281+
if include_options:
282+
pivot_values.append("response_options")
283+
284+
return df.pivot(
285+
on="item_name",
286+
values=pivot_values,
287+
separator="__",
288+
).rename(
289+
lambda s: s.removesuffix("__response_response")
290+
if s.endswith("__response_response")
291+
else s.removesuffix("_value")
292+
if s.endswith("__response_value")
293+
else s
221294
)
222295

223-
PIVOT_FNS = {
296+
PIVOT_FNS: dict[tuple[ItemType], PivotFunction] = {
224297
(ItemType.MultipleSelection,): _pivot_multiselect,
225298
(ItemType.SingleSelection,): _pivot_singleselect,
226299
(ItemType.Text,): _pivot_text,
@@ -230,7 +303,9 @@ def _pivot_subscale(df: pl.DataFrame, option_scores: pl.DataFrame) -> pl.DataFra
230303
def _get_pivot_fn(
231304
self, partition_type: tuple[ItemType]
232305
) -> Callable[[pl.DataFrame, pl.DataFrame], pl.DataFrame]:
233-
return self.PIVOT_FNS.get(partition_type, self._pivot_text)
306+
pivot_fn = self.PIVOT_FNS.get(partition_type, self._pivot_text)
307+
# Use partial to bind the keyword-only argument
308+
return partial(pivot_fn, include_options=self._include_options)
234309

235310
def _typed_pivot(
236311
self, df: pl.DataFrame, option_scores: pl.DataFrame
@@ -319,7 +394,8 @@ class RedcapImportFormat(WideFormat):
319394

320395
NAME = "redcap"
321396

322-
def __init__(self, project: str = "curious_parent_arm_1", *args, **kwargs):
397+
def __init__(self, project: str = "curious_parent_arm_1", *args, **kwargs) -> None:
398+
kwargs.setdefault("include_options", True)
323399
super().__init__(*args, **kwargs)
324400
self._project = project
325401
self._instrument_row_count: dict[str, int | None] = {}
@@ -346,12 +422,25 @@ def _prepare_activity_columns(
346422
{
347423
col: col[:-5]
348424
for col in df.columns
349-
if col.endswith("_start_time", "_end_time")
425+
if col.endswith(("_start_time", "_end_time"))
350426
}
351427
)
352428

353-
# Handle response columns
354-
# Text remains the same, selections = index + 1
429+
# Stringify `_options` columns
430+
options_cols = [col for col in df.columns if col.endswith("_options")]
431+
for col in options_cols:
432+
df = df.with_columns(
433+
[
434+
pl.format(
435+
"[{}]",
436+
pl.col(col)
437+
.list.eval(pl.element().struct.json_encode())
438+
.list.join(", "),
439+
).alias(col)
440+
]
441+
)
442+
443+
# For non-text items, drop the `_response` columns
355444
response_cols = [col for col in df.columns if col.endswith("_response")]
356445
index_cols = [col for col in df.columns if col.endswith("_index")]
357446
index_bases = {col.replace("_index", "") for col in index_cols}
@@ -369,19 +458,22 @@ def _prepare_activity_columns(
369458
)
370459
]
371460
)
461+
462+
# For items with `_index` but no `_score`, create `_score` from `_index`
463+
score_cols = [col for col in df.columns if col.endswith("_score")]
464+
score_bases = {col.replace("_score", "") for col in score_cols}
465+
for col in index_cols:
466+
base_name = col.replace("_index", "")
467+
if base_name not in score_bases:
468+
score_col = f"{base_name}_score"
469+
df = df.with_columns([pl.col(col).alias(score_col)])
470+
471+
# Create REDCap `_response` columns from `_index` for select items (`_index + 1`)
372472
for col in index_cols:
373473
response_col = col.replace("_index", "_response")
374474
df = df.with_columns([(pl.col(col) + 1).alias(response_col)])
375-
df = df.select([col for col in df.columns if not col.endswith("_index")])
376475

377-
# Drop bare item columns that have a corresponding _response column
378-
# These are score columns that we don't need
379-
response_bases = {
380-
col.replace("_response", "")
381-
for col in df.columns
382-
if col.endswith("_response")
383-
}
384-
return df.select([col for col in df.columns if col not in response_bases])
476+
return df
385477

386478
def _format_activity(self, df: pl.DataFrame, activity_name: str) -> pl.DataFrame:
387479
"""Format a single activity's data for REDCap import."""

0 commit comments

Comments
 (0)