@@ -128,22 +128,22 @@ def filter_predictions(all_preds_file, partition, out_file):
128128
129129
130130 if partition == "trn" :
131- df_preds_filt = df_preds_trn_sites [(df_preds_trn_sites .date >= config ['train_start_date' ][ 0 ] ) &
132- (df_preds_trn_sites .date < config ['train_end_date' ][ 0 ] )]
131+ df_preds_filt = df_preds_trn_sites [(df_preds_trn_sites .date >= config ['train_start_date' ]) &
132+ (df_preds_trn_sites .date < config ['train_end_date' ])]
133133 elif partition == "val" :
134134 # get all of the data in the validation sites and in the validation period
135135 # this assumes that the test period follows the validation period which follows the train period
136- df_preds_filt_val = df_preds_val_sites [df_preds_val_sites .date < config ['test_start_date' ][ 0 ] ]
137- df_preds_filt_trn = df_preds_trn_sites [(df_preds_trn_sites .date < config ['val_end_date' ][ 0 ] ) &
138- (df_preds_trn_sites .date >= config ['val_start_date' ][ 0 ] )]
136+ df_preds_filt_val = df_preds_val_sites [df_preds_val_sites .date < config ['test_start_date' ]]
137+ df_preds_filt_trn = df_preds_trn_sites [(df_preds_trn_sites .date < config ['val_end_date' ]) &
138+ (df_preds_trn_sites .date >= config ['val_start_date' ])]
139139 df_preds_filt = pd .concat ([df_preds_filt_val , df_preds_filt_trn ], axis = 0 )
140140
141141 elif partition == "val_times" :
142142 # get the data in just the validation times at train and val sites
143- df_preds_filt_val = df_preds_val_sites [(df_preds_val_sites .date < config ['val_end_date' ][ 0 ] ) &
144- (df_preds_val_sites .date >= config ['val_start_date' ][ 0 ] )]
145- df_preds_filt_trn = df_preds_trn_sites [(df_preds_trn_sites .date < config ['val_end_date' ][ 0 ] ) &
146- (df_preds_trn_sites .date >= config ['val_start_date' ][ 0 ] )]
143+ df_preds_filt_val = df_preds_val_sites [(df_preds_val_sites .date < config ['val_end_date' ]) &
144+ (df_preds_val_sites .date >= config ['val_start_date' ])]
145+ df_preds_filt_trn = df_preds_trn_sites [(df_preds_trn_sites .date < config ['val_end_date' ]) &
146+ (df_preds_trn_sites .date >= config ['val_start_date' ])]
147147 df_preds_filt = pd .concat ([df_preds_filt_val , df_preds_filt_trn ], axis = 0 )
148148
149149
0 commit comments