Skip to content

Commit 0e21991

Browse files
ymckiggerganov
andauthored
fix vulkan ggml_acc only works in 3d but not 4d (ggml-org#19426)
* fix vulkan ggml_acc only works in 3d but not 4d * removed clamp in test_acc_block * use the correct stride and its test case * cuda : fix "supports op" condition * change src0 to src1 in ggml_vk_acc. Update acc.comp with jeffbolznv\'s suggestion except to keep the boundary check * version without boundary check * revert back to boundary check version --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
1 parent b2ecc0c commit 0e21991

4 files changed

Lines changed: 60 additions & 23 deletions

File tree

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4822,8 +4822,11 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
48224822
case GGML_OP_CONV_2D_DW:
48234823
case GGML_OP_CONV_TRANSPOSE_2D:
48244824
case GGML_OP_POOL_2D:
4825-
case GGML_OP_ACC:
48264825
return true;
4826+
case GGML_OP_ACC:
4827+
// TODO: extend support like so:
4828+
//return ggml_is_contiguous_rows(op->src[0]) && ggml_is_contiguous_rows(op->src[1]);
4829+
return ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]);
48274830
case GGML_OP_SUM:
48284831
return ggml_is_contiguous_rows(op->src[0]);
48294832
case GGML_OP_TOP_K:

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9801,16 +9801,16 @@ static void ggml_vk_acc(ggml_backend_vk_context * ctx, vk_context& subctx, const
98019801
const uint32_t src1_type_size = ggml_type_size(src1->type);
98029802
const uint32_t dst_type_size = ggml_type_size(dst->type);
98039803

9804-
int nb1 = dst->op_params[0] / 4; // 4 bytes of float32
9805-
int nb2 = dst->op_params[1] / 4; // 4 bytes of float32
9806-
// int nb3 = dst->op_params[2] / 4; // 4 bytes of float32 - unused
9807-
int offset = dst->op_params[3] / 4; // offset in bytes
9804+
int nb1 = dst->op_params[0] / src0_type_size; // 4 bytes of float32
9805+
int nb2 = dst->op_params[1] / src0_type_size; // 4 bytes of float32
9806+
int nb3 = dst->op_params[2] / src0_type_size; // 4 bytes of float32
9807+
int offset = dst->op_params[3] / src0_type_size; // offset in bytes
98089808

98099809
ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_ACC, {
98109810
(uint32_t)ggml_nelements(src0),
9811-
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t)src0->nb[3] / src0_type_size,
9811+
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t)nb3,
98129812
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
9813-
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t) dst->nb[3] / dst_type_size,
9813+
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t)nb3,
98149814
0,
98159815
0.0f, 0.0f, offset,
98169816
});

ggml/src/ggml-vulkan/vulkan-shaders/acc.comp

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,18 @@ void main() {
1313

1414
const uint offset = p.param3;
1515
const uint src1_i = idx - offset;
16-
const uint oz = src1_i / p.nb02;
17-
const uint oy = (src1_i - (oz * p.nb02)) / p.nb01;
18-
const uint ox = src1_i % p.nb01;
16+
const uint i3 = src1_i / p.nb03;
17+
const uint rem2 = src1_i - i3 * p.nb03;
18+
const uint i2 = rem2 / p.nb02;
19+
const uint rem1 = rem2 - i2 * p.nb02;
20+
const uint i1 = rem1 / p.nb01;
21+
const uint i0 = rem1 % p.nb01;
1922

2023
uint i00, i01, i02, i03;
21-
get_indices(idx, i00, i01, i02, i03);
2224

23-
if (ox < p.ne10 && oy < p.ne11 && oz < p.ne12) {
24-
data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) + FLOAT_TYPE(data_b[get_boffset() + ox + oy * p.ne10 + oz * p.ne10 * p.ne11]));
25+
if (i0 < p.ne10 && i1 < p.ne11 && i2 < p.ne12 && i3 < p.ne13) {
26+
data_d[get_doffset() + idx] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + idx]) + FLOAT_TYPE(data_b[get_boffset() + src1_idx(i0, i1, i2, i3)]));
2527
} else {
26-
data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]));
28+
data_d[get_doffset() + idx] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + idx]));
2729
}
2830
}
29-

tests/test-backend-ops.cpp

Lines changed: 41 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5847,26 +5847,46 @@ struct test_acc : public test_case {
58475847
const ggml_type type;
58485848
const std::array<int64_t, 4> ne_a;
58495849
const std::array<int64_t, 4> ne_b;
5850+
const int64_t stride_dim;
58505851

58515852
std::string vars() override {
5852-
return VARS_TO_STR3(type, ne_a, ne_b);
5853+
return VARS_TO_STR4(type, ne_a, ne_b, stride_dim);
58535854
}
58545855

58555856
test_acc(ggml_type type = GGML_TYPE_F32,
5856-
std::array<int64_t, 4> ne_a = {256, 17, 1, 1},
5857-
std::array<int64_t, 4> ne_b = {256, 16, 1, 1})
5858-
: type(type), ne_a(ne_a), ne_b(ne_b) {}
5857+
std::array<int64_t, 4> ne_a = {256, 17, 2, 3},
5858+
std::array<int64_t, 4> ne_b = {256, 16, 2, 3},
5859+
uint64_t stride_dim = -1)
5860+
: type(type), ne_a(ne_a), ne_b(ne_b), stride_dim(stride_dim) {}
58595861

