You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
@@ -84,53 +84,147 @@ At this point, the script could technically already run in MLBench. But so far i
84
84
The PyTorch script reports loss to ``stdout``, but we can easily report the loss to MLBench as well. First we need to import the relevant MLBench functionality by adding the following line to the imports at the top of the file:
85
85
86
86
{% highlight python %}
87
-
from mlbench_core.api import ApiClient
87
+
from mlbench_core.utils import Tracker
88
+
from mlbench_core.evaluation.goals import task1_time_to_accuracy_goal
89
+
from mlbench_core.evaluation.pytorch.metrics import TopKAccuracy
90
+
from mlbench_core.controlflow.pytorch import validation_round
88
91
{% endhighlight %}
89
92
90
-
Then we can simply create an api client object and use it to report the loss. We instantiate the client as shown on lines 10 - 13 in this snippet and post the loss as shown on lines 32 - 35:
93
+
Then we can simply create a ``Tracker`` object and use it to report the loss and add metrics (``TokKAccuracy``) to track. We add code to record the timing of different steps with ``tracker.record_batch_step()``.
94
+
We have to tell the tracker that we're in the training loop ba calling ``tracker.train()`` and that the epoch is done by calling ``tracker.epoch_end()``. The loss is recorded with ``tracker.record_loss()``.
Make sure to change ``default`` on line 12 to the namespace MLBench is running under in Kubernetes.
131
154
132
155
That's it. Now the training will report the loss of each worker back to the Dashboard and show it in a nice Graph.
133
156
157
+
For the official tasks, we also need to report validation stats to the tracker and use the offical validation code. Rename the current ``partition_dataset()`` method to ``partition_dataset_train``
158
+
and add a new partition method to load the validation set:
159
+
160
+
{% highlight python linenos %}
161
+
def partition_dataset_val():
162
+
""" Partitioning MNIST validation set"""
163
+
dataset = datasets.MNIST(
164
+
'./data',
165
+
train=False,
166
+
download=True,
167
+
transform=transforms.Compose([
168
+
transforms.ToTensor(),
169
+
transforms.Normalize((0.1307, ), (0.3081, ))
170
+
]))
171
+
size = dist.get_world_size()
172
+
bsz = int(128 / float(size))
173
+
partition_sizes = [1.0 / size for _ in range(size)]
Now all that is needed is to add the validation loop code (``validation_round()``) to run validation in the ``run()`` function. We also check if the goal is reached and stop training if it is.
208
+
``validation_round()`` evaluates the metrics on the validation set and reports the results to the Dashboard.
The full code (with some additional improvements) is in our [Github Repo](https://github.com/mlbench/mlbench-benchmarks/blob/master/examples/mlbench-pytorch-tutorial/)
0 commit comments