@@ -305,10 +305,10 @@ def format(
305305 'experiments' , 'combinations' , 'drug_descriptor' , 'drugs' ,
306306 'genes' , 'samples' ,
307307 ],
308- use_polars : bool = False ,
308+ remove_na : bool = False ,
309309 ** kwargs : dict ,
310310 ):
311- return format (self , data_type = data_type , use_polars = use_polars , ** kwargs )
311+ return format (self , data_type = data_type , remove_na = False , ** kwargs )
312312
313313
314314 def split_train_other (
@@ -512,7 +512,7 @@ def format(
512512 'experiments' , 'combinations' , 'drug_descriptor' , 'drugs' ,
513513 'genes' , 'samples' ,
514514 ],
515- use_polars : bool = False ,
515+ remove_na : bool = False ,
516516 ** kwargs : dict ,
517517 ):
518518
@@ -618,6 +618,8 @@ def format(
618618 columns = 'dose_response_metric' ,
619619 values = 'dose_response_value'
620620 ).reset_index ().rename_axis (None , axis = 1 )
621+ if remove_na :
622+ ret .dropna (axis = 'index' , inplace = True )
621623 elif shape == 'matrix' :
622624 if len (metrics ) > 1 :
623625 raise ValueError (
@@ -1182,7 +1184,8 @@ def _split_two_way(
11821184 columns = 'dose_response_metric' ,
11831185 values = 'dose_response_value'
11841186 ).reset_index ()
1185-
1187+ if stratify_by is not None :
1188+ df_full .dropna (axis = 'index' , subset = [stratify_by ], inplace = True )
11861189 # Defining the split sizes.
11871190 train_size = float (ratio [0 ]) / sum (ratio )
11881191 test_val_size = float (ratio [1 ]) / sum (ratio )
0 commit comments