Skip to content
This repository was archived by the owner on Jun 30, 2022. It is now read-only.

Commit 2ea0c08

Browse files
committed
add dataset_title in experiment
1 parent c53f071 commit 2ea0c08

3 files changed

Lines changed: 14 additions & 9 deletions

File tree

mljar/client/dataset.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def _wait_till_all_datasets_are_valid(self):
110110

111111

112112

113-
def add_dataset_if_not_exists(self, X, y, title_prefix = 'dataset-'):
113+
def add_dataset_if_not_exists(self, X, y, title_prefix = 'dataset-', dataset_title = None):
114114
'''
115115
Checks if dataset already exists, if not it add dataset to project.
116116
'''
@@ -127,7 +127,7 @@ def add_dataset_if_not_exists(self, X, y, title_prefix = 'dataset-'):
127127
# dataset with specified hash does not exist
128128
if len(dataset_details) == 0:
129129
# add new dataset
130-
dataset_details = self.add_new_dataset(data, y, title_prefix)
130+
dataset_details = self.add_new_dataset(data, y, title_prefix, dataset_title)
131131
else:
132132
dataset_details = dataset_details[0]
133133

@@ -157,9 +157,12 @@ def _accept_dataset_column_usage(self, dataset_hid):
157157
return response.status_code == 200
158158

159159

160-
def add_new_dataset(self, data, y, title_prefix = 'dataset-'):
160+
def add_new_dataset(self, data, y, title_prefix = 'dataset-', dataset_title = None):
161161
logger.info('Add new dataset')
162-
title = title_prefix + str(uuid.uuid4())[:4] # set some random name
162+
if dataset_title is None:
163+
title = title_prefix + str(uuid.uuid4())[:4] # set some random name
164+
else:
165+
title = dataset_title
163166
file_path = '/tmp/dataset-'+ str(uuid.uuid4())[:8]+'.csv'
164167

165168
logger.info('Compress data before export')

mljar/mljar.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def __init__(self, project,
122122
raise MljarException('Wrong validation_train_split parameter value, it should be in (0.05, 0.95) range.')
123123

124124

125-
def fit(self, X, y, validation_data = None, wait_till_all_done = True):
125+
def fit(self, X, y, validation_data = None, wait_till_all_done = True, dataset_title = None):
126126
'''
127127
Fit models with MLJAR engine.
128128
Args:
@@ -132,6 +132,8 @@ def fit(self, X, y, validation_data = None, wait_till_all_done = True):
132132
the k-fold CV or train split validation will be used.
133133
wait_till_all_done: The flag which decides if fit function will wait
134134
till experiment is done.
135+
dataset_title: The title of your dataset. It is optional. If missing the
136+
random title will be generated.
135137
'''
136138
self.wait_till_all_done = wait_till_all_done
137139
# check input data dimensions
@@ -141,12 +143,12 @@ def fit(self, X, y, validation_data = None, wait_till_all_done = True):
141143
raise IncorrectInputDataException('Sorry, there is a missmatch between X and y matrices shapes')
142144

143145
try:
144-
self._start_experiment(X, y, validation_data)
146+
self._start_experiment(X, y, validation_data, dataset_title)
145147
except Exception as e:
146148
print 'Ups, %s' % str(e)
147149

148150

149-
def _start_experiment(self, X, y, validation_data = None):
151+
def _start_experiment(self, X, y, validation_data = None, dataset_title = None):
150152

151153
# define project task
152154
self.project_task = 'bin_class' if len(np.unique(y)) == 2 else 'reg'
@@ -159,7 +161,7 @@ def _start_experiment(self, X, y, validation_data = None):
159161
# add a dataset to project
160162
#
161163
logger.info('MLJAR: add training dataset')
162-
self.dataset = DatasetClient(self.project.hid).add_dataset_if_not_exists(X, y, title_prefix = 'Training-')
164+
self.dataset = DatasetClient(self.project.hid).add_dataset_if_not_exists(X, y, title_prefix = 'Training-', dataset_title = dataset_title)
163165

164166
self.dataset_vald = None
165167
if validation_data is not None:

tests/mljar_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def test_compute_prediction(self):
4646
single_algorithm_time_limit = 1)
4747
self.assertTrue(model is not None)
4848
# fit models and wait till all models are trained
49-
model.fit(X = self.X, y = self.y)
49+
model.fit(X = self.X, y = self.y, dataset_title = 'My dataset')
5050

5151
# get project id
5252
project_id = model.project.hid

0 commit comments

Comments
 (0)