Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 49edbfc

Browse files
authored
fix safe acc in the general case of layer norm (#19806)
1 parent 3122423 commit 49edbfc

1 file changed

Lines changed: 6 additions & 6 deletions

File tree

src/operator/nn/layer_norm-inl.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ void LayerNormComputeGeneral(const nnvm::NodeAttrs& attrs,
124124
// Calculate mean
125125
MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
126126
BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, {
127-
if (safe_acc) {
127+
if (!safe_acc) {
128128
broadcast::Reduce<mshadow_op::sum, NDim, DType, mshadow_op::identity, false>(
129129
s, mean_data, req[0], workspace, in_data);
130130
} else {
@@ -149,7 +149,7 @@ void LayerNormComputeGeneral(const nnvm::NodeAttrs& attrs,
149149
const TBlob centered_out = outputs[0].reshape(red_src_shape);
150150
MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
151151
BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, {
152-
if (safe_acc) {
152+
if (!safe_acc) {
153153
broadcast::Reduce<mshadow_op::sum, NDim, DType, mshadow_op::square, false>(
154154
s, std_data, req[0], workspace, centered_out);
155155
} else {
@@ -290,7 +290,7 @@ void LayerNormGradComputeGeneral(const nnvm::NodeAttrs& attrs,
290290
if (req[2] != kNullOp) {
291291
MSHADOW_REAL_TYPE_SWITCH(outputs[2].type_flag_, DType, {
292292
BROADCAST_NDIM_SWITCH(red_exclude_dst_shape.ndim(), NDim, {
293-
if (safe_acc) {
293+
if (!safe_acc) {
294294
broadcast::Reduce<mshadow_op::sum, NDim, DType, mshadow_op::identity, false>(
295295
s, outputs[2].reshape(red_exclude_dst_shape), req[2], workspace,
296296
ograd.reshape(red_exclude_src_shape));
@@ -313,7 +313,7 @@ void LayerNormGradComputeGeneral(const nnvm::NodeAttrs& attrs,
313313
if (req[1] != kNullOp) {
314314
MSHADOW_REAL_TYPE_SWITCH(outputs[1].type_flag_, DType, {
315315
BROADCAST_NDIM_SWITCH(red_exclude_dst_shape.ndim(), NDim, {
316-
if (safe_acc) {
316+
if (!safe_acc) {
317317
broadcast::Reduce<mshadow_op::sum, NDim, DType, mshadow_op::identity, false>(
318318
s, outputs[1].reshape(red_exclude_dst_shape), req[1], workspace,
319319
ograd_mult.reshape(red_exclude_src_shape));
@@ -347,7 +347,7 @@ void LayerNormGradComputeGeneral(const nnvm::NodeAttrs& attrs,
347347
#endif // !defined(__CUDACC__)
348348
MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
349349
BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, {
350-
if (safe_acc) {
350+
if (!safe_acc) {
351351
broadcast::Reduce<mshadow_op::sum, NDim, DType, mshadow_op::identity, false>(
352352
s, red_out.reshape(red_dst_shape), kWriteTo, workspace,
353353
ograd_mult.reshape(red_src_shape));
@@ -375,7 +375,7 @@ void LayerNormGradComputeGeneral(const nnvm::NodeAttrs& attrs,
375375
#endif // !defined(__CUDACC__)
376376
MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
377377
BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, {
378-
if (safe_acc) {
378+
if (!safe_acc) {
379379
broadcast::Reduce<mshadow_op::sum, NDim, DType, mshadow_op::identity, false>(
380380
s, red_out.reshape(red_dst_shape), kWriteTo, workspace,
381381
ograd_mult.reshape(red_src_shape));

0 commit comments

Comments
 (0)