58605862
ggml_tensor * build_graph(ggml_context * ctx) override {
58615863
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data());
58625864
ggml_set_param(a);
58635865
ggml_set_name(a, "a");
58645866

5865-
ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne_b.data());
5866-
ggml_set_param(b);
5867+
ggml_tensor * b;
5868+
if (stride_dim == 1 || stride_dim == 2 || stride_dim == 3) {
5869+
// Create a larger tensor and take a view at a non-zero offset.
5870+
// This tests that the backend correctly handles b's data offset
5871+
std::array<int64_t, 4> ne_b_pad = {ne_b[0], ne_b[1], ne_b[2], ne_b[3]};
5872+
ne_b_pad[stride_dim] += 1;
5873+
ggml_tensor * b_pad = ggml_new_tensor(ctx, type, 4, ne_b_pad.data());
5874+
ggml_set_param(b_pad);
5875+
ggml_set_name(b_pad, "b_pad");
5876+
// View that skips the first row, so b has a non-zero byte offset
5877+
b = ggml_view_4d(ctx, b_pad,
5878+
ne_b[0], ne_b[1], ne_b[2], ne_b[3],
5879+
b_pad->nb[1], b_pad->nb[2], b_pad->nb[3],
5880+
b_pad->nb[1]);
5881+
} else {
5882+
b = ggml_new_tensor(ctx, type, 4, ne_b.data());
5883+
ggml_set_param(b);
5884+
}
58675885
ggml_set_name(b, "b");
58685886

5869-
ggml_tensor * out = ggml_acc(ctx, a, b, a->nb[1], a->nb[2], a->nb[3], b->nb[1]);
5887+
// When ne_b[0] < ne_a[0], a->nb[1] != b->nb[1], so the stride
5888+
// parameters to ggml_acc don't match b's natural stride.
5889+
ggml_tensor * out = ggml_acc(ctx, a, b, a->nb[1], a->nb[2], a->nb[3], 0);
58705890
ggml_set_name(out, "out");
58715891

58725892
return out;
@@ -8170,7 +8190,12 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
81708190
test_cases.emplace_back(new test_group_norm(GGML_TYPE_F32, {9, 9, 1280, 1}));
81718191
test_cases.emplace_back(new test_group_norm_mul_add(GGML_TYPE_F32, {64, 64, 320, 1}));
81728192
test_cases.emplace_back(new test_group_norm_mul_add(GGML_TYPE_F32, {9, 9, 1280, 1}));
8173-
test_cases.emplace_back(new test_acc());
8193+
test_cases.emplace_back(new test_acc(GGML_TYPE_F32, {256, 17, 1, 1}, {256, 16, 1, 1}, -1));
8194+
test_cases.emplace_back(new test_acc(GGML_TYPE_F32, {256, 17, 2, 3}, {256, 16, 2, 3}, -1));
8195+
test_cases.emplace_back(new test_acc(GGML_TYPE_F32, {256, 17, 2, 3}, {128, 16, 2, 3}, -1));
8196+
test_cases.emplace_back(new test_acc(GGML_TYPE_F32, {256, 17, 2, 3}, {256, 16, 2, 3}, 1));
8197+
test_cases.emplace_back(new test_acc(GGML_TYPE_F32, {256, 17, 2, 3}, {128, 16, 2, 3}, 2));
8198+
test_cases.emplace_back(new test_acc(GGML_TYPE_F32, {256, 17, 2, 3}, {64, 16, 2, 3}, 3));
81748199
test_cases.emplace_back(new test_pad());
81758200
test_cases.emplace_back(new test_pad(GGML_TYPE_F32, {33, 17, 2, 1}, 4, 3, true)); // circular
81768201
test_cases.emplace_back(new test_pad_ext());
@@ -8605,6 +8630,14 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
86058630
test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 128, 64, 48, 1, 512, 1)); // prefill
86068631
test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 128, 64, 48, 1, 1, 1)); // generate
86078632

8633+
// acc
8634+
test_cases.emplace_back(new test_acc(GGML_TYPE_F32, {256, 17, 1, 1}, {256, 16, 1, 1}, -1));
8635+
test_cases.emplace_back(new test_acc(GGML_TYPE_F32, {256, 17, 2, 3}, {256, 16, 2, 3}, -1));
8636+
test_cases.emplace_back(new test_acc(GGML_TYPE_F32, {256, 17, 2, 3}, {128, 16, 2, 3}, -1));
8637+
test_cases.emplace_back(new test_acc(GGML_TYPE_F32, {256, 17, 2, 3}, {256, 16, 2, 3}, 1));
8638+
test_cases.emplace_back(new test_acc(GGML_TYPE_F32, {256, 17, 2, 3}, {128, 16, 2, 3}, 2));
8639+
test_cases.emplace_back(new test_acc(GGML_TYPE_F32, {256, 17, 2, 3}, {64, 16, 2, 3}, 3));
8640+
86088641
return test_cases;
86098642
}
86108643

0 commit comments

Comments
 (0)