From 2dfa2148bc65b1fa52c3bad6b6e02832ea67029e Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 30 Jul 2025 07:42:59 +0000 Subject: [PATCH 1/5] [Misc] Use config definitions from Transformers library Signed-off-by: DarkLight1337 --- vllm/model_executor/models/aimv2.py | 3 +++ vllm/model_executor/models/dbrx.py | 14 +++++++------- vllm/model_executor/models/deepseek_v2.py | 14 ++++++++------ vllm/model_executor/models/dots1.py | 8 ++++---- vllm/model_executor/models/minimax_text_01.py | 6 +++--- vllm/model_executor/models/mpt.py | 8 ++++---- vllm/model_executor/models/olmoe.py | 4 ++-- vllm/model_executor/models/qwen2_moe.py | 6 +++--- vllm/model_executor/models/qwen3_moe.py | 6 +++--- 9 files changed, 37 insertions(+), 32 deletions(-) diff --git a/vllm/model_executor/models/aimv2.py b/vllm/model_executor/models/aimv2.py index d2307bb464bd..30e02fed4fe7 100644 --- a/vllm/model_executor/models/aimv2.py +++ b/vllm/model_executor/models/aimv2.py @@ -23,6 +23,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader +# NOTE: The Aimv2Config used here is defined by Ovis +# (https://huggingface.co/AIDC-AI/Ovis2-1B/tree/main) +# It is different from the one inside Transformers library class AIMv2SwiGLUFFN(nn.Module): def __init__(self, config: PretrainedConfig, diff --git a/vllm/model_executor/models/dbrx.py b/vllm/model_executor/models/dbrx.py index 360c7e66bf5c..e74d90e0b1d7 100644 --- a/vllm/model_executor/models/dbrx.py +++ b/vllm/model_executor/models/dbrx.py @@ -6,7 +6,7 @@ import torch import torch.nn as nn -from transformers import PretrainedConfig +from transformers import DbrxConfig from vllm.attention import Attention from vllm.config import CacheConfig, VllmConfig @@ -39,7 +39,7 @@ class DbrxRouter(nn.Module): def __init__( self, - config: PretrainedConfig, + config: DbrxConfig, params_dtype: Optional[torch.dtype] = None, ): super().__init__() @@ -63,7 +63,7 @@ class DbrxExperts(FusedMoE): def __init__( self, - config: PretrainedConfig, + config: DbrxConfig, quant_config: Optional[QuantizationConfig] = None, params_dtype: Optional[torch.dtype] = None, prefix: str = "", @@ -138,7 +138,7 @@ class DbrxMoE(nn.Module): def __init__( self, - config: PretrainedConfig, + config: DbrxConfig, quant_config: Optional[QuantizationConfig] = None, params_dtype: Optional[torch.dtype] = None, prefix: str = "", @@ -169,7 +169,7 @@ class DbrxAttention(nn.Module): def __init__( self, - config: PretrainedConfig, + config: DbrxConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", @@ -249,7 +249,7 @@ class DbrxFusedNormAttention(nn.Module): def __init__( self, - config: PretrainedConfig, + config: DbrxConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", @@ -284,7 +284,7 @@ class DbrxBlock(nn.Module): def __init__( self, - config: PretrainedConfig, + config: DbrxConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 79ddd3d0f627..7be6574c323a 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -29,7 +29,7 @@ import torch from torch import nn -from transformers import PretrainedConfig +from transformers import DeepseekV3Config from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile @@ -100,7 +100,7 @@ class DeepseekV2MoE(nn.Module): def __init__( self, - config: PretrainedConfig, + config: DeepseekV3Config, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", enable_eplb: bool = False, @@ -221,7 +221,7 @@ class DeepseekV2Attention(nn.Module): def __init__( self, - config: PretrainedConfig, + config: DeepseekV3Config, hidden_size: int, num_heads: int, qk_nope_head_dim: int, @@ -373,7 +373,7 @@ class DeepseekV2MLAAttention(nn.Module): def __init__( self, - config: PretrainedConfig, + config: DeepseekV3Config, hidden_size: int, num_heads: int, qk_nope_head_dim: int, @@ -538,7 +538,7 @@ class DeepseekV2DecoderLayer(nn.Module): def __init__( self, - config: PretrainedConfig, + config: DeepseekV3Config, prefix: str, model_config: ModelConfig, cache_config: Optional[CacheConfig] = None, @@ -971,7 +971,9 @@ class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM): pass -def get_spec_layer_idx_from_weight_name(config: PretrainedConfig, +# Compatibility with +# https://huggingface.co/deepseek-ai/DeepSeek-V3-Base/blob/main/configuration_deepseek.py +def get_spec_layer_idx_from_weight_name(config: DeepseekV3Config, weight_name: str) -> Optional[int]: if (hasattr(config, "num_nextn_predict_layers") and config.num_nextn_predict_layers > 0): diff --git a/vllm/model_executor/models/dots1.py b/vllm/model_executor/models/dots1.py index 9b21a7944613..5f410c0ae5fb 100644 --- a/vllm/model_executor/models/dots1.py +++ b/vllm/model_executor/models/dots1.py @@ -29,7 +29,7 @@ import torch from torch import nn -from transformers import PretrainedConfig +from transformers import Dots1Config from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile @@ -99,7 +99,7 @@ class Dots1MoE(nn.Module): def __init__( self, - config: PretrainedConfig, + config: Dots1Config, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ): @@ -174,7 +174,7 @@ def __init__( hidden_size: int, num_heads: int, num_kv_heads: int, - config: PretrainedConfig, + config: Dots1Config, rope_theta: float = 10000, rope_scaling: Optional[dict[str, Any]] = None, max_position_embeddings: int = 8192, @@ -260,7 +260,7 @@ class Dots1DecoderLayer(nn.Module): def __init__( self, - config: PretrainedConfig, + config: Dots1Config, prefix: str, model_config: ModelConfig, cache_config: Optional[CacheConfig] = None, diff --git a/vllm/model_executor/models/minimax_text_01.py b/vllm/model_executor/models/minimax_text_01.py index f2773af490c5..5b7246453b1a 100644 --- a/vllm/model_executor/models/minimax_text_01.py +++ b/vllm/model_executor/models/minimax_text_01.py @@ -12,7 +12,7 @@ import torch.nn.functional as F from einops import rearrange from torch import nn -from transformers.configuration_utils import PretrainedConfig +from transformers import MiniMaxConfig from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig, VllmConfig @@ -585,7 +585,7 @@ class MiniMaxText01DecoderLayer(nn.Module): def __init__( self, - config: PretrainedConfig, + config: MiniMaxConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, expert_num: int = 1, @@ -788,7 +788,7 @@ class MiniMaxText01Model(nn.Module): def __init__( self, - config: PretrainedConfig, + config: MiniMaxConfig, quant_config: Optional[QuantizationConfig] = None, cache_config: Optional[CacheConfig] = None, scheduler_config=None, diff --git a/vllm/model_executor/models/mpt.py b/vllm/model_executor/models/mpt.py index c243f575ae54..c9d2db4052c6 100644 --- a/vllm/model_executor/models/mpt.py +++ b/vllm/model_executor/models/mpt.py @@ -8,7 +8,7 @@ import torch import torch.nn as nn -from transformers import PretrainedConfig +from transformers import MptConfig from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile @@ -50,7 +50,7 @@ class MPTAttention(nn.Module): def __init__( self, - config: PretrainedConfig, + config: MptConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", @@ -144,7 +144,7 @@ class MPTMLP(nn.Module): def __init__( self, - config: PretrainedConfig, + config: MptConfig, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -176,7 +176,7 @@ class MPTBlock(nn.Module): def __init__( self, - config: PretrainedConfig, + config: MptConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", diff --git a/vllm/model_executor/models/olmoe.py b/vllm/model_executor/models/olmoe.py index 7552f64c423e..a47c3bd41645 100644 --- a/vllm/model_executor/models/olmoe.py +++ b/vllm/model_executor/models/olmoe.py @@ -19,7 +19,7 @@ import torch from torch import nn -from transformers import PretrainedConfig +from transformers import OlmoeConfig from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile @@ -205,7 +205,7 @@ class OlmoeDecoderLayer(nn.Module): def __init__( self, - config: PretrainedConfig, + config: OlmoeConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", diff --git a/vllm/model_executor/models/qwen2_moe.py b/vllm/model_executor/models/qwen2_moe.py index b061e2f69a6c..5c4ad34246d6 100644 --- a/vllm/model_executor/models/qwen2_moe.py +++ b/vllm/model_executor/models/qwen2_moe.py @@ -30,7 +30,7 @@ import torch import torch.nn.functional as F from torch import nn -from transformers import PretrainedConfig +from transformers import Qwen2MoeConfig from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile @@ -98,7 +98,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module): def __init__( self, - config: PretrainedConfig, + config: Qwen2MoeConfig, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ): @@ -256,7 +256,7 @@ class Qwen2MoeDecoderLayer(nn.Module): def __init__( self, - config: PretrainedConfig, + config: Qwen2MoeConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", diff --git a/vllm/model_executor/models/qwen3_moe.py b/vllm/model_executor/models/qwen3_moe.py index 12899c28016b..c4beb076daf5 100644 --- a/vllm/model_executor/models/qwen3_moe.py +++ b/vllm/model_executor/models/qwen3_moe.py @@ -27,7 +27,7 @@ import torch from torch import nn -from transformers import PretrainedConfig +from transformers import Qwen3MoeConfig from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile @@ -98,7 +98,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module): def __init__( self, - config: PretrainedConfig, + config: Qwen3MoeConfig, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ): @@ -242,7 +242,7 @@ class Qwen3MoeDecoderLayer(nn.Module): def __init__( self, - config: PretrainedConfig, + config: Qwen3MoeConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", From 235bffb54991dd9c6b59c45ba8562b94bcf81266 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Thu, 7 Aug 2025 06:12:58 +0000 Subject: [PATCH 2/5] Use V2 config Signed-off-by: DarkLight1337 --- vllm/model_executor/models/deepseek_v2.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index ade3e8689622..cbe65eaee488 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -29,7 +29,7 @@ import torch from torch import nn -from transformers import DeepseekV3Config +from transformers import DeepseekV2Config from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile @@ -100,7 +100,7 @@ class DeepseekV2MoE(nn.Module): def __init__( self, - config: DeepseekV3Config, + config: DeepseekV2Config, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", enable_eplb: bool = False, @@ -221,7 +221,7 @@ class DeepseekV2Attention(nn.Module): def __init__( self, - config: DeepseekV3Config, + config: DeepseekV2Config, hidden_size: int, num_heads: int, qk_nope_head_dim: int, @@ -373,7 +373,7 @@ class DeepseekV2MLAAttention(nn.Module): def __init__( self, - config: DeepseekV3Config, + config: DeepseekV2Config, hidden_size: int, num_heads: int, qk_nope_head_dim: int, @@ -538,7 +538,7 @@ class DeepseekV2DecoderLayer(nn.Module): def __init__( self, - config: DeepseekV3Config, + config: DeepseekV2Config, prefix: str, model_config: ModelConfig, cache_config: Optional[CacheConfig] = None, @@ -959,7 +959,7 @@ class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM): # Compatibility with # https://huggingface.co/deepseek-ai/DeepSeek-V3-Base/blob/main/configuration_deepseek.py -def get_spec_layer_idx_from_weight_name(config: DeepseekV3Config, +def get_spec_layer_idx_from_weight_name(config: DeepseekV2Config, weight_name: str) -> Optional[int]: if (hasattr(config, "num_nextn_predict_layers") and config.num_nextn_predict_layers > 0): From c7df154b1219cda0fc276e418a78e72970bf3485 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Thu, 7 Aug 2025 06:14:40 +0000 Subject: [PATCH 3/5] Check both Signed-off-by: DarkLight1337 --- vllm/model_executor/models/deepseek_v2.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index cbe65eaee488..ce5ab5a4bbd6 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -29,7 +29,7 @@ import torch from torch import nn -from transformers import DeepseekV2Config +from transformers import DeepseekV2Config, DeepseekV3Config from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile @@ -100,7 +100,7 @@ class DeepseekV2MoE(nn.Module): def __init__( self, - config: DeepseekV2Config, + config: Union[DeepseekV2Config, DeepseekV3Config], quant_config: Optional[QuantizationConfig] = None, prefix: str = "", enable_eplb: bool = False, @@ -221,7 +221,7 @@ class DeepseekV2Attention(nn.Module): def __init__( self, - config: DeepseekV2Config, + config: Union[DeepseekV2Config, DeepseekV3Config], hidden_size: int, num_heads: int, qk_nope_head_dim: int, @@ -373,7 +373,7 @@ class DeepseekV2MLAAttention(nn.Module): def __init__( self, - config: DeepseekV2Config, + config: Union[DeepseekV2Config, DeepseekV3Config], hidden_size: int, num_heads: int, qk_nope_head_dim: int, @@ -538,7 +538,7 @@ class DeepseekV2DecoderLayer(nn.Module): def __init__( self, - config: DeepseekV2Config, + config: Union[DeepseekV2Config, DeepseekV3Config], prefix: str, model_config: ModelConfig, cache_config: Optional[CacheConfig] = None, @@ -959,7 +959,8 @@ class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM): # Compatibility with # https://huggingface.co/deepseek-ai/DeepSeek-V3-Base/blob/main/configuration_deepseek.py -def get_spec_layer_idx_from_weight_name(config: DeepseekV2Config, +def get_spec_layer_idx_from_weight_name(config: Union[DeepseekV2Config, + DeepseekV3Config], weight_name: str) -> Optional[int]: if (hasattr(config, "num_nextn_predict_layers") and config.num_nextn_predict_layers > 0): From a7700d2a214c1ca330fde6c4f9fbd248824874f9 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Thu, 7 Aug 2025 06:21:10 +0000 Subject: [PATCH 4/5] Update Signed-off-by: DarkLight1337 --- vllm/model_executor/models/aimv2.py | 25 +++++++++++-------------- vllm/model_executor/models/exaone4.py | 6 +++--- vllm/model_executor/models/glm4_moe.py | 10 +++++----- 3 files changed, 19 insertions(+), 22 deletions(-) diff --git a/vllm/model_executor/models/aimv2.py b/vllm/model_executor/models/aimv2.py index 30e02fed4fe7..b13d863ebb74 100644 --- a/vllm/model_executor/models/aimv2.py +++ b/vllm/model_executor/models/aimv2.py @@ -8,7 +8,6 @@ import torch import torch.nn as nn -from transformers import PretrainedConfig from vllm.attention.layer import MultiHeadAttention from vllm.distributed import get_tensor_model_parallel_world_size @@ -21,15 +20,13 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.transformers_utils.configs.ovis import AIMv2Config -# NOTE: The Aimv2Config used here is defined by Ovis -# (https://huggingface.co/AIDC-AI/Ovis2-1B/tree/main) -# It is different from the one inside Transformers library class AIMv2SwiGLUFFN(nn.Module): - def __init__(self, config: PretrainedConfig, - quant_config: QuantizationConfig, prefix: str): + def __init__(self, config: AIMv2Config, quant_config: QuantizationConfig, + prefix: str): super().__init__() hidden_features = config.intermediate_size in_features = config.hidden_size @@ -60,7 +57,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class AIMv2PatchEmbed(nn.Module): - def __init__(self, config: PretrainedConfig): + def __init__(self, config: AIMv2Config): super().__init__() self.proj = nn.Conv2d( config.num_channels, @@ -78,7 +75,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class AIMv2ViTPreprocessor(nn.Module): - def __init__(self, config: PretrainedConfig): + def __init__(self, config: AIMv2Config): super().__init__() num_patches = (config.image_size // config.patch_size)**2 @@ -96,8 +93,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class AIMv2Attention(nn.Module): - def __init__(self, config: PretrainedConfig, - quant_config: QuantizationConfig, prefix: str): + def __init__(self, config: AIMv2Config, quant_config: QuantizationConfig, + prefix: str): super().__init__() self.config = config self.embed_dim = config.hidden_size @@ -144,8 +141,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class AIMv2Block(nn.Module): - def __init__(self, config: PretrainedConfig, - quant_config: QuantizationConfig, prefix: str): + def __init__(self, config: AIMv2Config, quant_config: QuantizationConfig, + prefix: str): super().__init__() self.attn = AIMv2Attention(config, quant_config=quant_config, @@ -166,7 +163,7 @@ class AIMv2Transformer(nn.Module): def __init__( self, - config: PretrainedConfig, + config: AIMv2Config, quant_config: QuantizationConfig, *, require_post_norm: Optional[bool] = None, @@ -196,7 +193,7 @@ def forward(self, tokens: torch.Tensor) -> torch.Tensor: class AIMv2Model(torch.nn.Module): def __init__(self, - config: PretrainedConfig, + config: AIMv2Config, quant_config: QuantizationConfig, *, require_post_norm: Optional[bool] = None, diff --git a/vllm/model_executor/models/exaone4.py b/vllm/model_executor/models/exaone4.py index 3d6ce3e8895f..ecd942a76ace 100644 --- a/vllm/model_executor/models/exaone4.py +++ b/vllm/model_executor/models/exaone4.py @@ -26,7 +26,7 @@ import torch from torch import nn -from transformers import PretrainedConfig +from transformers import Exaone4Config from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile @@ -96,7 +96,7 @@ class Exaone4Attention(nn.Module): def __init__( self, - config: PretrainedConfig, + config: Exaone4Config, hidden_size: int, num_heads: int, num_kv_heads: int, @@ -224,7 +224,7 @@ class Exaone4DecoderLayer(nn.Module): def __init__( self, - config: PretrainedConfig, + config: Exaone4Config, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", diff --git a/vllm/model_executor/models/glm4_moe.py b/vllm/model_executor/models/glm4_moe.py index bd3e27662ee7..a419945e326e 100644 --- a/vllm/model_executor/models/glm4_moe.py +++ b/vllm/model_executor/models/glm4_moe.py @@ -28,7 +28,7 @@ import torch from torch import nn -from transformers import PretrainedConfig +from transformers.models.glm4_moe import Glm4MoeConfig from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile @@ -100,7 +100,7 @@ class Glm4MoE(nn.Module): def __init__( self, - config: PretrainedConfig, + config: Glm4MoeConfig, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", enable_eplb: bool = False, @@ -198,7 +198,7 @@ class Glm4MoeAttention(nn.Module): def __init__( self, - config: PretrainedConfig, + config: Glm4MoeConfig, hidden_size: int, num_heads: int, num_kv_heads: int, @@ -297,7 +297,7 @@ class Glm4MoeDecoderLayer(nn.Module): def __init__( self, - config: PretrainedConfig, + config: Glm4MoeConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", @@ -683,7 +683,7 @@ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: return self.model.get_expert_mapping() -def get_spec_layer_idx_from_weight_name(config: PretrainedConfig, +def get_spec_layer_idx_from_weight_name(config: Glm4MoeConfig, weight_name: str) -> Optional[int]: if hasattr(config, "num_nextn_predict_layers") and (config.num_nextn_predict_layers From 302ae63a6ba993c1f98832b522526757d605e743 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Thu, 7 Aug 2025 07:24:39 +0000 Subject: [PATCH 5/5] Update Signed-off-by: DarkLight1337 --- vllm/model_executor/models/commandr.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/models/commandr.py b/vllm/model_executor/models/commandr.py index c4f6144ed91f..69281abf730a 100644 --- a/vllm/model_executor/models/commandr.py +++ b/vllm/model_executor/models/commandr.py @@ -27,7 +27,7 @@ import torch from torch import nn -from transformers import CohereConfig +from transformers import Cohere2Config, CohereConfig from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile @@ -89,7 +89,7 @@ class CohereMLP(nn.Module): def __init__( self, - config: CohereConfig, + config: Union[CohereConfig, Cohere2Config], quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ): @@ -124,7 +124,7 @@ class CohereAttention(nn.Module): def __init__( self, - config: CohereConfig, + config: Union[CohereConfig, Cohere2Config], cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", @@ -242,7 +242,7 @@ def forward( class CohereDecoderLayer(nn.Module): def __init__(self, - config: CohereConfig, + config: Union[CohereConfig, Cohere2Config], cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = ""):