diff --git a/.markdownlint.yaml b/.markdownlint.yaml new file mode 100644 index 000000000..cd9df57cd --- /dev/null +++ b/.markdownlint.yaml @@ -0,0 +1,12 @@ +MD007: + indent: 4 +MD013: false +MD024: + siblings_only: true +MD033: false +MD045: false +MD046: false +MD051: false +MD052: false +MD053: false +MD059: false diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index be9c5bd18..292b93fc2 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -6,30 +6,19 @@ default_stages: - manual # Run in CI exclude: 'vllm/third_party/.*' repos: -- repo: https://github.com/google/yapf - rev: v0.43.0 - hooks: - - id: yapf - args: [--in-place, --verbose] - # Keep the same list from yapfignore here to avoid yapf failing without any inputs - exclude: '(.buildkite|benchmarks|build|examples)/.*' - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.11.7 + rev: v0.14.0 hooks: - - id: ruff + - id: ruff-check args: [--output-format, github, --fix] - id: ruff-format - files: ^(.buildkite|benchmarks|examples)/.* - repo: https://github.com/crate-ci/typos - rev: v1.34.0 + rev: v1.38.1 hooks: - id: typos -- repo: https://github.com/PyCQA/isort - rev: 6.0.1 - hooks: - - id: isort + args: [--force-exclude] - repo: https://github.com/pre-commit/mirrors-clang-format - rev: v20.1.3 + rev: v21.1.2 hooks: - id: clang-format exclude: 'csrc/(moe/topk_softmax_kernels.cu|quantization/gguf/(ggml-common.h|dequantize.cuh|vecdotq.cuh|mmq.cuh|mmvq.cuh))|vllm/third_party/.*' @@ -40,7 +29,7 @@ repos: hooks: - id: actionlint - repo: https://github.com/astral-sh/uv-pre-commit - rev: 0.6.17 + rev: 0.9.1 hooks: - id: pip-compile args: [requirements/test.in, -o, requirements/test.txt, --index-strategy, unsafe-best-match, --torch-backend, cpu] @@ -48,36 +37,32 @@ repos: - repo: local hooks: - id: mypy-local - name: Run mypy for local Python installation - entry: tools/mypy.sh 0 "local" - language: python - types: [python] - additional_dependencies: &mypy_deps [mypy==1.11.1, types-cachetools, types-setuptools, types-PyYAML, types-requests, pydantic] + name: Run mypy locally for lowest supported Python version + entry: python tools/pre_commit/mypy.py 0 "3.10" stages: [pre-commit] # Don't run in CI + <<: &mypy_common + language: python + types_or: [python, pyi] + require_serial: true + additional_dependencies: [mypy==1.11.1, regex, types-cachetools, types-setuptools, types-PyYAML, types-requests, types-torch, pydantic] - id: mypy-3.10 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward name: Run mypy for Python 3.10 - entry: tools/mypy.sh 1 "3.10" - language: python - types: [python] - additional_dependencies: *mypy_deps + entry: python tools/pre_commit/mypy.py 1 "3.10" + <<: *mypy_common stages: [manual] # Only run in CI - id: mypy-3.11 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward name: Run mypy for Python 3.11 - entry: tools/mypy.sh 1 "3.11" - language: python - types: [python] - additional_dependencies: *mypy_deps + entry: python tools/pre_commit/mypy.py 1 "3.11" + <<: *mypy_common stages: [manual] # Only run in CI - id: mypy-3.12 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward name: Run mypy for Python 3.12 - entry: tools/mypy.sh 1 "3.12" - language: python - types: [python] - additional_dependencies: *mypy_deps + entry: python tools/pre_commit/mypy.py 1 "3.12" + <<: *mypy_common stages: [manual] # Only run in CI - id: shellcheck name: Lint shell scripts - entry: tools/shellcheck.sh + entry: tools/pre_commit/shellcheck.sh language: script types: [shell] - id: png-lint @@ -116,7 +101,7 @@ repos: pass_filenames: false - id: enforce-import-regex-instead-of-re name: Enforce import regex as re - entry: python tools/enforce_regex_import.py + entry: python tools/pre_commit/enforce_regex_import.py language: python types: [python] pass_filenames: false diff --git a/.shellcheckrc b/.shellcheckrc new file mode 100644 index 000000000..f3b6eedf8 --- /dev/null +++ b/.shellcheckrc @@ -0,0 +1,9 @@ +# rules currently disabled: +# +# SC1091 (info): Not following: was not specified as input (see shellcheck -x) +# SC2004 (style): $/${} is unnecessary on arithmetic variables. +# SC2129 (style): Consider using { cmd1; cmd2; } >> file instead of individual redirects. +# SC2155 (warning): Declare and assign separately to avoid masking return values. +# SC2164 (warning): Use 'cd ... || exit' or 'cd ... || return' in case cd fails. +# +disable=SC1091,SC2004,SC2129,SC2155,SC2164 diff --git a/.yapfignore b/.yapfignore index 2d6dcf838..381582590 100644 --- a/.yapfignore +++ b/.yapfignore @@ -1 +1,2 @@ collect_env.py +vllm/model_executor/layers/fla/ops/*.py diff --git a/cmake/hipify.py b/cmake/hipify.py index a15577125..7067f6658 100755 --- a/cmake/hipify.py +++ b/cmake/hipify.py @@ -15,7 +15,7 @@ from torch.utils.hipify.hipify_python import hipify -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() # Project directory where all the source + include files live. @@ -33,15 +33,14 @@ ) # Source files to convert. - parser.add_argument("sources", - help="Source files to hipify.", - nargs="*", - default=[]) + parser.add_argument( + "sources", help="Source files to hipify.", nargs="*", default=[] + ) args = parser.parse_args() # Limit include scope to project_dir only - includes = [os.path.join(args.project_dir, '*')] + includes = [os.path.join(args.project_dir, "*")] # Get absolute path for all source files. extra_files = [os.path.abspath(s) for s in args.sources] @@ -50,25 +49,31 @@ # The directory might already exist to hold object files so we ignore that. shutil.copytree(args.project_dir, args.output_dir, dirs_exist_ok=True) - hipify_result = hipify(project_directory=args.project_dir, - output_directory=args.output_dir, - header_include_dirs=[], - includes=includes, - extra_files=extra_files, - show_detailed=True, - is_pytorch_extension=True, - hipify_extra_files_only=True) + hipify_result = hipify( + project_directory=args.project_dir, + output_directory=args.output_dir, + header_include_dirs=[], + includes=includes, + extra_files=extra_files, + show_detailed=True, + is_pytorch_extension=True, + hipify_extra_files_only=True, + ) hipified_sources = [] for source in args.sources: s_abs = os.path.abspath(source) - hipified_s_abs = (hipify_result[s_abs].hipified_path if - (s_abs in hipify_result - and hipify_result[s_abs].hipified_path is not None) - else s_abs) + hipified_s_abs = ( + hipify_result[s_abs].hipified_path + if ( + s_abs in hipify_result + and hipify_result[s_abs].hipified_path is not None + ) + else s_abs + ) hipified_sources.append(hipified_s_abs) - assert (len(hipified_sources) == len(args.sources)) + assert len(hipified_sources) == len(args.sources) # Print hipified source files. print("\n".join(hipified_sources)) diff --git a/csrc/cache.h b/csrc/cache.h index c06b64c98..1d324249f 100644 --- a/csrc/cache.h +++ b/csrc/cache.h @@ -70,3 +70,18 @@ void indexer_k_quant_and_cache( torch::Tensor& slot_mapping, // [num_tokens] int64_t quant_block_size, // quantization block size const std::string& scale_fmt); + +// Extract function to gather quantized K cache +void cp_gather_indexer_k_cache( + const torch::Tensor& kv_cache, // [num_blocks, block_size, cache_stride] + torch::Tensor& dst_k, // [num_tokens, head_dim] + const torch::Tensor& block_table, // [batch_size, num_blocks] + const torch::Tensor& cu_seq_lens); // [batch_size + 1] + +// Extract function to gather quantized K cache +void cp_gather_indexer_k_quant_cache( + const torch::Tensor& kv_cache, // [num_blocks, block_size, cache_stride] + torch::Tensor& dst_k, // [num_tokens, head_dim] + torch::Tensor& dst_scale, // [num_tokens, head_dim / quant_block_size * 4] + const torch::Tensor& block_table, // [batch_size, num_blocks] + const torch::Tensor& cu_seq_lens); // [batch_size + 1] \ No newline at end of file diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 5cda43de9..9ed7859a2 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -938,6 +938,124 @@ __global__ void indexer_k_cache_kernel( kv_cache[dst_offset + i] = k_val_ptr[i]; } } + +template +__global__ void cp_gather_indexer_k_quant_cache_kernel( + const char* __restrict__ kv_cache, // [num_blocks, block_size, + // cache_stride] + char* __restrict__ dst_k, // [num_tokens, head_dim] + char* __restrict__ dst_scale, // [num_tokens, head_dim / quant_block_size * + // 4] + const int* __restrict__ block_table, // [batch_size, num_blocks] + const int* __restrict__ cu_seq_lens, // [batch_size + 1] + const int batch_size, // batch size + const int64_t token_stride, // stride for each token in dst_k + const int64_t head_dim, // dimension of each head + const int64_t block_stride, // stride for each block in kv_cache + const int64_t cache_token_stride, // stride for each token in kv_cache + const int64_t cache_block_size, // num_tokens for each block in kv_cache + const int num_blocks, // number of blocks + const int num_tokens, // number of tokens + const int quant_block_size // quantization block size +) { + constexpr int VEC_SIZE = sizeof(float4) / sizeof(char); + const int token_idx = blockIdx.x * blockDim.y + threadIdx.y; + const int head_idx = (blockIdx.y * blockDim.x + threadIdx.x) * VEC_SIZE; + // Find batch index within a block + __shared__ int batch_idx[BLOCK_Y_SIZE]; + for (int iter = 0; iter < cuda_utils::ceil_div(batch_size, int(blockDim.x)); + iter++) { + int tid = iter * blockDim.x + threadIdx.x; + if (tid < batch_size) { + const int seq_start = cu_seq_lens[tid]; + const int seq_end = cu_seq_lens[tid + 1]; + if (token_idx >= seq_start && token_idx < seq_end) { + batch_idx[threadIdx.y] = tid; + } + } + } + +#ifndef USE_ROCM + __syncwarp(); +#endif + + if (head_idx >= head_dim || token_idx >= num_tokens) { + return; + } + const int inbatch_seq_idx = token_idx - cu_seq_lens[batch_idx[threadIdx.y]]; + const int block_idx = block_table[batch_idx[threadIdx.y] * num_blocks + + inbatch_seq_idx / cache_block_size]; + const int64_t src_block_offset = block_idx * block_stride; + const int64_t cache_inblock_offset = + (inbatch_seq_idx % cache_block_size) * head_dim + head_idx; + const int64_t src_inblock_offset = src_block_offset + cache_inblock_offset; + const int64_t dst_inblock_offset = token_idx * token_stride + head_idx; + + reinterpret_cast(dst_k)[dst_inblock_offset / VEC_SIZE] = + reinterpret_cast(kv_cache)[src_inblock_offset / VEC_SIZE]; + ; + if (threadIdx.x == 0) { + const int64_t src_scale_offset = + src_block_offset + cache_block_size * head_dim + + cache_inblock_offset * 4 / quant_block_size; + reinterpret_cast(dst_scale)[dst_inblock_offset / quant_block_size] = + reinterpret_cast(kv_cache)[src_scale_offset / 4]; + } +} + +template +__global__ void cp_gather_indexer_k_cache_kernel( + const char* __restrict__ kv_cache, // [num_blocks, block_size, + // cache_stride] + char* __restrict__ dst_k, // [num_tokens, head_dim] + const int* __restrict__ block_table, // [batch_size, num_blocks] + const int* __restrict__ cu_seq_lens, // [batch_size + 1] + const int batch_size, // batch size + const int64_t token_stride, // stride for each token in dst_k + const int64_t head_dim, // dimension of each head + const int64_t block_stride, // stride for each block in kv_cache + const int64_t cache_token_stride, // stride for each token in kv_cache + const int64_t cache_block_size, // num_tokens for each block in kv_cache + const int num_blocks, // number of blocks + const int num_tokens // number of tokens +) { + constexpr int VEC_SIZE = sizeof(float4) / sizeof(char); + const int token_idx = blockIdx.x * blockDim.y + threadIdx.y; + const int head_idx = (blockIdx.y * blockDim.x + threadIdx.x) * VEC_SIZE; + // Find batch index within a block + __shared__ int batch_idx[BLOCK_Y_SIZE]; + for (int iter = 0; iter < cuda_utils::ceil_div(batch_size, int(blockDim.x)); + iter++) { + int tid = iter * blockDim.x + threadIdx.x; + if (tid < batch_size) { + const int seq_start = cu_seq_lens[tid]; + const int seq_end = cu_seq_lens[tid + 1]; + if (token_idx >= seq_start && token_idx < seq_end) { + batch_idx[threadIdx.y] = tid; + } + } + } + +#ifndef USE_ROCM + __syncwarp(); +#endif + + if (head_idx >= head_dim || token_idx >= num_tokens) { + return; + } + const int inbatch_seq_idx = token_idx - cu_seq_lens[batch_idx[threadIdx.y]]; + const int block_idx = block_table[batch_idx[threadIdx.y] * num_blocks + + inbatch_seq_idx / cache_block_size]; + const int64_t src_block_offset = block_idx * block_stride; + const int64_t cache_inblock_offset = + (inbatch_seq_idx % cache_block_size) * head_dim + head_idx; + const int64_t src_inblock_offset = src_block_offset + cache_inblock_offset; + const int64_t dst_inblock_offset = token_idx * token_stride + head_idx; + + reinterpret_cast(dst_k)[dst_inblock_offset / VEC_SIZE] = + reinterpret_cast(kv_cache)[src_inblock_offset / VEC_SIZE]; +} + } // namespace vllm // Macro to dispatch the kernel based on the data type. @@ -1083,4 +1201,114 @@ void indexer_k_cache( const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); DISPATCH_BY_KV_CACHE_DTYPE(k.dtype(), "auto", CALL_INDEXER_K_CACHE); +} + +// Macro to dispatch the kernel based on the data amount. +#define CALL_CP_GATHER_INDEXER_K_CACHE(BLOCK_Y_SIZE) \ + vllm::cp_gather_indexer_k_cache_kernel \ + <<>>( \ + reinterpret_cast(kv_cache.data_ptr()), \ + reinterpret_cast(dst_k.data_ptr()), \ + block_table.data_ptr(), cu_seq_lens.data_ptr(), \ + batch_size, dst_k.stride(0), dst_k.size(1), kv_cache.stride(0), \ + kv_cache.stride(1), kv_cache.size(1), block_table.size(1), \ + num_tokens); + +void cp_gather_indexer_k_cache( + const torch::Tensor& kv_cache, // [num_blocks, block_size, cache_stride] + torch::Tensor& dst_k, // [num_tokens, head_dim] + const torch::Tensor& block_table, // [batch_size, num_blocks] + const torch::Tensor& cu_seq_lens // [batch_size + 1] +) { + int batch_size = block_table.size(0); + int num_tokens = dst_k.size(0); + int head_dim = dst_k.size(1); + // int quant_block_size = head_dim * 4 / dst_scale.size(1); + + TORCH_CHECK(kv_cache.device() == dst_k.device(), + "kv_cache and dst_k must be on the same device"); + // TORCH_CHECK(kv_cache.device() == dst_scale.device(), + // "kv_cache and dst_scale must be on the same device"); + TORCH_CHECK(kv_cache.device() == block_table.device(), + "kv_cache and block_table must be on the same device"); + TORCH_CHECK(kv_cache.device() == cu_seq_lens.device(), + "kv_cache and cu_seq_lens must be on the same device"); + // TORCH_CHECK(head_dim % quant_block_size == 0, + // "head_dim must be divisible by quant_block_size"); + + constexpr int vec_size = 16; + const at::cuda::OptionalCUDAGuard device_guard(device_of(kv_cache)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + if (num_tokens < 32) { + CALL_CP_GATHER_INDEXER_K_CACHE(1); + } else if (num_tokens < 64) { + CALL_CP_GATHER_INDEXER_K_CACHE(2); + } else if (num_tokens < 128) { + CALL_CP_GATHER_INDEXER_K_CACHE(4); + } else if (num_tokens < 256) { + CALL_CP_GATHER_INDEXER_K_CACHE(8); + } else if (num_tokens < 512) { + CALL_CP_GATHER_INDEXER_K_CACHE(16); + } else { + CALL_CP_GATHER_INDEXER_K_CACHE(32); + } +} + +// Macro to dispatch the kernel based on the data amount. +#define CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(BLOCK_Y_SIZE) \ + vllm::cp_gather_indexer_k_quant_cache_kernel \ + <<>>( \ + reinterpret_cast(kv_cache.data_ptr()), \ + reinterpret_cast(dst_k.data_ptr()), \ + reinterpret_cast(dst_scale.data_ptr()), \ + block_table.data_ptr(), cu_seq_lens.data_ptr(), \ + batch_size, dst_k.stride(0), dst_k.size(1), kv_cache.stride(0), \ + kv_cache.stride(1), kv_cache.size(1), block_table.size(1), \ + num_tokens, quant_block_size); + +void cp_gather_indexer_k_quant_cache( + const torch::Tensor& kv_cache, // [num_blocks, block_size, cache_stride] + torch::Tensor& dst_k, // [num_tokens, head_dim] + torch::Tensor& dst_scale, // [num_tokens, head_dim / quant_block_size * 4] + const torch::Tensor& block_table, // [batch_size, num_blocks] + const torch::Tensor& cu_seq_lens // [batch_size + 1] +) { + int batch_size = block_table.size(0); + int num_tokens = dst_k.size(0); + int head_dim = dst_k.size(1); + int quant_block_size = head_dim * 4 / dst_scale.size(1); + + TORCH_CHECK(kv_cache.device() == dst_k.device(), + "kv_cache and dst_k must be on the same device"); + TORCH_CHECK(kv_cache.device() == dst_scale.device(), + "kv_cache and dst_scale must be on the same device"); + TORCH_CHECK(kv_cache.device() == block_table.device(), + "kv_cache and block_table must be on the same device"); + TORCH_CHECK(kv_cache.device() == cu_seq_lens.device(), + "kv_cache and cu_seq_lens must be on the same device"); + TORCH_CHECK(head_dim % quant_block_size == 0, + "head_dim must be divisible by quant_block_size"); + + constexpr int vec_size = 16; + const at::cuda::OptionalCUDAGuard device_guard(device_of(kv_cache)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + if (num_tokens < 32) { + CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(1); + } else if (num_tokens < 64) { + CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(2); + } else if (num_tokens < 128) { + CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(4); + } else if (num_tokens < 256) { + CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(8); + } else if (num_tokens < 512) { + CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(16); + } else { + CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(32); + } } \ No newline at end of file diff --git a/csrc/cub_helpers.h b/csrc/cub_helpers.h new file mode 100644 index 000000000..18e4e343a --- /dev/null +++ b/csrc/cub_helpers.h @@ -0,0 +1,18 @@ +#pragma once + +#ifndef USE_ROCM + #include + #if CUB_VERSION >= 200800 + #include +using CubAddOp = cuda::std::plus<>; +using CubMaxOp = cuda::maximum<>; + #else // if CUB_VERSION < 200800 +using CubAddOp = cub::Sum; +using CubMaxOp = cub::Max; + #endif // CUB_VERSION +#else + #include +namespace cub = hipcub; +using CubAddOp = hipcub::Sum; +using CubMaxOp = hipcub::Max; +#endif // USE_ROCM diff --git a/csrc/cutlass_extensions/vllm_cutlass_library_extension.py b/csrc/cutlass_extensions/vllm_cutlass_library_extension.py index 1dd7101ac..7a81dd40c 100644 --- a/csrc/cutlass_extensions/vllm_cutlass_library_extension.py +++ b/csrc/cutlass_extensions/vllm_cutlass_library_extension.py @@ -27,7 +27,7 @@ class MixedInputKernelScheduleType(enum.Enum): **{ VLLMDataType.u4b8: "u4b8", VLLMDataType.u8b128: "u8b128", - } + }, } VLLMDataTypeTag: dict[Union[VLLMDataType, DataType], str] = { @@ -35,7 +35,7 @@ class MixedInputKernelScheduleType(enum.Enum): **{ VLLMDataType.u4b8: "cutlass::vllm_uint4b8_t", VLLMDataType.u8b128: "cutlass::vllm_uint8b128_t", - } + }, } VLLMDataTypeSize: dict[Union[VLLMDataType, DataType], int] = { @@ -43,7 +43,7 @@ class MixedInputKernelScheduleType(enum.Enum): **{ VLLMDataType.u4b8: 4, VLLMDataType.u8b128: 8, - } + }, } VLLMDataTypeVLLMScalarTypeTag: dict[Union[VLLMDataType, DataType], str] = { @@ -67,15 +67,13 @@ class MixedInputKernelScheduleType(enum.Enum): DataType.f32: "at::ScalarType::Float", } -VLLMKernelScheduleTag: dict[Union[ - MixedInputKernelScheduleType, KernelScheduleType], str] = { - **KernelScheduleTag, # type: ignore - **{ - MixedInputKernelScheduleType.TmaWarpSpecialized: - "cutlass::gemm::KernelTmaWarpSpecialized", - MixedInputKernelScheduleType.TmaWarpSpecializedPingpong: - "cutlass::gemm::KernelTmaWarpSpecializedPingpong", - MixedInputKernelScheduleType.TmaWarpSpecializedCooperative: - "cutlass::gemm::KernelTmaWarpSpecializedCooperative", - } - } +VLLMKernelScheduleTag: dict[ + Union[MixedInputKernelScheduleType, KernelScheduleType], str +] = { + **KernelScheduleTag, # type: ignore + **{ + MixedInputKernelScheduleType.TmaWarpSpecialized: "cutlass::gemm::KernelTmaWarpSpecialized", + MixedInputKernelScheduleType.TmaWarpSpecializedPingpong: "cutlass::gemm::KernelTmaWarpSpecializedPingpong", + MixedInputKernelScheduleType.TmaWarpSpecializedCooperative: "cutlass::gemm::KernelTmaWarpSpecializedCooperative", + }, +} diff --git a/csrc/moe/moe_ops.h b/csrc/moe/moe_ops.h index 17a5ffc4f..1aa39e85c 100644 --- a/csrc/moe/moe_ops.h +++ b/csrc/moe/moe_ops.h @@ -4,7 +4,7 @@ void topk_softmax(torch::Tensor& topk_weights, torch::Tensor& topk_indices, torch::Tensor& token_expert_indices, - torch::Tensor& gating_output); + torch::Tensor& gating_output, bool renormalize); void moe_sum(torch::Tensor& input, torch::Tensor& output); diff --git a/csrc/moe/topk_softmax_kernels.cu b/csrc/moe/topk_softmax_kernels.cu index 512c4943e..af6e6fcd4 100644 --- a/csrc/moe/topk_softmax_kernels.cu +++ b/csrc/moe/topk_softmax_kernels.cu @@ -16,15 +16,22 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#include #include #include #include #include "../cuda_compat.h" - -#include -#include -#include -using AddOp = cuda::std::plus; +#include "../cub_helpers.h" + +#ifndef USE_ROCM + #include + #include +#else + #include + #include + typedef __hip_bfloat16 __nv_bfloat16; + typedef __hip_bfloat162 __nv_bfloat162; +#endif #define MAX(a, b) ((a) > (b) ? (a) : (b)) #define MIN(a, b) ((a) < (b) ? (a) : (b)) @@ -40,16 +47,27 @@ template < /// Alignment requirement in bytes int Alignment = sizeof(T) * N > -class alignas(Alignment) AlignedArray { - float data[N]; +struct alignas(Alignment) AlignedArray { + T data[N]; }; +template +__device__ __forceinline__ float toFloat(T value) { + if constexpr (std::is_same_v) { + return value; + } else if constexpr (std::is_same_v) { + return __bfloat162float(value); + } else if constexpr (std::is_same_v) { + return __half2float(value); + } +} + // ====================== Softmax things =============================== // We have our own implementation of softmax here so we can support transposing the output // in the softmax kernel when we extend this module to support expert-choice routing. -template +template __launch_bounds__(TPB) __global__ - void moeSoftmax(const float* input, const bool* finished, float* output, const int num_cols) + void moeSoftmax(const InputType* input, const bool* finished, float* output, const int num_cols) { using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage tmpStorage; @@ -70,10 +88,11 @@ __launch_bounds__(TPB) __global__ for (int ii = threadIdx.x; ii < num_cols; ii += TPB) { const int idx = thread_row_offset + ii; - threadData = max(static_cast(input[idx]), threadData); + const float val = toFloat(input[idx]); + threadData = max(val, threadData); } - const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max()); + const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, CubMaxOp()); if (threadIdx.x == 0) { float_max = maxElem; @@ -85,10 +104,11 @@ __launch_bounds__(TPB) __global__ for (int ii = threadIdx.x; ii < num_cols; ii += TPB) { const int idx = thread_row_offset + ii; - threadData += exp((static_cast(input[idx]) - float_max)); + const float val = toFloat(input[idx]); + threadData += expf(val - float_max); } - const auto Z = BlockReduce(tmpStorage).Reduce(threadData, AddOp()); + const auto Z = BlockReduce(tmpStorage).Reduce(threadData, CubAddOp()); if (threadIdx.x == 0) { @@ -99,8 +119,9 @@ __launch_bounds__(TPB) __global__ for (int ii = threadIdx.x; ii < num_cols; ii += TPB) { const int idx = thread_row_offset + ii; - const float val = exp((static_cast(input[idx]) - float_max)) * normalizing_factor; - output[idx] = val; + const float val = toFloat(input[idx]); + const float softmax_val = expf(val - float_max) * normalizing_factor; + output[idx] = softmax_val; } } @@ -114,7 +135,8 @@ __launch_bounds__(TPB) __global__ void moeTopK( const int num_experts, const int k, const int start_expert, - const int end_expert) + const int end_expert, + const bool renormalize) { using cub_kvp = cub::KeyValuePair; @@ -129,6 +151,7 @@ __launch_bounds__(TPB) __global__ void moeTopK( const bool row_is_active = finished ? !finished[block_row] : true; const int thread_read_offset = blockIdx.x * num_experts; + float selected_sum = 0.f; for (int k_idx = 0; k_idx < k; ++k_idx) { thread_kvp.key = 0; @@ -167,9 +190,23 @@ __launch_bounds__(TPB) __global__ void moeTopK( indices[idx] = should_process_row ? (expert - start_expert) : num_experts; assert(indices[idx] >= 0); source_rows[idx] = k_idx * num_rows + block_row; + if (renormalize) { + selected_sum += result_kvp.value; + } } __syncthreads(); } + + // Renormalize the k weights for this row to sum to 1, if requested. + if (renormalize) { + if (threadIdx.x == 0) { + const float denom = selected_sum > 0.f ? selected_sum : 1.f; + for (int k_idx = 0; k_idx < k; ++k_idx) { + const int idx = k * block_row + k_idx; + output[idx] = output[idx] / denom; + } + } + } } // ====================== TopK softmax things =============================== @@ -188,21 +225,30 @@ __launch_bounds__(TPB) __global__ void moeTopK( 2) This implementation assumes k is small, but will work for any k. */ -template +template __launch_bounds__(WARPS_PER_CTA* WARP_SIZE_PARAM) __global__ - void topkGatingSoftmax(const float* input, const bool* finished, float* output, const int num_rows, IndType* indices, - int* source_rows, const int k, const int start_expert, const int end_expert) + void topkGatingSoftmax(const InputType* input, const bool* finished, float* output, const int num_rows, IndType* indices, + int* source_rows, const int k, const int start_expert, const int end_expert, const bool renormalize) { + static_assert(std::is_same_v || std::is_same_v || + std::is_same_v, + "InputType must be float, __nv_bfloat16, or __half"); + // We begin by enforcing compile time assertions and setting up compile time constants. static_assert(BYTES_PER_LDG == (BYTES_PER_LDG & -BYTES_PER_LDG), "BYTES_PER_LDG must be power of 2"); static_assert(BYTES_PER_LDG <= 16, "BYTES_PER_LDG must be leq 16"); // Number of bytes each thread pulls in per load - static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(float); + static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(InputType); static constexpr int ELTS_PER_ROW = NUM_EXPERTS; static constexpr int THREADS_PER_ROW = ELTS_PER_ROW / VPT; static constexpr int LDG_PER_THREAD = VPT / ELTS_PER_LDG; + if constexpr (std::is_same_v || std::is_same_v) { + static_assert(ELTS_PER_LDG == 1 || ELTS_PER_LDG % 2 == 0, + "ELTS_PER_LDG must be 1 or even for 16-bit conversion"); + } + // Restrictions based on previous section. static_assert(VPT % ELTS_PER_LDG == 0, "The elements per thread must be a multiple of the elements per ldg"); static_assert(WARP_SIZE_PARAM % THREADS_PER_ROW == 0, "The threads per row must cleanly divide the threads per warp"); @@ -240,27 +286,71 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE_PARAM) __global__ // We finally start setting up the read pointers for each thread. First, each thread jumps to the start of the // row it will read. - const float* thread_row_ptr = input + thread_row * ELTS_PER_ROW; + const InputType* thread_row_ptr = input + thread_row * ELTS_PER_ROW; // Now, we compute the group each thread belong to in order to determine the first column to start loads. const int thread_group_idx = threadIdx.x % THREADS_PER_ROW; const int first_elt_read_by_thread = thread_group_idx * ELTS_PER_LDG; - const float* thread_read_ptr = thread_row_ptr + first_elt_read_by_thread; - - // Determine the pointer type to use to read in the data depending on the BYTES_PER_LDG template param. In theory, - // this can support all powers of 2 up to 16. - // NOTE(woosuk): The original implementation uses CUTLASS aligned array here. - // We defined our own aligned array and use it here to avoid the dependency on CUTLASS. - using AccessType = AlignedArray; + const InputType* thread_read_ptr = thread_row_ptr + first_elt_read_by_thread; // Finally, we pull in the data from global mem float row_chunk[VPT]; - AccessType* row_chunk_vec_ptr = reinterpret_cast(&row_chunk); - const AccessType* vec_thread_read_ptr = reinterpret_cast(thread_read_ptr); + + // NOTE(zhuhaoran): dispatch different input types loading, BF16/FP16 convert to float + if constexpr (std::is_same_v) { + using VecType = AlignedArray; + VecType* row_chunk_vec_ptr = reinterpret_cast(&row_chunk); + const VecType* vec_thread_read_ptr = reinterpret_cast(thread_read_ptr); #pragma unroll - for (int ii = 0; ii < LDG_PER_THREAD; ++ii) - { - row_chunk_vec_ptr[ii] = vec_thread_read_ptr[ii * THREADS_PER_ROW]; + for (int ii = 0; ii < LDG_PER_THREAD; ++ii) { + row_chunk_vec_ptr[ii] = vec_thread_read_ptr[ii * THREADS_PER_ROW]; + } + } else if constexpr (std::is_same_v) { + if constexpr (ELTS_PER_LDG >= 2) { + using VecType = AlignedArray<__nv_bfloat16, ELTS_PER_LDG>; + float2* row_chunk_f2 = reinterpret_cast(row_chunk); + const VecType* vec_thread_read_ptr = reinterpret_cast(thread_read_ptr); +#pragma unroll + for (int ii = 0; ii < LDG_PER_THREAD; ++ii) { + VecType vec = vec_thread_read_ptr[ii * THREADS_PER_ROW]; + int base_idx_f2 = ii * ELTS_PER_LDG / 2; +#pragma unroll + for (int jj = 0; jj < ELTS_PER_LDG / 2; ++jj) { + row_chunk_f2[base_idx_f2 + jj] = __bfloat1622float2( + *reinterpret_cast(vec.data + jj * 2) + ); + } + } + } else { // ELTS_PER_LDG == 1 +#pragma unroll + for (int ii = 0; ii < LDG_PER_THREAD; ++ii) { + const __nv_bfloat16* scalar_ptr = thread_read_ptr + ii * THREADS_PER_ROW; + row_chunk[ii] = __bfloat162float(*scalar_ptr); + } + } + } else if constexpr (std::is_same_v) { + if constexpr (ELTS_PER_LDG >= 2) { + using VecType = AlignedArray<__half, ELTS_PER_LDG>; + float2* row_chunk_f2 = reinterpret_cast(row_chunk); + const VecType* vec_thread_read_ptr = reinterpret_cast(thread_read_ptr); +#pragma unroll + for (int ii = 0; ii < LDG_PER_THREAD; ++ii) { + VecType vec = vec_thread_read_ptr[ii * THREADS_PER_ROW]; + int base_idx_f2 = ii * ELTS_PER_LDG / 2; +#pragma unroll + for (int jj = 0; jj < ELTS_PER_LDG / 2; ++jj) { + row_chunk_f2[base_idx_f2 + jj] = __half22float2( + *reinterpret_cast(vec.data + jj * 2) + ); + } + } + } else { // ELTS_PER_LDG == 1 +#pragma unroll + for (int ii = 0; ii < LDG_PER_THREAD; ++ii) { + const __half* scalar_ptr = thread_read_ptr + ii * THREADS_PER_ROW; + row_chunk[ii] = __half2float(*scalar_ptr); + } + } } // First, we perform a max reduce within the thread. We can do the max in fp16 safely (I think) and just @@ -314,6 +404,7 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE_PARAM) __global__ int start_col = first_elt_read_by_thread; static constexpr int COLS_PER_GROUP_LDG = ELTS_PER_LDG * THREADS_PER_ROW; + float selected_sum = 0.f; for (int k_idx = 0; k_idx < k; ++k_idx) { // First, each thread does the local argmax @@ -367,6 +458,9 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE_PARAM) __global__ output[idx] = max_val; indices[idx] = should_process_row ? (expert - start_expert) : NUM_EXPERTS; source_rows[idx] = k_idx * num_rows + thread_row; + if (renormalize) { + selected_sum += max_val; + } } // Finally, we clear the value in the thread with the current max if there is another iteration to run. @@ -384,15 +478,28 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE_PARAM) __global__ } } } + + // Renormalize the k weights for this row to sum to 1, if requested. + if (renormalize) { + if (thread_group_idx == 0) + { + const float denom = selected_sum > 0.f ? selected_sum : 1.f; + for (int k_idx = 0; k_idx < k; ++k_idx) + { + const int idx = k * thread_row + k_idx; + output[idx] = output[idx] / denom; + } + } + } } namespace detail { // Constructs some constants needed to partition the work across threads at compile time. -template +template struct TopkConstants { - static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(float); + static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(InputType); static_assert(EXPERTS / (ELTS_PER_LDG * WARP_SIZE_PARAM) == 0 || EXPERTS % (ELTS_PER_LDG * WARP_SIZE_PARAM) == 0, ""); static constexpr int VECs_PER_THREAD = MAX(1, EXPERTS / (ELTS_PER_LDG * WARP_SIZE_PARAM)); static constexpr int VPT = VECs_PER_THREAD * ELTS_PER_LDG; @@ -401,33 +508,48 @@ struct TopkConstants }; } // namespace detail -template -void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, float* output, IndType* indices, - int* source_row, const int num_rows, const int k, const int start_expert, const int end_expert, cudaStream_t stream) +template +void topkGatingSoftmaxLauncherHelper(const InputType* input, const bool* finished, float* output, IndType* indices, + int* source_row, const int num_rows, const int k, const int start_expert, const int end_expert, const bool renormalize, + cudaStream_t stream) { - static constexpr int BYTES_PER_LDG = MIN(MAX_BYTES_PER_LDG, sizeof(float) * EXPERTS); - using Constants = detail::TopkConstants; + static constexpr int BYTES_PER_LDG = MIN(MAX_BYTES_PER_LDG, sizeof(InputType) * EXPERTS); + using Constants = detail::TopkConstants; static constexpr int VPT = Constants::VPT; static constexpr int ROWS_PER_WARP = Constants::ROWS_PER_WARP; const int num_warps = (num_rows + ROWS_PER_WARP - 1) / ROWS_PER_WARP; const int num_blocks = (num_warps + WARPS_PER_TB - 1) / WARPS_PER_TB; - dim3 block_dim(WARP_SIZE_PARAM, WARPS_PER_TB); - topkGatingSoftmax<<>>( - input, finished, output, num_rows, indices, source_row, k, start_expert, end_expert); + dim3 block_dim(WARP_SIZE_PARAM, WARPS_PER_TB); + topkGatingSoftmax<<>>( + input, finished, output, num_rows, indices, source_row, k, start_expert, end_expert, renormalize); } +#ifndef USE_ROCM #define LAUNCH_SOFTMAX(NUM_EXPERTS, WARPS_PER_TB, MAX_BYTES) \ - static_assert(WARP_SIZE == 32, \ - "Unsupported warp size. Only 32 is supported."); \ + static_assert(WARP_SIZE == 32, \ + "Unsupported warp size. Only 32 is supported for CUDA"); \ topkGatingSoftmaxLauncherHelper( \ - gating_output, nullptr, topk_weights, topk_indices, \ - token_expert_indices, num_tokens, topk, 0, num_experts, stream); - + gating_output, nullptr, topk_weights, topk_indices, token_expert_indices, \ + num_tokens, topk, 0, num_experts, renormalize, stream); +#else +#define LAUNCH_SOFTMAX(NUM_EXPERTS, WARPS_PER_TB, MAX_BYTES) \ + if (WARP_SIZE == 64) { \ + topkGatingSoftmaxLauncherHelper( \ + gating_output, nullptr, topk_weights, topk_indices, token_expert_indices, \ + num_tokens, topk, 0, num_experts, renormalize, stream); \ + } else if (WARP_SIZE == 32) { \ + topkGatingSoftmaxLauncherHelper( \ + gating_output, nullptr, topk_weights, topk_indices, token_expert_indices, \ + num_tokens, topk, 0, num_experts, renormalize, stream); \ + } else { \ + assert(false && "Unsupported warp size. Only 32 and 64 are supported for ROCm"); \ + } +#endif -template +template void topkGatingSoftmaxKernelLauncher( - const float* gating_output, + const InputType* gating_output, float* topk_weights, IndType* topk_indices, int* token_expert_indices, @@ -435,10 +557,16 @@ void topkGatingSoftmaxKernelLauncher( const int num_tokens, const int num_experts, const int topk, + const bool renormalize, cudaStream_t stream) { static constexpr int WARPS_PER_TB = 4; static constexpr int BYTES_PER_LDG_POWER_OF_2 = 16; - static constexpr int BYTES_PER_LDG_MULTIPLE_64 = 8; +#ifndef USE_ROCM + // for bfloat16 dtype, we need 4 bytes loading to make sure num_experts + // elements can be loaded by a warp + static constexpr int BYTES_PER_LDG_MULTIPLE_64 = + (std::is_same_v || std::is_same_v) ? 4 : 8; +#endif switch (num_experts) { case 1: LAUNCH_SOFTMAX(1, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2); @@ -473,6 +601,7 @@ void topkGatingSoftmaxKernelLauncher( // (CUDA only) support multiples of 64 when num_experts is not power of 2. // ROCm uses WARP_SIZE 64 so 8 bytes loading won't fit for some of num_experts, // alternatively we can test 4 bytes loading and enable it in future. +#ifndef USE_ROCM case 192: LAUNCH_SOFTMAX(192, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64); break; @@ -488,15 +617,16 @@ void topkGatingSoftmaxKernelLauncher( case 576: LAUNCH_SOFTMAX(576, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64); break; +#endif default: { TORCH_CHECK(softmax_workspace != nullptr, "softmax_workspace must be provided for num_experts that are not a power of 2 or multiple of 64."); static constexpr int TPB = 256; - moeSoftmax<<>>( + moeSoftmax<<>>( gating_output, nullptr, softmax_workspace, num_experts); moeTopK<<>>( softmax_workspace, nullptr, topk_weights, topk_indices, token_expert_indices, - num_experts, topk, 0, num_experts); + num_experts, topk, 0, num_experts, renormalize); } } } @@ -504,11 +634,50 @@ void topkGatingSoftmaxKernelLauncher( } // namespace moe } // namespace vllm + +template +void dispatch_topk_softmax_launch( + torch::Tensor& gating_output, + torch::Tensor& topk_weights, + torch::Tensor& topk_indices, + torch::Tensor& token_expert_indices, + torch::Tensor& softmax_workspace, + int num_tokens, int num_experts, int topk, bool renormalize, cudaStream_t stream) +{ + if (topk_indices.scalar_type() == at::ScalarType::Int) { + vllm::moe::topkGatingSoftmaxKernelLauncher( + reinterpret_cast(gating_output.data_ptr()), + topk_weights.data_ptr(), + topk_indices.data_ptr(), + token_expert_indices.data_ptr(), + softmax_workspace.data_ptr(), + num_tokens, num_experts, topk, renormalize, stream); + } else if (topk_indices.scalar_type() == at::ScalarType::UInt32) { + vllm::moe::topkGatingSoftmaxKernelLauncher( + reinterpret_cast(gating_output.data_ptr()), + topk_weights.data_ptr(), + topk_indices.data_ptr(), + token_expert_indices.data_ptr(), + softmax_workspace.data_ptr(), + num_tokens, num_experts, topk, renormalize, stream); + } else { + TORCH_CHECK(topk_indices.scalar_type() == at::ScalarType::Long); + vllm::moe::topkGatingSoftmaxKernelLauncher( + reinterpret_cast(gating_output.data_ptr()), + topk_weights.data_ptr(), + topk_indices.data_ptr(), + token_expert_indices.data_ptr(), + softmax_workspace.data_ptr(), + num_tokens, num_experts, topk, renormalize, stream); + } +} + void topk_softmax( torch::Tensor& topk_weights, // [num_tokens, topk] torch::Tensor& topk_indices, // [num_tokens, topk] torch::Tensor& token_expert_indices, // [num_tokens, topk] - torch::Tensor& gating_output) // [num_tokens, num_experts] + torch::Tensor& gating_output, // [num_tokens, num_experts] + bool renormalize) { const int num_experts = gating_output.size(-1); const auto num_tokens = gating_output.numel() / num_experts; @@ -520,45 +689,19 @@ void topk_softmax( const at::cuda::OptionalCUDAGuard device_guard(device_of(gating_output)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - torch::Tensor softmax_workspace = torch::empty({workspace_size}, gating_output.options()); - - if(topk_indices.scalar_type() == at::ScalarType::Int) - { - vllm::moe::topkGatingSoftmaxKernelLauncher( - gating_output.data_ptr(), - topk_weights.data_ptr(), - topk_indices.data_ptr(), - token_expert_indices.data_ptr(), - softmax_workspace.data_ptr(), - num_tokens, - num_experts, - topk, - stream); - } - else if (topk_indices.scalar_type() == at::ScalarType::UInt32) - { - vllm::moe::topkGatingSoftmaxKernelLauncher( - gating_output.data_ptr(), - topk_weights.data_ptr(), - topk_indices.data_ptr(), - token_expert_indices.data_ptr(), - softmax_workspace.data_ptr(), - num_tokens, - num_experts, - topk, - stream); - } - else { - assert(topk_indices.scalar_type() == at::ScalarType::Int64); - vllm::moe::topkGatingSoftmaxKernelLauncher( - gating_output.data_ptr(), - topk_weights.data_ptr(), - topk_indices.data_ptr(), - token_expert_indices.data_ptr(), - softmax_workspace.data_ptr(), - num_tokens, - num_experts, - topk, - stream); + const auto workspace_options = gating_output.options().dtype(at::ScalarType::Float); + torch::Tensor softmax_workspace = torch::empty({workspace_size}, workspace_options); + + if (gating_output.scalar_type() == at::ScalarType::Float) { + dispatch_topk_softmax_launch(gating_output, topk_weights, topk_indices, + token_expert_indices, softmax_workspace, num_tokens, num_experts, topk, renormalize, stream); + } else if (gating_output.scalar_type() == at::ScalarType::Half) { + dispatch_topk_softmax_launch<__half>(gating_output, topk_weights, topk_indices, + token_expert_indices, softmax_workspace, num_tokens, num_experts, topk, renormalize, stream); + } else if (gating_output.scalar_type() == at::ScalarType::BFloat16) { + dispatch_topk_softmax_launch<__nv_bfloat16>(gating_output, topk_weights, topk_indices, + token_expert_indices, softmax_workspace, num_tokens, num_experts, topk, renormalize, stream); + } else { + TORCH_CHECK(false, "Unsupported gating_output data type: ", gating_output.scalar_type()); } } diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index 031109a67..f3d61cc61 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -5,7 +5,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { // Apply topk softmax to the gating outputs. m.def( "topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor! " - "token_expert_indices, Tensor gating_output) -> ()"); + "token_expert_indices, Tensor gating_output, bool renormalize) -> ()"); m.impl("topk_softmax", torch::kCUDA, &topk_softmax); // Calculate the result of moe by summing up the partial results diff --git a/csrc/ops.h b/csrc/ops.h index 7d126d5c1..a152bd7a7 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -97,6 +97,14 @@ void apply_repetition_penalties_(torch::Tensor& logits, const torch::Tensor& output_mask, const torch::Tensor& repetition_penalties); +void top_k_per_row(const torch::Tensor& logits, const torch::Tensor& rowStarts, + const torch::Tensor& rowEnds, torch::Tensor& indices, + int64_t numRows, int64_t stride0, int64_t stride1); + +void top_k_per_row_decode(const torch::Tensor& logits, int64_t next_n, + const torch::Tensor& seq_lens, torch::Tensor& indices, + int64_t numRows, int64_t stride0, int64_t stride1); + void rms_norm_static_fp8_quant(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, torch::Tensor& scale, double epsilon); diff --git a/csrc/sampler.cu b/csrc/sampler.cu index a8d7f7374..f5c5811b2 100644 --- a/csrc/sampler.cu +++ b/csrc/sampler.cu @@ -40,6 +40,256 @@ __global__ void apply_repetition_penalties_kernel( } } +static inline __device__ uint16_t extractBinIdx(float x) { + union { + __half h; + uint16_t u16; + } tmp; + tmp.h = __float2half_rn(x); + tmp.u16 = (x < 0.f) ? (~tmp.u16 & 0xffff) : (tmp.u16 | 0x8000); + return 511 - (tmp.u16 >> 7); +} + +template +__device__ void topKPerRowJob(const float* logits, const int rowStart, + const int rowEnd, const int rowIdx, + int* outIndices, int stride0, int stride1) { + // The number of elements per thread for the final top-k sort. + static constexpr int kNumTopKItemsPerThread = kTopK / kNumThreadsPerBlock; + // The class to sort the elements during the final top-k sort. + using TopKSort = cub::BlockRadixSort; + + // The number of slots for the final pass. + static constexpr int kNumFinalItems = 3072; + // The number of elements per thread for the final sort. + static constexpr int kNumFinalItemsPerThread = + kNumFinalItems / kNumThreadsPerBlock; + // The class to sort the elements during the final pass. + using FinalSort = cub::BlockRadixSort; + + // The class to compute the inclusive prefix-sum over the histogram. + using Scan = cub::BlockScan; + + // Shared memory to compute the block scan. + __shared__ typename Scan::TempStorage smemScan; + + // The structure to store the final items (for the final pass). + struct FinalItems { + // Shared memory to store the indices for the final pass. + int indices[kNumFinalItems]; + // Shared memory to store the logits for the final pass. + float logits[kNumFinalItems]; + }; + + // Shared memory to compute the block sort. + __shared__ union { + FinalItems items; + typename FinalSort::TempStorage finalSort; + typename TopKSort::TempStorage topKSort; + } smemFinal; + + // Shared memory to store the histogram. + __shared__ int smemHistogram[kNumBins]; + // Shared memory to store the selected indices. + __shared__ int smemIndices[kTopK]; + // Shared memory to store the threshold bin. + __shared__ int smemThresholdBinIdx[1]; + // Shared memory counter to register the candidates for the final phase. + __shared__ int smemFinalDstIdx[1]; + + // The length of the row. + int rowLen = rowEnd - rowStart; + + // Shortcut if the length of the row is smaller than Top-K. Indices are not + // sorted by their corresponding logit. + if (rowLen <= kTopK) { + for (int rowIt = threadIdx.x; rowIt < rowLen; + rowIt += kNumThreadsPerBlock) { + int idx = rowStart + rowIt; + outIndices[rowIdx * kTopK + rowIt] = idx - rowStart; + } + for (int rowIt = rowLen + threadIdx.x; rowIt < kTopK; + rowIt += kNumThreadsPerBlock) { + outIndices[rowIdx * kTopK + rowIt] = -1; + } + return; + } + + // Clear the histogram. + if (threadIdx.x < kNumBins) { + smemHistogram[threadIdx.x] = 0; + } + + // Make sure the histogram is ready. + __syncthreads(); + + // Fetch elements one-by-one. + for (int rowIt = rowStart + threadIdx.x; rowIt < rowEnd; + rowIt += kNumThreadsPerBlock) { + uint16_t idx = extractBinIdx(logits[rowIdx * stride0 + rowIt * stride1]); + atomicAdd(&smemHistogram[idx], 1); + } + + // Make sure the histogram is ready. + __syncthreads(); + + // Read the values from SMEM. + int binCount{0}; + if (threadIdx.x < kNumBins) { + binCount = smemHistogram[threadIdx.x]; + } + + // Make sure each thread has read its value. + __syncthreads(); + + // Compute the prefix sum. + int prefixSum{0}, totalSum{0}; + Scan(smemScan).ExclusiveSum(binCount, prefixSum, totalSum); + + // Update the histogram with the prefix sums. + if (threadIdx.x < kNumBins) { + smemHistogram[threadIdx.x] = prefixSum; + } + + // Make sure the data is in shared memory. + __syncthreads(); + + // Find the last valid bin. + if (threadIdx.x < kNumBins) { + int nextPrefixSum = + threadIdx.x == kNumBins - 1 ? totalSum : smemHistogram[threadIdx.x + 1]; + if (prefixSum < kTopK && nextPrefixSum >= kTopK) { + smemThresholdBinIdx[0] = threadIdx.x; + } + } + + // Clear the counter to store the items for the final phase. + if (threadIdx.x == 0) { + smemFinalDstIdx[0] = 0; + } + + // Make sure the data is in shared memory. + __syncthreads(); + + // The threshold bin. + int thresholdBinIdx = smemThresholdBinIdx[0]; + + // Fetch elements one-by-one and populate the shared memory buffers. + for (int rowIt = rowStart + threadIdx.x; rowIt < rowEnd; + rowIt += kNumThreadsPerBlock) { + float logit = logits[rowIdx * stride0 + rowIt * stride1]; + uint16_t idx = extractBinIdx(logit); + if (idx < thresholdBinIdx) { + int dstIdx = atomicAdd(&smemHistogram[idx], 1); + smemIndices[dstIdx] = rowIt; + } else if (idx == thresholdBinIdx) { + int dstIdx = atomicAdd(&smemFinalDstIdx[0], 1); + if (dstIdx < kNumFinalItems) { + smemFinal.items.logits[dstIdx] = logit; + smemFinal.items.indices[dstIdx] = rowIt; + } + } + } + + // Make sure the elements are in shared memory. + __syncthreads(); + + // The logits of the elements to be sorted in the final pass. + float finalLogits[kNumFinalItemsPerThread]; + // The indices of the elements to be sorted in the final pass. + int finalIndices[kNumFinalItemsPerThread]; + +// Init. +#pragma unroll + for (int ii = 0; ii < kNumFinalItemsPerThread; ++ii) { + finalLogits[ii] = -FLT_MAX; + } + +// Read the elements from SMEM. +#pragma unroll + for (int ii = 0; ii < kNumFinalItemsPerThread; ++ii) { + int srcIdx = ii * kNumThreadsPerBlock + threadIdx.x; + if (srcIdx < smemFinalDstIdx[0]) { + finalLogits[ii] = smemFinal.items.logits[srcIdx]; + finalIndices[ii] = smemFinal.items.indices[srcIdx]; + } + } + + // Make sure the shared memory has been read. + __syncthreads(); + + // Sort the elements. + FinalSort(smemFinal.finalSort) + .SortDescendingBlockedToStriped(finalLogits, finalIndices); + + // Copy the data back to the shared memory storage. + int baseIdx = thresholdBinIdx > 0 ? smemHistogram[thresholdBinIdx - 1] : 0; +#pragma unroll + for (int ii = 0; ii < kNumFinalItemsPerThread; ++ii) { + int srcIdx = ii * kNumThreadsPerBlock + threadIdx.x; + int dstIdx = baseIdx + srcIdx; + if (dstIdx < kTopK) { + smemIndices[dstIdx] = finalIndices[ii]; + } + } + + // Make sure the data is in shared memory. + __syncthreads(); + +// Store to global memory. +#pragma unroll + for (int ii = 0; ii < kNumTopKItemsPerThread; ++ii) { + int offset = rowIdx * kTopK + ii * kNumThreadsPerBlock + threadIdx.x; + outIndices[offset] = + smemIndices[ii * kNumThreadsPerBlock + threadIdx.x] - rowStart; + } +} + +template +static __global__ void topKPerRow(const float* logits, const int* rowStarts, + const int* rowEnds, int* outIndices, + int stride0, int stride1) { + // The number of bins in the histogram. + static constexpr int kNumBins = 512; + + // The top-k width. + static constexpr int kTopK = 2048; + + // The row computed by this block. + int rowIdx = blockIdx.x; + + // The range of logits within the row. + int rowStart = rowStarts[rowIdx]; + int rowEnd = rowEnds[rowIdx]; + + topKPerRowJob( + logits, rowStart, rowEnd, rowIdx, outIndices, stride0, stride1); +} + +template +static __global__ void topKPerRowDecode(const float* logits, const int* seqLens, + int* outIndices, int stride0, + int stride1, int next_n) { + // The number of bins in the histogram. + static constexpr int kNumBins = 512; + + // The top-k width. + static constexpr int kTopK = 2048; + + // The row computed by this block. + int rowIdx = blockIdx.x; + + // The range of logits within the row. + int rowStart = 0; + int seq_len = seqLens[rowIdx / next_n]; + int rowEnd = seq_len - next_n + (rowIdx % next_n) + 1; + + topKPerRowJob( + logits, rowStart, rowEnd, rowIdx, outIndices, stride0, stride1); +} + } // namespace vllm void apply_repetition_penalties_( @@ -81,4 +331,32 @@ void apply_repetition_penalties_( repetition_penalties.data_ptr(), num_seqs, vocab_size, tile_size); }); +} + +void top_k_per_row_decode(const torch::Tensor& logits, int64_t next_n, + const torch::Tensor& seqLens, torch::Tensor& indices, + int64_t numRows, int64_t stride0, int64_t stride1) { + // Compute the results on the device. + constexpr int kNumThreadsPerBlock = 512; + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + vllm::topKPerRowDecode + <<>>( + logits.data_ptr(), seqLens.data_ptr(), + indices.data_ptr(), static_cast(stride0), + static_cast(stride1), static_cast(next_n)); +} + +void top_k_per_row(const torch::Tensor& logits, const torch::Tensor& rowStarts, + const torch::Tensor& rowEnds, torch::Tensor& indices, + int64_t numRows, int64_t stride0, int64_t stride1) { + // Compute the results on the device. + constexpr int kNumThreadsPerBlock = 512; + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + vllm::topKPerRow + <<>>( + logits.data_ptr(), rowStarts.data_ptr(), + rowEnds.data_ptr(), indices.data_ptr(), + static_cast(stride0), static_cast(stride1)); } \ No newline at end of file diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 37a232b29..c4c265f1d 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -169,6 +169,19 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.impl("apply_repetition_penalties_", torch::kCUDA, &apply_repetition_penalties_); + // Optimized top-k per row operation + ops.def( + "top_k_per_row(Tensor logits, Tensor rowStarts, Tensor rowEnds, " + "Tensor! indices, int numRows, int stride0, " + "int stride1) -> ()"); + ops.impl("top_k_per_row", torch::kCUDA, &top_k_per_row); + + ops.def( + "top_k_per_row_decode(Tensor logits, int next_n, " + "Tensor seq_lens, Tensor! indices, int numRows, " + "int stride0, int stride1) -> ()"); + ops.impl("top_k_per_row_decode", torch::kCUDA, &top_k_per_row_decode); + // ┌------------------------ Not supported for Metax // -------------------------┐ Layernorm-quant Apply Root Mean Square (RMS) // Normalization to the input tensor. diff --git a/docs/mkdocs/hooks/url_schemes.py b/docs/mkdocs/hooks/url_schemes.py index c1d57cbc2..e3f7d50c3 100644 --- a/docs/mkdocs/hooks/url_schemes.py +++ b/docs/mkdocs/hooks/url_schemes.py @@ -46,9 +46,9 @@ relative_link = re.compile(rf"\[{TITLE}\]\({RELATIVE}\)") -def on_page_markdown(markdown: str, *, page: Page, config: MkDocsConfig, - files: Files) -> str: - +def on_page_markdown( + markdown: str, *, page: Page, config: MkDocsConfig, files: Files +) -> str: def replace_relative_link(match: re.Match) -> str: """Replace relative file links with URLs if they point outside the docs dir.""" title = match.group("title") diff --git a/examples/offline_inference/audio_language.py b/examples/offline_inference/audio_language.py index 22cb8b057..53d69bbdb 100644 --- a/examples/offline_inference/audio_language.py +++ b/examples/offline_inference/audio_language.py @@ -10,7 +10,7 @@ import os from dataclasses import asdict -from typing import Any, NamedTuple, Optional +from typing import Any, NamedTuple from huggingface_hub import snapshot_download from transformers import AutoTokenizer @@ -18,7 +18,7 @@ from vllm import LLM, EngineArgs, SamplingParams from vllm.assets.audio import AudioAsset from vllm.lora.request import LoRARequest -from vllm.utils import FlexibleArgumentParser +from vllm.utils.argparse_utils import FlexibleArgumentParser audio_assets = [AudioAsset("mary_had_lamb"), AudioAsset("winning_call")] question_per_audio_count = { @@ -30,11 +30,11 @@ class ModelRequestData(NamedTuple): engine_args: EngineArgs - prompt: Optional[str] = None - prompt_token_ids: Optional[dict[str, list[int]]] = None - multi_modal_data: Optional[dict[str, Any]] = None - stop_token_ids: Optional[list[int]] = None - lora_requests: Optional[list[LoRARequest]] = None + prompt: str | None = None + prompt_token_ids: dict[str, list[int]] | None = None + multi_modal_data: dict[str, Any] | None = None + stop_token_ids: list[int] | None = None + lora_requests: list[LoRARequest] | None = None # NOTE: The default `max_num_seqs` and `max_model_len` may result in OOM on @@ -45,10 +45,12 @@ class ModelRequestData(NamedTuple): # Voxtral def run_voxtral(question: str, audio_count: int) -> ModelRequestData: from mistral_common.audio import Audio - from mistral_common.protocol.instruct.messages import ( + from mistral_common.protocol.instruct.chunk import ( AudioChunk, RawAudio, TextChunk, + ) + from mistral_common.protocol.instruct.messages import ( UserMessage, ) from mistral_common.protocol.instruct.request import ChatCompletionRequest @@ -117,7 +119,7 @@ def run_gemma3n(question: str, audio_count: int) -> ModelRequestData: # Granite Speech def run_granite_speech(question: str, audio_count: int) -> ModelRequestData: - # NOTE - the setting in this example are somehat different than what is + # NOTE - the setting in this example are somewhat different from what is # optimal for granite speech, and it is generally recommended to use beam # search. Check the model README for suggested settings. # https://huggingface.co/ibm-granite/granite-speech-3.3-8b @@ -146,6 +148,36 @@ def run_granite_speech(question: str, audio_count: int) -> ModelRequestData: ) +# MiDashengLM +def run_midashenglm(question: str, audio_count: int): + model_name = "mispeech/midashenglm-7b" + + engine_args = EngineArgs( + model=model_name, + trust_remote_code=True, + max_model_len=4096, + max_num_seqs=5, + limit_mm_per_prompt={"audio": audio_count}, + ) + + audio_in_prompt = "".join( + ["<|audio_bos|><|AUDIO|><|audio_eos|>" for idx in range(audio_count)] + ) + + default_system = "You are a helpful language and speech assistant." + + prompt = ( + f"<|im_start|>system\n{default_system}<|im_end|>\n" + "<|im_start|>user\n" + f"{audio_in_prompt}{question}<|im_end|>\n" + "<|im_start|>assistant\n" + ) + return ModelRequestData( + engine_args=engine_args, + prompt=prompt, + ) + + # MiniCPM-O def run_minicpmo(question: str, audio_count: int) -> ModelRequestData: model_name = "openbmb/MiniCPM-o-2_6" @@ -352,6 +384,7 @@ def run_whisper(question: str, audio_count: int) -> ModelRequestData: "voxtral": run_voxtral, "gemma3n": run_gemma3n, "granite_speech": run_granite_speech, + "midashenglm": run_midashenglm, "minicpmo": run_minicpmo, "phi4_mm": run_phi4mm, "phi4_multimodal": run_phi4_multimodal, diff --git a/examples/offline_inference/basic/chat.py b/examples/offline_inference/basic/chat.py index d078c517d..c42b00730 100644 --- a/examples/offline_inference/basic/chat.py +++ b/examples/offline_inference/basic/chat.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from vllm import LLM, EngineArgs -from vllm.utils import FlexibleArgumentParser +from vllm.utils.argparse_utils import FlexibleArgumentParser def create_parser(): @@ -87,6 +87,7 @@ def print_outputs(outputs): use_tqdm=False, chat_template=chat_template, ) + print_outputs(outputs) if __name__ == "__main__": diff --git a/examples/offline_inference/basic/classify.py b/examples/offline_inference/basic/classify.py index dc3bc399c..b72ddde1f 100644 --- a/examples/offline_inference/basic/classify.py +++ b/examples/offline_inference/basic/classify.py @@ -4,7 +4,7 @@ from argparse import Namespace from vllm import LLM, EngineArgs -from vllm.utils import FlexibleArgumentParser +from vllm.utils.argparse_utils import FlexibleArgumentParser def parse_args(): diff --git a/examples/offline_inference/basic/embed.py b/examples/offline_inference/basic/embed.py index 158836728..eeb7137ff 100644 --- a/examples/offline_inference/basic/embed.py +++ b/examples/offline_inference/basic/embed.py @@ -4,7 +4,7 @@ from argparse import Namespace from vllm import LLM, EngineArgs -from vllm.utils import FlexibleArgumentParser +from vllm.utils.argparse_utils import FlexibleArgumentParser def parse_args(): diff --git a/examples/offline_inference/basic/generate.py b/examples/offline_inference/basic/generate.py index 6a41ef4d8..9650dcfe9 100644 --- a/examples/offline_inference/basic/generate.py +++ b/examples/offline_inference/basic/generate.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from vllm import LLM, EngineArgs -from vllm.utils import FlexibleArgumentParser +from vllm.utils.argparse_utils import FlexibleArgumentParser def create_parser(): diff --git a/examples/offline_inference/basic/reward.py b/examples/offline_inference/basic/reward.py index aa173cf96..e95085686 100644 --- a/examples/offline_inference/basic/reward.py +++ b/examples/offline_inference/basic/reward.py @@ -4,7 +4,7 @@ from argparse import Namespace from vllm import LLM, EngineArgs -from vllm.utils import FlexibleArgumentParser +from vllm.utils.argparse_utils import FlexibleArgumentParser def parse_args(): diff --git a/examples/offline_inference/basic/score.py b/examples/offline_inference/basic/score.py index c9ca7a8bf..cbca50eb5 100644 --- a/examples/offline_inference/basic/score.py +++ b/examples/offline_inference/basic/score.py @@ -4,7 +4,7 @@ from argparse import Namespace from vllm import LLM, EngineArgs -from vllm.utils import FlexibleArgumentParser +from vllm.utils.argparse_utils import FlexibleArgumentParser def parse_args(): diff --git a/examples/offline_inference/chat_with_tools.py b/examples/offline_inference/chat_with_tools.py index 6e56e24f2..3a95b1fdf 100644 --- a/examples/offline_inference/chat_with_tools.py +++ b/examples/offline_inference/chat_with_tools.py @@ -143,5 +143,5 @@ def get_current_weather(city: str, state: str, unit: "str"): print(outputs[0].outputs[0].text.strip()) # yields -# 'The weather in Dallas, TX is 85 degrees fahrenheit. ' +# 'The weather in Dallas, TX is 85 degrees Fahrenheit. ' # 'It is partly cloudly, with highs in the 90's.' diff --git a/examples/offline_inference/data_parallel.py b/examples/offline_inference/data_parallel.py index dd7559451..0b281fc41 100644 --- a/examples/offline_inference/data_parallel.py +++ b/examples/offline_inference/data_parallel.py @@ -33,7 +33,7 @@ from time import sleep from vllm import LLM, SamplingParams -from vllm.utils import get_open_port +from vllm.utils.network_utils import get_open_port def parse_args(): @@ -87,10 +87,27 @@ def parse_args(): default=0.8, help=("Fraction of GPU memory vLLM is allowed to allocate (0.0, 1.0]."), ) + parser.add_argument( + "--enable-dbo", + action="store_true", + help=("Enable microbatched execution"), + ) + parser.add_argument( + "--compilation-config", + type=int, + help=("Compilation optimization (O) mode 0-3."), + ) parser.add_argument( "--quantization", type=str, ) + parser.add_argument( + "--disable-expert-parallel", + dest="enable_expert_parallel", + action="store_false", + help="Disable expert parallel (default: enabled).", + ) + parser.set_defaults(enable_expert_parallel=True) return parser.parse_args() @@ -103,10 +120,13 @@ def main( dp_master_port, GPUs_per_dp_rank, enforce_eager, + enable_expert_parallel, trust_remote_code, max_num_seqs, max_model_len, + compilation_config, gpu_memory_utilization, + enable_dbo, quantization, ): os.environ["VLLM_DP_RANK"] = str(global_dp_rank) @@ -156,12 +176,14 @@ def start(rank): model=model, tensor_parallel_size=GPUs_per_dp_rank, enforce_eager=enforce_eager, - enable_expert_parallel=True, + enable_expert_parallel=enable_expert_parallel, trust_remote_code=trust_remote_code, max_num_seqs=max_num_seqs, max_model_len=max_model_len, gpu_memory_utilization=gpu_memory_utilization, + enable_dbo=enable_dbo, quantization=quantization, + compilation_config=compilation_config, ) outputs = llm.generate(prompts, sampling_params) # Print the outputs. @@ -215,10 +237,13 @@ def start(rank): dp_master_port, tp_size, args.enforce_eager, + args.enable_expert_parallel, args.trust_remote_code, args.max_num_seqs, args.max_model_len, + args.compilation_config, args.gpu_memory_utilization, + args.enable_dbo, args.quantization, ), ) diff --git a/examples/offline_inference/disaggregated_prefill.py b/examples/offline_inference/disaggregated_prefill.py index 05a361fee..f619fa584 100644 --- a/examples/offline_inference/disaggregated_prefill.py +++ b/examples/offline_inference/disaggregated_prefill.py @@ -30,12 +30,12 @@ def run_prefill(prefill_done): ] sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=1) - # Using PyNcclConnector to transmit KV caches between vLLM instances. + # Using P2pNcclConnector to transmit KV caches between vLLM instances. # This instance is the prefill node (kv_producer, rank 0). # The number of parallel instances for KV cache transfer is set to 2, - # as required for PyNcclConnector. + # as required for P2pNcclConnector. ktc = KVTransferConfig( - kv_connector="PyNcclConnector", + kv_connector="P2pNcclConnector", kv_role="kv_producer", kv_rank=0, kv_parallel_size=2, @@ -74,12 +74,12 @@ def run_decode(prefill_done): ] sampling_params = SamplingParams(temperature=0, top_p=0.95) - # Using PyNcclConnector to transmit KV caches between vLLM instances. + # Using P2pNcclConnector to transmit KV caches between vLLM instances. # This instance is the decode node (kv_consumer, rank 1). # The number of parallel instances for KV cache transfer is set to 2, - # as required for PyNcclConnector. + # as required for P2pNcclConnector. ktc = KVTransferConfig( - kv_connector="PyNcclConnector", + kv_connector="P2pNcclConnector", kv_role="kv_consumer", kv_rank=1, kv_parallel_size=2, diff --git a/examples/offline_inference/encoder_decoder.py b/examples/offline_inference/encoder_decoder.py deleted file mode 100644 index df6c1eaf4..000000000 --- a/examples/offline_inference/encoder_decoder.py +++ /dev/null @@ -1,193 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -""" -Demonstrate prompting of text-to-text -encoder/decoder models, specifically BART and mBART. - -This script is refactored to allow model selection via command-line arguments. -""" - -import argparse -from typing import NamedTuple, Optional - -from vllm import LLM, SamplingParams -from vllm.inputs import ( - ExplicitEncoderDecoderPrompt, - TextPrompt, - TokensPrompt, - zip_enc_dec_prompts, -) - - -class ModelRequestData(NamedTuple): - """ - Holds the configuration for a specific model, including its - HuggingFace ID and the prompts to use for the demo. - """ - - model_id: str - encoder_prompts: list - decoder_prompts: list - hf_overrides: Optional[dict] = None - - -def get_bart_config() -> ModelRequestData: - """ - Returns the configuration for facebook/bart-large-cnn. - This uses the exact test cases from the original script. - """ - encoder_prompts = [ - "Hello, my name is", - "The president of the United States is", - "The capital of France is", - "An encoder prompt", - ] - decoder_prompts = [ - "A decoder prompt", - "Another decoder prompt", - ] - return ModelRequestData( - model_id="facebook/bart-large-cnn", - encoder_prompts=encoder_prompts, - decoder_prompts=decoder_prompts, - ) - - -def get_mbart_config() -> ModelRequestData: - """ - Returns the configuration for facebook/mbart-large-en-ro. - This uses prompts suitable for an English-to-Romanian translation task. - """ - encoder_prompts = [ - "The quick brown fox jumps over the lazy dog.", - "How are you today?", - ] - decoder_prompts = ["", ""] - hf_overrides = {"architectures": ["MBartForConditionalGeneration"]} - return ModelRequestData( - model_id="facebook/mbart-large-en-ro", - encoder_prompts=encoder_prompts, - decoder_prompts=decoder_prompts, - hf_overrides=hf_overrides, - ) - - -MODEL_GETTERS = { - "bart": get_bart_config, - "mbart": get_mbart_config, -} - - -def create_all_prompt_types( - encoder_prompts_raw: list, - decoder_prompts_raw: list, - tokenizer, -) -> list: - """ - Generates a list of diverse prompt types for demonstration. - This function is generic and uses the provided raw prompts - to create various vLLM input objects. - """ - text_prompt_raw = encoder_prompts_raw[0] - text_prompt = TextPrompt(prompt=encoder_prompts_raw[1 % len(encoder_prompts_raw)]) - tokens_prompt = TokensPrompt( - prompt_token_ids=tokenizer.encode( - encoder_prompts_raw[2 % len(encoder_prompts_raw)] - ) - ) - - decoder_tokens_prompt = TokensPrompt( - prompt_token_ids=tokenizer.encode(decoder_prompts_raw[0]) - ) - single_prompt_examples = [ - text_prompt_raw, - text_prompt, - tokens_prompt, - ] - explicit_pair_examples = [ - ExplicitEncoderDecoderPrompt( - encoder_prompt=text_prompt_raw, - decoder_prompt=decoder_tokens_prompt, - ), - ExplicitEncoderDecoderPrompt( - encoder_prompt=text_prompt, - decoder_prompt=decoder_prompts_raw[1 % len(decoder_prompts_raw)], - ), - ExplicitEncoderDecoderPrompt( - encoder_prompt=tokens_prompt, - decoder_prompt=text_prompt, - ), - ] - zipped_prompt_list = zip_enc_dec_prompts( - encoder_prompts_raw, - decoder_prompts_raw, - ) - return single_prompt_examples + explicit_pair_examples + zipped_prompt_list - - -def create_sampling_params() -> SamplingParams: - """Create a sampling params object.""" - return SamplingParams( - temperature=0, - top_p=1.0, - min_tokens=0, - max_tokens=30, - ) - - -def print_outputs(outputs: list): - """Formats and prints the generation outputs.""" - print("-" * 80) - for i, output in enumerate(outputs): - prompt = output.prompt - encoder_prompt = output.encoder_prompt - generated_text = output.outputs[0].text - print(f"Output {i + 1}:") - print(f"Encoder Prompt: {encoder_prompt!r}") - print(f"Decoder Prompt: {prompt!r}") - print(f"Generated Text: {generated_text!r}") - print("-" * 80) - - -def main(args): - """Main execution function.""" - model_key = args.model - if model_key not in MODEL_GETTERS: - raise ValueError( - f"Unknown model: {model_key}. " - f"Available models: {list(MODEL_GETTERS.keys())}" - ) - config_getter = MODEL_GETTERS[model_key] - model_config = config_getter() - - print(f"🚀 Running demo for model: {model_config.model_id}") - llm = LLM( - model=model_config.model_id, - dtype="float", - hf_overrides=model_config.hf_overrides, - ) - tokenizer = llm.llm_engine.get_tokenizer_group() - prompts = create_all_prompt_types( - encoder_prompts_raw=model_config.encoder_prompts, - decoder_prompts_raw=model_config.decoder_prompts, - tokenizer=tokenizer, - ) - sampling_params = create_sampling_params() - outputs = llm.generate(prompts, sampling_params) - print_outputs(outputs) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser( - description="A flexible demo for vLLM encoder-decoder models." - ) - parser.add_argument( - "--model", - "-m", - type=str, - default="bart", - choices=MODEL_GETTERS.keys(), - help="The short name of the model to run.", - ) - args = parser.parse_args() - main(args) diff --git a/examples/offline_inference/encoder_decoder_multimodal.py b/examples/offline_inference/encoder_decoder_multimodal.py index d27a902ed..c1d6c6db5 100644 --- a/examples/offline_inference/encoder_decoder_multimodal.py +++ b/examples/offline_inference/encoder_decoder_multimodal.py @@ -5,6 +5,7 @@ the explicit/implicit prompt format on enc-dec LMMs for text generation. """ +import os import time from collections.abc import Sequence from dataclasses import asdict @@ -12,8 +13,7 @@ from vllm import LLM, EngineArgs, PromptType, SamplingParams from vllm.assets.audio import AudioAsset -from vllm.assets.image import ImageAsset -from vllm.utils import FlexibleArgumentParser +from vllm.utils.argparse_utils import FlexibleArgumentParser class ModelRequestData(NamedTuple): @@ -21,70 +21,9 @@ class ModelRequestData(NamedTuple): prompts: Sequence[PromptType] -def run_florence2(): - engine_args = EngineArgs( - model="microsoft/Florence-2-large", - tokenizer="Isotr0py/Florence-2-tokenizer", - max_num_seqs=8, - trust_remote_code=True, - limit_mm_per_prompt={"image": 1}, - dtype="half", - ) - - prompts = [ - { # implicit prompt with task token - "prompt": "", - "multi_modal_data": {"image": ImageAsset("stop_sign").pil_image}, - }, - { # explicit encoder/decoder prompt - "encoder_prompt": { - "prompt": "Describe in detail what is shown in the image.", - "multi_modal_data": {"image": ImageAsset("cherry_blossom").pil_image}, - }, - "decoder_prompt": "", - }, - ] - - return ModelRequestData( - engine_args=engine_args, - prompts=prompts, - ) - - -def run_mllama(): - engine_args = EngineArgs( - model="meta-llama/Llama-3.2-11B-Vision-Instruct", - max_model_len=8192, - max_num_seqs=2, - limit_mm_per_prompt={"image": 1}, - dtype="half", - ) - - prompts = [ - { # Implicit prompt - "prompt": "<|image|><|begin_of_text|>What is the content of this image?", # noqa: E501 - "multi_modal_data": { - "image": ImageAsset("stop_sign").pil_image, - }, - }, - { # Explicit prompt - "encoder_prompt": { - "prompt": "<|image|>", - "multi_modal_data": { - "image": ImageAsset("stop_sign").pil_image, - }, - }, - "decoder_prompt": "<|image|><|begin_of_text|>Please describe the image.", # noqa: E501 - }, - ] - - return ModelRequestData( - engine_args=engine_args, - prompts=prompts, - ) - - def run_whisper(): + os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" + engine_args = EngineArgs( model="openai/whisper-large-v3-turbo", max_model_len=448, @@ -118,8 +57,6 @@ def run_whisper(): model_example_map = { - "florence2": run_florence2, - "mllama": run_mllama, "whisper": run_whisper, } @@ -133,7 +70,7 @@ def parse_args(): "--model-type", "-m", type=str, - default="mllama", + default="whisper", choices=model_example_map.keys(), help='Huggingface "model_type".', ) diff --git a/examples/offline_inference/kv_load_failure_recovery/README.md b/examples/offline_inference/kv_load_failure_recovery/README.md new file mode 100644 index 000000000..230a16812 --- /dev/null +++ b/examples/offline_inference/kv_load_failure_recovery/README.md @@ -0,0 +1,30 @@ +# KV Load Failure Recovery Test + +This example builds upon the `disaggregated-prefill-v1` example in `examples/offline_inference`. + +It demonstrates vLLM's ability to recover from KV load failures in both synchronous and asynchronous loading modes. The goal is to verify that vLLM correctly identifies invalid KV blocks, reschedules the affected requests, and ensures successful and consistent output. + +## Files + +- `prefill_example.py` – performs the prefill stage and saves KV data (same as in `disaggregated-prefill-v1`). +- `decode_example.py` – performs the decode stage. Accepts: + - `--simulate-failure`: simulates KV load failure using a custom connector. + - `--async-load`: enables asynchronous KV loading mode. +- `rogue_shared_storage_connector.py` – defines `RogueSharedStorageConnector`, a subclass of `SharedStorageConnector`, that simulates missing or corrupted external KV blocks by failing to load blocks for the first decode request. +- `run.sh` – orchestrates the test: runs the prefill stage, then three decode stages: + 1. Normal decode (baseline). + 2. Decode with simulated sync KV load failure. + 3. Decode with simulated async KV load failure. + + Finally, it compares the output of the baseline with the recovered outputs to verify correctness. + +## How It Works + +- The test dynamically loads `RogueSharedStorageConnector` via `KVTransferConfig.kv_connector_module_path`, enabling controlled simulation of load failures without modifying the original connector. +- The decode stages that simulate failure are expected to trigger recovery logic in vLLM, resulting in the same output as the baseline decode. +- If recovery fails, the script prints a unified diff of the output mismatch and exits with error. + +## Usage + +```bash +./run.sh diff --git a/examples/offline_inference/kv_load_failure_recovery/decode_example.py b/examples/offline_inference/kv_load_failure_recovery/decode_example.py new file mode 100644 index 000000000..69523f56e --- /dev/null +++ b/examples/offline_inference/kv_load_failure_recovery/decode_example.py @@ -0,0 +1,85 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import argparse + +from vllm import LLM, SamplingParams +from vllm.config import KVTransferConfig + + +def read_prompts(): + """Read prompts from prefill_output.txt""" + prompts = [] + try: + with open("prefill_output.txt") as f: + for line in f: + prompts.append(line.strip()) + print(f"Loaded {len(prompts)} prompts from prefill_output.txt") + return prompts + except FileNotFoundError: + print("Error: prefill_output.txt file not found") + exit(-1) + + +def main(): + prompts = read_prompts() + sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10) + + parser = argparse.ArgumentParser() + parser.add_argument( + "--simulate-failure", action="store_true", help="Simulate KV load failure." + ) + parser.add_argument( + "--async-load", action="store_true", help="Simulate async KV load" + ) + args = parser.parse_args() + + if args.simulate_failure: + ktc = KVTransferConfig( + kv_connector="RogueSharedStorageConnector", + kv_role="kv_both", + kv_connector_extra_config={ + "shared_storage_path": "local_storage", + "async_load": args.async_load, + }, + kv_connector_module_path="rogue_shared_storage_connector", + ) + out_file = ( + "async_decode_recovered_output.txt" + if args.async_load + else "sync_decode_recovered_output.txt" + ) + else: + ktc = KVTransferConfig( + kv_connector="SharedStorageConnector", + kv_role="kv_both", + kv_connector_extra_config={ + "shared_storage_path": "local_storage", + }, + ) + out_file = "decode_output.txt" + + llm = LLM( + model="meta-llama/Llama-3.2-1B-Instruct", + enforce_eager=True, + gpu_memory_utilization=0.8, + max_num_batched_tokens=64, + max_num_seqs=16, + kv_transfer_config=ktc, + ) + + outputs = llm.generate(prompts, sampling_params) + + sep_str = "-" * 30 + with open(out_file, "w", encoding="utf-8") as f: + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + out_str = f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}" + print(out_str) + print(sep_str) + f.write(out_str) + f.write(sep_str) + + +if __name__ == "__main__": + main() diff --git a/examples/offline_inference/kv_load_failure_recovery/prefill_example.py b/examples/offline_inference/kv_load_failure_recovery/prefill_example.py new file mode 100644 index 000000000..047b81c82 --- /dev/null +++ b/examples/offline_inference/kv_load_failure_recovery/prefill_example.py @@ -0,0 +1,58 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from vllm import LLM, SamplingParams +from vllm.config import KVTransferConfig + + +def read_prompts(): + context = "Hi " * 1000 + context2 = "Hey " * 500 + return [ + context + "Hello, my name is", + context + "The capital of France is", + context2 + "Your name is", + context2 + "The capital of China is", + ] + + +def main(): + prompts = read_prompts() + + sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=1) + + llm = LLM( + model="meta-llama/Llama-3.2-1B-Instruct", + enforce_eager=True, + gpu_memory_utilization=0.8, + kv_transfer_config=KVTransferConfig( + kv_connector="SharedStorageConnector", + kv_role="kv_both", + kv_connector_extra_config={"shared_storage_path": "local_storage"}, + ), + ) # , max_model_len=2048, max_num_batched_tokens=2048) + + # 1ST generation (prefill instance) + outputs = llm.generate( + prompts, + sampling_params, + ) + + new_prompts = [] + print("-" * 30) + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + new_prompts.append(prompt + generated_text) + print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}") + print("-" * 30) + + # Write new_prompts to prefill_output.txt + with open("prefill_output.txt", "w") as f: + for prompt in new_prompts: + f.write(prompt + "\n") + print(f"Saved {len(new_prompts)} prompts to prefill_output.txt") + + +if __name__ == "__main__": + main() diff --git a/examples/offline_inference/kv_load_failure_recovery/rogue_shared_storage_connector.py b/examples/offline_inference/kv_load_failure_recovery/rogue_shared_storage_connector.py new file mode 100644 index 000000000..5b2acea4c --- /dev/null +++ b/examples/offline_inference/kv_load_failure_recovery/rogue_shared_storage_connector.py @@ -0,0 +1,145 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# ruff: noqa: E501 +import logging +from dataclasses import dataclass, field +from typing import TYPE_CHECKING + +from vllm.config import VllmConfig +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorMetadata, + KVConnectorRole, +) +from vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector import ( + SharedStorageConnector, + SharedStorageConnectorMetadata, +) +from vllm.forward_context import ForwardContext +from vllm.v1.core.kv_cache_manager import KVCacheBlocks +from vllm.v1.request import Request + +if TYPE_CHECKING: + from vllm.v1.core.sched.output import SchedulerOutput + +logger = logging.getLogger() +logging.basicConfig(level=logging.INFO) + + +@dataclass +class RogueSharedStorageConnectorMetadata(SharedStorageConnectorMetadata): + req_to_block_ids: dict[str, set[int]] = field(default_factory=dict) + + @classmethod + def from_base(cls, base: SharedStorageConnectorMetadata): + return cls(requests=base.requests) + + +class RogueSharedStorageConnector(SharedStorageConnector): + def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): + super().__init__(vllm_config=vllm_config, role=role) + self._async_load = vllm_config.kv_transfer_config.get_from_extra_config( + "async_load", False + ) + self._invalid_block_ids: set = None + self._seen_requests: set = set() + self._req_to_block_ids: dict[str, list[int]] = dict() + + def bind_connector_metadata(self, connector_metadata: KVConnectorMetadata) -> None: + assert isinstance(connector_metadata, RogueSharedStorageConnectorMetadata) + index, failed_request = next( + ( + (i, x) + for i, x in enumerate(connector_metadata.requests) + if not x.is_store + ), + (None, None), + ) + if index is not None: + del connector_metadata.requests[index] + self._invalid_block_ids = set( + ( + failed_request.slot_mapping[:: self._block_size] // self._block_size + ).tolist() + ) + logger.info( + "Simulating failure to load all KV blocks for the " + "first load request. Total blocks: %d", + len(self._invalid_block_ids), + ) + super().bind_connector_metadata(connector_metadata) + + def clear_connector_metadata(self) -> None: + self._invalid_block_ids = None + super().clear_connector_metadata() + + def start_load_kv(self, forward_context: ForwardContext, **kwargs) -> None: + if self._async_load and forward_context.attn_metadata is None: + # Bypass sanity check in super().start_load_kv + forward_context.attn_metadata = "None" + + super().start_load_kv(forward_context, **kwargs) + + def get_finished( + self, finished_req_ids: set[str] + ) -> tuple[set[str] | None, set[str] | None]: + if self._async_load: + meta = self._get_connector_metadata() + assert isinstance(meta, RogueSharedStorageConnectorMetadata) + if meta.req_to_block_ids: + return None, set(meta.req_to_block_ids) + + return None, None + + def get_block_ids_with_load_errors(self) -> set[int]: + return self._invalid_block_ids + + def get_num_new_matched_tokens( + self, + request: Request, + num_computed_tokens: int, + ) -> tuple[int, bool]: + if request.request_id in self._seen_requests: + return 0, False + + self._seen_requests.add(request.request_id) + + num_tokens, _ = super().get_num_new_matched_tokens(request, num_computed_tokens) + return num_tokens, self._async_load and num_tokens > 0 + + def update_state_after_alloc( + self, request: Request, blocks: KVCacheBlocks, num_external_tokens: int + ): + """ + Update KVConnector state after block allocation. + + If blocks were allocated, add to _requests_need_load, + such that we load the KVs in the next forward pass. + """ + super().update_state_after_alloc(request, blocks, num_external_tokens) + + if num_external_tokens > 0: + self._req_to_block_ids[request.request_id] = blocks.get_block_ids()[0] + + def build_connector_meta( + self, + scheduler_output: "SchedulerOutput", + ) -> KVConnectorMetadata: + if not self._async_load: + base = super().build_connector_meta(scheduler_output) + meta = RogueSharedStorageConnectorMetadata.from_base(base) + else: + meta = RogueSharedStorageConnectorMetadata() + if self._requests_need_load: + for req_id, request in self._requests_need_load.items(): + meta.add_request( + token_ids=request.prompt_token_ids, + block_ids=self._req_to_block_ids[req_id], + block_size=self._block_size, + is_store=False, + mm_hashes=[], + ) + # Clear state + self._requests_need_load.clear() + meta.req_to_block_ids = self._req_to_block_ids + self._req_to_block_ids = dict() + return meta diff --git a/examples/offline_inference/kv_load_failure_recovery/run.sh b/examples/offline_inference/kv_load_failure_recovery/run.sh new file mode 100644 index 000000000..53fe2385d --- /dev/null +++ b/examples/offline_inference/kv_load_failure_recovery/run.sh @@ -0,0 +1,33 @@ +#!/bin/bash + +# Constants +SHARED_STORAGE_DIR="local_storage" +PREFILL_OUTPUT="prefill_output.txt" +DECODE_OUTPUT="decode_output.txt" +SYNC_DECODE_RECOVERED_OUTPUT="sync_decode_recovered_output.txt" +ASYNC_DECODE_RECOVERED_OUTPUT="async_decode_recovered_output.txt" + +# Cleanup +rm -rf "$SHARED_STORAGE_DIR" +rm -f "$PREFILL_OUTPUT" "$DECODE_OUTPUT" "$SYNC_DECODE_RECOVERED_OUTPUT" "$ASYNC_DECODE_RECOVERED_OUTPUT" + +# Run inference examples +VLLM_ENABLE_V1_MULTIPROCESSING=0 CUDA_VISIBLE_DEVICES=0 python3 prefill_example.py +VLLM_ENABLE_V1_MULTIPROCESSING=0 CUDA_VISIBLE_DEVICES=0 python3 decode_example.py +VLLM_ENABLE_V1_MULTIPROCESSING=0 CUDA_VISIBLE_DEVICES=0 python3 decode_example.py --simulate-failure +VLLM_ENABLE_V1_MULTIPROCESSING=0 CUDA_VISIBLE_DEVICES=0 python3 decode_example.py --simulate-failure --async-load + +# Compare outputs +if ! cmp -s "$DECODE_OUTPUT" "$SYNC_DECODE_RECOVERED_OUTPUT"; then + echo "❌ Outputs differ: sync recovery failed." + diff -u "$DECODE_OUTPUT" "$SYNC_DECODE_RECOVERED_OUTPUT" + exit 1 +fi + +if ! cmp -s "$DECODE_OUTPUT" "$ASYNC_DECODE_RECOVERED_OUTPUT"; then + echo "❌ Outputs differ: async recovery failed." + diff -u "$DECODE_OUTPUT" "$ASYNC_DECODE_RECOVERED_OUTPUT" + exit 1 +fi + +echo "✅ Outputs match: recovery successful." diff --git a/examples/offline_inference/llm_engine_example.py b/examples/offline_inference/llm_engine_example.py index d7f2a1633..d9215255a 100644 --- a/examples/offline_inference/llm_engine_example.py +++ b/examples/offline_inference/llm_engine_example.py @@ -8,7 +8,7 @@ import argparse from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams -from vllm.utils import FlexibleArgumentParser +from vllm.utils.argparse_utils import FlexibleArgumentParser def create_test_prompts() -> list[tuple[str, SamplingParams]]: diff --git a/examples/offline_inference/load_sharded_state.py b/examples/offline_inference/load_sharded_state.py index cc78c0cbb..52c2363c8 100644 --- a/examples/offline_inference/load_sharded_state.py +++ b/examples/offline_inference/load_sharded_state.py @@ -25,7 +25,7 @@ import dataclasses from vllm import LLM, EngineArgs, SamplingParams -from vllm.utils import FlexibleArgumentParser +from vllm.utils.argparse_utils import FlexibleArgumentParser def parse_args(): diff --git a/examples/offline_inference/logits_processor.py b/examples/offline_inference/logits_processor/custom.py similarity index 71% rename from examples/offline_inference/logits_processor.py rename to examples/offline_inference/logits_processor/custom.py index 7ef20efa7..72e7ce24d 100644 --- a/examples/offline_inference/logits_processor.py +++ b/examples/offline_inference/logits_processor/custom.py @@ -33,8 +33,6 @@ class object. ------------------------------------------------------------ """ -from typing import Optional - import torch from vllm import LLM, SamplingParams @@ -42,8 +40,8 @@ class object. from vllm.v1.sample.logits_processor import ( BatchUpdate, LogitsProcessor, - MoveDirectionality, ) +from vllm.v1.sample.logits_processor.builtin import process_dict_updates # Hypothetical custom logits processor @@ -53,51 +51,33 @@ class DummyLogitsProcessor(LogitsProcessor): def __init__( self, vllm_config: VllmConfig, device: torch.device, is_pin_memory: bool ): - self.req_info: dict[int, SamplingParams] = {} + self.req_info: dict[int, int] = {} def is_argmax_invariant(self) -> bool: - """Never impacts greedy sampling""" return False - def update_state(self, batch_update: Optional[BatchUpdate]): - if not batch_update: - return - - # Process added requests. - for index, params, _, _ in batch_update.added: - assert params is not None - if params.extra_args and ( - target_token := params.extra_args.get("target_token") - ): - self.req_info[index] = target_token - - if self.req_info: - # Process removed requests. - for index in batch_update.removed: - self.req_info.pop(index, None) - - # Process moved requests, unidirectional move (a->b) and swap - # (a<->b) - for adx, bdx, direct in batch_update.moved: - a_val = self.req_info.pop(adx, None) - b_val = self.req_info.pop(bdx, None) - if a_val is not None: - self.req_info[bdx] = a_val - if direct == MoveDirectionality.SWAP and b_val is not None: - self.req_info[adx] = b_val + def update_state(self, batch_update: BatchUpdate | None): + process_dict_updates( + self.req_info, + batch_update, + # This function returns the LP's per-request state based on the + # request details, or None if this LP does not apply to the + # request. + lambda params, _, __: params.extra_args + and (params.extra_args.get("target_token")), + ) def apply(self, logits: torch.Tensor) -> torch.Tensor: if not self.req_info: return logits # Save target values before modification - rows_list = list(self.req_info.keys()) cols = torch.tensor( - [self.req_info[i] for i in rows_list], - dtype=torch.long, - device=logits.device, + list(self.req_info.values()), dtype=torch.long, device=logits.device + ) + rows = torch.tensor( + list(self.req_info.keys()), dtype=torch.long, device=logits.device ) - rows = torch.tensor(rows_list, dtype=torch.long, device=logits.device) values_to_keep = logits[rows, cols].clone() # Mask all but target tokens diff --git a/examples/offline_inference/logits_processor/custom_req.py b/examples/offline_inference/logits_processor/custom_req.py new file mode 100644 index 000000000..87cd7473f --- /dev/null +++ b/examples/offline_inference/logits_processor/custom_req.py @@ -0,0 +1,151 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""This example demonstrates wrapping a request-level logits processor to be +compatible with vLLM's batch-level logits processing + +For demo purposes, a dummy logits processor is employed which, if +`target_token` is passed as a keyword argument to `SamplingParams.extra_args`, +will mask out all tokens except `target_token`. This logits processor can be +applied to a vector of logits associated with a single decode step for a single +request. The logits processor cannot be applied to a request which does not +pass in a `target_token` custom argument. + +The request-level dummy logits processor is wrapped to create a batch-level +logits processor, which can apply the logits processor to output logits from +all requests in the persistent batch in a given decode step. For requests which +do not provide a `target_token` argument, the corresponding row of `logits` +will not be modified. + +A batch is constructed with `temperature=0.0` and 50% of requests specifying +`target_token`, and for these requests - and *only* these requests - we +expect the `target_token` to be decoded in each step, yielding an output +similar to that shown below: + +Generated Outputs: +------------------------------------------------------------ +Prompt: 'Hello, my name is' +Output: " ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '" +------------------------------------------------------------ +Prompt: 'The president of the United States is' +Output: " not a racist. He is a racist.\nHe's a racist because he" +------------------------------------------------------------ +Prompt: 'The capital of France is' +Output: ' also also also also also also also also also also also also also + also also also' +------------------------------------------------------------ +Prompt: 'The future of AI is' +Output: ' in the hands of the people.\n\nThe future of AI is in the' +------------------------------------------------------------ +""" + +from typing import Any + +import torch + +from vllm import LLM, SamplingParams +from vllm.logger import init_logger +from vllm.v1.sample.logits_processor import ( + AdapterLogitsProcessor, + RequestLogitsProcessor, +) + +logger = init_logger(__name__) + + +class DummyPerReqLogitsProcessor: + """The request-level logits processor masks out all logits except the + token id identified by `target_token`""" + + def __init__(self, target_token: int) -> None: + """Specify `target_token`""" + self.target_token = target_token + + def __call__( + self, + output_ids: list[int], + logits: torch.Tensor, + ) -> torch.Tensor: + val_to_keep = logits[self.target_token].item() + logits[:] = float("-inf") + logits[self.target_token] = val_to_keep + return logits + + +class WrappedPerReqLogitsProcessor(AdapterLogitsProcessor): + """Example of wrapping a fake request-level logit processor to create a + batch-level logits processor""" + + def is_argmax_invariant(self) -> bool: + return False + + def new_req_logits_processor( + self, + params: SamplingParams, + ) -> RequestLogitsProcessor | None: + """This method returns a new request-level logits processor, customized + to the `target_token` value associated with a particular request. + + Returns None if the logits processor should not be applied to the + particular request. To use the logits processor the request must have + a "target_token" custom argument with an integer value. + + Args: + params: per-request sampling params + + Returns: + `Callable` request logits processor, or None + """ + target_token: Any | None = params.extra_args and params.extra_args.get( + "target_token" + ) + if target_token is None: + return None + if not isinstance(target_token, int): + logger.warning( + "target_token value %s is not int; not applying logits" + " processor to request.", + target_token, + ) + return None + return DummyPerReqLogitsProcessor(target_token) + + +# Sample prompts. +prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] +# Create a mixture of requests which do and don't utilize the dummy logitproc +sampling_params_list = [ + SamplingParams(temperature=0.0, extra_args={"target_token": 128}), + SamplingParams(temperature=0.0), + SamplingParams(temperature=0.0, extra_args={"target_token": 67}), + SamplingParams(temperature=0.0), +] + + +def main(): + # Create an LLM. + llm = LLM( + model="facebook/opt-125m", + logits_processors=[WrappedPerReqLogitsProcessor], + ) + # Generate texts from the prompts. + # The output is a list of RequestOutput objects + # that contain the prompt, generated text, and other information. + outputs = llm.generate(prompts, sampling_params_list) + # Print the outputs. + print("\nGenerated Outputs:\n" + "-" * 60) + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}") + print(f"Output: {generated_text!r}") + print("-" * 60) + + +if __name__ == "__main__": + main() diff --git a/examples/offline_inference/logits_processor/custom_req_init.py b/examples/offline_inference/logits_processor/custom_req_init.py new file mode 100644 index 000000000..3bb82a786 --- /dev/null +++ b/examples/offline_inference/logits_processor/custom_req_init.py @@ -0,0 +1,163 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""This example demonstrates a special case of wrapping a request-level logits +processor, namely the case where it is necessary to utilize engine config or +environment info passed to the constructor. The subclass must override the +wrapper base class `__init__()` method to access the engine config, the device +identifier, or the flag which indicates whether pinned memory is available. + +For demo purposes, a request-level dummy logits processor is employed which +causes the same token (`target_token`) to be decoded in each step. The +request-level dummy logits processor is wrapped to create a batch-level logits +processor, which can apply the logits processor to output logits from all +requests in the persistent batch in a given decode step. + +The wrapped dummy logits processor below models a scenario where we must +disable the logits processor on non-"cuda" platforms. The wrapper base class +`__init__()` is overridden in order to check this condition and set a flag. + +A batch is constructed with `temperature=0.0` and 50% of requests specifying +`target_token`, and for these requests - and *only* these requests - we +expect that on a "cuda" device the output will look something like: + +Generated Outputs: +------------------------------------------------------------ +Prompt: 'Hello, my name is' +Output: " ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '" +------------------------------------------------------------ +Prompt: 'The president of the United States is' +Output: " not a racist. He is a racist.\nHe's a racist because he" +------------------------------------------------------------ +Prompt: 'The capital of France is' +Output: ' also also also also also also also also also also also also also + also also also' +------------------------------------------------------------ +Prompt: 'The future of AI is' +Output: ' in the hands of the people.\n\nThe future of AI is in the' +------------------------------------------------------------ + +which indicates that the logits processor is running. However, on a non-"cuda" +device, the first and third requests would not repeat the same token. +""" + +import torch + +from vllm import LLM, SamplingParams +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.v1.sample.logits_processor import ( + AdapterLogitsProcessor, + RequestLogitsProcessor, +) + +logger = init_logger(__name__) + + +class DummyPerReqLogitsProcessor: + """The request-level logits processor masks out all logits except the + token id identified by `target_token`""" + + def __init__(self, target_token: int) -> None: + """Specify `target_token`""" + self.target_token = target_token + + def __call__( + self, + output_ids: list[int], + logits: torch.Tensor, + ) -> torch.Tensor: + val_to_keep = logits[self.target_token].item() + logits[:] = float("-inf") + logits[self.target_token] = val_to_keep + return logits + + +class WrappedPerReqLogitsProcessor(AdapterLogitsProcessor): + """Example of overriding the wrapper class `__init__()` in order to utilize + info about the device type""" + + def __init__( + self, vllm_config: VllmConfig, device: torch.device, is_pin_memory: bool + ): + super().__init__(vllm_config, device, is_pin_memory) + self.is_cuda = device.type == "cuda" + + def is_argmax_invariant(self) -> bool: + return False + + def new_req_logits_processor( + self, + params: SamplingParams, + ) -> RequestLogitsProcessor | None: + """This method returns a new request-level logits processor, customized + to the `target_token` value associated with a particular request. + + Returns None if the logits processor should not be applied to the + particular request. To use the logits processor the request must have + a "target_token" custom argument with an integer value, and the device + must be "cuda"-type + + Args: + params: per-request sampling params + + Returns: + `Callable` request logits processor, or None + """ + if ( + not self.is_cuda + or ( + target_token := params.extra_args + and params.extra_args.get("target_token") + ) + is None + ): + return None + if not isinstance(target_token, int): + logger.warning( + "target_token value %s is not int; not applying logits" + " processor to request.", + target_token, + ) + return None + return DummyPerReqLogitsProcessor(target_token) + + +# Sample prompts. +prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] +# Create a mixture of requests which do and don't utilize the dummy logitproc +sampling_params_list = [ + SamplingParams(temperature=0.0, extra_args={"target_token": 128}), + SamplingParams(temperature=0.0), + SamplingParams(temperature=0.0, extra_args={"target_token": 67}), + SamplingParams(temperature=0.0), +] + + +def main(): + # Create an LLM. + llm = LLM( + model="facebook/opt-125m", + logits_processors=[WrappedPerReqLogitsProcessor], + ) + # Generate texts from the prompts. + # The output is a list of RequestOutput objects + # that contain the prompt, generated text, and other information. + outputs = llm.generate(prompts, sampling_params_list) + # Print the outputs. + print("\nGenerated Outputs:\n" + "-" * 60) + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}") + print(f"Output: {generated_text!r}") + print("-" * 60) + + +if __name__ == "__main__": + main() diff --git a/examples/offline_inference/lora_with_quantization_inference.py b/examples/offline_inference/lora_with_quantization_inference.py index 00d4cb9eb..dc5c6202f 100644 --- a/examples/offline_inference/lora_with_quantization_inference.py +++ b/examples/offline_inference/lora_with_quantization_inference.py @@ -8,7 +8,6 @@ """ import gc -from typing import Optional import torch from huggingface_hub import snapshot_download @@ -19,7 +18,7 @@ def create_test_prompts( lora_path: str, -) -> list[tuple[str, SamplingParams, Optional[LoRARequest]]]: +) -> list[tuple[str, SamplingParams, LoRARequest | None]]: return [ # this is an example of using quantization without LoRA ( @@ -56,7 +55,7 @@ def create_test_prompts( def process_requests( engine: LLMEngine, - test_prompts: list[tuple[str, SamplingParams, Optional[LoRARequest]]], + test_prompts: list[tuple[str, SamplingParams, LoRARequest | None]], ): """Continuously process a list of prompts and handle the outputs.""" request_id = 0 @@ -78,7 +77,7 @@ def process_requests( def initialize_engine( - model: str, quantization: str, lora_repo: Optional[str] + model: str, quantization: str, lora_repo: str | None ) -> LLMEngine: """Initialize the LLMEngine.""" diff --git a/examples/offline_inference/multilora_inference.py b/examples/offline_inference/multilora_inference.py index f0c00bcaa..6c23cf342 100644 --- a/examples/offline_inference/multilora_inference.py +++ b/examples/offline_inference/multilora_inference.py @@ -7,8 +7,6 @@ Requires HuggingFace credentials for access to Llama2. """ -from typing import Optional - from huggingface_hub import snapshot_download from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams @@ -17,13 +15,13 @@ def create_test_prompts( lora_path: str, -) -> list[tuple[str, SamplingParams, Optional[LoRARequest]]]: +) -> list[tuple[str, SamplingParams, LoRARequest | None]]: """Create a list of test prompts with their sampling parameters. 2 requests for base model, 4 requests for the LoRA. We define 2 different LoRA adapters (using the same model for demo purposes). Since we also set `max_loras=1`, the expectation is that the requests - with the second LoRA adapter will be ran after all requests with the + with the second LoRA adapter will be run after all requests with the first adapter have finished. """ return [ @@ -68,7 +66,7 @@ def create_test_prompts( def process_requests( engine: LLMEngine, - test_prompts: list[tuple[str, SamplingParams, Optional[LoRARequest]]], + test_prompts: list[tuple[str, SamplingParams, LoRARequest | None]], ): """Continuously process a list of prompts and handle the outputs.""" request_id = 0 diff --git a/examples/offline_inference/neuron.py b/examples/offline_inference/neuron.py deleted file mode 100644 index 7826629a3..000000000 --- a/examples/offline_inference/neuron.py +++ /dev/null @@ -1,49 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from vllm import LLM, SamplingParams - -# Sample prompts. -prompts = [ - "Hello, my name is", - "The president of the United States is", - "The capital of France is", - "The future of AI is", -] -# Create a sampling params object. -sampling_params = SamplingParams(temperature=0.8, top_p=0.95) - - -def main(): - # Create an LLM. - llm = LLM( - model="TinyLlama/TinyLlama-1.1B-Chat-v1.0", - max_num_seqs=8, - # The max_model_len and block_size arguments are required to be same as - # max sequence length when targeting neuron device. - # Currently, this is a known limitation in continuous batching support - # in transformers-neuronx. - # TODO(liangfu): Support paged-attention in transformers-neuronx. - max_model_len=1024, - block_size=1024, - # ruff: noqa: E501 - # The device can be automatically detected when AWS Neuron SDK is installed. - # The device argument can be either unspecified for automated detection, - # or explicitly assigned. - device="neuron", - tensor_parallel_size=2, - ) - # Generate texts from the prompts. The output is a list of RequestOutput objects - # that contain the prompt, generated text, and other information. - outputs = llm.generate(prompts, sampling_params) - # Print the outputs. - print("-" * 50) - for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}") - print("-" * 50) - - -if __name__ == "__main__": - main() diff --git a/examples/offline_inference/neuron_eagle.py b/examples/offline_inference/neuron_eagle.py deleted file mode 100644 index 8b1d235ff..000000000 --- a/examples/offline_inference/neuron_eagle.py +++ /dev/null @@ -1,61 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -""" -This example shows how to run offline inference with an EAGLE speculative -decoding model on neuron. To use EAGLE speculative decoding, you must use -a draft model that is specifically fine-tuned for EAGLE speculation. -Additionally, to use EAGLE with NxD Inference, the draft model must include -the LM head weights from the target model. These weights are shared between -the draft and target model. -""" - -from vllm import LLM, SamplingParams - -# Sample prompts. -prompts = [ - "What is annapurna labs?", -] - - -def main(): - # Create a sampling params object. - sampling_params = SamplingParams(top_k=1, max_tokens=500, ignore_eos=True) - - # Create an LLM. - llm = LLM( - model="/home/ubuntu/model_hf/Meta-Llama-3.1-70B-Instruct", - speculative_config={ - "model": "/home/ubuntu/model_hf/Llama-3.1-70B-Instruct-EAGLE-Draft", - "num_speculative_tokens": 5, - "max_model_len": 2048, - }, - max_num_seqs=4, - # The max_model_len and block_size arguments are required to be same as - # max sequence length when targeting neuron device. - # Currently, this is a known limitation in continuous batching support - # in neuronx-distributed-inference. - max_model_len=2048, - block_size=2048, - # The device can be automatically detected when AWS Neuron SDK is installed. - # The device argument can be either unspecified for automated detection, - # or explicitly assigned. - device="neuron", - tensor_parallel_size=32, - override_neuron_config={ - "enable_eagle_speculation": True, - "enable_fused_speculation": True, - }, - ) - - # Generate texts from the prompts. The output is a list of RequestOutput objects - # that contain the prompt, generated text, and other information. - outputs = llm.generate(prompts, sampling_params) - # Print the outputs. - for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}, \n\n\n Generated text: {generated_text!r}") - - -if __name__ == "__main__": - main() diff --git a/examples/offline_inference/neuron_int8_quantization.py b/examples/offline_inference/neuron_int8_quantization.py deleted file mode 100644 index c0ecfac50..000000000 --- a/examples/offline_inference/neuron_int8_quantization.py +++ /dev/null @@ -1,63 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import os - -from vllm import LLM, SamplingParams - -# creates XLA hlo graphs for all the context length buckets. -os.environ["NEURON_CONTEXT_LENGTH_BUCKETS"] = "128,512,1024,2048" -# creates XLA hlo graphs for all the token gen buckets. -os.environ["NEURON_TOKEN_GEN_BUCKETS"] = "128,512,1024,2048" -# Quantizes neuron model weight to int8 , -# The default config for quantization is int8 dtype. -os.environ["NEURON_QUANT_DTYPE"] = "s8" - -# Sample prompts. -prompts = [ - "Hello, my name is", - "The president of the United States is", - "The capital of France is", - "The future of AI is", -] -# Create a sampling params object. -sampling_params = SamplingParams(temperature=0.8, top_p=0.95) - - -def main(): - # Create an LLM. - llm = LLM( - model="TinyLlama/TinyLlama-1.1B-Chat-v1.0", - max_num_seqs=8, - # The max_model_len and block_size arguments are required to be same as - # max sequence length when targeting neuron device. - # Currently, this is a known limitation in continuous batching support - # in transformers-neuronx. - # TODO(liangfu): Support paged-attention in transformers-neuronx. - max_model_len=2048, - block_size=2048, - # ruff: noqa: E501 - # The device can be automatically detected when AWS Neuron SDK is installed. - # The device argument can be either unspecified for automated detection, - # or explicitly assigned. - device="neuron", - quantization="neuron_quant", - override_neuron_config={ - "cast_logits_dtype": "bfloat16", - }, - tensor_parallel_size=2, - ) - # Generate texts from the prompts. The output is a list of RequestOutput objects - # that contain the prompt, generated text, and other information. - outputs = llm.generate(prompts, sampling_params) - # Print the outputs. - print("-" * 50) - for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}") - print("-" * 50) - - -if __name__ == "__main__": - main() diff --git a/examples/offline_inference/neuron_multimodal.py b/examples/offline_inference/neuron_multimodal.py deleted file mode 100644 index 26f7505f2..000000000 --- a/examples/offline_inference/neuron_multimodal.py +++ /dev/null @@ -1,110 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import requests -import torch -from neuronx_distributed_inference.models.mllama.utils import add_instruct -from PIL import Image - -from vllm import LLM, SamplingParams, TextPrompt - - -def get_image(image_url): - image = Image.open(requests.get(image_url, stream=True).raw) - return image - - -# Model Inputs -PROMPTS = [ - "What is in this image? Tell me a story", - "What is the recipe of mayonnaise in two sentences?", - "Describe this image", - "What is the capital of Italy famous for?", -] -IMAGES = [ - get_image( - "https://images.pexels.com/photos/1108099/pexels-photo-1108099.jpeg?auto=compress&cs=tinysrgb&dpr=1&w=500" - ), - None, - get_image( - "https://images.pexels.com/photos/1108099/pexels-photo-1108099.jpeg?auto=compress&cs=tinysrgb&dpr=1&w=500" - ), - None, -] -SAMPLING_PARAMS = [ - dict(top_k=1, temperature=1.0, top_p=1.0, max_tokens=16) - for _ in range(len(PROMPTS)) -] - - -def get_VLLM_mllama_model_inputs(prompt, single_image, sampling_params): - # Prepare all inputs for mllama generation, including: - # 1. put text prompt into instruct chat template - # 2. compose single text and single image prompt into Vllm's prompt class - # 3. prepare sampling parameters - input_image = single_image - has_image = torch.tensor([1]) - if isinstance(single_image, torch.Tensor) and single_image.numel() == 0: - has_image = torch.tensor([0]) - - instruct_prompt = add_instruct(prompt, has_image) - inputs = TextPrompt(prompt=instruct_prompt) - - if input_image is not None: - inputs["multi_modal_data"] = {"image": input_image} - - sampling_params = SamplingParams(**sampling_params) - return inputs, sampling_params - - -def print_outputs(outputs): - # Print the outputs. - for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") - - -def main(): - assert ( - len(PROMPTS) == len(IMAGES) == len(SAMPLING_PARAMS) - ), f"""Text, image prompts and sampling parameters should have the - same batch size; but got {len(PROMPTS)}, {len(IMAGES)}, - and {len(SAMPLING_PARAMS)}""" - - # Create an LLM. - llm = LLM( - model="meta-llama/Llama-3.2-11B-Vision-Instruct", - max_num_seqs=1, - max_model_len=4096, - block_size=4096, - device="neuron", - tensor_parallel_size=32, - override_neuron_config={ - "sequence_parallel_enabled": False, - "skip_warmup": True, - "save_sharded_checkpoint": True, - "on_device_sampling_config": { - "global_topk": 1, - "dynamic": False, - "deterministic": False, - }, - }, - ) - - batched_inputs = [] - batched_sample_params = [] - for pmpt, img, params in zip(PROMPTS, IMAGES, SAMPLING_PARAMS): - inputs, sampling_params = get_VLLM_mllama_model_inputs(pmpt, img, params) - # test batch-size = 1 - outputs = llm.generate(inputs, sampling_params) - print_outputs(outputs) - batched_inputs.append(inputs) - batched_sample_params.append(sampling_params) - - # test batch-size = 4 - outputs = llm.generate(batched_inputs, batched_sample_params) - print_outputs(outputs) - - -if __name__ == "__main__": - main() diff --git a/examples/offline_inference/neuron_speculation.py b/examples/offline_inference/neuron_speculation.py deleted file mode 100644 index 7fc22caee..000000000 --- a/examples/offline_inference/neuron_speculation.py +++ /dev/null @@ -1,64 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -""" -This example shows how to run offline inference with a speculative -decoding model on neuron. -""" - -import os - -from vllm import LLM, SamplingParams - -# Sample prompts. -prompts = [ - "Hello, I am a language model and I can help", - "The president of the United States is", - "The capital of France is", -] - - -def config_buckets(): - """Configure context length and token gen buckets.""" - # creates XLA hlo graphs for all the context length buckets. - os.environ["NEURON_CONTEXT_LENGTH_BUCKETS"] = "128,512,1024,2048" - # creates XLA hlo graphs for all the token gen buckets. - os.environ["NEURON_TOKEN_GEN_BUCKETS"] = "128,512,1024,2048" - - -def initialize_llm(): - """Create an LLM with speculative decoding.""" - return LLM( - model="openlm-research/open_llama_7b", - speculative_config={ - "model": "openlm-research/open_llama_3b", - "num_speculative_tokens": 4, - "max_model_len": 2048, - }, - max_num_seqs=4, - max_model_len=2048, - block_size=2048, - device="neuron", - tensor_parallel_size=32, - ) - - -def process_requests(llm: LLM, sampling_params: SamplingParams): - """Generate texts from prompts and print them.""" - outputs = llm.generate(prompts, sampling_params) - for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") - - -def main(): - """Main function that sets up the llm and processes prompts.""" - config_buckets() - llm = initialize_llm() - # Create a sampling params object. - sampling_params = SamplingParams(max_tokens=100, top_k=1) - process_requests(llm, sampling_params) - - -if __name__ == "__main__": - main() diff --git a/examples/offline_inference/openai_batch/README.md b/examples/offline_inference/openai_batch/README.md index 3c6f6c7a6..7d5a1af8f 100644 --- a/examples/offline_inference/openai_batch/README.md +++ b/examples/offline_inference/openai_batch/README.md @@ -152,7 +152,9 @@ def generate_presigned_url(s3_client, client_method, method_parameters, expires_ """ try: url = s3_client.generate_presigned_url( - ClientMethod=client_method, Params=method_parameters, ExpiresIn=expires_in + ClientMethod=client_method, + Params=method_parameters, + ExpiresIn=expires_in, ) except ClientError: raise @@ -161,10 +163,16 @@ def generate_presigned_url(s3_client, client_method, method_parameters, expires_ s3_client = boto3.client("s3") input_url = generate_presigned_url( - s3_client, "get_object", {"Bucket": "MY_BUCKET", "Key": "MY_INPUT_FILE.jsonl"}, 3600 + s3_client, + "get_object", + {"Bucket": "MY_BUCKET", "Key": "MY_INPUT_FILE.jsonl"}, + expires_in=3600, ) output_url = generate_presigned_url( - s3_client, "put_object", {"Bucket": "MY_BUCKET", "Key": "MY_OUTPUT_FILE.jsonl"}, 3600 + s3_client, + "put_object", + {"Bucket": "MY_BUCKET", "Key": "MY_OUTPUT_FILE.jsonl"}, + expires_in=3600, ) print(f"{input_url=}") print(f"{output_url=}") diff --git a/examples/offline_inference/pooling/README.md b/examples/offline_inference/pooling/README.md new file mode 100644 index 000000000..ad78be387 --- /dev/null +++ b/examples/offline_inference/pooling/README.md @@ -0,0 +1,57 @@ +# Pooling models + +## Convert llm model to seq cls + +```bash +# for BAAI/bge-reranker-v2-gemma +# Caution: "Yes" and "yes" are two different tokens +python examples/offline_inference/pooling/convert_model_to_seq_cls.py --model_name BAAI/bge-reranker-v2-gemma --classifier_from_tokens '["Yes"]' --method no_post_processing --path ./bge-reranker-v2-gemma-seq-cls +# for mxbai-rerank-v2 +python examples/offline_inference/pooling/convert_model_to_seq_cls.py --model_name mixedbread-ai/mxbai-rerank-base-v2 --classifier_from_tokens '["0", "1"]' --method from_2_way_softmax --path ./mxbai-rerank-base-v2-seq-cls +# for Qwen3-Reranker +python examples/offline_inference/pooling/convert_model_to_seq_cls.py --model_name Qwen/Qwen3-Reranker-0.6B --classifier_from_tokens '["no", "yes"]' --method from_2_way_softmax --path ./Qwen3-Reranker-0.6B-seq-cls +``` + +## Embed jina_embeddings_v3 usage + +Only text matching task is supported for now. See + +```bash +python examples/offline_inference/pooling/embed_jina_embeddings_v3.py +``` + +## Embed matryoshka dimensions usage + +```bash +python examples/offline_inference/pooling/embed_matryoshka_fy.py +``` + +## Multi vector retrieval usage + +```bash +python examples/offline_inference/pooling/multi_vector_retrieval.py +``` + +## Named Entity Recognition (NER) usage + +```bash +python examples/offline_inference/pooling/ner.py +``` + +## Prithvi Geospatial MAE usage + +```bash +python examples/offline_inference/pooling/prithvi_geospatial_mae.py +``` + +## IO Processor Plugins for Prithvi Geospatial MAE + +```bash +python examples/offline_inference/pooling/prithvi_geospatial_mae_io_processor.py +``` + +## Qwen3 reranker usage + +```bash +python examples/offline_inference/pooling/qwen3_reranker.py +``` diff --git a/examples/offline_inference/convert_model_to_seq_cls.py b/examples/offline_inference/pooling/convert_model_to_seq_cls.py similarity index 100% rename from examples/offline_inference/convert_model_to_seq_cls.py rename to examples/offline_inference/pooling/convert_model_to_seq_cls.py diff --git a/examples/offline_inference/embed_jina_embeddings_v3.py b/examples/offline_inference/pooling/embed_jina_embeddings_v3.py similarity index 96% rename from examples/offline_inference/embed_jina_embeddings_v3.py rename to examples/offline_inference/pooling/embed_jina_embeddings_v3.py index 33a63deee..b117b0bd5 100644 --- a/examples/offline_inference/embed_jina_embeddings_v3.py +++ b/examples/offline_inference/pooling/embed_jina_embeddings_v3.py @@ -4,7 +4,7 @@ from argparse import Namespace from vllm import LLM, EngineArgs -from vllm.utils import FlexibleArgumentParser +from vllm.utils.argparse_utils import FlexibleArgumentParser def parse_args(): diff --git a/examples/offline_inference/embed_matryoshka_fy.py b/examples/offline_inference/pooling/embed_matryoshka_fy.py similarity index 96% rename from examples/offline_inference/embed_matryoshka_fy.py rename to examples/offline_inference/pooling/embed_matryoshka_fy.py index 6871bcfcc..6544df852 100644 --- a/examples/offline_inference/embed_matryoshka_fy.py +++ b/examples/offline_inference/pooling/embed_matryoshka_fy.py @@ -4,7 +4,7 @@ from argparse import Namespace from vllm import LLM, EngineArgs, PoolingParams -from vllm.utils import FlexibleArgumentParser +from vllm.utils.argparse_utils import FlexibleArgumentParser def parse_args(): diff --git a/examples/offline_inference/pooling/multi_vector_retrieval.py b/examples/offline_inference/pooling/multi_vector_retrieval.py new file mode 100644 index 000000000..fa7d1c3ba --- /dev/null +++ b/examples/offline_inference/pooling/multi_vector_retrieval.py @@ -0,0 +1,56 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from argparse import Namespace + +from vllm import LLM, EngineArgs +from vllm.utils.argparse_utils import FlexibleArgumentParser + + +def parse_args(): + parser = FlexibleArgumentParser() + parser = EngineArgs.add_cli_args(parser) + # Set example specific arguments + parser.set_defaults( + model="BAAI/bge-m3", + runner="pooling", + enforce_eager=True, + ) + return parser.parse_args() + + +def main(args: Namespace): + # Sample prompts. + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + + # Create an LLM. + # You should pass runner="pooling" for embedding models + llm = LLM(**vars(args)) + + # Generate embedding. The output is a list of EmbeddingRequestOutputs. + outputs = llm.embed(prompts) + + # Print the outputs. + print("\nGenerated Outputs:\n" + "-" * 60) + for prompt, output in zip(prompts, outputs): + embeds = output.outputs.embedding + print(len(embeds)) + + # Generate embedding for each token. The output is a list of PoolingRequestOutput. + outputs = llm.encode(prompts, pooling_task="token_embed") + + # Print the outputs. + print("\nGenerated Outputs:\n" + "-" * 60) + for prompt, output in zip(prompts, outputs): + multi_vector = output.outputs.data + print(multi_vector.shape) + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/examples/offline_inference/pooling/ner.py b/examples/offline_inference/pooling/ner.py new file mode 100644 index 000000000..34c80e7cc --- /dev/null +++ b/examples/offline_inference/pooling/ner.py @@ -0,0 +1,54 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Adapted from https://huggingface.co/boltuix/NeuroBERT-NER + +from argparse import Namespace + +from vllm import LLM, EngineArgs +from vllm.utils.argparse_utils import FlexibleArgumentParser + + +def parse_args(): + parser = FlexibleArgumentParser() + parser = EngineArgs.add_cli_args(parser) + # Set example specific arguments + parser.set_defaults( + model="boltuix/NeuroBERT-NER", + runner="pooling", + enforce_eager=True, + trust_remote_code=True, + ) + return parser.parse_args() + + +def main(args: Namespace): + # Sample prompts. + prompts = [ + "Barack Obama visited Microsoft headquarters in Seattle on January 2025." + ] + + # Create an LLM. + llm = LLM(**vars(args)) + tokenizer = llm.get_tokenizer() + label_map = llm.llm_engine.vllm_config.model_config.hf_config.id2label + + # Run inference + outputs = llm.encode(prompts, pooling_task="token_classify") + + for prompt, output in zip(prompts, outputs): + logits = output.outputs.data + predictions = logits.argmax(dim=-1) + + # Map predictions to labels + tokens = tokenizer.convert_ids_to_tokens(output.prompt_token_ids) + labels = [label_map[p.item()] for p in predictions] + + # Print results + for token, label in zip(tokens, labels): + if token not in tokenizer.all_special_tokens: + print(f"{token:15} → {label}") + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/examples/offline_inference/prithvi_geospatial_mae.py b/examples/offline_inference/pooling/prithvi_geospatial_mae.py similarity index 97% rename from examples/offline_inference/prithvi_geospatial_mae.py rename to examples/offline_inference/pooling/prithvi_geospatial_mae.py index b6007b9f4..b093c77c0 100644 --- a/examples/offline_inference/prithvi_geospatial_mae.py +++ b/examples/offline_inference/pooling/prithvi_geospatial_mae.py @@ -3,7 +3,6 @@ import argparse import datetime import os -from typing import Union import albumentations import numpy as np @@ -45,7 +44,12 @@ class PrithviMAE: def __init__(self, model): self.model = LLM( - model=model, skip_tokenizer_init=True, dtype="float16", enforce_eager=True + model=model, + skip_tokenizer_init=True, + dtype="float16", + enforce_eager=True, + model_impl="terratorch", + enable_mm_embeds=True, ) def run(self, input_data, location_coords): @@ -60,7 +64,7 @@ def run(self, input_data, location_coords): } prompt = {"prompt_token_ids": [1], "multi_modal_data": mm_data} - outputs = self.model.encode(prompt, use_tqdm=False) + outputs = self.model.encode(prompt, pooling_task="plugin", use_tqdm=False) return outputs[0].outputs.data @@ -156,7 +160,7 @@ def load_example( file_paths: list[str], mean: list[float] = None, std: list[float] = None, - indices: Union[list[int], None] = None, + indices: list[int] | None = None, ): """Build an input example by loading images in *file_paths*. diff --git a/examples/offline_inference/pooling/prithvi_geospatial_mae_io_processor.py b/examples/offline_inference/pooling/prithvi_geospatial_mae_io_processor.py new file mode 100644 index 000000000..b8637b89e --- /dev/null +++ b/examples/offline_inference/pooling/prithvi_geospatial_mae_io_processor.py @@ -0,0 +1,58 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import base64 +import os + +import torch + +from vllm import LLM + +# This example shows how to perform an offline inference that generates +# multimodal data. In this specific case this example will take a geotiff +# image as input, process it using the multimodal data processor, and +# perform inference. +# Requirements: +# - install TerraTorch v1.1 (or later): +# pip install terratorch>=v1.1 + + +def main(): + torch.set_default_dtype(torch.float16) + image_url = "https://huggingface.co/christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM/resolve/main/valencia_example_2024-10-26.tiff" # noqa: E501 + + img_prompt = dict( + data=image_url, + data_format="url", + image_format="tiff", + out_data_format="b64_json", + ) + + llm = LLM( + model="christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM", + skip_tokenizer_init=True, + trust_remote_code=True, + enforce_eager=True, + # Limit the maximum number of parallel requests + # to avoid the model going OOM. + # The maximum number depends on the available GPU memory + max_num_seqs=32, + io_processor_plugin="terratorch_segmentation", + model_impl="terratorch", + enable_mm_embeds=True, + ) + + pooler_output = llm.encode(img_prompt, pooling_task="plugin") + output = pooler_output[0].outputs + + print(output) + decoded_data = base64.b64decode(output.data) + + file_path = os.path.join(os.getcwd(), "offline_prediction.tiff") + with open(file_path, "wb") as f: + f.write(decoded_data) + + print(f"Output file path: {file_path}") + + +if __name__ == "__main__": + main() diff --git a/examples/offline_inference/qwen3_reranker.py b/examples/offline_inference/pooling/qwen3_reranker.py similarity index 100% rename from examples/offline_inference/qwen3_reranker.py rename to examples/offline_inference/pooling/qwen3_reranker.py diff --git a/examples/offline_inference/profiling.py b/examples/offline_inference/profiling.py deleted file mode 100644 index 392fba8fc..000000000 --- a/examples/offline_inference/profiling.py +++ /dev/null @@ -1,510 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import inspect -import json -import os -import sys -from argparse import RawTextHelpFormatter -from collections.abc import Generator -from dataclasses import asdict, dataclass -from typing import Any, Optional, TypeAlias - -import torch -import tqdm - -from vllm import LLM, SamplingParams -from vllm.engine.arg_utils import EngineArgs -from vllm.profiler.layerwise_profile import layerwise_profile -from vllm.utils import FlexibleArgumentParser - -BATCH_SIZE_DEFAULT = 1 -PROMPT_LEN_DEFAULT = 256 - - -@dataclass -class ProfileContext: - engine_args: EngineArgs - prompt_len: int - batch_size: int - - # The profiler can run in 2 modes, - # 1. Run profiler for user specified num_steps - num_steps: Optional[int] = None - # 2. Run profiler until all requests complete - complete_num_requests_per_step: Optional[int] = None - - save_chrome_traces_folder: Optional[str] = None - - -def get_dtype(dtype: str): - if dtype == "torch.float": - return torch.float - else: - return dtype - - -OutputLen_NumReqs_Map: TypeAlias = dict[int, int] - - -def compute_request_output_lengths( - batch_size: int, step_requests: list[int] -) -> OutputLen_NumReqs_Map: - """ - Given the number of requests, batch_size, and the number of requests - that each engine-step should process, step_requests, determine the - output lengths of the requests such that step_request is honoured. - - Example: - if batch size = 128 and step_request = [128, 128, 96, 64, 32, 1] - then return, - {2 : 32, 3 : 32, 4 : 32, 5 : 31, 6 : 1}, meaning, - 32 requests should have output length 2, - 32 requests should have output length 3, - 32 requests should have output length 4, - 31 requests should have output length 5, - 1 request should have output length 6. - - Args: - batch_size (int): Number of requests submitted for profile. This is - args.batch_size. - step_requests (list[int]): step_requests[i] is the number of requests - that the ith engine step should process. - - Returns: - OutputLen_NumReqs_Map : A dictionary with output-length as keys and the - number of requests required to have that output-length as values. - """ - ol_nr: OutputLen_NumReqs_Map = {} - - # Number of request that are assigned an output-length - num_reqs_assigned: int = 0 - num_steps: int = len(step_requests) - - # sanity check. The first step (prefill-step), must process all requests. - assert step_requests[0] == batch_size - - # Begin assignments from the last step. - output_length: int = num_steps - for num_requests_at_step in reversed(step_requests): - if num_reqs_assigned == batch_size: - break - - assert num_reqs_assigned < batch_size - - # Remove the number of requests that have been determined - # to participate in this step and beyond. - num_reqs_unassigned_at_step = num_requests_at_step - num_reqs_assigned - assert num_reqs_unassigned_at_step >= 0 - - if num_reqs_unassigned_at_step > 0: - ol_nr[output_length] = num_reqs_unassigned_at_step - num_reqs_assigned += num_reqs_unassigned_at_step - - output_length -= 1 - - # sanity checks. - assert sum(ol_nr.values()) == batch_size, ( - "Number of requests in output-length assignment does not match " - f"batch-size.\n batch size {batch_size} - " - f"step requests {step_requests} - assignments {ol_nr}" - ) - - # Check that the output-length is in [1, num-steps]. Output length must be - # at least 1 as all requests must participate in the prefill-step. - assert all(ol >= 1 and ol <= num_steps for ol in ol_nr), ( - "Output lengths of requests should be in range " - f"[1, num-engine-steps].\n batch size {batch_size} - " - f"step requests {step_requests} - assignments {ol_nr}" - ) - - return ol_nr - - -def determine_requests_per_step(context: ProfileContext) -> list[int]: - """ - Determine number of requests each engine step should process. - If context.num_steps is set, then all engine steps process the - same number of requests and the output list is of length - context.num_steps. - - If context.complete_num_requests_per_step is set, then each decode step - processes fewer and fewer requests until there are no requests to process. - In this case, the output list is as big as the number of steps - required to process all requests. - - Args: - context: ProfileContext object. - - Returns: - list[int]: Number of requests to process for all engine-steps. - output[i], contains the number of requests that the ith step - should process. - """ - if context.num_steps: - # All requests must run until num_engine_steps. This implies - # that their output lengths must be equal to num_engine_steps. - return [context.batch_size] * context.num_steps - - assert ( - context.complete_num_requests_per_step - and context.complete_num_requests_per_step > 0 - ), ( - f"Expected a positive complete_num_requests_per_step argument." - f"Instead got {context.complete_num_requests_per_step}" - ) - - # We start dropping after the first decode step. - step_requests = [ - context.batch_size, # prefill - context.batch_size, # decode - ] - - num_running_requests = context.batch_size - num_running_requests -= context.complete_num_requests_per_step - while num_running_requests > 0: - step_requests.append(num_running_requests) - num_running_requests -= context.complete_num_requests_per_step - - if step_requests[-1] != 1: - # have 1 request running at the last step. This is often - # useful - step_requests.append(1) - - return step_requests - - -def run_profile( - context: ProfileContext, csv_output: Optional[str], json_output: Optional[str] -): - print("Run profile with:") - for key, value in asdict(context).items(): - print(f" {key} = {value}") - - requests_per_step: list[int] = determine_requests_per_step(context) - - ol_nr: OutputLen_NumReqs_Map = compute_request_output_lengths( - context.batch_size, requests_per_step - ) - - num_steps_to_profile: int = len(requests_per_step) - max_output_len: int = max(ol_nr.keys()) - assert max_output_len >= 1 - - # Create sampling params - sampling_params = SamplingParams( - temperature=0.8, - top_p=0.95, - # max_tokens is set on a per-request basis. - max_tokens=None, - ignore_eos=True, - ) - - # Create LLM - llm = LLM(**asdict(context.engine_args)) - batch_size = context.batch_size - prompt_len = context.prompt_len - - scheduler_config = llm.llm_engine.vllm_config.scheduler_config - max_model_len = llm.llm_engine.model_config.max_model_len - max_num_batched_tokens = scheduler_config.max_num_batched_tokens - max_num_seqs = scheduler_config.max_num_seqs - - if batch_size * prompt_len > max_num_batched_tokens: - print( - f"ERROR: chosen batch_size * prompt_len " - f"({batch_size} * {prompt_len} = {batch_size * prompt_len}) is " - f"larger than max_num_batched_tokens ({max_num_batched_tokens}) " - f"and therefore cannot be run in a single profile step, please " - f"choose a smaller batch size or prompt length, or increase " - f"--max-num-batched-tokens" - ) - sys.exit(-1) - if batch_size > max_num_seqs: - print( - f"ERROR: chosen batch_size ({batch_size}) is larger than " - f"max_num_seqs ({max_num_seqs}) and therefore cannot be run in a " - f"single profile step, please choose a smaller batch size" - ) - sys.exit(-1) - print( - "llm.llm_engine.model_config.max_model_len: ", - llm.llm_engine.model_config.max_model_len, - ) - if prompt_len + max_output_len > llm.llm_engine.model_config.max_model_len: - print( - f"ERROR: chosen prompt_len + max_output_len ({prompt_len} + " - f"{max_output_len} = {prompt_len + max_output_len}) is larger " - f"than the model's max_model_len ({max_model_len}), please " - f"choose a smaller prompt_len or max_output_len, or increase " - f"--max-model-len" - ) - sys.exit(-1) - - def add_requests(): - def get_output_len_generator() -> Generator[int, Any, Any]: - for output_len, num_reqs in ol_nr.items(): - for _ in range(num_reqs): - yield output_len - - output_len_generator = get_output_len_generator() - for i in range(batch_size): - sampling_params.max_tokens = next(output_len_generator) - assert isinstance(sampling_params.max_tokens, int) - - prompt_token_ids = torch.randint( - llm.get_tokenizer().vocab_size, size=(prompt_len,) - ).tolist() - - llm.llm_engine.add_request( - request_id=f"seq{i}", - prompt={"prompt_token_ids": prompt_token_ids}, - params=sampling_params, - ) - - def abort_requests(): - for i in range(batch_size): - llm.llm_engine.abort_request(f"seq{i}") - - # Warm up run - print("Warm up run ...") - add_requests() - llm.llm_engine.step() # Prefill - llm.llm_engine.step() # Decode - abort_requests() - - print("Profile run ...") - add_requests() - - with layerwise_profile() as prefill_prof: - llm.llm_engine.step() # First step is prefill - - decode_profs = [] - for _ in tqdm.tqdm(range(num_steps_to_profile - 1)): - num_running_seqs = llm.llm_engine.scheduler[0].get_num_unfinished_seq_groups() - with layerwise_profile(num_running_seqs=num_running_seqs) as decode_prof: - llm.llm_engine.step() - decode_profs.append(decode_prof) - - decode_results_list = [prof.results for prof in decode_profs] - prefill_results = prefill_prof.results - has_decode = len(decode_results_list) > 0 - - LINE_WIDTH = 80 - print("=" * LINE_WIDTH) - print(f"= Prefill Model Table (prompt_len={prompt_len}, batch_size={batch_size})") - print("=" * LINE_WIDTH) - print() - prefill_results.print_model_table() - - if has_decode: - print() - print("=" * LINE_WIDTH) - print( - f"= First Decode Step Model Table " - f"(prompt_len={prompt_len}, batch_size={batch_size})" - ) - print("=" * LINE_WIDTH) - print() - decode_results_list[0].print_model_table() - - print() - print("=" * LINE_WIDTH) - print(f"= Prefill Summary Table (prompt_len={prompt_len}, batch_size={batch_size})") - print("=" * LINE_WIDTH) - print() - prefill_results.print_summary_table() - - if has_decode: - print() - print("=" * LINE_WIDTH) - print( - f"= First Decode Step Summary Table " - f"(prompt_len={prompt_len}, batch_size={batch_size})" - ) - print("=" * LINE_WIDTH) - print() - decode_results_list[0].print_summary_table() - - if csv_output: - csv_filename_base = ( - csv_output[:-4] if csv_output.endswith(".csv") else csv_output - ) - prefill_results.export_model_stats_table_csv( - csv_filename_base + "_prefill_model_table.csv" - ) - prefill_results.export_summary_stats_table_csv( - csv_filename_base + "_prefill_summary_table.csv" - ) - - if has_decode: - decode_results_list[0].export_model_stats_table_csv( - csv_filename_base + "_decode_model_table.csv" - ) - decode_results_list[0].export_summary_stats_table_csv( - csv_filename_base + "_decode_summary_table.csv" - ) - - if json_output: - cuda_devices = [ - torch.cuda.get_device_properties(dev_idx) - for dev_idx in range(torch.cuda.device_count()) - ] - - json_dict = { - "context": { - "python_version": f"{sys.version}", - "torch_version": f"{torch.__version__}", - "torch_cuda_version": f"{torch.version.cuda}", - "cuda_devices": f"{cuda_devices}", - **asdict(context), - }, - "prefill": prefill_results.convert_stats_to_dict(), - } - - if has_decode: - for idx, dr in enumerate(decode_results_list): - json_dict[f"decode_{idx + 1}"] = dr.convert_stats_to_dict() - - # Add .json to json_output filename if it doesn't exist already. - json_output_file = ( - json_output if json_output.endswith(".json") else json_output + ".json" - ) - with open(json_output_file, "w+") as f: - json.dump(json_dict, f, indent=2) - pass - - if context.save_chrome_traces_folder is not None: - os.makedirs(context.save_chrome_traces_folder, exist_ok=True) - prefill_prof.profiler.export_chrome_trace( - context.save_chrome_traces_folder + "/prefill.json" - ) - for idx, decode_prof in enumerate(decode_profs): - decode_prof.profiler.export_chrome_trace( - context.save_chrome_traces_folder + f"/decode_{idx + 1}.json" - ) - print( - "Traces saved as prefill.json and decode_1.json, etc." - f" in folder {context.save_chrome_traces_folder}" - ) - - -def parse_args(): - parser = FlexibleArgumentParser( - description=""" -Profile a model - - example: - ``` - python examples/offline_inference/profiling.py \\ - --model neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8 --batch-size 4 \\ - --prompt-len 512 --max-num-batched-tokens 8196 --json Llama31-8b-FP8 \\ - --enforce-eager run_num_steps -n 2 - ``` - - then you can use various tools to analyze the json output - terminal ascii tables: - ``` - python tools/profiler/print_layerwise_table.py \\ - --json-trace Llama31-8b-FP8.json --phase prefill --table summary - ``` - or create matplotlib stacked bar charts: - ``` - python tools/profiler/visualize_layerwise_profile.py \\ - --json-trace Llama31-8b-FP8.json \\ - --output-directory profile_breakdown --plot-metric pct_cuda_time - ``` -""", - formatter_class=RawTextHelpFormatter, - ) - parser.add_argument( - "--csv", - type=str, - default=None, - help="Export the results as multiple csv file. This should be the root " - "filename, will create _prefill_model_table.csv, " - "_prefill_summary_table.csv, " - "_decode_model_table.csv, and " - "_decode_summary_table.csv", - ) - parser.add_argument( - "--json", - type=str, - default=None, - help="Export the results as a json file. This should be the filename", - ) - parser.add_argument( - "--save-chrome-traces-folder", - type=str, - help="Save chrome traces for the prefill and decode " - "will save traces as prefill.json and decode_1.json, " - "etc. inside this folder", - ) - parser.add_argument( - "--prompt-len", - type=int, - default=PROMPT_LEN_DEFAULT, - help=f"Length of the random prompt to use when profiling, all batched " - f"requests use the same prompt_len, default={PROMPT_LEN_DEFAULT}", - ) - parser.add_argument( - "--batch-size", - type=int, - default=BATCH_SIZE_DEFAULT, - help=f"Number of requests to run as a single batch, " - f"default={BATCH_SIZE_DEFAULT}", - ) - - subparsers = parser.add_subparsers(dest="cmd") - - run_num_steps_parser = subparsers.add_parser( - "run_num_steps", help="This variation profiles n engine.step() invocations." - ) - run_num_steps_parser.add_argument( - "-n", - "--num-steps", - type=int, - help="Number of engine steps to profile.\n" - "Setting it to 1, profiles only the prefill step.\n" - "Setting it to 2, profiles the prefill and first decode step\n" - "Setting it to 3, profiles the prefill, 1st and 2nd decode steps\n" - "and so on ...", - ) - - run_to_completion_parser = subparsers.add_parser( - "run_to_completion", - help="This variation profiles all the engine.step() invocations" - "until the engine exhausts all submitted requests.", - ) - run_to_completion_parser.add_argument( - "-n", - "--complete-num-requests-per-step", - type=int, - help="Complete complete_num_requests_per_step requests every decode step." - "For e.g., with batch_size 128 and complete_num_requests_per_step 32," - "the profiler is run for 6 engine steps, with the steps processing, " - "128, 128, 96, 64, 32, 1 requests respectively.\n" - "Note that we tack-on a one-request step at the end as it is often " - "useful.", - ) - - EngineArgs.add_cli_args(parser) - - return parser.parse_args() - - -def main(args): - context = ProfileContext( - engine_args=EngineArgs.from_cli_args(args), - **{ - k: v - for k, v in vars(args).items() - if k in inspect.signature(ProfileContext).parameters - }, - ) - run_profile(context, csv_output=args.csv, json_output=args.json) - - -if __name__ == "__main__": - args = parse_args() - main(args) diff --git a/examples/offline_inference/profiling_tpu/profiling.py b/examples/offline_inference/profiling_tpu/profiling.py index dfcbd8c8d..3b127e4fd 100644 --- a/examples/offline_inference/profiling_tpu/profiling.py +++ b/examples/offline_inference/profiling_tpu/profiling.py @@ -13,7 +13,7 @@ from vllm import LLM, SamplingParams from vllm.engine.arg_utils import EngineArgs from vllm.inputs import PromptType -from vllm.utils import FlexibleArgumentParser +from vllm.utils.argparse_utils import FlexibleArgumentParser DURATION_MS = int(os.getenv("VLLM_TPU_PROFILE_DURATION_MS", 3000)) DELAY_MS = int(os.getenv("VLLM_TPU_PROFILE_DELAY_MS", 0)) diff --git a/examples/offline_inference/qwen2_5_omni/only_thinker.py b/examples/offline_inference/qwen2_5_omni/only_thinker.py index 62effd5c8..6fbe1303f 100644 --- a/examples/offline_inference/qwen2_5_omni/only_thinker.py +++ b/examples/offline_inference/qwen2_5_omni/only_thinker.py @@ -13,7 +13,7 @@ from vllm.assets.image import ImageAsset from vllm.assets.video import VideoAsset from vllm.multimodal.image import convert_image_mode -from vllm.utils import FlexibleArgumentParser +from vllm.utils.argparse_utils import FlexibleArgumentParser class QueryResult(NamedTuple): diff --git a/examples/offline_inference/qwen_1m.py b/examples/offline_inference/qwen_1m.py index d8d61667f..c8d0d91ce 100644 --- a/examples/offline_inference/qwen_1m.py +++ b/examples/offline_inference/qwen_1m.py @@ -5,7 +5,6 @@ from vllm import LLM, SamplingParams -os.environ["VLLM_ATTENTION_BACKEND"] = "DUAL_CHUNK_FLASH_ATTN" os.environ["VLLM_ALLOW_LONG_MAX_MODEL_LEN"] = "1" diff --git a/examples/offline_inference/rlhf.py b/examples/offline_inference/rlhf.py index ed974b90b..0c09e6032 100644 --- a/examples/offline_inference/rlhf.py +++ b/examples/offline_inference/rlhf.py @@ -38,7 +38,7 @@ from transformers import AutoModelForCausalLM from vllm import LLM, SamplingParams -from vllm.utils import get_ip, get_open_port +from vllm.utils.network_utils import get_ip, get_open_port class MyLLM(LLM): diff --git a/examples/offline_inference/rlhf_colocate.py b/examples/offline_inference/rlhf_colocate.py index 65621023a..360fd79b5 100644 --- a/examples/offline_inference/rlhf_colocate.py +++ b/examples/offline_inference/rlhf_colocate.py @@ -28,12 +28,15 @@ https://docs.ray.io/en/latest/placement-groups.html """ +import gc import os import ray import torch +import zmq from ray.util.placement_group import placement_group from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy +from torch.multiprocessing.reductions import reduce_tensor from vllm import LLM @@ -86,20 +89,72 @@ def __init__(self): from vllm.platforms import current_platform self.device_uuid = current_platform.get_device_uuid(0) + self.zmq_context = zmq.Context() + self.zmq_address_counter = 0 + self.zmq_handle = None def report_device_id(self) -> str: return self.device_uuid - def get_weight_ipc_handles(self): - from torch.multiprocessing.reductions import reduce_tensor + def get_zmq_handles(self) -> dict[str, str]: + suffix = f"{self.device_uuid}-{self.zmq_address_counter}" + self.zmq_handle = f"ipc:///tmp/rl-colocate-zmq-{suffix}.sock" + self.zmq_address_counter += 1 + return {self.device_uuid: self.zmq_handle} - data = {} - for name, p in self.model.named_parameters(): - # A training actor might hold only a subset of the weights and may - # need to gather weights from other actors. For demonstration - # purposes, each training actor owns the full weight set. - data[name] = reduce_tensor(p.detach()) - return {self.device_uuid: data} + def update_weights(self): + # align size to avoid misaligned address + align_size = 256 + + def get_size(p: torch.Tensor) -> int: + return (p.nbytes + align_size - 1) // align_size * align_size + + named_parameters: dict[str, torch.nn.Parameter] = dict( + self.model.named_parameters() + ) + max_tensor_size = max(get_size(p) for p in named_parameters.values()) + # use max_tensor_size * 2 as buffer size + buffer = torch.empty(max_tensor_size * 2, dtype=torch.uint8, device="cuda:0") + s = self.zmq_context.socket(zmq.REQ) + s.bind(self.zmq_handle) + handle = reduce_tensor(buffer) + + offset = 0 + buckets: list[tuple[list[dict], list[torch.Tensor]]] = [] + named_tensors: list[dict] = [] + real_tensors: list[torch.Tensor] = [] + for name, p in named_parameters.items(): + size = get_size(p) + if offset + size > buffer.numel(): + buckets.append((named_tensors, real_tensors)) + named_tensors, real_tensors = [], [] + offset = 0 + # assume tensors are contiguous + named_tensors.append( + {"name": name, "dtype": p.dtype, "shape": p.shape, "offset": offset} + ) + real_tensors.append(p) + offset += size + if named_tensors: + buckets.append((named_tensors, real_tensors)) + s.send_pyobj(handle) + s.recv() + for named_tensors, real_tensors in buckets: + offset = 0 + for p in real_tensors: + buffer[offset : offset + p.nbytes].data.copy_( + p.data.view(-1).view(dtype=torch.uint8), non_blocking=True + ) + offset += get_size(p) + torch.cuda.synchronize() + s.send_pyobj(named_tensors) + s.recv() + s.send_pyobj(None) + s.recv() + s.close() + del buffer + gc.collect() + torch.cuda.empty_cache() # Ray manages four GPUs. @@ -175,18 +230,22 @@ def get_weight_ipc_handles(self): # the second inference engine. assert training_actor_device_ids[2:] == inference_engine_device_ids[1] -print("Gather all the IPC handles from the training actors.") -ipc_handles = {} +print("Gather all the ZMQ handles from the training actors.") +zmq_handles = {} for actor in training_actors: - ipc_handles.update(ray.get(actor.get_weight_ipc_handles.remote())) + zmq_handles.update(ray.get(actor.get_zmq_handles.remote())) + +print(f"ZMQ handles: {zmq_handles}") print("Update the weights of the inference engines.") -for llm in inference_engines: - ray.get( - llm.collective_rpc.remote( - "update_weights_from_ipc_handles", args=(ipc_handles,) - ) - ) +ray.get( + [actor.update_weights.remote() for actor in training_actors] + + [ + llm.collective_rpc.remote("update_weights_from_ipc", args=(zmq_handles,)) + for llm in inference_engines + ] +) + print("Check if the weights are updated.") for llm in inference_engines: assert ray.get(llm.collective_rpc.remote("check_weights_changed", args=tuple())) diff --git a/examples/offline_inference/rlhf_utils.py b/examples/offline_inference/rlhf_utils.py index d2a8419ff..13def8843 100644 --- a/examples/offline_inference/rlhf_utils.py +++ b/examples/offline_inference/rlhf_utils.py @@ -1,6 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import gc +from collections.abc import Callable +from typing import TypedDict + import torch +import zmq def stateless_init_process_group(master_address, master_port, rank, world_size, device): @@ -66,6 +71,27 @@ def check_weights_changed(self): return weights_updated +def rebuild_ipc( + handle: tuple[Callable, tuple], device_id: int | None = None +) -> torch.Tensor: + func, args = handle + list_args = list(args) + if device_id is not None: + # the key is to change device id to the current device id + # in case two processes have different CUDA_VISIBLE_DEVICES + list_args[6] = device_id + buffer = func(*list_args) + return buffer + + +class FlattenedTensorMetadata(TypedDict): + name: str + shape: torch.Size + dtype: torch.dtype + # specify the start offset of this tensor in shared ipc_buffer tensor + offset: int + + class ColocateWorkerExtension: """ The class for vLLM's worker to inherit from, in the colocate setting. @@ -76,27 +102,62 @@ class ColocateWorkerExtension: should pass the full qualified name as `worker_extension_cls` argument. """ + def update_weights_from_ipc(self, zmq_handles: dict[str, str]): + from vllm.model_executor.model_loader.utils import process_weights_after_loading + + assert self.device is not None + if not hasattr(self, "_zmq_ctx") or self._zmq_ctx is None: + self._zmq_ctx = zmq.Context() + socket = self._zmq_ctx.socket(zmq.REP) + socket.connect(zmq_handles[self.report_device_id()]) + buffer: torch.Tensor | None = None + while True: + payload: tuple[Callable, tuple] | list[FlattenedTensorMetadata] | None = ( + socket.recv_pyobj() + ) + if payload is None: + # means the update is done + process_weights_after_loading( + self.model_runner.model, self.model_config, self.device + ) + torch.cuda.synchronize() + socket.send(b"") + break + if isinstance(payload, tuple): + # an ipc handle that vLLM can use `func, args = handle` + # and `func(*args)` to rebuild GPU tensor. + buffer = rebuild_ipc(payload, self.device.index) + assert buffer.dtype == torch.uint8 + socket.send(b"") + continue + assert isinstance(payload, list) + assert buffer is not None + weights = [] + for item in payload: + shape = item["shape"] + if isinstance(shape, (list, tuple)): + shape = torch.Size(shape) + assert isinstance(shape, torch.Size) + dtype, offset = item["dtype"], item["offset"] + size = dtype.itemsize * shape.numel() + tensor = buffer[offset : offset + size].view(dtype=dtype).view(shape) + weights.append((item["name"], tensor)) + self.model_runner.model.load_weights(weights=weights) + del weights + torch.cuda.synchronize() + socket.send(b"") + + socket.close() + del buffer + gc.collect() + torch.cuda.empty_cache() + def report_device_id(self) -> str: from vllm.platforms import current_platform self.device_uuid = current_platform.get_device_uuid(self.device.index) return self.device_uuid - def update_weights_from_ipc_handles(self, ipc_handles): - handles = ipc_handles[self.device_uuid] - device_id = self.device.index - weights = [] - for name, handle in handles.items(): - func, args = handle - list_args = list(args) - # the key is to change device id to the current device id - # in case two processes have different CUDA_VISIBLE_DEVICES - list_args[6] = device_id - tensor = func(*list_args) - weights.append((name, tensor)) - self.model_runner.model.load_weights(weights=weights) - torch.cuda.synchronize() - def check_weights_changed(self): """ Check if the weights are updated to 0. diff --git a/examples/offline_inference/save_sharded_state.py b/examples/offline_inference/save_sharded_state.py index 41d7a3492..e25f46b12 100644 --- a/examples/offline_inference/save_sharded_state.py +++ b/examples/offline_inference/save_sharded_state.py @@ -30,7 +30,7 @@ from vllm import LLM, EngineArgs from vllm.model_executor.model_loader import ShardedStateLoader -from vllm.utils import FlexibleArgumentParser +from vllm.utils.argparse_utils import FlexibleArgumentParser def parse_args(): diff --git a/examples/offline_inference/spec_decode.py b/examples/offline_inference/spec_decode.py index 184c30891..f5f6e28b5 100644 --- a/examples/offline_inference/spec_decode.py +++ b/examples/offline_inference/spec_decode.py @@ -5,10 +5,11 @@ from vllm import LLM, SamplingParams from vllm.benchmarks.datasets import add_dataset_parser, get_samples +from vllm.inputs import TokensPrompt from vllm.v1.metrics.reader import Counter, Vector try: - from vllm.utils import FlexibleArgumentParser + from vllm.utils.argparse_utils import FlexibleArgumentParser except ImportError: from argparse import ArgumentParser as FlexibleArgumentParser @@ -48,6 +49,7 @@ def get_custom_mm_prompts(num_prompts): def parse_args(): parser = FlexibleArgumentParser() add_dataset_parser(parser) + parser.add_argument("--test", action="store_true") parser.add_argument( "--method", type=str, @@ -60,6 +62,7 @@ def parse_args(): parser.add_argument("--tp", type=int, default=1) parser.add_argument("--enforce-eager", action="store_true") parser.add_argument("--enable-chunked-prefill", action="store_true") + parser.add_argument("--max-model-len", type=int, default=16384) parser.add_argument("--temp", type=float, default=0) parser.add_argument("--top-p", type=float, default=1.0) parser.add_argument("--top-k", type=int, default=-1) @@ -71,8 +74,7 @@ def parse_args(): return parser.parse_args() -def main(): - args = parse_args() +def main(args): args.endpoint_type = "openai-chat" model_dir = args.model_dir @@ -117,6 +119,11 @@ def main(): "prompt_lookup_max": args.prompt_lookup_max, "prompt_lookup_min": args.prompt_lookup_min, } + elif args.method == "mtp": + speculative_config = { + "method": "mtp", + "num_speculative_tokens": args.num_spec_tokens, + } else: raise ValueError(f"unknown method: {args.method}") @@ -129,7 +136,7 @@ def main(): gpu_memory_utilization=0.8, speculative_config=speculative_config, disable_log_stats=False, - max_model_len=16384, + max_model_len=args.max_model_len, limit_mm_per_prompt={"image": 5}, disable_chunked_mm_input=True, ) @@ -137,7 +144,8 @@ def main(): sampling_params = SamplingParams(temperature=args.temp, max_tokens=args.output_len) if not args.custom_mm_prompts: outputs = llm.generate( - prompt_token_ids=prompt_ids, sampling_params=sampling_params + [TokensPrompt(prompt_token_ids=x) for x in prompt_ids], + sampling_params=sampling_params, ) else: outputs = llm.chat(prompts, sampling_params=sampling_params) @@ -192,6 +200,39 @@ def main(): acceptance_rate = acceptance_counts[i] / num_drafts if num_drafts > 0 else 0 print(f"acceptance at token {i}: {acceptance_rate:.2f}") + return acceptance_length + if __name__ == "__main__": - main() + args = parse_args() + acceptance_length = main(args) + + if args.test: + # takes ~30s to run on 1xH100 + assert args.method in ["eagle", "eagle3"] + assert args.tp == 1 + assert args.num_spec_tokens == 3 + assert args.dataset_name == "hf" + assert args.dataset_path == "philschmid/mt-bench" + assert args.num_prompts == 80 + assert args.temp == 0 + assert args.top_p == 1.0 + assert args.top_k == -1 + assert args.enable_chunked_prefill + + # check acceptance length is within 2% of expected value + rtol = 0.02 + expected_acceptance_length = 2.296 if args.method == "eagle" else 2.811 + + assert ( + acceptance_length <= (1 + rtol) * expected_acceptance_length + and acceptance_length >= (1 - rtol) * expected_acceptance_length + ), ( + f"acceptance_length {acceptance_length} is not " + f"within {rtol * 100}% of {expected_acceptance_length}" + ) + + print( + f"Test passed! Expected AL: " + f"{expected_acceptance_length}, got {acceptance_length}" + ) diff --git a/examples/offline_inference/structured_outputs.py b/examples/offline_inference/structured_outputs.py index f46064931..6b6099f71 100644 --- a/examples/offline_inference/structured_outputs.py +++ b/examples/offline_inference/structured_outputs.py @@ -1,11 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """ -This file demonstrates the example usage of guided decoding -to generate structured outputs using vLLM. It shows how to apply -different guided decoding techniques such as Choice, Regex, JSON schema, -and Grammar to produce structured and formatted results -based on specific prompts. +This file demonstrates the example usage of structured outputs +in vLLM. It shows how to apply different constraints such as choice, +regex, json schema, and grammar to produce structured and formatted +results based on specific prompts. """ from enum import Enum @@ -13,19 +12,23 @@ from pydantic import BaseModel from vllm import LLM, SamplingParams -from vllm.sampling_params import GuidedDecodingParams +from vllm.sampling_params import StructuredOutputsParams MAX_TOKENS = 50 -# Guided decoding by Choice (list of possible options) -guided_decoding_params_choice = GuidedDecodingParams(choice=["Positive", "Negative"]) -sampling_params_choice = SamplingParams(guided_decoding=guided_decoding_params_choice) +# Structured outputs by Choice (list of possible options) +structured_outputs_params_choice = StructuredOutputsParams( + choice=["Positive", "Negative"] +) +sampling_params_choice = SamplingParams( + structured_outputs=structured_outputs_params_choice +) prompt_choice = "Classify this sentiment: vLLM is wonderful!" -# Guided decoding by Regex -guided_decoding_params_regex = GuidedDecodingParams(regex=r"\w+@\w+\.com\n") +# Structured outputs by Regex +structured_outputs_params_regex = StructuredOutputsParams(regex=r"\w+@\w+\.com\n") sampling_params_regex = SamplingParams( - guided_decoding=guided_decoding_params_regex, + structured_outputs=structured_outputs_params_regex, stop=["\n"], max_tokens=MAX_TOKENS, ) @@ -36,7 +39,7 @@ ) -# Guided decoding by JSON using Pydantic schema +# Structured outputs by JSON using Pydantic schema class CarType(str, Enum): sedan = "sedan" suv = "SUV" @@ -51,17 +54,16 @@ class CarDescription(BaseModel): json_schema = CarDescription.model_json_schema() -guided_decoding_params_json = GuidedDecodingParams(json=json_schema) +structured_outputs_params_json = StructuredOutputsParams(json=json_schema) sampling_params_json = SamplingParams( - guided_decoding=guided_decoding_params_json, - max_tokens=MAX_TOKENS, + structured_outputs=structured_outputs_params_json, max_tokens=MAX_TOKENS ) prompt_json = ( - "Generate a JSON with the brand, model and car_type of" + "Generate a JSON with the brand, model and car_type of " "the most iconic car from the 90's" ) -# Guided decoding by Grammar +# Structured outputs by Grammar simplified_sql_grammar = """ root ::= select_statement select_statement ::= "SELECT " column " from " table " where " condition @@ -70,13 +72,15 @@ class CarDescription(BaseModel): condition ::= column "= " number number ::= "1 " | "2 " """ -guided_decoding_params_grammar = GuidedDecodingParams(grammar=simplified_sql_grammar) +structured_outputs_params_grammar = StructuredOutputsParams( + grammar=simplified_sql_grammar +) sampling_params_grammar = SamplingParams( - guided_decoding=guided_decoding_params_grammar, + structured_outputs=structured_outputs_params_grammar, max_tokens=MAX_TOKENS, ) prompt_grammar = ( - "Generate an SQL query to show the 'username' and 'email'from the 'users' table." + "Generate an SQL query to show the 'username' and 'email' from the 'users' table." ) @@ -85,7 +89,7 @@ def format_output(title: str, output: str): def generate_output(prompt: str, sampling_params: SamplingParams, llm: LLM): - outputs = llm.generate(prompts=prompt, sampling_params=sampling_params) + outputs = llm.generate(prompt, sampling_params=sampling_params) return outputs[0].outputs[0].text @@ -93,16 +97,16 @@ def main(): llm = LLM(model="Qwen/Qwen2.5-3B-Instruct", max_model_len=100) choice_output = generate_output(prompt_choice, sampling_params_choice, llm) - format_output("Guided decoding by Choice", choice_output) + format_output("Structured outputs by Choice", choice_output) regex_output = generate_output(prompt_regex, sampling_params_regex, llm) - format_output("Guided decoding by Regex", regex_output) + format_output("Structured outputs by Regex", regex_output) json_output = generate_output(prompt_json, sampling_params_json, llm) - format_output("Guided decoding by JSON", json_output) + format_output("Structured outputs by JSON", json_output) grammar_output = generate_output(prompt_grammar, sampling_params_grammar, llm) - format_output("Guided decoding by Grammar", grammar_output) + format_output("Structured outputs by Grammar", grammar_output) if __name__ == "__main__": diff --git a/examples/offline_inference/torchrun_dp_example.py b/examples/offline_inference/torchrun_dp_example.py new file mode 100644 index 000000000..eb7ed969e --- /dev/null +++ b/examples/offline_inference/torchrun_dp_example.py @@ -0,0 +1,151 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +experimental support for data-parallel inference with torchrun +Note the data load balancing and distribution is done out of the vllm engine, +no internal lb supported in external_launcher mode. + +To run this example: +```bash +$ torchrun --nproc-per-node=2 examples/offline_inference/torchrun_dp_example.py +``` + +With custom parallelism settings: +```bash +$ torchrun --nproc-per-node=8 examples/offline_inference/torchrun_dp_example.py \ + --tp-size=2 --pp-size=1 --dp-size=4 --enable-ep +``` +""" + +import argparse + +from vllm import LLM, SamplingParams + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Data-parallel inference with torchrun" + ) + parser.add_argument( + "--tp-size", + type=int, + default=1, + help="Tensor parallel size (default: 1)", + ) + parser.add_argument( + "--pp-size", + type=int, + default=1, + help="Pipeline parallel size (default: 1)", + ) + parser.add_argument( + "--dp-size", + type=int, + default=2, + help="Data parallel size (default: 2)", + ) + parser.add_argument( + "--enable-ep", + action="store_true", + help="Enable expert parallel (default: False)", + ) + parser.add_argument( + "--model", + type=str, + default="microsoft/Phi-mini-MoE-instruct", + help="Model name or path (default: microsoft/Phi-mini-MoE-instruct)", + ) + parser.add_argument( + "--max-model-len", + type=int, + default=4096, + help="Maximum model length (default: 4096)", + ) + parser.add_argument( + "--gpu-memory-utilization", + type=float, + default=0.6, + help="GPU memory utilization (default: 0.6)", + ) + parser.add_argument( + "--seed", + type=int, + default=1, + help="Random seed (default: 1)", + ) + return parser.parse_args() + + +args = parse_args() + + +# Create prompts, the same across all ranks +prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] + +# Create sampling parameters, the same across all ranks +sampling_params = SamplingParams(temperature=0.8, top_p=0.95) + +# Use `distributed_executor_backend="external_launcher"` so that +# this llm engine/instance only creates one worker. +# it is important to set an explicit seed to make sure that +# all ranks have the same random seed, so that sampling can be +# deterministic across ranks. +llm = LLM( + model=args.model, + tensor_parallel_size=args.tp_size, + data_parallel_size=args.dp_size, + pipeline_parallel_size=args.pp_size, + enable_expert_parallel=args.enable_ep, + distributed_executor_backend="external_launcher", + max_model_len=args.max_model_len, + gpu_memory_utilization=args.gpu_memory_utilization, + seed=args.seed, +) + +dp_rank = llm.llm_engine.vllm_config.parallel_config.data_parallel_rank +dp_size = llm.llm_engine.vllm_config.parallel_config.data_parallel_size + +prompts = [ + f"{idx}.{prompt}" for idx, prompt in enumerate(prompts) if idx % dp_size == dp_rank +] + +outputs = llm.generate(prompts, sampling_params) + +for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print( + f"DP Rank: {dp_rank} Prompt: {prompt!r}\nGenerated text: {generated_text!r}\n" + ) + +""" +Further tips: + +1. to communicate control messages across all ranks, use the cpu group, +a PyTorch ProcessGroup with GLOO backend. + +```python +from vllm.distributed.parallel_state import get_world_group +cpu_group = get_world_group().cpu_group +torch_rank = dist.get_rank(group=cpu_group) +if torch_rank == 0: + # do something for rank 0, e.g. saving the results to disk. +``` + +2. to communicate data across all ranks, use the model's device group, +a PyTorch ProcessGroup with NCCL backend. +```python +from vllm.distributed.parallel_state import get_world_group +device_group = get_world_group().device_group +``` + +3. to access the model directly in every rank, use the following code: +```python +llm.llm_engine.model_executor.driver_worker.worker.model_runner.model +``` +""" diff --git a/examples/offline_inference/tpu.py b/examples/offline_inference/tpu.py index 9776f4fe3..0093b63b0 100644 --- a/examples/offline_inference/tpu.py +++ b/examples/offline_inference/tpu.py @@ -42,7 +42,7 @@ def main(): llm_args["model"] = "meta-llama/Llama-3.1-8B-Instruct" # Set `enforce_eager=True` to avoid ahead-of-time compilation. - # In real workloads, `enforace_eager` should be `False`. + # In real workloads, `enforce_eager` should be `False`. llm = LLM(**llm_args) outputs = llm.generate(prompts, sampling_params) print("-" * 50) diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py index 988ad35cd..c1ea95f8d 100644 --- a/examples/offline_inference/vision_language.py +++ b/examples/offline_inference/vision_language.py @@ -12,7 +12,7 @@ import random from contextlib import contextmanager from dataclasses import asdict -from typing import NamedTuple, Optional +from typing import NamedTuple from huggingface_hub import snapshot_download from transformers import AutoTokenizer @@ -22,14 +22,15 @@ from vllm.assets.video import VideoAsset from vllm.lora.request import LoRARequest from vllm.multimodal.image import convert_image_mode -from vllm.utils import FlexibleArgumentParser +from vllm.utils.argparse_utils import FlexibleArgumentParser class ModelRequestData(NamedTuple): engine_args: EngineArgs prompts: list[str] - stop_token_ids: Optional[list[int]] = None - lora_requests: Optional[list[LoRARequest]] = None + stop_token_ids: list[int] | None = None + lora_requests: list[LoRARequest] | None = None + sampling_params: list[SamplingParams] | None = None # NOTE: The default `max_num_seqs` and `max_model_len` may result in OOM on @@ -90,6 +91,33 @@ def run_aya_vision(questions: list[str], modality: str) -> ModelRequestData: ) +# Bee-8B +def run_bee(questions: list[str], modality: str) -> ModelRequestData: + assert modality == "image" + model_name = "Open-Bee/Bee-8B-RL" + + prompts = [ + ( + f"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n" + f"<|im_start|>user\n\n{question}<|im_end|>" + f"<|im_start|>assistant\n\n" + ) + for question in questions + ] + + engine_args = EngineArgs( + model=model_name, + max_model_len=16384, + limit_mm_per_prompt={modality: 1}, + trust_remote_code=True, + ) + + return ModelRequestData( + engine_args=engine_args, + prompts=prompts, + ) + + # BLIP-2 def run_blip2(questions: list[str], modality: str) -> ModelRequestData: assert modality == "image" @@ -173,21 +201,90 @@ def run_deepseek_vl2(questions: list[str], modality: str) -> ModelRequestData: ) -# Florence2 -def run_florence2(questions: list[str], modality: str) -> ModelRequestData: +def run_deepseek_ocr(questions: list[str], modality: str) -> ModelRequestData: + from vllm.model_executor.models.deepseek_ocr import NGramPerReqLogitsProcessor + assert modality == "image" + model_name = "deepseek-ai/DeepSeek-OCR" + engine_args = EngineArgs( - model="microsoft/Florence-2-large", - tokenizer="Isotr0py/Florence-2-tokenizer", - max_model_len=4096, - max_num_seqs=2, + model=model_name, + limit_mm_per_prompt={modality: 1}, + logits_processors=[NGramPerReqLogitsProcessor], + ) + + # deepseek-ocr use plain prompt template + prompts = [f"\n{question}" for question in questions] + + # The following sampling params config is taken from + # the official Deepseek-OCR inference example. + # (IMPORTANT) Use the custom logits processor and avoid skipping + # special tokens for this model for the optimal OCR performance. + sampling_params = [ + SamplingParams( + temperature=0.0, + max_tokens=8192, + # ngram logit processor args + extra_args=dict( + ngram_size=30, + window_size=90, + # whitelist: , + whitelist_token_ids={128821, 128822}, + ), + skip_special_tokens=False, + ) + for _ in questions + ] + + return ModelRequestData( + engine_args=engine_args, + prompts=prompts, + sampling_params=sampling_params, + ) + + +# Dots-OCR +def run_dots_ocr(questions: list[str], modality: str) -> ModelRequestData: + assert modality == "image" + + prompts = [f"<|img|><|imgpad|><|endofimg|>{question}" for question in questions] + engine_args = EngineArgs( + model="rednote-hilab/dots.ocr", + limit_mm_per_prompt={modality: 1}, trust_remote_code=True, - dtype="bfloat16", + ) + + return ModelRequestData( + engine_args=engine_args, + prompts=prompts, + ) + + +# Ernie4.5-VL +def run_ernie45_vl(questions: list[str], modality: str) -> ModelRequestData: + model_name = "baidu/ERNIE-4.5-VL-28B-A3B-PT" + + engine_args = EngineArgs( + model=model_name, + max_model_len=4096, + max_num_seqs=5, limit_mm_per_prompt={modality: 1}, + trust_remote_code=True, ) - prompts = ["" for _ in questions] + if modality == "image": + placeholder = "Picture 1:<|IMAGE_START|><|image@placeholder|><|IMAGE_END|>" + elif modality == "video": + placeholder = "Video 1:<|VIDEO_START|><|video@placeholder|><|VIDEO_END|>" + + prompts = [ + ( + f"<|begin_of_sentence|>User: {question}{placeholder}\n" + "Assistant: " + ) + for question in questions + ] return ModelRequestData( engine_args=engine_args, @@ -283,8 +380,10 @@ def run_glm4v(questions: list[str], modality: str) -> ModelRequestData: ) prompts = [ - f"<|user|>\n<|begin_of_image|><|endoftext|><|end_of_image|>\ - {question}<|assistant|>" + ( + "<|user|>\n<|begin_of_image|><|endoftext|><|end_of_image|>" + f"{question}<|assistant|>" + ) for question in questions ] @@ -333,6 +432,80 @@ def run_glm4_1v(questions: list[str], modality: str) -> ModelRequestData: ) +# GLM-4.5V +def run_glm4_5v(questions: list[str], modality: str) -> ModelRequestData: + model_name = "zai-org/GLM-4.5V" + + engine_args = EngineArgs( + model=model_name, + max_model_len=4096, + max_num_seqs=2, + mm_processor_kwargs={ + "size": {"shortest_edge": 12544, "longest_edge": 47040000}, + "fps": 1, + }, + limit_mm_per_prompt={modality: 1}, + enforce_eager=True, + tensor_parallel_size=4, + ) + + if modality == "image": + placeholder = "<|begin_of_image|><|image|><|end_of_image|>" + elif modality == "video": + placeholder = "<|begin_of_video|><|video|><|end_of_video|>" + + prompts = [ + ( + "[gMASK]<|system|>\nYou are a helpful assistant.<|user|>\n" + f"{placeholder}" + f"{question}<|assistant|>assistant\n" + ) + for question in questions + ] + + return ModelRequestData( + engine_args=engine_args, + prompts=prompts, + ) + + +# GLM-4.5V-FP8 +def run_glm4_5v_fp8(questions: list[str], modality: str) -> ModelRequestData: + model_name = "zai-org/GLM-4.5V-FP8" + + engine_args = EngineArgs( + model=model_name, + max_model_len=4096, + max_num_seqs=2, + mm_processor_kwargs={ + "size": {"shortest_edge": 12544, "longest_edge": 47040000}, + "fps": 1, + }, + limit_mm_per_prompt={modality: 1}, + enforce_eager=True, + tensor_parallel_size=4, + ) + + if modality == "image": + placeholder = "<|begin_of_image|><|image|><|end_of_image|>" + elif modality == "video": + placeholder = "<|begin_of_video|><|video|><|end_of_video|>" + + prompts = [ + ( + "[gMASK]<|system|>\nYou are a helpful assistant.<|user|>\n" + f"{placeholder}" + f"{question}<|assistant|>assistant\n" + ) + for question in questions + ] + + return ModelRequestData( + engine_args=engine_args, + prompts=prompts, + ) + + # H2OVL-Mississippi def run_h2ovl(questions: list[str], modality: str) -> ModelRequestData: assert modality == "image" @@ -383,8 +556,8 @@ def run_hyperclovax_seed_vision( for question in questions: if modality == "image": """ - ocr: List the words in the image in raster order. - Even if the word order feels unnatural for reading, + ocr: List the words in the image in raster order. + Even if the word order feels unnatural for reading, the model will handle it as long as it follows raster order. e.g. "Naver, CLOVA, bigshane" lens_keywords: List the entity names in the image. @@ -474,7 +647,7 @@ def run_idefics3(questions: list[str], modality: str) -> ModelRequestData: # Intern-S1 def run_interns1(questions: list[str], modality: str) -> ModelRequestData: - model_name = "internlm/Intern-S1" + model_name = "internlm/Intern-S1-mini" engine_args = EngineArgs( model=model_name, @@ -576,6 +749,37 @@ def run_keye_vl(questions: list[str], modality: str) -> ModelRequestData: ) +# Keye-VL-1.5 +def run_keye_vl1_5(questions: list[str], modality: str) -> ModelRequestData: + model_name = "Kwai-Keye/Keye-VL-1.5-8B" + + engine_args = EngineArgs( + model=model_name, + max_model_len=8192, + trust_remote_code=True, + limit_mm_per_prompt={modality: 1}, + ) + + if modality == "image": + placeholder = "<|image_pad|>" + elif modality == "video": + placeholder = "<|video_pad|>" + + prompts = [ + ( + f"<|im_start|>user\n<|vision_start|>{placeholder}<|vision_end|>" + f"{question}<|im_end|>\n" + "<|im_start|>assistant\n" + ) + for question in questions + ] + + return ModelRequestData( + engine_args=engine_args, + prompts=prompts, + ) + + # Kimi-VL def run_kimi_vl(questions: list[str], modality: str) -> ModelRequestData: assert modality == "image" @@ -600,6 +804,26 @@ def run_kimi_vl(questions: list[str], modality: str) -> ModelRequestData: ) +# LightOnOCR +def run_lightonocr(questions: list[str], modality: str) -> ModelRequestData: + assert modality == "image" + + prompts = [ + "<|im_start|>system<|im_end|>\n<|im_start|>user\n<|image_pad|><|im_end|>\n<|im_start|>assistant\n" + for _ in questions + ] + + engine_args = EngineArgs( + model="lightonai/LightOnOCR-1B", + limit_mm_per_prompt={modality: 1}, + ) + + return ModelRequestData( + engine_args=engine_args, + prompts=prompts, + ) + + def run_llama4(questions: list[str], modality: str) -> ModelRequestData: assert modality == "image" @@ -693,15 +917,13 @@ def run_llava_next_video(questions: list[str], modality: str) -> ModelRequestDat def run_llava_onevision(questions: list[str], modality: str) -> ModelRequestData: if modality == "video": prompts = [ - f"<|im_start|>user '}} + {%- endif %} + {%- set ns.is_last_user = false -%} + {%- set ns.is_first = false %} + {%- set ns.is_tool = false -%} + {%- for tool in message['tool_calls'] %} + {%- if not ns.is_first %} + {%- if message['content'] is none %} + {{'<|tool▁calls▁begin|><|tool▁call▁begin|>'+ tool['function']['name'] + '<|tool▁sep|>' + tool['function']['arguments']|tojson + '<|tool▁call▁end|>'}} + {%- else %} + {{message['content'] + '<|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['function']['name'] + '<|tool▁sep|>' + tool['function']['arguments']|tojson + '<|tool▁call▁end|>'}} + {%- endif %} + {%- set ns.is_first = true -%} + {%- else %} + {{'<|tool▁call▁begin|>'+ tool['function']['name'] + '<|tool▁sep|>' + tool['function']['arguments']|tojson + '<|tool▁call▁end|>'}} + {%- endif %} + {%- endfor %} + {{'<|tool▁calls▁end|><|end▁of▁sentence|>'}} + {%- endif %} + {%- if message['role'] == 'assistant' and (message['tool_calls'] is not defined or message['tool_calls'] is none) %} + {%- if ns.is_last_user %} + {{'<|Assistant|>'}} + {%- if message['prefix'] is defined and message['prefix'] and thinking %} + {{''}} + {%- else %} + {{''}} + {%- endif %} + {%- endif %} + {%- set ns.is_last_user = false -%} + {%- if ns.is_tool %} + {{message['content'] + '<|end▁of▁sentence|>'}} + {%- set ns.is_tool = false -%} + {%- else %} + {%- set content = message['content'] -%} + {%- if '' in content %} + {%- set content = content.split('', 1)[1] -%} + {%- endif %} + {{content + '<|end▁of▁sentence|>'}} + {%- endif %} + {%- endif %} + {%- if message['role'] == 'tool' %} + {%- set ns.is_last_user = false -%} + {%- set ns.is_tool = true -%} + {{'<|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}} + {%- endif %} +{%- endfor -%} +{%- if add_generation_prompt and ns.is_last_user and not ns.is_tool %} + {{'<|Assistant|>'}} + {%- if not thinking %} + {{''}} + {%- else %} + {{''}} + {%- endif %} +{% endif %} diff --git a/examples/tool_chat_template_gemma3_pythonic.jinja b/examples/tool_chat_template_gemma3_pythonic.jinja new file mode 100644 index 000000000..5a20b0191 --- /dev/null +++ b/examples/tool_chat_template_gemma3_pythonic.jinja @@ -0,0 +1,123 @@ +{#- Begin-of-sequence token to start the model prompt -#} +{{ bos_token }} +{#- Extracts the system message. Gemma does not support system messages so it will be prepended to first user message. -#} +{%- if messages[0]['role'] == 'system' -%} + {%- if messages[0]['content'] is string -%} + {%- set first_user_prefix = messages[0]['content'] + '\n\n' -%} + {%- else -%} + {%- set first_user_prefix = messages[0]['content'][0]['text'] + '\n\n' -%} + {%- endif -%} + {%- set loop_messages = messages[1:] -%} +{%- else -%} + {%- set first_user_prefix = "" -%} + {%- set loop_messages = messages -%} +{%- endif -%} +{#- Set tools to none if not defined for this ChatCompletion request (helps avoid errors later) -#} +{%- if not tools is defined %} + {%- set tools = none %} +{%- endif %} +{#- Validate alternating user/assistant messages (excluding 'tool' messages and ones with tool_calls) -#} +{%- for message in loop_messages | rejectattr("role", "equalto", "tool") | selectattr("tool_calls", "undefined") -%} + {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) %} + {{ raise_exception("Conversation roles must alternate user/assistant/user/assistant/...") }} + {%- endif -%} +{%- endfor -%} + +{#- Main loop over all messages in the conversation history -#} +{%- for message in loop_messages -%} + {#- Normalize roles for model prompt formatting -#} + {%- if (message['role'] == 'assistant') -%} + {%- set role = "model" -%} + {%- elif (message['role'] == 'tool') -%} + {%- set role = "user" -%} + {%- else -%} + {%- set role = message['role'] -%} + {%- endif -%} + {#- Mark the start of a message block with the appropriate role -#} + {{ '' + role + '\n' -}} + + {#- Insert system message content (if present) at the beginning of the first message. -#} + {%- if loop.first -%} + {{ first_user_prefix }} + {#- Append system message with tool information if using tools in message request. -#} + {%- if tools is not none -%} + {{- "Tools (functions) are available. If you decide to invoke one or more of the tools, you must respond with a python list of the function calls.\n" -}} + {{- "Example Format: [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)] \n" -}} + {{- "Do not use variables. DO NOT USE MARKDOWN SYNTAX. You SHOULD NOT include any other text in the response if you call a function. If none of the functions can be used, point it out. If you lack the parameters required by the function, also point it out.\n" -}} + {{- "Here is a list of functions in JSON format that you can invoke.\n" -}} + {{- tools | tojson(indent=4) -}} + {{- "\n\n" -}} + {%- endif -%} + {%- endif -%} + + {#- Format model tool calls (turns where model indicates they want to call a tool) -#} + {%- if 'tool_calls' in message -%} + {#- Opening bracket for tool call list. -#} + {{- '[' -}} + {#- For each tool call -#} + {%- for tool_call in message.tool_calls -%} + {#- Get tool call function. -#} + {%- if tool_call.function is defined -%} + {%- set tool_call = tool_call.function -%} + {%- endif -%} + {#- Function name & opening parenthesis. -#} + {{- tool_call.name + '(' -}} + + {#-- Handle arguments as list (positional) or dict (named) --#} + {#-- Named arguments (dict) --#} + {%- if tool_call.arguments is iterable and tool_call.arguments is mapping -%} + {%- set first = true -%} + {%- for key, val in tool_call.arguments.items() -%} + {%- if not first %}, {% endif -%} + {{ key }}={{ val | tojson }} + {%- set first = false -%} + {%- endfor -%} + {#-- Positional arguments (list) --#} + {%- elif tool_call.arguments is iterable -%} + {{- tool_call.arguments | map('tojson') | join(', ') -}} + {#-- Fallback: single positional value --#} + {%- else -%} + {{- tool_call.arguments | tojson -}} + {#-- Closing parenthesis. --#} + {%- endif -%} + {{- ')' -}} + {#-- If more than one tool call, place comma and move to formatting next tool call --#} + {%- if not loop.last -%}, {% endif -%} + {%- endfor -%} + {#- Closing bracket for tool call list. -#} + {{- ']' -}} + {%- endif -%} + + {#- Tool response start tag (for messages from a tool) -#} + {%- if (message['role'] == 'tool') -%} + {{ '\n' -}} + {%- endif -%} + + {#- Render the message content: handle plain string or multimodal content like image/text -#} + {%- if message['content'] is string -%} + {{ message['content'] | trim }} + {%- elif message['content'] is iterable -%} + {%- for item in message['content'] -%} + {%- if item['type'] == 'image' -%} + {{ '' }} + {%- elif item['type'] == 'text' -%} + {{ item['text'] | trim }} + {%- endif -%} + {%- endfor -%} + {%- else -%} + {{ raise_exception("Invalid content type") }} + {%- endif -%} + + {#- Tool response end tag -#} + {%- if (message['role'] == 'tool') -%} + {{ '' -}} + {%- endif -%} + + {#- Mark end of a single turn -#} + {{ '\n' }} +{%- endfor -%} + +{#- If generation is to be triggered, add model prompt prefix -#} +{%- if add_generation_prompt -%} + {{'model\n'}} +{%- endif -%} \ No newline at end of file diff --git a/examples/tool_chat_template_phi4_mini.jinja b/examples/tool_chat_template_phi4_mini.jinja index 36423b6c4..6f40c38c2 100644 --- a/examples/tool_chat_template_phi4_mini.jinja +++ b/examples/tool_chat_template_phi4_mini.jinja @@ -1,11 +1,15 @@ +{%- if messages and messages[0]['role'] == 'system' %} + {%- set system_message = messages[0]['content']|trim %} + {%- set messages = messages[1:] %} +{%- else %} + {%- set system_message = "You are a helpful assistant." %} +{%- endif %} + {%- if messages %} - {%- if system_message or tools %} <|system|> - -{%- if system_message %} {{ system_message }} -{%- endif %} -In addition to plain text responses, you can chose to call one or more of the provided functions. +{%- if tools %} +In addition to plain text responses, you can choose to call one or more of the provided functions. Use the following rule to decide when to call a function: * if the response can be generated from your internal knowledge (e.g., as in the case of queries like "What is the capital of Poland?"), do so @@ -15,17 +19,15 @@ If you decide to call functions: * prefix function calls with functools marker (no closing marker required) * all function calls should be generated in a single JSON list formatted as functools[{"name": [function name], "arguments": [function arguments as JSON]}, ...] * follow the provided JSON schema. Do not hallucinate arguments or values. Do to blindly copy values from the provided samples - * respect the argument type formatting. E.g., if the type if number and format is float, write value 7 as 7.0 + * respect the argument type formatting. E.g., if the type is number and format is float, write value 7 as 7.0 * make sure you pick the right functions that match the user intent -{%- if tools %} {%- for t in tools %} {{- t | tojson(indent=4) }} {{- "\n\n" }} {%- endfor %} {%- endif %}<|end|> - {%- endif %} {%- for message in messages %} {%- if message.role != "system" %} diff --git a/examples/tool_chat_template_qwen3coder.jinja b/examples/tool_chat_template_qwen3coder.jinja new file mode 100644 index 000000000..49b0e8d0e --- /dev/null +++ b/examples/tool_chat_template_qwen3coder.jinja @@ -0,0 +1,117 @@ +{% macro render_extra_keys(json_dict, handled_keys) %} + {%- if json_dict is mapping %} + {%- for json_key in json_dict if json_key not in handled_keys %} + {%- if json_dict[json_key] is mapping or (json_dict[json_key] is sequence and json_dict[json_key] is not string) %} + {{- '\n<' ~ json_key ~ '>' ~ (json_dict[json_key] | tojson | safe) ~ '' }} + {%- else %} + {{-'\n<' ~ json_key ~ '>' ~ (json_dict[json_key] | string) ~ '' }} + {%- endif %} + {%- endfor %} + {%- endif %} +{% endmacro %} + +{%- if messages[0]["role"] == "system" %} + {%- set system_message = messages[0]["content"] %} + {%- set loop_messages = messages[1:] %} +{%- else %} + {%- set loop_messages = messages %} +{%- endif %} + +{%- if not tools is defined %} + {%- set tools = [] %} +{%- endif %} + +{%- if system_message is defined %} + {{- "<|im_start|>system\n" + system_message }} +{%- else %} + {%- if tools is iterable and tools | length > 0 %} + {{- "<|im_start|>system\nYou are Qwen, a helpful AI assistant that can interact with a computer to solve tasks." }} + {%- endif %} +{%- endif %} +{%- if tools is iterable and tools | length > 0 %} + {{- "\n\n# Tools\n\nYou have access to the following functions:\n\n" }} + {{- "" }} + {%- for tool in tools %} + {%- if tool.function is defined %} + {%- set tool = tool.function %} + {%- endif %} + {{- "\n\n" ~ tool.name ~ "" }} + {%- if tool.description is defined %} + {{- '\n' ~ (tool.description | trim) ~ '' }} + {%- endif %} + {{- '\n' }} + {%- if tool.parameters is defined and tool.parameters is mapping and tool.parameters.properties is defined and tool.parameters.properties is mapping %} + {%- for param_name, param_fields in tool.parameters.properties|items %} + {{- '\n' }} + {{- '\n' ~ param_name ~ '' }} + {%- if param_fields.type is defined %} + {{- '\n' ~ (param_fields.type | string) ~ '' }} + {%- endif %} + {%- if param_fields.description is defined %} + {{- '\n' ~ (param_fields.description | trim) ~ '' }} + {%- endif %} + {%- set handled_keys = ['name', 'type', 'description'] %} + {{- render_extra_keys(param_fields, handled_keys) }} + {{- '\n' }} + {%- endfor %} + {%- endif %} + {% set handled_keys = ['type', 'properties'] %} + {{- render_extra_keys(tool.parameters, handled_keys) }} + {{- '\n' }} + {%- set handled_keys = ['type', 'name', 'description', 'parameters'] %} + {{- render_extra_keys(tool, handled_keys) }} + {{- '\n' }} + {%- endfor %} + {{- "\n" }} + {{- '\n\nIf you choose to call a function ONLY reply in the following format with NO suffix:\n\n\n\n\nvalue_1\n\n\nThis is the value for the second parameter\nthat can span\nmultiple lines\n\n\n\n\n\nReminder:\n- Function calls MUST follow the specified format: an inner block must be nested within XML tags\n- Required parameters MUST be specified\n- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n' }} +{%- endif %} +{%- if system_message is defined %} + {{- '<|im_end|>\n' }} +{%- else %} + {%- if tools is iterable and tools | length > 0 %} + {{- '<|im_end|>\n' }} + {%- endif %} +{%- endif %} +{%- for message in loop_messages %} + {%- if message.role == "assistant" and message.tool_calls is defined and message.tool_calls is iterable and message.tool_calls | length > 0 %} + {{- '<|im_start|>' + message.role }} + {%- if message.content is defined and message.content is string and message.content | trim | length > 0 %} + {{- '\n' + message.content | trim + '\n' }} + {%- endif %} + {%- for tool_call in message.tool_calls %} + {%- if tool_call.function is defined %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {{- '\n\n\n' }} + {%- if tool_call.arguments is defined %} + {%- for args_name, args_value in tool_call.arguments|items %} + {{- '\n' }} + {%- set args_value = args_value | tojson | safe if args_value is mapping or (args_value is sequence and args_value is not string) else args_value | string %} + {{- args_value }} + {{- '\n\n' }} + {%- endfor %} + {%- endif %} + {{- '\n' }} + {%- endfor %} + {{- '<|im_end|>\n' }} + {%- elif message.role == "user" or message.role == "system" or message.role == "assistant" %} + {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>' + '\n' }} + {%- elif message.role == "tool" %} + {%- if loop.previtem and loop.previtem.role != "tool" %} + {{- '<|im_start|>user\n' }} + {%- endif %} + {{- '\n' }} + {{- message.content }} + {{- '\n\n' }} + {%- if not loop.last and loop.nextitem.role != "tool" %} + {{- '<|im_end|>\n' }} + {%- elif loop.last %} + {{- '<|im_end|>\n' }} + {%- endif %} + {%- else %} + {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>\n' }} + {%- endif %} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '<|im_start|>assistant\n' }} +{%- endif %} diff --git a/find_cuda_init.py b/find_cuda_init.py deleted file mode 100644 index 0d13b2f86..000000000 --- a/find_cuda_init.py +++ /dev/null @@ -1,35 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -import importlib -import traceback -from typing import Callable -from unittest.mock import patch - - -def find_cuda_init(fn: Callable[[], object]) -> None: - """ - Helper function to debug CUDA re-initialization errors. - - If `fn` initializes CUDA, prints the stack trace of how this happens. - """ - from torch.cuda import _lazy_init - - stack = None - - def wrapper(): - nonlocal stack - stack = traceback.extract_stack() - return _lazy_init() - - with patch("torch.cuda._lazy_init", wrapper): - fn() - - if stack is not None: - print("==== CUDA Initialized ====") - print("".join(traceback.format_list(stack)).strip()) - print("==========================") - - -if __name__ == "__main__": - find_cuda_init( - lambda: importlib.import_module("vllm.model_executor.models.llava")) diff --git a/pyproject.toml b/pyproject.toml index 458fbf133..c06a21331 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,28 +51,23 @@ metax_enhanced_model = "vllm_metax:register_model" [tool.setuptools_scm] # no extra settings needed, presence enables setuptools-scm -local_scheme = "node-and-date" [tool.setuptools.packages.find] where = ["."] include = ["vllm_metax*"] -[tool.yapfignore] -ignore_patterns = [ - ".buildkite/**", - "benchmarks/**", - "build/**", - "examples/**" -] - [tool.ruff] -# Allow lines to be as long as 80. -line-length = 80 +# Note: all these are maintained by vllm +exclude = [ +"examples", +"tests", +"vllm_metax/third_party", +"vllm_metax/models" +] [tool.ruff.lint.per-file-ignores] -"vllm_metax/third_party/**" = ["ALL"] "vllm_metax/platform.py" = ["ALL"] -"vllm_metax/version.py" = ["F401"] +"vllm_metax/version.py" = ["ALL"] "vllm_metax/_version.py" = ["ALL"] "vllm_metax/v1/**" = ["ALL"] "vllm_metax/csrc/**" = ["ALL"] @@ -108,53 +103,25 @@ ignore = [ "B007", # f-string format "UP032", - # Can remove once 3.10+ is the minimum Python version - "UP007", - "UP006", - "UP035" ] +[tool.ruff.format] +docstring-code-format = true + [tool.mypy] plugins = ['pydantic.mypy'] ignore_missing_imports = true check_untyped_defs = true follow_imports = "silent" -# After fixing type errors resulting from follow_imports: "skip" -> "silent", -# move the directory here and remove it from tools/mypy.sh -files = [ - "vllm_metax/*.py", -] -# TODO(woosuk): Include the code from Megatron and HuggingFace. -exclude = [ - 'vllm_metax/distributed/.*\.py$', - # Ignore triton kernels in ops. - 'vllm_metax/attention/ops/.*\.py$' -] - -[[tool.mypy.overrides]] -module = [ - "vllm.*", -] -ignore_missing_imports = true - -[tool.isort] -skip_glob = [ - ".buildkite/*", - "benchmarks/*", - "examples/*", - "tests/*", - "vllm_metax/*" -] -use_parentheses = true -skip_gitignore = true - [tool.pytest.ini_options] markers = [ + "slow_test", "skip_global_cleanup", "core_model: enable this model test in each PR instead of only nightly", "hybrid_model: models that contain mamba layers (including pure SSM and hybrid architectures)", "cpu_model: enable this model test in CPU tests", + "cpu_test: mark test as CPU-only test", "split: run this test as part of a split", "distributed: run this test only in distributed GPU tests", "skip_v1: do not run this test with v1", @@ -234,6 +201,8 @@ fo = "fo" ba = "ba" [tool.typos.type.py.extend-words] +ba = "ba" +nd = "nd" [tool.typos.type.cpp] extend-glob = ["*.cu"] @@ -351,22 +320,5 @@ windo = "windo" [tool.typos.type.vimscript.extend-words] -[tool.uv.pip] -index-url = "https://pypi.tuna.tsinghua.edu.cn/simple" -extra-index-url = ["https://pypi.org/simple"] - -[tool.uv] -no-build-isolation-package = [ - "torch", "torchaudio", "torchvision", - "xformers", - "lm_eval", - "causal_conv1d", - "dropout_layer_norm", - "flash_attn", - "flash-linear-attention", - "flash_mla", - "flashinfer", - "fused_dense_lib", - "mamba_ssm", - "rotary_emb", -] \ No newline at end of file +[tool.uv] +no-build-isolation-package = ["torch"] \ No newline at end of file diff --git a/requirements/common.txt b/requirements/common.txt index d386c778c..05ee52f07 100644 --- a/requirements/common.txt +++ b/requirements/common.txt @@ -7,13 +7,13 @@ requests >= 2.26.0 tqdm blake3 py-cpuinfo -transformers >= 4.55.2 +transformers >= 4.56.0 tokenizers >= 0.21.1 # Required for fast incremental detokenization. protobuf # Required by LlamaTokenizer. fastapi[standard] >= 0.115.0 # Required by FastAPI's form models in the OpenAI API server's audio transcriptions endpoint. aiohttp openai >= 1.99.1 # For Responses API with reasoning content -pydantic >= 2.11.7 +pydantic >= 2.12.0 prometheus_client >= 0.18.0 pillow # Required for image processing prometheus-fastapi-instrumentator >= 7.0.0 @@ -31,15 +31,14 @@ partial-json-parser # used for parsing partial JSON outputs pyzmq >= 25.0.0 msgspec gguf >= 0.13.0 -importlib_metadata; python_version < '3.10' -mistral_common[image,audio] >= 1.8.2 +mistral_common[image,audio] >= 1.8.5 opencv-python-headless >= 4.11.0 # required for video IO pyyaml six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12 setuptools>=77.0.3,<80; python_version > '3.11' # Setuptools is used by triton, we need to ensure a modern version is installed for 3.12+ so that it does not try to import distutils, which was removed in 3.12 einops # Required for Qwen2-VL. -compressed-tensors == 0.11.0 # required for compressed-tensors -depyf==0.19.0 # required for profiling and debugging with compilation config +compressed-tensors == 0.12.2 # required for compressed-tensors +depyf==0.20.0 # required for profiling and debugging with compilation config cloudpickle # allows pickling lambda functions in model_executor/models/registry.py watchfiles # required for http server to monitor the updates of TLS files python-json-logger # Used by logging as per examples/others/logging_configuration.md @@ -49,3 +48,4 @@ pybase64 # fast base64 implementation cbor2 # Required for cross-language serialization of hashable objects setproctitle # Used to set process names for better debugging and monitoring openai-harmony >= 0.0.3 # Required for gpt-oss +anthropic == 0.71.0 diff --git a/requirements/test.in b/requirements/test.in index b897fae6e..802c340f4 100644 --- a/requirements/test.in +++ b/requirements/test.in @@ -53,4 +53,3 @@ runai-model-streamer-s3==0.11.0 fastsafetensors>=0.1.10 pydantic>=2.10 # 2.9 leads to error on python 3.10 decord==0.6.0 -terratorch @ git+https://github.com/IBM/terratorch.git@1.1.rc3 # required for PrithviMAE test diff --git a/requirements/test.txt b/requirements/test.txt index 5f3f1df21..8b9dd1ef8 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -6,10 +6,6 @@ accelerate==1.0.1 # via # lm-eval # peft -aenum==3.1.16 - # via lightly -affine==2.4.0 - # via rasterio aiohappyeyeballs==2.4.3 # via aiohttp aiohttp==3.10.11 @@ -23,18 +19,8 @@ aiohttp-cors==0.8.1 # via ray aiosignal==1.3.1 # via aiohttp -albucore==0.0.16 - # via terratorch -albumentations==1.4.6 - # via terratorch -alembic==1.16.4 - # via mlflow annotated-types==0.7.0 # via pydantic -antlr4-python3-runtime==4.9.3 - # via - # hydra-core - # omegaconf anyio==4.6.2.post1 # via # httpx @@ -50,12 +36,10 @@ async-timeout==5.0.1 attrs==24.2.0 # via # aiohttp - # fiona # hypothesis # jsonlines # jsonschema # pytest-subtests - # rasterio # referencing audioread==3.0.1 # via librosa @@ -64,13 +48,9 @@ backoff==2.2.1 # -r requirements/test.in # schemathesis bitsandbytes==0.46.1 - # via - # -r requirements/test.in - # lightning + # via -r requirements/test.in black==24.10.0 # via datamodel-code-generator -blinker==1.9.0 - # via flask blobfile==3.0.0 # via -r requirements/test.in bm25s==0.2.13 @@ -86,18 +66,11 @@ bounded-pool-executor==0.0.3 buildkite-test-collector==0.1.9 # via -r requirements/test.in cachetools==5.5.2 - # via - # google-auth - # mlflow-skinny + # via google-auth certifi==2024.8.30 # via - # fiona # httpcore # httpx - # lightly - # pyogrio - # pyproj - # rasterio # requests cffi==1.17.1 # via soundfile @@ -108,28 +81,11 @@ charset-normalizer==3.4.0 click==8.1.7 # via # black - # click-plugins - # cligj - # fiona - # flask # jiwer - # mlflow-skinny # nltk - # rasterio # ray # schemathesis # typer - # uvicorn -click-plugins==1.1.1.2 - # via - # fiona - # rasterio -cligj==0.7.2 - # via - # fiona - # rasterio -cloudpickle==3.1.1 - # via mlflow-skinny colorama==0.4.6 # via # sacrebleu @@ -145,8 +101,6 @@ cupy-cuda12x==13.6.0 # via ray cycler==0.12.1 # via matplotlib -databricks-sdk==0.59.0 - # via mlflow-skinny datamodel-code-generator==0.26.3 # via -r requirements/test.in dataproperty==1.0.1 @@ -172,20 +126,12 @@ distlib==0.3.9 # via virtualenv dnspython==2.7.0 # via email-validator -docker==7.1.0 - # via mlflow docopt==0.6.2 # via num2words -docstring-parser==0.17.0 - # via jsonargparse -efficientnet-pytorch==0.7.1 - # via segmentation-models-pytorch einops==0.8.1 # via # -r requirements/test.in # encodec - # terratorch - # torchgeo # vector-quantize-pytorch # vocos einx==0.3.0 @@ -203,8 +149,6 @@ exceptiongroup==1.3.0 # anyio # hypothesis # pytest -fastapi==0.116.1 - # via mlflow-skinny fastparquet==2024.11.0 # via genai-perf fastrlock==0.8.2 @@ -220,10 +164,6 @@ filelock==3.16.1 # torch # transformers # virtualenv -fiona==1.10.1 - # via torchgeo -flask==3.1.1 - # via mlflow fonttools==4.55.0 # via matplotlib fqdn==1.5.1 @@ -240,8 +180,6 @@ fsspec==2024.9.0 # evaluate # fastparquet # huggingface-hub - # lightning - # pytorch-lightning # torch ftfy==6.3.1 # via open-clip-torch @@ -249,41 +187,18 @@ genai-perf==0.0.8 # via -r requirements/test.in genson==1.3.0 # via datamodel-code-generator -geopandas==1.0.1 - # via terratorch -gitdb==4.0.12 - # via gitpython -gitpython==3.1.44 - # via mlflow-skinny google-api-core==2.24.2 # via opencensus google-auth==2.40.2 - # via - # databricks-sdk - # google-api-core + # via google-api-core googleapis-common-protos==1.70.0 # via google-api-core -graphene==3.4.3 - # via mlflow graphql-core==3.2.6 - # via - # graphene - # graphql-relay - # hypothesis-graphql -graphql-relay==3.2.0 - # via graphene -greenlet==3.2.3 - # via sqlalchemy + # via hypothesis-graphql grpcio==1.71.0 # via ray -gunicorn==23.0.0 - # via mlflow h11==0.14.0 - # via - # httpcore - # uvicorn -h5py==3.13.0 - # via terratorch + # via httpcore harfile==0.3.0 # via schemathesis hf-xet==1.1.7 @@ -303,19 +218,13 @@ huggingface-hub==0.34.3 # evaluate # open-clip-torch # peft - # segmentation-models-pytorch # sentence-transformers - # terratorch # timm # tokenizers # transformers # vocos humanize==4.11.0 # via runai-model-streamer -hydra-core==1.3.2 - # via - # lightly - # lightning hypothesis==6.131.0 # via # hypothesis-graphql @@ -333,14 +242,8 @@ idna==3.10 # jsonschema # requests # yarl -imageio==2.37.0 - # via scikit-image importlib-metadata==8.7.0 - # via - # mlflow-skinny - # opentelemetry-api -importlib-resources==6.5.2 - # via typeshed-client + # via opentelemetry-api inflect==5.6.2 # via datamodel-code-generator iniconfig==2.0.0 @@ -349,13 +252,9 @@ isoduration==20.11.0 # via jsonschema isort==5.13.2 # via datamodel-code-generator -itsdangerous==2.2.0 - # via flask jinja2==3.1.6 # via # datamodel-code-generator - # flask - # mlflow # torch jiwer==3.0.5 # via -r requirements/test.in @@ -368,10 +267,6 @@ joblib==1.4.2 # librosa # nltk # scikit-learn -jsonargparse==4.35.0 - # via - # lightning - # terratorch jsonlines==4.0.0 # via lm-eval jsonpointer==3.0.0 @@ -390,33 +285,12 @@ kaleido==0.2.1 # via genai-perf kiwisolver==1.4.7 # via matplotlib -kornia==0.8.1 - # via torchgeo -kornia-rs==0.1.9 - # via kornia lazy-loader==0.4 - # via - # librosa - # scikit-image + # via librosa libnacl==2.1.0 # via tensorizer librosa==0.10.2.post1 # via -r requirements/test.in -lightly==1.5.20 - # via - # terratorch - # torchgeo -lightly-utils==0.0.2 - # via lightly -lightning==2.5.1.post0 - # via - # terratorch - # torchgeo -lightning-utilities==0.14.3 - # via - # lightning - # pytorch-lightning - # torchmetrics llvmlite==0.44.0 # via numba lm-eval @ git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d @@ -425,25 +299,14 @@ lxml==5.3.0 # via # blobfile # sacrebleu -mako==1.3.10 - # via alembic -markdown==3.8.2 - # via mlflow markdown-it-py==3.0.0 # via rich markupsafe==3.0.1 # via - # flask # jinja2 - # mako # werkzeug matplotlib==3.9.2 - # via - # -r requirements/test.in - # lightning - # mlflow - # pycocotools - # torchgeo + # via -r requirements/test.in mbstrdecoder==1.1.3 # via # dataproperty @@ -453,10 +316,6 @@ mdurl==0.1.2 # via markdown-it-py mistral-common==1.8.2 # via -r requirements/test.in -mlflow==2.22.0 - # via terratorch -mlflow-skinny==2.22.0 - # via mlflow more-itertools==10.5.0 # via lm-eval mpmath==1.3.0 @@ -475,14 +334,10 @@ multiprocess==0.70.16 # via # datasets # evaluate -munch==4.0.0 - # via pretrainedmodels mypy-extensions==1.0.0 # via black networkx==3.2.1 - # via - # scikit-image - # torch + # via torch nltk==3.9.1 # via rouge-score num2words==0.5.14 @@ -497,8 +352,6 @@ numpy==1.26.4 # via # -r requirements/test.in # accelerate - # albucore - # albumentations # bitsandbytes # bm25s # contourpy @@ -510,15 +363,9 @@ numpy==1.26.4 # evaluate # fastparquet # genai-perf - # geopandas - # h5py - # imageio # librosa - # lightly - # lightly-utils # matplotlib # mistral-common - # mlflow # mteb # numba # numexpr @@ -526,34 +373,18 @@ numpy==1.26.4 # pandas # patsy # peft - # pycocotools - # pyogrio - # rasterio - # rioxarray # rouge-score # runai-model-streamer # sacrebleu - # scikit-image # scikit-learn # scipy - # segmentation-models-pytorch - # shapely # soxr # statsmodels - # tensorboardx # tensorizer - # tifffile - # torchgeo - # torchmetrics # torchvision # transformers # tritonclient # vocos - # xarray -omegaconf==2.3.0 - # via - # hydra-core - # lightning open-clip-torch==2.32.0 # via -r requirements/test.in opencensus==0.11.4 @@ -563,12 +394,9 @@ opencensus-context==0.1.3 opencv-python-headless==4.11.0.86 # via # -r requirements/test.in - # albucore - # albumentations # mistral-common opentelemetry-api==1.35.0 # via - # mlflow-skinny # opentelemetry-exporter-prometheus # opentelemetry-sdk # opentelemetry-semantic-conventions @@ -578,7 +406,6 @@ opentelemetry-proto==1.36.0 # via ray opentelemetry-sdk==1.35.0 # via - # mlflow-skinny # opentelemetry-exporter-prometheus # ray opentelemetry-semantic-conventions==0.56b0 @@ -591,43 +418,25 @@ packaging==24.2 # datasets # evaluate # fastparquet - # geopandas - # gunicorn # huggingface-hub - # hydra-core - # kornia # lazy-loader - # lightning - # lightning-utilities # matplotlib - # mlflow-skinny # peft # plotly # pooch - # pyogrio # pytest # pytest-rerunfailures - # pytorch-lightning # ray - # rioxarray - # scikit-image # statsmodels - # tensorboardx - # torchmetrics # transformers # typepy - # xarray pandas==2.2.3 # via # datasets # evaluate # fastparquet # genai-perf - # geopandas - # mlflow # statsmodels - # torchgeo - # xarray pathspec==0.12.1 # via black pathvalidate==3.2.1 @@ -641,14 +450,9 @@ peft==0.16.0 pillow==10.4.0 # via # genai-perf - # imageio - # lightly-utils # matplotlib # mistral-common - # scikit-image - # segmentation-models-pytorch # sentence-transformers - # torchgeo # torchvision platformdirs==4.3.6 # via @@ -667,8 +471,6 @@ portalocker==2.10.1 # via sacrebleu pqdm==0.2.0 # via -r requirements/test.in -pretrainedmodels==0.7.4 - # via segmentation-models-pytorch prometheus-client==0.22.0 # via # opentelemetry-exporter-prometheus @@ -681,11 +483,9 @@ protobuf==5.28.3 # via # google-api-core # googleapis-common-protos - # mlflow-skinny # opentelemetry-proto # proto-plus # ray - # tensorboardx # tensorizer psutil==6.1.0 # via @@ -700,7 +500,6 @@ pyarrow==18.0.0 # via # datasets # genai-perf - # mlflow pyasn1==0.6.1 # via # pyasn1-modules @@ -709,8 +508,6 @@ pyasn1-modules==0.4.2 # via google-auth pybind11==2.13.6 # via lm-eval -pycocotools==2.0.8 - # via terratorch pycountry==24.6.1 # via pydantic-extra-types pycparser==2.22 @@ -720,12 +517,8 @@ pycryptodomex==3.22.0 pydantic==2.11.7 # via # -r requirements/test.in - # albumentations # datamodel-code-generator - # fastapi - # lightly # mistral-common - # mlflow-skinny # mteb # pydantic-extra-types # ray @@ -735,17 +528,8 @@ pydantic-extra-types==2.10.5 # via mistral-common pygments==2.18.0 # via rich -pyogrio==0.11.0 - # via geopandas pyparsing==3.2.0 - # via - # matplotlib - # rasterio -pyproj==3.7.1 - # via - # geopandas - # rioxarray - # torchgeo + # via matplotlib pyrate-limiter==3.7.0 # via schemathesis pystemmer==3.0.0 @@ -765,7 +549,6 @@ pytest==8.3.5 # pytest-subtests # pytest-timeout # schemathesis - # terratorch pytest-asyncio==0.24.0 # via -r requirements/test.in pytest-forked==1.6.0 @@ -780,23 +563,15 @@ pytest-subtests==0.14.1 # via schemathesis pytest-timeout==2.3.1 # via -r requirements/test.in -python-box==7.3.2 - # via terratorch python-dateutil==2.9.0.post0 # via # arrow # botocore - # graphene - # lightly # matplotlib # pandas # typepy python-rapidjson==1.20 # via tritonclient -pytorch-lightning==2.5.2 - # via - # lightly - # lightning pytrec-eval-terrier==0.5.7 # via mteb pytz==2024.2 @@ -806,17 +581,11 @@ pytz==2024.2 pyyaml==6.0.2 # via # accelerate - # albumentations # datamodel-code-generator # datasets # genai-perf # huggingface-hub - # jsonargparse - # lightning - # mlflow-skinny - # omegaconf # peft - # pytorch-lightning # ray # responses # schemathesis @@ -825,11 +594,6 @@ pyyaml==6.0.2 # vocos rapidfuzz==3.12.1 # via jiwer -rasterio==1.4.3 - # via - # rioxarray - # terratorch - # torchgeo ray==2.48.0 # via -r requirements/test.in redis==5.2.0 @@ -848,16 +612,12 @@ regex==2024.9.11 requests==2.32.3 # via # buildkite-test-collector - # databricks-sdk # datasets - # docker # evaluate # google-api-core # huggingface-hub - # lightly # lm-eval # mistral-common - # mlflow-skinny # mteb # pooch # ray @@ -875,11 +635,8 @@ rfc3987==1.3.8 rich==13.9.4 # via # genai-perf - # lightning # mteb # typer -rioxarray==0.19.0 - # via terratorch rouge-score==0.1.2 # via lm-eval rpds-py==0.20.1 @@ -888,8 +645,6 @@ rpds-py==0.20.1 # referencing rsa==4.9.1 # via google-auth -rtree==1.4.0 - # via torchgeo runai-model-streamer==0.11.0 # via -r requirements/test.in runai-model-streamer-s3==0.11.0 @@ -907,32 +662,21 @@ safetensors==0.4.5 # transformers schemathesis==3.39.15 # via -r requirements/test.in -scikit-image==0.25.2 - # via albumentations scikit-learn==1.5.2 # via - # albumentations # librosa # lm-eval - # mlflow # mteb # sentence-transformers scipy==1.13.1 # via - # albumentations # bm25s # librosa - # mlflow # mteb - # scikit-image # scikit-learn # sentence-transformers # statsmodels # vocos -segmentation-models-pytorch==0.4.0 - # via - # terratorch - # torchgeo sentence-transformers==3.2.1 # via # -r requirements/test.in @@ -940,28 +684,18 @@ sentence-transformers==3.2.1 sentencepiece==0.2.0 # via mistral-common setuptools==77.0.3 - # via - # lightning-utilities - # pytablewriter -shapely==2.1.1 - # via - # geopandas - # torchgeo + # via pytablewriter shellingham==1.5.4 # via typer six==1.16.0 # via # junit-xml - # lightly # opencensus # python-dateutil # rfc3339-validator # rouge-score - # segmentation-models-pytorch smart-open==7.1.0 # via ray -smmap==5.0.2 - # via gitdb sniffio==1.3.1 # via # anyio @@ -977,17 +711,10 @@ soxr==0.5.0.post1 # via # librosa # mistral-common -sqlalchemy==2.0.41 - # via - # alembic - # mlflow sqlitedict==2.1.0 # via lm-eval -sqlparse==0.5.3 - # via mlflow-skinny starlette==0.46.2 # via - # fastapi # schemathesis # starlette-testclient starlette-testclient==0.4.1 @@ -1010,18 +737,10 @@ tenacity==9.0.0 # via # lm-eval # plotly -tensorboardx==2.6.4 - # via lightning tensorizer==2.10.1 # via -r requirements/test.in -terratorch @ git+https://github.com/IBM/terratorch.git@07184fcf91a1324f831ff521dd238d97fe350e3e - # via -r requirements/test.in threadpoolctl==3.5.0 # via scikit-learn -tifffile==2025.3.30 - # via - # scikit-image - # terratorch tiktoken==0.7.0 # via # lm-eval @@ -1030,9 +749,6 @@ timm==1.0.17 # via # -r requirements/test.in # open-clip-torch - # segmentation-models-pytorch - # terratorch - # torchgeo tokenizers==0.21.1 # via # -r requirements/test.in @@ -1041,7 +757,6 @@ toml==0.10.2 # via datamodel-code-generator tomli==2.2.1 # via - # alembic # black # pytest # schemathesis @@ -1052,27 +767,17 @@ torch==2.6.0+cpu # -r requirements/test.in # accelerate # bitsandbytes - # efficientnet-pytorch # encodec # fastsafetensors - # kornia - # lightly - # lightning # lm-eval # mteb # open-clip-torch # peft - # pretrainedmodels - # pytorch-lightning # runai-model-streamer - # segmentation-models-pytorch # sentence-transformers # tensorizer - # terratorch # timm # torchaudio - # torchgeo - # torchmetrics # torchvision # vector-quantize-pytorch # vocos @@ -1080,39 +785,21 @@ torchaudio==2.6.0+cpu # via # encodec # vocos -torchgeo==0.6.2 - # via terratorch -torchmetrics==1.7.4 - # via - # lightning - # pytorch-lightning - # terratorch - # torchgeo torchvision==0.21.0+cpu # via - # lightly # open-clip-torch - # pretrainedmodels - # segmentation-models-pytorch - # terratorch # timm - # torchgeo tqdm==4.66.6 # via # datasets # evaluate # huggingface-hub - # lightly - # lightning # lm-eval # mteb # nltk # open-clip-torch # peft # pqdm - # pretrainedmodels - # pytorch-lightning - # segmentation-models-pytorch # sentence-transformers # tqdm-multiprocess # transformers @@ -1141,23 +828,14 @@ typer==0.15.2 # via fastsafetensors types-python-dateutil==2.9.0.20241206 # via arrow -typeshed-client==2.8.2 - # via jsonargparse typing-extensions==4.12.2 # via - # albumentations - # alembic # anyio # black # exceptiongroup - # fastapi - # graphene # huggingface-hub # librosa - # lightning - # lightning-utilities # mistral-common - # mlflow-skinny # mteb # multidict # opentelemetry-api @@ -1167,14 +845,10 @@ typing-extensions==4.12.2 # pydantic # pydantic-core # pydantic-extra-types - # pytorch-lightning # rich - # sqlalchemy # torch # typer - # typeshed-client # typing-inspection - # uvicorn typing-inspection==0.4.1 # via pydantic tzdata==2024.2 @@ -1185,13 +859,9 @@ urllib3==2.2.3 # via # blobfile # botocore - # docker - # lightly # requests # responses # tritonclient -uvicorn==0.35.0 - # via mlflow-skinny vector-quantize-pytorch==1.21.2 # via -r requirements/test.in virtualenv==20.31.2 @@ -1203,15 +873,11 @@ wcwidth==0.2.13 webcolors==24.11.1 # via jsonschema werkzeug==3.1.3 - # via - # flask - # schemathesis + # via schemathesis word2number==1.1 # via lm-eval wrapt==1.17.2 # via smart-open -xarray==2025.6.1 - # via rioxarray xxhash==3.5.0 # via # datasets diff --git a/setup.py b/setup.py index 5a2b9ebf8..05e437f6f 100755 --- a/setup.py +++ b/setup.py @@ -22,12 +22,13 @@ try: from torch.utils.cpp_extension import MACA_HOME + USE_MACA = True except ImportError: MACA_HOME = None USE_MACA = False -CMAKE_EXECUTABLE = 'cmake' if not USE_MACA else 'cmake_maca' +CMAKE_EXECUTABLE = "cmake" if not USE_MACA else "cmake_maca" def load_module_from_path(module_name, path): @@ -43,12 +44,10 @@ def load_module_from_path(module_name, path): # cannot import envs directly because it depends on vllm, # which is not installed yet -envs = load_module_from_path('envs', - os.path.join(ROOT_DIR, 'vllm_metax', 'envs.py')) +envs = load_module_from_path("envs", os.path.join(ROOT_DIR, "vllm_metax", "envs.py")) try: - vllm_dist_path = importlib.metadata.distribution("vllm").locate_file( - "vllm") + vllm_dist_path = importlib.metadata.distribution("vllm").locate_file("vllm") logger.info("detected vllm distribution path: %s", vllm_dist_path) except importlib.metadata.PackageNotFoundError: vllm_dist_path = None @@ -59,8 +58,11 @@ def load_module_from_path(module_name, path): VLLM_TARGET_DEVICE = envs.VLLM_TARGET_DEVICE -if not (sys.platform.startswith("linux") or torch.version.cuda is None - or os.getenv("VLLM_TARGET_DEVICE") != "cuda"): +if not ( + sys.platform.startswith("linux") + or torch.version.cuda is None + or os.getenv("VLLM_TARGET_DEVICE") != "cuda" +): # if cuda or hip is not available and VLLM_TARGET_DEVICE is not set, # fallback to cpu raise AssertionError("Plugin only support cuda on linux platform. ") @@ -93,8 +95,7 @@ def is_url_available(url: str) -> bool: class CMakeExtension(Extension): - - def __init__(self, name: str, cmake_lists_dir: str = '.', **kwa) -> None: + def __init__(self, name: str, cmake_lists_dir: str = ".", **kwa) -> None: super().__init__(name, sources=[], py_limited_api=True, **kwa) self.cmake_lists_dir = os.path.abspath(cmake_lists_dir) @@ -142,36 +143,36 @@ def configure(self, ext: CMakeExtension) -> None: cfg = envs.CMAKE_BUILD_TYPE or default_cfg cmake_args = [ - '-DCMAKE_BUILD_TYPE={}'.format(cfg), - '-DVLLM_TARGET_DEVICE={}'.format(VLLM_TARGET_DEVICE), + "-DCMAKE_BUILD_TYPE={}".format(cfg), + "-DVLLM_TARGET_DEVICE={}".format(VLLM_TARGET_DEVICE), ] verbose = envs.VERBOSE if verbose: - cmake_args += ['-DCMAKE_VERBOSE_MAKEFILE=ON'] + cmake_args += ["-DCMAKE_VERBOSE_MAKEFILE=ON"] if is_sccache_available(): cmake_args += [ - '-DCMAKE_C_COMPILER_LAUNCHER=sccache', - '-DCMAKE_CXX_COMPILER_LAUNCHER=sccache', - '-DCMAKE_CUDA_COMPILER_LAUNCHER=sccache', - '-DCMAKE_HIP_COMPILER_LAUNCHER=sccache', + "-DCMAKE_C_COMPILER_LAUNCHER=sccache", + "-DCMAKE_CXX_COMPILER_LAUNCHER=sccache", + "-DCMAKE_CUDA_COMPILER_LAUNCHER=sccache", + "-DCMAKE_HIP_COMPILER_LAUNCHER=sccache", ] elif is_ccache_available(): cmake_args += [ - '-DCMAKE_C_COMPILER_LAUNCHER=ccache', - '-DCMAKE_CXX_COMPILER_LAUNCHER=ccache', - '-DCMAKE_CUDA_COMPILER_LAUNCHER=ccache', - '-DCMAKE_HIP_COMPILER_LAUNCHER=ccache', + "-DCMAKE_C_COMPILER_LAUNCHER=ccache", + "-DCMAKE_CXX_COMPILER_LAUNCHER=ccache", + "-DCMAKE_CUDA_COMPILER_LAUNCHER=ccache", + "-DCMAKE_HIP_COMPILER_LAUNCHER=ccache", ] # Pass the python executable to cmake so it can find an exact # match. - cmake_args += ['-DVLLM_PYTHON_EXECUTABLE={}'.format(sys.executable)] + cmake_args += ["-DVLLM_PYTHON_EXECUTABLE={}".format(sys.executable)] # Pass the python path to cmake so it can reuse the build dependencies # on subsequent calls to python. - cmake_args += ['-DVLLM_PYTHON_PATH={}'.format(":".join(sys.path))] + cmake_args += ["-DVLLM_PYTHON_PATH={}".format(":".join(sys.path))] # Override the base directory for FetchContent downloads to $ROOT/.deps # This allows sharing dependencies between profiles, @@ -179,7 +180,7 @@ def configure(self, ext: CMakeExtension) -> None: # To override this, set the FETCHCONTENT_BASE_DIR environment variable. fc_base_dir = os.path.join(ROOT_DIR, ".deps") fc_base_dir = os.environ.get("FETCHCONTENT_BASE_DIR", fc_base_dir) - cmake_args += ['-DFETCHCONTENT_BASE_DIR={}'.format(fc_base_dir)] + cmake_args += ["-DFETCHCONTENT_BASE_DIR={}".format(fc_base_dir)] # # Setup parallelism and build tool @@ -187,13 +188,13 @@ def configure(self, ext: CMakeExtension) -> None: num_jobs, nvcc_threads = self.compute_num_jobs() if nvcc_threads: - cmake_args += ['-DNVCC_THREADS={}'.format(nvcc_threads)] + cmake_args += ["-DNVCC_THREADS={}".format(nvcc_threads)] if is_ninja_available(): - build_tool = ['-G', 'Ninja'] + build_tool = ["-G", "Ninja"] cmake_args += [ - '-DCMAKE_JOB_POOL_COMPILE:STRING=compile', - '-DCMAKE_JOB_POOLS:STRING=compile={}'.format(num_jobs), + "-DCMAKE_JOB_POOL_COMPILE:STRING=compile", + "-DCMAKE_JOB_POOLS:STRING=compile={}".format(num_jobs), ] else: # Default build tool to whatever cmake picks. @@ -201,19 +202,20 @@ def configure(self, ext: CMakeExtension) -> None: # Make sure we use the nvcc from CUDA_HOME if _is_cuda() and not USE_MACA: - cmake_args += [f'-DCMAKE_CUDA_COMPILER={CUDA_HOME}/bin/nvcc'] + cmake_args += [f"-DCMAKE_CUDA_COMPILER={CUDA_HOME}/bin/nvcc"] if USE_MACA: - cmake_args += ['-DUSE_MACA=1'] + cmake_args += ["-DUSE_MACA=1"] subprocess.check_call( [CMAKE_EXECUTABLE, ext.cmake_lists_dir, *build_tool, *cmake_args], - cwd=self.build_temp) + cwd=self.build_temp, + ) def build_extensions(self) -> None: # Ensure that CMake is present and working try: - subprocess.check_output([CMAKE_EXECUTABLE, '--version']) + subprocess.check_output([CMAKE_EXECUTABLE, "--version"]) except OSError as e: - raise RuntimeError('Cannot find CMake executable') from e + raise RuntimeError("Cannot find CMake executable") from e # Create build directory if it does not exist. if not os.path.exists(self.build_temp): @@ -222,8 +224,7 @@ def build_extensions(self) -> None: targets = [] def target_name(s: str) -> str: - return s.removeprefix("vllm_metax.").removeprefix( - "vllm_flash_attn.") + return s.removeprefix("vllm_metax.").removeprefix("vllm_flash_attn.") # Build all the extensions for ext in self.extensions: @@ -239,8 +240,7 @@ def target_name(s: str) -> str: *[f"--target={name}" for name in targets], ] - subprocess.check_call([CMAKE_EXECUTABLE, *build_args], - cwd=self.build_temp) + subprocess.check_call([CMAKE_EXECUTABLE, *build_args], cwd=self.build_temp) # Install the libraries for ext in self.extensions: @@ -254,14 +254,18 @@ def target_name(s: str) -> str: # CMake appends the extension prefix to the install path, # and outdir already contains that prefix, so we need to remove it. prefix = outdir - for _ in range(ext.name.count('.')): + for _ in range(ext.name.count(".")): prefix = prefix.parent # prefix here should actually be the same for all components install_args = [ - CMAKE_EXECUTABLE, "--install", ".", "--prefix", prefix, + CMAKE_EXECUTABLE, + "--install", + ".", + "--prefix", + prefix, "--component", - target_name(ext.name) + target_name(ext.name), ] subprocess.check_call(install_args, cwd=self.build_temp) @@ -280,34 +284,44 @@ def get_base_commit_in_main_branch(self) -> str: try: # Get the latest commit hash of the upstream main branch. - resp_json = subprocess.check_output([ - "curl", "-s", - "https://api.github.com/repos/vllm-project/vllm/commits/main" - ]).decode("utf-8") + resp_json = subprocess.check_output( + [ + "curl", + "-s", + "https://api.github.com/repos/vllm-project/vllm/commits/main", + ] + ).decode("utf-8") upstream_main_commit = json.loads(resp_json)["sha"] # Check if the upstream_main_commit exists in the local repo try: subprocess.check_output( - ["git", "cat-file", "-e", f"{upstream_main_commit}"]) + ["git", "cat-file", "-e", f"{upstream_main_commit}"] + ) except subprocess.CalledProcessError: # If not present, fetch it from the remote repository. # Note that this does not update any local branches, # but ensures that this commit ref and its history are # available in our local repo. - subprocess.check_call([ - "git", "fetch", "https://github.com/vllm-project/vllm", - "main" - ]) + subprocess.check_call( + ["git", "fetch", "https://github.com/vllm-project/vllm", "main"] + ) # Then get the commit hash of the current branch that is the same as # the upstream main commit. - current_branch = subprocess.check_output( - ["git", "branch", "--show-current"]).decode("utf-8").strip() + current_branch = ( + subprocess.check_output(["git", "branch", "--show-current"]) + .decode("utf-8") + .strip() + ) - base_commit = subprocess.check_output([ - "git", "merge-base", f"{upstream_main_commit}", current_branch - ]).decode("utf-8").strip() + base_commit = ( + subprocess.check_output( + ["git", "merge-base", f"{upstream_main_commit}", current_branch] + ) + .decode("utf-8") + .strip() + ) return base_commit except ValueError as err: raise ValueError(err) from None @@ -315,12 +329,13 @@ def get_base_commit_in_main_branch(self) -> str: logger.warning( "Failed to get the base commit in the main branch. " "Using the nightly wheel. The libraries in this " - "wheel may not be compatible with your dev branch: %s", err) + "wheel may not be compatible with your dev branch: %s", + err, + ) return "nightly" def run(self) -> None: - assert _is_cuda( - ), "VLLM_USE_PRECOMPILED is only supported for CUDA builds" + assert _is_cuda(), "VLLM_USE_PRECOMPILED is only supported for CUDA builds" wheel_location = os.getenv("VLLM_PRECOMPILED_WHEEL_LOCATION", None) if wheel_location is None: @@ -357,7 +372,8 @@ def run(self) -> None: from setuptools.errors import SetupError raise SetupError( - f"Failed to get vLLM wheel from {wheel_location}") from e + f"Failed to get vLLM wheel from {wheel_location}" + ) from e with zipfile.ZipFile(wheel_path) as wheel: files_to_copy = [ @@ -368,20 +384,21 @@ def run(self) -> None: ] file_members = list( - filter(lambda x: x.filename in files_to_copy, wheel.filelist)) + filter(lambda x: x.filename in files_to_copy, wheel.filelist) + ) # vllm_flash_attn python code: # Regex from # `glob.translate('vllm/vllm_flash_attn/**/*.py', recursive=True)` compiled_regex = re.compile( - r"vllm/vllm_flash_attn/(?:[^/.][^/]*/)*(?!\.)[^/]*\.py") + r"vllm/vllm_flash_attn/(?:[^/.][^/]*/)*(?!\.)[^/]*\.py" + ) file_members += list( - filter(lambda x: compiled_regex.match(x.filename), - wheel.filelist)) + filter(lambda x: compiled_regex.match(x.filename), wheel.filelist) + ) for file in file_members: - print(f"Extracting and including {file.filename} " - "from existing wheel") + print(f"Extracting and including {file.filename} from existing wheel") package_name = os.path.dirname(file.filename).replace("/", ".") file_name = os.path.basename(file.filename) @@ -398,27 +415,28 @@ def run(self) -> None: def _is_cuda() -> bool: has_cuda = torch.version.cuda is not None - return (VLLM_TARGET_DEVICE == "cuda" and has_cuda) + return VLLM_TARGET_DEVICE == "cuda" and has_cuda def _build_custom_ops() -> bool: return _is_cuda() + def get_maca_version() -> Version: """ Returns the MACA SDK Version """ - file_full_path = os.path.join(os.getenv('MACA_PATH'), 'Version.txt') + file_full_path = os.path.join(os.getenv("MACA_PATH"), "Version.txt") if not os.path.isfile(file_full_path): return None - with open(file_full_path, encoding='utf-8') as file: + with open(file_full_path, encoding="utf-8") as file: first_line = file.readline().strip() return parse(first_line.split(":")[-1]) def fixed_version_scheme(version: ScmVersion) -> str: - return "0.11.0" + return "0.11.1" def always_hash(version: ScmVersion) -> str: @@ -426,17 +444,20 @@ def always_hash(version: ScmVersion) -> str: Always include short commit hash and current date (YYYYMMDD) """ from datetime import datetime + date_str = datetime.now().strftime("%Y%m%d") if version.node is not None: short_hash = version.node[:7] # short commit id - return f"g{short_hash}.d{date_str}" + return f"{short_hash}.d{date_str}" return f"unknown.{date_str}" def get_vllm_version() -> str: - version = get_version(version_scheme=fixed_version_scheme, - local_scheme=always_hash, - write_to="vllm_metax/_version.py") + version = get_version( + version_scheme=fixed_version_scheme, + local_scheme=always_hash, + write_to="vllm_metax/_version.py", + ) sep = "+" if "+" not in version else "." # dev versions might contain + if _is_cuda(): @@ -464,8 +485,11 @@ def _read_requirements(filename: str) -> list[str]: for line in requirements: if line.startswith("-r "): resolved_requirements += _read_requirements(line.split()[1]) - elif not line.startswith("--") and not line.startswith( - "#") and line.strip() != "": + elif ( + not line.startswith("--") + and not line.startswith("#") + and line.strip() != "" + ): resolved_requirements.append(line) return resolved_requirements @@ -474,7 +498,7 @@ def _read_requirements(filename: str) -> list[str]: cuda_major, cuda_minor = torch.version.cuda.split(".") modified_requirements = [] for req in requirements: - if ("vllm-flash-attn" in req and cuda_major != "12"): + if "vllm-flash-attn" in req and cuda_major != "12": # vllm-flash-attn is built only for CUDA 12.x. # Skip for other versions. continue @@ -482,8 +506,8 @@ def _read_requirements(filename: str) -> list[str]: requirements = modified_requirements else: raise ValueError( - "Unsupported platform, please use CUDA, ROCm, Neuron, HPU, " - "or CPU.") + "Unsupported platform, please use CUDA, ROCm, Neuron, HPU, or CPU." + ) return requirements @@ -507,7 +531,6 @@ def _read_requirements(filename: str) -> list[str]: class custom_install(install): - def _copy_with_backup(self, src_path: Path, dest_path: Path): """ Copy a file or directory from src_path to dest_path. @@ -525,15 +548,11 @@ def _copy_with_backup(self, src_path: Path, dest_path: Path): # Backup if target path already exists (file or dir) if os.path.exists(dest_full_path): - backup_path = dest_full_path.parent / (dest_full_path.name + - ".bak") - logger.debug( - f"{dest_full_path} exists, backing it up to {backup_path}") + backup_path = dest_full_path.parent / (dest_full_path.name + ".bak") + logger.debug(f"{dest_full_path} exists, backing it up to {backup_path}") if os.path.exists(backup_path): - logger.debug( - f"Backup path {backup_path} already exists, removing it.") - if os.path.isdir( - backup_path) and not os.path.islink(backup_path): + logger.debug(f"Backup path {backup_path} already exists, removing it.") + if os.path.isdir(backup_path) and not os.path.islink(backup_path): shutil.rmtree(backup_path) else: os.remove(backup_path) @@ -560,15 +579,9 @@ def run(self): return files_to_copy = { - "vllm_metax/_C.abi3.so": - vllm_dist_path, - "vllm_metax/_moe_C.abi3.so": - vllm_dist_path, - "vllm_metax/cumem_allocator.abi3.so": - vllm_dist_path, # for get_available_device: set cuda - "vllm_metax/patch/vllm_substitution/utils.py": - vllm_dist_path / "model_executor/layers/fla/ops/utils.py", + "vllm_metax/patch/vllm_substitution/utils.py": vllm_dist_path + / "model_executor/layers/fla/ops/utils.py", } for src_path, dest_path in files_to_copy.items(): @@ -580,9 +593,8 @@ def run(self): cmdclass = {} else: cmdclass = { - "build_ext": - repackage_wheel if envs.VLLM_USE_PRECOMPILED else cmake_build_ext, - "install": custom_install + "build_ext": repackage_wheel if envs.VLLM_USE_PRECOMPILED else cmake_build_ext, + # "install": custom_install, } setup( @@ -596,7 +608,7 @@ def run(self): "fastsafetensors": ["fastsafetensors >= 0.1.10"], "runai": ["runai-model-streamer", "runai-model-streamer-s3", "boto3"], "audio": ["librosa", "soundfile"], # Required for audio processing - "video": [] # Kept for backwards compatibility + "video": [], # Kept for backwards compatibility }, cmdclass=cmdclass, package_data=package_data, diff --git a/tools/check_spdx_header.py b/tools/check_spdx_header.py index 5b242b8b2..d109a8b9b 100644 --- a/tools/check_spdx_header.py +++ b/tools/check_spdx_header.py @@ -7,6 +7,7 @@ class SPDXStatus(Enum): """SPDX header status enumeration""" + EMPTY = "empty" # empty __init__.py COMPLETE = "complete" MISSING_LICENSE = "missing_license" # Only has copyright line @@ -16,7 +17,8 @@ class SPDXStatus(Enum): FULL_SPDX_HEADER = ( "# SPDX-License-Identifier: Apache-2.0\n" - "# SPDX-FileCopyrightText: Copyright contributors to the vLLM project") + "# SPDX-FileCopyrightText: Copyright contributors to the vLLM project" +) LICENSE_LINE = "# SPDX-License-Identifier: Apache-2.0" COPYRIGHT_LINE = "# SPDX-FileCopyrightText: Copyright contributors to the vLLM project" # noqa: E501 @@ -58,8 +60,7 @@ def check_spdx_header_status(file_path): # else: # # Completely missing both lines # return SPDXStatus.MISSING_BOTH - return (SPDXStatus.COMPLETE - if has_license else SPDXStatus.MISSING_LICENSE) + return SPDXStatus.COMPLETE if has_license else SPDXStatus.MISSING_LICENSE def add_header(file_path, status): @@ -128,8 +129,9 @@ def main(): continue # Collect all files that need fixing - all_files_to_fix = (files_missing_both + files_missing_copyright + - files_missing_license) + all_files_to_fix = ( + files_missing_both + files_missing_copyright + files_missing_license + ) if all_files_to_fix: print("The following files are missing the SPDX header:") if files_missing_both: diff --git a/tools/check_triton_import.py b/tools/check_triton_import.py deleted file mode 100644 index 6d7144bb5..000000000 --- a/tools/check_triton_import.py +++ /dev/null @@ -1,86 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import subprocess -import sys - -import regex as re - -FORBIDDEN_IMPORT_RE = re.compile(r"^(from|import)\s+triton(\s|\.|$)") - -# the way allowed to import triton -ALLOWED_LINES = { - "from vllm.triton_utils import triton", "from vllm.triton_utils import tl", - "from vllm.triton_utils import tl, triton", "from from triton.testing" -} - -ALLOWED_FILES = {"vllm/triton_utils/importing.py"} - - -def is_allowed_file(current_file: str) -> bool: - return current_file in ALLOWED_FILES - - -def is_forbidden_import(line: str) -> bool: - stripped = line.strip() - return bool( - FORBIDDEN_IMPORT_RE.match(stripped)) and stripped not in ALLOWED_LINES - - -def parse_diff(diff: str) -> list[str]: - violations = [] - current_file = None - current_lineno = None - skip_allowed_file = False - - for line in diff.splitlines(): - if line.startswith("+++ b/"): - current_file = line[6:] - skip_allowed_file = is_allowed_file(current_file) - elif skip_allowed_file: - continue - elif line.startswith("@@"): - match = re.search(r"\+(\d+)", line) - if match: - current_lineno = int( - match.group(1)) - 1 # next "+ line" is here - elif line.startswith("+") and not line.startswith("++"): - current_lineno += 1 - code_line = line[1:] - if is_forbidden_import(code_line): - violations.append( - f"{current_file}:{current_lineno}: {code_line.strip()}") - return violations - - -def get_diff(diff_type: str) -> str: - if diff_type == "staged": - return subprocess.check_output( - ["git", "diff", "--cached", "--unified=0"], text=True) - elif diff_type == "unstaged": - return subprocess.check_output(["git", "diff", "--unified=0"], - text=True) - else: - raise ValueError(f"Unknown diff_type: {diff_type}") - - -def main(): - all_violations = [] - for diff_type in ["staged", "unstaged"]: - try: - diff_output = get_diff(diff_type) - violations = parse_diff(diff_output) - all_violations.extend(violations) - except subprocess.CalledProcessError as e: - print(f"[{diff_type}] Git diff failed: {e}", file=sys.stderr) - - if all_violations: - print("❌ Forbidden direct `import triton` detected." - " ➤ Use `from vllm.triton_utils import triton` instead.\n") - for v in all_violations: - print(f"❌ {v}") - return 1 - return 0 - - -if __name__ == "__main__": - sys.exit(main()) diff --git a/tools/generate_cmake_presets.py b/tools/generate_cmake_presets.py index 5f92f2f58..85847c2c0 100644 --- a/tools/generate_cmake_presets.py +++ b/tools/generate_cmake_presets.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import argparse import json import multiprocessing import os @@ -11,8 +12,7 @@ # most reliable source of truth for vLLM's build. from torch.utils.cpp_extension import CUDA_HOME except ImportError: - print("Warning: PyTorch not found. " - "Falling back to CUDA_HOME environment variable.") + print("Warning: PyTorch not found. Falling back to CUDA_HOME environment variable.") CUDA_HOME = os.environ.get("CUDA_HOME") @@ -26,7 +26,7 @@ def get_cpu_cores(): return multiprocessing.cpu_count() -def generate_presets(output_path="CMakeUserPresets.json"): +def generate_presets(output_path="CMakeUserPresets.json", force_overwrite=False): """Generates the CMakeUserPresets.json file.""" print("Attempting to detect your system configuration...") @@ -37,8 +37,7 @@ def generate_presets(output_path="CMakeUserPresets.json"): prospective_path = os.path.join(CUDA_HOME, "bin", "nvcc") if os.path.exists(prospective_path): nvcc_path = prospective_path - print("Found nvcc via torch.utils.cpp_extension.CUDA_HOME: " - f"{nvcc_path}") + print(f"Found nvcc via torch.utils.cpp_extension.CUDA_HOME: {nvcc_path}") if not nvcc_path: nvcc_path = which("nvcc") @@ -48,7 +47,8 @@ def generate_presets(output_path="CMakeUserPresets.json"): if not nvcc_path: nvcc_path_input = input( "Could not automatically find 'nvcc'. Please provide the full " - "path to nvcc (e.g., /usr/local/cuda/bin/nvcc): ") + "path to nvcc (e.g., /usr/local/cuda/bin/nvcc): " + ) nvcc_path = nvcc_path_input.strip() print(f"Using NVCC path: {nvcc_path}") @@ -61,12 +61,13 @@ def generate_presets(output_path="CMakeUserPresets.json"): "Could not automatically find Python executable. Please provide " "the full path to your Python executable for vLLM development " "(typically from your virtual environment, e.g., " - "/home/user/venvs/vllm/bin/python): ") + "/home/user/venvs/vllm/bin/python): " + ) python_executable = input(python_executable_prompt).strip() if not python_executable: raise ValueError( - "Could not determine Python executable. Please provide it " - "manually.") + "Could not determine Python executable. Please provide it manually." + ) print(f"Using Python executable: {python_executable}") @@ -74,20 +75,23 @@ def generate_presets(output_path="CMakeUserPresets.json"): cpu_cores = get_cpu_cores() nvcc_threads = min(4, cpu_cores) cmake_jobs = max(1, cpu_cores // nvcc_threads) - print(f"Detected {cpu_cores} CPU cores. " - f"Setting NVCC_THREADS={nvcc_threads} and CMake jobs={cmake_jobs}.") + print( + f"Detected {cpu_cores} CPU cores. " + f"Setting NVCC_THREADS={nvcc_threads} and CMake jobs={cmake_jobs}." + ) # Get vLLM project root (assuming this script is in vllm/tools/) - project_root = os.path.abspath( - os.path.join(os.path.dirname(__file__), "..")) + project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) print(f"VLLM project root detected as: {project_root}") # Ensure python_executable path is absolute or resolvable if not os.path.isabs(python_executable) and which(python_executable): python_executable = os.path.abspath(which(python_executable)) elif not os.path.isabs(python_executable): - print(f"Warning: Python executable '{python_executable}' is not an " - "absolute path and not found in PATH. CMake might not find it.") + print( + f"Warning: Python executable '{python_executable}' is not an " + "absolute path and not found in PATH. CMake might not find it." + ) cache_variables = { "CMAKE_CUDA_COMPILER": nvcc_path, @@ -120,50 +124,57 @@ def generate_presets(output_path="CMakeUserPresets.json"): configure_preset["generator"] = "Ninja" cache_variables["CMAKE_JOB_POOLS"] = f"compile={cmake_jobs}" else: - print("Ninja not found, using default generator. " - "Build may be slower.") + print("Ninja not found, using default generator. Build may be slower.") presets = { - "version": - 6, + "version": 6, # Keep in sync with CMakeLists.txt and requirements/build.txt - "cmakeMinimumRequired": { - "major": 3, - "minor": 26, - "patch": 1 - }, + "cmakeMinimumRequired": {"major": 3, "minor": 26, "patch": 1}, "configurePresets": [configure_preset], - "buildPresets": [{ - "name": "release", - "configurePreset": "release", - "jobs": cmake_jobs, - }], + "buildPresets": [ + { + "name": "release", + "configurePreset": "release", + "jobs": cmake_jobs, + } + ], } output_file_path = os.path.join(project_root, output_path) if os.path.exists(output_file_path): - overwrite = input( - f"'{output_file_path}' already exists. Overwrite? (y/N): ").strip( - ).lower() - if overwrite != 'y': - print("Generation cancelled.") - return + if force_overwrite: + print(f"Overwriting existing file '{output_file_path}'") + else: + overwrite = ( + input(f"'{output_file_path}' already exists. Overwrite? (y/N): ") + .strip() + .lower() + ) + if overwrite != "y": + print("Generation cancelled.") + return try: with open(output_file_path, "w") as f: json.dump(presets, f, indent=4) print(f"Successfully generated '{output_file_path}'") print("\nTo use this preset:") - print( - f"1. Ensure you are in the vLLM root directory: cd {project_root}") + print(f"1. Ensure you are in the vLLM root directory: cd {project_root}") print("2. Initialize CMake: cmake --preset release") - print("3. Build+install: cmake --build --preset release " - "--target install") + print("3. Build+install: cmake --build --preset release --target install") except OSError as e: print(f"Error writing file: {e}") if __name__ == "__main__": - generate_presets() + parser = argparse.ArgumentParser() + parser.add_argument( + "--force-overwrite", + action="store_true", + help="Force overwrite existing CMakeUserPresets.json without prompting", + ) + + args = parser.parse_args() + generate_presets(force_overwrite=args.force_overwrite) diff --git a/tools/install_nixl.sh b/tools/install_nixl.sh deleted file mode 100755 index 56717cfb7..000000000 --- a/tools/install_nixl.sh +++ /dev/null @@ -1,109 +0,0 @@ -#!/bin/bash -# Usage: ./install_nixl.sh [--force] - -FORCE=false -if [ "$1" == "--force" ]; then - FORCE=true -fi - -SUDO=false -if command -v sudo >/dev/null 2>&1 && sudo -n true 2>/dev/null; then - SUDO=true -fi - -ARCH=$(uname -m) - -ROOT_DIR="/usr/local" -mkdir -p "$ROOT_DIR" -GDR_HOME="$ROOT_DIR/gdrcopy" -UCX_HOME="$ROOT_DIR/ucx" -NIXL_HOME="$ROOT_DIR/nixl" -CUDA_HOME=/usr/local/cuda - -export PATH="$GDR_HOME/bin:$UCX_HOME/bin:$NIXL_HOME/bin:$PATH" -export LD_LIBRARY_PATH="$GDR_HOME/lib:$UCX_HOME/lib:$NIXL_HOME/lib/$ARCH-linux-gnu:$LD_LIBRARY_PATH" - -TEMP_DIR="nixl_installer" -mkdir -p "$TEMP_DIR" -cd "$TEMP_DIR" - -pip install meson ninja pybind11 - -if [ ! -e "/dev/gdrdrv" ] || [ "$FORCE" = true ]; then - echo "Installing gdrcopy\n" - wget https://github.com/NVIDIA/gdrcopy/archive/refs/tags/v2.5.tar.gz - tar xzf v2.5.tar.gz; rm v2.5.tar.gz - cd gdrcopy-2.5 - make prefix=$GDR_HOME CUDA=$CUDA_HOME all install - - if $SUDO; then - echo "Running insmod.sh with sudo" - sudo ./insmod.sh - else - echo "Skipping insmod.sh - sudo not available" - echo "Please run 'sudo ./gdrcopy-2.5/insmod.sh' manually if needed" - fi - - cd .. -else - echo "Found /dev/gdrdrv. Skipping gdrcopy installation" -fi - -if ! command -v ucx_info &> /dev/null || [ "$FORCE" = true ]; then - echo "Installing UCX" - wget https://github.com/openucx/ucx/releases/download/v1.18.0/ucx-1.18.0.tar.gz - tar xzf ucx-1.18.0.tar.gz; rm ucx-1.18.0.tar.gz - cd ucx-1.18.0 - - # Checking Mellanox NICs - MLX_OPTS="" - if lspci | grep -i mellanox > /dev/null || command -v ibstat > /dev/null; then - echo "Mellanox NIC detected, adding Mellanox-specific options" - MLX_OPTS="--with-rdmacm \ - --with-mlx5-dv \ - --with-ib-hw-tm" - fi - - ./configure --prefix=$UCX_HOME \ - --enable-shared \ - --disable-static \ - --disable-doxygen-doc \ - --enable-optimizations \ - --enable-cma \ - --enable-devel-headers \ - --with-cuda=$CUDA_HOME \ - --with-dm \ - --with-gdrcopy=$GDR_HOME \ - --with-verbs \ - --enable-mt \ - $MLX_OPTS - make -j - make -j install-strip - - if $SUDO; then - echo "Running ldconfig with sudo" - sudo ldconfig - else - echo "Skipping ldconfig - sudo not available" - echo "Please run 'sudo ldconfig' manually if needed" - fi - - cd .. -else - echo "Found existing UCX. Skipping UCX installation" -fi - -if ! command -v nixl_test &> /dev/null || [ "$FORCE" = true ]; then - echo "Installing NIXL" - wget https://github.com/ai-dynamo/nixl/archive/refs/tags/0.2.0.tar.gz - tar xzf 0.2.0.tar.gz; rm 0.2.0.tar.gz - cd nixl-0.2.0 - meson setup build --prefix=$NIXL_HOME -Ducx_path=$UCX_HOME - cd build - ninja - ninja install - - cd ../.. -else - echo "Found existing NIXL. Skipping NIXL installation" -fi diff --git a/tools/mypy.sh b/tools/mypy.sh deleted file mode 100755 index e0f92e55c..000000000 --- a/tools/mypy.sh +++ /dev/null @@ -1,26 +0,0 @@ -#!/bin/bash - -CI=${1:-0} -PYTHON_VERSION=${2:-local} - -if [ "$CI" -eq 1 ]; then - set -e -fi - -if [ $PYTHON_VERSION == "local" ]; then - PYTHON_VERSION=$(python -c 'import sys; print(f"{sys.version_info.major}.{sys.version_info.minor}")') -fi - -run_mypy() { - echo "Running mypy on $1" - if [ "$CI" -eq 1 ] && [ -z "$1" ]; then - mypy --python-version "${PYTHON_VERSION}" "$@" - return - fi - mypy --follow-imports skip --python-version "${PYTHON_VERSION}" "$@" -} - -run_mypy # Note that this is less strict than CI -# run_mypy tests -# run_mypy examples - diff --git a/tools/enforce_regex_import.py b/tools/pre_commit/enforce_regex_import.py similarity index 73% rename from tools/enforce_regex_import.py rename to tools/pre_commit/enforce_regex_import.py index 63ceee582..a29952e92 100644 --- a/tools/enforce_regex_import.py +++ b/tools/pre_commit/enforce_regex_import.py @@ -1,30 +1,27 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from __future__ import annotations - import subprocess from pathlib import Path import regex as re -FORBIDDEN_PATTERNS = re.compile( - r'^\s*(?:import\s+re(?:$|\s|,)|from\s+re\s+import)') +FORBIDDEN_PATTERNS = re.compile(r"^\s*(?:import\s+re(?:$|\s|,)|from\s+re\s+import)") ALLOWED_PATTERNS = [ - re.compile(r'^\s*import\s+regex\s+as\s+re\s*$'), - re.compile(r'^\s*import\s+regex\s*$'), + re.compile(r"^\s*import\s+regex\s+as\s+re\s*$"), + re.compile(r"^\s*import\s+regex\s*$"), ] def get_staged_python_files() -> list[str]: try: result = subprocess.run( - ['git', 'diff', '--cached', '--name-only', '--diff-filter=AM'], + ["git", "diff", "--cached", "--name-only", "--diff-filter=AM"], capture_output=True, text=True, - check=True) - files = result.stdout.strip().split( - '\n') if result.stdout.strip() else [] - return [f for f in files if f.endswith('.py')] + check=True, + ) + files = result.stdout.strip().split("\n") if result.stdout.strip() else [] + return [f for f in files if f.endswith(".py")] except subprocess.CalledProcessError: return [] @@ -33,13 +30,14 @@ def is_forbidden_import(line: str) -> bool: line = line.strip() return bool( FORBIDDEN_PATTERNS.match(line) - and not any(pattern.match(line) for pattern in ALLOWED_PATTERNS)) + and not any(pattern.match(line) for pattern in ALLOWED_PATTERNS) + ) def check_file(filepath: str) -> list[tuple[int, str]]: violations = [] try: - with open(filepath, encoding='utf-8') as f: + with open(filepath, encoding="utf-8") as f: for line_num, line in enumerate(f, 1): if is_forbidden_import(line): violations.append((line_num, line.strip())) @@ -72,9 +70,7 @@ def main() -> int: if total_violations > 0: print(f"\n💡 Found {total_violations} violation(s).") print("❌ Please replace 'import re' with 'import regex as re'") - print( - " Also replace 'from re import ...' with 'from regex import ...'" - ) # noqa: E501 + print(" Also replace 'from re import ...' with 'from regex import ...'") # noqa: E501 print("✅ Allowed imports:") print(" - import regex as re") print(" - import regex") # noqa: E501 diff --git a/tools/pre_commit/mypy.py b/tools/pre_commit/mypy.py new file mode 100644 index 000000000..6e4d92ad2 --- /dev/null +++ b/tools/pre_commit/mypy.py @@ -0,0 +1,156 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Run mypy on changed files. + +This script is designed to be used as a pre-commit hook. It runs mypy +on files that have been changed. It groups files into different mypy calls +based on their directory to avoid import following issues. + +Usage: + python tools/pre_commit/mypy.py + +Args: + ci: "1" if running in CI, "0" otherwise. In CI, follow_imports is set to + "silent" for the main group of files. + python_version: Python version to use (e.g., "3.10") or "local" to use + the local Python version. + changed_files: List of changed files to check. +""" + +import subprocess +import sys + +import regex as re + +FILES = [ + "vllm_metax/*.py", + "vllm_metax/assets", + "vllm_metax/distributed", + "vllm_metax/entrypoints", + "vllm_metax/executor", + "vllm_metax/inputs", + "vllm_metax/logging_utils", + "vllm_metax/multimodal", + "vllm_metax/platforms", + "vllm_metax/transformers_utils", + "vllm_metax/triton_utils", + "vllm_metax/usage", + "vllm_metax/v1/core", + "vllm_metax/v1/engine", +] + +# After fixing errors resulting from changing follow_imports +# from "skip" to "silent", move the following directories to FILES +SEPARATE_GROUPS = [ + "tests", + # v0 related + "vllm_metax/attention", + "vllm_metax/compilation", + "vllm_metax/engine", + "vllm_metax/inputs", + "vllm_metax/lora", + "vllm_metax/model_executor", + "vllm_metax/plugins", + "vllm_metax/worker", + # v1 related + "vllm_metax/v1/attention", + "vllm_metax/v1/executor", + "vllm_metax/v1/kv_offload", + "vllm_metax/v1/metrics", + "vllm_metax/v1/pool", + "vllm_metax/v1/sample", + "vllm_metax/v1/spec_decode", + "vllm_metax/v1/structured_output", + "vllm_metax/v1/worker", +] + +# TODO(woosuk): Include the code from Megatron and HuggingFace. +EXCLUDE = [ + "vllm_metax/model_executor/parallel_utils", + "vllm_metax/model_executor/models", + "vllm_metax/model_executor/layers/fla/ops", + # Ignore triton kernels in ops. + "vllm_metax/attention/ops", +] + + +def group_files(changed_files: list[str]) -> dict[str, list[str]]: + """ + Group changed files into different mypy calls. + + Args: + changed_files: List of changed files. + + Returns: + A dictionary mapping file group names to lists of changed files. + """ + exclude_pattern = re.compile(f"^{'|'.join(EXCLUDE)}.*") + files_pattern = re.compile(f"^({'|'.join(FILES)}).*") + file_groups = {"": []} + file_groups.update({k: [] for k in SEPARATE_GROUPS}) + for changed_file in changed_files: + # Skip files which should be ignored completely + if exclude_pattern.match(changed_file): + continue + # Group files by mypy call + if files_pattern.match(changed_file): + file_groups[""].append(changed_file) + continue + else: + for directory in SEPARATE_GROUPS: + if re.match(f"^{directory}.*", changed_file): + file_groups[directory].append(changed_file) + break + return file_groups + + +def mypy( + targets: list[str], + python_version: str | None, + follow_imports: str | None, + file_group: str, +) -> int: + """ + Run mypy on the given targets. + + Args: + targets: List of files or directories to check. + python_version: Python version to use (e.g., "3.10") or None to use + the default mypy version. + follow_imports: Value for the --follow-imports option or None to use + the default mypy behavior. + file_group: The file group name for logging purposes. + + Returns: + The return code from mypy. + """ + args = ["mypy"] + if python_version is not None: + args += ["--python-version", python_version] + if follow_imports is not None: + args += ["--follow-imports", follow_imports] + print(f"$ {' '.join(args)} {file_group}") + return subprocess.run(args + targets, check=False).returncode + + +def main(): + ci = sys.argv[1] == "1" + python_version = sys.argv[2] + file_groups = group_files(sys.argv[3:]) + + if python_version == "local": + python_version = f"{sys.version_info.major}.{sys.version_info.minor}" + + returncode = 0 + for file_group, changed_files in file_groups.items(): + follow_imports = None if ci and file_group == "" else "skip" + if changed_files: + returncode |= mypy( + changed_files, python_version, follow_imports, file_group + ) + return returncode + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/tools/shellcheck.sh b/tools/pre_commit/shellcheck.sh similarity index 100% rename from tools/shellcheck.sh rename to tools/pre_commit/shellcheck.sh diff --git a/tools/profiler/print_layerwise_table.py b/tools/profiler/print_layerwise_table.py index 209c3a576..d7a24a598 100644 --- a/tools/profiler/print_layerwise_table.py +++ b/tools/profiler/print_layerwise_table.py @@ -29,48 +29,50 @@ def get_entries(node, curr_depth=0): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("--json-trace", - type=str, - required=True, - help="json trace file output by " - "examples/offline_inference/profiling.py") - parser.add_argument("--phase", - type=str, - required=True, - help="The phase to print the table for. This is either" - "prefill or decode_n, where n is the decode step " - "number") - parser.add_argument("--table", - type=str, - choices=["summary", "model"], - default="summary", - help="Which table to print, the summary table or the " - "layerwise model table") + parser.add_argument( + "--json-trace", + type=str, + required=True, + help="json trace file output by examples/offline_inference/profiling.py", + ) + parser.add_argument( + "--phase", + type=str, + required=True, + help="The phase to print the table for. This is either" + "prefill or decode_n, where n is the decode step " + "number", + ) + parser.add_argument( + "--table", + type=str, + choices=["summary", "model"], + default="summary", + help="Which table to print, the summary table or the layerwise model table", + ) args = parser.parse_args() with open(args.json_trace) as f: profile_data = json.load(f) - assert args.phase in profile_data, \ - (f"Cannot find phase {args.phase} in profile data. Choose one among" - f'{[x for x in profile_data.keys() if "prefill" in x or "decode" in x]}') #noqa + assert args.phase in profile_data, ( + f"Cannot find phase {args.phase} in profile data. Choose one among" + f"{[x for x in profile_data if 'prefill' in x or 'decode' in x]}" + ) # noqa if args.table == "summary": entries_and_depths = flatten_entries( - SummaryStatsEntry, profile_data[args.phase]["summary_stats"]) - column_widths = dict(name=80, - cuda_time_us=12, - pct_cuda_time=12, - invocations=15) + SummaryStatsEntry, profile_data[args.phase]["summary_stats"] + ) + column_widths = dict(name=80, cuda_time_us=12, pct_cuda_time=12, invocations=15) elif args.table == "model": entries_and_depths = flatten_entries( - ModelStatsEntry, profile_data[args.phase]["model_stats"]) - column_widths = dict(name=60, - cpu_time_us=12, - cuda_time_us=12, - pct_cuda_time=12, - trace=60) + ModelStatsEntry, profile_data[args.phase]["model_stats"] + ) + column_widths = dict( + name=60, cpu_time_us=12, cuda_time_us=12, pct_cuda_time=12, trace=60 + ) # indent entry names based on the depth entries = [] @@ -78,7 +80,8 @@ def get_entries(node, curr_depth=0): entry.name = indent_string( entry.name, indent=depth, - indent_style=lambda indent: "|" + "-" * indent + " ") + indent_style=lambda indent: "|" + "-" * indent + " ", + ) entries.append(entry) TablePrinter(type(entries[0]), column_widths).print_table(entries) diff --git a/tools/profiler/visualize_layerwise_profile.py b/tools/profiler/visualize_layerwise_profile.py index 038d3c44f..ed4bf0beb 100644 --- a/tools/profiler/visualize_layerwise_profile.py +++ b/tools/profiler/visualize_layerwise_profile.py @@ -7,7 +7,7 @@ import math import os from pathlib import Path -from typing import Any, Optional +from typing import Any import matplotlib.pyplot as plt import pandas as pd @@ -18,17 +18,18 @@ def largest_dist_from_leaf(node: dict, depth: int = 0): if len(node["children"]) == 0: return depth - return max([ - largest_dist_from_leaf(child, depth=depth + 1) - for child in node["children"] - ]) - - -def get_entries_at_depth(depth: int, - entries_and_traces: list[tuple[Any, Any]], - node: dict, - curr_depth: int = 0, - trace=()): + return max( + [largest_dist_from_leaf(child, depth=depth + 1) for child in node["children"]] + ) + + +def get_entries_at_depth( + depth: int, + entries_and_traces: list[tuple[Any, Any]], + node: dict, + curr_depth: int = 0, + trace=(), +): # assert that the query is at kernel or module level assert depth == -1 or depth == -2 @@ -40,21 +41,18 @@ def get_entries_at_depth(depth: int, if largest_dist_from_leaf(node) == (abs(depth) - 1): entries_and_traces.append((node["entry"], trace)) - trace = (node["entry"]["name"], ) + trace + trace = (node["entry"]["name"],) + trace for child in node["children"]: - get_entries_at_depth(depth, - entries_and_traces, - child, - curr_depth=curr_depth + 1, - trace=trace) + get_entries_at_depth( + depth, entries_and_traces, child, curr_depth=curr_depth + 1, trace=trace + ) def fold_nodes(root: dict, nodes_to_fold: list[str]): - stack: list[dict] = [root] while len(stack) != 0: node = stack.pop() - if node['entry']['name'] in nodes_to_fold: + if node["entry"]["name"] in nodes_to_fold: node["children"] = [] continue for child in node["children"]: @@ -76,9 +74,7 @@ def trim_string_back(string: str, width: int) -> str: def shorten_plot_legend_strings(legend, max_char_len: int): for t in legend.get_texts(): - t.set_text( - trim_string_back(abbreviate_known_names(t.get_text()), - max_char_len)) + t.set_text(trim_string_back(abbreviate_known_names(t.get_text()), max_char_len)) def abbreviate_known_names(name: str) -> str: @@ -108,50 +104,54 @@ def all_the_same(items) -> bool: names.add(entry["name"]) for name in non_unique_names: - entries_and_traces_with_name = [(entry, trace) - for entry, trace in entries_and_traces - if entry["name"] == name] + entries_and_traces_with_name = [ + (entry, trace) + for entry, trace in entries_and_traces + if entry["name"] == name + ] - zipped_traces = list( - zip(*[trace for _, trace in entries_and_traces_with_name])) + zipped_traces = list(zip(*[trace for _, trace in entries_and_traces_with_name])) first_trace_difference = next( - (i for i, trace_eles in enumerate(zipped_traces) - if not all_the_same(trace_eles)), None) + ( + i + for i, trace_eles in enumerate(zipped_traces) + if not all_the_same(trace_eles) + ), + None, + ) if first_trace_difference is None: - # can't create a unique name, leave them names as the + # can't create a unique name, leave the names as they # are they will get aggregated by the pivot_table call continue for entry, trace in entries_and_traces_with_name: - entry["name"] = " <- ".join((entry["name"], ) + - trace[:first_trace_difference + 1]) + entry["name"] = " <- ".join( + (entry["name"],) + trace[: first_trace_difference + 1] + ) ## Operation grouping utils #### -''' +""" Group operations in the given dataframe by some high-level ops like, - gemms - attention - rms_norm etc. -''' - +""" -def group_trace_by_operations(trace_df: pd.DataFrame) -> pd.DataFrame: +def group_trace_by_operations(trace_df: "pd.DataFrame") -> "pd.DataFrame": def is_rms_norm(op_name: str): if "rms_norm_kernel" in op_name: return True def is_attention_block(op_name: str): - if "flash_fwd" in op_name or \ - "reshape_and_cache_flash_kernel" in op_name: + if "flash_fwd" in op_name or "reshape_and_cache_flash_kernel" in op_name: return True def is_quant(op_name: str): - if "scaled_fp8_quant" in op_name or \ - "scaled_int8_quant" in op_name: + if "scaled_fp8_quant" in op_name or "scaled_int8_quant" in op_name: return True # LoRA ops @@ -168,24 +168,27 @@ def is_bgmv_expand(op_name: str): return "bgmv_expand" in op_name def is_cutlass_gemm_op(op_name: str): - return "void cutlass::Kernel" in op_name or \ - "void cutlass::device_kernel" in op_name + return ( + "void cutlass::Kernel" in op_name + or "void cutlass::device_kernel" in op_name + ) def is_gemm_op(op_name: str): if is_quant(op_name): return False - return is_cutlass_gemm_op(op_name) or \ - "xmma_gemm" in op_name or \ - "gemv2T_kernel" in op_name or \ - "splitKreduce" in op_name or \ - "s16816gemm" in op_name + return ( + is_cutlass_gemm_op(op_name) + or "xmma_gemm" in op_name + or "gemv2T_kernel" in op_name + or "splitKreduce" in op_name + or "s16816gemm" in op_name + ) def is_elementwise_op(op_name: str): return "elementwise_kernel" in op_name def is_mem_op(op_name: str): - return "memcpy" in op_name.lower() or \ - "memset" in op_name.lower() + return "memcpy" in op_name.lower() or "memset" in op_name.lower() def is_vocab_embedding_op(op_name: str): return "vocabparallelembed" in op_name.lower() @@ -195,17 +198,15 @@ def is_nccl_op(op_name: str): return "nccl" in op_name.lower() def is_nccl_all_reduce(op_name: str): - return is_nccl_op(op_name) and \ - ("all_reduce" in op_name.lower() or \ - "allreduce" in op_name.lower()) + return is_nccl_op(op_name) and ( + "all_reduce" in op_name.lower() or "allreduce" in op_name.lower() + ) def is_nccl_gather(op_name: str): - return is_nccl_op(op_name) and \ - "gather" in op_name.lower() + return is_nccl_op(op_name) and "gather" in op_name.lower() def is_nccl_broadcast(op_name: str): - return is_nccl_op(op_name) and \ - "broadcast" in op_name.lower() + return is_nccl_op(op_name) and "broadcast" in op_name.lower() # Reduce ops types def is_cross_device_reduce_1stage(op_name: str): @@ -269,114 +270,122 @@ def is_reduce_kernel(op_name: str): ops = list(filter(lambda x: x not in nccl_other_ops, ops)) cross_device_reduce_1stage_ops = list( - filter(lambda x: is_cross_device_reduce_1stage(x), ops)) + filter(lambda x: is_cross_device_reduce_1stage(x), ops) + ) ops = list(filter(lambda x: x not in cross_device_reduce_1stage_ops, ops)) cross_device_reduce_2stage_ops = list( - filter(lambda x: is_cross_device_reduce_2stage(x), ops)) + filter(lambda x: is_cross_device_reduce_2stage(x), ops) + ) ops = list(filter(lambda x: x not in cross_device_reduce_2stage_ops, ops)) - custom_ar_all_reduce_ops = list( - filter(lambda x: is_custom_ar_all_reduce(x), ops)) + custom_ar_all_reduce_ops = list(filter(lambda x: is_custom_ar_all_reduce(x), ops)) ops = list(filter(lambda x: x not in custom_ar_all_reduce_ops, ops)) reduce_kernel_ops = list(filter(lambda x: is_reduce_kernel(x), ops)) ops = list(filter(lambda x: x not in reduce_kernel_ops, ops)) if len(attention_ops): - trace_df['attention'] = trace_df[attention_ops].agg("sum", axis=1) + trace_df["attention"] = trace_df[attention_ops].agg("sum", axis=1) if len(quant_ops): - trace_df['quant_ops'] = trace_df[quant_ops].agg("sum", axis=1) + trace_df["quant_ops"] = trace_df[quant_ops].agg("sum", axis=1) if len(sgmv_shrink_ops): - trace_df['sgmv_shrink_ops'] = trace_df[sgmv_shrink_ops].agg("sum", - axis=1) + trace_df["sgmv_shrink_ops"] = trace_df[sgmv_shrink_ops].agg("sum", axis=1) if len(sgmv_expand_ops): - trace_df['sgmv_expand_ops'] = trace_df[sgmv_expand_ops].agg("sum", - axis=1) + trace_df["sgmv_expand_ops"] = trace_df[sgmv_expand_ops].agg("sum", axis=1) if len(bgmv_shrink_ops): - trace_df['bgmv_shrink_ops'] = trace_df[bgmv_shrink_ops].agg("sum", - axis=1) + trace_df["bgmv_shrink_ops"] = trace_df[bgmv_shrink_ops].agg("sum", axis=1) if len(bgmv_expand_ops): - trace_df['bgmv_expand_ops'] = trace_df[bgmv_expand_ops].agg("sum", - axis=1) + trace_df["bgmv_expand_ops"] = trace_df[bgmv_expand_ops].agg("sum", axis=1) if len(cutlass_gemm_ops): - trace_df['cutlass_gemm_ops'] = trace_df[cutlass_gemm_ops].agg("sum", - axis=1) + trace_df["cutlass_gemm_ops"] = trace_df[cutlass_gemm_ops].agg("sum", axis=1) if len(gemm_ops): - trace_df['gemm_ops'] = trace_df[gemm_ops].agg("sum", axis=1) + trace_df["gemm_ops"] = trace_df[gemm_ops].agg("sum", axis=1) if len(rms_norm_ops): - trace_df['rms_norm_ops'] = trace_df[rms_norm_ops].agg("sum", axis=1) + trace_df["rms_norm_ops"] = trace_df[rms_norm_ops].agg("sum", axis=1) if len(vocab_embed_ops): - trace_df['vocab_embed_ops'] = trace_df[vocab_embed_ops].agg("sum", - axis=1) + trace_df["vocab_embed_ops"] = trace_df[vocab_embed_ops].agg("sum", axis=1) if len(mem_ops): - trace_df['mem_ops'] = trace_df[mem_ops].agg("sum", axis=1) + trace_df["mem_ops"] = trace_df[mem_ops].agg("sum", axis=1) if len(elementwise_ops): - trace_df['elementwise_ops'] = trace_df[elementwise_ops].agg("sum", - axis=1) + trace_df["elementwise_ops"] = trace_df[elementwise_ops].agg("sum", axis=1) if len(nccl_all_reduce_ops): - trace_df['nccl_all_reduce_ops'] = trace_df[nccl_all_reduce_ops].agg( - "sum", axis=1) + trace_df["nccl_all_reduce_ops"] = trace_df[nccl_all_reduce_ops].agg( + "sum", axis=1 + ) if len(nccl_gather_ops): - trace_df['nccl_gather_ops'] = trace_df[nccl_gather_ops].agg("sum", - axis=1) + trace_df["nccl_gather_ops"] = trace_df[nccl_gather_ops].agg("sum", axis=1) if len(nccl_broadcast_ops): - trace_df['nccl_broadcast_ops'] = trace_df[nccl_broadcast_ops].agg( - "sum", axis=1) + trace_df["nccl_broadcast_ops"] = trace_df[nccl_broadcast_ops].agg("sum", axis=1) if len(nccl_other_ops): - trace_df['nccl_other_ops'] = trace_df[nccl_other_ops].agg("sum", - axis=1) + trace_df["nccl_other_ops"] = trace_df[nccl_other_ops].agg("sum", axis=1) if len(cross_device_reduce_1stage_ops): - trace_df['cross_device_reduce_1stage_ops'] = trace_df[ - cross_device_reduce_1stage_ops].agg("sum", axis=1) + trace_df["cross_device_reduce_1stage_ops"] = trace_df[ + cross_device_reduce_1stage_ops + ].agg("sum", axis=1) if len(cross_device_reduce_2stage_ops): - trace_df['cross_device_reduce_2stage_ops'] = trace_df[ - cross_device_reduce_2stage_ops].agg("sum", axis=1) + trace_df["cross_device_reduce_2stage_ops"] = trace_df[ + cross_device_reduce_2stage_ops + ].agg("sum", axis=1) if len(custom_ar_all_reduce_ops): - trace_df['custom_ar_all_reduce_ops'] = trace_df[ - custom_ar_all_reduce_ops].agg("sum", axis=1) + trace_df["custom_ar_all_reduce_ops"] = trace_df[custom_ar_all_reduce_ops].agg( + "sum", axis=1 + ) if len(reduce_kernel_ops): - trace_df['reduce_kernel_ops'] = trace_df[reduce_kernel_ops].agg("sum", - axis=1) - - trace_df.drop(attention_ops + quant_ops + sgmv_shrink_ops + - sgmv_expand_ops + bgmv_shrink_ops + bgmv_expand_ops + - cutlass_gemm_ops + gemm_ops + rms_norm_ops + - vocab_embed_ops + mem_ops + elementwise_ops + - nccl_all_reduce_ops + nccl_gather_ops + nccl_broadcast_ops + - nccl_other_ops + cross_device_reduce_1stage_ops + - cross_device_reduce_2stage_ops + custom_ar_all_reduce_ops + - reduce_kernel_ops, - axis=1, - inplace=True) + trace_df["reduce_kernel_ops"] = trace_df[reduce_kernel_ops].agg("sum", axis=1) + + trace_df.drop( + attention_ops + + quant_ops + + sgmv_shrink_ops + + sgmv_expand_ops + + bgmv_shrink_ops + + bgmv_expand_ops + + cutlass_gemm_ops + + gemm_ops + + rms_norm_ops + + vocab_embed_ops + + mem_ops + + elementwise_ops + + nccl_all_reduce_ops + + nccl_gather_ops + + nccl_broadcast_ops + + nccl_other_ops + + cross_device_reduce_1stage_ops + + cross_device_reduce_2stage_ops + + custom_ar_all_reduce_ops + + reduce_kernel_ops, + axis=1, + inplace=True, + ) return trace_df ## Data plotting utils #### -def plot_trace_df(traces_df: pd.DataFrame, - plot_metric: str, - plot_title: str, - output: Optional[Path] = None): - - def get_phase_description(traces_df: pd.DataFrame, phase: str) -> str: +def plot_trace_df( + traces_df: "pd.DataFrame", + plot_metric: str, + plot_title: str, + output: Path | None = None, +): + def get_phase_description(traces_df: "pd.DataFrame", phase: str) -> str: phase_df = traces_df.query(f'phase == "{phase}"') - descs = phase_df['phase_desc'].to_list() + descs = phase_df["phase_desc"].to_list() assert all([desc == descs[0] for desc in descs]) return descs[0] - phases = traces_df['phase'].unique() + phases = traces_df["phase"].unique() phase_descs = [get_phase_description(traces_df, p) for p in phases] - traces_df = traces_df.pivot_table(index="phase", - columns="name", - values=plot_metric, - aggfunc="sum") + traces_df = traces_df.pivot_table( + index="phase", columns="name", values=plot_metric, aggfunc="sum" + ) traces_df = group_trace_by_operations(traces_df) @@ -396,20 +405,19 @@ def get_phase_description(traces_df: pd.DataFrame, phase: str) -> str: # Write the values as text on the bars for bar in ax.patches: if bar.get_height() != 0: - ax.text(bar.get_x() + bar.get_width() / 2, - bar.get_height() / 2 + bar.get_y(), - f"{round(bar.get_height(), 2)}", - ha='center', - color='w', - weight='bold', - size=5) + ax.text( + bar.get_x() + bar.get_width() / 2, + bar.get_height() / 2 + bar.get_y(), + f"{round(bar.get_height(), 2)}", + ha="center", + color="w", + weight="bold", + size=5, + ) # Setup legend handles, labels = plt.gca().get_legend_handles_labels() - legend = fig.legend(handles, - labels, - loc='center left', - bbox_to_anchor=(1, 1)) + legend = fig.legend(handles, labels, loc="center left", bbox_to_anchor=(1, 1)) shorten_plot_legend_strings(legend, 50) # Setup labels and title @@ -417,21 +425,20 @@ def get_phase_description(traces_df: pd.DataFrame, phase: str) -> str: ax.set_ylabel(plot_metric) plt.suptitle(plot_title) - plt.savefig(output, bbox_inches='tight') + plt.savefig(output, bbox_inches="tight") print("Created: ", output) def main( - json_trace: Path, - output_directory: Path, - depth: int, # Fetch/Plot operations at this depth of the Json tree - plot_metric: str, - make_names_unique: bool, - top_k: int, - json_nodes_to_fold: list[str]): - - def prepare_data(profile_json: dict, step_keys: list[str]) -> pd.DataFrame: - + json_trace: Path, + output_directory: Path, + depth: int, # Fetch/Plot operations at this depth of the Json tree + plot_metric: str, + make_names_unique: bool, + top_k: int, + json_nodes_to_fold: list[str], +): + def prepare_data(profile_json: dict, step_keys: list[str]) -> "pd.DataFrame": def get_entries_and_traces(key: str): entries_and_traces: list[tuple[Any, Any]] = [] for root in profile_json[key]["summary_stats"]: @@ -441,16 +448,14 @@ def get_entries_and_traces(key: str): get_entries_at_depth(depth, entries_and_traces, root) return entries_and_traces - def keep_only_top_entries(df: pd.DataFrame, - metric: str, - top_k: int = 9) -> pd.DataFrame: - df.loc[df.nsmallest(len(df) - top_k + 1, metric).index, - ["name"]] = "others" + def keep_only_top_entries( + df: "pd.DataFrame", metric: str, top_k: int = 9 + ) -> "pd.DataFrame": + df.loc[df.nsmallest(len(df) - top_k + 1, metric).index, ["name"]] = "others" return df def get_phase_description(key: str) -> str: - num_running_seqs = profile_json[key]['metadata'][ - 'num_running_seqs'] + num_running_seqs = profile_json[key]["metadata"]["num_running_seqs"] if num_running_seqs is not None: return f"{key}-seqs-{num_running_seqs}" else: @@ -466,20 +471,24 @@ def get_phase_description(key: str) -> str: # To pandas dataframe trace_dfs = list( - map(lambda t: pd.DataFrame([entry for entry, _ in t]).fillna(0), - traces)) + map(lambda t: pd.DataFrame([entry for entry, _ in t]).fillna(0), traces) + ) # Respect top_k if top_k: trace_dfs = list( map( lambda trace_df: keep_only_top_entries( - trace_df, "cuda_time_us", top_k), trace_dfs)) + trace_df, "cuda_time_us", top_k + ), + trace_dfs, + ) + ) # Fill in information about the step-keys for trace_df, step_key in zip(trace_dfs, step_keys): - trace_df['phase'] = step_key - trace_df['phase_desc'] = get_phase_description(step_key) + trace_df["phase"] = step_key + trace_df["phase_desc"] = get_phase_description(step_key) # Combine all data frames so they can be put in a single plot traces_df = pd.concat(trace_dfs) @@ -492,17 +501,23 @@ def get_phase_description(key: str) -> str: def make_plot_title_suffix(profile_json: dict) -> str: context = profile_json["context"] - sparsity = context.get('sparsity', None) - run_type = \ - f'Run {context["num_steps"]} steps' if context['num_steps'] else \ - (f'Complete {context["complete_num_requests_per_step"]} per ' - f'step; Run till completion') - return (f"{context['engine_args']['model']}\n" - f"Batch={context['batch_size']}, " - f"PromptLen={context['prompt_len']}, " - f"NumGpus={context['engine_args']['tensor_parallel_size']}" - f"{', Sparsity ' + sparsity if sparsity else ''}\n" - f"Run Type: {run_type}") + sparsity = context.get("sparsity", None) + run_type = ( + f"Run {context['num_steps']} steps" + if context["num_steps"] + else ( + f"Complete {context['complete_num_requests_per_step']} per " + f"step; Run till completion" + ) + ) + return ( + f"{context['engine_args']['model']}\n" + f"Batch={context['batch_size']}, " + f"PromptLen={context['prompt_len']}, " + f"NumGpus={context['engine_args']['tensor_parallel_size']}" + f"{', Sparsity ' + sparsity if sparsity else ''}\n" + f"Run Type: {run_type}" + ) profile_json = None with open(json_trace) as f: @@ -511,14 +526,14 @@ def make_plot_title_suffix(profile_json: dict) -> str: # Get all `llm.generate.step()` profile step_traces = list(profile_json.keys()) - assert (step_traces[0] == 'context') + assert step_traces[0] == "context" step_traces = step_traces[1:] # have only prefill and decodes prefills = list(filter(lambda x: "prefill" in x, step_traces)) all_decodes = list(filter(lambda x: "decode" in x, step_traces)) assert len(prefills) + len(all_decodes) == len(step_traces) assert len(prefills) == 1 - decodes = all_decodes[::args.step_plot_interval] + decodes = all_decodes[:: args.step_plot_interval] if decodes[-1] != all_decodes[-1]: # Always have the last decode decodes.append(all_decodes[-1]) @@ -528,48 +543,63 @@ def make_plot_title_suffix(profile_json: dict) -> str: plot_title_suffix = make_plot_title_suffix(profile_json) - plot_trace_df(prefill_traces, plot_metric, "prefill " + plot_title_suffix, - output_directory / Path("prefill.png")) - plot_trace_df(decode_traces, plot_metric, "decodes " + plot_title_suffix, - output_directory / Path("decode_steps.png")) + plot_trace_df( + prefill_traces, + plot_metric, + "prefill " + plot_title_suffix, + output_directory / Path("prefill.png"), + ) + plot_trace_df( + decode_traces, + plot_metric, + "decodes " + plot_title_suffix, + output_directory / Path("decode_steps.png"), + ) if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("--json-trace", - type=str, - required=True, - help="json trace file output by \ - examples/offline_inference/profiling.py") - parser.add_argument("--output-directory", - type=str, - required=False, - help="Directory to output plots") - parser.add_argument("--level", - type=str, - default="module", - choices=["module", "kernel"]) - parser.add_argument("--top-k", - type=int, - default=12, - help="Only graph the top `top_k` entries by time.") - parser.add_argument("--fold-json-node", - nargs='+', - default=['Sampler', 'LogitsProcessor'], - help='Do not plot the children of these nodes. Let, \ + parser.add_argument( + "--json-trace", + type=str, + required=True, + help="json trace file output by \ + examples/offline_inference/profiling.py", + ) + parser.add_argument( + "--output-directory", type=str, required=False, help="Directory to output plots" + ) + parser.add_argument( + "--level", type=str, default="module", choices=["module", "kernel"] + ) + parser.add_argument( + "--top-k", + type=int, + default=12, + help="Only graph the top `top_k` entries by time.", + ) + parser.add_argument( + "--fold-json-node", + nargs="+", + default=["Sampler", "LogitsProcessor"], + help="Do not plot the children of these nodes. Let, \ the node represent the aggregate of all its \ - children') - parser.add_argument("--plot-metric", - type=str, - default="cuda_time_ms", - help='Metric to plot. some options are cuda_time_ms, \ - pct_cuda_time') + children", + ) + parser.add_argument( + "--plot-metric", + type=str, + default="cuda_time_ms", + help="Metric to plot. some options are cuda_time_ms, \ + pct_cuda_time", + ) parser.add_argument( "--step-plot-interval", type=int, default=4, - help="For every `step_plot_interval` steps, plot 1 step") + help="For every `step_plot_interval` steps, plot 1 step", + ) args = parser.parse_args() @@ -583,11 +613,19 @@ def make_plot_title_suffix(profile_json: dict) -> str: else: raise Exception(f"Unexpected level value ({args.level})") - output_directory = args.output_directory if args.output_directory else Path( - args.json_trace).parent + output_directory = ( + args.output_directory if args.output_directory else Path(args.json_trace).parent + ) if not os.path.exists(output_directory): os.makedirs(output_directory) - main(Path(args.json_trace), output_directory, depth, args.plot_metric, - make_names_unique, args.top_k, args.fold_json_node) + main( + Path(args.json_trace), + output_directory, + depth, + args.plot_metric, + make_names_unique, + args.top_k, + args.fold_json_node, + ) diff --git a/tools/report_build_time_ninja.py b/tools/report_build_time_ninja.py index 7386cdd9f..fe3f352fe 100644 --- a/tools/report_build_time_ninja.py +++ b/tools/report_build_time_ninja.py @@ -83,9 +83,9 @@ def WeightedDuration(self): """ # Allow for modest floating-point errors epsilon = 0.000002 - if (self.weighted_duration > self.Duration() + epsilon): - print('{} > {}?'.format(self.weighted_duration, self.Duration())) - assert (self.weighted_duration <= self.Duration() + epsilon) + if self.weighted_duration > self.Duration() + epsilon: + print("{} > {}?".format(self.weighted_duration, self.Duration())) + assert self.weighted_duration <= self.Duration() + epsilon return self.weighted_duration def DescribeTargets(self): @@ -93,10 +93,10 @@ def DescribeTargets(self): # Some build steps generate dozens of outputs - handle them sanely. # The max_length was chosen so that it can fit most of the long # single-target names, while minimizing word wrapping. - result = ', '.join(self.targets) + result = ", ".join(self.targets) max_length = 65 if len(result) > max_length: - result = result[:max_length] + '...' + result = result[:max_length] + "..." return result @@ -106,12 +106,13 @@ def ReadTargets(log, show_all): The result is a list of Target objects.""" header = log.readline() - assert header == '# ninja log v5\n', \ - 'unrecognized ninja log version {!r}'.format(header) + assert header == "# ninja log v5\n", "unrecognized ninja log version {!r}".format( + header + ) targets_dict = {} last_end_seen = 0.0 for line in log: - parts = line.strip().split('\t') + parts = line.strip().split("\t") if len(parts) != 5: # If ninja.exe is rudely halted then the .ninja_log file may be # corrupt. Silently continue. @@ -150,17 +151,17 @@ def ReadTargets(log, show_all): def GetExtension(target, extra_patterns): """Return the file extension that best represents a target. - For targets that generate multiple outputs it is important to return a - consistent 'canonical' extension. Ultimately the goal is to group build steps - by type.""" + For targets that generate multiple outputs it is important to return a + consistent 'canonical' extension. Ultimately the goal is to group build steps + by type.""" for output in target.targets: if extra_patterns: - for fn_pattern in extra_patterns.split(';'): - if fnmatch.fnmatch(output, '*' + fn_pattern + '*'): + for fn_pattern in extra_patterns.split(";"): + if fnmatch.fnmatch(output, "*" + fn_pattern + "*"): return fn_pattern # Not a true extension, but a good grouping. - if output.endswith('type_mappings'): - extension = 'type_mappings' + if output.endswith("type_mappings"): + extension = "type_mappings" break # Capture two extensions if present. For example: file.javac.jar should @@ -170,26 +171,26 @@ def GetExtension(target, extra_patterns): extension = ext2 + ext1 # Preserve the order in the file name. if len(extension) == 0: - extension = '(no extension found)' + extension = "(no extension found)" - if ext1 in ['.pdb', '.dll', '.exe']: - extension = 'PEFile (linking)' + if ext1 in [".pdb", ".dll", ".exe"]: + extension = "PEFile (linking)" # Make sure that .dll and .exe are grouped together and that the # .dll.lib files don't cause these to be listed as libraries break - if ext1 in ['.so', '.TOC']: - extension = '.so (linking)' + if ext1 in [".so", ".TOC"]: + extension = ".so (linking)" # Attempt to identify linking, avoid identifying as '.TOC' break # Make sure .obj files don't get categorized as mojo files - if ext1 in ['.obj', '.o']: + if ext1 in [".obj", ".o"]: break # Jars are the canonical output of java targets. - if ext1 == '.jar': + if ext1 == ".jar": break # Normalize all mojo related outputs to 'mojo'. - if output.count('.mojom') > 0: - extension = 'mojo' + if output.count(".mojom") > 0: + extension = "mojo" break return extension @@ -214,8 +215,8 @@ def SummarizeEntries(entries, extra_step_types): if target.end > latest: latest = target.end total_cpu_time += target.Duration() - task_start_stop_times.append((target.start, 'start', target)) - task_start_stop_times.append((target.end, 'stop', target)) + task_start_stop_times.append((target.start, "start", target)) + task_start_stop_times.append((target.end, "stop", target)) length = latest - earliest weighted_total = 0.0 @@ -241,10 +242,10 @@ def SummarizeEntries(entries, extra_step_types): if num_running > 0: # Update the total weighted time up to this moment. last_weighted_time += (time - last_time) / float(num_running) - if action_name == 'start': + if action_name == "start": # Record the total weighted task time when this task starts. running_tasks[target] = last_weighted_time - if action_name == 'stop': + if action_name == "stop": # Record the change in the total weighted task time while this task # ran. weighted_duration = last_weighted_time - running_tasks[target] @@ -252,13 +253,16 @@ def SummarizeEntries(entries, extra_step_types): weighted_total += weighted_duration del running_tasks[target] last_time = time - assert (len(running_tasks) == 0) + assert len(running_tasks) == 0 # Warn if the sum of weighted times is off by more than half a second. if abs(length - weighted_total) > 500: - print('Warning: Possible corrupt ninja log, results may be ' - 'untrustworthy. Length = {:.3f}, weighted total = {:.3f}'.format( - length, weighted_total)) + print( + "Warning: Possible corrupt ninja log, results may be " + "untrustworthy. Length = {:.3f}, weighted total = {:.3f}".format( + length, weighted_total + ) + ) entries_by_ext = defaultdict(list) for target in entries: @@ -266,32 +270,38 @@ def SummarizeEntries(entries, extra_step_types): entries_by_ext[extension].append(target) for key, values in entries_by_ext.items(): - print(' Longest build steps for {}:'.format(key)) + print(" Longest build steps for {}:".format(key)) values.sort(key=lambda x: x.WeightedDuration()) for target in values[-long_count:]: print( - ' {:8.1f} weighted s to build {} ({:.1f} s elapsed time)'. - format(target.WeightedDuration(), target.DescribeTargets(), - target.Duration())) - - print(' {:.1f} s weighted time ({:.1f} s elapsed time sum, {:1.1f}x ' - 'parallelism)'.format(length, total_cpu_time, - total_cpu_time * 1.0 / length)) - print(' {} build steps completed, average of {:1.2f}/s'.format( - len(entries), - len(entries) / (length))) + " {:8.1f} weighted s to build {} ({:.1f} s elapsed time)".format( + target.WeightedDuration(), + target.DescribeTargets(), + target.Duration(), + ) + ) + + print( + " {:.1f} s weighted time ({:.1f} s elapsed time sum, {:1.1f}x " + "parallelism)".format(length, total_cpu_time, total_cpu_time * 1.0 / length) + ) + print( + " {} build steps completed, average of {:1.2f}/s".format( + len(entries), len(entries) / (length) + ) + ) def main(): - log_file = '.ninja_log' + log_file = ".ninja_log" parser = argparse.ArgumentParser() - parser.add_argument('-C', dest='build_directory', help='Build directory.') + parser.add_argument("-C", dest="build_directory", help="Build directory.") parser.add_argument( - '-s', - '--step-types', - help='semicolon separated fnmatch patterns for build-step grouping') - parser.add_argument('--log-file', - help="specific ninja log file to analyze.") + "-s", + "--step-types", + help="semicolon separated fnmatch patterns for build-step grouping", + ) + parser.add_argument("--log-file", help="specific ninja log file to analyze.") args, _extra_args = parser.parse_known_args() if args.build_directory: log_file = os.path.join(args.build_directory, log_file) @@ -300,17 +310,16 @@ def main(): if args.step_types: # Make room for the extra build types. global long_ext_count - long_ext_count += len(args.step_types.split(';')) + long_ext_count += len(args.step_types.split(";")) try: with open(log_file) as log: entries = ReadTargets(log, False) SummarizeEntries(entries, args.step_types) except OSError: - print('Log file {!r} not found, no build summary created.'.format( - log_file)) + print("Log file {!r} not found, no build summary created.".format(log_file)) return errno.ENOENT -if __name__ == '__main__': +if __name__ == "__main__": sys.exit(main()) diff --git a/tools/validate_config.py b/tools/validate_config.py deleted file mode 100644 index 8b1e955c6..000000000 --- a/tools/validate_config.py +++ /dev/null @@ -1,158 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -""" -Ensures all fields in a config dataclass have default values -and that each field has a docstring. -""" - -import ast -import inspect -import sys - - -def get_attr_docs(cls_node: ast.ClassDef) -> dict[str, str]: - """ - Get any docstrings placed after attribute assignments in a class body. - - Adapted from https://davidism.com/attribute-docstrings/ - https://davidism.com/mit-license/ - """ - - def pairwise(iterable): - """ - Manually implement https://docs.python.org/3/library/itertools.html#itertools.pairwise - - Can be removed when Python 3.9 support is dropped. - """ - iterator = iter(iterable) - a = next(iterator, None) - - for b in iterator: - yield a, b - a = b - - out = {} - - # Consider each pair of nodes. - for a, b in pairwise(cls_node.body): - # Must be an assignment then a constant string. - if (not isinstance(a, (ast.Assign, ast.AnnAssign)) - or not isinstance(b, ast.Expr) - or not isinstance(b.value, ast.Constant) - or not isinstance(b.value.value, str)): - continue - - doc = inspect.cleandoc(b.value.value) - - # An assignment can have multiple targets (a = b = v), but an - # annotated assignment only has one target. - targets = a.targets if isinstance(a, ast.Assign) else [a.target] - - for target in targets: - # Must be assigning to a plain name. - if not isinstance(target, ast.Name): - continue - - out[target.id] = doc - - return out - - -class ConfigValidator(ast.NodeVisitor): - - def __init__(self): - ... - - def visit_ClassDef(self, node): - # Validate class with both @config and @dataclass decorators - decorators = [ - id for d in node.decorator_list if (isinstance(d, ast.Name) and ( - (id := d.id) == 'config' or id == 'dataclass')) or - (isinstance(d, ast.Call) and (isinstance(d.func, ast.Name) and - (id := d.func.id) == 'dataclass')) - ] - - if set(decorators) == {'config', 'dataclass'}: - validate_class(node) - elif set(decorators) == {'config'}: - fail( - f"Class {node.name} with config decorator must be a dataclass.", - node) - - self.generic_visit(node) - - -def validate_class(class_node: ast.ClassDef): - attr_docs = get_attr_docs(class_node) - - for stmt in class_node.body: - # A field is defined as a class variable that has a type annotation. - if isinstance(stmt, ast.AnnAssign): - # Skip ClassVar - # see https://docs.python.org/3/library/dataclasses.html#class-variables - if isinstance(stmt.annotation, ast.Subscript) and isinstance( - stmt.annotation.value, - ast.Name) and stmt.annotation.value.id == "ClassVar": - continue - - if isinstance(stmt.target, ast.Name): - field_name = stmt.target.id - if stmt.value is None: - fail( - f"Field '{field_name}' in {class_node.name} must have " - "a default value.", stmt) - - if field_name not in attr_docs: - fail( - f"Field '{field_name}' in {class_node.name} must have " - "a docstring.", stmt) - - if isinstance(stmt.annotation, ast.Subscript) and \ - isinstance(stmt.annotation.value, ast.Name) \ - and stmt.annotation.value.id == "Union" and \ - isinstance(stmt.annotation.slice, ast.Tuple): - args = stmt.annotation.slice.elts - literal_args = [ - arg for arg in args - if isinstance(arg, ast.Subscript) and isinstance( - arg.value, ast.Name) and arg.value.id == "Literal" - ] - if len(literal_args) > 1: - fail( - f"Field '{field_name}' in {class_node.name} must " - "use a single " - "Literal type. Please use 'Literal[Literal1, " - "Literal2]' instead of 'Union[Literal1, Literal2]'" - ".", stmt) - - -def validate_ast(tree: ast.stmt): - ConfigValidator().visit(tree) - - -def validate_file(file_path: str): - try: - print(f"validating {file_path} config dataclasses ", end="") - with open(file_path, encoding="utf-8") as f: - source = f.read() - - tree = ast.parse(source, filename=file_path) - validate_ast(tree) - except ValueError as e: - print(e) - SystemExit(2) - else: - print("✅") - - -def fail(message: str, node: ast.stmt): - raise ValueError(f"❌ line({node.lineno}): {message}") - - -def main(): - for filename in sys.argv[1:]: - validate_file(filename) - - -if __name__ == "__main__": - main() diff --git a/use_existing_metax.py b/use_existing_metax.py index 00ac18128..06124613b 100644 --- a/use_existing_metax.py +++ b/use_existing_metax.py @@ -3,7 +3,7 @@ import glob -requires_files = glob.glob('requirements/*.txt') +requires_files = glob.glob("requirements/*.txt") requires_files += ["pyproject.toml"] for file in requires_files: print(f">>> cleaning {file}") @@ -11,11 +11,11 @@ lines = f.readlines() if "+metax" in "".join(lines).lower(): print("removed:") - with open(file, 'w') as f: + with open(file, "w") as f: for line in lines: - if '+metax' not in line.lower(): + if "+metax" not in line.lower(): f.write(line) else: print(line.strip()) print(f"<<< done cleaning {file}") - print() \ No newline at end of file + print() diff --git a/use_existing_torch.py b/use_existing_torch.py index 76480f3e5..fd4caa69e 100644 --- a/use_existing_torch.py +++ b/use_existing_torch.py @@ -3,7 +3,7 @@ import glob -requires_files = glob.glob('requirements/*.txt') +requires_files = glob.glob("requirements/*.txt") requires_files += ["pyproject.toml"] for file in requires_files: print(f">>> cleaning {file}") @@ -11,11 +11,11 @@ lines = f.readlines() if "torch" in "".join(lines).lower(): print("removed:") - with open(file, 'w') as f: + with open(file, "w") as f: for line in lines: - if 'torch' not in line.lower(): + if "torch" not in line.lower(): f.write(line) else: print(line.strip()) print(f"<<< done cleaning {file}") - print() \ No newline at end of file + print() diff --git a/vllm_metax/__init__.py b/vllm_metax/__init__.py index 7280dafa5..f5f4cf8b3 100644 --- a/vllm_metax/__init__.py +++ b/vllm_metax/__init__.py @@ -45,29 +45,23 @@ def post_installation(): # Get the path to the vllm distribution vllm_dist_path = Path( - str(importlib.metadata.distribution("vllm").locate_file("vllm"))) + str(importlib.metadata.distribution("vllm").locate_file("vllm")) + ) plugin_dist_path = Path( - str( - importlib.metadata.distribution("vllm_metax").locate_file( - "vllm_metax"))) + str(importlib.metadata.distribution("vllm_metax").locate_file("vllm_metax")) + ) - assert (os.path.exists(vllm_dist_path)) - assert (os.path.exists(plugin_dist_path)) + assert os.path.exists(vllm_dist_path) + assert os.path.exists(plugin_dist_path) print(f"vLLM Dist Location: [{vllm_dist_path}]") print(f"vLLM_plugin Dist Location: [{plugin_dist_path}]") files_to_copy = { - "_C.abi3.so": - vllm_dist_path, - "_moe_C.abi3.so": - vllm_dist_path, - "cumem_allocator.abi3.so": - vllm_dist_path, # workaround for Qwen3-Next # for get_available_device: set cuda - "patch/vllm_substitution/utils.py": - vllm_dist_path / "model_executor/layers/fla/ops/utils.py", + "patch/vllm_substitution/utils.py": vllm_dist_path + / "model_executor/layers/fla/ops/utils.py", } for src_path, dest_path in files_to_copy.items(): @@ -84,6 +78,7 @@ def post_installation(): def collect_env() -> None: from vllm_metax.collect_env import main as collect_env_main + collect_env_main() @@ -105,17 +100,22 @@ def register_ops(): def register_model(): from .models import register_model + register_model() def register_quant_configs(): from vllm_metax.quant_config.awq import MacaAWQConfig # noqa: F401 from vllm_metax.quant_config.awq_marlin import ( # noqa: F401 - MacaAWQMarlinConfig) + MacaAWQMarlinConfig, + ) from vllm_metax.quant_config.gptq import MacaGPTQConfig # noqa: F401 from vllm_metax.quant_config.gptq_marlin import ( # noqa: F401 - MacaGPTQMarlinConfig) + MacaGPTQMarlinConfig, + ) from vllm_metax.quant_config.moe_wna16 import ( # noqa: F401 - MacaMoeWNA16Config) + MacaMoeWNA16Config, + ) from vllm_metax.quant_config.compressed_tensors import ( # noqa: F401 - MacaCompressedTensorsConfig) + MacaCompressedTensorsConfig, + ) diff --git a/vllm_metax/_custom_ops.py b/vllm_metax/_custom_ops.py index e65939f79..2537fbb02 100644 --- a/vllm_metax/_custom_ops.py +++ b/vllm_metax/_custom_ops.py @@ -4,15 +4,22 @@ import vllm.envs as envs -def awq_gemm(input: torch.Tensor, qweight: torch.Tensor, qzeros: torch.Tensor, - scales: torch.Tensor, split_k_iters: int, - temp_space: torch.Tensor, dtype_bf16: bool) -> torch.Tensor: +def awq_gemm( + input: torch.Tensor, + qweight: torch.Tensor, + qzeros: torch.Tensor, + scales: torch.Tensor, + split_k_iters: int, + temp_space: torch.Tensor, + dtype_bf16: bool, +) -> torch.Tensor: if envs.VLLM_USE_TRITON_AWQ: - from vllm.model_executor.layers.quantization.awq_triton import ( - awq_gemm_triton) + from vllm.model_executor.layers.quantization.awq_triton import awq_gemm_triton + return awq_gemm_triton(input, qweight, scales, qzeros, split_k_iters) - return torch.ops._C.awq_gemm(input, qweight, scales, qzeros, split_k_iters, - temp_space, dtype_bf16) + return torch.ops._C.awq_gemm( + input, qweight, scales, qzeros, split_k_iters, temp_space, dtype_bf16 + ) # awq to gptq 4bit conversion @@ -23,34 +30,129 @@ def awq_to_gptq_4bit(qweight: torch.Tensor) -> torch.Tensor: # gptq -def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, - b_gptq_qzeros: torch.Tensor, b_gptq_scales: torch.Tensor, - b_g_idx: torch.Tensor, use_exllama: bool, bit: int, - group_size: int, perm_space: torch.Tensor, - temp_space: torch.Tensor, dtype_bf16: bool) -> torch.Tensor: - return torch.ops._C.gptq_gemm(a, b_q_weight, b_gptq_qzeros, b_gptq_scales, - b_g_idx, use_exllama, bit, group_size, - perm_space, temp_space, dtype_bf16) - - -def fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, - topk_weights: torch.Tensor, topk_ids: torch.Tensor, - sorted_token_ids: torch.Tensor, expert_ids: torch.Tensor, - num_tokens_post_padded: torch.Tensor, - mul_routed_weight: bool, top_k: int, - tileConfig: int) -> None: - torch.ops._moe_C.fused_moe_kernel(A, B, C, topk_weights, topk_ids, - sorted_token_ids, expert_ids, - num_tokens_post_padded, - mul_routed_weight, top_k, tileConfig) - - -def indexer_k_quant_and_cache(k: torch.Tensor, kv_cache: torch.Tensor, - slot_mapping: torch.Tensor, - quant_block_size: int, - kv_cache_dtype: str) -> None: +def gptq_gemm( + a: torch.Tensor, + b_q_weight: torch.Tensor, + b_gptq_qzeros: torch.Tensor, + b_gptq_scales: torch.Tensor, + b_g_idx: torch.Tensor, + use_exllama: bool, + bit: int, + group_size: int, + perm_space: torch.Tensor, + temp_space: torch.Tensor, + dtype_bf16: bool, +) -> torch.Tensor: + return torch.ops._C.gptq_gemm( + a, + b_q_weight, + b_gptq_qzeros, + b_gptq_scales, + b_g_idx, + use_exllama, + bit, + group_size, + perm_space, + temp_space, + dtype_bf16, + ) + + +def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor, bit: int) -> None: + torch.ops._C.gptq_shuffle(q_weight, q_perm, bit) + + +def fused_moe_kernel( + A: torch.Tensor, + B: torch.Tensor, + C: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + sorted_token_ids: torch.Tensor, + expert_ids: torch.Tensor, + num_tokens_post_padded: torch.Tensor, + mul_routed_weight: bool, + top_k: int, + tileConfig: int, +) -> None: + torch.ops._moe_C.fused_moe_kernel( + A, + B, + C, + topk_weights, + topk_ids, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + mul_routed_weight, + top_k, + tileConfig, + ) + + +def indexer_k_quant_and_cache( + k: torch.Tensor, + kv_cache: torch.Tensor, + slot_mapping: torch.Tensor, + quant_block_size: int, + kv_cache_dtype: str, +) -> None: if k.dtype in (torch.bfloat16, torch.float16): torch.ops._C_cache_ops.indexer_k_cache(k, kv_cache, slot_mapping) else: torch.ops._C_cache_ops.indexer_k_quant_and_cache( - k, kv_cache, slot_mapping, quant_block_size, kv_cache_dtype) + k, kv_cache, slot_mapping, quant_block_size, kv_cache_dtype + ) + + +def cp_gather_indexer_k_quant_cache( + kv_cache: torch.Tensor, + dst_k: torch.Tensor, + dst_scale: torch.Tensor, + block_table: torch.Tensor, + cu_seq_lens: torch.Tensor, +) -> None: + if dst_k.dtype in (torch.bfloat16, torch.float16) or dst_scale is None: + torch.ops._C_cache_ops.cp_gather_indexer_k_cache( + kv_cache, dst_k, block_table, cu_seq_lens + ) + else: + torch.ops._C_cache_ops.cp_gather_indexer_k_quant_cache( + kv_cache, dst_k, dst_scale, block_table, cu_seq_lens + ) + + +def top_k_per_row( + logits: torch.Tensor, + row_starts: torch.Tensor, + row_ends: torch.Tensor, + topk_indices: torch.Tensor, + num_rows: int, +) -> None: + torch.ops._C.top_k_per_row( + logits, + row_starts, + row_ends, + topk_indices, + num_rows, + logits.stride(0), + logits.stride(1), + ) + + +def top_k_per_row_decode( + logits: torch.Tensor, + next_n: int, + seq_lens: torch.Tensor, + topk_indices: torch.Tensor, + num_rows: int, +) -> None: + torch.ops._C.top_k_per_row_decode( + logits, + next_n, + seq_lens, + topk_indices, + num_rows, + logits.stride(0), + logits.stride(1), + ) diff --git a/vllm_metax/attention/ops/flashmla.py b/vllm_metax/attention/ops/flashmla.py index 6394a8df6..640388802 100644 --- a/vllm_metax/attention/ops/flashmla.py +++ b/vllm_metax/attention/ops/flashmla.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project # adapted from: https://github.com/deepseek-ai/FlashMLA/blob/main/flash_mla/flash_mla_interface.py -from typing import Optional, Tuple import torch @@ -14,6 +13,7 @@ if current_platform.is_out_of_tree(): try: import flash_mla # noqa: F401 + _flashmla_AVAILABLE = True except ImportError: _flashmla_AVAILABLE = False @@ -22,7 +22,7 @@ # \------------------------ Metax Modification -------------------------/ -def is_flashmla_supported() -> Tuple[bool, Optional[str]]: +def _is_flashmla_available() -> tuple[bool, str | None]: """ Return: is_supported_flag, unsupported_reason (optional). """ @@ -31,35 +31,57 @@ def is_flashmla_supported() -> Tuple[bool, Optional[str]]: return True, None +def is_flashmla_dense_supported() -> tuple[bool, str | None]: + """ + Return: is_supported_flag, unsupported_reason (optional). + """ + is_availble, maybe_reason = _is_flashmla_available() + if not is_availble: + return False, maybe_reason + return True, None + + +def is_flashmla_sparse_supported() -> tuple[bool, str | None]: + """ + Return: is_supported_flag, unsupported_reason (optional). + """ + is_available, maybe_reason = _is_flashmla_available() + if not is_available: + return False, maybe_reason + return True, None + + def get_mla_metadata( - cache_seqlens: torch.Tensor, - num_q_tokens_per_head_k: int, - num_heads_k: int, - num_heads_q: Optional[int] = None, - is_fp8_kvcache: bool = False, - topk: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]: + cache_seqlens: torch.Tensor, + num_q_tokens_per_head_k: int, + num_heads_k: int, + num_heads_q: int | None = None, + is_fp8_kvcache: bool = False, + topk: int | None = None, +) -> tuple[torch.Tensor, torch.Tensor]: """ Arguments: - cache_seqlens: (batch_size), dtype torch.int32. - - num_q_tokens_per_head_k: + - num_q_tokens_per_head_k: Equals to num_q_tokens_per_q_seq * num_heads_q // num_heads_k. - num_heads_k: The number of k heads. - - num_heads_q: - The number of q heads. + - num_heads_q: + The number of q heads. This argument is optional when sparse attention is not enabled - is_fp8_kvcache: Whether the k_cache and v_cache are in fp8 format. - - topk: If not None, sparse attention will be enabled, - and only tokens in the `indices` array + - topk: If not None, sparse attention will be enabled, + and only tokens in the `indices` array passed to `flash_mla_with_kvcache_sm90` will be attended to. Returns: - - tile_scheduler_metadata: + - tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32. - num_splits: (batch_size + 1), dtype torch.int32. """ # /------------------------ Metax Modification -------------------------\ return flash_mla.flash_mla_interface.get_mla_metadata( - cache_seqlens, num_q_tokens_per_head_k, num_heads_k) + cache_seqlens, num_q_tokens_per_head_k, num_heads_k + ) # \------------------------- Metax Modification -------------------------/ @@ -71,13 +93,13 @@ def flash_mla_with_kvcache( head_dim_v: int, tile_scheduler_metadata: torch.Tensor, num_splits: torch.Tensor, - softmax_scale: Optional[float] = None, + softmax_scale: float | None = None, causal: bool = False, - descale_q: Optional[torch.Tensor] = None, - descale_k: Optional[torch.Tensor] = None, + descale_q: torch.Tensor | None = None, + descale_k: torch.Tensor | None = None, is_fp8_kvcache: bool = False, - indices: Optional[torch.Tensor] = None, -) -> Tuple[torch.Tensor, torch.Tensor]: + indices: torch.Tensor | None = None, +) -> tuple[torch.Tensor, torch.Tensor]: """ Arguments: - q: (batch_size, seq_len_q, num_heads_q, head_dim). @@ -85,26 +107,26 @@ def flash_mla_with_kvcache( - block_table: (batch_size, max_num_blocks_per_seq), torch.int32. - cache_seqlens: (batch_size), torch.int32. - head_dim_v: Head dimension of v. - - tile_scheduler_metadata: - (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, + - tile_scheduler_metadata: + (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, returned by get_mla_metadata. - - num_splits: + - num_splits: (batch_size + 1), torch.int32, returned by get_mla_metadata. - - softmax_scale: float. - The scale of QK^T before applying softmax. + - softmax_scale: float. + The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim). - causal: bool. Whether to apply causal attention mask. - - descale_q: (batch_size), + - descale_q: (batch_size), torch.float32. Descaling factors for Q, used for fp8 quantization. - - descale_k: (batch_size), + - descale_k: (batch_size), torch.float32. Descaling factors for K, used for fp8 quantization. - - is_fp8_kvcache: bool. - Whether the k_cache and v_cache are in fp8 format. + - is_fp8_kvcache: bool. + Whether the k_cache and v_cache are in fp8 format. For the format of FP8 KV cache, please refer to README.md - - indices: (batch_size, seq_len_q, topk), torch.int32. - If not None, sparse attention will be enabled, - and only tokens in the `indices` array will be attended to. - Invalid indices should be set to -1 or numbers >= total_seq_len_kv. + - indices: (batch_size, seq_len_q, topk), torch.int32. + If not None, sparse attention will be enabled, + and only tokens in the `indices` array will be attended to. + Invalid indices should be set to -1 or numbers >= total_seq_len_kv. For details about how to set up `indices`, please refer to README.md. Returns: @@ -112,20 +134,19 @@ def flash_mla_with_kvcache( - softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32. """ if softmax_scale is None: - softmax_scale = q.shape[-1]**(-0.5) + softmax_scale = q.shape[-1] ** (-0.5) if indices is not None: # NOTE (zyongye): sparse attention is also causal # since it only attend to the tokens before # but here `causal` should not be specified - assert not causal, \ - "causal must be `false` if sparse attention is enabled." - assert (descale_q is None) == ( - descale_k is None - ), "descale_q and descale_k should be both None or both not None" + assert not causal, "causal must be `false` if sparse attention is enabled." + assert (descale_q is None) == (descale_k is None), ( + "descale_q and descale_k should be both None or both not None" + ) + # /------------------------ Metax Modification -------------------------\ if indices is None and q.element_size() == 1: - raise NotImplementedError( - "flash_mla_with_kvcache does not support fp8 input. ") + raise NotImplementedError("flash_mla_with_kvcache does not support fp8 input. ") else: out, softmax_lse = flash_mla.flash_mla_interface.flash_mla_with_kvcache( q, @@ -145,8 +166,8 @@ def flash_mla_with_kvcache( # Metax: torch_ref def torch_flash_mla_sparse_prefill( - q: torch.Tensor, kv: torch.Tensor, indices: torch.Tensor, - sm_scale: float) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + q: torch.Tensor, kv: torch.Tensor, indices: torch.Tensor, sm_scale: float +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: import math def log2sumexp2(a: torch.Tensor, dim: int) -> torch.Tensor: @@ -164,12 +185,10 @@ def log2sumexp2(a: torch.Tensor, dim: int) -> torch.Tensor: _, topk = indices.shape kvs = torch.index_select( - kvs, 0, - indices.masked_fill(invalid_indices_mask, - 0).flatten()).view(s_q, topk, - d_qk) # [s_q, topk, d_qk] + kvs, 0, indices.masked_fill(invalid_indices_mask, 0).flatten() + ).view(s_q, topk, d_qk) # [s_q, topk, d_qk] attn_score = qs @ kvs.transpose(1, 2) # [s_q, h_q, topk] - attn_score.masked_fill_(invalid_indices_mask.unsqueeze(1), float('-inf')) + attn_score.masked_fill_(invalid_indices_mask.unsqueeze(1), float("-inf")) attn_score *= sm_scale * math.log2(math.e) max_logits = torch.max(attn_score, dim=-1)[0] # [s_q, h_q] lse = log2sumexp2(attn_score, dim=-1) # [s_q, h_q] @@ -185,21 +204,21 @@ def flash_mla_sparse_prefill( indices: torch.Tensor, sm_scale: float, d_v: int = 512, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Sparse attention prefill kernel Args: - q: [s_q, h_q, d_qk], bfloat16 - kv: [s_kv, h_kv, d_qk], bfloat16 - - indices: [s_q, h_kv, topk], int32. + - indices: [s_q, h_kv, topk], int32. Invalid indices should be set to -1 or numbers >= s_kv - sm_scale: float - d_v: The dimension of value vectors. Can only be 512 Returns: - (output, max_logits, lse) - About the definition of output, + About the definition of output, max_logits and lse, please refer to README.md - output: [s_q, h_q, d_v], bfloat16 - max_logits: [s_q, h_q], float @@ -210,7 +229,8 @@ def flash_mla_sparse_prefill( min_seq_len = -1 if (indices == -1).any() else 2049 results = flash_mla.flash_mla_interface.flash_mla_sparse_fwd( - q, kv, indices, sm_scale, d_v, min_seq_len) + q, kv, indices, sm_scale, d_v, min_seq_len + ) # \------------------------- Metax Modification -------------------------/ return results diff --git a/vllm_metax/attention/ops/merge_attn_states.py b/vllm_metax/attention/ops/merge_attn_states.py index d4aa2eed8..a81ff0948 100644 --- a/vllm_metax/attention/ops/merge_attn_states.py +++ b/vllm_metax/attention/ops/merge_attn_states.py @@ -1,22 +1,19 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional import torch from vllm.platforms import current_platform -# yapf: disable def merge_attn_states( output: torch.Tensor, prefix_output: torch.Tensor, prefix_lse: torch.Tensor, suffix_output: torch.Tensor, suffix_lse: torch.Tensor, - output_lse: Optional[torch.Tensor] = None, + output_lse: torch.Tensor | None = None, ) -> None: - # NOTE(DefTruth): Currently, custom merge_attn_states CUDA kernel # is not support for FP8 dtype, fallback to use Triton kernel. def supported_dtypes(o: torch.Tensor) -> bool: @@ -33,14 +30,20 @@ def supported_headdim(o: torch.Tensor) -> bool: return headdim % 8 == 0 # /------------------------ Metax Modification -------------------------\ - if (current_platform.is_out_of_tree() and supported_dtypes(output) - and supported_headdim(output)): + if ( + current_platform.is_out_of_tree() + and supported_dtypes(output) + and supported_headdim(output) + ): # \------------------------ Metax Modification -------------------------/ from vllm._custom_ops import merge_attn_states - return merge_attn_states(output, prefix_output, prefix_lse, - suffix_output, suffix_lse, output_lse) + + return merge_attn_states( + output, prefix_output, prefix_lse, suffix_output, suffix_lse, output_lse + ) else: - from vllm.attention.ops.triton_merge_attn_states import ( - merge_attn_states) - return merge_attn_states(output, prefix_output, prefix_lse, - suffix_output, suffix_lse, output_lse) + from vllm.attention.ops.triton_merge_attn_states import merge_attn_states + + return merge_attn_states( + output, prefix_output, prefix_lse, suffix_output, suffix_lse, output_lse + ) diff --git a/vllm_metax/attention/ops/triton_decode_attention.py b/vllm_metax/attention/ops/triton_decode_attention.py index 8f1fd588c..39ba4a573 100644 --- a/vllm_metax/attention/ops/triton_decode_attention.py +++ b/vllm_metax/attention/ops/triton_decode_attention.py @@ -32,6 +32,7 @@ import logging from packaging import version + from vllm.platforms import current_platform from vllm.triton_utils import tl, triton @@ -44,10 +45,11 @@ # Only print the following warnings when triton version < 3.2.0. # The issue won't affect performance or accuracy. -if version.parse(triton.__version__) < version.parse('3.2.0'): +if version.parse(triton.__version__) < version.parse("3.2.0"): logger.warning( "The following error message 'operation scheduled before its operands' " - "can be ignored.") + "can be ignored." + ) @triton.jit @@ -103,8 +105,7 @@ def _fwd_kernel_stage1( kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS) split_kv_start = kv_len_per_split * split_kv_id - split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, - cur_batch_seq_len) + split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len) e_max = -float("inf") e_sum = 0.0 @@ -114,14 +115,18 @@ def _fwd_kernel_stage1( for start_n in range(split_kv_start, split_kv_end, BLOCK_N): offs_n = start_n + tl.arange(0, BLOCK_N) kv_page_number = tl.load( - Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + - offs_n // PAGE_SIZE, + Req_to_tokens + + stride_req_to_tokens_b * cur_batch_req_idx + + offs_n // PAGE_SIZE, mask=offs_n < split_kv_end, other=0, ) kv_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE - offs_buf_k = (kv_loc[:, None] * stride_buf_kbs + - cur_kv_head * stride_buf_kh + offs_d[None, :]) + offs_buf_k = ( + kv_loc[:, None] * stride_buf_kbs + + cur_kv_head * stride_buf_kh + + offs_d[None, :] + ) k = tl.load( K_Buffer + offs_buf_k, mask=(offs_n[:, None] < split_kv_end) & (mask_d[None, :]), @@ -135,8 +140,11 @@ def _fwd_kernel_stage1( qk = tl.where(offs_n < split_kv_end, qk, float("-inf")) - offs_buf_v = (kv_loc[:, None] * stride_buf_vbs + - cur_kv_head * stride_buf_vh + offs_dv[None, :]) + offs_buf_v = ( + kv_loc[:, None] * stride_buf_vbs + + cur_kv_head * stride_buf_vh + + offs_dv[None, :] + ) v = tl.load( V_Buffer + offs_buf_v, mask=(offs_n[:, None] < split_kv_end) & (mask_dv[None, :]), @@ -152,8 +160,12 @@ def _fwd_kernel_stage1( e_sum = e_sum * re_scale + tl.sum(p, 0) e_max = n_e_max - offs_mid_o = (cur_batch * stride_mid_ob + cur_head * stride_mid_oh + - split_kv_id * stride_mid_os + offs_dv) + offs_mid_o = ( + cur_batch * stride_mid_ob + + cur_head * stride_mid_oh + + split_kv_id * stride_mid_os + + offs_dv + ) tl.store( Att_Out + offs_mid_o, @@ -161,8 +173,12 @@ def _fwd_kernel_stage1( mask=(mask_dv), ) - offs_mid_o_1 = (cur_batch * stride_mid_ob + cur_head * stride_mid_oh + - split_kv_id * stride_mid_os + Lv) + offs_mid_o_1 = ( + cur_batch * stride_mid_ob + + cur_head * stride_mid_oh + + split_kv_id * stride_mid_os + + Lv + ) tl.store( Att_Out + offs_mid_o_1, @@ -288,25 +304,22 @@ def _fwd_grouped_kernel_stage1( cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) cur_batch_req_idx = cur_batch - offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[ - None, :] - q = tl.load(Q + offs_q, - mask=(mask_h[:, None]) & (mask_d[None, :]), - other=0.0) + offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :] + q = tl.load(Q + offs_q, mask=(mask_h[:, None]) & (mask_d[None, :]), other=0.0) if BLOCK_DPE > 0: offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE) mask_dpe = offs_dpe < Lk - off_qpe = (cur_batch * stride_qbs + cur_head[:, None] * stride_qh + - offs_dpe[None, :]) - qpe = tl.load(Q + off_qpe, - mask=(mask_h[:, None]) & (mask_dpe[None, :]), - other=0.0) + off_qpe = ( + cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_dpe[None, :] + ) + qpe = tl.load( + Q + off_qpe, mask=(mask_h[:, None]) & (mask_dpe[None, :]), other=0.0 + ) kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS) split_kv_start = kv_len_per_split * split_kv_id - split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, - cur_batch_seq_len) + split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len) e_max = tl.zeros([BLOCK_H], dtype=tl.float32) - float("inf") e_sum = tl.zeros([BLOCK_H], dtype=tl.float32) @@ -316,14 +329,18 @@ def _fwd_grouped_kernel_stage1( for start_n in range(split_kv_start, split_kv_end, BLOCK_N): offs_n = start_n + tl.arange(0, BLOCK_N) kv_page_number = tl.load( - Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + - offs_n // PAGE_SIZE, + Req_to_tokens + + stride_req_to_tokens_b * cur_batch_req_idx + + offs_n // PAGE_SIZE, mask=offs_n < split_kv_end, other=0, ) kv_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE - offs_buf_k = (kv_loc[None, :] * stride_buf_kbs + - cur_kv_head * stride_buf_kh + offs_d[:, None]) + offs_buf_k = ( + kv_loc[None, :] * stride_buf_kbs + + cur_kv_head * stride_buf_kh + + offs_d[:, None] + ) k = tl.load( K_Buffer + offs_buf_k, mask=(offs_n[None, :] < split_kv_end) & (mask_d[:, None]), @@ -331,13 +348,14 @@ def _fwd_grouped_kernel_stage1( ) qk = tl.dot(q, k.to(q.dtype)) if BLOCK_DPE > 0: - offs_buf_kpe = (kv_loc[None, :] * stride_buf_kbs + - cur_kv_head * stride_buf_kh + - offs_dpe[:, None]) + offs_buf_kpe = ( + kv_loc[None, :] * stride_buf_kbs + + cur_kv_head * stride_buf_kh + + offs_dpe[:, None] + ) kpe = tl.load( K_Buffer + offs_buf_kpe, - mask=(offs_n[None, :] < split_kv_end) & - (mask_dpe[:, None]), + mask=(offs_n[None, :] < split_kv_end) & (mask_dpe[:, None]), other=0.0, ) qk += tl.dot(qpe, kpe.to(qpe.dtype)) @@ -346,11 +364,15 @@ def _fwd_grouped_kernel_stage1( if logit_cap > 0: qk = logit_cap * tanh(qk / logit_cap) - qk = tl.where(mask_h[:, None] & (offs_n[None, :] < split_kv_end), - qk, float("-inf")) + qk = tl.where( + mask_h[:, None] & (offs_n[None, :] < split_kv_end), qk, float("-inf") + ) - offs_buf_v = (kv_loc[:, None] * stride_buf_vbs + - cur_kv_head * stride_buf_vh + offs_dv[None, :]) + offs_buf_v = ( + kv_loc[:, None] * stride_buf_vbs + + cur_kv_head * stride_buf_vh + + offs_dv[None, :] + ) v = tl.load( V_Buffer + offs_buf_v, mask=(offs_n[:, None] < split_kv_end) & (mask_dv[None, :]), @@ -366,9 +388,12 @@ def _fwd_grouped_kernel_stage1( e_sum = e_sum * re_scale + tl.sum(p, 1) e_max = n_e_max - offs_mid_o = (cur_batch * stride_mid_ob + - cur_head[:, None] * stride_mid_oh + - split_kv_id * stride_mid_os + offs_dv[None, :]) + offs_mid_o = ( + cur_batch * stride_mid_ob + + cur_head[:, None] * stride_mid_oh + + split_kv_id * stride_mid_os + + offs_dv[None, :] + ) tl.store( Att_Out + offs_mid_o, @@ -376,8 +401,12 @@ def _fwd_grouped_kernel_stage1( mask=(mask_h[:, None]) & (mask_dv[None, :]), ) - offs_mid_o_1 = (cur_batch * stride_mid_ob + cur_head * stride_mid_oh + - split_kv_id * stride_mid_os + Lv) + offs_mid_o_1 = ( + cur_batch * stride_mid_ob + + cur_head * stride_mid_oh + + split_kv_id * stride_mid_os + + Lv + ) tl.store( Att_Out + offs_mid_o_1, @@ -508,13 +537,12 @@ def _fwd_kernel_stage2( for split_kv_id in range(0, NUM_KV_SPLITS): kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS) split_kv_start = kv_len_per_split * split_kv_id - split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, - cur_batch_seq_len) + split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len) if split_kv_end > split_kv_start: - tv = tl.load(Mid_O + offs_v + split_kv_id * stride_mid_os, - mask=mask_d, - other=0.0) + tv = tl.load( + Mid_O + offs_v + split_kv_id * stride_mid_os, mask=mask_d, other=0.0 + ) tlogic = tl.load(Mid_O + offs_logic + split_kv_id * stride_mid_os) n_e_max = tl.maximum(tlogic, e_max) @@ -557,11 +585,7 @@ def _decode_softmax_reducev_fwd( if is_hip_: # https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html # https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py - extra_kargs = { - "waves_per_eu": 4, - "matrix_instr_nonkdim": 16, - "kpack": 2 - } + extra_kargs = {"waves_per_eu": 4, "matrix_instr_nonkdim": 16, "kpack": 2} grid = (batch, head_num) _fwd_kernel_stage2[grid]( @@ -610,8 +634,9 @@ def decode_attention_fwd_normal( page_size, logit_cap, ) - _decode_softmax_reducev_fwd(attn_logits, q, o, lse, v_buffer, b_seq_len, - num_kv_splits) + _decode_softmax_reducev_fwd( + attn_logits, q, o, lse, v_buffer, b_seq_len, num_kv_splits + ) def decode_attention_fwd_grouped( @@ -640,8 +665,9 @@ def decode_attention_fwd_grouped( page_size, logit_cap, ) - _decode_softmax_reducev_fwd(attn_logits, q, o, lse, v_buffer, b_seq_len, - num_kv_splits) + _decode_softmax_reducev_fwd( + attn_logits, q, o, lse, v_buffer, b_seq_len, num_kv_splits + ) def decode_attention_fwd( diff --git a/vllm_metax/attention/ops/triton_unified_attention.py b/vllm_metax/attention/ops/triton_unified_attention.py index a53781465..2b7242ba2 100644 --- a/vllm_metax/attention/ops/triton_unified_attention.py +++ b/vllm_metax/attention/ops/triton_unified_attention.py @@ -31,9 +31,14 @@ def apply_softcap(S, x): @triton.jit -def find_seq_idx(query_start_len_ptr, target_idx, num_seqs, - BLOCK_Q: tl.constexpr, use_q_block_mode: tl.constexpr): - left = 0 +def find_seq_idx( + query_start_len_ptr, + target_idx, + num_seqs, + BLOCK_Q: tl.constexpr, + use_q_block_mode: tl.constexpr, +): + left = 0 # Metax Modification right = num_seqs while left < right: mid = (left + right) // 2 @@ -100,19 +105,18 @@ def kernel_unified_attention_2d( q_block_global_idx = tl.program_id(0) kv_head_idx = tl.program_id(1) - seq_idx = find_seq_idx(query_start_len_ptr, q_block_global_idx, num_seqs, - BLOCK_Q, True) + seq_idx = find_seq_idx( + query_start_len_ptr, q_block_global_idx, num_seqs, BLOCK_Q, True + ) - q_block_start_idx = tl.load(query_start_len_ptr + - seq_idx) // BLOCK_Q + seq_idx + q_block_start_idx = tl.load(query_start_len_ptr + seq_idx) // BLOCK_Q + seq_idx q_block_local_idx = q_block_global_idx - q_block_start_idx cur_batch_in_all_start_index = tl.load(query_start_len_ptr + seq_idx) cur_batch_in_all_stop_index = tl.load(query_start_len_ptr + seq_idx + 1) - cur_batch_query_len = cur_batch_in_all_stop_index \ - - cur_batch_in_all_start_index + cur_batch_query_len = cur_batch_in_all_stop_index - cur_batch_in_all_start_index if q_block_local_idx * BLOCK_Q >= cur_batch_query_len: return @@ -123,10 +127,12 @@ def kernel_unified_attention_2d( query_pos = q_block_local_idx * BLOCK_Q + offs_m // num_queries_per_kv query_offset_0 = cur_batch_in_all_start_index + query_pos - query_offset_1 = kv_head_idx * num_queries_per_kv + \ - offs_m % num_queries_per_kv - query_offset = (query_offset_0[:, None] * query_stride_0 + - query_offset_1[:, None] * query_stride_1 + offs_d[None, :]) + query_offset_1 = kv_head_idx * num_queries_per_kv + offs_m % num_queries_per_kv + query_offset = ( + query_offset_0[:, None] * query_stride_0 + + query_offset_1[:, None] * query_stride_1 + + offs_d[None, :] + ) dim_mask = tl.where(offs_d < HEAD_SIZE, 1, 0).to(tl.int1) query_mask_0 = tl.where(query_pos < cur_batch_query_len, 1, 0).to(tl.int1) @@ -161,19 +167,24 @@ def kernel_unified_attention_2d( # alibi slope for this head if USE_ALIBI_SLOPES: - alibi_slope = tl.load(alibi_slopes_ptr + query_offset_1, - mask=query_mask_1, - other=0.0) + alibi_slope = tl.load( + alibi_slopes_ptr + query_offset_1, mask=query_mask_1, other=0.0 + ) # query-query attention bias if USE_QQ_BIAS: - qq_bias_row_ptrs = (qq_bias_ptr + query_pos[:, None] * qq_bias_stride_0 - ) # shape: [BLOCK_M] + qq_bias_row_ptrs = ( + qq_bias_ptr + query_pos[:, None] * qq_bias_stride_0 + ) # shape: [BLOCK_M] # compute the length of the longest sequence prefix spanned by any # query token in the current q_block (q_block_local_idx) - max_seq_prefix_len = context_len + q_block_local_idx * BLOCK_Q + ( - BLOCK_M - 1) // num_queries_per_kv + 1 + max_seq_prefix_len = ( + context_len + + q_block_local_idx * BLOCK_Q + + (BLOCK_M - 1) // num_queries_per_kv + + 1 + ) # adjust for potential padding in the last q_block by considering the # actual sequence length @@ -211,23 +222,30 @@ def kernel_unified_attention_2d( seq_offset = j * TILE_SIZE + offs_t tile_mask = seq_offset < max_seq_prefix_len - physical_block_idx = tl.load(block_tables_ptr + block_table_offset + - seq_offset // BLOCK_SIZE).to(tl.int64) + physical_block_idx = tl.load( + block_tables_ptr + block_table_offset + seq_offset // BLOCK_SIZE + ).to(tl.int64) - v_offset = (physical_block_idx[:, None] * stride_v_cache_0 + - kv_head_idx * stride_v_cache_2 + - offs_d[None, :] * stride_v_cache_3 + - (seq_offset % BLOCK_SIZE)[:, None] * stride_v_cache_1) + v_offset = ( + physical_block_idx[:, None] * stride_v_cache_0 + + kv_head_idx * stride_v_cache_2 + + offs_d[None, :] * stride_v_cache_3 + + (seq_offset % BLOCK_SIZE)[:, None] * stride_v_cache_1 + ) - k_offset = (physical_block_idx[None, :] * stride_k_cache_0 + - kv_head_idx * stride_k_cache_2 + - offs_d[:, None] * stride_k_cache_3 + - (seq_offset % BLOCK_SIZE)[None, :] * stride_k_cache_1) + k_offset = ( + physical_block_idx[None, :] * stride_k_cache_0 + + kv_head_idx * stride_k_cache_2 + + offs_d[:, None] * stride_k_cache_3 + + (seq_offset % BLOCK_SIZE)[None, :] * stride_k_cache_1 + ) # K : (HEAD_SIZE, TILE_SIZE) - K_load = tl.load(key_cache_ptr + k_offset, - mask=dim_mask[:, None] & tile_mask[None, :], - other=0.0) + K_load = tl.load( + key_cache_ptr + k_offset, + mask=dim_mask[:, None] & tile_mask[None, :], + other=0.0, + ) if K_load.dtype.is_fp8(): if Q.dtype.is_fp8(): @@ -238,9 +256,11 @@ def kernel_unified_attention_2d( K = K_load # V : (TILE_SIZE, HEAD_SIZE) - V_load = tl.load(value_cache_ptr + v_offset, - mask=dim_mask[None, :] & tile_mask[:, None], - other=0.0) + V_load = tl.load( + value_cache_ptr + v_offset, + mask=dim_mask[None, :] & tile_mask[:, None], + other=0.0, + ) if V_load.dtype.is_fp8(): if Q.dtype.is_fp8(): @@ -260,12 +280,16 @@ def kernel_unified_attention_2d( if USE_SOFTCAP: S = apply_softcap(S, softcap) - S = tl.where(query_mask_1[:, None] & query_mask_0[:, None] & seq_mask, - S, float("-inf")) + S = tl.where( + query_mask_1[:, None] & query_mask_0[:, None] & seq_mask, S, float("-inf") + ) if SLIDING_WINDOW > 0: - S = tl.where((context_len + query_pos[:, None] - seq_offset) - < SLIDING_WINDOW, S, float("-inf")) + S = tl.where( + (context_len + query_pos[:, None] - seq_offset) < SLIDING_WINDOW, + S, + float("-inf"), + ) if USE_ALIBI_SLOPES: S += alibi_slope[:, None] * (seq_offset - context_len) @@ -315,9 +339,11 @@ def kernel_unified_attention_2d( acc = acc * tl.load(out_scale) acc = tl.clamp(acc, FP8_MIN, FP8_MAX) - output_offset = (query_offset_0[:, None] * output_stride_0 + - query_offset_1[:, None] * output_stride_1 + - offs_d[None, :]) + output_offset = ( + query_offset_0[:, None] * output_stride_0 + + query_offset_1[:, None] * output_stride_1 + + offs_d[None, :] + ) tl.store( output_ptr + output_offset, @@ -328,68 +354,67 @@ def kernel_unified_attention_2d( @triton.jit def kernel_unified_attention_3d( - segm_output_ptr, - # [num_tokens, num_query_heads, num_segments, head_size] - segm_max_ptr, # [num_tokens, num_query_heads, num_segments] - segm_expsum_ptr, # [num_tokens, num_query_heads, num_segments] - query_ptr, # [num_tokens, num_query_heads, head_size] - key_cache_ptr, # [num_blks, num_kv_heads, head_size // x, blk_size, x] - value_cache_ptr, # [num_blks, num_kv_heads, head_size, blk_size] - sink_ptr, # [num_query_heads] - block_tables_ptr, # [num_seqs, max_num_blocks_per_seq] - seq_lens_ptr, # [num_seqs] - alibi_slopes_ptr, # [num_query_heads] - qq_bias_ptr, # [num_query_tokens, num_query_tokens] - scale, # float32 - k_scale, # float32 - v_scale, # float32 - softcap, # float32 - num_query_heads: tl.constexpr, # int - num_queries_per_kv: tl.constexpr, # int - block_table_stride: tl.int64, # int - query_stride_0: tl.int64, # int - query_stride_1: tl.int64, # int, should be equal to head_size - qq_bias_stride_0: tl.int64, # int - BLOCK_SIZE: tl.constexpr, # int - TILE_SIZE: tl.constexpr, # int, must be power of 2 - HEAD_SIZE: tl.constexpr, # int - HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2 - USE_ALIBI_SLOPES: tl.constexpr, # bool - USE_QQ_BIAS: tl.constexpr, # bool - USE_SOFTCAP: tl.constexpr, # bool - USE_SINKS: tl.constexpr, # bool - SLIDING_WINDOW: tl.constexpr, # int - stride_k_cache_0: tl.int64, # int - stride_k_cache_1: tl.int64, # int - stride_k_cache_2: tl.int64, # int - stride_k_cache_3: tl.constexpr, # int - stride_v_cache_0: tl.int64, # int - stride_v_cache_1: tl.int64, # int - stride_v_cache_2: tl.int64, # int - stride_v_cache_3: tl.constexpr, # int - query_start_len_ptr, # [num_seqs+1] - BLOCK_Q: tl.constexpr, # int - num_seqs: tl.int32, - BLOCK_M: tl.constexpr, # int - NUM_SEGMENTS_PER_SEQ: tl.constexpr, # int + segm_output_ptr, + # [num_tokens, num_query_heads, num_segments, head_size] + segm_max_ptr, # [num_tokens, num_query_heads, num_segments] + segm_expsum_ptr, # [num_tokens, num_query_heads, num_segments] + query_ptr, # [num_tokens, num_query_heads, head_size] + key_cache_ptr, # [num_blks, num_kv_heads, head_size // x, blk_size, x] + value_cache_ptr, # [num_blks, num_kv_heads, head_size, blk_size] + sink_ptr, # [num_query_heads] + block_tables_ptr, # [num_seqs, max_num_blocks_per_seq] + seq_lens_ptr, # [num_seqs] + alibi_slopes_ptr, # [num_query_heads] + qq_bias_ptr, # [num_query_tokens, num_query_tokens] + scale, # float32 + k_scale, # float32 + v_scale, # float32 + softcap, # float32 + num_query_heads: tl.constexpr, # int + num_queries_per_kv: tl.constexpr, # int + block_table_stride: tl.int64, # int + query_stride_0: tl.int64, # int + query_stride_1: tl.int64, # int, should be equal to head_size + qq_bias_stride_0: tl.int64, # int + BLOCK_SIZE: tl.constexpr, # int + TILE_SIZE: tl.constexpr, # int, must be power of 2 + HEAD_SIZE: tl.constexpr, # int + HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2 + USE_ALIBI_SLOPES: tl.constexpr, # bool + USE_QQ_BIAS: tl.constexpr, # bool + USE_SOFTCAP: tl.constexpr, # bool + USE_SINKS: tl.constexpr, # bool + SLIDING_WINDOW: tl.constexpr, # int + stride_k_cache_0: tl.int64, # int + stride_k_cache_1: tl.int64, # int + stride_k_cache_2: tl.int64, # int + stride_k_cache_3: tl.constexpr, # int + stride_v_cache_0: tl.int64, # int + stride_v_cache_1: tl.int64, # int + stride_v_cache_2: tl.int64, # int + stride_v_cache_3: tl.constexpr, # int + query_start_len_ptr, # [num_seqs+1] + BLOCK_Q: tl.constexpr, # int + num_seqs: tl.int32, + BLOCK_M: tl.constexpr, # int + NUM_SEGMENTS_PER_SEQ: tl.constexpr, # int ): q_block_global_idx = tl.program_id(0) kv_head_idx = tl.program_id(1) segm_idx = tl.program_id(2) - seq_idx = find_seq_idx(query_start_len_ptr, q_block_global_idx, num_seqs, - BLOCK_Q, True) + seq_idx = find_seq_idx( + query_start_len_ptr, q_block_global_idx, num_seqs, BLOCK_Q, True + ) - q_block_start_idx = tl.load(query_start_len_ptr + - seq_idx) // BLOCK_Q + seq_idx + q_block_start_idx = tl.load(query_start_len_ptr + seq_idx) // BLOCK_Q + seq_idx q_block_local_idx = q_block_global_idx - q_block_start_idx cur_batch_in_all_start_index = tl.load(query_start_len_ptr + seq_idx) cur_batch_in_all_stop_index = tl.load(query_start_len_ptr + seq_idx + 1) - cur_batch_query_len = cur_batch_in_all_stop_index \ - - cur_batch_in_all_start_index + cur_batch_query_len = cur_batch_in_all_stop_index - cur_batch_in_all_start_index if q_block_local_idx * BLOCK_Q >= cur_batch_query_len: return @@ -410,10 +435,12 @@ def kernel_unified_attention_3d( query_pos = q_block_local_idx * BLOCK_Q + offs_m // num_queries_per_kv query_offset_0 = cur_batch_in_all_start_index + query_pos - query_offset_1 = kv_head_idx * num_queries_per_kv + \ - offs_m % num_queries_per_kv - query_offset = (query_offset_0[:, None] * query_stride_0 + - query_offset_1[:, None] * query_stride_1 + offs_d[None, :]) + query_offset_1 = kv_head_idx * num_queries_per_kv + offs_m % num_queries_per_kv + query_offset = ( + query_offset_0[:, None] * query_stride_0 + + query_offset_1[:, None] * query_stride_1 + + offs_d[None, :] + ) dim_mask = tl.where(offs_d < HEAD_SIZE, 1, 0).to(tl.int1) query_mask_0 = tl.where(query_pos < cur_batch_query_len, 1, 0).to(tl.int1) @@ -448,19 +475,24 @@ def kernel_unified_attention_3d( # alibi slope for this head if USE_ALIBI_SLOPES: - alibi_slope = tl.load(alibi_slopes_ptr + query_offset_1, - mask=query_mask_1, - other=0.0) + alibi_slope = tl.load( + alibi_slopes_ptr + query_offset_1, mask=query_mask_1, other=0.0 + ) # query-query attention bias if USE_QQ_BIAS: - qq_bias_row_ptrs = (qq_bias_ptr + query_pos[:, None] * qq_bias_stride_0 - ) # shape: [BLOCK_M] + qq_bias_row_ptrs = ( + qq_bias_ptr + query_pos[:, None] * qq_bias_stride_0 + ) # shape: [BLOCK_M] # compute the length of the longest sequence prefix spanned by any # query token in the current q_block (q_block_local_idx) - max_seq_prefix_len = context_len + q_block_local_idx * BLOCK_Q + ( - BLOCK_M - 1) // num_queries_per_kv + 1 + max_seq_prefix_len = ( + context_len + + q_block_local_idx * BLOCK_Q + + (BLOCK_M - 1) // num_queries_per_kv + + 1 + ) # adjust for potential padding in the last q_block by considering the # actual sequence length @@ -473,29 +505,36 @@ def kernel_unified_attention_3d( # iterate through tiles within current segment for j in range( - segm_idx * tiles_per_segment, - min((segm_idx + 1) * tiles_per_segment, num_tiles), + segm_idx * tiles_per_segment, + min((segm_idx + 1) * tiles_per_segment, num_tiles), ): seq_offset = j * TILE_SIZE + offs_t tile_mask = seq_offset < max_seq_prefix_len - physical_block_idx = tl.load(block_tables_ptr + block_table_offset + - seq_offset // BLOCK_SIZE).to(tl.int64) + physical_block_idx = tl.load( + block_tables_ptr + block_table_offset + seq_offset // BLOCK_SIZE + ).to(tl.int64) - v_offset = (physical_block_idx[:, None] * stride_v_cache_0 + - kv_head_idx * stride_v_cache_2 + - offs_d[None, :] * stride_v_cache_3 + - (seq_offset % BLOCK_SIZE)[:, None] * stride_v_cache_1) + v_offset = ( + physical_block_idx[:, None] * stride_v_cache_0 + + kv_head_idx * stride_v_cache_2 + + offs_d[None, :] * stride_v_cache_3 + + (seq_offset % BLOCK_SIZE)[:, None] * stride_v_cache_1 + ) - k_offset = (physical_block_idx[None, :] * stride_k_cache_0 + - kv_head_idx * stride_k_cache_2 + - offs_d[:, None] * stride_k_cache_3 + - (seq_offset % BLOCK_SIZE)[None, :] * stride_k_cache_1) + k_offset = ( + physical_block_idx[None, :] * stride_k_cache_0 + + kv_head_idx * stride_k_cache_2 + + offs_d[:, None] * stride_k_cache_3 + + (seq_offset % BLOCK_SIZE)[None, :] * stride_k_cache_1 + ) # K : (HEAD_SIZE, TILE_SIZE) - K_load = tl.load(key_cache_ptr + k_offset, - mask=dim_mask[:, None] & tile_mask[None, :], - other=0.0) + K_load = tl.load( + key_cache_ptr + k_offset, + mask=dim_mask[:, None] & tile_mask[None, :], + other=0.0, + ) if K_load.dtype.is_fp8(): if Q.dtype.is_fp8(): @@ -506,9 +545,11 @@ def kernel_unified_attention_3d( K = K_load # V : (TILE_SIZE, HEAD_SIZE) - V_load = tl.load(value_cache_ptr + v_offset, - mask=dim_mask[None, :] & tile_mask[:, None], - other=0.0) + V_load = tl.load( + value_cache_ptr + v_offset, + mask=dim_mask[None, :] & tile_mask[:, None], + other=0.0, + ) if V_load.dtype.is_fp8(): if Q.dtype.is_fp8(): @@ -527,12 +568,16 @@ def kernel_unified_attention_3d( if USE_SOFTCAP: S = apply_softcap(S, softcap) - S = tl.where(query_mask_1[:, None] & query_mask_0[:, None] & seq_mask, - S, float("-inf")) + S = tl.where( + query_mask_1[:, None] & query_mask_0[:, None] & seq_mask, S, float("-inf") + ) if SLIDING_WINDOW > 0: - S = tl.where((context_len + query_pos[:, None] - seq_offset) - < SLIDING_WINDOW, S, float("-inf")) + S = tl.where( + (context_len + query_pos[:, None] - seq_offset) < SLIDING_WINDOW, + S, + float("-inf"), + ) if USE_ALIBI_SLOPES: S += alibi_slope[:, None] * (seq_offset - context_len) @@ -577,29 +622,31 @@ def kernel_unified_attention_3d( acc += tl.dot(P.to(V.dtype), V) segm_output_offset = ( - query_offset_0[:, None].to(tl.int64) * - (num_query_heads * NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED) + - query_offset_1[:, None] * (NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED) + - segm_idx * HEAD_SIZE_PADDED + tl.arange(0, HEAD_SIZE_PADDED)[None, :]) + query_offset_0[:, None].to(tl.int64) + * (num_query_heads * NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED) + + query_offset_1[:, None] * (NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED) + + segm_idx * HEAD_SIZE_PADDED + + tl.arange(0, HEAD_SIZE_PADDED)[None, :] + ) tl.store( segm_output_ptr + segm_output_offset, acc, mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None], ) - segm_offset = (query_offset_0.to(tl.int64) * - (num_query_heads * NUM_SEGMENTS_PER_SEQ) + - query_offset_1 * NUM_SEGMENTS_PER_SEQ + segm_idx) + segm_offset = ( + query_offset_0.to(tl.int64) * (num_query_heads * NUM_SEGMENTS_PER_SEQ) + + query_offset_1 * NUM_SEGMENTS_PER_SEQ + + segm_idx + ) tl.store(segm_max_ptr + segm_offset, M, mask=query_mask_0 & query_mask_1) - tl.store(segm_expsum_ptr + segm_offset, - L, - mask=query_mask_0 & query_mask_1) + tl.store(segm_expsum_ptr + segm_offset, L, mask=query_mask_0 & query_mask_1) @triton.jit def reduce_segments( output_ptr, # [num_tokens, num_query_heads, head_size] segm_output_ptr, - #[num_tokens, num_query_heads, max_num_segments, head_size] + # [num_tokens, num_query_heads, max_num_segments, head_size] segm_max_ptr, # [num_tokens, num_query_heads, max_num_segments] segm_expsum_ptr, # [num_tokens, num_query_heads, max_num_segments] seq_lens_ptr, # [num_seqs] @@ -622,8 +669,9 @@ def reduce_segments( query_token_idx = tl.program_id(0) query_head_idx = tl.program_id(1) - seq_idx = find_seq_idx(query_start_len_ptr, query_token_idx, num_seqs, - BLOCK_Q, False) + seq_idx = find_seq_idx( + query_start_len_ptr, query_token_idx, num_seqs, BLOCK_Q, False + ) # sequence len for this particular sequence seq_len = tl.load(seq_lens_ptr + seq_idx) @@ -635,34 +683,32 @@ def reduce_segments( # create masks for subsequent loads act_num_segments = cdiv_fn(seq_len, tiles_per_segment * TILE_SIZE) segm_mask = tl.arange(0, NUM_SEGMENTS_PER_SEQ) < tl.full( - [NUM_SEGMENTS_PER_SEQ], act_num_segments, dtype=tl.int32) - dim_mask = tl.where(tl.arange(0, HEAD_SIZE_PADDED) < HEAD_SIZE, 1, - 0).to(tl.int1) + [NUM_SEGMENTS_PER_SEQ], act_num_segments, dtype=tl.int32 + ) + dim_mask = tl.where(tl.arange(0, HEAD_SIZE_PADDED) < HEAD_SIZE, 1, 0).to(tl.int1) # load segment maxima - segm_offset = (query_token_idx.to(tl.int64) * - (num_query_heads * NUM_SEGMENTS_PER_SEQ) + - query_head_idx * NUM_SEGMENTS_PER_SEQ + - tl.arange(0, NUM_SEGMENTS_PER_SEQ)) - segm_max = tl.load(segm_max_ptr + segm_offset, - mask=segm_mask, - other=float("-inf")) + segm_offset = ( + query_token_idx.to(tl.int64) * (num_query_heads * NUM_SEGMENTS_PER_SEQ) + + query_head_idx * NUM_SEGMENTS_PER_SEQ + + tl.arange(0, NUM_SEGMENTS_PER_SEQ) + ) + segm_max = tl.load(segm_max_ptr + segm_offset, mask=segm_mask, other=float("-inf")) overall_max = tl.max(segm_max) # load and rescale segment exp sums - segm_expsum = tl.load(segm_expsum_ptr + segm_offset, - mask=segm_mask, - other=0.0) + segm_expsum = tl.load(segm_expsum_ptr + segm_offset, mask=segm_mask, other=0.0) segm_expsum = segm_expsum * tl.exp(segm_max - overall_max) overall_expsum = tl.sum(segm_expsum) # load, rescale, and add segment attention outputs segm_output_offset = ( - query_token_idx.to(tl.int64) * - (num_query_heads * NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED) + - query_head_idx * (NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED) + - tl.arange(0, NUM_SEGMENTS_PER_SEQ)[:, None] * HEAD_SIZE_PADDED + - tl.arange(0, HEAD_SIZE_PADDED)[None, :]) + query_token_idx.to(tl.int64) + * (num_query_heads * NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED) + + query_head_idx * (NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED) + + tl.arange(0, NUM_SEGMENTS_PER_SEQ)[:, None] * HEAD_SIZE_PADDED + + tl.arange(0, HEAD_SIZE_PADDED)[None, :] + ) segm_output = tl.load( segm_output_ptr + segm_output_offset, mask=segm_mask[:, None] & dim_mask[None, :], @@ -678,9 +724,11 @@ def reduce_segments( acc = tl.clamp(acc, FP8_MIN, FP8_MAX) # write result - output_offset = (query_token_idx * output_stride_0 + - query_head_idx * output_stride_1 + - tl.arange(0, HEAD_SIZE_PADDED)) + output_offset = ( + query_token_idx * output_stride_0 + + query_head_idx * output_stride_1 + + tl.arange(0, HEAD_SIZE_PADDED) + ) tl.store(output_ptr + output_offset, acc, mask=dim_mask) @@ -707,13 +755,11 @@ def unified_attention( # Optional tensor for sinks sinks=None, ): - assert causal, "Only causal attention is supported" assert q_descale is None, "Q scales not supported" if sinks is not None: - assert sinks.shape[0] == q.shape[1], \ - "Sinks must be num_query_heads size" + assert sinks.shape[0] == q.shape[1], "Sinks must be num_query_heads size" use_alibi_slopes = alibi_slopes is not None use_qq_bias = qq_bias is not None @@ -725,8 +771,9 @@ def unified_attention( num_queries_per_kv = num_query_heads // num_kv_heads head_size = q.shape[2] - BLOCK_M = 16 if num_queries_per_kv <= 16 else triton.next_power_of_2( - num_queries_per_kv) + BLOCK_M = ( + 16 if num_queries_per_kv <= 16 else triton.next_power_of_2(num_queries_per_kv) + ) BLOCK_Q = BLOCK_M // num_queries_per_kv # Ideally we would launch with kernel with: @@ -748,10 +795,12 @@ def unified_attention( # if batch contains a prefill if max_seqlen_q > 1 or total_num_q_blocks * num_kv_heads > 128: - kernel_unified_attention_2d[( - total_num_q_blocks, - num_kv_heads, - )]( + kernel_unified_attention_2d[ + ( + total_num_q_blocks, + num_kv_heads, + ) + ]( output_ptr=out, query_ptr=q, key_cache_ptr=k, @@ -825,52 +874,51 @@ def unified_attention( device=q.device, ) - kernel_unified_attention_3d[( - total_num_q_blocks, num_kv_heads, NUM_SEGMENTS)]( - segm_output_ptr=segm_output, - segm_max_ptr=segm_max, - segm_expsum_ptr=segm_expsum, - query_ptr=q, - key_cache_ptr=k, - value_cache_ptr=v, - sink_ptr=sinks, - block_tables_ptr=block_table, - seq_lens_ptr=seqused_k, - alibi_slopes_ptr=alibi_slopes, - qq_bias_ptr=qq_bias, - scale=softmax_scale, - k_scale=k_descale, - v_scale=v_descale, - softcap=softcap, - num_query_heads=num_query_heads, - num_queries_per_kv=num_queries_per_kv, - block_table_stride=block_table.stride(0), - query_stride_0=q.stride(0), - query_stride_1=q.stride(1), - qq_bias_stride_0=qq_bias.stride(0) if use_qq_bias else 0, - BLOCK_SIZE=block_size, - TILE_SIZE=TILE_SIZE_DECODE, - HEAD_SIZE=head_size, - HEAD_SIZE_PADDED=triton.next_power_of_2(head_size), - USE_ALIBI_SLOPES=use_alibi_slopes, - USE_QQ_BIAS=use_qq_bias, - USE_SOFTCAP=(softcap > 0), - USE_SINKS=(sinks is not None), - SLIDING_WINDOW=(1 + window_size[0]), - stride_k_cache_0=k.stride(0), - stride_k_cache_1=k.stride(1), - stride_k_cache_2=k.stride(2), - stride_k_cache_3=k.stride(3), - stride_v_cache_0=v.stride(0), - stride_v_cache_1=v.stride(1), - stride_v_cache_2=v.stride(2), - stride_v_cache_3=v.stride(3), - query_start_len_ptr=cu_seqlens_q, - BLOCK_Q=BLOCK_Q, - num_seqs=num_seqs, - BLOCK_M=BLOCK_M, - NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS, - ) + kernel_unified_attention_3d[(total_num_q_blocks, num_kv_heads, NUM_SEGMENTS)]( + segm_output_ptr=segm_output, + segm_max_ptr=segm_max, + segm_expsum_ptr=segm_expsum, + query_ptr=q, + key_cache_ptr=k, + value_cache_ptr=v, + sink_ptr=sinks, + block_tables_ptr=block_table, + seq_lens_ptr=seqused_k, + alibi_slopes_ptr=alibi_slopes, + qq_bias_ptr=qq_bias, + scale=softmax_scale, + k_scale=k_descale, + v_scale=v_descale, + softcap=softcap, + num_query_heads=num_query_heads, + num_queries_per_kv=num_queries_per_kv, + block_table_stride=block_table.stride(0), + query_stride_0=q.stride(0), + query_stride_1=q.stride(1), + qq_bias_stride_0=qq_bias.stride(0) if use_qq_bias else 0, + BLOCK_SIZE=block_size, + TILE_SIZE=TILE_SIZE_DECODE, + HEAD_SIZE=head_size, + HEAD_SIZE_PADDED=triton.next_power_of_2(head_size), + USE_ALIBI_SLOPES=use_alibi_slopes, + USE_QQ_BIAS=use_qq_bias, + USE_SOFTCAP=(softcap > 0), + USE_SINKS=(sinks is not None), + SLIDING_WINDOW=(1 + window_size[0]), + stride_k_cache_0=k.stride(0), + stride_k_cache_1=k.stride(1), + stride_k_cache_2=k.stride(2), + stride_k_cache_3=k.stride(3), + stride_v_cache_0=v.stride(0), + stride_v_cache_1=v.stride(1), + stride_v_cache_2=v.stride(2), + stride_v_cache_3=v.stride(3), + query_start_len_ptr=cu_seqlens_q, + BLOCK_Q=BLOCK_Q, + num_seqs=num_seqs, + BLOCK_M=BLOCK_M, + NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS, + ) reduce_segments[(q.shape[0], num_query_heads)]( output_ptr=out, segm_output_ptr=segm_output, @@ -879,8 +927,7 @@ def unified_attention( seq_lens_ptr=seqused_k, num_seqs=num_seqs, num_query_heads=num_query_heads, - out_scale_inv=1 / - output_scale if output_scale is not None else 1.0, + out_scale_inv=1 / output_scale if output_scale is not None else 1.0, output_stride_0=out.stride(0), output_stride_1=out.stride(1), block_table_stride=block_table.stride(0), diff --git a/vllm_metax/attention/utils/fa_utils.py b/vllm_metax/attention/utils/fa_utils.py index c0fd8ddc8..28942164b 100644 --- a/vllm_metax/attention/utils/fa_utils.py +++ b/vllm_metax/attention/utils/fa_utils.py @@ -1,21 +1,21 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional -from vllm import _custom_ops as ops from vllm.attention.utils.fa_utils import logger from vllm.platforms import current_platform -get_scheduler_metadata = None if current_platform.is_out_of_tree(): from vllm import _custom_ops as ops + + get_scheduler_metadata = None reshape_and_cache_flash = ops.reshape_and_cache_flash from flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache # noqa: F401 + get_scheduler_metadata = None -def get_flash_attn_version(requires_alibi: bool = False) -> Optional[int]: +def get_flash_attn_version(requires_alibi: bool = False) -> int | None: logger.info_once( "Using Maca version of flash attention, which only supports version 2." ) @@ -24,7 +24,12 @@ def get_flash_attn_version(requires_alibi: bool = False) -> Optional[int]: def flash_attn_supports_fp8() -> bool: logger.info_once( - "Using Maca version of flash attention, which does not support FP8") + "Using Maca version of flash attention, which does not support FP8" + ) + return False + + +def flash_attn_supports_mla(): return False diff --git a/vllm_metax/collect_env.py b/vllm_metax/collect_env.py index d45ebb0ed..c9e563c50 100644 --- a/vllm_metax/collect_env.py +++ b/vllm_metax/collect_env.py @@ -9,6 +9,7 @@ import os import subprocess import sys + # Unlike the rest of the PyTorch this file must be python2 compliant. # This script outputs relevant system environment info # Run it with `python collect_env.py` or `python -m torch.utils.collect_env` @@ -18,48 +19,50 @@ try: import torch + TORCH_AVAILABLE = True except (ImportError, NameError, AttributeError, OSError): TORCH_AVAILABLE = False # System Environment Information SystemEnv = namedtuple( - 'SystemEnv', + "SystemEnv", [ - 'torch_version', - 'is_debug_build', - 'cuda_compiled_version', - 'gcc_version', - 'clang_version', - 'cmake_version', - 'os', - 'libc_version', - 'python_version', - 'python_platform', - 'is_cuda_available', - 'cuda_runtime_version', - 'maca_runtime_version', - 'bios_version', - 'cuda_module_loading', - 'nvidia_driver_version', - 'nvidia_gpu_models', - 'cudnn_version', - 'pip_version', # 'pip' or 'pip3' - 'pip_packages', - 'conda_packages', - 'hip_compiled_version', - 'hip_runtime_version', - 'miopen_runtime_version', - 'caching_allocator_config', - 'is_xnnpack_available', - 'cpu_info', - 'rocm_version', # vllm specific field - 'neuron_sdk_version', # vllm specific field - 'vllm_version', # vllm specific field - 'vllm_build_flags', # vllm specific field - 'gpu_topo', # vllm specific field - 'env_vars', - ]) + "torch_version", + "is_debug_build", + "cuda_compiled_version", + "gcc_version", + "clang_version", + "cmake_version", + "os", + "libc_version", + "python_version", + "python_platform", + "is_cuda_available", + "cuda_runtime_version", + "maca_runtime_version", + "bios_version", + "cuda_module_loading", + "nvidia_driver_version", + "nvidia_gpu_models", + "cudnn_version", + "pip_version", # 'pip' or 'pip3' + "pip_packages", + "conda_packages", + "hip_compiled_version", + "hip_runtime_version", + "miopen_runtime_version", + "caching_allocator_config", + "is_xnnpack_available", + "cpu_info", + "rocm_version", # vllm specific field + "neuron_sdk_version", # vllm specific field + "vllm_version", # vllm specific field + "vllm_build_flags", # vllm specific field + "gpu_topo", # vllm specific field + "env_vars", + ], +) DEFAULT_CONDA_PATTERNS = { "torch", @@ -97,18 +100,17 @@ def run(command): """Return (return-code, stdout, stderr).""" shell = True if type(command) is str else False try: - p = subprocess.Popen(command, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - shell=shell) + p = subprocess.Popen( + command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=shell + ) raw_output, raw_err = p.communicate() rc = p.returncode - if get_platform() == 'win32': - enc = 'oem' + if get_platform() == "win32": + enc = "oem" else: enc = locale.getpreferredencoding() output = raw_output.decode(enc) - if command == 'mx-smi topo -m': + if command == "mx-smi topo -m": # don't remove the leading whitespace of `mx-smi topo -m` # because they are meaningful output = output.rstrip() @@ -119,7 +121,7 @@ def run(command): except FileNotFoundError: cmd_str = command if isinstance(command, str) else command[0] - return 127, '', f"Command not found: {cmd_str}" + return 127, "", f"Command not found: {cmd_str}" def run_and_read_all(run_lambda, command): @@ -146,49 +148,54 @@ def run_and_return_first_line(run_lambda, command): rc, out, _ = run_lambda(command) if rc != 0: return None - return out.split('\n')[0] + return out.split("\n")[0] def get_conda_packages(run_lambda, patterns=None): if patterns is None: patterns = DEFAULT_CONDA_PATTERNS - conda = os.environ.get('CONDA_EXE', 'conda') - out = run_and_read_all(run_lambda, [conda, 'list']) + conda = os.environ.get("CONDA_EXE", "conda") + out = run_and_read_all(run_lambda, [conda, "list"]) if out is None: return out - return "\n".join(line for line in out.splitlines() - if not line.startswith("#") and any(name in line - for name in patterns)) + return "\n".join( + line + for line in out.splitlines() + if not line.startswith("#") and any(name in line for name in patterns) + ) def get_gcc_version(run_lambda): - return run_and_parse_first_match(run_lambda, 'gcc --version', r'gcc (.*)') + return run_and_parse_first_match(run_lambda, "gcc --version", r"gcc (.*)") def get_clang_version(run_lambda): - return run_and_parse_first_match(run_lambda, 'clang --version', - r'clang version (.*)') + return run_and_parse_first_match( + run_lambda, "clang --version", r"clang version (.*)" + ) def get_cmake_version(run_lambda): - return run_and_parse_first_match(run_lambda, 'cmake --version', - r'cmake (.*)') + return run_and_parse_first_match(run_lambda, "cmake --version", r"cmake (.*)") def get_nvidia_driver_version(run_lambda): - if get_platform() == 'darwin': - cmd = 'kextstat | grep -i cuda' - return run_and_parse_first_match(run_lambda, cmd, - r'com[.]nvidia[.]CUDA [(](.*?)[)]') + if get_platform() == "darwin": + cmd = "kextstat | grep -i cuda" + return run_and_parse_first_match( + run_lambda, cmd, r"com[.]nvidia[.]CUDA [(](.*?)[)]" + ) smi = get_nvidia_smi() - return run_and_parse_first_match(run_lambda, smi, - r'Driver Version: (.*?) ') + return run_and_parse_first_match(run_lambda, smi, r"Driver Version: (.*?) ") def get_gpu_info(run_lambda): - if get_platform() == 'darwin' or (TORCH_AVAILABLE and hasattr( - torch.version, 'hip') and torch.version.hip is not None): + if get_platform() == "darwin" or ( + TORCH_AVAILABLE + and hasattr(torch.version, "hip") + and torch.version.hip is not None + ): if TORCH_AVAILABLE and torch.cuda.is_available(): if torch.version.hip is not None: prop = torch.cuda.get_device_properties(0) @@ -201,53 +208,50 @@ def get_gpu_info(run_lambda): return torch.cuda.get_device_name(None) + gcnArch return None smi = get_nvidia_smi() - uuid_regex = re.compile(r' \(UUID: .+?\)') - rc, out, _ = run_lambda(smi + ' -L') + uuid_regex = re.compile(r" \(UUID: .+?\)") + rc, out, _ = run_lambda(smi + " -L") if rc != 0: return None # Anonymize GPUs by removing their UUID - return re.sub(uuid_regex, '', out) + return re.sub(uuid_regex, "", out) def get_running_cuda_version(run_lambda): - return run_and_parse_first_match(run_lambda, 'nvcc --version', - r'release .+ V(.*)') + return run_and_parse_first_match(run_lambda, "nvcc --version", r"release .+ V(.*)") def get_running_maca_version(run_lambda): - return run_and_parse_first_match(run_lambda, 'mx-smi', - r'MACA Version:\s*([^\s]+)') + return run_and_parse_first_match(run_lambda, "mx-smi", r"MACA Version:\s*([^\s]+)") def get_bios_version(run_lambda): - return run_and_parse_first_match(run_lambda, 'mx-smi', - r'BIOS Version:\s*([^\s]+)') + return run_and_parse_first_match(run_lambda, "mx-smi", r"BIOS Version:\s*([^\s]+)") def get_cudnn_version(run_lambda): """Return a list of libcudnn.so; it's hard to tell which one is being used.""" - if get_platform() == 'win32': - system_root = os.environ.get('SYSTEMROOT', 'C:\\Windows') - cuda_path = os.environ.get('CUDA_PATH', "%CUDA_PATH%") - where_cmd = os.path.join(system_root, 'System32', 'where') + if get_platform() == "win32": + system_root = os.environ.get("SYSTEMROOT", "C:\\Windows") + cuda_path = os.environ.get("CUDA_PATH", "%CUDA_PATH%") + where_cmd = os.path.join(system_root, "System32", "where") cudnn_cmd = '{} /R "{}\\bin" cudnn*.dll'.format(where_cmd, cuda_path) - elif get_platform() == 'darwin': + elif get_platform() == "darwin": # CUDA libraries and drivers can be found in /usr/local/cuda/. See # https://docs.nvidia.com/cuda/cuda-installation-guide-mac-os-x/index.html#install # https://docs.nvidia.com/deeplearning/sdk/cudnn-install/index.html#installmac # Use CUDNN_LIBRARY when cudnn library is installed elsewhere. - cudnn_cmd = 'ls /usr/local/cuda/lib/libcudnn*' + cudnn_cmd = "ls /usr/local/cuda/lib/libcudnn*" else: cudnn_cmd = 'ldconfig -p | grep libcudnn | rev | cut -d" " -f1 | rev' rc, out, _ = run_lambda(cudnn_cmd) # find will return 1 if there are permission errors or if not found if len(out) == 0 or (rc != 1 and rc != 0): - l = os.environ.get('CUDNN_LIBRARY') + l = os.environ.get("CUDNN_LIBRARY") if l is not None and os.path.isfile(l): return os.path.realpath(l) return None files_set = set() - for fn in out.split('\n'): + for fn in out.split("\n"): fn = os.path.realpath(fn) # eliminate symbolic links if os.path.isfile(fn): files_set.add(fn) @@ -257,20 +261,20 @@ def get_cudnn_version(run_lambda): files = sorted(files_set) if len(files) == 1: return files[0] - result = '\n'.join(files) - return 'Probably one of the following:\n{}'.format(result) + result = "\n".join(files) + return "Probably one of the following:\n{}".format(result) def get_nvidia_smi(): # Note: mx-smi is currently available only on Windows and Linux - smi = 'mx-smi' - if get_platform() == 'win32': - system_root = os.environ.get('SYSTEMROOT', 'C:\\Windows') - program_files_root = os.environ.get('PROGRAMFILES', - 'C:\\Program Files') - legacy_path = os.path.join(program_files_root, 'NVIDIA Corporation', - 'NVSMI', smi) - new_path = os.path.join(system_root, 'System32', smi) + smi = "mx-smi" + if get_platform() == "win32": + system_root = os.environ.get("SYSTEMROOT", "C:\\Windows") + program_files_root = os.environ.get("PROGRAMFILES", "C:\\Program Files") + legacy_path = os.path.join( + program_files_root, "NVIDIA Corporation", "NVSMI", smi + ) + new_path = os.path.join(system_root, "System32", smi) smis = [new_path, legacy_path] for candidate_smi in smis: if os.path.exists(candidate_smi): @@ -281,17 +285,18 @@ def get_nvidia_smi(): def get_rocm_version(run_lambda): """Returns the ROCm version if available, otherwise 'N/A'.""" - return run_and_parse_first_match(run_lambda, 'hipcc --version', - r'HIP version: (\S+)') + return run_and_parse_first_match( + run_lambda, "hipcc --version", r"HIP version: (\S+)" + ) def get_neuron_sdk_version(run_lambda): # Adapted from your install script try: result = run_lambda(["neuron-ls"]) - return result if result[0] == 0 else 'N/A' + return result if result[0] == 0 else "N/A" except Exception: - return 'N/A' + return "N/A" def get_vllm_version(): @@ -304,12 +309,12 @@ def get_vllm_version(): if __version__ == "dev": return "N/A (dev)" version_str = __version_tuple__[-1] - if isinstance(version_str, str) and version_str.startswith('g'): + if isinstance(version_str, str) and version_str.startswith("g"): # it's a dev build - if '.' in version_str: + if "." in version_str: # it's a dev build containing local changes - git_sha = version_str.split('.')[0][1:] - date = version_str.split('.')[-1][1:] + git_sha = version_str.split(".")[0][1:] + date = version_str.split(".")[-1][1:] return f"{__version__} (git sha: {git_sha}, date: {date})" else: # it's a dev build without local changes @@ -320,20 +325,20 @@ def get_vllm_version(): def summarize_vllm_build_flags(): # This could be a static method if the flags are constant, or dynamic if you need to check environment variables, etc. - return 'CUDA Archs: {}; ROCm: {}; Neuron: {}'.format( - os.environ.get('TORCH_CUDA_ARCH_LIST', 'Not Set'), - 'Enabled' if os.environ.get('ROCM_HOME') else 'Disabled', - 'Enabled' if os.environ.get('NEURON_CORES') else 'Disabled', + return "CUDA Archs: {}; ROCm: {}; Neuron: {}".format( + os.environ.get("TORCH_CUDA_ARCH_LIST", "Not Set"), + "Enabled" if os.environ.get("ROCM_HOME") else "Disabled", + "Enabled" if os.environ.get("NEURON_CORES") else "Disabled", ) def get_gpu_topo(run_lambda): output = None - if get_platform() == 'linux': - output = run_and_read_all(run_lambda, 'mx-smi topo -m') + if get_platform() == "linux": + output = run_and_read_all(run_lambda, "mx-smi topo -m") if output is None: - output = run_and_read_all(run_lambda, 'rocm-smi --showtopo') + output = run_and_read_all(run_lambda, "rocm-smi --showtopo") return output @@ -415,17 +420,17 @@ def get_gpu_topo(run_lambda): def get_cpu_info(run_lambda): - rc, out, err = 0, '', '' - if get_platform() == 'linux': - rc, out, err = run_lambda('lscpu') - elif get_platform() == 'win32': + rc, out, err = 0, "", "" + if get_platform() == "linux": + rc, out, err = run_lambda("lscpu") + elif get_platform() == "win32": rc, out, err = run_lambda( - 'wmic cpu get Name,Manufacturer,Family,Architecture,ProcessorType,DeviceID, \ - CurrentClockSpeed,MaxClockSpeed,L2CacheSize,L2CacheSpeed,Revision /VALUE' + "wmic cpu get Name,Manufacturer,Family,Architecture,ProcessorType,DeviceID, \ + CurrentClockSpeed,MaxClockSpeed,L2CacheSize,L2CacheSpeed,Revision /VALUE" ) - elif get_platform() == 'darwin': + elif get_platform() == "darwin": rc, out, err = run_lambda("sysctl -n machdep.cpu.brand_string") - cpu_info = 'None' + cpu_info = "None" if rc == 0: cpu_info = out else: @@ -434,67 +439,69 @@ def get_cpu_info(run_lambda): def get_platform(): - if sys.platform.startswith('linux'): - return 'linux' - elif sys.platform.startswith('win32'): - return 'win32' - elif sys.platform.startswith('cygwin'): - return 'cygwin' - elif sys.platform.startswith('darwin'): - return 'darwin' + if sys.platform.startswith("linux"): + return "linux" + elif sys.platform.startswith("win32"): + return "win32" + elif sys.platform.startswith("cygwin"): + return "cygwin" + elif sys.platform.startswith("darwin"): + return "darwin" else: return sys.platform def get_mac_version(run_lambda): - return run_and_parse_first_match(run_lambda, 'sw_vers -productVersion', - r'(.*)') + return run_and_parse_first_match(run_lambda, "sw_vers -productVersion", r"(.*)") def get_windows_version(run_lambda): - system_root = os.environ.get('SYSTEMROOT', 'C:\\Windows') - wmic_cmd = os.path.join(system_root, 'System32', 'Wbem', 'wmic') - findstr_cmd = os.path.join(system_root, 'System32', 'findstr') + system_root = os.environ.get("SYSTEMROOT", "C:\\Windows") + wmic_cmd = os.path.join(system_root, "System32", "Wbem", "wmic") + findstr_cmd = os.path.join(system_root, "System32", "findstr") return run_and_read_all( - run_lambda, - '{} os get Caption | {} /v Caption'.format(wmic_cmd, findstr_cmd)) + run_lambda, "{} os get Caption | {} /v Caption".format(wmic_cmd, findstr_cmd) + ) def get_lsb_version(run_lambda): - return run_and_parse_first_match(run_lambda, 'lsb_release -a', - r'Description:\t(.*)') + return run_and_parse_first_match( + run_lambda, "lsb_release -a", r"Description:\t(.*)" + ) def check_release_file(run_lambda): - return run_and_parse_first_match(run_lambda, 'cat /etc/*-release', - r'PRETTY_NAME="(.*)"') + return run_and_parse_first_match( + run_lambda, "cat /etc/*-release", r'PRETTY_NAME="(.*)"' + ) def get_os(run_lambda): from platform import machine + platform = get_platform() - if platform == 'win32' or platform == 'cygwin': + if platform == "win32" or platform == "cygwin": return get_windows_version(run_lambda) - if platform == 'darwin': + if platform == "darwin": version = get_mac_version(run_lambda) if version is None: return None - return 'macOS {} ({})'.format(version, machine()) + return "macOS {} ({})".format(version, machine()) - if platform == 'linux': + if platform == "linux": # Ubuntu/Debian based desc = get_lsb_version(run_lambda) if desc is not None: - return '{} ({})'.format(desc, machine()) + return "{} ({})".format(desc, machine()) # Try reading /etc/*-release desc = check_release_file(run_lambda) if desc is not None: - return '{} ({})'.format(desc, machine()) + return "{} ({})".format(desc, machine()) - return '{} ({})'.format(platform, machine()) + return "{} ({})".format(platform, machine()) # Unknown platform return platform @@ -502,14 +509,16 @@ def get_os(run_lambda): def get_python_platform(): import platform + return platform.platform() def get_libc_version(): import platform - if get_platform() != 'linux': - return 'N/A' - return '-'.join(platform.libc_ver()) + + if get_platform() != "linux": + return "N/A" + return "-".join(platform.libc_ver()) def get_pip_packages(run_lambda, patterns=None): @@ -520,13 +529,14 @@ def get_pip_packages(run_lambda, patterns=None): def run_with_pip(): try: import importlib.util - pip_spec = importlib.util.find_spec('pip') + + pip_spec = importlib.util.find_spec("pip") pip_available = pip_spec is not None except ImportError: pip_available = False if pip_available: - cmd = [sys.executable, '-mpip', 'list', '--format=freeze'] + cmd = [sys.executable, "-mpip", "list", "--format=freeze"] elif os.environ.get("UV") is not None: print("uv is set") cmd = ["uv", "pip", "list", "--format=freeze"] @@ -536,23 +546,24 @@ def run_with_pip(): ) out = run_and_read_all(run_lambda, cmd) - return "\n".join(line for line in out.splitlines() - if any(name in line for name in patterns)) + return "\n".join( + line for line in out.splitlines() if any(name in line for name in patterns) + ) - pip_version = 'pip3' if sys.version[0] == '3' else 'pip' + pip_version = "pip3" if sys.version[0] == "3" else "pip" out = run_with_pip() return pip_version, out def get_cachingallocator_config(): - ca_config = os.environ.get('PYTORCH_CUDA_ALLOC_CONF', '') + ca_config = os.environ.get("PYTORCH_CUDA_ALLOC_CONF", "") return ca_config def get_cuda_module_loading_config(): if TORCH_AVAILABLE and torch.cuda.is_available(): torch.cuda.init() - config = os.environ.get('CUDA_MODULE_LOADING', '') + config = os.environ.get("CUDA_MODULE_LOADING", "") return config else: return "N/A" @@ -561,8 +572,8 @@ def get_cuda_module_loading_config(): def is_xnnpack_available(): if TORCH_AVAILABLE: import torch.backends.xnnpack - return str( - torch.backends.xnnpack.enabled) # type: ignore[attr-defined] + + return str(torch.backends.xnnpack.enabled) # type: ignore[attr-defined] else: return "N/A" @@ -582,10 +593,19 @@ def get_env_vars(): all_envs = vllm_envs | plugin_envs - env_vars = '' - secret_terms = ('secret', 'token', 'api', 'access', 'password') - report_prefix = ("TORCH", "NCCL", "PYTORCH", "CUDA", "CUBLAS", "CUDNN", - "OMP_", "MKL_", "NVIDIA") + env_vars = "" + secret_terms = ("secret", "token", "api", "access", "password") + report_prefix = ( + "TORCH", + "NCCL", + "PYTORCH", + "CUDA", + "CUBLAS", + "CUDNN", + "OMP_", + "MKL_", + "NVIDIA", + ) for k, v in os.environ.items(): if any(term in k.lower() for term in secret_terms): continue @@ -606,23 +626,24 @@ def get_env_info(): debug_mode_str = str(torch.version.debug) cuda_available_str = str(torch.cuda.is_available()) cuda_version_str = torch.version.cuda - if not hasattr(torch.version, - 'hip') or torch.version.hip is None: # cuda version - hip_compiled_version = hip_runtime_version = miopen_runtime_version = 'N/A' + if ( + not hasattr(torch.version, "hip") or torch.version.hip is None + ): # cuda version + hip_compiled_version = hip_runtime_version = miopen_runtime_version = "N/A" else: # HIP version def get_version_or_na(cfg, prefix): _lst = [s.rsplit(None, 1)[-1] for s in cfg if prefix in s] - return _lst[0] if _lst else 'N/A' + return _lst[0] if _lst else "N/A" - cfg = torch._C._show_config().split('\n') - hip_runtime_version = get_version_or_na(cfg, 'HIP Runtime') - miopen_runtime_version = get_version_or_na(cfg, 'MIOpen') - cuda_version_str = 'N/A' + cfg = torch._C._show_config().split("\n") + hip_runtime_version = get_version_or_na(cfg, "HIP Runtime") + miopen_runtime_version = get_version_or_na(cfg, "MIOpen") + cuda_version_str = "N/A" hip_compiled_version = torch.version.hip else: - version_str = debug_mode_str = cuda_available_str = cuda_version_str = 'N/A' - hip_compiled_version = hip_runtime_version = miopen_runtime_version = 'N/A' + version_str = debug_mode_str = cuda_available_str = cuda_version_str = "N/A" + hip_compiled_version = hip_runtime_version = miopen_runtime_version = "N/A" sys_version = sys.version.replace("\n", " ") @@ -637,9 +658,9 @@ def get_version_or_na(cfg, prefix): return SystemEnv( torch_version=version_str, is_debug_build=debug_mode_str, - python_version='{} ({}-bit runtime)'.format( - sys_version, - sys.maxsize.bit_length() + 1), + python_version="{} ({}-bit runtime)".format( + sys_version, sys.maxsize.bit_length() + 1 + ), python_platform=get_python_platform(), is_cuda_available=cuda_available_str, cuda_compiled_version=cuda_version_str, @@ -749,15 +770,14 @@ def get_version_or_na(cfg, prefix): def pretty_str(envinfo): - - def replace_nones(dct, replacement='Could not collect'): + def replace_nones(dct, replacement="Could not collect"): for key in dct.keys(): if dct[key] is not None: continue dct[key] = replacement return dct - def replace_bools(dct, true='Yes', false='No'): + def replace_bools(dct, true="Yes", false="No"): for key in dct.keys(): if dct[key] is True: dct[key] = true @@ -765,43 +785,48 @@ def replace_bools(dct, true='Yes', false='No'): dct[key] = false return dct - def prepend(text, tag='[prepend]'): - lines = text.split('\n') + def prepend(text, tag="[prepend]"): + lines = text.split("\n") updated_lines = [tag + line for line in lines] - return '\n'.join(updated_lines) + return "\n".join(updated_lines) - def replace_if_empty(text, replacement='No relevant packages'): + def replace_if_empty(text, replacement="No relevant packages"): if text is not None and len(text) == 0: return replacement return text def maybe_start_on_next_line(string): # If `string` is multiline, prepend a \n to it. - if string is not None and len(string.split('\n')) > 1: - return '\n{}\n'.format(string) + if string is not None and len(string.split("\n")) > 1: + return "\n{}\n".format(string) return string mutable_dict = envinfo._asdict() # If nvidia_gpu_models is multiline, start on the next line - mutable_dict['nvidia_gpu_models'] = \ - maybe_start_on_next_line(envinfo.nvidia_gpu_models) + mutable_dict["nvidia_gpu_models"] = maybe_start_on_next_line( + envinfo.nvidia_gpu_models + ) # If the machine doesn't have CUDA, report some fields as 'No CUDA' dynamic_cuda_fields = [ - 'cuda_runtime_version', - 'nvidia_gpu_models', - 'nvidia_driver_version', + "cuda_runtime_version", + "nvidia_gpu_models", + "nvidia_driver_version", ] - all_cuda_fields = dynamic_cuda_fields + ['cudnn_version'] - all_dynamic_cuda_fields_missing = all(mutable_dict[field] is None - for field in dynamic_cuda_fields) - if TORCH_AVAILABLE and not torch.cuda.is_available( - ) and all_dynamic_cuda_fields_missing: + all_cuda_fields = dynamic_cuda_fields + ["cudnn_version"] + all_dynamic_cuda_fields_missing = all( + mutable_dict[field] is None for field in dynamic_cuda_fields + ) + if ( + TORCH_AVAILABLE + and not torch.cuda.is_available() + and all_dynamic_cuda_fields_missing + ): for field in all_cuda_fields: - mutable_dict[field] = 'No CUDA' + mutable_dict[field] = "No CUDA" if envinfo.cuda_compiled_version is None: - mutable_dict['cuda_compiled_version'] = 'None' + mutable_dict["cuda_compiled_version"] = "None" # Replace True with Yes, False with No mutable_dict = replace_bools(mutable_dict) @@ -810,20 +835,20 @@ def maybe_start_on_next_line(string): mutable_dict = replace_nones(mutable_dict) # If either of these are '', replace with 'No relevant packages' - mutable_dict['pip_packages'] = replace_if_empty( - mutable_dict['pip_packages']) - mutable_dict['conda_packages'] = replace_if_empty( - mutable_dict['conda_packages']) + mutable_dict["pip_packages"] = replace_if_empty(mutable_dict["pip_packages"]) + mutable_dict["conda_packages"] = replace_if_empty(mutable_dict["conda_packages"]) # Tag conda and pip packages with a prefix # If they were previously None, they'll show up as ie '[conda] Could not collect' - if mutable_dict['pip_packages']: - mutable_dict['pip_packages'] = prepend( - mutable_dict['pip_packages'], '[{}] '.format(envinfo.pip_version)) - if mutable_dict['conda_packages']: - mutable_dict['conda_packages'] = prepend( - mutable_dict['conda_packages'], '[conda] ') - mutable_dict['cpu_info'] = envinfo.cpu_info + if mutable_dict["pip_packages"]: + mutable_dict["pip_packages"] = prepend( + mutable_dict["pip_packages"], "[{}] ".format(envinfo.pip_version) + ) + if mutable_dict["conda_packages"]: + mutable_dict["conda_packages"] = prepend( + mutable_dict["conda_packages"], "[conda] " + ) + mutable_dict["cpu_info"] = envinfo.cpu_info return env_info_fmt.format(**mutable_dict) @@ -836,22 +861,29 @@ def main(): output = get_pretty_env_info() print(output) - if TORCH_AVAILABLE and hasattr(torch, 'utils') and hasattr( - torch.utils, '_crash_handler'): + if ( + TORCH_AVAILABLE + and hasattr(torch, "utils") + and hasattr(torch.utils, "_crash_handler") + ): minidump_dir = torch.utils._crash_handler.DEFAULT_MINIDUMP_DIR if sys.platform == "linux" and os.path.exists(minidump_dir): dumps = [ - os.path.join(minidump_dir, dump) - for dump in os.listdir(minidump_dir) + os.path.join(minidump_dir, dump) for dump in os.listdir(minidump_dir) ] latest = max(dumps, key=os.path.getctime) ctime = os.path.getctime(latest) creation_time = datetime.datetime.fromtimestamp(ctime).strftime( - '%Y-%m-%d %H:%M:%S') - msg = "\n*** Detected a minidump at {} created on {}, ".format(latest, creation_time) + \ - "if this is related to your bug please include it when you file a report ***" + "%Y-%m-%d %H:%M:%S" + ) + msg = ( + "\n*** Detected a minidump at {} created on {}, ".format( + latest, creation_time + ) + + "if this is related to your bug please include it when you file a report ***" + ) print(msg, file=sys.stderr) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/vllm_metax/device_allocator/__init__.py b/vllm_metax/device_allocator/__init__.py index 35e1ee895..988131360 100644 --- a/vllm_metax/device_allocator/__init__.py +++ b/vllm_metax/device_allocator/__init__.py @@ -1 +1 @@ -# SPDX-License-Identifier: Apache-2.0 \ No newline at end of file +# SPDX-License-Identifier: Apache-2.0 diff --git a/vllm_metax/device_allocator/cumem.py b/vllm_metax/device_allocator/cumem.py index 847d8deb3..9ab8571d5 100644 --- a/vllm_metax/device_allocator/cumem.py +++ b/vllm_metax/device_allocator/cumem.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # cumem-based pytorch pluggable allocator to implement sleep mode. # other approaches tried but failed: @@ -10,20 +11,25 @@ import dataclasses import gc import os +from collections.abc import Callable from contextlib import contextmanager -from typing import Any, Callable, Dict, Optional, Tuple, Union +from typing import Any import torch -from vllm.utils import is_pin_memory_available +from vllm.logger import init_logger +from vllm.utils.platform_utils import is_pin_memory_available -def find_loaded_library(lib_name) -> Optional[str]: +logger = init_logger(__name__) + + +def find_loaded_library(lib_name) -> str | None: """ According to according to https://man7.org/linux/man-pages/man5/proc_pid_maps.5.html, the file `/proc/self/maps` contains the memory maps of the process, which includes the shared libraries loaded by the process. We can use this file to find the path of the a loaded library. - """ # noqa + """ # noqa found_line = None with open("/proc/self/maps") as f: for line in f: @@ -38,17 +44,22 @@ def find_loaded_library(lib_name) -> Optional[str]: start = found_line.index("/") path = found_line[start:].strip() filename = path.split("/")[-1] - assert filename.rpartition(".so")[0].startswith(lib_name), \ + assert filename.rpartition(".so")[0].startswith(lib_name), ( f"Unexpected filename: {filename} for library {lib_name}" + ) return path cumem_available = False +# /------------------------ Metax Modifications -------------------------\ try: - from vllm_metax.cumem_allocator import (init_module, python_create_and_map, - python_unmap_and_release) - from vllm_metax.distributed.device_communicators.cuda_wrapper import ( - CudaRTLibrary) + from vllm_metax.cumem_allocator import ( + init_module, + python_create_and_map, + python_unmap_and_release, + ) + from vllm_metax.patch.distributed.cuda_wrapper import CudaRTLibrary + lib_name = find_loaded_library("cumem_allocator") libcudart = CudaRTLibrary() cumem_available = True @@ -60,16 +71,17 @@ def find_loaded_library(lib_name) -> Optional[str]: CudaRTLibrary = None lib_name = None libcudart = None +# \------------------------ Metax Modifications -------------------------/ # py_device, py_alignedSize, py_d_mem, py_p_memHandle -HandleType = Tuple[int, int, int, int] +HandleType = tuple[int, int, int, int] @dataclasses.dataclass class AllocationData: handle: HandleType tag: str - cpu_backup_tensor: Optional[torch.Tensor] = None + cpu_backup_tensor: torch.Tensor | None = None def create_and_map(allocation_handle: HandleType) -> None: @@ -81,20 +93,19 @@ def unmap_and_release(allocation_handle: HandleType) -> None: def get_pluggable_allocator( - python_malloc_fn: Callable[[int], - int], python_free_func: Callable[[int, int], - None] + python_malloc_fn: Callable[[int], int], python_free_func: Callable[[int, int], None] ) -> torch.cuda.memory.CUDAPluggableAllocator: init_module(python_malloc_fn, python_free_func) new_alloc = torch.cuda.memory.CUDAPluggableAllocator( - lib_name, 'my_malloc', 'my_free') + lib_name, "my_malloc", "my_free" + ) return new_alloc @contextmanager def use_memory_pool_with_allocator( - python_malloc_fn: Callable[[int], int], - python_free_func: Callable[[int, int], None]) -> None: + python_malloc_fn: Callable[[int], int], python_free_func: Callable[[int, int], None] +) -> None: new_alloc = get_pluggable_allocator(python_malloc_fn, python_free_func) mem_pool = torch.cuda.memory.MemPool(new_alloc._allocator) with torch.cuda.memory.use_mem_pool(mem_pool): @@ -125,6 +136,7 @@ class CuMemAllocator: the global variable will be overwritten and the free callback will not work as expected. """ + instance: "CuMemAllocator" = None default_tag: str = "default" @@ -142,37 +154,53 @@ def get_instance() -> "CuMemAllocator": def __init__(self): conf = os.environ.get("PYTORCH_CUDA_ALLOC_CONF", "") - assert "expandable_segments:True" not in conf, \ - ("Expandable segments are not compatible with memory pool. " + assert "expandable_segments:True" not in conf, ( + "Expandable segments are not compatible with memory pool. " "Please track https://github.com/pytorch/pytorch/issues/147851 " - "for the latest updates.") + "for the latest updates." + ) - self.pointer_to_data: Dict[int, AllocationData] = {} + self.pointer_to_data: dict[int, AllocationData] = {} self.current_tag: str = CuMemAllocator.default_tag - self.allocator_and_pools: Dict[str, Any] = {} - - def python_malloc_callback(self, allocation_handle: HandleType) -> None: + self.allocator_and_pools: dict[str, Any] = {} + # Creating strong references to the two callbacks here to prevent + # these ephemeral bound-method objects being garbage collected. + # See discussions in https://github.com/vllm-project/vllm/pull/22724 + self.python_malloc_callback = self._python_malloc_callback + self.python_free_callback = self._python_free_callback + + def _python_malloc_callback(self, allocation_handle: HandleType) -> None: """ Internal method to store the allocation data when memory is allocated in the memory pool.""" py_d_mem = allocation_handle[2] self.pointer_to_data[py_d_mem] = AllocationData( - allocation_handle, self.current_tag) + allocation_handle, self.current_tag + ) + logger.debug( + "Allocated %s bytes for %s with address %s from cumem allocator", + allocation_handle[1], + self.current_tag, + py_d_mem, + ) return - def python_free_callback(self, ptr: int) -> HandleType: + def _python_free_callback(self, ptr: int) -> HandleType: """ Internal method to look up the allocation data when memory is freed in the memory pool.""" data = self.pointer_to_data.pop(ptr) if data.cpu_backup_tensor is not None: data.cpu_backup_tensor = None + logger.debug( + "Freed %s bytes for %s with address %s from cumem allocator", + data.handle[1], + data.tag, + ptr, + ) return data.handle - def sleep( - self, - offload_tags: Optional[Union[Tuple[str, ...], - str]] = None) -> None: + def sleep(self, offload_tags: tuple[str, ...] | str | None = None) -> None: """ Put the allocator in sleep mode. All data in the memory allocation with the specified tag will be @@ -184,35 +212,50 @@ def sleep( if offload_tags is None: # by default, allocated tensors are offloaded # when the allocator sleeps - offload_tags = (CuMemAllocator.default_tag, ) + offload_tags = (CuMemAllocator.default_tag,) elif isinstance(offload_tags, str): - offload_tags = (offload_tags, ) + offload_tags = (offload_tags,) assert isinstance(offload_tags, tuple) + total_bytes = 0 + backup_bytes = 0 + for ptr, data in self.pointer_to_data.items(): handle = data.handle + total_bytes += handle[1] if data.tag in offload_tags: + backup_bytes += handle[1] size_in_bytes = handle[1] cpu_backup_tensor = torch.empty( size_in_bytes, dtype=torch.uint8, - device='cpu', - pin_memory=is_pin_memory_available()) + device="cpu", + pin_memory=is_pin_memory_available(), + ) cpu_ptr = cpu_backup_tensor.data_ptr() libcudart.cudaMemcpy(cpu_ptr, ptr, size_in_bytes) data.cpu_backup_tensor = cpu_backup_tensor unmap_and_release(handle) + logger.info( + "CuMemAllocator: sleep freed %.2f GiB memory in total, of which " + "%.2f GiB is backed up in CPU and the rest %.2f GiB is discarded " + "directly.", + total_bytes / 1024**3, + backup_bytes / 1024**3, + (total_bytes - backup_bytes) / 1024**3, + ) + gc.collect() torch.cuda.empty_cache() - def wake_up(self, tags: Optional[list[str]] = None) -> None: + def wake_up(self, tags: list[str] | None = None) -> None: """ Wake up the allocator from sleep mode. - All data that is previously offloaded will be loaded back to GPU + All data that is previously offloaded will be loaded back to GPU memory, and the rest of the data will have empty memory. - + :param tags: The tags of the memory allocation that will be loaded back to GPU memory. If None, all memory allocation will be loaded back to GPU memory. @@ -224,14 +267,15 @@ def wake_up(self, tags: Optional[list[str]] = None) -> None: if data.cpu_backup_tensor is not None: cpu_backup_tensor = data.cpu_backup_tensor if cpu_backup_tensor is not None: - size_in_bytes = cpu_backup_tensor.numel( - ) * cpu_backup_tensor.element_size() + size_in_bytes = ( + cpu_backup_tensor.numel() * cpu_backup_tensor.element_size() + ) cpu_ptr = cpu_backup_tensor.data_ptr() libcudart.cudaMemcpy(ptr, cpu_ptr, size_in_bytes) data.cpu_backup_tensor = None @contextmanager - def use_memory_pool(self, tag: Optional[str] = None): + def use_memory_pool(self, tag: str | None = None): """ A context manager to use the memory pool. All memory allocation created inside the context will be allocated @@ -247,8 +291,9 @@ def use_memory_pool(self, tag: Optional[str] = None): old_tag = self.current_tag self.current_tag = tag - with use_memory_pool_with_allocator(self.python_malloc_callback, - self.python_free_callback) as data: + with use_memory_pool_with_allocator( + self.python_malloc_callback, self.python_free_callback + ) as data: # start to hit another PyTorch bug in PyTorch 2.6, # possibly because of gc-related issue w.r.t. the allocator and # the memory pool. @@ -260,12 +305,17 @@ def use_memory_pool(self, tag: Optional[str] = None): # when using pluggable allocator, see # https://github.com/pytorch/pytorch/issues/145168 . # if we have some memory allocated and then freed, - # the memory will not be released. - # right now it is fine, because we only use this allocator - # during weight loading and kv cache creation, where we only - # allocate memory. - # TODO: we need to find a way to release the memory, - # i.e. calling torch.cuda.empty_cache() + # the memory will not be released, e.g. in online quantization, + # where the model is created in higher precision, and then + # quantized in lower precision. + # Find all unused allocations and manually release them. + # TODO: we should expose `empty_cache` method in the memory pool. + # TODO: ask for help from PyTorch team to expose this method. + allocations = data[0].snapshot() + for allocation in allocations: + if allocation["allocated_size"] == 0: + handle = self._python_free_callback(allocation["address"]) + unmap_and_release(handle) self.current_tag = old_tag def get_current_usage(self) -> int: diff --git a/vllm_metax/distributed/device_communicators/__init__.py b/vllm_metax/distributed/device_communicators/__init__.py deleted file mode 100644 index 35e1ee895..000000000 --- a/vllm_metax/distributed/device_communicators/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 \ No newline at end of file diff --git a/vllm_metax/distributed/device_communicators/cuda_communicator.py b/vllm_metax/distributed/device_communicators/cuda_communicator.py deleted file mode 100644 index 0feb95bce..000000000 --- a/vllm_metax/distributed/device_communicators/cuda_communicator.py +++ /dev/null @@ -1,131 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -from typing import Optional - -import torch -from torch.distributed import ProcessGroup -from vllm.distributed.device_communicators.base_device_communicator import ( - DeviceCommunicatorBase) - - -class CudaCommunicator(DeviceCommunicatorBase): - - def __init__(self, - cpu_group: ProcessGroup, - device: Optional[torch.device] = None, - device_group: Optional[ProcessGroup] = None, - unique_name: str = ""): - super().__init__(cpu_group, device, device_group, unique_name) - if "tp" not in unique_name: - # only tp uses custom allreduce - use_custom_allreduce = False - else: - from vllm_metax.distributed.parallel_state import ( - _ENABLE_CUSTOM_ALL_REDUCE) - use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE - use_pynccl = True - - self.use_pynccl = use_pynccl - self.use_custom_allreduce = use_custom_allreduce - - # lazy import to avoid documentation build error - from vllm_metax.distributed.device_communicators.custom_all_reduce import ( - CustomAllreduce) - from vllm_metax.distributed.device_communicators.pynccl import ( - PyNcclCommunicator) - - self.pynccl_comm: Optional[PyNcclCommunicator] = None - if use_pynccl and self.world_size > 1: - self.pynccl_comm = PyNcclCommunicator( - group=self.cpu_group, - device=self.device, - ) - - self.ca_comm: Optional[CustomAllreduce] = None - if use_custom_allreduce and self.world_size > 1: - # Initialize a custom fast all-reduce implementation. - self.ca_comm = CustomAllreduce( - group=self.cpu_group, - device=self.device, - ) - - def all_reduce(self, input_): - # always try custom allreduce first, - # and then pynccl. - ca_comm = self.ca_comm - if ca_comm is not None and not ca_comm.disabled and \ - ca_comm.should_custom_ar(input_): - out = ca_comm.custom_all_reduce(input_) - assert out is not None - return out - pynccl_comm = self.pynccl_comm - assert pynccl_comm is not None - out = pynccl_comm.all_reduce(input_) - if out is None: - # fall back to the default all-reduce using PyTorch. - # this usually happens during testing. - # when we run the model, allreduce only happens for the TP - # group, where we always have either custom allreduce or pynccl. - out = input_.clone() - torch.distributed.all_reduce(out, group=self.device_group) - return out - - def reduce_scatter(self, input_: torch.Tensor, dim: int = -1): - world_size = self.world_size - pynccl_comm = self.pynccl_comm - assert pynccl_comm is not None - if dim < 0: - # Convert negative dim to positive. - dim += input_.dim() - - # Note: This will produce an incorrect answer if we don't make - # the input_tensor contiguous. Possible bug in reduce_scatter_tensor? - input_tensor = input_.movedim(0, dim).contiguous() - - assert input_tensor.shape[0] % world_size == 0 - chunk_size = input_tensor.shape[0] // world_size - output_shape = (chunk_size, ) + input_tensor.shape[1:] - - output = torch.empty(output_shape, - dtype=input_tensor.dtype, - device=input_tensor.device) - - pynccl_comm.reduce_scatter(output, input_) - - # Reshape before returning - return output.movedim(0, dim).contiguous() - - def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None: - """Sends a tensor to the destination rank in a non-blocking way""" - """NOTE: `dst` is the local rank of the destination rank.""" - if dst is None: - dst = (self.rank_in_group + 1) % self.world_size - - pynccl_comm = self.pynccl_comm - if pynccl_comm is not None and not pynccl_comm.disabled: - pynccl_comm.send(tensor, dst) - else: - torch.distributed.send(tensor, self.ranks[dst], self.device_group) - - def recv(self, - size: torch.Size, - dtype: torch.dtype, - src: Optional[int] = None) -> torch.Tensor: - """Receives a tensor from the source rank.""" - """NOTE: `src` is the local rank of the source rank.""" - if src is None: - src = (self.rank_in_group - 1) % self.world_size - - tensor = torch.empty(size, dtype=dtype, device=self.device) - pynccl_comm = self.pynccl_comm - if pynccl_comm is not None and not pynccl_comm.disabled: - pynccl_comm.recv(tensor, src) - else: - torch.distributed.recv(tensor, self.ranks[src], self.device_group) - return tensor - - def destroy(self): - if self.pynccl_comm is not None: - self.pynccl_comm = None - if self.ca_comm is not None: - self.ca_comm = None diff --git a/vllm_metax/distributed/device_communicators/cuda_wrapper.py b/vllm_metax/distributed/device_communicators/cuda_wrapper.py deleted file mode 100644 index 82dc7bfb2..000000000 --- a/vllm_metax/distributed/device_communicators/cuda_wrapper.py +++ /dev/null @@ -1,178 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -"""This file is a pure Python wrapper for the cudart library. -It avoids the need to compile a separate shared library, and is -convenient for use when we just need to call a few functions. -""" - -import ctypes -from dataclasses import dataclass -from typing import Any, Dict, List, Optional - -# this line makes it possible to directly load `libcudart.so` using `ctypes` -import torch # noqa -import vllm.envs as envs -from vllm.logger import init_logger - -logger = init_logger(__name__) - -# === export types and functions from cudart to Python === -# for the original cudart definition, please check -# https://docs.nvidia.com/cuda/cuda-runtime-api/index.html - -cudaError_t = ctypes.c_int -cudaMemcpyKind = ctypes.c_int - - -class cudaIpcMemHandle_t(ctypes.Structure): - _fields_ = [("internal", ctypes.c_byte * 128)] - - -@dataclass -class Function: - name: str - restype: Any - argtypes: List[Any] - - -def find_loaded_library(lib_name) -> Optional[str]: - """ - According to according to https://man7.org/linux/man-pages/man5/proc_pid_maps.5.html, - the file `/proc/self/maps` contains the memory maps of the process, which includes the - shared libraries loaded by the process. We can use this file to find the path of the - a loaded library. - """ # noqa - found = False - with open("/proc/self/maps") as f: - for line in f: - if lib_name in line: - found = True - break - if not found: - # the library is not loaded in the current process - return None - # if lib_name is libcudart, we need to match a line with: - # address /path/to/libcudart-hash.so.11.0 - start = line.index("/") - path = line[start:].strip() - filename = path.split("/")[-1] - assert filename.rpartition(".so")[0].startswith(lib_name), \ - f"Unexpected filename: {filename} for library {lib_name}" - return path - - -class CudaRTLibrary: - exported_functions = [ - # ​cudaError_t cudaSetDevice ( int device ) - Function("mcSetDevice", cudaError_t, [ctypes.c_int]), - # cudaError_t cudaDeviceSynchronize ( void ) - Function("mcDeviceSynchronize", cudaError_t, []), - # ​cudaError_t cudaDeviceReset ( void ) - Function("mcDeviceReset", cudaError_t, []), - - # const char* cudaGetErrorString ( cudaError_t error ) - Function("mcGetErrorString", ctypes.c_char_p, [cudaError_t]), - - # ​cudaError_t cudaMalloc ( void** devPtr, size_t size ) - Function("mcMalloc", cudaError_t, - [ctypes.POINTER(ctypes.c_void_p), ctypes.c_size_t]), - # ​cudaError_t cudaFree ( void* devPtr ) - Function("mcFree", cudaError_t, [ctypes.c_void_p]), - # ​cudaError_t cudaMemset ( void* devPtr, int value, size_t count ) - Function("mcMemset", cudaError_t, - [ctypes.c_void_p, ctypes.c_int, ctypes.c_size_t]), - # ​cudaError_t cudaMemcpy ( void* dst, const void* src, size_t count, cudaMemcpyKind kind ) # noqa - Function("mcMemcpy", cudaError_t, [ - ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, cudaMemcpyKind - ]), - - # cudaError_t cudaIpcGetMemHandle ( cudaIpcMemHandle_t* handle, void* devPtr ) # noqa - Function("mcIpcGetMemHandle", cudaError_t, - [ctypes.POINTER(cudaIpcMemHandle_t), ctypes.c_void_p]), - # ​cudaError_t cudaIpcOpenMemHandle ( void** devPtr, cudaIpcMemHandle_t handle, unsigned int flags ) # noqa - Function("mcIpcOpenMemHandle", cudaError_t, [ - ctypes.POINTER(ctypes.c_void_p), cudaIpcMemHandle_t, ctypes.c_uint - ]), - ] - - # class attribute to store the mapping from the path to the library - # to avoid loading the same library multiple times - path_to_library_cache: Dict[str, Any] = {} - - # class attribute to store the mapping from library path - # to the corresponding dictionary - path_to_dict_mapping: Dict[str, Dict[str, Any]] = {} - - def __init__(self, so_file: Optional[str] = None): - if so_file is None: - so_file = find_loaded_library("libmcruntime") - if so_file is None: - so_file = envs.VLLM_CUDART_SO_PATH # fallback to env var - assert so_file is not None, \ - ( - "libcudart is not loaded in the current process, " - "try setting VLLM_CUDART_SO_PATH" - ) - if so_file not in CudaRTLibrary.path_to_library_cache: - lib = ctypes.CDLL(so_file) - CudaRTLibrary.path_to_library_cache[so_file] = lib - self.lib = CudaRTLibrary.path_to_library_cache[so_file] - - if so_file not in CudaRTLibrary.path_to_dict_mapping: - _funcs = {} - for func in CudaRTLibrary.exported_functions: - f = getattr(self.lib, func.name) - f.restype = func.restype - f.argtypes = func.argtypes - _funcs[func.name] = f - CudaRTLibrary.path_to_dict_mapping[so_file] = _funcs - self.funcs = CudaRTLibrary.path_to_dict_mapping[so_file] - - def CUDART_CHECK(self, result: cudaError_t) -> None: - if result != 0: - error_str = self.cudaGetErrorString(result) - raise RuntimeError(f"CUDART error: {error_str}") - - def cudaGetErrorString(self, error: cudaError_t) -> str: - return self.funcs["mcGetErrorString"](error).decode("utf-8") - - def cudaSetDevice(self, device: int) -> None: - self.CUDART_CHECK(self.funcs["mcSetDevice"](device)) - - def cudaDeviceSynchronize(self) -> None: - self.CUDART_CHECK(self.funcs["mcDeviceSynchronize"]()) - - def cudaDeviceReset(self) -> None: - self.CUDART_CHECK(self.funcs["mcDeviceReset"]()) - - def cudaMalloc(self, size: int) -> ctypes.c_void_p: - devPtr = ctypes.c_void_p() - self.CUDART_CHECK(self.funcs["mcMalloc"](ctypes.byref(devPtr), size)) - return devPtr - - def cudaFree(self, devPtr: ctypes.c_void_p) -> None: - self.CUDART_CHECK(self.funcs["mcFree"](devPtr)) - - def cudaMemset(self, devPtr: ctypes.c_void_p, value: int, - count: int) -> None: - self.CUDART_CHECK(self.funcs["mcMemset"](devPtr, value, count)) - - def cudaMemcpy(self, dst: ctypes.c_void_p, src: ctypes.c_void_p, - count: int) -> None: - cudaMemcpyDefault = 4 - kind = cudaMemcpyDefault - self.CUDART_CHECK(self.funcs["mcMemcpy"](dst, src, count, kind)) - - def cudaIpcGetMemHandle(self, - devPtr: ctypes.c_void_p) -> cudaIpcMemHandle_t: - handle = cudaIpcMemHandle_t() - self.CUDART_CHECK(self.funcs["mcIpcGetMemHandle"](ctypes.byref(handle), - devPtr)) - return handle - - def cudaIpcOpenMemHandle(self, - handle: cudaIpcMemHandle_t) -> ctypes.c_void_p: - cudaIpcMemLazyEnablePeerAccess = 1 - devPtr = ctypes.c_void_p() - self.CUDART_CHECK(self.funcs["mcIpcOpenMemHandle"]( - ctypes.byref(devPtr), handle, cudaIpcMemLazyEnablePeerAccess)) - return devPtr diff --git a/vllm_metax/distributed/device_communicators/pynccl.py b/vllm_metax/distributed/device_communicators/pynccl.py deleted file mode 100644 index c15a06509..000000000 --- a/vllm_metax/distributed/device_communicators/pynccl.py +++ /dev/null @@ -1,217 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -from typing import Optional, Union - -# ===================== import region ===================== -import torch -import torch.distributed as dist -from torch.distributed import ProcessGroup, ReduceOp -from vllm.distributed.utils import StatelessProcessGroup -from vllm.logger import init_logger -from vllm.utils import current_stream - -from vllm_metax.distributed.device_communicators.pynccl_wrapper import ( - NCCLLibrary, buffer_type, cudaStream_t, ncclComm_t, ncclDataTypeEnum, - ncclRedOpTypeEnum, ncclUniqueId) - -logger = init_logger(__name__) - - -class PyNcclCommunicator: - - def __init__( - self, - group: Union[ProcessGroup, StatelessProcessGroup], - device: Union[int, str, torch.device], - library_path: Optional[str] = None, - ): - """ - Args: - group: the process group to work on. If None, it will use the - default process group. - device: the device to bind the PyNcclCommunicator to. If None, - it will be bind to f"cuda:{local_rank}". - library_path: the path to the NCCL library. If None, it will - use the default library path. - It is the caller's responsibility to make sure each communicator - is bind to a unique device. - """ - if not isinstance(group, StatelessProcessGroup): - assert dist.is_initialized() - assert dist.get_backend(group) != dist.Backend.NCCL, ( - "PyNcclCommunicator should be attached to a non-NCCL group.") - # note: this rank is the rank in the group - self.rank = dist.get_rank(group) - self.world_size = dist.get_world_size(group) - else: - self.rank = group.rank - self.world_size = group.world_size - - self.group = group - - # if world_size == 1, no need to create communicator - if self.world_size == 1: - self.available = False - self.disabled = True - return - try: - self.nccl = NCCLLibrary(library_path) - except Exception: - # disable because of missing NCCL library - # e.g. in a non-GPU environment - self.available = False - self.disabled = True - return - - self.available = True - self.disabled = False - - logger.info("vLLM is using nccl==%s", self.nccl.ncclGetVersion()) - - if self.rank == 0: - # get the unique id from NCCL - self.unique_id = self.nccl.ncclGetUniqueId() - else: - # construct an empty unique id - self.unique_id = ncclUniqueId() - - if not isinstance(group, StatelessProcessGroup): - tensor = torch.ByteTensor(list(self.unique_id.internal)) - ranks = dist.get_process_group_ranks(group) - # arg `src` in `broadcast` is the global rank - dist.broadcast(tensor, src=ranks[0], group=group) - byte_list = tensor.tolist() - for i, byte in enumerate(byte_list): - self.unique_id.internal[i] = byte - else: - self.unique_id = group.broadcast_obj(self.unique_id, src=0) - if isinstance(device, int): - device = torch.device(f"cuda:{device}") - elif isinstance(device, str): - device = torch.device(device) - # now `device` is a `torch.device` object - assert isinstance(device, torch.device) - self.device = device - # nccl communicator and stream will use this device - # `torch.cuda.device` is a context manager that changes the - # current cuda device to the specified one - with torch.cuda.device(device): - self.comm: ncclComm_t = self.nccl.ncclCommInitRank( - self.world_size, self.unique_id, self.rank) - - stream = current_stream() - # A small all_reduce for warmup. - data = torch.zeros(1, device=device) - self.all_reduce(data) - stream.synchronize() - del data - - def all_reduce(self, - in_tensor: torch.Tensor, - op: ReduceOp = ReduceOp.SUM, - stream=None) -> torch.Tensor: - if self.disabled: - return None - # nccl communicator created on a specific device - # will only work on tensors on the same device - # otherwise it will cause "illegal memory access" - assert in_tensor.device == self.device, ( - f"this nccl communicator is created to work on {self.device}, " - f"but the input tensor is on {in_tensor.device}") - - out_tensor = torch.empty_like(in_tensor) - - if stream is None: - stream = current_stream() - self.nccl.ncclAllReduce(buffer_type(in_tensor.data_ptr()), - buffer_type(out_tensor.data_ptr()), - in_tensor.numel(), - ncclDataTypeEnum.from_torch(in_tensor.dtype), - ncclRedOpTypeEnum.from_torch(op), self.comm, - cudaStream_t(stream.cuda_stream)) - return out_tensor - - def all_gather(self, - output_tensor: torch.Tensor, - input_tensor: torch.Tensor, - stream=None): - if self.disabled: - return - # nccl communicator created on a specific device - # will only work on tensors on the same device - # otherwise it will cause "illegal memory access" - assert input_tensor.device == self.device, ( - f"this nccl communicator is created to work on {self.device}, " - f"but the input tensor is on {input_tensor.device}") - if stream is None: - stream = current_stream() - self.nccl.ncclAllGather( - buffer_type(input_tensor.data_ptr()), - buffer_type(output_tensor.data_ptr()), input_tensor.numel(), - ncclDataTypeEnum.from_torch(input_tensor.dtype), self.comm, - cudaStream_t(stream.cuda_stream)) - - def reduce_scatter(self, - output_tensor: torch.Tensor, - input_tensor: torch.Tensor, - op: ReduceOp = ReduceOp.SUM, - stream=None): - if self.disabled: - return - # nccl communicator created on a specific device - # will only work on tensors on the same device - # otherwise it will cause "illegal memory access" - assert input_tensor.device == self.device, ( - f"this nccl communicator is created to work on {self.device}, " - f"but the input tensor is on {input_tensor.device}") - if stream is None: - stream = current_stream() - self.nccl.ncclReduceScatter( - buffer_type(input_tensor.data_ptr()), - buffer_type(output_tensor.data_ptr()), output_tensor.numel(), - ncclDataTypeEnum.from_torch(input_tensor.dtype), - ncclRedOpTypeEnum.from_torch(op), self.comm, - cudaStream_t(stream.cuda_stream)) - - def send(self, tensor: torch.Tensor, dst: int, stream=None): - if self.disabled: - return - assert tensor.device == self.device, ( - f"this nccl communicator is created to work on {self.device}, " - f"but the input tensor is on {tensor.device}") - if stream is None: - stream = current_stream() - self.nccl.ncclSend(buffer_type(tensor.data_ptr()), tensor.numel(), - ncclDataTypeEnum.from_torch(tensor.dtype), dst, - self.comm, cudaStream_t(stream.cuda_stream)) - - def recv(self, tensor: torch.Tensor, src: int, stream=None): - if self.disabled: - return - assert tensor.device == self.device, ( - f"this nccl communicator is created to work on {self.device}, " - f"but the input tensor is on {tensor.device}") - if stream is None: - stream = current_stream() - self.nccl.ncclRecv(buffer_type(tensor.data_ptr()), tensor.numel(), - ncclDataTypeEnum.from_torch(tensor.dtype), src, - self.comm, cudaStream_t(stream.cuda_stream)) - - def broadcast(self, tensor: torch.Tensor, src: int, stream=None): - if self.disabled: - return - assert tensor.device == self.device, ( - f"this nccl communicator is created to work on {self.device}, " - f"but the input tensor is on {tensor.device}") - if stream is None: - stream = current_stream() - if src == self.rank: - sendbuff = buffer_type(tensor.data_ptr()) - # NCCL requires the sender also to have a receive buffer - recvbuff = buffer_type(tensor.data_ptr()) - else: - sendbuff = buffer_type() - recvbuff = buffer_type(tensor.data_ptr()) - self.nccl.ncclBroadcast(sendbuff, recvbuff, tensor.numel(), - ncclDataTypeEnum.from_torch(tensor.dtype), src, - self.comm, cudaStream_t(stream.cuda_stream)) diff --git a/vllm_metax/distributed/device_communicators/pynccl_wrapper.py b/vllm_metax/distributed/device_communicators/pynccl_wrapper.py deleted file mode 100644 index 38c929fdb..000000000 --- a/vllm_metax/distributed/device_communicators/pynccl_wrapper.py +++ /dev/null @@ -1,339 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -# This file is a pure Python wrapper for the NCCL library. -# The main purpose is to use NCCL combined with CUDA graph. -# Before writing this script, we tried the following approach: -# 1. We tried to use `cupy`, it calls NCCL correctly, but `cupy` itself -# often gets stuck when initializing the NCCL communicator. -# 2. We tried to use `torch.distributed`, but `torch.distributed.all_reduce` -# contains many other potential cuda APIs, that are not allowed during -# capturing the CUDA graph. For further details, please check -# https://discuss.pytorch.org/t/pytorch-cudagraph-with-nccl-operation-failed/ . -# -# Another rejected idea is to write a C/C++ binding for NCCL. It is usually -# doable, but we often encounter issues related with nccl versions, and need -# to switch between different versions of NCCL. See -# https://github.com/NVIDIA/nccl/issues/1234 for more details. -# A C/C++ binding is not flexible enough to handle this. It requires -# recompilation of the code every time we want to switch between different -# versions. This current implementation, with a **pure** Python wrapper, is -# more flexible. We can easily switch between different versions of NCCL by -# changing the environment variable `VLLM_NCCL_SO_PATH`, or the `so_file` -# variable in the code. - -import ctypes -import platform -from dataclasses import dataclass -from typing import Any, Dict, List, Optional - -import torch -from torch.distributed import ReduceOp -from vllm.logger import init_logger -from vllm.utils import find_nccl_library - -logger = init_logger(__name__) - -# === export types and functions from nccl to Python === -# for the original nccl definition, please check -# https://github.com/NVIDIA/nccl/blob/master/src/nccl.h.in - -ncclResult_t = ctypes.c_int -ncclComm_t = ctypes.c_void_p - - -class ncclUniqueId(ctypes.Structure): - _fields_ = [("internal", ctypes.c_byte * 128)] - - -cudaStream_t = ctypes.c_void_p -buffer_type = ctypes.c_void_p - -ncclDataType_t = ctypes.c_int - - -class ncclDataTypeEnum: - ncclInt8 = 0 - ncclChar = 0 - ncclUint8 = 1 - ncclInt32 = 2 - ncclInt = 2 - ncclUint32 = 3 - ncclInt64 = 4 - ncclUint64 = 5 - ncclFloat16 = 6 - ncclHalf = 6 - ncclFloat32 = 7 - ncclFloat = 7 - ncclFloat64 = 8 - ncclDouble = 8 - ncclBfloat16 = 9 - ncclNumTypes = 10 - - @classmethod - def from_torch(cls, dtype: torch.dtype) -> int: - if dtype == torch.int8: - return cls.ncclInt8 - if dtype == torch.uint8: - return cls.ncclUint8 - if dtype == torch.int32: - return cls.ncclInt32 - if dtype == torch.int64: - return cls.ncclInt64 - if dtype == torch.float16: - return cls.ncclFloat16 - if dtype == torch.float32: - return cls.ncclFloat32 - if dtype == torch.float64: - return cls.ncclFloat64 - if dtype == torch.bfloat16: - return cls.ncclBfloat16 - raise ValueError(f"Unsupported dtype: {dtype}") - - -ncclRedOp_t = ctypes.c_int - - -class ncclRedOpTypeEnum: - ncclSum = 0 - ncclProd = 1 - ncclMax = 2 - ncclMin = 3 - ncclAvg = 4 - ncclNumOps = 5 - - @classmethod - def from_torch(cls, op: ReduceOp) -> int: - if op == ReduceOp.SUM: - return cls.ncclSum - if op == ReduceOp.PRODUCT: - return cls.ncclProd - if op == ReduceOp.MAX: - return cls.ncclMax - if op == ReduceOp.MIN: - return cls.ncclMin - if op == ReduceOp.AVG: - return cls.ncclAvg - raise ValueError(f"Unsupported op: {op}") - - -@dataclass -class Function: - name: str - restype: Any - argtypes: List[Any] - - -class NCCLLibrary: - exported_functions = [ - # const char* ncclGetErrorString(ncclResult_t result) - Function("mcclGetErrorString", ctypes.c_char_p, [ncclResult_t]), - # ncclResult_t ncclGetVersion(int *version); - Function("mcclGetVersion", ncclResult_t, - [ctypes.POINTER(ctypes.c_int)]), - # ncclResult_t ncclGetUniqueId(ncclUniqueId* uniqueId); - Function("mcclGetUniqueId", ncclResult_t, - [ctypes.POINTER(ncclUniqueId)]), - # ncclResult_t ncclCommInitRank( - # ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank); - # note that ncclComm_t is a pointer type, so the first argument - # is a pointer to a pointer - Function("mcclCommInitRank", ncclResult_t, [ - ctypes.POINTER(ncclComm_t), ctypes.c_int, ncclUniqueId, - ctypes.c_int - ]), - # ncclResult_t ncclAllReduce( - # const void* sendbuff, void* recvbuff, size_t count, - # ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm, - # cudaStream_t stream); - # note that cudaStream_t is a pointer type, so the last argument - # is a pointer - Function("mcclAllReduce", ncclResult_t, [ - buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t, - ncclRedOp_t, ncclComm_t, cudaStream_t - ]), - - # ncclResult_t ncclAllGather( - # const void* sendbuff, void* recvbuff, size_t count, - # ncclDataType_t datatype, ncclComm_t comm, - # cudaStream_t stream); - # note that cudaStream_t is a pointer type, so the last argument - # is a pointer - Function("mcclAllGather", ncclResult_t, [ - buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t, - ncclComm_t, cudaStream_t - ]), - - # ncclResult_t ncclReduceScatter( - # const void* sendbuff, void* recvbuff, size_t count, - # ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm, - # cudaStream_t stream); - # note that cudaStream_t is a pointer type, so the last argument - # is a pointer - Function("mcclReduceScatter", ncclResult_t, [ - buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t, - ncclRedOp_t, ncclComm_t, cudaStream_t - ]), - - # ncclResult_t ncclSend( - # const void* sendbuff, size_t count, ncclDataType_t datatype, - # int dest, ncclComm_t comm, cudaStream_t stream); - Function("mcclSend", ncclResult_t, [ - buffer_type, ctypes.c_size_t, ncclDataType_t, ctypes.c_int, - ncclComm_t, cudaStream_t - ]), - - # ncclResult_t ncclRecv( - # void* recvbuff, size_t count, ncclDataType_t datatype, - # int src, ncclComm_t comm, cudaStream_t stream); - Function("mcclRecv", ncclResult_t, [ - buffer_type, ctypes.c_size_t, ncclDataType_t, ctypes.c_int, - ncclComm_t, cudaStream_t - ]), - - # ncclResult_t ncclBroadcast( - # const void* sendbuff, void* recvbuff, size_t count, - # ncclDataType_t datatype, int root, ncclComm_t comm, - # cudaStream_t stream); - Function("mcclBroadcast", ncclResult_t, [ - buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t, - ctypes.c_int, ncclComm_t, cudaStream_t - ]), - - # be cautious! this is a collective call, it will block until all - # processes in the communicator have called this function. - # because Python object destruction can happen in random order, - # it is better not to call it at all. - # ncclResult_t ncclCommDestroy(ncclComm_t comm); - Function("mcclCommDestroy", ncclResult_t, [ncclComm_t]), - ] - - # class attribute to store the mapping from the path to the library - # to avoid loading the same library multiple times - path_to_library_cache: Dict[str, Any] = {} - - # class attribute to store the mapping from library path - # to the corresponding dictionary - path_to_dict_mapping: Dict[str, Dict[str, Any]] = {} - - def __init__(self, so_file: Optional[str] = None): - - so_file = so_file or find_nccl_library() - - try: - if so_file not in NCCLLibrary.path_to_dict_mapping: - lib = ctypes.CDLL(so_file) - NCCLLibrary.path_to_library_cache[so_file] = lib - self.lib = NCCLLibrary.path_to_library_cache[so_file] - except Exception as e: - logger.error( - "Failed to load NCCL library from %s. " - "It is expected if you are not running on NVIDIA/AMD GPUs." - "Otherwise, the nccl library might not exist, be corrupted " - "or it does not support the current platform %s. " - "If you already have the library, please set the " - "environment variable VLLM_NCCL_SO_PATH" - " to point to the correct nccl library path.", so_file, - platform.platform()) - raise e - - if so_file not in NCCLLibrary.path_to_dict_mapping: - _funcs: Dict[str, Any] = {} - for func in NCCLLibrary.exported_functions: - f = getattr(self.lib, func.name) - f.restype = func.restype - f.argtypes = func.argtypes - _funcs[func.name] = f - NCCLLibrary.path_to_dict_mapping[so_file] = _funcs - self._funcs = NCCLLibrary.path_to_dict_mapping[so_file] - - def ncclGetErrorString(self, result: ncclResult_t) -> str: - return self._funcs["mcclGetErrorString"](result).decode("utf-8") - - def NCCL_CHECK(self, result: ncclResult_t) -> None: - if result != 0: - error_str = self.ncclGetErrorString(result) - raise RuntimeError(f"NCCL error: {error_str}") - - def ncclGetVersion(self) -> str: - version = ctypes.c_int() - self.NCCL_CHECK(self._funcs["mcclGetVersion"](ctypes.byref(version))) - version_str = str(version.value) - # something like 21903 --> "2.19.3" - major = version_str[0].lstrip("0") - minor = version_str[1:3].lstrip("0") - patch = version_str[3:].lstrip("0") - return f"{major}.{minor}.{patch}" - - def ncclGetUniqueId(self) -> ncclUniqueId: - unique_id = ncclUniqueId() - self.NCCL_CHECK(self._funcs["mcclGetUniqueId"]( - ctypes.byref(unique_id))) - return unique_id - - def ncclCommInitRank(self, world_size: int, unique_id: ncclUniqueId, - rank: int) -> ncclComm_t: - comm = ncclComm_t() - self.NCCL_CHECK(self._funcs["mcclCommInitRank"](ctypes.byref(comm), - world_size, unique_id, - rank)) - return comm - - def ncclAllReduce(self, sendbuff: buffer_type, recvbuff: buffer_type, - count: int, datatype: int, op: int, comm: ncclComm_t, - stream: cudaStream_t) -> None: - # `datatype` actually should be `ncclDataType_t` - # and `op` should be `ncclRedOp_t` - # both are aliases of `ctypes.c_int` - # when we pass int to a function, it will be converted to `ctypes.c_int` - # by ctypes automatically - self.NCCL_CHECK(self._funcs["mcclAllReduce"](sendbuff, recvbuff, count, - datatype, op, comm, - stream)) - - def ncclReduceScatter(self, sendbuff: buffer_type, recvbuff: buffer_type, - count: int, datatype: int, op: int, comm: ncclComm_t, - stream: cudaStream_t) -> None: - # `datatype` actually should be `ncclDataType_t` - # and `op` should be `ncclRedOp_t` - # both are aliases of `ctypes.c_int` - # when we pass int to a function, it will be converted to `ctypes.c_int` - # by ctypes automatically - self.NCCL_CHECK(self._funcs["mcclReduceScatter"](sendbuff, recvbuff, - count, datatype, op, - comm, stream)) - - def ncclAllGather(self, sendbuff: buffer_type, recvbuff: buffer_type, - count: int, datatype: int, comm: ncclComm_t, - stream: cudaStream_t) -> None: - # `datatype` actually should be `ncclDataType_t` - # which is an aliases of `ctypes.c_int` - # when we pass int to a function, it will be converted to `ctypes.c_int` - # by ctypes automatically - self.NCCL_CHECK(self._funcs["mcclAllGather"](sendbuff, recvbuff, count, - datatype, comm, stream)) - - def ncclSend(self, sendbuff: buffer_type, count: int, datatype: int, - dest: int, comm: ncclComm_t, stream: cudaStream_t) -> None: - self.NCCL_CHECK(self._funcs["mcclSend"](sendbuff, count, datatype, - dest, comm, stream)) - - def ncclRecv(self, recvbuff: buffer_type, count: int, datatype: int, - src: int, comm: ncclComm_t, stream: cudaStream_t) -> None: - self.NCCL_CHECK(self._funcs["mcclRecv"](recvbuff, count, datatype, src, - comm, stream)) - - def ncclBroadcast(self, sendbuff: buffer_type, recvbuff: buffer_type, - count: int, datatype: int, root: int, comm: ncclComm_t, - stream: cudaStream_t) -> None: - self.NCCL_CHECK(self._funcs["mcclBroadcast"](sendbuff, recvbuff, count, - datatype, root, comm, - stream)) - - def ncclCommDestroy(self, comm: ncclComm_t) -> None: - self.NCCL_CHECK(self._funcs["mcclCommDestroy"](comm)) - - -__all__ = [ - "NCCLLibrary", "ncclDataTypeEnum", "ncclRedOpTypeEnum", "ncclUniqueId", - "ncclComm_t", "cudaStream_t", "buffer_type" -] diff --git a/vllm_metax/distributed/parallel_state.py b/vllm_metax/distributed/parallel_state.py deleted file mode 100644 index bfceb5709..000000000 --- a/vllm_metax/distributed/parallel_state.py +++ /dev/null @@ -1,1210 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -# Copyright 2023 The vLLM team. -# Adapted from -# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. -"""vLLM distributed state. -It takes over the control of the distributed environment from PyTorch. -The typical workflow is: - -- call `init_distributed_environment` to initialize the distributed environment. -- call `initialize_model_parallel` or `ensure_model_parallel_initialized` to - initialize the model parallel groups. - -- any code dealing with the distributed stuff - -- call `destroy_model_parallel` to destroy the model parallel groups. -- call `destroy_distributed_environment` to destroy the distributed environment. - -If you only need to use the distributed environment without model/pipeline - parallelism, you can skip the model parallel initialization and destruction - steps. -""" -import contextlib -import gc -import pickle -import weakref -from collections import namedtuple -from contextlib import contextmanager, nullcontext -from dataclasses import dataclass -from multiprocessing import shared_memory -from typing import Any, Callable, Dict, List, Optional, Tuple, Union -from unittest.mock import patch - -import torch -import torch.distributed -import vllm.envs as envs -from torch.distributed import Backend, ProcessGroup -from vllm.distributed.device_communicators.base_device_communicator import ( - DeviceCommunicatorBase) -from vllm.distributed.utils import StatelessProcessGroup -from vllm.logger import init_logger -from vllm.utils import (direct_register_custom_op, resolve_obj_by_qualname, - supports_custom_op) - - -@dataclass -class GraphCaptureContext: - stream: torch.cuda.Stream - - -TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"]) - - -def _split_tensor_dict( - tensor_dict: Dict[str, Union[torch.Tensor, Any]] -) -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]: - """Split the tensor dictionary into two parts: - 1. A list of (key, value) pairs. If the value is a tensor, it is replaced - by its metadata. - 2. A list of tensors. - """ - metadata_list: List[Tuple[str, Any]] = [] - tensor_list: List[torch.Tensor] = [] - for key, value in tensor_dict.items(): - if isinstance(value, torch.Tensor): - # Note: we cannot use `value.device` here, - # because it contains not only the device type but also the device - # index (e.g. "cuda:0"). We only need the device type. - # receiving side will set the device index. - device = value.device.type - metadata_list.append( - (key, TensorMetadata(device, value.dtype, value.size()))) - tensor_list.append(value) - else: - metadata_list.append((key, value)) - return metadata_list, tensor_list - - -_group_name_counter: Dict[str, int] = {} - - -def _get_unique_name(name: str) -> str: - """Get a unique name for the group. - Example: - _get_unique_name("tp") -> "tp:0" - _get_unique_name("tp") -> "tp:1" - """ - if name not in _group_name_counter: - _group_name_counter[name] = 0 - newname = f"{name}:{_group_name_counter[name]}" - _group_name_counter[name] += 1 - return newname - - -_groups: Dict[str, Callable[[], Optional["GroupCoordinator"]]] = {} - - -def _register_group(group: "GroupCoordinator") -> None: - _groups[group.unique_name] = weakref.ref(group) - - -def all_reduce(tensor: torch.Tensor, group_name: str) -> torch.Tensor: - assert group_name in _groups, f"Group {group_name} is not found." - group = _groups[group_name]() - if group is None: - raise ValueError(f"Group {group_name} is destroyed.") - return group._all_reduce_out_place(tensor) - - -def all_reduce_fake(tensor: torch.Tensor, group_name: str) -> torch.Tensor: - return torch.empty_like(tensor) - - -def reduce_scatter(tensor: torch.Tensor, dim: int, world_size: int, - group_name: str) -> torch.Tensor: - assert group_name in _groups, f"Group {group_name} is not found." - group = _groups[group_name]() - if group is None: - raise ValueError(f"Group {group_name} is destroyed.") - return group.reduce_scatter(tensor, dim) - - -def reduce_scatter_fake(tensor: torch.Tensor, dim: int, world_size: int, - group_name: str) -> torch.Tensor: - new_shape = list(tensor.shape) - new_shape[dim] = tensor.shape[dim] // world_size - return torch.empty(new_shape, dtype=tensor.dtype, device=tensor.device) - - -def all_gather(tensor: torch.Tensor, dim: int, world_size: int, - group_name: str) -> torch.Tensor: - assert group_name in _groups, f"Group {group_name} is not found." - group = _groups[group_name]() - if group is None: - raise ValueError(f"Group {group_name} is destroyed.") - return group.all_gather(tensor, dim) - - -def all_gather_fake(tensor: torch.Tensor, dim: int, world_size: int, - group_name: str) -> torch.Tensor: - new_shape = list(tensor.shape) - new_shape[dim] = tensor.shape[dim] * world_size - return torch.empty(new_shape, dtype=tensor.dtype, device=tensor.device) - - -if supports_custom_op(): - from vllm.platforms import current_platform - direct_register_custom_op( - op_name="mx_all_reduce", - op_func=all_reduce, - mutates_args=[], - fake_impl=all_reduce_fake, - dispatch_key=current_platform.dispatch_key, - ) - - direct_register_custom_op( - op_name="mx_reduce_scatter", - op_func=reduce_scatter, - mutates_args=[], - fake_impl=reduce_scatter_fake, - ) - - direct_register_custom_op( - op_name="mx_all_gather", - op_func=all_gather, - mutates_args=[], - fake_impl=all_gather_fake, - ) - - -class GroupCoordinator: - """ - PyTorch ProcessGroup wrapper for a group of processes. - PyTorch ProcessGroup is bound to one specific communication backend, - e.g. NCCL, Gloo, MPI, etc. - GroupCoordinator takes charge of all the communication operations among - the processes in the group. It manages both CPU and device - communication. - """ - - # available attributes: - rank: int # global rank - ranks: List[int] # global ranks in the group - world_size: int # size of the group - # difference between `local_rank` and `rank_in_group`: - # if we have a group of size 4 across two nodes: - # Process | Node | Rank | Local Rank | Rank in Group - # 0 | 0 | 0 | 0 | 0 - # 1 | 0 | 1 | 1 | 1 - # 2 | 1 | 2 | 0 | 2 - # 3 | 1 | 3 | 1 | 3 - local_rank: int # local rank used to assign devices - rank_in_group: int # rank inside the group - cpu_group: ProcessGroup # group for CPU communication - device_group: ProcessGroup # group for device communication - use_device_communicator: bool # whether to use device communicator - device_communicator: DeviceCommunicatorBase # device communicator - mq_broadcaster: Optional[Any] # shared memory broadcaster - - def __init__( - self, - group_ranks: List[List[int]], - local_rank: int, - torch_distributed_backend: Union[str, Backend], - use_device_communicator: bool, - use_message_queue_broadcaster: bool = False, - group_name: Optional[str] = None, - ): - group_name = group_name or "anonymous" - self.unique_name = _get_unique_name(group_name) - _register_group(self) - - self.rank = torch.distributed.get_rank() - self.local_rank = local_rank - self.device_group = None - self.cpu_group = None - - for ranks in group_ranks: - device_group = torch.distributed.new_group( - ranks, backend=torch_distributed_backend) - # a group with `gloo` backend, to allow direct coordination between - # processes through the CPU. - cpu_group = torch.distributed.new_group(ranks, backend="gloo") - if self.rank in ranks: - self.ranks = ranks - self.world_size = len(ranks) - self.rank_in_group = ranks.index(self.rank) - self.device_group = device_group - self.cpu_group = cpu_group - - assert self.cpu_group is not None - assert self.device_group is not None - - from vllm.platforms import current_platform - - if current_platform.is_cuda_alike(): - self.device = torch.device(f"cuda:{local_rank}") - elif current_platform.is_out_of_tree(): - self.device = torch.device( - f"{current_platform.device_name}:{local_rank}") - else: - self.device = torch.device("cpu") - - self.use_device_communicator = use_device_communicator - - self.device_communicator: DeviceCommunicatorBase = None # type: ignore - if use_device_communicator and self.world_size > 1: - device_comm_cls = resolve_obj_by_qualname( - current_platform.get_device_communicator_cls()) - self.device_communicator = device_comm_cls( - cpu_group=self.cpu_group, - device=self.device, - device_group=self.device_group, - unique_name=self.unique_name, - ) - - from vllm.distributed.device_communicators.shm_broadcast import ( - MessageQueue) - self.mq_broadcaster: Optional[MessageQueue] = None - if use_message_queue_broadcaster and self.world_size > 1: - self.mq_broadcaster = MessageQueue.create_from_process_group( - self.cpu_group, 1 << 22, 6) - - from vllm.platforms import current_platform - self.use_custom_op_call = (current_platform.is_cuda_alike() - or current_platform.is_tpu()) - - @property - def first_rank(self): - """Return the global rank of the first process in the group""" - return self.ranks[0] - - @property - def last_rank(self): - """Return the global rank of the last process in the group""" - return self.ranks[-1] - - @property - def is_first_rank(self): - """Return whether the caller is the first process in the group""" - return self.rank == self.first_rank - - @property - def is_last_rank(self): - """Return whether the caller is the last process in the group""" - return self.rank == self.last_rank - - @property - def next_rank(self): - """Return the global rank of the process that follows the caller""" - rank_in_group = self.rank_in_group - world_size = self.world_size - return self.ranks[(rank_in_group + 1) % world_size] - - @property - def prev_rank(self): - """Return the global rank of the process that precedes the caller""" - rank_in_group = self.rank_in_group - world_size = self.world_size - return self.ranks[(rank_in_group - 1) % world_size] - - @contextmanager - def graph_capture( - self, graph_capture_context: Optional[GraphCaptureContext] = None): - if graph_capture_context is None: - stream = torch.cuda.Stream() - graph_capture_context = GraphCaptureContext(stream) - else: - stream = graph_capture_context.stream - - # only cuda uses this function, - # so we don't abstract it into the base class - maybe_ca_context = nullcontext() - from vllm_metax.distributed.device_communicators.cuda_communicator import ( - CudaCommunicator) - if self.device_communicator is not None: - assert isinstance(self.device_communicator, CudaCommunicator) - ca_comm = self.device_communicator.ca_comm - if ca_comm is not None: - maybe_ca_context = ca_comm.capture() # type: ignore - - # ensure all initialization operations complete before attempting to - # capture the graph on another stream - curr_stream = torch.cuda.current_stream() - if curr_stream != stream: - stream.wait_stream(curr_stream) - - with torch.cuda.stream(stream), maybe_ca_context: - yield graph_capture_context - - def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: - """ - User-facing all-reduce function before we actually call the - all-reduce operation. - - We need this because Dynamo does not support passing an arbitrary - object (`self` in this case) to a custom op. We need to pass the - group name as a string, and then look up the group coordinator from - the group name, dispatch the all-reduce operation to the group - coordinator. - - In addition, PyTorch custom ops do not support mutation or returning - a new tensor in the same op. So we always make the all-reduce operation - out-of-place. - """ - # Bypass the function if we are using only 1 GPU. - if self.world_size == 1: - return input_ - - if self.use_custom_op_call: - return torch.ops.vllm.all_reduce(input_, - group_name=self.unique_name) - else: - return self._all_reduce_out_place(input_) - - def _all_reduce_out_place(self, input_: torch.Tensor) -> torch.Tensor: - return self.device_communicator.all_reduce(input_) - - def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: - world_size = self.world_size - # Bypass the function if we are using only 1 GPU. - if world_size == 1: - return input_ - assert -input_.dim() <= dim < input_.dim(), ( - f"Invalid dim ({dim}) for input tensor with shape {input_.size()}") - - return self.device_communicator.all_gather(input_, dim) - - def reduce_scatter(self, - input_: torch.Tensor, - dim: int = -1) -> torch.Tensor: - world_size = self.world_size - # Bypass the function if we are using only 1 GPU. - if world_size == 1: - return input_ - assert -input_.dim() <= dim < input_.dim(), ( - f"Invalid dim ({dim}) for input tensor with shape {input_.size()}") - - return self.device_communicator.reduce_scatter(input_, dim) - - def gather(self, - input_: torch.Tensor, - dst: int = 0, - dim: int = -1) -> Optional[torch.Tensor]: - """ - NOTE: We assume that the input tensor is on the same device across - all the ranks. - NOTE: `dst` is the local rank of the destination rank. - """ - world_size = self.world_size - # Bypass the function if we are using only 1 GPU. - if world_size == 1: - return input_ - return self.device_communicator.gather(input_, dst, dim) - - def broadcast(self, input_: torch.Tensor, src: int = 0): - """Broadcast the input tensor. - NOTE: `src` is the local rank of the source rank. - """ - assert src < self.world_size, f"Invalid src rank ({src})" - - # Bypass the function if we are using only 1 GPU. - if self.world_size == 1: - return input_ - # Broadcast. - torch.distributed.broadcast(input_, - src=self.ranks[src], - group=self.device_group) - return input_ - - def broadcast_object(self, obj: Optional[Any] = None, src: int = 0): - """Broadcast the input object. - NOTE: `src` is the local rank of the source rank. - """ - assert src < self.world_size, f"Invalid src rank ({src})" - - # Bypass the function if we are using only 1 GPU. - if self.world_size == 1: - return obj - if self.mq_broadcaster is not None: - assert src == 0, "Message queue broadcaster only supports src=0" - return self.mq_broadcaster.broadcast_object(obj) - if self.rank_in_group == src: - torch.distributed.broadcast_object_list([obj], - src=self.ranks[src], - group=self.cpu_group) - return obj - else: - recv = [None] - torch.distributed.broadcast_object_list(recv, - src=self.ranks[src], - group=self.cpu_group) - return recv[0] - - def broadcast_object_list(self, - obj_list: List[Any], - src: int = 0, - group: Optional[ProcessGroup] = None): - """Broadcast the input object list. - NOTE: `src` is the local rank of the source rank. - """ - assert src < self.world_size, f"Invalid src rank ({src})" - - # Bypass the function if we are using only 1 GPU. - if self.world_size == 1: - return obj_list - # Broadcast. - torch.distributed.broadcast_object_list(obj_list, - src=self.ranks[src], - group=self.device_group) - return obj_list - - def send_object(self, obj: Any, dst: int) -> None: - """Send the input object list to the destination rank.""" - """NOTE: `dst` is the local rank of the destination rank.""" - - assert dst < self.world_size, f"Invalid dst rank ({dst})" - - assert dst != self.rank_in_group, ( - "Invalid destination rank. Destination rank is the same " - "as the current rank.") - - # Serialize object to tensor and get the size as well - object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8) - - size_tensor = torch.tensor([object_tensor.numel()], - dtype=torch.long, - device="cpu") - - # Send object size - - torch.distributed.send(size_tensor, - dst=self.ranks[dst], - group=self.cpu_group) - - # Send object - torch.distributed.send(object_tensor, - dst=self.ranks[dst], - group=self.cpu_group) - - return None - - def recv_object(self, src: int) -> Any: - """Receive the input object list from the source rank.""" - """NOTE: `src` is the local rank of the source rank.""" - - assert src < self.world_size, f"Invalid src rank ({src})" - - assert src != self.rank_in_group, ( - "Invalid source rank. Source rank is the same as the current rank." - ) - - size_tensor = torch.empty(1, dtype=torch.long, device="cpu") - - # Receive object size - rank_size = torch.distributed.recv(size_tensor, - src=self.ranks[src], - group=self.cpu_group) - - # Tensor to receive serialized objects into. - object_tensor = torch.empty( # type: ignore[call-overload] - size_tensor.item(), # type: ignore[arg-type] - dtype=torch.uint8, - device="cpu") - - rank_object = torch.distributed.recv(object_tensor, - src=self.ranks[src], - group=self.cpu_group) - - assert rank_object == rank_size, ( - "Received object sender rank does not match the size sender rank.") - - obj = pickle.loads(object_tensor.numpy().tobytes()) - - return obj - - def broadcast_tensor_dict( - self, - tensor_dict: Optional[Dict[str, Union[torch.Tensor, Any]]] = None, - src: int = 0, - group: Optional[ProcessGroup] = None, - metadata_group: Optional[ProcessGroup] = None - ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]: - """Broadcast the input tensor dictionary. - NOTE: `src` is the local rank of the source rank. - """ - # Bypass the function if we are using only 1 GPU. - if (not torch.distributed.is_initialized() or self.world_size == 1): - return tensor_dict - - group = self.device_group - metadata_group = self.cpu_group - assert src < self.world_size, f"Invalid src rank ({src})" - - rank_in_group = self.rank_in_group - if rank_in_group == src: - metadata_list: List[Tuple[Any, Any]] = [] - assert isinstance( - tensor_dict, - dict), (f"Expecting a dictionary, got {type(tensor_dict)}") - metadata_list, tensor_list = _split_tensor_dict(tensor_dict) - # `metadata_list` lives in CPU memory. - # `broadcast_object_list` has serialization & deserialization, - # all happening on CPU. Therefore, we can use the CPU group. - self.broadcast_object(metadata_list, src=src) - async_handles = [] - for tensor in tensor_list: - if tensor.numel() == 0: - # Skip broadcasting empty tensors. - continue - if tensor.is_cpu: - # use metadata_group for CPU tensors - handle = torch.distributed.broadcast(tensor, - src=self.ranks[src], - group=metadata_group, - async_op=True) - else: - # use group for GPU tensors - handle = torch.distributed.broadcast(tensor, - src=self.ranks[src], - group=group, - async_op=True) - async_handles.append(handle) - for async_handle in async_handles: - async_handle.wait() - - else: - metadata_list = self.broadcast_object(None, src=src) - tensor_dict = {} - async_handles = [] - for key, value in metadata_list: - if isinstance(value, TensorMetadata): - tensor = torch.empty(value.size, - dtype=value.dtype, - device=value.device) - if tensor.numel() == 0: - # Skip broadcasting empty tensors. - tensor_dict[key] = tensor - continue - if tensor.is_cpu: - # use metadata_group for CPU tensors - handle = torch.distributed.broadcast( - tensor, - src=self.ranks[src], - group=metadata_group, - async_op=True) - else: - # use group for GPU tensors - handle = torch.distributed.broadcast( - tensor, - src=self.ranks[src], - group=group, - async_op=True) - async_handles.append(handle) - tensor_dict[key] = tensor - else: - tensor_dict[key] = value - for async_handle in async_handles: - async_handle.wait() - return tensor_dict - - def send_tensor_dict( - self, - tensor_dict: Dict[str, Union[torch.Tensor, Any]], - dst: Optional[int] = None, - all_gather_group: Optional["GroupCoordinator"] = None, - ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]: - """Send the input tensor dictionary. - NOTE: `dst` is the local rank of the source rank. - """ - # Bypass the function if we are using only 1 GPU. - if not torch.distributed.is_initialized() or self.world_size == 1: - return tensor_dict - - all_gather_size = (1 if all_gather_group is None else - all_gather_group.world_size) - all_gather_rank = (0 if all_gather_group is None else - all_gather_group.rank_in_group) - - group = self.device_group - metadata_group = self.cpu_group - - if dst is None: - dst = (self.rank_in_group + 1) % self.world_size - assert dst < self.world_size, f"Invalid dst rank ({dst})" - - metadata_list: List[Tuple[Any, Any]] = [] - assert isinstance( - tensor_dict, - dict), f"Expecting a dictionary, got {type(tensor_dict)}" - metadata_list, tensor_list = _split_tensor_dict(tensor_dict) - # `metadata_list` lives in CPU memory. - # `send_object_list` has serialization & deserialization, - # all happening on CPU. Therefore, we can use the CPU group. - self.send_object(metadata_list, dst=dst) - for tensor in tensor_list: - if tensor.numel() == 0: - # Skip sending empty tensors. - continue - - # send-allgather: send only a slice, then do allgather. - if (all_gather_group is not None - and tensor.numel() % all_gather_size == 0): - tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank] - - if tensor.is_cpu: - # use metadata_group for CPU tensors - torch.distributed.send(tensor, - dst=self.ranks[dst], - group=metadata_group) - else: - # use group for GPU tensors - torch.distributed.send(tensor, - dst=self.ranks[dst], - group=group) - return None - - def recv_tensor_dict( - self, - src: Optional[int] = None, - all_gather_group: Optional["GroupCoordinator"] = None, - ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]: - """Recv the input tensor dictionary. - NOTE: `src` is the local rank of the source rank. - """ - # Bypass the function if we are using only 1 GPU. - if not torch.distributed.is_initialized() or self.world_size == 1: - return None - - all_gather_size = (1 if all_gather_group is None else - all_gather_group.world_size) - all_gather_rank = (0 if all_gather_group is None else - all_gather_group.rank_in_group) - - group = self.device_group - metadata_group = self.cpu_group - - if src is None: - src = (self.rank_in_group - 1) % self.world_size - assert src < self.world_size, f"Invalid src rank ({src})" - - recv_metadata_list = self.recv_object(src=src) - tensor_dict: Dict[str, Any] = {} - for key, value in recv_metadata_list: - if isinstance(value, TensorMetadata): - tensor = torch.empty(value.size, - dtype=value.dtype, - device=value.device) - if tensor.numel() == 0: - # Skip broadcasting empty tensors. - tensor_dict[key] = tensor - continue - - # send-allgather: send only a slice, then do allgather. - use_all_gather = (all_gather_group is not None - and tensor.numel() % all_gather_size == 0) - - if use_all_gather: - orig_shape = tensor.shape - tensor = tensor.reshape(all_gather_size, - -1)[all_gather_rank] - - if tensor.is_cpu: - # use metadata_group for CPU tensors - torch.distributed.recv(tensor, - src=self.ranks[src], - group=metadata_group) - else: - # use group for GPU tensors - torch.distributed.recv(tensor, - src=self.ranks[src], - group=group) - if use_all_gather: - # do the allgather - tensor = all_gather_group.all_gather( # type: ignore - tensor, dim=0) - tensor = tensor.reshape(orig_shape) - - tensor_dict[key] = tensor - else: - tensor_dict[key] = value - return tensor_dict - - def barrier(self): - """Barrier synchronization among the group. - NOTE: don't use `device_group` here! `barrier` in NCCL is - terrible because it is internally a broadcast operation with - secretly created GPU tensors. It is easy to mess up the current - device. Use the CPU group instead. - """ - torch.distributed.barrier(group=self.cpu_group) - - def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None: - """Sends a tensor to the destination rank in a non-blocking way""" - """NOTE: `dst` is the local rank of the destination rank.""" - self.device_communicator.send(tensor, dst) - - def recv(self, - size: torch.Size, - dtype: torch.dtype, - src: Optional[int] = None) -> torch.Tensor: - """Receives a tensor from the source rank.""" - """NOTE: `src` is the local rank of the source rank.""" - return self.device_communicator.recv(size, dtype, src) - - def destroy(self): - if self.device_group is not None: - torch.distributed.destroy_process_group(self.device_group) - self.device_group = None - if self.cpu_group is not None: - torch.distributed.destroy_process_group(self.cpu_group) - self.cpu_group = None - if self.device_communicator is not None: - self.device_communicator.destroy() - if self.mq_broadcaster is not None: - self.mq_broadcaster = None - - -_WORLD: Optional[GroupCoordinator] = None - - -def get_world_group() -> GroupCoordinator: - assert _WORLD is not None, ("world group is not initialized") - return _WORLD - - -def init_world_group(ranks: List[int], local_rank: int, - backend: str) -> GroupCoordinator: - return GroupCoordinator( - group_ranks=[ranks], - local_rank=local_rank, - torch_distributed_backend=backend, - use_device_communicator=False, - group_name="world", - ) - - -def init_model_parallel_group( - group_ranks: List[List[int]], - local_rank: int, - backend: str, - use_message_queue_broadcaster: bool = False, - group_name: Optional[str] = None, -) -> GroupCoordinator: - - return GroupCoordinator( - group_ranks=group_ranks, - local_rank=local_rank, - torch_distributed_backend=backend, - use_device_communicator=True, - use_message_queue_broadcaster=use_message_queue_broadcaster, - group_name=group_name, - ) - - -_TP: Optional[GroupCoordinator] = None - - -def get_tp_group() -> GroupCoordinator: - assert _TP is not None, ("tensor model parallel group is not initialized") - return _TP - - -# kept for backward compatibility -get_tensor_model_parallel_group = get_tp_group - -_PP: Optional[GroupCoordinator] = None - -_DP: Optional[GroupCoordinator] = None - - -def get_dp_group() -> GroupCoordinator: - assert _DP is not None, ("data parallel group is not initialized") - return _DP - - -def get_pp_group() -> GroupCoordinator: - assert _PP is not None, ( - "pipeline model parallel group is not initialized") - return _PP - - -# kept for backward compatibility -get_pipeline_model_parallel_group = get_pp_group - - -@contextmanager -def graph_capture(device: torch.device): - """ - `graph_capture` is a context manager which should surround the code that - is capturing the CUDA graph. Its main purpose is to ensure that the - some operations will be run after the graph is captured, before the graph - is replayed. It returns a `GraphCaptureContext` object which contains the - necessary data for the graph capture. Currently, it only contains the - stream that the graph capture is running on. This stream is set to the - current CUDA stream when the context manager is entered and reset to the - default stream when the context manager is exited. This is to ensure that - the graph capture is running on a separate stream from the default stream, - in order to explicitly distinguish the kernels to capture - from other kernels possibly launched on background in the default stream. - """ - context = GraphCaptureContext(torch.cuda.Stream(device=device)) - with get_tp_group().graph_capture(context), get_pp_group().graph_capture( - context): - yield context - - -logger = init_logger(__name__) - -_ENABLE_CUSTOM_ALL_REDUCE = True - - -def set_custom_all_reduce(enable: bool): - global _ENABLE_CUSTOM_ALL_REDUCE - _ENABLE_CUSTOM_ALL_REDUCE = enable - - -def init_distributed_environment( - world_size: int = -1, - rank: int = -1, - distributed_init_method: str = "env://", - local_rank: int = -1, - backend: str = "nccl", -): - logger.debug( - "world_size=%d rank=%d local_rank=%d " - "distributed_init_method=%s backend=%s", world_size, rank, local_rank, - distributed_init_method, backend) - from vllm.config import get_current_vllm_config - config = get_current_vllm_config() - if config is not None and config.parallel_config.data_parallel_size > 1: - parallel_config = config.parallel_config - # adjust to take into account data parallelism - # offset the rank by the data parallel rank - rank = parallel_config.data_parallel_rank * world_size + rank - # adjust the world size to take into account data parallelism - world_size = parallel_config.world_size_across_dp - ip = parallel_config.data_parallel_master_ip - port = parallel_config.get_next_dp_init_port() - distributed_init_method = f"tcp://{ip}:{port}" # noqa - logger.info( - "Adjusting world_size=%d rank=%d distributed_init_method=%s for DP", - world_size, rank, distributed_init_method) - if not torch.distributed.is_initialized(): - assert distributed_init_method is not None, ( - "distributed_init_method must be provided when initializing " - "distributed environment") - # this backend is used for WORLD - torch.distributed.init_process_group( - backend=backend, - init_method=distributed_init_method, - world_size=world_size, - rank=rank) - # set the local rank - # local_rank is not available in torch ProcessGroup, - # see https://github.com/pytorch/pytorch/issues/122816 - if local_rank == -1: - # local rank not set, this usually happens in single-node - # setting, where we can use rank as local rank - if distributed_init_method == "env://": - local_rank = envs.LOCAL_RANK - else: - local_rank = rank - global _WORLD - if _WORLD is None: - ranks = list(range(torch.distributed.get_world_size())) - _WORLD = init_world_group(ranks, local_rank, backend) - else: - assert _WORLD.world_size == torch.distributed.get_world_size(), ( - "world group already initialized with a different world size") - - -def initialize_model_parallel( - tensor_model_parallel_size: int = 1, - pipeline_model_parallel_size: int = 1, - backend: Optional[str] = None, -) -> None: - """ - Initialize model parallel groups. - - Arguments: - tensor_model_parallel_size: number of GPUs used for tensor model - parallelism. - pipeline_model_parallel_size: number of GPUs used for pipeline model - parallelism. - - Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we - use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize - the model pipeline. The present function will - create 4 tensor model-parallel groups and 2 pipeline model-parallel groups: - 4 tensor model-parallel groups: - [g0, g1], [g2, g3], [g4, g5], [g6, g7] - 2 pipeline model-parallel groups: - [g0, g2, g4, g6], [g1, g3, g5, g7] - Note that for efficiency, the caller should make sure adjacent ranks - are on the same DGX box. For example if we are using 2 DGX-1 boxes - with a total of 16 GPUs, rank 0 to 7 belong to the first box and - ranks 8 to 15 belong to the second box. - """ - # Get world size and rank. Ensure some consistencies. - assert torch.distributed.is_initialized() - world_size: int = torch.distributed.get_world_size() - rank = torch.distributed.get_rank() - backend = backend or torch.distributed.get_backend( - get_world_group().device_group) - - data_parallel_size = 1 - from vllm.config import get_current_vllm_config - config = get_current_vllm_config() - if config is not None: - data_parallel_size = config.parallel_config.data_parallel_size - - # the layout order is: ExternalDP x DP x PP x TP - # ExternalDP is the data parallel group that is not part of the model, - # every dp rank can generate independently (in verl integration). - # DP is the data parallel group that is part of the model, - # all the ranks in the same DP group should generate simultaneously, - # i.e. the `generate` call in the same DP group should be called together, - # otherwise it will cause deadlock. - # to get group_ranks for each dimension, transpose that dimension to the - # last dimension, then reshape to 2D, then unbind the last dimension - all_ranks = torch.arange(world_size).reshape( - -1, data_parallel_size, pipeline_model_parallel_size, - tensor_model_parallel_size) # noqa - - # Build the tensor model-parallel groups. - global _TP - assert _TP is None, ("tensor model parallel group is already initialized") - group_ranks = all_ranks.view(-1, tensor_model_parallel_size).unbind(0) - group_ranks = [x.tolist() for x in group_ranks] - - # message queue broadcaster is only used in tensor model parallel group - _TP = init_model_parallel_group(group_ranks, - get_world_group().local_rank, - backend, - use_message_queue_broadcaster=True, - group_name="tp") - - # Build the pipeline model-parallel groups. - global _PP - assert _PP is None, ( - "pipeline model parallel group is already initialized") - group_ranks = all_ranks.transpose(2, 3).reshape( - -1, pipeline_model_parallel_size).unbind(0) - group_ranks = [x.tolist() for x in group_ranks] - _PP = init_model_parallel_group(group_ranks, - get_world_group().local_rank, - backend, - group_name="pp") - - global _DP - assert _DP is None, ("data parallel group is already initialized") - group_ranks = all_ranks.transpose(1, - 3).reshape(-1, - data_parallel_size).unbind(0) - group_ranks = [x.tolist() for x in group_ranks] - _DP = init_model_parallel_group(group_ranks, - get_world_group().local_rank, - backend, - group_name="dp") - - logger.info( - "rank %s in world size %s is assigned as " - "DP rank %s, PP rank %s, TP rank %s", rank, world_size, - _DP.rank_in_group, _PP.rank_in_group, _TP.rank_in_group) - - -def ensure_model_parallel_initialized( - tensor_model_parallel_size: int, - pipeline_model_parallel_size: int, - backend: Optional[str] = None, -) -> None: - """Helper to initialize model parallel groups if they are not initialized, - or ensure tensor-parallel and pipeline-parallel sizes are equal to expected - values if the model parallel groups are initialized. - """ - backend = backend or torch.distributed.get_backend( - get_world_group().device_group) - if not model_parallel_is_initialized(): - initialize_model_parallel(tensor_model_parallel_size, - pipeline_model_parallel_size, backend) - return - - assert ( - get_tensor_model_parallel_world_size() == tensor_model_parallel_size - ), ("tensor parallel group already initialized, but of unexpected size: " - f"{get_tensor_model_parallel_world_size()=} vs. " - f"{tensor_model_parallel_size=}") - pp_world_size = get_pp_group().world_size - assert (pp_world_size == pipeline_model_parallel_size), ( - "pipeline parallel group already initialized, but of unexpected size: " - f"{pp_world_size=} vs. " - f"{pipeline_model_parallel_size=}") - - -def model_parallel_is_initialized(): - """Check if tensor and pipeline parallel groups are initialized.""" - return (_TP is not None and _PP is not None) - - -_TP_STATE_PATCHED = False - - -@contextmanager -def patch_tensor_parallel_group(tp_group: GroupCoordinator): - """Patch the tp group temporarily until this function ends. - - This method is for draft workers of speculative decoding to run draft model - with different tp degree from that of target model workers. - - Args: - tp_group (GroupCoordinator): the tp group coordinator - """ - global _TP_STATE_PATCHED - assert not _TP_STATE_PATCHED, "Should not call when it's already patched" - - _TP_STATE_PATCHED = True - old_tp_group = get_tp_group() - global _TP - _TP = tp_group - try: - yield - finally: - # restore the original state - _TP_STATE_PATCHED = False - _TP = old_tp_group - - -def get_tensor_model_parallel_world_size(): - """Return world size for the tensor model parallel group.""" - return get_tp_group().world_size - - -def get_tensor_model_parallel_rank(): - """Return my rank for the tensor model parallel group.""" - return get_tp_group().rank_in_group - - -def destroy_model_parallel(): - """Set the groups to none and destroy them.""" - global _TP - if _TP: - _TP.destroy() - _TP = None - - global _PP - if _PP: - _PP.destroy() - _PP = None - - global _DP - if _DP: - _DP.destroy() - _DP = None - - -def destroy_distributed_environment(): - global _WORLD - if _WORLD: - _WORLD.destroy() - _WORLD = None - if torch.distributed.is_initialized(): - torch.distributed.destroy_process_group() - - -def cleanup_dist_env_and_memory(shutdown_ray: bool = False): - destroy_model_parallel() - destroy_distributed_environment() - with contextlib.suppress(AssertionError): - torch.distributed.destroy_process_group() - if shutdown_ray: - import ray # Lazy import Ray - ray.shutdown() - gc.collect() - from vllm.platforms import current_platform - if not current_platform.is_cpu(): - torch.cuda.empty_cache() - """ - try: - torch._C._host_emptyCache() - except AttributeError: - logger.warning( - "torch._C._host_emptyCache() only available in Pytorch >=2.5") - """ - - -def in_the_same_node_as(pg: Union[ProcessGroup, StatelessProcessGroup], - source_rank: int = 0) -> List[bool]: - """ - This is a collective operation that returns if each rank is in the same node - as the source rank. It tests if processes are attached to the same - memory system (shared access to shared memory). - """ - if isinstance(pg, ProcessGroup): - assert torch.distributed.get_backend( - pg) != torch.distributed.Backend.NCCL, ( - "in_the_same_node_as should be tested with a non-NCCL group.") - # local rank inside the group - rank = torch.distributed.get_rank(group=pg) - world_size = torch.distributed.get_world_size(group=pg) - - # global ranks of the processes in the group - ranks = torch.distributed.get_process_group_ranks(pg) - else: - rank = pg.rank - world_size = pg.world_size - ranks = list(range(world_size)) - - # local tensor in each process to store the result - is_in_the_same_node = torch.tensor([0] * world_size, dtype=torch.int32) - - magic_message = b"magic_message" - shm = None - - try: - with contextlib.suppress(OSError): - if rank == source_rank: - # create a shared memory segment - shm = shared_memory.SharedMemory(create=True, size=128) - shm.buf[:len(magic_message)] = magic_message - if isinstance(pg, ProcessGroup): - torch.distributed.broadcast_object_list( - [shm.name], src=ranks[source_rank], group=pg) - else: - pg.broadcast_obj(shm.name, src=source_rank) - is_in_the_same_node[rank] = 1 - else: - # try to open the shared memory segment - if isinstance(pg, ProcessGroup): - recv = [None] - torch.distributed.broadcast_object_list( - recv, src=ranks[source_rank], group=pg) - name = recv[0] - else: - name = pg.broadcast_obj(None, src=source_rank) - # fix to https://stackoverflow.com/q/62748654/9191338 - # Python incorrectly tracks shared memory even if it is not - # created by the process. The following patch is a workaround. - with patch("multiprocessing.resource_tracker.register", - lambda *args, **kwargs: None): - shm = shared_memory.SharedMemory(name=name) - if shm.buf[:len(magic_message)] == magic_message: - is_in_the_same_node[rank] = 1 - except Exception as e: - logger.error("Error ignored in is_in_the_same_node: %s", e) - finally: - if shm: - shm.close() - - if isinstance(pg, ProcessGroup): - torch.distributed.barrier(group=pg) - else: - pg.barrier() - - # clean up the shared memory segment - with contextlib.suppress(OSError): - if rank == source_rank and shm: - shm.unlink() - - if isinstance(pg, ProcessGroup): - torch.distributed.all_reduce(is_in_the_same_node, group=pg) - aggregated_data = is_in_the_same_node - else: - aggregated_data = torch.zeros_like(is_in_the_same_node) - for i in range(world_size): - rank_data = pg.broadcast_obj(is_in_the_same_node, src=i) - aggregated_data += rank_data - - return [x == 1 for x in aggregated_data.tolist()] diff --git a/vllm_metax/envs.py b/vllm_metax/envs.py index 08b54a833..98a697bb0 100644 --- a/vllm_metax/envs.py +++ b/vllm_metax/envs.py @@ -14,67 +14,44 @@ MACA_VLLM_USE_TN_2_NN: bool = True environment_variables: dict[str, Callable[[], Any]] = { - # ================== Installation Time Env Vars ================== - # Target device of vLLM, supporting [cuda (by default), # rocm, neuron, cpu] - "VLLM_TARGET_DEVICE": - lambda: os.getenv("VLLM_TARGET_DEVICE", "cuda"), - + "VLLM_TARGET_DEVICE": lambda: os.getenv("VLLM_TARGET_DEVICE", "cuda"), # Maximum number of compilation jobs to run in parallel. # By default this is the number of CPUs - "MAX_JOBS": - lambda: os.getenv("MAX_JOBS", None), - + "MAX_JOBS": lambda: os.getenv("MAX_JOBS", None), # Number of threads to use for nvcc # By default this is 1. # If set, `MAX_JOBS` will be reduced to avoid oversubscribing the CPU. - "NVCC_THREADS": - lambda: os.getenv("NVCC_THREADS", None), - + "NVCC_THREADS": lambda: os.getenv("NVCC_THREADS", None), # If set, vllm will use precompiled binaries (*.so) - "VLLM_USE_PRECOMPILED": - lambda: bool(os.environ.get("VLLM_USE_PRECOMPILED")) or bool( - os.environ.get("VLLM_PRECOMPILED_WHEEL_LOCATION")), - + "VLLM_USE_PRECOMPILED": lambda: bool(os.environ.get("VLLM_USE_PRECOMPILED")) + or bool(os.environ.get("VLLM_PRECOMPILED_WHEEL_LOCATION")), # CMake build type # If not set, defaults to "Debug" or "RelWithDebInfo" # Available options: "Debug", "Release", "RelWithDebInfo" - "CMAKE_BUILD_TYPE": - lambda: os.getenv("CMAKE_BUILD_TYPE"), - + "CMAKE_BUILD_TYPE": lambda: os.getenv("CMAKE_BUILD_TYPE"), # If set, vllm will print verbose logs during installation - "VERBOSE": - lambda: bool(int(os.getenv('VERBOSE', '0'))), - + "VERBOSE": lambda: bool(int(os.getenv("VERBOSE", "0"))), # path to cudatoolkit home directory, under which should be bin, include, # and lib directories. - "CUDA_HOME": - lambda: os.environ.get("CUDA_HOME", None), - + "CUDA_HOME": lambda: os.environ.get("CUDA_HOME", None), # Path to the NCCL library file. It is needed because nccl>=2.19 brought # by PyTorch contains a bug: https://github.com/NVIDIA/nccl/issues/1234 - "VLLM_NCCL_SO_PATH": - lambda: os.environ.get("VLLM_NCCL_SO_PATH", None), - + "VLLM_NCCL_SO_PATH": lambda: os.environ.get("VLLM_NCCL_SO_PATH", None), # when `VLLM_NCCL_SO_PATH` is not set, vllm will try to find the nccl # library file in the locations specified by `LD_LIBRARY_PATH` - "LD_LIBRARY_PATH": - lambda: os.environ.get("LD_LIBRARY_PATH", None), - + "LD_LIBRARY_PATH": lambda: os.environ.get("LD_LIBRARY_PATH", None), # ================== Runtime Env Vars ================== - # When installing vllm from source, the version of vllm set by setuptool_scm # will be different from the version of vllm installed by pip. # (e.g. install vllm from source with tag v0.9.1 will cause the version set # as 0.9.2) - "VLLM_OFFICIAL_VERSION": - lambda: os.getenv("VLLM_OFFICIAL_VERSION", None), - + "VLLM_OFFICIAL_VERSION": lambda: os.getenv("VLLM_OFFICIAL_VERSION", None), # if set, enable loading weight by transpose - "MACA_VLLM_USE_TN_2_NN": - lambda: os.environ.get("MACA_VLLM_USE_TN_2_NN", "0") == "1", + "MACA_VLLM_USE_TN_2_NN": lambda: os.environ.get("MACA_VLLM_USE_TN_2_NN", "0") + == "1", } # end-env-vars-definition @@ -95,4 +72,4 @@ def is_set(name: str): """Check if an environment variable is explicitly set.""" if name in environment_variables: return name in os.environ - raise AttributeError(f"module {__name__!r} has no attribute {name!r}") \ No newline at end of file + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/vllm_metax/distributed/__init__.py b/vllm_metax/hotfix/__init__.py similarity index 100% rename from vllm_metax/distributed/__init__.py rename to vllm_metax/hotfix/__init__.py diff --git a/vllm_metax/patch/hotfix/support_ds_32.patch b/vllm_metax/hotfix/support_ds_32.patch similarity index 100% rename from vllm_metax/patch/hotfix/support_ds_32.patch rename to vllm_metax/hotfix/support_ds_32.patch diff --git a/vllm_metax/model_executor/layers/fused_moe/__init__.py b/vllm_metax/model_executor/layers/fused_moe/__init__.py index 75fdb6607..430aade3d 100644 --- a/vllm_metax/model_executor/layers/fused_moe/__init__.py +++ b/vllm_metax/model_executor/layers/fused_moe/__init__.py @@ -13,6 +13,6 @@ def get_config() -> Optional[dict[str, Any]]: if HAS_TRITON: # import to register the custom ops - from vllm_metax.model_executor.layers.fused_moe.fused_moe import ( - fused_experts) + from vllm_metax.model_executor.layers.fused_moe.fused_moe import fused_experts + __all__ = ["fused_experts"] diff --git a/vllm_metax/model_executor/layers/fused_moe/fused_moe.py b/vllm_metax/model_executor/layers/fused_moe/fused_moe.py index 5d9ccf353..4434deeb1 100644 --- a/vllm_metax/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm_metax/model_executor/layers/fused_moe/fused_moe.py @@ -1,14 +1,15 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Fused MoE Triton kernels.""" + import functools import json import math import os + # torch.compile needs typing.List. It will fail torch.library.infer_schema # otherwise -from typing import List # noqa: UP035 -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, cast import torch import torch.nn.functional as F @@ -16,100 +17,124 @@ import vllm.envs as envs import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm import _custom_ops as ops -from vllm.logger import init_logger -# yapf: disable +from vllm._aiter_ops import rocm_aiter_ops +from vllm.logger import logger +from vllm.model_executor.layers.batch_invariant import ( + vllm_is_batch_invariant, +) from vllm.model_executor.layers.fused_moe.config import ( - FUSED_MOE_UNQUANTIZED_CONFIG, FusedMoEQuantConfig) + FUSED_MOE_UNQUANTIZED_CONFIG, + FusedMoEQuantConfig, +) from vllm.model_executor.layers.fused_moe.cutlass_moe import ( _valid_cutlass_block_scaled_grouped_gemm, - run_cutlass_block_scaled_fused_experts) + run_cutlass_block_scaled_fused_experts, +) + # yapf: enable from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( - _valid_deep_gemm, deep_gemm_moe_fp8) + _valid_deep_gemm, + deep_gemm_moe_fp8, +) from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( - moe_align_block_size) + moe_align_block_size, +) from vllm.model_executor.layers.fused_moe.prepare_finalize import ( - MoEPrepareAndFinalizeNoEP) + MoEPrepareAndFinalizeNoEP, +) from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( - TopKWeightAndReduceNoOP) + TopKWeightAndReduceNoOP, +) from vllm.model_executor.layers.fused_moe.utils import ( - _resize_cache, activation_without_mul, moe_kernel_quantize_input) -from vllm.model_executor.layers.quantization.utils.mxfp4_utils import ( - dequant_mxfp4) + _resize_cache, + activation_without_mul, + disable_inplace, + moe_kernel_quantize_input, +) +from vllm.model_executor.layers.quantization.utils.mxfp4_utils import dequant_mxfp4 +from vllm.model_executor.layers.quantization.utils.mxfp6_utils import dequant_mxfp6 +from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import OCP_MX_Scheme +from vllm.model_executor.utils import maybe_disable_graph_partition from vllm.platforms import current_platform from vllm.triton_utils import tl, triton -from vllm.utils import direct_register_custom_op, is_torch_equal_or_newer from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used +from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer -from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled from vllm_metax import _custom_ops as mx_ops -logger = init_logger(__name__) - @triton.jit -def write_zeros_to_output(c_ptr, stride_cm, stride_cn, pid_n, N, offs_token, - token_mask, BLOCK_SIZE_M, BLOCK_SIZE_N, - compute_type): +def write_zeros_to_output( + c_ptr, + stride_cm, + stride_cn, + pid_n, + N, + offs_token, + token_mask, + BLOCK_SIZE_M, + BLOCK_SIZE_N, + compute_type, +): accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=compute_type) offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[ - None, :] + c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :] c_mask = token_mask[:, None] & (offs_cn[None, :] < N) tl.store(c_ptrs, accumulator, mask=c_mask) @triton.jit def fused_moe_kernel_gptq_awq( - # Pointers to matrices - a_ptr, - b_ptr, - c_ptr, - b_scale_ptr, - b_zp_ptr, - topk_weights_ptr, - sorted_token_ids_ptr, - expert_ids_ptr, - num_tokens_post_padded_ptr, - # Matrix dimensions - N: tl.constexpr, - K: tl.constexpr, - EM, - num_valid_tokens, - # The stride variables represent how much to increase the ptr by when - # moving by 1 element in a particular dimension. E.g. `stride_am` is - # how much to increase `a_ptr` by to get the element one row down - # (A has M rows). - stride_am, - stride_ak, - stride_be, - stride_bk, - stride_bn, - stride_cm, - stride_cn, - stride_bse, - stride_bsk, - stride_bsn, - stride_bze, - stride_bzk, - stride_bzn, - block_k_diviable: tl.constexpr, - group_size: tl.constexpr, - # Meta-parameters - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, - # ┌------------------------ Metax Modification -------------------------┐ - SPLIT_K: tl.constexpr, - ACCF32: tl.constexpr, - # └------------------------- Metax Modification -------------------------┘ - MUL_ROUTED_WEIGHT: tl.constexpr, - top_k: tl.constexpr, - compute_type: tl.constexpr, - has_zp: tl.constexpr, - use_int4_w4a16: tl.constexpr, - use_int8_w8a16: tl.constexpr): + # Pointers to matrices + a_ptr, + b_ptr, + c_ptr, + b_scale_ptr, + b_zp_ptr, + topk_weights_ptr, + sorted_token_ids_ptr, + expert_ids_ptr, + num_tokens_post_padded_ptr, + # Matrix dimensions + N: tl.constexpr, + K: tl.constexpr, + EM, + num_valid_tokens, + # The stride variables represent how much to increase the ptr by when + # moving by 1 element in a particular dimension. E.g. `stride_am` is + # how much to increase `a_ptr` by to get the element one row down + # (A has M rows). + stride_am, + stride_ak, + stride_be, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_bse, + stride_bsk, + stride_bsn, + stride_bze, + stride_bzk, + stride_bzn, + block_k_diviable: tl.constexpr, + group_size: tl.constexpr, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + # ┌------------------------ Metax Modification -------------------------┐ + SPLIT_K: tl.constexpr, + ACCF32: tl.constexpr, + # └------------------------- Metax Modification -------------------------┘ + MUL_ROUTED_WEIGHT: tl.constexpr, + top_k: tl.constexpr, + compute_type: tl.constexpr, + has_zp: tl.constexpr, + use_int4_w4a16: tl.constexpr, + use_int8_w8a16: tl.constexpr, +): """ Implements the fused computation for a Mixture of Experts (MOE) using token and expert matrices. @@ -158,8 +183,7 @@ def fused_moe_kernel_gptq_awq( num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: return - offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to( - tl.int64) + offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64) offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) token_mask = offs_token < num_valid_tokens @@ -168,25 +192,41 @@ def fused_moe_kernel_gptq_awq( # ----------------------------------------------------------- # Write back zeros to the output when the expert is not # in the current expert parallel rank. - write_zeros_to_output(c_ptr, stride_cm, stride_cn, pid_n, N, - offs_token, token_mask, BLOCK_SIZE_M, - BLOCK_SIZE_N, compute_type) + write_zeros_to_output( + c_ptr, + stride_cm, + stride_cn, + pid_n, + N, + offs_token, + token_mask, + BLOCK_SIZE_M, + BLOCK_SIZE_N, + compute_type, + ) return - offs_bn = (pid_n * BLOCK_SIZE_N + - tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N offs_k = tl.arange(0, BLOCK_SIZE_K) - a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am + - offs_k[None, :] * stride_ak) + a_ptrs = a_ptr + ( + offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak + ) if use_int4_w4a16: - b_ptrs = b_ptr + off_experts * stride_be + \ - (offs_k[:, None] // 2) * stride_bk + offs_bn[None, :] * \ - stride_bn + b_ptrs = ( + b_ptr + + off_experts * stride_be + + (offs_k[:, None] // 2) * stride_bk + + offs_bn[None, :] * stride_bn + ) b_shifter = (offs_k[:, None] % 2) * 4 elif use_int8_w8a16: - b_ptrs = b_ptr + off_experts * stride_be + \ - offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn + b_ptrs = ( + b_ptr + + off_experts * stride_be + + offs_k[:, None] * stride_bk + + offs_bn[None, :] * stride_bn + ) if not has_zp and use_int4_w4a16: b_zp_num = 8 @@ -212,34 +252,43 @@ def fused_moe_kernel_gptq_awq( k_mask = None k_other = None - a = tl.load(a_ptrs, - mask=token_mask[:, None] & - (offs_k[None, :] < K - k * BLOCK_SIZE_K), - other=0.0) + a = tl.load( + a_ptrs, + mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K), + other=0.0, + ) b = tl.load(b_ptrs) if use_int4_w4a16: b = (b >> b_shifter) & 0xF - b_scale_ptrs = b_scale_ptr + off_experts * stride_bse + \ - offs_bn[None, :] * stride_bsn + \ - ((offs_k[:, None] + BLOCK_SIZE_K * k) // group_size) * \ - stride_bsk + b_scale_ptrs = ( + b_scale_ptr + + off_experts * stride_bse + + offs_bn[None, :] * stride_bsn + + ((offs_k[:, None] + BLOCK_SIZE_K * k) // group_size) * stride_bsk + ) b_scale = tl.load(b_scale_ptrs, mask=k_mask, other=k_other) b_scale = b_scale.to(tl.float32) if has_zp and use_int4_w4a16: offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size - b_zp_ptrs = b_zp_ptr + off_experts * stride_bze + \ - (offs_bn[None, :] // 2) * stride_bzn + \ - offs_k_true * stride_bzk + b_zp_ptrs = ( + b_zp_ptr + + off_experts * stride_bze + + (offs_bn[None, :] // 2) * stride_bzn + + offs_k_true * stride_bzk + ) b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other) - b_zp = ((b_zp >> b_zp_shifter) & 0xF) + b_zp = (b_zp >> b_zp_shifter) & 0xF b_zp = b_zp.to(tl.float32) elif has_zp and use_int8_w8a16: offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size - b_zp_ptrs = b_zp_ptr + off_experts * stride_bze + \ - offs_bn[None, :] * stride_bzn + \ - offs_k_true * stride_bzk + b_zp_ptrs = ( + b_zp_ptr + + off_experts * stride_bze + + offs_bn[None, :] * stride_bzn + + offs_k_true * stride_bzk + ) b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other) b_zp = b_zp.to(tl.float32) @@ -258,40 +307,46 @@ def fused_moe_kernel_gptq_awq( b_ptrs += BLOCK_SIZE_K * stride_bk if MUL_ROUTED_WEIGHT: - moe_weight = tl.load(topk_weights_ptr + offs_token, - mask=token_mask, - other=0) + moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0) accumulator = accumulator * moe_weight[:, None] accumulator = accumulator.to(compute_type) # ----------------------------------------------------------- # Write back the block of the output offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[ - None, :] + c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :] c_mask = token_mask[:, None] & (offs_cn[None, :] < N) tl.store(c_ptrs, accumulator, mask=c_mask) # ┌------------------------ Metax Modification -------------------------┐ -@triton.heuristics({ - "UPGRADE": - lambda args: math.ceil( - (args["EM"] * args["N"]) / - (args["BLOCK_SIZE_M"] * args["BLOCK_SIZE_N"])).bit_length() > 31, -}) -@triton.heuristics({ - "UPGRADE_A_OFFS": - lambda args: - (args["num_valid_tokens"] // args["top_k"] * args["stride_am"] + args[ - "BLOCK_SIZE_K"] * args["stride_ak"]).bit_length() > 31, -}) -@triton.heuristics({ - "UPGRADE_B_OFFS": - lambda args: ((args["E"] - 1) * args["stride_be"] + - (args["N"] - 1) * args["stride_bn"] + - (args["K"] - 1) * args["stride_bk"]).bit_length() > 31, -}) +@triton.heuristics( + { + "UPGRADE": lambda args: math.ceil( + (args["EM"] * args["N"]) / (args["BLOCK_SIZE_M"] * args["BLOCK_SIZE_N"]) + ).bit_length() + > 31, + } +) +@triton.heuristics( + { + "UPGRADE_A_OFFS": lambda args: ( + args["num_valid_tokens"] // args["top_k"] * args["stride_am"] + + args["BLOCK_SIZE_K"] * args["stride_ak"] + ).bit_length() + > 31, + } +) +@triton.heuristics( + { + "UPGRADE_B_OFFS": lambda args: ( + (args["E"] - 1) * args["stride_be"] + + (args["N"] - 1) * args["stride_bn"] + + (args["K"] - 1) * args["stride_bk"] + ).bit_length() + > 31, + } +) # └------------------------- Metax Modification -------------------------┘ @triton.jit def fused_moe_kernel( @@ -356,7 +411,7 @@ def fused_moe_kernel( UPGRADE: tl.constexpr, UPGRADE_A_OFFS: tl.constexpr, UPGRADE_B_OFFS: tl.constexpr, - FAST_F32_TO_BF16: tl.constexpr + FAST_F32_TO_BF16: tl.constexpr, # └------------------------- Metax Modification -------------------------┘ ): """ @@ -396,7 +451,7 @@ def fused_moe_kernel( else: pid = tl.program_id(axis=0) pid_z = tl.program_id(axis=1) -# └------------------------- Metax Modification -------------------------┘ + # └------------------------- Metax Modification -------------------------┘ num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) num_pid_in_group = GROUP_SIZE_M * num_pid_n @@ -415,8 +470,7 @@ def fused_moe_kernel( num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: return - offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to( - tl.int64) + offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64) offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) token_mask = offs_token < num_valid_tokens @@ -428,20 +482,28 @@ def fused_moe_kernel( if UPGRADE_A_OFFS: offs_token = offs_token.to(tl.int64) -# └------------------------- Metax Modification -------------------------┘ + # └------------------------- Metax Modification -------------------------┘ if off_experts == -1: # ----------------------------------------------------------- # Write back zeros to the output when the expert is not # in the current expert parallel rank. - write_zeros_to_output(c_ptr, stride_cm, stride_cn, pid_n, N, - offs_token, token_mask, BLOCK_SIZE_M, - BLOCK_SIZE_N, compute_type) + write_zeros_to_output( + c_ptr, + stride_cm, + stride_cn, + pid_n, + N, + offs_token, + token_mask, + BLOCK_SIZE_M, + BLOCK_SIZE_N, + compute_type, + ) return -# ┌------------------------ Metax Modification -------------------------┐ + # ┌------------------------ Metax Modification -------------------------┐ if UPGRADE_B_OFFS: - offs_bn = (pid_n * BLOCK_SIZE_N + - tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N else: offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N @@ -449,43 +511,51 @@ def fused_moe_kernel( offs_k = pid_z * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) # └------------------------- Metax Modification -------------------------┘ - a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am + - offs_k[None, :] * stride_ak) + a_ptrs = a_ptr + ( + offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak + ) - b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk + - offs_bn[None, :] * stride_bn) + b_ptrs = ( + b_ptr + + off_experts * stride_be + + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + ) if use_int8_w8a16: - b_scale_ptrs = b_scale_ptr + off_experts * stride_bse + offs_bn[ - None, :] * stride_bsn + b_scale_ptrs = ( + b_scale_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsn + ) b_scale = tl.load(b_scale_ptrs) -# ┌------------------------ Metax Modification -------------------------┐ + # ┌------------------------ Metax Modification -------------------------┐ if use_int8_w8a8: - a_scale = tl.load(a_scale_ptr + - (offs_token[:, None] // top_k * stride_asm), - mask=token_mask[:, None], - other=0.0) - b_scale_ptrs = b_scale_ptr + off_experts * stride_bse + offs_bn[ - None, :] * stride_bsn + a_scale = tl.load( + a_scale_ptr + (offs_token[:, None] // top_k * stride_asm), + mask=token_mask[:, None], + other=0.0, + ) + b_scale_ptrs = ( + b_scale_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsn + ) b_scale = tl.load(b_scale_ptrs) -# └------------------------- Metax Modification -------------------------┘ + # └------------------------- Metax Modification -------------------------┘ if use_fp8_w8a8: # block-wise if group_k > 0 and group_n > 0: a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm offs_bsn = offs_bn // group_n - b_scale_ptrs = (b_scale_ptr + off_experts * stride_bse + - offs_bsn * stride_bsn) + b_scale_ptrs = ( + b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn + ) # channel-wise elif per_channel_quant: - b_scale_ptrs = b_scale_ptr + off_experts * stride_bse + offs_bn[ - None, :] * stride_bsn + b_scale_ptrs = ( + b_scale_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsn + ) b_scale = tl.load(b_scale_ptrs) # Load per-token scale for activations a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm - a_scale = tl.load(a_scale_ptrs, mask=token_mask, other=0.0)[:, - None] + a_scale = tl.load(a_scale_ptrs, mask=token_mask, other=0.0)[:, None] # tensor-wise else: a_scale = tl.load(a_scale_ptr) @@ -502,21 +572,24 @@ def fused_moe_kernel( # ┌------------------------ Metax Modification -------------------------┐ # accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), - dtype=tl.int32 if use_int8_w8a8 else tl.float32) + accumulator = tl.zeros( + (BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.int32 if use_int8_w8a8 else tl.float32 + ) # └------------------------- Metax Modification -------------------------┘ # ┌------------------------ Metax Modification -------------------------┐ for k in range(0, tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)): # Load the next block of A and B, generate a mask by checking the # K dimension. - a = tl.load(a_ptrs, - mask=token_mask[:, None] & - (offs_k[None, :] < K - k * BLOCK_SIZE_K * SPLIT_K), - other=0.0) - b = tl.load(b_ptrs, - mask=offs_k[:, None] < K - k * BLOCK_SIZE_K * SPLIT_K, - other=0.0) + a = tl.load( + a_ptrs, + mask=token_mask[:, None] + & (offs_k[None, :] < K - k * BLOCK_SIZE_K * SPLIT_K), + other=0.0, + ) + b = tl.load( + b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K * SPLIT_K, other=0.0 + ) # We accumulate along the K dimension. if use_int8_w8a16: accumulator = tl.dot(a, b.to(compute_type), acc=accumulator) @@ -527,13 +600,12 @@ def fused_moe_kernel( if group_k > 0 and group_n > 0: k_start = k * BLOCK_SIZE_K * SPLIT_K offs_ks = k_start // group_k - a_scale = tl.load(a_scale_ptrs + offs_ks * stride_ask, - mask=token_mask, - other=0.0) + a_scale = tl.load( + a_scale_ptrs + offs_ks * stride_ask, mask=token_mask, other=0.0 + ) b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk) - accumulator += tl.dot(a, b) * a_scale[:, - None] * b_scale[None, :] + accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :] else: if use_fp8_w8a8: # acc used to enable fp8_fast_accum @@ -545,25 +617,23 @@ def fused_moe_kernel( # Advance the ptrs to the next K block. a_ptrs += BLOCK_SIZE_K * stride_ak * SPLIT_K b_ptrs += BLOCK_SIZE_K * stride_bk * SPLIT_K -# └------------------------- Metax Modification -------------------------┘ + # └------------------------- Metax Modification -------------------------┘ if MUL_ROUTED_WEIGHT: - moe_weight = tl.load(topk_weights_ptr + offs_token, - mask=token_mask, - other=0) + moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0) accumulator = accumulator * moe_weight[:, None] if use_int8_w8a16: accumulator = (accumulator * b_scale).to(compute_type) -# ┌------------------------ Metax Modification -------------------------┐ + # ┌------------------------ Metax Modification -------------------------┐ elif use_int8_w8a8: accumulator = accumulator.to(tl.float32) - accumulator = (accumulator * a_scale * b_scale) + accumulator = accumulator * a_scale * b_scale if not ACCF32: if FAST_F32_TO_BF16: accumulator = accumulator.to(compute_type, "rtne_no_nan") else: accumulator = accumulator.to(compute_type) -# └------------------------- Metax Modification -------------------------┘ + # └------------------------- Metax Modification -------------------------┘ elif use_fp8_w8a8: if group_k > 0 and group_n > 0: accumulator = accumulator.to(compute_type) @@ -575,27 +645,24 @@ def fused_moe_kernel( # ----------------------------------------------------------- # Write back the block of the output offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[ - None, :] + c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :] c_mask = token_mask[:, None] & (offs_cn[None, :] < N) # ┌------------------------ Metax Modification -------------------------┐ if SPLIT_K == 1: tl.store(c_ptrs, accumulator, mask=c_mask) else: tl.atomic_add(c_ptrs, accumulator, mask=c_mask) - - -# └------------------------- Metax Modification -------------------------┘ + # └------------------------- Metax Modification -------------------------┘ # yapf: disable def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, - A_scale: Optional[torch.Tensor], - B_scale: Optional[torch.Tensor], - B_zp: Optional[torch.Tensor], - topk_weights: Optional[torch.Tensor], + A_scale: torch.Tensor | None, + B_scale: torch.Tensor | None, + B_zp: torch.Tensor | None, + topk_weights: torch.Tensor | None, sorted_token_ids: torch.Tensor, expert_ids: torch.Tensor, num_tokens_post_padded: torch.Tensor, @@ -611,8 +678,8 @@ def invoke_fused_moe_kernel(A: torch.Tensor, orig_acc_dtype: torch.dtype, # └------------------------- Metax Modification -------------------------┘ per_channel_quant: bool, - block_shape: Optional[list[int]] = None, - B_bias: Optional[torch.Tensor] = None) -> None: + block_shape: list[int] | None = None, + B_bias: torch.Tensor | None = None) -> None: assert topk_weights is not None or not mul_routed_weight assert topk_weights is None or topk_weights.stride(1) == 1 assert sorted_token_ids.stride(0) == 1 @@ -782,22 +849,22 @@ def invoke_fused_moe_kernel(A: torch.Tensor, # └------------------------- Metax Modification -------------------------┘ **config, ) -# ┌------------------------ Metax Modification -------------------------┐ + # ┌------------------------ Metax Modification -------------------------┐ if config["ACCF32"]: C = C.to(orig_acc_dtype) -# └------------------------- Metax Modification -------------------------┘ + # └------------------------- Metax Modification -------------------------┘ # yapf: enable # Adapted from: https://github.com/sgl-project/sglang/pull/2628 -def get_config_file_name(E: int, - N: int, - dtype: Optional[str], - block_shape: Optional[list[int]] = None) -> str: +def get_config_file_name( + E: int, N: int, dtype: str | None, block_shape: list[int] | None = None +) -> str: device_name = current_platform.get_device_name().replace(" ", "_") dtype_selector = "" if not dtype else f",dtype={dtype}" - block_shape_selector = ("" if not block_shape or not all(block_shape) else - f",block_shape={block_shape}").replace(" ", "") + block_shape_selector = ( + "" if not block_shape or not all(block_shape) else f",block_shape={block_shape}" + ).replace(" ", "") return f"E={E},N={N},device_name={device_name}{dtype_selector}{block_shape_selector}.json" # noqa: E501 @@ -806,13 +873,13 @@ def get_config_file_name(E: int, def get_moe_configs( E: int, N: int, - dtype: Optional[str], - block_n: Optional[int] = None, - block_k: Optional[int] = None, + dtype: str | None, + block_n: int | None = None, + block_k: int | None = None, # ┌------------------------ Metax Modification -------------------------┐ H: int = 0, # └------------------------- Metax Modification -------------------------┘ -) -> Optional[dict[int, Any]]: +) -> dict[int, Any] | None: """ Return optimized configurations for the fused MoE kernel. @@ -822,6 +889,10 @@ def get_moe_configs( be picked and the associated configuration chosen to invoke the kernel. """ + # Avoid optimizing for the batch invariant case. Use default config + if vllm_is_batch_invariant(): + return None + # First look up if an optimized configuration is available in the configs # directory block_shape = [block_n, block_k] if block_n and block_k else None @@ -833,18 +904,21 @@ def get_moe_configs( user_defined_config_folder = envs.VLLM_TUNED_CONFIG_FOLDER if user_defined_config_folder is not None: user_defined_config_file_path = os.path.join( - user_defined_config_folder, json_file_name) + user_defined_config_folder, json_file_name + ) config_file_paths.append(user_defined_config_file_path) default_config_file_path = os.path.join( - os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name) + os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name + ) config_file_paths.append(default_config_file_path) for config_file_path in config_file_paths: if os.path.exists(config_file_path): with open(config_file_path) as f: - logger.info("Using configuration from %s for MoE layer.", - config_file_path) + logger.info( + "Using configuration from %s for MoE layer.", config_file_path + ) # If a configuration has been found, return it tuned_config = json.load(f) # Delete triton_version from tuned_config @@ -854,16 +928,26 @@ def get_moe_configs( # If no optimized configuration is available, we will use the default # configuration logger.warning( - ("Using default MoE config. Performance might be sub-optimal! " - "Config file not found at %s"), config_file_paths) + ( + "Using default MoE config. Performance might be sub-optimal! " + "Config file not found at %s" + ), + config_file_paths, + ) return None -def get_moe_wna16_block_config(config: dict[str, - int], use_moe_wna16_cuda: bool, - num_valid_tokens: int, size_k: int, size_n: int, - num_experts: int, group_size: int, - real_top_k: int, block_size_m: int): +def get_moe_wna16_block_config( + config: dict[str, int], + use_moe_wna16_cuda: bool, + num_valid_tokens: int, + size_k: int, + size_n: int, + num_experts: int, + group_size: int, + real_top_k: int, + block_size_m: int, +): if "BLOCK_SIZE_N" in config and "BLOCK_SIZE_K" in config: # optimal block config is set return {} @@ -885,20 +969,24 @@ def get_moe_wna16_block_config(config: dict[str, num_n_blocks = size_k // block_size_k num_k_blocks = size_n // block_size_k - num_m_blocks = (num_valid_tokens + block_size_m - 1) / block_size_m + \ - num_experts + num_m_blocks = ( + num_valid_tokens + block_size_m - 1 + ) / block_size_m + num_experts if num_valid_tokens // real_top_k <= block_size_m: num_m_blocks = min(num_m_blocks, num_valid_tokens) num_blocks = num_m_blocks * num_n_blocks * num_k_blocks - if size_k % 256 == 0 and num_blocks >= 256 and \ - block_size_k < 256: + if size_k % 256 == 0 and num_blocks >= 256 and block_size_k < 256: block_size_k = 256 num_blocks = num_blocks // (256 // block_size_k) - if num_m_blocks <= 16 and size_k % (block_size_k * 2) == 0 and \ - size_k % (block_size_k * 2) == 0 and block_size_k <= 512 and \ - num_blocks >= 512: + if ( + num_m_blocks <= 16 + and size_k % (block_size_k * 2) == 0 + and size_k % (block_size_k * 2) == 0 + and block_size_k <= 512 + and num_blocks >= 512 + ): block_size_k = block_size_k * 2 num_blocks = num_blocks // 2 @@ -917,10 +1005,15 @@ def get_moe_wna16_block_config(config: dict[str, return {"BLOCK_SIZE_N": block_size_n, "BLOCK_SIZE_K": block_size_k} -def should_moe_wna16_use_cuda(num_valid_tokens: int, group_size: int, - num_experts: int, bit: int): - return current_platform.is_cuda() and bit == 4 and \ - group_size in [32, 64, 128] and num_valid_tokens / num_experts <= 6 +def should_moe_wna16_use_cuda( + num_valid_tokens: int, group_size: int, num_experts: int, bit: int +): + return ( + current_platform.is_cuda() + and bit == 4 + and group_size in [32, 64, 128] + and num_valid_tokens / num_experts <= 6 + ) def get_default_config( @@ -929,9 +1022,19 @@ def get_default_config( N: int, K: int, topk: int, - dtype: Optional[str], - block_shape: Optional[list[int]] = None, + dtype: str | None, + block_shape: list[int] | None = None, ) -> dict[str, int]: + if vllm_is_batch_invariant(): + config = { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + "SPLIT_K": 1, + } + return config + if dtype == "fp8_w8a8" and block_shape is not None: # Block-wise quant: BLOCK_SIZE_N must be divisible by block_shape[0] # BLOCK_SIZE_K must be divisible by block_shape[1] @@ -942,6 +1045,7 @@ def get_default_config( "BLOCK_SIZE_N": block_shape[0], "BLOCK_SIZE_K": block_shape[1], "GROUP_SIZE_M": 32, + "SPLIT_K": 1, "num_warps": 4, "num_stages": 3 if not current_platform.is_rocm() else 2, } @@ -969,6 +1073,7 @@ def get_default_config( "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, + "SPLIT_K": 1, } else: config = { @@ -976,6 +1081,7 @@ def get_default_config( "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8, + "SPLIT_K": 1, } return config @@ -984,14 +1090,15 @@ def try_get_optimal_moe_config( w1_shape: tuple[int, ...], w2_shape: tuple[int, ...], top_k: int, - dtype: Optional[str], + dtype: str | None, M: int, - block_shape: Optional[list[int]] = None, + block_shape: list[int] | None = None, # ┌------------------------ Metax Modification -------------------------┐ H: int = 0, # └------------------------- Metax Modification -------------------------┘ ) -> dict[str, int]: from vllm.model_executor.layers.fused_moe import get_config + override_config = get_config() if override_config: config = override_config @@ -1010,31 +1117,33 @@ def try_get_optimal_moe_config( config = configs[min(configs.keys(), key=lambda x: abs(x - M))] else: # Else use the default config - config = get_default_config(M, E, N, w1_shape[2], top_k, dtype, - block_shape) + config = get_default_config(M, E, N, w1_shape[2], top_k, dtype, block_shape) return config -def vllm_topk_softmax(topk_weights: torch.Tensor, topk_indices: torch.Tensor, - token_expert_indices: torch.Tensor, - gating_output: torch.Tensor, - renormalize: bool) -> tuple[torch.Tensor, ...]: +def vllm_topk_softmax( + topk_weights: torch.Tensor, + topk_indices: torch.Tensor, + token_expert_indices: torch.Tensor, + gating_output: torch.Tensor, + renormalize: bool, +) -> tuple[torch.Tensor, ...]: ops.topk_softmax( topk_weights, topk_indices, token_expert_indices, gating_output, + renormalize, ) - if renormalize: - topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) return topk_weights, topk_indices -def dispatch_topk_func() -> Callable[..., tuple[torch.Tensor, ...]]: - if is_rocm_aiter_moe_enabled(): - from .rocm_aiter_fused_moe import rocm_aiter_topk_softmax - return rocm_aiter_topk_softmax +def dispatch_topk_func( + use_rocm_aiter: bool = False, +) -> Callable[..., tuple[torch.Tensor, ...]]: + if use_rocm_aiter: + return rocm_aiter_ops.topk_softmax return vllm_topk_softmax @@ -1043,39 +1152,39 @@ def fused_topk( gating_output: torch.Tensor, topk: int, renormalize: bool, - indices_type: Optional[torch.dtype] = None, + indices_type: torch.dtype | None = None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - assert hidden_states.size(0) == gating_output.size(0), ( - "Number of tokens mismatch") + assert hidden_states.size(0) == gating_output.size(0), "Number of tokens mismatch" M, _ = hidden_states.size() - topk_weights = torch.empty(M, - topk, - dtype=torch.float32, - device=hidden_states.device) + topk_weights = torch.empty( + M, topk, dtype=torch.float32, device=hidden_states.device + ) topk_ids = torch.empty( M, topk, dtype=torch.int32 if indices_type is None else indices_type, - device=hidden_states.device) - token_expert_indices = torch.empty(M, - topk, - dtype=torch.int32, - device=hidden_states.device) - - gating_output_float = gating_output.float() # TODO(woosuk): Optimize this. + device=hidden_states.device, + ) + token_expert_indices = torch.empty( + M, topk, dtype=torch.int32, device=hidden_states.device + ) - topk_func = dispatch_topk_func() - topk_weights, topk_ids = topk_func(topk_weights, topk_ids, - token_expert_indices, - gating_output_float, renormalize) + topk_func = dispatch_topk_func(use_rocm_aiter=rocm_aiter_ops.is_fused_moe_enabled()) + topk_weights, topk_ids = topk_func( + topk_weights, topk_ids, token_expert_indices, gating_output, renormalize + ) return topk_weights, topk_ids, token_expert_indices # This is used by the Deepseek-V2 and Deepseek-V3 model -@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend) +@torch.compile( + dynamic=True, + backend=current_platform.simple_compile_backend, + options=maybe_disable_graph_partition(current_platform.simple_compile_backend), +) def grouped_topk( hidden_states: torch.Tensor, gating_output: torch.Tensor, @@ -1085,12 +1194,15 @@ def grouped_topk( topk_group: int = 0, scoring_func: str = "softmax", routed_scaling_factor: float = 1.0, - e_score_correction_bias: Optional[torch.Tensor] = None, + e_score_correction_bias: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: - if envs.VLLM_USE_FUSED_MOE_GROUPED_TOPK and \ - current_platform.is_cuda() and \ - num_expert_group <= 32 and topk <= 32 and \ - e_score_correction_bias is not None: + if ( + envs.VLLM_USE_FUSED_MOE_GROUPED_TOPK + and current_platform.is_cuda() + and num_expert_group <= 32 + and topk <= 32 + and e_score_correction_bias is not None + ): return fused_grouped_topk( hidden_states=hidden_states, gating_output=gating_output, @@ -1100,10 +1212,10 @@ def grouped_topk( num_expert_group=num_expert_group, topk_group=topk_group, scoring_func=scoring_func, - routed_scaling_factor=routed_scaling_factor) + routed_scaling_factor=routed_scaling_factor, + ) - assert hidden_states.size(0) == gating_output.size(0), ( - "Number of tokens mismatch") + assert hidden_states.size(0) == gating_output.size(0), "Number of tokens mismatch" if scoring_func == "softmax": scores = torch.softmax(gating_output, dim=-1) @@ -1118,30 +1230,36 @@ def grouped_topk( # scores for expert selection but original scores for routing weights original_scores = scores scores = scores + e_score_correction_bias.unsqueeze(0) - group_scores = (scores.view(num_token, num_expert_group, - -1).topk(2, dim=-1)[0].sum(dim=-1)) + group_scores = ( + scores.view(num_token, num_expert_group, -1).topk(2, dim=-1)[0].sum(dim=-1) + ) else: - group_scores = scores.view(num_token, num_expert_group, - -1).max(dim=-1).values # [n, n_group] - group_idx = torch.topk(group_scores, k=topk_group, dim=-1, - sorted=False)[1] # [n, top_k_group] + group_scores = ( + scores.view(num_token, num_expert_group, -1).max(dim=-1).values + ) # [n, n_group] + + # For batch invariance, use sorted=True to ensure deterministic expert selection + use_sorted = vllm_is_batch_invariant() + group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=use_sorted)[ + 1 + ] # [n, top_k_group] group_mask = torch.zeros_like(group_scores) # [n, n_group] group_mask.scatter_(1, group_idx, 1) # [n, n_group] - score_mask = group_mask.unsqueeze(-1).expand( - num_token, num_expert_group, - scores.size(-1) // num_expert_group).reshape(num_token, -1) # [n, e] - tmp_scores = scores.masked_fill(~score_mask.bool(), - float("-inf")) # [n, e] + score_mask = ( + group_mask.unsqueeze(-1) + .expand(num_token, num_expert_group, scores.size(-1) // num_expert_group) + .reshape(num_token, -1) + ) # [n, e] + tmp_scores = scores.masked_fill(~score_mask.bool(), float("-inf")) # [n, e] if e_score_correction_bias is not None: - topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)[1] + topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=use_sorted)[1] # Use original unbiased scores for the routing weights topk_weights = original_scores.gather(1, topk_ids) else: - topk_weights, topk_ids = torch.topk(tmp_scores, - k=topk, - dim=-1, - sorted=False) + topk_weights, topk_ids = torch.topk( + tmp_scores, k=topk, dim=-1, sorted=use_sorted + ) if renormalize: topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) @@ -1153,12 +1271,13 @@ def grouped_topk( @torch.compile(dynamic=True, backend=current_platform.simple_compile_backend) def eplb_map_to_physical_and_record( - topk_ids: torch.Tensor, - expert_load_view: torch.Tensor, - logical_to_physical_map: torch.Tensor, - logical_replica_count: torch.Tensor, - indices_type: Optional[torch.dtype] = None) -> torch.Tensor: - ''' + topk_ids: torch.Tensor, + expert_load_view: torch.Tensor, + logical_to_physical_map: torch.Tensor, + logical_replica_count: torch.Tensor, + indices_type: torch.dtype | None = None, +) -> torch.Tensor: + """ Map the logical expert ids to physical expert ids and record the expert load metrics. @@ -1174,7 +1293,7 @@ def eplb_map_to_physical_and_record( Returns: The physical expert ids. - ''' + """ # 1. Convert the logical expert ids to physical expert ids # Directly select a random replica for each logical expert @@ -1186,13 +1305,14 @@ def eplb_map_to_physical_and_record( # to deterministically choose a replica replica_count = logical_replica_count[topk_ids_long] # Flatten-position based index, reshaped back to `topk_ids` shape - pos_indices = torch.arange(topk_ids.numel(), - device=topk_ids.device, - dtype=torch.long).reshape_as(topk_ids) + pos_indices = torch.arange( + topk_ids.numel(), device=topk_ids.device, dtype=torch.long + ).reshape_as(topk_ids) # Compute pseudo-random indices by modulo replica_indices = (pos_indices % replica_count).unsqueeze(-1) - physical_ids = logical_to_physical_map[topk_ids_long].gather( - -1, replica_indices).squeeze(-1) + physical_ids = ( + logical_to_physical_map[topk_ids_long].gather(-1, replica_indices).squeeze(-1) + ) topk_ids = physical_ids @@ -1217,7 +1337,8 @@ def eplb_map_to_physical_and_record( expert_load_view.scatter_add_( dim=0, index=topk_ids_flatten.long(), - src=torch.ones_like(topk_ids_flatten).to(expert_load_view)) + src=torch.ones_like(topk_ids_flatten).to(expert_load_view), + ) if indices_type is not None: topk_ids = topk_ids.to(dtype=indices_type) @@ -1235,8 +1356,7 @@ def fused_grouped_topk( scoring_func: str = "softmax", routed_scaling_factor: float = 1.0, ) -> tuple[torch.Tensor, torch.Tensor]: - assert hidden_states.size(0) == gating_output.size(0), ( - "Number of tokens mismatch") + assert hidden_states.size(0) == gating_output.size(0), "Number of tokens mismatch" if scoring_func == "softmax": scores = torch.softmax(gating_output, dim=-1) @@ -1247,8 +1367,14 @@ def fused_grouped_topk( scores_with_bias = scores + e_score_correction_bias.unsqueeze(0) topk_values, topk_indices = ops.grouped_topk( - scores, scores_with_bias.to(scores.dtype), num_expert_group, - topk_group, topk, renormalize, routed_scaling_factor) + scores, + scores_with_bias.to(scores.dtype), + num_expert_group, + topk_group, + topk, + renormalize, + routed_scaling_factor, + ) return topk_values.to(torch.float32), topk_indices.to(torch.int32) @@ -1264,26 +1390,47 @@ def inplace_fused_experts( use_int8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, - use_mxfp4_w4a4: bool = False, + ocp_mx_scheme: str | None = None, per_channel_quant: bool = False, global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - w1_scale: Optional[torch.Tensor] = None, - w2_scale: Optional[torch.Tensor] = None, - w1_zp: Optional[torch.Tensor] = None, - w2_zp: Optional[torch.Tensor] = None, - a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None, - block_shape: Optional[List[int]] = None, #noqa: UP006 - w1_bias: Optional[torch.Tensor] = None, - w2_bias: Optional[torch.Tensor] = None, + expert_map: torch.Tensor | None = None, + w1_scale: torch.Tensor | None = None, + w2_scale: torch.Tensor | None = None, + w1_zp: torch.Tensor | None = None, + w2_zp: torch.Tensor | None = None, + a1_scale: torch.Tensor | None = None, + a2_scale: torch.Tensor | None = None, + block_shape: list[int] | None = None, + w1_bias: torch.Tensor | None = None, + w2_bias: torch.Tensor | None = None, ) -> None: - fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True, - activation, apply_router_weight_on_input, use_fp8_w8a8, - use_int8_w8a8, use_int8_w8a16, use_int4_w4a16, - use_mxfp4_w4a4, per_channel_quant, global_num_experts, - expert_map, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, - a2_scale, block_shape, w1_bias, w2_bias) + fused_experts_impl( + hidden_states, + w1, + w2, + topk_weights, + topk_ids, + True, + activation, + apply_router_weight_on_input, + use_fp8_w8a8, + use_int8_w8a8, + use_int8_w8a16, + use_int4_w4a16, + ocp_mx_scheme, + per_channel_quant, + global_num_experts, + expert_map, + w1_scale, + w2_scale, + w1_zp, + w2_zp, + a1_scale, + a2_scale, + block_shape, + w1_bias, + w2_bias, + ) def inplace_fused_experts_fake( @@ -1298,19 +1445,19 @@ def inplace_fused_experts_fake( use_int8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, - use_mxfp4_w4a4: bool = False, + ocp_mx_scheme: str | None = None, per_channel_quant: bool = False, global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - w1_scale: Optional[torch.Tensor] = None, - w2_scale: Optional[torch.Tensor] = None, - w1_zp: Optional[torch.Tensor] = None, - w2_zp: Optional[torch.Tensor] = None, - a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None, - block_shape: Optional[List[int]] = None, #noqa: UP006 - w1_bias: Optional[torch.Tensor] = None, - w2_bias: Optional[torch.Tensor] = None, + expert_map: torch.Tensor | None = None, + w1_scale: torch.Tensor | None = None, + w2_scale: torch.Tensor | None = None, + w1_zp: torch.Tensor | None = None, + w2_zp: torch.Tensor | None = None, + a1_scale: torch.Tensor | None = None, + a2_scale: torch.Tensor | None = None, + block_shape: list[int] | None = None, + w1_bias: torch.Tensor | None = None, + w2_bias: torch.Tensor | None = None, ) -> None: pass @@ -1320,8 +1467,11 @@ def inplace_fused_experts_fake( op_func=inplace_fused_experts, mutates_args=["hidden_states"], fake_impl=inplace_fused_experts_fake, - tags=(() if is_torch_equal_or_newer("2.7.0") else - (torch.Tag.needs_fixed_stride_order, )), + tags=( + () + if is_torch_equal_or_newer("2.7.0") + else (torch.Tag.needs_fixed_stride_order,) + ), ) @@ -1337,26 +1487,47 @@ def outplace_fused_experts( use_int8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, - use_mxfp4_w4a4: bool = False, + ocp_mx_scheme: str | None = None, per_channel_quant: bool = False, global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - w1_scale: Optional[torch.Tensor] = None, - w2_scale: Optional[torch.Tensor] = None, - w1_zp: Optional[torch.Tensor] = None, - w2_zp: Optional[torch.Tensor] = None, - a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None, - block_shape: Optional[List[int]] = None, #noqa: UP006 - w1_bias: Optional[torch.Tensor] = None, - w2_bias: Optional[torch.Tensor] = None, + expert_map: torch.Tensor | None = None, + w1_scale: torch.Tensor | None = None, + w2_scale: torch.Tensor | None = None, + w1_zp: torch.Tensor | None = None, + w2_zp: torch.Tensor | None = None, + a1_scale: torch.Tensor | None = None, + a2_scale: torch.Tensor | None = None, + block_shape: list[int] | None = None, + w1_bias: torch.Tensor | None = None, + w2_bias: torch.Tensor | None = None, ) -> torch.Tensor: return fused_experts_impl( - hidden_states, w1, w2, topk_weights, topk_ids, False, activation, - apply_router_weight_on_input, use_fp8_w8a8, use_int8_w8a8, - use_int8_w8a16, use_int4_w4a16, use_mxfp4_w4a4, per_channel_quant, - global_num_experts, expert_map, w1_scale, w2_scale, w1_zp, w2_zp, - a1_scale, a2_scale, block_shape, w1_bias, w2_bias) + hidden_states, + w1, + w2, + topk_weights, + topk_ids, + False, + activation, + apply_router_weight_on_input, + use_fp8_w8a8, + use_int8_w8a8, + use_int8_w8a16, + use_int4_w4a16, + ocp_mx_scheme, + per_channel_quant, + global_num_experts, + expert_map, + w1_scale, + w2_scale, + w1_zp, + w2_zp, + a1_scale, + a2_scale, + block_shape, + w1_bias, + w2_bias, + ) def outplace_fused_experts_fake( @@ -1370,19 +1541,19 @@ def outplace_fused_experts_fake( use_int8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, - use_mxfp4_w4a4: bool = False, + ocp_mx_scheme: str | None = None, per_channel_quant: bool = False, global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - w1_scale: Optional[torch.Tensor] = None, - w2_scale: Optional[torch.Tensor] = None, - w1_zp: Optional[torch.Tensor] = None, - w2_zp: Optional[torch.Tensor] = None, - a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None, - block_shape: Optional[list[int]] = None, - w1_bias: Optional[torch.Tensor] = None, - w2_bias: Optional[torch.Tensor] = None, + expert_map: torch.Tensor | None = None, + w1_scale: torch.Tensor | None = None, + w2_scale: torch.Tensor | None = None, + w1_zp: torch.Tensor | None = None, + w2_zp: torch.Tensor | None = None, + a1_scale: torch.Tensor | None = None, + a2_scale: torch.Tensor | None = None, + block_shape: list[int] | None = None, + w1_bias: torch.Tensor | None = None, + w2_bias: torch.Tensor | None = None, ) -> torch.Tensor: return torch.empty_like(hidden_states) @@ -1392,14 +1563,17 @@ def outplace_fused_experts_fake( op_func=outplace_fused_experts, mutates_args=[], fake_impl=outplace_fused_experts_fake, - tags=(() if is_torch_equal_or_newer("2.7.0") else - (torch.Tag.needs_fixed_stride_order, )), + tags=( + () + if is_torch_equal_or_newer("2.7.0") + else (torch.Tag.needs_fixed_stride_order,) + ), ) def torch_vllm_inplace_fused_experts(**kwargs) -> torch.Tensor: torch.ops.vllm.maca_inplace_fused_experts(**kwargs) - hidden_states = kwargs['hidden_states'] + hidden_states = kwargs["hidden_states"] return hidden_states @@ -1408,7 +1582,7 @@ def torch_vllm_outplace_fused_experts(**kwargs) -> torch.Tensor: def dispatch_fused_experts_func(inplace: bool) -> Callable[..., torch.Tensor]: - if inplace: + if inplace and not disable_inplace(): return torch_vllm_inplace_fused_experts return torch_vllm_outplace_fused_experts @@ -1425,12 +1599,11 @@ def fused_experts( activation: str = "silu", apply_router_weight_on_input: bool = False, global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - quant_config: Optional[FusedMoEQuantConfig] = None, + expert_map: torch.Tensor | None = None, + quant_config: FusedMoEQuantConfig | None = None, allow_deep_gemm: bool = False, allow_cutlass_block_scaled_grouped_gemm: bool = False, ) -> torch.Tensor: - if quant_config is None: quant_config = FUSED_MOE_UNQUANTIZED_CONFIG use_fp8_w8a8 = quant_config.use_fp8_w8a8 @@ -1441,8 +1614,11 @@ def fused_experts( # E8M0 scale, which means we requantize the weight and input to the specific # scale. Fallen back to cutlass or triton for some cases would cause # accuracy issue. - if (allow_deep_gemm and quant_config.use_fp8_w8a8 and - (is_deep_gemm_e8m0_used() or _valid_deep_gemm(hidden_states, w1, w2))): + if ( + allow_deep_gemm + and quant_config.use_fp8_w8a8 + and (is_deep_gemm_e8m0_used() or _valid_deep_gemm(hidden_states, w1, w2)) + ): assert quant_config is not None assert apply_router_weight_on_input is False return deep_gemm_moe_fp8( @@ -1461,10 +1637,13 @@ def fused_experts( a2_scale=quant_config.a2_scale, apply_router_weight_on_input=apply_router_weight_on_input, ) - elif (allow_cutlass_block_scaled_grouped_gemm and use_fp8_w8a8 - and _valid_cutlass_block_scaled_grouped_gemm( - w1, w2, inplace, activation, apply_router_weight_on_input, - expert_map)): + elif ( + allow_cutlass_block_scaled_grouped_gemm + and use_fp8_w8a8 + and _valid_cutlass_block_scaled_grouped_gemm( + w1, w2, inplace, activation, apply_router_weight_on_input, expert_map + ) + ): assert quant_config is not None return run_cutlass_block_scaled_fused_experts( a=hidden_states, @@ -1473,7 +1652,8 @@ def fused_experts( w1_scale=quant_config.w1_scale, w2_scale=quant_config.w2_scale, topk_weights=topk_weights, - topk_ids=topk_ids) + topk_ids=topk_ids, + ) else: return dispatch_fused_experts_func(inplace)( hidden_states=hidden_states, @@ -1487,7 +1667,7 @@ def fused_experts( use_int8_w8a8=quant_config.use_int8_w8a8, use_int8_w8a16=quant_config.use_int8_w8a16, use_int4_w4a16=quant_config.use_int4_w4a16, - use_mxfp4_w4a4=quant_config.use_mxfp4_w4a4, + ocp_mx_scheme=quant_config.ocp_mx_scheme, per_channel_quant=quant_config.per_act_token_quant, global_num_experts=global_num_experts, expert_map=expert_map, @@ -1499,18 +1679,20 @@ def fused_experts( a2_scale=quant_config.a2_scale, block_shape=quant_config.block_shape, w1_bias=quant_config.w1_bias, - w2_bias=quant_config.w2_bias) + w2_bias=quant_config.w2_bias, + ) SILU_NO_MUL: str = activation_without_mul("silu") GELU_NO_MUL: str = activation_without_mul("gelu") +RELU2_NO_MUL: str = activation_without_mul("relu2") def _get_config_quant_dtype( use_fp8_w8a8: bool, use_int8_w8a8: bool, - use_mxfp4_w4a4: bool, -) -> Union[None, torch.dtype, str]: + ocp_mx_scheme: str | None, +) -> None | torch.dtype | str: """ Get the quantization type based on the quantization strategy flags. We don't have a quant_config at this point so we need to work backwards. @@ -1522,8 +1704,12 @@ def _get_config_quant_dtype( return torch.float8_e4m3fn elif use_int8_w8a8: return torch.int8 - elif use_mxfp4_w4a4: + elif ocp_mx_scheme == "w_mxfp4_a_mxfp4": return "mxfp4" + elif ocp_mx_scheme in {"w_mxfp4_a_mxfp6_e3m2", "w_mxfp6_e3m2_a_mxfp6_e3m2"}: + return "mxfp6_e3m2" + elif ocp_mx_scheme in {"w_mxfp4_a_mxfp6_e2m3", "w_mxfp6_e2m3_a_mxfp6_e2m3"}: + return "mxfp6_e2m3" return None @@ -1540,38 +1726,50 @@ def fused_experts_impl( use_int8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, - use_mxfp4_w4a4: bool = False, + ocp_mx_scheme: str | None = None, per_channel_quant: bool = False, global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - w1_scale: Optional[torch.Tensor] = None, - w2_scale: Optional[torch.Tensor] = None, - w1_zp: Optional[torch.Tensor] = None, - w2_zp: Optional[torch.Tensor] = None, - a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None, - block_shape: Optional[list[int]] = None, - w1_bias: Optional[torch.Tensor] = None, - w2_bias: Optional[torch.Tensor] = None, + expert_map: torch.Tensor | None = None, + w1_scale: torch.Tensor | None = None, + w2_scale: torch.Tensor | None = None, + w1_zp: torch.Tensor | None = None, + w2_zp: torch.Tensor | None = None, + a1_scale: torch.Tensor | None = None, + a2_scale: torch.Tensor | None = None, + block_shape: list[int] | None = None, + w1_bias: torch.Tensor | None = None, + w2_bias: torch.Tensor | None = None, ) -> torch.Tensor: # Check constraints. if use_int4_w4a16: - assert hidden_states.size(1) // 2 == w1.size(2), ( - "Hidden size mismatch") - elif use_mxfp4_w4a4: - # 16bit activation and fp4x2 packed weight - assert hidden_states.size(1) // 2 == w1.size(2), "hidden size mismatch" + assert hidden_states.size(1) // 2 == w1.size(2), "Hidden size mismatch" + elif ocp_mx_scheme is not None: + if ocp_mx_scheme in { + "w_mxfp4_a_mxfp4", + "w_mxfp4_a_mxfp6_e3m2", + "w_mxfp4_a_mxfp6_e2m3", + }: + # 16bit activation and fp4x2 packed weight + assert hidden_states.size(1) == w1.size(2) * 2, "hidden size mismatch" + elif ocp_mx_scheme in { + "w_mxfp6_e3m2_a_mxfp6_e3m2", + "w_mxfp6_e2m3_a_mxfp6_e2m3", + }: + assert hidden_states.size(1) == (w1.size(2) * 4) // 3, ( + "hidden size mismatch" + ) + else: + raise NotImplementedError(f"Unsupported ocp_mx_scheme={ocp_mx_scheme}") else: assert hidden_states.size(1) == w1.size(2), ( - f"Hidden size mismatch {hidden_states.size(1)} != {w1.size(2)}") + f"Hidden size mismatch {hidden_states.size(1)} != {w1.size(2)}" + ) assert topk_weights.size() == topk_ids.size(), "topk shape mismatch" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" assert w1.stride(-1) == 1, "Stride of last dimension must be 1" assert w2.stride(-1) == 1, "Stride of last dimension must be 1" - assert hidden_states.dtype in [ - torch.float32, torch.float16, torch.bfloat16 - ] + assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16] # ┌------------------------ Metax Modification -------------------------┐ H = hidden_states.shape[-1] # └------------------------- Metax Modification -------------------------┘ @@ -1592,14 +1790,17 @@ def fused_experts_impl( # └------------------------- Metax Modification -------------------------┘ use_int8_w8a16=use_int8_w8a16, use_int4_w4a16=use_int4_w4a16, - use_mxfp4_w4a4=use_mxfp4_w4a4, - dtype=hidden_states.dtype) + ocp_mx_scheme=ocp_mx_scheme, + dtype=hidden_states.dtype, + ) # Note: for use_int8_w8a16 or use_int4_w4a16, the activations are # quantized prior to calling fused_experts. - quant_dtype = _get_config_quant_dtype(use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a8=use_int8_w8a8, - use_mxfp4_w4a4=use_mxfp4_w4a4) + quant_dtype = _get_config_quant_dtype( + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a8=use_int8_w8a8, + ocp_mx_scheme=ocp_mx_scheme, + ) get_config_func = functools.partial( try_get_optimal_moe_config, @@ -1618,50 +1819,50 @@ def fused_experts_impl( # We can reuse the memory between these because by the time we need # cache3, we're done with cache1 # ┌------------------------ Metax Modification -------------------------┐ - stage1_config = config.get("stage1", config) - stage2_config = config.get("stage2", config) - - if 'ACCF32' not in stage1_config: - stage1_config['ACCF32'] = False - if 'ACCF32' not in stage2_config: - stage2_config['ACCF32'] = False - if 'SPLIT_K' not in stage1_config: - stage1_config['SPLIT_K'] = 1 - if 'SPLIT_K' not in stage2_config: - stage2_config['SPLIT_K'] = 1 - - if stage1_config['ACCF32']: - acc_type1 = torch.float32 - else: - acc_type1 = hidden_states.dtype - if stage2_config['ACCF32']: - acc_type2 = torch.float32 + stage1_config: dict[str, int] = cast(dict[str, int], config.get("stage1", config)) + stage2_config: dict[str, int] = cast(dict[str, int], config.get("stage2", config)) + + if "ACCF32" not in stage1_config: + stage1_config["ACCF32"] = False + if "ACCF32" not in stage2_config: + stage2_config["ACCF32"] = False + if "SPLIT_K" not in stage1_config: + stage1_config["SPLIT_K"] = 1 + if "SPLIT_K" not in stage2_config: + stage2_config["SPLIT_K"] = 1 + + acc_type1 = torch.float32 if stage1_config["ACCF32"] else hidden_states.dtype + acc_type2 = torch.float32 if stage2_config["ACCF32"] else hidden_states.dtype + + if stage1_config["SPLIT_K"] > 1: + intermediate_cache1 = torch.zeros( + (M, topk_ids.shape[1], N), device=hidden_states.device, dtype=acc_type1 + ) else: - acc_type2 = hidden_states.dtype + intermediate_cache1 = torch.empty( + (M, topk_ids.shape[1], N), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) - if stage1_config['SPLIT_K'] > 1: - intermediate_cache1 = torch.zeros((M, topk_ids.shape[1], N), - device=hidden_states.device, - dtype=acc_type1) - else: - intermediate_cache1 = torch.empty((M, topk_ids.shape[1], N), - device=hidden_states.device, - dtype=hidden_states.dtype) - - if stage2_config['SPLIT_K'] > 1: - intermediate_cache3 = torch.zeros((M, topk_ids.shape[1], w2.shape[1]), - device=hidden_states.device, - dtype=acc_type2) + if stage2_config["SPLIT_K"] > 1: + intermediate_cache3 = torch.zeros( + (M, topk_ids.shape[1], w2.shape[1]), + device=hidden_states.device, + dtype=acc_type2, + ) else: - intermediate_cache3 = torch.empty((M, topk_ids.shape[1], w2.shape[1]), - device=hidden_states.device, - dtype=hidden_states.dtype) + intermediate_cache3 = torch.empty( + (M, topk_ids.shape[1], w2.shape[1]), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) # └------------------------- Metax Modification -------------------------┘ # This needs separate memory since it's used concurrently with cache1 - intermediate_cache2 = torch.empty((M * top_k_num, N // 2), - device=hidden_states.device, - dtype=hidden_states.dtype) + intermediate_cache2 = torch.empty( + (M * top_k_num, N // 2), device=hidden_states.device, dtype=hidden_states.dtype + ) if hidden_states.dtype == torch.bfloat16: compute_type = tl.bfloat16 @@ -1672,22 +1873,51 @@ def fused_experts_impl( else: raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}") - if inplace: + if inplace and not disable_inplace(): out_hidden_states = hidden_states else: out_hidden_states = torch.empty_like(hidden_states) - if use_mxfp4_w4a4: - # Weight has to be dequantized for mxfp4 emulation. - w1 = dequant_mxfp4(w1, w1_scale, hidden_states.dtype) - w1_scale = None - w2 = dequant_mxfp4(w2, w2_scale, hidden_states.dtype) - w2_scale = None + if ocp_mx_scheme is not None: + # TODO: On platforms for which `current_platform.supports_mx()` is True + # and for which we have a native OCP mx fused MOE kernel, + # this dequantization step should not be done. + if ocp_mx_scheme in { + OCP_MX_Scheme.w_mxfp4_a_mxfp4, + OCP_MX_Scheme.w_mxfp4_a_mxfp6_e3m2, + OCP_MX_Scheme.w_mxfp4_a_mxfp6_e2m3, + }: + # Weight has to be dequantized for mxfp4 emulation. + w1 = dequant_mxfp4(w1, w1_scale, hidden_states.dtype) + w1_scale = None + w2 = dequant_mxfp4(w2, w2_scale, hidden_states.dtype) + w2_scale = None + elif ocp_mx_scheme == OCP_MX_Scheme.w_mxfp6_e3m2_a_mxfp6_e3m2: + w1 = dequant_mxfp6( + w1, w1_scale, quant_dtype="fp6_e3m2", float_dtype=hidden_states.dtype + ) + w1_scale = None + w2 = dequant_mxfp6( + w2, w2_scale, quant_dtype="fp6_e3m2", float_dtype=hidden_states.dtype + ) + w2_scale = None + elif ocp_mx_scheme == OCP_MX_Scheme.w_mxfp6_e2m3_a_mxfp6_e2m3: + w1 = dequant_mxfp6( + w1, w1_scale, quant_dtype="fp6_e2m3", float_dtype=hidden_states.dtype + ) + w1_scale = None + w2 = dequant_mxfp6( + w2, w2_scale, quant_dtype="fp6_e2m3", float_dtype=hidden_states.dtype + ) + w2_scale = None + else: + raise NotImplementedError(f"Unsupported ocp_mx_scheme={ocp_mx_scheme}") for chunk in range((num_tokens // CHUNK_SIZE) + 1): - begin_chunk_idx, end_chunk_idx = (chunk * CHUNK_SIZE, - min((chunk + 1) * CHUNK_SIZE, - num_tokens)) + begin_chunk_idx, end_chunk_idx = ( + chunk * CHUNK_SIZE, + min((chunk + 1) * CHUNK_SIZE, num_tokens), + ) curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx] tokens_in_chunk, _ = curr_hidden_states.size() @@ -1700,8 +1930,9 @@ def fused_experts_impl( # so the cache size and config are already set correctly and # do not need to be adjusted. intermediate_cache1 = intermediate_cache1[:tokens_in_chunk] - intermediate_cache2 = intermediate_cache2[:tokens_in_chunk * - topk_ids.size(1)] + intermediate_cache2 = intermediate_cache2[ + : tokens_in_chunk * topk_ids.size(1) + ] intermediate_cache3 = intermediate_cache3[:tokens_in_chunk] config = get_config_func(tokens_in_chunk) @@ -1712,129 +1943,173 @@ def fused_experts_impl( A_scale=a1_scale, quant_dtype=quant_dtype, per_act_token_quant=per_channel_quant, - block_shape=block_shape) + block_shape=block_shape, + ) - sorted_token_ids, expert_ids, num_tokens_post_padded = ( - moe_align_block_size(curr_topk_ids, stage1_config['BLOCK_SIZE_M'], - global_num_experts, expert_map)) + sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( + curr_topk_ids, stage1_config["BLOCK_SIZE_M"], global_num_experts, expert_map + ) # ┌------------------------ Metax Modification -------------------------┐ - if (stage1_config['BLOCK_SIZE_M'] == 128 and not use_int8_w8a8 - and (topk_ids.shape[1] == 1 or topk_ids.shape[1] == 2) - and (curr_hidden_states.dtype == torch.bfloat16 - or curr_hidden_states.dtype == torch.float16) - and w1.shape[1] % 4 == 0 and w1.shape[2] % 8 == 0): - mx_ops.fused_moe_kernel(curr_hidden_states, w1, - intermediate_cache1, curr_topk_weights, - curr_topk_ids, sorted_token_ids, - expert_ids, num_tokens_post_padded, False, - topk_ids.shape[1], 0) + if ( + stage1_config["BLOCK_SIZE_M"] == 128 + and not use_int8_w8a8 + and (topk_ids.shape[1] == 1 or topk_ids.shape[1] == 2) + and ( + curr_hidden_states.dtype == torch.bfloat16 + or curr_hidden_states.dtype == torch.float16 + ) + and w1.shape[1] % 4 == 0 + and w1.shape[2] % 8 == 0 + ): + mx_ops.fused_moe_kernel( + curr_hidden_states, + w1, + intermediate_cache1, + curr_topk_weights, + curr_topk_ids, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + False, + topk_ids.shape[1], + 0, + ) else: qcurr_hidden_states, a1q_scale = moe_kernel_quantize_input( A=curr_hidden_states, A_scale=a1_scale, quant_dtype=quant_dtype, per_act_token_quant=per_channel_quant, - block_shape=block_shape) - - invoke_fused_moe_kernel(qcurr_hidden_states, - w1, - intermediate_cache1, - a1q_scale, - w1_scale, - w1_zp, - curr_topk_weights, - sorted_token_ids, - expert_ids, - num_tokens_post_padded, - apply_router_weight_on_input, - top_k_num, - stage1_config, - compute_type=compute_type, - use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a8=use_int8_w8a8, - use_int8_w8a16=use_int8_w8a16, - use_int4_w4a16=use_int4_w4a16, - orig_acc_dtype=hidden_states.dtype, - per_channel_quant=per_channel_quant, - block_shape=block_shape, - B_bias=w1_bias) + block_shape=block_shape, + ) + + invoke_fused_moe_kernel( + qcurr_hidden_states, + w1, + intermediate_cache1, + a1q_scale, + w1_scale, + w1_zp, + curr_topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + apply_router_weight_on_input, + top_k_num, + stage1_config, + compute_type=compute_type, + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a8=use_int8_w8a8, + use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, + orig_acc_dtype=hidden_states.dtype, + per_channel_quant=per_channel_quant, + block_shape=block_shape, + B_bias=w1_bias, + ) # └------------------------- Metax Modification -------------------------┘ # Activation function with multiplication if activation == "silu": - torch.ops._C.silu_and_mul(intermediate_cache2, - intermediate_cache1.view(-1, N)) + torch.ops._C.silu_and_mul( + intermediate_cache2, intermediate_cache1.view(-1, N) + ) elif activation == "gelu": - torch.ops._C.gelu_and_mul(intermediate_cache2, - intermediate_cache1.view(-1, N)) + torch.ops._C.gelu_and_mul( + intermediate_cache2, intermediate_cache1.view(-1, N) + ) elif activation == "swigluoai": # alpha = 1.702, limit = 7.0 - torch.ops._C.swigluoai_and_mul(intermediate_cache2, - intermediate_cache1.view(-1, N)) + torch.ops._C.swigluoai_and_mul( + intermediate_cache2, intermediate_cache1.view(-1, N) + ) # Activation function without multiplication elif activation == SILU_NO_MUL: intermediate_cache2 = F.silu(intermediate_cache1.view(-1, N)) elif activation == GELU_NO_MUL: intermediate_cache2 = F.gelu(intermediate_cache1.view(-1, N)) - + elif activation == RELU2_NO_MUL: + intermediate_cache2 = torch.square(F.relu(intermediate_cache1.view(-1, N))) else: raise ValueError(f"Unsupported FusedMoe activation: {activation}.") # ┌------------------------ Metax Modification -------------------------┐ - if (stage2_config['BLOCK_SIZE_M'] == 128 and not use_int8_w8a8 - and w2.shape[1] % 4 == 0 and w2.shape[2] % 8 == 0 - and (hidden_states.dtype == torch.bfloat16 - or hidden_states.dtype == torch.float16)): - mx_ops.fused_moe_kernel(intermediate_cache2, w2, - intermediate_cache3, curr_topk_weights, - curr_topk_ids, sorted_token_ids, - expert_ids, num_tokens_post_padded, True, - 1, 0) + if ( + stage2_config["BLOCK_SIZE_M"] == 128 + and not use_int8_w8a8 + and w2.shape[1] % 4 == 0 + and w2.shape[2] % 8 == 0 + and ( + hidden_states.dtype == torch.bfloat16 + or hidden_states.dtype == torch.float16 + ) + ): + mx_ops.fused_moe_kernel( + intermediate_cache2, + w2, + intermediate_cache3, + curr_topk_weights, + curr_topk_ids, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + True, + 1, + 0, + ) else: qintermediate_cache2, a2q_scale = moe_kernel_quantize_input( A=intermediate_cache2, A_scale=a2_scale, quant_dtype=quant_dtype, per_act_token_quant=per_channel_quant, - block_shape=block_shape) + block_shape=block_shape, + ) - if stage2_config['BLOCK_SIZE_M'] != stage1_config['BLOCK_SIZE_M']: + if stage2_config["BLOCK_SIZE_M"] != stage1_config["BLOCK_SIZE_M"]: sorted_token_ids, expert_ids, num_tokens_post_padded = ( - moe_align_block_size(curr_topk_ids, - stage2_config['BLOCK_SIZE_M'], - global_num_experts, expert_map)) - - invoke_fused_moe_kernel(qintermediate_cache2, - w2, - intermediate_cache3, - a2q_scale, - w2_scale, - w2_zp, - curr_topk_weights, - sorted_token_ids, - expert_ids, - num_tokens_post_padded, - not apply_router_weight_on_input, - 1, - stage2_config, - compute_type=compute_type, - use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a8=use_int8_w8a8, - use_int8_w8a16=use_int8_w8a16, - use_int4_w4a16=use_int4_w4a16, - orig_acc_dtype=hidden_states.dtype, - per_channel_quant=per_channel_quant, - block_shape=block_shape, - B_bias=w2_bias) + moe_align_block_size( + curr_topk_ids, + stage2_config["BLOCK_SIZE_M"], + global_num_experts, + expert_map, + ) + ) + + invoke_fused_moe_kernel( + qintermediate_cache2, + w2, + intermediate_cache3, + a2q_scale, + w2_scale, + w2_zp, + curr_topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + not apply_router_weight_on_input, + 1, + stage2_config, + compute_type=compute_type, + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a8=use_int8_w8a8, + use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, + orig_acc_dtype=hidden_states.dtype, + per_channel_quant=per_channel_quant, + block_shape=block_shape, + B_bias=w2_bias, + ) # └------------------------- Metax Modification -------------------------┘ - ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.size()), - out_hidden_states[begin_chunk_idx:end_chunk_idx]) + ops.moe_sum( + intermediate_cache3.view(*intermediate_cache3.size()), + out_hidden_states[begin_chunk_idx:end_chunk_idx], + ) return out_hidden_states class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): - def __init__( self, quant_config: FusedMoEQuantConfig, @@ -1843,10 +2118,12 @@ def __init__( @property def activation_formats( - self + self, ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: - return (mk.FusedMoEActivationFormat.Standard, - mk.FusedMoEActivationFormat.Standard) + return ( + mk.FusedMoEActivationFormat.Standard, + mk.FusedMoEActivationFormat.Standard, + ) def supports_chunking(self) -> bool: return True @@ -1859,20 +2136,18 @@ def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: def workspace_shapes( self, - a: torch.Tensor, - aq: torch.Tensor, M: int, N: int, K: int, topk: int, global_num_experts: int, local_num_experts: int, - expert_tokens_meta: Optional[mk.ExpertTokensMetadata], - ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: + expert_tokens_meta: mk.ExpertTokensMetadata | None, + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: workspace1 = (M, topk, max(N // 2, K)) workspace2 = (M, topk, max(N, K)) output = (M, K) - return (workspace1, workspace2, output, a.dtype) + return (workspace1, workspace2, output) def apply( self, @@ -1884,50 +2159,45 @@ def apply( topk_ids: torch.Tensor, activation: str, global_num_experts: int, - expert_map: Optional[torch.Tensor], - a1q_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], + expert_map: torch.Tensor | None, + a1q_scale: torch.Tensor | None, + a2_scale: torch.Tensor | None, workspace13: torch.Tensor, workspace2: torch.Tensor, - expert_tokens_meta: Optional[mk.ExpertTokensMetadata], + expert_tokens_meta: mk.ExpertTokensMetadata | None, apply_router_weight_on_input: bool, ): # Check constraints. if self.quant_config.use_int4_w4a16: - assert hidden_states.size(-1) // 2 == w1.size(2), ( - "Hidden size mismatch") + assert hidden_states.size(-1) // 2 == w1.size(2), "Hidden size mismatch" else: - assert hidden_states.size(-1) == w1.size(2), \ - (f"Hidden size mismatch {hidden_states.size(-1)} " - f"!= {w1.size(2)}") + assert hidden_states.size(-1) == w1.size(2), ( + f"Hidden size mismatch {hidden_states.size(-1)} != {w1.size(2)}" + ) - assert hidden_states.is_contiguous( - ), "Hidden_states must be contiguous" + assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" assert hidden_states.dim() == 2 assert w1.stride(-1) == 1, "Stride of last dimension must be 1" assert w2.stride(-1) == 1, "Stride of last dimension must be 1" assert hidden_states.dtype in [ - torch.float32, torch.float16, torch.bfloat16, torch.float8_e4m3fn + torch.float32, + torch.float16, + torch.bfloat16, + torch.float8_e4m3fn, ] - E, num_tokens, N, K, top_k_num = mk._moe_problem_size( - hidden_states, w1, w2, topk_ids) + E, num_tokens, N, K, top_k_num = self.moe_problem_size( + hidden_states, w1, w2, topk_ids + ) if global_num_experts == -1: global_num_experts = E - config_dtype = get_config_dtype_str( - use_fp8_w8a8=self.quant_config.use_fp8_w8a8, - use_int8_w8a16=self.quant_config.use_int8_w8a16, - use_int4_w4a16=self.quant_config.use_int4_w4a16, - use_mxfp4_w4a4=self.quant_config.use_mxfp4_w4a4, - dtype=hidden_states.dtype) - config = try_get_optimal_moe_config( w1.size(), w2.size(), top_k_num, - config_dtype, + self.quant_config.config_name(hidden_states.dtype), num_tokens, block_shape=self.block_shape, ) @@ -1941,20 +2211,18 @@ def apply( elif hidden_states.dtype == torch.float8_e4m3fn: compute_type = tl.bfloat16 else: - raise ValueError( - f"Unsupported compute_type: {hidden_states.dtype}") + raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}") # Note that the output tensor might be in workspace1 - intermediate_cache1 = _resize_cache(workspace2, - (num_tokens, top_k_num, N)) - intermediate_cache2 = _resize_cache(workspace13, - (num_tokens * top_k_num, N // 2)) - intermediate_cache3 = _resize_cache(workspace2, - (num_tokens, top_k_num, K)) + intermediate_cache1 = _resize_cache(workspace2, (num_tokens, top_k_num, N)) + intermediate_cache2 = _resize_cache( + workspace13, (num_tokens * top_k_num, N // 2) + ) + intermediate_cache3 = _resize_cache(workspace2, (num_tokens, top_k_num, K)) - sorted_token_ids, expert_ids, num_tokens_post_padded = ( - moe_align_block_size(topk_ids, config['BLOCK_SIZE_M'], - global_num_experts, expert_map)) + sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( + topk_ids, config["BLOCK_SIZE_M"], global_num_experts, expert_map + ) invoke_fused_moe_kernel( hidden_states, @@ -1975,19 +2243,25 @@ def apply( use_int8_w8a8=self.quant_config.use_int8_w8a8, use_int8_w8a16=self.quant_config.use_int8_w8a16, use_int4_w4a16=self.quant_config.use_int4_w4a16, + orig_acc_dtype=hidden_states.dtype, per_channel_quant=self.per_act_token_quant, block_shape=self.block_shape, B_bias=self.w1_bias, ) - self.activation(activation, intermediate_cache2, - intermediate_cache1.view(-1, N)) + self.activation( + activation, intermediate_cache2, intermediate_cache1.view(-1, N) + ) - a2q_scale: Optional[torch.Tensor] = None + a2q_scale: torch.Tensor | None = None qintermediate_cache2, a2q_scale = moe_kernel_quantize_input( - intermediate_cache2, a2_scale, self.quant_dtype, - self.per_act_token_quant, self.block_shape) + intermediate_cache2, + a2_scale, + self.quant_dtype, + self.per_act_token_quant, + self.block_shape, + ) invoke_fused_moe_kernel( qintermediate_cache2, @@ -2008,31 +2282,39 @@ def apply( use_int8_w8a8=self.quant_config.use_int8_w8a8, use_int8_w8a16=self.quant_config.use_int8_w8a16, use_int4_w4a16=self.quant_config.use_int4_w4a16, + orig_acc_dtype=hidden_states.dtype, per_channel_quant=self.per_act_token_quant, block_shape=self.block_shape, B_bias=self.w2_bias, ) - ops.moe_sum(intermediate_cache3, output) + # separate function is required for MoE + LoRA + self.moe_sum(intermediate_cache3, output) + + def moe_sum(self, input: torch.Tensor, output: torch.Tensor) -> None: + ops.moe_sum(input, output) def modular_triton_fused_moe( - quant_config: FusedMoEQuantConfig) -> mk.FusedMoEModularKernel: + quant_config: FusedMoEQuantConfig, shared_experts: torch.nn.Module | None = None +) -> mk.FusedMoEModularKernel: return mk.FusedMoEModularKernel( MoEPrepareAndFinalizeNoEP(), TritonExperts(quant_config), + shared_experts, ) def get_config_dtype_str( - dtype: torch.dtype, - use_int4_w4a16: Optional[bool] = False, - # ┌------------------------ Metax Modification -------------------------┐ - use_int8_w8a8: Optional[bool] = False, - # └------------------------- Metax Modification -------------------------┘ - use_int8_w8a16: Optional[bool] = False, - use_fp8_w8a8: Optional[bool] = False, - use_mxfp4_w4a4: Optional[bool] = False) -> Optional[str]: + dtype: torch.dtype, + use_int4_w4a16: bool | None = False, + # ┌------------------------ Metax Modification -------------------------┐ + use_int8_w8a8: bool | None = False, + # └------------------------- Metax Modification -------------------------┘ + use_int8_w8a16: bool | None = False, + use_fp8_w8a8: bool | None = False, + ocp_mx_scheme: str | None = None, +) -> str | None: if use_fp8_w8a8: return "fp8_w8a8" # ┌------------------------ Metax Modification -------------------------┐ @@ -2043,8 +2325,11 @@ def get_config_dtype_str( return "int8_w8a16" elif use_int4_w4a16: return "int4_w4a16" - elif use_mxfp4_w4a4: - return "mxfp4_w4a4" + elif ocp_mx_scheme is not None: + # The output of this function is passed to `try_get_optimal_moe_config`, + # and as we only simulate OCP MX execution in fused_moe for now, + # we will NOT look for `*,dtype=w_mxfp4_a_mxfp4.json` for now. + return None elif dtype == torch.float: # avoiding cases where kernel fails when float32 MoE # use fp16/bfloat16 configs diff --git a/vllm_metax/models/__init__.py b/vllm_metax/models/__init__.py index 1560803f8..e6ceb7f76 100644 --- a/vllm_metax/models/__init__.py +++ b/vllm_metax/models/__init__.py @@ -4,37 +4,41 @@ def register_model(): - ModelRegistry.register_model( - "BaichuanForCausalLM", - "vllm_metax.models.baichuan:BaichuanForCausalLM") + "BaichuanForCausalLM", "vllm_metax.models.baichuan:BaichuanForCausalLM" + ) ModelRegistry.register_model( "BaiChuanMoEForCausalLM", - "vllm_metax.models.baichuan_moe:BaiChuanMoEForCausalLM") + "vllm_metax.models.baichuan_moe:BaiChuanMoEForCausalLM", + ) ModelRegistry.register_model( "Qwen2VLForConditionalGeneration", - "vllm_metax.models.qwen2_vl:Qwen2VLForConditionalGeneration") + "vllm_metax.models.qwen2_vl:Qwen2VLForConditionalGeneration", + ) ModelRegistry.register_model( "Qwen2_5_VLForConditionalGeneration", - "vllm_metax.models.qwen2_5_vl:Qwen2_5_VLForConditionalGeneration") + "vllm_metax.models.qwen2_5_vl:Qwen2_5_VLForConditionalGeneration", + ) ModelRegistry.register_model( "Qwen3VLForConditionalGeneration", - "vllm_metax.models.qwen3_vl:Qwen3VLForConditionalGeneration") + "vllm_metax.models.qwen3_vl:Qwen3VLForConditionalGeneration", + ) - ModelRegistry.register_model("DeepSeekMTPModel", - "vllm_metax.models.deepseek_mtp:DeepSeekMTP") + ModelRegistry.register_model( + "DeepSeekMTPModel", "vllm_metax.models.deepseek_mtp:DeepSeekMTP" + ) ModelRegistry.register_model( - "DeepseekV2ForCausalLM", - "vllm_metax.models.deepseek_v2:DeepseekV2ForCausalLM") + "DeepseekV2ForCausalLM", "vllm_metax.models.deepseek_v2:DeepseekV2ForCausalLM" + ) ModelRegistry.register_model( - "DeepseekV3ForCausalLM", - "vllm_metax.models.deepseek_v2:DeepseekV3ForCausalLM") + "DeepseekV3ForCausalLM", "vllm_metax.models.deepseek_v2:DeepseekV3ForCausalLM" + ) ModelRegistry.register_model( - "DeepseekV32ForCausalLM", - "vllm_metax.models.deepseek_v2:DeepseekV3ForCausalLM") + "DeepseekV32ForCausalLM", "vllm_metax.models.deepseek_v2:DeepseekV3ForCausalLM" + ) diff --git a/vllm_metax/models/deepseek_v2.py b/vllm_metax/models/deepseek_v2.py index 634819d16..0b8533f2f 100644 --- a/vllm_metax/models/deepseek_v2.py +++ b/vllm_metax/models/deepseek_v2.py @@ -23,6 +23,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only DeepseekV2/DeepseekV3 model.""" + import typing from collections.abc import Callable, Iterable from itertools import islice @@ -32,51 +33,70 @@ from torch import nn from transformers import DeepseekV2Config, DeepseekV3Config +from vllm._aiter_ops import rocm_aiter_ops from vllm.attention import Attention from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.ops.common import pack_seq_triton, unpack_seq_triton from vllm.compilation.decorators import support_torch_compile -from vllm.config import (CacheConfig, ParallelConfig, VllmConfig, - get_current_vllm_config) -from vllm.distributed import (get_ep_group, get_pp_group, - get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, - tensor_model_parallel_all_gather) +from vllm.config import CacheConfig, ParallelConfig, VllmConfig, get_current_vllm_config +from vllm.distributed import ( + get_ep_group, + get_pp_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_gather, +) from vllm.forward_context import get_forward_context from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase -from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.fused_moe import SharedFusedMoE from vllm.model_executor.layers.layernorm import LayerNorm, RMSNorm -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - MergedColumnParallelLinear, - ReplicatedLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.mla import MLAModules, MultiHeadLatentAttention +from vllm.model_executor.layers.mla import MLAModules, MultiHeadLatentAttentionWrapper from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - per_token_group_quant_fp8) + per_token_group_quant_fp8, +) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.shared_fused_moe import SharedFusedMoE from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, maybe_remap_kv_scale_name) + default_weight_loader, + maybe_remap_kv_scale_name, +) from vllm.model_executor.models.utils import sequence_parallel_chunk from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors -from vllm.utils import cdiv, direct_register_custom_op +from vllm.utils.torch_utils import direct_register_custom_op from vllm_metax.utils.deep_gemm import bf16_mqa_logits, bf16_paged_mqa_logits from vllm_metax.v1.attention.backends.mla.indexer import ( MacaDeepseekV32IndexerBackend as DeepseekV32IndexerBackend, - DeepseekV32IndexerMetadata) + DeepseekV32IndexerMetadata, +) from vllm.v1.kv_cache_interface import KVCacheSpec, MLAAttentionSpec -from vllm.model_executor.models.interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP +from vllm.model_executor.models.interfaces import ( + MixtureOfExperts, + SupportsLoRA, + SupportsPP, +) from vllm.model_executor.models.utils import ( - PPMissingLayer, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) + PPMissingLayer, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) if current_platform.is_cuda_alike(): from vllm import _custom_ops as ops @@ -87,14 +107,99 @@ logger = init_logger(__name__) -class DeepseekV2MLP(nn.Module): +class DeepseekAttention(nn.Module): + """Normal MHA implementation used by Deepseek v1.""" + def __init__( + self, + vllm_config: VllmConfig, + config: DeepseekV2Config | DeepseekV3Config, + hidden_size: int, + num_heads: int, + rope_theta: float = 10000, + rope_scaling: dict[str, Any] | None = None, + max_position_embeddings: int = 8192, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + **kwargs, + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = config.num_key_value_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=False, + quant_config=quant_config, + ) + + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + ) + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + ) + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v) + output, _ = self.o_proj(attn_output) + return output + + +class DeepseekV2MLP(nn.Module): def __init__( self, hidden_size: int, intermediate_size: int, hidden_act: str, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, reduce_results: bool = True, is_sequence_parallel=False, prefix: str = "", @@ -106,21 +211,26 @@ def __init__( # replicated and no collective ops are needed. # Otherwise we use standard TP with an allreduce at the end. self.gate_up_proj = MergedColumnParallelLinear( - hidden_size, [intermediate_size] * 2, + hidden_size, + [intermediate_size] * 2, bias=False, quant_config=quant_config, disable_tp=is_sequence_parallel, - prefix=f"{prefix}.gate_up_proj") - self.down_proj = RowParallelLinear(intermediate_size, - hidden_size, - bias=False, - quant_config=quant_config, - reduce_results=reduce_results, - disable_tp=is_sequence_parallel, - prefix=f"{prefix}.down_proj") + prefix=f"{prefix}.gate_up_proj", + ) + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + reduce_results=reduce_results, + disable_tp=is_sequence_parallel, + prefix=f"{prefix}.down_proj", + ) if hidden_act != "silu": - raise ValueError(f"Unsupported activation: {hidden_act}. " - "Only silu is supported for now.") + raise ValueError( + f"Unsupported activation: {hidden_act}. Only silu is supported for now." + ) self.act_fn = SiluAndMul() def forward(self, x): @@ -131,22 +241,21 @@ def forward(self, x): class DeepseekV2MoE(nn.Module): - def __init__( self, - config: Union[DeepseekV2Config, DeepseekV3Config], + config: DeepseekV2Config | DeepseekV3Config, parallel_config: ParallelConfig, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() self.tp_size = get_tensor_model_parallel_world_size() self.tp_rank = get_tensor_model_parallel_rank() - self.routed_scaling_factor = config.routed_scaling_factor + self.routed_scaling_factor = getattr(config, "routed_scaling_factor", 1.0) self.ep_group = get_ep_group().device_group - self.ep_rank = self.ep_group.rank() + self.ep_rank = get_ep_group().rank_in_group self.ep_size = self.ep_group.size() self.n_routed_experts: int = config.n_routed_experts self.n_shared_experts: int = config.n_shared_experts @@ -154,17 +263,22 @@ def __init__( self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe if config.hidden_act != "silu": - raise ValueError(f"Unsupported activation: {config.hidden_act}. " - "Only silu is supported for now.") - - self.gate = ReplicatedLinear(config.hidden_size, - config.n_routed_experts, - bias=False, - quant_config=None, - prefix=f"{prefix}.gate") - if config.topk_method == "noaux_tc": + raise ValueError( + f"Unsupported activation: {config.hidden_act}. " + "Only silu is supported for now." + ) + + self.gate = ReplicatedLinear( + config.hidden_size, + config.n_routed_experts, + bias=False, + quant_config=None, + prefix=f"{prefix}.gate", + ) + if getattr(config, "topk_method", None) == "noaux_tc": self.gate.e_score_correction_bias = nn.Parameter( - torch.empty(config.n_routed_experts, dtype=torch.float32)) + torch.empty(config.n_routed_experts, dtype=torch.float32) + ) else: self.gate.e_score_correction_bias = None @@ -174,40 +288,19 @@ def __init__( self.n_redundant_experts = eplb_config.num_redundant_experts self.n_logical_experts = self.n_routed_experts - self.n_physical_experts = (self.n_logical_experts + - self.n_redundant_experts) + self.n_physical_experts = self.n_logical_experts + self.n_redundant_experts self.n_local_physical_experts = self.n_physical_experts // self.ep_size - self.physical_expert_start = (self.ep_rank * - self.n_local_physical_experts) - self.physical_expert_end = (self.physical_expert_start + - self.n_local_physical_experts) + self.physical_expert_start = self.ep_rank * self.n_local_physical_experts + self.physical_expert_end = ( + self.physical_expert_start + self.n_local_physical_experts + ) - if config.n_shared_experts is None: - self.experts = FusedMoE( - num_experts=config.n_routed_experts, - top_k=config.num_experts_per_tok, - hidden_size=config.hidden_size, - intermediate_size=config.moe_intermediate_size, - reduce_results=False, - renormalize=config.norm_topk_prob, - quant_config=quant_config, - use_grouped_topk=True, - num_expert_group=config.n_group, - topk_group=config.topk_group, - prefix=f"{prefix}.experts", - scoring_func=config.scoring_func, - # we do scaling outside, set factor to 1.0 to avoid double mul - routed_scaling_factor=1.0, - e_score_correction_bias=self.gate.e_score_correction_bias, - enable_eplb=self.enable_eplb, - num_redundant_experts=self.n_redundant_experts, - is_sequence_parallel=self.is_sequence_parallel, - ) + self.is_rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled() + if config.n_shared_experts is None or self.is_rocm_aiter_moe_enabled: self.shared_experts = None else: - intermediate_size = (config.moe_intermediate_size * - config.n_shared_experts) + intermediate_size = config.moe_intermediate_size * config.n_shared_experts self.shared_experts = DeepseekV2MLP( hidden_size=config.hidden_size, @@ -219,27 +312,34 @@ def __init__( prefix=f"{prefix}.shared_experts", ) - self.experts = SharedFusedMoE( - shared_experts=self.shared_experts, - num_experts=config.n_routed_experts, - top_k=config.num_experts_per_tok, - hidden_size=config.hidden_size, - intermediate_size=config.moe_intermediate_size, - reduce_results=False, - renormalize=config.norm_topk_prob, - quant_config=quant_config, - use_grouped_topk=True, - num_expert_group=config.n_group, - topk_group=config.topk_group, - prefix=f"{prefix}.experts", - scoring_func=config.scoring_func, - # we do scaling outside, set factor to 1.0 to avoid double mul - routed_scaling_factor=1.0, - e_score_correction_bias=self.gate.e_score_correction_bias, - enable_eplb=self.enable_eplb, - num_redundant_experts=self.n_redundant_experts, - is_sequence_parallel=self.is_sequence_parallel, - ) + self.experts = SharedFusedMoE( + shared_experts=self.shared_experts, + gate=self.gate, + num_experts=config.n_routed_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=config.norm_topk_prob, + quant_config=quant_config, + use_grouped_topk=True, + num_expert_group=getattr(config, "n_group", 1), + topk_group=getattr(config, "topk_group", 1), + prefix=f"{prefix}.experts", + scoring_func=getattr(config, "scoring_func", "softmax"), + # we do scaling outside, set factor to 1.0 to avoid double mul + # aiter applies routed_scaling_factor internally + routed_scaling_factor=1.0 + if not self.is_rocm_aiter_moe_enabled + else self.routed_scaling_factor, + e_score_correction_bias=self.gate.e_score_correction_bias, + enable_eplb=self.enable_eplb, + num_redundant_experts=self.n_redundant_experts, + is_sequence_parallel=self.is_sequence_parallel, + n_shared_experts=config.n_shared_experts + if rocm_aiter_ops.is_fusion_moe_shared_experts_enabled() + else None, + ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: num_tokens, hidden_dim = hidden_states.shape @@ -252,25 +352,30 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: if self.is_sequence_parallel: hidden_states = sequence_parallel_chunk(hidden_states) - # router_logits: (num_tokens, n_experts) - router_logits, _ = self.gate(hidden_states) - - fused_moe_out = self.experts(hidden_states=hidden_states, - router_logits=router_logits) - - if self.shared_experts is not None: - shared_output, final_hidden_states = fused_moe_out + if self.experts.is_internal_router: + # In this case, the gate/router runs inside the FusedMoE class + fused_moe_out = self.experts( + hidden_states=hidden_states, router_logits=hidden_states + ) else: - shared_output = None - final_hidden_states = fused_moe_out + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + fused_moe_out = self.experts( + hidden_states=hidden_states, router_logits=router_logits + ) + + shared_output, final_hidden_states = fused_moe_out + if self.shared_experts is None: + assert shared_output is None # Fix FP16 overflow # See DeepseekV2DecoderLayer for more details. if hidden_states.dtype != torch.float16: - final_hidden_states *= self.routed_scaling_factor + if not self.is_rocm_aiter_moe_enabled: + final_hidden_states *= self.routed_scaling_factor elif self.shared_experts is not None: assert shared_output is not None - shared_output *= (1. / self.routed_scaling_factor) + shared_output *= 1.0 / self.routed_scaling_factor if self.shared_experts is not None: assert shared_output is not None @@ -278,29 +383,30 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: if self.is_sequence_parallel: final_hidden_states = tensor_model_parallel_all_gather( - final_hidden_states, 0) + final_hidden_states, 0 + ) final_hidden_states = final_hidden_states[:num_tokens] elif self.tp_size > 1: - final_hidden_states = ( - self.experts.maybe_all_reduce_tensor_model_parallel( - final_hidden_states)) + final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( + final_hidden_states + ) return final_hidden_states.view(num_tokens, hidden_dim) def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float: import math + if scale <= 1: return 1.0 return 0.1 * mscale * math.log(scale) + 1.0 class DeepseekV2Attention(nn.Module): - def __init__( self, vllm_config: VllmConfig, - config: Union[DeepseekV2Config, DeepseekV3Config], + config: DeepseekV2Config | DeepseekV3Config, hidden_size: int, num_heads: int, qk_nope_head_dim: int, @@ -309,11 +415,11 @@ def __init__( q_lora_rank: int, kv_lora_rank: int, rope_theta: float = 10000, - rope_scaling: Optional[dict[str, Any]] = None, + rope_scaling: dict[str, Any] | None = None, max_position_embeddings: int = 8192, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - topk_indices_buffer: Optional[torch.Tensor] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, + topk_indices_buffer: torch.Tensor | None = None, prefix: str = "", ) -> None: super().__init__() @@ -331,60 +437,70 @@ def __init__( self.scaling = self.qk_head_dim**-0.5 self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings - assert topk_indices_buffer is None, "topk_indices_buffer is not \ + assert topk_indices_buffer is None, ( + "topk_indices_buffer is not \ supported for DeepseekV2Attention" + ) if self.q_lora_rank is not None: - self.q_a_proj = ReplicatedLinear(self.hidden_size, - self.q_lora_rank, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.q_a_proj") - self.q_a_layernorm = RMSNorm(self.q_lora_rank, - eps=config.rms_norm_eps) - self.q_b_proj = ColumnParallelLinear(q_lora_rank, - self.num_heads * - self.qk_head_dim, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.q_b_proj") + self.q_a_proj = ReplicatedLinear( + self.hidden_size, + self.q_lora_rank, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_a_proj", + ) + self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps) + self.q_b_proj = ColumnParallelLinear( + q_lora_rank, + self.num_heads * self.qk_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_b_proj", + ) else: - self.q_proj = ColumnParallelLinear(self.hidden_size, - self.num_heads * - self.qk_head_dim, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.q_proj") + self.q_proj = ColumnParallelLinear( + self.hidden_size, + self.num_heads * self.qk_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_proj", + ) self.kv_a_proj_with_mqa = ReplicatedLinear( self.hidden_size, self.kv_lora_rank + self.qk_rope_head_dim, bias=False, quant_config=quant_config, - prefix=f"{prefix}.kv_a_proj_with_mqa") - self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, - eps=config.rms_norm_eps) + prefix=f"{prefix}.kv_a_proj_with_mqa", + ) + self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps) self.kv_b_proj = ColumnParallelLinear( self.kv_lora_rank, self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), bias=False, quant_config=quant_config, - prefix=f"{prefix}.kv_b_proj") + prefix=f"{prefix}.kv_b_proj", + ) # O projection. - self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim, - self.hidden_size, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.o_proj") + self.o_proj = RowParallelLinear( + self.num_heads * self.v_head_dim, + self.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) if rope_scaling: - rope_scaling["rope_type"] = 'deepseek_yarn' + rope_scaling["rope_type"] = "deepseek_yarn" - self.rotary_emb = get_rope(qk_rope_head_dim, - rotary_dim=qk_rope_head_dim, - max_position=max_position_embeddings, - base=rope_theta, - rope_scaling=rope_scaling, - is_neox_style=False) + self.rotary_emb = get_rope( + qk_rope_head_dim, + rotary_dim=qk_rope_head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + is_neox_style=False, + ) if rope_scaling: mscale_all_dim = rope_scaling.get("mscale_all_dim", False) @@ -392,13 +508,15 @@ def __init__( mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim)) self.scaling = self.scaling * mscale * mscale - self.attn = Attention(self.num_local_heads, - self.qk_head_dim, - self.scaling, - num_kv_heads=self.num_local_heads, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn") + self.attn = Attention( + self.num_local_heads, + self.qk_head_dim, + self.scaling, + num_kv_heads=self.num_local_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) def forward( self, @@ -408,47 +526,43 @@ def forward( if self.q_lora_rank is not None: q = self.q_a_proj(hidden_states)[0] q = self.q_a_layernorm(q) - q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, - self.qk_head_dim) + q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim) else: - q = self.q_proj(hidden_states)[0].view(-1, self.num_local_heads, - self.qk_head_dim) - q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], - dim=-1) + q = self.q_proj(hidden_states)[0].view( + -1, self.num_local_heads, self.qk_head_dim + ) + q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0] - kv_a, _ = latent_cache.split( - [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) latent_cache = latent_cache.unsqueeze(1) kv_a = self.kv_a_layernorm(kv_a) kv = self.kv_b_proj(kv_a)[0] - kv = kv.view(-1, self.num_local_heads, - self.qk_nope_head_dim + self.v_head_dim) + kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim) k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) - k_pe = latent_cache[:, :, self.kv_lora_rank:] + k_pe = latent_cache[:, :, self.kv_lora_rank :] q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) - q[..., self.qk_nope_head_dim:] = q_pe + q[..., self.qk_nope_head_dim :] = q_pe k = torch.empty_like(q) - k[..., :self.qk_nope_head_dim] = k_nope - k[..., self.qk_nope_head_dim:] = k_pe + k[..., : self.qk_nope_head_dim] = k_nope + k[..., self.qk_nope_head_dim :] = k_pe # padding value to qk_head_dim for alignment v = torch.nn.functional.pad( - v, [0, self.qk_head_dim - self.v_head_dim], - value=0).view(-1, self.num_local_heads * self.qk_head_dim) + v, [0, self.qk_head_dim - self.v_head_dim], value=0 + ).view(-1, self.num_local_heads * self.qk_head_dim) attn_output = self.attn(q, k, v) - attn_output = attn_output.view( - -1, self.num_local_heads, - self.qk_head_dim)[..., :self.v_head_dim].reshape( - -1, self.num_local_heads * self.v_head_dim) + attn_output = attn_output.view(-1, self.num_local_heads, self.qk_head_dim)[ + ..., : self.v_head_dim + ].reshape(-1, self.num_local_heads * self.v_head_dim) output, _ = self.o_proj(attn_output) return output class DeepseekV32IndexerCache(torch.nn.Module, AttentionLayerBase): - - def __init__(self, head_dim: int, dtype: torch.dtype, prefix: str, - cache_config: CacheConfig): + def __init__( + self, head_dim: int, dtype: torch.dtype, prefix: str, cache_config: CacheConfig + ): super().__init__() self.kv_cache = [torch.tensor([])] self.head_dim = head_dim @@ -460,7 +574,7 @@ def __init__(self, head_dim: int, dtype: torch.dtype, prefix: str, raise ValueError(f"Duplicate layer name: {prefix}") compilation_config.static_forward_context[prefix] = self - def get_kv_cache_spec(self) -> KVCacheSpec: + def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec: return MLAAttentionSpec( # Only has one vector instead of K + V block_size=self.cache_config.block_size, num_kv_heads=1, @@ -468,55 +582,12 @@ def get_kv_cache_spec(self) -> KVCacheSpec: dtype=self.dtype, ) - def forward(self): - ... + def forward(self): ... def get_attn_backend(self) -> AttentionBackend: return DeepseekV32IndexerBackend -@torch.inference_mode() -def cp_gather_indexer_k_quant_cache( - kv_cache, # [num_blocks, block_size, head_dim + 1] - dst_value, # [cu_seq_lens[-1], head_dim] - block_table, # [batch_size, num_blocks] - cu_seq_lens, # [batch_size + 1, ] - batch_size, -): - num_blocks, block_size, _ = kv_cache.shape - head_dim = dst_value.shape[-1] - kv_cache = kv_cache.view(num_blocks, -1) - - expected_value = [] - for b in range(batch_size): - s = cu_seq_lens[b + 1] - cu_seq_lens[b] - if s == 0: - continue - tot = cdiv(s, block_size) - blocks = block_table[b, :tot] - - value = [] - full_block = torch.arange(tot - 1, - device=kv_cache.device, - dtype=torch.int32) - non_remaining_value = kv_cache[blocks[full_block], :block_size * - head_dim].view(-1, head_dim) - - remaining = s - (tot - 1) * block_size - - value = torch.cat([ - non_remaining_value, - kv_cache[blocks[-1], :remaining * head_dim].view(-1, head_dim) - ], - dim=0) - - expected_value.append(value) - - gather_value = torch.cat(expected_value, dim=0).view(-1, head_dim) - gather_value = gather_value.view(torch.bfloat16) - dst_value.copy_(gather_value) - - def sparse_attn_indexer( hidden_states: torch.Tensor, k_cache_prefix: str, @@ -525,14 +596,13 @@ def sparse_attn_indexer( k_bf16: torch.Tensor, weights: torch.Tensor, quant_block_size: int, - scale_fmt: Optional[str], + scale_fmt: str | None, topk_tokens: int, head_dim: int, max_model_len: int, total_seq_lens: int, - topk_indices_buffer: Optional[torch.Tensor], + topk_indices_buffer: torch.Tensor | None, ) -> torch.Tensor: - # careful! this will be None in dummy run attn_metadata = get_forward_context().attn_metadata # assert isinstance(attn_metadata, dict) @@ -567,42 +637,42 @@ def sparse_attn_indexer( scale_fmt, ) - topk_indices_buffer[:hidden_states.shape[0]] = -1 + topk_indices_buffer[: hidden_states.shape[0]] = -1 if has_prefill: prefill_metadata = attn_metadata.prefill for chunk in prefill_metadata.chunks: - _k_bf16 = torch.empty([chunk.total_seq_lens, head_dim], - device=k_bf16.device, - dtype=torch.bfloat16) - cp_gather_indexer_k_quant_cache( + _k_bf16 = torch.empty( + [chunk.total_seq_lens, head_dim], + device=k_bf16.device, + dtype=torch.bfloat16, + ) + k_scale = None + mx_ops.cp_gather_indexer_k_quant_cache( kv_cache, _k_bf16, + k_scale, chunk.block_table, chunk.cu_seq_lens, - chunk.num_reqs, ) logits = bf16_mqa_logits( - q_bf16[chunk.token_start:chunk.token_end], + q_bf16[chunk.token_start : chunk.token_end], _k_bf16, - weights[chunk.token_start:chunk.token_end], + weights[chunk.token_start : chunk.token_end], + chunk.cu_seqlen_ks, + chunk.cu_seqlen_ke, + ) + num_rows = logits.shape[0] + assert topk_tokens == 2048, "top_k_per_row assumes size 2048" + topk_indices = topk_indices_buffer[ + chunk.token_start : chunk.token_end, :topk_tokens + ] + mx_ops.top_k_per_row( + logits, chunk.cu_seqlen_ks, chunk.cu_seqlen_ke, + topk_indices, + num_rows, ) - topk_indices = logits.topk(min(topk_tokens, logits.shape[-1]), - dim=-1)[1] - topk_indices -= chunk.cu_seqlen_ks[:, None] - mask_lo = topk_indices >= 0 - mask_hi = topk_indices - (chunk.cu_seqlen_ke - - chunk.cu_seqlen_ks)[:, None] < 0 - mask = torch.full_like(topk_indices, - False, - dtype=torch.bool, - device=topk_indices.device) - mask = mask_lo & mask_hi - topk_indices = topk_indices.masked_fill(~mask, -1) - topk_indices_buffer[ - chunk.token_start:chunk.token_end, :topk_indices. - shape[-1]] = topk_indices.to(dtype=torch.int32) if has_decode: decode_metadata = attn_metadata.decode @@ -616,10 +686,12 @@ def sparse_attn_indexer( # prefill and decode by decode_threshold # (currently set to 1 + speculative tokens) padded_q_bf16_decode_tokens = pack_seq_triton( - q_bf16[:num_decode_tokens], decode_lens) + q_bf16[:num_decode_tokens], decode_lens + ) else: padded_q_bf16_decode_tokens = q_bf16[:num_decode_tokens].reshape( - decode_lens.shape[0], -1, *q_bf16.shape[1:]) + decode_lens.shape[0], -1, *q_bf16.shape[1:] + ) # TODO: move and optimize below logic with triton kernels batch_size = padded_q_bf16_decode_tokens.shape[0] next_n = padded_q_bf16_decode_tokens.shape[1] @@ -632,39 +704,29 @@ def sparse_attn_indexer( decode_metadata.seq_lens, decode_metadata.block_table, decode_metadata.schedule_metadata, - max_context_len=max_model_len, + max_model_len, + ) + num_rows = logits.shape[0] + assert topk_tokens == 2048, "top_k_per_row assumes size 2048" + topk_indices = topk_indices_buffer[:num_decode_tokens, :topk_tokens] + + mx_ops.top_k_per_row_decode( + logits, + next_n, + decode_metadata.seq_lens, + topk_indices, + num_rows, ) - # padded query len - current_device = padded_q_bf16_decode_tokens.device - padded_num_tokens = batch_size * next_n - positions = torch.arange(max_model_len, - device=current_device).unsqueeze(0).expand( - batch_size * next_n, -1) - row_indices = torch.arange(padded_num_tokens, - device=current_device) // next_n - next_n_offset = torch.arange( - padded_num_tokens, - device=padded_q_bf16_decode_tokens.device) % next_n - index_end_pos = (decode_metadata.seq_lens[row_indices] - next_n + - next_n_offset).unsqueeze(1) - # index_end_pos: [B * N, 1] - mask = positions <= index_end_pos - # mask: [B * N, L] - logits = logits.masked_fill(~mask, float('-inf')) - topk_indices = logits.topk(topk_tokens, - dim=-1)[1].to(torch.int32) # [B * N, K] - # ensure we don't set indices for the top k - # that is out of range(masked already) - # this will happen if context length is shorter than K - topk_indices[topk_indices > index_end_pos] = -1 if decode_metadata.requires_padding: # if padded, we need to unpack # the topk indices removing padded tokens topk_indices = unpack_seq_triton( topk_indices.reshape(batch_size, -1, topk_indices.shape[-1]), - decode_lens) - topk_indices_buffer[:num_decode_tokens, :topk_indices. - shape[-1]] = topk_indices.to(dtype=torch.int32) + decode_lens, + ) + topk_indices_buffer[:num_decode_tokens, : topk_indices.shape[-1]] = ( + topk_indices.to(dtype=torch.int32) + ) return topk_indices_buffer @@ -677,19 +739,19 @@ def sparse_attn_indexer_fake( k: torch.Tensor, weights: torch.Tensor, quant_block_size: int, - scale_fmt: Optional[str], + scale_fmt: str | None, topk_tokens: int, head_dim: int, max_model_len: int, total_seq_lens: int, - topk_indices_buffer: Optional[torch.Tensor], + topk_indices_buffer: torch.Tensor | None, ) -> torch.Tensor: # profile run # NOTE(Chen): create the max possible flattened_kv. So that # profile_run can get correct memory usage. - _flattened_kv = torch.empty([total_seq_lens, head_dim], - device=k.device, - dtype=torch.bfloat16) + _flattened_kv = torch.empty( + [total_seq_lens, head_dim], device=k.device, dtype=torch.bfloat16 + ) _k = _flattened_kv[..., :head_dim].view(torch.bfloat16).contiguous() _k_scale = _flattened_kv[..., head_dim:].view(torch.float32).contiguous() return topk_indices_buffer @@ -705,16 +767,17 @@ def sparse_attn_indexer_fake( class Indexer(nn.Module): - - def __init__(self, - vllm_config: VllmConfig, - config: Union[DeepseekV2Config, DeepseekV3Config], - hidden_size: int, - q_lora_rank: int, - quant_config: Optional[QuantizationConfig], - cache_config: Optional[CacheConfig], - topk_indices_buffer: Optional[torch.Tensor], - prefix: str = ""): + def __init__( + self, + vllm_config: VllmConfig, + config: DeepseekV2Config | DeepseekV3Config, + hidden_size: int, + q_lora_rank: int, + quant_config: QuantizationConfig | None, + cache_config: CacheConfig | None, + topk_indices_buffer: torch.Tensor | None, + prefix: str = "", + ): super().__init__() self.vllm_config = vllm_config self.config = config @@ -725,22 +788,28 @@ def __init__(self, self.rope_dim = config.qk_rope_head_dim # 64 self.q_lora_rank = q_lora_rank # 1536 # no tensor parallel, just replicated - self.wq_b = ReplicatedLinear(self.q_lora_rank, - self.head_dim * self.n_head, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.wq_b") - self.wk = ReplicatedLinear(hidden_size, - self.head_dim, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.wk") + self.wq_b = ReplicatedLinear( + self.q_lora_rank, + self.head_dim * self.n_head, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.wq_b", + ) + self.wk = ReplicatedLinear( + hidden_size, + self.head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.wk", + ) self.k_norm = LayerNorm(self.head_dim, eps=1e-6) - self.weights_proj = ReplicatedLinear(hidden_size, - self.n_head, - bias=False, - quant_config=None, - prefix=f"{prefix}.weights_proj") + self.weights_proj = ReplicatedLinear( + hidden_size, + self.n_head, + bias=False, + quant_config=None, + prefix=f"{prefix}.weights_proj", + ) self.softmax_scale = self.head_dim**-0.5 # MetaX use bfloat16 @@ -751,27 +820,34 @@ def __init__(self, # NOTE: (zyongye) we use fp8 naive cache, # where we store value in fp8 and scale in fp32 # per self.quant_block_size element - self.k_cache = DeepseekV32IndexerCache(head_dim=self.head_dim, - dtype=torch.bfloat16, - prefix=f"{prefix}.k_cache", - cache_config=cache_config) + self.k_cache = DeepseekV32IndexerCache( + head_dim=self.head_dim, + dtype=torch.bfloat16, + prefix=f"{prefix}.k_cache", + cache_config=cache_config, + ) self.max_model_len = vllm_config.model_config.max_model_len self.prefix = prefix from vllm_metax.v1.attention.backends.mla.indexer import ( - get_max_prefill_buffer_size) + get_max_prefill_buffer_size, + ) + self.max_total_seq_len = get_max_prefill_buffer_size(vllm_config) - def forward(self, hidden_states: torch.Tensor, qr: torch.Tensor, positions, - rotary_emb) -> torch.Tensor: + def forward( + self, hidden_states: torch.Tensor, qr: torch.Tensor, positions, rotary_emb + ) -> torch.Tensor: q, _ = self.wq_b(qr) q = q.view(-1, self.n_head, self.head_dim) q_pe, q_nope = torch.split( - q, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1) + q, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1 + ) k, _ = self.wk(hidden_states) k = self.k_norm(k) k_pe, k_nope = torch.split( - k, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1) + k, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1 + ) q_pe, k_pe = rotary_emb(positions, q_pe, k_pe.unsqueeze(1)) q = torch.cat([q_pe, q_nope], dim=-1) @@ -790,12 +866,11 @@ def forward(self, hidden_states: torch.Tensor, qr: torch.Tensor, positions, q = q.view(-1, self.n_head, self.head_dim) weights, _ = self.weights_proj(hidden_states) - weights = weights.unsqueeze( - -1) * self.softmax_scale * self.n_head**-0.5 + weights = weights.unsqueeze(-1) * self.softmax_scale * self.n_head**-0.5 weights = weights.squeeze(-1) - assert (q.dtype == torch.bfloat16) - assert (k.dtype == torch.bfloat16) + assert q.dtype == torch.bfloat16 + assert k.dtype == torch.bfloat16 return torch.ops.vllm.mx_sparse_attn_indexer( hidden_states, self.k_cache.prefix, @@ -817,7 +892,7 @@ class DeepseekV2MLAAttention(nn.Module): """ Main reference: DeepseekV2 paper, and FlashInfer Implementation (https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551). - + For more info see MLACommonImpl in: vllm/v1/attention/backends/mla/utils.py """ @@ -825,21 +900,21 @@ class DeepseekV2MLAAttention(nn.Module): def __init__( self, vllm_config: VllmConfig, - config: Union[DeepseekV2Config, DeepseekV3Config], + config: DeepseekV2Config | DeepseekV3Config, hidden_size: int, num_heads: int, qk_nope_head_dim: int, qk_rope_head_dim: int, v_head_dim: int, - q_lora_rank: Optional[int], + q_lora_rank: int | None, kv_lora_rank: int, rope_theta: float = 10000, - rope_scaling: Optional[dict[str, Any]] = None, + rope_scaling: dict[str, Any] | None = None, max_position_embeddings: int = 8192, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", - topk_indices_buffer: Optional[torch.Tensor] = None, + topk_indices_buffer: torch.Tensor | None = None, ) -> None: super().__init__() self.hidden_size = hidden_size @@ -867,53 +942,60 @@ def __init__( bias=False, quant_config=quant_config, prefix=f"{prefix}.fused_qkv_a_proj", - disable_tp=True) + disable_tp=True, + ) else: self.kv_a_proj_with_mqa = ReplicatedLinear( self.hidden_size, self.kv_lora_rank + self.qk_rope_head_dim, bias=False, quant_config=quant_config, - prefix=f"{prefix}.kv_a_proj_with_mqa") + prefix=f"{prefix}.kv_a_proj_with_mqa", + ) if self.q_lora_rank is not None: - self.q_a_layernorm = RMSNorm(self.q_lora_rank, - eps=config.rms_norm_eps) - self.q_b_proj = ColumnParallelLinear(self.q_lora_rank, - self.num_heads * - self.qk_head_dim, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.q_b_proj") + self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps) + self.q_b_proj = ColumnParallelLinear( + self.q_lora_rank, + self.num_heads * self.qk_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_b_proj", + ) else: - self.q_proj = ColumnParallelLinear(self.hidden_size, - self.num_heads * - self.qk_head_dim, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.q_proj") - self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, - eps=config.rms_norm_eps) + self.q_proj = ColumnParallelLinear( + self.hidden_size, + self.num_heads * self.qk_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_proj", + ) + self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps) self.kv_b_proj = ColumnParallelLinear( self.kv_lora_rank, self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), bias=False, quant_config=quant_config, - prefix=f"{prefix}.kv_b_proj") - self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim, - self.hidden_size, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.o_proj") + prefix=f"{prefix}.kv_b_proj", + ) + self.o_proj = RowParallelLinear( + self.num_heads * self.v_head_dim, + self.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) if rope_scaling: - rope_scaling["rope_type"] = 'deepseek_yarn' - self.rotary_emb = get_rope(qk_rope_head_dim, - rotary_dim=qk_rope_head_dim, - max_position=max_position_embeddings, - base=rope_theta, - rope_scaling=rope_scaling, - is_neox_style=False) + rope_scaling["rope_type"] = "deepseek_yarn" + self.rotary_emb = get_rope( + qk_rope_head_dim, + rotary_dim=qk_rope_head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + is_neox_style=False, + ) if rope_scaling: mscale_all_dim = rope_scaling.get("mscale_all_dim", False) scaling_factor = rope_scaling["factor"] @@ -923,9 +1005,16 @@ def __init__( self.is_v32 = hasattr(config, "index_topk") if self.is_v32: - self.indexer = Indexer(vllm_config, config, hidden_size, - q_lora_rank, quant_config, cache_config, - topk_indices_buffer, f"{prefix}.indexer") + self.indexer = Indexer( + vllm_config, + config, + hidden_size, + q_lora_rank, + quant_config, + cache_config, + topk_indices_buffer, + f"{prefix}.indexer", + ) else: self.indexer = None @@ -935,11 +1024,12 @@ def __init__( rotary_emb=self.rotary_emb, o_proj=self.o_proj, fused_qkv_a_proj=self.fused_qkv_a_proj - if self.q_lora_rank is not None else None, + if self.q_lora_rank is not None + else None, kv_a_proj_with_mqa=self.kv_a_proj_with_mqa - if self.q_lora_rank is None else None, - q_a_layernorm=self.q_a_layernorm - if self.q_lora_rank is not None else None, + if self.q_lora_rank is None + else None, + q_a_layernorm=self.q_a_layernorm if self.q_lora_rank is not None else None, q_b_proj=self.q_b_proj if self.q_lora_rank is not None else None, q_proj=self.q_proj if self.q_lora_rank is None else None, indexer=self.indexer, @@ -947,7 +1037,7 @@ def __init__( topk_indices_buffer=topk_indices_buffer, ) - self.mla_attn = MultiHeadLatentAttention( + self.mla_attn = MultiHeadLatentAttentionWrapper( self.hidden_size, self.num_local_heads, self.scaling, @@ -971,14 +1061,17 @@ def forward( class DeepseekV2DecoderLayer(nn.Module): - - def __init__(self, - vllm_config: VllmConfig, - prefix: str, - topk_indices_buffer: Optional[torch.Tensor] = None) -> None: + def __init__( + self, + vllm_config: VllmConfig, + prefix: str, + config: DeepseekV2Config | None = None, + topk_indices_buffer: torch.Tensor | None = None, + ) -> None: super().__init__() - config = vllm_config.model_config.hf_config + if config is None: + config = vllm_config.model_config.hf_config model_config = vllm_config.model_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config @@ -987,13 +1080,25 @@ def __init__(self, self.hidden_size = config.hidden_size rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) - max_position_embeddings = getattr(config, "max_position_embeddings", - 8192) + max_position_embeddings = getattr(config, "max_position_embeddings", 8192) + moe_layer_freq = getattr(config, "moe_layer_freq", 1) # DecoderLayers are created with `make_layers` which passes the prefix # with the layer's index. - layer_idx = int(prefix.split(sep='.')[-1]) + layer_idx = int(prefix.split(sep=".")[-1]) self.layer_idx = layer_idx - if model_config.use_mla: + + # verify MLA attention specific fields + qk_nope_head_dim = getattr(config, "qk_nope_head_dim", 0) + qk_rope_head_dim = getattr(config, "qk_rope_head_dim", 0) + v_head_dim = getattr(config, "v_head_dim", 0) + kv_lora_rank = getattr(config, "kv_lora_rank", 0) + use_mha = config.model_type == "deepseek" or all( + dim == 0 for dim in (qk_nope_head_dim, qk_rope_head_dim) + ) + + if use_mha: + attn_cls = DeepseekAttention + elif model_config.use_mla: attn_cls = DeepseekV2MLAAttention else: attn_cls = DeepseekV2Attention @@ -1002,12 +1107,11 @@ def __init__(self, config=config, hidden_size=self.hidden_size, num_heads=config.num_attention_heads, - qk_nope_head_dim=config.qk_nope_head_dim, - qk_rope_head_dim=config.qk_rope_head_dim, - v_head_dim=config.v_head_dim, - q_lora_rank=config.q_lora_rank - if hasattr(config, "q_lora_rank") else None, - kv_lora_rank=config.kv_lora_rank, + qk_nope_head_dim=qk_nope_head_dim, + qk_rope_head_dim=qk_rope_head_dim, + v_head_dim=v_head_dim, + q_lora_rank=config.q_lora_rank if hasattr(config, "q_lora_rank") else None, + kv_lora_rank=kv_lora_rank, rope_theta=rope_theta, rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, @@ -1017,9 +1121,11 @@ def __init__(self, topk_indices_buffer=topk_indices_buffer, ) - if (config.n_routed_experts is not None - and layer_idx >= config.first_k_dense_replace - and layer_idx % config.moe_layer_freq == 0): + if ( + config.n_routed_experts is not None + and layer_idx >= config.first_k_dense_replace + and layer_idx % moe_layer_freq == 0 + ): self.mlp = DeepseekV2MoE( config=config, parallel_config=parallel_config, @@ -1034,60 +1140,59 @@ def __init__(self, quant_config=quant_config, prefix=f"{prefix}.mlp", ) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.routed_scaling_factor = config.routed_scaling_factor + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.routed_scaling_factor = getattr(config, "routed_scaling_factor", 1.0) def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - residual: Optional[torch.Tensor], + residual: torch.Tensor | None, ) -> torch.Tensor: # Self Attention if residual is None: residual = hidden_states.clone() hidden_states = self.input_layernorm(hidden_states) else: - hidden_states, residual = self.input_layernorm( - hidden_states, residual) + hidden_states, residual = self.input_layernorm(hidden_states, residual) hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, ) - if hidden_states.dtype == torch.float16: + if ( + not isinstance(self.self_attn, DeepseekAttention) + and hidden_states.dtype == torch.float16 + ): # Fix FP16 overflow # We scale both hidden_states and residual before # rmsnorm, and rmsnorm result would not affect by scale. - hidden_states *= 1. / self.routed_scaling_factor + hidden_states *= 1.0 / self.routed_scaling_factor if self.layer_idx == 0: # The residual is shared by all layers, we only scale it on # first layer. - residual *= 1. / self.routed_scaling_factor + residual *= 1.0 / self.routed_scaling_factor # Fully Connected - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) hidden_states = self.mlp(hidden_states) - if isinstance(self.mlp, - DeepseekV2MLP) and hidden_states.dtype == torch.float16: + if isinstance(self.mlp, DeepseekV2MLP) and hidden_states.dtype == torch.float16: # Fix FP16 overflow # Scaling the DeepseekV2MLP output, it is the input of # input_layernorm of next decoder layer. # The scaling of DeepseekV2MOE output would be done in the forward # of DeepseekV2MOE - hidden_states *= 1. / self.routed_scaling_factor + hidden_states *= 1.0 / self.routed_scaling_factor return hidden_states, residual @support_torch_compile class DeepseekV2Model(nn.Module): - fall_back_to_pt_during_load = False def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): @@ -1096,6 +1201,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config self.config = config + self.device = current_platform.device_type self.vocab_size = config.vocab_size self.is_v32 = hasattr(config, "index_topk") @@ -1105,7 +1211,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): vllm_config.scheduler_config.max_num_batched_tokens, topk_tokens, dtype=torch.int32, - device="cuda") + device=self.device, + ) else: topk_indices_buffer = None @@ -1114,23 +1221,25 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config.vocab_size, config.hidden_size, quant_config=quant_config, - prefix=f"{prefix}.embed_tokens") + prefix=f"{prefix}.embed_tokens", + ) else: self.embed_tokens = PPMissingLayer() - self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: DeepseekV2DecoderLayer(vllm_config, prefix, - topk_indices_buffer), - prefix=f"{prefix}.layers") + lambda prefix: DeepseekV2DecoderLayer( + vllm_config, prefix, topk_indices_buffer=topk_indices_buffer + ), + prefix=f"{prefix}.layers", + ) if get_pp_group().is_last_rank: self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) else: self.norm = PPMissingLayer() - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size)) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -1139,9 +1248,9 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors], - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + intermediate_tensors: IntermediateTensors | None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds @@ -1157,17 +1266,58 @@ def forward( hidden_states, residual = layer(positions, hidden_states, residual) if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states -class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts, - SupportsLoRA): +class DeepseekV2MixtureOfExperts(MixtureOfExperts): + moe_mlp_layers: list[DeepseekV2MoE] + """ + List of MoE MLP layers in the model. + """ + + def extract_moe_parameters(self, example_moe: DeepseekV2MoE | None): + if example_moe is None: + self.num_moe_layers = 0 + self.num_expert_groups = 0 + self.num_logical_experts = 0 + self.num_physical_experts = 0 + self.num_local_physical_experts = 0 + self.num_routed_experts = 0 + self.num_shared_experts = 0 + self.num_redundant_experts = 0 + logger.warning("DeepSeekV2: No DeepseekV2MoE layer found in model.layers.") + else: + self.num_logical_experts = example_moe.n_logical_experts + self.num_physical_experts = example_moe.n_physical_experts + self.num_local_physical_experts = example_moe.n_local_physical_experts + self.num_routed_experts = example_moe.n_routed_experts + self.num_shared_experts = example_moe.n_shared_experts + self.num_redundant_experts = example_moe.n_redundant_experts + + def update_physical_experts_metadata( + self, + num_physical_experts: int, + num_local_physical_experts: int, + ) -> None: + assert self.num_local_physical_experts == num_local_physical_experts + self.num_physical_experts = num_physical_experts + self.num_local_physical_experts = num_local_physical_experts + self.num_redundant_experts = num_physical_experts - self.num_logical_experts + for moe in self.moe_mlp_layers: + moe.n_local_physical_experts = num_local_physical_experts + moe.n_physical_experts = num_physical_experts + moe.n_redundant_experts = self.num_redundant_experts + moe.experts.update_expert_map() + + +class DeepseekV2ForCausalLM( + nn.Module, SupportsPP, DeepseekV2MixtureOfExperts, SupportsLoRA +): packed_modules_mapping = { "gate_up_proj": ["gate_proj", "up_proj"], } @@ -1179,20 +1329,31 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.config = config self.quant_config = quant_config + qk_nope_head_dim = getattr(config, "qk_nope_head_dim", 0) + qk_rope_head_dim = getattr(config, "qk_rope_head_dim", 0) + self.use_mha = config.model_type == "deepseek" or all( + dim == 0 for dim in (qk_nope_head_dim, qk_rope_head_dim) + ) + + if self.use_mha: + self.packed_modules_mapping["qkv_proj"] = ["q_proj", "k_proj", "v_proj"] + # `packed_modules_mapping` needs to be modified before # initializing DeepseekV2Model, as it is passed inplace to # quantization config init and may be used to select the # quant_method for relevant layers during initialization. - self.fuse_qkv_a_proj = hasattr( - config, "q_lora_rank") and config.q_lora_rank is not None + self.fuse_qkv_a_proj = ( + hasattr(config, "q_lora_rank") and config.q_lora_rank is not None + ) if self.fuse_qkv_a_proj: self.packed_modules_mapping["fused_qkv_a_proj"] = [ "q_a_proj", "kv_a_proj_with_mqa", ] - self.model = DeepseekV2Model(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + self.model = DeepseekV2Model( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) if get_pp_group().is_last_rank: self.lm_head = ParallelLMHead( config.vocab_size, @@ -1204,15 +1365,21 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.lm_head = PPMissingLayer() self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model.make_empty_intermediate_tensors + ) + # Set MoE hyperparameters + self.num_moe_layers = ( + self.config.num_hidden_layers - self.config.first_k_dense_replace + ) + self.set_moe_parameters() + + def set_moe_parameters(self): self.expert_weights = [] - # Set MoE hyperparameters - self.num_moe_layers = (config.num_hidden_layers - - config.first_k_dense_replace) - self.num_expert_groups = config.n_group + self.num_expert_groups = getattr(self.config, "n_group", 1) - self.moe_layers: list[FusedMoE] = [] + self.moe_layers = [] + self.moe_mlp_layers = [] example_moe = None for layer in self.model.layers: if isinstance(layer, PPMissingLayer): @@ -1222,51 +1389,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): if isinstance(layer.mlp, DeepseekV2MoE): # Pick last one layer since the first ones may be dense layers. example_moe = layer.mlp + self.moe_mlp_layers.append(layer.mlp) self.moe_layers.append(layer.mlp.experts) - if example_moe is None: - raise RuntimeError("No DeepseekV2MoE layer found in model.layers.") - - self.num_logical_experts = example_moe.n_logical_experts - self.num_physical_experts = example_moe.n_physical_experts - self.num_local_physical_experts = example_moe.n_local_physical_experts - self.num_routed_experts = example_moe.n_routed_experts - self.num_shared_experts = example_moe.n_shared_experts - self.num_redundant_experts = example_moe.n_redundant_experts - - def set_eplb_state( - self, - expert_load_view: torch.Tensor, - logical_to_physical_map: torch.Tensor, - logical_replica_count: torch.Tensor, - ) -> None: - for layer_idx, layer in enumerate(self.moe_layers): - # Register the expert weights. - self.expert_weights.append(layer.get_expert_weights()) - layer.set_eplb_state( - moe_layer_idx=layer_idx, - expert_load_view=expert_load_view, - logical_to_physical_map=logical_to_physical_map, - logical_replica_count=logical_replica_count, - ) - - def update_physical_experts_metadata( - self, - num_physical_experts: int, - num_local_physical_experts: int, - ) -> None: - assert self.num_local_physical_experts == num_local_physical_experts - self.num_physical_experts = num_physical_experts - self.num_local_physical_experts = num_local_physical_experts - self.num_redundant_experts = (num_physical_experts - - self.num_logical_experts) - for layer in self.model.layers: - if isinstance(layer.mlp, DeepseekV2MoE): - moe = layer.mlp - moe.n_local_physical_experts = num_local_physical_experts - moe.n_physical_experts = num_physical_experts - moe.n_redundant_experts = self.num_redundant_experts - moe.experts.update_expert_map() + self.extract_moe_parameters(example_moe) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -1275,38 +1401,69 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: + hidden_states = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return hidden_states def compute_logits( self, hidden_states: torch.Tensor, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: logits = self.logits_processor(self.lm_head, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + return SharedFusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.n_routed_experts, + num_redundant_experts=0, + ) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + rocm_aiter_moe_shared_expert_enabled = ( + rocm_aiter_ops.is_fusion_moe_shared_experts_enabled() + ) stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "up_proj", 1), + ] + mla_params_mapping = [ ("fused_qkv_a_proj", "q_a_proj", 0), ("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1), ] + mha_params_mapping = [ + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + if self.use_mha: + stacked_params_mapping.extend(mha_params_mapping) + else: + stacked_params_mapping.extend(mla_params_mapping) # Params for weights, fp8 weight scales, fp8 activation scales # (param_name, weight_name, expert_id, shard_id) - expert_params_mapping = FusedMoE.make_expert_params_mapping( + expert_params_mapping = SharedFusedMoE.make_expert_params_mapping( ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", - num_experts=self.config.n_routed_experts, - num_redundant_experts=self.num_redundant_experts) + num_experts=self.config.n_routed_experts + + ( + self.config.n_shared_experts + if rocm_aiter_moe_shared_expert_enabled + else 0 + ), + num_redundant_experts=self.num_redundant_experts, + ) params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() @@ -1318,7 +1475,11 @@ def load_weights(self, weights: Iterable[tuple[str, if spec_layer is not None: continue # skip spec decode layers for main model - for (param_name, weight_name, shard_id) in stacked_params_mapping: + is_fuse_shared_experts_layer = rocm_aiter_moe_shared_expert_enabled and ( + "mlp.shared_experts" in name + ) + + for param_name, weight_name, shard_id in stacked_params_mapping: # Skip non-stacked layers and experts (experts handled below). if weight_name not in name: continue @@ -1328,15 +1489,18 @@ def load_weights(self, weights: Iterable[tuple[str, # name will be updated to mlp.experts[0].gate_up_proj, which # will then be updated below in expert_params_mapping # for mlp.experts[0].gate_gate_up_proj, which breaks load. - if (("mlp.experts." in name) and name not in params_dict): + if ("mlp.experts." in name) and name not in params_dict: + continue + if is_fuse_shared_experts_layer: continue name_mapped = name.replace(weight_name, param_name) # QKV fusion is optional, fall back to normal # weight loading if it's not enabled # if go with fusion option, then update name - if ((param_name == "fused_qkv_a_proj") - and name_mapped not in params_dict): + if ( + param_name == "fused_qkv_a_proj" + ) and name_mapped not in params_dict: continue else: name = name_mapped @@ -1353,78 +1517,138 @@ def load_weights(self, weights: Iterable[tuple[str, break else: is_expert_weight = False - for mapping in expert_params_mapping: - param_name, weight_name, expert_id, shard_id = mapping - if weight_name not in name: - continue - - # Anyway, this is an expert weight and should not be - # attempted to load as other weights later - is_expert_weight = True - - # Do not modify `name` since the loop may continue here - # Instead, create a new variable - name_mapped = name.replace(weight_name, param_name) - - if is_pp_missing_parameter(name_mapped, self): - continue - - param = params_dict[name_mapped] - # We should ask the weight loader to return success or not - # here since otherwise we may skip experts with other - # available replicas. - weight_loader = typing.cast(Callable[..., bool], - param.weight_loader) - success = weight_loader(param, - loaded_weight, - name_mapped, - shard_id=shard_id, - expert_id=expert_id, - return_success=True) - if success: - name = name_mapped - break - else: - if is_expert_weight: - # We've checked that this is an expert weight - # However it's not mapped locally to this rank - # So we simply skip it - continue - - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - - # Remapping the name of FP8 kv-scale. - name = maybe_remap_kv_scale_name(name, params_dict) - if name is None: - continue - - if is_pp_missing_parameter(name, self): - continue - - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) + + # Special handling: when AITER fusion_shared_experts is enabled, + # checkpoints may provide a single widened shared_experts tensor + # without explicit expert indices + # (e.g. ...mlp.shared_experts.gate_proj.weight). + # For models with multiple shared experts, split that tensor + # evenly into per-shared-expert slices and load them into + # appended expert slots mlp.experts.{n_routed_experts + j}.* + # accordingly. + num_chunks = 1 + if is_fuse_shared_experts_layer: + num_chunks = getattr(self.config, "n_shared_experts", 1) or 1 + # Determine split axis based on op type + # gate/up: ColumnParallel → split along dim 0 + # down: RowParallel → split along dim 1 + split_dim = 1 if "down_proj.weight" in name else 0 + total = loaded_weight.shape[split_dim] + assert total % num_chunks == 0, ( + f"Shared expert weight dim {total} " + f"not divisible by num_chunks {num_chunks}" + ) + chunk_size = total // num_chunks + + for j in range(num_chunks): + chunk_name = name + weight_to_load = loaded_weight + + if is_fuse_shared_experts_layer: + if split_dim == 0: + weight_to_load = loaded_weight[ + j * chunk_size : (j + 1) * chunk_size, : + ] + else: + weight_to_load = loaded_weight[ + :, j * chunk_size : (j + 1) * chunk_size + ] + # Synthesize an expert-style name so expert mapping + # can route it + chunk_name = name.replace( + "mlp.shared_experts", + f"mlp.experts.{self.config.n_routed_experts + j}", + ) + + # Use expert_params_mapping to locate the destination + # param and delegate to its expert-aware weight_loader + # with expert_id. + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in chunk_name: + continue + + # Anyway, this is an expert weight and should not be + # attempted to load as other weights later + is_expert_weight = True + + # Do not modify `name` since the loop may continue here + # Instead, create a new variable + name_mapped = chunk_name.replace(weight_name, param_name) + + if is_pp_missing_parameter(name_mapped, self): + continue + + param = params_dict[name_mapped] + # We should ask the weight loader to return success or + # not here since otherwise we may skip experts with + # other available replicas. + weight_loader = typing.cast( + Callable[..., bool], param.weight_loader + ) + success = weight_loader( + param, + weight_to_load, + name_mapped, + shard_id=shard_id, + expert_id=expert_id, + return_success=True, + ) + if success: + if not is_fuse_shared_experts_layer: + name = name_mapped + else: + loaded_params.add(name_mapped) + break + else: + if is_expert_weight: + # We've checked that this is an expert weight + # However it's not mapped locally to this rank + # So we simply skip it + continue + + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + weight_loader(param, loaded_weight) + if not is_fuse_shared_experts_layer: + loaded_params.add(name) return loaded_params +class DeepseekForCausalLM(DeepseekV2ForCausalLM): + pass + + class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM): pass # Compatibility with # https://huggingface.co/deepseek-ai/DeepSeek-V3-Base/blob/main/configuration_deepseek.py -def get_spec_layer_idx_from_weight_name(config: Union[DeepseekV2Config, - DeepseekV3Config], - weight_name: str) -> Optional[int]: - if (hasattr(config, "num_nextn_predict_layers") - and config.num_nextn_predict_layers > 0): +def get_spec_layer_idx_from_weight_name( + config: DeepseekV2Config | DeepseekV3Config, weight_name: str +) -> int | None: + if ( + hasattr(config, "num_nextn_predict_layers") + and config.num_nextn_predict_layers > 0 + ): layer_idx = config.num_hidden_layers for i in range(config.num_nextn_predict_layers): - if weight_name.startswith(f"model.layers.{layer_idx+i}."): + if weight_name.startswith(f"model.layers.{layer_idx + i}."): return layer_idx + i return None diff --git a/vllm_metax/models/qwen2_5_vl.py b/vllm_metax/models/qwen2_5_vl.py index aa9704cdf..99d1b0472 100644 --- a/vllm_metax/models/qwen2_5_vl.py +++ b/vllm_metax/models/qwen2_5_vl.py @@ -25,65 +25,93 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Qwen2.5-VL model compatible with HuggingFace weights.""" -from collections.abc import Iterable, Mapping, Sequence + +import math +from collections.abc import Callable, Iterable, Mapping, Sequence from functools import lru_cache, partial -from typing import Annotated, Any, Callable, Literal, Optional, Union +from typing import Annotated, Any, Literal, TypeAlias +import einops import torch import torch.nn as nn import torch.nn.functional as F -from einops import rearrange -from transformers import BatchFeature +from transformers import BatchFeature, PretrainedConfig from transformers.models.qwen2_5_vl import Qwen2_5_VLProcessor from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import ( - Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig) - -from vllm.attention.layer import check_upstream_fa_availability + Qwen2_5_VLConfig, + Qwen2_5_VLVisionConfig, +) + +from vllm.attention.backends.registry import _Backend +from vllm.attention.layer import maybe_get_vit_flash_attn_backend +from vllm.attention.ops.vit_attn_wrappers import ( + vit_flash_attn_wrapper, + vit_torch_sdpa_wrapper, + vit_xformers_attn_wrapper, +) +from vllm.compilation.decorators import support_torch_compile from vllm.config import VllmConfig from vllm.distributed import parallel_state from vllm.distributed import utils as dist_utils +from vllm.forward_context import set_forward_context from vllm.logger import init_logger from vllm.model_executor.layers.activation import get_act_and_mul_fn from vllm.model_executor.layers.layernorm import RMSNorm -# yapf: disable -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) -# yapf: enable +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.module_mapping import MultiModelKeys +from vllm.model_executor.models.vision import should_torch_compile_mm_vit from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.evs import (compute_mrope_for_media, - compute_retained_tokens_count, - compute_retention_mask, - recompute_mrope_positions) +from vllm.multimodal.evs import ( + compute_mrope_for_media, + compute_retained_tokens_count, + compute_retention_mask, + recompute_mrope_positions, +) from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs from vllm.multimodal.parse import MultiModalDataItems from vllm.multimodal.processing import PromptReplacement, PromptUpdate -from vllm.platforms import _Backend from vllm.sequence import IntermediateTensors -from vllm.transformers_utils.config import uses_mrope -from vllm.utils import is_pin_memory_available +from vllm.utils.platform_utils import is_pin_memory_available from vllm.utils.tensor_schema import TensorSchema, TensorShape -from vllm.model_executor.models.interfaces import (MultiModalEmbeddings, - SupportsLoRA, - SupportsMultiModal, - SupportsMultiModalPruning, - SupportsPP, SupportsQuant) -from vllm.model_executor.models.qwen2_vl import Qwen2VLDummyInputsBuilder as Qwen2_5_VLDummyInputsBuilder -from vllm.model_executor.models.qwen2_vl import (Qwen2VLMultiModalProcessor, - Qwen2VLProcessingInfo, - apply_rotary_pos_emb_vision) -from vllm.model_executor.models.utils import (AutoWeightsLoader, WeightsMapper, - cast_overflow_tensors, - init_vllm_registered_model, - maybe_prefix, - merge_multimodal_embeddings) - -from vllm.model_executor.models.vision import get_vit_attn_backend, run_dp_sharded_mrope_vision_model +from vllm.model_executor.models.interfaces import ( + MultiModalEmbeddings, + SupportsEagle3, + SupportsLoRA, + SupportsMRoPE, + SupportsMultiModal, + SupportsMultiModalPruning, + SupportsPP, + SupportsQuant, +) +from vllm.model_executor.models.qwen2_vl import ( + Qwen2VLDummyInputsBuilder as Qwen2_5_VLDummyInputsBuilder, +) +from vllm.model_executor.models.qwen2_vl import ( + Qwen2VLMultiModalProcessor, + Qwen2VLProcessingInfo, + apply_rotary_pos_emb_vision, +) +from vllm.model_executor.models.utils import ( + AutoWeightsLoader, + WeightsMapper, + cast_overflow_tensors, + init_vllm_registered_model, + maybe_prefix, +) +from vllm.model_executor.models.vision import ( + conv3d_to_linear_weight, + get_vit_attn_backend, + run_dp_sharded_mrope_vision_model, +) logger = init_logger(__name__) @@ -101,8 +129,9 @@ class Qwen2_5_VLImagePixelInputs(TensorSchema): - pixel_values shape: (num_patches, num_channels * patch_size * patch_size) - image_grid_thw shape: (num_images, 3) in (grid_t, grid_h, grid_w) - formatnum_channels * patch_size * patch_size + format. """ + type: Literal["pixel_values"] pixel_values: Annotated[ @@ -131,6 +160,7 @@ class Qwen2_5_VLImageEmbeddingInputs(TensorSchema): - image_grid_thw shape: (num_images, 3) in (grid_t, grid_h, grid_w) format """ + type: Literal["image_embeds"] image_embeds: Annotated[ @@ -144,8 +174,9 @@ class Qwen2_5_VLImageEmbeddingInputs(TensorSchema): ] -Qwen2_5_VLImageInputs = Union[Qwen2_5_VLImagePixelInputs, - Qwen2_5_VLImageEmbeddingInputs] +Qwen2_5_VLImageInputs: TypeAlias = ( + Qwen2_5_VLImagePixelInputs | Qwen2_5_VLImageEmbeddingInputs +) class Qwen2_5_VLVideoPixelInputs(TensorSchema): @@ -165,6 +196,7 @@ class Qwen2_5_VLVideoPixelInputs(TensorSchema): grid along the temporal dimension in the 3D position IDs. Returned when `videos` is not `None`. """ + type: Literal["pixel_values_videos"] pixel_values_videos: Annotated[ @@ -178,7 +210,7 @@ class Qwen2_5_VLVideoPixelInputs(TensorSchema): ] second_per_grid_ts: Annotated[ - Optional[torch.Tensor], + torch.Tensor | None, TensorShape("nv"), ] @@ -198,6 +230,7 @@ class Qwen2_5_VLVideoEmbeddingInputs(TensorSchema): - video_grid_thw shape: (num_videos, 3) in (grid_t, grid_h, grid_w) format """ + type: Literal["video_embeds"] video_embeds: Annotated[ @@ -211,22 +244,24 @@ class Qwen2_5_VLVideoEmbeddingInputs(TensorSchema): ] -Qwen2_5_VLVideoInputs = Union[Qwen2_5_VLVideoPixelInputs, - Qwen2_5_VLVideoEmbeddingInputs] +Qwen2_5_VLVideoInputs: TypeAlias = ( + Qwen2_5_VLVideoPixelInputs | Qwen2_5_VLVideoEmbeddingInputs +) # === Vision Encoder === # class Qwen2_5_VisionMLP(nn.Module): - - def __init__(self, - in_features: int, - hidden_features: int, - bias: bool = False, - act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - use_data_parallel: bool = False): + def __init__( + self, + in_features: int, + hidden_features: int, + bias: bool = False, + act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + use_data_parallel: bool = False, + ): super().__init__() self.gate_up_proj = MergedColumnParallelLinear( input_size=in_features, @@ -234,14 +269,17 @@ def __init__(self, bias=bias, quant_config=quant_config, prefix=f"{prefix}.gate_up_proj", - disable_tp=use_data_parallel) - - self.down_proj = RowParallelLinear(hidden_features, - in_features, - bias=bias, - quant_config=quant_config, - prefix=f"{prefix}.down_proj", - disable_tp=use_data_parallel) + disable_tp=use_data_parallel, + ) + + self.down_proj = RowParallelLinear( + hidden_features, + in_features, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.down_proj", + disable_tp=use_data_parallel, + ) self.act_fn = act_fn def forward(self, x: torch.Tensor): @@ -254,14 +292,14 @@ def forward(self, x: torch.Tensor): def all_gather_interleave(local_tensor, hidden_size: int, tp_size: int): """All-gather the input tensor interleavely across model parallel group.""" import torch.distributed as dist + gathered_tensors = [torch.zeros_like(local_tensor) for _ in range(tp_size)] - dist.all_gather(gathered_tensors, - local_tensor, - group=parallel_state.get_tp_group().device_group) + dist.all_gather( + gathered_tensors, local_tensor, group=parallel_state.get_tp_group().device_group + ) gathered_tensors_split = [ - torch.split(tensor, hidden_size // tp_size, -1) - for tensor in gathered_tensors + torch.split(tensor, hidden_size // tp_size, -1) for tensor in gathered_tensors ] ordered_tensors = [ tensor for pair in zip(*gathered_tensors_split) for tensor in pair @@ -271,27 +309,32 @@ def all_gather_interleave(local_tensor, hidden_size: int, tp_size: int): class Qwen2_5_VisionAttention(nn.Module): - def __init__( self, embed_dim: int, num_heads: int, projection_size: int, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, attn_backend: _Backend = _Backend.TORCH_SDPA, use_upstream_fa: bool = False, + attn_backend_override: _Backend | None = None, ) -> None: super().__init__() # Per attention head and per partition values. - self.tp_size = (1 if use_data_parallel else - parallel_state.get_tensor_model_parallel_world_size()) + self.tp_size = ( + 1 + if use_data_parallel + else parallel_state.get_tensor_model_parallel_world_size() + ) self.tp_rank = parallel_state.get_tensor_model_parallel_rank() self.hidden_size_per_attention_head = dist_utils.divide( - projection_size, num_heads) + projection_size, num_heads + ) self.num_attention_heads_per_partition = dist_utils.divide( - num_heads, self.tp_size) + num_heads, self.tp_size + ) self.qkv = QKVParallelLinear( hidden_size=embed_dim, @@ -301,50 +344,66 @@ def __init__( bias=True, quant_config=quant_config, prefix=f"{prefix}.qkv", - disable_tp=use_data_parallel) + disable_tp=use_data_parallel, + ) - self.proj = RowParallelLinear(input_size=projection_size, - output_size=embed_dim, - quant_config=quant_config, - prefix=f"{prefix}.proj", - disable_tp=use_data_parallel) + self.proj = RowParallelLinear( + input_size=projection_size, + output_size=embed_dim, + quant_config=quant_config, + prefix=f"{prefix}.proj", + disable_tp=use_data_parallel, + ) self.attn_backend = attn_backend self.use_upstream_fa = use_upstream_fa + + # /--------------- Metax Modification ---------------\ + self.use_upstream_fa = True + from flash_attn import flash_attn_varlen_func + + self.flash_attn_varlen_func = flash_attn_varlen_func + # \--------------- Metax Modification ---------------/ + self.is_flash_attn_backend = self.attn_backend in { - _Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA + _Backend.FLASH_ATTN, + _Backend.ROCM_AITER_FA, } def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: # [s, b, 3 * head * head_dim] seq_len, bs, _ = qkv.shape if self.tp_size > 1: - qkv = all_gather_interleave(qkv, self.qkv.hidden_size, - self.tp_size) + qkv = all_gather_interleave(qkv, self.qkv.hidden_size, self.tp_size) # [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim] q, k, v = qkv.chunk(3, dim=2) # 3 * [s, b, head * head_dim] if self.tp_size > 1: - splitter = partial(dist_utils.split_tensor_along_last_dim, - num_partitions=self.tp_size) + splitter = partial( + dist_utils.split_tensor_along_last_dim, num_partitions=self.tp_size + ) q = splitter(q)[self.tp_rank] k = splitter(k)[self.tp_rank] v = splitter(v)[self.tp_rank] # 3 * [s, b, head * head_dim] -> 3 * [s, b, head, head_dim] - new_shape = (seq_len, bs, self.num_attention_heads_per_partition, - self.hidden_size_per_attention_head) + new_shape = ( + seq_len, + bs, + self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head, + ) q, k, v = (x.view(*new_shape) for x in (q, k, v)) return q, k, v def forward( - self, - x: torch.Tensor, - cu_seqlens: torch.Tensor, - rotary_pos_emb: torch.Tensor, - max_seqlen: Optional[int] = None, # Only used for Flash Attention - seqlens: Optional[list[int]] = None, # Only used for xFormers + self, + x: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: torch.Tensor, + max_seqlen: torch.Tensor, # Only used for Flash Attention + seqlens: torch.Tensor, # Only used for xFormers ) -> torch.Tensor: # [s, b, c] --> [s, b, head * 3 * head_dim] x, _ = self.qkv(x) @@ -353,8 +412,7 @@ def forward( q, k, v = self.split_qkv(x) batch_size = q.shape[1] - q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous() - for x in (q, k, v)) + q, k, v = (einops.rearrange(x, "s b ... -> b s ...") for x in (q, k, v)) if rotary_pos_emb is not None: # [2 * b, s, heads, head_dim] qk_concat = torch.cat([q, k], dim=0) @@ -362,79 +420,63 @@ def forward( q, k = torch.chunk(qk_rotated, 2, dim=0) if self.is_flash_attn_backend: - if self.attn_backend == _Backend.ROCM_AITER_FA: - raise AssertionError( - "ROCM AITER Flash Attention is not supported on maca.") - else: - # metax modification - from flash_attn import flash_attn_varlen_func - - q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) - - output = flash_attn_varlen_func(q, - k, - v, - cu_seqlens_q=cu_seqlens, - cu_seqlens_k=cu_seqlens, - max_seqlen_q=max_seqlen, - max_seqlen_k=max_seqlen, - dropout_p=0.0, - causal=False) - - context_layer = rearrange(output, - "(b s) h d -> s b (h d)", - b=batch_size).contiguous() + context_layer = vit_flash_attn_wrapper( + q, + k, + v, + cu_seqlens, + max_seqlen, + batch_size, + self.attn_backend == _Backend.ROCM_AITER_FA, + self.use_upstream_fa, + ) elif self.attn_backend == _Backend.TORCH_SDPA: # Execute attention entry by entry for speed & less VRAM. - outputs = [] - for i in range(1, len(cu_seqlens)): - start_idx = cu_seqlens[i - 1] - end_idx = cu_seqlens[i] - q_i = q[:, start_idx:end_idx] - k_i = k[:, start_idx:end_idx] - v_i = v[:, start_idx:end_idx] - q_i, k_i, v_i = (rearrange(x, "b s h d -> b h s d") - for x in [q_i, k_i, v_i]) - output_i = F.scaled_dot_product_attention(q_i, - k_i, - v_i, - dropout_p=0.0) - output_i = rearrange(output_i, "b h s d -> b s h d ") - outputs.append(output_i) - context_layer = torch.cat(outputs, dim=1) - context_layer = rearrange(context_layer, - "b s h d -> s b (h d)").contiguous() + from vllm.platforms import current_platform + + # Never remove the next contiguous logic + # Without it, hallucinations occur with the backend + if current_platform.is_rocm(): + q = q.contiguous() + k = k.contiguous() + v = v.contiguous() + context_layer = vit_torch_sdpa_wrapper( + q, + k, + v, + cu_seqlens, + ) elif self.attn_backend == _Backend.XFORMERS: - from xformers import ops as xops - from xformers.ops.fmha.attn_bias import BlockDiagonalMask - - attn_bias = BlockDiagonalMask.from_seqlens(q_seqlen=seqlens, - kv_seqlen=None, - device=q.device) - - context_layer = xops.memory_efficient_attention_forward( - q, k, v, attn_bias=attn_bias, p=0, scale=None) - context_layer = rearrange(context_layer, - "b s h d -> s b (h d)").contiguous() + context_layer = vit_xformers_attn_wrapper(q, k, v, seqlens) output, _ = self.proj(context_layer) return output +@support_torch_compile( + dynamic_arg_dims={ + "x": 0, + "cu_seqlens": 0, + "rotary_pos_emb": 0, + "seqlens": 0, + }, + mark_unbacked_dims={"seqlens": 0}, + enable_if=should_torch_compile_mm_vit, +) class Qwen2_5_VisionBlock(nn.Module): - def __init__( self, dim: int, num_heads: int, mlp_hidden_dim: int, act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu, - norm_layer: Optional[Callable[[int], nn.Module]] = None, - quant_config: Optional[QuantizationConfig] = None, + norm_layer: Callable[[int], nn.Module] | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, attn_backend: _Backend = _Backend.TORCH_SDPA, use_upstream_fa: bool = False, + attn_backend_override: _Backend | None = None, ) -> None: super().__init__() if norm_layer is None: @@ -449,35 +491,46 @@ def __init__( prefix=f"{prefix}.attn", use_data_parallel=use_data_parallel, attn_backend=attn_backend, - use_upstream_fa=use_upstream_fa) - self.mlp = Qwen2_5_VisionMLP(dim, - mlp_hidden_dim, - act_fn=act_fn, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.mlp", - use_data_parallel=use_data_parallel) + use_upstream_fa=use_upstream_fa, + attn_backend_override=attn_backend_override, + ) + self.mlp = Qwen2_5_VisionMLP( + dim, + mlp_hidden_dim, + act_fn=act_fn, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + use_data_parallel=use_data_parallel, + ) def forward( - self, - x: torch.Tensor, - cu_seqlens: torch.Tensor, - rotary_pos_emb: torch.Tensor, - max_seqlen: Optional[int] = None, # Only used for Flash Attention - seqlens: Optional[list[int]] = None, # Only used for xFormers + self, + x: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: torch.Tensor, + max_seqlen: torch.Tensor, # Only used for Flash Attention + seqlens: torch.Tensor, # Only used for xFormers ) -> torch.Tensor: - x_attn = self.attn(self.norm1(x), - cu_seqlens=cu_seqlens, - rotary_pos_emb=rotary_pos_emb, - max_seqlen=max_seqlen, - seqlens=seqlens) + x_attn = self.attn( + self.norm1(x), + cu_seqlens=cu_seqlens, + rotary_pos_emb=rotary_pos_emb, + max_seqlen=max_seqlen, + seqlens=seqlens, + ) x_fused_norm, residual = self.norm2(x, residual=x_attn) x = residual + self.mlp(x_fused_norm) return x +@support_torch_compile( + dynamic_arg_dims={ + "x": 0, + }, + enable_if=should_torch_compile_mm_vit, +) class Qwen2_5_VisionPatchEmbed(nn.Module): - def __init__( self, patch_size: int = 14, @@ -491,29 +544,32 @@ def __init__( self.hidden_size = hidden_size kernel_size = (temporal_patch_size, patch_size, patch_size) - self.proj = nn.Conv3d(in_channels, - hidden_size, - kernel_size=kernel_size, - stride=kernel_size, - bias=False) + self.proj = ReplicatedLinear( + in_channels * math.prod(kernel_size), + hidden_size, + bias=False, + return_bias=False, + ) def forward(self, x: torch.Tensor) -> torch.Tensor: - L, C = x.shape - x = x.view(L, -1, self.temporal_patch_size, self.patch_size, - self.patch_size) - x = self.proj(x).view(L, self.hidden_size) + x = self.proj(x) return x +@support_torch_compile( + dynamic_arg_dims={ + "x": 0, + }, + enable_if=should_torch_compile_mm_vit, +) class Qwen2_5_VisionPatchMerger(nn.Module): - def __init__( self, d_model: int, context_dim: int, - norm_layer: Optional[Callable[[int], nn.Module]] = None, + norm_layer: Callable[[int], nn.Module] | None = None, spatial_merge_size: int = 2, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, ) -> None: @@ -553,13 +609,13 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class Qwen2_5_VisionRotaryEmbedding(nn.Module): - def __init__(self, dim: int, theta: float = 10000.0) -> None: super().__init__() self.dim = dim self.theta = theta - inv_freq = 1.0 / (theta**( - torch.arange(0, dim, 2, dtype=torch.float, device='cpu') / dim)) + inv_freq = 1.0 / ( + theta ** (torch.arange(0, dim, 2, dtype=torch.float, device="cpu") / dim) + ) self.register_buffer("inv_freq", inv_freq, persistent=False) self._seq_len_cached = 0 self._freqs_cached = None @@ -568,12 +624,18 @@ def update_freqs_cache(self, seqlen: int) -> None: if seqlen > self._seq_len_cached: seqlen *= 2 self._seq_len_cached = seqlen - self.inv_freq = 1.0 / (self.theta**(torch.arange( - 0, self.dim, 2, dtype=torch.float, device=self.inv_freq.device) - / self.dim)) - seq = torch.arange(seqlen, - device=self.inv_freq.device, - dtype=self.inv_freq.dtype) + self.inv_freq = 1.0 / ( + self.theta + ** ( + torch.arange( + 0, self.dim, 2, dtype=torch.float, device=self.inv_freq.device + ) + / self.dim + ) + ) + seq = torch.arange( + seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype + ) freqs = torch.outer(seq, self.inv_freq) self._freqs_cached = freqs @@ -583,14 +645,14 @@ def forward(self, seqlen: int) -> torch.Tensor: class Qwen2_5_VisionTransformer(nn.Module): - def __init__( self, vision_config: Qwen2_5_VLVisionConfig, norm_eps: float = 1e-6, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, + attn_backend_override: _Backend | None = None, ) -> None: super().__init__() @@ -609,13 +671,18 @@ def __init__( self.spatial_merge_size = vision_config.spatial_merge_size self.fullatt_block_indexes = vision_config.fullatt_block_indexes self.spatial_merge_unit = self.spatial_merge_size**2 - - self.patch_embed = Qwen2_5_VisionPatchEmbed( - patch_size=patch_size, - temporal_patch_size=temporal_patch_size, - in_channels=in_channels, - hidden_size=self.hidden_size, - ) + # TODO[@lucaskabela]: Investigate fixing this usage + # see https://github.com/vllm-project/vllm/issues/27044 + # DO NOT MOVE THIS IMPORT + from vllm.compilation.backends import set_model_tag + + with set_model_tag("Qwen2_5_VisionPatchEmbed"): + self.patch_embed = Qwen2_5_VisionPatchEmbed( + patch_size=patch_size, + temporal_patch_size=temporal_patch_size, + in_channels=in_channels, + hidden_size=self.hidden_size, + ) norm_layer = partial(RMSNorm, eps=norm_eps) head_dim = self.hidden_size // self.num_heads @@ -623,43 +690,65 @@ def __init__( use_upstream_fa = False self.attn_backend = get_vit_attn_backend( - head_size=head_dim, dtype=torch.get_default_dtype()) - if self.attn_backend != _Backend.FLASH_ATTN and \ - check_upstream_fa_availability( - torch.get_default_dtype()): - self.attn_backend = _Backend.FLASH_ATTN - use_upstream_fa = True + head_size=head_dim, + dtype=torch.get_default_dtype(), + attn_backend_override=attn_backend_override, + ) + + # /--------------- Metax Modification ---------------\ + use_upstream_fa = True + from flash_attn import flash_attn_varlen_func + + self.flash_attn_varlen_func = flash_attn_varlen_func + # self.attn_backend, self.flash_attn_varlen_func = ( + # maybe_get_vit_flash_attn_backend( + # self.attn_backend, + # use_upstream_fa, + # attn_backend_override=attn_backend_override, + # ) + # ) + # \--------------------------------------------------/ if self.attn_backend not in { - _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS, - _Backend.ROCM_AITER_FA + _Backend.FLASH_ATTN, + _Backend.TORCH_SDPA, + _Backend.XFORMERS, + _Backend.ROCM_AITER_FA, }: raise RuntimeError( f"Qwen2.5-VL does not support {self.attn_backend} backend now." ) - self.blocks = nn.ModuleList([ - Qwen2_5_VisionBlock( - dim=self.hidden_size, - num_heads=self.num_heads, - mlp_hidden_dim=vision_config.intermediate_size, - act_fn=get_act_and_mul_fn(vision_config.hidden_act), + with set_model_tag("Qwen2_5_VisionBlock"): + self.blocks = nn.ModuleList( + [ + Qwen2_5_VisionBlock( + dim=self.hidden_size, + num_heads=self.num_heads, + mlp_hidden_dim=vision_config.intermediate_size, + act_fn=get_act_and_mul_fn(vision_config.hidden_act), + norm_layer=norm_layer, + quant_config=quant_config, + prefix=f"{prefix}.blocks.{layer_idx}", + use_data_parallel=use_data_parallel, + attn_backend=self.attn_backend, + use_upstream_fa=use_upstream_fa, + attn_backend_override=attn_backend_override, + ) + for layer_idx in range(depth) + ] + ) + + with set_model_tag("Qwen2_5_VisionPatchMerger"): + self.merger = Qwen2_5_VisionPatchMerger( + d_model=vision_config.out_hidden_size, + context_dim=self.hidden_size, norm_layer=norm_layer, + spatial_merge_size=self.spatial_merge_size, quant_config=quant_config, - prefix=f"{prefix}.blocks.{layer_idx}", + prefix=f"{prefix}.merger", use_data_parallel=use_data_parallel, - attn_backend=self.attn_backend, - use_upstream_fa=use_upstream_fa) for layer_idx in range(depth) - ]) - self.merger = Qwen2_5_VisionPatchMerger( - d_model=vision_config.out_hidden_size, - context_dim=self.hidden_size, - norm_layer=norm_layer, - spatial_merge_size=self.spatial_merge_size, - quant_config=quant_config, - prefix=f"{prefix}.merger", - use_data_parallel=use_data_parallel, - ) + ) @property def dtype(self) -> torch.dtype: @@ -672,48 +761,66 @@ def device(self) -> torch.device: def rotary_pos_emb_thw(self, t, h, w): hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) - hpos_ids = hpos_ids.reshape( - h // self.spatial_merge_size, - self.spatial_merge_size, - w // self.spatial_merge_size, - self.spatial_merge_size, - ).permute(0, 2, 1, 3).flatten() - wpos_ids = wpos_ids.reshape( - h // self.spatial_merge_size, - self.spatial_merge_size, - w // self.spatial_merge_size, - self.spatial_merge_size, - ).permute(0, 2, 1, 3).flatten() + hpos_ids = ( + hpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + .permute(0, 2, 1, 3) + .flatten() + ) + wpos_ids = ( + wpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + .permute(0, 2, 1, 3) + .flatten() + ) pos_ids = torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1) max_size = max(h, w) rotary_pos_emb_full = self.rotary_pos_emb(max_size) rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) rotary_pos_emb = rotary_pos_emb.reshape( rotary_pos_emb.shape[0] // self.spatial_merge_unit, - self.spatial_merge_unit, -1) + self.spatial_merge_unit, + -1, + ) return rotary_pos_emb def get_window_index_thw(self, grid_t, grid_h, grid_w): - vit_merger_window_size = (self.window_size // - self.spatial_merge_size // self.patch_size) + vit_merger_window_size = ( + self.window_size // self.spatial_merge_size // self.patch_size + ) llm_grid_h = grid_h // self.spatial_merge_size llm_grid_w = grid_w // self.spatial_merge_size index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape( - grid_t, llm_grid_h, llm_grid_w) + grid_t, llm_grid_h, llm_grid_w + ) pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size - index_padded = F.pad(index, (0, pad_w, 0, pad_h), 'constant', -100) - index_padded = index_padded.reshape(grid_t, num_windows_h, - vit_merger_window_size, - num_windows_w, - vit_merger_window_size) + index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100) + index_padded = index_padded.reshape( + grid_t, + num_windows_h, + vit_merger_window_size, + num_windows_w, + vit_merger_window_size, + ) index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape( - grid_t, num_windows_h * num_windows_w, vit_merger_window_size, - vit_merger_window_size) + grid_t, + num_windows_h * num_windows_w, + vit_merger_window_size, + vit_merger_window_size, + ) seqlens = (index_padded != -100).sum([2, 3]).reshape(-1) index_padded = index_padded.reshape(-1) index_new = index_padded[index_padded != -100] @@ -725,35 +832,37 @@ def get_window_index_thw(self, grid_t, grid_h, grid_w): @lru_cache(maxsize=1024) # noqa: B019 def get_rope_by_thw(self, t, h, w): - window_index_thw, cu_seqlens_window_thw = self.get_window_index_thw( - t, h, w) + window_index_thw, cu_seqlens_window_thw = self.get_window_index_thw(t, h, w) rotary_pos_emb_thw = self.rotary_pos_emb_thw(t, h, w) rotary_pos_emb_thw = rotary_pos_emb_thw[window_index_thw, :, :] rotary_pos_emb_thw = rotary_pos_emb_thw.flatten(start_dim=0, end_dim=1) cu_seqlens_thw = torch.repeat_interleave( - torch.tensor([h * w], dtype=torch.int32), t) - return (rotary_pos_emb_thw, window_index_thw, cu_seqlens_window_thw, - cu_seqlens_thw) + torch.tensor([h * w], dtype=torch.int32), t + ) + return ( + rotary_pos_emb_thw, + window_index_thw, + cu_seqlens_window_thw, + cu_seqlens_thw, + ) def compute_attn_mask_seqlen( self, cu_seqlens: torch.Tensor, - ) -> tuple[Optional[int], Optional[list[int]]]: - max_seqlen, seqlens = None, None - if (self.attn_backend == _Backend.FLASH_ATTN - or self.attn_backend == _Backend.ROCM_AITER_FA): - max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + ) -> tuple[torch.Tensor, torch.Tensor]: + max_seqlen = torch.zeros([], device=cu_seqlens.device) + seqlens = torch.zeros(1, device=cu_seqlens.device) + if self.attn_backend in {_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA}: + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() elif self.attn_backend == _Backend.XFORMERS: - seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() + seqlens = cu_seqlens[1:] - cu_seqlens[:-1] return max_seqlen, seqlens @staticmethod def invert_permutation(perm: torch.Tensor) -> torch.Tensor: # building the inverse permutation in O(n) time inv = torch.empty_like(perm, pin_memory=is_pin_memory_available()) - inv[perm] = torch.arange(perm.numel(), - device=perm.device, - dtype=perm.dtype) + inv[perm] = torch.arange(perm.numel(), device=perm.device, dtype=perm.dtype) return inv def forward( @@ -786,10 +895,9 @@ def forward( ) = self.get_rope_by_thw(t, h, w) window_index.append(window_index_thw + window_index_id) - window_index_id += (t * llm_h * llm_w) + window_index_id += t * llm_h * llm_w - cu_seqlens_window_thw = (cu_seqlens_window_thw + - cu_window_seqlens_last) + cu_seqlens_window_thw = cu_seqlens_window_thw + cu_window_seqlens_last cu_window_seqlens_last = cu_seqlens_window_thw[-1] cu_window_seqlens.append(cu_seqlens_window_thw) @@ -809,23 +917,22 @@ def forward( # transformers # pre-compute seqlens for window/full attn to reduce cuMemcpy operations - max_seqlen_full, seqlens_full = self.compute_attn_mask_seqlen( - cu_seqlens) + max_seqlen_full, seqlens_full = self.compute_attn_mask_seqlen(cu_seqlens) max_seqlen_window, seqlens_window = self.compute_attn_mask_seqlen( - cu_window_seqlens) + cu_window_seqlens + ) cu_seqlens = cu_seqlens.to(device=self.device, non_blocking=True) - cu_window_seqlens = cu_window_seqlens.to(device=self.device, - non_blocking=True) - rotary_pos_emb = rotary_pos_emb.to(device=self.device, - non_blocking=True) - window_index = window_index.to(device=hidden_states.device, - non_blocking=True) - reverse_indices = reverse_indices.to(device=hidden_states.device, - non_blocking=True) + cu_window_seqlens = cu_window_seqlens.to(device=self.device, non_blocking=True) + rotary_pos_emb = rotary_pos_emb.to(device=self.device, non_blocking=True) + window_index = window_index.to(device=hidden_states.device, non_blocking=True) + reverse_indices = reverse_indices.to( + device=hidden_states.device, non_blocking=True + ) hidden_states = hidden_states.reshape( - seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) + seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1 + ) hidden_states = hidden_states[window_index, :, :] hidden_states = hidden_states.reshape(seq_len, -1) @@ -859,8 +966,7 @@ def forward( hidden_states = hidden_states[reverse_indices, :] return hidden_states - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("attn.qkv.", "attn.q.", "q"), @@ -873,7 +979,10 @@ def load_weights(self, weights: Iterable[tuple[str, loaded_params: set[str] = set() for name, loaded_weight in weights: - for (param_name, weight_name, shard_id) in stacked_params_mapping: + if name.endswith("patch_embed.proj.weight"): + loaded_weight = conv3d_to_linear_weight(loaded_weight) + + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -883,15 +992,13 @@ def load_weights(self, weights: Iterable[tuple[str, break else: param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params class Qwen2_5_VLProcessingInfo(Qwen2VLProcessingInfo): - def get_hf_config(self): return self.ctx.get_hf_config(Qwen2_5_VLConfig) @@ -904,7 +1011,6 @@ def get_hf_processor(self, **kwargs: object) -> Qwen2_5_VLProcessor: class Qwen2_5_VLMultiModalProcessor(Qwen2VLMultiModalProcessor): - def _get_mm_fields_config( self, hf_inputs: BatchFeature, @@ -922,8 +1028,7 @@ def _get_prompt_updates( out_mm_kwargs: MultiModalKwargs, ) -> Sequence[PromptUpdate]: hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) - image_processor = self.info.get_image_processor( - **hf_processor_mm_kwargs) + image_processor = self.info.get_image_processor(**hf_processor_mm_kwargs) tokenizer = self.info.get_tokenizer() vocab = tokenizer.get_vocab() @@ -942,13 +1047,19 @@ def get_replacement_qwen2vl(item_idx: int, modality: str): num_tokens = int(grid_thw.prod()) // merge_length # EVS-specific code - video_pruning_rate = self.info.ctx.get_mm_config( - ).video_pruning_rate - if (modality == "video" and video_pruning_rate is not None - and video_pruning_rate > 0.0): + video_pruning_rate = self.info.ctx.get_mm_config().video_pruning_rate + if ( + modality == "video" + and video_pruning_rate is not None + and video_pruning_rate > 0.0 + ): + T, H, W = map(int, grid_thw) + tokens_per_frame = (H // image_processor.merge_size) * ( + W // image_processor.merge_size + ) num_tokens = compute_retained_tokens_count( - grid_thw, - image_processor.merge_size, + tokens_per_frame, + T, video_pruning_rate, ) # End of EVS-specific code @@ -959,20 +1070,29 @@ def get_replacement_qwen2vl(item_idx: int, modality: str): PromptReplacement( modality=modality, target=[placeholder[modality]], - replacement=partial(get_replacement_qwen2vl, - modality=modality), - ) for modality in ("image", "video") + replacement=partial(get_replacement_qwen2vl, modality=modality), + ) + for modality in ("image", "video") ] @MULTIMODAL_REGISTRY.register_processor( Qwen2_5_VLMultiModalProcessor, info=Qwen2_5_VLProcessingInfo, - dummy_inputs=Qwen2_5_VLDummyInputsBuilder) -class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, - SupportsLoRA, SupportsPP, - SupportsQuant, - SupportsMultiModalPruning): + dummy_inputs=Qwen2_5_VLDummyInputsBuilder, +) +class Qwen2_5_VLForConditionalGeneration( + nn.Module, + SupportsMultiModal, + SupportsLoRA, + SupportsPP, + SupportsQuant, + SupportsEagle3, + SupportsMultiModalPruning, + SupportsMRoPE, +): + merge_by_field_config = True + multimodal_cpu_fields = {"image_grid_thw", "video_grid_thw"} packed_modules_mapping = { "qkv_proj": ["q_proj", "k_proj", "v_proj"], @@ -988,12 +1108,135 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, # mapping for original checkpoint "lm_head.": "language_model.lm_head.", "model.": "language_model.model.", - }) + } + ) supports_encoder_tp_data = True + def get_mrope_input_positions( + self, + input_tokens: list[int], + hf_config: PretrainedConfig, + image_grid_thw: list[list[int]] | torch.Tensor, + video_grid_thw: list[list[int]] | torch.Tensor, + second_per_grid_ts: list[float], + audio_feature_lengths: torch.Tensor | None = None, + use_audio_in_video: bool = False, + ) -> tuple[torch.Tensor, int]: + """Get mrope input positions and delta value.""" + + image_token_id = hf_config.image_token_id + video_token_id = hf_config.video_token_id + vision_start_token_id = hf_config.vision_start_token_id + spatial_merge_size = hf_config.vision_config.spatial_merge_size + tokens_per_second = getattr(hf_config.vision_config, "tokens_per_second", 1.0) + + input_tokens_tensor = torch.tensor(input_tokens) + vision_start_indices = torch.argwhere( + input_tokens_tensor == vision_start_token_id + ).squeeze(1) + vision_tokens = input_tokens_tensor[vision_start_indices + 1] + image_nums = (vision_tokens == image_token_id).sum() + video_nums = (vision_tokens == video_token_id).sum() + llm_pos_ids_list: list = [] + + st = 0 + remain_images, remain_videos = image_nums, video_nums + + image_index, video_index = 0, 0 + for _ in range(image_nums + video_nums): + video_second_per_grid_t = 0.0 + if remain_images > 0: + try: + ed_image = input_tokens.index(image_token_id, st) + except ValueError: + ed_image = len(input_tokens) + 1 + else: + ed_image = len(input_tokens) + 1 + if remain_videos > 0: + try: + ed_video = input_tokens.index(video_token_id, st) + except ValueError: + ed_video = len(input_tokens) + 1 + else: + ed_video = len(input_tokens) + 1 + if ed_image < ed_video: + t, h, w = ( + image_grid_thw[image_index][0], + image_grid_thw[image_index][1], + image_grid_thw[image_index][2], + ) + image_index += 1 + remain_images -= 1 + ed = ed_image + else: + t, h, w = ( + video_grid_thw[video_index][0], + video_grid_thw[video_index][1], + video_grid_thw[video_index][2], + ) + video_second_per_grid_t = 1.0 + if second_per_grid_ts: + video_second_per_grid_t = second_per_grid_ts[video_index] + video_index += 1 + remain_videos -= 1 + ed = ed_video + + llm_grid_t, llm_grid_h, llm_grid_w = ( + t, + h // spatial_merge_size, + w // spatial_merge_size, + ) + text_len = ed - st + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx + ) + + t_index = ( + ( + torch.arange(llm_grid_t) + .view(-1, 1) + .expand(-1, llm_grid_h * llm_grid_w) + * video_second_per_grid_t + * tokens_per_second + ) + .long() + .flatten() + ) + + h_index = ( + torch.arange(llm_grid_h) + .view(1, -1, 1) + .expand(llm_grid_t, -1, llm_grid_w) + .flatten() + ) + w_index = ( + torch.arange(llm_grid_w) + .view(1, 1, -1) + .expand(llm_grid_t, llm_grid_h, -1) + .flatten() + ) + llm_pos_ids_list.append( + torch.stack([t_index, h_index, w_index]) + text_len + st_idx + ) + st = ed + llm_grid_t * llm_grid_h * llm_grid_w + + if st < len(input_tokens): + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + text_len = len(input_tokens) - st + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx + ) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() + + return llm_positions, mrope_position_delta + @classmethod - def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: + def get_placeholder_str(cls, modality: str, i: int) -> str | None: if modality.startswith("image"): return "<|vision_start|><|image_pad|><|vision_end|>" if modality.startswith("video"): @@ -1008,19 +1251,28 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" self.config = config + self.vllm_config = vllm_config self.multimodal_config = multimodal_config self.video_pruning_rate = multimodal_config.video_pruning_rate self.is_multimodal_pruning_enabled = ( - multimodal_config.is_multimodal_pruning_enabled()) + multimodal_config.is_multimodal_pruning_enabled() + ) - if multimodal_config.get_limit_per_prompt("image") or \ - multimodal_config.get_limit_per_prompt("video"): + if multimodal_config.get_limit_per_prompt( + "image" + ) or multimodal_config.get_limit_per_prompt("video"): + attn_backend_override = ( + multimodal_config.mm_encoder_attn_backend + if multimodal_config is not None + else None + ) self.visual = Qwen2_5_VisionTransformer( - config.vision_config, + vision_config=config.vision_config, norm_eps=getattr(config, "rms_norm_eps", 1e-6), quant_config=self.quant_config, prefix=maybe_prefix(prefix, "visual"), use_data_parallel=self.use_data_parallel, + attn_backend_override=attn_backend_override, ) else: self.visual = None @@ -1032,26 +1284,19 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): ) self.make_empty_intermediate_tensors = ( - self.language_model.make_empty_intermediate_tensors) - - def _validate_and_reshape_mm_tensor(self, mm_input: object, - name: str) -> torch.Tensor: - if not isinstance(mm_input, (torch.Tensor, list)): - raise ValueError(f"Incorrect type of {name}. " - f"Got type: {type(mm_input)}") - if isinstance(mm_input, torch.Tensor): - if mm_input.ndim == 2: - return mm_input - if mm_input.ndim != 3: - raise ValueError(f"{name} should be 2D or batched 3D tensor. " - f"Got ndim: {mm_input.ndim} " - f"(shape={mm_input.shape})") - return mm_input.reshape(-1, mm_input.shape[-1]) - else: - return torch.concat(mm_input) + self.language_model.make_empty_intermediate_tensors + ) + + def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None: + self.language_model.model.aux_hidden_state_layers = layers + + def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]: + num_layers = len(self.language_model.model.layers) + return (2, num_layers // 2, num_layers - 3) def _parse_and_validate_image_input( - self, **kwargs: object) -> Optional[Qwen2_5_VLImageInputs]: + self, **kwargs: object + ) -> Qwen2_5_VLImageInputs | None: pixel_values = kwargs.pop("pixel_values", None) image_embeds = kwargs.pop("image_embeds", None) image_grid_thw = kwargs.pop("image_grid_thw", None) @@ -1060,28 +1305,22 @@ def _parse_and_validate_image_input( return None if pixel_values is not None: - pixel_values = self._validate_and_reshape_mm_tensor( - pixel_values, "image pixel values") - image_grid_thw = self._validate_and_reshape_mm_tensor( - image_grid_thw, "image grid_thw") - - return Qwen2_5_VLImagePixelInputs(type="pixel_values", - pixel_values=pixel_values, - image_grid_thw=image_grid_thw) + return Qwen2_5_VLImagePixelInputs( + type="pixel_values", + pixel_values=pixel_values, + image_grid_thw=image_grid_thw, + ) if image_embeds is not None: - image_embeds = self._validate_and_reshape_mm_tensor( - image_embeds, "image embeds") - image_grid_thw = self._validate_and_reshape_mm_tensor( - image_grid_thw, "image grid_thw") - return Qwen2_5_VLImageEmbeddingInputs( type="image_embeds", image_embeds=image_embeds, - image_grid_thw=image_grid_thw) + image_grid_thw=image_grid_thw, + ) def _parse_and_validate_video_input( - self, **kwargs: object) -> Optional[Qwen2_5_VLVideoInputs]: + self, **kwargs: object + ) -> Qwen2_5_VLVideoInputs | None: pixel_values_videos = kwargs.pop("pixel_values_videos", None) video_embeds = kwargs.pop("video_embeds", None) video_grid_thw = kwargs.pop("video_grid_thw", None) @@ -1091,12 +1330,6 @@ def _parse_and_validate_video_input( return None if pixel_values_videos is not None: - pixel_values_videos = self._validate_and_reshape_mm_tensor( - pixel_values_videos, "video pixel values") - video_grid_thw = self._validate_and_reshape_mm_tensor( - video_grid_thw, "video grid_thw") - if second_per_grid_ts is not None and second_per_grid_ts.ndim == 2: - second_per_grid_ts = second_per_grid_ts.squeeze(-1) return Qwen2_5_VLVideoPixelInputs( type="pixel_values_videos", pixel_values_videos=pixel_values_videos, @@ -1105,20 +1338,15 @@ def _parse_and_validate_video_input( ) if video_embeds is not None: - video_embeds = self._validate_and_reshape_mm_tensor( - video_embeds, "video embeds") - video_grid_thw = self._validate_and_reshape_mm_tensor( - video_grid_thw, "video grid_thw") - return Qwen2_5_VLVideoEmbeddingInputs( type="video_embeds", video_embeds=video_embeds, - video_grid_thw=video_grid_thw) + video_grid_thw=video_grid_thw, + ) def _process_image_input( - self, - image_input: Qwen2_5_VLImageInputs) -> tuple[torch.Tensor, ...]: - + self, image_input: Qwen2_5_VLImageInputs + ) -> tuple[torch.Tensor, ...]: grid_thw = image_input["image_grid_thw"] assert grid_thw.ndim == 2 grid_thw_list = grid_thw.tolist() @@ -1127,27 +1355,24 @@ def _process_image_input( image_embeds = image_input["image_embeds"].type(self.visual.dtype) else: pixel_values = image_input["pixel_values"] - - if self.use_data_parallel: - return run_dp_sharded_mrope_vision_model(self.visual, - pixel_values, - grid_thw_list, - rope_type="rope_3d") - else: - image_embeds = self.visual(pixel_values, - grid_thw=grid_thw_list) + with set_forward_context(None, self.vllm_config): + if self.use_data_parallel: + return run_dp_sharded_mrope_vision_model( + self.visual, pixel_values, grid_thw_list, rope_type="rope_3d" + ) + else: + image_embeds = self.visual(pixel_values, grid_thw=grid_thw_list) # Split concatenated embeddings for each image item. - # Using prod on grid_thw_list instead of grid_thw.prod avoids CUDA sync merge_size = self.visual.spatial_merge_size - sizes = (torch.tensor(grid_thw_list, dtype=torch.long).prod(-1) // - (merge_size * merge_size)).tolist() - + sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist() return image_embeds.split(sizes) def _postprocess_image_embeds_evs( - self, image_embeds_split: tuple[torch.Tensor, ...], - image_input: Qwen2_5_VLImageInputs) -> tuple[torch.Tensor, ...]: + self, + image_embeds_split: tuple[torch.Tensor, ...], + image_input: Qwen2_5_VLImageInputs, + ) -> tuple[torch.Tensor, ...]: """ Append mrope positions for each for images. This is necessary to recover correct mrope @@ -1168,17 +1393,15 @@ def _postprocess_image_embeds_evs( grid_thw_list = grid_thw.tolist() image_embeds_out = [] for emb, size in zip(image_embeds_split, grid_thw_list): - positions = compute_mrope_for_media(size, - merge_size).to(emb.device) + positions = compute_mrope_for_media(size, merge_size).to(emb.device) emb = torch.cat([emb, positions], dim=1) image_embeds_out.append(emb) image_embeds_split = image_embeds_out return tuple(image_embeds_split) def _process_video_input( - self, - video_input: Qwen2_5_VLVideoInputs) -> tuple[torch.Tensor, ...]: - + self, video_input: Qwen2_5_VLVideoInputs + ) -> tuple[torch.Tensor, ...]: grid_thw = video_input["video_grid_thw"] assert grid_thw.ndim == 2 grid_thw_list = grid_thw.tolist() @@ -1187,26 +1410,29 @@ def _process_video_input( video_embeds = video_input["video_embeds"].type(self.visual.dtype) else: pixel_values_videos = video_input["pixel_values_videos"] - if self.use_data_parallel: - return run_dp_sharded_mrope_vision_model(self.visual, - pixel_values_videos, - grid_thw_list, - rope_type="rope_3d") - else: - video_embeds = self.visual(pixel_values_videos, - grid_thw=grid_thw_list) + with set_forward_context(None, self.vllm_config): + if self.use_data_parallel: + return run_dp_sharded_mrope_vision_model( + self.visual, + pixel_values_videos, + grid_thw_list, + rope_type="rope_3d", + ) + else: + video_embeds = self.visual( + pixel_values_videos, grid_thw=grid_thw_list + ) # Split concatenated embeddings for each video item. merge_size = self.visual.spatial_merge_size - # Using prod on grid_thw_list instead of grid_thw.prod avoids CUDA sync - sizes = (torch.tensor(grid_thw_list, dtype=torch.long).prod(-1) // - (merge_size * merge_size)).tolist() - + sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist() return video_embeds.split(sizes) def _postprocess_video_embeds_evs( - self, video_embeds_split: tuple[torch.Tensor, ...], - video_input: Qwen2_5_VLVideoInputs) -> tuple[torch.Tensor, ...]: + self, + video_embeds_split: tuple[torch.Tensor, ...], + video_input: Qwen2_5_VLVideoInputs, + ) -> tuple[torch.Tensor, ...]: """ Prunes video embeddings via Efficient Video Sampling (EVS) and then appends mrope positions for each retained embeddings @@ -1231,9 +1457,9 @@ def _postprocess_video_embeds_evs( tokens_per_second = self.config.vision_config.tokens_per_second video_embeds_out = [] - for emb, size, video_second_per_grid_t in zip(video_embeds_split, - grid_thw_list, - second_per_grid_ts): + for emb, size, video_second_per_grid_t in zip( + video_embeds_split, grid_thw_list, second_per_grid_ts + ): # For each video, we compute retention mask using EVS retention_mask = compute_retention_mask( emb, @@ -1285,20 +1511,19 @@ def recompute_mrope_positions( vision_start_token_id = self.config.vision_start_token_id # Device - device = (multimodal_embeddings[0].device - if len(multimodal_embeddings) else mrope_positions.device) + device = ( + multimodal_embeddings[0].device + if len(multimodal_embeddings) + else mrope_positions.device + ) # Tensors - input_ids_t = torch.as_tensor(input_ids, - device=device, - dtype=torch.long) + input_ids_t = torch.as_tensor(input_ids, device=device, dtype=torch.long) - # fmt: off - mm_embeddings_out = [mm[:, :-4] for mm in - multimodal_embeddings] - mm_embeddings_pos = [mm[:, -4:].permute(1, 0).long() for mm in - multimodal_embeddings] - # fmt: in + mm_embeddings_out = [mm[:, :-4] for mm in multimodal_embeddings] + mm_embeddings_pos = [ + mm[:, -4:].permute(1, 0).long() for mm in multimodal_embeddings + ] positions, mrope_positions_delta = recompute_mrope_positions( input_ids_t, @@ -1318,24 +1543,27 @@ def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: # Preserve the order of modalities if there are multiple of them # from the order of kwargs. for input_key in kwargs: - if input_key in ("pixel_values", "image_embeds" - ) and "image" not in mm_input_by_modality: - mm_input_by_modality[ - "image"] = self._parse_and_validate_image_input(**kwargs) - if input_key in ("pixel_values_videos", "video_embeds" - ) and "video" not in mm_input_by_modality: - mm_input_by_modality[ - "video"] = self._parse_and_validate_video_input(**kwargs) + if ( + input_key in ("pixel_values", "image_embeds") + and "image" not in mm_input_by_modality + ): + mm_input_by_modality["image"] = self._parse_and_validate_image_input( + **kwargs + ) + if ( + input_key in ("pixel_values_videos", "video_embeds") + and "video" not in mm_input_by_modality + ): + mm_input_by_modality["video"] = self._parse_and_validate_video_input( + **kwargs + ) return mm_input_by_modality def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings(self, - **kwargs: object) -> MultiModalEmbeddings: - - mm_input_by_modality = self._parse_and_validate_multimodal_inputs( - **kwargs) + def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: + mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs) if not mm_input_by_modality: return [] @@ -1348,76 +1576,29 @@ def get_multimodal_embeddings(self, for modality in mm_input_by_modality: multimodal_input = mm_input_by_modality[modality] if modality == "image": - vision_embeddings = self._process_image_input(multimodal_input) + image_embeddings = self._process_image_input(multimodal_input) if self.is_multimodal_pruning_enabled: - vision_embeddings = self._postprocess_image_embeds_evs( - vision_embeddings, multimodal_input + image_embeddings = self._postprocess_image_embeds_evs( + image_embeddings, multimodal_input ) - multimodal_embeddings += vision_embeddings + multimodal_embeddings += tuple(image_embeddings) if modality == "video": video_embeddings = self._process_video_input(multimodal_input) if self.is_multimodal_pruning_enabled: video_embeddings = self._postprocess_video_embeds_evs( video_embeddings, multimodal_input ) - multimodal_embeddings += video_embeddings + multimodal_embeddings += tuple(video_embeddings) return multimodal_embeddings - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None \ - and len(multimodal_embeddings) != 0: - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, multimodal_embeddings, - [self.config.image_token_id, self.config.video_token_id]) - return inputs_embeds - - def get_input_embeddings_v0( - self, - input_ids: torch.Tensor, - image_input: Optional[Qwen2_5_VLImageInputs] = None, - video_input: Optional[Qwen2_5_VLVideoInputs] = None, - ) -> torch.Tensor: - inputs_embeds = self.get_input_embeddings(input_ids) - if image_input is not None: - image_embeds = self._process_image_input(image_input) - if self.is_multimodal_pruning_enabled: - image_embeds = self._postprocess_image_embeds_evs( - image_embeds, image_input - ) - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - image_embeds, - placeholder_token_id=self.config.image_token_id, - ) - - if video_input is not None: - video_embeds = self._process_video_input(video_input) - if self.is_multimodal_pruning_enabled: - video_embeds = self._postprocess_video_embeds_evs( - video_embeds, video_input - ) - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - video_embeds, - placeholder_token_id=self.config.video_token_id, - ) - return inputs_embeds - def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, **kwargs: object, - ) -> Union[torch.Tensor, IntermediateTensors]: + ) -> torch.Tensor | IntermediateTensors: """Run forward pass for Qwen2.5-VL. Args: @@ -1432,26 +1613,6 @@ def forward( if intermediate_tensors is not None: inputs_embeds = None - # NOTE: In v1, inputs_embeds is always generated at model runner from - # `get_multimodal_embeddings` and `get_input_embeddings`, this - # condition is only for v0 compatibility. - elif inputs_embeds is None: - image_input = self._parse_and_validate_image_input(**kwargs) - video_input = self._parse_and_validate_video_input(**kwargs) - - if image_input is None and video_input is None: - inputs_embeds = None - else: - if uses_mrope(self.config): - assert positions.ndim == 2 and positions.size(0) == 3, ( - "multimodal section rotary embedding requires " - f"(3, seq_len) positions, but got {positions.size()}") - inputs_embeds = self.get_input_embeddings_v0( - input_ids, - image_input=image_input, - video_input=video_input) - input_ids = None - hidden_states = self.language_model.model( input_ids=input_ids, positions=positions, @@ -1463,12 +1624,10 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: return self.language_model.compute_logits(hidden_states) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: - + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: skip_prefixes = [] if self.visual is None: skip_prefixes.extend(["visual."]) diff --git a/vllm_metax/models/qwen2_vl.py b/vllm_metax/models/qwen2_vl.py index 65df94dd8..7577e14b9 100644 --- a/vllm_metax/models/qwen2_vl.py +++ b/vllm_metax/models/qwen2_vl.py @@ -24,167 +24,250 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Qwen2-VL model compatible with HuggingFace weights.""" -from collections.abc import Iterable, Mapping, Sequence + +import math +from collections.abc import Callable, Iterable, Mapping, Sequence from functools import partial -from typing import Any, Callable, Literal, Optional, TypedDict, Union +from typing import Annotated, Any, Literal, TypeAlias import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange, repeat -from transformers import BatchFeature -from transformers.models.qwen2_vl import (Qwen2VLImageProcessor, - Qwen2VLProcessor) +from transformers import BatchFeature, PretrainedConfig +from transformers.models.qwen2_vl import Qwen2VLImageProcessor, Qwen2VLProcessor from transformers.models.qwen2_vl.configuration_qwen2_vl import ( - Qwen2VLConfig, Qwen2VLVisionConfig) + Qwen2VLConfig, + Qwen2VLVisionConfig, +) from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize +from transformers.models.qwen2_vl.video_processing_qwen2_vl import Qwen2VLVideoProcessor + +from vllm.attention.backends.registry import _Backend +from vllm.attention.layer import ( + check_upstream_fa_availability, + maybe_get_vit_flash_attn_backend, +) from vllm.config import VllmConfig +from vllm.config.multimodal import BaseDummyOptions from vllm.distributed import parallel_state, tensor_model_parallel_all_gather from vllm.distributed import utils as dist_utils from vllm.logger import init_logger from vllm.model_executor.layers.activation import QuickGELU -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.quantization.gptq import GPTQConfig -from vllm.model_executor.layers.quantization.gptq_marlin import ( - GPTQMarlinConfig) +from vllm.model_executor.layers.rotary_embedding.common import ( + dispatch_rotary_emb_function, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.interfaces import (MultiModalEmbeddings, - SupportsLoRA, - SupportsMultiModal, - SupportsPP) from vllm.model_executor.models.module_mapping import MultiModelKeys -from vllm.model_executor.models.utils import (AutoWeightsLoader, WeightsMapper, - init_vllm_registered_model, - maybe_prefix, - merge_multimodal_embeddings) -from vllm.model_executor.models.vision import get_vit_attn_backend from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (ImageItem, ModalityData, - MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargs, VideoItem) -from vllm.multimodal.parse import (DictEmbeddingItems, ImageSize, - ModalityDataItems, MultiModalDataItems, - MultiModalDataParser) -from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, PromptReplacement, - PromptUpdate) +from vllm.multimodal.inputs import ( + ImageItem, + ModalityData, + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalKwargsItems, + VideoItem, +) +from vllm.multimodal.parse import ( + DictEmbeddingItems, + ImageSize, + ModalityDataItems, + MultiModalDataItems, + MultiModalDataParser, +) +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + BaseProcessingInfo, + PromptReplacement, + PromptUpdate, +) from vllm.multimodal.profiling import BaseDummyInputsBuilder -from vllm.platforms import _Backend, current_platform from vllm.sequence import IntermediateTensors -from vllm.transformers_utils.config import uses_mrope -from vllm.transformers_utils.processor import ( - cached_image_processor_from_config) +from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.utils.tensor_schema import TensorSchema, TensorShape + +from vllm.model_executor.models.interfaces import ( + MultiModalEmbeddings, + SupportsLoRA, + SupportsMRoPE, + SupportsMultiModal, + SupportsPP, +) +from vllm.model_executor.models.utils import ( + AutoWeightsLoader, + WeightsMapper, + init_vllm_registered_model, + maybe_prefix, +) +from vllm.model_executor.models.vision import ( + conv3d_to_linear_weight, + get_vit_attn_backend, + run_dp_sharded_mrope_vision_model, +) logger = init_logger(__name__) # For profile run -_MAX_FRAMES_PER_VIDEO = 16 +_MAX_FRAMES_PER_VIDEO = 14 # === Vision Inputs === # -class Qwen2VLImagePixelInputs(TypedDict): - type: Literal["pixel_values"] - pixel_values: torch.Tensor - """Shape: - `(num_patches, num_channels * patch_size * patch_size)` +class Qwen2VLImagePixelInputs(TensorSchema): """ - - image_grid_thw: torch.Tensor - """Shape: `(num_images, 3)` - This should be in `(grid_t, grid_h, grid_w)` format. + Dimensions: + - np: The total number of patches over each image over each prompt in + the batch + - ni: Number of images + - cps: Number of channels * patch_size * patch_size + + Historical context: + - pixel_values shape: (num_patches, num_channels * patch_size * + patch_size) + - image_grid_thw shape: (num_images, 3) in (grid_t, grid_h, grid_w) + format """ + type: Literal["pixel_values"] -class Qwen2VLImageEmbeddingInputs(TypedDict): - type: Literal["image_embeds"] - image_embeds: torch.Tensor - """Supported types: - - list[`torch.Tensor`]: A list of tensors holding all images' features. - Each tensor holds an image's features. - - `torch.Tensor`: A tensor holding all images' features - (concatenation of all images' feature tensors). - - Tensor shape: `(num_image_features, hidden_size)` - - `num_image_features` varies based on - the number and resolution of the images. - - `hidden_size` must match the hidden size of language model backbone. - """ + pixel_values: Annotated[ + torch.Tensor, + TensorShape("np", "cps"), + ] - image_grid_thw: torch.Tensor - """Shape: `(num_images, 3)` - This should be in `(grid_t, grid_h, grid_w)` format. + image_grid_thw: Annotated[ + torch.Tensor, + TensorShape("ni", 3), + ] + + +class Qwen2VLImageEmbeddingInputs(TensorSchema): + """ + Dimensions: + - nf: Number of image features + - hs: Hidden size + - ni: Number of images + + Historical context: + - image_embeds shape: (num_image_features, hidden_size) + - num_image_features varies based on the number and resolution of the + images. + - hidden_size must match the hidden size of language model backbone. + - image_grid_thw shape: (num_images, 3) in (grid_t, grid_h, grid_w) + format """ + type: Literal["image_embeds"] -Qwen2VLImageInputs = Union[Qwen2VLImagePixelInputs, - Qwen2VLImageEmbeddingInputs] + image_embeds: Annotated[ + torch.Tensor, + TensorShape("nf", "hs"), + ] + image_grid_thw: Annotated[ + torch.Tensor, + TensorShape("ni", 3), + ] -class Qwen2VLVideoPixelInputs(TypedDict): - type: Literal["pixel_values_videos"] - pixel_values_videos: torch.Tensor - """Shape: - `(num_patches, - num_channels * temporal_patch_size * patch_size * patch_size)` - """ - video_grid_thw: torch.Tensor - """Shape: `(num_videos, 3)` +Qwen2VLImageInputs: TypeAlias = Qwen2VLImagePixelInputs | Qwen2VLImageEmbeddingInputs + - This should be in `(grid_t, grid_h, grid_w)` format. +class Qwen2VLVideoPixelInputs(TensorSchema): + """ + Dimensions: + - np: The total number of patches over each video over each prompt in + the batch + - ctps: Number of channels * temporal_patch_size * patch_size * + patch_size + - nv: Number of videos + + Historical context: + - pixel_values_videos shape: (num_patches, num_channels * + temporal_patch_size * patch_size * patch_size) + - video_grid_thw shape: (num_videos, 3) in (grid_t, grid_h, grid_w) + format """ + type: Literal["pixel_values_videos"] -class Qwen2VLVideoEmbeddingInputs(TypedDict): - type: Literal["video_embeds"] - video_embeds: torch.Tensor - """Supported types: - - list[`torch.Tensor`]: A list of tensors holding all videos' features. - Each tensor holds an video's features. - - `torch.Tensor`: A tensor holding all videos' features - (concatenation of all videos' feature tensors). - - Tensor shape: `(num_image_features, hidden_size)` - - `num_image_features` varies based on - the number and resolution of the videos. - - `hidden_size` must match the hidden size of language model backbone. - """ + pixel_values_videos: Annotated[ + torch.Tensor, + TensorShape("np", "ctps"), + ] - video_grid_thw: torch.Tensor - """Shape: `(num_videos, 3)` - This should be in `(grid_t, grid_h, grid_w)` format. + video_grid_thw: Annotated[ + torch.Tensor, + TensorShape("nv", 3), + ] + + +class Qwen2VLVideoEmbeddingInputs(TensorSchema): + """ + Dimensions: + - nf: Number of video features + - hs: Hidden size + - nv: Number of videos + + Historical context: + - video_embeds shape: (num_video_features, hidden_size) + - num_video_features varies based on the number and resolution of the + videos. + - hidden_size must match the hidden size of language model backbone. + - video_grid_thw shape: (num_videos, 3) in (grid_t, grid_h, grid_w) + format """ + type: Literal["video_embeds"] + + video_embeds: Annotated[ + torch.Tensor, + TensorShape("nf", "hs"), + ] + + video_grid_thw: Annotated[ + torch.Tensor, + TensorShape("nv", 3), + ] + -Qwen2VLVideoInputs = Union[Qwen2VLVideoPixelInputs, - Qwen2VLVideoEmbeddingInputs] +Qwen2VLVideoInputs: TypeAlias = Qwen2VLVideoPixelInputs | Qwen2VLVideoEmbeddingInputs # === Vision Encoder === # class Qwen2VisionMLP(nn.Module): - def __init__( self, in_features: int, hidden_features: int, act_layer: type[nn.Module] = QuickGELU, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", + use_data_parallel: bool = False, ): super().__init__() - self.fc1 = ColumnParallelLinear(in_features, - hidden_features, - quant_config=quant_config, - prefix=f"{prefix}.fc1") + self.fc1 = ColumnParallelLinear( + in_features, + hidden_features, + quant_config=quant_config, + prefix=f"{prefix}.fc1", + disable_tp=use_data_parallel, + ) self.act = act_layer() - self.fc2 = RowParallelLinear(hidden_features, - in_features, - quant_config=quant_config, - prefix=f"{prefix}.fc2") + self.fc2 = RowParallelLinear( + hidden_features, + in_features, + quant_config=quant_config, + prefix=f"{prefix}.fc2", + disable_tp=use_data_parallel, + ) def forward(self, x: torch.Tensor) -> torch.Tensor: x_parallel, _ = self.fc1(x) @@ -199,15 +282,14 @@ def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor: return torch.cat((-x2, x1), dim=-1) else: x1, x2 = x[..., ::2], x[..., 1::2] - return rearrange(torch.stack((-x2, x1), dim=-1), - "... d two -> ... (d two)", - two=2) + return rearrange( + torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2 + ) -def apply_rotary_emb_torch(x: torch.Tensor, - cos: torch.Tensor, - sin: torch.Tensor, - interleaved: bool = False) -> torch.Tensor: +def apply_rotary_emb_torch( + x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, interleaved: bool = False +) -> torch.Tensor: """ x: (batch_size, seqlen, nheads, headdim) cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2) @@ -215,68 +297,99 @@ def apply_rotary_emb_torch(x: torch.Tensor, ro_dim = cos.shape[-1] * 2 assert ro_dim <= x.shape[-1] cos = repeat( - cos, - "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)") + cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)" + ) sin = repeat( - sin, - "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)") + sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)" + ) return torch.cat( [ - x[..., :ro_dim] * cos + - rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:] + x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin, + x[..., ro_dim:], ], dim=-1, ) -def apply_rotary_pos_emb_vision(t: torch.Tensor, - freqs: torch.Tensor) -> torch.Tensor: +def apply_rotary_pos_emb_vision(t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: + rotary_emb_function = dispatch_rotary_emb_function(default=apply_rotary_emb_torch) t_ = t.float() cos = freqs.cos() sin = freqs.sin() - apply_rotary_emb = apply_rotary_emb_torch - if current_platform.is_out_of_tree(): - from flash_attn.layers.rotary import apply_rotary_emb - output = apply_rotary_emb(t_, cos, sin).type_as(t) + output = rotary_emb_function(t_, cos, sin).type_as(t) return output class Qwen2VisionAttention(nn.Module): - def __init__( self, embed_dim: int, num_heads: int, projection_size: int, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", + use_data_parallel: bool = False, + attn_backend_override: _Backend | None = None, ) -> None: super().__init__() # Per attention head and per partition values. - world_size = parallel_state.get_tensor_model_parallel_world_size() - self.tp_size = world_size + self.tp_size = ( + 1 + if use_data_parallel + else parallel_state.get_tensor_model_parallel_world_size() + ) self.tp_rank = parallel_state.get_tensor_model_parallel_rank() self.hidden_size_per_attention_head = dist_utils.divide( - projection_size, num_heads) + projection_size, num_heads + ) self.num_attention_heads_per_partition = dist_utils.divide( - num_heads, world_size) + num_heads, self.tp_size + ) - self.qkv = ColumnParallelLinear(input_size=embed_dim, - output_size=3 * projection_size, - quant_config=quant_config, - prefix=f"{prefix}.qkv") - self.proj = RowParallelLinear(input_size=projection_size, - output_size=embed_dim, - quant_config=quant_config, - prefix=f"{prefix}.proj") + self.qkv = ColumnParallelLinear( + input_size=embed_dim, + output_size=3 * projection_size, + quant_config=quant_config, + prefix=f"{prefix}.qkv", + disable_tp=use_data_parallel, + ) + self.proj = RowParallelLinear( + input_size=projection_size, + output_size=embed_dim, + quant_config=quant_config, + prefix=f"{prefix}.proj", + disable_tp=use_data_parallel, + ) # Detect attention implementation. - self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True) + self.attn_backend = get_vit_attn_backend( + head_size=self.hidden_size_per_attention_head, + dtype=torch.get_default_dtype(), + attn_backend_override=attn_backend_override, + ) + self.use_upstream_fa = False + + # /--------------- Metax Modification ---------------\ + self.use_upstream_fa = True + from flash_attn import flash_attn_varlen_func + + self.flash_attn_varlen_func = flash_attn_varlen_func + # \--------------- Metax Modification ---------------/ + if self.attn_backend not in { - _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS + _Backend.FLASH_ATTN, + _Backend.TORCH_SDPA, + _Backend.XFORMERS, + _Backend.ROCM_AITER_FA, }: raise RuntimeError( - f"Qwen2-VL does not support {self.attn_backend} backend now.") + f"Qwen2-VL does not support {self.attn_backend} backend now." + ) + + self.is_flash_attn_backend = self.attn_backend in { + _Backend.FLASH_ATTN, + _Backend.ROCM_AITER_FA, + } def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: # [s, b, 3 * head * head_dim] @@ -289,27 +402,31 @@ def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: # 3 * [s, b, head * head_dim] if self.tp_size > 1: - splitter = partial(dist_utils.split_tensor_along_last_dim, - num_partitions=self.tp_size) + splitter = partial( + dist_utils.split_tensor_along_last_dim, num_partitions=self.tp_size + ) q = splitter(q)[self.tp_rank] k = splitter(k)[self.tp_rank] v = splitter(v)[self.tp_rank] # 3 * [s, b, head * head_dim] -> 3 * [s, b, head, head_dim] - new_shape = (seq_len, bs, self.num_attention_heads_per_partition, - self.hidden_size_per_attention_head) + new_shape = ( + seq_len, + bs, + self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head, + ) q, k, v = (x.view(*new_shape) for x in (q, k, v)) return q, k, v def forward( - self, - x: torch.Tensor, - cu_seqlens: torch.Tensor, - rotary_pos_emb: torch.Tensor, - max_seqlen: Optional[int] = None, # Only used for Flash Attention - seqlens: Optional[list[int]] = None, # Only used for xFormers + self, + x: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: torch.Tensor, + max_seqlen: int | None = None, # Only used for Flash Attention + seqlens: list[int] | None = None, # Only used for xFormers ) -> torch.Tensor: - # [s, b, c] --> [s, b, 3 * head * head_dim] x, _ = self.qkv(x) @@ -317,34 +434,39 @@ def forward( q, k, v = self.split_qkv(x) batch_size = q.shape[1] - q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous() - for x in (q, k, v)) + q, k, v = (rearrange(x, "s b ... -> b s ...") for x in (q, k, v)) if rotary_pos_emb is not None: - q = apply_rotary_pos_emb_vision(q, rotary_pos_emb) - k = apply_rotary_pos_emb_vision(k, rotary_pos_emb) - - if self.attn_backend == _Backend.FLASH_ATTN: - # from vllm_flash_attn.flash_attn_interface import ( - # flash_attn_varlen_func) - from flash_attn import flash_attn_varlen_func + # [2 * b, s, heads, head_dim] + qk_concat = torch.cat([q, k], dim=0) + qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb) + q, k = torch.chunk(qk_rotated, 2, dim=0) + if self.is_flash_attn_backend: q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) - output = flash_attn_varlen_func(q, - k, - v, - cu_seqlens_q=cu_seqlens, - cu_seqlens_k=cu_seqlens, - max_seqlen_q=max_seqlen, - max_seqlen_k=max_seqlen, - dropout_p=0, - causal=False) - - context_layer = rearrange(output, - "(b s) ... -> b s ...", - b=batch_size) + output = self.flash_attn_varlen_func( + q, + k, + v, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, + dropout_p=0.0, + causal=False, + ) + + context_layer = rearrange( + output, "(b s) h d -> s b (h d)", b=batch_size + ).contiguous() elif self.attn_backend == _Backend.TORCH_SDPA: # Execute attention entry by entry for speed & less VRAM. + from vllm.platforms import current_platform + + if current_platform.is_rocm(): + q = q.contiguous() + k = k.contiguous() + v = v.contiguous() outputs = [] for i in range(1, len(cu_seqlens)): start_idx = cu_seqlens[i - 1] @@ -352,43 +474,47 @@ def forward( q_i = q[:, start_idx:end_idx] k_i = k[:, start_idx:end_idx] v_i = v[:, start_idx:end_idx] - q_i, k_i, v_i = (rearrange(x, "b s h d -> b h s d") - for x in [q_i, k_i, v_i]) - output_i = F.scaled_dot_product_attention(q_i, - k_i, - v_i, - dropout_p=0.0) + q_i, k_i, v_i = ( + rearrange(x, "b s h d -> b h s d") for x in [q_i, k_i, v_i] + ) + output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0) output_i = rearrange(output_i, "b h s d -> b s h d ") outputs.append(output_i) context_layer = torch.cat(outputs, dim=1) + context_layer = rearrange( + context_layer, "b s h d -> s b (h d)" + ).contiguous() elif self.attn_backend == _Backend.XFORMERS: from xformers import ops as xops from xformers.ops.fmha.attn_bias import BlockDiagonalMask - attn_bias = BlockDiagonalMask.from_seqlens(q_seqlen=seqlens, - kv_seqlen=None, - device=q.device) + attn_bias = BlockDiagonalMask.from_seqlens( + q_seqlen=seqlens, kv_seqlen=None, device=q.device + ) context_layer = xops.memory_efficient_attention_forward( - q, k, v, attn_bias=attn_bias, p=0, scale=None) - context_layer = rearrange(context_layer, - "b s h d -> s b (h d)").contiguous() + q, k, v, attn_bias=attn_bias, p=0, scale=None + ) + context_layer = rearrange( + context_layer, "b s h d -> s b (h d)" + ).contiguous() output, _ = self.proj(context_layer) return output class Qwen2VisionBlock(nn.Module): - def __init__( self, dim: int, num_heads: int, mlp_ratio: float, act_layer: type[nn.Module] = QuickGELU, - norm_layer: Optional[Callable[[int], nn.Module]] = None, - quant_config: Optional[QuantizationConfig] = None, + norm_layer: Callable[[int], nn.Module] | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", + use_data_parallel: bool = False, + attn_backend_override: _Backend | None = None, ) -> None: super().__init__() if norm_layer is None: @@ -397,24 +523,31 @@ def __init__( self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) - self.attn = Qwen2VisionAttention(embed_dim=dim, - num_heads=num_heads, - projection_size=dim, - quant_config=quant_config, - prefix=f"{prefix}.attn") - self.mlp = Qwen2VisionMLP(dim, - mlp_hidden_dim, - act_layer=act_layer, - quant_config=quant_config, - prefix=f"{prefix}.mlp") + self.attn = Qwen2VisionAttention( + embed_dim=dim, + num_heads=num_heads, + projection_size=dim, + quant_config=quant_config, + prefix=f"{prefix}.attn", + use_data_parallel=use_data_parallel, + attn_backend_override=attn_backend_override, + ) + self.mlp = Qwen2VisionMLP( + dim, + mlp_hidden_dim, + act_layer=act_layer, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + use_data_parallel=use_data_parallel, + ) def forward( - self, - x: torch.Tensor, - cu_seqlens: torch.Tensor, - rotary_pos_emb: torch.Tensor, - max_seqlen: Optional[int] = None, # Only used for Flash Attention - seqlens: Optional[list[int]] = None, # Only used for xFormers + self, + x: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: torch.Tensor, + max_seqlen: int | None = None, # Only used for Flash Attention + seqlens: list[int] | None = None, # Only used for xFormers ) -> torch.Tensor: x = x + self.attn( self.norm1(x), @@ -429,7 +562,6 @@ def forward( class Qwen2VisionPatchEmbed(nn.Module): - def __init__( self, patch_size: int = 14, @@ -443,49 +575,55 @@ def __init__( self.embed_dim = embed_dim kernel_size = (temporal_patch_size, patch_size, patch_size) - self.proj = nn.Conv3d(in_channels, - embed_dim, - kernel_size=kernel_size, - stride=kernel_size, - bias=False) + self.proj = ReplicatedLinear( + in_channels * math.prod(kernel_size), + embed_dim, + bias=False, + return_bias=False, + ) def forward(self, x: torch.Tensor) -> torch.Tensor: - L, C = x.shape - x = x.view(L, -1, self.temporal_patch_size, self.patch_size, - self.patch_size) - x = self.proj(x).view(L, self.embed_dim) + x = self.proj(x) return x class Qwen2VisionPatchMerger(nn.Module): - def __init__( self, d_model: int, context_dim: int, - norm_layer: Optional[Callable[[int], nn.Module]] = None, + norm_layer: Callable[[int], nn.Module] | None = None, spatial_merge_size: int = 2, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", + use_data_parallel: bool = False, ) -> None: super().__init__() self.hidden_size = context_dim * (spatial_merge_size**2) if norm_layer is None: norm_layer = partial(nn.LayerNorm, eps=1e-6) self.ln_q = norm_layer(context_dim) - self.mlp = nn.ModuleList([ - ColumnParallelLinear(self.hidden_size, - self.hidden_size, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.mlp.0"), - nn.GELU(), - RowParallelLinear(self.hidden_size, - d_model, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.mlp.2"), - ]) + self.mlp = nn.ModuleList( + [ + ColumnParallelLinear( + self.hidden_size, + self.hidden_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.mlp.0", + disable_tp=use_data_parallel, + ), + nn.GELU(), + RowParallelLinear( + self.hidden_size, + d_model, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.mlp.2", + disable_tp=use_data_parallel, + ), + ] + ) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.ln_q(x) @@ -499,13 +637,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class Qwen2VisionRotaryEmbedding(nn.Module): - def __init__(self, dim: int, theta: float = 10000.0) -> None: super().__init__() self.dim = dim self.theta = theta - inv_freq = 1.0 / (theta - **(torch.arange(0, dim, 2, dtype=torch.float) / dim)) + inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) self._seq_len_cached = 0 self._freqs_cached = None @@ -514,12 +650,18 @@ def update_freqs_cache(self, seqlen: int) -> None: if seqlen > self._seq_len_cached: seqlen *= 2 self._seq_len_cached = seqlen - self.inv_freq = 1.0 / (self.theta**(torch.arange( - 0, self.dim, 2, dtype=torch.float, device=self.inv_freq.device) - / self.dim)) - seq = torch.arange(seqlen, - device=self.inv_freq.device, - dtype=self.inv_freq.dtype) + self.inv_freq = 1.0 / ( + self.theta + ** ( + torch.arange( + 0, self.dim, 2, dtype=torch.float, device=self.inv_freq.device + ) + / self.dim + ) + ) + seq = torch.arange( + seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype + ) freqs = torch.outer(seq, self.inv_freq) self._freqs_cached = freqs @@ -529,13 +671,14 @@ def forward(self, seqlen: int) -> torch.Tensor: class Qwen2VisionTransformer(nn.Module): - def __init__( self, vision_config: Qwen2VLVisionConfig, norm_eps: float = 1e-6, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", + use_data_parallel: bool = False, + attn_backend_override: _Backend | None = None, ) -> None: super().__init__() @@ -549,6 +692,9 @@ def __init__( num_heads = vision_config.num_heads mlp_ratio = vision_config.mlp_ratio + self.use_data_parallel = use_data_parallel + self.out_hidden_size = vision_config.hidden_size + self.spatial_merge_size = spatial_merge_size self.num_heads = num_heads self.embed_dim = embed_dim @@ -564,23 +710,38 @@ def __init__( head_dim = embed_dim // num_heads self.rotary_pos_emb = Qwen2VisionRotaryEmbedding(head_dim // 2) - self.blocks = nn.ModuleList([ - Qwen2VisionBlock(dim=embed_dim, - num_heads=num_heads, - mlp_ratio=mlp_ratio, - norm_layer=norm_layer, - quant_config=quant_config, - prefix=f"{prefix}.blocks.{layer_idx}") - for layer_idx in range(depth) - ]) + self.blocks = nn.ModuleList( + [ + Qwen2VisionBlock( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + norm_layer=norm_layer, + quant_config=quant_config, + prefix=f"{prefix}.blocks.{layer_idx}", + use_data_parallel=use_data_parallel, + attn_backend_override=attn_backend_override, + ) + for layer_idx in range(depth) + ] + ) self.merger = Qwen2VisionPatchMerger( d_model=hidden_size, context_dim=embed_dim, norm_layer=norm_layer, quant_config=quant_config, prefix=f"{prefix}.merger", + use_data_parallel=use_data_parallel, ) - self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True) + self.attn_backend = get_vit_attn_backend( + head_size=head_dim, + dtype=torch.get_default_dtype(), + attn_backend_override=attn_backend_override, + ) + if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability( + torch.get_default_dtype() + ): + self.attn_backend = _Backend.FLASH_ATTN @property def dtype(self) -> torch.dtype: @@ -590,36 +751,44 @@ def dtype(self) -> torch.dtype: def device(self) -> torch.device: return self.patch_embed.proj.weight.device - def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: + def rot_pos_emb(self, grid_thw: list[list[int]]) -> torch.Tensor: pos_ids = [] + max_grid_size = 0 for t, h, w in grid_thw: hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) - hpos_ids = hpos_ids.reshape( - h // self.spatial_merge_size, - self.spatial_merge_size, - w // self.spatial_merge_size, - self.spatial_merge_size, - ).permute(0, 2, 1, 3).flatten() - wpos_ids = wpos_ids.reshape( - h // self.spatial_merge_size, - self.spatial_merge_size, - w // self.spatial_merge_size, - self.spatial_merge_size, - ).permute(0, 2, 1, 3).flatten() - pos_ids.append( - torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + hpos_ids = ( + hpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + .permute(0, 2, 1, 3) + .flatten() + ) + wpos_ids = ( + wpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + .permute(0, 2, 1, 3) + .flatten() + ) + pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + max_grid_size = max(max_grid_size, h, w) pos_ids = torch.cat(pos_ids, dim=0) - max_grid_size = grid_thw[:, 1:].max() rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) return rotary_pos_emb def compute_attn_mask_seqlen( - self, cu_seqlens: torch.Tensor - ) -> tuple[Optional[int], Optional[list[int]]]: + self, cu_seqlens: torch.Tensor + ) -> tuple[int | None, list[int] | None]: max_seqlen, seqlens = None, None - if self.attn_backend == _Backend.FLASH_ATTN: + if self.attn_backend in {_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA}: max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() elif self.attn_backend == _Backend.XFORMERS: seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() @@ -628,20 +797,27 @@ def compute_attn_mask_seqlen( def forward( self, x: torch.Tensor, - grid_thw: torch.Tensor, + grid_thw: torch.Tensor | list[list[int]], ) -> torch.Tensor: # patchify x = x.to(device=self.device, dtype=self.dtype) x = self.patch_embed(x) + if isinstance(grid_thw, list): + grid_thw_list = grid_thw + grid_thw = torch.tensor(grid_thw, dtype=torch.int32) + else: + grid_thw_list = grid_thw.tolist() + # compute position embedding - rotary_pos_emb = self.rot_pos_emb(grid_thw) + rotary_pos_emb = self.rot_pos_emb(grid_thw_list) # compute cu_seqlens - cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], - grid_thw[:, 0]).cumsum( - dim=0, dtype=torch.int32) - cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0) + cu_seqlens = torch.repeat_interleave( + grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0] + ).cumsum(dim=0, dtype=torch.int32) + cu_seqlens = torch.cat([cu_seqlens.new_zeros(1), cu_seqlens]) + cu_seqlens = cu_seqlens.to(self.device, non_blocking=True) # transformers x = x.unsqueeze(1) @@ -662,8 +838,7 @@ def forward( return x - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -674,7 +849,10 @@ def load_weights(self, weights: Iterable[tuple[str, loaded_params: set[str] = set() for name, loaded_weight in weights: - for (param_name, weight_name, shard_id) in stacked_params_mapping: + if name.endswith("patch_embed.proj.weight"): + loaded_weight = conv3d_to_linear_weight(loaded_weight) + + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -685,139 +863,110 @@ def load_weights(self, weights: Iterable[tuple[str, break else: param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params -def _qwen2vl_field_config(hf_inputs: Mapping[str, torch.Tensor]): - image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3))) - image_grid_sizes = image_grid_thw.prod(-1) - - video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3))) - video_grid_sizes = video_grid_thw.prod(-1) - - return dict( - pixel_values=MultiModalFieldConfig.flat_from_sizes( - "image", image_grid_sizes), - image_embeds=MultiModalFieldConfig.flat_from_sizes( - "image", image_grid_sizes), - image_grid_thw=MultiModalFieldConfig.batched("image"), - pixel_values_videos=MultiModalFieldConfig.flat_from_sizes( - "video", video_grid_sizes), - video_embeds=MultiModalFieldConfig.flat_from_sizes( - "video", video_grid_sizes), - video_grid_thw=MultiModalFieldConfig.batched("video"), - ) +def _create_qwen2vl_field_factory( + spatial_merge_size: int, +) -> Callable[ + [Mapping[str, torch.Tensor]], + Mapping[str, MultiModalFieldConfig], +]: + def _qwen2vl_field_config(hf_inputs: Mapping[str, torch.Tensor]): + image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3))) + image_pixel_grid_sizes = image_grid_thw.prod(-1) + image_embed_grid_sizes = ( + image_pixel_grid_sizes // spatial_merge_size // spatial_merge_size + ) + + video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3))) + video_grid_sizes = video_grid_thw.prod(-1) + video_embed_grid_sizes = ( + video_grid_sizes // spatial_merge_size // spatial_merge_size + ) + + return dict( + pixel_values=MultiModalFieldConfig.flat_from_sizes( + "image", image_pixel_grid_sizes + ), + image_embeds=MultiModalFieldConfig.flat_from_sizes( + "image", image_embed_grid_sizes + ), + image_grid_thw=MultiModalFieldConfig.batched("image"), + pixel_values_videos=MultiModalFieldConfig.flat_from_sizes( + "video", video_grid_sizes + ), + video_embeds=MultiModalFieldConfig.flat_from_sizes( + "video", video_embed_grid_sizes + ), + video_grid_thw=MultiModalFieldConfig.batched("video"), + ) + + return _qwen2vl_field_config class Qwen2VLMultiModalDataParser(MultiModalDataParser): + def __init__(self, spatial_merge_size: int, *args, **kwargs): + self._spatial_merge_size = spatial_merge_size + super().__init__(*args, **kwargs) def _parse_image_data( self, - data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]], - ) -> Optional[ModalityDataItems[Any, Any]]: + data: dict[str, torch.Tensor] | ModalityData[ImageItem], + ) -> ModalityDataItems[Any, Any] | None: if isinstance(data, dict): return DictEmbeddingItems( data, modality="image", required_fields={"image_embeds", "image_grid_thw"}, - fields_factory=_qwen2vl_field_config, + fields_factory=_create_qwen2vl_field_factory(self._spatial_merge_size), ) return super()._parse_image_data(data) def _parse_video_data( self, - data: Union[dict[str, torch.Tensor], ModalityData[VideoItem]], - ) -> Optional[ModalityDataItems[Any, Any]]: + data: dict[str, torch.Tensor] | ModalityData[VideoItem], + ) -> ModalityDataItems[Any, Any] | None: if isinstance(data, dict): return DictEmbeddingItems( data, modality="video", required_fields={"video_embeds", "video_grid_thw"}, - fields_factory=_qwen2vl_field_config, + fields_factory=_create_qwen2vl_field_factory(self._spatial_merge_size), ) return super()._parse_video_data(data) class Qwen2VLProcessingInfo(BaseProcessingInfo): - def get_hf_config(self): return self.ctx.get_hf_config(Qwen2VLConfig) - def get_hf_processor( - self, - *, - min_pixels: Optional[int] = None, - max_pixels: Optional[int] = None, - size: Optional[dict[str, int]] = None, - **kwargs: object, - ) -> Qwen2VLProcessor: + def get_hf_processor(self, **kwargs: object) -> Qwen2VLProcessor: return self.ctx.get_hf_processor( Qwen2VLProcessor, - image_processor=self.get_image_processor( - min_pixels=min_pixels, - max_pixels=max_pixels, - size=size, - use_fast=kwargs.get("use_fast")), + use_fast=kwargs.pop("use_fast", True), **kwargs, ) - def _get_image_processor_kwargs( - self, - *, - min_pixels: Optional[int] = None, - max_pixels: Optional[int] = None, - size: Optional[dict[str, int]] = None, - **kwargs: object, - ): - mm_config = self.ctx.model_config.get_multimodal_config() - if mm_config.mm_processor_kwargs: - kwargs.update(mm_config.mm_processor_kwargs) - - if min_pixels is not None: - kwargs["min_pixels"] = min_pixels - - if size is None: - size = {"shortest_edge": min_pixels} - else: - size["shortest_edge"] = min_pixels - - if max_pixels is not None: - kwargs["max_pixels"] = max_pixels - - if size is None: - size = {"longest_edge": max_pixels} - else: - size["longest_edge"] = max_pixels + def get_image_processor(self, **kwargs: object) -> Qwen2VLImageProcessor: + return self.get_hf_processor(**kwargs).image_processor - if size is not None: - kwargs["size"] = size - - return kwargs + def get_supported_mm_limits(self) -> Mapping[str, int | None]: + return {"image": None, "video": None} - def get_image_processor( + def get_mm_max_tokens_per_item( self, - *, - min_pixels: Optional[int] = None, - max_pixels: Optional[int] = None, - size: Optional[dict[str, int]] = None, - **kwargs: object, - ) -> Qwen2VLImageProcessor: - return cached_image_processor_from_config( - self.ctx.model_config, - **self._get_image_processor_kwargs(min_pixels=min_pixels, - max_pixels=max_pixels, - size=size, - **kwargs), - ) - - def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: - return {"image": None, "video": None} + seq_len: int, + mm_counts: Mapping[str, int], + ) -> Mapping[str, int]: + max_image_tokens = self.get_max_image_tokens() + max_video_tokens = self.get_max_video_tokens(seq_len, mm_counts) + return {"image": max_image_tokens, "video": max_video_tokens} def _get_vision_info( self, @@ -826,7 +975,7 @@ def _get_vision_info( image_height: int, num_frames: int = 1, do_resize: bool = True, - image_processor: Optional[Qwen2VLImageProcessor], + image_processor: Qwen2VLImageProcessor | None, ) -> tuple[ImageSize, int]: if image_processor is None: image_processor = self.get_image_processor() @@ -845,11 +994,9 @@ def _get_vision_info( min_pixels=image_processor.min_pixels, max_pixels=image_processor.max_pixels, ) - preprocessed_size = ImageSize(width=resized_width, - height=resized_height) + preprocessed_size = ImageSize(width=resized_width, height=resized_height) else: - preprocessed_size = ImageSize(width=image_width, - height=image_height) + preprocessed_size = ImageSize(width=image_width, height=image_height) # NOTE: Frames are padded to be divisible by `temporal_patch_size` # https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/qwen2_vl/image_processing_qwen2_vl.py#L294 @@ -869,11 +1016,12 @@ def get_num_image_tokens( *, image_width: int, image_height: int, - image_processor: Optional[Qwen2VLImageProcessor], + image_processor: Qwen2VLImageProcessor | None, ) -> int: _, num_image_tokens = self._get_vision_info( image_width=image_width, image_height=image_height, + num_frames=1, image_processor=image_processor, ) return num_image_tokens @@ -884,7 +1032,7 @@ def get_num_video_tokens( image_width: int, image_height: int, num_frames: int, - image_processor: Optional[Qwen2VLImageProcessor], + image_processor: Qwen2VLImageProcessor | None, ) -> int: _, num_video_tokens = self._get_vision_info( image_width=image_width, @@ -898,6 +1046,7 @@ def get_image_size_with_most_features(self) -> ImageSize: max_image_size, _ = self._get_vision_info( image_width=9999999, image_height=9999999, + num_frames=1, image_processor=None, ) return max_image_size @@ -911,10 +1060,10 @@ def get_max_image_tokens(self) -> int: image_processor=None, ) - def _get_max_video_frames(self, max_tokens: int) -> int: + def _get_max_video_frames(self, max_tokens: int, start_num_frames: int = 1) -> int: target_width, target_height = self.get_image_size_with_most_features() - num_frames = 0 + num_frames = start_num_frames while True: next_num_frames = num_frames + 1 @@ -936,15 +1085,14 @@ def get_num_frames_with_most_features( self, seq_len: int, mm_counts: Mapping[str, int], + max_frames_per_video: int = _MAX_FRAMES_PER_VIDEO, ) -> int: - max_images = mm_counts.get("image", 0) max_videos = mm_counts.get("video", 0) - max_image_tokens = self.get_max_image_tokens() * max_images - max_total_frames = self._get_max_video_frames(seq_len - - max_image_tokens) - max_frames_per_video = min(max_total_frames // max(max_videos, 1), - _MAX_FRAMES_PER_VIDEO) + max_total_frames = self._get_max_video_frames(seq_len) + max_frames_per_video = min( + max_total_frames // max(max_videos, 1), max_frames_per_video + ) return max(max_frames_per_video, 1) @@ -958,14 +1106,12 @@ def get_max_video_tokens( return self.get_num_video_tokens( image_width=target_width, image_height=target_height, - num_frames=self.get_num_frames_with_most_features( - seq_len, mm_counts), + num_frames=self.get_num_frames_with_most_features(seq_len, mm_counts), image_processor=None, ) class Qwen2VLDummyInputsBuilder(BaseDummyInputsBuilder[Qwen2VLProcessingInfo]): - def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: num_images = mm_counts.get("image", 0) num_videos = mm_counts.get("video", 0) @@ -980,58 +1126,50 @@ def get_dummy_mm_data( self, seq_len: int, mm_counts: Mapping[str, int], + mm_options: Mapping[str, BaseDummyOptions] | None = None, ) -> MultiModalDataDict: num_images = mm_counts.get("image", 0) num_videos = mm_counts.get("video", 0) - target_width, target_height = \ - self.info.get_image_size_with_most_features() - target_num_frames = \ - self.info.get_num_frames_with_most_features(seq_len, mm_counts) + target_width, target_height = self.info.get_image_size_with_most_features() + target_num_frames = self.info.get_num_frames_with_most_features( + seq_len, mm_counts + ) + + image_overrides = mm_options.get("image") if mm_options else None + video_overrides = mm_options.get("video") if mm_options else None return { - "image": - self._get_dummy_images(width=target_width, - height=target_height, - num_images=num_images), - "video": - self._get_dummy_videos( + "image": self._get_dummy_images( + width=target_width, + height=target_height, + num_images=num_images, + overrides=image_overrides, + ), + "video": self._get_dummy_videos( width=target_width, height=target_height, num_frames=target_num_frames, num_videos=num_videos, - ) + overrides=video_overrides, + ), } -class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo] - ): - +class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo]): def _get_data_parser(self) -> MultiModalDataParser: - return Qwen2VLMultiModalDataParser() - - def _call_hf_processor( - self, - prompt: str, - mm_data: Mapping[str, object], - mm_kwargs: Mapping[str, object], - tok_kwargs: Mapping[str, object], - ) -> BatchFeature: - return self.info.ctx.call_hf_processor( - self.info.get_hf_processor(**mm_kwargs), - dict(text=prompt, **mm_data), - self.info._get_image_processor_kwargs(**mm_kwargs, **tok_kwargs), + return Qwen2VLMultiModalDataParser( + self.info.get_hf_config().vision_config.spatial_merge_size ) def _get_prompt_updates( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, Any], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) - image_processor = self.info.get_image_processor( - **hf_processor_mm_kwargs) + image_processor = self.info.get_image_processor(**hf_processor_mm_kwargs) tokenizer = self.info.get_tokenizer() vocab = tokenizer.get_vocab() @@ -1043,7 +1181,8 @@ def _get_prompt_updates( merge_length = image_processor.merge_size**2 def get_replacement_qwen2vl(item_idx: int, modality: str): - grid_thw = out_mm_kwargs[f"{modality}_grid_thw"][item_idx] + out_item = out_mm_kwargs[modality][item_idx] + grid_thw = out_item[f"{modality}_grid_thw"].data assert isinstance(grid_thw, torch.Tensor) num_tokens = int(grid_thw.prod()) // merge_length @@ -1053,9 +1192,9 @@ def get_replacement_qwen2vl(item_idx: int, modality: str): PromptReplacement( modality=modality, target=[placeholder[modality]], - replacement=partial(get_replacement_qwen2vl, - modality=modality), - ) for modality in ("image", "video") + replacement=partial(get_replacement_qwen2vl, modality=modality), + ) + for modality in ("image", "video") ] def _get_mm_fields_config( @@ -1063,14 +1202,21 @@ def _get_mm_fields_config( hf_inputs: BatchFeature, hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: - return _qwen2vl_field_config(hf_inputs) - - -@MULTIMODAL_REGISTRY.register_processor(Qwen2VLMultiModalProcessor, - info=Qwen2VLProcessingInfo, - dummy_inputs=Qwen2VLDummyInputsBuilder) -class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, - SupportsLoRA, SupportsPP): + return _create_qwen2vl_field_factory( + self.info.get_hf_config().vision_config.spatial_merge_size + )(hf_inputs) + + +@MULTIMODAL_REGISTRY.register_processor( + Qwen2VLMultiModalProcessor, + info=Qwen2VLProcessingInfo, + dummy_inputs=Qwen2VLDummyInputsBuilder, +) +class Qwen2VLForConditionalGeneration( + nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE +): + merge_by_field_config = True + multimodal_cpu_fields = {"image_grid_thw", "video_grid_thw"} # To ensure correct weight loading and mapping. hf_to_vllm_mapper = WeightsMapper( @@ -1081,7 +1227,147 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, # mapping for original checkpoint "lm_head.": "language_model.lm_head.", "model.": "language_model.model.", - }) + } + ) + + supports_encoder_tp_data = True + + def get_mrope_input_positions( + self, + input_tokens: list[int], + hf_config: PretrainedConfig, + image_grid_thw: list[list[int]] | torch.Tensor | None, + video_grid_thw: list[list[int]] | torch.Tensor | None, + second_per_grid_ts: list[float] | None = None, + audio_feature_lengths: torch.Tensor | None = None, + use_audio_in_video: bool = False, + ) -> tuple[torch.Tensor, int]: + """Get M-RoPE input positions for Qwen2-VL model.""" + if image_grid_thw is None: + image_grid_thw = [] + if video_grid_thw is None: + video_grid_thw = [] + if second_per_grid_ts is None: + second_per_grid_ts = [] + + image_token_id = hf_config.image_token_id + video_token_id = hf_config.video_token_id + vision_start_token_id = hf_config.vision_start_token_id + spatial_merge_size = hf_config.vision_config.spatial_merge_size + tokens_per_second = getattr(hf_config.vision_config, "tokens_per_second", 1.0) + + input_tokens_tensor = torch.tensor(input_tokens) + vision_start_indices = torch.argwhere( + input_tokens_tensor == vision_start_token_id + ).squeeze(1) + vision_tokens = input_tokens_tensor[vision_start_indices + 1] + image_nums = (vision_tokens == image_token_id).sum() + video_nums = (vision_tokens == video_token_id).sum() + llm_pos_ids_list: list = [] + + st = 0 + remain_images, remain_videos = image_nums, video_nums + + image_index, video_index = 0, 0 + for _ in range(image_nums + video_nums): + video_second_per_grid_t = 0.0 + if remain_images > 0: + try: + ed_image = input_tokens.index(image_token_id, st) + except ValueError: + ed_image = len(input_tokens) + 1 + else: + ed_image = len(input_tokens) + 1 + if remain_videos > 0: + try: + ed_video = input_tokens.index(video_token_id, st) + except ValueError: + ed_video = len(input_tokens) + 1 + else: + ed_video = len(input_tokens) + 1 + if ed_image < ed_video: + t, h, w = ( + image_grid_thw[image_index][0], + image_grid_thw[image_index][1], + image_grid_thw[image_index][2], + ) + image_index += 1 + remain_images -= 1 + ed = ed_image + else: + t, h, w = ( + video_grid_thw[video_index][0], + video_grid_thw[video_index][1], + video_grid_thw[video_index][2], + ) + video_second_per_grid_t = 1.0 + if second_per_grid_ts: + video_second_per_grid_t = second_per_grid_ts[video_index] + video_index += 1 + remain_videos -= 1 + ed = ed_video + + llm_grid_t, llm_grid_h, llm_grid_w = ( + t, + h // spatial_merge_size, + w // spatial_merge_size, + ) + text_len = ed - st + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx + ) + + t_index = ( + ( + torch.arange(llm_grid_t) + .view(-1, 1) + .expand(-1, llm_grid_h * llm_grid_w) + * video_second_per_grid_t + * tokens_per_second + ) + .long() + .flatten() + ) + + h_index = ( + torch.arange(llm_grid_h) + .view(1, -1, 1) + .expand(llm_grid_t, -1, llm_grid_w) + .flatten() + ) + w_index = ( + torch.arange(llm_grid_w) + .view(1, 1, -1) + .expand(llm_grid_t, llm_grid_h, -1) + .flatten() + ) + llm_pos_ids_list.append( + torch.stack([t_index, h_index, w_index]) + text_len + st_idx + ) + st = ed + llm_grid_t * llm_grid_h * llm_grid_w + + if st < len(input_tokens): + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + text_len = len(input_tokens) - st + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx + ) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() + + return llm_positions, mrope_position_delta + + @classmethod + def get_placeholder_str(cls, modality: str, i: int) -> str | None: + if modality.startswith("image"): + return "<|vision_start|><|image_pad|><|vision_end|>" + if modality.startswith("video"): + return "<|vision_start|><|video_pad|><|vision_end|>" + + raise ValueError("Only image or video modality is supported") def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -1089,15 +1375,28 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): quant_config = vllm_config.quant_config multimodal_config = vllm_config.model_config.multimodal_config + self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" self.config = config self.multimodal_config = multimodal_config - self.visual = Qwen2VisionTransformer( - config.vision_config, - norm_eps=getattr(config, "rms_norm_eps", 1e-6), - quant_config=self._maybe_ignore_quant_config(quant_config), - prefix=maybe_prefix(prefix, "visual"), - ) + if multimodal_config.get_limit_per_prompt( + "image" + ) or multimodal_config.get_limit_per_prompt("video"): + attn_backend_override = ( + multimodal_config.mm_encoder_attn_backend + if multimodal_config is not None + else None + ) + self.visual = Qwen2VisionTransformer( + config.vision_config, + norm_eps=getattr(config, "rms_norm_eps", 1e-6), + quant_config=quant_config, + prefix=maybe_prefix(prefix, "visual"), + use_data_parallel=self.use_data_parallel, + attn_backend_override=attn_backend_override, + ) + else: + self.visual = None self.language_model = init_vllm_registered_model( vllm_config=vllm_config, @@ -1106,34 +1405,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): ) self.make_empty_intermediate_tensors = ( - self.language_model.make_empty_intermediate_tensors) - - def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig): - # GPTQ configs do not have a list of ignored modules, however AutoGPTQ - # seems to avoid vision encoder sections for some models. - # See: https://huggingface.co/Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4 - if isinstance(quant_config, (GPTQConfig, GPTQMarlinConfig)): - return None - return quant_config - - def _validate_and_reshape_mm_tensor(self, mm_input: object, - name: str) -> torch.Tensor: - if not isinstance(mm_input, (torch.Tensor, list)): - raise ValueError(f"Incorrect type of {name}. " - f"Got type: {type(mm_input)}") - if isinstance(mm_input, torch.Tensor): - if mm_input.ndim == 2: - return mm_input - if mm_input.ndim != 3: - raise ValueError(f"{name} should be 2D or batched 3D tensor. " - f"Got ndim: {mm_input.ndim} " - f"(shape={mm_input.shape})") - return torch.concat(list(mm_input)) - else: - return torch.concat(mm_input) + self.language_model.make_empty_intermediate_tensors + ) def _parse_and_validate_image_input( - self, **kwargs: object) -> Optional[Qwen2VLImageInputs]: + self, **kwargs: object + ) -> Qwen2VLImageInputs | None: pixel_values = kwargs.pop("pixel_values", None) image_embeds = kwargs.pop("image_embeds", None) image_grid_thw = kwargs.pop("image_grid_thw", None) @@ -1142,34 +1419,22 @@ def _parse_and_validate_image_input( return None if pixel_values is not None: - pixel_values = self._validate_and_reshape_mm_tensor( - pixel_values, "image pixel values") - image_grid_thw = self._validate_and_reshape_mm_tensor( - image_grid_thw, "image grid_thw") - - if not isinstance(pixel_values, (torch.Tensor, list)): - raise ValueError("Incorrect type of image pixel values. " - f"Got type: {type(pixel_values)}") - - return Qwen2VLImagePixelInputs(type="pixel_values", - pixel_values=pixel_values, - image_grid_thw=image_grid_thw) + return Qwen2VLImagePixelInputs( + type="pixel_values", + pixel_values=pixel_values, + image_grid_thw=image_grid_thw, + ) if image_embeds is not None: - image_embeds = self._validate_and_reshape_mm_tensor( - image_embeds, "image embeds") - image_grid_thw = self._validate_and_reshape_mm_tensor( - image_grid_thw, "image grid_thw") - - if not isinstance(image_embeds, torch.Tensor): - raise ValueError("Incorrect type of image embeddings. " - f"Got type: {type(image_embeds)}") - return Qwen2VLImageEmbeddingInputs(type="image_embeds", - image_embeds=image_embeds, - image_grid_thw=image_grid_thw) + return Qwen2VLImageEmbeddingInputs( + type="image_embeds", + image_embeds=image_embeds, + image_grid_thw=image_grid_thw, + ) def _parse_and_validate_video_input( - self, **kwargs: object) -> Optional[Qwen2VLVideoInputs]: + self, **kwargs: object + ) -> Qwen2VLVideoInputs | None: pixel_values_videos = kwargs.pop("pixel_values_videos", None) video_embeds = kwargs.pop("video_embeds", None) video_grid_thw = kwargs.pop("video_grid_thw", None) @@ -1178,11 +1443,6 @@ def _parse_and_validate_video_input( return None if pixel_values_videos is not None: - pixel_values_videos = self._validate_and_reshape_mm_tensor( - pixel_values_videos, "video pixel values") - video_grid_thw = self._validate_and_reshape_mm_tensor( - video_grid_thw, "video grid_thw") - return Qwen2VLVideoPixelInputs( type="pixel_values_videos", pixel_values_videos=pixel_values_videos, @@ -1190,54 +1450,57 @@ def _parse_and_validate_video_input( ) if video_embeds is not None: - video_embeds = self._validate_and_reshape_mm_tensor( - video_embeds, "video embeds") - video_grid_thw = self._validate_and_reshape_mm_tensor( - video_grid_thw, "video grid_thw") - - if not isinstance(video_embeds, torch.Tensor): - raise ValueError("Incorrect type of video embeddings. " - f"Got type: {type(video_embeds)}") - return Qwen2VLVideoEmbeddingInputs(type="video_embeds", - video_embeds=video_embeds, - video_grid_thw=video_grid_thw) + return Qwen2VLVideoEmbeddingInputs( + type="video_embeds", + video_embeds=video_embeds, + video_grid_thw=video_grid_thw, + ) def _process_image_input( - self, image_input: Qwen2VLImageInputs) -> tuple[torch.Tensor, ...]: - + self, image_input: Qwen2VLImageInputs + ) -> tuple[torch.Tensor, ...]: grid_thw = image_input["image_grid_thw"] assert grid_thw.ndim == 2 if image_input["type"] == "image_embeds": - image_embeds = image_input["image_embeds"].type(self.visual.dtype) + image_embeds = image_input["image_embeds"] else: - pixel_values = image_input["pixel_values"].type(self.visual.dtype) - image_embeds = self.visual(pixel_values, grid_thw=grid_thw) + pixel_values = image_input["pixel_values"] + + if self.use_data_parallel: + return run_dp_sharded_mrope_vision_model( + self.visual, pixel_values, grid_thw.tolist(), rope_type="rope_3d" + ) + else: + image_embeds = self.visual(pixel_values, grid_thw=grid_thw) # Split concatenated embeddings for each image item. merge_size = self.visual.spatial_merge_size - sizes = grid_thw.prod(-1) // merge_size // merge_size - - return image_embeds.split(sizes.tolist()) + sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist() + return image_embeds.split(sizes) def _process_video_input( - self, video_input: Qwen2VLVideoInputs) -> tuple[torch.Tensor, ...]: - + self, video_input: Qwen2VLVideoInputs + ) -> tuple[torch.Tensor, ...]: grid_thw = video_input["video_grid_thw"] assert grid_thw.ndim == 2 if video_input["type"] == "video_embeds": - video_embeds = video_input["video_embeds"].type(self.visual.dtype) + video_embeds = video_input["video_embeds"] else: - pixel_values_videos = video_input["pixel_values_videos"].type( - self.visual.dtype) - video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw) + pixel_values_videos = video_input["pixel_values_videos"] + if self.use_data_parallel: + grid_thw_list = grid_thw.tolist() + return run_dp_sharded_mrope_vision_model( + self.visual, pixel_values_videos, grid_thw_list, rope_type="rope_3d" + ) + else: + video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw) # Split concatenated embeddings for each video item. merge_size = self.visual.spatial_merge_size - sizes = grid_thw.prod(-1) // merge_size // merge_size - - return video_embeds.split(sizes.tolist()) + sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist() + return video_embeds.split(sizes) def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: modalities = {} @@ -1245,26 +1508,26 @@ def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: # Preserve the order of modalities if there are multiple of them # from the order of kwargs. for input_key in kwargs: - if input_key in ("pixel_values", - "image_embeds") and "images" not in modalities: - modalities["images"] = self._parse_and_validate_image_input( - **kwargs) - if input_key in ("pixel_values_videos", - "video_embeds") and "videos" not in modalities: - modalities["videos"] = self._parse_and_validate_video_input( - **kwargs) + if ( + input_key in ("pixel_values", "image_embeds") + and "images" not in modalities + ): + modalities["images"] = self._parse_and_validate_image_input(**kwargs) + if ( + input_key in ("pixel_values_videos", "video_embeds") + and "videos" not in modalities + ): + modalities["videos"] = self._parse_and_validate_video_input(**kwargs) return modalities def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings( - self, **kwargs: object) -> Optional[MultiModalEmbeddings]: - + def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: modalities = self._parse_and_validate_multimodal_inputs(**kwargs) if not modalities: - return None + return [] # The result multimodal_embeddings is tuple of tensors, with each # tensor correspoending to a multimodal data item (image or video). @@ -1275,61 +1538,23 @@ def get_multimodal_embeddings( for modality in modalities: if modality == "images": image_input = modalities["images"] - vision_embeddings = self._process_image_input(image_input) - multimodal_embeddings += vision_embeddings + image_embeddings = self._process_image_input(image_input) + multimodal_embeddings += tuple(image_embeddings) if modality == "videos": video_input = modalities["videos"] video_embeddings = self._process_video_input(video_input) - multimodal_embeddings += video_embeddings + multimodal_embeddings += tuple(video_embeddings) return multimodal_embeddings - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None: - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, multimodal_embeddings, - [self.config.image_token_id, self.config.video_token_id]) - return inputs_embeds - - def get_input_embeddings_v0( - self, - input_ids: torch.Tensor, - image_input: Optional[Qwen2VLImagePixelInputs] = None, - video_input: Optional[Qwen2VLVideoPixelInputs] = None, - ) -> torch.Tensor: - inputs_embeds = self.get_input_embeddings(input_ids) - if image_input is not None: - image_embeds = self._process_image_input(image_input) - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - image_embeds, - placeholder_token_id=self.config.image_token_id, - ) - - if video_input is not None: - video_embeds = self._process_video_input(video_input) - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - video_embeds, - placeholder_token_id=self.config.video_token_id, - ) - return inputs_embeds - def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, **kwargs: object, - ) -> Union[torch.Tensor, IntermediateTensors]: + ) -> torch.Tensor | IntermediateTensors: """Run forward pass for Qwen2-VL. Args: @@ -1339,40 +1564,14 @@ def forward( batch. **NOTE**: If mrope is enabled (default setting for Qwen2-VL opensource models), the shape will be `(3, seq_len)`, - otherwise it will be `(seq_len,). - pixel_values: Pixel values to be fed to a model. - `None` if no images are passed. - image_grid_thw: Tensor `(n_images, 3)` of image 3D grid in LLM. - `None` if no images are passed. - pixel_values_videos: Pixel values of videos to be fed to a model. - `None` if no videos are passed. - video_grid_thw: Tensor `(n_videos, 3)` of video 3D grid in LLM. - `None` if no videos are passed. + otherwise it will be `(seq_len,)`. + intermediate_tensors: Intermediate tensors from prior forward pass. + inputs_embeds: Optional tensor of input embeddings. """ if intermediate_tensors is not None: inputs_embeds = None - # NOTE: In v1, inputs_embeds is always generated at model runner from - # `get_multimodal_embeddings` and `get_input_embeddings`, this - # condition is only for v0 compatibility. - elif inputs_embeds is None: - image_input = self._parse_and_validate_image_input(**kwargs) - video_input = self._parse_and_validate_video_input(**kwargs) - - if image_input is None and video_input is None: - inputs_embeds = None - else: - if uses_mrope(self.config): - assert positions.ndim == 2 and positions.size(0) == 3, ( - "multimodal section rotary embedding requires " - f"(3, seq_len) positions, but got {positions.size()}") - inputs_embeds = self.get_input_embeddings_v0( - input_ids, - image_input=image_input, - video_input=video_input) - input_ids = None - hidden_states = self.language_model.model( input_ids=input_ids, positions=positions, @@ -1384,13 +1583,14 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: return self.language_model.compute_logits(hidden_states) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: - - loader = AutoWeightsLoader(self) + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + skip_prefixes = [] + if self.visual is None: + skip_prefixes.extend(["visual."]) + loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) def get_mm_mapping(self) -> MultiModelKeys: @@ -1402,3 +1602,88 @@ def get_mm_mapping(self) -> MultiModelKeys: connector="visual.merger.", tower_model="visual.", ) + + +class Tarsier2MultiModalProcessor(Qwen2VLMultiModalProcessor): + pass + + +class Tarsier2ImageProcessor(Qwen2VLImageProcessor): + def __init__( + self, + size: dict[str, int] | None = None, + **kwargs, + ) -> None: + if size is not None and "min_pixels" in size and "max_pixels" in size: + # Remap if Tarsier2-specific format is provided + remapped_size = { + "shortest_edge": size["min_pixels"], + "longest_edge": size["max_pixels"], + } + super().__init__(size=remapped_size, **kwargs) + else: + super().__init__(size=size, **kwargs) + + +class Tarsier2Processor(Qwen2VLProcessor): + def __init__( + self, + vision_config: dict, + tokenizer: AnyTokenizer, + **kwargs, + ): + self.image_processor = Tarsier2ImageProcessor(**vision_config) + super().__init__( + image_processor=self.image_processor, + tokenizer=tokenizer, + video_processor=Qwen2VLVideoProcessor(**vision_config), + chat_template=None, + **kwargs, + ) + + +class Tarsier2ProcessingInfo(Qwen2VLProcessingInfo): + def get_hf_config(self) -> Qwen2VLConfig: + model_path = self.ctx.model_config.model + correct_config = Qwen2VLConfig.from_pretrained(model_path) + + return correct_config + + def get_hf_processor(self, **kwargs: object) -> Tarsier2Processor: + return Tarsier2Processor( + vision_config=self.ctx.get_hf_image_processor_config(), + tokenizer=self.get_tokenizer(), + **kwargs, + ) + + def get_image_processor(self) -> Tarsier2ImageProcessor: + return Tarsier2ImageProcessor(**self.ctx.get_hf_image_processor_config()) + + +@MULTIMODAL_REGISTRY.register_processor( + Tarsier2MultiModalProcessor, + info=Tarsier2ProcessingInfo, + dummy_inputs=Qwen2VLDummyInputsBuilder, +) +class Tarsier2ForConditionalGeneration(Qwen2VLForConditionalGeneration): + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + "vision_tower.": "visual.", + } + ) + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + # Tarsier2 uses llava as model_type, which will create a Qwen2VLConfig + # as text_config, we need to reconstruct Qwen2VLConfig from LlavaConfig. + config = vllm_config.model_config.hf_config + qwen2vl_config = config.text_config + qwen2vl_config.architectures = config.architectures + vllm_config.model_config.hf_config = qwen2vl_config + super().__init__(vllm_config=vllm_config, prefix=prefix) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + skip_prefixes = [] + if self.visual is None: + skip_prefixes.extend(["visual."]) + loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) diff --git a/vllm_metax/models/qwen3_vl.py b/vllm_metax/models/qwen3_vl.py index 9fe673771..b4d0551b3 100644 --- a/vllm_metax/models/qwen3_vl.py +++ b/vllm_metax/models/qwen3_vl.py @@ -23,71 +23,100 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Qwen3VL model compatible with HuggingFace weights.""" -from collections.abc import Iterable, Mapping, Sequence + +import math +from collections.abc import Callable, Iterable, Mapping, Sequence from functools import partial -from typing import Any, Callable, Optional, Union +from itertools import islice +from typing import Any import numpy as np import torch import torch.nn as nn import torch.nn.functional as F -from transformers import BatchFeature +from transformers import BatchFeature, PretrainedConfig from transformers.models.qwen2_vl import Qwen2VLImageProcessorFast from transformers.models.qwen2_vl.image_processing_qwen2_vl import ( - smart_resize as image_smart_resize) -from transformers.models.qwen3_vl import (Qwen3VLProcessor, - Qwen3VLVideoProcessor) + smart_resize as image_smart_resize, +) +from transformers.models.qwen3_vl import Qwen3VLProcessor, Qwen3VLVideoProcessor from transformers.models.qwen3_vl.configuration_qwen3_vl import ( - Qwen3VLConfig, Qwen3VLVisionConfig) + Qwen3VLConfig, + Qwen3VLVisionConfig, +) from transformers.models.qwen3_vl.video_processing_qwen3_vl import ( - smart_resize as video_smart_resize) + smart_resize as video_smart_resize, +) from transformers.video_utils import VideoMetadata +from vllm.attention.backends.registry import _Backend from vllm.attention.layer import check_upstream_fa_availability from vllm.compilation.decorators import support_torch_compile from vllm.config import VllmConfig +from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions from vllm.distributed import get_pp_group from vllm.logger import init_logger from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargsItem, - MultiModalKwargsItems, VideoItem) -from vllm.multimodal.parse import (ImageSize, MultiModalDataItems, - MultiModalDataParser) -from vllm.multimodal.processing import (BaseMultiModalProcessor, - PromptReplacement, PromptUpdate, - PromptUpdateDetails) +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalKwargsItem, + MultiModalKwargsItems, + VideoItem, +) +from vllm.multimodal.parse import ImageSize, MultiModalDataItems, MultiModalDataParser +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + PromptReplacement, + PromptUpdate, + PromptUpdateDetails, +) from vllm.multimodal.profiling import BaseDummyInputsBuilder -from vllm.platforms import _Backend from vllm.sequence import IntermediateTensors -from vllm.transformers_utils.config import uses_mrope -from vllm.utils import is_list_of - -from vllm.model_executor.models.interfaces import (MultiModalEmbeddings, - SupportsLoRA, - SupportsMultiModal, - SupportsPP) -from .qwen2_5_vl import (Qwen2_5_VisionAttention, - Qwen2_5_VisionRotaryEmbedding, - Qwen2_5_VLImageEmbeddingInputs, Qwen2_5_VLImageInputs, - Qwen2_5_VLImagePixelInputs, - Qwen2_5_VLVideoEmbeddingInputs, Qwen2_5_VLVideoInputs, - Qwen2_5_VLVideoPixelInputs) +from vllm.utils.collection_utils import is_list_of + +from vllm.model_executor.models.interfaces import ( + MultiModalEmbeddings, + SupportsLoRA, + SupportsMRoPE, + SupportsMultiModal, + SupportsPP, +) +from .qwen2_5_vl import ( + Qwen2_5_VisionAttention, + Qwen2_5_VisionRotaryEmbedding, + Qwen2_5_VLImageEmbeddingInputs, + Qwen2_5_VLImageInputs, + Qwen2_5_VLImagePixelInputs, + Qwen2_5_VLVideoEmbeddingInputs, + Qwen2_5_VLVideoInputs, + Qwen2_5_VLVideoPixelInputs, +) from vllm.model_executor.models.qwen2_vl import Qwen2VLProcessingInfo from vllm.model_executor.models.qwen3 import Qwen3ForCausalLM, Qwen3Model -from vllm.model_executor.models.utils import (AutoWeightsLoader, - PPMissingLayer, WeightsMapper, - maybe_prefix, - merge_multimodal_embeddings) -from vllm.model_executor.models.vision import get_vit_attn_backend, run_dp_sharded_mrope_vision_model +from vllm.model_executor.models.utils import ( + AutoWeightsLoader, + PPMissingLayer, + WeightsMapper, + _merge_multimodal_embeddings, + maybe_prefix, +) +from vllm.model_executor.models.vision import ( + conv3d_to_linear_weight, + get_vit_attn_backend, + run_dp_sharded_mrope_vision_model, +) logger = init_logger(__name__) @@ -96,7 +125,6 @@ class Qwen3_VisionPatchEmbed(nn.Module): - def __init__( self, patch_size: int = 14, @@ -110,45 +138,48 @@ def __init__( self.hidden_size = hidden_size kernel_size = (temporal_patch_size, patch_size, patch_size) - self.proj = nn.Conv3d(in_channels, - hidden_size, - kernel_size=kernel_size, - stride=kernel_size, - bias=True) + self.proj = ReplicatedLinear( + in_channels * math.prod(kernel_size), + hidden_size, + bias=True, + return_bias=False, + ) def forward(self, x: torch.Tensor) -> torch.Tensor: - L, C = x.shape - x = x.view(L, -1, self.temporal_patch_size, self.patch_size, - self.patch_size) - x = self.proj(x).view(L, self.hidden_size) + x = self.proj(x) return x class Qwen3_VisionMLP(nn.Module): - - def __init__(self, - in_features: int, - hidden_features: int, - bias: bool = False, - act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - use_data_parallel: bool = False): + def __init__( + self, + in_features: int, + hidden_features: int, + bias: bool = False, + act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + use_data_parallel: bool = False, + ): super().__init__() - self.linear_fc1 = ColumnParallelLinear(in_features, - hidden_features, - bias=bias, - quant_config=quant_config, - return_bias=False, - prefix=f"{prefix}.linear_fc1", - disable_tp=use_data_parallel) - self.linear_fc2 = RowParallelLinear(hidden_features, - in_features, - bias=bias, - quant_config=quant_config, - return_bias=False, - prefix=f"{prefix}.linear_fc2", - disable_tp=use_data_parallel) + self.linear_fc1 = ColumnParallelLinear( + in_features, + hidden_features, + bias=bias, + quant_config=quant_config, + return_bias=False, + prefix=f"{prefix}.linear_fc1", + disable_tp=use_data_parallel, + ) + self.linear_fc2 = RowParallelLinear( + hidden_features, + in_features, + bias=bias, + quant_config=quant_config, + return_bias=False, + prefix=f"{prefix}.linear_fc2", + disable_tp=use_data_parallel, + ) self.act_fn = act_fn def forward(self, x: torch.Tensor): @@ -157,15 +188,14 @@ def forward(self, x: torch.Tensor): class Qwen3_VisionBlock(nn.Module): - def __init__( self, dim: int, num_heads: int, mlp_hidden_dim: int, act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu, - norm_layer: Optional[Callable[[int], nn.Module]] = None, - quant_config: Optional[QuantizationConfig] = None, + norm_layer: Callable[[int], nn.Module] | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, attn_backend: _Backend = _Backend.TORCH_SDPA, @@ -184,43 +214,47 @@ def __init__( prefix=f"{prefix}.attn", use_data_parallel=use_data_parallel, attn_backend=attn_backend, - use_upstream_fa=use_upstream_fa) - self.mlp = Qwen3_VisionMLP(dim, - mlp_hidden_dim, - act_fn=act_fn, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.mlp", - use_data_parallel=use_data_parallel) + use_upstream_fa=use_upstream_fa, + ) + self.mlp = Qwen3_VisionMLP( + dim, + mlp_hidden_dim, + act_fn=act_fn, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + use_data_parallel=use_data_parallel, + ) def forward( - self, - x: torch.Tensor, - cu_seqlens: torch.Tensor, - rotary_pos_emb: torch.Tensor, - max_seqlen: Optional[int] = None, # Only used for Flash Attention - seqlens: Optional[list[int]] = None, # Only used for xFormers + self, + x: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: torch.Tensor, + max_seqlen: torch.Tensor, # Only used for Flash Attention + seqlens: torch.Tensor, # Only used for xFormers ) -> torch.Tensor: - x = x + self.attn(self.norm1(x), - cu_seqlens=cu_seqlens, - rotary_pos_emb=rotary_pos_emb, - max_seqlen=max_seqlen, - seqlens=seqlens) + x = x + self.attn( + self.norm1(x), + cu_seqlens=cu_seqlens, + rotary_pos_emb=rotary_pos_emb, + max_seqlen=max_seqlen, + seqlens=seqlens, + ) x = x + self.mlp(self.norm2(x)) return x class Qwen3_VisionPatchMerger(nn.Module): - def __init__( self, d_model: int, context_dim: int, - norm_layer: Optional[Callable[[int], nn.Module]] = None, + norm_layer: Callable[[int], nn.Module] | None = None, spatial_merge_size: int = 2, use_postshuffle_norm: bool = False, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, ) -> None: @@ -234,19 +268,23 @@ def __init__( if norm_layer is None: norm_layer = partial(nn.LayerNorm, eps=1e-6) self.norm = norm_layer(context_dim) - self.linear_fc1 = ColumnParallelLinear(self.hidden_size, - self.hidden_size, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.linear_fc1", - disable_tp=use_data_parallel) + self.linear_fc1 = ColumnParallelLinear( + self.hidden_size, + self.hidden_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.linear_fc1", + disable_tp=use_data_parallel, + ) self.act_fn = nn.GELU() - self.linear_fc2 = RowParallelLinear(self.hidden_size, - d_model, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.linear_fc2", - disable_tp=use_data_parallel) + self.linear_fc2 = RowParallelLinear( + self.hidden_size, + d_model, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.linear_fc2", + disable_tp=use_data_parallel, + ) def forward(self, x: torch.Tensor) -> torch.Tensor: if self.use_postshuffle_norm: @@ -261,14 +299,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class Qwen3_VisionTransformer(nn.Module): - def __init__( self, vision_config: Qwen3VLVisionConfig, norm_eps: float = 1e-6, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, + attn_backend_override: _Backend | None = None, ) -> None: super().__init__() self.hidden_size = vision_config.hidden_size @@ -284,8 +322,9 @@ def __init__( # NOTE: This is used for creating empty tensor for all_gather for # DP ViT. Here out_hidden_size is enlarged due to deepstack - self.out_hidden_size = (vision_config.out_hidden_size * - (1 + len(self.deepstack_visual_indexes))) + self.out_hidden_size = vision_config.out_hidden_size * ( + 1 + len(self.deepstack_visual_indexes) + ) self.patch_embed = Qwen3_VisionPatchEmbed( patch_size=self.patch_size, @@ -294,8 +333,7 @@ def __init__( hidden_size=self.hidden_size, ) - self.pos_embed = nn.Embedding(self.num_position_embeddings, - self.hidden_size) + self.pos_embed = nn.Embedding(self.num_position_embeddings, self.hidden_size) norm_layer = partial(nn.LayerNorm, eps=norm_eps) head_dim = self.hidden_size // self.num_heads @@ -311,49 +349,57 @@ def __init__( use_data_parallel=use_data_parallel, ) - self.deepstack_merger_list = nn.ModuleList([ - Qwen3_VisionPatchMerger( - d_model=vision_config.out_hidden_size, - context_dim=self.hidden_size, - spatial_merge_size=self.spatial_merge_size, - use_postshuffle_norm=True, - norm_layer=norm_layer, - quant_config=quant_config, - prefix=f"{prefix}.deepstack_merger_list.{layer_idx}", - use_data_parallel=use_data_parallel) - for layer_idx in range(len(self.deepstack_visual_indexes)) - ]) + self.deepstack_merger_list = nn.ModuleList( + [ + Qwen3_VisionPatchMerger( + d_model=vision_config.out_hidden_size, + context_dim=self.hidden_size, + spatial_merge_size=self.spatial_merge_size, + use_postshuffle_norm=True, + norm_layer=norm_layer, + quant_config=quant_config, + prefix=f"{prefix}.deepstack_merger_list.{layer_idx}", + use_data_parallel=use_data_parallel, + ) + for layer_idx in range(len(self.deepstack_visual_indexes)) + ] + ) self.attn_backend = get_vit_attn_backend( - head_size=head_dim, dtype=torch.get_default_dtype()) - use_upstream_fa = False - if self.attn_backend != _Backend.FLASH_ATTN and \ - check_upstream_fa_availability( - torch.get_default_dtype()): - self.attn_backend = _Backend.FLASH_ATTN - use_upstream_fa = True + head_size=head_dim, + dtype=torch.get_default_dtype(), + attn_backend_override=attn_backend_override, + ) + # /--------------- Metax Modification ---------------\ + use_upstream_fa = True + # \--------------- Metax Modification ---------------/ if self.attn_backend not in { - _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS, - _Backend.ROCM_AITER_FA + _Backend.FLASH_ATTN, + _Backend.TORCH_SDPA, + _Backend.XFORMERS, + _Backend.ROCM_AITER_FA, }: raise RuntimeError( - f"Qwen3-VL does not support {self.attn_backend} backend now.") - - self.blocks = nn.ModuleList([ - Qwen3_VisionBlock( - dim=self.hidden_size, - num_heads=self.num_heads, - mlp_hidden_dim=vision_config.intermediate_size, - act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act], - norm_layer=norm_layer, - quant_config=quant_config, - prefix=f"{prefix}.blocks.{layer_idx}", - use_data_parallel=use_data_parallel, - attn_backend=self.attn_backend, - use_upstream_fa=use_upstream_fa) - for layer_idx in range(vision_config.depth) - ]) + f"Qwen3-VL does not support {self.attn_backend} backend now." + ) + self.blocks = nn.ModuleList( + [ + Qwen3_VisionBlock( + dim=self.hidden_size, + num_heads=self.num_heads, + mlp_hidden_dim=vision_config.intermediate_size, + act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act], + norm_layer=norm_layer, + quant_config=quant_config, + prefix=f"{prefix}.blocks.{layer_idx}", + use_data_parallel=use_data_parallel, + attn_backend=self.attn_backend, + use_upstream_fa=use_upstream_fa, + ) + for layer_idx in range(vision_config.depth) + ] + ) @property def dtype(self) -> torch.dtype: @@ -363,16 +409,10 @@ def dtype(self) -> torch.dtype: def device(self) -> torch.device: return self.patch_embed.proj.weight.device - def rot_pos_emb(self, grid_thw): + def rot_pos_emb(self, grid_thw: list[list[int]]): pos_ids = [] - # Support both Tensor and list inputs for DP path - if isinstance(grid_thw, list): - grid_list = grid_thw - max_grid_size = max(max(h, w) for _, h, w in grid_list) - else: - grid_list = grid_thw.tolist() - max_grid_size = int(grid_thw[:, 1:].max().item()) - for t, h, w in grid_list: + max_grid_size = max(max(h, w) for _, h, w in grid_thw) + for t, h, w in grid_thw: hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) hpos_ids = hpos_ids.reshape( h // self.spatial_merge_size, @@ -392,32 +432,25 @@ def rot_pos_emb(self, grid_thw): ) wpos_ids = wpos_ids.permute(0, 2, 1, 3) wpos_ids = wpos_ids.flatten() - pos_ids.append( - torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) pos_ids = torch.cat(pos_ids, dim=0) rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) return rotary_pos_emb - def fast_pos_embed_interpolate(self, - grid_thw: list[list[int]]) -> torch.Tensor: - + def fast_pos_embed_interpolate(self, grid_thw: list[list[int]]) -> torch.Tensor: num_grid_per_side = self.num_grid_per_side m_size = self.spatial_merge_size hidden_dim = self.pos_embed.embedding_dim outputs = [] for t, h, w in grid_thw: - h_idxs = torch.linspace(0, - num_grid_per_side - 1, - h, - dtype=torch.float32, - device=self.device) - w_idxs = torch.linspace(0, - num_grid_per_side - 1, - w, - dtype=torch.float32, - device=self.device) + h_idxs = torch.linspace( + 0, num_grid_per_side - 1, h, dtype=torch.float32, device=self.device + ) + w_idxs = torch.linspace( + 0, num_grid_per_side - 1, w, dtype=torch.float32, device=self.device + ) h_floor = h_idxs.to(torch.long) w_floor = w_idxs.to(torch.long) @@ -428,15 +461,9 @@ def fast_pos_embed_interpolate(self, dw = w_idxs - w_floor # Create meshgrid view for all h, w vars - dh_grid, dw_grid = torch.meshgrid(dh, dw, indexing='ij') - h_floor_grid, w_floor_grid = torch.meshgrid(h_floor, - w_floor, - indexing='ij') - h_ceil_grid, w_ceil_grid = torch.meshgrid(h_ceil, - w_ceil, - indexing='ij') - h_floor_grid_idx = h_floor_grid * num_grid_per_side - h_ceil_grid_idx = h_ceil_grid * num_grid_per_side + dh_grid, dw_grid = torch.meshgrid(dh, dw, indexing="ij") + h_floor_grid, w_floor_grid = torch.meshgrid(h_floor, w_floor, indexing="ij") + h_ceil_grid, w_ceil_grid = torch.meshgrid(h_ceil, w_ceil, indexing="ij") # original computation of weights # w00 = (1 - dh_grid) * (1 - dw_grid) @@ -448,30 +475,25 @@ def fast_pos_embed_interpolate(self, w11 = dh_grid * dw_grid w10 = dh_grid - w11 w01 = dw_grid - w11 - w00 = 1 - dh_grid - dw_grid + w11 + w00 = 1 - dh_grid - w01 - idx00 = h_floor_grid_idx + w_floor_grid - idx01 = h_floor_grid_idx + w_ceil_grid - idx10 = h_ceil_grid_idx + w_floor_grid - idx11 = h_ceil_grid_idx + w_ceil_grid + h_grid = torch.stack([h_floor_grid, h_floor_grid, h_ceil_grid, h_ceil_grid]) + w_grid = torch.stack([w_floor_grid, w_ceil_grid, w_floor_grid, w_ceil_grid]) + h_grid_idx = h_grid * num_grid_per_side - indices = torch.stack([idx00, idx01, idx10, idx11], - dim=0).reshape(4, -1) - weights = torch.stack([w00, w01, w10, w11], - dim=0).reshape(4, -1, 1) - weights = weights.to(dtype=self.dtype, device=self.device) + indices = (h_grid_idx + w_grid).reshape(4, -1) + weights = torch.stack([w00, w01, w10, w11], dim=0).reshape(4, -1, 1) + weights = weights.to(dtype=self.dtype) embeds = self.pos_embed(indices) weighted_embeds = embeds * weights - p0, p1, p2, p3 = weighted_embeds.unbind(dim=0) - combined = p0 + p1 + p2 + p3 - - combined = combined.view(h * w, hidden_dim) - repeated = combined.unsqueeze(0).expand(t, -1, -1).contiguous() - repeated = repeated.view(t, h // m_size, m_size, w // m_size, - m_size, hidden_dim) - repeated = repeated.permute(0, 1, 3, 2, 4, - 5).reshape(-1, hidden_dim) + combined = weighted_embeds.sum(dim=0) + + combined = combined.reshape( + h // m_size, m_size, w // m_size, m_size, hidden_dim + ) + combined = combined.permute(0, 2, 1, 3, 4).reshape(1, -1, hidden_dim) + repeated = combined.expand(t, -1, -1).reshape(-1, hidden_dim) outputs.append(repeated) return torch.cat(outputs, dim=0) @@ -479,64 +501,68 @@ def fast_pos_embed_interpolate(self, def compute_attn_mask_seqlen( self, cu_seqlens: torch.Tensor, - ) -> tuple[Optional[int], Optional[list[int]]]: - max_seqlen, seqlens = None, None - if self.attn_backend == _Backend.FLASH_ATTN: - max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + ) -> tuple[torch.Tensor, torch.Tensor]: + max_seqlen = torch.zeros([], device=cu_seqlens.device) + seqlens = torch.zeros(1, device=cu_seqlens.device) + if ( + self.attn_backend == _Backend.FLASH_ATTN + or self.attn_backend == _Backend.ROCM_AITER_FA + ): + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() elif self.attn_backend == _Backend.XFORMERS: - seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() + seqlens = cu_seqlens[1:] - cu_seqlens[:-1] return max_seqlen, seqlens def forward( self, x: torch.Tensor, - grid_thw: list[list[int]], + grid_thw: torch.Tensor | list[list[int]], ) -> torch.Tensor: - hidden_states = x.to(device=self.device, dtype=self.dtype) + hidden_states = x.to(device=self.device, dtype=self.dtype, non_blocking=True) hidden_states = self.patch_embed(hidden_states) - pos_embeds = self.fast_pos_embed_interpolate(grid_thw) - hidden_states = hidden_states + pos_embeds - rotary_pos_emb = self.rot_pos_emb(grid_thw) + if isinstance(grid_thw, list): + grid_thw_list = grid_thw + grid_thw = torch.tensor(grid_thw, dtype=torch.int32) + else: + grid_thw_list = grid_thw.tolist() - grid_thw_tensor = torch.tensor(grid_thw, - device=self.device, - dtype=torch.int32) + pos_embeds = self.fast_pos_embed_interpolate(grid_thw_list) + hidden_states = hidden_states + pos_embeds + rotary_pos_emb = self.rot_pos_emb(grid_thw_list) + rotary_pos_emb = rotary_pos_emb.to(hidden_states.device, non_blocking=True) cu_seqlens = torch.repeat_interleave( - grid_thw_tensor[:, 1] * grid_thw_tensor[:, 2], - grid_thw_tensor[:, 0]).cumsum( - dim=0, - dtype=grid_thw_tensor.dtype - if torch.jit.is_tracing() else torch.int32, - ) - cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0] + ).cumsum(dim=0, dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32) + cu_seqlens = torch.cat([cu_seqlens.new_zeros(1), cu_seqlens]) hidden_states = hidden_states.unsqueeze(1) - rotary_pos_emb = rotary_pos_emb.to(hidden_states.device) max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens) + cu_seqlens = cu_seqlens.to(self.device, non_blocking=True) deepstack_feature_lists = [] for layer_num, blk in enumerate(self.blocks): - hidden_states = blk(hidden_states, - cu_seqlens=cu_seqlens, - rotary_pos_emb=rotary_pos_emb, - max_seqlen=max_seqlen, - seqlens=seqlens) + hidden_states = blk( + hidden_states, + cu_seqlens=cu_seqlens, + rotary_pos_emb=rotary_pos_emb, + max_seqlen=max_seqlen, + seqlens=seqlens, + ) if layer_num in self.deepstack_visual_indexes: - deepstack_merger_idx = self.deepstack_visual_indexes.index( - layer_num) - deepstack_feature = self.deepstack_merger_list[ - deepstack_merger_idx](hidden_states) + deepstack_merger_idx = self.deepstack_visual_indexes.index(layer_num) + deepstack_feature = self.deepstack_merger_list[deepstack_merger_idx]( + hidden_states + ) deepstack_feature_lists.append(deepstack_feature) hidden_states = self.merger(hidden_states) hidden_states = torch.cat( - [hidden_states] + deepstack_feature_lists, - dim=1) # [seq_len, hidden_size * (1 + depth_of_deepstack)] + [hidden_states] + deepstack_feature_lists, dim=1 + ) # [seq_len, hidden_size * (1 + depth_of_deepstack)] return hidden_states - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("attn.qkv.", "attn.q.", "q"), @@ -547,7 +573,10 @@ def load_weights(self, weights: Iterable[tuple[str, loaded_params: set[str] = set() for name, loaded_weight in weights: - for (param_name, weight_name, shard_id) in stacked_params_mapping: + if name.endswith("patch_embed.proj.weight"): + loaded_weight = conv3d_to_linear_weight(loaded_weight) + + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -558,15 +587,13 @@ def load_weights(self, weights: Iterable[tuple[str, break else: param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params class Qwen3VLProcessingInfo(Qwen2VLProcessingInfo): - def get_hf_config(self): return self.ctx.get_hf_config(Qwen3VLConfig) @@ -580,8 +607,7 @@ def get_hf_processor(self, **kwargs: object) -> Qwen3VLProcessor: def get_tokenizer(self): return self.ctx.tokenizer - def get_image_processor(self, - **kwargs: object) -> Qwen2VLImageProcessorFast: + def get_image_processor(self, **kwargs: object) -> Qwen2VLImageProcessorFast: return self.get_hf_processor(**kwargs).image_processor def get_video_processor(self, **kwargs: object) -> Qwen3VLVideoProcessor: @@ -594,8 +620,7 @@ def _get_vision_info( image_height: int, num_frames: int = 2, do_resize: bool = True, - image_processor: Optional[Union[Qwen2VLImageProcessorFast, - Qwen3VLVideoProcessor]], + image_processor: Qwen2VLImageProcessorFast | Qwen3VLVideoProcessor | None, ) -> tuple[ImageSize, int]: if image_processor is None and num_frames > 1: image_processor = self.get_video_processor() @@ -615,7 +640,7 @@ def _get_vision_info( smart_resize = video_smart_resize extra_kwargs = { "num_frames": num_frames, - "temporal_factor": temporal_patch_size + "temporal_factor": temporal_patch_size, } else: smart_resize = image_smart_resize @@ -628,11 +653,9 @@ def _get_vision_info( max_pixels=image_processor.size["longest_edge"], **extra_kwargs, ) - preprocessed_size = ImageSize(width=resized_width, - height=resized_height) + preprocessed_size = ImageSize(width=resized_width, height=resized_height) else: - preprocessed_size = ImageSize(width=image_width, - height=image_height) + preprocessed_size = ImageSize(width=image_width, height=image_height) padded_num_frames = num_frames + num_frames % temporal_patch_size @@ -645,11 +668,10 @@ def _get_vision_info( return preprocessed_size, num_vision_tokens - def _get_max_video_frames(self, - max_tokens: int, - start_num_frames: int = 2) -> int: - return super()._get_max_video_frames(max_tokens, - start_num_frames=start_num_frames) + def _get_max_video_frames(self, max_tokens: int, start_num_frames: int = 2) -> int: + return super()._get_max_video_frames( + max_tokens, start_num_frames=start_num_frames + ) def get_num_frames_with_most_features( self, @@ -657,7 +679,8 @@ def get_num_frames_with_most_features( mm_counts: Mapping[str, int], ) -> int: return super().get_num_frames_with_most_features( - seq_len, mm_counts, max_frames_per_video=_MAX_FRAMES_PER_VIDEO) + seq_len, mm_counts, max_frames_per_video=_MAX_FRAMES_PER_VIDEO + ) def get_max_video_tokens( self, @@ -668,8 +691,7 @@ def get_max_video_tokens( video_soft_tokens = self.get_num_video_tokens( image_width=target_width, image_height=target_height, - num_frames=self.get_num_frames_with_most_features( - seq_len, mm_counts), + num_frames=self.get_num_frames_with_most_features(seq_len, mm_counts), image_processor=None, ) @@ -678,25 +700,28 @@ def get_max_video_tokens( formatted_video_soft_tokens = video_soft_tokens * 12.5 return int(formatted_video_soft_tokens) - def _calculate_timestamps(self, indices: list[int] | torch.Tensor, - video_fps: float, merge_size: int): + def _calculate_timestamps( + self, indices: list[int] | torch.Tensor, video_fps: float, merge_size: int + ): if not isinstance(indices, list): indices = indices.tolist() if len(indices) % merge_size != 0: # don't update metadata's frames_indices directly - indices = indices + [indices[-1] - ] * (merge_size - len(indices) % merge_size) + indices = indices + [indices[-1]] * (merge_size - len(indices) % merge_size) timestamps = [idx / video_fps for idx in indices] - timestamps = [(timestamps[i] + timestamps[i + merge_size - 1]) / 2 - for i in range(0, len(timestamps), merge_size)] + timestamps = [ + (timestamps[i] + timestamps[i + merge_size - 1]) / 2 + for i in range(0, len(timestamps), merge_size) + ] return timestamps def _get_video_second_idx( - self, - metadata: dict[str, Any], - out_item: MultiModalKwargsItem, - do_sample_frames: Optional[bool] = None, - sampled_fps: Optional[float] = None) -> list[int]: + self, + metadata: dict[str, Any], + out_item: MultiModalKwargsItem, + do_sample_frames: bool | None = None, + sampled_fps: float | None = None, + ) -> list[int]: video_processor = self.get_video_processor() merge_size = video_processor.merge_size indices = metadata["frames_indices"] @@ -712,20 +737,27 @@ def _get_video_second_idx( if do_sample_frames: # here video_fps is the fps of the sampled video, and # metadata["fps"] refers to the fps of the original video. - video_fps = sampled_fps if sampled_fps else video_processor.fps + sampled_fps = sampled_fps if sampled_fps else video_processor.fps total_num_frames = metadata["total_num_frames"] - num_frames = int(total_num_frames / metadata["fps"] * video_fps) + num_frames = int(total_num_frames / metadata["fps"] * sampled_fps) num_frames = min( - min(max(num_frames, video_processor.min_frames), - video_processor.max_frames), total_num_frames) - indices = np.linspace(0, total_num_frames - 1, - num_frames).round().astype(int).tolist() + min( + max(num_frames, video_processor.min_frames), + video_processor.max_frames, + ), + total_num_frames, + ) + indices = ( + np.linspace(0, total_num_frames - 1, num_frames) + .round() + .astype(int) + .tolist() + ) timestamps = self._calculate_timestamps(indices, video_fps, merge_size) return timestamps class Qwen3VLDummyInputsBuilder(BaseDummyInputsBuilder[Qwen3VLProcessingInfo]): - def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: num_images = mm_counts.get("image", 0) num_videos = mm_counts.get("video", 0) @@ -739,29 +771,80 @@ def get_dummy_mm_data( self, seq_len: int, mm_counts: Mapping[str, int], + mm_options: Mapping[str, BaseDummyOptions] | None = None, ) -> MultiModalDataDict: num_images = mm_counts.get("image", 0) num_videos = mm_counts.get("video", 0) + image_overrides = mm_options.get("image") if mm_options else None + video_overrides = mm_options.get("video") if mm_options else None - target_width, target_height = ( - self.info.get_image_size_with_most_features()) + target_width, target_height = self.info.get_image_size_with_most_features() target_num_frames = self.info.get_num_frames_with_most_features( - seq_len, mm_counts) + seq_len, mm_counts + ) + + if video_overrides: + assert isinstance(video_overrides, VideoDummyOptions) + num_frames_override = video_overrides.num_frames + if num_frames_override: + if num_frames_override > target_num_frames: + logger.warning( + "video.num_frames override (%d) exceeds model's " + "maximum number of frames (%d), will be ignored", + num_frames_override, + target_num_frames, + ) + if num_frames_override < 2: + logger.warning( + "video.num_frames override (%d) cannot be less " + "than 2, will be ignored", + num_frames_override, + ) + target_num_frames = min(target_num_frames, num_frames_override) + target_num_frames = max(target_num_frames, 2) + target_video_size, _ = self.info._get_vision_info( image_width=target_width, image_height=target_height, num_frames=target_num_frames, image_processor=self.info.get_video_processor(), ) + # NOTE: we need to do this check here since Qwen3-VL resizes video + # frames depending on how many frames there are. + width, height = target_video_size.width, target_video_size.height + if video_overrides: + assert isinstance(video_overrides, VideoDummyOptions) + width_override = video_overrides.width + if width_override: + if width_override > width: + logger.warning( + "video.width override (%d) exceeds model's " + "maximum width (%d), will be ignored", + width_override, + width, + ) + width = min(width, width_override) + height_override = video_overrides.height + if height_override: + if height_override > height: + logger.warning( + "video.height override (%d) exceeds model's " + "maximum height (%d), will be ignored", + height_override, + height, + ) + height = min(height, height_override) + return { - "image": - self._get_dummy_images(width=target_width, - height=target_height, - num_images=num_images), - "video": - self._get_dummy_videos( - width=target_video_size.width, - height=target_video_size.height, + "image": self._get_dummy_images( + width=target_width, + height=target_height, + num_images=num_images, + overrides=image_overrides, + ), + "video": self._get_dummy_videos( + width=width, + height=height, num_frames=target_num_frames, num_videos=num_videos, ), @@ -775,7 +858,6 @@ def _get_dummy_videos( num_frames: int, num_videos: int, ) -> list[VideoItem]: - num_frames = max(num_frames, 2) video = np.full((num_frames, width, height, 3), 255, dtype=np.uint8) video_items = [] for i in range(num_videos): @@ -791,22 +873,8 @@ def _get_dummy_videos( video_items.append(video_item) return video_items - def get_dummy_processor_inputs(self, seq_len, mm_counts): - processor_inputs = super().get_dummy_processor_inputs( - seq_len, mm_counts) - # HACK(Isotr0py): We set do_resize to False here to reuse Qwen2-VL's - # profiling logic, which will be problematic for configurable mm - # profiling. - # TODO(Isotr0py): Switch to the implementation in - # https://github.com/vllm-project/vllm/pull/25557 - # after supporting configurable mm profiling. - processor_inputs.hf_processor_mm_kwargs = {"do_resize": False} - return processor_inputs - - -class Qwen3VLMultiModalProcessor(BaseMultiModalProcessor[Qwen3VLProcessingInfo] - ): +class Qwen3VLMultiModalProcessor(BaseMultiModalProcessor[Qwen3VLProcessingInfo]): def _get_data_parser(self) -> MultiModalDataParser: return MultiModalDataParser(video_needs_metadata=True) @@ -821,13 +889,12 @@ def _call_hf_processor( processor = self.info.get_hf_processor(**mm_kwargs) # Separate video processing from image processing. Because the videos - # are processed into serval image patches - if ("videos" in mm_data and isinstance(mm_data["videos"], list) - and len(mm_data["videos"]) > 0): + # are processed into several image patches + if videos := mm_data.pop("videos", []): video_grid_thw_lst = [] pixel_values_videos_lst = [] - for item_idx, item in enumerate(mm_data.pop("videos", [])): + for item in videos: video_array, metadata = item # NOTE: @JJJYmmm new attr metadata.frames_indices indicates @@ -842,12 +909,12 @@ def _call_hf_processor( # qwen_vl_utils already has "do_sample_frames" in # mm_kwargs, don't overwrite it. video_mm_kwargs["do_sample_frames"] = metadata.get( - "do_sample_frames", False) + "do_sample_frames", False + ) - metadata = VideoMetadata(**{ - k: metadata[k] - for k in metadata if k != "do_sample_frames" - }) + metadata = VideoMetadata( + **{k: metadata[k] for k in metadata if k != "do_sample_frames"} + ) video_mm_data = dict() video_mm_data["videos"] = [[video_array]] @@ -860,8 +927,7 @@ def _call_hf_processor( tok_kwargs=tok_kwargs, ) input_ids = video_outputs.pop("input_ids") - video_placeholder = processor.tokenizer.batch_decode( - input_ids)[0] + video_placeholder = processor.tokenizer.batch_decode(input_ids)[0] prompt = prompt.replace( "<|vision_start|><|video_pad|><|vision_end|>", video_placeholder, @@ -869,8 +935,7 @@ def _call_hf_processor( ) video_grid_thw_lst.append(video_outputs["video_grid_thw"]) - pixel_values_videos_lst.append( - video_outputs["pixel_values_videos"]) + pixel_values_videos_lst.append(video_outputs["pixel_values_videos"]) video_outputs = dict( pixel_values_videos=torch.cat(pixel_values_videos_lst), video_grid_thw=torch.cat(video_grid_thw_lst), @@ -903,14 +968,18 @@ def _get_mm_fields_config( return dict( pixel_values=MultiModalFieldConfig.flat_from_sizes( - "image", image_grid_sizes), + "image", image_grid_sizes + ), image_embeds=MultiModalFieldConfig.flat_from_sizes( - "image", image_grid_sizes), + "image", image_grid_sizes + ), image_grid_thw=MultiModalFieldConfig.batched("image"), pixel_values_videos=MultiModalFieldConfig.flat_from_sizes( - "video", video_grid_sizes), + "video", video_grid_sizes + ), video_embeds=MultiModalFieldConfig.flat_from_sizes( - "video", video_grid_sizes), + "video", video_grid_sizes + ), video_grid_thw=MultiModalFieldConfig.batched("video"), ) @@ -921,8 +990,7 @@ def _get_prompt_updates( out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) - image_processor = self.info.get_image_processor( - **hf_processor_mm_kwargs) + image_processor = self.info.get_image_processor(**hf_processor_mm_kwargs) tokenizer = self.info.get_tokenizer() hf_config = self.info.get_hf_config() @@ -951,26 +1019,28 @@ def get_video_replacement_qwen3vl(item_idx: int): if is_list_of(sampled_fps, float): sampled_fps = sampled_fps[item_idx] timestamps = self.info._get_video_second_idx( - metadata, out_item, do_sample_frames, sampled_fps) + metadata, out_item, do_sample_frames, sampled_fps + ) assert len(timestamps) == grid_thw[0], ( f"The timestamps length({len(timestamps)}) should be equal " - f"video length ({grid_thw[0]}).") + f"video length ({grid_thw[0]})." + ) frames_idx_token = [ - tokenizer.encode(f"<{curr_time:.1f} seconds>", - add_special_tokens=False) + tokenizer.encode(f"<{curr_time:.1f} seconds>", add_special_tokens=False) for curr_time in timestamps ] num_tokens_per_frame = int(grid_thw[1:].prod()) // merge_length placeholder = [] for frame_idx in frames_idx_token: placeholder.extend(frame_idx) - placeholder.extend([vision_start_token_id] + - [video_token_id] * num_tokens_per_frame + - [vision_end_token_id]) - return PromptUpdateDetails.select_token_id(placeholder, - video_token_id) + placeholder.extend( + [vision_start_token_id] + + [video_token_id] * num_tokens_per_frame + + [vision_end_token_id] + ) + return PromptUpdateDetails.select_token_id(placeholder, video_token_id) return [ PromptReplacement( @@ -978,7 +1048,6 @@ def get_video_replacement_qwen3vl(item_idx: int): target=hf_processor.image_token, replacement=get_image_replacement_qwen3vl, ), - # NOTE: We match string on purpose since searching sequence of # token ids takes more time. PromptReplacement( @@ -998,28 +1067,29 @@ def get_video_replacement_qwen3vl(item_idx: int): "intermediate_tensors": 0, "inputs_embeds": 0, # the same shape as input_embeds - "deepstack_input_embeds": 0 - }) + "deepstack_input_embeds": 0, + } +) class Qwen3LLMModel(Qwen3Model): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__(vllm_config=vllm_config, prefix=prefix) if not get_pp_group().is_first_rank: assert self.start_layer >= len( - vllm_config.model_config.hf_config.vision_config. - deepstack_visual_indexes), ( - "start_layer should be greater than or equal to " - "len(deepstack_visual_indexes)") + vllm_config.model_config.hf_config.vision_config.deepstack_visual_indexes + ), ( + "start_layer should be greater than or equal to " + "len(deepstack_visual_indexes)" + ) def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, # args for deepstack - deepstack_input_embeds: Optional[IntermediateTensors] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + deepstack_input_embeds: IntermediateTensors | None = None, + ) -> torch.Tensor | IntermediateTensors: if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds @@ -1030,32 +1100,32 @@ def forward( assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for layer_idx, layer in enumerate( - self.layers[self.start_layer:self.end_layer]): - layer_idx = layer_idx + self.start_layer - + for layer_idx, layer in islice( + enumerate(self.layers), self.start_layer, self.end_layer + ): hidden_states, residual = layer( positions, hidden_states, residual, ) - if deepstack_input_embeds is not None and \ - layer_idx in range(0, len(deepstack_input_embeds)): - hidden_states = hidden_states + deepstack_input_embeds[ - f"deepstack_input_embeds_{layer_idx}"] + if deepstack_input_embeds is not None and layer_idx in range( + 0, len(deepstack_input_embeds) + ): + hidden_states = ( + hidden_states + + deepstack_input_embeds[f"deepstack_input_embeds_{layer_idx}"] + ) if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states class Qwen3LLMForCausalLM(Qwen3ForCausalLM): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super(Qwen3ForCausalLM, self).__init__() config = vllm_config.model_config.hf_config.text_config @@ -1072,24 +1142,33 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): if config.tie_word_embeddings: self.lm_head = self.model.embed_tokens else: - self.lm_head = ParallelLMHead(config.vocab_size, - config.hidden_size, - quant_config=quant_config, - prefix="lm_head") + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix="lm_head", + ) else: self.lm_head = PPMissingLayer() self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model.make_empty_intermediate_tensors + ) + +@MULTIMODAL_REGISTRY.register_processor( + Qwen3VLMultiModalProcessor, + info=Qwen3VLProcessingInfo, + dummy_inputs=Qwen3VLDummyInputsBuilder, +) +class Qwen3VLForConditionalGeneration( + nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE +): + merge_by_field_config = True + multimodal_cpu_fields = {"image_grid_thw", "video_grid_thw"} -@MULTIMODAL_REGISTRY.register_processor(Qwen3VLMultiModalProcessor, - info=Qwen3VLProcessingInfo, - dummy_inputs=Qwen3VLDummyInputsBuilder) -class Qwen3VLForConditionalGeneration(nn.Module, SupportsMultiModal, - SupportsLoRA, SupportsPP): packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -1110,10 +1189,11 @@ class Qwen3VLForConditionalGeneration(nn.Module, SupportsMultiModal, "model.visual.": "visual.", "lm_head.": "language_model.lm_head.", "model.language_model.": "language_model.model.", - }) + } + ) @classmethod - def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: + def get_placeholder_str(cls, modality: str, i: int) -> str | None: if modality.startswith("image"): return "<|vision_start|><|image_pad|><|vision_end|>" if modality.startswith("video"): @@ -1130,37 +1210,46 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "model"): self.config = config self.multimodal_config = multimodal_config self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" - if not multimodal_config.get_limit_per_prompt("image") and \ - not multimodal_config.get_limit_per_prompt("video"): + if not multimodal_config.get_limit_per_prompt( + "image" + ) and not multimodal_config.get_limit_per_prompt("video"): self.visual = None else: + attn_backend_override = ( + multimodal_config.mm_encoder_attn_backend + if multimodal_config is not None + else None + ) self.visual = Qwen3_VisionTransformer( config.vision_config, norm_eps=getattr(config, "rms_norm_eps", 1e-6), quant_config=quant_config, prefix=maybe_prefix(prefix, "visual"), use_data_parallel=self.use_data_parallel, + attn_backend_override=attn_backend_override, ) - self.language_model = Qwen3LLMForCausalLM(vllm_config=vllm_config, - prefix=maybe_prefix( - prefix, - "language_model")) + self.language_model = Qwen3LLMForCausalLM( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "language_model") + ) self.make_empty_intermediate_tensors = ( - self.language_model.make_empty_intermediate_tensors) + self.language_model.make_empty_intermediate_tensors + ) - self.use_deepstack = hasattr(config.vision_config, - 'deepstack_visual_indexes') - self.deepstack_num_level = len( - config.vision_config.deepstack_visual_indexes - ) if self.use_deepstack else 0 + self.use_deepstack = hasattr(config.vision_config, "deepstack_visual_indexes") + self.deepstack_num_level = ( + len(config.vision_config.deepstack_visual_indexes) + if self.use_deepstack + else 0 + ) # register buffer for deepstack if self.use_deepstack and self.visual is not None: self.deepstack_input_embeds = [ torch.zeros( vllm_config.scheduler_config.max_num_batched_tokens, - config.text_config.hidden_size) + config.text_config.hidden_size, + ) for _ in range(self.deepstack_num_level) ] else: @@ -1168,30 +1257,34 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "model"): self.visual_dim = config.vision_config.out_hidden_size self.multiscale_dim = self.visual_dim * self.deepstack_num_level - def _get_deepstack_input_embeds(self, - num_tokens: int) -> IntermediateTensors: + def _get_deepstack_input_embeds(self, num_tokens: int) -> IntermediateTensors: # get deepstack_input_embeds from buffer, and clear the buffer - return IntermediateTensors({ - f"deepstack_input_embeds_{idx}": - self.deepstack_input_embeds[idx][:num_tokens] - for idx in range(self.deepstack_num_level) - }) - - def _set_deepstack_input_embeds( - self, deepstack_input_embeds: torch.Tensor) -> None: + return IntermediateTensors( + { + f"deepstack_input_embeds_{idx}": self.deepstack_input_embeds[idx][ + :num_tokens + ] + for idx in range(self.deepstack_num_level) + } + ) + + def _set_deepstack_input_embeds(self, deepstack_input_embeds: torch.Tensor) -> None: # set deepstack_input_embeds to buffer num_tokens = deepstack_input_embeds.size(1) if num_tokens > self.deepstack_input_embeds[0].size(0): self.deepstack_input_embeds = [ - torch.zeros(num_tokens, - self.config.text_config.hidden_size, - device=self.deepstack_input_embeds[0].device, - dtype=self.deepstack_input_embeds[0].dtype) + torch.zeros( + num_tokens, + self.config.text_config.hidden_size, + device=self.deepstack_input_embeds[0].device, + dtype=self.deepstack_input_embeds[0].dtype, + ) for _ in range(self.deepstack_num_level) ] for idx in range(self.deepstack_num_level): self.deepstack_input_embeds[idx][:num_tokens].copy_( - deepstack_input_embeds[idx]) + deepstack_input_embeds[idx] + ) def _clear_deepstack_input_embeds(self, num_tokens: int) -> None: # clear deepstack_input_embeds in buffer @@ -1199,24 +1292,9 @@ def _clear_deepstack_input_embeds(self, num_tokens: int) -> None: for idx in range(self.deepstack_num_level): self.deepstack_input_embeds[idx][:num_tokens].zero_() - def _validate_and_reshape_mm_tensor(self, mm_input: object, - name: str) -> torch.Tensor: - if not isinstance(mm_input, (torch.Tensor, list)): - raise ValueError(f"Incorrect type of {name}. " - f"Got type: {type(mm_input)}") - if isinstance(mm_input, torch.Tensor): - if mm_input.ndim == 2: - return mm_input - if mm_input.ndim != 3: - raise ValueError(f"{name} should be 2D or batched 3D tensor. " - f"Got ndim: {mm_input.ndim} " - f"(shape={mm_input.shape})") - return torch.concat(list(mm_input)) - else: - return torch.concat(mm_input) - def _parse_and_validate_image_input( - self, **kwargs: object) -> Optional[Qwen2_5_VLImageInputs]: + self, **kwargs: object + ) -> Qwen2_5_VLImageInputs | None: pixel_values = kwargs.pop("pixel_values", None) image_embeds = kwargs.pop("image_embeds", None) image_grid_thw = kwargs.pop("image_grid_thw", None) @@ -1225,35 +1303,22 @@ def _parse_and_validate_image_input( return None if pixel_values is not None: - pixel_values = self._validate_and_reshape_mm_tensor( - pixel_values, "image pixel values") - image_grid_thw = self._validate_and_reshape_mm_tensor( - image_grid_thw, "image grid_thw") - - if not isinstance(pixel_values, (torch.Tensor, list)): - raise ValueError("Incorrect type of image pixel values. " - f"Got type: {type(pixel_values)}") - - return Qwen2_5_VLImagePixelInputs(type="pixel_values", - pixel_values=pixel_values, - image_grid_thw=image_grid_thw) + return Qwen2_5_VLImagePixelInputs( + type="pixel_values", + pixel_values=pixel_values, + image_grid_thw=image_grid_thw, + ) if image_embeds is not None: - image_embeds = self._validate_and_reshape_mm_tensor( - image_embeds, "image embeds") - image_grid_thw = self._validate_and_reshape_mm_tensor( - image_grid_thw, "image grid_thw") - - if not isinstance(image_embeds, torch.Tensor): - raise ValueError("Incorrect type of image embeddings. " - f"Got type: {type(image_embeds)}") return Qwen2_5_VLImageEmbeddingInputs( type="image_embeds", image_embeds=image_embeds, - image_grid_thw=image_grid_thw) + image_grid_thw=image_grid_thw, + ) def _parse_and_validate_video_input( - self, **kwargs: object) -> Optional[Qwen2_5_VLVideoInputs]: + self, **kwargs: object + ) -> Qwen2_5_VLVideoInputs | None: pixel_values_videos = kwargs.pop("pixel_values_videos", None) video_embeds = kwargs.pop("video_embeds", None) video_grid_thw = kwargs.pop("video_grid_thw", None) @@ -1263,11 +1328,6 @@ def _parse_and_validate_video_input( return None if pixel_values_videos is not None: - pixel_values_videos = self._validate_and_reshape_mm_tensor( - pixel_values_videos, "video pixel values") - video_grid_thw = self._validate_and_reshape_mm_tensor( - video_grid_thw, "video grid_thw") - return Qwen2_5_VLVideoPixelInputs( type="pixel_values_videos", pixel_values_videos=pixel_values_videos, @@ -1276,97 +1336,192 @@ def _parse_and_validate_video_input( ) if video_embeds is not None: - video_embeds = self._validate_and_reshape_mm_tensor( - video_embeds, "video embeds") - video_grid_thw = self._validate_and_reshape_mm_tensor( - video_grid_thw, "video grid_thw") - - if not isinstance(video_embeds, torch.Tensor): - raise ValueError("Incorrect type of video embeddings. " - f"Got type: {type(video_embeds)}") return Qwen2_5_VLVideoEmbeddingInputs( type="video_embeds", video_embeds=video_embeds, - video_grid_thw=video_grid_thw) + video_grid_thw=video_grid_thw, + ) def _process_image_input( - self, - image_input: Qwen2_5_VLImageInputs) -> tuple[torch.Tensor, ...]: - + self, image_input: Qwen2_5_VLImageInputs + ) -> tuple[torch.Tensor, ...]: grid_thw = image_input["image_grid_thw"] assert grid_thw.ndim == 2 - grid_thw_list = grid_thw.tolist() if image_input["type"] == "image_embeds": image_embeds = image_input["image_embeds"].type(self.visual.dtype) else: pixel_values = image_input["pixel_values"].type(self.visual.dtype) if self.use_data_parallel: - return run_dp_sharded_mrope_vision_model(self.visual, - pixel_values, - grid_thw_list, - rope_type="rope_3d") + return run_dp_sharded_mrope_vision_model( + self.visual, pixel_values, grid_thw.tolist(), rope_type="rope_3d" + ) else: - image_embeds = self.visual(pixel_values, - grid_thw=grid_thw_list) + image_embeds = self.visual(pixel_values, grid_thw=grid_thw) # Split concatenated embeddings for each image item. - # Using prod on grid_thw_list instead of grid_thw.prod avoids CUDA sync merge_size = self.visual.spatial_merge_size - sizes = (torch.tensor(grid_thw_list, dtype=torch.long).prod(-1) // - (merge_size * merge_size)).tolist() + sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist() return image_embeds.split(sizes) def _process_video_input( - self, - video_input: Qwen2_5_VLVideoInputs) -> tuple[torch.Tensor, ...]: - + self, video_input: Qwen2_5_VLVideoInputs + ) -> tuple[torch.Tensor, ...]: grid_thw = video_input["video_grid_thw"] assert grid_thw.ndim == 2 - grid_thw_list = grid_thw.tolist() if video_input["type"] == "video_embeds": video_embeds = video_input["video_embeds"].type(self.visual.dtype) else: pixel_values_videos = video_input["pixel_values_videos"].type( - self.visual.dtype) + self.visual.dtype + ) if self.use_data_parallel: - return run_dp_sharded_mrope_vision_model(self.visual, - pixel_values_videos, - grid_thw_list, - rope_type="rope_3d") + grid_thw_list = grid_thw.tolist() + return run_dp_sharded_mrope_vision_model( + self.visual, pixel_values_videos, grid_thw_list, rope_type="rope_3d" + ) else: - video_embeds = self.visual(pixel_values_videos, - grid_thw=grid_thw_list) + video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw) # Split concatenated embeddings for each video item. - # Using prod on grid_thw_list instead of grid_thw.prod avoids CUDA sync merge_size = self.visual.spatial_merge_size - sizes = (torch.tensor(grid_thw_list, dtype=torch.long).prod(-1) // - (merge_size * merge_size)).tolist() + sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist() return video_embeds.split(sizes) def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: mm_input_by_modality = {} for input_key in kwargs: - if input_key in ("pixel_values", "image_embeds" - ) and "image" not in mm_input_by_modality: - mm_input_by_modality[ - "image"] = self._parse_and_validate_image_input(**kwargs) - if input_key in ("pixel_values_videos", "video_embeds" - ) and "video" not in mm_input_by_modality: - mm_input_by_modality[ - "video"] = self._parse_and_validate_video_input(**kwargs) + if ( + input_key in ("pixel_values", "image_embeds") + and "image" not in mm_input_by_modality + ): + mm_input_by_modality["image"] = self._parse_and_validate_image_input( + **kwargs + ) + if ( + input_key in ("pixel_values_videos", "video_embeds") + and "video" not in mm_input_by_modality + ): + mm_input_by_modality["video"] = self._parse_and_validate_video_input( + **kwargs + ) return mm_input_by_modality + def get_mrope_input_positions( + self, + input_tokens: list[int], + hf_config: PretrainedConfig, + image_grid_thw: list[list[int]] | torch.Tensor, + video_grid_thw: list[list[int]] | torch.Tensor, + second_per_grid_ts: list[float] | None = None, + audio_feature_lengths: torch.Tensor | None = None, + use_audio_in_video: bool = False, + ) -> tuple[torch.Tensor, int]: + """Get mrope input positions and delta value.""" + + video_grid_thw = [[1, h, w] for t, h, w in video_grid_thw for _ in range(t)] + + image_token_id = hf_config.image_token_id + video_token_id = hf_config.video_token_id + vision_start_token_id = hf_config.vision_start_token_id + spatial_merge_size = hf_config.vision_config.spatial_merge_size + + input_tokens_tensor = torch.tensor(input_tokens) + vision_start_indices = torch.argwhere( + input_tokens_tensor == vision_start_token_id + ).squeeze(1) + vision_tokens = input_tokens_tensor[vision_start_indices + 1] + image_nums = (vision_tokens == image_token_id).sum() + video_nums = (vision_tokens == video_token_id).sum() + llm_pos_ids_list: list = [] + + st = 0 + remain_images, remain_videos = image_nums, video_nums + + image_index, video_index = 0, 0 + for _ in range(image_nums + video_nums): + if image_token_id in input_tokens and remain_images > 0: + ed_image = input_tokens.index(image_token_id, st) + else: + ed_image = len(input_tokens) + 1 + if video_token_id in input_tokens and remain_videos > 0: + ed_video = input_tokens.index(video_token_id, st) + else: + ed_video = len(input_tokens) + 1 + if ed_image < ed_video: + t, h, w = ( + image_grid_thw[image_index][0], + image_grid_thw[image_index][1], + image_grid_thw[image_index][2], + ) + image_index += 1 + remain_images -= 1 + ed = ed_image + else: + t, h, w = ( + video_grid_thw[video_index][0], + video_grid_thw[video_index][1], + video_grid_thw[video_index][2], + ) + video_index += 1 + remain_videos -= 1 + ed = ed_video + + llm_grid_t, llm_grid_h, llm_grid_w = ( + t, + h // spatial_merge_size, + w // spatial_merge_size, + ) + text_len = ed - st + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx + ) + + t_index = ( + torch.arange(llm_grid_t) + .view(-1, 1) + .expand(-1, llm_grid_h * llm_grid_w) + .flatten() + ) + h_index = ( + torch.arange(llm_grid_h) + .view(1, -1, 1) + .expand(llm_grid_t, -1, llm_grid_w) + .flatten() + ) + w_index = ( + torch.arange(llm_grid_w) + .view(1, 1, -1) + .expand(llm_grid_t, llm_grid_h, -1) + .flatten() + ) + llm_pos_ids_list.append( + torch.stack([t_index, h_index, w_index]) + text_len + st_idx + ) + st = ed + llm_grid_t * llm_grid_h * llm_grid_w + + if st < len(input_tokens): + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + text_len = len(input_tokens) - st + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx + ) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() + + return llm_positions, mrope_position_delta + def get_language_model(self) -> torch.nn.Module: return self.language_model def get_multimodal_embeddings( - self, **kwargs: object) -> Optional[MultiModalEmbeddings]: - - mm_input_by_modality = self._parse_and_validate_multimodal_inputs( - **kwargs) + self, **kwargs: object + ) -> MultiModalEmbeddings | None: + mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs) if not mm_input_by_modality: return None @@ -1379,150 +1534,110 @@ def get_multimodal_embeddings( for modality in mm_input_by_modality: multimodal_input = mm_input_by_modality[modality] if modality == "image": - vision_embeddings = self._process_image_input(multimodal_input) - multimodal_embeddings += vision_embeddings + image_embeddings = self._process_image_input(multimodal_input) + multimodal_embeddings += tuple(image_embeddings) if modality == "video": video_embeddings = self._process_video_input(multimodal_input) - multimodal_embeddings += video_embeddings + multimodal_embeddings += tuple(video_embeddings) return multimodal_embeddings def _compute_deepstack_embeds( - self, input_ids: torch.Tensor, inputs_embeds: torch.Tensor, - multimodal_embeddings: MultiModalEmbeddings) -> torch.Tensor: - visual_lens = [ - x.shape[0] if isinstance(x, torch.Tensor) else len(x) - for x in multimodal_embeddings - ] + self, + inputs_embeds: torch.Tensor, + multimodal_embeddings: MultiModalEmbeddings, + is_multimodal: torch.Tensor, + ) -> tuple[torch.Tensor, MultiModalEmbeddings]: + visual_lens = [len(x) for x in multimodal_embeddings] multimodal_embeddings_cat = torch.cat(multimodal_embeddings, dim=0) - multimodal_embeddings_main, multimodal_embeddings_multiscale = torch.split( # noqa:E501 - multimodal_embeddings_cat, [self.visual_dim, self.multiscale_dim], - dim=-1) + ( + multimodal_embeddings_main, + multimodal_embeddings_multiscale, + ) = torch.split( + multimodal_embeddings_cat, + [self.visual_dim, self.multiscale_dim], + dim=-1, + ) - multimodal_embeddings = torch.split(multimodal_embeddings_main, - visual_lens, - dim=0) + multimodal_embeddings = torch.split( + multimodal_embeddings_main, visual_lens, dim=0 + ) multimodal_embeddings_multiscale = torch.split( - multimodal_embeddings_multiscale, visual_lens, dim=0) + multimodal_embeddings_multiscale, visual_lens, dim=0 + ) deepstack_input_embeds = inputs_embeds.new_zeros( - inputs_embeds.size(0), - self.deepstack_num_level * inputs_embeds.size(1)) + inputs_embeds.size(0), self.deepstack_num_level * inputs_embeds.size(1) + ) - deepstack_input_embeds = merge_multimodal_embeddings( - input_ids, - deepstack_input_embeds, - multimodal_embeddings_multiscale, - placeholder_token_id=[ - self.config.image_token_id, self.config.video_token_id - ], + deepstack_input_embeds = _merge_multimodal_embeddings( + inputs_embeds=deepstack_input_embeds, + multimodal_embeddings=multimodal_embeddings_multiscale, + is_multimodal=is_multimodal, ) deepstack_input_embeds = deepstack_input_embeds.view( - inputs_embeds.shape[0], self.deepstack_num_level, self.visual_dim) + inputs_embeds.shape[0], self.deepstack_num_level, self.visual_dim + ) deepstack_input_embeds = deepstack_input_embeds.permute(1, 0, 2) + return deepstack_input_embeds, multimodal_embeddings def get_input_embeddings( self, input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + multimodal_embeddings: MultiModalEmbeddings | None = None, + *, + is_multimodal: torch.Tensor | None = None, + handle_oov_mm_token: bool = False, ) -> torch.Tensor: - deepstack_input_embeds = None - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None: - if self.use_deepstack: - deepstack_input_embeds, multimodal_embeddings = self._compute_deepstack_embeds( # noqa:E501 - input_ids, inputs_embeds, multimodal_embeddings) - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, multimodal_embeddings, - [self.config.image_token_id, self.config.video_token_id]) - - if self.use_deepstack: - if deepstack_input_embeds is None: - deepstack_input_embeds = torch.zeros_like( - inputs_embeds).unsqueeze(0).repeat( - self.deepstack_num_level, 1, 1).contiguous() - self._set_deepstack_input_embeds(deepstack_input_embeds) + inputs_embeds = self._get_text_embeddings( + input_ids, + self.language_model.get_input_embeddings, + is_multimodal=is_multimodal, + handle_oov_mm_token=handle_oov_mm_token, + ) - return inputs_embeds + if multimodal_embeddings is None or len(multimodal_embeddings) == 0: + return inputs_embeds - def get_input_embeddings_v0( - self, - input_ids: torch.Tensor, - image_input: Optional[Qwen2_5_VLImageInputs] = None, - video_input: Optional[Qwen2_5_VLVideoInputs] = None, - ) -> torch.Tensor: - inputs_embeds = self.get_input_embeddings(input_ids) + if is_multimodal is None: + raise ValueError( + "`get_input_embeddings` now requires `is_multimodal` arg, " + "please update your model runner according to " + "https://github.com/vllm-project/vllm/pull/16229." + ) if self.use_deepstack: - visual_dim = inputs_embeds.shape[-1] - deepstack_input_embeds = None - if image_input is not None or video_input is not None: - deepstack_input_embeds = torch.zeros_like( - inputs_embeds).unsqueeze(1).repeat( - 1, self.deepstack_num_level, 1).flatten(1) - - if image_input is not None: - image_embeds = self._process_image_input(image_input) - if self.use_deepstack: - image_embeds = torch.cat(image_embeds) - - image_embeds, image_embeds_multiscale = image_embeds.split( - [visual_dim, visual_dim * self.deepstack_num_level], - dim=-1) - - deepstack_input_embeds = merge_multimodal_embeddings( - input_ids, - deepstack_input_embeds, - image_embeds_multiscale, - placeholder_token_id=self.config.image_token_id, - ) - - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - image_embeds, - placeholder_token_id=self.config.image_token_id, + ( + deepstack_input_embeds, + multimodal_embeddings, + ) = self._compute_deepstack_embeds( + inputs_embeds=inputs_embeds, + multimodal_embeddings=multimodal_embeddings, + is_multimodal=is_multimodal, ) + else: + deepstack_input_embeds = None - if video_input is not None: - video_embeds = self._process_video_input(video_input) - if self.use_deepstack: - video_embeds = torch.cat(video_embeds) - - video_embeds, video_embeds_multiscale = video_embeds.split( - [visual_dim, visual_dim * self.deepstack_num_level], - dim=-1) - - deepstack_input_embeds = merge_multimodal_embeddings( - input_ids, - deepstack_input_embeds, - video_embeds_multiscale, - placeholder_token_id=self.config.video_token_id, - ) - - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - video_embeds, - placeholder_token_id=self.config.video_token_id, - ) + inputs_embeds = _merge_multimodal_embeddings( + inputs_embeds=inputs_embeds, + multimodal_embeddings=multimodal_embeddings, + is_multimodal=is_multimodal, + ) - if self.use_deepstack and deepstack_input_embeds is not None: - deepstack_input_embeds = deepstack_input_embeds.view( - inputs_embeds.shape[0], self.deepstack_num_level, - visual_dim).permute(1, 0, 2).contiguous() + if deepstack_input_embeds is not None: self._set_deepstack_input_embeds(deepstack_input_embeds) + return inputs_embeds def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, **kwargs: object, - ) -> Union[torch.Tensor, IntermediateTensors]: + ) -> torch.Tensor | IntermediateTensors: """Run forward pass for Qwen3VL. Args: @@ -1550,30 +1665,14 @@ def forward( if intermediate_tensors is not None: inputs_embeds = None - # NOTE: In v1, inputs_embeds is always generated at model runner from - # `get_multimodal_embeddings` and `get_input_embeddings`, this - # condition is only for v0 compatibility. - elif inputs_embeds is None: - image_input = self._parse_and_validate_image_input(**kwargs) - video_input = self._parse_and_validate_video_input(**kwargs) - - if image_input is None and video_input is None: - inputs_embeds = None - else: - if uses_mrope(self.config): - assert positions.ndim == 2 and positions.size(0) == 3, ( - "multimodal section rotary embedding requires " - f"(3, seq_len) positions, but got {positions.size()}") - inputs_embeds = self.get_input_embeddings_v0( - input_ids, - image_input=image_input, - video_input=video_input) - input_ids = None - - if self.use_deepstack and inputs_embeds is not None and get_pp_group( - ).is_first_rank: + if ( + self.use_deepstack + and inputs_embeds is not None + and get_pp_group().is_first_rank + ): deepstack_input_embeds = self._get_deepstack_input_embeds( - inputs_embeds.size(0)) + inputs_embeds.size(0) + ) else: deepstack_input_embeds = None @@ -1594,12 +1693,10 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: return self.language_model.compute_logits(hidden_states) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: - + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: skip_prefixes = [] if self.visual is None: skip_prefixes.extend(["visual."]) @@ -1612,6 +1709,6 @@ def get_mm_mapping(self) -> MultiModelKeys: """ return MultiModelKeys.from_string_field( language_model="language_model", - connector="model.visual.merger", - tower_model="model.visual.", + connector="visual.merger", + tower_model="visual.", ) diff --git a/vllm_metax/ops/activation.py b/vllm_metax/ops/activation.py index d35a932d7..5ece9cd65 100644 --- a/vllm_metax/ops/activation.py +++ b/vllm_metax/ops/activation.py @@ -1,61 +1,59 @@ # SPDX-License-Identifier: Apache-2.0 -from vllm.model_executor.layers.activation import (FastGELU, FatreluAndMul, - GeluAndMul, MulAndSilu, - NewGELU, QuickGELU, - SiluAndMul, SwigluOAIAndMul) +from vllm.model_executor.layers.activation import ( + FastGELU, + FatreluAndMul, + GeluAndMul, + MulAndSilu, + NewGELU, + QuickGELU, + SiluAndMul, + SwigluOAIAndMul, +) @FatreluAndMul.register_oot class MacaFatreluAndMul(FatreluAndMul): - def forward_oot(self, *args, **kwargs): return self.forward_cuda(*args, **kwargs) @SiluAndMul.register_oot class MacaSiluAndMul(SiluAndMul): - def forward_oot(self, *args, **kwargs): return self.forward_cuda(*args, **kwargs) @MulAndSilu.register_oot class MacaMulAndSilu(MulAndSilu): - def forward_oot(self, *args, **kwargs): return self.forward_cuda(*args, **kwargs) @GeluAndMul.register_oot class MacaGeluAndMul(GeluAndMul): - def forward_oot(self, *args, **kwargs): return self.forward_cuda(*args, **kwargs) @SwigluOAIAndMul.register_oot class MacaSwigluOAIAndMul(SwigluOAIAndMul): - def forward_oot(self, *args, **kwargs): return self.forward_cuda(*args, **kwargs) @NewGELU.register_oot class MacaNewGELU(NewGELU): - def forward_oot(self, *args, **kwargs): return self.forward_cuda(*args, **kwargs) @FastGELU.register_oot class MacaFastGELU(FastGELU): - def forward_oot(self, *args, **kwargs): return self.forward_cuda(*args, **kwargs) @QuickGELU.register_oot class MacaQuickGELU(QuickGELU): - def forward_oot(self, *args, **kwargs): return self.forward_cuda(*args, **kwargs) diff --git a/vllm_metax/ops/fused_moe.py b/vllm_metax/ops/fused_moe.py index 1181607f1..2e8dead3d 100644 --- a/vllm_metax/ops/fused_moe.py +++ b/vllm_metax/ops/fused_moe.py @@ -1,13 +1,85 @@ # SPDX-License-Identifier: Apache-2.0 from vllm.model_executor.layers.fused_moe.layer import ( - UnquantizedFusedMoEMethod) + UnquantizedFusedMoEMethod, + FusedMoE, +) + +from typing import Callable +import torch from vllm_metax.model_executor.layers.fused_moe.fused_moe import fused_experts @UnquantizedFusedMoEMethod.register_oot class MacaUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod): + def forward_oot( + self, + layer: torch.nn.Module, + x: torch.Tensor, + use_grouped_topk: bool, + top_k: int, + router_logits: torch.Tensor, + renormalize: bool, + topk_group: int | None = None, + num_expert_group: int | None = None, + global_num_experts: int = -1, + expert_map: torch.Tensor | None = None, + custom_routing_function: Callable | None = None, + scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, + e_score_correction_bias: torch.Tensor | None = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: torch.Tensor | None = None, + logical_to_physical_map: torch.Tensor | None = None, + logical_replica_count: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + zero_expert_num = getattr(layer, "zero_expert_num", 0) + zero_expert_type = getattr(layer, "zero_expert_type", None) + + topk_weights, topk_ids, zero_expert_result = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, + e_score_correction_bias=e_score_correction_bias, + indices_type=self.topk_indices_dtype, + enable_eplb=enable_eplb, + expert_map=expert_map, + expert_load_view=expert_load_view, + logical_to_physical_map=logical_to_physical_map, + logical_replica_count=logical_replica_count, + global_num_experts=global_num_experts, + zero_expert_num=zero_expert_num, + zero_expert_type=zero_expert_type, + num_fused_shared_experts=layer.num_fused_shared_experts, + ) + + result = fused_experts( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=True, + activation=activation, + quant_config=self.moe_quant_config, + apply_router_weight_on_input=apply_router_weight_on_input, + global_num_experts=global_num_experts, + expert_map=expert_map, + ) - def __init__(self, moe): - super().__init__(moe) - self.fused_experts = fused_experts # type: ignore + if zero_expert_num != 0 and zero_expert_type is not None: + assert not isinstance(result, tuple), ( + "Shared + zero experts are mutually exclusive not yet supported" + ) + return result, zero_expert_result + else: + return result diff --git a/vllm_metax/ops/layernorm.py b/vllm_metax/ops/layernorm.py index e1158ec75..dcaa11c73 100644 --- a/vllm_metax/ops/layernorm.py +++ b/vllm_metax/ops/layernorm.py @@ -4,13 +4,11 @@ @RMSNorm.register_oot class MacaRMSNorm(RMSNorm): - def forward_oot(self, *args, **kwargs): return self.forward_cuda(*args, **kwargs) @GemmaRMSNorm.register_oot class MacaGemmaRMSNorm(GemmaRMSNorm): - def forward_oot(self, *args, **kwargs): return self.forward_cuda(*args, **kwargs) diff --git a/vllm_metax/ops/rotary_embedding.py b/vllm_metax/ops/rotary_embedding.py index ca124cf3c..ed1675e6d 100644 --- a/vllm_metax/ops/rotary_embedding.py +++ b/vllm_metax/ops/rotary_embedding.py @@ -4,6 +4,5 @@ @RotaryEmbedding.register_oot class MacaRotaryEmbedding(RotaryEmbedding): - def forward_oot(self, *args, **kwargs): return self.forward_cuda(*args, **kwargs) diff --git a/vllm_metax/patch/__init__.py b/vllm_metax/patch/__init__.py index 27a76f597..3c81ff74c 100644 --- a/vllm_metax/patch/__init__.py +++ b/vllm_metax/patch/__init__.py @@ -1,8 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # isort: skip_file -from . import hotfix from . import maca_visible_device from . import distributed from . import device_allocator from . import model_executor from . import oot +from . import sample diff --git a/vllm_metax/patch/device_allocator/device_allocator.py b/vllm_metax/patch/device_allocator/device_allocator.py index a1d7b0ad5..16b7e2a66 100644 --- a/vllm_metax/patch/device_allocator/device_allocator.py +++ b/vllm_metax/patch/device_allocator/device_allocator.py @@ -4,9 +4,84 @@ logger = init_logger(__name__) -import vllm.device_allocator.cumem +from contextlib import AbstractContextManager, nullcontext +from vllm.utils.mem_constants import GiB_bytes -from vllm_metax.device_allocator.cumem import (CuMemAllocator as - mx_CuMemAllocator) +import torch +from vllm.v1.worker import worker_base +from vllm.v1.kv_cache_interface import KVCacheConfig -vllm.device_allocator.cumem.CuMemAllocator = mx_CuMemAllocator + +def sleep(self, level: int = 1) -> None: + from vllm_metax.device_allocator.cumem import CuMemAllocator + + free_bytes_before_sleep = torch.cuda.mem_get_info()[0] + + # Save the buffers before level 2 sleep + if level == 2: + model = self.model_runner.model + self._sleep_saved_buffers = { + name: buffer.cpu().clone() for name, buffer in model.named_buffers() + } + + allocator = CuMemAllocator.get_instance() + allocator.sleep(offload_tags=("weights",) if level == 1 else tuple()) + free_bytes_after_sleep, total = torch.cuda.mem_get_info() + freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep + used_bytes = total - free_bytes_after_sleep + assert freed_bytes >= 0, "Memory usage increased after sleeping." + logger.info( + "Sleep mode freed %.2f GiB memory, %.2f GiB memory is still in use.", + freed_bytes / GiB_bytes, + used_bytes / GiB_bytes, + ) + + +def wake_up(self, tags: list[str] | None = None) -> None: + from vllm_metax.device_allocator.cumem import CuMemAllocator + + allocator = CuMemAllocator.get_instance() + allocator.wake_up(tags) + + # Restore the buffers after level 2 sleep + if len(self._sleep_saved_buffers): + model = self.model_runner.model + for name, buffer in model.named_buffers(): + if name in self._sleep_saved_buffers: + buffer.data.copy_(self._sleep_saved_buffers[name].data) + self._sleep_saved_buffers = {} + + +def _maybe_get_memory_pool_context(self, tag: str) -> AbstractContextManager: + if self.vllm_config.model_config.enable_sleep_mode: + from vllm_metax.device_allocator.cumem import CuMemAllocator + + allocator = CuMemAllocator.get_instance() + if tag == "weights": + assert allocator.get_current_usage() == 0, ( + "Sleep mode can only be used for one instance per process." + ) + context = allocator.use_memory_pool(tag=tag) + else: + context = nullcontext() + return context + + +def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None: + """Allocate GPU KV cache with the specified kv_cache_config.""" + + if self.vllm_config.model_config.enable_sleep_mode: + from vllm_metax.device_allocator.cumem import CuMemAllocator + + allocator = CuMemAllocator.get_instance() + context = allocator.use_memory_pool(tag="kv_cache") + else: + context = nullcontext() + with context: + self.model_runner.initialize_kv_cache(kv_cache_config) + + +worker_base.sleep = sleep +worker_base.wake_up = wake_up +worker_base._maybe_get_memory_pool_context = _maybe_get_memory_pool_context +worker_base.initialize_from_config = initialize_from_config diff --git a/vllm_metax/patch/distributed/__init__.py b/vllm_metax/patch/distributed/__init__.py index 230fbf397..6c0ba04c2 100644 --- a/vllm_metax/patch/distributed/__init__.py +++ b/vllm_metax/patch/distributed/__init__.py @@ -1,3 +1,3 @@ # SPDX-License-Identifier: Apache-2.0 -from . import (cuda_wrapper, pynccl_wrapper, utils_patch) +from . import cuda_wrapper, pynccl_wrapper, utils_patch diff --git a/vllm_metax/patch/distributed/cuda_wrapper.py b/vllm_metax/patch/distributed/cuda_wrapper.py index 7071e5a9f..22090ebca 100644 --- a/vllm_metax/patch/distributed/cuda_wrapper.py +++ b/vllm_metax/patch/distributed/cuda_wrapper.py @@ -42,9 +42,8 @@ def find_loaded_library(lib_name) -> Optional[str]: the file `/proc/self/maps` contains the memory maps of the process, which includes the shared libraries loaded by the process. We can use this file to find the path of the a loaded library. - """ # noqa - logger.info( - f"[Plugin] Hooked find_loaded_library -> {find_loaded_library}") + """ # noqa + logger.info(f"[Plugin] Hooked find_loaded_library -> {find_loaded_library}") found = False with open("/proc/self/maps") as f: @@ -60,8 +59,9 @@ def find_loaded_library(lib_name) -> Optional[str]: start = line.index("/") path = line[start:].strip() filename = path.split("/")[-1] - assert filename.rpartition(".so")[0].startswith(lib_name), \ + assert filename.rpartition(".so")[0].startswith(lib_name), ( f"Unexpected filename: {filename} for library {lib_name}" + ) return path @@ -73,30 +73,36 @@ class CudaRTLibrary: Function("mcDeviceSynchronize", cudaError_t, []), # ​cudaError_t cudaDeviceReset ( void ) Function("mcDeviceReset", cudaError_t, []), - # const char* cudaGetErrorString ( cudaError_t error ) Function("mcGetErrorString", ctypes.c_char_p, [cudaError_t]), - # ​cudaError_t cudaMalloc ( void** devPtr, size_t size ) - Function("mcMalloc", cudaError_t, - [ctypes.POINTER(ctypes.c_void_p), ctypes.c_size_t]), + Function( + "mcMalloc", cudaError_t, [ctypes.POINTER(ctypes.c_void_p), ctypes.c_size_t] + ), # ​cudaError_t cudaFree ( void* devPtr ) Function("mcFree", cudaError_t, [ctypes.c_void_p]), # ​cudaError_t cudaMemset ( void* devPtr, int value, size_t count ) - Function("mcMemset", cudaError_t, - [ctypes.c_void_p, ctypes.c_int, ctypes.c_size_t]), + Function( + "mcMemset", cudaError_t, [ctypes.c_void_p, ctypes.c_int, ctypes.c_size_t] + ), # ​cudaError_t cudaMemcpy ( void* dst, const void* src, size_t count, cudaMemcpyKind kind ) # noqa - Function("mcMemcpy", cudaError_t, [ - ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, cudaMemcpyKind - ]), - + Function( + "mcMemcpy", + cudaError_t, + [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, cudaMemcpyKind], + ), # cudaError_t cudaIpcGetMemHandle ( cudaIpcMemHandle_t* handle, void* devPtr ) # noqa - Function("mcIpcGetMemHandle", cudaError_t, - [ctypes.POINTER(cudaIpcMemHandle_t), ctypes.c_void_p]), + Function( + "mcIpcGetMemHandle", + cudaError_t, + [ctypes.POINTER(cudaIpcMemHandle_t), ctypes.c_void_p], + ), # ​cudaError_t cudaIpcOpenMemHandle ( void** devPtr, cudaIpcMemHandle_t handle, unsigned int flags ) # noqa - Function("mcIpcOpenMemHandle", cudaError_t, [ - ctypes.POINTER(ctypes.c_void_p), cudaIpcMemHandle_t, ctypes.c_uint - ]), + Function( + "mcIpcOpenMemHandle", + cudaError_t, + [ctypes.POINTER(ctypes.c_void_p), cudaIpcMemHandle_t, ctypes.c_uint], + ), ] # class attribute to store the mapping from the path to the library @@ -112,11 +118,10 @@ def __init__(self, so_file: Optional[str] = None): so_file = find_loaded_library("libmcruntime") if so_file is None: so_file = envs.VLLM_CUDART_SO_PATH # fallback to env var - assert so_file is not None, \ - ( - "libcudart is not loaded in the current process, " - "try setting VLLM_CUDART_SO_PATH" - ) + assert so_file is not None, ( + "libcudart is not loaded in the current process, " + "try setting VLLM_CUDART_SO_PATH" + ) if so_file not in CudaRTLibrary.path_to_library_cache: lib = ctypes.CDLL(so_file) CudaRTLibrary.path_to_library_cache[so_file] = lib @@ -157,29 +162,29 @@ def cudaMalloc(self, size: int) -> ctypes.c_void_p: def cudaFree(self, devPtr: ctypes.c_void_p) -> None: self.CUDART_CHECK(self.funcs["mcFree"](devPtr)) - def cudaMemset(self, devPtr: ctypes.c_void_p, value: int, - count: int) -> None: + def cudaMemset(self, devPtr: ctypes.c_void_p, value: int, count: int) -> None: self.CUDART_CHECK(self.funcs["mcMemset"](devPtr, value, count)) - def cudaMemcpy(self, dst: ctypes.c_void_p, src: ctypes.c_void_p, - count: int) -> None: + def cudaMemcpy( + self, dst: ctypes.c_void_p, src: ctypes.c_void_p, count: int + ) -> None: cudaMemcpyDefault = 4 kind = cudaMemcpyDefault self.CUDART_CHECK(self.funcs["mcMemcpy"](dst, src, count, kind)) - def cudaIpcGetMemHandle(self, - devPtr: ctypes.c_void_p) -> cudaIpcMemHandle_t: + def cudaIpcGetMemHandle(self, devPtr: ctypes.c_void_p) -> cudaIpcMemHandle_t: handle = cudaIpcMemHandle_t() - self.CUDART_CHECK(self.funcs["mcIpcGetMemHandle"](ctypes.byref(handle), - devPtr)) + self.CUDART_CHECK(self.funcs["mcIpcGetMemHandle"](ctypes.byref(handle), devPtr)) return handle - def cudaIpcOpenMemHandle(self, - handle: cudaIpcMemHandle_t) -> ctypes.c_void_p: + def cudaIpcOpenMemHandle(self, handle: cudaIpcMemHandle_t) -> ctypes.c_void_p: cudaIpcMemLazyEnablePeerAccess = 1 devPtr = ctypes.c_void_p() - self.CUDART_CHECK(self.funcs["mcIpcOpenMemHandle"]( - ctypes.byref(devPtr), handle, cudaIpcMemLazyEnablePeerAccess)) + self.CUDART_CHECK( + self.funcs["mcIpcOpenMemHandle"]( + ctypes.byref(devPtr), handle, cudaIpcMemLazyEnablePeerAccess + ) + ) return devPtr diff --git a/vllm_metax/patch/distributed/pynccl_wrapper.py b/vllm_metax/patch/distributed/pynccl_wrapper.py index 059cf4b56..b7d01fc42 100644 --- a/vllm_metax/patch/distributed/pynccl_wrapper.py +++ b/vllm_metax/patch/distributed/pynccl_wrapper.py @@ -6,8 +6,18 @@ import vllm from vllm.distributed.device_communicators.pynccl_wrapper import ( - Function, NCCLLibrary, buffer_type, cudaStream_t, logger, ncclComm_t, - ncclDataType_t, ncclRedOp_t, ncclResult_t, ncclWindow_t, ncclUniqueId) + Function, + NCCLLibrary, + buffer_type, + cudaStream_t, + logger, + ncclComm_t, + ncclDataType_t, + ncclRedOp_t, + ncclResult_t, + ncclWindow_t, + ncclUniqueId, +) from vllm.logger import init_logger from vllm_metax.utils import find_mccl_library @@ -20,88 +30,141 @@ class NCCLLibrary: # const char* ncclGetErrorString(ncclResult_t result) Function("mcclGetErrorString", ctypes.c_char_p, [ncclResult_t]), # ncclResult_t ncclGetVersion(int *version); - Function("mcclGetVersion", ncclResult_t, - [ctypes.POINTER(ctypes.c_int)]), + Function("mcclGetVersion", ncclResult_t, [ctypes.POINTER(ctypes.c_int)]), # ncclResult_t ncclGetUniqueId(ncclUniqueId* uniqueId); - Function("mcclGetUniqueId", ncclResult_t, - [ctypes.POINTER(ncclUniqueId)]), + Function("mcclGetUniqueId", ncclResult_t, [ctypes.POINTER(ncclUniqueId)]), # ncclResult_t ncclCommInitRank( # ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank); # note that ncclComm_t is a pointer type, so the first argument # is a pointer to a pointer - Function("mcclCommInitRank", ncclResult_t, [ - ctypes.POINTER(ncclComm_t), ctypes.c_int, ncclUniqueId, - ctypes.c_int - ]), + Function( + "mcclCommInitRank", + ncclResult_t, + [ctypes.POINTER(ncclComm_t), ctypes.c_int, ncclUniqueId, ctypes.c_int], + ), # ncclResult_t ncclAllReduce( # const void* sendbuff, void* recvbuff, size_t count, # ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm, # cudaStream_t stream); # note that cudaStream_t is a pointer type, so the last argument # is a pointer - Function("mcclAllReduce", ncclResult_t, [ - buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t, - ncclRedOp_t, ncclComm_t, cudaStream_t - ]), - + Function( + "mcclAllReduce", + ncclResult_t, + [ + buffer_type, + buffer_type, + ctypes.c_size_t, + ncclDataType_t, + ncclRedOp_t, + ncclComm_t, + cudaStream_t, + ], + ), # ncclResult_t ncclReduce( # const void* sendbuff, void* recvbuff, size_t count, # ncclDataType_t datatype, ncclRedOp_t op, int root, # ncclComm_t comm, cudaStream_t stream); # note that cudaStream_t is a pointer type, so the last argument # is a pointer - Function("mcclReduce", ncclResult_t, [ - buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t, - ncclRedOp_t, ctypes.c_int, ncclComm_t, cudaStream_t - ]), - + Function( + "mcclReduce", + ncclResult_t, + [ + buffer_type, + buffer_type, + ctypes.c_size_t, + ncclDataType_t, + ncclRedOp_t, + ctypes.c_int, + ncclComm_t, + cudaStream_t, + ], + ), # ncclResult_t ncclAllGather( # const void* sendbuff, void* recvbuff, size_t count, # ncclDataType_t datatype, ncclComm_t comm, # cudaStream_t stream); # note that cudaStream_t is a pointer type, so the last argument # is a pointer - Function("mcclAllGather", ncclResult_t, [ - buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t, - ncclComm_t, cudaStream_t - ]), - + Function( + "mcclAllGather", + ncclResult_t, + [ + buffer_type, + buffer_type, + ctypes.c_size_t, + ncclDataType_t, + ncclComm_t, + cudaStream_t, + ], + ), # ncclResult_t ncclReduceScatter( # const void* sendbuff, void* recvbuff, size_t count, # ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm, # cudaStream_t stream); # note that cudaStream_t is a pointer type, so the last argument # is a pointer - Function("mcclReduceScatter", ncclResult_t, [ - buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t, - ncclRedOp_t, ncclComm_t, cudaStream_t - ]), - + Function( + "mcclReduceScatter", + ncclResult_t, + [ + buffer_type, + buffer_type, + ctypes.c_size_t, + ncclDataType_t, + ncclRedOp_t, + ncclComm_t, + cudaStream_t, + ], + ), # ncclResult_t ncclSend( # const void* sendbuff, size_t count, ncclDataType_t datatype, # int dest, ncclComm_t comm, cudaStream_t stream); - Function("mcclSend", ncclResult_t, [ - buffer_type, ctypes.c_size_t, ncclDataType_t, ctypes.c_int, - ncclComm_t, cudaStream_t - ]), - + Function( + "mcclSend", + ncclResult_t, + [ + buffer_type, + ctypes.c_size_t, + ncclDataType_t, + ctypes.c_int, + ncclComm_t, + cudaStream_t, + ], + ), # ncclResult_t ncclRecv( # void* recvbuff, size_t count, ncclDataType_t datatype, # int src, ncclComm_t comm, cudaStream_t stream); - Function("mcclRecv", ncclResult_t, [ - buffer_type, ctypes.c_size_t, ncclDataType_t, ctypes.c_int, - ncclComm_t, cudaStream_t - ]), - + Function( + "mcclRecv", + ncclResult_t, + [ + buffer_type, + ctypes.c_size_t, + ncclDataType_t, + ctypes.c_int, + ncclComm_t, + cudaStream_t, + ], + ), # ncclResult_t ncclBroadcast( # const void* sendbuff, void* recvbuff, size_t count, # ncclDataType_t datatype, int root, ncclComm_t comm, # cudaStream_t stream); - Function("mcclBroadcast", ncclResult_t, [ - buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t, - ctypes.c_int, ncclComm_t, cudaStream_t - ]), - + Function( + "mcclBroadcast", + ncclResult_t, + [ + buffer_type, + buffer_type, + ctypes.c_size_t, + ncclDataType_t, + ctypes.c_int, + ncclComm_t, + cudaStream_t, + ], + ), # be cautious! this is a collective call, it will block until all # processes in the communicator have called this function. # because Python object destruction can happen in random order, @@ -141,7 +204,6 @@ class NCCLLibrary: path_to_dict_mapping: dict[str, dict[str, Any]] = {} def __init__(self, so_file: Optional[str] = None): - so_file = so_file or find_mccl_library() try: @@ -157,8 +219,10 @@ def __init__(self, so_file: Optional[str] = None): "or it does not support the current platform %s. " "If you already have the library, please set the " "environment variable VLLM_NCCL_SO_PATH" - " to point to the correct nccl library path.", so_file, - platform.platform()) + " to point to the correct nccl library path.", + so_file, + platform.platform(), + ) raise e if so_file not in NCCLLibrary.path_to_dict_mapping: @@ -195,88 +259,153 @@ def ncclGetVersion(self) -> str: def ncclGetUniqueId(self) -> ncclUniqueId: unique_id = ncclUniqueId() - self.NCCL_CHECK(self._funcs["mcclGetUniqueId"]( - ctypes.byref(unique_id))) + self.NCCL_CHECK(self._funcs["mcclGetUniqueId"](ctypes.byref(unique_id))) return unique_id def unique_id_from_bytes(self, data: bytes) -> ncclUniqueId: if len(data) != 128: raise ValueError( - f"Expected 128 bytes for ncclUniqueId, got {len(data)} bytes") + f"Expected 128 bytes for ncclUniqueId, got {len(data)} bytes" + ) unique_id = ncclUniqueId() ctypes.memmove(ctypes.addressof(unique_id.internal), data, 128) return unique_id - def ncclCommInitRank(self, world_size: int, unique_id: ncclUniqueId, - rank: int) -> ncclComm_t: + def ncclCommInitRank( + self, world_size: int, unique_id: ncclUniqueId, rank: int + ) -> ncclComm_t: comm = ncclComm_t() - self.NCCL_CHECK(self._funcs["mcclCommInitRank"](ctypes.byref(comm), - world_size, unique_id, - rank)) + self.NCCL_CHECK( + self._funcs["mcclCommInitRank"]( + ctypes.byref(comm), world_size, unique_id, rank + ) + ) return comm - def ncclAllReduce(self, sendbuff: buffer_type, recvbuff: buffer_type, - count: int, datatype: int, op: int, comm: ncclComm_t, - stream: cudaStream_t) -> None: + def ncclAllReduce( + self, + sendbuff: buffer_type, + recvbuff: buffer_type, + count: int, + datatype: int, + op: int, + comm: ncclComm_t, + stream: cudaStream_t, + ) -> None: # `datatype` actually should be `ncclDataType_t` # and `op` should be `ncclRedOp_t` # both are aliases of `ctypes.c_int` # when we pass int to a function, it will be converted to `ctypes.c_int` # by ctypes automatically - self.NCCL_CHECK(self._funcs["mcclAllReduce"](sendbuff, recvbuff, count, - datatype, op, comm, - stream)) - - def ncclReduce(self, sendbuff: buffer_type, recvbuff: buffer_type, - count: int, datatype: int, op: int, root: int, - comm: ncclComm_t, stream: cudaStream_t) -> None: + self.NCCL_CHECK( + self._funcs["mcclAllReduce"]( + sendbuff, recvbuff, count, datatype, op, comm, stream + ) + ) + + def ncclReduce( + self, + sendbuff: buffer_type, + recvbuff: buffer_type, + count: int, + datatype: int, + op: int, + root: int, + comm: ncclComm_t, + stream: cudaStream_t, + ) -> None: # `datatype` actually should be `ncclDataType_t` # and `op` should be `ncclRedOp_t` # both are aliases of `ctypes.c_int` # when we pass int to a function, it will be converted to `ctypes.c_int` # by ctypes automatically - self.NCCL_CHECK(self._funcs["mcclReduce"](sendbuff, recvbuff, count, - datatype, op, root, comm, - stream)) - - def ncclReduceScatter(self, sendbuff: buffer_type, recvbuff: buffer_type, - count: int, datatype: int, op: int, comm: ncclComm_t, - stream: cudaStream_t) -> None: + self.NCCL_CHECK( + self._funcs["mcclReduce"]( + sendbuff, recvbuff, count, datatype, op, root, comm, stream + ) + ) + + def ncclReduceScatter( + self, + sendbuff: buffer_type, + recvbuff: buffer_type, + count: int, + datatype: int, + op: int, + comm: ncclComm_t, + stream: cudaStream_t, + ) -> None: # `datatype` actually should be `ncclDataType_t` # and `op` should be `ncclRedOp_t` # both are aliases of `ctypes.c_int` # when we pass int to a function, it will be converted to `ctypes.c_int` # by ctypes automatically - self.NCCL_CHECK(self._funcs["mcclReduceScatter"](sendbuff, recvbuff, - count, datatype, op, - comm, stream)) - - def ncclAllGather(self, sendbuff: buffer_type, recvbuff: buffer_type, - count: int, datatype: int, comm: ncclComm_t, - stream: cudaStream_t) -> None: + self.NCCL_CHECK( + self._funcs["mcclReduceScatter"]( + sendbuff, recvbuff, count, datatype, op, comm, stream + ) + ) + + def ncclAllGather( + self, + sendbuff: buffer_type, + recvbuff: buffer_type, + count: int, + datatype: int, + comm: ncclComm_t, + stream: cudaStream_t, + ) -> None: # `datatype` actually should be `ncclDataType_t` # which is an aliases of `ctypes.c_int` # when we pass int to a function, it will be converted to `ctypes.c_int` # by ctypes automatically - self.NCCL_CHECK(self._funcs["mcclAllGather"](sendbuff, recvbuff, count, - datatype, comm, stream)) - - def ncclSend(self, sendbuff: buffer_type, count: int, datatype: int, - dest: int, comm: ncclComm_t, stream: cudaStream_t) -> None: - self.NCCL_CHECK(self._funcs["mcclSend"](sendbuff, count, datatype, - dest, comm, stream)) - - def ncclRecv(self, recvbuff: buffer_type, count: int, datatype: int, - src: int, comm: ncclComm_t, stream: cudaStream_t) -> None: - self.NCCL_CHECK(self._funcs["mcclRecv"](recvbuff, count, datatype, src, - comm, stream)) - - def ncclBroadcast(self, sendbuff: buffer_type, recvbuff: buffer_type, - count: int, datatype: int, root: int, comm: ncclComm_t, - stream: cudaStream_t) -> None: - self.NCCL_CHECK(self._funcs["mcclBroadcast"](sendbuff, recvbuff, count, - datatype, root, comm, - stream)) + self.NCCL_CHECK( + self._funcs["mcclAllGather"]( + sendbuff, recvbuff, count, datatype, comm, stream + ) + ) + + def ncclSend( + self, + sendbuff: buffer_type, + count: int, + datatype: int, + dest: int, + comm: ncclComm_t, + stream: cudaStream_t, + ) -> None: + self.NCCL_CHECK( + self._funcs["mcclSend"](sendbuff, count, datatype, dest, comm, stream) + ) + + def ncclRecv( + self, + recvbuff: buffer_type, + count: int, + datatype: int, + src: int, + comm: ncclComm_t, + stream: cudaStream_t, + ) -> None: + self.NCCL_CHECK( + self._funcs["mcclRecv"](recvbuff, count, datatype, src, comm, stream) + ) + + def ncclBroadcast( + self, + sendbuff: buffer_type, + recvbuff: buffer_type, + count: int, + datatype: int, + root: int, + comm: ncclComm_t, + stream: cudaStream_t, + ) -> None: + self.NCCL_CHECK( + self._funcs["mcclBroadcast"]( + sendbuff, recvbuff, count, datatype, root, comm, stream + ) + ) def ncclCommDestroy(self, comm: ncclComm_t) -> None: self.NCCL_CHECK(self._funcs["mcclCommDestroy"](comm)) @@ -287,15 +416,15 @@ def ncclGroupStart(self) -> None: def ncclGroupEnd(self) -> None: self.NCCL_CHECK(self._funcs["mcclGroupEnd"]()) - def ncclCommWindowRegister(self, comm: ncclComm_t, buff: buffer_type, - size: int, win_flags: int) -> ncclWindow_t: + def ncclCommWindowRegister( + self, comm: ncclComm_t, buff: buffer_type, size: int, win_flags: int + ) -> ncclWindow_t: window = ncclWindow_t() # self.NCCL_CHECK(self._funcs["mcclCommWindowRegister"]( # comm, buff, size, ctypes.byref(window), win_flags)) return window - def ncclCommWindowDeregister(self, comm: ncclComm_t, - window: ncclWindow_t) -> None: + def ncclCommWindowDeregister(self, comm: ncclComm_t, window: ncclWindow_t) -> None: # self.NCCL_CHECK(self._funcs["mcclCommWindowDeregister"](comm, window)) return diff --git a/vllm_metax/patch/hotfix/patch_utils.py b/vllm_metax/patch/hotfix/patch_utils.py deleted file mode 100644 index 80eac1596..000000000 --- a/vllm_metax/patch/hotfix/patch_utils.py +++ /dev/null @@ -1,72 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -from typing import Callable, List, Optional - -import torch -from torch.library import Library -from vllm import utils -from vllm.utils import vllm_lib, supports_custom_op - - -def maca_direct_register_custom_op( - op_name: str, - op_func: Callable, - mutates_args: Optional[list[str]] = None, - fake_impl: Optional[Callable] = None, - target_lib: Optional[Library] = None, - dispatch_key: Optional[str] = None, - tags: tuple[torch.Tag, ...] = (), -): - """ - `torch.library.custom_op` can have significant overhead because it - needs to consider complicated dispatching logic. This function - directly registers a custom op and dispatches it to the CUDA backend. - See https://gist.github.com/youkaichao/ecbea9ec9fc79a45d2adce1784d7a9a5 - for more details. - - By default, the custom op is registered to the vLLM library. If you - want to register it to a different library, you can pass the library - object to the `target_lib` argument. - - IMPORTANT: the lifetime of the operator is tied to the lifetime of the - library object. If you want to bind the operator to a different library, - make sure the library object is alive when the operator is used. - """ - if not supports_custom_op(): - from vllm.platforms import current_platform - assert not current_platform.is_cuda_alike(), ( - "cuda platform needs torch>=2.4 to support custom op, " - "chances are you are using an old version of pytorch " - "or a custom build of pytorch. It is recommended to " - "use vLLM in a fresh new environment and let it install " - "the required dependencies.") - return - - if mutates_args is None: - mutates_args = [] - - if dispatch_key is None: - from vllm.platforms import current_platform - dispatch_key = current_platform.dispatch_key - - for k, v in op_func.__annotations__.items(): - if v == list[int]: - op_func.__annotations__[k] = List[int] - if v == Optional[list[int]]: - op_func.__annotations__[k] = Optional[List[int]] - # TODO: add more type convert here if needed. - import torch.library - if hasattr(torch.library, "infer_schema"): - schema_str = torch.library.infer_schema(op_func, - mutates_args=mutates_args) - else: - # for pytorch 2.4 - import torch._custom_op.impl - schema_str = torch._custom_op.impl.infer_schema(op_func, mutates_args) - my_lib = target_lib or vllm_lib - my_lib.define(op_name + schema_str, tags=tags) - my_lib.impl(op_name, op_func, dispatch_key=dispatch_key) - if fake_impl is not None: - my_lib._register_fake(op_name, fake_impl) - - -utils.direct_register_custom_op = maca_direct_register_custom_op diff --git a/vllm_metax/patch/maca_visible_device.py b/vllm_metax/patch/maca_visible_device.py index 5a09b339f..8c8a59ce4 100644 --- a/vllm_metax/patch/maca_visible_device.py +++ b/vllm_metax/patch/maca_visible_device.py @@ -1,10 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 -from vllm.worker.worker_base import WorkerWrapperBase, logger +from vllm.v1.worker.worker_base import WorkerWrapperBase, logger import contextlib from typing import List, Dict, Iterator import os -from vllm.utils import update_environment_variables +from vllm.utils.system_utils import update_environment_variables from vllm.platforms import current_platform from vllm.v1.engine.utils import get_device_indices from unittest.mock import patch @@ -12,11 +12,12 @@ def update_environment_variables_with_maca( - self, envs_list: List[Dict[str, str]]) -> None: + self, envs_list: List[Dict[str, str]] +) -> None: envs = envs_list[self.rpc_rank] - key = 'CUDA_VISIBLE_DEVICES' + key = "CUDA_VISIBLE_DEVICES" # sync `MACA_VISIBLE_DEVICES`` with `CUDA_VISIBLE_DEVICES` - envs['MACA_VISIBLE_DEVICES'] = envs.get(key, '') + envs["MACA_VISIBLE_DEVICES"] = envs.get(key, "") if key in envs and key in os.environ: # overwriting CUDA_VISIBLE_DEVICES is desired behavior # suppress the warning in `update_environment_variables` @@ -25,8 +26,9 @@ def update_environment_variables_with_maca( @contextlib.contextmanager -def set_device_control_env_var_with_maca(vllm_config: VllmConfig, - local_dp_rank: int) -> Iterator[None]: +def set_device_control_env_var_with_maca( + vllm_config: VllmConfig, local_dp_rank: int +) -> Iterator[None]: """ Temporarily set CUDA_VISIBLE_DEVICES or equivalent for engine subprocess. @@ -35,8 +37,8 @@ def set_device_control_env_var_with_maca(vllm_config: VllmConfig, evar = current_platform.device_control_env_var value = get_device_indices(evar, local_dp_rank, world_size) - with patch.dict(os.environ, values=((evar, value), )): - os.environ['MACA_VISIBLE_DEVICES'] = value + with patch.dict(os.environ, values=((evar, value),)): + os.environ["MACA_VISIBLE_DEVICES"] = value yield diff --git a/vllm_metax/patch/model_executor/chunk_delta_h.py b/vllm_metax/patch/model_executor/chunk_delta_h.py index af62b4ea9..4426a8c25 100644 --- a/vllm_metax/patch/model_executor/chunk_delta_h.py +++ b/vllm_metax/patch/model_executor/chunk_delta_h.py @@ -13,22 +13,26 @@ from vllm.model_executor.layers.fla.ops.utils import use_cuda_graph -@triton.heuristics({ - 'USE_G': lambda args: args['g'] is not None, - 'USE_INITIAL_STATE': lambda args: args['h0'] is not None, - 'STORE_FINAL_STATE': lambda args: args['ht'] is not None, - 'SAVE_NEW_VALUE': lambda args: args['v_new'] is not None, - 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None, -}) +@triton.heuristics( + { + "USE_G": lambda args: args["g"] is not None, + "USE_INITIAL_STATE": lambda args: args["h0"] is not None, + "STORE_FINAL_STATE": lambda args: args["ht"] is not None, + "SAVE_NEW_VALUE": lambda args: args["v_new"] is not None, + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + } +) @triton.autotune( configs=[ - triton.Config({'BV': BV}, num_warps=num_warps, num_stages=num_stages) - for num_warps in [2, 4] for num_stages in [1] for BV in [32, 64] + triton.Config({"BV": BV}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4] + for num_stages in [1] + for BV in [32, 64] ], - key=['H', 'K', 'V', 'BT', 'USE_G'], + key=["H", "K", "V", "BT", "USE_G"], use_cuda_graph=use_cuda_graph, ) -@triton.jit(do_not_specialize=['T']) +@triton.jit(do_not_specialize=["T"]) def chunk_gated_delta_rule_fwd_kernel_h_blockdim64( k, v, @@ -56,8 +60,10 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64( i_v, i_nh = tl.program_id(0), tl.program_id(1) i_n, i_h = i_nh // H, i_nh % H if IS_VARLEN: - bos, eos = tl.load(cu_seqlens + i_n).to( - tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + bos, eos = ( + tl.load(cu_seqlens + i_n).to(tl.int32), + tl.load(cu_seqlens + i_n + 1).to(tl.int32), + ) T = eos - bos NT = tl.cdiv(T, BT) boh = tl.load(chunk_offsets + i_n).to(tl.int32) @@ -93,87 +99,98 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64( # load initial state if USE_INITIAL_STATE: - p_h0_1 = tl.make_block_ptr(h0, (K, V), (V, 1), (0, i_v * BV), (64, BV), - (1, 0)) + p_h0_1 = tl.make_block_ptr(h0, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0)) b_h1 += tl.load(p_h0_1, boundary_check=(0, 1)).to(tl.float32) if K > 64: - p_h0_2 = tl.make_block_ptr(h0, (K, V), (V, 1), (64, i_v * BV), - (64, BV), (1, 0)) + p_h0_2 = tl.make_block_ptr( + h0, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0) + ) b_h2 += tl.load(p_h0_2, boundary_check=(0, 1)).to(tl.float32) if K > 128: - p_h0_3 = tl.make_block_ptr(h0, (K, V), (V, 1), (128, i_v * BV), - (64, BV), (1, 0)) + p_h0_3 = tl.make_block_ptr( + h0, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0) + ) b_h3 += tl.load(p_h0_3, boundary_check=(0, 1)).to(tl.float32) if K > 192: - p_h0_4 = tl.make_block_ptr(h0, (K, V), (V, 1), (192, i_v * BV), - (64, BV), (1, 0)) + p_h0_4 = tl.make_block_ptr( + h0, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0) + ) b_h4 += tl.load(p_h0_4, boundary_check=(0, 1)).to(tl.float32) # main recurrence for i_t in range(NT): - p_h1 = tl.make_block_ptr(h + i_t * stride_h, (K, V), (V, 1), - (0, i_v * BV), (64, BV), (1, 0)) + p_h1 = tl.make_block_ptr( + h + i_t * stride_h, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0) + ) tl.store(p_h1, b_h1.to(p_h1.dtype.element_ty), boundary_check=(0, 1)) if K > 64: - p_h2 = tl.make_block_ptr(h + i_t * stride_h, (K, V), (V, 1), - (64, i_v * BV), (64, BV), (1, 0)) - tl.store(p_h2, - b_h2.to(p_h2.dtype.element_ty), - boundary_check=(0, 1)) + p_h2 = tl.make_block_ptr( + h + i_t * stride_h, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0) + ) + tl.store(p_h2, b_h2.to(p_h2.dtype.element_ty), boundary_check=(0, 1)) if K > 128: - p_h3 = tl.make_block_ptr(h + i_t * stride_h, (K, V), (V, 1), - (128, i_v * BV), (64, BV), (1, 0)) - tl.store(p_h3, - b_h3.to(p_h3.dtype.element_ty), - boundary_check=(0, 1)) + p_h3 = tl.make_block_ptr( + h + i_t * stride_h, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0) + ) + tl.store(p_h3, b_h3.to(p_h3.dtype.element_ty), boundary_check=(0, 1)) if K > 192: - p_h4 = tl.make_block_ptr(h + i_t * stride_h, (K, V), (V, 1), - (192, i_v * BV), (64, BV), (1, 0)) - tl.store(p_h4, - b_h4.to(p_h4.dtype.element_ty), - boundary_check=(0, 1)) + p_h4 = tl.make_block_ptr( + h + i_t * stride_h, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0) + ) + tl.store(p_h4, b_h4.to(p_h4.dtype.element_ty), boundary_check=(0, 1)) - p_v = tl.make_block_ptr(v, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), - (BT, BV), (1, 0)) - p_v_new = tl.make_block_ptr(v_new, (T, V), (stride_v, 1), - (i_t * BT, i_v * BV), (BT, BV), - (1, 0)) if SAVE_NEW_VALUE else None + p_v = tl.make_block_ptr( + v, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0) + ) + p_v_new = ( + tl.make_block_ptr( + v_new, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0) + ) + if SAVE_NEW_VALUE + else None + ) b_v_new = tl.zeros([BT, BV], dtype=tl.float32) - p_w = tl.make_block_ptr(w, (T, K), (stride_w, 1), (i_t * BT, 0), - (BT, 64), (1, 0)) + p_w = tl.make_block_ptr( + w, (T, K), (stride_w, 1), (i_t * BT, 0), (BT, 64), (1, 0) + ) b_w = tl.load(p_w, boundary_check=(0, 1)) b_v_new += tl.dot(b_w, b_h1.to(b_w.dtype)) if K > 64: - p_w = tl.make_block_ptr(w, (T, K), (stride_w, 1), (i_t * BT, 64), - (BT, 64), (1, 0)) + p_w = tl.make_block_ptr( + w, (T, K), (stride_w, 1), (i_t * BT, 64), (BT, 64), (1, 0) + ) b_w = tl.load(p_w, boundary_check=(0, 1)) b_v_new += tl.dot(b_w, b_h2.to(b_w.dtype)) if K > 128: - p_w = tl.make_block_ptr(w, (T, K), (stride_w, 1), (i_t * BT, 128), - (BT, 64), (1, 0)) + p_w = tl.make_block_ptr( + w, (T, K), (stride_w, 1), (i_t * BT, 128), (BT, 64), (1, 0) + ) b_w = tl.load(p_w, boundary_check=(0, 1)) b_v_new += tl.dot(b_w, b_h3.to(b_w.dtype)) if K > 192: - p_w = tl.make_block_ptr(w, (T, K), (stride_w, 1), (i_t * BT, 192), - (BT, 64), (1, 0)) + p_w = tl.make_block_ptr( + w, (T, K), (stride_w, 1), (i_t * BT, 192), (BT, 64), (1, 0) + ) b_w = tl.load(p_w, boundary_check=(0, 1)) b_v_new += tl.dot(b_w, b_h4.to(b_w.dtype)) b_v_new = -b_v_new + tl.load(p_v, boundary_check=(0, 1)) if SAVE_NEW_VALUE: - p_v_new = tl.make_block_ptr(v_new, (T, V), (stride_v, 1), - (i_t * BT, i_v * BV), (BT, BV), (1, 0)) - tl.store(p_v_new, - b_v_new.to(p_v_new.dtype.element_ty), - boundary_check=(0, 1)) + p_v_new = tl.make_block_ptr( + v_new, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0) + ) + tl.store( + p_v_new, b_v_new.to(p_v_new.dtype.element_ty), boundary_check=(0, 1) + ) if USE_G: m_t = (i_t * BT + tl.arange(0, BT)) < T last_idx = min((i_t + 1) * BT, T) - 1 b_g_last = tl.load(g + bos * H + last_idx * H + i_h) - p_g = tl.make_block_ptr(g + bos * H + i_h, (T, ), (H, ), - (i_t * BT, ), (BT, ), (0, )) - b_g = tl.load(p_g, boundary_check=(0, )) + p_g = tl.make_block_ptr( + g + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,) + ) + b_g = tl.load(p_g, boundary_check=(0,)) b_v_new = b_v_new * tl.where(m_t, exp(b_g_last - b_g), 0)[:, None] b_g_last = exp(b_g_last) b_h1 = b_h1 * b_g_last @@ -184,49 +201,49 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64( if K > 192: b_h4 = b_h4 * b_g_last b_v_new = b_v_new.to(k.dtype.element_ty) - p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (0, i_t * BT), - (64, BT), (0, 1)) + p_k = tl.make_block_ptr( + k, (K, T), (1, stride_k), (0, i_t * BT), (64, BT), (0, 1) + ) b_k = tl.load(p_k, boundary_check=(0, 1)) b_h1 += tl.dot(b_k, b_v_new) if K > 64: - p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (64, i_t * BT), - (64, BT), (0, 1)) + p_k = tl.make_block_ptr( + k, (K, T), (1, stride_k), (64, i_t * BT), (64, BT), (0, 1) + ) b_k = tl.load(p_k, boundary_check=(0, 1)) b_h2 += tl.dot(b_k, b_v_new) if K > 128: - p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (128, i_t * BT), - (64, BT), (0, 1)) + p_k = tl.make_block_ptr( + k, (K, T), (1, stride_k), (128, i_t * BT), (64, BT), (0, 1) + ) b_k = tl.load(p_k, boundary_check=(0, 1)) b_h3 += tl.dot(b_k, b_v_new) if K > 192: - p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (192, i_t * BT), - (64, BT), (0, 1)) + p_k = tl.make_block_ptr( + k, (K, T), (1, stride_k), (192, i_t * BT), (64, BT), (0, 1) + ) b_k = tl.load(p_k, boundary_check=(0, 1)) b_h4 += tl.dot(b_k, b_v_new) # epilogue if STORE_FINAL_STATE: - p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (0, i_v * BV), (64, BV), - (1, 0)) + p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0)) tl.store(p_ht, b_h1.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) if K > 64: - p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (64, i_v * BV), - (64, BV), (1, 0)) - tl.store(p_ht, - b_h2.to(p_ht.dtype.element_ty), - boundary_check=(0, 1)) + p_ht = tl.make_block_ptr( + ht, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0) + ) + tl.store(p_ht, b_h2.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) if K > 128: - p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (128, i_v * BV), - (64, BV), (1, 0)) - tl.store(p_ht, - b_h3.to(p_ht.dtype.element_ty), - boundary_check=(0, 1)) + p_ht = tl.make_block_ptr( + ht, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0) + ) + tl.store(p_ht, b_h3.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) if K > 192: - p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (192, i_v * BV), - (64, BV), (1, 0)) - tl.store(p_ht, - b_h4.to(p_ht.dtype.element_ty), - boundary_check=(0, 1)) + p_ht = tl.make_block_ptr( + ht, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0) + ) + tl.store(p_ht, b_h4.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) import vllm.model_executor.layers.fla.ops.chunk_delta_h diff --git a/vllm_metax/patch/model_executor/hook_register.py b/vllm_metax/patch/model_executor/hook_register.py index 4e663bff6..af1ae86a0 100644 --- a/vllm_metax/patch/model_executor/hook_register.py +++ b/vllm_metax/patch/model_executor/hook_register.py @@ -1,8 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 +from asyncio.log import logger from vllm.model_executor.layers.quantization import ( - _CUSTOMIZED_METHOD_TO_QUANT_CONFIG, QUANTIZATION_METHODS, - QuantizationConfig, register_quantization_config) + _CUSTOMIZED_METHOD_TO_QUANT_CONFIG, + QUANTIZATION_METHODS, + QuantizationConfig, + register_quantization_config, +) def register_quantization_config(quantization: str): @@ -15,9 +19,13 @@ def register_quantization_config(quantization: str): quantization (str): The quantization method name. Examples: - >>> from vllm.model_executor.layers.quantization import register_quantization_config + >>> from vllm.model_executor.layers.quantization import ( + ... register_quantization_config, + ... ) >>> from vllm.model_executor.layers.quantization import get_quantization_config - >>> from vllm.model_executor.layers.quantization.base_config import QuantizationConfig + >>> from vllm.model_executor.layers.quantization.base_config import ( + ... QuantizationConfig, + ... ) >>> >>> @register_quantization_config("my_quant") ... class MyQuantConfig(QuantizationConfig): @@ -28,9 +36,17 @@ def register_quantization_config(quantization: str): """ # noqa: E501 def _wrapper(quant_config_cls): + if quantization in QUANTIZATION_METHODS: + logger.warning( + "The quantization method `%s` is already exists." + " and will be overwritten by the quantization config %s.", + quantization, + quant_config_cls, + ) if not issubclass(quant_config_cls, QuantizationConfig): - raise ValueError("The quantization config must be a subclass of " - "`QuantizationConfig`.") + raise ValueError( + "The quantization config must be a subclass of `QuantizationConfig`." + ) _CUSTOMIZED_METHOD_TO_QUANT_CONFIG[quantization] = quant_config_cls QUANTIZATION_METHODS.append(quantization) return quant_config_cls @@ -40,4 +56,6 @@ def _wrapper(quant_config_cls): import vllm.model_executor.layers.quantization -vllm.model_executor.layers.quantization.register_quantization_config = register_quantization_config +vllm.model_executor.layers.quantization.register_quantization_config = ( + register_quantization_config +) diff --git a/vllm_metax/patch/model_executor/rotary_embedding.py b/vllm_metax/patch/model_executor/rotary_embedding.py index af426ea75..9f92df3f3 100644 --- a/vllm_metax/patch/model_executor/rotary_embedding.py +++ b/vllm_metax/patch/model_executor/rotary_embedding.py @@ -8,9 +8,9 @@ import torch -def apply_rotary_emb_dispatch(x: torch.Tensor, cos: torch.Tensor, - sin: torch.Tensor, - is_neox_style: bool) -> torch.Tensor: +def apply_rotary_emb_dispatch( + x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, is_neox_style: bool +) -> torch.Tensor: """ Args: x: [num_tokens, num_heads, head_size] @@ -20,10 +20,12 @@ def apply_rotary_emb_dispatch(x: torch.Tensor, cos: torch.Tensor, positional embeddings. """ from flash_attn.layers.rotary import apply_rotary_emb - return apply_rotary_emb(x.unsqueeze(0), cos, sin, - not is_neox_style).squeeze(0) + + return apply_rotary_emb(x.unsqueeze(0), cos, sin, not is_neox_style).squeeze(0) import vllm.model_executor.layers.rotary_embedding.common -vllm.model_executor.layers.rotary_embedding.common.apply_rotary_emb_dispatch = apply_rotary_emb_dispatch +vllm.model_executor.layers.rotary_embedding.common.apply_rotary_emb_dispatch = ( + apply_rotary_emb_dispatch +) diff --git a/vllm_metax/patch/model_executor/utils.py b/vllm_metax/patch/model_executor/utils.py index 0a9f28c2d..628d94351 100644 --- a/vllm_metax/patch/model_executor/utils.py +++ b/vllm_metax/patch/model_executor/utils.py @@ -5,8 +5,9 @@ import vllm from vllm.logger import init_logger from vllm.model_executor.layers.quantization.utils.int8_utils import ( - per_token_group_quant_int8) -from vllm.utils import cdiv + per_token_group_quant_int8, +) +from vllm.utils.math_utils import cdiv from vllm import _custom_ops as ops from vllm.model_executor.layers.fused_moe import utils @@ -29,8 +30,7 @@ def _int8_quantize( # activations apply per-token quantization. Otherwise, assume # activation tensor-wise fp8/int8 quantization, dynamic or static if block_shape is None: - assert per_act_token, \ - "int8 quantization only supports block or channel-wise" + assert per_act_token, "int8 quantization only supports block or channel-wise" # ┌------------------------ Metax Modification -------------------------┐ # A, A_scale = per_token_quant_int8(A) diff --git a/vllm_metax/patch/oot/__init__.py b/vllm_metax/patch/oot/__init__.py index 4fe2915ef..27ff2b671 100644 --- a/vllm_metax/patch/oot/__init__.py +++ b/vllm_metax/patch/oot/__init__.py @@ -1,2 +1,2 @@ # SPDX-License-Identifier: Apache-2.0 -from . import scaled_mm \ No newline at end of file +from . import scaled_mm diff --git a/vllm_metax/patch/oot/scaled_mm.py b/vllm_metax/patch/oot/scaled_mm.py index 54780dfb1..1e103882c 100644 --- a/vllm_metax/patch/oot/scaled_mm.py +++ b/vllm_metax/patch/oot/scaled_mm.py @@ -2,17 +2,18 @@ from typing import Optional from vllm.model_executor.layers.quantization.kernels.scaled_mm.cutlass import ( - CutlassScaledMMLinearKernel) + CutlassScaledMMLinearKernel, +) from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501 - ScaledMMLinearKernel, ScaledMMLinearLayerConfig) + ScaledMMLinearKernel, + ScaledMMLinearLayerConfig, +) from vllm.platforms import PlatformEnum class MctlassScaledMMLinearKernel(CutlassScaledMMLinearKernel): - @classmethod - def can_implement( - cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]: + def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]: return True, None @@ -22,4 +23,6 @@ def can_implement( import vllm.model_executor.layers.quantization.kernels.scaled_mm -vllm.model_executor.layers.quantization.kernels.scaled_mm._POSSIBLE_KERNELS = _POSSIBLE_KERNELS +vllm.model_executor.layers.quantization.kernels.scaled_mm._POSSIBLE_KERNELS = ( + _POSSIBLE_KERNELS +) diff --git a/vllm_metax/patch/hotfix/__init__.py b/vllm_metax/patch/sample/__init__.py similarity index 53% rename from vllm_metax/patch/hotfix/__init__.py rename to vllm_metax/patch/sample/__init__.py index 599f421e4..bc96e2874 100644 --- a/vllm_metax/patch/hotfix/__init__.py +++ b/vllm_metax/patch/sample/__init__.py @@ -1,2 +1,3 @@ # SPDX-License-Identifier: Apache-2.0 -from . import patch_utils \ No newline at end of file + +from . import rejection_sampler diff --git a/vllm_metax/patch/sample/rejection_sampler.py b/vllm_metax/patch/sample/rejection_sampler.py new file mode 100644 index 000000000..0a1dcb873 --- /dev/null +++ b/vllm_metax/patch/sample/rejection_sampler.py @@ -0,0 +1,100 @@ +# SPDX-License-Identifier: Apache-2.0 +from vllm.triton_utils import tl, triton + +import vllm.v1.sample.rejection_sampler + +# SPDX-License-Identifier: Apache-2.0 + + +@triton.jit +def sample_recovered_tokens_kernel( + output_token_ids_ptr, # [num_tokens] + cu_num_draft_tokens_ptr, # [batch_size] + draft_token_ids_ptr, # [num_tokens] + draft_probs_ptr, # [num_tokens, vocab_size] or None + target_probs_ptr, # [num_tokens, vocab_size] + q_ptr, # [batch_size, vocab_size] + vocab_size, + PADDED_VOCAB_SIZE: tl.constexpr, + NO_DRAFT_PROBS: tl.constexpr, + BLOCK_SIZE: tl.constexpr = 1024, +): + req_idx = tl.program_id(0) + if req_idx == 0: + start_idx = 0 + else: + start_idx = tl.load(cu_num_draft_tokens_ptr + req_idx - 1) + end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx) + num_draft_tokens = end_idx - start_idx + + # Early exit for out-of-range positions. + pos = tl.program_id(1) + if pos >= num_draft_tokens: + return + + max_prob = -float("inf") + best_token_id = 0 + + for block_start in range(0, PADDED_VOCAB_SIZE, BLOCK_SIZE): + block_end = min(block_start + BLOCK_SIZE, vocab_size) + + vocab_offset = tl.arange(0, BLOCK_SIZE) + mask = vocab_offset < block_end - block_start + + if NO_DRAFT_PROBS: + draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos) + prob = tl.load( + target_probs_ptr + + (start_idx + pos) * vocab_size + + block_start + + vocab_offset, + mask=(mask & (vocab_offset + block_start != draft_token_id)), + other=0, + ) + + else: + draft_prob = tl.load( + draft_probs_ptr + + (start_idx + pos) * vocab_size + + block_start + + vocab_offset, + mask=mask, + other=0, + ) + target_prob = tl.load( + target_probs_ptr + + (start_idx + pos) * vocab_size + + block_start + + vocab_offset, + mask=mask, + other=0, + ) + prob = tl.maximum(target_prob - draft_prob, 0) + + # NOTE(woosuk): We don't need `prob = prob / tl.sum(prob)` here because + # `tl.argmax` will select the maximum value. + + q = tl.load( + q_ptr + req_idx * vocab_size + block_start + vocab_offset, + mask=mask, + other=float("-inf"), + ) + + # recovered_id = tl.argmax(prob / q, axis=-1) + # calc block prob and token ID + block_prob = prob / q + block_max_prob = tl.max(block_prob, axis=-1) + block_best_token_id = tl.argmax(block_prob, axis=-1) + block_start + + # update token ID + max_prob = tl.maximum(max_prob, block_max_prob) + best_token_id = tl.where( + block_max_prob >= max_prob, block_best_token_id, best_token_id + ) + + tl.store(output_token_ids_ptr + start_idx + pos, best_token_id) + + +vllm.v1.sample.rejection_sampler.sample_recovered_tokens_kernel = ( + sample_recovered_tokens_kernel +) diff --git a/vllm_metax/patch/vllm_substitution/utils.py b/vllm_metax/patch/vllm_substitution/utils.py deleted file mode 100644 index 773d08d4c..000000000 --- a/vllm_metax/patch/vllm_substitution/utils.py +++ /dev/null @@ -1,177 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang -# -# This file contains code copied from the flash-linear-attention project. -# The original source code was licensed under the MIT license and included -# the following copyright notice: -# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang -# ruff: noqa: E501 -import contextlib -import functools -import logging -import os -from enum import Enum -from typing import Any, Callable, Literal, Optional - -import torch - -from vllm.triton_utils import triton - -logger = logging.getLogger(__name__) - -COMPILER_MODE = os.getenv("FLA_COMPILER_MODE") == "1" -FLA_CI_ENV = os.getenv("FLA_CI_ENV") == "1" -FLA_GDN_FIX_BT = os.getenv("FLA_GDN_FIX_BT", "0") == "1" - -SUPPRESS_LEVEL = int(os.getenv("GDN_RECOMPUTE_SUPPRESS_LEVEL", "0")) - - -def tensor_cache( - fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]: - """ - A decorator that caches the most recent results of a function with tensor inputs. - - This decorator will store the output of the decorated function for the most recent set of input tensors. - The cache is limited to a fixed size (default is 4). When the cache is full, the oldest entry will be removed. - - Args: - fn (Callable[..., torch.Tensor]): - The function to be decorated. It should take tensor inputs and return tensor outputs. - - Returns: - Callable[..., torch.Tensor]: - A wrapped version of the input function with single-entry caching. - """ - - cache_entries: tuple[Optional[tuple], Optional[dict], Any] = [] - cache_size = 4 - - @functools.wraps(fn) - def wrapper(*args: Any, **kwargs: Any) -> Any: - nonlocal cache_entries, cache_size - for i, entry in enumerate(cache_entries): - last_args, last_kwargs, last_result = entry - if len(args) == len(last_args) and len(kwargs) == len(last_kwargs) \ - and all(a is b for a, b in zip(args, last_args)) \ - and all(k in last_kwargs and v is last_kwargs[k] for k, v in kwargs.items()): - cache_entries = cache_entries[:i] + cache_entries[i + 1:] + [ - (args, kwargs, last_result) - ] - return last_result - - result = fn(*args, **kwargs) - - if len(cache_entries) >= cache_size: - cache_entries = cache_entries[1:] - cache_entries.append((args, kwargs, result)) - return result - - return wrapper - - -def input_guard( - fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]: - """ - A decorator to make sure all input tensors are contiguous and set the device based on input tensors. - """ - - @functools.wraps(fn) - def wrapper(*args, **kwargs): - contiguous_args = (i if not isinstance(i, torch.Tensor) else - i.contiguous() for i in args) - contiguous_kwargs = { - k: (v if not isinstance(v, torch.Tensor) else v.contiguous()) - for k, v in kwargs.items() - } - - tensor = None - for arg in args: - if isinstance(arg, torch.Tensor): - tensor = arg - break - if tensor is None: - for value in kwargs.values(): - if isinstance(value, torch.Tensor): - tensor = value - break - - if tensor is not None: - ctx = torch.cuda.device(tensor.device.index) - else: - ctx = contextlib.nullcontext() - - with ctx: - return fn(*contiguous_args, **contiguous_kwargs) - - return wrapper - - -@functools.cache -def get_available_device() -> str: - return 'cuda' - - -@functools.cache -def _check_platform() -> Literal['nvidia', 'amd', 'intel', 'musa']: - device = get_available_device() - mapping = { - "cuda": "nvidia", - "hip": "amd", - "xpu": "intel", - } - # return the mapped value, or the original if not found - return mapping.get(device, device) - - -# For AMD GPUs, the triton backend is 'hip', while for Nvidia GPUs, the triton backend is 'cuda'. -# However, the torch backend is 'cuda' for both Nvidia and AMD GPUs. -# Therefore, we need to check the triton backend to determine the actual GPU vendor. -device = get_available_device() if get_available_device() != 'hip' else 'cuda' -device_torch_lib = getattr(torch, device) -device_platform = _check_platform() - -is_amd = (device_platform == 'amd') -is_intel = (device_platform == 'intel') -is_nvidia = (device_platform == 'nvidia') -is_intel_alchemist = (is_intel - and 'Intel(R) Arc(TM) A' in torch.xpu.get_device_name(0)) -is_nvidia_hopper = (is_nvidia - and ('NVIDIA H' in torch.cuda.get_device_name(0) - or torch.cuda.get_device_capability()[0] >= 9)) -use_cuda_graph = (is_nvidia - and os.environ.get('FLA_USE_CUDA_GRAPH', '0') == '1') - - -def get_all_max_shared_mem(): - try: - return [ - triton.runtime.driver.active.utils.get_device_properties(i) - ['max_shared_mem'] for i in range(device_torch_lib.device_count()) - ] - except BaseException: - return [-1] - - -class Backend(Enum): - ADA = 101376 # RTX 4090 - AMPERE = 166912 # A100 - HOPPER = 232448 # H100 - DEFAULT = 102400 # Default - - @classmethod - def get_shared_memory(cls, arch: str) -> int: - try: - return cls[arch.upper()].value - except KeyError: - return cls.DEFAULT.value - - -@functools.cache -def check_shared_mem(arch: str = "none", tensor_idx: int = 0) -> bool: - try: - device_shared_mem_list = get_all_max_shared_mem() - max_shared_memory = device_shared_mem_list[tensor_idx] - return max_shared_memory >= Backend.get_shared_memory(arch) - except Exception: - return False diff --git a/vllm_metax/platform.py b/vllm_metax/platform.py index dd352d2c8..33ee69d3c 100644 --- a/vllm_metax/platform.py +++ b/vllm_metax/platform.py @@ -3,26 +3,26 @@ pynvml. However, it should not initialize cuda context. """ +import contextlib import os -from datetime import timedelta -from functools import wraps -from typing import TYPE_CHECKING, Callable, List, Optional, TypeVar, Union +from collections.abc import Callable +from functools import cache, wraps +from typing import TYPE_CHECKING, TypeVar import torch -# import custom ops, trigger op registration -import vllm.envs as envs -from torch.distributed import PrefixStore, ProcessGroup -from torch.distributed.distributed_c10d import is_nccl_available from typing_extensions import ParamSpec -from vllm.logger import logger -from vllm.platforms.interface import (DeviceCapability, FlexibleArgumentParser, - Platform, PlatformEnum, _Backend) -from vllm.utils import cuda_device_count_stateless +import vllm.envs as envs +from vllm.logger import logger from vllm_metax.utils import import_pymxml +from vllm.utils.torch_utils import cuda_device_count_stateless + +from vllm.platforms.interface import DeviceCapability, Platform, PlatformEnum +from vllm.utils.argparse_utils import FlexibleArgumentParser if TYPE_CHECKING: - from vllm.config import ModelConfig, VllmConfig + from vllm.attention.backends.registry import _Backend + from vllm.config import VllmConfig _P = ParamSpec("_P") _R = TypeVar("_R") @@ -36,7 +36,6 @@ def with_mxml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]: - @wraps(fn) def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R: pymxml.nvmlInit() @@ -58,8 +57,12 @@ class MacaPlatformBase(Platform): device_control_env_var: str = "CUDA_VISIBLE_DEVICES" supported_quantization: list[str] = [ - "awq", "gptq", "compressed-tensors", "compressed_tensors", "moe_wna16", - "gguf" + "awq", + "gptq", + "compressed-tensors", + "compressed_tensors", + "moe_wna16", + "gguf", ] @classmethod @@ -74,9 +77,7 @@ def set_device(cls, device: torch.device) -> None: _ = torch.zeros(1, device=device) @classmethod - def get_device_capability(cls, - device_id: int = 0 - ) -> Optional[DeviceCapability]: + def get_device_capability(cls, device_id: int = 0) -> DeviceCapability | None: raise NotImplementedError @classmethod @@ -103,6 +104,16 @@ def is_fully_connected(cls, device_ids: list[int]) -> bool: def log_warnings(cls): pass + @classmethod + def import_kernels(cls) -> None: + """Import any platform-specific C kernels.""" + try: + import vllm_metax._C # noqa: F401 + except ImportError as e: + logger.warning("Failed to import from vllm_metax._C: %r", e) + with contextlib.suppress(ImportError): + import vllm_metax._moe_C # noqa: F401 + @classmethod def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: # Env Override @@ -114,17 +125,7 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: model_config = vllm_config.model_config if parallel_config.worker_cls == "auto": - if vllm_config.speculative_config: - if not envs.VLLM_USE_V1: - raise NotImplementedError( - "Speculative decoding is not supported on vLLM V0.") - parallel_config.worker_cls = "vllm.v1.worker.gpu_worker.Worker" - else: - if envs.VLLM_USE_V1: - parallel_config.worker_cls = \ - "vllm.v1.worker.gpu_worker.Worker" - else: - parallel_config.worker_cls = "vllm.worker.worker.Worker" + parallel_config.worker_cls = "vllm.v1.worker.gpu_worker.Worker" cache_config = vllm_config.cache_config if cache_config and cache_config.block_size is None: @@ -132,9 +133,12 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: # TODO(lucas): handle this more gracefully # Note: model_config may be None during testing - if model_config is not None and model_config.use_mla: - use_sparse = hasattr(vllm_config.model_config.hf_config, - "index_topk") + if ( + model_config is not None + and model_config.use_mla + and cache_config.block_size is not None + ): + use_sparse = hasattr(vllm_config.model_config.hf_config, "index_topk") # If `VLLM_ATTENTION_BACKEND` is not set and we are using MLA, # then we default to FlashMLA backend for non-blackwell GPUs, # else we default to CutlassMLA. For each case, we force the @@ -157,34 +161,51 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: use_flashmla = True else: # Forced case - use_flashmla = (envs.VLLM_ATTENTION_BACKEND == "FLASHMLA") - use_cutlass_mla = ( - envs.VLLM_ATTENTION_BACKEND == "CUTLASS_MLA") + use_flashmla = envs.VLLM_ATTENTION_BACKEND == "FLASHMLA" + use_cutlass_mla = envs.VLLM_ATTENTION_BACKEND == "CUTLASS_MLA" + use_flashinfer_mla = envs.VLLM_ATTENTION_BACKEND == "FLASHINFER_MLA" + + from vllm_metax.attention.ops.flashmla import is_flashmla_dense_supported - from vllm_metax.attention.ops.flashmla import is_flashmla_supported - if use_flashmla and is_flashmla_supported()[0] \ - and cache_config.block_size != 64: + if ( + use_flashmla + and is_flashmla_dense_supported()[0] + and cache_config.block_size % 64 != 0 + ): cache_config.block_size = 64 - logger.info( - "Forcing kv cache block size to 64 for FlashMLA backend.") + logger.info("Forcing kv cache block size to 64 for FlashMLA backend.") if use_cutlass_mla and cache_config.block_size != 128: cache_config.block_size = 128 - logger.info("Forcing kv cache block size to 128 for " - "CUTLASS_MLA backend.") + logger.info( + "Forcing kv cache block size to 128 for CUTLASS_MLA backend." + ) + + if ( + use_flashinfer_mla + and cache_config.block_size != 32 + and cache_config.block_size % 64 != 0 + ): + cache_config.block_size = 64 + logger.info( + "Forcing kv cache block size to 64 for FlashInferMLA backend." + ) + # TODO(Chen): remove this hacky code if use_sparse and cache_config.block_size != 64: cache_config.block_size = 64 logger.info( - "Forcing kv cache block size to 64 for FlashMLASparse " - "backend.") + "Forcing kv cache block size to 64 for FlashMLASparse backend." + ) # lazy import to avoid circular import from vllm.config import CUDAGraphMode compilation_config = vllm_config.compilation_config - if (envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput" - and parallel_config.data_parallel_size > 1 - and compilation_config.cudagraph_mode != CUDAGraphMode.NONE): + if ( + envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput" + and parallel_config.data_parallel_size > 1 + and compilation_config.cudagraph_mode != CUDAGraphMode.NONE + ): # TODO: Piecewise Cuda graph might be enabled # if torch compile cache key issue fixed # See https://github.com/vllm-project/vllm/pull/25093 @@ -194,111 +215,154 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: "CUDA Graphs. " "In order to use CUDA Graphs for decode-optimized workloads, " "set VLLM_ALL2ALL_BACKEND to another option, such as " - "deepep_low_latency, pplx, or allgather_reducescatter.") + "deepep_low_latency, pplx, or allgather_reducescatter." + ) compilation_config.cudagraph_mode = CUDAGraphMode.NONE + # Reduce the cudagraph capture sizes on Maca to avoid OOM issues + compilation_config.max_cudagraph_capture_size = 256 + compilation_config.cudagraph_capture_sizes = [ + size + for size in compilation_config.cudagraph_capture_sizes + if size <= compilation_config.max_cudagraph_capture_size + ] + compilation_config.compile_sizes = [ + size + for size in compilation_config.compile_sizes + if size <= compilation_config.max_cudagraph_capture_size + ] + compilation_config.bs_to_padded_graph_size = [ + size + for size in compilation_config.bs_to_padded_graph_size + if size <= compilation_config.max_cudagraph_capture_size + ] - if vllm_config.model_config is not None and \ - not vllm_config.model_config.enforce_eager and \ - compilation_config.cudagraph_capture_sizes is not None: - batch_size_capture_list = [ - size for size in compilation_config.cudagraph_capture_sizes - if size < 257 - ] - compilation_config.cudagraph_capture_sizes = None - compilation_config.init_with_cudagraph_sizes( - batch_size_capture_list) - + # Disable cascade attention for Maca platform currently if vllm_config.model_config is not None: model_config.disable_cascade_attn = True @classmethod - def get_current_memory_usage(cls, - device: Optional[torch.types.Device] = None - ) -> float: + def get_current_memory_usage( + cls, device: torch.types.Device | None = None + ) -> float: torch.cuda.empty_cache() torch.cuda.reset_peak_memory_stats(device) return torch.cuda.max_memory_allocated(device) @classmethod - def get_vit_attn_backend(cls, head_size: int, - dtype: torch.dtype) -> _Backend: - if dtype not in (torch.float16, torch.bfloat16): - return _Backend.XFORMERS + def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> "_Backend": + from vllm.attention.backends.registry import _Backend - FLASH_ATTN_V1 = "vllm_metax.v1.attention.backends.flash_attn.MacaFlashAttentionBackend" # noqa: E501 + # TODO(Hank) Need to check which is better between + # TORCH_SDPA or FLASH_ATTN on Maca platform + FLASH_ATTN_V1 = ( + "vllm_metax.v1.attention.backends.flash_attn.MacaFlashAttentionBackend" # noqa: E501 + ) from vllm.attention.selector import is_attn_backend_supported - is_default_fa_supported = is_attn_backend_supported( - FLASH_ATTN_V1, head_size, dtype, allow_import_error=False) - if is_default_fa_supported: + + if is_default_fa_supported := is_attn_backend_supported( + FLASH_ATTN_V1, head_size, dtype, allow_import_error=False + ): return _Backend.FLASH_ATTN else: - # Fallback to XFORMERS - return _Backend.XFORMERS + use_sdpa_attention_reason = {} + if not is_default_fa_supported.head_size: + use_sdpa_attention_reason["head_size"] = head_size + if not is_default_fa_supported.dtype: + use_sdpa_attention_reason["dtype"] = dtype + logger.warning( + "Fallback to Backend TORCH_SDPA as vit_attn_backend since %s is " + "not supported on FLASH_ATTN.", + ", ".join(f"{k}={v}" for k, v in use_sdpa_attention_reason.items()), + ) + return _Backend.TORCH_SDPA @classmethod - def get_attn_backend_cls(cls, selected_backend, head_size, dtype, - kv_cache_dtype, block_size, use_v1, use_mla, - has_sink, use_sparse) -> str: + def get_attn_backend_cls( + cls, + selected_backend, + head_size, + dtype, + kv_cache_dtype, + block_size, + use_v1, + use_mla, + has_sink, + use_sparse, + ) -> str: + from vllm.attention.backends.registry import _Backend + if use_mla: if not use_v1: raise RuntimeError( "MLA attention backends require the V1 engine. " - "Set VLLM_USE_V1=1 to enable them.") + "Set VLLM_USE_V1=1 to enable them." + ) + + from vllm_metax.attention.ops.flashmla import is_flashmla_dense_supported + from vllm_metax.attention.utils.fa_utils import flash_attn_supports_mla if use_sparse: logger.info_once("Using Sparse MLA backend on V1 engine.") - return ("vllm_metax.v1.attention.backends.mla.flashmla_sparse." - "MacaFlashMLASparseBackend") - - from vllm_metax.attention.ops.flashmla import is_flashmla_supported + return ( + "vllm_metax.v1.attention.backends.mla.flashmla_sparse." + "MacaFlashMLASparseBackend" + ) - use_cutlassmla = _Backend.CUTLASS_MLA or ( - cls.is_device_capability(100) and selected_backend is None - and block_size == 128) + use_cutlassmla = selected_backend == _Backend.CUTLASS_MLA or ( + selected_backend is None and block_size % 128 == 0 + ) + use_flashinfermla = selected_backend == _Backend.FLASHINFER_MLA or ( + selected_backend is None and (block_size == 32 or block_size % 64 == 0) + ) use_flashmla = selected_backend == _Backend.FLASHMLA or ( - selected_backend is None and is_flashmla_supported()[0]) + selected_backend is None and is_flashmla_dense_supported()[0] + ) + use_flashattn_mla = selected_backend == _Backend.FLASH_ATTN_MLA or ( + selected_backend is None and flash_attn_supports_mla() + ) use_triton = selected_backend == _Backend.TRITON_MLA or ( - selected_backend is None) - - def _get_version(name, import_suffix) -> str: - if use_v1: - logger.info_once(f"Using {name} backend on V1 engine.") - return f"vllm_metax.v1.attention.backends.mla.{import_suffix}" - else: - raise AssertionError( - f"{name} backend is only supported on V1 engine") + selected_backend is None + ) if use_flashmla: - if block_size != 64: + if block_size % 64 != 0: logger.warning( "FlashMLA backend is not supported for block size %d" " (currently only supports block size 64).", - block_size) + block_size, + ) else: - return _get_version("Maca FlashMLA", - "flashmla.MacaFlashMLABackend") + logger.info_once("Using FlashMLA backend on V1 engine.") + return "vllm_metax.v1.attention.backends.mla.flashmla.MacaFlashMLABackend" # noqa: E501 if use_triton: - return _get_version("Maca Triton MLA", - "triton_mla.MacaTritonMLABackend") + logger.info_once("Using Triton MLA backend on V1 engine.") + return "vllm_metax.v1.attention.backends.mla.triton_mla.MacaTritonMLABackend" # noqa: E501 # default mla logger.warning( "Selected MLA backend is not valid, falling back to Triton MLA." ) - return _get_version("Maca Triton MLA", - "triton_mla.MacaTritonMLABackend") + return ( + "vllm_metax.v1.attention.backends.mla.triton_mla.MacaTritonMLABackend" # noqa: E501 + ) if use_v1: assert not use_mla - FLASHINFER_V1 = "vllm_metax.v1.attention.backends.flashinfer.MacaFlashInferBackend" # noqa: E501 - FLEX_ATTENTION_V1 = "vllm_metax.v1.attention.backends.flex_attention.FlexAttentionBackend" # noqa: E501 + FLASHINFER_V1 = ( + "vllm_metax.v1.attention.backends.flashinfer.MacaFlashInferBackend" # noqa: E501 + ) + FLEX_ATTENTION_V1 = "vllm_metax.v1.attention.backends.flex_attention.MacaFlexAttentionBackend" # noqa: E501 TRITON_ATTN = "vllm_metax.v1.attention.backends.triton_attn.MacaTritonAttentionBackend" # noqa: E501 - FLASH_ATTN_V1 = "vllm_metax.v1.attention.backends.flash_attn.MacaFlashAttentionBackend" # noqa: E501 - TREE_ATTN_V1 = "vllm_metax.v1.attention.backends.tree_attn.TreeAttentionBackend" # noqa: E501 + FLASH_ATTN_V1 = ( + "vllm_metax.v1.attention.backends.flash_attn.MacaFlashAttentionBackend" # noqa: E501 + ) + TREE_ATTN_V1 = ( + "vllm_metax.v1.attention.backends.tree_attn.MacaTreeAttentionBackend" # noqa: E501 + ) if selected_backend == _Backend.FLASHINFER: logger.info_once("Using FlashInfer backend on V1 engine.") - from vllm.v1.attention.backends.utils import ( - set_kv_cache_layout) + from vllm.v1.attention.backends.utils import set_kv_cache_layout + set_kv_cache_layout("HND") return FLASHINFER_V1 elif selected_backend == _Backend.FLEX_ATTENTION: @@ -319,18 +383,19 @@ def _get_version(name, import_suffix) -> str: # Default backends for V1 engine # FlashAttention is the default for MetaX GPUs if is_default_backend_supported := is_attn_backend_supported( - FLASH_ATTN_V1, head_size, dtype, allow_import_error=False): - logger.info_once("Using Flash Attention backend on " - "V1 engine.") + FLASH_ATTN_V1, head_size, dtype, allow_import_error=False + ): + logger.info_once("Using Flash Attention backend on V1 engine.") return FLASH_ATTN_V1 if is_default_backend_supported := is_attn_backend_supported( - FLASHINFER_V1, head_size, dtype): - from vllm.v1.attention.backends.utils import ( - set_kv_cache_layout) + FLASHINFER_V1, head_size, dtype + ): + from vllm.v1.attention.backends.utils import set_kv_cache_layout logger.info_once( "Using FlashInfer backend with HND KV cache layout on " - "V1 engine by default for MetaX GPUs.") + "V1 engine by default for MetaX GPUs." + ) set_kv_cache_layout("HND") return FLASHINFER_V1 @@ -346,14 +411,14 @@ def _get_version(name, import_suffix) -> str: logger.info_once( "Using FlexAttention backend for %s on V1 engine.", - ", ".join(f"{k}={v}" - for k, v in use_flex_attention_reason.items()), + ", ".join(f"{k}={v}" for k, v in use_flex_attention_reason.items()), ) return FLEX_ATTENTION_V1 raise RuntimeError( "V0 attention backends have been removed. Set VLLM_USE_V1=1 " - "to select a supported backend.") + "to select a supported backend." + ) @classmethod def get_punica_wrapper(cls) -> str: @@ -361,7 +426,9 @@ def get_punica_wrapper(cls) -> str: @classmethod def get_device_communicator_cls(cls) -> str: - return "vllm.distributed.device_communicators.cuda_communicator.CudaCommunicator" # noqa + return ( + "vllm.distributed.device_communicators.cuda_communicator.CudaCommunicator" # noqa + ) @classmethod def supports_fp8(cls) -> bool: @@ -379,52 +446,39 @@ def opaque_attention_op(cls) -> bool: def get_static_graph_wrapper_cls(cls) -> str: return "vllm.compilation.cuda_graph.CUDAGraphWrapper" - @classmethod - def stateless_init_device_torch_dist_pg( - cls, - backend: str, - prefix_store: PrefixStore, - group_rank: int, - group_size: int, - timeout: timedelta, - ) -> ProcessGroup: - assert is_nccl_available() - pg: ProcessGroup = ProcessGroup( - prefix_store, - group_rank, - group_size, - ) - from torch.distributed.distributed_c10d import ProcessGroupNCCL - - backend_options = ProcessGroupNCCL.Options() - backend_options._timeout = timeout - - backend_class = ProcessGroupNCCL(prefix_store, group_rank, group_size, - backend_options) - backend_type = ProcessGroup.BackendType.NCCL - device = torch.device("cuda") - pg._set_default_backend(backend_type) - backend_class._set_sequence_number_for_group() - - pg._register_backend(device, backend_type, backend_class) - return pg - @classmethod def device_count(cls) -> int: return cuda_device_count_stateless() - @classmethod - def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str, - model_config: "ModelConfig") -> bool: - fp8_attention = kv_cache_dtype.startswith("fp8") - return (not fp8_attention) - @classmethod def check_if_supports_dtype(cls, torch_dtype: torch.dtype): - if torch_dtype == torch.float8_e4m3fn or \ - torch_dtype == torch.float8_e5m2: # noqa + if torch_dtype == torch.float8_e4m3fn or torch_dtype == torch.float8_e5m2: # noqa raise ValueError("FP8 is not supported on GPUs ") + @classmethod + def insert_blocks_to_device( + cls, + src_cache: torch.Tensor, + dst_cache: torch.Tensor, + src_block_indices: torch.Tensor, + dst_block_indices: torch.Tensor, + ) -> None: + """Copy blocks from src_cache to dst_cache on GPU.""" + _src_cache = src_cache[:, src_block_indices] + dst_cache[:, dst_block_indices] = _src_cache.to(dst_cache.device) + + @classmethod + def swap_out_blocks_to_host( + cls, + src_cache: torch.Tensor, + dst_cache: torch.Tensor, + src_block_indices: torch.Tensor, + dst_block_indices: torch.Tensor, + ) -> None: + """Copy blocks from GPU to host (CPU).""" + _src_cache = src_cache[:, src_block_indices] + dst_cache[:, dst_block_indices] = _src_cache.cpu() + @classmethod def support_hybrid_kv_cache(cls) -> bool: return True @@ -434,11 +488,11 @@ def support_static_graph_mode(cls) -> bool: return True @classmethod - def pre_register_and_update(cls, - parser: Optional[FlexibleArgumentParser] = None - ) -> None: - logger.info("[hook] platform:pre_register_and_update...") - import vllm_metax.patch # noqa: F401 + def pre_register_and_update( + cls, parser: FlexibleArgumentParser | None = None + ) -> None: + # TODO(m01016): update cudagraph max capture size here + logger.info("Pre-registering and updating Maca platform.") # NVML utils @@ -446,12 +500,10 @@ def pre_register_and_update(cls, # all the related functions work on real physical device ids. # the major benefit of using NVML is that it will not initialize CUDA class MxmlPlatform(MacaPlatformBase): - @classmethod + @cache @with_mxml_context - def get_device_capability(cls, - device_id: int = 0 - ) -> Optional[DeviceCapability]: + def get_device_capability(cls, device_id: int = 0) -> DeviceCapability | None: try: physical_device_id = cls.device_id_to_physical_device_id(device_id) handle = pymxml.nvmlDeviceGetHandleByIndex(physical_device_id) @@ -464,7 +516,7 @@ def get_device_capability(cls, @with_mxml_context def has_device_capability( cls, - capability: Union[tuple[int, int], int], + capability: tuple[int, int] | int, device_id: int = 0, ) -> bool: try: @@ -497,9 +549,7 @@ def is_fully_connected(cls, physical_device_ids: list[int]) -> bool: """ query if the set of gpus are fully connected by nvlink (1 hop) """ - handles = [ - pymxml.nvmlDeviceGetHandleByIndex(i) for i in physical_device_ids - ] + handles = [pymxml.nvmlDeviceGetHandleByIndex(i) for i in physical_device_ids] for i, handle in enumerate(handles): for j, peer_handle in enumerate(handles): if i < j: @@ -514,7 +564,8 @@ def is_fully_connected(cls, physical_device_ids: list[int]) -> bool: except pymxml.NVMLError: logger.exception( "NVLink detection failed. This is normal if" - " your machine has no NVLink equipped.") + " your machine has no NVLink equipped." + ) return False return True @@ -529,11 +580,11 @@ def _get_physical_device_name(cls, device_id: int = 0) -> str: def log_warnings(cls): device_ids: int = pymxml.nvmlDeviceGetCount() if device_ids > 1: - device_names = [ - cls._get_physical_device_name(i) for i in range(device_ids) - ] - if (len(set(device_names)) > 1 - and os.environ.get("CUDA_DEVICE_ORDER") != "PCI_BUS_ID"): + device_names = [cls._get_physical_device_name(i) for i in range(device_ids)] + if ( + len(set(device_names)) > 1 + and os.environ.get("CUDA_DEVICE_ORDER") != "PCI_BUS_ID" + ): logger.warning( "Detected different devices in the system: %s. Please" " make sure to set `CUDA_DEVICE_ORDER=PCI_BUS_ID` to " @@ -543,8 +594,8 @@ def log_warnings(cls): class NonMxmlMetaxPlatform(MacaPlatformBase): - @classmethod + @cache def get_device_capability(cls, device_id: int = 0) -> DeviceCapability: major, minor = torch.cuda.get_device_capability(device_id) return DeviceCapability(major=major, minor=minor) @@ -559,10 +610,11 @@ def get_device_total_memory(cls, device_id: int = 0) -> int: return device_props.total_memory @classmethod - def is_fully_connected(cls, physical_device_ids: List[int]) -> bool: + def is_fully_connected(cls, physical_device_ids: list[int]) -> bool: logger.exception( "MetaXLink detection not possible, as context support was" - " not found. Assuming no MetaXLink available.") + " not found. Assuming no MetaXLink available." + ) return False diff --git a/vllm_metax/quant_config/__init__.py b/vllm_metax/quant_config/__init__.py index 35e1ee895..988131360 100644 --- a/vllm_metax/quant_config/__init__.py +++ b/vllm_metax/quant_config/__init__.py @@ -1 +1 @@ -# SPDX-License-Identifier: Apache-2.0 \ No newline at end of file +# SPDX-License-Identifier: Apache-2.0 diff --git a/vllm_metax/quant_config/awq.py b/vllm_metax/quant_config/awq.py index fde8ca917..fdefb66ca 100644 --- a/vllm_metax/quant_config/awq.py +++ b/vllm_metax/quant_config/awq.py @@ -1,30 +1,28 @@ # SPDX-License-Identifier: Apache-2.0 -# TODO: hotfix for subprocess while unpickle, remove after torch2.8 is released -from vllm_metax.patch.hotfix import patch_utils # noqa: F401 from typing import Optional, Union import torch from vllm.model_executor.layers.fused_moe.layer import FusedMoE -from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, - UnquantizedLinearMethod) +from vllm.model_executor.layers.linear import ( + LinearBase, + LinearMethodBase, + UnquantizedLinearMethod, +) from vllm.model_executor.layers.quantization.awq import AWQConfig -from vllm.model_executor.layers.quantization.awq import (AWQLinearMethod as - vllm_AWQLinearMethod) -from vllm.model_executor.layers.quantization.awq import (is_layer_skipped_awq, - logger) -from vllm.model_executor.layers.quantization.base_config import ( - QuantizeMethodBase) -from vllm.utils import direct_register_custom_op +from vllm.model_executor.layers.quantization.awq import ( + AWQLinearMethod as vllm_AWQLinearMethod, +) +from vllm.model_executor.layers.quantization.awq import is_layer_skipped, logger +from vllm.model_executor.layers.quantization.base_config import QuantizeMethodBase +from vllm.utils.torch_utils import direct_register_custom_op from vllm_metax import _custom_ops as ops -from vllm_metax.patch.model_executor.hook_register import ( - register_quantization_config) +from vllm_metax.patch.model_executor.hook_register import register_quantization_config @register_quantization_config("awq") class MacaAWQConfig(AWQConfig): - def get_supported_act_dtypes(self): return [torch.half, torch.bfloat16] @@ -32,15 +30,22 @@ def get_quant_method( self, layer: torch.nn.Module, prefix: str ) -> Optional[Union["LinearMethodBase", "QuantizeMethodBase"]]: if isinstance(layer, LinearBase): - if is_layer_skipped_awq(prefix, self.modules_to_not_convert): + if is_layer_skipped( + prefix, + self.modules_to_not_convert, + self.packed_modules_mapping, + skip_with_substr=True, + ): return UnquantizedLinearMethod() return AWQLinearMethod(self) elif isinstance(layer, FusedMoE): # Lazy import to avoid circular import. from vllm_metax.quant_config.moe_wna16 import MacaMoeWNA16Config + logger.warning_once( f"Layer '{prefix}' is not supported by AWQMoeMarlin. " - "Falling back to Moe WNA16 kernels.") + "Falling back to Moe WNA16 kernels." + ) config = { "quant_method": "awq", "bits": self.weight_bits, @@ -49,7 +54,8 @@ def get_quant_method( "lm_head": False, } return MacaMoeWNA16Config.from_config(config).get_quant_method( - layer, prefix) + layer, prefix + ) return None @@ -61,12 +67,9 @@ class AWQLinearMethod(vllm_AWQLinearMethod): """ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - layer.qweight = torch.nn.Parameter(layer.qweight.data, - requires_grad=False) - layer.qzeros = torch.nn.Parameter(layer.qzeros.data, - requires_grad=False) - layer.scales = torch.nn.Parameter(layer.scales.data, - requires_grad=False) + layer.qweight = torch.nn.Parameter(layer.qweight.data, requires_grad=False) + layer.qzeros = torch.nn.Parameter(layer.qzeros.data, requires_grad=False) + layer.scales = torch.nn.Parameter(layer.scales.data, requires_grad=False) # ┌------------------------ Metax Modification -------------------------┐ # warmup if self.quant_config.group_size % 32: @@ -76,11 +79,12 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.qweight = torch.nn.Parameter(qweight, requires_grad=False) # └------------------------- Metax Modification -------------------------┘ - def apply(self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: - + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: qweight = layer.qweight scales = layer.scales qzeros = layer.qzeros @@ -88,27 +92,38 @@ def apply(self, # ┌------------------------ Metax Modification -------------------------┐ group_size = self.quant_config.group_size - return torch.ops.vllm._apply_awq(x, qweight, scales, qzeros, bias, - pack_factor, group_size) + return torch.ops.vllm._apply_awq( + x, qweight, scales, qzeros, bias, pack_factor, group_size + ) # └------------------------- Metax Modification -------------------------┘ -def _apply_awq_fake(x: torch.Tensor, qweight: torch.Tensor, - scales: torch.Tensor, qzeros: torch.Tensor, - bias: torch.Tensor, pack_factor: int, - group_size: int) -> torch.Tensor: +def _apply_awq_fake( + x: torch.Tensor, + qweight: torch.Tensor, + scales: torch.Tensor, + qzeros: torch.Tensor, + bias: torch.Tensor, + pack_factor: int, + group_size: int, +) -> torch.Tensor: out_shape = () if group_size % 32: - out_shape = (x.shape[:-1] + (qweight.shape[-1] * pack_factor, )) + out_shape = x.shape[:-1] + (qweight.shape[-1] * pack_factor,) else: - out_shape = (x.shape[:-1] + (qweight.shape[0], )) + out_shape = x.shape[:-1] + (qweight.shape[0],) return torch.empty(out_shape, dtype=x.dtype, device=x.device) -def _apply_awq(x: torch.Tensor, qweight: torch.Tensor, scales: torch.Tensor, - qzeros: torch.Tensor, bias: torch.Tensor, pack_factor: int, - group_size: int) -> torch.Tensor: - +def _apply_awq( + x: torch.Tensor, + qweight: torch.Tensor, + scales: torch.Tensor, + qzeros: torch.Tensor, + bias: torch.Tensor, + pack_factor: int, + group_size: int, +) -> torch.Tensor: out_shape = () reshaped_x = x.reshape(-1, x.shape[-1]) out = torch.empty(0) @@ -116,20 +131,29 @@ def _apply_awq(x: torch.Tensor, qweight: torch.Tensor, scales: torch.Tensor, FP16_MATMUL_HEURISTIC_CONDITION = x.shape[:-1].numel() >= 256 # noqa: F841 # if (FP16_MATMUL_HEURISTIC_CONDITION and reshaped_x.dtype == torch.half) or self.quant_config.group_size != 128: if group_size % 32: - out_shape = (x.shape[:-1] + (qweight.shape[-1] * pack_factor, )) + out_shape = x.shape[:-1] + (qweight.shape[-1] * pack_factor,) out = ops.awq_dequantize(qweight, scales, qzeros, 0, 0, 0) out = torch.matmul(reshaped_x, out) else: num_out_channel = qweight.shape[0] - out_shape = (x.shape[:-1] + (num_out_channel, )) + out_shape = x.shape[:-1] + (num_out_channel,) temp_space = torch.empty(0, dtype=torch.float32, device=x.device) if reshaped_x.dtype == torch.bfloat16: - temp_space = torch.zeros(reshaped_x.shape[0], - num_out_channel, - dtype=torch.float32, - device=x.device) - out = ops.awq_gemm(reshaped_x, qweight, qzeros, scales, pack_factor, - temp_space, reshaped_x.dtype == torch.bfloat16) + temp_space = torch.zeros( + reshaped_x.shape[0], + num_out_channel, + dtype=torch.float32, + device=x.device, + ) + out = ops.awq_gemm( + reshaped_x, + qweight, + qzeros, + scales, + pack_factor, + temp_space, + reshaped_x.dtype == torch.bfloat16, + ) if bias is not None: out.add_(bias) return out.reshape(out_shape) @@ -140,5 +164,5 @@ def _apply_awq(x: torch.Tensor, qweight: torch.Tensor, scales: torch.Tensor, op_func=_apply_awq, mutates_args=[], fake_impl=_apply_awq_fake, - tags=(torch.Tag.needs_fixed_stride_order, ), + tags=(torch.Tag.needs_fixed_stride_order,), ) diff --git a/vllm_metax/quant_config/awq_marlin.py b/vllm_metax/quant_config/awq_marlin.py index 1a2f801d0..68b5f96dd 100644 --- a/vllm_metax/quant_config/awq_marlin.py +++ b/vllm_metax/quant_config/awq_marlin.py @@ -1,18 +1,19 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Optional +from typing import Optional, TYPE_CHECKING -from vllm.model_executor.layers.quantization.awq_marlin import ( - AWQMarlinConfig, QuantizationMethods) +from vllm.model_executor.layers.quantization.awq_marlin import AWQMarlinConfig -from vllm_metax.patch.model_executor.hook_register import ( - register_quantization_config) +from vllm_metax.patch.model_executor.hook_register import register_quantization_config + +if TYPE_CHECKING: + from vllm.model_executor.layers.quantization import QuantizationMethods @register_quantization_config("awq_marlin") class MacaAWQMarlinConfig(AWQMarlinConfig): - @classmethod def override_quantization_method( - cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]: + cls, hf_quant_cfg, user_quant + ) -> Optional["QuantizationMethods"]: return None diff --git a/vllm_metax/quant_config/compressed_tensors.py b/vllm_metax/quant_config/compressed_tensors.py index 261c460c3..b3af428e0 100644 --- a/vllm_metax/quant_config/compressed_tensors.py +++ b/vllm_metax/quant_config/compressed_tensors.py @@ -1,35 +1,36 @@ # SPDX-License-Identifier: Apache-2.0 -# TODO: hotfix for subprocess while unpickle, remove after torch2.8 is released -from vllm_metax.patch.hotfix import patch_utils # noqa: F401 - from typing import Optional import torch from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501 - QuantizeMethodBase) + QuantizeMethodBase, +) from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import ( # noqa: E501 - CompressedTensorsMoEMethod) + CompressedTensorsMoEMethod, +) from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( - CompressedTensorsConfig) + CompressedTensorsConfig, +) -from vllm_metax.quant_config.compressed_tensors_moe import MacaCompressedTensorsMoEMethod +from vllm_metax.quant_config.compressed_tensors_moe import ( + MacaCompressedTensorsMoEMethod, +) from vllm_metax.patch.model_executor.hook_register import register_quantization_config @register_quantization_config("compressed-tensors") class MacaCompressedTensorsConfig(CompressedTensorsConfig): - def get_quant_method( self, layer: torch.nn.Module, prefix: str, ) -> Optional["QuantizeMethodBase"]: - origin_quant_method = super().get_quant_method(layer, prefix) if isinstance(origin_quant_method, CompressedTensorsMoEMethod): origin_quant_method = MacaCompressedTensorsMoEMethod.get_moe_method( - self, layer) + self, layer + ) return origin_quant_method diff --git a/vllm_metax/quant_config/compressed_tensors_moe.py b/vllm_metax/quant_config/compressed_tensors_moe.py index 0e011dee6..2a1cbeacc 100644 --- a/vllm_metax/quant_config/compressed_tensors_moe.py +++ b/vllm_metax/quant_config/compressed_tensors_moe.py @@ -3,31 +3,31 @@ from typing import Callable, Optional, Union from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import ( - CompressedTensorsMoEMethod, CompressedTensorsW8A8Int8MoEMethod, - CompressedTensorsWNA16MoEMethod) + CompressedTensorsMoEMethod, + CompressedTensorsW8A8Int8MoEMethod, + CompressedTensorsWNA16MoEMethod, +) class MacaCompressedTensorsMoEMethod(CompressedTensorsMoEMethod): - @staticmethod def get_moe_method( quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501 - layer: torch.nn.Module + layer: torch.nn.Module, ) -> "CompressedTensorsMoEMethod": - moe_method = CompressedTensorsMoEMethod.get_moe_method( - quant_config, layer) + moe_method = CompressedTensorsMoEMethod.get_moe_method(quant_config, layer) if isinstance(moe_method, CompressedTensorsWNA16MoEMethod): moe_method = MacaCompressedTensorsWNA16MoEMethod( - quant_config, layer.moe_config) + quant_config, layer.moe_config + ) elif isinstance(moe_method, CompressedTensorsW8A8Int8MoEMethod): moe_method = MacaCompressedTensorsW8A8Int8MoEMethod( - quant_config, layer.moe_config) + quant_config, layer.moe_config + ) return moe_method -class MacaCompressedTensorsW8A8Int8MoEMethod(CompressedTensorsW8A8Int8MoEMethod - ): - +class MacaCompressedTensorsW8A8Int8MoEMethod(CompressedTensorsW8A8Int8MoEMethod): def apply( self, layer: torch.nn.Module, @@ -55,8 +55,8 @@ def apply( if enable_eplb: raise NotImplementedError( - "EPLB not supported for " - "`CompressedTensorsW8A8Int8MoEMethod` yet.") + "EPLB not supported for `CompressedTensorsW8A8Int8MoEMethod` yet." + ) from vllm_metax.model_executor.layers.fused_moe import fused_experts @@ -72,7 +72,8 @@ def apply( scoring_func=scoring_func, routed_scaling_factor=routed_scaling_factor, e_score_correction_bias=e_score_correction_bias, - indices_type=self.topk_indices_dtype) + indices_type=self.topk_indices_dtype, + ) return fused_experts( hidden_states=x, @@ -90,7 +91,6 @@ def apply( class MacaCompressedTensorsWNA16MoEMethod(CompressedTensorsWNA16MoEMethod): - def apply( self, layer: torch.nn.Module, @@ -117,8 +117,9 @@ def apply( assert self.fused_experts is None if enable_eplb: - raise NotImplementedError("EPLB not supported for " - "`CompressedTensorsWNA16MoEMethod` yet.") + raise NotImplementedError( + "EPLB not supported for `CompressedTensorsWNA16MoEMethod` yet." + ) from vllm_metax.model_executor.layers.fused_moe import fused_experts @@ -134,7 +135,8 @@ def apply( scoring_func=scoring_func, routed_scaling_factor=routed_scaling_factor, e_score_correction_bias=e_score_correction_bias, - indices_type=self.topk_indices_dtype) + indices_type=self.topk_indices_dtype, + ) return fused_experts( x, diff --git a/vllm_metax/quant_config/gptq.py b/vllm_metax/quant_config/gptq.py index 6721696c6..ede0b442c 100644 --- a/vllm_metax/quant_config/gptq.py +++ b/vllm_metax/quant_config/gptq.py @@ -1,31 +1,26 @@ # SPDX-License-Identifier: Apache-2.0 -# TODO: hotfix for subprocess while unpickle, remove after torch2.8 is released -from vllm_metax.patch.hotfix import patch_utils # noqa: F401 from typing import Optional, Union import torch from torch.nn.parameter import Parameter from vllm.model_executor.layers.fused_moe.layer import FusedMoE -from vllm.model_executor.layers.quantization.base_config import ( - QuantizeMethodBase) -from vllm.model_executor.layers.quantization.gptq import (ExllamaState, - GPTQConfig) -from vllm.model_executor.layers.quantization.gptq import (GPTQLinearMethod as - vllm_GPTQLinearMethod - ) +from vllm.model_executor.layers.quantization.base_config import QuantizeMethodBase +from vllm.model_executor.layers.quantization.gptq import ExllamaState, GPTQConfig +from vllm.model_executor.layers.quantization.gptq import ( + GPTQLinearMethod as vllm_GPTQLinearMethod, +) from vllm.model_executor.layers.quantization.utils.gptq_utils import ( - get_linear_quant_method) -from vllm.utils import direct_register_custom_op + get_linear_quant_method, +) +from vllm.utils.torch_utils import direct_register_custom_op from vllm_metax import _custom_ops as ops -from vllm_metax.patch.model_executor.hook_register import ( - register_quantization_config) +from vllm_metax.patch.model_executor.hook_register import register_quantization_config @register_quantization_config("gptq") class MacaGPTQConfig(GPTQConfig): - def get_supported_act_dtypes(cls) -> list[torch.dtype]: return [torch.half, torch.bfloat16] @@ -44,13 +39,13 @@ def get_quant_method( "lm_head": False, } return MacaMoeWNA16Config.from_config(config).get_quant_method( - layer, prefix) + layer, prefix + ) return get_linear_quant_method(self, layer, prefix, GPTQLinearMethod) class GPTQLinearMethod(vllm_GPTQLinearMethod): - def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # for torch.compile layer.qzeros = Parameter(layer.qzeros.data, requires_grad=False) @@ -67,12 +62,11 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: if self.quant_config.desc_act: layer.g_idx.data = torch.argsort(layer.g_idx).to(torch.int) else: - layer.g_idx.data = torch.empty((0, ), - dtype=torch.int, - device=layer.g_idx.device) + layer.g_idx.data = torch.empty( + (0,), dtype=torch.int, device=layer.g_idx.device + ) layer.exllama_state = ExllamaState.READY - ops.gptq_shuffle(layer.qweight, layer.g_idx, - self.quant_config.weight_bits) + ops.gptq_shuffle(layer.qweight, layer.g_idx, self.quant_config.weight_bits) # ┌------------------------ Metax Modification -------------------------┐ if layer.scales.dtype != torch.bfloat16: @@ -80,45 +74,66 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: temp_space = torch.empty(0) if self.quant_config.weight_bits == 4: # warmup - reshaped_x = torch.randn(1, - layer.qweight.shape[0] * 8, - dtype=layer.scales.dtype, - device="cuda") + reshaped_x = torch.randn( + 1, + layer.qweight.shape[0] * 8, + dtype=layer.scales.dtype, + device="cuda", + ) _ = ops.gptq_gemm( - reshaped_x, layer.qweight, layer.qzeros, layer.scales, - layer.g_idx, layer.exllama_state == ExllamaState.READY, + reshaped_x, + layer.qweight, + layer.qzeros, + layer.scales, + layer.g_idx, + layer.exllama_state == ExllamaState.READY, self.quant_config.weight_bits, - self.quant_config.group_size, perm_space, temp_space, - False) + self.quant_config.group_size, + perm_space, + temp_space, + False, + ) if self.quant_config.weight_bits == 8: # warmup - reshaped_x = torch.randn(1, - layer.qweight.shape[0] * 4, - dtype=layer.scales.dtype, - device="cuda") + reshaped_x = torch.randn( + 1, + layer.qweight.shape[0] * 4, + dtype=layer.scales.dtype, + device="cuda", + ) _ = ops.gptq_gemm( - reshaped_x, layer.qweight, layer.qzeros, layer.scales, - layer.g_idx, layer.exllama_state == ExllamaState.READY, + reshaped_x, + layer.qweight, + layer.qzeros, + layer.scales, + layer.g_idx, + layer.exllama_state == ExllamaState.READY, self.quant_config.weight_bits, - self.quant_config.group_size, perm_space, temp_space, - False) + self.quant_config.group_size, + perm_space, + temp_space, + False, + ) # └------------------------- Metax Modification -------------------------┘ else: if layer.exllama_state == ExllamaState.UNINITIALIZED: if self.quant_config.desc_act: layer.g_idx.data = torch.argsort(layer.g_idx).to(torch.int) else: - layer.g_idx.data = torch.empty((0, ), - dtype=torch.int, - device=layer.g_idx.device) + layer.g_idx.data = torch.empty( + (0,), dtype=torch.int, device=layer.g_idx.device + ) layer.exllama_state = ExllamaState.READY - ops.gptq_shuffle(layer.qweight, layer.g_idx, - self.quant_config.weight_bits) - - def apply(self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: + ops.gptq_shuffle( + layer.qweight, layer.g_idx, self.quant_config.weight_bits + ) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: # ┌------------------------ Metax Modification -------------------------┐ qweight = layer.qweight scales = layer.scales @@ -130,47 +145,84 @@ def apply(self, desc_act = self.quant_config.desc_act use_exllama = exllama_state == ExllamaState.READY - return torch.ops.vllm._apply_gptq(x, qweight, scales, qzeros, bias, - g_idx, use_exllama, weight_bits, - group_size, desc_act) + return torch.ops.vllm._apply_gptq( + x, + qweight, + scales, + qzeros, + bias, + g_idx, + use_exllama, + weight_bits, + group_size, + desc_act, + ) # └------------------------- Metax Modification -------------------------┘ -def _apply_gptq_fake(x: torch.Tensor, qweight: torch.Tensor, - scales: torch.Tensor, qzeros: torch.Tensor, - bias: torch.Tensor, g_idx: torch.Tensor, - use_exllama: bool, weight_bits: int, group_size: int, - desc_act: bool) -> torch.Tensor: - out_shape = x.shape[:-1] + (qweight.shape[-1], ) +def _apply_gptq_fake( + x: torch.Tensor, + qweight: torch.Tensor, + scales: torch.Tensor, + qzeros: torch.Tensor, + bias: torch.Tensor, + g_idx: torch.Tensor, + use_exllama: bool, + weight_bits: int, + group_size: int, + desc_act: bool, +) -> torch.Tensor: + out_shape = x.shape[:-1] + (qweight.shape[-1],) return torch.empty(out_shape, dtype=x.dtype, device=x.device) -def _apply_gptq(x: torch.Tensor, qweight: torch.Tensor, scales: torch.Tensor, - qzeros: torch.Tensor, bias: torch.Tensor, g_idx: torch.Tensor, - use_exllama: bool, weight_bits: int, group_size: int, - desc_act: bool) -> torch.Tensor: - +def _apply_gptq( + x: torch.Tensor, + qweight: torch.Tensor, + scales: torch.Tensor, + qzeros: torch.Tensor, + bias: torch.Tensor, + g_idx: torch.Tensor, + use_exllama: bool, + weight_bits: int, + group_size: int, + desc_act: bool, +) -> torch.Tensor: reshaped_x = x.reshape(-1, x.shape[-1]) - out_shape = x.shape[:-1] + (qweight.shape[-1], ) + out_shape = x.shape[:-1] + (qweight.shape[-1],) perm_space = torch.empty(0) temp_space = torch.empty(0) if weight_bits == 4 or weight_bits == 8 or group_size == 128 or group_size == 64: if desc_act: - perm_space = torch.empty(reshaped_x.shape[0], - reshaped_x.shape[1], - dtype=torch.float16, - device=x.device) + perm_space = torch.empty( + reshaped_x.shape[0], + reshaped_x.shape[1], + dtype=torch.float16, + device=x.device, + ) if reshaped_x.dtype == torch.bfloat16: - temp_space = torch.zeros(reshaped_x.shape[0], - qweight.shape[1], - dtype=torch.float32, - device=x.device) - - output = ops.gptq_gemm(reshaped_x, qweight, qzeros, scales, g_idx, - use_exllama, weight_bits, group_size, perm_space, - temp_space, reshaped_x.dtype == torch.bfloat16) + temp_space = torch.zeros( + reshaped_x.shape[0], + qweight.shape[1], + dtype=torch.float32, + device=x.device, + ) + + output = ops.gptq_gemm( + reshaped_x, + qweight, + qzeros, + scales, + g_idx, + use_exllama, + weight_bits, + group_size, + perm_space, + temp_space, + reshaped_x.dtype == torch.bfloat16, + ) if bias is not None: output.add_(bias) @@ -182,5 +234,5 @@ def _apply_gptq(x: torch.Tensor, qweight: torch.Tensor, scales: torch.Tensor, op_func=_apply_gptq, mutates_args=[], fake_impl=_apply_gptq_fake, - tags=(torch.Tag.needs_fixed_stride_order, ), + tags=(torch.Tag.needs_fixed_stride_order,), ) diff --git a/vllm_metax/quant_config/gptq_marlin.py b/vllm_metax/quant_config/gptq_marlin.py index 4c410c988..b11a43ed7 100644 --- a/vllm_metax/quant_config/gptq_marlin.py +++ b/vllm_metax/quant_config/gptq_marlin.py @@ -1,18 +1,19 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Optional +from typing import Optional, TYPE_CHECKING -from vllm.model_executor.layers.quantization.gptq_marlin import ( - GPTQMarlinConfig, QuantizationMethods) +from vllm.model_executor.layers.quantization.gptq_marlin import GPTQMarlinConfig -from vllm_metax.patch.model_executor.hook_register import ( - register_quantization_config) +from vllm_metax.patch.model_executor.hook_register import register_quantization_config + +if TYPE_CHECKING: + from vllm.model_executor.layers.quantization import QuantizationMethods @register_quantization_config("gptq_marlin") class MacaGPTQMarlinConfig(GPTQMarlinConfig): - @classmethod def override_quantization_method( - cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]: + cls, hf_quant_cfg, user_quant + ) -> Optional["QuantizationMethods"]: return None diff --git a/vllm_metax/quant_config/moe_wna16.py b/vllm_metax/quant_config/moe_wna16.py index ffe63568f..a82f840cf 100644 --- a/vllm_metax/quant_config/moe_wna16.py +++ b/vllm_metax/quant_config/moe_wna16.py @@ -4,16 +4,16 @@ from typing import Optional import torch -from vllm.model_executor.layers.fused_moe.layer import (FusedMoE) -from vllm.model_executor.layers.linear import (LinearBase, - UnquantizedLinearMethod) -from vllm.model_executor.layers.quantization.base_config import ( - QuantizeMethodBase) +from vllm.model_executor.layers.fused_moe.layer import FusedMoE +from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod +from vllm.model_executor.layers.quantization.base_config import QuantizeMethodBase from vllm.model_executor.layers.quantization.moe_wna16 import ( - MoeWNA16Config, is_layer_skipped_quant, MoeWNA16Method) + MoeWNA16Config, + is_layer_skipped_quant, + MoeWNA16Method, +) -from vllm_metax.patch.model_executor.hook_register import ( - register_quantization_config) +from vllm_metax.patch.model_executor.hook_register import register_quantization_config # Remove configs of marlin @@ -25,20 +25,24 @@ def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.use_marlin = False - def get_quant_method(self, layer: torch.nn.Module, - prefix: str) -> Optional["QuantizeMethodBase"]: + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> Optional["QuantizeMethodBase"]: if is_layer_skipped_quant(prefix, self.modules_to_not_convert): return UnquantizedLinearMethod() elif isinstance(layer, LinearBase): # Avoid circular import from vllm_metax.quant_config.awq import MacaAWQConfig from vllm_metax.quant_config.gptq import MacaGPTQConfig + if self.linear_quant_method == "gptq": - return MacaGPTQConfig.from_config( - self.full_config).get_quant_method(layer, prefix) + return MacaGPTQConfig.from_config(self.full_config).get_quant_method( + layer, prefix + ) elif self.linear_quant_method == "awq": - return MacaAWQConfig.from_config( - self.full_config).get_quant_method(layer, prefix) + return MacaAWQConfig.from_config(self.full_config).get_quant_method( + layer, prefix + ) else: raise ValueError("moe_wna16 only support gptq and awq.") elif isinstance(layer, FusedMoE): diff --git a/vllm_metax/utils/__init__.py b/vllm_metax/utils/__init__.py index b35eb9671..60792ae63 100644 --- a/vllm_metax/utils/__init__.py +++ b/vllm_metax/utils/__init__.py @@ -40,6 +40,7 @@ def import_pymxml(): module to our codebase, and use it directly. """ import vllm_metax.third_party.pymxml as pymxml + return pymxml @@ -55,8 +56,8 @@ def find_mccl_library() -> str: # manually load the nccl library if so_file: logger.info( - "Found nccl from environment variable VLLM_NCCL_SO_PATH=%s", - so_file) + "Found nccl from environment variable VLLM_NCCL_SO_PATH=%s", so_file + ) else: if torch.version.cuda is not None: so_file = "libmccl.so" @@ -73,4 +74,5 @@ def vllm_version(): return mx_envs.VLLM_OFFICIAL_VERSION else: import vllm + return vllm.__version__ diff --git a/vllm_metax/utils/deep_gemm.py b/vllm_metax/utils/deep_gemm.py index 266a39bcc..7e3256278 100644 --- a/vllm_metax/utils/deep_gemm.py +++ b/vllm_metax/utils/deep_gemm.py @@ -4,6 +4,7 @@ Users of vLLM should always import **only** these wrappers. """ + from __future__ import annotations import importlib @@ -13,14 +14,15 @@ import torch import vllm.envs as envs -from vllm.utils import has_deep_gemm +from vllm.utils.import_utils import has_deep_gemm def _missing(*_: Any, **__: Any) -> NoReturn: """Placeholder for unavailable DeepGEMM backend.""" raise RuntimeError( "DeepGEMM backend is not available. Please install the `deep_gemm` " - "package to enable BF16 kernels.") + "package to enable BF16 kernels." + ) _bf16_mqa_logits_impl: Callable[..., Any] | None = None @@ -35,10 +37,11 @@ def _lazy_init() -> None: return # Set up deep_gemm cache path - DEEP_GEMM_JIT_CACHE_ENV_NAME = 'DG_JIT_CACHE_DIR' + DEEP_GEMM_JIT_CACHE_ENV_NAME = "DG_JIT_CACHE_DIR" if not os.environ.get(DEEP_GEMM_JIT_CACHE_ENV_NAME, None): os.environ[DEEP_GEMM_JIT_CACHE_ENV_NAME] = os.path.join( - envs.VLLM_CACHE_ROOT, "deep_gemm") + envs.VLLM_CACHE_ROOT, "deep_gemm" + ) _dg = importlib.import_module("deep_gemm") @@ -109,14 +112,16 @@ def bf16_paged_mqa_logits( _lazy_init() if _bf16_paged_mqa_logits_impl is None: return _missing() - return _bf16_paged_mqa_logits_impl(q_bf16, - kv_cache_bf16, - weights, - context_lens, - block_tables, - schedule_metadata, - max_model_len, - clean_logits=True) + return _bf16_paged_mqa_logits_impl( + q_bf16, + kv_cache_bf16, + weights, + context_lens, + block_tables, + schedule_metadata, + max_model_len, + clean_logits=True, + ) __all__ = [ diff --git a/vllm_metax/v1/__init__.py b/vllm_metax/v1/__init__.py index 35e1ee895..988131360 100644 --- a/vllm_metax/v1/__init__.py +++ b/vllm_metax/v1/__init__.py @@ -1 +1 @@ -# SPDX-License-Identifier: Apache-2.0 \ No newline at end of file +# SPDX-License-Identifier: Apache-2.0 diff --git a/vllm_metax/v1/attention/backends/flash_attn.py b/vllm_metax/v1/attention/backends/flash_attn.py index 2dc4f3038..cfaa0c0a3 100644 --- a/vllm_metax/v1/attention/backends/flash_attn.py +++ b/vllm_metax/v1/attention/backends/flash_attn.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Attention layer with FlashAttention.""" + from dataclasses import dataclass from typing import Optional @@ -8,41 +9,51 @@ import torch from vllm import envs -from vllm import _custom_ops as ops -from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata, AttentionType, - is_quantized_kv_cache) +from vllm.attention.backends.abstract import ( + AttentionBackend, + AttentionImpl, + AttentionMetadata, + AttentionType, + MultipleOf, + is_quantized_kv_cache, +) from vllm.attention.layer import Attention +from vllm.attention.ops.common import cp_lse_ag_out_rs from vllm.attention.ops.merge_attn_states import merge_attn_states from vllm_metax.attention.utils.fa_utils import ( - flash_attn_supports_fp8, get_flash_attn_version, - is_flash_attn_varlen_func_available) + flash_attn_supports_fp8, + get_flash_attn_version, + is_flash_attn_varlen_func_available, +) if is_flash_attn_varlen_func_available(): - from vllm_metax.attention.utils.fa_utils import (flash_attn_varlen_func, - get_scheduler_metadata, - reshape_and_cache_flash, - flash_attn_with_kvcache) - + from vllm_metax.attention.utils.fa_utils import ( + flash_attn_varlen_func, + get_scheduler_metadata, + reshape_and_cache_flash, + flash_attn_with_kvcache, + ) from vllm.config import VllmConfig, get_layers_from_vllm_config +from vllm.distributed.parallel_state import get_dcp_group from vllm.logger import init_logger -from vllm.utils import cdiv -# yapf: disable -from vllm.v1.attention.backends.utils import (AttentionCGSupport, - AttentionMetadataBuilder, - CommonAttentionMetadata, - get_kv_cache_layout, - reorder_batch_to_split_decodes_and_prefills, - split_decodes_and_prefills) +from vllm.model_executor.layers.batch_invariant import ( + vllm_is_batch_invariant, +) +from vllm.utils.math_utils import cdiv +from vllm.v1.attention.backends.utils import ( + AttentionCGSupport, + AttentionMetadataBuilder, + CommonAttentionMetadata, + get_kv_cache_layout, + split_decodes_and_prefills, +) from vllm.v1.kv_cache_interface import AttentionSpec logger = init_logger(__name__) class MacaFlashAttentionBackend(AttentionBackend): - accept_output_buffer: bool = True - supports_quant_query_input: bool = True @classmethod def get_supported_dtypes(cls) -> list[torch.dtype]: @@ -50,7 +61,11 @@ def get_supported_dtypes(cls) -> list[torch.dtype]: @classmethod def get_supported_head_sizes(cls) -> list[int]: - return [32, 64, 80, 96, 112, 128, 160, 192, 224, 256] + return [32, 64, 72, 80, 96, 112, 128, 160, 192, 224, 256] + + @staticmethod + def get_supported_kernel_block_size() -> list[int | MultipleOf]: + return [MultipleOf(16)] @classmethod def validate_head_size(cls, head_size: int) -> None: @@ -61,7 +76,8 @@ def validate_head_size(cls, head_size: int) -> None: f"Head size {head_size} is not supported by {attn_type}. " f"Supported head sizes are: {supported_head_sizes}. " "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use " - "FlexAttention backend which supports all head sizes.") + "FlexAttention backend which supports all head sizes." + ) @staticmethod def get_name() -> str: @@ -106,6 +122,9 @@ def get_kv_cache_stride_order() -> tuple[int, ...]: @staticmethod def get_fp8_dtype_for_flashattn(kv_cache_dtype: str) -> torch.dtype: + raise NotImplementedError( + "FP8 dtype is not supported for FlashAttention on Maca." + ) if kv_cache_dtype in ("fp8", "fp8_e4m3"): return torch.float8_e4m3fn else: @@ -130,13 +149,6 @@ class FlashAttentionMetadata: block_table: torch.Tensor slot_mapping: torch.Tensor - # For cascade attention. - use_cascade: bool - common_prefix_len: int - cu_prefix_query_lens: Optional[torch.Tensor] - prefix_kv_lens: Optional[torch.Tensor] - suffix_kv_lens: Optional[torch.Tensor] - # /------------------------ Metax Modification -------------------------\ # For handling prefill decode split num_decodes: int @@ -154,18 +166,30 @@ class FlashAttentionMetadata: prefill_block_table: torch.Tensor # \------------------------- Metax Modification -------------------------/ + # For cascade attention. + use_cascade: bool + common_prefix_len: int + cu_prefix_query_lens: torch.Tensor | None + prefix_kv_lens: torch.Tensor | None + suffix_kv_lens: torch.Tensor | None + + # For GQA DCP + max_dcp_context_kv_len: int | None = None + dcp_context_kv_lens: torch.Tensor | None = None + # Optional aot scheduling - scheduler_metadata: Optional[torch.Tensor] = None - prefix_scheduler_metadata: Optional[torch.Tensor] = None + scheduler_metadata: torch.Tensor | None = None + prefix_scheduler_metadata: torch.Tensor | None = None max_num_splits: int = 0 causal: bool = True def _get_sliding_window_configs( - vllm_config: VllmConfig) -> set[Optional[tuple[int, int]]]: + vllm_config: VllmConfig, +) -> set[tuple[int, int] | None]: """Get the set of all sliding window configs used in the model.""" - sliding_window_configs: set[Optional[tuple[int, int]]] = set() + sliding_window_configs: set[tuple[int, int] | None] = set() layers = get_layers_from_vllm_config(vllm_config, Attention) for layer in layers.values(): assert isinstance(layer.impl, FlashAttentionImpl) @@ -173,8 +197,7 @@ def _get_sliding_window_configs( return sliding_window_configs -class FlashAttentionMetadataBuilder( - AttentionMetadataBuilder[FlashAttentionMetadata]): +class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetadata]): # FA3: # Supports full cudagraphs for all cases. # @@ -193,11 +216,19 @@ class FlashAttentionMetadataBuilder( # to FULL_AND_PIECEWISE. # TODO(luka, lucas): audit FA2 as part of: # https://github.com/vllm-project/vllm/issues/22945 - cudagraph_support = AttentionCGSupport.ALWAYS \ - if get_flash_attn_version() == 3 else AttentionCGSupport.UNIFORM_BATCH + cudagraph_support = ( + AttentionCGSupport.ALWAYS + if get_flash_attn_version() == 3 + else AttentionCGSupport.UNIFORM_BATCH + ) - def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], - vllm_config: VllmConfig, device: torch.device): + def __init__( + self, + kv_cache_spec: AttentionSpec, + layer_names: list[str], + vllm_config: VllmConfig, + device: torch.device, + ): super().__init__(kv_cache_spec, layer_names, vllm_config, device) self.model_config = vllm_config.model_config self.parallel_config = vllm_config.parallel_config @@ -205,27 +236,38 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], self.compilation_config = vllm_config.compilation_config self.num_heads_q = self.model_config.get_num_attention_heads( - self.parallel_config) - self.num_heads_kv = self.model_config.get_num_kv_heads( - self.parallel_config) + self.parallel_config + ) + self.num_heads_kv = self.model_config.get_num_kv_heads(self.parallel_config) self.kv_cache_dtype = kv_cache_spec.dtype self.headdim = self.model_config.get_head_size() self.block_size = kv_cache_spec.block_size self.max_num_splits = 0 # No upper bound on the number of splits. - self.aot_schedule = (get_flash_attn_version() == 3) + self.aot_schedule = get_flash_attn_version() == 3 - self.use_full_cuda_graph = \ + try: + from vllm.distributed.parallel_state import get_dcp_group + + self.dcp_world_size = get_dcp_group().world_size + self.dcp_rank = get_dcp_group().rank_in_group + except AssertionError: + # DCP might not be initialized in testing + self.dcp_world_size = 1 + self.dcp_rank = 0 + + self.use_full_cuda_graph = ( self.compilation_config.cudagraph_mode.has_full_cudagraphs() - self.max_cudagraph_size = self.compilation_config.max_capture_size + ) + self.max_cudagraph_size = self.compilation_config.max_cudagraph_capture_size if self.use_full_cuda_graph and self.aot_schedule: if self.max_cudagraph_size > 992: # This condition derives from FA3's internal heuristic. # TODO(woosuk): Support larger cudagraph sizes. raise ValueError( - "Capture size larger than 992 is not supported for " - "full cuda graph.") + "Capture size larger than 992 is not supported for full cuda graph." + ) self.scheduler_metadata = torch.zeros( vllm_config.scheduler_config.max_num_seqs + 1, @@ -235,19 +277,20 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], # When using cuda graph, we need to set the upper bound of the # number of splits so that large enough intermediate buffers are # pre-allocated during capture. - self.max_num_splits = ( - envs.VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH) + self.max_num_splits = envs.VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH # Sliding window size to be used with the AOT scheduler will be # populated on first build() call. - self.aot_sliding_window: Optional[tuple[int, int]] = None + self.aot_sliding_window: tuple[int, int] | None = None - def build(self, - common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata, - fast_build: bool = False) -> FlashAttentionMetadata: + def build( + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False, + ) -> FlashAttentionMetadata: """ - fast_build disables AOT scheduling, used when there will be few + fast_build disables AOT scheduling, used when there will be few iterations i.e. spec-decode """ num_reqs = common_attn_metadata.num_reqs @@ -261,8 +304,9 @@ def build(self, slot_mapping = common_attn_metadata.slot_mapping causal = common_attn_metadata.causal - num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \ + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( split_decodes_and_prefills(common_attn_metadata) + ) # /------------------------ Metax Modification -------------------------\ assert num_decode_tokens + num_prefill_tokens == num_actual_tokens @@ -279,8 +323,7 @@ def build(self, # build() call so the layers are constructed (cannot populate) # in __init__. if aot_schedule: - sliding_window_configs = _get_sliding_window_configs( - self.vllm_config) + sliding_window_configs = _get_sliding_window_configs(self.vllm_config) if len(sliding_window_configs) == 1: sliding_window_config = sliding_window_configs.pop() if sliding_window_config is not None: @@ -290,8 +333,7 @@ def build(self, aot_schedule = False max_num_splits = 0 # 0 means use FA3's heuristics, not CG compatible - if self.use_full_cuda_graph and \ - num_actual_tokens <= self.max_cudagraph_size: + if self.use_full_cuda_graph and num_actual_tokens <= self.max_cudagraph_size: # NOTE(woosuk): Setting num_splits > 1 may increase the memory # usage, because the intermediate buffers of size [num_splits, # num_heads, num_tokens, head_size] are allocated. Therefore, @@ -302,13 +344,15 @@ def build(self, # For handling prefill decode split if num_decodes > 0: decode_max_seq_len = int( - common_attn_metadata.seq_lens_cpu[:num_decodes].max()) - decode_query_start_loc = common_attn_metadata.query_start_loc[: - num_decodes - + 1] + common_attn_metadata.seq_lens_cpu[:num_decodes].max() + ) + decode_query_start_loc = common_attn_metadata.query_start_loc[ + : num_decodes + 1 + ] decode_seq_lens = common_attn_metadata.seq_lens[:num_decodes] - decode_block_table_tensor = common_attn_metadata.block_table_tensor[: - num_decodes] + decode_block_table_tensor = common_attn_metadata.block_table_tensor[ + :num_decodes + ] else: decode_max_seq_len = 0 decode_query_start_loc = None @@ -317,14 +361,16 @@ def build(self, if num_prefills > 0: prefill_max_seq_len = int( - common_attn_metadata.seq_lens_cpu[num_decodes:num_reqs].max()) + common_attn_metadata.seq_lens_cpu[num_decodes:num_reqs].max() + ) prefill_query_start_loc = ( - common_attn_metadata.query_start_loc[num_decodes:num_reqs + 1] - - common_attn_metadata.query_start_loc[num_decodes]) - prefill_seq_lens = common_attn_metadata.seq_lens[ - num_decodes:num_reqs] + common_attn_metadata.query_start_loc[num_decodes : num_reqs + 1] + - common_attn_metadata.query_start_loc[num_decodes] + ) + prefill_seq_lens = common_attn_metadata.seq_lens[num_decodes:num_reqs] prefill_block_table_tensor = common_attn_metadata.block_table_tensor[ - num_decodes:num_reqs] + num_decodes:num_reqs + ] else: prefill_max_seq_len = 0 prefill_query_start_loc = None @@ -332,12 +378,17 @@ def build(self, prefill_block_table_tensor = None # \------------------------- Metax Modification -------------------------/ - def schedule(batch_size, cu_query_lens, max_query_len, seqlens, - max_seq_len, causal): + if vllm_is_batch_invariant(): + max_num_splits = 1 + + def schedule( + batch_size, cu_query_lens, max_query_len, seqlens, max_seq_len, causal + ): cache_dtype = self.cache_config.cache_dtype if cache_dtype.startswith("fp8"): qkv_dtype = MacaFlashAttentionBackend.get_fp8_dtype_for_flashattn( - cache_dtype) + cache_dtype + ) else: qkv_dtype = self.kv_cache_dtype if aot_schedule: @@ -345,7 +396,7 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens, batch_size=batch_size, max_seqlen_q=max_query_len, max_seqlen_k=max_seq_len, - num_heads_q=self.num_heads_q, + num_heads_q=self.num_heads_q * self.dcp_world_size, num_heads_kv=self.num_heads_kv, headdim=self.headdim, cache_seqlens=seqlens, @@ -359,41 +410,73 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens, return None use_cascade = common_prefix_len > 0 - - if use_cascade: - cu_prefix_query_lens = torch.tensor([0, num_actual_tokens], - dtype=torch.int32, - device=self.device) - prefix_kv_lens = torch.tensor([common_prefix_len], - dtype=torch.int32, - device=self.device) + max_dcp_context_kv_len = 0 + dcp_context_kv_lens = None + + cu_prefix_query_lens = None + prefix_kv_lens = None + suffix_kv_lens = None + prefix_scheduler_metadata = None + + if self.dcp_world_size > 1: + query_kv_lens_cpu = ( + common_attn_metadata.query_start_loc_cpu[1:] + - common_attn_metadata.query_start_loc_cpu[:-1] + ) + dcp_context_kv_lens_cpu = seq_lens_cpu - query_kv_lens_cpu + dcp_context_kv_lens_cpu = dcp_context_kv_lens_cpu // self.dcp_world_size + ( + self.dcp_rank <= (dcp_context_kv_lens_cpu - 1) % self.dcp_world_size + ) + dcp_context_kv_lens = dcp_context_kv_lens_cpu.to(self.device) + max_dcp_context_kv_len = dcp_context_kv_lens.max().item() + + scheduler_metadata = schedule( + batch_size=num_reqs, + cu_query_lens=query_start_loc, + max_query_len=max_query_len, + seqlens=dcp_context_kv_lens, + max_seq_len=max_dcp_context_kv_len, + causal=False, + ) + elif use_cascade: + cu_prefix_query_lens = torch.tensor( + [0, num_actual_tokens], dtype=torch.int32, device=self.device + ) + prefix_kv_lens = torch.tensor( + [common_prefix_len], dtype=torch.int32, device=self.device + ) suffix_kv_lens = (seq_lens_cpu[:num_reqs] - common_prefix_len).to( - self.device, non_blocking=True) + self.device, non_blocking=True + ) prefix_scheduler_metadata = schedule( batch_size=1, cu_query_lens=cu_prefix_query_lens, max_query_len=num_actual_tokens, seqlens=prefix_kv_lens, max_seq_len=common_prefix_len, - causal=False) - scheduler_metadata = schedule(batch_size=num_reqs, - cu_query_lens=query_start_loc, - max_query_len=max_query_len, - seqlens=suffix_kv_lens, - max_seq_len=max_seq_len - - common_prefix_len, - causal=True) + causal=False, + ) + scheduler_metadata = schedule( + batch_size=num_reqs, + cu_query_lens=query_start_loc, + max_query_len=max_query_len, + seqlens=suffix_kv_lens, + max_seq_len=max_seq_len - common_prefix_len, + causal=True, + ) else: cu_prefix_query_lens = None prefix_kv_lens = None suffix_kv_lens = None prefix_scheduler_metadata = None - scheduler_metadata = schedule(batch_size=num_reqs, - cu_query_lens=query_start_loc, - max_query_len=max_query_len, - seqlens=seq_lens, - max_seq_len=max_seq_len, - causal=causal) + scheduler_metadata = schedule( + batch_size=num_reqs, + cu_query_lens=query_start_loc, + max_query_len=max_query_len, + seqlens=seq_lens, + max_seq_len=max_seq_len, + causal=causal, + ) # For FA3 + full cudagraph if self.use_full_cuda_graph and scheduler_metadata is not None: n = scheduler_metadata.shape[0] @@ -428,6 +511,8 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens, # \------------------------- Metax Modification -------------------------/ block_table=block_table_tensor, slot_mapping=slot_mapping, + max_dcp_context_kv_len=max_dcp_context_kv_len, + dcp_context_kv_lens=dcp_context_kv_lens, use_cascade=use_cascade, common_prefix_len=common_prefix_len, scheduler_metadata=scheduler_metadata, @@ -436,7 +521,8 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens, suffix_kv_lens=suffix_kv_lens, prefix_scheduler_metadata=prefix_scheduler_metadata, max_num_splits=max_num_splits, - causal=causal) + causal=causal, + ) return attn_metadata def use_cascade_attention(self, *args, **kwargs) -> bool: @@ -444,6 +530,7 @@ def use_cascade_attention(self, *args, **kwargs) -> bool: class FlashAttentionImpl(AttentionImpl): + can_return_lse_for_decode: bool = True def __init__( self, @@ -451,13 +538,13 @@ def __init__( head_size: int, scale: float, num_kv_heads: int, - alibi_slopes: Optional[list[float]], - sliding_window: Optional[int], + alibi_slopes: list[float] | None, + sliding_window: int | None, kv_cache_dtype: str, - logits_soft_cap: Optional[float] = None, + logits_soft_cap: float | None = None, attn_type: AttentionType = AttentionType.DECODER, - kv_sharing_target_layer_name: Optional[str] = None, - sinks: Optional[torch.Tensor] = None, + kv_sharing_target_layer_name: str | None = None, + sinks: torch.Tensor | None = None, ) -> None: self.num_heads = num_heads self.head_size = head_size @@ -485,18 +572,26 @@ def __init__( self.attn_type = attn_type self.vllm_flash_attn_version = get_flash_attn_version() - if is_quantized_kv_cache(self.kv_cache_dtype) \ - and not flash_attn_supports_fp8(): + # Cache the batch invariant result for use in forward passes + self.batch_invariant_enabled = vllm_is_batch_invariant() + + if is_quantized_kv_cache(self.kv_cache_dtype) and not flash_attn_supports_fp8(): raise NotImplementedError( - "FlashAttention does not support fp8 kv-cache on this device.") + "FlashAttention does not support fp8 kv-cache on this device." + ) self.sinks = sinks if self.sinks is not None: assert self.vllm_flash_attn_version == 3, ( - "Sinks are only supported in FlashAttention 3") + "Sinks are only supported in FlashAttention 3" + ) assert self.sinks.shape[0] == num_heads, ( "Sinks must have the same number of heads as the number of " - "heads in the layer") + "heads in the layer" + ) + + def supports_quant_query_input(self) -> bool: + return False def forward( self, @@ -506,9 +601,9 @@ def forward( value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: FlashAttentionMetadata, - output: Optional[torch.Tensor] = None, - output_scale: Optional[torch.Tensor] = None, - output_block_scale: Optional[torch.Tensor] = None, + output: torch.Tensor | None = None, + output_scale: torch.Tensor | None = None, + output_block_scale: torch.Tensor | None = None, ) -> torch.Tensor: """Forward pass with FlashAttention. @@ -529,12 +624,12 @@ def forward( if output_scale is not None or output_block_scale is not None: raise NotImplementedError( - "fused output quantization is not yet supported" - " for FlashAttentionImpl") + "fused output quantization is not yet supported for FlashAttentionImpl" + ) if attn_metadata is None: # Profiling run. - return output + return output.fill_(0) attn_type = self.attn_type @@ -553,11 +648,14 @@ def forward( if attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER): # For encoder attention, # we use direct Q, K, V tensors without caching - return self._forward_encoder_attention(query[:num_actual_tokens], - key[:num_actual_tokens], - value[:num_actual_tokens], - output[:num_actual_tokens], - attn_metadata, layer) + return self._forward_encoder_attention( + query[:num_actual_tokens], + key[:num_actual_tokens], + value[:num_actual_tokens], + output[:num_actual_tokens], + attn_metadata, + layer, + ) # For decoder and cross-attention, use KV cache as before key_cache, value_cache = kv_cache.unbind(0) @@ -565,8 +663,11 @@ def forward( # key and value may be None in the case of cross attention. They are # calculated once based on the output from the encoder and then cached # in KV cache. - if (self.kv_sharing_target_layer_name is None and key is not None - and value is not None): + if ( + self.kv_sharing_target_layer_name is None + and key is not None + and value is not None + ): # Reshape the input keys and values and store them in the cache. # Skip this if sharing KV cache with an earlier attention layer. # NOTE(woosuk): Here, key and value are padded while slot_mapping is @@ -588,52 +689,79 @@ def forward( if self.kv_cache_dtype.startswith("fp8"): # queries are quantized in the attention layer dtype = MacaFlashAttentionBackend.get_fp8_dtype_for_flashattn( - self.kv_cache_dtype) + self.kv_cache_dtype + ) key_cache = key_cache.view(dtype) value_cache = value_cache.view(dtype) - # ┌------------------------ Metax Modification -------------------------┐ - # For handling prefill decode split if not attn_metadata.use_cascade: - num_decode_tokens = attn_metadata.num_decode_tokens - if attn_metadata.num_prefills > 0: - cu_prefix_kv_lens = torch.tensor( - [0] + attn_metadata.prefill_seq_lens.tolist(), - device=attn_metadata.prefill_seq_lens.device, - dtype=torch.int32).cumsum(dim=0, dtype=torch.int32) - output[num_decode_tokens: - num_actual_tokens] = flash_attn_varlen_func( - q=query[num_decode_tokens:num_actual_tokens], - k=key_cache, - v=value_cache, - block_table=attn_metadata.prefill_block_table, - cu_seqlens_q=attn_metadata.prefill_query_start_loc, - cu_seqlens_k=cu_prefix_kv_lens, - max_seqlen_q=attn_metadata.max_query_len, - max_seqlen_k=attn_metadata.prefill_max_seq_len, - softmax_scale=self.scale, - causal=attn_metadata.causal, - window_size=self.sliding_window, - alibi_slopes=self.alibi_slopes, - softcap=self.logits_soft_cap, - ) - if attn_metadata.num_decodes > 0: - # Use flash_attn_with_kvcache for normal decoding. - decode_query = query[:num_decode_tokens] - output[:num_decode_tokens] = flash_attn_with_kvcache( - q=decode_query.unsqueeze(1), - k_cache=key_cache, - v_cache=value_cache, - block_table=attn_metadata.decode_block_table, - cache_seqlens=attn_metadata.decode_seq_lens, - softmax_scale=self.scale, - causal=True, - window_size=self.sliding_window, - alibi_slopes=self.alibi_slopes, - softcap=self.logits_soft_cap, - ).squeeze(1) - return output - # └------------------------- Metax Modification -------------------------┘ + cu_seqlens_q = attn_metadata.query_start_loc + seqused_k = attn_metadata.seq_lens + max_seqlen_q = attn_metadata.max_query_len + max_seqlen_k = attn_metadata.max_seq_len + block_table = attn_metadata.block_table + scheduler_metadata = attn_metadata.scheduler_metadata + + descale_shape = (cu_seqlens_q.shape[0] - 1, self.num_kv_heads) + + if self.dcp_world_size > 1: + self._forward_with_dcp( + query[:num_actual_tokens], + key[:num_actual_tokens], + value[:num_actual_tokens], + key_cache, + value_cache, + output[:num_actual_tokens], + attn_metadata, + q_descale=layer._q_scale.expand(descale_shape), + k_descale=layer._k_scale.expand(descale_shape), + v_descale=layer._v_scale.expand(descale_shape), + ) + return output + else: + # ┌------------------------ Metax Modification -------------------------┐ + # For handling prefill decode split + num_decode_tokens = attn_metadata.num_decode_tokens + if attn_metadata.num_prefills > 0: + cu_prefix_kv_lens = torch.tensor( + [0] + attn_metadata.prefill_seq_lens.tolist(), + device=attn_metadata.prefill_seq_lens.device, + dtype=torch.int32, + ).cumsum(dim=0, dtype=torch.int32) + output[num_decode_tokens:num_actual_tokens] = ( + flash_attn_varlen_func( + q=query[num_decode_tokens:num_actual_tokens], + k=key_cache, + v=value_cache, + block_table=attn_metadata.prefill_block_table, + cu_seqlens_q=attn_metadata.prefill_query_start_loc, + cu_seqlens_k=cu_prefix_kv_lens, + max_seqlen_q=attn_metadata.max_query_len, + max_seqlen_k=attn_metadata.prefill_max_seq_len, + softmax_scale=self.scale, + causal=attn_metadata.causal, + window_size=self.sliding_window, + alibi_slopes=self.alibi_slopes, + softcap=self.logits_soft_cap, + ) + ) + if attn_metadata.num_decodes > 0: + # Use flash_attn_with_kvcache for normal decoding. + decode_query = query[:num_decode_tokens] + output[:num_decode_tokens] = flash_attn_with_kvcache( + q=decode_query.unsqueeze(1), + k_cache=key_cache, + v_cache=value_cache, + block_table=attn_metadata.decode_block_table, + cache_seqlens=attn_metadata.decode_seq_lens, + softmax_scale=self.scale, + causal=True, + window_size=self.sliding_window, + alibi_slopes=self.alibi_slopes, + softcap=self.logits_soft_cap, + ).squeeze(1) + return output + # └------------------------- Metax Modification -------------------------┘ # Cascade attention (rare case). cascade_attention( @@ -659,9 +787,90 @@ def forward( q_descale=layer._q_scale, k_descale=layer._k_scale, v_descale=layer._v_scale, + s_aux=self.sinks, ) return output + def _forward_with_dcp( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + output: torch.Tensor, + attn_metadata: FlashAttentionMetadata, + q_descale: torch.Tensor | None = None, + k_descale: torch.Tensor | None = None, + v_descale: torch.Tensor | None = None, + ) -> torch.Tensor: + cu_seqlens_q = attn_metadata.query_start_loc + max_seqlen_q = attn_metadata.max_query_len + block_table = attn_metadata.block_table + + query = query.contiguous() + query_across_dcp = get_dcp_group().all_gather(query, dim=1) + context_attn_out, context_lse = flash_attn_varlen_func( + q=query_across_dcp, + k=key_cache, + v=value_cache, + out=None, + cu_seqlens_q=cu_seqlens_q, + max_seqlen_q=max_seqlen_q, + seqused_k=attn_metadata.dcp_context_kv_lens, + max_seqlen_k=attn_metadata.max_dcp_context_kv_len, + softmax_scale=self.scale, + causal=False, + alibi_slopes=self.alibi_slopes, + window_size=self.sliding_window, + block_table=block_table, + softcap=self.logits_soft_cap, + return_softmax_lse=True, + scheduler_metadata=attn_metadata.scheduler_metadata, + fa_version=self.vllm_flash_attn_version, + q_descale=q_descale, + k_descale=k_descale, + v_descale=v_descale, + ) + # FA returns LSE in shape [ H, B ] but cp_lse_ag_out_rs wants [ B, H ] + context_attn_out_cor, context_lse_cor = cp_lse_ag_out_rs( + context_attn_out, + context_lse.transpose(0, 1), + get_dcp_group(), + return_lse=True, + ) + context_lse_cor = context_lse_cor.transpose(0, 1).contiguous() + + query_attn_out, query_lse = flash_attn_varlen_func( + q=query, + k=key, + v=value, + out=None, + cu_seqlens_q=cu_seqlens_q, + max_seqlen_q=max_seqlen_q, + cu_seqlens_k=cu_seqlens_q, + max_seqlen_k=max_seqlen_q, + softmax_scale=self.scale, + causal=attn_metadata.causal, + alibi_slopes=self.alibi_slopes, + window_size=self.sliding_window, + softcap=self.logits_soft_cap, + return_softmax_lse=True, + fa_version=self.vllm_flash_attn_version, + q_descale=q_descale, + k_descale=k_descale, + v_descale=v_descale, + ) + assert context_attn_out_cor.shape == query_attn_out.shape + assert context_lse_cor.shape == query_lse.shape + merge_attn_states( + output, + context_attn_out_cor, + context_lse_cor, + query_attn_out, + query_lse, + ) + def _forward_encoder_attention( self, query: torch.Tensor, @@ -684,7 +893,8 @@ def _forward_encoder_attention( # For encoder attention, process FP8 quantization if needed if self.kv_cache_dtype.startswith("fp8"): raise NotImplementedError( - "quantization is not supported for encoder attention") + "quantization is not supported for encoder attention" + ) # Use encoder-specific metadata for sequence information cu_seqlens_q = attn_metadata.query_start_loc @@ -694,7 +904,8 @@ def _forward_encoder_attention( descale_shape = ( cu_seqlens_q.shape[0] - 1, # type: ignore[union-attr] - self.num_kv_heads) + self.num_kv_heads, + ) # Call flash attention directly on Q, K, V tensors flash_attn_varlen_func( @@ -715,6 +926,7 @@ def _forward_encoder_attention( q_descale=layer._q_scale.expand(descale_shape), k_descale=layer._k_scale.expand(descale_shape), v_descale=layer._v_scale.expand(descale_shape), + num_splits=1 if self.batch_invariant_enabled else 0, ) return output @@ -729,6 +941,7 @@ def use_cascade_attention( use_sliding_window: bool, use_local_attention: bool, num_sms: int, + dcp_world_size: int, ) -> bool: """Decide whether to use cascade attention. @@ -750,6 +963,9 @@ def use_cascade_attention( num_reqs = len(query_lens) if num_reqs < 8: return False + # disable cascade attention for DCP + if dcp_world_size > 1: + return False # Heuristics to decide whether using cascade attention is beneficial. # 1. When FlashDecoding is not used for normal attention, cascade attention @@ -757,8 +973,12 @@ def use_cascade_attention( num_queries_per_kv = num_query_heads // num_kv_heads # The criteria for using FlashDecoding can be found in the following link: # https://github.com/vllm-project/flash-attention/blob/96266b1111111f3d11aabefaf3bacbab6a89d03c/csrc/flash_attn/flash_api.cpp#L535 - use_flash_decoding = (num_queries_per_kv > 1 and not use_sliding_window - and not use_alibi and np.all(query_lens == 1)) + use_flash_decoding = ( + num_queries_per_kv > 1 + and not use_sliding_window + and not use_alibi + and np.all(query_lens == 1) + ) if not use_flash_decoding: # Use cascade attention. return True @@ -780,8 +1000,9 @@ def use_cascade_attention( cascade_waves = cdiv(cascade_ctas, num_sms) cascade_time = cascade_waves * num_prefix_tiles - flash_decoding_ctas = (num_reqs * num_kv_heads * - cdiv(num_queries_per_kv, q_tile_size)) + flash_decoding_ctas = ( + num_reqs * num_kv_heads * cdiv(num_queries_per_kv, q_tile_size) + ) flash_decoding_ctas *= num_prefix_tiles flash_decoding_time = cdiv(flash_decoding_ctas, num_sms) @@ -801,22 +1022,24 @@ def cascade_attention( suffix_kv_lens: torch.Tensor, max_kv_len: int, softmax_scale: float, - alibi_slopes: Optional[torch.Tensor], + alibi_slopes: torch.Tensor | None, sliding_window: tuple[int, int], logits_soft_cap: float, block_table: torch.Tensor, common_prefix_len: int, fa_version: int, - prefix_scheduler_metadata: Optional[torch.Tensor] = None, - suffix_scheduler_metadata: Optional[torch.Tensor] = None, - q_descale: Optional[torch.Tensor] = None, - k_descale: Optional[torch.Tensor] = None, - v_descale: Optional[torch.Tensor] = None, + prefix_scheduler_metadata: torch.Tensor | None = None, + suffix_scheduler_metadata: torch.Tensor | None = None, + q_descale: torch.Tensor | None = None, + k_descale: torch.Tensor | None = None, + v_descale: torch.Tensor | None = None, + s_aux: torch.Tensor | None = None, ) -> torch.Tensor: - assert alibi_slopes is None, ("Cascade attention does not support ALiBi.") + assert alibi_slopes is None, "Cascade attention does not support ALiBi." # TODO: Support sliding window. assert sliding_window == (-1, -1), ( - "Cascade attention does not support sliding window.") + "Cascade attention does not support sliding window." + ) num_tokens = query.shape[0] block_size = key_cache.shape[-3] @@ -826,10 +1049,9 @@ def cascade_attention( descale_shape = (cu_prefix_query_lens.shape[0] - 1, key_cache.shape[-2]) # /------------------------ Metax Modification -------------------------\ - cu_prefix_kv_lens = torch.tensor([0] + prefix_kv_lens.tolist(), - device=prefix_kv_lens.device, - dtype=torch.int32).cumsum( - dim=0, dtype=torch.int32) + cu_prefix_kv_lens = torch.tensor( + [0] + prefix_kv_lens.tolist(), device=prefix_kv_lens.device, dtype=torch.int32 + ).cumsum(dim=0, dtype=torch.int32) # \------------------------ Metax Modification -------------------------/ # Process shared prefix. @@ -850,10 +1072,9 @@ def cascade_attention( descale_shape = (cu_query_lens.shape[0] - 1, key_cache.shape[-2]) # /------------------------ Metax Modification -------------------------\ - cu_suffix_kv_lens = torch.tensor([0] + suffix_kv_lens.tolist(), - device=suffix_kv_lens.device, - dtype=torch.int32).cumsum( - dim=0, dtype=torch.int32) + cu_suffix_kv_lens = torch.tensor( + [0] + suffix_kv_lens.tolist(), device=suffix_kv_lens.device, dtype=torch.int32 + ).cumsum(dim=0, dtype=torch.int32) # \------------------------ Metax Modification -------------------------/ # Process suffix per query. @@ -874,5 +1095,4 @@ def cascade_attention( ) # Merge prefix and suffix outputs, and store the result in output. - merge_attn_states(output, prefix_output, prefix_lse, suffix_output, - suffix_lse) + merge_attn_states(output, prefix_output, prefix_lse, suffix_output, suffix_lse) diff --git a/vllm_metax/v1/attention/backends/flashinfer.py b/vllm_metax/v1/attention/backends/flashinfer.py index 15de1805d..03e38e1f7 100644 --- a/vllm_metax/v1/attention/backends/flashinfer.py +++ b/vllm_metax/v1/attention/backends/flashinfer.py @@ -1,31 +1,45 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Attention layer with FlashInfer.""" -from __future__ import annotations from dataclasses import dataclass -from typing import ClassVar, Optional, Union +from typing import Any, ClassVar import numpy as np import torch -from flashinfer import (BatchDecodeWithPagedKVCacheWrapper, - BatchPrefillWithPagedKVCacheWrapper) +from flashinfer import ( + BatchDecodeWithPagedKVCacheWrapper, + BatchPrefillWithPagedKVCacheWrapper, +) from flashinfer.decode import _get_range_buf from vllm import _custom_ops as ops -from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionType) +from vllm.attention.backends.abstract import ( + AttentionBackend, + AttentionImpl, + AttentionType, + MultipleOf, +) from vllm.config import CUDAGraphMode, VllmConfig from vllm.logger import init_logger +from vllm.model_executor.layers.batch_invariant import ( + vllm_is_batch_invariant, +) from vllm.model_executor.layers.quantization.utils.quant_utils import ( - QuantKey, kFp8StaticTensorSym, kNvfp4Quant) + QuantKey, + kFp8StaticTensorSym, + kNvfp4Quant, +) from vllm.platforms import current_platform from vllm.triton_utils import tl, triton -from vllm.utils import cdiv, is_pin_memory_available -from vllm.utils.flashinfer import (can_use_trtllm_attention, - flashinfer_disable_q_quantization, - supports_trtllm_attention, - use_trtllm_attention) +from vllm.utils.math_utils import cdiv, is_pin_memory_available +from vllm.utils.flashinfer import ( + can_use_trtllm_attention, + flashinfer_disable_q_quantization, + supports_trtllm_attention, + use_trtllm_attention, +) + # yapf conflicts with isort for this block # yapf: disable from vllm.v1.attention.backends.utils import (AttentionCGSupport, @@ -49,6 +63,7 @@ #from flashinfer.prefill import trtllm_batch_context_with_kv_cache FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024 +FLASHINFER_WORKSPACE_BUFFER_SIZE_BATCH_INVARIANT = 2048 * 1024 * 1024 FP8_DTYPE = current_platform.fp8_dtype() FP4_DTYPE = torch.uint8 @@ -62,7 +77,8 @@ def _get_trtllm_gen_workspace_buffer(): global trtllm_gen_workspace_buffer if trtllm_gen_workspace_buffer is None: trtllm_gen_workspace_buffer = torch.zeros( - FLASHINFER_WORKSPACE_BUFFER_SIZE, dtype=torch.uint8, device='cuda') + FLASHINFER_WORKSPACE_BUFFER_SIZE, dtype=torch.uint8, device="cuda" + ) return trtllm_gen_workspace_buffer @@ -79,9 +95,9 @@ def _trtllm_prefill_attn_kvfp8_dequant( ): batch_idx = tl.program_id(0).to(tl.int64) mock_block_table_idx = tl.program_id(1).to(tl.int64) - orig_page_num = tl.load(block_tables_prefill_ptr + - batch_idx * block_table_stride + - mock_block_table_idx).to(tl.int64) + orig_page_num = tl.load( + block_tables_prefill_ptr + batch_idx * block_table_stride + mock_block_table_idx + ).to(tl.int64) if orig_page_num <= 0: return dequant_dtype = mock_kv_cache_ptr.dtype.element_ty @@ -91,20 +107,24 @@ def _trtllm_prefill_attn_kvfp8_dequant( offset = orig_page_num * KV_CACHE_STRIDE + tl.arange(0, K_CACHE_STRIDE) fp8_vals = tl.load(kv_cache_ptr + offset) dequantized_vals = fp8_vals.to(tl.float32) * k_scale_val - mock_cache_offset = (batch_idx * block_table_stride + mock_block_table_idx - + 1) * KV_CACHE_STRIDE + tl.arange(0, K_CACHE_STRIDE) + mock_cache_offset = ( + batch_idx * block_table_stride + mock_block_table_idx + 1 + ) * KV_CACHE_STRIDE + tl.arange(0, K_CACHE_STRIDE) dequantized_vals = dequantized_vals.to(dequant_dtype) tl.store(mock_kv_cache_ptr + mock_cache_offset, dequantized_vals) # Dequantize V v_scale_val = tl.load(v_scale_ptr) - offset = (orig_page_num * KV_CACHE_STRIDE + K_CACHE_STRIDE + - tl.arange(0, K_CACHE_STRIDE)) + offset = ( + orig_page_num * KV_CACHE_STRIDE + K_CACHE_STRIDE + tl.arange(0, K_CACHE_STRIDE) + ) fp8_vals = tl.load(kv_cache_ptr + offset) dequantized_vals = fp8_vals.to(tl.float32) * v_scale_val mock_cache_offset = ( - (batch_idx * block_table_stride + mock_block_table_idx + 1) * - KV_CACHE_STRIDE + K_CACHE_STRIDE + tl.arange(0, K_CACHE_STRIDE)) + (batch_idx * block_table_stride + mock_block_table_idx + 1) * KV_CACHE_STRIDE + + K_CACHE_STRIDE + + tl.arange(0, K_CACHE_STRIDE) + ) dequantized_vals = dequantized_vals.to(dequant_dtype) tl.store(mock_kv_cache_ptr + mock_cache_offset, dequantized_vals) @@ -124,9 +144,7 @@ def trtllm_prefill_attn_kvfp8_dequant( kv_cache_stride = k_cache_stride * s[1] new_s = (batch_size * num_of_page_per_token + 1, s[1], s[2], s[3], s[4]) # mock kv cache contains just the pages needed by this prefill - mock_kv_cache = torch.empty(new_s, - dtype=dequant_dtype, - device=kv_cache.device) + mock_kv_cache = torch.empty(new_s, dtype=dequant_dtype, device=kv_cache.device) # we simply sequentially index the pages needed by this prefill mock_block_table = torch.arange( start=1, @@ -160,6 +178,11 @@ def get_supported_head_sizes(cls) -> list[int]: # https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157 return [64, 128, 256] + @staticmethod + def get_supported_kernel_block_size() -> list[int | MultipleOf]: + # Note: Not sure for all platforms, + return [16, 32, 64] + @classmethod def validate_head_size(cls, head_size: int) -> None: supported_head_sizes = cls.get_supported_head_sizes() @@ -169,22 +192,23 @@ def validate_head_size(cls, head_size: int) -> None: f"Head size {head_size} is not supported by {attn_type}. " f"Supported head sizes are: {supported_head_sizes}. " "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use " - "FlexAttention backend which supports all head sizes.") + "FlexAttention backend which supports all head sizes." + ) @staticmethod def get_name() -> str: return "FLASHINFER" @staticmethod - def get_impl_cls() -> type[FlashInferImpl]: + def get_impl_cls() -> type["FlashInferImpl"]: return FlashInferImpl @staticmethod - def get_metadata_cls() -> type[FlashInferMetadata]: + def get_metadata_cls() -> type["FlashInferMetadata"]: return FlashInferMetadata @staticmethod - def get_builder_cls() -> type[FlashInferMetadataBuilder]: + def get_builder_cls() -> type["FlashInferMetadataBuilder"]: return FlashInferMetadataBuilder @staticmethod @@ -212,6 +236,7 @@ def get_kv_cache_stride_order() -> tuple[int, ...]: @staticmethod def get_fp8_dtype_for_flashinfer(kv_cache_dtype: str) -> torch.dtype: + raise NotImplementedError("Maca does not support FP8 FlashInfer.") if kv_cache_dtype in ("fp8", "fp8_e4m3"): return torch.float8_e4m3fn elif kv_cache_dtype == "fp8_e5m2": @@ -247,24 +272,30 @@ class FlashInferMetadata: # For cascade attention (CPU for planning). use_cascade: bool - prefill_wrapper: Optional[BatchPrefillWithPagedKVCacheWrapper] = None - decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None + prefill_wrapper: BatchPrefillWithPagedKVCacheWrapper | None = None + decode_wrapper: BatchDecodeWithPagedKVCacheWrapper | None = None # /------------------------ Metax Modification -------------------------\ - #cascade_wrapper: Optional[MultiLevelCascadeAttentionWrapper] = None + cascade_wrapper: Any | None = None # \------------------------- Metax Modification -------------------------/ - qo_indptr_gpu: Optional[torch.Tensor] = None - paged_kv_indptr_gpu: Optional[torch.Tensor] = None + qo_indptr_gpu: torch.Tensor | None = None + paged_kv_indptr_gpu: torch.Tensor | None = None class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): - cudagraph_support: ClassVar[AttentionCGSupport] = \ + cudagraph_support: ClassVar[AttentionCGSupport] = ( AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE + ) reorder_batch_threshold: int = 1 - def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], - vllm_config: VllmConfig, device: torch.device): + def __init__( + self, + kv_cache_spec: AttentionSpec, + layer_names: list[str], + vllm_config: VllmConfig, + device: torch.device, + ): super().__init__(kv_cache_spec, layer_names, vllm_config, device) self.cache_config = vllm_config.cache_config self.model_config = vllm_config.model_config @@ -272,23 +303,44 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], self._prefill_wrapper = None # Wrapper for prefill/append self._decode_wrapper = None # Wrapper for decode (general shape) + if vllm_is_batch_invariant(): + self.decode_fixed_split_size = 2048 + self.prefill_fixed_split_size = 4096 + self.disable_split_kv = True + else: + self.decode_fixed_split_size = -1 + self.prefill_fixed_split_size = -1 + self.disable_split_kv = False + self.compilation_config = vllm_config.compilation_config - max_num_pages_per_req = cdiv(self.model_config.max_model_len, - self.kv_cache_spec.block_size) + max_num_pages_per_req = cdiv( + self.model_config.max_model_len, self.kv_cache_spec.block_size + ) max_num_reqs = vllm_config.scheduler_config.max_num_seqs max_num_pages = max_num_reqs * max_num_pages_per_req - self.enable_cuda_graph = (self.compilation_config.cudagraph_mode.\ - decode_mode() == CUDAGraphMode.FULL) + speculative_config = vllm_config.speculative_config + num_spec_tokens = ( + speculative_config.num_speculative_tokens + if speculative_config is not None + else 0 + ) + self.enable_cuda_graph = ( + self.compilation_config.cudagraph_mode.decode_mode() == CUDAGraphMode.FULL + ) if self.enable_cuda_graph: # For full cudagraph capture, one `decode_wrapper` for each batch # size is needed for FlashInfer. self._decode_wrappers_cudagraph: dict[ - int, BatchDecodeWithPagedKVCacheWrapper] = {} + int, BatchDecodeWithPagedKVCacheWrapper + ] = {} self._decode_cudagraph_max_bs = min( - max_num_reqs, self.compilation_config.max_capture_size) + (1 + num_spec_tokens) * max_num_reqs, + self.compilation_config.max_cudagraph_capture_size, + ) self.num_qo_heads = self.model_config.get_num_attention_heads( - self.vllm_config.parallel_config) + self.vllm_config.parallel_config + ) self.num_kv_heads = self.kv_cache_spec.num_kv_heads self.head_dim = self.kv_cache_spec.head_size MacaFlashInferBackend.validate_head_size(self.head_dim) @@ -307,91 +359,88 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], # VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION is set to 1. Otherwise, try to # use fp8 q if kv cache is fp8, and will fall back to model dtype # if TRTLLM attention kernel is not used when building attn metadata - if supports_trtllm_attention() and \ - not flashinfer_disable_q_quantization(): + can_use_trtllm = can_use_trtllm_attention(self.num_qo_heads, self.num_kv_heads) + if can_use_trtllm and not flashinfer_disable_q_quantization(): self.q_data_type = self.kv_cache_dtype else: self.q_data_type = self.model_config.dtype - supports_spec_as_decode = \ - can_use_trtllm_attention(self.num_qo_heads, self.num_kv_heads) - self._init_reorder_batch_threshold(1, supports_spec_as_decode) + self._init_reorder_batch_threshold(1, supports_spec_as_decode=can_use_trtllm) self._cascade_wrapper = None # Wrapper for cascade attention # Global hyperparameters shared by all attention layers # TODO: discard this for trtllm-gen backend self.global_hyperparameters = infer_global_hyperparameters( - get_per_layer_parameters(vllm_config, layer_names, FlashInferImpl)) + get_per_layer_parameters(vllm_config, layer_names, FlashInferImpl) + ) self.sm_scale = self.global_hyperparameters.sm_scale self.window_left = self.global_hyperparameters.window_left self.logits_soft_cap = self.global_hyperparameters.logits_soft_cap self.has_sinks = self.global_hyperparameters.has_sinks - if self.has_sinks and not supports_trtllm_attention(): + if self.has_sinks and not can_use_trtllm: raise NotImplementedError( "FlashInfer backend currently does not support attention " "sinks, please use trtllm on blackwell or flash attention on " - "earlier GPUs.") + "earlier GPUs." + ) # Preparing persistent buffers (device-side) - self.paged_kv_indptr = torch.zeros(max_num_reqs + 1, - dtype=torch.int32, - device=self.device) + self.paged_kv_indptr = torch.zeros( + max_num_reqs + 1, dtype=torch.int32, device=self.device + ) self.paged_kv_indices = torch.zeros( max_num_pages, # max num pages possible dtype=torch.int32, - device=self.device) - self.paged_kv_last_page_len = torch.zeros(max_num_reqs, - dtype=torch.int32, - device=self.device) + device=self.device, + ) + self.paged_kv_last_page_len = torch.zeros( + max_num_reqs, dtype=torch.int32, device=self.device + ) # host-side buffer pin_memory = is_pin_memory_available() - self.paged_kv_indptr_cpu = torch.zeros(max_num_reqs + 1, - dtype=torch.int32, - device="cpu", - pin_memory=pin_memory) + self.paged_kv_indptr_cpu = torch.zeros( + max_num_reqs + 1, dtype=torch.int32, device="cpu", pin_memory=pin_memory + ) self.paged_kv_indptr_np = self.paged_kv_indptr_cpu.numpy() self.paged_kv_indptr_buffer = torch.zeros_like( - self.paged_kv_indptr_cpu, pin_memory=pin_memory) - self.paged_kv_indices_cpu = torch.zeros(max_num_pages, - dtype=torch.int32, - device="cpu", - pin_memory=pin_memory) - self.paged_kv_last_page_len_cpu = torch.zeros(max_num_reqs, - dtype=torch.int32, - device="cpu", - pin_memory=pin_memory) - self.paged_kv_last_page_len_np = ( - self.paged_kv_last_page_len_cpu.numpy()) + self.paged_kv_indptr_cpu, pin_memory=pin_memory + ) + self.paged_kv_indices_cpu = torch.zeros( + max_num_pages, dtype=torch.int32, device="cpu", pin_memory=pin_memory + ) + self.paged_kv_last_page_len_cpu = torch.zeros( + max_num_reqs, dtype=torch.int32, device="cpu", pin_memory=pin_memory + ) + self.paged_kv_last_page_len_np = self.paged_kv_last_page_len_cpu.numpy() def _get_workspace_buffer(self): if self._workspace_buffer is None: + buffer_size = FLASHINFER_WORKSPACE_BUFFER_SIZE + if vllm_is_batch_invariant(): + buffer_size = FLASHINFER_WORKSPACE_BUFFER_SIZE_BATCH_INVARIANT self._workspace_buffer = torch.zeros( - FLASHINFER_WORKSPACE_BUFFER_SIZE, - dtype=torch.uint8, - device=self.device) + buffer_size, dtype=torch.uint8, device=self.device + ) return self._workspace_buffer def _get_prefill_wrapper(self): if self._prefill_wrapper is None: self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper( - self._get_workspace_buffer(), get_kv_cache_layout()) + self._get_workspace_buffer(), get_kv_cache_layout() + ) return self._prefill_wrapper - def _get_decode_wrapper(self, - batch_size: int, - use_cudagraph: bool = False): + def _get_decode_wrapper(self, batch_size: int, use_cudagraph: bool = False): if use_cudagraph: - decode_wrapper = self._decode_wrappers_cudagraph.get( - batch_size, None) + decode_wrapper = self._decode_wrappers_cudagraph.get(batch_size, None) else: decode_wrapper = self._decode_wrapper if decode_wrapper is None: if use_cudagraph: - paged_kv_indptr = self.paged_kv_indptr[:batch_size + 1] + paged_kv_indptr = self.paged_kv_indptr[: batch_size + 1] paged_kv_indices = self.paged_kv_indices - paged_kv_last_page_len = self.paged_kv_last_page_len[: - batch_size] + paged_kv_last_page_len = self.paged_kv_last_page_len[:batch_size] else: paged_kv_indptr = None paged_kv_indices = None @@ -421,21 +470,25 @@ def _get_decode_wrapper(self, # Note: MetaX won't support cascade attention for now # \------------------------ Metax Modification -------------------------/ def _get_cascade_wrapper(self): - if self._cascade_wrapper is None: - self._cascade_wrapper = MultiLevelCascadeAttentionWrapper( - 2, self._get_workspace_buffer(), get_kv_cache_layout()) - return self._cascade_wrapper - - def build(self, - common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata, - fast_build: bool = False) -> FlashInferMetadata: + raise NotImplementedError( + "Maca FlashInfer backend does not support cascade attention." + ) + + def build( + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False, + ) -> FlashInferMetadata: num_reqs = common_attn_metadata.num_reqs num_actual_tokens = common_attn_metadata.num_actual_tokens - num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens =\ - split_decodes_and_prefills(common_attn_metadata, - decode_threshold=self.reorder_batch_threshold, - require_uniform=True) + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( + split_decodes_and_prefills( + common_attn_metadata, + decode_threshold=self.reorder_batch_threshold, + require_uniform=True, + ) + ) page_size = self.page_size max_q_len = common_attn_metadata.max_query_len @@ -447,24 +500,24 @@ def build(self, num_blocks_np = (seq_lens_np + (page_size - 1)) // page_size - use_cascade = common_prefix_len > 0 + # Metax does not support cascade attention for now + use_cascade = False and common_prefix_len > 0 if use_cascade: # Grab the blocks of the shared prefix from the first request. assert common_prefix_len % page_size == 0 num_common_kv_blocks = common_prefix_len // page_size # Create CPU versions directly for cascade (no GPU versions needed) - shared_qo_indptr_cpu = torch.tensor([0, num_actual_tokens], - dtype=torch.int32, - device='cpu') - shared_kv_page_indptr_cpu = torch.tensor([0, num_common_kv_blocks], - dtype=torch.int32, - device='cpu') - shared_kv_page_indices_cpu = block_table_tensor[ - 0, :num_common_kv_blocks] - shared_kv_last_page_len_cpu = torch.tensor([page_size], - dtype=torch.int32, - device='cpu') + shared_qo_indptr_cpu = torch.tensor( + [0, num_actual_tokens], dtype=torch.int32, device="cpu" + ) + shared_kv_page_indptr_cpu = torch.tensor( + [0, num_common_kv_blocks], dtype=torch.int32, device="cpu" + ) + shared_kv_page_indices_cpu = block_table_tensor[0, :num_common_kv_blocks] + shared_kv_last_page_len_cpu = torch.tensor( + [page_size], dtype=torch.int32, device="cpu" + ) # Remove the blocks of the shared prefix from all requests. block_table_tensor = block_table_tensor[:, num_common_kv_blocks:] @@ -479,22 +532,23 @@ def build(self, np.cumsum( num_blocks_np, dtype=np.int32, - out=self.paged_kv_indptr_np[1:num_reqs + 1], + out=self.paged_kv_indptr_np[1 : num_reqs + 1], ) # NOTE(woosuk): Because self.paged_kv_indptr_cpu can be modified # after this line (e.g., for cuda graphs), we need to copy the data to # self.paged_kv_indptr_buffer to avoid race condition. - self.paged_kv_indptr_buffer[:num_reqs + - 1] = (self.paged_kv_indptr_cpu[:num_reqs + - 1]) - paged_kv_indptr = self.paged_kv_indptr[:num_reqs + 1] - paged_kv_indptr.copy_(self.paged_kv_indptr_buffer[:num_reqs + 1], - non_blocking=True) + self.paged_kv_indptr_buffer[: num_reqs + 1] = self.paged_kv_indptr_cpu[ + : num_reqs + 1 + ] + paged_kv_indptr = self.paged_kv_indptr[: num_reqs + 1] + paged_kv_indptr.copy_( + self.paged_kv_indptr_buffer[: num_reqs + 1], non_blocking=True + ) # write self.paged_kv_indices inplace num_actual_pages = self.paged_kv_indptr_np[num_reqs] paged_kv_indices = self.paged_kv_indices[:num_actual_pages] - _copy_page_indices_kernel[(num_reqs, )]( + _copy_page_indices_kernel[(num_reqs,)]( paged_kv_indices, block_table_tensor, block_table_tensor.stride(0), @@ -511,33 +565,52 @@ def build(self, ) uses_spec_reorder = self.reorder_batch_threshold > 1 - prefill_use_trtllm = use_trtllm_attention(self.num_qo_heads, - self.num_kv_heads, - num_prefill_tokens, - max_seq_len, - self.cache_dtype, - self.q_data_type, - is_prefill=True, - has_sinks=self.has_sinks, - has_spec=uses_spec_reorder) - decode_use_trtllm = use_trtllm_attention(self.num_qo_heads, - self.num_kv_heads, - num_decode_tokens, - max_seq_len, - self.cache_dtype, - self.q_data_type, - is_prefill=False, - has_sinks=self.has_sinks, - has_spec=uses_spec_reorder) - if self.has_sinks and not (prefill_use_trtllm and decode_use_trtllm): - raise NotImplementedError( - "FlashInfer backend currently does not support attention " - "sinks, please use trtllm on blackwell or flash attention on " - "earlier GPUs.") + prefill_use_trtllm = use_trtllm_attention( + self.num_qo_heads, + self.num_kv_heads, + num_prefill_tokens, + max_seq_len, + self.cache_dtype, + self.q_data_type, + is_prefill=True, + has_sinks=self.has_sinks, + has_spec=uses_spec_reorder, + ) + decode_use_trtllm = use_trtllm_attention( + self.num_qo_heads, + self.num_kv_heads, + num_decode_tokens, + max_seq_len, + self.cache_dtype, + self.q_data_type, + is_prefill=False, + has_sinks=self.has_sinks, + has_spec=uses_spec_reorder, + ) - # If TRTLLM attention is not used, the q quantization is not supported. - # Fall back to use model dtype. if not (prefill_use_trtllm and decode_use_trtllm): + if self.has_sinks: + raise NotImplementedError( + "FlashInfer backend currently does not support attention " + "sinks, please use trtllm on blackwell or flash attention " + "on earlier GPUs." + ) + + if not self.global_hyperparameters.has_same_window_lefts: + raise ValueError( + "Window left is not the same for all layers. " + "One potential fix is to set disable_sliding_window=True" + ) + + assert self.global_hyperparameters.has_same_all_params, ( + "FlashInfer backend currently only supports models in which " + "all layers share the same values for the following " + "hyperparameters: `window_left`, `logits_soft_cap`, " + "`sm_scale`." + ) + + # The q quantization is not supported for non-trtllm attention, + # fall back to model dtype. self.q_data_type = self.model_config.dtype attn_metadata = FlashInferMetadata( @@ -559,48 +632,33 @@ def build(self, ) qo_indptr_cpu = common_attn_metadata.query_start_loc_cpu - paged_kv_indptr_cpu = self.paged_kv_indptr_cpu[:1 + num_reqs] + paged_kv_indptr_cpu = self.paged_kv_indptr_cpu[: 1 + num_reqs] paged_kv_last_page_len_cpu = self.paged_kv_last_page_len_cpu[:num_reqs] if attn_metadata.use_cascade: - attn_metadata.cascade_wrapper = self._get_cascade_wrapper() - attn_metadata.cascade_wrapper.plan( - [shared_qo_indptr_cpu, qo_indptr_cpu], - [shared_kv_page_indptr_cpu, paged_kv_indptr_cpu], - [shared_kv_page_indices_cpu, paged_kv_indices], - [shared_kv_last_page_len_cpu, paged_kv_last_page_len_cpu], - self.num_qo_heads, - self.num_kv_heads, - self.head_dim, - self.page_size, - causal=True, - sm_scale=self.sm_scale, - window_left=self.window_left, - logits_soft_cap=self.logits_soft_cap, - q_data_type=self.q_data_type, - kv_data_type=self.kv_cache_dtype, + raise NotImplementedError( + "Maca FlashInfer backend does not support cascade attention." ) else: # Regular attention (common case). - # Decodes are at the front and prefills are at the back, - # according to reorder_batch() + # Decodes are at the front and prefills are at the back. num_prefills = attn_metadata.num_prefills num_decodes = attn_metadata.num_decodes if num_prefills > 0: # Decodes are first so prefills start after the last decode prefill_start = num_decodes attn_metadata.prefill_wrapper = self._get_prefill_wrapper() - assert qo_indptr_cpu[prefill_start:].shape[ - 0] == num_prefills + 1 - assert paged_kv_indptr_cpu[prefill_start:].shape[ - 0] == num_prefills + 1 - assert paged_kv_last_page_len_cpu[prefill_start:].shape[ - 0] == num_prefills + assert qo_indptr_cpu[prefill_start:].shape[0] == num_prefills + 1 + assert paged_kv_indptr_cpu[prefill_start:].shape[0] == num_prefills + 1 + assert ( + paged_kv_last_page_len_cpu[prefill_start:].shape[0] == num_prefills + ) # Since prefill_wrapper.run() will be called with # query[num_decode_tokens:] we need to adjust the qo_indptr # to be relative to the start of the prefill queries. - qo_indptr_cpu = qo_indptr_cpu[prefill_start:] - qo_indptr_cpu[ - prefill_start] + qo_indptr_cpu = ( + qo_indptr_cpu[prefill_start:] - qo_indptr_cpu[prefill_start] + ) paged_kv_indptr_cpu = paged_kv_indptr_cpu[prefill_start:] # Recompute max_q_len for the slice of requests we are using @@ -608,8 +666,7 @@ def build(self, # we have a non-uniform batch with some short decodes offloaded # to the prefill pathway query_lens_prefill = qo_indptr_cpu[1:] - qo_indptr_cpu[:-1] - attn_metadata.max_q_len_prefill = \ - int(query_lens_prefill.max().item()) + attn_metadata.max_q_len_prefill = int(query_lens_prefill.max().item()) if not attn_metadata.prefill_use_trtllm: attn_metadata.prefill_wrapper.plan( @@ -630,42 +687,50 @@ def build(self, ) else: attn_metadata.qo_indptr_gpu = qo_indptr_cpu.to( - self.device, non_blocking=True) + self.device, non_blocking=True + ) attn_metadata.paged_kv_indptr_gpu = paged_kv_indptr_cpu.to( - self.device, non_blocking=True) + self.device, non_blocking=True + ) if num_decodes > 0: pure_decode = num_prefills == 0 # possible required padding for cudagraph replay - use_cudagraph = (self.enable_cuda_graph and pure_decode and - num_decodes <= self._decode_cudagraph_max_bs) + use_cudagraph = ( + self.enable_cuda_graph + and pure_decode + and num_decode_tokens <= self._decode_cudagraph_max_bs + ) if use_cudagraph: - num_input_tokens = ( - self.vllm_config.pad_for_cudagraph(num_decode_tokens)) + num_input_tokens = self.vllm_config.pad_for_cudagraph( + num_decode_tokens + ) # Carefully fulfill the padding region with reasonable value # on cpu. # Make sure paged_kv_indptr_cpu is not decreasing - self.paged_kv_indptr_cpu[1 + num_decodes:1 + - num_input_tokens].fill_( - paged_kv_indptr_cpu[-1]) + self.paged_kv_indptr_cpu[ + 1 + num_decodes : 1 + num_input_tokens + ].fill_(paged_kv_indptr_cpu[-1]) # Fill the remaining paged_kv_last_page_len_cpu with 1. # This is because flashinfer treats 0 as a full page # instead of empty. - self.paged_kv_last_page_len_cpu[ - num_decodes:num_input_tokens].fill_(1) + self.paged_kv_last_page_len_cpu[num_decodes:num_input_tokens].fill_( + 1 + ) else: num_input_tokens = num_decode_tokens attn_metadata.decode_wrapper = self._get_decode_wrapper( - num_input_tokens, use_cudagraph) + num_input_tokens, use_cudagraph + ) if not attn_metadata.decode_use_trtllm: # Use the persistent buffer with padding length, # instead of the same address but chunked version # in atten_metadata when using cudagraph. fast_plan_decode( attn_metadata.decode_wrapper, - self.paged_kv_indptr_cpu[:num_input_tokens + 1], + self.paged_kv_indptr_cpu[: num_input_tokens + 1], paged_kv_indices, self.paged_kv_last_page_len_cpu[:num_input_tokens], seq_lens_cpu[:num_input_tokens], @@ -701,13 +766,13 @@ def __init__( head_size: int, scale: float, num_kv_heads: int, - alibi_slopes: Optional[list[float]], - sliding_window: Optional[int], + alibi_slopes: list[float] | None, + sliding_window: int | None, kv_cache_dtype: str, - logits_soft_cap: Optional[float] = None, + logits_soft_cap: float | None = None, attn_type: AttentionType = AttentionType.DECODER, - kv_sharing_target_layer_name: Optional[int] = None, - sinks: Optional[torch.Tensor] = None, + kv_sharing_target_layer_name: int | None = None, + sinks: torch.Tensor | None = None, ) -> None: self.num_heads = num_heads self.head_size = head_size @@ -720,8 +785,9 @@ def __init__( self.sliding_window = (-1, -1) else: self.sliding_window = (sliding_window - 1, 0) - self.window_left = (self.sliding_window[0] - if self.sliding_window is not None else -1) + self.window_left = ( + self.sliding_window[0] if self.sliding_window is not None else -1 + ) self.kv_cache_dtype = kv_cache_dtype self.logits_soft_cap = logits_soft_cap self.kv_sharing_target_layer_name = kv_sharing_target_layer_name @@ -729,12 +795,14 @@ def __init__( self.num_queries_per_kv = self.num_heads // self.num_kv_heads if attn_type != AttentionType.DECODER: - raise NotImplementedError("Encoder self-attention and " - "encoder/decoder cross-attention " - "are not implemented for " - "FlashInferImpl") + raise NotImplementedError( + "Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "FlashInferImpl" + ) - # self.sinks: Optional[torch.Tensor] = None + self.sinks: torch.Tensor | None = None # if sinks is not None: # if sinks.shape[0] != num_heads: # raise ValueError( @@ -743,16 +811,28 @@ def __init__( # f"{sinks.shape[0]}." # ) # self.sinks = sinks - self.support_trtllm_attn = (supports_trtllm_attention() - and num_heads % num_kv_heads == 0) - self.bmm1_scale: Optional[float] = None - self.bmm2_scale: Optional[float] = None - self.o_sf_scale: Optional[float] = None + self.support_trtllm_attn = can_use_trtllm_attention(num_heads, num_kv_heads) + self.bmm1_scale: float | None = None + self.bmm2_scale: float | None = None + self.o_sf_scale: float | None = None def fused_output_quant_supported(self, quant_key: QuantKey): - return (self.support_trtllm_attn - and self.kv_cache_dtype.startswith("fp8") - and quant_key in (kFp8StaticTensorSym, kNvfp4Quant)) + return ( + self.support_trtllm_attn + and self.kv_cache_dtype.startswith("fp8") + and quant_key in (kFp8StaticTensorSym, kNvfp4Quant) + ) + + def supports_quant_query_input(self) -> bool: + if flashinfer_disable_q_quantization(): + return False + + return self.support_trtllm_attn + + # FlashInfer requires attention sinks to be float32 + def process_weights_after_loading(self, act_dtype: torch.dtype): + if self.sinks is not None and self.sinks.dtype != torch.float32: + self.sinks = self.sinks.to(torch.float32) def forward( self, @@ -762,9 +842,9 @@ def forward( value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: FlashInferMetadata, - output: Optional[torch.Tensor] = None, - output_scale: Optional[torch.Tensor] = None, - output_block_scale: Optional[torch.Tensor] = None, + output: torch.Tensor | None = None, + output_scale: torch.Tensor | None = None, + output_block_scale: torch.Tensor | None = None, ) -> torch.Tensor: """Forward pass with FlashInfer. @@ -783,31 +863,41 @@ def forward( if attn_metadata is None: # Profiling run. - return output + return output.fill_(0) + + # Ensure query dtype matches the expected dtype from attention metadata + assert attn_metadata.q_data_type == query.dtype, ( + f"Query dtype mismatch: expected {attn_metadata.q_data_type}, " + f"got {query.dtype}" + ) if self.bmm1_scale is None: - self.bmm1_scale = (layer._q_scale_float * layer._k_scale_float * - self.scale) + self.bmm1_scale = layer._q_scale_float * layer._k_scale_float * self.scale if self.bmm2_scale is None: self.bmm2_scale = layer._v_scale_float # The attn+quant fusion happens when output_scale is provided. if output_scale is None: - assert output_block_scale is None, "output_block_scale "\ - "is not supported when fusion has not happened" + assert output_block_scale is None, ( + "output_block_scale is not supported when fusion has not happened" + ) else: - assert attn_metadata.q_data_type == FP8_DTYPE, \ + assert attn_metadata.q_data_type == FP8_DTYPE, ( "Query must be FP8 when attn+quant fusion happened." - assert (attn_metadata.prefill_use_trtllm and - attn_metadata.decode_use_trtllm), "Must use TRT-LLM attn" + ) + assert ( + attn_metadata.prefill_use_trtllm and attn_metadata.decode_use_trtllm + ), "Must use TRT-LLM attn" if output.dtype == FP8_DTYPE: - assert output_block_scale is None, \ + assert output_block_scale is None, ( "output_block_scale should not be provided for fp8 output" + ) elif output.dtype == FP4_DTYPE: - assert output_block_scale is not None, \ + assert output_block_scale is not None, ( "output_block_scale is required for nvfp4 output" + ) else: raise ValueError(f"Unsupported output dtype: {output.dtype}") @@ -821,15 +911,6 @@ def forward( elif output.dtype == FP4_DTYPE: self.o_sf_scale = layer._o_scale_float - # Insert FP8 quant for query - if attn_metadata.q_data_type == FP8_DTYPE: - num_tokens, num_heads, head_size = query.shape - query, _ = ops.scaled_fp8_quant( - query.reshape( - (num_tokens, num_heads * head_size)).contiguous(), - layer._q_scale) - query = query.reshape((num_tokens, num_heads, head_size)) - # IMPORTANT! # NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in # eager-mode PyTorch. Thus, we need to be careful about any CPU overhead @@ -864,7 +945,8 @@ def forward( # to process the cache when the kv_cache_dtype is fp8 if self.kv_cache_dtype.startswith("fp8"): torch_dtype = MacaFlashInferBackend.get_fp8_dtype_for_flashinfer( - self.kv_cache_dtype) + self.kv_cache_dtype + ) kv_cache = kv_cache.view(torch_dtype) # Inputs and outputs may be padded for CUDA graphs @@ -887,8 +969,7 @@ def forward( stride_order = MacaFlashInferBackend.get_kv_cache_stride_order() kv_cache_permute = kv_cache.permute(*stride_order) # Regular attention (common case). - # Decodes are at the front and prefills are at the back, - # according to reorder_batch() + # Decodes are at the front and prefills are at the back. if num_prefill_tokens > 0: prefill_wrapper = attn_metadata.prefill_wrapper prefill_query = query[num_decode_tokens:] @@ -898,8 +979,7 @@ def forward( if not attn_metadata.prefill_use_trtllm: assert prefill_wrapper._causal assert prefill_wrapper._window_left == self.window_left - assert prefill_wrapper._logits_soft_cap == ( - self.logits_soft_cap or 0.0) + assert prefill_wrapper._logits_soft_cap == (self.logits_soft_cap or 0.0) assert prefill_wrapper._sm_scale == self.scale prefill_wrapper.run( prefill_query, @@ -910,65 +990,8 @@ def forward( ) else: # prefill_query may be non-contiguous - prefill_query = prefill_query.contiguous() - workspace_buffer = _get_trtllm_gen_workspace_buffer() - block_tables_prefill = attn_metadata.block_table_tensor[ - num_decodes:] - seq_lens_prefill = attn_metadata.seq_lens[num_decodes:] - - # This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND - assert get_kv_cache_layout() == "HND" - assert prefill_query.is_contiguous() - assert kv_cache_permute.is_contiguous() - assert workspace_buffer.is_contiguous() - assert block_tables_prefill.is_contiguous() - assert seq_lens_prefill.is_contiguous() - - if output.dtype == FP4_DTYPE: - assert self.o_sf_scale is not None - out = FP4Tensor(data=output[num_decode_tokens:], - scale=output_block_scale, - scale_start_index=num_decode_tokens, - original_shape=prefill_query.shape) - else: - assert self.o_sf_scale is None - out = output[num_decode_tokens:] - - if attn_metadata.q_data_type != FP8_DTYPE \ - and self.kv_cache_dtype.startswith("fp8"): - # TRTLLM prefill attention does not support BF16 Q - # and fp8 kv cache. So to enable prefill attention - # with fp8 kv cache, we can construct a mock block - # and mock kv cache with BF16 KV involved in the prefill - mock_kv_cache, mock_block_table = ( - trtllm_prefill_attn_kvfp8_dequant( - kv_cache_permute, - block_tables_prefill, - layer._k_scale, - layer._v_scale, - attn_metadata.q_data_type, - )) - else: - mock_kv_cache = kv_cache_permute - mock_block_table = block_tables_prefill - - trtllm_batch_context_with_kv_cache( - query=prefill_query, - kv_cache=mock_kv_cache, - workspace_buffer=workspace_buffer, - block_tables=mock_block_table, - seq_lens=seq_lens_prefill, - max_q_len=attn_metadata.max_q_len_prefill, - max_kv_len=attn_metadata.max_seq_len, - bmm1_scale=self.bmm1_scale, - bmm2_scale=self.bmm2_scale, - batch_size=attn_metadata.num_prefills, - cum_seq_lens_q=attn_metadata.qo_indptr_gpu, - cum_seq_lens_kv=attn_metadata.paged_kv_indptr_gpu, - window_left=self.window_left, - sinks=self.sinks, - o_sf_scale=self.o_sf_scale, - out=out, + raise NotImplementedError( + "prefill_use_trtllm prefill attention is not implemented yet." ) if num_decode_tokens > 0: @@ -979,8 +1002,7 @@ def forward( if not attn_metadata.decode_use_trtllm: assert decode_wrapper._window_left == self.window_left - assert decode_wrapper._logits_soft_cap == (self.logits_soft_cap - or 0.0) + assert decode_wrapper._logits_soft_cap == (self.logits_soft_cap or 0.0) assert decode_wrapper._sm_scale == self.scale decode_wrapper.run( decode_query, @@ -990,53 +1012,9 @@ def forward( out=output[:num_decode_tokens], ) else: - # decode_query may be non-contiguous - decode_query = decode_query.contiguous() - workspace_buffer = _get_trtllm_gen_workspace_buffer() - block_tables_decode = attn_metadata.\ - block_table_tensor[:num_decode_tokens] - seq_lens_decode = attn_metadata.seq_lens[:num_decode_tokens] - - # This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND - assert get_kv_cache_layout() == "HND" - assert decode_query.is_contiguous() - assert kv_cache_permute.is_contiguous() - assert workspace_buffer.is_contiguous() - assert block_tables_decode.is_contiguous() - assert seq_lens_decode.is_contiguous() - - if output.dtype == FP4_DTYPE: - assert self.o_sf_scale is not None - out = FP4Tensor(data=output[:num_decode_tokens], - scale=output_block_scale, - scale_start_index=0, - original_shape=decode_query.shape) - else: - assert self.o_sf_scale is None - out = output[:num_decode_tokens] - - if num_decode_tokens % attn_metadata.num_decodes != 0: - # This gets triggered when the dummy_run forces - # attention to be initialized with q_len = 0 - q_len_per_req = 1 - else: - q_len_per_req = \ - num_decode_tokens // attn_metadata.num_decodes - - trtllm_batch_decode_with_kv_cache( - query=decode_query, - kv_cache=kv_cache_permute, - workspace_buffer=workspace_buffer, - block_tables=block_tables_decode, - seq_lens=seq_lens_decode, - max_seq_len=attn_metadata.max_seq_len, - bmm1_scale=self.bmm1_scale, - bmm2_scale=self.bmm2_scale, - window_left=self.window_left, - sinks=self.sinks, - o_sf_scale=self.o_sf_scale, - out=out, - q_len_per_req=q_len_per_req) + raise NotImplementedError( + "prefill_use_trtllm prefill attention is not implemented yet." + ) return output_padded @@ -1052,13 +1030,13 @@ def fast_plan_decode( page_size: int, pos_encoding_mode: str = "NONE", window_left: int = -1, - logits_soft_cap: Optional[float] = None, - q_data_type: Optional[Union[str, torch.dtype]] = "float16", - kv_data_type: Optional[Union[str, torch.dtype]] = None, - data_type: Optional[Union[str, torch.dtype]] = None, - sm_scale: Optional[float] = None, - rope_scale: Optional[float] = None, - rope_theta: Optional[float] = None, + logits_soft_cap: float | None = None, + q_data_type: str | torch.dtype | None = "float16", + kv_data_type: str | torch.dtype | None = None, + data_type: str | torch.dtype | None = None, + sm_scale: float | None = None, + rope_scale: float | None = None, + rope_theta: float | None = None, non_blocking: bool = True, ) -> None: """ @@ -1077,8 +1055,7 @@ def fast_plan_decode( # Warm up with the original plan if it is first call, and always run the # original plan if we run for dynamic shape. For fixed shape (cudagraph), # this warm up is to generate the _cached_module for the decode wrapper. - if not self.is_cuda_graph_enabled or \ - getattr(self, "vllm_first_call", True): + if not self.is_cuda_graph_enabled or getattr(self, "vllm_first_call", True): self.plan( indptr_cpu, indices, @@ -1118,26 +1095,28 @@ def fast_plan_decode( if kv_data_type is None: kv_data_type = q_data_type - q_data_type = getattr(torch, q_data_type) if isinstance( - q_data_type, str) else q_data_type - kv_data_type = getattr(torch, kv_data_type) if isinstance( - kv_data_type, str) else kv_data_type + q_data_type = ( + getattr(torch, q_data_type) if isinstance(q_data_type, str) else q_data_type + ) + kv_data_type = ( + getattr(torch, kv_data_type) if isinstance(kv_data_type, str) else kv_data_type + ) if batch_size != self._fixed_batch_size: raise ValueError( "The batch size should be fixed in cudagraph mode, the runtime " "batch size {} mismatches the batch size set during " - "initialization {}".format(batch_size, self._fixed_batch_size)) + "initialization {}".format(batch_size, self._fixed_batch_size) + ) if len(indices) > len(self._paged_kv_indices_buf): raise ValueError( - "The size of indices should be less than or equal to the " - "allocated buffer") + "The size of indices should be less than or equal to the allocated buffer" + ) # host-to-device copy for the indptr buffer self._paged_kv_indptr_buf.copy_(indptr_cpu, non_blocking=True) # host-to-device copy for the last_page_len buffer - self._paged_kv_last_page_len_buf.copy_(last_page_len_cpu, - non_blocking=True) + self._paged_kv_last_page_len_buf.copy_(last_page_len_cpu, non_blocking=True) qo_indptr_host = _get_range_buf(batch_size + 1, "cpu") @@ -1190,6 +1169,8 @@ def _copy_page_indices_kernel( offset = tl.arange(0, BLOCK_SIZE) for i in tl.range(0, num_blocks, BLOCK_SIZE): block_ids = tl.load(row_ptr + i + offset, mask=i + offset < num_blocks) - tl.store(page_indices + start_idx + i + offset, - block_ids, - mask=i + offset < num_blocks) + tl.store( + page_indices + start_idx + i + offset, + block_ids, + mask=i + offset < num_blocks, + ) diff --git a/vllm_metax/v1/attention/backends/flex_attention.py b/vllm_metax/v1/attention/backends/flex_attention.py index 7fe55a9b0..c3702e929 100644 --- a/vllm_metax/v1/attention/backends/flex_attention.py +++ b/vllm_metax/v1/attention/backends/flex_attention.py @@ -3,35 +3,44 @@ """Attention layer with FlexAttention.""" from dataclasses import dataclass -from typing import TYPE_CHECKING, Optional, Union import torch import torch._dynamo.decorators import torch.nn.functional as F -from torch.nn.attention.flex_attention import (BlockMask, _mask_mod_signature, - _score_mod_signature, and_masks, - create_block_mask, - flex_attention) - -from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata, AttentionType, - is_quantized_kv_cache) +from torch.nn.attention.flex_attention import ( + BlockMask, + _mask_mod_signature, + _score_mod_signature, + and_masks, + create_block_mask, + flex_attention, +) + +from vllm.attention.backends.abstract import ( + AttentionBackend, + AttentionImpl, + AttentionMetadata, + AttentionType, + is_quantized_kv_cache, +) from vllm.config import VllmConfig from vllm.logger import init_logger -from vllm.utils import cdiv, is_torch_equal_or_newer -from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, - CommonAttentionMetadata) +from vllm.model_executor.layers.batch_invariant import ( + vllm_is_batch_invariant, +) +from vllm.utils.math_utils import cdiv +from vllm.utils.torch_utils import is_torch_equal_or_newer +from vllm.v1.attention.backends.utils import ( + AttentionMetadataBuilder, + CommonAttentionMetadata, +) from vllm.v1.kv_cache_interface import AttentionSpec logger = init_logger(__name__) -if TYPE_CHECKING: - from vllm.v1.core.sched.output import SchedulerOutput - from vllm.v1.worker.gpu_input_batch import InputBatch - -create_block_mask_compiled = torch.compile(create_block_mask, - fullgraph=True, - mode="reduce-overhead") +create_block_mask_compiled = torch.compile( + create_block_mask, fullgraph=True, mode="reduce-overhead" +) flex_attention_compiled = torch.compile(flex_attention, fullgraph=True) @@ -39,7 +48,8 @@ def _offsets_to_doc_ids_tensor(offsets: torch.Tensor) -> torch.Tensor: device = offsets.device counts = offsets[1:] - offsets[:-1] return torch.repeat_interleave( - torch.arange(len(counts), device=device, dtype=torch.int32), counts) + torch.arange(len(counts), device=device, dtype=torch.int32), counts + ) def pad_to_multiple(x: torch.Tensor, multiple: int, dim: int): @@ -59,7 +69,7 @@ def pad_to_multiple(x: torch.Tensor, multiple: int, dim: int): return F.pad(x, pad_list, mode="constant", value=0) -class FlexAttentionBackend(AttentionBackend): +class MacaFlexAttentionBackend(AttentionBackend): accept_output_buffer: bool = True @classmethod @@ -101,10 +111,13 @@ def use_cascade_attention(*args, **kwargs) -> bool: return False -#@torch.compile(fullgraph=True, mode="reduce-overhead") -def physical_to_logical_mapping(block_table: torch.Tensor, - seq_lens: torch.Tensor, block_size: int, - total_blocks: int) -> torch.Tensor: +# @torch.compile(fullgraph=True, mode="reduce-overhead") +def physical_to_logical_mapping( + block_table: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + total_blocks: int, +) -> torch.Tensor: """ Creates an inverse mapping from physical block locations to logical indices. @@ -174,35 +187,37 @@ def physical_to_logical_mapping(block_table: torch.Tensor, max_reqs, max_num_blocks = block_table.shape device = block_table.device - physical_to_logical = torch.full((max_reqs, total_blocks), - -1, - dtype=torch.long, - device=device) + physical_to_logical = torch.full( + (max_reqs, total_blocks), -1, dtype=torch.long, device=device + ) # Only process valid blocks to avoid garbage values num_blocks_per_seq = cdiv(seq_lens, block_size) - mask = torch.arange(max_num_blocks, - device=device)[None, :] < num_blocks_per_seq[:, None] + mask = ( + torch.arange(max_num_blocks, device=device)[None, :] + < num_blocks_per_seq[:, None] + ) valid_block_table = torch.where(mask, block_table, 0) valid_logical_indices = torch.where( - mask, - torch.arange(max_num_blocks, device=device)[None, :], 0) + mask, torch.arange(max_num_blocks, device=device)[None, :], 0 + ) - physical_to_logical.scatter_(-1, valid_block_table.to(torch.int64), - valid_logical_indices) + physical_to_logical.scatter_( + -1, valid_block_table.to(torch.int64), valid_logical_indices + ) # NB - Seems like block 0 is always empty so we reset it manually physical_to_logical[:, 0] = -1 return physical_to_logical def unique_static_unsorted( - x: torch.Tensor, - *, - M: int, # maximum positive value (0 is “skip me”) - dim: int = -1, # axis along which to deduplicate - ignored_val: int = 0, # value to ignore - pad_val: int = -1, # sentinel for unused slots + x: torch.Tensor, + *, + M: int, # maximum positive value (0 is “skip me”) + dim: int = -1, # axis along which to deduplicate + ignored_val: int = 0, # value to ignore + pad_val: int = -1, # sentinel for unused slots ) -> torch.Tensor: """ - Keeps the first occurrence of each non-zero value while preserving order, @@ -234,8 +249,7 @@ def unique_static_unsorted( first_idx.scatter_reduce_(1, x_flat, idx, reduce="amin") # ── keep mask: first occurrence *and* value ≠ 0 ───────────────────── - keep = (x_flat != ignored_val) & (idx == first_idx.gather(1, x_flat) - ) # [B, N] + keep = (x_flat != ignored_val) & (idx == first_idx.gather(1, x_flat)) # [B, N] # ── left-pack uniques into a fresh tensor ─────────────────────────── dest_pos = torch.cumsum(keep.to(torch.long), dim=1) - 1 # where to go @@ -249,8 +263,9 @@ def unique_static_unsorted( return packed -def causal_mask_mod(b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor, - kv_idx: torch.Tensor): +def causal_mask_mod( + b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor +): return q_idx >= kv_idx @@ -267,9 +282,9 @@ class FlexAttentionMetadata: use_cascade: bool common_prefix_len: int - cu_prefix_query_lens: Optional[torch.Tensor] - prefix_kv_lens: Optional[torch.Tensor] - suffix_kv_lens: Optional[torch.Tensor] + cu_prefix_query_lens: torch.Tensor | None + prefix_kv_lens: torch.Tensor | None + suffix_kv_lens: torch.Tensor | None # Block info total_cache_tokens: int @@ -285,15 +300,15 @@ class FlexAttentionMetadata: # Flex Metadata num_blocks = 0 - block_mask: Optional[BlockMask] = None - score_mod: Optional[_score_mod_signature] = None + block_mask: BlockMask | None = None + score_mod: _score_mod_signature | None = None logical_mask_mod: _mask_mod_signature = causal_mask_mod - doc_ids: Optional[torch.Tensor] = None + doc_ids: torch.Tensor | None = None direct_build: bool = True q_block_size: int = 16 kv_block_size: int = 16 - transformed_score_mod: Optional[_score_mod_signature] = None - sliding_window: Optional[int] = None + transformed_score_mod: _score_mod_signature | None = None + sliding_window: int | None = None def _convert_physical_to_logical( self, @@ -315,8 +330,7 @@ def _convert_physical_to_logical( physical_kv_block = physical_kv_idx // self.block_size physical_kv_offset = physical_kv_idx % self.block_size logical_block_idx = self.physical_to_logical[q_req, physical_kv_block] - logical_kv_idx = (logical_block_idx * self.block_size + - physical_kv_offset) + logical_kv_idx = logical_block_idx * self.block_size + physical_kv_offset # Determine valid kv indices live_block = logical_block_idx >= 0 @@ -350,9 +364,9 @@ def final_mask_mod( q_idx: torch.Tensor, physical_kv_idx: torch.Tensor, ) -> torch.Tensor: - (is_valid, logical_q_idx, - logical_kv_idx) = self._convert_physical_to_logical( - self.doc_ids, q_idx, physical_kv_idx) + (is_valid, logical_q_idx, logical_kv_idx) = ( + self._convert_physical_to_logical(self.doc_ids, q_idx, physical_kv_idx) + ) # Apply mask modification only for valid indices return torch.where( is_valid, @@ -390,11 +404,11 @@ def get_sliding_window_mask_mod(self) -> _mask_mod_signature: """ if self.sliding_window is None: - raise ValueError( - "sliding_window must be set for sliding window attention") + raise ValueError("sliding_window must be set for sliding window attention") - def sliding_window_mask_mod(b: torch.Tensor, h: torch.Tensor, - q_idx: torch.Tensor, kv_idx: torch.Tensor): + def sliding_window_mask_mod( + b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor + ): return torch.abs(q_idx - kv_idx) < self.sliding_window def final_mask_mod( @@ -403,9 +417,9 @@ def final_mask_mod( q_idx: torch.Tensor, physical_kv_idx: torch.Tensor, ) -> torch.Tensor: - (is_valid, logical_q_idx, - logical_kv_idx) = self._convert_physical_to_logical( - self.doc_ids, q_idx, physical_kv_idx) + (is_valid, logical_q_idx, logical_kv_idx) = ( + self._convert_physical_to_logical(self.doc_ids, q_idx, physical_kv_idx) + ) return torch.where( is_valid, sliding_window_mask_mod(b, h, logical_q_idx, logical_kv_idx), @@ -429,7 +443,7 @@ def get_mask_mod(self): mask_mod = and_masks(mask_mod, sliding_window_mask_mod) return mask_mod - def get_transformed_score_mod(self) -> Optional[_score_mod_signature]: + def get_transformed_score_mod(self) -> _score_mod_signature | None: """Creates the transformed score_mod function for FlexAttention. This function wraps the user's score_mod to handle physical-to-logical @@ -449,18 +463,19 @@ def transformed_score_mod( q_idx: torch.Tensor, physical_kv_idx: torch.Tensor, ) -> torch.Tensor: - (is_valid, logical_q_idx, - logical_kv_idx) = self._convert_physical_to_logical( - request_lookup, q_idx, physical_kv_idx) + (is_valid, logical_q_idx, logical_kv_idx) = ( + self._convert_physical_to_logical( + request_lookup, q_idx, physical_kv_idx + ) + ) return torch.where( is_valid, - user_score_mod(score, - b, - h, - logical_q_idx, - logical_kv_idx, - physical_q=q_idx), -float('inf')) + user_score_mod( + score, b, h, logical_q_idx, logical_kv_idx, physical_q=q_idx + ), + -float("inf"), + ) return transformed_score_mod @@ -491,18 +506,22 @@ def _build_block_mask_direct(self) -> BlockMask: f"FlexAttention currently requires the cache block size " f"({self.block_size}) to be equal to the kv_block_size " f"({self.kv_block_size}). Please check your model's " - f"configuration.") + f"configuration." + ) used_pages = self.block_table[ - self.doc_ids, :cdiv(self.max_seq_len, self.block_size)] - used_pages_padded = pad_to_multiple(used_pages, - multiple=self.q_block_size, - dim=0) + self.doc_ids, : cdiv(self.max_seq_len, self.block_size) + ] + used_pages_padded = pad_to_multiple( + used_pages, multiple=self.q_block_size, dim=0 + ) used_pages_padded = used_pages_padded.reshape( - used_pages_padded.shape[0] // self.q_block_size, -1) + used_pages_padded.shape[0] // self.q_block_size, -1 + ) used_pages_padded = used_pages_padded // page_to_block_ratio - kv_indices = unique_static_unsorted((used_pages_padded.long()), - M=self.num_blocks).to(torch.int32) + kv_indices = unique_static_unsorted( + (used_pages_padded.long()), M=self.num_blocks + ).to(torch.int32) kv_num_blocks = (kv_indices >= 0).sum(dim=-1).to(torch.int32) block_mask_kwargs = { @@ -522,8 +541,7 @@ def _build_block_mask_direct(self) -> BlockMask: def build_block_mask(self) -> BlockMask: mask_mod = self.get_mask_mod() - kv_len = (self.total_cache_tokens - if self.causal else self.num_actual_tokens) + kv_len = self.total_cache_tokens if self.causal else self.num_actual_tokens return create_block_mask_compiled( mask_mod, None, @@ -553,11 +571,14 @@ def __post_init__(self): self.block_mask = self.build_block_mask() -class FlexAttentionMetadataBuilder( - AttentionMetadataBuilder[FlexAttentionMetadata]): - - def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], - vllm_config: VllmConfig, device: torch.device): +class FlexAttentionMetadataBuilder(AttentionMetadataBuilder[FlexAttentionMetadata]): + def __init__( + self, + kv_cache_spec: AttentionSpec, + layer_names: list[str], + vllm_config: VllmConfig, + device: torch.device, + ): super().__init__(kv_cache_spec, layer_names, vllm_config, device) self.model_config = vllm_config.model_config @@ -565,26 +586,22 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], self.cache_config = vllm_config.cache_config self.num_heads_q = self.model_config.get_num_attention_heads( - self.parallel_config) - self.num_heads_kv = self.model_config.get_num_kv_heads( - self.parallel_config) + self.parallel_config + ) + self.num_heads_kv = self.model_config.get_num_kv_heads(self.parallel_config) self.headdim = self.model_config.get_head_size() self.block_size = kv_cache_spec.block_size self.kv_cache_spec = kv_cache_spec self.direct_build: bool = is_torch_equal_or_newer("2.9.0.dev0") - self.q_block_size: int = 16 if is_torch_equal_or_newer( - "2.9.0.dev0") else 128 - self.kv_block_size: int = 16 if is_torch_equal_or_newer( - "2.9.0.dev0") else 128 + self.q_block_size: int = 16 if is_torch_equal_or_newer("2.9.0.dev0") else 128 + self.kv_block_size: int = 16 if is_torch_equal_or_newer("2.9.0.dev0") else 128 - def reorder_batch(self, input_batch: "InputBatch", - scheduler_output: "SchedulerOutput") -> bool: - return False - - def build(self, - common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata, - fast_build: bool = False) -> FlexAttentionMetadata: + def build( + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False, + ) -> FlexAttentionMetadata: num_reqs = common_attn_metadata.num_reqs num_actual_tokens = common_attn_metadata.num_actual_tokens max_query_len = common_attn_metadata.max_query_len @@ -607,15 +624,18 @@ def build(self, max_possible_seq_len = self.model_config.max_model_len num_gpu_blocks = self.cache_config.num_gpu_blocks - assert num_gpu_blocks is not None, \ + assert num_gpu_blocks is not None, ( "FlexAttention requires num_gpu_blocks to be set" - total_cache_tokens = (num_gpu_blocks * block_size) + ) + total_cache_tokens = num_gpu_blocks * block_size inverse_block_table = physical_to_logical_mapping( - block_table_tensor, seq_lens, block_size, num_gpu_blocks) + block_table_tensor, seq_lens, block_size, num_gpu_blocks + ) offset_tensor = common_attn_metadata.num_computed_tokens_cpu.to( - self.device, non_blocking=True) + self.device, non_blocking=True + ) out = FlexAttentionMetadata( causal=common_attn_metadata.causal, @@ -638,7 +658,10 @@ def build(self, total_cache_tokens=total_cache_tokens, decode_offset=offset_tensor, num_blocks_per_seq=num_blocks_per_seq, - direct_build=self.direct_build, + # FIXME(Isotr0py): direct build has issue to build bidirectional + # attention block mask for encoder-only models, disable it temporarily. + # see: https://github.com/vllm-project/vllm/pull/27329#issuecomment-3431484053 + direct_build=(self.direct_build and common_attn_metadata.causal), q_block_size=self.q_block_size, kv_block_size=self.kv_block_size, ) @@ -649,9 +672,9 @@ def use_cascade_attention(self, *args, **kwargs) -> bool: class FlexAttentionImpl(AttentionImpl): - sliding_window: Optional[int] - alibi_slopes: Optional[torch.Tensor] - logits_soft_cap: Optional[float] + sliding_window: int | None + alibi_slopes: torch.Tensor | None + logits_soft_cap: float | None def __init__( self, @@ -659,12 +682,12 @@ def __init__( head_size: int, scale: float, num_kv_heads: int, - alibi_slopes: Optional[list[float]], - sliding_window: Optional[int], + alibi_slopes: list[float] | None, + sliding_window: int | None, kv_cache_dtype: str, - logits_soft_cap: Optional[float] = None, + logits_soft_cap: float | None = None, attn_type: AttentionType = AttentionType.DECODER, - kv_sharing_target_layer_name: Optional[str] = None, + kv_sharing_target_layer_name: str | None = None, **kwargs, ) -> None: self.num_heads = num_heads @@ -673,14 +696,15 @@ def __init__( self.num_kv_heads = num_kv_heads self.attn_type = attn_type - if attn_type not in (AttentionType.ENCODER_ONLY, - AttentionType.DECODER): + if attn_type not in (AttentionType.ENCODER_ONLY, AttentionType.DECODER): raise NotImplementedError( - f"FlexAttention does not support {attn_type} attention") + f"FlexAttention does not support {attn_type} attention" + ) if alibi_slopes is not None: raise NotImplementedError( - "FlexAttention does not support alibi slopes yet.") + "FlexAttention does not support alibi slopes yet." + ) else: self.alibi_slopes = None @@ -690,19 +714,20 @@ def __init__( self.logits_soft_cap = logits_soft_cap if self.logits_soft_cap is not None: raise NotImplementedError( - "FlexAttention does not support logits soft cap yet.") + "FlexAttention does not support logits soft cap yet." + ) assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads if kv_sharing_target_layer_name is not None: - raise NotImplementedError( - "FlexAttention does not support kv sharing yet.") + raise NotImplementedError("FlexAttention does not support kv sharing yet.") - FlexAttentionBackend.validate_head_size(head_size) + MacaFlexAttentionBackend.validate_head_size(head_size) if is_quantized_kv_cache(self.kv_cache_dtype): raise NotImplementedError( - "FlexAttention does not support quantized kv-cache. Yet") + "FlexAttention does not support quantized kv-cache. Yet" + ) @staticmethod def view_as_4d(tensor: torch.Tensor) -> torch.Tensor: @@ -720,9 +745,9 @@ def forward( value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: FlexAttentionMetadata, - output: Optional[torch.Tensor] = None, - output_scale: Optional[torch.Tensor] = None, - output_block_scale: Optional[torch.Tensor] = None, + output: torch.Tensor | None = None, + output_scale: torch.Tensor | None = None, + output_block_scale: torch.Tensor | None = None, ) -> torch.Tensor: """Forward pass with FLexAttention. @@ -739,14 +764,14 @@ def forward( assert output is not None, "Output tensor must be provided." if output_scale is not None or output_block_scale is not None: raise NotImplementedError( - "fused output quantization is not yet supported" - " for FlexAttentionImpl") + "fused output quantization is not yet supported for FlexAttentionImpl" + ) enable_gqa = self.num_kv_heads != self.num_heads if attn_metadata is None: # Profiling run. - return output + return output.fill_(0) # query = self.view_as_4d(query).permute(0, 2, 1, 3) # return torch.empty_like(query) @@ -759,11 +784,11 @@ def forward( # in direct block mask building code path. logger.warning_once( "Using direct block mask building with sliding window, " - "which is suboptimal now. Performance may be degraded.") + "which is suboptimal now. Performance may be degraded." + ) # update mask mod in attention metadata attn_metadata.mask_mod = attn_metadata.get_mask_mod() - attn_metadata.block_mask = ( - attn_metadata._build_block_mask_direct()) + attn_metadata.block_mask = attn_metadata._build_block_mask_direct() else: attn_metadata.block_mask = attn_metadata.build_block_mask() @@ -776,8 +801,9 @@ def forward( ) query = query[:, :, :num_actual_tokens, :] - if ((key_tensor.size(-2) > num_actual_tokens) - or (value_tensor.size(-2) > num_actual_tokens)): + if (key_tensor.size(-2) > num_actual_tokens) or ( + value_tensor.size(-2) > num_actual_tokens + ): # In the encoder-only model with torch.compile, # qkv might be padded, which might cause exception. # see: https://github.com/vllm-project/vllm/pull/24872#discussion_r2353252290 @@ -801,8 +827,7 @@ def forward( # View out the block_size dim key_cache = key_cache.view(-1, self.num_kv_heads, self.head_size) - value_cache = value_cache.view(-1, self.num_kv_heads, - self.head_size) + value_cache = value_cache.view(-1, self.num_kv_heads, self.head_size) query, key_tensor, value_tensor = map( lambda x: self.view_as_4d(x).permute(0, 2, 1, 3), (query, key_cache, value_cache), @@ -816,8 +841,9 @@ def forward( assert attn_metadata.block_mask is not None block_m, block_n = attn_metadata.block_mask.BLOCK_SIZE - kernel_options = get_kernel_options(query, block_m, block_n, - attn_metadata.direct_build) + kernel_options = get_kernel_options( + query, block_m, block_n, attn_metadata.direct_build + ) out = flex_attention_compiled( query, key_tensor, @@ -835,11 +861,17 @@ def forward( return output -def get_kernel_options(query, block_m, block_n, - use_direct_build: bool) -> dict[str, Union[int, bool]]: - kernel_options: dict[str, Union[int, bool]] = { +def get_kernel_options( + query, block_m, block_n, use_direct_build: bool +) -> dict[str, int | bool]: + kernel_options: dict[str, int | bool] = { "FORCE_USE_FLEX_ATTENTION": True, } + if vllm_is_batch_invariant(): + kernel_options["BLOCK_M"] = 16 + kernel_options["BLOCK_N"] = 16 + kernel_options["IS_DIVISIBLE"] = False + return kernel_options if use_direct_build: kernel_options["BLOCK_M"] = block_m kernel_options["BLOCK_N"] = block_n diff --git a/vllm_metax/v1/attention/backends/mla/common.py b/vllm_metax/v1/attention/backends/mla/common.py index 20add8af6..faaef48e7 100644 --- a/vllm_metax/v1/attention/backends/mla/common.py +++ b/vllm_metax/v1/attention/backends/mla/common.py @@ -190,82 +190,104 @@ import functools from abc import abstractmethod from dataclasses import dataclass, field -from typing import Generic, Optional, TypeVar, Union +from enum import Enum +from typing import ClassVar, Generic, TypeVar import torch from tqdm import tqdm import vllm.envs as envs -from flash_attn import flash_attn_varlen_func from vllm import _custom_ops as ops -from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer, - AttentionMetadata, - MLAAttentionImpl) +from vllm.attention.backends.abstract import ( + AttentionBackend, + AttentionLayer, + AttentionMetadata, + MLAAttentionImpl, +) from vllm.attention.backends.utils import get_mla_dims from vllm.attention.ops.common import cp_lse_ag_out_rs from vllm.attention.ops.merge_attn_states import merge_attn_states +from vllm_metax.attention.utils.fa_utils import get_flash_attn_version from vllm.config import VllmConfig, get_current_vllm_config from vllm.distributed.parallel_state import get_dcp_group, is_global_first_rank from vllm.logger import init_logger -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - LinearBase, - UnquantizedLinearMethod) +from vllm.model_executor.layers.batch_invariant import ( + vllm_is_batch_invariant, +) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + LinearBase, + UnquantizedLinearMethod, +) from vllm.platforms import current_platform -from vllm.utils import cdiv, round_down +from vllm.utils.math_utils import cdiv, round_down from vllm.utils.flashinfer import has_nvidia_artifactory -from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, - CommonAttentionMetadata, - get_per_layer_parameters, - infer_global_hyperparameters, - split_decodes_and_prefills) +from vllm.v1.attention.backends.utils import ( + AttentionMetadataBuilder, + CommonAttentionMetadata, + get_per_layer_parameters, + infer_global_hyperparameters, + split_decodes_and_prefills, +) from vllm.v1.kv_cache_interface import AttentionSpec -from vllm.v1.attention.backends.mla.common import logger -from vllm_metax.attention.utils.fa_utils import get_flash_attn_version -is_vllm_fa = False +class QueryLenSupport(Enum): + """Defines the level of query length support for an attention backend's + decode pipeline. + + - SINGLE_ONLY: Decode pipeline only supports single-token queries + (query_len=1) + - UNIFORM: Decode pipeline supports uniform multi-token queries + (all requests must have same query_len > 1) + - VARLEN: Decode pipeline supports variable-length queries + (mixed query lengths in same batch) + """ + + SINGLE_ONLY = "single_only" + UNIFORM = "uniform" + VARLEN = "varlen" + + +try: + from vllm.vllm_flash_attn import flash_attn_varlen_func + + is_vllm_fa = True +except ImportError: + # For rocm use upstream flash attention + if current_platform.is_out_of_tree(): + from flash_attn import flash_attn_varlen_func + is_vllm_fa = False try: from flashinfer import BatchPrefillWithRaggedKVCacheWrapper # /------------------------ Metax Modification -------------------------\ - # from flashinfer.prefill import ( # noqa: F401 - # cudnn_batch_prefill_with_kv_cache) + cudnn_batch_prefill_with_kv_cache = None # type: ignore # \------------------------- Metax Modification -------------------------/ - logger.info("cudnn_batch_prefill_with_kv_cache is not supported in Metax.") flashinfer_available = True except ImportError: flashinfer_available = False +# workaround for type checking def is_rocm_aiter_fp8bmm_enabled() -> bool: - return current_platform.is_rocm() \ - and envs.VLLM_ROCM_USE_AITER_FP8BMM \ - and envs.VLLM_ROCM_USE_AITER + return False -if is_rocm_aiter_fp8bmm_enabled(): - from aiter.ops.triton.batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant import ( # noqa: E501 # isort: skip - batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant - as aiter_triton_fp8_bmm) +aiter_triton_fp8_bmm = None # type: ignore - def dynamic_per_batched_tensor_quant( - x: torch.Tensor, dtype: torch.dtype = torch.float8_e4m3fn): - DTYPE_MAX = torch.finfo(dtype).max - min_val, max_val = x.aminmax() - amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-10) - scale = DTYPE_MAX / amax - x_scl_sat = (x * scale).clamp(min=-DTYPE_MAX, max=DTYPE_MAX) - return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal() +def dynamic_per_batched_tensor_quant(x: torch.Tensor, dtype: torch.dtype): + pass -logger = init_logger(__name__) + +from vllm.v1.attention.backends.mla.common import logger CUDNN_WORKSPACE_SIZE = 12800 class MLACommonBackend(AttentionBackend): - accept_output_buffer: bool = True @staticmethod @@ -307,12 +329,13 @@ def validate_head_size(cls, head_size: int) -> None: f"Head size {head_size} is not supported by {attn_type}. " f"Supported head sizes are: {supported_head_sizes}. " "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use " - "FlexAttention backend which supports all head sizes.") + "FlexAttention backend which supports all head sizes." + ) @dataclass class MLACommonPrefillMetadata: - """ Prefill Specific Metadata """ + """Prefill Specific Metadata""" @dataclass class ChunkedContextMetadata: @@ -326,40 +349,40 @@ class ChunkedContextMetadata: workspace: torch.Tensor # for mla DCP - cp_chunk_seq_lens: Optional[list[list[int]]] = None - origin_context_lens: Optional[list[int]] = None - cp_cu_seq_lens: Optional[torch.Tensor] = None - chunk_size: Optional[int] = None - cu_seq_lens_lst: Optional[list[list[int]]] = None + cp_chunk_seq_lens: list[list[int]] | None = None + origin_context_lens: list[int] | None = None + cp_cu_seq_lens: torch.Tensor | None = None + chunk_size: int | None = None + cu_seq_lens_lst: list[list[int]] | None = None block_table: torch.Tensor query_start_loc: torch.Tensor max_query_len: int - chunked_context: Optional[ChunkedContextMetadata] = None + chunked_context: ChunkedContextMetadata | None = None + query_seq_lens: torch.Tensor | None = None @dataclass class FlashInferPrefillMetadata(MLACommonPrefillMetadata): - prefill_main: Optional['BatchPrefillWithRaggedKVCacheWrapper'] = None - prefill_chunks: list['BatchPrefillWithRaggedKVCacheWrapper'] = field( - default_factory=list) + prefill_main: BatchPrefillWithRaggedKVCacheWrapper | None = None + prefill_chunks: list[BatchPrefillWithRaggedKVCacheWrapper] = field( + default_factory=list + ) @dataclass class CudnnPrefillMetadata(MLACommonPrefillMetadata): - - class ChunkedContextMetadata( - MLACommonPrefillMetadata.ChunkedContextMetadata): + class ChunkedContextMetadata(MLACommonPrefillMetadata.ChunkedContextMetadata): seq_lens: torch.Tensor - query_seq_lens: Optional[torch.Tensor] = None - cudnn_workspace: Optional[torch.Tensor] = None + cudnn_workspace: torch.Tensor | None = None @dataclass class MLACommonDecodeMetadata: block_table: torch.Tensor seq_lens: torch.Tensor + dcp_tot_seq_lens: torch.Tensor | None D = TypeVar("D", bound=MLACommonDecodeMetadata) @@ -372,6 +395,7 @@ class MLACommonMetadata(Generic[D]): NOTE: Please read the comment at the top of the file before trying to understand this class """ + # NOTE(sang): Definition of context_len, query_len, and seq_len. # |---------- N-1 iteration --------| # |---------------- N iteration ---------------------| @@ -395,12 +419,15 @@ class MLACommonMetadata(Generic[D]): num_prefills: int # The dimension of the attention heads - head_dim: Optional[int] = None + head_dim: int | None = None - decode: Optional[D] = None - prefill: Optional[Union[MLACommonPrefillMetadata, - FlashInferPrefillMetadata, - CudnnPrefillMetadata]] = None + decode: D | None = None + prefill: ( + MLACommonPrefillMetadata + | FlashInferPrefillMetadata + | CudnnPrefillMetadata + | None + ) = None def __post_init__(self): if self.head_dim is not None: @@ -414,18 +441,38 @@ def __post_init__(self): def use_flashinfer_prefill() -> bool: # For blackwell default to flashinfer prefill if it's available since # it is faster than FA2. - return (not envs.VLLM_DISABLE_FLASHINFER_PREFILL and flashinfer_available - and not envs.VLLM_USE_CUDNN_PREFILL - and current_platform.is_device_capability(100)) + return ( + not envs.VLLM_DISABLE_FLASHINFER_PREFILL + and flashinfer_available + and not envs.VLLM_USE_CUDNN_PREFILL + and not envs.VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL + and current_platform.is_device_capability(100) + ) def use_cudnn_prefill() -> bool: - logger.info("cudnn prefill is not supported in Metax.") + logger.info("cudnn prefill is not supported on Maca.") return False # /------------------------ Metax Modification -------------------------\ - # return (flashinfer_available and envs.VLLM_USE_CUDNN_PREFILL - # and current_platform.is_device_capability(100) - # and has_nvidia_artifactory()) + return ( + flashinfer_available + and envs.VLLM_USE_CUDNN_PREFILL + and current_platform.is_device_capability(100) + and has_nvidia_artifactory() + ) + # \------------------------ Metax Modification -------------------------/ + + +def use_trtllm_ragged_deepseek_prefill() -> bool: + """Check if TRT-LLM ragged DeepSeek prefill should be used.""" + logger.info("TRT-LLM ragged DeepSeek prefill is not supported on Maca.") + return False + # /------------------------ Metax Modification -------------------------\ + return ( + flashinfer_available + and envs.VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL + and current_platform.is_device_capability(100) + ) # \------------------------ Metax Modification -------------------------/ @@ -440,19 +487,34 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): NOTE: Please read the comment at the top of the file before trying to understand this class """ + + # Defines the level of query length support for this backend. + # - SINGLE_ONLY: Only single-token queries (no spec decode support) + # - UNIFORM: Supports uniform multi-token queries (spec decode with uniform lengths) + # - VARLEN: Supports variable-length queries (spec decode with mixed lengths) + # If set to UNIFORM or VARLEN, this will increase `reorder_batch_threshold` when + # speculative decoding is enabled. + query_len_support: ClassVar[QueryLenSupport] = QueryLenSupport.SINGLE_ONLY + + # The threshold for reordering the batch into decode and prefill requests. + # If > 1, the batch will be reordered such that requests with + # query length <= threshold are classified as decode requests. + # Use `query_len_support` (above) to set this automatically + # when speculative decoding is enabled. reorder_batch_threshold: int = 1 @staticmethod - def determine_chunked_prefill_workspace_size( - vllm_config: VllmConfig) -> int: + def determine_chunked_prefill_workspace_size(vllm_config: VllmConfig) -> int: scheduler_config = vllm_config.scheduler_config cache_config = vllm_config.cache_config model_config = vllm_config.model_config chunked_prefill_workspace_size = min( # Try for 8 full length request or at least 4 pages per-request - max(8 * model_config.max_model_len, - 4 * scheduler_config.max_num_seqs * cache_config.block_size), + max( + 8 * model_config.max_model_len, + 4 * scheduler_config.max_num_seqs * cache_config.block_size, + ), # For long-context models try not to over-allocate limiting # kv-cache space, limiting it to 64k tokens, # which would result in the workspace being: @@ -461,32 +523,37 @@ def determine_chunked_prefill_workspace_size( # which would result in up-projected context being # 2*(192*128)*(64*1024) = 3gb # (assuming 192 QK head dim, 128 heads, and fp16) - 64 * 1024) + 64 * 1024, + ) # Enforce that we enough for at least 1 page per request chunked_prefill_workspace_size = max( chunked_prefill_workspace_size, - scheduler_config.max_num_seqs * cache_config.block_size) + scheduler_config.max_num_seqs * cache_config.block_size, + ) return chunked_prefill_workspace_size - def __init__(self, - kv_cache_spec: AttentionSpec, - layer_names: list[str], - vllm_config: VllmConfig, - device: torch.device, - metadata_cls: Optional[type[M]] = None): - self.metadata_cls = metadata_cls \ - if metadata_cls is not None else MLACommonMetadata + def __init__( + self, + kv_cache_spec: AttentionSpec, + layer_names: list[str], + vllm_config: VllmConfig, + device: torch.device, + metadata_cls: type[M] | None = None, + ): + self.metadata_cls = ( + metadata_cls if metadata_cls is not None else MLACommonMetadata + ) self.kv_cache_spec = kv_cache_spec scheduler_config = vllm_config.scheduler_config self.model_config = vllm_config.model_config parallel_config = vllm_config.parallel_config self.compilation_config = vllm_config.compilation_config + self.vllm_config = vllm_config self.device = device - self.num_heads = self.model_config.get_num_attention_heads( - parallel_config) + self.num_heads = self.model_config.get_num_attention_heads(parallel_config) self.mla_dims = get_mla_dims(self.model_config) # /------------------------ Metax Modification -------------------------\ self.aot_schedule = False @@ -504,52 +571,62 @@ def __init__(self, if self.aot_schedule: self.page_size = self.kv_cache_spec.block_size - self.chunked_prefill_workspace_size = \ + self.chunked_prefill_workspace_size = ( self.determine_chunked_prefill_workspace_size(vllm_config) + ) if self.dcp_world_size > 1: # Note(hc): The local kvcache is incomplete when DCP is triggered, # an additional kvcache allgather across the DCP group is therefore # required, so the workspace has to be enlarged by 1/DCP relative # to the original TP allocation. - assert self.chunked_prefill_workspace_size % \ - self.dcp_world_size == 0 + assert self.chunked_prefill_workspace_size % self.dcp_world_size == 0 self.chunked_prefill_workspace = torch.empty( - (self.chunked_prefill_workspace_size + - self.chunked_prefill_workspace_size // self.dcp_world_size, - self.model_config.get_head_size()), + ( + self.chunked_prefill_workspace_size + + self.chunked_prefill_workspace_size // self.dcp_world_size, + self.model_config.get_head_size(), + ), dtype=self.model_config.dtype, device=device, ) else: self.chunked_prefill_workspace = torch.empty( - (self.chunked_prefill_workspace_size, - self.model_config.get_head_size()), + ( + self.chunked_prefill_workspace_size, + self.model_config.get_head_size(), + ), dtype=self.model_config.dtype, device=device, ) self._use_cudnn_prefill = use_cudnn_prefill() self._use_fi_prefill = use_flashinfer_prefill() + self._use_trtllm_ragged_prefill = use_trtllm_ragged_deepseek_prefill() self.prefill_metadata_cls = ( FlashInferPrefillMetadata - if self._use_fi_prefill else CudnnPrefillMetadata - if self._use_cudnn_prefill else MLACommonPrefillMetadata) + if self._use_fi_prefill + else CudnnPrefillMetadata + if self._use_cudnn_prefill + else MLACommonPrefillMetadata + ) if self._use_fi_prefill: self._workspace_buffer = torch.empty( - FLASHINFER_WORKSPACE_BUFFER_SIZE, - dtype=torch.uint8, - device=device) + FLASHINFER_WORKSPACE_BUFFER_SIZE, dtype=torch.uint8, device=device + ) - self._fi_prefill_main: Optional[ - BatchPrefillWithRaggedKVCacheWrapper] = None - self._fi_prefill_chunks: list[ - BatchPrefillWithRaggedKVCacheWrapper] = [] + self._fi_prefill_main: BatchPrefillWithRaggedKVCacheWrapper | None = None + self._fi_prefill_chunks: list[BatchPrefillWithRaggedKVCacheWrapper] = [] self._global_hyperparameters = infer_global_hyperparameters( - get_per_layer_parameters(vllm_config, layer_names, - MLACommonImpl)) + get_per_layer_parameters(vllm_config, layer_names, MLACommonImpl) + ) + + if self._use_trtllm_ragged_prefill: + self._workspace_buffer = torch.empty( + FLASHINFER_WORKSPACE_BUFFER_SIZE, dtype=torch.uint8, device=device + ) if self._use_cudnn_prefill: self.cudnn_workspace = torch.empty( @@ -558,6 +635,18 @@ def __init__(self, device=device, ) + supports_spec_decode = self.query_len_support != QueryLenSupport.SINGLE_ONLY + self._init_reorder_batch_threshold( + self.reorder_batch_threshold, supports_spec_decode + ) + + # Validate consistency between query_len_support and reorder_batch_threshold + if self.query_len_support == QueryLenSupport.SINGLE_ONLY: + assert self.reorder_batch_threshold == 1, ( + f"reorder_batch_threshold must be 1 when query_len_support is " + f"SINGLE_ONLY, got {self.reorder_batch_threshold}" + ) + def _build_fi_prefill_wrappers(self, prefill: FlashInferPrefillMetadata): qo_indptr = prefill.query_start_loc @@ -568,7 +657,8 @@ def _build_fi_prefill_wrappers(self, prefill: FlashInferPrefillMetadata): if self._fi_prefill_main is None: self._fi_prefill_main = BatchPrefillWithRaggedKVCacheWrapper( - self._workspace_buffer, "NHD", backend="cutlass") + self._workspace_buffer, "NHD", backend="cutlass" + ) if has_context: num_chunks = chunked_context.cu_seq_lens.shape[0] @@ -577,7 +667,9 @@ def _build_fi_prefill_wrappers(self, prefill: FlashInferPrefillMetadata): for _ in range(len(self._fi_prefill_chunks), num_chunks): self._fi_prefill_chunks.append( BatchPrefillWithRaggedKVCacheWrapper( - self._workspace_buffer, "NHD", backend="cutlass")) + self._workspace_buffer, "NHD", backend="cutlass" + ) + ) assert num_chunks <= len(self._fi_prefill_chunks) # In MLA, the non-latent num_qo_heads == num_kv_heads @@ -588,8 +680,7 @@ def _build_fi_prefill_wrappers(self, prefill: FlashInferPrefillMetadata): assert self.kv_cache_spec.num_kv_heads == 1 # Get non-latent head_dim_qk and head_dim_vo - head_dim_qk = (self.mla_dims.qk_nope_head_dim + - self.mla_dims.qk_rope_head_dim) + head_dim_qk = self.mla_dims.qk_nope_head_dim + self.mla_dims.qk_rope_head_dim head_dim_vo = self.mla_dims.v_head_dim # For main run, qo_indptr == kv_indptr @@ -625,45 +716,52 @@ def _build_fi_prefill_wrappers(self, prefill: FlashInferPrefillMetadata): causal=False, # This is context run sm_scale=self._global_hyperparameters.sm_scale, window_left=self._global_hyperparameters.window_left, - logits_soft_cap=self._global_hyperparameters. - logits_soft_cap, + logits_soft_cap=self._global_hyperparameters.logits_soft_cap, q_data_type=self.model_config.dtype, ) prefill.prefill_main = self._fi_prefill_main prefill.prefill_chunks = self._fi_prefill_chunks - def _build_decode(self, block_table_tensor: torch.Tensor, - seq_lens_cpu: torch.Tensor, - seq_lens_device: torch.Tensor, - query_start_loc_cpu: torch.Tensor, - query_start_loc_device: torch.Tensor, - num_decode_tokens: int) -> MLACommonDecodeMetadata: + def _build_decode( + self, + block_table_tensor: torch.Tensor, + seq_lens_cpu: torch.Tensor, + seq_lens_device: torch.Tensor, + query_start_loc_cpu: torch.Tensor, + query_start_loc_device: torch.Tensor, + num_decode_tokens: int, + dcp_tot_seq_lens_device: torch.Tensor | None, + ) -> MLACommonDecodeMetadata: return MLACommonDecodeMetadata( block_table=block_table_tensor, seq_lens=seq_lens_device, + dcp_tot_seq_lens=dcp_tot_seq_lens_device, ) def build_for_cudagraph_capture( - self, common_attn_metadata: CommonAttentionMetadata) -> M: + self, common_attn_metadata: CommonAttentionMetadata + ) -> M: """ This method builds the metadata for full cudagraph capture. Currently, only decode is supported for full cudagraphs with MLA. """ m = common_attn_metadata - assert m.num_reqs <= (m.num_actual_tokens * - self.reorder_batch_threshold), \ - "MLA only supports decode-only full CUDAGraph capture. " \ + assert m.num_reqs <= (m.num_actual_tokens * self.reorder_batch_threshold), ( + "MLA only supports decode-only full CUDAGraph capture. " "Make sure all cudagraph capture sizes <= max_num_seq." + ) assert m.max_query_len <= self.reorder_batch_threshold # decode only return self.build(0, m) - def build(self, - common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata, - fast_build: bool = False) -> M: + def build( + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False, + ) -> M: num_reqs = common_attn_metadata.num_reqs num_tokens = common_attn_metadata.num_actual_tokens max_query_len = common_attn_metadata.max_query_len @@ -680,21 +778,28 @@ def build(self, query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu seq_lens = common_attn_metadata.seq_lens seq_lens_cpu = common_attn_metadata.seq_lens_cpu + dcp_local_seq_lens = common_attn_metadata.dcp_local_seq_lens query_seq_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] - num_computed_tokens_cpu = (common_attn_metadata.seq_lens_cpu - - query_seq_lens_cpu) + num_computed_tokens_cpu = common_attn_metadata.seq_lens_cpu - query_seq_lens_cpu - num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \ - split_decodes_and_prefills(common_attn_metadata, - decode_threshold=self.reorder_batch_threshold) + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( + split_decodes_and_prefills( + common_attn_metadata, + decode_threshold=self.reorder_batch_threshold, + require_uniform=(self.query_len_support != QueryLenSupport.VARLEN), + ) + ) # Note(hc): update seq_lens of decode reqs under DCP. if self.dcp_world_size > 1: - seq_lens[:num_decodes] = seq_lens[:num_decodes] \ - // self.dcp_world_size + (self.dcp_rank <= \ - (seq_lens[:num_decodes] - 1) % self.dcp_world_size) + assert dcp_local_seq_lens is not None + dcp_local_seq_lens[:num_decodes] = seq_lens[ + :num_decodes + ] // self.dcp_world_size + ( + self.dcp_rank <= (seq_lens[:num_decodes] - 1) % self.dcp_world_size + ) assert num_decodes + num_prefills == num_reqs assert num_decode_tokens + num_prefill_tokens == num_tokens @@ -705,13 +810,15 @@ def build(self, context_lens_cpu = num_computed_tokens_cpu[reqs_start:num_reqs] # Note(hc): The context lengths in the perspective of dcp rank0. - cp_context_lens_cpu = torch.ceil(context_lens_cpu.float() / - self.dcp_world_size).int() + cp_context_lens_cpu = torch.ceil( + context_lens_cpu.float() / self.dcp_world_size + ).int() origin_context_lens = context_lens_cpu.tolist() max_context_len_cpu = context_lens_cpu.max().item() num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item() - prefill_query_start_loc = query_start_loc[ - reqs_start:] - query_start_loc[reqs_start] + prefill_query_start_loc = ( + query_start_loc[reqs_start:] - query_start_loc[reqs_start] + ) chunked_context_metadata = None if max_context_len_cpu > 0: @@ -723,16 +830,16 @@ def build(self, # prefill in the batch, we could probably use a more advanced # algorithm here and allocate more workspace to prefills with # longer context lengths - max_context_chunk = (self.chunked_prefill_workspace_size // - num_prefills_with_context_cpu) + max_context_chunk = ( + self.chunked_prefill_workspace_size // num_prefills_with_context_cpu + ) if self.aot_schedule: # align max_context_chunk to page_size by rounding down, # currently the `gather_and_maybe_dequant_cache` kernel # cannot handle `context_chunk_starts` that are not aligned # to page_size - max_context_chunk = round_down(max_context_chunk, - self.page_size) + max_context_chunk = round_down(max_context_chunk, self.page_size) assert max_context_chunk > 0 num_chunks = cdiv(max_context_len_cpu, max_context_chunk) @@ -743,22 +850,23 @@ def build(self, # [[0, 0, 0, 0], [256, 256, 256, 256], [512, 512, 512, 512]] # Note(simon): this is done in CPU because of downstream's # of `to_list`. - chunk_starts = \ - torch.arange(num_chunks, dtype=torch.int32) \ - .unsqueeze(1).expand(-1, num_prefills) \ + chunk_starts = ( + torch.arange(num_chunks, dtype=torch.int32) + .unsqueeze(1) + .expand(-1, num_prefills) * max_context_chunk - chunk_ends = torch.min(context_lens_cpu.unsqueeze(0), - chunk_starts + max_context_chunk) + ) + chunk_ends = torch.min( + context_lens_cpu.unsqueeze(0), chunk_starts + max_context_chunk + ) chunk_seq_lens = (chunk_ends - chunk_starts).clamp(min=0) - cu_seq_lens_cpu = torch.zeros(num_chunks, - num_prefills + 1, - dtype=torch.int32, - pin_memory=True) - torch.cumsum(chunk_seq_lens, - dim=1, - out=cu_seq_lens_cpu[:, 1:], - dtype=torch.int32) + cu_seq_lens_cpu = torch.zeros( + num_chunks, num_prefills + 1, dtype=torch.int32, pin_memory=True + ) + torch.cumsum( + chunk_seq_lens, dim=1, out=cu_seq_lens_cpu[:, 1:], dtype=torch.int32 + ) if self.dcp_world_size > 1: # Note(hc): The above max_context_chunk already enforces @@ -767,36 +875,37 @@ def build(self, # cp_gather_cache which not require `cp_chunk_starts` # aligned to page_size. assert max_context_chunk % self.dcp_world_size == 0 - cp_max_context_chunk = max_context_chunk // \ - self.dcp_world_size - cp_chunk_starts = \ - torch.arange(num_chunks, dtype=torch.int32) \ - .unsqueeze(1).expand(-1, num_prefills) \ + cp_max_context_chunk = max_context_chunk // self.dcp_world_size + cp_chunk_starts = ( + torch.arange(num_chunks, dtype=torch.int32) + .unsqueeze(1) + .expand(-1, num_prefills) * cp_max_context_chunk + ) cp_chunk_ends = torch.min( cp_context_lens_cpu.unsqueeze(0), - cp_chunk_starts + cp_max_context_chunk) - cp_chunk_seq_lens = (cp_chunk_ends - - cp_chunk_starts).clamp(min=0) - - cp_cu_seq_lens_cpu = torch.zeros(num_chunks, - num_prefills + 1, - dtype=torch.int32, - pin_memory=True) - torch.cumsum(cp_chunk_seq_lens, - dim=1, - out=cp_cu_seq_lens_cpu[:, 1:], - dtype=torch.int32) - - chunked_context_metadata_cls = \ - CudnnPrefillMetadata.ChunkedContextMetadata \ - if self._use_cudnn_prefill else \ - MLACommonPrefillMetadata.ChunkedContextMetadata + cp_chunk_starts + cp_max_context_chunk, + ) + cp_chunk_seq_lens = (cp_chunk_ends - cp_chunk_starts).clamp(min=0) + + cp_cu_seq_lens_cpu = torch.zeros( + num_chunks, num_prefills + 1, dtype=torch.int32, pin_memory=True + ) + torch.cumsum( + cp_chunk_seq_lens, + dim=1, + out=cp_cu_seq_lens_cpu[:, 1:], + dtype=torch.int32, + ) + + chunked_context_metadata_cls = ( + CudnnPrefillMetadata.ChunkedContextMetadata + if self._use_cudnn_prefill + else MLACommonPrefillMetadata.ChunkedContextMetadata + ) if self.dcp_world_size > 1: - chunked_context_metadata = \ - chunked_context_metadata_cls( - cu_seq_lens=cu_seq_lens_cpu \ - .to(device, non_blocking=True), + chunked_context_metadata = chunked_context_metadata_cls( + cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True), starts=cp_chunk_starts.to(device, non_blocking=True), seq_tot=cp_chunk_seq_lens.sum(dim=1).tolist(), max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(), @@ -804,16 +913,13 @@ def build(self, workspace=self.chunked_prefill_workspace, cp_chunk_seq_lens=cp_chunk_seq_lens.tolist(), origin_context_lens=origin_context_lens, - cp_cu_seq_lens=cp_cu_seq_lens_cpu \ - .to(device, non_blocking=True), + cp_cu_seq_lens=cp_cu_seq_lens_cpu.to(device, non_blocking=True), chunk_size=max_context_chunk, cu_seq_lens_lst=cu_seq_lens_cpu.tolist(), ) else: - chunked_context_metadata = \ - chunked_context_metadata_cls( - cu_seq_lens=cu_seq_lens_cpu \ - .to(device, non_blocking=True), + chunked_context_metadata = chunked_context_metadata_cls( + cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True), starts=chunk_starts.to(device, non_blocking=True), seq_tot=chunk_seq_lens.sum(dim=1).tolist(), max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(), @@ -824,8 +930,10 @@ def build(self, if self._use_cudnn_prefill: chunked_context_metadata.seq_lens = chunk_seq_lens - assert max(chunked_context_metadata.max_seq_lens) <= \ - self.chunked_prefill_workspace_size + assert ( + max(chunked_context_metadata.max_seq_lens) + <= self.chunked_prefill_workspace_size + ) prefill_metadata = self.prefill_metadata_cls( block_table=block_table_tensor[reqs_start:, ...], @@ -836,19 +944,30 @@ def build(self, if self._use_cudnn_prefill: assert isinstance(prefill_metadata, CudnnPrefillMetadata) - prefill_metadata.query_seq_lens = prefill_query_start_loc[1:] \ - - prefill_query_start_loc[:-1] + prefill_metadata.query_seq_lens = ( + prefill_query_start_loc[1:] - prefill_query_start_loc[:-1] + ) prefill_metadata.cudnn_workspace = self.cudnn_workspace + if self._use_trtllm_ragged_prefill: + prefill_metadata.query_seq_lens = ( + prefill_query_start_loc[1:] - prefill_query_start_loc[:-1] + ) + decode_metadata = None if num_decodes > 0: decode_metadata = self._build_decode( block_table_tensor=block_table_tensor[:num_decodes, ...], seq_lens_cpu=seq_lens_cpu[:num_decodes], - seq_lens_device=seq_lens[:num_decodes], - query_start_loc_cpu=query_start_loc_cpu[:num_decodes + 1], - query_start_loc_device=query_start_loc[:num_decodes + 1], + seq_lens_device=dcp_local_seq_lens[:num_decodes] + if self.dcp_world_size > 1 and dcp_local_seq_lens is not None + else seq_lens[:num_decodes], + query_start_loc_cpu=query_start_loc_cpu[: num_decodes + 1], + query_start_loc_device=query_start_loc[: num_decodes + 1], num_decode_tokens=num_decode_tokens, + dcp_tot_seq_lens_device=seq_lens[:num_decodes] + if self.dcp_world_size > 1 + else None, ) attn_metadata = self.metadata_cls( @@ -904,12 +1023,14 @@ def reorg_kvcache( k_pe_segments = [] src_token_idx = 0 max_seq_len_check = 0 - for cp_chunk_seq_len, origin_context_len in zip(cp_chunk_seq_lens_lst, - origin_context_lens): + for cp_chunk_seq_len, origin_context_len in zip( + cp_chunk_seq_lens_lst, origin_context_lens + ): chunk_context_len = chunk_size if cp_chunk_seq_len != 0: chunk_context_len = min( - chunk_context_len, origin_context_len - chunk_size * chunk_idx) + chunk_context_len, origin_context_len - chunk_size * chunk_idx + ) cp_target_rank = (chunk_context_len - 1) % cp_world_size cur_seq_len = 0 for rank in range(cp_world_size): @@ -918,14 +1039,16 @@ def reorg_kvcache( else: real_cp_chunk_seq_len = cp_chunk_seq_len if real_cp_chunk_seq_len: - kv_c_segment = allgatered_kv_c_normed[rank * toks + - src_token_idx:rank * - toks + src_token_idx + - real_cp_chunk_seq_len] - k_pe_segment = allgatered_k_pe[rank * toks + - src_token_idx:rank * toks + - src_token_idx + - real_cp_chunk_seq_len] + kv_c_segment = allgatered_kv_c_normed[ + rank * toks + src_token_idx : rank * toks + + src_token_idx + + real_cp_chunk_seq_len + ] + k_pe_segment = allgatered_k_pe[ + rank * toks + src_token_idx : rank * toks + + src_token_idx + + real_cp_chunk_seq_len + ] kv_c_segments.append(kv_c_segment) k_pe_segments.append(k_pe_segment) cur_seq_len += real_cp_chunk_seq_len @@ -953,14 +1076,14 @@ def __init__( head_size: int, scale: float, num_kv_heads: int, - alibi_slopes: Optional[list[float]], - sliding_window: Optional[int], + alibi_slopes: list[float] | None, + sliding_window: int | None, kv_cache_dtype: str, - logits_soft_cap: Optional[float], + logits_soft_cap: float | None, attn_type: str, - kv_sharing_target_layer_name: Optional[str], + kv_sharing_target_layer_name: str | None, # MLA Specific Arguments - q_lora_rank: Optional[int], + q_lora_rank: int | None, kv_lora_rank: int, qk_nope_head_dim: int, qk_rope_head_dim: int, @@ -968,7 +1091,7 @@ def __init__( v_head_dim: int, kv_b_proj: ColumnParallelLinear, indexer=None, - q_pad_num_heads: Optional[int] = None, + q_pad_num_heads: int | None = None, ) -> None: if kv_sharing_target_layer_name is not None: raise NotImplementedError("KV sharing is not supported for MLA") @@ -990,25 +1113,24 @@ def __init__( self.q_pad_num_heads = q_pad_num_heads def process_weights_after_loading(self, act_dtype: torch.dtype): - def get_layer_weight(layer): WEIGHT_NAMES = ("weight", "qweight", "weight_packed") for attr in WEIGHT_NAMES: if hasattr(layer, attr): return getattr(layer, attr) raise AttributeError( - f"Layer '{layer}' has no recognized weight attribute:" - f" {WEIGHT_NAMES}.") + f"Layer '{layer}' has no recognized weight attribute: {WEIGHT_NAMES}." + ) def get_and_maybe_dequant_weights(layer: LinearBase): if not isinstance(layer.quant_method, UnquantizedLinearMethod): # NOTE: This should only be used offline, since it's O(N^3) - eye = torch.eye(layer.input_size_per_partition, - dtype=act_dtype, - device=get_layer_weight(layer).device) - dequant_weights = layer.quant_method.apply(layer, - eye, - bias=None) + eye = torch.eye( + layer.input_size_per_partition, + dtype=act_dtype, + device=get_layer_weight(layer).device, + ) + dequant_weights = layer.quant_method.apply(layer, eye, bias=None) del eye # standardize to (output, input) return dequant_weights.T @@ -1020,12 +1142,14 @@ def get_and_maybe_dequant_weights(layer: LinearBase): kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T assert kv_b_proj_weight.shape == ( self.kv_lora_rank, - self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), ( - f"{kv_b_proj_weight.shape=}, " - f"{self.kv_lora_rank=}, " - f"{self.num_heads=}, " - f"{self.qk_nope_head_dim=}, " - f"{self.v_head_dim=}") + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), + ), ( + f"{kv_b_proj_weight.shape=}, " + f"{self.kv_lora_rank=}, " + f"{self.num_heads=}, " + f"{self.qk_nope_head_dim=}, " + f"{self.v_head_dim=}" + ) kv_b_proj_weight = kv_b_proj_weight.view( self.kv_lora_rank, self.num_heads, @@ -1033,15 +1157,18 @@ def get_and_maybe_dequant_weights(layer: LinearBase): ) W_UK, W_UV = kv_b_proj_weight.split( - [self.qk_nope_head_dim, self.v_head_dim], dim=-1) + [self.qk_nope_head_dim, self.v_head_dim], dim=-1 + ) if is_rocm_aiter_fp8bmm_enabled(): W_K = W_UK.transpose(0, 1) # 16 512 128 W_V = W_UV.permute(1, 2, 0) # 16 128 512 self.W_K, self.W_K_scale = dynamic_per_batched_tensor_quant( - W_K, dtype=current_platform.fp8_dtype()) + W_K, dtype=current_platform.fp8_dtype() + ) self.W_V, self.W_V_scale = dynamic_per_batched_tensor_quant( - W_V, dtype=current_platform.fp8_dtype()) + W_V, dtype=current_platform.fp8_dtype() + ) # The kernel operates on non-padded inputs. Hence, pre-compiling # triton kernel to avoid runtime compilation for unseen batch sizes @@ -1057,23 +1184,23 @@ def get_and_maybe_dequant_weights(layer: LinearBase): ) for m in pre_compilation_list: - x = torch.empty((self.W_K.shape[0], m, self.W_K.shape[2]), - dtype=torch.bfloat16, - device=self.W_K.device) - aiter_triton_fp8_bmm(x, - self.W_K, - self.W_K_scale, - group_size=128, - transpose_bm=True) - - x = torch.empty((self.W_V.shape[0], m, self.W_V.shape[2]), - dtype=torch.bfloat16, - device=self.W_V.device) - aiter_triton_fp8_bmm(x, - self.W_V, - self.W_V_scale, - group_size=128, - transpose_bm=True) + x = torch.empty( + (self.W_K.shape[0], m, self.W_K.shape[2]), + dtype=torch.bfloat16, + device=self.W_K.device, + ) + aiter_triton_fp8_bmm( # type: ignore + x, self.W_K, self.W_K_scale, group_size=128, transpose_bm=True + ) + + x = torch.empty( + (self.W_V.shape[0], m, self.W_V.shape[2]), + dtype=torch.bfloat16, + device=self.W_V.device, + ) + aiter_triton_fp8_bmm( # type: ignore + x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True + ) else: # Convert from (L, N, V) to (N, L, V) self.W_UV = W_UV.transpose(0, 1) @@ -1083,13 +1210,12 @@ def get_and_maybe_dequant_weights(layer: LinearBase): def _v_up_proj(self, x: torch.Tensor, out: torch.Tensor): # Convert from (B, N, L) to (N, B, L) x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) + if is_rocm_aiter_fp8bmm_enabled(): # Multiply + Transpose (N, B, L) x (N, L, V)->(N, B, V)->(B, N, V) - x = aiter_triton_fp8_bmm(x, - self.W_V, - self.W_V_scale, - group_size=128, - transpose_bm=True) + x = aiter_triton_fp8_bmm( # type: ignore + x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True + ) # Convert from (B, N, V) to (B, N * V) x = x.reshape(-1, self.num_heads * self.v_head_dim) # Copy result @@ -1102,8 +1228,7 @@ def _v_up_proj(self, x: torch.Tensor, out: torch.Tensor): torch.bmm(x, self.W_UV, out=out) # Reuse "out" to make it "hot" # Convert from (N, B, V) to (B, N * V) - out_new = out.transpose(0, 1).reshape( - -1, self.num_heads * self.v_head_dim) + out_new = out.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim) # Adjust output buffer shape back to the original (B, N * V) N, B, V = out.shape @@ -1125,10 +1250,16 @@ def __init__(self, *args, **kwargs) -> None: self._run_prefill_context_chunk = self._run_prefill_context_chunk_fi self._run_prefill_new_tokens = self._run_prefill_new_tokens_fi self._pad_v = False + elif use_trtllm_ragged_deepseek_prefill(): + logger.debug_once("Using TRT-LLM ragged DeepSeek prefill for MLA") + self._run_prefill_context_chunk = ( + self._run_prefill_context_chunk_trtllm_ragged + ) + self._run_prefill_new_tokens = self._run_prefill_new_tokens_trtllm_ragged + self._pad_v = False elif use_cudnn_prefill(): logger.debug_once("Using CUDNN prefill for MLA") - self._run_prefill_context_chunk = \ - self._run_prefill_context_chunk_cudnn + self._run_prefill_context_chunk = self._run_prefill_context_chunk_cudnn self._run_prefill_new_tokens = self._run_prefill_new_tokens_cudnn self._pad_v = False else: # Use FlashAttention @@ -1143,9 +1274,9 @@ def __init__(self, *args, **kwargs) -> None: self.flash_attn_varlen_func = flash_attn_varlen_func self.vllm_flash_attn_version = get_flash_attn_version() if self.vllm_flash_attn_version is not None: - self.flash_attn_varlen_func = \ - functools.partial(flash_attn_varlen_func, - fa_version=self.vllm_flash_attn_version) + self.flash_attn_varlen_func = functools.partial( + flash_attn_varlen_func, fa_version=self.vllm_flash_attn_version + ) # For MLA the v head dim is smaller than qk head dim so we pad out # v with 0s to match the qk head dim for attention backends that do @@ -1153,25 +1284,25 @@ def __init__(self, *args, **kwargs) -> None: # We don't need to pad V if we are on a hopper system with FA3 self._pad_v = self.vllm_flash_attn_version is None or not ( self.vllm_flash_attn_version == 3 - and current_platform.get_device_capability()[0] == 9) + and current_platform.get_device_capability()[0] == 9 + ) - self.dcp_world_size: Optional[int] = None + self.dcp_world_size: int | None = None - self.chunked_prefill_workspace_size = \ + self.chunked_prefill_workspace_size = ( MLACommonMetadataBuilder.determine_chunked_prefill_workspace_size( - get_current_vllm_config()) - - def _flash_attn_varlen_diff_headdims(self, - q, - k, - v, - return_softmax_lse=False, - softmax_scale=None, - **kwargs): + get_current_vllm_config() + ) + ) + + def _flash_attn_varlen_diff_headdims( + self, q, k, v, return_softmax_lse=False, softmax_scale=None, **kwargs + ): maybe_padded_v = v if self._pad_v: maybe_padded_v = torch.nn.functional.pad( - v, [0, q.shape[-1] - v.shape[-1]], value=0) + v, [0, q.shape[-1] - v.shape[-1]], value=0 + ) if is_vllm_fa: kwargs["return_softmax_lse"] = return_softmax_lse @@ -1179,6 +1310,8 @@ def _flash_attn_varlen_diff_headdims(self, # ROCm leverages the upstream flash_attn, which takes a parameter # called "return_attn_probs" instead of return_softmax_lse kwargs["return_attn_probs"] = return_softmax_lse + if vllm_is_batch_invariant(): + kwargs["num_splits"] = 1 attn_out = self.flash_attn_varlen_func( q=q, @@ -1199,8 +1332,9 @@ def _flash_attn_varlen_diff_headdims(self, return attn_out, lse return attn_out - def _run_prefill_new_tokens_fa(self, prefill: MLACommonPrefillMetadata, q, - k, v, return_softmax_lse): + def _run_prefill_new_tokens_fa( + self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse + ): return self._flash_attn_varlen_diff_headdims( q=q, k=k, @@ -1214,10 +1348,12 @@ def _run_prefill_new_tokens_fa(self, prefill: MLACommonPrefillMetadata, q, return_softmax_lse=return_softmax_lse, ) - def _run_prefill_new_tokens_fi(self, prefill: MLACommonPrefillMetadata, q, - k, v, return_softmax_lse): + def _run_prefill_new_tokens_fi( + self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse + ): assert isinstance(prefill, FlashInferPrefillMetadata) assert prefill.prefill_main is not None + ret = prefill.prefill_main.run( q=q, k=k, @@ -1226,15 +1362,15 @@ def _run_prefill_new_tokens_fi(self, prefill: MLACommonPrefillMetadata, q, ) if isinstance(ret, tuple): - # Convert from (q_len, num_heads) to (num_heads, q_len) return ret[0], ret[1].transpose(0, 1).contiguous() return ret - def _run_prefill_new_tokens_cudnn(self, prefill: MLACommonPrefillMetadata, - q, k, v, return_softmax_lse): + def _run_prefill_new_tokens_cudnn( + self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse + ): assert isinstance(prefill, CudnnPrefillMetadata) assert prefill.query_seq_lens is not None - output, lse = cudnn_batch_prefill_with_kv_cache( + output, lse = cudnn_batch_prefill_with_kv_cache( # type: ignore q=q, k_cache=k, v_cache=v, @@ -1245,16 +1381,18 @@ def _run_prefill_new_tokens_cudnn(self, prefill: MLACommonPrefillMetadata, actual_seq_lens_q=prefill.query_seq_lens.view(-1, 1, 1, 1), actual_seq_lens_kv=prefill.query_seq_lens.view(-1, 1, 1, 1), causal=True, - return_lse=True, # do not support False for now - is_cuda_graph_compatible= - True, #Indicates actual_seq_lens are on GPU or CPU. + # Do not support False for now + return_lse=True, + # Indicates actual_seq_lens are on GPU or CPU. + is_cuda_graph_compatible=True, ) if return_softmax_lse: return output, lse return output - def _run_prefill_context_chunk_fa(self, prefill: MLACommonPrefillMetadata, - chunk_idx: int, q, k, v): + def _run_prefill_context_chunk_fa( + self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v + ): assert prefill.chunked_context is not None return self._flash_attn_varlen_diff_headdims( q=q, @@ -1269,26 +1407,29 @@ def _run_prefill_context_chunk_fa(self, prefill: MLACommonPrefillMetadata, return_softmax_lse=True, ) - def _run_prefill_context_chunk_fi(self, prefill: MLACommonPrefillMetadata, - chunk_idx: int, q, k, v): + def _run_prefill_context_chunk_fi( + self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v + ): assert isinstance(prefill, FlashInferPrefillMetadata) + attn_out, lse = prefill.prefill_chunks[chunk_idx].run( q=q, k=k, v=v, return_lse=True, ) + # Convert from (q_len, num_heads) to (num_heads, q_len) return attn_out, lse.transpose(0, 1).contiguous() - def _run_prefill_context_chunk_cudnn(self, - prefill: MLACommonPrefillMetadata, - chunk_idx: int, q, k, v): + def _run_prefill_context_chunk_cudnn( + self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v + ): assert isinstance(prefill, CudnnPrefillMetadata) assert prefill.chunked_context is not None assert prefill.chunked_context.seq_lens[chunk_idx] is not None assert prefill.query_seq_lens is not None - return cudnn_batch_prefill_with_kv_cache( + return cudnn_batch_prefill_with_kv_cache( # type: ignore q=q, k_cache=k, v_cache=v, @@ -1297,34 +1438,109 @@ def _run_prefill_context_chunk_cudnn(self, max_token_per_sequence=prefill.max_query_len, max_sequence_kv=prefill.chunked_context.max_seq_lens[chunk_idx], actual_seq_lens_q=prefill.query_seq_lens.view(-1, 1, 1, 1), - actual_seq_lens_kv=prefill.chunked_context.seq_lens[chunk_idx]. - view(-1, 1, 1, 1), + actual_seq_lens_kv=prefill.chunked_context.seq_lens[chunk_idx].view( + -1, 1, 1, 1 + ), causal=False, return_lse=True, - is_cuda_graph_compatible= - True, #Indicates actual_seq_lens are on GPU or CPU. + # Indicates actual_seq_lens are on GPU or CPU. + is_cuda_graph_compatible=True, ) - def process_weights_after_loading(self, act_dtype: torch.dtype): + def _run_prefill_new_tokens_trtllm_ragged( + self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse + ): + """TRT-LLM ragged attention for new tokens (causal).""" + from flashinfer.prefill import trtllm_ragged_attention_deepseek + + assert prefill.query_seq_lens is not None + + ret = trtllm_ragged_attention_deepseek( + query=q, + key=k, + value=v, + workspace_buffer=self._workspace_buffer, + seq_lens=prefill.query_seq_lens, + max_q_len=prefill.max_query_len, + max_kv_len=prefill.max_query_len, + bmm1_scale=self.scale, + bmm2_scale=1.0, + o_sf_scale=1.0, + batch_size=prefill.query_seq_lens.shape[0], + window_left=-1, + cum_seq_lens_q=prefill.query_start_loc, + cum_seq_lens_kv=prefill.query_start_loc, + enable_pdl=False, + is_causal=True, + return_lse=return_softmax_lse, + ) + + if isinstance(ret, tuple): + # Convert from (q_len, num_heads) to (num_heads, q_len) + return ret[0], ret[1].transpose(0, 1).contiguous() + return ret + + def _run_prefill_context_chunk_trtllm_ragged( + self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v + ): + """TRT-LLM ragged attention for context chunks (non-causal).""" + from flashinfer.prefill import trtllm_ragged_attention_deepseek + + assert prefill.chunked_context is not None + assert prefill.chunked_context.seq_lens[chunk_idx] is not None + out = torch.zeros( + q.shape[0], + q.shape[1], + v.shape[2], + device=q.device, + dtype=q.dtype, + ) + self._workspace_buffer.fill_(0) + + attn_out, lse = trtllm_ragged_attention_deepseek( + query=q, + key=k, + value=v, + workspace_buffer=self._workspace_buffer, + seq_lens=prefill.chunked_context.seq_lens[chunk_idx], + max_q_len=prefill.max_query_len, + max_kv_len=prefill.chunked_context.max_seq_lens[chunk_idx], + bmm1_scale=self.scale, + bmm2_scale=1.0, + o_sf_scale=1.0, + batch_size=prefill.chunked_context.seq_lens[chunk_idx].shape[0], + window_left=-1, + cum_seq_lens_q=prefill.query_start_loc, + cum_seq_lens_kv=prefill.chunked_context.cu_seq_lens[chunk_idx], + enable_pdl=False, + is_causal=False, + return_lse=True, + out=out, + ) + + # Convert from (q_len, num_heads) to (num_heads, q_len) + return attn_out, lse.transpose(0, 1).contiguous() + + def process_weights_after_loading(self, act_dtype: torch.dtype): def get_layer_weight(layer): WEIGHT_NAMES = ("weight", "qweight", "weight_packed") for attr in WEIGHT_NAMES: if hasattr(layer, attr): return getattr(layer, attr) raise AttributeError( - f"Layer '{layer}' has no recognized weight attribute:" - f" {WEIGHT_NAMES}.") + f"Layer '{layer}' has no recognized weight attribute: {WEIGHT_NAMES}." + ) def get_and_maybe_dequant_weights(layer: LinearBase): if not isinstance(layer.quant_method, UnquantizedLinearMethod): # NOTE: This should only be used offline, since it's O(N^3) - eye = torch.eye(layer.input_size_per_partition, - dtype=act_dtype, - device=get_layer_weight(layer).device) - dequant_weights = layer.quant_method.apply(layer, - eye, - bias=None) + eye = torch.eye( + layer.input_size_per_partition, + dtype=act_dtype, + device=get_layer_weight(layer).device, + ) + dequant_weights = layer.quant_method.apply(layer, eye, bias=None) del eye # standardize to (output, input) return dequant_weights.T @@ -1336,12 +1552,14 @@ def get_and_maybe_dequant_weights(layer: LinearBase): kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T assert kv_b_proj_weight.shape == ( self.kv_lora_rank, - self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), ( - f"{kv_b_proj_weight.shape=}, " - f"{self.kv_lora_rank=}, " - f"{self.num_heads=}, " - f"{self.qk_nope_head_dim=}, " - f"{self.v_head_dim=}") + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), + ), ( + f"{kv_b_proj_weight.shape=}, " + f"{self.kv_lora_rank=}, " + f"{self.num_heads=}, " + f"{self.qk_nope_head_dim=}, " + f"{self.v_head_dim=}" + ) kv_b_proj_weight = kv_b_proj_weight.view( self.kv_lora_rank, self.num_heads, @@ -1349,15 +1567,18 @@ def get_and_maybe_dequant_weights(layer: LinearBase): ) W_UK, W_UV = kv_b_proj_weight.split( - [self.qk_nope_head_dim, self.v_head_dim], dim=-1) + [self.qk_nope_head_dim, self.v_head_dim], dim=-1 + ) if is_rocm_aiter_fp8bmm_enabled(): W_K = W_UK.transpose(0, 1) # 16 512 128 W_V = W_UV.permute(1, 2, 0) # 16 128 512 self.W_K, self.W_K_scale = dynamic_per_batched_tensor_quant( - W_K, dtype=current_platform.fp8_dtype()) + W_K, dtype=current_platform.fp8_dtype() + ) self.W_V, self.W_V_scale = dynamic_per_batched_tensor_quant( - W_V, dtype=current_platform.fp8_dtype()) + W_V, dtype=current_platform.fp8_dtype() + ) # The kernel operates on non-padded inputs. Hence, pre-compiling # triton kernel to avoid runtime compilation for unseen batch sizes @@ -1373,23 +1594,23 @@ def get_and_maybe_dequant_weights(layer: LinearBase): ) for m in pre_compilation_list: - x = torch.empty((self.W_K.shape[0], m, self.W_K.shape[2]), - dtype=torch.bfloat16, - device=self.W_K.device) - aiter_triton_fp8_bmm(x, - self.W_K, - self.W_K_scale, - group_size=128, - transpose_bm=True) - - x = torch.empty((self.W_V.shape[0], m, self.W_V.shape[2]), - dtype=torch.bfloat16, - device=self.W_V.device) - aiter_triton_fp8_bmm(x, - self.W_V, - self.W_V_scale, - group_size=128, - transpose_bm=True) + x = torch.empty( + (self.W_K.shape[0], m, self.W_K.shape[2]), + dtype=torch.bfloat16, + device=self.W_K.device, + ) + aiter_triton_fp8_bmm( # type: ignore + x, self.W_K, self.W_K_scale, group_size=128, transpose_bm=True + ) + + x = torch.empty( + (self.W_V.shape[0], m, self.W_V.shape[2]), + dtype=torch.bfloat16, + device=self.W_V.device, + ) + aiter_triton_fp8_bmm( # type: ignore + x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True + ) else: # Convert from (L, N, V) to (N, L, V) self.W_UV = W_UV.transpose(0, 1) @@ -1425,18 +1646,15 @@ def _compute_prefill_context( seq_starts=prefill_metadata.chunked_context.starts[i], ) - kv_c_normed = workspace[:toks]\ - [..., :self.kv_lora_rank] - k_pe = workspace[:toks]\ - [..., self.kv_lora_rank:].unsqueeze(1) + kv_c_normed = workspace[:toks][..., : self.kv_lora_rank] + k_pe = workspace[:toks][..., self.kv_lora_rank :].unsqueeze(1) - kv_nope = self.kv_b_proj(kv_c_normed)[0].view( \ - -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) - k_nope, v = kv_nope\ - .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) + kv_nope = self.kv_b_proj(kv_c_normed)[0].view( + -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim + ) + k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) - k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), - dim=-1) + k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) attn_output, attn_softmax_lse = self._run_prefill_context_chunk( prefill=prefill_metadata, @@ -1501,44 +1719,45 @@ def _context_parallel_compute_prefill_context( # |------- N tokens --------|--------- N*dcp_size tokens ----------| # |<- use for loca_gather ->|<--------- use for allgather -------->| allgather_offset = workspace.shape[0] // (dcp_world_size + 1) - assert allgather_offset * (dcp_world_size + - 1) == workspace.shape[0] + assert allgather_offset * (dcp_world_size + 1) == workspace.shape[0] assert toks <= allgather_offset local_gathered_kvcache = workspace[:toks] cur_allgather_workspace = workspace[ - allgather_offset:allgather_offset * (1 + dcp_world_size)] + allgather_offset : allgather_offset * (1 + dcp_world_size) + ] assert toks * dcp_world_size <= cur_allgather_workspace.shape[0] - cur_allgather_kvcache = cur_allgather_workspace[:toks * - dcp_world_size] - cur_allgather_kvcache.copy_(get_dcp_group().all_gather( - local_gathered_kvcache, dim=0)) - assert cur_allgather_kvcache.shape[ - -1] == self.kv_lora_rank + self.qk_rope_head_dim - allgatered_kv_c_normed, allgatered_k_pe = \ - cur_allgather_kvcache.unsqueeze( - 1).split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + cur_allgather_kvcache = cur_allgather_workspace[: toks * dcp_world_size] + cur_allgather_kvcache.copy_( + get_dcp_group().all_gather(local_gathered_kvcache, dim=0) + ) + assert ( + cur_allgather_kvcache.shape[-1] + == self.kv_lora_rank + self.qk_rope_head_dim + ) + allgatered_kv_c_normed, allgatered_k_pe = cur_allgather_kvcache.unsqueeze( + 1 + ).split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) kv_c_normed, k_pe = reorg_kvcache( allgatered_kv_c_normed, allgatered_k_pe, - cp_chunk_seq_lens_lst=prefill_metadata.chunked_context. - cp_chunk_seq_lens[i], - origin_context_lens=prefill_metadata.chunked_context. - origin_context_lens, + cp_chunk_seq_lens_lst=prefill_metadata.chunked_context.cp_chunk_seq_lens[ + i + ], + origin_context_lens=prefill_metadata.chunked_context.origin_context_lens, cp_world_size=dcp_world_size, - sum_seq_len=prefill_metadata.chunked_context.cu_seq_lens_lst[i] - [-1], + sum_seq_len=prefill_metadata.chunked_context.cu_seq_lens_lst[i][-1], max_seq_len=prefill_metadata.chunked_context.max_seq_lens[i], chunk_size=prefill_metadata.chunked_context.chunk_size, chunk_idx=i, - toks=toks) + toks=toks, + ) - kv_nope = self.kv_b_proj(kv_c_normed)[0].view( \ - -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) - k_nope, v = kv_nope\ - .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) - k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), - dim=-1) + kv_nope = self.kv_b_proj(kv_c_normed)[0].view( + -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim + ) + k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) + k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) attn_output, attn_softmax_lse = self._run_prefill_context_chunk( prefill=prefill_metadata, @@ -1580,12 +1799,11 @@ def _forward_prefill( assert attn_metadata.prefill is not None assert self.dcp_world_size is not None - # TODO: need to check if this is supported on maca has_context = attn_metadata.prefill.chunked_context is not None - kv_nope = self.kv_b_proj(kv_c_normed)[0].view(\ - -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) - k_nope, v = kv_nope\ - .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) + kv_nope = self.kv_b_proj(kv_c_normed)[0].view( + -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim + ) + k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) @@ -1600,14 +1818,19 @@ def _forward_prefill( if has_context: suffix_output, suffix_lse = output if self.dcp_world_size > 1: - context_output, context_lse = \ + context_output, context_lse = ( self._context_parallel_compute_prefill_context( - q, kv_c_and_k_pe_cache, attn_metadata, - k_scale=None, dcp_world_size=self.dcp_world_size) + q, + kv_c_and_k_pe_cache, + attn_metadata, + k_scale=None, + dcp_world_size=self.dcp_world_size, + ) + ) else: - context_output, context_lse = \ - self._compute_prefill_context( - q, kv_c_and_k_pe_cache, attn_metadata, k_scale) + context_output, context_lse = self._compute_prefill_context( + q, kv_c_and_k_pe_cache, attn_metadata, k_scale + ) output = torch.empty_like(suffix_output) merge_attn_states( @@ -1620,18 +1843,18 @@ def _forward_prefill( # unpad if necessary if self._pad_v: - output = output[..., :v.shape[-1]] + output = output[..., : v.shape[-1]] return output.flatten(start_dim=-2) @abstractmethod def _forward_decode( self, - q: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]], + q: torch.Tensor | tuple[torch.Tensor, torch.Tensor], kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: M, layer: AttentionLayer, - ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + ) -> tuple[torch.Tensor, torch.Tensor | None]: raise NotImplementedError def forward( @@ -1642,24 +1865,27 @@ def forward( k_pe: torch.Tensor, # value in unified attn kv_cache: torch.Tensor, attn_metadata: M, - output: Optional[torch.Tensor] = None, - output_scale: Optional[torch.Tensor] = None, - output_block_scale: Optional[torch.Tensor] = None, + output: torch.Tensor | None = None, + output_scale: torch.Tensor | None = None, + output_block_scale: torch.Tensor | None = None, ) -> torch.Tensor: assert output is not None, "Output tensor must be provided." if output_scale is not None or output_block_scale is not None: raise NotImplementedError( - "fused output quantization is not yet supported" - " for MLACommonImpl") + "fused output quantization is not yet supported for MLACommonImpl" + ) if attn_metadata is None: # During the profile run try to simulate to worse case output size # for `self.kv_b_proj(kv_c_normed)` in `_compute_prefill_context` # since this can be large _ = torch.empty( - (self.chunked_prefill_workspace_size, self.num_heads, - self.qk_nope_head_dim + self.v_head_dim), + ( + self.chunked_prefill_workspace_size, + self.num_heads, + self.qk_nope_head_dim + self.v_head_dim, + ), device=k_c_normed.device, dtype=k_c_normed.dtype, ) @@ -1683,9 +1909,11 @@ def forward( k_c_normed = k_c_normed[:num_actual_toks, ...] k_pe = k_pe[:num_actual_toks, ...] - assert attn_metadata.num_decodes is not None and \ - attn_metadata.num_prefills is not None and \ - attn_metadata.num_decode_tokens is not None + assert ( + attn_metadata.num_decodes is not None + and attn_metadata.num_prefills is not None + and attn_metadata.num_decode_tokens is not None + ) has_decode = attn_metadata.num_decodes > 0 has_prefill = attn_metadata.num_prefills > 0 @@ -1713,61 +1941,74 @@ def forward( if has_prefill: output[num_decode_tokens:] = self._forward_prefill( - prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache, - attn_metadata, layer._k_scale) + prefill_q, + prefill_k_c_normed, + prefill_k_pe, + kv_cache, + attn_metadata, + layer._k_scale, + ) if has_decode: assert attn_metadata.decode is not None + decode_q_nope, decode_q_pe = decode_q.split( - [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 + ) + # Convert from (B, N, P) to (N, B, P) decode_q_nope = decode_q_nope.transpose(0, 1) # Pads the head_dim if necessary (for the underlying kernel) if self.q_pad_num_heads is not None: B, N, L = decode_q_pe.shape - decode_pe_padded = decode_q_pe.new_empty( - (B, self.q_pad_num_heads, L)) + decode_pe_padded = decode_q_pe.new_empty((B, self.q_pad_num_heads, L)) decode_pe_padded.resize_((B, N, L)) decode_pe_padded.copy_(decode_q_pe) decode_q_pe = decode_pe_padded if is_rocm_aiter_fp8bmm_enabled(): # Multiply+Transpose (N, B, P)x(N, P, L)->(N, B, L)->(B, N, L) - decode_ql_nope = aiter_triton_fp8_bmm(decode_q_nope, - self.W_K, - self.W_K_scale, - group_size=128, - transpose_bm=True) + decode_ql_nope = aiter_triton_fp8_bmm( # type: ignore + decode_q_nope, + self.W_K, + self.W_K_scale, + group_size=128, + transpose_bm=True, + ) else: # Pads the head_dim if necessary (for the underlying kernel) N, B, P = decode_q_nope.shape _, _, L = self.W_UK_T.shape + if self.q_pad_num_heads is not None: decode_ql_nope = decode_q_nope.new_empty( - (self.q_pad_num_heads, B, L)) + (self.q_pad_num_heads, B, L) + ) decode_ql_nope.resize_((N, B, L)) - else: decode_ql_nope = decode_q_nope.new_empty((N, B, L)) # Multiply (N, B, P) x (N, P, L) -> (N, B, L) torch.bmm(decode_q_nope, self.W_UK_T, out=decode_ql_nope) + # Convert from (N, B, L) to (B, N, L) decode_ql_nope = decode_ql_nope.transpose(0, 1) if fp8_attention: ql_nope_shape = decode_ql_nope.shape decode_ql_nope, _ = ops.scaled_fp8_quant( - decode_ql_nope.reshape([ - ql_nope_shape[0], ql_nope_shape[1] * ql_nope_shape[2] - ]), layer._q_scale) + decode_ql_nope.reshape( + [ql_nope_shape[0], ql_nope_shape[1] * ql_nope_shape[2]] + ), + layer._q_scale, + ) decode_ql_nope = decode_ql_nope.reshape(ql_nope_shape) q_pe_shape = decode_q_pe.shape decode_q_pe, _ = ops.scaled_fp8_quant( - decode_q_pe.reshape( - [q_pe_shape[0], q_pe_shape[1] * q_pe_shape[2]]), - layer._q_scale) + decode_q_pe.reshape([q_pe_shape[0], q_pe_shape[1] * q_pe_shape[2]]), + layer._q_scale, + ) decode_q_pe = decode_q_pe.reshape(q_pe_shape) decode_q = (decode_ql_nope, decode_q_pe) @@ -1779,8 +2020,9 @@ def forward( decode_q = get_dcp_group().all_gather(decode_q, dim=1) # call decode attn - attn_out, lse = self._forward_decode(decode_q, kv_cache, - attn_metadata, layer) + attn_out, lse = self._forward_decode( + decode_q, kv_cache, attn_metadata, layer + ) # recorect dcp attn_out with lse. if self.dcp_world_size > 1: diff --git a/vllm_metax/v1/attention/backends/mla/flashmla.py b/vllm_metax/v1/attention/backends/mla/flashmla.py index c66871edc..7fc14bd42 100644 --- a/vllm_metax/v1/attention/backends/mla/flashmla.py +++ b/vllm_metax/v1/attention/backends/mla/flashmla.py @@ -2,27 +2,40 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass -from typing import ClassVar, Optional, Union +from typing import ClassVar import torch -from vllm.attention.backends.abstract import AttentionLayer, AttentionType -from vllm_metax.attention.ops.flashmla import (flash_mla_with_kvcache, - get_mla_metadata, - is_flashmla_supported) +from vllm.attention.backends.abstract import AttentionLayer, AttentionType, MultipleOf +from vllm_metax.attention.ops.flashmla import ( + flash_mla_with_kvcache, + get_mla_metadata, + is_flashmla_dense_supported, +) from vllm.config import VllmConfig from vllm.logger import init_logger +from vllm.model_executor.layers.batch_invariant import ( + vllm_is_batch_invariant, +) from vllm_metax.v1.attention.backends.mla.common import ( - MLACommonBackend, MLACommonDecodeMetadata, MLACommonImpl, - MLACommonMetadata, MLACommonMetadataBuilder) -from vllm.v1.attention.backends.utils import AttentionCGSupport + MLACommonBackend, + MLACommonDecodeMetadata, + MLACommonImpl, + MLACommonMetadata, + MLACommonMetadataBuilder, + QueryLenSupport, +) +from vllm.v1.attention.backends.utils import ( + AttentionCGSupport, + reshape_attn_output_for_spec_decode, + reshape_query_for_spec_decode, +) from vllm.v1.kv_cache_interface import AttentionSpec logger = init_logger(__name__) class MacaFlashMLABackend(MLACommonBackend): - @staticmethod def get_name() -> str: return "FLASHMLA" @@ -39,6 +52,10 @@ def get_builder_cls() -> type["FlashMLAMetadataBuilder"]: def get_impl_cls() -> type["FlashMLAImpl"]: return FlashMLAImpl + @staticmethod + def get_supported_kernel_block_size() -> list[int | MultipleOf]: + return [64] + @dataclass class FlashMLADecodeMetadata(MLACommonDecodeMetadata): @@ -52,19 +69,29 @@ class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]): class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): - cudagraph_support: ClassVar[AttentionCGSupport] = \ - AttentionCGSupport.UNIFORM_BATCH + cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH + query_len_support: ClassVar[QueryLenSupport] = QueryLenSupport.UNIFORM + reorder_batch_threshold: int = 512 # process small prefills with decode pathway + # ^ TODO(matt): tune this - def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], - vllm_config: VllmConfig, device: torch.device): - super().__init__(kv_cache_spec, layer_names, vllm_config, device, - FlashMLAMetadata) + def __init__( + self, + kv_cache_spec: AttentionSpec, + layer_names: list[str], + vllm_config: VllmConfig, + device: torch.device, + ): + super().__init__( + kv_cache_spec, layer_names, vllm_config, device, FlashMLAMetadata + ) self.num_q_heads = vllm_config.model_config.get_num_attention_heads( - vllm_config.parallel_config) + vllm_config.parallel_config + ) self.cg_buf_tile_scheduler_metadata = None self.cg_buf_num_splits = None + self.is_fp8_kvcache = vllm_config.cache_config.cache_dtype.startswith("fp8") device_properties = torch.cuda.get_device_properties(self.device) num_sms = device_properties.multi_processor_count @@ -80,19 +107,28 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], self.cg_buf_num_splits = torch.empty( (vllm_config.scheduler_config.max_num_seqs + 1), device=self.device, - dtype=torch.int32) - - def _build_decode(self, block_table_tensor: torch.Tensor, - seq_lens_cpu: torch.Tensor, - seq_lens_device: torch.Tensor, - query_start_loc_cpu: torch.Tensor, - query_start_loc_device: torch.Tensor, - num_decode_tokens: int) -> FlashMLADecodeMetadata: - tile_scheduler_metadata, num_splits = \ - get_mla_metadata( + dtype=torch.int32, + ) + + def _build_decode( + self, + block_table_tensor: torch.Tensor, + seq_lens_cpu: torch.Tensor, + seq_lens_device: torch.Tensor, + query_start_loc_cpu: torch.Tensor, + query_start_loc_device: torch.Tensor, + num_decode_tokens: int, + dcp_tot_seq_lens_device: torch.Tensor | None, + ) -> FlashMLADecodeMetadata: + query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] + # we use the max but all should be the same due to uniform length requirement + max_query_len = query_lens_cpu.max().item() + num_q_tokens_per_head_k = max_query_len * self.num_q_heads // 1 + tile_scheduler_metadata, num_splits = get_mla_metadata( seq_lens_device, - self.num_q_heads, - 1, # MQA for the decode path + num_q_tokens_per_head_k, + 1, # MQA for the decode path + is_fp8_kvcache=self.is_fp8_kvcache, ) # TODO: we can disambiguate between decode and mixed-prefill decode here @@ -105,8 +141,9 @@ def _build_decode(self, block_table_tensor: torch.Tensor, sm_parts = tile_scheduler_metadata.size(0) # Metadata per-SM, upper bound on size (<= #SMs, TileMetadataSize) assert sm_parts <= self.cg_buf_tile_scheduler_metadata.size(0) - tile_scheduler_metadata_view = \ - self.cg_buf_tile_scheduler_metadata[:sm_parts] + tile_scheduler_metadata_view = self.cg_buf_tile_scheduler_metadata[ + :sm_parts + ] tile_scheduler_metadata_view.copy_(tile_scheduler_metadata) tile_scheduler_metadata = tile_scheduler_metadata_view @@ -127,54 +164,67 @@ def _build_decode(self, block_table_tensor: torch.Tensor, seq_lens=seq_lens_device, tile_scheduler_metadata=tile_scheduler_metadata, num_splits=num_splits, + dcp_tot_seq_lens=dcp_tot_seq_lens_device, ) class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): - can_return_lse_for_decode: bool = True def __init__( - self, - num_heads: int, - head_size: int, - scale: float, - num_kv_heads: int, - alibi_slopes: Optional[list[float]], - sliding_window: Optional[int], - kv_cache_dtype: str, - logits_soft_cap: Optional[float], - attn_type: str, - kv_sharing_target_layer_name: Optional[str], - # MLA Specific Arguments - **mla_args) -> None: - super().__init__(num_heads, head_size, scale, num_kv_heads, - alibi_slopes, sliding_window, kv_cache_dtype, - logits_soft_cap, attn_type, - kv_sharing_target_layer_name, **mla_args) - - is_supported, reason = is_flashmla_supported() + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: list[float] | None, + sliding_window: int | None, + kv_cache_dtype: str, + logits_soft_cap: float | None, + attn_type: str, + kv_sharing_target_layer_name: str | None, + # MLA Specific Arguments + **mla_args, + ) -> None: + super().__init__( + num_heads, + head_size, + scale, + num_kv_heads, + alibi_slopes, + sliding_window, + kv_cache_dtype, + logits_soft_cap, + attn_type, + kv_sharing_target_layer_name, + **mla_args, + ) + + is_supported, reason = is_flashmla_dense_supported() assert is_supported, reason unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap] if any(unsupported_features): raise NotImplementedError( "FlashMLAImpl does not support one of the following: " - "alibi_slopes, sliding_window, logits_soft_cap") + "alibi_slopes, sliding_window, logits_soft_cap" + ) if attn_type != AttentionType.DECODER: - raise NotImplementedError("Encoder self-attention and " - "encoder/decoder cross-attention " - "are not implemented for " - "FlashMLAImpl") + raise NotImplementedError( + "Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "FlashMLAImpl" + ) def _forward_decode( self, - q: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]], + q: torch.Tensor | tuple[torch.Tensor, torch.Tensor], kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: FlashMLAMetadata, layer: AttentionLayer, - ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + ) -> tuple[torch.Tensor, torch.Tensor | None]: # TODO: (zyongye) decode function for mla here assert kv_c_and_k_pe_cache.numel() > 0 assert attn_metadata.decode is not None @@ -182,20 +232,56 @@ def _forward_decode( if type(q) is tuple: q = torch.cat(q, dim=-1) + # mypy assertion: q is now always a tensor assert isinstance(q, torch.Tensor) + + num_decodes = attn_metadata.num_decodes + q = reshape_query_for_spec_decode(q, num_decodes) + + tile_scheduler_metadata = attn_metadata.decode.tile_scheduler_metadata + num_splits = attn_metadata.decode.num_splits + if vllm_is_batch_invariant(): + device = q.device + dtype = torch.int32 + + B = q.shape[0] + # block_table shape: [batch_size, max_num_blocks_per_seq] + # The number of blocks per sequence is in the second dimension + topk = attn_metadata.decode.block_table.shape[-1] + B_TOPK = 64 + assert topk % B_TOPK == 0, f"topk ({topk}) must be divisible by {B_TOPK}" + end_block_idx = topk // B_TOPK + + # Single partition => num_sm_parts = 1 + # TileSchedulerMetaDataSize = 8, layout: + # [begin_idx, begin_block_idx, end_idx, end_block_idx, + # begin_n_split_idx, _, _, _] + tile_scheduler_metadata = torch.zeros((1, 8), dtype=dtype, device=device) + tile_scheduler_metadata[0, 0] = 0 # begin_idx + tile_scheduler_metadata[0, 1] = 0 # sched_begin_block_idx + tile_scheduler_metadata[0, 2] = B - 1 # end_idx + tile_scheduler_metadata[0, 3] = end_block_idx + tile_scheduler_metadata[0, 4] = 0 # begin_n_split_idx + # fields [5..7] stay 0 + + # Non-split path ignores num_splits, but the API requires it: + # zeros of length B+1 + num_splits = torch.zeros((B + 1,), dtype=dtype, device=device) + o, lse = flash_mla_with_kvcache( - q=q.unsqueeze(1), # Add seqlen dim of 1 (decode) + q=q, k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1 block_table=attn_metadata.decode.block_table, cache_seqlens=attn_metadata.decode.seq_lens, head_dim_v=self.kv_lora_rank, - tile_scheduler_metadata=attn_metadata.decode. - tile_scheduler_metadata, - num_splits=attn_metadata.decode.num_splits, + tile_scheduler_metadata=tile_scheduler_metadata, + num_splits=num_splits, softmax_scale=self.scale, causal=True, descale_q=layer._q_scale.reshape(1), descale_k=layer._k_scale.reshape(1), ) + o = reshape_attn_output_for_spec_decode(o) + return o, lse diff --git a/vllm_metax/v1/attention/backends/mla/flashmla_sparse.py b/vllm_metax/v1/attention/backends/mla/flashmla_sparse.py index 06103fb61..e7e689296 100644 --- a/vllm_metax/v1/attention/backends/mla/flashmla_sparse.py +++ b/vllm_metax/v1/attention/backends/mla/flashmla_sparse.py @@ -8,21 +8,28 @@ import torch from vllm import _custom_ops as ops -from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer, - AttentionMetadata) +from vllm.attention.backends.abstract import ( + AttentionBackend, + AttentionLayer, + AttentionMetadata, +) from vllm.attention.backends.utils import get_mla_dims -from vllm_metax.attention.ops.flashmla import (flash_mla_sparse_prefill, - flash_mla_with_kvcache, - get_mla_metadata) +from vllm_metax.attention.ops.flashmla import ( + flash_mla_sparse_prefill, + flash_mla_with_kvcache, + get_mla_metadata, +) from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.triton_utils import tl, triton -from vllm.utils import cdiv +from vllm.utils.math_utils import cdiv from vllm_metax.v1.attention.backends.mla.common import MLACommonBaseImpl -from vllm.v1.attention.backends.utils import (AttentionCGSupport, - AttentionMetadataBuilder, - CommonAttentionMetadata) +from vllm.v1.attention.backends.utils import ( + AttentionCGSupport, + AttentionMetadataBuilder, + CommonAttentionMetadata, +) from vllm.v1.kv_cache_interface import AttentionSpec if TYPE_CHECKING: @@ -44,19 +51,12 @@ """ -def _lse2_to_lse(lse_base2: torch.Tensor) -> torch.Tensor: - # Convert base-2 LSE to natural-log LSE - # Keep FP32 for numerical stability during the merge. - return (lse_base2.to(torch.float32) * math.log(2.0)) - - class MacaFlashMLASparseBackend(AttentionBackend): - accept_output_buffer: bool = True @staticmethod def get_name() -> str: - return "FLASHMLA_SPARSE_VLLM_V1" + return "FLASHMLA_SPARSE" @staticmethod def get_metadata_cls() -> type[AttentionMetadata]: @@ -94,35 +94,6 @@ def get_supported_head_sizes(cls) -> list[int]: return [576] -@dataclass -class MLASparsePrefillMetadata: - # NOTE(Chen): not call it "FlashMLASparsePrefillMetadata" because - # the kernel is not from flashmla - block_table: torch.Tensor - has_context: bool = False - context_lens: Optional[torch.Tensor] = None - - -@dataclass -class FlashMLASparseDecodeAndContextMetadata: - scheduler_metadata: torch.Tensor = None - num_splits: torch.Tensor = None - cache_lens: torch.Tensor = None - prefill_context_lengths: Optional[torch.Tensor] = None - prefill_new_k_start_locs: Optional[torch.Tensor] = None - dummy_block_table: torch.Tensor = None - - def filter_prefill_indices( - self, indices: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - assert self.prefill_context_lengths is not None - prefill_context_lengths = self.prefill_context_lengths.unsqueeze(-1) - context_indices = torch.where(indices < prefill_context_lengths, - indices, -1) - new_token_indices = torch.where(indices >= prefill_context_lengths, - indices - prefill_context_lengths, -1) - return context_indices, new_token_indices - - @dataclass class FlashMLASparseMetadata: num_reqs: int @@ -140,12 +111,12 @@ class FlashMLASparseMetadata: @dataclass class FP8KernelMetadata: - scheduler_metadata: Optional[torch.Tensor] + scheduler_metadata: torch.Tensor | None num_splits: torch.Tensor dummy_block_table: torch.Tensor cache_lens: torch.Tensor - fp8_extra_metadata: Optional[FP8KernelMetadata] = None + fp8_extra_metadata: FP8KernelMetadata | None = None @triton.jit @@ -194,8 +165,9 @@ def _convert_req_index_to_global_index_kernel( base = tl.load(bt_ptr, mask=valid_block, other=0) # If token == -1 OR block_id OOB, output -1; else base * BLOCK_SIZE + offset - out_val = tl.where(is_invalid_tok | (~valid_block), -1, - base * BLOCK_SIZE + inblock_off) + out_val = tl.where( + is_invalid_tok | (~valid_block), -1, base * BLOCK_SIZE + inblock_off + ) # Store results out_ptr_ij = out_ptr + token_id * out_stride0 + indice_id * out_stride1 @@ -203,31 +175,30 @@ def _convert_req_index_to_global_index_kernel( def triton_convert_req_index_to_global_index( - req_id: torch.Tensor, # int32 [num_tokens] - block_table: torch. - Tensor, # int32 [num_requests, max_num_blocks_per_req] - token_indices: torch.Tensor, # int32 [num_tokens, NUM_TOPK_TOKENS] - BLOCK_SIZE: int = 64, - NUM_TOPK_TOKENS: int = 2048, - BLOCK_N: int = 128, # tile width along columns + req_id: torch.Tensor, # int32 [num_tokens] + block_table: torch.Tensor, # int32 [num_requests, max_num_blocks_per_req] + token_indices: torch.Tensor, # int32 [num_tokens, NUM_TOPK_TOKENS] + BLOCK_SIZE: int = 64, + NUM_TOPK_TOKENS: int = 2048, + BLOCK_N: int = 128, # tile width along columns ): """ out[token_id, indice_id] = - block_table[req_id[token_id], + block_table[req_id[token_id], token_indices[token_id, indice_id] // BLOCK_SIZE] * BLOCK_SIZE + token_indices[token_id, indice_id] % BLOCK_SIZE Only when token_indices[token_id, indice_id] == -1 do we output -1. - For safety, we also output -1 if the derived block_id would be + For safety, we also output -1 if the derived block_id would be out-of-bounds. """ assert req_id.dtype == torch.int32 assert block_table.dtype == torch.int32 assert token_indices.dtype == torch.int32 assert token_indices.shape[1] == NUM_TOPK_TOKENS - assert NUM_TOPK_TOKENS % BLOCK_N == 0, \ - f"NUM_TOPK_TOKENS ({NUM_TOPK_TOKENS}) must be divisible by" \ - f"BLOCK_N ({BLOCK_N})" + assert NUM_TOPK_TOKENS % BLOCK_N == 0, ( + f"NUM_TOPK_TOKENS ({NUM_TOPK_TOKENS}) must be divisible byBLOCK_N ({BLOCK_N})" + ) num_tokens = req_id.shape[0] num_requests, max_num_blocks_per_req = block_table.shape @@ -268,14 +239,16 @@ def triton_convert_req_index_to_global_index( @dataclass -class FlashMLASparseMetadataBuilder( - AttentionMetadataBuilder[FlashMLASparseMetadata]): - cudagraph_support: ClassVar[AttentionCGSupport] = \ - AttentionCGSupport.UNIFORM_BATCH - - def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], - vllm_config: VllmConfig, device: torch.device): +class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetadata]): + cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH + def __init__( + self, + kv_cache_spec: AttentionSpec, + layer_names: list[str], + vllm_config: VllmConfig, + device: torch.device, + ): cache_config = vllm_config.cache_config self.kv_cache_spec = kv_cache_spec self.model_config = vllm_config.model_config @@ -285,28 +258,27 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], props = torch.cuda.get_device_properties(device) sm_count = props.multi_processor_count - self.num_heads = self.model_config.get_num_attention_heads( - parallel_config) + self.num_heads = self.model_config.get_num_attention_heads(parallel_config) self.mla_dims = get_mla_dims(self.model_config) self.topk_tokens = vllm_config.model_config.hf_config.index_topk self.use_fp8_kv_cache = cache_config.cache_dtype == "fp8_ds_mla" - self.topk_tokens_tensor = torch.tensor([self.topk_tokens], - device=device, - dtype=torch.int32) + self.topk_tokens_tensor = torch.tensor( + [self.topk_tokens], device=device, dtype=torch.int32 + ) self.max_model_len_tensor = torch.tensor( - [self.model_config.max_model_len], - device=device, - dtype=torch.int32) + [self.model_config.max_model_len], device=device, dtype=torch.int32 + ) # this is ignored by `flash_mla_with_kvcache` if indices not None - self.dummy_block_table = torch.empty((1, 1), - dtype=torch.int32, - device=self.device) + self.dummy_block_table = torch.empty( + (1, 1), dtype=torch.int32, device=self.device + ) # Equation taken from FlashMLA/csrc/pybind.cpp h_q, h_k = self.num_heads, 1 s_q = 1 # inversely proportional to s_q, so s_q = 1 is the largest max_num_sm_parts = int( - max((sm_count // 2) / h_k // (cdiv(h_q // h_k, 2 * 64) * s_q), 1)) + max((sm_count // 2) / h_k // (cdiv(h_q // h_k, 2 * 64) * s_q), 1) + ) if current_platform.is_device_capability(100): max_num_sm_parts *= 2 self.tile_scheduler_metadata_buffer = torch.empty( @@ -314,34 +286,38 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], # see: FlashMLA/csrc/params.h (max_num_sm_parts, 8), dtype=torch.int32, - device=device) + device=device, + ) self.num_splits_buffer = torch.empty( # We pack all the tokens into one batch for sparse attention. # Otherwise, we can exceed the sm of `get_mla_metadata`. - ( - 2, ), + (2,), dtype=torch.int32, - device=device) + device=device, + ) self.req_id_per_token_buffer = torch.empty( - (vllm_config.scheduler_config.max_num_batched_tokens, ), + (vllm_config.scheduler_config.max_num_batched_tokens,), dtype=torch.int32, - device=device) - - def build(self, - common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata, - fast_build: bool = False) -> FlashMLASparseMetadata: + device=device, + ) + def build( + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False, + ) -> FlashMLASparseMetadata: num_tokens = common_attn_metadata.num_actual_tokens - starts = np.asarray(common_attn_metadata.query_start_loc_cpu, - dtype=np.int32) + starts = np.asarray(common_attn_metadata.query_start_loc_cpu, dtype=np.int32) seg_lengths = np.diff(starts) req_id_per_token = np.repeat( - np.arange(seg_lengths.shape[0], dtype=np.int32), seg_lengths) + np.arange(seg_lengths.shape[0], dtype=np.int32), seg_lengths + ) # Zero-fill for cudagraphs self.req_id_per_token_buffer.fill_(0) - self.req_id_per_token_buffer[:req_id_per_token.shape[0]]\ - .copy_(torch.from_numpy(req_id_per_token), non_blocking=True) + self.req_id_per_token_buffer[: req_id_per_token.shape[0]].copy_( + torch.from_numpy(req_id_per_token), non_blocking=True + ) req_id_per_token = self.req_id_per_token_buffer[:num_tokens] fp8_extra_metadata = None @@ -357,8 +333,9 @@ def build(self, num_sm_parts = tile_scheduler_metadata.size(0) # Copy to persistent buffer for full-CG support - tile_scheduler_metadata_buffer = \ - self.tile_scheduler_metadata_buffer[:num_sm_parts] + tile_scheduler_metadata_buffer = self.tile_scheduler_metadata_buffer[ + :num_sm_parts + ] tile_scheduler_metadata_buffer.copy_(tile_scheduler_metadata) self.num_splits_buffer.copy_(num_splits) @@ -371,7 +348,8 @@ def build(self, # accidentally mark indices invalid, we will use -1 exclusively # to mark invalid indices cache_lens=self.max_model_len_tensor, - dummy_block_table=self.dummy_block_table) + dummy_block_table=self.dummy_block_table, + ) metadata = FlashMLASparseMetadata( num_reqs=common_attn_metadata.num_reqs, @@ -390,62 +368,79 @@ def build(self, class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]): - def __init__( - self, - num_heads: int, - head_size: int, - scale: float, - num_kv_heads: int, - alibi_slopes: Optional[list[float]], - sliding_window: Optional[int], - kv_cache_dtype: str, - logits_soft_cap: Optional[float], - attn_type: str, - kv_sharing_target_layer_name: Optional[str], - # MLA Specific Arguments - topk_indice_buffer: Optional[torch.Tensor] = None, - indexer: Optional["Indexer"] = None, - **mla_args) -> None: - super().__init__(num_heads, head_size, scale, num_kv_heads, - alibi_slopes, sliding_window, kv_cache_dtype, - logits_soft_cap, attn_type, - kv_sharing_target_layer_name, **mla_args) + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: list[float] | None, + sliding_window: int | None, + kv_cache_dtype: str, + logits_soft_cap: float | None, + attn_type: str, + kv_sharing_target_layer_name: str | None, + # MLA Specific Arguments + topk_indice_buffer: torch.Tensor | None = None, + indexer: Optional["Indexer"] = None, + **mla_args, + ) -> None: + super().__init__( + num_heads, + head_size, + scale, + num_kv_heads, + alibi_slopes, + sliding_window, + kv_cache_dtype, + logits_soft_cap, + attn_type, + kv_sharing_target_layer_name, + **mla_args, + ) self.softmax_scale = scale assert indexer is not None self.topk_indices_buffer = indexer.topk_indices_buffer - self.padding = 128 if current_platform.is_device_capability( - 100) else 64 + self.padding = 128 if current_platform.is_device_capability(100) else 64 def _forward_bf16_kv( - self, q: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, - topk_indices: torch.Tensor, - attn_metadata: FlashMLASparseMetadata) -> torch.Tensor: + self, + q: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + topk_indices: torch.Tensor, + attn_metadata: FlashMLASparseMetadata, + ) -> torch.Tensor: num_tokens = q.shape[0] kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.view( - -1, 1, kv_c_and_k_pe_cache.shape[-1]) + -1, 1, kv_c_and_k_pe_cache.shape[-1] + ) # NOTE(Chen): kernel requires num_local_head to be a multiple of # 64 on hopper and 128 on blackwell if self.num_heads % self.padding != 0: assert self.padding % self.num_heads == 0 - logger.warning_once(f"padding num_heads to {self.padding} \ - due to sparse attn kernel requirement") + logger.warning_once( + f"padding num_heads to {self.padding} \ + due to sparse attn kernel requirement" + ) q_padded = q.new_empty((q.shape[0], self.padding, q.shape[2])) - q_padded[:, :self.num_heads, :] = q + q_padded[:, : self.num_heads, :] = q q = q_padded topk_indices = topk_indices.view(num_tokens, 1, -1) - output = flash_mla_sparse_prefill(q, kv_c_and_k_pe_cache, topk_indices, - self.softmax_scale)[0] - output = output[:, :self.num_heads, :] + output = flash_mla_sparse_prefill( + q, kv_c_and_k_pe_cache, topk_indices, self.softmax_scale + )[0] + output = output[:, : self.num_heads, :] return output - def _forward_fp8_kv(self, q: torch.Tensor, - kv_c_and_k_pe_cache: torch.Tensor, - topk_indices: torch.Tensor, - attn_metadata: FlashMLASparseMetadata) -> torch.Tensor: - + def _forward_fp8_kv( + self, + q: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + topk_indices: torch.Tensor, + attn_metadata: FlashMLASparseMetadata, + ) -> torch.Tensor: assert attn_metadata.fp8_extra_metadata is not None extra_metadata = attn_metadata.fp8_extra_metadata @@ -472,9 +467,9 @@ def forward( k_pe: torch.Tensor, # value in unified attn kv_cache: torch.Tensor, attn_metadata: FlashMLASparseMetadata, - output: Optional[torch.Tensor] = None, - output_scale: Optional[torch.Tensor] = None, - output_block_scale: Optional[torch.Tensor] = None, + output: torch.Tensor | None = None, + output_scale: torch.Tensor | None = None, + output_block_scale: torch.Tensor | None = None, ) -> torch.Tensor: # NOTE(lucas): for the sparse FlashMLA kernels the kernels want to use # MQA 576/512 approach for both prefill and decode @@ -483,8 +478,8 @@ def forward( if output_scale is not None or output_block_scale is not None: raise NotImplementedError( - "fused output quantization is not yet supported" - " for MLACommonImpl") + "fused output quantization is not yet supported for MLACommonImpl" + ) if attn_metadata is None: # The zero fill is required when used with DP + EP @@ -500,8 +495,7 @@ def forward( k_c_normed = k_c_normed[:num_actual_toks, ...] k_pe = k_pe[:num_actual_toks, ...] - q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], - dim=-1) + q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) # Convert from (B, N, P) to (N, B, P) q_nope = q_nope.transpose(0, 1) # Multiply (N, B, P) x (N, P, L) -> (N, B, L) @@ -534,11 +528,13 @@ def forward( ) if self.kv_cache_dtype != "fp8_ds_mla": - attn_out = self._forward_bf16_kv(q, kv_cache, topk_indices_global, - attn_metadata) + attn_out = self._forward_bf16_kv( + q, kv_cache, topk_indices_global, attn_metadata + ) else: - attn_out = self._forward_fp8_kv(q, kv_cache, topk_indices_global, - attn_metadata) + attn_out = self._forward_fp8_kv( + q, kv_cache, topk_indices_global, attn_metadata + ) self._v_up_proj(attn_out, out=output[:num_actual_toks]) return output diff --git a/vllm_metax/v1/attention/backends/mla/indexer.py b/vllm_metax/v1/attention/backends/mla/indexer.py index 851f633fc..eb7b3d09d 100644 --- a/vllm_metax/v1/attention/backends/mla/indexer.py +++ b/vllm_metax/v1/attention/backends/mla/indexer.py @@ -1,25 +1,29 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass -from typing import ClassVar, Optional +from typing import ClassVar import torch -from vllm.attention.backends.abstract import (AttentionBackend, - AttentionMetadata) +from vllm.attention.backends.abstract import ( + AttentionBackend, + AttentionMetadata, + MultipleOf, +) from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.utils.deep_gemm import get_paged_mqa_logits_metadata -from vllm.v1.attention.backends.utils import (AttentionCGSupport, - AttentionMetadataBuilder, - CommonAttentionMetadata, - split_decodes_and_prefills) +from vllm.v1.attention.backends.utils import ( + AttentionCGSupport, + AttentionMetadataBuilder, + CommonAttentionMetadata, + split_decodes_and_prefills, +) logger = init_logger(__name__) class MacaDeepseekV32IndexerBackend(AttentionBackend): - @staticmethod def get_metadata_cls() -> type["AttentionMetadata"]: return DeepseekV32IndexerMetadata @@ -47,6 +51,10 @@ def get_kv_cache_shape( def get_kv_cache_stride_order() -> tuple[int, ...]: return (0, 1, 2) + @classmethod + def get_supported_kernel_block_size(cls) -> list[int | MultipleOf]: + return [64] + @dataclass class DeepseekV32IndexerPrefillChunkMetadata: @@ -76,7 +84,6 @@ class DeepSeekV32IndexerDecodeMetadata: @dataclass class DeepseekV32IndexerMetadata: - # FIXME (zyongye) # hacky way to access the data now, need to be in chunked meta seq_lens: torch.Tensor @@ -98,33 +105,33 @@ class DeepseekV32IndexerMetadata: num_prefills: int num_prefill_tokens: int - decode: Optional[DeepSeekV32IndexerDecodeMetadata] = None - prefill: Optional[DeepseekV32IndexerPrefillMetadata] = None + decode: DeepSeekV32IndexerDecodeMetadata | None = None + prefill: DeepseekV32IndexerPrefillMetadata | None = None # TODO (zyongye) optimize this, this is now vibe coded def kv_spans_from_batches( - start_seq_loc: torch.Tensor, seq_len_per_batch: torch.Tensor, - device: torch.device) -> tuple[torch.Tensor, torch.Tensor]: + start_seq_loc: torch.Tensor, seq_len_per_batch: torch.Tensor, device: torch.device +) -> tuple[torch.Tensor, torch.Tensor]: """ Args: - start_seq_loc: 1D long tensor [B+1], cumulative counts of + start_seq_loc: 1D long tensor [B+1], cumulative counts of selected tokens per batch. - Example: [0, 2, 4, 7] -> + Example: [0, 2, 4, 7] -> batch sizes (selected) [2, 2, 3], N=7 tokens total. - seq_len_per_batch: 1D long tensor [B], + seq_len_per_batch: 1D long tensor [B], full sequence length (KV length) of each batch. Example: [5, 9, 4]. Returns: - start_tensor: 1D long tensor [N], start offset in the + start_tensor: 1D long tensor [N], start offset in the concatenated KV cache for each token's batch. - end_location: 1D long tensor [N], + end_location: 1D long tensor [N], **exclusive** end = start + token's local position. (So the attended KV slice is kv[start:end].) - Assumes each batch contributes its full `seq_len_per_batch[i]` - keys to the KV cache, andthe selected tokens within a batch + Assumes each batch contributes its full `seq_len_per_batch[i]` + keys to the KV cache, andthe selected tokens within a batch are the **last** `counts[i]` positions of that sequence. """ q = start_seq_loc.to(dtype=torch.long) @@ -138,8 +145,10 @@ def kv_spans_from_batches( B = L.numel() if N == 0: - return (torch.empty(0, dtype=torch.long, device=device), - torch.empty(0, dtype=torch.long, device=device)) + return ( + torch.empty(0, dtype=torch.long, device=device), + torch.empty(0, dtype=torch.long, device=device), + ) # KV start offsets per batch in the concatenated KV cache kv_starts_per_batch = torch.cumsum(L, dim=0) - L # [B] @@ -155,8 +164,9 @@ def kv_spans_from_batches( L_expand = torch.repeat_interleave(L, counts) # [N] m_expand = torch.repeat_interleave(counts, counts) # [N] # position within the selected block: 1..counts[b] - pos_within = (torch.arange(N, dtype=torch.long) - - torch.repeat_interleave(q[:-1], counts) + 1) + pos_within = ( + torch.arange(N, dtype=torch.long) - torch.repeat_interleave(q[:-1], counts) + 1 + ) local_pos = L_expand - m_expand + pos_within # [N], 1-based end_location = start_tensor + local_pos # exclusive end @@ -171,9 +181,9 @@ def get_max_prefill_buffer_size(vllm_config: VllmConfig): return max_model_len * 2 -def split_prefill_chunks(seq_lens_cpu: torch.Tensor, - max_prefill_buffer_size: int, - reqs_start: int) -> list[tuple[int, int]]: +def split_prefill_chunks( + seq_lens_cpu: torch.Tensor, max_prefill_buffer_size: int, reqs_start: int +) -> list[tuple[int, int]]: """ Split the prefill chunks into a list of tuples of (reqs_start, reqs_end) such that the total sequence length of each chunk is less than the @@ -183,7 +193,7 @@ def split_prefill_chunks(seq_lens_cpu: torch.Tensor, seq_lens_cpu: The sequence lengths of the prefill requests. max_prefill_buffer_size: The maximum prefill buffer size. reqs_start: The start index of the prefill requests. - + Returns: A list of tuples of (reqs_start, reqs_end). """ @@ -203,20 +213,22 @@ def split_prefill_chunks(seq_lens_cpu: torch.Tensor, class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder): - cudagraph_support: ClassVar[AttentionCGSupport] = \ + cudagraph_support: ClassVar[AttentionCGSupport] = ( AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE + ) reorder_batch_threshold: int = 1 def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) scheduler_config = self.vllm_config.scheduler_config - #NOTE(Chen):an estimated max size of flattened_kv. Need to double check. - self.max_prefill_buffer_size = get_max_prefill_buffer_size( - self.vllm_config) + # NOTE(Chen):an estimated max size of flattened_kv. Need to double check. + self.max_prefill_buffer_size = get_max_prefill_buffer_size(self.vllm_config) self.num_speculative_tokens = ( self.vllm_config.speculative_config.num_speculative_tokens - if self.vllm_config.speculative_config else 0) + if self.vllm_config.speculative_config + else 0 + ) # Now deepgemm fp8_paged_mqa_logits does not support next_n > 2 self.reorder_batch_threshold += min(self.num_speculative_tokens, 1) @@ -225,31 +237,38 @@ def __init__(self, *args, **kwargs): self.num_sms = sm_count self.decode_lens_buffer = torch.empty( - (scheduler_config.max_num_seqs, ), - dtype=torch.int32, - device=self.device) + (scheduler_config.max_num_seqs,), dtype=torch.int32, device=self.device + ) # See: DeepGMM/csrc/apis/attention.hpp - self.scheduler_metadata_buffer = torch.empty((self.num_sms + 1, 2), - dtype=torch.int32, - device=self.device) - - def build_one_prefill_chunk(self, reqs_start, reqs_end, - query_start_loc_cpu, seq_lens_cpu, - block_table): - prefill_query_start_loc = query_start_loc_cpu[ - reqs_start:reqs_end + 1] - query_start_loc_cpu[reqs_start] + self.scheduler_metadata_buffer = torch.empty( + (self.num_sms + 1, 2), dtype=torch.int32, device=self.device + ) + + def build_one_prefill_chunk( + self, reqs_start, reqs_end, query_start_loc_cpu, seq_lens_cpu, block_table + ): + prefill_query_start_loc = ( + query_start_loc_cpu[reqs_start : reqs_end + 1] + - query_start_loc_cpu[reqs_start] + ) cu_seqlen_ks, cu_seqlen_ke = kv_spans_from_batches( - prefill_query_start_loc, seq_lens_cpu[reqs_start:reqs_end], - self.device) + prefill_query_start_loc, seq_lens_cpu[reqs_start:reqs_end], self.device + ) token_start = query_start_loc_cpu[reqs_start].item() token_end = query_start_loc_cpu[reqs_end].item() total_seq_lens = seq_lens_cpu[reqs_start:reqs_end].sum() assert total_seq_lens <= self.max_prefill_buffer_size - cu_seq_lens = torch.cat([ - torch.zeros(1, dtype=torch.int32), - seq_lens_cpu[reqs_start:reqs_end].cumsum(dim=0) - ]).to(torch.int32).to(self.device) + cu_seq_lens = ( + torch.cat( + [ + torch.zeros(1, dtype=torch.int32), + seq_lens_cpu[reqs_start:reqs_end].cumsum(dim=0), + ] + ) + .to(torch.int32) + .to(self.device) + ) return DeepseekV32IndexerPrefillChunkMetadata( cu_seqlen_ks=cu_seqlen_ks, cu_seqlen_ke=cu_seqlen_ke, @@ -261,19 +280,21 @@ def build_one_prefill_chunk(self, reqs_start, reqs_end, num_reqs=reqs_end - reqs_start, ) - def build(self, - common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata, - fast_build: bool = False) -> DeepseekV32IndexerMetadata: - + def build( + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False, + ) -> DeepseekV32IndexerMetadata: num_reqs = common_attn_metadata.num_reqs num_tokens = common_attn_metadata.num_actual_tokens query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu - num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \ + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( split_decodes_and_prefills( - common_attn_metadata, - decode_threshold=self.reorder_batch_threshold) + common_attn_metadata, decode_threshold=self.reorder_batch_threshold + ) + ) assert num_decodes + num_prefills == num_reqs assert num_decode_tokens + num_prefill_tokens == num_tokens @@ -287,33 +308,39 @@ def build(self, ) chunks = [ self.build_one_prefill_chunk( - reqs_start, reqs_end, query_start_loc_cpu, + reqs_start, + reqs_end, + query_start_loc_cpu, common_attn_metadata.seq_lens_cpu, - common_attn_metadata.block_table_tensor) + common_attn_metadata.block_table_tensor, + ) for reqs_start, reqs_end in chunk_seq_ids ] prefill_metadata = DeepseekV32IndexerPrefillMetadata( - chunks=chunks, ) + chunks=chunks, + ) decode_metadata = None if num_decodes > 0: - torch.diff(common_attn_metadata.query_start_loc[:num_decodes + 1], - out=self.decode_lens_buffer[:num_decodes]) + torch.diff( + common_attn_metadata.query_start_loc[: num_decodes + 1], + out=self.decode_lens_buffer[:num_decodes], + ) decode_lens = self.decode_lens_buffer[:num_decodes] decode_lens_cpu = torch.diff( - common_attn_metadata.query_start_loc_cpu[:num_decodes + 1]) + common_attn_metadata.query_start_loc_cpu[: num_decodes + 1] + ) # Use CPU to avoid GPU sync; breaking async scheduling - requires_padding = (decode_lens_cpu.max() - > decode_lens_cpu.min()).item() + requires_padding = (decode_lens_cpu.max() > decode_lens_cpu.min()).item() seq_lens = common_attn_metadata.seq_lens[:num_decodes] self.scheduler_metadata_buffer[:] = get_paged_mqa_logits_metadata( - seq_lens, self.kv_cache_spec.block_size, self.num_sms) + seq_lens, self.kv_cache_spec.block_size, self.num_sms + ) decode_metadata = DeepSeekV32IndexerDecodeMetadata( - block_table=common_attn_metadata. - block_table_tensor[:num_decodes, ...], + block_table=common_attn_metadata.block_table_tensor[:num_decodes, ...], seq_lens=common_attn_metadata.seq_lens[:num_decodes], decode_lens=decode_lens, requires_padding=requires_padding, diff --git a/vllm_metax/v1/attention/backends/mla/triton_mla.py b/vllm_metax/v1/attention/backends/mla/triton_mla.py index aa2fdb7b8..ab70328a4 100644 --- a/vllm_metax/v1/attention/backends/mla/triton_mla.py +++ b/vllm_metax/v1/attention/backends/mla/triton_mla.py @@ -1,28 +1,33 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional, Union import torch from vllm import envs -from vllm.attention.backends.abstract import (AttentionLayer, AttentionType, - is_quantized_kv_cache) -from vllm_metax.attention.ops.triton_decode_attention import decode_attention_fwd +from vllm.attention.backends.abstract import ( + AttentionLayer, + AttentionType, + is_quantized_kv_cache, +) +from vllm.attention.ops.triton_decode_attention import decode_attention_fwd from vllm.attention.ops.triton_flash_attention import triton_attention from vllm.logger import init_logger +from vllm.model_executor.layers.batch_invariant import ( + vllm_is_batch_invariant, +) from vllm.platforms import current_platform from vllm.triton_utils import HAS_TRITON - -from vllm_metax.v1.attention.backends.mla.common import (MLACommonBackend, - MLACommonImpl, - MLACommonMetadata) +from vllm_metax.v1.attention.backends.mla.common import ( + MLACommonBackend, + MLACommonImpl, + MLACommonMetadata, +) logger = init_logger(__name__) class MacaTritonMLABackend(MLACommonBackend): - @staticmethod def get_name() -> str: return "TRITON_MLA" @@ -36,54 +41,64 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]): can_return_lse_for_decode: bool = True def __init__( - self, - num_heads: int, - head_size: int, - scale: float, - num_kv_heads: int, - alibi_slopes: Optional[list[float]], - sliding_window: Optional[int], - kv_cache_dtype: str, - logits_soft_cap: Optional[float], - attn_type: str, - kv_sharing_target_layer_name: Optional[str], - # MLA Specific Arguments - **mla_args) -> None: - super().__init__(num_heads, head_size, scale, num_kv_heads, - alibi_slopes, sliding_window, kv_cache_dtype, - logits_soft_cap, attn_type, - kv_sharing_target_layer_name, **mla_args) + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: list[float] | None, + sliding_window: int | None, + kv_cache_dtype: str, + logits_soft_cap: float | None, + attn_type: str, + kv_sharing_target_layer_name: str | None, + # MLA Specific Arguments + **mla_args, + ) -> None: + super().__init__( + num_heads, + head_size, + scale, + num_kv_heads, + alibi_slopes, + sliding_window, + kv_cache_dtype, + logits_soft_cap, + attn_type, + kv_sharing_target_layer_name, + **mla_args, + ) unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap] if any(unsupported_features): raise NotImplementedError( "TritonMLAImpl does not support one of the following: " - "alibi_slopes, sliding_window, logits_soft_cap") + "alibi_slopes, sliding_window, logits_soft_cap" + ) if attn_type != AttentionType.DECODER: - raise NotImplementedError("Encoder self-attention and " - "encoder/decoder cross-attention " - "are not implemented for " - "TritonMLAImpl") + raise NotImplementedError( + "Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "TritonMLAImpl" + ) if is_quantized_kv_cache(self.kv_cache_dtype): raise NotImplementedError( - "TritonMLA V1 with FP8 KV cache not yet supported") + "TritonMLA V1 with FP8 KV cache not yet supported" + ) self.use_triton_flash_attn = envs.VLLM_USE_TRITON_FLASH_ATTN self.triton_fa_func = triton_attention if HAS_TRITON else None - def _flash_attn_varlen_diff_headdims_rocm(self, - q, - k, - v, - softmax_scale=None, - **kwargs): + def _flash_attn_varlen_diff_headdims_rocm( + self, q, k, v, softmax_scale=None, **kwargs + ): assert self.triton_fa_func is not None # Triton Attention requires a padded V - padded_v = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]], - value=0) + padded_v = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]], value=0) # The output of triton_attention is a tuple of # [output_tensor, encoded_softmax] where encoded_softmax is always None output_tensor, _ = self.triton_fa_func( @@ -102,18 +117,17 @@ def _flash_attn_varlen_diff_headdims_rocm(self, return output_tensor - def _flash_attn_varlen_diff_headdims(self, - q, - k, - v, - return_softmax_lse=False, - softmax_scale=None, - **kwargs): - if current_platform.is_rocm() \ - and self.use_triton_flash_attn \ - and not return_softmax_lse: + def _flash_attn_varlen_diff_headdims( + self, q, k, v, return_softmax_lse=False, softmax_scale=None, **kwargs + ): + if ( + current_platform.is_rocm() + and self.use_triton_flash_attn + and not return_softmax_lse + ): return self._flash_attn_varlen_diff_headdims_rocm( - q, k, v, softmax_scale=softmax_scale, **kwargs) + q, k, v, softmax_scale=softmax_scale, **kwargs + ) else: return super()._flash_attn_varlen_diff_headdims( q, @@ -121,15 +135,16 @@ def _flash_attn_varlen_diff_headdims(self, v, return_softmax_lse=return_softmax_lse, softmax_scale=softmax_scale, - **kwargs) + **kwargs, + ) def _forward_decode( self, - q: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]], + q: torch.Tensor | tuple[torch.Tensor, torch.Tensor], kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: MLACommonMetadata, layer: AttentionLayer, - ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + ) -> tuple[torch.Tensor, torch.Tensor | None]: assert kv_c_and_k_pe_cache.numel() > 0 assert attn_metadata.decode is not None @@ -142,13 +157,13 @@ def _forward_decode( assert isinstance(q, torch.Tensor) B = q.shape[0] q_num_heads = q.shape[1] - o = torch.zeros(B, - q_num_heads, - self.kv_lora_rank, - dtype=q.dtype, - device=q.device) + o = torch.zeros( + B, q_num_heads, self.kv_lora_rank, dtype=q.dtype, device=q.device + ) lse = torch.zeros(B, q_num_heads, dtype=q.dtype, device=q.device) - num_kv_splits = 4 # TODO: heuristic + + # For batch invariance, use only 1 split to ensure deterministic reduction + num_kv_splits = 1 if vllm_is_batch_invariant() else 4 # TODO(lucas) Allocate ahead of time attn_logits = torch.empty( @@ -166,13 +181,22 @@ def _forward_decode( # Add a head dim of 1 kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.unsqueeze(2) - kv_c_cache = kv_c_and_k_pe_cache[..., :self.kv_lora_rank] + kv_c_cache = kv_c_and_k_pe_cache[..., : self.kv_lora_rank] PAGE_SIZE = kv_c_and_k_pe_cache.size(1) # Run MQA - decode_attention_fwd(q, kv_c_and_k_pe_cache, kv_c_cache, o, lse, - attn_metadata.decode.block_table, - attn_metadata.decode.seq_lens, attn_logits, - num_kv_splits, self.scale, PAGE_SIZE) + decode_attention_fwd( + q, + kv_c_and_k_pe_cache, + kv_c_cache, + o, + lse, + attn_metadata.decode.block_table, + attn_metadata.decode.seq_lens, + attn_logits, + num_kv_splits, + self.scale, + PAGE_SIZE, + ) return o, lse diff --git a/vllm_metax/v1/attention/backends/tree_attn.py b/vllm_metax/v1/attention/backends/tree_attn.py index 5980ebde3..413007c3d 100644 --- a/vllm_metax/v1/attention/backends/tree_attn.py +++ b/vllm_metax/v1/attention/backends/tree_attn.py @@ -4,31 +4,32 @@ import ast from dataclasses import dataclass -from typing import TYPE_CHECKING, Optional +from typing import Optional import torch -from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata, AttentionType) +from vllm import _custom_ops as ops +from vllm.attention.backends.abstract import ( + AttentionBackend, + AttentionImpl, + AttentionMetadata, + AttentionType, + MultipleOf, +) from vllm_metax.attention.ops.triton_unified_attention import unified_attention from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.v1.attention.backends.utils import ( - AttentionMetadataBuilder, CommonAttentionMetadata, - reorder_batch_to_split_decodes_and_prefills, split_decodes_and_prefills) + AttentionMetadataBuilder, + CommonAttentionMetadata, + split_decodes_and_prefills, +) from vllm.v1.kv_cache_interface import AttentionSpec -if TYPE_CHECKING: - from vllm.v1.core.sched.output import SchedulerOutput - from vllm.v1.worker.gpu_input_batch import InputBatch - -from vllm import _custom_ops as ops - logger = init_logger(__name__) -class TreeAttentionBackend(AttentionBackend): - +class MacaTreeAttentionBackend(AttentionBackend): accept_output_buffer: bool = True @classmethod @@ -39,6 +40,10 @@ def get_supported_dtypes(cls) -> list[torch.dtype]: def get_supported_head_sizes(cls) -> list[int]: return [32, 64, 96, 128, 160, 192, 224, 256] + @staticmethod + def get_supported_kernel_block_size() -> list[int | MultipleOf]: + return [MultipleOf(16)] + @classmethod def validate_head_size(cls, head_size: int) -> None: supported_head_sizes = cls.get_supported_head_sizes() @@ -48,7 +53,8 @@ def validate_head_size(cls, head_size: int) -> None: f"Head size {head_size} is not supported by {attn_type}. " f"Supported head sizes are: {supported_head_sizes}. " "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use " - "FlexAttention backend which supports all head sizes.") + "FlexAttention backend which supports all head sizes." + ) @staticmethod def get_name() -> str: @@ -98,7 +104,7 @@ class TreeAttentionMetadata: num_prefills: int = 0 num_decodes: int = 0 - tree_attn_bias: Optional[torch.Tensor] = None + tree_attn_bias: torch.Tensor | None = None # Cached Prefill/decode metadata. _cached_prefill_metadata: Optional["TreeAttentionMetadata"] = None @@ -114,9 +120,9 @@ def prefill_metadata(self) -> Optional["TreeAttentionMetadata"]: # metadata structure return self._cached_prefill_metadata - q_start_loc = self.query_start_loc[self.num_decodes:] + q_start_loc = self.query_start_loc[self.num_decodes :] q_seqlens = torch.diff(q_start_loc) - kv_seqlens = self.seq_lens[self.num_decodes:] + kv_seqlens = self.seq_lens[self.num_decodes :] # Construct & cache prefill-phase attention metadata structure self._cached_prefill_metadata = TreeAttentionMetadata( num_actual_tokens=self.num_prefill_tokens, @@ -124,8 +130,8 @@ def prefill_metadata(self) -> Optional["TreeAttentionMetadata"]: query_start_loc=q_start_loc - q_start_loc[0], max_seq_len=int(kv_seqlens.max().item()), seq_lens=kv_seqlens, - block_table=self.block_table[self.num_decodes:], - slot_mapping=self.slot_mapping[self.num_decode_tokens:], + block_table=self.block_table[self.num_decodes :], + slot_mapping=self.slot_mapping[self.num_decode_tokens :], ) return self._cached_prefill_metadata @@ -139,9 +145,9 @@ def decode_metadata(self) -> Optional["TreeAttentionMetadata"]: # metadata structure return self._cached_decode_metadata - q_start_loc = self.query_start_loc[:self.num_decodes + 1] + q_start_loc = self.query_start_loc[: self.num_decodes + 1] q_seqlens = torch.diff(q_start_loc) - kv_seqlens = self.seq_lens[:self.num_decodes] + kv_seqlens = self.seq_lens[: self.num_decodes] # Construct & cache decode-phase attention metadata structure self._cached_decode_metadata = TreeAttentionMetadata( num_actual_tokens=self.num_decode_tokens, @@ -149,16 +155,14 @@ def decode_metadata(self) -> Optional["TreeAttentionMetadata"]: query_start_loc=q_start_loc, max_seq_len=int(kv_seqlens.max().item()), seq_lens=kv_seqlens, - block_table=self.block_table[:self.num_decodes], - slot_mapping=self.slot_mapping[:self.num_decode_tokens], + block_table=self.block_table[: self.num_decodes], + slot_mapping=self.slot_mapping[: self.num_decode_tokens], tree_attn_bias=self.tree_attn_bias, ) return self._cached_decode_metadata -class TreeAttentionMetadataBuilder( - AttentionMetadataBuilder[TreeAttentionMetadata]): - +class TreeAttentionMetadataBuilder(AttentionMetadataBuilder[TreeAttentionMetadata]): def __init__( self, kv_cache_spec: AttentionSpec, @@ -172,10 +176,9 @@ def __init__( spec_config = vllm_config.speculative_config spec_token_tree = (spec := spec_config) and spec.speculative_token_tree - tree_choices: list[tuple[int, - ...]] = (ast.literal_eval(spec_token_tree) - if spec_token_tree is not None else - [(0, )]) + tree_choices: list[tuple[int, ...]] = ( + ast.literal_eval(spec_token_tree) if spec_token_tree is not None else [(0,)] + ) # Construct the tree attention bias. depth_counts = _get_depth_counts(tree_choices) self.tree_attn_bias = _prepare_tree_attn_bias( @@ -185,12 +188,7 @@ def __init__( device=device, ) - def reorder_batch(self, input_batch: "InputBatch", - scheduler_output: "SchedulerOutput") -> bool: - return reorder_batch_to_split_decodes_and_prefills( - input_batch, - scheduler_output, - decode_threshold=self.tree_attn_bias.shape[0]) + self.reorder_batch_threshold = self.tree_attn_bias.shape[0] def build( self, @@ -200,8 +198,10 @@ def build( ) -> TreeAttentionMetadata: decode_threshold = self.tree_attn_bias.shape[0] num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( - split_decodes_and_prefills(common_attn_metadata, - decode_threshold=decode_threshold)) + split_decodes_and_prefills( + common_attn_metadata, decode_threshold=decode_threshold + ) + ) num_actual_tokens = common_attn_metadata.num_actual_tokens q_start_loc = common_attn_metadata.query_start_loc @@ -241,8 +241,7 @@ def build_for_drafting( # Slice the tree attention bias for drafting. Exclude # the root level. start, end = 1, 1 + common_attn_metadata.max_query_len - self.tree_attn_bias = self.tree_attn_bias[start:end, - start:end].contiguous() + self.tree_attn_bias = self.tree_attn_bias[start:end, start:end].contiguous() # Build attention bias. attn_metadata = self.build(0, common_attn_metadata, fast_build=True) @@ -268,15 +267,14 @@ def _get_depth_counts(sorted_tree_choices: list[tuple[int, ...]]) -> list[int]: def _prepare_tree_attn_bias( sorted_tree_choices: list[tuple[int, ...]], depth_counts: list[int], - dtype: Optional[torch.dtype], - device: Optional[torch.device], + dtype: torch.dtype | None, + device: torch.device | None, ) -> torch.Tensor: # +1 comes from the additional root node. tree_len = len(sorted_tree_choices) + 1 - tree_attn_mask = torch.full((tree_len, tree_len), - -torch.inf, - device=device, - dtype=dtype) + tree_attn_mask = torch.full( + (tree_len, tree_len), -torch.inf, device=device, dtype=dtype + ) # Set diagonal to all zeros. Each token should # attend to itself. @@ -298,26 +296,26 @@ def _prepare_tree_attn_bias( ancestor_idx = [] for c in range(len(cur_tree_choice) - 1): ancestor_idx.append( - sorted_tree_choices.index(cur_tree_choice[:c + 1]) + 1) + sorted_tree_choices.index(cur_tree_choice[: c + 1]) + 1 + ) tree_attn_mask[j + start + 1, ancestor_idx] = mask_val start += depth_counts[i] return tree_attn_mask class TreeAttentionImpl(AttentionImpl): - def __init__( self, num_heads: int, head_size: int, scale: float, num_kv_heads: int, - alibi_slopes: Optional[list[float]], - sliding_window: Optional[int], + alibi_slopes: list[float] | None, + sliding_window: int | None, kv_cache_dtype: str, - logits_soft_cap: Optional[float] = None, + logits_soft_cap: float | None = None, attn_type: AttentionType = AttentionType.DECODER, - kv_sharing_target_layer_name: Optional[str] = None, + kv_sharing_target_layer_name: str | None = None, ) -> None: self.num_heads = num_heads self.head_size = head_size @@ -338,13 +336,15 @@ def __init__( else: self.sliding_window = (sliding_window - 1, 0) - TreeAttentionBackend.validate_head_size(head_size) + MacaTreeAttentionBackend.validate_head_size(head_size) if attn_type != AttentionType.DECODER: - raise NotImplementedError("Encoder self-attention and " - "encoder/decoder cross-attention " - "are not implemented for " - "TreeAttentionImpl.") + raise NotImplementedError( + "Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "TreeAttentionImpl." + ) def forward( self, @@ -354,9 +354,9 @@ def forward( value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: TreeAttentionMetadata, - output: Optional[torch.Tensor] = None, - output_scale: Optional[torch.Tensor] = None, - output_block_scale: Optional[torch.Tensor] = None, + output: torch.Tensor | None = None, + output_scale: torch.Tensor | None = None, + output_block_scale: torch.Tensor | None = None, ) -> torch.Tensor: """Forward pass with TreeAttention. @@ -374,12 +374,12 @@ def forward( if output_scale is not None or output_block_scale is not None: raise NotImplementedError( - "fused output quantization is not yet supported" - " for TreeAttentionImpl") + "fused output quantization is not yet supported for TreeAttentionImpl" + ) if attn_metadata is None: # Profiling run. - return output + return output.fill_(0) # Cache the input KVs. key_cache, value_cache = kv_cache.unbind(0) @@ -404,8 +404,7 @@ def forward( num_actual_tokens = attn_metadata.num_actual_tokens num_decode_tokens = attn_metadata.num_decode_tokens - descale_shape = (attn_metadata.query_start_loc.shape[0] - 1, - key.shape[1]) + descale_shape = (attn_metadata.query_start_loc.shape[0] - 1, key.shape[1]) if prefill_meta := attn_metadata.prefill_metadata: unified_attention( q=query[num_decode_tokens:num_actual_tokens], diff --git a/vllm_metax/v1/attention/backends/triton_attn.py b/vllm_metax/v1/attention/backends/triton_attn.py index ae11adcb7..0f354ce2e 100644 --- a/vllm_metax/v1/attention/backends/triton_attn.py +++ b/vllm_metax/v1/attention/backends/triton_attn.py @@ -1,28 +1,37 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """High-Performance Triton-only Attention layer.""" + from dataclasses import dataclass -from typing import ClassVar, Optional +from typing import ClassVar import torch -from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata, AttentionType) +from vllm.attention.backends.abstract import ( + AttentionBackend, + AttentionImpl, + AttentionMetadata, + AttentionType, + MultipleOf, +) from vllm.attention.ops.triton_reshape_and_cache_flash import ( - triton_reshape_and_cache_flash) -from vllm.attention.ops.triton_unified_attention import unified_attention + triton_reshape_and_cache_flash, +) +from vllm_metax.attention.ops.triton_unified_attention import unified_attention from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.layers.quantization.utils.quant_utils import ( - QuantKey, kFp8StaticTensorSym) + QuantKey, + kFp8StaticTensorSym, +) from vllm.platforms import current_platform -from vllm.v1.attention.backends.utils import (AttentionCGSupport, - AttentionMetadataBuilder, - CommonAttentionMetadata) +from vllm.v1.attention.backends.utils import ( + AttentionCGSupport, + AttentionMetadataBuilder, + CommonAttentionMetadata, +) from vllm.v1.kv_cache_interface import AttentionSpec -from vllm_metax import _custom_ops as ops - logger = init_logger(__name__) @@ -47,30 +56,34 @@ class TritonAttentionMetadata: # For cascade attention. use_cascade: bool common_prefix_len: int - cu_prefix_query_lens: Optional[torch.Tensor] - prefix_kv_lens: Optional[torch.Tensor] - suffix_kv_lens: Optional[torch.Tensor] + cu_prefix_query_lens: torch.Tensor | None + prefix_kv_lens: torch.Tensor | None + suffix_kv_lens: torch.Tensor | None # Optional aot scheduling - scheduler_metadata: Optional[torch.Tensor] = None - prefix_scheduler_metadata: Optional[torch.Tensor] = None + scheduler_metadata: torch.Tensor | None = None + prefix_scheduler_metadata: torch.Tensor | None = None -class TritonAttentionMetadataBuilder( - AttentionMetadataBuilder[TritonAttentionMetadata]): +class TritonAttentionMetadataBuilder(AttentionMetadataBuilder[TritonAttentionMetadata]): cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.ALWAYS - def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], - vllm_config: VllmConfig, device: torch.device): + def __init__( + self, + kv_cache_spec: AttentionSpec, + layer_names: list[str], + vllm_config: VllmConfig, + device: torch.device, + ): super().__init__(kv_cache_spec, layer_names, vllm_config, device) self.block_size = kv_cache_spec.block_size model_config = vllm_config.model_config self.num_heads_q = model_config.get_num_attention_heads( - vllm_config.parallel_config) - self.num_heads_kv = model_config.get_num_kv_heads( - vllm_config.parallel_config) + vllm_config.parallel_config + ) + self.num_heads_kv = model_config.get_num_kv_heads(vllm_config.parallel_config) self.headdim = model_config.get_head_size() def build_for_cudagraph_capture( @@ -83,10 +96,12 @@ def build_for_cudagraph_capture( attn_metadata.seq_lens.fill_(1) return attn_metadata - def build(self, - common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata, - fast_build: bool = False) -> TritonAttentionMetadata: + def build( + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False, + ) -> TritonAttentionMetadata: num_actual_tokens = common_attn_metadata.num_actual_tokens max_query_len = common_attn_metadata.max_query_len @@ -99,14 +114,13 @@ def build(self, use_cascade = common_prefix_len > 0 if use_cascade: - cu_prefix_query_lens = torch.tensor([0, num_actual_tokens], - dtype=torch.int32, - device=self.device) - prefix_kv_lens = torch.tensor([common_prefix_len], - dtype=torch.int32, - device=self.device) - suffix_kv_lens = (common_attn_metadata.seq_lens_cpu - - common_prefix_len) + cu_prefix_query_lens = torch.tensor( + [0, num_actual_tokens], dtype=torch.int32, device=self.device + ) + prefix_kv_lens = torch.tensor( + [common_prefix_len], dtype=torch.int32, device=self.device + ) + suffix_kv_lens = common_attn_metadata.seq_lens_cpu - common_prefix_len suffix_kv_lens = suffix_kv_lens.to(self.device) else: cu_prefix_query_lens = None @@ -133,16 +147,15 @@ def build(self, class MacaTritonAttentionBackend(AttentionBackend): - accept_output_buffer: bool = True @classmethod def get_supported_dtypes(cls) -> list[torch.dtype]: return [torch.float16, torch.bfloat16, torch.float32] - @classmethod - def get_supported_head_sizes(cls) -> list[int]: - return [32, 64, 96, 128, 160, 192, 224, 256] + @staticmethod + def get_supported_kernel_block_size() -> list[int | MultipleOf]: + return [MultipleOf(16)] @classmethod def validate_head_size(cls, head_size: int) -> None: @@ -152,7 +165,8 @@ def validate_head_size(cls, head_size: int) -> None: f"Head size {head_size} is not supported by TritonAttention." f"Head sizes need to be larger or equal 32 for this backend. " "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use " - "FlexAttention backend which supports all head sizes.") + "FlexAttention backend which supports all head sizes." + ) @staticmethod def get_name() -> str: @@ -188,23 +202,25 @@ def get_builder_cls() -> type["TritonAttentionMetadataBuilder"]: class TritonAttentionImpl(AttentionImpl): - def fused_output_quant_supported(self, quant_key: QuantKey): return quant_key == kFp8StaticTensorSym + def supports_quant_query_input(self) -> bool: + return current_platform.is_cuda() + def __init__( self, num_heads: int, head_size: int, scale: float, num_kv_heads: int, - alibi_slopes: Optional[list[float]], - sliding_window: Optional[int], + alibi_slopes: list[float] | None, + sliding_window: int | None, kv_cache_dtype: str, - logits_soft_cap: Optional[float] = None, + logits_soft_cap: float | None = None, attn_type: AttentionType = AttentionType.DECODER, - kv_sharing_target_layer_name: Optional[int] = None, - sinks: Optional[torch.Tensor] = None, + kv_sharing_target_layer_name: int | None = None, + sinks: torch.Tensor | None = None, ) -> None: self.num_heads = num_heads self.head_size = head_size @@ -229,10 +245,12 @@ def __init__( MacaTritonAttentionBackend.validate_head_size(head_size) if attn_type != AttentionType.DECODER: - raise NotImplementedError("Encoder self-attention and " - "encoder/decoder cross-attention " - "are not implemented for " - "TritonAttentionImpl") + raise NotImplementedError( + "Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "TritonAttentionImpl" + ) self.fp8_dtype = current_platform.fp8_dtype() @@ -241,7 +259,8 @@ def __init__( assert sinks.shape[0] == num_heads, ( "Sinks must have the same number of heads as the number of " f"heads in the layer. Sinks shape: {sinks.shape}, " - f"num_heads: {num_heads}.") + f"num_heads: {num_heads}." + ) def forward( self, @@ -251,9 +270,9 @@ def forward( value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: TritonAttentionMetadata, - output: Optional[torch.Tensor] = None, - output_scale: Optional[torch.Tensor] = None, - output_block_scale: Optional[torch.Tensor] = None, + output: torch.Tensor | None = None, + output_scale: torch.Tensor | None = None, + output_block_scale: torch.Tensor | None = None, ) -> torch.Tensor: """Forward pass with Paged Attention impl. in Triton. @@ -272,11 +291,12 @@ def forward( if output_block_scale is not None: raise NotImplementedError( "fused block_scale output quantization is not yet supported" - " for TritonAttentionImpl") + " for TritonAttentionImpl" + ) if attn_metadata is None: # Profiling run. - return output + return output.fill_(0) assert attn_metadata.use_cascade is False @@ -316,18 +336,9 @@ def forward( if key_cache.dtype != self.fp8_dtype: key_cache = key_cache.view(self.fp8_dtype) value_cache = value_cache.view(self.fp8_dtype) - num_tokens, num_heads, head_size = query.shape - assert layer._q_scale_float == 1.0, \ + assert layer._q_scale_float == 1.0, ( "A non 1.0 q_scale is not currently supported." - if current_platform.is_cuda(): - # Skip Q quantization on ROCm and XPU, enable this on cuda - # only, since dequantizing back to f32 in the attention kernel - # is not supported. - query, _ = ops.scaled_fp8_quant( - query.reshape( - (num_tokens, num_heads * head_size)).contiguous(), - layer._q_scale) - query = query.reshape((num_tokens, num_heads, head_size)) + ) cu_seqlens_q = attn_metadata.query_start_loc seqused_k = attn_metadata.seq_lens diff --git a/vllm_metax/version.py b/vllm_metax/version.py index 8329d7bec..f368c16dd 100644 --- a/vllm_metax/version.py +++ b/vllm_metax/version.py @@ -5,9 +5,7 @@ except Exception as e: import warnings - warnings.warn(f"Failed to read commit hash:\n{e}", - RuntimeWarning, - stacklevel=2) + warnings.warn(f"Failed to read commit hash:\n{e}", RuntimeWarning, stacklevel=2) __version__ = "dev" __version_tuple__ = (0, 0, __version__)