Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
200 changes: 117 additions & 83 deletions vllm/model_executor/layers/mamba/mamba_mixer2.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from vllm.model_executor.custom_op import CustomOp, PluggableLayer
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
MergedColumnParallelLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.mamba.abstract import MambaBase
Expand Down Expand Up @@ -301,94 +302,127 @@ def __init__(
self.groups_ssm_state_size = self.n_groups * self.ssm_state_size
self.conv_dim = intermediate_size + 2 * self.groups_ssm_state_size

# Use ColumnParallelLinear with custom weight loaders for both cases:
# - When n_groups % tp_size == 0: standard sharding without duplication
# - When n_groups == 1: groups are duplicated across TP ranks
# The custom weight loader handles both cases correctly.

self.conv1d = ColumnParallelLinear(
input_size=conv_kernel_size,
output_size=self.conv_dim,
bias=use_conv_bias,
quant_config=None,
prefix=f"{prefix}.conv1d",
)
if n_groups % self.tp_size == 0:
self.conv1d = MergedColumnParallelLinear(
input_size=conv_kernel_size,
output_sizes=[
intermediate_size,
self.groups_ssm_state_size,
self.groups_ssm_state_size,
],
bias=use_conv_bias,
quant_config=None,
prefix=f"{prefix}.conv1d",
)

self.in_proj = ColumnParallelLinear(
input_size=hidden_size,
output_size=intermediate_size + self.conv_dim + self.num_heads,
bias=use_bias,
quant_config=quant_config,
prefix=f"{prefix}.in_proj",
)
self.in_proj = MergedColumnParallelLinear(
input_size=hidden_size,
output_sizes=[
intermediate_size,
intermediate_size,
self.groups_ssm_state_size,
self.groups_ssm_state_size,
self.num_heads,
],
bias=use_bias,
quant_config=quant_config,
prefix=f"{prefix}.in_proj",
)
else:
# This is the n_groups == 1 case,
# where we need to duplicate groups if TP>1.

self.conv1d = ColumnParallelLinear(
input_size=conv_kernel_size,
output_size=self.conv_dim,
bias=use_conv_bias,
quant_config=None,
prefix=f"{prefix}.conv1d",
)

# Configure shard settings for the custom weight loader:
# - group_shard_settings handles group duplication when n_groups == 1
# - When n_groups % tp_size == 0, extra=0 and duplicate_groups=False
group_shard_settings = (
self.groups_ssm_state_size, # expected model size
(self.n_groups - n_groups) * self.ssm_state_size, # extra dims assigned
n_groups == 1, # duplicate groups when n_groups == 1
)
intermediate_settings = (intermediate_size, 0, False)
head_settings = (self.num_heads, 0, False)

# Apply custom weight loaders for conv1d (bias and weight)
delattr(self.conv1d.bias, "weight_loader")
set_weight_attrs(
self.conv1d.bias,
{
"weight_loader": mamba_v2_sharded_weight_loader(
[
intermediate_settings,
group_shard_settings,
group_shard_settings,
],
self.tp_size,
tp_rank,
)
},
)
self.in_proj = ColumnParallelLinear(
input_size=hidden_size,
output_size=intermediate_size + self.conv_dim + self.num_heads,
bias=use_bias,
quant_config=quant_config,
prefix=f"{prefix}.in_proj",
)

delattr(self.conv1d.weight, "weight_loader")
set_weight_attrs(
self.conv1d.weight,
{
"weight_loader": mamba_v2_sharded_weight_loader(
[
intermediate_settings,
group_shard_settings,
group_shard_settings,
],
self.tp_size,
tp_rank,
)
},
)
# - because in_proj is a concatenation of 3 weights, we
# need to interleave them before sharding
# - use the custom weight loader mamba_v2_sharded_weight_loader
# for conv1d.bias, covn1d.weight and in_proj.weight
# - need to set these settings, to assign the groups
# to the head shards
group_shard_settings = (
self.groups_ssm_state_size, # expected model size
(self.n_groups - n_groups) * self.ssm_state_size, # extra dims assigned
n_groups == 1, # if there was only one group
)
intermediate_settings = (intermediate_size, 0, False)
head_settings = (self.num_heads, 0, False)

# - the weight already has a "weight_loader" attribute
# which set_weight_attrs will raise if we do not
# delete before trying to override it
# - ditto for the other two weights below
delattr(self.conv1d.bias, "weight_loader")
set_weight_attrs(
self.conv1d.bias,
{
"weight_loader": mamba_v2_sharded_weight_loader(
[
intermediate_settings,
group_shard_settings,
group_shard_settings,
],
self.tp_size,
tp_rank,
)
},
)

# Create the custom weight loader for in_proj
mamba_loader = mamba_v2_sharded_weight_loader(
[
intermediate_settings, # for gate
intermediate_settings,
group_shard_settings,
group_shard_settings,
head_settings, # for dt
],
self.tp_size,
tp_rank,
)
delattr(self.conv1d.weight, "weight_loader")
set_weight_attrs(
self.conv1d.weight,
{
"weight_loader": mamba_v2_sharded_weight_loader(
[
intermediate_settings,
group_shard_settings,
group_shard_settings,
],
self.tp_size,
tp_rank,
)
},
)

# Apply the custom weight loader to in_proj.weight
# Works for both non-quantized (Parameter) and quantized
# (ModelWeightParameter which extends BasevLLMParameter)
if isinstance(self.in_proj.weight, BasevLLMParameter):
# For BasevLLMParameter subclasses (quantized layers like FP8)
self.in_proj.weight.weight_loader = mamba_loader
else:
# For standard Parameter (non-quantized layers)
delattr(self.in_proj.weight, "weight_loader")
set_weight_attrs(self.in_proj.weight, {"weight_loader": mamba_loader})
# Create the custom weight loader for Mamba sharding with group
# replication. This handles the interleaved projections correctly.
mamba_loader = mamba_v2_sharded_weight_loader(
[
intermediate_settings, # for gate
intermediate_settings,
group_shard_settings,
group_shard_settings,
head_settings, # for dt
],
self.tp_size,
tp_rank,
)

# Apply the custom weight loader to in_proj.weight
# Works for both non-quantized (Parameter) and quantized
# (ModelWeightParameter which extends BasevLLMParameter)
if isinstance(self.in_proj.weight, BasevLLMParameter):
# For BasevLLMParameter subclasses (quantized layers like FP8)
# These have a weight_loader property that can be directly set
self.in_proj.weight.weight_loader = mamba_loader
else:
# For standard Parameter (non-quantized layers)
delattr(self.in_proj.weight, "weight_loader")
set_weight_attrs(self.in_proj.weight, {"weight_loader": mamba_loader})

# unsqueeze to fit conv1d weights shape into the linear weights shape.
# Can't do this in `weight_loader` since it already exists in
Expand Down