Skip to content

[Perf] Enable dual stream execution of input projection for Qwen3#36795

Merged
DarkLight1337 merged 3 commits intovllm-project:mainfrom
xyang16:multi_stream
Mar 18, 2026
Merged

[Perf] Enable dual stream execution of input projection for Qwen3#36795
DarkLight1337 merged 3 commits intovllm-project:mainfrom
xyang16:multi_stream

Conversation

@xyang16
Copy link
Contributor

@xyang16 xyang16 commented Mar 11, 2026

Purpose

This PR Enable dual stream execution of input projection for Qwen3 Next.

  • Parallelize the execution of in_proj_qkvz and in_proj_ba in 2 streams, because their outputs are independent.
  • Wrap the implementation in custom op for torch.compile.

Profiling

Main:

Screenshot 2026-03-10 at 6 08 16 PM

PR:

Screenshot 2026-03-10 at 6 08 31 PM

Main: nvjet_tst_64x8_64x16_4x2_h_bz_TNT (in_proj_qkvz) and nvjet_tst_64x8_64x16_1x2_h_bz_TNT (in_proj_ba) kernels launched sequentially.

PR: kernels launched in parallel.

Benchmarking

Benchmarked on H200.

  • Qwen3
vllm serve Qwen/Qwen3-Next-80B-A3B-Instruct \
    --tensor-parallel-size 1 \
    --max-num-seqs 16 \
    --no-enable-prefix-caching
vllm bench serve \
        --model Qwen/Qwen3-Next-80B-A3B-Instruct \
        --dataset-name sharegpt \
        --dataset-path /tmp/ShareGPT_V3_unfiltered_cleaned_split.json \
        --sharegpt-output-len 300 \
        --num-prompts ${num_prompts} \
        --max-concurrency 16 \
        --num-warmups 50 \
        --ignore-eos \
        --temperature 0

Main:

============ Serving Benchmark Result ============
Successful requests:                     960       
Failed requests:                         0         
Maximum request concurrency:             16        
Benchmark duration (s):                  121.46    
Total input tokens:                      219140    
Total generated tokens:                  288000    
Request throughput (req/s):              7.90      
Output token throughput (tok/s):         2371.20   
Peak output token throughput (tok/s):    2640.00   
Peak concurrent requests:                32.00     
Total token throughput (tok/s):          4175.46   
---------------Time to First Token----------------
Mean TTFT (ms):                          176.72    
Median TTFT (ms):                        193.47    
P99 TTFT (ms):                           222.77    
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          6.18      
Median TPOT (ms):                        6.15      
P99 TPOT (ms):                           6.48      
---------------Inter-token Latency----------------
Mean ITL (ms):                           6.18      
Median ITL (ms):                         6.13      
P99 ITL (ms):                            6.94      
==================================================

PR:

============ Serving Benchmark Result ============
Successful requests:                     960       
Failed requests:                         0         
Maximum request concurrency:             16        
Benchmark duration (s):                  120.20    
Total input tokens:                      219140    
Total generated tokens:                  288000    
Request throughput (req/s):              7.99      
Output token throughput (tok/s):         2396.09   
Peak output token throughput (tok/s):    2672.00   
Peak concurrent requests:                32.00     
Total token throughput (tok/s):          4219.28   
---------------Time to First Token----------------
Mean TTFT (ms):                          191.90    
Median TTFT (ms):                        214.37    
P99 TTFT (ms):                           249.63    
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          6.06      
Median TPOT (ms):                        6.04      
P99 TPOT (ms):                           6.40      
---------------Inter-token Latency----------------
Mean ITL (ms):                           6.06      
Median ITL (ms):                         6.02      
P99 ITL (ms):                            6.72      
==================================================
  • Qwen3 fp8
vllm serve Qwen/Qwen3-Next-80B-A3B-Instruct-FP8 \
    --tensor-parallel-size 1 \
    --max-num-seqs 16 \
    --no-enable-prefix-caching
vllm bench serve \
        --model Qwen/Qwen3-Next-80B-A3B-Instruct-FP8 \
        --dataset-name sharegpt \
        --dataset-path /tmp/ShareGPT_V3_unfiltered_cleaned_split.json \
        --sharegpt-output-len 300 \
        --num-prompts ${num_prompts} \
        --max-concurrency 16 \
        --num-warmups 50 \
        --ignore-eos \
        --temperature 0

