Skip to content

Commit fc6855d

Browse files
authored
Use scalar fast path in optimized layer_norm for small tensors (pytorch#18636)
Differential Revision: D98795281 Pull Request resolved: pytorch#18636
1 parent c58fc28 commit fc6855d

3 files changed

Lines changed: 77 additions & 36 deletions

File tree

kernels/optimized/cpu/op_native_layer_norm.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,24 @@ void layer_norm(
7272
const bool gamma_null = gamma_data == nullptr;
7373
const bool beta_null = beta_data == nullptr;
7474

75+
// For small normalized dimensions, fall back to the portable scalar
76+
// implementation since SIMD vectorization setup/tail-handling overhead
77+
// exceeds the benefit for small N.
78+
constexpr size_t kSmallNThreshold = 256;
79+
if (N < kSmallNThreshold) {
80+
layer_norm_scalar<CTYPE>(
81+
input_data,
82+
gamma_data,
83+
beta_data,
84+
out_data,
85+
mean_data,
86+
rstd_data,
87+
M,
88+
N,
89+
eps);
90+
return;
91+
}
92+
7593
for (size_t i = 0; i < M; ++i) {
7694
const CTYPE* src_ptr = input_data + i * N;
7795
CTYPE* dst_ptr = out_data + i * N;

kernels/portable/cpu/op_native_layer_norm.cpp

Lines changed: 15 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
#include <c10/util/irange.h>
99

1010
#include <executorch/kernels/portable/cpu/util/normalization_ops_util.h>
11-
#include <executorch/kernels/portable/cpu/vec_ops.h>
1211
#include <executorch/runtime/kernel/kernel_includes.h>
1312
#include <cmath>
1413
#include <tuple>
@@ -54,41 +53,21 @@ void layer_norm(
5453
}
5554

5655
const CTYPE* input_data = input.const_data_ptr<CTYPE>();
57-
const CTYPE* weight_data;
58-
if (weight.has_value()) {
59-
weight_data = weight.value().const_data_ptr<CTYPE>();
60-
} else {
61-
weight_data = nullptr;
62-
}
63-
const CTYPE* bias_data;
64-
if (bias.has_value()) {
65-
bias_data = bias.value().const_data_ptr<CTYPE>();
66-
} else {
67-
bias_data = nullptr;
68-
}
69-
70-
const CTYPE ct_normalized = static_cast<CTYPE>(normalized);
71-
for (const auto i : c10::irange(leading)) {
72-
const CTYPE* x = input_data + i * normalized;
73-
CTYPE* y = out_data + i * normalized;
74-
75-
// compute E[X] and Var[x] = E[x^2] - E[x]^2
76-
float sum = reduce_add(x, ct_normalized);
77-
float sq_sum = vec_powerf(x, ct_normalized);
78-
float mean_value = sum / ct_normalized;
79-
float variance = sq_sum / ct_normalized - mean_value * mean_value;
80-
float std = std::sqrt(variance + eps);
81-
82-
// Calculate the elements of output
83-
for (const auto j : c10::irange(normalized)) {
84-
CTYPE w = weight_data ? weight_data[j] : static_cast<CTYPE>(1);
85-
CTYPE b = bias_data ? bias_data[j] : static_cast<CTYPE>(0);
86-
y[j] = (x[j] - mean_value) / std * w + b;
87-
}
88-
89-
mean_data[i] = mean_value;
90-
rstd_data[i] = 1.0 / std;
91-
}
56+
const CTYPE* weight_data =
57+
weight.has_value() ? weight.value().const_data_ptr<CTYPE>() : nullptr;
58+
const CTYPE* bias_data =
59+
bias.has_value() ? bias.value().const_data_ptr<CTYPE>() : nullptr;
60+
61+
layer_norm_scalar<CTYPE>(
62+
input_data,
63+
weight_data,
64+
bias_data,
65+
out_data,
66+
mean_data,
67+
rstd_data,
68+
leading,
69+
normalized,
70+
eps);
9271
}
9372

9473
} // namespace

kernels/portable/cpu/util/normalization_ops_util.h

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,54 @@
99
#pragma once
1010

1111
#include <executorch/runtime/kernel/kernel_includes.h>
12+
#include <cmath>
13+
#include <numeric>
1214

1315
namespace torch {
1416
namespace executor {
1517

18+
/**
19+
* Scalar layer_norm computation over M rows of N elements each.
20+
* Computes mean/variance in float, normalizes with (x - mean) / std * gamma +
21+
* beta. Caller must handle M==0 and N==0 edge cases before calling.
22+
*/
23+
template <typename CTYPE>
24+
inline void layer_norm_scalar(
25+
const CTYPE* input_data,
26+
const CTYPE* weight_data, // nullable
27+
const CTYPE* bias_data, // nullable
28+
CTYPE* out_data,
29+
CTYPE* mean_data,
30+
CTYPE* rstd_data,
31+
size_t M,
32+
size_t N,
33+
float eps) {
34+
for (size_t i = 0; i < M; ++i) {
35+
const CTYPE* x = input_data + i * N;
36+
CTYPE* y = out_data + i * N;
37+
38+
// compute E[X] and Var[x] = E[x^2] - E[x]^2
39+
float sum = std::accumulate(x, x + N, 0.0f);
40+
float sq_sum = 0;
41+
for (size_t j = 0; j < N; ++j) {
42+
sq_sum += static_cast<float>(x[j]) * x[j];
43+
}
44+
float mean_value = sum / N;
45+
float variance = sq_sum / N - mean_value * mean_value;
46+
float std = std::sqrt(variance + eps);
47+
48+
// Calculate the elements of output
49+
for (size_t j = 0; j < N; ++j) {
50+
CTYPE w = weight_data ? weight_data[j] : static_cast<CTYPE>(1);
51+
CTYPE b = bias_data ? bias_data[j] : static_cast<CTYPE>(0);
52+
y[j] = (x[j] - mean_value) / std * w + b;
53+
}
54+
55+
mean_data[i] = mean_value;
56+
rstd_data[i] = 1.0 / std;
57+
}
58+
}
59+
1660
bool check_batch_norm_args(
1761
const Tensor& in,
1862
const std::optional<Tensor>& weight,

0 commit comments

Comments
 (0)