Skip to content

vlm: remove redundant d2h movement of mm feature tensors#9987

Merged
JustinTong0323 merged 10 commits intosgl-project:mainfrom
AlienKevin:fast-hash-feature
Sep 17, 2025
Merged

vlm: remove redundant d2h movement of mm feature tensors#9987
JustinTong0323 merged 10 commits intosgl-project:mainfrom
AlienKevin:fast-hash-feature

Conversation

@AlienKevin
Copy link
Contributor

Motivation

I found that surprisingly discovery is that the optimized gpu_tensor_hash is not triggered at all and the slow SHA256 data_hash is still used for hashing image requests. The reason was a recent change from a month ago that moved pixel_values to CPU to the GPU to prevent a memory leak. Now that the memory leak is identified and fixed, we can remove this redundant movement and use the fast GPU image hashing again!

Benchmarking and Profiling

After this PR, 7.5% boost in throughput on MMMU (Math) and 5.2% boost on fixed ISL1000/OSL1 benchmarks.

Details:
Hashing on CPU with SHA256 (Current behavior):
ISL1000/OSL1, max-concurrency=256 (best of 3 trials): 13483 tok/s

============ Serving Benchmark Result ============
Backend:                                 sglang-oai-chat
Traffic request rate:                    inf       
Max request concurrency:                 256       
Successful requests:                     768       
Benchmark duration (s):                  58.27     
Total input tokens:                      784933    
Total input text tokens:                 15397     
Total input vision tokens:               769536    
Total generated tokens:                  768       
Total generated tokens (retokenized):    768       
Request throughput (req/s):              13.18     
Input token throughput (tok/s):          13469.78  
Output token throughput (tok/s):         13.18     
Total token throughput (tok/s):          13482.96  
Concurrency:                             224.41     
==================================================

MMMU (Math): 10403 tok/s

============ Serving Benchmark Result ============
Backend:                                 sglang-oai-chat
Traffic request rate:                    inf       
Max request concurrency:                 256       
Successful requests:                     505       
Benchmark duration (s):                  20.66     
Total input tokens:                      214424    
Total input text tokens:                 44483     
Total input vision tokens:               169941    
Total generated tokens:                  505       
Total generated tokens (retokenized):    505       
Request throughput (req/s):              24.44     
Input token throughput (tok/s):          10378.11  
Output token throughput (tok/s):         24.44     
Total token throughput (tok/s):          10402.55  
Concurrency:                             218.94

Hashing on GPU with custom kernel:
ISL1000/OSL1, max-concurrency=128: 14525 tok/s

============ Serving Benchmark Result ============
Backend:                                 sglang-oai-chat
Traffic request rate:                    inf       
Max request concurrency:                 128       
Successful requests:                     768       
Benchmark duration (s):                  54.09     
Total input tokens:                      784923    
Total input text tokens:                 15386     
Total input vision tokens:               769537    
Total generated tokens:                  768       
Total generated tokens (retokenized):    768       
Request throughput (req/s):              14.20     
Input token throughput (tok/s):          14510.63  
Output token throughput (tok/s):         14.20     
Total token throughput (tok/s):          14524.82  
Concurrency:                             119.99     
==================================================

ISL1000/OSL1, max-concurrency=256 (best of 3 trials): 14190 tok/s

============ Serving Benchmark Result ============
Backend:                                 sglang-oai-chat
Traffic request rate:                    inf       
Max request concurrency:                 256       
Successful requests:                     768       
Benchmark duration (s):                  55.37     
Total input tokens:                      784918    
Total input text tokens:                 15382     
Total input vision tokens:               769536    
Total generated tokens:                  768       
Total generated tokens (retokenized):    768       
Request throughput (req/s):              13.87     
Input token throughput (tok/s):          14175.75  
Output token throughput (tok/s):         13.87     
Total token throughput (tok/s):          14189.62  
Concurrency:                             222.80        
==================================================

ISL1000/OSL1, max-concurrency=512: 13123 tok/s

============ Serving Benchmark Result ============
Backend:                                 sglang-oai-chat
Traffic request rate:                    inf       
Max request concurrency:                 512       
Successful requests:                     768       
Benchmark duration (s):                  59.87     
Total input tokens:                      784904    
Total input text tokens:                 15366     
Total input vision tokens:               769538    
Total generated tokens:                  768       
Total generated tokens (retokenized):    768       
Request throughput (req/s):              12.83     
Input token throughput (tok/s):          13109.67  
Output token throughput (tok/s):         12.83     
Total token throughput (tok/s):          13122.50  
Concurrency:                             384.82         
==================================================

ISL1000/OSL1, max-concurrency=768
OOM

MMMU (Math): 11184 tok/s

============ Serving Benchmark Result ============
Backend:                                 sglang-oai-chat
Traffic request rate:                    inf       
Max request concurrency:                 256       
Successful requests:                     505       
Benchmark duration (s):                  19.22     
Total input tokens:                      214424    
Total input text tokens:                 44483     
Total input vision tokens:               169941    
Total generated tokens:                  505       
Total generated tokens (retokenized):    505       
Request throughput (req/s):              26.28     
Input token throughput (tok/s):          11157.85  
Output token throughput (tok/s):         26.28     
Total token throughput (tok/s):          11184.13  
Concurrency:                             216.03        
==================================================

