1+ """
2+ Contains functions that help in generating training, testing and
3+ validation sets for CoderData Objects.
4+ """
15
26from copy import deepcopy
37from typing import Literal
1418
1519from coderdata .load .loader import DatasetLoader
1620
21+
1722def train_test_validate (
1823 data : DatasetLoader ,
1924 split_type : Literal [
@@ -27,11 +32,16 @@ def train_test_validate(
2732 """
2833 Splits a `CoderData` object (see also
2934 `coderdata.load.loader.DatasetLoader`) into three subsets for
30- training, testing and validating machine learning algorithms. The
31- Size of the splits is fixed to 80:10:10 for train:test:validate. The
32- function allows for additional optional arguments, that define the
33- type of split that is performed as well as a random seed to enable
34- the creation of reproducable splits.
35+ training, testing and validating machine learning algorithms.
36+
37+ The size of the splits can be adjusted to be different from 80:10:10
38+ (the default)for train:test:validate. The function also allows for
39+ additional optional arguments, that define the type of split that is
40+ performed ('mixed-set', 'drug-blind', 'cancer-blind'), if the splits
41+ should be stratified (and which drug response metric to use), as
42+ well as a random seed to enable the creation of reproducable splits.
43+ Furhermore, a list of keyword arguments can be defined that will be
44+ passed to the stratification function if so desired.
3545
3646 Parameters
3747 ----------
@@ -41,8 +51,9 @@ def train_test_validate(
4151 `coderdata.download.downloader.download_data_by_prefix`) or
4252 built locally via the `build_all` process. The object must first
4353 be loaded via `coderdata.load.loader.DatasetLoader`.
44- split_type : Literal['mixed-set', 'drug-blind', 'cancer-blind', \
45- 'disjoint'], optional
54+ split_type : {'mixed-set', 'drug-blind', 'cancer-blind'}, \
55+ default='mixed-set'
56+
4657 Defines the type of split that should be generated:
4758
4859 - *mixed-set*: Splits randomly independent of drug / cancer
@@ -55,15 +66,23 @@ def train_test_validate(
5566 - *cancer-blind*: Splits according to cancer association.
5667 Equivalent to drug-blind, except cancer types will be unique
5768 to splits.
58-
59- Defaults to *mixed-set*.
60-
61- random_state : int | RandomState | None, optional
69+ ratio : tuple[int, int, int], default=(8,1,1)
70+ Defines the size ratio of the resulting test, train and
71+ validation sets.
72+ stratify_by : str | None, default=None
73+ Defines if the training, testing and validation sets should be
74+ stratified. Any value other than None indicates stratification
75+ and defines which drug response value should be used as basis
76+ for the stratification. _None_ indicates that no stratfication
77+ should be performed.
78+ random_state : int | RandomState | None, defaul=None
6279 Defines a seed value for the randomization of the splits. Will
6380 get passed to internal functions. Providing the seed will enable
6481 reproducability of the generated splits.
65-
66- Defaults to None
82+ **kwargs
83+ Additional keyword arguments that will be passed to the function
84+ that generates classes for the stratification
85+ (see also ``_create_classes``).
6786
6887 Returns
6988 -------
@@ -91,7 +110,6 @@ def train_test_validate(
91110 f"{ split_type } not an excepted input for 'split_type'"
92111 )
93112
94-
95113 # A wide (pivoted) table is more easy to work with in this instance.
96114 # The pivot is done using all columns but the 'dose_respones_value'
97115 # and 'dose_respones_metric' as index. df.pivot will generate a
@@ -163,11 +181,13 @@ def train_test_validate(
163181 test_size = validate_size ,
164182 random_state = random_state
165183 )
166- sgk = StratifiedGroupKFold (
167- n_splits = sum (ratio ),
168- shuffle = True ,
169- random_state = random_state
170- )
184+
185+ # StratifiedShuffleSplit is similar to ShuffleSplit with the added
186+ # functionality to also stratify the splits according to defined
187+ # class labels.
188+ #
189+ # StratifiedShuffleSplit will be used for stratified mixed-set
190+ # train/test/validate sets.
171191
172192 sss_1 = StratifiedShuffleSplit (
173193 n_splits = 1 ,
@@ -182,6 +202,25 @@ def train_test_validate(
182202 random_state = random_state
183203 )
184204
205+ # StratifiedGroupKFold generates K folds that take the group into
206+ # account when generating folds, i.e. a group will only be present
207+ # in one fold. It further tries to stratify the folds based on the
208+ # defined classes.
209+ #
210+ # StratifiedGroupKFold will be used for stratified drug-/sample-
211+ # blind splitting.
212+ #
213+ # The way the K folds are utilized is to combine i, j, & k folds
214+ # (according to the defined ratio) into training, testing and
215+ # validation sets.
216+ sgk = StratifiedGroupKFold (
217+ n_splits = sum (ratio ),
218+ shuffle = True ,
219+ random_state = random_state
220+ )
221+
222+ # The "actual" splitting logic using the defined Splitters as above
223+ # follows here starting with the non-stratified splitting:
185224 if stratify_by is None :
186225 if split_type == 'mixed-set' :
187226 # Using ShuffleSplit to generate randomized train and
@@ -229,7 +268,12 @@ def train_test_validate(
229268 # sampled indices
230269 df_test = df_other .iloc [idx1 ]
231270 df_val = df_other .iloc [idx2 ]
271+
272+ # The following block contains the stratified splitting logic
232273 else :
274+ # First the classes that are needed for the stratification are
275+ # generated. `num_classes`, `thresh` and `quantiles` were
276+ # previously defined as possible keyword arguments.
233277 df_full = _create_classes (
234278 data = df_full ,
235279 metric = stratify_by ,
@@ -238,20 +282,23 @@ def train_test_validate(
238282 quantiles = quantiles ,
239283 )
240284 if split_type == 'mixed-set' :
241- # Using ShuffleSplit to generate randomized train and
242- # 'other' set, since there is no need for grouping.
285+ # Using StratifiedShuffleSplit to generate randomized train
286+ # and 'other' set, since there is no need for grouping.
243287 idx_train , idx_other = next (
244288 sss_1 .split (X = df_full , y = df_full ['split_class' ])
245289 )
246290 df_train = df_full .iloc [idx_train ]
247291 df_other = df_full .iloc [idx_other ]
292+ # Splitting 'other' further into test and validate
248293 idx_test , idx_val = next (
249294 sss_2 .split (X = df_other , y = df_other ['split_class' ])
250295 )
251296 df_test = df_other .iloc [idx_test ]
252297 df_val = df_other .iloc [idx_val ]
298+
299+ # using StratifiedGroupKSplit for the stratified drug-/sample-
300+ # blind splits.
253301 elif split_type == 'drug-blind' or split_type == 'cancer-blind' :
254-
255302 if split_type == 'drug-blind' :
256303 splitter = enumerate (
257304 sgk .split (
@@ -268,11 +315,18 @@ def train_test_validate(
268315 groups = df_full .improve_sample_id
269316 )
270317 )
271-
318+
319+ # StratifiedGroupKSplit is setup to generate K splits where
320+ # K=sum(ratios) (e.g. 10 if ratio=8:1:1). To obtain three
321+ # sets (train/test/validate) the individual splits need to
322+ # be combined (e.g. k=[1:8] -> train, k=9 -> test, k=10 ->
323+ # validate). The code block below does that by combining
324+ # all indices (row numbers) that go into individual sets and
325+ # then extracting and adding those rows into the individual
326+ # sets.
272327 idx_train = []
273328 idx_test = []
274329 idx_val = []
275-
276330 for i , (idx1 , idx2 ) in splitter :
277331 if i < ratio [0 ]:
278332 idx_train .extend (idx2 )
@@ -302,13 +356,22 @@ def _filter(data: DatasetLoader, split: pd.DataFrame) -> DatasetLoader:
302356 Helper function to filter down the CoderData object(s) to create
303357 indipendent more concise CoderData objects for training, testing
304358 and validation splits.
305- """
306359
307- # cd.drugs -> reduce based on improve_drug_id
308- # cd.mutations -> reduce based on improve_sample_id
309- # cd.proteomics -> reduce based on improve_sample_id
310- # cd.samples -> reduce based on improve_sample_id
311- # cd.transcriptomics -> reduce based on improve_sample_id
360+ Parameters
361+ ----------
362+ data : DatasetLoader
363+ CoderData object containing the "full" dataset, e.g. the dataset
364+ that splits are based on.
365+ split : pandas.DataFrame
366+ Contains a subset of rows from the data.experiments DataFrame
367+ that correspond to the generated split.
368+
369+ Returns
370+ -------
371+ DatasetLoader
372+ A CoderData object that is a subset of ``data`` containing only
373+ the data points that pertain to the information in ``split``.
374+ """
312375
313376 # extracting improve sample and drug ids from the provided split
314377 sample_ids = np .unique (split ['improve_sample_id' ].values )
@@ -338,6 +401,16 @@ def _filter(data: DatasetLoader, split: pd.DataFrame) -> DatasetLoader:
338401
339402 # filtering each individual data type down by only the improve
340403 # sample / drug ids that are present in the split (extracted above)
404+ #
405+ # Datapoints in the individual datasets in the CoderData object are
406+ # filtered by the following "rules":
407+ #
408+ # cd.drugs -> reduce based on improve_drug_id
409+ # cd.mutations -> reduce based on improve_sample_id
410+ # cd.proteomics -> reduce based on improve_sample_id
411+ # cd.samples -> reduce based on improve_sample_id
412+ # cd.transcriptomics -> reduce based on improve_sample_id
413+
341414 data_ret .drugs = data_ret .drugs [
342415 data_ret .drugs ['improve_drug_id' ].isin (drug_ids )
343416 ]
@@ -369,6 +442,43 @@ def _create_classes(
369442 Helper function that bins experiment data into a number of defined
370443 classes for use with Stratified Splits.
371444
445+ Parameters
446+ ----------
447+ data : pandas.DataFrame
448+ The DataFrame containing drug response data (experiment) which
449+ is subject to binning / creating classes based on the
450+ drug response metric
451+ metric : str
452+ The drug response metric upon which the class generation should
453+ be based on. Needs to be in data['drug_response_metric'].
454+ num_classes : int, default=2
455+ Number of classes that should be generated. Defaults to 2.
456+ quantiles : bool, default=True
457+ Defines whether the individual bins should be based on quantiles
458+ (or percentiles) instead of "evenly" spaced. If true then the
459+ "bin" size will be chosen such that roughly the same number of
460+ data points fall into each class
461+ thresh : float, default=None
462+ Optional argument that defines a threshold other than the mean
463+ of the drug response metric if ``num_classes=2``. Can be used to
464+ generate "uneven" classes.
465+
466+ Returns
467+ -------
468+ pandas.DataFrame
469+ DataFrame that is the same as the input with additional column
470+ that defines the established class association of each data
471+ point.
472+
473+ Raises
474+ ------
475+ ValueError
476+ If the chosen ``metric`` is not present in the
477+ `drug_response_metric` column of ``data``.
478+ ValueError
479+ If ``num_classes`` < 2.
480+ ValueError
481+ If ``thresh`` is defined but ``num_classes`` > 2.
372482 """
373483
374484 if metric not in data .columns :
@@ -408,8 +518,3 @@ def _create_classes(
408518 )
409519
410520 return data
411-
412-
413- def get_subset (df_full : pd .DataFrame , df_subset : pd .DataFrame ):
414- idx = df_subset .index
415- return df_full .drop (idx , axis = 'index' )
0 commit comments