Skip to content

Commit 6a74456

Browse files
committed
fixes to splitting function and format function
1 parent 314cdf9 commit 6a74456

1 file changed

Lines changed: 7 additions & 4 deletions

File tree

coderdata/dataset/dataset.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -305,10 +305,10 @@ def format(
305305
'experiments', 'combinations', 'drug_descriptor', 'drugs',
306306
'genes', 'samples',
307307
],
308-
use_polars: bool=False,
308+
remove_na: bool=False,
309309
**kwargs: dict,
310310
):
311-
return format(self, data_type=data_type, use_polars=use_polars, **kwargs)
311+
return format(self, data_type=data_type, remove_na=False, **kwargs)
312312

313313

314314
def split_train_other(
@@ -512,7 +512,7 @@ def format(
512512
'experiments', 'combinations', 'drug_descriptor', 'drugs',
513513
'genes', 'samples',
514514
],
515-
use_polars: bool=False,
515+
remove_na: bool=False,
516516
**kwargs: dict,
517517
):
518518

@@ -618,6 +618,8 @@ def format(
618618
columns = 'dose_response_metric',
619619
values = 'dose_response_value'
620620
).reset_index().rename_axis(None, axis=1)
621+
if remove_na:
622+
ret.dropna(axis='index', inplace=True)
621623
elif shape == 'matrix':
622624
if len(metrics) > 1:
623625
raise ValueError(
@@ -1182,7 +1184,8 @@ def _split_two_way(
11821184
columns = 'dose_response_metric',
11831185
values = 'dose_response_value'
11841186
).reset_index()
1185-
1187+
if stratify_by is not None:
1188+
df_full.dropna(axis='index', subset=[stratify_by], inplace=True)
11861189
# Defining the split sizes.
11871190
train_size = float(ratio[0]) / sum(ratio)
11881191
test_val_size = float(ratio[1]) / sum(ratio)

0 commit comments

Comments
 (0)