Skip to content

[Feature] Optimizations for class Qwen3VLMoeVisionModel (Conv3d to Linear) in Qwen3VL#19788

Closed
wili-65535 wants to merge 1 commit intosgl-project:mainfrom
wili-65535:wili/VisionPathchEmbed
Closed

[Feature] Optimizations for class Qwen3VLMoeVisionModel (Conv3d to Linear) in Qwen3VL#19788
wili-65535 wants to merge 1 commit intosgl-project:mainfrom
wili-65535:wili/VisionPathchEmbed

Conversation

@wili-65535
Copy link
Copy Markdown
Contributor

@wili-65535 wili-65535 commented Mar 3, 2026

Motivation

Modifications

  • Use Linear rather than Conv3d to compute the Qwen3VLVisionPatchEmbed.
  • Add unit test test/srt/test_conv3d_to_linear_unittest.py to ensure the replacement is equivalent.

Accuracy Tests

  1. By unit test test/srt/test_conv3d_to_linear_unittest.py.

  2. lm-eval shows no drop between main branch and this PR, which get the same score:

  • Command:
lm-eval run --model sglang --model_args pretrained=/workspace/Qwen3-VL-8B-Instruct,dtype=auto,tp_size=1 --tasks gpqa_diamond_zeroshot gpqa_extended_zeroshot gpqa_main_zeroshot gpqa_diamond_n_shot gpqa_extended_n_shot gpqa_main_n_shot
  • before this optimization (main branch baseline)
Tasks Version Filter n-shot Metric Value Stderr
gpqa_diamond_n_shot 2 none 0 acc 0.3687 ± 0.0344
none 0 acc_norm 0.3687 ± 0.0344
gpqa_diamond_zeroshot 1 none 0 acc 0.3788 ± 0.0346
none 0 acc_norm 0.3788 ± 0.0346
gpqa_extended_n_shot 2 none 0 acc 0.3791 ± 0.0208
none 0 acc_norm 0.3791 ± 0.0208
gpqa_extended_zeroshot 1 none 0 acc 0.3773 ± 0.0208
none 0 acc_norm 0.3773 ± 0.0208
gpqa_main_n_shot 2 none 0 acc 0.3862 ± 0.0230
none 0 acc_norm 0.3862 ± 0.0230
gpqa_main_zeroshot 1 none 0 acc 0.4085 ± 0.0232
none 0 acc_norm 0.4085 ± 0.0232
  • After the optimization (this PR)
Tasks Version Filter n-shot Metric Value Stderr
gpqa_diamond_n_shot 2 none 0 acc 0.3687 ± 0.0344
none 0 acc_norm 0.3687 ± 0.0344
gpqa_diamond_zeroshot 1 none 0 acc 0.3788 ± 0.0346
none 0 acc_norm 0.3788 ± 0.0346
gpqa_extended_n_shot 2 none 0 acc 0.3791 ± 0.0208
none 0 acc_norm 0.3791 ± 0.0208
gpqa_extended_zeroshot 1 none 0 acc 0.3773 ± 0.0208
none 0 acc_norm 0.3773 ± 0.0208
gpqa_main_n_shot 2 none 0 acc 0.3862 ± 0.0230
none 0 acc_norm 0.3862 ± 0.0230
gpqa_main_zeroshot 1 none 0 acc 0.4085 ± 0.0232
none 0 acc_norm 0.4085 ± 0.0232
  • lmms_eval shows no drop between main branch and this PR, which get the similar score:

  • Command:

