-
Notifications
You must be signed in to change notification settings - Fork 22
Expand file tree
/
Copy pathmm.cc
More file actions
512 lines (438 loc) · 19.4 KB
/
mm.cc
File metadata and controls
512 lines (438 loc) · 19.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
// SPDX-FileCopyrightText: Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
#define NOCPP
#include <stdio.h>
#include <stdlib.h>
#define REL_WRITE 0
#define REL_READ 1
#include "zero.cc"
#include <aie_api/aie.hpp>
template <typename T_in, typename T_out, int rowA, int colA, int colB, bool b_row_maj = true, bool c_row_maj = true>
static inline void matmul_scalar(T_in *a, T_in *b, T_out *c)
{
event0();
for (int row = 0; row < rowA; row++) {
for (int col = 0; col < colB; col++) {
T_out running_sum = 0;
for (int i = 0; i < colA; i++) {
T_in a_val = a[row * colA + i];
T_in b_val;
if constexpr (b_row_maj) {
b_val = b[i * colB + col];
} else {
b_val = b[i + col * colA];
}
running_sum += a_val * b_val;
}
T_out *c_ptr;
if constexpr (c_row_maj) {
c_ptr = &c[row * colB + col];
} else {
c_ptr = &c[row + col * rowA];
}
*c_ptr += running_sum;
}
}
event1();
}
/* Blocked MatMul kernel (vectorized) utilizing the aie::mmul class.
* The matrices are assumed to be pre-tiled with the following shapes
* for the aie:mmul class: A => rxs, B => sxt, C => rxt.
*
* The matrix dimensions of the kernel are defined by rowA, colA and colB.
* In this particular kernel we expand the aie::mmul two times in each
* input matrices A (in 'm' dimension, or rowA) and B (in 'n' dimension, or
* ColB), leading to a 2x2 expansion in output matrix C (see C00, C01, C10, C11
* below). This expansion helps with accumulator registers usage, which leads in
* attaining high kernel efficiency (SIMD utilization).
*
* Data within each tile (rxs, sxt and rxt) are assumed to be in row-major
* order. Also, the entire tiles themselves are stored in row-major order, as
* shown in the example below for matrix A:
*
* <-s->
* _ ________________________
* r | 1 | 2 | 3 | ...
* _ |____|____|____|
* | x | x+1| x+2| ...
* |____|____|____|
* |.
* |.
* |.
*
* A simplified example of this kernel can be found in the AIE-API
* documentation: https://xilinx.github.io/aie_api/group__group__mmul.html
*/
template <typename T_in,
typename T_out,
unsigned rowA,
unsigned colA,
unsigned colB,
unsigned r,
unsigned s,
unsigned t,
bool b_row_maj = true,
bool c_row_maj = true>
static inline void
matmul_vectorized_2x2_mmul(const T_in *__restrict pA, const T_in *__restrict pB, T_out *__restrict pC)
{
using MMUL = aie::mmul<r, s, t, T_in, T_in, accauto>;
event0();
for (unsigned z = 0; z < rowA; z += 2)
chess_prepare_for_pipelining chess_loop_range(4, )
{
T_out *__restrict pC1;
T_out *__restrict pC2;
if constexpr (c_row_maj) {
pC1 = pC + (z * colB) * MMUL::size_C;
pC2 = pC + ((z + 1) * colB) * MMUL::size_C;
}
for (unsigned j = 0; j < colB; j += 2)
#ifdef OPT_PERF_ENABLED
chess_flatten_loop
#endif
{
if constexpr (!c_row_maj) {
pC1 = pC + j * rowA * MMUL::size_C + z * MMUL::size_C;
pC2 = pC + (j + 1) * rowA * MMUL::size_C + z * MMUL::size_C;
}
const T_in *__restrict pA1 = pA + (z * colA) * MMUL::size_A;
const T_in *__restrict pA2 = pA + ((z + 1) * colA) * MMUL::size_A;
const T_in *__restrict pB1;
const T_in *__restrict pB2;
if constexpr (b_row_maj) {
pB1 = pB + (j)*MMUL::size_B;
pB2 = pB + (j + 1) * MMUL::size_B;
} else {
pB1 = pB + (j * colA) * MMUL::size_B;
pB2 = pB + ((j + 1) * colA) * MMUL::size_B;
}
aie::vector<T_in, MMUL::size_A> A0;
aie::vector<T_in, MMUL::size_A> A1;
aie::vector<T_in, MMUL::size_B> B0;
aie::vector<T_in, MMUL::size_B> B1;
// Load partial results from C buffer for accumulation in-place. The
// zero.cc function handles the zeroing of data when a new
// accumulation is needed (after the 'K' reduction dimension)
aie::vector<T_out, MMUL::size_C> acc_C00;
aie::vector<T_out, MMUL::size_C> acc_C01;
aie::vector<T_out, MMUL::size_C> acc_C10;
aie::vector<T_out, MMUL::size_C> acc_C11;
if constexpr (c_row_maj) {
acc_C00 = aie::load_v<MMUL::size_C>(pC1);
acc_C01 = aie::load_v<MMUL::size_C>(pC1 + MMUL::size_C);
acc_C10 = aie::load_v<MMUL::size_C>(pC2);
acc_C11 = aie::load_v<MMUL::size_C>(pC2 + MMUL::size_C);
} else {
acc_C00 = aie::transpose(aie::load_v<MMUL::size_C>(pC1), t, r);
acc_C01 = aie::transpose(aie::load_v<MMUL::size_C>(pC2), t, r);
acc_C10 = aie::transpose(aie::load_v<MMUL::size_C>(pC1 + MMUL::size_C), t, r);
acc_C11 = aie::transpose(aie::load_v<MMUL::size_C>(pC2 + MMUL::size_C), t, r);
}
MMUL C00(acc_C00);
MMUL C01(acc_C01);
MMUL C10(acc_C10);
MMUL C11(acc_C11);
for (unsigned i = 0; i < colA; ++i)
#ifdef OPT_PERF_ENABLED
chess_flatten_loop
#endif
{
A0 = aie::load_v<MMUL::size_A>(pA1);
pA1 += MMUL::size_A;
A1 = aie::load_v<MMUL::size_A>(pA2);
pA2 += MMUL::size_A;
if constexpr (b_row_maj) {
B0 = aie::load_v<MMUL::size_B>(pB1);
pB1 += MMUL::size_B * colB;
B1 = aie::load_v<MMUL::size_B>(pB2);
pB2 += MMUL::size_B * colB;
} else {
B0 = aie::transpose(aie::load_v<MMUL::size_B>(pB1), t, s);
pB1 += MMUL::size_B;
B1 = aie::transpose(aie::load_v<MMUL::size_B>(pB2), t, s);
pB2 += MMUL::size_B;
}
C00.mac(A0, B0);
C01.mac(A0, B1);
C10.mac(A1, B0);
C11.mac(A1, B1);
}
// TODO make shift right here to keep most significat bits
// when lowering the output
// example below shows how to shift right 10 bits
// #define SHIFT 10
// aie::store_v(pC1, C00.template to_vector<T_out>(SHIFT));
if constexpr (c_row_maj) {
aie::store_v(pC1, C00.template to_vector<T_out>());
pC1 += MMUL::size_C;
aie::store_v(pC1, C01.template to_vector<T_out>());
pC1 += MMUL::size_C;
aie::store_v(pC2, C10.template to_vector<T_out>());
pC2 += MMUL::size_C;
aie::store_v(pC2, C11.template to_vector<T_out>());
pC2 += MMUL::size_C;
} else {
aie::store_v(pC1, aie::transpose(C00.template to_vector<T_out>(), r, t));
pC1 += MMUL::size_C;
aie::store_v(pC2, aie::transpose(C01.template to_vector<T_out>(), r, t));
pC2 += MMUL::size_C;
aie::store_v(pC1, aie::transpose(C10.template to_vector<T_out>(), r, t));
pC1 += MMUL::size_C;
aie::store_v(pC2, aie::transpose(C11.template to_vector<T_out>(), r, t));
pC2 += MMUL::size_C;
}
}
}
event1();
}
#ifdef B_COL_MAJ
constexpr bool is_b_row_maj = false;
#else
constexpr bool is_b_row_maj = true;
#endif
#ifdef C_COL_MAJ
constexpr bool is_c_row_maj = false;
#else
constexpr bool is_c_row_maj = true;
#endif
// The rounding mode can be set for bfloat16 mmul to improve accuracy
#ifdef ROUND_CONV_EVEN
constexpr aie::rounding_mode round_mode = aie::rounding_mode::conv_even;
#else
constexpr aie::rounding_mode round_mode = aie::rounding_mode::floor; // default
#endif
// The following kernel definitions use mmul shapes that have been found to be
// optimal for AIE2P in combination with the 2x2 mmul expanded kernel.
//
// All available matrix multiplication shapes in the AIE-API can be found here:
// https://xilinx.github.io/aie_api/group__group__mmul.html
//
// They are all defined based on the shape of the mmul, the input data format
// and the output data format.
//
// Additionally, they check for the correct
// divisibility of the tile dimensions. Note that while both the 'm' and 'n'
// dimensions of the mmul are expanded, the 'k' dimension is not.
template <unsigned m, unsigned k, unsigned n>
static inline void
matmul_vectorized_4x4x8_i16_i16(const int16 *__restrict pA, const int16 *__restrict pB, int16 *__restrict pC)
{
constexpr int r = 4;
constexpr int s = 4;
constexpr int t = 8;
static_assert(m % (2 * r) == 0);
static_assert(k % s == 0);
static_assert(n % (2 * t) == 0);
return matmul_vectorized_2x2_mmul<int16, int16, (m / r), (k / s), (n / t), r, s, t, is_b_row_maj, is_c_row_maj>(
pA, pB, pC);
}
template <unsigned m, unsigned k, unsigned n>
static inline void
matmul_vectorized_4x4x8_i16_i32(const int16 *__restrict pA, const int16 *__restrict pB, int32 *__restrict pC)
{
constexpr int r = 4;
constexpr int s = 4;
constexpr int t = 8;
static_assert(m % (2 * r) == 0);
static_assert(k % s == 0);
static_assert(n % (2 * t) == 0);
return matmul_vectorized_2x2_mmul<int16, int32, (m / r), (k / s), (n / t), r, s, t, is_b_row_maj, is_c_row_maj>(
pA, pB, pC);
}
template <unsigned m, unsigned k, unsigned n>
static inline void
matmul_vectorized_4x8x8_bf16_bf16(const bfloat16 *__restrict pA, const bfloat16 *__restrict pB, bfloat16 *__restrict pC)
{
constexpr int r = 4;
constexpr int s = 8;
constexpr int t = 8;
static_assert(m % (2 * r) == 0);
static_assert(k % s == 0);
static_assert(n % (2 * t) == 0);
::aie::set_rounding(round_mode);
return matmul_vectorized_2x2_mmul<bfloat16,
bfloat16,
(m / r),
(k / s),
(n / t),
r,
s,
t,
is_b_row_maj,
is_c_row_maj>(pA, pB, pC);
}
// Note that this shape is only possible for bf16 when using bfp16 emulation
// during matmuls.
template <unsigned m, unsigned k, unsigned n>
static inline void
matmul_vectorized_8x8x8_bf16_bf16(const bfloat16 *__restrict pA, const bfloat16 *__restrict pB, bfloat16 *__restrict pC)
{
constexpr int r = 8;
constexpr int s = 8;
constexpr int t = 8;
static_assert(m % (2 * r) == 0);
static_assert(k % s == 0);
static_assert(n % (2 * t) == 0);
::aie::set_rounding(round_mode);
return matmul_vectorized_2x2_mmul<bfloat16,
bfloat16,
(m / r),
(k / s),
(n / t),
r,
s,
t,
is_b_row_maj,
is_c_row_maj>(pA, pB, pC);
}
template <unsigned m, unsigned k, unsigned n>
static inline void
matmul_vectorized_4x8x8_bf16_f32(const bfloat16 *__restrict pA, const bfloat16 *__restrict pB, float *__restrict pC)
{
constexpr int r = 4;
constexpr int s = 8;
constexpr int t = 8;
static_assert(m % (2 * r) == 0);
static_assert(k % s == 0);
static_assert(n % (2 * t) == 0);
::aie::set_rounding(round_mode);
return matmul_vectorized_2x2_mmul<bfloat16, float, (m / r), (k / s), (n / t), r, s, t, is_b_row_maj, is_c_row_maj>(
pA, pB, pC);
}
template <unsigned m, unsigned k, unsigned n>
static inline void
matmul_vectorized_8x8x8_bf16_f32(const bfloat16 *__restrict pA, const bfloat16 *__restrict pB, float *__restrict pC)
{
constexpr int r = 8;
constexpr int s = 8;
constexpr int t = 8;
static_assert(m % (2 * r) == 0);
static_assert(k % s == 0);
static_assert(n % (2 * t) == 0);
::aie::set_rounding(round_mode);
return matmul_vectorized_2x2_mmul<bfloat16, float, (m / r), (k / s), (n / t), r, s, t, is_b_row_maj, is_c_row_maj>(
pA, pB, pC);
}
template <unsigned m, unsigned k, unsigned n>
static inline void
matmul_vectorized_8x8x8_i8_i8(const int8 *__restrict pA, const int8 *__restrict pB, int8 *__restrict pC)
{
constexpr int r = 8;
constexpr int s = 8;
constexpr int t = 8;
static_assert(m % (2 * r) == 0);
static_assert(k % s == 0);
static_assert(n % (2 * t) == 0);
return matmul_vectorized_2x2_mmul<int8, int8, (m / r), (k / s), (n / t), r, s, t, is_b_row_maj, is_c_row_maj>(
pA, pB, pC);
}
template <unsigned m, unsigned k, unsigned n>
static inline void
matmul_vectorized_8x8x8_i8_i16(const int8 *__restrict pA, const int8 *__restrict pB, int16 *__restrict pC)
{
constexpr int r = 8;
constexpr int s = 8;
constexpr int t = 8;
static_assert(m % (2 * r) == 0);
static_assert(k % s == 0);
static_assert(n % (2 * t) == 0);
return matmul_vectorized_2x2_mmul<int8, int16, (m / r), (k / s), (n / t), r, s, t, is_b_row_maj, is_c_row_maj>(
pA, pB, pC);
}
template <unsigned m, unsigned k, unsigned n>
static inline void
matmul_vectorized_8x8x8_i8_i32(const int8 *__restrict pA, const int8 *__restrict pB, int32 *__restrict pC)
{
constexpr int r = 8;
constexpr int s = 8;
constexpr int t = 8;
static_assert(m % (2 * r) == 0);
static_assert(k % s == 0);
static_assert(n % (2 * t) == 0);
return matmul_vectorized_2x2_mmul<int8, int32, (m / r), (k / s), (n / t), r, s, t, is_b_row_maj, is_c_row_maj>(
pA, pB, pC);
}
extern "C" {
// If you want to compile microkernels with different inner tile sizes,
// define DIM_M, DIM_K and DIM_N at compile time using -DDIM_M 32 etc.
// These dimensions must be divisible by the r, s, t dimensions used in
// the kernels.
#ifndef DIM_M
#define DIM_M 64
#endif
#ifndef DIM_K
#define DIM_K 64
#endif
#ifndef DIM_N
#define DIM_N 64
#endif
#ifdef i8_i8_ONLY
#define combos(X) X(int8, i8, int8, i8, 8, 8, 8)
#endif
#ifdef i8_i16_ONLY
#define combos(X) X(int8, i8, int16, i16, 8, 8, 8)
#endif
#ifdef i8_i32_ONLY
#define combos(X) X(int8, i8, int32, i32, 8, 8, 8)
#endif
#ifdef i16_i16_ONLY
#define combos(X) X(int16, i16, int16, i16, 4, 4, 8)
#endif
#ifdef i16_i32_ONLY
#define combos(X) X(int16, i16, int32, i32, 4, 4, 8)
#endif
// The emulation of bf16 changes the available shapes for matrix multiplication
#ifdef bf16_bf16_ONLY
#ifdef AIE_API_EMULATE_BFLOAT16_MMUL_WITH_BFP16
#define combos(X) X(bfloat16, bf16, bfloat16, bf16, 8, 8, 8)
#else
#define combos(X) X(bfloat16, bf16, bfloat16, bf16, 4, 8, 8)
#endif
#endif
#ifdef bf16_f32_ONLY
#ifdef AIE_API_EMULATE_BFLOAT16_MMUL_WITH_BFP16
#define combos(X) X(bfloat16, bf16, float, f32, 8, 8, 8)
#else
#define combos(X) X(bfloat16, bf16, float, f32, 4, 8, 8)
#endif
#endif
#ifndef combos
#ifdef AIE_API_EMULATE_BFLOAT16_MMUL_WITH_BFP16
#define combos(X) \
X(int8, i8, int8, i8, 8, 8, 8) \
X(int16, i16, int16, i16, 4, 4, 8) \
X(int16, i16, int32, i32, 4, 4, 8) \
X(bfloat16, bf16, bfloat16, bf16, 8, 8, 8) \
X(bfloat16, bf16, float, f32, 8, 8, 8)
#else
#define combos(X) \
X(int8, i8, int8, i8, 8, 8, 8) \
X(int16, i16, int16, i16, 4, 4, 8) \
X(int16, i16, int32, i32, 4, 4, 8) \
X(bfloat16, bf16, bfloat16, bf16, 4, 8, 8) \
X(bfloat16, bf16, float, f32, 4, 8, 8)
#endif
#endif
#define matmul_vectorized_c_func(ctype_in, mlir_type_in, ctype_out, mlir_type_out, r, s, t) \
void matmul_##mlir_type_in##_##mlir_type_out(ctype_in *a_in, ctype_in *b_in, ctype_out *c_out) \
{ \
matmul_vectorized_##r##x##s##x##t##_##mlir_type_in##_##mlir_type_out<DIM_M, DIM_K, DIM_N>(a_in, b_in, c_out); \
}
#define matmul_scalar_c_func(ctype_in, mlir_type_in, ctype_out, mlir_type_out, r, s, t) \
void matmul_scalar_##mlir_type_in##_##mlir_type_out(ctype_in *a_in, ctype_in *b_in, ctype_out *c_out) \
{ \
matmul_scalar<ctype_in, ctype_out, DIM_M, DIM_K, DIM_N, is_b_row_maj, is_c_row_maj>(a_in, b_in, c_out); \
}
#define zero_vectorized_c_func(ctype_in, mlir_type_in, ctype_out, mlir_type_out, r, s, t) \
void zero_##mlir_type_out(ctype_out *c_out) \
{ \
zero_vectorized<ctype_out, DIM_M, DIM_N>(c_out); \
}
#define zero_scalar_c_func(ctype_in, mlir_type_in, ctype_out, mlir_type_out, r, s, t) \
void zero_scalar_##mlir_type_out(ctype_out *c_out) \
{ \
zero_scalar<ctype_out, DIM_M, DIM_N>(c_out); \
}
combos(matmul_vectorized_c_func) combos(matmul_scalar_c_func) combos(zero_vectorized_c_func) combos(zero_scalar_c_func)
} // extern "C"