Measured on my bench_serving enhancement PR (synced with latest main):

Server command:

SGLANG_VLM_CACHE_SIZE_MB=0 python -m sglang.launch_server \
    --model-path Qwen/Qwen2.5-VL-7B-Instruct \
    --mem-fraction-static 0.8 \
    --chat-template 'qwen2-vl' \
    --tp 1 \
    --disable-radix-cache \
    --cuda-graph-bs 256 \
    --cuda-graph-max-bs 256 \
    --chunked-prefill-size 8192 \
    --max-prefill-tokens 8192 \
    --max-running-requests 256 \
    --enable-multimodal

Client command:

python3 -m sglang.bench_serving \
    --backend sglang-oai-chat \
    --dataset-name image \
    --num-prompts 768 \
    --apply-chat-template \
    --random-output-len 1 \
    --random-input-len 1 \
    --image-resolution 1120x700 \
    --image-format jpeg \
    --image-count 1 \
    --image-content random \
    --random-range-ratio 1 \
    --max-concurrency X

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @AlienKevin, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request optimizes the processing of image feature tensors by re-enabling GPU-based hashing. This was made possible by a recent fix for a memory leak that previously necessitated a workaround involving CPU-based hashing. The change eliminates an unnecessary data transfer step, resulting in improved overall system throughput for multimodal benchmarks.

Highlights

  • Performance Improvement: Re-enabled the optimized GPU tensor hashing for image requests, which was previously bypassed, leading to significant throughput gains.
  • Code Optimization: Removed redundant device-to-host (GPU to CPU) movement of image feature tensors, streamlining the data processing pipeline.
  • Impact: Achieved a 7.5% boost in throughput on MMMU (Math) and a 5.2% boost on fixed ISL1000/OSL1 benchmarks by leveraging faster GPU operations.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request removes a redundant device-to-host data transfer for image feature tensors within the process_mm_data method. The change is well-justified, as it eliminates a workaround previously implemented to mitigate a memory leak that has since been resolved. By ensuring feature tensors remain on the GPU, this modification successfully re-enables a more performant GPU-based hashing function, which is validated by the significant throughput improvements shown in the benchmarks. The code is cleaner and more efficient as a result. This is an excellent and well-documented improvement.

@AlienKevin
Copy link
Contributor Author

AlienKevin commented Sep 3, 2025

@JustinTong0323 Verified no memory leak with 7,000 requests under max-concurrency of 256. VRAM usage stays consistently around 92% throughout.

Screenshot 2025-09-03 at 4 16 12 PM
============ Serving Benchmark Result ============
Backend:                                 sglang-oai-chat
Traffic request rate:                    inf       
Max request concurrency:                 256       
Successful requests:                     7000      
Benchmark duration (s):                  438.06    
Total input tokens:                      7154145   
Total input text tokens:                 140143    
Total input vision tokens:               7014002   
Total generated tokens:                  7000      
Total generated tokens (retokenized):    7000      
Request throughput (req/s):              15.98     
Input token throughput (tok/s):          16331.59  
Output token throughput (tok/s):         15.98     
Total token throughput (tok/s):          16347.57  
Concurrency:                             251.51    
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   15739.46  
Median E2E Latency (ms):                 15532.16  
---------------Time to First Token----------------
Mean TTFT (ms):                          15739.42  
Median TTFT (ms):                        15532.11  
P99 TTFT (ms):                           22711.53  

Before this PR:
Screenshot 2025-09-03 at 5 03 49 PM

Server cmd:

python3 -m sglang.launch_server \
    --model-path Qwen/Qwen2.5-VL-7B-Instruct \
    --mem-fraction-static 0.8 \
    --tp 1 \
    --disable-radix-cache \
    --cuda-graph-bs 256 \
    --cuda-graph-max-bs 256 \
    --chunked-prefill-size 8192 \
    --max-prefill-tokens 8192 \
    --max-running-requests 256

Client cmd:

python3 -m sglang.bench_serving \
    --backend sglang-oai-chat \
    --dataset-name image \
    --num-prompts 7000 \
    --apply-chat-template \
    --random-output-len 1 \
    --random-input-len 1 \
    --image-resolution 1120x700 \
    --image-format jpeg \
    --image-count 1 \
    --image-content random \
    --random-range-ratio 1 \
    --max-concurrency 256

@Swipe4057
Copy link
Contributor

AlienKevin Hi! Very cool mr, please can you also test with --enable-mixed-chunk?

@mickqian mickqian changed the title Remove redundant device to host movement of image feature tensors vlm: remove redundant d2h movement of image feature tensors Sep 4, 2025
Copy link
Collaborator

@mickqian mickqian left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, and cc @yhyang201 as the original author of this code

@mickqian mickqian changed the title vlm: remove redundant d2h movement of image feature tensors vlm: remove redundant d2h movement of mm feature tensors Sep 4, 2025
@yhyang201
Copy link
Collaborator

I’m curious—if TP is enabled now, how are the various feature tensors transferred from GPU0 to the other GPUs?

@AlienKevin
Copy link
Contributor Author

AlienKevin commented Sep 4, 2025

@yhyang201 thanks for bringing this up.

With TP=2, rank 0 memory usage increased from 92% to 96% and stablized. Finished successfully without issue.

Screenshot 2025-09-04 at 1 48 26 PM

With TP=4, an OOM occurred:

OOM trace
[2025-09-04 20:13:50] INFO:     127.0.0.1:39038 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2025-09-04 20:13:50] INFO:     127.0.0.1:39050 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2025-09-04 20:13:58 TP0] Prefill batch. #new-seq: 9, #new-token: 8192, #cached-token: 0, token usage: 0.00, #running-req: 18, #queue-req: 239, 
[2025-09-04 20:14:06 TP3] Scheduler hit an exception: Traceback (most recent call last):
  File "/home/scratch.lik_gpu/benchmarking_toolkit/my_sglang/python/sglang/srt/managers/scheduler.py", line 2672, in run_scheduler_process
    scheduler.event_loop_overlap()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 120, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/scratch.lik_gpu/benchmarking_toolkit/my_sglang/python/sglang/srt/managers/scheduler.py", line 811, in event_loop_overlap
    recv_reqs = self.recv_requests()
                ^^^^^^^^^^^^^^^^^^^^
  File "/home/scratch.lik_gpu/benchmarking_toolkit/my_sglang/python/sglang/srt/managers/scheduler.py", line 1074, in recv_requests
    recv_reqs = broadcast_pyobj(
                ^^^^^^^^^^^^^^^^
  File "/home/scratch.lik_gpu/benchmarking_toolkit/my_sglang/python/sglang/srt/utils.py", line 1059, in broadcast_pyobj
    data = pickle.loads(serialized_data)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/storage.py", line 534, in _load_from_bytes
    return torch.load(io.BytesIO(b), weights_only=False)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/serialization.py", line 1554, in load
    return _legacy_load(
           ^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/serialization.py", line 1812, in _legacy_load
    result = unpickler.load()
             ^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/serialization.py", line 1747, in persistent_load
    obj = restore_location(obj, location)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/serialization.py", line 698, in default_restore_location
    result = fn(storage, location)
             ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/serialization.py", line 637, in _deserialize
    return obj.to(device=device)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/storage.py", line 291, in to
    return _to(self, device, non_blocking)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_utils.py", line 101, in _to
    untyped_storage = torch.UntypedStorage(self.size(), device=device)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 18.00 MiB. GPU 0 has a total capacity of 139.80 GiB of which 6.00 MiB is free. Process 550224 has 5.17 GiB memory in use. Process 550531 has 4.87 GiB memory in use. Process 550533 has 4.81 GiB memory in use. Process 550530 has 120.04 GiB memory in use. Process 550532 has 4.87 GiB memory in use. Of the allocated memory 4.31 GiB is allocated by PyTorch, and 0 bytes is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

[2025-09-04 20:14:06] Received sigquit from a child process. It usually means the child failed.
./qwen_server.sh: line 13: 144010 Killed                  python -m sglang.launch_server --model-path Qwen/Qwen2.5-VL-7B-Instruct --mem-fraction-static 0.8 --tp 4 --disable-radix-cache --cuda-graph-bs 256 --cuda-graph-max-bs 256 --chunked-prefill-size 8192 --max-prefill-tokens 8192 --max-running-requests 256

Reducing max_concurrency from 256 to 128 prevented OOM and memory usage stays at 93%:

Screenshot 2025-09-04 at 2 17 14 PM

@AlienKevin
Copy link
Contributor Author

AlienKevin commented Sep 4, 2025

@Swipe4057 I found there's a dynamo issue when running with --enable-mixed-chunk even without this PR.

dynamo error
user@sglang-bench:/home/scratch.lik_gpu/benchmarking_toolkit/my_sglang$ ./qwen_server.sh
All deep_gemm operations loaded successfully!
W0904 20:24:16.581000 165808 torch/utils/cpp_extension.py:2425] TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation. 
W0904 20:24:16.581000 165808 torch/utils/cpp_extension.py:2425] If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'] to specific architectures.
[2025-09-04 20:24:17] server_args=ServerArgs(model_path='Qwen/Qwen2.5-VL-7B-Instruct', tokenizer_path='Qwen/Qwen2.5-VL-7B-Instruct', tokenizer_mode='auto', tokenizer_worker_num=1, skip_tokenizer_init=False, load_format='auto', model_loader_extra_config='{}', trust_remote_code=False, context_length=None, is_embedding=False, enable_multimodal=None, revision=None, model_impl='auto', host='127.0.0.1', port=30000, skip_server_warmup=False, warmups=None, nccl_port=None, dtype='auto', quantization=None, quantization_param_path=None, kv_cache_dtype='auto', mem_fraction_static=0.8, max_running_requests=256, max_queued_requests=9223372036854775807, max_total_tokens=None, chunked_prefill_size=8192, max_prefill_tokens=8192, schedule_policy='fcfs', schedule_conservativeness=1.0, page_size=1, hybrid_kvcache_ratio=None, swa_full_tokens_ratio=0.8, disable_hybrid_swa_memory=False, device='cuda', tp_size=1, pp_size=1, max_micro_batch_size=None, stream_interval=1, stream_output=False, random_seed=912179317, constrained_json_whitespace_pattern=None, watchdog_timeout=300, dist_timeout=None, download_dir=None, base_gpu_id=0, gpu_id_step=1, sleep_on_idle=False, log_level='info', log_level_http=None, log_requests=False, log_requests_level=2, crash_dump_folder=None, show_time_cost=False, enable_metrics=False, enable_metrics_for_all_schedulers=False, bucket_time_to_first_token=None, bucket_inter_token_latency=None, bucket_e2e_request_latency=None, collect_tokens_histogram=False, decode_log_interval=40, enable_request_time_stats_logging=False, kv_events_config=None, gc_warning_threshold_secs=0.0, api_key=None, served_model_name='Qwen/Qwen2.5-VL-7B-Instruct', weight_version='default', chat_template=None, completion_template=None, file_storage_path='sglang_storage', enable_cache_report=False, reasoning_parser=None, tool_call_parser=None, tool_server=None, dp_size=1, load_balance_method='round_robin', dist_init_addr=None, nnodes=1, node_rank=0, json_model_override_args='{}', preferred_sampling_params=None, enable_lora=None, max_lora_rank=None, lora_target_modules=None, lora_paths=None, max_loaded_loras=None, max_loras_per_batch=8, lora_backend='triton', attention_backend=None, decode_attention_backend=None, prefill_attention_backend=None, sampling_backend='flashinfer', grammar_backend='xgrammar', mm_attention_backend=None, speculative_algorithm=None, speculative_draft_model_path=None, speculative_num_steps=None, speculative_eagle_topk=None, speculative_num_draft_tokens=None, speculative_accept_threshold_single=1.0, speculative_accept_threshold_acc=1.0, speculative_token_map=None, ep_size=1, moe_a2a_backend='none', moe_runner_backend='auto', flashinfer_mxfp4_moe_precision='default', enable_flashinfer_allreduce_fusion=False, deepep_mode='auto', ep_num_redundant_experts=0, ep_dispatch_algorithm='static', init_expert_location='trivial', enable_eplb=False, eplb_algorithm='auto', eplb_rebalance_num_iterations=1000, eplb_rebalance_layers_per_chunk=None, eplb_min_rebalancing_utilization_threshold=1.0, expert_distribution_recorder_mode=None, expert_distribution_recorder_buffer_size=1000, enable_expert_distribution_metrics=False, deepep_config=None, moe_dense_tp_size=None, enable_hierarchical_cache=False, hicache_ratio=2.0, hicache_size=0, hicache_write_policy='write_through', hicache_io_backend='kernel', hicache_mem_layout='layer_first', hicache_storage_backend=None, hicache_storage_prefetch_policy='best_effort', hicache_storage_backend_extra_config=None, enable_double_sparsity=False, ds_channel_config_path=None, ds_heavy_channel_num=32, ds_heavy_token_num=256, ds_heavy_channel_type='qk', ds_sparse_decode_threshold=4096, cpu_offload_gb=0, offload_group_size=-1, offload_num_in_group=1, offload_prefetch_step=1, offload_mode='cpu', disable_radix_cache=True, cuda_graph_max_bs=256, cuda_graph_bs=[256], disable_cuda_graph=False, disable_cuda_graph_padding=False, enable_profile_cuda_graph=False, enable_cudagraph_gc=False, enable_nccl_nvls=False, enable_symm_mem=False, disable_flashinfer_cutlass_moe_fp4_allgather=False, enable_tokenizer_batch_encode=False, disable_outlines_disk_cache=False, disable_custom_all_reduce=False, enable_mscclpp=False, disable_overlap_schedule=False, enable_mixed_chunk=True, enable_dp_attention=False, enable_dp_lm_head=False, enable_two_batch_overlap=False, tbo_token_distribution_threshold=0.48, enable_torch_compile=False, torch_compile_max_bs=32, torchao_config='', enable_nan_detection=False, enable_p2p_check=False, triton_attention_reduce_in_fp32=False, triton_attention_num_kv_splits=8, num_continuous_decode_steps=1, delete_ckpt_after_loading=False, enable_memory_saver=False, allow_auto_truncate=False, enable_custom_logit_processor=False, flashinfer_mla_disable_ragged=False, disable_shared_experts_fusion=False, disable_chunked_prefix_cache=False, disable_fast_image_processor=False, offload_mm_feature=False, enable_return_hidden_states=False, scheduler_recv_interval=1, debug_tensor_dump_output_folder=None, debug_tensor_dump_input_file=None, debug_tensor_dump_inject=False, debug_tensor_dump_prefill_only=False, disaggregation_mode='null', disaggregation_transfer_backend='mooncake', disaggregation_bootstrap_port=8998, disaggregation_decode_tp=None, disaggregation_decode_dp=None, disaggregation_prefill_pp=1, disaggregation_ib_device=None, num_reserved_decode_tokens=512, pdlb_url=None, custom_weight_loader=[], weight_loader_disable_mmap=False, enable_pdmux=False, sm_group_num=3, enable_ep_moe=False, enable_deepep_moe=False, enable_flashinfer_cutlass_moe=False, enable_flashinfer_trtllm_moe=False, enable_triton_kernel_moe=False, enable_flashinfer_mxfp4_moe=False)
`torch_dtype` is deprecated! Use `dtype` instead!
[2025-09-04 20:24:17] MOE_RUNNER_BACKEND is not initialized, using triton backend
[2025-09-04 20:24:19] Using default HuggingFace chat template with detected content format: openai
All deep_gemm operations loaded successfully!
W0904 20:24:22.469000 166083 torch/utils/cpp_extension.py:2425] TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation. 
W0904 20:24:22.469000 166083 torch/utils/cpp_extension.py:2425] If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'] to specific architectures.
All deep_gemm operations loaded successfully!
W0904 20:24:22.933000 166084 torch/utils/cpp_extension.py:2425] TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation. 
W0904 20:24:22.933000 166084 torch/utils/cpp_extension.py:2425] If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'] to specific architectures.
`torch_dtype` is deprecated! Use `dtype` instead!
[2025-09-04 20:24:24] Attention backend not explicitly specified. Use flashinfer backend by default.
[2025-09-04 20:24:24] Init torch distributed begin.
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[2025-09-04 20:24:25] Init torch distributed ends. mem usage=0.00 GB
[2025-09-04 20:24:25] MOE_RUNNER_BACKEND is not initialized, using triton backend
[2025-09-04 20:24:25] Load weight begin. avail mem=139.28 GB
[2025-09-04 20:24:25] Multimodal attention backend not set. Use fa3.
[2025-09-04 20:24:25] Using fa3 as multimodal attention backend.
[2025-09-04 20:24:25] Using model weights format ['*.safetensors']
Loading safetensors checkpoint shards:   0% Completed | 0/5 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  20% Completed | 1/5 [00:00<00:02,  1.84it/s]
Loading safetensors checkpoint shards:  40% Completed | 2/5 [00:01<00:01,  1.82it/s]
Loading safetensors checkpoint shards:  60% Completed | 3/5 [00:01<00:01,  1.80it/s]
Loading safetensors checkpoint shards:  80% Completed | 4/5 [00:02<00:00,  1.81it/s]
Loading safetensors checkpoint shards: 100% Completed | 5/5 [00:02<00:00,  2.39it/s]
Loading safetensors checkpoint shards: 100% Completed | 5/5 [00:02<00:00,  2.09it/s]

