Skip to content

Commit b6c83aa

Browse files
authored
[SYCL] ehance UPSCALE to support all UT cases (ggml-org#20637)
* [SYCL] ehance UPSCALE to support more cases * rm test case result of SYCL1
1 parent 2e4a6ed commit b6c83aa

8 files changed

Lines changed: 712 additions & 114 deletions

File tree

docs/ops.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,5 +117,5 @@ Legend:
117117
| TOP_K ||||||| 🟡 | 🟡 ||||
118118
| TRI ||||||||||||
119119
| TRUNC |||| 🟡 ||| 🟡 | 🟡 ||||
120-
| UPSCALE || 🟡 ||| 🟡 | 🟡 | 🟡 |||||
120+
| UPSCALE || 🟡 ||| 🟡 | 🟡 | |||||
121121
| XIELU ||||||||||||

docs/ops/SYCL.csv

Lines changed: 288 additions & 18 deletions
Large diffs are not rendered by default.

ggml/src/ggml-sycl/backend.hpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,15 @@
2424
#include "dmmv.hpp"
2525
#include "element_wise.hpp"
2626
#include "fattn.hpp"
27+
#include "gated_delta_net.hpp"
2728
#include "gla.hpp"
2829
#include "im2col.hpp"
2930
#include "mmq.hpp"
3031
#include "mmvq.hpp"
3132
#include "norm.hpp"
3233
#include "outprod.hpp"
3334
#include "pad.hpp"
35+
#include "pad_reflect_1d.hpp"
3436
#include "quantize.hpp"
3537
#include "quants.hpp"
3638
#include "roll.hpp"
@@ -39,8 +41,8 @@
3941
#include "ssm_conv.hpp"
4042
#include "softmax.hpp"
4143
#include "tsembd.hpp"
44+
#include "upscale.hpp"
4245
#include "wkv.hpp"
43-
#include "pad_reflect_1d.hpp"
4446

4547

4648
#endif // GGML_SYCL_BACKEND_HPP

ggml/src/ggml-sycl/element_wise.cpp

Lines changed: 0 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -294,30 +294,6 @@ static void unary_op_trunc_kernel(const T * x, T * dst, const int k, const sycl:
294294
}
295295
}
296296

297-
template<typename T>
298-
static void upscale(const T *x, T *dst, const int nb00, const int nb01,
299-
const int nb02, const int nb03, const int ne10, const int ne11,
300-
const int ne12, const int ne13, const float sf0, const float sf1,
301-
const float sf2, const float sf3, const sycl::nd_item<1> &item_ct1) {
302-
int index = item_ct1.get_local_id(0) +
303-
item_ct1.get_group(0) * item_ct1.get_local_range(0);
304-
if (index >= ne10 * ne11 * ne12 * ne13) {
305-
return;
306-
}
307-
// operation
308-
int i10 = index % ne10;
309-
int i11 = (index / ne10) % ne11;
310-
int i12 = (index / (ne10 * ne11)) % ne12;
311-
int i13 = (index / (ne10 * ne11 * ne12)) % ne13;
312-
313-
int i00 = static_cast<int>(i10 / sf0);
314-
int i01 = static_cast<int>(i11 / sf1);
315-
int i02 = static_cast<int>(i12 / sf2);
316-
int i03 = static_cast<int>(i13 / sf3);
317-
318-
dst[index] = *(const T *)((const char *)x + i03 * nb03 + i02 * nb02 + i01 * nb01 + i00 * nb00);
319-
}
320-
321297
template<typename T>
322298
static void clamp(const T * x, T * dst, const float min, const float max, const int k,
323299
const sycl::nd_item<1> &item_ct1) {
@@ -392,20 +368,6 @@ static void arange_kernel(T * dst, const int k, T start, T step,
392368
}
393369
}
394370

