diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index a620495b7a5c..f602d9b62191 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -17,7 +17,6 @@ from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.linear import ( ColumnParallelLinear, - MergedColumnParallelLinear, RowParallelLinear, ) from vllm.model_executor.layers.mamba.abstract import MambaBase @@ -40,6 +39,7 @@ composed_weight_loader, sharded_weight_loader, ) +from vllm.model_executor.parameter import BasevLLMParameter from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.utils.torch_utils import direct_register_custom_op @@ -280,13 +280,6 @@ def __init__( "then num_groups must equal 1." ) - assert ( - (n_groups % self.tp_size == 0) or self.tp_size == 1 or quant_config is None - ), ( - "Tensor parallel currently supported for quantized models only " - "if tensor parallel world size divides num groups." - ) - self.ssm_state_size = ssm_state_size self.conv_kernel_size = conv_kernel_size self.activation = activation @@ -308,121 +301,94 @@ def __init__( self.groups_ssm_state_size = self.n_groups * self.ssm_state_size self.conv_dim = intermediate_size + 2 * self.groups_ssm_state_size - if n_groups % self.tp_size == 0: - self.conv1d = MergedColumnParallelLinear( - input_size=conv_kernel_size, - output_sizes=[ - intermediate_size, - self.groups_ssm_state_size, - self.groups_ssm_state_size, - ], - bias=use_conv_bias, - quant_config=None, - prefix=f"{prefix}.conv1d", - ) + # Use ColumnParallelLinear with custom weight loaders for both cases: + # - When n_groups % tp_size == 0: standard sharding without duplication + # - When n_groups == 1: groups are duplicated across TP ranks + # The custom weight loader handles both cases correctly. + + self.conv1d = ColumnParallelLinear( + input_size=conv_kernel_size, + output_size=self.conv_dim, + bias=use_conv_bias, + quant_config=None, + prefix=f"{prefix}.conv1d", + ) - self.in_proj = MergedColumnParallelLinear( - input_size=hidden_size, - output_sizes=[ - intermediate_size, - intermediate_size, - self.groups_ssm_state_size, - self.groups_ssm_state_size, - self.num_heads, - ], - bias=use_bias, - quant_config=quant_config, - prefix=f"{prefix}.in_proj", - ) - else: - # This is the n_groups == 1 case, - # where we need to duplicate groups if TP>1. - - self.conv1d = ColumnParallelLinear( - input_size=conv_kernel_size, - output_size=self.conv_dim, - bias=use_conv_bias, - quant_config=None, - prefix=f"{prefix}.conv1d", - ) + self.in_proj = ColumnParallelLinear( + input_size=hidden_size, + output_size=intermediate_size + self.conv_dim + self.num_heads, + bias=use_bias, + quant_config=quant_config, + prefix=f"{prefix}.in_proj", + ) - self.in_proj = ColumnParallelLinear( - input_size=hidden_size, - output_size=intermediate_size + self.conv_dim + self.num_heads, - bias=use_bias, - quant_config=quant_config, - prefix=f"{prefix}.in_proj", - ) + # Configure shard settings for the custom weight loader: + # - group_shard_settings handles group duplication when n_groups == 1 + # - When n_groups % tp_size == 0, extra=0 and duplicate_groups=False + group_shard_settings = ( + self.groups_ssm_state_size, # expected model size + (self.n_groups - n_groups) * self.ssm_state_size, # extra dims assigned + n_groups == 1, # duplicate groups when n_groups == 1 + ) + intermediate_settings = (intermediate_size, 0, False) + head_settings = (self.num_heads, 0, False) + + # Apply custom weight loaders for conv1d (bias and weight) + delattr(self.conv1d.bias, "weight_loader") + set_weight_attrs( + self.conv1d.bias, + { + "weight_loader": mamba_v2_sharded_weight_loader( + [ + intermediate_settings, + group_shard_settings, + group_shard_settings, + ], + self.tp_size, + tp_rank, + ) + }, + ) - # - because in_proj is a concatenation of 3 weights, we - # need to interleave them before sharding - # - use the custom weight loader mamba_v2_sharded_weight_loader - # for conv1d.bias, covn1d.weight and in_proj.weight - # - need to set these settings, to assign the groups - # to the head shards - group_shard_settings = ( - self.groups_ssm_state_size, # expected model size - (self.n_groups - n_groups) * self.ssm_state_size, # extra dims assigned - n_groups == 1, # if there was only one group - ) - intermediate_settings = (intermediate_size, 0, False) - head_settings = (self.num_heads, 0, False) - - # - the weight already has a "weight_loader" attribute - # which set_weight_attrs will raise if we do not - # delete before trying to override it - # - ditto for the other two weights below - delattr(self.conv1d.bias, "weight_loader") - set_weight_attrs( - self.conv1d.bias, - { - "weight_loader": mamba_v2_sharded_weight_loader( - [ - intermediate_settings, - group_shard_settings, - group_shard_settings, - ], - self.tp_size, - tp_rank, - ) - }, - ) + delattr(self.conv1d.weight, "weight_loader") + set_weight_attrs( + self.conv1d.weight, + { + "weight_loader": mamba_v2_sharded_weight_loader( + [ + intermediate_settings, + group_shard_settings, + group_shard_settings, + ], + self.tp_size, + tp_rank, + ) + }, + ) - delattr(self.conv1d.weight, "weight_loader") - set_weight_attrs( - self.conv1d.weight, - { - "weight_loader": mamba_v2_sharded_weight_loader( - [ - intermediate_settings, - group_shard_settings, - group_shard_settings, - ], - self.tp_size, - tp_rank, - ) - }, - ) + # Create the custom weight loader for in_proj + mamba_loader = mamba_v2_sharded_weight_loader( + [ + intermediate_settings, # for gate + intermediate_settings, + group_shard_settings, + group_shard_settings, + head_settings, # for dt + ], + self.tp_size, + tp_rank, + ) - if quant_config is None: - # - quant layers do not have a weight loader - delattr(self.in_proj.weight, "weight_loader") - set_weight_attrs( - self.in_proj.weight, - { - "weight_loader": mamba_v2_sharded_weight_loader( - [ - intermediate_settings, # for gate - intermediate_settings, - group_shard_settings, - group_shard_settings, - head_settings, # for dt - ], - self.tp_size, - tp_rank, - ) - }, - ) + # Apply the custom weight loader to in_proj.weight + # Works for both non-quantized (Parameter) and quantized + # (ModelWeightParameter which extends BasevLLMParameter) + if isinstance(self.in_proj.weight, BasevLLMParameter): + # For BasevLLMParameter subclasses (quantized layers like FP8) + self.in_proj.weight.weight_loader = mamba_loader + else: + # For standard Parameter (non-quantized layers) + delattr(self.in_proj.weight, "weight_loader") + set_weight_attrs(self.in_proj.weight, {"weight_loader": mamba_loader}) # unsqueeze to fit conv1d weights shape into the linear weights shape. # Can't do this in `weight_loader` since it already exists in