Main:

============ Serving Benchmark Result ============
Successful requests:                     960       
Failed requests:                         0         
Maximum request concurrency:             16        
Benchmark duration (s):                  195.14    
Total input tokens:                      227546    
Total generated tokens:                  288000    
Request throughput (req/s):              4.92      
Output token throughput (tok/s):         1475.89   
Peak output token throughput (tok/s):    1648.00   
Peak concurrent requests:                32.00     
Total token throughput (tok/s):          2641.98   
---------------Time to First Token----------------
Mean TTFT (ms):                          162.16    
Median TTFT (ms):                        156.37    
P99 TTFT (ms):                           234.19    
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          10.33     
Median TPOT (ms):                        10.31     
P99 TPOT (ms):                           10.77     
---------------Inter-token Latency----------------
Mean ITL (ms):                           10.33     
Median ITL (ms):                         10.24     
P99 ITL (ms):                            11.43     
==================================================

PR:

============ Serving Benchmark Result ============
Successful requests:                     960       
Failed requests:                         0         
Maximum request concurrency:             16        
Benchmark duration (s):                  191.31    
Total input tokens:                      219140    
Total generated tokens:                  288000    
Request throughput (req/s):              5.02      
Output token throughput (tok/s):         1505.38   
Peak output token throughput (tok/s):    1712.00   
Peak concurrent requests:                32.00     
Total token throughput (tok/s):          2650.83   
---------------Time to First Token----------------
Mean TTFT (ms):                          236.09    
Median TTFT (ms):                        236.95    
P99 TTFT (ms):                           380.30    
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          9.87      
Median TPOT (ms):                        9.83      
P99 TPOT (ms):                           10.49     
---------------Inter-token Latency----------------
Mean ITL (ms):                           9.87      
Median ITL (ms):                         9.81      
P99 ITL (ms):                            10.99     
==================================================
  • Qwen3.5
vllm serve Qwen/Qwen3.5-35B-A3B \
    --tensor-parallel-size 1 \
    --max-num-seqs 16 \
    --no-enable-prefix-caching
vllm bench serve \
        --model Qwen/Qwen3.5-35B-A3B \
        --dataset-name sharegpt \
        --dataset-path /tmp/ShareGPT_V3_unfiltered_cleaned_split.json \
        --sharegpt-output-len 300 \
        --num-prompts ${num_prompts} \
        --max-concurrency 16 \
        --num-warmups 50 \
        --ignore-eos \
        --temperature 0

Main:

============ Serving Benchmark Result ============
Successful requests:                     960       
Failed requests:                         0         
Maximum request concurrency:             16        
Benchmark duration (s):                  197.96    
Total input tokens:                      227546    
Total generated tokens:                  288000    
Request throughput (req/s):              4.85      
Output token throughput (tok/s):         1454.81   
Peak output token throughput (tok/s):    1648.00   
Peak concurrent requests:                32.00     
Total token throughput (tok/s):          2604.25   
---------------Time to First Token----------------
Mean TTFT (ms):                          142.74    
Median TTFT (ms):                        152.15    
P99 TTFT (ms):                           199.72    
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          10.55     
Median TPOT (ms):                        10.59     
P99 TPOT (ms):                           11.10     
---------------Inter-token Latency----------------
Mean ITL (ms):                           10.55     
Median ITL (ms):                         10.24     
P99 ITL (ms):                            12.05     
==================================================

PR:

============ Serving Benchmark Result ============
Successful requests:                     960       
Failed requests:                         0         
Maximum request concurrency:             16        
Benchmark duration (s):                  190.61    
Total input tokens:                      227546    
Total generated tokens:                  288000    
Request throughput (req/s):              5.04      
Output token throughput (tok/s):         1510.93   
Peak output token throughput (tok/s):    1715.00   
Peak concurrent requests:                32.00     
Total token throughput (tok/s):          2704.71   
---------------Time to First Token----------------
Mean TTFT (ms):                          169.85    
Median TTFT (ms):                        173.71    
P99 TTFT (ms):                           254.99    
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          10.05     
Median TPOT (ms):                        10.06     
P99 TPOT (ms):                           10.28     
---------------Inter-token Latency----------------
Mean ITL (ms):                           10.05     
Median ITL (ms):                         10.06     
P99 ITL (ms):                            11.37     
==================================================

