diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc index 00f60142df159..606dbfde15c2c 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc @@ -16,6 +16,32 @@ namespace onnxruntime { namespace contrib { namespace webgpu { +Status SplitPackedQKVWithRotaryEmbeddingAndCopyKVProgram::GenerateShaderCode(ShaderHelper& sh) const { + const auto& packed_qkv = sh.AddInput("packed_qkv", ShaderUsage::UseUniform); + const auto& seqlens = sh.AddInput("seqlens", ShaderUsage::UseUniform); + const auto& cos_cache = sh.AddInput("cos_cache", ShaderUsage::UseUniform); + const auto& sin_cache = sh.AddInput("sin_cache", ShaderUsage::UseUniform); + + const auto& query = sh.AddOutput("query", ShaderUsage::UseUniform); + const auto& present_key = sh.AddOutput("present_key", ShaderUsage::UseUniform); + const auto& present_value = sh.AddOutput("present_value", ShaderUsage::UseUniform); + + if (prepare_indirect_dispatch_) { + sh.AddOutput("indirect_buffer", ShaderUsage::None); + } + + return WGSL_TEMPLATE_APPLY(sh, "bert/split_packed_qkv_with_rotary_embedding_and_copykv.wgsl.template", + WGSL_TEMPLATE_PARAMETER(interleaved, interleaved_), + WGSL_TEMPLATE_PARAMETER(prepare_indirect_dispatch, prepare_indirect_dispatch_), + WGSL_TEMPLATE_VARIABLE(cos_cache, cos_cache), + WGSL_TEMPLATE_VARIABLE(packed_qkv, packed_qkv), + WGSL_TEMPLATE_VARIABLE(present_key, present_key), + WGSL_TEMPLATE_VARIABLE(present_value, present_value), + WGSL_TEMPLATE_VARIABLE(query, query), + WGSL_TEMPLATE_VARIABLE(seqlens, seqlens), + WGSL_TEMPLATE_VARIABLE(sin_cache, sin_cache)); +} + Status CopyKVCacheProgram::GenerateShaderCode(ShaderHelper& shader) const { // Expectations are // qkv have same number of heads and hidden dimension (head size). @@ -351,17 +377,54 @@ Status ComputeFlashAttentionDecodeVxReduce(onnxruntime::webgpu::ComputeContext& Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const Tensor* attention_bias, Tensor* output, const Tensor* past_key, Tensor* present_key, const Tensor* past_value, Tensor* present_value, - const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* seqlen_k) { + const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* seqlen_k, + const Tensor* cos_cache, const Tensor* sin_cache) { + constexpr uint32_t tile_size = 64; + // Extract present_sequence_length directly from present_key tensor shape: // (batch_size, num_heads, total_sequence_length/max_sequence_length, head_size) const uint32_t present_sequence_length = static_cast(present_key->Shape()[2]); const bool use_seqlen_k = seqlen_k != nullptr && context.IsGraphCaptureEnabled(); + // Declare query_output at function scope to ensure it persists throughout the function + Tensor query_output; + + // Create indirect dispatch buffer if using indirect dispatch + Tensor* indirect_buffer_ptr = nullptr; + Tensor indirect_buffer; + + // Prepare indirect dispatch buffer for decode path with static KV cache + const bool use_indirect_dispatch = parameters.sequence_length_ == 1 && + parameters.past_present_share_buffer_ && + seqlen_k != nullptr && + context.IsGraphCaptureEnabled(); + if (use_indirect_dispatch) { + const TensorShape indirect_buffer_shape{3}; // 3 uint32 values for dispatch dimensions + indirect_buffer = context.CreateGPUTensor(DataTypeImpl::GetType(), indirect_buffer_shape); + indirect_buffer_ptr = &indirect_buffer; + } + + const bool do_rotary = (cos_cache != nullptr && sin_cache != nullptr); + + if (do_rotary) { + ORT_ENFORCE(parameters.is_packed_qkv_, "Fused SplitPackedQKVWithRotaryEmbeddingAndCopyKV requires packed QKV input."); + ORT_ENFORCE(parameters.past_present_share_buffer_, "Fused SplitPackedQKVWithRotaryEmbeddingAndCopyKV requires static KV cache."); + + // Q points to the packed QKV tensor in this case, create query output tensor + query_output = context.CreateGPUTensor(Q->DataType(), TensorShape({parameters.batch_size_, parameters.sequence_length_, parameters.hidden_size_})); + + ORT_RETURN_IF_ERROR(RunSplitPackedQKVWithRotaryEmbeddingAndCopyKV(context, parameters, + Q, seqlen_k, + cos_cache, sin_cache, + &query_output, present_key, present_value, + indirect_buffer_ptr, tile_size)); + Q = &query_output; + } else { + ORT_RETURN_IF_ERROR(CopyKVCache(context, parameters, K, past_key, present_key, V, past_value, present_value, tile_size, use_indirect_dispatch ? seqlen_k : nullptr, indirect_buffer_ptr)); + } + if (parameters.sequence_length_ > 1) { - const uint32_t tile_size = 64; - // For encode path, use the original CopyKVCache without indirect dispatch preparation - ORT_RETURN_IF_ERROR(CopyKVCache(context, parameters, K, past_key, present_key, V, past_value, present_value, tile_size, use_seqlen_k ? seqlen_k : nullptr, nullptr)); bool has_attention_bias = attention_bias != nullptr; bool is_qualcomm = context.AdapterInfo().vendor == std::string_view{"qualcomm"}; bool is_nvidia = context.AdapterInfo().vendor == std::string_view{"nvidia"}; @@ -406,29 +469,9 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co parameters.sequence_length_, present_sequence_length}); const TensorShape qk_shape(qk_dims); Tensor qk = context.CreateGPUTensor(Q->DataType(), qk_shape); - constexpr uint32_t tile_size = 64; const uint32_t num_total_seq_length_tile = (parameters.total_sequence_length_ + tile_size - 1) / tile_size; const uint32_t num_present_sequence_length_tile = (present_sequence_length + tile_size - 1) / tile_size; - // Determine if we should use indirect dispatch - const bool use_indirect_dispatch = parameters.past_present_share_buffer_ && - seqlen_k != nullptr && - context.IsGraphCaptureEnabled(); - - // Create indirect dispatch buffer if using indirect dispatch - Tensor* indirect_buffer_ptr = nullptr; - Tensor indirect_buffer; - if (use_indirect_dispatch) { - const TensorShape indirect_buffer_shape{3}; // 3 uint32 values for dispatch dimensions - indirect_buffer = context.CreateGPUTensor(DataTypeImpl::GetType(), indirect_buffer_shape); - indirect_buffer_ptr = &indirect_buffer; - // Use the fused CopyKVCache that also prepares the indirect dispatch buffer - ORT_RETURN_IF_ERROR(CopyKVCache(context, parameters, K, past_key, present_key, V, past_value, present_value, tile_size, seqlen_k, indirect_buffer_ptr)); - } else { - // Use the original CopyKVCache without indirect dispatch preparation - ORT_RETURN_IF_ERROR(CopyKVCache(context, parameters, K, past_key, present_key, V, past_value, present_value, tile_size, seqlen_k, nullptr)); - } - // The metadata is used to store the max and sum of each tile. const TensorShapeVector metadata_dims({parameters.batch_size_, parameters.num_heads_, num_present_sequence_length_tile, 2}); @@ -467,6 +510,78 @@ bool CanApplyFlashAttention(const Tensor* bias, const Tensor* present_key, const ((context.AdapterInfo().vendor == std::string_view{"qualcomm"} && parameters.head_size_ % 8 == 0) || parameters.head_size_ % 4 == 0); } +Status RunSplitPackedQKVWithRotaryEmbeddingAndCopyKV(onnxruntime::webgpu::ComputeContext& context, + const WebgpuAttentionParameters& params, + const Tensor* packedQKV, + const Tensor* seqlen_k, + const Tensor* cos_cache, + const Tensor* sin_cache, + Tensor* query, + Tensor* present_key, + Tensor* present_value, + Tensor* indirect_buffer, + uint32_t tile_size) { + const auto half_rotary_embedding_dim = gsl::narrow_cast(cos_cache->Shape()[1]); + const auto head_size = params.head_size_; + + int components = 1; + // Currently we only support vectorization when RoPE is not interleaved + if (!params.rotary_interleaved_) { + if ((params.head_size_ % 4 == 0) && (half_rotary_embedding_dim % 4 == 0)) { + components = 4; + } else if ((params.head_size_ % 2 == 0) && (half_rotary_embedding_dim % 2 == 0)) { + components = 2; + } + } + // Adjust dimensions for vectorization + const auto half_rotary_embedding_dim_vec = half_rotary_embedding_dim / components; + const auto head_size_vec = head_size / components; + + // Dispatch: batch_size * sequence_length * num_heads * (half_rotary_dim + need_copy_dim) + // work_per_head = half_rotary_dim + (head_size - 2 * half_rotary_dim) + // = head_size - half_rotary_dim + const auto work_per_head = head_size_vec - half_rotary_embedding_dim_vec; + auto dispatch_size = static_cast(params.batch_size_ * params.sequence_length_ * params.num_heads_ * work_per_head); + + // Extract present_sequence_length from present_key tensor shape + const uint32_t present_sequence_length = gsl::narrow_cast(present_key->Shape()[2]); + + const bool prepare_indirect_dispatch = (indirect_buffer != nullptr); + + SplitPackedQKVWithRotaryEmbeddingAndCopyKVProgram program(params.rotary_interleaved_, prepare_indirect_dispatch); + program + .CacheHint(params.rotary_interleaved_, prepare_indirect_dispatch) + .AddInput({packedQKV, ProgramTensorMetadataDependency::TypeAndRank, components}) + .AddInputs({ + {seqlen_k, ProgramTensorMetadataDependency::TypeAndRank}, + {cos_cache, ProgramTensorMetadataDependency::Rank, components}, + {sin_cache, ProgramTensorMetadataDependency::Rank, components}, + }); + program.AddOutputs({{query, ProgramTensorMetadataDependency::None, components}, + {present_key, ProgramTensorMetadataDependency::None, components}, + {present_value, ProgramTensorMetadataDependency::None, components}}); + + if (prepare_indirect_dispatch) { + program.AddOutput({indirect_buffer, ProgramTensorMetadataDependency::None}); + } + + program.AddUniformVariables({ + {static_cast(params.sequence_length_)}, + {static_cast(params.hidden_size_ / components)}, + {static_cast(params.kv_hidden_size_ / components)}, + {static_cast(params.num_heads_)}, + {static_cast(params.kv_num_heads_)}, + {head_size_vec}, + {half_rotary_embedding_dim_vec}, + {present_sequence_length}, + {tile_size}, + {static_cast(dispatch_size)}, + }); + + program.SetDispatchGroupSize((dispatch_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE); + return context.RunProgram(program); +} + } // namespace webgpu } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h index 9599c10533351..a936a91695921 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h @@ -15,6 +15,32 @@ namespace webgpu { using namespace onnxruntime::webgpu; +class SplitPackedQKVWithRotaryEmbeddingAndCopyKVProgram final : public Program { + public: + SplitPackedQKVWithRotaryEmbeddingAndCopyKVProgram(bool interleaved, bool prepare_indirect_dispatch) + : Program{"SplitPackedQKVWithRotaryEmbeddingAndCopyKV"}, + interleaved_(interleaved), + prepare_indirect_dispatch_(prepare_indirect_dispatch) {} + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( + {"sequence_length", ProgramUniformVariableDataType::Uint32}, + {"hidden_size", ProgramUniformVariableDataType::Uint32}, + {"kv_hidden_size", ProgramUniformVariableDataType::Uint32}, + {"num_heads", ProgramUniformVariableDataType::Uint32}, + {"kv_num_heads", ProgramUniformVariableDataType::Uint32}, + {"head_size", ProgramUniformVariableDataType::Uint32}, + {"half_rotary_dim", ProgramUniformVariableDataType::Uint32}, + {"present_sequence_length", ProgramUniformVariableDataType::Uint32}, + {"tile_size", ProgramUniformVariableDataType::Uint32}, + {"dispatch_size", ProgramUniformVariableDataType::Uint32}); + + private: + const bool interleaved_; + const bool prepare_indirect_dispatch_; +}; + class CopyKVCacheProgram final : public Program { public: CopyKVCacheProgram(const std::string& kernel_name, bool has_past, bool kv_BNSH, @@ -145,10 +171,24 @@ class FlashAttentionDecodeVxReduceProgram final : public ProgramShape().Size(); program - .AddInput({packedQKV, ProgramTensorMetadataDependency::Rank}) - .AddOutputs({{query, ProgramTensorMetadataDependency::Rank}, {key, ProgramTensorMetadataDependency::Rank}, {val, ProgramTensorMetadataDependency::Rank}}) + .AddInput({packedQKV, ProgramTensorMetadataDependency::TypeAndRank}) + .AddOutputs({{query, ProgramTensorMetadataDependency::None}, {key, ProgramTensorMetadataDependency::None}, {val, ProgramTensorMetadataDependency::None}}) .AddUniformVariables({ {static_cast(params.hidden_size_)}, {static_cast(params.kv_hidden_size_)}, @@ -90,32 +90,46 @@ Status RunSplitPackedQKVWithRotaryEmbedding(onnxruntime::webgpu::ComputeContext& const auto half_rotary_embedding_dim = gsl::narrow_cast(cos_cache->Shape()[1]); const auto head_size = params.head_size_; + int components = 1; + // Currently we only support vectorization when RoPE is not interleaved + if (!params.rotary_interleaved_) { + if ((params.head_size_ % 4 == 0) && (half_rotary_embedding_dim % 4 == 0)) { + components = 4; + } else if ((params.head_size_ % 2 == 0) && (half_rotary_embedding_dim % 2 == 0)) { + components = 2; + } + } + + // Adjust dimensions for vectorization + const auto half_rotary_embedding_dim_vec = half_rotary_embedding_dim / components; + const auto head_size_vec = head_size / components; + // Dispatch: batch_size * sequence_length * num_heads * (half_rotary_dim + need_copy_dim) // work_per_head = half_rotary_dim + (head_size - 2 * half_rotary_dim) // = head_size - half_rotary_dim - const auto work_per_head = head_size - half_rotary_embedding_dim; - auto dispatch_size = static_cast(params.batch_size_ * params.sequence_length_ * params.num_heads_ * work_per_head); + const auto work_per_head_vec = head_size_vec - half_rotary_embedding_dim_vec; + auto dispatch_size = static_cast(params.batch_size_ * params.sequence_length_ * params.num_heads_ * work_per_head_vec); SplitPackedQKVWithRotaryEmbeddingProgram program(params.rotary_interleaved_); program .CacheHint(params.rotary_interleaved_) - .AddInput({packedQKV, ProgramTensorMetadataDependency::Rank}) + .AddInput({packedQKV, ProgramTensorMetadataDependency::TypeAndRank, components}) .AddInputs({ - {seqlen_k, ProgramTensorMetadataDependency::Rank}, - {cos_cache, ProgramTensorMetadataDependency::Rank}, - {sin_cache, ProgramTensorMetadataDependency::Rank}, + {seqlen_k, ProgramTensorMetadataDependency::TypeAndRank}, + {cos_cache, ProgramTensorMetadataDependency::Rank, components}, + {sin_cache, ProgramTensorMetadataDependency::Rank, components}, }) - .AddOutputs({{query, ProgramTensorMetadataDependency::Rank}, - {key, ProgramTensorMetadataDependency::Rank}, - {val, ProgramTensorMetadataDependency::Rank}}) + .AddOutputs({{query, ProgramTensorMetadataDependency::None, components}, + {key, ProgramTensorMetadataDependency::None, components}, + {val, ProgramTensorMetadataDependency::None, components}}) .AddUniformVariables({ {static_cast(params.sequence_length_)}, - {static_cast(params.hidden_size_)}, - {static_cast(params.kv_hidden_size_)}, + {static_cast(params.hidden_size_ / components)}, + {static_cast(params.kv_hidden_size_ / components)}, {static_cast(params.num_heads_)}, {static_cast(params.kv_num_heads_)}, - {static_cast(head_size)}, - {half_rotary_embedding_dim}, + {head_size_vec}, + {half_rotary_embedding_dim_vec}, {static_cast(dispatch_size)}, }) .SetDispatchGroupSize((dispatch_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE); @@ -177,15 +191,15 @@ Status RunFusedQKRotaryEmbedding(onnxruntime::webgpu::ComputeContext& context, program .CacheHint(params.rotary_interleaved_) .AddInputs({ - {query_in, ProgramTensorMetadataDependency::Rank}, + {query_in, ProgramTensorMetadataDependency::TypeAndRank}, {key_in, ProgramTensorMetadataDependency::Rank}, - {seqlen_k, ProgramTensorMetadataDependency::Rank}, + {seqlen_k, ProgramTensorMetadataDependency::TypeAndRank}, {cos_cache, ProgramTensorMetadataDependency::Rank}, {sin_cache, ProgramTensorMetadataDependency::Rank}, }) .AddOutputs({ - {query_out, ProgramTensorMetadataDependency::Rank}, - {key_out, ProgramTensorMetadataDependency::Rank}, + {query_out, ProgramTensorMetadataDependency::None}, + {key_out, ProgramTensorMetadataDependency::None}, }) .SetDispatchGroupSize((q_domain_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) .AddUniformVariables({ @@ -265,7 +279,26 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& Tensor qRotary; Tensor kRotary; + + // Use a sliding window if the total sequence exceeds the window's length. + bool use_sliding_window = (local_window_size_ != -1 && local_window_size_ < parameters.total_sequence_length_); + bool will_use_flash_attention = false; + if (head_sink == nullptr && !use_smooth_softmax_ && !use_sliding_window) { + // Create a temporary parameters copy with is_packed_qkv_ set to false to check if flash attention can be applied after unpacking + WebgpuAttentionParameters temp_params = parameters; + temp_params.is_packed_qkv_ = false; + will_use_flash_attention = CanApplyFlashAttention(attention_bias, present_key, present_value, temp_params, context); + } + if (parameters.is_packed_qkv_ && do_rotary_) { + // Use the ultimate fused operation when FlashAttention and static KV cache is enabled. + if (will_use_flash_attention && parameters.past_present_share_buffer_) { + // Directly call ApplyFlashAttention with fused split/rotary/copyKV enabled + // query points to packed QKV, K and V are nullptr since they're not needed + return ApplyFlashAttention(query, nullptr, nullptr, attention_bias, output, past_key, present_key, past_value, + present_value, parameters, context, seqlen_k, cos_cache, sin_cache); + } + // Fused: splitQKV + rotary QK qSplit = context.CreateGPUTensor(query->DataType(), TensorShape({parameters.batch_size_, parameters.sequence_length_, parameters.hidden_size_})); kSplit = context.CreateGPUTensor(query->DataType(), TensorShape({parameters.batch_size_, parameters.sequence_length_, parameters.kv_hidden_size_})); vSplit = context.CreateGPUTensor(query->DataType(), TensorShape({parameters.batch_size_, parameters.sequence_length_, parameters.kv_hidden_size_})); @@ -279,8 +312,8 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& key = &kSplit; value = &vSplit; } else { - // Original separate path if (parameters.is_packed_qkv_) { + // splitQKV qSplit = context.CreateGPUTensor(query->DataType(), TensorShape({parameters.batch_size_, parameters.sequence_length_, parameters.hidden_size_})); kSplit = context.CreateGPUTensor(query->DataType(), TensorShape({parameters.batch_size_, parameters.sequence_length_, parameters.kv_hidden_size_})); vSplit = context.CreateGPUTensor(query->DataType(), TensorShape({parameters.batch_size_, parameters.sequence_length_, parameters.kv_hidden_size_})); @@ -292,6 +325,7 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& value = &vSplit; } if (do_rotary_) { + // rotary QK qRotary = context.CreateGPUTensor(query->DataType(), query->Shape()); kRotary = context.CreateGPUTensor(key->DataType(), key->Shape()); ORT_RETURN_IF_ERROR(RunFusedQKRotaryEmbedding(context, parameters, @@ -304,11 +338,7 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& } } - // Use a sliding window if the total sequence exceeds the window's length. - bool use_sliding_window = (local_window_size_ != -1 && local_window_size_ < parameters.total_sequence_length_); - if (head_sink == nullptr && !use_smooth_softmax_ && - !use_sliding_window && - CanApplyFlashAttention(attention_bias, present_key, present_value, parameters, context)) { + if (will_use_flash_attention) { return ApplyFlashAttention(query, key, value, attention_bias, output, past_key, present_key, past_value, present_value, parameters, context, seqlen_k); } diff --git a/onnxruntime/contrib_ops/webgpu/bert/split_packed_qkv_with_rotary_embedding.wgsl.template b/onnxruntime/contrib_ops/webgpu/bert/split_packed_qkv_with_rotary_embedding.wgsl.template index b64448611079f..777be41ffb456 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/split_packed_qkv_with_rotary_embedding.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/bert/split_packed_qkv_with_rotary_embedding.wgsl.template @@ -36,11 +36,11 @@ $MAIN { // Calculate actual indices in the head for i and j #if interleaved - let idx_i = in_head_idx; - let idx_j = in_head_idx + 1u; + let idx_i = in_head_idx + in_head_idx; + let idx_j = idx_i + 1u; #else let idx_i = in_head_idx; - let idx_j = in_head_idx + uniforms.half_rotary_dim; + let idx_j = idx_i + uniforms.half_rotary_dim; #endif // Process Q pair diff --git a/onnxruntime/contrib_ops/webgpu/bert/split_packed_qkv_with_rotary_embedding_and_copykv.wgsl.template b/onnxruntime/contrib_ops/webgpu/bert/split_packed_qkv_with_rotary_embedding_and_copykv.wgsl.template new file mode 100644 index 0000000000000..d6cb654afa756 --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/bert/split_packed_qkv_with_rotary_embedding_and_copykv.wgsl.template @@ -0,0 +1,111 @@ +#param interleaved +#param prepare_indirect_dispatch + +#use guardAgainstOutOfBoundsWorkgroupSizes +#use .setByIndices .getByIndices .getByOffset + +$MAIN { + guardAgainstOutOfBoundsWorkgroupSizes(uniforms.dispatch_size); + + // Dispatch: batch * seq * num_heads * (half_rotary_dim + need_copy_dim) + // work_per_head = half_rotary_dim + (head_size - 2 * half_rotary_dim) + let work_per_head = uniforms.head_size - uniforms.half_rotary_dim; + let total_work = uniforms.num_heads * work_per_head; + + let batch_idx = global_idx / (uniforms.sequence_length * total_work); + let remainder1 = global_idx % (uniforms.sequence_length * total_work); + let seq_idx = remainder1 / total_work; + let remainder2 = remainder1 % total_work; + let head_idx = remainder2 / work_per_head; + let in_head_idx = remainder2 % work_per_head; + + // Calculate base offset in packed_qkv for this token + // Layout per token: [Q(hidden_size), K(kv_hidden_size), V(kv_hidden_size)] + let token_size = uniforms.hidden_size + 2u * uniforms.kv_hidden_size; + let base_offset = batch_idx * uniforms.sequence_length * token_size + seq_idx * token_size; + + // Calculate position_id (needed for rotary embedding) + let seqlen_i = seqlens.getByOffset(batch_idx); + let seqlen = u32(seqlen_i); + let total_seqlen = seqlen + 1u; + + let past_seqlen = total_seqlen - uniforms.sequence_length; + // `position_id` is used to get cos/sin cache and also as the time step index in present_key/present_value + let position_id = past_seqlen + seq_idx; + +#if prepare_indirect_dispatch + // Prepare indirect dispatch buffer for thread 0 + if (global_idx == 0u) { + let num_total_seq_length_tile = (total_seqlen + uniforms.tile_size - 1u) / uniforms.tile_size; + indirect_buffer[0] = num_total_seq_length_tile; + indirect_buffer[1] = uniforms.num_heads; + indirect_buffer[2] = 1u; + } +#endif + + if (in_head_idx < uniforms.half_rotary_dim) { + // Process a rotary pair (i, j) + let cos_v = cos_cache.getByIndices(vec2(position_id, in_head_idx)); + let sin_v = sin_cache.getByIndices(vec2(position_id, in_head_idx)); + + // Calculate actual indices in the head for i and j +#if interleaved + let idx_i = in_head_idx + in_head_idx; + let idx_j = idx_i + 1u; +#else + let idx_i = in_head_idx; + let idx_j = idx_i + uniforms.half_rotary_dim; +#endif + + // Process Q pair + let q_base = base_offset + head_idx * uniforms.head_size; + let q_i_offset = q_base + idx_i; + let q_j_offset = q_base + idx_j; + let q_i = packed_qkv.getByOffset(q_i_offset); + let q_j = packed_qkv.getByOffset(q_j_offset); + let q_re = q_i * cos_v - q_j * sin_v; + let q_im = q_i * sin_v + q_j * cos_v; + query.setByIndices(vec3(batch_idx, seq_idx, head_idx * uniforms.head_size + idx_i), q_re); + query.setByIndices(vec3(batch_idx, seq_idx, head_idx * uniforms.head_size + idx_j), q_im); + + // Process K and V pairs if within kv_num_heads + if (head_idx < uniforms.kv_num_heads) { + let k_base = base_offset + uniforms.hidden_size + head_idx * uniforms.head_size; + let k_i_offset = k_base + idx_i; + let k_j_offset = k_base + idx_j; + let k_i = packed_qkv.getByOffset(k_i_offset); + let k_j = packed_qkv.getByOffset(k_j_offset); + let k_re = k_i * cos_v - k_j * sin_v; + let k_im = k_i * sin_v + k_j * cos_v; + // Write K directly to present_key cache + present_key.setByIndices(vec4(batch_idx, head_idx, position_id, idx_i), k_re); + present_key.setByIndices(vec4(batch_idx, head_idx, position_id, idx_j), k_im); + + // V doesn't need rotary, just copy the pair to present_value cache + let v_base = base_offset + uniforms.hidden_size + uniforms.kv_hidden_size + head_idx * uniforms.head_size; + let v_i = packed_qkv.getByOffset(v_base + idx_i); + let v_j = packed_qkv.getByOffset(v_base + idx_j); + present_value.setByIndices(vec4(batch_idx, head_idx, position_id, idx_i), v_i); + present_value.setByIndices(vec4(batch_idx, head_idx, position_id, idx_j), v_j); + } + } else { + // Process non-rotary elements (direct copy) + let actual_idx = uniforms.half_rotary_dim + in_head_idx; + + // Copy Q + let q_offset = base_offset + head_idx * uniforms.head_size + actual_idx; + let q_data = packed_qkv.getByOffset(q_offset); + query.setByIndices(vec3(batch_idx, seq_idx, head_idx * uniforms.head_size + actual_idx), q_data); + + // Copy K and V if within kv_num_heads directly to present cache + if (head_idx < uniforms.kv_num_heads) { + let k_offset = base_offset + uniforms.hidden_size + head_idx * uniforms.head_size + actual_idx; + let k_data = packed_qkv.getByOffset(k_offset); + present_key.setByIndices(vec4(batch_idx, head_idx, position_id, actual_idx), k_data); + + let v_offset = base_offset + uniforms.hidden_size + uniforms.kv_hidden_size + head_idx * uniforms.head_size + actual_idx; + let v_data = packed_qkv.getByOffset(v_offset); + present_value.setByIndices(vec4(batch_idx, head_idx, position_id, actual_idx), v_data); + } + } +} // MAIN