Skip to content
Merged
1 change: 1 addition & 0 deletions docs/models/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,7 @@ th {
| `MambaForCausalLM` | Mamba | `state-spaces/mamba-130m-hf`, `state-spaces/mamba-790m-hf`, `state-spaces/mamba-2.8b-hf`, etc. | | ✅︎ |
| `Mamba2ForCausalLM` | Mamba2 | `mistralai/Mamba-Codestral-7B-v0.1`, etc. | | ✅︎ |
| `MiMoForCausalLM` | MiMo | `XiaomiMiMo/MiMo-7B-RL`, etc. | ✅︎ | ✅︎ |
| `MiMoV2FlashForCausalLM` | MiMoV2Flash | `XiaomiMiMo/MiMo-V2-Flash`, etc. | ︎| ✅︎ |
| `MiniCPMForCausalLM` | MiniCPM | `openbmb/MiniCPM-2B-sft-bf16`, `openbmb/MiniCPM-2B-dpo-bf16`, `openbmb/MiniCPM-S-1B-sft`, etc. | ✅︎ | ✅︎ |
| `MiniCPM3ForCausalLM` | MiniCPM3 | `openbmb/MiniCPM3-4B`, etc. | ✅︎ | ✅︎ |
| `MiniMaxM2ForCausalLM` | MiniMax-M2 |`MiniMaxAI/MiniMax-M2`, etc. | | ✅︎ |
Expand Down
3 changes: 3 additions & 0 deletions tests/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,9 @@ def check_available_online(
),
"Zamba2ForCausalLM": _HfExamplesInfo("Zyphra/Zamba2-7B-instruct"),
"MiMoForCausalLM": _HfExamplesInfo("XiaomiMiMo/MiMo-7B-RL", trust_remote_code=True),
"MiMoV2FlashForCausalLM": _HfExamplesInfo(
"XiaomiMiMo/MiMo-V2-Flash", trust_remote_code=True
),
"Dots1ForCausalLM": _HfExamplesInfo("rednote-hilab/dots.llm1.inst"),
}

Expand Down
2 changes: 2 additions & 0 deletions vllm/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from vllm.config.model import (
ModelConfig,
iter_architecture_defaults,
str_dtype_to_torch_dtype,
try_match_architecture_defaults,
)
from vllm.config.multimodal import MultiModalConfig
Expand Down Expand Up @@ -72,6 +73,7 @@
# From vllm.config.model
"ModelConfig",
"iter_architecture_defaults",
"str_dtype_to_torch_dtype",
"try_match_architecture_defaults",
# From vllm.config.multimodal
"MultiModalConfig",
Expand Down
5 changes: 5 additions & 0 deletions vllm/config/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1849,6 +1849,11 @@ def try_match_architecture_defaults(
"bfloat16": torch.bfloat16,
}


def str_dtype_to_torch_dtype(type: str):
return _STR_DTYPE_TO_TORCH_DTYPE.get(type)


# model_type -> reason
_FLOAT16_NOT_SUPPORTED_MODELS = {
"gemma2": "Numerical instability. Please use bfloat16 or float32 instead.",
Expand Down
62 changes: 49 additions & 13 deletions vllm/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,7 @@ def __init__(
self.params_dtype = params_dtype
self.quant_config = quant_config
self.prefix = prefix
self.allow_fp8_block_shape_mismatch = False
if quant_config is None:
self.quant_method: QuantizeMethodBase | None = UnquantizedLinearMethod()
else:
Expand Down Expand Up @@ -475,6 +476,7 @@ def __init__(
disable_tp=disable_tp,
)

self._maybe_allow_fp8_block_shape_mismatch()
self.gather_output = gather_output

if output_sizes is None:
Expand Down Expand Up @@ -509,6 +511,33 @@ def __init__(
self.register_parameter("bias", None)
self.update_param_tp_status()

def _maybe_allow_fp8_block_shape_mismatch(self) -> None:
quant_config = getattr(self, "quant_config", None)
weight_block = getattr(quant_config, "weight_block_size", None)
if (
weight_block is None
or len(weight_block) < 1
or len(self.output_partition_sizes) <= 1
):
return

try:
block_n = int(weight_block[0])
except (ValueError, TypeError):
return

if block_n <= 0:
return

if any(size % block_n != 0 for size in self.output_partition_sizes):
self.allow_fp8_block_shape_mismatch = True
logger.debug(
"Allowing FP8 block shape mismatch for %s (block_n=%d, partitions=%s)",
getattr(self, "prefix", "<unknown>"),
block_n,
self.output_partition_sizes,
)

def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
output_dim = getattr(param, "output_dim", None)

Expand Down Expand Up @@ -906,9 +935,11 @@ def __init__(
*,
return_bias: bool = True,
disable_tp: bool = False,
v_head_size: int | None = None,
):
self.hidden_size = hidden_size
self.head_size = head_size
self.v_head_size = v_head_size if v_head_size is not None else head_size
self.total_num_heads = total_num_heads
if total_num_kv_heads is None:
total_num_kv_heads = total_num_heads
Expand All @@ -924,12 +955,14 @@ def __init__(
self.num_kv_head_replicas = 1
input_size = self.hidden_size
output_size = (
(self.num_heads + 2 * self.num_kv_heads) * tp_size * self.head_size
)
self.num_heads * self.head_size
+ self.num_kv_heads * self.head_size
+ self.num_kv_heads * self.v_head_size
) * tp_size
self.output_sizes = [
self.num_heads * self.head_size * tp_size, # q_proj
self.num_kv_heads * self.head_size * tp_size, # k_proj
self.num_kv_heads * self.head_size * tp_size, # v_proj
self.num_kv_heads * self.v_head_size * tp_size, # v_proj
]

super().__init__(
Expand All @@ -950,15 +983,16 @@ def _get_shard_offset_mapping(self, loaded_shard_id: str):
"q": 0,
"k": self.num_heads * self.head_size,
"v": (self.num_heads + self.num_kv_heads) * self.head_size,
"total": (self.num_heads + 2 * self.num_kv_heads) * self.head_size,
"total": (self.num_heads + self.num_kv_heads) * self.head_size
+ self.num_kv_heads * self.v_head_size,
}
return shard_offset_mapping.get(loaded_shard_id)

def _get_shard_size_mapping(self, loaded_shard_id: str):
shard_size_mapping = {
"q": self.num_heads * self.head_size,
"k": self.num_kv_heads * self.head_size,
"v": self.num_kv_heads * self.head_size,
"v": self.num_kv_heads * self.v_head_size,
}
return shard_size_mapping.get(loaded_shard_id)

Expand All @@ -985,7 +1019,7 @@ def _load_fused_module_from_checkpoint(
(
"v",
(self.total_num_heads + self.total_num_kv_heads) * self.head_size,
self.total_num_kv_heads * self.head_size,
self.total_num_kv_heads * self.v_head_size,
),
]

Expand Down Expand Up @@ -1110,7 +1144,7 @@ def weight_loader(
(
"v",
(self.total_num_heads + self.total_num_kv_heads) * self.head_size,
self.total_num_kv_heads * self.head_size,
self.total_num_kv_heads * self.v_head_size,
),
]
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
Expand Down Expand Up @@ -1139,11 +1173,12 @@ def weight_loader(
"v": (
(self.total_num_heads + self.total_num_kv_heads)
* self.head_size,
self.total_num_kv_heads * self.head_size,
self.total_num_kv_heads * self.v_head_size,
),
"total": (
(self.total_num_heads + 2 * self.total_num_kv_heads)
* self.head_size,
(self.total_num_heads + self.total_num_kv_heads)
* self.head_size
+ self.total_num_kv_heads * self.v_head_size,
0,
),
}
Expand All @@ -1170,7 +1205,7 @@ def weight_loader(
shard_size = self.num_kv_heads * self.head_size
elif loaded_shard_id == "v":
shard_offset = (self.num_heads + self.num_kv_heads) * self.head_size
shard_size = self.num_kv_heads * self.head_size
shard_size = self.num_kv_heads * self.v_head_size
# Special case for Quantized Weights.
# If quantized, we need to adjust the offset and size to account
# for the packing.
Expand Down Expand Up @@ -1199,10 +1234,11 @@ def weight_loader(
),
"v": (
(self.num_heads + self.num_kv_heads) * self.head_size,
self.num_kv_heads * self.head_size,
self.num_kv_heads * self.v_head_size,
),
"total": (
(self.num_heads + 2 * self.num_kv_heads) * self.head_size,
(self.num_heads + self.num_kv_heads) * self.head_size
+ self.num_kv_heads * self.v_head_size,
0,
),
}
Expand Down
8 changes: 8 additions & 0 deletions vllm/model_executor/layers/quantization/utils/fp8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1252,6 +1252,14 @@ def validate_fp8_block_shape(
"""Validate block quantization shapes for tensor parallelism."""
from vllm.distributed import get_tensor_model_parallel_world_size

if getattr(layer, "allow_fp8_block_shape_mismatch", False):
logger.debug(
"Skipping FP8 block shape validation for layer %s due to detected"
" mismatch allowance.",
getattr(layer, "prefix", "<unknown>"),
)
return
Comment on lines +1255 to +1261
Copy link
Member

@Isotr0py Isotr0py Dec 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit worried that this will cause unexpected behavior for FP8 kernel if we disabled block shape check.

Perhaps we should improve the block shape check for Mimo-V2's edge case instead of just skipping it.

Perhaps @mgoin can give more insights?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Isotr0py We tried removing the code above and got the following error.

ValueError: Weight output_partition_size = 192 is not divisible by weight quantization block_n = 128.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm we support weights that aren't divisible by 128 for other block fp8 models fine, such as kv_a_proj in deepseek, I wonder if it is a specific fused layer


tp_size = getattr(layer, "tp_size", get_tensor_model_parallel_world_size())
block_n, block_k = block_size[0], block_size[1]

Expand Down
Loading