diff --git a/tests/test_routing_simulator.py b/tests/test_routing_simulator.py new file mode 100644 index 000000000000..8324b225a8ce --- /dev/null +++ b/tests/test_routing_simulator.py @@ -0,0 +1,171 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Test script for the token-to-expert routing simulator. + +This script demonstrates how to use the routing simulator to test +different routing strategies and analyze their performance, including +integration tests with FusedMoE layer. +""" + +import pytest +import torch + +from vllm.model_executor.layers.fused_moe.routing_simulator import ( + DistributionBasedRouting, RoutingSimulator) + + +@pytest.fixture +def device(): + """Fixture to provide the appropriate device for testing.""" + return torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +@pytest.mark.parametrize("num_tokens", [1, 16, 256]) +@pytest.mark.parametrize("hidden_size", [64, 1024]) +@pytest.mark.parametrize("num_experts", [16, 128]) +@pytest.mark.parametrize("top_k", [1, 4]) +def test_basic_functionality( + num_tokens: int, + hidden_size: int, + num_experts: int, + top_k: int, + device, +): + """Test basic functionality of the routing simulator.""" + # Test each routing strategy + strategies = RoutingSimulator.get_available_strategies() + + hidden_states = torch.randn(num_tokens, hidden_size, device=device) + router_logits = torch.randn(num_tokens, num_experts, device=device) + + for strategy in strategies: + # Simulate routing + topk_weights, topk_ids = RoutingSimulator.simulate_routing( + hidden_states=hidden_states, + router_logits=router_logits, + strategy_name=strategy, + top_k=top_k, + ) + + # Check output shapes + assert topk_weights.shape == ( + num_tokens, + top_k, + ), f"Wrong weights shape for {strategy}" + assert topk_ids.shape == ( + num_tokens, + top_k, + ), f"Wrong ids shape for {strategy}" + + # Check that expert IDs are valid + assert (topk_ids.min() + >= 0), f"Invalid expert ID (negative) for {strategy}" + assert (topk_ids.max() + < num_experts), f"Invalid expert ID (too large) for {strategy}" + + +def test_routing_strategy_integration(monkeypatch, device): + """Test that the routing strategy environment variable works with + FusedMoE.""" + pytest.importorskip("vllm.model_executor.layers.fused_moe.layer") + + import vllm.envs as envs + from vllm.model_executor.layers.fused_moe.layer import FusedMoE + + # Test parameters + num_tokens = 32 + hidden_size = 16 + num_experts = 4 + top_k = 2 + + # Create test data + hidden_states = torch.randn(num_tokens, hidden_size, device=device) + router_logits = torch.randn(num_tokens, num_experts, device=device) + + # Test different routing strategies + strategies = RoutingSimulator.get_available_strategies() + + for strategy in strategies: + # Set environment variable + env_name = "VLLM_MOE_ROUTING_SIMULATION_STRATEGY" + monkeypatch.setenv(env_name, strategy) + + # Force reload of environment variable + envs.environment_variables[env_name] = lambda s=strategy: s + + # Test the select_experts method + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=hidden_states, + router_logits=router_logits, + top_k=top_k, + use_grouped_topk=False, + renormalize=True, + indices_type=torch.long) + + # Verify output shapes + assert topk_weights.shape == ( + num_tokens, top_k), f"Wrong weights shape for {strategy}" + assert topk_ids.shape == (num_tokens, + top_k), f"Wrong ids shape for {strategy}" + + # Verify expert IDs are valid + assert topk_ids.min( + ) >= 0, f"Invalid expert ID (negative) for {strategy}" + assert topk_ids.max( + ) < num_experts, f"Invalid expert ID (too large) for {strategy}" + + +def test_distribution_based_routing_with_custom_strategy(): + """Test registering and using DistributionBasedRouting with custom + parameters.""" + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # Register custom distribution-based strategy + custom_strategy = DistributionBasedRouting(distribution="normal", + mean=2.0, + std=0.5) + RoutingSimulator.register_strategy("custom_normal", custom_strategy) + + # Test data + num_tokens = 60 + hidden_size = 48 + num_experts = 6 + top_k = 3 + + hidden_states = torch.randn(num_tokens, hidden_size, device=device) + router_logits = torch.randn(num_tokens, num_experts, device=device) + + # Use the custom strategy + topk_weights, topk_ids = RoutingSimulator.simulate_routing( + hidden_states=hidden_states, + router_logits=router_logits, + strategy_name="custom_normal", + top_k=top_k) + + # Check output shapes + assert topk_weights.shape == (num_tokens, top_k) + assert topk_ids.shape == (num_tokens, top_k) + + # Check that expert IDs are valid + assert topk_ids.min() >= 0 + assert topk_ids.max() < num_experts + + +def test_instance_compatibility(): + """Test that static methods work correctly.""" + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # Test static method directly + hidden_states = torch.randn(10, 8, device=device) + router_logits = torch.randn(10, 4, device=device) + + topk_weights, topk_ids = RoutingSimulator.simulate_routing( + hidden_states=hidden_states, + router_logits=router_logits, + strategy_name="uniform_random", + top_k=2) + + assert topk_weights.shape == (10, 2) + assert topk_ids.shape == (10, 2) diff --git a/vllm/envs.py b/vllm/envs.py index 2fda2903179b..8c5defee73ca 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -964,6 +964,15 @@ def get_vllm_port() -> Optional[int]: "VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE": lambda: int(os.getenv("VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE", "163840")), + # MoE routing strategy selector. + # See `RoutingSimulator.get_available_strategies()` # for available + # strategies. + # Cutstom routing strategies can be registered by + # RoutingSimulator.register_strategy() + # Note: custom strategies may not produce correct model outputs + "VLLM_MOE_ROUTING_SIMULATION_STRATEGY": + lambda: os.environ.get("VLLM_MOE_ROUTING_SIMULATION_STRATEGY", "").lower(), + # Regex timeout for use by the vLLM tool parsing plugins. "VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS": lambda: int(os.getenv("VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS", "1")), diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index e16fc13c945c..fbe2f5059499 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -28,6 +28,8 @@ FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize) from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( is_rocm_aiter_moe_enabled) +from vllm.model_executor.layers.fused_moe.routing_simulator import ( + RoutingSimulator) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.utils import set_weight_attrs @@ -1275,6 +1277,16 @@ def select_experts( """ from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk + # Check if we should use a routing simulation strategy + routing_strategy = envs.VLLM_MOE_ROUTING_SIMULATION_STRATEGY + if routing_strategy != "": + return RoutingSimulator.simulate_routing( + hidden_states=hidden_states, + router_logits=router_logits, + strategy_name=routing_strategy, + top_k=top_k, + indices_type=indices_type) + # DeepSeekv2 uses grouped_top_k if use_grouped_topk: assert topk_group is not None diff --git a/vllm/model_executor/layers/fused_moe/routing_simulator.py b/vllm/model_executor/layers/fused_moe/routing_simulator.py new file mode 100644 index 000000000000..c8b107f13cd0 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/routing_simulator.py @@ -0,0 +1,289 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Token-to-Expert Routing Simulator + +This module provides a framework for simulating and testing different +token-to-expert routing strategies for Mixture of Experts (MoE) models. +It supports routing logic customization and includes example implementations +like uniform random routing. +""" + +from abc import ABC, abstractmethod +from typing import Optional + +import torch + + +class RoutingStrategy(ABC): + """Base class for token-to-expert routing strategies.""" + + @abstractmethod + def route_tokens( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + indices_type: Optional[torch.dtype] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Route tokens to experts. + + Args: + hidden_states: Input hidden states [num_tokens, hidden_size] + router_logits: Router logits [num_tokens, num_experts] + top_k: Number of experts to select per token + indices_type: Data type for expert indices + + Returns: + tuple of (topk_weights, topk_ids) + """ + pass + + +class DistributionBasedRouting(RoutingStrategy): + """ + Distribution-based random routing strategy with configurable distributions. + + This routing strategy randomly selects experts for each token based on + different probability distributions. Currently supports uniform and normal + distributions for testing different routing patterns. + """ + + def __init__(self, distribution: str = "uniform", **distribution_params): + """ + Initialize distribution-based routing. + + Args: + distribution: Type of distribution to use for sampling + - "uniform": Uniform distribution (default) + - "normal": Normal/Gaussian distribution + **distribution_params: Parameters specific to the + chosen distribution + For "uniform": No additional parameters needed + For "normal": mean (default: 0.0), std (default: 1.0) + """ + self.distribution = distribution.lower() + self.distribution_params = distribution_params + + # Validate distribution and parameters + self._validate_distribution_params() + + def _validate_distribution_params(self): + """Validate distribution type and parameters.""" + valid_distributions = ["uniform", "normal"] + + if self.distribution not in valid_distributions: + raise ValueError(f"Unsupported distribution: {self.distribution}. " + f"Supported distributions: {valid_distributions}") + + # Set default parameters if not provided + if self.distribution == "normal": + self.distribution_params.setdefault("mean", 0.0) + self.distribution_params.setdefault("std", 1.0) + + def route_tokens( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + indices_type: Optional[torch.dtype] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Randomly select experts for each token using the specified distribution. + + Args: + hidden_states: Input hidden states [num_tokens, hidden_size] + router_logits: Router logits [num_tokens, num_experts] + top_k: Number of experts to select per token + indices_type: Data type for expert indices + + Returns: + tuple of (topk_weights, topk_ids) where: + - topk_weights: Weights based on distribution sampling + - topk_ids: Expert indices sampled from the distribution + """ + num_tokens = hidden_states.shape[0] + num_experts = router_logits.shape[-1] + + if indices_type is None: + indices_type = torch.long + + # Generate expert IDs based on the specified distribution + topk_ids = self._sample_expert_ids(num_tokens, num_experts, top_k, + hidden_states.device, indices_type) + + # Generate weights based on the distribution + topk_weights = self._generate_weights(num_tokens, top_k, + hidden_states.device) + + return topk_weights, topk_ids + + def _sample_expert_ids( + self, + num_tokens: int, + num_experts: int, + top_k: int, + device: torch.device, + indices_type: torch.dtype, + ) -> torch.Tensor: + """Sample expert IDs based on the specified distribution.""" + + if self.distribution == "uniform": + # Uniform random sampling + return torch.randint( + low=0, + high=num_experts, + size=(num_tokens, top_k), + dtype=indices_type, + device=device, + ) + + elif self.distribution == "normal": + # For normal distribution, sample continuous values and map to + # expert IDs + continuous_samples = self._sample_continuous_distribution( + num_tokens, top_k, device) + + # Map continuous samples to expert indices + # Normalize to [0, 1] range and scale to [0, num_experts) + normalized_samples = self._normalize_samples(continuous_samples) + expert_ids = (normalized_samples * num_experts).long() + expert_ids = torch.clamp(expert_ids, 0, num_experts - 1) + + return expert_ids.to(dtype=indices_type) + + else: + raise ValueError(f"Unsupported distribution: {self.distribution}") + + def _sample_continuous_distribution(self, num_tokens: int, top_k: int, + device: torch.device) -> torch.Tensor: + """Sample from continuous distributions.""" + shape = (num_tokens, top_k) + + if self.distribution == "normal": + mean = self.distribution_params["mean"] + std = self.distribution_params["std"] + return torch.normal(mean, std, size=shape, device=device) + + else: + raise ValueError( + f"Unsupported continuous distribution: {self.distribution}") + + def _normalize_samples(self, samples: torch.Tensor) -> torch.Tensor: + """Normalize samples to [0, 1] range.""" + if self.distribution == "normal": + # Use sigmoid to map normal distribution to [0, 1] + return torch.sigmoid(samples) + + else: + raise ValueError(f"Unsupported distribution for normalization: " + f"{self.distribution}") + + def _generate_weights(self, num_tokens: int, top_k: int, + device: torch.device) -> torch.Tensor: + """Generate weights based on the distribution.""" + if self.distribution == "uniform": + # All-ones weights for uniform distribution + return torch.ones( + (num_tokens, top_k), + dtype=torch.float32, + device=device, + ) + + elif self.distribution == "normal": + # For normal distribution, generate weights from the same + # distribution + continuous_weights = self._sample_continuous_distribution( + num_tokens, top_k, device) + # Normalize to positive values and sum to 1 + weights = torch.abs(continuous_weights) + weights = weights / weights.sum(dim=-1, keepdim=True) + return weights + + else: + raise ValueError( + f"Unsupported distribution for weight generation: " + f"{self.distribution}") + + def get_distribution_info(self) -> dict: + """Get information about the current distribution configuration.""" + return { + "distribution": self.distribution, + "parameters": self.distribution_params.copy() + } + + +class RoutingSimulator: + """ + Token-to-Expert Routing Simulator. + + This class provides a framework for testing and comparing different + routing strategies for MoE models. It can simulate routing behavior + and collect statistics for analysis. + """ + + # Class-level registry of routing strategies + _routing_strategies: dict[str, RoutingStrategy] = { + # Basic routing strategies + "uniform_random": + DistributionBasedRouting(distribution="uniform", mean=0.0, std=1.0), + "normal_routing": + DistributionBasedRouting(distribution="normal", mean=0.0, std=1.0), + } + + @classmethod + def register_strategy(cls, name: str, strategy: RoutingStrategy): + """ + Register a custom routing strategy. + + Args: + name: Name of the strategy + strategy: RoutingStrategy instance + """ + cls._routing_strategies[name] = strategy + + @classmethod + def get_available_strategies(cls): + """ + Get list of available routing strategy names. + + Returns: + List of available strategy names + """ + return list(cls._routing_strategies.keys()) + + @staticmethod + def simulate_routing( + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + strategy_name: str, + top_k: int, + indices_type: Optional[torch.dtype] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Simulate token-to-expert routing using the specified strategy. + + Args: + hidden_states: Input hidden states [num_tokens, hidden_size] + router_logits: Router logits [num_tokens, num_experts] + strategy_name: Name of the routing strategy to use + top_k: Number of experts to select per token + indices_type: Data type for expert indices + + Returns: + tuple of (topk_weights, topk_ids) + """ + if strategy_name not in RoutingSimulator._routing_strategies: + raise ValueError( + f"Unknown routing strategy: {strategy_name}. " + f"Available strategies: " + f"{list(RoutingSimulator._routing_strategies.keys())}") + + strategy = RoutingSimulator._routing_strategies[strategy_name] + return strategy.route_tokens( + hidden_states=hidden_states, + router_logits=router_logits, + top_k=top_k, + indices_type=indices_type, + )