python3 -m lmms_eval --model sglang --model_args model_path=/workspace/Qwen3-VL-8B-Instruct --tasks mmmu_val
  • Before this optimization (main branch baseline)
{'Overall-Art and Design': {'num': 120, 'acc': 0.68333}, 'Art': {'num': 30, 'acc': 0.66667}, 'Art_Theory': {'num': 30, 'acc': 0.9}, 'Design': {'num': 30, 'acc': 0.73333}, 'Music': {'num': 30, 'acc': 0.43333}, 'Overall-Business': {'num': 150, 'acc': 0.40667}, 'Accounting': {'num': 30, 'acc': 0.3}, 'Economics': {'num': 30, 'acc': 0.5}, 'Finance': {'num': 30, 'acc': 0.26667}, 'Manage': {'num': 30, 'acc': 0.5}, 'Marketing': {'num': 30, 'acc': 0.46667}, 'Overall-Science': {'num': 150, 'acc': 0.45333}, 'Biology': {'num': 30, 'acc': 0.53333}, 'Chemistry': {'num': 30, 'acc': 0.36667}, 'Geography': {'num': 30, 'acc': 0.53333}, 'Math': {'num': 30, 'acc': 0.3}, 'Physics': {'num': 30, 'acc': 0.53333}, 'Overall-Health and Medicine': {'num': 150, 'acc': 0.53333}, 'Basic_Medical_Science': {'num': 30, 'acc': 0.66667}, 'Clinical_Medicine': {'num': 30, 'acc': 0.66667}, 'Diagnostics_and_Laboratory_Medicine': {'num': 30, 'acc': 0.36667}, 'Pharmacy': {'num': 30, 'acc': 0.5}, 'Public_Health': {'num': 30, 'acc': 0.46667}, 'Overall-Humanities and Social Science': {'num': 120, 'acc': 0.675}, 'History': {'num': 30, 'acc': 0.6}, 'Literature': {'num': 30, 'acc': 0.83333}, 'Sociology': {'num': 30, 'acc': 0.63333}, 'Psychology': {'num': 30, 'acc': 0.63333}, 'Overall-Tech and Engineering': {'num': 210, 'acc': 0.38095}, 'Agriculture': {'num': 30, 'acc': 0.53333}, 'Architecture_and_Engineering': {'num': 30, 'acc': 0.3}, 'Computer_Science': {'num': 30, 'acc': 0.5}, 'Electronics': {'num': 30, 'acc': 0.3}, 'Energy_and_Power': {'num': 30, 'acc': 0.3}, 'Materials': {'num': 30, 'acc': 0.4}, 'Mechanical_Engineering': {'num': 30, 'acc': 0.33333}, 'Overall': {'num': 900, 'acc': 0.50222}}
2026-03-16T05:00:15.655750+0000 | save_results_aggregated | INFO - Output path not provided, skipping saving results aggregated
sglang (model_path=/workspace/Qwen3-VL-8B-Instruct), gen_kwargs: (), limit: None, offset: 0, num_fewshot: None, batch_size: 1

LMMs-Eval: Probing Intelligence in the Real World
> The unified evaluation toolkit for frontier models.

branch: main
commit: v0.6-72-g88b23e2b

