diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index c325a0381755..775c60c86574 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -17,6 +17,7 @@ from vllm.model_executor.custom_op import CustomOp, PluggableLayer from vllm.model_executor.layers.linear import ( ColumnParallelLinear, + MergedColumnParallelLinear, RowParallelLinear, ) from vllm.model_executor.layers.mamba.abstract import MambaBase @@ -301,94 +302,127 @@ def __init__( self.groups_ssm_state_size = self.n_groups * self.ssm_state_size self.conv_dim = intermediate_size + 2 * self.groups_ssm_state_size - # Use ColumnParallelLinear with custom weight loaders for both cases: - # - When n_groups % tp_size == 0: standard sharding without duplication - # - When n_groups == 1: groups are duplicated across TP ranks - # The custom weight loader handles both cases correctly. - - self.conv1d = ColumnParallelLinear( - input_size=conv_kernel_size, - output_size=self.conv_dim, - bias=use_conv_bias, - quant_config=None, - prefix=f"{prefix}.conv1d", - ) + if n_groups % self.tp_size == 0: + self.conv1d = MergedColumnParallelLinear( + input_size=conv_kernel_size, + output_sizes=[ + intermediate_size, + self.groups_ssm_state_size, + self.groups_ssm_state_size, + ], + bias=use_conv_bias, + quant_config=None, + prefix=f"{prefix}.conv1d", + ) - self.in_proj = ColumnParallelLinear( - input_size=hidden_size, - output_size=intermediate_size + self.conv_dim + self.num_heads, - bias=use_bias, - quant_config=quant_config, - prefix=f"{prefix}.in_proj", - ) + self.in_proj = MergedColumnParallelLinear( + input_size=hidden_size, + output_sizes=[ + intermediate_size, + intermediate_size, + self.groups_ssm_state_size, + self.groups_ssm_state_size, + self.num_heads, + ], + bias=use_bias, + quant_config=quant_config, + prefix=f"{prefix}.in_proj", + ) + else: + # This is the n_groups == 1 case, + # where we need to duplicate groups if TP>1. + + self.conv1d = ColumnParallelLinear( + input_size=conv_kernel_size, + output_size=self.conv_dim, + bias=use_conv_bias, + quant_config=None, + prefix=f"{prefix}.conv1d", + ) - # Configure shard settings for the custom weight loader: - # - group_shard_settings handles group duplication when n_groups == 1 - # - When n_groups % tp_size == 0, extra=0 and duplicate_groups=False - group_shard_settings = ( - self.groups_ssm_state_size, # expected model size - (self.n_groups - n_groups) * self.ssm_state_size, # extra dims assigned - n_groups == 1, # duplicate groups when n_groups == 1 - ) - intermediate_settings = (intermediate_size, 0, False) - head_settings = (self.num_heads, 0, False) - - # Apply custom weight loaders for conv1d (bias and weight) - delattr(self.conv1d.bias, "weight_loader") - set_weight_attrs( - self.conv1d.bias, - { - "weight_loader": mamba_v2_sharded_weight_loader( - [ - intermediate_settings, - group_shard_settings, - group_shard_settings, - ], - self.tp_size, - tp_rank, - ) - }, - ) + self.in_proj = ColumnParallelLinear( + input_size=hidden_size, + output_size=intermediate_size + self.conv_dim + self.num_heads, + bias=use_bias, + quant_config=quant_config, + prefix=f"{prefix}.in_proj", + ) - delattr(self.conv1d.weight, "weight_loader") - set_weight_attrs( - self.conv1d.weight, - { - "weight_loader": mamba_v2_sharded_weight_loader( - [ - intermediate_settings, - group_shard_settings, - group_shard_settings, - ], - self.tp_size, - tp_rank, - ) - }, - ) + # - because in_proj is a concatenation of 3 weights, we + # need to interleave them before sharding + # - use the custom weight loader mamba_v2_sharded_weight_loader + # for conv1d.bias, covn1d.weight and in_proj.weight + # - need to set these settings, to assign the groups + # to the head shards + group_shard_settings = ( + self.groups_ssm_state_size, # expected model size + (self.n_groups - n_groups) * self.ssm_state_size, # extra dims assigned + n_groups == 1, # if there was only one group + ) + intermediate_settings = (intermediate_size, 0, False) + head_settings = (self.num_heads, 0, False) + + # - the weight already has a "weight_loader" attribute + # which set_weight_attrs will raise if we do not + # delete before trying to override it + # - ditto for the other two weights below + delattr(self.conv1d.bias, "weight_loader") + set_weight_attrs( + self.conv1d.bias, + { + "weight_loader": mamba_v2_sharded_weight_loader( + [ + intermediate_settings, + group_shard_settings, + group_shard_settings, + ], + self.tp_size, + tp_rank, + ) + }, + ) - # Create the custom weight loader for in_proj - mamba_loader = mamba_v2_sharded_weight_loader( - [ - intermediate_settings, # for gate - intermediate_settings, - group_shard_settings, - group_shard_settings, - head_settings, # for dt - ], - self.tp_size, - tp_rank, - ) + delattr(self.conv1d.weight, "weight_loader") + set_weight_attrs( + self.conv1d.weight, + { + "weight_loader": mamba_v2_sharded_weight_loader( + [ + intermediate_settings, + group_shard_settings, + group_shard_settings, + ], + self.tp_size, + tp_rank, + ) + }, + ) - # Apply the custom weight loader to in_proj.weight - # Works for both non-quantized (Parameter) and quantized - # (ModelWeightParameter which extends BasevLLMParameter) - if isinstance(self.in_proj.weight, BasevLLMParameter): - # For BasevLLMParameter subclasses (quantized layers like FP8) - self.in_proj.weight.weight_loader = mamba_loader - else: - # For standard Parameter (non-quantized layers) - delattr(self.in_proj.weight, "weight_loader") - set_weight_attrs(self.in_proj.weight, {"weight_loader": mamba_loader}) + # Create the custom weight loader for Mamba sharding with group + # replication. This handles the interleaved projections correctly. + mamba_loader = mamba_v2_sharded_weight_loader( + [ + intermediate_settings, # for gate + intermediate_settings, + group_shard_settings, + group_shard_settings, + head_settings, # for dt + ], + self.tp_size, + tp_rank, + ) + + # Apply the custom weight loader to in_proj.weight + # Works for both non-quantized (Parameter) and quantized + # (ModelWeightParameter which extends BasevLLMParameter) + if isinstance(self.in_proj.weight, BasevLLMParameter): + # For BasevLLMParameter subclasses (quantized layers like FP8) + # These have a weight_loader property that can be directly set + self.in_proj.weight.weight_loader = mamba_loader + else: + # For standard Parameter (non-quantized layers) + delattr(self.in_proj.weight, "weight_loader") + set_weight_attrs(self.in_proj.weight, {"weight_loader": mamba_loader}) # unsqueeze to fit conv1d weights shape into the linear weights shape. # Can't do this in `weight_loader` since it already exists in