Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions vllm_spyre_next/vllm_spyre_next/custom_ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from . import rms_norm
from . import silu_and_mul
from . import vocab_parallel_embedding
from . import linear
from vllm.logger import init_logger

logger = init_logger(__name__)
Expand All @@ -13,3 +14,4 @@ def register_all():
rms_norm.register()
silu_and_mul.register()
vocab_parallel_embedding.register()
linear.register()
166 changes: 166 additions & 0 deletions vllm_spyre_next/vllm_spyre_next/custom_ops/linear.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Spyre-specific linear layer implementations using out-of-tree (OOT) registration.

This module provides Spyre-device-specific replacements for the parallel linear
layer classes used inside MLP blocks:

- SpyreMergedColumnParallelLinear — replaces MergedColumnParallelLinear
(vllm/model_executor/layers/linear.py)
- SpyreRowParallelLinear — replaces RowParallelLinear
(vllm/model_executor/layers/linear.py)

Since tensor_parallel=1 is assumed, both classes are functionally equivalent
to F.linear(input, weight, bias) and share the same implementation pattern.

Spyre Device Constraints:
- Computations performed in torch.float16:
Input (dtype defined by model / user) converted to torch.float16 for
operations on spyre and then converted back to original dtype for cpu.
- Tensor parallelism: TP=1 assumed (single Spyre device)

References:
- Upstream linear layers: vllm/model_executor/layers/linear.py
- Pattern reference: vllm_spyre_next/custom_ops/rms_norm.py
"""

import torch
import torch.nn.functional as F
from functools import lru_cache

from vllm.logger import init_logger
from vllm.utils.torch_utils import direct_register_custom_op
from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear,
RowParallelLinear,
)

from .utils import convert, register_layer, get_layer, _fake_impl

logger = init_logger(__name__)


class SpyreLinearBase:
"""Shared implementation for Spyre linear layers at TP=1."""

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for all other oot impl. (e.g rms, silu ... ) I see this line:
_dynamic_arg_dims = {"x": [], "residual": []}
is it not needed here? could it be removed in the other classes too? @bohnstingl

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay, I see we do not have a residual here...
is not specifying anything the same as putting _dynamic_arg_dims = {"x": []} ?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it not needed here? could it be removed in the other classes too? @bohnstingl

No, it can't be removed and in fact we need it here as well. The _dynamic_arg_dims = {"x": [], "residual": []} ensures that maybe_compile compiles with dynamic=False. Here it should be _dynamic_arg_dims = {"x": [], "weight": [], "bias": []}, I think.

@R3hankhan123 could you please confirm that:

  • This path here is followed, i.e., make a breakpoint() there and check that it is triggered.
  • That no shape is marked as dynamic, i.e., this part is never reached.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The weight and bias tensors are internal to the layer implementation, they're accessed inside _forward_spyre_impl. Since they're not direct arguments to the custom op, i think they don't need to be in _dynamic_arg_dims. I think only {"x": [], "output": []}, are sufficient

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On a second thought, I have to take my comment above back. MergedColumnParallelLinear and RowParallelLinear are PluggableLayer, not CustomOp. Thus, the compilation path is different and there is no maybe_compile. Probably we need to simply invoke torch.compile directly for at the moment:

self.maybe_compiled_forward_spyre = torch.compile(self.forward_spyre, dynamic=False)

Please leave a note though that this should be changed in the future.

This means you also don't need to define _dynamic_arg_dims.

def _init_spyre_linear(self, layer_prefix: str):
"""Common initialization for Spyre linear layers."""
if self.tp_size > 1:
raise NotImplementedError(
f"{self.__class__.__name__} only supports TP=1, got TP={self.tp_size}"
)

logger.debug("Building custom %s", self.__class__.__name__)

self._target_device = torch.device("spyre")
self._target_dtype = torch.float16

# NOTE: Using torch.compile directly here since PluggableLayer (unlike CustomOp)
# does not provide a maybe_compile method. This should be revisited in the future
# to align with vLLM's compilation infrastructure once PluggableLayer supports
# compilation hooks similar to CustomOp.maybe_compile.
self.maybe_compiled_forward_spyre = torch.compile(self.forward_spyre, dynamic=False)
self._layer_name = register_layer(self, layer_prefix)

logger.warning_once(
"%s: no dtype promotion (torch-spyre limitation),"
"expect numerical differences to upstream vLLM.",
self.__class__.__name__,
)

def forward_spyre(
Comment thread
R3hankhan123 marked this conversation as resolved.
self,
x: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
return F.linear(x, weight, bias)

def _forward_spyre_impl(self, x: torch.Tensor) -> torch.Tensor:
x_dtype = x.dtype
x_device = x.device

# Bias is fused into F.linear only when not skipping bias add
bias = self.bias.data if (self.bias is not None and not self.skip_bias_add) else None

out = self.maybe_compiled_forward_spyre(
convert(x, self._target_device, self._target_dtype),
convert(self.weight.data, self._target_device, self._target_dtype),
convert(bias, self._target_device, self._target_dtype) if bias is not None else None,
)

return convert(out, dtype=x_dtype, device=x_device)


@MergedColumnParallelLinear.register_oot(name="MergedColumnParallelLinear")
class SpyreMergedColumnParallelLinear(SpyreLinearBase, MergedColumnParallelLinear):
"""Spyre MergedColumnParallelLinear (TP=1 only)."""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._init_spyre_linear("spyre_merged_col_linear")

# `MergedColumnParallelLinear` is a PluggableLayer and we register a class as OOT,
# thus, the `forward` method is invoked when the OOT is triggered.
def forward(self, input_: torch.Tensor):
output = input_.new_empty(
input_.shape[0],
self.output_size_per_partition,
)
torch.ops.vllm.spyre_merged_col_linear(input_, output, self._layer_name)

if not self.return_bias:
return output
output_bias = self.bias if self.skip_bias_add else None
return output, output_bias


@RowParallelLinear.register_oot(name="RowParallelLinear")
class SpyreRowParallelLinear(SpyreLinearBase, RowParallelLinear):
"""Spyre RowParallelLinear (TP=1 only)."""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._init_spyre_linear("spyre_row_parallel_linear")

# `SpyreRowParallelLinear` is a PluggableLayer and we register a class as OOT,
# thus, the `forward` method is invoked when the OOT is triggered.
def forward(self, input_: torch.Tensor):
output = input_.new_empty(
*input_.shape[:-1],
self.output_size_per_partition,
)

torch.ops.vllm.spyre_row_parallel_linear(input_, output, self._layer_name)

if not self.return_bias:
return output
output_bias = self.bias if self.skip_bias_add else None
return output, output_bias


def _make_spyre_linear_op_func(op_name: str):
def _op_func(
x: torch.Tensor,
output: torch.Tensor,
layer_name: str,
) -> None:
layer = get_layer(layer_name)
result = layer._forward_spyre_impl(x)
output.copy_(result)

_op_func.__name__ = f"_{op_name}_op_func"
return _op_func


@lru_cache(maxsize=1)
def register():
"""Register Spyre linear custom ops."""
for op_name in ["spyre_merged_col_linear", "spyre_row_parallel_linear"]:
direct_register_custom_op(
op_name=op_name,
op_func=_make_spyre_linear_op_func(op_name),
mutates_args=["output"],
Comment thread
R3hankhan123 marked this conversation as resolved.
fake_impl=_fake_impl,
)
logger.info("Registered custom op: %s", op_name)
Loading