Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
9a4bb0b
Add AMD AITER MLA fusion optimization for DeepSeek models
khairulkabir1661 Feb 27, 2026
dda6084
Add comprehensive tests for MLA fusion on AMD AITER
khairulkabir1661 Mar 2, 2026
1eafc57
Fix pre-commit issues in test_mla_fusion.py
khairulkabir1661 Mar 2, 2026
b020a60
Fix pytest mark warnings in test_mla_fusion.py
khairulkabir1661 Mar 2, 2026
cb71c3b
test: Remove placeholder tests from MLA fusion test suite
khairulkabir1661 Mar 3, 2026
2de8c58
Fix code review issues: improve exception handling and add logging
khairulkabir1661 Mar 4, 2026
d95dd82
Fix MLA fusion: use custom op pattern and clean up tests
khairulkabir1661 Mar 5, 2026
dc47e37
Fix MLA fusion tests: compare FP8-fused vs FP8-unfused
khairulkabir1661 Mar 5, 2026
882debd
Remove test_deterministic_outputs from MLA fusion tests
khairulkabir1661 Mar 5, 2026
d132b95
Fix MLA fusion custom op registration and optimize tests
khairulkabir1661 Mar 5, 2026
59895ac
[ROCm][FP8] Add x_scale parameter support for MLA fusion (Option 2)
khairulkabir1661 Mar 6, 2026
8c38d23
Fix mypy signature compatibility for x_scale/input_scale parameters
khairulkabir1661 Mar 6, 2026
576b09e
Clean up mla.py comments (lines 259-287)
khairulkabir1661 Mar 27, 2026
6ff38d8
Clarify q_c_scale comment in mla.py (line 240)
khairulkabir1661 Mar 27, 2026
16bea47
Clean up fusion init comments (lines 213-231)
khairulkabir1661 Mar 27, 2026
f49af5e
Clean up AITER fusion helper comments (lines 15-112)
khairulkabir1661 Mar 27, 2026
5a505ea
Remove test_mla_fusion.py test file
khairulkabir1661 Mar 27, 2026
b54784a
Clean up comments in fp8_utils.py
khairulkabir1661 Mar 27, 2026
709d0ed
Fix input_scale and output_dtype handling in FP8 quantization
khairulkabir1661 Mar 27, 2026
7e677a7
Remove is_layer_moe_router_gate check from batch invariance
khairulkabir1661 Mar 27, 2026
ba89242
[ROCm] Add unified AITER RoPE + KV cache kernel for MLA
khairulkabir1661 Mar 27, 2026
97d95fc
Fix batch invariant: use envs.VLLM_BATCH_INVARIANT instead of function
khairulkabir1661 Mar 27, 2026
b0054ae
Match FA3/FA4 padding logic with main branch
khairulkabir1661 Mar 27, 2026
76d270b
Add missing logger.info_once for MLA prefill backends
khairulkabir1661 Mar 27, 2026
4a51328
Add support for quantized layers without .weight attribute
khairulkabir1661 Mar 27, 2026
5c130c2
Add missing logger.info_once for FP8 prefill attention
khairulkabir1661 Mar 27, 2026
c375285
Remove comment from use_flashinfer_prefill to match main
khairulkabir1661 Mar 27, 2026
96cdcd5
Update get_kv_cache_stride_order to match main branch
khairulkabir1661 Mar 27, 2026
6e65048
Add missing XPU flash_attn support section
khairulkabir1661 Mar 27, 2026
3bc4c1a
Clean up comments and remove debug statements
khairulkabir1661 Mar 27, 2026
26d37ae
Fix kv_cache indexing and clean up debug comments
khairulkabir1661 Mar 27, 2026
17b4fed
Revert FP4/FP8 BMM comments to match main branch
khairulkabir1661 Mar 27, 2026
cc8730e
Reorganize comments in unified RoPE+KV fusion section
khairulkabir1661 Mar 27, 2026
d221c67
Clean up forward_impl parameter comments
khairulkabir1661 Mar 27, 2026
4be0a72
Simplify comments for custom ops path
khairulkabir1661 Mar 27, 2026
a64e721
Clean up forward method comments and fix kv_cache indexing
khairulkabir1661 Mar 27, 2026
d7105b9
Fix kv_cache initialization to match main branch
khairulkabir1661 Mar 27, 2026
8b14e38
Restore VLLM_BATCH_INVARIANT to envs.py to match main branch
khairulkabir1661 Mar 27, 2026
420048f
Reorganize and simplify comments in mla.py
khairulkabir1661 Mar 27, 2026
045a604
Remove @torch_compile_guard to make fusion transparent to compiler
khairulkabir1661 Mar 30, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 42 additions & 34 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,10 @@
VLLM_CPU_OMP_THREADS_BIND: str = "auto"
VLLM_CPU_NUM_OF_RESERVED_CPU: int | None = None
VLLM_CPU_SGL_KERNEL: bool = False
VLLM_ZENTORCH_WEIGHT_PREPACK: bool = True
VLLM_XLA_CACHE_PATH: str = os.path.join(VLLM_CACHE_ROOT, "xla_cache")
VLLM_XLA_CHECK_RECOMPILATION: bool = False
VLLM_FUSED_MOE_CHUNK_SIZE: int = 16 * 1024
VLLM_ENABLE_FUSED_MOE_ACTIVATION_CHUNKING: bool = True
VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE: Literal["auto", "nccl", "shm"] = "auto"
VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM: bool = False
VLLM_USE_RAY_WRAPPED_PP_COMM: bool = True
Expand Down Expand Up @@ -97,7 +98,6 @@
VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False
VLLM_SKIP_P2P_CHECK: bool = False
VLLM_DISABLED_KERNELS: list[str] = []
VLLM_ENABLE_FLA_PACKED_RECURRENT_DECODE: bool = True
VLLM_DISABLE_PYNCCL: bool = False
VLLM_USE_OINK_OPS: bool = False
VLLM_ROCM_USE_AITER: bool = False
Expand All @@ -117,6 +117,9 @@
VLLM_ROCM_USE_SKINNY_GEMM: bool = True
VLLM_ROCM_FP8_PADDING: bool = True
VLLM_ROCM_MOE_PADDING: bool = True
VLLM_USE_AITER_FUSED: bool = True
VLLM_USE_AITER_PREFILL_FUSED: bool = True
VLLM_ROCM_CUSTOM_PAGED_ATTN: bool = True
VLLM_ROCM_SHUFFLE_KV_CACHE_LAYOUT: bool = False
VLLM_ENABLE_V1_MULTIPROCESSING: bool = True
VLLM_LOG_BATCHSIZE_INTERVAL: float = -1
Expand Down Expand Up @@ -169,7 +172,7 @@
VLLM_FLASHINFER_MOE_BACKEND: Literal["throughput", "latency", "masked_gemm"] = (
"latency"
)
VLLM_FLASHINFER_ALLREDUCE_BACKEND: Literal["auto", "trtllm", "mnnvl"] = "auto"
VLLM_FLASHINFER_ALLREDUCE_BACKEND: Literal["auto", "trtllm", "mnnvl"] = "trtllm"
VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE: int = 394 * 1024 * 1024
VLLM_XGRAMMAR_CACHE_MB: int = 0
VLLM_MSGPACK_ZERO_COPY_THRESHOLD: int = 256
Expand Down Expand Up @@ -246,8 +249,6 @@
VLLM_ELASTIC_EP_SCALE_UP_LAUNCH: bool = False
VLLM_ELASTIC_EP_DRAIN_REQUESTS: bool = False
VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS: bool = False
VLLM_NIXL_EP_MAX_NUM_RANKS: int = 32
VLLM_XPU_ENABLE_XPU_GRAPH: bool = False


