diff --git a/vllm/lora/layers/column_parallel_linear.py b/vllm/lora/layers/column_parallel_linear.py index eaed6e2265cd..f49a3fcbb941 100644 --- a/vllm/lora/layers/column_parallel_linear.py +++ b/vllm/lora/layers/column_parallel_linear.py @@ -9,6 +9,7 @@ from vllm.config.lora import LoRAConfig from vllm.distributed import tensor_model_parallel_all_gather from vllm.distributed.utils import divide +from vllm.model_executor.custom_op import maybe_get_oot_by_class from vllm.model_executor.layers.linear import ( ColumnParallelLinear, MergedColumnParallelLinear, @@ -155,9 +156,9 @@ def can_replace_layer( packed_modules_list: list, model_config: PretrainedConfig | None = None, ) -> bool: - if type(source_layer) is ColumnParallelLinear: + if type(source_layer) is maybe_get_oot_by_class(ColumnParallelLinear): return True - if type(source_layer) is MergedColumnParallelLinear: + if type(source_layer) is maybe_get_oot_by_class(MergedColumnParallelLinear): if len(packed_modules_list) != 1: return False # Exclude layers with 3+ output sizes - those are handled by @@ -606,7 +607,7 @@ def can_replace_layer( ) -> bool: # Support MergedColumnParallelLinear with 3 or more slices # (2 slices are handled by MergedColumnParallelLinearWithLoRA) - if type(source_layer) is not MergedColumnParallelLinear: + if type(source_layer) is not maybe_get_oot_by_class(MergedColumnParallelLinear): return False # If packed_modules_list has 3+ items, use this class diff --git a/vllm/lora/layers/replicated_linear.py b/vllm/lora/layers/replicated_linear.py index 62bac546ccd1..f1f499b841ba 100644 --- a/vllm/lora/layers/replicated_linear.py +++ b/vllm/lora/layers/replicated_linear.py @@ -7,6 +7,7 @@ from transformers import PretrainedConfig from vllm.config.lora import LoRAConfig +from vllm.model_executor.custom_op import maybe_get_oot_by_class from vllm.model_executor.layers.linear import ReplicatedLinear from .base_linear import BaseLinearLayerWithLoRA @@ -55,7 +56,7 @@ def can_replace_layer( packed_modules_list: list, model_config: PretrainedConfig | None = None, ) -> bool: - return type(source_layer) is ReplicatedLinear + return type(source_layer) is maybe_get_oot_by_class(ReplicatedLinear) def slice_lora_a( self, lora_a: torch.Tensor | list[torch.Tensor | None] diff --git a/vllm/lora/layers/row_parallel_linear.py b/vllm/lora/layers/row_parallel_linear.py index 8de5822db4d1..9460b687f1af 100644 --- a/vllm/lora/layers/row_parallel_linear.py +++ b/vllm/lora/layers/row_parallel_linear.py @@ -11,6 +11,7 @@ split_tensor_along_last_dim, tensor_model_parallel_all_reduce, ) +from vllm.model_executor.custom_op import maybe_get_oot_by_class from vllm.model_executor.layers.linear import RowParallelLinear from vllm.platforms import current_platform @@ -89,7 +90,7 @@ def can_replace_layer( packed_modules_list: list, model_config: PretrainedConfig | None = None, ) -> bool: - return type(source_layer) is RowParallelLinear + return type(source_layer) is maybe_get_oot_by_class(RowParallelLinear) # The following layer is based on the tensor parallelism strategy given in diff --git a/vllm/lora/layers/vocal_parallel_embedding.py b/vllm/lora/layers/vocal_parallel_embedding.py index efc5a1771514..05e7cfa06c85 100644 --- a/vllm/lora/layers/vocal_parallel_embedding.py +++ b/vllm/lora/layers/vocal_parallel_embedding.py @@ -7,6 +7,7 @@ from transformers import PretrainedConfig from vllm.config.lora import LoRAConfig +from vllm.model_executor.custom_op import maybe_get_oot_by_class from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.platforms import current_platform @@ -132,7 +133,7 @@ def can_replace_layer( packed_modules_list: list, model_config: PretrainedConfig | None = None, ) -> bool: - return type(source_layer) is VocabParallelEmbedding + return type(source_layer) is maybe_get_oot_by_class(VocabParallelEmbedding) @property def weight(self): diff --git a/vllm/model_executor/custom_op.py b/vllm/model_executor/custom_op.py index b8e372e88e6f..a1514c9206be 100644 --- a/vllm/model_executor/custom_op.py +++ b/vllm/model_executor/custom_op.py @@ -22,10 +22,11 @@ op_registry_oot: dict[str, type["CustomOp"] | type["PluggableLayer"]] = {} -def get_oot_class_by_name(class_name: str) -> type | None: +def maybe_get_oot_by_class(class_type: type) -> type: + class_name = class_type.__name__ if class_name in op_registry_oot: return op_registry_oot[class_name] - return None + return class_type class PluggableLayer(nn.Module): diff --git a/vllm/model_executor/layers/attention/mm_encoder_attention.py b/vllm/model_executor/layers/attention/mm_encoder_attention.py index bc0687ed2701..46d461c38b3f 100644 --- a/vllm/model_executor/layers/attention/mm_encoder_attention.py +++ b/vllm/model_executor/layers/attention/mm_encoder_attention.py @@ -6,7 +6,7 @@ import torch from vllm.logger import init_logger -from vllm.model_executor.custom_op import CustomOp, get_oot_class_by_name +from vllm.model_executor.custom_op import CustomOp, maybe_get_oot_by_class from vllm.model_executor.models.vision import get_vit_attn_backend from vllm.utils.math_utils import round_up from vllm.v1.attention.backends.fa_utils import get_flash_attn_version @@ -125,7 +125,7 @@ def maybe_compute_seq_lens( cu_seqlens: np.ndarray, device: torch.device, ) -> torch.Tensor | None: - if (oot_class := get_oot_class_by_name(cls.__name__)) is not None: + if (oot_class := maybe_get_oot_by_class(cls)) is not cls: return oot_class.maybe_compute_seq_lens(attn_backend, cu_seqlens, device) # type: ignore[attr-defined] if attn_backend != AttentionBackendEnum.FLASHINFER: @@ -149,7 +149,7 @@ def maybe_recompute_cu_seqlens( tp_size: int, device: torch.device, ) -> torch.Tensor: - if (oot_class := get_oot_class_by_name(cls.__name__)) is not None: + if (oot_class := maybe_get_oot_by_class(cls)) is not cls: return oot_class.maybe_recompute_cu_seqlens( # type: ignore[attr-defined] attn_backend, cu_seqlens, hidden_size, tp_size, device )