Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
f9af118
step3p5-turbo-support
csy0225 Jan 27, 2026
4c60890
fix: resove diff from 014
csy0225 Jan 27, 2026
28eecba
fix: fp8 activation should support swigluoai_step
csy0225 Jan 27, 2026
17494c7
format: fix review comments for step3p5, remove groupwise-quant code
csy0225 Jan 28, 2026
9b0cbdb
feat: add step3p5_tool_parser
csy0225 Jan 28, 2026
3268d78
fix: attn module import error
csy0225 Jan 28, 2026
b020ee2
feat: support mtp3
Jan 28, 2026
b247624
fix: FlashInferExperts.apply() got an unexpected keyword argument 'ac…
csy0225 Jan 28, 2026
45eeaf8
fix: step3p5 reasoning parser error
csy0225 Jan 29, 2026
f911ab0
Revert "feat: support mtp3"
Jan 29, 2026
ede83c9
revert routed_scaling_factor passthrough in fused moe
csy0225 Jan 29, 2026
03c65f8
refactor: revert activation_limit for swiglustep
csy0225 Jan 29, 2026
952026e
fix step3p5 reasoning parser
Jan 29, 2026
285a3ab
fix: fix comments about swiglustep, default limit=7.0
csy0225 Jan 29, 2026
864602c
format: remove useless field in step3p5 config
csy0225 Jan 29, 2026
5389b7d
format: simplified rope scaling params update
csy0225 Jan 29, 2026
459448f
fix: mtp
csy0225 Jan 29, 2026
a5ab370
format: fix config.json default value
csy0225 Jan 30, 2026
4086e80
format: pre-commit tool fix
csy0225 Jan 30, 2026
67ced2d
fix: remove moe_dynamic_exp_p from config
csy0225 Jan 30, 2026
9fded69
fix: mtp3 weights load error
csy0225 Jan 30, 2026
0689253
format: some review comments fix
csy0225 Jan 31, 2026
b2b55de
Refactor moe
jeejeelee Jan 31, 2026
128c8c7
fix shared moe
jeejeelee Jan 31, 2026
d66b2de
NIT
jeejeelee Jan 31, 2026
82631af
NIT
jeejeelee Jan 31, 2026
f35951b
feat: keep router logits in fp32 precision
csy0225 Jan 31, 2026
2faeea8
gate
jeejeelee Jan 31, 2026
5147782
Fix MTP
jeejeelee Feb 1, 2026
1c4d28c
triton act
jeejeelee Feb 1, 2026
b4264eb
format: delete print log
csy0225 Feb 1, 2026
809d69d
fix: remove
csy0225 Feb 1, 2026
f86b755
CI: add step3p5 mtp hf examples info
csy0225 Feb 2, 2026
c7b05eb
CI: add step3p5 mtp hf examples info
csy0225 Feb 2, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/models/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,7 @@ th {
| `StableLmForCausalLM` | StableLM | `stabilityai/stablelm-3b-4e1t`, `stabilityai/stablelm-base-alpha-7b-v2`, etc. | | |
| `Starcoder2ForCausalLM` | Starcoder2 | `bigcode/starcoder2-3b`, `bigcode/starcoder2-7b`, `bigcode/starcoder2-15b`, etc. | | ✅︎ |
| `Step1ForCausalLM` | Step-Audio | `stepfun-ai/Step-Audio-EditX`, etc. | ✅︎ | ✅︎ |
| `Step3p5ForCausalLM` | Step-3.5-flash | `stepfun-ai/step-3.5-flash`, etc. | | ✅︎ |
| `TeleChat2ForCausalLM` | TeleChat2 | `Tele-AI/TeleChat2-3B`, `Tele-AI/TeleChat2-7B`, `Tele-AI/TeleChat2-35B`, etc. | ✅︎ | ✅︎ |
| `TeleFLMForCausalLM` | TeleFLM | `CofeAI/FLM-2-52B-Instruct-2407`, `CofeAI/Tele-FLM`, etc. | ✅︎ | ✅︎ |
| `XverseForCausalLM` | XVERSE | `xverse/XVERSE-7B-Chat`, `xverse/XVERSE-13B-Chat`, `xverse/XVERSE-65B-Chat`, etc. | ✅︎ | ✅︎ |
Expand Down
10 changes: 8 additions & 2 deletions tests/kernels/core/test_activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
QuickGELU,
SiluAndMul,
SwigluOAIAndMul,
SwigluStepAndMul,
swiglustep_and_mul_triton,
)
from vllm.utils.torch_utils import set_random_seed

