[Intel GPU] Enable DeepSeek V4 Inference on XPU#25336
Merged
mingfeima merged 25 commits intoMay 18, 2026
Merged
Conversation
Adds full SM120 (RTX PRO 6000 / RTX 5090 / DGX Spark) support for DeepSeek-V4 on SGLang, rebased onto main branch. Key changes: - Triton MXFP4 MoE kernel for SM120 (no MARLIN/tcgen05 on desktop Blackwell) - Triton FlashMLA sparse decode kernel for SM120 - MQA wq-precompute with vectorized batch for CUDA graph compatibility - DeepGEMM/PDL guards for SM120 (no TMEM/tcgen05) - NSA backend SM120 dispatch (tilelang default, skip DeepGEMM metadata) - FlashMLA SM120 adapter for deepseek_v4_backend - 3 CUDA-graph-breaking paths fixed (MoE .unique/.item, NSA/Compressed MQA) Results (8x RTX PRO 6000, TP=8): - Decode: 10.26 tok/s BS=1 with CUDA graph (2.4x vs without) - GSM8K 5-shot: 98.0% accuracy (200 questions) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
… missing
Detect tvm_ffi availability at module import time via _HAS_TVM_FFI.
When unavailable (e.g. on XPU where the JIT-compiled CUDA toolchain is
not present), CompressorPrefillPlan.generate dispatches to a new
_torch_plan_compress_prefill helper that mirrors plan_prefill_host in
jit_kernel/csrc/deepseek_v4/common.cuh, packing each PrefillPlan
{ragged_id, batch_id, position, window_len} as 4 little-endian uint32
(16 bytes) into the (num_q_tokens, 16) uint8 buffer.
Signed-off-by: P V R K Jyothendra Varma <polisettyvarma@gmail.com>
sgl_kernel's fused hadamard_transform is CUDA/HIP-only, so importing it from rotate_activation() fails on Intel XPU with ImportError when the NSA indexer path runs (e.g. DeepSeek-V4 compressed attention). Add an iterative Walsh-Hadamard transform in pure torch and dispatch to it when _is_xpu is True. Same contract as the fused op: operates on the last dim, requires power-of-two size, multiplies by the supplied scale. CUDA/HIP/SM103/NPU paths are unchanged.
Signed-off-by: Rahul Vijayaraghavan <rahul.vijayaraghavan@intel.com>
Signed-off-by: Rahul Vijayaraghavan <rahul.vijayaraghavan@intel.com>
Signed-off-by: P V R K Jyothendra Varma <polisettyvarma@gmail.com>
…p OOB indices The pure-PyTorch FlashMLA fallback used by Intel Arc / SM120 dequantized the full (B, s_q, topk) KV tensor at once. On long-context prefills with large topk this allocated multi-GB int64 advanced-index tensors plus the bf16 KV gather, often OOMing or stalling on per-call USM allocations inside the L0 driver. Refactor: - _gather_and_dequant flattens the raw paged byte buffer once and gathers in fixed-size token chunks (16k tokens), keeping the int64 index tensors small and reusable across pages. - _sm120_sparse_decode_fwd flattens (B, s_q) into rows and processes them in chunks sized to a configurable peak memory budget (SGLANG_SM120_SPARSE_CHUNK_MIB, default 256 MiB). - Indices outside [0, num_pages*page_size) are now treated as invalid and clamped before gather, matching the CUDA kernel's tile-scheduler behavior which simply skips OOB entries. No behavior change for in-bounds indices; matches the previous non-chunked path numerically.
The prior code allocated topk_weights ([M, topk] fp32) and topk_ids ([M, topk] int32) and then immediately overwrote both with the result of torch.topk, so the empty allocations were pure waste (and on XPU each empty() may take an L0 USM allocation path). Inline the scoring result and cast topk_ids to int32 to preserve the original dtype.
…rized
This vectorized topk -> page-index transform runs once per layer per
prefill chunk, and several patterns in it caused per-call host syncs or
allocations on Intel L0 (XPU) that dominated indexer latency:
* ``masked_scores[~valid_mask] = float("-inf")`` -- boolean-indexed
scatter, a known L0 sync hot spot (count-nonzero + USM alloc +
scatter sync per call). Replaced with a single ``torch.where``
against a cached scalar ``-inf`` tensor.
* ``scores[batch_idx.flatten(), raw.flatten()].view(B, TOPK)`` --
materializes a flat int64 advanced-index tensor of size B*TOPK
every call. Replaced with a single ``torch.gather(scores, dim=1,
index=...)`` kernel.
* ``if needs_sequential.any():`` -- forces a D2H sync per call.
Removed the guard and run the sequential override unconditionally
via ``torch.where``; it is cheap when no row hits it.
* ``torch.tensor(-1, device=device, ...)`` and
``torch.tensor(float("-inf"), device=device, ...)`` constructed
inside the function fire an H2D copy + sync per call. Hoisted
into per-device caches (``_neg_inf_scalar`` /
``_neg_one_i32_scalar``) so the scalar is built once per device
and reused across layers.
…rt time The sparse decode fwd previously called os.environ.get(...) every invocation to read the per-chunk peak-memory budget. This runs once per layer per decode step, and os.environ access is not free on the hot path. Read the variable once at module import into _SM120_SPARSE_CHUNK_MIB and reference the cached int from the forward path. No behavior change.
The pure-PyTorch fallback path for the DeepSeek-V4 compress/write/RoPE
kernels had four recurring patterns that each force-drain the L0
command queue every layer on Intel XPU and turned a single decode step
into minutes:
* ``_decode_prefill_plan`` did a D->H copy + numpy ``view("<u4")``
bitcast of the plan tensor. Replaced with an on-device
``view(torch.int32).reshape(-1, 4).to(torch.int64) & 0xFFFFFFFF``
so the kInvalid sentinel 0xFFFFFFFF still compares correctly after
the signed->unsigned promotion.
* ``bool(do_fwd.any()) / do_fwd.nonzero()`` early-out on the per-step
decode hot path. Removed in ``_torch_c4_decode``,
``_torch_c128_decode``, and ``_torch_fused_norm_rope`` (mode 1).
These now compute over the full B-batch and mask the writeback
with ``torch.where``; ``positions`` is clamped where needed so the
in-flight gather stays in-bounds.
* ``valid = (plan[:, 0] != _INVALID); cp = plan[valid].to(device)``
boolean-mask gather on prefill plans. Prefill plans are already
pre-sliced to their exact valid length on the host before the H->D
copy (see ``CompressorPrefillPlan.generate``) and never use
cuda-graph padding, so an unconditional ``if plan.shape[0] > 0``
is sufficient. Removed in ``_torch_c4_prefill``,
``_torch_c4_prefill_write``, ``_torch_c128_prefill``, and
``_torch_fused_norm_rope`` (mode 0).
* ``bool(sl4.any())`` guard around the seq_len==4 special case in
``_torch_c4_decode`` and ``_torch_c4_prefill``. The masked
``torch.where`` is cheap enough to apply unconditionally; rows
where seq_len != 4 see no change.
Signed-off-by: P V R K Jyothendra Varma <polisettyvarma@gmail.com>
Signed-off-by: Rahul Vijayaraghavan <rahul.vijayaraghavan@intel.com>
Signed-off-by: Rahul Vijayaraghavan <rahul.vijayaraghavan@intel.com>
Signed-off-by: P V R K Jyothendra Varma <polisettyvarma@gmail.com>
Signed-off-by: P V R K Jyothendra Varma <polisettyvarma@gmail.com>
…rrectly Signed-off-by: P V R K Jyothendra Varma <polisettyvarma@gmail.com>
Contributor
|
Warning You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again! |
Collaborator
|
@jianan-gu @sunjiweiswift please help review this one~ |
jianan-gu
reviewed
May 15, 2026
jianan-gu
reviewed
May 15, 2026
jianan-gu
reviewed
May 15, 2026
jianan-gu
reviewed
May 15, 2026
jianan-gu
reviewed
May 15, 2026
jianan-gu
reviewed
May 15, 2026
Signed-off-by: P V R K Jyothendra Varma <polisettyvarma@gmail.com>
SGLANG_SM120_TRITON_FLASHMLA=0 SGLANG_FP8_PAGED_MQA_LOGITS_TRITON=true SGLANG_TOPK_TRANSFORM_512_TORCH=true SGLANG_OPT_SWIGLU_CLAMP_FUSION=false SGLANG_OPT_USE_FUSED_HASH_TOPK=false SGLANG_OPT_BF16_FP32_GEMM_ALGO=torch SGLANG_OPT_USE_TILELANG_MHC_POST=false SGLANG_OPT_USE_FUSED_STORE_CACHE=false SGLANG_OPT_USE_TILELANG_MHC_PRE=false SGLANG_OPT_DEEPGEMM_HC_PRENORM=false Signed-off-by: P V R K Jyothendra Varma <polisettyvarma@gmail.com>
|
Please update the command in the description after removing the flags |
Contributor
Author
done |
msinnha1
approved these changes
May 15, 2026
msinnha1
left a comment
There was a problem hiding this comment.
Some minor cleanup required when it is planned for upstream change.. mostly from duplicate functions
jianan-gu
reviewed
May 18, 2026
| SGLANG_FP8_PAGED_MQA_LOGITS_TORCH = EnvBool(False) | ||
| SGLANG_TOPK_TRANSFORM_512_TORCH = EnvBool(False) | ||
| SGLANG_FP8_PAGED_MQA_LOGITS_TRITON = EnvBool(True) | ||
| SGLANG_TOPK_TRANSFORM_512_TORCH = EnvBool(True) |
Contributor
There was a problem hiding this comment.
@polisettyvarma please fix this one, hardcode SGLANG_TOPK_TRANSFORM_512_TORCH=True will let other platforms go this path, which blocks CPU path! Please make sure your changes is XPU only!
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
dependency - 89cb865
#24692
offline throughput command
--json-model-override-args '{"num_hidden_layers": 5}' --profile --profile-activities "CPU" "XPU"add above for profiling reduced model with env
SGLANG_TORCH_PROFILER_DIR=/tmpserver:
client:
offline throughput comes around 14.8 and 13.5 for serving on B70 8x
CI States
Latest PR Test: ❌ Missing
run-cilabel — add it to run CI tests.Latest PR Test (Extra): ❌ Blocked —
run-ciis required first.