def get_default_cache_root():
Expand Down Expand Up @@ -295,16 +296,6 @@ def use_aot_compile() -> bool:
)


def use_mega_aot_artifact():
from vllm.utils.torch_utils import is_torch_equal_or_newer

default_value = (
"1" if is_torch_equal_or_newer("2.12.0.dev") and use_aot_compile() else "0"
)

return os.environ.get("VLLM_USE_MEGA_AOT_ARTIFACT", default_value) == "1"


def env_with_choices(
env_name: str,
default: str | None,
Expand Down Expand Up @@ -628,7 +619,10 @@ def _get_or_set_default() -> str:
# Enable loading compiled models directly from cached standalone compile artifacts
# without re-splitting graph modules. This reduces overhead during model
# loading by using reconstruct_serializable_fn_from_mega_artifact.
"VLLM_USE_MEGA_AOT_ARTIFACT": use_mega_aot_artifact,
"VLLM_USE_MEGA_AOT_ARTIFACT": lambda: os.environ.get(
"VLLM_USE_MEGA_AOT_ARTIFACT", "0"
)
== "1",
# local rank of the process in the distributed setting, used to determine
# the GPU device id
"LOCAL_RANK": lambda: int(os.environ.get("LOCAL_RANK", "0")),
Expand Down Expand Up @@ -719,11 +713,6 @@ def _get_or_set_default() -> str:
else None,
# (CPU backend only) whether to use SGL kernels, optimized for small batch.
"VLLM_CPU_SGL_KERNEL": lambda: bool(int(os.getenv("VLLM_CPU_SGL_KERNEL", "0"))),
# (Zen CPU backend) eagerly prepack weights into ZenDNN blocked layout
# at model load time. Eliminates per-inference layout conversion overhead.
"VLLM_ZENTORCH_WEIGHT_PREPACK": lambda: bool(
int(os.getenv("VLLM_ZENTORCH_WEIGHT_PREPACK", "1"))
),
# If the env var is set, Ray Compiled Graph uses the specified
# channel type to communicate between workers belonging to
# different pipeline-parallel stages.
Expand Down Expand Up @@ -841,6 +830,15 @@ def _get_or_set_default() -> str:
),
# Enable SPMD mode for TPU backend.
"VLLM_XLA_USE_SPMD": lambda: bool(int(os.getenv("VLLM_XLA_USE_SPMD", "0"))),
"VLLM_FUSED_MOE_CHUNK_SIZE": lambda: int(
os.getenv("VLLM_FUSED_MOE_CHUNK_SIZE", str(16 * 1024))
),
# Control whether to use fused MoE activation chunking. Current chunking
# logic is incompatible with torch.compile and causes IMA. See issue
# https://github.com/vllm-project/vllm/issues/19631.
"VLLM_ENABLE_FUSED_MOE_ACTIVATION_CHUNKING": lambda: bool(
int(os.getenv("VLLM_ENABLE_FUSED_MOE_ACTIVATION_CHUNKING", "1"))
),
# If set, the OpenAI API server will stay alive even after the underlying
# AsyncLLMEngine errors and stops serving requests
"VLLM_KEEP_ALIVE_ON_ENGINE_DEATH": lambda: bool(
Expand Down Expand Up @@ -910,9 +908,6 @@ def _get_or_set_default() -> str:
"VLLM_DISABLED_KERNELS": lambda: []
if "VLLM_DISABLED_KERNELS" not in os.environ
else os.environ["VLLM_DISABLED_KERNELS"].split(","),
"VLLM_ENABLE_FLA_PACKED_RECURRENT_DECODE": lambda: bool(
int(os.getenv("VLLM_ENABLE_FLA_PACKED_RECURRENT_DECODE", "1"))
),
# Disable pynccl (using torch.distributed instead)
"VLLM_DISABLE_PYNCCL": lambda: (
os.getenv("VLLM_DISABLE_PYNCCL", "False").lower() in ("true", "1")
Expand Down Expand Up @@ -993,6 +988,19 @@ def _get_or_set_default() -> str:
"VLLM_ROCM_USE_AITER_TRITON_GEMM": lambda: (
os.getenv("VLLM_ROCM_USE_AITER_TRITON_GEMM", "True").lower() in ("true", "1")
),
# Enable AITER fused decode kernel for MLA (ROCm only, decode path only)
# Enable AITER fused kernels for MLA (ROCm only, prefill and decode)
# Fuses: RoPE + concat + KV cache write (prefill) or BMM + RoPE +
# concat + KV cache write (decode) in ONE kernel
# By default is enabled for AMD GPUs with FP8 support.
"VLLM_USE_AITER_FUSED": lambda: (
os.getenv("VLLM_USE_AITER_FUSED", "True").lower() in ("true", "1")
),
# AITER fused RoPE + KV cache write for prefill tokens
# By default is enabled when VLLM_USE_AITER_FUSED is enabled.
"VLLM_USE_AITER_PREFILL_FUSED": lambda: (
os.getenv("VLLM_USE_AITER_PREFILL_FUSED", "True").lower() in ("true", "1")
),
# use rocm skinny gemms
"VLLM_ROCM_USE_SKINNY_GEMM": lambda: (
os.getenv("VLLM_ROCM_USE_SKINNY_GEMM", "True").lower() in ("true", "1")
Expand All @@ -1001,6 +1009,10 @@ def _get_or_set_default() -> str:
"VLLM_ROCM_FP8_PADDING": lambda: bool(int(os.getenv("VLLM_ROCM_FP8_PADDING", "1"))),
# Pad the weights for the moe kernel
"VLLM_ROCM_MOE_PADDING": lambda: bool(int(os.getenv("VLLM_ROCM_MOE_PADDING", "1"))),
# custom paged attention kernel for MI3* cards
"VLLM_ROCM_CUSTOM_PAGED_ATTN": lambda: (
os.getenv("VLLM_ROCM_CUSTOM_PAGED_ATTN", "True").lower() in ("true", "1")
),
# Whether to use the shuffled kv cache layout
"VLLM_ROCM_SHUFFLE_KV_CACHE_LAYOUT": lambda: (
os.getenv("VLLM_ROCM_SHUFFLE_KV_CACHE_LAYOUT", "False").lower() in ("true", "1")
Expand Down Expand Up @@ -1305,9 +1317,14 @@ def _get_or_set_default() -> str:
["throughput", "latency", "masked_gemm"],
),
# Flashinfer fused allreduce backend.
# "auto" will default to "mnnvl", which performs mostly same/better than "trtllm".
# But "mnnvl" backend does not support fuse with quantization.
# TODO: Default is "trtllm" right now because "mnnvl" has issues with cudagraph:
# https://github.com/vllm-project/vllm/issues/35772
# Should switch back to "auto" if the issue is resolved.
"VLLM_FLASHINFER_ALLREDUCE_BACKEND": env_with_choices(
"VLLM_FLASHINFER_ALLREDUCE_BACKEND",
"auto",
"trtllm",
["auto", "trtllm", "mnnvl"],
),
# Control the workspace buffer size for the FlashInfer backend.
Expand Down Expand Up @@ -1640,14 +1657,6 @@ def _get_or_set_default() -> str:
"VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS": lambda: bool(
int(os.getenv("VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS", "0"))
),
# NIXL EP environment variables
"VLLM_NIXL_EP_MAX_NUM_RANKS": lambda: int(
os.getenv("VLLM_NIXL_EP_MAX_NUM_RANKS", "32")
),
# Whether enable XPU graph on Intel GPU
"VLLM_XPU_ENABLE_XPU_GRAPH": lambda: bool(
int(os.getenv("VLLM_XPU_ENABLE_XPU_GRAPH", "0"))
),
}


Expand Down Expand Up @@ -1784,7 +1793,6 @@ def compile_factors() -> dict[str, object]:
"VLLM_V1_OUTPUT_PROC_CHUNK_SIZE",
"VLLM_CPU_KVCACHE_SPACE",
"VLLM_CPU_MOE_PREPACK",
"VLLM_ZENTORCH_WEIGHT_PREPACK",
"VLLM_TEST_FORCE_LOAD_FORMAT",
"VLLM_ENABLE_CUDA_COMPATIBILITY",
"VLLM_CUDA_COMPATIBILITY_PATH",
Expand Down
Loading
Loading