Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Latest commit

 

History

History
76 lines (67 loc) · 3.05 KB

File metadata and controls

76 lines (67 loc) · 3.05 KB

Add Reduce operation to computation Graph

Introduction

This is the part of work of transferring DeepSpeed's work into MXNet. Since the difference between symbolic and imperative, we divide the whole proecss into two phases:

phase 1: Add reduce operation into graph. The reduce operation will do nothing in forward but reduce the gradient to the right GPU(according to POS-trainer).

phase2: In backward graph, delete the outputs in arrays so the memory planner can reuse such memory.

Getting start

Prepare NCCL and horovod

Since we use horovod to communicate, please firstly install horovod. And we use NCCL reduce, please also install it.

Complie the Graph Pass and load

Please firstly compile it like lib pass. Run make and it will generate dynamic library add_reduce_op_lib.so which is compiled from the add_reduce_op.cc file. Then load such file in your python code like

import mxnet as mx
mx.library.load('add_reduce_op_lib.so')

Prepare options

Then we need know the correct partition of parameters and gradients about their GPUs. So please use POS_Trainer from pos_trainer.py like normal trainer in MXNet.

from pos_trainer import POS_Trainer
trainer = POS_Trainer(params_dict, "adam", optimizer_params)

Then trainer can generate corresponding options like:

options = trainer.generate_graph_pass_options()
backward_options = trainer.generate_backward_options()

modify graph

Before forward, we use

model.optimize_for(x, backend = "add_reduce_op", **options)

to insert reduce operation into graphs.
example add reduce

Then we call backward option as

loss.backward(backward_option = backward_options)

Simple Example

Please see test_reduce.py

Current problem

  1. The reduce operation will cause deadlock (it won't happen in NaiveEngine). Moreover, it will meet invalid address problem in complex model like Bert-Base.
  2. We do remove outputs from backward graph using backward option. But we need to verify whether it decrease the memory consumption.