Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions vllm/model_executor/layers/mamba/mamba_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,6 @@ def linear_attention_state_dtype(
model_dtype: ModelDType | torch.dtype,
mamba_cache_dtype: MambaDType,
) -> tuple[torch.dtype, ...]:
# TODO (tdoublep) requires testing
if mamba_cache_dtype == "float32":
raise ValueError("fp32 state for minimax is not yet supported")
state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype, model_dtype)
return (state_dtype,)

Expand Down
28 changes: 15 additions & 13 deletions vllm/model_executor/models/bailing_moe_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
)
from vllm.forward_context import get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.custom_op import PluggableLayer
from vllm.model_executor.layers.fla.ops.layernorm_guard import (
RMSNormGated,
layernorm_fn,
Expand Down Expand Up @@ -211,7 +212,6 @@ def __init__(
max_position=max_position,
is_neox_style=False,
rope_parameters=rope_parameters or None,
dtype=torch.float32,
)

# Build MLAModules for MultiHeadLatentAttentionWrapper
Expand Down Expand Up @@ -425,14 +425,18 @@ def _weight_loader(param: torch.nn.Parameter, loaded_weight: torch.Tensor) -> No
param.data.copy_(loaded_weight[shard].contiguous())


class BailingMoELinearAttention(nn.Module, MambaBase):
"""
Bailing MoE Linear Attention implementation using minimax backend.
# --8<-- [start:bailing_moe_linear_attention]
@PluggableLayer.register("bailing_moe_linear_attention")
class BailingMoELinearAttention(PluggableLayer, MambaBase):
"""Pluggable Bailing MoE Linear Attention layer which allows OOT backends
to add custom implementations.

This implements the linear attention mechanism from sglang, adapted for vLLM's
v1 engine with MambaBase interface support.
This implements the linear attention mechanism from sglang, adapted for
vLLM's v1 engine with MambaBase interface support.
"""

# --8<-- [end:bailing_moe_linear_attention]

@property
def mamba_type(self) -> str:
return "linear_attention"
Expand Down Expand Up @@ -569,7 +573,6 @@ def __init__(
self.head_dim,
max_position=self.max_position_embeddings,
is_neox_style=True,
dtype=torch.float32,
rope_parameters=rope_parameters or None,
)

Expand Down Expand Up @@ -754,19 +757,17 @@ def _prefill_and_mix_infer(

def _decode_infer(self, q, k, v, kv_cache, state_indices_tensor, attn_metadata):
"""Handle decode (single token per sequence)."""
num_prefill_tokens = attn_metadata.num_prefill_tokens
num_prefills = attn_metadata.num_prefills
hidden = linear_attention_decode(
q,
k,
v,
kv_cache,
self.tp_slope,
state_indices_tensor,
q_start=num_prefill_tokens,
q_end=None,
slot_start=num_prefills,
slot_end=None,
q_start=0,
q_end=attn_metadata.num_decode_tokens,
slot_start=0,
slot_end=attn_metadata.num_decodes,
block_size=32,
)
return hidden
Expand Down Expand Up @@ -1149,6 +1150,7 @@ def __init__(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "lm_head"),
)
self.logits_processor = LogitsProcessor(config.vocab_size)
else:
Expand Down
Loading