| Tasks  |Filter|n-shot| Metric |   |Value |   |Stderr|
|--------|------|-----:|--------|---|-----:|---|------|
|mmmu_val|none  |     0|mmmu_acc|↑  |0.5022|±  |N/A   |
  • After this optimization (this PR)
{'Overall-Art and Design': {'num': 120, 'acc': 0.68333}, 'Art': {'num': 30, 'acc': 0.66667}, 'Art_Theory': {'num': 30, 'acc': 0.9}, 'Design': {'num': 30, 'acc': 0.73333}, 'Music': {'num': 30, 'acc': 0.43333}, 'Overall-Business': {'num': 150, 'acc': 0.40667}, 'Accounting': {'num': 30, 'acc': 0.3}, 'Economics': {'num': 30, 'acc': 0.5}, 'Finance': {'num': 30, 'acc': 0.26667}, 'Manage': {'num': 30, 'acc': 0.5}, 'Marketing': {'num': 30, 'acc': 0.46667}, 'Overall-Science': {'num': 150, 'acc': 0.45333}, 'Biology': {'num': 30, 'acc': 0.53333}, 'Chemistry': {'num': 30, 'acc': 0.36667}, 'Geography': {'num': 30, 'acc': 0.53333}, 'Math': {'num': 30, 'acc': 0.3}, 'Physics': {'num': 30, 'acc': 0.53333}, 'Overall-Health and Medicine': {'num': 150, 'acc': 0.53333}, 'Basic_Medical_Science': {'num': 30, 'acc': 0.66667}, 'Clinical_Medicine': {'num': 30, 'acc': 0.66667}, 'Diagnostics_and_Laboratory_Medicine': {'num': 30, 'acc': 0.36667}, 'Pharmacy': {'num': 30, 'acc': 0.5}, 'Public_Health': {'num': 30, 'acc': 0.46667}, 'Overall-Humanities and Social Science': {'num': 120, 'acc': 0.675}, 'History': {'num': 30, 'acc': 0.6}, 'Literature': {'num': 30, 'acc': 0.83333}, 'Sociology': {'num': 30, 'acc': 0.63333}, 'Psychology': {'num': 30, 'acc': 0.63333}, 'Overall-Tech and Engineering': {'num': 210, 'acc': 0.38095}, 'Agriculture': {'num': 30, 'acc': 0.53333}, 'Architecture_and_Engineering': {'num': 30, 'acc': 0.3}, 'Computer_Science': {'num': 30, 'acc': 0.5}, 'Electronics': {'num': 30, 'acc': 0.3}, 'Energy_and_Power': {'num': 30, 'acc': 0.3}, 'Materials': {'num': 30, 'acc': 0.4}, 'Mechanical_Engineering': {'num': 30, 'acc': 0.33333}, 'Overall': {'num': 900, 'acc': 0.50222}}
2026-03-16T05:11:01.956747+0000 | save_results_aggregated | INFO - Output path not provided, skipping saving results aggregated
sglang (model_path=/workspace/Qwen3-VL-8B-Instruct), gen_kwargs: (), limit: None, offset: 0, num_fewshot: None, batch_size: 1

LMMs-Eval: Probing Intelligence in the Real World
> The unified evaluation toolkit for frontier models.

branch: main
commit: v0.6-72-g88b23e2b

| Tasks  |Filter|n-shot| Metric |   |Value |   |Stderr|
|--------|------|-----:|--------|---|-----:|---|------|
|mmmu_val|none  |     0|mmmu_acc|↑  |0.5022|±  |N/A   |

Benchmarking and Profiling

  • Part of the performance data is in the original issue, here we file more detailed data.
  • We use H100 GPU to run Qwen3-VL-8B model (the larger the model is, the more significant the performance improvement earns), sending requests with one JPEG image of different resolution
  • Add code below (in forward() of class Qwen3VLMoeVisionModel) to count the time.
        # original code
        x = x.to(device=self.device, dtype=self.dtype)

        # timming code
        for _ in range(10):
            y = self.patch_embed(x)
        patch_embed_start = torch.cuda.Event(enable_timing=True)
        patch_embed_end = torch.cuda.Event(enable_timing=True)
        patch_embed_start.record()
        for _ in range(30):
            y = self.patch_embed(x)
        patch_embed_end.record()
        torch.cuda.synchronize(x.device)
        patch_embed_elapsed_ms = patch_embed_start.elapsed_time(patch_embed_end) / 30
        print(f"###[{x.shape[0]}]: {patch_embed_elapsed_ms:.3f} ms")
  • The results before / after this PR are shown below. Averagely 6.9x acceleration is earned, and the larger the image is, the better performance is earned.
  • In our old scenario (Qwen3-VL-235B + TP8 + DP1, one request with 20 images of 960x1280), approximately 11.0x acceleration was earned.
