diff --git a/docs/platforms/plugin.md b/docs/platforms/plugin.md new file mode 100644 index 000000000000..8a4c4ee1c64d --- /dev/null +++ b/docs/platforms/plugin.md @@ -0,0 +1,414 @@ +# SGLang Plugin System + +## Overview + +Allows hardware vendors and developers to extend SGLang **without modifying the main repository code**. + +The framework provides two plugin types, both discovered via Python's standard `setuptools` entry_points: + +| Plugin Type | Entry Point Group | Purpose | +|---|---|---| +| **Hardware Platform Plugin** | `sglang.srt.platforms` | Register a custom hardware platform (device operations, KV cache pools, attention backends, graph capture, compilation backends, etc.) | +| **General Plugin** | `sglang.srt.plugins` | Inject hooks (before/after/around/replace) into any function/method, or replace entire classes | + +### Principles + +- **Non-intrusive**: Existing CUDA/ROCm/NPU/XPU code remains unchanged. OOT code paths are added alongside existing hardware-specific logic. +- **Zero configuration**: Plugins are automatically discovered after `pip install`, no sglang code changes required. +- **Environment variable control**: `SGLANG_PLATFORM` selects or validates the active platform plugin; `SGLANG_PLUGINS` (comma-separated) controls which general plugins to load. + +### Current Scope & Future Direction + +The plugin system currently targets **out-of-tree (OOT) hardware platforms** — enabling new devices to integrate with SGLang without any changes to the main repository. The main-repo hardware paths (CUDA, ROCm, NPU, XPU, etc.) continue to use the existing `is_cuda()`/`is_npu()`/… utility functions. + +As the plugin interfaces mature and stabilize, in-tree hardware backends can be gradually migrated to the same plugin architecture. This would replace the scattered `if device == "cuda" … elif device == "npu" …` branches throughout the codebase with a single polymorphic dispatch through the platform interface, making each hardware backend self-contained and the core engine hardware-agnostic. + +## Architecture + +### Platform Hierarchy + +The platform hierarchy uses a DeviceMixin pattern to share device operations between SRT (LLM inference) and Multimodal subsystems: + +``` +DeviceMixin (shared device identity + operations) +├── SRTPlatform(DeviceMixin) # + graph runner, KV pool, … +│ └── MySRTPlatform(SRTPlatform, MyDeviceMixin) # OOT plugin +└── MMPlatform(DeviceMixin) # + attention backend, VAE, … (future) + └── MyMMPlatform(MMPlatform, MyDeviceMixin) # OOT plugin +``` + +Key design points: +- **DeviceMixin** provides platform identity queries (`is_cuda()`, `is_npu()`, etc.) and device operations (`set_device()`, `get_device_name()`, etc.) +- **SRTPlatform** adds SRT-specific factory methods, capability flags, and lifecycle hooks +- OOT plugins implement a **device mixin** (vendor-specific operations) and compose it with **SRTPlatform** via multiple inheritance +- All methods are **instance methods** (not classmethods), called through the `current_platform` singleton +- Device operations and factory methods raise `NotImplementedError` by default (fail-fast) +- Capability flags use safe conservative defaults (`False`/`pass`) +- Methods are annotated `[Active]` (called by SGLang core) or `[Planned]` (reserved for future migration) + +### Platform Discovery (`current_platform`) + +`current_platform` is a **lazy singleton** in `sglang.srt.platforms`. On first access it resolves the active platform through the following priority chain: + +``` +entry_points("sglang.srt.platforms") → Enumerate ALL plugins by name (metadata only) + │ + ├─ SGLANG_PLATFORM set (front-loading filter): + │ ├─ Name not found in discovered → RuntimeError + │ ├─ activate() returns non-None → load that platform + │ └─ activate() returns None → RuntimeError (hardware unavailable) + │ + └─ SGLANG_PLATFORM unset (auto-discover, activate all): + ├─ 0 activated → fallback base SRTPlatform + ├─ 1 activated → use it + └─ N activated → RuntimeError (must set SGLANG_PLATFORM) +``` + +### Plugin Loading Flow + +`load_plugins()` discovers and executes general plugins, then applies all registered hooks. It is called at four points: + +| Call Site | Process | Timing | +|---------|------|------| +| `cli/serve.py` serve() | Main | Before `prepare_server_args()` | +| `launch_server.py` `__main__` | Main | Before `prepare_server_args()` | +| `engine.py` `_launch_subprocesses()` | Main | Before `server_args.check_server_args()` | +| `scheduler.py` `run_scheduler_process()` | Subprocess | Before `Scheduler()` construction | + +> **Note**: `load_plugins()` is idempotent (guarded by `_plugins_loaded` flag). In spawn'd subprocesses the flag resets, so plugins are correctly re-loaded. + +``` +load_plugins() + ├── _get_excluded_dists() → compute dists to skip (via SGLANG_PLATFORM) + ├── load_plugins_by_group("sglang.srt.plugins", → discover entry_points, filter by SGLANG_PLUGINS + │ excluded_dists=...) skip plugins from unselected platform packages + ├── for each plugin: → set _current_plugin_source context var + │ func() side effects (register hooks with source tracking) + └── HookRegistry.apply_hooks() → monkey-patch targets +``` + +--- + +## Plugin Type 1: Hardware Platform Plugin + +### Description + +A hardware platform plugin registers an `SRTPlatform` subclass that tells SGLang how to interact with a specific hardware backend. + +### Quick Start + +**1. Create a minimal package:** + +``` +my_platform_plugin/ +├── pyproject.toml +└── my_platform_plugin/ + ├── __init__.py # activate() function + ├── device.py # MyDeviceMixin + └── platform.py # MySRTPlatform +``` + +**2. `pyproject.toml`:** + +```toml +[build-system] +requires = ["setuptools"] +build-backend = "setuptools.build_meta" + +[project] +name = "my-platform-plugin" +version = "0.1.0" + +[project.entry-points."sglang.srt.platforms"] +my_device = "my_platform_plugin:activate" +``` + +**3. `__init__.py`** — activation function: + +```python +def activate(): + """Return fully-qualified class name to activate, or None to skip.""" + if _my_device_is_available(): + return "my_platform_plugin.platform.MySRTPlatform" + return None +``` + +**4. `device.py`** — device mixin: + +```python +from sglang.srt.platforms.device_mixin import DeviceMixin, PlatformEnum + +class MyDeviceMixin(DeviceMixin): + _enum = PlatformEnum.OOT + device_name = "my_device" + device_type = "my_device" # torch device type + + def set_device(self, device) -> None: ... + def get_device_name(self, device_id=0) -> str: ... + def get_device_total_memory(self, device_id=0) -> int: ... + def get_current_memory_usage(self, device=None) -> float: ... + def get_device_capability(self, device_id=0): ... + def get_torch_distributed_backend_str(self) -> str: ... +``` + +**5. `platform.py`** — SRT platform: + +```python +from sglang.srt.platforms.interface import SRTPlatform +from my_platform_plugin.device import MyDeviceMixin + +class MySRTPlatform(SRTPlatform, MyDeviceMixin): + def get_default_attention_backend(self) -> str: ... + def support_cuda_graph(self) -> bool: ... + # ... override other methods as needed +``` + +**6. Install and verify:** + +```bash +pip install -e my_platform_plugin/ +python -c "from sglang.srt.platforms import current_platform; print(current_platform)" +``` + +### Platform Interface Reference + +#### Identity Queries (from DeviceMixin) + +| Method | Default | Description | +|---|---|---| +| `is_cuda()` | Based on `_enum` | Whether this is an NVIDIA CUDA platform | +| `is_rocm()` | Based on `_enum` | Whether this is an AMD ROCm platform | +| `is_npu()` | Based on `_enum` | Whether this is a Huawei NPU platform | +| `is_cpu()` | Based on `_enum` | Whether this is a CPU-only platform | +| `is_xpu()` | Based on `_enum` | Whether this is an Intel XPU platform | +| `is_musa()` | Based on `_enum` | Whether this is a Moore Threads MUSA platform | +| `is_cuda_alike()` | CUDA+ROCM+MUSA | True if the hardware supports CUDA-like APIs | +| `is_out_of_tree()` | `True` for OOT | Automatically detected based on `_enum = PlatformEnum.OOT` | + +#### Device Operations (from DeviceMixin) + +> Methods annotated **[Active]** are called by SGLang core through `current_platform` — OOT implementations take effect immediately. +> Methods annotated **[Planned]** are reserved interfaces — SGLang core still uses hardcoded calls (e.g. `torch.cuda.empty_cache()`). OOT implementations will NOT take effect until the core is migrated in a future PR. + +| Method | Default | Status | Description | +|---|---|---|---| +| `get_device(local_rank)` | `raise NotImplementedError` | Planned | Return `torch.device` for a given local rank | +| `set_device(device)` | `raise NotImplementedError` | Planned | Set the current device | +| `get_device_name(device_id)` | `raise NotImplementedError` | Planned | Get human-readable device name | +| `get_device_uuid(device_id)` | `raise NotImplementedError` | Planned | Get unique device identifier | +| `get_device_capability(device_id)` | `raise NotImplementedError` | Planned | Get `DeviceCapability(major, minor)`. None if N/A | +| `empty_cache()` | `pass` | Planned | Release cached device memory | +| `synchronize()` | `pass` | Planned | Synchronize device operations | +| `get_device_total_memory(device_id)` | `raise NotImplementedError` | **Active** | Get total device memory in bytes | +| `get_available_memory(device_id)` | `raise NotImplementedError` | Planned | Return `(free_bytes, total_bytes)` | +| `get_current_memory_usage(device)` | `raise NotImplementedError` | **Active** | Get current peak memory usage in bytes | +| `get_torch_distributed_backend_str()` | `raise NotImplementedError` | Planned | Distributed backend string (e.g. "nccl", "hccl") | +| `get_communicator_class()` | `None` | Planned | Platform-specific communicator class | +| `inference_mode()` | `torch.inference_mode(True)` | Planned | Return inference mode context manager | +| `seed_everything(seed)` | Set random/np/torch seeds | Planned | Set random seeds for reproducibility | +| `verify_quantization(quant)` | `pass` | Planned | Validate quantization method support | +| `get_cpu_architecture()` | Auto-detect x86/arm | Planned | Detect CPU architecture (`CpuArchEnum`) | + +#### Types (from DeviceMixin) + +| Type | Description | +|---|---| +| `PlatformEnum` | Enumeration of platform types: CUDA, ROCM, CPU, XPU, MUSA, NPU, TPU, MPS, OOT, UNSPECIFIED | +| `CpuArchEnum` | CPU architecture: X86, ARM, UNSPECIFIED | +| `DeviceCapability` | `NamedTuple(major, minor)` with comparison support. Methods: `as_version_str()`, `to_int()` | + +#### Capability Flags (from SRTPlatform) + +| Method | Default | Description | +|---|---|---| +| `support_cuda_graph()` | `False` | Whether device graph capture is supported (plain CUDA graph) | +| `support_piecewise_cuda_graph()` | `False` | Whether piecewise CUDA graph (torch.compile backend) is supported | +| `supports_fp8()` | `False` | Whether FP8 quantization is supported | +| `is_pin_memory_available()` | `True` | Whether pinned memory is available | + +#### Subsystem Factory Methods (from SRTPlatform) + +| Method | Default | Description | +|---|---|---| +| `get_default_attention_backend()` | `raise NotImplementedError` | Default attention backend name | +| `get_graph_runner_cls()` | `raise NotImplementedError` | Graph Runner class | +| `get_mha_kv_pool_cls()` | `raise NotImplementedError` | MHA KV cache pool class | +| `get_mla_kv_pool_cls()` | `raise NotImplementedError` | MLA KV cache pool class | +| `get_nsa_kv_pool_cls()` | `raise NotImplementedError` | NSA KV cache pool class (DeepSeek V3.2) | +| `get_paged_allocator_cls()` | `raise NotImplementedError` | Paged allocator class | +| `get_piecewise_backend_cls()` | `raise NotImplementedError` | Piecewise compilation backend class | +| `get_compile_backend(mode)` | `"inductor"` | Compilation backend string | +| `get_dispatch_key_name()` | `"native"` | MultiPlatformOp dispatch key name | + +#### Lifecycle Hooks (from SRTPlatform) + +| Method | Invocation Timing | Purpose | +|---|---|---| +| `apply_server_args_defaults(server_args)` | After ServerArgs parsing, in `__post_init__` | Set platform-specific defaults | +| `init_backend()` | In each worker, before model construction | One-time backend initialization | + +### Environment Variables + +| Variable | Description | +|---|---| +| `SGLANG_PLATFORM` | Select the platform plugin by entry_point name (e.g. `kunlun`, `demo_cuda`). When set, **only** the named plugin's `activate()` is called (front-loading filter) — other plugins are not touched. Additionally, general plugins (`sglang.srt.plugins`) from unselected platform packages are automatically skipped to avoid importing their dependencies. Required when multiple plugins would activate. Errors if the name is not found or if the plugin's hardware is unavailable. | +| `SGLANG_PLUGINS` | Comma-separated whitelist of general plugin names to load (group: `sglang.srt.plugins`). If unset, all discovered general plugins are loaded. | + +--- + +## Plugin Type 2: General Plugin + +### Description + +General function plugins inject behavior into sglang **without requiring a custom platform**. Use cases include: + +- **Observability**: Add logging, metrics, and tracing to any function +- **Behavior modification**: Modify function arguments or return values +- **Performance profiling**: Add timing to critical functions +- **A/B testing**: Replace implementations at runtime + +### Quick Start + +**1. Create a minimal package:** + +``` +my_general_plugin/ +├── pyproject.toml +└── my_general_plugin/ + └── __init__.py # register() function +``` + +**2. `pyproject.toml`:** + +```toml +[build-system] +requires = ["setuptools"] +build-backend = "setuptools.build_meta" + +[project] +name = "my-general-plugin" +version = "0.1.0" + +[project.entry-points."sglang.srt.plugins"] +my_plugin = "my_general_plugin:register" +``` + +**3. `__init__.py`** — register hooks: + +```python +from sglang.srt.plugins.hook_registry import HookRegistry, HookType + +def register(): + """Entry point called by load_plugins().""" + HookRegistry.register( + "sglang.srt.managers.scheduler.Scheduler.__init__", + my_hook, + HookType.AROUND, + ) + +def my_hook(original_fn, self, *args, **kwargs): + result = original_fn(self, *args, **kwargs) + print(f"Scheduler initialized! gpu_id={self.gpu_id}") + return result +``` + +**4. Install and run:** + +```bash +pip install -e my_general_plugin/ +sglang serve --model-path [options] +# Look for "Scheduler initialized!" in logs +``` + +### Hook Types + +`HookRegistry` supports four hook types: + +| Hook Type | Signature | Description | +|---|---|---| +| **BEFORE** | `fn(*args, **kwargs) -> (args, kwargs) \| None` | Runs before the original. Return `None` to keep args unchanged, or `(args, kwargs)` to modify. | +| **AFTER** | `fn(result, *args, **kwargs) -> new_result \| None` | Runs after the original. Return `None` to keep result, or a new value to replace. | +| **AROUND** | `fn(original_fn, *args, **kwargs) -> result` | Wraps the original. You must call `original_fn` yourself. Full control over execution. | +| **REPLACE** | `fn(*args, **kwargs) -> result` or `class` | Replace the original function or class entirely. For class targets, pass a replacement class directly — it is substituted via `setattr` preserving `isinstance()`/`issubclass()` semantics. | + +> **Note**: Only `REPLACE` accepts a class as the hook. Passing a class to `BEFORE`/`AFTER`/`AROUND` raises `TypeError` at registration time. + +### Registration API + +Hooks can be registered using the **imperative API** or the **decorator API**: + +```python +# --- Imperative API --- +from sglang.srt.plugins.hook_registry import HookRegistry, HookType + +def my_timer(original_fn, *args, **kwargs): + start = time.perf_counter() + result = original_fn(*args, **kwargs) + print(f"Elapsed: {time.perf_counter() - start:.3f}s") + return result + +HookRegistry.register( + "sglang.srt.managers.scheduler.Scheduler.get_next_batch_to_run", + my_timer, + HookType.AROUND, +) + +# --- Decorator API --- +from sglang.srt.plugins.hook_registry import plugin_hook, HookType + +@plugin_hook( + "sglang.srt.managers.scheduler.Scheduler.get_next_batch_to_run", + type=HookType.AROUND, +) +def my_timer(original_fn, *args, **kwargs): + start = time.perf_counter() + result = original_fn(*args, **kwargs) + print(f"Elapsed: {time.perf_counter() - start:.3f}s") + return result + +# --- Class replacement (REPLACE) --- +from sglang.srt.plugins.hook_registry import plugin_hook, HookType +from sglang.srt.managers.scheduler import Scheduler + +@plugin_hook( + "sglang.srt.managers.scheduler.Scheduler", + type=HookType.REPLACE, +) +class MyScheduler(Scheduler): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + print("Enhanced scheduler initialized!") +``` + +### Hook Target Resolution + +Target paths use fully-qualified dotted notation. Both formats are supported: + +- **Dotted**: `sglang.srt.managers.scheduler.Scheduler.__init__` +- **Entry-points style**: `sglang.srt.managers.scheduler:Scheduler.__init__` (colon treated as dot) + +### Common Hook Targets + +| Target | Description | +|---|---| +| `sglang.srt.server_args.ServerArgs.add_cli_args` | Add custom CLI arguments | +| `sglang.srt.server_args.ServerArgs.__post_init__` | Modify ServerArgs after parsing | +| `sglang.srt.server_args.ServerArgs.check_server_args` | Add/relax validation | +| `sglang.srt.managers.scheduler.Scheduler.__init__` | Custom scheduler state | +| `sglang.srt.managers.scheduler.Scheduler.get_next_batch_to_run` | Custom scheduling policy | +| `sglang.srt.managers.scheduler.Scheduler.run_batch` | Profiling / inspection | +| `sglang.srt.managers.scheduler.Scheduler.process_batch_result` | Custom metrics | +| `sglang.srt.managers.tp_worker.TpModelWorker.__init__` | Custom worker state | +| `sglang.srt.managers.tp_worker.TpModelWorker.forward_batch_generation` | Forward pass wrapping | + +--- + +## File Reference + +| File | Description | +|---|---| +| `sglang/srt/platforms/device_mixin.py` | `PlatformEnum` + `DeviceMixin` base class | +| `sglang/srt/platforms/interface.py` | `SRTPlatform` base class (extends DeviceMixin) | +| `sglang/srt/platforms/__init__.py` | `current_platform` lazy singleton + discovery logic | +| `sglang/srt/plugins/__init__.py` | `load_plugins()` + `load_plugins_by_group()` | +| `sglang/srt/plugins/hook_registry.py` | `HookRegistry`, `HookType`, `plugin_hook` decorator | diff --git a/python/sglang/cli/serve.py b/python/sglang/cli/serve.py index a9a1874cc949..0268a11007a1 100644 --- a/python/sglang/cli/serve.py +++ b/python/sglang/cli/serve.py @@ -86,6 +86,10 @@ def serve(args, extra_argv): ) return + from sglang.srt.plugins import load_plugins + + load_plugins() + model_type, dispatch_argv = _extract_model_type_override(extra_argv) model_path = get_model_path(dispatch_argv) try: diff --git a/python/sglang/launch_server.py b/python/sglang/launch_server.py index e1d05b1f8803..e5572dcc9ab8 100644 --- a/python/sglang/launch_server.py +++ b/python/sglang/launch_server.py @@ -56,6 +56,10 @@ def run_server(server_args): stacklevel=1, ) + from sglang.srt.plugins import load_plugins + + load_plugins() + server_args = prepare_server_args(sys.argv[1:]) try: diff --git a/python/sglang/srt/compilation/backend.py b/python/sglang/srt/compilation/backend.py index 201123324068..e46e8a1b3c74 100644 --- a/python/sglang/srt/compilation/backend.py +++ b/python/sglang/srt/compilation/backend.py @@ -22,6 +22,7 @@ from sglang.srt.compilation.npu_piecewise_backend import NPUPiecewiseBackend from sglang.srt.compilation.pass_manager import PostGradPassManager from sglang.srt.environ import envs +from sglang.srt.platforms import current_platform from sglang.srt.utils.common import is_npu logger = logging.getLogger(__name__) @@ -48,7 +49,12 @@ def make_backend( sglang_backend, ): - backend_cls = CUDAPiecewiseBackend if not is_npu() else NPUPiecewiseBackend + if current_platform.is_out_of_tree(): + backend_cls = current_platform.get_piecewise_backend_cls() + elif is_npu(): + backend_cls = NPUPiecewiseBackend + else: + backend_cls = CUDAPiecewiseBackend return backend_cls( graph, compile_config, diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index a012fde7973c..04c16a33cda7 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -84,6 +84,7 @@ from sglang.srt.managers.template_manager import TemplateManager from sglang.srt.managers.tokenizer_manager import TokenizerManager from sglang.srt.observability.trace import process_tracing_init, trace_set_thread_info +from sglang.srt.plugins import load_plugins from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.utils import ( MultiprocessingSerializer, @@ -167,6 +168,10 @@ def __init__(self, **kwargs): Please refer to `ServerArgs` for the documentation. """ + # Ensure plugins are loaded before ServerArgs construction, + # so hooks on ServerArgs.__post_init__ fire correctly. + load_plugins() + # Parse server_args if "server_args" in kwargs: # Directly load server_args @@ -647,6 +652,11 @@ def _launch_subprocesses( # Configure global environment configure_logger(server_args) _set_envs_and_config(server_args) + + # Defensive: ensure plugins loaded (may already be loaded by + # Engine.__init__ or CLI entry). + load_plugins() + server_args.check_server_args() _set_gc(server_args) diff --git a/python/sglang/srt/environ.py b/python/sglang/srt/environ.py index 52451a9b2187..7b6b2eaa528d 100644 --- a/python/sglang/srt/environ.py +++ b/python/sglang/srt/environ.py @@ -554,6 +554,10 @@ class Envs: # Sglang Cache Dir SGLANG_CACHE_DIR = EnvStr(os.path.expanduser("~/.cache/sglang")) + # Plugin system + SGLANG_PLATFORM = EnvStr("") + SGLANG_PLUGINS = EnvStr("") + envs = Envs() EnvField._allow_set_name = False diff --git a/python/sglang/srt/layers/utils/multi_platform.py b/python/sglang/srt/layers/utils/multi_platform.py index ff4d89914843..d482b4530aef 100644 --- a/python/sglang/srt/layers/utils/multi_platform.py +++ b/python/sglang/srt/layers/utils/multi_platform.py @@ -1,8 +1,9 @@ -from typing import Callable +from typing import Callable, ClassVar from torch import nn from sglang.kernel_api_logging import debug_kernel_api +from sglang.srt.platforms import current_platform from sglang.srt.utils import ( cpu_has_amx_support, is_cpu, @@ -23,6 +24,15 @@ class MultiPlatformOp(nn.Module): + + # OOT forward registry: maps dispatch_key -> {op_cls -> forward_fn} + _oot_forward_registry: ClassVar[dict[str, dict[type, Callable]]] = {} + + @classmethod + def register_oot_forward(cls, op_cls: type, fn: Callable, platform_key: str): + """Register an OOT forward implementation for a specific op class and platform.""" + cls._oot_forward_registry.setdefault(platform_key, {})[op_cls] = fn + def __init__(self): super().__init__() self._forward_method: Callable = self.dispatch_forward() @@ -100,6 +110,17 @@ def forward_cpu(self, *args, **kwargs): return self.forward_native(*args, **kwargs) def dispatch_forward(self): + # OOT platform dispatch: check registry then method lookup + if current_platform.is_out_of_tree(): + key = current_platform.get_dispatch_key_name() + oot = self._oot_forward_registry.get(key, {}) + if type(self) in oot: + return oot[type(self)].__get__(self) + method = getattr(self, f"forward_{key}", None) + if method is not None: + return method + return self.forward_native + if _is_cuda: return self.forward_cuda elif _is_hip: diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 997903bf60dd..3db97cf73c36 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -204,6 +204,7 @@ ) from sglang.srt.observability.trace import process_tracing_init, trace_set_thread_info from sglang.srt.parser.reasoning_parser import ReasoningParser +from sglang.srt.plugins import load_plugins from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.server_args import PortArgs, ServerArgs, get_global_server_args from sglang.srt.speculative.spec_info import SpeculativeAlgorithm @@ -3742,6 +3743,8 @@ def run_scheduler_process( dp_rank: Optional[int], pipe_writer, ): + # Load plugins so hooks can override Scheduler and its dependencies. + load_plugins() dp_rank = configure_scheduler_process( server_args, gpu_id, diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index def1bdb87572..55605f98f337 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -54,6 +54,7 @@ set_mla_kv_buffer_triton_fp8_quant, set_mla_kv_scale_buffer_triton, ) +from sglang.srt.platforms import current_platform from sglang.srt.utils import ( cpu_has_amx_support, is_cpu, @@ -780,8 +781,12 @@ def __init__( self._create_buffers() self.device_module = torch.get_device_module(self.device) + + _use_alt_stream = _is_cuda or current_platform.is_cuda_alike() self.alt_stream = ( - self.device_module.Stream() if _is_cuda and enable_alt_stream else None + self.device_module.Stream() + if _use_alt_stream and enable_alt_stream + else None ) if enable_kv_cache_copy: @@ -1262,7 +1267,9 @@ def __init__( TokenToKVPoolClass = MHATokenToKVPool - if _is_npu: + if current_platform.is_out_of_tree(): + TokenToKVPoolClass = current_platform.get_mha_kv_pool_cls() + elif _is_npu: from sglang.srt.hardware_backend.npu.memory_pool_npu import ( NPUMHATokenToKVPool, ) @@ -1283,7 +1290,9 @@ def __init__( TokenToKVPoolClass = MLATokenToKVPool - if _is_npu: + if current_platform.is_out_of_tree(): + TokenToKVPoolClass = current_platform.get_mla_kv_pool_cls() + elif _is_npu: from sglang.srt.hardware_backend.npu.memory_pool_npu import ( NPUMLATokenToKVPool, ) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index ccc92603e0eb..886dbdd93d41 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -150,6 +150,7 @@ ) from sglang.srt.model_loader.utils import set_default_torch_dtype from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.platforms import current_platform from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.server_args import ( ServerArgs, @@ -207,6 +208,8 @@ from sglang.srt.hardware_backend.npu.utils import init_npu_backend init_npu_backend() +elif current_platform.is_out_of_tree(): + current_platform.init_backend() MLA_ATTENTION_BACKENDS = [ "aiter", @@ -702,6 +705,7 @@ def initialize(self, pre_model_load_memory: float): # Init routed experts capturer self.init_routed_experts_capturer() + # TODO: Refactor device-specific init branches into platform interface (separate PR). # Must be called BEFORE init_device_graphs() so CUDA graph capture # runs with aux hidden state capture enabled. self.init_aux_hidden_state_capture() @@ -714,6 +718,13 @@ def initialize(self, pre_model_load_memory: float): elif self.device in ["npu", "cpu"]: self.init_attention_backend() self.init_device_graphs() + elif current_platform.is_out_of_tree(): + self.init_attention_backend() + if current_platform.support_cuda_graph(): + self.init_device_graphs() + else: + self.graph_runner = None + self.graph_mem_usage = 0 else: self.graph_runner = None self.graph_mem_usage = 0 @@ -1483,7 +1494,14 @@ def model_load_weights(model, iter): self.server_args.load_format = load_format self.load_config = load_config - if recapture_cuda_graph and (self.device == "cuda" or self.device == "musa"): + if recapture_cuda_graph and ( + self.device == "cuda" + or self.device == "musa" + or ( + current_platform.is_out_of_tree() + and current_platform.support_cuda_graph() + ) + ): self.init_device_graphs() logger.info("Update weights end.") @@ -2532,8 +2550,10 @@ def init_device_graphs(self): tic = time.perf_counter() before_mem = get_available_gpu_memory(self.device, self.gpu_id) graph_backend = defaultdict( - lambda: "cuda graph", + lambda: f"{current_platform.device_name} graph", { + "cuda": "cuda graph", + "musa": "cuda graph", "cpu": "cpu graph", "npu": "npu graph", }, @@ -2541,14 +2561,18 @@ def init_device_graphs(self): logger.info( f"Capture {graph_backend[self.device]} begin. This can take up to several minutes. avail mem={before_mem:.2f} GB" ) - graph_runners = defaultdict( - lambda: CudaGraphRunner, - { - "cpu": CPUGraphRunner, - "npu": NPUGraphRunner, - }, - ) - self.graph_runner = graph_runners[self.device](self) + if current_platform.is_out_of_tree(): + GraphRunnerCls = current_platform.get_graph_runner_cls() + self.graph_runner = GraphRunnerCls(self) + else: + graph_runners = defaultdict( + lambda: CudaGraphRunner, + { + "cpu": CPUGraphRunner, + "npu": NPUGraphRunner, + }, + ) + self.graph_runner = graph_runners[self.device](self) after_mem = get_available_gpu_memory(self.device, self.gpu_id) self.graph_mem_usage = before_mem - after_mem diff --git a/python/sglang/srt/model_executor/model_runner_kv_cache_mixin.py b/python/sglang/srt/model_executor/model_runner_kv_cache_mixin.py index 8166a392dbf4..74ce00c19db0 100644 --- a/python/sglang/srt/model_executor/model_runner_kv_cache_mixin.py +++ b/python/sglang/srt/model_executor/model_runner_kv_cache_mixin.py @@ -282,7 +282,63 @@ def _init_pools(self: ModelRunner): # Initialize token_to_kv_pool is_nsa_model = is_deepseek_nsa(self.model_config.hf_config) - if self.server_args.attention_backend == "ascend" and not self.mambaish_config: + + # Check out-of-tree platform (plugin system) first + from sglang.srt.platforms import current_platform + + if current_platform.is_out_of_tree() and not self.mambaish_config: + if self.use_mla_backend and is_nsa_model: + PoolCls = current_platform.get_nsa_kv_pool_cls() + self.token_to_kv_pool = PoolCls( + self.max_total_num_tokens, + page_size=self.page_size, + dtype=self.kv_cache_dtype, + kv_lora_rank=self.model_config.kv_lora_rank, + qk_rope_head_dim=self.model_config.qk_rope_head_dim, + layer_num=self.num_effective_layers, + device=self.device, + kv_cache_dim=self.calculate_mla_kv_cache_dim(), + enable_memory_saver=self.server_args.enable_memory_saver, + start_layer=self.start_layer, + end_layer=self.end_layer, + index_head_dim=get_nsa_index_head_dim(self.model_config.hf_config), + ) + elif self.use_mla_backend: + PoolCls = current_platform.get_mla_kv_pool_cls() + self.token_to_kv_pool = PoolCls( + self.max_total_num_tokens, + page_size=self.page_size, + dtype=self.kv_cache_dtype, + kv_lora_rank=self.model_config.kv_lora_rank, + qk_rope_head_dim=self.model_config.qk_rope_head_dim, + index_head_dim=( + self.model_config.index_head_dim if is_nsa_model else None + ), + layer_num=self.num_effective_layers, + device=self.device, + enable_memory_saver=self.server_args.enable_memory_saver, + start_layer=self.start_layer, + end_layer=self.end_layer, + ) + else: + PoolCls = current_platform.get_mha_kv_pool_cls() + self.token_to_kv_pool = PoolCls( + self.max_total_num_tokens, + page_size=self.page_size, + dtype=self.kv_cache_dtype, + head_num=self.model_config.get_num_kv_heads( + get_attention_tp_size() + ), + head_dim=self.model_config.head_dim, + layer_num=self.num_effective_layers, + device=self.device, + enable_memory_saver=self.server_args.enable_memory_saver, + start_layer=self.start_layer, + end_layer=self.end_layer, + ) + elif ( + self.server_args.attention_backend == "ascend" and not self.mambaish_config + ): if self.is_hybrid_swa: from sglang.srt.hardware_backend.npu.memory_pool_npu import ( NPUMHATokenToKVPool, @@ -513,7 +569,17 @@ def _init_pools(self: ModelRunner): # Initialize token_to_kv_pool_allocator need_sort = self.server_args.disaggregation_mode in ("decode", "prefill") if self.token_to_kv_pool_allocator is None: - if _is_npu and ( + if current_platform.is_out_of_tree(): + AllocatorCls = current_platform.get_paged_allocator_cls() + self.token_to_kv_pool_allocator = AllocatorCls( + self.max_total_num_tokens, + page_size=self.page_size, + dtype=self.kv_cache_dtype, + device=self.device, + kvcache=self.token_to_kv_pool, + need_sort=need_sort, + ) + elif _is_npu and ( self.server_args.attention_backend == "ascend" or self.hybrid_gdn_config is not None ): diff --git a/python/sglang/srt/platforms/__init__.py b/python/sglang/srt/platforms/__init__.py new file mode 100644 index 000000000000..731f5b6c48d6 --- /dev/null +++ b/python/sglang/srt/platforms/__init__.py @@ -0,0 +1,125 @@ +""" +SGLang Platform Discovery and Lazy Initialization. + +Provides `current_platform` as a module-level lazy singleton. On first access, +it discovers platform plugins via entry_points and instantiates the appropriate +SRTPlatform subclass. + +Usage: + from sglang.srt.platforms import current_platform + print(current_platform.device_name) +""" + +import logging +import pkgutil +from importlib.metadata import entry_points + +from sglang.srt.environ import envs +from sglang.srt.platforms.interface import SRTPlatform +from sglang.srt.plugins import PLATFORM_PLUGINS_GROUP, load_plugins_by_group + +logger = logging.getLogger(__name__) + +_current_platform: SRTPlatform | None = None + + +def _resolve_platform() -> SRTPlatform: + """ + Discover and instantiate the active platform. + + Discovery flow: + 1. Branch on SGLANG_PLATFORM: + + SGLANG_PLATFORM set (front-loading filter): + - Enumerate entry_points without importing any plugin modules + - Only ep.load() + activate() the named plugin + - Other plugins are never imported (avoids pulling their dependencies) + - Plugin name not found → RuntimeError + - activate() returns None → RuntimeError (hardware unavailable) + + SGLANG_PLATFORM unset (auto-discover): + - Import and activate all discovered plugins + - 0 activated → fallback base SRTPlatform + - 1 activated → use it + - N activated → RuntimeError (must set SGLANG_PLATFORM) + + SGLANG_PLATFORM matches against entry_point names. + """ + selected = envs.SGLANG_PLATFORM.get() + + if selected: + # Front-loading filter: only import and activate the specified plugin. + # Other plugins' modules are never loaded — avoids pulling their deps. + discovered = entry_points(group=PLATFORM_PLUGINS_GROUP) + ep_map = {ep.name: ep for ep in discovered} + + if selected not in ep_map: + available = ", ".join(f"'{n}'" for n in ep_map) if ep_map else "none" + raise RuntimeError( + f"SGLANG_PLATFORM={selected!r} not found in discovered platform plugins " + f"(available: {available}). Install the plugin with 'pip install -e' " + f"to register its entry_points." + ) + + try: + plugin_fn = ep_map[selected].load() + result = plugin_fn() + except Exception: + logger.exception("Failed to activate platform plugin: %s", selected) + raise + + if result is None: + raise RuntimeError( + f"Platform plugin {selected!r} is installed but activate() returned None " + f"(hardware not available on this machine?)." + ) + logger.info("OOT platform plugin activated: %s -> %s", selected, result) + return _load_platform_class(result)() + + # Auto-discover: import and activate all plugins, expect exactly one + all_plugins = load_plugins_by_group(PLATFORM_PLUGINS_GROUP) + + activated: dict[str, str] = {} + for name, (plugin_fn, _dist) in all_plugins.items(): + try: + result = plugin_fn() + if result is not None: + activated[name] = result + logger.info("OOT platform plugin activated: %s -> %s", name, result) + except Exception: + logger.exception("Failed to activate platform plugin: %s", name) + + if len(activated) == 0: + logger.warning("No platform detected. Using base SRTPlatform with defaults.") + return SRTPlatform() + + if len(activated) == 1: + name, qualname = next(iter(activated.items())) + return _load_platform_class(qualname)() + + # Multiple activated without SGLANG_PLATFORM + names_str = ", ".join(f"'{n}'" for n in activated) + raise RuntimeError( + f"Multiple platform plugins activated: {names_str}. " + f"Set SGLANG_PLATFORM to select one." + ) + + +def _load_platform_class(qualname: str) -> type: + """Load an SRTPlatform subclass from its fully-qualified class name.""" + cls = pkgutil.resolve_name(qualname) + if not isinstance(cls, type) or not issubclass(cls, SRTPlatform): + raise TypeError( + f"Expected an SRTPlatform subclass, got {type(cls)}: {qualname}" + ) + return cls + + +def __getattr__(name: str): + """Lazy initialization of current_platform on first access.""" + if name == "current_platform": + global _current_platform + if _current_platform is None: + _current_platform = _resolve_platform() + return _current_platform + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/python/sglang/srt/platforms/device_mixin.py b/python/sglang/srt/platforms/device_mixin.py new file mode 100644 index 000000000000..49bd2f0c513c --- /dev/null +++ b/python/sglang/srt/platforms/device_mixin.py @@ -0,0 +1,244 @@ +""" +Shared device abstraction for SGLang platforms. + +DeviceMixin provides the common device identity queries and operations +shared between the SRT (LLM inference) and Multimodal (diffusion) +platform hierarchies. Concrete per-device mixins (e.g. MyDeviceMixin) +implement the abstract operations; subsystem-specific platforms +(SRTPlatform, MMPlatform) inherit DeviceMixin and add their own methods. + +Hierarchy example (OOT plugin):: + + DeviceMixin + ├── MyDeviceMixin(DeviceMixin) # vendor-specific device operations + ├── SRTPlatform(DeviceMixin) # + graph runner, KV pool, … + │ └── MySRTPlatform(SRTPlatform, MyDeviceMixin) + └── MMPlatform(DeviceMixin) # + attention backend, VAE, … + └── MyMMPlatform(MMPlatform, MyDeviceMixin) + +Method status annotations: + +- ``[Active]`` — SGLang core calls this method through ``current_platform``. + OOT implementations take effect immediately. +- ``[Planned]`` — Reserved interface. SGLang core still uses hardcoded calls + (e.g. ``torch.cuda.empty_cache()``). OOT implementations will NOT take + effect until the core is migrated in a future PR. +""" + +import enum +from typing import TYPE_CHECKING, NamedTuple, Optional + +if TYPE_CHECKING: + import torch + + +class PlatformEnum(enum.Enum): + """Enumeration of known platform types. + + Superset of both SRT and MM enums so that a single PlatformEnum can + be shared across subsystems. + """ + + CUDA = enum.auto() + ROCM = enum.auto() + CPU = enum.auto() + XPU = enum.auto() + MUSA = enum.auto() + NPU = enum.auto() + TPU = enum.auto() + MPS = enum.auto() + OOT = enum.auto() # Out-of-tree (external plugin) + UNSPECIFIED = enum.auto() + + +class CpuArchEnum(enum.Enum): + """CPU architecture enumeration.""" + + X86 = enum.auto() + ARM = enum.auto() + UNSPECIFIED = enum.auto() + + +class DeviceCapability(NamedTuple): + """Device compute capability (major, minor). + + Uses NamedTuple for built-in comparison support: + ``DeviceCapability(9, 0) >= DeviceCapability(8, 9)`` works naturally. + """ + + major: int + minor: int + + def as_version_str(self) -> str: + return f"{self.major}.{self.minor}" + + def to_int(self) -> int: + """Express capability as ```` (minor is single digit).""" + assert 0 <= self.minor < 10 + return self.major * 10 + self.minor + + +class DeviceMixin: + """Mixin providing device identity queries and basic device operations. + + Class-level attributes (override in subclasses): + _enum: PlatformEnum identifying this platform. + device_name: Human-readable short name (e.g. "cuda", "npu"). + device_type: ``torch.device`` type string (e.g. "cuda", "npu"). + """ + + _enum: PlatformEnum = PlatformEnum.UNSPECIFIED + device_name: str = "unknown" + device_type: str = "cpu" + + # ------------------------------------------------------------------ + # Platform identity queries + # ------------------------------------------------------------------ + + def is_cuda(self) -> bool: + return self._enum == PlatformEnum.CUDA + + def is_rocm(self) -> bool: + return self._enum == PlatformEnum.ROCM + + def is_cpu(self) -> bool: + return self._enum == PlatformEnum.CPU + + def is_xpu(self) -> bool: + return self._enum == PlatformEnum.XPU + + def is_musa(self) -> bool: + return self._enum == PlatformEnum.MUSA + + def is_npu(self) -> bool: + return self._enum == PlatformEnum.NPU + + def is_tpu(self) -> bool: + return self._enum == PlatformEnum.TPU + + def is_mps(self) -> bool: + return self._enum == PlatformEnum.MPS + + def is_cuda_alike(self) -> bool: + """True for CUDA, ROCm, or MUSA (all expose CUDA-like APIs).""" + return self._enum in ( + PlatformEnum.CUDA, + PlatformEnum.ROCM, + PlatformEnum.MUSA, + ) + + def is_out_of_tree(self) -> bool: + """True for externally-registered OOT platforms.""" + return self._enum == PlatformEnum.OOT + + # ------------------------------------------------------------------ + # Active methods — core calls these through current_platform. + # OOT implementations take effect immediately. + # ------------------------------------------------------------------ + + def get_device_total_memory(self, device_id: int = 0) -> int: + """[Active] Get total device memory in bytes.""" + raise NotImplementedError + + def get_current_memory_usage( + self, device: Optional["torch.device"] = None + ) -> float: + """[Active] Get current peak memory usage in bytes.""" + raise NotImplementedError + + # ------------------------------------------------------------------ + # Planned methods — reserved interface. Core still uses hardcoded + # calls (e.g. torch.cuda.*). OOT implementations will NOT take + # effect until the core is migrated in a future PR. + # ------------------------------------------------------------------ + + # ---- Device management ---- + + def get_device(self, local_rank: int) -> "torch.device": + """[Planned] Return ``torch.device`` for the given local rank.""" + raise NotImplementedError + + def set_device(self, device: "torch.device") -> None: + """[Planned] Set the current device.""" + raise NotImplementedError + + def get_device_name(self, device_id: int = 0) -> str: + """[Planned] Get human-readable device name.""" + raise NotImplementedError + + def get_device_uuid(self, device_id: int = 0) -> str: + """[Planned] Get unique device identifier string.""" + raise NotImplementedError + + def get_device_capability(self, device_id: int = 0) -> Optional["DeviceCapability"]: + """[Planned] Get device compute capability. None if N/A.""" + raise NotImplementedError + + def empty_cache(self) -> None: + """[Planned] Release cached device memory. No-op for CPU-like platforms.""" + pass + + def synchronize(self) -> None: + """[Planned] Synchronize device operations. No-op for CPU-like platforms.""" + pass + + # ---- Memory ---- + + def get_available_memory(self, device_id: int = 0) -> tuple[int, int]: + """[Planned] Return ``(free_bytes, total_bytes)``.""" + raise NotImplementedError + + # ---- Distributed ---- + + def get_torch_distributed_backend_str(self) -> str: + """[Planned] Return the torch.distributed backend string (e.g. "nccl", "hccl").""" + raise NotImplementedError + + def get_communicator_class(self) -> type | None: + """[Planned] Return platform-specific communicator class, or None for default.""" + return None + + # ---- Misc ---- + + @classmethod + def inference_mode(cls): + """[Planned] Return inference mode context manager.""" + import torch + + return torch.inference_mode(mode=True) + + @classmethod + def seed_everything(cls, seed: int | None = None) -> None: + """[Planned] Set random seeds for reproducibility across all libraries.""" + if seed is not None: + import random + + import numpy as np + import torch + + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + def verify_quantization(self, quant: str) -> None: + """[Planned] Validate that a quantization method is supported. No-op by default.""" + pass + + @classmethod + def get_cpu_architecture(cls) -> "CpuArchEnum": + """[Planned] Detect CPU architecture.""" + import platform as _platform + + machine = _platform.machine().lower() + if machine in ("x86_64", "amd64", "i386", "i686"): + return CpuArchEnum.X86 + elif machine in ("arm64", "aarch64"): + return CpuArchEnum.ARM + return CpuArchEnum.UNSPECIFIED + + # ------------------------------------------------------------------ + # Dunder helpers + # ------------------------------------------------------------------ + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(device={self.device_name})" diff --git a/python/sglang/srt/platforms/interface.py b/python/sglang/srt/platforms/interface.py new file mode 100644 index 000000000000..b7d4bfdc7b67 --- /dev/null +++ b/python/sglang/srt/platforms/interface.py @@ -0,0 +1,133 @@ +""" +SGLang SRT Hardware Platform Abstraction. + +Defines SRTPlatform — the base class for SRT (LLM inference) platform +backends. SRTPlatform inherits DeviceMixin for shared device operations +and adds SRT-specific subsystem factory methods, capability flags, and +configuration lifecycle hooks. + +Out-of-tree platforms register via setuptools entry_points under the +"sglang.platform_plugins" group and should subclass SRTPlatform. +""" + +from typing import TYPE_CHECKING + +from sglang.srt.platforms.device_mixin import DeviceMixin, PlatformEnum + +if TYPE_CHECKING: + pass + +# Re-export for convenience +__all__ = ["SRTPlatform", "PlatformEnum"] + + +class SRTPlatform(DeviceMixin): + """ + Base class for SRT hardware platform backends. + + Inherits device identity queries and operations from DeviceMixin. + Adds SRT-specific factory methods, capability flags, and lifecycle hooks. + + OOT platforms should subclass SRTPlatform and override the methods + relevant to their hardware. + """ + + # SRT-specific class-level attribute + supported_quantization: list[str] = [] + + # ------------------------------------------------------------------ + # Configuration lifecycle + # ------------------------------------------------------------------ + + def apply_server_args_defaults(self, server_args) -> None: + """Apply platform-specific default values to server arguments. + + Called after ServerArgs is parsed. + """ + pass + + # ------------------------------------------------------------------ + # Subsystem factory methods + # ------------------------------------------------------------------ + + def get_default_attention_backend(self) -> str: + """Return the default attention backend name for this platform.""" + raise NotImplementedError + + def get_graph_runner_cls(self) -> type: + """Return the graph runner class for this platform.""" + raise NotImplementedError + + def get_mha_kv_pool_cls(self) -> type: + """Return the MHA KV pool class for this platform.""" + raise NotImplementedError + + def get_mla_kv_pool_cls(self) -> type: + """Return the MLA KV pool class for this platform.""" + raise NotImplementedError + + def get_nsa_kv_pool_cls(self) -> type: + """Return the NSA KV pool class for this platform (DeepSeek V3.2).""" + raise NotImplementedError + + def get_paged_allocator_cls(self) -> type: + """Return the paged allocator class for this platform.""" + raise NotImplementedError + + def get_compile_backend(self, mode: str | None = None) -> str: + """Return the compilation backend identifier. + + ``mode`` is an optional hint for the platform (e.g. "npugraph_ex"). + """ + return "inductor" + + def get_piecewise_backend_cls(self) -> type: + """Return the piecewise compilation backend class for this platform.""" + raise NotImplementedError + + # ------------------------------------------------------------------ + # Capability flags (safe conservative defaults) + # ------------------------------------------------------------------ + + def supports_fp8(self) -> bool: + """Whether this platform supports FP8 quantization.""" + return False + + def is_pin_memory_available(self) -> bool: + """Whether pinned memory is available on this platform.""" + return True + + def support_cuda_graph(self) -> bool: + """Whether this platform supports device graph capture and replay. + Controls CUDA graph (CudaGraphRunner) for the decode path. + OOT platforms that support graph-style capture should return True. + """ + return False + + def support_piecewise_cuda_graph(self) -> bool: + """Whether this platform supports piecewise CUDA graph. + + Controls PiecewiseCudaGraphRunner for the prefill/extend path + (torch.compile backend). + """ + return False + + # ------------------------------------------------------------------ + # Initialization + # ------------------------------------------------------------------ + + def init_backend(self) -> None: + """One-time backend initialization. Called in each worker.""" + pass + + # ------------------------------------------------------------------ + # MultiPlatformOp integration + # ------------------------------------------------------------------ + + def get_dispatch_key_name(self) -> str: + """Return the dispatch key name for MultiPlatformOp. + + Determines which ``forward_()`` method is selected. + E.g. "cuda", "npu", "hip", "xpu", "cpu". + """ + return "native" diff --git a/python/sglang/srt/plugins/__init__.py b/python/sglang/srt/plugins/__init__.py new file mode 100644 index 000000000000..00ae1acd1826 --- /dev/null +++ b/python/sglang/srt/plugins/__init__.py @@ -0,0 +1,141 @@ +""" +SGLang Unified Plugin Framework. + +Supports two types of plugins via setuptools entry_points: +1. Hardware Platform Plugins (sglang.srt.platforms) - register custom hardware platforms +2. General Plugins (sglang.srt.plugins) - inject hooks into functions/methods, replace classes, etc. + +Plugins are discovered automatically when installed via pip. +- Platform plugins: use ``SGLANG_PLATFORM`` to select when multiple are installed. +- General plugins: use ``SGLANG_PLUGINS`` (comma-separated) to restrict which are loaded. +""" + +import logging +from collections.abc import Callable +from importlib.metadata import entry_points +from typing import Any + +from sglang.srt.environ import envs +from sglang.srt.plugins.hook_registry import ( + HookRegistry, + HookSource, + _current_plugin_source, +) + +logger = logging.getLogger(__name__) + +# Entry point group names +PLATFORM_PLUGINS_GROUP = "sglang.srt.platforms" +GENERAL_PLUGINS_GROUP = "sglang.srt.plugins" + +# Guard against multiple loads in the same process +_plugins_loaded = False + + +def load_plugins_by_group( + group: str, + excluded_dists: set[str] | None = None, +) -> dict[str, tuple[Callable[[], Any], str | None]]: + """ + Discover and load plugins registered under the given entry point group. + + Args: + group: The setuptools entry_point group name. + excluded_dists: Distribution names to skip. Plugins from these + distributions are never ``ep.load()``-ed (avoids importing + their modules and pulling hardware-specific dependencies). + + Returns: + Dictionary mapping plugin name to ``(callable, dist_name)``. + """ + # SGLANG_PLUGINS whitelist (comma-separated plugin names) + allowed_set: set[str] | None = None + allowed_str = envs.SGLANG_PLUGINS.get() + if allowed_str: + allowed_set = {x.strip() for x in allowed_str.split(",") if x.strip()} + + discovered = entry_points(group=group) + if len(discovered) == 0: + logger.debug("No plugins found for group %s.", group) + return {} + + logger.info("Available plugins for group %s:", group) + for ep in discovered: + logger.info(" - %s -> %s", ep.name, ep.value) + + plugins: dict[str, tuple[Callable[[], Any], str | None]] = {} + for ep in discovered: + if allowed_set is not None and ep.name not in allowed_set: + logger.info("Skipping plugin %s (not in SGLANG_PLUGINS)", ep.name) + continue + dist_name = ep.dist.name if ep.dist else None + if excluded_dists and dist_name in excluded_dists: + logger.info( + "Skipping plugin %s (dist %s excluded by SGLANG_PLATFORM)", + ep.name, + dist_name, + ) + continue + try: + func = ep.load() + plugins[ep.name] = (func, dist_name) + logger.info("Loaded plugin %s from group %s", ep.name, group) + except Exception: + logger.exception("Failed to load plugin %s from group %s", ep.name, group) + + return plugins + + +def _get_excluded_dists() -> set[str]: + """Compute dist names to skip when ``SGLANG_PLATFORM`` is set. + + Returns dist names that provide a platform plugin but are NOT the one + selected by ``SGLANG_PLATFORM``. This prevents unselected platform + packages from registering hooks that pull their hardware dependencies. + """ + selected = envs.SGLANG_PLATFORM.get() + if not selected: + return set() + platform_eps = entry_points(group=PLATFORM_PLUGINS_GROUP) + return {ep.dist.name for ep in platform_eps if ep.dist and ep.name != selected} + + +def load_plugins(): + """ + Load and execute all general plugins, then apply registered hooks. + + Idempotent - safe to call multiple times. General plugins are functions + whose side effects (registering hooks, replacing classes, etc.) are the + desired behavior. Return values are ignored. + + When ``SGLANG_PLATFORM`` is set, general plugins from unselected platform + packages are automatically skipped (avoids pulling their dependencies). + + After all plugins execute, ``HookRegistry.apply_hooks()`` is called + automatically so callers only need this single function call. + + This should be called early in every process (main, engine core, workers). + """ + global _plugins_loaded + if _plugins_loaded: + return + _plugins_loaded = True + + plugins = load_plugins_by_group( + GENERAL_PLUGINS_GROUP, + excluded_dists=_get_excluded_dists(), + ) + + for name, (func, dist_name) in plugins.items(): + source = HookSource(plugin_name=name, dist_name=dist_name) + token = _current_plugin_source.set(source) + try: + func() + logger.info("Executed general plugin: %s", name) + except Exception: + logger.exception("Failed to execute general plugin: %s", name) + finally: + _current_plugin_source.reset(token) + + # Apply all registered hooks (idempotent — already-patched targets are skipped). + HookRegistry.apply_hooks() diff --git a/python/sglang/srt/plugins/hook_registry.py b/python/sglang/srt/plugins/hook_registry.py new file mode 100644 index 000000000000..c577b5232c06 --- /dev/null +++ b/python/sglang/srt/plugins/hook_registry.py @@ -0,0 +1,430 @@ +""" +Hook registry for SGLang plugins. + +Provides before/after/around/replace hooks that can be applied to any +function, method, or class in the sglang codebase. Hooks are registered +during plugin loading and applied before the engine starts. + +Usage: + from sglang.srt.plugins.hook_registry import HookRegistry, HookType + + def my_timer(original_fn, *args, **kwargs): + start = time.perf_counter() + result = original_fn(*args, **kwargs) + print(f"Elapsed: {time.perf_counter() - start:.3f}s") + return result + + HookRegistry.register( + "sglang.srt.managers.scheduler.Scheduler.schedule", + my_timer, + HookType.AROUND, + ) +""" + +import contextvars +import functools +import logging +import pkgutil +import sys +import types +from collections import defaultdict +from collections.abc import Callable +from enum import Enum +from typing import NamedTuple + +logger = logging.getLogger(__name__) + + +class HookSource(NamedTuple): + """Identifies which plugin registered a hook.""" + + plugin_name: str # entry_point name, e.g. "xpu_hooks" + dist_name: str | None # distribution name, e.g. "sglang_xpu_platform" + + +# Set by load_plugins() around each plugin's func() call, read by register(). +_current_plugin_source: contextvars.ContextVar[HookSource | None] = ( + contextvars.ContextVar("_current_plugin_source", default=None) +) + + +def _format_source(source: HookSource | None) -> str: + """Format source info for log messages.""" + if source is None: + return "unknown" + if source.dist_name: + return f"plugin={source.plugin_name}, dist={source.dist_name}" + return f"plugin={source.plugin_name}" + + +class HookType(Enum): + """Types of hooks that can be applied to functions or classes.""" + + BEFORE = "before" # Execute before original; can modify args + AFTER = "after" # Execute after original; can modify return value + AROUND = "around" # Wrap original; full control over execution + REPLACE = "replace" # Replace the original function or class entirely + + +class HookRegistry: + """ + Global registry for function/method/class hooks. + + Thread safety: All registration should happen during load_plugins() + phase (single-threaded). apply_hooks() should be called once before the + engine starts serving requests. + """ + + _hooks: dict[str, list[tuple[HookType, Callable, HookSource | None]]] = defaultdict( + list + ) + _patched: set[str] = set() + + @classmethod + def register( + cls, + target: str, + hook: Callable, + hook_type: HookType = HookType.AFTER, + *, + source: HookSource | None = None, + ): + """ + Register a hook on a target function, method, or class. + + Args: + target: Fully-qualified dotted path to the target. + e.g. "sglang.srt.managers.scheduler.Scheduler.schedule" + or "sglang.srt.managers.scheduler.Scheduler" (class) + hook: The hook callable (function or class). Signature depends on hook_type: + - BEFORE: fn(*args, **kwargs) -> (args, kwargs) or None + - AFTER: fn(result, *args, **kwargs) -> new_result or None + - AROUND: fn(original_fn, *args, **kwargs) -> result + - REPLACE: fn(*args, **kwargs) -> result (function replacement) + MyClass (class replacement) + hook_type: Type of hook (default: AFTER). + source: Optional source info. If None, auto-read from context var + set by ``load_plugins()``. + + Raises: + TypeError: If a class is passed with a hook_type other than REPLACE. + """ + if isinstance(hook, type) and hook_type != HookType.REPLACE: + raise TypeError( + f"Class {hook.__name__} can only be used with HookType.REPLACE, " + f"got HookType.{hook_type.name}. " + f"Use a function for BEFORE/AFTER/AROUND hooks." + ) + resolved_source = source or _current_plugin_source.get() + # Warn on duplicate REPLACE for the same target + if hook_type == HookType.REPLACE: + existing_replace = [ + (h, src) for ht, h, src in cls._hooks[target] if ht == HookType.REPLACE + ] + if existing_replace: + prev, prev_src = existing_replace[-1] + prev_name = getattr(prev, "__qualname__", None) or repr(prev) + new_name = getattr(hook, "__qualname__", None) or repr(hook) + logger.warning( + "Multiple REPLACE hooks on '%s': previous (%s [%s]) will be " + "overridden by (%s [%s]). The last registered REPLACE takes effect.", + target, + prev_name, + _format_source(prev_src), + new_name, + _format_source(resolved_source), + ) + cls._hooks[target].append((hook_type, hook, resolved_source)) + logger.debug( + "Registered %s hook on %s [%s]", + hook_type.value, + target, + _format_source(resolved_source), + ) + + @classmethod + def apply_hooks(cls): + """ + Apply all registered hooks to their target functions/classes. + + This performs the actual monkey-patching. Should be called once after + all plugins have been loaded and before the engine starts. + + Targets with class REPLACE hooks are applied first, so that + subsequent method-level hooks (AROUND, BEFORE, AFTER) on child + attributes resolve against the *replaced* class rather than the + original. + """ + sorted_items = sorted(cls._hooks.items(), key=cls._target_sort_key) + for target, hooks in sorted_items: + if target in cls._patched: + continue + try: + cls._apply_target(target, hooks) + cls._patched.add(target) + except Exception: + logger.exception("Failed to apply hooks to %s", target) + + @staticmethod + def _target_sort_key(item): + """Sort key: class REPLACE targets (tier 0) before all others (tier 1). + + This ensures that when a class is replaced, subsequent method-level + hooks on ``ClassName.method`` resolve against the replacement class. + """ + _target, hooks = item + has_class_replace = any( + isinstance(h, type) and ht == HookType.REPLACE for ht, h, _ in hooks + ) + return (0 if has_class_replace else 1, _target) + + @classmethod + def _apply_target(cls, target: str, hooks: list): + """Resolve target, build wrapper chain, and replace the original.""" + parts = target.rsplit(".", 1) + if len(parts) != 2: + raise ValueError( + f"Invalid target path (need at least module.attr): {target}" + ) + + obj_path, attr_name = parts + obj = pkgutil.resolve_name(obj_path) + + # Check if the original is a classmethod or staticmethod by + # inspecting __dict__ before getattr() triggers the descriptor + # protocol (which would lose the wrapper type for classmethod). + original = getattr(obj, attr_name) + is_classmethod = False + is_staticmethod = False + if isinstance(obj, type): + raw_attr = obj.__dict__.get(attr_name) + if isinstance(raw_attr, classmethod): + is_classmethod = True + original = raw_attr.__func__ + elif isinstance(raw_attr, staticmethod): + is_staticmethod = True + original = raw_attr.__func__ + + # Cross-target conflict detection: if the parent object is a class + # that was already class-REPLACE'd, and the replacement class defines + # its own version of this method, a method REPLACE here will silently + # override the replacement class's implementation. + if isinstance(obj, type) and obj_path in cls._patched: + has_method_replace = any(ht == HookType.REPLACE for ht, _, _ in hooks) + if has_method_replace and attr_name in obj.__dict__: + replace_sources = [ + _format_source(src) + for ht, _, src in hooks + if ht == HookType.REPLACE + ] + logger.warning( + "Method REPLACE on '%s' will override the class REPLACE's " + "own implementation of '%s'. If this is unintended, remove " + "the method REPLACE and modify the replacement class " + "directly, or use AROUND to wrap it. (from: %s)", + target, + attr_name, + ", ".join(replace_sources), + ) + + # Guard: if the target is a class, only REPLACE is safe. Wrapping a + # class in a function would break isinstance/issubclass/inheritance. + if isinstance(original, type): + bad = [ht for ht, _, _ in hooks if ht != HookType.REPLACE] + if bad: + raise TypeError( + f"Target '{target}' is a class. Only HookType.REPLACE is " + f"allowed for class targets (got {bad[0].value}). " + f"To hook a method, use '{target}.' instead." + ) + + # Warn about risky hook combinations + hook_types = [ht for ht, _, _ in hooks] + around_count = hook_types.count(HookType.AROUND) + has_replace = HookType.REPLACE in hook_types + has_others = any(ht != HookType.REPLACE for ht in hook_types) + + if around_count > 1: + around_sources = [ + _format_source(src) for ht, _, src in hooks if ht == HookType.AROUND + ] + logger.warning( + "Multiple AROUND hooks on '%s' (%d hooks, from: %s). If any AROUND hook " + "skips calling original_fn, inner hooks will be bypassed.", + target, + around_count, + ", ".join(around_sources), + ) + if has_replace and has_others: + logger.info( + "Target '%s' has both REPLACE and %s hooks. " + "REPLACE will be applied first, then wrapped by other hooks.", + target, + ", ".join( + sorted({ht.value for ht in hook_types if ht != HookType.REPLACE}) + ), + ) + + # Build the wrapper chain. + # Sort: REPLACE hooks first (stable sort preserves registration order + # within the same type). This ensures AROUND/BEFORE/AFTER always wrap + # the replaced function, regardless of registration order. + sorted_hooks = sorted( + hooks, key=lambda h: (0 if h[0] == HookType.REPLACE else 1) + ) + wrapped = original + for hook_type, hook, _src in sorted_hooks: + if isinstance(hook, type) and hook_type == HookType.REPLACE: + # Class replacement: direct substitution to preserve type identity. + # This keeps isinstance(), issubclass(), and inheritance working. + wrapped = hook + else: + wrapped = _wrap_fn(wrapped, hook, hook_type) + + # Restore classmethod/staticmethod decorator if the original had one. + if is_classmethod: + wrapped = classmethod(wrapped) + logger.debug("Preserved @classmethod decorator for %s", target) + elif is_staticmethod: + wrapped = staticmethod(wrapped) + logger.debug("Preserved @staticmethod decorator for %s", target) + + setattr(obj, attr_name, wrapped) + + # Propagate the patch to all other modules that imported the original + # via ``from source_module import name``. Python's ``from X import Y`` + # copies the reference at import time; patching X alone leaves + # importers with a stale binding. + if wrapped is not original: + extra = _propagate_patch(original, wrapped, obj) + if extra: + logger.debug( + "Propagated patch for %s to %d additional module(s)", + target, + extra, + ) + + sources = sorted({_format_source(src) for _, _, src in hooks}) + logger.info( + "Applied %d hook(s) to %s (from: %s)", + len(hooks), + target, + ", ".join(sources), + ) + + @classmethod + def reset(cls): + """Reset all hooks and patches. Primarily for testing.""" + cls._hooks.clear() + cls._patched.clear() + + +def _propagate_patch(original: object, wrapped: object, source_module: object) -> int: + """Propagate a monkey-patch to all modules holding a stale ``from X import Y`` binding. + + After ``setattr(source_module, name, wrapped)`` updates the defining module, + other modules that did ``from source_module import name`` still hold a direct + reference to the old *original* object. This walks ``sys.modules`` and + replaces every such stale binding with *wrapped*. + + Returns the number of additional module attributes that were patched. + """ + patched_count = 0 + for mod in list(sys.modules.values()): + if mod is source_module or mod is None: + continue + if not isinstance(mod, types.ModuleType): + continue + try: + mod_vars = vars(mod) + except TypeError: + continue + for attr_name, attr_value in list(mod_vars.items()): + if attr_value is original: + try: + setattr(mod, attr_name, wrapped) + patched_count += 1 + except (AttributeError, TypeError): + pass + return patched_count + + +def _wrap_fn(original_fn: Callable, hook: Callable, hook_type: HookType) -> Callable: + """Create a wrapper function based on the hook type.""" + if hook_type == HookType.REPLACE: + + @functools.wraps(original_fn) + def wrapper(*args, **kwargs): + return hook(*args, **kwargs) + + wrapper.__wrapped__ = original_fn + return wrapper + + elif hook_type == HookType.BEFORE: + + @functools.wraps(original_fn) + def wrapper(*args, **kwargs): + result = hook(*args, **kwargs) + if result is not None: + args, kwargs = result + return original_fn(*args, **kwargs) + + wrapper.__wrapped__ = original_fn + return wrapper + + elif hook_type == HookType.AFTER: + + @functools.wraps(original_fn) + def wrapper(*args, **kwargs): + result = original_fn(*args, **kwargs) + modified = hook(result, *args, **kwargs) + return modified if modified is not None else result + + wrapper.__wrapped__ = original_fn + return wrapper + + elif hook_type == HookType.AROUND: + + @functools.wraps(original_fn) + def wrapper(*args, **kwargs): + return hook(original_fn, *args, **kwargs) + + wrapper.__wrapped__ = original_fn + return wrapper + + else: + raise ValueError(f"Unknown hook type: {hook_type}") + + +def plugin_hook( + target: str, + type: HookType = HookType.AFTER, +) -> Callable: + """Decorator that registers a function or class as a hook on *target*. + + Usage:: + + # Function hook (AROUND) + @plugin_hook("sglang.srt.managers.scheduler.Scheduler.schedule", + type=HookType.AROUND) + def my_timer(original_fn, *args, **kwargs): + start = time.perf_counter() + result = original_fn(*args, **kwargs) + print(f"Elapsed: {time.perf_counter() - start:.3f}s") + return result + + # Class replacement (REPLACE) + @plugin_hook("sglang.srt.managers.scheduler.Scheduler", + type=HookType.REPLACE) + class MyScheduler(Scheduler): + ... + + The decorated function/class is returned unchanged so it can still be + used directly if needed. + """ + + def decorator(hook: Callable) -> Callable: + HookRegistry.register(target, hook, type) + return hook + + return decorator diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 654d4765952d..ce444c0ed2e5 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -790,6 +790,11 @@ def __post_init__(self): self._handle_mps_backends() self._handle_xpu_backends() + # Allow OOT platform plugins to apply server args defaults. + from sglang.srt.platforms import current_platform + + current_platform.apply_server_args_defaults(self) + # Handle piecewise CUDA graph. self._handle_piecewise_cuda_graph() @@ -1157,6 +1162,12 @@ def _handle_piecewise_cuda_graph(self): # 5. Non-CUDA hardware (AMD, NPU, CPU, MPS, XPU, etc.) if is_hip() or is_npu() or is_cpu() or is_mps() or is_xpu(): self.disable_piecewise_cuda_graph = True + # 5b. OOT platforms that don't support piecewise cuda graph + from sglang.srt.platforms import current_platform + + if current_platform.is_out_of_tree(): + if not current_platform.support_piecewise_cuda_graph(): + self.disable_piecewise_cuda_graph = True # 6. MoE A2A backend if self.moe_a2a_backend != "none": self.disable_piecewise_cuda_graph = True @@ -2324,6 +2335,12 @@ def _get_default_attn_backend(self, use_mla_backend: bool, model_config): 2.2 We will use Flashinfer backend on blackwell. 2.3 Otherwise, we will use triton backend. """ + # OOT platforms provide their own default attention backend. + from sglang.srt.platforms import current_platform + + if current_platform.is_out_of_tree(): + return current_platform.get_default_attention_backend() + # Whisper requires flashinfer for cross-attention CUDA graph support if "WhisperForConditionalGeneration" in ( model_config.hf_config.architectures or [] diff --git a/python/sglang/srt/utils/common.py b/python/sglang/srt/utils/common.py index a6be5c606c4f..a91d42ae58a6 100644 --- a/python/sglang/srt/utils/common.py +++ b/python/sglang/srt/utils/common.py @@ -590,6 +590,18 @@ def get_available_gpu_memory( free_gpu_memory, total_gpu_memory = torch.musa.mem_get_info() elif device == "mps": free_gpu_memory = psutil.virtual_memory().available + else: + from sglang.srt.platforms import current_platform + + if not current_platform.is_out_of_tree(): + raise ValueError( + f"Unsupported device type: {device!r}. " + "If this is an OOT platform, ensure it is properly registered " + "via the 'sglang.platform_plugins' entry point." + ) + total_mem = current_platform.get_device_total_memory(gpu_id) + used_mem = current_platform.get_current_memory_usage() + free_gpu_memory = total_mem - used_mem if distributed: tensor = torch.tensor(free_gpu_memory, dtype=torch.float32) @@ -1671,6 +1683,14 @@ def get_mtgpu_memory_capacity(): def get_device_memory_capacity(device: str = None): + # OOT platforms provide their own memory query via the platform class. + from sglang.srt.platforms import current_platform + + if current_platform.is_out_of_tree(): + mem_bytes = current_platform.get_device_total_memory() + if mem_bytes: + return mem_bytes / (1 << 20) # bytes -> MiB + return None if is_cuda(): gpu_mem = get_nvgpu_memory_capacity() elif is_hip(): @@ -1913,6 +1933,12 @@ def get_device_capability(device_id: int = 0) -> Tuple[int, int]: def get_compiler_backend(mode=None) -> str: + # OOT platforms provide their own compile backend. + from sglang.srt.platforms import current_platform + + if current_platform.is_out_of_tree(): + return current_platform.get_compile_backend(mode) + if hasattr(torch, "hpu") and torch.hpu.is_available(): return "hpu_backend" diff --git a/test/registered/unit/platforms/test_platform_interface.py b/test/registered/unit/platforms/test_platform_interface.py new file mode 100644 index 000000000000..e77a83c6fd7a --- /dev/null +++ b/test/registered/unit/platforms/test_platform_interface.py @@ -0,0 +1,478 @@ +""" +Unit tests for SGLang platform abstraction layer. + +Tests DeviceMixin, SRTPlatform, PlatformEnum, CpuArchEnum, DeviceCapability, +and the platform discovery / lazy initialization mechanism. +""" + +from unittest.mock import MagicMock, patch + +from sglang.srt.platforms import _load_platform_class, _resolve_platform +from sglang.srt.platforms.device_mixin import ( + CpuArchEnum, + DeviceCapability, + DeviceMixin, + PlatformEnum, +) +from sglang.srt.platforms.interface import SRTPlatform +from sglang.test.ci.ci_register import register_cpu_ci +from sglang.test.test_utils import CustomTestCase + +register_cpu_ci(est_time=10, suite="stage-a-test-cpu") + + +# --------------------------------------------------------------------------- +# Helpers: factory functions to reduce boilerplate +# --------------------------------------------------------------------------- + + +def _make_device_mixin(enum, name, dtype): + """Create a concrete DeviceMixin subclass for testing.""" + + class M(DeviceMixin): + _enum = enum + device_name = name + device_type = dtype + + def get_device_total_memory(self, device_id=0): + return 10**9 + + def get_current_memory_usage(self, device=None): + return 5 * 10**8 + + return M() + + +class _StubPlatform(SRTPlatform): + """Concrete SRTPlatform with minimal defaults for testing overrides.""" + + _enum = PlatformEnum.CUDA + device_name = "cuda" + device_type = "cuda" + + def get_device_total_memory(self, device_id=0): + return 10**9 + + def get_current_memory_usage(self, device=None): + return 5 * 10**8 + + def get_default_attention_backend(self): + return "flashinfer" + + def get_graph_runner_cls(self): + return object + + def get_mha_kv_pool_cls(self): + return object + + def get_mla_kv_pool_cls(self): + return object + + def get_nsa_kv_pool_cls(self): + return object + + def get_paged_allocator_cls(self): + return object + + def get_piecewise_backend_cls(self): + return object + + +def _make_platform_ep(name, load_fn=None): + """Create a mock entry point for platform plugins.""" + ep = MagicMock() + ep.name = name + if load_fn is not None: + ep.load.return_value = load_fn + else: + ep.load.return_value = MagicMock() + return ep + + +# --------------------------------------------------------------------------- +# PlatformEnum & CpuArchEnum +# --------------------------------------------------------------------------- + + +class TestPlatformEnum(CustomTestCase): + """Tests for PlatformEnum enumeration.""" + + def test_all_expected_values_exist(self): + expected = { + "CUDA", + "ROCM", + "CPU", + "XPU", + "MUSA", + "NPU", + "TPU", + "MPS", + "OOT", + "UNSPECIFIED", + } + actual = {member.name for member in PlatformEnum} + self.assertEqual(actual, expected) + + +class TestCpuArchEnum(CustomTestCase): + """Tests for CpuArchEnum enumeration.""" + + def test_all_expected_values_exist(self): + expected = {"X86", "ARM", "UNSPECIFIED"} + actual = {member.name for member in CpuArchEnum} + self.assertEqual(actual, expected) + + +# --------------------------------------------------------------------------- +# DeviceCapability +# --------------------------------------------------------------------------- + + +class TestDeviceCapability(CustomTestCase): + """Tests for DeviceCapability custom logic (formatting, conversion).""" + + def test_as_version_str(self): + self.assertEqual(DeviceCapability(major=9, minor=0).as_version_str(), "9.0") + self.assertEqual(DeviceCapability(major=8, minor=9).as_version_str(), "8.9") + + def test_to_int(self): + self.assertEqual(DeviceCapability(major=9, minor=0).to_int(), 90) + self.assertEqual(DeviceCapability(major=8, minor=9).to_int(), 89) + self.assertEqual(DeviceCapability(major=0, minor=0).to_int(), 0) + + +# --------------------------------------------------------------------------- +# DeviceMixin +# --------------------------------------------------------------------------- + +# Platform identity test data: (enum, name, dtype, true_method) +_PLATFORM_IDENTITY = [ + (PlatformEnum.CUDA, "cuda", "cuda", "is_cuda"), + (PlatformEnum.ROCM, "rocm", "hip", "is_rocm"), + (PlatformEnum.CPU, "cpu", "cpu", "is_cpu"), + (PlatformEnum.XPU, "xpu", "xpu", "is_xpu"), + (PlatformEnum.MUSA, "musa", "musa", "is_musa"), + (PlatformEnum.NPU, "npu", "npu", "is_npu"), + (PlatformEnum.TPU, "tpu", "tpu", "is_tpu"), + (PlatformEnum.MPS, "mps", "mps", "is_mps"), +] + +# is_cuda_alike test data: (enum, name, dtype, expected) +_CUDA_ALIKE = [ + (PlatformEnum.CUDA, "cuda", "cuda", True), + (PlatformEnum.ROCM, "rocm", "hip", True), + (PlatformEnum.MUSA, "musa", "musa", True), + (PlatformEnum.CPU, "cpu", "cpu", False), + (PlatformEnum.NPU, "npu", "npu", False), +] + + +class TestDeviceMixin(CustomTestCase): + """Tests for DeviceMixin base class.""" + + def test_platform_identity_methods(self): + """Each platform type returns True for its identity method.""" + for enum_val, name, dtype, method in _PLATFORM_IDENTITY: + with self.subTest(method=method, enum=enum_val.name): + mixin = _make_device_mixin(enum_val, name, dtype) + self.assertTrue(getattr(mixin, method)()) + + def test_is_cuda_alike(self): + """is_cuda_alike is True for CUDA/ROCM/MUSA, False otherwise.""" + for enum_val, name, dtype, expected in _CUDA_ALIKE: + with self.subTest(enum=enum_val.name): + mixin = _make_device_mixin(enum_val, name, dtype) + self.assertEqual(mixin.is_cuda_alike(), expected) + + def test_is_out_of_tree(self): + oot = _make_device_mixin(PlatformEnum.OOT, "custom", "custom") + self.assertTrue(oot.is_out_of_tree()) + cuda = _make_device_mixin(PlatformEnum.CUDA, "cuda", "cuda") + self.assertFalse(cuda.is_out_of_tree()) + + @patch("platform.machine") + def test_get_cpu_architecture(self, mock_machine): + """get_cpu_architecture maps common strings to CpuArchEnum.""" + cases = [ + ("x86_64", CpuArchEnum.X86), + ("amd64", CpuArchEnum.X86), + ("i386", CpuArchEnum.X86), + ("i686", CpuArchEnum.X86), + ("X86_64", CpuArchEnum.X86), # case insensitive + ("arm64", CpuArchEnum.ARM), + ("aarch64", CpuArchEnum.ARM), + ("unknown_arch", CpuArchEnum.UNSPECIFIED), + ] + for machine_str, expected in cases: + with self.subTest(machine=machine_str): + mock_machine.return_value = machine_str + self.assertEqual(DeviceMixin.get_cpu_architecture(), expected) + + +# --------------------------------------------------------------------------- +# SRTPlatform +# --------------------------------------------------------------------------- + + +class TestSRTPlatform(CustomTestCase): + """Tests for SRTPlatform base class and default behaviors.""" + + def test_compile_backend_signature_compatibility(self): + """get_compile_backend accepts mode keyword arg without error.""" + base = SRTPlatform() + self.assertEqual(base.get_compile_backend(mode="npugraph_ex"), "inductor") + + +class TestSRTPlatformOverrides(CustomTestCase): + """Tests for SRTPlatform method overrides via plugins.""" + + def test_custom_get_dispatch_key_name(self): + class P(_StubPlatform): + _enum = PlatformEnum.NPU + device_name = "npu" + device_type = "npu" + + def get_dispatch_key_name(self): + return "npu" + + self.assertEqual(P().get_dispatch_key_name(), "npu") + + def test_custom_get_compile_backend(self): + class P(_StubPlatform): + _enum = PlatformEnum.NPU + device_name = "npu" + device_type = "npu" + + def get_compile_backend(self, mode=None): + return "inductor" + + self.assertEqual(P().get_compile_backend(mode="npugraph_ex"), "inductor") + + +# --------------------------------------------------------------------------- +# Platform Discovery: _resolve_platform +# --------------------------------------------------------------------------- + + +class TestResolvePlatformWithEnv(CustomTestCase): + """Tests for _resolve_platform when SGLANG_PLATFORM is set.""" + + @patch("sglang.srt.platforms.entry_points") + @patch("sglang.srt.platforms.envs") + def test_selected_plugin_activates(self, mock_envs, mock_ep): + """When SGLANG_PLATFORM matches an entry point, it activates that plugin.""" + mock_envs.SGLANG_PLATFORM.get.return_value = "my_hardware" + plugin_fn = MagicMock(return_value="pkg.Mod:MyPlatform") + mock_ep.return_value = [_make_platform_ep("my_hardware", plugin_fn)] + with patch("sglang.srt.platforms._load_platform_class") as mock_load: + mock_instance = MagicMock() + mock_load.return_value = MagicMock(return_value=mock_instance) + result = _resolve_platform() + mock_load.assert_called_once_with("pkg.Mod:MyPlatform") + self.assertEqual(result, mock_instance) + + @patch("sglang.srt.platforms.entry_points") + @patch("sglang.srt.platforms.envs") + def test_selected_plugin_not_found(self, mock_envs, mock_ep): + """When SGLANG_PLATFORM names a nonexistent plugin, raise RuntimeError.""" + mock_envs.SGLANG_PLATFORM.get.return_value = "nonexistent" + mock_ep.return_value = [] + with self.assertRaises(RuntimeError): + _resolve_platform() + + @patch("sglang.srt.platforms.entry_points") + @patch("sglang.srt.platforms.envs") + def test_selected_plugin_hardware_unavailable(self, mock_envs, mock_ep): + """When activate() returns None, hardware is not available.""" + mock_envs.SGLANG_PLATFORM.get.return_value = "my_hardware" + plugin_fn = MagicMock(return_value=None) + mock_ep.return_value = [_make_platform_ep("my_hardware", plugin_fn)] + with self.assertRaises(RuntimeError): + _resolve_platform() + + @patch("sglang.srt.platforms.entry_points") + @patch("sglang.srt.platforms.envs") + def test_selected_plugin_load_exception(self, mock_envs, mock_ep): + """When ep.load() or activate() throws, exception is re-raised.""" + mock_envs.SGLANG_PLATFORM.get.return_value = "my_hardware" + plugin_fn = MagicMock(side_effect=ImportError("missing dep")) + mock_ep.return_value = [_make_platform_ep("my_hardware", plugin_fn)] + with self.assertRaises(ImportError): + _resolve_platform() + + @patch("sglang.srt.platforms.entry_points") + @patch("sglang.srt.platforms.envs") + def test_other_plugins_not_loaded(self, mock_envs, mock_ep): + """When SGLANG_PLATFORM is set, other plugins are not imported.""" + mock_envs.SGLANG_PLATFORM.get.return_value = "target_hw" + target_fn = MagicMock(return_value="pkg.Mod:TargetPlatform") + other_ep = _make_platform_ep("other_hw") # default load returns MagicMock + target_ep = _make_platform_ep("target_hw", target_fn) + mock_ep.return_value = [other_ep, target_ep] + with patch("sglang.srt.platforms._load_platform_class") as mock_load: + mock_load.return_value = MagicMock(return_value=MagicMock()) + _resolve_platform() + # Only the target entry point should be loaded + target_ep.load.assert_called_once() + other_ep.load.assert_not_called() + + +class TestResolvePlatformAutoDiscover(CustomTestCase): + """Tests for _resolve_platform auto-discovery when SGLANG_PLATFORM is not set.""" + + @patch("sglang.srt.platforms.load_plugins_by_group") + @patch("sglang.srt.platforms.envs") + def test_single_plugin_activates(self, mock_envs, mock_load): + """When exactly one plugin activates, return its platform instance.""" + mock_envs.SGLANG_PLATFORM.get.return_value = "" + plugin_fn = MagicMock(return_value="pkg.Mod:MyPlatform") + mock_load.return_value = {"my_hw": (plugin_fn, "my-hw-dist")} + with patch("sglang.srt.platforms._load_platform_class") as mock_resolve: + mock_instance = MagicMock() + mock_resolve.return_value = MagicMock(return_value=mock_instance) + result = _resolve_platform() + mock_resolve.assert_called_once_with("pkg.Mod:MyPlatform") + self.assertEqual(result, mock_instance) + + @patch("sglang.srt.platforms.load_plugins_by_group") + @patch("sglang.srt.platforms.envs") + def test_no_plugin_activates_fallback(self, mock_envs, mock_load): + """When no plugin activates, return base SRTPlatform with warning.""" + mock_envs.SGLANG_PLATFORM.get.return_value = "" + mock_load.return_value = {} + result = _resolve_platform() + self.assertIsInstance(result, SRTPlatform) + + @patch("sglang.srt.platforms.load_plugins_by_group") + @patch("sglang.srt.platforms.envs") + def test_multiple_plugins_activate_raises(self, mock_envs, mock_load): + """When multiple plugins activate, raise RuntimeError.""" + mock_envs.SGLANG_PLATFORM.get.return_value = "" + fn1 = MagicMock(return_value="pkg1.Mod:Platform1") + fn2 = MagicMock(return_value="pkg2.Mod:Platform2") + mock_load.return_value = {"hw1": (fn1, "hw1-dist"), "hw2": (fn2, "hw2-dist")} + with self.assertRaises(RuntimeError): + _resolve_platform() + + @patch("sglang.srt.platforms.load_plugins_by_group") + @patch("sglang.srt.platforms.envs") + def test_plugin_exception_does_not_crash(self, mock_envs, mock_load): + """When a plugin's activate() throws, it is skipped, others continue.""" + mock_envs.SGLANG_PLATFORM.get.return_value = "" + bad_fn = MagicMock(side_effect=RuntimeError("broken")) + good_fn = MagicMock(return_value="pkg.Mod:GoodPlatform") + mock_load.return_value = { + "bad": (bad_fn, "bad-dist"), + "good": (good_fn, "good-dist"), + } + with patch("sglang.srt.platforms._load_platform_class") as mock_resolve: + mock_instance = MagicMock() + mock_resolve.return_value = MagicMock(return_value=mock_instance) + result = _resolve_platform() + mock_resolve.assert_called_once_with("pkg.Mod:GoodPlatform") + self.assertEqual(result, mock_instance) + + @patch("sglang.srt.platforms.load_plugins_by_group") + @patch("sglang.srt.platforms.envs") + def test_plugin_returns_none_is_skipped(self, mock_envs, mock_load): + """When a plugin's activate() returns None, it is skipped (hardware unavailable).""" + mock_envs.SGLANG_PLATFORM.get.return_value = "" + none_fn = MagicMock(return_value=None) + good_fn = MagicMock(return_value="pkg.Mod:GoodPlatform") + mock_load.return_value = { + "unavailable": (none_fn, "unavail-dist"), + "good": (good_fn, "good-dist"), + } + with patch("sglang.srt.platforms._load_platform_class") as mock_resolve: + mock_instance = MagicMock() + mock_resolve.return_value = MagicMock(return_value=mock_instance) + result = _resolve_platform() + # Only the good plugin activated; single activation succeeds + mock_resolve.assert_called_once_with("pkg.Mod:GoodPlatform") + + +# --------------------------------------------------------------------------- +# Platform Discovery: _load_platform_class +# --------------------------------------------------------------------------- + + +class TestLoadPlatformClass(CustomTestCase): + """Tests for _load_platform_class qualname resolution.""" + + @patch("sglang.srt.platforms.pkgutil.resolve_name") + def test_valid_subclass(self, mock_resolve): + """Valid SRTPlatform subclass resolves successfully.""" + mock_resolve.return_value = type("MyPlatform", (SRTPlatform,), {}) + result = _load_platform_class("pkg.Mod:MyPlatform") + self.assertTrue(issubclass(result, SRTPlatform)) + + @patch("sglang.srt.platforms.pkgutil.resolve_name") + def test_non_subclass_raises_type_error(self, mock_resolve): + """Non-SRTPlatform class raises TypeError.""" + mock_resolve.return_value = str + with self.assertRaises(TypeError): + _load_platform_class("builtins.str") + + @patch("sglang.srt.platforms.pkgutil.resolve_name") + def test_non_type_raises_type_error(self, mock_resolve): + """Non-type object raises TypeError.""" + mock_resolve.return_value = "not a class" + with self.assertRaises(TypeError): + _load_platform_class("something") + + +# --------------------------------------------------------------------------- +# Platform Discovery: current_platform lazy init +# --------------------------------------------------------------------------- + + +class TestCurrentPlatformLazyInit(CustomTestCase): + """Tests for current_platform lazy initialization via module __getattr__.""" + + def setUp(self): + """Reset module-level cache before each test.""" + import sglang.srt.platforms as plat_mod + + self._saved_platform = plat_mod._current_platform + plat_mod._current_platform = None + + def tearDown(self): + """Restore original _current_platform after each test.""" + import sglang.srt.platforms as plat_mod + + plat_mod._current_platform = self._saved_platform + + @patch("sglang.srt.platforms._resolve_platform") + def test_first_access_triggers_resolve(self, mock_resolve): + """First access to current_platform calls _resolve_platform.""" + mock_instance = MagicMock(spec=SRTPlatform) + mock_resolve.return_value = mock_instance + import sglang.srt.platforms as plat_mod + + result = plat_mod.current_platform + mock_resolve.assert_called_once() + self.assertEqual(result, mock_instance) + + @patch("sglang.srt.platforms._resolve_platform") + def test_subsequent_access_uses_cache(self, mock_resolve): + """Subsequent accesses return cached instance without re-resolving.""" + mock_instance = MagicMock(spec=SRTPlatform) + mock_resolve.return_value = mock_instance + import sglang.srt.platforms as plat_mod + + _ = plat_mod.current_platform + _ = plat_mod.current_platform + mock_resolve.assert_called_once() + + def test_other_attribute_raises_error(self): + """Accessing non-existent module attribute raises AttributeError.""" + import sglang.srt.platforms as plat_mod + + with self.assertRaises(AttributeError): + _ = plat_mod.nonexistent_attribute + + +if __name__ == "__main__": + import unittest + + unittest.main() diff --git a/test/registered/unit/plugins/__init__.py b/test/registered/unit/plugins/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/registered/unit/plugins/test_hook_registry.py b/test/registered/unit/plugins/test_hook_registry.py new file mode 100644 index 000000000000..4f066189aa0a --- /dev/null +++ b/test/registered/unit/plugins/test_hook_registry.py @@ -0,0 +1,448 @@ +""" +Unit tests for the hook registry system. + +Covers: basic hooks (AROUND/BEFORE/AFTER/REPLACE), descriptor preservation +(classmethod/staticmethod), hook ordering, cross-target conflict detection, +patch propagation, and edge cases. + +Run: python -m pytest test/registered/unit/plugins/test_hook_registry.py -v +""" + +import sys +import types +import uuid + +from sglang.srt.plugins.hook_registry import HookRegistry, HookType, plugin_hook +from sglang.test.ci.ci_register import register_cpu_ci +from sglang.test.test_utils import CustomTestCase + +register_cpu_ci(est_time=10, suite="stage-a-test-cpu") + +# --------------------------------------------------------------------------- +# Helpers: synthetic module creation +# --------------------------------------------------------------------------- + +_SYNTH_MODULE_PREFIX = "_synth_hook_test_" + + +def _make_module(**attrs): + """Create a throwaway module registered in sys.modules.""" + name = f"{_SYNTH_MODULE_PREFIX}{uuid.uuid4().hex[:8]}" + mod = types.ModuleType(name) + for k, v in attrs.items(): + setattr(mod, k, v) + sys.modules[name] = mod + return mod, name + + +def _cleanup_synth_modules(): + """Remove all synthetic modules from sys.modules.""" + to_del = [k for k in sys.modules if k.startswith(_SYNTH_MODULE_PREFIX)] + for k in to_del: + del sys.modules[k] + + +# --------------------------------------------------------------------------- +# Base class for hook tests (shared setUp/tearDown) +# --------------------------------------------------------------------------- + + +class _HookTestCase(CustomTestCase): + """Base class that resets HookRegistry and cleans up synth modules.""" + + def setUp(self): + HookRegistry.reset() + _cleanup_synth_modules() + + def tearDown(self): + HookRegistry.reset() + _cleanup_synth_modules() + + +# =========================================================================== +# TestBasicHooks +# =========================================================================== + + +class TestBasicHooks(_HookTestCase): + """AROUND / BEFORE / AFTER / REPLACE on plain functions, class REPLACE, + and the @plugin_hook decorator.""" + + def test_around_function(self): + def orig(x): + return x * 2 + + mod, name = _make_module(orig=orig) + + def add_one(original_fn, x): + return original_fn(x) + 1 + + HookRegistry.register(f"{name}.orig", add_one, HookType.AROUND) + HookRegistry.apply_hooks() + self.assertEqual(mod.orig(3), 7) # 3*2 + 1 + + def test_before_modifies_args(self): + """BEFORE hook returns (args, kwargs) to modify arguments.""" + + def orig(x, y=0): + return x + y + + mod, name = _make_module(orig=orig) + + def double_x(x, y=0): + return (x * 2,), {"y": y + 1} + + HookRegistry.register(f"{name}.orig", double_x, HookType.BEFORE) + HookRegistry.apply_hooks() + self.assertEqual(mod.orig(3), 7) # x=3*2=6, y=0+1=1, 6+1=7 + + def test_before_returning_none(self): + """BEFORE hook returning None leaves arguments unchanged.""" + + def orig(x): + return x * 2 + + mod, name = _make_module(orig=orig) + + def before_noop(x): + return None # leave args unchanged + + HookRegistry.register(f"{name}.orig", before_noop, HookType.BEFORE) + HookRegistry.apply_hooks() + self.assertEqual(mod.orig(3), 6) # args unchanged + + def test_after_function(self): + def orig(x): + return x * 2 + + mod, name = _make_module(orig=orig) + + def add_ten(result, x): + return result + 10 + + HookRegistry.register(f"{name}.orig", add_ten, HookType.AFTER) + HookRegistry.apply_hooks() + self.assertEqual(mod.orig(3), 16) # 3*2 + 10 + + def test_replace_function(self): + def orig(x): + return x * 2 + + mod, name = _make_module(orig=orig) + + def replacement(x): + return x * 100 + + HookRegistry.register(f"{name}.orig", replacement, HookType.REPLACE) + HookRegistry.apply_hooks() + self.assertEqual(mod.orig(3), 300) + + def test_class_replace(self): + class Original: + def greet(self): + return "original" + + mod, name = _make_module(Original=Original) + + class Replacement(Original): + def greet(self): + return "replaced" + + HookRegistry.register(f"{name}.Original", Replacement, HookType.REPLACE) + HookRegistry.apply_hooks() + + self.assertIs(mod.Original, Replacement) + self.assertIsInstance(mod.Original(), Replacement) + self.assertEqual(mod.Original().greet(), "replaced") + + def test_plugin_hook_decorator(self): + def orig(x): + return x + + mod, name = _make_module(orig=orig) + + @plugin_hook(f"{name}.orig", type=HookType.REPLACE) + def my_replace(x): + return x + 42 + + HookRegistry.apply_hooks() + self.assertEqual(mod.orig(0), 42) + + +# =========================================================================== +# TestDescriptorPreservation (Bug B regression tests) +# =========================================================================== + + +class TestDescriptorPreservation(_HookTestCase): + """Hooks on classmethod/staticmethod must preserve descriptor semantics.""" + + def _make_cls_module(self): + class MyClass: + @classmethod + def cm(cls, x): + return ("cm", cls.__name__, x) + + @staticmethod + def sm(x): + return ("sm", x) + + mod, name = _make_module(MyClass=MyClass) + return mod, name, MyClass + + def test_around_classmethod(self): + mod, name, MyClass = self._make_cls_module() + + def add_tag(original_fn, cls, x): + return original_fn(cls, x) + ("around",) + + HookRegistry.register(f"{name}.MyClass.cm", add_tag, HookType.AROUND) + HookRegistry.apply_hooks() + + result = mod.MyClass.cm(1) + self.assertEqual(result, ("cm", "MyClass", 1, "around")) + + def test_replace_classmethod(self): + mod, name, MyClass = self._make_cls_module() + + def new_cm(cls, x): + return ("replaced_cm", cls.__name__, x) + + HookRegistry.register(f"{name}.MyClass.cm", new_cm, HookType.REPLACE) + HookRegistry.apply_hooks() + + result = mod.MyClass.cm(1) + self.assertEqual(result, ("replaced_cm", "MyClass", 1)) + + def test_around_staticmethod(self): + mod, name, MyClass = self._make_cls_module() + + def wrap_sm(original_fn, x): + return original_fn(x) + ("around",) + + HookRegistry.register(f"{name}.MyClass.sm", wrap_sm, HookType.AROUND) + HookRegistry.apply_hooks() + + result = mod.MyClass.sm(1) + self.assertEqual(result, ("sm", 1, "around")) + + def test_replace_staticmethod(self): + mod, name, MyClass = self._make_cls_module() + + def new_sm(x): + return ("replaced_sm", x) + + HookRegistry.register(f"{name}.MyClass.sm", new_sm, HookType.REPLACE) + HookRegistry.apply_hooks() + + result = mod.MyClass.sm(1) + self.assertEqual(result, ("replaced_sm", 1)) + + def test_classmethod_subclass_cls(self): + mod, name, MyClass = self._make_cls_module() + + def add_tag(original_fn, cls, x): + return original_fn(cls, x) + ("around",) + + HookRegistry.register(f"{name}.MyClass.cm", add_tag, HookType.AROUND) + HookRegistry.apply_hooks() + + class Sub(mod.MyClass): + pass + + result = Sub.cm(1) + self.assertEqual(result, ("cm", "Sub", 1, "around")) + + +# =========================================================================== +# TestHookOrdering +# =========================================================================== + + +class TestHookOrdering(_HookTestCase): + """Verify REPLACE is applied first, then other hooks wrap it.""" + + def test_replace_then_around(self): + def orig(x): + return x + + mod, name = _make_module(orig=orig) + + def repl(x): + return x * 10 + + def add_one(original_fn, x): + return original_fn(x) + 1 + + HookRegistry.register(f"{name}.orig", repl, HookType.REPLACE) + HookRegistry.register(f"{name}.orig", add_one, HookType.AROUND) + HookRegistry.apply_hooks() + # REPLACE first: x*10, then AROUND: +1 => 31 + self.assertEqual(mod.orig(3), 31) + + def test_replace_before_after(self): + def orig(x): + return x + + mod, name = _make_module(orig=orig) + + def repl(x): + return x * 10 + + def double_arg(x): + return (x * 2,), {} + + def add_hundred(result, x): + return result + 100 + + HookRegistry.register(f"{name}.orig", repl, HookType.REPLACE) + HookRegistry.register(f"{name}.orig", double_arg, HookType.BEFORE) + HookRegistry.register(f"{name}.orig", add_hundred, HookType.AFTER) + HookRegistry.apply_hooks() + # BEFORE doubles x: 3*2=6 → REPLACE: 6*10=60 → AFTER: 60+100=160 + self.assertEqual(mod.orig(3), 160) + + +# =========================================================================== +# TestCrossTargetConflict +# =========================================================================== + + +class TestCrossTargetConflict(_HookTestCase): + """Verify warning for class REPLACE + method REPLACE combo.""" + + def test_class_replace_then_method_replace_warns(self): + class Original: + def foo(self): + return "orig" + + mod, name = _make_module(Original=Original) + + class Replacement(Original): + def foo(self): + return "class_replaced" + + HookRegistry.register(f"{name}.Original", Replacement, HookType.REPLACE) + + def method_repl(self): + return "method_replaced" + + HookRegistry.register(f"{name}.Original.foo", method_repl, HookType.REPLACE) + + with self.assertLogs("sglang.srt.plugins.hook_registry", level="WARNING") as cm: + HookRegistry.apply_hooks() + + self.assertTrue(any("will override" in msg for msg in cm.output)) + + +# =========================================================================== +# TestPatchPropagation +# =========================================================================== + + +class TestPatchPropagation(_HookTestCase): + """Verify that patches propagate to other modules that imported the target.""" + + def test_same_reference_propagates(self): + def orig(x): + return x * 2 + + source_mod, source_name = _make_module(orig=orig) + importer_mod, _ = _make_module(orig=orig) # same reference + + def add_one(fn, x): + return fn(x) + 1 + + HookRegistry.register(f"{source_name}.orig", add_one, HookType.AROUND) + HookRegistry.apply_hooks() + + self.assertEqual(source_mod.orig(3), 7) + self.assertEqual(importer_mod.orig(3), 7) + + +# =========================================================================== +# TestEdgeCases +# =========================================================================== + + +class TestEdgeCases(_HookTestCase): + """Reset, type validation, multi-AROUND onion, idempotent apply.""" + + def test_reset(self): + def orig(x): + return x + + mod, name = _make_module(orig=orig) + + def noop(fn, x): + return fn(x) + + HookRegistry.register(f"{name}.orig", noop, HookType.AROUND) + HookRegistry.reset() + + HookRegistry.apply_hooks() + self.assertEqual(mod.orig(3), 3) + + def test_register_class_with_wrong_type(self): + class BadHook: + pass + + for ht in (HookType.BEFORE, HookType.AFTER, HookType.AROUND): + with self.assertRaises(TypeError): + HookRegistry.register("some.target", BadHook, ht) + + def test_multi_around_onion(self): + call_order = [] + + def orig(x): + call_order.append("orig") + return x + + mod, name = _make_module(orig=orig) + + def around1(fn, x): + call_order.append("a1_before") + result = fn(x) + call_order.append("a1_after") + return result + 1 + + def around2(fn, x): + call_order.append("a2_before") + result = fn(x) + call_order.append("a2_after") + return result + 10 + + HookRegistry.register(f"{name}.orig", around1, HookType.AROUND) + HookRegistry.register(f"{name}.orig", around2, HookType.AROUND) + HookRegistry.apply_hooks() + + result = mod.orig(0) + self.assertEqual(result, 11) + self.assertEqual( + call_order, ["a2_before", "a1_before", "orig", "a1_after", "a2_after"] + ) + + def test_apply_idempotent(self): + call_count = [0] + + def orig(x): + return x + + mod, name = _make_module(orig=orig) + + def counter(fn, x): + call_count[0] += 1 + return fn(x) + + HookRegistry.register(f"{name}.orig", counter, HookType.AROUND) + HookRegistry.apply_hooks() + HookRegistry.apply_hooks() # second apply should be no-op + + mod.orig(1) + self.assertEqual(call_count[0], 1) + + +if __name__ == "__main__": + import unittest + + unittest.main() diff --git a/test/registered/unit/plugins/test_load_plugins.py b/test/registered/unit/plugins/test_load_plugins.py new file mode 100644 index 000000000000..218f0fc276bb --- /dev/null +++ b/test/registered/unit/plugins/test_load_plugins.py @@ -0,0 +1,187 @@ +""" +Unit tests for the plugin loading flow. + +Covers: idempotency, apply_hooks invocation, exception resilience, +SGLANG_PLUGINS whitelist, SGLANG_PLATFORM exclusion logic, +and _current_plugin_source context var reset. + +Run: python -m pytest test/registered/unit/plugins/test_load_plugins.py -v +""" + +from unittest.mock import MagicMock, patch + +from sglang.srt.plugins import ( + _current_plugin_source, + _get_excluded_dists, + load_plugins, + load_plugins_by_group, +) +from sglang.test.ci.ci_register import register_cpu_ci +from sglang.test.test_utils import CustomTestCase + +register_cpu_ci(est_time=10, suite="stage-a-test-cpu") + + +def _make_ep(name, dist_name=None, load_fn=None): + """Create a mock entry point.""" + ep = MagicMock() + ep.name = name + ep.value = f"fake_module:{name}" + ep.dist = MagicMock() + ep.dist.name = dist_name or f"{name}-dist" + if load_fn is not None: + ep.load.return_value = load_fn + else: + ep.load.return_value = MagicMock() + return ep + + +def _reset_plugins_loaded(): + """Reset the _plugins_loaded flag so load_plugins() can run again.""" + import sglang.srt.plugins as plugins_mod + + plugins_mod._plugins_loaded = False + + +class TestLoadPlugins(CustomTestCase): + """Tests for load_plugins() and related helpers.""" + + def setUp(self): + _reset_plugins_loaded() + + def tearDown(self): + _reset_plugins_loaded() + + @patch("sglang.srt.plugins.HookRegistry") + @patch("sglang.srt.plugins.envs") + @patch("sglang.srt.plugins.entry_points", return_value=[]) + def test_load_plugins_idempotent_and_calls_apply( + self, mock_eps, mock_envs, mock_registry + ): + """Second call is a no-op; first call invokes apply_hooks.""" + mock_envs.SGLANG_PLATFORM.get.return_value = "" + mock_envs.SGLANG_PLUGINS.get.return_value = "" + + load_plugins() + self.assertEqual(mock_registry.apply_hooks.call_count, 1) + + load_plugins() # should be skipped + self.assertEqual(mock_registry.apply_hooks.call_count, 1) + + @patch("sglang.srt.plugins.HookRegistry") + @patch("sglang.srt.plugins.envs") + @patch("sglang.srt.plugins.entry_points") + def test_plugin_exception_does_not_crash(self, mock_eps, mock_envs, mock_registry): + """A failing plugin should not prevent others from loading.""" + mock_envs.SGLANG_PLATFORM.get.return_value = "" + mock_envs.SGLANG_PLUGINS.get.return_value = "" + + def bad_plugin(): + raise RuntimeError("boom") + + good_call_log = [] + + def good_plugin(): + good_call_log.append("ok") + + eps = [ + _make_ep("bad", load_fn=bad_plugin), + _make_ep("good", load_fn=good_plugin), + ] + mock_eps.return_value = eps + + with self.assertLogs("sglang.srt.plugins", level="ERROR") as cm: + load_plugins() + + self.assertTrue(any("boom" in msg for msg in cm.output)) + self.assertEqual(good_call_log, ["ok"]) + mock_registry.apply_hooks.assert_called_once() + + @patch("sglang.srt.plugins.entry_points") + @patch("sglang.srt.plugins.envs") + def test_sglang_plugins_whitelist(self, mock_envs, mock_eps): + """Only plugins named in SGLANG_PLUGINS should be loaded.""" + mock_envs.SGLANG_PLUGINS.get.return_value = "alpha,gamma" + mock_envs.SGLANG_PLATFORM.get.return_value = "" + + alpha_fn = MagicMock() + beta_fn = MagicMock() + gamma_fn = MagicMock() + + eps = [ + _make_ep("alpha", load_fn=alpha_fn), + _make_ep("beta", load_fn=beta_fn), + _make_ep("gamma", load_fn=gamma_fn), + ] + mock_eps.return_value = eps + + result = load_plugins_by_group("test.group") + self.assertIn("alpha", result) + self.assertNotIn("beta", result) + self.assertIn("gamma", result) + + @patch("sglang.srt.plugins.entry_points") + @patch("sglang.srt.plugins.envs") + def test_excluded_dists(self, mock_envs, mock_eps): + """SGLANG_PLATFORM excludes other platform dists; empty when unset.""" + # Case 1: no env set → empty + mock_envs.SGLANG_PLATFORM.get.return_value = "" + self.assertEqual(_get_excluded_dists(), set()) + + # Case 2: env set → exclude other dists + mock_envs.SGLANG_PLATFORM.get.return_value = "kunlun" + ep_kunlun = _make_ep("kunlun", dist_name="kunlun-pkg") + ep_other = _make_ep("other_hw", dist_name="other-pkg") + mock_eps.return_value = [ep_kunlun, ep_other] + + excluded = _get_excluded_dists() + self.assertNotIn("kunlun-pkg", excluded) + self.assertIn("other-pkg", excluded) + + @patch("sglang.srt.plugins.HookRegistry") + @patch("sglang.srt.plugins.envs") + @patch("sglang.srt.plugins.entry_points") + def test_current_plugin_source_set_during_and_reset_after( + self, mock_eps, mock_envs, mock_registry + ): + """_current_plugin_source is set during plugin execution, reset after.""" + sources_seen = [] + + def spy_plugin(): + sources_seen.append(_current_plugin_source.get()) + + mock_eps.return_value = [_make_ep("spy", load_fn=spy_plugin)] + mock_envs.SGLANG_PLATFORM.get.return_value = "" + mock_envs.SGLANG_PLUGINS.get.return_value = "" + + load_plugins() + # During execution: source was set (not None) + self.assertEqual(len(sources_seen), 1) + self.assertIsNotNone(sources_seen[0]) + self.assertEqual(sources_seen[0].plugin_name, "spy") + # After execution: source is back to None + self.assertIsNone(_current_plugin_source.get()) + + @patch("sglang.srt.plugins.HookRegistry") + @patch("sglang.srt.plugins.envs") + @patch("sglang.srt.plugins.entry_points") + def test_current_plugin_source_reset_after_exception( + self, mock_eps, mock_envs, mock_registry + ): + """_current_plugin_source is reset to None even when a plugin raises.""" + mock_envs.SGLANG_PLATFORM.get.return_value = "" + mock_envs.SGLANG_PLUGINS.get.return_value = "" + + def bad_plugin(): + raise RuntimeError("boom") + + mock_eps.return_value = [_make_ep("bad", load_fn=bad_plugin)] + + load_plugins() + self.assertIsNone(_current_plugin_source.get()) + + +if __name__ == "__main__": + import unittest + + unittest.main()