@@ -169,7 +169,7 @@ def main() -> None:
169169 parser .add_argument ("input" , help = "ppim_combinatorics CSV file" )
170170 parser .add_argument ("-o" , "--output" , default = "." , help = "Output directory for model & plots" )
171171 parser .add_argument ("--epochs" , type = int , default = 50 , help = "Training epochs (default: 50)" )
172- parser .add_argument ("--batch-size" , type = int , default = 256 , help = "Batch size (default: 256)" )
172+ parser .add_argument ("--batch-size" , type = int , default = 32 , help = "Batch size (default: 256)" )
173173 parser .add_argument ("--seed" , type = int , default = 42 , help = "Random seed (default: 42)" )
174174 args = parser .parse_args ()
175175
@@ -179,9 +179,7 @@ def main() -> None:
179179 print (f"Keras version : { keras .__version__ } " )
180180
181181 # --- Data -----------------------------------------------------------------
182- X_train , X_test , y_train , y_test , mean , std = load_data (
183- args .input , seed = args .seed ,
184- )
182+ X_train , X_test , y_train , y_test , mean , std = load_data (args .input , seed = args .seed ,)
185183 print (f"Train samples : { len (y_train )} | Test samples : { len (y_test )} " )
186184
187185 # --- Class weights --------------------------------------------------------
0 commit comments