@@ -22,6 +22,7 @@ def train_test_validate(
2222 ratio : tuple [int , int , int ]= (8 ,1 ,1 ),
2323 stratify_by : (str | None )= None ,
2424 random_state : (int | RandomState | None )= None ,
25+ ** kwargs : dict ,
2526 ) -> tuple [DatasetLoader , DatasetLoader , DatasetLoader ]:
2627 """
2728 Splits a `CoderData` object (see also
@@ -76,6 +77,12 @@ def train_test_validate(
7677
7778 """
7879
80+ # reading in the potential keyword arguments that will be passed to
81+ # _create_classes().
82+ thresh = kwargs .get ('thresh' , None )
83+ num_classes = kwargs .get ('num_classes' , 2 )
84+ quantiles = kwargs .get ('quantiles' , True )
85+
7986 # Type checking split_type
8087 if split_type not in [
8188 'mixed-set' , 'drug-blind' , 'cancer-blind'
@@ -84,12 +91,13 @@ def train_test_validate(
8491 f"{ split_type } not an excepted input for 'split_type'"
8592 )
8693
87- df_full = data . experiments . copy ()
94+
8895 # A wide (pivoted) table is more easy to work with in this instance.
8996 # The pivot is done using all columns but the 'dose_respones_value'
9097 # and 'dose_respones_metric' as index. df.pivot will generate a
9198 # MultiIndex which complicates things further down the line. To that
9299 # end 'reset_index()' is used to remove the MultiIndex
100+ df_full = data .experiments .copy ()
93101 df_full = df_full .pivot (
94102 index = [
95103 'source' ,
@@ -225,8 +233,9 @@ def train_test_validate(
225233 df_full = _create_classes (
226234 data = df_full ,
227235 metric = stratify_by ,
228- num_classes = 2 ,
229- thresh = 0.5 ,
236+ num_classes = num_classes ,
237+ thresh = thresh ,
238+ quantiles = quantiles ,
230239 )
231240 if split_type == 'mixed-set' :
232241 # Using ShuffleSplit to generate randomized train and
@@ -316,7 +325,7 @@ def _filter(data: DatasetLoader, split: pd.DataFrame) -> DatasetLoader:
316325 'improve_drug_id' ,
317326 'study' ,
318327 'time' ,
319- 'time_unit'
328+ 'time_unit' ,
320329 ],
321330 var_name = 'dose_response_metric' ,
322331 value_name = 'dose_response_value'
@@ -352,7 +361,8 @@ def _filter(data: DatasetLoader, split: pd.DataFrame) -> DatasetLoader:
352361def _create_classes (
353362 data : pd .DataFrame ,
354363 metric : str ,
355- num_classes : int = 3 ,
364+ num_classes : int = 2 ,
365+ quantiles : bool = True ,
356366 thresh : float = None ,
357367 ) -> pd .DataFrame :
358368 """
@@ -372,11 +382,18 @@ def _create_classes(
372382 )
373383
374384 if thresh is None :
375- data ['split_class' ] = pd .cut (
376- data [metric ],
377- bins = num_classes ,
378- labels = False
385+ if quantiles :
386+ data ['split_class' ] = pd .qcut (
387+ data [metric ],
388+ q = num_classes ,
389+ labels = False ,
379390 )
391+ else :
392+ data ['split_class' ] = pd .cut (
393+ data [metric ],
394+ bins = num_classes ,
395+ labels = False
396+ )
380397 elif num_classes == 2 :
381398 data ['split_class' ] = pd .cut (
382399 data [metric ],
0 commit comments