@@ -75,6 +75,14 @@ def main():
7575 type = int ,
7676 default = 10
7777 )
78+ p_process_datasets .add_argument (
79+ '-r' , '--random_seeds' , dest = 'RANDOM_SEEDS' ,
80+ type = _random_seed_list ,
81+ default = None ,
82+ help = "Defines a list of random seeds. Must be comma separated "
83+ "integers. Must be same length as <NUM_SPLITS>. If omitted will "
84+ "default to randomized seeds."
85+ )
7886
7987 p_all = command_parsers .add_parser (
8088 "all" ,
@@ -116,6 +124,12 @@ def full_workflow(args):
116124
117125def process_datasets (args ):
118126
127+ if args .RANDOM_SEEDS is not None and len (args .RANDOM_SEEDS ) != args .NUM_SPLITS :
128+ sys .exit (
129+ "<RANDOM_SEEDS> must contain same number of random seed values as "
130+ "<NUM_SPLITS>."
131+ )
132+
119133
120134 local_path = args .WORKDIR .joinpath ('data_in_tmp' )
121135
@@ -481,7 +495,10 @@ def split_data_sets(
481495 split_type = args .SPLIT_TYPE
482496 ratio = (8 ,1 ,1 )
483497 stratify_by = None
484- random_state = None
498+ if args .RANDOM_SEEDS is not None :
499+ random_seeds = args .RANDOM_SEEDS
500+ else :
501+ random_seeds = [None ] * args .NUM_SPLITS
485502
486503 for data_set in data_sets_info .keys ():
487504 if data_sets [data_set ].experiments is not None :
@@ -525,7 +542,7 @@ def split_data_sets(
525542 split_type = split_type ,
526543 ratio = ratio ,
527544 stratify_by = stratify_by ,
528- random_state = random_state
545+ random_state = random_seeds [ i ]
529546 )
530547 train_keys = (
531548 splits [i ]
@@ -768,6 +785,17 @@ def _check_folder(path: Union[str, PathLike, Path]) -> Path:
768785
769786 return abs_path
770787
788+ def _random_seed_list (list : str ) -> list :
789+
790+ if not isinstance (list , str ):
791+ raise TypeError (
792+ f"'random_seed' must be of type str. Supplied argument is of type "
793+ f"{ type (list )} ."
794+ )
795+ list_ = list .split (',' )
796+ return [int (item ) for item in list_ ]
797+
798+
771799if __name__ == '__main__' :
772800 try : main ()
773801 except KeyboardInterrupt : pass
0 commit comments