Skip to content

fix no think of GLM-4.5 / GLM-4.7#31449

Merged
chaunceyjiang merged 2 commits intovllm-project:mainfrom
zRzRzRzRzRzRzR:parser
Jan 4, 2026
Merged

fix no think of GLM-4.5 / GLM-4.7#31449
chaunceyjiang merged 2 commits intovllm-project:mainfrom
zRzRzRzRzRzRzR:parser

Conversation

@zRzRzRzRzRzRzR
Copy link
Copy Markdown
Contributor

Using the logic of DeepSeek-R1 and Qwen3, it can be extracted even if it doesn't appear.

Signed-off-by: zRzRzRzRzRzRzR <2448370773@qq.com>
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors Glm4MoeModelReasoningParser to inherit from BaseThinkingReasoningParser, simplifying the code. It also introduces a fix to handle cases where GLM-4.5 models might not output a <think> start token, which is a good improvement. However, the implementation introduces code duplication between the subclass and the base class. I've left a comment highlighting this maintainability issue and suggesting a refactoring approach.

Comment on lines 48 to 68
if (
ret is not None
and self.start_token_id not in previous_token_ids
and self.start_token_id not in delta_token_ids
):
return None

if self.think_start_token_id in previous_token_ids:
if self.think_end_token_id in delta_token_ids:
# <think> in previous, </think> in delta,
# extract reasoning content
end_index = delta_text.find(self.think_end_token)
if self.end_token_id in delta_token_ids:
# end token in delta with more tokens,
# extract reasoning content and content
end_index = delta_text.find(self.end_token)
reasoning = delta_text[:end_index]
content = delta_text[end_index + len(self.think_end_token) :]
content = delta_text[end_index + len(self.end_token) :]
return DeltaMessage(
reasoning=reasoning,
content=content if content else None,
)
elif self.think_end_token_id in previous_token_ids:
# <think> in previous, </think> in previous,
# reasoning content continues
elif self.end_token_id in previous_token_ids:
# end token in previous, thinking content ends
return DeltaMessage(content=delta_text)
else:
# <think> in previous, no </think> in previous or delta,
# reasoning content continues
return DeltaMessage(reasoning=delta_text)
elif self.think_start_token_id in delta_token_ids:
if self.think_end_token_id in delta_token_ids:
# <think> in delta, </think> in delta, extract reasoning content
start_index = delta_text.find(self.think_start_token)
end_index = delta_text.find(self.think_end_token)
reasoning = delta_text[
start_index + len(self.think_start_token) : end_index
]
content = delta_text[end_index + len(self.think_end_token) :]
return DeltaMessage(
reasoning=reasoning,
content=content if content else None,
)
else:
# <think> in delta, no </think> in delta,
# reasoning content continues
# no end token in previous or delta, reasoning content continues
return DeltaMessage(reasoning=delta_text)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The logic within this if block for handling streaming extraction when a start token is missing is a duplication of existing logic in BaseThinkingReasoningParser for when a start token is present. This code duplication introduces a maintenance risk: future changes to the parsing logic will need to be made in two places, which is error-prone.

To improve maintainability, this duplicated logic should be extracted into a shared helper method within the base class.

Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment on lines 48 to 52
if (
ret is not None
and self.start_token_id not in previous_token_ids
and self.start_token_id not in delta_token_ids
):
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Normal outputs treated as reasoning when think tokens absent

The new branch that runs when no <think> start token has ever been seen now routes text into the reasoning field even though the model never emitted thinking markers, and the class now inherits the base extract_reasoning (basic_parsers.py 151-175) which likewise returns the entire output as reasoning when <think> is missing. For GLM calls with enable_thinking=False or for variants that simply omit the tags, content stays None, so the OpenAI response builders (serving_responses.py 844-889) skip creating the assistant message and the user receives no answer despite the model producing one. The previous parser returned the text as content when the tags were absent, so non-thinking generations are now silently dropped.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Collaborator

