Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit f0ef9d8

Browse files
Use cuDNN for conv bias and bias grad (#20771)
* Use cuDNN for conv bias and bias grad * Environment variables to use native add-bias and bias-grad * Handle 3D tensors in cuDNN legacy API * Fix AMP for ndarray.numpy * Remove env vars, used for benchmarking Co-authored-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
1 parent b555b54 commit f0ef9d8

7 files changed

Lines changed: 201 additions & 67 deletions

File tree

python/mxnet/amp/loss_scaler.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,14 +46,14 @@ def has_overflow(self, params):
4646
"""Check gradients for overflow."""
4747
if is_np_array():
4848
all_finite_f = ndarray.numpy._internal.multi_all_finite
49-
ones_f = ndarray.numpy.ones
49+
ones_f = lambda ctx: ndarray.numpy.ones((1,), device=ctx)
5050
else:
5151
all_finite_f = ndarray.multi_all_finite
52-
ones_f = ndarray.ones
52+
ones_f = lambda ctx: ndarray.ones((1,), ctx=ctx)
5353
with ag.pause():
5454
chunk_size = 200
5555
valid_params = [p._grad[0] for p in params if p._grad is not None]
56-
gpu_output = ones_f((1,), ctx=valid_params[0].context)
56+
gpu_output = ones_f(valid_params[0].context)
5757
nb_params = len(valid_params)
5858
for idx in range(0, nb_params, chunk_size):
5959
all_finite_f(*valid_params[idx:idx+chunk_size],

src/common/cuda/cudnn_cxx.cc

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -112,15 +112,6 @@ std::vector<Descriptor> GetSomeAttrs(size_t max_n,
112112
return ret;
113113
}
114114

115-
std::vector<int64_t> PackedStrides(const std::vector<size_t>& order,
116-
const std::vector<int64_t>& dims) {
117-
CHECK_EQ(order.size(), dims.size());
118-
std::vector<int64_t> ret(dims.size(), 1);
119-
for (size_t i = dims.size() - 1; i--;)
120-
ret[order[i]] = dims[order[i + 1]] * ret[order[i + 1]];
121-
return ret;
122-
}
123-
124115
std::vector<Descriptor> GetPlans(cudnnBackendHeurMode_t h_mode,
125116
cudnnHandle_t handle,
126117
const Descriptor& op_graph,

src/common/cuda/cudnn_cxx.h

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -244,8 +244,14 @@ std::vector<Descriptor> GetSomeAttrs(size_t max_n,
244244
cudnnBackendDescriptorType_t type);
245245

246246
// Order sets layout, as a permutation of dims, with N,C,<spacial dims> being identity.
247-
std::vector<int64_t> PackedStrides(const std::vector<size_t>& order,
248-
const std::vector<int64_t>& dims);
247+
template <typename T>
248+
std::vector<T> PackedStrides(const std::vector<size_t>& order, const std::vector<T>& dims) {
249+
CHECK_EQ(order.size(), dims.size());
250+
std::vector<T> ret(dims.size(), 1);
251+
for (size_t i = dims.size() - 1; i--;)
252+
ret[order[i]] = dims[order[i + 1]] * ret[order[i + 1]];
253+
return ret;
254+
}
249255

250256
// Given an engine config's `notes`, return whether that config is compatible, i.e. does
251257
// the config have all of the required notes and none of the notes that are being excluded.

src/operator/cudnn_ops.cc

Lines changed: 105 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,10 @@
2929

3030
#include <dmlc/parameter.h>
3131

32-
#include <algorithm>
3332
#include <cstdlib>
3433
#include <iomanip>
3534
#include <iterator>
3635
#include <limits>
37-
#include <numeric>
3836
#include <sstream>
3937
#include <string>
4038
#include <utility>
@@ -79,10 +77,6 @@ size_t LayoutInfo::ChannelIdx() const {
7977
return channel_last ? 1 + n_space_dims : 1;
8078
}
8179

82-
std::vector<int64_t> LayoutInfo::Strides(const std::vector<int64_t>& dims) const {
83-
return PackedStrides(Order(), dims);
84-
}
85-
8680
LayoutInfo GetLayoutInfo(mshadow::LayoutFlag layout) {
8781
static std::unordered_map<mshadow::LayoutFlag, LayoutInfo> layout_map{
8882
{mshadow::kNCW, {1, false}},
@@ -165,14 +159,8 @@ Descriptor MakeTensorDesc(int64_t uid,
165159
for (size_t i = 0; i < dims.size(); ++i)
166160
dims[i] = blob.shape_[rev_order[i]];
167161
auto strides = li.Strides(dims);
168-
if (li.n_space_dims == 1 && expand_1d) {
169-
dims.insert(dims.begin() + 2, 1);
170-
std::vector<size_t> order(dims.size());
171-
std::iota(order.begin(), order.end(), 0);
172-
if (li.channel_last)
173-
std::rotate(order.begin() + 1, order.begin() + 2, order.end());
174-
strides = PackedStrides(order, dims);
175-
}
162+
if (expand_1d)
163+
li.ExpandIf1d(&dims, &strides);
176164
return MakeTensorDesc(
177165
uid, CudnnType(static_cast<mshadow::TypeFlag>(blob.type_flag_)), dims, strides, is_virtual);
178166
}
@@ -758,6 +746,109 @@ void ConvWgrad::Exec(const cudnn_cxx::Descriptor& plan,
758746
CUDNN_CALL(cudnnBackendExecute(s->dnn_handle_, plan.get(), var_pack.get()));
759747
}
760748

749+
struct LegacyTensorDestroyer {
750+
using pointer = cudnnTensorDescriptor_t;
751+
752+
void operator()(cudnnTensorDescriptor_t desc) {
753+
CUDNN_CALL_NONFATAL(cudnnDestroyTensorDescriptor(desc));
754+
}
755+
};
756+
757+
using LegacyTensor = std::unique_ptr<cudnnTensorDescriptor_t, LegacyTensorDestroyer>;
758+
759+
LegacyTensor MakeLegacyTensor() {
760+
cudnnTensorDescriptor_t desc{};
761+
CUDNN_CALL(cudnnCreateTensorDescriptor(&desc));
762+
return LegacyTensor(desc);
763+
}
764+
765+
union ScalingParam {
766+
double d;
767+
float f;
768+
};
769+
770+
std::pair<ScalingParam, ScalingParam> AlphaBeta(int type_flag, double init_a, double init_b) {
771+
ScalingParam a, b;
772+
switch (type_flag) {
773+
case kFloat64:
774+
a.d = init_a;
775+
b.d = init_b;
776+
break;
777+
case kFloat32: // fallthrough
778+
case kFloat16:
779+
a.f = init_a;
780+
b.f = init_b;
781+
break;
782+
default:
783+
LOG(FATAL) << "Unexpected type: " << type_flag;
784+
}
785+
return {a, b};
786+
}
787+
788+
void SetLegacyTensor(cudnnTensorDescriptor_t desc, const TBlob& blob, const LayoutInfo& li) {
789+
std::vector<int> dims(blob.shape_.ndim());
790+
CHECK_EQ(dims.size(), li.n_space_dims + 2);
791+
auto rev_order = ReverseOrder(li.Order());
792+
for (size_t i = 0; i < dims.size(); ++i)
793+
dims[i] = blob.shape_[rev_order[i]];
794+
auto strides = li.Strides(dims);
795+
li.ExpandIf1d(&dims, &strides);
796+
auto type = static_cast<mshadow::TypeFlag>(blob.type_flag_);
797+
CUDNN_CALL(cudnnSetTensorNdDescriptor(desc, CudnnType(type), dims.size(), &dims[0], &strides[0]));
798+
}
799+
800+
void SetLegacyCTensorExpandDims(cudnnTensorDescriptor_t desc,
801+
const TBlob& blob,
802+
const LayoutInfo& li) {
803+
std::vector<int> dims(li.n_space_dims + 2, 1);
804+
dims[1] = blob.shape_[0];
805+
std::vector<int> strides(dims.size(), 1);
806+
strides[0] = blob.shape_[0];
807+
li.ExpandIf1d(&dims, &strides);
808+
auto type = static_cast<mshadow::TypeFlag>(blob.type_flag_);
809+
CUDNN_CALL(cudnnSetTensorNdDescriptor(desc, CudnnType(type), dims.size(), &dims[0], &strides[0]));
810+
}
811+
812+
bool LegacyAddBias(const OpContext& ctx, const LayoutInfo& li, const TBlob& y, const TBlob& b) {
813+
thread_local auto y_desc = MakeLegacyTensor();
814+
thread_local auto b_desc = MakeLegacyTensor();
815+
816+
auto s = ctx.get_stream<gpu>();
817+
auto [alpha, beta] = AlphaBeta(y.type_flag_, 1.0, 1.0); // NOLINT(whitespace/braces)
818+
819+
SetLegacyTensor(y_desc.get(), y, li);
820+
SetLegacyCTensorExpandDims(b_desc.get(), b, li);
821+
822+
auto err =
823+
cudnnAddTensor(s->dnn_handle_, &alpha, b_desc.get(), b.dptr_, &beta, y_desc.get(), y.dptr_);
824+
if (err == CUDNN_STATUS_NOT_SUPPORTED)
825+
return false;
826+
CHECK_EQ(err, CUDNN_STATUS_SUCCESS);
827+
return true;
828+
}
829+
830+
bool LegacyBiasGrad(const OpContext& ctx,
831+
const LayoutInfo& li,
832+
bool add_to,
833+
const TBlob& db,
834+
const TBlob& dy) {
835+
thread_local auto db_desc = MakeLegacyTensor();
836+
thread_local auto dy_desc = MakeLegacyTensor();
837+
838+
auto s = ctx.get_stream<gpu>();
839+
auto [alpha, beta] = AlphaBeta(dy.type_flag_, 1.0, add_to ? 1.0 : 0.0); // NOLINT(*)
840+
841+
SetLegacyCTensorExpandDims(db_desc.get(), db, li);
842+
SetLegacyTensor(dy_desc.get(), dy, li);
843+
844+
auto err = cudnnConvolutionBackwardBias(
845+
s->dnn_handle_, &alpha, dy_desc.get(), dy.dptr_, &beta, db_desc.get(), db.dptr_);
846+
if (err == CUDNN_STATUS_NOT_SUPPORTED)
847+
return false;
848+
CHECK_EQ(err, CUDNN_STATUS_SUCCESS);
849+
return true;
850+
}
851+
761852
} // namespace cudnn
762853
} // namespace op
763854
} // namespace mxnet

src/operator/cudnn_ops.h

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,9 @@
2929

3030
#include <mxnet/op_attr_types.h>
3131

32+
#include <algorithm>
3233
#include <mutex>
34+
#include <numeric>
3335
#include <tuple>
3436
#include <unordered_map>
3537
#include <utility>
@@ -89,7 +91,23 @@ struct LayoutInfo {
8991

9092
std::vector<size_t> Order() const;
9193
size_t ChannelIdx() const;
92-
std::vector<int64_t> Strides(const std::vector<int64_t>& dims) const;
94+
95+
template <typename T>
96+
std::vector<T> Strides(const std::vector<T>& dims) const {
97+
return cudnn_cxx::PackedStrides(Order(), dims);
98+
}
99+
100+
template <typename T>
101+
void ExpandIf1d(std::vector<T>* dims, std::vector<T>* strides) const {
102+
if (n_space_dims != 1)
103+
return;
104+
dims->insert(dims->begin() + 2, 1);
105+
std::vector<size_t> order(dims->size());
106+
std::iota(order.begin(), order.end(), 0);
107+
if (channel_last)
108+
std::rotate(order.begin() + 1, order.begin() + 2, order.end());
109+
*strides = cudnn_cxx::PackedStrides(order, *dims);
110+
}
93111
};
94112

95113
LayoutInfo GetLayoutInfo(mshadow::LayoutFlag layout);
@@ -246,6 +264,14 @@ struct ConvWgrad {
246264
const TBlob& dw);
247265
};
248266

267+
bool LegacyAddBias(const OpContext& ctx, const LayoutInfo& li, const TBlob& y, const TBlob& b);
268+
269+
bool LegacyBiasGrad(const OpContext& ctx,
270+
const LayoutInfo& li,
271+
bool add_to,
272+
const TBlob& db,
273+
const TBlob& dy);
274+
249275
} // namespace cudnn
250276
} // namespace op
251277
} // namespace mxnet

