Skip to content
This repository was archived by the owner on Jun 30, 2022. It is now read-only.

Commit 82577bf

Browse files
committed
compute_predict
1 parent c631a59 commit 82577bf

4 files changed

Lines changed: 90 additions & 25 deletions

File tree

mljar/client/dataset.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import os
55
import sys
66
import time
7+
import copy
78
from zipfile import ZipFile, ZIP_DEFLATED
89
from os.path import basename
910
from base import MljarHttpClient
@@ -45,6 +46,12 @@ def get_dataset(self, dataset_hid):
4546
logger.error('Dataset not found')
4647
return None
4748

49+
def delete_dataset(self, dataset_hid):
50+
'''
51+
Deletes dataset
52+
'''
53+
response = self.request("DELETE", '/'.join([self.url, dataset_hid]))
54+
return response.status_code == 204 or response.status_code == 200
4855

4956
def _prepare_data(self, X, y):
5057
'''
@@ -55,22 +62,23 @@ def _prepare_data(self, X, y):
5562
if isinstance(X, np.ndarray):
5663
cols = {}
5764
col_names = []
58-
for i in xrange(X.shape[1]):
65+
X_cpy = copy.deepcopy(X)
66+
for i in xrange(X_cpy.shape[1]):
5967
c = 'attribute_'+str(i+1)
60-
cols[c] = X[:,i]
68+
cols[c] = X_cpy[:,i]
6169
col_names += [c]
6270
if y is not None:
63-
cols['target'] = y
71+
cols['target'] = copy.deepcopy(y)
6472
col_names.append('target')
6573
data = pd.DataFrame(cols, columns=col_names)
6674
if isinstance(X, pd.DataFrame):
6775
if y is not None:
68-
data = X
69-
data['target'] = y
76+
data = copy.deepcopy(X)
77+
data['target'] = copy.deepcopy(y)
7078
# todo: add search for target like attributes and rename
7179
# "target", "class", "loss"
7280
else:
73-
data = X
81+
data = copy.deepcopy(X)
7482

7583
dataset_hash = str(make_hash(data))
7684
return data, dataset_hash

mljar/mljar.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,8 @@ def predict(self, X):
286286

287287
if self.selected_algorithm is not None:
288288

289+
return Mljar.compute_prediction(X, self.selected_algorithm.hid, self.project.hid)
290+
'''
289291
# chack if dataset exists in mljar if not upload dataset for prediction
290292
dataset = DatasetClient(self.project.hid).add_dataset_if_not_exists(X, y = None)
291293
@@ -317,10 +319,11 @@ def predict(self, X):
317319
logger.error('Sorry, there was some problem with computing prediction for your dataset. \
318320
Please login to mljar.com to your account and check details.')
319321
return None
322+
'''
320323

321324

322325
@staticmethod
323-
def compute_prediction(X, model_id, project_id):
326+
def compute_prediction(X, model_id, project_id, keep_dataset = False):
324327

325328

326329
# chack if dataset exists in mljar if not upload dataset for prediction
@@ -344,6 +347,8 @@ def compute_prediction(X, model_id, project_id):
344347
if prediction is not None:
345348
pred = PredictionDownloadClient().download(prediction.hid)
346349
#sys.stdout.write('\r\n')
350+
if not keep_dataset:
351+
DatasetClient(project_id).delete_dataset(dataset.hid)
347352
return pred
348353

349354
#sys.stdout.write('\rFetch predictions: {0}%'.format(round(i/(total_checks*0.01))))

tests/dataset_client_test.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,14 @@ def setUp(self):
2323
df = pd.read_csv('tests/data/test_1.csv')
2424
cols = ['sepal length', 'sepal width', 'petal length', 'petal width']
2525
target = 'class'
26-
self.X = df[cols]
26+
self.X = df.loc[:,cols]
2727
self.y = df[target]
2828

2929
def tearDown(self):
3030
# clean
3131
self.project_client.delete_project(self.project.hid)
3232

33+
3334
def test_get_datasests(self):
3435
"""
3536
Get empty list of datasets in project.
@@ -117,3 +118,39 @@ def test_add_existing_dataset(self):
117118
# number of all datasets in project should be 1
118119
datasets = dc.get_datasets()
119120
self.assertEqual(len(datasets), init_datasets_cnt+1)
121+
122+
123+
def test_prepare_data_two_sources(self):
124+
dc = DatasetClient(self.project.hid)
125+
data_1, data_hash_1 = dc._prepare_data(self.X, self.y)
126+
data_2, data_hash_2 = dc._prepare_data(self.X, None)
127+
self.assertNotEqual(data_hash_1, data_hash_2)
128+
129+
130+
def test_prepare_data_two_sources_numpy(self):
131+
dc = DatasetClient(self.project.hid)
132+
data_1, data_hash_1 = dc._prepare_data(np.array(self.X), np.array(self.y))
133+
data_2, data_hash_2 = dc._prepare_data(np.array(self.X), None)
134+
self.assertNotEqual(data_hash_1, data_hash_2)
135+
136+
def test_create_and_delete(self):
137+
# setup dataset client
138+
dc = DatasetClient(self.project.hid)
139+
self.assertNotEqual(dc, None)
140+
# get initial number of datasets
141+
init_datasets_cnt = len(dc.get_datasets())
142+
# add dataset
143+
my_dataset_1 = dc.add_dataset_if_not_exists(self.X, self.y)
144+
my_dataset_2 = dc.add_dataset_if_not_exists(self.X, y = None)
145+
# get datasets
146+
datasets = dc.get_datasets()
147+
self.assertEqual(len(datasets), init_datasets_cnt+2)
148+
# delete added dataset
149+
dc.delete_dataset(my_dataset_1.hid)
150+
# check number of datasets
151+
datasets = dc.get_datasets()
152+
self.assertEqual(len(datasets), init_datasets_cnt+1)
153+
154+
155+
if __name__ == "__main__":
156+
unittest.main()

tests/mljar_test.py

Lines changed: 32 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import time
99

1010
from mljar.client.project import ProjectClient
11+
from mljar.client.dataset import DatasetClient
1112
from project_based_test import ProjectBasedTest
1213
from mljar.exceptions import BadValueException, IncorrectInputDataException
1314
from mljar.utils import MLJAR_DEFAULT_TUNING_MODE
@@ -26,9 +27,10 @@ def setUp(self):
2627
self.X = df[cols]
2728
self.y = df[target]
2829

29-
#def tearDown(self):
30-
# # clean
31-
# ProjectBasedTest.clean_projects()
30+
def tearDown(self):
31+
# clean
32+
ProjectBasedTest.clean_projects()
33+
3234

3335
def mse(self, predictions, targets):
3436
predictions = np.array(predictions)
@@ -38,12 +40,10 @@ def mse(self, predictions, targets):
3840

3941

4042
def test_compute_prediction(self):
41-
'''
42-
Test the most common usage.
43-
'''
4443
model = Mljar(project = self.proj_title, experiment = self.expt_title,
45-
algorithms = ['rfc'], metric='logloss',
46-
validation_kfolds=3, tuning_mode='Normal')
44+
algorithms = ['rfc'], metric = 'logloss',
45+
validation_kfolds = 3, tuning_mode = 'Normal',
46+
single_algorithm_time_limit = 1)
4747
self.assertTrue(model is not None)
4848
# fit models and wait till all models are trained
4949
model.fit(X = self.X, y = self.y)
@@ -52,19 +52,29 @@ def test_compute_prediction(self):
5252
project_id = model.project.hid
5353
# get model id
5454
model_id = model.selected_algorithm.hid
55+
56+
dc = DatasetClient(project_id)
57+
init_datasets_cnt = len(dc.get_datasets())
5558
# compute predictions
5659
pred = Mljar.compute_prediction(self.X, model_id, project_id)
5760
# compute score
5861
score = self.mse(pred, self.y)
5962
self.assertTrue(score < 0.1)
63+
# check if dataset was removed
64+
self.assertEqual(init_datasets_cnt, len(dc.get_datasets()))
65+
# run predictions again, but keep dataset
66+
pred = Mljar.compute_prediction(self.X, model_id, project_id, keep_dataset = True)
67+
self.assertEqual(init_datasets_cnt+1, len(dc.get_datasets())) # should be one more
68+
6069

6170
def test_basic_usage(self):
6271
'''
6372
Test the most common usage.
6473
'''
6574
model = Mljar(project = self.proj_title, experiment = self.expt_title,
66-
algorithms = ['xgb'], metric='logloss',
67-
validation_kfolds=3, tuning_mode='Normal')
75+
algorithms = ['xgb'], metric = 'logloss',
76+
validation_kfolds = 3, tuning_mode = 'Normal',
77+
single_algorithm_time_limit = 1)
6878
self.assertTrue(model is not None)
6979
# fit models and wait till all models are trained
7080
model.fit(X = self.X, y = self.y)
@@ -97,7 +107,8 @@ def test_usage_with_train_split(self):
97107
Test usage with train split.
98108
'''
99109
model = Mljar(project = self.proj_title, experiment = self.expt_title,
100-
validation_train_split = 0.8, algorithms = ['xgb'], tuning_mode='Normal')
110+
validation_train_split = 0.8, algorithms = ['xgb'], tuning_mode='Normal',
111+
single_algorithm_time_limit=1)
101112
self.assertTrue(model is not None)
102113
# fit models and wait till all models are trained
103114
model.fit(X = self.X, y = self.y, wait_till_all_done = False)
@@ -117,7 +128,8 @@ def test_usage_with_validation_dataset(self):
117128
Test usage with validation dataset.
118129
'''
119130
model = Mljar(project = self.proj_title, experiment = self.expt_title,
120-
algorithms = ['xgb'], tuning_mode='Normal')
131+
algorithms = ['xgb'], tuning_mode='Normal',
132+
single_algorithm_time_limit = 1)
121133
self.assertTrue(model is not None)
122134
# load validation data
123135
df = pd.read_csv('tests/data/test_1_vald.csv')
@@ -174,7 +186,8 @@ def test_non_wait_fit(self):
174186
'''
175187
model = Mljar(project = self.proj_title, experiment = self.expt_title,
176188
algorithms = ['xgb'], metric='logloss',
177-
validation_kfolds=3, tuning_mode='Normal')
189+
validation_kfolds=3, tuning_mode='Normal',
190+
single_algorithm_time_limit = 1)
178191
self.assertTrue(model is not None)
179192
# fit models, just start computation and do not wait
180193
start_time = time.time()
@@ -211,8 +224,9 @@ def test_retrive_models(self):
211224
all models will be simply retrived from existing project.
212225
'''
213226
model = Mljar(project = self.proj_title, experiment = self.expt_title,
214-
algorithms = ['xgb'], metric='logloss',
215-
validation_kfolds=3, tuning_mode='Normal')
227+
algorithms = ['xgb'], metric = 'logloss',
228+
validation_kfolds = 3, tuning_mode = 'Normal',
229+
single_algorithm_time_limit = 1)
216230
self.assertTrue(model is not None)
217231
# fit models and wait till all models are trained
218232
model.fit(X = self.X, y = self.y)
@@ -240,8 +254,9 @@ def test_retrive_models(self):
240254
# re-use project
241255
start_time = time.time()
242256
model_2 = Mljar(project = self.proj_title, experiment = self.expt_title,
243-
algorithms = ['xgb'], metric='logloss',
244-
validation_kfolds=3, tuning_mode='Normal')
257+
algorithms = ['xgb'], metric = 'logloss',
258+
validation_kfolds = 3, tuning_mode = 'Normal',
259+
single_algorithm_time_limit = 1)
245260
self.assertTrue(model_2 is not None)
246261
# re-use trained models
247262
model_2.fit(X = self.X, y = self.y)

0 commit comments

Comments
 (0)