@chaunceyjiang chaunceyjiang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems exactly the same as DeepSeekR1ReasoningParser. I suggest adding a new entry glm47 in __init__.py:

    "glm47": (  # name
        "deepseek_r1_reasoning_parser",  # filename
        "DeepSeekR1ReasoningParser",  # class_name
    ),

@zRzRzRzRzRzRzR
Copy link
Copy Markdown
Contributor Author

It seems exactly the same as DeepSeekR1ReasoningParser. I suggest adding a new entry glm47 in __init__.py:

    "glm47": (  # name

"deepseek_r1_reasoning_parser", # filename
"DeepSeekR1ReasoningParser", # class_name
),


This approach works too. Do I need to submit a separate PR? Does this change require modifying the startup command?

@chaunceyjiang
Copy link
Copy Markdown
Collaborator

Does this change require modifying the startup command?

No.
For the R1 reasoning-parser: if it fully supports both GLM-4.5 and GLM-4.7, I suggest simply renaming R1 rather than copying and pasting duplicate code.

Since the official GLM blog recommends using the GLM-4.5 reasoning-parser for the GLM-4.7 model, end users can continue using the GLM-4.5 reasoning-parser (even though it is effectively an R1 implementation) without needing to modify any startup parameters.

Signed-off-by: zRzRzRzRzRzRzR <2448370773@qq.com>
@zRzRzRzRzRzRzR
Copy link
Copy Markdown
Contributor Author

zRzRzRzRzRzRzR commented Dec 29, 2025

I understand, I modified the code. Is this what you mean?

@chaunceyjiang
Copy link
Copy Markdown
Collaborator

@zRzRzRzRzRzRzR Could you test examples/online_serving/openai_chat_completion_with_reasoning_streaming.py and paste the results from your GLM-4.5 test here?

@zRzRzRzRzRzRzR
Copy link
Copy Markdown
Contributor Author

root@node196:/mnt/vllm# python examples/online_serving/openai_chat_completion_with_reasoning_streaming.py
client: Start streaming chat completions...
reasoning:1.  **Analyze the Request:** The user is asking to compare two numbers: 9.11 and 9.8, and determine which one is greater.

2.  **Identify the Numbers:**
    *   Number A: 9.11
    *   Number B: 9.8 (which can be written as 9.80)

3.  **Compare the Numbers:**
    *   *Integer part:* Both have 9. They are equal so far.
    *   *Decimal part (Tenths place):*
        *   9.**1**1 has a 1 in the tenths place.
        *   9.**8** has an 8 in the tenths place.
    *   Since 8 > 1, the number with the 8 in the tenths place is larger, regardless of what comes after the hundredths place.
    *   *Alternative method (aligning decimals):*
        *   9.11
        *   9.80
        *   11 < 80.

4.  **Determine the Result:** 9.8 is greater than 9.11.

5.  **Formulate the Answer:**
    *   State the answer clearly: 9.8 is greater.
    *   Provide a brief explanation to avoid confusion (since some people mistakenly think 11 > 8 because of whole-number logic).
    *   Explanation: compare the digits immediately after the decimal point (the tenths place). The tenths place in 9.11 is 1. The tenths place in 9.8 is 8. Since 8 is greater than 1, 9.8 is larger.

6.  **Refine for User Experience:** Keep it direct but helpful.

    *   *Draft:* 9.8 is greater. This is because the digit in the tenths place (8) is larger than the digit in the tenths place of 9.11 (1).

7.  **Final Output Generation:** (Matches the drafted response).
content:**9.8** is greater.

Here is the breakdown:
*   **9.11** has a **1** in the tenths place.
*   **9.8** has an **8** in the tenths place.

Since 8 is greater than 1, 9.8 is the larger number.root@node196:/mnt/vllm# 

@chaunceyjiang
Copy link
Copy Markdown
Collaborator