size conv3d/ms linear/ms SpeedUp
32x32 0.122 0.070 1.74
64x64 0.122 0.070 1.74
96x96 0.123 0.070 1.76
128x128 0.123 0.071 1.73
160x160 0.123 0.071 1.73
192x192 0.124 0.071 1.75
224x224 0.129 0.071 1.82
256x256 0.136 0.073 1.86
288x288 0.136 0.072 1.89
320x320 0.175 0.070 2.50
352x352 0.228 0.069 3.30
384x384 0.279 0.083 3.36
416x416 0.261 0.076 3.43
448x448 0.270 0.069 3.91
480x480 0.308 0.074 4.16
512x512 0.334 0.069 4.84
544x544 0.375 0.069 5.43
576x576 0.414 0.069 6.00
608x608 0.499 0.070 7.13
640x640 0.619 0.070 8.84
672x672 0.646 0.070 9.23
704x704 0.663 0.069 9.61
736x736 0.672 0.069 9.74
768x768 0.707 0.070 10.10
800x800 0.747 0.068 10.99
832x832 0.875 0.069 12.68
864x864 0.847 0.069 12.28
896x896 1.028 0.069 14.90
928x928 1.065 0.070 15.21
960x960 1.127 0.071 15.87
992x992 1.174 0.071 16.54
1024x1024 1.235 0.084 14.70

Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • /tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci
  4. After green CI and required approvals, ask Merge Oncalls to merge.

@wili-65535 wili-65535 changed the title v0.1: conv3d to linear [Feature] Optimizations for class Qwen3VLMoeVisionModel (Conv3d to Linear) in Qwen3VL Mar 3, 2026
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request focuses on optimizing the Qwen3VLVisionPatchEmbed module by replacing the Conv3d layer with a Linear layer. This change leverages the existing Conv3d layer for initial weight loading and then transfers the weights to the Linear layer for actual computation, resulting in significant performance improvements, especially with larger image sizes.

Highlights

  • Optimization: Replaced Conv3d with Linear in Qwen3VLVisionPatchEmbed to improve performance.
  • Weight Handling: Utilized the original Conv3d to accept weight data, then copied the weight into Linear for computation.
  • Performance Improvement: Achieved an average of 6.9x acceleration, with larger images showing better performance gains.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • python/sglang/srt/models/qwen3_vl.py
    • Replaced Conv3d with Linear for Qwen3VLVisionPatchEmbed.
    • Added a method to copy weights from Conv3d to Linear.
    • Modified the forward pass to use the Linear layer.
    • Added a call to copy_conv3d_weight_to_linear after loading weights.
Activity
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a performance optimization by replacing the Conv3d layer with a Linear layer in Qwen3VLVisionPatchEmbed. The weights are loaded into the Conv3d layer and then copied to the Linear layer, after which the Conv3d layer is deleted to save memory. The changes are logical and well-implemented. My feedback includes a suggestion to make the weight copying method idempotent for increased robustness.

@wili-65535 wili-65535 force-pushed the wili/VisionPathchEmbed branch 2 times, most recently from ee2b18a to 528ed81 Compare March 6, 2026 08:32
@samuellees
Copy link
Copy Markdown
Contributor

Could you please provide accuracy result?

zhou9402 pushed a commit to sgl-project/sglang-omni that referenced this pull request Mar 6, 2026
Conv3d with kernel_size == stride does not slide over the input, making
it equivalent to a reshape + linear projection. Replace it with
nn.Linear after loading HF checkpoint weights for significantly faster
patch embedding, especially on older cuDNN versions where Conv3d with
large non-sliding kernels is extremely slow.

Safety checks ensure the optimization only applies when kernel_size ==
stride, padding == 0, dilation == 1, and groups == 1.