395-
template<typename T>
396-
static void upscale_sycl(const T *x, T *dst, const int nb00, const int nb01,
397-
const int nb02, const int nb03, const int ne10, const int ne11,
398-
const int ne12, const int ne13, const float sf0, const float sf1,
399-
const float sf2, const float sf3, queue_ptr stream) {
400-
int dst_size = ne10 * ne11 * ne12 * ne13;
401-
int num_blocks = ceil_div(dst_size, SYCL_UPSCALE_BLOCK_SIZE);
402-
sycl::range<1> gridDim(num_blocks * SYCL_UPSCALE_BLOCK_SIZE);
403-
stream->parallel_for(
404-
sycl::nd_range<1>(gridDim, sycl::range<1>(SYCL_UPSCALE_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
405-
upscale(x, dst, nb00, nb01, nb02, nb03, ne10, ne11, ne12, ne13, sf0, sf1, sf2, sf3, item_ct1);
406-
});
407-
}
408-
409371
template<typename KernelInvoker, typename... Args>
410372
static inline void dispatch_ggml_sycl_op_unary(ggml_backend_sycl_context & ctx, ggml_tensor * dst, KernelInvoker kernel_invoker, Args&&... args) {
411373
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
@@ -505,42 +467,6 @@ static inline void dispatch_ggml_sycl_op_fused_glu(ggml_backend_sycl_context & c
505467
}
506468
}
507469

508-
template<typename KernelInvoker, typename... Args>
509-
static inline void dispatch_ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst, KernelInvoker kernel_invoker, Args&&... args) {
510-
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
511-
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
512-
513-
GGML_ASSERT(dst->src[0]->type == dst->type);
514-
515-
dpct::queue_ptr main_stream = ctx.stream();
516-
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
517-
518-
const float sf0 = (float) dst->ne[0] / dst->src[0]->ne[0];
519-
const float sf1 = (float) dst->ne[1] / dst->src[0]->ne[1];
520-
const float sf2 = (float) dst->ne[2] / dst->src[0]->ne[2];
521-
const float sf3 = (float) dst->ne[3] / dst->src[0]->ne[3];
522-
switch (dst->type) {
523-
case GGML_TYPE_F16:
524-
{
525-
auto data_pts = cast_data<sycl::half>(dst);
526-
kernel_invoker(data_pts.src, data_pts.dst, (int)dst->src[0]->nb[0], (int)dst->src[0]->nb[1], (int)dst->src[0]->nb[2],
527-
(int)dst->src[0]->nb[3], (int)dst->ne[0], (int)dst->ne[1], (int)dst->ne[2], (int)dst->ne[3], sf0, sf1, sf2, sf3,
528-
main_stream, std::forward<Args>(args)...);
529-
break;
530-
}
531-
case GGML_TYPE_F32:
532-
{
533-
auto data_pts = cast_data<float>(dst);
534-
kernel_invoker(data_pts.src, data_pts.dst, (int)dst->src[0]->nb[0], (int)dst->src[0]->nb[1], (int)dst->src[0]->nb[2],
535-
(int)dst->src[0]->nb[3], (int)dst->ne[0], (int)dst->ne[1], (int)dst->ne[2], (int)dst->ne[3], sf0, sf1, sf2, sf3,
536-
main_stream, std::forward<Args>(args)...);
537-
break;
538-
}
539-
default:
540-
GGML_ABORT("GGML tensor type not supported!\n");
541-
}
542-
}
543-
544470
template<typename F>
545471
static inline void ggml_sycl_op_unary(
546472
ggml_backend_sycl_context & ctx, ggml_tensor * dst, F func) {
@@ -784,15 +710,6 @@ static inline void ggml_sycl_op_sqr(ggml_backend_sycl_context & ctx, ggml_tensor
784710
});
785711
}
786712

787-
static inline void ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
788-
ggml_sycl_detail::dispatch_ggml_sycl_op_upscale(ctx, dst,
789-
[](const auto* src, auto* dst_ptr, int nb00, int nb01, int nb02, int nb03,
790-
int ne10, int ne11, int ne12, int ne13, float sf0, float sf1, float sf2, float sf3,
791-
queue_ptr stream) {
792-
ggml_sycl_detail::upscale_sycl(src, dst_ptr, nb00, nb01, nb02, nb03, ne10, ne11, ne12, ne13, sf0, sf1, sf2, sf3, stream);
793-
});
794-
}
795-
796713
static inline void ggml_sycl_op_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
797714
float min_val;
798715
float max_val;
@@ -1131,12 +1048,6 @@ void ggml_sycl_sqr(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
11311048
ggml_sycl_op_sqr(ctx, dst);
11321049
}
11331050

1134-
void ggml_sycl_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1135-
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
1136-
ggml_sycl_op_upscale(ctx, dst);
1137-
}
1138-
1139-
11401051
void ggml_sycl_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
11411052
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
11421053
ggml_sycl_op_clamp(ctx, dst);

ggml/src/ggml-sycl/element_wise.hpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,6 @@ void ggml_sycl_leaky_relu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
7171

7272
void ggml_sycl_sqr(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
7373

74-
void ggml_sycl_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
75-
7674
void ggml_sycl_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
7775

7876
void ggml_sycl_sgn(ggml_backend_sycl_context & ctx, ggml_tensor * dst);

ggml/src/ggml-sycl/ggml-sycl.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@
4444
#include "ggml-sycl/backend.hpp"
4545
#include "ggml-sycl/common.hpp"
4646
#include "ggml-sycl/element_wise.hpp"
47-
#include "ggml-sycl/gated_delta_net.hpp"
4847
#include "ggml-sycl/gemm.hpp"
4948
#include "ggml-sycl/getrows.hpp"
5049
#include "ggml-sycl/norm.hpp"
@@ -4863,9 +4862,8 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
48634862
case GGML_OP_ROPE:
48644863
case GGML_OP_ROPE_BACK:
48654864
case GGML_OP_IM2COL:
4866-
return true;
48674865
case GGML_OP_UPSCALE:
4868-
return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST && !(op->op_params[0] & GGML_SCALE_FLAG_ANTIALIAS);
4866+
return true;
48694867
case GGML_OP_SUM:
48704868
case GGML_OP_SUM_ROWS:
48714869
case GGML_OP_MEAN:

0 commit comments

Comments
 (0)