src/operator/nn/convolution.cu

Lines changed: 27 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -57,14 +57,18 @@ void ConvolutionCompute<gpu>(const nnvm::NodeAttrs& attrs,
5757
if (ok && !param.no_bias) {
5858
CHECK_EQ(inputs[conv::kBias].shape_.ndim(), 1);
5959
auto layout = static_cast<mshadow::LayoutFlag>(param.layout.value());
60-
int k = inputs[conv::kBias].shape_.Size();
61-
auto b = inputs[conv::kBias].reshape(cudnn::ExpandChannelDims(layout, k));
62-
BinaryBroadcastRTCCompute{"add"}( // NOLINT(whitespace/braces)
63-
attrs,
64-
ctx,
65-
{outputs[conv::kOut], b},
66-
{kWriteInplace},
67-
{outputs[conv::kOut]});
60+
auto li = cudnn::GetLayoutInfo(layout);
61+
if (li.channel_last ||
62+
!cudnn::LegacyAddBias(ctx, li, outputs[conv::kOut], inputs[conv::kBias])) {
63+
int k = inputs[conv::kBias].shape_.Size();
64+
auto b = inputs[conv::kBias].reshape(cudnn::ExpandChannelDims(layout, k));
65+
BinaryBroadcastRTCCompute{"add"}( // NOLINT(whitespace/braces)
66+
attrs,
67+
ctx,
68+
{outputs[conv::kOut], b},
69+
{kWriteInplace},
70+
{outputs[conv::kOut]});
71+
}
6872
}
6973
if (!ok) {
7074
if (!param.cudnn_off)
@@ -137,17 +141,21 @@ void ConvolutionGradCompute<gpu>(const nnvm::NodeAttrs& attrs,
137141
cudnn::Exec<cudnn::ConvWgrad>(
138142
ctx, conv_param, inputs[1 + conv::kData], inputs[0], outputs[conv::kWeight]));
139143
if (ok && !param.no_bias && req[conv::kBias] != kNullOp) {
140-
auto li = cudnn::GetLayoutInfo(static_cast<mshadow::LayoutFlag>(param.layout.value()));
141-
if (li.channel_last) {
142-
// This kernel should be faster.
143-
auto y_grad = FlattenAs2DHead<gpu, DType>(inputs[0], ctx);
144-
AddBiasGrad(outputs[conv::kBias], y_grad, req[conv::kBias], param.num_filter, ctx);
145-
} else {
146-
TShape axes{static_cast<int>(li.ChannelIdx())};
147-
TShape small =
148-
ReduceAxesShapeImpl(inputs[0].shape_, dmlc::optional<mxnet::TShape>(axes), true, true);
149-
ReduceAxesRTCComputeImpl(
150-
ctx, {inputs[0]}, {req[conv::kBias]}, {outputs[conv::kBias]}, small, "red::sum{}");
144+
auto li = cudnn::GetLayoutInfo(static_cast<mshadow::LayoutFlag>(param.layout.value()));
145+
auto add_to = req[conv::kBias] == kAddTo;
146+
if (li.channel_last ||
147+
!cudnn::LegacyBiasGrad(ctx, li, add_to, outputs[conv::kBias], inputs[0])) {
148+
if (li.channel_last) {
149+
// This kernel should be faster.
150+
auto y_grad = FlattenAs2DHead<gpu, DType>(inputs[0], ctx);
151+
AddBiasGrad(outputs[conv::kBias], y_grad, req[conv::kBias], param.num_filter, ctx);
152+
} else {
153+
TShape axes{static_cast<int>(li.ChannelIdx())};
154+
TShape small = ReduceAxesShapeImpl(
155+
inputs[0].shape_, dmlc::optional<mxnet::TShape>(axes), true, true);
156+
ReduceAxesRTCComputeImpl(
157+
ctx, {inputs[0]}, {req[conv::kBias]}, {outputs[conv::kBias]}, small, "red::sum{}");
158+
}
151159
}
152160
}
153161
if (!ok) {

src/operator/nn/deconvolution.cu

Lines changed: 31 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -56,14 +56,18 @@ void DeconvolutionCompute<gpu>(const nnvm::NodeAttrs& attrs,
5656
if (ok && !param.no_bias) {
5757
CHECK_EQ(inputs[deconv::kBias].shape_.ndim(), 1);
5858
auto layout = static_cast<mshadow::LayoutFlag>(param.layout.value());
59-
int k = inputs[deconv::kBias].shape_.Size();
60-
auto b = inputs[deconv::kBias].reshape(cudnn::ExpandChannelDims(layout, k));
61-
BinaryBroadcastRTCCompute{"add"}( // NOLINT(whitespace/braces)
62-
attrs,
63-
ctx,
64-
{outputs[deconv::kOut], b},
65-
{kWriteInplace},
66-
{outputs[deconv::kOut]});
59+
auto li = cudnn::GetLayoutInfo(layout);
60+
if (li.channel_last ||
61+
!cudnn::LegacyAddBias(ctx, li, outputs[deconv::kOut], inputs[deconv::kBias])) {
62+
int k = inputs[deconv::kBias].shape_.Size();
63+
auto b = inputs[deconv::kBias].reshape(cudnn::ExpandChannelDims(layout, k));
64+
BinaryBroadcastRTCCompute{"add"}( // NOLINT(whitespace/braces)
65+
attrs,
66+
ctx,
67+
{outputs[deconv::kOut], b},
68+
{kWriteInplace},
69+
{outputs[deconv::kOut]});
70+
}
6771
}
6872
if (!ok) {
6973
if (!param.cudnn_off)
@@ -115,17 +119,25 @@ void DeconvolutionGradCompute<gpu>(const nnvm::NodeAttrs& attrs,
115119
cudnn::Exec<cudnn::ConvWgrad>(
116120
ctx, conv_param, inputs[0], inputs[1 + deconv::kData], outputs[deconv::kWeight]));
117121
if (ok && !param.no_bias && req[deconv::kBias] != kNullOp) {
118-
auto li = cudnn::GetLayoutInfo(static_cast<mshadow::LayoutFlag>(param.layout.value()));
119-
if (li.channel_last) {
120-
// This kernel should be faster.
121-
auto y_grad = FlattenAs2DHead<gpu, DType>(inputs[0], ctx);
122-
AddBiasGrad(outputs[deconv::kBias], y_grad, req[deconv::kBias], param.num_filter, ctx);
123-
} else {
124-
TShape axes{static_cast<int>(li.ChannelIdx())};
125-
TShape small =
126-
ReduceAxesShapeImpl(inputs[0].shape_, dmlc::optional<mxnet::TShape>(axes), true, true);
127-
ReduceAxesRTCComputeImpl(
128-
ctx, {inputs[0]}, {req[deconv::kBias]}, {outputs[deconv::kBias]}, small, "red::sum{}");
122+
auto li = cudnn::GetLayoutInfo(static_cast<mshadow::LayoutFlag>(param.layout.value()));
123+
auto add_to = req[conv::kBias] == kAddTo;
124+
if (li.channel_last ||
125+
!cudnn::LegacyBiasGrad(ctx, li, add_to, outputs[deconv::kBias], inputs[0])) {
126+
if (li.channel_last) {
127+
// This kernel should be faster.
128+
auto y_grad = FlattenAs2DHead<gpu, DType>(inputs[0], ctx);
129+
AddBiasGrad(outputs[deconv::kBias], y_grad, req[deconv::kBias], param.num_filter, ctx);
130+
} else {
131+
TShape axes{static_cast<int>(li.ChannelIdx())};
132+
TShape small = ReduceAxesShapeImpl(
133+
inputs[0].shape_, dmlc::optional<mxnet::TShape>(axes), true, true);
134+
ReduceAxesRTCComputeImpl(ctx,
135+
{inputs[0]},
136+
{req[deconv::kBias]},
137+
{outputs[deconv::kBias]},
138+
small,
139+
"red::sum{}");
140+
}
129141
}
130142
}
131143
if (!ok) {

0 commit comments

Comments
 (0)