Reference: sgl-project/sglang#19788
Made-with: Cursor
FrankLeeeee pushed a commit to sgl-project/sglang-omni that referenced this pull request Mar 6, 2026
…up (#132)

Conv3d with kernel_size == stride does not slide over the input, making
it equivalent to a reshape + linear projection. Replace it with
nn.Linear after loading HF checkpoint weights for significantly faster
patch embedding, especially on older cuDNN versions where Conv3d with
large non-sliding kernels is extremely slow.

Safety checks ensure the optimization only applies when kernel_size ==
stride, padding == 0, dilation == 1, and groups == 1.

Reference: sgl-project/sglang#19788
Made-with: Cursor

Co-authored-by: Hongli Mi <hmi@nvidia.com>
@wili-65535 wili-65535 force-pushed the wili/VisionPathchEmbed branch 2 times, most recently from 4b3a6f9 to 241bb50 Compare March 6, 2026 19:11
@wili-65535
Copy link
Copy Markdown
Contributor Author

Could you please provide accuracy result?

Sure, we add a unit test test/srt/test_conv3d_to_linear_unittest.py for this replacement and use lmms_evals to test the score between main branch and this PR.
See result in the description.

@yuan-luo
Copy link
Copy Markdown
Collaborator

yuan-luo commented Mar 7, 2026

/tag-and-rerun-ci

@github-actions github-actions bot added the run-ci label Mar 7, 2026
@yuan-luo
Copy link
Copy Markdown
Collaborator

yuan-luo commented Mar 7, 2026

/rerun-failed-ci

@wili-65535 wili-65535 force-pushed the wili/VisionPathchEmbed branch from 241bb50 to c0e9724 Compare March 8, 2026 13:50
@samuellees
Copy link
Copy Markdown
Contributor

/rerun-failed-ci

@wili-65535
Copy link
Copy Markdown
Contributor Author

wili-65535 commented Mar 9, 2026

Hi @yuan-luo , I see your PR20033 for GLM4v has been merged. Cheers!
I've adjusted this PR aligned with yours. Now I wonder,

  • How can we move this PR forward? Does our CI pipeline resume to work?
  • Are the accuracy and performance results in the description enough for this PR?
  • Do we still need the unit test TestConv3dToLinear beside your patch_embed_linear_conv3d?

Oh, I see the CI is rerunning, thank you...

It seems the CI failures are not related to our code change?

registered/metrics/test_priority_metrics.py (exit code 1)
registered/lora/test_multi_lora_backend.py (exit code 1)
sglang/multimodal_gen/test/server/ascend/test_server_2_npu.py::TestDiffusionServerTwoNpu::test_diffusion_generation[flux_2_image_t2i_2npu]
...

@yuan-luo
Copy link
Copy Markdown
Collaborator

yuan-luo commented Mar 9, 2026

Hi @yuan-luo , I see your PR20033 for GLM4v has been merged. Cheers! I've adjusted this PR aligned with yours. Now I wonder,

  • How can we move this PR forward? Does our CI pipeline resume to work?
  • Are the accuracy and performance results in the description enough for this PR?
  • Do we still need the unit test TestConv3dToLinear beside your patch_embed_linear_conv3d?

Oh, I see the CI is rerunning, thank you...

@wili-65535 We can probably keep this unit test as it is related with Qwen3VLMoE. BTW, can we move test to the vlm folder?

Let's wait for the CI passed.

@yuan-luo
Copy link
Copy Markdown
Collaborator

yuan-luo commented Mar 9, 2026

This CI failure may need to take special care of.
registered/vlm/test_vision_openai_server_a.py (exit code 1)

@yuan-luo
Copy link
Copy Markdown
Collaborator

yuan-luo commented Mar 9, 2026

/rerun-failed-ci

@wili-65535 wili-65535 force-pushed the wili/VisionPathchEmbed branch 3 times, most recently from 6bc6b65 to 27612f4 Compare March 9, 2026 14:25
@wili-65535
Copy link
Copy Markdown
Contributor Author

wili-65535 commented Mar 10, 2026

Here are also some error information reproduced stably in the CI pipeline.

registered/lora/test_multi_lora_backend.py

KeyError: '/loky-7350-k06btve4'

xpu/test_intel_xpu_backend.py

AttributeError: module 'torch.xpu' has no attribute 'graph_pool_handle'

In addition, the output of Qwen3VL model in test registered/vlm/test_vision_openai_server_a.py::TestQwen3VLServer is all mess, but it has no problem in my local tests, I will try to find the reason.

@samuellees
Copy link
Copy Markdown
Contributor

/rerun-failed-ci

@wili-65535 wili-65535 force-pushed the wili/VisionPathchEmbed branch from 27612f4 to fc1f1c2 Compare March 10, 2026 07:04
@wili-65535
Copy link
Copy Markdown
Contributor Author

wili-65535 commented Mar 10, 2026

OK the error in test registered/vlm/test_vision_openai_server_a.py::TestQwen3VLServer is resolved.
The root cause is the change should be done both in qwen3_vl.py and qwen3_vl_moe.py.
Now here should be only the errors not related to this PR.

@yhyang201
Copy link
Copy Markdown
Collaborator

yhyang201 commented Mar 10, 2026

The latest CI results seem to indicate that this change may have introduced an accuracy issue. Could you please take a look when you have a chance? Many thanks!
https://github.com/sgl-project/sglang/actions/runs/22891107783/job/66415614526?pr=19788

@yhyang201
Copy link
Copy Markdown
Collaborator

I opened #20282 which builds on your idea and generalizes it into a unified Conv2dLayer/Conv3dLayer abstraction layer (sglang/srt/layers/conv.py). Conv3dLayer enables unfold+linear by default, which fixes the CuDNN < 9.15 compatibility issue and also gives a nice speedup for patch embeddings. It covers all 3 Conv3d models and 12 Conv2d patch embedding models, and removes the global CuDNN compatibility check from server_args.py.

What do you think about this approach? Would appreciate it if you could help review when you have a chance. Thanks!

@wili-65535
Copy link
Copy Markdown
Contributor Author

I opened #20282 which builds on your idea and generalizes it into a unified Conv2dLayer/Conv3dLayer abstraction layer (sglang/srt/layers/conv.py). Conv3dLayer enables unfold+linear by default, which fixes the CuDNN < 9.15 compatibility issue and also gives a nice speedup for patch embeddings. It covers all 3 Conv3d models and 12 Conv2d patch embedding models, and removes the global CuDNN compatibility check from server_args.py.

What do you think about this approach? Would appreciate it if you could help review when you have a chance. Thanks!

Great idea! I had actually considered a common implementation for models with similar sub-modules, but got blocked at Qwen3VL due to my limited familiarity with sglang layers and other models. Really appreciate you picking this up and making it more general. I'll review the PR and let you know if I have any feedback.

@wili-65535 wili-65535 force-pushed the wili/VisionPathchEmbed branch from fc1f1c2 to 5f0da01 Compare March 10, 2026 22:45
@github-actions github-actions bot added the Multi-modal multi-modal language model label Mar 10, 2026
@wili-65535
Copy link
Copy Markdown
Contributor Author

The latest CI results seem to indicate that this change may have introduced an accuracy issue. Could you please take a look when you have a chance? Many thanks! https://github.com/sgl-project/sglang/actions/runs/22891107783/job/66415614526?pr=19788

Fixed, root cause is Qwen3Omni needs adjust, too.

@wili-65535 wili-65535 force-pushed the wili/VisionPathchEmbed branch 2 times, most recently from bd9ef2b to bf19d8e Compare March 10, 2026 23:43
@samuellees
Copy link
Copy Markdown
Contributor

/rerun-failed-ci

2 similar comments
@samuellees
Copy link
Copy Markdown
Contributor

/rerun-failed-ci

@samuellees
Copy link
Copy Markdown
Contributor

/rerun-failed-ci

@wili-65535
Copy link
Copy Markdown
Contributor Author

Hi @yhyang201 @samuellees
I update the result of GPQA tests in the description, and now it seems the CI errors are not related to this PR (shown as below).
Could we move forward the PR?

ImportError: cannot import name 'intel' from 'triton._C.libtriton' (/home/sdp/miniforge3/envs/py3.10/lib/python3.10/site-packages/triton/_C/libtriton.so)

filename='registered/ep/test_deepep_small.py', elapsed=1200, estimated_time=531.0
Error: The action 'Run test' has timed out after 20 minutes.

@yhyang201
Copy link
Copy Markdown
Collaborator

Hello, this PR (#20282) has already been merged. Could you please help check its performance and quality? Thank you!

v1.0: align to PR20033

v1.1: add unit test

v1.2: fix CI error
@wili-65535 wili-65535 force-pushed the wili/VisionPathchEmbed branch from bf19d8e to 262a4ac Compare March 25, 2026 03:24
@wili-65535 wili-65535 closed this Mar 29, 2026
@wili-65535 wili-65535 deleted the wili/VisionPathchEmbed branch March 30, 2026 02:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Multi-modal multi-modal language model run-ci

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants