Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 0 additions & 17 deletions vllm/compilation/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -927,23 +927,6 @@ def collect_standalone_compile_artifacts(
return standalone_compile_artifacts, sym_shape_indices_map, returns_tuple_map

def configure_post_pass(self) -> None:
# TODO proper PassManager?
pre_grad_pass_key = "pre_grad_custom_pass"
assert self.pass_key != pre_grad_pass_key
assert pre_grad_pass_key not in self.inductor_config
self.inductor_config[pre_grad_pass_key] = VllmIRInplaceFunctionalizationPass(
self.vllm_config
)

# Make sure pre_grad_custom_pass is not pickled
# as part of AOTAutograd built-in cache key
# TODO(luka) is there a cleaner way to do this
import torch._inductor.config as inductor_config

ignore = inductor_config._cache_config_ignore_prefix + [pre_grad_pass_key]
assert "_cache_config_ignore_prefix" not in self.inductor_config
self.inductor_config["_cache_config_ignore_prefix"] = ignore

# Configure the (nominally post-grad) pass manager
self.pass_manager.configure(self.vllm_config)

Expand Down
10 changes: 4 additions & 6 deletions vllm/compilation/passes/pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

if rocm_aiter_ops.is_enabled():
from .fusion.allreduce_rms_fusion import (
AllReduceFusionPass,
RocmAiterAllReduceFusionPass,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The import of RocmAiterAllReduceFusionPass is now unused because it has been replaced by AllReduceFusionPass in the configure method. It should be removed.

)
from .fusion.rocm_aiter_fusion import (
Expand Down Expand Up @@ -116,8 +117,6 @@ def __call__(self, graph: fx.Graph) -> None:
# DCE handles mutating ops correctly as well.
self.ir_lowering(graph)
VllmInductorPass.dump_prefix += 1
self.clone_elimination(graph)
VllmInductorPass.dump_prefix += 1

# clean up after lowering again
self.post_cleanup(graph)
Expand All @@ -143,10 +142,9 @@ def configure(self, config: VllmConfig) -> None:
self.passes += [AsyncTPPass(config)]

if self.pass_config.fuse_allreduce_rms:
if rocm_aiter_ops.is_enabled():
self.passes += [RocmAiterAllReduceFusionPass(config)]
else:
self.passes += [AllReduceFusionPass(config)]
# RocmAiterAllReduceFusionPass is disabled: it corrupts HIP
# graph replay when sparse MLA attention (nhead=32) is active.
self.passes += [AllReduceFusionPass(config)]

if self.pass_config.fuse_minimax_qk_norm:
self.passes += [MiniMaxQKNormPass(config)]
Expand Down
10 changes: 1 addition & 9 deletions vllm/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,6 @@ def graph_capture(self, graph_capture_context: GraphCaptureContext | None = None
# only cuda uses this function,
# so we don't abstract it into the base class
maybe_ca_context = nullcontext()
maybe_aiter_context = nullcontext()
from vllm.distributed.device_communicators.cuda_communicator import (
CudaCommunicator,
)
Expand All @@ -483,20 +482,13 @@ def graph_capture(self, graph_capture_context: GraphCaptureContext | None = None
if ca_comm is not None:
maybe_ca_context = ca_comm.capture() # type: ignore

from vllm._aiter_ops import rocm_aiter_ops

if rocm_aiter_ops.is_enabled():
aiter_ar = rocm_aiter_ops.get_aiter_allreduce()
if aiter_ar is not None:
maybe_aiter_context = aiter_ar.capture() # type: ignore

# ensure all initialization operations complete before attempting to
# capture the graph on another stream
curr_stream = torch.cuda.current_stream()
if curr_stream != stream:
stream.wait_stream(curr_stream)

with torch.cuda.stream(stream), maybe_ca_context, maybe_aiter_context:
with torch.cuda.stream(stream), maybe_ca_context:
yield graph_capture_context

def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
Expand Down
3 changes: 1 addition & 2 deletions vllm/model_executor/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,9 +351,8 @@ def __init__(
self.is_rocm_aiter_moe_enabled
and self.gate.e_score_correction_bias is not None
):
gate_out_dtype = self.gate.out_dtype or self.gate.weight.dtype
self.gate.e_score_correction_bias.data = (
self.gate.e_score_correction_bias.data.to(gate_out_dtype)
self.gate.e_score_correction_bias.data.to(self.gate.out_dtype)
)

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
Expand Down
2 changes: 1 addition & 1 deletion vllm/v1/attention/backends/mla/rocm_aiter_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,7 @@ class AiterMLAHelper:
"""

_AITER_MIN_MLA_HEADS: Final = 16
_AITER_UNSUPPORTED_HEADS = [32]
_AITER_UNSUPPORTED_HEADS: list[int] = []

@staticmethod
def check_num_heads_validity(num_heads: int):
Expand Down
Loading