Skip to content

Commit b11e807

Browse files
committed
Updated documentation
1 parent 567446d commit b11e807

1 file changed

Lines changed: 140 additions & 35 deletions

File tree

coderdata/split/splitter.py

Lines changed: 140 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
"""
2+
Contains functions that help in generating training, testing and
3+
validation sets for CoderData Objects.
4+
"""
15

26
from copy import deepcopy
37
from typing import Literal
@@ -14,6 +18,7 @@
1418

1519
from coderdata.load.loader import DatasetLoader
1620

21+
1722
def train_test_validate(
1823
data: DatasetLoader,
1924
split_type: Literal[
@@ -27,11 +32,16 @@ def train_test_validate(
2732
"""
2833
Splits a `CoderData` object (see also
2934
`coderdata.load.loader.DatasetLoader`) into three subsets for
30-
training, testing and validating machine learning algorithms. The
31-
Size of the splits is fixed to 80:10:10 for train:test:validate. The
32-
function allows for additional optional arguments, that define the
33-
type of split that is performed as well as a random seed to enable
34-
the creation of reproducable splits.
35+
training, testing and validating machine learning algorithms.
36+
37+
The size of the splits can be adjusted to be different from 80:10:10
38+
(the default)for train:test:validate. The function also allows for
39+
additional optional arguments, that define the type of split that is
40+
performed ('mixed-set', 'drug-blind', 'cancer-blind'), if the splits
41+
should be stratified (and which drug response metric to use), as
42+
well as a random seed to enable the creation of reproducable splits.
43+
Furhermore, a list of keyword arguments can be defined that will be
44+
passed to the stratification function if so desired.
3545
3646
Parameters
3747
----------
@@ -41,8 +51,9 @@ def train_test_validate(
4151
`coderdata.download.downloader.download_data_by_prefix`) or
4252
built locally via the `build_all` process. The object must first
4353
be loaded via `coderdata.load.loader.DatasetLoader`.
44-
split_type : Literal['mixed-set', 'drug-blind', 'cancer-blind', \
45-
'disjoint'], optional
54+
split_type : {'mixed-set', 'drug-blind', 'cancer-blind'}, \
55+
default='mixed-set'
56+
4657
Defines the type of split that should be generated:
4758
4859
- *mixed-set*: Splits randomly independent of drug / cancer
@@ -55,15 +66,23 @@ def train_test_validate(
5566
- *cancer-blind*: Splits according to cancer association.
5667
Equivalent to drug-blind, except cancer types will be unique
5768
to splits.
58-
59-
Defaults to *mixed-set*.
60-
61-
random_state : int | RandomState | None, optional
69+
ratio : tuple[int, int, int], default=(8,1,1)
70+
Defines the size ratio of the resulting test, train and
71+
validation sets.
72+
stratify_by : str | None, default=None
73+
Defines if the training, testing and validation sets should be
74+
stratified. Any value other than None indicates stratification
75+
and defines which drug response value should be used as basis
76+
for the stratification. _None_ indicates that no stratfication
77+
should be performed.
78+
random_state : int | RandomState | None, defaul=None
6279
Defines a seed value for the randomization of the splits. Will
6380
get passed to internal functions. Providing the seed will enable
6481
reproducability of the generated splits.
65-
66-
Defaults to None
82+
**kwargs
83+
Additional keyword arguments that will be passed to the function
84+
that generates classes for the stratification
85+
(see also ``_create_classes``).
6786
6887
Returns
6988
-------
@@ -91,7 +110,6 @@ def train_test_validate(
91110
f"{split_type} not an excepted input for 'split_type'"
92111
)
93112

94-
95113
# A wide (pivoted) table is more easy to work with in this instance.
96114
# The pivot is done using all columns but the 'dose_respones_value'
97115
# and 'dose_respones_metric' as index. df.pivot will generate a
@@ -163,11 +181,13 @@ def train_test_validate(
163181
test_size=validate_size,
164182
random_state=random_state
165183
)
166-
sgk = StratifiedGroupKFold(
167-
n_splits=sum(ratio),
168-
shuffle=True,
169-
random_state=random_state
170-
)
184+
185+
# StratifiedShuffleSplit is similar to ShuffleSplit with the added
186+
# functionality to also stratify the splits according to defined
187+
# class labels.
188+
#
189+
# StratifiedShuffleSplit will be used for stratified mixed-set
190+
# train/test/validate sets.
171191

172192
sss_1 = StratifiedShuffleSplit(
173193
n_splits=1,
@@ -182,6 +202,25 @@ def train_test_validate(
182202
random_state=random_state
183203
)
184204

205+
# StratifiedGroupKFold generates K folds that take the group into
206+
# account when generating folds, i.e. a group will only be present
207+
# in one fold. It further tries to stratify the folds based on the
208+
# defined classes.
209+
#
210+
# StratifiedGroupKFold will be used for stratified drug-/sample-
211+
# blind splitting.
212+
#
213+
# The way the K folds are utilized is to combine i, j, & k folds
214+
# (according to the defined ratio) into training, testing and
215+
# validation sets.
216+
sgk = StratifiedGroupKFold(
217+
n_splits=sum(ratio),
218+
shuffle=True,
219+
random_state=random_state
220+
)
221+
222+
# The "actual" splitting logic using the defined Splitters as above
223+
# follows here starting with the non-stratified splitting:
185224
if stratify_by is None:
186225
if split_type == 'mixed-set':
187226
# Using ShuffleSplit to generate randomized train and
@@ -229,7 +268,12 @@ def train_test_validate(
229268
# sampled indices
230269
df_test = df_other.iloc[idx1]
231270
df_val = df_other.iloc[idx2]
271+
272+
# The following block contains the stratified splitting logic
232273
else:
274+
# First the classes that are needed for the stratification are
275+
# generated. `num_classes`, `thresh` and `quantiles` were
276+
# previously defined as possible keyword arguments.
233277
df_full = _create_classes(
234278
data=df_full,
235279
metric=stratify_by,
@@ -238,20 +282,23 @@ def train_test_validate(
238282
quantiles=quantiles,
239283
)
240284
if split_type == 'mixed-set':
241-
# Using ShuffleSplit to generate randomized train and
242-
# 'other' set, since there is no need for grouping.
285+
# Using StratifiedShuffleSplit to generate randomized train
286+
# and 'other' set, since there is no need for grouping.
243287
idx_train, idx_other = next(
244288
sss_1.split(X=df_full, y=df_full['split_class'])
245289
)
246290
df_train = df_full.iloc[idx_train]
247291
df_other = df_full.iloc[idx_other]
292+
# Splitting 'other' further into test and validate
248293
idx_test, idx_val = next(
249294
sss_2.split(X=df_other, y=df_other['split_class'])
250295
)
251296
df_test = df_other.iloc[idx_test]
252297
df_val = df_other.iloc[idx_val]
298+
299+
# using StratifiedGroupKSplit for the stratified drug-/sample-
300+
# blind splits.
253301
elif split_type == 'drug-blind' or split_type == 'cancer-blind':
254-
255302
if split_type == 'drug-blind':
256303
splitter = enumerate(
257304
sgk.split(
@@ -268,11 +315,18 @@ def train_test_validate(
268315
groups=df_full.improve_sample_id
269316
)
270317
)
271-
318+
319+
# StratifiedGroupKSplit is setup to generate K splits where
320+
# K=sum(ratios) (e.g. 10 if ratio=8:1:1). To obtain three
321+
# sets (train/test/validate) the individual splits need to
322+
# be combined (e.g. k=[1:8] -> train, k=9 -> test, k=10 ->
323+
# validate). The code block below does that by combining
324+
# all indices (row numbers) that go into individual sets and
325+
# then extracting and adding those rows into the individual
326+
# sets.
272327
idx_train = []
273328
idx_test = []
274329
idx_val = []
275-
276330
for i, (idx1, idx2) in splitter:
277331
if i < ratio[0]:
278332
idx_train.extend(idx2)
@@ -302,13 +356,22 @@ def _filter(data: DatasetLoader, split: pd.DataFrame) -> DatasetLoader:
302356
Helper function to filter down the CoderData object(s) to create
303357
indipendent more concise CoderData objects for training, testing
304358
and validation splits.
305-
"""
306359
307-
# cd.drugs -> reduce based on improve_drug_id
308-
# cd.mutations -> reduce based on improve_sample_id
309-
# cd.proteomics -> reduce based on improve_sample_id
310-
# cd.samples -> reduce based on improve_sample_id
311-
# cd.transcriptomics -> reduce based on improve_sample_id
360+
Parameters
361+
----------
362+
data : DatasetLoader
363+
CoderData object containing the "full" dataset, e.g. the dataset
364+
that splits are based on.
365+
split : pandas.DataFrame
366+
Contains a subset of rows from the data.experiments DataFrame
367+
that correspond to the generated split.
368+
369+
Returns
370+
-------
371+
DatasetLoader
372+
A CoderData object that is a subset of ``data`` containing only
373+
the data points that pertain to the information in ``split``.
374+
"""
312375

313376
# extracting improve sample and drug ids from the provided split
314377
sample_ids = np.unique(split['improve_sample_id'].values)
@@ -338,6 +401,16 @@ def _filter(data: DatasetLoader, split: pd.DataFrame) -> DatasetLoader:
338401

339402
# filtering each individual data type down by only the improve
340403
# sample / drug ids that are present in the split (extracted above)
404+
#
405+
# Datapoints in the individual datasets in the CoderData object are
406+
# filtered by the following "rules":
407+
#
408+
# cd.drugs -> reduce based on improve_drug_id
409+
# cd.mutations -> reduce based on improve_sample_id
410+
# cd.proteomics -> reduce based on improve_sample_id
411+
# cd.samples -> reduce based on improve_sample_id
412+
# cd.transcriptomics -> reduce based on improve_sample_id
413+
341414
data_ret.drugs = data_ret.drugs[
342415
data_ret.drugs['improve_drug_id'].isin(drug_ids)
343416
]
@@ -369,6 +442,43 @@ def _create_classes(
369442
Helper function that bins experiment data into a number of defined
370443
classes for use with Stratified Splits.
371444
445+
Parameters
446+
----------
447+
data : pandas.DataFrame
448+
The DataFrame containing drug response data (experiment) which
449+
is subject to binning / creating classes based on the
450+
drug response metric
451+
metric : str
452+
The drug response metric upon which the class generation should
453+
be based on. Needs to be in data['drug_response_metric'].
454+
num_classes : int, default=2
455+
Number of classes that should be generated. Defaults to 2.
456+
quantiles : bool, default=True
457+
Defines whether the individual bins should be based on quantiles
458+
(or percentiles) instead of "evenly" spaced. If true then the
459+
"bin" size will be chosen such that roughly the same number of
460+
data points fall into each class
461+
thresh : float, default=None
462+
Optional argument that defines a threshold other than the mean
463+
of the drug response metric if ``num_classes=2``. Can be used to
464+
generate "uneven" classes.
465+
466+
Returns
467+
-------
468+
pandas.DataFrame
469+
DataFrame that is the same as the input with additional column
470+
that defines the established class association of each data
471+
point.
472+
473+
Raises
474+
------
475+
ValueError
476+
If the chosen ``metric`` is not present in the
477+
`drug_response_metric` column of ``data``.
478+
ValueError
479+
If ``num_classes`` < 2.
480+
ValueError
481+
If ``thresh`` is defined but ``num_classes`` > 2.
372482
"""
373483

374484
if metric not in data.columns:
@@ -408,8 +518,3 @@ def _create_classes(
408518
)
409519

410520
return data
411-
412-
413-
def get_subset(df_full: pd.DataFrame, df_subset: pd.DataFrame):
414-
idx = df_subset.index
415-
return df_full.drop(idx, axis='index')

0 commit comments

Comments
 (0)