Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 171 additions & 0 deletions tests/test_routing_simulator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
#!/usr/bin/env python3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Test script for the token-to-expert routing simulator.

This script demonstrates how to use the routing simulator to test
different routing strategies and analyze their performance, including
integration tests with FusedMoE layer.
"""

import pytest
import torch

from vllm.model_executor.layers.fused_moe.routing_simulator import (
DistributionBasedRouting, RoutingSimulator)


@pytest.fixture
def device():
"""Fixture to provide the appropriate device for testing."""
return torch.device("cuda" if torch.cuda.is_available() else "cpu")


@pytest.mark.parametrize("num_tokens", [1, 16, 256])
@pytest.mark.parametrize("hidden_size", [64, 1024])
@pytest.mark.parametrize("num_experts", [16, 128])
@pytest.mark.parametrize("top_k", [1, 4])
def test_basic_functionality(
num_tokens: int,
hidden_size: int,
num_experts: int,
top_k: int,
device,
):
"""Test basic functionality of the routing simulator."""
# Test each routing strategy
strategies = RoutingSimulator.get_available_strategies()

hidden_states = torch.randn(num_tokens, hidden_size, device=device)
router_logits = torch.randn(num_tokens, num_experts, device=device)

for strategy in strategies:
# Simulate routing
topk_weights, topk_ids = RoutingSimulator.simulate_routing(
hidden_states=hidden_states,
router_logits=router_logits,
strategy_name=strategy,
top_k=top_k,
)

# Check output shapes
assert topk_weights.shape == (
num_tokens,
top_k,
), f"Wrong weights shape for {strategy}"
assert topk_ids.shape == (
num_tokens,
top_k,
), f"Wrong ids shape for {strategy}"

# Check that expert IDs are valid
assert (topk_ids.min()
>= 0), f"Invalid expert ID (negative) for {strategy}"
assert (topk_ids.max()
< num_experts), f"Invalid expert ID (too large) for {strategy}"


def test_routing_strategy_integration(monkeypatch, device):
"""Test that the routing strategy environment variable works with
FusedMoE."""
pytest.importorskip("vllm.model_executor.layers.fused_moe.layer")

import vllm.envs as envs
from vllm.model_executor.layers.fused_moe.layer import FusedMoE

# Test parameters
num_tokens = 32
hidden_size = 16
num_experts = 4
top_k = 2

# Create test data
hidden_states = torch.randn(num_tokens, hidden_size, device=device)
router_logits = torch.randn(num_tokens, num_experts, device=device)

# Test different routing strategies
strategies = RoutingSimulator.get_available_strategies()

for strategy in strategies:
# Set environment variable
env_name = "VLLM_MOE_ROUTING_SIMULATION_STRATEGY"
monkeypatch.setenv(env_name, strategy)

# Force reload of environment variable
envs.environment_variables[env_name] = lambda s=strategy: s

# Test the select_experts method
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
top_k=top_k,
use_grouped_topk=False,
renormalize=True,
indices_type=torch.long)

# Verify output shapes
assert topk_weights.shape == (
num_tokens, top_k), f"Wrong weights shape for {strategy}"
assert topk_ids.shape == (num_tokens,
top_k), f"Wrong ids shape for {strategy}"

# Verify expert IDs are valid
assert topk_ids.min(
) >= 0, f"Invalid expert ID (negative) for {strategy}"
assert topk_ids.max(
) < num_experts, f"Invalid expert ID (too large) for {strategy}"


def test_distribution_based_routing_with_custom_strategy():
"""Test registering and using DistributionBasedRouting with custom
parameters."""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Register custom distribution-based strategy
custom_strategy = DistributionBasedRouting(distribution="normal",
mean=2.0,
std=0.5)
RoutingSimulator.register_strategy("custom_normal", custom_strategy)

# Test data
num_tokens = 60
hidden_size = 48
num_experts = 6
top_k = 3

hidden_states = torch.randn(num_tokens, hidden_size, device=device)
router_logits = torch.randn(num_tokens, num_experts, device=device)

# Use the custom strategy
topk_weights, topk_ids = RoutingSimulator.simulate_routing(
hidden_states=hidden_states,
router_logits=router_logits,
strategy_name="custom_normal",
top_k=top_k)

# Check output shapes
assert topk_weights.shape == (num_tokens, top_k)
assert topk_ids.shape == (num_tokens, top_k)

# Check that expert IDs are valid
assert topk_ids.min() >= 0
assert topk_ids.max() < num_experts


def test_instance_compatibility():
"""Test that static methods work correctly."""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Test static method directly
hidden_states = torch.randn(10, 8, device=device)
router_logits = torch.randn(10, 4, device=device)

topk_weights, topk_ids = RoutingSimulator.simulate_routing(
hidden_states=hidden_states,
router_logits=router_logits,
strategy_name="uniform_random",
top_k=2)

assert topk_weights.shape == (10, 2)
assert topk_ids.shape == (10, 2)
9 changes: 9 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -964,6 +964,15 @@ def get_vllm_port() -> Optional[int]:
"VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE":
lambda: int(os.getenv("VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE", "163840")),

# MoE routing strategy selector.
# See `RoutingSimulator.get_available_strategies()` # for available
# strategies.
# Cutstom routing strategies can be registered by
# RoutingSimulator.register_strategy()
# Note: custom strategies may not produce correct model outputs
"VLLM_MOE_ROUTING_SIMULATION_STRATEGY":
lambda: os.environ.get("VLLM_MOE_ROUTING_SIMULATION_STRATEGY", "").lower(),

# Regex timeout for use by the vLLM tool parsing plugins.
"VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS":
lambda: int(os.getenv("VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS", "1")),
Expand Down
12 changes: 12 additions & 0 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize)
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
is_rocm_aiter_moe_enabled)
from vllm.model_executor.layers.fused_moe.routing_simulator import (
RoutingSimulator)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.utils import set_weight_attrs
Expand Down Expand Up @@ -1275,6 +1277,16 @@ def select_experts(
"""
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk

# Check if we should use a routing simulation strategy
routing_strategy = envs.VLLM_MOE_ROUTING_SIMULATION_STRATEGY
if routing_strategy != "":
return RoutingSimulator.simulate_routing(
hidden_states=hidden_states,
router_logits=router_logits,
strategy_name=routing_strategy,
top_k=top_k,
indices_type=indices_type)

# DeepSeekv2 uses grouped_top_k
if use_grouped_topk:
assert topk_group is not None
Expand Down
Loading