Enable PytHooks!

[2025-09-04 20:24:28] Load weight end. type=Qwen2_5_VLForConditionalGeneration, dtype=torch.bfloat16, avail mem=123.49 GB, mem usage=15.79 GB.
[2025-09-04 20:24:28] KV Cache is allocated. #tokens: 1790758, K size: 47.82 GB, V size: 47.82 GB
[2025-09-04 20:24:28] Memory pool end. avail mem=27.59 GB
[2025-09-04 20:24:28] Capture cuda graph begin. This can take up to several minutes. avail mem=27.01 GB
[2025-09-04 20:24:28] Capture cuda graph bs [256]
Capturing batches (bs=256 avail_mem=26.71 GB): 100%|████████████████████████████████████████████████████████████████████| 1/1 [00:02<00:00,  2.34s/it]
[2025-09-04 20:24:31] Capture cuda graph end. Time elapsed: 2.53 s. mem usage=0.50 GB. avail mem=26.51 GB.
[2025-09-04 20:24:33] max_total_num_tokens=1790758, chunked_prefill_size=8192, max_prefill_tokens=8192, max_running_requests=256, context_len=128000, available_gpu_mem=26.51 GB
[2025-09-04 20:24:33] INFO:     Started server process [165808]
[2025-09-04 20:24:33] INFO:     Waiting for application startup.
[2025-09-04 20:24:33] INFO:     Application startup complete.
[2025-09-04 20:24:33] INFO:     Uvicorn running on http://127.0.0.1:30000 (Press CTRL+C to quit)
[2025-09-04 20:24:34] INFO:     127.0.0.1:49828 - "GET /get_model_info HTTP/1.1" 200 OK
[2025-09-04 20:24:34] Prefill batch. #new-seq: 1, #new-token: 6, #cached-token: 0, token usage: 0.00, #running-req: 0, #queue-req: 0, 
[2025-09-04 20:24:36] INFO:     127.0.0.1:49838 - "POST /generate HTTP/1.1" 200 OK
[2025-09-04 20:24:36] The server is fired up and ready to roll!
[2025-09-04 20:24:42] INFO:     127.0.0.1:43182 - "GET /v1/models HTTP/1.1" 200 OK
[2025-09-04 20:25:51] INFO:     127.0.0.1:36234 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2025-09-04 20:25:52] Prefill batch. #new-seq: 1, #new-token: 1024, #cached-token: 0, token usage: 0.00, #running-req: 0, #queue-req: 0, 
[2025-09-04 20:25:54] INFO:     127.0.0.1:34174 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2025-09-04 20:25:54] INFO:     127.0.0.1:34186 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2025-09-04 20:25:54] INFO:     127.0.0.1:34202 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2025-09-04 20:25:54] INFO:     127.0.0.1:34204 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2025-09-04 20:25:54] INFO:     127.0.0.1:34214 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2025-09-04 20:25:54] INFO:     127.0.0.1:34226 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2025-09-04 20:25:54] INFO:     127.0.0.1:34240 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2025-09-04 20:25:54] INFO:     127.0.0.1:34250 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2025-09-04 20:25:54] INFO:     127.0.0.1:34256 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2025-09-04 20:25:54] Prefill batch. #new-seq: 1, #new-token: 1022, #cached-token: 0, token usage: 0.00, #running-req: 0, #queue-req: 0, 
[2025-09-04 20:25:54] Prefill batch. #new-seq: 1, #new-token: 1022, #cached-token: 0, token usage: 0.00, #running-req: 1, #queue-req: 0, 
[2025-09-04 20:25:54] Prefill batch. #new-seq: 4, #new-token: 4088, #cached-token: 0, token usage: 0.00, #running-req: 1, #queue-req: 0, 
[2025-09-04 20:25:54] INFO:     127.0.0.1:34498 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2025-09-04 20:25:54] INFO:     127.0.0.1:34508 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2025-09-04 20:25:54] INFO:     127.0.0.1:34522 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2025-09-04 20:25:54] INFO:     127.0.0.1:34528 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2025-09-04 20:25:54] INFO:     127.0.0.1:34544 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2025-09-04 20:25:54] INFO:     127.0.0.1:34554 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2025-09-04 20:25:54] INFO:     127.0.0.1:34556 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2025-09-04 20:25:54] INFO:     127.0.0.1:34570 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2025-09-04 20:25:54] INFO:     127.0.0.1:34578 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2025-09-04 20:25:54] INFO:     127.0.0.1:34582 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2025-09-04 20:25:54] INFO:     127.0.0.1:34588 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2025-09-04 20:25:54] INFO:     127.0.0.1:34604 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2025-09-04 20:25:54] INFO:     127.0.0.1:34610 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2025-09-04 20:25:54] INFO:     127.0.0.1:34614 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2025-09-04 20:25:55] INFO:     127.0.0.1:34628 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[rank0]:E0904 20:25:55.048000 166083 torch/_subclasses/fake_tensor.py:2721] [0/2] failed while attempting to run meta for aten.view.default
[rank0]:E0904 20:25:55.048000 166083 torch/_subclasses/fake_tensor.py:2721] [0/2] Traceback (most recent call last):
[rank0]:E0904 20:25:55.048000 166083 torch/_subclasses/fake_tensor.py:2721] [0/2]   File "/usr/local/lib/python3.12/dist-packages/torch/_subclasses/fake_tensor.py", line 2717, in _dispatch_impl
[rank0]:E0904 20:25:55.048000 166083 torch/_subclasses/fake_tensor.py:2721] [0/2]     r = func(*args, **kwargs)
[rank0]:E0904 20:25:55.048000 166083 torch/_subclasses/fake_tensor.py:2721] [0/2]         ^^^^^^^^^^^^^^^^^^^^^
[rank0]:E0904 20:25:55.048000 166083 torch/_subclasses/fake_tensor.py:2721] [0/2]   File "/usr/local/lib/python3.12/dist-packages/torch/_ops.py", line 829, in __call__
[rank0]:E0904 20:25:55.048000 166083 torch/_subclasses/fake_tensor.py:2721] [0/2]     return self._op(*args, **kwargs)
[rank0]:E0904 20:25:55.048000 166083 torch/_subclasses/fake_tensor.py:2721] [0/2]            ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:E0904 20:25:55.048000 166083 torch/_subclasses/fake_tensor.py:2721] [0/2]   File "/usr/local/lib/python3.12/dist-packages/torch/_meta_registrations.py", line 364, in _view_meta
[rank0]:E0904 20:25:55.048000 166083 torch/_subclasses/fake_tensor.py:2721] [0/2]     return torch._refs._reshape_view_helper(a, *shape, allow_copy=False)
[rank0]:E0904 20:25:55.048000 166083 torch/_subclasses/fake_tensor.py:2721] [0/2]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:E0904 20:25:55.048000 166083 torch/_subclasses/fake_tensor.py:2721] [0/2]   File "/usr/local/lib/python3.12/dist-packages/torch/_refs/__init__.py", line 3823, in _reshape_view_helper
[rank0]:E0904 20:25:55.048000 166083 torch/_subclasses/fake_tensor.py:2721] [0/2]     shape = utils.infer_size(shape, a.numel())
[rank0]:E0904 20:25:55.048000 166083 torch/_subclasses/fake_tensor.py:2721] [0/2]             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:E0904 20:25:55.048000 166083 torch/_subclasses/fake_tensor.py:2721] [0/2]   File "/usr/local/lib/python3.12/dist-packages/torch/_prims_common/__init__.py", line 1018, in infer_size
[rank0]:E0904 20:25:55.048000 166083 torch/_subclasses/fake_tensor.py:2721] [0/2]     torch._check(
[rank0]:E0904 20:25:55.048000 166083 torch/_subclasses/fake_tensor.py:2721] [0/2]   File "/usr/local/lib/python3.12/dist-packages/torch/__init__.py", line 1684, in _check
[rank0]:E0904 20:25:55.048000 166083 torch/_subclasses/fake_tensor.py:2721] [0/2]     _check_with(RuntimeError, cond, message)
[rank0]:E0904 20:25:55.048000 166083 torch/_subclasses/fake_tensor.py:2721] [0/2]   File "/usr/local/lib/python3.12/dist-packages/torch/__init__.py", line 1666, in _check_with
[rank0]:E0904 20:25:55.048000 166083 torch/_subclasses/fake_tensor.py:2721] [0/2]     raise error_type(message_evaluated)
[rank0]:E0904 20:25:55.048000 166083 torch/_subclasses/fake_tensor.py:2721] [0/2] RuntimeError: shape '[s18, -1, 128]' is invalid for input of size s5*s83
[2025-09-04 20:25:55] TpModelWorkerClient hit an exception: Traceback (most recent call last):
  File "/home/scratch.lik_gpu/benchmarking_toolkit/my_sglang/python/sglang/srt/managers/tp_worker_overlap_thread.py", line 141, in forward_thread_func
    self.forward_thread_func_()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 120, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/scratch.lik_gpu/benchmarking_toolkit/my_sglang/python/sglang/srt/managers/tp_worker_overlap_thread.py", line 176, in forward_thread_func_
    self.worker.forward_batch_generation(
  File "/home/scratch.lik_gpu/benchmarking_toolkit/my_sglang/python/sglang/srt/managers/tp_worker.py", line 239, in forward_batch_generation
    logits_output, can_run_cuda_graph = self.model_runner.forward(
                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/scratch.lik_gpu/benchmarking_toolkit/my_sglang/python/sglang/srt/model_executor/model_runner.py", line 1753, in forward
    output = self._forward_raw(
             ^^^^^^^^^^^^^^^^^^
  File "/home/scratch.lik_gpu/benchmarking_toolkit/my_sglang/python/sglang/srt/model_executor/model_runner.py", line 1798, in _forward_raw
    ret = self.forward_extend(
          ^^^^^^^^^^^^^^^^^^^^
  File "/home/scratch.lik_gpu/benchmarking_toolkit/my_sglang/python/sglang/srt/model_executor/model_runner.py", line 1698, in forward_extend
    return self.model.forward(
           ^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 120, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/scratch.lik_gpu/benchmarking_toolkit/my_sglang/python/sglang/srt/models/qwen2_5_vl.py", line 582, in forward
    hidden_states = general_mm_embed_routine(
                    ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/scratch.lik_gpu/benchmarking_toolkit/my_sglang/python/sglang/srt/managers/mm_utils.py", line 663, in general_mm_embed_routine
    hidden_states = language_model(
                    ^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1879, in _call_impl
    return inner()
           ^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1827, in inner
    result = forward_call(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/scratch.lik_gpu/benchmarking_toolkit/my_sglang/python/sglang/srt/models/qwen2.py", line 340, in forward
    hidden_states, residual = layer(
                              ^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1879, in _call_impl
    return inner()
           ^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1827, in inner
    result = forward_call(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/scratch.lik_gpu/benchmarking_toolkit/my_sglang/python/sglang/srt/models/qwen2.py", line 244, in forward
    hidden_states = self.self_attn(
                    ^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1879, in _call_impl
    return inner()
           ^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1827, in inner
    result = forward_call(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/scratch.lik_gpu/benchmarking_toolkit/my_sglang/python/sglang/srt/models/qwen2.py", line 182, in forward
    q, k = self.rotary_emb(positions, q, k)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1879, in _call_impl
    return inner()
           ^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1827, in inner
    result = forward_call(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py", line 736, in compile_wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/convert_frame.py", line 1495, in __call__
    return self._torchdynamo_orig_callable(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/convert_frame.py", line 1272, in __call__
    result = self._inner_convert(
             ^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/convert_frame.py", line 629, in __call__
    return _compile(
           ^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/convert_frame.py", line 1111, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_utils_internal.py", line 97, in wrapper_function
    return function(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/convert_frame.py", line 793, in compile_inner
    return _compile_inner(code, one_graph, hooks, transform)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/convert_frame.py", line 832, in _compile_inner
    out_code = transform_code_object(code, transform)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/bytecode_transformation.py", line 1424, in transform_code_object
    transformations(instructions, code_options)
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/convert_frame.py", line 267, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/convert_frame.py", line 753, in transform
    tracer.run()
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/symbolic_convert.py", line 3497, in run
    super().run()
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/symbolic_convert.py", line 1363, in run
    while self.step():
          ^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/symbolic_convert.py", line 1267, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/symbolic_convert.py", line 834, in wrapper
    return inner_fn(self, inst)
           ^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/symbolic_convert.py", line 2910, in CALL
    self._call(inst)
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/symbolic_convert.py", line 2904, in _call
    self.call_function(fn, args, kwargs)
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/symbolic_convert.py", line 1193, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/variables/misc.py", line 1111, in call_function
    return self.obj.call_method(tx, self.name, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/variables/tensor.py", line 712, in call_method
    return wrap_fx_proxy(
           ^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/variables/builder.py", line 2559, in wrap_fx_proxy
    return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/variables/builder.py", line 2625, in wrap_fx_proxy_cls
    return _wrap_fx_proxy(
           ^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/variables/builder.py", line 2723, in _wrap_fx_proxy
    example_value = get_fake_value(proxy.node, tx, allow_non_graph_fake=True)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/utils.py", line 3355, in get_fake_value
    raise TorchRuntimeError(str(e)).with_traceback(e.__traceback__) from None
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/utils.py", line 3253, in get_fake_value
    ret_val = wrap_fake_exception(
              ^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/utils.py", line 2753, in wrap_fake_exception
    return fn()
           ^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/utils.py", line 3254, in <lambda>
    lambda: run_node(tx.output, node, args, kwargs, nnmodule)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/utils.py", line 3462, in run_node
    raise RuntimeError(make_error_message(e)).with_traceback(
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/utils.py", line 3432, in run_node
    return getattr(args[0], node.target)(*args[1:], **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/_stats.py", line 28, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_subclasses/fake_tensor.py", line 1352, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_subclasses/fake_tensor.py", line 2058, in dispatch
    return self._cached_dispatch_impl(func, types, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_subclasses/fake_tensor.py", line 1487, in _cached_dispatch_impl
    output = self._dispatch_impl(func, types, args, kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_subclasses/fake_tensor.py", line 2717, in _dispatch_impl
    r = func(*args, **kwargs)
        ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_ops.py", line 829, in __call__
    return self._op(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_meta_registrations.py", line 364, in _view_meta
    return torch._refs._reshape_view_helper(a, *shape, allow_copy=False)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_refs/__init__.py", line 3823, in _reshape_view_helper
    shape = utils.infer_size(shape, a.numel())
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_prims_common/__init__.py", line 1018, in infer_size
    torch._check(
  File "/usr/local/lib/python3.12/dist-packages/torch/__init__.py", line 1684, in _check
    _check_with(RuntimeError, cond, message)
  File "/usr/local/lib/python3.12/dist-packages/torch/__init__.py", line 1666, in _check_with
    raise error_type(message_evaluated)
torch._dynamo.exc.TorchRuntimeError: Dynamo failed to run FX node with fake tensors: call_method view(*(FakeTensor(..., device='cuda:0', size=(s5, s83), dtype=torch.bfloat16), s18, -1, 128), **{}): got RuntimeError("shape '[s18, -1, 128]' is invalid for input of size s5*s83")

from user code:
   File "/home/scratch.lik_gpu/benchmarking_toolkit/my_sglang/python/sglang/srt/layers/rotary_embedding.py", line 1066, in forward
    query = query.view(num_tokens, -1, self.head_size)

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"


[2025-09-04 20:25:55] INFO:     127.0.0.1:34122 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2025-09-04 20:25:55] Received sigquit from a child process. It usually means the child failed.
./qwen_server.sh: line 14: 165808 Killed                  python -m sglang.launch_server --model-path Qwen/Qwen2.5-VL-7B-Instruct --mem-fraction-static 0.8 --tp 1 --disable-radix-cache --cuda-graph-bs 256 --cuda-graph-max-bs 256 --chunked-prefill-size 8192 --max-prefill-tokens 8192 --max-running-requests 256 --enable-mixed-chunk
user@sglang-bench:/home/scratch.lik_gpu/benchmarking_toolkit/my_sglang$

@Swipe4057
Copy link
Contributor

AlienKevin got it, thank you very much! I think this flag was without errors with vlm before,

@yhyang201
Copy link
Collaborator

If we remove the GPU-to-CPU logic, it might reintroduce the issue described in #5732. For now, I haven’t found a way to address this issue without potentially impacting performance.

@AlienKevin
Copy link
Contributor Author

@yhyang201 I'm not sure if there's any memory leak like in #5732, it seems that when max_concurrency set appropriately, memory usage is stable across all devices as shown in my nvitop screenshots above. To support GPUs with less memory, I've added a --offload-mm-feature flag to offload the pixel_values to CPU if needed as suggested by @JustinTong0323. Let me know what you think.

@yhyang201
Copy link
Collaborator

I think adding the --offload-mm-feature flag is indeed a great idea.
LGTM—thank you!

@JustinTong0323
Copy link
Collaborator

AlienKevin got it, thank you very much! I think this flag was without errors with vlm before,

@AlienKevin kevin, could you help raise a issue with that? Thanks!

disable_shared_experts_fusion: bool = False
disable_chunked_prefix_cache: bool = False
disable_fast_image_processor: bool = False
offload_mm_feature: bool = False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As it would easily cause OOM, could we default it to true? Or just use flag like "load_mm_feature_to_gpu".

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I'll update the flag

@AlienKevin
Copy link
Contributor Author

@JustinTong0323 Renamed offload_mm_feature to keep_mm_feature_on_device and defaults to False as requested.

@AlienKevin
Copy link
Contributor Author

AlienKevin got it, thank you very much! I think this flag was without errors with vlm before,

@AlienKevin kevin, could you help raise a issue with that? Thanks!

Issue filed on --enable-mixed-chunk: #10179

@AlienKevin
Copy link
Contributor Author

@JustinTong0323 Just following up, is this PR ready to be merged?

@JustinTong0323 JustinTong0323 merged commit de28f8e into sgl-project:main Sep 17, 2025
93 of 107 checks passed
@mickqian
Copy link
Collaborator

Better make this arg default in the future

HanHan009527 pushed a commit to HanHan009527/sglang that referenced this pull request Oct 9, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants