Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
202 changes: 84 additions & 118 deletions vllm/model_executor/layers/mamba/mamba_mixer2.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
MergedColumnParallelLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.mamba.abstract import MambaBase
Expand All @@ -40,6 +39,7 @@
composed_weight_loader,
sharded_weight_loader,
)
from vllm.model_executor.parameter import BasevLLMParameter
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.utils.torch_utils import direct_register_custom_op
Expand Down Expand Up @@ -280,13 +280,6 @@ def __init__(
"then num_groups must equal 1."
)

assert (
(n_groups % self.tp_size == 0) or self.tp_size == 1 or quant_config is None
), (
"Tensor parallel currently supported for quantized models only "
"if tensor parallel world size divides num groups."
)

self.ssm_state_size = ssm_state_size
self.conv_kernel_size = conv_kernel_size
self.activation = activation
Expand All @@ -308,121 +301,94 @@ def __init__(
self.groups_ssm_state_size = self.n_groups * self.ssm_state_size
self.conv_dim = intermediate_size + 2 * self.groups_ssm_state_size

if n_groups % self.tp_size == 0:
self.conv1d = MergedColumnParallelLinear(
input_size=conv_kernel_size,
output_sizes=[
intermediate_size,
self.groups_ssm_state_size,
self.groups_ssm_state_size,
],
bias=use_conv_bias,
quant_config=None,
prefix=f"{prefix}.conv1d",
)
# Use ColumnParallelLinear with custom weight loaders for both cases:
# - When n_groups % tp_size == 0: standard sharding without duplication
# - When n_groups == 1: groups are duplicated across TP ranks
# The custom weight loader handles both cases correctly.

self.conv1d = ColumnParallelLinear(
input_size=conv_kernel_size,
output_size=self.conv_dim,
bias=use_conv_bias,
quant_config=None,
prefix=f"{prefix}.conv1d",
)

self.in_proj = MergedColumnParallelLinear(
input_size=hidden_size,
output_sizes=[
intermediate_size,
intermediate_size,
self.groups_ssm_state_size,
self.groups_ssm_state_size,
self.num_heads,
],
bias=use_bias,
quant_config=quant_config,
prefix=f"{prefix}.in_proj",
)
else:
# This is the n_groups == 1 case,
# where we need to duplicate groups if TP>1.

self.conv1d = ColumnParallelLinear(
input_size=conv_kernel_size,
output_size=self.conv_dim,
bias=use_conv_bias,
quant_config=None,
prefix=f"{prefix}.conv1d",
)
self.in_proj = ColumnParallelLinear(
input_size=hidden_size,
output_size=intermediate_size + self.conv_dim + self.num_heads,
bias=use_bias,
quant_config=quant_config,
prefix=f"{prefix}.in_proj",
)

self.in_proj = ColumnParallelLinear(
input_size=hidden_size,
output_size=intermediate_size + self.conv_dim + self.num_heads,
bias=use_bias,
quant_config=quant_config,
prefix=f"{prefix}.in_proj",
)
# Configure shard settings for the custom weight loader:
# - group_shard_settings handles group duplication when n_groups == 1
# - When n_groups % tp_size == 0, extra=0 and duplicate_groups=False
group_shard_settings = (
self.groups_ssm_state_size, # expected model size
(self.n_groups - n_groups) * self.ssm_state_size, # extra dims assigned
n_groups == 1, # duplicate groups when n_groups == 1
)
intermediate_settings = (intermediate_size, 0, False)
head_settings = (self.num_heads, 0, False)

# Apply custom weight loaders for conv1d (bias and weight)
delattr(self.conv1d.bias, "weight_loader")
set_weight_attrs(
self.conv1d.bias,
{
"weight_loader": mamba_v2_sharded_weight_loader(
[
intermediate_settings,
group_shard_settings,
group_shard_settings,
],
self.tp_size,
tp_rank,
)
},
)

# - because in_proj is a concatenation of 3 weights, we
# need to interleave them before sharding
# - use the custom weight loader mamba_v2_sharded_weight_loader
# for conv1d.bias, covn1d.weight and in_proj.weight
# - need to set these settings, to assign the groups
# to the head shards
group_shard_settings = (
self.groups_ssm_state_size, # expected model size
(self.n_groups - n_groups) * self.ssm_state_size, # extra dims assigned
n_groups == 1, # if there was only one group
)
intermediate_settings = (intermediate_size, 0, False)
head_settings = (self.num_heads, 0, False)

# - the weight already has a "weight_loader" attribute
# which set_weight_attrs will raise if we do not
# delete before trying to override it
# - ditto for the other two weights below
delattr(self.conv1d.bias, "weight_loader")
set_weight_attrs(
self.conv1d.bias,
{
"weight_loader": mamba_v2_sharded_weight_loader(
[
intermediate_settings,
group_shard_settings,
group_shard_settings,
],
self.tp_size,
tp_rank,
)
},
)
delattr(self.conv1d.weight, "weight_loader")
set_weight_attrs(
self.conv1d.weight,
{
"weight_loader": mamba_v2_sharded_weight_loader(
[
intermediate_settings,
group_shard_settings,
group_shard_settings,
],
self.tp_size,
tp_rank,
)
},
)

delattr(self.conv1d.weight, "weight_loader")
set_weight_attrs(
self.conv1d.weight,
{
"weight_loader": mamba_v2_sharded_weight_loader(
[
intermediate_settings,
group_shard_settings,
group_shard_settings,
],
self.tp_size,
tp_rank,
)
},
)
# Create the custom weight loader for in_proj
mamba_loader = mamba_v2_sharded_weight_loader(
[
intermediate_settings, # for gate
intermediate_settings,
group_shard_settings,
group_shard_settings,
head_settings, # for dt
],
self.tp_size,
tp_rank,
)

if quant_config is None:
# - quant layers do not have a weight loader
delattr(self.in_proj.weight, "weight_loader")
set_weight_attrs(
self.in_proj.weight,
{
"weight_loader": mamba_v2_sharded_weight_loader(
[
intermediate_settings, # for gate
intermediate_settings,
group_shard_settings,
group_shard_settings,
head_settings, # for dt
],
self.tp_size,
tp_rank,
)
},
)
# Apply the custom weight loader to in_proj.weight
# Works for both non-quantized (Parameter) and quantized
# (ModelWeightParameter which extends BasevLLMParameter)
if isinstance(self.in_proj.weight, BasevLLMParameter):
# For BasevLLMParameter subclasses (quantized layers like FP8)
self.in_proj.weight.weight_loader = mamba_loader
else:
# For standard Parameter (non-quantized layers)
delattr(self.in_proj.weight, "weight_loader")
set_weight_attrs(self.in_proj.weight, {"weight_loader": mamba_loader})

# unsqueeze to fit conv1d weights shape into the linear weights shape.
# Can't do this in `weight_loader` since it already exists in
Expand Down