|
29 | 29 |
|
30 | 30 | #include <dmlc/parameter.h> |
31 | 31 |
|
32 | | -#include <algorithm> |
33 | 32 | #include <cstdlib> |
34 | 33 | #include <iomanip> |
35 | 34 | #include <iterator> |
36 | 35 | #include <limits> |
37 | | -#include <numeric> |
38 | 36 | #include <sstream> |
39 | 37 | #include <string> |
40 | 38 | #include <utility> |
@@ -79,10 +77,6 @@ size_t LayoutInfo::ChannelIdx() const { |
79 | 77 | return channel_last ? 1 + n_space_dims : 1; |
80 | 78 | } |
81 | 79 |
|
82 | | -std::vector<int64_t> LayoutInfo::Strides(const std::vector<int64_t>& dims) const { |
83 | | - return PackedStrides(Order(), dims); |
84 | | -} |
85 | | - |
86 | 80 | LayoutInfo GetLayoutInfo(mshadow::LayoutFlag layout) { |
87 | 81 | static std::unordered_map<mshadow::LayoutFlag, LayoutInfo> layout_map{ |
88 | 82 | {mshadow::kNCW, {1, false}}, |
@@ -165,14 +159,8 @@ Descriptor MakeTensorDesc(int64_t uid, |
165 | 159 | for (size_t i = 0; i < dims.size(); ++i) |
166 | 160 | dims[i] = blob.shape_[rev_order[i]]; |
167 | 161 | auto strides = li.Strides(dims); |
168 | | - if (li.n_space_dims == 1 && expand_1d) { |
169 | | - dims.insert(dims.begin() + 2, 1); |
170 | | - std::vector<size_t> order(dims.size()); |
171 | | - std::iota(order.begin(), order.end(), 0); |
172 | | - if (li.channel_last) |
173 | | - std::rotate(order.begin() + 1, order.begin() + 2, order.end()); |
174 | | - strides = PackedStrides(order, dims); |
175 | | - } |
| 162 | + if (expand_1d) |
| 163 | + li.ExpandIf1d(&dims, &strides); |
176 | 164 | return MakeTensorDesc( |
177 | 165 | uid, CudnnType(static_cast<mshadow::TypeFlag>(blob.type_flag_)), dims, strides, is_virtual); |
178 | 166 | } |
@@ -758,6 +746,109 @@ void ConvWgrad::Exec(const cudnn_cxx::Descriptor& plan, |
758 | 746 | CUDNN_CALL(cudnnBackendExecute(s->dnn_handle_, plan.get(), var_pack.get())); |
759 | 747 | } |
760 | 748 |
|
| 749 | +struct LegacyTensorDestroyer { |
| 750 | + using pointer = cudnnTensorDescriptor_t; |
| 751 | + |
| 752 | + void operator()(cudnnTensorDescriptor_t desc) { |
| 753 | + CUDNN_CALL_NONFATAL(cudnnDestroyTensorDescriptor(desc)); |
| 754 | + } |
| 755 | +}; |
| 756 | + |
| 757 | +using LegacyTensor = std::unique_ptr<cudnnTensorDescriptor_t, LegacyTensorDestroyer>; |
| 758 | + |
| 759 | +LegacyTensor MakeLegacyTensor() { |
| 760 | + cudnnTensorDescriptor_t desc{}; |
| 761 | + CUDNN_CALL(cudnnCreateTensorDescriptor(&desc)); |
| 762 | + return LegacyTensor(desc); |
| 763 | +} |
| 764 | + |
| 765 | +union ScalingParam { |
| 766 | + double d; |
| 767 | + float f; |
| 768 | +}; |
| 769 | + |
| 770 | +std::pair<ScalingParam, ScalingParam> AlphaBeta(int type_flag, double init_a, double init_b) { |
| 771 | + ScalingParam a, b; |
| 772 | + switch (type_flag) { |
| 773 | + case kFloat64: |
| 774 | + a.d = init_a; |
| 775 | + b.d = init_b; |
| 776 | + break; |
| 777 | + case kFloat32: // fallthrough |
| 778 | + case kFloat16: |
| 779 | + a.f = init_a; |
| 780 | + b.f = init_b; |
| 781 | + break; |
| 782 | + default: |
| 783 | + LOG(FATAL) << "Unexpected type: " << type_flag; |
| 784 | + } |
| 785 | + return {a, b}; |
| 786 | +} |
| 787 | + |
| 788 | +void SetLegacyTensor(cudnnTensorDescriptor_t desc, const TBlob& blob, const LayoutInfo& li) { |
| 789 | + std::vector<int> dims(blob.shape_.ndim()); |
| 790 | + CHECK_EQ(dims.size(), li.n_space_dims + 2); |
| 791 | + auto rev_order = ReverseOrder(li.Order()); |
| 792 | + for (size_t i = 0; i < dims.size(); ++i) |
| 793 | + dims[i] = blob.shape_[rev_order[i]]; |
| 794 | + auto strides = li.Strides(dims); |
| 795 | + li.ExpandIf1d(&dims, &strides); |
| 796 | + auto type = static_cast<mshadow::TypeFlag>(blob.type_flag_); |
| 797 | + CUDNN_CALL(cudnnSetTensorNdDescriptor(desc, CudnnType(type), dims.size(), &dims[0], &strides[0])); |
| 798 | +} |
| 799 | + |
| 800 | +void SetLegacyCTensorExpandDims(cudnnTensorDescriptor_t desc, |
| 801 | + const TBlob& blob, |
| 802 | + const LayoutInfo& li) { |
| 803 | + std::vector<int> dims(li.n_space_dims + 2, 1); |
| 804 | + dims[1] = blob.shape_[0]; |
| 805 | + std::vector<int> strides(dims.size(), 1); |
| 806 | + strides[0] = blob.shape_[0]; |
| 807 | + li.ExpandIf1d(&dims, &strides); |
| 808 | + auto type = static_cast<mshadow::TypeFlag>(blob.type_flag_); |
| 809 | + CUDNN_CALL(cudnnSetTensorNdDescriptor(desc, CudnnType(type), dims.size(), &dims[0], &strides[0])); |
| 810 | +} |
| 811 | + |
| 812 | +bool LegacyAddBias(const OpContext& ctx, const LayoutInfo& li, const TBlob& y, const TBlob& b) { |
| 813 | + thread_local auto y_desc = MakeLegacyTensor(); |
| 814 | + thread_local auto b_desc = MakeLegacyTensor(); |
| 815 | + |
| 816 | + auto s = ctx.get_stream<gpu>(); |
| 817 | + auto [alpha, beta] = AlphaBeta(y.type_flag_, 1.0, 1.0); // NOLINT(whitespace/braces) |
| 818 | + |
| 819 | + SetLegacyTensor(y_desc.get(), y, li); |
| 820 | + SetLegacyCTensorExpandDims(b_desc.get(), b, li); |
| 821 | + |
| 822 | + auto err = |
| 823 | + cudnnAddTensor(s->dnn_handle_, &alpha, b_desc.get(), b.dptr_, &beta, y_desc.get(), y.dptr_); |
| 824 | + if (err == CUDNN_STATUS_NOT_SUPPORTED) |
| 825 | + return false; |
| 826 | + CHECK_EQ(err, CUDNN_STATUS_SUCCESS); |
| 827 | + return true; |
| 828 | +} |
| 829 | + |
| 830 | +bool LegacyBiasGrad(const OpContext& ctx, |
| 831 | + const LayoutInfo& li, |
| 832 | + bool add_to, |
| 833 | + const TBlob& db, |
| 834 | + const TBlob& dy) { |
| 835 | + thread_local auto db_desc = MakeLegacyTensor(); |
| 836 | + thread_local auto dy_desc = MakeLegacyTensor(); |
| 837 | + |
| 838 | + auto s = ctx.get_stream<gpu>(); |
| 839 | + auto [alpha, beta] = AlphaBeta(dy.type_flag_, 1.0, add_to ? 1.0 : 0.0); // NOLINT(*) |
| 840 | + |
| 841 | + SetLegacyCTensorExpandDims(db_desc.get(), db, li); |
| 842 | + SetLegacyTensor(dy_desc.get(), dy, li); |
| 843 | + |
| 844 | + auto err = cudnnConvolutionBackwardBias( |
| 845 | + s->dnn_handle_, &alpha, dy_desc.get(), dy.dptr_, &beta, db_desc.get(), db.dptr_); |
| 846 | + if (err == CUDNN_STATUS_NOT_SUPPORTED) |
| 847 | + return false; |
| 848 | + CHECK_EQ(err, CUDNN_STATUS_SUCCESS); |
| 849 | + return true; |
| 850 | +} |
| 851 | + |
761 | 852 | } // namespace cudnn |
762 | 853 | } // namespace op |
763 | 854 | } // namespace mxnet |
|
0 commit comments