88import time
99
1010from mljar .client .project import ProjectClient
11+ from mljar .client .dataset import DatasetClient
1112from project_based_test import ProjectBasedTest
1213from mljar .exceptions import BadValueException , IncorrectInputDataException
1314from mljar .utils import MLJAR_DEFAULT_TUNING_MODE
@@ -26,9 +27,10 @@ def setUp(self):
2627 self .X = df [cols ]
2728 self .y = df [target ]
2829
29- #def tearDown(self):
30- # # clean
31- # ProjectBasedTest.clean_projects()
30+ def tearDown (self ):
31+ # clean
32+ ProjectBasedTest .clean_projects ()
33+
3234
3335 def mse (self , predictions , targets ):
3436 predictions = np .array (predictions )
@@ -38,12 +40,10 @@ def mse(self, predictions, targets):
3840
3941
4042 def test_compute_prediction (self ):
41- '''
42- Test the most common usage.
43- '''
4443 model = Mljar (project = self .proj_title , experiment = self .expt_title ,
45- algorithms = ['rfc' ], metric = 'logloss' ,
46- validation_kfolds = 3 , tuning_mode = 'Normal' )
44+ algorithms = ['rfc' ], metric = 'logloss' ,
45+ validation_kfolds = 3 , tuning_mode = 'Normal' ,
46+ single_algorithm_time_limit = 1 )
4747 self .assertTrue (model is not None )
4848 # fit models and wait till all models are trained
4949 model .fit (X = self .X , y = self .y )
@@ -52,19 +52,29 @@ def test_compute_prediction(self):
5252 project_id = model .project .hid
5353 # get model id
5454 model_id = model .selected_algorithm .hid
55+
56+ dc = DatasetClient (project_id )
57+ init_datasets_cnt = len (dc .get_datasets ())
5558 # compute predictions
5659 pred = Mljar .compute_prediction (self .X , model_id , project_id )
5760 # compute score
5861 score = self .mse (pred , self .y )
5962 self .assertTrue (score < 0.1 )
63+ # check if dataset was removed
64+ self .assertEqual (init_datasets_cnt , len (dc .get_datasets ()))
65+ # run predictions again, but keep dataset
66+ pred = Mljar .compute_prediction (self .X , model_id , project_id , keep_dataset = True )
67+ self .assertEqual (init_datasets_cnt + 1 , len (dc .get_datasets ())) # should be one more
68+
6069
6170 def test_basic_usage (self ):
6271 '''
6372 Test the most common usage.
6473 '''
6574 model = Mljar (project = self .proj_title , experiment = self .expt_title ,
66- algorithms = ['xgb' ], metric = 'logloss' ,
67- validation_kfolds = 3 , tuning_mode = 'Normal' )
75+ algorithms = ['xgb' ], metric = 'logloss' ,
76+ validation_kfolds = 3 , tuning_mode = 'Normal' ,
77+ single_algorithm_time_limit = 1 )
6878 self .assertTrue (model is not None )
6979 # fit models and wait till all models are trained
7080 model .fit (X = self .X , y = self .y )
@@ -97,7 +107,8 @@ def test_usage_with_train_split(self):
97107 Test usage with train split.
98108 '''
99109 model = Mljar (project = self .proj_title , experiment = self .expt_title ,
100- validation_train_split = 0.8 , algorithms = ['xgb' ], tuning_mode = 'Normal' )
110+ validation_train_split = 0.8 , algorithms = ['xgb' ], tuning_mode = 'Normal' ,
111+ single_algorithm_time_limit = 1 )
101112 self .assertTrue (model is not None )
102113 # fit models and wait till all models are trained
103114 model .fit (X = self .X , y = self .y , wait_till_all_done = False )
@@ -117,7 +128,8 @@ def test_usage_with_validation_dataset(self):
117128 Test usage with validation dataset.
118129 '''
119130 model = Mljar (project = self .proj_title , experiment = self .expt_title ,
120- algorithms = ['xgb' ], tuning_mode = 'Normal' )
131+ algorithms = ['xgb' ], tuning_mode = 'Normal' ,
132+ single_algorithm_time_limit = 1 )
121133 self .assertTrue (model is not None )
122134 # load validation data
123135 df = pd .read_csv ('tests/data/test_1_vald.csv' )
@@ -174,7 +186,8 @@ def test_non_wait_fit(self):
174186 '''
175187 model = Mljar (project = self .proj_title , experiment = self .expt_title ,
176188 algorithms = ['xgb' ], metric = 'logloss' ,
177- validation_kfolds = 3 , tuning_mode = 'Normal' )
189+ validation_kfolds = 3 , tuning_mode = 'Normal' ,
190+ single_algorithm_time_limit = 1 )
178191 self .assertTrue (model is not None )
179192 # fit models, just start computation and do not wait
180193 start_time = time .time ()
@@ -211,8 +224,9 @@ def test_retrive_models(self):
211224 all models will be simply retrived from existing project.
212225 '''
213226 model = Mljar (project = self .proj_title , experiment = self .expt_title ,
214- algorithms = ['xgb' ], metric = 'logloss' ,
215- validation_kfolds = 3 , tuning_mode = 'Normal' )
227+ algorithms = ['xgb' ], metric = 'logloss' ,
228+ validation_kfolds = 3 , tuning_mode = 'Normal' ,
229+ single_algorithm_time_limit = 1 )
216230 self .assertTrue (model is not None )
217231 # fit models and wait till all models are trained
218232 model .fit (X = self .X , y = self .y )
@@ -240,8 +254,9 @@ def test_retrive_models(self):
240254 # re-use project
241255 start_time = time .time ()
242256 model_2 = Mljar (project = self .proj_title , experiment = self .expt_title ,
243- algorithms = ['xgb' ], metric = 'logloss' ,
244- validation_kfolds = 3 , tuning_mode = 'Normal' )
257+ algorithms = ['xgb' ], metric = 'logloss' ,
258+ validation_kfolds = 3 , tuning_mode = 'Normal' ,
259+ single_algorithm_time_limit = 1 )
245260 self .assertTrue (model_2 is not None )
246261 # re-use trained models
247262 model_2 .fit (X = self .X , y = self .y )
0 commit comments