Accuracy Testing

  • Qwen3
python3 -m lm_eval --model local-completions \
  --model_args model=Qwen/Qwen3-Next-80B-A3B-Instruct,base_url=http://127.0.0.1:8000/v1/completions,num_concurrent=16 \
  --tasks gsm8k

Main:

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.8575|±  |0.0096|
|     |       |strict-match    |     5|exact_match|↑  |0.8150|±  |0.0107|

PR:

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.8552|±  |0.0097|
|     |       |strict-match    |     5|exact_match|↑  |0.8082|±  |0.0108|
  • Qwen3 fp8
python3 -m lm_eval --model local-completions \
  --model_args model=Qwen/Qwen3-Next-80B-A3B-Instruct-FP8,base_url=http://127.0.0.1:8000/v1/completions,num_concurrent=16 \
  --tasks gsm8k

Main:

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.8491|±  |0.0099|
|     |       |strict-match    |     5|exact_match|↑  |0.8127|±  |0.0107|

PR:

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.8575|±  |0.0096|
|     |       |strict-match    |     5|exact_match|↑  |0.8127|±  |0.0107|
  • Qwen3.5
python3 -m lm_eval --model local-completions \
  --model_args model=Qwen/Qwen3.5-35B-A3B,base_url=http://127.0.0.1:8000/v1/completions,num_concurrent=16 \
  --tasks gsm8k

Main:

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.8476|±  |0.0099|
|     |       |strict-match    |     5|exact_match|↑  |0.8370|±  |0.0102|

PR:

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.8499|±  |0.0098|
|     |       |strict-match    |     5|exact_match|↑  |0.8332|±  |0.0103|

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

@xyang16 xyang16 requested a review from sighingnow as a code owner March 11, 2026 14:39
@mergify mergify bot added the qwen Related to Qwen models label Mar 11, 2026
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This PR introduces dual-stream execution for input projection in the Qwen3 Next model to improve performance by parallelizing in_proj_qkvz and in_proj_ba operations. It also wraps the implementation in a custom op for torch.compile. The changes include adding an auxiliary stream, modifying the forward pass to use the custom op, and introducing a new function for dual-stream execution. I have identified a critical issue related to potential deadlocks when using auxiliary streams.

@xyang16 xyang16 force-pushed the multi_stream branch 2 times, most recently from 5d76ed2 to 911a451 Compare March 11, 2026 15:09
@mergify
Copy link

mergify bot commented Mar 12, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @xyang16.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Mar 12, 2026
@ZJY0516
Copy link
Member

ZJY0516 commented Mar 12, 2026

Could you please also apply this to qwen 3.5?

@robertgshaw2-redhat
Copy link
Collaborator

I would avoid passing the aux_stream through the class constructors

@xyang16 xyang16 force-pushed the multi_stream branch 3 times, most recently from 01d0d02 to df1763c Compare March 13, 2026 00:45
@benchislett
Copy link
Collaborator

consider leveraging a maybe_execute_in_parallel primitive as in #35968

@xyang16
Copy link
Contributor Author

xyang16 commented Mar 13, 2026

Could you please also apply this to qwen 3.5?

@ZJY0516 Thanks for review! I have put the benchmark and accuracy testing result in PR description.

@xyang16
Copy link
Contributor Author

xyang16 commented Mar 13, 2026

@robertgshaw2-redhat Thanks for review! I have removed passing aux_stream the class constructors.

@ZJY0516
Copy link
Member

ZJY0516 commented Mar 13, 2026

Could you please also apply this to qwen 3.5?

@ZJY0516 Thanks for review! I have put the benchmark and accuracy testing result in PR description.

I don't see qwen 3.5 related code change

@xyang16 xyang16 force-pushed the multi_stream branch 2 times, most recently from 2cbaae6 to 4e75f43 Compare March 13, 2026 03:46
@mergify
Copy link

mergify bot commented Mar 13, 2026

Hi @xyang16, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

@xyang16
Copy link
Contributor Author

xyang16 commented Mar 13, 2026

Could you please also apply this to qwen 3.5?

@ZJY0516 Thanks for review! I have put the benchmark and accuracy testing result in PR description.

I don't see qwen 3.5 related code change

I have pushed the change in qwen 3.5. Thanks!

@xyang16
Copy link
Contributor Author

xyang16 commented Mar 13, 2026

