-
Notifications
You must be signed in to change notification settings - Fork 54
[Spyre-Next] Add RowParallelLinear and ColumnParallelLinear(MLP) wrappers #869
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
+168
−0
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,166 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
| """Spyre-specific linear layer implementations using out-of-tree (OOT) registration. | ||
|
|
||
| This module provides Spyre-device-specific replacements for the parallel linear | ||
| layer classes used inside MLP blocks: | ||
|
|
||
| - SpyreMergedColumnParallelLinear — replaces MergedColumnParallelLinear | ||
| (vllm/model_executor/layers/linear.py) | ||
| - SpyreRowParallelLinear — replaces RowParallelLinear | ||
| (vllm/model_executor/layers/linear.py) | ||
|
|
||
| Since tensor_parallel=1 is assumed, both classes are functionally equivalent | ||
| to F.linear(input, weight, bias) and share the same implementation pattern. | ||
|
|
||
| Spyre Device Constraints: | ||
| - Computations performed in torch.float16: | ||
| Input (dtype defined by model / user) converted to torch.float16 for | ||
| operations on spyre and then converted back to original dtype for cpu. | ||
| - Tensor parallelism: TP=1 assumed (single Spyre device) | ||
|
|
||
| References: | ||
| - Upstream linear layers: vllm/model_executor/layers/linear.py | ||
| - Pattern reference: vllm_spyre_next/custom_ops/rms_norm.py | ||
| """ | ||
|
|
||
| import torch | ||
| import torch.nn.functional as F | ||
| from functools import lru_cache | ||
|
|
||
| from vllm.logger import init_logger | ||
| from vllm.utils.torch_utils import direct_register_custom_op | ||
| from vllm.model_executor.layers.linear import ( | ||
| MergedColumnParallelLinear, | ||
| RowParallelLinear, | ||
| ) | ||
|
|
||
| from .utils import convert, register_layer, get_layer, _fake_impl | ||
|
|
||
| logger = init_logger(__name__) | ||
|
|
||
|
|
||
| class SpyreLinearBase: | ||
| """Shared implementation for Spyre linear layers at TP=1.""" | ||
|
|
||
| def _init_spyre_linear(self, layer_prefix: str): | ||
| """Common initialization for Spyre linear layers.""" | ||
| if self.tp_size > 1: | ||
| raise NotImplementedError( | ||
| f"{self.__class__.__name__} only supports TP=1, got TP={self.tp_size}" | ||
| ) | ||
|
|
||
| logger.debug("Building custom %s", self.__class__.__name__) | ||
|
|
||
| self._target_device = torch.device("spyre") | ||
| self._target_dtype = torch.float16 | ||
|
|
||
| # NOTE: Using torch.compile directly here since PluggableLayer (unlike CustomOp) | ||
| # does not provide a maybe_compile method. This should be revisited in the future | ||
| # to align with vLLM's compilation infrastructure once PluggableLayer supports | ||
| # compilation hooks similar to CustomOp.maybe_compile. | ||
| self.maybe_compiled_forward_spyre = torch.compile(self.forward_spyre, dynamic=False) | ||
| self._layer_name = register_layer(self, layer_prefix) | ||
|
|
||
| logger.warning_once( | ||
| "%s: no dtype promotion (torch-spyre limitation)," | ||
| "expect numerical differences to upstream vLLM.", | ||
| self.__class__.__name__, | ||
| ) | ||
|
|
||
| def forward_spyre( | ||
|
R3hankhan123 marked this conversation as resolved.
|
||
| self, | ||
| x: torch.Tensor, | ||
| weight: torch.Tensor, | ||
| bias: torch.Tensor | None = None, | ||
| ) -> torch.Tensor: | ||
| return F.linear(x, weight, bias) | ||
|
|
||
| def _forward_spyre_impl(self, x: torch.Tensor) -> torch.Tensor: | ||
| x_dtype = x.dtype | ||
| x_device = x.device | ||
|
|
||
| # Bias is fused into F.linear only when not skipping bias add | ||
| bias = self.bias.data if (self.bias is not None and not self.skip_bias_add) else None | ||
|
|
||
| out = self.maybe_compiled_forward_spyre( | ||
| convert(x, self._target_device, self._target_dtype), | ||
| convert(self.weight.data, self._target_device, self._target_dtype), | ||
| convert(bias, self._target_device, self._target_dtype) if bias is not None else None, | ||
| ) | ||
|
|
||
| return convert(out, dtype=x_dtype, device=x_device) | ||
|
|
||
|
|
||
| @MergedColumnParallelLinear.register_oot(name="MergedColumnParallelLinear") | ||
| class SpyreMergedColumnParallelLinear(SpyreLinearBase, MergedColumnParallelLinear): | ||
| """Spyre MergedColumnParallelLinear (TP=1 only).""" | ||
|
|
||
| def __init__(self, *args, **kwargs): | ||
| super().__init__(*args, **kwargs) | ||
| self._init_spyre_linear("spyre_merged_col_linear") | ||
|
|
||
| # `MergedColumnParallelLinear` is a PluggableLayer and we register a class as OOT, | ||
| # thus, the `forward` method is invoked when the OOT is triggered. | ||
| def forward(self, input_: torch.Tensor): | ||
| output = input_.new_empty( | ||
| input_.shape[0], | ||
| self.output_size_per_partition, | ||
| ) | ||
| torch.ops.vllm.spyre_merged_col_linear(input_, output, self._layer_name) | ||
|
|
||
| if not self.return_bias: | ||
| return output | ||
| output_bias = self.bias if self.skip_bias_add else None | ||
| return output, output_bias | ||
|
|
||
|
|
||
| @RowParallelLinear.register_oot(name="RowParallelLinear") | ||
| class SpyreRowParallelLinear(SpyreLinearBase, RowParallelLinear): | ||
| """Spyre RowParallelLinear (TP=1 only).""" | ||
|
|
||
| def __init__(self, *args, **kwargs): | ||
| super().__init__(*args, **kwargs) | ||
| self._init_spyre_linear("spyre_row_parallel_linear") | ||
|
|
||
| # `SpyreRowParallelLinear` is a PluggableLayer and we register a class as OOT, | ||
| # thus, the `forward` method is invoked when the OOT is triggered. | ||
| def forward(self, input_: torch.Tensor): | ||
| output = input_.new_empty( | ||
| *input_.shape[:-1], | ||
| self.output_size_per_partition, | ||
| ) | ||
|
|
||
| torch.ops.vllm.spyre_row_parallel_linear(input_, output, self._layer_name) | ||
|
|
||
| if not self.return_bias: | ||
| return output | ||
| output_bias = self.bias if self.skip_bias_add else None | ||
| return output, output_bias | ||
|
|
||
|
|
||
| def _make_spyre_linear_op_func(op_name: str): | ||
| def _op_func( | ||
| x: torch.Tensor, | ||
| output: torch.Tensor, | ||
| layer_name: str, | ||
| ) -> None: | ||
| layer = get_layer(layer_name) | ||
| result = layer._forward_spyre_impl(x) | ||
| output.copy_(result) | ||
|
|
||
| _op_func.__name__ = f"_{op_name}_op_func" | ||
| return _op_func | ||
|
|
||
|
|
||
| @lru_cache(maxsize=1) | ||
| def register(): | ||
| """Register Spyre linear custom ops.""" | ||
| for op_name in ["spyre_merged_col_linear", "spyre_row_parallel_linear"]: | ||
| direct_register_custom_op( | ||
| op_name=op_name, | ||
| op_func=_make_spyre_linear_op_func(op_name), | ||
| mutates_args=["output"], | ||
|
R3hankhan123 marked this conversation as resolved.
|
||
| fake_impl=_fake_impl, | ||
| ) | ||
| logger.info("Registered custom op: %s", op_name) | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for all other oot impl. (e.g rms, silu ... ) I see this line:
_dynamic_arg_dims = {"x": [], "residual": []}is it not needed here? could it be removed in the other classes too? @bohnstingl
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
okay, I see we do not have a residual here...
is not specifying anything the same as putting
_dynamic_arg_dims = {"x": []}?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, it can't be removed and in fact we need it here as well. The
_dynamic_arg_dims = {"x": [], "residual": []}ensures that maybe_compile compiles withdynamic=False. Here it should be_dynamic_arg_dims = {"x": [], "weight": [], "bias": []}, I think.@R3hankhan123 could you please confirm that:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The weight and bias tensors are internal to the layer implementation, they're accessed inside
_forward_spyre_impl. Since they're not direct arguments to the custom op, i think they don't need to be in_dynamic_arg_dims. I think only{"x": [], "output": []},are sufficientThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On a second thought, I have to take my comment above back.
MergedColumnParallelLinearandRowParallelLineararePluggableLayer, notCustomOp. Thus, the compilation path is different and there is nomaybe_compile. Probably we need to simply invoke torch.compile directly for at the moment:Please leave a note though that this should be changed in the future.
This means you also don't need to define
_dynamic_arg_dims.