diff --git a/.buildkite/scripts/hardware_ci/run-xpu-test.sh b/.buildkite/scripts/hardware_ci/run-xpu-test.sh index be7886354392..1e72c2931688 100644 --- a/.buildkite/scripts/hardware_ci/run-xpu-test.sh +++ b/.buildkite/scripts/hardware_ci/run-xpu-test.sh @@ -40,16 +40,16 @@ docker run \ python3 examples/basic/offline_inference/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend mp python3 examples/basic/offline_inference/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager --attention-backend=TRITON_ATTN python3 examples/basic/offline_inference/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager --quantization fp8 - python3 examples/basic/offline_inference/generate.py --model superjob/Qwen3-4B-Instruct-2507-GPTQ-Int4 --block-size 64 --enforce-eager + python3 examples/basic/offline_inference/generate.py --model superjob/Qwen3-4B-Instruct-2507-GPTQ-Int4 --block-size 64 --enforce-eager --max-model-len 8192 python3 examples/basic/offline_inference/generate.py --model ibm-research/PowerMoE-3b --block-size 64 --enforce-eager -tp 2 python3 examples/basic/offline_inference/generate.py --model ibm-research/PowerMoE-3b --block-size 64 --enforce-eager -tp 2 --enable-expert-parallel cd tests pytest -v -s v1/core --ignore=v1/core/test_reset_prefix_cache_e2e.py --ignore=v1/core/test_scheduler_e2e.py pytest -v -s v1/engine pytest -v -s v1/sample --ignore=v1/sample/test_logprobs.py --ignore=v1/sample/test_logprobs_e2e.py - pytest -v -s v1/worker --ignore=v1/worker/test_gpu_model_runner.py + pytest -v -s v1/worker --ignore=v1/worker/test_gpu_model_runner.py --ignore=v1/worker/test_worker_memory_snapshot.py pytest -v -s v1/structured_output pytest -v -s v1/spec_decode --ignore=v1/spec_decode/test_max_len.py --ignore=v1/spec_decode/test_tree_attention.py --ignore=v1/spec_decode/test_speculators_eagle3.py --ignore=v1/spec_decode/test_acceptance_length.py - pytest -v -s v1/kv_connector/unit --ignore=v1/kv_connector/unit/test_multi_connector.py --ignore=v1/kv_connector/unit/test_nixl_connector.py --ignore=v1/kv_connector/unit/test_example_connector.py --ignore=v1/kv_connector/unit/test_lmcache_integration.py + pytest -v -s v1/kv_connector/unit --ignore=v1/kv_connector/unit/test_multi_connector.py --ignore=v1/kv_connector/unit/test_nixl_connector.py --ignore=v1/kv_connector/unit/test_example_connector.py --ignore=v1/kv_connector/unit/test_lmcache_integration.py -k "not (test_register_kv_caches and FLASH_ATTN and True)" pytest -v -s v1/test_serial_utils.py ' diff --git a/.buildkite/test-amd.yaml b/.buildkite/test-amd.yaml index eb331aaf9d43..a4a8778fe620 100644 --- a/.buildkite/test-amd.yaml +++ b/.buildkite/test-amd.yaml @@ -1573,7 +1573,7 @@ steps: - tests/compile/fullgraph/test_basic_correctness.py - examples/offline_inference/rlhf.py - examples/offline_inference/rlhf_colocate.py - - examples/offline_inference/new_weight_syncing/ + - examples/rl/ - tests/examples/offline_inference/data_parallel.py - tests/v1/distributed - tests/v1/engine/test_engine_core_client.py @@ -1615,7 +1615,7 @@ steps: - VLLM_ALLOW_INSECURE_SERIALIZATION=1 RAY_DEDUP_LOGS=0 python3 rlhf_colocate.py - popd # NEW rlhf examples - - pushd ../examples/offline_inference/new_weight_syncing + - pushd ../examples/rl - VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf_nccl.py - VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf_ipc.py - VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf_async_new_apis.py @@ -2660,7 +2660,7 @@ steps: - tests/v1/entrypoints/openai/test_multi_api_servers.py - tests/v1/shutdown - tests/v1/worker/test_worker_memory_snapshot.py - - examples/offline_inference/new_weight_syncing/ + - examples/rl/ commands: # Work around HIP bug tracked here: https://github.com/ROCm/hip/issues/3876 # TODO: Remove when the bug is fixed in a future ROCm release @@ -3325,7 +3325,7 @@ steps: - tests/compile/fullgraph/test_basic_correctness.py - examples/offline_inference/rlhf.py - examples/offline_inference/rlhf_colocate.py - - examples/offline_inference/new_weight_syncing/ + - examples/rl/ - tests/examples/offline_inference/data_parallel.py - tests/v1/distributed - tests/v1/engine/test_engine_core_client.py @@ -3367,7 +3367,7 @@ steps: - VLLM_ALLOW_INSECURE_SERIALIZATION=1 RAY_DEDUP_LOGS=0 python3 rlhf_colocate.py - popd # NEW rlhf examples - - pushd ../examples/offline_inference/new_weight_syncing + - pushd ../examples/rl - VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf_nccl.py - VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf_ipc.py - VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf_async_new_apis.py diff --git a/.buildkite/test_areas/distributed.yaml b/.buildkite/test_areas/distributed.yaml index 331103ceebd1..03ffc5a274a3 100644 --- a/.buildkite/test_areas/distributed.yaml +++ b/.buildkite/test_areas/distributed.yaml @@ -82,7 +82,7 @@ steps: - label: Distributed Torchrun + Examples (4 GPUs) timeout_in_minutes: 30 - working_dir: "/vllm-workspace/tests" + working_dir: "/vllm-workspace" num_devices: 4 source_file_dependencies: - vllm/distributed/ @@ -90,33 +90,28 @@ steps: - tests/distributed/test_torchrun_example_moe.py - examples/offline_inference/rlhf.py - examples/offline_inference/rlhf_colocate.py - - examples/offline_inference/new_weight_syncing/ + - examples/rl/ - tests/examples/offline_inference/data_parallel.py commands: # https://github.com/NVIDIA/nccl/issues/1838 - export NCCL_CUMEM_HOST_ENABLE=0 # test with torchrun tp=2 and external_dp=2 - - torchrun --nproc-per-node=4 distributed/test_torchrun_example.py + - torchrun --nproc-per-node=4 tests/distributed/test_torchrun_example.py # test with torchrun tp=2 and pp=2 - - PP_SIZE=2 torchrun --nproc-per-node=4 distributed/test_torchrun_example.py + - PP_SIZE=2 torchrun --nproc-per-node=4 tests/distributed/test_torchrun_example.py # test with torchrun tp=4 and dp=1 - - TP_SIZE=4 torchrun --nproc-per-node=4 distributed/test_torchrun_example_moe.py + - TP_SIZE=4 torchrun --nproc-per-node=4 tests/distributed/test_torchrun_example_moe.py # test with torchrun tp=2, pp=2 and dp=1 - - PP_SIZE=2 TP_SIZE=2 torchrun --nproc-per-node=4 distributed/test_torchrun_example_moe.py + - PP_SIZE=2 TP_SIZE=2 torchrun --nproc-per-node=4 tests/distributed/test_torchrun_example_moe.py # test with torchrun tp=1 and dp=4 with ep - - DP_SIZE=4 ENABLE_EP=1 torchrun --nproc-per-node=4 distributed/test_torchrun_example_moe.py + - DP_SIZE=4 ENABLE_EP=1 torchrun --nproc-per-node=4 tests/distributed/test_torchrun_example_moe.py # test with torchrun tp=2 and dp=2 with ep - - TP_SIZE=2 DP_SIZE=2 ENABLE_EP=1 torchrun --nproc-per-node=4 distributed/test_torchrun_example_moe.py + - TP_SIZE=2 DP_SIZE=2 ENABLE_EP=1 torchrun --nproc-per-node=4 tests/distributed/test_torchrun_example_moe.py # test with internal dp - - python3 ../examples/offline_inference/data_parallel.py --enforce-eager - # OLD rlhf examples - - cd ../examples/offline_inference - - VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf.py - - VLLM_ALLOW_INSECURE_SERIALIZATION=1 RAY_DEDUP_LOGS=0 python3 rlhf_colocate.py - # NEW rlhf examples - - cd new_weight_syncing - - VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf_nccl.py - - VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf_ipc.py + - python3 examples/offline_inference/data_parallel.py --enforce-eager + # rlhf examples + - VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 examples/rl/rlhf_nccl.py + - VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 examples/rl/rlhf_ipc.py - label: Distributed DP Tests (4 GPUs) timeout_in_minutes: 30 diff --git a/.buildkite/test_areas/expert_parallelism.yaml b/.buildkite/test_areas/expert_parallelism.yaml index 1443d847eaf5..63404fc5df66 100644 --- a/.buildkite/test_areas/expert_parallelism.yaml +++ b/.buildkite/test_areas/expert_parallelism.yaml @@ -24,8 +24,7 @@ steps: - label: Elastic EP Scaling Test timeout_in_minutes: 20 - device: b200 - optional: true + device: h100 working_dir: "/vllm-workspace/tests" num_devices: 4 source_file_dependencies: diff --git a/CMakeLists.txt b/CMakeLists.txt index bbadfdc5e9e3..693070b5f476 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -999,6 +999,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") list(APPEND VLLM_MOE_EXT_SRC "csrc/moe/moe_wna16.cu" "csrc/moe/grouped_topk_kernels.cu" + "csrc/moe/gpt_oss_router_gemm.cu" "csrc/moe/router_gemm.cu") endif() diff --git a/benchmarks/kernels/benchmark_moe.py b/benchmarks/kernels/benchmark_moe.py index cf49232fd72d..515406aa9ce0 100644 --- a/benchmarks/kernels/benchmark_moe.py +++ b/benchmarks/kernels/benchmark_moe.py @@ -750,17 +750,20 @@ def get_weight_block_size_safety(config, default_value=None): def get_model_params(config): - if config.architectures[0] == "DbrxForCausalLM": + architectures = getattr(config, "architectures", None) or [type(config).__name__] + architecture = architectures[0] + + if architecture == "DbrxForCausalLM": E = config.ffn_config.moe_num_experts topk = config.ffn_config.moe_top_k intermediate_size = config.ffn_config.ffn_hidden_size hidden_size = config.hidden_size - elif config.architectures[0] == "JambaForCausalLM": + elif architecture == "JambaForCausalLM": E = config.num_experts topk = config.num_experts_per_tok intermediate_size = config.intermediate_size hidden_size = config.hidden_size - elif config.architectures[0] in ( + elif architecture in ( "DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM", "DeepseekV32ForCausalLM", @@ -774,7 +777,7 @@ def get_model_params(config): topk = config.num_experts_per_tok intermediate_size = config.moe_intermediate_size hidden_size = config.hidden_size - elif config.architectures[0] in ( + elif architecture in ( "Qwen2MoeForCausalLM", "Qwen3MoeForCausalLM", "Qwen3NextForCausalLM", @@ -783,23 +786,27 @@ def get_model_params(config): topk = config.num_experts_per_tok intermediate_size = config.moe_intermediate_size hidden_size = config.hidden_size - elif config.architectures[0] == "Qwen3VLMoeForConditionalGeneration": + elif architecture in ( + "Qwen3VLMoeForConditionalGeneration", + "Qwen3_5MoeForConditionalGeneration", + "Qwen3_5MoeTextConfig", + ): text_config = config.get_text_config() E = text_config.num_experts topk = text_config.num_experts_per_tok intermediate_size = text_config.moe_intermediate_size hidden_size = text_config.hidden_size - elif config.architectures[0] == "HunYuanMoEV1ForCausalLM": + elif architecture == "HunYuanMoEV1ForCausalLM": E = config.num_experts topk = config.moe_topk[0] intermediate_size = config.moe_intermediate_size[0] hidden_size = config.hidden_size - elif config.architectures[0] == "Qwen3OmniMoeForConditionalGeneration": + elif architecture == "Qwen3OmniMoeForConditionalGeneration": E = config.thinker_config.text_config.num_experts topk = config.thinker_config.text_config.num_experts_per_tok intermediate_size = config.thinker_config.text_config.moe_intermediate_size hidden_size = config.thinker_config.text_config.hidden_size - elif config.architectures[0] == "PixtralForConditionalGeneration": + elif architecture == "PixtralForConditionalGeneration": # Pixtral can contain different LLM architectures, # recurse to get their parameters return get_model_params(config.get_text_config()) @@ -814,6 +821,23 @@ def get_model_params(config): return E, topk, intermediate_size, hidden_size +def resolve_dtype(config) -> torch.dtype: + if current_platform.is_rocm(): + return torch.float16 + + dtype = getattr(config, "dtype", None) + if dtype is not None: + return dtype + + if hasattr(config, "get_text_config"): + text_config = config.get_text_config() + dtype = getattr(text_config, "dtype", None) + if dtype is not None: + return dtype + + return torch.bfloat16 + + def get_quantization_group_size(config) -> int | None: """Extract the quantization group size from the HF model config. @@ -861,7 +885,7 @@ def main(args: argparse.Namespace): else: ensure_divisibility(intermediate_size, args.tp_size, "intermediate_size") shard_intermediate_size = 2 * intermediate_size // args.tp_size - dtype = torch.float16 if current_platform.is_rocm() else config.dtype + dtype = resolve_dtype(config) use_fp8_w8a8 = args.dtype == "fp8_w8a8" use_int8_w8a16 = args.dtype == "int8_w8a16" use_int4_w4a16 = args.dtype == "int4_w4a16" diff --git a/benchmarks/kernels/benchmark_router_gemm.py b/benchmarks/kernels/benchmark_router_gemm.py new file mode 100644 index 000000000000..cc63f8904c27 --- /dev/null +++ b/benchmarks/kernels/benchmark_router_gemm.py @@ -0,0 +1,134 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch +import torch.nn.functional as F + +from vllm import _custom_ops as ops +from vllm.platforms import current_platform +from vllm.transformers_utils.config import get_config +from vllm.triton_utils import triton +from vllm.utils.argparse_utils import FlexibleArgumentParser + +# Dimensions supported by the DSV3 specialized kernel +DSV3_SUPPORTED_NUM_EXPERTS = [256, 384] +DSV3_SUPPORTED_HIDDEN_SIZES = [7168] + +# Dimensions supported by the gpt-oss specialized kernel +GPT_OSS_SUPPORTED_NUM_EXPERTS = [32, 128] +GPT_OSS_SUPPORTED_HIDDEN_SIZES = [2880] + + +def get_batch_size_range(max_batch_size): + return [2**x for x in range(14) if 2**x <= max_batch_size] + + +def get_model_params(config): + if config.architectures[0] in ( + "DeepseekV2ForCausalLM", + "DeepseekV3ForCausalLM", + "DeepseekV32ForCausalLM", + ): + num_experts = config.n_routed_experts + hidden_size = config.hidden_size + elif config.architectures[0] in ("GptOssForCausalLM",): + num_experts = config.num_local_experts + hidden_size = config.hidden_size + else: + raise ValueError(f"Unsupported architecture: {config.architectures}") + return num_experts, hidden_size + + +def get_benchmark(model, max_batch_size, trust_remote_code): + @triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["batch_size"], + x_vals=get_batch_size_range(max_batch_size), + x_log=False, + line_arg="provider", + line_vals=[ + "torch", + "vllm", + ], + line_names=["PyTorch", "vLLM"], + styles=([("blue", "-"), ("red", "-")]), + ylabel="TFLOPs", + plot_name=f"{model} router gemm throughput", + args={}, + ) + ) + def benchmark(batch_size, provider): + config = get_config(model=model, trust_remote_code=trust_remote_code) + num_experts, hidden_size = get_model_params(config) + + mat_a = torch.randn( + (batch_size, hidden_size), dtype=torch.bfloat16, device="cuda" + ).contiguous() + mat_b = torch.randn( + (num_experts, hidden_size), dtype=torch.bfloat16, device="cuda" + ).contiguous() + bias = torch.randn( + num_experts, dtype=torch.bfloat16, device="cuda" + ).contiguous() + + is_hopper_or_blackwell = current_platform.is_device_capability( + 90 + ) or current_platform.is_device_capability_family(100) + allow_dsv3_router_gemm = ( + is_hopper_or_blackwell + and num_experts in DSV3_SUPPORTED_NUM_EXPERTS + and hidden_size in DSV3_SUPPORTED_HIDDEN_SIZES + ) + allow_gpt_oss_router_gemm = ( + is_hopper_or_blackwell + and num_experts in GPT_OSS_SUPPORTED_NUM_EXPERTS + and hidden_size in GPT_OSS_SUPPORTED_HIDDEN_SIZES + ) + + has_bias = False + if allow_gpt_oss_router_gemm: + has_bias = True + + quantiles = [0.5, 0.2, 0.8] + + if provider == "torch": + + def runner(): + if has_bias: + F.linear(mat_a, mat_b, bias) + else: + F.linear(mat_a, mat_b) + elif provider == "vllm": + + def runner(): + if allow_dsv3_router_gemm: + ops.dsv3_router_gemm(mat_a, mat_b, torch.bfloat16) + elif allow_gpt_oss_router_gemm: + ops.gpt_oss_router_gemm(mat_a, mat_b, bias) + else: + raise ValueError("Unsupported router gemm") + + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( + runner, quantiles=quantiles + ) + + def tflops(t_ms): + flops = 2 * batch_size * hidden_size * num_experts + return flops / (t_ms * 1e-3) / 1e12 + + return tflops(ms), tflops(max_ms), tflops(min_ms) + + return benchmark + + +if __name__ == "__main__": + parser = FlexibleArgumentParser() + parser.add_argument("--model", type=str, default="openai/gpt-oss-20b") + parser.add_argument("--max-batch-size", default=16, type=int) + parser.add_argument("--trust-remote-code", action="store_true") + args = parser.parse_args() + + # Get the benchmark function + benchmark = get_benchmark(args.model, args.max_batch_size, args.trust_remote_code) + # Run performance benchmark + benchmark.run(print_data=True) diff --git a/csrc/moe/gpt_oss_router_gemm.cu b/csrc/moe/gpt_oss_router_gemm.cu new file mode 100644 index 000000000000..0294cd36aa8f --- /dev/null +++ b/csrc/moe/gpt_oss_router_gemm.cu @@ -0,0 +1,144 @@ +/* + * Adapted from + * https://github.com/NVIDIA/TensorRT-LLM/blob/v1.3.0rc7/cpp/tensorrt_llm/kernels/tinygemm2/tinygemm2_cuda.cu + * Copyright (c) 2025, The vLLM team. + * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. + * All rights reserved. SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include "gpt_oss_router_gemm.cuh" + +void launch_gpt_oss_router_gemm(__nv_bfloat16* gA, __nv_bfloat16* gB, + __nv_bfloat16* gC, __nv_bfloat16* bias, + int batch_size, int output_features, + int input_features, cudaStream_t stream) { + static int const WARP_TILE_M = 16; + static int const TILE_M = WARP_TILE_M; + static int const TILE_N = 8; + static int const TILE_K = 64; + static int const STAGES = 16; + static int const STAGE_UNROLL = 4; + static bool const PROFILE = false; + + CUtensorMap weight_map{}; + CUtensorMap activation_map{}; + + constexpr uint32_t rank = 2; + uint64_t size[rank] = {(uint64_t)input_features, (uint64_t)output_features}; + uint64_t stride[rank - 1] = {input_features * sizeof(__nv_bfloat16)}; + uint32_t box_size[rank] = {TILE_K, TILE_M}; + uint32_t elem_stride[rank] = {1, 1}; + + CUresult res = cuTensorMapEncodeTiled( + &weight_map, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, rank, + gB, size, stride, box_size, elem_stride, + CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE, + CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B, + CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_NONE, + CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE); + TORCH_CHECK(res == CUDA_SUCCESS, + "cuTensorMapEncodeTiled failed for weight_map, error code=", + static_cast(res)); + + size[1] = batch_size; + box_size[1] = TILE_N; + + res = cuTensorMapEncodeTiled( + &activation_map, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, + rank, gA, size, stride, box_size, elem_stride, + CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE, + CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B, + CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_NONE, + CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE); + TORCH_CHECK(res == CUDA_SUCCESS, + "cuTensorMapEncodeTiled failed for activation_map, error code=", + static_cast(res)); + + int smem_size = STAGES * STAGE_UNROLL * + (TILE_M * TILE_K * sizeof(__nv_bfloat16) + + TILE_N * TILE_K * sizeof(__nv_bfloat16)); + + gpuErrChk(cudaFuncSetAttribute( + gpt_oss_router_gemm_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + + int tiles_m = (output_features + TILE_M - 1) / TILE_M; + int tiles_n = (batch_size + TILE_N - 1) / TILE_N; + + dim3 grid(tiles_m, tiles_n); + dim3 block(384); + + cudaLaunchConfig_t config; + cudaLaunchAttribute attrs[1]; + config.gridDim = grid; + config.blockDim = block; + config.dynamicSmemBytes = smem_size; + config.stream = stream; + config.attrs = attrs; + attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; + attrs[0].val.programmaticStreamSerializationAllowed = 1; + config.numAttrs = 1; + + cudaLaunchKernelEx( + &config, + &gpt_oss_router_gemm_kernel, + gC, gA, gB, bias, output_features, batch_size, input_features, weight_map, + activation_map, nullptr); +} + +void gpt_oss_router_gemm_cuda_forward(torch::Tensor& output, + torch::Tensor input, torch::Tensor weight, + torch::Tensor bias) { + auto const batch_size = input.size(0); + auto const input_dim = input.size(1); + auto const output_dim = weight.size(0); + + auto stream = at::cuda::getCurrentCUDAStream(); + + if (input.scalar_type() == at::ScalarType::BFloat16) { + launch_gpt_oss_router_gemm((__nv_bfloat16*)input.data_ptr(), + (__nv_bfloat16*)weight.data_ptr(), + (__nv_bfloat16*)output.mutable_data_ptr(), + (__nv_bfloat16*)bias.data_ptr(), batch_size, + output_dim, input_dim, stream); + } else { + throw std::invalid_argument("Unsupported dtype, only supports bfloat16"); + } +} + +void gpt_oss_router_gemm(torch::Tensor& output, torch::Tensor input, + torch::Tensor weight, torch::Tensor bias) { + TORCH_CHECK(input.dim() == 2, "input must be 2D"); + TORCH_CHECK(weight.dim() == 2, "weight must be 2D"); + TORCH_CHECK(bias.dim() == 1, "bias must be 1D"); + TORCH_CHECK(input.sizes()[1] == weight.sizes()[1], + "input.size(1) must match weight.size(1)"); + TORCH_CHECK(weight.sizes()[0] == bias.sizes()[0], + "weight.size(0) must match bias.size(0)"); + TORCH_CHECK(input.scalar_type() == at::ScalarType::BFloat16, + "input tensor must be bfloat16"); + TORCH_CHECK(weight.scalar_type() == at::ScalarType::BFloat16, + "weight tensor must be bfloat16"); + TORCH_CHECK(bias.scalar_type() == at::ScalarType::BFloat16, + "bias tensor must be bfloat16"); + gpt_oss_router_gemm_cuda_forward(output, input, weight, bias); +} diff --git a/csrc/moe/gpt_oss_router_gemm.cuh b/csrc/moe/gpt_oss_router_gemm.cuh new file mode 100644 index 000000000000..5cc653f19cfb --- /dev/null +++ b/csrc/moe/gpt_oss_router_gemm.cuh @@ -0,0 +1,447 @@ +/* + * Adapted from + * https://github.com/NVIDIA/TensorRT-LLM/blob/v1.3.0rc7/cpp/tensorrt_llm/kernels/tinygemm2/tinygemm2_kernel.cuh + * Copyright (c) 2025, The vLLM team. + * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. + * All rights reserved. SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "cuda_bf16.h" +#include +#include +#include + +#include "cuda_pipeline.h" +#include +#include +#include +#include + +using barrier = cuda::barrier; +namespace cde = cuda::device::experimental; +namespace ptx = cuda::ptx; + +#define gpuErrChk(ans) \ + { \ + gpuAssert((ans), __FILE__, __LINE__); \ + } + +inline void gpuAssert(cudaError_t code, char const* file, int line, + bool abort = true) { + if (code != cudaSuccess) { + fprintf(stderr, "GPUassert: %s %s %d\n", cudaGetErrorString(code), file, + line); + if (abort) { + throw std::runtime_error(cudaGetErrorString(code)); + } + } +} + +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) +__device__ uint64_t gclock64() { + unsigned long long int rv; + asm volatile("mov.u64 %0, %%globaltimer;" : "=l"(rv)); + return rv; +} + +__device__ void ldmatrix(__nv_bfloat16 rv[2], uint32_t smem_ptr) { + int dst; + asm volatile("ldmatrix.sync.aligned.x1.m8n8.shared.b16 {%0}, [%1];\n" + : "=r"(dst) + : "r"(smem_ptr)); + int* rvi = reinterpret_cast(&rv[0]); + rvi[0] = dst; +} + +__device__ void ldmatrix2(__nv_bfloat16 rv[4], uint32_t smem_ptr) { + int x, y; + asm volatile("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];\n" + : "=r"(x), "=r"(y) + : "r"(smem_ptr)); + + int* rvi = reinterpret_cast(&rv[0]); + rvi[0] = x; + rvi[1] = y; +} + +__device__ void ldmatrix4(__nv_bfloat16 rv[8], uint32_t smem_ptr) { + int x, y, z, w; + asm volatile( + "ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];" + : "=r"(x), "=r"(y), "=r"(z), "=r"(w) + : "r"(smem_ptr)); + int* rvi = reinterpret_cast(&rv[0]); + rvi[0] = x; + rvi[1] = y; + rvi[2] = z; + rvi[3] = w; +} + +__device__ void HMMA_1688(float d[4], __nv_bfloat16 a[4], __nv_bfloat16 b[2], + float c[4]) { + uint32_t const* A = reinterpret_cast(&a[0]); + uint32_t const* B = reinterpret_cast(&b[0]); + float const* C = reinterpret_cast(&c[0]); + float* D = reinterpret_cast(&d[0]); + + asm volatile( + "mma.sync.aligned.m16n8k8.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n" + : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(B[0]), "f"(C[0]), "f"(C[1]), "f"(C[2]), + "f"(C[3])); +} + +__device__ void HMMA_16816(float d[4], __nv_bfloat16 a[8], __nv_bfloat16 b[4], + float c[4]) { + uint32_t const* A = reinterpret_cast(&a[0]); + uint32_t const* B = reinterpret_cast(&b[0]); + float const* C = reinterpret_cast(&c[0]); + float* D = reinterpret_cast(&d[0]); + + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), + "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3])); +} + +__device__ void bar_wait(uint32_t bar_ptr, int phase) { + asm volatile( + "{\n" + ".reg .pred P1;\n" + "LAB_WAIT:\n" + "mbarrier.try_wait.parity.shared::cta.b64 P1, [%0], %1;\n" + "@P1 bra.uni DONE;\n" + "bra.uni LAB_WAIT;\n" + "DONE:\n" + "}\n" ::"r"(bar_ptr), + "r"(phase)); +} + +__device__ bool bar_try_wait(uint32_t bar_ptr, int phase) { + uint32_t success; + #ifdef INTERNAL + asm volatile(".pragma \"set knob DontInsertYield\";\n" : : : "memory"); + #endif + asm volatile( + "{\n\t" + ".reg .pred P1; \n\t" + "mbarrier.try_wait.parity.shared::cta.b64 P1, [%1], %2; \n\t" + "selp.b32 %0, 1, 0, P1; \n\t" + "}" + : "=r"(success) + : "r"(bar_ptr), "r"(phase)); + return success; +} + +__device__ uint32_t elect_one_sync() { + uint32_t pred = 0; + uint32_t laneid = 0; + asm volatile( + "{\n" + ".reg .b32 %%rx;\n" + ".reg .pred %%px;\n" + " elect.sync %%rx|%%px, %2;\n" + "@%%px mov.s32 %1, 1;\n" + " mov.s32 %0, %%rx;\n" + "}\n" + : "+r"(laneid), "+r"(pred) + : "r"(0xFFFFFFFF)); + return pred; +} +#endif + +struct Profile { + uint64_t start; + uint64_t weight_load_start; + uint64_t act_load_start; + uint64_t compute_start; + uint64_t complete; +}; + +template +__global__ __launch_bounds__(384, 1) void gpt_oss_router_gemm_kernel( + __nv_bfloat16* output, __nv_bfloat16* weights, __nv_bfloat16* activations, + __nv_bfloat16* bias, int M, int N, int K, + const __grid_constant__ CUtensorMap weight_map, + const __grid_constant__ CUtensorMap activation_map, + Profile* profile = nullptr) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) + + if (PROFILE && threadIdx.x == 0 && blockIdx.y == 0) + profile[blockIdx.x].start = gclock64(); + + extern __shared__ __align__(128) char smem[]; + + __nv_bfloat16* sh_weights = (__nv_bfloat16*)&smem[0]; + __nv_bfloat16* sh_activations = + (__nv_bfloat16*)&smem[STAGES * STAGE_UNROLL * TILE_M * TILE_K * + sizeof(__nv_bfloat16)]; + + #pragma nv_diag_suppress static_var_with_dynamic_init + __shared__ barrier bar_wt_ready[STAGES]; + __shared__ barrier bar_act_ready[STAGES]; + __shared__ barrier bar_data_consumed[STAGES]; + + __shared__ float4 reduction_buffer[128]; + + __shared__ nv_bfloat16 sh_bias[TILE_M]; + + if (threadIdx.x == 0) { + for (int i = 0; i < STAGES; i++) { + init(&bar_wt_ready[i], 1); + init(&bar_act_ready[i], 1); + init(&bar_data_consumed[i], 32); + } + ptx::fence_proxy_async(ptx::space_shared); + asm volatile("prefetch.tensormap [%0];" + : + : "l"(reinterpret_cast(&weight_map)) + : "memory"); + asm volatile("prefetch.tensormap [%0];" + : + : "l"(reinterpret_cast(&activation_map)) + : "memory"); + } + __syncthreads(); + + int warp_id = threadIdx.x / 32; + int lane_id = threadIdx.x % 32; + + int phase = 0; + + int mib = blockIdx.x * TILE_M; + int ni = blockIdx.y * TILE_N; + + float accum[4]; + for (int i = 0; i < 4; i++) accum[i] = 0.f; + + int const K_LOOPS_DMA = + (K + 4 * TILE_K * STAGE_UNROLL - 1) / (4 * (TILE_K * STAGE_UNROLL)); + int const K_LOOPS_COMPUTE = K_LOOPS_DMA; + + // Data loading thread + if (warp_id >= 4 && elect_one_sync()) { + int stage = warp_id % 4; + + bool weight_warp = warp_id < 8; + if (!weight_warp) { + cudaGridDependencySynchronize(); + cudaTriggerProgrammaticLaunchCompletion(); + } + + for (int ki = 0; ki < K_LOOPS_DMA; ki++) { + int k = (ki * 4 + (warp_id % 4)) * TILE_K * STAGE_UNROLL; + + uint64_t desc_ptr_wt = reinterpret_cast(&weight_map); + uint64_t desc_ptr_act = reinterpret_cast(&activation_map); + + uint32_t bar_ptr_wt = __cvta_generic_to_shared(&bar_wt_ready[stage]); + uint32_t bar_ptr_act = __cvta_generic_to_shared(&bar_act_ready[stage]); + int bytes_wt = TILE_M * TILE_K * sizeof(__nv_bfloat16); + int bytes_act = TILE_N * TILE_K * sizeof(__nv_bfloat16); + + bar_wait(__cvta_generic_to_shared(&bar_data_consumed[stage]), phase ^ 1); + + if (weight_warp) + asm volatile("mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;" + : + : "r"(bar_ptr_wt), "r"(STAGE_UNROLL * bytes_wt)); + if (!weight_warp) + asm volatile("mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;" + : + : "r"(bar_ptr_act), "r"(STAGE_UNROLL * bytes_act)); + + if (PROFILE && blockIdx.y == 0 && ki == 0 && weight_warp) + profile[blockIdx.x].weight_load_start = gclock64(); + if (PROFILE && blockIdx.y == 0 && ki == 0 && !weight_warp) + profile[blockIdx.x].act_load_start = gclock64(); + + for (int i = 0; i < STAGE_UNROLL; i++) { + uint32_t smem_ptr_wt = __cvta_generic_to_shared( + &sh_weights[(stage * STAGE_UNROLL + i) * TILE_M * TILE_K]); + uint32_t crd0 = k + i * TILE_K; + uint32_t crd1 = mib; + if (weight_warp) + asm volatile( + "cp.async.bulk.tensor.2d.shared::cta.global.mbarrier::complete_" + "tx::bytes [%0], [%1, {%3,%4}], " + "[%2];" + : + : "r"(smem_ptr_wt), "l"(desc_ptr_wt), "r"(bar_ptr_wt), "r"(crd0), + "r"(crd1) + : "memory"); + + uint32_t smem_ptr_act = __cvta_generic_to_shared( + &sh_activations[(stage * STAGE_UNROLL + i) * TILE_N * TILE_K]); + crd0 = k + i * TILE_K; + crd1 = ni; + if (!weight_warp) + asm volatile( + "cp.async.bulk.tensor.2d.shared::cta.global.mbarrier::complete_" + "tx::bytes [%0], [%1, {%3,%4}], " + "[%2];" + : + : "r"(smem_ptr_act), "l"(desc_ptr_act), "r"(bar_ptr_act), + "r"(crd0), "r"(crd1) + : "memory"); + } + + stage += 4; + if (stage >= STAGES) { + stage = warp_id % 4; + phase ^= 1; + } + } + // Wait for pending loads to be consumed before exiting, to avoid race + for (int i = 0; i < (STAGES / 4) - 1; i++) { + bar_wait(__cvta_generic_to_shared(&bar_data_consumed[stage]), phase ^ 1); + stage += 4; + if (stage >= STAGES) { + stage = warp_id % 4; + phase ^= 1; + } + } + } + // Compute threads + else if (warp_id < 4) { + // Sneak the bias load into the compute warps since they're just waiting for + // stuff anyway + if (threadIdx.x < TILE_M) sh_bias[threadIdx.x] = bias[mib + threadIdx.x]; + + int stage = warp_id; + + int phase = 0; + int lane_id_div8 = lane_id / 8; + int lane_id_mod8 = lane_id % 8; + + int lane_row_offset_wt = (lane_id_div8 % 2) ? 8 : 0; + int lane_col_offset_wt = (lane_id_div8 / 2) ? 1 : 0; + + int row_wt = lane_id_mod8 + lane_row_offset_wt; + int row_act = lane_id_mod8; + + int row_offset_wt = (reinterpret_cast(sh_weights) / 128) % 8; + int row_offset_act = row_offset_wt; + + uint32_t bar_ptr_wt = __cvta_generic_to_shared(&bar_wt_ready[stage]); + uint32_t bar_ptr_act = __cvta_generic_to_shared(&bar_act_ready[stage]); + + bool weight_ready = bar_try_wait(bar_ptr_wt, phase); + bool act_ready = bar_try_wait(bar_ptr_act, phase); + + #pragma unroll 2 + for (int ki = 0; ki < K_LOOPS_COMPUTE; ki++) { + int next_stage = stage + 4; + int next_phase = phase; + if (next_stage >= STAGES) { + next_stage = warp_id; + next_phase ^= 1; + } + + while (!weight_ready || !act_ready) { + weight_ready = bar_try_wait(bar_ptr_wt, phase); + act_ready = bar_try_wait(bar_ptr_act, phase); + } + + if (PROFILE && blockIdx.y == 0 && threadIdx.x == 0 && ki == 0) + profile[blockIdx.x].compute_start = gclock64(); + + if (ki + 1 < K_LOOPS_COMPUTE) { + weight_ready = bar_try_wait( + __cvta_generic_to_shared(&bar_wt_ready[next_stage]), next_phase); + act_ready = bar_try_wait( + __cvta_generic_to_shared(&bar_act_ready[next_stage]), next_phase); + } + + #pragma unroll + for (int su = 0; su < STAGE_UNROLL; su++) { + __nv_bfloat16* ptr_weights = + &sh_weights[(stage * STAGE_UNROLL + su) * TILE_M * TILE_K]; + __nv_bfloat16* ptr_act = + &sh_activations[(stage * STAGE_UNROLL + su) * TILE_N * TILE_K]; + + #pragma unroll + for (int kii = 0; kii < TILE_K / 16; kii++) { + __nv_bfloat16 a[8]; + __nv_bfloat16 b[4]; + + int col = 2 * kii + lane_col_offset_wt; + int col_sw = ((row_wt + row_offset_wt) % 8) ^ col; + + ldmatrix4(a, __cvta_generic_to_shared( + &ptr_weights[row_wt * TILE_K + col_sw * 8])); + + col = 2 * kii + lane_id_div8; + col_sw = ((row_act + row_offset_act) % 8) ^ col; + + ldmatrix2(b, __cvta_generic_to_shared( + &ptr_act[row_act * TILE_K + 8 * col_sw])); + + HMMA_16816(accum, a, b, accum); + } + } + + uint32_t bar_c = __cvta_generic_to_shared(&bar_data_consumed[stage]); + asm volatile("mbarrier.arrive.shared::cta.b64 _, [%0];" : : "r"(bar_c)); + + stage = next_stage; + phase = next_phase; + } + + float4 accum4; + accum4.x = accum[0]; + accum4.y = accum[1]; + accum4.z = accum[2]; + accum4.w = accum[3]; + reduction_buffer[threadIdx.x] = accum4; + + __syncthreads(); + + if (warp_id == 0) { + int mi = mib + warp_id * WARP_TILE_M; + int tm = mi + lane_id / 4; + int tn = ni + 2 * (lane_id % 4); + + float4 accum1 = reduction_buffer[32 + threadIdx.x]; + float4 accum2 = reduction_buffer[64 + threadIdx.x]; + float4 accum3 = reduction_buffer[96 + threadIdx.x]; + + accum[0] = accum[0] + accum1.x + accum2.x + accum3.x; + accum[1] = accum[1] + accum1.y + accum2.y + accum3.y; + accum[2] = accum[2] + accum1.z + accum2.z + accum3.z; + accum[3] = accum[3] + accum1.w + accum2.w + accum3.w; + + float bias_lo = __bfloat162float(sh_bias[tm - mib]); + float bias_hi = __bfloat162float(sh_bias[tm + 8 - mib]); + + if (tn < N && tm < M) + output[tn * M + tm] = __float2bfloat16(accum[0] + bias_lo); + if (tn + 1 < N && tm < M) + output[(tn + 1) * M + tm] = __float2bfloat16(accum[1] + bias_lo); + if (tn < N && tm + 8 < M) + output[tn * M + tm + 8] = __float2bfloat16(accum[2] + bias_hi); + if (tn + 1 < N && tm + 8 < M) + output[(tn + 1) * M + tm + 8] = __float2bfloat16(accum[3] + bias_hi); + + if (PROFILE && blockIdx.y == 0 && threadIdx.x == 0) + profile[blockIdx.x].complete = gclock64(); + } + } +#endif // end if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) +} diff --git a/csrc/moe/moe_ops.h b/csrc/moe/moe_ops.h index d8d962887dab..de931dc76467 100644 --- a/csrc/moe/moe_ops.h +++ b/csrc/moe/moe_ops.h @@ -70,4 +70,8 @@ torch::Tensor router_gemm_bf16_fp32(torch::Tensor const& input, // Supports num_tokens in [1, 16], num_experts in {256, 384}, hidden_dim = 7168 void dsv3_router_gemm(torch::Tensor& output, const torch::Tensor& mat_a, const torch::Tensor& mat_b); + +// gpt-oss optimized router GEMM kernel for SM90+ +void gpt_oss_router_gemm(torch::Tensor& output, torch::Tensor input, + torch::Tensor weight, torch::Tensor bias); #endif diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index 7b627a6f8760..4cd74366ea4d 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -132,6 +132,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { // DeepSeek V3 optimized router GEMM for SM90+ m.def("dsv3_router_gemm(Tensor! output, Tensor mat_a, Tensor mat_b) -> ()"); // conditionally compiled so impl registration is in source file + + // gpt-oss optimized router GEMM kernel for SM90+ + m.def( + "gpt_oss_router_gemm(Tensor! output, Tensor input, Tensor weights, " + "Tensor bias) -> ()"); + m.impl("gpt_oss_router_gemm", torch::kCUDA, &gpt_oss_router_gemm); #endif } diff --git a/docs/mkdocs/hooks/generate_examples.py b/docs/mkdocs/hooks/generate_examples.py index e886a91e6573..194db05e395e 100644 --- a/docs/mkdocs/hooks/generate_examples.py +++ b/docs/mkdocs/hooks/generate_examples.py @@ -23,15 +23,18 @@ def title(text: str) -> str: # Custom substitutions subs = { "io": "IO", - "api": "API", + "rl": "RL", + "api(s?)": r"API\1", "cli": "CLI", "cpu": "CPU", + "ipc": "IPC", "llm": "LLM", "mae": "MAE", "ner": "NER", "tpu": "TPU", "gguf": "GGUF", "lora": "LoRA", + "nccl": "NCCL", "rlhf": "RLHF", "vllm": "vLLM", "openai": "OpenAI", @@ -196,6 +199,11 @@ def generate(self) -> str: def on_startup(command: Literal["build", "gh-deploy", "serve"], dirty: bool): + # Monkey-patch dirname_to_title in awesome-nav so that sub-directory names are + # title-cased (e.g. "Offline Inference" instead of "Offline inference"). + import mkdocs_awesome_nav.nav.directory as _nav_dir + + _nav_dir.dirname_to_title = title logger.info("Generating example documentation") logger.debug("Root directory: %s", ROOT_DIR.resolve()) logger.debug("Example directory: %s", EXAMPLE_DIR.resolve()) diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index dea60155ac02..f36f74308c88 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -707,7 +707,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen | `GraniteSpeechForConditionalGeneration` | Granite Speech | T + A | `ibm-granite/granite-speech-3.3-8b` | ✅︎ | ✅︎ | | `HCXVisionForCausalLM` | HyperCLOVAX-SEED-Vision-Instruct-3B | T + I+ + V+ | `naver-hyperclovax/HyperCLOVAX-SEED-Vision-Instruct-3B` | | | | `HCXVisionV2ForCausalLM` | HyperCLOVAX-SEED-Think-32B | T + I+ + V+ | `naver-hyperclovax/HyperCLOVAX-SEED-Think-32B` | | | -| `H2OVLChatModel` | H2OVL | T + IE+ | `h2oai/h2ovl-mississippi-800m`, `h2oai/h2ovl-mississippi-2b`, etc. | | ✅︎ | +| `H2OVLChatModel` | H2OVL | T + IE+ | `h2oai/h2ovl-mississippi-800m`, `h2oai/h2ovl-mississippi-2b`, etc. | ✅︎ | ✅︎ | | `HunYuanVLForConditionalGeneration` | HunyuanOCR | T + IE+ | `tencent/HunyuanOCR`, etc. | ✅︎ | ✅︎ | | `Idefics3ForConditionalGeneration` | Idefics3 | T + I | `HuggingFaceM4/Idefics3-8B-Llama3`, etc. | ✅︎ | | | `IsaacForConditionalGeneration` | Isaac | T + I+ | `PerceptronAI/Isaac-0.1` | ✅︎ | ✅︎ | diff --git a/docs/training/async_rl.md b/docs/training/async_rl.md new file mode 100644 index 000000000000..172466f89039 --- /dev/null +++ b/docs/training/async_rl.md @@ -0,0 +1,63 @@ +# Async Reinforcement Learning + +## Overview + +In a standard RL training loop, generation and training happen sequentially: the policy generates rollouts, then training runs on those rollouts, and the cycle repeats. During generation the training accelerators sit idle, and vice versa. + +The **one-off pipelining** approach separates the generation and training phases into two parallel coroutines, allowing the model to generate new samples while simultaneously training on previously generated data. This can lead to better GPU utilization and greater training throughput. + +However, this overlap introduces a complication: weights must be updated in the inference engine mid-flight, while requests may still be in progress. + +## The Pause and Resume API + +To safely update weights while the inference engine is running, vLLM provides `pause_generation` and `resume_generation` methods. These let the trainer coordinate a clean window for weight synchronization without losing in-flight work. + +### pause_generation + +```python +await engine.pause_generation(mode="keep", clear_cache=True) +``` + +The `mode` parameter controls how in-flight requests are handled: + +| Mode | Behavior | +| ---- | -------- | +| `"abort"` | Abort all in-flight requests immediately and return partial results (default) | +| `"wait"` | Wait for all in-flight requests to finish before pausing | +| `"keep"` | Freeze requests in the queue; they resume when `resume_generation` is called | + +The `clear_cache` parameter controls whether to clear the KV cache and prefix cache after pausing. + +### resume_generation + +```python +await engine.resume_generation() +``` + +Resumes the scheduler after a pause. Any requests frozen with `mode="keep"` will continue generating. + +### HTTP Endpoints + +When using the vLLM HTTP server, the same functionality is available via: + +- `POST /pause?mode=keep` - Pause generation +- `POST /resume` - Resume generation + +!!! note "Data Parallelism" + When using data parallelism with vLLM's **internal load balancer** (i.e. `data_parallel_backend="ray"`), pause and resume are handled automatically across all DP ranks -- a single call is sufficient. When using an **external load balancer** (i.e. multiple independent vLLM instances behind a proxy), you must send pause and resume requests to **every** engine instance individually before and after the weight update. + +## Typical Async RL Flow + +A typical async RL loop with weight syncing looks like this: + +1. Start generating rollouts from the current policy +2. Once trainer has new weights to update to, pause generation with `mode="keep"` +3. Sync the updated weights from the trainer to the inference engine (see [Weight Transfer](weight_transfer/README.md)) +4. Resume generation -- in-flight requests continue with the new weights +5. Repeat + +The key insight is that requests paused with `mode="keep"` will produce tokens from the **old** weights before the pause and tokens from the **new** weights after resume. The `clear_cache` parameter controls whether the KV cache is invalidated during the pause. When `clear_cache=True`, previously cached key-value entries are discarded, so all tokens generated after resume will be computed entirely with the new weights. When `clear_cache=False`, existing KV cache entries are retained, meaning some tokens in context may still reflect the old weights (stale KV cache). + +## Example + +The [async RLHF example](../examples/rl/rlhf_async_new_apis.md) demonstrates this pattern with `vllm.AsyncLLMEngine`, NCCL weight transfer, and mid-flight pause/resume with validation. diff --git a/docs/training/rlhf.md b/docs/training/rlhf.md index 0b7e384dc8d6..3eddd4fbecfb 100644 --- a/docs/training/rlhf.md +++ b/docs/training/rlhf.md @@ -16,11 +16,9 @@ The following open-source RL libraries use vLLM for fast rollouts (sorted alphab - [Unsloth](https://github.com/unslothai/unsloth) - [verl](https://github.com/volcengine/verl) -See the following basic examples to get started if you don't want to use an existing library: +For weight synchronization between training and inference, see the [Weight Transfer](weight_transfer/README.md) documentation, which covers the pluggable backend system with [NCCL](weight_transfer/nccl.md) (multi-GPU) and [IPC](weight_transfer/ipc.md) (same-GPU) engines. -- [Training and inference processes are located on separate GPUs (inspired by OpenRLHF)](../examples/offline_inference/rlhf.md) -- [Training and inference processes are colocated on the same GPUs using Ray](../examples/offline_inference/rlhf_colocate.md) -- [Utilities for performing RLHF with vLLM](../examples/offline_inference/rlhf_utils.md) +For pipelining generation and training to improve GPU utilization and throughput, see the [Async Reinforcement Learning](async_rl.md) guide, which covers the pause/resume API for safely updating weights mid-flight. See the following notebooks showing how to use vLLM for GRPO: diff --git a/docs/training/weight_transfer/README.md b/docs/training/weight_transfer/README.md new file mode 100644 index 000000000000..17afd2bc8965 --- /dev/null +++ b/docs/training/weight_transfer/README.md @@ -0,0 +1,78 @@ +# Weight Transfer + +vLLM provides a pluggable weight transfer system for synchronizing model weights from a training process to the inference engine during reinforcement learning (RL) workflows. This is essential for RLHF, GRPO, and other online RL methods where the policy model is iteratively updated during training and the updated weights must be reflected in the inference engine for rollout generation. + +## Architecture + +The weight transfer system follows a **two-phase protocol** with a pluggable backend design: + +1. **Initialization** (`init_weight_transfer_engine`): Establishes the communication channel between the trainer and inference workers. Called once before the training loop begins. +2. **Weight Update** (`update_weights`): Transfers updated weights from the trainer to the inference engine. Called after each training step (or batch of steps). + +## Available Backends + +| Backend | Transport | Use Case | +| ------- | --------- | -------- | +| [NCCL](nccl.md) | NCCL broadcast | Separate GPUs for training and inference | +| [IPC](ipc.md) | CUDA IPC handles | Colocated training and inference on same GPU | + +## Configuration + +Specify the weight transfer backend through `WeightTransferConfig`. The backend determines which engine handles the weight synchronization. + +### Programmatic (Offline Inference) + +```python +from vllm import LLM +from vllm.config import WeightTransferConfig + +llm = LLM( + model="my-model", + weight_transfer_config=WeightTransferConfig(backend="nccl"), # or "ipc" +) +``` + +### CLI (Online Serving) + +```bash +vllm serve my-model \ + --weight-transfer-config '{"backend": "nccl"}' +``` + +The `backend` field accepts `"nccl"` (default) or `"ipc"`. + +## API Endpoints + +When running vLLM as an HTTP server, the following endpoints are available for weight transfer: + +| Endpoint | Method | Description | +| -------- | ------ | ----------- | +| `/init_weight_transfer_engine` | POST | Initialize the weight transfer engine with backend-specific info | +| `/update_weights` | POST | Trigger a weight update with backend-specific metadata | +| `/pause` | POST | Pause generation before weight sync to handle inflight requests | +| `/resume` | POST | Resume generation after weight sync | +| `/get_world_size` | GET | Get the number of inference workers (useful for NCCL world size calculation) | + +!!! note + The HTTP weight transfer endpoints require `VLLM_SERVER_DEV_MODE=1` to be set. + +## Trainer-Side API + +Both backends provide static methods that the trainer calls to send weights. The general pattern is: + +```python +# 1. Initialize the transfer engine (backend-specific) +EngineClass.trainer_init(init_info) + +# 2. Send weights to inference workers +EngineClass.trainer_send_weights( + iterator=model.named_parameters(), + trainer_args=backend_specific_args, +) +``` + +See the [NCCL](nccl.md) and [IPC](ipc.md) pages for backend-specific trainer APIs and full examples. + +## Extending the System + +The weight transfer system is designed to be extensible. You can implement custom backends by subclassing `WeightTransferEngine` and registering them with the factory. See the [Base Class](base.md) page for details. diff --git a/docs/training/weight_transfer/base.md b/docs/training/weight_transfer/base.md new file mode 100644 index 000000000000..973ec8ad9f55 --- /dev/null +++ b/docs/training/weight_transfer/base.md @@ -0,0 +1,162 @@ +# Base Class and Custom Engines + +The weight transfer system is built on an abstract base class that defines the contract between vLLM's worker infrastructure and the transport backend. You can implement custom backends by subclassing `WeightTransferEngine` and registering them with the `WeightTransferEngineFactory`. + +## WeightTransferEngine + +The `WeightTransferEngine` is a generic abstract class parameterized by two dataclass types: + +- **`TInitInfo`** (extends `WeightTransferInitInfo`): Backend-specific initialization parameters. +- **`TUpdateInfo`** (extends `WeightTransferUpdateInfo`): Backend-specific weight update metadata. + +### Abstract Methods + +Subclasses must implement these four methods: + +| Method | Side | Description | +| ------ | ---- | ----------- | +| `init_transfer_engine(init_info)` | Inference | Initialize the communication channel on each inference worker | +| `receive_weights(update_info, load_weights)` | Inference | Receive weights and call `load_weights` incrementally | +| `shutdown()` | Inference | Clean up resources | +| `trainer_send_weights(iterator, trainer_args)` | Trainer | Static method to send weights from the trainer process | + +### Request Classes + +The API-level request classes provide backend-agnostic serialization using plain dictionaries. The engine's `parse_init_info` and `parse_update_info` methods convert these dictionaries into typed dataclasses. + +```python +from vllm.distributed.weight_transfer.base import ( + WeightTransferInitRequest, + WeightTransferUpdateRequest, +) + +# Init request (dict is converted to backend-specific TInitInfo) +init_request = WeightTransferInitRequest( + init_info={"master_address": "10.0.0.1", "master_port": 29500, ...} +) + +# Update request (dict is converted to backend-specific TUpdateInfo) +update_request = WeightTransferUpdateRequest( + update_info={"names": [...], "dtype_names": [...], "shapes": [...]} +) +``` + +### WeightTransferUpdateInfo + +The base `WeightTransferUpdateInfo` includes an `is_checkpoint_format` flag: + +```python +@dataclass +class WeightTransferUpdateInfo(ABC): + is_checkpoint_format: bool = True +``` + +When `is_checkpoint_format=True` (the default), vLLM applies layerwise weight processing (repacking, renaming, etc.) on the received weights before loading them. Set to `False` if the trainer has already converted weights to the kernel format expected by the model. + +## Implementing a Custom Engine + +To create a custom weight transfer backend: + +### 1. Define Info Dataclasses + +```python +from dataclasses import dataclass +from vllm.distributed.weight_transfer.base import ( + WeightTransferEngine, + WeightTransferInitInfo, + WeightTransferUpdateInfo, +) + +@dataclass +class MyInitInfo(WeightTransferInitInfo): + endpoint: str + token: str + +@dataclass +class MyUpdateInfo(WeightTransferUpdateInfo): + names: list[str] + dtype_names: list[str] + shapes: list[list[int]] + # Add custom fields as needed +``` + +### 2. Implement the Engine + +```python +from collections.abc import Callable, Iterator +from typing import Any +import torch + +class MyWeightTransferEngine(WeightTransferEngine[MyInitInfo, MyUpdateInfo]): + init_info_cls = MyInitInfo + update_info_cls = MyUpdateInfo + + def init_transfer_engine(self, init_info: MyInitInfo) -> None: + # Set up connection to trainer using init_info.endpoint, etc. + ... + + def receive_weights( + self, + update_info: MyUpdateInfo, + load_weights: Callable[[list[tuple[str, torch.Tensor]]], None], + ) -> None: + # Receive each weight and call load_weights incrementally + for name, dtype_name, shape in zip( + update_info.names, update_info.dtype_names, update_info.shapes + ): + dtype = getattr(torch, dtype_name) + weight = self._fetch_weight(name, shape, dtype) + load_weights([(name, weight)]) + + def shutdown(self) -> None: + # Clean up resources + ... + + @staticmethod + def trainer_send_weights( + iterator: Iterator[tuple[str, torch.Tensor]], + trainer_args: dict[str, Any], + ) -> None: + # Send weights from the trainer process + for name, tensor in iterator: + # Send tensor via custom transport + ... +``` + +!!! important + The `load_weights` callable passed to `receive_weights` should be called **incrementally** (one or a few weights at a time) rather than accumulating all weights first. This avoids GPU out-of-memory errors with large models. + +### 3. Register with the Factory + +```python +from vllm.distributed.weight_transfer.factory import WeightTransferEngineFactory + +# Option 1: Lazy loading (recommended for built-in engines) +WeightTransferEngineFactory.register_engine( + "my_backend", + "my_package.my_module", + "MyWeightTransferEngine", +) + +# Option 2: Direct class registration +WeightTransferEngineFactory.register_engine( + "my_backend", + MyWeightTransferEngine, +) +``` + +Once registered, users can select your backend via `WeightTransferConfig(backend="my_backend")`. + +## WeightTransferEngineFactory + +The factory uses a registry pattern with lazy loading. Built-in engines (`nccl` and `ipc`) are registered at import time but their modules are only loaded when the backend is actually requested. This avoids importing heavy dependencies (like NCCL communicators) when they aren't needed. + +```python +from vllm.distributed.weight_transfer.factory import WeightTransferEngineFactory + +# Create an engine from config +engine = WeightTransferEngineFactory.create_engine( + config=weight_transfer_config, + parallel_config=parallel_config, +) +``` diff --git a/docs/training/weight_transfer/ipc.md b/docs/training/weight_transfer/ipc.md new file mode 100644 index 000000000000..8e19fa7b429b --- /dev/null +++ b/docs/training/weight_transfer/ipc.md @@ -0,0 +1,73 @@ +# IPC Engine + +The IPC weight transfer engine uses **CUDA IPC** (Inter-Process Communication) handles to share GPU memory directly between the trainer and inference workers on the **same node and same GPU**. This avoids any data copying, making it a efficient option when colocating training and inference. + +## When to Use IPC + +- Training and inference on the **same GPU** (colocated) +- You want to minimize memory overhead by sharing tensors in-place + +## How It Works + +1. The trainer creates CUDA tensors for each weight and generates IPC handles using `torch.multiprocessing.reductions.reduce_tensor`. +2. IPC handles are sent to the inference engine via **Ray.remote()** or **HTTP POST**. +3. The inference worker reconstructs the tensors from the handles, reading directly from the trainer's GPU memory. + +!!! warning + IPC handles involve sending serialized Python objects. When using HTTP transport, you must set `VLLM_ALLOW_INSECURE_SERIALIZATION=1` on both the server and client. This is because IPC handles are pickled and base64-encoded for HTTP transmission. + +## Initialization + +The IPC backend requires no initialization on either side. The `init_transfer_engine` call is a no-op for IPC. + +## Sending Weights + +IPC supports two transport modes for delivering the handles: + +### Ray Mode + +Used when vLLM is running as a Ray actor: + +```python +from vllm.distributed.weight_transfer.ipc_engine import ( + IPCTrainerSendWeightsArgs, + IPCWeightTransferEngine, +) + +trainer_args = IPCTrainerSendWeightsArgs( + mode="ray", + llm_handle=llm_actor_handle, +) + +IPCWeightTransferEngine.trainer_send_weights( + iterator=model.named_parameters(), + trainer_args=trainer_args, +) +``` + +In Ray mode, the engine calls `llm_handle.update_weights.remote(...)` directly, passing the IPC handles via Ray's serialization. + +### HTTP Mode + +Used when vLLM is running as an HTTP server: + +```python +trainer_args = IPCTrainerSendWeightsArgs( + mode="http", + url="http://localhost:8000", +) + +IPCWeightTransferEngine.trainer_send_weights( + iterator=model.named_parameters(), + trainer_args=trainer_args, +) +``` + +In HTTP mode, IPC handles are pickled, base64-encoded, and sent as JSON to the `/update_weights` endpoint. + +See [`IPCTrainerSendWeightsArgs`](https://github.com/vllm-project/vllm/blob/main/vllm/distributed/weight_transfer/ipc_engine.py) for the full list of configurable fields. + +## Examples + +- [RLHF with IPC weight syncing (offline, Ray)](../../examples/rl/rlhf_ipc.md) - Colocated training and inference on a single GPU using Ray placement groups and CUDA IPC handles +- [RLHF with IPC weight syncing (online serving, HTTP)](../../examples/rl/rlhf_http_ipc.md) - Weight transfer with a vLLM HTTP server where both server and trainer share the same GPU diff --git a/docs/training/weight_transfer/nccl.md b/docs/training/weight_transfer/nccl.md new file mode 100644 index 000000000000..a50b3664d89d --- /dev/null +++ b/docs/training/weight_transfer/nccl.md @@ -0,0 +1,110 @@ +# NCCL Engine + +The NCCL weight transfer engine uses [NCCL](https://developer.nvidia.com/nccl) broadcast operations to transfer weights from the trainer to inference workers. It supports **multi-node** and **multi-GPU** setups where the trainer and inference engine run on separate GPUs. + +## When to Use NCCL + +- Training and inference on **separate GPUs** (possibly across nodes) +- **Tensor-parallel** inference with multiple workers that all need the updated weights +- You need high-bandwidth, low-latency weight transfer over NVLink or InfiniBand + +## How It Works + +1. The trainer and all inference workers join a shared NCCL process group using `StatelessProcessGroup` (vLLM's torch.distributed-independent group abstraction). +2. The trainer broadcasts weights to all workers simultaneously. Each worker receives and loads weights incrementally. +3. Optionally, **packed tensor broadcasting** batches multiple small tensors into larger buffers with double/triple buffering and CUDA stream overlap for higher throughput. This implementation is based on [NeMo-RL's packed tensor](https://github.com/NVIDIA-NeMo/RL/blob/main/nemo_rl/utils/packed_tensor.py). + +## Initialization + +NCCL requires explicit process group setup. The trainer and inference workers must agree on a master address, port, and world size. + +### Inference Side + +```python +from vllm.distributed.weight_transfer.base import WeightTransferInitRequest + +# rank_offset accounts for the trainer occupying rank 0 +llm.init_weight_transfer_engine( + WeightTransferInitRequest( + init_info=dict( + master_address=master_address, + master_port=master_port, + rank_offset=1, + world_size=world_size, # trainer + all inference workers + ) + ) +) +``` + +### Trainer Side + +```python +from vllm.distributed.weight_transfer.nccl_engine import ( + NCCLWeightTransferEngine, +) + +group = NCCLWeightTransferEngine.trainer_init( + dict( + master_address=master_address, + master_port=master_port, + world_size=world_size, + ) +) +``` + +!!! note + `trainer_init` always assigns the trainer to rank 0. Inference workers start at `rank_offset` (typically 1). + +## Sending Weights + +```python +from vllm.distributed.weight_transfer.nccl_engine import ( + NCCLTrainerSendWeightsArgs, + NCCLWeightTransferEngine, +) + +trainer_args = NCCLTrainerSendWeightsArgs( + group=group, + packed=True, # use packed broadcasting for efficiency +) + +NCCLWeightTransferEngine.trainer_send_weights( + iterator=model.named_parameters(), + trainer_args=trainer_args, +) +``` + +See [`NCCLTrainerSendWeightsArgs`](https://github.com/vllm-project/vllm/blob/main/vllm/distributed/weight_transfer/nccl_engine.py) for the full list of configurable fields. + +### Packed Tensor Broadcasting + +When `packed=True`, multiple weight tensors are packed into large contiguous buffers before broadcasting. This reduces the number of NCCL operations and uses double/triple buffering with dedicated CUDA streams for overlap between packing, broadcasting, and unpacking. + +Both the trainer (`NCCLTrainerSendWeightsArgs`) and inference side (`NCCLWeightTransferUpdateInfo`) must use matching `packed_buffer_size_bytes` and `packed_num_buffers` values. + +## Receiving Weights (Inference Side) + +The inference side triggers weight reception by calling `update_weights`: + +```python +from vllm.distributed.weight_transfer.base import WeightTransferUpdateRequest + +llm.update_weights( + WeightTransferUpdateRequest( + update_info=dict( + names=names, + dtype_names=dtype_names, + shapes=shapes, + packed=True, + ) + ) +) +``` + +The `names`, `dtype_names`, and `shapes` lists describe each parameter. These must match the order in which the trainer iterates over its parameters. + +## Examples + +- [RLHF with NCCL weight syncing (offline, Ray)](../../examples/rl/rlhf_nccl.md) - Trainer on one GPU, 2x tensor-parallel vLLM engine on two others, with packed NCCL weight broadcast +- [RLHF with async weight syncing (offline, Ray)](../../examples/rl/rlhf_async_new_apis.md) - Async generation with mid-flight pause, weight sync, resume, and validation against a fresh model +- [RLHF with NCCL weight syncing (online serving, HTTP)](../../examples/rl/rlhf_http_nccl.md) - Weight transfer with a running vLLM HTTP server using HTTP control plane and NCCL data plane diff --git a/examples/offline_inference/rlhf.py b/examples/offline_inference/rlhf.py deleted file mode 100644 index 6f05968ce065..000000000000 --- a/examples/offline_inference/rlhf.py +++ /dev/null @@ -1,147 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -""" -Demonstrates reinforcement learning from human feedback (RLHF) using vLLM and Ray. - -The script separates training and inference workloads onto distinct GPUs -so that Ray can manage process placement and inter-process communication. -A Hugging Face Transformer model occupies GPU 0 for training, whereas a -tensor-parallel vLLM inference engine occupies GPU 1–2. - -The example performs the following steps: - -* Load the training model on GPU 0. -* Split the inference model across GPUs 1–2 using vLLM's tensor parallelism - and Ray placement groups. -* Generate text from a list of prompts using the inference engine. -* Update the weights of the training model and broadcast the updated weights - to the inference engine by using a Ray collective RPC group. Note that - for demonstration purposes we simply zero out the weights. - -For a production-ready implementation that supports multiple training and -inference replicas, see the OpenRLHF framework: -https://github.com/OpenRLHF/OpenRLHF - -This example assumes a single-node cluster with three GPUs, but Ray -supports multi-node clusters. vLLM expects the GPUs are only used for vLLM -workloads. Residual GPU activity interferes with vLLM memory profiling and -causes unexpected behavior. -""" - -import os - -import ray -import torch -from ray.util.placement_group import placement_group -from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy -from rlhf_utils import stateless_init_process_group -from transformers import AutoModelForCausalLM - -from vllm import LLM, SamplingParams -from vllm.utils.network_utils import get_ip, get_open_port - - -class MyLLM(LLM): - """Configure the vLLM worker for Ray placement group execution.""" - - def __init__(self, *args, **kwargs): - # Remove the top-level CUDA_VISIBLE_DEVICES variable set by Ray - # so that vLLM can manage its own device placement within the worker. - os.environ.pop("CUDA_VISIBLE_DEVICES", None) - super().__init__(*args, **kwargs) - - -# Load the OPT-125M model onto GPU 0 for the training workload. -train_model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m") -train_model.to("cuda:0") - -# Initialize Ray and set the visible devices. The vLLM engine will -# be placed on GPUs 1 and 2. -os.environ["CUDA_VISIBLE_DEVICES"] = "1,2" -ray.init() - -# Create a placement group that reserves GPU 1–2 for the vLLM inference engine. -# Learn more about Ray placement groups: -# https://docs.ray.io/en/latest/ray-core/scheduling/placement-group.html -pg_inference = placement_group([{"GPU": 1, "CPU": 0}] * 2) -ray.get(pg_inference.ready()) -scheduling_inference = PlacementGroupSchedulingStrategy( - placement_group=pg_inference, - placement_group_capture_child_tasks=True, - placement_group_bundle_index=0, -) - -# Launch the vLLM inference engine. The `enforce_eager` flag reduces -# start-up latency. -llm = ray.remote( - num_cpus=0, - num_gpus=0, - scheduling_strategy=scheduling_inference, -)(MyLLM).remote( - model="facebook/opt-125m", - enforce_eager=True, - worker_extension_cls="rlhf_utils.WorkerExtension", - tensor_parallel_size=2, - distributed_executor_backend="ray", -) - -# Generate text from the prompts. -prompts = [ - "Hello, my name is", - "The president of the United States is", - "The capital of France is", - "The future of AI is", -] - -sampling_params = SamplingParams(temperature=0) - -outputs = ray.get(llm.generate.remote(prompts, sampling_params)) - -print("-" * 50) -for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}") - print("-" * 50) - -# Set up the communication channel between the training process and the -# inference engine. -master_address = get_ip() -master_port = get_open_port() - -handle = llm.collective_rpc.remote( - "init_weight_update_group", args=(master_address, master_port, 1, 3) -) - -model_update_group = stateless_init_process_group( - master_address, master_port, 0, 3, torch.device("cuda:0") -) -ray.get(handle) - -# Simulate a training step by zeroing out all model weights. -# In a real RLHF training loop the weights would be updated using the gradient -# from an RL objective such as PPO on a reward model. -for name, p in train_model.named_parameters(): - p.data.zero_() - -# Synchronize the updated weights to the inference engine. -for name, p in train_model.named_parameters(): - dtype_name = str(p.dtype).split(".")[-1] - handle = llm.collective_rpc.remote( - "update_weight", args=(name, dtype_name, p.shape) - ) - model_update_group.broadcast(p, src=0, stream=torch.cuda.current_stream()) - ray.get(handle) - -# Verify that the inference weights have been updated. -assert all(ray.get(llm.collective_rpc.remote("check_weights_changed"))) - -# Generate text with the updated model. The output is expected to be nonsense -# because the weights are zero. -outputs_updated = ray.get(llm.generate.remote(prompts, sampling_params)) -print("-" * 50) -for output in outputs_updated: - prompt = output.prompt - generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}") - print("-" * 50) diff --git a/examples/offline_inference/rlhf_colocate.py b/examples/offline_inference/rlhf_colocate.py deleted file mode 100644 index ea4b3a6b911e..000000000000 --- a/examples/offline_inference/rlhf_colocate.py +++ /dev/null @@ -1,256 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -""" -Demonstrates how to co-locate a vLLM inference worker and training -actors on the same set of GPUs for reinforcement learning from human feedback -(RLHF) workloads. - -Ray serves as the distributed execution framework in this example. Ray -placement groups allocate both training actors and vLLM workers to the -same GPU bundles, enabling fast, in-GPU communication between the two -components. - -The script shows how to do the following: - -* Configure environment variables (`VLLM_RAY_PER_WORKER_GPUS` and - `VLLM_RAY_BUNDLE_INDICES`) so that vLLM workers land on the desired - devices. -* Exchange tensors between processes by means of CUDA inter-process - communication (IPC). CUDA IPC sidesteps NCCL limitations that occur - when multiple processes share a single GPU. - -Note that this example assumes a single-node cluster with four GPUs, but Ray -supports multi-node clusters. vLLM expects exclusive use of the GPUs during -its initialization for memory profiling. Residual GPU activity interferes -with vLLM memory profiling and causes unexpected behavior. - -Learn more about Ray placement groups: -https://docs.ray.io/en/latest/placement-groups.html -""" - -import gc -import os -import sys - -import ray -import torch -import zmq -from ray.util.placement_group import placement_group -from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy -from torch.multiprocessing.reductions import reduce_tensor - -from vllm import LLM - -if torch.version.hip is not None: - print("Skipping test for ROCm. Ray is unsupported on vLLM ROCm.") - sys.exit(0) - - -class MyLLM(LLM): - """Configure the vLLM worker for Ray placement group execution. - - The constructor sets environment variables that allow multiple vLLM - workers to share a single physical GPU and that encode the bundle - indices assigned by the placement group. - - Args: - *args: Positional arguments forwarded to `vllm.LLM`. - bundle_indices (list[int]): Placement-group bundle indices - assigned to this worker. - **kwargs: Keyword arguments forwarded to `vllm.LLM`. - """ - - def __init__(self, *args, bundle_indices: list[int], **kwargs): - # Prevent Ray from manipulating the top-level CUDA_VISIBLE_DEVICES variable - # so that vLLM can its own device placement inside the worker. - os.environ.pop("CUDA_VISIBLE_DEVICES", None) - # Each worker uses 0.4 GPU so that two instances fit on the same GPUs. - os.environ["VLLM_RAY_PER_WORKER_GPUS"] = "0.4" - os.environ["VLLM_RAY_BUNDLE_INDICES"] = ",".join(map(str, bundle_indices)) - print(f"creating LLM with bundle_indices={bundle_indices}") - super().__init__(*args, **kwargs) - - -class RayTrainingActor: - """Training actor that hosts a Facebook OPT-125M model from Hugging Face. - - The model is loaded onto the first GPU assigned to this actor, and expose - the CUDA IPC handles so that colocated vLLM workers can map tensors - directly. - """ - - def __init__(self): - # Ray sets CUDA_VISIBLE_DEVICES to the GPUs assigned to this actor. - from transformers import AutoModelForCausalLM - - self.model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m") - self.model.to("cuda:0") - # Zero out all the parameters. - for name, p in self.model.named_parameters(): - p.data.zero_() - torch.accelerator.synchronize() - # The argument for `get_device_uuid` is the index of the GPU in the - # list of visible devices. - from vllm.platforms import current_platform - - self.device_uuid = current_platform.get_device_uuid(0) - self.zmq_context = zmq.Context() - self.zmq_address_counter = 0 - self.zmq_handle = None - - def report_device_id(self) -> str: - return self.device_uuid - - def get_zmq_handles(self) -> dict[str, str]: - suffix = f"{self.device_uuid}-{self.zmq_address_counter}" - self.zmq_handle = f"ipc:///tmp/rl-colocate-zmq-{suffix}.sock" - self.zmq_address_counter += 1 - return {self.device_uuid: self.zmq_handle} - - def update_weights(self): - # align size to avoid misaligned address - align_size = 256 - - def get_size(p: torch.Tensor) -> int: - return (p.nbytes + align_size - 1) // align_size * align_size - - named_parameters: dict[str, torch.nn.Parameter] = dict( - self.model.named_parameters() - ) - max_tensor_size = max(get_size(p) for p in named_parameters.values()) - # use max_tensor_size * 2 as buffer size - buffer = torch.empty(max_tensor_size * 2, dtype=torch.uint8, device="cuda:0") - s = self.zmq_context.socket(zmq.REQ) - s.bind(self.zmq_handle) - handle = reduce_tensor(buffer) - - offset = 0 - buckets: list[tuple[list[dict], list[torch.Tensor]]] = [] - named_tensors: list[dict] = [] - real_tensors: list[torch.Tensor] = [] - for name, p in named_parameters.items(): - size = get_size(p) - if offset + size > buffer.numel(): - buckets.append((named_tensors, real_tensors)) - named_tensors, real_tensors = [], [] - offset = 0 - # assume tensors are contiguous - named_tensors.append( - {"name": name, "dtype": p.dtype, "shape": p.shape, "offset": offset} - ) - real_tensors.append(p) - offset += size - if named_tensors: - buckets.append((named_tensors, real_tensors)) - s.send_pyobj(handle) - s.recv() - for named_tensors, real_tensors in buckets: - offset = 0 - for p in real_tensors: - buffer[offset : offset + p.nbytes].data.copy_( - p.data.view(-1).view(dtype=torch.uint8), non_blocking=True - ) - offset += get_size(p) - torch.accelerator.synchronize() - s.send_pyobj(named_tensors) - s.recv() - s.send_pyobj(None) - s.recv() - s.close() - del buffer - gc.collect() - torch.accelerator.empty_cache() - - -# Ray manages four GPUs. - -os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3" -ray.init() - -# Co-locate vLLM instances and training actors on the same set of GPUs: -# * GPU 0 and 1: training actor 0, training actor 1, and vLLM instance 0 -# (tensor parallelism = 2). -# * GPU 2 and 3: training actor 2, training actor 3, and vLLM instance 1 -# (tensor parallelism = 2). - -pg = placement_group([{"GPU": 1, "CPU": 0}] * 4) -ray.get(pg.ready()) -print(f"placement group has bundles {pg.bundle_specs=}") - -training_actors = [] -training_actor_device_ids = [] -inference_engines = [] -inference_engine_device_ids = [] - -for bundle_index in [0, 1, 2, 3]: - training_actor = ray.remote( - num_cpus=0, - num_gpus=0.4, - scheduling_strategy=PlacementGroupSchedulingStrategy( - placement_group=pg, - placement_group_capture_child_tasks=True, - placement_group_bundle_index=bundle_index, - ), - )(RayTrainingActor).remote() - training_actors.append(training_actor) - -for bundle_index, training_actor in enumerate(training_actors): - device_id = ray.get(training_actor.report_device_id.remote()) - print(f"training actor {bundle_index} is on {device_id}") - training_actor_device_ids.append(device_id) - -for i, bundle_indices in enumerate([[0, 1], [2, 3]]): - # Use the following syntax instead of the @ray.remote decorator so that - # the placement group is customized for each bundle. - llm = ray.remote( - num_cpus=0, - num_gpus=0, - scheduling_strategy=PlacementGroupSchedulingStrategy( - placement_group=pg, - placement_group_capture_child_tasks=True, - ), - )(MyLLM).remote( - model="facebook/opt-125m", - enforce_eager=True, - worker_extension_cls="rlhf_utils.ColocateWorkerExtension", - tensor_parallel_size=2, - distributed_executor_backend="ray", - gpu_memory_utilization=0.4, - bundle_indices=bundle_indices, - ) - inference_engines.append(llm) - # Do not call any method on the inference engine at this point; the call - # blocks until the vLLM instance finishes initialization. - -for i, llm in enumerate(inference_engines): - inference_engine_device_ids.append( - ray.get(llm.collective_rpc.remote("report_device_id", args=tuple())) - ) - print(f"inference engine {i} is on {inference_engine_device_ids[-1]}") - -# Verify placement: the first two training actors share the same GPUs as -# the first inference engine. -assert training_actor_device_ids[:2] == inference_engine_device_ids[0] -# Verify placement: the last two training actors share the same GPUs as -# the second inference engine. -assert training_actor_device_ids[2:] == inference_engine_device_ids[1] - -print("Gather all the ZMQ handles from the training actors.") -zmq_handles = {} -for actor in training_actors: - zmq_handles.update(ray.get(actor.get_zmq_handles.remote())) - -print(f"ZMQ handles: {zmq_handles}") - -print("Update the weights of the inference engines.") -ray.get( - [actor.update_weights.remote() for actor in training_actors] - + [ - llm.collective_rpc.remote("update_weights_from_ipc", args=(zmq_handles,)) - for llm in inference_engines - ] -) - -print("Check if the weights are updated.") -for llm in inference_engines: - assert ray.get(llm.collective_rpc.remote("check_weights_changed", args=tuple())) diff --git a/examples/offline_inference/rlhf_online_quant.py b/examples/offline_inference/rlhf_online_quant.py deleted file mode 100644 index 2d98ad22c589..000000000000 --- a/examples/offline_inference/rlhf_online_quant.py +++ /dev/null @@ -1,162 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -""" -Demonstrates reinforcement learning from human feedback (RLHF) using vLLM and Ray. - -The script separates training and inference workloads onto distinct GPUs -so that Ray can manage process placement and inter-process communication. -A Hugging Face Transformer model occupies GPU 0 for training, whereas a -tensor-parallel vLLM inference engine occupies GPU 1–2. - -The example performs the following steps: - -* Load the training model on GPU 0. -* Split the inference model across GPUs 1–2 using vLLM's tensor parallelism - and Ray placement groups. -* Generate text from a list of prompts using the inference engine. -* Update the weights of the training model and broadcast the updated weights - to the inference engine by using a Ray collective RPC group. Note that - for demonstration purposes we simply zero out the weights. - -For a production-ready implementation that supports multiple training and -inference replicas, see the OpenRLHF framework: -https://github.com/OpenRLHF/OpenRLHF - -This example assumes a single-node cluster with three GPUs, but Ray -supports multi-node clusters. vLLM expects the GPUs are only used for vLLM -workloads. Residual GPU activity interferes with vLLM memory profiling and -causes unexpected behavior. -""" - -import json -import os - -import ray -import torch -from ray.util.placement_group import placement_group -from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy -from rlhf_utils import stateless_init_process_group -from torchao.core.config import config_to_dict -from torchao.quantization import ( - Float8DynamicActivationFloat8WeightConfig, - PerRow, -) -from transformers import AutoModelForCausalLM - -from vllm import LLM, SamplingParams -from vllm.utils.network_utils import get_ip, get_open_port - - -class MyLLM(LLM): - """Configure the vLLM worker for Ray placement group execution.""" - - def __init__(self, *args, **kwargs): - # Remove the top-level CUDA_VISIBLE_DEVICES variable set by Ray - # so that vLLM can manage its own device placement within the worker. - os.environ.pop("CUDA_VISIBLE_DEVICES", None) - super().__init__(*args, **kwargs) - - -# Load the OPT-125M model onto GPU 0 for the training workload. -train_model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m") -train_model.to("cuda:0") - -# Initialize Ray and set the visible devices. The vLLM engine will -# be placed on GPUs 1 and 2. -os.environ["CUDA_VISIBLE_DEVICES"] = "1,2" -ray.init() - -# Create a placement group that reserves GPU 1–2 for the vLLM inference engine. -# Learn more about Ray placement groups: -# https://docs.ray.io/en/latest/ray-core/scheduling/placement-group.html -pg_inference = placement_group([{"GPU": 1, "CPU": 0}] * 2) -ray.get(pg_inference.ready()) -scheduling_inference = PlacementGroupSchedulingStrategy( - placement_group=pg_inference, - placement_group_capture_child_tasks=True, - placement_group_bundle_index=0, -) - -# Launch the vLLM inference engine. The `enforce_eager` flag reduces -# start-up latency. - -# generate torchao quantization config for RL rollout -# see https://github.com/vllm-project/vllm/pull/23014 for instructions to -# use serialized config files instead of passing around json string -config = Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()) - -json_str = json.dumps(config_to_dict(config)) - -llm = ray.remote( - num_cpus=0, - num_gpus=0, - scheduling_strategy=scheduling_inference, -)(MyLLM).remote( - model="facebook/opt-125m", - hf_overrides={"quantization_config_dict_json": json_str}, - enforce_eager=True, - worker_extension_cls="rlhf_utils.WorkerExtension", - tensor_parallel_size=2, - distributed_executor_backend="ray", -) - -# Generate text from the prompts. -prompts = [ - "Hello, my name is", - "The president of the United States is", - "The capital of France is", - "The future of AI is", -] - -sampling_params = SamplingParams(temperature=0) - -outputs = ray.get(llm.generate.remote(prompts, sampling_params)) - -print("-" * 50) -for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}") - print("-" * 50) - -# Set up the communication channel between the training process and the -# inference engine. -master_address = get_ip() -master_port = get_open_port() - -handle = llm.collective_rpc.remote( - "init_weight_update_group", args=(master_address, master_port, 1, 3) -) - -model_update_group = stateless_init_process_group( - master_address, master_port, 0, 3, torch.device("cuda:0") -) -ray.get(handle) - -# Simulate a training step by zeroing out all model weights. -# In a real RLHF training loop the weights would be updated using the gradient -# from an RL objective such as PPO on a reward model. -for name, p in train_model.named_parameters(): - p.data.zero_() - -# Synchronize the updated weights to the inference engine. -for name, p in train_model.named_parameters(): - dtype_name = str(p.dtype).split(".")[-1] - handle = llm.collective_rpc.remote( - "update_weight", args=(name, dtype_name, p.shape) - ) - model_update_group.broadcast(p, src=0, stream=torch.cuda.current_stream()) - ray.get(handle) - -# Verify that the inference weights have been updated. -assert all(ray.get(llm.collective_rpc.remote("check_weights_changed"))) - -# Generate text with the updated model. The output is expected to be nonsense -# because the weights are zero. -outputs_updated = ray.get(llm.generate.remote(prompts, sampling_params)) -print("-" * 50) -for output in outputs_updated: - prompt = output.prompt - generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}") - print("-" * 50) diff --git a/examples/offline_inference/rlhf_utils.py b/examples/offline_inference/rlhf_utils.py deleted file mode 100644 index e9fc393bb549..000000000000 --- a/examples/offline_inference/rlhf_utils.py +++ /dev/null @@ -1,168 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import gc -from collections.abc import Callable -from typing import TypedDict - -import torch -import zmq - - -def stateless_init_process_group(master_address, master_port, rank, world_size, device): - """ - vLLM provides `StatelessProcessGroup` to create a process group - without considering the global process group in torch.distributed. - It is recommended to create `StatelessProcessGroup`, and then initialize - the data-plane communication (NCCL) between external (train processes) - and vLLM workers. - """ - from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator - from vllm.distributed.utils import StatelessProcessGroup - - pg = StatelessProcessGroup.create( - host=master_address, port=master_port, rank=rank, world_size=world_size - ) - pynccl = PyNcclCommunicator(pg, device=device) - return pynccl - - -class WorkerExtension: - """ - The class for vLLM's worker to inherit from. - By defining an extension class, the code can work no matter what is - the underlying worker class. - - NOTE: we define this class in a separate module, and the main module - should pass the full qualified name as `worker_extension_cls` argument. - """ - - def init_weight_update_group( - self, master_address, master_port, rank_offset, world_size - ): - from vllm.distributed.parallel_state import get_world_group - - rank = get_world_group().rank + rank_offset - self.model_update_group = stateless_init_process_group( - master_address, - master_port, - rank, - world_size, - self.device, - ) - - def update_weight(self, name, dtype_name, shape): - dtype = getattr(torch, dtype_name) - weight = torch.empty(shape, dtype=dtype, device="cuda") - self.model_update_group.broadcast( - weight, src=0, stream=torch.cuda.current_stream() - ) - - self.model_runner.model.load_weights(weights=[(name, weight)]) - - del weight - - def check_weights_changed(self): - """ - Check if the weights are updated to 0. - """ - weights_updated = True - for name, p in self.model_runner.model.named_parameters(): - weights_updated = weights_updated and torch.allclose(p, torch.zeros_like(p)) - return weights_updated - - -def rebuild_ipc( - handle: tuple[Callable, tuple], device_id: int | None = None -) -> torch.Tensor: - func, args = handle - list_args = list(args) - if device_id is not None: - # the key is to change device id to the current device id - # in case two processes have different CUDA_VISIBLE_DEVICES - list_args[6] = device_id - buffer = func(*list_args) - return buffer - - -class FlattenedTensorMetadata(TypedDict): - name: str - shape: torch.Size - dtype: torch.dtype - # specify the start offset of this tensor in shared ipc_buffer tensor - offset: int - - -class ColocateWorkerExtension: - """ - The class for vLLM's worker to inherit from, in the colocate setting. - By defining an extension class, the code can work no matter what is - the underlying worker class. - - NOTE: we define this class in a separate module, and the main module - should pass the full qualified name as `worker_extension_cls` argument. - """ - - def update_weights_from_ipc(self, zmq_handles: dict[str, str]): - from vllm.model_executor.model_loader.utils import process_weights_after_loading - - assert self.device is not None - if not hasattr(self, "_zmq_ctx") or self._zmq_ctx is None: - self._zmq_ctx = zmq.Context() - socket = self._zmq_ctx.socket(zmq.REP) - socket.connect(zmq_handles[self.report_device_id()]) - buffer: torch.Tensor | None = None - while True: - payload: tuple[Callable, tuple] | list[FlattenedTensorMetadata] | None = ( - socket.recv_pyobj() - ) - if payload is None: - # means the update is done - process_weights_after_loading( - self.model_runner.model, self.model_config, self.device - ) - torch.accelerator.synchronize() - socket.send(b"") - break - if isinstance(payload, tuple): - # an ipc handle that vLLM can use `func, args = handle` - # and `func(*args)` to rebuild GPU tensor. - buffer = rebuild_ipc(payload, self.device.index) - assert buffer.dtype == torch.uint8 - socket.send(b"") - continue - assert isinstance(payload, list) - assert buffer is not None - weights = [] - for item in payload: - shape = item["shape"] - if isinstance(shape, (list, tuple)): - shape = torch.Size(shape) - assert isinstance(shape, torch.Size) - dtype, offset = item["dtype"], item["offset"] - size = dtype.itemsize * shape.numel() - tensor = buffer[offset : offset + size].view(dtype=dtype).view(shape) - weights.append((item["name"], tensor)) - self.model_runner.model.load_weights(weights=weights) - del weights - torch.accelerator.synchronize() - socket.send(b"") - - socket.close() - del buffer - gc.collect() - torch.accelerator.empty_cache() - - def report_device_id(self) -> str: - from vllm.platforms import current_platform - - self.device_uuid = current_platform.get_device_uuid(self.device.index) - return self.device_uuid - - def check_weights_changed(self): - """ - Check if the weights are updated to 0. - """ - weights_updated = True - for name, p in self.model_runner.model.named_parameters(): - weights_updated = weights_updated and torch.allclose(p, torch.zeros_like(p)) - return weights_updated diff --git a/examples/offline_inference/new_weight_syncing/rlhf_async_new_apis.py b/examples/rl/rlhf_async_new_apis.py similarity index 100% rename from examples/offline_inference/new_weight_syncing/rlhf_async_new_apis.py rename to examples/rl/rlhf_async_new_apis.py diff --git a/examples/online_serving/new_weight_syncing/rlhf_http_ipc.py b/examples/rl/rlhf_http_ipc.py similarity index 100% rename from examples/online_serving/new_weight_syncing/rlhf_http_ipc.py rename to examples/rl/rlhf_http_ipc.py diff --git a/examples/online_serving/new_weight_syncing/rlhf_http_nccl.py b/examples/rl/rlhf_http_nccl.py similarity index 100% rename from examples/online_serving/new_weight_syncing/rlhf_http_nccl.py rename to examples/rl/rlhf_http_nccl.py diff --git a/examples/offline_inference/new_weight_syncing/rlhf_ipc.py b/examples/rl/rlhf_ipc.py similarity index 100% rename from examples/offline_inference/new_weight_syncing/rlhf_ipc.py rename to examples/rl/rlhf_ipc.py diff --git a/examples/offline_inference/new_weight_syncing/rlhf_nccl.py b/examples/rl/rlhf_nccl.py similarity index 100% rename from examples/offline_inference/new_weight_syncing/rlhf_nccl.py rename to examples/rl/rlhf_nccl.py diff --git a/requirements/common.txt b/requirements/common.txt index d96928f06b60..05666c5d14b0 100644 --- a/requirements/common.txt +++ b/requirements/common.txt @@ -12,7 +12,7 @@ tokenizers >= 0.21.1 # Required for fast incremental detokenization. protobuf >= 5.29.6, !=6.30.*, !=6.31.*, !=6.32.*, !=6.33.0.*, !=6.33.1.*, !=6.33.2.*, !=6.33.3.*, !=6.33.4.* # Required by LlamaTokenizer, gRPC. CVE-2026-0994 fastapi[standard] >= 0.115.0 # Required by FastAPI's form models in the OpenAI API server's audio transcriptions endpoint. aiohttp >= 3.13.3 -openai >= 1.99.1, < 2.25.0 # For Responses API with reasoning content +openai >= 2.0.0 # For Responses API with reasoning content pydantic >= 2.12.0 prometheus_client >= 0.18.0 pillow # Required for image processing diff --git a/tests/compile/fusions_e2e/conftest.py b/tests/compile/fusions_e2e/conftest.py index 873f92cfe6ce..5716c95bb241 100644 --- a/tests/compile/fusions_e2e/conftest.py +++ b/tests/compile/fusions_e2e/conftest.py @@ -82,6 +82,10 @@ def run( f"attention backend '{attn_backend.backend.name}'" ) + # TODO: remove this after finishing migration from envs to model kwargs + if model_name == "openai/gpt-oss-20b": + monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8", "1") + # Disable, compile cache to make sure custom passes run. # Otherwise, we can't verify fusion happened through the logs. monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1") diff --git a/tests/compile/fusions_e2e/models.py b/tests/compile/fusions_e2e/models.py index 9d6c202648e2..1a5f18cc0d50 100644 --- a/tests/compile/fusions_e2e/models.py +++ b/tests/compile/fusions_e2e/models.py @@ -162,3 +162,12 @@ # async_tp=n_layers * 2, ), ) + +gpt_oss_20b = ModelFusionInfo( + model_name="openai/gpt-oss-20b", + matches=lambda n_layers: Matches( + ar_rms_fusion=n_layers * 2 + 1, + sequence_parallel=n_layers * 2 + 1, + async_tp=n_layers * 2, + ), +) diff --git a/tests/compile/fusions_e2e/test_tp2_ar_rms.py b/tests/compile/fusions_e2e/test_tp2_ar_rms.py index 8ffadbfaf298..301409b2bf6a 100644 --- a/tests/compile/fusions_e2e/test_tp2_ar_rms.py +++ b/tests/compile/fusions_e2e/test_tp2_ar_rms.py @@ -20,6 +20,7 @@ FLASHINFER_MLA_ATTN, TRITON_ATTN, deepseek_v3_fp8, + gpt_oss_20b, llama3_8b, llama3_8b_fp4, llama3_8b_fp8, @@ -158,7 +159,7 @@ def test_tp2_ar_rms_fp4_fusions( @multi_gpu_test(num_gpus=2) @pytest.mark.parametrize( "model_name, matches_fn, model_kwargs, hf_overrides", - [llama3_8b, qwen3_a3b], + [llama3_8b, qwen3_a3b, gpt_oss_20b], ) @pytest.mark.parametrize("attn_backend", [TRITON_ATTN]) @pytest.mark.parametrize("n_layers", [4]) diff --git a/tests/entrypoints/openai/chat_completion/test_audio_in_video.py b/tests/entrypoints/openai/chat_completion/test_audio_in_video.py index 9e56b03027a5..8c024995b938 100644 --- a/tests/entrypoints/openai/chat_completion/test_audio_in_video.py +++ b/tests/entrypoints/openai/chat_completion/test_audio_in_video.py @@ -9,7 +9,7 @@ import pytest_asyncio from tests.conftest import VideoTestAssets -from tests.utils import RemoteOpenAIServer +from tests.utils import ROCM_EXTRA_ARGS, RemoteOpenAIServer MODEL_NAME = "Qwen/Qwen2.5-Omni-3B" @@ -22,6 +22,7 @@ def server(): "--enforce-eager", "--limit-mm-per-prompt", json.dumps({"audio": 3, "video": 3}), + *ROCM_EXTRA_ARGS, ] with RemoteOpenAIServer( diff --git a/tests/entrypoints/openai/responses/conftest.py b/tests/entrypoints/openai/responses/conftest.py index 68fdbbba3b02..a1d16b123166 100644 --- a/tests/entrypoints/openai/responses/conftest.py +++ b/tests/entrypoints/openai/responses/conftest.py @@ -370,7 +370,7 @@ def log_response_diagnostics( def default_server_args(): return [ "--max-model-len", - "8192", + "18192", "--enforce-eager", # For faster startup. "--enable-auto-tool-choice", "--structured-outputs-config.backend", diff --git a/tests/entrypoints/openai/responses/test_function_call.py b/tests/entrypoints/openai/responses/test_function_call.py index 36627f92d7d7..bacb084c7eb6 100644 --- a/tests/entrypoints/openai/responses/test_function_call.py +++ b/tests/entrypoints/openai/responses/test_function_call.py @@ -118,7 +118,6 @@ async def test_function_tool_use( tool_choice=tool_choice, temperature=0.0, ) - assert len(response.output) >= 1 tool_call = None reasoning = None @@ -127,11 +126,15 @@ async def test_function_tool_use( tool_call = out if out.type == "reasoning": reasoning = out - assert tool_call is not None - assert tool_call.type == "function_call" - assert json.loads(tool_call.arguments) is not None - assert reasoning is not None - assert reasoning.type == "reasoning" + if response.incomplete_details is None: + assert tool_call is not None + assert tool_call.type == "function_call" + assert json.loads(tool_call.arguments) is not None + assert reasoning is not None + assert reasoning.type == "reasoning" + else: + print(response.model_dump_json(indent=2)) + assert response.incomplete_details.reason == "max_output_tokens" @pytest.mark.asyncio diff --git a/tests/kernels/attention/test_trtllm_kvfp8_dequant.py b/tests/kernels/attention/test_trtllm_kvfp8_dequant.py index a2ea372c0c15..c49ceb03f5b1 100644 --- a/tests/kernels/attention/test_trtllm_kvfp8_dequant.py +++ b/tests/kernels/attention/test_trtllm_kvfp8_dequant.py @@ -12,6 +12,12 @@ from vllm.platforms import current_platform +if current_platform.is_rocm(): + pytest.skip( + "trtllm kvfp8 dequant is not supported on ROCm.", + allow_module_level=True, + ) + FP8_DTYPE = current_platform.fp8_dtype() NUM_BLOCKS = 128 diff --git a/tests/kernels/moe/test_router_gemm.py b/tests/kernels/moe/test_router_gemm.py new file mode 100644 index 000000000000..906e47708f29 --- /dev/null +++ b/tests/kernels/moe/test_router_gemm.py @@ -0,0 +1,37 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Tests for optimized router GEMM kernel + +Run `pytest tests/kernels/moe/test_router_gemm.py`. +""" + +import pytest +import torch + +import vllm._custom_ops as ops +from vllm.platforms import current_platform +from vllm.utils.torch_utils import set_random_seed + + +@pytest.mark.skipif( + not ( + current_platform.is_cuda() + and ( + current_platform.is_device_capability(90) + or current_platform.is_device_capability_family(100) + ) + ), + reason="This test only runs on Hopper or Blackwell GPUs.", +) +@pytest.mark.parametrize("batch_size", [1, 2, 4, 8]) +@pytest.mark.parametrize("input_dim", [360, 720, 1440, 2880]) +@pytest.mark.parametrize("output_dim", [32, 64, 128]) +def test_gpt_oss_router_gemm(batch_size, input_dim, output_dim): + set_random_seed(0) + x = torch.randn(batch_size, input_dim, device="cuda", dtype=torch.bfloat16) + weight = torch.randn(output_dim, input_dim, device="cuda", dtype=torch.bfloat16) + bias = torch.randn(output_dim, device="cuda", dtype=torch.bfloat16) + + output = ops.gpt_oss_router_gemm(x, weight, bias) + output_ref = torch.nn.functional.linear(x, weight, bias) + torch.testing.assert_close(output, output_ref, atol=1e-2, rtol=1e-2) diff --git a/tests/models/multimodal/generation/test_common.py b/tests/models/multimodal/generation/test_common.py index 97dc6c51c5a9..c16efd065e1b 100644 --- a/tests/models/multimodal/generation/test_common.py +++ b/tests/models/multimodal/generation/test_common.py @@ -777,6 +777,7 @@ max_model_len=8192, max_num_seqs=2, auto_cls=AutoModelForCausalLM, + patch_hf_runner=model_utils.paddleocr_vl_patch_hf_runner, image_size_factors=[(0.25,)], marks=[ pytest.mark.skipif( diff --git a/tests/models/multimodal/generation/vlm_utils/model_utils.py b/tests/models/multimodal/generation/vlm_utils/model_utils.py index b8e31e274de4..9bdedb3c5c25 100644 --- a/tests/models/multimodal/generation/vlm_utils/model_utils.py +++ b/tests/models/multimodal/generation/vlm_utils/model_utils.py @@ -489,13 +489,14 @@ def __init__(self, hf_runner: HfRunner): self.image_size = self.vision_config.image_size def __call__(self, text: str, images: Image | list[Image], **kwargs): - from vllm.model_executor.models.h2ovl import ( - IMG_CONTEXT, - IMG_END, - IMG_START, + from vllm.transformers_utils.processors.h2ovl import ( image_to_pixel_values_h2ovl, ) + IMG_START = "" + IMG_END = "" + IMG_CONTEXT = "" + images = [images] if isinstance(images, Image) else images pixel_values = [ image_to_pixel_values_h2ovl( @@ -751,16 +752,17 @@ def __init__(self, hf_runner: HfRunner): self.image_size = self.vision_config.image_size def __call__(self, text: str, images: Image | list[Image], **kwargs): - from vllm.model_executor.models.skyworkr1v import ( - IMG_CONTEXT, - IMG_END, - IMG_START, - image_to_pixel_values_skyworkr1v, + from vllm.transformers_utils.processors.internvl import ( + image_to_pixel_values_internvl, ) + IMG_START = "" + IMG_END = "" + IMG_CONTEXT = "" + images = [images] if isinstance(images, Image) else images pixel_values = [ - image_to_pixel_values_skyworkr1v( + image_to_pixel_values_internvl( image, input_size=self.image_size, min_num=self.min_num, @@ -815,14 +817,15 @@ def __call__( videos: npt.NDArray | list[npt.NDArray] = None, **kwargs, ): - from vllm.model_executor.models.internvl import ( - IMG_CONTEXT, - IMG_END, - IMG_START, + from vllm.transformers_utils.processors.internvl import ( image_to_pixel_values_internvl, video_to_pixel_values_internvl, ) + IMG_START = "" + IMG_END = "" + IMG_CONTEXT = "" + images = [images] if isinstance(images, Image) else images videos = [videos] if isinstance(videos, np.ndarray) else videos if images is not None: @@ -1149,6 +1152,31 @@ def processor(*args, text="", images=None, videos=None, **kwargs): return hf_model +def paddleocr_vl_patch_hf_runner(hf_model: HfRunner) -> HfRunner: + """Patches the HfRunner to fix create_causal_mask API mismatch. + + The PaddleOCR-VL HF model passes `inputs_embeds` to create_causal_mask, + but transformers renamed this parameter to `input_embeds`. + """ + import sys + + model_module = sys.modules.get(type(hf_model.model.model).__module__) + if model_module is None: + return hf_model + + original_create_causal_mask = getattr(model_module, "create_causal_mask", None) + if original_create_causal_mask is None: + return hf_model + + def patched_create_causal_mask(*args, **kwargs): + if "inputs_embeds" in kwargs: + kwargs["input_embeds"] = kwargs.pop("inputs_embeds") + return original_create_causal_mask(*args, **kwargs) + + model_module.create_causal_mask = patched_create_causal_mask # type: ignore[attr-defined] + return hf_model + + def qwen2_5_omni_patch_hf_runner(hf_model: HfRunner) -> HfRunner: """Patches and returns an instance of the HfRunner for Qwen2.5-Omni.""" thinker = hf_model.model.thinker diff --git a/tests/models/registry.py b/tests/models/registry.py index 47551d7eb187..aac707a9065b 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -779,7 +779,8 @@ def check_available_online( "rednote-hilab/dots.ocr", trust_remote_code=True ), "Eagle2_5_VLForConditionalGeneration": _HfExamplesInfo( - "nvidia/Eagle2.5-8B", trust_remote_code=True, is_available_online=False + "nvidia/Eagle2.5-8B", + trust_remote_code=True, ), "Emu3ForConditionalGeneration": _HfExamplesInfo("BAAI/Emu3-Chat-hf"), "Ernie4_5_VLMoeForConditionalGeneration": _HfExamplesInfo( diff --git a/tests/multimodal/media/test_video.py b/tests/multimodal/media/test_video.py index 9c04d991aba0..a1223ebc07e2 100644 --- a/tests/multimodal/media/test_video.py +++ b/tests/multimodal/media/test_video.py @@ -1,9 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import io from pathlib import Path import numpy as np import numpy.typing as npt +import pybase64 import pytest from PIL import Image @@ -235,3 +237,53 @@ def test_video_media_io_backend_env_var_fallback(monkeypatch: pytest.MonkeyPatch frames_missing, metadata_missing = videoio_missing.load_bytes(b"test") np.testing.assert_array_equal(frames_missing, FAKE_OUTPUT_2) assert metadata_missing["video_backend"] == "test_video_backend_override_2" + + +def test_load_base64_jpeg_returns_metadata(): + """Regression test: load_base64 with video/jpeg must return metadata. + + Previously, base64 JPEG frame sequences returned an empty dict for + metadata, which broke downstream consumers that rely on fields like + total_num_frames and fps. See PR #37301. + """ + + num_test_frames = 3 + frame_width, frame_height = 8, 8 + + # Build a few tiny JPEG frames and base64-encode them + b64_frames = [] + for i in range(num_test_frames): + img = Image.new("RGB", (frame_width, frame_height), color=(i * 80, 0, 0)) + buf = io.BytesIO() + img.save(buf, format="JPEG") + b64_frames.append(pybase64.b64encode(buf.getvalue()).decode("ascii")) + + data = ",".join(b64_frames) + + imageio = ImageMediaIO() + videoio = VideoMediaIO(imageio, num_frames=num_test_frames) + frames, metadata = videoio.load_base64("video/jpeg", data) + + # Frames array shape: (num_frames, H, W, 3) + assert frames.shape[0] == num_test_frames + + # All required metadata keys must be present + required_keys = { + "total_num_frames", + "fps", + "duration", + "video_backend", + "frames_indices", + "do_sample_frames", + } + assert required_keys.issubset(metadata.keys()), ( + f"Missing metadata keys: {required_keys - metadata.keys()}" + ) + + assert metadata["total_num_frames"] == num_test_frames + assert metadata["video_backend"] == "jpeg_sequence" + assert metadata["frames_indices"] == list(range(num_test_frames)) + assert metadata["do_sample_frames"] is False + # Default fps=1 → duration == num_frames + assert metadata["fps"] == 1.0 + assert metadata["duration"] == float(num_test_frames) diff --git a/tests/tool_parsers/test_glm47_moe_tool_parser.py b/tests/tool_parsers/test_glm47_moe_tool_parser.py new file mode 100644 index 000000000000..c7170e67500f --- /dev/null +++ b/tests/tool_parsers/test_glm47_moe_tool_parser.py @@ -0,0 +1,168 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# ruff: noqa: E501 +"""Tests for the GLM-4.7 tool call parser.""" + +import json +from unittest.mock import Mock + +import pytest + +from vllm.entrypoints.openai.chat_completion.protocol import ( + ChatCompletionRequest, + ChatCompletionToolsParam, + FunctionDefinition, +) +from vllm.tokenizers import get_tokenizer +from vllm.tool_parsers.glm47_moe_tool_parser import Glm47MoeModelToolParser + +MODEL = "zai-org/GLM-4.5" + + +@pytest.fixture(scope="module") +def glm47_tokenizer(): + return get_tokenizer(tokenizer_name=MODEL) + + +@pytest.fixture +def glm47_tool_parser(glm47_tokenizer): + return Glm47MoeModelToolParser(glm47_tokenizer) + + +@pytest.fixture +def mock_request() -> ChatCompletionRequest: + request = Mock(spec=ChatCompletionRequest) + request.tools = [ + ChatCompletionToolsParam( + function=FunctionDefinition(name="get_current_date", parameters={}), + ), + ChatCompletionToolsParam( + function=FunctionDefinition( + name="get_weather", + parameters={ + "type": "object", + "properties": { + "city": {"type": "string"}, + "date": {"type": "string"}, + }, + }, + ), + ), + ] + request.tool_choice = "auto" + return request + + +class TestGlm47ExtractToolCalls: + def test_no_tool_call(self, glm47_tool_parser, mock_request): + out = "This is a plain response." + r = glm47_tool_parser.extract_tool_calls(out, request=mock_request) + assert not r.tools_called + assert r.content == out + + def test_zero_arg_inline(self, glm47_tool_parser, mock_request): + out = "get_current_date" + r = glm47_tool_parser.extract_tool_calls(out, request=mock_request) + assert r.tools_called + assert r.tool_calls[0].function.name == "get_current_date" + assert json.loads(r.tool_calls[0].function.arguments) == {} + assert r.content is None + + def test_zero_arg_newline(self, glm47_tool_parser, mock_request): + out = "get_current_date\n" + r = glm47_tool_parser.extract_tool_calls(out, request=mock_request) + assert r.tools_called + assert r.tool_calls[0].function.name == "get_current_date" + + def test_args_same_line(self, glm47_tool_parser, mock_request): + out = "get_weathercityBeijing" + r = glm47_tool_parser.extract_tool_calls(out, request=mock_request) + assert r.tools_called + assert json.loads(r.tool_calls[0].function.arguments) == {"city": "Beijing"} + + def test_args_with_newlines(self, glm47_tool_parser, mock_request): + out = "get_weather\ncity\nBeijing\n" + r = glm47_tool_parser.extract_tool_calls(out, request=mock_request) + assert r.tools_called + assert json.loads(r.tool_calls[0].function.arguments) == {"city": "Beijing"} + + def test_content_before(self, glm47_tool_parser, mock_request): + out = "Checking.get_current_date" + r = glm47_tool_parser.extract_tool_calls(out, request=mock_request) + assert r.tools_called + assert r.content == "Checking." + + def test_multiple(self, glm47_tool_parser, mock_request): + out = ( + "get_weathercityBeijing" + "get_weathercityShanghai" + ) + r = glm47_tool_parser.extract_tool_calls(out, request=mock_request) + assert len(r.tool_calls) == 2 + + def test_empty_content_none(self, glm47_tool_parser, mock_request): + out = "get_current_date" + r = glm47_tool_parser.extract_tool_calls(out, request=mock_request) + assert r.content is None + + def test_whitespace_content_none(self, glm47_tool_parser, mock_request): + out = " \n get_current_date" + r = glm47_tool_parser.extract_tool_calls(out, request=mock_request) + assert r.content is None + + +def _reset(parser): + parser._buffer = "" + parser._in_tool_call = False + parser.current_tool_name_sent = False + parser._current_tool_name = None + parser._pending_key = None + parser._streaming_string_value = False + parser.prev_tool_call_arr = [] + parser.current_tool_id = -1 + parser.streamed_args_for_tool = [] + parser._tool_call_ids = [] + parser._args_started = [] + parser._args_closed = [] + parser._seen_keys = [] + + +class TestGlm47Streaming: + def test_no_args(self, glm47_tool_parser, mock_request): + _reset(glm47_tool_parser) + for chunk in ["", "get_current_date", ""]: + glm47_tool_parser.extract_tool_calls_streaming( + previous_text="", + current_text="", + delta_text=chunk, + previous_token_ids=[], + current_token_ids=[], + delta_token_ids=[], + request=mock_request, + ) + assert len(glm47_tool_parser.prev_tool_call_arr) >= 1 + + def test_with_args(self, glm47_tool_parser, mock_request): + _reset(glm47_tool_parser) + # Split chunks so that the incremental string streaming path + # processes the value, its closing tag, and the tool-call closing + # tag in separate calls. + for chunk in [ + "", + "get_weather\n", + "city", + "", + "Beijing", + "", + "", + ]: + glm47_tool_parser.extract_tool_calls_streaming( + previous_text="", + current_text="", + delta_text=chunk, + previous_token_ids=[], + current_token_ids=[], + delta_token_ids=[], + request=mock_request, + ) + assert glm47_tool_parser.prev_tool_call_arr[0]["arguments"]["city"] == "Beijing" diff --git a/tests/tool_parsers/test_glm4_moe_tool_parser.py b/tests/tool_parsers/test_glm4_moe_tool_parser.py index 9ee9ea008f3f..213cc75db7ea 100644 --- a/tests/tool_parsers/test_glm4_moe_tool_parser.py +++ b/tests/tool_parsers/test_glm4_moe_tool_parser.py @@ -107,7 +107,7 @@ def test_extract_tool_calls_no_tools(glm4_moe_tool_parser, mock_request): ) ) ], - "", + None, ), ( """get_current_weather @@ -152,7 +152,7 @@ def test_extract_tool_calls_no_tools(glm4_moe_tool_parser, mock_request): ) ), ], - "", + None, ), ( """I'll help you check the weather. get_current_weather @@ -202,7 +202,7 @@ def test_extract_tool_calls_no_tools(glm4_moe_tool_parser, mock_request): ) ) ], - "", + None, ), ( """I will help you get the weather.get_weather diff --git a/tests/v1/executor/test_executor.py b/tests/v1/executor/test_executor.py index e9f635378e57..494e8aa67dd8 100644 --- a/tests/v1/executor/test_executor.py +++ b/tests/v1/executor/test_executor.py @@ -14,12 +14,35 @@ from vllm.sampling_params import SamplingParams from vllm.v1.engine.async_llm import AsyncLLM from vllm.v1.engine.llm_engine import LLMEngine +from vllm.v1.executor.abstract import Executor from vllm.v1.executor.multiproc_executor import MultiprocExecutor +from vllm.v1.executor.uniproc_executor import ( + ExecutorWithExternalLauncher, + UniProcExecutor, +) class Mock: ... +def test_supports_async_scheduling_base_executor(): + assert Executor.supports_async_scheduling() is False + + +def test_supports_async_scheduling_uniproc_executor(): + assert UniProcExecutor.supports_async_scheduling() is True + + +def test_supports_async_scheduling_executor_with_external_launcher(): + # ExecutorWithExternalLauncher inherits from UniProcExecutor and does not + # override supports_async_scheduling, so it should return True. + assert ExecutorWithExternalLauncher.supports_async_scheduling() is True + + +def test_supports_async_scheduling_multiproc_executor(): + assert MultiprocExecutor.supports_async_scheduling() is True + + class CustomMultiprocExecutor(MultiprocExecutor): def collective_rpc( self, diff --git a/tests/v1/kv_connector/unit/test_multi_connector.py b/tests/v1/kv_connector/unit/test_multi_connector.py index 6acc486292a1..671a80137b63 100644 --- a/tests/v1/kv_connector/unit/test_multi_connector.py +++ b/tests/v1/kv_connector/unit/test_multi_connector.py @@ -231,10 +231,11 @@ def test_multi_example_connector_consistency(): ] # First three events are from initialization (register_kv_caches, # set_host_xfer_buffer_ops, get_handshake_metadata), then generate() events. - assert events["storage1-WORKER"][:7] == [ + assert events["storage1-WORKER"][:8] == [ "register_kv_caches", "set_host_xfer_buffer_ops", "get_handshake_metadata", + "handle_preemptions", "bind_connector_metadata", "start_load_kv", "wait_for_layer_load", @@ -246,10 +247,11 @@ def test_multi_example_connector_consistency(): "update_state_after_alloc num_blocks=[0] 0", "build_connector_meta", ] - assert events["storage2-WORKER"][:7] == [ + assert events["storage2-WORKER"][:8] == [ "register_kv_caches", "set_host_xfer_buffer_ops", "get_handshake_metadata", + "handle_preemptions", "bind_connector_metadata", "start_load_kv", "wait_for_layer_load", @@ -399,8 +401,8 @@ def test_multi_connector_handle_preemptions_integration(): # testing the delegation behavior of MultiConnector here. # The connector attribute contains the KV connector. assert scheduler.connector is not None, "Scheduler should have a connector" - preempted_req_ids = {"req-1", "req-2", "req-3"} - scheduler.connector.handle_preemptions(preempted_req_ids) + connector_md = scheduler.connector.build_connector_meta(scheduler.schedule()) + scheduler.connector.handle_preemptions(connector_md) # Verify both connectors received the handle_preemptions call events = get_connector_events() diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index 095bd4c3dd98..bda9e43c7829 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -694,16 +694,18 @@ def test_async_load_kv( ) @pytest.mark.parametrize("local_tp_size", [1, 2]) def test_prefill_tp_size_greater_than_decode_tp_size( - self, local_tp_size: int, default_vllm_config, dist_init + self, local_tp_size: int, default_vllm_config, dist_init, monkeypatch ): """ Verify remote TP > local TP handshake succeeds with different remote configurations. """ + monkeypatch.setattr( + "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.get_tensor_model_parallel_world_size", + lambda: local_tp_size, + ) vllm_config = create_vllm_config() - local_tp_size = 1 - vllm_config.parallel_config.tensor_parallel_size = local_tp_size connector = NixlConnector( vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16) @@ -738,10 +740,10 @@ def check_handshake(remote_tp_size: int): remote_agents = worker._nixl_handshake( host="localhost", port=1234, - remote_tp_size=2, + remote_tp_size=4, expected_engine_id=worker.REMOTE_ENGINE_ID, ) - check_handshake(2) + check_handshake(4) # NOTE flexibility: a second remote with higher number of ranks is # discovered. This is not a scenario we actively support right now, but @@ -759,9 +761,8 @@ def check_handshake(remote_tp_size: int): "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", FakeNixlWrapper, ) - @pytest.mark.parametrize("local_tp_size", [1, 2]) def test_prefill_tp_size_greater_than_decode_tp_size_mla( - self, local_tp_size: int, default_vllm_config, dist_init + self, default_vllm_config, dist_init ): """ Verify remote TP > local TP handshake succeeds with different @@ -1369,7 +1370,13 @@ def run_test_and_cleanup(): "NIXL_TELEMETRY_ENABLE": "1", }, } - ray.init(runtime_env=runtime_env) + # On XPU/ROCm, vLLM expects Ray's device key to be "GPU". + # Explicitly reserving GPU resources here prevents false negatives + # when Ray cannot auto-detect accelerator resources in test envs. + ray_init_kwargs: dict[str, Any] = {"runtime_env": runtime_env} + if not current_platform.is_cuda(): + ray_init_kwargs["num_gpus"] = 1 + ray.init(**ray_init_kwargs) try: run_test_and_cleanup() finally: @@ -2005,7 +2012,7 @@ def test_transfer_failure_logging( connector = NixlConnector( vllm_config, KVConnectorRole.WORKER, - make_kv_cache_config(block_size=16, hma_enabled=enable_hma), + make_kv_cache_config(block_size=16, swa_enabled=enable_hma), ) connector.connector_worker = FakeNixlConnectorWorker( vllm_config, diff --git a/tests/v1/kv_connector/unit/test_nixl_connector_hma.py b/tests/v1/kv_connector/unit/test_nixl_connector_hma.py index d4b0c28a5de5..898f8e4b35ba 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector_hma.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector_hma.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Unit tests for NixlConnectorScheduler sw_sizes calculation with HMA.""" +"""Unit tests for NixlConnectorScheduler with HMA and Mamba N-1 prefill.""" from unittest.mock import patch @@ -14,24 +14,26 @@ ) from .utils import ( + create_request, create_vllm_config, make_kv_cache_config, + make_nixl_scheduler, ) @pytest.mark.cpu_test @pytest.mark.parametrize( - "hma_enabled,expected_sw_sizes", + "swa_enabled,expected_sw_sizes", [ - # HMA enabled: FullAttentionSpec (0) + SlidingWindowSpec (2048/16=128) + # SWA enabled: FullAttentionSpec (0) + SlidingWindowSpec (2048/16=128) (True, [0, 128 + 1]), - # HMA disabled: only FullAttentionSpec (0) + # SWA disabled: only FullAttentionSpec (0) (False, [0]), ], ) @patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.current_platform") -def test_sw_sizes(mock_platform, hma_enabled, expected_sw_sizes): - """Test sw_sizes is correctly computed based on HMA enabled/disabled.""" +def test_sw_sizes(mock_platform, swa_enabled, expected_sw_sizes): + """Test sw_sizes is correctly computed based on SWA enabled/disabled.""" from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import ( NixlConnectorScheduler, ) @@ -42,7 +44,7 @@ def test_sw_sizes(mock_platform, hma_enabled, expected_sw_sizes): vllm_config = create_vllm_config(block_size=block_size) # SW 2048 tokens=>128 blocks kv_cache_config = make_kv_cache_config( - block_size=block_size, hma_enabled=hma_enabled, sw_size=2048 + block_size=block_size, swa_enabled=swa_enabled, sw_size=2048 ) scheduler = NixlConnectorScheduler( @@ -75,7 +77,7 @@ def test_logical_to_kernel_block_ids_with_hma(): # So each logical block maps to 2 kernel blocks eg [0]->[0,1] worker._physical_blocks_per_logical_kv_block = 2 # FA + SW groups (neither is MambaSpec, so both get expanded) - worker.kv_cache_config = make_kv_cache_config(block_size=16, hma_enabled=True) + worker.kv_cache_config = make_kv_cache_config(block_size=16, swa_enabled=True) # Test conversion: FA + SW group logical_block_ids = [[0, 1, 2], [3, 4]] @@ -313,3 +315,106 @@ def test_nixl_metadata_hybrid_ssm_block_ids(): assert list(req_meta.remote.block_ids[0]) == [10, 11, 12, 13, 14, 15, 16, 17] assert list(req_meta.remote.block_ids[1]) == [20, 21] assert len(req_meta.remote.block_ids[0]) != len(req_meta.remote.block_ids[1]) + + +# ── Mamba N-1 prefill tests ────────────────────────────────────────────── + + +@pytest.mark.cpu_test +@pytest.mark.parametrize( + "has_mamba,is_hma_required,expected_count", + [ + (True, True, 9), + (False, False, 10), + (False, True, 10), + ], + ids=["mamba", "fa_only", "swa_only"], +) +def test_mamba_n1_d_side(has_mamba, is_hma_required, expected_count): + """D-side: Mamba gets N-1 matched tokens, non-Mamba gets N.""" + sched = make_nixl_scheduler(has_mamba=has_mamba, is_hma_required=is_hma_required) + req = create_request(num_tokens=10, do_remote_prefill=True) + + count, is_async = sched.get_num_new_matched_tokens(req, num_computed_tokens=0) + assert count == expected_count + assert is_async is True + + +@pytest.mark.cpu_test +def test_mamba_n1_p_side_truncation(): + """P-side: Mamba truncates prompt to N-1, sets max_tokens=1. + + Also verifies idempotency (calling again is a no-op) which is + needed for preemption safety via the _p_side_truncated guard, + and that non-Mamba models skip truncation entirely. + """ + sched = make_nixl_scheduler(has_mamba=True, is_hma_required=True) + req = create_request(num_tokens=10, do_remote_decode=True) + req.max_tokens = 128 + original_len = len(req.prompt_token_ids) + + count, is_async = sched.get_num_new_matched_tokens(req, num_computed_tokens=0) + + assert count == 0 + assert is_async is False + assert len(req.prompt_token_ids) == original_len - 1 + assert req.num_prompt_tokens == original_len - 1 + assert req.max_tokens == 1 + assert req.kv_transfer_params["_p_side_truncated"] is True + + # Idempotency: second call must not truncate further + sched.get_num_new_matched_tokens(req, num_computed_tokens=0) + assert len(req.prompt_token_ids) == original_len - 1 + + # Non-Mamba: truncation is skipped + fa_sched = make_nixl_scheduler(has_mamba=False, is_hma_required=False) + fa_req = create_request(num_tokens=10, do_remote_decode=True) + fa_original = len(fa_req.prompt_token_ids) + + fa_sched.get_num_new_matched_tokens(fa_req, num_computed_tokens=0) + assert len(fa_req.prompt_token_ids) == fa_original + + +@pytest.mark.cpu_test +@pytest.mark.parametrize( + "swa_enabled,mamba_enabled,expected_has_mamba,expected_is_hma", + [ + (True, True, True, True), + (True, False, False, True), + (False, False, False, False), + ], + ids=["fa_swa_mamba", "fa_swa_only", "fa_only"], +) +@patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.current_platform") +def test_has_mamba_init( + mock_platform, + swa_enabled, + mamba_enabled, + expected_has_mamba, + expected_is_hma, +): + """Test _has_mamba / _is_hma_required derived from kv_cache_groups.""" + from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import ( + NixlConnectorScheduler, + ) + + mock_platform.device_type = "cpu" + + block_size = 16 + vllm_config = create_vllm_config(block_size=block_size) + # VllmConfig.__post_init__ auto-disables HMA when kv_transfer_config + # is set; override so we can test the scheduler's own derivation. + vllm_config.scheduler_config.disable_hybrid_kv_cache_manager = False + kv_cache_config = make_kv_cache_config( + block_size=block_size, + swa_enabled=swa_enabled, + mamba_enabled=mamba_enabled, + ) + + scheduler = NixlConnectorScheduler( + vllm_config=vllm_config, + engine_id="test-engine", + kv_cache_config=kv_cache_config, + ) + assert scheduler._has_mamba is expected_has_mamba + assert scheduler._is_hma_required is expected_is_hma diff --git a/tests/v1/kv_connector/unit/test_offloading_connector.py b/tests/v1/kv_connector/unit/test_offloading_connector.py index 893a5d8d4d78..cf118f7f3c60 100644 --- a/tests/v1/kv_connector/unit/test_offloading_connector.py +++ b/tests/v1/kv_connector/unit/test_offloading_connector.py @@ -13,11 +13,15 @@ from vllm.config import KVTransferConfig, VllmConfig from vllm.distributed.kv_events import BlockRemoved, BlockStored from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorRole -from vllm.distributed.kv_transfer.kv_connector.v1.offloading_connector import ( - OffloadingConnector, +from vllm.distributed.kv_transfer.kv_connector.v1.offloading.common import ( OffloadingConnectorMetadata, +) +from vllm.distributed.kv_transfer.kv_connector.v1.offloading.metrics import ( OffloadingConnectorStats, ) +from vllm.distributed.kv_transfer.kv_connector.v1.offloading_connector import ( + OffloadingConnector, +) from vllm.forward_context import ForwardContext from vllm.utils.hashing import sha256 from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend @@ -363,10 +367,7 @@ def _run(self, decoded_tokens: list[int], complete_transfers: bool): assert kv_connector_metadata is not None assert isinstance(kv_connector_metadata, OffloadingConnectorMetadata) - if scheduler_output.preempted_req_ids: - self.worker_connector.handle_preemptions( - scheduler_output.preempted_req_ids - ) + self.worker_connector.handle_preemptions(kv_connector_metadata) self.worker_connector.bind_connector_metadata(kv_connector_metadata) self.worker_connector.start_load_kv(self._dummy_ctx) diff --git a/tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py b/tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py index f48dc0fff602..283b4f25e6e4 100644 --- a/tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py +++ b/tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py @@ -1,10 +1,15 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import copy +from unittest.mock import patch import pytest -from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, KVConnectorOutput +from vllm.v1.outputs import ( + EMPTY_MODEL_RUNNER_OUTPUT, + KVConnectorOutput, + ModelRunnerOutput, +) from vllm.v1.request import FinishReason, RequestStatus from .utils import ( @@ -13,6 +18,7 @@ create_request, create_scheduler, create_vllm_config, + make_kv_cache_config, ) pytestmark = pytest.mark.cpu_test @@ -579,3 +585,73 @@ def test_cannot_recv(): scheduler.update_from_output(scheduler_output, model_runner_output) _ = scheduler.schedule() assert_scheduler_empty(scheduler) + + +@patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.current_platform") +def test_p_side_chunked_prefill_mamba(mock_platform): + """P-side integration: Mamba N-1 truncation + chunked prefill completes. + + A 64-token P-side request is truncated to 63 by the N-1 fix, then + chunked into two prefill steps (32 + 31) and finishes with + LENGTH_CAPPED because max_tokens is set to 1. + """ + mock_platform.device_type = "cpu" + + BATCH_SIZE = 32 + NUM_TOKENS = 64 + BLOCK_SIZE = 16 + + vllm_config = create_vllm_config( + max_num_batched_tokens=BATCH_SIZE, + block_size=BLOCK_SIZE, + ) + vllm_config.scheduler_config.disable_hybrid_kv_cache_manager = False + + kv_cache_config = make_kv_cache_config( + block_size=BLOCK_SIZE, + mamba_enabled=True, + num_blocks=10000, + ) + + scheduler = create_scheduler(vllm_config, kv_cache_config=kv_cache_config) + + request = create_request( + num_tokens=NUM_TOKENS, + do_remote_decode=True, + block_size=BLOCK_SIZE, + ) + request.max_tokens = 128 + scheduler.add_request(request) + request_id = request.request_id + + # ── Step 1: first chunk ── + scheduler_output = scheduler.schedule() + + assert len(request.prompt_token_ids) == NUM_TOKENS - 1 + assert request.max_tokens == 1 + assert scheduler_output.num_scheduled_tokens[request_id] == BATCH_SIZE + assert request.num_computed_tokens == BATCH_SIZE + + # Model returns no tokens for intermediate prefill chunk + intermediate_output = ModelRunnerOutput( + req_ids=[request.request_id], + req_id_to_index={request.request_id: 0}, + sampled_token_ids=[[]], + ) + scheduler.update_from_output(scheduler_output, intermediate_output) + + # ── Step 2: remaining chunk ── + scheduler_output = scheduler.schedule() + + remaining = NUM_TOKENS - 1 - BATCH_SIZE # 31 + assert scheduler_output.num_scheduled_tokens[request_id] == remaining + assert request.num_computed_tokens == NUM_TOKENS - 1 + + # Prefill complete: model generates 1 decode token + final_output = create_model_runner_output([request]) + engine_core_outputs = scheduler.update_from_output(scheduler_output, final_output) + + # max_tokens=1 → request finishes with LENGTH + outputs = engine_core_outputs[0].outputs + assert len(outputs) == 1 + assert outputs[0].finish_reason == FinishReason.LENGTH diff --git a/tests/v1/kv_connector/unit/utils.py b/tests/v1/kv_connector/unit/utils.py index 6e00cf8d5bed..1e2a05f0e345 100644 --- a/tests/v1/kv_connector/unit/utils.py +++ b/tests/v1/kv_connector/unit/utils.py @@ -37,6 +37,7 @@ FullAttentionSpec, KVCacheConfig, KVCacheGroupSpec, + MambaSpec, SlidingWindowSpec, ) from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput @@ -423,7 +424,8 @@ def wait_for_save(self): def make_kv_cache_config( block_size: int, - hma_enabled: bool = False, + swa_enabled: bool = False, + mamba_enabled: bool = False, sw_size: int = 128, num_blocks: int = 100, ) -> KVCacheConfig: @@ -438,7 +440,7 @@ def make_kv_cache_config( ), ) ] - if hma_enabled: + if swa_enabled: kv_cache_groups.append( KVCacheGroupSpec( ["layer1", "layer3"], @@ -451,6 +453,32 @@ def make_kv_cache_config( ), ) ) + if mamba_enabled: + kv_cache_groups.append( + KVCacheGroupSpec( + ["mamba0", "mamba1"], + MambaSpec( + block_size=block_size, + shapes=((16,), (16,)), + dtypes=(torch.float16,), + ), + ) + ) return KVCacheConfig( num_blocks=num_blocks, kv_cache_tensors=[], kv_cache_groups=kv_cache_groups ) + + +def make_nixl_scheduler(has_mamba: bool = False, is_hma_required: bool = False): + """Create a NixlConnectorScheduler via __new__ (skipping __init__). + + Only sets the two flags needed by the N-1 prefill logic. + """ + from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import ( + NixlConnectorScheduler, + ) + + sched = object.__new__(NixlConnectorScheduler) + sched._has_mamba = has_mamba + sched._is_hma_required = is_hma_required + return sched diff --git a/tests/v1/kv_offload/test_cpu_gpu.py b/tests/v1/kv_offload/test_cpu_gpu.py index 9d14e3cff89e..3f4ef7d07f98 100644 --- a/tests/v1/kv_offload/test_cpu_gpu.py +++ b/tests/v1/kv_offload/test_cpu_gpu.py @@ -135,19 +135,19 @@ def test_transfer( # set transfer direction if gpu_to_cpu: handler = handlers.gpu_to_cpu_handler - src_spec_class = GPULoadStoreSpec - dst_spec_class = CPULoadStoreSpec src_blocks = gpu_blocks dst_blocks = cpu_blocks + src_spec = GPULoadStoreSpec(src_blocks, group_sizes=(len(src_blocks),)) + dst_spec = CPULoadStoreSpec(dst_blocks) src_blocks_in_kernel_block_size = gpu_blocks_in_kernel_block_size dst_blocks_in_kernel_block_size = cpu_blocks_in_kernel_block_size dst_size_in_kernel_blocks = num_cpu_blocks * kernel_blocks_per_cpu_block else: handler = handlers.cpu_to_gpu_handler - src_spec_class = CPULoadStoreSpec - dst_spec_class = GPULoadStoreSpec src_blocks = cpu_blocks dst_blocks = gpu_blocks + src_spec = CPULoadStoreSpec(src_blocks) + dst_spec = GPULoadStoreSpec(dst_blocks, group_sizes=(len(dst_blocks),)) src_blocks_in_kernel_block_size = cpu_blocks_in_kernel_block_size dst_blocks_in_kernel_block_size = gpu_blocks_in_kernel_block_size dst_size_in_kernel_blocks = num_gpu_blocks * kernel_blocks_per_gpu_block @@ -159,10 +159,6 @@ def test_transfer( ): dst_to_src[dst_block] = src_block - # build transfer specs - src_spec = src_spec_class(src_blocks) - dst_spec = dst_spec_class(dst_blocks) - # clone src and dst tensors before transfer orig_src_caches = [x.clone() for x in handler.src_tensors] orig_dst_caches = [x.clone() for x in handler.dst_tensors] diff --git a/tests/v1/kv_offload/test_cpu_offloading.py b/tests/v1/kv_offload/test_cpu_offloading.py index 103675608c69..d3db828dc60e 100644 --- a/tests/v1/kv_offload/test_cpu_offloading.py +++ b/tests/v1/kv_offload/test_cpu_offloading.py @@ -22,6 +22,17 @@ elif current_platform.is_rocm(): ATTN_BACKENDS = ["TRITON_ATTN"] +# Maximum time (seconds) to wait for the async CPU offload transfer +# to complete before giving up. +_RESET_CACHE_TIMEOUT = 30 if current_platform.is_rocm() else 10 + +# ZMQ poll timeout (ms) for the first event. +_FIRST_EVENT_POLL_MS = 10_000 if current_platform.is_rocm() else 1000 + +# Hard ceiling (seconds) on how long get_new_cpu_stored_events may loop, +# to prevent hangs if non-CPU events keep arriving indefinitely. +_EVENT_DRAIN_TIMEOUT = 60 + class MockSubscriber: """Helper class to receive and verify published events""" @@ -47,9 +58,10 @@ def get_new_cpu_stored_events(self) -> list[BlockStored]: poller = zmq.Poller() poller.register(self.sub, zmq.POLLIN) - timeout = 1000 # 1 second - while True: - events = dict(poller.poll(timeout)) + poll_ms = _FIRST_EVENT_POLL_MS + deadline = time.monotonic() + _EVENT_DRAIN_TIMEOUT + while time.monotonic() < deadline: + events = dict(poller.poll(poll_ms)) if events.get(self.sub) != zmq.POLLIN: return cpu_stored_events @@ -63,13 +75,32 @@ def get_new_cpu_stored_events(self) -> list[BlockStored]: for event in event_batch.events: if isinstance(event, BlockStored) and event.medium == "CPU": cpu_stored_events.append(event) - timeout = 100 + poll_ms = 100 + + return cpu_stored_events def close(self): """Clean up resources""" self.sub.close() +def _wait_for_prefix_cache_reset(llm: LLM) -> None: + """Wait for async offload transfers to finish so prefix cache can reset. + + The GPU-to-CPU offload runs on a CUDA stream asynchronously. While blocks + are still held by the offload worker, ``reset_prefix_cache`` returns + ``False``. Retry with a short sleep until it succeeds or we time out. + """ + deadline = time.monotonic() + _RESET_CACHE_TIMEOUT + while not llm.reset_prefix_cache(): + if time.monotonic() > deadline: + raise TimeoutError( + "reset_prefix_cache did not succeed within " + f"{_RESET_CACHE_TIMEOUT}s - async offload may be stuck" + ) + time.sleep(0.1) + + def _latency_test(llm: LLM, subscriber: MockSubscriber): sampling_params = SamplingParams(max_tokens=1) @@ -95,10 +126,16 @@ def _latency_test(llm: LLM, subscriber: MockSubscriber): gpu_hit_time = time.time() - start_time total_gpu_hit_time += gpu_hit_time - # reset prefix cache to avoid GPU hit. - llm.reset_prefix_cache() + # Wait for the async CPU offload to finish, then reset prefix cache + # so the next generate() must reload from CPU rather than GPU. + _wait_for_prefix_cache_reset(llm) - assert subscriber.get_new_cpu_stored_events() + # Verify CPU stored events arrived (offload is done before we + # attempt to load from CPU). + assert subscriber.get_new_cpu_stored_events(), ( + f"No CPU stored events received on iteration {i}; " + "async offload may not have completed in time" + ) # run generation again - this should trigger loading from CPU start_time = time.time() @@ -185,6 +222,8 @@ def test_cpu_offloading(cpu_block_size: int, attn_backend: str) -> None: kv_events_config=kv_events_config, kv_transfer_config=kv_transfer_config, attention_config={"backend": attn_backend}, + # ROCm: batch size 1 to reduce variability + **({"max_num_seqs": 1} if current_platform.is_rocm() else {}), ) events_endpoint = events_endpoint.replace("*", "127.0.0.1") diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index a01f44e1649d..a45caac7c9e2 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -2362,6 +2362,19 @@ def dsv3_router_gemm( return output +def gpt_oss_router_gemm( + hidden_states: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor +) -> torch.Tensor: + output = torch.empty( + hidden_states.shape[0], + weight.shape[0], + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + torch.ops._moe_C.gpt_oss_router_gemm(output, hidden_states, weight, bias) + return output + + def topk_softmax( topk_weights: torch.Tensor, topk_ids: torch.Tensor, diff --git a/vllm/benchmarks/datasets.py b/vllm/benchmarks/datasets.py index edd84403fea5..1e0a63dd6eb3 100644 --- a/vllm/benchmarks/datasets.py +++ b/vllm/benchmarks/datasets.py @@ -183,6 +183,68 @@ def get_random_lora_request( ) return lora_request + def get_round_robin_lora_request( + self, + index: int, + max_loras: int | None = None, + lora_path: str | None = None, + ) -> LoRARequest | None: + """ + Optionally select a LoRA request using deterministic round-robin. + + This method cycles through LoRA IDs in order based on the request + index, providing reproducible LoRA assignment. + + Args: + index (int): The request index used for round-robin selection. + max_loras (Optional[int]): The maximum number of LoRAs available. + If `None`, LoRA is not used. + lora_path (Optional[str]): Path to the LoRA parameters on disk. + If `None`, LoRA is not used. + + Returns: + A new [`LoRARequest`][vllm.lora.request.LoRARequest] + (or `None` if not applicable). + """ + if max_loras is None or lora_path is None: + return None + + # Deterministic round-robin: cycle through [1, max_loras] + lora_id = index % max_loras + 1 + lora_request = LoRARequest( + lora_name=str(lora_id), + lora_int_id=lora_id, + lora_path=lora_path_on_disk(lora_path), + ) + return lora_request + + def get_lora_request( + self, + index: int, + max_loras: int | None = None, + lora_path: str | None = None, + lora_assignment: str = "random", + ) -> LoRARequest | None: + """ + Select a LoRA request using the specified assignment strategy. + + Args: + index (int): The request index (used for round-robin). + max_loras (Optional[int]): The maximum number of LoRAs available. + lora_path (Optional[str]): Path to the LoRA parameters on disk. + lora_assignment (str): Strategy for LoRA selection. + 'random' (default) or 'round-robin'. + + Returns: + A new [`LoRARequest`][vllm.lora.request.LoRARequest] + (or `None` if not applicable). + """ + if lora_assignment == "round-robin": + return self.get_round_robin_lora_request( + index=index, max_loras=max_loras, lora_path=lora_path + ) + return self.get_random_lora_request(max_loras=max_loras, lora_path=lora_path) + @abstractmethod def sample( self, @@ -478,6 +540,9 @@ def sample( input_len: int = DEFAULT_INPUT_LEN, output_len: int = DEFAULT_OUTPUT_LEN, batchsize: int = 1, + max_loras: int | None = None, + lora_path: str | None = None, + lora_assignment: str = "random", **kwargs, ) -> list[SampleRequest]: # validate total input tokens (prefix + sampled) is at least 1. @@ -522,11 +587,18 @@ def sample( allowed_tokens=allowed_tokens, ) token_mismatch_total += token_mismatch + lora_req = self.get_lora_request( + index=i, + max_loras=max_loras, + lora_path=lora_path, + lora_assignment=lora_assignment, + ) requests.append( SampleRequest( prompt=prompt, prompt_len=total_input_len, expected_output_len=int(output_lens[i]), + lora_request=lora_req, request_id=request_id_prefix + str(i), ) ) @@ -1263,6 +1335,7 @@ def sample( enable_multimodal_chat: bool = False, request_id_prefix: str = "", no_oversample: bool = False, + lora_assignment: str = "random", **kwargs, ) -> list: samples: list = [] @@ -1275,8 +1348,11 @@ def sample( entry["conversations"][1]["value"], ) - lora_request = self.get_random_lora_request( - max_loras=max_loras, lora_path=lora_path + lora_request = self.get_lora_request( + index=ind, + max_loras=max_loras, + lora_path=lora_path, + lora_assignment=lora_assignment, ) prompt_ids = tokenizer(prompt).input_ids completion_ids = tokenizer(completion).input_ids @@ -2413,6 +2489,7 @@ def sample( lora_path: str | None = None, request_id_prefix: str = "", no_oversample: bool = False, + lora_assignment: str = "random", **kwargs, ) -> list[SampleRequest]: samples = [] @@ -2420,8 +2497,11 @@ def sample( for i in range(num_requests): input_len = int(data[i][2]) output_len = int(data[i][3]) - lora_req = self.get_random_lora_request( - max_loras=max_loras, lora_path=lora_path + lora_req = self.get_lora_request( + index=i, + max_loras=max_loras, + lora_path=lora_path, + lora_assignment=lora_assignment, ) vocab_size = tokenizer.vocab_size # Generate a synthetic prompt: a list of token IDs computed as (i + diff --git a/vllm/benchmarks/serve.py b/vllm/benchmarks/serve.py index fca01e17ea17..53ae6ca6a804 100644 --- a/vllm/benchmarks/serve.py +++ b/vllm/benchmarks/serve.py @@ -624,6 +624,7 @@ async def benchmark( lora_modules: Iterable[str] | None, extra_headers: dict | None, extra_body: dict | None, + lora_assignment: Literal["random", "round-robin"] = "random", ramp_up_strategy: Literal["linear", "exponential"] | None = None, ramp_up_start_rps: int | None = None, ramp_up_end_rps: int | None = None, @@ -731,10 +732,20 @@ async def warmup_limited_request_func(): print("Starting main benchmark run...") if lora_modules: - # For each input request, choose a LoRA module at random. - lora_modules = iter( - [random.choice(lora_modules) for _ in range(len(input_requests))] - ) + lora_modules_list = list(lora_modules) + if lora_assignment == "round-robin": + # Deterministic round-robin assignment across requests. + lora_modules = iter( + [ + lora_modules_list[i % len(lora_modules_list)] + for i in range(len(input_requests)) + ] + ) + else: + # For each input request, choose a LoRA module at random. + lora_modules = iter( + [random.choice(lora_modules_list) for _ in range(len(input_requests))] + ) if profile: print("Starting profiler...") @@ -1523,7 +1534,18 @@ def add_cli_args(parser: argparse.ArgumentParser): default=None, help="A subset of LoRA module names passed in when " "launching the server. For each request, the " - "script chooses a LoRA module at random.", + "script chooses a LoRA module at random by default. " + "Use --lora-assignment to control selection strategy.", + ) + + parser.add_argument( + "--lora-assignment", + type=str, + default="random", + choices=["random", "round-robin"], + help="Strategy for assigning LoRA modules to requests. " + "'random' (default) selects a LoRA at random for each request. " + "'round-robin' cycles through LoRA modules deterministically.", ) parser.add_argument( @@ -1788,6 +1810,7 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]: goodput_config_dict=goodput_config_dict, max_concurrency=args.max_concurrency, lora_modules=args.lora_modules, + lora_assignment=args.lora_assignment, extra_headers=headers, extra_body=extra_body, ramp_up_strategy=args.ramp_up_strategy, diff --git a/vllm/benchmarks/throughput.py b/vllm/benchmarks/throughput.py index ad6f44404613..1af8cf900b7a 100644 --- a/vllm/benchmarks/throughput.py +++ b/vllm/benchmarks/throughput.py @@ -350,6 +350,7 @@ def get_requests(args, tokenizer): "tokenizer": tokenizer, "lora_path": args.lora_path, "max_loras": args.max_loras, + "lora_assignment": getattr(args, "lora_assignment", "random"), "num_requests": args.num_prompts, } @@ -778,6 +779,15 @@ def add_cli_args(parser: argparse.ArgumentParser): help="Path to the lora adapters to use. This can be an absolute path, " "a relative path, or a Hugging Face model identifier.", ) + parser.add_argument( + "--lora-assignment", + type=str, + default="random", + choices=["random", "round-robin"], + help="Strategy for assigning LoRA adapters to requests. " + "'random' (default) selects a LoRA at random for each request. " + "'round-robin' cycles through LoRAs deterministically.", + ) parser.add_argument( "--prefix-len", type=int, diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 51dff720b307..3526099dc7dc 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -371,13 +371,15 @@ def autograd_cache_key(*args, **kwargs): logger.info_once( "Cache the graph of compile range %s for later use", str(compile_range), + scope="local", ) - logger.debug( + logger.debug_once( "Store the %s-th graph for compile range%s from %s via handle %s", graph_index, str(compile_range), self.compiler.name, handle, + scope="local", ) # after compiling the last graph, record the end time diff --git a/vllm/config/parallel.py b/vllm/config/parallel.py index d4048a4731ef..add011ca40a9 100644 --- a/vllm/config/parallel.py +++ b/vllm/config/parallel.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import os +import socket from collections.abc import Callable from typing import TYPE_CHECKING, Any, Literal, overload @@ -266,33 +267,9 @@ class is dynamically inherited by the worker class. This is used to inject Set to be private as it's not intended to be configured by users. """ - _stateless_dp_group_port_list: list[list[int]] = Field(default_factory=list) - """List of open ports for stateless DP groups when enable_elastic_ep is True. - Set to be private as it's not intended to be configured by users. - It is a list of list[int], with each inner list contains a set of 3 ports - to be used for setting up the stateless CPU/device/TCPStore groups - in StatelessGroupCoordinator. The number of inner lists is equal to - the number of DP groups, - i.e., len(self._stateless_dp_group_port_list) == world_size_across_dp // dp_size, - and len(self._stateless_dp_group_port_list[i]) == 3 for all i. - """ - - _stateless_ep_group_port_list: list[list[int]] = Field(default_factory=list) - """List of open ports for stateless EP groups when enable_elastic_ep is True. - Set to be private as it's not intended to be configured by users. - len(self._stateless_ep_group_port_list) == world_size_across_dp // ep_size, - """ - - _stateless_eplb_group_port_list: list[list[int]] = Field(default_factory=list) - """List of open ports for stateless EPLB groups when enable_elastic_ep is True. - Same topology as EP but separate NCCL communicator to avoid deadlocks. - """ - - _stateless_world_group_port_list: list[list[int]] = Field(default_factory=list) - """List of open ports for stateless world group when enable_elastic_ep is True. - Set to be private as it's not intended to be configured by users. - len(self._stateless_world_group_port_list) == 1, - """ + _coord_store_port: int = 0 + """Port of the coordination TCPStore. Can be set by the API server; workers + connect as clients to exchange self-picked group ports at runtime.""" decode_context_parallel_size: int = 1 """Number of decode context parallel groups, because the world size does @@ -465,65 +442,32 @@ def get_next_dp_init_port(self) -> int: return answer - def allocate_elastic_ep_ports(self) -> None: - """Allocate all ports for elastic EP (stateless groups + DP master). + def _pick_stateless_dp_port(self) -> tuple[int, socket.socket | None]: + """Return ``(port, listen_socket)`` for DP group init. - Must be called AFTER ray.init() so that ports claimed by Ray's - idle worker pool are already in use and won't be returned by - get_open_ports_list(). + With a coord store, rank 0 binds a socket and publishes the port; + others read it. Without one, pops a pre-allocated port and + returns ``listen_socket=None``. """ - if not self.enable_elastic_ep: - return - if self._stateless_world_group_port_list: - return - - num_world_groups = 1 - dp_size = self.data_parallel_size - ep_size = self.data_parallel_size * self.world_size_across_dp - num_dp_groups = max(1, self.world_size_across_dp // dp_size) - num_ep_groups = max(1, self.world_size_across_dp // ep_size) - num_eplb_groups = num_ep_groups - total_stateless_ports = ( - num_world_groups + num_dp_groups + num_ep_groups + num_eplb_groups - ) * 3 - num_dp_master_ports = 5 - - all_ports = get_open_ports_list(total_stateless_ports + num_dp_master_ports) - - self._data_parallel_master_port_list = all_ports[-num_dp_master_ports:] - self.data_parallel_master_port = self._data_parallel_master_port_list.pop() - all_ports = all_ports[:-num_dp_master_ports] - - self._stateless_world_group_port_list = [ - all_ports[i : i + 3] for i in range(0, num_world_groups * 3, 3) - ] - start_idx = num_world_groups * 3 - self._stateless_dp_group_port_list = [ - all_ports[i : i + 3] - for i in range(start_idx, start_idx + num_dp_groups * 3, 3) - ] - start_idx += num_dp_groups * 3 - self._stateless_ep_group_port_list = [ - all_ports[i : i + 3] - for i in range(start_idx, start_idx + num_ep_groups * 3, 3) - ] - start_idx += num_ep_groups * 3 - self._stateless_eplb_group_port_list = [ - all_ports[i : i + 3] - for i in range(start_idx, start_idx + num_eplb_groups * 3, 3) - ] - - def get_next_stateless_world_group_port(self) -> list[int]: - return self._stateless_world_group_port_list.pop() - - def get_next_stateless_dp_group_port(self) -> list[int]: - return self._stateless_dp_group_port_list.pop() - - def get_next_stateless_ep_group_port(self) -> list[int]: - return self._stateless_ep_group_port_list.pop() - - def get_next_stateless_eplb_group_port(self) -> list[int]: - return self._stateless_eplb_group_port_list.pop() + if not self._coord_store_port: + return self.get_next_dp_init_port(), None + + from vllm.distributed.utils import get_cached_tcp_store_client + + store = get_cached_tcp_store_client( + self.data_parallel_master_ip, self._coord_store_port + ) + + key = "dp_master_port" + if self.data_parallel_rank == 0: + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + s.bind((self.data_parallel_master_ip, 0)) + s.listen() + port = s.getsockname()[1] + store.set(key, str(port).encode()) + return port, s + else: + return int(store.get(key).decode()), None @overload def stateless_init_dp_group( @@ -553,14 +497,16 @@ def stateless_init_dp_group( last_exc: Exception | None = None for _ in range(max_retries): try: + port, listen_socket = self._pick_stateless_dp_port() # use gloo since the engine process might not have cuda device return stateless_init_torch_distributed_process_group( self.data_parallel_master_ip, - self.get_next_dp_init_port(), + port, self.data_parallel_rank, self.data_parallel_size, backend="gloo", return_store=return_store, + listen_socket=listen_socket, ) except DistNetworkError as e: # We only want to retry when the root cause is EADDRINUSE. diff --git a/vllm/config/scheduler.py b/vllm/config/scheduler.py index 9f6284c4b389..584080ae12a0 100644 --- a/vllm/config/scheduler.py +++ b/vllm/config/scheduler.py @@ -228,9 +228,10 @@ def __post_init__(self, max_model_len: int, is_encoder_decoder: bool) -> None: self.encoder_cache_size = self.max_num_batched_tokens if self.enable_chunked_prefill: - logger.info( + logger.info_once( "Chunked prefill is enabled with max_num_batched_tokens=%d.", self.max_num_batched_tokens, + scope="local", ) if self.max_num_partial_prefills > 1: diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 8cd114481053..948335d6cd61 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -682,12 +682,11 @@ def __post_init__(self): self.model_config, self.load_config ) + from vllm.v1.executor.abstract import Executor + executor_backend = self.parallel_config.distributed_executor_backend - executor_supports_async_sched = executor_backend in ( - "mp", - "uni", - "external_launcher", - ) + executor_class = Executor.get_class(self) + executor_supports_async_sched = executor_class.supports_async_scheduling() if self.scheduler_config.async_scheduling: # Async scheduling explicitly enabled, hard fail any incompatibilities. @@ -711,9 +710,7 @@ def __post_init__(self): ) if not executor_supports_async_sched: raise ValueError( - "Currently, async scheduling only supports `mp`, `uni`, or " - "`external_launcher` distributed executor backend, but you chose " - f"`{executor_backend}`." + f"`{executor_backend}` does not support async scheduling yet." ) elif self.scheduler_config.async_scheduling is None: # Enable async scheduling unless there is an incompatible option. @@ -742,8 +739,7 @@ def __post_init__(self): elif not executor_supports_async_sched: logger.warning_once( "Async scheduling will be disabled because it is not supported " - "with the `%s` distributed executor backend (only `mp`, `uni`, and " - "`external_launcher` are supported).", + "with the `%s` distributed executor backend. ", executor_backend, scope="local", ) diff --git a/vllm/distributed/elastic_ep/elastic_execute.py b/vllm/distributed/elastic_ep/elastic_execute.py index 516d2c256726..00ac6d84b425 100644 --- a/vllm/distributed/elastic_ep/elastic_execute.py +++ b/vllm/distributed/elastic_ep/elastic_execute.py @@ -162,10 +162,8 @@ def create_standby_groups( new_dp_size=new_dp_size, new_world_size_across_dp=new_world_size_across_dp, master_ip=reconfig_request.new_data_parallel_master_ip, - world_group_ports=reconfig_request.new_stateless_world_group_port_list, - dp_group_ports=reconfig_request.new_stateless_dp_group_port_list, - ep_group_ports=reconfig_request.new_stateless_ep_group_port_list, - eplb_group_ports=reconfig_request.new_stateless_eplb_group_port_list, + coord_store_port=reconfig_request.coord_store_port, + enable_eplb=updated_config.parallel_config.enable_eplb, ) self.worker.model_runner.eep_eplb_suppressed = True standby_ep_group = get_standby_ep_group() diff --git a/vllm/distributed/elastic_ep/elastic_state.py b/vllm/distributed/elastic_ep/elastic_state.py index fce0d83611d9..cd989a49a2b8 100644 --- a/vllm/distributed/elastic_ep/elastic_state.py +++ b/vllm/distributed/elastic_ep/elastic_state.py @@ -563,15 +563,4 @@ def _update_parallel_config(self): parallel_config._data_parallel_master_port_list = ( reconfig_request.new_data_parallel_master_port_list ) - parallel_config._stateless_world_group_port_list = ( - reconfig_request.new_stateless_world_group_port_list - ) - parallel_config._stateless_dp_group_port_list = ( - reconfig_request.new_stateless_dp_group_port_list - ) - parallel_config._stateless_ep_group_port_list = ( - reconfig_request.new_stateless_ep_group_port_list - ) - parallel_config._stateless_eplb_group_port_list = ( - reconfig_request.new_stateless_eplb_group_port_list - ) + parallel_config._coord_store_port = reconfig_request.coord_store_port diff --git a/vllm/distributed/elastic_ep/standby_state.py b/vllm/distributed/elastic_ep/standby_state.py index d11e0b550531..846793a955f6 100644 --- a/vllm/distributed/elastic_ep/standby_state.py +++ b/vllm/distributed/elastic_ep/standby_state.py @@ -38,10 +38,8 @@ def create_standby_groups( new_dp_size: int, new_world_size_across_dp: int, master_ip: str, - world_group_ports: list[list[int]], - dp_group_ports: list[list[int]], - ep_group_ports: list[list[int]], - eplb_group_ports: list[list[int]] | None = None, + coord_store_port: int, + enable_eplb: bool = True, backend: str | None = None, ) -> None: global \ @@ -51,19 +49,23 @@ def create_standby_groups( _STANDBY_EP, \ _STANDBY_EPLB + from vllm.distributed.utils import get_cached_tcp_store_client + assert new_world_size_across_dp == torch.distributed.get_world_size() * new_dp_size world_group = get_world_group() assert isinstance(world_group, StatelessGroupCoordinator) backend = backend or world_group.backend + coord_store = get_cached_tcp_store_client(master_ip, coord_store_port) + standby_world_ranks = [list(range(new_world_size_across_dp))] _STANDBY_WORLD = _init_stateless_group( standby_world_ranks, "world", - world_group_ports, master_ip, backend, use_device_communicator=False, + coord_store=coord_store, ) _STANDBY_WORLD_NODE_COUNT = _node_count(_STANDBY_WORLD.tcp_store_group) @@ -76,7 +78,7 @@ def create_standby_groups( standby_dp_ranks = all_ranks.transpose(1, 3).reshape(-1, new_dp_size).unbind(0) standby_dp_ranks = [x.tolist() for x in standby_dp_ranks] _STANDBY_DP = _init_stateless_group( - standby_dp_ranks, "dp", dp_group_ports, master_ip, backend + standby_dp_ranks, "dp", master_ip, backend, coord_store=coord_store ) standby_ep_ranks = ( @@ -84,12 +86,16 @@ def create_standby_groups( ) standby_ep_ranks = [x.tolist() for x in standby_ep_ranks] _STANDBY_EP = _init_stateless_group( - standby_ep_ranks, "ep", ep_group_ports, master_ip, backend + standby_ep_ranks, "ep", master_ip, backend, coord_store=coord_store ) - if eplb_group_ports is not None: + if enable_eplb: _STANDBY_EPLB = _init_stateless_group( - standby_ep_ranks, "eplb", eplb_group_ports, master_ip, backend + standby_ep_ranks, + "eplb", + master_ip, + backend, + coord_store=coord_store, ) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py index 2abbe6bf610a..ef143cba7fb5 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -25,8 +25,8 @@ Worker-side: runs in each worker, loads/saves KV cache to/from the Connector based on the metadata. - handle_preemptions() - called if there are preempted requests, - before their blocks are overwritten + handle_preemptions() - called for handling preempted requests + or request evicted blocks before they are overwritten start_load_kv() - starts loading all KVs (maybe async) wait_for_layer_load() - blocks until layer i load is done @@ -288,9 +288,9 @@ def set_host_xfer_buffer_ops(self, copy_operation: CopyBlocksOp): """ return - def handle_preemptions(self, preempted_req_ids: set[str]): + def handle_preemptions(self, kv_connector_metadata: KVConnectorMetadata): """ - Handle preempted requests BEFORE their blocks are overwritten. + Handle preempted requests or evicted blocks BEFORE they are overwritten. Needed for connectors which use async saves (e.g., OffloadingConnector) """ return diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py index 7cc80129a3a1..3888d2e0f44c 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py @@ -315,10 +315,11 @@ def set_host_xfer_buffer_ops(self, copy_operation: CopyBlocksOp): for c in self._connectors: c.set_host_xfer_buffer_ops(copy_operation) - def handle_preemptions(self, preempted_req_ids: set[str]): + def handle_preemptions(self, kv_connector_metadata: KVConnectorMetadata): """Handle preempted requests for all sub-connectors.""" - for c in self._connectors: - c.handle_preemptions(preempted_req_ids) + assert isinstance(kv_connector_metadata, MultiKVConnectorMetadata) + for c, cm in zip(self._connectors, kv_connector_metadata.metadata): + c.handle_preemptions(cm) def get_finished_count(self) -> int | None: # TODO(https://github.com/vllm-project/vllm/issues/33400) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 9001e31810ff..ed53c35c9ed9 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -572,6 +572,10 @@ def __init__( for g in kv_cache_config.kv_cache_groups ) ) + self._has_mamba = any( + isinstance(g.kv_cache_spec, MambaSpec) + for g in kv_cache_config.kv_cache_groups + ) logger.info("Initializing NIXL Scheduler %s", engine_id) if vllm_config.scheduler_config.disable_hybrid_kv_cache_manager: @@ -717,6 +721,39 @@ def _nixl_handshake_listener( logger.warning("Connection listener got unexpected message %s", msg) sock.send_multipart((identity, b"", encoded_data[target_tp_rank])) + def _mamba_prefill_token_count(self, num_prompt_tokens: int) -> int: + """D-side only. Returns N-1 for Mamba models since the decoder + always recomputes the last token and must start from h(N-1).""" + if self._has_mamba and num_prompt_tokens > 1: + return num_prompt_tokens - 1 + return num_prompt_tokens + + def _truncate_mamba_request_for_prefill(self, request: "Request") -> None: + """P-side only: drop the last prompt token so the prefiller computes + h(N-1) instead of h(N). The decoder recomputes the last token to + derive h(N) correctly. + + Guarded by ``_p_side_truncated`` to avoid repeated truncation if the + request is preempted and rescheduled.""" + params = request.kv_transfer_params + if ( + params is not None + # Guard against repeated truncation after preemption/reschedule. + and not params.get("_p_side_truncated") + and request.num_prompt_tokens > 1 + ): + if request.prompt_token_ids is not None: + request.prompt_token_ids.pop() + elif request.prompt_embeds is not None: + request.prompt_embeds = request.prompt_embeds[:-1] + else: + return + + request._all_token_ids.pop() + request.num_prompt_tokens -= 1 + request.max_tokens = 1 + params["_p_side_truncated"] = True + def get_num_new_matched_tokens( self, request: "Request", num_computed_tokens: int ) -> tuple[int, bool]: @@ -746,10 +783,14 @@ def get_num_new_matched_tokens( if params is not None and params.get("do_remote_prefill"): # Remote prefill: get all prompt blocks from remote. token_ids = request.prompt_token_ids or [] - count = len(token_ids) - num_computed_tokens + actual = self._mamba_prefill_token_count(len(token_ids)) + count = actual - num_computed_tokens if count > 0: return count, True + if params is not None and params.get("do_remote_decode") and self._has_mamba: + self._truncate_mamba_request_for_prefill(request) + # No remote prefill for this request. return 0, False @@ -1318,12 +1359,12 @@ def _nixl_handshake( f"Expected {expected_engine_id}," f"received {metadata.engine_id}." ) - setup_agent_time = time.perf_counter() # Register Remote agent. remote_agent_name = self.add_remote_agent( metadata, remote_rank, remote_tp_size ) + setup_agent_time = time.perf_counter() logger.debug( "NIXL handshake: add agent took: %s", setup_agent_time - got_metadata_time, diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/offloading/__init__.py b/vllm/distributed/kv_transfer/kv_connector/v1/offloading/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/offloading/common.py b/vllm/distributed/kv_transfer/kv_connector/v1/offloading/common.py new file mode 100644 index 000000000000..06a727a27b55 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/v1/offloading/common.py @@ -0,0 +1,15 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from dataclasses import dataclass + +from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorMetadata +from vllm.v1.kv_offload.worker.worker import TransferSpec + +ReqId = str + + +@dataclass +class OffloadingConnectorMetadata(KVConnectorMetadata): + reqs_to_load: dict[ReqId, TransferSpec] + reqs_to_store: dict[ReqId, TransferSpec] + reqs_to_flush: set[str] | None = None diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/offloading/metrics.py b/vllm/distributed/kv_transfer/kv_connector/v1/offloading/metrics.py new file mode 100644 index 000000000000..0839b2727ccc --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/v1/offloading/metrics.py @@ -0,0 +1,165 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from dataclasses import dataclass +from typing import Any + +from vllm.config import VllmConfig +from vllm.distributed.kv_transfer.kv_connector.v1.metrics import ( + KVConnectorPromMetrics, + KVConnectorStats, + PromMetric, + PromMetricT, +) +from vllm.logger import init_logger +from vllm.v1.kv_offload.worker.worker import TransferType + +logger = init_logger(__name__) + + +@dataclass +class OffloadingOperationMetrics: + op_size: int + op_time: float + + +@dataclass +class OffloadingConnectorStats(KVConnectorStats): + def __post_init__(self): + if not self.data: + # Empty container init, no data is passed in. + self.reset() + + def reset(self): + self.data: dict[str, list[OffloadingOperationMetrics]] = {} + + def aggregate(self, other: KVConnectorStats) -> KVConnectorStats: + if not other.is_empty(): + for k, v in other.data.items(): + if k not in self.data: + self.data[k] = v + else: + accumulator = self.data[k] + assert isinstance(accumulator, list) + accumulator.extend(v) + return self + + def reduce(self) -> dict[str, int | float]: + """ + Reduce the observations collected during a time interval to one or + more representative values (eg avg/median/sum of the series). + This is meant to be called by the logger to produce a summary of the + stats for the last time interval. + """ + return_dict: dict[str, int | float] = {} + for transfer_type, ops_list in self.data.items(): + assert isinstance(ops_list, list) + total_bytes = 0 + total_time = 0.0 + for op in ops_list: + assert isinstance(op, dict) + total_bytes += op["op_size"] + total_time += op["op_time"] + return_dict[f"{transfer_type}_total_bytes"] = total_bytes + return_dict[f"{transfer_type}_total_time"] = total_time + return return_dict + + def is_empty(self) -> bool: + return not self.data + + def record_transfer(self, num_bytes: int, time: float, transfer_type: TransferType): + src, dst = transfer_type + transfer_type_key = src + "_to_" + dst + op = OffloadingOperationMetrics(num_bytes, time) + if transfer_type_key in self.data: + self.data[transfer_type_key].append(op) + else: + self.data[transfer_type_key] = [op] + + +class OffloadPromMetrics(KVConnectorPromMetrics): + def __init__( + self, + vllm_config: VllmConfig, + metric_types: dict[type[PromMetric], type[PromMetricT]], + labelnames: list[str], + per_engine_labelvalues: dict[int, list[object]], + ): + super().__init__(vllm_config, metric_types, labelnames, per_engine_labelvalues) + # (engine_idx, transfer_type) -> (metric with bounded labels) + self.histogram_transfer_size: dict[tuple[int, str], PromMetricT] = {} + self.counter_kv_bytes: dict[tuple[int, str], PromMetricT] = {} + self.counter_kv_transfer_time: dict[tuple[int, str], PromMetricT] = {} + buckets = [ # In bytes + 1e6, + 5e6, + 10e6, + 20e6, + 40e6, + 60e6, + 80e6, + 100e6, + 150e6, + 200e6, + ] + + self._counter_kv_bytes = self._counter_cls( + name="vllm:kv_offload_total_bytes", + documentation="Number of bytes offloaded by KV connector", + labelnames=labelnames + ["transfer_type"], + ) + + self._counter_kv_transfer_time = self._counter_cls( + name="vllm:kv_offload_total_time", + documentation="Total time measured by all KV offloading operations", + labelnames=labelnames + ["transfer_type"], + ) + + self._histogram_transfer_size = self._histogram_cls( + name="vllm:kv_offload_size", + documentation="Histogram of KV offload transfer size, in bytes.", + buckets=buckets[:], + labelnames=labelnames + ["transfer_type"], + ) + + def observe(self, transfer_stats_data: dict[str, Any], engine_idx: int = 0): + """ + Observe transfer statistics from the new data structure. + transfer_stats_data is expected to be a dict where: + - keys are transfer type strings (e.g., "cpu_to_gpu", "gpu_to_cpu") + - values are lists of OffloadingOperationMetrics objects + """ + + for transfer_type, ops in transfer_stats_data.items(): + # Cache: + if (engine_idx, transfer_type) not in self.histogram_transfer_size: + self.histogram_transfer_size[(engine_idx, transfer_type)] = ( + self._histogram_transfer_size.labels( + *(self.per_engine_labelvalues[engine_idx] + [transfer_type]) + ) + ) + self.counter_kv_bytes[(engine_idx, transfer_type)] = ( + self._counter_kv_bytes.labels( + *(self.per_engine_labelvalues[engine_idx] + [transfer_type]) + ) + ) + self.counter_kv_transfer_time[(engine_idx, transfer_type)] = ( + self._counter_kv_transfer_time.labels( + *(self.per_engine_labelvalues[engine_idx] + [transfer_type]) + ) + ) + + # Process ops: + assert isinstance(ops, list) + for op in ops: # ops is a list of serialized OffloadingOperationMetrics + assert isinstance(op, dict) + # Observe size histogram + self.histogram_transfer_size[(engine_idx, transfer_type)].observe( + op["op_size"] + ) + + # Increment byte and time counters + self.counter_kv_bytes[(engine_idx, transfer_type)].inc(op["op_size"]) + + self.counter_kv_transfer_time[(engine_idx, transfer_type)].inc( + op["op_time"] + ) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/offloading/scheduler.py b/vllm/distributed/kv_transfer/kv_connector/v1/offloading/scheduler.py new file mode 100644 index 000000000000..c28fe5e96593 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/v1/offloading/scheduler.py @@ -0,0 +1,353 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections import defaultdict +from collections.abc import Iterable +from itertools import islice +from typing import Any + +from vllm.distributed.kv_events import BlockRemoved, BlockStored, KVCacheEvent +from vllm.distributed.kv_transfer.kv_connector.utils import yield_req_data +from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorMetadata +from vllm.distributed.kv_transfer.kv_connector.v1.offloading.common import ( + OffloadingConnectorMetadata, + ReqId, +) +from vllm.logger import init_logger +from vllm.v1.core.kv_cache_manager import KVCacheBlocks +from vllm.v1.core.kv_cache_utils import BlockHash +from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.kv_offload.abstract import OffloadingManager +from vllm.v1.kv_offload.mediums import GPULoadStoreSpec +from vllm.v1.kv_offload.spec import OffloadingSpec +from vllm.v1.kv_offload.worker.worker import TransferSpec +from vllm.v1.outputs import KVConnectorOutput +from vllm.v1.request import Request + +logger = init_logger(__name__) + + +class OffloadingConnectorScheduler: + """Implementation of Scheduler side methods""" + + def __init__(self, spec: OffloadingSpec): + assert len(spec.gpu_block_size) == 1 + self.gpu_block_size = spec.gpu_block_size[0] + self.offloaded_block_size = self.gpu_block_size * spec.block_size_factor + self.block_size_factor = spec.block_size_factor + self.manager: OffloadingManager = spec.get_manager() + + self._requests: dict[ReqId, Request] = {} + # list of GPU block IDs per request + self._request_block_ids: dict[ReqId, list[int]] = {} + # requests to load for the current scheduler step + self._reqs_to_load: dict[ReqId, TransferSpec] = {} + # request blocks are stored in order + # index of next block (of size offloaded_block_size) to offload + self._next_stored_block_idx: dict[ReqId, int] = {} + # if GPU prefix caching is enabled, + # track loaded blocks to avoid redundant loads + self._blocks_being_loaded: set[BlockHash] | None = ( + set() if spec.vllm_config.cache_config.enable_prefix_caching else None + ) + + # request ID -> set(block hashes being stored/load) + self._reqs_being_stored = defaultdict[ReqId, set[BlockHash]](set) + self._reqs_being_loaded = defaultdict[ReqId, set[BlockHash]](set) + + def _get_block_hashes( + self, + req: Request, + start_idx: int = 0, + end_idx: int | None = None, + ) -> Iterable[BlockHash]: + return islice( + req.block_hashes, + self.block_size_factor * start_idx + self.block_size_factor - 1, + self.block_size_factor * end_idx if end_idx else None, + self.block_size_factor, + ) + + def get_num_new_matched_tokens( + self, request: Request, num_computed_tokens: int + ) -> tuple[int | None, bool]: + """ + Get number of new tokens that can be loaded beyond the + num_computed_tokens. + + Args: + request (Request): the request object. + num_computed_tokens (int): the number of locally + computed tokens for this request + + Returns: + A tuple with the following elements: + - The number of tokens that can be loaded beyond what is + already computed. + If None, it means that the connector needs more time to + determine the number of matched tokens, and the scheduler + should query for this request again later. + - `True` if tokens will be loaded asynchronously + (between scheduler steps). + """ + num_blocks = request.num_tokens // self.offloaded_block_size + + assert len(request.block_hashes) // self.block_size_factor == num_blocks + block_hashes = self._get_block_hashes(request) + + self.manager.touch(block_hashes) + + full_block_tokens = self.offloaded_block_size * num_blocks + if full_block_tokens - num_computed_tokens < self.offloaded_block_size: + # we can load less than a block, skip + return 0, False + + start_block_idx = num_computed_tokens // self.offloaded_block_size + hits = self.manager.lookup( + self._get_block_hashes(request, start_idx=start_block_idx) + ) + if hits is None: + # indicates a lookup that should be tried later + return None, False + if hits == 0: + return 0, False + + num_hit_tokens = ( + self.offloaded_block_size * (start_block_idx + hits) - num_computed_tokens + ) + logger.debug( + "Request %s hit %s offloaded tokens after %s GPU hit tokens", + request.request_id, + num_hit_tokens, + num_computed_tokens, + ) + if num_hit_tokens < self.offloaded_block_size: + return 0, False + + if self._blocks_being_loaded: + block_hashes = self._get_block_hashes( + request, start_idx=start_block_idx, end_idx=start_block_idx + hits + ) + + if any( + block_hash in self._blocks_being_loaded for block_hash in block_hashes + ): + # hit blocks are being loaded, delay request + logger.debug( + "Delaying request %s since some of its blocks are already" + " being loaded", + request.request_id, + ) + return None, False + + return num_hit_tokens, True + + def update_state_after_alloc( + self, request: Request, blocks: KVCacheBlocks, num_external_tokens: int + ): + self._requests[request.request_id] = request + # the block ids are updated in _get_reqs_to_store + self._request_block_ids[request.request_id] = [] + + if num_external_tokens == 0: + return + + block_groups = blocks.get_block_ids() + block_ids = block_groups[0] + + num_computed_gpu_blocks = sum( + block.block_hash is not None for block in blocks.blocks[0] + ) + num_computed_tokens = num_computed_gpu_blocks * self.gpu_block_size + full_block_tokens = num_computed_tokens + num_external_tokens + assert full_block_tokens % self.offloaded_block_size == 0 + + num_pending_gpu_blocks = len(block_ids) - num_computed_gpu_blocks + assert num_external_tokens == num_pending_gpu_blocks * self.gpu_block_size + + start_block_idx = num_computed_tokens // self.offloaded_block_size + num_blocks = full_block_tokens // self.offloaded_block_size + + assert len(request.block_hashes) // self.block_size_factor >= num_blocks + block_hashes = self._get_block_hashes( + request, start_idx=start_block_idx, end_idx=num_blocks + ) + + src_spec = self.manager.prepare_load(block_hashes) + dst_spec = GPULoadStoreSpec( + block_ids[num_computed_gpu_blocks:], + group_sizes=(num_pending_gpu_blocks,), + block_indices=(num_computed_gpu_blocks,), + ) + + block_hashes = self._get_block_hashes( + request, start_idx=start_block_idx, end_idx=num_blocks + ) + + self._reqs_to_load[request.request_id] = (src_spec, dst_spec) + req_blocks_being_loaded = self._reqs_being_loaded[request.request_id] + req_blocks_being_loaded.update(block_hashes) + self._next_stored_block_idx[request.request_id] = num_blocks + + if self._blocks_being_loaded is not None: + self._blocks_being_loaded.update(req_blocks_being_loaded) + + def _get_reqs_to_store(self, scheduler_output: SchedulerOutput): + reqs_to_store: dict[ReqId, TransferSpec] = {} + # iterate over both new and cached requests + for req_id, new_block_id_groups, preempted in yield_req_data(scheduler_output): + if preempted: + self._request_block_ids[req_id] = [] + + if new_block_id_groups: + new_block_ids = new_block_id_groups[0] + self._request_block_ids[req_id] += new_block_ids + + block_ids = self._request_block_ids[req_id] + + req = self._requests[req_id] + new_tokens = scheduler_output.num_scheduled_tokens[req_id] + expected_tokens = req.num_computed_tokens + new_tokens + # with async scheduling, some tokens may be missing + total_tokens = min(expected_tokens, req.num_tokens) + num_blocks = total_tokens // self.offloaded_block_size + start_block_idx = self._next_stored_block_idx.get(req_id, 0) + num_new_blocks = num_blocks - start_block_idx + + if num_new_blocks <= 0: + continue + + num_gpu_blocks = num_blocks * self.block_size_factor + assert len(req.block_hashes) >= num_gpu_blocks + + new_block_hashes = self._get_block_hashes( + req, start_idx=start_block_idx, end_idx=num_blocks + ) + store_output = self.manager.prepare_store(new_block_hashes) + if store_output is None: + logger.warning( + "Request %s: cannot store %s blocks", req_id, num_new_blocks + ) + continue + + self._next_stored_block_idx[req_id] = num_blocks + + if not store_output.block_hashes_to_store: + continue + block_hashes_to_store = set(store_output.block_hashes_to_store) + + block_hashes = self._get_block_hashes(req, end_idx=num_blocks) + self.manager.touch(block_hashes) + + new_block_hashes = self._get_block_hashes( + req, start_idx=start_block_idx, end_idx=num_blocks + ) + dst_spec = store_output.store_spec + src_block_ids: list[int] = [] + for idx, blk_hash in enumerate(new_block_hashes): + if blk_hash not in block_hashes_to_store: + continue + offloaded_block_idx = start_block_idx + idx + gpu_block_idx = offloaded_block_idx * self.block_size_factor + for i in range(self.block_size_factor): + src_block_ids.append(block_ids[gpu_block_idx + i]) + src_spec = GPULoadStoreSpec( + src_block_ids, group_sizes=(len(src_block_ids),) + ) + + reqs_to_store[req_id] = (src_spec, dst_spec) + self._reqs_being_stored[req_id] |= block_hashes_to_store + + logger.debug( + "Request %s offloading %s blocks starting from block #%d", + req_id, + len(block_hashes_to_store), + start_block_idx, + ) + + return reqs_to_store + + def build_connector_meta( + self, scheduler_output: SchedulerOutput + ) -> KVConnectorMetadata: + meta = OffloadingConnectorMetadata( + reqs_to_load=self._reqs_to_load, + reqs_to_store=self._get_reqs_to_store(scheduler_output), + reqs_to_flush=scheduler_output.preempted_req_ids, + ) + self._reqs_to_load = {} + + # NOTE (orozery): we should move this logic to update_connector_output + # once KVConnectorOutput allows us to report completed transfers + for req_id in scheduler_output.preempted_req_ids or (): + block_hashes = self._reqs_being_stored.get(req_id) + if block_hashes: + self.manager.complete_store(block_hashes) + block_hashes.clear() + + return meta + + def update_connector_output(self, connector_output: KVConnectorOutput): + """ + Update KVConnector state from worker-side connectors output. + + Args: + connector_output (KVConnectorOutput): the worker-side + connectors output. + """ + for req_id in connector_output.finished_sending or []: + block_hashes = self._reqs_being_stored.pop(req_id, None) + if block_hashes: + self.manager.complete_store(block_hashes) + + for req_id in connector_output.finished_recving or []: + block_hashes = self._reqs_being_loaded.pop(req_id, None) + if block_hashes: + if self._blocks_being_loaded: + self._blocks_being_loaded.difference_update(block_hashes) + self.manager.complete_load(block_hashes) + + def request_finished( + self, + request: Request, + block_ids: list[int], + ) -> tuple[bool, dict[str, Any] | None]: + """ + Called when a request has finished, before its blocks are freed. + + Returns: + True if the request is being saved/sent asynchronously and blocks + should not be freed until the request_id is returned from + get_finished(). + Optional KVTransferParams to be included in the request outputs + returned by the engine. + """ + req_id = request.request_id + self._requests.pop(req_id, None) + self._request_block_ids.pop(req_id, None) + + # TODO(orozery): possibly kickoff offload for last block + # which may have been deferred due to async scheduling + self._next_stored_block_idx.pop(req_id, None) + + request_being_stored = req_id in self._reqs_being_stored + return request_being_stored, None + + def take_events(self) -> Iterable[KVCacheEvent]: + """Take the KV cache events from the connector. + + Returns: + A list of KV cache events. + """ + for event in self.manager.take_events(): + if event.removed: + yield BlockRemoved(block_hashes=event.block_hashes, medium=event.medium) + else: + yield BlockStored( + block_hashes=event.block_hashes, + parent_block_hash=None, + token_ids=[], + lora_id=None, + block_size=event.block_size, + medium=event.medium, + lora_name=None, + ) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/offloading/worker.py b/vllm/distributed/kv_transfer/kv_connector/v1/offloading/worker.py new file mode 100644 index 000000000000..63f1d0133f3c --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/v1/offloading/worker.py @@ -0,0 +1,185 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections import defaultdict + +import torch + +from vllm.config import get_layers_from_vllm_config +from vllm.distributed.kv_transfer.kv_connector.v1.metrics import ( + KVConnectorStats, +) +from vllm.distributed.kv_transfer.kv_connector.v1.offloading.common import ( + OffloadingConnectorMetadata, + ReqId, +) +from vllm.distributed.kv_transfer.kv_connector.v1.offloading.metrics import ( + OffloadingConnectorStats, +) +from vllm.logger import init_logger +from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase +from vllm.v1.attention.backend import AttentionBackend +from vllm.v1.kv_offload.spec import OffloadingSpec +from vllm.v1.kv_offload.worker.worker import ( + OffloadingWorker, + TransferSpec, +) + +logger = init_logger(__name__) + + +class OffloadingConnectorWorker: + """Implementation of Worker side methods""" + + def __init__(self, spec: OffloadingSpec): + self.spec = spec + self.worker = OffloadingWorker() + + self._job_counter = 0 + + self.kv_connector_stats = OffloadingConnectorStats() + # req_id -> (job_id, store) + self._jobs: dict[int, tuple[ReqId, bool]] = {} + # req_id -> active job IDs + self._load_job: dict[ReqId, int] = {} + # req_id -> set(active job IDs) + self._store_jobs = defaultdict[ReqId, set[int]](set) + # list of store jobs pending submission (job_id, transfer_spec) + self._unsubmitted_store_jobs: list[tuple[int, TransferSpec]] = [] + + self._finished_reqs_waiting_for_store: set[ReqId] = set() + + def _generate_job_id(self) -> int: + job_id = self._job_counter + self._job_counter = job_id + 1 + return job_id + + def _register_handlers( + self, + kv_caches: dict[str, torch.Tensor], + attn_backends: dict[str, type[AttentionBackend]], + ): + for src_cls, dst_cls, handler in self.spec.get_handlers( + kv_caches, attn_backends + ): + self.worker.register_handler(src_cls, dst_cls, handler) + + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): + layer_names = list(kv_caches.keys()) + layers = get_layers_from_vllm_config( + self.spec.vllm_config, + AttentionLayerBase, # type: ignore[type-abstract] + layer_names, + ) + attn_backends = { + layer_name: layers[layer_name].get_attn_backend() + for layer_name in layer_names + } + self._register_handlers(kv_caches, attn_backends) + + def register_cross_layers_kv_cache( + self, kv_cache: torch.Tensor, attn_backend: type[AttentionBackend] + ): + cross_layer_name = "ALL_LAYERS" + kv_caches = {cross_layer_name: kv_cache} + attn_backends = {cross_layer_name: attn_backend} + self._register_handlers(kv_caches, attn_backends) + + def handle_preemptions(self, kv_connector_metadata: OffloadingConnectorMetadata): + for job_id, transfer_spec in self._unsubmitted_store_jobs: + success = self.worker.transfer_async(job_id, transfer_spec) + assert success + self._unsubmitted_store_jobs.clear() + + for req_id in kv_connector_metadata.reqs_to_flush or (): + job_ids = self._store_jobs.get(req_id) + if job_ids: + self.worker.wait(job_ids) + + def start_kv_transfers(self, metadata: OffloadingConnectorMetadata): + for job_id, transfer_spec in self._unsubmitted_store_jobs: + success = self.worker.transfer_async(job_id, transfer_spec) + assert success + self._unsubmitted_store_jobs.clear() + + for req_id, transfer_spec in metadata.reqs_to_load.items(): + job_id = self._generate_job_id() + self._jobs[job_id] = (req_id, False) + assert req_id not in self._load_job + self._load_job[req_id] = job_id + success = self.worker.transfer_async(job_id, transfer_spec) + assert success + + def prepare_store_kv(self, metadata: OffloadingConnectorMetadata): + for req_id, transfer_spec in metadata.reqs_to_store.items(): + job_id = self._generate_job_id() + self._jobs[job_id] = (req_id, True) + self._store_jobs[req_id].add(job_id) + # NOTE(orozery): defer the store to the beginning of the next engine step, + # so that offloading starts AFTER transfers related to token sampling, + # thereby avoiding delays to token generation due to offloading. + self._unsubmitted_store_jobs.append((job_id, transfer_spec)) + + def get_finished(self, finished_req_ids: set[str]) -> tuple[set[str], set[str]]: + """ + Notifies worker-side connector ids of requests that have + finished generating tokens. + Returns a list of request IDs that finished loading or storing. + + Returns: + ids of requests that have finished asynchronous transfer + tuple of (sending/saving ids, recving/loading ids). + """ + finished_sending = set() + finished_recving = set() + for transfer_result in self.worker.get_finished(): + # we currently do not support job failures + job_id = transfer_result.job_id + assert transfer_result.success + req_id, store = self._jobs.pop(job_id) + if ( + transfer_result.transfer_time + and transfer_result.transfer_size is not None + and transfer_result.transfer_type is not None + ): + self.kv_connector_stats.record_transfer( + num_bytes=transfer_result.transfer_size, + time=transfer_result.transfer_time, + transfer_type=transfer_result.transfer_type, + ) + if store: + req_jobs = self._store_jobs[req_id] + req_jobs.remove(job_id) + if req_jobs: + continue + + if req_id in self._finished_reqs_waiting_for_store: + self._finished_reqs_waiting_for_store.remove(req_id) + finished_sending.add(req_id) + del self._store_jobs[req_id] + else: + req_job = self._load_job[req_id] + assert job_id == req_job + del self._load_job[req_id] + finished_recving.add(req_id) + + for req_id in finished_req_ids: + pending_req_jobs = self._store_jobs.get(req_id) + if pending_req_jobs: + self._finished_reqs_waiting_for_store.add(req_id) + elif pending_req_jobs is not None: + finished_sending.add(req_id) + del self._store_jobs[req_id] + + return finished_sending, finished_recving + + def get_kv_connector_stats(self) -> KVConnectorStats | None: + """ + Get the KV transfer stats for the connector. + """ + + if self.kv_connector_stats.is_empty(): + return None + # Clear stats for next iteration + kv_connector_stats = self.kv_connector_stats + self.kv_connector_stats = OffloadingConnectorStats() + return kv_connector_stats diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py index 4c850fd2f8bd..547ee2578a12 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py @@ -1,16 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from collections import defaultdict from collections.abc import Iterable -from dataclasses import dataclass -from itertools import islice from typing import Any import torch -from vllm.config import VllmConfig, get_layers_from_vllm_config -from vllm.distributed.kv_events import BlockRemoved, BlockStored, KVCacheEvent -from vllm.distributed.kv_transfer.kv_connector.utils import yield_req_data +from vllm.config import VllmConfig +from vllm.distributed.kv_events import KVCacheEvent from vllm.distributed.kv_transfer.kv_connector.v1 import ( KVConnectorBase_V1, KVConnectorRole, @@ -22,96 +18,28 @@ PromMetric, PromMetricT, ) +from vllm.distributed.kv_transfer.kv_connector.v1.offloading.common import ( + OffloadingConnectorMetadata, +) +from vllm.distributed.kv_transfer.kv_connector.v1.offloading.metrics import ( + OffloadingConnectorStats, + OffloadPromMetrics, +) +from vllm.distributed.kv_transfer.kv_connector.v1.offloading.scheduler import ( + OffloadingConnectorScheduler, +) +from vllm.distributed.kv_transfer.kv_connector.v1.offloading.worker import ( + OffloadingConnectorWorker, +) from vllm.forward_context import ForwardContext -from vllm.logger import init_logger -from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.v1.attention.backend import AttentionBackend, AttentionMetadata from vllm.v1.core.kv_cache_manager import KVCacheBlocks -from vllm.v1.core.kv_cache_utils import BlockHash from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.kv_cache_interface import KVCacheConfig -from vllm.v1.kv_offload.abstract import OffloadingManager from vllm.v1.kv_offload.factory import OffloadingSpecFactory -from vllm.v1.kv_offload.mediums import GPULoadStoreSpec -from vllm.v1.kv_offload.spec import OffloadingSpec -from vllm.v1.kv_offload.worker.worker import ( - OffloadingWorker, - TransferSpec, - TransferType, -) from vllm.v1.outputs import KVConnectorOutput from vllm.v1.request import Request -ReqId = str - -logger = init_logger(__name__) - - -@dataclass -class OffloadingOperationMetrics: - op_size: int - op_time: float - - -@dataclass -class OffloadingConnectorStats(KVConnectorStats): - def __post_init__(self): - if not self.data: - # Empty container init, no data is passed in. - self.reset() - - def reset(self): - self.data: dict[str, list[OffloadingOperationMetrics]] = {} - - def aggregate(self, other: KVConnectorStats) -> KVConnectorStats: - if not other.is_empty(): - for k, v in other.data.items(): - if k not in self.data: - self.data[k] = v - else: - accumulator = self.data[k] - assert isinstance(accumulator, list) - accumulator.extend(v) - return self - - def reduce(self) -> dict[str, int | float]: - """ - Reduce the observations collected during a time interval to one or - more representative values (eg avg/median/sum of the series). - This is meant to be called by the logger to produce a summary of the - stats for the last time interval. - """ - return_dict: dict[str, int | float] = {} - for transfer_type, ops_list in self.data.items(): - assert isinstance(ops_list, list) - total_bytes = 0 - total_time = 0.0 - for op in ops_list: - assert isinstance(op, dict) - total_bytes += op["op_size"] - total_time += op["op_time"] - return_dict[f"{transfer_type}_total_bytes"] = total_bytes - return_dict[f"{transfer_type}_total_time"] = total_time - return return_dict - - def is_empty(self) -> bool: - return not self.data - - def record_transfer(self, num_bytes: int, time: float, transfer_type: TransferType): - src, dst = transfer_type - transfer_type_key = src + "_to_" + dst - op = OffloadingOperationMetrics(num_bytes, time) - if transfer_type_key in self.data: - self.data[transfer_type_key].append(op) - else: - self.data[transfer_type_key] = [op] - - -@dataclass -class OffloadingConnectorMetadata(KVConnectorMetadata): - reqs_to_load: dict[ReqId, TransferSpec] - reqs_to_store: dict[ReqId, TransferSpec] - class OffloadingConnector(KVConnectorBase_V1): @property @@ -146,9 +74,10 @@ def register_cross_layers_kv_cache( assert self.connector_worker is not None self.connector_worker.register_cross_layers_kv_cache(kv_cache, attn_backend) - def handle_preemptions(self, preempted_req_ids: set[str]): + def handle_preemptions(self, kv_connector_metadata: KVConnectorMetadata): assert self.connector_worker is not None - self.connector_worker.handle_preemptions(preempted_req_ids) + assert isinstance(kv_connector_metadata, OffloadingConnectorMetadata) + self.connector_worker.handle_preemptions(kv_connector_metadata) def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: assert self.connector_worker is not None @@ -240,570 +169,3 @@ def build_prom_metrics( return OffloadPromMetrics( vllm_config, metric_types, labelnames, per_engine_labelvalues ) - - -class OffloadingConnectorScheduler: - """Implementation of Scheduler side methods""" - - def __init__(self, spec: OffloadingSpec): - assert len(spec.gpu_block_size) == 1 - self.gpu_block_size = spec.gpu_block_size[0] - self.offloaded_block_size = self.gpu_block_size * spec.block_size_factor - self.block_size_factor = spec.block_size_factor - self.manager: OffloadingManager = spec.get_manager() - - self._requests: dict[ReqId, Request] = {} - # list of GPU block IDs per request - self._request_block_ids: dict[ReqId, list[int]] = {} - # requests to load for the current scheduler step - self._reqs_to_load: dict[ReqId, TransferSpec] = {} - # request blocks are stored in order - # index of next block (of size offloaded_block_size) to offload - self._next_stored_block_idx: dict[ReqId, int] = {} - # if GPU prefix caching is enabled, - # track loaded blocks to avoid redundant loads - self._blocks_being_loaded: set[BlockHash] | None = ( - set() if spec.vllm_config.cache_config.enable_prefix_caching else None - ) - - # request ID -> set(block hashes being stored/load) - self._reqs_being_stored = defaultdict[ReqId, set[BlockHash]](set) - self._reqs_being_loaded = defaultdict[ReqId, set[BlockHash]](set) - - def _get_block_hashes( - self, - req: Request, - start_idx: int = 0, - end_idx: int | None = None, - ) -> Iterable[BlockHash]: - return islice( - req.block_hashes, - self.block_size_factor * start_idx + self.block_size_factor - 1, - self.block_size_factor * end_idx if end_idx else None, - self.block_size_factor, - ) - - def get_num_new_matched_tokens( - self, request: Request, num_computed_tokens: int - ) -> tuple[int | None, bool]: - """ - Get number of new tokens that can be loaded beyond the - num_computed_tokens. - - Args: - request (Request): the request object. - num_computed_tokens (int): the number of locally - computed tokens for this request - - Returns: - A tuple with the following elements: - - The number of tokens that can be loaded beyond what is - already computed. - If None, it means that the connector needs more time to - determine the number of matched tokens, and the scheduler - should query for this request again later. - - `True` if tokens will be loaded asynchronously - (between scheduler steps). - """ - num_blocks = request.num_tokens // self.offloaded_block_size - - assert len(request.block_hashes) // self.block_size_factor == num_blocks - block_hashes = self._get_block_hashes(request) - - self.manager.touch(block_hashes) - - full_block_tokens = self.offloaded_block_size * num_blocks - if full_block_tokens - num_computed_tokens < self.offloaded_block_size: - # we can load less than a block, skip - return 0, False - - start_block_idx = num_computed_tokens // self.offloaded_block_size - hits = self.manager.lookup( - self._get_block_hashes(request, start_idx=start_block_idx) - ) - if hits is None: - # indicates a lookup that should be tried later - return None, False - if hits == 0: - return 0, False - - num_hit_tokens = ( - self.offloaded_block_size * (start_block_idx + hits) - num_computed_tokens - ) - logger.debug( - "Request %s hit %s offloaded tokens after %s GPU hit tokens", - request.request_id, - num_hit_tokens, - num_computed_tokens, - ) - if num_hit_tokens < self.offloaded_block_size: - return 0, False - - if self._blocks_being_loaded: - block_hashes = self._get_block_hashes( - request, start_idx=start_block_idx, end_idx=start_block_idx + hits - ) - - if any( - block_hash in self._blocks_being_loaded for block_hash in block_hashes - ): - # hit blocks are being loaded, delay request - logger.debug( - "Delaying request %s since some of its blocks are already" - " being loaded", - request.request_id, - ) - return None, False - - return num_hit_tokens, True - - def update_state_after_alloc( - self, request: Request, blocks: KVCacheBlocks, num_external_tokens: int - ): - self._requests[request.request_id] = request - # the block ids are updated in _get_reqs_to_store - self._request_block_ids[request.request_id] = [] - - if num_external_tokens == 0: - return - - block_groups = blocks.get_block_ids() - block_ids = block_groups[0] - - num_computed_gpu_blocks = sum( - block.block_hash is not None for block in blocks.blocks[0] - ) - num_computed_tokens = num_computed_gpu_blocks * self.gpu_block_size - full_block_tokens = num_computed_tokens + num_external_tokens - assert full_block_tokens % self.offloaded_block_size == 0 - - num_pending_gpu_blocks = len(block_ids) - num_computed_gpu_blocks - assert num_external_tokens == num_pending_gpu_blocks * self.gpu_block_size - - start_block_idx = num_computed_tokens // self.offloaded_block_size - num_blocks = full_block_tokens // self.offloaded_block_size - - assert len(request.block_hashes) // self.block_size_factor >= num_blocks - block_hashes = self._get_block_hashes( - request, start_idx=start_block_idx, end_idx=num_blocks - ) - - src_spec = self.manager.prepare_load(block_hashes) - dst_spec = GPULoadStoreSpec(block_ids[num_computed_gpu_blocks:]) - - block_hashes = self._get_block_hashes( - request, start_idx=start_block_idx, end_idx=num_blocks - ) - - self._reqs_to_load[request.request_id] = (src_spec, dst_spec) - req_blocks_being_loaded = self._reqs_being_loaded[request.request_id] - req_blocks_being_loaded.update(block_hashes) - self._next_stored_block_idx[request.request_id] = num_blocks - - if self._blocks_being_loaded is not None: - self._blocks_being_loaded.update(req_blocks_being_loaded) - - def _get_reqs_to_store(self, scheduler_output: SchedulerOutput): - reqs_to_store: dict[ReqId, TransferSpec] = {} - # iterate over both new and cached requests - for req_id, new_block_id_groups, preempted in yield_req_data(scheduler_output): - if preempted: - self._request_block_ids[req_id] = [] - - if new_block_id_groups: - new_block_ids = new_block_id_groups[0] - self._request_block_ids[req_id] += new_block_ids - - block_ids = self._request_block_ids[req_id] - - req = self._requests[req_id] - new_tokens = scheduler_output.num_scheduled_tokens[req_id] - expected_tokens = req.num_computed_tokens + new_tokens - # with async scheduling, some tokens may be missing - total_tokens = min(expected_tokens, req.num_tokens) - num_blocks = total_tokens // self.offloaded_block_size - start_block_idx = self._next_stored_block_idx.get(req_id, 0) - num_new_blocks = num_blocks - start_block_idx - - if num_new_blocks <= 0: - continue - - num_gpu_blocks = num_blocks * self.block_size_factor - assert len(req.block_hashes) >= num_gpu_blocks - - new_block_hashes = self._get_block_hashes( - req, start_idx=start_block_idx, end_idx=num_blocks - ) - store_output = self.manager.prepare_store(new_block_hashes) - if store_output is None: - logger.warning( - "Request %s: cannot store %s blocks", req_id, num_new_blocks - ) - continue - - self._next_stored_block_idx[req_id] = num_blocks - - if not store_output.block_hashes_to_store: - continue - block_hashes_to_store = set(store_output.block_hashes_to_store) - - block_hashes = self._get_block_hashes(req, end_idx=num_blocks) - self.manager.touch(block_hashes) - - new_block_hashes = self._get_block_hashes( - req, start_idx=start_block_idx, end_idx=num_blocks - ) - dst_spec = store_output.store_spec - src_block_ids: list[int] = [] - for idx, blk_hash in enumerate(new_block_hashes): - if blk_hash not in block_hashes_to_store: - continue - offloaded_block_idx = start_block_idx + idx - gpu_block_idx = offloaded_block_idx * self.block_size_factor - for i in range(self.block_size_factor): - src_block_ids.append(block_ids[gpu_block_idx + i]) - src_spec = GPULoadStoreSpec(src_block_ids) - - reqs_to_store[req_id] = (src_spec, dst_spec) - self._reqs_being_stored[req_id] |= block_hashes_to_store - - logger.debug( - "Request %s offloading %s blocks starting from block #%d", - req_id, - len(block_hashes_to_store), - start_block_idx, - ) - - return reqs_to_store - - def build_connector_meta( - self, scheduler_output: SchedulerOutput - ) -> KVConnectorMetadata: - meta = OffloadingConnectorMetadata( - reqs_to_load=self._reqs_to_load, - reqs_to_store=self._get_reqs_to_store(scheduler_output), - ) - self._reqs_to_load = {} - - # NOTE (orozery): we should move this logic to update_connector_output - # once KVConnectorOutput allows us to report completed transfers - for req_id in scheduler_output.preempted_req_ids or (): - block_hashes = self._reqs_being_stored.get(req_id) - if block_hashes: - self.manager.complete_store(block_hashes) - block_hashes.clear() - - return meta - - def update_connector_output(self, connector_output: KVConnectorOutput): - """ - Update KVConnector state from worker-side connectors output. - - Args: - connector_output (KVConnectorOutput): the worker-side - connectors output. - """ - for req_id in connector_output.finished_sending or []: - block_hashes = self._reqs_being_stored.pop(req_id, None) - if block_hashes: - self.manager.complete_store(block_hashes) - - for req_id in connector_output.finished_recving or []: - block_hashes = self._reqs_being_loaded.pop(req_id, None) - if block_hashes: - if self._blocks_being_loaded: - self._blocks_being_loaded.difference_update(block_hashes) - self.manager.complete_load(block_hashes) - - def request_finished( - self, - request: Request, - block_ids: list[int], - ) -> tuple[bool, dict[str, Any] | None]: - """ - Called when a request has finished, before its blocks are freed. - - Returns: - True if the request is being saved/sent asynchronously and blocks - should not be freed until the request_id is returned from - get_finished(). - Optional KVTransferParams to be included in the request outputs - returned by the engine. - """ - req_id = request.request_id - self._requests.pop(req_id, None) - self._request_block_ids.pop(req_id, None) - - # TODO(orozery): possibly kickoff offload for last block - # which may have been deferred due to async scheduling - self._next_stored_block_idx.pop(req_id, None) - - request_being_stored = req_id in self._reqs_being_stored - return request_being_stored, None - - def take_events(self) -> Iterable[KVCacheEvent]: - """Take the KV cache events from the connector. - - Returns: - A list of KV cache events. - """ - for event in self.manager.take_events(): - if event.removed: - yield BlockRemoved(block_hashes=event.block_hashes, medium=event.medium) - else: - yield BlockStored( - block_hashes=event.block_hashes, - parent_block_hash=None, - token_ids=[], - lora_id=None, - block_size=event.block_size, - medium=event.medium, - lora_name=None, - ) - - -class OffloadingConnectorWorker: - """Implementation of Worker side methods""" - - def __init__(self, spec: OffloadingSpec): - self.spec = spec - self.worker = OffloadingWorker() - - self._job_counter = 0 - - self.kv_connector_stats = OffloadingConnectorStats() - # req_id -> (job_id, store) - self._jobs: dict[int, tuple[ReqId, bool]] = {} - # req_id -> active job IDs - self._load_job: dict[ReqId, int] = {} - # req_id -> set(active job IDs) - self._store_jobs = defaultdict[ReqId, set[int]](set) - # list of store jobs pending submission (job_id, transfer_spec) - self._unsubmitted_store_jobs: list[tuple[int, TransferSpec]] = [] - - self._finished_reqs_waiting_for_store: set[ReqId] = set() - - def _generate_job_id(self) -> int: - job_id = self._job_counter - self._job_counter = job_id + 1 - return job_id - - def _register_handlers( - self, - kv_caches: dict[str, torch.Tensor], - attn_backends: dict[str, type[AttentionBackend]], - ): - for src_cls, dst_cls, handler in self.spec.get_handlers( - kv_caches, attn_backends - ): - self.worker.register_handler(src_cls, dst_cls, handler) - - def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): - layer_names = list(kv_caches.keys()) - layers = get_layers_from_vllm_config( - self.spec.vllm_config, - AttentionLayerBase, # type: ignore[type-abstract] - layer_names, - ) - attn_backends = { - layer_name: layers[layer_name].get_attn_backend() - for layer_name in layer_names - } - self._register_handlers(kv_caches, attn_backends) - - def register_cross_layers_kv_cache( - self, kv_cache: torch.Tensor, attn_backend: type[AttentionBackend] - ): - cross_layer_name = "ALL_LAYERS" - kv_caches = {cross_layer_name: kv_cache} - attn_backends = {cross_layer_name: attn_backend} - self._register_handlers(kv_caches, attn_backends) - - def handle_preemptions(self, preempted_req_ids: set[str]): - for job_id, transfer_spec in self._unsubmitted_store_jobs: - success = self.worker.transfer_async(job_id, transfer_spec) - assert success - self._unsubmitted_store_jobs.clear() - - for req_id in preempted_req_ids: - job_ids = self._store_jobs.get(req_id) - if job_ids: - self.worker.wait(job_ids) - - def start_kv_transfers(self, metadata: OffloadingConnectorMetadata): - for job_id, transfer_spec in self._unsubmitted_store_jobs: - success = self.worker.transfer_async(job_id, transfer_spec) - assert success - self._unsubmitted_store_jobs.clear() - - for req_id, transfer_spec in metadata.reqs_to_load.items(): - job_id = self._generate_job_id() - self._jobs[job_id] = (req_id, False) - assert req_id not in self._load_job - self._load_job[req_id] = job_id - success = self.worker.transfer_async(job_id, transfer_spec) - assert success - - def prepare_store_kv(self, metadata: OffloadingConnectorMetadata): - for req_id, transfer_spec in metadata.reqs_to_store.items(): - job_id = self._generate_job_id() - self._jobs[job_id] = (req_id, True) - self._store_jobs[req_id].add(job_id) - # NOTE(orozery): defer the store to the beginning of the next engine step, - # so that offloading starts AFTER transfers related to token sampling, - # thereby avoiding delays to token generation due to offloading. - self._unsubmitted_store_jobs.append((job_id, transfer_spec)) - - def get_finished(self, finished_req_ids: set[str]) -> tuple[set[str], set[str]]: - """ - Notifies worker-side connector ids of requests that have - finished generating tokens. - Returns a list of request IDs that finished loading or storing. - - Returns: - ids of requests that have finished asynchronous transfer - tuple of (sending/saving ids, recving/loading ids). - """ - finished_sending = set() - finished_recving = set() - for transfer_result in self.worker.get_finished(): - # we currently do not support job failures - job_id = transfer_result.job_id - assert transfer_result.success - req_id, store = self._jobs.pop(job_id) - if ( - transfer_result.transfer_time - and transfer_result.transfer_size is not None - and transfer_result.transfer_type is not None - ): - self.kv_connector_stats.record_transfer( - num_bytes=transfer_result.transfer_size, - time=transfer_result.transfer_time, - transfer_type=transfer_result.transfer_type, - ) - if store: - req_jobs = self._store_jobs[req_id] - req_jobs.remove(job_id) - if req_jobs: - continue - - if req_id in self._finished_reqs_waiting_for_store: - self._finished_reqs_waiting_for_store.remove(req_id) - finished_sending.add(req_id) - del self._store_jobs[req_id] - else: - req_job = self._load_job[req_id] - assert job_id == req_job - del self._load_job[req_id] - finished_recving.add(req_id) - - for req_id in finished_req_ids: - pending_req_jobs = self._store_jobs.get(req_id) - if pending_req_jobs: - self._finished_reqs_waiting_for_store.add(req_id) - elif pending_req_jobs is not None: - finished_sending.add(req_id) - del self._store_jobs[req_id] - - return finished_sending, finished_recving - - def get_kv_connector_stats(self) -> KVConnectorStats | None: - """ - Get the KV transfer stats for the connector. - """ - - if self.kv_connector_stats.is_empty(): - return None - # Clear stats for next iteration - kv_connector_stats = self.kv_connector_stats - self.kv_connector_stats = OffloadingConnectorStats() - return kv_connector_stats - - -class OffloadPromMetrics(KVConnectorPromMetrics): - def __init__( - self, - vllm_config: VllmConfig, - metric_types: dict[type[PromMetric], type[PromMetricT]], - labelnames: list[str], - per_engine_labelvalues: dict[int, list[object]], - ): - super().__init__(vllm_config, metric_types, labelnames, per_engine_labelvalues) - # (engine_idx, transfer_type) -> (metric with bounded labels) - self.histogram_transfer_size: dict[tuple[int, str], PromMetricT] = {} - self.counter_kv_bytes: dict[tuple[int, str], PromMetricT] = {} - self.counter_kv_transfer_time: dict[tuple[int, str], PromMetricT] = {} - buckets = [ # In bytes - 1e6, - 5e6, - 10e6, - 20e6, - 40e6, - 60e6, - 80e6, - 100e6, - 150e6, - 200e6, - ] - - self._counter_kv_bytes = self._counter_cls( - name="vllm:kv_offload_total_bytes", - documentation="Number of bytes offloaded by KV connector", - labelnames=labelnames + ["transfer_type"], - ) - - self._counter_kv_transfer_time = self._counter_cls( - name="vllm:kv_offload_total_time", - documentation="Total time measured by all KV offloading operations", - labelnames=labelnames + ["transfer_type"], - ) - - self._histogram_transfer_size = self._histogram_cls( - name="vllm:kv_offload_size", - documentation="Histogram of KV offload transfer size, in bytes.", - buckets=buckets[:], - labelnames=labelnames + ["transfer_type"], - ) - - def observe(self, transfer_stats_data: dict[str, Any], engine_idx: int = 0): - """ - Observe transfer statistics from the new data structure. - transfer_stats_data is expected to be a dict where: - - keys are transfer type strings (e.g., "cpu_to_gpu", "gpu_to_cpu") - - values are lists of OffloadingOperationMetrics objects - """ - - for transfer_type, ops in transfer_stats_data.items(): - # Cache: - if (engine_idx, transfer_type) not in self.histogram_transfer_size: - self.histogram_transfer_size[(engine_idx, transfer_type)] = ( - self._histogram_transfer_size.labels( - *(self.per_engine_labelvalues[engine_idx] + [transfer_type]) - ) - ) - self.counter_kv_bytes[(engine_idx, transfer_type)] = ( - self._counter_kv_bytes.labels( - *(self.per_engine_labelvalues[engine_idx] + [transfer_type]) - ) - ) - self.counter_kv_transfer_time[(engine_idx, transfer_type)] = ( - self._counter_kv_transfer_time.labels( - *(self.per_engine_labelvalues[engine_idx] + [transfer_type]) - ) - ) - - # Process ops: - assert isinstance(ops, list) - for op in ops: # ops is a list of serialized OffloadingOperationMetrics - assert isinstance(op, dict) - # Observe size histogram - self.histogram_transfer_size[(engine_idx, transfer_type)].observe( - op["op_size"] - ) - - # Increment byte and time counters - self.counter_kv_bytes[(engine_idx, transfer_type)].inc(op["op_size"]) - - self.counter_kv_transfer_time[(engine_idx, transfer_type)].inc( - op["op_time"] - ) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index af1bc6b14b59..04187b34ec7a 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -40,13 +40,16 @@ import torch.distributed import torch.distributed._functional_collectives as funcol import torch.distributed._symmetric_memory -from torch.distributed import Backend, ProcessGroup +from torch.distributed import Backend, ProcessGroup, Store import vllm.envs as envs from vllm.distributed.device_communicators.base_device_communicator import ( DeviceCommunicatorBase, ) -from vllm.distributed.utils import StatelessProcessGroup +from vllm.distributed.utils import ( + StatelessProcessGroup, + get_cached_tcp_store_client, +) from vllm.logger import init_logger from vllm.utils.import_utils import resolve_obj_by_qualname from vllm.utils.network_utils import get_distributed_init_method @@ -1164,9 +1167,9 @@ def init_model_parallel_group( def _init_stateless_group( group_ranks: list[list[int]], group_name: str, - group_ports: list[list[int]], host: str, backend: str, + coord_store: Store, use_device_communicator: bool = True, ) -> "StatelessGroupCoordinator": """Create a StatelessGroupCoordinator with the given parameters.""" @@ -1180,7 +1183,7 @@ def _init_stateless_group( use_device_communicator=use_device_communicator, group_name=group_name, host=host, - group_ports=group_ports, + coord_store=coord_store, global_rank=world.rank, global_world_size=world.world_size, ) @@ -1321,7 +1324,9 @@ def _init_elastic_ep_world( group_ranks = [all_ranks[i : i + 1] for i in range(global_world_size)] if global_rank in all_ranks: group_ranks = [all_ranks] - group_ports = [parallel_config.get_next_stateless_world_group_port()] + coord_store = get_cached_tcp_store_client( + parallel_config.data_parallel_master_ip, parallel_config._coord_store_port + ) world = StatelessGroupCoordinator( group_ranks=group_ranks, local_rank=local_rank, @@ -1329,7 +1334,7 @@ def _init_elastic_ep_world( use_device_communicator=False, group_name="world", host=parallel_config.data_parallel_master_ip, - group_ports=group_ports, + coord_store=coord_store, global_rank=global_rank, global_world_size=global_world_size, ) @@ -1513,7 +1518,13 @@ def initialize_model_parallel( config = get_current_vllm_config() data_parallel_size = config.parallel_config.data_parallel_size enable_elastic_ep = config.parallel_config.enable_elastic_ep + parallel_config = config.parallel_config + coord_store: Store | None = None if enable_elastic_ep: + coord_store = get_cached_tcp_store_client( + parallel_config.data_parallel_master_ip, + parallel_config._coord_store_port, + ) # Use stateless world group for global information world_size = get_world_group().world_size rank = get_world_group().rank @@ -1633,16 +1644,12 @@ def initialize_model_parallel( group_ranks = all_ranks.transpose(1, 4).reshape(-1, data_parallel_size).unbind(0) group_ranks = [x.tolist() for x in group_ranks] if enable_elastic_ep: - parallel_config = config.parallel_config - dp_ports = [ - parallel_config.get_next_stateless_dp_group_port() for _ in group_ranks - ] _DP = _init_stateless_group( group_ranks, "dp", - dp_ports, parallel_config.data_parallel_master_ip, backend, + coord_store=coord_store, ) else: _DP = init_model_parallel_group( @@ -1665,16 +1672,12 @@ def initialize_model_parallel( ) group_ranks = [x.tolist() for x in group_ranks] if enable_elastic_ep: - parallel_config = config.parallel_config - ep_ports = [ - parallel_config.get_next_stateless_ep_group_port() for _ in group_ranks - ] _EP = _init_stateless_group( group_ranks, "ep", - ep_ports, parallel_config.data_parallel_master_ip, backend, + coord_store=coord_store, ) else: _EP = init_model_parallel_group( @@ -1693,16 +1696,12 @@ def initialize_model_parallel( and config.parallel_config.enable_eplb ): if enable_elastic_ep: - eplb_ports = [ - parallel_config.get_next_stateless_eplb_group_port() - for _ in group_ranks - ] _EPLB = _init_stateless_group( group_ranks, "eplb", - eplb_ports, parallel_config.data_parallel_master_ip, backend, + coord_store=coord_store, ) else: _EPLB = init_model_parallel_group( diff --git a/vllm/distributed/stateless_coordinator.py b/vllm/distributed/stateless_coordinator.py index f2126fdbaa32..549284df32df 100644 --- a/vllm/distributed/stateless_coordinator.py +++ b/vllm/distributed/stateless_coordinator.py @@ -1,9 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import socket +import struct from typing import Any, Optional import torch -from torch.distributed import Backend, ProcessGroup +from torch.distributed import Backend, ProcessGroup, Store from vllm.distributed.device_communicators.cuda_communicator import CudaCommunicator from vllm.distributed.parallel_state import ( @@ -23,6 +25,38 @@ logger = init_logger(__name__) +_PORTS_FMT = "!3I" + + +def _allocate_group_ports( + key: str, + host: str, + coord_store: Store, +) -> tuple[list[int], list[socket.socket]]: + """Bind 3 sockets and publish the ports to *coord_store*. + + Called by rank 0 only. Returns ``(ports, sockets)`` with the + sockets still open. + """ + socks: list[socket.socket] = [] + ports: list[int] = [] + for _ in range(3): + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + s.bind((host, 0)) + s.listen() + socks.append(s) + ports.append(s.getsockname()[1]) + coord_store.set(key, struct.pack(_PORTS_FMT, *ports)) + return ports, socks + + +def _fetch_group_ports(key: str, coord_store: Store) -> list[int]: + """Read 3 ports published by rank 0 from *coord_store*. + + Blocks until the key is available. + """ + return list(struct.unpack(_PORTS_FMT, coord_store.get(key))) + class StatelessGroupCoordinator(GroupCoordinator): """ @@ -39,10 +73,10 @@ def __init__( local_rank: int, torch_distributed_backend: str | Backend, use_device_communicator: bool, + coord_store: Store, use_message_queue_broadcaster: bool = False, group_name: str | None = None, host: str = "127.0.0.1", - group_ports: list[list[int]] | None = None, global_rank: int = 0, global_world_size: int = 1, ): @@ -61,17 +95,23 @@ def __init__( backend = str(torch_distributed_backend) self.backend = backend - assert group_ports is not None, "group_ports is not provided" for idx, ranks in enumerate(group_ranks): if self.rank in ranks: self.ranks = ranks self.world_size = len(ranks) self.rank_in_group = ranks.index(self.rank) - ports = group_ports[idx] - device_port = ports[0] - cpu_port = ports[1] - tcp_store_port = ports[2] + key = f"{group_name}_{idx}" + if self.rank_in_group == 0: + ports, socks = _allocate_group_ports( + key, + host, + coord_store, + ) + else: + ports = _fetch_group_ports(key, coord_store) + socks = [] + device_port, cpu_port, tcp_store_port = ports device_group = stateless_init_torch_distributed_process_group( host=host, @@ -80,6 +120,7 @@ def __init__( world_size=self.world_size, backend=backend, group_name=f"{self.unique_name}_device", + listen_socket=socks[0] if socks else None, ) cpu_group = stateless_init_torch_distributed_process_group( host=host, @@ -88,12 +129,14 @@ def __init__( world_size=self.world_size, backend="gloo", group_name=f"{self.unique_name}_cpu", + listen_socket=socks[1] if socks else None, ) tcp_store_group = StatelessProcessGroup.create( host=host, port=tcp_store_port, rank=self.rank_in_group, world_size=self.world_size, + listen_socket=socks[2] if socks else None, ) self_device_group = device_group diff --git a/vllm/distributed/utils.py b/vllm/distributed/utils.py index 102f2f727b75..9991ab1ddc23 100644 --- a/vllm/distributed/utils.py +++ b/vllm/distributed/utils.py @@ -6,6 +6,7 @@ # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. import dataclasses +import functools import os import pickle import socket @@ -139,6 +140,29 @@ def get_pp_indices( return (start_layer, end_layer) +def create_tcp_store( + host: str, + port: int, + listen_socket: socket.socket | None = None, + **kwargs: Any, +) -> TCPStore: + """Create a TCPStore, optionally taking ownership of ``listen_socket``.""" + if listen_socket is None: + return TCPStore(host_name=host, port=port, **kwargs) + + listen_fd = listen_socket.detach() + try: + return TCPStore( + host_name=host, + port=port, + master_listen_fd=listen_fd, + **kwargs, + ) + except Exception: + socket.close(listen_fd) + raise + + @dataclasses.dataclass class StatelessProcessGroup: """A dataclass to hold a metadata store, and the rank, world_size of the @@ -150,9 +174,6 @@ class StatelessProcessGroup: world_size: int store: torch._C._distributed_c10d.Store - # stores a reference to the socket so that the file descriptor stays alive - socket: socket.socket | None - data_expiration_seconds: int = 3600 # 1 hour # dst rank -> counter @@ -419,6 +440,7 @@ def create( world_size: int, data_expiration_seconds: int = 3600, store_timeout: int = 300, + listen_socket: socket.socket | None = None, ) -> "StatelessProcessGroup": """A replacement for `torch.distributed.init_process_group` that does not pollute the global state. @@ -436,36 +458,39 @@ def create( C, and D can call `StatelessProcessGroup.create` to form another group. """ # noqa launch_server = rank == 0 - if launch_server: - # listen on the specified interface (instead of 0.0.0.0) + if launch_server and listen_socket is None: listen_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) listen_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) listen_socket.bind((host, port)) listen_socket.listen() - listen_fd = listen_socket.fileno() - else: - listen_socket = None - listen_fd = None - - store = TCPStore( - host_name=host, - port=port, + store = create_tcp_store( + host, + port, + listen_socket=listen_socket, world_size=world_size, is_master=launch_server, timeout=timedelta(seconds=store_timeout), use_libuv=False, # for now: github.com/pytorch/pytorch/pull/150215 - master_listen_fd=listen_fd, ) return StatelessProcessGroup( rank=rank, world_size=world_size, store=store, - socket=listen_socket, data_expiration_seconds=data_expiration_seconds, ) +@functools.lru_cache(maxsize=1) +def get_cached_tcp_store_client(host: str, port: int) -> TCPStore: + """Return a cached TCPStore client. + + Cached so that every call with the same ``(host, port)`` reuses the + same connection. A new ``(host, port)`` evicts the old entry. + """ + return TCPStore(host, port, is_master=False, wait_for_workers=False) + + def init_gloo_process_group( prefix_store: PrefixStore, group_rank: int, @@ -504,6 +529,7 @@ def stateless_init_torch_distributed_process_group( backend: str, group_name: str | None = None, return_store: bool = False, + listen_socket: socket.socket | None = None, ) -> ProcessGroup | tuple[ProcessGroup, Store]: """ A replacement for `torch.distributed.init_process_group` that does not @@ -535,14 +561,30 @@ def stateless_init_torch_distributed_process_group( are the same as process 1 and 5, the main communication channel is always formed with process 1, 2, ..., 8, and the additional communication channel is formed with process 9 and 10. + + When *listen_socket* is provided, the rendezvous step + is skipped and a ``TCPStore`` server is created directly using the + pre-bound socket. This is useful for eliminating TOCTOU races + between port allocation and binding. """ init_method = get_tcp_uri(host, port) backend = Backend(backend) # it is basically string timeout = _get_default_timeout(backend) - store, rank, world_size = next( - rendezvous(init_method, rank, world_size, timeout=timeout) - ) + if listen_socket is not None: + store = create_tcp_store( + host, + port, + listen_socket=listen_socket, + world_size=world_size, + is_master=True, + timeout=timeout, + multi_tenant=True, + ) + else: + store, rank, world_size = next( + rendezvous(init_method, rank, world_size, timeout=timeout) + ) store.set_timeout(timeout) group_rank = rank diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index d0bdd4916144..730641a184fc 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -108,6 +108,7 @@ from vllm.utils.torch_utils import resolve_kv_cache_dtype_string from vllm.v1.attention.backends.registry import AttentionBackendEnum from vllm.v1.sample.logits_processor import LogitsProcessor +from vllm.version import __version__ as VLLM_VERSION if TYPE_CHECKING: from vllm.model_executor.layers.quantization import QuantizationMethods @@ -243,6 +244,14 @@ def get_type_hints(type_hint: TypeHint) -> set[TypeHint]: ) +def _maybe_add_docs_url(cls: Any) -> str: + """Generate API docs URL for a vllm config class.""" + if not cls.__module__.startswith("vllm.config"): + return "" + version = f"v{VLLM_VERSION}" if "dev" not in VLLM_VERSION else "latest" + return f"\n\nAPI docs: https://docs.vllm.ai/en/{version}/api/vllm/config/#vllm.config.{cls.__name__}" + + @functools.lru_cache(maxsize=30) def _compute_kwargs(cls: ConfigType) -> dict[str, dict[str, Any]]: # Save time only getting attr docs if we're generating help text @@ -293,6 +302,7 @@ def parse_dataclass(val: str, cls=dataclass_cls) -> Any: raise argparse.ArgumentTypeError(repr(e)) from e kwargs[name]["type"] = parse_dataclass + kwargs[name]["help"] += _maybe_add_docs_url(dataclass_cls) kwargs[name]["help"] += f"\n\n{json_tip}" elif contains_type(type_hints, bool): # Creates --no- and -- flags diff --git a/vllm/entrypoints/openai/responses/protocol.py b/vllm/entrypoints/openai/responses/protocol.py index 2adcd9eaa09c..a5f62bdd8c39 100644 --- a/vllm/entrypoints/openai/responses/protocol.py +++ b/vllm/entrypoints/openai/responses/protocol.py @@ -27,6 +27,7 @@ ResponseReasoningTextDeltaEvent, ResponseReasoningTextDoneEvent, ResponseStatus, + ResponseTextConfig, ResponseWebSearchCallCompletedEvent, ResponseWebSearchCallInProgressEvent, ResponseWebSearchCallSearchingEvent, @@ -38,20 +39,13 @@ from openai.types.responses import ( ResponseInProgressEvent as OpenAIResponseInProgressEvent, ) -from openai.types.responses.tool import Tool -from openai_harmony import Message as OpenAIHarmonyMessage - -# Backward compatibility for OpenAI client versions -try: # For older openai versions (< 1.100.0) - from openai.types.responses import ResponseTextConfig -except ImportError: # For newer openai versions (>= 1.100.0) - from openai.types.responses import ResponseFormatTextConfig as ResponseTextConfig - from openai.types.responses.response import IncompleteDetails, ToolChoice from openai.types.responses.response_reasoning_item import ( Content as ResponseReasoningTextContent, ) +from openai.types.responses.tool import Tool from openai.types.shared import Metadata, Reasoning +from openai_harmony import Message as OpenAIHarmonyMessage from pydantic import ( Field, ValidationError, diff --git a/vllm/entrypoints/openai/responses/serving.py b/vllm/entrypoints/openai/responses/serving.py index dd42a6a56600..b2428e97e20d 100644 --- a/vllm/entrypoints/openai/responses/serving.py +++ b/vllm/entrypoints/openai/responses/serving.py @@ -1012,6 +1012,7 @@ def _make_response_output_items( parser = self.parser(tokenizer) return parser.extract_response_outputs( model_output=final_output.text, + model_output_token_ids=final_output.token_ids, request=request, enable_auto_tools=self.enable_auto_tools, tool_call_id_type=self.tool_call_id_type, diff --git a/vllm/lora/layers/__init__.py b/vllm/lora/layers/__init__.py index 1f3fdea2cdaf..235f40b73852 100644 --- a/vllm/lora/layers/__init__.py +++ b/vllm/lora/layers/__init__.py @@ -13,6 +13,7 @@ QKVParallelLinearWithShardedLoRA, ) from vllm.lora.layers.fused_moe import FusedMoE3DWithLoRA, FusedMoEWithLoRA +from vllm.lora.layers.gate_linear import GateLinearWithLoRA from vllm.lora.layers.logits_processor import LogitsProcessorWithLoRA from vllm.lora.layers.replicated_linear import ReplicatedLinearWithLoRA from vllm.lora.layers.row_parallel_linear import ( @@ -38,6 +39,7 @@ "RowParallelLinearWithLoRA", "RowParallelLinearWithShardedLoRA", "ReplicatedLinearWithLoRA", + "GateLinearWithLoRA", "LoRAMapping", "LoRAMappingType", "FusedMoEWithLoRA", diff --git a/vllm/lora/layers/gate_linear.py b/vllm/lora/layers/gate_linear.py new file mode 100644 index 000000000000..9bcaaa5b8e20 --- /dev/null +++ b/vllm/lora/layers/gate_linear.py @@ -0,0 +1,30 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch.nn as nn +from transformers import PretrainedConfig + +from vllm.config.lora import LoRAConfig +from vllm.model_executor.custom_op import maybe_get_oot_by_class +from vllm.model_executor.layers.fused_moe.router.gate_linear import GateLinear + +from .replicated_linear import ReplicatedLinearWithLoRA + + +class GateLinearWithLoRA(ReplicatedLinearWithLoRA): + def __init__(self, base_layer: GateLinear) -> None: + super().__init__( + base_layer, + ) + + # GateLinearWithLoRA should always be replaced, regardless of the fully + # sharded LoRAs setting, because it is, by definition, copied per GPU. + @classmethod + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: list, + model_config: PretrainedConfig | None = None, + ) -> bool: + return type(source_layer) is maybe_get_oot_by_class(GateLinear) diff --git a/vllm/lora/model_manager.py b/vllm/lora/model_manager.py index 12d6f719a5c7..a84c399c3fd0 100644 --- a/vllm/lora/model_manager.py +++ b/vllm/lora/model_manager.py @@ -161,9 +161,9 @@ def _maybe_init_mm( device=self.device, lora_config=self.lora_config, ) + lm_prefix = self.mm_mapping.language_model[0] self.punica_wrapper_mapping[lm_prefix] = llm_punica_wrapper - if self.lora_config.enable_tower_connector_lora: self.supports_tower_connector_lora = self.supports_mm and hasattr( self.model, "get_num_mm_encoder_tokens" @@ -171,6 +171,18 @@ def _maybe_init_mm( if not self.supports_tower_connector_lora: return + if ( + vllm_config.model_config.multimodal_config + and vllm_config.model_config.multimodal_config.language_model_only + ): + if self.supports_tower_connector_lora: + logger.warning( + "Disabling `enable_tower_connector_lora` because the multimodal " + "model is configured to initialize the language model only." + ) + self.supports_tower_connector_lora = False + return + logger.warning( "LoRA for the tower and connector of multimodal models is " "experimental and may contain bugs. Please report any related issues on " diff --git a/vllm/lora/ops/triton_ops/fused_moe_lora_fp8_op.py b/vllm/lora/ops/triton_ops/fused_moe_lora_fp8_op.py index 015d434165d4..deb34cfe435c 100644 --- a/vllm/lora/ops/triton_ops/fused_moe_lora_fp8_op.py +++ b/vllm/lora/ops/triton_ops/fused_moe_lora_fp8_op.py @@ -10,11 +10,10 @@ tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce, ) +from vllm.lora.ops.triton_ops.utils import supports_pdl from vllm.triton_utils import tl, triton from vllm.utils.torch_utils import direct_register_custom_op -from .utils import supports_pdl - @triton.jit def _get_lora_id( diff --git a/vllm/lora/utils.py b/vllm/lora/utils.py index 2349ace70846..75ed9674af56 100644 --- a/vllm/lora/utils.py +++ b/vllm/lora/utils.py @@ -21,6 +21,7 @@ ColumnParallelLinearWithShardedLoRA, FusedMoE3DWithLoRA, FusedMoEWithLoRA, + GateLinearWithLoRA, LogitsProcessorWithLoRA, MergedColumnParallelLinearVariableSliceWithLoRA, MergedColumnParallelLinearWithLoRA, @@ -81,6 +82,7 @@ def get_lora_id(): MergedQKVParallelLinearWithLoRA, RowParallelLinearWithLoRA, ReplicatedLinearWithLoRA, + GateLinearWithLoRA, LogitsProcessorWithLoRA, ColumnParallelLinearWithShardedLoRA, QKVParallelLinearWithShardedLoRA, diff --git a/vllm/model_executor/layers/attention/mm_encoder_attention.py b/vllm/model_executor/layers/attention/mm_encoder_attention.py index 46d461c38b3f..6755e9af9e65 100644 --- a/vllm/model_executor/layers/attention/mm_encoder_attention.py +++ b/vllm/model_executor/layers/attention/mm_encoder_attention.py @@ -227,7 +227,9 @@ def __init__( if self.attn_backend == AttentionBackendEnum.FLASHINFER: _get_flashinfer_workspace_buffer() - logger.info_once(f"Using {self.attn_backend} for MMEncoderAttention.") + logger.info_once( + f"Using {self.attn_backend} for MMEncoderAttention.", scope="local" + ) @classmethod def enabled(cls) -> bool: diff --git a/vllm/model_executor/layers/fused_moe/configs/E=256,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=256,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 000000000000..689e553e1c2f --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=256,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,147 @@ +{ + "triton_version": "3.6.0", + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py b/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py index 88cd173fe6a8..f6a303e7988e 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py @@ -101,6 +101,11 @@ def topk_indices_dtype(self) -> torch.dtype | None: return self.moe_kernel.prepare_finalize.topk_indices_dtype() return None + @property + def skip_forward_padding(self) -> bool: + """Whether to skip the padding in the forward before applying the moe method.""" + return False + @property def supports_eplb(self) -> bool: return False diff --git a/vllm/model_executor/layers/fused_moe/router/gate_linear.py b/vllm/model_executor/layers/fused_moe/router/gate_linear.py index 77d8e756026d..e8ed8a5249d1 100644 --- a/vllm/model_executor/layers/fused_moe/router/gate_linear.py +++ b/vllm/model_executor/layers/fused_moe/router/gate_linear.py @@ -3,9 +3,11 @@ import torch from torch.nn.parameter import Parameter +import vllm._custom_ops as ops from vllm.model_executor.custom_op import PluggableLayer from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.platforms import current_platform +from vllm.utils.torch_utils import direct_register_custom_op @PluggableLayer.register("gate_linear") @@ -13,8 +15,9 @@ class GateLinear(ReplicatedLinear): """MoE gate linear layer with three-tier GEMM dispatch: 1. DSV3 specialized kernel (SM90+, batch<=16, supported dims) - 2. cuBLAS bf16×bf16→fp32 (SM90+ + bf16 + fp32 out_dtype) - 3. F.linear via ReplicatedLinear (ultimate fallback) + 2. gpt-oss specialized kernel (SM90+, batch<=128, supported dims) + 3. cuBLAS bf16×bf16→fp32 (SM90+ + bf16 + fp32 out_dtype) + 4. F.linear via ReplicatedLinear (ultimate fallback) The ``out_dtype`` attribute is mutable and can be set after init (e.g. when the required dtype depends on the expert quantization @@ -25,6 +28,10 @@ class GateLinear(ReplicatedLinear): DSV3_SUPPORTED_NUM_EXPERTS = [256, 384] DSV3_SUPPORTED_HIDDEN_SIZES = [7168] + # Dimensions supported by the gpt-oss specialized kernel + GPT_OSS_SUPPORTED_NUM_EXPERTS = [32, 128] + GPT_OSS_SUPPORTED_HIDDEN_SIZES = [2880] + def __init__( self, input_size: int, @@ -65,6 +72,15 @@ def __init__( and input_size in self.DSV3_SUPPORTED_HIDDEN_SIZES ) + # gpt-oss specialized kernel eligibility (SM90+, exact dims) + self.allow_gpt_oss_router_gemm = ( + self.weight.dtype == torch.bfloat16 + and current_platform.is_cuda() + and is_hopper_or_blackwell + and output_size in self.GPT_OSS_SUPPORTED_NUM_EXPERTS + and input_size in self.GPT_OSS_SUPPORTED_HIDDEN_SIZES + ) + # cuBLAS bf16→fp32 eligibility self.allow_cublas_router_gemm = ( self.allow_specialized_router_gemm @@ -92,8 +108,6 @@ def set_out_dtype(self, out_dtype: torch.dtype) -> None: def forward( self, x: torch.Tensor ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]: - import vllm._custom_ops as ops - # Tier 1: DSV3 specialized kernel if self.allow_dsv3_router_gemm and x.shape[0] <= 16: output = ops.dsv3_router_gemm( @@ -103,15 +117,47 @@ def forward( ) return output, None - # Tier 2: cuBLAS bf16→fp32 + # Tier 2: gpt-oss specialized kernel + if self.allow_gpt_oss_router_gemm: + output = torch.ops.vllm.gpt_oss_router_gemm(x, self.weight, self.bias) + return output, None + + # Tier 3: cuBLAS bf16→fp32 if self.allow_cublas_router_gemm and x.dtype == torch.bfloat16: output = ops.router_gemm_bf16_fp32(x, self.weight) return output, None - # Tier 3: F.linear (ReplicatedLinear) + # Tier 4: F.linear (ReplicatedLinear) if self.out_dtype is not None and x.dtype != self.weight.dtype: x = x.to(self.weight.dtype) output, output_bias = super().forward(x) if self.out_dtype is not None and output.dtype != self.out_dtype: output = output.to(self.out_dtype) return output, output_bias + + +def gpt_oss_router_gemm_impl( + x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor +) -> torch.Tensor: + """ + Dynamically run min-latency gemm if num_tokens <= 128. + This must be wrapped in a custom op because our torch.compile integration + does not support runtime dispatching on num_tokens. + """ + if x.shape[0] <= 128: + return ops.gpt_oss_router_gemm(x, weight, bias) + else: + return torch.nn.functional.linear(x, weight, bias) + + +def gpt_oss_router_gemm_fake( + x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor +) -> torch.Tensor: + return x.new_empty((x.shape[0], weight.shape[0])) + + +direct_register_custom_op( + op_name="gpt_oss_router_gemm", + op_func=gpt_oss_router_gemm_impl, + fake_impl=gpt_oss_router_gemm_fake, +) diff --git a/vllm/model_executor/layers/fused_moe/runner/default_moe_runner.py b/vllm/model_executor/layers/fused_moe/runner/default_moe_runner.py index b6313776e85d..12b560493fa2 100644 --- a/vllm/model_executor/layers/fused_moe/runner/default_moe_runner.py +++ b/vllm/model_executor/layers/fused_moe/runner/default_moe_runner.py @@ -415,7 +415,10 @@ def forward( # This is the dimension after transform (for routed expert output slicing) transformed_hidden_dim = hidden_states.shape[-1] - if self.moe_config.hidden_dim != transformed_hidden_dim: + if ( + not self.quant_method.skip_forward_padding + and self.moe_config.hidden_dim != transformed_hidden_dim + ): hidden_states = F.pad( hidden_states, (0, self.moe_config.hidden_dim - transformed_hidden_dim), diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index 1ad024a6fd48..f992d0f86c4e 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -294,6 +294,12 @@ def __init__(self, moe: FusedMoEConfig): # Initialized in process_weights_after_loading for CUTLASS/SM90 backends self.moe_kernel: mk.FusedMoEKernel | None = None + @property + def skip_forward_padding(self) -> bool: + # SM100_FI_MXFP4_MXFP8_TRTLLM supports padding with mxfp8 quant + # so can skip the padding in the forward before applying the moe method + return self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM + def create_weights( self, layer: torch.nn.Module, @@ -1130,9 +1136,17 @@ def apply_monolithic( elif self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM: from flashinfer import mxfp8_quantize - x_quant, x_scale = mxfp8_quantize(x, False) # to mxfp8 + # x_quant is padded in hidden dimension with alignment=256 + x_quant, x_scale = mxfp8_quantize( + x, + is_sf_swizzled_layout=False, + alignment=256, + ) x_scale = x_scale.view(torch.float8_e4m3fn).reshape(*x.shape[:-1], -1) + # output with original unpadded hidden size + output = torch.empty_like(x) + trtllm_gen_output = trtllm_fp4_block_scale_moe( routing_logits=router_logits.to(torch.bfloat16), routing_bias=None, @@ -1161,6 +1175,7 @@ def apply_monolithic( routing_method_type=1 if layer.renormalize else 0, do_finalize=True, tune_max_num_tokens=max(self.max_capture_size, 1), + output=output, )[0] return trtllm_gen_output elif self.mxfp4_backend == Mxfp4Backend.CK: diff --git a/vllm/model_executor/model_loader/default_loader.py b/vllm/model_executor/model_loader/default_loader.py index a8d81024421d..5c9c97f4b64a 100644 --- a/vllm/model_executor/model_loader/default_loader.py +++ b/vllm/model_executor/model_loader/default_loader.py @@ -320,6 +320,13 @@ def _init_ep_weight_filter(self, model_config: ModelConfig) -> None: ): return + # When EPLB is enabled, redundant physical expert slots may map to + # logical experts that belong to other ranks in the default partition. + # The weight loader needs to see ALL logical expert weights so it can + # populate these redundant slots. Skip the filter entirely. + if parallel_config.enable_eplb: + return + num_experts = model_config.get_num_experts() if num_experts <= 0: return diff --git a/vllm/model_executor/model_loader/ep_weight_filter.py b/vllm/model_executor/model_loader/ep_weight_filter.py index 1ef7f0174511..190842379253 100644 --- a/vllm/model_executor/model_loader/ep_weight_filter.py +++ b/vllm/model_executor/model_loader/ep_weight_filter.py @@ -73,4 +73,9 @@ def should_skip_weight( if eid is None: # Not an expert weight (dense / shared-expert / embedding) → keep. return False + # Only skip heavy weight tensors, never scale/metadata tensors. + # Scale tensors are tiny and some backends need them from ALL experts + # (e.g. FlashInfer NVFP4 computes a global max of activation scales). + if not weight_name.endswith(".weight"): + return False return eid not in local_expert_ids diff --git a/vllm/model_executor/models/eagle2_5_vl.py b/vllm/model_executor/models/eagle2_5_vl.py index 3e6182db586c..30b8173f19cf 100644 --- a/vllm/model_executor/models/eagle2_5_vl.py +++ b/vllm/model_executor/models/eagle2_5_vl.py @@ -16,7 +16,10 @@ from vllm.model_executor.models.siglip import SiglipVisionModel from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.sequence import IntermediateTensors -from vllm.transformers_utils.processors.eagle2_5_vl import Eagle2_5_VLProcessor +from vllm.transformers_utils.processors.internvl import ( + InternVLImageProcessor, + InternVLProcessor, +) from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import ( @@ -68,12 +71,35 @@ class Eagle2_5_VLImageEmbeddingInputs(TensorSchema): class Eagle2_5_VLProcessingInfo(BaseInternVLProcessingInfo): """Processing info for Eagle2.5-VL model.""" - def get_hf_processor(self, **kwargs) -> Eagle2_5_VLProcessor: - return self.ctx.init_processor( - Eagle2_5_VLProcessor, - config=self.ctx.get_hf_config(), + def get_image_processor(self, **kwargs): + config = self.get_hf_config() + vision_config = config.vision_config + + kwargs = self.ctx.get_merged_mm_kwargs(kwargs) + kwargs.setdefault( + "image_size", config.force_image_size or vision_config.image_size + ) + kwargs.setdefault("min_dynamic_patch", config.min_dynamic_patch) + kwargs.setdefault("max_dynamic_patch", config.max_dynamic_patch) + kwargs.setdefault("dynamic_image_size", config.dynamic_image_size) + kwargs.setdefault("use_thumbnail", config.use_thumbnail) + + return InternVLImageProcessor(**kwargs) + + def get_hf_processor(self, **kwargs) -> InternVLProcessor: + config = self.get_hf_config() + vision_config = config.vision_config + + image_processor = self.get_image_processor(**kwargs) + image_size = image_processor.image_size + patch_size = vision_config.patch_size + downsample_ratio = config.downsample_ratio + image_seq_length = int((image_size // patch_size) ** 2 * (downsample_ratio**2)) + + return InternVLProcessor( tokenizer=self.get_tokenizer(), - **kwargs, + image_processor=image_processor, + image_seq_length=image_seq_length, ) diff --git a/vllm/model_executor/models/glm4v.py b/vllm/model_executor/models/glm4v.py index 4434d10369e9..83af8ea86cd9 100644 --- a/vllm/model_executor/models/glm4v.py +++ b/vllm/model_executor/models/glm4v.py @@ -395,13 +395,13 @@ def get_image_processor(self, **kwargs): vision_config = config.vision_config image_size = vision_config["image_size"] + kwargs = self.ctx.get_merged_mm_kwargs(kwargs) kwargs.setdefault("size", {"width": image_size, "height": image_size}) return GLM4VImageProcessorFast(**kwargs) def get_hf_processor(self, **kwargs: object) -> GLM4VProcessor: - return self.ctx.init_processor( - GLM4VProcessor, + return GLM4VProcessor( tokenizer=self.get_tokenizer(), image_processor=self.get_image_processor(**kwargs), ) diff --git a/vllm/model_executor/models/gpt_oss.py b/vllm/model_executor/models/gpt_oss.py index c3111489c0ca..482056250a1e 100644 --- a/vllm/model_executor/models/gpt_oss.py +++ b/vllm/model_executor/models/gpt_oss.py @@ -20,12 +20,11 @@ tensor_model_parallel_all_gather, ) from vllm.model_executor.layers.attention import Attention -from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.fused_moe import FusedMoE, GateLinear from vllm.model_executor.layers.fused_moe.config import FusedMoEParallelConfig from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( QKVParallelLinear, - ReplicatedLinear, RowParallelLinear, ) from vllm.model_executor.layers.logits_processor import LogitsProcessor @@ -175,13 +174,11 @@ def __init__( self.hidden_size = config.hidden_size self.experts_per_token = config.num_experts_per_tok self.world_size = dist.get_world_size() if dist.is_initialized() else 1 - self.router = ReplicatedLinear( + self.router = GateLinear( config.hidden_size, config.num_local_experts, bias=True, - quant_config=None, prefix=f"{prefix}.router", - return_bias=False, ) assert config.intermediate_size % self.world_size == 0 self.experts = FusedMoE( @@ -209,7 +206,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: self, x[:, : self.hidden_size], self.router.weight, self.router.bias ) else: - g = self.router(x) + g, _ = self.router(x) x = self.experts(hidden_states=x, router_logits=g)[:, : self.hidden_size] if self.is_sequence_parallel: @@ -273,7 +270,6 @@ def __init__( self.config = vllm_config.model_config.hf_config self.quant_config = vllm_config.quant_config self.parallel_config = vllm_config.parallel_config - self.config.hidden_size = self.config.hidden_size self.embedding = VocabParallelEmbedding( self.config.vocab_size, self.config.hidden_size, diff --git a/vllm/model_executor/models/h2ovl.py b/vllm/model_executor/models/h2ovl.py index 3b01985c4458..e684280fef36 100644 --- a/vllm/model_executor/models/h2ovl.py +++ b/vllm/model_executor/models/h2ovl.py @@ -28,7 +28,7 @@ PromptUpdate, TimingContext, ) -from vllm.transformers_utils.processors.h2ovl import H2OVLProcessor +from vllm.transformers_utils.processors.h2ovl import H2OVLImageProcessor, H2OVLProcessor from .intern_vit import InternVisionModel from .internvl import ( @@ -40,12 +40,34 @@ class H2OVLProcessingInfo(BaseInternVLProcessingInfo): + def get_image_processor(self, **kwargs): + config = self.get_hf_config() + vision_config = config.vision_config + + kwargs = self.ctx.get_merged_mm_kwargs(kwargs) + kwargs.setdefault("image_size", vision_config.image_size) + kwargs.setdefault("min_dynamic_patch", config.min_dynamic_patch) + kwargs.setdefault("max_dynamic_patch", config.max_dynamic_patch) + kwargs.setdefault("dynamic_image_size", config.dynamic_image_size) + kwargs.setdefault("use_thumbnail", config.use_thumbnail) + kwargs.setdefault("use_msac", config.use_msac) + + return H2OVLImageProcessor(**kwargs) + def get_hf_processor(self, **kwargs: object) -> H2OVLProcessor: - return self.ctx.init_processor( - H2OVLProcessor, - config=self.get_hf_config(), + config = self.get_hf_config() + vision_config = config.vision_config + + image_processor = self.get_image_processor(**kwargs) + image_size = image_processor.image_size + patch_size = vision_config.patch_size + downsample_ratio = config.downsample_ratio + image_seq_length = int((image_size // patch_size) ** 2 * (downsample_ratio**2)) + + return H2OVLProcessor( tokenizer=self.get_tokenizer(), - **kwargs, + image_processor=image_processor, + image_seq_length=image_seq_length, ) def get_num_image_tokens( @@ -106,7 +128,7 @@ def get_replacement_internvl(item_idx: int): if num_patches is not None: assert isinstance(num_patches, int) - return hf_processor.get_image_repl(feature_size, num_patches) + return hf_processor.get_image_repl(num_patches, num_features=feature_size) return [ PromptReplacement( @@ -163,3 +185,17 @@ def _init_vision_model( else: msg = "Monolith mode is not applicable to H2OVL" raise NotImplementedError(msg) + + def get_num_mm_encoder_tokens(self, num_image_tokens: int) -> int: + if num_image_tokens <= 0 or self.num_image_token <= 0: + return 0 + + num_patches = num_image_tokens // self.num_image_token + return num_patches * (self.patch_tokens + 1) + + def get_num_mm_connector_tokens(self, num_vision_tokens: int) -> int: + if num_vision_tokens <= 0 or self.num_image_token <= 0: + return 0 + + num_patches = num_vision_tokens // (self.patch_tokens + 1) + return num_patches * self.num_image_token diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index 8126391b269e..3c33da212f1d 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -9,6 +9,7 @@ # -------------------------------------------------------- from abc import abstractmethod from collections.abc import Iterable, Mapping, Sequence +from functools import cached_property from typing import Annotated, Literal, TypeAlias, TypeVar import torch @@ -45,8 +46,9 @@ ) from vllm.sequence import IntermediateTensors from vllm.transformers_utils.processors.internvl import ( - BaseInternVLProcessor, + InternVLImageProcessor, InternVLProcessor, + InternVLVideoProcessor, ) from vllm.utils.tensor_schema import TensorSchema, TensorShape @@ -123,7 +125,7 @@ class BaseInternVLProcessingInfo(BaseProcessingInfo): """Basic image-only ProcessingInfo for InternVL-style models.""" @abstractmethod - def get_hf_processor(self, **kwargs: object) -> BaseInternVLProcessor: + def get_hf_processor(self, **kwargs: object) -> InternVLProcessor: raise NotImplementedError def get_supported_mm_limits(self) -> Mapping[str, int | None]: @@ -134,7 +136,7 @@ def get_num_image_tokens( *, image_width: int, image_height: int, - processor: BaseInternVLProcessor, + processor: InternVLProcessor, ) -> int: return processor.get_num_image_tokens( image_width=image_width, @@ -143,8 +145,9 @@ def get_num_image_tokens( def get_image_size_with_most_features(self) -> ImageSize: processor = self.get_hf_processor() + image_processor = processor.image_processor - base_size = processor.image_size + base_size = image_processor.image_size target_ratios = processor.resolve_target_ratios() largest_feature_size, largest_feature_pinpoint = 0, None @@ -226,7 +229,7 @@ def _call_hf_processor( ) hf_processor = self.info.get_hf_processor(**mm_kwargs) - image_token_id = hf_processor.image_token_id + image_token_id = hf_processor.ctx_image_token_id # Since there may be extra tokens in the feature placeholders, # we need to pass the image token ID to the model to select the @@ -291,7 +294,7 @@ def get_replacement_internvl(item_idx: int): if num_patches is not None: assert isinstance(num_patches, int) - return hf_processor.get_image_repl(feature_size, num_patches) + return hf_processor.get_image_repl(num_patches, num_features=feature_size) return [ PromptReplacement( @@ -305,23 +308,73 @@ def get_replacement_internvl(item_idx: int): class InternVLProcessingInfo(BaseInternVLProcessingInfo): """InternVL ProcessingInfo extended for video processing""" - @property - def supports_video(self): - return self.get_hf_processor().supports_video + def get_image_processor(self, **kwargs): + config = self.get_hf_config() + vision_config = config.vision_config - def get_supported_mm_limits(self): - video_limit = {"video": None} if self.supports_video else {} - return {**super().get_supported_mm_limits(), **video_limit} + kwargs = self.ctx.get_merged_mm_kwargs(kwargs) + kwargs.setdefault("image_size", vision_config.image_size) + kwargs.setdefault("min_dynamic_patch", config.min_dynamic_patch) + kwargs.setdefault("max_dynamic_patch", config.max_dynamic_patch) + kwargs.setdefault("dynamic_image_size", config.dynamic_image_size) + kwargs.setdefault("use_thumbnail", config.use_thumbnail) + + return InternVLImageProcessor(**kwargs) + + def get_video_processor(self, **kwargs): + config = self.get_hf_config() + vision_config = config.vision_config - def get_video_token(self) -> str | None: + kwargs = self.ctx.get_merged_mm_kwargs(kwargs) + kwargs.setdefault("image_size", vision_config.image_size) + + return InternVLVideoProcessor(**kwargs) + + @cached_property + def ctx_video_token(self): text_model_type = self.get_hf_config().get_text_config().model_type - video_token_map = { + ctx_video_token_map = { "qwen2": "<|video_pad|>", "qwen3": "<|video_pad|>", "qwen3_moe": "<|video_pad|>", "gpt_oss": "<|reserved_200000|>", } - return video_token_map.get(text_model_type) + + if text_model_type not in ctx_video_token_map: + return None + + ctx_video_token = ctx_video_token_map[text_model_type] + if ctx_video_token not in self.get_tokenizer().get_vocab(): + return None + + return ctx_video_token + + def get_hf_processor(self, **kwargs: object) -> InternVLProcessor: + config = self.get_hf_config() + vision_config = config.vision_config + + image_processor = self.get_image_processor(**kwargs) + image_size = image_processor.image_size + patch_size = vision_config.patch_size + downsample_ratio = config.downsample_ratio + image_seq_length = int((image_size // patch_size) ** 2 * (downsample_ratio**2)) + + ctx_video_token = self.ctx_video_token + video_processor = ( + self.get_video_processor(**kwargs) if ctx_video_token else None + ) + + return InternVLProcessor( + tokenizer=self.get_tokenizer(), + image_processor=image_processor, + video_processor=video_processor, + image_seq_length=image_seq_length, + ctx_video_token=ctx_video_token, + ) + + def get_supported_mm_limits(self): + video_limit = {"video": None} if self.ctx_video_token else {} + return {**super().get_supported_mm_limits(), **video_limit} def get_num_frames_with_most_features( self, @@ -332,22 +385,14 @@ def get_num_frames_with_most_features( max_videos = mm_counts.get("video", 0) processor = self.get_hf_processor() + num_image_token = processor.image_seq_length max_image_tokens = self.get_max_image_tokens() * max_images - max_total_frames = (seq_len - max_image_tokens) // processor.num_image_token + max_total_frames = (seq_len - max_image_tokens) // num_image_token max_frames_per_video = max_total_frames // max(max_videos, 1) return max(max_frames_per_video, 1) - def get_hf_processor(self, **kwargs: object) -> InternVLProcessor: - return self.ctx.init_processor( - InternVLProcessor, - config=self.get_hf_config(), - tokenizer=self.get_tokenizer(), - video_token=self.get_video_token(), - **kwargs, - ) - class InternVLDummyInputsBuilder( BaseInternVLDummyInputsBuilder[InternVLProcessingInfo] @@ -366,7 +411,7 @@ def get_dummy_mm_data( mm_options: Mapping[str, BaseDummyOptions], ) -> MultiModalDataDict: dummy_image = super().get_dummy_mm_data(seq_len, mm_counts, mm_options) - if self.info.supports_video: + if self.info.ctx_video_token: config = self.info.get_hf_config() image_size: int = config.vision_config.image_size target_num_frames = self.info.get_num_frames_with_most_features( @@ -405,11 +450,9 @@ def _call_hf_processor( ) hf_processor = self.info.get_hf_processor(**mm_kwargs) - if ( - self.info.supports_video - and (video_token_id := hf_processor.video_token_id) is not None - ): + if (video_token_id := hf_processor.ctx_video_token_id) is not None: processed_outputs["video_token_id"] = torch.tensor(video_token_id) + return processed_outputs def _get_mm_fields_config( @@ -418,7 +461,7 @@ def _get_mm_fields_config( hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: image_fields = super()._get_mm_fields_config(hf_inputs, hf_processor_mm_kwargs) - if self.info.supports_video: + if self.info.ctx_video_token: video_num_patches = hf_inputs.get("video_num_patches", torch.empty(0)) num_videos = len(video_num_patches) video_fields = dict( @@ -444,6 +487,8 @@ def _get_prompt_updates( hf_processor_mm_kwargs=hf_processor_mm_kwargs, out_mm_kwargs=out_mm_kwargs, ) + if self.info.ctx_video_token is None: + return prompt_repl hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) @@ -456,26 +501,20 @@ def _get_prompt_updates( video_num_patches = [] def get_video_replacement_internvl(item_idx: int): - feature_size = hf_processor.num_image_token num_patches = video_num_patches[item_idx] if num_patches is not None: assert isinstance(num_patches, int) - return hf_processor.get_video_repl( - feature_size, num_patches, video_context_token=hf_processor.video_token - ) - - if self.info.supports_video: - prompt_repl = [ - *prompt_repl, - PromptReplacement( - modality="video", - target="