From 89abb6ff14f6696454db811d473b48d21a9be63a Mon Sep 17 00:00:00 2001 From: ssjia Date: Tue, 9 Sep 2025 13:30:10 -0700 Subject: [PATCH 1/3] [ET-VK] Implement SDPA with fused ops ## Context As title; optimize the SDPA operator by introducing shaders to perform the operation in 3 steps: 1. Compute attention weights, multiplying QT x K_cache, and applying scale and mask 2. Compute softmax normalization of computed attention weights 3. Compute final output by multiplying attention weights with V cache This new implementation is much more efficient than the existing one, which performed slicing, repeat_interleave, and transposition of projected and cache tensors as separate steps. The fusion of scale and mask with the computation of attention weights also allows for the computation of elements within the mask region to be skipped. ## Impact Decode latency for LLMs is much improved. For llama 3.2 3B generating ~250 tokens, decode latency increases from ~15 tok/s to ~21.5 tok/s Differential Revision: [D82053493](https://our.internmc.facebook.com/intern/diff/D82053493/) [ghstack-poisoned] --- .github/workflows/pull.yml | 4 + backends/vulkan/op_registry.py | 2 +- .../ops/glsl/flash_attention_buffer.glsl | 227 ------- .../ops/glsl/flash_attention_texture3d.glsl | 332 --------- .../graph/ops/glsl/kv_cache_update.glsl | 80 --- .../glsl/sdpa_attn_weight_scale_and_mask.glsl | 120 ---- .../glsl/sdpa_attn_weight_scale_and_mask.yaml | 13 - .../ops/glsl/sdpa_attn_weights_softmax.glsl | 164 +++++ ...te.yaml => sdpa_attn_weights_softmax.yaml} | 9 +- .../glsl/sdpa_compute_attn_weights_coop.glsl | 213 ++++++ .../glsl/sdpa_compute_attn_weights_coop.yaml | 21 + .../glsl/sdpa_compute_attn_weights_tiled.glsl | 203 ++++++ .../glsl/sdpa_compute_attn_weights_tiled.yaml | 22 + .../graph/ops/glsl/sdpa_compute_out_coop.glsl | 195 ++++++ ...ture3d.yaml => sdpa_compute_out_coop.yaml} | 12 +- .../ops/glsl/sdpa_compute_out_tiled.glsl | 165 +++++ .../ops/glsl/sdpa_compute_out_tiled.yaml | 22 + .../glsl/sdpa_fp_attn_weight_tile_load.glslh | 74 ++ .../glsl/sdpa_fp_attn_weight_tile_store.glslh | 105 +++ .../ops/glsl/sdpa_fp_k_cache_tile_load.glslh | 93 +++ .../ops/glsl/sdpa_fp_out_tile_store.glslh | 57 ++ .../glsl/sdpa_fp_q_projected_tile_load.glslh | 74 ++ .../ops/glsl/sdpa_fp_v_cache_tile_load.glslh | 76 +++ .../graph/ops/glsl/sdpa_kv_cache_update.glsl | 90 +++ ..._buffer.yaml => sdpa_kv_cache_update.yaml} | 10 +- .../glsl/sdpa_q_projected_input_tile.glslh | 42 ++ .../vulkan/runtime/graph/ops/impl/SDPA.cpp | 641 +++++++++--------- backends/vulkan/test/op_tests/CMakeLists.txt | 2 +- backends/vulkan/test/op_tests/sdpa_test.cpp | 497 +++----------- backends/vulkan/test/scripts/test_op.sh | 62 +- 30 files changed, 2080 insertions(+), 1547 deletions(-) delete mode 100644 backends/vulkan/runtime/graph/ops/glsl/flash_attention_buffer.glsl delete mode 100644 backends/vulkan/runtime/graph/ops/glsl/flash_attention_texture3d.glsl delete mode 100644 backends/vulkan/runtime/graph/ops/glsl/kv_cache_update.glsl delete mode 100644 backends/vulkan/runtime/graph/ops/glsl/sdpa_attn_weight_scale_and_mask.glsl delete mode 100644 backends/vulkan/runtime/graph/ops/glsl/sdpa_attn_weight_scale_and_mask.yaml create mode 100644 backends/vulkan/runtime/graph/ops/glsl/sdpa_attn_weights_softmax.glsl rename backends/vulkan/runtime/graph/ops/glsl/{kv_cache_update.yaml => sdpa_attn_weights_softmax.yaml} (82%) create mode 100644 backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_coop.glsl create mode 100644 backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_coop.yaml create mode 100644 backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_tiled.glsl create mode 100644 backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_tiled.yaml create mode 100644 backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_coop.glsl rename backends/vulkan/runtime/graph/ops/glsl/{flash_attention_texture3d.yaml => sdpa_compute_out_coop.yaml} (56%) create mode 100644 backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_tiled.glsl create mode 100644 backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_tiled.yaml create mode 100644 backends/vulkan/runtime/graph/ops/glsl/sdpa_fp_attn_weight_tile_load.glslh create mode 100644 backends/vulkan/runtime/graph/ops/glsl/sdpa_fp_attn_weight_tile_store.glslh create mode 100644 backends/vulkan/runtime/graph/ops/glsl/sdpa_fp_k_cache_tile_load.glslh create mode 100644 backends/vulkan/runtime/graph/ops/glsl/sdpa_fp_out_tile_store.glslh create mode 100644 backends/vulkan/runtime/graph/ops/glsl/sdpa_fp_q_projected_tile_load.glslh create mode 100644 backends/vulkan/runtime/graph/ops/glsl/sdpa_fp_v_cache_tile_load.glslh create mode 100644 backends/vulkan/runtime/graph/ops/glsl/sdpa_kv_cache_update.glsl rename backends/vulkan/runtime/graph/ops/glsl/{flash_attention_buffer.yaml => sdpa_kv_cache_update.yaml} (61%) create mode 100644 backends/vulkan/runtime/graph/ops/glsl/sdpa_q_projected_input_tile.glslh diff --git a/.github/workflows/pull.yml b/.github/workflows/pull.yml index 68d7f90d09c..cb0a3d9b679 100644 --- a/.github/workflows/pull.yml +++ b/.github/workflows/pull.yml @@ -971,6 +971,10 @@ jobs: ./cmake-out/backends/vulkan/test/custom_ops/q4gsw_linear ./cmake-out/backends/vulkan/test/custom_ops/choose_qparams_per_row + # "Classic" Operator tests + PYTHON_EXECUTABLE=python bash backends/vulkan/test/scripts/test_op.sh --build + ./cmake-out/backends/vulkan/test/op_tests/vulkan_sdpa_test + # Run e2e testing for selected operators. More operators will be tested via this # route in the future. python -m unittest backends/vulkan/test/test_vulkan_delegate.py -k "*pt2e*" diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index 934e02eb7be..9f1561fb05e 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -571,7 +571,7 @@ def register_sdpa_with_kv_cache_op(): ) def register_sdpa_ops(): return OpFeatures( - inputs_storage=utils.WIDTH_PACKED_TEXTURE, + inputs_storage=utils.CONTIGUOUS_ANY, supports_resize=True, ) diff --git a/backends/vulkan/runtime/graph/ops/glsl/flash_attention_buffer.glsl b/backends/vulkan/runtime/graph/ops/glsl/flash_attention_buffer.glsl deleted file mode 100644 index 8509fdf1f49..00000000000 --- a/backends/vulkan/runtime/graph/ops/glsl/flash_attention_buffer.glsl +++ /dev/null @@ -1,227 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#version 450 core - -#define PRECISION ${PRECISION} -#define T ${buffer_scalar_type(DTYPE)} -${define_required_extensions(DTYPE)} - -layout(std430) buffer; - -#include "indexing_utils.h" - -// Flash Attention inputs: Query, Key, Value tensors -${layout_declare_tensor(B, "rw", "t_O", DTYPE, "buffer")} -${layout_declare_tensor(B, "rw", "t_l", "float", "buffer")} -${layout_declare_tensor(B, "rw", "t_m", "float", "buffer")} -${layout_declare_tensor(B, "r", "t_Q", DTYPE, "buffer")} -${layout_declare_tensor(B, "r", "t_K", DTYPE, "buffer")} -${layout_declare_tensor(B, "r", "t_V", DTYPE, "buffer")} - -${layout_declare_ubo(B, "ivec4", "Q_sizes")} // [B, H, N, D] -${layout_declare_ubo(B, "ivec4", "K_sizes")} -${layout_declare_ubo(B, "ivec4", "V_sizes")} -${layout_declare_ubo(B, "ivec4", "O_sizes")} - -${layout_declare_ubo(B, "ivec3", "l_sizes")} // [B, H, N] -${layout_declare_ubo(B, "ivec3", "m_sizes")} // [B, H, N] - -${layout_declare_ubo(B, "float", "scale")} -${layout_declare_ubo(B, "int", "block_size_r")} // Br (num rows in Q block) -${layout_declare_ubo(B, "int", "block_size_c")} // Bc (num cols in K/V block) -${layout_declare_ubo(B, "int", "input_pos")} // Starting position for causal masking -${layout_declare_ubo(B, "int", "num_heads")} // Number of query heads -${layout_declare_ubo(B, "int", "num_kv_heads")} // Number of key/value heads -layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; - -// Maximum block sizes to prevent array overflow -#define MAX_BR 64 -#define MAX_BC 128 - -void main() { - // Each thread processes one row block - const int thread_id = int(gl_GlobalInvocationID.x); - - // Tensor dimensions: Q_sizes = [D, H, N, B] from graph.sizes_ubo() - // The UBO layout is different from the PyTorch tensor layout - const int head_dim = Q_sizes.x; // D (head dim) - const int num_heads = Q_sizes.y; // H (num heads) - const int seq_len = Q_sizes.z; // N (sequence length) - const int batch_size = Q_sizes.w; // B (batch) - - // Block sizes - const int Br = block_size_r; - const int Bc = block_size_c; - - const int Tr = (seq_len + Br - 1) / Br; // Number of row blocks - const int total_row_blocks = batch_size * num_heads * Tr; - - if (thread_id >= total_row_blocks) { - return; - } - - // Decode thread_id to (batch, head, row_block) - const int batch = thread_id / (num_heads * Tr); - const int remaining = thread_id % (num_heads * Tr); - const int head = remaining / Tr; - const int row_block = remaining % Tr; - - // Calculate row range for this block - const int row_start = row_block * Br; - const int row_end = min(row_start + Br, seq_len); - const int actual_Br = row_end - row_start; - - // Base indices for this batch - const int q_base = batch * (seq_len * num_heads * head_dim); - const int k_base = batch * (seq_len * num_heads * head_dim); - const int v_base = batch * (seq_len * num_heads * head_dim); - const int o_base = batch * (seq_len * num_heads * head_dim); - const int lm_base = batch * (seq_len * num_heads); - - // STEP 2: Initialize O = 0, l = 0, m = -inf for this row block - for (int r = 0; r < actual_Br; r++) { - const int seq_pos = row_start + r; - const int lm_idx = lm_base + head * seq_len + seq_pos; - - t_l[lm_idx] = 0.0; - t_m[lm_idx] = -1.0 / 0.0; // -infinity - - for (int dim = 0; dim < head_dim; dim++) { - const int o_idx = o_base + seq_pos * (num_heads * head_dim) + head * head_dim + dim; - t_O[o_idx] = T(0.0); - } - } - - // STEP 5: Outer loop over column blocks (For K, V tensors) - const int Tc = (seq_len + Bc - 1) / Bc; // Number of column blocks - for (int j = 0; j < Tc; j++) { - const int col_start = j * Bc; - const int col_end = min(col_start + Bc, seq_len); - const int actual_Bc = col_end - col_start; - - // STEP 6-8 done implicitly below - - // Load current statistics for all rows in this block - float m_i[MAX_BR]; - float l_i[MAX_BR]; - for (int r = 0; r < actual_Br; r++) { - const int seq_pos = row_start + r; - const int lm_idx = lm_base + head * seq_len + seq_pos; - m_i[r] = t_m[lm_idx]; - l_i[r] = t_l[lm_idx]; - } - - // STEP 9: Compute Sij = Qi * Kj^T - T S_block[MAX_BR][MAX_BC]; // Use MAX_BR and MAX_BC constants - float m_tilde_ij[MAX_BR]; // Row maxes (float to match l/m) - float l_tilde_ij[MAX_BR]; // Row sums (float to match l/m) - - // Initialize row statistics - for (int r = 0; r < actual_Br; r++) { - m_tilde_ij[r] = -1.0 / 0.0; // -infinity - l_tilde_ij[r] = 0.0; - } - - // Compute attention scores Sij = Qi @ Kj^T - for (int r = 0; r < actual_Br; r++) { - const int global_row = row_start + r; - for (int c = 0; c < actual_Bc; c++) { - const int global_col = col_start + c; - - // For multi-query attention: map query head to KV head - const int kv_head = (head * num_kv_heads) / num_heads; - - // Dot product: Q[seq_pos, :] · K[col_pos, :] - T score = T(0.0); - for (int dim = 0; dim < head_dim; dim++) { - const int q_idx = q_base + global_row * (num_heads * head_dim) + head * head_dim + dim; - const int k_idx = k_base + global_col * (num_kv_heads * head_dim) + kv_head * head_dim + dim; - score += t_Q[q_idx] * t_K[k_idx]; - } - score *= scale; - - - // Apply causal masking: mask if global_col > global_row + input_pos - if (global_col > global_row + input_pos) { - score = T(-1.0 / 0.0); // Set to negative infinity - } - - S_block[r][c] = score; - - // Track row maximum (after masking) - m_tilde_ij[r] = max(m_tilde_ij[r], float(score)); - } - } - - // STEP 10: Compute P'ij = exp(Sij − m'ij) and l'ij = rowsum(P'ij) - for (int r = 0; r < actual_Br; r++) { - // Handle the case where all scores are -inf (fully masked row) - if (isinf(m_tilde_ij[r]) && m_tilde_ij[r] < 0.0) { - // All scores are -inf, so all probabilities are 0 - for (int c = 0; c < actual_Bc; c++) { - S_block[r][c] = T(0.0); - } - l_tilde_ij[r] = 0.0; - } else { - // Normal case: compute softmax - for (int c = 0; c < actual_Bc; c++) { - S_block[r][c] = exp(S_block[r][c] - T(m_tilde_ij[r])); - l_tilde_ij[r] += float(S_block[r][c]); - } - } - } - - // STEP 11: Softmax update - float m_new_i[MAX_BR]; - float l_new_i[MAX_BR]; - for (int r = 0; r < actual_Br; r++) { - m_new_i[r] = max(m_i[r], m_tilde_ij[r]); - - l_new_i[r] = exp(m_i[r] - m_new_i[r]) * l_i[r] + exp(m_tilde_ij[r] - m_new_i[r]) * l_tilde_ij[r]; - } - - // STEP 12: Update Oi - for (int r = 0; r < actual_Br; r++) { - const int global_row = row_start + r; - float alpha = exp(m_i[r] - m_new_i[r]); - float beta = exp(m_tilde_ij[r] - m_new_i[r]); - - // For multi-query attention: map query head to KV head - const int kv_head = (head * num_kv_heads) / num_heads; - - for (int dim = 0; dim < head_dim; dim++) { - const int o_idx = o_base + global_row * (num_heads * head_dim) + head * head_dim + dim; - - // Compute P'ij @ Vj for this dimension - T pv_sum = T(0.0); - for (int c = 0; c < actual_Bc; c++) { - const int global_col = col_start + c; - const int v_idx = v_base + global_col * (num_kv_heads * head_dim) + kv_head * head_dim + dim; - pv_sum += S_block[r][c] * t_V[v_idx]; - } - - // Check for division by zero before updating output - if (l_new_i[r] <= 0.0) { - t_O[o_idx] = T(0.0); // Set to zero to avoid NaN - } else { - // Oi = (alpha * l_i * Oi + beta * P'ij @ Vj) / l_new_i - t_O[o_idx] = (T(alpha) * T(l_i[r]) * t_O[o_idx] + T(beta) * pv_sum) / T(l_new_i[r]); - } - } - } - - // STEP 13: Update li, mi - for (int r = 0; r < actual_Br; r++) { - const int seq_pos = row_start + r; - const int lm_idx = lm_base + head * seq_len + seq_pos; - t_l[lm_idx] = l_new_i[r]; - t_m[lm_idx] = m_new_i[r]; - } - } -} diff --git a/backends/vulkan/runtime/graph/ops/glsl/flash_attention_texture3d.glsl b/backends/vulkan/runtime/graph/ops/glsl/flash_attention_texture3d.glsl deleted file mode 100644 index 1f72a583410..00000000000 --- a/backends/vulkan/runtime/graph/ops/glsl/flash_attention_texture3d.glsl +++ /dev/null @@ -1,332 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#version 450 core - -#define PRECISION ${PRECISION} -#define T ${buffer_scalar_type(DTYPE)} -${define_required_extensions(DTYPE)} - -layout(std430) buffer; - -#include "indexing_utils.h" - -// Flash Attention inputs: Query, Key, Value tensors using texture storage -${layout_declare_tensor(B, "rw", "t_O", DTYPE, "texture3d")} -${layout_declare_tensor(B, "rw", "t_l", "float", "texture3d")} -${layout_declare_tensor(B, "rw", "t_m", "float", "texture3d")} -${layout_declare_tensor(B, "r", "t_Q", DTYPE, "texture3d")} -${layout_declare_tensor(B, "r", "t_K", DTYPE, "texture3d")} -${layout_declare_tensor(B, "r", "t_V", DTYPE, "texture3d")} - -${layout_declare_ubo(B, "ivec4", "Q_sizes")} // [B, H, N, D] -${layout_declare_ubo(B, "ivec4", "K_sizes")} -${layout_declare_ubo(B, "ivec4", "V_sizes")} -${layout_declare_ubo(B, "ivec4", "O_sizes")} - -${layout_declare_ubo(B, "ivec3", "l_sizes")} // [B, H, N] -${layout_declare_ubo(B, "ivec3", "m_sizes")} // [B, H, N] - -${layout_declare_ubo(B, "float", "scale")} -${layout_declare_ubo(B, "int", "block_size_r")} // Br (num rows in Q block) -${layout_declare_ubo(B, "int", "block_size_c")} // Bc (num cols in K/V block) -${layout_declare_ubo(B, "int", "input_pos")} // Starting position for causal masking -${layout_declare_ubo(B, "int", "num_heads")} // Number of query heads -${layout_declare_ubo(B, "int", "num_kv_heads")} // Number of key/value heads - -// Axis mapping setup for proper texture indexing -${layout_declare_spec_const(C, "int", "Q_layout", "DEFAULT_LAYOUT")} -const lowp ivec4 Q_axis_map = unhash_axis_map(Q_layout); -const lowp int Q_packed_dim = unhash_packed_dim(Q_layout); - -${layout_declare_spec_const(C, "int", "K_layout", "DEFAULT_LAYOUT")} -const lowp ivec4 K_axis_map = unhash_axis_map(K_layout); -const lowp int K_packed_dim = unhash_packed_dim(K_layout); - -${layout_declare_spec_const(C, "int", "V_layout", "DEFAULT_LAYOUT")} -const lowp ivec4 V_axis_map = unhash_axis_map(V_layout); -const lowp int V_packed_dim = unhash_packed_dim(V_layout); - -${layout_declare_spec_const(C, "int", "O_layout", "DEFAULT_LAYOUT")} -const lowp ivec4 O_axis_map = unhash_axis_map(O_layout); -const lowp int O_packed_dim = unhash_packed_dim(O_layout); - -${layout_declare_spec_const(C, "int", "l_layout", "DEFAULT_LAYOUT")} -const lowp ivec4 l_axis_map = unhash_axis_map(l_layout); -const lowp int l_packed_dim = unhash_packed_dim(l_layout); - -${layout_declare_spec_const(C, "int", "m_layout", "DEFAULT_LAYOUT")} -const lowp ivec4 m_axis_map = unhash_axis_map(m_layout); -const lowp int m_packed_dim = unhash_packed_dim(m_layout); - -layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; - -// Maximum block sizes to prevent array overflow -#define MAX_BR 64 -#define MAX_BC 128 - -// Texture access helper functions using proper axis mapping -// Q_sizes, K_sizes, V_sizes, O_sizes are [D, H, N, B] (UBO layout) -// l_sizes, m_sizes are [B, H, N] (UBO layout) -T load_tensor_Q(int batch, int seq_pos, int head, int dim) { - ivec4 tidx = ivec4(dim, head, seq_pos, batch); // Match [D, H, N, B] order - ivec3 pos = tidx_to_pos(tidx, Q_sizes, Q_axis_map, Q_packed_dim); - int component = tidx[Q_packed_dim] % 4; - vec4 texel = texelFetch(t_Q, pos, 0); - return T(texel[component]); -} - -T load_tensor_K(int batch, int seq_pos, int head, int dim) { - ivec4 tidx = ivec4(dim, head, seq_pos, batch); // Match [D, H, N, B] order - ivec3 pos = tidx_to_pos(tidx, K_sizes, K_axis_map, K_packed_dim); - int component = tidx[K_packed_dim] % 4; - vec4 texel = texelFetch(t_K, pos, 0); - return T(texel[component]); -} - -T load_tensor_V(int batch, int seq_pos, int head, int dim) { - ivec4 tidx = ivec4(dim, head, seq_pos, batch); // Match [D, H, N, B] order - ivec3 pos = tidx_to_pos(tidx, V_sizes, V_axis_map, V_packed_dim); - int component = tidx[V_packed_dim] % 4; - vec4 texel = texelFetch(t_V, pos, 0); - return T(texel[component]); -} - -T load_tensor_O(int batch, int seq_pos, int head, int dim) { - ivec4 tidx = ivec4(dim, head, seq_pos, batch); // Match [D, H, N, B] order - ivec3 pos = tidx_to_pos(tidx, O_sizes, O_axis_map, O_packed_dim); - int component = tidx[O_packed_dim] % 4; - vec4 texel = imageLoad(t_O, pos); - return T(texel[component]); -} - -void store_tensor_O(int batch, int seq_pos, int head, int dim, T value) { - ivec4 tidx = ivec4(dim, head, seq_pos, batch); // Match [D, H, N, B] order - ivec3 pos = tidx_to_pos(tidx, O_sizes, O_axis_map, O_packed_dim); - int component = tidx[O_packed_dim] % 4; - vec4 texel = imageLoad(t_O, pos); - texel[component] = float(value); - imageStore(t_O, pos, texel); -} - -float load_tensor_l(int batch, int head, int seq_pos) { - ivec4 tidx = ivec4(seq_pos, head, batch, 0); // Match [N, H, B] order (with padding) - ivec3 pos = tidx_to_pos(tidx, ivec4(l_sizes, 1), l_axis_map, l_packed_dim); - int component = tidx[l_packed_dim] % 4; - vec4 texel = imageLoad(t_l, pos); - return texel[component]; -} - -void store_tensor_l(int batch, int head, int seq_pos, float value) { - ivec4 tidx = ivec4(seq_pos, head, batch, 0); // Match [N, H, B] order (with padding) - ivec3 pos = tidx_to_pos(tidx, ivec4(l_sizes, 1), l_axis_map, l_packed_dim); - int component = tidx[l_packed_dim] % 4; - vec4 texel = imageLoad(t_l, pos); - texel[component] = value; - imageStore(t_l, pos, texel); -} - -float load_tensor_m(int batch, int head, int seq_pos) { - ivec4 tidx = ivec4(seq_pos, head, batch, 0); // Match [N, H, B] order (with padding) - ivec3 pos = tidx_to_pos(tidx, ivec4(m_sizes, 1), m_axis_map, m_packed_dim); - int component = tidx[m_packed_dim] % 4; - vec4 texel = imageLoad(t_m, pos); - return texel[component]; -} - -void store_tensor_m(int batch, int head, int seq_pos, float value) { - ivec4 tidx = ivec4(seq_pos, head, batch, 0); // Match [N, H, B] order (with padding) - ivec3 pos = tidx_to_pos(tidx, ivec4(m_sizes, 1), m_axis_map, m_packed_dim); - int component = tidx[m_packed_dim] % 4; - vec4 texel = imageLoad(t_m, pos); - texel[component] = value; - imageStore(t_m, pos, texel); - -} - -void main() { - // Each thread processes one row block - same as buffer version - const int thread_id = int(gl_GlobalInvocationID.x); - - // Tensor dimensions: Q_sizes = [D, H, N, B] - const int head_dim = Q_sizes.x; // D (head dim) - const int num_heads_val = Q_sizes.y; // H (num heads) - const int seq_len = Q_sizes.z; // N (sequence length) - const int batch_size = Q_sizes.w; // B (batch) - - // Block sizes - const int Br = block_size_r; - const int Bc = block_size_c; - - const int Tr = (seq_len + Br - 1) / Br; // Number of row blocks - const int total_row_blocks = batch_size * num_heads_val * Tr; - - if (thread_id >= total_row_blocks) { - return; - } - - // Decode thread_id to (batch, head, row_block) - const int batch = thread_id / (num_heads_val * Tr); - const int remaining = thread_id % (num_heads_val * Tr); - const int head = remaining / Tr; - const int row_block = remaining % Tr; - - // Calculate row range for this block - const int row_start = row_block * Br; - const int row_end = min(row_start + Br, seq_len); - const int actual_Br = row_end - row_start; - - // STEP 1: Initialize only this thread's row block - // Each thread initializes its own rows to avoid cross-workgroup synchronization issues - for (int r = 0; r < actual_Br; r++) { - const int seq_pos = row_start + r; - - // Initialize l and m textures for this row block's positions - ivec4 l_tidx = ivec4(batch, head, seq_pos, 0); - ivec3 l_pos = tidx_to_pos(l_tidx, ivec4(l_sizes, 1), l_axis_map, l_packed_dim); - vec4 l_texel = vec4(0.0); - imageStore(t_l, l_pos, l_texel); - - ivec4 m_tidx = ivec4(batch, head, seq_pos, 0); - ivec3 m_pos = tidx_to_pos(m_tidx, ivec4(m_sizes, 1), m_axis_map, m_packed_dim); - vec4 m_texel = vec4(-1e10); - imageStore(t_m, m_pos, m_texel); - - // Initialize output tensor for this row block - for (int dim = 0; dim < head_dim; dim++) { - store_tensor_O(batch, seq_pos, head, dim, T(0.0)); - } - } - - // STEP 5: Outer loop over column blocks (For K, V tensors) - const int Tc = (seq_len + Bc - 1) / Bc; // Number of column blocks - for (int j = 0; j < Tc; j++) { - const int col_start = j * Bc; - const int col_end = min(col_start + Bc, seq_len); - const int actual_Bc = col_end - col_start; - - // Load current statistics for all rows in this block - float m_i[MAX_BR]; - float l_i[MAX_BR]; - for (int r = 0; r < actual_Br; r++) { - const int seq_pos = row_start + r; - m_i[r] = load_tensor_m(batch, head, seq_pos); - l_i[r] = load_tensor_l(batch, head, seq_pos); - } - - // STEP 9: Compute Sij = Qi * Kj^T - T S_block[MAX_BR][MAX_BC]; - float m_tilde_ij[MAX_BR]; // Row maxes - float l_tilde_ij[MAX_BR]; // Row sums - - // Initialize row statistics - for (int r = 0; r < actual_Br; r++) { - m_tilde_ij[r] = -1.0 / 0.0; // -infinity - l_tilde_ij[r] = 0.0; - } - - // Compute attention scores Sij = Qi @ Kj^T - for (int r = 0; r < actual_Br; r++) { - const int global_row = row_start + r; - for (int c = 0; c < actual_Bc; c++) { - const int global_col = col_start + c; - - // For multi-query attention: map query head to KV head - const int kv_head = (head * num_kv_heads) / num_heads_val; - - // Dot product: Q[seq_pos, :] · K[col_pos, :] - T score = T(0.0); - for (int dim = 0; dim < head_dim; dim++) { - T q_val = load_tensor_Q(batch, global_row, head, dim); - T k_val = load_tensor_K(batch, global_col, kv_head, dim); - score += q_val * k_val; - } - score *= scale; - - - // Apply causal masking: mask if global_col > global_row + input_pos - bool masked = (global_col > global_row + input_pos); - if (masked) { - score = T(-1.0 / 0.0); // Set to negative infinity - } - - S_block[r][c] = score; - - - // Track row maximum (after masking) - m_tilde_ij[r] = max(m_tilde_ij[r], float(score)); - } - } - - // STEP 10: Compute P'ij = exp(Sij − m'ij) and l'ij = rowsum(P'ij) - for (int r = 0; r < actual_Br; r++) { - // Handle the case where all scores are -inf (fully masked row) - if (isinf(m_tilde_ij[r]) && m_tilde_ij[r] < 0.0) { - // All scores are -inf, so all probabilities are 0 - for (int c = 0; c < actual_Bc; c++) { - S_block[r][c] = 0.0; - } - l_tilde_ij[r] = 0.0; - } else { - // Normal case: compute softmax - for (int c = 0; c < actual_Bc; c++) { - S_block[r][c] = exp(S_block[r][c] - T(m_tilde_ij[r])); - l_tilde_ij[r] += float(S_block[r][c]); - } - } - } - - // STEP 11: Softmax update - float m_new_i[MAX_BR]; - float l_new_i[MAX_BR]; - for (int r = 0; r < actual_Br; r++) { - m_new_i[r] = max(m_i[r], m_tilde_ij[r]); - l_new_i[r] = exp(m_i[r] - m_new_i[r]) * l_i[r] + exp(m_tilde_ij[r] - m_new_i[r]) * l_tilde_ij[r]; - - } - - // STEP 12: Update Oi - for (int r = 0; r < actual_Br; r++) { - const int global_row = row_start + r; - float alpha = exp(m_i[r] - m_new_i[r]); - float beta = exp(m_tilde_ij[r] - m_new_i[r]); - - // For multi-query attention: map query head to KV head - const int kv_head = (head * num_kv_heads) / num_heads_val; - - for (int dim = 0; dim < head_dim; dim++) { - // Compute P'ij @ Vj for this dimension - T pv_sum = T(0.0); - for (int c = 0; c < actual_Bc; c++) { - const int global_col = col_start + c; - T v_val = load_tensor_V(batch, global_col, kv_head, dim); - pv_sum += S_block[r][c] * v_val; - } - - // Check for division by zero before updating output - if (l_new_i[r] <= 0.0) { - store_tensor_O(batch, global_row, head, dim, T(0.0)); - } else { - // Oi = (alpha * l_i * Oi + beta * P'ij @ Vj) / l_new_i - T current_o = load_tensor_O(batch, global_row, head, dim); - T new_o = (T(alpha) * T(l_i[r]) * current_o + T(beta) * pv_sum) / T(l_new_i[r]); - store_tensor_O(batch, global_row, head, dim, new_o); - - } - } - } - - // STEP 13: Update li, mi - for (int r = 0; r < actual_Br; r++) { - const int seq_pos = row_start + r; - store_tensor_l(batch, head, seq_pos, l_new_i[r]); - store_tensor_m(batch, head, seq_pos, m_new_i[r]); - } - - } -} diff --git a/backends/vulkan/runtime/graph/ops/glsl/kv_cache_update.glsl b/backends/vulkan/runtime/graph/ops/glsl/kv_cache_update.glsl deleted file mode 100644 index 8028362c3e5..00000000000 --- a/backends/vulkan/runtime/graph/ops/glsl/kv_cache_update.glsl +++ /dev/null @@ -1,80 +0,0 @@ -#version 450 core - -#define PRECISION ${PRECISION} - -#define T ${buffer_scalar_type(DTYPE)} - -${define_active_storage_type(STORAGE)} -${define_required_extensions(DTYPE)} - -layout(std430) buffer; - -#include "indexing_utils.h" - -${layout_declare_tensor(B, "w", "cache", DTYPE, STORAGE)} -${layout_declare_tensor(B, "r", "projected", DTYPE, STORAGE)} -$if STORAGE == "buffer": - ${layout_declare_ubo(B, "int", "projected_numel")} - ${layout_declare_ubo(B, "ivec4", "cache_strides")} - ${layout_declare_ubo(B, "int", "input_pos")} -$else: - ${layout_declare_ubo(B, "ivec3", "projected_limits")} - ${layout_declare_ubo(B, "int", "input_pos")} - - -layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; - -/* - * t_cache will have sizes of (max_batch_size, max_seq_len, n_heads, head_dim). - * t_projected will have sizes of (batch_size, seq_len, n_heads, head_dim). - * - * The cache update inserts the values of t_projected into t_cache at the index - * specified by input_pos at the seq_len dimension. It is equivalent to calling - - * t_cache = t_cache.slice_scatter( - * t_projected, dim=1, start=input_pos, end=input_pos+seq_len) - * - * Note that this shader is implemented assuming that max_batch_size is 1. - */ - -#ifdef USING_BUFFER - -/*************************** - ** Buffer Implementation ** - ***************************/ - -void main() { - int projected_bufi = int(gl_GlobalInvocationID.x); - // Bump cache index forward by input_pos elements along the seq_len dimension. - // cache_strides contains the strides of the cache tensor. - int cache_bufi = input_pos * cache_strides.z + projected_bufi; - if (projected_bufi >= projected_numel) { - return; - } - cache[cache_bufi] = projected[projected_bufi]; -} - -#else - -/**************************** - ** Texture Implementation ** - ****************************/ - -// Note that this shader assumes the that tensors are width packed, i.e. -// packed_dim = 0 -void main() { - const ivec3 projected_pos = ivec3(gl_GlobalInvocationID); - - if (any(greaterThanEqual(projected_pos, projected_limits))) { - return; - } - - const ivec3 cache_pos = ivec3( - projected_pos.x, - projected_pos.y, - projected_pos.z + input_pos); - - write_texel(cache, cache_pos, load_texel(projected, projected_pos)); -} - -#endif // USING_BUFFER diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_attn_weight_scale_and_mask.glsl b/backends/vulkan/runtime/graph/ops/glsl/sdpa_attn_weight_scale_and_mask.glsl deleted file mode 100644 index 1e854bf7f85..00000000000 --- a/backends/vulkan/runtime/graph/ops/glsl/sdpa_attn_weight_scale_and_mask.glsl +++ /dev/null @@ -1,120 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#version 450 core - -#define PRECISION ${PRECISION} - -#define T ${buffer_scalar_type(DTYPE)} - -${define_active_storage_type(STORAGE)} -${define_required_extensions(DTYPE)} - -#extension GL_EXT_control_flow_attributes : require - -layout(std430) buffer; - -${layout_declare_tensor(B, "rw", "attn_weight", DTYPE, STORAGE)} - -$if STORAGE == "buffer": - ${layout_declare_ubo(B, "ivec4", "attn_weight_sizes")} - ${layout_declare_ubo(B, "ivec4", "attn_weight_strides")} -$else: - ${layout_declare_ubo(B, "ivec3", "attn_weight_limits")} - -${layout_declare_ubo(B, "int", "input_pos")} -${layout_declare_ubo(B, "float", "scale")} - - -#include "indexing_utils.h" - -layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; - -// Negative infinity is represented by having sign bit be 1, all exponent bits -// be 1, all mantissa bits be 0. -#define NEGATIVE_INF_BITS 0xFF800000 -const float negative_infinity = NEGATIVE_INF_BITS; - -#ifdef USING_BUFFER - -/* - * This implementations applies a scale and mask to the attention weight tensor - * of an SDPA block. The sizes of the attention weight is - * (batch_size, n_heads, seq_len, input_pos + seq_len) - * Conceptually the weights represent the relationship between each token in the - * sequence with each token preceding it. - * - * The scale applied is 1.0 / sqrt(head_dim_length) - * - * The mask applied is a bit more complicated. Imagine you create a square - * matrix of size (input_pos + seq_len, input_pos + seq_len), and then set the - * lower triangular section of the matrix to -inf. Then, slice the matrix along - * the row dimension starting from input_pos to input_pos + seq_len. You end up - * with a partial mask with size (seq_len, input_pos + seq_len). This is the - * mask that is applied to the attention weight. - * - * In the shader, instead of generating the mask, the index of the elment is - * inspected to determine if it would have been masked. Given an element at - * tensor index (n, c, h, w), it would be masked if w < h + input_pos. - */ - -/*************************** - ** Buffer Implementation ** - ***************************/ - -void main() { - const ivec4 attn_weight_idx = ivec4( - gl_GlobalInvocationID.x, - gl_GlobalInvocationID.y, - gl_GlobalInvocationID.z, - 0); - - if (any(greaterThanEqual(attn_weight_idx, attn_weight_sizes))) { - return; - } - - const T scale_conv = T(scale); - - const int attn_weight_id = tidx_to_bufi(attn_weight_idx, attn_weight_strides); - if (attn_weight_idx.x <= attn_weight_idx.y + input_pos) { - attn_weight[attn_weight_id] = attn_weight[attn_weight_id] * scale_conv; - } else { - attn_weight[attn_weight_id] = T(negative_infinity); - } -} - -#else - -/**************************** - ** Texture Implementation ** - ****************************/ - -/* - * This implementation assumes that the attention weight is width packed, i.e. - * the packed dim of the attn_weight is 0. - */ -void main() { - const ivec3 attn_weight_pos = ivec3(gl_GlobalInvocationID); - - if (any(greaterThanEqual(attn_weight_pos, attn_weight_limits))) { - return; - } - - vec4 outtex = imageLoad(attn_weight, attn_weight_pos) * scale; - - // Mask out the upper triangular of attn_weight to -inf - [[unroll]] for (int i = 0; i < 4; ++i) { - if (attn_weight_pos.x * 4 + i > attn_weight_pos.y + input_pos) { - outtex[i] = negative_infinity; - } - } - - write_texel(attn_weight, attn_weight_pos, outtex); -} - -#endif // USING_BUFFER diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_attn_weight_scale_and_mask.yaml b/backends/vulkan/runtime/graph/ops/glsl/sdpa_attn_weight_scale_and_mask.yaml deleted file mode 100644 index ca8806fe000..00000000000 --- a/backends/vulkan/runtime/graph/ops/glsl/sdpa_attn_weight_scale_and_mask.yaml +++ /dev/null @@ -1,13 +0,0 @@ -sdpa_attn_weight_scale_and_mask: - parameter_names_with_default_values: - DTYPE: float - STORAGE: buffer - generate_variant_forall: - STORAGE: - - VALUE: buffer - - VALUE: texture3d - DTYPE: - - VALUE: half - - VALUE: float - shader_variants: - - NAME: sdpa_attn_weight_scale_and_mask diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_attn_weights_softmax.glsl b/backends/vulkan/runtime/graph/ops/glsl/sdpa_attn_weights_softmax.glsl new file mode 100644 index 00000000000..1dff0017f30 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_attn_weights_softmax.glsl @@ -0,0 +1,164 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} +#define VEC4_T ${texel_load_type(DTYPE, STORAGE)} +#define T ${texel_load_component_type(DTYPE, STORAGE)} + +#define NUM_WORKERS_PER_WG 64 + +${define_active_storage_type(STORAGE)} + +#extension GL_EXT_control_flow_attributes : require + +layout(std430) buffer; + +#include "common.glslh" + +${layout_declare_tensor(B, "w", "t_attn_weights_softmax", DTYPE, STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_attn_weights", DTYPE, STORAGE, is_scalar_array=False)} + +${layout_declare_ubo(B, "ivec4", "q_projected_sizes")} +${layout_declare_ubo(B, "int", "input_pos")} + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +// Shared memory for cooperative exp sum finding +shared T shared_exp_sum[NUM_WORKERS_PER_WG]; + +VEC4_T load_attn_weights_c4( + const int c4, + const int s, + const int q_h, + const int C4, + const int S, + const int Q_H) { +#ifdef USING_BUFFER + return t_attn_weights[(q_h * S * C4) + (s * C4) + c4]; +#else + return texelFetch(t_attn_weights, ivec3(c4, s, q_h), 0); +#endif +} + +void store_attn_weights_softmax_c4( + const VEC4_T out_texel, + const int c4, + const int s, + const int q_h, + const int C4, + const int S, + const int Q_H) { +#ifdef USING_BUFFER + t_attn_weights_softmax[(q_h * S * C4) + (s * C4) + c4] = out_texel; +#else + imageStore(t_attn_weights_softmax, ivec3(c4, s, q_h), out_texel); +#endif +} + +void main() { + const int worker_id = int(gl_LocalInvocationID.x); + + // Index along attention weight's sequence_len dim + const int s = int(gl_GlobalInvocationID.y); + // idx along attention weight's num_q_heads dim + const int q_h = int(gl_GlobalInvocationID.z); + + // number of Q heads + const int Q_H = q_projected_sizes.y; + // sequence length + const int S = q_projected_sizes.z; + // manually determine size of the context_len dim of the attention weight. + // The "actual" tensor sizes may have been aligned to a multiple of 4 to allow + // memory loads to be aligned to texel boundaries. + const int context_len = input_pos + S; + const int context_texel_len = div_up_4(context_len); + + if (s >= S || q_h >= Q_H) { + return; + } + + // Initialize thread-local min/max + T local_exp_sum = 0; + + const int context_len_aligned_down = context_len - mod_4(context_len); + const int C4_limit = div_4(context_len_aligned_down); + + // Each thread processes elements along a context_len row with a stride of the + // number of threads in the work group. + for (int c4 = worker_id; c4 < C4_limit; c4 += NUM_WORKERS_PER_WG) { + VEC4_T in_texel = load_attn_weights_c4( + c4, s, q_h, context_texel_len, S, Q_H); + + for (int comp = 0; comp < 4; comp++) { + local_exp_sum += exp(in_texel[comp]); + } + } + // First thread in the work group responsible for handling last texel if it + // contains any padded elements + if (worker_id == 0) { + for (int c4 = C4_limit; c4 < context_texel_len; ++c4) { + const int c_base = mul_4(c4); + VEC4_T in_texel = load_attn_weights_c4( + c4, s, q_h, context_texel_len, S, Q_H); + + [[unroll]] for (int comp = 0; comp < 4; comp++) { + if (c_base + comp < context_len) { + local_exp_sum += exp(in_texel[comp]); + } + } + } + } + + // Store thread-local results in shared memory + shared_exp_sum[worker_id] = local_exp_sum; + + memoryBarrierShared(); + barrier(); + + // Tree reduction to compute the overall result + for (int i = NUM_WORKERS_PER_WG / 2; i > 0; i >>= 1) { + if (worker_id < i) { + shared_exp_sum[worker_id] = shared_exp_sum[worker_id] + + shared_exp_sum[worker_id + i]; + } + memoryBarrierShared(); + barrier(); + } + + local_exp_sum = shared_exp_sum[0]; + // Now go back through each element in the row and normalize + for (int c4 = worker_id; c4 < C4_limit; c4 += NUM_WORKERS_PER_WG) { + VEC4_T in_texel = load_attn_weights_c4( + c4, s, q_h, context_texel_len, S, Q_H); + + VEC4_T out_texel = exp(in_texel) / local_exp_sum; + store_attn_weights_softmax_c4( + out_texel, c4, s, q_h, context_texel_len, S, Q_H); + } + // First thread in the work group responsible for handling last texel if it + // contains any padded elements + if (worker_id == 0) { + for (int c4 = C4_limit; c4 < context_texel_len; ++c4) { + const int c_base = mul_4(c4); + VEC4_T in_texel = load_attn_weights_c4( + c4, s, q_h, context_texel_len, S, Q_H); + + // Ensure that padding elements are set to 0. + VEC4_T out_texel = VEC4_T(0); + [[unroll]] for (int comp = 0; comp < 4; comp++) { + if (c_base + comp < context_len) { + out_texel[comp] = exp(in_texel[comp]) / local_exp_sum; + } + } + store_attn_weights_softmax_c4( + out_texel, c4, s, q_h, context_texel_len, S, Q_H); + } + } +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/kv_cache_update.yaml b/backends/vulkan/runtime/graph/ops/glsl/sdpa_attn_weights_softmax.yaml similarity index 82% rename from backends/vulkan/runtime/graph/ops/glsl/kv_cache_update.yaml rename to backends/vulkan/runtime/graph/ops/glsl/sdpa_attn_weights_softmax.yaml index e2a96234465..8abf50399e0 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/kv_cache_update.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_attn_weights_softmax.yaml @@ -4,16 +4,15 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -kv_cache_update: +sdpa_attn_weights_softmax: parameter_names_with_default_values: DTYPE: float - STORAGE: buffer + STORAGE: texture3d generate_variant_forall: STORAGE: - - VALUE: buffer - VALUE: texture3d + - VALUE: buffer DTYPE: - - VALUE: half - VALUE: float shader_variants: - - NAME: kv_cache_update + - NAME: sdpa_attn_weights_softmax diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_coop.glsl b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_coop.glsl new file mode 100644 index 00000000000..4b7e3e0ddd2 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_coop.glsl @@ -0,0 +1,213 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} +#define VEC4_T ${texel_load_type(DTYPE, IO_STORAGE)} +#define T ${texel_load_component_type(DTYPE, IO_STORAGE)} + +$if IO_STORAGE == "buffer": + #define OUTPUT_BUFFER + #define INPUT_BUFFER +$if K_CACHE_STORAGE == "buffer": + #define K_CACHE_BUFFER + +#define TILE_K4 ${TILE_K4} +#define TILE_N4 ${TILE_N4} + +#define TILE_M 1 +#define TILE_K ${TILE_K4 * 4} +#define TILE_N ${TILE_N4 * 4} + +#define NUM_WORKERS_PER_OUT 64 + +${define_required_extensions(DTYPE)} + +layout(std430) buffer; + +#include "common.glslh" + +${layout_declare_tensor(B, "w", "t_attn_weights", DTYPE, IO_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_q_projected", DTYPE, IO_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_k_cache", DTYPE, K_CACHE_STORAGE, is_scalar_array=False)} + +${layout_declare_ubo(B, "ivec4", "q_projected_sizes")} +${layout_declare_ubo(B, "ivec4", "k_cache_sizes")} +${layout_declare_ubo(B, "int", "input_pos")} + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +${layout_declare_spec_const(C, "float", "inv_scale", "1.0")} + +#include "sdpa_fp_q_projected_tile_load.glslh" +#include "sdpa_fp_k_cache_tile_load.glslh" +#include "linear_fp_output_tile_fp_compute.glslh" +#include "sdpa_fp_attn_weight_tile_store.glslh" + +shared FPOutTile partial_sums[NUM_WORKERS_PER_OUT]; + +/* + * See the tiled variant of this shader for the implemented behavior. This + * shader is implements an optimization for cases where sequence length is 1; in + * these cases, the matrix multiplication being performed is akin to gemv, which + * benefits from using a co-operative algorithm for reduction. For this shader + * the entire work group co-operates to compute one reduction output. + */ + +#extension GL_EXT_debug_printf : enable + +void main() { + const int worker_id = int(gl_LocalInvocationID.y); + + const int tile_idx_x = int(gl_GlobalInvocationID.x); + // idx along output num_q_heads dim + const int q_h = int(gl_GlobalInvocationID.z); + + // idx along the output context_len dim + const int c = tile_idx_x * TILE_N; + const int c4 = div_4(c); + + // idx along the output seq_len dim. Note that for this shader seq_len will be + // 1. + const int s = 0; + + // texel size of head_dim, over which the dot product is accumulated + const int D4 = div_up_4(q_projected_sizes.x); + // number of Q heads + const int Q_H = q_projected_sizes.y; + // sequence length + const int S = q_projected_sizes.z; + + // number of K/V heads + const int KV_H = k_cache_sizes.y; + // Max context length + const int C = k_cache_sizes.z; + const int C4 = div_up_4(C); + + int kv_h = q_h; + if (KV_H < Q_H) { + kv_h = q_h / (Q_H / KV_H); + } + + // current context length + const int context_len = input_pos + S; + const int context_texel_len = div_up_4(context_len); + + // bounds check + if (c >= context_len || s >= S || q_h >= Q_H) { + return; + } + + FPOutTile out_tile; + initialize(out_tile); + + FPInputTile q_tile; + FPWeightTile w_tile; + + // If the tile is completely inside the mask region, then there is no need to + // compute the output tile. All the elements in the output tile can be set to + // negative infinity. + bool tile_in_mask_region = c > (input_pos + s + (TILE_M - 1)); + if (tile_in_mask_region) { + const VEC4_T negative_infinity_vec = VEC4_T(negative_infinity_val); + set_out_tile_to_vec(out_tile, negative_infinity_vec); + } + // Otherwise, need to actually compute output tile + else { + const bool dont_check_bounds = (S - s) >= TILE_M && + (context_len - c) >= TILE_N; + + if (dont_check_bounds) { + for (int d4 = worker_id; d4 < D4; d4 += NUM_WORKERS_PER_OUT) { + load_q_projected_tile_no_checks( + q_tile, + d4, + s, + q_h, + D4, + Q_H, + S); + + load_k_cache_tile_no_checks( + w_tile, + d4, + c, + kv_h, + D4, + context_len, + C, + KV_H); + + fp_accumulate_with_fp_weight(out_tile, q_tile, w_tile); + } + } else { + for (int d4 = worker_id; d4 < D4; d4 += NUM_WORKERS_PER_OUT) { + load_q_projected_tile_with_checks( + q_tile, + d4, + s, + q_h, + D4, + Q_H, + S); + + load_k_cache_tile_with_checks( + w_tile, + d4, + c, + kv_h, + D4, + context_len, + C, + KV_H); + + fp_accumulate_with_fp_weight(out_tile, q_tile, w_tile); + } + } + } + + partial_sums[worker_id] = out_tile; + + memoryBarrierShared(); + barrier(); + + // Tree reduction to compute the overall result. + for (int i = NUM_WORKERS_PER_OUT / 2; i > 0; i /= 2) { + if (worker_id < i) { + accumulate_out_tile_with_out_tile( + partial_sums[worker_id], partial_sums[worker_id + i]); + } + memoryBarrierShared(); + barrier(); + } + + // Only the first thread will write out the result + if (worker_id == 0) { + out_tile = partial_sums[0]; + // Apply scale and mask if the tile was not entirely in the mask region + if (!tile_in_mask_region) { + VEC4_T inv_scale_vec = VEC4_T(inv_scale); + apply_scale_and_mask( + out_tile, + inv_scale_vec, + input_pos, + c, + s); + } + + store_attn_weight_tile_with_checks( + out_tile, + c4, + s, + q_h, + context_texel_len, + S, + Q_H); + } +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_coop.yaml b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_coop.yaml new file mode 100644 index 00000000000..6a4cffcc913 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_coop.yaml @@ -0,0 +1,21 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +sdpa_compute_attn_weights_coop: + parameter_names_with_default_values: + DTYPE: float + IO_STORAGE: texture3d + K_CACHE_STORAGE: texture3d + TILE_K4: 1 + TILE_N4: 1 + generate_variant_forall: + DTYPE: + - VALUE: float + - VALUE: half + shader_variants: + - NAME: sdpa_compute_attn_weights_coop_texture3d_texture3d + - NAME: sdpa_compute_attn_weights_coop_buffer_texture3d + IO_STORAGE: buffer diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_tiled.glsl b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_tiled.glsl new file mode 100644 index 00000000000..577d7dea749 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_tiled.glsl @@ -0,0 +1,203 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} +#define VEC4_T ${texel_load_type(DTYPE, IO_STORAGE)} +#define T ${texel_load_component_type(DTYPE, IO_STORAGE)} + +$if IO_STORAGE == "buffer": + #define OUTPUT_BUFFER + #define INPUT_BUFFER +$if K_CACHE_STORAGE == "buffer": + #define K_CACHE_BUFFER + +#define TILE_M4 ${TILE_M4} +#define TILE_K4 ${TILE_K4} +#define TILE_N4 ${TILE_N4} + +#define TILE_M ${TILE_M4 * 4} +#define TILE_K ${TILE_K4 * 4} +#define TILE_N ${TILE_N4 * 4} + +${define_required_extensions(DTYPE)} + +layout(std430) buffer; + +#include "common.glslh" + +${layout_declare_tensor(B, "w", "t_attn_weights", DTYPE, IO_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_q_projected", DTYPE, IO_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_k_cache", DTYPE, K_CACHE_STORAGE, is_scalar_array=False)} + +${layout_declare_ubo(B, "ivec4", "q_projected_sizes")} +${layout_declare_ubo(B, "ivec4", "k_cache_sizes")} +${layout_declare_ubo(B, "int", "input_pos")} + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +${layout_declare_spec_const(C, "float", "inv_scale", "1.0")} + +#include "sdpa_fp_q_projected_tile_load.glslh" +#include "sdpa_fp_k_cache_tile_load.glslh" +#include "linear_fp_output_tile_fp_compute.glslh" +#include "sdpa_fp_attn_weight_tile_store.glslh" + +/* + * Compute attention weights given the q_projected and k_cache tensors. + * q_projected has shape (batches, seq_len, num_q_heads, head_dim) + * k_cache has shape (batches, max_context_len, num_kv_heads, head_dim) + * output has shape (batches, num_q_heads, seq_len, context_len) + * + * This shader also applies scales and masking to the computed attention + * weights. + * + * The scale applied is 1.0 / sqrt(head_dim_length). + * + * The mask applied is a bit more complicated. Imagine you create a square + * matrix of size (input_pos + seq_len, input_pos + seq_len), and then set the + * lower triangular section of the matrix to -inf. Then, slice the matrix along + * the row dimension starting from input_pos to input_pos + seq_len. You end up + * with a partial mask with size (seq_len, input_pos + seq_len). This is the + * mask that is applied to the attention weight. + * + * In the shader, instead of generating the mask, the index of the elment is + * inspected to determine if it would have been masked. Given an element at + * tensor index (n, c, h, w), it would be masked if w < h + input_pos. + * + */ + +#extension GL_EXT_debug_printf : enable + +void main() { + const int tile_idx_x = int(gl_GlobalInvocationID.x); + const int tile_idx_y = int(gl_GlobalInvocationID.y); + // idx along output num_q_heads dim + const int q_h = int(gl_GlobalInvocationID.z); + + // idx along the output context_len dim + const int c = tile_idx_x * TILE_N; + const int c4 = div_4(c); + + // idx along the output seq_len dim + const int s = tile_idx_y * TILE_M; + const int s4 = div_4(s); + + // texel size of head_dim, over which the dot product is accumulated + const int D4 = div_up_4(q_projected_sizes.x); + // number of Q heads + const int Q_H = q_projected_sizes.y; + // sequence length + const int S = q_projected_sizes.z; + + // number of K/V heads + const int KV_H = k_cache_sizes.y; + // Max context length + const int C = k_cache_sizes.z; + const int C4 = div_up_4(C); + + int kv_h = q_h; + if (KV_H < Q_H) { + kv_h = q_h / (Q_H / KV_H); + } + + const int context_len = input_pos + S; + const int context_texel_len = div_up_4(context_len); + + // bounds check + if (c >= context_len || s >= S || q_h >= Q_H) { + return; + } + + FPOutTile out_tile; + initialize(out_tile); + + FPInputTile q_tile; + FPWeightTile w_tile; + + // If the tile is completely inside the mask region, then there is no need to + // compute the output tile. All the elements in the output tile can be set to + // negative infinity. + bool tile_in_mask_region = c > (input_pos + s + (TILE_M - 1)); + if (tile_in_mask_region) { + const VEC4_T negative_infinity_vec = VEC4_T(negative_infinity_val); + set_out_tile_to_vec(out_tile, negative_infinity_vec); + } + // Otherwise, need to actually compute output tile + else { + const bool dont_check_bounds = (S - s) >= TILE_M && + (context_len - c) >= TILE_N; + + if (dont_check_bounds) { + for (int d4 = 0; d4 < D4; d4++) { + load_q_projected_tile_no_checks( + q_tile, + d4, + s, + q_h, + D4, + Q_H, + S); + + load_k_cache_tile_no_checks( + w_tile, + d4, + c, + kv_h, + D4, + context_len, + C, + KV_H); + + fp_accumulate_with_fp_weight(out_tile, q_tile, w_tile); + } + } else { + for (int d4 = 0; d4 < D4; d4++) { + load_q_projected_tile_with_checks( + q_tile, + d4, + s, + q_h, + D4, + Q_H, + S); + + load_k_cache_tile_with_checks( + w_tile, + d4, + c, + kv_h, + D4, + context_len, + C, + KV_H); + + fp_accumulate_with_fp_weight(out_tile, q_tile, w_tile); + } + } + + // Apply scale and mask + VEC4_T inv_scale_vec = VEC4_T(inv_scale); + apply_scale_and_mask( + out_tile, + inv_scale_vec, + input_pos, + c, + s); + } + + store_attn_weight_tile_with_checks( + out_tile, + c4, + s, + q_h, + context_texel_len, + S, + Q_H); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_tiled.yaml b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_tiled.yaml new file mode 100644 index 00000000000..6aadbbc379e --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_tiled.yaml @@ -0,0 +1,22 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +sdpa_compute_attn_weights_tiled: + parameter_names_with_default_values: + DTYPE: float + IO_STORAGE: texture3d + K_CACHE_STORAGE: texture3d + TILE_M4: 1 + TILE_K4: 1 + TILE_N4: 1 + generate_variant_forall: + DTYPE: + - VALUE: float + - VALUE: half + shader_variants: + - NAME: sdpa_compute_attn_weights_tiled_texture3d_texture3d + - NAME: sdpa_compute_attn_weights_tiled_buffer_texture3d + IO_STORAGE: buffer diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_coop.glsl b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_coop.glsl new file mode 100644 index 00000000000..1fdd803d02b --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_coop.glsl @@ -0,0 +1,195 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} +#define VEC4_T ${texel_load_type(DTYPE, IO_STORAGE)} +#define T ${texel_load_component_type(DTYPE, IO_STORAGE)} + +$if IO_STORAGE == "buffer": + #define OUTPUT_BUFFER + #define INPUT_BUFFER +$if V_CACHE_STORAGE == "buffer": + #define V_CACHE_BUFFER + +#define TILE_K4 ${TILE_K4} +#define TILE_N4 ${TILE_N4} + +#define TILE_M 1 +#define TILE_K ${TILE_K4 * 4} +#define TILE_N ${TILE_N4 * 4} + +#define NUM_WORKERS_PER_OUT 64 + +${define_required_extensions(DTYPE)} + +layout(std430) buffer; + +#include "common.glslh" + +${layout_declare_tensor(B, "w", "t_output", DTYPE, IO_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_attn_weights", DTYPE, IO_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_v_cache", DTYPE, V_CACHE_STORAGE, is_scalar_array=False)} + +${layout_declare_ubo(B, "ivec4", "q_projected_sizes")} +${layout_declare_ubo(B, "ivec4", "v_cache_sizes")} +${layout_declare_ubo(B, "int", "input_pos")} + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +${layout_declare_spec_const(C, "float", "inv_scale", "1.0")} + +#include "sdpa_fp_attn_weight_tile_load.glslh" +#include "sdpa_fp_v_cache_tile_load.glslh" +#include "linear_fp_output_tile_fp_compute.glslh" +#include "sdpa_fp_out_tile_store.glslh" + +shared FPOutTile partial_sums[NUM_WORKERS_PER_OUT]; + +/* + * See the tiled variant of this shader for the implemented behavior. This + * shader is implements an optimization for cases where sequence length is 1; in + * these cases, the matrix multiplication being performed is akin to gemv, which + * benefits from using a co-operative algorithm for reduction. For this shader + * the entire work group co-operates to compute one reduction output. + */ + +#extension GL_EXT_debug_printf : enable + +void main() { + const int worker_id = int(gl_LocalInvocationID.y); + + const int tile_idx_x = int(gl_GlobalInvocationID.x); + // idx along output num_q_heads dim + const int q_h = int(gl_GlobalInvocationID.z); + + // idx along the output head_dim dim + const int d = tile_idx_x * TILE_N; + const int d4 = div_4(d); + + // idx along the output seq_len dim. Note that for this shader seq_len will be + // 1. + const int s = 0; + + // texel size of head_dim + const int D4 = div_up_4(q_projected_sizes.x); + // number of Q heads + const int Q_H = q_projected_sizes.y; + // sequence length + const int S = q_projected_sizes.z; + + // number of K/V heads + const int KV_H = v_cache_sizes.y; + // Max context length + const int C = v_cache_sizes.z; + const int C4 = div_up_4(C); + + int kv_h = q_h; + if (KV_H < Q_H) { + kv_h = q_h / (Q_H / KV_H); + } + + // current context length + const int context_len = input_pos + S; + const int context_texel_len = div_up_4(context_len); + + // bounds check + if (d4 >= D4 || s >= S || q_h >= Q_H) { + return; + } + + FPOutTile out_tile; + initialize(out_tile); + + FPInputTile attn_weight_tile; + FPWeightTile w_tile; + + const int context_len_aligned_down = context_len - mod_4(context_len); + const int C4_limit = div_up_4(context_len_aligned_down); + + for (int c4 = worker_id; c4 < C4_limit; c4 += NUM_WORKERS_PER_OUT) { + const int c = mul_4(c4); + + load_attn_weight_tile_no_checks( + attn_weight_tile, + c4, + s, + q_h, + context_texel_len, + S, + Q_H); + + load_v_cache_tile_no_checks( + w_tile, + d4, + c, + kv_h, + D4, + context_len, + C, + KV_H); + + fp_accumulate_with_fp_weight(out_tile, attn_weight_tile, w_tile); + } + // first worker in the work group will handle final texel, which may contain + // padding elements. + if (worker_id == 0) { + for (int c4 = C4_limit; c4 < context_texel_len; c4++) { + const int c = mul_4(c4); + load_attn_weight_tile_with_checks( + attn_weight_tile, + c4, + s, + q_h, + context_texel_len, + S, + Q_H); + + load_v_cache_tile_with_checks( + w_tile, + d4, + c, + kv_h, + D4, + context_len, + C, + KV_H); + + fp_accumulate_with_fp_weight(out_tile, attn_weight_tile, w_tile); + } + } + + partial_sums[worker_id] = out_tile; + + memoryBarrierShared(); + barrier(); + + // Tree reduction to compute the overall result. + for (int i = NUM_WORKERS_PER_OUT / 2; i > 0; i /= 2) { + if (worker_id < i) { + accumulate_out_tile_with_out_tile( + partial_sums[worker_id], partial_sums[worker_id + i]); + } + memoryBarrierShared(); + barrier(); + } + + // Only the first thread will write out the result + if (worker_id == 0) { + out_tile = partial_sums[0]; + store_sdpa_out_tile_with_checks( + out_tile, + d4, + s, + q_h, + D4, + S, + Q_H); + } +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/flash_attention_texture3d.yaml b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_coop.yaml similarity index 56% rename from backends/vulkan/runtime/graph/ops/glsl/flash_attention_texture3d.yaml rename to backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_coop.yaml index 909b8bfd3a9..ccebf8f7c1c 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/flash_attention_texture3d.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_coop.yaml @@ -4,12 +4,18 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -flash_attention_texture3d: +sdpa_compute_out_coop: parameter_names_with_default_values: DTYPE: float - STORAGE: texture3d + IO_STORAGE: texture3d + V_CACHE_STORAGE: texture3d + TILE_K4: 1 + TILE_N4: 1 generate_variant_forall: DTYPE: - VALUE: float + - VALUE: half shader_variants: - - NAME: flash_attention_texture3d + - NAME: sdpa_compute_out_coop_texture3d_texture3d + - NAME: sdpa_compute_out_coop_buffer_texture3d + IO_STORAGE: buffer diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_tiled.glsl b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_tiled.glsl new file mode 100644 index 00000000000..fb4eaded826 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_tiled.glsl @@ -0,0 +1,165 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} +#define VEC4_T ${texel_load_type(DTYPE, IO_STORAGE)} +#define T ${texel_load_component_type(DTYPE, IO_STORAGE)} + +$if IO_STORAGE == "buffer": + #define OUTPUT_BUFFER + #define INPUT_BUFFER +$if V_CACHE_STORAGE == "buffer": + #define V_CACHE_BUFFER + +#define TILE_M4 ${TILE_M4} +// Equvalent to K4 in matrix multiplication +#define TILE_K4 ${TILE_K4} +// Equvalent to N4 in matrix multiplication +#define TILE_N4 ${TILE_N4} + +#define TILE_M ${TILE_M4 * 4} +#define TILE_K ${TILE_K4 * 4} +#define TILE_N ${TILE_N4 * 4} + +${define_required_extensions(DTYPE)} + +layout(std430) buffer; + +#include "common.glslh" + +${layout_declare_tensor(B, "w", "t_output", DTYPE, IO_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_attn_weights", DTYPE, IO_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_v_cache", DTYPE, V_CACHE_STORAGE, is_scalar_array=False)} + +${layout_declare_ubo(B, "ivec4", "q_projected_sizes")} +${layout_declare_ubo(B, "ivec4", "v_cache_sizes")} +${layout_declare_ubo(B, "int", "input_pos")} + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +#include "sdpa_fp_attn_weight_tile_load.glslh" +#include "sdpa_fp_v_cache_tile_load.glslh" +#include "linear_fp_output_tile_fp_compute.glslh" +#include "sdpa_fp_out_tile_store.glslh" + +/* + * Compute SDPA output given the attention weights and v_cache tensors. + * attention weights has shape (batches, num_q_heads, seq_len, context_len) + * v_cache has shape (batches, max_context_len, num_kv_heads, head_dim) + * output has shape (batches, seq_len, num_q_heads, head_dim) + */ + +#extension GL_EXT_debug_printf : enable + +void main() { + const int tile_idx_x = int(gl_GlobalInvocationID.x); + const int tile_idx_y = int(gl_GlobalInvocationID.y); + // idx along output num_q_heads dim + const int q_h = int(gl_GlobalInvocationID.z); + + // idx along the output head_dim dim + const int d = tile_idx_x * TILE_N; + const int d4 = div_4(d); + + // idx along the output seq_len dim + const int s = tile_idx_y * TILE_M; + + // texel size of head_dim + const int D4 = div_up_4(q_projected_sizes.x); + // number of Q heads + const int Q_H = q_projected_sizes.y; + // sequence length + const int S = q_projected_sizes.z; + + // number of K/V heads + const int KV_H = v_cache_sizes.y; + // Max context length + const int C = v_cache_sizes.z; + const int C4 = div_up_4(C); + + int kv_h = q_h; + if (KV_H < Q_H) { + kv_h = q_h / (Q_H / KV_H); + } + + // current context length + const int context_len = input_pos + S; + const int context_texel_len = div_up_4(context_len); + + // bounds check + if (d4 >= D4 || s >= S || q_h >= Q_H) { + return; + } + + FPOutTile out_tile; + initialize(out_tile); + + FPInputTile attn_weight_tile; + FPWeightTile w_tile; + + const int context_len_aligned_down = context_len - mod_4(context_len); + const int C4_limit = div_4(context_len_aligned_down); + + for (int c4 = 0; c4 < C4_limit; c4++) { + const int c = mul_4(c4); + load_attn_weight_tile_no_checks( + attn_weight_tile, + c4, + s, + q_h, + context_texel_len, + S, + Q_H); + + load_v_cache_tile_no_checks( + w_tile, + d4, + c, + kv_h, + D4, + context_len, + C, + KV_H); + + fp_accumulate_with_fp_weight(out_tile, attn_weight_tile, w_tile); + } + for (int c4 = C4_limit; c4 < context_texel_len; c4++) { + const int c = mul_4(c4); + load_attn_weight_tile_with_checks( + attn_weight_tile, + c4, + s, + q_h, + context_texel_len, + S, + Q_H); + + load_v_cache_tile_with_checks( + w_tile, + d4, + c, + kv_h, + D4, + context_len, + C, + KV_H); + + fp_accumulate_with_fp_weight(out_tile, attn_weight_tile, w_tile); + } + + store_sdpa_out_tile_with_checks( + out_tile, + d4, + s, + q_h, + D4, + S, + Q_H); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_tiled.yaml b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_tiled.yaml new file mode 100644 index 00000000000..7fbce29e908 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_tiled.yaml @@ -0,0 +1,22 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +sdpa_compute_out_tiled: + parameter_names_with_default_values: + DTYPE: float + IO_STORAGE: texture3d + V_CACHE_STORAGE: texture3d + TILE_M4: 1 + TILE_K4: 1 + TILE_N4: 1 + generate_variant_forall: + DTYPE: + - VALUE: float + - VALUE: half + shader_variants: + - NAME: sdpa_compute_out_tiled_texture3d_texture3d + - NAME: sdpa_compute_out_tiled_buffer_texture3d + IO_STORAGE: buffer diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_fp_attn_weight_tile_load.glslh b/backends/vulkan/runtime/graph/ops/glsl/sdpa_fp_attn_weight_tile_load.glslh new file mode 100644 index 00000000000..12b2292fa45 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_fp_attn_weight_tile_load.glslh @@ -0,0 +1,74 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/* + * Assume the following variables are defined in the shader layout: + * - t_attn_weights + * + * Macro Settings: + * - INPUT_BUFFER + */ + +#ifndef SDPA_FP_ATTN_WEIGHT_TILE_LOAD_GLSLH +#define SDPA_FP_ATTN_WEIGHT_TILE_LOAD_GLSLH + +#extension GL_EXT_control_flow_attributes : require + +#include "linear_fp_input_tile.glslh" + +VEC4_T load_attn_weight_c4( + const int c4, + const int s, + const int q_h, + const int C4, + const int S, + const int Q_H) { +#ifdef INPUT_BUFFER + return t_attn_weights[(q_h * S * C4) + (s * C4) + c4]; +#else + return texelFetch(t_attn_weights, ivec3(c4, s, q_h), 0); +#endif +} + +void load_attn_weight_tile_no_checks( + out FPInputTile tile, + const int c4_start, + const int s_start, + const int q_h, + const int C4, + const int S, + const int Q_H) { + [[unroll]] for (int s = 0; s < TILE_M; ++s) { + [[unroll]] for (int c4 = 0; c4 < TILE_N4; ++c4) { + tile.data[s][c4] = + load_attn_weight_c4(c4_start + c4, s_start + s, q_h, C4, S, Q_H); + } + } +} + +void load_attn_weight_tile_with_checks( + out FPInputTile tile, + const int c4_start, + const int s_start, + const int q_h, + const int C4, + const int S, + const int Q_H) { + [[unroll]] for (int s = 0; s < TILE_M; ++s) { + [[unroll]] for (int c4 = 0; c4 < TILE_N4; ++c4) { + if (c4_start + c4 < C4 && s_start + s < S) { + tile.data[s][c4] = + load_attn_weight_c4(c4_start + c4, s_start + s, q_h, C4, S, Q_H); + } else { + tile.data[s][c4] = VEC4_T(0.0); + } + } + } +} + +#endif // SDPA_FP_ATTN_WEIGHT_TILE_LOAD_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_fp_attn_weight_tile_store.glslh b/backends/vulkan/runtime/graph/ops/glsl/sdpa_fp_attn_weight_tile_store.glslh new file mode 100644 index 00000000000..c64d9af8cfb --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_fp_attn_weight_tile_store.glslh @@ -0,0 +1,105 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/* + * Assume the following variables are defined in the shader layout: + * - t_attn_weights + * + * Macro Settings: + * - OUTPUT_BUFFER + */ + +#ifndef SDPA_FP_ATTN_WEIGHT_TILE_LOAD_GLSLH +#define SDPA_FP_ATTN_WEIGHT_TILE_LOAD_GLSLH + +#extension GL_EXT_control_flow_attributes : require + +#include "linear_fp_output_tile.glslh" + +T negative_infinity_val = T(-1.0 / 0.0); + +void store_attn_weight_c4( + const VEC4_T out_texel, + const int c4, + const int s, + const int q_h, + const int C4, + const int S, + const int Q_H) { +#ifdef OUTPUT_BUFFER + t_attn_weights[(q_h * S * C4) + (s * C4) + c4] = out_texel; +#else + imageStore(t_attn_weights, ivec3(c4, s, q_h), out_texel); +#endif +} + +void store_attn_weight_tile_no_checks( + const FPOutTile tile, + const int c4_start, + const int s_start, + const int q_h, + const int C4, + const int S, + const int Q_H) { + [[unroll]] for (int s = 0; s < TILE_M; ++s) { + [[unroll]] for (int c4 = 0; c4 < TILE_N4; ++c4) { + store_attn_weight_c4( + tile.data[s][c4], c4_start + c4, s_start + s, q_h, C4, S, Q_H); + } + } +} + +void store_attn_weight_tile_with_checks( + const FPOutTile tile, + const int c4_start, + const int s_start, + const int q_h, + const int C4, + const int S, + const int Q_H) { + [[unroll]] for (int s = 0; s < TILE_M; ++s) { + [[unroll]] for (int c4 = 0; c4 < TILE_N4; ++c4) { + if (c4_start + c4 < C4 && s_start + s < S) { + store_attn_weight_c4( + tile.data[s][c4], c4_start + c4, s_start + s, q_h, C4, S, Q_H); + } + } + } +} + +void set_out_tile_to_vec(out FPOutTile tile, const VEC4_T vec) { + [[unroll]] for (int s = 0; s < TILE_M; ++s) { + [[unroll]] for (int c4 = 0; c4 < TILE_N4; ++c4) { tile.data[s][c4] = vec; } + } +} + +void apply_scale_and_mask( + inout FPOutTile tile, + const VEC4_T inv_scale_vec, + const int input_pos, + const int c_idx_start, + const int s_idx_start) { + [[unroll]] for (int s = 0; s < TILE_M; ++s) { + [[unroll]] for (int c4 = 0; c4 < TILE_N4; ++c4) { + tile.data[s][c4] = tile.data[s][c4] * inv_scale_vec; + + const int c_base = mul_4(c4); + [[unroll]] for (int c4i = 0; c4i < 4; ++c4i) { + const int c = c_base + c4i; + // Indices of the tile element in the overall output tensor + const int c_idx = c_idx_start + c; + const int s_idx = s_idx_start + s; + if (c_idx > s_idx + input_pos) { + tile.data[s][c4][c4i] = negative_infinity_val; + } + } + } + } +} + +#endif // SDPA_FP_ATTN_WEIGHT_TILE_LOAD_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_fp_k_cache_tile_load.glslh b/backends/vulkan/runtime/graph/ops/glsl/sdpa_fp_k_cache_tile_load.glslh new file mode 100644 index 00000000000..03132db1348 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_fp_k_cache_tile_load.glslh @@ -0,0 +1,93 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/* + * Assume the following variables are defined in the shader layout: + * - t_k_cache + * + * Macro Settings: + * - K_CACHE_BUFFER + */ + +#ifndef SDPA_FP_K_CACHE_TILE_LOAD_GLSLH +#define SDPA_FP_K_CACHE_TILE_LOAD_GLSLH + +#extension GL_EXT_control_flow_attributes : require + +#include "linear_fp_weight_tile.glslh" + +VEC4_T load_k_cache_d4( + const int d4, + const int c, + const int kv_h, + const int D4, + const int C, + const int KV_H) { +#ifdef K_CACHE_BUFFER + return VEC4_T(t_k_cache[(c * KV_H * D4) + (kv_h * D4) + d4]); +#else + return VEC4_T(texelFetch(t_k_cache, ivec3(d4, kv_h, c), 0)); +#endif +} + +void load_k_cache_tile_no_checks( + out FPWeightTile tile, + const int d4_start, + const int c_start, + const int kv_h, + const int D4, + const int context_len, + const int C, + const int KV_H) { + bool should_print = d4_start == 0 && c_start == 0 && kv_h == 0; + [[unroll]] for (int c = 0; c < TILE_N; ++c) { + const int c4 = div_4(c); + const int c4i = mod_4(c); + [[unroll]] for (int d4 = 0; d4 < TILE_K4; ++d4) { + VEC4_T d4_row = + load_k_cache_d4(d4_start + d4, c_start + c, kv_h, D4, C, KV_H); + + // Transpose in-place + const int d_base = mul_4(d4); + tile.data[d_base][c4][c4i] = d4_row[0]; + tile.data[d_base + 1][c4][c4i] = d4_row[1]; + tile.data[d_base + 2][c4][c4i] = d4_row[2]; + tile.data[d_base + 3][c4][c4i] = d4_row[3]; + } + } +} + +void load_k_cache_tile_with_checks( + out FPWeightTile tile, + const int d4_start, + const int c_start, + const int kv_h, + const int D4, + const int context_len, + const int C, + const int KV_H) { + [[unroll]] for (int c = 0; c < TILE_N; ++c) { + const int c4 = div_4(c); + const int c4i = mod_4(c); + [[unroll]] for (int d4 = 0; d4 < TILE_K4; ++d4) { + VEC4_T d4_row = VEC4_T(0.0); + if (d4_start + d4 < D4 && c_start + c < context_len) { + d4_row = load_k_cache_d4(d4_start + d4, c_start + c, kv_h, D4, C, KV_H); + } + + // Transpose in-place + const int d_base = mul_4(d4); + tile.data[d_base][c4][c4i] = d4_row[0]; + tile.data[d_base + 1][c4][c4i] = d4_row[1]; + tile.data[d_base + 2][c4][c4i] = d4_row[2]; + tile.data[d_base + 3][c4][c4i] = d4_row[3]; + } + } +} + +#endif // SDPA_FP_K_CACHE_TILE_LOAD_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_fp_out_tile_store.glslh b/backends/vulkan/runtime/graph/ops/glsl/sdpa_fp_out_tile_store.glslh new file mode 100644 index 00000000000..17e0988a6a4 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_fp_out_tile_store.glslh @@ -0,0 +1,57 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/* + * Assume the following variables are defined in the shader layout: + * - t_attn_weights + * + * Macro Settings: + * - OUTPUT_BUFFER + */ + +#ifndef SDPA_FP_OUT_TILE_LOAD_GLSLH +#define SDPA_FP_OUT_TILE_LOAD_GLSLH + +#extension GL_EXT_control_flow_attributes : require + +#include "linear_fp_output_tile.glslh" + +void store_out_d4( + const VEC4_T out_texel, + const int d4, + const int q_h, + const int s, + const int D4, + const int Q_H, + const int S) { +#ifdef OUTPUT_BUFFER + t_output[(s * Q_H * D4) + (q_h * D4) + d4] = out_texel; +#else + imageStore(t_output, ivec3(d4, q_h, s), out_texel); +#endif +} + +void store_sdpa_out_tile_with_checks( + const FPOutTile tile, + const int d4_start, + const int s_start, + const int q_h, + const int D4, + const int S, + const int Q_H) { + [[unroll]] for (int s = 0; s < TILE_M; ++s) { + [[unroll]] for (int d4 = 0; d4 < TILE_N4; ++d4) { + if (d4_start + d4 < D4 && s_start + s < S) { + store_out_d4( + tile.data[s][d4], d4_start + d4, q_h, s_start + s, D4, Q_H, S); + } + } + } +} + +#endif // SDPA_FP_OUT_TILE_LOAD_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_fp_q_projected_tile_load.glslh b/backends/vulkan/runtime/graph/ops/glsl/sdpa_fp_q_projected_tile_load.glslh new file mode 100644 index 00000000000..a304e5019e9 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_fp_q_projected_tile_load.glslh @@ -0,0 +1,74 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/* + * Assume the following variables are defined in the shader layout: + * - t_input + * + * Macro Settings: + * - INPUT_BUFFER + */ + +#ifndef SDPA_FP_Q_PROJECTED_TILE_LOAD_GLSLH +#define SDPA_FP_Q_PROJECTED_TILE_LOAD_GLSLH + +#extension GL_EXT_control_flow_attributes : require + +#include "linear_fp_input_tile.glslh" + +VEC4_T load_q_projected_d4( + const int d4, + const int q_h, + const int s, + const int D4, + const int Q_H, + const int S) { +#ifdef INPUT_BUFFER + return t_q_projected[(s * Q_H * D4) + (q_h * D4) + d4]; +#else + return texelFetch(t_q_projected, ivec3(d4, q_h, s), 0); +#endif +} + +void load_q_projected_tile_no_checks( + out FPInputTile tile, + const int d4_start, + const int s_start, + const int q_h, + const int D4, + const int Q_H, + const int S) { + [[unroll]] for (int s = 0; s < TILE_M; ++s) { + [[unroll]] for (int d4 = 0; d4 < TILE_K4; ++d4) { + tile.data[s][d4] = + load_q_projected_d4(d4_start + d4, q_h, s_start + s, D4, Q_H, S); + } + } +} + +void load_q_projected_tile_with_checks( + out FPInputTile tile, + const int d4_start, + const int s_start, + const int q_h, + const int D4, + const int Q_H, + const int S) { + [[unroll]] for (int s = 0; s < TILE_M; ++s) { + [[unroll]] for (int d4 = 0; d4 < TILE_K4; ++d4) { + if (d4_start + d4 < D4 && s_start + s < S) { + tile.data[s][d4] = + load_q_projected_d4(d4_start + d4, q_h, s_start + s, D4, Q_H, S); + } else { + tile.data[s][d4] = VEC4_T(0.0); + } + } + } +} + +#endif // SDPA_FP_Q_PROJECTED_TILE_LOAD_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_fp_v_cache_tile_load.glslh b/backends/vulkan/runtime/graph/ops/glsl/sdpa_fp_v_cache_tile_load.glslh new file mode 100644 index 00000000000..bf94b251c43 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_fp_v_cache_tile_load.glslh @@ -0,0 +1,76 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/* + * Assume the following variables are defined in the shader layout: + * - t_v_cache + * + * Macro Settings: + * - V_CACHE_BUFFER + */ + +#ifndef SDPA_FP_V_CACHE_TILE_LOAD_GLSLH +#define SDPA_FP_V_CACHE_TILE_LOAD_GLSLH + +#extension GL_EXT_control_flow_attributes : require + +#include "linear_fp_weight_tile.glslh" + +VEC4_T load_v_cache_d4( + const int d4, + const int c, + const int kv_h, + const int D4, + const int C, + const int KV_H) { +#ifdef V_CACHE_BUFFER + return VEC4_T(t_v_cache[(c * KV_H * D4) + (kv_h * D4) + d4]); +#else + return VEC4_T(texelFetch(t_v_cache, ivec3(d4, kv_h, c), 0)); +#endif +} + +void load_v_cache_tile_no_checks( + out FPWeightTile tile, + const int d4_start, + const int c_start, + const int kv_h, + const int D4, + const int context_len, + const int C, + const int KV_H) { + [[unroll]] for (int c = 0; c < TILE_N; ++c) { + [[unroll]] for (int d4 = 0; d4 < TILE_K4; ++d4) { + tile.data[c][d4] = + load_v_cache_d4(d4_start + d4, c_start + c, kv_h, D4, C, KV_H); + } + } +} + +void load_v_cache_tile_with_checks( + out FPWeightTile tile, + const int d4_start, + const int c_start, + const int kv_h, + const int D4, + const int context_len, + const int C, + const int KV_H) { + [[unroll]] for (int c = 0; c < TILE_N; ++c) { + [[unroll]] for (int d4 = 0; d4 < TILE_K4; ++d4) { + if (d4_start + d4 < D4 && c_start + c < context_len) { + tile.data[c][d4] = + load_v_cache_d4(d4_start + d4, c_start + c, kv_h, D4, C, KV_H); + } else { + tile.data[c][d4] = VEC4_T(0.0); + } + } + } +} + +#endif // SDPA_FP_V_CACHE_TILE_LOAD_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_kv_cache_update.glsl b/backends/vulkan/runtime/graph/ops/glsl/sdpa_kv_cache_update.glsl new file mode 100644 index 00000000000..932696fff02 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_kv_cache_update.glsl @@ -0,0 +1,90 @@ +#version 450 core + +#define PRECISION ${PRECISION} + +#define IN_VEC4_T ${texel_load_type(DTYPE, INPUT_STORAGE)} +#define T ${buffer_scalar_type(DTYPE)} + +$if INPUT_STORAGE == "buffer": + #define INPUT_BUFFER + +${define_required_extensions(DTYPE)} + +layout(std430) buffer; + +#include "common.glslh" + +${layout_declare_tensor(B, "w", "t_cache", DTYPE, OUTPUT_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_projected", DTYPE, INPUT_STORAGE, is_scalar_array=False)} + +${layout_declare_ubo(B, "ivec4", "cache_sizes")} +${layout_declare_ubo(B, "ivec4", "projected_sizes")} +${layout_declare_ubo(B, "int", "input_pos")} + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +/* + * t_cache will have sizes of (batches, n_heads, max_context_len, head_dim). + * t_projected will have sizes of (batches, seq_len, n_heads, head_dim). + * + * Note that the cache tensor swaps the order of the n_heads and seq len + * dimensions. This is to faciliate more optimal memory access patterns when + * using the caches to compute matrix multiplications. + * + * The cache update inserts the values of t_projected into t_cache at the index + * specified by input_pos at the seq_len dimension. It is equivalent to calling + + * t_cache = t_cache.slice_scatter( + * t_projected, dim=1, start=input_pos, end=input_pos+seq_len) + * + * Note that this shader is implemented assuming that max_batch_size is 1. + */ + +IN_VEC4_T read_projected_d4( + const int d4, + const int h, + const int s, + const int D4, + const int H, + const int S) { +#ifdef INPUT_BUFFER + return t_projected[(s * H * D4) + (h * D4) + d4]; +#else + return texelFetch(t_projected, ivec3(d4, h, s), 0); +#endif +} + +void write_cache_d4( + const IN_VEC4_T texel, + const int d4, + const int c, + const int h, + const int D4, + const int C, + const int H) { +#ifdef OUTPUT_BUFFER + t_cache[(c * H * D4) + (h * D4) + d4] = texel; +#else + imageStore(t_cache, ivec3(d4, h, c), texel); +#endif +} + +void main() { + const int d4 = int(gl_GlobalInvocationID.x); // idx along the head_dim dim + const int s = int(gl_GlobalInvocationID.y); // idx along the seq_len dim + const int h = int(gl_GlobalInvocationID.z); // idx along the n_heads dim + + const int D4 = div_up_4(projected_sizes.x); + const int S = projected_sizes.z; + const int H = projected_sizes.y; + + if (d4 >= D4 || s >= S || h >= H) { + return; + } + + const int c = s + input_pos; // idx along max_context_len dim + const int C = cache_sizes.y; + + IN_VEC4_T in_texel = read_projected_d4(d4, h, s, D4, H, S); + write_cache_d4(in_texel, d4, c, h, D4, C, H); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/flash_attention_buffer.yaml b/backends/vulkan/runtime/graph/ops/glsl/sdpa_kv_cache_update.yaml similarity index 61% rename from backends/vulkan/runtime/graph/ops/glsl/flash_attention_buffer.yaml rename to backends/vulkan/runtime/graph/ops/glsl/sdpa_kv_cache_update.yaml index 795ab906caa..85f4ce090f8 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/flash_attention_buffer.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_kv_cache_update.yaml @@ -4,12 +4,16 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -flash_attention_buffer: +sdpa_kv_cache_update: parameter_names_with_default_values: DTYPE: float - STORAGE: buffer + INPUT_STORAGE: texture3d + OUTPUT_STORAGE: texture3d generate_variant_forall: DTYPE: + - VALUE: half - VALUE: float shader_variants: - - NAME: flash_attention_buffer + - NAME: sdpa_kv_cache_update_texture3d + - NAME: sdpa_kv_cache_update_buffer + INPUT_STORAGE: buffer diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_q_projected_input_tile.glslh b/backends/vulkan/runtime/graph/ops/glsl/sdpa_q_projected_input_tile.glslh new file mode 100644 index 00000000000..da5dcd63b31 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_q_projected_input_tile.glslh @@ -0,0 +1,42 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifndef SDPA_FP_Q_PROJECTED_TILE_GLSLH +#define SDPA_FP_Q_PROJECTED_TILE_GLSLH + +/* + * Macro Settings: + * - TILE_S + * - TILE_D4 + */ + +#extension GL_EXT_control_flow_attributes : require + +struct FPQProjectedTile { + VEC4_T data[TILE_S][TILE_D4]; +}; + +#ifdef DEBUG_MODE + +void printFPQProjectedTile(const FPQProjectedTile in_tile) { + debugPrintfEXT("input_tile: \\n"); + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + [[unroll]] for (int k4 = 0; k4 < TILE_K4; ++k4) { + debugPrintfEXT( + " %f, %f, %f, %f, \\n", + in_tile.data[m][k4].x, + in_tile.data[m][k4].y, + in_tile.data[m][k4].z, + in_tile.data[m][k4].w); + } + } +} + +#endif // DEBUG_MODE + +#endif // SDPA_FP_Q_PROJECTED_TILE_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp b/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp index 2cc7455cd4a..8edaebd11ff 100644 --- a/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp @@ -24,6 +24,58 @@ namespace vkcompute { +bool is_single_token(ComputeGraph* graph, const ValueRef& q_projected) { + return graph->size_at(-3, q_projected) == 1; +} + +// +// Resize functions +// + +void resize_compute_attn_weights_node( + ComputeGraph* graph, + const std::vector& args, + const std::vector& resize_args) { + const ValueRef attn_weights = args.at(0).refs.at(0); + const ValueRef q_projected = args.at(1).refs.at(0); + const ValueRef input_pos_symint = resize_args.at(0); + + const uint32_t num_q_heads = graph->size_at(-2, q_projected); + const uint32_t seq_len = graph->size_at(-3, q_projected); + + const int32_t input_pos_val = graph->read_symint(input_pos_symint); + + const uint32_t context_len = seq_len + input_pos_val; + + std::vector out_sizes = { + 1, // batch + num_q_heads, + seq_len, + utils::align_up_4(context_len)}; + + graph->virtual_resize(attn_weights, out_sizes); +} + +void resize_sdpa_attn_weights_softmax_node( + ComputeGraph* graph, + const std::vector& args, + const std::vector& resize_args) { + const ValueRef attn_weights_softmax = args.at(0).refs.at(0); + const ValueRef attn_weights = args.at(1).refs.at(0); + + graph->virtual_resize(attn_weights_softmax, graph->sizes_of(attn_weights)); +} + +void resize_sdpa_compute_out_node( + ComputeGraph* graph, + const std::vector& args, + const std::vector& resize_args) { + const ValueRef out = args.at(0).refs.at(0); + const ValueRef q_projected = resize_args.at(0); + + graph->virtual_resize(out, graph->sizes_of(q_projected)); +} + void resize_sdpa_out( ComputeGraph* graph, const std::vector& args, @@ -36,195 +88,207 @@ void resize_sdpa_out( graph->virtual_resize(out, graph->sizes_of(q_projected)); } -void resize_flash_attention_out( +// +// Shader dispatch pick functions +// + +utils::uvec3 kv_cache_update_global_wg_size( ComputeGraph* graph, + const vkapi::ShaderInfo& shader, const std::vector& args, const std::vector& resize_args) { + (void)shader; (void)resize_args; - // Find the output tensor in the args - it's the first tensor in the first - // ArgGroup - const ValueRef out = args.at(0).refs.at(0); - const ValueRef q_projected = args.at(1).refs.at(0); - graph->virtual_resize(out, graph->sizes_of(q_projected)); + const ValueRef projected = args.at(1).refs.at(0); + + const uint32_t head_dim_size = graph->size_at(-1, projected); + const uint32_t num_heads = graph->size_at(-2, projected); + const uint32_t seq_len = graph->size_at(-3, projected); + + return {utils::div_up_4(head_dim_size), seq_len, num_heads}; } -utils::uvec3 flash_attention_global_wg_size( +utils::uvec3 attn_weight_scale_and_mask_global_wg_size( ComputeGraph* graph, const vkapi::ShaderInfo& shader, const std::vector& args, const std::vector& resize_args) { (void)shader; + (void)resize_args; - const ValueRef q_projected = resize_args.at(0); - const ValueRef block_size_r = resize_args.at(1); + const ValueRef attn_weight = args.at(0).refs.at(0); - // Get tensor dimensions - PyTorch format is [B, N, H, D] - // But Vulkan uses negative indexing: -4=B, -3=N, -2=H, -1=D - const int32_t B = graph->size_at(-4, q_projected); // batch - const int32_t N = graph->size_at(-3, q_projected); // sequence length - const int32_t H = graph->size_at(-2, q_projected); // num heads - const int32_t Br = - static_cast(graph->extract_scalar(block_size_r)); + if (graph->is_buffer_storage(attn_weight)) { + return { + graph->size_at(-1, attn_weight), + graph->size_at(-2, attn_weight), + graph->size_at(-3, attn_weight), + }; + } else { + return graph->logical_limits_of(attn_weight); + } +} + +vkapi::ShaderInfo pick_sdpa_compute_attn_weights_shader( + ComputeGraph* graph, + const std::vector& args, + const std::vector& resize_args) { + const ValueRef q_projected = args.at(1).refs.at(0); + const ValueRef k_cache = args.at(1).refs.at(1); - // Calculate number of row blocks - const int32_t Tr = (N + Br - 1) / Br; + const bool is_gemv = is_single_token(graph, q_projected); - return {static_cast(B * H * Tr), 1, 1}; + std::string shader_name = "sdpa_compute_attn_weights"; + if (is_gemv) { + shader_name += "_coop"; + } else { + shader_name += "_tiled"; + } + + add_storage_type_suffix(shader_name, graph->storage_type_of(q_projected)); + add_storage_type_suffix(shader_name, graph->storage_type_of(k_cache)); + add_dtype_suffix(shader_name, graph->dtype_of(q_projected)); + + return VK_KERNEL_FROM_STR(shader_name); } -void flash_attention_impl( - ComputeGraph& graph, - const std::vector& args) { - int arg_idx = 0; - const ValueRef q_projected = args[arg_idx++]; - const ValueRef k_cache = args[arg_idx++]; - const ValueRef v_cache = args[arg_idx++]; - const ValueRef input_pos_symint = args[arg_idx++]; - const ValueRef attn_mask = args[arg_idx++]; - const ValueRef dropout_p = args[arg_idx++]; - const ValueRef is_causal = args[arg_idx++]; - const ValueRef scale = args[arg_idx++]; +utils::uvec3 pick_sdpa_compute_attn_weights_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + const ValueRef q_projected = args.at(1).refs.at(0); + const ValueRef input_pos_symint = resize_args.at(0); - const ValueRef out = args[arg_idx++]; + const uint32_t num_q_heads = graph->size_at(-2, q_projected); + const uint32_t seq_len = graph->size_at(-3, q_projected); - // Extract input_pos value for causal masking - const int32_t input_pos_val = graph.read_symint(input_pos_symint); + const int32_t input_pos_val = graph->read_symint(input_pos_symint); - const ValueRef k_cache_tensor = k_cache; - const ValueRef v_cache_tensor = v_cache; + const uint32_t context_len = seq_len + input_pos_val; - // Validation checks - re-enable with correct indexing - VK_CHECK_COND(graph.size_at(-4, q_projected) == 1); // batch size = 1 - VK_CHECK_COND(graph.size_at(-4, k_cache_tensor) == 1); - VK_CHECK_COND(graph.size_at(-4, v_cache_tensor) == 1); - VK_CHECK_COND( - graph.sizes_of(k_cache_tensor) == graph.sizes_of(v_cache_tensor)); - VK_CHECK_COND( - graph.size_at(-1, q_projected) == - graph.size_at(-1, k_cache_tensor)); // head_dim must match - VK_CHECK_COND( - graph.val_is_none(dropout_p) || - graph.extract_scalar(dropout_p) == 0); - VK_CHECK_COND(graph.val_is_none(scale)); - VK_CHECK_COND( - graph.val_is_none(is_causal) || graph.extract_scalar(is_causal)); - VK_CHECK_COND(graph.val_is_none(attn_mask)); + const uint32_t N4 = utils::div_up_4(context_len); + const uint32_t M4 = utils::div_up_4(seq_len); + + return {N4, M4, num_q_heads}; +} + +utils::uvec3 pick_sdpa_compute_attn_weights_local_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const utils::uvec3& global_workgroup_size, + const std::vector& args, + const std::vector& resize_args) { + const bool use_coop_algorithm = + shader.kernel_name.find("_coop") != std::string::npos; - if (graph.is_buffer_storage(q_projected)) { - VK_CHECK_COND(graph.is_buffer_storage(k_cache_tensor)); - VK_CHECK_COND(graph.is_buffer_storage(v_cache_tensor)); - VK_CHECK_COND(graph.is_buffer_storage(out)); + if (use_coop_algorithm) { + return {1, 64, 1}; + } else { + return pick_hw_square_wg_size( + graph, shader, global_workgroup_size, args, resize_args); } +} - // Calculate scale factor - const int32_t head_dim_size = graph.size_at(-1, q_projected); - const float scale_val = 1.0f / std::sqrt(static_cast(head_dim_size)); +utils::uvec3 pick_sdpa_attn_weights_softmax_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + const ValueRef q_projected = resize_args.at(0); - // Get number of heads for multi-query attention support - const int32_t num_heads = graph.size_at(-2, q_projected); - const int32_t num_kv_heads = graph.size_at(-2, k_cache_tensor); + const uint32_t num_q_heads = graph->size_at(-2, q_projected); + const uint32_t seq_len = graph->size_at(-3, q_projected); - const int32_t block_size_r = 32; // Row block size - const int32_t block_size_c = 32; // Column block size + return {1, seq_len, num_q_heads}; +} - // l and m have shape [B, H, N] - std::vector lm_sizes = { - graph.size_at(-4, q_projected), // B (batch) - graph.size_at(-2, q_projected), // H (num heads) - graph.size_at(-3, q_projected) // N (sequence length) - }; +utils::uvec3 pick_sdpa_attn_weights_softmax_local_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const utils::uvec3& global_workgroup_size, + const std::vector& args, + const std::vector& resize_args) { + return {64, 1, 1}; +} - // t_l stores row-wise normalization sums for softmax computation - // t_m stores row-wise maximum values for numerical stability in softmax - TmpTensor t_l(&graph, lm_sizes, vkapi::kFloat, graph.storage_type_of(out)); - TmpTensor t_m(&graph, lm_sizes, vkapi::kFloat, graph.storage_type_of(out)); +vkapi::ShaderInfo pick_sdpa_compute_out_shader( + ComputeGraph* graph, + const std::vector& args, + const std::vector& resize_args) { + const ValueRef out = args.at(0).refs.at(0); + const ValueRef v_cache = args.at(1).refs.at(1); - // Choose kernel name based on storage type - std::string kernel_name = "flash_attention"; - add_storage_type_suffix(kernel_name, graph.storage_type_of(out)); - add_dtype_suffix(kernel_name, graph.dtype_of(out)); + const ValueRef q_projected = resize_args.at(0); - vkapi::ParamsBindList param_ubos = { - graph.sizes_ubo(q_projected), // Q_sizes - graph.sizes_ubo(k_cache_tensor), // K_sizes - graph.sizes_ubo(v_cache_tensor), // V_sizes - graph.sizes_ubo(out), // O_sizes - graph.sizes_ubo(t_l), // l_sizes - graph.sizes_ubo(t_m), // m_sizes - graph.create_params_buffer(scale_val), // scale - graph.create_params_buffer(block_size_r), // block_size_r - graph.create_params_buffer(block_size_c), // block_size_c - graph.create_params_buffer(input_pos_val), // input_pos - graph.create_params_buffer(num_heads), // num_heads - graph.create_params_buffer(num_kv_heads) // num_kv_heads - }; - - // Create block size references for dispatch calculation - const ValueRef block_size_r_ref = - graph.add_scalar(static_cast(block_size_r)); + const bool is_gemv = is_single_token(graph, q_projected); - graph.execute_nodes().emplace_back(new DynamicDispatchNode( - graph, - VK_KERNEL_FROM_STR(kernel_name), - flash_attention_global_wg_size, - default_pick_local_wg_size, - // Inputs and Outputs - { - {{out, t_l, t_m}, vkapi::kReadWrite}, - {{q_projected, k_cache_tensor, v_cache_tensor}, vkapi::kRead}, - }, - // Shader param buffers - param_ubos, - // Push Constants - {}, - // Specialization Constants - {}, - // Resize Args - {q_projected, block_size_r_ref}, - // Resizing Logic - resize_flash_attention_out)); + std::string shader_name = "sdpa_compute_out"; + if (is_gemv) { + shader_name += "_coop"; + } else { + shader_name += "_tiled"; + } + + add_storage_type_suffix(shader_name, graph->storage_type_of(out)); + add_storage_type_suffix(shader_name, graph->storage_type_of(v_cache)); + add_dtype_suffix(shader_name, graph->dtype_of(out)); + + return VK_KERNEL_FROM_STR(shader_name); } -utils::uvec3 kv_cache_update_global_wg_size( +utils::uvec3 pick_sdpa_compute_out_global_wg_size( ComputeGraph* graph, const vkapi::ShaderInfo& shader, const std::vector& args, const std::vector& resize_args) { - (void)shader; - (void)resize_args; + const ValueRef q_projected = resize_args.at(0); - const ValueRef cache = args.at(0).refs.at(0); - const ValueRef projected = args.at(1).refs.at(0); + const uint32_t head_dim = graph->size_at(-1, q_projected); + const uint32_t num_q_heads = graph->size_at(-2, q_projected); + const uint32_t seq_len = graph->size_at(-3, q_projected); - if (graph->is_buffer_storage(cache)) { - return graph->create_global_wg_size(projected); + const uint32_t N4 = utils::div_up_4(head_dim); + const uint32_t M4 = utils::div_up_4(seq_len); + + return {N4, M4, num_q_heads}; +} + +utils::uvec3 pick_sdpa_compute_out_local_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const utils::uvec3& global_workgroup_size, + const std::vector& args, + const std::vector& resize_args) { + const bool use_coop_algorithm = + shader.kernel_name.find("_coop") != std::string::npos; + + if (use_coop_algorithm) { + return {1, 64, 1}; } else { - return graph->logical_limits_of(projected); + return pick_hw_square_wg_size( + graph, shader, global_workgroup_size, args, resize_args); } } -void add_kv_cache_update_node( +// +// Dispatch nodes +// + +void add_sdpa_kv_cache_update_node( ComputeGraph& graph, const ValueRef input_pos_symint, const ValueRef projected, const ValueRef cache) { - std::string kernel_name("kv_cache_update"); + std::string kernel_name("sdpa_kv_cache_update"); add_storage_type_suffix(kernel_name, graph.storage_type_of(projected)); add_dtype_suffix(kernel_name, graph.dtype_of(projected)); - vkapi::ParamsBindList param_ubos; - - if (graph.is_buffer_storage(cache)) { - param_ubos = { - graph.numel_ubo(projected), - graph.strides_ubo(cache), - graph.get_or_create_int_param_buffer(input_pos_symint)}; - } else { - param_ubos = { - graph.logical_limits_ubo(projected), - graph.get_or_create_int_param_buffer(input_pos_symint)}; - } + vkapi::ParamsBindList param_ubos = { + graph.sizes_ubo(cache), + graph.sizes_ubo(projected), + graph.get_or_create_int_param_buffer(input_pos_symint)}; graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, @@ -240,121 +304,113 @@ void add_kv_cache_update_node( // Specialization Constants {}, // Resize Args - {}, + {input_pos_symint}, // Resizing Logic nullptr)); } -utils::uvec3 attn_weight_scale_and_mask_global_wg_size( - ComputeGraph* graph, - const vkapi::ShaderInfo& shader, - const std::vector& args, - const std::vector& resize_args) { - (void)shader; - (void)resize_args; - - const ValueRef attn_weight = args.at(0).refs.at(0); - - if (graph->is_buffer_storage(attn_weight)) { - return { - graph->size_at(-1, attn_weight), - graph->size_at(-2, attn_weight), - graph->size_at(-3, attn_weight), - }; - } else { - return graph->logical_limits_of(attn_weight); - } -} - -void add_attn_weight_scale_and_mask_node( +void add_sdpa_compute_attn_weights_node( ComputeGraph& graph, - const ValueRef input_pos_symint, const ValueRef q_projected, - const ValueRef attn_weight) { - std::string kernel_name("sdpa_attn_weight_scale_and_mask"); - add_storage_type_suffix(kernel_name, graph.storage_type_of(attn_weight)); - add_dtype_suffix(kernel_name, graph.dtype_of(attn_weight)); - + const ValueRef k_cache, + const ValueRef input_pos_symint, + const ValueRef attn_weights) { const int32_t head_dim_size = graph.size_at(-1, q_projected); const float scale_val = 1.0f / std::sqrt(static_cast(head_dim_size)); - vkapi::ParamsBindList param_ubos; - - if (graph.is_buffer_storage(attn_weight)) { - param_ubos = { - graph.sizes_ubo(attn_weight), - graph.strides_ubo(attn_weight), - graph.create_params_buffer(scale_val)}; - } else { - param_ubos = { - graph.logical_limits_ubo(attn_weight), - graph.get_or_create_int_param_buffer(input_pos_symint), - graph.create_params_buffer(scale_val)}; - } + vkapi::ParamsBindList param_ubos = { + graph.sizes_ubo(q_projected), + graph.sizes_ubo(k_cache), + graph.get_or_create_int_param_buffer(input_pos_symint)}; graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, - VK_KERNEL_FROM_STR(kernel_name), - attn_weight_scale_and_mask_global_wg_size, - default_pick_local_wg_size, + pick_sdpa_compute_attn_weights_shader, + pick_sdpa_compute_attn_weights_global_wg_size, + pick_sdpa_compute_attn_weights_local_wg_size, // Inputs and Outputs - {{attn_weight, vkapi::kReadWrite}}, + {{attn_weights, vkapi::kWrite}, {{q_projected, k_cache}, vkapi::kRead}}, // Shader param buffers param_ubos, // Push Constants {}, // Specialization Constants - {}, + {scale_val}, // Resize Args - {}, + {input_pos_symint}, // Resizing Logic - nullptr)); + resize_compute_attn_weights_node)); } -std::vector get_cache_slice_sizes( +void add_sdpa_attn_weights_softmax_node( ComputeGraph& graph, - ValueRef cache, - ValueRef input_pos_symint, - ValueRef q_projected) { - std::vector slice_sizes = graph.sizes_of(cache); - - // Cache slicing will always be in the channels dim - const int32_t input_pos_val = graph.read_symint(input_pos_symint); - const int64_t q_seq_len = graph.size_at(1, q_projected); - slice_sizes.at(1) = input_pos_val + q_seq_len; - return slice_sizes; -} + const ValueRef attn_weights, + const ValueRef q_projected, + const ValueRef input_pos_symint, + const ValueRef attn_weights_softmax) { + std::string shader_name = "sdpa_attn_weights_softmax"; + add_storage_type_suffix( + shader_name, graph.storage_type_of(attn_weights_softmax)); + add_dtype_suffix(shader_name, graph.dtype_of(attn_weights_softmax)); -void resize_cache_slice_view_node( - ComputeGraph* graph, - const std::vector& args, - const std::vector& extra_args) { - (void)args; - std::vector slice_sizes = get_cache_slice_sizes( - *graph, extra_args[0], extra_args[1], extra_args[2]); + vkapi::ParamsBindList param_ubos = { + graph.sizes_ubo(q_projected), + graph.get_or_create_int_param_buffer(input_pos_symint)}; - graph->virtual_resize(extra_args[3], slice_sizes); + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(shader_name), + pick_sdpa_attn_weights_softmax_global_wg_size, + pick_sdpa_attn_weights_softmax_local_wg_size, + // Inputs and Outputs + {{attn_weights_softmax, vkapi::kWrite}, {attn_weights, vkapi::kRead}}, + // Shader param buffers + param_ubos, + // Push Constants + {}, + // Specialization Constants + {}, + // Resize Args + {q_projected, input_pos_symint}, + // Resizing Logic + resize_sdpa_attn_weights_softmax_node)); } -void add_cache_slice_view_node( +void add_sdpa_compute_out_node( ComputeGraph& graph, - ValueRef cache, - ValueRef input_pos_symint, - ValueRef q_projected, - ValueRef cache_sliced, - const int64_t max_seq_len) { - std::vector slice_sizes = - get_cache_slice_sizes(graph, cache, input_pos_symint, q_projected); - // Initialize the slice to the maximum possible size to start - slice_sizes.at(1) = max_seq_len; - - graph.virtual_resize(cache_sliced, slice_sizes); - - graph.execute_nodes().emplace_back(new ExecuteNode( - resize_cache_slice_view_node, - {cache, input_pos_symint, q_projected, cache_sliced})); + const ValueRef attn_weights_softmax, + const ValueRef v_cache, + const ValueRef q_projected, + const ValueRef input_pos_symint, + const ValueRef out) { + vkapi::ParamsBindList param_ubos = { + graph.sizes_ubo(q_projected), + graph.sizes_ubo(v_cache), + graph.get_or_create_int_param_buffer(input_pos_symint)}; + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + pick_sdpa_compute_out_shader, + pick_sdpa_compute_out_global_wg_size, + pick_sdpa_compute_out_local_wg_size, + // Inputs and Outputs + {{out, vkapi::kWrite}, {{attn_weights_softmax, v_cache}, vkapi::kRead}}, + // Shader param buffers + param_ubos, + // Push Constants + {}, + // Specialization Constants + {}, + // Resize Args + {q_projected, input_pos_symint}, + // Resizing Logic + resize_sdpa_compute_out_node)); } +// +// High level operator impl +// + void update_cache_impl(ComputeGraph& graph, const std::vector& args) { int arg_idx = 0; const ValueRef value = args[arg_idx++]; @@ -372,7 +428,7 @@ void update_cache_impl(ComputeGraph& graph, const std::vector& args) { VK_CHECK_COND( graph.size_at(-2, value) == graph.size_at(-2, cache)); - add_kv_cache_update_node(graph, input_pos_symint, value, cache); + add_sdpa_kv_cache_update_node(graph, input_pos_symint, value, cache); } void sdpa_impl(ComputeGraph& graph, const std::vector& args) { @@ -413,105 +469,39 @@ void sdpa_impl(ComputeGraph& graph, const std::vector& args) { graph.val_is_none(is_causal) || graph.extract_scalar(is_causal)); VK_CHECK_COND(graph.val_is_none(attn_mask)); - const int32_t max_seq_len = graph.size_at(1, k_cache); + const int64_t num_q_heads = graph.size_at(-2, q_projected); + const int64_t max_seq_len = graph.size_at(-3, q_projected); - // Slice caches from 0 to input_pos + sequence_len - const ValueRef k_cache_sliced = graph.add_tensor_view(k_cache); - const ValueRef v_cache_sliced = graph.add_tensor_view(v_cache); - add_cache_slice_view_node( - graph, - k_cache, - input_pos_symint, - q_projected, - k_cache_sliced, - max_seq_len); - add_cache_slice_view_node( - graph, - v_cache, - input_pos_symint, - q_projected, - v_cache_sliced, - max_seq_len); - - // Scalar values for various dims - const ValueRef channels = graph.add_scalar(1); - const ValueRef height = graph.add_scalar(2); - const ValueRef width = graph.add_scalar(3); - - // Repeat interleave - const int64_t num_heads = graph.size_at(2, q_projected); - const int64_t num_kv_heads = graph.size_at(2, k_cache); - - const ValueRef num_repeats = - graph.add_scalar(num_heads / num_kv_heads); - - std::vector cache_slice_repeated_sizes(graph.sizes_of(q_projected)); - cache_slice_repeated_sizes.at(1) = max_seq_len; - - TmpTensor k_cache_sliced_repeated( - &graph, cache_slice_repeated_sizes, graph.dtype_of(k_cache_sliced)); - TmpTensor v_cache_sliced_repeated( - &graph, cache_slice_repeated_sizes, graph.dtype_of(v_cache_sliced)); - - add_repeat_interleave_node( - graph, k_cache_sliced, num_repeats, height, k_cache_sliced_repeated); - add_repeat_interleave_node( - graph, v_cache_sliced, num_repeats, height, v_cache_sliced_repeated); - - // Transpose sequence and head dims - const ValueRef q_transposed = graph.add_tensor_view(q_projected); - const ValueRef k_transposed = graph.add_tensor_view(k_cache_sliced_repeated); - const ValueRef v_transposed = graph.add_tensor_view(v_cache_sliced_repeated); - - add_transpose_view_node(graph, q_projected, channels, height, q_transposed); - add_transpose_view_node( - graph, k_cache_sliced_repeated, channels, height, k_transposed); - add_transpose_view_node( - graph, v_cache_sliced_repeated, channels, height, v_transposed); - - // Transpose K again to prepare for matmul - const ValueRef k_transposed_2 = graph.add_tensor_view(k_transposed); - add_transpose_view_node(graph, k_transposed, height, width, k_transposed_2); - - // Initialize attn_weight to the maximum possible size - std::vector attn_weight_full_sizes = graph.sizes_of(q_transposed); - attn_weight_full_sizes.at(2) = max_seq_len; - attn_weight_full_sizes.at(3) = max_seq_len; - TmpTensor attn_weight( - &graph, attn_weight_full_sizes, graph.dtype_of(q_transposed)); - - // Resize attn_weight to the correct dim - std::vector attn_weight_sizes = attn_weight_full_sizes; - attn_weight_sizes.at(2) = graph.size_at(2, q_transposed); - attn_weight_sizes.at(3) = graph.size_at(2, k_transposed); - graph.virtual_resize(attn_weight, attn_weight_sizes); - - // Calculate attention weight, which is a matmul of Q and K - const ValueRef mat2_is_transposed = graph.add_scalar(false); - add_matmul_node( - graph, q_transposed, k_transposed_2, attn_weight, mat2_is_transposed); - - // Apply scale and mask to the attention weight - add_attn_weight_scale_and_mask_node( - graph, input_pos_symint, q_projected, attn_weight); - - TmpTensor attn_weight_softmax( - &graph, attn_weight_full_sizes, graph.dtype_of(q_transposed)); - graph.virtual_resize(attn_weight_softmax, attn_weight_sizes); - add_softmax_node(graph, attn_weight, width, attn_weight_softmax, false); - - // Calculate final output - const ValueRef out_transposed = graph.add_tensor_view(out); - add_transpose_view_node(graph, out, channels, height, out_transposed); - add_matmul_node( - graph, - attn_weight_softmax, - v_transposed, - out_transposed, - mat2_is_transposed); + const int64_t max_context_len = graph.size_at(-3, k_cache); + + std::vector attn_weight_full_sizes = { + 1, // batch + num_q_heads, + max_seq_len, + max_context_len}; + + TmpTensor attn_weights( + &graph, + attn_weight_full_sizes, + graph.dtype_of(q_projected), + graph.storage_type_of(q_projected), + utils::kWidthPacked); + + TmpTensor attn_weights_softmax( + &graph, + attn_weight_full_sizes, + graph.dtype_of(q_projected), + graph.storage_type_of(q_projected), + utils::kWidthPacked); + + add_sdpa_compute_attn_weights_node( + graph, q_projected, k_cache, input_pos_symint, attn_weights); + + add_sdpa_attn_weights_softmax_node( + graph, attn_weights, q_projected, input_pos_symint, attn_weights_softmax); - graph.execute_nodes().emplace_back( - new ExecuteNode(resize_sdpa_out, {q_projected, out})); + add_sdpa_compute_out_node( + graph, attn_weights_softmax, v_cache, q_projected, input_pos_symint, out); } void sdpa_with_kv_cache_impl( @@ -535,10 +525,10 @@ void sdpa_with_kv_cache_impl( (void)sequence_len; - const ValueRef k_cache = - prepack_standard_like(graph, k_cache_data, q_projected); - const ValueRef v_cache = - prepack_standard_like(graph, v_cache_data, q_projected); + const ValueRef k_cache = prepack_standard( + graph, k_cache_data, utils::kTexture3D, utils::kWidthPacked); + const ValueRef v_cache = prepack_standard( + graph, v_cache_data, utils::kTexture3D, utils::kWidthPacked); update_cache_impl(graph, {k_projected, k_cache, input_pos_symint, -1}); update_cache_impl(graph, {v_projected, v_cache, input_pos_symint, -1}); @@ -560,7 +550,6 @@ REGISTER_OPERATORS { VK_REGISTER_OP(sdpa_with_kv_cache.default, sdpa_with_kv_cache_impl); VK_REGISTER_OP(update_cache.default, update_cache_impl); VK_REGISTER_OP(llama.custom_sdpa.default, sdpa_impl); - VK_REGISTER_OP(llama.flash_attention.default, flash_attention_impl); } } // namespace vkcompute diff --git a/backends/vulkan/test/op_tests/CMakeLists.txt b/backends/vulkan/test/op_tests/CMakeLists.txt index 07a13c3f260..5e8991f8e50 100644 --- a/backends/vulkan/test/op_tests/CMakeLists.txt +++ b/backends/vulkan/test/op_tests/CMakeLists.txt @@ -47,7 +47,7 @@ find_library(LIB_C10 c10 HINTS ${TORCH_INSTALL_PREFIX}/lib) # Third party include paths -set(VULKAN_THIRD_PARTY_PATH ${CMAKE_CURRENT_SOURCE_DIR}/../../third-party) +set(VULKAN_THIRD_PARTY_PATH ${EXECUTORCH_ROOT}/backends/vulkan/third-party) set(GTEST_INCLUDE_PATH ${EXECUTORCH_ROOT}/third-party/googletest/googletest/include diff --git a/backends/vulkan/test/op_tests/sdpa_test.cpp b/backends/vulkan/test/op_tests/sdpa_test.cpp index e4b3f662c04..a94e68a53af 100644 --- a/backends/vulkan/test/op_tests/sdpa_test.cpp +++ b/backends/vulkan/test/op_tests/sdpa_test.cpp @@ -215,16 +215,13 @@ at::Tensor sdpa_reference_impl( void test_reference_sdpa( const int start_input_pos, const int sequence_len, - const int embedding_dim, + const int head_dim, const int num_heads, const int num_kv_heads, const int batch_size, const int max_seq_len, at::ScalarType dtype = at::kFloat) { - const int head_dim = embedding_dim / num_heads; - // K and V caches. Need an extra set for the reference implementation - at::Tensor k_cache = at::zeros( {batch_size, max_seq_len, num_kv_heads, head_dim}, at::device(at::kCPU).dtype(dtype)); @@ -265,19 +262,23 @@ void test_reference_sdpa( void test_vulkan_sdpa( const int start_input_pos, - const int base_sequence_len, - const int embedding_dim, + const std::vector& sequence_lens, + const int head_dim, const int num_heads, const int num_kv_heads, const int batch_size, - const int max_seq_len, - const bool dynamic_seq_len = true, + vkcompute::utils::StorageType storage_type, at::ScalarType dtype = at::kFloat) { - const int head_dim = embedding_dim / num_heads; + // compute the max sequence length + int max_seq_len = start_input_pos; + for (int i = 0; i < sequence_lens.size(); ++i) { + max_seq_len += sequence_lens[i]; + } + // Add some extra space to the max sequence length + max_seq_len += 128; - const int init_seq_len = dynamic_seq_len ? max_seq_len : base_sequence_len; + const int init_seq_len = max_seq_len; // K and V caches - at::Tensor k_cache = at::zeros( {batch_size, max_seq_len, num_kv_heads, head_dim}, at::device(at::kCPU).dtype(dtype)); @@ -300,7 +301,6 @@ void test_vulkan_sdpa( using namespace vkcompute; GraphConfig config; - config.set_storage_type_override(utils::kTexture3D); ComputeGraph graph(config); // "Data" variant for vulkan initialization @@ -319,7 +319,7 @@ void test_vulkan_sdpa( #define MAKE_INPUT_FOR(x) \ IOValueRef r_##x = graph.add_input_tensor( \ - x.sizes().vec(), from_at_scalartype(x.scalar_type())); + x.sizes().vec(), from_at_scalartype(x.scalar_type()), storage_type); MAKE_INPUT_FOR(q); MAKE_INPUT_FOR(k); @@ -328,7 +328,7 @@ void test_vulkan_sdpa( const ValueRef r_input_pos_symint = graph.add_symint(start_input_pos); const ValueRef r_out = graph.add_tensor( - out.sizes().vec(), from_at_scalartype(out.scalar_type())); + out.sizes().vec(), from_at_scalartype(out.scalar_type()), storage_type); VK_GET_OP_FN("sdpa_with_kv_cache.default") (graph, @@ -365,10 +365,10 @@ void test_vulkan_sdpa( graph.copy_from_staging( \ staging_##x, vk_##x.mutable_data_ptr(), vk_##x.numel()); - int seq_len = base_sequence_len; - for (int i = 0, input_pos = start_input_pos; - input_pos + seq_len < max_seq_len; - input_pos += seq_len, i++) { + torch::manual_seed(0); + + int input_pos = start_input_pos; + for (auto seq_len : sequence_lens) { q = at::rand( {batch_size, seq_len, num_heads, head_dim}, at::device(at::kCPU).dtype(dtype)); @@ -398,6 +398,46 @@ void test_vulkan_sdpa( const bool output_correct = at::allclose(reference_out, vk_out); if (!output_correct) { + // Print only differing tensor elements side by side for easier comparison + auto ref_flat = reference_out.flatten(); + auto vk_flat = vk_out.flatten(); + auto numel = ref_flat.numel(); + std::cout << "reference_out\tvk_out\tindex" << std::endl; + int first_diff_idx = -1; + auto sizes = reference_out.sizes(); + int d0 = sizes[0], d1 = sizes[1], d2 = sizes[2], d3 = sizes[3]; + for (int i = 0; i < numel; ++i) { + if (std::abs(ref_flat[i].item() - vk_flat[i].item()) > + 1e-4) { + // Compute 4-D index from flat index + int i0 = i / (d1 * d2 * d3); + int rem0 = i % (d1 * d2 * d3); + int i1 = rem0 / (d2 * d3); + int rem1 = rem0 % (d2 * d3); + int i2 = rem1 / d3; + int i3 = rem1 % d3; + std::cout << ref_flat[i].item() << "\t" << vk_flat[i].item() << "\t[" + << i0 << ", " << i1 << ", " << i2 << ", " << i3 << "]" + << std::endl; + if (first_diff_idx == -1) { + first_diff_idx = i; + } + break; + } + } + if (first_diff_idx != -1) { + // Compute 4-D index from flat index + int i0 = first_diff_idx / (d1 * d2 * d3); + int rem0 = first_diff_idx % (d1 * d2 * d3); + int i1 = rem0 / (d2 * d3); + int rem1 = rem0 % (d2 * d3); + int i2 = rem1 / d3; + int i3 = rem1 % d3; + std::cout << "First difference at flat index " << first_diff_idx + << " which is tensor index [" << i0 << ", " << i1 << ", " + << i2 << ", " << i3 << "]" << std::endl; + } + at::Tensor diffs = at::abs(reference_out - vk_out); std::cout << "Failed at input_pos " << input_pos << " with seq_len " @@ -414,426 +454,65 @@ void test_vulkan_sdpa( } ASSERT_TRUE(output_correct); - if (dynamic_seq_len) { - seq_len = base_sequence_len + (i % 3); - } + input_pos += seq_len; } } -TEST(VulkanSDPATest, test_sdpa_op_small_params) { - const int starting_input_pos = 0; - const int base_sequence_len = 3; - const int embedding_dim = 18; - const int num_heads = 6; - const int num_kv_heads = 2; - const int batch_size = 1; - const int max_seq_len = 7; - - test_vulkan_sdpa( - starting_input_pos, - base_sequence_len, - embedding_dim, - num_heads, - num_kv_heads, - batch_size, - max_seq_len, - false); -} - -TEST(VulkanSDPATest, test_sdpa_op_small_params_dynamic) { - const int starting_input_pos = 0; - const int base_sequence_len = 3; - const int embedding_dim = 18; - const int num_heads = 6; - const int num_kv_heads = 2; - const int batch_size = 1; - const int max_seq_len = 12; - - test_vulkan_sdpa( - starting_input_pos, - base_sequence_len, - embedding_dim, - num_heads, - num_kv_heads, - batch_size, - max_seq_len); -} - -TEST(VulkanSDPATest, test_sdpa_op_llama3_params_dynamic) { - const int starting_input_pos = 0; - const int base_sequence_len = 3; - const int embedding_dim = 2048; - const int num_heads = 32; - const int num_kv_heads = 8; - const int batch_size = 1; - const int max_seq_len = 128; - - test_vulkan_sdpa( - starting_input_pos, - base_sequence_len, - embedding_dim, - num_heads, - num_kv_heads, - batch_size, - max_seq_len); -} - -TEST(VulkanSDPATest, test_reference_impl) { - const int starting_input_pos = 0; - const int base_sequence_len = 3; - const int embedding_dim = 2048; - const int num_heads = 32; - const int num_kv_heads = 8; - const int batch_size = 1; - const int max_seq_len = 128; - - test_reference_sdpa( - starting_input_pos, - base_sequence_len, - embedding_dim, - num_heads, - num_kv_heads, - batch_size, - max_seq_len); -} - -void test_vulkan_flash_attention_impl( - const int start_input_pos, - const int sequence_len, - const int embedding_dim, - const int num_heads, - const int num_kv_heads, - const int batch_size, - const int max_seq_len, - vkcompute::utils::StorageType storage_type, - at::ScalarType dtype = at::kFloat) { - const int head_dim = embedding_dim / num_heads; - - at::Tensor k_cache = at::zeros( - {batch_size, max_seq_len, num_kv_heads, head_dim}, - at::device(at::kCPU).dtype(dtype)); - at::Tensor v_cache = at::zeros_like(k_cache); - - at::Tensor q = at::rand( - {batch_size, sequence_len, num_heads, head_dim}, - at::device(at::kCPU).dtype(dtype)); - at::Tensor k = at::rand( - {batch_size, sequence_len, num_kv_heads, head_dim}, - at::device(at::kCPU).dtype(dtype)); - at::Tensor v = at::rand_like(k); - - // Get reference output using existing SDPA - at::Tensor reference_out = sdpa_reference_impl( - q, - k, - v, - k_cache, - v_cache, - start_input_pos, - sequence_len, - {}, - 0.0, - true, - {}); - - using namespace vkcompute; - - GraphConfig config; - config.set_storage_type_override(storage_type); - ComputeGraph graph(config); - - // Create input references - IOValueRef r_q = graph.add_input_tensor( - q.sizes().vec(), from_at_scalartype(q.scalar_type())); - IOValueRef r_k = graph.add_input_tensor( - k.sizes().vec(), from_at_scalartype(k.scalar_type())); - IOValueRef r_v = graph.add_input_tensor( - v.sizes().vec(), from_at_scalartype(v.scalar_type())); - - // Create cache tensors (these would be updated by cache update operations in - // practice) - ValueRef r_k_cache = graph.add_tensorref( - k_cache.sizes().vec(), - from_at_scalartype(k_cache.scalar_type()), - k_cache.const_data_ptr()); - ValueRef r_v_cache = graph.add_tensorref( - v_cache.sizes().vec(), - from_at_scalartype(v_cache.scalar_type()), - v_cache.const_data_ptr()); - - const ValueRef r_input_pos_symint = graph.add_symint(start_input_pos); - const ValueRef r_out = - graph.add_tensor(q.sizes().vec(), from_at_scalartype(q.scalar_type())); - - // Call Flash Attention implementation - VK_GET_OP_FN("llama.flash_attention.default") - (graph, - { - r_q.value, - r_k.value, // Use actual K tensor, not cache - r_v.value, // Use actual V tensor, not cache - r_input_pos_symint, - kDummyValueRef, // attn_mask - kDummyValueRef, // dropout_p - kDummyValueRef, // is_causal - kDummyValueRef, // scale - r_out, - }); - - ValueRef staging_out = graph.set_output_tensor(r_out); - - graph.prepare(); - graph.prepack(); - - // Copy inputs and run - graph.copy_into_staging(r_q.staging, q.const_data_ptr(), q.numel()); - graph.copy_into_staging(r_k.staging, k.const_data_ptr(), k.numel()); - graph.copy_into_staging(r_v.staging, v.const_data_ptr(), v.numel()); - - graph.execute(); - - // Extract output - at::Tensor vk_out = at::zeros_like(q).contiguous(); - graph.copy_from_staging( - staging_out, vk_out.mutable_data_ptr(), vk_out.numel()); - - // Compare results - const bool output_correct = at::allclose(reference_out, vk_out, 1e-3, 1e-3); - - if (!output_correct) { - at::Tensor diffs = at::abs(reference_out - vk_out); - std::cout << "Maximum difference: " << at::max(diffs).item() << std::endl; - std::cout << "Maximum value observed: " - << at::max(at::abs(at::cat({reference_out, vk_out}, -1))).item() - << std::endl; - } - ASSERT_TRUE(output_correct); -} - -void test_vulkan_flash_attention( +void test_vulkan_sdpa( const int start_input_pos, - const int sequence_len, - const int embedding_dim, + const std::vector& sequence_lens, + const int head_dim, const int num_heads, const int num_kv_heads, const int batch_size, - const int max_seq_len, at::ScalarType dtype = at::kFloat) { - test_vulkan_flash_attention_impl( + // Test texture + test_vulkan_sdpa( start_input_pos, - sequence_len, - embedding_dim, + sequence_lens, + head_dim, num_heads, num_kv_heads, batch_size, - max_seq_len, - vkcompute::utils::kBuffer, + vkcompute::utils::kTexture3D, dtype); - test_vulkan_flash_attention_impl( + // Test buffer + test_vulkan_sdpa( start_input_pos, - sequence_len, - embedding_dim, + sequence_lens, + head_dim, num_heads, num_kv_heads, batch_size, - max_seq_len, - vkcompute::utils::kTexture3D, + vkcompute::utils::kBuffer, dtype); } -// Flash Attention Tests (both Buffer and Texture) -TEST(VulkanSDPATest, test_flash_attention_small_params) { - const int starting_input_pos = 0; - const int sequence_len = 2; - const int embedding_dim = 4; - const int num_heads = 2; - const int num_kv_heads = 2; - const int batch_size = 1; - const int max_seq_len = 4; - - test_vulkan_flash_attention( - starting_input_pos, - sequence_len, - embedding_dim, - num_heads, - num_kv_heads, - batch_size, - max_seq_len); -} - -TEST(VulkanSDPATest, test_flash_attention_multi_tile) { - const int starting_input_pos = 0; - const int sequence_len = 48; - const int embedding_dim = 32; - const int num_heads = 2; - const int num_kv_heads = 2; - const int batch_size = 1; - const int max_seq_len = 64; - - test_vulkan_flash_attention( - starting_input_pos, - sequence_len, - embedding_dim, - num_heads, - num_kv_heads, - batch_size, - max_seq_len); -} - -TEST(VulkanSDPATest, test_flash_attention_op_small_params) { - const int starting_input_pos = 0; - const int sequence_len = 3; - const int embedding_dim = 18; - const int num_heads = 6; - const int num_kv_heads = 2; - const int batch_size = 1; - const int max_seq_len = 7; +TEST(VulkanSDPATest, test_sdpa_op_small_params) { + const int base_sequence_len = 3; + const int num_heads = 8; + const int head_dim = 4; + const int num_kv_heads = 4; - test_vulkan_flash_attention( - starting_input_pos, - sequence_len, - embedding_dim, - num_heads, - num_kv_heads, - batch_size, - max_seq_len); + test_vulkan_sdpa( + 0, {3, 1, 1, 5, 1, 1, 2}, head_dim, num_heads, num_kv_heads, 1); } -TEST(VulkanSDPATest, test_flash_attention_op_small_params_dynamic) { - const int starting_input_pos = 0; - const int sequence_len = 3; - const int embedding_dim = 18; +TEST(VulkanSDPATest, test_sdpa_op_small_params_dynamic) { + const int base_sequence_len = 3; + const int head_dim = 8; const int num_heads = 6; const int num_kv_heads = 2; - const int batch_size = 1; - const int max_seq_len = 12; - test_vulkan_flash_attention( - starting_input_pos, - sequence_len, - embedding_dim, - num_heads, - num_kv_heads, - batch_size, - max_seq_len); + test_vulkan_sdpa(0, {3, 1, 1, 5, 1, 1}, head_dim, num_heads, num_kv_heads, 1); } -TEST(VulkanSDPATest, test_flash_attention_op_llama3_params) { - const int starting_input_pos = 0; - const int sequence_len = 3; - const int embedding_dim = 2048; - const int num_heads = 32; - const int num_kv_heads = 8; - const int batch_size = 1; - const int max_seq_len = 128; - - test_vulkan_flash_attention( - starting_input_pos, - sequence_len, - embedding_dim, - num_heads, - num_kv_heads, - batch_size, - max_seq_len); -} - -TEST(VulkanSDPATest, test_flash_attention_op_llama3_params_dynamic) { - const int starting_input_pos = 0; - const int embedding_dim = 2048; - const int num_heads = 32; - const int num_kv_heads = 8; - const int batch_size = 1; - const int max_seq_len = 128; - - // Test with different sequence lengths - std::vector sequence_lengths = {1, 3, 5, 7, 16, 32}; - - for (int seq_len : sequence_lengths) { - if (seq_len < max_seq_len) { - test_vulkan_flash_attention( - starting_input_pos, - seq_len, - embedding_dim, - num_heads, - num_kv_heads, - batch_size, - max_seq_len); - } - } -} - -TEST(VulkanSDPATest, test_flash_attention_reference_impl) { - const int starting_input_pos = 0; - const int sequence_len = 3; - const int embedding_dim = 2048; - const int num_heads = 32; +TEST(VulkanSDPATest, test_sdpa_op_llama3_params_dynamic) { + const int head_dim = 128; + const int num_heads = 24; const int num_kv_heads = 8; - const int batch_size = 1; - const int max_seq_len = 128; - - test_vulkan_flash_attention( - starting_input_pos, - sequence_len, - embedding_dim, - num_heads, - num_kv_heads, - batch_size, - max_seq_len); -} - -TEST(VulkanSDPATest, test_flash_attention_reference_impl_small) { - const int starting_input_pos = 0; - const int sequence_len = 2; - const int embedding_dim = 32; - const int num_heads = 4; - const int num_kv_heads = 2; - const int batch_size = 1; - const int max_seq_len = 16; - - test_vulkan_flash_attention( - starting_input_pos, - sequence_len, - embedding_dim, - num_heads, - num_kv_heads, - batch_size, - max_seq_len); -} - -TEST(VulkanSDPATest, test_flash_attention_vec4_alignment) { - const int starting_input_pos = 0; - const int sequence_len = 8; - const int embedding_dim = 64; - const int num_heads = 4; - const int num_kv_heads = 2; - const int batch_size = 1; - const int max_seq_len = 16; - test_vulkan_flash_attention( - starting_input_pos, - sequence_len, - embedding_dim, - num_heads, - num_kv_heads, - batch_size, - max_seq_len); -} - -TEST(VulkanSDPATest, test_flash_attention_edge_cases) { - // Test with single head (no multi-query complexity) - test_vulkan_flash_attention(0, 1, 8, 1, 1, 1, 4); - - // Test with equal heads (no multi-query complexity) - test_vulkan_flash_attention(0, 2, 16, 4, 4, 1, 8); - - // Test with large head dimension - test_vulkan_flash_attention(0, 2, 128, 2, 1, 1, 8); - - // Test with sequence length that exactly matches block size (32) - test_vulkan_flash_attention(0, 32, 64, 2, 1, 1, 64); - - // Test with sequence length slightly larger than block size - test_vulkan_flash_attention( - 0, 33, 68, 2, 1, 1, 64); // 68 = 4*17, good for vec4 + test_vulkan_sdpa( + 0, {111, 1, 1, 1, 57, 1, 1}, head_dim, num_heads, num_kv_heads, 1); } diff --git a/backends/vulkan/test/scripts/test_op.sh b/backends/vulkan/test/scripts/test_op.sh index 36920cb73cc..1ec07b7f75f 100755 --- a/backends/vulkan/test/scripts/test_op.sh +++ b/backends/vulkan/test/scripts/test_op.sh @@ -141,6 +141,8 @@ build_core_libraries() { -DEXECUTORCH_BUILD_EXTENSION_RUNNER_UTIL=ON \ -DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \ -DEXECUTORCH_BUILD_EXECUTOR_RUNNER=ON \ + -DEXECUTORCH_BUILD_KERNELS_LLM=ON \ + -DEXECUTORCH_BUILD_KERNELS_LLM_AOT=ON \ -DEXECUTORCH_BUILD_DEVTOOLS=ON \ -DEXECUTORCH_BUILD_VULKAN=ON \ -DEXECUTORCH_BUILD_XNNPACK=ON \ @@ -152,39 +154,45 @@ build_core_libraries() { build_operator_tests() { echo "Building Vulkan operator tests..." - # Check if TORCH_OPS_YAML_PATH is set, if not use default - if [[ -z "${TORCH_OPS_YAML_PATH:-}" ]]; then - TORCH_OPS_YAML_PATH="$HOME/Github/pytorch/aten/src/ATen/native" - echo "Using default TORCH_OPS_YAML_PATH: $TORCH_OPS_YAML_PATH" - fi + # Prepare CMAKE arguments + CMAKE_ARGS=( + "backends/vulkan/test/op_tests" + "-DCMAKE_INSTALL_PREFIX=cmake-out" + "-DPYTHON_EXECUTABLE=$PYTHON_EXECUTABLE" + "-DCMAKE_CXX_STANDARD=17" + ) - # Verify that TORCH_OPS_YAML_PATH exists - if [[ ! -d "$TORCH_OPS_YAML_PATH" ]]; then - echo "Error: TORCH_OPS_YAML_PATH directory does not exist: $TORCH_OPS_YAML_PATH" - echo "Please set TORCH_OPS_YAML_PATH to a valid PyTorch native operations directory" - echo "Example: export TORCH_OPS_YAML_PATH=/path/to/pytorch/aten/src/ATen/native" - exit 1 - fi + # Check if TORCH_OPS_YAML_PATH is set + if [[ -n "${TORCH_OPS_YAML_PATH:-}" ]]; then + # Verify that TORCH_OPS_YAML_PATH exists + if [[ ! -d "$TORCH_OPS_YAML_PATH" ]]; then + echo "Error: TORCH_OPS_YAML_PATH directory does not exist: $TORCH_OPS_YAML_PATH" + echo "Please set TORCH_OPS_YAML_PATH to a valid PyTorch native operations directory" + echo "Example: export TORCH_OPS_YAML_PATH=/path/to/pytorch/aten/src/ATen/native" + exit 1 + fi - # Verify required YAML files exist - if [[ ! -f "$TORCH_OPS_YAML_PATH/native_functions.yaml" ]]; then - echo "Error: Required file not found: $TORCH_OPS_YAML_PATH/native_functions.yaml" - exit 1 - fi + # Verify required YAML files exist + if [[ ! -f "$TORCH_OPS_YAML_PATH/native_functions.yaml" ]]; then + echo "Error: Required file not found: $TORCH_OPS_YAML_PATH/native_functions.yaml" + exit 1 + fi - if [[ ! -f "$TORCH_OPS_YAML_PATH/tags.yaml" ]]; then - echo "Error: Required file not found: $TORCH_OPS_YAML_PATH/tags.yaml" - exit 1 - fi + if [[ ! -f "$TORCH_OPS_YAML_PATH/tags.yaml" ]]; then + echo "Error: Required file not found: $TORCH_OPS_YAML_PATH/tags.yaml" + exit 1 + fi - echo "Using TORCH_OPS_YAML_PATH: $TORCH_OPS_YAML_PATH" + echo "Using TORCH_OPS_YAML_PATH: $TORCH_OPS_YAML_PATH" + CMAKE_ARGS+=("-DTORCH_OPS_YAML_PATH=$TORCH_OPS_YAML_PATH") + else + echo "WARNING: TORCH_OPS_YAML_PATH is not set. Building without PyTorch operator definitions." + echo "Some functionality may be limited. To enable full functionality, set TORCH_OPS_YAML_PATH to point to PyTorch's native operations directory." + echo "Example: export TORCH_OPS_YAML_PATH=/path/to/pytorch/aten/src/ATen/native" + fi # Build operator tests - cmake backends/vulkan/test/op_tests \ - -DCMAKE_INSTALL_PREFIX=cmake-out \ - -DPYTHON_EXECUTABLE="$PYTHON_EXECUTABLE" \ - -DTORCH_OPS_YAML_PATH="$TORCH_OPS_YAML_PATH" \ - -DCMAKE_CXX_STANDARD=17 \ + cmake "${CMAKE_ARGS[@]}" \ -Bcmake-out/backends/vulkan/test/op_tests && \ cmake --build cmake-out/backends/vulkan/test/op_tests -j16 } From fb1fff5db9483becf986e32a0f4408dd7cc4632e Mon Sep 17 00:00:00 2001 From: ssjia Date: Tue, 9 Sep 2025 15:04:36 -0700 Subject: [PATCH 2/3] Update on "[ET-VK] Implement SDPA with fused ops" ## Context As title; optimize the SDPA operator by introducing shaders to perform the operation in 3 steps: 1. Compute attention weights, multiplying QT x K_cache, and applying scale and mask 2. Compute softmax normalization of computed attention weights 3. Compute final output by multiplying attention weights with V cache This new implementation is much more efficient than the existing one, which performed slicing, repeat_interleave, and transposition of projected and cache tensors as separate steps. The fusion of scale and mask with the computation of attention weights also allows for the computation of elements within the mask region to be skipped. ## Impact Decode latency for LLMs is much improved. For llama 3.2 3B generating ~250 tokens, decode latency increases from ~15 tok/s to ~21.5 tok/s Differential Revision: [D82053493](https://our.internmc.facebook.com/intern/diff/D82053493/) [ghstack-poisoned] --- .github/workflows/pull.yml | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/.github/workflows/pull.yml b/.github/workflows/pull.yml index cb0a3d9b679..7cb22d90f60 100644 --- a/.github/workflows/pull.yml +++ b/.github/workflows/pull.yml @@ -973,7 +973,10 @@ jobs: # "Classic" Operator tests PYTHON_EXECUTABLE=python bash backends/vulkan/test/scripts/test_op.sh --build - ./cmake-out/backends/vulkan/test/op_tests/vulkan_sdpa_test + # TODO(ssjia): figure out how to run custom op tests in CI. Currently, they are + # failing due to to the libstdc++.so.6 installed with conda not supporting + # GLIBCXX_3.4.30. These tests are still run in Meta internal CI. + # ./cmake-out/backends/vulkan/test/op_tests/vulkan_sdpa_test # Run e2e testing for selected operators. More operators will be tested via this # route in the future. From 12b0f12792133cc7c10f1170e3f191a9be112d8a Mon Sep 17 00:00:00 2001 From: ssjia Date: Tue, 9 Sep 2025 17:29:32 -0700 Subject: [PATCH 3/3] Update on "[ET-VK] Implement SDPA with fused ops" ## Context As title; optimize the SDPA operator by introducing shaders to perform the operation in 3 steps: 1. Compute attention weights, multiplying QT x K_cache, and applying scale and mask 2. Compute softmax normalization of computed attention weights 3. Compute final output by multiplying attention weights with V cache This new implementation is much more efficient than the existing one, which performed slicing, repeat_interleave, and transposition of projected and cache tensors as separate steps. The fusion of scale and mask with the computation of attention weights also allows for the computation of elements within the mask region to be skipped. ## Impact Decode latency for LLMs is much improved. For llama 3.2 3B generating ~250 tokens, decode latency increases from ~15 tok/s to ~21.5 tok/s Differential Revision: [D82053493](https://our.internmc.facebook.com/intern/diff/D82053493/) [ghstack-poisoned]