Expand All @@ -36,6 +38,7 @@
"gelu_tanh",
"fatrelu",
"swigluoai_and_mul",
"swiglustep_and_mul",
],
)
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
Expand Down Expand Up @@ -75,9 +78,12 @@ def test_act_and_mul(
elif activation == "swigluoai_and_mul":
layer = SwigluOAIAndMul()
fn = torch.ops._C.swigluoai_and_mul
elif activation == "swiglustep_and_mul":
layer = SwigluStepAndMul()
fn = swiglustep_and_mul_triton
out = layer(x)
ref_out = layer.forward_native(x)
if activation == "swigluoai_and_mul":
if activation in ["swigluoai_and_mul", "swiglustep_and_mul"]:
rtol = {
# For fp16, change the relative tolerance from 1e-3 to 2e-3
torch.float16: 2e-3,
Expand All @@ -104,7 +110,7 @@ def _get_rtol(output) -> float:
opcheck(fn, (out, x, threshold))
elif activation == "swigluoai_and_mul":
opcheck(fn, (out, x, layer.alpha, layer.limit))
else:
elif activation != "swiglustep_and_mul":
opcheck(fn, (out, x))


Expand Down
9 changes: 9 additions & 0 deletions tests/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,9 @@ def check_available_online(
"Step1ForCausalLM": _HfExamplesInfo(
"stepfun-ai/Step-Audio-EditX", trust_remote_code=True
),
"Step3p5ForCausalLM": _HfExamplesInfo(
"stepfun-ai/step-3.5-flash", is_available_online=False
),
"SmolLM3ForCausalLM": _HfExamplesInfo("HuggingFaceTB/SmolLM3-3B"),
"StableLMEpochForCausalLM": _HfExamplesInfo("stabilityai/stablelm-zephyr-3b"),
"StableLmForCausalLM": _HfExamplesInfo("stabilityai/stablelm-3b-4e1t"),
Expand Down Expand Up @@ -1113,6 +1116,12 @@ def check_available_online(
"Qwen3NextMTP": _HfExamplesInfo(
"Qwen/Qwen3-Next-80B-A3B-Instruct", min_transformers_version="4.56.3"
),
"Step3p5MTP": _HfExamplesInfo(
"stepfun-ai/Step-3.5-Flash",
trust_remote_code=True,
speculative_model="stepfun-ai/Step-3.5-Flash",
is_available_online=False,
),
}

_TRANSFORMERS_BACKEND_MODELS = {
Expand Down
6 changes: 6 additions & 0 deletions vllm/config/speculative.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
"longcat_flash_mtp",
"mtp",
"pangu_ultra_moe_mtp",
"step3p5_mtp",
]
EagleModelTypes = Literal["eagle", "eagle3", MTPModelTypes]
SpeculativeMethod = Literal[
Expand Down Expand Up @@ -264,6 +265,11 @@ def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig:
{"n_predict": n_predict, "architectures": ["LongCatFlashMTPModel"]}
)

if hf_config.model_type == "step3p5":
hf_config.model_type = "step3p5_mtp"
n_predict = getattr(hf_config, "num_nextn_predict_layers", 1)
hf_config.update({"n_predict": n_predict, "architectures": ["Step3p5MTP"]})

if initial_architecture == "MistralLarge3ForCausalLM":
hf_config.update({"architectures": ["EagleMistralLarge3ForCausalLM"]})

Expand Down
90 changes: 90 additions & 0 deletions vllm/model_executor/layers/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,63 @@
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils.collection_utils import LazyDict

logger = init_logger(__name__)


@triton.jit
def _swiglustep_and_mul_kernel(
o_ptr,
o_stride,
x_ptr,
x_stride,
limit: tl.constexpr,
d: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
) -> None:
i = tl.program_id(axis=0).to(tl.int64)
j = tl.program_id(axis=1)
o_row_ptr = o_ptr + o_stride * i
x_row_ptr = x_ptr + x_stride * i
offsets = j * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < d

gate = tl.load(x_row_ptr + offsets, mask=mask).to(tl.float32)
up = tl.load(x_row_ptr + offsets + d, mask=mask).to(tl.float32)

gate_silu = tl.sigmoid(gate) * gate
gate_clamped = tl.minimum(gate_silu, limit)
up_clamped = tl.minimum(tl.maximum(up, -limit), limit)

result = gate_clamped * up_clamped
result = result.to(x_ptr.dtype.element_ty)
tl.store(o_row_ptr + offsets, result, mask=mask)


def swiglustep_and_mul_triton(
output: torch.Tensor, input: torch.Tensor, limit: float = 7.0
):
b, n = input.shape
assert input.ndim == 2
assert n % 2 == 0
d = n // 2

def grid(meta):
return (b, triton.cdiv(d, meta["BLOCK_SIZE"]))

_swiglustep_and_mul_kernel[grid](
output,
output.stride(0),
input,
input.stride(0),
limit=limit,
d=d,
BLOCK_SIZE=1024,
)


# --8<-- [start:fatrelu_and_mul]
@CustomOp.register("fatrelu_and_mul")
class FatreluAndMul(CustomOp):
Expand Down Expand Up @@ -304,6 +356,44 @@ def extra_repr(self) -> str:
return f"alpha={repr(self.alpha)}, limit={repr(self.limit)}"


# --8<-- [start:swiglustep_and_mul]
@CustomOp.register("swiglustep_and_mul")
class SwigluStepAndMul(CustomOp):
"""An activation function for SwiGLU with clamping.

Computes x -> silu(x[:d]).clamp(max=limit) * x[d:].clamp(-limit, limit)
where d = x.shape[-1] // 2.

Shapes:
x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
return: (num_tokens, d) or (batch_size, seq_len, d)
"""

def __init__(self, limit: float = 7.0):
super().__init__()
if limit is None:
raise ValueError("SwigluStepAndMul requires limit to be set.")
self.limit = limit

def forward_native(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
gate, up = x.chunk(2, dim=-1)
gate = F.silu(gate)
gate = gate.clamp(max=self.limit)
up = up.clamp(min=-self.limit, max=self.limit)
return gate * up

def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2
output_shape = x.shape[:-1] + (d,)
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
swiglustep_and_mul_triton(out, x, self.limit)
return out

def extra_repr(self) -> str:
return f"limit={repr(self.limit)}"


# --8<-- [start:gelu_new]
@CustomOp.register("gelu_new")
class NewGELU(CustomOp):
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/layers/fused_moe/deep_gemm_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def _supports_quant_scheme(

@staticmethod
def _supports_activation(activation: str) -> bool:
return activation in ["silu"]
return activation in ["silu", "swiglustep"]

@staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/layers/fused_moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1939,7 +1939,7 @@ def _supports_quant_scheme(

@staticmethod
def _supports_activation(activation: str) -> bool:
return activation in ["silu", "gelu", "swigluoai"]
return activation in ["silu", "gelu", "swigluoai", "swiglustep"]

@staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
Expand Down
5 changes: 5 additions & 0 deletions vllm/model_executor/layers/fused_moe/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,11 @@ def apply_moe_activation(
torch.ops._C.gelu_and_mul(output, input)
elif activation == "swigluoai":
torch.ops._C.swigluoai_and_mul(output, input)
elif activation == "swiglustep":
from vllm.model_executor.layers.activation import swiglustep_and_mul_triton

swiglustep_and_mul_triton(output, input)

# Activations without gated multiplication
elif activation == SILU_NO_MUL:
output.copy_(F.silu(input))
Expand Down
2 changes: 2 additions & 0 deletions vllm/model_executor/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@
"SeedOssForCausalLM": ("seed_oss", "SeedOssForCausalLM"),
"Step1ForCausalLM": ("step1", "Step1ForCausalLM"),
"Step3TextForCausalLM": ("step3_text", "Step3TextForCausalLM"),
"Step3p5ForCausalLM": ("step3p5", "Step3p5ForCausalLM"),
"StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
"StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
"Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
Expand Down Expand Up @@ -495,6 +496,7 @@
"MedusaModel": ("medusa", "Medusa"),
"OpenPanguMTPModel": ("openpangu_mtp", "OpenPanguMTP"),
"Qwen3NextMTP": ("qwen3_next_mtp", "Qwen3NextMTP"),
"Step3p5MTP": ("step3p5_mtp", "Step3p5MTP"),
# Temporarily disabled.
# # TODO(woosuk): Re-enable this once the MLP Speculator is supported in V1.
# "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
Expand Down
Loading