Skip to content

Commit 567446d

Browse files
committed
added option to use quantiles as an argument to _create_classes() & incorporated that change into train_test_validate(). Also added _cerate_classes' arguments as kwargs to train_test_validate()
1 parent 9f360d1 commit 567446d

1 file changed

Lines changed: 26 additions & 9 deletions

File tree

coderdata/split/splitter.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def train_test_validate(
2222
ratio: tuple[int, int, int]=(8,1,1),
2323
stratify_by: (str | None)=None,
2424
random_state: (int | RandomState | None)=None,
25+
**kwargs: dict,
2526
) -> tuple[DatasetLoader, DatasetLoader, DatasetLoader]:
2627
"""
2728
Splits a `CoderData` object (see also
@@ -76,6 +77,12 @@ def train_test_validate(
7677
7778
"""
7879

80+
# reading in the potential keyword arguments that will be passed to
81+
# _create_classes().
82+
thresh = kwargs.get('thresh', None)
83+
num_classes = kwargs.get('num_classes', 2)
84+
quantiles = kwargs.get('quantiles', True)
85+
7986
# Type checking split_type
8087
if split_type not in [
8188
'mixed-set', 'drug-blind', 'cancer-blind'
@@ -84,12 +91,13 @@ def train_test_validate(
8491
f"{split_type} not an excepted input for 'split_type'"
8592
)
8693

87-
df_full = data.experiments.copy()
94+
8895
# A wide (pivoted) table is more easy to work with in this instance.
8996
# The pivot is done using all columns but the 'dose_respones_value'
9097
# and 'dose_respones_metric' as index. df.pivot will generate a
9198
# MultiIndex which complicates things further down the line. To that
9299
# end 'reset_index()' is used to remove the MultiIndex
100+
df_full = data.experiments.copy()
93101
df_full = df_full.pivot(
94102
index = [
95103
'source',
@@ -225,8 +233,9 @@ def train_test_validate(
225233
df_full = _create_classes(
226234
data=df_full,
227235
metric=stratify_by,
228-
num_classes=2,
229-
thresh=0.5,
236+
num_classes=num_classes,
237+
thresh=thresh,
238+
quantiles=quantiles,
230239
)
231240
if split_type == 'mixed-set':
232241
# Using ShuffleSplit to generate randomized train and
@@ -316,7 +325,7 @@ def _filter(data: DatasetLoader, split: pd.DataFrame) -> DatasetLoader:
316325
'improve_drug_id',
317326
'study',
318327
'time',
319-
'time_unit'
328+
'time_unit',
320329
],
321330
var_name='dose_response_metric',
322331
value_name='dose_response_value'
@@ -352,7 +361,8 @@ def _filter(data: DatasetLoader, split: pd.DataFrame) -> DatasetLoader:
352361
def _create_classes(
353362
data: pd.DataFrame,
354363
metric: str,
355-
num_classes: int=3,
364+
num_classes: int=2,
365+
quantiles: bool=True,
356366
thresh: float=None,
357367
) -> pd.DataFrame:
358368
"""
@@ -372,11 +382,18 @@ def _create_classes(
372382
)
373383

374384
if thresh is None:
375-
data['split_class'] = pd.cut(
376-
data[metric],
377-
bins=num_classes,
378-
labels=False
385+
if quantiles:
386+
data['split_class'] = pd.qcut(
387+
data[metric],
388+
q=num_classes,
389+
labels=False,
379390
)
391+
else:
392+
data['split_class'] = pd.cut(
393+
data[metric],
394+
bins=num_classes,
395+
labels=False
396+
)
380397
elif num_classes == 2:
381398
data['split_class'] = pd.cut(
382399
data[metric],

0 commit comments

Comments
 (0)