Skip to content

[Whisper] Fix audio feature device placement in encoder forward#22296

Closed
shenxiul wants to merge 1 commit intosgl-project:mainfrom
shenxiul:fix/whisper-device-placement
Closed

[Whisper] Fix audio feature device placement in encoder forward#22296
shenxiul wants to merge 1 commit intosgl-project:mainfrom
shenxiul:fix/whisper-device-placement

Conversation

@shenxiul
Copy link
Copy Markdown

@shenxiul shenxiul commented Apr 8, 2026

Motivation

After #21190 enabled CUDA graph for Whisper, the server crashes on the first transcription request with:

RuntimeError: Expected all tensors to be on the same device, but got
weight is on cuda:0, different from other tensors on cpu

This happens because features.to(dtype) at line 462 of whisper.py only converts the dtype without moving the tensor to GPU. The multimodal processor returns input_features as a CPU tensor (via HuggingFace's feature extractor with return_tensors="pt"), and the default keep_mm_feature_on_device=False keeps them on CPU. When the encoder's conv1 (on CUDA) receives a CPU input, F.conv1d fails.

Similarly, encoder_position_ids is created on features.device (CPU) via torch.arange(...).to(features.device), so it also lands on CPU.

Modifications

  • Changed features.to(dtype) to features.to(device=device, dtype=dtype) where device is obtained from self.encoder.conv1.weight.device
  • Added encoder_position_ids.to(device) to move position IDs to the same device

Repro

# Terminal 1: launch server
python -m sglang.launch_server --model-path openai/whisper-large-v3 --port 30000

# Terminal 2: run benchmark (auto-downloads dataset, crashes on first request)
python benchmark/asr/bench_sglang.py \
  --base-url http://127.0.0.1:30000 \
  --model openai/whisper-large-v3 \
  --api-type transcription --language en --concurrency 1

Server crashes with:

File "sglang/srt/models/whisper.py", line 462, in forward
    req_encoder_output = self.encoder(
File "sglang/srt/models/whisper.py", line 295, in forward
    inputs_embeds = torch.nn.functional.gelu(self.conv1(input_features))
RuntimeError: Expected all tensors to be on the same device, but got
weight is on cuda:0, different from other tensors on cpu

Benchmark (after fix)

Tested on NVIDIA GB300, openai/whisper-large-v3, dataset D4nt3/esb-datasets-earnings22-validation-tiny-filtered (511 samples), concurrency=1:

CUDA graph No CUDA graph
default 4.74 req/s 1.01 req/s
keep_mm_on_device 2.29 req/s 1.19 req/s

Checklist

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

`features.to(dtype)` only converts the dtype without moving the tensor
to GPU, causing `RuntimeError: Expected all tensors to be on the same
device` when `conv1` weights are on CUDA but `input_features` remain on
CPU. This happens by default because `keep_mm_feature_on_device=False`
moves features to CPU after preprocessing.

Similarly, `encoder_position_ids` is created on `features.device` (CPU)
via `torch.arange(...).to(features.device)`, so it also needs explicit
device placement.

Fix: use `features.to(device=..., dtype=...)` and
`encoder_position_ids.to(device)` to explicitly move both tensors to the
model's device before calling the encoder.

## Bug introduced in
sgl-project#21190 (feat: enable CUDA graph and timestamp for whisper model)

## Repro
```bash
# Launch server (default config, no special flags needed)
python -m sglang.launch_server --model-path openai/whisper-large-v3 --port 30000

# Send a transcription request — server crashes with:
# RuntimeError: Expected all tensors to be on the same device,
# but got weight is on cuda:0, different from other tensors on cpu
curl http://localhost:30000/v1/audio/transcriptions \
  -F file=@test.wav -F model=openai/whisper-large-v3
```

The crash occurs on the first request because the multimodal
processor returns `input_features` as a CPU tensor (via
`return_tensors="pt"` from HuggingFace's feature extractor),
and the default `keep_mm_feature_on_device=False` keeps them on CPU.
The encoder's `conv1` weights are on CUDA, so `F.conv1d` fails with
a device mismatch.

## Benchmark (after fix)
Tested on NVIDIA GB300, openai/whisper-large-v3, D4nt3/esb-datasets-earnings22-validation-tiny-filtered (511 samples), concurrency=1:

|                     | CUDA graph    | No CUDA graph |
|---------------------|---------------|---------------|
| default             | 4.74 req/s    | 1.01 req/s    |
| keep_mm_on_device   | 2.29 req/s    | 1.19 req/s    |

- WER: 12.78% across all configs (matches sgl-project#21190)
- CUDA graph gives 4.7x throughput improvement
@shenxiul shenxiul force-pushed the fix/whisper-device-placement branch from 6d2e157 to 4961f0f Compare April 8, 2026 02:45
@JustinTong0323
Copy link
Copy Markdown
Collaborator

Duplicated with #22293

@JustinTong0323
Copy link
Copy Markdown
Collaborator

Verified & Root Cause Analysis

Tested on B200 — the fix is correct. The bug is 100% reproducible on current main.

Root Cause

The device placement issue was introduced by #22038 ([VLM] Chunk-aware ViT encoding with per-image cache and lazy device transfer).

Before #22038: schedule_batch.py had a global pixel_values.to(self.device) in ForwardBatch init that moved all mm features to GPU for every model.

After #22038: That global transfer was removed (-15 lines from schedule_batch.py). The GPU transfer was moved to mm_utils.py:_move_items_to_device(), which is only called in the VLM embedding path (_get_chunked_embedding_full, etc.).

Problem: Whisper does not go through the VLM embedding path — it directly accesses mm_input.mm_items[0].feature in its model forward. So after #22038, Whisper's features stay on CPU → features.to(dtype) only casts dtype without moving to GPU → F.conv1d fails with CPU input vs CUDA weight.

Verification

Test Result
Basic transcription PASS
Consistency (3 sequential requests) PASS — identical outputs
WER (c=1, 20 samples) 12.68%
WER (c=64, 511 samples) 12.76% — stable under load
Throughput (c=64) 42.70 req/s

The fix is the right approach — models should not assume the caller has already moved tensors to the correct device.

@yhyang201
Copy link
Copy Markdown
Collaborator

My bad, and Whisper will be added to CI afterwards.

@shenxiul
Copy link
Copy Markdown
Author

shenxiul commented Apr 8, 2026

QQ for 511 sample test runs, I'm getting following:

┌────────────────────┬──────┬───────┬───────┬───────┐
│       Config       │ c=1  │  c=4  │ c=16  │ c=64  │
├────────────────────┼──────┼───────┼───────┼───────┤
│ sglang_cg          │ 6.99 │ 12.81 │ 20.45 │ 24.74 │
├────────────────────┼──────┼───────┼───────┼───────┤
│ sglang_cg_keepmm   │ 7.05 │ 13.13 │ 21.32 │ 25.62 │
├────────────────────┼──────┼───────┼───────┼───────┤
│ sglang_nocg        │ 1.53 │ 4.66  │ 10.72 │ 15.06 │
├────────────────────┼──────┼───────┼───────┼───────┤
│ sglang_nocg_keepmm │ 1.55 │ 4.72  │ 10.70 │ 14.69 │
├────────────────────┼──────┼───────┼───────┼───────┤
│ vllm               │ 0.94 │ 3.21  │ 6.20  │ 6.08  │
└────────────────────┴──────┴───────┴───────┴───────┘

All of the numbers are req/s. Looks quite different from what you're getting, mind sharing precise command so that I can double check?

@JustinTong0323
Copy link
Copy Markdown
Collaborator

The benchmark commands used during #21190 development:

Server launch:

CUDA_VISIBLE_DEVICES=0 python -m sglang.launch_server \
  --model-path openai/whisper-large-v3 --port 30000

No extra flags — all defaults (gpu_memory_utilization, cuda graph, etc.).

Benchmark:

# High concurrency (511 samples)
python benchmark/asr/bench_sglang.py \
  --base-url http://localhost:30000 \
  --model openai/whisper-large-v3 \
  --api-type transcription \
  --language en \
  --concurrency 64

# Single request (50 samples)
python benchmark/asr/bench_sglang.py \
  --base-url http://localhost:30000 \
  --model openai/whisper-large-v3 \
  --api-type transcription \
  --language en \
  --concurrency 1 \
  --n-examples 50

Hardware: Single NVIDIA B200 (183 GB). Our c=64 numbers were ~48 req/s on B200 — the difference from your ~25 req/s on GB300 is likely hardware-dependent (memory bandwidth, clock speeds, etc.).

@JustinTong0323
Copy link
Copy Markdown
Collaborator

Verified on GB300 as well. Here are the numbers side by side:

SGLang (main + this fix), openai/whisper-large-v3, 511 samples

Concurrency B200 (req/s) GB300 (req/s)
1 7.02 7.56
4 22.95 13.85
16 36.33 21.99
64 43.03 25.77

WER stays at ~12.77% across all configs on both GPUs.

@JustinTong0323
Copy link
Copy Markdown
Collaborator

solved in #22293

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants