Skip to content

model: support OpenAI Whisper#8064

Closed
MahmoudAshraf97 wants to merge 18 commits intosgl-project:mainfrom
MahmoudAshraf97:whisper
Closed

model: support OpenAI Whisper#8064
MahmoudAshraf97 wants to merge 18 commits intosgl-project:mainfrom
MahmoudAshraf97:whisper

Conversation

@MahmoudAshraf97
Copy link
Copy Markdown
Contributor

@MahmoudAshraf97 MahmoudAshraf97 commented Jul 15, 2025

Motivation

This PR aims to implement whisper model for STT

Modifications

Audio modality is already supported with many models, but most of them use encoder projection to convert the audio to audio tokens that the language model can use, whisper uses encoder-decoder cross attention and I see that there is a dispatch in the backends for cross attention so I assume it's supported

I still need help to get this completed because the process gets killed without errors so it's hard to debug
Edit: the script was being killed at resampling stage, using a 16khz audio file avoids that for now

import sglang as sgl
import asyncio


async def main():
    engine = sgl.Engine(
        model_path="openai/whisper-tiny",
        disable_cuda_graph=True,
        grammar_backend="none",
        # attention_backend="flashinfer",
        # attention_backend="triton",
        attention_backend="fa3",
        # attention_backend="torch_native",
        disable_radix_cache=True,
    )

    a = engine.async_generate(
        # prompt="<|en|>",pip
        input_ids=[50258, 50259, 50359, 50363],
        audio_data=["/mnt/e/Projects/whisper-diarization/monofile2.wav"], # must be mono 16khz
        sampling_params={"temperature": 0.0, }
    )
    result = await a
    print(result)


# The __main__ condition is necessary here because we use "spawn" to create subprocesses
# Spawn starts a fresh program every time, if there is no __main__, it will run into infinite loop to keep spawning processes from sgl.Engine
if __name__ == "__main__":
    asyncio.run(main())

Checklist

To Do:

This function is hardcoded to support the current models but its not generalizable as each model has its own quirks, I suggest offloading it to each model's implementation such as the case with pad_input_ids

def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]):

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @MahmoudAshraf97, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces the core infrastructure to support OpenAI's Whisper model for Speech-to-Text functionality. It encompasses the model's definition, custom adaptations for its encoder and weight loading, and a specialized processor to handle audio input and prepare it for the framework's multimodal capabilities.

Highlights

  • Whisper Model Integration: I've added the foundational support for integrating OpenAI's Whisper model, specifically WhisperForConditionalGeneration, into the framework. This enables Speech-to-Text (STT) capabilities.
  • Custom Encoder and Weight Loading: The integration includes a custom WhisperForConditionalGeneration class that utilizes a MiniCPMWhisperEncoder for its encoder component. I've also implemented specific weight loading logic to handle parameter mapping for efficient model initialization.
  • Multimodal Audio Processing: A dedicated WhisperProcessor has been introduced to handle audio input. This processor is responsible for loading raw audio data, converting it into the necessary input features, and preparing it as multimodal data items for the SGLang framework.
  • Model Registration: The new WhisperForConditionalGeneration model has been registered within the system's model_config.py to ensure it's recognized as a valid generation model.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This PR introduces support for the OpenAI Whisper model. The changes look like a good start, but there are a few areas that need attention.

In python/sglang/srt/models/whisper.py, there are unused parameters and variables in __init__ and load_weights, which should be cleaned up.

In python/sglang/srt/multimodal/processors/whisper.py, the process_mm_data_async method contains blocking calls for I/O and CPU-intensive work, which will impact server performance. These should be made asynchronous using the provided executors.

Additionally, the is_audio_model function in python/sglang/srt/configs/model_config.py needs to be updated to correctly identify Whisper models.

Addressing these points will improve the correctness and performance of the implementation.

@JustinTong0323
Copy link
Copy Markdown
Collaborator

Thanks for the contribution. Could you please add some tests?

@MahmoudAshraf97
Copy link
Copy Markdown
Contributor Author

Thanks for the contribution. Could you please add some tests?

I will add them once I'm able to run the model, as I mentioned above, the script crashes with no error so I can't debug it, I'm not very experienced with SGLang so I'll need help from the team or another contributors, I can handle all the modeling and inference logic of whisper, but not SGL internals

@JustinTong0323
Copy link
Copy Markdown
Collaborator

JustinTong0323 commented Jul 22, 2025

Thanks for the contribution. Could you please add some tests?

I will add them once I'm able to run the model, as I mentioned above, the script crashes with no error so I can't debug it, I'm not very experienced with SGLang so I'll need help from the team or another contributors, I can handle all the modeling and inference logic of whisper, but not SGL internals

Got you, @byjiang1996 may you take a look? Thanks!

@mickqian mickqian changed the title feat: Implement OpenAI Whisper model: Implement OpenAI Whisper Jul 23, 2025
@mickqian mickqian changed the title model: Implement OpenAI Whisper model: support OpenAI Whisper Jul 23, 2025
Copy link
Copy Markdown
Collaborator

@byjiang1996 byjiang1996 Jul 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please follow this class to add unit test for WhisperForConditionalGeneration to avoid future regression? We can use a small openai whisper model in the unit test so it won't take too long to run

https://github.com/sgl-project/sglang/pull/8048/files#diff-40b31588286beebffb07cbcba6ac68278f65bfab723831352e4084b49d6ab84e

@MahmoudAshraf97
Copy link
Copy Markdown
Contributor Author

I reimplemented the model using SGL modules to handle the kv caching, I'm getting the following error:

[2025-07-25 01:59:56] TpModelWorkerClient hit an exception: Traceback (most recent call last):
  File "/mnt/e/Projects/sglang/python/sglang/srt/managers/tp_worker_overlap_thread.py", line 140, in forward_thread_func
    self.forward_thread_func_()
  File "/home/mahmoud/.local/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/mnt/e/Projects/sglang/python/sglang/srt/managers/tp_worker_overlap_thread.py", line 175, in forward_thread_func_
    self.worker.forward_batch_generation(
  File "/mnt/e/Projects/sglang/python/sglang/srt/managers/tp_worker.py", line 228, in forward_batch_generation
    logits_output, can_run_cuda_graph = self.model_runner.forward(
  File "/mnt/e/Projects/sglang/python/sglang/srt/model_executor/model_runner.py", line 1547, in forward
    output = self._forward_raw(
  File "/mnt/e/Projects/sglang/python/sglang/srt/model_executor/model_runner.py", line 1582, in _forward_raw
    ret = self.forward_extend(
  File "/mnt/e/Projects/sglang/python/sglang/srt/model_executor/model_runner.py", line 1492, in forward_extend
    return self.model.forward(
  File "/mnt/e/Projects/sglang/python/sglang/srt/models/whisper.py", line 410, in forward
    encoder_outputs = self.encoder(features.to(dtype), forward_batch)
  File "/home/mahmoud/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/mahmoud/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
    return forward_call(*args, **kwargs)
  File "/mnt/e/Projects/sglang/python/sglang/srt/models/whisper.py", line 276, in forward
    hidden_states = encoder_layer(hidden_states, forward_batch)
  File "/home/mahmoud/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/mahmoud/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
    return forward_call(*args, **kwargs)
  File "/mnt/e/Projects/sglang/python/sglang/srt/models/whisper.py", line 134, in forward
    hidden_states = self.self_attn(hidden_states, forward_batch)
  File "/home/mahmoud/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/mahmoud/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
    return forward_call(*args, **kwargs)
  File "/mnt/e/Projects/sglang/python/sglang/srt/models/whisper.py", line 92, in forward
    attn_output = self.attn(q, k, v, forward_batch)
  File "/home/mahmoud/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/mahmoud/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
    return forward_call(*args, **kwargs)
  File "/mnt/e/Projects/sglang/python/sglang/srt/layers/radix_attention.py", line 100, in forward
    return forward_batch.attn_backend.forward(
  File "/mnt/e/Projects/sglang/python/sglang/srt/layers/attention/base_attn_backend.py", line 79, in forward
    return self.forward_extend(
  File "/mnt/e/Projects/sglang/python/sglang/srt/layers/attention/flashinfer_backend.py", line 475, in forward_extend
    forward_batch.token_to_kv_pool.set_kv_buffer(
  File "/mnt/e/Projects/sglang/python/sglang/srt/mem_cache/memory_pool.py", line 459, in set_kv_buffer
    self.k_buffer[layer_id - self.start_layer][loc] = cache_k
RuntimeError: shape mismatch: value tensor of shape [1500, 6, 64] cannot be broadcast to indexing result of shape [1, 6, 64]

it seems to be a problem with the encoder kv cache, we don't need to store the kv values for the encoder in the forward_extend stage, only the final encoder hidden state should be stored for the decoding phase

@MahmoudAshraf97
Copy link
Copy Markdown
Contributor Author

MahmoudAshraf97 commented Jul 25, 2025

What should be done now is:

  1. in prefill or forward_extend, encode the features and store them, do not save kv cache for encoder, as far as I know, there is no method for doing this in SGL, all multimodal models process MM embeddings as tokens and hence stored in the decoder KVCache and decoding stage does not need the encoder embeddings, but whisper still needs the encoder embeddings for each decoding iteration
  2. in decoding, use the stored encoder output and decode tokens autoregressively, save kv cache for self attention and store kv cache for self and cross attention

@JustinTong0323

@MahmoudAshraf97
Copy link
Copy Markdown
Contributor Author

MahmoudAshraf97 commented Jul 26, 2025

I'm blocked on two things:

  1. I need a way to handle the caching and retrieval of the encoder output
  2. If the number of tokens in the input IDs for a certain element in the batch is 27, why does the positions tensor contain a single element in the prefill phase? or does that not represent the position_ids?
input_ids=tensor([[50258, 50363, 50259, ... , 50257]], device='cuda:0'), shape=torch.Size([1, 27])
seq_lens=tensor([1], device='cuda:0')
out_cache_loc=tensor([1], device='cuda:0')
seq_lens_sum=1
positions=tensor([0], device='cuda:0')

EDIT: point 2 was solved by removing the batch dim from the multimodal processor outputs

@Guobing-Chen
Copy link
Copy Markdown

Any update on this PR? We are very eager to try Whisper with SGLang.

@yanbing-j
Copy link
Copy Markdown
Contributor

yanbing-j commented Nov 11, 2025

Hi @MahmoudAshraf97 , I try to run Whisper on 4090D using this PR branch. The example script in the description should be the client script. So I launch server first by python3 -m sglang.launch_server --model openai/whisper-tiny --disable-overlap-schedule --mem-fraction-static 0.3 --max-total-tokens 63356 --enable-multimodal. It fails in mm_inputs is None, which should be the default text input in sglang.launch_server.

Could you please support the scenario of mm_inputs is None as well? If there is anything mistake, please correct me. Thanks!

$ python3 -m sglang.launch_server --model openai/whisper-tiny --disable-overlap-schedule --disable-radix-cache --log-requests --log-requests-level 3 --mem-fraction-static 0.3 --max-total-tokens 63356 --enable-multimodal
/home/yanbingj/miniforge3/envs/sglang/lib/python3.12/site-packages/torch/cuda/__init__.py:63: FutureWarning: The pynvml package is deprecated. Please install nvidia-ml-py instead. If you did not install pynvml directly, please report this to the maintainers of the package that installed pynvml for you.
  import pynvml  # type: ignore[import]
/home/yanbingj/miniforge3/envs/sglang/lib/python3.12/site-packages/torch/cuda/__init__.py:63: FutureWarning: The pynvml package is deprecated. Please install nvidia-ml-py instead. If you did not install pynvml directly, please report this to the maintainers of the package that installed pynvml for you.
  import pynvml  # type: ignore[import]
config.json: 1.98kB [00:00, 3.92MB/s]
generation_config.json: 3.75kB [00:00, 8.08MB/s]
[2025-11-11 10:56:51] INFO model_config.py:884: Downcasting torch.float32 to torch.float16.
[2025-11-11 10:56:52] WARNING server_args.py:1183: Attention backend not explicitly specified. Use flashinfer backend by default.
[2025-11-11 10:56:52] INFO trace.py:52: opentelemetry package is not installed, tracing disabled
[2025-11-11 10:56:52] server_args=ServerArgs(model_path='openai/whisper-tiny', tokenizer_path='openai/whisper-tiny', tokenizer_mode='auto', tokenizer_worker_num=1, skip_tokenizer_init=False, load_format='auto', model_loader_extra_config='{}', trust_remote_code=False, context_length=None, is_embedding=False, enable_multimodal=True, revision=None, model_impl='auto', host='127.0.0.1', port=30000, grpc_mode=False, skip_server_warmup=False, warmups=None, nccl_port=None, checkpoint_engine_wait_weights_before_ready=False, dtype='auto', quantization=None, quantization_param_path=None, kv_cache_dtype='auto', enable_fp32_lm_head=False, modelopt_quant=None, modelopt_checkpoint_restore_path=None, modelopt_checkpoint_save_path=None, modelopt_export_path=None, quantize_and_serve=False, mem_fraction_static=0.3, max_running_requests=None, max_queued_requests=None, max_total_tokens=63356, chunked_prefill_size=2048, max_prefill_tokens=16384, schedule_policy='fcfs', enable_priority_scheduling=False, abort_on_priority_when_disabled=False, schedule_low_priority_values_first=False, priority_scheduling_preemption_threshold=10, schedule_conservativeness=1.0, page_size=1, hybrid_kvcache_ratio=None, swa_full_tokens_ratio=0.8, disable_hybrid_swa_memory=False, radix_eviction_policy='lru', device='cuda', tp_size=1, pp_size=1, pp_max_micro_batch_size=None, stream_interval=1, stream_output=False, random_seed=885639532, constrained_json_whitespace_pattern=None, constrained_json_disable_any_whitespace=False, watchdog_timeout=300, dist_timeout=None, download_dir=None, base_gpu_id=0, gpu_id_step=1, sleep_on_idle=False, log_level='info', log_level_http=None, log_requests=True, log_requests_level=3, crash_dump_folder=None, show_time_cost=False, enable_metrics=False, enable_metrics_for_all_schedulers=False, tokenizer_metrics_custom_labels_header='x-custom-labels', tokenizer_metrics_allowed_custom_labels=None, bucket_time_to_first_token=None, bucket_inter_token_latency=None, bucket_e2e_request_latency=None, collect_tokens_histogram=False, prompt_tokens_buckets=None, generation_tokens_buckets=None, gc_warning_threshold_secs=0.0, decode_log_interval=40, enable_request_time_stats_logging=False, kv_events_config=None, enable_trace=False, otlp_traces_endpoint='localhost:4317', api_key=None, served_model_name='openai/whisper-tiny', weight_version='default', chat_template=None, completion_template=None, file_storage_path='sglang_storage', enable_cache_report=False, reasoning_parser=None, tool_call_parser=None, tool_server=None, sampling_defaults='model', dp_size=1, load_balance_method='round_robin', load_watch_interval=0.1, prefill_round_robin_balance=False, dist_init_addr=None, nnodes=1, node_rank=0, json_model_override_args='{}', preferred_sampling_params=None, enable_lora=None, max_lora_rank=None, lora_target_modules=None, lora_paths=None, max_loaded_loras=None, max_loras_per_batch=8, lora_eviction_policy='lru', lora_backend='csgmv', max_lora_chunk_size=16, attention_backend='flashinfer', decode_attention_backend=None, prefill_attention_backend=None, sampling_backend='flashinfer', grammar_backend='xgrammar', mm_attention_backend=None, nsa_prefill_backend='flashmla_sparse', nsa_decode_backend='fa3', speculative_algorithm=None, speculative_draft_model_path=None, speculative_draft_model_revision=None, speculative_draft_load_format=None, speculative_num_steps=None, speculative_eagle_topk=None, speculative_num_draft_tokens=None, speculative_accept_threshold_single=1.0, speculative_accept_threshold_acc=1.0, speculative_token_map=None, speculative_attention_mode='prefill', speculative_moe_runner_backend=None, speculative_ngram_min_match_window_size=1, speculative_ngram_max_match_window_size=12, speculative_ngram_min_bfs_breadth=1, speculative_ngram_max_bfs_breadth=10, speculative_ngram_match_type='BFS', speculative_ngram_branch_length=18, speculative_ngram_capacity=10000000, ep_size=1, moe_a2a_backend='none', moe_runner_backend='auto', flashinfer_mxfp4_moe_precision='default', enable_flashinfer_allreduce_fusion=False, deepep_mode='auto', ep_num_redundant_experts=0, ep_dispatch_algorithm='static', init_expert_location='trivial', enable_eplb=False, eplb_algorithm='auto', eplb_rebalance_num_iterations=1000, eplb_rebalance_layers_per_chunk=None, eplb_min_rebalancing_utilization_threshold=1.0, expert_distribution_recorder_mode=None, expert_distribution_recorder_buffer_size=1000, enable_expert_distribution_metrics=False, deepep_config=None, moe_dense_tp_size=None, elastic_ep_backend=None, mooncake_ib_device=None, max_mamba_cache_size=None, mamba_ssm_dtype='float32', mamba_full_memory_ratio=0.9, enable_hierarchical_cache=False, hicache_ratio=2.0, hicache_size=0, hicache_write_policy='write_through', hicache_io_backend='kernel', hicache_mem_layout='layer_first', hicache_storage_backend=None, hicache_storage_prefetch_policy='best_effort', hicache_storage_backend_extra_config=None, enable_lmcache=False, kt_weight_path=None, kt_method='AMXINT4', kt_cpuinfer=None, kt_threadpool_count=2, kt_num_gpu_experts=None, kt_max_deferred_experts_per_token=None, enable_double_sparsity=False, ds_channel_config_path=None, ds_heavy_channel_num=32, ds_heavy_token_num=256, ds_heavy_channel_type='qk', ds_sparse_decode_threshold=4096, cpu_offload_gb=0, offload_group_size=-1, offload_num_in_group=1, offload_prefetch_step=1, offload_mode='cpu', multi_item_scoring_delimiter=None, disable_radix_cache=True, cuda_graph_max_bs=24, cuda_graph_bs=[1, 2, 4, 8, 12, 16, 24], disable_cuda_graph=False, disable_cuda_graph_padding=False, enable_profile_cuda_graph=False, enable_cudagraph_gc=False, enable_nccl_nvls=False, enable_symm_mem=False, disable_flashinfer_cutlass_moe_fp4_allgather=False, enable_tokenizer_batch_encode=False, disable_tokenizer_batch_decode=False, disable_outlines_disk_cache=False, disable_custom_all_reduce=False, enable_mscclpp=False, enable_torch_symm_mem=False, disable_overlap_schedule=True, enable_mixed_chunk=False, enable_dp_attention=False, enable_dp_lm_head=False, enable_two_batch_overlap=False, enable_single_batch_overlap=False, tbo_token_distribution_threshold=0.48, enable_torch_compile=False, enable_piecewise_cuda_graph=False, torch_compile_max_bs=32, piecewise_cuda_graph_max_tokens=4096, piecewise_cuda_graph_tokens=[4, 8, 12, 16, 20, 24, 28, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240, 256, 288, 320, 352, 384, 416, 448, 480, 512, 640, 768, 896, 1024, 1152, 1280, 1408, 1536, 1664, 1792, 1920, 2048, 2176, 2304, 2432, 2560, 2688, 2816, 2944, 3072, 3200, 3328, 3456, 3584, 3712, 3840, 3968, 4096], piecewise_cuda_graph_compiler='eager', torchao_config='', enable_nan_detection=False, enable_p2p_check=False, triton_attention_reduce_in_fp32=False, triton_attention_num_kv_splits=8, triton_attention_split_tile_size=None, num_continuous_decode_steps=1, delete_ckpt_after_loading=False, enable_memory_saver=False, enable_weights_cpu_backup=False, allow_auto_truncate=False, enable_custom_logit_processor=False, flashinfer_mla_disable_ragged=False, disable_shared_experts_fusion=False, disable_chunked_prefix_cache=False, disable_fast_image_processor=False, keep_mm_feature_on_device=False, enable_return_hidden_states=False, scheduler_recv_interval=1, numa_node=None, enable_deterministic_inference=False, rl_on_policy_target=None, enable_dynamic_batch_tokenizer=False, dynamic_batch_tokenizer_batch_size=32, dynamic_batch_tokenizer_batch_timeout=0.002, debug_tensor_dump_output_folder=None, debug_tensor_dump_layers=None, debug_tensor_dump_input_file=None, debug_tensor_dump_inject=False, disaggregation_mode='null', disaggregation_transfer_backend='mooncake', disaggregation_bootstrap_port=8998, disaggregation_decode_tp=None, disaggregation_decode_dp=None, disaggregation_prefill_pp=1, disaggregation_ib_device=None, disaggregation_decode_enable_offload_kvcache=False, num_reserved_decode_tokens=512, disaggregation_decode_polling_interval=1, custom_weight_loader=[], weight_loader_disable_mmap=False, remote_instance_weight_loader_seed_instance_ip=None, remote_instance_weight_loader_seed_instance_service_port=None, remote_instance_weight_loader_send_weights_group_ports=None, enable_pdmux=False, pdmux_config_path=None, sm_group_num=8, mm_max_concurrent_calls=32, mm_per_request_timeout=10.0, decrypted_config_file=None, decrypted_draft_config_file=None)
[2025-11-11 10:56:52] Downcasting torch.float32 to torch.float16.
/home/yanbingj/miniforge3/envs/sglang/lib/python3.12/site-packages/torch/cuda/__init__.py:63: FutureWarning: The pynvml package is deprecated. Please install nvidia-ml-py instead. If you did not install pynvml directly, please report this to the maintainers of the package that installed pynvml for you.
  import pynvml  # type: ignore[import]
/home/yanbingj/miniforge3/envs/sglang/lib/python3.12/site-packages/torch/cuda/__init__.py:63: FutureWarning: The pynvml package is deprecated. Please install nvidia-ml-py instead. If you did not install pynvml directly, please report this to the maintainers of the package that installed pynvml for you.
  import pynvml  # type: ignore[import]
preprocessor_config.json: 185kB [00:00, 112MB/s]
tokenizer_config.json: 283kB [00:00, 111MB/s]
vocab.json: 836kB [00:00, 1.08MB/s]
[2025-11-11 10:57:02] INFO trace.py:52: opentelemetry package is not installed, tracing disabled
[2025-11-11 10:57:02] INFO trace.py:52: opentelemetry package is not installed, tracing disabled
tokenizer.json: 2.48MB [00:00, 2.51MB/s]
merges.txt: 494kB [00:00, 1.93MB/s]
[2025-11-11 10:57:04] Downcasting torch.float32 to torch.float16.
normalizer.json: 52.7kB [00:00, 206kB/s]
added_tokens.json: 34.6kB [00:00, 134kB/s]
special_tokens_map.json: 2.19kB [00:00, 5.70MB/s]
[2025-11-11 10:57:10] No chat template found, defaulting to 'string' content format
[2025-11-11 10:57:11] Downcasting torch.float32 to torch.float16.
[2025-11-11 10:57:11] Init torch distributed begin.
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[2025-11-11 10:57:12] Init torch distributed ends. mem usage=0.00 GB
[2025-11-11 10:57:12] MOE_RUNNER_BACKEND is not initialized, the backend will be automatically selected
[2025-11-11 10:57:13] Load weight begin. avail mem=23.17 GB
[2025-11-11 10:57:14] Using model weights format ['*.safetensors']
model.safetensors: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| 151M/151M [00:12<00:00, 12.4MB/s]
[2025-11-11 10:57:28] No model.safetensors.index.json found in remote.
Loading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00, 108.21it/s]

[2025-11-11 10:57:29] Load weight end. type=WhisperForConditionalGeneration, dtype=torch.float16, avail mem=23.00 GB, mem usage=0.17 GB.
[2025-11-11 10:57:29] Using KV cache dtype: torch.float16
[2025-11-11 10:57:29] KV Cache is allocated. #tokens: 63356, K size: 0.18 GB, V size: 0.18 GB
[2025-11-11 10:57:29] Memory pool end. avail mem=22.57 GB
[2025-11-11 10:57:29] Capture cuda graph begin. This can take up to several minutes. avail mem=21.98 GB
[2025-11-11 10:57:29] Capture cuda graph bs [1, 2, 4, 8, 12, 16, 24]
Capturing batches (bs=24 avail_mem=21.96 GB):   0%|                                                                                          | 0/7 [00:12<?, ?it/s]
[2025-11-11 10:57:42] Scheduler hit an exception: Traceback (most recent call last):
  File "/home/yanbingj/projects/sglang/python/sglang/srt/managers/scheduler.py", line 2672, in run_scheduler_process
    scheduler = Scheduler(
                ^^^^^^^^^^
  File "/home/yanbingj/projects/sglang/python/sglang/srt/managers/scheduler.py", line 311, in __init__
    self.tp_worker = TpModelWorker(
                     ^^^^^^^^^^^^^^
  File "/home/yanbingj/projects/sglang/python/sglang/srt/managers/tp_worker.py", line 237, in __init__
    self._model_runner = ModelRunner(
                         ^^^^^^^^^^^^
  File "/home/yanbingj/projects/sglang/python/sglang/srt/model_executor/model_runner.py", line 323, in __init__
    self.initialize(min_per_gpu_memory)
  File "/home/yanbingj/projects/sglang/python/sglang/srt/model_executor/model_runner.py", line 490, in initialize
    self.init_device_graphs()
  File "/home/yanbingj/projects/sglang/python/sglang/srt/model_executor/model_runner.py", line 2006, in init_device_graphs
    self.graph_runner = graph_runners[self.device](self)
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/yanbingj/projects/sglang/python/sglang/srt/model_executor/cuda_graph_runner.py", line 381, in __init__
    self.capture()
  File "/home/yanbingj/projects/sglang/python/sglang/srt/model_executor/cuda_graph_runner.py", line 500, in capture
    ) = self.capture_one_batch_size(bs, forward)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/yanbingj/projects/sglang/python/sglang/srt/model_executor/cuda_graph_runner.py", line 692, in capture_one_batch_size
    run_once()
  File "/home/yanbingj/projects/sglang/python/sglang/srt/model_executor/cuda_graph_runner.py", line 679, in run_once
    logits_output_or_pp_proxy_tensors = forward(
                                        ^^^^^^^^
  File "/home/yanbingj/projects/sglang/python/sglang/srt/models/whisper.py", line 416, in forward
    assert mm_inputs is not None
           ^^^^^^^^^^^^^^^^^^^^^
AssertionError

[2025-11-11 10:57:42] Received sigquit from a child process. It usually means the child failed.
Killed

@yhyang201
Copy link
Copy Markdown
Collaborator

Hi @MahmoudAshraf97 , I try to run Whisper on 4090D using this PR branch. The example script in the description should be the client script. So I launch server first by python3 -m sglang.launch_server --model openai/whisper-tiny --disable-overlap-schedule --mem-fraction-static 0.3 --max-total-tokens 63356 --enable-multimodal. It fails in mm_inputs is None, which should be the default text input in sglang.launch_server.

Could you please support the scenario of mm_inputs is None as well? If there is anything mistake, please correct me. Thanks!

$ python3 -m sglang.launch_server --model openai/whisper-tiny --disable-overlap-schedule --disable-radix-cache --log-requests --log-requests-level 3 --mem-fraction-static 0.3 --max-total-tokens 63356 --enable-multimodal
/home/yanbingj/miniforge3/envs/sglang/lib/python3.12/site-packages/torch/cuda/__init__.py:63: FutureWarning: The pynvml package is deprecated. Please install nvidia-ml-py instead. If you did not install pynvml directly, please report this to the maintainers of the package that installed pynvml for you.
  import pynvml  # type: ignore[import]
/home/yanbingj/miniforge3/envs/sglang/lib/python3.12/site-packages/torch/cuda/__init__.py:63: FutureWarning: The pynvml package is deprecated. Please install nvidia-ml-py instead. If you did not install pynvml directly, please report this to the maintainers of the package that installed pynvml for you.
  import pynvml  # type: ignore[import]
config.json: 1.98kB [00:00, 3.92MB/s]
generation_config.json: 3.75kB [00:00, 8.08MB/s]
[2025-11-11 10:56:51] INFO model_config.py:884: Downcasting torch.float32 to torch.float16.
[2025-11-11 10:56:52] WARNING server_args.py:1183: Attention backend not explicitly specified. Use flashinfer backend by default.
[2025-11-11 10:56:52] INFO trace.py:52: opentelemetry package is not installed, tracing disabled
[2025-11-11 10:56:52] server_args=ServerArgs(model_path='openai/whisper-tiny', tokenizer_path='openai/whisper-tiny', tokenizer_mode='auto', tokenizer_worker_num=1, skip_tokenizer_init=False, load_format='auto', model_loader_extra_config='{}', trust_remote_code=False, context_length=None, is_embedding=False, enable_multimodal=True, revision=None, model_impl='auto', host='127.0.0.1', port=30000, grpc_mode=False, skip_server_warmup=False, warmups=None, nccl_port=None, checkpoint_engine_wait_weights_before_ready=False, dtype='auto', quantization=None, quantization_param_path=None, kv_cache_dtype='auto', enable_fp32_lm_head=False, modelopt_quant=None, modelopt_checkpoint_restore_path=None, modelopt_checkpoint_save_path=None, modelopt_export_path=None, quantize_and_serve=False, mem_fraction_static=0.3, max_running_requests=None, max_queued_requests=None, max_total_tokens=63356, chunked_prefill_size=2048, max_prefill_tokens=16384, schedule_policy='fcfs', enable_priority_scheduling=False, abort_on_priority_when_disabled=False, schedule_low_priority_values_first=False, priority_scheduling_preemption_threshold=10, schedule_conservativeness=1.0, page_size=1, hybrid_kvcache_ratio=None, swa_full_tokens_ratio=0.8, disable_hybrid_swa_memory=False, radix_eviction_policy='lru', device='cuda', tp_size=1, pp_size=1, pp_max_micro_batch_size=None, stream_interval=1, stream_output=False, random_seed=885639532, constrained_json_whitespace_pattern=None, constrained_json_disable_any_whitespace=False, watchdog_timeout=300, dist_timeout=None, download_dir=None, base_gpu_id=0, gpu_id_step=1, sleep_on_idle=False, log_level='info', log_level_http=None, log_requests=True, log_requests_level=3, crash_dump_folder=None, show_time_cost=False, enable_metrics=False, enable_metrics_for_all_schedulers=False, tokenizer_metrics_custom_labels_header='x-custom-labels', tokenizer_metrics_allowed_custom_labels=None, bucket_time_to_first_token=None, bucket_inter_token_latency=None, bucket_e2e_request_latency=None, collect_tokens_histogram=False, prompt_tokens_buckets=None, generation_tokens_buckets=None, gc_warning_threshold_secs=0.0, decode_log_interval=40, enable_request_time_stats_logging=False, kv_events_config=None, enable_trace=False, otlp_traces_endpoint='localhost:4317', api_key=None, served_model_name='openai/whisper-tiny', weight_version='default', chat_template=None, completion_template=None, file_storage_path='sglang_storage', enable_cache_report=False, reasoning_parser=None, tool_call_parser=None, tool_server=None, sampling_defaults='model', dp_size=1, load_balance_method='round_robin', load_watch_interval=0.1, prefill_round_robin_balance=False, dist_init_addr=None, nnodes=1, node_rank=0, json_model_override_args='{}', preferred_sampling_params=None, enable_lora=None, max_lora_rank=None, lora_target_modules=None, lora_paths=None, max_loaded_loras=None, max_loras_per_batch=8, lora_eviction_policy='lru', lora_backend='csgmv', max_lora_chunk_size=16, attention_backend='flashinfer', decode_attention_backend=None, prefill_attention_backend=None, sampling_backend='flashinfer', grammar_backend='xgrammar', mm_attention_backend=None, nsa_prefill_backend='flashmla_sparse', nsa_decode_backend='fa3', speculative_algorithm=None, speculative_draft_model_path=None, speculative_draft_model_revision=None, speculative_draft_load_format=None, speculative_num_steps=None, speculative_eagle_topk=None, speculative_num_draft_tokens=None, speculative_accept_threshold_single=1.0, speculative_accept_threshold_acc=1.0, speculative_token_map=None, speculative_attention_mode='prefill', speculative_moe_runner_backend=None, speculative_ngram_min_match_window_size=1, speculative_ngram_max_match_window_size=12, speculative_ngram_min_bfs_breadth=1, speculative_ngram_max_bfs_breadth=10, speculative_ngram_match_type='BFS', speculative_ngram_branch_length=18, speculative_ngram_capacity=10000000, ep_size=1, moe_a2a_backend='none', moe_runner_backend='auto', flashinfer_mxfp4_moe_precision='default', enable_flashinfer_allreduce_fusion=False, deepep_mode='auto', ep_num_redundant_experts=0, ep_dispatch_algorithm='static', init_expert_location='trivial', enable_eplb=False, eplb_algorithm='auto', eplb_rebalance_num_iterations=1000, eplb_rebalance_layers_per_chunk=None, eplb_min_rebalancing_utilization_threshold=1.0, expert_distribution_recorder_mode=None, expert_distribution_recorder_buffer_size=1000, enable_expert_distribution_metrics=False, deepep_config=None, moe_dense_tp_size=None, elastic_ep_backend=None, mooncake_ib_device=None, max_mamba_cache_size=None, mamba_ssm_dtype='float32', mamba_full_memory_ratio=0.9, enable_hierarchical_cache=False, hicache_ratio=2.0, hicache_size=0, hicache_write_policy='write_through', hicache_io_backend='kernel', hicache_mem_layout='layer_first', hicache_storage_backend=None, hicache_storage_prefetch_policy='best_effort', hicache_storage_backend_extra_config=None, enable_lmcache=False, kt_weight_path=None, kt_method='AMXINT4', kt_cpuinfer=None, kt_threadpool_count=2, kt_num_gpu_experts=None, kt_max_deferred_experts_per_token=None, enable_double_sparsity=False, ds_channel_config_path=None, ds_heavy_channel_num=32, ds_heavy_token_num=256, ds_heavy_channel_type='qk', ds_sparse_decode_threshold=4096, cpu_offload_gb=0, offload_group_size=-1, offload_num_in_group=1, offload_prefetch_step=1, offload_mode='cpu', multi_item_scoring_delimiter=None, disable_radix_cache=True, cuda_graph_max_bs=24, cuda_graph_bs=[1, 2, 4, 8, 12, 16, 24], disable_cuda_graph=False, disable_cuda_graph_padding=False, enable_profile_cuda_graph=False, enable_cudagraph_gc=False, enable_nccl_nvls=False, enable_symm_mem=False, disable_flashinfer_cutlass_moe_fp4_allgather=False, enable_tokenizer_batch_encode=False, disable_tokenizer_batch_decode=False, disable_outlines_disk_cache=False, disable_custom_all_reduce=False, enable_mscclpp=False, enable_torch_symm_mem=False, disable_overlap_schedule=True, enable_mixed_chunk=False, enable_dp_attention=False, enable_dp_lm_head=False, enable_two_batch_overlap=False, enable_single_batch_overlap=False, tbo_token_distribution_threshold=0.48, enable_torch_compile=False, enable_piecewise_cuda_graph=False, torch_compile_max_bs=32, piecewise_cuda_graph_max_tokens=4096, piecewise_cuda_graph_tokens=[4, 8, 12, 16, 20, 24, 28, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240, 256, 288, 320, 352, 384, 416, 448, 480, 512, 640, 768, 896, 1024, 1152, 1280, 1408, 1536, 1664, 1792, 1920, 2048, 2176, 2304, 2432, 2560, 2688, 2816, 2944, 3072, 3200, 3328, 3456, 3584, 3712, 3840, 3968, 4096], piecewise_cuda_graph_compiler='eager', torchao_config='', enable_nan_detection=False, enable_p2p_check=False, triton_attention_reduce_in_fp32=False, triton_attention_num_kv_splits=8, triton_attention_split_tile_size=None, num_continuous_decode_steps=1, delete_ckpt_after_loading=False, enable_memory_saver=False, enable_weights_cpu_backup=False, allow_auto_truncate=False, enable_custom_logit_processor=False, flashinfer_mla_disable_ragged=False, disable_shared_experts_fusion=False, disable_chunked_prefix_cache=False, disable_fast_image_processor=False, keep_mm_feature_on_device=False, enable_return_hidden_states=False, scheduler_recv_interval=1, numa_node=None, enable_deterministic_inference=False, rl_on_policy_target=None, enable_dynamic_batch_tokenizer=False, dynamic_batch_tokenizer_batch_size=32, dynamic_batch_tokenizer_batch_timeout=0.002, debug_tensor_dump_output_folder=None, debug_tensor_dump_layers=None, debug_tensor_dump_input_file=None, debug_tensor_dump_inject=False, disaggregation_mode='null', disaggregation_transfer_backend='mooncake', disaggregation_bootstrap_port=8998, disaggregation_decode_tp=None, disaggregation_decode_dp=None, disaggregation_prefill_pp=1, disaggregation_ib_device=None, disaggregation_decode_enable_offload_kvcache=False, num_reserved_decode_tokens=512, disaggregation_decode_polling_interval=1, custom_weight_loader=[], weight_loader_disable_mmap=False, remote_instance_weight_loader_seed_instance_ip=None, remote_instance_weight_loader_seed_instance_service_port=None, remote_instance_weight_loader_send_weights_group_ports=None, enable_pdmux=False, pdmux_config_path=None, sm_group_num=8, mm_max_concurrent_calls=32, mm_per_request_timeout=10.0, decrypted_config_file=None, decrypted_draft_config_file=None)
[2025-11-11 10:56:52] Downcasting torch.float32 to torch.float16.
/home/yanbingj/miniforge3/envs/sglang/lib/python3.12/site-packages/torch/cuda/__init__.py:63: FutureWarning: The pynvml package is deprecated. Please install nvidia-ml-py instead. If you did not install pynvml directly, please report this to the maintainers of the package that installed pynvml for you.
  import pynvml  # type: ignore[import]
/home/yanbingj/miniforge3/envs/sglang/lib/python3.12/site-packages/torch/cuda/__init__.py:63: FutureWarning: The pynvml package is deprecated. Please install nvidia-ml-py instead. If you did not install pynvml directly, please report this to the maintainers of the package that installed pynvml for you.
  import pynvml  # type: ignore[import]
preprocessor_config.json: 185kB [00:00, 112MB/s]
tokenizer_config.json: 283kB [00:00, 111MB/s]
vocab.json: 836kB [00:00, 1.08MB/s]
[2025-11-11 10:57:02] INFO trace.py:52: opentelemetry package is not installed, tracing disabled
[2025-11-11 10:57:02] INFO trace.py:52: opentelemetry package is not installed, tracing disabled
tokenizer.json: 2.48MB [00:00, 2.51MB/s]
merges.txt: 494kB [00:00, 1.93MB/s]
[2025-11-11 10:57:04] Downcasting torch.float32 to torch.float16.
normalizer.json: 52.7kB [00:00, 206kB/s]
added_tokens.json: 34.6kB [00:00, 134kB/s]
special_tokens_map.json: 2.19kB [00:00, 5.70MB/s]
[2025-11-11 10:57:10] No chat template found, defaulting to 'string' content format
[2025-11-11 10:57:11] Downcasting torch.float32 to torch.float16.
[2025-11-11 10:57:11] Init torch distributed begin.
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[2025-11-11 10:57:12] Init torch distributed ends. mem usage=0.00 GB
[2025-11-11 10:57:12] MOE_RUNNER_BACKEND is not initialized, the backend will be automatically selected
[2025-11-11 10:57:13] Load weight begin. avail mem=23.17 GB
[2025-11-11 10:57:14] Using model weights format ['*.safetensors']
model.safetensors: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| 151M/151M [00:12<00:00, 12.4MB/s]
[2025-11-11 10:57:28] No model.safetensors.index.json found in remote.
Loading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00, 108.21it/s]

[2025-11-11 10:57:29] Load weight end. type=WhisperForConditionalGeneration, dtype=torch.float16, avail mem=23.00 GB, mem usage=0.17 GB.
[2025-11-11 10:57:29] Using KV cache dtype: torch.float16
[2025-11-11 10:57:29] KV Cache is allocated. #tokens: 63356, K size: 0.18 GB, V size: 0.18 GB
[2025-11-11 10:57:29] Memory pool end. avail mem=22.57 GB
[2025-11-11 10:57:29] Capture cuda graph begin. This can take up to several minutes. avail mem=21.98 GB
[2025-11-11 10:57:29] Capture cuda graph bs [1, 2, 4, 8, 12, 16, 24]
Capturing batches (bs=24 avail_mem=21.96 GB):   0%|                                                                                          | 0/7 [00:12<?, ?it/s]
[2025-11-11 10:57:42] Scheduler hit an exception: Traceback (most recent call last):
  File "/home/yanbingj/projects/sglang/python/sglang/srt/managers/scheduler.py", line 2672, in run_scheduler_process
    scheduler = Scheduler(
                ^^^^^^^^^^
  File "/home/yanbingj/projects/sglang/python/sglang/srt/managers/scheduler.py", line 311, in __init__
    self.tp_worker = TpModelWorker(
                     ^^^^^^^^^^^^^^
  File "/home/yanbingj/projects/sglang/python/sglang/srt/managers/tp_worker.py", line 237, in __init__
    self._model_runner = ModelRunner(
                         ^^^^^^^^^^^^
  File "/home/yanbingj/projects/sglang/python/sglang/srt/model_executor/model_runner.py", line 323, in __init__
    self.initialize(min_per_gpu_memory)
  File "/home/yanbingj/projects/sglang/python/sglang/srt/model_executor/model_runner.py", line 490, in initialize
    self.init_device_graphs()
  File "/home/yanbingj/projects/sglang/python/sglang/srt/model_executor/model_runner.py", line 2006, in init_device_graphs
    self.graph_runner = graph_runners[self.device](self)
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/yanbingj/projects/sglang/python/sglang/srt/model_executor/cuda_graph_runner.py", line 381, in __init__
    self.capture()
  File "/home/yanbingj/projects/sglang/python/sglang/srt/model_executor/cuda_graph_runner.py", line 500, in capture
    ) = self.capture_one_batch_size(bs, forward)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/yanbingj/projects/sglang/python/sglang/srt/model_executor/cuda_graph_runner.py", line 692, in capture_one_batch_size
    run_once()
  File "/home/yanbingj/projects/sglang/python/sglang/srt/model_executor/cuda_graph_runner.py", line 679, in run_once
    logits_output_or_pp_proxy_tensors = forward(
                                        ^^^^^^^^
  File "/home/yanbingj/projects/sglang/python/sglang/srt/models/whisper.py", line 416, in forward
    assert mm_inputs is not None
           ^^^^^^^^^^^^^^^^^^^^^
AssertionError

[2025-11-11 10:57:42] Received sigquit from a child process. It usually means the child failed.
Killed

Right now, sglang doesn’t have very good support for cross-attention, and I feel that this implementation has some issues.

@yanbing-j
Copy link
Copy Markdown
Contributor

Right now, sglang doesn’t have very good support for cross-attention, and I feel that this implementation has some issues.

Does SGLang has future plan to support this? Or a workable patch to enable Whisper?

@yhyang201
Copy link
Copy Markdown
Collaborator

Right now, sglang doesn’t have very good support for cross-attention, and I feel that this implementation has some issues.

Does SGLang has future plan to support this? Or a workable patch to enable Whisper?

I don’t think there are any official plans for that. Gracefully supporting cross-attention would require quite a bit of effort, and for now, it seems that not many models actually need it.

@MahmoudAshraf97
Copy link
Copy Markdown
Contributor Author

@yanbing-j there is no case where the text input to whisper can be None, so I don't understand why should it be supported?

@yhyang201 this implementation is correct and I've verified using multiple intermediate tensors, I have not committed the work arounds to make the model work to this branch because the problems need to be solved on SGL side, whisper uses very standard MHA so I don't understand why it needs these workarounds for it to work

@yhyang201
Copy link
Copy Markdown
Collaborator

@yanbing-j there is no case where the text input to whisper can be None, so I don't understand why should it be supported?

@yhyang201 this implementation is correct and I've verified using multiple intermediate tensors, I have not committed the work arounds to make the model work to this branch because the problems need to be solved on SGL side, whisper uses very standard MHA so I don't understand why it needs these workarounds for it to work

Your implementation should be fine. However, SGLang’s support for Cross Attention is not very complete (it seems somewhat coupled with mllama). Specifically, you can see this part:

. Given the current state of Cross Attention support in SGLang, I suspect the model might not function properly.

@MahmoudAshraf97
Copy link
Copy Markdown
Contributor Author

@yanbing-j there is no case where the text input to whisper can be None, so I don't understand why should it be supported?
@yhyang201 this implementation is correct and I've verified using multiple intermediate tensors, I have not committed the work arounds to make the model work to this branch because the problems need to be solved on SGL side, whisper uses very standard MHA so I don't understand why it needs these workarounds for it to work

Your implementation should be fine. However, SGLang’s support for Cross Attention is not very complete (it seems somewhat coupled with mllama). Specifically, you can see this part:

. Given the current state of Cross Attention support in SGLang, I suspect the model might not function properly.

I agree, since mllama was the first model implemented that uses cross attention, anyway the workaround to make the model work is to override cross attention with manual computation of the attention

@yanbing-j
Copy link
Copy Markdown
Contributor

@yanbing-j there is no case where the text input to whisper can be None, so I don't understand why should it be supported?

@MahmoudAshraf97 Could you please share how to run Whisper using this PR branch? I list my cmd in the above comments, I try to launch server and do the client part using your example scripe in the description, but it fails when launching server.

@MahmoudAshraf97
Copy link
Copy Markdown
Contributor Author

@yanbing-j there is no case where the text input to whisper can be None, so I don't understand why should it be supported?

@MahmoudAshraf97 Could you please share how to run Whisper using this PR branch? I list my cmd in the above comments, I try to launch server and do the client part using your example scripe in the description, but it fails when launching server.

Just disable cuda graphs and the server will run as expected, also I've updated the script above to be more simple since the implementation is more mature now

@mobicham
Copy link
Copy Markdown
Contributor

@MahmoudAshraf97 thanks for the work. I tried it with a 3:36 audio file but I am getting the following, so I guess this PR doesn't include longer audio files?

python/sglang/srt/managers/schedule_batch.py", line 1133, in prepare_encoder_info_extend
    len(self.out_cache_loc) == self.extend_num_tokens
AssertionError: Expected 0, got -1496

@MahmoudAshraf97
Copy link
Copy Markdown
Contributor Author

@MahmoudAshraf97 thanks for the work. I tried it with a 3:36 audio file but I am getting the following, so I guess this PR doesn't include longer audio files?

python/sglang/srt/managers/schedule_batch.py", line 1133, in prepare_encoder_info_extend
    len(self.out_cache_loc) == self.extend_num_tokens
AssertionError: Expected 0, got -1496

No this error is unrelated to the length of the audio file, use this patch to fix it
Although this PR does not intend to support segments more than 30s, segmentation should be handled by the client and not the server, which is the case in vLLM and TRT-LLM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants