@@ -6176,6 +6176,7 @@ template [[host_name("kernel_flash_attn_ext_f32_dk128_dv128")]] kernel flash_at
61766176template [[host_name(" kernel_flash_attn_ext_f32_dk192_dv192" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1 , dequantize_f32, float4x4, 1 , dequantize_f32, 192 , 192 >;
61776177template [[host_name(" kernel_flash_attn_ext_f32_dk192_dv128" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1 , dequantize_f32, float4x4, 1 , dequantize_f32, 192 , 128 >;
61786178template [[host_name(" kernel_flash_attn_ext_f32_dk256_dv256" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1 , dequantize_f32, float4x4, 1 , dequantize_f32, 256 , 256 >;
6179+ template [[host_name(" kernel_flash_attn_ext_f32_dk320_dv256" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1 , dequantize_f32, float4x4, 1 , dequantize_f32, 320 , 256 >;
61796180template [[host_name(" kernel_flash_attn_ext_f32_dk576_dv512" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1 , dequantize_f32, float4x4, 1 , dequantize_f32, 576 , 512 >;
61806181
61816182template [[host_name(" kernel_flash_attn_ext_f16_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1 , dequantize_f16, half4x4, 1 , dequantize_f16, 32 , 32 >;
@@ -6190,6 +6191,7 @@ template [[host_name("kernel_flash_attn_ext_f16_dk128_dv128")]] kernel flash_at
61906191template [[host_name(" kernel_flash_attn_ext_f16_dk192_dv192" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1 , dequantize_f16, half4x4, 1 , dequantize_f16, 192 , 192 >;
61916192template [[host_name(" kernel_flash_attn_ext_f16_dk192_dv128" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1 , dequantize_f16, half4x4, 1 , dequantize_f16, 192 , 128 >;
61926193template [[host_name(" kernel_flash_attn_ext_f16_dk256_dv256" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1 , dequantize_f16, half4x4, 1 , dequantize_f16, 256 , 256 >;
6194+ template [[host_name(" kernel_flash_attn_ext_f16_dk320_dv256" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1 , dequantize_f16, half4x4, 1 , dequantize_f16, 320 , 256 >;
61936195template [[host_name(" kernel_flash_attn_ext_f16_dk576_dv512" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1 , dequantize_f16, half4x4, 1 , dequantize_f16, 576 , 512 >;
61946196
61956197#if defined(GGML_METAL_HAS_BF16)
@@ -6205,6 +6207,7 @@ template [[host_name("kernel_flash_attn_ext_bf16_dk128_dv128")]] kernel flash_at
62056207template [[host_name(" kernel_flash_attn_ext_bf16_dk192_dv192" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1 , dequantize_bf16, bfloat4x4, 1 , dequantize_bf16, 192 , 192 >;
62066208template [[host_name(" kernel_flash_attn_ext_bf16_dk192_dv128" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1 , dequantize_bf16, bfloat4x4, 1 , dequantize_bf16, 192 , 128 >;
62076209template [[host_name(" kernel_flash_attn_ext_bf16_dk256_dv256" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1 , dequantize_bf16, bfloat4x4, 1 , dequantize_bf16, 256 , 256 >;
6210+ template [[host_name(" kernel_flash_attn_ext_bf16_dk320_dv256" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1 , dequantize_bf16, bfloat4x4, 1 , dequantize_bf16, 320 , 256 >;
62086211template [[host_name(" kernel_flash_attn_ext_bf16_dk576_dv512" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1 , dequantize_bf16, bfloat4x4, 1 , dequantize_bf16, 576 , 512 >;
62096212#endif
62106213
@@ -6220,6 +6223,7 @@ template [[host_name("kernel_flash_attn_ext_q4_0_dk128_dv128")]] kernel flash_at
62206223template [[host_name(" kernel_flash_attn_ext_q4_0_dk192_dv192" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2 , dequantize_q4_0, block_q4_0, 2 , dequantize_q4_0, 192 , 192 >;
62216224template [[host_name(" kernel_flash_attn_ext_q4_0_dk192_dv128" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2 , dequantize_q4_0, block_q4_0, 2 , dequantize_q4_0, 192 , 128 >;
62226225template [[host_name(" kernel_flash_attn_ext_q4_0_dk256_dv256" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2 , dequantize_q4_0, block_q4_0, 2 , dequantize_q4_0, 256 , 256 >;
6226+ template [[host_name(" kernel_flash_attn_ext_q4_0_dk320_dv256" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2 , dequantize_q4_0, block_q4_0, 2 , dequantize_q4_0, 320 , 256 >;
62236227template [[host_name(" kernel_flash_attn_ext_q4_0_dk576_dv512" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2 , dequantize_q4_0, block_q4_0, 2 , dequantize_q4_0, 576 , 512 >;
62246228
62256229template [[host_name(" kernel_flash_attn_ext_q4_1_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2 , dequantize_q4_1, block_q4_1, 2 , dequantize_q4_1, 32 , 32 >;
@@ -6234,6 +6238,7 @@ template [[host_name("kernel_flash_attn_ext_q4_1_dk128_dv128")]] kernel flash_at
62346238template [[host_name(" kernel_flash_attn_ext_q4_1_dk192_dv192" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2 , dequantize_q4_1, block_q4_1, 2 , dequantize_q4_1, 192 , 192 >;
62356239template [[host_name(" kernel_flash_attn_ext_q4_1_dk192_dv128" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2 , dequantize_q4_1, block_q4_1, 2 , dequantize_q4_1, 192 , 128 >;
62366240template [[host_name(" kernel_flash_attn_ext_q4_1_dk256_dv256" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2 , dequantize_q4_1, block_q4_1, 2 , dequantize_q4_1, 256 , 256 >;
6241+ template [[host_name(" kernel_flash_attn_ext_q4_1_dk320_dv256" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2 , dequantize_q4_1, block_q4_1, 2 , dequantize_q4_1, 320 , 256 >;
62376242template [[host_name(" kernel_flash_attn_ext_q4_1_dk576_dv512" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2 , dequantize_q4_1, block_q4_1, 2 , dequantize_q4_1, 576 , 512 >;
62386243
62396244template [[host_name(" kernel_flash_attn_ext_q5_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2 , dequantize_q5_0, block_q5_0, 2 , dequantize_q5_0, 32 , 32 >;
@@ -6248,6 +6253,7 @@ template [[host_name("kernel_flash_attn_ext_q5_0_dk128_dv128")]] kernel flash_at
62486253template [[host_name(" kernel_flash_attn_ext_q5_0_dk192_dv192" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2 , dequantize_q5_0, block_q5_0, 2 , dequantize_q5_0, 192 , 192 >;
62496254template [[host_name(" kernel_flash_attn_ext_q5_0_dk192_dv128" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2 , dequantize_q5_0, block_q5_0, 2 , dequantize_q5_0, 192 , 128 >;
62506255template [[host_name(" kernel_flash_attn_ext_q5_0_dk256_dv256" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2 , dequantize_q5_0, block_q5_0, 2 , dequantize_q5_0, 256 , 256 >;
6256+ template [[host_name(" kernel_flash_attn_ext_q5_0_dk320_dv256" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2 , dequantize_q5_0, block_q5_0, 2 , dequantize_q5_0, 320 , 256 >;
62516257template [[host_name(" kernel_flash_attn_ext_q5_0_dk576_dv512" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2 , dequantize_q5_0, block_q5_0, 2 , dequantize_q5_0, 576 , 512 >;
62526258
62536259template [[host_name(" kernel_flash_attn_ext_q5_1_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2 , dequantize_q5_1, block_q5_1, 2 , dequantize_q5_1, 32 , 32 >;
@@ -6262,6 +6268,7 @@ template [[host_name("kernel_flash_attn_ext_q5_1_dk128_dv128")]] kernel flash_at
62626268template [[host_name(" kernel_flash_attn_ext_q5_1_dk192_dv192" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2 , dequantize_q5_1, block_q5_1, 2 , dequantize_q5_1, 192 , 192 >;
62636269template [[host_name(" kernel_flash_attn_ext_q5_1_dk192_dv128" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2 , dequantize_q5_1, block_q5_1, 2 , dequantize_q5_1, 192 , 128 >;
62646270template [[host_name(" kernel_flash_attn_ext_q5_1_dk256_dv256" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2 , dequantize_q5_1, block_q5_1, 2 , dequantize_q5_1, 256 , 256 >;
6271+ template [[host_name(" kernel_flash_attn_ext_q5_1_dk320_dv256" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2 , dequantize_q5_1, block_q5_1, 2 , dequantize_q5_1, 320 , 256 >;
62656272template [[host_name(" kernel_flash_attn_ext_q5_1_dk576_dv512" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2 , dequantize_q5_1, block_q5_1, 2 , dequantize_q5_1, 576 , 512 >;
62666273
62676274template [[host_name(" kernel_flash_attn_ext_q8_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2 , dequantize_q8_0, block_q8_0, 2 , dequantize_q8_0, 32 , 32 >;
@@ -6276,6 +6283,7 @@ template [[host_name("kernel_flash_attn_ext_q8_0_dk128_dv128")]] kernel flash_at
62766283template [[host_name(" kernel_flash_attn_ext_q8_0_dk192_dv192" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2 , dequantize_q8_0, block_q8_0, 2 , dequantize_q8_0, 192 , 192 >;
62776284template [[host_name(" kernel_flash_attn_ext_q8_0_dk192_dv128" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2 , dequantize_q8_0, block_q8_0, 2 , dequantize_q8_0, 192 , 128 >;
62786285template [[host_name(" kernel_flash_attn_ext_q8_0_dk256_dv256" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2 , dequantize_q8_0, block_q8_0, 2 , dequantize_q8_0, 256 , 256 >;
6286+ template [[host_name(" kernel_flash_attn_ext_q8_0_dk320_dv256" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2 , dequantize_q8_0, block_q8_0, 2 , dequantize_q8_0, 320 , 256 >;
62796287template [[host_name(" kernel_flash_attn_ext_q8_0_dk576_dv512" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2 , dequantize_q8_0, block_q8_0, 2 , dequantize_q8_0, 576 , 512 >;
62806288
62816289#undef FA_TYPES
@@ -6846,6 +6854,17 @@ template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk256_dv256")]] kernel flas
68466854template [[host_name(" kernel_flash_attn_ext_vec_q5_1_dk256_dv256" )]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8 , dequantize_q5_1_t4, block_q5_1, 8 , dequantize_q5_1_t4, 256 , 256 , 1 >;
68476855template [[host_name(" kernel_flash_attn_ext_vec_q8_0_dk256_dv256" )]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8 , dequantize_q8_0_t4, block_q8_0, 8 , dequantize_q8_0_t4, 256 , 256 , 1 >;
68486856
6857+ template [[host_name(" kernel_flash_attn_ext_vec_f32_dk320_dv256" )]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1 , dequantize_f32_t4, float4, 1 , dequantize_f32_t4, 320 , 256 , 2 >;
6858+ template [[host_name(" kernel_flash_attn_ext_vec_f16_dk320_dv256" )]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1 , dequantize_f16_t4, half4, 1 , dequantize_f16_t4, 320 , 256 , 2 >;
6859+ #if defined(GGML_METAL_HAS_BF16)
6860+ template [[host_name(" kernel_flash_attn_ext_vec_bf16_dk320_dv256" )]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1 , dequantize_bf16_t4, bfloat4, 1 , dequantize_bf16_t4, 320 , 256 , 2 >;
6861+ #endif
6862+ template [[host_name(" kernel_flash_attn_ext_vec_q4_0_dk320_dv256" )]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8 , dequantize_q4_0_t4, block_q4_0, 8 , dequantize_q4_0_t4, 320 , 256 , 2 >;
6863+ template [[host_name(" kernel_flash_attn_ext_vec_q4_1_dk320_dv256" )]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8 , dequantize_q4_1_t4, block_q4_1, 8 , dequantize_q4_1_t4, 320 , 256 , 2 >;
6864+ template [[host_name(" kernel_flash_attn_ext_vec_q5_0_dk320_dv256" )]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8 , dequantize_q5_0_t4, block_q5_0, 8 , dequantize_q5_0_t4, 320 , 256 , 2 >;
6865+ template [[host_name(" kernel_flash_attn_ext_vec_q5_1_dk320_dv256" )]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8 , dequantize_q5_1_t4, block_q5_1, 8 , dequantize_q5_1_t4, 320 , 256 , 2 >;
6866+ template [[host_name(" kernel_flash_attn_ext_vec_q8_0_dk320_dv256" )]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8 , dequantize_q8_0_t4, block_q8_0, 8 , dequantize_q8_0_t4, 320 , 256 , 2 >;
6867+
68496868template [[host_name(" kernel_flash_attn_ext_vec_f32_dk576_dv512" )]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1 , dequantize_f32_t4, float4, 1 , dequantize_f32_t4, 576 , 512 , 2 >;
68506869template [[host_name(" kernel_flash_attn_ext_vec_f16_dk576_dv512" )]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1 , dequantize_f16_t4, half4, 1 , dequantize_f16_t4, 576 , 512 , 2 >;
68516870#if defined(GGML_METAL_HAS_BF16)
0 commit comments