Thank you for your patience. @zRzRzRzRzRzRzR I discussed this with the other maintainers.
They all believe that vllm/reasoning/glm4_moe_reasoning_parser.py is exactly the same as R1.
Using the approach from #28128 is easier to maintain.

However, considering that the file vllm/reasoning/glm4_moe_reasoning_parser.py has already gone through several versions, I will merge this PR now and remove the file in a few future releases.

@chaunceyjiang chaunceyjiang added the ready ONLY add when PR is ready to merge/full CI is needed label Dec 29, 2025
@zRzRzRzRzRzRzR
Copy link
Copy Markdown
Contributor Author

Alright, I won't modify this PR anymore, looking forward to the merge.

@chaunceyjiang chaunceyjiang enabled auto-merge (squash) December 30, 2025 04:31
@chaunceyjiang chaunceyjiang disabled auto-merge January 4, 2026 03:42
@chaunceyjiang chaunceyjiang enabled auto-merge (squash) January 4, 2026 03:42
Copy link
Copy Markdown
Collaborator

@chaunceyjiang chaunceyjiang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks~

@chaunceyjiang chaunceyjiang merged commit 0d4044e into vllm-project:main Jan 4, 2026
47 checks passed
@athenacykes
Copy link
Copy Markdown

athenacykes commented Jan 4, 2026

I tried to use deepseek_r1 parser with GLM-4.7, and when {"chat_template_kwargs": {"enable_thinking": False}} is added to the request, deepseek_r1 parser seems to encapsulate the entire non-reasoning content into <think> ... </think> tag.

Simply extending the Deepseek-R1's reasoning parser class will introduce the same issue.

@zhangsongqing
Copy link
Copy Markdown

curl http://localhost:18001/v1/chat/completions
-H "Content-Type: application/json"
-d '{
"model": "glm-4.7-fp8",
"messages": [{"role": "user","content": "续写蜀道难,100字"}],
"max_tokens": 2048,
"temperature": 0.6,
"stream": false,
"chat_template_kwargs": {"enable_thinking": false}
}'

{"id":"chatcmpl-9ba8e0f3887b48b7","object":"chat.completion","created":1767525526,"model":"glm-4.7-fp8","choices":[{"index":0,"message":{"role":"assistant","content":null,"refusal":null,"annotations":null,"audio":null,"function_call":null,"tool_calls":[],"reasoning":"青泥何盘盘,百步九折萦岩峦。扪参历井仰胁息,以手抚膺坐长叹。问君西游何时还?畏途巉岩不可攀。但见悲鸟号古木,雄飞雌从绕林间。又闻子规啼夜月,愁空山。蜀道之难,难于上青天,使人听此凋朱颜!连峰去天不盈尺,枯松倒挂倚绝壁。飞湍瀑流争喧豗,砯崖转石万壑雷。其险也如此,嗟尔远道之人胡为乎来哉!剑阁峥嵘而崔嵬,一夫当关,万夫莫开。所守或匪亲,化为狼与豺。朝避猛虎,夕避长蛇;磨牙吮血,杀人如麻。锦城虽云乐,不如早还家。蜀道之难,难于上青天,侧身西望长咨嗟!","reasoning_content":"青泥何盘盘,百步九折萦岩峦。扪参历井仰胁息,以手抚膺坐长叹。问君西游何时还?畏途巉岩不可攀。但见悲鸟号古木,雄飞雌从绕林间。又闻子规啼夜月,愁空山。蜀道之难,难于上青天,使人听此凋朱颜!连峰去天不盈尺,枯松倒挂倚绝壁。飞湍瀑流争喧豗,砯崖转石万壑雷。其险也如此,嗟尔远道之人胡为乎来哉!剑阁峥嵘而崔嵬,一夫当关,万夫莫开。所守或匪亲,化为狼与豺。朝避猛虎,夕避长蛇;磨牙吮血,杀人如麻。锦城虽云乐,不如早还家。蜀道之难,难于上青天,侧身西望长咨嗟!"},"logprobs":null,"finish_reason":"stop","stop_reason":151336,"token_ids":null}],"service_tier":null,"system_fingerprint":null,"usage":{"prompt_tokens":13,"total_tokens":242,"completion_tokens":229,"prompt_tokens_details":null},"prompt_logprobs":null,"prompt_token_ids":null,"kv_transfer_params":null}

the content is null,and had replace to reasoning?

@zhangsongqing
Copy link
Copy Markdown

curl http://localhost:18001/v1/chat/completions -H "Content-Type: application/json" -d '{
"model": "glm-4.7-fp8",
"messages": [{"role": "user","content": "续写蜀道难,100字"}],
"max_tokens": 2048,
"temperature": 0.6,
"stream": false
}'
{"id":"chatcmpl-bd5871b01f84afe5","object":"chat.completion","created":1767525726,"model":"glm-4.7-fp8","choices":[{"index":0,"message":{"role":"assistant","content":"云开雾散见青天,......踏破青山人未还。","refusal":null,"annotations":null,"audio":null,"function_call":null,"tool_calls":[],"reasoning":"1. 分析请求:\n * 主题: 续写李白的《蜀道难》。\n * 篇幅: 大约100字。\n * 风格: 古典诗词风格(模仿李白的浪漫、夸张。。。。。最终输出生成。","reasoning_content":"1. 分析请求:\n * 。。。。。。。。。\n\n12. 最终输出生成。"},"logprobs":null,"finish_reason":"stop","stop_reason":151336,"token_ids":null}],"service_tier":null,"system_fingerprint":null,"usage":{"prompt_tokens":13,"total_tokens":1455,"completion_tokens":1442,"prompt_tokens_details":null},"prompt_logprobs":null,"prompt_token_ids":null,"kv_transfer_params":null}

not set chat_template_kwargs": {"enable_thinking": false},but he reason have "reasoning" and "reasoning_content"

@chaunceyjiang
Copy link
Copy Markdown
Collaborator

not set chat_template_kwargs": {"enable_thinking": false},but he reason have "reasoning" and "reasoning_content"

/cc @zRzRzRzRzRzRzR WDYT?

@hhd52859
Copy link
Copy Markdown

hhd52859 commented Jan 6, 2026

I tried to use deepseek_r1 parser with GLM-4.7, and when {"chat_template_kwargs": {"enable_thinking": False}} is added to the request, deepseek_r1 parser seems to encapsulate the entire non-reasoning content into <think> ... </think> tag.

Simply extending the Deepseek-R1's reasoning parser class will introduce the same issue.

meet same problem, cc @chaunceyjiang

LucasWilkinson pushed a commit to neuralmagic/vllm that referenced this pull request Jan 6, 2026
Signed-off-by: zRzRzRzRzRzRzR <2448370773@qq.com>
yugong333 pushed a commit to yugong333/vllm that referenced this pull request Jan 9, 2026
Signed-off-by: zRzRzRzRzRzRzR <2448370773@qq.com>
@zRzRzRzRzRzRzR zRzRzRzRzRzRzR deleted the parser branch January 15, 2026 08:43
akh64bit pushed a commit to akh64bit/vllm that referenced this pull request Jan 16, 2026
Signed-off-by: zRzRzRzRzRzRzR <2448370773@qq.com>
dsuhinin pushed a commit to dsuhinin/vllm that referenced this pull request Jan 21, 2026
Signed-off-by: zRzRzRzRzRzRzR <2448370773@qq.com>
Signed-off-by: dsuhinin <suhinin.dmitriy@gmail.com>
ItzDEXX pushed a commit to ItzDEXX/vllm that referenced this pull request Feb 19, 2026
Signed-off-by: zRzRzRzRzRzRzR <2448370773@qq.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants