diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h index f01ce985658aa..46d3e7e675e85 100644 --- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h @@ -379,11 +379,6 @@ Status CheckCustomAttentionInputs(const T* position_ids, } if (head_sink != nullptr) { - if (parameters.use_smooth_softmax) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "head_sink should not be provided when use_smooth_softmax is true."); - } - const auto& head_sink_shape = head_sink->Shape(); if (head_sink_shape.NumDimensions() != 1) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "head_sink must be a 1D tensor"); diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_data.h b/onnxruntime/contrib_ops/cuda/bert/attention_data.h index 691391ccef0d0..e08d120750a40 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_data.h +++ b/onnxruntime/contrib_ops/cuda/bert/attention_data.h @@ -156,6 +156,7 @@ struct GroupQueryAttentionData { int* seqlens_k = nullptr; const T* cos_cache = nullptr; const T* sin_cache = nullptr; + const T* head_sink = nullptr; // Flash buffers T* softmax_lse = nullptr; diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash.h index c24bf88fa729b..09ead61e7d80d 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash.h @@ -123,6 +123,7 @@ struct Flash_fwd_params : public Qkv_params { bool is_rotary_interleaved = false; + void* __restrict__ head_sink_ptr = nullptr; bool smooth_softmax = false; int num_splits = 0; // For split-KV version diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc index b0241c26aafc6..76704b5b29fcd 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc @@ -28,6 +28,7 @@ void set_params_fprop(Flash_fwd_params& params, void* q, void* k, void* v, + void* head_sink, void* out, void* cu_seqlens_q_d, void* cu_seqlens_k_d, @@ -50,7 +51,9 @@ void set_params_fprop(Flash_fwd_params& params, params.o_ptr = out; params.is_bf16 = is_bf16; + params.smooth_softmax = use_smooth_softmax; + params.head_sink_ptr = head_sink; // All stride are in elements, not bytes. if (kv_bsnh) { @@ -297,6 +300,8 @@ Status mha_fwd(const cudaDeviceProp& dprops, const int seqlen_q_rounded = round_multiple(seqlen_q, 128); const int seqlen_k_rounded = round_multiple(seqlen_k, 128); + constexpr void* head_sink = nullptr; + Flash_fwd_params params; set_params_fprop(params, batch_size, @@ -304,7 +309,7 @@ Status mha_fwd(const cudaDeviceProp& dprops, seqlen_q_rounded, seqlen_k_rounded, num_heads, num_heads_k, head_size, head_size_rounded, - q, k, v, out, + q, k, v, head_sink, out, /*cu_seqlens_q*/ nullptr, /*cu_seqlens_k*/ nullptr, /*seqused_k=*/nullptr, @@ -376,6 +381,8 @@ Status mha_varlen_fwd(const cudaDeviceProp& dprops, const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128); const bool paged_KV = block_table != nullptr; + constexpr void* head_sink = nullptr; + Flash_fwd_params params; set_params_fprop(params, batch_size, @@ -383,7 +390,7 @@ Status mha_varlen_fwd(const cudaDeviceProp& dprops, seqlen_q_rounded, seqlen_k_rounded, num_heads, num_heads_k, head_size, head_size_rounded, - q, k, v, out, + q, k, v, head_sink, out, cu_seqlens_q, cu_seqlens_k, seqused_k, @@ -443,6 +450,7 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops, void* seqlens_k_, // batch_size void* rotary_cos, // seqlen_ro x (rotary_dim / 2) void* rotary_sin, // seqlen_ro x (rotary_dim / 2) + void* head_sink, // num_heads int* block_table, // batch_size x max_num_blocks_per_seq int batch_size, int num_heads, @@ -480,7 +488,7 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops, seqlen_q_rounded, seqlen_k_rounded, num_heads, num_heads_k, head_size, head_size_rounded, - q, kcache, vcache, out, + q, kcache, vcache, head_sink, out, /*cu_seqlens_q_d=*/nullptr, /*cu_seqlens_k_d=*/nullptr, /*seqused_k=*/nullptr, diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h index e28e38ea3ed93..e29dd7c1c231d 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h @@ -98,6 +98,7 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops, void* seqlens_k_, // batch_size void* rotary_cos, // seqlen_ro x (rotary_dim / 2) void* rotary_sin, // seqlen_ro x (rotary_dim / 2) + void* head_sink, // num_heads int* block_table, // batch_size x max_num_blocks_per_seq int batch_size, int num_heads, diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_kernel.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_kernel.h index 4110e715c4391..91104b8c3dfe0 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_kernel.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_kernel.h @@ -369,8 +369,10 @@ inline __device__ void compute_attn_1rowblock(const Params& params, const int bi } // Epilogue - - Tensor lse = softmax.template normalize_softmax_lse<>(acc_o, params.scale_softmax, params.smooth_softmax); + float sink = (params.head_sink_ptr != nullptr) + ? reinterpret_cast(params.head_sink_ptr)[bidh] + : (params.smooth_softmax ? 0.0f : -kInfinity); + Tensor lse = softmax.template normalize_softmax_lse<>(acc_o, params.scale_softmax, sink); // Convert acc_o from fp32 to fp16/bf16 Tensor rO = flash::convert_type(acc_o); @@ -928,8 +930,10 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons } // Epilogue - - Tensor lse = softmax.template normalize_softmax_lse(acc_o, params.scale_softmax, params.smooth_softmax); + float sink = (params.head_sink_ptr != nullptr) + ? reinterpret_cast(params.head_sink_ptr)[bidh] + : (params.smooth_softmax ? 0.0f : -std::numeric_limits::infinity()); + Tensor lse = softmax.template normalize_softmax_lse(acc_o, params.scale_softmax, sink); Tensor sOaccum = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N) // Partition sO to match the accumulator partitioning diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/softmax.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/softmax.h index 7fe506e01a9b9..c7a8476f5beae 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/softmax.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/softmax.h @@ -18,6 +18,7 @@ namespace flash { using namespace cute; //////////////////////////////////////////////////////////////////////////////////////////////////// +constexpr float kInfinity = std::numeric_limits::infinity(); template __device__ __forceinline__ void thread_reduce_(Tensor const& tensor, Tensor& summary, Operator& op) { @@ -72,9 +73,7 @@ __forceinline__ __device__ void scale_apply_exp2(Tensor& tenso // If max is -inf, then all elements must have been -inf (possibly due to masking). // We don't want (-inf - (-inf)) since that would give NaN. // If we don't have float around M_LOG2E the multiplication is done in fp64. - const float max_scaled = max(mi) == -std::numeric_limits::infinity() - ? 0.f - : max(mi) * (Scale_max ? scale : float(M_LOG2E)); + const float max_scaled = max(mi) == -kInfinity ? 0.f : max(mi) * (Scale_max ? scale : float(M_LOG2E)); #pragma unroll for (int ni = 0; ni < size<1>(tensor); ++ni) { // Instead of computing exp(x - max), we compute exp2(x * log_2(e) - @@ -85,38 +84,6 @@ __forceinline__ __device__ void scale_apply_exp2(Tensor& tenso } } -// Apply the exp to all the elements. -template -__forceinline__ __device__ void max_scale_exp2_sum(Tensor& tensor, Tensor& max, Tensor& sum, const float scale) { - static_assert(Layout0::rank == 2, "Only support 2D Tensor"); - static_assert(Layout1::rank == 1, "Only support 1D Tensor"); - CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor)); -#pragma unroll - for (int mi = 0; mi < size<0>(tensor); ++mi) { - MaxOp max_op; - max(mi) = zero_init ? tensor(mi, 0) : max_op(max(mi), tensor(mi, 0)); -#pragma unroll - for (int ni = 1; ni < size<1>(tensor); ni++) { - max(mi) = max_op(max(mi), tensor(mi, ni)); - } - max(mi) = Allreduce<4>::run(max(mi), max_op); - // If max is -inf, then all elements must have been -inf (possibly due to masking). - // We don't want (-inf - (-inf)) since that would give NaN. - const float max_scaled = max(mi) == -std::numeric_limits::infinity() ? 0.f : max(mi) * scale; - sum(mi) = 0; -#pragma unroll - for (int ni = 0; ni < size<1>(tensor); ++ni) { - // Instead of computing exp(x - max), we compute exp2(x * log_2(e) - - // max * log_2(e)) This allows the compiler to use the ffma - // instruction instead of fadd and fmul separately. - tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled); - sum(mi) += tensor(mi, ni); - } - SumOp sum_op; - sum(mi) = Allreduce<4>::run(sum(mi), sum_op); - } -} - //////////////////////////////////////////////////////////////////////////////////////////////////// template @@ -143,10 +110,10 @@ struct Softmax { Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows); #pragma unroll - for (int mi = 0; mi < size(row_max); ++mi) { + for (int mi = 0; mi < size<0>(row_max); ++mi) { float scores_max_cur = !Check_inf ? row_max(mi) - : (row_max(mi) == -std::numeric_limits::infinity() ? 0.0f : row_max(mi)); + : (row_max(mi) == -kInfinity ? 0.0f : row_max(mi)); float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2); row_sum(mi) *= scores_scale; #pragma unroll @@ -154,6 +121,7 @@ struct Softmax { acc_o_rowcol(mi, ni) *= scores_scale; } } + flash::scale_apply_exp2(scores, row_max, softmax_scale_log2); // We don't do the reduce across threads here since we don't need to use the row_sum. // We do that reduce at the end when we need to normalize the softmax. @@ -162,27 +130,62 @@ struct Softmax { }; template - __forceinline__ __device__ TensorT normalize_softmax_lse(Tensor0& acc_o, float softmax_scale, bool smooth_softmax) { + __forceinline__ __device__ TensorT normalize_softmax_lse(Tensor0& acc_o, + float softmax_scale, + float sink) { // IMPORTANT: sink is a pre-scaled logit + SumOp sum_op; quad_allreduce_(row_sum, row_sum, sum_op); TensorT lse = make_fragment_like(row_sum); Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows); + + const bool use_sink = (sink != -kInfinity); + #pragma unroll for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) { - float sum = smooth_softmax ? row_sum(mi) + expf(-row_max(mi) * softmax_scale) : row_sum(mi); - float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum; + float sum = row_sum(mi); + float max_unscaled = row_max(mi); // Max of the qk scores, NOT scaled. + + if (use_sink) { + const float max_scaled = (max_unscaled == -kInfinity) + ? -kInfinity + : max_unscaled * softmax_scale; + + const float true_max_scaled = max(max_scaled, sink); + + // Rescale the intermediate the output accumulator (acc_o) and sum. + // They were calculated relative to `max_scaled` and must be + // rescaled to be relative to `true_max_scaled`. + const float rescale_factor = expf(max_scaled - true_max_scaled); + +#pragma unroll + for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { + acc_o_rowcol(mi, ni) *= rescale_factor; + } + + sum *= rescale_factor; + + // Add the sink to the sum. + sum += expf(sink - true_max_scaled); + + // The unscaled max that reflects the sink. It is used for the below LSE calculation. + max_unscaled = true_max_scaled / softmax_scale; + } + lse(mi) = (sum == 0.f || sum != sum) - ? (Split ? -std::numeric_limits::infinity() : std::numeric_limits::infinity()) - : row_max(mi) * softmax_scale + __logf(sum); - float scale = inv_sum; + ? (Split ? -kInfinity : kInfinity) + : max_unscaled * softmax_scale + __logf(sum); + + float inv_sum = (sum == 0.f || !isfinite(sum)) ? 1.f : 1.f / sum; #pragma unroll for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { - acc_o_rowcol(mi, ni) *= scale; + acc_o_rowcol(mi, ni) *= inv_sum; } } + return lse; - }; + } }; } // namespace flash diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc index 9cb93cbcd3f32..e5d2434a31808 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc @@ -78,6 +78,14 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { const Tensor* total_seqlen = context->Input(6); const Tensor* cos_cache = context->Input(7); const Tensor* sin_cache = context->Input(8); + const Tensor* position_ids = context->Input(9); + const Tensor* attention_bias = context->Input(10); + const Tensor* head_sink = context->Input(11); + + if (position_ids != nullptr || attention_bias != nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "position_ids and attention_bias are not supported in GroupQueryAttention cuda kernel."); + } auto& device_prop = GetDeviceProp(); GroupQueryAttentionParameters parameters; @@ -99,12 +107,17 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { scale_, softcap_, device_prop.maxThreadsPerBlock)); + + ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckCustomAttentionInputs(position_ids, + attention_bias, + head_sink, + parameters)); parameters.local_window_size = local_window_size_; parameters.is_unidirectional = is_unidirectional_; - parameters.use_smooth_softmax = use_smooth_softmax_; + parameters.use_smooth_softmax = use_smooth_softmax_ || head_sink != nullptr; parameters.zeros_count = kZerosCount; parameters.zero_ptr = zeros_.get(); - // parameters.left_padding = left_padding_; + int sequence_length = parameters.sequence_length; parameters.do_rotary = do_rotary_; parameters.rotary_interleaved = rotary_interleaved_; @@ -276,6 +289,10 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { data.sin_cache = reinterpret_cast(sin_cache->Data()); } + if (head_sink != nullptr) { + data.head_sink = reinterpret_cast(head_sink->Data()); + } + cublasHandle_t cublas = GetCublasHandle(context); return QkvToContext( diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu index bb450e476d5ba..19d496569f79e 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu @@ -460,11 +460,18 @@ Status FlashAttention( void* present_value = reinterpret_cast(const_cast(data.present_value)); void* cos_cache = reinterpret_cast(const_cast(data.cos_cache)); void* sin_cache = reinterpret_cast(const_cast(data.sin_cache)); + void* head_sink = reinterpret_cast(const_cast(data.head_sink)); bool past_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH; + + DUMP_TENSOR_INIT(); + DUMP_TENSOR("Q", reinterpret_cast(query), batch_size, sequence_length, num_heads, head_size); + DUMP_TENSOR("K", reinterpret_cast(present_key), batch_size, parameters.seqlen_present_kv_cache, kv_num_heads, head_size); + DUMP_TENSOR("V", reinterpret_cast(present_value), batch_size, parameters.seqlen_present_kv_cache, kv_num_heads, head_size); + ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd_kvcache( device_prop, stream, query, present_key, present_value, key, value, data.output, - reinterpret_cast(data.softmax_lse), seqlens_k, cos_cache, sin_cache, /*block_table*/ nullptr, + reinterpret_cast(data.softmax_lse), seqlens_k, cos_cache, sin_cache, head_sink, /*block_table*/ nullptr, batch_size, num_heads, kv_num_heads, head_size, sequence_length, parameters.seqlen_present_kv_cache, kv_sequence_length, parameters.rotary_dim, scale, parameters.softcap, is_causal, is_bf16, parameters.use_smooth_softmax, past_bsnh, parameters.num_splits, @@ -475,7 +482,6 @@ Status FlashAttention( // ORT_RETURN_IF_ERROR(LaunchLeftPadLast(parameters, data, stream, device_prop.maxThreadsPerBlock)); // } - DUMP_TENSOR_INIT(); DUMP_TENSOR("flash attention output", data.output, batch_size, sequence_length, num_heads, head_size); return Status::OK(); @@ -680,6 +686,11 @@ template Status QkvToContext( contrib::GroupQueryAttentionParameters& parameters, GroupQueryAttentionData& data); +template Status LaunchUnpackQKV( + const half* packed_qkv, half* unpacked_q, half* unpacked_k, half* unpacked_v, const int num_heads, + const int kv_num_heads, const int head_size, const int sequence_length, const int batch_size, + cudaStream_t stream, const int max_threads_per_block); + template struct GroupQueryAttentionData; template Status QkvToContext( @@ -689,11 +700,6 @@ template Status QkvToContext( contrib::GroupQueryAttentionParameters& parameters, GroupQueryAttentionData& data); -template Status LaunchUnpackQKV( - const half* packed_qkv, half* unpacked_q, half* unpacked_k, half* unpacked_v, const int num_heads, - const int kv_num_heads, const int head_size, const int sequence_length, const int batch_size, - cudaStream_t stream, const int max_threads_per_block); - template Status LaunchUnpackQKV( const BFloat16* packed_qkv, BFloat16* unpacked_q, BFloat16* unpacked_k, BFloat16* unpacked_v, const int num_heads, const int kv_num_heads, const int head_size, const int sequence_length, const int batch_size, diff --git a/onnxruntime/test/python/transformers/test_gqa.py b/onnxruntime/test/python/transformers/test_gqa.py new file mode 100644 index 0000000000000..3163bb33a3a82 --- /dev/null +++ b/onnxruntime/test/python/transformers/test_gqa.py @@ -0,0 +1,1167 @@ +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +# Copyright 2020 The HuggingFace Inc. team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# ------------------------------------------------------------------------- +import math +import os +import platform +import random +import unittest +from dataclasses import dataclass + +import numpy +import torch +from einops import rearrange, repeat +from onnx import TensorProto, helper +from parameterized import parameterized + +from onnxruntime import InferenceSession, OrtValue, SessionOptions, get_available_providers + +# Set seed for reproducibility +torch.manual_seed(0) +random.seed(69) + +# Reduces number of tests to run for faster pipeline checks +pipeline_mode = os.getenv("PIPELINE_MODE", "1") == "1" + +# ################################################################################################# +# Configuration and Helper Classes +# ################################################################################################# + + +@dataclass +class GQAConfig: + batch_size: int + q_sequence_length: int + kv_sequence_length: int + num_heads: int + kv_num_heads: int + head_size: int + past_kv_sequence_length: int = 0 + buffer_sequence_length: int = 0 + # Test-specific parameters + local_window_size: int = -1 + rotary: bool = False + rotary_interleaved: bool = False + packed: bool = False + softcap: float = 0.0 + use_smooth_softmax: bool = False + # CPU-only parameters + has_position_ids: bool = False + has_attention_bias: bool = False + has_head_sink: bool = False + + +# ################################################################################################# +# Rotary Embedding Implementations (CPU and CUDA) +# ################################################################################################# + + +# PyTorch implementation for CPU and fallback +class LlamaMSRotaryEmbedding(torch.nn.Module): + def __init__(self): + super().__init__() + + def rotate_tensor(self, x, cos, sin, pos, interleaved): + rot_dim = 2 * cos.shape[3] + x_rot = x[:, :, :, :rot_dim] + + if interleaved: + x1 = x_rot[:, :, :, 0::2] + x2 = x_rot[:, :, :, 1::2] + else: + half = x_rot.shape[-1] // 2 + x1 = x_rot[:, :, :, 0:half] + x2 = x_rot[:, :, :, half : 2 * half] + + seq_len = x.shape[1] + batch_size = x.shape[0] + + cos = cos.squeeze(0).squeeze(1) + sin = sin.squeeze(0).squeeze(1) + + if seq_len == 1: + pos_i = pos.long() + cos_x = cos[pos_i].unsqueeze(1) + sin_x = sin[pos_i].unsqueeze(1) + else: + cos_x_list = [] + sin_x_list = [] + for b in range(batch_size): + pos_b = pos[b] + cos_x_list.append(cos[pos_b : pos_b + seq_len]) + sin_x_list.append(sin[pos_b : pos_b + seq_len]) + cos_x = torch.stack(cos_x_list, dim=0) + sin_x = torch.stack(sin_x_list, dim=0) + + cos_x = cos_x.unsqueeze(2) + sin_x = sin_x.unsqueeze(2) + + real = cos_x * x1 - sin_x * x2 + imag = sin_x * x1 + cos_x * x2 + + if interleaved: + x_rot[:, :, :, 0::2] = real + x_rot[:, :, :, 1::2] = imag + else: + x_rot = torch.cat((real, imag), dim=-1) + + return torch.cat((x_rot, x[:, :, :, rot_dim:]), dim=-1) + + def forward(self, x, cos, sin, pos, interleaved): + return self.rotate_tensor(x, cos, sin, pos, interleaved) + + +# Triton-based implementation for CUDA +def rotary_embedding_cuda(*args, **kwargs): + from rotary_flash import apply_rotary_emb # noqa: PLC0415 + + return apply_rotary_emb(*args, **kwargs) + + +# Unified wrapper for rotary embeddings +def apply_rotary_embedding(x, cos, sin, pos, interleaved, device="cpu"): + """Applies rotary embedding, using Triton for CUDA if available, otherwise fallback to PyTorch.""" + use_cuda_triton = device == "cuda" and platform.system() == "Linux" + if use_cuda_triton: + try: + return rotary_embedding_cuda(x, cos, sin, seqlen_offsets=pos, interleaved=interleaved) + except ImportError: + print("WARNING: Triton-based rotary embedding not found. Falling back to PyTorch version.") + + # PyTorch implementation for CPU or as a fallback for CUDA + rot = LlamaMSRotaryEmbedding().to(device) + # Unsqueeze to match the expected shape in the PyTorch version + cos_unsqueezed = cos.unsqueeze(0).unsqueeze(2) + sin_unsqueezed = sin.unsqueeze(0).unsqueeze(2) + return rot(x, cos_unsqueezed, sin_unsqueezed, pos, interleaved) + + +# ################################################################################################# +# ONNX Graph Creation +# ################################################################################################# + + +def create_group_query_attention_graph_prompt( + config: GQAConfig, + ort_type, + share_buffer=True, +): + assert not (config.has_head_sink and config.use_smooth_softmax) + past_kv_seqlen = config.buffer_sequence_length if share_buffer else 0 + present_kv_seqlen = config.buffer_sequence_length if share_buffer else config.kv_sequence_length + + nodes = [ + helper.make_node( + op_type="GroupQueryAttention", + inputs=[ + "query", + "key" if not config.packed else "", + "value" if not config.packed else "", + "past_key" if share_buffer else "", + "past_value" if share_buffer else "", + "seqlens_k", + "total_sequence_length", + "cos_cache" if config.rotary else "", + "sin_cache" if config.rotary else "", + "position_ids" if config.has_position_ids else "", + "attention_bias" if config.has_attention_bias else "", + "head_sink" if config.has_head_sink else "", + ], + outputs=["output", "present_key", "present_value"], + name="GroupQueryAttention_0", + num_heads=config.num_heads, + kv_num_heads=config.kv_num_heads, + local_window_size=config.local_window_size, + do_rotary=config.rotary, + rotary_interleaved=config.rotary_interleaved, + softcap=config.softcap, + smooth_softmax=1 if config.use_smooth_softmax else 0, + domain="com.microsoft", + ), + ] + + q_hidden_size = ( + (config.num_heads * config.head_size) + if not config.packed + else (config.num_heads * config.head_size + 2 * config.kv_num_heads * config.head_size) + ) + graph_input = [ + helper.make_tensor_value_info("query", ort_type, [config.batch_size, config.q_sequence_length, q_hidden_size]), + helper.make_tensor_value_info("seqlens_k", TensorProto.INT32, [config.batch_size]), + helper.make_tensor_value_info("total_sequence_length", TensorProto.INT32, [1]), + ] + + if not config.packed: + graph_input.extend( + [ + helper.make_tensor_value_info( + "key", + ort_type, + [config.batch_size, config.kv_sequence_length, config.kv_num_heads * config.head_size], + ), + helper.make_tensor_value_info( + "value", + ort_type, + [config.batch_size, config.kv_sequence_length, config.kv_num_heads * config.head_size], + ), + ] + ) + if share_buffer: + # Shape is (batch_size, kv_num_heads, sequence_length, head_size) + k_shape = [config.batch_size, config.kv_num_heads, past_kv_seqlen, config.head_size] + v_shape = k_shape + graph_input.extend( + [ + helper.make_tensor_value_info("past_key", ort_type, k_shape), + helper.make_tensor_value_info("past_value", ort_type, v_shape), + ] + ) + if config.rotary: + rotary_dim = (math.floor(config.head_size / 16) * 16) // 2 + cache_seq_len = config.buffer_sequence_length if share_buffer else config.kv_sequence_length + graph_input.extend( + [ + helper.make_tensor_value_info("cos_cache", ort_type, [cache_seq_len, rotary_dim]), + helper.make_tensor_value_info("sin_cache", ort_type, [cache_seq_len, rotary_dim]), + ] + ) + if config.has_position_ids: + graph_input.append( + helper.make_tensor_value_info( + "position_ids", TensorProto.INT64, [config.batch_size, config.q_sequence_length] + ) + ) + if config.has_attention_bias: + graph_input.append( + helper.make_tensor_value_info( + "attention_bias", ort_type, [config.batch_size, 1, config.q_sequence_length, config.kv_sequence_length] + ) + ) + if config.has_head_sink: + graph_input.append(helper.make_tensor_value_info("head_sink", ort_type, [config.num_heads])) + + # Shape is (batch_size, kv_num_heads, sequence_length, head_size) + output_k_shape = [config.batch_size, config.kv_num_heads, present_kv_seqlen, config.head_size] + output_v_shape = output_k_shape + + graph_output = [ + helper.make_tensor_value_info( + "output", ort_type, [config.batch_size, config.q_sequence_length, config.num_heads * config.head_size] + ), + helper.make_tensor_value_info("present_key", ort_type, output_k_shape), + helper.make_tensor_value_info("present_value", ort_type, output_v_shape), + ] + + graph = helper.make_graph(nodes, "GroupQueryAttention_Graph", graph_input, graph_output) + model = helper.make_model(graph) + return model.SerializeToString() + + +def create_group_query_attention_graph_past( + config: GQAConfig, + ort_type, + share_buffer=True, +): + assert not (config.has_head_sink and config.use_smooth_softmax) + + if share_buffer: + past_kv_seqlen = config.buffer_sequence_length + present_kv_seqlen = config.buffer_sequence_length + else: + past_kv_seqlen = config.past_kv_sequence_length + present_kv_seqlen = config.past_kv_sequence_length + config.kv_sequence_length + + nodes = [ + helper.make_node( + "GroupQueryAttention", + [ + "query", + "key" if not config.packed else "", + "value" if not config.packed else "", + "past_key", + "past_value", + "seqlens_k", + "total_sequence_length", + "cos_cache" if config.rotary else "", + "sin_cache" if config.rotary else "", + "position_ids" if config.has_position_ids else "", + "attention_bias" if config.has_attention_bias else "", + "head_sink" if config.has_head_sink else "", + ], + ["output", "present_key", "present_value"], + "GroupQueryAttention_0", + num_heads=config.num_heads, + kv_num_heads=config.kv_num_heads, + local_window_size=config.local_window_size, + do_rotary=config.rotary, + rotary_interleaved=config.rotary_interleaved, + softcap=config.softcap, + smooth_softmax=1 if config.use_smooth_softmax else 0, + domain="com.microsoft", + ), + ] + + q_hidden_size = ( + (config.num_heads * config.head_size) + if not config.packed + else (config.num_heads * config.head_size + 2 * config.kv_num_heads * config.head_size) + ) + # Shape is (batch_size, kv_num_heads, sequence_length, head_size) + past_k_shape = [config.batch_size, config.kv_num_heads, past_kv_seqlen, config.head_size] + graph_input = [ + helper.make_tensor_value_info("query", ort_type, [config.batch_size, config.q_sequence_length, q_hidden_size]), + helper.make_tensor_value_info("past_key", ort_type, past_k_shape), + helper.make_tensor_value_info("past_value", ort_type, past_k_shape), + helper.make_tensor_value_info("seqlens_k", TensorProto.INT32, [config.batch_size]), + helper.make_tensor_value_info("total_sequence_length", TensorProto.INT32, [1]), + ] + + if not config.packed: + graph_input.extend( + [ + helper.make_tensor_value_info( + "key", + ort_type, + [config.batch_size, config.q_sequence_length, config.kv_num_heads * config.head_size], + ), + helper.make_tensor_value_info( + "value", + ort_type, + [config.batch_size, config.q_sequence_length, config.kv_num_heads * config.head_size], + ), + ] + ) + + if config.rotary: + rotary_dim = (math.floor(config.head_size / 16) * 16) // 2 + cache_len = config.buffer_sequence_length + graph_input.extend( + [ + helper.make_tensor_value_info("cos_cache", ort_type, [cache_len, rotary_dim]), + helper.make_tensor_value_info("sin_cache", ort_type, [cache_len, rotary_dim]), + ] + ) + + if config.has_position_ids: + graph_input.append( + helper.make_tensor_value_info( + "position_ids", TensorProto.INT64, [config.batch_size, config.q_sequence_length] + ) + ) + if config.has_attention_bias: + graph_input.append( + helper.make_tensor_value_info( + "attention_bias", ort_type, [config.batch_size, 1, config.q_sequence_length, present_kv_seqlen] + ) + ) + if config.has_head_sink: + graph_input.append(helper.make_tensor_value_info("head_sink", ort_type, [config.num_heads])) + + output_k_shape = [ + config.batch_size, + config.kv_num_heads, + present_kv_seqlen, + config.head_size, + ] + + graph_output = [ + helper.make_tensor_value_info( + "output", ort_type, [config.batch_size, config.q_sequence_length, config.num_heads * config.head_size] + ), + helper.make_tensor_value_info("present_key", ort_type, output_k_shape), + helper.make_tensor_value_info("present_value", ort_type, output_k_shape), + ] + + graph = helper.make_graph(nodes, "GroupQueryAttention_Graph", graph_input, graph_output) + model = helper.make_model(graph) + return model.SerializeToString() + + +# ################################################################################################# +# ONNX Runtime Execution Functions +# ################################################################################################# + + +def gqa_prompt_func( + q, + k, + v, + config: GQAConfig, + new_k, + new_v, + cos, + sin, + seqlens_k, + position_ids, + attention_bias, + head_sink, + ep, + device, + share_buffer=True, + ort_type=TensorProto.FLOAT16, + numpy_type=numpy.float16, +): + onnx_model_str = create_group_query_attention_graph_prompt( + config=config, + ort_type=ort_type, + share_buffer=share_buffer, + ) + + q = torch.reshape(q, (config.batch_size, config.q_sequence_length, -1)) + if new_k is not None: + new_k = torch.reshape(new_k, (config.batch_size, config.kv_sequence_length, -1)) + new_v = torch.reshape(new_v, (config.batch_size, config.kv_sequence_length, -1)) + + sess_options = SessionOptions() + ort_session = InferenceSession(onnx_model_str, sess_options, providers=[ep]) + io_binding = ort_session.io_binding() + + # Common inputs + ort_inputs = { + "query": q.detach().cpu().numpy(), + "seqlens_k": seqlens_k.detach().cpu().numpy().astype(numpy.int32), + "total_sequence_length": torch.tensor([config.q_sequence_length], dtype=torch.int32).detach().cpu().numpy(), + } + io_binding.bind_cpu_input("query", ort_inputs["query"]) + io_binding.bind_cpu_input("seqlens_k", ort_inputs["seqlens_k"]) + io_binding.bind_cpu_input("total_sequence_length", ort_inputs["total_sequence_length"]) + + if new_k is not None: + ort_inputs["key"] = new_k.detach().cpu().numpy() + ort_inputs["value"] = new_v.detach().cpu().numpy() + io_binding.bind_cpu_input("key", ort_inputs["key"]) + io_binding.bind_cpu_input("value", ort_inputs["value"]) + if cos is not None: + ort_inputs["cos_cache"] = cos.detach().cpu().numpy() + ort_inputs["sin_cache"] = sin.detach().cpu().numpy() + io_binding.bind_cpu_input("cos_cache", ort_inputs["cos_cache"]) + io_binding.bind_cpu_input("sin_cache", ort_inputs["sin_cache"]) + + # CPU-specific inputs + if config.has_position_ids: + ort_inputs["position_ids"] = position_ids.detach().cpu().numpy() + io_binding.bind_cpu_input("position_ids", ort_inputs["position_ids"]) + if config.has_attention_bias: + ort_inputs["attention_bias"] = attention_bias.detach().cpu().numpy() + io_binding.bind_cpu_input("attention_bias", ort_inputs["attention_bias"]) + if config.has_head_sink: + ort_inputs["head_sink"] = head_sink.detach().cpu().numpy() + io_binding.bind_cpu_input("head_sink", ort_inputs["head_sink"]) + + if share_buffer: + past_k_ort = OrtValue.ortvalue_from_numpy(k.detach().cpu().numpy(), device, 0) + past_v_ort = OrtValue.ortvalue_from_numpy(v.detach().cpu().numpy(), device, 0) + io_binding.bind_input("past_key", device, 0, numpy_type, past_k_ort.shape(), past_k_ort.data_ptr()) + io_binding.bind_input("past_value", device, 0, numpy_type, past_v_ort.shape(), past_v_ort.data_ptr()) + io_binding.bind_output("output") + io_binding.bind_ortvalue_output("present_key", past_k_ort) + io_binding.bind_ortvalue_output("present_value", past_v_ort) + else: + io_binding.bind_output("output") + io_binding.bind_output("present_key") + io_binding.bind_output("present_value") + + ort_session.run_with_iobinding(io_binding) + ort_output, present_k, present_v = io_binding.copy_outputs_to_cpu() + return torch.tensor(ort_output), present_k, present_v + + +def gqa_past_func( + q, + k, + v, + config: GQAConfig, + new_k, + new_v, + cos, + sin, + seqlens_k, + position_ids, + attention_bias, + head_sink, + ep, + device, + share_buffer=True, + ort_type=TensorProto.FLOAT16, + numpy_type=numpy.float16, +): + onnx_model_str = create_group_query_attention_graph_past( + config=config, + ort_type=ort_type, + share_buffer=share_buffer, + ) + + q = torch.reshape(q, (config.batch_size, config.q_sequence_length, -1)) + if new_k is not None: + new_k = torch.reshape(new_k, (config.batch_size, config.q_sequence_length, -1)) + new_v = torch.reshape(new_v, (config.batch_size, config.q_sequence_length, -1)) + + sess_options = SessionOptions() + ort_session = InferenceSession(onnx_model_str, sess_options, providers=[ep]) + io_binding = ort_session.io_binding() + + # Common inputs + total_seq_len = ( + config.past_kv_sequence_length if share_buffer else config.past_kv_sequence_length + config.q_sequence_length + ) + ort_inputs = { + "query": q.detach().cpu().numpy(), + "seqlens_k": seqlens_k.detach().cpu().numpy().astype(numpy.int32), + "total_sequence_length": torch.tensor([total_seq_len], dtype=torch.int32).detach().cpu().numpy(), + } + io_binding.bind_cpu_input("query", ort_inputs["query"]) + io_binding.bind_cpu_input("seqlens_k", ort_inputs["seqlens_k"]) + io_binding.bind_cpu_input("total_sequence_length", ort_inputs["total_sequence_length"]) + + if new_k is not None: + ort_inputs["key"] = new_k.detach().cpu().numpy() + ort_inputs["value"] = new_v.detach().cpu().numpy() + io_binding.bind_cpu_input("key", ort_inputs["key"]) + io_binding.bind_cpu_input("value", ort_inputs["value"]) + if cos is not None: + ort_inputs["cos_cache"] = cos.detach().cpu().numpy() + ort_inputs["sin_cache"] = sin.detach().cpu().numpy() + io_binding.bind_cpu_input("cos_cache", ort_inputs["cos_cache"]) + io_binding.bind_cpu_input("sin_cache", ort_inputs["sin_cache"]) + + # CPU-specific inputs + if config.has_position_ids: + ort_inputs["position_ids"] = position_ids.detach().cpu().numpy() + io_binding.bind_cpu_input("position_ids", ort_inputs["position_ids"]) + if config.has_attention_bias: + ort_inputs["attention_bias"] = attention_bias.detach().cpu().numpy() + io_binding.bind_cpu_input("attention_bias", ort_inputs["attention_bias"]) + if config.has_head_sink: + ort_inputs["head_sink"] = head_sink.detach().cpu().numpy() + io_binding.bind_cpu_input("head_sink", ort_inputs["head_sink"]) + + # Binding past and present KV + if share_buffer: + past_k_ort = OrtValue.ortvalue_from_numpy(k.detach().cpu().numpy(), device, 0) + past_v_ort = OrtValue.ortvalue_from_numpy(v.detach().cpu().numpy(), device, 0) + io_binding.bind_input("past_key", device, 0, numpy_type, past_k_ort.shape(), past_k_ort.data_ptr()) + io_binding.bind_input("past_value", device, 0, numpy_type, past_v_ort.shape(), past_v_ort.data_ptr()) + io_binding.bind_output("output") + io_binding.bind_ortvalue_output("present_key", past_k_ort) + io_binding.bind_ortvalue_output("present_value", past_v_ort) + else: + ort_inputs["past_key"] = k.detach().cpu().numpy() + ort_inputs["past_value"] = v.detach().cpu().numpy() + io_binding.bind_cpu_input("past_key", ort_inputs["past_key"]) + io_binding.bind_cpu_input("past_value", ort_inputs["past_value"]) + io_binding.bind_output("output") + io_binding.bind_output("present_key") + io_binding.bind_output("present_value") + + ort_session.run_with_iobinding(io_binding) + ort_output, present_k, present_v = io_binding.copy_outputs_to_cpu() + return torch.tensor(ort_output), present_k, present_v + + +# ################################################################################################# +# Reference Attention Implementation +# ################################################################################################# + + +def construct_local_mask(seqlen_q, seqlen_k, window_size, query_padding_mask, key_padding_mask, device): + row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1") + col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long) + sk = seqlen_k if key_padding_mask is None else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1") + sq = seqlen_q if query_padding_mask is None else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1") + if window_size[0] < 0: + return col_idx > row_idx + sk - sq + window_size[1] + else: + sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk + return torch.logical_or( + col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk), + col_idx < row_idx + sk - sq - window_size[0], + ) + + +def smooth_softmax_ref(x, head_sink): + b, n, s, t = x.shape + if head_sink is not None: + sink = head_sink.reshape(1, n, 1, 1).expand(b, -1, s, -1) + else: + sink = torch.zeros(b, n, s, 1, dtype=x.dtype, device=x.device) + + y = torch.cat([x, sink], dim=-1) + y = torch.softmax(y, dim=-1) + return y[..., :-1] + + +def attention_ref( + q, + k, + v, + query_padding_mask=None, + key_padding_mask=None, + attention_bias=None, + causal=False, + window_size=(-1, -1), + softcap=0.0, + use_smooth_softmax=False, + head_sink=None, +): + if causal: + window_size = (window_size[0], 0) + + dtype_og = q.dtype + q, k, v = q.float(), k.float(), v.float() + seqlen_q, seqlen_k = q.shape[1], k.shape[1] + + # Repeat K/V heads for Grouped-Query Attention + if k.shape[2] != q.shape[2]: + k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2]) + if v.shape[2] != q.shape[2]: + v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2]) + + scores = torch.einsum("bthd,bshd->bhts", q, k) / math.sqrt(q.shape[-1]) + + if softcap > 0: + scores = (scores / softcap).tanh() * softcap + + if key_padding_mask is not None: + scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) + + local_mask = None + if window_size[0] >= 0 or window_size[1] >= 0: + local_mask = construct_local_mask( + seqlen_q, seqlen_k, window_size, query_padding_mask, key_padding_mask, q.device + ) + scores.masked_fill_(local_mask, float("-inf")) + + # Add custom attention bias if provided (for CPU tests) + if attention_bias is not None: + # The bias should only be applied to the relevant part of the scores matrix, + # matching the sequence length of the bias tensor. + scores[..., : attention_bias.shape[-1]] += attention_bias + + if use_smooth_softmax or (head_sink is not None): + # Note that the sink directly joins softmax. No scaling and softcap is needed! + attention = smooth_softmax_ref(scores, head_sink) + else: + attention = torch.softmax(scores, dim=-1) + + # Fill NaNs with 0 + if local_mask is not None: + attention = attention.masked_fill(torch.all(local_mask, dim=-1, keepdim=True), 0.0) + if query_padding_mask is not None: + attention = attention.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0) + + output = torch.einsum("bhts,bshd->bthd", attention, v) + + if query_padding_mask is not None: + output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0) + + return output.to(dtype=dtype_og), attention.to(dtype=dtype_og) + + +# ################################################################################################# +# Parity Check (Core Test Logic) +# ################################################################################################# + + +def parity_check_gqa_prompt( + config: GQAConfig, + ep, + device, + torch_type, + numpy_type, + ort_type, + causal, + rtol, + atol, +): + # Q/K/V have normal distribution with mean = 0 and standard deviation = 0.02. + # If we use standard deviation = 1, numerical stability issues may occur. + std = 0.02 + + # --- Test Data Generation --- + q = ( + torch.randn( + config.batch_size, + config.q_sequence_length, + config.num_heads, + config.head_size, + device=device, + dtype=torch_type, + ) + * std + ) + + # k and v are the cache buffers, created in BNSH format + k = ( + torch.randn( + config.batch_size, + config.kv_num_heads, + config.buffer_sequence_length, + config.head_size, + device=device, + dtype=torch_type, + ) + * std + ) + v = torch.randn_like(k) + + new_k = ( + torch.randn( + config.batch_size, + config.kv_sequence_length, + config.kv_num_heads, + config.head_size, + device=device, + dtype=torch_type, + ) + * std + ) + new_v = torch.randn_like(new_k) * std + + head_sink = torch.rand(config.num_heads, dtype=torch_type, device=device) if config.has_head_sink else None + + window_size = (-1, -1) + if config.local_window_size > 0: + window_size = (config.local_window_size, 0) + elif causal: + window_size = (-1, 0) + + # --- PyTorch Reference Path --- + # Transpose BNSH cache to BSNH format for reference implementation + k_cache_ref = k.clone().transpose(1, 2) + v_cache_ref = v.clone().transpose(1, 2) + + cache_seqlens = torch.full((config.batch_size,), config.kv_sequence_length, device=device, dtype=torch.int32) + rotary_seqlens = torch.zeros(config.batch_size, device=device, dtype=torch.long) + + cos, sin, q_ro, k_ro = None, None, q, new_k + if config.rotary: + rotary_dim = math.floor(config.head_size / 16) * 16 + angle = torch.rand(config.buffer_sequence_length, rotary_dim // 2, device=device) * 2 * math.pi + cos = torch.cos(angle).to(dtype=torch_type) + sin = torch.sin(angle).to(dtype=torch_type) + q_ro = apply_rotary_embedding(q.clone(), cos, sin, rotary_seqlens, config.rotary_interleaved, device) + k_ro = apply_rotary_embedding(new_k.clone(), cos, sin, rotary_seqlens, config.rotary_interleaved, device) + + position_ids = None + attention_bias = None + if ep == "CPUExecutionProvider": + if config.has_position_ids: + position_ids = ( + torch.arange(config.q_sequence_length, device=device).unsqueeze(0).expand(config.batch_size, -1) + ) + if config.has_attention_bias: + attention_bias = torch.zeros( + config.batch_size, + 1, + config.q_sequence_length, + config.kv_sequence_length, + device=device, + dtype=torch_type, + ) + + arange = rearrange(torch.arange(config.buffer_sequence_length, device=device), "s -> 1 s") + kv_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") + update_mask = arange < kv_seqlens_expanded + + k_cache_ref[update_mask] = rearrange(k_ro, "b s ... -> (b s) ...").to(dtype=torch_type) + v_cache_ref[update_mask] = rearrange(new_v, "b s ... -> (b s) ...").to(dtype=torch_type) + key_padding_mask = arange < kv_seqlens_expanded + + out_ref, _ = attention_ref( + q=q_ro, + k=k_cache_ref, + v=v_cache_ref, + query_padding_mask=None, + key_padding_mask=key_padding_mask, + attention_bias=attention_bias, + causal=True, + window_size=window_size, + softcap=config.softcap, + use_smooth_softmax=config.use_smooth_softmax, + head_sink=head_sink, + ) + out_ref_np = out_ref.detach().cpu().numpy() + + # Transpose reference cache back to BNSH for comparison + k_cache_ref_np = k_cache_ref.transpose(1, 2).detach().cpu().numpy() + v_cache_ref_np = v_cache_ref.transpose(1, 2).detach().cpu().numpy() + + # --- ONNX Runtime Path --- + q_ort, k_ort, v_ort, new_k_ort, new_v_ort = q, k, v, new_k, new_v + if config.packed: + q_ort = torch.cat([q, new_k, new_v], dim=2) + new_k_ort, new_v_ort = None, None + + # seqlens_k for GQA op is past_seq_len + seq_len - 1 + ort_seqlens = cache_seqlens - 1 + out, present_k, present_v = gqa_prompt_func( + q=q_ort, + k=k_ort, + v=v_ort, + config=config, + new_k=new_k_ort, + new_v=new_v_ort, + cos=cos, + sin=sin, + seqlens_k=ort_seqlens, + position_ids=position_ids, + attention_bias=attention_bias, + head_sink=head_sink, + ep=ep, + device=device, + share_buffer=True, + ort_type=ort_type, + numpy_type=numpy_type, + ) + out = torch.reshape(out, (config.batch_size, config.q_sequence_length, config.num_heads, config.head_size)) + out_np = out.detach().cpu().numpy() + + # --- Comparison --- + numpy.testing.assert_allclose(present_k, k_cache_ref_np, rtol=rtol, atol=atol) + numpy.testing.assert_allclose(present_v, v_cache_ref_np, rtol=rtol, atol=atol) + numpy.testing.assert_allclose(out_np, out_ref_np, rtol=rtol, atol=atol) + + +def parity_check_gqa_past( + config: GQAConfig, + ep, + device, + torch_type, + numpy_type, + ort_type, + causal, + rtol, + atol, +): + # --- Test Data Generation --- + q = torch.randn( + config.batch_size, + config.q_sequence_length, + config.num_heads, + config.head_size, + device=device, + dtype=torch_type, + ) + # k and v are the cache buffers, created in BNSH format + k = torch.randn( + config.batch_size, + config.kv_num_heads, + config.buffer_sequence_length, + config.head_size, + device=device, + dtype=torch_type, + ) + v = torch.randn_like(k) + new_k = torch.randn( + config.batch_size, + config.q_sequence_length, + config.kv_num_heads, + config.head_size, + device=device, + dtype=torch_type, + ) + new_v = torch.randn_like(new_k) + + head_sink = torch.rand(config.num_heads, dtype=torch_type, device=device) if config.has_head_sink else None + window_size = (-1, -1) + if config.local_window_size > 0: + window_size = (config.local_window_size, 0) + elif causal: + window_size = (-1, 0) + + # --- PyTorch Reference Path --- + # Transpose BNSH cache to BSNH format for reference implementation + k_cache_ref = k.clone().transpose(1, 2) + v_cache_ref = v.clone().transpose(1, 2) + + cache_seqlens = torch.randint( + 0, + config.past_kv_sequence_length - config.q_sequence_length + 1, + (config.batch_size,), + device=device, + dtype=torch.long, + ) + + cos, sin, q_ro, k_ro = None, None, q, new_k + if config.rotary: + rotary_dim = math.floor(config.head_size / 16) * 16 + angle = torch.rand(config.buffer_sequence_length, rotary_dim // 2, device=device) * 2 * math.pi + cos = torch.cos(angle).to(dtype=torch_type) + sin = torch.sin(angle).to(dtype=torch_type) + q_ro = apply_rotary_embedding(q.clone(), cos, sin, cache_seqlens, config.rotary_interleaved, device) + k_ro = apply_rotary_embedding(new_k.clone(), cos, sin, cache_seqlens, config.rotary_interleaved, device) + + position_ids = None + attention_bias = None + total_seq_len = config.past_kv_sequence_length + if ep == "CPUExecutionProvider": + if config.has_position_ids: + position_ids = (cache_seqlens.unsqueeze(1) + torch.arange(config.q_sequence_length, device=device)).long() + if config.has_attention_bias: + attention_bias = torch.zeros( + config.batch_size, 1, config.q_sequence_length, total_seq_len, device=device, dtype=torch_type + ) + for b in range(config.batch_size): + end_pos = cache_seqlens[b] + config.q_sequence_length + attention_bias[b, :, :, end_pos:] = float("-inf") + + arange = rearrange(torch.arange(config.buffer_sequence_length, device=device), "s -> 1 s") + cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") + update_mask = torch.logical_and( + cache_seqlens_expanded <= arange, arange < cache_seqlens_expanded + config.q_sequence_length + ) + k_cache_ref[update_mask] = rearrange(k_ro, "b s ... -> (b s) ...").to(dtype=torch_type) + v_cache_ref[update_mask] = rearrange(new_v, "b s ... -> (b s) ...").to(dtype=torch_type) + key_padding_mask = arange < cache_seqlens_expanded + config.q_sequence_length + + out_ref, _ = attention_ref( + q=q_ro, + k=k_cache_ref, + v=v_cache_ref, + query_padding_mask=None, + key_padding_mask=key_padding_mask, + attention_bias=attention_bias, + causal=True, + window_size=window_size, + softcap=config.softcap, + use_smooth_softmax=config.use_smooth_softmax, + head_sink=head_sink, + ) + out_ref_np = out_ref.detach().cpu().numpy() + + # Transpose reference cache back to BNSH for comparison + k_cache_ref_np = k_cache_ref.transpose(1, 2).detach().cpu().numpy() + v_cache_ref_np = v_cache_ref.transpose(1, 2).detach().cpu().numpy() + + # --- ONNX Runtime Path --- + q_ort, k_ort, v_ort, new_k_ort, new_v_ort = q, k, v, new_k, new_v + if config.packed: + q_ort = torch.cat([q, new_k, new_v], dim=2) + new_k_ort, new_v_ort = None, None + + ort_seqlens = cache_seqlens + config.q_sequence_length - 1 + out, present_k, present_v = gqa_past_func( + q=q_ort, + k=k_ort, + v=v_ort, + config=config, + new_k=new_k_ort, + new_v=new_v_ort, + cos=cos, + sin=sin, + seqlens_k=ort_seqlens.int(), + position_ids=position_ids, + attention_bias=attention_bias, + head_sink=head_sink, + ep=ep, + device=device, + share_buffer=True, + ort_type=ort_type, + numpy_type=numpy_type, + ) + out = torch.reshape(out, (config.batch_size, config.q_sequence_length, config.num_heads, config.head_size)) + out_np = out.detach().cpu().numpy() + + numpy.testing.assert_allclose(present_k, k_cache_ref_np, rtol=rtol, atol=atol) + numpy.testing.assert_allclose(present_v, v_cache_ref_np, rtol=rtol, atol=atol) + numpy.testing.assert_allclose(out_np, out_ref_np, rtol=rtol, atol=atol) + + +# ################################################################################################# +# Test Case Generators +# ################################################################################################# + + +def get_cuda_rotary_options(): + return [(False, False)] if pipeline_mode else [(True, False), (True, True), (False, False)] + + +def get_cpu_rotary_options(): + return [(False, False), (True, False), (True, True)] + + +def get_softmax_options(allow_head_sink: bool = True): + head_sink_option = (False, True) if allow_head_sink else (False, False) + return [(False, False), head_sink_option] if pipeline_mode else [(False, False), (False, True), (True, False)] + + +def gqa_cuda_prompt_test_cases(allow_head_sink: bool = True): + batches = [3] if pipeline_mode else [1, 3, 5] + seqs = [(35, 35)] if pipeline_mode else [(35, 35), (127, 127), (240, 240), (2000, 2000)] + num_h = [(6, 3)] if pipeline_mode else [(6, 3), (9, 9), (32, 8)] + h_sizes = [32] if pipeline_mode else [32, 64, 128, 256] + smmoth_softmax__head_sink = get_softmax_options(allow_head_sink) + + for b in batches: + for sq, skv in seqs: + for n, n2 in num_h: + for h in h_sizes: + for lws in [-1, random.randint(1, skv)]: + for rotary, rotary_interleaved in get_cuda_rotary_options(): + for packed in [False, True]: + for softcap in [0.0, 50.0]: + if rotary and h % 16 > 0: + continue + for use_smooth_softmax, has_head_sink in smmoth_softmax__head_sink: + if softcap > 0 and (use_smooth_softmax or has_head_sink): + continue + config = GQAConfig( + batch_size=b, + q_sequence_length=sq, + kv_sequence_length=skv, + past_kv_sequence_length=0, + buffer_sequence_length=sq + skv + 8, + num_heads=n, + kv_num_heads=n2, + head_size=h, + local_window_size=lws, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + softcap=softcap, + use_smooth_softmax=use_smooth_softmax, + has_head_sink=has_head_sink, + ) + name = f"b{b}_sq{sq}_skv{skv}_nh{n}_{n2}_h{h}_w{lws}_rot{rotary}{rotary_interleaved}_pkd{packed}_sc{softcap}_sm{use_smooth_softmax}_{has_head_sink}" + yield name, config + + +def gqa_cuda_past_test_cases(allow_head_sink: bool = True): + batches = [5] if pipeline_mode else [1, 3, 5] + # s: new sequence length, s2: past sequence length + seqs = [(1, 1024)] if pipeline_mode else [(1, 128), (1, 1024), (1, 2048), (1, 5000)] + num_h = [(32, 8)] if pipeline_mode else [(6, 3), (9, 9), (32, 8)] + h_sizes = [256] if pipeline_mode else [64, 128, 256] + smmoth_softmax__head_sink = get_softmax_options(allow_head_sink) + + for b in batches: + for s, s2 in seqs: + for n, n2 in num_h: + for h in h_sizes: + for lws in [-1, random.randint(1, s2)]: + for rotary, rotary_interleaved in get_cuda_rotary_options(): + for packed in [False, True]: + for softcap in [0.0, 50.0]: + if rotary and h % 16 > 0: + continue + for use_smooth_softmax, has_head_sink in smmoth_softmax__head_sink: + config = GQAConfig( + batch_size=b, + q_sequence_length=s, + kv_sequence_length=s, + past_kv_sequence_length=s2, + buffer_sequence_length=s + s2 + 8, + num_heads=n, + kv_num_heads=n2, + head_size=h, + local_window_size=lws, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + softcap=softcap, + use_smooth_softmax=use_smooth_softmax, + has_head_sink=has_head_sink, + ) + name = f"b{b}_s{s}_{s2}_nh{n}_{n2}_h{h}_w{lws}_rot{rotary}{rotary_interleaved}_pkd{packed}_sc{softcap}_sm{use_smooth_softmax}_{has_head_sink}" + yield name, config + + +# ################################################################################################# +# Unit Test Classes +# ################################################################################################# + + +def has_cuda_provider(): + return "CUDAExecutionProvider" in get_available_providers() + + +def has_flash_attention(): + if not has_cuda_provider() or not torch.cuda.is_available(): + return False + major, _ = torch.cuda.get_device_capability() + return major >= 8 + + +def has_memory_efficient(): + if not has_cuda_provider() or not torch.cuda.is_available(): + return False + major, _ = torch.cuda.get_device_capability() + return major >= 5 + + +@unittest.skipIf(not has_flash_attention(), "Flash Attention is not available, skipping tests.") +class TestFlashGQA(unittest.TestCase): + @parameterized.expand(gqa_cuda_prompt_test_cases()) + def test_gqa_prompt_flash_attention(self, name, config): + os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0" + parity_check_gqa_prompt( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float16, + numpy_type=numpy.float16, + ort_type=TensorProto.FLOAT16, + causal=True, + rtol=5e-3, + atol=5e-3, + ) + + @parameterized.expand(gqa_cuda_past_test_cases()) + def test_gqa_past_flash_attention(self, name, config): + os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0" + parity_check_gqa_past( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float16, + numpy_type=numpy.float16, + ort_type=TensorProto.FLOAT16, + causal=True, + rtol=5e-3, + atol=5e-3, + ) + + +@unittest.skipIf(not has_memory_efficient(), "Memory Efficient Attention is not available, skipping tests.") +class TestMemoryEfficientGQA(unittest.TestCase): + @parameterized.expand(gqa_cuda_prompt_test_cases(allow_head_sink=False)) + def test_gqa_prompt_memory_efficient(self, name, config): + os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1" + parity_check_gqa_prompt( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float16, + numpy_type=numpy.float16, + ort_type=TensorProto.FLOAT16, + causal=True, + rtol=5e-3, + atol=5e-3, + ) + + @parameterized.expand(gqa_cuda_past_test_cases(allow_head_sink=False)) + def test_gqa_past_memory_efficient(self, name, config): + os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1" + parity_check_gqa_past( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float16, + numpy_type=numpy.float16, + ort_type=TensorProto.FLOAT16, + causal=True, + rtol=5e-3, + atol=5e-3, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxruntime/test/python/transformers/test_gqa_cuda.py b/onnxruntime/test/python/transformers/test_gqa_cuda.py deleted file mode 100644 index 79976a92e54bf..0000000000000 --- a/onnxruntime/test/python/transformers/test_gqa_cuda.py +++ /dev/null @@ -1,2046 +0,0 @@ -# -------------------------------------------------------------------------- -# Copyright 2020 The HuggingFace Inc. team -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 -# -------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# ------------------------------------------------------------------------- -import math -import os -import platform -import random -import unittest - -import numpy -import torch -from einops import rearrange, repeat -from onnx import TensorProto, helper -from packaging import version -from parameterized import parameterized -from test_gqa_cpu import smooth_softmax_ref - -from onnxruntime import InferenceSession, OrtValue, SessionOptions, get_available_providers - -torch.manual_seed(0) - -pipeline_mode = True # Reduces number of tests so pipeline doesn't time out - - -class Formats: - BSNH = 0 - BNSH = 1 - - -class Config: - batch_size = 0 - sequence_length = 0 - kv_sequence_length = 0 # this is past sequence length when there is past state. - num_heads = 0 - kv_num_heads = 0 - head_size = 0 - ep = "CUDAExecutionProvider" - - def __init__(self, batch_size, sequence_length, kv_sequence_length, num_heads, kv_num_heads, head_size): - self.batch_size = batch_size - self.sequence_length = sequence_length - self.kv_sequence_length = kv_sequence_length - self.num_heads = num_heads - self.kv_num_heads = kv_num_heads - self.head_size = head_size - - def __repr__(self): - short_ep = self.ep[: -len("ExecutionProvider")].lower() - return ( - f"Config(batch_size={self.batch_size}, sequence_length={self.sequence_length}, " - f"kv_sequence_length={self.kv_sequence_length}, " - f"num_heads={self.num_heads}, kv_num_heads={self.kv_num_heads}, head_size={self.head_size}, ep={short_ep})" - ) - - -class PromptConfig: - batch_size = 0 - q_sequence_length = 0 - kv_sequence_length = 0 - buffer_sequence_length = 0 - num_heads = 0 - kv_num_heads = 0 - head_size = 0 - ep = "CUDAExecutionProvider" - - def __init__( - self, - batch_size, - q_sequence_length, - kv_sequence_length, - buffer_sequence_length, - num_heads, - kv_num_heads, - head_size, - ): - self.batch_size = batch_size - self.q_sequence_length = q_sequence_length - self.kv_sequence_length = kv_sequence_length - self.buffer_sequence_length = buffer_sequence_length - self.num_heads = num_heads - self.kv_num_heads = kv_num_heads - self.head_size = head_size - - def __repr__(self): - short_ep = self.ep[: -len("ExecutionProvider")].lower() - return ( - f"PromptConfig(batch_size={self.batch_size}, q_sequence_length={self.q_sequence_length}, " - f"kv_sequence_length={self.kv_sequence_length}, buffer_sequence_length={self.buffer_sequence_length}, " - f"num_heads={self.num_heads}, kv_num_heads={self.kv_num_heads}, head_size={self.head_size}, ep={short_ep})" - ) - - -def create_group_query_attention_graph_prompt( - config, - past_kv_format=Formats.BSNH, - share_buffer=True, - local_window_size=-1, - rotary=False, - rotary_interleaved=False, - packed=False, - interactive=False, - softcap=0.0, - use_smooth_softmax=False, -): - past_kv_seqlen = config.buffer_sequence_length if share_buffer else 0 - present_kv_seqlen = config.buffer_sequence_length if share_buffer else config.kv_sequence_length - nodes = [ - helper.make_node( - "GroupQueryAttention", - [ - "query", - "key" if not packed else "", - "value" if not packed else "", - "past_key" if share_buffer else "", - "past_value" if share_buffer else "", - "seqlens_k", - "total_sequence_length", - "cos_cache" if rotary else "", - "sin_cache" if rotary else "", - ], - ["output", "present_key", "present_value"], - "GroupQueryAttention_0", - num_heads=config.num_heads, - kv_num_heads=config.kv_num_heads, - local_window_size=local_window_size, - do_rotary=rotary, - rotary_interleaved=rotary_interleaved, - softcap=softcap, - smooth_softmax=1 if use_smooth_softmax else 0, - # is_past_bsnh=1 if past_kv_format == Formats.BSNH else 0, - # kv_share_buffer=1 if share_buffer else 0, - domain="com.microsoft", - ), - ] - - graph_input = [ - helper.make_tensor_value_info( - "query", - TensorProto.FLOAT16, - [ - config.batch_size, - config.q_sequence_length, - ( - (config.num_heads * config.head_size) - if not packed - else (config.num_heads * config.head_size + 2 * config.kv_num_heads * config.head_size) - ), - ], - ), - helper.make_tensor_value_info( - "seqlens_k", - TensorProto.INT32, - [config.batch_size], - ), - helper.make_tensor_value_info( - "total_sequence_length", - TensorProto.INT32, - [1], - ), - ] - if not packed: - graph_input += [ - helper.make_tensor_value_info( - "key", - TensorProto.FLOAT16, - [ - config.batch_size, - config.kv_sequence_length, - config.kv_num_heads * config.head_size, - ], - ), - helper.make_tensor_value_info( - "value", - TensorProto.FLOAT16, - [ - config.batch_size, - config.kv_sequence_length, - config.kv_num_heads * config.head_size, - ], - ), - ] - if share_buffer: - graph_input += [ - helper.make_tensor_value_info( - "past_key", - TensorProto.FLOAT16, - [ - config.batch_size, - past_kv_seqlen if past_kv_format == Formats.BSNH else config.kv_num_heads, - config.kv_num_heads if past_kv_format == Formats.BSNH else past_kv_seqlen, - config.head_size, - ], - ), - helper.make_tensor_value_info( - "past_value", - TensorProto.FLOAT16, - [ - config.batch_size, - past_kv_seqlen if past_kv_format == Formats.BSNH else config.kv_num_heads, - config.kv_num_heads if past_kv_format == Formats.BSNH else past_kv_seqlen, - config.head_size, - ], - ), - ] - if rotary: - graph_input += [ - helper.make_tensor_value_info( - "cos_cache", - TensorProto.FLOAT16, - [ - config.buffer_sequence_length if share_buffer else config.kv_sequence_length, - (math.floor(config.head_size / 16) * 16) // 2, - ], - ), - helper.make_tensor_value_info( - "sin_cache", - TensorProto.FLOAT16, - [ - config.buffer_sequence_length if share_buffer else config.kv_sequence_length, - (math.floor(config.head_size / 16) * 16) // 2, - ], - ), - ] - - graph_output = [ - helper.make_tensor_value_info( - "output", - TensorProto.FLOAT16, - [config.batch_size, config.q_sequence_length, config.num_heads * config.head_size], - ), - helper.make_tensor_value_info( - "present_key", - TensorProto.FLOAT16, - [ - config.batch_size, - present_kv_seqlen if past_kv_format == Formats.BSNH else config.kv_num_heads, - config.kv_num_heads if past_kv_format == Formats.BSNH else present_kv_seqlen, - config.head_size, - ], - ), - helper.make_tensor_value_info( - "present_value", - TensorProto.FLOAT16, - [ - config.batch_size, - present_kv_seqlen if past_kv_format == Formats.BSNH else config.kv_num_heads, - config.kv_num_heads if past_kv_format == Formats.BSNH else present_kv_seqlen, - config.head_size, - ], - ), - helper.make_tensor_value_info( - "present_key", - TensorProto.FLOAT16, - [ - config.batch_size, - config.kv_sequence_length if past_kv_format == Formats.BSNH else config.kv_num_heads, - config.kv_num_heads if past_kv_format == Formats.BSNH else config.kv_sequence_length, - config.head_size, - ], - ), - helper.make_tensor_value_info( - "present_value", - TensorProto.FLOAT16, - [ - config.batch_size, - config.kv_sequence_length if past_kv_format == Formats.BSNH else config.kv_num_heads, - config.kv_num_heads if past_kv_format == Formats.BSNH else config.kv_sequence_length, - config.head_size, - ], - ), - ] - - graph = helper.make_graph( - nodes, - "GroupQueryAttention_Graph", - graph_input, - graph_output, - ) - - model = helper.make_model(graph) - return model.SerializeToString() - - -def create_group_query_attention_graph_past( - config, - past_kv_format=Formats.BSNH, - share_buffer=True, - local_window_size=-1, - rotary=False, - rotary_interleaved=False, - packed=False, - softcap=0.0, - use_smooth_softmax=False, -): - past_kv_seqlen = config.kv_sequence_length - present_kv_seqlen = ( - config.kv_sequence_length if share_buffer else config.kv_sequence_length + config.sequence_length - ) - nodes = [ - helper.make_node( - "GroupQueryAttention", - [ - "query", - "key" if not packed else "", - "value" if not packed else "", - "past_key", - "past_value", - "seqlens_k", - "total_sequence_length", - "cos_cache" if rotary else "", - "sin_cache" if rotary else "", - ], - ["output", "present_key", "present_value"], - "GroupQueryAttention_0", - num_heads=config.num_heads, - kv_num_heads=config.kv_num_heads, - local_window_size=local_window_size, - do_rotary=rotary, - rotary_interleaved=rotary_interleaved, - softcap=softcap, - smooth_softmax=1 if use_smooth_softmax else 0, - # is_past_bsnh=1 if past_kv_format == Formats.BSNH else 0, - # kv_share_buffer=1 if share_buffer else 0, - domain="com.microsoft", - ), - ] - - graph_input = [ - helper.make_tensor_value_info( - "query", - TensorProto.FLOAT16, - [ - config.batch_size, - config.sequence_length, - ( - (config.num_heads * config.head_size) - if not packed - else (config.num_heads * config.head_size + 2 * config.kv_num_heads * config.head_size) - ), - ], - ), - helper.make_tensor_value_info( - "past_key", - TensorProto.FLOAT16, - [ - config.batch_size, - past_kv_seqlen if past_kv_format == Formats.BSNH else config.kv_num_heads, - config.kv_num_heads if past_kv_format == Formats.BSNH else past_kv_seqlen, - config.head_size, - ], - ), - helper.make_tensor_value_info( - "past_value", - TensorProto.FLOAT16, - [ - config.batch_size, - past_kv_seqlen if past_kv_format == Formats.BSNH else config.kv_num_heads, - config.kv_num_heads if past_kv_format == Formats.BSNH else past_kv_seqlen, - config.head_size, - ], - ), - helper.make_tensor_value_info( - "seqlens_k", - TensorProto.INT32, - [config.batch_size], - ), - helper.make_tensor_value_info( - "total_sequence_length", - TensorProto.INT32, - [1], - ), - ] - if not packed: - graph_input += [ - helper.make_tensor_value_info( - "key", - TensorProto.FLOAT16, - [ - config.batch_size, - config.sequence_length, - config.kv_num_heads * config.head_size, - ], - ), - helper.make_tensor_value_info( - "value", - TensorProto.FLOAT16, - [ - config.batch_size, - config.sequence_length, - config.kv_num_heads * config.head_size, - ], - ), - ] - if rotary: - graph_input += [ - helper.make_tensor_value_info( - "cos_cache", - TensorProto.FLOAT16, - [ - config.kv_sequence_length + (0 if share_buffer else config.sequence_length), - (math.floor(config.head_size / 16) * 16) // 2, - ], - ), - helper.make_tensor_value_info( - "sin_cache", - TensorProto.FLOAT16, - [ - config.kv_sequence_length + (0 if share_buffer else config.sequence_length), - (math.floor(config.head_size / 16) * 16) // 2, - ], - ), - ] - - graph_output = [ - helper.make_tensor_value_info( - "output", - TensorProto.FLOAT16, - [config.batch_size, config.sequence_length, config.num_heads * config.head_size], - ), - helper.make_tensor_value_info( - "present_key", - TensorProto.FLOAT16, - [ - config.batch_size, - present_kv_seqlen if past_kv_format == Formats.BSNH else config.kv_num_heads, - config.kv_num_heads if past_kv_format == Formats.BSNH else present_kv_seqlen, - config.head_size, - ], - ), - helper.make_tensor_value_info( - "present_value", - TensorProto.FLOAT16, - [ - config.batch_size, - present_kv_seqlen if past_kv_format == Formats.BSNH else config.kv_num_heads, - config.kv_num_heads if past_kv_format == Formats.BSNH else present_kv_seqlen, - config.head_size, - ], - ), - ] - - graph = helper.make_graph( - nodes, - "GroupQueryAttention_Graph", - graph_input, - graph_output, - ) - - model = helper.make_model(graph) - return model.SerializeToString() - - -def rotary_options_for_current_os(): - # Reference implementation of rotary uses triton, which is not available in Windows. - # So we only test rotary in Linux right now. - return [(False, False)] if platform.system() != "Linux" else [(True, False), (True, True), (False, False)] - - -def gqa_prompt_func( - q, - k, - v, - config, - new_k, - new_v, - cos=None, - sin=None, - seqlens_k=None, - window_size=-1, - past_kv_format=Formats.BSNH, - share_buffer=True, - rotary_interleaved=False, - softcap=0.0, - use_smooth_softmax=False, -): - onnx_model_str = create_group_query_attention_graph_prompt( - config, - past_kv_format, - share_buffer, - local_window_size=window_size, - rotary=cos is not None, - rotary_interleaved=rotary_interleaved, - packed=new_k is None, - softcap=softcap, - use_smooth_softmax=use_smooth_softmax, - ) - q = torch.reshape(q, (config.batch_size, config.q_sequence_length, -1)) - past_k = k.clone() if share_buffer else None - past_v = v.clone() if share_buffer else None - if new_k is not None: - new_k = torch.reshape(new_k, (config.batch_size, config.kv_sequence_length, -1)) - new_v = torch.reshape(new_v, (config.batch_size, config.kv_sequence_length, -1)) - if share_buffer: - ort_inputs = { - "query": q.detach().cpu().numpy(), - "past_key": OrtValue.ortvalue_from_numpy(past_k.detach().cpu().numpy(), "cuda", 0), - "past_value": OrtValue.ortvalue_from_numpy(past_v.detach().cpu().numpy(), "cuda", 0), - "seqlens_k": seqlens_k.detach().cpu().numpy().astype(numpy.int32), - "total_sequence_length": torch.tensor([config.q_sequence_length], dtype=torch.int32).detach().cpu().numpy(), - } - sess_options = SessionOptions() - ort_session = InferenceSession(onnx_model_str, sess_options, providers=[config.ep]) - io_binding = ort_session.io_binding() - if new_k is not None: - ort_inputs["key"] = new_k.detach().cpu().numpy() - ort_inputs["value"] = new_v.detach().cpu().numpy() - io_binding.bind_cpu_input("key", ort_inputs["key"]) - io_binding.bind_cpu_input("value", ort_inputs["value"]) - if cos is not None: - ort_inputs["cos_cache"] = cos.detach().cpu().numpy() - ort_inputs["sin_cache"] = sin.detach().cpu().numpy() - io_binding.bind_cpu_input("cos_cache", ort_inputs["cos_cache"]) - io_binding.bind_cpu_input("sin_cache", ort_inputs["sin_cache"]) - io_binding.bind_cpu_input("query", ort_inputs["query"]) - io_binding.bind_input( - "past_key", "cuda", 0, numpy.float16, ort_inputs["past_key"].shape(), ort_inputs["past_key"].data_ptr() - ) - io_binding.bind_input( - "past_value", - "cuda", - 0, - numpy.float16, - ort_inputs["past_value"].shape(), - ort_inputs["past_value"].data_ptr(), - ) - io_binding.bind_cpu_input("seqlens_k", ort_inputs["seqlens_k"]) - io_binding.bind_cpu_input("total_sequence_length", ort_inputs["total_sequence_length"]) - io_binding.bind_output("output") - io_binding.bind_ortvalue_output("present_key", ort_inputs["past_key"]) - io_binding.bind_ortvalue_output("present_value", ort_inputs["past_value"]) - ort_session.run_with_iobinding(io_binding) - ort_output, present_k, present_v = io_binding.copy_outputs_to_cpu() - ort_output = numpy.array(ort_output) - output = torch.tensor(ort_output) - return output, present_k, present_v - else: - ort_inputs = { - "query": q.detach().cpu().numpy(), - "seqlens_k": seqlens_k.detach().cpu().numpy().astype(numpy.int32), - "total_sequence_length": torch.tensor([config.q_sequence_length], dtype=torch.int32).detach().cpu().numpy(), - } - sess_options = SessionOptions() - ort_session = InferenceSession(onnx_model_str, sess_options, providers=[config.ep]) - io_binding = ort_session.io_binding() - if new_k is not None: - ort_inputs["key"] = new_k.detach().cpu().numpy() - ort_inputs["value"] = new_v.detach().cpu().numpy() - io_binding.bind_cpu_input("key", ort_inputs["key"]) - io_binding.bind_cpu_input("value", ort_inputs["value"]) - if cos is not None: - ort_inputs["cos_cache"] = cos.detach().cpu().numpy() - ort_inputs["sin_cache"] = sin.detach().cpu().numpy() - io_binding.bind_cpu_input("cos_cache", ort_inputs["cos_cache"]) - io_binding.bind_cpu_input("sin_cache", ort_inputs["sin_cache"]) - io_binding.bind_cpu_input("query", ort_inputs["query"]) - io_binding.bind_cpu_input("seqlens_k", ort_inputs["seqlens_k"]) - io_binding.bind_cpu_input("total_sequence_length", ort_inputs["total_sequence_length"]) - io_binding.bind_output("output") - io_binding.bind_output("present_key") - io_binding.bind_output("present_value") - ort_session.run_with_iobinding(io_binding) - ort_output, present_k, present_v = io_binding.copy_outputs_to_cpu() - ort_output = numpy.array(ort_output) - output = torch.tensor(ort_output) - return output, present_k, present_v - - -def gqa_past_func( - q, - k, - v, - config, - new_k, - new_v, - cos=None, - sin=None, - seqlens_k=None, - past_kv_format=Formats.BSNH, - share_buffer=True, - window_size=-1, - rotary_interleaved=False, - softcap=0.0, - use_smooth_softmax=False, -): - onnx_model_str = create_group_query_attention_graph_past( - config, - past_kv_format, - share_buffer, - local_window_size=window_size, - rotary=cos is not None, - rotary_interleaved=rotary_interleaved, - packed=new_k is None, - softcap=softcap, - use_smooth_softmax=use_smooth_softmax, - ) - q = torch.reshape(q, (config.batch_size, config.sequence_length, -1)) - past_k = k.clone() - past_v = v.clone() - if new_k is not None: - new_k = torch.reshape(new_k, (config.batch_size, config.sequence_length, -1)) - new_v = torch.reshape(new_v, (config.batch_size, config.sequence_length, -1)) - if share_buffer: - ort_inputs = { - "query": q.detach().cpu().numpy(), - "past_key": OrtValue.ortvalue_from_numpy(past_k.detach().cpu().numpy(), "cuda", 0), - "past_value": OrtValue.ortvalue_from_numpy(past_v.detach().cpu().numpy(), "cuda", 0), - "seqlens_k": seqlens_k.detach().cpu().numpy().astype(numpy.int32), - "total_sequence_length": torch.tensor([config.kv_sequence_length], dtype=torch.int32) - .detach() - .cpu() - .numpy(), - } - sess_options = SessionOptions() - ort_session = InferenceSession(onnx_model_str, sess_options, providers=[config.ep]) - io_binding = ort_session.io_binding() - if new_k is not None: - ort_inputs["key"] = new_k.detach().cpu().numpy() - ort_inputs["value"] = new_v.detach().cpu().numpy() - io_binding.bind_cpu_input("key", ort_inputs["key"]) - io_binding.bind_cpu_input("value", ort_inputs["value"]) - if cos is not None: - ort_inputs["cos_cache"] = cos.detach().cpu().numpy() - ort_inputs["sin_cache"] = sin.detach().cpu().numpy() - io_binding.bind_cpu_input("cos_cache", ort_inputs["cos_cache"]) - io_binding.bind_cpu_input("sin_cache", ort_inputs["sin_cache"]) - io_binding.bind_cpu_input("query", ort_inputs["query"]) - io_binding.bind_input( - "past_key", "cuda", 0, numpy.float16, ort_inputs["past_key"].shape(), ort_inputs["past_key"].data_ptr() - ) - io_binding.bind_input( - "past_value", - "cuda", - 0, - numpy.float16, - ort_inputs["past_value"].shape(), - ort_inputs["past_value"].data_ptr(), - ) - io_binding.bind_cpu_input("seqlens_k", ort_inputs["seqlens_k"]) - io_binding.bind_cpu_input("total_sequence_length", ort_inputs["total_sequence_length"]) - io_binding.bind_output("output") - io_binding.bind_ortvalue_output("present_key", ort_inputs["past_key"]) - io_binding.bind_ortvalue_output("present_value", ort_inputs["past_value"]) - ort_session.run_with_iobinding(io_binding) - ort_output, present_k, present_v = io_binding.copy_outputs_to_cpu() - ort_output = numpy.array(ort_output) - output = torch.tensor(ort_output) - return output, present_k, present_v - else: - ort_inputs = { - "query": q.detach().cpu().numpy(), - "past_key": past_k.detach().cpu().numpy(), - "past_value": past_v.detach().cpu().numpy(), - "seqlens_k": seqlens_k.detach().cpu().numpy().astype(numpy.int32), - "total_sequence_length": torch.tensor( - [config.kv_sequence_length + config.sequence_length], dtype=torch.int32 - ) - .detach() - .cpu() - .numpy(), - } - sess_options = SessionOptions() - ort_session = InferenceSession(onnx_model_str, sess_options, providers=[config.ep]) - io_binding = ort_session.io_binding() - if new_k is not None: - ort_inputs["key"] = new_k.detach().cpu().numpy() - ort_inputs["value"] = new_v.detach().cpu().numpy() - io_binding.bind_cpu_input("key", ort_inputs["key"]) - io_binding.bind_cpu_input("value", ort_inputs["value"]) - if cos is not None: - ort_inputs["cos_cache"] = cos.detach().cpu().numpy() - ort_inputs["sin_cache"] = sin.detach().cpu().numpy() - io_binding.bind_cpu_input("cos_cache", ort_inputs["cos_cache"]) - io_binding.bind_cpu_input("sin_cache", ort_inputs["sin_cache"]) - io_binding.bind_cpu_input("query", ort_inputs["query"]) - io_binding.bind_cpu_input("past_key", ort_inputs["past_key"]) - io_binding.bind_cpu_input("past_value", ort_inputs["past_value"]) - io_binding.bind_cpu_input("seqlens_k", ort_inputs["seqlens_k"]) - io_binding.bind_cpu_input("total_sequence_length", ort_inputs["total_sequence_length"]) - io_binding.bind_output("output") - io_binding.bind_output("present_key") - io_binding.bind_output("present_value") - ort_session.run_with_iobinding(io_binding) - ort_output, present_k, present_v = io_binding.copy_outputs_to_cpu() - ort_output = numpy.array(ort_output) - output = torch.tensor(ort_output) - return output, present_k, present_v - - -def construct_local_mask( - seqlen_q, - seqlen_k, - window_size=(-1, -1), # -1 means infinite window size - query_padding_mask=None, - key_padding_mask=None, - device=None, -): - row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1") - col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long) - sk = seqlen_k if key_padding_mask is None else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1") - sq = seqlen_q if query_padding_mask is None else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1") - if window_size[0] < 0: - return col_idx > row_idx + sk - sq + window_size[1] - else: - sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk - return torch.logical_or( - col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk), - col_idx < row_idx + sk - sq - window_size[0], - ) - - -def attention_ref( - q, - k, - v, - query_padding_mask=None, - key_padding_mask=None, - dropout_p=0.0, - dropout_mask=None, - causal=False, - window_size=(-1, -1), # -1 means infinite window size - softcap=0.0, - upcast=True, - reorder_ops=False, - use_smooth_softmax=False, -): - """ - Arguments: - q: (batch_size, seqlen_q, nheads, head_dim) - k: (batch_size, seqlen_k, nheads_k, head_dim) - v: (batch_size, seqlen_k, nheads_k, head_dim) - query_padding_mask: (batch_size, seqlen_q) - key_padding_mask: (batch_size, seqlen_k) - dropout_p: float - dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k) - causal: whether to apply causal masking - window_size: (int, int), left and right window size - upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast - output back to fp16/bf16. - reorder_ops: whether to change the order of operations (scaling k instead of scaling k, etc.) - without changing the math. This is to estimate the numerical error from operation - reordering. - Output: - output: (batch_size, seqlen_q, nheads, head_dim) - attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout - """ - if causal: - window_size = (window_size[0], 0) - dtype_og = q.dtype - if upcast: - q, k, v = q.float(), k.float(), v.float() - seqlen_q, seqlen_k = q.shape[1], k.shape[1] - k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2]) - v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2]) - d = q.shape[-1] - if not reorder_ops: - scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(d), k) - else: - scores = torch.einsum("bthd,bshd->bhts", q, k / math.sqrt(d)) - if softcap > 0: - scores = scores / softcap - scores = scores.tanh() - scores = scores * softcap - if key_padding_mask is not None: - scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) - if window_size[0] >= 0 or window_size[1] >= 0: - local_mask = construct_local_mask( - seqlen_q, - seqlen_k, - window_size, - query_padding_mask, - key_padding_mask, - q.device, - ) - scores.masked_fill_(local_mask, float("-inf")) - - if use_smooth_softmax: - head_sink = None - attention = smooth_softmax_ref(scores, head_sink) - else: - attention = torch.softmax(scores, dim=-1) - - # Some rows might be completely masked out so we fill them with zero instead of NaN - if window_size[0] >= 0 or window_size[1] >= 0: - attention = attention.masked_fill(torch.all(local_mask, dim=-1, keepdim=True), 0.0) - # We want to mask here so that the attention matrix doesn't have any NaNs - # Otherwise we'll get NaN in dV - if query_padding_mask is not None: - attention = attention.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0) - dropout_scaling = 1.0 / (1 - dropout_p) - if dropout_mask is not None: - attention_drop = attention.masked_fill(~dropout_mask, 0.0) - else: - attention_drop = attention - output = torch.einsum("bhts,bshd->bthd", attention_drop, v * dropout_scaling) - if query_padding_mask is not None: - output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0) - return output.to(dtype=dtype_og), attention.to(dtype=dtype_og) - - -def rotary_embedding(*args, **kwargs): - # Use local import since triton is not available in Windows. - from rotary_flash import apply_rotary_emb # noqa: PLC0415 - - return apply_rotary_emb(*args, **kwargs) - - -def parity_check_gqa_prompt( - config: PromptConfig, - causal=True, - local=False, - past_format=Formats.BNSH, - rotary=False, - rotary_interleaved=False, - packed=False, - softcap=0.0, - use_smooth_softmax=False, - rtol=1e-3, - atol=1e-3, -): - q = torch.randn( - config.batch_size, - config.q_sequence_length, - config.num_heads, - config.head_size, - device="cuda", - dtype=torch.float16, - requires_grad=False, - ) - k = torch.randn( - config.batch_size, - config.buffer_sequence_length if past_format == Formats.BSNH else config.kv_num_heads, - config.kv_num_heads if past_format == Formats.BSNH else config.buffer_sequence_length, - config.head_size, - device="cuda", - dtype=torch.float16, - requires_grad=False, - ) - v = torch.randn( - config.batch_size, - config.buffer_sequence_length if past_format == Formats.BSNH else config.kv_num_heads, - config.kv_num_heads if past_format == Formats.BSNH else config.buffer_sequence_length, - config.head_size, - device="cuda", - dtype=torch.float16, - requires_grad=False, - ) - new_k = torch.randn( - config.batch_size, - config.kv_sequence_length, - config.kv_num_heads, - config.head_size, - device="cuda", - dtype=torch.float16, - requires_grad=False, - ) - new_v = torch.randn( - config.batch_size, - config.kv_sequence_length, - config.kv_num_heads, - config.head_size, - device="cuda", - dtype=torch.float16, - requires_grad=False, - ) - - window_size = (-1, -1) - left_window_size = -1 - if local: - left_window_size = random.randint(0, config.kv_sequence_length) - window_size = (left_window_size, 0) - elif causal: - left_window_size = -1 - window_size = (-1, 0) - - # Pytorch to compare - k_cache_ref = k.clone() - v_cache_ref = v.clone() - if past_format == Formats.BNSH: - k_cache_ref = k_cache_ref.transpose(1, 2) - v_cache_ref = v_cache_ref.transpose(1, 2) - cache_seqlens = torch.tensor([config.kv_sequence_length], device="cuda").repeat(config.batch_size) - # cache_seqlens = torch.randint( - # 0, - # config.kv_sequence_length, - # (config.batch_size,), - # dtype=torch.int32, - # device="cuda", - # ) - # cache_seqlens[random.randint(0, cache_seqlens.size(dim=0) - 1)] = config.kv_sequence_length - rotary_seqlens = torch.tensor([0], device="cuda").repeat(config.batch_size) - - if rotary: - rotary_fraction = 1.0 - rotary_dim = math.floor(int(rotary_fraction * config.head_size) / 16) * 16 - angle = torch.rand(config.buffer_sequence_length, rotary_dim // 2, device="cuda") * 2 * math.pi - cos = torch.cos(angle).to(dtype=torch.float16) - sin = torch.sin(angle).to(dtype=torch.float16) - - if causal or local: - q_ro = rotary_embedding(q, cos, sin, seqlen_offsets=rotary_seqlens, interleaved=rotary_interleaved) - else: - q_ro = rearrange( - rotary_embedding( - rearrange(q, "b s h d -> b 1 (s h) d"), - cos, - sin, - seqlen_offsets=rotary_seqlens, - interleaved=rotary_interleaved, - ), - "b 1 (s h) d -> b s h d", - s=config.q_sequence_length, - ) - # q_ro = q - k_ro = rotary_embedding(new_k, cos, sin, seqlen_offsets=rotary_seqlens, interleaved=rotary_interleaved) - else: - cos, sin = None, None - q_ro, k_ro = q, new_k - - rearrange(torch.arange(config.kv_sequence_length, device="cuda"), "s -> 1 s") - arange = rearrange(torch.arange(config.buffer_sequence_length, device="cuda"), "s -> 1 s") - cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") - kv_seqlens = torch.tensor([config.kv_sequence_length], device="cuda").repeat(config.batch_size) - kv_seqlens_expanded = rearrange(kv_seqlens, "b -> b 1") - update_mask = arange < kv_seqlens_expanded - k_cache_ref[update_mask] = rearrange(k_ro, "b s ... -> (b s) ...") - v_cache_ref[update_mask] = rearrange(new_v, "b s ... -> (b s) ...") - k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) - v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) - key_padding_mask = arange < cache_seqlens_expanded - out_ref, _ = attention_ref( - q_ro, - k_cache_rep, - v_cache_rep, - None, - key_padding_mask, - 0.0, - None, - causal=True, - window_size=window_size, - softcap=softcap, - use_smooth_softmax=use_smooth_softmax, - ) - out_ref = out_ref.detach().cpu().numpy() - if past_format == Formats.BNSH: - k_cache_ref = k_cache_ref.transpose(1, 2) - v_cache_ref = v_cache_ref.transpose(1, 2) - - # Flash function - if packed: - packed_qkv = torch.concatenate([q, new_k, new_v], dim=2) - out, present_k, present_v = gqa_prompt_func( - packed_qkv, - k, - v, - config, - None, - None, - cos, - sin, - cache_seqlens, - left_window_size, - past_format, - True, - rotary_interleaved, - softcap, - use_smooth_softmax=use_smooth_softmax, - ) - else: - out, present_k, present_v = gqa_prompt_func( - q, - k, - v, - config, - new_k, - new_v, - cos, - sin, - cache_seqlens, - left_window_size, - past_format, - True, - rotary_interleaved, - softcap, - use_smooth_softmax=use_smooth_softmax, - ) - out = torch.squeeze(out, 0) - out = torch.reshape(out, (config.batch_size, config.q_sequence_length, config.num_heads, config.head_size)) - out = out.detach().cpu().numpy() - - err_msg = ( - f" with {config}, causal={causal}, local={local}, past_format={past_format}," - f" rotary={rotary}, rotary_interleaved={rotary_interleaved}, packed={packed}, softcap={softcap}" - ) - # Make sure past-present buffer updating correctly - numpy.testing.assert_allclose( - present_k, k_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True, err_msg=err_msg - ) - numpy.testing.assert_allclose( - present_v, v_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True, err_msg=err_msg - ) - - numpy.testing.assert_allclose(out, out_ref, rtol=rtol, atol=atol, equal_nan=True, err_msg=err_msg) - - -def parity_check_gqa_prompt_no_buff( - config: PromptConfig, - causal=True, - local=False, - past_format=Formats.BNSH, - rotary=False, - rotary_interleaved=False, - packed=False, - softcap=0.0, - use_smooth_softmax=False, - rtol=1e-3, - atol=1e-3, -): - q = torch.randn( - config.batch_size, - config.q_sequence_length, - config.num_heads, - config.head_size, - device="cuda", - dtype=torch.float16, - requires_grad=False, - ) - new_k = torch.randn( - config.batch_size, - config.kv_sequence_length, - config.kv_num_heads, - config.head_size, - device="cuda", - dtype=torch.float16, - requires_grad=False, - ) - new_v = torch.randn( - config.batch_size, - config.kv_sequence_length, - config.kv_num_heads, - config.head_size, - device="cuda", - dtype=torch.float16, - requires_grad=False, - ) - - window_size = (-1, -1) - left_window_size = -1 - if local: - left_window_size = random.randint(0, config.kv_sequence_length) - window_size = (left_window_size, 0) - elif causal: - left_window_size = -1 - window_size = (-1, 0) - - # Pytorch to compare - k_cache_ref = new_k.clone() - v_cache_ref = new_v.clone() - # if past_format == Formats.BNSH: - # k_cache_ref = k_cache_ref.transpose(1, 2) - # v_cache_ref = v_cache_ref.transpose(1, 2) - cache_seqlens = torch.tensor([config.kv_sequence_length], device="cuda").repeat(config.batch_size) - # cache_seqlens = torch.randint( - # 0, - # config.kv_sequence_length, - # (config.batch_size,), - # dtype=torch.int32, - # device="cuda", - # ) - # cache_seqlens[random.randint(0, cache_seqlens.size(dim=0) - 1)] = config.kv_sequence_length - rotary_seqlens = torch.tensor([0], device="cuda").repeat(config.batch_size) - - if rotary: - rotary_fraction = 1.0 - rotary_dim = math.floor(int(rotary_fraction * config.head_size) / 16) * 16 - angle = torch.rand(config.kv_sequence_length, rotary_dim // 2, device="cuda") * 2 * math.pi - cos = torch.cos(angle).to(dtype=torch.float16) - sin = torch.sin(angle).to(dtype=torch.float16) - - if causal or local: - q_ro = rotary_embedding(q, cos, sin, seqlen_offsets=rotary_seqlens, interleaved=rotary_interleaved) - else: - q_ro = rearrange( - rotary_embedding( - rearrange(q, "b s h d -> b 1 (s h) d"), - cos, - sin, - seqlen_offsets=rotary_seqlens, - interleaved=rotary_interleaved, - ), - "b 1 (s h) d -> b s h d", - s=config.q_sequence_length, - ) - # q_ro = q - k_ro = rotary_embedding(k_cache_ref, cos, sin, seqlen_offsets=rotary_seqlens, interleaved=rotary_interleaved) - else: - cos, sin = None, None - q_ro, k_ro = q, k_cache_ref - k_cache_ref = k_ro - - brange = rearrange(torch.arange(config.kv_sequence_length, device="cuda"), "s -> 1 s") - cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") - new_mask = brange < cache_seqlens_expanded - k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) - v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) - out_ref, _ = attention_ref( - q_ro, - k_cache_rep, - v_cache_rep, - None, - new_mask, - 0.0, - None, - causal=True, - window_size=window_size, - softcap=softcap, - use_smooth_softmax=use_smooth_softmax, - ) - out_ref = out_ref.detach().cpu().numpy() - if past_format == Formats.BNSH: - k_cache_ref = k_cache_ref.transpose(1, 2) - v_cache_ref = v_cache_ref.transpose(1, 2) - - # Flash function - if packed: - packed_qkv = torch.concatenate([q, new_k, new_v], dim=2) - out, present_k, present_v = gqa_prompt_func( - packed_qkv, - None, - None, - config, - None, - None, - cos, - sin, - cache_seqlens, - left_window_size, - past_format, - False, - rotary_interleaved, - softcap, - use_smooth_softmax=use_smooth_softmax, - ) - else: - out, present_k, present_v = gqa_prompt_func( - q, - None, - None, - config, - new_k, - new_v, - cos, - sin, - cache_seqlens, - left_window_size, - past_format, - False, - rotary_interleaved, - softcap, - use_smooth_softmax=use_smooth_softmax, - ) - out = torch.squeeze(out, 0) - out = torch.reshape(out, (config.batch_size, config.q_sequence_length, config.num_heads, config.head_size)) - out = out.detach().cpu().numpy() - - err_msg = ( - f" with {config}, causal={causal}, local={local}, past_format={past_format}," - f" rotary={rotary}, rotary_interleaved={rotary_interleaved}, packed={packed}, softcap={softcap}, use_smooth_softmax={use_smooth_softmax}" - ) - # Make sure past-present buffer updating correctly - numpy.testing.assert_allclose( - present_k, k_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True, err_msg=err_msg - ) - numpy.testing.assert_allclose( - present_v, v_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True, err_msg=err_msg - ) - - numpy.testing.assert_allclose(out, out_ref, rtol=rtol, atol=atol, equal_nan=True, err_msg=err_msg) - - -def parity_check_gqa_past( - config: Config, - causal=True, - local=False, - past_format=Formats.BNSH, - rotary=False, - rotary_interleaved=False, - packed=False, - softcap=0.0, - use_smooth_softmax=False, - rtol=1e-3, - atol=1e-3, -): - q = torch.randn( - config.batch_size, - config.sequence_length, - config.num_heads, - config.head_size, - device="cuda", - dtype=torch.float16, - requires_grad=False, - ) - k = torch.randn( - config.batch_size, - config.kv_sequence_length if past_format == Formats.BSNH else config.kv_num_heads, - config.kv_num_heads if past_format == Formats.BSNH else config.kv_sequence_length, - config.head_size, - device="cuda", - dtype=torch.float16, - requires_grad=False, - ) - v = torch.randn( - config.batch_size, - config.kv_sequence_length if past_format == Formats.BSNH else config.kv_num_heads, - config.kv_num_heads if past_format == Formats.BSNH else config.kv_sequence_length, - config.head_size, - device="cuda", - dtype=torch.float16, - requires_grad=False, - ) - new_k = torch.randn( - config.batch_size, - config.sequence_length, - config.kv_num_heads, - config.head_size, - device="cuda", - dtype=torch.float16, - requires_grad=False, - ) - new_v = torch.randn( - config.batch_size, - config.sequence_length, - config.kv_num_heads, - config.head_size, - device="cuda", - dtype=torch.float16, - requires_grad=False, - ) - - window_size = (-1, -1) - left_window_size = -1 - if local: - left_window_size = random.randint(0, config.kv_sequence_length) - window_size = (left_window_size, 0) - elif causal: - left_window_size = -1 - window_size = (-1, 0) - - # Pytorch to compare - k_cache_ref = k.clone() - v_cache_ref = v.clone() - if past_format == Formats.BNSH: - k_cache_ref = k_cache_ref.transpose(1, 2) - v_cache_ref = v_cache_ref.transpose(1, 2) - cache_seqlens = torch.randint( - 0, - config.kv_sequence_length - config.sequence_length + 1, - (config.batch_size,), - dtype=torch.int32, - device="cuda", - ) - - if rotary: - rotary_fraction = 1.0 - rotary_dim = math.floor(int(rotary_fraction * config.head_size) / 16) * 16 - angle = torch.rand(config.kv_sequence_length, rotary_dim // 2, device="cuda") * 2 * math.pi - cos = torch.cos(angle).to(dtype=torch.float16) - sin = torch.sin(angle).to(dtype=torch.float16) - if causal or local: - q_ro = rotary_embedding(q, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved) - else: - q_ro = rearrange( - rotary_embedding( - rearrange(q, "b s h d -> b 1 (s h) d"), - cos, - sin, - seqlen_offsets=cache_seqlens, - interleaved=rotary_interleaved, - ), - "b 1 (s h) d -> b s h d", - s=config.sequence_length, - ) - k_ro = rotary_embedding(new_k, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved) - else: - cos, sin = None, None - q_ro, k_ro = q, new_k - - arange = rearrange(torch.arange(config.kv_sequence_length, device="cuda"), "s -> 1 s") - cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") - update_mask = torch.logical_and( - cache_seqlens_expanded <= arange, arange < cache_seqlens_expanded + config.sequence_length - ) - k_cache_ref[update_mask] = rearrange(k_ro, "b s ... -> (b s) ...") - v_cache_ref[update_mask] = rearrange(new_v, "b s ... -> (b s) ...") - k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) - v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) - key_padding_mask = arange < cache_seqlens_expanded + config.sequence_length - out_ref, _ = attention_ref( - q_ro, - k_cache_rep, - v_cache_rep, - None, - key_padding_mask, - 0.0, - None, - causal=True, - window_size=window_size, - softcap=softcap, - use_smooth_softmax=use_smooth_softmax, - ) - out_ref = out_ref.detach().cpu().numpy() - if past_format == Formats.BNSH: - k_cache_ref = k_cache_ref.transpose(1, 2) - v_cache_ref = v_cache_ref.transpose(1, 2) - - cache_seqlens += config.sequence_length - 1 - - # Flash function - if packed: - packed_qkv = torch.concatenate([q, new_k, new_v], dim=2) - out, present_k, present_v = gqa_past_func( - packed_qkv, - k, - v, - config, - None, - None, - cos, - sin, - cache_seqlens, - past_format, - True, - left_window_size, - rotary_interleaved, - softcap, - use_smooth_softmax=use_smooth_softmax, - ) - else: - out, present_k, present_v = gqa_past_func( - q, - k, - v, - config, - new_k, - new_v, - cos, - sin, - cache_seqlens, - past_format, - True, - left_window_size, - rotary_interleaved, - softcap, - use_smooth_softmax=use_smooth_softmax, - ) - out = torch.squeeze(out, 0) - out = torch.reshape(out, (config.batch_size, config.sequence_length, config.num_heads, config.head_size)) - out = out.detach().cpu().numpy() - - err_msg = ( - f" with {config}, causal={causal}, local={local}, past_format={past_format}," - f" rotary={rotary}, rotary_interleaved={rotary_interleaved}, packed={packed}, softcap={softcap}" - ) - # Make sure past-present buffer updating correctly - numpy.testing.assert_allclose( - present_k, k_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True, err_msg=err_msg - ) - numpy.testing.assert_allclose( - present_v, v_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True, err_msg=err_msg - ) - numpy.testing.assert_allclose(out, out_ref, rtol=rtol, atol=atol, equal_nan=True, err_msg=err_msg) - - -def parity_check_gqa_past_no_buff( - config: Config, - causal=True, - local=False, - past_format=Formats.BNSH, - rotary=False, - rotary_interleaved=False, - packed=False, - softcap=0.0, - use_smooth_softmax=False, - rtol=1e-3, - atol=1e-3, -): - torch.manual_seed(69) - q = torch.randn( - config.batch_size, - config.sequence_length, - config.num_heads, - config.head_size, - device="cuda", - dtype=torch.float16, - requires_grad=False, - ) - k = torch.randn( - config.batch_size, - config.kv_sequence_length if past_format == Formats.BSNH else config.kv_num_heads, - config.kv_num_heads if past_format == Formats.BSNH else config.kv_sequence_length, - config.head_size, - device="cuda", - dtype=torch.float16, - requires_grad=False, - ) - v = torch.randn( - config.batch_size, - config.kv_sequence_length if past_format == Formats.BSNH else config.kv_num_heads, - config.kv_num_heads if past_format == Formats.BSNH else config.kv_sequence_length, - config.head_size, - device="cuda", - dtype=torch.float16, - requires_grad=False, - ) - new_k = torch.randn( - config.batch_size, - config.sequence_length, - config.kv_num_heads, - config.head_size, - device="cuda", - dtype=torch.float16, - requires_grad=False, - ) - new_v = torch.randn( - config.batch_size, - config.sequence_length, - config.kv_num_heads, - config.head_size, - device="cuda", - dtype=torch.float16, - requires_grad=False, - ) - - window_size = (-1, -1) - left_window_size = -1 - if local: - left_window_size = random.randint(0, config.kv_sequence_length) - window_size = (left_window_size, 0) - elif causal: - left_window_size = -1 - window_size = (-1, 0) - - # Pytorch to compare - k_cache_ref = k.clone() - v_cache_ref = v.clone() - if past_format == Formats.BNSH: - k_cache_ref = k_cache_ref.transpose(1, 2) - v_cache_ref = v_cache_ref.transpose(1, 2) - k_cache_ref = torch.cat((k_cache_ref, new_k), 1) - v_cache_ref = torch.cat((v_cache_ref, new_v), 1) - cache_seqlens = torch.randint( - 0, - config.kv_sequence_length, - (config.batch_size,), - dtype=torch.int32, - device="cuda", - ) - cache_seqlens[random.randint(0, config.batch_size - 1)] = config.kv_sequence_length - - if rotary: - rotary_fraction = 1.0 - rotary_dim = math.floor(int(rotary_fraction * config.head_size) / 16) * 16 - angle = ( - torch.rand(config.kv_sequence_length + config.sequence_length, rotary_dim // 2, device="cuda") * 2 * math.pi - ) - cos = torch.cos(angle).to(dtype=torch.float16) - sin = torch.sin(angle).to(dtype=torch.float16) - if causal or local: - q_ro = rotary_embedding(q, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved) - else: - q_ro = rearrange( - rotary_embedding( - rearrange(q, "b s h d -> b 1 (s h) d"), - cos, - sin, - seqlen_offsets=cache_seqlens, - interleaved=rotary_interleaved, - ), - "b 1 (s h) d -> b s h d", - s=config.sequence_length, - ) - k_ro = rotary_embedding(new_k, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved) - else: - cos, sin = None, None - q_ro, k_ro = q, new_k - - arange = rearrange(torch.arange(config.kv_sequence_length + config.sequence_length, device="cuda"), "s -> 1 s") - cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") - update_mask = torch.logical_and( - cache_seqlens_expanded <= arange, arange < cache_seqlens_expanded + config.sequence_length - ) - k_cache_ref[update_mask] = rearrange(k_ro, "b s ... -> (b s) ...") - v_cache_ref[update_mask] = rearrange(new_v, "b s ... -> (b s) ...") - k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) - v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) - key_padding_mask = arange < cache_seqlens_expanded + config.sequence_length - out_ref, _ = attention_ref( - q_ro, - k_cache_rep, - v_cache_rep, - None, - key_padding_mask, - 0.0, - None, - causal=True, - window_size=window_size, - softcap=softcap, - use_smooth_softmax=use_smooth_softmax, - ) - out_ref = out_ref.detach().cpu().numpy() - if past_format == Formats.BNSH: - k_cache_ref = k_cache_ref.transpose(1, 2) - v_cache_ref = v_cache_ref.transpose(1, 2) - - cache_seqlens += config.sequence_length - 1 - - # Flash function - if packed: - packed_qkv = torch.concatenate([q, new_k, new_v], dim=2) - out, present_k, present_v = gqa_past_func( - packed_qkv, - k, - v, - config, - None, - None, - cos, - sin, - cache_seqlens, - past_format, - False, - window_size=left_window_size, - rotary_interleaved=rotary_interleaved, - softcap=softcap, - use_smooth_softmax=use_smooth_softmax, - ) - else: - out, present_k, present_v = gqa_past_func( - q, - k, - v, - config, - new_k, - new_v, - cos, - sin, - cache_seqlens, - past_format, - False, - window_size=left_window_size, - rotary_interleaved=rotary_interleaved, - softcap=softcap, - use_smooth_softmax=use_smooth_softmax, - ) - out = torch.squeeze(out, 0) - out = torch.reshape(out, (config.batch_size, config.sequence_length, config.num_heads, config.head_size)) - out = out.detach().cpu().numpy() - - err_msg = ( - f" with {config}, causal={causal}, local={local}, past_format={past_format}," - f" rotary={rotary}, rotary_interleaved={rotary_interleaved}, packed={packed}, softcap={softcap}" - ) - for b in range(config.batch_size): - numpy.testing.assert_allclose( - present_k[b, :, : (cache_seqlens + 1)[b]], - k_cache_ref[b, :, : (cache_seqlens + 1)[b]].detach().cpu().numpy(), - rtol=rtol, - atol=atol, - equal_nan=True, - err_msg=err_msg, - ) - numpy.testing.assert_allclose( - present_v[b, :, : (cache_seqlens + 1)[b]], - v_cache_ref[b, :, : (cache_seqlens + 1)[b]].detach().cpu().numpy(), - rtol=rtol, - atol=atol, - equal_nan=True, - err_msg=err_msg, - ) - numpy.testing.assert_allclose(out, out_ref, rtol=rtol, atol=atol, equal_nan=True, err_msg=err_msg) - - -def has_flash_attention(): - if not torch.cuda.is_available(): - return False - if "CUDAExecutionProvider" not in get_available_providers(): - return False - major, _ = torch.cuda.get_device_capability() - return major >= 8 and ( - platform.system() == "Linux" - or (platform.system() == "Windows" and version.parse(torch.version.cuda) >= version.parse("12.0")) - ) - - -def has_memory_efficient(): - if not torch.cuda.is_available(): - return False - if "CUDAExecutionProvider" not in get_available_providers(): - return False - major, minor = torch.cuda.get_device_capability() - if major < 5 or (major == 5 and minor < 3): - return False - return True - - -def gqa_no_past_memory_efficient_test_cases(): - batches = [3] if pipeline_mode else [1, 3, 5] - seqs = ( - [ - (2000, 2000), - ] - if pipeline_mode - else [ - (127, 127), - (35, 35), - (2000, 2000), - (200, 200), - (240, 240), - ] - ) - num_h = [(9, 3)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] - h_sizes = [128] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] - torch.manual_seed(69) - - for b in batches: - for sq, skv in seqs: - for n, n2 in num_h: - for h in h_sizes: - for local in [False, True]: - for rotary, rotary_interleaved in rotary_options_for_current_os(): - for packed in [False, True]: - for softcap in [0.0, 50.0]: - config = PromptConfig(b, sq, skv, sq + skv + 8, n, n2, h) - if rotary and h % 16 > 0: - continue - yield ( - str(config) + f"{local}_{rotary}_{rotary_interleaved}_{packed}", - config, - local, - rotary, - rotary_interleaved, - packed, - softcap, - ) - - -def gqa_no_past_flash_attention_test_cases(): - batches = [3] if pipeline_mode else [1, 3, 5] - seqs = ( - [ - (240, 240), - ] - if pipeline_mode - else [ - (127, 127), - (35, 35), - (2000, 2000), - (200, 200), - (240, 240), - ] - ) - num_h = [(32, 8)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] - h_sizes = [128] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] - torch.manual_seed(69) - - for b in batches: - for sq, skv in seqs: - for n, n2 in num_h: - for h in h_sizes: - for local in [False, True]: - for rotary, rotary_interleaved in rotary_options_for_current_os(): - for packed in [False, True]: - for softcap in [0.0, 50.0]: - if rotary and h % 16 > 0: - continue - - config = PromptConfig(b, sq, skv, sq + skv + 8, n, n2, h) - yield ( - str(config) + f"{local}_{rotary}_{rotary_interleaved}_{packed}_{softcap}", - config, - local, - rotary, - rotary_interleaved, - packed, - softcap, - ) - - -def gqa_past_memory_efficient_test_cases(): - batches = [5] if pipeline_mode else [1, 3, 5] - seqs = ( - [(1, 1024)] - if pipeline_mode - else [ - (1, 128), - (1, 339), - (1, 1024), - (1, 5000), - (1, 800), - (1, 256), - (1, 799), - (1, 2048), - # (1, 128 * 512), - # (16, 128 * 512), - # (128, 128), - ] - ) - num_h = [(32, 8)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] - h_sizes = [256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] - random.seed(69) - - for b in batches: - for s, s2 in seqs: - for n, n2 in num_h: - for h in h_sizes: - for local in [False, True]: - for rotary, rotary_interleaved in rotary_options_for_current_os(): - for packed in [False, True]: - for softcap in [0.0, 50.0]: - if rotary and h % 16 > 0: - continue - config = Config(b, s, s2, n, n2, h) - yield ( - str(config) + f"{local}_{rotary}_{rotary_interleaved}_{packed}_{softcap}", - config, - local, - rotary, - rotary_interleaved, - packed, - softcap, - ) - - -def gqa_past_flash_attention_test_cases(): - batches = [5] if pipeline_mode else [1, 3, 5] - seqs = ( - [(1, 2048)] - if pipeline_mode - else [ - (1, 128), - (1, 339), - (1, 1024), - (1, 5000), - (1, 800), - (1, 256), - (1, 799), - (1, 2048), - # (1, 128 * 512), - # (16, 128 * 512), - # (128, 128), - ] - ) - num_h = [(32, 8)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] - h_sizes = [256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] - random.seed(69) - - for b in batches: - for s, s2 in seqs: - for n, n2 in num_h: - for h in h_sizes: - for local in [False, True]: - for rotary, rotary_interleaved in rotary_options_for_current_os(): - for packed in [False, True]: - for softcap in [0.0, 50.0]: - if rotary and h % 16 > 0: - continue - - config = Config(b, s, s2, n, n2, h) - yield ( - str(config) + f"{local}_{rotary}_{rotary_interleaved}_{packed}_{softcap}", - config, - local, - rotary, - rotary_interleaved, - packed, - softcap, - ) - - -def gqa_interactive_one_batch_flash_attention_test_cases(): - batches = [1] - seqs = ( - [(128, 2048)] - if pipeline_mode - else [ - (1, 128), - (32, 128), - (128, 2048), - (1235, 5000), - (40, 800), - (1, 256), - (2, 799), - (41, 2048), - # (1, 128 * 512), - # (16, 128 * 512), - # (128, 128), - ] - ) - num_h = [(9, 3)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] - h_sizes = [64] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] - random.seed(69) - - for b in batches: - for s, s2 in seqs: - for n, n2 in num_h: - for h in h_sizes: - for local in [False, True]: - for rotary, rotary_interleaved in rotary_options_for_current_os(): - for packed in [False, True]: - if rotary and h % 16 > 0: - continue - - config = Config(b, s, s2, n, n2, h) - yield ( - str(config) + f"{local}_{rotary}_{rotary_interleaved}_{packed}", - config, - local, - rotary, - rotary_interleaved, - packed, - ) - - -def gqa_interactive_one_batch_memory_efficient_attention_test_cases(): - batches = [1] - seqs = ( - [(32, 128)] - if pipeline_mode - else [ - (1, 128), - (32, 128), - (128, 2048), - (1235, 5000), - (40, 800), - (1, 256), - (2, 799), - (41, 2048), - # (1, 128 * 512), - # (16, 128 * 512), - # (128, 128), - ] - ) - num_h = [(9, 3)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] - h_sizes = [64] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] - random.seed(69) - - for b in batches: - for s, s2 in seqs: - for n, n2 in num_h: - for h in h_sizes: - for rotary, rotary_interleaved in rotary_options_for_current_os(): - for packed in [False, True]: - if rotary and h % 16 > 0: - continue - - config = Config(b, s, s2, n, n2, h) - yield ( - str(config) + f"{rotary}_{rotary_interleaved}_{packed}", - config, - rotary, - rotary_interleaved, - packed, - ) - - -@unittest.skipIf(not has_flash_attention(), reason="Flash Attention is not available, skipping tests.") -class TestFlashGQA(unittest.TestCase): - @parameterized.expand(gqa_no_past_flash_attention_test_cases()) - def test_gqa_no_past_flash_attention(self, _, config, local, rotary, rotary_interleaved, packed, softcap): - print("------- FLASH ATTENTION (PROMPT CASE) --------") - os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0" - - parity_check_gqa_prompt( - config, - local=local, - past_format=Formats.BNSH, - rotary=rotary, - rotary_interleaved=rotary_interleaved, - packed=packed, - softcap=softcap, - use_smooth_softmax=True, - ) - parity_check_gqa_prompt_no_buff( - config, - local=local, - past_format=Formats.BNSH, - rotary=rotary, - rotary_interleaved=rotary_interleaved, - packed=packed, - softcap=softcap, - use_smooth_softmax=False, - ) - - @parameterized.expand(gqa_past_flash_attention_test_cases()) - def test_gqa_past_flash_attention(self, _, config, local, rotary, rotary_interleaved, packed, softcap): - print("------- FLASH ATTENTION (TOKEN GEN) -------") - os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0" - - parity_check_gqa_past( - config, - local=local, - past_format=Formats.BNSH, - rtol=1e-3, - atol=1e-3, - rotary=rotary, - rotary_interleaved=rotary_interleaved, - packed=packed, - softcap=softcap, - use_smooth_softmax=False, - ) - parity_check_gqa_past_no_buff( - config, - local=local, - past_format=Formats.BNSH, - rtol=1e-3, - atol=1e-3, - rotary=rotary, - rotary_interleaved=rotary_interleaved, - packed=packed, - softcap=softcap, - use_smooth_softmax=True, - ) - - @parameterized.expand(gqa_interactive_one_batch_flash_attention_test_cases()) - def test_gqa_interactive_one_batch_flash_attention(self, _, config, local, rotary, rotary_interleaved, packed): - print("------- FLASH ATTENTION (INTERACTIVE) -------") - os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0" - - parity_check_gqa_past( - config, - local=local, - past_format=Formats.BNSH, - rtol=5e-3, - atol=5e-3, - rotary=rotary, - rotary_interleaved=rotary_interleaved, - packed=packed, - ) - parity_check_gqa_past_no_buff( - config, - local=local, - past_format=Formats.BNSH, - rtol=5e-3, - atol=5e-3, - rotary=rotary, - rotary_interleaved=rotary_interleaved, - packed=packed, - ) - - -@unittest.skipIf(not has_memory_efficient(), reason="Memory efficient FMHA is not available, skipping tests.") -class TestMemoryEfficientGQA(unittest.TestCase): - @parameterized.expand(gqa_no_past_memory_efficient_test_cases()) - def test_gqa_no_past_memory_efficient(self, _, config, local, rotary, rotary_interleaved, packed, softcap): - os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1" - print("------- MEMORY EFFICIENT ATTENTION (PROMPT CASE) ---------") - - parity_check_gqa_prompt( - config, - local=local, - rtol=5e-3, - atol=5e-3, - past_format=Formats.BNSH, - rotary=rotary, - rotary_interleaved=rotary_interleaved, - packed=packed, - softcap=softcap, - use_smooth_softmax=False, - ) - parity_check_gqa_prompt_no_buff( - config, - local=local, - rtol=5e-3, - atol=5e-3, - past_format=Formats.BNSH, - rotary=rotary, - rotary_interleaved=rotary_interleaved, - packed=packed, - softcap=softcap, - use_smooth_softmax=True, - ) - - @parameterized.expand(gqa_past_memory_efficient_test_cases()) - def test_gqa_past_memory_efficient(self, _, config, local, rotary, rotary_interleaved, packed, softcap): - os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1" - print("-------- MEMORY EFFICIENT (TOKEN GEN) --------") - - parity_check_gqa_past( - config, - local=local, - past_format=Formats.BNSH, - rtol=1e-3, - atol=1e-3, - rotary=rotary, - rotary_interleaved=rotary_interleaved, - packed=packed, - softcap=softcap, - use_smooth_softmax=True, - ) - parity_check_gqa_past_no_buff( - config, - local=local, - past_format=Formats.BNSH, - rtol=1e-3, - atol=1e-3, - rotary=rotary, - rotary_interleaved=rotary_interleaved, - packed=packed, - softcap=softcap, - use_smooth_softmax=False, - ) - - @parameterized.expand(gqa_interactive_one_batch_memory_efficient_attention_test_cases()) - def test_gqa_interactive_one_batch_memory_efficient_attention(self, _, config, rotary, rotary_interleaved, packed): - os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1" - print("-------- MEMORY EFFICIENT (INTERACTIVE) --------") - - parity_check_gqa_past( - config, - past_format=Formats.BNSH, - rtol=5e-3, - atol=5e-3, - rotary=rotary, - rotary_interleaved=rotary_interleaved, - packed=packed, - ) - parity_check_gqa_past_no_buff( - config, - past_format=Formats.BNSH, - rtol=5e-3, - atol=5e-3, - rotary=rotary, - rotary_interleaved=rotary_interleaved, - packed=packed, - ) - - -if __name__ == "__main__": - unittest.main() diff --git a/onnxruntime/test/python/transformers/test_gqa_rocm.py b/onnxruntime/test/python/transformers/test_gqa_rocm.py deleted file mode 100644 index 29ae1b6e44a78..0000000000000 --- a/onnxruntime/test/python/transformers/test_gqa_rocm.py +++ /dev/null @@ -1,81 +0,0 @@ -import platform -import unittest - -import torch -from parameterized import parameterized -from test_gqa_cuda import ( - Formats, - gqa_no_past_flash_attention_test_cases, - gqa_past_flash_attention_test_cases, - parity_check_gqa_past, - parity_check_gqa_past_no_buff, - parity_check_gqa_prompt, - parity_check_gqa_prompt_no_buff, -) - -import onnxruntime - - -@unittest.skipIf( - (not torch.cuda.is_available()) - or (platform.system() != "Linux") - or ("ROCMExecutionProvider" not in onnxruntime.get_available_providers()), - reason="ROCm is not available, skipping tests.", -) -class TestRocmGQA(unittest.TestCase): - @parameterized.expand(gqa_no_past_flash_attention_test_cases()) - def test_gqa_no_past_flash_attention(self, _, config, local, rotary, rotary_interleaved, packed, softcap): - config.ep = "ROCMExecutionProvider" - print("------- FLASH ATTENTION (PROMPT CASE) --------") - - parity_check_gqa_prompt( - config, - local=local, - past_format=Formats.BNSH, - rotary=rotary, - rotary_interleaved=rotary_interleaved, - packed=packed, - rtol=0.001, - atol=0.005, - ) - - parity_check_gqa_prompt_no_buff( - config, - local=local, - past_format=Formats.BNSH, - rotary=rotary, - rotary_interleaved=rotary_interleaved, - packed=packed, - rtol=0.001, - atol=0.005, - ) - - @parameterized.expand(gqa_past_flash_attention_test_cases()) - def test_gqa_past_flash_attention(self, _, config, local, rotary, rotary_interleaved, packed, softcap): - config.ep = "ROCMExecutionProvider" - print("------- FLASH ATTENTION (TOKEN GEN) -------") - - parity_check_gqa_past( - config, - local=local, - past_format=Formats.BNSH, - rotary=rotary, - rotary_interleaved=rotary_interleaved, - packed=packed, - rtol=0.001, - atol=0.005, - ) - parity_check_gqa_past_no_buff( - config, - local=local, - past_format=Formats.BNSH, - rotary=rotary, - rotary_interleaved=rotary_interleaved, - packed=packed, - rtol=0.001, - atol=0.005, - ) - - -if __name__ == "__main__": - unittest.main() diff --git a/onnxruntime/test/python/transformers/test_mha_flash_attn.py b/onnxruntime/test/python/transformers/test_mha_flash_attn.py index f87370e37d21a..a015ce6979f91 100644 --- a/onnxruntime/test/python/transformers/test_mha_flash_attn.py +++ b/onnxruntime/test/python/transformers/test_mha_flash_attn.py @@ -12,7 +12,7 @@ from einops import rearrange, repeat from onnx import TensorProto, helper from parameterized import parameterized -from test_gqa_cuda import attention_ref, has_flash_attention +from test_gqa import attention_ref, has_flash_attention from onnxruntime import InferenceSession, SessionOptions @@ -303,24 +303,16 @@ def mha_func(q, k, v, config): def attention_qkvpacked_ref( qkv, key_padding_mask=None, - dropout_p=0.0, - dropout_mask=None, causal=False, - upcast=True, - reorder_ops=False, use_smooth_softmax=False, ): return attention_ref( qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], - key_padding_mask, - key_padding_mask, - dropout_p, - dropout_mask, - upcast=upcast, + query_padding_mask=key_padding_mask, + key_padding_mask=key_padding_mask, causal=causal, - reorder_ops=reorder_ops, use_smooth_softmax=use_smooth_softmax, ) @@ -344,7 +336,7 @@ def parity_check_mha( ) out = out.detach().cpu().numpy() # Pytorch to compare - out_ref, _ = attention_qkvpacked_ref(qkv, key_padding_mask, 0.0, None, causal=False) + out_ref, _ = attention_qkvpacked_ref(qkv, key_padding_mask, causal=False) out_ref = out_ref.detach().cpu().numpy() else: q = torch.randn(