-
-
Notifications
You must be signed in to change notification settings - Fork 11.6k
[Hardware][AMD] integrate aiter into vllm #17710
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 9 commits
d483fc2
4f85566
ae85e79
87ea0ba
db4bc55
efe59bd
40654e4
3ff8565
dcbcd68
bc2afe5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -81,6 +81,7 @@ | |
| VLLM_ROCM_USE_AITER_MOE: bool = True | ||
| VLLM_ROCM_USE_AITER_RMSNORM: bool = True | ||
| VLLM_ROCM_USE_AITER_MLA: bool = True | ||
| VLLM_ROCM_USE_AITER_MHA: bool = True | ||
| VLLM_ROCM_USE_SKINNY_GEMM: bool = True | ||
| VLLM_ROCM_FP8_PADDING: bool = True | ||
| VLLM_ROCM_MOE_PADDING: bool = True | ||
|
|
@@ -581,6 +582,13 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]: | |
| "VLLM_ROCM_USE_AITER_MLA": | ||
| lambda: (os.getenv("VLLM_ROCM_USE_AITER_MLA", "True").lower() in | ||
| ("true", "1")), | ||
|
|
||
| # Whether to use aiter mha ops. | ||
| # By default is enabled. | ||
| "VLLM_ROCM_USE_AITER_MHA": | ||
| lambda: (os.getenv("VLLM_ROCM_USE_AITER_MHA", "True").lower() in | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we want to override #16828 by default? |
||
| ("true", "1")), | ||
|
|
||
| # use rocm skinny gemms | ||
| "VLLM_ROCM_USE_SKINNY_GEMM": | ||
| lambda: (os.getenv("VLLM_ROCM_USE_SKINNY_GEMM", "True").lower() in | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -8,6 +8,7 @@ | |
| import vllm.envs as envs | ||
| from vllm.model_executor.custom_op import CustomOp | ||
| from vllm.platforms import current_platform | ||
| from vllm.utils import direct_register_custom_op | ||
|
|
||
|
|
||
| def is_rocm_aiter_rmsnorm_enabled() -> bool: | ||
|
|
@@ -42,36 +43,70 @@ def fused_add_rms_norm( | |
| return x, residual | ||
|
|
||
|
|
||
| def rocm_aiter_rms_norm(x: torch.Tensor, weight: torch.Tensor, | ||
| variance_epsilon: float) -> torch.Tensor: | ||
| if is_rocm_aiter_rmsnorm_enabled(): | ||
|
|
||
| import aiter as rocm_aiter | ||
| if x.dim() > 2: | ||
| x_original_shape = x.shape | ||
| x = x.reshape(-1, x_original_shape[-1]) | ||
| x = rocm_aiter.rms_norm(x, weight, variance_epsilon) | ||
| return x.reshape(x_original_shape) | ||
| def rocm_aiter_rms_norm_impl(x: torch.Tensor, weight: torch.Tensor, | ||
| variance_epsilon: float) -> torch.Tensor: | ||
|
|
||
| return rocm_aiter.rms_norm(x, weight, variance_epsilon) | ||
| import aiter as rocm_aiter | ||
| if x.dim() > 2: | ||
| x_original_shape = x.shape | ||
| x = x.reshape(-1, x_original_shape[-1]) | ||
| x = rocm_aiter.rms_norm(x, weight, variance_epsilon) | ||
| return x.reshape(x_original_shape) | ||
|
|
||
| return rocm_aiter.rms_norm(x, weight, variance_epsilon) | ||
|
|
||
| def rocm_aiter_fused_add_rms_norm( | ||
| x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, | ||
| variance_epsilon: float) -> tuple[torch.Tensor, torch.Tensor]: | ||
| def rocm_aiter_rms_norm_fake(input: torch.Tensor, weight: torch.Tensor, | ||
| variance_epsilon: float) -> torch.Tensor: | ||
| return input.clone() | ||
|
|
||
| import aiter as rocm_aiter | ||
| try: | ||
| direct_register_custom_op( | ||
| op_name="rocm_aiter_rms_norm", | ||
| op_func=rocm_aiter_rms_norm_impl, | ||
| mutates_args=[], | ||
| fake_impl=rocm_aiter_rms_norm_fake, | ||
| ) | ||
| rocm_aiter_rms_norm = torch.ops.vllm.rocm_aiter_rms_norm | ||
|
|
||
| except AttributeError: | ||
|
||
| rocm_aiter_rms_norm = rocm_aiter_rms_norm_impl | ||
|
|
||
| def rocm_aiter_fused_add_rms_norm_impl( | ||
| x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, | ||
| variance_epsilon: float) -> tuple[torch.Tensor, torch.Tensor]: | ||
|
|
||
| import aiter as rocm_aiter | ||
| residual_out = torch.empty_like(residual) | ||
| output = torch.empty_like(x) | ||
| rocm_aiter.rmsnorm2d_fwd_with_add( | ||
| output, # output | ||
| x, # input | ||
| residual, # residual input | ||
| residual_out, # residual output | ||
| weight, | ||
| variance_epsilon, | ||
| ) | ||
| return output, residual_out | ||
|
|
||
| def rocm_aiter_fused_add_rms_norm_fake( | ||
| x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, | ||
| variance_epsilon: float) -> tuple[torch.Tensor, torch.Tensor]: | ||
| return x.clone(), residual.clone() | ||
|
|
||
| try: | ||
| direct_register_custom_op( | ||
| op_name="rocm_aiter_fused_add_rms_norm", | ||
| op_func=rocm_aiter_fused_add_rms_norm_impl, | ||
| mutates_args=[], | ||
| fake_impl=rocm_aiter_fused_add_rms_norm_fake, | ||
| ) | ||
| rocm_aiter_fused_add_rms_norm = \ | ||
| torch.ops.vllm.rocm_aiter_fused_add_rms_norm | ||
|
|
||
| residual_out = torch.empty_like(residual) | ||
| output = torch.empty_like(x) | ||
| rocm_aiter.rmsnorm2d_fwd_with_add( | ||
| output, # output | ||
| x, # input | ||
| residual, # residual input | ||
| residual_out, # residual output | ||
| weight, | ||
| variance_epsilon, | ||
| ) | ||
| return output, residual_out | ||
| except AttributeError: | ||
| rocm_aiter_fused_add_rms_norm = rocm_aiter_fused_add_rms_norm_impl | ||
|
|
||
|
|
||
| def dispatch_cuda_rmsnorm_func(add_residual: bool): | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wondering what's difference between VLLM_ROCM_USE_AITER and VLLM_ROCM_USE_AITER_MHA?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Main switch and submodule switch.