Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions vllm/lora/layers/column_parallel_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from vllm.config.lora import LoRAConfig
from vllm.distributed import tensor_model_parallel_all_gather
from vllm.distributed.utils import divide
from vllm.model_executor.custom_op import maybe_get_oot_by_class
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
MergedColumnParallelLinear,
Expand Down Expand Up @@ -155,9 +156,9 @@ def can_replace_layer(
packed_modules_list: list,
model_config: PretrainedConfig | None = None,
) -> bool:
if type(source_layer) is ColumnParallelLinear:
if type(source_layer) is maybe_get_oot_by_class(ColumnParallelLinear):
return True
if type(source_layer) is MergedColumnParallelLinear:
if type(source_layer) is maybe_get_oot_by_class(MergedColumnParallelLinear):
if len(packed_modules_list) != 1:
return False
# Exclude layers with 3+ output sizes - those are handled by
Expand Down Expand Up @@ -606,7 +607,7 @@ def can_replace_layer(
) -> bool:
# Support MergedColumnParallelLinear with 3 or more slices
# (2 slices are handled by MergedColumnParallelLinearWithLoRA)
if type(source_layer) is not MergedColumnParallelLinear:
if type(source_layer) is not maybe_get_oot_by_class(MergedColumnParallelLinear):
return False

# If packed_modules_list has 3+ items, use this class
Expand Down
3 changes: 2 additions & 1 deletion vllm/lora/layers/replicated_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from transformers import PretrainedConfig

from vllm.config.lora import LoRAConfig
from vllm.model_executor.custom_op import maybe_get_oot_by_class
from vllm.model_executor.layers.linear import ReplicatedLinear

from .base_linear import BaseLinearLayerWithLoRA
Expand Down Expand Up @@ -55,7 +56,7 @@ def can_replace_layer(
packed_modules_list: list,
model_config: PretrainedConfig | None = None,
) -> bool:
return type(source_layer) is ReplicatedLinear
return type(source_layer) is maybe_get_oot_by_class(ReplicatedLinear)

def slice_lora_a(
self, lora_a: torch.Tensor | list[torch.Tensor | None]
Expand Down
3 changes: 2 additions & 1 deletion vllm/lora/layers/row_parallel_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
split_tensor_along_last_dim,
tensor_model_parallel_all_reduce,
)
from vllm.model_executor.custom_op import maybe_get_oot_by_class
from vllm.model_executor.layers.linear import RowParallelLinear
from vllm.platforms import current_platform

Expand Down Expand Up @@ -89,7 +90,7 @@ def can_replace_layer(
packed_modules_list: list,
model_config: PretrainedConfig | None = None,
) -> bool:
return type(source_layer) is RowParallelLinear
return type(source_layer) is maybe_get_oot_by_class(RowParallelLinear)


# The following layer is based on the tensor parallelism strategy given in
Expand Down
3 changes: 2 additions & 1 deletion vllm/lora/layers/vocal_parallel_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from transformers import PretrainedConfig

from vllm.config.lora import LoRAConfig
from vllm.model_executor.custom_op import maybe_get_oot_by_class
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.platforms import current_platform

Expand Down Expand Up @@ -132,7 +133,7 @@ def can_replace_layer(
packed_modules_list: list,
model_config: PretrainedConfig | None = None,
) -> bool:
return type(source_layer) is VocabParallelEmbedding
return type(source_layer) is maybe_get_oot_by_class(VocabParallelEmbedding)

@property
def weight(self):
Expand Down
5 changes: 3 additions & 2 deletions vllm/model_executor/custom_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,11 @@
op_registry_oot: dict[str, type["CustomOp"] | type["PluggableLayer"]] = {}


def get_oot_class_by_name(class_name: str) -> type | None:
def maybe_get_oot_by_class(class_type: type) -> type:
class_name = class_type.__name__
if class_name in op_registry_oot:
return op_registry_oot[class_name]
return None
return class_type


class PluggableLayer(nn.Module):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch

from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp, get_oot_class_by_name
from vllm.model_executor.custom_op import CustomOp, maybe_get_oot_by_class
from vllm.model_executor.models.vision import get_vit_attn_backend
from vllm.utils.math_utils import round_up
from vllm.v1.attention.backends.fa_utils import get_flash_attn_version
Expand Down Expand Up @@ -125,7 +125,7 @@ def maybe_compute_seq_lens(
cu_seqlens: np.ndarray,
device: torch.device,
) -> torch.Tensor | None:
if (oot_class := get_oot_class_by_name(cls.__name__)) is not None:
if (oot_class := maybe_get_oot_by_class(cls)) is not cls:
return oot_class.maybe_compute_seq_lens(attn_backend, cu_seqlens, device) # type: ignore[attr-defined]

if attn_backend != AttentionBackendEnum.FLASHINFER:
Expand All @@ -149,7 +149,7 @@ def maybe_recompute_cu_seqlens(
tp_size: int,
device: torch.device,
) -> torch.Tensor:
if (oot_class := get_oot_class_by_name(cls.__name__)) is not None:
if (oot_class := maybe_get_oot_by_class(cls)) is not cls:
return oot_class.maybe_recompute_cu_seqlens( # type: ignore[attr-defined]
attn_backend, cu_seqlens, hidden_size, tp_size, device
)
Expand Down
Loading