diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml index 195579f206a2..b83503754dcb 100644 --- a/.github/workflows/pre-commit.yml +++ b/.github/workflows/pre-commit.yml @@ -14,7 +14,7 @@ permissions: jobs: pre-commit: - runs-on: ubuntu-latest + runs-on: self-hosted steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0 @@ -26,3 +26,5 @@ jobs: - uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1 with: extra_args: --all-files --hook-stage manual + env: + SKIP: shellcheck diff --git a/cmake/external_projects/vllm_flash_attn.cmake b/cmake/external_projects/vllm_flash_attn.cmake index ef45a5fbebf6..4eb4b464a216 100644 --- a/cmake/external_projects/vllm_flash_attn.cmake +++ b/cmake/external_projects/vllm_flash_attn.cmake @@ -38,7 +38,7 @@ else() FetchContent_Declare( vllm-flash-attn GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git - GIT_TAG 1c2624e53c078854e0637ee566c72fe2107e75f4 + GIT_TAG b99f8c821771fd11feb66d5c89661e9858fde359 GIT_PROGRESS TRUE # Don't share the vllm-flash-attn build between build types BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn diff --git a/examples/online_serving/openai_response_api_gpt_oss.py b/examples/online_serving/openai_response_api_gpt_oss.py new file mode 100644 index 000000000000..55bbf66d8980 --- /dev/null +++ b/examples/online_serving/openai_response_api_gpt_oss.py @@ -0,0 +1,604 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +vllm serve /data/woosuk/os-mini-weights/pytorch-rc-20b --enforce-eager +""" + +import argparse +import json +import time + +import requests +from openai import BadRequestError, NotFoundError, OpenAI + +parser = argparse.ArgumentParser() +parser.add_argument("--model", type=str, required=False, choices=["gpt-4.1", "o4-mini"]) +parser.add_argument("--port", type=int, required=False, default=8000) +args = parser.parse_args() + +MODEL = args.model +if MODEL is None: + openai_api_key = "EMPTY" + openai_api_base = f"http://localhost:{args.port}/v1" +else: + openai_api_key = None + openai_api_base = None + +client = OpenAI(api_key=openai_api_key, base_url=openai_api_base) + + +def test_basic(): + response = client.responses.create( + model=MODEL, + input="What is 13 * 24?", + # max_output_tokens=10, + ) + print(response) + + +def test_basic_with_instructions(): + response = client.responses.create( + model=MODEL, + input="What is 13 * 24?", + instructions="Respond in Korean.", + ) + print(response) + + +def test_basic_with_reasoning_effort(): + response = client.responses.create( + model=MODEL, + input="What is the capital of South Korea?", + reasoning={"effort": "low"}, + ) + print(response) + + +def test_chat(): + response = client.responses.create( + model=MODEL, + input=[ + {"role": "system", "content": "Respond in Korean."}, + {"role": "user", "content": "Hello!"}, + {"role": "assistant", "content": "Hello! How can I help you today?"}, + {"role": "user", "content": "What is 13 * 24? Explain your answer."}, + ], + ) + print(response) + + +def test_chat_with_input_type(): + response = client.responses.create( + model=MODEL, + input=[ + { + "role": "user", + "content": [{"type": "input_text", "text": "What is 13*24?"}], + }, + ], + ) + print(response) + + +def test_structured_output(): + response = client.responses.create( + model=MODEL, + input=[ + {"role": "system", "content": "Extract the event information."}, + { + "role": "user", + "content": "Alice and Bob are going to a science fair on Friday.", + }, + ], + text={ + "format": { + "type": "json_schema", + "name": "calendar_event", + "schema": { + "type": "object", + "properties": { + "name": {"type": "string"}, + "date": {"type": "string"}, + "participants": {"type": "array", "items": {"type": "string"}}, + }, + "required": ["name", "date", "participants"], + "additionalProperties": False, + }, + "description": "A calendar event.", + "strict": True, + } + }, + ) + print(response) + + +def test_structured_output_with_parse(): + from pydantic import BaseModel + + class CalendarEvent(BaseModel): + name: str + date: str + participants: list[str] + + response = client.responses.parse( + model=MODEL, + input="Alice and Bob are going to a science fair on Friday", + instructions="Extract the event information", + text_format=CalendarEvent, + ) + print(response) + + +def test_store(): + for store in [True, False]: + response = client.responses.create( + model=MODEL, + input="What is 13 * 24?", + store=store, + ) + print(response) + + try: + response = client.responses.retrieve(response.id) + print(response) + except NotFoundError: + is_not_found = True + else: + is_not_found = False + assert is_not_found == (not store) + + +def test_background(): + response = client.responses.create( + model=MODEL, + input="What is 13 * 24?", + background=True, + ) + print(response) + + while True: + response = client.responses.retrieve(response.id) + if response.status == "completed": + break + time.sleep(1) + print(response) + + +def test_background_cancel(): + response = client.responses.create( + model=MODEL, + input="Write a long story about a cat.", + background=True, + ) + print(response) + time.sleep(1) + response = client.responses.cancel(response.id) + print(response) + + +def test_stateful_multi_turn(): + response1 = client.responses.create( + model=MODEL, + input="What is 13 * 24?", + ) + print(response1) + + response2 = client.responses.create( + model=MODEL, + input="What if I increase both numbers by 1?", + previous_response_id=response1.id, + ) + print(response2) + + response3 = client.responses.create( + model=MODEL, + input="Divide the result by 2.", + previous_response_id=response2.id, + ) + print(response3) + + +def test_streaming(): + promts = [ + "tell me a story about a cat in 20 words", + "What is 13 * 24? Use python to calculate the result.", + "When did Jensen found NVIDIA? Search it and answer the year only.", + ] + for prompt in promts: + print(f"\n{prompt}\n") + response = client.responses.create( + model=MODEL, + input=prompt, + reasoning={"effort": "low"}, + tools=[ + {"type": "web_search_preview"}, + {"type": "code_interpreter", "container": {"type": "auto"}}, + ], + stream=True, + ) + + events = [] + current_event_mode = None + + for event in response: + if current_event_mode != event.type: + current_event_mode = event.type + print(f"\n[{event.type}] ", end="", flush=True) + + if "text.delta" in event.type: + print(event.delta, end="", flush=True) + elif "reasoning_text.delta" in event.type: + print(f"{event.delta}", end="", flush=True) + elif "response.code_interpreter_call_code.done" in event.type: + print(f"Code: {event.code}", end="", flush=True) + elif ( + "response.output_item.added" in event.type + and event.item.type == "web_search_call" + ): + print(f"Web search: {event.item.action}", end="", flush=True) + events.append(event) + + print("\n--------------------------------\n") + + +def test_web_search(): + response = client.responses.create( + model=MODEL, + input="Who is the president of South Korea as of now?", + tools=[{"type": "web_search_preview"}], + ) + print(response) + + +def test_code_interpreter(): + response = client.responses.create( + model=MODEL, + input="Multiply 643258029438.6132 * 23516705917230.84279 using Python.", + tools=[{"type": "code_interpreter", "container": {"type": "auto"}}], + ) + print(response) + + +def get_weather(latitude, longitude): + response = requests.get( + f"https://api.open-meteo.com/v1/forecast?latitude={latitude}&longitude={longitude}¤t=temperature_2m,wind_speed_10m&hourly=temperature_2m,relative_humidity_2m,wind_speed_10m" # noqa + ) + data = response.json() + return data["current"]["temperature_2m"] + + +def get_place_to_travel(): + return "Paris" + + +def call_function(name, args): + if name == "get_weather": + return get_weather(**args) + elif name == "get_place_to_travel": + return get_place_to_travel() + else: + raise ValueError(f"Unknown function: {name}") + + +def test_function_calling(): + tools = [ + { + "type": "function", + "name": "get_weather", + "description": "Get current temperature for provided coordinates in celsius.", # noqa + "parameters": { + "type": "object", + "properties": { + "latitude": {"type": "number"}, + "longitude": {"type": "number"}, + }, + "required": ["latitude", "longitude"], + "additionalProperties": False, + }, + "strict": True, + } + ] + + response = client.responses.create( + model=MODEL, + input="What's the weather like in Paris today?", + tools=tools, + ) + print("The first response:") + print(response) + print("output:") + for out in response.output: + print(out) + print("--------------------------------") + + assert len(response.output) == 2 + assert response.output[0].type == "reasoning" + assert response.output[1].type == "function_call" + + tool_call = response.output[1] + + name = tool_call.name + args = json.loads(tool_call.arguments) + + result = call_function(name, args) + print("tool call result: ", result, type(result)) + + response_2 = client.responses.create( + model=MODEL, + input=[ + { + "type": "function_call_output", + "call_id": tool_call.call_id, + "output": str(result), + } + ], + tools=tools, + previous_response_id=response.id, + ) + print("The second response:") + print(response_2) + print("output:") + for out in response_2.output: + print(out) + print("--------------------------------") + print(response_2.output_text) + + # NOTE: chain-of-thought should be removed. + response_3 = client.responses.create( + model=MODEL, + input="What's the weather like in Paris today?", + tools=tools, + previous_response_id=response_2.id, + ) + print("The third response:") + print(response_3) + print("output:") + for out in response_3.output: + print(out) + print("--------------------------------") + print(response_3.output_text) + + +def test_function_calling_multi_turn(): + tools = [ + { + "type": "function", + "name": "get_place_to_travel", + "description": "Get a random place to travel", + "parameters": { + "type": "object", + "properties": {}, + "required": [], + "additionalProperties": False, + }, + "strict": True, + }, + { + "type": "function", + "name": "get_weather", + "description": "Get current temperature for provided coordinates in celsius.", # noqa + "parameters": { + "type": "object", + "properties": { + "latitude": {"type": "number"}, + "longitude": {"type": "number"}, + }, + "required": ["latitude", "longitude"], + "additionalProperties": False, + }, + "strict": True, + }, + ] + + response = client.responses.create( + model=MODEL, + input="Help me plan a trip to a random place. And tell me the weather there.", + tools=tools, + ) + print("The first response:") + print(response) + print("output:") + for out in response.output: + print(out) + print("--------------------------------") + + assert len(response.output) == 2 + assert response.output[0].type == "reasoning" + assert response.output[1].type == "function_call" + + tool_call = response.output[1] + + name = tool_call.name + args = json.loads(tool_call.arguments) + + result = call_function(name, args) + print("tool call result: ", result, type(result)) + + response_2 = client.responses.create( + model=MODEL, + input=[ + { + "type": "function_call_output", + "call_id": tool_call.call_id, + "output": str(result), + } + ], + tools=tools, + previous_response_id=response.id, + ) + print("The second response:") + print(response_2) + print("output:") + for out in response_2.output: + print(out) + print("--------------------------------") + assert len(response_2.output) == 2 + assert response_2.output[0].type == "reasoning" + assert response_2.output[1].type == "function_call" + + tool_call = response_2.output[1] + + name = tool_call.name + args = json.loads(tool_call.arguments) + + result = call_function(name, args) + print("tool call result: ", result, type(result)) + + response_3 = client.responses.create( + model=MODEL, + input=[ + { + "type": "function_call_output", + "call_id": tool_call.call_id, + "output": str(result), + } + ], + tools=tools, + previous_response_id=response_2.id, + ) + print("The third response:") + print(response_3) + print("output:") + for out in response_3.output: + print(out) + print("--------------------------------") + print(response_3.output_text) + + +def test_function_calling_required(): + tools = [ + { + "type": "function", + "name": "get_weather", + "description": "Get current temperature for provided coordinates in celsius.", # noqa + "parameters": { + "type": "object", + "properties": { + "latitude": {"type": "number"}, + "longitude": {"type": "number"}, + }, + "required": ["latitude", "longitude"], + "additionalProperties": False, + }, + "strict": True, + } + ] + try: + _response = client.responses.create( + model=MODEL, + input="What's the weather like in Paris today?", + tools=tools, + tool_choice="required", + ) + except BadRequestError as e: + print(e) + return + else: + raise ValueError("Should raise BadRequestError") + + +def test_function_calling_full_history(): + tools = [ + { + "type": "function", + "name": "get_weather", + "description": "Get current temperature for provided coordinates in celsius.", # noqa + "parameters": { + "type": "object", + "properties": { + "latitude": {"type": "number"}, + "longitude": {"type": "number"}, + }, + "required": ["latitude", "longitude"], + "additionalProperties": False, + }, + "strict": True, + } + ] + + input_messages = [ + {"role": "user", "content": "What's the weather like in Paris today?"} + ] + + response = client.responses.create( + model=MODEL, + input=input_messages, + tools=tools, + ) + + print(response) + print("output:") + for out in response.output: + print(out) + print("--------------------------------") + + tool_call = response.output[-1] + name = tool_call.name + args = json.loads(tool_call.arguments) + + result = call_function(name, args) + print("tool call result: ", result, type(result)) + + input_messages.extend(response.output) # append model's function call message + input_messages.append( + { # append result message + "type": "function_call_output", + "call_id": tool_call.call_id, + "output": str(result), + } + ) + + print("input_messages: ", input_messages) + + response_2 = client.responses.create( + model=MODEL, + input=input_messages, + tools=tools, + ) + print(response_2.output_text) + + +if __name__ == "__main__": + # 1. Stateless & Non-streaming tests: + print("===test_basic:") + test_basic() + print("===test_basic_with_instructions:") + test_basic_with_instructions() + print("===test_basic_with_reasoning_effort:") + test_basic_with_reasoning_effort() + print("===test_chat:") + test_chat() # should we overwrite system message? + print("===test_chat_with_input_type:") + test_chat_with_input_type() + print("===test_structured_output:") + test_structured_output() + print("===test_structured_output_with_parse:") + test_structured_output_with_parse() + + # 2. Stateful & Non-streaming tests: + print("===test_store:") + test_store() + print("===test_background:") + test_background() + print("===test_background_cancel:") + test_background_cancel() + print("===test_stateful_multi_turn:") + test_stateful_multi_turn() + + # 3. Streaming tests: + print("===test_streaming:") + test_streaming() + + # 4. Tool tests: + print("===test_web_search:") + test_web_search() # can crash occasionally + print("===test_code_interpreter:") + test_code_interpreter() + print("===test_function_calling:") + test_function_calling() + print("===test_function_calling_multi_turn:") + test_function_calling_multi_turn() # can crash occasionally + print("===test_function_calling_required:") + test_function_calling_required() + print("===test_function_calling_full_history:") + test_function_calling_full_history() diff --git a/pyproject.toml b/pyproject.toml index dfad5d2cdf31..25c368e09afc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,6 @@ requires = [ "packaging>=24.2", "setuptools>=77.0.3,<80.0.0", "setuptools-scm>=8.0", - "torch == 2.7.1", "wheel", "jinja2", ] diff --git a/requirements/common.txt b/requirements/common.txt index 6b57a3d2f1d0..daededa8dc82 100644 --- a/requirements/common.txt +++ b/requirements/common.txt @@ -13,8 +13,8 @@ tokenizers >= 0.21.1 # Required for fast incremental detokenization. protobuf # Required by LlamaTokenizer. fastapi[standard] >= 0.115.0 # Required by FastAPI's form models in the OpenAI API server's audio transcriptions endpoint. aiohttp -openai >= 1.87.0 # Ensure modern openai package (ensure ResponsePrompt exists in type.responses and max_completion_tokens field support) -pydantic >= 2.10 +openai >= 1.98.0 # Ensure modern openai package (ensure ResponsePrompt exists in type.responses and max_completion_tokens field support) +pydantic >= 2.11.7 prometheus_client >= 0.18.0 pillow # Required for image processing prometheus-fastapi-instrumentator >= 7.0.0 @@ -49,3 +49,5 @@ ninja # Required for xgrammar, rocm, tpu, xpu pybase64 # fast base64 implementation cbor2 # Required for cross-language serialization of hashable objects setproctitle # Used to set process names for better debugging and monitoring +mcp==1.12.3 +openai-harmony diff --git a/requirements/test.txt b/requirements/test.txt index 691420df87c4..7d76a9f58b7b 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -749,7 +749,7 @@ pycparser==2.22 # via cffi pycryptodomex==3.22.0 # via blobfile -pydantic==2.11.5 +pydantic==2.11.7 # via # -r requirements/test.in # albumentations diff --git a/tests/kernels/moe/test_triton_kernel_oai.py b/tests/kernels/moe/test_triton_kernel_oai.py new file mode 100644 index 000000000000..e3e5d8f7ecbd --- /dev/null +++ b/tests/kernels/moe/test_triton_kernel_oai.py @@ -0,0 +1,371 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from dataclasses import dataclass, fields + +import pytest +import torch +import torch.nn.functional as F +import triton_kernels.swiglu +from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig +from triton_kernels.numerics import InFlexData +from triton_kernels.numerics_details.mxfp import (downcast_to_mxfp, + upcast_from_mxfp) +from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor +from triton_kernels.tensor_details import layout +from triton_kernels.testing import assert_close + +from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( + BatchedPrepareAndFinalize) +from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk +from vllm.model_executor.layers.fused_moe.modular_kernel import ( + FusedMoEModularKernel) +from vllm.model_executor.layers.fused_moe.triton_kernels_moe import ( + BatchedOAITritonExperts, triton_kernel_moe_forward) +from vllm.model_executor.layers.utils import shuffle_weight +from vllm.utils import round_up + + +def deshuffle(w: torch.Tensor): + first = w[..., ::2] + second = w[..., 1::2] + + deshuffled = torch.concat((first, second), dim=-1) + return deshuffled + + +def init_compute_data(M, K, N, E, a_dtype: str, w_dtype: str, num_warps: int): + randbits = [torch.randperm(E) for _ in range(M)] + x_list = [ + (-1)**i * + ((16384 + + ((i * 512) % 4096) + bits).to(torch.int16).view(torch.bfloat16)) + for i, bits in enumerate(randbits) + ] + exp_data = torch.stack(x_list).to( + device="cuda") # simulating gate_output (M, E) + + # create input tensor + x = torch.randn((M, K), dtype=torch.bfloat16, device="cuda") + w1 = torch.randn((E, 2 * N, K), dtype=torch.bfloat16, device="cuda") + w1_bias = torch.randn((E, 2 * N), dtype=torch.bfloat16, device="cuda") + + w2 = torch.randn((E, K, N), dtype=torch.bfloat16, device="cuda") + w2_bias = torch.randn((E, K), dtype=torch.bfloat16, device="cuda") + + exp_data_tri = exp_data.clone() + x_tri = x.clone() + w1_tri = w1.clone() + w2_tri = w2.clone() + + w1_bias_tri = w1_bias.clone() + w2_bias_tri = w2_bias.clone() + w1_bias_tri = w1_bias_tri.to(torch.float32) + w2_bias_tri = w2_bias_tri.to(torch.float32) + + dtype_dict = { + "bf16": torch.bfloat16, + "fp8_e4m3": torch.float8_e4m3fn, + "fp8_e5m2": torch.float8_e5m2 + } + + x = x.to(dtype_dict[a_dtype]).to(torch.bfloat16) + if w_dtype != "mx4": + # simulate quantization support on reference impl + w1 = w1.to(dtype_dict[w_dtype]).to(torch.bfloat16) + w2 = w2.to(dtype_dict[w_dtype]).to(torch.bfloat16) + + # triton moe kernel use transposed shape for matmul + w1_tri = w1_tri.transpose(-2, -1) + w2_tri = w2_tri.transpose(-2, -1) + + # shuffle weights + w1_tri = shuffle_weight(w1_tri) + w1_bias_tri = shuffle_weight(w1_bias_tri) + + # quant triton_weights + x_tri = x.to(dtype_dict[a_dtype]) + if w_dtype != "mx4": + pytest.skip("NYI") + else: # quantize to mx4 + # careful on the padding here, the activation padding need to be + # multiple of 64, the actual engine is not implemented + w1_bottom_pad = round_up(w1_tri.shape[1], 64) - w1_tri.shape[1] + w1_right_pad = round_up(w1_tri.shape[2], 128) - w1_tri.shape[2] + + w2_bottom_pad = w1_right_pad // 2 + w2_right_pad = w1_bottom_pad + + x_pad = w1_bottom_pad + + w1_tri = F.pad(w1_tri, (0, w1_right_pad, 0, w1_bottom_pad, 0, 0), + mode="constant", + value=0) + w2_tri = F.pad(w2_tri, (0, w2_right_pad, 0, w2_bottom_pad, 0, 0), + mode="constant", + value=0) + + w1_bias_tri = F.pad(w1_bias_tri, (0, w1_right_pad, 0, 0), + mode="constant", + value=0) + w2_bias_tri = F.pad(w2_bias_tri, (0, w2_right_pad, 0, 0), + mode="constant", + value=0) + + x_tri = F.pad(x_tri, (0, x_pad, 0, 0), mode="constant", value=0) + + w_layout, w_layout_opts = layout.make_default_matmul_mxfp4_w_layout( + mx_axis=1) + w_scale_layout, w_scale_layout_opts = ( + layout.make_default_matmul_mxfp4_w_scale_layout( + mx_axis=1, num_warps=num_warps)) + + w1_tri, w1_scale_tri = downcast_to_mxfp(w1_tri, torch.uint8, axis=1) + w1 = upcast_from_mxfp(w1_tri, w1_scale_tri, torch.bfloat16, axis=1) + + w2_tri, w2_scale_tri = downcast_to_mxfp(w2_tri, torch.uint8, axis=1) + w2 = upcast_from_mxfp(w2_tri, w2_scale_tri, torch.bfloat16, axis=1) + + w1_tri = convert_layout(wrap_torch_tensor(w1_tri, FP4), w_layout, + **w_layout_opts) + w1_scale_tri = convert_layout(wrap_torch_tensor(w1_scale_tri), + w_scale_layout, **w_scale_layout_opts) + + w2_tri = convert_layout(wrap_torch_tensor(w2_tri, FP4), w_layout, + **w_layout_opts) + w2_scale_tri = convert_layout(wrap_torch_tensor(w2_scale_tri), + w_scale_layout, **w_scale_layout_opts) + + pc1 = PrecisionConfig(weight_scale=w1_scale_tri, + flex_ctx=FlexCtx(rhs_data=InFlexData())) + pc2 = PrecisionConfig(weight_scale=w2_scale_tri, + flex_ctx=FlexCtx(rhs_data=InFlexData())) + + # tucuate so the rest can run properly + w1 = w1[..., :K, :2 * N] + w2 = w2[..., :N, :K] + + w1 = deshuffle(w1) + + w1 = w1.transpose(-1, -2).contiguous() + w2 = w2.transpose(-1, -2).contiguous() + + return (x, w1, w1_bias, w2, w2_bias, exp_data, x_tri, w1_tri, w2_tri, + exp_data_tri, w1_bias_tri, w2_bias_tri, pc1, pc2) + + +@dataclass +class ModelConfig: + num_hidden_layers: int = 36 + num_experts: int = 128 + experts_per_token: int = 4 + vocab_size: int = 201088 + hidden_size: int = 2880 + intermediate_size: int = 2880 + head_dim: int = 64 + num_attention_heads: int = 64 + num_key_value_heads: int = 8 + sliding_window: int = 128 + initial_context_length: int = 4096 + rope_theta: float = 150000.0 + rope_scaling_factor: float = 32.0 + rope_ntk_alpha: float = 1.0 + rope_ntk_beta: float = 32.0 + + +def swiglu(x, alpha: float = 1.702, limit: float = 1.0): + # Note we add an extra bias of 1 to the linear layer + x_glu, x_linear = torch.chunk(x, 2, dim=-1) + if limit is not None: + x_glu = x_glu.clamp(max=limit) + out_glu = x_glu * torch.sigmoid(alpha * x_glu) + if limit is not None: + x_linear = x_linear.clamp(min=-limit, max=limit) + return out_glu * (x_linear + 1) + + +def oai_moe_forward( + hidden_states: torch.Tensor, # (M, K) + w1: torch.Tensor, # (E, 2N) + w1_bias: torch.Tensor, # (E, 2N, K) + w2: torch.Tensor, # (E, K, N) + w2_bias: torch.Tensor, # (E, N) + gating_output: torch.Tensor, # (M, E) + topk: int): + # model.py 309:330, assuming gating and norm + t = hidden_states + experts = torch.topk(gating_output, k=topk, dim=-1, sorted=True) + expert_weights = torch.nn.functional.softmax(experts.values, dim=1) + expert_indices = experts.indices + + # MLP #1 + mlp1_weight = w1[expert_indices, ...] + mlp1_bias = w1_bias[expert_indices, ...] + t = torch.einsum("beck,bk->bec", mlp1_weight, t) + mlp1_bias + t = swiglu(t, limit=7) + + # MLP #2 + mlp2_weight = w2[expert_indices, ...] + mlp2_bias = w2_bias[expert_indices, ...] + t = torch.einsum("beck,bek->bec", mlp2_weight, t) + t += mlp2_bias + + # Weighted sum of experts + t = torch.einsum("bec,be->bc", t, expert_weights) + + return t + + +@dataclass +class Case: + a_dtype: str + w_dtype: str + + +@pytest.mark.parametrize( + ", ".join(f.name for f in fields(Case)), + [ + tuple(getattr(case, f.name) for f in fields(Case)) for case in [ + # Case(a_dtype="bf16", w_dtype="bf16"), + # Case(a_dtype="fp8_e4m3", w_dtype="fp8_e5m2"), + Case(a_dtype="bf16", w_dtype="mx4") + ] + ], +) +@pytest.mark.parametrize("num_token", [2]) +@pytest.mark.parametrize("tp", [1, 2, 4, 8]) +def test_equiv(num_token, a_dtype, w_dtype, tp): + M = num_token + E = ModelConfig.num_experts + K = ModelConfig.hidden_size + N = ModelConfig.intermediate_size // tp + topk = ModelConfig.experts_per_token + + x, w1, w1_bias, w2, w2_bias, exp_data, \ + x_tri, w1_tri, w2_tri, exp_data_tri, w1_bias_tri,\ + w2_bias_tri, pc1, pc2 = init_compute_data( + M, K, N, E, a_dtype, w_dtype, num_warps=8) + + out_triton_monolithic = triton_kernel_moe_forward( + hidden_states=x_tri, + w1=w1_tri, + w2=w2_tri, + gating_output=exp_data_tri, + topk=topk, + renormalize=True, + w1_bias=w1_bias_tri, + w2_bias=w2_bias_tri, + w1_precision=pc1, + w2_precision=pc2) + out_triton_monolithic = out_triton_monolithic[..., :K] + + out_ref = oai_moe_forward(hidden_states=x, + w1=w1, + w1_bias=w1_bias, + w2=w2, + w2_bias=w2_bias, + gating_output=exp_data, + topk=topk) + assert_close(ref=out_ref, + tri=out_triton_monolithic, + maxtol=0.025, + rmstol=0.005) + + +def batched_moe(a: torch.Tensor, w1, w2, gating_output: torch.Tensor, + topk: int, renormalize: bool, w1_bias: torch.Tensor, + w2_bias: torch.Tensor, w1_precision: PrecisionConfig, + w2_precision: PrecisionConfig) -> torch.Tensor: + max_num_tokens = round_up(a.shape[0], 64) + + fused_experts = FusedMoEModularKernel( + BatchedPrepareAndFinalize(max_num_tokens, + num_dispatchers=1, + num_local_experts=w1.shape[0], + rank=0), + BatchedOAITritonExperts( + None, + max_num_tokens=max_num_tokens, + num_dispatchers=1, + w1_precision=w1_precision, + w2_precision=w2_precision, + ), + ) + + topk_weight, topk_ids, _ = fused_topk(a, gating_output, topk, renormalize) + + return fused_experts( + a, + w1, + w2, + topk_weight, + topk_ids, + w1_bias=w1_bias, + w2_bias=w2_bias, + ) + + +@pytest.mark.parametrize( + ", ".join(f.name for f in fields(Case)), + [ + tuple(getattr(case, f.name) for f in fields(Case)) for case in [ + # Case(a_dtype="bf16", w_dtype="bf16"), + # Case(a_dtype="fp8_e4m3", w_dtype="fp8_e5m2"), + Case(a_dtype="bf16", w_dtype="mx4") + ] + ], +) +@pytest.mark.parametrize("num_token", [64]) +@pytest.mark.parametrize("ep", [1, 2, 4, 8]) +def test_triton_kernel_batched_moe(num_token, a_dtype, w_dtype, ep): + M = num_token + E = ModelConfig.num_experts // ep + K = ModelConfig.hidden_size + N = ModelConfig.intermediate_size + topk = ModelConfig.experts_per_token + + x, w1, w1_bias, w2, w2_bias, exp_data, \ + x_tri, w1_tri, w2_tri, exp_data_tri, w1_bias_tri, \ + w2_bias_tri, pc1, pc2 = init_compute_data( + M, K, N, E, a_dtype, w_dtype, num_warps=4) + + out_tri = batched_moe(a=x_tri, + w1=w1_tri, + w2=w2_tri, + gating_output=exp_data_tri, + topk=topk, + renormalize=True, + w1_bias=w1_bias_tri, + w2_bias=w2_bias_tri, + w1_precision=pc1, + w2_precision=pc2) + out_tri = out_tri[..., :K] + + out_ref = oai_moe_forward(hidden_states=x, + w1=w1, + w1_bias=w1_bias, + w2=w2, + w2_bias=w2_bias, + gating_output=exp_data, + topk=topk) + assert_close(ref=out_ref, tri=out_tri, maxtol=0.025, rmstol=0.005) + + +def test_unit_shuffle(): + N = ModelConfig.intermediate_size + K = ModelConfig.hidden_size + m = torch.randn((K, 2 * N), dtype=torch.bfloat16, device="cuda") + + x = torch.randn(K, dtype=torch.bfloat16, device="cuda") + + m_shuffled = shuffle_weight(m) + + out_ref = x @ m + out_ref = swiglu(out_ref, limit=1.0) + + out = x @ m_shuffled + out = triton_kernels.swiglu.swiglu_torch( + out, + alpha=1.702, + precision_config=triton_kernels.swiglu.PrecisionConfig(limit=1.0)) + + assert_close(ref=out_ref, tri=out) diff --git a/tests/models/registry.py b/tests/models/registry.py index d86bd20fb0e3..5a751521c31b 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -169,6 +169,8 @@ def check_available_online( trust_remote_code=True), "DeepseekV3ForCausalLM": _HfExamplesInfo("deepseek-ai/DeepSeek-V3", # noqa: E501 trust_remote_code=True), + "Dots1ForCausalLM": _HfExamplesInfo("rednote-hilab/dots.llm1.inst", + min_transformers_version="4.53"), "Ernie4_5ForCausalLM": _HfExamplesInfo("baidu/ERNIE-4.5-0.3B-PT", min_transformers_version="4.54"), "Ernie4_5_MoeForCausalLM": _HfExamplesInfo("baidu/ERNIE-4.5-21B-A3B-PT", @@ -198,6 +200,9 @@ def check_available_online( {"6b": "EleutherAI/gpt-j-6b"}), "GPTNeoXForCausalLM": _HfExamplesInfo("EleutherAI/pythia-70m", {"1b": "EleutherAI/pythia-1.4b"}), + "GptOssForCausalLM": _HfExamplesInfo("openai/gpt-oss-20b", + trust_remote_code=True, + is_available_online=False), "GraniteForCausalLM": _HfExamplesInfo("ibm/PowerLM-3b"), "GraniteMoeForCausalLM": _HfExamplesInfo("ibm/PowerMoE-3b"), "GraniteMoeHybridForCausalLM": _HfExamplesInfo("ibm-granite/granite-4.0-tiny-preview"), # noqa: E501 @@ -235,6 +240,8 @@ def check_available_online( "MambaForCausalLM": _HfExamplesInfo("state-spaces/mamba-130m-hf"), "Mamba2ForCausalLM": _HfExamplesInfo("mistralai/Mamba-Codestral-7B-v0.1"), "FalconMambaForCausalLM": _HfExamplesInfo("tiiuae/falcon-mamba-7b-instruct"), # noqa: E501 + "MiMoForCausalLM": _HfExamplesInfo("XiaomiMiMo/MiMo-7B-RL", + trust_remote_code=True), "MiniCPMForCausalLM": _HfExamplesInfo("openbmb/MiniCPM-2B-sft-bf16", trust_remote_code=True), "MiniCPM3ForCausalLM": _HfExamplesInfo("openbmb/MiniCPM3-4B", @@ -297,10 +304,6 @@ def check_available_online( tokenizer="meta-llama/Llama-2-7b", trust_remote_code=True), "Zamba2ForCausalLM": _HfExamplesInfo("Zyphra/Zamba2-7B-instruct"), - "MiMoForCausalLM": _HfExamplesInfo("XiaomiMiMo/MiMo-7B-RL", - trust_remote_code=True), - "Dots1ForCausalLM": _HfExamplesInfo("rednote-hilab/dots.llm1.inst", - min_transformers_version="4.53"), # [Encoder-decoder] "BartModel": _HfExamplesInfo("facebook/bart-base"), "BartForConditionalGeneration": _HfExamplesInfo("facebook/bart-large-cnn"), diff --git a/tests/v1/entrypoints/openai/responses/test_basic.py b/tests/v1/entrypoints/openai/responses/test_basic.py index 974ea8673c44..c6cc30b80f34 100644 --- a/tests/v1/entrypoints/openai/responses/test_basic.py +++ b/tests/v1/entrypoints/openai/responses/test_basic.py @@ -17,7 +17,8 @@ async def test_simple_input(client: openai.AsyncOpenAI): # Whether the output contains the reasoning. assert outputs[0].type == "reasoning" - assert outputs[0].text != "" + assert outputs[0].content[0].type == "reasoning_text" + assert outputs[0].content[0].text != "" @pytest.mark.asyncio diff --git a/vllm/attention/ops/chunked_prefill_paged_decode.py b/vllm/attention/ops/chunked_prefill_paged_decode.py index 4f839348e522..3bd7627742a4 100644 --- a/vllm/attention/ops/chunked_prefill_paged_decode.py +++ b/vllm/attention/ops/chunked_prefill_paged_decode.py @@ -28,6 +28,7 @@ def kernel_paged_attention_2d( query_ptr, # [num_tokens, num_query_heads, head_size] key_cache_ptr, # [num_blks, num_kv_heads, head_size // x, blk_size, x] value_cache_ptr, # [num_blks, num_kv_heads, head_size, blk_size] + sink_ptr, # [num_query_heads] block_tables_ptr, # [num_seqs, max_num_blocks_per_seq] seq_lens_ptr, # [num_seqs] alibi_slopes_ptr, # [num_query_heads] @@ -95,7 +96,17 @@ def kernel_paged_attention_2d( block_table_offset = seq_idx * block_table_stride - M = tl.full([num_queries_per_kv_padded], float("-inf"), dtype=tl.float32) + if sink_ptr is None: + M = tl.full([num_queries_per_kv_padded], + float("-inf"), + dtype=tl.float32) + else: + M = tl.load( + sink_ptr + query_head_idx, + mask=head_mask, + other=float("-inf"), + ).to(dtype=tl.float32) + L = tl.full([num_queries_per_kv_padded], 1.0, dtype=tl.float32) acc = tl.zeros([num_queries_per_kv_padded, HEAD_SIZE_PADDED], dtype=tl.float32) @@ -206,23 +217,24 @@ def kernel_paged_attention_2d( def chunked_prefill_paged_decode( - query, - key, - value, - output, - kv_cache_dtype, - key_cache, - value_cache, - block_table, - query_start_loc, - seq_lens, - max_seq_len, - max_query_len, - k_scale, - v_scale, - alibi_slopes=None, - sliding_window=None, - sm_scale=None, + query, + key, + value, + output, + kv_cache_dtype, + key_cache, + value_cache, + block_table, + query_start_loc, + seq_lens, + max_seq_len, + max_query_len, + k_scale, + v_scale, + alibi_slopes=None, + sliding_window=None, + sm_scale=None, + sinks=None, # Optional tensor for sinks ): if sm_scale is None: @@ -253,6 +265,7 @@ def chunked_prefill_paged_decode( sliding_window=sliding_window, sm_scale=sm_scale, skip_decode=True, + sinks=sinks, ) block_size = value_cache.shape[3] @@ -281,11 +294,9 @@ def chunked_prefill_paged_decode( num_queries_per_kv_padded = max(triton.next_power_of_2(num_queries_per_kv), 16) - use_custom = use_rocm_custom_paged_attention(query.dtype, head_size, - block_size, - num_queries_per_kv, - max_seq_len, sliding_window, - kv_cache_dtype, alibi_slopes) + use_custom = use_rocm_custom_paged_attention( + query.dtype, head_size, block_size, num_queries_per_kv, max_seq_len, + sliding_window, kv_cache_dtype, alibi_slopes, sinks) if use_custom: _PARTITION_SIZE_ROCM = 256 max_num_partitions = ((max_seq_len + _PARTITION_SIZE_ROCM - 1) // @@ -334,6 +345,7 @@ def chunked_prefill_paged_decode( query_ptr=query, key_cache_ptr=key_cache, value_cache_ptr=value_cache, + sink_ptr=sinks, block_tables_ptr=block_table, seq_lens_ptr=seq_lens, alibi_slopes_ptr=alibi_slopes, diff --git a/vllm/attention/ops/prefix_prefill.py b/vllm/attention/ops/prefix_prefill.py index 13bef96722d2..64c90337970f 100644 --- a/vllm/attention/ops/prefix_prefill.py +++ b/vllm/attention/ops/prefix_prefill.py @@ -38,6 +38,7 @@ def _fwd_kernel(Q, V, K_cache, V_cache, + sink_ptr, B_Loc, sm_scale, k_scale, @@ -126,7 +127,15 @@ def _fwd_kernel(Q, other=0.0) # [M,D] # initialize pointer to m and l - m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + if sink_ptr is None: + m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + else: + m_i = tl.load( + sink_ptr + tl.full([BLOCK_M], cur_head, dtype=tl.int64), + mask=(offs_m < cur_batch_query_len), + other=float("-inf"), + ).to(dtype=tl.float32) + l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32) acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], dtype=tl.float32) # [M,D] @@ -732,7 +741,8 @@ def context_attention_fwd(q, alibi_slopes=None, sliding_window=None, sm_scale=None, - skip_decode=False): + skip_decode=False, + sinks=None): q_dtype_is_f32 = q.dtype is torch.float32 @@ -781,6 +791,7 @@ def context_attention_fwd(q, sliding_window = 0 if alibi_slopes is not None: + assert sinks is None, "Sinks arg is not supported with alibi" # need to reduce num. blocks when using fp32 # due to increased use of GPU shared memory # if q.dtype is torch.float32: @@ -843,7 +854,7 @@ def context_attention_fwd(q, max_seq_len = 0 if max_seq_len is None else max_seq_len extra_kargs = {} if current_platform.is_rocm(): - extra_kargs = {"kpack": 2, "waves_per_eu": 2} + extra_kargs = {"kpack": 1, "waves_per_eu": 2} grid = lambda META: (batch, head, triton.cdiv(max_input_len, META["BLOCK_M"])) @@ -853,6 +864,7 @@ def context_attention_fwd(q, v, k_cache, v_cache, + sinks, b_loc, sm_scale, k_scale, diff --git a/vllm/attention/ops/triton_unified_attention.py b/vllm/attention/ops/triton_unified_attention.py index 0fdba569f93f..0a2a636b3c80 100644 --- a/vllm/attention/ops/triton_unified_attention.py +++ b/vllm/attention/ops/triton_unified_attention.py @@ -52,6 +52,7 @@ def kernel_unified_attention_2d( query_ptr, # [num_tokens, num_query_heads, head_size] key_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size] value_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size] + sink_ptr, # [num_query_heads] block_tables_ptr, # [num_seqs, max_num_blocks_per_seq] seq_lens_ptr, # [num_seqs] alibi_slopes_ptr, # [num_query_heads] @@ -131,7 +132,15 @@ def kernel_unified_attention_2d( block_table_offset = seq_idx * block_table_stride - M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + if sink_ptr is None: + M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + else: + M = tl.load( + sink_ptr + query_offset_1, + mask=query_mask_1, + other=float("-inf"), + ).to(dtype=tl.float32) + L = tl.full([BLOCK_M], 1.0, dtype=tl.float32) acc = tl.zeros([BLOCK_M, HEAD_SIZE_PADDED], dtype=tl.float32) @@ -292,6 +301,7 @@ def kernel_unified_attention_3d( query_ptr, # [num_tokens, num_query_heads, head_size] key_cache_ptr, # [num_blks, num_kv_heads, head_size // x, blk_size, x] value_cache_ptr, # [num_blks, num_kv_heads, head_size, blk_size] + sink_ptr, # [num_query_heads] block_tables_ptr, # [num_seqs, max_num_blocks_per_seq] seq_lens_ptr, # [num_seqs] alibi_slopes_ptr, # [num_query_heads] @@ -383,7 +393,15 @@ def kernel_unified_attention_3d( block_table_offset = seq_idx * block_table_stride - M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + if sink_ptr is None or segm_idx != 0: + M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + else: + M = tl.load( + sink_ptr + query_offset_1, + mask=query_mask_1, + other=float("-inf"), + ).to(dtype=tl.float32) + L = tl.full([BLOCK_M], 1.0, dtype=tl.float32) acc = tl.zeros([BLOCK_M, HEAD_SIZE_PADDED], dtype=tl.float32) @@ -609,24 +627,25 @@ def reduce_segments( def unified_attention( - q, - k, - v, - out, - cu_seqlens_q, - max_seqlen_q, - seqused_k, - max_seqlen_k, - softmax_scale, - causal, - window_size, - block_table, - softcap, - q_descale, - k_descale, - v_descale, - alibi_slopes=None, - qq_bias=None, + q, + k, + v, + out, + cu_seqlens_q, + max_seqlen_q, + seqused_k, + max_seqlen_k, + softmax_scale, + causal, + window_size, + block_table, + softcap, + q_descale, + k_descale, + v_descale, + alibi_slopes=None, + qq_bias=None, + sinks=None, # Optional tensor for sinks ): assert causal, "Only causal attention is supported" assert q_descale is None, "Q scales not supported" @@ -635,6 +654,10 @@ def unified_attention( assert q.element_size() >= 2 or block_size >= 32, \ "Block size must be at least 32 for fp8" + if sinks is not None: + assert sinks.shape[0] == q.shape[1], \ + "Sinks must be num_query_heads size" + use_alibi_slopes = alibi_slopes is not None use_qq_bias = qq_bias is not None @@ -669,6 +692,7 @@ def unified_attention( query_ptr=q, key_cache_ptr=k, value_cache_ptr=v, + sink_ptr=sinks, block_tables_ptr=block_table, seq_lens_ptr=seqused_k, alibi_slopes_ptr=alibi_slopes, @@ -741,6 +765,7 @@ def unified_attention( query_ptr=q, key_cache_ptr=k, value_cache_ptr=v, + sink_ptr=sinks, block_tables_ptr=block_table, seq_lens_ptr=seqused_k, alibi_slopes_ptr=alibi_slopes, diff --git a/vllm/entrypoints/context.py b/vllm/entrypoints/context.py new file mode 100644 index 000000000000..965353a6461f --- /dev/null +++ b/vllm/entrypoints/context.py @@ -0,0 +1,202 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import json +import logging +from abc import ABC, abstractmethod +from typing import Union + +from mcp import ClientSession +from openai_harmony import Author, Message, Role, StreamState, TextContent + +from vllm.entrypoints.harmony_utils import ( + get_encoding, get_streamable_parser_for_assistant, render_for_completion) +from vllm.entrypoints.tool import Tool +from vllm.outputs import RequestOutput + +logger = logging.getLogger(__name__) + + +class ConversationContext(ABC): + + @abstractmethod + def append_output(self, output) -> None: + pass + + @abstractmethod + async def call_tool(self) -> list[Message]: + pass + + @abstractmethod + def need_builtin_tool_call(self) -> bool: + pass + + @abstractmethod + def render_for_completion(self) -> list[int]: + pass + + +class SimpleContext(ConversationContext): + + def __init__(self): + self.last_output = None + + def append_output(self, output) -> None: + self.last_output = output + + def need_builtin_tool_call(self) -> bool: + return False + + async def call_tool(self) -> list[Message]: + raise NotImplementedError("Should not be called.") + + def render_for_completion(self) -> list[int]: + raise NotImplementedError("Should not be called.") + + +class HarmonyContext(ConversationContext): + + def __init__( + self, + messages: list, + tool_sessions: dict[str, Union[ClientSession, Tool]], + ): + # TODO: Remove the hack of Union[ClientSession, Tool] by using MCP + # when demo. + self._messages = messages + self.tool_sessions = tool_sessions + + self.parser = get_streamable_parser_for_assistant() + self.num_init_messages = len(messages) + # TODO + self.num_prompt_tokens = 0 + self.num_cached_tokens = 0 + self.num_output_tokens = 0 + self.num_reasoning_tokens = 0 + + def append_output(self, output) -> None: + if isinstance(output, RequestOutput): + output_token_ids = output.outputs[0].token_ids + for token_id in output_token_ids: + self.parser.process(token_id) + output_msgs = self.parser.messages + else: + # Tool output. + output_msgs = output + self._messages.extend(output_msgs) + + @property + def messages(self) -> list: + return self._messages + + def need_builtin_tool_call(self) -> bool: + last_msg = self.messages[-1] + recipient = last_msg.recipient + return recipient is not None and (recipient.startswith("browser.") + or recipient.startswith("python")) + + async def call_tool(self) -> list[Message]: + if not self.messages: + return [] + last_msg = self.messages[-1] + recipient = last_msg.recipient + if recipient is not None: + if recipient.startswith("browser."): + return await self.call_search_tool( + self.tool_sessions["browser"], last_msg) + elif recipient.startswith("python"): + return await self.call_python_tool( + self.tool_sessions["python"], last_msg) + raise ValueError("No tool call found") + + def render_for_completion(self) -> list[int]: + return render_for_completion(self.messages) + + async def call_search_tool(self, tool_session: Union[ClientSession, Tool], + last_msg: Message) -> list[Message]: + if isinstance(tool_session, Tool): + return await tool_session.get_result(self) + tool_name = last_msg.recipient.split(".")[1] + args = json.loads(last_msg.content[0].text) + result = await tool_session.call_tool(tool_name, args) + result_str = result.content[0].text + content = TextContent(text=result_str) + author = Author(role=Role.TOOL, name=last_msg.recipient) + return [ + Message(author=author, content=[content], recipient=Role.ASSISTANT) + ] + + async def call_python_tool(self, tool_session: Union[ClientSession, Tool], + last_msg: Message) -> list[Message]: + if isinstance(tool_session, Tool): + return await tool_session.get_result(self) + param = { + "code": last_msg.content[0].text, + } + result = await tool_session.call_tool("python", param) + result_str = result.content[0].text + + content = TextContent(text=result_str) + author = Author(role=Role.TOOL, name="python") + + return [ + Message(author=author, + content=[content], + channel=last_msg.channel, + recipient=Role.ASSISTANT) + ] + + +class StreamingHarmonyContext(HarmonyContext): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.last_output = None + + self.parser = get_streamable_parser_for_assistant() + self.encoding = get_encoding() + self.last_tok = None + + @property + def messages(self) -> list: + return self.parser.messages + + def append_output(self, output) -> None: + if isinstance(output, RequestOutput): + tok = output.outputs[0].token_ids[0] + self.parser.process(tok) + self.last_tok = tok + else: + # Handle the case of tool output in direct message format + assert len(output) == 1, "Tool output should be a single message" + msg = output[0] + # Sometimes the recipient is not set for tool messages, + # so we set it to "assistant" + if msg.author.role == Role.TOOL and msg.recipient is None: + msg.recipient = "assistant" + toks = self.encoding.render(msg) + for tok in toks: + self.parser.process(tok) + self.last_tok = toks[-1] + + def is_expecting_start(self) -> bool: + return self.parser.state == StreamState.EXPECT_START + + def is_assistant_action_turn(self) -> bool: + return self.last_tok in self.encoding.stop_tokens_for_assistant_actions( + ) + + def render_for_completion(self) -> list[int]: + # now this list of tokens as next turn's starting tokens + # `<|start|>assistant``, + # we need to process them in parser. + rendered_tokens = super().render_for_completion() + + last_n = -1 + to_process = [] + while rendered_tokens[last_n] != self.last_tok: + to_process.append(rendered_tokens[last_n]) + last_n -= 1 + for tok in reversed(to_process): + self.parser.process(tok) + + return rendered_tokens diff --git a/vllm/entrypoints/harmony_utils.py b/vllm/entrypoints/harmony_utils.py new file mode 100644 index 000000000000..cf374f8bfffe --- /dev/null +++ b/vllm/entrypoints/harmony_utils.py @@ -0,0 +1,321 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import datetime +import json +from collections.abc import Iterable +from typing import Literal, Optional, Union + +from openai.types.responses import (ResponseOutputItem, ResponseOutputMessage, + ResponseOutputText) +from openai.types.responses.response_function_tool_call import ( + ResponseFunctionToolCall) +from openai.types.responses.response_function_web_search import ( + ActionFind, ActionOpenPage, ActionSearch, ResponseFunctionWebSearch) +from openai.types.responses.tool import Tool +from openai_harmony import (Author, Conversation, DeveloperContent, + HarmonyEncodingName, Message, ReasoningEffort, + Role, StreamableParser, SystemContent, TextContent, + ToolDescription, load_harmony_encoding) + +from vllm.entrypoints.openai.protocol import (ResponseInputOutputItem, + ResponseReasoningItem, + ResponseReasoningTextContent) +from vllm.utils import random_uuid + +REASONING_EFFORT = { + "high": ReasoningEffort.HIGH, + "medium": ReasoningEffort.MEDIUM, + "low": ReasoningEffort.LOW, +} + +_harmony_encoding = None + + +def get_encoding(): + global _harmony_encoding + if _harmony_encoding is None: + _harmony_encoding = load_harmony_encoding( + HarmonyEncodingName.HARMONY_GPT_OSS) + return _harmony_encoding + + +def get_system_message( + model_identity: Optional[str] = None, + reasoning_effort: Optional[Literal["high", "medium", "low"]] = None, + start_date: Optional[str] = None, + browser_description: Optional[str] = None, + python_description: Optional[str] = None, +) -> Message: + sys_msg_content = SystemContent.new() + if model_identity is not None: + sys_msg_content = sys_msg_content.with_model_identity(model_identity) + if reasoning_effort is not None: + sys_msg_content = sys_msg_content.with_reasoning_effort( + REASONING_EFFORT[reasoning_effort]) + if start_date is None: + start_date = datetime.datetime.now().strftime("%Y-%m-%d") + sys_msg_content = sys_msg_content.with_conversation_start_date(start_date) + if browser_description is not None: + sys_msg_content = sys_msg_content.with_tools(browser_description) + if python_description is not None: + sys_msg_content = sys_msg_content.with_tools(python_description) + sys_msg = Message.from_role_and_content(Role.SYSTEM, sys_msg_content) + return sys_msg + + +def get_developer_message(instructions: Optional[str] = None, + tools: Optional[list[Tool]] = None) -> Message: + dev_msg_content = DeveloperContent.new() + if instructions is not None: + dev_msg_content = dev_msg_content.with_instructions(instructions) + if tools is not None: + function_tools = [] + for tool in tools: + if tool.type in ("web_search_preview", "code_interpreter"): + # These are built-in tools that are added to the system message. + pass + elif tool.type == "function": + function_tools.append(tool) + else: + raise ValueError(f"tool type {tool.type} not supported") + if function_tools: + function_tool_descriptions = [ + ToolDescription.new( + name=tool.name, + description=tool.description, + parameters=tool.parameters, + ) for tool in function_tools + ] + dev_msg_content = dev_msg_content.with_function_tools( + function_tool_descriptions) + dev_msg = Message.from_role_and_content(Role.DEVELOPER, dev_msg_content) + return dev_msg + + +def get_user_message(content: str) -> Message: + return Message.from_role_and_content(Role.USER, content) + + +def parse_response_input( + response_msg: ResponseInputOutputItem, + prev_responses: list[Union[ResponseOutputItem, ResponseReasoningItem]] +) -> Message: + if not isinstance(response_msg, dict): + response_msg = response_msg.model_dump() + if "type" not in response_msg or response_msg["type"] == "message": + role = response_msg["role"] + content = response_msg["content"] + if role == "system": + # User is trying to set a system message. Change it to: + # <|start|>developer<|message|># Instructions + # {instructions}<|end|> + role = "developer" + text_prefix = "Instructions:\n" + else: + text_prefix = "" + if isinstance(content, str): + msg = Message.from_role_and_content(role, text_prefix + content) + else: + contents = [ + TextContent(text=text_prefix + c["text"]) for c in content + ] + msg = Message.from_role_and_contents(role, contents) + elif response_msg["type"] == "function_call_output": + call_id = response_msg["call_id"] + call_response: Optional[ResponseFunctionToolCall] = None + for prev_response in reversed(prev_responses): + if isinstance(prev_response, ResponseFunctionToolCall + ) and prev_response.call_id == call_id: + call_response = prev_response + break + if call_response is None: + raise ValueError(f"No call message found for {call_id}") + msg = Message.from_author_and_content( + Author.new(Role.TOOL, f"functions.{call_response.name}"), + response_msg["output"]) + elif response_msg["type"] == "reasoning": + content = response_msg["content"] + assert len(content) == 1 + msg = Message.from_role_and_content(Role.ASSISTANT, content[0]["text"]) + elif response_msg["type"] == "function_call": + msg = Message.from_role_and_content(Role.ASSISTANT, + response_msg["arguments"]) + msg = msg.with_channel("commentary") + msg = msg.with_recipient(f"functions.{response_msg['name']}") + msg = msg.with_content_type("json") + else: + raise ValueError(f"Unknown input type: {response_msg['type']}") + return msg + + +def parse_response_output(output: ResponseOutputItem) -> Message: + if isinstance(output, ResponseOutputMessage): + role = output.role + contents = [TextContent(text=c.text) for c in output.content] + msg = Message.from_role_and_contents(role, contents) + return msg + elif isinstance(output, ResponseFunctionToolCall): + msg = Message.from_role_and_content(Role.ASSISTANT, output.arguments) + msg = msg.with_channel("commentary") + msg = msg.with_recipient(output.name) + msg = msg.with_content_type("json") + return msg + else: + raise ValueError(f"Unknown output type: {type(output)}") + + +def parse_chat_input(chat_msg) -> Message: + role = chat_msg["role"] + content = chat_msg["content"] + if isinstance(content, str): + contents = [TextContent(text=content)] + else: + # TODO: Support refusal. + contents = [TextContent(text=c["text"]) for c in content] + msg = Message.from_role_and_contents(role, contents) + return msg + + +def render_for_completion(messages: list[Message]) -> list[int]: + conversation = Conversation.from_messages(messages) + token_ids = get_encoding().render_conversation_for_completion( + conversation, Role.ASSISTANT) + return token_ids + + +def get_stop_tokens_for_assistant_actions() -> list[int]: + return get_encoding().stop_tokens_for_assistant_actions() + + +def get_streamable_parser_for_assistant() -> StreamableParser: + return StreamableParser(get_encoding(), role=Role.ASSISTANT) + + +def parse_output_message(message: Message): + if message.author.role != "assistant": + # This is a message from a tool to the assistant (e.g., search result). + # Don't include it in the final output for now. This aligns with + # OpenAI's behavior on models like o4-mini. + return [] + + output_items = [] + recipient = message.recipient + if recipient is not None and recipient.startswith("browser."): + if len(message.content) != 1: + raise ValueError("Invalid number of contents in browser message") + content = message.content[0] + browser_call = json.loads(content.text) + # TODO: translate to url properly! + if recipient == "browser.search": + action = ActionSearch( + query=f"cursor:{browser_call.get('query', '')}", type="search") + elif recipient == "browser.open": + action = ActionOpenPage( + url=f"cursor:{browser_call.get('url', '')}", type="open_page") + elif recipient == "browser.find": + action = ActionFind(pattern=browser_call["pattern"], + url=f"cursor:{browser_call.get('url', '')}", + type="find") + else: + raise ValueError(f"Unknown browser action: {recipient}") + web_search_item = ResponseFunctionWebSearch( + id=f"ws_{random_uuid()}", + action=action, + status="completed", + type="web_search_call", + ) + output_items.append(web_search_item) + elif message.channel == "analysis": + for content in message.content: + reasoning_item = ResponseReasoningItem( + content=[ResponseReasoningTextContent(text=content.text)], + status=None, + ) + output_items.append(reasoning_item) + elif message.channel == "commentary": + if message.recipient.startswith("functions."): + function_name = message.recipient.split(".")[-1] + for content in message.content: + random_id = random_uuid() + response_item = ResponseFunctionToolCall( + arguments=content.text, + call_id=f"call_{random_id}", + type="function_call", + name=function_name, + id=f"ft_{random_id}", + ) + output_items.append(response_item) + elif message.recipient.startswith( + "python") or message.recipient.startswith("browser"): + for content in message.content: + reasoning_item = ResponseReasoningItem( + text=content.text, + status=None, + ) + output_items.append(reasoning_item) + else: + raise ValueError(f"Unknown recipient: {message.recipient}") + elif message.channel == "final": + contents = [] + for content in message.content: + output_text = ResponseOutputText( + text=content.text, + annotations=[], # TODO + type="output_text", + logprobs=None, # TODO + ) + contents.append(output_text) + text_item = ResponseOutputMessage( + id=f"msg_{random_uuid()}", + content=contents, + role=message.author.role, + status="completed", + type="message", + ) + output_items.append(text_item) + else: + raise ValueError(f"Unknown channel: {message.channel}") + return output_items + + +def parse_remaining_state(parser: StreamableParser): + if not parser.current_content: + return [] + if parser.current_role != Role.ASSISTANT: + return [] + current_recipient = parser.current_recipient + if (current_recipient is not None + and current_recipient.startswith("browser.")): + return [] + + if parser.current_channel == "analysis": + reasoning_item = ResponseReasoningItem( + content=[ + ResponseReasoningTextContent(text=parser.current_content) + ], + status=None, + ) + return [reasoning_item] + elif parser.current_channel == "final": + output_text = ResponseOutputText( + text=parser.current_content, + annotations=[], # TODO + type="output_text", + logprobs=None, # TODO + ) + text_item = ResponseOutputMessage( + id=f"msg_{random_uuid()}", + content=[output_text], + role="assistant", + status="completed", + type="message", + ) + return [text_item] + return [] + + +def parse_output_into_messages(token_ids: Iterable[int]): + parser = get_streamable_parser_for_assistant() + for token_id in token_ids: + parser.process(token_id) + return parser diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 9bf470232078..2c176f0eab29 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -92,6 +92,8 @@ from vllm.entrypoints.openai.serving_transcription import ( OpenAIServingTranscription, OpenAIServingTranslation) from vllm.entrypoints.openai.tool_parsers import ToolParserManager +from vllm.entrypoints.openai.tool_server import (DemoToolServer, MCPToolServer, + ToolServer) from vllm.entrypoints.utils import (cli_env_setup, load_aware_call, log_non_default_args, with_cancellation) from vllm.logger import init_logger @@ -1620,6 +1622,14 @@ async def init_app_state( "This discrepancy may lead to performance degradation.", resolved_chat_template, args.model) + if args.tool_server == "demo": + tool_server: Optional[ToolServer] = DemoToolServer() + elif args.tool_server: + tool_server = MCPToolServer() + await tool_server.add_tool_server(args.tool_server) + else: + tool_server = None + # Merge default_mm_loras into the static lora_modules default_mm_loras = (vllm_config.lora_config.default_mm_loras if vllm_config.lora_config is not None else {}) @@ -1654,6 +1664,7 @@ async def init_app_state( return_tokens_as_token_ids=args.return_tokens_as_token_ids, enable_auto_tools=args.enable_auto_tool_choice, tool_parser=args.tool_call_parser, + tool_server=tool_server, reasoning_parser=args.reasoning_parser, enable_prompt_tokens_details=args.enable_prompt_tokens_details, enable_force_include_usage=args.enable_force_include_usage, diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index dfbc9cde3d5b..38d588cc9bea 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -147,6 +147,9 @@ class FrontendArgs: """Special the tool parser plugin write to parse the model-generated tool into OpenAI API format, the name register in this plugin can be used in `--tool-call-parser`.""" + tool_server: Optional[str] = None + """Comma-separated list of host:port pairs (IPv4, IPv6, or hostname). + Examples: 127.0.0.1:8000, [::1]:8000, localhost:1234""" log_config_file: Optional[str] = envs.VLLM_LOGGING_CONFIG_PATH """Path to logging config JSON file for both vllm and uvicorn""" max_log_len: Optional[int] = None diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index d77aee345843..6087852f0515 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -17,9 +17,11 @@ from openai.types.chat.chat_completion_message import ( Annotation as OpenAIAnnotation) # yapf: enable -from openai.types.responses import (ResponseInputParam, ResponseOutputItem, - ResponseOutputMessage, ResponsePrompt, - ResponseStatus, ResponseTextConfig) +from openai.types.responses import (ResponseFunctionToolCall, + ResponseFunctionToolCallOutputItem, + ResponseInputItemParam, ResponseOutputItem, + ResponsePrompt, ResponseStatus, + ResponseTextConfig) from openai.types.responses.response import ToolChoice from openai.types.responses.tool import Tool from openai.types.shared import Metadata, Reasoning @@ -234,6 +236,11 @@ def get_logits_processors(processors: Optional[LogitsProcessors], return None +ResponseInputOutputItem: TypeAlias = Union[ResponseInputItemParam, + "ResponseReasoningItem", + ResponseFunctionToolCall] + + class ResponsesRequest(OpenAIBaseModel): # Ordered by official OpenAI API documentation # https://platform.openai.com/docs/api-reference/responses/create @@ -248,7 +255,7 @@ class ResponsesRequest(OpenAIBaseModel): "reasoning.encrypted_content", ], ]] = None - input: Union[str, ResponseInputParam] + input: Union[str, list[ResponseInputOutputItem]] instructions: Optional[str] = None max_output_tokens: Optional[int] = None max_tool_calls: Optional[int] = None @@ -323,6 +330,7 @@ def to_sampling_params( if (top_p := self.top_p) is None: top_p = default_sampling_params.get( "top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"]) + stop_token_ids = default_sampling_params.get("stop_token_ids") # Structured output guided_decoding = None @@ -340,6 +348,7 @@ def to_sampling_params( top_p=top_p, max_tokens=max_tokens, logprobs=self.top_logprobs, + stop_token_ids=stop_token_ids, output_kind=(RequestOutputKind.DELTA if self.stream else RequestOutputKind.FINAL_ONLY), guided_decoding=guided_decoding, @@ -404,6 +413,8 @@ class ChatCompletionRequest(OpenAIBaseModel): Literal["required"], ChatCompletionNamedToolChoiceParam, ]] = "none" + reasoning_effort: Optional[Literal["low", "medium", "high"]] = None + include_reasoning: bool = True # NOTE this will be ignored by vLLM -- the model determines the behavior parallel_tool_calls: Optional[bool] = False @@ -1707,15 +1718,74 @@ class TranscriptionStreamResponse(OpenAIBaseModel): usage: Optional[UsageInfo] = Field(default=None) +class ResponseReasoningTextContent(OpenAIBaseModel): + text: str + type: Literal["reasoning_text"] = "reasoning_text" + + class ResponseReasoningItem(OpenAIBaseModel): id: str = Field(default_factory=lambda: f"rs_{random_uuid()}") - text: str + content: list[ResponseReasoningTextContent] = Field(default_factory=list) summary: list = Field(default_factory=list) type: Literal["reasoning"] = "reasoning" encrypted_content: Optional[str] = None status: Optional[Literal["in_progress", "completed", "incomplete"]] +class InputTokensDetails(OpenAIBaseModel): + cached_tokens: int + + +class OutputTokensDetails(OpenAIBaseModel): + reasoning_tokens: int + + +class ResponseUsage(OpenAIBaseModel): + input_tokens: int + input_tokens_details: InputTokensDetails + output_tokens: int + output_tokens_details: OutputTokensDetails + total_tokens: int + + +class ResponseReasoningTextDeltaEvent(OpenAIBaseModel): + type: Literal[ + "response.reasoning_text.delta"] = "response.reasoning_text.delta" + item_id: str = "item_1234" + output_index: int + content_index: int + delta: str + sequence_number: int = -1 + + +class ResponseReasoningTextDoneEvent(OpenAIBaseModel): + type: Literal[ + "response.reasoning_text.done"] = "response.reasoning_text.done" + item_id: str = "item_1234" + output_index: int + content_index: int + text: str + sequence_number: int = -1 + + +class ResponseContentPartDoneEvent(OpenAIBaseModel): + type: Literal["response.content_part.done"] = "response.content_part.done" + item_id: str = "item_1234" + output_index: int + content_index: int + part: Union[ResponseOutputItem, ResponseReasoningItem] + sequence_number: int = -1 + + +class ResponseOutputItemDoneEvent(OpenAIBaseModel): + type: Literal["response.output_item.done"] = "response.output_item.done" + item_id: str = "item_1234" + output_index: int + item: Union[ResponseOutputItem, ResponseReasoningItem, + ResponseFunctionToolCallOutputItem] + sequence_number: int = -1 + + class ResponsesResponse(OpenAIBaseModel): id: str = Field(default_factory=lambda: f"resp_{random_uuid()}") created_at: int = Field(default_factory=lambda: int(time.time())) @@ -1725,7 +1795,7 @@ class ResponsesResponse(OpenAIBaseModel): metadata: Optional[Metadata] = None model: str object: Literal["response"] = "response" - output: list[Union[ResponseOutputMessage, ResponseReasoningItem]] + output: list[Union[ResponseOutputItem, ResponseReasoningItem]] parallel_tool_calls: bool temperature: float tool_choice: ToolChoice @@ -1742,7 +1812,7 @@ class ResponsesResponse(OpenAIBaseModel): text: Optional[ResponseTextConfig] = None top_logprobs: int truncation: Literal["auto", "disabled"] - usage: Optional[UsageInfo] = None + usage: Optional[ResponseUsage] = None user: Optional[str] = None @classmethod @@ -1752,9 +1822,9 @@ def from_request( sampling_params: SamplingParams, model_name: str, created_time: int, - output: list[ResponseOutputItem], + output: list[Union[ResponseOutputItem, ResponseReasoningItem]], status: ResponseStatus, - usage: Optional[UsageInfo] = None, + usage: Optional[ResponseUsage] = None, ) -> "ResponsesResponse": return cls( id=request.request_id, diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index e1d8a31672ed..d7e149d76849 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -12,6 +12,7 @@ import partial_json_parser import regex as re from fastapi import Request +from openai_harmony import Message as OpenAIMessage from pydantic import TypeAdapter from vllm.config import ModelConfig @@ -19,6 +20,10 @@ from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption, ConversationMessage, random_tool_call_id) +from vllm.entrypoints.harmony_utils import ( + get_developer_message, get_stop_tokens_for_assistant_actions, + get_streamable_parser_for_assistant, get_system_message, parse_chat_input, + parse_output_into_messages, render_for_completion) from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.protocol import ( ChatCompletionLogProb, ChatCompletionLogProbs, @@ -35,6 +40,7 @@ from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import ( MistralToolCall) from vllm.entrypoints.utils import get_max_tokens +from vllm.inputs.data import TokensPrompt as EngineTokensPrompt from vllm.logger import init_logger from vllm.outputs import CompletionOutput, RequestOutput from vllm.reasoning import ReasoningParser, ReasoningParserManager @@ -125,6 +131,23 @@ def __init__( logger.info("Using default chat sampling params from %s: %s", source, self.default_sampling_params) + self.use_harmony = model_config.hf_config.model_type == "gpt_oss" + if self.use_harmony: + if "stop_token_ids" not in self.default_sampling_params: + self.default_sampling_params["stop_token_ids"] = [] + self.default_sampling_params["stop_token_ids"].extend( + get_stop_tokens_for_assistant_actions()) + + # NOTE(woosuk): While OpenAI's chat completion API supports browsing + # for some models, currently vLLM doesn't support it. Please use the + # Responses API instead. + self.supports_browsing = False + self.browser_tool = None + # NOTE(woosuk): Chat completion API does not support code interpreter. + # Please use the Responses API instead. + self.supports_code_interpreter = False + self.python_tool = None + async def create_chat_completion( self, request: ChatCompletionRequest, @@ -169,7 +192,8 @@ async def create_chat_completion( if (request.tool_choice == "auto" and not (self.enable_auto_tools and tool_parser is not None) - and not isinstance(tokenizer, MistralTokenizer)): + and not isinstance(tokenizer, MistralTokenizer) + and not self.use_harmony): # for hf tokenizers, "auto" tools requires # --enable-auto-tool-choice and --tool-call-parser return self.create_error_response( @@ -184,25 +208,35 @@ async def create_chat_completion( else: tool_dicts = [tool.model_dump() for tool in request.tools] - ( - conversation, - request_prompts, - engine_prompts, - ) = await self._preprocess_chat( - request, - tokenizer, - request.messages, - chat_template=request.chat_template or self.chat_template, - chat_template_content_format=self.chat_template_content_format, - add_generation_prompt=request.add_generation_prompt, - continue_final_message=request.continue_final_message, - tool_dicts=tool_dicts, - documents=request.documents, - chat_template_kwargs=request.chat_template_kwargs, - tool_parser=tool_parser, - truncate_prompt_tokens=request.truncate_prompt_tokens, - add_special_tokens=request.add_special_tokens, - ) + if not self.use_harmony: + # Common case. + ( + conversation, + request_prompts, + engine_prompts, + ) = await self._preprocess_chat( + request, + tokenizer, + request.messages, + chat_template=request.chat_template or self.chat_template, + chat_template_content_format=self. + chat_template_content_format, + add_generation_prompt=request.add_generation_prompt, + continue_final_message=request.continue_final_message, + tool_dicts=tool_dicts, + documents=request.documents, + chat_template_kwargs=request.chat_template_kwargs, + tool_parser=tool_parser, + truncate_prompt_tokens=request.truncate_prompt_tokens, + add_special_tokens=request.add_special_tokens, + ) + else: + # For GPT-OSS. + ( + conversation, + request_prompts, + engine_prompts, + ) = self._make_request_with_harmony(request) except (ValueError, TypeError, RuntimeError, jinja2.TemplateError) as e: logger.exception("Error in preprocessing prompt inputs") @@ -436,6 +470,11 @@ async def chat_completion_stream_generator( finish_reason_sent = [False] * num_choices num_prompt_tokens = 0 num_cached_tokens = None + if self.use_harmony: + harmony_parsers = [ + get_streamable_parser_for_assistant() + for _ in range(num_choices) + ] if isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam): tool_choice_function_name = request.tool_choice.function.name @@ -597,7 +636,18 @@ async def chat_completion_stream_generator( else: logprobs = None - delta_text = output.text + if self.use_harmony: + harmony_parser = harmony_parsers[i] + for token_id in output.token_ids: + harmony_parser.process(token_id) + # FIXME(woosuk): Support function calling + is_final = harmony_parser.current_channel == "final" + if not (request.include_reasoning or is_final): + # Skip the reasoning content. + continue + delta_text = harmony_parser.last_content_delta or "" + else: + delta_text = output.text if not delta_text and not output.token_ids and \ not previous_num_tokens[i]: @@ -607,7 +657,8 @@ async def chat_completion_stream_generator( delta_message: Optional[DeltaMessage] # just update previous_texts and previous_token_ids - if tool_choice_auto or self.reasoning_parser: + if ((tool_choice_auto or self.reasoning_parser) + and not self.use_harmony): assert previous_texts is not None assert all_previous_token_ids is not None previous_text = previous_texts[i] @@ -621,8 +672,14 @@ async def chat_completion_stream_generator( else: current_token_ids = list(output.token_ids) + if self.use_harmony: + if is_final: + delta_message = DeltaMessage(content=delta_text) + else: + delta_message = DeltaMessage( + reasoning_content=delta_text) # handle streaming deltas for tools with named tool_choice - if tool_choice_function_name: + elif tool_choice_function_name: if (self.reasoning_parser and not reasoning_end_arr[i] and not reasoning_parser.is_reasoning_end( previous_token_ids)): @@ -990,7 +1047,59 @@ async def chat_completion_full_generator( ) else: logprobs = None - auto_tools_called = False + + if self.use_harmony: + parser = parse_output_into_messages(token_ids) + output_msgs = parser.messages + if len(output_msgs) == 0: + # The generation has stopped during reasoning. + is_tool_call = False + reasoning_content = parser.current_content + final_content = None + elif len(output_msgs) == 1: + # The generation has stopped during final message. + is_tool_call = False + reasoning_content = output_msgs[0].content[0].text + final_content = parser.current_content + else: + if len(output_msgs) != 2: + raise ValueError( + "Expected 2 output messages (reasoning and final), " + f"but got {len(output_msgs)}.") + reasoning_msg, final_msg = output_msgs + reasoning_content = reasoning_msg.content[0].text + final_content = final_msg.content[0].text + is_tool_call = final_msg.recipient is not None + + if not request.include_reasoning: + reasoning_content = None + if is_tool_call: + # Tool call TODO + raise NotImplementedError( + "Tool call is not supported yet.") + else: + # Normal message + message = ChatMessage( + role=role, + reasoning_content=reasoning_content, + content=final_content, + ) + + if is_tool_call: + finish_reason = "tool_calls" + elif output.finish_reason: + finish_reason = output.finish_reason + else: + finish_reason = "stop" + choice_data = ChatCompletionResponseChoice( + index=output.index, + message=message, + logprobs=logprobs, + finish_reason=finish_reason, + stop_reason=output.stop_reason, + ) + choices.append(choice_data) + continue if self.reasoning_parser: try: @@ -1003,10 +1112,13 @@ async def chat_completion_full_generator( reasoning_content, content = ( reasoning_parser.extract_reasoning_content( output.text, request=request)) + if not request.include_reasoning: + reasoning_content = None else: reasoning_content = None content = output.text + auto_tools_called = False # if auto tools are not enabled, and a named tool choice using # outlines is not being used if (not self.enable_auto_tools or not self.tool_parser) and \ @@ -1261,3 +1373,33 @@ def _should_check_for_unstreamed_tool_arg_tokens( and delta_message.tool_calls[0].function and delta_message.tool_calls[0].function.arguments is not None ) + + def _make_request_with_harmony( + self, + request: ChatCompletionRequest, + ): + messages: list[OpenAIMessage] = [] + + # Add system message. + # In Chat Completion API, browsing is enabled by default if the model + # supports it. + assert not self.supports_browsing + assert not self.supports_code_interpreter + sys_msg = get_system_message( + reasoning_effort=request.reasoning_effort, + browser_description=None, + python_description=None) + messages.append(sys_msg) + + # Add developer message. + dev_msg = get_developer_message() + messages.append(dev_msg) + + # Add user message. + for chat_msg in request.messages: + messages.append(parse_chat_input(chat_msg)) + + # Render prompt token ids. + prompt_token_ids = render_for_completion(messages) + engine_prompt = EngineTokensPrompt(prompt_token_ids=prompt_token_ids) + return messages, [prompt_token_ids], [engine_prompt] diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 71976fea1ee7..822f1868406c 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -35,6 +35,7 @@ apply_mistral_chat_template, parse_chat_messages_futures, resolve_chat_template_content_format) +from vllm.entrypoints.context import ConversationContext from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, ChatCompletionResponse, @@ -948,6 +949,61 @@ async def _preprocess_chat( return conversation, [request_prompt], [engine_prompt] + async def _generate_with_builtin_tools( + self, + request_id: str, + request_prompt: RequestPrompt, + engine_prompt: EngineTokensPrompt, + sampling_params: SamplingParams, + context: ConversationContext, + lora_request: Optional[LoRARequest] = None, + priority: int = 0, + **kwargs, + ): + orig_priority = priority + while True: + self._log_inputs( + request_id, + request_prompt, + params=sampling_params, + lora_request=lora_request, + ) + generator = self.engine_client.generate( + engine_prompt, + sampling_params, + request_id, + lora_request=lora_request, + priority=priority, + **kwargs, + ) + async for res in generator: + context.append_output(res) + # NOTE(woosuk): The stop condition is handled by the engine. + yield context + + if not context.need_builtin_tool_call(): + # The model did not ask for a tool call, so we're done. + break + + # Call the tool and update the context with the result. + tool_output = await context.call_tool() + context.append_output(tool_output) + + # TODO: uncomment this and enable tool output streaming + # yield context + + # Create inputs for the next turn. + # Render the next prompt token ids. + prompt_token_ids = context.render_for_completion() + engine_prompt = EngineTokensPrompt( + prompt_token_ids=prompt_token_ids) + request_prompt = prompt_token_ids + # Update the sampling params. + sampling_params.max_tokens = (self.max_model_len - + len(prompt_token_ids)) + # OPTIMIZATION + priority = orig_priority - 1 + def _load_prompt_embeds( self, prompt_embeds: Optional[Union[bytes, list[bytes]]], diff --git a/vllm/entrypoints/openai/serving_responses.py b/vllm/entrypoints/openai/serving_responses.py index e009529fbd2a..0795ed7e92de 100644 --- a/vllm/entrypoints/openai/serving_responses.py +++ b/vllm/entrypoints/openai/serving_responses.py @@ -2,34 +2,56 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import asyncio +import json import time from collections.abc import AsyncGenerator, AsyncIterator +from contextlib import AsyncExitStack +from copy import copy from http import HTTPStatus -from typing import Callable, Final, Optional, Union +from typing import Any, Callable, Final, Optional, Union import jinja2 +import openai.types.responses as openai_responses_types from fastapi import Request from openai.types.responses import ResponseOutputMessage, ResponseOutputText +from openai.types.responses.response_function_tool_call import ( + ResponseFunctionToolCall) +from openai_harmony import Message as OpenAIMessage +from pydantic import BaseModel from vllm import envs from vllm.config import ModelConfig from vllm.engine.protocol import EngineClient from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam, ChatTemplateContentFormatOption) +from vllm.entrypoints.context import (ConversationContext, HarmonyContext, + SimpleContext, StreamingHarmonyContext) +from vllm.entrypoints.harmony_utils import ( + get_developer_message, get_stop_tokens_for_assistant_actions, + get_system_message, get_user_message, parse_output_message, + parse_remaining_state, parse_response_input, render_for_completion) from vllm.entrypoints.logger import RequestLogger # yapf conflicts with isort for this block # yapf: disable from vllm.entrypoints.openai.protocol import (ErrorResponse, - PromptTokenUsageInfo, + InputTokensDetails, + OutputTokensDetails, RequestResponseMetadata, + ResponseContentPartDoneEvent, + ResponseOutputItemDoneEvent, ResponseReasoningItem, + ResponseReasoningTextContent, + ResponseReasoningTextDeltaEvent, + ResponseReasoningTextDoneEvent, ResponsesRequest, - ResponsesResponse, UsageInfo) + ResponsesResponse, ResponseUsage) # yapf: enable from vllm.entrypoints.openai.serving_engine import OpenAIServing from vllm.entrypoints.openai.serving_models import OpenAIServingModels +from vllm.entrypoints.openai.tool_server import MCPToolServer, ToolServer +from vllm.inputs.data import TokensPrompt as EngineTokensPrompt from vllm.logger import init_logger -from vllm.outputs import RequestOutput +from vllm.outputs import CompletionOutput, RequestOutput from vllm.reasoning import ReasoningParser, ReasoningParserManager from vllm.sampling_params import SamplingParams from vllm.transformers_utils.tokenizer import AnyTokenizer @@ -53,6 +75,7 @@ def __init__( reasoning_parser: str = "", enable_auto_tools: bool = False, tool_parser: Optional[str] = None, + tool_server: Optional[ToolServer] = None, enable_prompt_tokens_details: bool = False, enable_force_include_usage: bool = False, ) -> None: @@ -90,6 +113,33 @@ def __init__( logger.info("Using default chat sampling params from %s: %s", source, self.default_sampling_params) + self.supports_browsing = tool_server.has_tool( + "browser") if tool_server else False + self.supports_code_interpreter = tool_server.has_tool( + "python") if tool_server else False + self.tool_server = tool_server + self.use_harmony = model_config.hf_config.model_type == "gpt_oss" + if self.use_harmony: + logger.warning("For gpt-oss, we ignore --enable-auto-tool-choice " + "and always enable tool use.") + # OpenAI models have two EOS-like tokens: <|return|> and <|call|>. + # We need to add them to the stop token ids. + if "stop_token_ids" not in self.default_sampling_params: + self.default_sampling_params["stop_token_ids"] = [] + self.default_sampling_params["stop_token_ids"].extend( + get_stop_tokens_for_assistant_actions()) + + # set up tool use + self.enable_auto_tools: bool = enable_auto_tools + if self.enable_auto_tools: + logger.info( + "\"auto\" tool choice has been enabled please note that while" + " the parallel_tool_calls client option is preset for " + "compatibility reasons, it will be ignored.") + if not self.use_harmony: + raise NotImplementedError("Auto tool choice is not supported " + "yet unless using Harmony") + # If False (default), the "store" option is (silently) ignored and the # response is not stored. If True, the response is stored in memory. # NOTE(woosuk): This may not be intuitive for users, as the default @@ -161,21 +211,20 @@ async def create_responses( return self._make_not_found_error(prev_response_id) else: prev_response = None - # Construct the input messages. - messages = self._construct_input_messages(request, prev_response) try: lora_request = self._maybe_get_adapters(request) model_name = self._get_model_name(request.model, lora_request) tokenizer = await self.engine_client.get_tokenizer(lora_request) - _, request_prompts, engine_prompts = await self._preprocess_chat( - request, - tokenizer, - messages, - chat_template=self.chat_template, - chat_template_content_format=self.chat_template_content_format, - ) + if self.use_harmony: + messages, request_prompts, engine_prompts = ( + self._make_request_with_harmony(request, prev_response)) + else: + messages, request_prompts, engine_prompts = ( + await self._make_request(request, prev_response, + tokenizer)) + except (ValueError, TypeError, RuntimeError, jinja2.TemplateError) as e: logger.exception("Error in preprocessing prompt inputs") @@ -186,98 +235,181 @@ async def create_responses( if raw_request: raw_request.state.request_metadata = request_metadata + if self.tool_server is not None and isinstance( + self.tool_server, MCPToolServer + ) and (request.background or request.stream) and request.tools and any( + tool.type in ["web_search_preview", "code_interpreter"] + for tool in request.tools): + return self.create_error_response( + "MCP tool server is not supported in background mode and " + "streaming mode") # Schedule the request and get the result generator. generators: list[AsyncGenerator[RequestOutput, None]] = [] - try: - for i, engine_prompt in enumerate(engine_prompts): - default_max_tokens = self.max_model_len - len( - engine_prompt["prompt_token_ids"]) - sampling_params = request.to_sampling_params( - default_max_tokens, self.default_sampling_params) - - self._log_inputs(request.request_id, - request_prompts[i], - params=sampling_params, - lora_request=lora_request) - - trace_headers = (None if raw_request is None else await - self._get_trace_headers(raw_request.headers)) - - generator = self.engine_client.generate( - engine_prompt, + if self.use_harmony: + tool_list = [] + if self.supports_browsing: + tool_list.append("browser") + if self.supports_code_interpreter: + tool_list.append("python") + async with AsyncExitStack() as exit_stack: + try: + if self.tool_server is not None: + tool_session_ctxs: dict[str, Any] = { + tool_name: + exit_stack.enter_async_context( + self.tool_server.get_tool_session(tool_name)) + for tool_name in tool_list + } + tool_sessions = {} + for tool_name in tool_list: + tool_sessions[tool_name] = ( + await tool_session_ctxs[tool_name]) + else: + assert len(tool_list) == 0 + tool_sessions = {} + for i, engine_prompt in enumerate(engine_prompts): + default_max_tokens = self.max_model_len - len( + engine_prompt["prompt_token_ids"]) + sampling_params = request.to_sampling_params( + default_max_tokens, self.default_sampling_params) + + trace_headers = (None if raw_request is None else await + self._get_trace_headers( + raw_request.headers)) + + context: ConversationContext + if self.use_harmony: + if request.stream: + context = StreamingHarmonyContext( + messages, tool_sessions) + else: + context = HarmonyContext(messages, tool_sessions) + else: + context = SimpleContext() + generator = self._generate_with_builtin_tools( + request.request_id, + request_prompts[i], + engine_prompt, + sampling_params, + context, + lora_request=lora_request, + trace_headers=trace_headers, + priority=request.priority, + ) + generators.append(generator) + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) + + assert len(generators) == 1 + result_generator, = generators + + # Store the input messages. + if request.store: + self.msg_store[request.request_id] = messages + + if request.background: + created_time = int(time.time()) + response = ResponsesResponse.from_request( + request, sampling_params, - request.request_id, - lora_request=lora_request, - trace_headers=trace_headers, - priority=request.priority, + model_name=model_name, + created_time=created_time, + output=[], + status="queued", + usage=None, ) - generators.append(generator) - except ValueError as e: - # TODO: Use a vllm-specific Validation Error - return self.create_error_response(str(e)) + async with self.response_store_lock: + self.response_store[response.id] = response - assert len(generators) == 1 - result_generator, = generators + # Run the request in the background. + task = asyncio.create_task( + self._run_background_request( + request, + sampling_params, + result_generator, + context, + model_name, + tokenizer, + request_metadata, + created_time, + ), + name=f"create_{response.id}", + ) - # Store the input messages. - if request.store: - self.msg_store[request.request_id] = messages + # For cleanup. + response_id = response.id + self.background_tasks[response_id] = task + task.add_done_callback( + lambda _: self.background_tasks.pop(response_id, None)) + return response - if request.background: - created_time = int(time.time()) - response = ResponsesResponse.from_request( - request, - sampling_params, - model_name=model_name, - created_time=created_time, - output=[], - status="queued", - usage=None, - ) - async with self.response_store_lock: - self.response_store[response.id] = response - - # Run the request in the background. - task = asyncio.create_task( - self._run_background_request( + if request.stream: + return self.responses_stream_generator( request, sampling_params, - result_generator, + result_generator, # type: ignore[arg-type] + context, # type: ignore[arg-type] model_name, tokenizer, request_metadata, - created_time, - ), - name=f"create_{response.id}", - ) + ) - # For cleanup. - response_id = response.id - self.background_tasks[response_id] = task - task.add_done_callback( - lambda _: self.background_tasks.pop(response_id, None)) - return response + try: + result: Union[ + ErrorResponse, + ResponsesResponse] = await self.responses_full_generator( + request, + sampling_params, + result_generator, + context, + model_name, + tokenizer, + request_metadata, + ) + return result + except Exception as e: + return self.create_error_response(str(e)) + return self.create_error_response("Unknown error") - if request.stream: - raise NotImplementedError("Streaming responses are not supported") + async def _make_request( + self, + request: ResponsesRequest, + prev_response: Optional[ResponsesResponse], + tokenizer: AnyTokenizer, + ): + # Construct the input messages. + messages = self._construct_input_messages(request, prev_response) + _, request_prompts, engine_prompts = await self._preprocess_chat( + request, + tokenizer, + messages, + chat_template=self.chat_template, + chat_template_content_format=self.chat_template_content_format, + ) + return messages, request_prompts, engine_prompts - try: - return await self.responses_full_generator( - request, - sampling_params, - result_generator, - model_name, - tokenizer, - request_metadata, - ) - except Exception as e: - return self.create_error_response(str(e)) + def _make_request_with_harmony( + self, + request: ResponsesRequest, + prev_response: Optional[ResponsesResponse], + ): + if request.tool_choice != "auto": + raise NotImplementedError( + "Only 'auto' tool_choice is supported in " + "response API") + messages = self._construct_input_messages_with_harmony( + request, prev_response) + prompt_token_ids = render_for_completion(messages) + engine_prompt = EngineTokensPrompt(prompt_token_ids=prompt_token_ids) + return messages, [prompt_token_ids], [engine_prompt] async def responses_full_generator( self, request: ResponsesRequest, sampling_params: SamplingParams, result_generator: AsyncIterator[RequestOutput], + context: ConversationContext, model_name: str, tokenizer: AnyTokenizer, request_metadata: RequestResponseMetadata, @@ -285,21 +417,76 @@ async def responses_full_generator( ) -> Union[ErrorResponse, ResponsesResponse]: if created_time is None: created_time = int(time.time()) - final_res: Optional[RequestOutput] = None try: - async for res in result_generator: - final_res = res + async for _ in result_generator: + pass except asyncio.CancelledError: return self.create_error_response("Client disconnected") except ValueError as e: # TODO: Use a vllm-specific Validation Error return self.create_error_response(str(e)) - assert final_res is not None - assert len(final_res.outputs) == 1 - final_output = final_res.outputs[0] + if self.use_harmony: + assert isinstance(context, HarmonyContext) + output = self._make_response_output_items_with_harmony(context) + # TODO: these are all 0 for now! + num_prompt_tokens = context.num_prompt_tokens + num_generated_tokens = context.num_output_tokens + num_cached_tokens = context.num_cached_tokens + num_reasoning_tokens = context.num_reasoning_tokens + else: + assert isinstance(context, SimpleContext) + final_res = context.last_output + assert final_res is not None + assert len(final_res.outputs) == 1 + final_output = final_res.outputs[0] + + output = self._make_response_output_items(request, final_output, + tokenizer) + + # Calculate usage. + assert final_res.prompt_token_ids is not None + num_prompt_tokens = len(final_res.prompt_token_ids) + num_generated_tokens = len(final_output.token_ids) + num_cached_tokens = final_res.num_cached_tokens + num_reasoning_tokens = 0 + + usage = ResponseUsage( + input_tokens=num_prompt_tokens, + output_tokens=num_generated_tokens, + total_tokens=num_prompt_tokens + num_generated_tokens, + input_tokens_details=InputTokensDetails( + cached_tokens=num_cached_tokens), + output_tokens_details=OutputTokensDetails( + reasoning_tokens=num_reasoning_tokens), + ) + + response = ResponsesResponse.from_request( + request, + sampling_params, + model_name=model_name, + created_time=created_time, + output=output, + status="completed", + usage=usage, + ) + if request.store: + async with self.response_store_lock: + stored_response = self.response_store.get(response.id) + # If the response is already cancelled, don't update it. + if (stored_response is None + or stored_response.status != "cancelled"): + self.response_store[response.id] = response + return response + + def _make_response_output_items( + self, + request: ResponsesRequest, + final_output: CompletionOutput, + tokenizer: AnyTokenizer, + ): if self.reasoning_parser: try: reasoning_parser = self.reasoning_parser(tokenizer) @@ -314,13 +501,13 @@ async def responses_full_generator( reasoning_content = None content = final_output.text - output = [] + output_items = [] if reasoning_content: reasoning_item = ResponseReasoningItem( - text=reasoning_content, + content=[ResponseReasoningTextContent(text=reasoning_content)], status=None, # NOTE: Only the last output item has status. ) - output.append(reasoning_item) + output_items.append(reasoning_item) if content: output_text = ResponseOutputText( text=content, @@ -335,40 +522,22 @@ async def responses_full_generator( status="completed", type="message", ) - output.append(message) - - # Calculate usage. - assert final_res.prompt_token_ids is not None - num_prompt_tokens = len(final_res.prompt_token_ids) - num_generated_tokens = len(final_output.token_ids) - usage = UsageInfo( - prompt_tokens=num_prompt_tokens, - completion_tokens=num_generated_tokens, - total_tokens=num_prompt_tokens + num_generated_tokens, - ) - if self.enable_prompt_tokens_details and final_res.num_cached_tokens: - usage.prompt_tokens_details = PromptTokenUsageInfo( - cached_tokens=final_res.num_cached_tokens) - request_metadata.final_usage_info = usage - - response = ResponsesResponse.from_request( - request, - sampling_params, - model_name=model_name, - created_time=created_time, - output=output, - status="completed", - usage=usage, - ) + output_items.append(message) + return output_items - if request.store: - async with self.response_store_lock: - stored_response = self.response_store.get(response.id) - # If the response is already cancelled, don't update it. - if (stored_response is None - or stored_response.status != "cancelled"): - self.response_store[response.id] = response - return response + def _make_response_output_items_with_harmony( + self, + context: HarmonyContext, + ): + output_items = [] + num_init_messages = context.num_init_messages + for msg in context.messages[num_init_messages:]: + output_items.extend(parse_output_message(msg)) + # Handle the generation stopped in the middle (if any). + last_items = parse_remaining_state(context.parser) + if last_items: + output_items.extend(last_items) + return output_items def _construct_input_messages( self, @@ -390,13 +559,14 @@ def _construct_input_messages( # Add the previous output. for output_item in prev_response.output: - # NOTE: We skip the reasoning output. - if isinstance(output_item, ResponseOutputMessage): - for content in output_item.content: - messages.append({ - "role": "assistant", - "content": content.text, - }) + # NOTE: We skip the reasoning output of the previous response. + if isinstance(output_item, ResponseReasoningItem): + continue + for content in output_item.content: + messages.append({ + "role": "assistant", + "content": content.text, + }) # Append the new input. # Responses API supports simple text inputs without chat format. @@ -406,6 +576,74 @@ def _construct_input_messages( messages.extend(request.input) # type: ignore return messages + def _construct_input_messages_with_harmony( + self, + request: ResponsesRequest, + prev_response: Optional[ResponsesResponse], + ) -> list[OpenAIMessage]: + messages: list[OpenAIMessage] = [] + if prev_response is None: + # New conversation. + reasoning_effort = (request.reasoning.effort + if request.reasoning else None) + tool_types = [tool.type for tool in request.tools] + enable_browser = ("web_search_preview" in tool_types + and self.tool_server is not None) + enable_code_interpreter = ("code_interpreter" in tool_types + and self.tool_server is not None) + sys_msg = get_system_message( + reasoning_effort=reasoning_effort, + browser_description=self.tool_server.get_tool_description( + "browser") + if self.tool_server and enable_browser else None, + python_description=self.tool_server.get_tool_description( + "python") + if self.tool_server and enable_code_interpreter else None, + ) + messages.append(sys_msg) + dev_msg = get_developer_message(request.instructions, + request.tools) + messages.append(dev_msg) + else: + # Continue the previous conversation. + # FIXME(woosuk): Currently, request params like reasoning and + # instructions are ignored. + prev_msgs = self.msg_store[prev_response.id] + # Remove the previous chain-of-thoughts if there is a new "final" + # message. + if len(prev_msgs) > 0 and hasattr( + prev_msgs[-1], 'channel' + ) and prev_msgs[-1].channel == "final": # type: ignore[union-attr] + prev_final_msg_idx = -1 + for i in range(len(prev_msgs) - 2, -1, -1): + if hasattr(prev_msgs[i], 'channel') and prev_msgs[ + i].channel == "final": # type: ignore[union-attr] + prev_final_msg_idx = i + break + recent_turn_msgs = prev_msgs[prev_final_msg_idx + 1:] + del prev_msgs[prev_final_msg_idx + 1:] + for msg in recent_turn_msgs: + if hasattr( + msg, 'channel' + ) and msg.channel != "analysis": # type: ignore[union-attr] + prev_msgs.append(msg) + messages.extend(prev_msgs) + # Append the new input. + # Reponses API supports simple text inputs without chat format. + if isinstance(request.input, str): + messages.append(get_user_message(request.input)) + else: + if prev_response is not None: + prev_outputs = copy(prev_response.output) + else: + prev_outputs = [] + for response_msg in request.input: + messages.append( + parse_response_input(response_msg, prev_outputs)) + if isinstance(response_msg, ResponseFunctionToolCall): + prev_outputs.append(response_msg) + return messages + async def _run_background_request( self, request: ResponsesRequest, @@ -498,3 +736,412 @@ def _make_store_not_supported_error(self) -> ErrorResponse: "starting the vLLM server."), status_code=HTTPStatus.BAD_REQUEST, ) + + async def responses_stream_generator( + self, + request: ResponsesRequest, + sampling_params: SamplingParams, + result_generator: AsyncIterator[StreamingHarmonyContext], + context: StreamingHarmonyContext, + model_name: str, + tokenizer: AnyTokenizer, + request_metadata: RequestResponseMetadata, + created_time: Optional[int] = None, + ) -> AsyncGenerator[str, None]: + # TODO: + # 1. Handle disconnect + + created_time = created_time or int(time.time()) + + sequence_number = 0 + + def _send_event(event: BaseModel): + nonlocal sequence_number + # Set sequence_number if the event has this attribute + if hasattr(event, 'sequence_number'): + event.sequence_number = sequence_number + sequence_number += 1 + # Get event type from the event's type field if it exists + event_type = getattr(event, 'type', 'unknown') + return (f"event: {event_type}\n" + f"data: {event.model_dump_json(indent=None)}\n\n") + + current_content_index = 0 + current_output_index = 0 + current_item_id = "" + sent_output_item_added = False + + initial_response = ResponsesResponse.from_request( + request, + sampling_params, + model_name=model_name, + created_time=created_time, + output=[], + status="in_progress", + usage=None, + ).model_dump() + yield _send_event( + openai_responses_types.ResponseCreatedEvent( + type="response.created", + sequence_number=-1, + response=initial_response, + )) + yield _send_event( + openai_responses_types.ResponseInProgressEvent( + type="response.in_progress", + sequence_number=-1, + response=initial_response, + )) + + async for ctx in result_generator: + + if ctx.is_expecting_start(): + current_output_index += 1 + sent_output_item_added = False + + if len(ctx.parser.messages) > 0: + previous_item = ctx.parser.messages[-1] + if previous_item.recipient is not None: + # Deal with tool call here + pass + elif previous_item.channel == "analysis": + reasoning_item = ResponseReasoningItem( + type="reasoning", + content=[ + ResponseReasoningTextContent( + text=previous_item.content[0].text), + ], + status="completed", + ) + yield _send_event( + ResponseReasoningTextDoneEvent( + type="response.reasoning_text.done", + sequence_number=-1, + output_index=current_output_index, + content_index=current_content_index, + text=previous_item.content[0].text, + )) + yield _send_event( + ResponseContentPartDoneEvent( + type="response.content_part.done", + sequence_number=-1, + output_index=current_output_index, + content_index=current_content_index, + part=reasoning_item, + )) + yield _send_event( + ResponseOutputItemDoneEvent( + type="response.output_item.done", + sequence_number=-1, + output_index=current_output_index, + item=reasoning_item, + )) + elif previous_item.channel == "final": + text_content = ResponseOutputText( + type="output_text", + text=previous_item.content[0].text, + annotations=[], + ) + yield _send_event( + openai_responses_types.ResponseTextDoneEvent( + type="response.output_text.done", + sequence_number=-1, + output_index=current_output_index, + content_index=current_content_index, + text=previous_item.content[0].text, + logprobs=[], + item_id=current_item_id, + )) + yield _send_event( + openai_responses_types. + ResponseContentPartDoneEvent( + type="response.content_part.done", + sequence_number=-1, + item_id=current_item_id, + output_index=current_output_index, + content_index=current_content_index, + part=text_content, + )) + yield _send_event( + openai_responses_types.ResponseOutputItemDoneEvent( + type="response.output_item.done", + sequence_number=-1, + output_index=current_output_index, + item=ResponseOutputMessage( + id=current_item_id, + type="message", + role="assistant", + content=[text_content], + status="completed", + ), + )) + + if ctx.parser.last_content_delta: + if (ctx.parser.current_channel == "final" + and ctx.parser.current_recipient is None): + if not sent_output_item_added: + sent_output_item_added = True + yield _send_event( + openai_responses_types. + ResponseOutputItemAddedEvent( + type="response.output_item.added", + sequence_number=-1, + output_index=current_output_index, + item=openai_responses_types. + ResponseOutputMessage( + id=current_item_id, + type="message", + role="assistant", + content=[], + status="in_progress", + ), + )) + yield _send_event( + openai_responses_types. + ResponseContentPartAddedEvent( + type="response.content_part.added", + sequence_number=-1, + output_index=current_output_index, + item_id=current_item_id, + content_index=current_content_index, + part=openai_responses_types.ResponseOutputText( + type="output_text", + text="", + annotations=[], + logprobs=[], + ), + )) + yield _send_event( + openai_responses_types.ResponseTextDeltaEvent( + type="response.output_text.delta", + sequence_number=-1, + content_index=current_content_index, + output_index=current_output_index, + item_id=current_item_id, + delta=ctx.parser.last_content_delta, + # TODO, use logprobs from ctx.last_request_output + logprobs=[], + )) + elif (ctx.parser.current_channel == "analysis" + and ctx.parser.current_recipient is None): + if not sent_output_item_added: + sent_output_item_added = True + yield _send_event( + openai_responses_types. + ResponseOutputItemAddedEvent( + type="response.output_item.added", + sequence_number=-1, + output_index=current_output_index, + item=openai_responses_types. + ResponseReasoningItem( + type="reasoning", + id=current_item_id, + summary=[], + status="in_progress", + ), + )) + yield _send_event( + openai_responses_types. + ResponseContentPartAddedEvent( + type="response.content_part.added", + sequence_number=-1, + output_index=current_output_index, + item_id=current_item_id, + content_index=current_content_index, + # TODO: migrate this to + # ResponseReasoningTextContent for now + part=openai_responses_types.ResponseOutputText( + type="output_text", + text="", + annotations=[], + logprobs=[], + ), + )) + # TODO: migrate to OpenAI types once updated. + yield _send_event( + ResponseReasoningTextDeltaEvent( + type="response.reasoning_text.delta", + output_index=current_output_index, + content_index=current_content_index, + delta=ctx.parser.last_content_delta, + sequence_number=-1, + )) + + if ctx.is_assistant_action_turn() and len(ctx.parser.messages) > 0: + previous_item = ctx.parser.messages[-1] + if (self.supports_browsing + and previous_item.recipient is not None + and previous_item.recipient.startswith("browser.")): + function_name = previous_item.recipient[len("browser."):] + action = None + parsed_args = json.loads(previous_item.content[0].text) + if function_name == "search": + action = (openai_responses_types. + response_function_web_search.ActionSearch( + type="search", + query=parsed_args["query"], + )) + elif function_name == "open": + action = ( + openai_responses_types. + response_function_web_search.ActionOpenPage( + type="open_page", + # TODO: translate to url + url=f"cursor:{parsed_args.get('cursor', '')}", + )) + elif function_name == "find": + action = ( + openai_responses_types. + response_function_web_search.ActionFind( + type="find", + pattern=parsed_args["pattern"], + # TODO: translate to url + url=f"cursor:{parsed_args.get('cursor', '')}", + )) + else: + raise ValueError( + f"Unknown function name: {function_name}") + + yield _send_event( + openai_responses_types.ResponseOutputItemAddedEvent( + type="response.output_item.added", + sequence_number=-1, + output_index=current_output_index, + item=openai_responses_types. + response_function_web_search. + ResponseFunctionWebSearch( + # TODO: generate a unique id for web search call + type="web_search_call", + id=current_item_id, + action=action, + status="in_progress", + ), + )) + yield _send_event( + openai_responses_types. + ResponseWebSearchCallInProgressEvent( + type="response.web_search_call.in_progress", + sequence_number=-1, + output_index=current_output_index, + item_id=current_item_id, + )) + yield _send_event( + openai_responses_types. + ResponseWebSearchCallSearchingEvent( + type="response.web_search_call.searching", + sequence_number=-1, + output_index=current_output_index, + item_id=current_item_id, + )) + + # enqueue + yield _send_event( + openai_responses_types. + ResponseWebSearchCallCompletedEvent( + type="response.web_search_call.completed", + sequence_number=-1, + output_index=current_output_index, + item_id=current_item_id, + )) + yield _send_event( + openai_responses_types.ResponseOutputItemDoneEvent( + type="response.output_item.done", + sequence_number=-1, + output_index=current_output_index, + item=openai_responses_types. + ResponseFunctionWebSearch( + type="web_search_call", + id=current_item_id, + action=action, + status="completed", + ), + )) + + if (self.supports_code_interpreter + and previous_item.recipient is not None + and previous_item.recipient.startswith("python")): + yield _send_event( + openai_responses_types.ResponseOutputItemAddedEvent( + type="response.output_item.added", + sequence_number=-1, + output_index=current_output_index, + item=openai_responses_types. + ResponseCodeInterpreterToolCallParam( + type="code_interpreter_call", + id=current_item_id, + code="", + container_id="auto", + outputs=[], + status="in_progress", + ), + )) + yield _send_event( + openai_responses_types. + ResponseCodeInterpreterCallInProgressEvent( + type="response.code_interpreter_call.in_progress", + sequence_number=-1, + output_index=current_output_index, + item_id=current_item_id, + )) + # TODO: do we need to add delta event here? + yield _send_event( + openai_responses_types. + ResponseCodeInterpreterCallCodeDoneEvent( + type="response.code_interpreter_call_code.done", + sequence_number=-1, + output_index=current_output_index, + item_id=current_item_id, + code=previous_item.content[0].text)) + yield _send_event( + openai_responses_types. + ResponseCodeInterpreterCallInterpretingEvent( + type="response.code_interpreter_call.interpreting", + sequence_number=-1, + output_index=current_output_index, + item_id=current_item_id, + )) + yield _send_event( + openai_responses_types. + ResponseCodeInterpreterCallCompletedEvent( + type="response.code_interpreter_call.completed", + sequence_number=-1, + output_index=current_output_index, + item_id=current_item_id, + )) + yield _send_event( + openai_responses_types.ResponseOutputItemDoneEvent( + type="response.output_item.done", + sequence_number=-1, + output_index=current_output_index, + item=openai_responses_types. + ResponseCodeInterpreterToolCallParam( + type="code_interpreter_call", + id=current_item_id, + code=previous_item.content[0].text, + container_id="auto", + # TODO: add outputs here + outputs=[], + status="completed", + ), + )) + + async def empty_async_generator(): + if False: + yield + + final_response = await self.responses_full_generator( + request, + sampling_params, + empty_async_generator(), + context, + model_name, + tokenizer, + request_metadata, + created_time=created_time, + ) + yield _send_event( + openai_responses_types.ResponseCompletedEvent( + type="response.completed", + sequence_number=-1, + response=final_response.model_dump(), + )) diff --git a/vllm/entrypoints/openai/tool_server.py b/vllm/entrypoints/openai/tool_server.py new file mode 100644 index 000000000000..31e3c7524a5b --- /dev/null +++ b/vllm/entrypoints/openai/tool_server.py @@ -0,0 +1,162 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from abc import ABC, abstractmethod +from contextlib import AbstractAsyncContextManager, asynccontextmanager +from typing import Any + +from mcp import ClientSession +from mcp.client.sse import sse_client +from mcp.types import ListToolsResult +from openai_harmony import ToolDescription, ToolNamespaceConfig + +from vllm.entrypoints.tool import HarmonyBrowserTool, HarmonyPythonTool, Tool +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +async def list_server_and_tools(server_url: str): + + async with sse_client(url=server_url) as streams, ClientSession( + *streams) as session: + initialize_response = await session.initialize() + list_tools_response = await session.list_tools() + return initialize_response, list_tools_response + + +def trim_schema(schema: dict) -> dict: + # Turn JSON Schema from MCP generated into Harmony's variant. + if "title" in schema: + del schema["title"] + if "default" in schema and schema["default"] is None: + del schema["default"] + if "anyOf" in schema: + # Turn "anyOf": [{"type": "type-1"}, {"type": "type-2"}] + # into "type": ["type-1", "type-2"] + # if there's more than 1 types, also remove "null" type as Harmony will + # just ignore it + types = [ + type_dict["type"] for type_dict in schema["anyOf"] + if type_dict["type"] != 'null' + ] + schema["type"] = types + del schema["anyOf"] + if "properties" in schema: + schema["properties"] = { + k: trim_schema(v) + for k, v in schema["properties"].items() + } + return schema + + +def post_process_tools_description( + list_tools_result: ListToolsResult) -> ListToolsResult: + # Adapt the MCP tool result for Harmony + for tool in list_tools_result.tools: + tool.inputSchema = trim_schema(tool.inputSchema) + + # Some tools schema don't need to be part of the prompt (e.g. simple text + # in text out for Python) + list_tools_result.tools = [ + tool for tool in list_tools_result.tools + if getattr(tool.annotations, "include_in_prompt", True) + ] + + return list_tools_result + + +class ToolServer(ABC): + + @abstractmethod + def has_tool(self, tool_name: str): + pass + + @abstractmethod + def get_tool_description(self, tool_name: str): + pass + + @abstractmethod + def get_tool_session(self, + tool_name: str) -> AbstractAsyncContextManager[Any]: + ... + + +class MCPToolServer(ToolServer): + + def __init__(self): + self.harmony_tool_descriptions = {} + + async def add_tool_server(self, server_url: str): + tool_urls = server_url.split(",") + self.harmony_tool_descriptions = {} + self.urls: dict[str, str] = {} + for url in tool_urls: + url = f"http://{url}/sse" + initialize_response, list_tools_response = ( + await list_server_and_tools(url)) + + list_tools_response = post_process_tools_description( + list_tools_response) + + tool_from_mcp = ToolNamespaceConfig( + name=initialize_response.serverInfo.name, + description=initialize_response.instructions, + tools=[ + ToolDescription.new(name=tool.name, + description=tool.description, + parameters=tool.inputSchema) + for tool in list_tools_response.tools + ]) + self.harmony_tool_descriptions[tool_from_mcp.name] = tool_from_mcp + if tool_from_mcp.name not in self.urls: + self.urls[tool_from_mcp.name] = url + else: + logger.warning( + "Tool %s already exists. Ignoring duplicate tool server %s", + tool_from_mcp.name, url) + + def has_tool(self, tool_name: str): + return tool_name in self.harmony_tool_descriptions + + def get_tool_description(self, tool_name: str): + return self.harmony_tool_descriptions.get(tool_name) + + @asynccontextmanager + async def get_tool_session(self, tool_name: str): + url = self.urls.get(tool_name) + if url: + async with sse_client(url=url) as streams, ClientSession( + *streams) as session: + await session.initialize() + yield session + else: + logger.warning("Tool %s not found", tool_name) + + +class DemoToolServer(ToolServer): + + def __init__(self): + self.tools: dict[str, Tool] = {} + browser_tool = HarmonyBrowserTool() + if browser_tool.enabled: + self.tools["browser"] = browser_tool + python_tool = HarmonyPythonTool() + if python_tool.enabled: + self.tools["python"] = python_tool + + def has_tool(self, tool_name: str): + return tool_name in self.tools + + def get_tool_description(self, tool_name: str): + if tool_name not in self.tools: + return None + if tool_name == "browser": + return ToolNamespaceConfig.browser() + elif tool_name == "python": + return ToolNamespaceConfig.python() + else: + raise ValueError(f"Unknown tool {tool_name}") + + @asynccontextmanager + async def get_tool_session(self, tool_name: str): + yield self.tools[tool_name] diff --git a/vllm/entrypoints/tool.py b/vllm/entrypoints/tool.py new file mode 100644 index 000000000000..01ee77414f13 --- /dev/null +++ b/vllm/entrypoints/tool.py @@ -0,0 +1,87 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import os +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any + +from vllm.logger import init_logger + +if TYPE_CHECKING: + # Avoid circular import. + from vllm.entrypoints.context import ConversationContext + +logger = init_logger(__name__) + + +class Tool(ABC): + + @abstractmethod + async def get_result(self, context: "ConversationContext") -> Any: + pass + + +class HarmonyBrowserTool(Tool): + + def __init__(self): + self.enabled = True + exa_api_key = os.getenv("EXA_API_KEY") + if not exa_api_key: + self.enabled = False + logger.warning_once("EXA_API_KEY is not set, browsing is disabled") + return + + try: + from gpt_oss.tools.simple_browser import SimpleBrowserTool + from gpt_oss.tools.simple_browser.backend import ExaBackend + except ImportError: + self.enabled = False + logger.warning_once( + "gpt_oss is not installed, browsing is disabled") + return + + browser_backend = ExaBackend(source="web", api_key=exa_api_key) + self.browser_tool = SimpleBrowserTool(backend=browser_backend) + logger.info_once("Browser tool initialized") + + async def get_result(self, context: "ConversationContext") -> Any: + from vllm.entrypoints.context import HarmonyContext + assert isinstance(context, HarmonyContext) + last_msg = context.messages[-1] + tool_output_msgs = [] + async for msg in self.browser_tool.process(last_msg): + tool_output_msgs.append(msg) + return tool_output_msgs + + @property + def tool_config(self) -> Any: + return self.browser_tool.tool_config + + +class HarmonyPythonTool(Tool): + + def __init__(self): + self.enabled = True + + try: + from gpt_oss.tools.python_docker.docker_tool import PythonTool + except ImportError: + self.enabled = False + logger.warning_once( + "gpt_oss is not installed, code interpreter is disabled") + return + + self.python_tool = PythonTool() + logger.info_once("Code interpreter tool initialized") + + async def get_result(self, context: "ConversationContext") -> Any: + from vllm.entrypoints.context import HarmonyContext + assert isinstance(context, HarmonyContext) + last_msg = context.messages[-1] + tool_output_msgs = [] + async for msg in self.python_tool.process(last_msg): + tool_output_msgs.append(msg) + return tool_output_msgs + + @property + def tool_config(self) -> Any: + return self.python_tool.tool_config diff --git a/vllm/envs.py b/vllm/envs.py index 78f955f78a98..f23827c106dd 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -17,6 +17,7 @@ LD_LIBRARY_PATH: Optional[str] = None VLLM_USE_TRITON_FLASH_ATTN: bool = True VLLM_V1_USE_PREFILL_DECODE_ATTENTION: bool = False + VLLM_USE_AITER_UNIFIED_ATTENTION: bool = False VLLM_FLASH_ATTN_VERSION: Optional[int] = None LOCAL_RANK: int = 0 CUDA_VISIBLE_DEVICES: Optional[str] = None @@ -150,8 +151,10 @@ VLLM_USE_CUDNN_PREFILL: bool = False VLLM_ENABLE_CUDAGRAPH_GC: bool = False VLLM_LOOPBACK_IP: str = "" + VLLM_NSYS_PROFILE_START_STOP: str = "None" VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE: bool = False VLLM_ENABLE_RESPONSES_API_STORE: bool = False + VLLM_PRINT_LOGO: bool = True def get_default_cache_root(): @@ -327,6 +330,12 @@ def get_vllm_port() -> Optional[int]: (os.getenv("VLLM_V1_USE_PREFILL_DECODE_ATTENTION", "False").lower() in ("true", "1")), + # Use AITER triton unified attention for V1 attention + "VLLM_USE_AITER_UNIFIED_ATTENTION": + lambda: + (os.getenv("VLLM_USE_AITER_UNIFIED_ATTENTION", "False").lower() in + ("true", "1")), + # Force vllm to use a specific flash-attention version (2 or 3), only valid # when using the flash-attention backend. "VLLM_FLASH_ATTN_VERSION": @@ -1027,6 +1036,10 @@ def get_vllm_port() -> Optional[int]: "VLLM_USE_CUDNN_PREFILL": lambda: bool(int(os.getenv("VLLM_USE_CUDNN_PREFILL", "0"))), + # If set to 1, use the TRTLLM Context Attention backend in flashinfer. + "VLLM_USE_TRTLLM_CONTEXT_ATTENTION": + lambda: os.getenv("VLLM_USE_TRTLLM_CONTEXT_ATTENTION", None), + # If set to 1, use the TRTLLM Decode Attention backend in flashinfer. "VLLM_USE_TRTLLM_DECODE_ATTENTION": lambda: os.getenv("VLLM_USE_TRTLLM_DECODE_ATTENTION", None), @@ -1037,10 +1050,22 @@ def get_vllm_port() -> Optional[int]: "VLLM_ENABLE_CUDAGRAPH_GC": lambda: bool(int(os.getenv("VLLM_ENABLE_CUDAGRAPH_GC", "0"))), + # If set to 1, use the FlashInfer MXFP4 x MXFP8 MoE backend. + "VLLM_USE_FLASHINFER_MXFP4_MOE": + lambda: bool(int(os.getenv("VLLM_USE_FLASHINFER_MXFP4_MOE", "0"))), + + # If set to 1, use the FlashInfer MXFP4 x BF16 MoE backend. + "VLLM_USE_FLASHINFER_MXFP4_BF16_MOE": + lambda: bool(int(os.getenv("VLLM_USE_FLASHINFER_MXFP4_BF16_MOE", "0"))), + # Used to force set up loopback IP "VLLM_LOOPBACK_IP": lambda: os.getenv("VLLM_LOOPBACK_IP", ""), + # Used to set the start and stop iteration for nsys profile + "VLLM_NSYS_PROFILE_START_STOP": + lambda: os.getenv("VLLM_NSYS_PROFILE_START_STOP", "None"), + # Used to set the process name prefix for vLLM processes. # This is useful for debugging and monitoring purposes. # The default value is "VLLM". @@ -1069,6 +1094,11 @@ def get_vllm_port() -> Optional[int]: # never removed from memory until the server terminates. "VLLM_ENABLE_RESPONSES_API_STORE": lambda: bool(int(os.getenv("VLLM_ENABLE_RESPONSES_API_STORE", "0"))), + + # If set, print logo + "VLLM_PRINT_LOGO": + lambda: bool(int(os.getenv("VLLM_PRINT_LOGO", "1"))), + } # --8<-- [end:env-vars-definition] diff --git a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py index 3ccddb52998b..6600dfdc0f8b 100644 --- a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py @@ -260,7 +260,9 @@ def apply(self, output: torch.Tensor, hidden_states: torch.Tensor, expert_map: Optional[torch.Tensor], w1_scale: Optional[torch.Tensor], w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor], - w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], w1_bias: Optional[torch.Tensor], + w2_bias: Optional[torch.Tensor], + a1q_scale: Optional[torch.Tensor], a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, workspace2: torch.Tensor, expert_tokens_meta: Optional[mk.ExpertTokensMetadata], diff --git a/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py index fc30e84e6656..171d40fa998a 100644 --- a/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py @@ -138,7 +138,9 @@ def apply(self, output: torch.Tensor, hidden_states: torch.Tensor, expert_map: Optional[torch.Tensor], w1_scale: Optional[torch.Tensor], w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor], - w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], w1_bias: Optional[torch.Tensor], + w2_bias: Optional[torch.Tensor], + a1q_scale: Optional[torch.Tensor], a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, workspace2: torch.Tensor, expert_tokens_meta: Optional[mk.ExpertTokensMetadata], @@ -149,6 +151,6 @@ def apply(self, output: torch.Tensor, hidden_states: torch.Tensor, assert experts is not None experts.apply(output, hidden_states, w1, w2, topk_weights, topk_ids, activation, global_num_experts, expert_map, w1_scale, - w2_scale, w1_zp, w2_zp, a1q_scale, a2_scale, workspace13, - workspace2, expert_tokens_meta, + w2_scale, w1_zp, w2_zp, w1_bias, w2_bias, a1q_scale, + a2_scale, workspace13, workspace2, expert_tokens_meta, apply_router_weight_on_input, extra_expert_args) diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index 2585a2953c9d..63675a9be607 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -299,7 +299,9 @@ def apply(self, output: torch.Tensor, hidden_states: torch.Tensor, expert_map: Optional[torch.Tensor], w1_scale: Optional[torch.Tensor], w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor], - w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], w1_bias: Optional[torch.Tensor], + w2_bias: Optional[torch.Tensor], + a1q_scale: Optional[torch.Tensor], a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, workspace2: torch.Tensor, expert_tokens_meta: Optional[mk.ExpertTokensMetadata], @@ -625,8 +627,10 @@ def apply(self, output: torch.Tensor, hidden_states: torch.Tensor, topk_ids: torch.Tensor, activation: str, global_num_experts: int, expert_map: Optional[torch.Tensor], w1_scale: torch.Tensor, w2_scale: torch.Tensor, w1_zp: Optional[torch.Tensor], - w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], - a2_scale: torch.Tensor, workspace13: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], w1_bias: Optional[torch.Tensor], + w2_bias: Optional[torch.Tensor], + a1q_scale: Optional[torch.Tensor], a2_scale: torch.Tensor, + workspace13: Optional[torch.Tensor], workspace2: Optional[torch.Tensor], expert_tokens_meta: Optional[mk.ExpertTokensMetadata], apply_router_weight_on_input: bool, diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index bd3605378b6d..5ce6cdc7bc02 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -207,6 +207,8 @@ def apply( w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor], w2_zp: Optional[torch.Tensor], + w1_bias: Optional[torch.Tensor], + w2_bias: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py index 3e79a1a8c24b..901447b1465e 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py @@ -143,6 +143,8 @@ def apply( w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor], w2_zp: Optional[torch.Tensor], + w1_bias: Optional[torch.Tensor], + w2_bias: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], a2_scale: Optional[torch.Tensor], # Not used workspace13: Optional[torch.Tensor], diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index 9a5c85e120cc..8a7845524f43 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -694,7 +694,9 @@ def apply(self, output: torch.Tensor, hidden_states: torch.Tensor, expert_map: Optional[torch.Tensor], w1_scale: Optional[torch.Tensor], w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor], - w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], w1_bias: Optional[torch.Tensor], + w2_bias: Optional[torch.Tensor], + a1q_scale: Optional[torch.Tensor], a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, workspace2: torch.Tensor, expert_tokens_meta: Optional[mk.ExpertTokensMetadata], @@ -900,7 +902,9 @@ def apply(self, output: torch.Tensor, hidden_states: torch.Tensor, expert_map: Optional[torch.Tensor], w1_scale: Optional[torch.Tensor], w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor], - w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], w1_bias: Optional[torch.Tensor], + w2_bias: Optional[torch.Tensor], + a1q_scale: Optional[torch.Tensor], a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, workspace2: torch.Tensor, expert_tokens_meta: Optional[mk.ExpertTokensMetadata], diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 56d1dfe135b3..ec265c8dabd1 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1843,6 +1843,8 @@ def apply( w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor], w2_zp: Optional[torch.Tensor], + w1_bias: Optional[torch.Tensor], + w2_bias: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 9e7296feeae1..609d1b3d7328 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -33,7 +33,8 @@ from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.platforms.interface import CpuArchEnum -from vllm.utils import direct_register_custom_op, has_deep_ep, has_pplx +from vllm.utils import (direct_register_custom_op, has_deep_ep, has_pplx, + has_triton_kernels, round_up) from vllm.utils.flashinfer import has_flashinfer if current_platform.is_cuda_alike(): @@ -719,6 +720,20 @@ def __init__( self.global_num_experts = num_experts + num_redundant_experts + if quant_config is not None and quant_config.get_name() == "mxfp4": + if has_triton_kernels: + self.use_triton_kernels = True + if (current_platform.is_rocm() + or self.moe_parallel_config.use_deepep_ll_kernels + or envs.VLLM_USE_FLASHINFER_MXFP4_MOE + or envs.VLLM_USE_FLASHINFER_MXFP4_BF16_MOE + or self.moe_parallel_config.use_deepep_ll_kernels): + # For ROCm or DEEPEP low latency, + # we need to round up the hidden size + hidden_size = round_up(hidden_size, 256) + else: + raise ValueError("triton_kernels must be installed first") + # For smuggling this layer into the fused moe custom op compilation_config = vllm_config.compilation_config if prefix in compilation_config.static_forward_context: @@ -1060,6 +1075,18 @@ def weight_loader(self, shard_id: str, expert_id: int, return_success: bool = False) -> Optional[bool]: + # if expert_id is None, then + # all the experts are loaded at the same time + if not expert_id and self.quant_config.get_name() == "mxfp4": + if "bias" in weight_name: + dim1 = loaded_weight.shape[1] + param.data[:, :dim1].copy_(loaded_weight) + else: + dim1 = loaded_weight.shape[1] + dim2 = loaded_weight.shape[2] + param.data[:, :dim1, :dim2].copy_(loaded_weight) + return True if return_success else None + expert_id = self._map_global_expert_id_to_local_expert_id(expert_id) if expert_id == -1: # Failed to load this param since it's not local to this rank @@ -1455,11 +1482,18 @@ def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): # TODO: Once the OOM issue for the TPU backend is resolved, we will # switch to using the moe_forward custom op. + og_hidden_states = hidden_states.shape[-1] + if self.hidden_size != og_hidden_states: + hidden_states = F.pad(hidden_states, + (0, self.hidden_size - og_hidden_states), + mode='constant', + value=0.0) if current_platform.is_tpu(): return self.forward_impl(hidden_states, router_logits) else: - return torch.ops.vllm.moe_forward(hidden_states, router_logits, - self.layer_name) + return torch.ops.vllm.moe_forward( + hidden_states, router_logits, + self.layer_name)[..., :og_hidden_states] def forward_impl_chunked(self, full_hidden_states: torch.Tensor, full_router_logits: torch.Tensor): @@ -1588,6 +1622,7 @@ def forward_impl(self, hidden_states: torch.Tensor, final_hidden_states = self.maybe_all_reduce_tensor_model_parallel( final_hidden_states) + # manually crop the tensor since oai kernel pad the output return final_hidden_states @classmethod diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 6262904e4dca..0184086288f7 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -362,6 +362,8 @@ def apply( w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor], w2_zp: Optional[torch.Tensor], + w1_bias: Optional[torch.Tensor], + w2_bias: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, @@ -461,6 +463,7 @@ def _do_fused_experts( expert_map: Optional[torch.Tensor], w1_scale: Optional[torch.Tensor], w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor], w2_zp: Optional[torch.Tensor], + w1_bias: Optional[torch.Tensor], w2_bias: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], a2_scale: Optional[torch.Tensor], expert_tokens_meta: Optional[ExpertTokensMetadata], @@ -503,6 +506,8 @@ def _do_fused_experts( w2_scale=w2_scale, w1_zp=w1_zp, w2_zp=w2_zp, + w1_bias=w1_bias, + w2_bias=w2_bias, a1q_scale=a1q_scale, a2_scale=a2_scale, workspace13=workspace13, @@ -529,6 +534,8 @@ def _maybe_chunk_fused_experts( w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor], w2_zp: Optional[torch.Tensor], + w1_bias: Optional[torch.Tensor], + w2_bias: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], a2_scale: Optional[torch.Tensor], expert_tokens_meta: Optional[ExpertTokensMetadata], @@ -558,6 +565,8 @@ def _maybe_chunk_fused_experts( w2_scale=w2_scale, w1_zp=w1_zp, w2_zp=w2_zp, + w1_bias=w1_bias, + w2_bias=w2_bias, a1q_scale=a1q_scale, a2_scale=a2_scale, expert_tokens_meta=expert_tokens_meta, @@ -658,6 +667,8 @@ def slice_expert_tokens_metadata( w2_scale=w2_scale, w1_zp=w1_zp, w2_zp=w2_zp, + w1_bias=w1_bias, + w2_bias=w2_bias, a1q_scale=c_a1q_scale, a2_scale=c_a2_scale, expert_tokens_meta=c_expert_tokens_meta, @@ -681,6 +692,8 @@ def forward( w2_scale: Optional[torch.Tensor] = None, w1_zp: Optional[torch.Tensor] = None, w2_zp: Optional[torch.Tensor] = None, + w1_bias: Optional[torch.Tensor] = None, + w2_bias: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, @@ -782,6 +795,8 @@ def forward( w2_scale=w2_scale, w1_zp=w1_zp, w2_zp=w2_zp, + w1_bias=w1_bias, + w2_bias=w2_bias, a1q_scale=a1q_scale, a2_scale=a2_scale, expert_tokens_meta=expert_tokens_meta, diff --git a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py index 1b31368c79cd..c1d7d2822414 100644 --- a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py @@ -125,7 +125,9 @@ def apply(self, output: torch.Tensor, hidden_states: torch.Tensor, expert_map: Optional[torch.Tensor], w1_scale: Optional[torch.Tensor], w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor], - w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], w1_bias: Optional[torch.Tensor], + w2_bias: Optional[torch.Tensor], + a1q_scale: Optional[torch.Tensor], a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, workspace2: torch.Tensor, expert_tokens_meta: Optional[mk.ExpertTokensMetadata], @@ -152,6 +154,8 @@ def apply(self, output: torch.Tensor, hidden_states: torch.Tensor, w2_scale, w1_zp, w2_zp, + w1_bias, + w2_bias, a1q_scale, a2_scale, workspace13, diff --git a/vllm/model_executor/layers/fused_moe/triton_kernels_moe.py b/vllm/model_executor/layers/fused_moe/triton_kernels_moe.py new file mode 100644 index 000000000000..604a919b1d39 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/triton_kernels_moe.py @@ -0,0 +1,233 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Any, Optional + +import torch +import triton_kernels.swiglu +from triton_kernels.matmul_ogs import (FnSpecs, FusedActivation, + PrecisionConfig, matmul_ogs) +from triton_kernels.routing import (GatherIndx, RoutingData, ScatterIndx, + routing) + +import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( + TopKWeightAndReduceDelegate) + + +def triton_kernel_moe_forward( + hidden_states: torch.Tensor, + w1, + w2, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + activation: str = "silu", + apply_router_weight_on_input: bool = False, + use_fp8_w8a8: bool = False, + per_channel_quant: bool = False, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + w1_bias: Optional[torch.Tensor] = None, + w2_bias: Optional[torch.Tensor] = None, + w1_precision: Optional[PrecisionConfig] = None, + w2_precision: Optional[PrecisionConfig] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[list[int]] = None, +) -> torch.Tensor: + + routing_data, gather_idx, scatter_idx = routing(gating_output, + topk, + sm_first=not renormalize) + + return triton_kernel_fused_experts( + None, + hidden_states, + w1, + w2, + routing_data, + gather_idx, + scatter_idx, + activation=activation, + apply_router_weight_on_input=apply_router_weight_on_input, + use_fp8_w8a8=use_fp8_w8a8, + per_channel_quant=per_channel_quant, + global_num_experts=global_num_experts, + expert_map=expert_map, + w1_scale=w1_scale, + w2_scale=w2_scale, + w1_bias=w1_bias, + w2_bias=w2_bias, + w1_precision=w1_precision, + w2_precision=w2_precision, + a1_scale=a1_scale, + a2_scale=a2_scale, + block_shape=block_shape) + + +# This is a triton implementation of the fused_experts function +def triton_kernel_fused_experts( + output_tensor: torch.Tensor, + hidden_states: torch.Tensor, + w1, + w2, + routing_data: RoutingData, + gather_indx: GatherIndx, + scatter_indx: ScatterIndx, + activation: str = "silu", + swiglu_alpha: float = 1.702, + swiglu_limit: float = 7.0, + apply_router_weight_on_input: bool = False, + use_fp8_w8a8: bool = False, + per_channel_quant: bool = False, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + w1_bias: Optional[torch.Tensor] = None, + w2_bias: Optional[torch.Tensor] = None, + w1_precision: Optional[PrecisionConfig] = None, + w2_precision: Optional[PrecisionConfig] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[list[int]] = None, +) -> torch.Tensor: + + # type check, uint8 means mxfp4 + assert hidden_states.dtype == torch.bfloat16 + # assert w1.dtype in (torch.bfloat16, torch.uint8) + # assert w2.dtype in (torch.bfloat16, torch.uint8) + assert w1_bias is None or w1_bias.dtype == torch.float32 + assert w2_bias is None or w2_bias.dtype == torch.float32 + + # Shape check, only check non-mxfp4 + # assert hidden_states.ndim == 2 + assert hidden_states.shape[-1] == w1.shape[-2] + assert w2.shape[-1] == w1.shape[1] + + # M, K = hidden_states.shape + E, _, N = w1.shape + # n_expts_act = routing_data.n_expts_act + # dtype = hidden_states.dtype + + if global_num_experts == -1: + global_num_experts = E + + act = FusedActivation( + FnSpecs("swiglu", triton_kernels.swiglu.swiglu_fn, ("alpha", "limit")), + (swiglu_alpha, swiglu_limit), 2) + gammas = routing_data.gate_scal if routing_data else None + + intermediate_cache1 = matmul_ogs( + hidden_states, + w1, + w1_bias, + routing_data, + gather_indx=gather_indx, + precision_config=w1_precision, + gammas=gammas if apply_router_weight_on_input else None, + fused_activation=act) + + intermediate_cache3 = matmul_ogs( + intermediate_cache1, + w2, + w2_bias, + routing_data, + scatter_indx=scatter_indx, + precision_config=w2_precision, + gammas=None if apply_router_weight_on_input else gammas, + y=output_tensor, + ) + return intermediate_cache3 + + +class BatchedOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute): + + def __init__(self, quant_config, max_num_tokens: int, num_dispatchers: int, + w1_precision: PrecisionConfig, w2_precision: PrecisionConfig): + super().__init__(quant_config) + self.max_num_tokens = max_num_tokens + self.num_dispatchers = num_dispatchers + self.w1_precision = w1_precision + self.w2_precision = w2_precision + + @property + def activation_formats( + self + ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: + return (mk.FusedMoEActivationFormat.BatchedExperts, + mk.FusedMoEActivationFormat.BatchedExperts) + + def supports_chunking(self) -> bool: + return False + + def supports_expert_map(self) -> bool: + return False + + def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: + # Let PrepareAndFinalize::finalize() decide the impl. + return TopKWeightAndReduceDelegate() + + def workspace_shapes( + self, a: torch.Tensor, aq: torch.Tensor, M: int, N: int, K: int, + topk: int, global_num_experts: int, local_num_experts: int, + expert_tokens_meta: Optional[mk.ExpertTokensMetadata] + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: + # workspace are allocated inside the kernel + assert a.dim() == 2 + num_dp = self.num_dispatchers + num_experts = local_num_experts + max_num_tokens = self.max_num_tokens + workspace2 = (0, 0, 0) + output = (num_experts, max_num_tokens * num_dp, N) + return (output, workspace2, output, a.dtype) + + def apply( + self, + output: torch.Tensor, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + activation: str, + global_num_experts: int, + expert_map: Optional[torch.Tensor], + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], + w1_zp: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], + w1_bias: Optional[torch.Tensor], + w2_bias: Optional[torch.Tensor], + a1q_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + workspace13: torch.Tensor, + workspace2: torch.Tensor, + expert_tokens_meta: Optional[mk.ExpertTokensMetadata], + apply_router_weight_on_input: bool, + extra_expert_args: Optional[dict[str, Any]], + ): + return triton_kernel_fused_experts( + output, + hidden_states, + w1, + w2, + None, + None, + None, + activation=activation, + apply_router_weight_on_input=False, + use_fp8_w8a8=False, + per_channel_quant=False, + global_num_experts=global_num_experts, + expert_map=expert_map, + w1_scale=w1_scale, + w2_scale=w2_scale, + w1_bias=w1_bias, + w2_bias=w2_bias, + w1_precision=self.w1_precision, + w2_precision=self.w2_precision, + a1_scale=a1q_scale, + a2_scale=a2_scale) diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index 95aea912a150..8d63027e1863 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -37,6 +37,7 @@ "auto-round", "rtn", "inc", + "mxfp4", ] QUANTIZATION_METHODS: list[str] = list(get_args(QuantizationMethods)) @@ -110,6 +111,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]: from .marlin import MarlinConfig from .modelopt import ModelOptFp8Config, ModelOptNvFp4Config from .moe_wna16 import MoeWNA16Config + from .mxfp4 import Mxfp4Config from .neuron_quant import NeuronQuantConfig from .ptpc_fp8 import PTPCFp8Config from .qqq import QQQConfig @@ -148,6 +150,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]: "auto-round": AutoRoundConfig, "rtn": RTNConfig, "inc": INCConfig, + "mxfp4": Mxfp4Config, } # Update the `method_to_config` with customized quantization methods. method_to_config.update(_CUSTOMIZED_METHOD_TO_QUANT_CONFIG) diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py new file mode 100644 index 000000000000..58b40b437d99 --- /dev/null +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -0,0 +1,507 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Callable, Optional + +import torch +from torch.nn.parameter import Parameter + +from vllm import envs +from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe import ( + FusedMoE, FusedMoEActivationFormat, FusedMoEConfig, FusedMoEMethodBase, + FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize) +from vllm.model_executor.layers.linear import (LinearBase, + UnquantizedLinearMethod) +from vllm.model_executor.layers.quantization import QuantizationMethods +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, QuantizeMethodBase) +from vllm.model_executor.layers.quantization.utils.mxfp4_utils import ( + _can_support_triton_kernels, _swizzle_mxfp4) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + is_layer_skipped) +from vllm.model_executor.utils import set_weight_attrs +from vllm.platforms import current_platform +from vllm.utils import next_power_of_2, round_up + +if (envs.VLLM_USE_FLASHINFER_MXFP4_MOE + or envs.VLLM_USE_FLASHINFER_MXFP4_BF16_MOE): + # from flashinfer.fused_moe import cutlass_fused_moe + from flashinfer import (mxfp8_quantize, shuffle_matrix_a, + shuffle_matrix_sf_a, trtllm_fp4_block_scale_moe) + +logger = init_logger(__name__) + + +class Mxfp4Config(QuantizationConfig): + + def __init__(self, ignored_layers: Optional[list[str]] = None): + super().__init__() + self.ignored_layers = ignored_layers + + @classmethod + def from_config(cls, config): + return cls() + + @classmethod + def get_min_capability(cls) -> int: + return 80 + + @classmethod + def get_name(cls) -> QuantizationMethods: + return "mxfp4" + + @classmethod + def get_supported_act_dtypes(cls) -> list[torch.dtype]: + return [torch.bfloat16] + + @classmethod + def get_config_filenames(cls) -> list[str]: + return [] + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["QuantizeMethodBase"]: + from vllm.attention.layer import Attention # Avoid circular import + + if isinstance(layer, LinearBase): + if self.ignored_layers and is_layer_skipped( + prefix=prefix, + ignored_layers=self.ignored_layers, + fused_mapping=self.packed_modules_mapping): + return UnquantizedLinearMethod() + raise NotImplementedError("Mxfp4 linear layer is not implemented") + elif isinstance(layer, FusedMoE): + return Mxfp4MoEMethod(layer.moe_config) + elif isinstance(layer, Attention): + raise NotImplementedError( + "Mxfp4 attention layer is not implemented") + return None + + +class Mxfp4MoEMethod(FusedMoEMethodBase): + + def __init__(self, moe: FusedMoEConfig): + super().__init__() + self.topk_indices_dtype = None + self.moe = moe + + def create_weights(self, layer: torch.nn.Module, num_experts: int, + hidden_size: int, intermediate_size_per_partition: int, + params_dtype: torch.dtype, **extra_weight_attrs): + self.num_experts = num_experts + weight_dtype = torch.uint8 + scale_dtype = torch.uint8 + + # FIXME (zyongye): ship after torch and safetensors support mxfp4 + # is_torch_mxfp4_available = ( + # hasattr(torch, "float4_e2m1fn_x2") and + # hasattr(torch, "float8_e8m0fnu")) + # if is_torch_mxfp4_available: + # weight_dtype = torch.float4_e2m1fn_x2 + # scale_dtype = torch.float8_e8m0fnu + + mxfp4_block = 32 + + # pad the intermediate size to be a multiple of 2 * mxfp4_block + # for to hold non-uniform sharded tensor as well as swizzling + if (envs.VLLM_USE_FLASHINFER_MXFP4_MOE + or envs.VLLM_USE_FLASHINFER_MXFP4_BF16_MOE): + intermediate_size_per_partition_after_pad = round_up( + intermediate_size_per_partition, 256) + hidden_size = round_up(hidden_size, 256) + elif current_platform.is_rocm(): + intermediate_size_per_partition_after_pad = round_up( + intermediate_size_per_partition, 128) + else: + intermediate_size_per_partition_after_pad = round_up( + intermediate_size_per_partition, 64) + + self.intermediate_size = intermediate_size_per_partition_after_pad + self.hidden_size = hidden_size + # Fused gate_up_proj (column parallel) + w13_weight = torch.nn.Parameter(torch.zeros( + num_experts, + 2 * intermediate_size_per_partition_after_pad, + hidden_size // 2, + dtype=weight_dtype), + requires_grad=False) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + w13_weight_scale = torch.nn.Parameter(torch.zeros( + num_experts, + 2 * intermediate_size_per_partition_after_pad, + hidden_size // mxfp4_block, + dtype=scale_dtype), + requires_grad=False) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + + w13_bias = torch.nn.Parameter(torch.zeros( + num_experts, + 2 * intermediate_size_per_partition_after_pad, + dtype=torch.bfloat16), + requires_grad=False) + layer.register_parameter("w13_bias", w13_bias) + set_weight_attrs(w13_bias, extra_weight_attrs) + + # down_proj (row parallel) + w2_weight = torch.nn.Parameter(torch.zeros( + num_experts, + hidden_size, + intermediate_size_per_partition_after_pad // 2, + dtype=weight_dtype), + requires_grad=False) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + w2_weight_scale = torch.nn.Parameter(torch.zeros( + num_experts, + hidden_size, + intermediate_size_per_partition_after_pad // mxfp4_block, + dtype=scale_dtype), + requires_grad=False) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) + + w2_bias = torch.nn.Parameter(torch.zeros(num_experts, + hidden_size, + dtype=torch.bfloat16), + requires_grad=False) + layer.register_parameter("w2_bias", w2_bias) + set_weight_attrs(w2_bias, extra_weight_attrs) + + def process_weights_after_loading(self, layer): + if (envs.VLLM_USE_FLASHINFER_MXFP4_MOE + or envs.VLLM_USE_FLASHINFER_MXFP4_BF16_MOE): + logger.info_once("Shuffling MoE weights, it might take a while...") + layer.gemm1_alpha = Parameter(torch.tensor( + [1.702] * self.num_experts, dtype=torch.float32).cuda(), + requires_grad=False) + layer.gemm1_beta = Parameter(torch.tensor( + [1.0] * self.num_experts, dtype=torch.float32).cuda(), + requires_grad=False) + layer.gemm1_clamp_limit = Parameter(torch.tensor( + [7.0] * self.num_experts, dtype=torch.float32).cuda(), + requires_grad=False) + sf_block_size = 32 # mxfp4 block size + + assert (layer.w13_weight.dim() == 3 + and layer.w13_weight.shape[0] == self.num_experts + and layer.w13_weight.shape[1] == self.intermediate_size * 2 + and layer.w13_weight.shape[2] == self.hidden_size // 2) + assert (layer.w13_weight_scale.dim() == 3 + and layer.w13_weight_scale.shape[0] == self.num_experts + and layer.w13_weight_scale.shape[1] + == self.intermediate_size * 2 + and layer.w13_weight_scale.shape[2] + == self.hidden_size // sf_block_size) + assert (layer.w2_weight.dim() == 3 + and layer.w2_weight.shape[0] == self.num_experts + and layer.w2_weight.shape[1] == self.hidden_size and + layer.w2_weight.shape[2] == self.intermediate_size // 2) + assert (layer.w2_weight_scale.dim() == 3 + and layer.w2_weight_scale.shape[1] == self.hidden_size + and layer.w2_weight_scale.shape[2] + == self.intermediate_size // sf_block_size) + assert (layer.w13_bias.dim() == 2 + and layer.w13_bias.shape[0] == self.num_experts + and layer.w13_bias.shape[1] == self.intermediate_size * 2) + assert (layer.w2_bias.dim() == 2 + and layer.w2_bias.shape[0] == self.num_experts + and layer.w2_bias.shape[1] == self.hidden_size) + + w13_weight_scale = layer.w13_weight_scale.data + w2_weight_scale = layer.w2_weight_scale.data + w13_weight = layer.w13_weight.data + w2_weight = layer.w2_weight.data + w13_bias = layer.w13_bias.data.to(torch.float32) + w2_bias = layer.w2_bias.data.to(torch.float32) + + # Swap w1 and w3 as the defenition of + # swiglu is different in the trtllm-gen + def swap_every_two_rows(x, axis=-1): + shape = x.shape + if axis < 0: + axis = len(shape) + axis + + # Create a new shape with pairs swapped along specified axis + new_shape = list(shape) + new_shape[axis] = shape[axis] // 2 + new_shape.insert(axis + 1, 2) + + # Reshape to expose pairs, swap them, and reshape back + x = x.reshape(*new_shape) + x = x.flip(axis + 1) + new_shape = list(shape) + return x.reshape(*new_shape) + + w13_weight_scale = swap_every_two_rows(w13_weight_scale, -2) + w13_weight = swap_every_two_rows(w13_weight, -2) + w13_bias = swap_every_two_rows(w13_bias, -1) + + # Do not interleave as the checkpoint is already interleaved + + # Shuffle weights and scaling factors for transposed mma output + gemm1_weights_mxfp4_shuffled = [] + gemm1_scales_mxfp4_shuffled = [] + gemm2_weights_mxfp4_shuffled = [] + gemm2_scales_mxfp4_shuffled = [] + gemm1_bias_shuffled = [] + gemm2_bias_shuffled = [] + epilogue_tile_m = 128 # FIXME: this depends on the kernel internals + for i in range(self.num_experts): + gemm1_weights_mxfp4_shuffled.append( + shuffle_matrix_a(w13_weight[i].view(torch.uint8), + epilogue_tile_m)) + gemm1_scales_mxfp4_shuffled.append( + shuffle_matrix_sf_a(w13_weight_scale[i].view(torch.uint8), + epilogue_tile_m)) + gemm1_bias_shuffled.append( + shuffle_matrix_a(w13_bias[i].clone().reshape(-1, 1), + epilogue_tile_m)) + + gemm2_weights_mxfp4_shuffled.append( + shuffle_matrix_a(w2_weight[i].view(torch.uint8), + epilogue_tile_m)) + gemm2_scales_mxfp4_shuffled.append( + shuffle_matrix_sf_a(w2_weight_scale[i].view(torch.uint8), + epilogue_tile_m)) + gemm2_bias_shuffled.append( + shuffle_matrix_a(w2_bias[i].clone().reshape(-1, 1), + epilogue_tile_m)) + + w13_weight = torch.stack(gemm1_weights_mxfp4_shuffled) + w13_weight_scale = torch.stack( + gemm1_scales_mxfp4_shuffled).reshape( + self.num_experts, 2 * self.intermediate_size, + self.hidden_size // sf_block_size).view( + torch.float8_e4m3fn) + + w2_weight = torch.stack(gemm2_weights_mxfp4_shuffled) + w2_weight_scale = torch.stack(gemm2_scales_mxfp4_shuffled).reshape( + self.num_experts, self.hidden_size, self.intermediate_size // + sf_block_size).view(torch.float8_e4m3fn) + + layer.w13_weight = Parameter(w13_weight, requires_grad=False) + layer.w13_weight_scale = Parameter(w13_weight_scale, + requires_grad=False) + layer.w2_weight = Parameter(w2_weight, requires_grad=False) + layer.w2_weight_scale = Parameter(w2_weight_scale, + requires_grad=False) + layer.w13_bias = Parameter( + torch.stack(gemm1_bias_shuffled).reshape(self.num_experts, -1), + requires_grad=False) + layer.w2_bias = Parameter(torch.stack(gemm2_bias_shuffled).reshape( + self.num_experts, -1), + requires_grad=False) + return + from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig + + w13_bias = layer.w13_bias.to(torch.float32) + w2_bias = layer.w2_bias.to(torch.float32) + + layer.w13_bias = Parameter(w13_bias, requires_grad=False) + layer.w2_bias = Parameter(w2_bias, requires_grad=False) + + # FIXME warp need to be adjusted based on batch size + # only apply to batched mode + if self.moe.use_ep: + num_warps = 4 if envs.VLLM_MOE_DP_CHUNK_SIZE <= 512 else 8 + else: + num_warps = 8 + + w13_weight, w13_flex, w13_scale = _swizzle_mxfp4( + layer.w13_weight, layer.w13_weight_scale, num_warps) + w2_weight, w2_flex, w2_scale = _swizzle_mxfp4(layer.w2_weight, + layer.w2_weight_scale, + num_warps) + + self.w13_precision_config = PrecisionConfig( + weight_scale=w13_scale, flex_ctx=FlexCtx(rhs_data=w13_flex)) + self.w2_precision_config = PrecisionConfig( + weight_scale=w2_scale, flex_ctx=FlexCtx(rhs_data=w2_flex)) + + self.w13_weight_triton_tensor = w13_weight + self.w2_weight_triton_tensor = w2_weight + + # need to delete the original weights to save memory on single GPU + del layer.w13_weight + del layer.w2_weight + layer.w13_weight = None + layer.w2_weight = None + torch.cuda.empty_cache() + + def select_gemm_impl( + self, + prepare_finalize: FusedMoEPrepareAndFinalize, + moe: FusedMoEConfig, + ) -> FusedMoEPermuteExpertsUnpermute: + # this is happen after the weights are loaded + # so we can safely use precision config + from vllm.model_executor.layers.fused_moe.triton_kernels_moe import ( + BatchedOAITritonExperts) + if (prepare_finalize.activation_format == + FusedMoEActivationFormat.BatchedExperts): + max_num_tokens_per_rank = ( + prepare_finalize.max_num_tokens_per_rank()) + assert max_num_tokens_per_rank is not None + return BatchedOAITritonExperts( + None, + max_num_tokens=max_num_tokens_per_rank, + num_dispatchers=prepare_finalize.num_dispatchers(), + w1_precision=self.w13_precision_config, + w2_precision=self.w2_precision_config, + ) + else: + raise NotImplementedError( + "Mxfp4 does not support non-batched experts format for EP") + + def _get_tile_tokens_dim(self, x: torch.Tensor, top_k: int): + # Number of tokens in the input tensor. + num_tokens = x.shape[0] + # Factor to account for the imbalance of the experts. + # factor equals to the + # max_real_num_tokens_per_expert / perfect_num_tokens_per_expert + # - 1.0 means perfect expert distribution. + # - > 1.0 means some experts have more + # tokens than the perfect distribution. + # - < 1.0 does not make sense. + imbalance_factor = 1.3 + # Calculate the number of tokens per expert + # assuming perfect distribution. + num_tokens_per_expert = (num_tokens * top_k) // self.num_experts + # Apply the imbalance factor. + num_tokens_per_expert = int(num_tokens_per_expert * imbalance_factor) + # And pad the number to the next power of 2. + tile_tokens_dim = next_power_of_2(num_tokens_per_expert) + # Cap to 8-64 tokens per CTA tile + # as it's the range supported by the kernel. + tile_tokens_dim = min(max(tile_tokens_dim, 8), 64) + + return tile_tokens_dim + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: Optional[torch.Tensor] = None, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + # avoid import error when triton_kernel is not installed + from vllm.model_executor.layers.fused_moe.triton_kernels_moe import ( + triton_kernel_moe_forward) + + if enable_eplb: + raise NotImplementedError("EPLB is not supported for mxfp4") + + if (envs.VLLM_USE_FLASHINFER_MXFP4_MOE + or envs.VLLM_USE_FLASHINFER_MXFP4_BF16_MOE): + assert not self.moe.use_ep, ( + "EP is not supported for flashinfer mxfp4 moe backend yet.") + if envs.VLLM_USE_FLASHINFER_MXFP4_BF16_MOE: + assert x.dtype == torch.bfloat16 + x_quant = x + x_scale = None + else: + x_quant, x_scale = mxfp8_quantize(x, False) # to mxfp8 + x_scale = x_scale.view(torch.float8_e4m3fn).reshape(-1) + trtllm_gen_output = trtllm_fp4_block_scale_moe( + router_logits.to(torch.bfloat16), + None, # routing_bias + x_quant, + x_scale, + layer.w13_weight, # uint8 (e2m1 x 2) + layer.w13_weight_scale, # uint8 (e4m3 x 2) + layer.w13_bias, # fp32 per expert per channel + layer.gemm1_alpha, # fp32 per expert + layer.gemm1_beta, # fp32 per expert + layer.gemm1_clamp_limit, # fp32 per expert + layer.w2_weight, # uint8 (e2m1 x 2) + layer.w2_weight_scale, # ue8m0 + layer.w2_bias, # fp32 per expert per channel + None, # output1_scale_scalar + None, # output1_scale_gate_scalar + None, # output2_scale_scalar + self.num_experts, + top_k, + None, # n_group + None, # topk_group + self.intermediate_size, # padded to multiple of 256 + 0, # local_expert_offset + self.num_experts, # local num experts + None, + self._get_tile_tokens_dim(x, top_k), + 1, # routing_method_type, renormalize + True, # do finalize + )[0] + return trtllm_gen_output + else: + assert _can_support_triton_kernels( + use_grouped_topk, topk_group, num_expert_group, + custom_routing_function, e_score_correction_bias, scoring_func, + activation, expert_load_view, logical_to_physical_map, + logical_replica_count), ( + "Triton kernels are not supported for \ + mxfp4 MoE with the current configuration.") + if self.moe.use_ep: + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias, + indices_type=self.topk_indices_dtype, + enable_eplb=enable_eplb, + expert_map=expert_map, + expert_load_view=expert_load_view, + logical_to_physical_map=logical_to_physical_map, + logical_replica_count=logical_replica_count, + ) + + return self.fused_experts( + x, + self.w13_weight_triton_tensor, + self.w2_weight_triton_tensor, + topk_weights, + topk_ids, + global_num_experts=global_num_experts, + expert_map=expert_map, + w1_bias=layer.w13_bias, + w2_bias=layer.w2_bias, + ) + else: + return triton_kernel_moe_forward( + hidden_states=x, + w1=self.w13_weight_triton_tensor, + w2=self.w2_weight_triton_tensor, + gating_output=router_logits, + topk=top_k, + renormalize=renormalize, + global_num_experts=global_num_experts, + expert_map=expert_map, + w1_bias=layer.w13_bias, + w2_bias=layer.w2_bias, + w1_precision=self.w13_precision_config, + w2_precision=self.w2_precision_config, + apply_router_weight_on_input=apply_router_weight_on_input, + ) diff --git a/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py b/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py index 1119045db072..dd8f591e71c4 100644 --- a/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py +++ b/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py @@ -2,11 +2,54 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import torch +from vllm.platforms import current_platform from vllm.utils import direct_register_custom_op OCP_MX_BLOCK_SIZE = 32 +def _can_support_triton_kernels(use_grouped_topk, topk_group, num_expert_group, + custom_routing_function, + e_score_correction_bias, scoring_func, + activation, expert_load_view, + logical_to_physical_map, + logical_replica_count): + return not (use_grouped_topk or topk_group or num_expert_group + or custom_routing_function or e_score_correction_bias + or scoring_func != "softmax" or activation != "silu" + or expert_load_view or logical_to_physical_map + or logical_replica_count) + + +def _swizzle_mxfp4(quant_tensor, scale, num_warps): + """ weight swizzle for mxfp4 moe, used for OAI mxfp4 kernel + """ + import triton_kernels.matmul_ogs_details.opt_flags as opt_flags + from triton_kernels.numerics import InFlexData + from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor + from triton_kernels.tensor_details import layout + value_layout, value_layout_opts = layout.make_default_matmul_mxfp4_w_layout( + mx_axis=1) + scale_layout, scale_layout_opts = ( + layout.make_default_matmul_mxfp4_w_scale_layout(mx_axis=1, + num_warps=num_warps)) + if current_platform.is_cuda() and \ + current_platform.is_device_capability(100): + constraints = { + "is_persistent": True, + "epilogue_subtile": 1, + } + opt_flags.update_opt_flags_constraints(constraints) + # transpose the tensor so that the quantization axis is on dim1 + quant_tensor = quant_tensor.transpose(-2, -1) + scale = scale.transpose(-2, -1) + quant_tensor = convert_layout(wrap_torch_tensor(quant_tensor, dtype=FP4), + value_layout, **value_layout_opts) + scale = convert_layout(wrap_torch_tensor(scale), scale_layout, + **scale_layout_opts) + return quant_tensor, InFlexData(), scale + + def _dequant_mxfp4(x: torch.Tensor, scale: torch.Tensor, float_dtype: torch.dtype) -> torch.Tensor: try: diff --git a/vllm/model_executor/layers/utils.py b/vllm/model_executor/layers/utils.py index cd32f12f3c26..48a347a8f561 100644 --- a/vllm/model_executor/layers/utils.py +++ b/vllm/model_executor/layers/utils.py @@ -11,6 +11,27 @@ from vllm.utils import direct_register_custom_op +def shuffle_weight(w: torch.Tensor) -> torch.Tensor: + # Shuffle weight along the last dimension so that + # we folded the weights to adjance location + # Example: + # input: + # [[1, 2, 3, 4, 5, 6], + # [7, 8, 9, 10, 11, 12]] + # output: + # [[1, 4, 2, 5, 3, 6], + # [7, 10, 8, 11, 9, 12]] + # This will be used together with triton swiglu kernel + shape = w.shape + N = shape[-1] + first = w[..., :N // 2] + second = w[..., N // 2:] + + stacked = torch.stack((first, second), dim=-1) + w_shuffled = stacked.reshape(shape) + return w_shuffled + + def get_token_bin_counts_and_mask( tokens: torch.Tensor, vocab_size: int, diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index 9030ff307bee..9fe5e5c3698c 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -218,6 +218,33 @@ def verify_and_update_config(vllm_config: "VllmConfig") -> None: config.max_model_len) +class GptOssConfig(VerifyAndUpdateConfig): + + @staticmethod + def verify_and_update_config(vllm_config: "VllmConfig") -> None: + decoding_config = vllm_config.decoding_config + if decoding_config.reasoning_backend == "": + decoding_config.reasoning_backend = "openai" + + # Increase the max capture size from 512 to 1024 for performance. + # NOTE(woosuk): This will increase the number of CUDA graphs + # from 67 to 83. + scheduler_config = vllm_config.scheduler_config + if len(scheduler_config.cuda_graph_sizes) == 1: + max_capture_size = scheduler_config.cuda_graph_sizes[0] + # FIXME(woosuk): When using full cuda graph with FA3, the max + # supported size is 992. + if max_capture_size < 1024: + cuda_graph_sizes = [1, 2, 4] + # Step size 8 for small batch sizes + cuda_graph_sizes += [i for i in range(8, 256, 8)] + # Step size 16 for larger batch sizes + cuda_graph_sizes += [i for i in range(256, 1025, 16)] + scheduler_config.cuda_graph_sizes = cuda_graph_sizes + logger.info("Overriding cuda graph sizes to %s", + cuda_graph_sizes) + + class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig): @classmethod @@ -313,4 +340,5 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: "XLMRobertaModel": JinaRobertaModelConfig, "JinaVLForRanking": JinaVLForSequenceClassificationConfig, "GraniteMoeHybridForCausalLM": GraniteMoeHybridModelConfig, + "GptOssForCausalLM": GptOssConfig, } diff --git a/vllm/model_executor/models/gpt_oss.py b/vllm/model_executor/models/gpt_oss.py new file mode 100644 index 000000000000..5ce35c20dc79 --- /dev/null +++ b/vllm/model_executor/models/gpt_oss.py @@ -0,0 +1,472 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Iterable +from typing import Optional + +import torch +import torch.distributed as dist +from torch import nn + +from vllm import envs +from vllm.attention import Attention, AttentionType +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import (get_ep_group, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.configs import GptOssConfig +from vllm.utils import cdiv + +from .utils import extract_layer_index, maybe_prefix + + +class OAIAttention(nn.Module): + + def __init__( + self, + config: GptOssConfig, + quant_config: Optional[QuantizationConfig] = None, + cache_config: Optional[CacheConfig] = None, + prefix: str = "", + ): + super().__init__() + self.layer_idx = extract_layer_index(prefix) + self.head_dim = config.head_dim + self.num_attention_heads = config.num_attention_heads + self.num_key_value_heads = config.num_key_value_heads + self.hidden_size = config.hidden_size + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=config.max_position_embeddings, + base=config.rope_theta, + dtype=torch.float32, + rope_scaling={ + "rope_type": + "yarn", + "factor": + config.rope_scaling["factor"], + "original_max_position_embeddings": + config.rope_scaling["original_max_position_embeddings"], + "beta_fast": + config.rope_ntk_beta, + "beta_slow": + config.rope_ntk_alpha, + }, + is_neox_style=True, + ) + + tp_size = get_tensor_model_parallel_world_size() + + attention_sink_dtype = ( + torch.float32 if envs.VLLM_USE_TRTLLM_CONTEXT_ATTENTION + or envs.VLLM_USE_TRTLLM_DECODE_ATTENTION else torch.bfloat16) + self.sinks = torch.nn.Parameter( + torch.empty(config.num_attention_heads // tp_size, + dtype=attention_sink_dtype, + requires_grad=False)) + + self.norm = RMSNorm(config.hidden_size, eps=1e-5) + + self.q_size = self.num_attention_heads * self.head_dim // tp_size + self.kv_size = self.num_key_value_heads * self.head_dim // tp_size + self.scaling = self.head_dim**-0.5 + self.rope_theta = config.rope_theta + + self.qkv = QKVParallelLinear( + hidden_size=self.hidden_size, + head_size=self.head_dim, + total_num_heads=self.num_attention_heads, + total_num_kv_heads=self.num_key_value_heads, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + + self.o_proj = RowParallelLinear( + input_size=self.num_attention_heads * self.head_dim, + output_size=self.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + + self.num_local_attention_heads = config.num_attention_heads // tp_size + self.num_local_key_value_heads = config.num_key_value_heads // tp_size + + # Only apply sliding window to every other layer + sliding_window = (config.sliding_window if self.layer_idx % + 2 == 0 else None) + self.attn = Attention( + self.num_local_attention_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_local_key_value_heads, + cache_config=cache_config, + quant_config=quant_config, + per_layer_sliding_window=sliding_window, + attn_type=AttentionType.DECODER, + prefix=f"{prefix}.attn", + sinks=self.sinks, + ) + + def forward(self, hidden_states: torch.Tensor, + positions: torch.Tensor) -> torch.Tensor: + t = self.norm(hidden_states) + + qkv, _ = self.qkv(t) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + v = v.contiguous() + attn_output = self.attn(q, k, v) + output, _ = self.o_proj(attn_output) + + return output + hidden_states + + +class MLPBlock(torch.nn.Module): + + def __init__( + self, + config: GptOssConfig, + layer_idx: int, + quant_config: QuantizationConfig, + prefix: str = "", + ): + super().__init__() + self.layer_idx = layer_idx + self.num_experts = config.num_local_experts + self.experts_per_token = config.num_experts_per_tok + self.world_size = dist.get_world_size() if dist.is_initialized() else 1 + self.norm = RMSNorm(config.hidden_size, eps=1e-5) + self.router = torch.nn.Linear(config.hidden_size, + config.num_local_experts, + dtype=torch.bfloat16) + assert config.intermediate_size % self.world_size == 0 + self.experts = FusedMoE(num_experts=config.num_local_experts, + top_k=config.num_experts_per_token, + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + reduce_results=True, + renormalize=True, + quant_config=quant_config, + prefix=f"{prefix}.experts", + apply_router_weight_on_input=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + t = self.norm(x) + g = self.router(t) + t = self.experts(hidden_states=t, router_logits=g) + return x + t + + +class TransformerBlock(torch.nn.Module): + + def __init__( + self, + config: GptOssConfig, + quant_config: QuantizationConfig, + prefix: str = "", + ): + super().__init__() + self.layer_idx = extract_layer_index(prefix) + self.attn = OAIAttention(config, prefix=f"{prefix}.attn") + self.mlp = MLPBlock(config, + self.layer_idx, + quant_config=quant_config, + prefix=f"{prefix}.mlp") + + def forward(self, hidden_states: torch.Tensor, + positions: torch.Tensor) -> torch.Tensor: + attn_output = self.attn(hidden_states, positions) + output = self.mlp(attn_output) + return output + + +@support_torch_compile +class GptOssModel(nn.Module): + + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + ): + super().__init__() + self.config = vllm_config.model_config.hf_config + self.quant_config = vllm_config.quant_config + self.config.hidden_size = self.config.hidden_size + self.embedding = VocabParallelEmbedding( + self.config.vocab_size, + self.config.hidden_size, + ) + self.layers = torch.nn.ModuleList([ + TransformerBlock( + self.config, + quant_config=self.quant_config, + prefix=maybe_prefix(prefix, f"block.{layer_idx}"), + ) for layer_idx in range(self.config.num_hidden_layers) + ]) + self.norm = RMSNorm(self.config.hidden_size, eps=1e-5) + + def forward(self, input_ids: torch.Tensor, + positions: torch.Tensor) -> torch.Tensor: + x = self.embedding(input_ids) + for layer in self.layers: + x = layer(x, positions) + x = self.norm(x) + return x + + +class GptOssForCausalLM(nn.Module): + + def __init__( + self, + vllm_config: VllmConfig, + prefix: str = "", + ): + super().__init__() + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config.hf_config + self.model = GptOssModel( + vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model"), + ) + self.lm_head = ParallelLMHead( + self.model_config.vocab_size, + self.model_config.hidden_size, + ) + self.logits_processor = LogitsProcessor(self.model_config.vocab_size) + + def forward(self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None) -> torch.Tensor: + assert intermediate_tensors is None + assert inputs_embeds is None + return self.model(input_ids, positions) + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + rename_mapping = { + "self_attn": "attn", + "input_layernorm.weight": "attn.norm.weight", + "post_attention_layernorm.weight": "mlp.norm.weight", + "embed_tokens": "embedding", + } + + def maybe_rename(name: str) -> str: + for remap_name, new_name in rename_mapping.items(): + if remap_name in name: + return name.replace(remap_name, new_name) + return name + + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + mxfp4_block = 32 + + tp_rank = get_tensor_model_parallel_rank() + tp_size = get_tensor_model_parallel_world_size() + intermediate_size = self.model_config.intermediate_size + intermediate_size_block = intermediate_size // mxfp4_block + per_rank_intermediate_size_block = cdiv(intermediate_size_block, + tp_size) + per_rank_intermediate_size = (per_rank_intermediate_size_block * + mxfp4_block) + + # Calculate common slicing bounds for current rank + tp_rank_start = tp_rank * per_rank_intermediate_size + tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size, + intermediate_size) + + # Attention heads per rank + heads_per_rank = self.model_config.num_attention_heads // tp_size + head_start = tp_rank * heads_per_rank + + use_ep = self.vllm_config.parallel_config.enable_expert_parallel + ep_size = get_ep_group().world_size + ep_rank = get_ep_group().rank + num_experts = self.model_config.num_local_experts + experts_per_rank = num_experts // ep_size + ep_rank_start = ep_rank * experts_per_rank + ep_rank_end = (ep_rank + 1) * experts_per_rank + + for name, weight in weights: + weight = weight.cuda() + + if "gate_up_proj_blocks" in name: + # Handle MLP gate and up projection weights + new_name = name.replace("gate_up_proj_blocks", "w13_weight") + + # flat weight from (E, 2 * N, block_size, entry_per_block) + # to (E, 2 * N, -1), shouldn't trigger copy for contiguous + weight = weight.view(num_experts, 2 * intermediate_size, + -1).contiguous() + + # Extract gate and up projection parts + # since the weight is shuffled, we can slice directly + if use_ep: + narrow_weight = weight[ep_rank_start:ep_rank_end, ...] + else: + narrow_weight = weight[:, + 2 * tp_rank_start:2 * tp_rank_end, + ...] + + param = params_dict[new_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, + narrow_weight, + weight_name=new_name, + shard_id=None, + expert_id=None) + loaded_params.add(new_name) + + elif "down_proj_blocks" in name: + # Handle MLP down projection weights + new_name = name.replace("down_proj_blocks", "w2_weight") + # same flatten here, but since 2 mx4 value are packed in 1 + # uint8, divide by 2 + weight = weight.view(num_experts, -1, + intermediate_size // 2).contiguous() + if use_ep: + narrow_weight = weight[ep_rank_start:ep_rank_end, ...] + else: + narrow_weight = weight[..., + tp_rank_start // 2:tp_rank_end // 2] + + param = params_dict[new_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, + narrow_weight, + weight_name=new_name, + shard_id=None, + expert_id=None) + loaded_params.add(new_name) + + elif "gate_up_proj_scales" in name: + # Handle MLP gate and up projection weights scale + new_name = name.replace("gate_up_proj_scales", + "w13_weight_scale") + if use_ep: + narrow_weight = weight[ep_rank_start:ep_rank_end, ...] + else: + narrow_weight = weight[:, + 2 * tp_rank_start:2 * tp_rank_end, + ...] + + param = params_dict[new_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, + narrow_weight, + weight_name=new_name, + shard_id=None, + expert_id=None) + loaded_params.add(new_name) + + elif "down_proj_scales" in name: + # Handle MLP down projection weights + new_name = name.replace("down_proj_scales", "w2_weight_scale") + if use_ep: + narrow_weight = weight[ep_rank_start:ep_rank_end, ...] + else: + narrow_weight = weight[..., tp_rank_start // + mxfp4_block:tp_rank_end // + mxfp4_block] + + param = params_dict[new_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, + narrow_weight, + weight_name=new_name, + shard_id=None, + expert_id=None) + loaded_params.add(new_name) + elif "gate_up_proj_bias" in name: + # Handle MLP gate and up projection biases + new_name = name.replace("gate_up_proj_bias", "w13_bias") + + # Extract gate and up projection bias parts + if use_ep: + narrow_weight = weight[ep_rank_start:ep_rank_end, ...] + else: + narrow_weight = weight[:, + 2 * tp_rank_start:2 * tp_rank_end] + + param = params_dict[new_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, + narrow_weight, + weight_name=new_name, + shard_id=None, + expert_id=None) + loaded_params.add(new_name) + + elif "down_proj_bias" in name: + # Handle MLP down projection bias + new_name = name.replace("down_proj_bias", "w2_bias") + param = params_dict[new_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + if use_ep: + weight = weight[ep_rank_start:ep_rank_end, ...] + else: + # (only load on rank 0 to avoid duplication) + if tp_rank != 0: + weight.zero_() + weight_loader(param, + weight, + weight_name=new_name, + shard_id=None, + expert_id=None) + loaded_params.add(new_name) + elif "sinks" in name: + # Handle attention sinks (distributed across ranks) + name = name.replace("self_attn", "attn") + param = params_dict[name] + narrow_weight = weight.narrow(0, head_start, heads_per_rank) + param.data.copy_(narrow_weight) + loaded_params.add(name) + elif "q_proj" in name or "k_proj" in name or "v_proj" in name: + shard_id = ("q" if "q_proj" in name else + "k" if "k_proj" in name else "v") + name = name.replace("self_attn", "attn") + param_name = name.replace(f"{shard_id}_proj", "qkv") + param = params_dict[param_name] + weight_loader = param.weight_loader + weight_loader(param, weight, loaded_shard_id=shard_id) + loaded_params.add(param_name) + else: + # Handle all other weights with potential renaming + renamed_name = maybe_rename(name) + if renamed_name not in params_dict: + print(f"Warning: {renamed_name} not found in params_dict") + continue + param = params_dict[renamed_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, weight) + loaded_params.add(renamed_name) + + return loaded_params diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 9b6ab52d8680..b6e50bcebc93 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -74,6 +74,7 @@ "GlmForCausalLM": ("glm", "GlmForCausalLM"), "Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"), "Glm4MoeForCausalLM": ("glm4_moe", "Glm4MoeForCausalLM"), + "GptOssForCausalLM": ("gpt_oss", "GptOssForCausalLM"), "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"), "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"), "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"), @@ -138,6 +139,7 @@ "TeleFLMForCausalLM": ("teleflm", "TeleFLMForCausalLM"), "XverseForCausalLM": ("llama", "LlamaForCausalLM"), "Zamba2ForCausalLM": ("zamba2", "Zamba2ForCausalLM"), + # [Encoder-decoder] "BartModel": ("bart", "BartForConditionalGeneration"), "BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"), diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 54ffc83cd565..d26e4b335038 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -127,7 +127,8 @@ def use_rocm_custom_paged_attention( max_seq_len: int, sliding_window: int, kv_cache_dtype: str, - alibi_slopes: Optional[torch.Tensor] = None) -> bool: + alibi_slopes: Optional[torch.Tensor] = None, + sinks: Optional[torch.Tensor] = None) -> bool: GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName ON_GFX9 = any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"]) @@ -145,7 +146,7 @@ def use_rocm_custom_paged_attention( and max_seq_len <= 128 * 1024 and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN) and not (envs.VLLM_ROCM_USE_AITER_PAGED_ATTN - and envs.VLLM_ROCM_USE_AITER)) + and envs.VLLM_ROCM_USE_AITER) and sinks is None) else: return (ON_GFX11_GFX12 and (not envs.VLLM_USE_V1 or sliding_window == 0 @@ -155,7 +156,7 @@ def use_rocm_custom_paged_attention( and (gqa_ratio >= 3 and gqa_ratio <= 16) and max_seq_len <= 128 * 1024 and alibi_slopes is None and kv_cache_dtype == "auto" - and envs.VLLM_ROCM_CUSTOM_PAGED_ATTN) + and envs.VLLM_ROCM_CUSTOM_PAGED_ATTN and sinks is None) class RocmPlatform(Platform): @@ -170,7 +171,7 @@ class RocmPlatform(Platform): supported_quantization: list[str] = [ "awq", "gptq", "fp8", "compressed-tensors", "fbgemm_fp8", "gguf", - "quark", "ptpc_fp8" + "quark", "ptpc_fp8", "mxfp4" ] @classmethod @@ -469,4 +470,4 @@ def device_count(cls) -> int: @classmethod def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str) -> bool: - return True \ No newline at end of file + return True diff --git a/vllm/reasoning/__init__.py b/vllm/reasoning/__init__.py index 1c3f78f2edbf..33d64b413f90 100644 --- a/vllm/reasoning/__init__.py +++ b/vllm/reasoning/__init__.py @@ -7,6 +7,7 @@ from .granite_reasoning_parser import GraniteReasoningParser from .hunyuan_a13b_reasoning_parser import HunyuanA13BReasoningParser from .mistral_reasoning_parser import MistralReasoningParser +from .openai_reasoning_parser import OpenAIReasoningParser from .qwen3_reasoning_parser import Qwen3ReasoningParser from .step3_reasoning_parser import Step3ReasoningParser @@ -20,4 +21,5 @@ "Glm4MoeModelReasoningParser", "MistralReasoningParser", "Step3ReasoningParser", + "OpenAIReasoningParser", ] diff --git a/vllm/reasoning/openai_reasoning_parser.py b/vllm/reasoning/openai_reasoning_parser.py new file mode 100644 index 000000000000..1b3f7a3c941f --- /dev/null +++ b/vllm/reasoning/openai_reasoning_parser.py @@ -0,0 +1,74 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Sequence +from typing import Optional, Union + +from transformers import PreTrainedTokenizerBase + +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + DeltaMessage) +from vllm.logger import init_logger +from vllm.reasoning import ReasoningParser, ReasoningParserManager + +logger = init_logger(__name__) + + +@ReasoningParserManager.register_module("openai") +class OpenAIReasoningParser(ReasoningParser): + """ + Reasoning parser for OpenAI model. + + The OpenAI model uses harmony to extract reasoning content and this parser + is only used for detecting the end of the reasoning content. + """ + + # <|start|>assistant<|channel|>final<|message|> + # token_ids: list[int] = [200006, 173781, 200005, 17196, 200008] + + def __init__(self, tokenizer: PreTrainedTokenizerBase): + super().__init__(tokenizer) + self.reasoning_end_token_ids = self.model_tokenizer.encode( + "<|start|>assistant<|channel|>final<|message|>") + print("reasoning_end_token_ids", self.reasoning_end_token_ids) + + def is_reasoning_end(self, input_ids: list[int]) -> bool: + + def find_last_index(lst, value): + return len(lst) - 1 - list(reversed(lst)).index(value) + + last_start = find_last_index(input_ids, + self.reasoning_end_token_ids[0]) + if last_start is None: + return False + for i in range(5): + if last_start + i >= len(input_ids): + return False + if input_ids[last_start + i] != self.reasoning_end_token_ids[i]: + return False + return True + + def extract_content_ids(self, input_ids: list[int]) -> list[int]: + raise RuntimeError( + "OpenAI model uses harmony to extract reasoning content. This " + "function should not be called.") + + def extract_reasoning_content_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + ) -> Union[DeltaMessage, None]: + raise RuntimeError( + "OpenAI model uses harmony to extract reasoning content. This " + "function should not be called.") + + def extract_reasoning_content( + self, model_output: str, request: ChatCompletionRequest + ) -> tuple[Optional[str], Optional[str]]: + raise RuntimeError( + "OpenAI model uses harmony to extract reasoning content. This " + "function should not be called.") diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index cc41a771d06c..c6ef3445a4a3 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -30,9 +30,10 @@ # yapf conflicts with isort for this block # yapf: disable from vllm.transformers_utils.configs import (ChatGLMConfig, DeepseekVLV2Config, - EAGLEConfig, JAISConfig, - KimiVLConfig, MedusaConfig, - MllamaConfig, MLPSpeculatorConfig, + EAGLEConfig, GptOssConfig, + JAISConfig, KimiVLConfig, + MedusaConfig, MllamaConfig, + MLPSpeculatorConfig, Nemotron_Nano_VL_Config, NemotronConfig, NVLM_D_Config, RWConfig, SpeculatorsConfig, @@ -88,6 +89,7 @@ def _get_hf_token() -> Optional[str]: "ultravox": UltravoxConfig, "step3_vl": Step3VLConfig, "step3_text": Step3TextConfig, + "gpt_oss": GptOssConfig, **_CONFIG_REGISTRY_OVERRIDE_HF } diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index 64ace167a5a0..b37b82583535 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -14,6 +14,7 @@ # tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the # `FalconConfig` class from the official HuggingFace transformers library. from vllm.transformers_utils.configs.falcon import RWConfig +from vllm.transformers_utils.configs.gpt_oss import GptOssConfig from vllm.transformers_utils.configs.jais import JAISConfig from vllm.transformers_utils.configs.kimi_vl import KimiVLConfig from vllm.transformers_utils.configs.medusa import MedusaConfig @@ -50,4 +51,5 @@ "Step3VLConfig", "Step3VisionEncoderConfig", "Step3TextConfig", + "GptOssConfig", ] diff --git a/vllm/transformers_utils/configs/gpt_oss.py b/vllm/transformers_utils/configs/gpt_oss.py new file mode 100644 index 000000000000..27a08a87a4d0 --- /dev/null +++ b/vllm/transformers_utils/configs/gpt_oss.py @@ -0,0 +1,39 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from transformers import PretrainedConfig + + +class GptOssConfig(PretrainedConfig): + model_type = "gpt_oss" + + def __init__(self, + num_hidden_layers: int = 36, + num_experts: int = 128, + num_experts_per_token: int = 4, + vocab_size: int = 201088, + hidden_size: int = 2880, + intermediate_size: int = 2880, + head_dim: int = 64, + num_attention_heads: int = 64, + num_key_value_heads: int = 8, + sliding_window: int = 128, + rope_theta: float = 150000.0, + rope_scaling_factor: float = 32.0, + rope_ntk_alpha: float = 1.0, + rope_ntk_beta: float = 32.0, + **kwargs): + super().__init__(**kwargs) + self.num_hidden_layers = num_hidden_layers + self.num_experts = num_experts + self.num_experts_per_token = num_experts_per_token + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.head_dim = head_dim + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.sliding_window = sliding_window + self.rope_theta = rope_theta + self.rope_scaling_factor = rope_scaling_factor + self.rope_ntk_alpha = rope_ntk_alpha + self.rope_ntk_beta = rope_ntk_beta diff --git a/vllm/utils/__init__.py b/vllm/utils/__init__.py index ce62282c2199..ad58225b521c 100644 --- a/vllm/utils/__init__.py +++ b/vllm/utils/__init__.py @@ -3264,6 +3264,34 @@ def set_process_title(name: str, setproctitle.setproctitle(name) +def print_logo(): + + logo = """ + LL LL MMM MMM + LL LL MMMM MMMM + V LL LL MM MM MM MM +vvvv VVVV LL LL MM MM MM MM +vvvv VVVV LL LL MM MMM MM + vvv VVVV LL LL MM M MM + vvVVVV LL LL MM MM + VVVV LLLLLLLLLL LLLLLLLLL M M +""" + + if sys.stdout.isatty(): + + def color256(index): + return f"\033[38;5;{index}m" + + reset = "\033[0m" + colored_logo = ( + logo.replace("v", color256(214) + "v" + reset) \ + .replace("V", color256(33) + "V" + reset) + ) + print(colored_logo) + else: + print(logo) + + def _add_prefix(file: TextIO, worker_name: str, pid: int) -> None: """Prepend each output line with process-specific prefix""" @@ -3311,3 +3339,9 @@ def decorate_logs(process_name: Optional[str] = None) -> None: pid = os.getpid() _add_prefix(sys.stdout, process_name, pid) _add_prefix(sys.stderr, process_name, pid) + + +def has_triton_kernels() -> bool: + """Whether the optional `triton_kernels` package is available.""" + + return _has_module("triton_kernels") diff --git a/vllm/utils/flashinfer.py b/vllm/utils/flashinfer.py index 29967bc51671..664e65cc05b5 100644 --- a/vllm/utils/flashinfer.py +++ b/vllm/utils/flashinfer.py @@ -160,7 +160,7 @@ def use_trtllm_decode_attention( # Check if the dimensions are supported by TRTLLM decode attention if (attn_head_size is None or num_qo_heads is None or num_kv_heads is None or num_qo_heads // num_kv_heads > 8 - or num_qo_heads % num_kv_heads != 0 or attn_head_size != 128): + or num_qo_heads % num_kv_heads != 0): return False env_value = envs.VLLM_USE_TRTLLM_DECODE_ATTENTION @@ -185,6 +185,43 @@ def use_trtllm_decode_attention( return use_trtllm +def use_trtllm_context_attention( + num_tokens: int, + max_seq_len: int, + kv_cache_dtype: str, + num_qo_heads: Optional[int], + num_kv_heads: Optional[int], + attn_head_size: Optional[int], +) -> bool: + # Requires SM100 and NVIDIA artifactory to be accessible to download cubins + if not (current_platform.is_device_capability(100) + and has_nvidia_artifactory()): + return False + + # TODO: update the check to compatible with latest trtllm-gen kernel + # Check if the dimensions are supported by TRTLLM decode attention + if (attn_head_size is None or num_qo_heads is None or num_kv_heads is None + or num_qo_heads // num_kv_heads > 8 + or num_qo_heads % num_kv_heads != 0): + return False + + env_value = envs.VLLM_USE_TRTLLM_CONTEXT_ATTENTION + if env_value is not None: + logger.info_once("VLLM_USE_TRTLLM_CONTEXT_ATTENTION is set to %s", + env_value) + # Environment variable is set - respect it + # Making the conditional check for zero because + # the path is automatically enabled if the batch size condition + # is satisfied. + no_use_trtllm = (env_value == "0") + if not no_use_trtllm: + logger.info_once("Using TRTLLM context attention.") + return not no_use_trtllm + + # TODO: add heuristic + return False + + __all__ = [ "has_flashinfer", "flashinfer_trtllm_fp8_block_scale_moe", diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index f086bab2556e..95bf43a2fc66 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -373,6 +373,8 @@ def __init__( logits_soft_cap: Optional[float] = None, attn_type: AttentionType = AttentionType.DECODER, kv_sharing_target_layer_name: Optional[str] = None, + use_irope: bool = False, + sinks: Optional[torch.Tensor] = None, ) -> None: self.num_heads = num_heads self.head_size = head_size @@ -410,6 +412,14 @@ def __init__( raise NotImplementedError( "FlashAttention does not support fp8 kv-cache on this device.") + self.sinks = sinks + if self.sinks is not None: + assert self.vllm_flash_attn_version == 3, ( + "Sinks are only supported in FlashAttention 3") + assert self.sinks.shape[0] == num_heads, ( + "Sinks must have the same number of heads as the number of " + "heads in the layer") + def forward( self, layer: torch.nn.Module, @@ -534,6 +544,7 @@ def forward( k_descale=layer._k_scale.expand(descale_shape), v_descale=layer._v_scale.expand(descale_shape), num_splits=attn_metadata.max_num_splits, + s_aux=self.sinks, ) return output diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 3697cb9387a9..13b8f8199ffa 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -19,7 +19,8 @@ from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.utils import cdiv, is_pin_memory_available -from vllm.utils.flashinfer import use_trtllm_decode_attention +from vllm.utils.flashinfer import (use_trtllm_context_attention, + use_trtllm_decode_attention) from vllm.v1.attention.backends.flash_attn import use_cascade_attention # yapf conflicts with isort for this block # yapf: disable @@ -249,10 +250,12 @@ def _get_workspace_buffer(self): device=self.device) return self._workspace_buffer - def _get_prefill_wrapper(self): + def _get_prefill_wrapper(self, backend = 'auto'): if self._prefill_wrapper is None: self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper( - self._get_workspace_buffer(), get_kv_cache_layout()) + self._get_workspace_buffer(), + get_kv_cache_layout(), + backend=backend) return self._prefill_wrapper def _get_decode_wrapper(self, @@ -344,7 +347,15 @@ def _plan(self, num_prefills: int, num_decodes: int, if num_prefills > 0: # Decodes are first so prefills start after the last decode prefill_start = num_decodes - attn_metadata.prefill_wrapper = self._get_prefill_wrapper() + use_trtllm_context_attention_ = use_trtllm_context_attention( + num_prefills, attn_metadata.max_seq_len, + self.cache_config.cache_dtype, + attn_metadata.num_qo_heads, attn_metadata.num_kv_heads, + attn_metadata.head_dim) + backend = ("trtllm-gen" if use_trtllm_context_attention_ + else 'auto') + attn_metadata.prefill_wrapper = \ + self._get_prefill_wrapper(backend=backend) assert attn_metadata.qo_indptr_cpu[prefill_start:].shape[ 0] == num_prefills + 1 assert attn_metadata.paged_kv_indptr_cpu[prefill_start:].shape[ @@ -372,6 +383,8 @@ def _plan(self, num_prefills: int, num_decodes: int, logits_soft_cap, q_data_type=attn_metadata.q_data_type, kv_data_type=attn_metadata.kv_data_type, + seq_lens=attn_metadata.seq_lens[prefill_start:], + block_tables=attn_metadata.block_table_tensor[prefill_start:], ) if num_decodes > 0: @@ -578,6 +591,7 @@ def __init__( logits_soft_cap: Optional[float] = None, attn_type: AttentionType = AttentionType.DECODER, kv_sharing_target_layer_name: Optional[int] = None, + sinks: Optional[torch.Tensor] = None, ) -> None: self.num_heads = num_heads self.head_size = head_size @@ -601,6 +615,14 @@ def __init__( "encoder/decoder cross-attention " "are not implemented for " "FlashInferImpl") + self.sinks: Optional[torch.Tensor] = None + if sinks is not None: + assert sinks.shape[0] == num_heads, ( + "Sinks must have the same number of heads " + "as the number of heads in the layer" + ) + assert sinks.dtype == torch.float32, "Sinks must be of type float32" + self.sinks = sinks def forward( self, @@ -703,13 +725,19 @@ def forward( assert prefill_query.shape[0] == num_prefill_tokens assert prefill_wrapper is not None assert prefill_wrapper._causal - assert prefill_wrapper._window_left == window_left - assert prefill_wrapper._logits_soft_cap == (self.logits_soft_cap - or 0.0) - assert prefill_wrapper._sm_scale == self.scale + if not use_trtllm_context_attention( + attn_metadata.num_prefills, attn_metadata.max_seq_len, + self.kv_cache_dtype, attn_metadata.num_qo_heads, + attn_metadata.num_kv_heads, attn_metadata.head_dim): + assert prefill_wrapper._window_left == window_left + assert prefill_wrapper._logits_soft_cap == (self.logits_soft_cap + or 0.0) + assert prefill_wrapper._sm_scale == self.scale prefill_wrapper.run( prefill_query, kv_cache_permute, + window_left=window_left, + sinks=self.sinks, k_scale=layer._k_scale_float, v_scale=layer._v_scale_float, out=output[num_decode_tokens:], @@ -729,6 +757,7 @@ def forward( decode_wrapper.run( decode_query, kv_cache_permute, + window_left=window_left, k_scale=layer._k_scale_float, v_scale=layer._v_scale_float, out=output[:num_decode_tokens], @@ -760,6 +789,8 @@ def forward( max_seq_len=attn_metadata.max_seq_len, bmm1_scale=layer._k_scale_float * self.scale, bmm2_scale=layer._v_scale_float, + window_left=window_left, + sinks=self.sinks, out=output[:num_decode_tokens], ) return output_padded diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index 942cb95eefa2..c33afbfebcde 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Attention layer with PagedAttention and Triton prefix prefill.""" from dataclasses import dataclass +from functools import cache from typing import ClassVar, Optional import torch @@ -13,7 +14,6 @@ from vllm.attention.ops.chunked_prefill_paged_decode import ( chunked_prefill_paged_decode) from vllm.attention.ops.paged_attn import PagedAttention -from vllm.attention.ops.triton_unified_attention import unified_attention from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.platforms import current_platform @@ -193,6 +193,15 @@ def get_builder_cls() -> type["TritonAttentionMetadataBuilder"]: return TritonAttentionMetadataBuilder +@cache +def use_aiter_unified_attention() -> bool: + """Check if aiter unified attention should be used.""" + # VLLM_ROCM_USE_AITER_MHA needs to set to 0 as well as it is set + # to 1 as default + return envs.VLLM_ROCM_USE_AITER \ + and envs.VLLM_USE_AITER_UNIFIED_ATTENTION + + class TritonAttentionImpl(AttentionImpl): def __init__( @@ -207,6 +216,7 @@ def __init__( logits_soft_cap: Optional[float] = None, attn_type: AttentionType = AttentionType.DECODER, kv_sharing_target_layer_name: Optional[int] = None, + sinks: Optional[torch.Tensor] = None, ) -> None: self.num_heads = num_heads self.head_size = head_size @@ -240,6 +250,29 @@ def __init__( self.force_prefill_decode_attn = \ envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION + if not self.force_prefill_decode_attn: + # If not using prefill decode attention, we use the Triton + # unified attention implementation. + if use_aiter_unified_attention(): + logger.info_once( + "Using aiter unified attention for TritonAttentionImpl") + from aiter.ops.triton.unified_attention import ( + unified_attention) + self.unified_attention = unified_attention + else: + logger.info_once( + "Using vllm unified attention for TritonAttentionImpl") + from vllm.attention.ops.triton_unified_attention import ( + unified_attention) + self.unified_attention = unified_attention + + self.sinks = sinks + if sinks is not None: + assert sinks.shape[0] == num_heads, ( + "Sinks must have the same number of heads as the number of " + f"heads in the layer. Sinks shape: {sinks.shape}, " + f"num_heads: {num_heads}.") + def forward( self, layer: torch.nn.Module, @@ -342,28 +375,31 @@ def forward( if use_prefill_decode_attn: # Compute attention and update output up to `num_actual_tokens`. - chunked_prefill_paged_decode(query=query[:num_actual_tokens], - key=key[:num_actual_tokens], - value=value[:num_actual_tokens], - output=output[:num_actual_tokens], - kv_cache_dtype=self.kv_cache_dtype, - key_cache=key_cache, - value_cache=value_cache, - block_table=block_table, - query_start_loc=cu_seqlens_q, - seq_lens=seqused_k, - max_seq_len=max_seqlen_k, - max_query_len=max_seqlen_q, - k_scale=layer._k_scale, - v_scale=layer._v_scale, - alibi_slopes=self.alibi_slopes, - sliding_window=self.sliding_window[0], - sm_scale=self.scale) + chunked_prefill_paged_decode( + query=query[:num_actual_tokens], + key=key[:num_actual_tokens], + value=value[:num_actual_tokens], + output=output[:num_actual_tokens], + kv_cache_dtype=self.kv_cache_dtype, + key_cache=key_cache, + value_cache=value_cache, + block_table=block_table, + query_start_loc=cu_seqlens_q, + seq_lens=seqused_k, + max_seq_len=max_seqlen_k, + max_query_len=max_seqlen_q, + k_scale=layer._k_scale, + v_scale=layer._v_scale, + alibi_slopes=self.alibi_slopes, + sliding_window=self.sliding_window[0], + sm_scale=self.scale, + sinks=self.sinks, + ) else: descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1]) - unified_attention( + self.unified_attention( q=query[:num_actual_tokens], k=key_cache, v=value_cache, @@ -381,6 +417,7 @@ def forward( q_descale=None, # Not supported k_descale=layer._k_scale.expand(descale_shape), v_descale=layer._v_scale.expand(descale_shape), + sinks=self.sinks, ) return output diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 7aeea40b25a6..f521d94331b5 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -254,7 +254,11 @@ def get_kv_cache_layout(): # Override with format specified by the user. cache_layout = envs.VLLM_KV_CACHE_LAYOUT if cache_layout is None: - cache_layout = get_kv_connector_cache_layout() + if (envs.VLLM_USE_TRTLLM_CONTEXT_ATTENTION + or envs.VLLM_USE_TRTLLM_DECODE_ATTENTION): + cache_layout = "HND" + else: + cache_layout = get_kv_connector_cache_layout() else: logger.info_once("`VLLM_KV_CACHE_LAYOUT` environment variable " \ "detected. Setting KV cache layout to %s.", cache_layout) @@ -272,7 +276,9 @@ def set_kv_cache_layout(cache_layout: str): class PerLayerParameters: """ Currently, FlashInfer backend only support models in which all layers share - the same values for the following hyperparameters. + the same values for the following hyperparameters. Should not be used for + trtllm-gen backend since it supports different values for the following + hyperparameters. """ window_left: int @@ -310,7 +316,8 @@ def get_per_layer_parameters( def infer_global_hyperparameters( per_layer_params: dict[str, PerLayerParameters]) -> PerLayerParameters: """ - Currently, FlashInfer backend only support models in which all layers share + Currently, FlashInfer backend other than trtllm-gen + only support models in which all layers share the same values for the following hyperparameters: - `window_left` - `logits_soft_cap` @@ -324,15 +331,20 @@ def infer_global_hyperparameters( param_sets = list(per_layer_params.values()) global_params = param_sets[0] - for params in param_sets: - if params.window_left != global_params.window_left: - raise ValueError( - "Window left is not the same for all layers. One potential fix " - "is to set disable_sliding_window=True") - assert params == global_params, ( - "FlashInfer backend currently only supports models in which all " - "layers share the same values for the following hyperparameters: " - "`window_left`, `logits_soft_cap`, `sm_scale`.") + + # trtllm attention doesn't need global hyper params so disable the check + if (not envs.VLLM_USE_TRTLLM_CONTEXT_ATTENTION + and not envs.VLLM_USE_TRTLLM_DECODE_ATTENTION): + for params in param_sets: + if params.window_left != global_params.window_left: + raise ValueError( + "Window left is not the same for all layers. " \ + "One potential fix is to set disable_sliding_window=True") + assert params == global_params, ( + "FlashInfer backend currently only supports models in which all" + "layers share the same values " + "for the following hyperparameters:" + "`window_left`, `logits_soft_cap`, `sm_scale`.") return global_params diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 79c47e102888..69d26f37dc03 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -14,8 +14,10 @@ from typing import Any, Callable, Optional, TypeVar, Union import msgspec +import torch.cuda.profiler as profiler import zmq +import vllm.envs as envs from vllm.config import ParallelConfig, VllmConfig from vllm.distributed import stateless_destroy_torch_distributed_process_group from vllm.logger import init_logger @@ -24,7 +26,7 @@ from vllm.tasks import POOLING_TASKS, SupportedTask from vllm.transformers_utils.config import ( maybe_register_config_serialize_by_value) -from vllm.utils import (decorate_logs, make_zmq_socket, +from vllm.utils import (decorate_logs, make_zmq_socket, print_logo, resolve_obj_by_qualname, set_process_title) from vllm.v1.core.kv_cache_utils import (get_kv_cache_config, unify_kv_cache_configs) @@ -71,6 +73,9 @@ def __init__(self, logger.info("Initializing a V1 LLM engine (v%s) with config: %s", VLLM_VERSION, vllm_config) + if envs.VLLM_PRINT_LOGO: + print_logo() + self.log_stats = log_stats # Setup Model. @@ -128,6 +133,18 @@ def __init__(self, self.mm_input_cache_server = MirroredProcessingCache( vllm_config.model_config) + # cudaProfilerApi Support + self._perf_iter = 0 + self._profiler_running = False + _perf_env_str = envs.VLLM_NSYS_PROFILE_START_STOP + if '-' in _perf_env_str: + start, stop = _perf_env_str.strip().split('-') + self._start_perf_iter = int(start) + self._stop_perf_iter = int(stop) + else: + self._start_perf_iter = -1 + self._stop_perf_iter = -1 + # Setup batch queue for pipeline parallelism. # Batch queue for scheduled batches. This enables us to asynchronously # schedule and execute batches, and is required by pipeline parallelism @@ -265,6 +282,18 @@ def step(self) -> tuple[dict[int, EngineCoreOutputs], bool]: was executed. """ + # Profiler Start and Stop + if self._perf_iter == self._start_perf_iter: + logger.info("Starting profiler") + profiler.start() + self._profiler_running = True + + if self._perf_iter == self._stop_perf_iter: + logger.info("Stopping profiler") + profiler.stop() + self._profiler_running = False + self._perf_iter += 1 + # Check for any requests remaining in the scheduler - unfinished, # or finished and not yet removed from the batch. if not self.scheduler.has_requests(): @@ -331,6 +360,12 @@ def step_with_batch_queue( return engine_core_outputs, scheduled_batch def shutdown(self): + + # Check if profiler is running + if self._profiler_running: + logger.info("Stopping profiler") + profiler.stop() + self.structured_output_manager.clear_backend() if self.model_executor: self.model_executor.shutdown()