Skip to content

Commit 6bd9bca

Browse files
pytorchbotssjia
andauthored
[ET-VK] Fix pack_fp_linear_weight for devices without VK_KHR_16bit_storage (pytorch#18653)
This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: pytorch#18642 by @SS-JIA ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/SS-JIA/515/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/SS-JIA/515/head Merge bot PR base: https://github.com/pytorch/executorch/tree/main Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/SS-JIA/515/orig Differential Revision: [D99133993](https://our.internmc.facebook.com/intern/diff/D99133993/) @diff-train-skip-merge Co-authored-by: ssjia <ssjia@devvm26340.ftw0.facebook.com>
1 parent 7004989 commit 6bd9bca

5 files changed

Lines changed: 17 additions & 10 deletions

File tree

backends/vulkan/runtime/graph/ops/glsl/pack_fp_linear_weight.glsl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,16 @@
99
#version 450 core
1010

1111
#define PRECISION ${PRECISION}
12-
#define VEC4_T ${texel_load_type(DTYPE, "buffer")}
13-
#define T ${texel_load_component_type(DTYPE, "buffer")}
12+
#define BUF_T ${buffer_scalar_type(BUF_DTYPE)}
13+
#define VEC4_T ${texel_load_type(DTYPE, PACKED_STORAGE)}
14+
#define T ${texel_load_component_type(DTYPE, PACKED_STORAGE)}
1415

1516
$if PACKED_STORAGE == "buffer":
1617
#define OUTPUT_BUFFER
1718

1819
#extension GL_EXT_control_flow_attributes : require
1920

20-
${define_required_extensions("buffer", DTYPE)}
21+
${define_required_extensions("buffer", BUF_DTYPE)}
2122
$if PACKED_STORAGE != "buffer":
2223
${define_required_extensions(PACKED_STORAGE, DTYPE)}
2324

@@ -29,7 +30,7 @@ $if PACKED_STORAGE == "buffer":
2930
${layout_declare_tensor(B, "w", "t_weight_packed", DTYPE, "buffer", is_scalar_array=False)}
3031
$else:
3132
${layout_declare_tensor(B, "w", "t_weight_packed", DTYPE, PACKED_STORAGE, is_scalar_array=False)}
32-
${layout_declare_tensor(B, "r", "t_weight_src", DTYPE, "buffer", is_scalar_array=True)}
33+
${layout_declare_tensor(B, "r", "t_weight_src", BUF_DTYPE, "buffer", is_scalar_array=True)}
3334

3435
layout(push_constant) uniform restrict Block {
3536
int N;

backends/vulkan/runtime/graph/ops/glsl/pack_fp_linear_weight.yaml

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,16 @@
77
pack_fp_linear_weight:
88
parameter_names_with_default_values:
99
DTYPE: float
10+
BUF_DTYPE: float
1011
PACKED_STORAGE: texture2d
1112
generate_variant_forall:
12-
PACKED_STORAGE:
13-
- VALUE: texture2d
14-
- VALUE: buffer
15-
DTYPE:
16-
- VALUE: float
17-
- VALUE: half
13+
combination:
14+
parameter_names: [PACKED_STORAGE, DTYPE, BUF_DTYPE]
15+
combos:
16+
- parameter_values: [texture2d, float, float]
17+
- parameter_values: [texture2d, half, half]
18+
- parameter_values: [texture2d, half, float]
19+
- parameter_values: [buffer, float, float]
20+
- parameter_values: [buffer, half, half]
1821
shader_variants:
1922
- NAME: pack_fp_linear_weight

backends/vulkan/runtime/graph/ops/impl/Conv1dPW.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ static ValueRef prepack_conv1d_pw_weight(
7171
std::string kernel_name = "pack_fp_linear_weight";
7272
add_storage_type_suffix(kernel_name, weight_storage);
7373
add_dtype_suffix(kernel_name, graph.dtype_of(weight_data));
74+
add_dtype_suffix(kernel_name, graph.get_staging_dtype_for(weight_data));
7475

7576
graph.prepack_nodes().emplace_back(new PrepackNode(
7677
graph,

backends/vulkan/runtime/graph/ops/impl/Conv2dPW.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ ValueRef prepack_conv2d_pw_weight(
130130
std::string pack_kernel_name = "pack_fp_linear_weight";
131131
add_storage_type_suffix(pack_kernel_name, weight_storage);
132132
add_dtype_suffix(pack_kernel_name, graph.dtype_of(weight_data));
133+
add_dtype_suffix(pack_kernel_name, graph.get_staging_dtype_for(weight_data));
133134

134135
graph.prepack_nodes().emplace_back(new PrepackNode(
135136
graph,

backends/vulkan/runtime/graph/ops/impl/Linear.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ ValueRef prepack_fp_linear_weight(
8181
std::string kernel_name = "pack_fp_linear_weight";
8282
add_storage_type_suffix(kernel_name, weight_storage);
8383
add_dtype_suffix(kernel_name, graph.dtype_of(weight_data));
84+
add_dtype_suffix(kernel_name, graph.get_staging_dtype_for(weight_data));
8485

8586
graph.prepack_nodes().emplace_back(new PrepackNode(
8687
graph,

0 commit comments

Comments
 (0)