consider leveraging a maybe_execute_in_parallel primitive as in #35968

@benchislett I have added maybe_execute_in_parallel. Thanks!

@xyang16 xyang16 changed the title [Perf] Enable dual stream execution of input projection for Qwen3 Next [Perf] Enable dual stream execution of input projection for Qwen3 Mar 13, 2026
xyang16 added 2 commits March 12, 2026 23:11
Signed-off-by: Xin Yang <xyangx@amazon.com>
@xyang16
Copy link
Contributor Author

xyang16 commented Mar 13, 2026

@robertgshaw2-redhat Could you please review again? Thanks!

@jhaotingc
Copy link
Contributor

Thanks for the implementation! #32828

def _forward_in_proj(
self, hidden_states: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
projected_states_qkvz, projected_states_ba = maybe_execute_in_parallel(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a small question about the naming here. maybe means it may not run in parallel, but in this case, we always run in parallel, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ZJY0516 In maybe_execute_in_parallel, if aux_stream is not None it runs in parallel, otherwise runs sequentially. aux_stream is None in none-cuda platform. Thanks!

@mgoin mgoin added the ready ONLY add when PR is ready to merge/full CI is needed label Mar 17, 2026
Comment on lines +10 to +21
def maybe_execute_in_parallel(
fn0: Callable[[], Any],
fn1: Callable[[], Any],
event0: torch.cuda.Event,
event1: torch.cuda.Event,
aux_stream: torch.cuda.Stream | None = None,
) -> tuple[Any, Any]:
"""Run two functions potentially in parallel on separate CUDA streams.
When aux_stream is provided, fn0 runs on the current (default) stream and
fn1 runs on aux_stream, synchronized via CUDA events. When aux_stream is
None, both functions execute sequentially on the current stream.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do like this utility as a pattern to apply generally

Comment on lines +1699 to +1710
def gdn_in_proj(
hidden_states: torch.Tensor,
qkvz_output_size: int,
ba_output_size: int,
layer_name: str,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Custom op for the input projection.
"""
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
return self._forward_in_proj(hidden_states)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This indirection is pretty gross though. Could we avoid this somehow? I found it very confusing that you were passing in self.in_proj_qkvz.weight.shape[0] to this op instead of the module itself.

Also there is the concern of wrapping these MergedColumnParallelLinear modules that could be quantized - it seems we would lose the potential of torch.compile fusing the input quantization with previous ops or reaching inside of the linear op itself (less valid concern)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mgoin Thanks for review!

Actually this layer is already wrapped here https://github.com/vllm-project/vllm/blob/v0.18.0rc0/vllm/model_executor/models/qwen3_next.py#L1673-L1692. And I agree this should be improved once torch.compile supports multi stream.

ba_output_size: int,
layer_name: str,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a tracking issue somewhere for porting this over to native Inductor multi-stream support?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have created an issue to track #37372. Thanks!

@DarkLight1337 DarkLight1337 merged commit f174000 into vllm-project:main Mar 18, 2026
54 checks passed
JartX added a commit to JartX/vllm that referenced this pull request Mar 18, 2026
wendyliu235 pushed a commit to wendyliu235/vllm-public that referenced this pull request Mar 18, 2026
yewentao256 pushed a commit that referenced this pull request Mar 18, 2026
ikaadil pushed a commit to ikaadil/vllm that referenced this pull request Mar 19, 2026
…ct#36795) (vllm-project#37427)

Signed-off-by: JartX <sagformas@epdcenter.es>
Signed-off-by: Ifta Khairul Alam Adil <ikaadil007@gmail.com>
ikaadil pushed a commit to ikaadil/vllm that referenced this pull request Mar 19, 2026
…ct#36795) (vllm-project#37427)

Signed-off-by: JartX <sagformas@epdcenter.es>
Signed-off-by: Ifta Khairul Alam Adil <ikaadil007@gmail.com>
JWriter20 added a commit to JWriter20/vllm that referenced this pull request Mar 19, 2026
…with LoRA

The `gdn_in_proj` custom op (introduced in f174000 / PR vllm-project#36795) uses
`self.in_proj_qkvz.weight.shape[0]` to communicate the output tensor
size to torch.compile's fake implementation. With LoRA + AWQ/GPTQ
quantization, `.weight` returns the quantized `qweight` whose shape is
packed (e.g. input_size // 8 for 4-bit), causing a dimension mismatch
in the subsequent `.split()` call.

Fix: compute output sizes analytically from model dimensions
(key_dim, value_dim, num_v_heads, tp_size) instead of reading from
the weight tensor shape. These computed values are identical to
weight.shape[0] for non-quantized models, so there is no regression.

Tested with:
- cyankiwi/Qwen3.5-9B-AWQ-4bit + LoRA adapters (torch.compile)
- Qwen/Qwen3.5-9B without quantization (torch.compile)
- Qwen/Qwen3.5-9B + LoRA adapters without quantization (eager)
- Qwen/Qwen3.5-35B-A3B-GPTQ-Int4 (torch.compile)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
JWriter20 added a commit to JWriter20/vllm that referenced this pull request Mar 19, 2026
…with LoRA

The `gdn_in_proj` custom op (introduced in f174000 / PR vllm-project#36795) uses
`self.in_proj_qkvz.weight.shape[0]` to communicate the output tensor
size to torch.compile's fake implementation. With LoRA + AWQ/GPTQ
quantization, `.weight` returns the quantized `qweight` whose shape is
packed (e.g. input_size // 8 for 4-bit), causing a dimension mismatch
in the subsequent `.split()` call.

Fix: compute output sizes analytically from model dimensions
(key_dim, value_dim, num_v_heads, tp_size) instead of reading from
the weight tensor shape. These computed values are identical to
weight.shape[0] for non-quantized models, so there is no regression.

Tested with:
- cyankiwi/Qwen3.5-9B-AWQ-4bit + LoRA adapters (torch.compile)
- Qwen/Qwen3.5-9B without quantization (torch.compile)
- Qwen/Qwen3.5-9B + LoRA adapters without quantization (eager)
- Qwen/Qwen3.5-35B-A3B-GPTQ-Int4 (torch.compile)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: Jake Writer <writer.j@northeastern.edu>
fxdawnn pushed a commit to fxdawnn/vllm that referenced this pull request Mar 19, 2026
fxdawnn pushed a commit to fxdawnn/vllm that referenced this pull request Mar 19, 2026
@xyang16 xyang16 deleted the multi_stream branch March 19, 2026 22:42
Comment on lines +183 to +188
mixed_qkvz, ba = torch.ops.vllm.gdn_in_proj(
hidden_states,
self.in_proj_qkvz.weight.shape[0],
self.in_proj_ba.weight.shape[0],
self.prefix,
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This regresses cold compile times by baking in a string into the compiled graph. We should really make a lint rule for this or something

Copy link
Contributor Author

@xyang16 xyang16 Mar 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zou3519 Thanks for your comment. Looking into this. btw I was actually following torch.ops.vllm.gdn_attention_core ops in the same forward().

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.ops.vllm.gdn_attention_core is not included in the subgraph so it doesn't cause problems with compile times. I'm trying to figure out what to do with this. In theory we have a fix for this in PyTorch 2.11

Copy link
Contributor Author

@xyang16 xyang16 Mar 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can revert using this torch.ops.vllm.gdn_in_proj op and wait for PyTorch 2.11. Please let me know how you think. Thanks!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@xyang16 are you able to refactor this so that the gdn_in_proj op does NOT need to pass a string as an input? Basically we would avoid stashing state into a side table. How difficult do you think that would be?

Copy link
Contributor Author

@xyang16 xyang16 Mar 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zou3519 Sure, I will look into this today.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zou3519 I see your PR 38123. So it will fix this issue?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will fix the issue for PyTorch 2.11. But vLLM is going to do one more release (0.19.0, branch cut this Monday) without PyTorch 2.11.

If we can wait for the performance improvement in this PR, the easiest thing for us to do is just revert this PR and then re-merge it after #38123 and we upgrde to 2.11 (probably Tuesday)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zou3519 Thanks for the help. I have created #38152 to revert this PR. cc @benchislett

SouthWest7 pushed a commit to SouthWest7/vllm that referenced this pull request Mar 27, 2026
SouthWest7 pushed a commit to SouthWest7/vllm that referenced this pull request Mar 27, 2026
khairulkabir1661 pushed a commit to khairulkabir1661/vllm that referenced this pull request Mar 27, 2026
khairulkabir1661 pushed a commit to khairulkabir1661/vllm that referenced this pull request Mar 27, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

qwen Related to Qwen models ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants