@@ -99,20 +99,22 @@ def _pivot_multiselect(
9999 df : pl .DataFrame , option_scores : pl .DataFrame , * , include_options : bool = False
100100 ) -> pl .DataFrame :
101101 del option_scores
102-
103- # Extract response_options before dropping item
102+ item_options_map : pl . DataFrame = pl . DataFrame ()
103+ # Extract ` response_options` before exploding (all options share the same list)
104104 if include_options :
105- df = df .with_columns (
106- item_option = pl .col ("item" ).struct .field ("response_options" ),
107- response_options = pl .col ("item" ).struct .field ("response_options" ),
108- )
109- else :
110- df = df .with_columns (
111- item_option = pl .col ("item" ).struct .field ("response_options" )
112- )
105+ # Get unique `response_options` per item (before exploding)
106+ item_options_map = df .select (
107+ [
108+ pl .col ("item" ).struct .field ("name" ).alias ("item_name" ),
109+ pl .col ("item" )
110+ .struct .field ("response_options" )
111+ .alias ("response_options" ),
112+ ]
113+ ).unique (subset = ["item_name" ])
113114
114115 df = (
115- df .explode ("item_option" )
116+ df .with_columns (item_option = pl .col ("item" ).struct .field ("response_options" ))
117+ .explode ("item_option" )
116118 # Generate value column indicating presence of response.
117119 .with_columns (
118120 response_present = pl .col ("item_option" )
@@ -129,15 +131,21 @@ def _pivot_multiselect(
129131 )
130132 )
131133 .drop ("item_option" , "item" )
134+ .pivot (
135+ on = ["item_option_pivot" ], values = "response_present" , sort_columns = True
136+ )
132137 )
133138
134- pivot_values = [ "response_present" ]
139+ # Join back the `response_options` for each item
135140 if include_options :
136- pivot_values .append ("response_options" )
141+ for row in item_options_map .iter_rows (named = True ):
142+ item_name = row ["item_name" ]
143+ options_col = f"{ item_name } _options"
144+ df = df .with_columns (
145+ [pl .lit (row ["response_options" ]).alias (options_col )]
146+ )
137147
138- return df .pivot (
139- on = ["item_option_pivot" ], values = pivot_values , sort_columns = True
140- )
148+ return df
141149
142150 @staticmethod
143151 def _map_response_column_names (cname : str ) -> str :
0 commit comments