Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 158 additions & 0 deletions tests/entrypoints/openai/chat_completion/test_serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionStreamResponse,
)
from vllm.entrypoints.openai.chat_completion.serving import OpenAIServingChat
from vllm.entrypoints.openai.engine.protocol import (
Expand All @@ -38,6 +39,7 @@
from vllm.exceptions import VLLMValidationError
from vllm.inputs import TokensPrompt
from vllm.outputs import CompletionOutput, RequestOutput
from vllm.parser import ParserManager
from vllm.renderers.hf import HfRenderer
from vllm.renderers.mistral import MistralRenderer
from vllm.tokenizers import get_tokenizer
Expand Down Expand Up @@ -655,6 +657,162 @@ async def return_model_name(*args):
assert await serving_chat.create_chat_completion(req) == MODEL_NAME


@pytest.mark.asyncio
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize(
"parser_name",
["minimax_m2", "minimax_m2_append_think"],
)
async def test_chat_usage_includes_reasoning_tokens_for_minimax_parser(
parser_name: str,
):
class FakeTokenizer:
def __init__(self):
self._vocab = {"<think>": 1, "</think>": 2}

def get_vocab(self):
return self._vocab

mock_engine = MagicMock(spec=AsyncLLM)
mock_engine.errored = False
mock_engine.model_config = MockModelConfig()
mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
mock_engine.renderer = MagicMock()

serving_chat = _build_serving_chat(mock_engine)
serving_chat.reasoning_parser_cls = ParserManager.get_reasoning_parser(parser_name)

tokenizer = FakeTokenizer()
parser_cls = ParserManager.get_parser(reasoning_parser_name=parser_name)
assert parser_cls is not None
parser = parser_cls(tokenizer)
completion = CompletionOutput(
index=0,
text="reasoning</think>final",
token_ids=[10, 11, 2, 20],
cumulative_logprob=0.0,
logprobs=None,
finish_reason="stop",
stop_reason=None,
)
req_output = RequestOutput(
request_id="req",
prompt="hi",
prompt_token_ids=[7, 8],
prompt_logprobs=None,
outputs=[completion],
finished=True,
num_cached_tokens=0,
)

async def result_generator():
yield req_output

request = ChatCompletionRequest(
model=MODEL_NAME,
messages=[{"role": "user", "content": "hi"}],
)
request_id = "req"

response = await serving_chat.chat_completion_full_generator(
request=request,
result_generator=result_generator(),
request_id=request_id,
model_name=MODEL_NAME,
conversation=[],
tokenizer=tokenizer,
request_metadata=RequestResponseMetadata(request_id=request_id),
parser=parser,
)

assert response.usage.completion_tokens == 4
assert response.usage.completion_tokens_details is not None
assert response.usage.completion_tokens_details.reasoning_tokens == 2


@pytest.mark.asyncio
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize(
"parser_name",
["minimax_m2", "minimax_m2_append_think"],
)
async def test_chat_stream_usage_includes_reasoning_tokens_for_minimax_parser(
parser_name: str,
):
class FakeTokenizer:
def __init__(self):
self._vocab = {"<think>": 1, "</think>": 2}

def get_vocab(self):
return self._vocab

mock_engine = MagicMock(spec=AsyncLLM)
mock_engine.errored = False
mock_engine.model_config = MockModelConfig()
mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
mock_engine.renderer = MagicMock()

serving_chat = _build_serving_chat(mock_engine)
serving_chat.reasoning_parser_cls = ParserManager.get_reasoning_parser(parser_name)

tokenizer = FakeTokenizer()
reasoning_parser = serving_chat.reasoning_parser_cls(tokenizer)
completion = CompletionOutput(
index=0,
text="reasoning</think>final",
token_ids=[10, 11, 2, 20],
cumulative_logprob=0.0,
logprobs=None,
finish_reason="stop",
stop_reason=None,
)
req_output = RequestOutput(
request_id="req",
prompt="hi",
prompt_token_ids=[7, 8],
prompt_logprobs=None,
outputs=[completion],
finished=True,
num_cached_tokens=0,
)

async def result_generator():
yield req_output

request = ChatCompletionRequest(
model=MODEL_NAME,
messages=[{"role": "user", "content": "hi"}],
stream=True,
stream_options={"include_usage": True},
)
request_id = "req"
stream = serving_chat.chat_completion_stream_generator(
request=request,
result_generator=result_generator(),
request_id=request_id,
model_name=MODEL_NAME,
conversation=[],
tokenizer=tokenizer,
request_metadata=RequestResponseMetadata(request_id=request_id),
reasoning_parser=reasoning_parser,
)

final_usage = None
async for chunk_str in stream:
if chunk_str.strip() == "data: [DONE]":
continue
chunk = ChatCompletionStreamResponse.model_validate_json(chunk_str[6:].strip())
if chunk.usage is not None and not chunk.choices:
final_usage = chunk.usage

assert final_usage is not None
assert final_usage.completion_tokens == 4
assert final_usage.completion_tokens_details is not None
assert final_usage.completion_tokens_details.reasoning_tokens == 2


@pytest.mark.asyncio
async def test_serving_chat_should_set_correct_max_tokens():
mock_engine = MagicMock(spec=AsyncLLM)
Expand Down
38 changes: 38 additions & 0 deletions tests/reasoning/test_minimax_m2_append_reasoning_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,3 +193,41 @@ def test_reasoning(
output_ids = minimax_m2_tokenizer.convert_tokens_to_ids(output)
is_reasoning_end = parser.is_reasoning_end(output_ids)
assert is_reasoning_end == param_dict["is_reasoning_end"]


@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize(
("output_text", "reasoning_text"),
[
pytest.param(
SIMPLE_OUTPUT["output"],
"This is reasoning",
id="reasoning-before-end-token",
),
pytest.param(
NO_END_TOKEN["output"],
NO_END_TOKEN["output"],
id="all-tokens-are-reasoning-before-end-token",
),
pytest.param(
ONLY_END_TOKEN["output"],
"",
id="end-token-first-means-no-reasoning-tokens",
),
],
)
def test_count_reasoning_tokens(
output_text: str,
reasoning_text: str,
minimax_m2_tokenizer,
):
parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser(parser_name)(
minimax_m2_tokenizer
)

output_ids = minimax_m2_tokenizer.encode(output_text, add_special_tokens=False)
reasoning_ids = minimax_m2_tokenizer.encode(
reasoning_text, add_special_tokens=False
)

assert parser.count_reasoning_tokens(output_ids) == len(reasoning_ids)
38 changes: 38 additions & 0 deletions tests/reasoning/test_minimax_m2_reasoning_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,3 +228,41 @@ def test_reasoning(
else:
content = parser.extract_content_ids(output)
assert content == []


@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize(
("output_text", "reasoning_text"),
[
pytest.param(
SIMPLE_REASONING["output"],
SIMPLE_REASONING["reasoning"],
id="reasoning-before-end-token",
),
pytest.param(
NO_END_TOKEN["output"],
NO_END_TOKEN["output"],
id="all-tokens-are-reasoning-before-end-token",
),
pytest.param(
SHORTEST_REASONING_NO_STREAMING["output"],
"",
id="end-token-first-means-no-reasoning-tokens",
),
],
)
def test_count_reasoning_tokens(
output_text: str,
reasoning_text: str,
minimax_m2_tokenizer,
):
parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser(parser_name)(
minimax_m2_tokenizer
)

output_ids = minimax_m2_tokenizer.encode(output_text, add_special_tokens=False)
reasoning_ids = minimax_m2_tokenizer.encode(
reasoning_text, add_special_tokens=False
)

assert parser.count_reasoning_tokens(output_ids) == len(reasoning_ids)
Loading
Loading