diff --git a/docs/diffusion/api/cli.md b/docs/diffusion/api/cli.md index 587efeb46450..6652bf964943 100644 --- a/docs/diffusion/api/cli.md +++ b/docs/diffusion/api/cli.md @@ -73,10 +73,13 @@ Use `sglang generate --help` and `sglang serve --help` for the full argument lis - `--model-path {MODEL}`: model path or Hugging Face model ID - `--lora-path {PATH}` and `--lora-nickname {NAME}`: load a LoRA adapter +- `--lora-merge-mode {auto|merge|dynamic}`: choose how LoRA is applied. `auto` statically merges regular weights and uses dynamic LoRA for FSDP-sharded weights to avoid full-gather peaks. - `--num-gpus {N}`: number of GPUs to use +- `--performance-mode {manual|auto|speed|memory}` / `--mode`: preset for latency/throughput and memory defaults. `auto` is the default and keeps safe offload defaults, using FSDP only for validated DiT-offload replacement paths; use `manual` to keep performance-related server args under explicit user control. Explicit offload, FSDP, and parallelism flags take precedence in all modes. - `--tp-size {N}`: tensor parallelism size, mainly for encoders - `--sp-degree {N}`: sequence parallelism size - `--ulysses-degree {N}` and `--ring-degree {N}`: USP parallelism controls +- `--enable-cfg-parallel {true|false}`: enable or explicitly disable CFG parallelism - `--attention-backend {BACKEND}`: attention backend for native SGLang pipelines - `--component-attention-backends {MAP}`: per-component attention backend overrides, for example `text_encoder=torch_sdpa,transformer=fa` - `--attention-backend-config {CONFIG}`: attention backend configuration diff --git a/docs/diffusion/api/openai_api.md b/docs/diffusion/api/openai_api.md index 8d18c49599ba..99bbea056bba 100644 --- a/docs/diffusion/api/openai_api.md +++ b/docs/diffusion/api/openai_api.md @@ -268,13 +268,13 @@ curl -sS -L "http://localhost:30010/v1/videos//content" \ The server supports dynamic loading, merging, and unmerging of LoRA adapters. **Important Notes:** -- Mutual Exclusion: Only one LoRA can be *merged* (active) at a time -- Switching: To switch LoRAs, you must first `unmerge` the current one, then `set` the new one +- Mutual Exclusion: Only one LoRA configuration can be active per target at a time +- Switching: To switch LoRAs, deactivate the current LoRA with `unmerge_lora_weights`, then `set` the new one - Caching: The server caches loaded LoRA weights in memory. Switching back to a previously loaded LoRA (same path) has little cost **Set LoRA Adapter** -Loads one or more LoRA adapters and merges their weights into the model. Supports both single LoRA (backward compatible) and multiple LoRA adapters. +Loads one or more LoRA adapters and applies them to the model. By default, regular weights are statically merged, while FSDP-sharded weights use dynamic LoRA to avoid full-gather memory peaks. **Endpoint:** `POST /v1/set_lora` @@ -287,6 +287,7 @@ Loads one or more LoRA adapters and merges their weights into the model. Support - `"transformer_2"`: Apply only to transformer_2 (low noise for Wan2.2) - `"critic"`: Apply only to the critic model - `strength` (float or list of floats, optional): LoRA strength for merge, default 1.0. If a list, must match the length of `lora_nickname`. Values < 1.0 reduce the effect, values > 1.0 amplify the effect +- `merge_mode` (string, optional): `"auto"` (default server policy), `"merge"` (force static merge), or `"dynamic"` (apply LoRA at forward time) **Single LoRA Example:** @@ -331,7 +332,7 @@ curl -X POST http://localhost:30010/v1/set_lora \ > When using multiple LoRAs: > - All list parameters (`lora_nickname`, `lora_path`, `target`, `strength`) must have the same length > - If `target` or `strength` is a single value, it will be applied to all LoRAs -> - Multiple LoRAs applied to the same target will be merged in order +> - Multiple LoRAs applied to the same target are applied in order **Merge LoRA Weights** @@ -339,7 +340,7 @@ curl -X POST http://localhost:30010/v1/set_lora \ Manually merges the currently set LoRA weights into the base model. > [!NOTE] -> `set_lora` automatically performs a merge, so this is typically only needed if you have manually unmerged but want to re-apply the same LoRA without calling `set_lora` again.* +> With FSDP-sharded weights, manual merge may require a full-gather and can OOM. Use `set_lora` with `merge_mode="auto"` or `"dynamic"` for the lower-peak path. **Endpoint:** `POST /v1/merge_lora_weights` @@ -395,6 +396,7 @@ curl -sS -X GET "http://localhost:30010/v1/list_loras" "nickname": "lora2", "path": "tarn59/pixel_art_style_lora_z_image_turbo", "merged": true, + "mode": "merged", "strength": 1.0 } ] diff --git a/docs/diffusion/performance/deployment_cookbook.md b/docs/diffusion/performance/deployment_cookbook.md new file mode 100644 index 000000000000..d45798b1c56b --- /dev/null +++ b/docs/diffusion/performance/deployment_cookbook.md @@ -0,0 +1,94 @@ +# Deployment Cookbook + +This page gives practical defaults for choosing CPU offload, FSDP, CFG parallelism, SP, and TP. + +## Quick Rule + +Use the simplest setting that fits your memory target: + +| Goal | Recommended setting | +|--------------------------------------------|------------------------------------------------------------------------------------| +| Fastest single-GPU run when the model fits | Disable CPU offload and do not use FSDP. | +| Lower single-GPU memory usage | Use component CPU offload, or layerwise DiT offload for supported Wan/MOVA models. | +| Faster multi-GPU Qwen/Wan CFG generation | Use FSDP with CFG parallelism and disable CPU offload. | +| Sequence length or video-shape scaling | Use SP/Ulysses/Ring when the model benefits from sequence parallelism. | +| TP compatibility or encoder-heavy paths | Set TP explicitly; do not treat TP as the default latency optimization. | + +Base the decision on available memory on the selected GPU(s). + +- For multi-GPU deployment: the least-free selected GPU is the bottleneck. A busy 80GiB GPU can behave like a much smaller GPU. +- For single-GPU deployment: FSDP shards DiT weights across multiple GPUs. It is not useful for keeping a single-GPU deployment on one GPU; for that case use CPU offload. + +## Performance Modes + +`--performance-mode` applies safe presets without overriding explicit offload, FSDP, or parallelism flags. `auto` is the default. Use `manual` when you need to keep performance-related server args under explicit user control. `--mode` is a short alias. + +| Mode | Meaning | +|------------|---------------------------------------------------------------------------------------------------------------------------| +| `manual` | Keeps performance-related server args under explicit user control. | +| `auto` | Default. Keeps legacy safe offload defaults and uses FSDP/CFG only on validated multi-GPU deployments where FSDP can replace DiT offload. | +| `speed` | Favors GPU-resident execution for lower latency and higher throughput. Disables CPU offload when unset; may OOM. | +| `memory` | Favors lower GPU memory. Uses component offload, or Wan/MOVA layerwise DiT offload when supported. | + +`auto` checks selected GPU memory before applying FSDP. In multi-GPU runs it uses the least available memory across selected GPUs, and only turns on FSDP automatically when doing so can replace DiT offload. Text encoder, image encoder, and other component residency still follow the offload policy unless the model marks a high-memory resident path as safe. When the model default uses CFG and the user did not set a parallelism policy, `auto` may also enable CFG parallelism. `speed` intentionally does not check memory; it is the mode for users who prefer latency/throughput and accept OOM risk. + +The modes tune residency for native pipeline components declared to the component residency manager. Today this covers the major DiT, text/image encoder, VAE, vocoder, and upsampler components; DiT can use layerwise offload when supported, while text encoders use either resident execution or component CPU offload. Do not assume text-encoder layerwise offload unless a model implements and validates it. + +NOTE: +The preset is intentionally coarse. A future continuous value such as `0.0` to `1.0` could express the speed-memory tradeoff more precisely, but it would need model-specific memory models and clearer user expectations. Until then, use the preset plus explicit flags for overrides. + +Examples: + +```bash +sglang generate \ + --model-path Qwen/Qwen-Image \ + --num-gpus 2 \ + --performance-mode auto +``` + +```bash +sglang generate \ + --model-path Wan-AI/Wan2.1-T2V-1.3B-Diffusers \ + --performance-mode memory +``` + +Explicit flags win over the mode: + +```bash +sglang generate \ + --model-path Qwen/Qwen-Image \ + --num-gpus 2 \ + --performance-mode auto \ + --use-fsdp-inference false +``` + +In this example, `auto` will not re-enable FSDP. The same applies to parallelism; for example, `--enable-cfg-parallel false` keeps CFG parallelism disabled. + +## Interpreting The Levers + +**No offload** keeps model components resident on GPU. It is usually fastest when memory is sufficient. + +**Component CPU offload** lowers GPU memory by moving large components to CPU. It is simple and robust, but it usually trades latency for memory. + +**Layerwise DiT offload** lowers DiT memory further for supported Wan/MOVA models by moving DiT layers between CPU and GPU. It can be the best single-GPU memory mode, but may increase latency and lower throughput. + +**FSDP** shards DiT weights across multiple GPUs and all-gathers weights during forward. It can reduce DiT CPU offload cost on multi-GPU deployments, especially for validated Wan I2V workloads. + +FSDP sharding granularity matters. SGLang Diffusion prefers sharding direct repeated transformer block entries such as `transformer_blocks.0` or `blocks.0`. Coarser sharding lowers wrapper count but can increase all-gather peak memory; finer sharding can reduce transient memory but adds communication and scheduling overhead. If a model does not define an explicit sharding rule, the loader falls back to repeated block class names and common direct numbered block paths. + +**CFG parallelism** splits positive and negative CFG branches across GPUs. For Qwen/Wan workloads with normal step counts, this is the most reliable multi-GPU speedup observed so far. + +**SP/Ulysses/Ring** splits sequence work. It can help video workloads, but validated Qwen/Wan runs showed CFG parallelism outperforming SP for latency. + +**TP** is supported for compatibility and some model structures, but current measurements do not make it the default latency path for Qwen/Wan. + +## Current Benchmark Takeaways + +Observed regular-scale trends: + +- Z-Image: single-GPU no-offload was faster than FSDP/SP in the tested setting; keep FSDP off unless memory or parallelism requires it. +- Qwen-Image: keep the default non-FSDP path unless a specific FSDP/SP/Ring setting has been benchmarked on the target hardware. +- Wan: FSDP can replace DiT offload on validated multi-GPU workloads, while text/image encoders may still need component offload. Keep model-specific precision checks before making FSDP automatic for a path. +- Component offload mainly reduced memory; it did not improve latency in the tested no-offload-vs-offload runs. + +Always benchmark with your actual resolution, frame count, step count, and GPU type before locking production defaults. diff --git a/docs/diffusion/performance/index.md b/docs/diffusion/performance/index.md index 2a2abe54a239..0723e15b9688 100644 --- a/docs/diffusion/performance/index.md +++ b/docs/diffusion/performance/index.md @@ -14,6 +14,7 @@ This section covers the main performance levers for SGLang Diffusion: attention ## Start Here - Use [Attention Backends](attention_backends.md) to choose the best backend for your model and hardware. +- Use [Deployment Cookbook](deployment_cookbook.md) to choose CPU offload, FSDP, CFG parallelism, SP, and TP. - Use [Caching Acceleration](cache/index.md) to reduce denoising cost with Cache-DiT or TeaCache. - Use [Profiling](profiling.md) when you need to diagnose a bottleneck rather than guess. @@ -26,6 +27,7 @@ This section covers the main performance levers for SGLang Diffusion: attention :maxdepth: 1 attention_backends +deployment_cookbook cache/index profiling ``` diff --git a/python/sglang/benchmark/utils.py b/python/sglang/benchmark/utils.py index 7bf6494b5df1..031d38cdf646 100644 --- a/python/sglang/benchmark/utils.py +++ b/python/sglang/benchmark/utils.py @@ -118,13 +118,16 @@ def download_and_cache_file(url: str, filename: Optional[str] = None): chunk_size = 1024 # Download in chunks of 1KB # Use tqdm to display the progress bar - with open(filename, "wb") as f, tqdm( - desc=filename, - total=total_size, - unit="B", - unit_scale=True, - unit_divisor=1024, - ) as bar: + with ( + open(filename, "wb") as f, + tqdm( + desc=filename, + total=total_size, + unit="B", + unit_scale=True, + unit_divisor=1024, + ) as bar, + ): for chunk in response.iter_content(chunk_size=chunk_size): f.write(chunk) bar.update(len(chunk)) diff --git a/python/sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/executors/qwen_image.py b/python/sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/executors/qwen_image.py index b89186801b47..8bc84b975f62 100644 --- a/python/sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/executors/qwen_image.py +++ b/python/sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/executors/qwen_image.py @@ -120,7 +120,7 @@ def forward( ref_latents=None, additional_t_cond=None, transformer_options={}, - **kwargs + **kwargs, ): """Forward pass for QwenImageEdit model.""" latents, orig_shape = self._pack_latents(x) diff --git a/python/sglang/multimodal_gen/configs/models/dits/wanvideo.py b/python/sglang/multimodal_gen/configs/models/dits/wanvideo.py index 544240d8ae82..d13783afdd7e 100644 --- a/python/sglang/multimodal_gen/configs/models/dits/wanvideo.py +++ b/python/sglang/multimodal_gen/configs/models/dits/wanvideo.py @@ -28,6 +28,7 @@ class WanVideoArchConfig(DiTArchConfig): r"^blocks\.(\d+)\.attn1\.norm_q\.(.*)$": r"blocks.\1.norm_q.\2", r"^blocks\.(\d+)\.attn1\.norm_k\.(.*)$": r"blocks.\1.norm_k.\2", r"^blocks\.(\d+)\.attn1\.attn_op\.local_attn\.proj_l\.(.*)$": r"blocks.\1.attn1.local_attn.proj_l.\2", + r"^blocks\.(\d+)\.attn2\.norm_added_q\.(.*)$": "", r"^blocks\.(\d+)\.attn2\.to_out\.0\.(.*)$": r"blocks.\1.attn2.to_out.\2", r"^blocks\.(\d+)\.ffn\.net\.0\.proj\.(.*)$": r"blocks.\1.ffn.fc_in.\2", r"^blocks\.(\d+)\.ffn\.net\.2\.(.*)$": r"blocks.\1.ffn.fc_out.\2", diff --git a/python/sglang/multimodal_gen/configs/pipeline_configs/base.py b/python/sglang/multimodal_gen/configs/pipeline_configs/base.py index 7005ac70d5dd..ba5c313c31ec 100644 --- a/python/sglang/multimodal_gen/configs/pipeline_configs/base.py +++ b/python/sglang/multimodal_gen/configs/pipeline_configs/base.py @@ -22,6 +22,9 @@ ) from sglang.multimodal_gen.configs.models.encoders import BaseEncoderOutput from sglang.multimodal_gen.configs.models.encoders.t5 import T5Config +from sglang.multimodal_gen.configs.pipeline_configs.model_deployment_config import ( + ModelDeploymentConfig, +) from sglang.multimodal_gen.configs.sample.sampling_params import DataType from sglang.multimodal_gen.configs.utils import update_config_from_args from sglang.multimodal_gen.runtime.distributed.cfg_policy import CFGPolicy @@ -240,6 +243,9 @@ class PipelineConfig: # image encoding image_encoder_extra_args: dict = field(default_factory=lambda: {}) + def get_model_deployment_config(self) -> ModelDeploymentConfig: + return ModelDeploymentConfig() + def postprocess_image(self, image): return image.last_hidden_state diff --git a/python/sglang/multimodal_gen/configs/pipeline_configs/ltx_2.py b/python/sglang/multimodal_gen/configs/pipeline_configs/ltx_2.py index c32d4bd80845..f425289f251f 100644 --- a/python/sglang/multimodal_gen/configs/pipeline_configs/ltx_2.py +++ b/python/sglang/multimodal_gen/configs/pipeline_configs/ltx_2.py @@ -16,6 +16,9 @@ ModelTaskType, PipelineConfig, ) +from sglang.multimodal_gen.configs.pipeline_configs.model_deployment_config import ( + ModelDeploymentConfig, +) from sglang.multimodal_gen.runtime.distributed import ( get_sp_parallel_rank, get_sp_world_size, @@ -187,6 +190,12 @@ def vae_scale_factor(self): def vae_temporal_compression(self): return self.vae_config.arch_config.temporal_compression_ratio + def get_model_deployment_config(self) -> ModelDeploymentConfig: + return ModelDeploymentConfig( + auto_disable_component_offload_min_available_memory_gb=70, + auto_disable_component_offload_components=("dit",), + ) + def prepare_latent_shape(self, batch, batch_size, num_frames): """Return unpacked latent shape [B, C, F, H, W].""" height = batch.height // self.vae_scale_factor diff --git a/python/sglang/multimodal_gen/configs/pipeline_configs/model_deployment_config.py b/python/sglang/multimodal_gen/configs/pipeline_configs/model_deployment_config.py new file mode 100644 index 000000000000..f8580e65203c --- /dev/null +++ b/python/sglang/multimodal_gen/configs/pipeline_configs/model_deployment_config.py @@ -0,0 +1,26 @@ +""" +ModelDeploymentConfig provides model-specific config on how to deploy a model optimally + +""" + +from dataclasses import dataclass +from typing import Literal + +OffloadComponentName = Literal["dit", "text_encoder", "image_encoder"] + + +@dataclass(frozen=True) +class ModelDeploymentConfig: + auto_dit_layerwise_offload: bool = False + # if the available memory is bigger than this value, keep dit resident instead of apply layerwise-offload + auto_dit_layerwise_offload_high_memory_disable_gb: float | None = None + auto_disable_component_offload_min_available_memory_gb: float | None = None + # keep this explicit because large encoders can OOM even when DiT fits resident + auto_disable_component_offload_components: tuple[OffloadComponentName, ...] = ( + "dit", + "text_encoder", + "image_encoder", + ) + fsdp_auto_min_available_memory_gb: float | None = None + fsdp_auto_requires_cfg: bool = True + fsdp_auto_requires_default_parallelism: bool = True diff --git a/python/sglang/multimodal_gen/configs/pipeline_configs/mova.py b/python/sglang/multimodal_gen/configs/pipeline_configs/mova.py index 4612a9eef493..5de1ccf447cd 100644 --- a/python/sglang/multimodal_gen/configs/pipeline_configs/mova.py +++ b/python/sglang/multimodal_gen/configs/pipeline_configs/mova.py @@ -17,6 +17,9 @@ ModelTaskType, PipelineConfig, ) +from sglang.multimodal_gen.configs.pipeline_configs.model_deployment_config import ( + ModelDeploymentConfig, +) from sglang.multimodal_gen.configs.pipeline_configs.wan import t5_postprocess_text from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger @@ -52,6 +55,12 @@ class MOVAPipelineConfig(PipelineConfig): time_division_factor: int = 4 time_division_remainder: int = 1 + def get_model_deployment_config(self) -> ModelDeploymentConfig: + return ModelDeploymentConfig( + auto_dit_layerwise_offload=True, + auto_dit_layerwise_offload_high_memory_disable_gb=130, + ) + def _center_crop_and_resize( self, image: torch.Tensor | Image.Image, target_height: int, target_width: int ) -> torch.Tensor | Image.Image: diff --git a/python/sglang/multimodal_gen/configs/pipeline_configs/qwen_image.py b/python/sglang/multimodal_gen/configs/pipeline_configs/qwen_image.py index d2b9c5a62c3e..cdfa175f2350 100644 --- a/python/sglang/multimodal_gen/configs/pipeline_configs/qwen_image.py +++ b/python/sglang/multimodal_gen/configs/pipeline_configs/qwen_image.py @@ -168,6 +168,7 @@ class QwenImagePipelineConfig(QwenImageRolloutPipelineMixin, ImagePipelineConfig postprocess_text_funcs: tuple[Callable[[str], str], ...] = field( default_factory=lambda: (qwen_image_postprocess_text,) ) + text_encoder_extra_args: list[dict] = field( default_factory=lambda: [ dict( diff --git a/python/sglang/multimodal_gen/configs/pipeline_configs/wan.py b/python/sglang/multimodal_gen/configs/pipeline_configs/wan.py index c08851e93fe6..e5f33027db9f 100644 --- a/python/sglang/multimodal_gen/configs/pipeline_configs/wan.py +++ b/python/sglang/multimodal_gen/configs/pipeline_configs/wan.py @@ -18,6 +18,9 @@ ModelTaskType, PipelineConfig, ) +from sglang.multimodal_gen.configs.pipeline_configs.model_deployment_config import ( + ModelDeploymentConfig, +) from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger logger = init_logger(__name__) @@ -91,6 +94,12 @@ def __post_init__(self): self.vae_config.load_encoder = False self.vae_config.load_decoder = True + def get_model_deployment_config(self) -> ModelDeploymentConfig: + return ModelDeploymentConfig( + auto_dit_layerwise_offload=True, + auto_dit_layerwise_offload_high_memory_disable_gb=130, + ) + @dataclass class TurboWanT2V480PConfig(WanT2V480PConfig): @@ -136,6 +145,12 @@ def __post_init__(self) -> None: self.vae_config.load_encoder = True self.vae_config.load_decoder = True + def get_model_deployment_config(self) -> ModelDeploymentConfig: + return ModelDeploymentConfig( + auto_dit_layerwise_offload=True, + auto_dit_layerwise_offload_high_memory_disable_gb=130, + ) + @dataclass class WanI2V720PConfig(WanI2V480PConfig): diff --git a/python/sglang/multimodal_gen/configs/pipeline_configs/zimage.py b/python/sglang/multimodal_gen/configs/pipeline_configs/zimage.py index d26276863457..af9fc49914f1 100644 --- a/python/sglang/multimodal_gen/configs/pipeline_configs/zimage.py +++ b/python/sglang/multimodal_gen/configs/pipeline_configs/zimage.py @@ -17,6 +17,9 @@ TextConditioningOutput, pad_text_embeddings_with_mask, ) +from sglang.multimodal_gen.configs.pipeline_configs.model_deployment_config import ( + ModelDeploymentConfig, +) from sglang.multimodal_gen.configs.post_training.pipeline_configs import ( ZImageRolloutPipelineMixin, ) @@ -80,6 +83,9 @@ class ZImagePipelineConfig(ZImageRolloutPipelineMixin, ImagePipelineConfig): PATCH_SIZE: int = 2 F_PATCH_SIZE: int = 1 + def get_model_deployment_config(self) -> ModelDeploymentConfig: + return ModelDeploymentConfig(fsdp_auto_min_available_memory_gb=40) + def tokenize_prompt(self, prompts: list[str], tokenizer, tok_kwargs) -> dict: rendered_prompts = [ tokenizer.apply_chat_template( diff --git a/python/sglang/multimodal_gen/runtime/entrypoints/diffusion_generator.py b/python/sglang/multimodal_gen/runtime/entrypoints/diffusion_generator.py index 19cd5ea097b8..25d0d6f159a0 100644 --- a/python/sglang/multimodal_gen/runtime/entrypoints/diffusion_generator.py +++ b/python/sglang/multimodal_gen/runtime/entrypoints/diffusion_generator.py @@ -456,6 +456,7 @@ def set_lora( lora_path: Union[str, None, List[Union[str, None]]] = None, target: Union[str, List[str]] = "all", strength: Union[float, List[float]] = 1.0, + merge_mode: str | None = None, ) -> None: """ Set LoRA adapter(s) for the specified transformer(s). @@ -471,12 +472,14 @@ def set_lora( - "transformer_2": Apply only to transformer_2 (low noise for Wan2.2) - "critic": Apply only to the critic model strength: LoRA strength(s) for merge, default 1.0. Can be a float or a list of floats. + merge_mode: Optional LoRA merge mode: "auto", "merge", or "dynamic". """ req = SetLoraReq( lora_nickname=lora_nickname, lora_path=lora_path, target=target, strength=strength, + merge_mode=merge_mode, ) nickname_str, target_str, strength_str = format_lora_message( lora_nickname, target, strength diff --git a/python/sglang/multimodal_gen/runtime/entrypoints/openai/common_api.py b/python/sglang/multimodal_gen/runtime/entrypoints/openai/common_api.py index 328d9f6f1dd2..5022b411d122 100644 --- a/python/sglang/multimodal_gen/runtime/entrypoints/openai/common_api.py +++ b/python/sglang/multimodal_gen/runtime/entrypoints/openai/common_api.py @@ -66,6 +66,7 @@ async def set_lora( lora_path: Optional[Union[str, List[Optional[str]]]] = Body(None, embed=True), target: Union[str, List[str]] = Body("all", embed=True), strength: Union[float, List[float]] = Body(1.0, embed=True), + merge_mode: Optional[str] = Body(None, embed=True), ): """ Set LoRA adapter(s) for the specified transformer(s). @@ -84,12 +85,14 @@ async def set_lora( strength: LoRA strength(s) for merge, default 1.0. Can be a float or a list of floats. If a list, must match the length of lora_nickname. Values < 1.0 reduce the effect, values > 1.0 amplify the effect. + merge_mode: Optional LoRA merge mode: "auto", "merge", or "dynamic". """ req = SetLoraReq( lora_nickname=lora_nickname, lora_path=lora_path, target=target, strength=strength, + merge_mode=merge_mode, ) nickname_str, target_str, strength_str = format_lora_message( lora_nickname, target, strength diff --git a/python/sglang/multimodal_gen/runtime/entrypoints/utils.py b/python/sglang/multimodal_gen/runtime/entrypoints/utils.py index d65a1287fcbc..cf69ff4c4fcf 100644 --- a/python/sglang/multimodal_gen/runtime/entrypoints/utils.py +++ b/python/sglang/multimodal_gen/runtime/entrypoints/utils.py @@ -48,6 +48,7 @@ class SetLoraReq: lora_path: Optional[Union[str, List[Optional[str]]]] = None target: Union[str, List[str]] = "all" strength: Union[float, List[float]] = 1.0 + merge_mode: Optional[str] = None @dataclass diff --git a/python/sglang/multimodal_gen/runtime/layers/lora/linear.py b/python/sglang/multimodal_gen/runtime/layers/lora/linear.py index f14344edc943..8ca749fd793e 100644 --- a/python/sglang/multimodal_gen/runtime/layers/lora/linear.py +++ b/python/sglang/multimodal_gen/runtime/layers/lora/linear.py @@ -403,6 +403,9 @@ def __init__( super().__init__(base_layer, lora_rank, lora_alpha) def forward(self, input_: torch.Tensor) -> torch.Tensor: + if self.merged or self.disable_lora: + return self.base_layer(input_) + lora_A = self.lora_A lora_B = self.lora_B if isinstance(self.lora_B, DTensor): diff --git a/python/sglang/multimodal_gen/runtime/managers/gpu_worker.py b/python/sglang/multimodal_gen/runtime/managers/gpu_worker.py index e60ee241361d..5982cd07ca1d 100644 --- a/python/sglang/multimodal_gen/runtime/managers/gpu_worker.py +++ b/python/sglang/multimodal_gen/runtime/managers/gpu_worker.py @@ -769,6 +769,7 @@ def set_lora( lora_path: Union[str, None, List[Union[str, None]]] = None, target: Union[str, List[str]] = "all", strength: Union[float, List[float]] = 1.0, + merge_mode: str | None = None, ) -> OutputBatch: """ Set the LoRA adapter(s) for the pipeline. @@ -779,10 +780,13 @@ def set_lora( lora_path: Path(s) to the LoRA adapter(s). Can be a string, None, or a list of strings/None. target: Which transformer(s) to apply the LoRA to. Can be a string or a list of strings. strength: LoRA strength(s) for merge, default 1.0. Can be a float or a list of floats. + merge_mode: Optional per-request LoRA merge mode. """ if not isinstance(self.pipeline, LoRAPipeline): return OutputBatch(error="Lora is not enabled") - self.pipeline.set_lora(lora_nickname, lora_path, target, strength) + self.pipeline.set_lora( + lora_nickname, lora_path, target, strength, merge_mode=merge_mode + ) return OutputBatch() def merge_lora_weights( @@ -868,16 +872,22 @@ def get_weights_checksum( return checksums -OOM_MSG = f""" +OOM_MSG = """ OOM detected. Possible solutions: - If the OOM occurs during loading: - 1. Enable CPU offload for memory-intensive components, or use `--dit-layerwise-offload` for DiT + 1. Check available memory on every selected GPU, not only total capacity. + In multi-GPU runs, the least-free selected GPU is the bottleneck. + 2. For single-GPU deployment, use `--performance-mode memory`, component CPU offload, + or `--dit-layerwise-offload` for supported Wan/MOVA DiTs. + 3. For multi-GPU deployment, keep the default `--performance-mode auto` or set + `--use-fsdp-inference true` to shard DiT weights with FSDP. FSDP is not a + single-GPU substitute for CPU offload. - If the OOM occurs during runtime: - 1. Enable SP and/or TP (in a multi-GPU setup) - 2. Reduce the number of output tokens by lowering resolution or decreasing `--num-frames` - 3. Opt for a sparse-attention backend - 4. Enable FSDP by `--use-fsdp-inference` (in a multi-GPU setup) - 5. Enable quantization (e.g. nunchaku) + 1. Reduce resolution, `--num-frames`, or batch size. + 2. Use `--performance-mode memory` for lower memory usage. + 3. Enable SP/Ulysses/Ring for sequence-heavy workloads in multi-GPU setups. + 4. Use FSDP, with CFG parallelism when supported, for validated multi-GPU workloads. + 5. Use a lower-memory attention backend or quantization when available. Or, open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose """ diff --git a/python/sglang/multimodal_gen/runtime/managers/scheduler.py b/python/sglang/multimodal_gen/runtime/managers/scheduler.py index 8a1231dcaab6..a51de4f8afaa 100644 --- a/python/sglang/multimodal_gen/runtime/managers/scheduler.py +++ b/python/sglang/multimodal_gen/runtime/managers/scheduler.py @@ -189,7 +189,11 @@ def _handle_set_lora(self, reqs: List[Any]) -> OutputBatch: # TODO: return with SetLoRAResponse or something more appropriate req = reqs[0] return self.worker.set_lora( - req.lora_nickname, req.lora_path, req.target, req.strength + req.lora_nickname, + req.lora_path, + req.target, + req.strength, + req.merge_mode, ) def _handle_merge_lora(self, reqs: List[Any]): diff --git a/python/sglang/multimodal_gen/runtime/pipelines/ltx_2_pipeline.py b/python/sglang/multimodal_gen/runtime/pipelines/ltx_2_pipeline.py index 39e2872b4f21..e2e26de70a13 100644 --- a/python/sglang/multimodal_gen/runtime/pipelines/ltx_2_pipeline.py +++ b/python/sglang/multimodal_gen/runtime/pipelines/ltx_2_pipeline.py @@ -559,6 +559,25 @@ def prepare_after_request( self.manager._sync_refinement_stage_transformer("stage1") self.manager._active_phase = "stage1" + def finish_request( + self, + module: torch.nn.Module, + use: ComponentUse, + state: ResidencyState, + *, + preferred: bool, + ) -> None: + if ( + preferred + and state.batch_is_warmup + and self._snapshot_low_vram_mode + and self._phase(use) == "stage1" + ): + # keep the text encoder warm, but avoid stage1 DiT overlap before the first real request + self.manager._active_phase = None + return + super().finish_request(module, use, state, preferred=preferred) + def finish_use( self, module: torch.nn.Module, @@ -618,6 +637,13 @@ def prefetch_for_use( if not self.server_args.dit_cpu_offload: return True phase = self._phase(use) + if ( + self._snapshot_low_vram_mode + and phase == "stage1" + and state.current_use is not None + and state.current_use.component_name.startswith("text_encoder") + ): + return False if phase == "stage2": if self._snapshot_strategy.is_ready("transformer_2"): return True diff --git a/python/sglang/multimodal_gen/runtime/pipelines_core/lora_pipeline.py b/python/sglang/multimodal_gen/runtime/pipelines_core/lora_pipeline.py index a8e13f222d50..82fcc5b31112 100644 --- a/python/sglang/multimodal_gen/runtime/pipelines_core/lora_pipeline.py +++ b/python/sglang/multimodal_gen/runtime/pipelines_core/lora_pipeline.py @@ -10,6 +10,7 @@ import torch import torch.distributed as dist from safetensors.torch import load_file +from torch.distributed.tensor import DTensor from sglang.multimodal_gen.runtime.distributed import get_local_torch_device from sglang.multimodal_gen.runtime.layers.lora.linear import ( @@ -24,7 +25,7 @@ from sglang.multimodal_gen.runtime.pipelines_core.lora_format_adapter import ( normalize_lora_state_dict, ) -from sglang.multimodal_gen.runtime.server_args import ServerArgs +from sglang.multimodal_gen.runtime.server_args import LORA_MERGE_MODES, ServerArgs from sglang.multimodal_gen.runtime.utils.hf_diffusers_utils import maybe_download_lora from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger @@ -362,6 +363,7 @@ def _check_lora_config_matches( module_name: str, target_nicknames: list[str], target_strengths: list[float], + target_merge_weights: bool, adapter_updated: bool, ) -> bool: """ @@ -376,11 +378,12 @@ def _check_lora_config_matches( Returns: True if the configuration matches exactly (including order and strength), False otherwise. """ - if not self.is_lora_merged.get(module_name, False): - return False if adapter_updated: return False # Adapter was updated, need to reapply + if self.is_lora_merged.get(module_name, False) != target_merge_weights: + return False + stored_config = self.cur_adapter_config.get(module_name) if stored_config is None: return False @@ -392,6 +395,68 @@ def _check_lora_config_matches( and stored_strengths == target_strengths ) + @staticmethod + def _uses_dtensor_weights(lora_layers: dict[str, BaseLayerWithLoRA]) -> bool: + return any(isinstance(layer.weight, DTensor) for layer in lora_layers.values()) + + @staticmethod + def _has_active_unmerged_lora( + lora_layers: dict[str, BaseLayerWithLoRA], + ) -> bool: + return any( + not layer.merged and not layer.disable_lora + for layer in lora_layers.values() + ) + + def _is_lora_effective_for_module( + self, + module_name: str, + lora_layers: dict[str, BaseLayerWithLoRA], + ) -> bool: + return self.is_lora_merged.get( + module_name, False + ) or self._has_active_unmerged_lora(lora_layers) + + def _resolve_lora_merge_mode( + self, + merge_weights: bool | None, + merge_mode: str | None, + ) -> str: + if merge_mode is None: + if merge_weights is not None: + merge_mode = "merge" if merge_weights else "dynamic" + else: + merge_mode = self.server_args.lora_merge_mode + if merge_mode not in LORA_MERGE_MODES: + raise ValueError( + f"Invalid LoRA merge mode: {merge_mode}. Valid modes: {LORA_MERGE_MODES}" + ) + return merge_mode + + def _should_merge_lora_for_layers( + self, + module_name: str, + lora_layers: dict[str, BaseLayerWithLoRA], + merge_mode: str, + ) -> bool: + if merge_mode == "dynamic": + return False + uses_dtensor_weights = self._uses_dtensor_weights(lora_layers) + if merge_mode == "auto": + if uses_dtensor_weights: + logger.info( + "Using dynamic LoRA for %s because FSDP-sharded weights would require a full-gather merge.", + module_name, + ) + return False + return True + if uses_dtensor_weights: + logger.warning( + "Merging LoRA for %s with FSDP-sharded weights may require full-gather and can OOM.", + module_name, + ) + return True + def _apply_lora_to_layers( self, lora_layers: dict[str, BaseLayerWithLoRA], @@ -516,14 +581,18 @@ def _apply_lora_to_layers( def is_lora_effective(self, target: str = "all") -> bool: """ - Check if LoRA is currently effective (merged) for the specified target. + Check if LoRA is currently effective for the specified target. Args: - target: Which transformer to check. "all" returns True if any is merged. + target: Which transformer to check. "all" returns True if any is effective. """ - if target == "all": - return any(self.is_lora_merged.values()) - return self.is_lora_merged.get(target, False) + target_modules, error = self._get_target_lora_layers(target) + if error: + logger.warning("is_lora_effective: %s", error) + return any( + self._is_lora_effective_for_module(module_name, lora_layers_dict) + for module_name, lora_layers_dict in target_modules + ) def is_lora_set(self, target: str = "all") -> bool: """ @@ -622,12 +691,15 @@ def set_lora( lora_path: str | None | list[str | None] = None, target: str | list[str] = "all", strength: float | list[float] = 1.0, - merge_weights: bool = True, + merge_weights: bool | None = None, + merge_mode: str | None = None, ): # type: ignore """ Load LoRA adapter(s) into the pipeline and apply them to the specified transformer(s). Supports both single LoRA (backward compatible) and multiple LoRA adapters. """ + merge_mode = self._resolve_lora_merge_mode(merge_weights, merge_mode) + # Normalize inputs to lists for multi-LoRA support lora_nicknames, lora_paths, strengths, targets = self._normalize_lora_params( lora_nickname, lora_path, strength, target @@ -700,15 +772,34 @@ def set_lora( # Skip if LoRA configuration matches exactly (including order and strength) # Since all modules for the same target apply the same config, checking one is sufficient - first_module_name, _ = target_modules[0] + first_module_name, first_lora_layers_dict = target_modules[0] + first_effective_merge_weights = self._should_merge_lora_for_layers( + first_module_name, first_lora_layers_dict, merge_mode + ) + if not first_effective_merge_weights and len(tgt_nicknames) > 1: + raise ValueError( + "Dynamic LoRA currently supports only one adapter per target. " + "Use merge_mode='merge' for multiple adapters." + ) if self._check_lora_config_matches( - first_module_name, tgt_nicknames, tgt_strengths, adapter_updated + first_module_name, + tgt_nicknames, + tgt_strengths, + first_effective_merge_weights, + adapter_updated, ): logger.info("LoRA configuration matches exactly, skipping") continue # Apply LoRA to modules for this target for module_name, lora_layers_dict in target_modules: + effective_merge_weights = ( + first_effective_merge_weights + if module_name == first_module_name + else self._should_merge_lora_for_layers( + module_name, lora_layers_dict, merge_mode + ) + ) count = self._apply_lora_to_layers( lora_layers_dict, tgt_nicknames, @@ -716,7 +807,7 @@ def set_lora( rank, tgt_strengths, clear_existing=True, - merge_weights=merge_weights, + merge_weights=effective_merge_weights, ) adapted_count += count self.cur_adapter_name[module_name] = merged_name @@ -724,7 +815,7 @@ def set_lora( str(p or self.loaded_adapter_paths.get(n, "")) for n, p in zip(tgt_nicknames, tgt_paths) ) - self.is_lora_merged[module_name] = merge_weights + self.is_lora_merged[module_name] = effective_merge_weights self.cur_adapter_strength[module_name] = tgt_strengths[0] # Store full configuration for multi-LoRA support (preserves order and all strengths) self.cur_adapter_config[module_name] = ( @@ -733,7 +824,7 @@ def set_lora( ) logger.info( - "Rank %d: LoRA adapter(s) %s applied to %d layers (targets: %s, strengths: %s, merge_weights=%s)", + "Rank %d: LoRA adapter(s) %s applied to %d layers (targets: %s, strengths: %s, merge_mode=%s)", rank, ", ".join(map(str, lora_paths)) if lora_paths else None, adapted_count, @@ -743,7 +834,7 @@ def set_lora( if len(strengths) > 1 else f"{strengths[0]:.2f}" ), - merge_weights, + merge_mode, ) def deactivate_lora_weights(self, target: str = "all") -> None: @@ -799,6 +890,24 @@ def merge_lora_weights(self, target: str = "all", strength: float = 1.0) -> None # Disable layerwise offload if enabled: load all layers to GPU with self._temporarily_disable_offload(target_modules=target_modules): for module_name, lora_layers_dict in target_modules: + if not self._should_merge_lora_for_layers( + module_name, lora_layers_dict, self.server_args.lora_merge_mode + ): + for layer in lora_layers_dict.values(): + if layer.lora_A is None: + continue + if layer.merged: + layer.unmerge_lora_weights() + layer.disable_lora = False + layer.strength = strength + self.is_lora_merged[module_name] = False + self.cur_adapter_strength[module_name] = strength + logger.info( + "Dynamic LoRA activated for %s (strength: %s)", + module_name, + strength, + ) + continue if self.is_lora_merged.get(module_name, False): # Check if strength is the same - if so, skip (idempotent) if self.cur_adapter_strength.get(module_name) == strength: @@ -815,13 +924,9 @@ def merge_lora_weights(self, target: str = "all", strength: float = 1.0) -> None ) for name, layer in lora_layers_dict.items(): # Only re-enable LoRA for layers that actually have LoRA weights - has_lora_weights = ( - hasattr(layer, "lora_A") and layer.lora_A is not None - ) - if not has_lora_weights: + if layer.lora_A is None: continue - if hasattr(layer, "disable_lora"): - layer.disable_lora = False + layer.disable_lora = False try: layer.merge_lora_weights(strength=strength) except Exception as e: @@ -854,9 +959,17 @@ def unmerge_lora_weights(self, target: str = "all") -> None: for module_name, lora_layers_dict in target_modules: if not self.is_lora_merged.get(module_name, False): - logger.warning( - "LoRA weights are not merged for %s, skipping", module_name - ) + if self._has_active_unmerged_lora(lora_layers_dict): + for layer in lora_layers_dict.values(): + if not layer.disable_lora: + layer.disable_lora = True + self.cur_adapter_strength.pop(module_name, None) + self.cur_adapter_config.pop(module_name, None) + logger.info("Unmerged LoRA weights deactivated for %s", module_name) + else: + logger.warning( + "LoRA weights are not merged for %s, skipping", module_name + ) continue with self._temporarily_disable_offload(target_modules=target_modules): for name, layer in lora_layers_dict.items(): @@ -897,7 +1010,14 @@ def get_lora_status(self) -> dict[str, Any]: def _module_status(module_name: str) -> list[dict] | None: # return list of dict to support multi-lora in the future - if not self.is_lora_merged.get(module_name, False): + if module_name == "transformer": + lora_layers = self.lora_layers + elif module_name == "transformer_2": + lora_layers = self.lora_layers_transformer_2 + else: + lora_layers = self.lora_layers_critic + + if not self._is_lora_effective_for_module(module_name, lora_layers): return None else: return [ @@ -905,6 +1025,11 @@ def _module_status(module_name: str) -> list[dict] | None: "nickname": self.cur_adapter_name.get(module_name, None), "path": self.cur_adapter_path.get(module_name, None), "merged": self.is_lora_merged.get(module_name, False), + "mode": ( + "merged" + if self.is_lora_merged.get(module_name, False) + else "unmerged" + ), "strength": self.cur_adapter_strength.get(module_name, None), } ] diff --git a/python/sglang/multimodal_gen/runtime/server_args.py b/python/sglang/multimodal_gen/runtime/server_args.py index d98cfddd3503..d5043fe27ec2 100644 --- a/python/sglang/multimodal_gen/runtime/server_args.py +++ b/python/sglang/multimodal_gen/runtime/server_args.py @@ -41,6 +41,10 @@ AttentionBackendEnum, current_platform, ) +from sglang.multimodal_gen.runtime.server_args_auto_tune import ( + PERFORMANCE_MODES, + ServerArgsAutoTuner, +) from sglang.multimodal_gen.runtime.utils.common import ( is_port_available, is_valid_ipv6_address, @@ -62,21 +66,11 @@ logger = init_logger(__name__) -# Derived from single-H200 benchmarking (~140.4 GiB total) at the maximum -# supported 720p workloads with dit_layerwise_offload=False and -# num_inference_steps=1: -# - Wan-AI/Wan2.2-T2V-A14B-Diffusers, 1280x720, 81 frames: -# peak_reserved=108076 MB (~105.5 GiB), peak_allocated=97665 MB (~95.4 GiB) -# - OpenMOSS-Team/MOVA-720p, 1280x720, 193 frames: -# peak_reserved=130264 MB (~127.2 GiB), peak_allocated=108819 MB (~106.3 GiB) -# Also, on H200, enabling dit_layerwise_offload regressed latency noticeably on -# our validated Wan/MOVA workloads, so use a 130 GiB cutoff to keep H200-class -# GPUs on the faster no-offload default while preserving some headroom. -WAN_LAYERWISE_OFFLOAD_AUTO_DISABLE_MEM_GB = 130 LTX2_TWO_STAGE_DEVICE_MODES = ("original", "snapshot", "resident") LTX2_TWO_STAGE_PIPELINE_NAMES = ("LTX2TwoStagePipeline", "LTX2TwoStageHQPipeline") # H200-class GPUs (>=130 GiB total) can usually keep both LTX2 DiTs resident. LTX2_RESIDENT_AUTO_ENABLE_MEM_GB = 130 +LORA_MERGE_MODES = ("auto", "merge", "dynamic") def _normalize_ltx2_two_stage_device_mode(mode: str | None) -> str | None: @@ -148,6 +142,7 @@ class ServerArgs(DisaggArgsMixin): # Parallelism num_gpus: int = 1 + performance_mode: str = "auto" tp_size: Optional[int] = None sp_degree: Optional[int] = None # sequence parallelism @@ -179,6 +174,7 @@ class ServerArgs(DisaggArgsMixin): lora_path: str | None = None lora_nickname: str = "default" # for swapping adapters in the pipeline lora_scale: float = 1.0 # LoRA scale for merging (e.g., 0.125 for Hyper-SD) + lora_merge_mode: str = "auto" lora_weight_name: str | None = None # Component path overrides (key = model_index.json component name, value = path) @@ -197,7 +193,7 @@ class ServerArgs(DisaggArgsMixin): text_encoder_cpu_offload: bool | None = None image_encoder_cpu_offload: bool | None = None vae_cpu_offload: bool | None = False - use_fsdp_inference: bool = False + use_fsdp_inference: bool | None = None pin_cpu_memory: bool = True ltx2_two_stage_device_mode: str | None = None @@ -320,8 +316,15 @@ def _adjust_path(self): def _adjust_parameters(self): """set defaults and normalize values.""" - self._adjust_offload() + auto_tuner = ServerArgsAutoTuner(self) + auto_tuner.adjust() + if auto_tuner.could_override_server_args(): + self._adjust_offload() + auto_tuner.maybe_adjust_auto_dit_layerwise_offload() self._adjust_ltx2_two_stage_device_mode() + if auto_tuner.could_override_server_args(): + auto_tuner.maybe_adjust_auto_component_residency_after_offload() + auto_tuner.maybe_adjust_auto_fsdp_with_offload_enabled() self._adjust_path() self._adjust_quant_config() self._adjust_warmup() @@ -331,6 +334,7 @@ def _adjust_parameters(self): self._adjust_attention_backend() self._adjust_platform_specific() self._adjust_autocast() + auto_tuner.finalize_auto_flags() self.adjust_pipeline_config() def _validate_parameters(self): @@ -477,6 +481,12 @@ def _is_ltx23_two_stage_pipeline(self) -> bool: or is_ltx23_native_variant(self.pipeline_config.vae_config.arch_config) ) + def _uses_ltx23_snapshot_two_stage_residency(self) -> bool: + return ( + self.ltx2_two_stage_device_mode == "snapshot" + and self._is_ltx23_two_stage_pipeline() + ) + def _adjust_attention_backend(self): if self.attention_backend in ["fa3", "fa4"]: self.attention_backend = "fa" @@ -667,13 +677,12 @@ def _adjust_network_ports(self): self.master_port = self.settle_port(self.master_port, 37) def _adjust_parallelism(self): - tp_unspecified = self.tp_size is None sp_unspecified = self.sp_degree is None ulysses_unspecified = self.ulysses_degree is None ring_unspecified = self.ring_degree is None cfg_unspecified = self.enable_cfg_parallel is None - if current_platform.is_cpu() and self.tp_size > 1: + if current_platform.is_cpu() and (self.tp_size or 1) > 1: # CPU platform reuse num_gpus to represent num cpu numa nodes as devices self.num_gpus = self.tp_size @@ -698,7 +707,8 @@ def _adjust_parallelism(self): if cfg_unspecified: cfg_group_size = self.dp_size * self.tp_size * 2 if ( - self.num_gpus >= 2 + self.performance_mode != "manual" + and self.num_gpus >= 2 and self.num_gpus % cfg_group_size == 0 and sp_unspecified and ulysses_unspecified @@ -787,38 +797,6 @@ def _adjust_platform_specific(self): self.use_fsdp_inference = False self.dit_layerwise_offload = False - # automatically enable dit_layerwise_offload for Wan/MOVA models if appropriate - if not envs.SGLANG_CACHE_DIT_ENABLED: - pipeline_name_lower = self.pipeline_config.__class__.__name__.lower() - if ( - "wan" in pipeline_name_lower or "mova" in pipeline_name_lower - ) and self.dit_layerwise_offload is None: - auto_enable_layerwise_offload = ( - current_platform.enable_dit_layerwise_offload_for_wan_by_default() - ) - if auto_enable_layerwise_offload and current_platform.is_cuda(): - device_total_memory_gb = ( - current_platform.get_device_total_memory() / BYTES_PER_GB - ) - if ( - device_total_memory_gb - >= WAN_LAYERWISE_OFFLOAD_AUTO_DISABLE_MEM_GB - ): - logger.info( - "Skipping automatic dit_layerwise_offload for %s on a high-memory CUDA GPU (e.g. H200/B200/B300-class, %.2f GiB total)", - self.pipeline_config.__class__.__name__, - device_total_memory_gb, - ) - auto_enable_layerwise_offload = False - self.dit_layerwise_offload = False - - if auto_enable_layerwise_offload: - logger.info( - f"Automatically enable dit_layerwise_offload for {self.pipeline_config.__class__.__name__} " - "for low memory and performance balance" - ) - self.dit_layerwise_offload = True - def _adjust_autocast(self): if self.disable_autocast is None: self.disable_autocast = not self.pipeline_config.enable_autocast @@ -967,6 +945,21 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: default=ServerArgs.num_gpus, help="The number of GPUs to use.", ) + parser.add_argument( + "--performance-mode", + "--mode", + type=str, + choices=PERFORMANCE_MODES, + default=ServerArgs.performance_mode, + help=( + "Preset for performance and memory defaults. " + "'manual' keeps performance-related server args under explicit user control; " + "'auto' keeps safe defaults and applies high-confidence FSDP/CFG improvements; " + "'speed' favors GPU-resident execution for lower latency and higher throughput, and may OOM; " + "'memory' favors lower GPU memory usage; " + "Explicit offload/FSDP/parallelism flags take precedence." + ), + ) parser.add_argument( "--tp-size", @@ -994,9 +987,9 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: ) parser.add_argument( "--enable-cfg-parallel", - action="store_true", + action=StoreBoolean, default=None, - help="Enable cfg parallel at degree 2. Auto-enabled when num_gpus >= 2 and no SP flags are set.", + help="Enable cfg parallel at degree 2. Auto-enabled when num_gpus >= 2 and no SP flags are set. Use false to disable it explicitly.", ) parser.add_argument( "--cfg-parallel-size", @@ -1106,7 +1099,7 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: parser.add_argument( "--use-fsdp-inference", action=StoreBoolean, - help="Use FSDP for inference by sharding the model weights. Latency is very low due to prefetch--enable if run out of memory.", + help="Use FSDP inference to shard DiT weights across GPUs. For single-GPU memory pressure, prefer CPU or layerwise offload.", ) parser.add_argument( "--text-encoder-cpu-offload", @@ -1270,6 +1263,17 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: default=ServerArgs.lora_scale, help="LoRA scale for merging (e.g., 0.125 for Hyper-SD). Same as lora_scale in Diffusers", ) + parser.add_argument( + "--lora-merge-mode", + type=str, + choices=LORA_MERGE_MODES, + default=ServerArgs.lora_merge_mode, + help=( + "How LoRA is applied: auto keeps static merge for regular weights " + "and uses dynamic LoRA for FSDP-sharded weights to avoid full-gather; " + "merge always merges into base weights; dynamic always applies LoRA at forward time." + ), + ) parser.add_argument( "--lora-weight-name", type=str, @@ -1562,6 +1566,8 @@ def get_provided_args( # For '--arg=value', this gets 'arg'; for '--arg', this also gets 'arg'. arg_name = arg.split("=", 1)[0].replace("-", "_").lstrip("_") provided_arg_names.add(arg_name) + if "mode" in provided_arg_names: + provided_arg_names.add("performance_mode") # Populate provided_args if the argument from the namespace was on the command line. for k, v in vars(args).items(): diff --git a/python/sglang/multimodal_gen/runtime/server_args_auto_tune.py b/python/sglang/multimodal_gen/runtime/server_args_auto_tune.py new file mode 100644 index 000000000000..d9b67e235990 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/server_args_auto_tune.py @@ -0,0 +1,370 @@ +""" +ServerArgsAutoTuner tunes the ServerArgs based on the desired performance mode +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from sglang.multimodal_gen import envs +from sglang.multimodal_gen.configs.pipeline_configs.model_deployment_config import ( + ModelDeploymentConfig, +) +from sglang.multimodal_gen.runtime.platforms import current_platform +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +if TYPE_CHECKING: + from sglang.multimodal_gen.runtime.server_args import ServerArgs + +logger = init_logger(__name__) + +PERFORMANCE_MODES = ("manual", "auto", "speed", "memory") + + +class ServerArgsAutoTuner: + """Auto-tunes the server-arg for the given performance-mode, based on practical deployment experience with different model architectures""" + + def __init__(self, server_args: "ServerArgs"): + self.server_args = server_args + self._explicit_memory_policy = self._has_explicit_memory_policy() + + def _deployment_config(self) -> ModelDeploymentConfig: + return self.server_args.pipeline_config.get_model_deployment_config() + + def adjust(self) -> None: + """Adjust the server args based on the performance mode""" + args = self.server_args + args.performance_mode = self._normalize_performance_mode() + + if current_platform.is_cpu(): + return + + if args.performance_mode == "speed": + logger.info("Applying performance_mode=speed") + if args.num_gpus >= 2 and self._can_apply_fsdp_policy( + require_memory_headroom=False + ): + self._set_gpu_resident_defaults(use_fsdp=True) + self._enable_cfg_parallel_if_supported() + else: + self._set_gpu_resident_defaults(use_fsdp=False) + return + + if args.performance_mode == "memory": + logger.info("Applying performance_mode=memory") + if args.use_fsdp_inference: + self._set_gpu_resident_defaults(use_fsdp=True) + return + args.use_fsdp_inference = False + if self._can_apply_dit_layerwise_offload_policy(): + # apply dit layerwise offload to save VRAM during denoising stage + self._set_layerwise_offload_defaults() + else: + self._set_component_offload_defaults() + return + + def maybe_adjust_auto_component_residency_after_offload(self) -> None: + args = self.server_args + if ( + args.performance_mode != "auto" + or self._explicit_memory_policy + or current_platform.is_cpu() + ): + return + + min_available_gb = self._get_min_available_device_memory_gb() + disable_threshold_gb = ( + self._deployment_config().auto_disable_component_offload_min_available_memory_gb + ) + if ( + min_available_gb is not None + and disable_threshold_gb is not None + and min_available_gb >= disable_threshold_gb + ): + changed = [] + components = ( + self._deployment_config().auto_disable_component_offload_components + ) + if args._uses_ltx23_snapshot_two_stage_residency(): + # ltx2 snapshot mode uses DiT offload to release/prefetch stage DiTs between phases + components = tuple( + component for component in components if component != "dit" + ) + if args.dit_cpu_offload and "dit" in components: + args.dit_cpu_offload = False + changed.append("dit_cpu_offload=False") + if args.text_encoder_cpu_offload and "text_encoder" in components: + args.text_encoder_cpu_offload = False + changed.append("text_encoder_cpu_offload=False") + if args.image_encoder_cpu_offload and "image_encoder" in components: + args.image_encoder_cpu_offload = False + changed.append("image_encoder_cpu_offload=False") + if changed: + logger.info( + "Disabling component offload for %s because minimum available memory on selected GPUs is %.2f GiB: %s", + args.pipeline_config.__class__.__name__, + min_available_gb, + ", ".join(changed), + ) + + def maybe_adjust_auto_fsdp_with_offload_enabled(self) -> None: + args = self.server_args + if ( + args.performance_mode == "auto" + and args.num_gpus >= 2 + and not self._explicit_memory_policy + and self._auto_uses_dit_offload() + and self._can_apply_fsdp_policy(require_memory_headroom=True) + ): + logger.info( + "Automatically selecting FSDP defaults for multi-GPU %s to replace DiT offload", + args.pipeline_config.__class__.__name__, + ) + args.use_fsdp_inference = True + if args.dit_cpu_offload: + args.dit_cpu_offload = False + if args.dit_layerwise_offload: + args.dit_layerwise_offload = False + self._enable_cfg_parallel_if_supported() + + def maybe_adjust_auto_dit_layerwise_offload(self) -> None: + args = self.server_args + if not self.could_override_server_args(): + return + if self._explicit_memory_policy: + return + deployment_config = self._deployment_config() + if envs.SGLANG_CACHE_DIT_ENABLED: + return + if ( + not deployment_config.auto_dit_layerwise_offload + or args.dit_layerwise_offload is not None + ): + return + if args.use_fsdp_inference: + # if fsdp is enabled, layerwise-offload is weakened since the parameter has already been sharded + args.dit_layerwise_offload = False + return + + auto_enable_layerwise_offload = ( + current_platform.enable_dit_layerwise_offload_for_wan_by_default() + ) + disable_threshold_gb = ( + deployment_config.auto_dit_layerwise_offload_high_memory_disable_gb + ) + if ( + auto_enable_layerwise_offload + and current_platform.is_cuda() + and disable_threshold_gb is not None + ): + # auto turn off layerwise-offload if we have sufficient VRAM headroom + device_total_memory_gb = current_platform.get_device_total_memory() / ( + 1 << 30 + ) + if device_total_memory_gb >= disable_threshold_gb: + logger.info( + "Skipping automatic dit_layerwise_offload for %s on a high-memory CUDA GPU (e.g. H200/B200/B300-class, %.2f GiB total)", + args.pipeline_config.__class__.__name__, + device_total_memory_gb, + ) + auto_enable_layerwise_offload = False + args.dit_layerwise_offload = False + + if auto_enable_layerwise_offload: + logger.info( + "Automatically enable dit_layerwise_offload for %s for low memory and performance balance", + args.pipeline_config.__class__.__name__, + ) + args.dit_layerwise_offload = True + args.dit_cpu_offload = False + + def finalize_auto_flags(self) -> None: + """if some args are unset after all the adjustment, set them to defaults""" + if not self.could_override_server_args(): + return + args = self.server_args + if args.use_fsdp_inference is None: + args.use_fsdp_inference = False + if args.dit_cpu_offload is None: + args.dit_cpu_offload = False + if args.dit_layerwise_offload is None: + args.dit_layerwise_offload = False + if args.text_encoder_cpu_offload is None: + args.text_encoder_cpu_offload = False + if args.image_encoder_cpu_offload is None: + args.image_encoder_cpu_offload = False + + def _normalize_performance_mode(self) -> str: + args = self.server_args + mode = (args.performance_mode or "auto").lower() + if mode not in PERFORMANCE_MODES: + valid_modes = PERFORMANCE_MODES + raise ValueError( + f"Invalid performance_mode={args.performance_mode!r}. " + f"Expected one of {valid_modes}." + ) + return mode + + def could_override_server_args(self) -> bool: + return self.server_args.performance_mode != "manual" + + def _set_gpu_resident_defaults(self, *, use_fsdp: bool) -> None: + """set all components to be resident""" + args = self.server_args + changed = [] + if args.use_fsdp_inference is None: + args.use_fsdp_inference = use_fsdp + changed.append(f"use_fsdp_inference={use_fsdp}") + if args.dit_cpu_offload is None: + args.dit_cpu_offload = False + changed.append("dit_cpu_offload=False") + if args.dit_layerwise_offload is None: + args.dit_layerwise_offload = False + changed.append("dit_layerwise_offload=False") + if args.text_encoder_cpu_offload is None: + args.text_encoder_cpu_offload = False + changed.append("text_encoder_cpu_offload=False") + if args.image_encoder_cpu_offload is None: + args.image_encoder_cpu_offload = False + changed.append("image_encoder_cpu_offload=False") + + if changed: + logger.debug( + "Applied GPU-resident performance defaults: %s", ", ".join(changed) + ) + + def _set_component_offload_defaults(self) -> None: + args = self.server_args + changed = [] + if args.dit_cpu_offload is None: + args.dit_cpu_offload = True + changed.append("dit_cpu_offload=True") + if args.text_encoder_cpu_offload is None: + args.text_encoder_cpu_offload = True + changed.append("text_encoder_cpu_offload=True") + if args.image_encoder_cpu_offload is None: + args.image_encoder_cpu_offload = True + changed.append("image_encoder_cpu_offload=True") + if args.use_fsdp_inference is None: + args.use_fsdp_inference = False + changed.append("use_fsdp_inference=False") + + if changed: + logger.info( + "Applied low-memory component offload defaults: %s", + ", ".join(changed), + ) + + def _set_layerwise_offload_defaults(self) -> None: + args = self.server_args + if args.dit_layerwise_offload is None: + args.dit_layerwise_offload = True + if args.dit_cpu_offload is None: + args.dit_cpu_offload = False + if args.text_encoder_cpu_offload is None: + args.text_encoder_cpu_offload = True + if args.image_encoder_cpu_offload is None: + args.image_encoder_cpu_offload = True + + def _can_apply_dit_layerwise_offload_policy(self) -> bool: + return ( + self._deployment_config().auto_dit_layerwise_offload + and not envs.SGLANG_CACHE_DIT_ENABLED + and current_platform.enable_dit_layerwise_offload_for_wan_by_default() + ) + + def _auto_uses_dit_offload(self) -> bool: + args = self.server_args + return bool(args.dit_cpu_offload or args.dit_layerwise_offload) + + def _get_min_available_device_memory_gb(self) -> float | None: + args = self.server_args + if current_platform.is_cpu(): + return None + + # Multi-GPU defaults are limited by the least-free selected GPU. + return min( + current_platform.get_available_gpu_memory( + device_id=device_id, + empty_cache=False, + ) + for device_id in range( + args.base_gpu_id, args.base_gpu_id + max(1, args.num_gpus) + ) + ) + + def _has_explicit_memory_policy(self) -> bool: + args = self.server_args + return ( + args.use_fsdp_inference is not None + or args.dit_cpu_offload is not None + or args.dit_layerwise_offload is not None + or args.text_encoder_cpu_offload is not None + or args.image_encoder_cpu_offload is not None + ) + + def _has_explicit_parallel_policy(self) -> bool: + args = self.server_args + return ( + args.tp_size is not None + or args.sp_degree is not None + or args.ulysses_degree is not None + or args.ring_degree is not None + or args.enable_cfg_parallel is not None + ) + + def _enable_cfg_parallel_if_supported(self) -> None: + args = self.server_args + if ( + args.enable_cfg_parallel is None + and not self._has_explicit_parallel_policy() + and args._model_default_uses_cfg() + ): + args.enable_cfg_parallel = True + + def _supports_high_confidence_fsdp(self) -> bool: + deployment_config = self._deployment_config() + return deployment_config.fsdp_auto_min_available_memory_gb is not None and ( + not deployment_config.fsdp_auto_requires_cfg + or self.server_args._model_default_uses_cfg() + ) + + def _has_enough_available_memory_for_fsdp(self) -> bool: + args = self.server_args + min_available_gb = self._get_min_available_device_memory_gb() + if min_available_gb is None: + return True + + required_gb = self._deployment_config().fsdp_auto_min_available_memory_gb + if required_gb is None: + return False + if min_available_gb < required_gb: + logger.info( + "Skipping automatic FSDP defaults: minimum available memory on selected GPUs %.2f GiB is below %.2f GiB for %s", + min_available_gb, + required_gb, + args.pipeline_config.__class__.__name__, + ) + return False + return True + + def _can_apply_fsdp_policy(self, *, require_memory_headroom: bool) -> bool: + args = self.server_args + deployment_config = self._deployment_config() + if not self._supports_high_confidence_fsdp(): + return False + if envs.SGLANG_CACHE_DIT_ENABLED: + logger.info("Skipping automatic FSDP defaults because cache-dit is enabled") + return False + if ( + args.performance_mode == "auto" + and deployment_config.fsdp_auto_requires_default_parallelism + and self._has_explicit_parallel_policy() + ): + logger.info( + "Skipping automatic FSDP defaults because an explicit parallel policy is set" + ) + return False + return ( + not require_memory_headroom or self._has_enough_available_memory_for_fsdp() + ) diff --git a/python/sglang/multimodal_gen/test/server/perf_baselines.json b/python/sglang/multimodal_gen/test/server/perf_baselines.json index 393a73137d46..8463871dde92 100644 --- a/python/sglang/multimodal_gen/test/server/perf_baselines.json +++ b/python/sglang/multimodal_gen/test/server/perf_baselines.json @@ -1105,7 +1105,7 @@ "TimestepPreparationStage": 7.85, "LTX2AVLatentPreparationStage": 0.34, "LTX2ImageEncodingStage": 0.02, - "LTX2AVDenoisingStage": 7744.44, + "LTX2AVDenoisingStage": 6000.0, "LTX2UpsampleStage": 2.98, "LTX2RefinementStage": 666.08, "LTX2AVDecodingStage": 338.87, @@ -1156,9 +1156,9 @@ "41": 223.51, "42": 221.07 }, - "expected_e2e_ms": 10601.1, - "expected_avg_denoise_ms": 195.04, - "expected_median_denoise_ms": 191.34, + "expected_e2e_ms": 8500.0, + "expected_avg_denoise_ms": 150.0, + "expected_median_denoise_ms": 135.0, "estimated_full_test_time_s": 345.4 }, "wan2_2_ti2v_5b": { @@ -2384,7 +2384,7 @@ "LTX2SigmaPreparationStage": 0.26, "TimestepPreparationStage": 13.63, "LTX2AVLatentPreparationStage": 0.12, - "LTX2AVDenoisingStage": 24658.05, + "LTX2AVDenoisingStage": 18500.0, "LTX2AVDecodingStage": 383.25 }, "denoise_step_ms": { @@ -2419,9 +2419,9 @@ "28": 815.91, "29": 803.65 }, - "expected_e2e_ms": 26917.71, - "expected_avg_denoise_ms": 791.24, - "expected_median_denoise_ms": 791.0, + "expected_e2e_ms": 21000.0, + "expected_avg_denoise_ms": 620.0, + "expected_median_denoise_ms": 620.0, "estimated_full_test_time_s": 144.2 }, "ltx_2.3_two_stage_t2v_2gpus": { diff --git a/python/sglang/multimodal_gen/test/server/test_server_common.py b/python/sglang/multimodal_gen/test/server/test_server_common.py index af04cf68871f..fcdc20bd7065 100644 --- a/python/sglang/multimodal_gen/test/server/test_server_common.py +++ b/python/sglang/multimodal_gen/test/server/test_server_common.py @@ -2,7 +2,7 @@ Config-driven diffusion generation test with pytest parametrization. -If the actual run is significantly better than the baseline, the improved cases with their updated baseline will be printed +Each collected request prints a performance log before validation. """ from __future__ import annotations @@ -225,13 +225,11 @@ class DiffusionServerBase: """ _perf_results: list[dict[str, Any]] = [] - _improved_baselines: list[dict[str, Any]] = [] _pytest_config = None # Store pytest config for stash access @classmethod def setup_class(cls): cls._perf_results = [] - cls._improved_baselines = [] @classmethod def teardown_class(cls): @@ -251,20 +249,6 @@ def teardown_class(cls): "[DEBUG teardown_class] No pytest_config available, skipping stash update" ) - if cls._improved_baselines: - import json - - output = """ ---- POTENTIAL BASELINE IMPROVEMENTS DETECTED --- -The following test cases performed significantly better than their baselines. -Consider updating perf_baselines.json with the snippets below: -""" - for item in cls._improved_baselines: - output += ( - f'\n"{item["id"]}": {json.dumps(item["baseline"], indent=4)},\n' - ) - print(output) - @pytest.fixture(autouse=True) def _capture_pytest_config(self, request): """Capture pytest config for use in teardown_class.""" @@ -351,6 +335,7 @@ def _validate_and_record( ) summary = validator.collect_metrics(perf_record) + self._print_performance_log(case, summary, scenario) if case.run_perf_check: if is_baseline_generation_mode: @@ -365,8 +350,6 @@ def _validate_and_record( ) return - self._check_for_improvement(case, summary, scenario) - # only run performance validation if run_perf_check is True try: validator.validate(perf_record, case.sampling_params.num_frames) @@ -400,83 +383,43 @@ def _validate_and_record( f"[DEBUG _validate_and_record] Appended result for {case.id}, class {self.__class__.__name__} now has {len(self.__class__._perf_results)} results" ) - def _check_for_improvement( + def _print_performance_log( self, case: DiffusionTestCase, summary: PerformanceSummary, - scenario: "ScenarioConfig", + scenario: ScenarioConfig | None, ) -> None: - """Check for potential significant performance improvements and record them.""" - is_improved = False - threshold = BASELINE_CONFIG.improvement_threshold - - def is_sig_faster(actual, expected): - if expected == 0 or expected is None: - return False - return actual < expected * (1 - threshold) - - def safe_get_metric(metric_dict, key): - val = metric_dict.get(key) - return val if val is not None else float("inf") - - # Check for any significant improvement - if ( - is_sig_faster(summary.e2e_ms, scenario.expected_e2e_ms) - or is_sig_faster(summary.avg_denoise_ms, scenario.expected_avg_denoise_ms) - or is_sig_faster( - summary.median_denoise_ms, scenario.expected_median_denoise_ms + lines = [ + "", + f"--- Performance Log: {case.id} ---", + ( + f" e2e={summary.e2e_ms:.2f}ms, " + f"avg_denoise={summary.avg_denoise_ms:.2f}ms, " + f"median_denoise={summary.median_denoise_ms:.2f}ms" + ), + ] + if scenario is not None: + lines.append( + " baseline: " + f"e2e={scenario.expected_e2e_ms:.2f}ms, " + f"avg_denoise={scenario.expected_avg_denoise_ms:.2f}ms, " + f"median_denoise={scenario.expected_median_denoise_ms:.2f}ms" ) - ): - is_improved = True - # Combine metrics, always taking the better (lower) value - new_stages = { - stage: min( - safe_get_metric(summary.stage_metrics, stage), - safe_get_metric(scenario.stages_ms, stage), + if summary.stage_metrics: + stages = ", ".join( + f"{name}={duration:.2f}ms" + for name, duration in summary.stage_metrics.items() ) - for stage in set(summary.stage_metrics) | set(scenario.stages_ms) - } - new_denoise_steps = { - step: min( - safe_get_metric(summary.all_denoise_steps, step), - safe_get_metric(scenario.denoise_step_ms, step), + lines.append(f" stages: {stages}") + if summary.all_denoise_steps: + # ci retries need the exact outlier, not only sampled checkpoints + steps = ", ".join( + f"{idx}={duration:.2f}ms" + for idx, duration in sorted(summary.all_denoise_steps.items()) ) - for step in set(summary.all_denoise_steps.keys()) - | set(scenario.denoise_step_ms) - } - - # Check for stage-level improvements - if not is_improved: - for stage, new_val in new_stages.items(): - if is_sig_faster(new_val, scenario.stages_ms.get(stage, float("inf"))): - is_improved = True - break - if not is_improved: - for step, new_val in new_denoise_steps.items(): - if is_sig_faster( - new_val, scenario.denoise_step_ms.get(step, float("inf")) - ): - is_improved = True - break - - if is_improved: - new_baseline = { - "stages_ms": {k: round(v, 2) for k, v in new_stages.items()}, - "denoise_step_ms": { - str(k): round(v, 2) for k, v in new_denoise_steps.items() - }, - "expected_e2e_ms": round( - min(summary.e2e_ms, scenario.expected_e2e_ms), 2 - ), - "expected_avg_denoise_ms": round( - min(summary.avg_denoise_ms, scenario.expected_avg_denoise_ms), 2 - ), - "expected_median_denoise_ms": round( - min(summary.median_denoise_ms, scenario.expected_median_denoise_ms), - 2, - ), - } - self._improved_baselines.append({"id": case.id, "baseline": new_baseline}) + lines.append(f" denoise_steps: {steps}") + lines.append(f"--- End Performance Log: {case.id} ---") + print("\n".join(lines), flush=True) def _dump_baseline_for_testcase( self, diff --git a/python/sglang/multimodal_gen/test/unit/test_resolve_prompts.py b/python/sglang/multimodal_gen/test/unit/test_resolve_prompts.py index 73e5cdb9eb2e..5b92ad281b29 100644 --- a/python/sglang/multimodal_gen/test/unit/test_resolve_prompts.py +++ b/python/sglang/multimodal_gen/test/unit/test_resolve_prompts.py @@ -51,9 +51,10 @@ def test_prompt_path_multi_line(self): os.unlink(path) def test_prompt_path_takes_priority_over_server_args(self): - with tempfile.NamedTemporaryFile( - "w", suffix=".txt", delete=False - ) as f1, tempfile.NamedTemporaryFile("w", suffix=".txt", delete=False) as f2: + with ( + tempfile.NamedTemporaryFile("w", suffix=".txt", delete=False) as f1, + tempfile.NamedTemporaryFile("w", suffix=".txt", delete=False) as f2, + ): f1.write("from prompt_path\n") f2.write("from server_args\n") path1, path2 = f1.name, f2.name diff --git a/python/sglang/multimodal_gen/test/unit/test_server_args.py b/python/sglang/multimodal_gen/test/unit/test_server_args.py index d0550df1dcad..463577b66148 100644 --- a/python/sglang/multimodal_gen/test/unit/test_server_args.py +++ b/python/sglang/multimodal_gen/test/unit/test_server_args.py @@ -12,9 +12,13 @@ ModelTaskType, PipelineConfig, ) +from sglang.multimodal_gen.configs.pipeline_configs.ltx_2 import LTX2PipelineConfig +from sglang.multimodal_gen.configs.pipeline_configs.mova import MOVAPipelineConfig from sglang.multimodal_gen.configs.pipeline_configs.qwen_image import ( QwenImagePipelineConfig, ) +from sglang.multimodal_gen.configs.pipeline_configs.wan import WanT2V480PConfig +from sglang.multimodal_gen.configs.pipeline_configs.zimage import ZImagePipelineConfig from sglang.multimodal_gen.registry import _get_config_info from sglang.multimodal_gen.runtime.models.dits.qwen_image import ( QwenImageTransformer2DModel, @@ -110,12 +114,34 @@ def test_dynamic_component_attention_backend_cli_args(self): "torch_sdpa", ] - with patch.object(sys, "argv", ["sglang"] + argv): - args, unknown_args = parser.parse_known_args(argv) - with patch.object( + with ( + patch.object(sys, "argv", ["sglang"] + argv), + patch.object( PipelineConfig, "from_kwargs", return_value=QwenImagePipelineConfig() - ): - server_args = ServerArgs.from_cli_args(args, unknown_args) + ), + patch( + "sglang.multimodal_gen.runtime.server_args.current_platform.is_cpu", + return_value=False, + ), + patch( + "sglang.multimodal_gen.runtime.server_args.current_platform.is_mps", + return_value=False, + ), + patch( + "sglang.multimodal_gen.runtime.server_args.current_platform.is_cuda", + return_value=True, + ), + patch( + "sglang.multimodal_gen.runtime.server_args.current_platform.get_device_total_memory", + return_value=80 * 1024**3, + ), + patch( + "sglang.multimodal_gen.runtime.server_args.current_platform.get_available_gpu_memory", + return_value=80, + ), + ): + args, unknown_args = parser.parse_known_args(argv) + server_args = ServerArgs.from_cli_args(args, unknown_args) self.assertEqual( server_args.component_attention_backends, {"text_encoder": "torch_sdpa"} @@ -123,6 +149,50 @@ def test_dynamic_component_attention_backend_cli_args(self): class TestOffloadDefaults(unittest.TestCase): + def _from_dict_with_pipeline_config( + self, + pipeline_config, + *, + memory_gb=80, + available_memory_gb=None, + kwargs=None, + ): + def get_available_gpu_memory(device_id=0, **_kwargs): + if isinstance(available_memory_gb, dict): + return available_memory_gb[device_id] + if available_memory_gb is not None: + return available_memory_gb + return memory_gb + + with ( + patch.object(PipelineConfig, "from_kwargs", return_value=pipeline_config), + patch( + "sglang.multimodal_gen.runtime.server_args.current_platform.is_cpu", + return_value=False, + ), + patch( + "sglang.multimodal_gen.runtime.server_args.current_platform.is_mps", + return_value=False, + ), + patch( + "sglang.multimodal_gen.runtime.server_args.current_platform.is_cuda", + return_value=True, + ), + patch( + "sglang.multimodal_gen.runtime.server_args.current_platform.enable_dit_layerwise_offload_for_wan_by_default", + return_value=True, + ), + patch( + "sglang.multimodal_gen.runtime.server_args.current_platform.get_device_total_memory", + return_value=memory_gb * 1024**3, + ), + patch( + "sglang.multimodal_gen.runtime.server_args.current_platform.get_available_gpu_memory", + side_effect=get_available_gpu_memory, + ), + ): + return ServerArgs.from_dict({"model_path": "/fake", **(kwargs or {})}) + def _from_dict_with_task_type( self, task_type, @@ -142,6 +212,10 @@ def _from_dict_with_task_type( "sglang.multimodal_gen.runtime.server_args.current_platform.get_device_total_memory", return_value=memory_gb * 1024**3, ), + patch( + "sglang.multimodal_gen.runtime.server_args.current_platform.get_available_gpu_memory", + return_value=memory_gb, + ), ): return ServerArgs.from_dict({"model_path": "/fake", **(kwargs or {})}) @@ -151,7 +225,11 @@ def test_vae_cpu_offload_defaults_false_for_video_generation(self): self.assertFalse(args.vae_cpu_offload) def test_vae_cpu_offload_defaults_false_on_low_memory_gpu(self): - args = self._from_dict_with_task_type(ModelTaskType.T2V, memory_gb=16) + args = self._from_dict_with_task_type( + ModelTaskType.T2V, + memory_gb=16, + kwargs={"performance_mode": "memory"}, + ) self.assertFalse(args.vae_cpu_offload) self.assertTrue(args.dit_cpu_offload) @@ -166,6 +244,353 @@ def test_explicit_vae_cpu_offload_true_is_preserved(self): self.assertTrue(args.vae_cpu_offload) + def test_pipeline_configs_declare_auto_tune_hints(self): + qwen_deployment = QwenImagePipelineConfig().get_model_deployment_config() + wan_deployment = WanT2V480PConfig().get_model_deployment_config() + mova_deployment = MOVAPipelineConfig().get_model_deployment_config() + zimage_deployment = ZImagePipelineConfig().get_model_deployment_config() + ltx_deployment = LTX2PipelineConfig().get_model_deployment_config() + + self.assertIsNone(qwen_deployment.fsdp_auto_min_available_memory_gb) + self.assertFalse(qwen_deployment.auto_dit_layerwise_offload) + + self.assertIsNone(wan_deployment.fsdp_auto_min_available_memory_gb) + self.assertTrue(wan_deployment.auto_dit_layerwise_offload) + + self.assertIsNone(mova_deployment.fsdp_auto_min_available_memory_gb) + self.assertTrue(mova_deployment.auto_dit_layerwise_offload) + + self.assertEqual(zimage_deployment.fsdp_auto_min_available_memory_gb, 40) + self.assertTrue(zimage_deployment.fsdp_auto_requires_cfg) + self.assertFalse(zimage_deployment.auto_dit_layerwise_offload) + + self.assertEqual( + ltx_deployment.auto_disable_component_offload_min_available_memory_gb, 70 + ) + self.assertEqual( + ltx_deployment.auto_disable_component_offload_components, ("dit",) + ) + + def test_manual_mode_preserves_unset_performance_args(self): + args = self._from_dict_with_pipeline_config( + QwenImagePipelineConfig(), + kwargs={ + "model_path": "Qwen/Qwen-Image", + "num_gpus": 2, + "performance_mode": "manual", + }, + ) + + self.assertEqual(args.performance_mode, "manual") + self.assertIsNone(args.use_fsdp_inference) + self.assertIsNone(args.dit_cpu_offload) + self.assertIsNone(args.dit_layerwise_offload) + self.assertIsNone(args.text_encoder_cpu_offload) + self.assertIsNone(args.image_encoder_cpu_offload) + self.assertFalse(args.enable_cfg_parallel) + + def test_default_auto_keeps_legacy_single_gpu_offload_defaults(self): + args = self._from_dict_with_pipeline_config( + QwenImagePipelineConfig(), + kwargs={"model_path": "Qwen/Qwen-Image"}, + ) + + self.assertEqual(args.performance_mode, "auto") + self.assertFalse(args.use_fsdp_inference) + self.assertTrue(args.dit_cpu_offload) + self.assertFalse(args.dit_layerwise_offload) + self.assertTrue(args.text_encoder_cpu_offload) + self.assertFalse(args.image_encoder_cpu_offload) + + def test_auto_ltx_snapshot_keeps_dit_offload_with_headroom(self): + args = self._from_dict_with_pipeline_config( + LTX2PipelineConfig(), + available_memory_gb=76, + kwargs={ + "model_path": "Lightricks/LTX-2.3", + "pipeline_class_name": "LTX2TwoStageHQPipeline", + "ltx2_two_stage_device_mode": "snapshot", + "performance_mode": "auto", + }, + ) + + self.assertEqual(args.ltx2_two_stage_device_mode, "snapshot") + self.assertTrue(args.dit_cpu_offload) + self.assertTrue(args.text_encoder_cpu_offload) + self.assertTrue(args.image_encoder_cpu_offload) + + def test_auto_wan_layerwise_offload_is_enabled_without_fsdp(self): + args = self._from_dict_with_pipeline_config( + WanT2V480PConfig(), + kwargs={"performance_mode": "auto"}, + ) + + self.assertTrue(args.dit_layerwise_offload) + self.assertFalse(args.use_fsdp_inference) + + def test_memory_wan_layerwise_offload_is_enabled_without_fsdp(self): + args = self._from_dict_with_pipeline_config( + WanT2V480PConfig(), + kwargs={"performance_mode": "memory"}, + ) + + self.assertTrue(args.dit_layerwise_offload) + self.assertFalse(args.use_fsdp_inference) + + def test_auto_wan_layerwise_offload_does_not_disable_explicit_fsdp(self): + args = self._from_dict_with_pipeline_config( + WanT2V480PConfig(), + kwargs={ + "model_path": "Wan-AI/Wan2.1-T2V-1.3B-Diffusers", + "num_gpus": 2, + "performance_mode": "auto", + "use_fsdp_inference": True, + }, + ) + + self.assertFalse(args.dit_layerwise_offload) + self.assertTrue(args.use_fsdp_inference) + + def test_auto_multi_gpu_wan_uses_layerwise_offload_without_cfg(self): + with patch.object(ServerArgs, "_model_default_uses_cfg", return_value=False): + args = self._from_dict_with_pipeline_config( + WanT2V480PConfig(), + kwargs={ + "model_path": "Wan-AI/Wan2.1-T2V-1.3B-Diffusers", + "num_gpus": 2, + "performance_mode": "auto", + }, + ) + + self.assertFalse(args.use_fsdp_inference) + self.assertFalse(args.enable_cfg_parallel) + self.assertFalse(args.dit_cpu_offload) + self.assertTrue(args.dit_layerwise_offload) + + def test_auto_multi_gpu_qwen_keeps_legacy_offload_with_cfg(self): + args = self._from_dict_with_pipeline_config( + QwenImagePipelineConfig(), + kwargs={ + "model_path": "Qwen/Qwen-Image", + "num_gpus": 2, + "performance_mode": "auto", + }, + ) + + self.assertFalse(args.use_fsdp_inference) + self.assertTrue(args.enable_cfg_parallel) + self.assertTrue(args.dit_cpu_offload) + self.assertFalse(args.dit_layerwise_offload) + self.assertTrue(args.text_encoder_cpu_offload) + self.assertFalse(args.image_encoder_cpu_offload) + + def test_auto_multi_gpu_zimage_base_prefers_fsdp(self): + args = self._from_dict_with_pipeline_config( + ZImagePipelineConfig(), + kwargs={ + "model_path": "Tongyi-MAI/Z-Image", + "num_gpus": 2, + "performance_mode": "auto", + }, + ) + + self.assertTrue(args.use_fsdp_inference) + self.assertTrue(args.enable_cfg_parallel) + + def test_auto_multi_gpu_zimage_turbo_skips_fsdp(self): + args = self._from_dict_with_pipeline_config( + ZImagePipelineConfig(), + kwargs={ + "model_path": "Tongyi-MAI/Z-Image-Turbo", + "num_gpus": 2, + "performance_mode": "auto", + }, + ) + + self.assertFalse(args.use_fsdp_inference) + self.assertFalse(args.enable_cfg_parallel) + + def test_auto_multi_gpu_qwen_preserves_explicit_fsdp_false(self): + args = self._from_dict_with_pipeline_config( + QwenImagePipelineConfig(), + kwargs={ + "model_path": "Qwen/Qwen-Image", + "num_gpus": 2, + "performance_mode": "auto", + "use_fsdp_inference": False, + }, + ) + + self.assertFalse(args.use_fsdp_inference) + self.assertTrue(args.enable_cfg_parallel) + self.assertTrue(args.dit_cpu_offload) + self.assertTrue(args.text_encoder_cpu_offload) + self.assertFalse(args.image_encoder_cpu_offload) + + def test_auto_multi_gpu_qwen_skips_fsdp_when_available_memory_is_low(self): + args = self._from_dict_with_pipeline_config( + QwenImagePipelineConfig(), + memory_gb=50, + kwargs={ + "model_path": "Qwen/Qwen-Image", + "num_gpus": 2, + "performance_mode": "auto", + }, + ) + + self.assertFalse(args.use_fsdp_inference) + self.assertTrue(args.enable_cfg_parallel) + self.assertTrue(args.dit_cpu_offload) + self.assertTrue(args.text_encoder_cpu_offload) + self.assertFalse(args.image_encoder_cpu_offload) + + def test_auto_multi_gpu_qwen_uses_selected_gpu_min_available_memory(self): + args = self._from_dict_with_pipeline_config( + QwenImagePipelineConfig(), + available_memory_gb={1: 50, 2: 80}, + kwargs={ + "model_path": "Qwen/Qwen-Image", + "base_gpu_id": 1, + "num_gpus": 2, + "performance_mode": "auto", + }, + ) + + self.assertFalse(args.use_fsdp_inference) + self.assertTrue(args.enable_cfg_parallel) + + def test_auto_multi_gpu_qwen_keeps_legacy_offload_with_headroom(self): + args = self._from_dict_with_pipeline_config( + QwenImagePipelineConfig(), + available_memory_gb={1: 72, 2: 80}, + kwargs={ + "model_path": "Qwen/Qwen-Image", + "base_gpu_id": 1, + "num_gpus": 2, + "performance_mode": "auto", + }, + ) + + self.assertFalse(args.use_fsdp_inference) + self.assertTrue(args.enable_cfg_parallel) + self.assertTrue(args.dit_cpu_offload) + self.assertTrue(args.text_encoder_cpu_offload) + self.assertFalse(args.image_encoder_cpu_offload) + + def test_speed_mode_single_gpu_disables_offload(self): + args = self._from_dict_with_pipeline_config( + QwenImagePipelineConfig(), + kwargs={ + "model_path": "Qwen/Qwen-Image", + "performance_mode": "speed", + }, + ) + + self.assertEqual(args.performance_mode, "speed") + self.assertFalse(args.use_fsdp_inference) + self.assertFalse(args.dit_cpu_offload) + self.assertFalse(args.dit_layerwise_offload) + self.assertFalse(args.text_encoder_cpu_offload) + self.assertFalse(args.image_encoder_cpu_offload) + + def test_speed_mode_preserves_explicit_offload(self): + args = self._from_dict_with_pipeline_config( + QwenImagePipelineConfig(), + kwargs={ + "model_path": "Qwen/Qwen-Image", + "performance_mode": "speed", + "dit_cpu_offload": True, + }, + ) + + self.assertEqual(args.performance_mode, "speed") + self.assertTrue(args.dit_cpu_offload) + self.assertFalse(args.text_encoder_cpu_offload) + self.assertFalse(args.image_encoder_cpu_offload) + + def test_memory_mode_wan_uses_layerwise_offload(self): + args = self._from_dict_with_pipeline_config( + WanT2V480PConfig(), + kwargs={ + "model_path": "Wan-AI/Wan2.1-T2V-1.3B-Diffusers", + "performance_mode": "memory", + }, + ) + + self.assertFalse(args.use_fsdp_inference) + self.assertTrue(args.dit_layerwise_offload) + self.assertFalse(args.dit_cpu_offload) + self.assertTrue(args.text_encoder_cpu_offload) + self.assertTrue(args.image_encoder_cpu_offload) + + def test_memory_mode_preserves_explicit_fsdp(self): + args = self._from_dict_with_pipeline_config( + WanT2V480PConfig(), + kwargs={ + "model_path": "Wan-AI/Wan2.1-T2V-1.3B-Diffusers", + "num_gpus": 2, + "performance_mode": "memory", + "use_fsdp_inference": True, + }, + ) + + self.assertTrue(args.use_fsdp_inference) + self.assertFalse(args.dit_layerwise_offload) + self.assertFalse(args.dit_cpu_offload) + + def test_invalid_performance_mode_raises(self): + with self.assertRaises(ValueError): + self._from_dict_with_pipeline_config( + QwenImagePipelineConfig(), + kwargs={"performance_mode": "turbo"}, + ) + + def test_cfg_parallel_cli_can_be_disabled_explicitly(self): + parser = FlexibleArgumentParser() + ServerArgs.add_cli_args(parser) + argv = [ + "--model-path", + "Qwen/Qwen-Image", + "--num-gpus", + "2", + "--performance-mode", + "auto", + "--enable-cfg-parallel", + "false", + ] + + with ( + patch.object(sys, "argv", ["sglang"] + argv), + patch.object( + PipelineConfig, "from_kwargs", return_value=QwenImagePipelineConfig() + ), + patch( + "sglang.multimodal_gen.runtime.server_args.current_platform.is_cpu", + return_value=False, + ), + patch( + "sglang.multimodal_gen.runtime.server_args.current_platform.is_mps", + return_value=False, + ), + patch( + "sglang.multimodal_gen.runtime.server_args.current_platform.is_cuda", + return_value=True, + ), + patch( + "sglang.multimodal_gen.runtime.server_args.current_platform.get_device_total_memory", + return_value=80 * 1024**3, + ), + patch( + "sglang.multimodal_gen.runtime.server_args.current_platform.get_available_gpu_memory", + return_value=80, + ), + ): + args, unknown_args = parser.parse_known_args(argv) + server_args = ServerArgs.from_cli_args(args, unknown_args) + + self.assertFalse(server_args.use_fsdp_inference) + self.assertFalse(server_args.enable_cfg_parallel) + class TestFSDPShardConditions(unittest.TestCase): def test_helpers_match_only_direct_block_entries(self): diff --git a/python/sglang/srt/configs/chatglm.py b/python/sglang/srt/configs/chatglm.py index 9370c218aab8..dce801196f57 100644 --- a/python/sglang/srt/configs/chatglm.py +++ b/python/sglang/srt/configs/chatglm.py @@ -43,7 +43,7 @@ def __init__( quantization_bit=0, pre_seq_len=None, prefix_projection=False, - **kwargs + **kwargs, ): self.num_layers = num_layers self.vocab_size = padded_vocab_size diff --git a/python/sglang/srt/configs/dots_ocr.py b/python/sglang/srt/configs/dots_ocr.py index 8b0693b8e9cc..031972a95fcc 100644 --- a/python/sglang/srt/configs/dots_ocr.py +++ b/python/sglang/srt/configs/dots_ocr.py @@ -16,7 +16,7 @@ def __init__( video_token_id=151656, vision_config: Optional[dict] = None, *args, - **kwargs + **kwargs, ): super().__init__(*args, **kwargs) self.image_token_id = image_token_id @@ -42,7 +42,7 @@ def __init__( tokenizer=None, video_processor=None, chat_template=None, - **kwargs + **kwargs, ): if video_processor is None: video_processor = DummyVideoProcessor() diff --git a/python/sglang/srt/configs/exaone.py b/python/sglang/srt/configs/exaone.py index f5d91c45cf84..5f8349d23cd2 100644 --- a/python/sglang/srt/configs/exaone.py +++ b/python/sglang/srt/configs/exaone.py @@ -161,7 +161,7 @@ def __init__( bos_token_id=0, eos_token_id=2, tie_word_embeddings=True, - **kwargs + **kwargs, ): self.vocab_size = vocab_size self.max_position_embeddings = max_position_embeddings @@ -192,5 +192,5 @@ def __init__( bos_token_id=bos_token_id, eos_token_id=eos_token_id, tie_word_embeddings=tie_word_embeddings, - **kwargs + **kwargs, ) diff --git a/python/sglang/srt/configs/kimi_vl.py b/python/sglang/srt/configs/kimi_vl.py index 3c7d20f5944d..db88e07fa7fc 100644 --- a/python/sglang/srt/configs/kimi_vl.py +++ b/python/sglang/srt/configs/kimi_vl.py @@ -18,7 +18,7 @@ def __init__( ignore_index: int = -100, media_placeholder_token_id: int = 163605, pad_token_id: int = 0, - **kwargs + **kwargs, ): if vision_config is None: vision_config = MoonViTConfig() diff --git a/python/sglang/srt/distributed/parallel_state.py b/python/sglang/srt/distributed/parallel_state.py index c42760ce5b72..30d186b7d5cc 100644 --- a/python/sglang/srt/distributed/parallel_state.py +++ b/python/sglang/srt/distributed/parallel_state.py @@ -1559,9 +1559,10 @@ def graph_capture(stream: Optional[torch.cuda.Stream] = None): in order to explicitly distinguish the kernels to capture from other kernels possibly launched on background in the default stream. """ - with get_tp_group().graph_capture( - stream=stream - ) as context, get_pp_group().graph_capture(context): + with ( + get_tp_group().graph_capture(stream=stream) as context, + get_pp_group().graph_capture(context), + ): with contextlib.ExitStack() as stack: seen = {id(_TP)} for group in (_MOE_EP, _MOE_TP): diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index 670dec46de2a..cc35d2e3c1d9 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -621,8 +621,9 @@ def _launch_scheduler_processes( writer, ), ) - with memory_saver_adapter.configure_subprocess(), numa_utils.configure_subprocess( - server_args, gpu_id + with ( + memory_saver_adapter.configure_subprocess(), + numa_utils.configure_subprocess(server_args, gpu_id), ): proc.start() diff --git a/python/sglang/srt/entrypoints/openai/tool_server.py b/python/sglang/srt/entrypoints/openai/tool_server.py index 269d9e99eae4..71f20c341b3e 100644 --- a/python/sglang/srt/entrypoints/openai/tool_server.py +++ b/python/sglang/srt/entrypoints/openai/tool_server.py @@ -19,9 +19,10 @@ async def list_server_and_tools(server_url: str): - async with sse_client(url=server_url) as streams, ClientSession( - *streams - ) as session: + async with ( + sse_client(url=server_url) as streams, + ClientSession(*streams) as session, + ): initialize_response = await session.initialize() list_tools_response = await session.list_tools() return initialize_response, list_tools_response @@ -131,9 +132,10 @@ def get_tool_description(self, tool_name: str): async def get_tool_session(self, tool_name: str): url = self.urls.get(tool_name) if url: - async with sse_client(url=url) as streams, ClientSession( - *streams - ) as session: + async with ( + sse_client(url=url) as streams, + ClientSession(*streams) as session, + ): await session.initialize() yield session else: diff --git a/python/sglang/srt/hardware_backend/npu/graph_runner/npu_graph_runner.py b/python/sglang/srt/hardware_backend/npu/graph_runner/npu_graph_runner.py index 9a377d5d9ade..f1a4ebf0923b 100644 --- a/python/sglang/srt/hardware_backend/npu/graph_runner/npu_graph_runner.py +++ b/python/sglang/srt/hardware_backend/npu/graph_runner/npu_graph_runner.py @@ -108,11 +108,14 @@ def _capture_graph(self, graph, pool, stream, run_once_fn): else: skip_guard_context = empty_context() - with skip_guard_context, torch.npu.graph( - graph, - pool=pool, - stream=stream, - auto_dispatch_capture=True, + with ( + skip_guard_context, + torch.npu.graph( + graph, + pool=pool, + stream=stream, + auto_dispatch_capture=True, + ), ): out = run_once_fn() return out diff --git a/python/sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py b/python/sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py index 89b17ce2db15..8cd397e30396 100644 --- a/python/sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +++ b/python/sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py @@ -395,7 +395,7 @@ def _decode_grouped_att_m_fwd_rope( IS_NEOX_STYLE=is_neox_style, num_warps=4, num_stages=num_stages, - **extra_kargs + **extra_kargs, ) diff --git a/python/sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py b/python/sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py index 04f7cd246ac0..f4510def79b3 100644 --- a/python/sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +++ b/python/sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py @@ -47,7 +47,7 @@ def create_weights( input_size_per_partition: int, params_dtype: torch.dtype, weight_loader: Callable, - **kwargs + **kwargs, ): output_size_per_partition = sum(output_partition_sizes) layer.logical_widths = output_partition_sizes diff --git a/python/sglang/srt/managers/data_parallel_controller.py b/python/sglang/srt/managers/data_parallel_controller.py index 57632e146d94..a64cbdb01aa8 100644 --- a/python/sglang/srt/managers/data_parallel_controller.py +++ b/python/sglang/srt/managers/data_parallel_controller.py @@ -545,8 +545,9 @@ def launch_tensor_parallel_group( writer, ), ) - with memory_saver_adapter.configure_subprocess(), numa_utils.configure_subprocess( - server_args, gpu_id + with ( + memory_saver_adapter.configure_subprocess(), + numa_utils.configure_subprocess(server_args, gpu_id), ): proc.start() self.scheduler_procs.append(proc) diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index ad3f650ad7ba..43758f1335cd 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -249,10 +249,13 @@ def __init__( maybe_init_custom_mem_pool(device=self.device) ) - with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE), ( - torch.cuda.use_mem_pool(self.custom_mem_pool) - if self.enable_custom_mem_pool - else nullcontext() + with ( + self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE), + ( + torch.cuda.use_mem_pool(self.custom_mem_pool) + if self.enable_custom_mem_pool + else nullcontext() + ), ): conv_state = [ torch.zeros( diff --git a/python/sglang/srt/model_executor/breakable_cuda_graph_runner.py b/python/sglang/srt/model_executor/breakable_cuda_graph_runner.py index 69426f29f359..08141a90b10b 100644 --- a/python/sglang/srt/model_executor/breakable_cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/breakable_cuda_graph_runner.py @@ -308,9 +308,11 @@ def _warmup(self): def _capture_all(self): """Capture breakable CUDA graphs for all token sizes.""" - with freeze_gc( - self.model_runner.server_args.enable_cudagraph_gc - ), graph_capture() as graph_capture_context, enable_breakable_cuda_graph(): + with ( + freeze_gc(self.model_runner.server_args.enable_cudagraph_gc), + graph_capture() as graph_capture_context, + enable_breakable_cuda_graph(), + ): stream = graph_capture_context.stream pool = get_global_graph_memory_pool() diff --git a/python/sglang/srt/model_executor/piecewise_cuda_graph_runner.py b/python/sglang/srt/model_executor/piecewise_cuda_graph_runner.py index 661d1728c848..621fcc395bac 100644 --- a/python/sglang/srt/model_executor/piecewise_cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/piecewise_cuda_graph_runner.py @@ -488,9 +488,10 @@ def capture(self) -> None: # Trigger CUDA graph capture for specific shapes. # Capture the large shapes first so that the smaller shapes # can reuse the memory pool allocated for the large shapes. - with freeze_gc( - self.model_runner.server_args.enable_cudagraph_gc - ), graph_capture() as graph_capture_context: + with ( + freeze_gc(self.model_runner.server_args.enable_cudagraph_gc), + graph_capture() as graph_capture_context, + ): stream = graph_capture_context.stream with set_pcg_capture_stream(stream): avail_mem = get_available_gpu_memory( diff --git a/python/sglang/srt/multimodal/processors/deepseek_vl_v2.py b/python/sglang/srt/multimodal/processors/deepseek_vl_v2.py index 3a9edd0e4864..3a08a3c6170a 100644 --- a/python/sglang/srt/multimodal/processors/deepseek_vl_v2.py +++ b/python/sglang/srt/multimodal/processors/deepseek_vl_v2.py @@ -42,7 +42,7 @@ async def process_mm_data_async( request_obj, max_req_input_len, *args, - **kwargs + **kwargs, ): base_output = self.load_mm_data( input_text, diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index 10a38a98e6e6..0a50b9182b72 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -218,9 +218,11 @@ def __init__( self.eagle_use_aux_hidden_state = eagle_config.get( "use_aux_hidden_state", True ) - with self.draft_tp_context( - self.draft_model_runner.tp_group - ), speculative_moe_backend_context(), speculative_moe_a2a_backend_context(): + with ( + self.draft_tp_context(self.draft_model_runner.tp_group), + speculative_moe_backend_context(), + speculative_moe_a2a_backend_context(), + ): self.init_attention_backend() self.init_cuda_graphs() if self.adaptive_controller is not None: @@ -459,9 +461,11 @@ def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResul seq_lens_cpu, can_run_cuda_graph, ) = self.forward_target_extend(batch) - with self.draft_tp_context( - self.draft_model_runner.tp_group - ), speculative_moe_backend_context(), speculative_moe_a2a_backend_context(): + with ( + self.draft_tp_context(self.draft_model_runner.tp_group), + speculative_moe_backend_context(), + speculative_moe_a2a_backend_context(), + ): self.forward_draft_extend( batch, logits_output.hidden_states, @@ -478,9 +482,11 @@ def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResul else: set_time_batch(batch.reqs, "set_spec_draft_start_time", trace_only=True) - with self.draft_tp_context( - self.draft_model_runner.tp_group - ), speculative_moe_backend_context(), speculative_moe_a2a_backend_context(): + with ( + self.draft_tp_context(self.draft_model_runner.tp_group), + speculative_moe_backend_context(), + speculative_moe_a2a_backend_context(), + ): verify_input = self.draft(batch) set_time_batch(batch.reqs, "set_spec_draft_end_time", trace_only=True) @@ -502,9 +508,11 @@ def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResul batch.reqs, "set_spec_draft_extend_start_time", trace_only=True ) - with self.draft_tp_context( - self.draft_model_runner.tp_group - ), speculative_moe_backend_context(), speculative_moe_a2a_backend_context(): + with ( + self.draft_tp_context(self.draft_model_runner.tp_group), + speculative_moe_backend_context(), + speculative_moe_a2a_backend_context(), + ): # NOTE: We should use `check_forward_draft_extend_after_decode` # when DP attention is enabled, but it is slow. Skip it for now. draft_extend_input = verify_output.draft_extend_input diff --git a/python/sglang/srt/speculative/eagle_worker_v2.py b/python/sglang/srt/speculative/eagle_worker_v2.py index 2403aab4c8fe..e845fde3f2ce 100644 --- a/python/sglang/srt/speculative/eagle_worker_v2.py +++ b/python/sglang/srt/speculative/eagle_worker_v2.py @@ -181,9 +181,11 @@ def __init__( self.draft_tp_context = ( draft_tp_context if server_args.enable_dp_attention else empty_context ) - with self.draft_tp_context( - self.draft_runner.tp_group - ), speculative_moe_backend_context(), speculative_moe_a2a_backend_context(): + with ( + self.draft_tp_context(self.draft_runner.tp_group), + speculative_moe_backend_context(), + speculative_moe_a2a_backend_context(), + ): self.init_attention_backend() self.init_cuda_graphs() @@ -706,9 +708,13 @@ def __init__( # Build adaptive runtime states (must be after draft worker is fully initialized) if self.adaptive_controller is not None: - with self._draft_worker.draft_tp_context( - self._draft_worker.draft_runner.tp_group - ), speculative_moe_backend_context(), speculative_moe_a2a_backend_context(): + with ( + self._draft_worker.draft_tp_context( + self._draft_worker.draft_runner.tp_group + ), + speculative_moe_backend_context(), + speculative_moe_a2a_backend_context(), + ): self.adaptive_controller.register( SpecRuntimeState( speculative_num_steps=self.speculative_num_steps, @@ -758,9 +764,13 @@ def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch): else CaptureHiddenMode.LAST ) model_worker_batch.capture_hidden_mode = draft_capture_mode - with self.draft_worker.draft_tp_context( - self.draft_worker.draft_runner.tp_group - ), speculative_moe_backend_context(), speculative_moe_a2a_backend_context(): + with ( + self.draft_worker.draft_tp_context( + self.draft_worker.draft_runner.tp_group + ), + speculative_moe_backend_context(), + speculative_moe_a2a_backend_context(), + ): batch_output.next_draft_input = ( self.draft_worker._draft_extend_for_prefill( model_worker_batch, @@ -784,9 +794,13 @@ def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch): topk=self.topk, capture_hidden_mode=capture_mode, ) - with self.draft_worker.draft_tp_context( - self.draft_worker.draft_runner.tp_group - ), speculative_moe_backend_context(), speculative_moe_a2a_backend_context(): + with ( + self.draft_worker.draft_tp_context( + self.draft_worker.draft_runner.tp_group + ), + speculative_moe_backend_context(), + speculative_moe_a2a_backend_context(), + ): verify_input: EagleVerifyInput = self.draft_worker.draft( model_worker_batch ) @@ -800,9 +814,13 @@ def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch): self._draft_done_event.record() model_worker_batch.spec_info = verify_input batch_output = self.verify(model_worker_batch) - with self.draft_worker.draft_tp_context( - self.draft_worker.draft_runner.tp_group - ), speculative_moe_backend_context(), speculative_moe_a2a_backend_context(): + with ( + self.draft_worker.draft_tp_context( + self.draft_worker.draft_runner.tp_group + ), + speculative_moe_backend_context(), + speculative_moe_a2a_backend_context(), + ): self.draft_worker._draft_extend_for_decode( model_worker_batch, batch_output ) diff --git a/python/sglang/srt/speculative/frozen_kv_mtp_worker.py b/python/sglang/srt/speculative/frozen_kv_mtp_worker.py index 988797c113d0..02d1454b3282 100644 --- a/python/sglang/srt/speculative/frozen_kv_mtp_worker.py +++ b/python/sglang/srt/speculative/frozen_kv_mtp_worker.py @@ -179,9 +179,11 @@ def __init__( self.draft_model_runner.draft_attn_backend = self.draft_attn_backend self.cuda_graph_runner = None - with self.draft_tp_context( - self.draft_model_runner.tp_group - ), speculative_moe_backend_context(), speculative_moe_a2a_backend_context(): + with ( + self.draft_tp_context(self.draft_model_runner.tp_group), + speculative_moe_backend_context(), + speculative_moe_a2a_backend_context(), + ): self.init_cuda_graphs() @property @@ -422,9 +424,11 @@ def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResul seq_lens_cpu, can_run_cuda_graph, ) = self.forward_target_extend(batch) - with self.draft_tp_context( - self.draft_model_runner.tp_group - ), speculative_moe_backend_context(), speculative_moe_a2a_backend_context(): + with ( + self.draft_tp_context(self.draft_model_runner.tp_group), + speculative_moe_backend_context(), + speculative_moe_a2a_backend_context(), + ): self.forward_draft_extend( batch, logits_output.hidden_states, @@ -440,9 +444,11 @@ def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResul ) set_time_batch(batch.reqs, "set_spec_draft_start_time", trace_only=True) - with self.draft_tp_context( - self.draft_model_runner.tp_group - ), speculative_moe_backend_context(), speculative_moe_a2a_backend_context(): + with ( + self.draft_tp_context(self.draft_model_runner.tp_group), + speculative_moe_backend_context(), + speculative_moe_a2a_backend_context(), + ): verify_input = self.draft(batch) set_time_batch(batch.reqs, "set_spec_draft_end_time", trace_only=True) set_time_batch(batch.reqs, "set_spec_verify_start_time", trace_only=True) @@ -458,9 +464,11 @@ def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResul ) set_time_batch(batch.reqs, "set_spec_draft_extend_start_time", trace_only=True) - with self.draft_tp_context( - self.draft_model_runner.tp_group - ), speculative_moe_backend_context(), speculative_moe_a2a_backend_context(): + with ( + self.draft_tp_context(self.draft_model_runner.tp_group), + speculative_moe_backend_context(), + speculative_moe_a2a_backend_context(), + ): draft_extend_input = verify_output.draft_extend_input if ( self.server_args.enable_dp_attention diff --git a/python/sglang/srt/speculative/multi_layer_eagle_worker.py b/python/sglang/srt/speculative/multi_layer_eagle_worker.py index 190012974769..64cb2ff5da74 100644 --- a/python/sglang/srt/speculative/multi_layer_eagle_worker.py +++ b/python/sglang/srt/speculative/multi_layer_eagle_worker.py @@ -193,9 +193,10 @@ def __init__( self.draft_tp_context = ( draft_tp_context if server_args.enable_dp_attention else empty_context ) - with self.draft_tp_context( - self.mtp_model_runner(0).tp_group - ), speculative_moe_backend_context(): + with ( + self.draft_tp_context(self.mtp_model_runner(0).tp_group), + speculative_moe_backend_context(), + ): self.init_attention_backend() self.init_cuda_graphs() @@ -265,9 +266,10 @@ def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResul seq_lens_cpu, can_run_cuda_graph, ) = self.forward_target_extend(batch) - with self.draft_tp_context( - self.mtp_model_runner(0).tp_group - ), speculative_moe_backend_context(): + with ( + self.draft_tp_context(self.mtp_model_runner(0).tp_group), + speculative_moe_backend_context(), + ): self.forward_draft_extend( batch, logits_output.hidden_states, next_token_ids, seq_lens_cpu ) @@ -278,16 +280,18 @@ def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResul can_run_cuda_graph=can_run_cuda_graph, ) else: - with self.draft_tp_context( - self.mtp_model_runner(0).tp_group - ), speculative_moe_backend_context(): + with ( + self.draft_tp_context(self.mtp_model_runner(0).tp_group), + speculative_moe_backend_context(), + ): verify_input = self.draft(batch) batch.spec_info = verify_input logits_output, verify_output, can_run_cuda_graph = self.verify(batch) - with self.draft_tp_context( - self.mtp_model_runner(0).tp_group - ), speculative_moe_backend_context(): + with ( + self.draft_tp_context(self.mtp_model_runner(0).tp_group), + speculative_moe_backend_context(), + ): # NOTE: We should use `check_forward_draft_extend_after_decode` # when DP attention is enabled, but it is slow. Skip it for now. draft_extend_input = verify_output.draft_extend_input diff --git a/python/sglang/srt/speculative/multi_layer_eagle_worker_v2.py b/python/sglang/srt/speculative/multi_layer_eagle_worker_v2.py index a47d8742632d..4650619e9695 100644 --- a/python/sglang/srt/speculative/multi_layer_eagle_worker_v2.py +++ b/python/sglang/srt/speculative/multi_layer_eagle_worker_v2.py @@ -172,9 +172,10 @@ def __init__( self.draft_tp_context = ( draft_tp_context if server_args.enable_dp_attention else empty_context ) - with self.draft_tp_context( - self.draft_runner_list[0].tp_group - ), speculative_moe_backend_context(): + with ( + self.draft_tp_context(self.draft_runner_list[0].tp_group), + speculative_moe_backend_context(), + ): self.init_attention_backend() self.init_cuda_graphs() diff --git a/python/sglang/srt/speculative/standalone_worker.py b/python/sglang/srt/speculative/standalone_worker.py index a67e4196fea0..12d10533f37b 100644 --- a/python/sglang/srt/speculative/standalone_worker.py +++ b/python/sglang/srt/speculative/standalone_worker.py @@ -77,7 +77,11 @@ def __init__( self.hot_token_id = None # Init draft worker - with empty_context(), speculative_moe_backend_context(), speculative_moe_a2a_backend_context(): + with ( + empty_context(), + speculative_moe_backend_context(), + speculative_moe_a2a_backend_context(), + ): TpModelWorker.__init__( self, server_args=server_args, @@ -102,9 +106,11 @@ def __init__( self.draft_tp_context = ( draft_tp_context if server_args.enable_dp_attention else empty_context ) - with self.draft_tp_context( - self.draft_model_runner.tp_group - ), speculative_moe_backend_context(), speculative_moe_a2a_backend_context(): + with ( + self.draft_tp_context(self.draft_model_runner.tp_group), + speculative_moe_backend_context(), + speculative_moe_a2a_backend_context(), + ): self.init_attention_backend() self.init_cuda_graphs() diff --git a/python/sglang/srt/speculative/standalone_worker_v2.py b/python/sglang/srt/speculative/standalone_worker_v2.py index d79fd09a755a..dacf2ae565d0 100644 --- a/python/sglang/srt/speculative/standalone_worker_v2.py +++ b/python/sglang/srt/speculative/standalone_worker_v2.py @@ -116,9 +116,10 @@ def __init__( self.draft_tp_context = ( draft_tp_context if server_args.enable_dp_attention else empty_context ) - with self.draft_tp_context( - self.draft_runner.tp_group - ), speculative_moe_backend_context(): + with ( + self.draft_tp_context(self.draft_runner.tp_group), + speculative_moe_backend_context(), + ): self.init_attention_backend() self.init_cuda_graphs() self.tree_mask_mode = TreeMaskMode.FULL_MASK diff --git a/python/sglang/test/server_fixtures/eagle_fixture.py b/python/sglang/test/server_fixtures/eagle_fixture.py index d3201c08736f..512f6d9767de 100644 --- a/python/sglang/test/server_fixtures/eagle_fixture.py +++ b/python/sglang/test/server_fixtures/eagle_fixture.py @@ -37,9 +37,10 @@ class EagleServerBase(CustomTestCase): @classmethod def setUpClass(cls): cls.base_url = DEFAULT_URL_FOR_TEST - with envs.SGLANG_SPEC_NAN_DETECTION.override( - True - ), envs.SGLANG_SPEC_OOB_DETECTION.override(True): + with ( + envs.SGLANG_SPEC_NAN_DETECTION.override(True), + envs.SGLANG_SPEC_OOB_DETECTION.override(True), + ): cls.process = popen_launch_server( cls.target_model, cls.base_url, diff --git a/python/sglang/test/simple_eval_common.py b/python/sglang/test/simple_eval_common.py index b9e4057fa74c..56f8b5cd341b 100644 --- a/python/sglang/test/simple_eval_common.py +++ b/python/sglang/test/simple_eval_common.py @@ -528,13 +528,16 @@ def download_dataset(path, url): total_size = int(response.headers.get("content-length", 0)) block_size = 8192 - with open(path, "wb") as f, tqdm( - desc="Downloading", - total=total_size, - unit="iB", - unit_scale=True, - unit_divisor=1024, - ) as progress_bar: + with ( + open(path, "wb") as f, + tqdm( + desc="Downloading", + total=total_size, + unit="iB", + unit_scale=True, + unit_divisor=1024, + ) as progress_bar, + ): for data in response.iter_content(block_size): size = f.write(data) progress_bar.update(size) diff --git a/python/sglang/test/test_mm_utils.py b/python/sglang/test/test_mm_utils.py index bc8fc63de4d7..c526c43246ba 100644 --- a/python/sglang/test/test_mm_utils.py +++ b/python/sglang/test/test_mm_utils.py @@ -30,12 +30,12 @@ def test_materialize_proxy(self): model_specific_data={"image_grid_thw": [[1, 1, 1], [1, 1, 1]]}, ) - with patch.object( - schedule_batch.torch.cuda, "is_available", return_value=True - ), patch.object( - schedule_batch.torch.cuda, "current_device", return_value=0 - ), patch.object( - schedule_batch.envs.SGLANG_MM_BUFFER_SIZE_MB, "get", return_value=0 + with ( + patch.object(schedule_batch.torch.cuda, "is_available", return_value=True), + patch.object(schedule_batch.torch.cuda, "current_device", return_value=0), + patch.object( + schedule_batch.envs.SGLANG_MM_BUFFER_SIZE_MB, "get", return_value=0 + ), ): mm_inputs = MultimodalInputs.from_dict({"mm_items": [mm_item]}) diff --git a/python/sglang/utils.py b/python/sglang/utils.py index 7cc322a23117..33bc175deffc 100644 --- a/python/sglang/utils.py +++ b/python/sglang/utils.py @@ -426,13 +426,16 @@ def download_and_cache_file(url: str, filename: Optional[str] = None): chunk_size = 1024 # Download in chunks of 1KB # Use tqdm to display the progress bar - with open(filename, "wb") as f, tqdm( - desc=filename, - total=total_size, - unit="B", - unit_scale=True, - unit_divisor=1024, - ) as bar: + with ( + open(filename, "wb") as f, + tqdm( + desc=filename, + total=total_size, + unit="B", + unit_scale=True, + unit_divisor=1024, + ) as bar, + ): for chunk in response.iter_content(chunk_size=chunk_size): f.write(chunk) bar.update(len(chunk)) diff --git a/sgl-model-gateway/e2e_test/infra/simple_eval_common.py b/sgl-model-gateway/e2e_test/infra/simple_eval_common.py index 7be4358172c7..93767f3ef9c0 100644 --- a/sgl-model-gateway/e2e_test/infra/simple_eval_common.py +++ b/sgl-model-gateway/e2e_test/infra/simple_eval_common.py @@ -463,14 +463,17 @@ def download_dataset(path: str, url: str) -> None: total_size = int(response.headers.get("content-length", 0)) block_size = 8192 - with open(path, "wb") as f, tqdm( - desc="Downloading", - total=total_size, - unit="iB", - unit_scale=True, - unit_divisor=1024, - leave=False, - ) as progress_bar: + with ( + open(path, "wb") as f, + tqdm( + desc="Downloading", + total=total_size, + unit="iB", + unit_scale=True, + unit_divisor=1024, + leave=False, + ) as progress_bar, + ): for data in response.iter_content(block_size): size = f.write(data) progress_bar.update(size) diff --git a/test/manual/kv_transfer/test_mooncake_transfer_engine_init.py b/test/manual/kv_transfer/test_mooncake_transfer_engine_init.py index 33c92328adf2..64a45e2b0159 100755 --- a/test/manual/kv_transfer/test_mooncake_transfer_engine_init.py +++ b/test/manual/kv_transfer/test_mooncake_transfer_engine_init.py @@ -59,12 +59,15 @@ def _fake_init_mooncake_transfer_engine(*, hostname, gpu_id, ib_device): ib_device=ib_device, ) - with patch( - "sglang.srt.distributed.device_communicators.mooncake_transfer_engine.init_mooncake_transfer_engine", - side_effect=_fake_init_mooncake_transfer_engine, - ), patch( - "sglang.srt.model_executor.model_runner.get_local_ip_auto", - return_value="127.0.0.1", + with ( + patch( + "sglang.srt.distributed.device_communicators.mooncake_transfer_engine.init_mooncake_transfer_engine", + side_effect=_fake_init_mooncake_transfer_engine, + ), + patch( + "sglang.srt.model_executor.model_runner.get_local_ip_auto", + return_value="127.0.0.1", + ), ): ModelRunner.init_shared_mooncake_transfer_engine(dummy_runner) diff --git a/test/manual/test_tokenizer_batch_encode.py b/test/manual/test_tokenizer_batch_encode.py index 8d6e7539d332..31b9dd9c40d8 100644 --- a/test/manual/test_tokenizer_batch_encode.py +++ b/test/manual/test_tokenizer_batch_encode.py @@ -30,11 +30,13 @@ def setUp(self): ) self.port_args = PortArgs.init_new(self.server_args) - with patch("zmq.asyncio.Context"), patch( - "sglang.srt.utils.get_zmq_socket" - ), patch( - "sglang.srt.utils.hf_transformers_utils.get_tokenizer" - ) as mock_tokenizer: + with ( + patch("zmq.asyncio.Context"), + patch("sglang.srt.utils.get_zmq_socket"), + patch( + "sglang.srt.utils.hf_transformers_utils.get_tokenizer" + ) as mock_tokenizer, + ): mock_tokenizer.return_value = Mock(vocab_size=32000) self.tokenizer_manager = TokenizerManager(self.server_args, self.port_args) diff --git a/test/manual/test_tokenizer_manager.py b/test/manual/test_tokenizer_manager.py index d0febc75e530..fb1df645449e 100644 --- a/test/manual/test_tokenizer_manager.py +++ b/test/manual/test_tokenizer_manager.py @@ -37,11 +37,13 @@ def setUp(self): self.server_args = ServerArgs(model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST) self.port_args = PortArgs.init_new(self.server_args) - with patch("zmq.asyncio.Context"), patch( - "sglang.srt.utils.network.get_zmq_socket" - ), patch( - "sglang.srt.utils.hf_transformers_utils.get_tokenizer" - ) as mock_tokenizer: + with ( + patch("zmq.asyncio.Context"), + patch("sglang.srt.utils.network.get_zmq_socket"), + patch( + "sglang.srt.utils.hf_transformers_utils.get_tokenizer" + ) as mock_tokenizer, + ): mock_tokenizer.return_value = Mock(vocab_size=32000) self.tokenizer_manager = TokenizerManager(self.server_args, self.port_args) @@ -133,11 +135,13 @@ def setUp(self): self.server_args = ServerArgs(model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST) self.port_args = PortArgs.init_new(self.server_args) - with patch("zmq.asyncio.Context"), patch( - "sglang.srt.utils.network.get_zmq_socket" - ), patch( - "sglang.srt.utils.hf_transformers_utils.get_tokenizer" - ) as mock_tokenizer: + with ( + patch("zmq.asyncio.Context"), + patch("sglang.srt.utils.network.get_zmq_socket"), + patch( + "sglang.srt.utils.hf_transformers_utils.get_tokenizer" + ) as mock_tokenizer, + ): mock_tokenizer.return_value = Mock(vocab_size=32000) self.tokenizer_manager = TokenizerManager(self.server_args, self.port_args) @@ -191,11 +195,13 @@ def setUp(self): self.server_args = ServerArgs(model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST) self.port_args = PortArgs.init_new(self.server_args) - with patch("zmq.asyncio.Context"), patch( - "sglang.srt.utils.network.get_zmq_socket" - ), patch( - "sglang.srt.utils.hf_transformers_utils.get_tokenizer" - ) as mock_tokenizer: + with ( + patch("zmq.asyncio.Context"), + patch("sglang.srt.utils.network.get_zmq_socket"), + patch( + "sglang.srt.utils.hf_transformers_utils.get_tokenizer" + ) as mock_tokenizer, + ): mock_tokenizer.return_value = Mock(vocab_size=32000) self.tokenizer_manager = TokenizerManager(self.server_args, self.port_args) @@ -313,11 +319,13 @@ def setUp(self): self.server_args = ServerArgs(model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST) self.port_args = PortArgs.init_new(self.server_args) - with patch("zmq.asyncio.Context"), patch( - "sglang.srt.utils.network.get_zmq_socket" - ), patch( - "sglang.srt.utils.hf_transformers_utils.get_tokenizer" - ) as mock_tokenizer: + with ( + patch("zmq.asyncio.Context"), + patch("sglang.srt.utils.network.get_zmq_socket"), + patch( + "sglang.srt.utils.hf_transformers_utils.get_tokenizer" + ) as mock_tokenizer, + ): mock_tokenizer.return_value = Mock(vocab_size=32000) self.tokenizer_manager = TokenizerManager(self.server_args, self.port_args) diff --git a/test/registered/bench_fn/test_benchmark_datasets_api.py b/test/registered/bench_fn/test_benchmark_datasets_api.py index d807b50335b0..0b980369829b 100644 --- a/test/registered/bench_fn/test_benchmark_datasets_api.py +++ b/test/registered/bench_fn/test_benchmark_datasets_api.py @@ -405,12 +405,15 @@ def test_dataset_mapping_and_dispatch(self): fake_mmmu_dataset = _FakeMMMUDataset( [{"image_1": Image.new("RGB", (4, 4), color="white"), "question": "q"}] ) - with patch( - "sglang.benchmark.datasets.mmmu.get_processor", - return_value=self.processor, - ), patch( - "sglang.benchmark.datasets.mmmu.load_dataset", - return_value=fake_mmmu_dataset, + with ( + patch( + "sglang.benchmark.datasets.mmmu.get_processor", + return_value=self.processor, + ), + patch( + "sglang.benchmark.datasets.mmmu.load_dataset", + return_value=fake_mmmu_dataset, + ), ): mmmu_args = make_args(dataset_name="mmmu", num_prompts=1) mmmu_rows = get_dataset(mmmu_args, self.tokenizer, model_id="dummy-model") diff --git a/test/registered/distributed/test_parallel_state.py b/test/registered/distributed/test_parallel_state.py index c3610a274b46..0c3f7eb620af 100644 --- a/test/registered/distributed/test_parallel_state.py +++ b/test/registered/distributed/test_parallel_state.py @@ -73,20 +73,16 @@ def test_parallel_group_construction_tp8_attn_cp2(): # Mock the distributed backend # Note: get_rank() returns 0 because we're testing from a single process, # but initialize_model_parallel() still creates all groups for all ranks - with patch.object(parallel_state, "_WORLD", None), patch.object( - parallel_state, "_TP", None - ), patch.object(parallel_state, "_ATTN_CP", None), patch.object( - parallel_state, "_ATTN_TP", None - ), patch.object( - parallel_state, "_PP", None - ), patch( - "torch.distributed.is_initialized", return_value=True - ), patch( - "torch.distributed.get_world_size", return_value=world_size - ), patch( - "torch.distributed.get_rank", return_value=0 - ), patch( - "torch.distributed.get_backend", return_value="nccl" + with ( + patch.object(parallel_state, "_WORLD", None), + patch.object(parallel_state, "_TP", None), + patch.object(parallel_state, "_ATTN_CP", None), + patch.object(parallel_state, "_ATTN_TP", None), + patch.object(parallel_state, "_PP", None), + patch("torch.distributed.is_initialized", return_value=True), + patch("torch.distributed.get_world_size", return_value=world_size), + patch("torch.distributed.get_rank", return_value=0), + patch("torch.distributed.get_backend", return_value="nccl"), ): # Mock init_model_parallel_group to capture the groups being created @@ -101,11 +97,14 @@ def mock_init_model_parallel_group(group_ranks, local_rank, backend, **kwargs): mock_group.device_group = Mock() return mock_group - with patch.object( - parallel_state, - "init_model_parallel_group", - side_effect=mock_init_model_parallel_group, - ), patch.object(parallel_state, "get_world_group") as mock_world_group: + with ( + patch.object( + parallel_state, + "init_model_parallel_group", + side_effect=mock_init_model_parallel_group, + ), + patch.object(parallel_state, "get_world_group") as mock_world_group, + ): # Mock world group mock_world = Mock() @@ -173,22 +172,17 @@ def test_parallel_group_construction_tp8_moe_ep4_cp2(): world_size = 8 # Mock the distributed backend - with patch.object(parallel_state, "_WORLD", None), patch.object( - parallel_state, "_TP", None - ), patch.object(parallel_state, "_MOE_EP", None), patch.object( - parallel_state, "_MOE_DP", None - ), patch.object( - parallel_state, "_MOE_TP", None - ), patch.object( - parallel_state, "_PP", None - ), patch( - "torch.distributed.is_initialized", return_value=True - ), patch( - "torch.distributed.get_world_size", return_value=world_size - ), patch( - "torch.distributed.get_rank", return_value=0 - ), patch( - "torch.distributed.get_backend", return_value="nccl" + with ( + patch.object(parallel_state, "_WORLD", None), + patch.object(parallel_state, "_TP", None), + patch.object(parallel_state, "_MOE_EP", None), + patch.object(parallel_state, "_MOE_DP", None), + patch.object(parallel_state, "_MOE_TP", None), + patch.object(parallel_state, "_PP", None), + patch("torch.distributed.is_initialized", return_value=True), + patch("torch.distributed.get_world_size", return_value=world_size), + patch("torch.distributed.get_rank", return_value=0), + patch("torch.distributed.get_backend", return_value="nccl"), ): # Mock init_model_parallel_group to capture the groups being created @@ -203,11 +197,14 @@ def mock_init_model_parallel_group(group_ranks, local_rank, backend, **kwargs): mock_group.device_group = Mock() return mock_group - with patch.object( - parallel_state, - "init_model_parallel_group", - side_effect=mock_init_model_parallel_group, - ), patch.object(parallel_state, "get_world_group") as mock_world_group: + with ( + patch.object( + parallel_state, + "init_model_parallel_group", + side_effect=mock_init_model_parallel_group, + ), + patch.object(parallel_state, "get_world_group") as mock_world_group, + ): # Mock world group mock_world = Mock() diff --git a/test/registered/lora/test_lora_qwen3_8b_logprob_diff.py b/test/registered/lora/test_lora_qwen3_8b_logprob_diff.py index 4c0e8e1f382d..cd6f59fe7c3b 100644 --- a/test/registered/lora/test_lora_qwen3_8b_logprob_diff.py +++ b/test/registered/lora/test_lora_qwen3_8b_logprob_diff.py @@ -113,11 +113,15 @@ def test_auto_detect_lora_target_modules(self): of internal param names that would break LoRA auto-detection.""" model = _build_qwen3_mock() - with patch("sglang.srt.layers.linear.LinearBase", _MockLinearBase), patch( - "sglang.srt.layers.moe.fused_moe_triton.layer.FusedMoE", _MockFusedMoE - ), patch( - "sglang.srt.layers.vocab_parallel_embedding.ParallelLMHead", - _MockParallelLMHead, + with ( + patch("sglang.srt.layers.linear.LinearBase", _MockLinearBase), + patch( + "sglang.srt.layers.moe.fused_moe_triton.layer.FusedMoE", _MockFusedMoE + ), + patch( + "sglang.srt.layers.vocab_parallel_embedding.ParallelLMHead", + _MockParallelLMHead, + ), ): detected = auto_detect_lora_target_modules(model) diff --git a/test/registered/models/test_generation_models.py b/test/registered/models/test_generation_models.py index 64d5246f1229..437b4d49af77 100644 --- a/test/registered/models/test_generation_models.py +++ b/test/registered/models/test_generation_models.py @@ -161,14 +161,17 @@ def assert_close_logits_and_output_strs( ) as hf_runner: hf_outputs = hf_runner.forward(prompts, max_new_tokens=max_new_tokens) - with env_ctx, SRTRunner( - model_path, - tp_size=model_case.tp_size, - torch_dtype=torch_dtype, - model_type="generation", - trust_remote_code=model_case.trust_remote_code, - attention_backend=model_case.attention_backend, - ) as srt_runner: + with ( + env_ctx, + SRTRunner( + model_path, + tp_size=model_case.tp_size, + torch_dtype=torch_dtype, + model_type="generation", + trust_remote_code=model_case.trust_remote_code, + attention_backend=model_case.attention_backend, + ) as srt_runner, + ): srt_outputs = srt_runner.forward(prompts, max_new_tokens=max_new_tokens) check_close_model_outputs( diff --git a/test/registered/rl/test_fp32_lm_head.py b/test/registered/rl/test_fp32_lm_head.py index 8b721cc4e485..63b0481741ba 100644 --- a/test/registered/rl/test_fp32_lm_head.py +++ b/test/registered/rl/test_fp32_lm_head.py @@ -86,8 +86,9 @@ def probe_linear(x, w, bias=None): state.update(called=True, ooperationp="linear", a=x.dtype, b=w.dtype) return original_linear(x, w, bias) - with patch("torch.matmul", new=probe_matmul), patch( - "torch.nn.functional.linear", new=probe_linear + with ( + patch("torch.matmul", new=probe_matmul), + patch("torch.nn.functional.linear", new=probe_linear), ): logits = logprocessor._get_logits(hidden_state, head, meta) self.assertEqual(hidden_state.dtype, hidden_state_dtype) diff --git a/test/registered/scheduler/test_retract_decode.py b/test/registered/scheduler/test_retract_decode.py index 56ed5b0d417d..81858ac6a071 100644 --- a/test/registered/scheduler/test_retract_decode.py +++ b/test/registered/scheduler/test_retract_decode.py @@ -31,9 +31,10 @@ def setUpClass(cls): cls.model = DEFAULT_MODEL_NAME_FOR_TEST cls.base_url = DEFAULT_URL_FOR_TEST launch_args = ["--chunked-prefill-size", "128"] + cls.other_args - with envs.SGLANG_TEST_RETRACT.override( - True - ), envs.SGLANG_ENABLE_STRICT_MEM_CHECK_DURING_BUSY.override(1): + with ( + envs.SGLANG_TEST_RETRACT.override(True), + envs.SGLANG_ENABLE_STRICT_MEM_CHECK_DURING_BUSY.override(1), + ): cls.process = popen_launch_server( cls.model, cls.base_url, diff --git a/test/registered/scheduler/test_scheduler_control.py b/test/registered/scheduler/test_scheduler_control.py index ea77c8023d83..aa8351ddbebc 100644 --- a/test/registered/scheduler/test_scheduler_control.py +++ b/test/registered/scheduler/test_scheduler_control.py @@ -331,9 +331,10 @@ class TestAbortWithRunningTimeout(CustomTestCase): def setUpClass(cls): cls.model = DEFAULT_MODEL_NAME_FOR_TEST cls.base_url = DEFAULT_URL_FOR_TEST - with envs.SGLANG_REQ_RUNNING_TIMEOUT.override( - 0.001 - ), envs.SGLANG_ENABLE_HEALTH_ENDPOINT_GENERATION.override(False): + with ( + envs.SGLANG_REQ_RUNNING_TIMEOUT.override(0.001), + envs.SGLANG_ENABLE_HEALTH_ENDPOINT_GENERATION.override(False), + ): cls.process = popen_launch_server( cls.model, cls.base_url, diff --git a/test/registered/sessions/test_streaming_session.py b/test/registered/sessions/test_streaming_session.py index f565f7742ebf..3c6901c006a9 100644 --- a/test/registered/sessions/test_streaming_session.py +++ b/test/registered/sessions/test_streaming_session.py @@ -796,9 +796,10 @@ class TestStreamingSessionRetractMixedChunk(TestStreamingSession): def setUpClass(cls): cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST cls.base_url = DEFAULT_URL_FOR_TEST - with envs.SGLANG_TEST_RETRACT.override( - True - ), envs.SGLANG_ENABLE_STRICT_MEM_CHECK_DURING_BUSY.override(2): + with ( + envs.SGLANG_TEST_RETRACT.override(True), + envs.SGLANG_ENABLE_STRICT_MEM_CHECK_DURING_BUSY.override(2), + ): cls.process = popen_launch_server( cls.model, cls.base_url, @@ -825,9 +826,10 @@ class TestStreamingSessionRetractLargePage(TestStreamingSession): def setUpClass(cls): cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST cls.base_url = DEFAULT_URL_FOR_TEST - with envs.SGLANG_TEST_RETRACT.override( - True - ), envs.SGLANG_ENABLE_STRICT_MEM_CHECK_DURING_BUSY.override(2): + with ( + envs.SGLANG_TEST_RETRACT.override(True), + envs.SGLANG_ENABLE_STRICT_MEM_CHECK_DURING_BUSY.override(2), + ): cls.process = popen_launch_server( cls.model, cls.base_url, @@ -856,9 +858,10 @@ class TestStreamingSessionEagle(TestStreamingSession): def setUpClass(cls): cls.model = DEFAULT_TARGET_MODEL_EAGLE3 cls.base_url = DEFAULT_URL_FOR_TEST - with envs.SGLANG_ENABLE_STRICT_MEM_CHECK_DURING_BUSY.override( - 2 - ), envs.SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN.override(True): + with ( + envs.SGLANG_ENABLE_STRICT_MEM_CHECK_DURING_BUSY.override(2), + envs.SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN.override(True), + ): cls.process = popen_launch_server( cls.model, cls.base_url, @@ -897,12 +900,10 @@ class TestStreamingSessionEagleV2(TestStreamingSession): def setUpClass(cls): cls.model = DEFAULT_TARGET_MODEL_EAGLE3 cls.base_url = DEFAULT_URL_FOR_TEST - with envs.SGLANG_ENABLE_SPEC_V2.override( - True - ), envs.SGLANG_ENABLE_STRICT_MEM_CHECK_DURING_BUSY.override( - 2 - ), envs.SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN.override( - True + with ( + envs.SGLANG_ENABLE_SPEC_V2.override(True), + envs.SGLANG_ENABLE_STRICT_MEM_CHECK_DURING_BUSY.override(2), + envs.SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN.override(True), ): cls.process = popen_launch_server( cls.model, @@ -944,12 +945,10 @@ class TestStreamingSessionEagleRetractLargePage(TestStreamingSession): def setUpClass(cls): cls.model = DEFAULT_TARGET_MODEL_EAGLE3 cls.base_url = DEFAULT_URL_FOR_TEST - with envs.SGLANG_TEST_RETRACT.override( - True - ), envs.SGLANG_ENABLE_STRICT_MEM_CHECK_DURING_BUSY.override( - 2 - ), envs.SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN.override( - True + with ( + envs.SGLANG_TEST_RETRACT.override(True), + envs.SGLANG_ENABLE_STRICT_MEM_CHECK_DURING_BUSY.override(2), + envs.SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN.override(True), ): cls.process = popen_launch_server( cls.model, @@ -991,14 +990,11 @@ class TestStreamingSessionEagleV2RetractLargePage(TestStreamingSession): def setUpClass(cls): cls.model = DEFAULT_TARGET_MODEL_EAGLE3 cls.base_url = DEFAULT_URL_FOR_TEST - with envs.SGLANG_ENABLE_SPEC_V2.override( - True - ), envs.SGLANG_TEST_RETRACT.override( - True - ), envs.SGLANG_ENABLE_STRICT_MEM_CHECK_DURING_BUSY.override( - 2 - ), envs.SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN.override( - True + with ( + envs.SGLANG_ENABLE_SPEC_V2.override(True), + envs.SGLANG_TEST_RETRACT.override(True), + envs.SGLANG_ENABLE_STRICT_MEM_CHECK_DURING_BUSY.override(2), + envs.SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN.override(True), ): cls.process = popen_launch_server( cls.model, diff --git a/test/registered/sessions/test_streaming_session_swa.py b/test/registered/sessions/test_streaming_session_swa.py index 76e124783874..787224cf535c 100644 --- a/test/registered/sessions/test_streaming_session_swa.py +++ b/test/registered/sessions/test_streaming_session_swa.py @@ -70,9 +70,10 @@ class TestStreamingSessionSWARetractLargePage(TestStreamingSession): def setUpClass(cls): cls.model = SWA_MODEL cls.base_url = DEFAULT_URL_FOR_TEST - with envs.SGLANG_TEST_RETRACT.override( - True - ), envs.SGLANG_ENABLE_STRICT_MEM_CHECK_DURING_BUSY.override(2): + with ( + envs.SGLANG_TEST_RETRACT.override(True), + envs.SGLANG_ENABLE_STRICT_MEM_CHECK_DURING_BUSY.override(2), + ): cls.process = popen_launch_server( cls.model, cls.base_url, @@ -100,9 +101,10 @@ class TestStreamingSessionSWARetractMixedChunk(TestStreamingSession): def setUpClass(cls): cls.model = SWA_MODEL cls.base_url = DEFAULT_URL_FOR_TEST - with envs.SGLANG_TEST_RETRACT.override( - True - ), envs.SGLANG_ENABLE_STRICT_MEM_CHECK_DURING_BUSY.override(2): + with ( + envs.SGLANG_TEST_RETRACT.override(True), + envs.SGLANG_ENABLE_STRICT_MEM_CHECK_DURING_BUSY.override(2), + ): cls.process = popen_launch_server( cls.model, cls.base_url, diff --git a/test/registered/spec/dflash/test_dflash.py b/test/registered/spec/dflash/test_dflash.py index d393e4019cdc..b64605e18c17 100644 --- a/test/registered/spec/dflash/test_dflash.py +++ b/test/registered/spec/dflash/test_dflash.py @@ -53,12 +53,10 @@ def setUpClass(cls): old_value = os.environ.get("SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN") os.environ["SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN"] = "1" try: - with envs.SGLANG_ENABLE_STRICT_MEM_CHECK_DURING_BUSY.override( - 1 - ), envs.SGLANG_SPEC_NAN_DETECTION.override( - True - ), envs.SGLANG_SPEC_OOB_DETECTION.override( - True + with ( + envs.SGLANG_ENABLE_STRICT_MEM_CHECK_DURING_BUSY.override(1), + envs.SGLANG_SPEC_NAN_DETECTION.override(True), + envs.SGLANG_SPEC_OOB_DETECTION.override(True), ): cls.process = popen_launch_server( cls.model, diff --git a/test/registered/spec/eagle/test_deepseek_v3_fp4_mtp_small.py b/test/registered/spec/eagle/test_deepseek_v3_fp4_mtp_small.py index 46a7cf5e55b9..ae6ae80f94c5 100644 --- a/test/registered/spec/eagle/test_deepseek_v3_fp4_mtp_small.py +++ b/test/registered/spec/eagle/test_deepseek_v3_fp4_mtp_small.py @@ -49,9 +49,10 @@ def setUpClass(cls): "--model-loader-extra-config", '{"enable_multithread_load": true,"num_threads": 64}', ] - with envs.SGLANG_SPEC_NAN_DETECTION.override( - True - ), envs.SGLANG_SPEC_OOB_DETECTION.override(True): + with ( + envs.SGLANG_SPEC_NAN_DETECTION.override(True), + envs.SGLANG_SPEC_OOB_DETECTION.override(True), + ): cls.process = popen_launch_server( cls.model, cls.base_url, diff --git a/test/registered/spec/eagle/test_eagle_constrained_decoding.py b/test/registered/spec/eagle/test_eagle_constrained_decoding.py index 05536be5e494..2862d60d3622 100644 --- a/test/registered/spec/eagle/test_eagle_constrained_decoding.py +++ b/test/registered/spec/eagle/test_eagle_constrained_decoding.py @@ -59,12 +59,10 @@ def setUpClass(cls): cls.grammar_backend, ] launch_args.extend(cls.other_launch_args) - with envs.SGLANG_ENABLE_SPEC_V2.override( - cls.spec_v2 - ), envs.SGLANG_SPEC_NAN_DETECTION.override( - True - ), envs.SGLANG_SPEC_OOB_DETECTION.override( - True + with ( + envs.SGLANG_ENABLE_SPEC_V2.override(cls.spec_v2), + envs.SGLANG_SPEC_NAN_DETECTION.override(True), + envs.SGLANG_SPEC_OOB_DETECTION.override(True), ): cls.process = popen_launch_server( cls.model, diff --git a/test/registered/spec/eagle/test_eagle_dp_attention.py b/test/registered/spec/eagle/test_eagle_dp_attention.py index 4fa87f1d542b..33a2b5d91e34 100644 --- a/test/registered/spec/eagle/test_eagle_dp_attention.py +++ b/test/registered/spec/eagle/test_eagle_dp_attention.py @@ -57,9 +57,10 @@ def setUpClass(cls): "--cuda-graph-max-bs", "64", ] - with envs.SGLANG_SPEC_NAN_DETECTION.override( - True - ), envs.SGLANG_SPEC_OOB_DETECTION.override(True): + with ( + envs.SGLANG_SPEC_NAN_DETECTION.override(True), + envs.SGLANG_SPEC_OOB_DETECTION.override(True), + ): cls.process = popen_launch_server( cls.model, cls.base_url, diff --git a/test/registered/spec/eagle/test_eagle_infer_beta.py b/test/registered/spec/eagle/test_eagle_infer_beta.py index 436ac2001e5e..6fff766f4b1b 100644 --- a/test/registered/spec/eagle/test_eagle_infer_beta.py +++ b/test/registered/spec/eagle/test_eagle_infer_beta.py @@ -63,14 +63,11 @@ def setUpClass(cls): *[str(i) for i in range(1, cls.max_running_requests + 1)], ] launch_args.extend(cls.other_launch_args) - with envs.SGLANG_ENABLE_STRICT_MEM_CHECK_DURING_BUSY.override( - 1 - ), envs.SGLANG_SPEC_NAN_DETECTION.override( - True - ), envs.SGLANG_SPEC_OOB_DETECTION.override( - True - ), envs.SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN.override( - True + with ( + envs.SGLANG_ENABLE_STRICT_MEM_CHECK_DURING_BUSY.override(1), + envs.SGLANG_SPEC_NAN_DETECTION.override(True), + envs.SGLANG_SPEC_OOB_DETECTION.override(True), + envs.SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN.override(True), ): cls.process = popen_launch_server( cls.model, diff --git a/test/registered/spec/eagle/test_eagle_infer_beta_dp_attention.py b/test/registered/spec/eagle/test_eagle_infer_beta_dp_attention.py index b0dacf4525a7..68c86f8c9199 100644 --- a/test/registered/spec/eagle/test_eagle_infer_beta_dp_attention.py +++ b/test/registered/spec/eagle/test_eagle_infer_beta_dp_attention.py @@ -65,9 +65,10 @@ def setUpClass(cls): "--speculative-num-draft-tokens", "4", ] - with envs.SGLANG_SPEC_NAN_DETECTION.override( - True - ), envs.SGLANG_SPEC_OOB_DETECTION.override(True): + with ( + envs.SGLANG_SPEC_NAN_DETECTION.override(True), + envs.SGLANG_SPEC_OOB_DETECTION.override(True), + ): cls.process = popen_launch_server( cls.model, cls.base_url, diff --git a/test/registered/spec/eagle/test_eagle_infer_beta_dp_attention_large.py b/test/registered/spec/eagle/test_eagle_infer_beta_dp_attention_large.py index c64acb19cddd..0c1d63ec9490 100644 --- a/test/registered/spec/eagle/test_eagle_infer_beta_dp_attention_large.py +++ b/test/registered/spec/eagle/test_eagle_infer_beta_dp_attention_large.py @@ -73,9 +73,10 @@ def setUpClass(cls): "--model-loader-extra-config", '{"enable_multithread_load": true,"num_threads": 64}', ] - with envs.SGLANG_SPEC_NAN_DETECTION.override( - True - ), envs.SGLANG_SPEC_OOB_DETECTION.override(True): + with ( + envs.SGLANG_SPEC_NAN_DETECTION.override(True), + envs.SGLANG_SPEC_OOB_DETECTION.override(True), + ): cls.process = popen_launch_server( cls.model, cls.base_url, diff --git a/test/registered/spec/test_constrained_decoding_spec_reasoning.py b/test/registered/spec/test_constrained_decoding_spec_reasoning.py index 18225dc25411..1cd4dc60f9dd 100644 --- a/test/registered/spec/test_constrained_decoding_spec_reasoning.py +++ b/test/registered/spec/test_constrained_decoding_spec_reasoning.py @@ -51,9 +51,10 @@ def setUpClass(cls): "--speculative-num-draft-tokens=8", ] - with envs.SGLANG_SPEC_NAN_DETECTION.override( - True - ), envs.SGLANG_SPEC_OOB_DETECTION.override(True): + with ( + envs.SGLANG_SPEC_NAN_DETECTION.override(True), + envs.SGLANG_SPEC_OOB_DETECTION.override(True), + ): cls.process = popen_launch_server( cls.model, cls.base_url, diff --git a/test/registered/unit/entrypoints/openai/test_serving_chat.py b/test/registered/unit/entrypoints/openai/test_serving_chat.py index 5e4b9f81e3b3..608d2e76b1ad 100644 --- a/test/registered/unit/entrypoints/openai/test_serving_chat.py +++ b/test/registered/unit/entrypoints/openai/test_serving_chat.py @@ -119,9 +119,12 @@ def setUp(self): # ------------- conversion tests ------------- def test_convert_to_internal_request_single(self): - with patch( - "sglang.srt.entrypoints.openai.serving_chat.generate_chat_conv" - ) as conv_mock, patch.object(self.chat, "_process_messages") as proc_mock: + with ( + patch( + "sglang.srt.entrypoints.openai.serving_chat.generate_chat_conv" + ) as conv_mock, + patch.object(self.chat, "_process_messages") as proc_mock, + ): conv_ins = Mock() conv_ins.get_prompt.return_value = "Test prompt" conv_ins.image_data = conv_ins.audio_data = None diff --git a/test/registered/unit/observability/test_trace.py b/test/registered/unit/observability/test_trace.py index 7fce827adfc1..6a03324fc8b7 100644 --- a/test/registered/unit/observability/test_trace.py +++ b/test/registered/unit/observability/test_trace.py @@ -123,32 +123,38 @@ def test_generates_nonzero_ids(self): # __get_host_id class TestGetHostId(unittest.TestCase): def test_from_machine_id_file(self): - with patch("os.path.exists", return_value=True), patch( - "builtins.open", - unittest.mock.mock_open(read_data="abc123\n"), + with ( + patch("os.path.exists", return_value=True), + patch( + "builtins.open", + unittest.mock.mock_open(read_data="abc123\n"), + ), ): self.assertEqual(_get_host_id(), "abc123") def test_from_machine_id_file_error(self): """Falls back to MAC address when file read fails.""" - with patch("os.path.exists", return_value=True), patch( - "builtins.open", side_effect=IOError("read error") + with ( + patch("os.path.exists", return_value=True), + patch("builtins.open", side_effect=IOError("read error")), ): result = _get_host_id() self.assertIsInstance(result, str) self.assertGreater(len(result), 0) def test_from_mac_address(self): - with patch("os.path.exists", return_value=False), patch( - "uuid.getnode", return_value=0x112233445566 + with ( + patch("os.path.exists", return_value=False), + patch("uuid.getnode", return_value=0x112233445566), ): result = _get_host_id() self.assertIsInstance(result, str) self.assertGreater(len(result), 0) def test_unknown_fallback(self): - with patch("os.path.exists", return_value=False), patch( - "uuid.getnode", return_value=0 + with ( + patch("os.path.exists", return_value=False), + patch("uuid.getnode", return_value=0), ): self.assertEqual(_get_host_id(), "unknown") diff --git a/test/registered/unit/tools/test_get_version_tag.py b/test/registered/unit/tools/test_get_version_tag.py index b6114fef115b..71370ad1e02d 100644 --- a/test/registered/unit/tools/test_get_version_tag.py +++ b/test/registered/unit/tools/test_get_version_tag.py @@ -50,11 +50,14 @@ def test_parse_version_tuple_sorts_stable_above_rc_and_post_above_stable(self): ) def test_exact_version_tag_takes_precedence_over_latest_tag(self): - with patch.object( - self.version_helper, "get_exact_version_tag", return_value="v0.5.9" - ), patch.object( - self.version_helper, "get_latest_version_tag_describe" - ) as latest_describe: + with ( + patch.object( + self.version_helper, "get_exact_version_tag", return_value="v0.5.9" + ), + patch.object( + self.version_helper, "get_latest_version_tag_describe" + ) as latest_describe, + ): self.assertEqual(self.version_helper.get_version_describe(), "v0.5.9") latest_describe.assert_not_called() @@ -68,15 +71,16 @@ def test_pyprojects_use_describe_mode_for_setuptools_scm(self): self.assertIn(FALLBACK_VERSION, content) def test_tag_only_cli_mode_remains_available_for_callers_that_need_latest_tag(self): - with patch.object( - sys, "argv", ["get_version_tag.py", "--tag-only"] - ), patch.object( - self.version_helper, "get_latest_version_tag", return_value="v0.5.10" - ), patch.object( - self.version_helper, "get_version_describe" - ) as version_describe, patch( - "builtins.print" - ) as print_mock: + with ( + patch.object(sys, "argv", ["get_version_tag.py", "--tag-only"]), + patch.object( + self.version_helper, "get_latest_version_tag", return_value="v0.5.10" + ), + patch.object( + self.version_helper, "get_version_describe" + ) as version_describe, + patch("builtins.print") as print_mock, + ): self.version_helper.main() version_describe.assert_not_called() diff --git a/test/registered/unit/utils/test_weight_checker.py b/test/registered/unit/utils/test_weight_checker.py index 1ee0442ed61b..643c513694e8 100644 --- a/test/registered/unit/utils/test_weight_checker.py +++ b/test/registered/unit/utils/test_weight_checker.py @@ -466,11 +466,14 @@ def test_passes_after_reset_then_restoring_normal_params(self): class TestHandle(_WeightCheckerTestBase): def test_routes_to_actions(self): - with patch.object(self.checker, "_snapshot") as m_snap, patch.object( - self.checker, "_reset_tensors" - ) as m_reset, patch.object(self.checker, "_compare") as m_compare, patch.object( - self.checker, "_compute_checksum", return_value={"checksums": {}} - ) as m_checksum: + with ( + patch.object(self.checker, "_snapshot") as m_snap, + patch.object(self.checker, "_reset_tensors") as m_reset, + patch.object(self.checker, "_compare") as m_compare, + patch.object( + self.checker, "_compute_checksum", return_value={"checksums": {}} + ) as m_checksum, + ): self.checker.handle("snapshot") self.checker.handle("reset_tensors") self.checker.handle("compare")