From ebff5fcb068b4d6e7b39014d15d69e04a56d4a2c Mon Sep 17 00:00:00 2001 From: kozo <87003759+kozoy@users.noreply.github.com> Date: Wed, 16 Jul 2025 09:17:34 +0800 Subject: [PATCH 001/396] feat: replace Decord with video_reader-rs (#5163) Signed-off-by: Xinyuan Tong Co-authored-by: Xinyuan Tong --- python/pyproject.toml | 1 + python/sglang/check_env.py | 2 +- .../multimodal/processors/base_processor.py | 4 ++-- .../srt/multimodal/processors/internvl.py | 4 ++-- .../srt/multimodal/processors/qwen_vl.py | 4 ++-- python/sglang/srt/utils.py | 22 +++++++------------ 6 files changed, 16 insertions(+), 21 deletions(-) diff --git a/python/pyproject.toml b/python/pyproject.toml index 86467457a78e..3d72566f71fc 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -21,6 +21,7 @@ runtime_common = [ "build", "compressed-tensors", "datasets", + "video-reader-rs", "fastapi", "hf_transfer", "huggingface_hub", diff --git a/python/sglang/check_env.py b/python/sglang/check_env.py index 1870e3207ae7..ba42c17beb2b 100644 --- a/python/sglang/check_env.py +++ b/python/sglang/check_env.py @@ -47,7 +47,7 @@ def is_cuda_v2(): "tiktoken", "anthropic", "litellm", - "decord", + "video-reader-rs", ] diff --git a/python/sglang/srt/multimodal/processors/base_processor.py b/python/sglang/srt/multimodal/processors/base_processor.py index 91aaa19090cf..7d7784c18f38 100644 --- a/python/sglang/srt/multimodal/processors/base_processor.py +++ b/python/sglang/srt/multimodal/processors/base_processor.py @@ -206,7 +206,7 @@ def get_estimated_frames_list(self, image_data): estimate the total frame count from all visual input """ # Lazy import because decord is not available on some arm platforms. - from decord import VideoReader, cpu + from video_reader import PyVideoReader, cpu # Before processing inputs if not image_data or len(image_data) == 0: @@ -216,7 +216,7 @@ def get_estimated_frames_list(self, image_data): if isinstance(image, str) and image.startswith("video:"): path = image[len("video:") :] # Estimate frames for the video - vr = VideoReader(path, ctx=cpu(0)) + vr = PyVideoReader(path, threads=0) num_frames = len(vr) else: # For images, each contributes one frame diff --git a/python/sglang/srt/multimodal/processors/internvl.py b/python/sglang/srt/multimodal/processors/internvl.py index df9b67aadeae..4b27a91a374c 100644 --- a/python/sglang/srt/multimodal/processors/internvl.py +++ b/python/sglang/srt/multimodal/processors/internvl.py @@ -150,7 +150,7 @@ def get_index(bound, fps, max_frame, first_idx=0, num_segments=32): def load_video(video_path, bound=None, input_size=448, max_num=1, num_segments=32): vr = VideoReader(video_path, ctx=cpu(0), num_threads=1) max_frame = len(vr) - 1 - fps = float(vr.get_avg_fps()) + fps = float(vr.get_fps()) pixel_values_list, num_patches_list = [], [] transform = InternVLImageProcessor.build_transform(input_size=input_size) @@ -158,7 +158,7 @@ def load_video(video_path, bound=None, input_size=448, max_num=1, num_segments=3 bound, fps, max_frame, first_idx=0, num_segments=num_segments ) for frame_index in frame_indices: - img = Image.fromarray(vr[frame_index].asnumpy()).convert("RGB") + img = Image.fromarray(vr[frame_index]).convert("RGB") img = InternVLImageProcessor.dynamic_preprocess( img, image_size=input_size, use_thumbnail=True, max_num=max_num ) diff --git a/python/sglang/srt/multimodal/processors/qwen_vl.py b/python/sglang/srt/multimodal/processors/qwen_vl.py index 1ecb4e119ac3..68381dbec639 100644 --- a/python/sglang/srt/multimodal/processors/qwen_vl.py +++ b/python/sglang/srt/multimodal/processors/qwen_vl.py @@ -156,10 +156,10 @@ async def preprocess_video( # vr: VideoReader, image_factor: int = IMAGE_FACTOR ) -> torch.Tensor: ele = {} - total_frames, video_fps = len(vr), vr.get_avg_fps() + total_frames, video_fps = len(vr), vr.get_fps() nframes = smart_nframes({}, total_frames=total_frames, video_fps=video_fps) idx = torch.linspace(0, total_frames - 1, nframes).round().long().tolist() - video = vr.get_batch(idx).asnumpy() + video = vr.get_batch(idx) video = torch.tensor(video).permute(0, 3, 1, 2) # Convert to TCHW format nframes, _, height, width = video.shape min_pixels = ele.get("min_pixels", VIDEO_MIN_PIXELS) diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index ce159a4da77b..377fa90c8367 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -84,6 +84,7 @@ from torch.profiler import ProfilerActivity, profile, record_function from torch.utils._contextlib import _DecoratorContextManager from triton.runtime.cache import FileCacheManager +from video_reader import PyVideoReader logger = logging.getLogger(__name__) @@ -757,16 +758,9 @@ def load_image( def load_video(video_file: Union[str, bytes], use_gpu: bool = True): # We import decord here to avoid a strange Segmentation fault (core dumped) issue. - from decord import VideoReader, cpu, gpu - - try: - from decord.bridge import decord_bridge - - ctx = gpu(0) - _ = decord_bridge.get_ctx_device(ctx) - except Exception: - ctx = cpu(0) + from video_reader import PyVideoReader + device = "cuda" if use_gpu and torch.cuda.is_available() else None tmp_file = None vr = None try: @@ -774,7 +768,7 @@ def load_video(video_file: Union[str, bytes], use_gpu: bool = True): tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") tmp_file.write(video_file) tmp_file.close() - vr = VideoReader(tmp_file.name, ctx=ctx) + vr = PyVideoReader(tmp_file.name, device=device, threads=0) elif isinstance(video_file, str): if video_file.startswith(("http://", "https://")): timeout = int(os.getenv("REQUEST_TIMEOUT", "10")) @@ -784,22 +778,22 @@ def load_video(video_file: Union[str, bytes], use_gpu: bool = True): for chunk in response.iter_content(chunk_size=8192): tmp_file.write(chunk) tmp_file.close() - vr = VideoReader(tmp_file.name, ctx=ctx) + vr = PyVideoReader(tmp_file.name, device=device, threads=0) elif video_file.startswith("data:"): _, encoded = video_file.split(",", 1) video_bytes = base64.b64decode(encoded) tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") tmp_file.write(video_bytes) tmp_file.close() - vr = VideoReader(tmp_file.name, ctx=ctx) + vr = PyVideoReader(tmp_file.name, device=device, threads=0) elif os.path.isfile(video_file): - vr = VideoReader(video_file, ctx=ctx) + vr = PyVideoReader(video_file, device=device, threads=0) else: video_bytes = base64.b64decode(video_file) tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") tmp_file.write(video_bytes) tmp_file.close() - vr = VideoReader(tmp_file.name, ctx=ctx) + vr = PyVideoReader(tmp_file.name, device=device, threads=0) else: raise ValueError(f"Unsupported video input type: {type(video_file)}") From 194841e3292ea918aae8389b1d6716ee1dab6653 Mon Sep 17 00:00:00 2001 From: strgrb Date: Wed, 16 Jul 2025 09:20:41 +0800 Subject: [PATCH 002/396] remove kv_a.congigous in DeepseekV2AttentionMLA (#8058) Co-authored-by: Zhang Kaihong --- python/sglang/srt/models/deepseek_v2.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 2b87d91d475f..bb1efde2941e 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -1154,7 +1154,7 @@ def forward_normal_prepare( _, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) latent_cache = latent_cache.unsqueeze(1) - kv_a = self.kv_a_layernorm(kv_a.contiguous()) + kv_a = self.kv_a_layernorm(kv_a) kv = self.kv_b_proj(kv_a)[0] kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim) k_nope = kv[..., : self.qk_nope_head_dim] @@ -1693,7 +1693,7 @@ def forward_normal_chunked_kv_prepare( _, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) latent_cache = latent_cache.unsqueeze(1) - kv_a = self.kv_a_layernorm(kv_a.contiguous()) + kv_a = self.kv_a_layernorm(kv_a) kv = self.kv_b_proj(kv_a)[0] kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim) k_nope = kv[..., : self.qk_nope_head_dim] From 7498522f7d296f9fbfe6534aec511674d0786dc4 Mon Sep 17 00:00:00 2001 From: Xinyuan Tong <115166877+JustinTong0323@users.noreply.github.com> Date: Tue, 15 Jul 2025 18:24:39 -0700 Subject: [PATCH 003/396] update transformers to 4.53.2 (#8029) Signed-off-by: Xinyuan Tong --- python/pyproject.toml | 2 +- test/srt/test_vlm_accuracy.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/python/pyproject.toml b/python/pyproject.toml index 3d72566f71fc..c538c4bcb3e0 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -46,7 +46,7 @@ runtime_common = [ "soundfile==0.13.1", "scipy", "torchao==0.9.0", - "transformers==4.53.0", + "transformers==4.53.2", "timm==1.0.16", "uvicorn", "uvloop", diff --git a/test/srt/test_vlm_accuracy.py b/test/srt/test_vlm_accuracy.py index a699a36feef4..ea83f3eef755 100644 --- a/test/srt/test_vlm_accuracy.py +++ b/test/srt/test_vlm_accuracy.py @@ -161,7 +161,6 @@ def get_sglang_model(self): return self.model_runner.model -# TODO: MiniCPMV is not compatible with transformers==4.52.3, temporarily disabled class TestMiniCPMVLogits(VisionLLMLogitsBase): @classmethod def setUpClass(cls): From 3bc43c683e6297cbc6b01e2d2468ca2b25052710 Mon Sep 17 00:00:00 2001 From: Qiaolin Yu Date: Tue, 15 Jul 2025 19:37:14 -0700 Subject: [PATCH 004/396] Fix different device type adjustment in PP (#7760) --- .../sglang/srt/distributed/parallel_state.py | 12 +++---- python/sglang/srt/managers/scheduler.py | 5 +++ python/sglang/srt/managers/tp_worker.py | 1 + python/sglang/srt/utils.py | 34 ++++++++----------- 4 files changed, 25 insertions(+), 27 deletions(-) diff --git a/python/sglang/srt/distributed/parallel_state.py b/python/sglang/srt/distributed/parallel_state.py index 509c71531062..5ab2e3758115 100644 --- a/python/sglang/srt/distributed/parallel_state.py +++ b/python/sglang/srt/distributed/parallel_state.py @@ -699,14 +699,14 @@ def send_object(self, obj: Any, dst: int) -> None: ) # Serialize object to tensor and get the size as well - object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8).cuda( - device=torch.cuda.current_device() + object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8).to( + device=self.device ) size_tensor = torch.tensor( [object_tensor.numel()], dtype=torch.long, - device=torch.cuda.current_device(), + device=self.device, ) # Send object size @@ -731,9 +731,7 @@ def recv_object(self, src: int) -> Any: src != self.rank_in_group ), "Invalid source rank. Source rank is the same as the current rank." - size_tensor = torch.empty( - 1, dtype=torch.long, device=torch.cuda.current_device() - ) + size_tensor = torch.empty(1, dtype=torch.long, device=self.device) # Receive object size rank_size = torch.distributed.recv( @@ -744,7 +742,7 @@ def recv_object(self, src: int) -> Any: object_tensor = torch.empty( # type: ignore[call-overload] size_tensor.item(), # type: ignore[arg-type] dtype=torch.uint8, - device=torch.cuda.current_device(), + device=self.device, ) rank_object = torch.distributed.recv( diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index afb4b870d34d..9a1654343603 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -962,6 +962,7 @@ def event_loop_pp(self): self.world_group.device_group, self.pp_rank * self.tp_size + dp_offset, (self.pp_rank + 1) * self.tp_size + dp_offset, + device=self.device, ) # send out proxy tensors to the next stage @@ -1010,6 +1011,7 @@ def recv_requests(self) -> List[Req]: self.world_group.device_group, (self.pp_rank - 1) * self.tp_size + dp_offset, self.pp_rank * self.tp_size + dp_offset, + device=self.device, ) else: recv_reqs = None @@ -1040,6 +1042,7 @@ def recv_requests(self) -> List[Req]: self.attn_tp_group.rank, self.attn_tp_cpu_group, src=self.attn_tp_group.ranks[0], + device=self.device, ) if self.tp_size != 1: control_reqs = broadcast_pyobj( @@ -1047,6 +1050,7 @@ def recv_requests(self) -> List[Req]: self.tp_group.rank, self.tp_cpu_group, src=self.tp_group.ranks[0], + device=self.device, ) recv_reqs = work_reqs + control_reqs elif self.tp_size != 1: @@ -1055,6 +1059,7 @@ def recv_requests(self) -> List[Req]: self.tp_group.rank, self.tp_cpu_group, src=self.tp_group.ranks[0], + device=self.device, ) return recv_reqs diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index ff20ea01e4d3..daeed4faff7c 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -144,6 +144,7 @@ def __init__( self.tp_size * self.pp_rank + tp_rank, self.world_group.cpu_group, src=self.world_group.ranks[0], + device=self.device, )[0] set_random_seed(self.random_seed) diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 377fa90c8367..d055aab5b9cf 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -1094,15 +1094,15 @@ def broadcast_pyobj( rank: int, dist_group: Optional[torch.distributed.ProcessGroup] = None, src: int = 0, - force_cpu_device: bool = True, + device: Optional[str] = None, ): """Broadcast inputs from src rank to all other ranks with torch.dist backend. The `rank` here refer to the source rank on global process group (regardless of dist_group argument). """ - device = torch.device( - "cuda" if torch.cuda.is_available() and not force_cpu_device else "cpu" - ) + + if device is None: + device = get_device() if rank == src: if len(data) == 0: @@ -1142,44 +1142,38 @@ def point_to_point_pyobj( group: Optional[torch.distributed.ProcessGroup] = None, src: int = 0, dst: int = 1, + device: Optional[str] = None, ): """Send data from src to dst in group using DeviceToDevice communication.""" - + if device is None: + device = get_device() if rank == src: if len(data) == 0: - tensor_size = torch.tensor( - [0], dtype=torch.long, device=torch.cuda.current_device() - ) + tensor_size = torch.tensor([0], dtype=torch.long, device=device) dist.send(tensor_size, dst=dst, group=group) else: serialized_data = pickle.dumps(data) size = len(serialized_data) tensor_data = torch.ByteTensor( np.frombuffer(serialized_data, dtype=np.uint8) - ).cuda( - device=torch.cuda.current_device() - ) # Move to GPU - tensor_size = torch.tensor( - [size], dtype=torch.long, device=torch.cuda.current_device() - ) + ).to( + device=device + ) # Move to Device + tensor_size = torch.tensor([size], dtype=torch.long, device=device) dist.send(tensor_size, dst=dst, group=group) dist.send(tensor_data, dst=dst, group=group) return data elif rank == dst: - tensor_size = torch.tensor( - [0], dtype=torch.long, device=torch.cuda.current_device() - ) + tensor_size = torch.tensor([0], dtype=torch.long, device=device) dist.recv(tensor_size, src=src, group=group) size = tensor_size.item() if size == 0: return [] - tensor_data = torch.empty( - size, dtype=torch.uint8, device=torch.cuda.current_device() - ) + tensor_data = torch.empty(size, dtype=torch.uint8, device=device) dist.recv(tensor_data, src=src, group=group) serialized_data = bytes( From 69f453e5a446a2fec28a106252836d644c33c2c6 Mon Sep 17 00:00:00 2001 From: Qiaolin Yu Date: Tue, 15 Jul 2025 19:38:58 -0700 Subject: [PATCH 005/396] Use device_group for all_gather when disabling overlap scheduling (#8001) --- python/sglang/bench_one_batch.py | 3 ++- python/sglang/srt/managers/scheduler.py | 16 +++++++++++++--- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/python/sglang/bench_one_batch.py b/python/sglang/bench_one_batch.py index cca7d5a495fa..4a027ae99721 100644 --- a/python/sglang/bench_one_batch.py +++ b/python/sglang/bench_one_batch.py @@ -271,12 +271,13 @@ def _maybe_prepare_mlp_sync_batch(batch: ScheduleBatch, model_runner): batch, dp_size=model_runner.server_args.dp_size, attn_tp_size=1, - tp_cpu_group=model_runner.tp_group.cpu_group, + tp_group=model_runner.tp_group, get_idle_batch=None, disable_cuda_graph=model_runner.server_args.disable_cuda_graph, spec_algorithm=SpeculativeAlgorithm.NONE, speculative_num_draft_tokens=None, require_mlp_tp_gather=require_mlp_tp_gather(model_runner.server_args), + disable_overlap_schedule=model_runner.server_args.disable_overlap_schedule, ) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 9a1654343603..a7f893253637 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -1945,7 +1945,7 @@ def prepare_mlp_sync_batch(self, local_batch: ScheduleBatch): local_batch, dp_size=self.server_args.dp_size, attn_tp_size=self.attn_tp_size, - tp_cpu_group=self.tp_cpu_group, + tp_group=self.tp_group, get_idle_batch=self.get_idle_batch, disable_cuda_graph=self.server_args.disable_cuda_graph, spec_algorithm=self.spec_algorithm, @@ -1954,6 +1954,7 @@ def prepare_mlp_sync_batch(self, local_batch: ScheduleBatch): enable_deepep_moe=self.server_args.enable_deepep_moe, deepep_mode=DeepEPMode[self.server_args.deepep_mode], require_mlp_tp_gather=require_mlp_tp_gather(self.server_args), + disable_overlap_schedule=self.server_args.disable_overlap_schedule, ) @staticmethod @@ -1961,7 +1962,7 @@ def prepare_mlp_sync_batch_raw( local_batch: ScheduleBatch, dp_size, attn_tp_size: int, - tp_cpu_group, + tp_group, get_idle_batch, disable_cuda_graph: bool, spec_algorithm, @@ -1970,6 +1971,7 @@ def prepare_mlp_sync_batch_raw( enable_deepep_moe: bool, deepep_mode: DeepEPMode, require_mlp_tp_gather: bool, + disable_overlap_schedule: bool, ): # Check if other DP workers have running batches if local_batch is None: @@ -2000,6 +2002,12 @@ def prepare_mlp_sync_batch_raw( ) tbo_preparer = TboDPAttentionPreparer() + if disable_overlap_schedule: + group = tp_group.device_group + device = tp_group.device + else: + group = tp_group.cpu_group + device = "cpu" local_info = torch.tensor( [ @@ -2015,15 +2023,17 @@ def prepare_mlp_sync_batch_raw( ), ], dtype=torch.int64, + device=device, ) global_info = torch.empty( (dp_size, attn_tp_size, 6), dtype=torch.int64, + device=device, ) torch.distributed.all_gather_into_tensor( global_info.flatten(), local_info, - group=tp_cpu_group, + group=group, ) global_num_tokens = global_info[:, 0, 0].tolist() can_cuda_graph = min(global_info[:, 0, 1].tolist()) From 497efe747d1f1cbcb6721f9d1721901e978956b4 Mon Sep 17 00:00:00 2001 From: Mick Date: Wed, 16 Jul 2025 11:04:56 +0800 Subject: [PATCH 006/396] Revert "feat: replace Decord with video_reader-rs" (#8077) --- python/pyproject.toml | 1 - python/sglang/check_env.py | 2 +- .../multimodal/processors/base_processor.py | 4 ++-- .../srt/multimodal/processors/internvl.py | 4 ++-- .../srt/multimodal/processors/qwen_vl.py | 4 ++-- python/sglang/srt/utils.py | 22 ++++++++++++------- 6 files changed, 21 insertions(+), 16 deletions(-) diff --git a/python/pyproject.toml b/python/pyproject.toml index c538c4bcb3e0..7afb3581a3b5 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -21,7 +21,6 @@ runtime_common = [ "build", "compressed-tensors", "datasets", - "video-reader-rs", "fastapi", "hf_transfer", "huggingface_hub", diff --git a/python/sglang/check_env.py b/python/sglang/check_env.py index ba42c17beb2b..1870e3207ae7 100644 --- a/python/sglang/check_env.py +++ b/python/sglang/check_env.py @@ -47,7 +47,7 @@ def is_cuda_v2(): "tiktoken", "anthropic", "litellm", - "video-reader-rs", + "decord", ] diff --git a/python/sglang/srt/multimodal/processors/base_processor.py b/python/sglang/srt/multimodal/processors/base_processor.py index 7d7784c18f38..91aaa19090cf 100644 --- a/python/sglang/srt/multimodal/processors/base_processor.py +++ b/python/sglang/srt/multimodal/processors/base_processor.py @@ -206,7 +206,7 @@ def get_estimated_frames_list(self, image_data): estimate the total frame count from all visual input """ # Lazy import because decord is not available on some arm platforms. - from video_reader import PyVideoReader, cpu + from decord import VideoReader, cpu # Before processing inputs if not image_data or len(image_data) == 0: @@ -216,7 +216,7 @@ def get_estimated_frames_list(self, image_data): if isinstance(image, str) and image.startswith("video:"): path = image[len("video:") :] # Estimate frames for the video - vr = PyVideoReader(path, threads=0) + vr = VideoReader(path, ctx=cpu(0)) num_frames = len(vr) else: # For images, each contributes one frame diff --git a/python/sglang/srt/multimodal/processors/internvl.py b/python/sglang/srt/multimodal/processors/internvl.py index 4b27a91a374c..df9b67aadeae 100644 --- a/python/sglang/srt/multimodal/processors/internvl.py +++ b/python/sglang/srt/multimodal/processors/internvl.py @@ -150,7 +150,7 @@ def get_index(bound, fps, max_frame, first_idx=0, num_segments=32): def load_video(video_path, bound=None, input_size=448, max_num=1, num_segments=32): vr = VideoReader(video_path, ctx=cpu(0), num_threads=1) max_frame = len(vr) - 1 - fps = float(vr.get_fps()) + fps = float(vr.get_avg_fps()) pixel_values_list, num_patches_list = [], [] transform = InternVLImageProcessor.build_transform(input_size=input_size) @@ -158,7 +158,7 @@ def load_video(video_path, bound=None, input_size=448, max_num=1, num_segments=3 bound, fps, max_frame, first_idx=0, num_segments=num_segments ) for frame_index in frame_indices: - img = Image.fromarray(vr[frame_index]).convert("RGB") + img = Image.fromarray(vr[frame_index].asnumpy()).convert("RGB") img = InternVLImageProcessor.dynamic_preprocess( img, image_size=input_size, use_thumbnail=True, max_num=max_num ) diff --git a/python/sglang/srt/multimodal/processors/qwen_vl.py b/python/sglang/srt/multimodal/processors/qwen_vl.py index 68381dbec639..1ecb4e119ac3 100644 --- a/python/sglang/srt/multimodal/processors/qwen_vl.py +++ b/python/sglang/srt/multimodal/processors/qwen_vl.py @@ -156,10 +156,10 @@ async def preprocess_video( # vr: VideoReader, image_factor: int = IMAGE_FACTOR ) -> torch.Tensor: ele = {} - total_frames, video_fps = len(vr), vr.get_fps() + total_frames, video_fps = len(vr), vr.get_avg_fps() nframes = smart_nframes({}, total_frames=total_frames, video_fps=video_fps) idx = torch.linspace(0, total_frames - 1, nframes).round().long().tolist() - video = vr.get_batch(idx) + video = vr.get_batch(idx).asnumpy() video = torch.tensor(video).permute(0, 3, 1, 2) # Convert to TCHW format nframes, _, height, width = video.shape min_pixels = ele.get("min_pixels", VIDEO_MIN_PIXELS) diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index d055aab5b9cf..37e06b8dcc72 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -84,7 +84,6 @@ from torch.profiler import ProfilerActivity, profile, record_function from torch.utils._contextlib import _DecoratorContextManager from triton.runtime.cache import FileCacheManager -from video_reader import PyVideoReader logger = logging.getLogger(__name__) @@ -758,9 +757,16 @@ def load_image( def load_video(video_file: Union[str, bytes], use_gpu: bool = True): # We import decord here to avoid a strange Segmentation fault (core dumped) issue. - from video_reader import PyVideoReader + from decord import VideoReader, cpu, gpu + + try: + from decord.bridge import decord_bridge + + ctx = gpu(0) + _ = decord_bridge.get_ctx_device(ctx) + except Exception: + ctx = cpu(0) - device = "cuda" if use_gpu and torch.cuda.is_available() else None tmp_file = None vr = None try: @@ -768,7 +774,7 @@ def load_video(video_file: Union[str, bytes], use_gpu: bool = True): tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") tmp_file.write(video_file) tmp_file.close() - vr = PyVideoReader(tmp_file.name, device=device, threads=0) + vr = VideoReader(tmp_file.name, ctx=ctx) elif isinstance(video_file, str): if video_file.startswith(("http://", "https://")): timeout = int(os.getenv("REQUEST_TIMEOUT", "10")) @@ -778,22 +784,22 @@ def load_video(video_file: Union[str, bytes], use_gpu: bool = True): for chunk in response.iter_content(chunk_size=8192): tmp_file.write(chunk) tmp_file.close() - vr = PyVideoReader(tmp_file.name, device=device, threads=0) + vr = VideoReader(tmp_file.name, ctx=ctx) elif video_file.startswith("data:"): _, encoded = video_file.split(",", 1) video_bytes = base64.b64decode(encoded) tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") tmp_file.write(video_bytes) tmp_file.close() - vr = PyVideoReader(tmp_file.name, device=device, threads=0) + vr = VideoReader(tmp_file.name, ctx=ctx) elif os.path.isfile(video_file): - vr = PyVideoReader(video_file, device=device, threads=0) + vr = VideoReader(video_file, ctx=ctx) else: video_bytes = base64.b64decode(video_file) tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") tmp_file.write(video_bytes) tmp_file.close() - vr = PyVideoReader(tmp_file.name, device=device, threads=0) + vr = VideoReader(tmp_file.name, ctx=ctx) else: raise ValueError(f"Unsupported video input type: {type(video_file)}") From b188a89a5d09ba634c77c34a2407e95dea5826b8 Mon Sep 17 00:00:00 2001 From: YanbingJiang Date: Wed, 16 Jul 2025 17:12:23 +0800 Subject: [PATCH 007/396] Fix CI xeon test with triton 3.3.1 (#8086) --- python/sglang/srt/layers/quantization/fp8_kernel.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/layers/quantization/fp8_kernel.py b/python/sglang/srt/layers/quantization/fp8_kernel.py index 7d73c5bc2b1e..79504265c299 100644 --- a/python/sglang/srt/layers/quantization/fp8_kernel.py +++ b/python/sglang/srt/layers/quantization/fp8_kernel.py @@ -29,6 +29,7 @@ direct_register_custom_op, get_device_core_count, get_device_name, + is_cpu, is_cuda, is_hip, log_info_on_rank0, @@ -37,6 +38,7 @@ _is_hip = is_hip() _is_cuda = is_cuda() +_is_cpu = is_cpu() if _is_cuda: from sgl_kernel import ( @@ -1168,7 +1170,7 @@ def scaled_fp8_quant( return output, scale -@triton.autotune( +fp8_autotune = triton.autotune( configs=[ triton.Config({"BLOCK_M": block_m}, num_warps=num_warps) for block_m in [16, 32, 64, 128] @@ -1176,6 +1178,8 @@ def scaled_fp8_quant( ], key=["K", "BLOCK_K", "M_ALIGNMENT"], ) + + @triton.jit def _per_token_group_quant_fp8_hopper_moe_mn_major( a, # (M, K):(K, 1) @@ -1221,6 +1225,12 @@ def _per_token_group_quant_fp8_hopper_moe_mn_major( tl.store(sfa_ptrs, inp_amax / 448.0, mask=coord_m < m) +if not _is_cpu: + _per_token_group_quant_fp8_hopper_moe_mn_major = fp8_autotune( + _per_token_group_quant_fp8_hopper_moe_mn_major + ) + + def per_token_group_quant_fp8_hopper_moe_mn_major( A: torch.Tensor, expert_offsets: torch.Tensor, From 6dc4af49377d25fb9d745c5dd14f13a04f9ffbdd Mon Sep 17 00:00:00 2001 From: Peng Zhang Date: Wed, 16 Jul 2025 22:08:46 +0800 Subject: [PATCH 008/396] fix greenctx stream compability (#8090) --- sgl-kernel/csrc/spatial/greenctx_stream.cu | 65 ++++++++++++++++------ 1 file changed, 48 insertions(+), 17 deletions(-) diff --git a/sgl-kernel/csrc/spatial/greenctx_stream.cu b/sgl-kernel/csrc/spatial/greenctx_stream.cu index b549aea5fa00..8c2e6d813c95 100644 --- a/sgl-kernel/csrc/spatial/greenctx_stream.cu +++ b/sgl-kernel/csrc/spatial/greenctx_stream.cu @@ -7,52 +7,83 @@ #include "cuda_utils.h" #include "greenctx_stream.h" +std::vector create_greenctx_stream_fallback(CUgreenCtx gctx[2]) { + CUstream streamA, streamB; + CUcontext ctx; + + // Stream A + CUDA_DRV(cuCtxFromGreenCtx(&ctx, gctx[0])); + CUDA_DRV(cuCtxPushCurrent(ctx)); + CUDA_DRV(cuStreamCreate(&streamA, CU_STREAM_NON_BLOCKING)); + CUDA_DRV(cuCtxPopCurrent(nullptr)); + + // Stream B + CUDA_DRV(cuCtxFromGreenCtx(&ctx, gctx[1])); + CUDA_DRV(cuCtxPushCurrent(ctx)); + CUDA_DRV(cuStreamCreate(&streamB, CU_STREAM_NON_BLOCKING)); + CUDA_DRV(cuCtxPopCurrent(nullptr)); + + return {(int64_t)streamA, (int64_t)streamB}; +} + +#if CUDA_VERSION >= 12050 +std::vector create_greenctx_stream_direct(CUgreenCtx gctx[2]) { + CUstream streamA; + CUstream streamB; + + CUDA_DRV(cuGreenCtxStreamCreate(&streamA, gctx[0], CU_STREAM_NON_BLOCKING, 0)); + CUDA_DRV(cuGreenCtxStreamCreate(&streamB, gctx[1], CU_STREAM_NON_BLOCKING, 0)); + + std::vector vec = {(int64_t)streamA, (int64_t)streamB}; + return vec; +} +#endif + std::vector create_greenctx_stream_by_value(int64_t smA, int64_t smB, int64_t device) { + TORCH_CHECK(CUDA_VERSION >= 12040, "Green Contexts feature requires CUDA Toolkit 12.4 or newer."); + CUgreenCtx gctx[3]; CUdevResourceDesc desc[3]; CUdevResource input; CUdevResource resources[4]; - CUstream streamA; - CUstream streamB; - unsigned int nbGroups = 1; if (smA <= 0 || smB <= 0) { TORCH_CHECK(false, "SM counts must be positive"); } - // Initialize device - CUDA_RT(cudaInitDevice(device, 0, 0)); - - // Query input SMs CUDA_DRV(cuDeviceGetDevResource((CUdevice)device, &input, CU_DEV_RESOURCE_TYPE_SM)); - // We want 3/4 the device for our green context unsigned int minCount = (unsigned int)(smA + smB); unsigned int minCountA = (unsigned int)(smA); - TORCH_CHECK(minCount <= input.sm.smCount, "Not enough SMs available for the requested configuration"); - // Split resources CUDA_DRV(cuDevSmResourceSplitByCount(&resources[2], &nbGroups, &input, &resources[3], 0, minCount)); - CUDA_DRV(cuDevResourceGenerateDesc(&desc[2], &resources[2], 1)); CUDA_DRV(cuGreenCtxCreate(&gctx[2], desc[2], (CUdevice)device, CU_GREEN_CTX_DEFAULT_STREAM)); CUDA_DRV(cuGreenCtxGetDevResource(gctx[2], &input, CU_DEV_RESOURCE_TYPE_SM)); CUDA_DRV(cuDevSmResourceSplitByCount(&resources[0], &nbGroups, &input, &resources[1], 0, minCountA)); - CUDA_DRV(cuDevResourceGenerateDesc(&desc[0], &resources[0], 1)); CUDA_DRV(cuGreenCtxCreate(&gctx[0], desc[0], (CUdevice)device, CU_GREEN_CTX_DEFAULT_STREAM)); CUDA_DRV(cuDevResourceGenerateDesc(&desc[1], &resources[1], 1)); CUDA_DRV(cuGreenCtxCreate(&gctx[1], desc[1], (CUdevice)device, CU_GREEN_CTX_DEFAULT_STREAM)); - - CUDA_DRV(cuGreenCtxStreamCreate(&streamA, gctx[0], CU_STREAM_NON_BLOCKING, 0)); - CUDA_DRV(cuGreenCtxStreamCreate(&streamB, gctx[1], CU_STREAM_NON_BLOCKING, 0)); - int smCountA = resources[0].sm.smCount; int smCountB = resources[1].sm.smCount; + std::vector stream_handles; + +#if CUDA_VERSION >= 12050 + stream_handles = create_greenctx_stream_direct(gctx); +#else + stream_handles = create_greenctx_stream_fallback(gctx); +#endif + CUDA_DRV(cuGreenCtxDestroy(gctx[2])); - std::vector vec = {(int64_t)streamA, (int64_t)streamB, smCountA, smCountB}; + std::vector vec = { + stream_handles[0], // streamA + stream_handles[1], // streamB + (int64_t)smCountA, + (int64_t)smCountB}; + return vec; } From d9eb5efc71b1a8eabf7a6f1765fcf3e73736d63d Mon Sep 17 00:00:00 2001 From: Simo Lin Date: Wed, 16 Jul 2025 08:54:55 -0700 Subject: [PATCH 009/396] [misc] update nvshmem and pin deepEP commit hash (#8098) --- docker/Dockerfile | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/docker/Dockerfile b/docker/Dockerfile index f998bddbc821..349873da4acf 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -2,6 +2,7 @@ ARG CUDA_VERSION=12.6.1 FROM nvidia/cuda:${CUDA_VERSION}-cudnn-devel-ubuntu22.04 ARG BUILD_TYPE=all +ARG DEEPEP_COMMIT=b6ce310bb0b75079682d09bc2ebc063a074fbd58 ENV DEBIAN_FRONTEND=noninteractive \ CUDA_HOME=/usr/local/cuda \ GDRCOPY_HOME=/usr/src/gdrdrv-2.4.4/ \ @@ -14,7 +15,7 @@ RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \ tzdata \ software-properties-common netcat-openbsd kmod unzip openssh-server \ curl wget lsof zsh ccache tmux htop git-lfs tree \ - python3 python3-pip python3-dev libpython3-dev \ + python3 python3-pip python3-dev libpython3-dev python3-venv \ build-essential cmake \ libopenmpi-dev libnuma1 libnuma-dev \ libibverbs-dev libibverbs1 libibumad3 \ @@ -62,13 +63,12 @@ RUN python3 -m pip install --no-cache-dir --upgrade pip setuptools wheel html5li fi # Build and install NVSHMEM + DeepEP -RUN wget https://developer.download.nvidia.com/compute/redist/nvshmem/3.2.5/source/nvshmem_src_3.2.5-1.txz \ +RUN wget https://developer.download.nvidia.com/compute/redist/nvshmem/3.3.9/source/nvshmem_src_cuda12-all-all-3.3.9.tar.gz \ && git clone https://github.com/deepseek-ai/DeepEP.git \ - && tar -xf nvshmem_src_3.2.5-1.txz && mv nvshmem_src nvshmem \ + && cd DeepEP && git checkout ${DEEPEP_COMMIT} && cd .. \ + && tar -xf nvshmem_src_cuda12-all-all-3.3.9.tar.gz && mv nvshmem_src nvshmem \ && cd nvshmem \ - && git apply /sgl-workspace/DeepEP/third-party/nvshmem.patch \ - && sed -i '1i#include ' examples/moe_shuffle.cu \ - && rm -f /sgl-workspace/nvshmem_src_3.2.5-1.txz \ + && rm -f /sgl-workspace/nvshmem_src_cuda12-all-all-3.3.9.tar.gz \ && NVSHMEM_SHMEM_SUPPORT=0 \ NVSHMEM_UCX_SUPPORT=0 \ NVSHMEM_USE_NCCL=0 \ From 570d33437bf0b4ac42e00ad468ddc43f9e0b376f Mon Sep 17 00:00:00 2001 From: Xiaoze Fan Date: Thu, 17 Jul 2025 01:57:46 +0800 Subject: [PATCH 010/396] [Feature] Layer-wise Prefill (#7634) Signed-off-by: jason-fxz Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- python/sglang/srt/managers/schedule_batch.py | 5 ++ .../srt/model_executor/forward_batch_info.py | 13 ++++ .../sglang/srt/model_executor/model_runner.py | 37 ++++++++++- python/sglang/srt/models/gemma.py | 48 ++++++++++++++ python/sglang/srt/models/gemma2.py | 51 +++++++++++++++ python/sglang/srt/models/gemma3_causal.py | 63 +++++++++++++++++++ python/sglang/srt/models/llama.py | 41 ++++++++++++ python/sglang/srt/models/qwen.py | 37 +++++++++++ python/sglang/srt/models/qwen2.py | 41 ++++++++++++ python/sglang/srt/models/qwen2_moe.py | 44 +++++++++++++ python/sglang/srt/models/qwen3.py | 42 ++++++++++++- python/sglang/srt/models/qwen3_moe.py | 43 +++++++++++++ python/sglang/srt/two_batch_overlap.py | 1 + 13 files changed, 464 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 1a48b055369f..c2750d072457 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -1328,6 +1328,11 @@ def prepare_for_extend(self): self.model_config.vocab_size, ) + def prepare_for_split_prefill(self): + self.prepare_for_extend() + # For split prefill, we need to set the forward mode to SPLIT_PREFILL + self.forward_mode = ForwardMode.SPLIT_PREFILL + def mix_with_running(self, running_batch: "ScheduleBatch"): self.forward_mode = ForwardMode.MIXED running_bs = running_batch.batch_size() diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index 7ed8eb1d47bd..fde60e0e5012 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -68,6 +68,8 @@ class ForwardMode(IntEnum): MIXED = auto() # No sequence to forward. For data parallel attention, some workers will be IDLE if no sequence are allocated. IDLE = auto() + # Split Prefill for PD multiplexing + SPLIT_PREFILL = auto() # Used in speculative decoding: verify a batch in the target model. TARGET_VERIFY = auto() @@ -95,6 +97,9 @@ def is_decode(self): def is_mixed(self): return self == ForwardMode.MIXED + def is_split_prefill(self): + return self == ForwardMode.SPLIT_PREFILL + def is_idle(self): return self == ForwardMode.IDLE @@ -194,6 +199,14 @@ class ForwardBatch: extend_logprob_start_lens_cpu: Optional[List[int]] = None extend_input_logprob_token_ids_gpu: Optional[torch.Tensor] = None + # For split prefill + # intermediate values for split prefill + hidden_states: torch.Tensor = None + residual: torch.Tensor = None + model_specific_states: Dict[str, any] = None + # current split index of layer + split_index: int = 0 + # For MLA chunked prefix cache used in chunked prefill # Tell attention backend whether the kv cache needs to be attended in current pass attn_attend_prefix_cache: Optional[bool] = None diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index a7885a5e367c..12db1d0559f3 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -1513,11 +1513,34 @@ def forward_idle( **kwargs, ) + def forward_split_prefill( + self, + forward_batch: ForwardBatch, + reinit_attn_backend: bool = False, + forward_count: int = 1, + ) -> LogitsProcessorOutput: + if forward_batch.split_index == 0 or reinit_attn_backend: + self.attn_backend.init_forward_metadata(forward_batch) + next_split_index = min( + forward_batch.split_index + forward_count, + self.model_config.num_hidden_layers, + ) + ret = self.model.forward_split_prefill( + forward_batch.input_ids, + forward_batch.positions, + forward_batch, + (forward_batch.split_index, next_split_index), + ) + forward_batch.split_index = next_split_index + return ret + def forward( self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False, pp_proxy_tensors: Optional[PPProxyTensors] = None, + reinit_attn_backend: bool = False, + split_forward_count: int = 1, ) -> Tuple[Union[LogitsProcessorOutput, PPProxyTensors], bool]: self.forward_pass_id += 1 @@ -1526,7 +1549,11 @@ def forward( forward_batch, ): output = self._forward_raw( - forward_batch, skip_attn_backend_init, pp_proxy_tensors + forward_batch, + skip_attn_backend_init, + pp_proxy_tensors, + reinit_attn_backend, + split_forward_count, ) if self.eplb_manager is not None: @@ -1539,6 +1566,8 @@ def _forward_raw( forward_batch: ForwardBatch, skip_attn_backend_init: bool, pp_proxy_tensors: Optional[PPProxyTensors], + reinit_attn_backend: bool = False, + split_forward_count: int = 1, ) -> Tuple[Union[LogitsProcessorOutput, PPProxyTensors], bool]: can_run_cuda_graph = bool( forward_batch.forward_mode.is_cuda_graph() @@ -1559,6 +1588,12 @@ def _forward_raw( skip_attn_backend_init=skip_attn_backend_init, pp_proxy_tensors=pp_proxy_tensors, ) + elif forward_batch.forward_mode.is_split_prefill(): + ret = self.forward_split_prefill( + forward_batch, + reinit_attn_backend=reinit_attn_backend, + forward_count=split_forward_count, + ) elif forward_batch.forward_mode.is_idle(): ret = self.forward_idle(forward_batch, pp_proxy_tensors=pp_proxy_tensors) else: diff --git a/python/sglang/srt/models/gemma.py b/python/sglang/srt/models/gemma.py index d8074487cb67..1ecb5011f71c 100644 --- a/python/sglang/srt/models/gemma.py +++ b/python/sglang/srt/models/gemma.py @@ -318,6 +318,54 @@ def forward( input_ids, hidden_states, self.model.embed_tokens, forward_batch ) + @torch.no_grad() + def forward_split_prefill( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + split_interval: Tuple[int, int], # [start, end) 0-based + input_embeds: torch.Tensor = None, + ): + start, end = split_interval + # embed + if start == 0: + if input_embeds is None: + forward_batch.hidden_states = self.model.embed_tokens(input_ids) + else: + forward_batch.hidden_states = input_embeds + + # Normalize the embedding by sqrt(hidden_size) + forward_batch.hidden_states *= self.model.config.hidden_size**0.5 + + # decoder layer + for i in range(start, end): + layer = self.model.layers[i] + forward_batch.hidden_states, forward_batch.residual = layer( + positions, + forward_batch.hidden_states, + forward_batch, + forward_batch.residual, + ) + + if end == self.model.config.num_hidden_layers: + # norm + forward_batch.hidden_states, _ = self.model.norm( + forward_batch.hidden_states, forward_batch.residual + ) + + # logits process + result = self.logits_processor( + input_ids, + forward_batch.hidden_states, + self.model.embed_tokens, + forward_batch, + ) + else: + result = None + + return result + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) diff --git a/python/sglang/srt/models/gemma2.py b/python/sglang/srt/models/gemma2.py index 9ee892bb79fa..ee490d083d1b 100644 --- a/python/sglang/srt/models/gemma2.py +++ b/python/sglang/srt/models/gemma2.py @@ -381,6 +381,57 @@ def forward( input_ids, hidden_states, self.model.embed_tokens, forward_batch ) + @torch.no_grad() + def forward_split_prefill( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + split_interval: Tuple[int, int], # [start, end) 0-based + input_embeds: torch.Tensor = None, + ): + start, end = split_interval + # embed + if start == 0: + if input_embeds is None: + forward_batch.hidden_states = self.model.embed_tokens(input_ids) + else: + forward_batch.hidden_states = input_embeds + + # Normalize + normalizer = torch.tensor( + self.model.config.hidden_size**0.5, dtype=torch.float16 + ) + forward_batch.hidden_states *= normalizer + + # decoder layer + for i in range(start, end): + layer = self.model.layers[i] + forward_batch.hidden_states, forward_batch.residual = layer( + positions, + forward_batch.hidden_states, + forward_batch, + forward_batch.residual, + ) + + if end == self.model.config.num_hidden_layers: + # norm + forward_batch.hidden_states, _ = self.model.norm( + forward_batch.hidden_states, forward_batch.residual + ) + + # logits process + result = self.logits_processor( + input_ids, + forward_batch.hidden_states, + self.model.embed_tokens, + forward_batch, + ) + else: + result = None + + return result + def get_hidden_dim(self, module_name): # return input_dim, output_dim if module_name in ["q_proj", "qkv_proj"]: diff --git a/python/sglang/srt/models/gemma3_causal.py b/python/sglang/srt/models/gemma3_causal.py index f5bff8fc4f57..5b6145affacc 100644 --- a/python/sglang/srt/models/gemma3_causal.py +++ b/python/sglang/srt/models/gemma3_causal.py @@ -647,6 +647,69 @@ def forward( input_ids, hidden_states, self.model.embed_tokens, forward_batch ) + @torch.no_grad() + def forward_split_prefill( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + split_interval: Tuple[int, int], # [start, end) 0-based + input_embeds: torch.Tensor = None, + ): + start, end = split_interval + # embed + if start == 0: + if input_embeds is None: + hidden_states = self.model.embed_tokens(input_ids) + else: + hidden_states = input_embeds + + if positions.dim() == 1: + positions = einops.rearrange(positions, "s -> 1 s") + position_embeddings_global = self.model.rotary_emb(hidden_states, positions) + position_embeddings_local = self.model.rotary_emb_local( + hidden_states, positions + ) + + forward_batch.hidden_states = hidden_states + forward_batch.model_specific_states = { + "positions": positions, + "position_embeddings_global": position_embeddings_global, + "position_embeddings_local": position_embeddings_local, + } + + # decoder layer + for i in range(start, end): + layer = self.model.layers[i] + layer_output = layer( + positions=forward_batch.model_specific_states["positions"], + position_embeddings_global=forward_batch.model_specific_states[ + "position_embeddings_global" + ], + position_embeddings_local=forward_batch.model_specific_states[ + "position_embeddings_local" + ], + hidden_states=forward_batch.hidden_states, + forward_batch=forward_batch, + ) + forward_batch.hidden_states = layer_output[0] + + if end == self.model.config.num_hidden_layers: + # norm + forward_batch.hidden_states = self.model.norm(forward_batch.hidden_states) + + # logits process + result = self.logits_processor( + input_ids, + forward_batch.hidden_states, + self.model.embed_tokens, + forward_batch, + ) + else: + result = None + + return result + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) diff --git a/python/sglang/srt/models/llama.py b/python/sglang/srt/models/llama.py index f8cfe859b2ba..d1614935bb18 100644 --- a/python/sglang/srt/models/llama.py +++ b/python/sglang/srt/models/llama.py @@ -480,6 +480,47 @@ def forward( else: return hidden_states + @torch.no_grad() + def forward_split_prefill( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + split_interval: Tuple[int, int], # [start, end) 0-based + input_embeds: torch.Tensor = None, + ) -> Optional[LogitsProcessorOutput]: + start, end = split_interval + # embed + if start == 0: + if input_embeds is None: + forward_batch.hidden_states = self.model.embed_tokens(input_ids) + else: + forward_batch.hidden_states = input_embeds + # decoder layer + for i in range(start, end): + layer = self.model.layers[i] + forward_batch.hidden_states, forward_batch.residual = layer( + positions, + forward_batch.hidden_states, + forward_batch, + forward_batch.residual, + ) + + if end == self.model.config.num_hidden_layers: + # norm + hidden_states, _ = self.model.norm( + forward_batch.hidden_states, forward_batch.residual + ) + forward_batch.hidden_states = hidden_states + # logits process + result = self.logits_processor( + input_ids, forward_batch.hidden_states, self.lm_head, forward_batch + ) + else: + result = None + + return result + @property def start_layer(self): return self.model.start_layer diff --git a/python/sglang/srt/models/qwen.py b/python/sglang/srt/models/qwen.py index f0660f62da6d..009650411e3d 100644 --- a/python/sglang/srt/models/qwen.py +++ b/python/sglang/srt/models/qwen.py @@ -15,6 +15,7 @@ # Adapted from # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/qwen.py#L1 +import time from typing import Any, Dict, Iterable, Optional, Tuple import torch @@ -286,6 +287,42 @@ def forward( input_ids, hidden_states, self.lm_head, forward_batch ) + @torch.no_grad() + def forward_split_prefill( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + split_interval: Tuple[int, int], # [start, end) 0-based + ): + start, end = split_interval + # embed + if start == 0: + forward_batch.hidden_states = self.transformer.wte(input_ids) + + # decoder layer + for i in range(start, end): + layer = self.transformer.h[i] + forward_batch.hidden_states = layer( + positions, + forward_batch.hidden_states, + forward_batch, + ) + + if end == self.transformer.config.num_hidden_layers: + # norm + forward_batch.hidden_states = self.transformer.ln_f( + forward_batch.hidden_states + ) + # logits process + result = self.logits_processor( + input_ids, forward_batch.hidden_states, self.lm_head, forward_batch + ) + else: + result = None + + return result + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) diff --git a/python/sglang/srt/models/qwen2.py b/python/sglang/srt/models/qwen2.py index e3670bb552e8..1696bdfa9177 100644 --- a/python/sglang/srt/models/qwen2.py +++ b/python/sglang/srt/models/qwen2.py @@ -481,6 +481,47 @@ def forward( else: return hidden_states + @torch.no_grad() + def forward_split_prefill( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + split_interval: Tuple[int, int], # [start, end) 0-based + input_embeds: torch.Tensor = None, + ): + start, end = split_interval + # embed + if start == 0: + if input_embeds is None: + forward_batch.hidden_states = self.model.embed_tokens(input_ids) + else: + forward_batch.hidden_states = input_embeds + # decoder layer + for i in range(start, end): + layer = self.model.layers[i] + forward_batch.hidden_states, forward_batch.residual = layer( + positions, + forward_batch.hidden_states, + forward_batch, + forward_batch.residual, + ) + + if end == self.model.config.num_hidden_layers: + # norm + hidden_states, _ = self.model.norm( + forward_batch.hidden_states, forward_batch.residual + ) + forward_batch.hidden_states = hidden_states + # logits process + result = self.logits_processor( + input_ids, forward_batch.hidden_states, self.lm_head, forward_batch + ) + else: + result = None + + return result + @property def start_layer(self): return self.model.start_layer diff --git a/python/sglang/srt/models/qwen2_moe.py b/python/sglang/srt/models/qwen2_moe.py index 92637d73b76f..fe2636ab74e8 100644 --- a/python/sglang/srt/models/qwen2_moe.py +++ b/python/sglang/srt/models/qwen2_moe.py @@ -406,6 +406,7 @@ def __init__( alt_stream: Optional[torch.cuda.Stream] = None, ) -> None: super().__init__() + self.config = config self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.pp_group = get_pp_group() @@ -554,6 +555,49 @@ def forward( else: return hidden_states + @torch.no_grad() + def forward_split_prefill( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + split_interval: Tuple[int, int], # [start, end) 0-based + input_embeds: torch.Tensor = None, + ): + start, end = split_interval + # embed + if start == 0: + if input_embeds is None: + forward_batch.hidden_states = self.model.embed_tokens(input_ids) + else: + forward_batch.hidden_states = input_embeds + + # decoder layer + for i in range(start, end): + with get_global_expert_distribution_recorder().with_current_layer(i): + layer = self.model.layers[i] + forward_batch.hidden_states, forward_batch.residual = layer( + positions, + forward_batch.hidden_states, + forward_batch, + forward_batch.residual, + ) + + if end == self.model.config.num_hidden_layers: + # norm + hidden_states, _ = self.model.norm( + forward_batch.hidden_states, forward_batch.residual + ) + forward_batch.hidden_states = hidden_states + # logits process + result = self.logits_processor( + input_ids, forward_batch.hidden_states, self.lm_head, forward_batch + ) + else: + result = None + + return result + @property def start_layer(self): return self.model.start_layer diff --git a/python/sglang/srt/models/qwen3.py b/python/sglang/srt/models/qwen3.py index 9c36598397fb..6289e61e7a72 100644 --- a/python/sglang/srt/models/qwen3.py +++ b/python/sglang/srt/models/qwen3.py @@ -1,5 +1,4 @@ # Adapted from qwen2.py - import logging from functools import partial from typing import Any, Dict, Iterable, List, Optional, Tuple @@ -367,6 +366,47 @@ def forward( else: return hidden_states + @torch.no_grad() + def forward_split_prefill( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + split_interval: Tuple[int, int], # [start, end) 0-based + input_embeds: torch.Tensor = None, + ): + start, end = split_interval + # embed + if start == 0: + if input_embeds is None: + forward_batch.hidden_states = self.model.embed_tokens(input_ids) + else: + forward_batch.hidden_states = input_embeds + # decoder layer + for i in range(start, end): + layer = self.model.layers[i] + forward_batch.hidden_states, forward_batch.residual = layer( + positions, + forward_batch.hidden_states, + forward_batch, + forward_batch.residual, + ) + + if end == self.model.config.num_hidden_layers: + # norm + hidden_states, _ = self.model.norm( + forward_batch.hidden_states, forward_batch.residual + ) + forward_batch.hidden_states = hidden_states + # logits process + result = self.logits_processor( + input_ids, forward_batch.hidden_states, self.lm_head, forward_batch + ) + else: + result = None + + return result + @property def start_layer(self): return self.model.start_layer diff --git a/python/sglang/srt/models/qwen3_moe.py b/python/sglang/srt/models/qwen3_moe.py index 7c7c7551be78..75d3b475cb0e 100644 --- a/python/sglang/srt/models/qwen3_moe.py +++ b/python/sglang/srt/models/qwen3_moe.py @@ -745,6 +745,49 @@ def forward( else: return hidden_states + @torch.no_grad() + def forward_split_prefill( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + split_interval: Tuple[int, int], # [start, end) 0-based + input_embeds: torch.Tensor = None, + ): + start, end = split_interval + # embed + if start == 0: + if input_embeds is None: + forward_batch.hidden_states = self.model.embed_tokens(input_ids) + else: + forward_batch.hidden_states = input_embeds + + # decoder layer + for i in range(start, end): + with get_global_expert_distribution_recorder().with_current_layer(i): + layer = self.model.layers[i] + forward_batch.hidden_states, forward_batch.residual = layer( + positions, + forward_batch.hidden_states, + forward_batch, + forward_batch.residual, + ) + + if end == self.model.config.num_hidden_layers: + # norm + hidden_states, _ = self.model.norm( + forward_batch.hidden_states, forward_batch.residual + ) + forward_batch.hidden_states = hidden_states + # logits process + result = self.logits_processor( + input_ids, forward_batch.hidden_states, self.lm_head, forward_batch + ) + else: + result = None + + return result + @property def start_layer(self): return self.model.start_layer diff --git a/python/sglang/srt/two_batch_overlap.py b/python/sglang/srt/two_batch_overlap.py index fc419b03c298..3fdf2a1f77a6 100644 --- a/python/sglang/srt/two_batch_overlap.py +++ b/python/sglang/srt/two_batch_overlap.py @@ -500,6 +500,7 @@ def filter_batch( "capture_hidden_mode", "padded_static_len", "mrope_positions", # only used by qwen2-vl, thus not care + "split_index", # for split prefill ]: output_dict[key] = getattr(batch, key) if not batch.forward_mode.is_target_verify(): From c28ad1990d29f3993c1eebff06673e819ac4b032 Mon Sep 17 00:00:00 2001 From: Peng Zhang Date: Thu, 17 Jul 2025 06:56:26 +0800 Subject: [PATCH 011/396] [1/n] chore: decouple quantization implementation from vLLM dependency (#7992) --- .../layers/moe/fused_moe_triton/__init__.py | 5 +- .../srt/layers/quantization/__init__.py | 6 +- python/sglang/srt/layers/quantization/gptq.py | 610 +++++++++++--- .../srt/layers/quantization/marlin_utils.py | 781 ++++++++++++++++++ .../srt/layers/quantization/moe_wna16.py | 30 + .../srt/layers/quantization/quant_utils.py | 166 ---- .../srt/layers/quantization}/scalar_type.py | 0 .../sglang/srt/layers/quantization/utils.py | 163 +++- sgl-kernel/python/sgl_kernel/fused_moe.py | 3 +- sgl-kernel/tests/test_marlin_repack.py | 6 +- test/srt/test_gptqmodel_dynamic.py | 9 +- test/srt/test_int4_kernel.py | 301 ------- test/srt/test_w4a8.py | 14 - 13 files changed, 1478 insertions(+), 616 deletions(-) create mode 100644 python/sglang/srt/layers/quantization/marlin_utils.py delete mode 100644 python/sglang/srt/layers/quantization/quant_utils.py rename {sgl-kernel/python/sgl_kernel => python/sglang/srt/layers/quantization}/scalar_type.py (100%) delete mode 100644 test/srt/test_int4_kernel.py delete mode 100644 test/srt/test_w4a8.py diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/__init__.py b/python/sglang/srt/layers/moe/fused_moe_triton/__init__.py index b68961931d54..839b659fe31b 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/__init__.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/__init__.py @@ -1,10 +1,11 @@ from contextlib import contextmanager from typing import Any, Dict, Optional -import sglang.srt.layers.moe.fused_moe_triton.fused_moe # noqa from sglang.srt.layers.moe.fused_moe_triton.fused_moe import ( fused_experts, get_config_file_name, + moe_align_block_size, + try_get_optimal_moe_config, ) from sglang.srt.layers.moe.fused_moe_triton.layer import ( FusedMoE, @@ -37,4 +38,6 @@ def get_config() -> Optional[Dict[str, Any]]: "fused_moe", "fused_experts", "get_config_file_name", + "moe_align_block_size", + "try_get_optimal_moe_config", ] diff --git a/python/sglang/srt/layers/quantization/__init__.py b/python/sglang/srt/layers/quantization/__init__.py index 4ee498169baa..7507a5b62893 100644 --- a/python/sglang/srt/layers/quantization/__init__.py +++ b/python/sglang/srt/layers/quantization/__init__.py @@ -22,10 +22,6 @@ from vllm.model_executor.layers.quantization.experts_int8 import ExpertsInt8Config from vllm.model_executor.layers.quantization.fbgemm_fp8 import FBGEMMFp8Config from vllm.model_executor.layers.quantization.gguf import GGUFConfig - from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod - from vllm.model_executor.layers.quantization.gptq_marlin import ( - GPTQMarlinLinearMethod, - ) from vllm.model_executor.layers.quantization.gptq_marlin_24 import ( GPTQMarlin24Config, ) @@ -59,7 +55,9 @@ def override_quantization_method(self, *args, **kwargs): from sglang.srt.layers.quantization.fp8 import Fp8Config from sglang.srt.layers.quantization.gptq import ( GPTQConfig, + GPTQLinearMethod, GPTQMarlinConfig, + GPTQMarlinLinearMethod, GPTQMarlinMoEMethod, ) from sglang.srt.layers.quantization.modelopt_quant import ( diff --git a/python/sglang/srt/layers/quantization/gptq.py b/python/sglang/srt/layers/quantization/gptq.py index 9e2b3e0630bf..3658d0b85793 100644 --- a/python/sglang/srt/layers/quantization/gptq.py +++ b/python/sglang/srt/layers/quantization/gptq.py @@ -1,48 +1,56 @@ import logging +from dataclasses import dataclass from fractions import Fraction from typing import Any, Callable, Dict, List, Optional, Union import torch -from sglang.srt.layers.linear import LinearBase, set_weight_attrs +from sglang.srt.layers.linear import LinearBase, LinearMethodBase, set_weight_attrs +from sglang.srt.layers.parameter import ( + BasevLLMParameter, + ChannelQuantScaleParameter, + GroupQuantScaleParameter, + PackedColumnParameter, + PackedvLLMParameter, + RowvLLMParameter, + permute_param_layout_, +) from sglang.srt.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase, ) -from sglang.srt.layers.quantization.utils import replace_parameter -from sglang.srt.utils import is_cuda - -_is_cuda = is_cuda() +from sglang.srt.layers.quantization.marlin_utils import ( + apply_gptq_marlin_linear, + check_marlin_supported, + check_marlin_supports_shape, + marlin_is_k_full, + marlin_make_empty_g_idx, + marlin_make_workspace, + marlin_moe_permute_scales, + marlin_permute_scales, + marlin_repeat_scales_on_all_ranks, + marlin_sort_g_idx, + marlin_zero_points, + verify_marlin_supported, +) +from sglang.srt.layers.quantization.scalar_type import ScalarType, scalar_types +from sglang.srt.layers.quantization.utils import replace_parameter, unpack_cols try: from vllm import _custom_ops as ops - from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod - from vllm.model_executor.layers.quantization.gptq_marlin import ( - FusedMoE, - FusedMoEMethodBase, - FusedMoeWeightScaleSupported, - GPTQMarlinLinearMethod, - marlin_moe_permute_scales, - ) - from vllm.model_executor.layers.quantization.marlin import MarlinLinearMethod - from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - check_marlin_supported, - ) - from vllm.scalar_type import scalar_types - - VLLM_AVAILABLE = True except ImportError: - VLLM_AVAILABLE = False + ops = None - GPTQLinearMethod = MarlinLinearMethod = Any +from sglang.srt.utils import is_cuda - FusedMoEMethodBase = QuantizeMethodBase +_is_cuda = is_cuda() - class scalar_types: - uint4b8 = "uint4b8" - uint8b128 = "uint8b128" +if _is_cuda: + from sgl_kernel import fused_marlin_moe +FusedMoEMethodBase = QuantizeMethodBase + logger = logging.getLogger(__name__) @@ -54,6 +62,38 @@ def check_marlin_format(hf_quant_cfg: Dict[str, Any]) -> bool: ) +def gptq_marlin_moe_repack( + b_q_weight: torch.Tensor, + perm: torch.Tensor, + size_k: int, + size_n: int, + num_bits: int, +) -> torch.Tensor: + num_experts = b_q_weight.shape[0] + assert size_k % 16 == 0 + output = torch.empty( + (num_experts, size_k // 16, size_n * (num_bits // 2)), + device=b_q_weight.device, + dtype=b_q_weight.dtype, + ) + for e in range(num_experts): + output[e] = torch.ops.sgl_kernel.gptq_marlin_repack( + b_q_weight[e], perm[e], size_k, size_n, num_bits + ) + return output + + +@dataclass +class MarlinLinearLayerConfig: + full_weight_shape: tuple[int, int] # [in, out] + partition_weight_shape: tuple[int, int] + weight_type: ScalarType + act_type: torch.dtype + group_size: int + zero_points: bool + has_g_idx: bool + + class GPTQConfig(QuantizationConfig): """Config class for GPTQ. @@ -151,11 +191,16 @@ def from_config(cls, config: Dict[str, Any]) -> "GPTQConfig": def get_quant_method( self, layer: torch.nn.Module, prefix: str - ) -> Optional[GPTQLinearMethod]: + ) -> Optional["LinearMethodBase"]: # Delay the import to avoid circular dependency + from sglang.srt.layers.moe.fused_moe_triton import FusedMoE from sglang.srt.layers.quantization import get_linear_quant_method - return get_linear_quant_method(self, layer, prefix, GPTQLinearMethod) + if isinstance(layer, LinearBase): + return get_linear_quant_method(self, layer, prefix, GPTQLinearMethod) + elif isinstance(layer, FusedMoE): + raise TypeError("GPTQ Method does not support MoE, please use gptq_marlin") + return None class GPTQMarlinConfig(QuantizationConfig): @@ -313,14 +358,6 @@ def get_quant_method( if isinstance(layer, FusedMoE): return GPTQMarlinMoEMethod(self) - # TODO: re-enable after SGLang syncs with vllm >= 0.7.3 - # if layer.num_experts > 32: - # # For MoEs with many experts the moe_wna16 kernel is faster - # return MoeWNA16Config.from_config(self.full_config).get_quant_method( - # layer, prefix - # ) - # else: - # return GPTQMarlinMoEMethod(self) return get_linear_quant_method(self, layer, prefix, GPTQMarlinLinearMethod) @classmethod @@ -344,112 +381,439 @@ def is_gptq_marlin_compatible(cls, quant_config: Dict[str, Any]): if (num_bits, sym) not in cls.TYPE_MAP: return False - assert ( - VLLM_AVAILABLE - ), "vllm is not installed, to use gptq_marlin, please install vllm" - return check_marlin_supported( quant_type=cls.TYPE_MAP[(num_bits, sym)], group_size=group_size ) -class MarlinConfig(QuantizationConfig): - """Config class for Marlin. +class GPTQLinearMethod(LinearMethodBase): + """Linear method for GPTQ. - Reference: https://github.com/IST-DASLab/marlin/tree/master + Args: + quant_config: The GPTQ quantization config. """ - def __init__( + def __init__(self, quant_config: GPTQConfig): + self.quant_config = quant_config + + def create_weights( self, - group_size: int, - lm_head_quantized: bool, - ) -> None: - # Group size for the quantization. - self.group_size = group_size - self.lm_head_quantized = lm_head_quantized - if self.group_size != 128 and self.group_size != -1: + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + del output_size # Unused. + weight_loader = extra_weight_attrs.get("weight_loader") + if input_size_per_partition % self.quant_config.group_size != 0: raise ValueError( - "Currently, only group size 128 and -1 (channelwise) " - "is supported for Marlin, but got group_size of " - f"{self.group_size}" + "The input size is not aligned with the quantized " + "weight shape. This can be caused by too large " + "tensor parallel size." + ) + output_size_per_partition = sum(output_partition_sizes) + if output_size_per_partition % self.quant_config.pack_factor.numerator != 0: + raise ValueError( + "The output size is not aligned with the quantized " + "weight shape. This can be caused by too large " + "tensor parallel size." ) - # 4 Bits packed into 32 bit datatype. - self.pack_factor = 32 // 4 + if self.quant_config.group_size != -1: + group_size = self.quant_config.group_size + else: + group_size = input_size + + self.use_shuffle = True + scale_and_zero_size = input_size // group_size + scale_and_zero_input_dim = None + if ( + input_size != input_size_per_partition + and self.quant_config.group_size != -1 + ): + if self.quant_config.desc_act: + self.use_shuffle = False + else: + # we need to partition qzeros and scales for exllama kernel + scale_and_zero_size = input_size_per_partition // group_size + scale_and_zero_input_dim = 0 + + qweight = PackedvLLMParameter( + data=torch.empty( + input_size_per_partition // self.quant_config.pack_factor, + output_size_per_partition, + dtype=torch.int32, + ), + input_dim=0, + output_dim=1, + packed_dim=0, + packed_factor=self.quant_config.pack_factor, + weight_loader=weight_loader, + ) - # Tile size used by marlin kernels. - self.tile_size = 16 + g_idx = RowvLLMParameter( + data=torch.tensor( + [ + i // self.quant_config.group_size + for i in range(input_size_per_partition) + ], + dtype=torch.int32, + ), + input_dim=0, + weight_loader=weight_loader, + ) + qzeros_args = { + "data": torch.empty( + scale_and_zero_size, + output_size_per_partition // self.quant_config.pack_factor, + dtype=torch.int32, + ), + "weight_loader": weight_loader, + } + weight_scale_args = { + "data": torch.empty( + scale_and_zero_size, + output_size_per_partition, + dtype=params_dtype, + ), + "weight_loader": weight_loader, + } + if scale_and_zero_input_dim is None: + scales = ChannelQuantScaleParameter(output_dim=1, **weight_scale_args) + qzeros = PackedColumnParameter( + output_dim=1, + packed_dim=1, + packed_factor=self.quant_config.pack_factor, + **qzeros_args, + ) - # Min out_features dim - self.min_n_threads = 64 + else: + scales = GroupQuantScaleParameter( + output_dim=1, input_dim=0, **weight_scale_args + ) + qzeros = PackedvLLMParameter( + input_dim=0, + output_dim=1, + packed_dim=1, + packed_factor=self.quant_config.pack_factor, + **qzeros_args, + ) - # Min in_features dim - self.min_k_threads = 128 + layer.register_parameter("qweight", qweight) + layer.register_parameter("g_idx", g_idx) + layer.register_parameter("qzeros", qzeros) + layer.register_parameter("scales", scales) - # Max parallel problems to solve at once (improves large - # batch performance) - self.max_parallel = 16 + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + # for torch.compile + layer.qzeros = torch.nn.Parameter(layer.qzeros.data, requires_grad=False) + layer.qweight = torch.nn.Parameter(layer.qweight.data, requires_grad=False) + layer.g_idx = torch.nn.Parameter(layer.g_idx.data, requires_grad=False) + layer.scales = torch.nn.Parameter(layer.scales.data, requires_grad=False) + + # exllama needs to shuffle the weight after the weight is loaded + # here we do the shuffle on first forward pass + if self.use_shuffle: + if self.quant_config.desc_act: + layer.g_idx.data = torch.argsort(layer.g_idx).to(torch.int) + else: + layer.g_idx.data = torch.empty( + (0,), dtype=torch.int, device=layer.g_idx.device + ) + ops.gptq_shuffle(layer.qweight, layer.g_idx, self.quant_config.weight_bits) - # Permutation length used by the marlin kernels. - self.perm_len = 1024 + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + out_shape = x.shape[:-1] + (layer.qweight.shape[-1],) + reshaped_x = x.reshape(-1, x.shape[-1]) + + output = ops.gptq_gemm( + reshaped_x, + layer.qweight, + layer.qzeros, + layer.scales, + layer.g_idx, + self.use_shuffle, + self.quant_config.weight_bits, + ) + if bias is not None: + output.add_(bias) + return output.reshape(out_shape) - def __repr__(self) -> str: - return ( - f"MarlinConfig(group_size={self.group_size}, " - f"lm_head_quantized={self.lm_head_quantized})" + +class GPTQMarlinLinearMethod(LinearMethodBase): + """Linear method for GPTQ Marlin. + + Args: + quant_config: The GPTQ Marlin quantization config. + """ + + _kernel_backends_being_used: set[str] = set() + + def __init__(self, quant_config: GPTQMarlinConfig) -> None: + self.quant_config = quant_config + + # Verify supported on platform. + verify_marlin_supported( + quant_type=self.quant_config.quant_type, + group_size=self.quant_config.group_size, ) - @classmethod - def get_name(cls) -> str: - return "marlin" + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ) -> None: + output_size_per_partition = sum(output_partition_sizes) + is_row_parallel = input_size != input_size_per_partition + weight_loader = extra_weight_attrs.get("weight_loader") + + self.kernel_config = MarlinLinearLayerConfig( + full_weight_shape=(input_size, output_size), + partition_weight_shape=( + input_size_per_partition, + output_size_per_partition, + ), + weight_type=self.quant_config.quant_type, + act_type=params_dtype, + group_size=self.quant_config.group_size, + zero_points=False, + has_g_idx=self.quant_config.desc_act, + ) + # Normalize group_size + if self.quant_config.group_size != -1: + group_size = self.quant_config.group_size + else: + group_size = input_size - @classmethod - def get_supported_act_dtypes(cls) -> List[torch.dtype]: - return [torch.half] + # Determine sharding + if marlin_repeat_scales_on_all_ranks( + self.quant_config.desc_act, self.quant_config.group_size, is_row_parallel + ): + # By setting scale_dim == None, weight_loader will + # repeat the scales on each GPU in TP>1 case. + scales_and_zp_input_dim = None + scales_and_zp_size = input_size // group_size + else: + # By setting scale_dim == 0, weight_loader will + # shard the scales in TP>1 case. + scales_and_zp_input_dim = 0 + scales_and_zp_size = input_size_per_partition // group_size + + # Quantized weights + qweight = PackedvLLMParameter( + data=torch.empty( + input_size_per_partition // self.quant_config.pack_factor, + output_size_per_partition, + dtype=torch.int32, + ), + input_dim=0, + output_dim=1, + packed_dim=0, + packed_factor=self.quant_config.pack_factor, + weight_loader=weight_loader, + ) - @classmethod - # Need to figure it out - def get_min_capability(cls) -> int: - return 80 + # Activation order + g_idx = RowvLLMParameter( + data=torch.empty( + input_size_per_partition, + dtype=torch.int32, + ), + input_dim=0, + weight_loader=weight_loader, + ) - @classmethod - def get_config_filenames(cls) -> List[str]: - return ["quantize_config.json"] + qzeros_args = { + "data": torch.empty( + scales_and_zp_size, + output_size_per_partition // self.quant_config.pack_factor, + dtype=torch.int32, + ), + "weight_loader": weight_loader, + } + weight_scale_args = { + "data": torch.empty( + scales_and_zp_size, + output_size_per_partition, + dtype=params_dtype, + ), + "weight_loader": weight_loader, + } + + if scales_and_zp_input_dim is None: + scales = ChannelQuantScaleParameter(output_dim=1, **weight_scale_args) + qzeros = PackedColumnParameter( + output_dim=1, + packed_dim=1, + packed_factor=self.quant_config.pack_factor, + **qzeros_args, + ) - @classmethod - def from_config(cls, config: Dict[str, Any]) -> "MarlinConfig": - group_size = cls.get_from_keys(config, ["group_size"]) - lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False) - return cls(group_size, lm_head_quantized) + else: + scales = GroupQuantScaleParameter( + output_dim=1, input_dim=0, **weight_scale_args + ) + qzeros = PackedvLLMParameter( + input_dim=0, + output_dim=1, + packed_dim=1, + packed_factor=self.quant_config.pack_factor, + **qzeros_args, + ) - @classmethod - def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]: - is_marlin_format = check_marlin_format(hf_quant_cfg) + layer.register_parameter("qweight", qweight) + layer.register_parameter("g_idx", g_idx) + layer.register_parameter("scales", scales) + layer.register_parameter("qzeros", qzeros) - is_valid_user_quant = ( - user_quant is None or user_quant == "gptq" or user_quant == "marlin" + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + device = getattr(layer, "qweight").device + c = self.kernel_config + + check_marlin_supports_shape( + c.partition_weight_shape[1], # out_features + c.partition_weight_shape[0], # in_features + c.full_weight_shape[0], # in_features + c.group_size, ) - if is_marlin_format and is_valid_user_quant: - msg = "The model is serialized in {} format. Using {} kernel.".format( - cls.get_name(), cls.get_name() + row_parallel = c.partition_weight_shape[0] != c.full_weight_shape[0] + self.is_k_full = marlin_is_k_full(c.has_g_idx, row_parallel) + + # Allocate marlin workspace. + self.workspace = marlin_make_workspace(device) + + # Default names since marlin requires empty parameters for these, + # TODO: remove this requirement from marlin (allow optional tensors) + self.w_q_name = "qweight" + self.w_s_name = "scales" + self.w_zp_name = "qzeros" + self.w_gidx_name = "g_idx" + + def _transform_param( + layer: torch.nn.Module, name: Optional[str], fn: Callable + ) -> None: + if name is not None and getattr(layer, name, None) is not None: + + old_param = getattr(layer, name) + new_param = fn(old_param) + # replace the parameter with torch.nn.Parameter for TorchDynamo + # compatibility + replace_parameter( + layer, name, torch.nn.Parameter(new_param.data, requires_grad=False) + ) + + def transform_w_q(x): + assert isinstance(x, BasevLLMParameter) + permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0) + x.data = torch.ops.sgl_kernel.gptq_marlin_repack( + x.data.contiguous(), + perm=layer.g_idx_sort_indices, + size_k=c.partition_weight_shape[0], + size_n=c.partition_weight_shape[1], + num_bits=c.weight_type.size_bits, ) - logger.info(msg) - return cls.get_name() + return x + + def transform_w_s(x): + assert isinstance(x, BasevLLMParameter) + permute_param_layout_(x, input_dim=0, output_dim=1) + x.data = marlin_permute_scales( + x.data.contiguous(), + size_k=c.partition_weight_shape[0], + size_n=c.partition_weight_shape[1], + group_size=c.group_size, + ) + return x - return None + if c.has_g_idx: + g_idx, g_idx_sort_indices = marlin_sort_g_idx( + getattr(layer, self.w_gidx_name) + ) + _transform_param(layer, self.w_gidx_name, lambda _: g_idx) + layer.g_idx_sort_indices = g_idx_sort_indices + else: + setattr(layer, self.w_gidx_name, marlin_make_empty_g_idx(device)) + layer.g_idx_sort_indices = marlin_make_empty_g_idx(device) - def get_quant_method( - self, layer: torch.nn.Module, prefix: str - ) -> Optional[MarlinLinearMethod]: - # Delay the import to avoid circular dependency - from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead + if c.zero_points: + grouped_k = ( + c.partition_weight_shape[0] // c.group_size if c.group_size != -1 else 1 + ) + _transform_param( + layer, + self.w_zp_name, + lambda x: marlin_zero_points( + unpack_cols( + x.t(), + c.weight_type.size_bits, + grouped_k, + c.partition_weight_shape[1], + ), + size_k=grouped_k, + size_n=c.partition_weight_shape[1], + num_bits=c.weight_type.size_bits, + ), + ) + else: + setattr(layer, self.w_zp_name, marlin_make_empty_g_idx(device)) + _transform_param(layer, self.w_q_name, transform_w_q) + _transform_param(layer, self.w_s_name, transform_w_s) - if isinstance(layer, LinearBase) or ( - isinstance(layer, ParallelLMHead) and self.lm_head_quantized - ): - return MarlinLinearMethod(self) - return None + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + c = self.kernel_config + + def _get_weight_params( + layer: torch.nn.Module, + ) -> tuple[ + torch.Tensor, # w_q + torch.Tensor, # w_s + Optional[torch.Tensor], # w_zp, + Optional[torch.Tensor], # w_gidx + ]: + return ( + getattr(layer, self.w_q_name), + getattr(layer, self.w_s_name), + getattr(layer, self.w_zp_name or "", None), + getattr(layer, self.w_gidx_name or "", None), + ) + + w_q, w_s, w_zp, w_gidx = _get_weight_params(layer) + + # `process_weights_after_loading` will ensure w_zp and w_gidx are not + # None for marlin + return apply_gptq_marlin_linear( + input=x, + weight=w_q, + weight_scale=w_s, + weight_zp=w_zp, # type: ignore + g_idx=w_gidx, # type: ignore + g_idx_sort_indices=layer.g_idx_sort_indices, + workspace=self.workspace, + wtype=c.weight_type, + input_size_per_partition=c.partition_weight_shape[0], + output_size_per_partition=c.partition_weight_shape[1], + is_k_full=self.is_k_full, + bias=bias, + ) class GPTQMarlinMoEMethod(FusedMoEMethodBase): @@ -467,6 +831,9 @@ def create_weights( params_dtype: torch.dtype, **extra_weight_attrs, ): + # Delay the import to avoid circular dependency + from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported + intermediate_size = extra_weight_attrs.pop("intermediate_size") self.is_k_full = (not self.quant_config.desc_act) or ( @@ -644,20 +1011,20 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: requires_grad=False, ) # Repack weights - marlin_w13_qweight = ops.gptq_marlin_moe_repack( + marlin_w13_qweight = gptq_marlin_moe_repack( layer.w13_qweight, layer.w13_g_idx_sort_indices, layer.w13_qweight.shape[1] * self.quant_config.pack_factor, layer.w13_qweight.shape[2], - self.quant_config.quant_type.size_bits, + self.quant_config.weight_bits, ) replace_parameter(layer, "w13_qweight", marlin_w13_qweight) - marlin_w2_qweight = ops.gptq_marlin_moe_repack( + marlin_w2_qweight = gptq_marlin_moe_repack( layer.w2_qweight, layer.w2_g_idx_sort_indices, layer.w2_qweight.shape[1] * self.quant_config.pack_factor, layer.w2_qweight.shape[2], - self.quant_config.quant_type.size_bits, + self.quant_config.weight_bits, ) replace_parameter(layer, "w2_qweight", marlin_w2_qweight) # Repack scales @@ -698,13 +1065,19 @@ def apply( e_score_correction_bias: Optional[torch.Tensor] = None, activation: str = "silu", ) -> torch.Tensor: + # Delay the import to avoid circular dependency + from sglang.srt.layers.moe.topk import select_experts + assert activation == "silu", "Only SiLU activation is supported." + assert ( + scoring_func == "softmax" + ), "Only softmax score func is supported for now." # The input must currently be float16 orig_dtype = x.dtype x = x.half() - topk_weights, topk_ids = FusedMoE.select_experts( + topk_weights, topk_ids = select_experts( hidden_states=x, router_logits=router_logits, use_grouped_topk=use_grouped_topk, @@ -713,11 +1086,10 @@ def apply( topk_group=topk_group, num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, - scoring_func=scoring_func, - e_score_correction_bias=e_score_correction_bias, + correction_bias=e_score_correction_bias, ) - return torch.ops.vllm.fused_marlin_moe( + return fused_marlin_moe( x, layer.w13_qweight, layer.w2_qweight, @@ -730,6 +1102,6 @@ def apply( g_idx2=layer.w2_g_idx, sort_indices1=layer.w13_g_idx_sort_indices, sort_indices2=layer.w2_g_idx_sort_indices, - quant_type_id=self.quant_config.quant_type.id, + num_bits=self.quant_config.weight_bits, is_k_full=self.is_k_full, ).to(orig_dtype) diff --git a/python/sglang/srt/layers/quantization/marlin_utils.py b/python/sglang/srt/layers/quantization/marlin_utils.py new file mode 100644 index 000000000000..503c3d003632 --- /dev/null +++ b/python/sglang/srt/layers/quantization/marlin_utils.py @@ -0,0 +1,781 @@ +# SPDX-License-Identifier: Apache-2.0 +# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/marlin_utils.py + +import logging +from typing import Any, Optional + +import numpy +import torch + +from sglang.srt.layers.linear import LinearBase, LinearMethodBase +from sglang.srt.layers.parameter import ( + BasevLLMParameter, + ChannelQuantScaleParameter, + GroupQuantScaleParameter, + PackedvLLMParameter, +) +from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.layers.quantization.scalar_type import ScalarType, scalar_types +from sglang.srt.layers.quantization.utils import pack_cols, unpack_cols +from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead +from sglang.srt.utils import get_device_capability + +try: + from vllm import _custom_ops as ops +except ImportError: + ops = None + +logger = logging.getLogger(__name__) + +GPTQ_MARLIN_TILE = 16 +GPTQ_MARLIN_MIN_THREAD_N = 64 +GPTQ_MARLIN_MIN_THREAD_K = 128 +GPTQ_MARLIN_MAX_PARALLEL = 16 + +MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] + +# In case there is a performance issue with Marlin, the variable below can be +# changed to False, which allows Marlin to perform global reductions in fp16 +# precision (instead of fp32), and therefore, save on some memory movements. +USE_FP32_REDUCE_DEFAULT = True + + +# For binary size and compile time, we don't support the same types for with and +# without runtime zero-point. We support common cases, i.e. AWQ and GPTQ. +# TODO: we may want to move this into the C++ so its closer to the actual impl +def query_marlin_supported_quant_types( + has_zp: Optional[bool] = None, + include_fp_type: bool = True, + device_capability: Optional[int] = None, +): + if device_capability is None: + major, minor = get_device_capability() + capability = major * 10 + minor + device_capability = -1 if capability is None else capability + + if device_capability < 80: + return [] + + # - has_zp is True: return quant_types that has zero points + # - has_zp is False: return quant_types that has not zero points + # - has_zp is None: both + if has_zp is None: + types0 = query_marlin_supported_quant_types( + False, include_fp_type, device_capability + ) + types1 = query_marlin_supported_quant_types( + True, include_fp_type, device_capability + ) + return types0 + types1 + + if has_zp: + # AWQ style, unsigned + runtime zero-point + return [scalar_types.uint4] + else: + # GPTQ style, unsigned + symmetric bias + res = [scalar_types.uint4b8, scalar_types.uint8b128] + if include_fp_type: + res += [scalar_types.float8_e4m3fn, scalar_types.float4_e2m1f] + return res + + +def _check_marlin_supported( + quant_type: ScalarType, + group_size: Optional[int], + has_zp: bool, + device_capability: Optional[int] = None, +) -> tuple[bool, Optional[str]]: + + if device_capability is None: + major, minor = get_device_capability() + capability = major * 10 + minor + device_capability = -1 if capability is None else capability + + supported_types = query_marlin_supported_quant_types( + has_zp, True, device_capability + ) + + if quant_type not in supported_types: + return ( + False, + f"Marlin does not support weight_bits = {quant_type}. " + f"Only types = {supported_types} " + f"are supported (for group_size = {group_size}, " + f"device_capability = {device_capability}, zp = {has_zp}).", + ) + if group_size is None or group_size not in MARLIN_SUPPORTED_GROUP_SIZES: + return ( + False, + f"Marlin does not support group_size = {group_size}. " + f"Only group_sizes = {MARLIN_SUPPORTED_GROUP_SIZES} " + "are supported.", + ) + + return True, None + + +def check_marlin_supported( + quant_type: ScalarType, + group_size: int, + has_zp: bool = False, + device_capability: Optional[int] = None, +) -> bool: + cond, _ = _check_marlin_supported(quant_type, group_size, has_zp, device_capability) + return cond + + +def verify_marlin_supported( + quant_type: ScalarType, group_size: int, has_zp: bool = False +) -> None: + cond, err_msg = _check_marlin_supported(quant_type, group_size, has_zp) + if not cond: + assert err_msg is not None + raise ValueError(err_msg) + + +def verify_marlin_supports_shape( + output_size_per_partition: int, + input_size_per_partition: int, + input_size: int, + group_size: int, +) -> None: + + # Validate output_size_per_partition + if output_size_per_partition % GPTQ_MARLIN_MIN_THREAD_N != 0: + raise ValueError( + f"Weight output_size_per_partition = " + f"{output_size_per_partition} is not divisible by " + f" min_thread_n = {GPTQ_MARLIN_MIN_THREAD_N}. " + "Consider reducing tensor_parallel_size or running " + "with --quantization gptq." + ) + + # Validate input_size_per_partition + if input_size_per_partition % GPTQ_MARLIN_MIN_THREAD_K != 0: + raise ValueError( + f"Weight input_size_per_partition = " + f"{input_size_per_partition} is not divisible " + f"by min_thread_k = {GPTQ_MARLIN_MIN_THREAD_K}. " + "Consider reducing tensor_parallel_size or running " + "with --quantization gptq." + ) + + if group_size < input_size and input_size_per_partition % group_size != 0: + raise ValueError( + f"Weight input_size_per_partition = {input_size_per_partition}" + f" is not divisible by group_size = {group_size}. " + "Consider reducing tensor_parallel_size or running " + "with --quantization gptq." + ) + + +def check_marlin_supports_shape( + output_size_per_partition: int, + input_size_per_partition: int, + input_size: int, + group_size: int, +) -> tuple[bool, Optional[str]]: + try: + verify_marlin_supports_shape( + output_size_per_partition, input_size_per_partition, input_size, group_size + ) + except ValueError as e: + return False, e.__str__() + return True, None + + +def check_marlin_supports_layer(layer: LinearBase, group_size: int) -> bool: + output_size_per_partition = ( + getattr(layer, "output_size_per_partition", None) or layer.output_size + ) + input_size_per_partition = ( + getattr(layer, "input_size_per_partition", None) or layer.input_size + ) + + return check_marlin_supports_shape( + output_size_per_partition=output_size_per_partition, + input_size_per_partition=input_size_per_partition, + input_size=layer.input_size, + group_size=group_size, + )[0] + + +def check_moe_marlin_supports_layer(layer: LinearBase, group_size: int) -> bool: + hidden_size = layer.hidden_size + intermediate_size_per_partition = layer.intermediate_size_per_partition + # apply_router_weight_on_input is not supported for moe marlin + supports_router_weight = not layer.apply_router_weight_on_input + # moe marlin requires the activation to be silu + supports_activation = layer.activation == "silu" + + # gate-up: (n, k) = (intermediate_size_per_partition * 2, hidden_size) + # down: (n, k) = (hidden_size, intermediate_size_per_partition) + # moe marlin requires n % 128 == 0 and k % 64 == 0 + supports_shape = ( + hidden_size % 128 == 0 + and intermediate_size_per_partition % max(64, group_size) == 0 + ) + supports_group_size = group_size in [-1, 32, 64, 128] + return ( + supports_shape + and supports_group_size + and supports_router_weight + and supports_activation + ) + + +def marlin_make_workspace( + device: torch.device, max_blocks_per_sm: int = 1 +) -> torch.Tensor: + # In the new marlin kernel, we use the num of threadblocks as workspace + # size. The num of threadblocks is is sms_count * max_blocks_per_sm. + sms = torch.cuda.get_device_properties(device).multi_processor_count + return torch.zeros( + sms * max_blocks_per_sm, dtype=torch.int, device=device, requires_grad=False + ) + + +def marlin_is_k_full(act_order: bool, is_row_parallel: bool) -> bool: + return (not act_order) or (act_order and not is_row_parallel) + + +def marlin_repeat_scales_on_all_ranks( + act_order: bool, group_size: int, is_row_parallel: bool +) -> bool: + # Need to repeat scales on every rank if act_ordering or + # channelwise and RowParallelLinear + is_channelwise = group_size == -1 + return act_order or (is_channelwise and is_row_parallel) + + +def marlin_make_empty_g_idx(device: torch.device) -> torch.Tensor: + return torch.nn.Parameter( + torch.empty(0, dtype=torch.int, device=device), requires_grad=False + ) + + +def marlin_make_empty_zp(device: torch.device) -> torch.Tensor: + return torch.nn.Parameter( + torch.empty(0, dtype=torch.int, device=device), requires_grad=False + ) + + +def marlin_sort_g_idx(g_idx: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + g_idx_sort_indices = torch.argsort(g_idx).to(torch.int) + return g_idx[g_idx_sort_indices], g_idx_sort_indices + + +def get_scale_perms(): + scale_perm: list[int] = [] + for i in range(8): + scale_perm.extend([i + 8 * j for j in range(8)]) + scale_perm_single: list[int] = [] + for i in range(4): + scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) + return scale_perm, scale_perm_single + + +def marlin_permute_scales( + s: torch.Tensor, size_k: int, size_n: int, group_size: int +) -> torch.Tensor: + + scale_perm, scale_perm_single = get_scale_perms() + if group_size < size_k and group_size != -1: + s = s.reshape((-1, len(scale_perm)))[:, scale_perm] + else: + s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single] + s = s.reshape((-1, size_n)).contiguous() + + return s + + +def marlin_moe_permute_scales( + s: torch.Tensor, + size_k: int, + size_n: int, + group_size: int, +): + num_experts = s.shape[0] + output = torch.empty( + (num_experts, s.shape[1], s.shape[2]), + device=s.device, + dtype=s.dtype, + ) + + for e in range(num_experts): + output[e] = marlin_permute_scales(s[e], size_k, size_n, group_size) + return output + + +def marlin_zero_points( + zp: torch.Tensor, size_k: int, size_n: int, num_bits: int +) -> torch.Tensor: + # Permute zero-points in a similar way to scales, but do not use the + # "single" permutation, since zero-points are applied on every MMA + scale_perm, _ = get_scale_perms() + zp = zp.reshape((-1, len(scale_perm)))[:, scale_perm] + + # Interleave column dim (for the dequantize code) and pack it to int32 + if num_bits == 4: + interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) + elif num_bits == 8: + interleave = numpy.array([0, 2, 1, 3]) + else: + raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) + + zp = zp.reshape((-1, len(interleave)))[:, interleave].ravel() + zp = zp.reshape((-1, size_n)).contiguous() + zp = pack_cols(zp, num_bits, size_k, size_n) + + return zp + + +def awq_to_marlin_zero_points( + q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int +) -> torch.Tensor: + # AWQ zero-points are quantized and packed on the column dim. + # In addition, the values are permuted based on dequantizer. + # Here we undo both of these, and then apply marlin permutation + # and pack it back. + q_zp = unpack_cols(q_zp_packed, num_bits, size_k, size_n) + + # Undo interleaving (use argsort(..) to get inverse perm) + if num_bits == 4: + undo_interleave = numpy.argsort(numpy.array([0, 2, 4, 6, 1, 3, 5, 7])) + elif num_bits == 8: + undo_interleave = numpy.argsort(numpy.array([0, 2, 1, 3])) + else: + raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) + + q_zp = q_zp.reshape((-1, len(undo_interleave)))[:, undo_interleave].ravel() + q_zp = q_zp.reshape((-1, size_n)).contiguous() + + marlin_zp = marlin_zero_points(q_zp, size_k, size_n, num_bits) + return marlin_zp + + +def moe_awq_to_marlin_zero_points( + q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int +): + num_experts = q_zp_packed.shape[0] + output = torch.empty( + (num_experts, q_zp_packed.shape[1], q_zp_packed.shape[2]), + device=q_zp_packed.device, + dtype=q_zp_packed.dtype, + ) + for e in range(num_experts): + output[e] = awq_to_marlin_zero_points(q_zp_packed[e], size_k, size_n, num_bits) + return output + + +def maybe_warn_marlin_atomic_add(device, dtype): + if torch.compiler.is_dynamo_compiling(): + return + device_capability = torch.cuda.get_device_capability(device) + if device_capability[0] < 9 and dtype == torch.bfloat16: + logger.info_once( + "You are running Marlin kernel with bf16 on GPUs before SM90. " + "You can consider change to fp16 to achieve better performance " + "if possible." + ) + + +def maybe_warn_marlin_atomic_add_env(): + if torch.compiler.is_dynamo_compiling(): + return + # TODO(yiyun): Need to add sglang's MARLIN_USE_ATOMIC_ADD: bool = False + if True: + return + # if envs.VLLM_MARLIN_USE_ATOMIC_ADD: + # return + logger.info_once( + "Marlin kernel can achieve better performance for small size_n " + "with experimental use_atomic_add feature. " + "You can consider set environment variable " + "VLLM_MARLIN_USE_ATOMIC_ADD to 1 if possible." + ) + + +def should_use_atomic_add_reduce( + m: int, n: int, k: int, device: torch.device, dtype: torch.dtype +) -> bool: + + # the performance of atomicAdd is better than global reduce + # only when m*n is small and k is large + if n >= 2048 or k < 2048 or device.type != "cuda": + return False + + # disable atomicAdd reduce by default, + # one can enable it with VLLM_MARLIN_USE_ATOMIC_ADD=1 + # TODO: Need to add sglang's MARLIN_USE_ATOMIC_ADD: bool = False + if not True: + maybe_warn_marlin_atomic_add_env() + return False + + # sm8x doesn't support atomicAdd + bfloat16 natively + device_capability = torch.cuda.get_device_capability(device) + if device_capability[0] < 9 and dtype == torch.bfloat16: + maybe_warn_marlin_atomic_add(device, dtype) + return False + + return True + + +def apply_gptq_marlin_linear( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + weight_zp: torch.Tensor, + g_idx: torch.Tensor, + g_idx_sort_indices: torch.Tensor, + workspace: torch.Tensor, + wtype: ScalarType, + output_size_per_partition: int, + input_size_per_partition: int, + is_k_full: bool, + bias: Optional[torch.Tensor] = None, + use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT, +) -> torch.Tensor: + reshaped_x = input.reshape(-1, input.shape[-1]) + out_shape = input.shape[:-1] + (output_size_per_partition,) + + use_atomic_add = should_use_atomic_add_reduce( + m=reshaped_x.size(0), + n=output_size_per_partition, + k=reshaped_x.size(1), + device=input.device, + dtype=input.dtype, + ) + + output = ops.gptq_marlin_gemm( + reshaped_x, + None, + weight, + weight_scale, + None, + weight_zp, + g_idx, + g_idx_sort_indices, + workspace, + wtype, + size_m=reshaped_x.shape[0], + size_n=output_size_per_partition, + size_k=input_size_per_partition, + is_k_full=is_k_full, + use_atomic_add=use_atomic_add, + use_fp32_reduce=use_fp32_reduce, + is_zp_float=False, + ) + + if bias is not None: + output.add_(bias) # In-place add + + return output.reshape(out_shape) + + +def apply_awq_marlin_linear( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + weight_zp: torch.Tensor, + g_idx: torch.Tensor, + g_idx_sort_indices: torch.Tensor, + workspace: torch.Tensor, + quant_type: ScalarType, + output_size_per_partition: int, + input_size_per_partition: int, + bias: Optional[torch.Tensor] = None, + use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT, +) -> torch.Tensor: + reshaped_x = input.reshape(-1, input.shape[-1]) + out_shape = input.shape[:-1] + (output_size_per_partition,) + + use_atomic_add = should_use_atomic_add_reduce( + m=reshaped_x.size(0), + n=output_size_per_partition, + k=reshaped_x.size(1), + device=input.device, + dtype=input.dtype, + ) + + output = ops.gptq_marlin_gemm( + reshaped_x, + None, + weight, + weight_scale, + None, + weight_zp, + g_idx, + g_idx_sort_indices, + workspace, + quant_type, + size_m=reshaped_x.shape[0], + size_n=output_size_per_partition, + size_k=input_size_per_partition, + use_atomic_add=use_atomic_add, + use_fp32_reduce=use_fp32_reduce, + is_zp_float=False, + ) + + if bias is not None: + output.add_(bias) # In-place add + + return output.reshape(out_shape) + + +class MarlinConfig(QuantizationConfig): + """Config class for Marlin. + + Reference: https://github.com/IST-DASLab/marlin/tree/master + """ + + def __init__( + self, + group_size: int, + lm_head_quantized: bool, + ) -> None: + super().__init__() + + # Group size for the quantization. + self.group_size = group_size + self.lm_head_quantized = lm_head_quantized + if self.group_size != 128 and self.group_size != -1: + raise ValueError( + "Currently, only group size 128 and -1 (channelwise) " + "is supported for Marlin, but got group_size of " + f"{self.group_size}" + ) + + # 4 Bits packed into 32 bit datatype. + self.pack_factor = 32 // 4 + + # Tile size used by marlin kernels. + self.tile_size = 16 + + # Min out_features dim + self.min_n_threads = 64 + + # Min in_features dim + self.min_k_threads = 128 + + # Max parallel problems to solve at once (improves large + # batch performance) + self.max_parallel = 16 + + # Permutation length used by the marlin kernels. + self.perm_len = 1024 + + def __repr__(self) -> str: + return ( + f"MarlinConfig(group_size={self.group_size}, " + f"lm_head_quantized={self.lm_head_quantized})" + ) + + @classmethod + def get_name(cls) -> str: + return "marlin" + + @classmethod + def get_supported_act_dtypes(cls) -> list[torch.dtype]: + return [torch.half] + + @classmethod + # Need to figure it out + def get_min_capability(cls) -> int: + return 80 + + @classmethod + def get_config_filenames(cls) -> list[str]: + return ["quantize_config.json"] + + @classmethod + def from_config(cls, config: dict[str, Any]) -> "MarlinConfig": + group_size = cls.get_from_keys(config, ["group_size"]) + lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False) + return cls(group_size, lm_head_quantized) + + @classmethod + def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]: + # compat: autogptq >=0.8.0 use checkpoint_format: str + # compat: autogptq <=0.7.1 is_marlin_format: bool + is_marlin_format = hf_quant_cfg.get( + "checkpoint_format" + ) == "marlin" or hf_quant_cfg.get("is_marlin_format", False) + + is_valid_user_quant = ( + user_quant is None or user_quant == "gptq" or user_quant == "marlin" + ) + + if is_marlin_format and is_valid_user_quant: + msg = "The model is serialized in {} format. Using {} kernel.".format( + cls.get_name(), cls.get_name() + ) + logger.info(msg) + return cls.get_name() + + return None + + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> Optional["MarlinLinearMethod"]: + if isinstance(layer, LinearBase) or ( + isinstance(layer, ParallelLMHead) and self.lm_head_quantized + ): + return MarlinLinearMethod(self) + return None + + +class MarlinLinearMethod(LinearMethodBase): + """Linear method for Marlin. + + Args: + quant_config: The Marlin quantization config. + """ + + def __init__(self, quant_config: MarlinConfig): + self.quant_config = quant_config + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + del output_size # Unused. + weight_loader = extra_weight_attrs["weight_loader"] + + if params_dtype != torch.float16: + raise ValueError( + f"The params dtype must be float16, but got {params_dtype}" + ) + + # Validate output_size_per_partition + output_size_per_partition = sum(output_partition_sizes) + if output_size_per_partition % self.quant_config.min_n_threads != 0: + raise ValueError( + f"Weight output_size_per_partition = " + f"{output_size_per_partition} is not divisible by " + f"min_n_threads = {self.quant_config.min_n_threads}." + ) + if output_size_per_partition % self.quant_config.pack_factor != 0: + raise ValueError( + f"Weight output_size_per_partition = " + f"{output_size_per_partition} is not divisible by " + f"pack_factor = {self.quant_config.pack_factor}." + ) + + # Validate input_size_per_partition + if input_size_per_partition % self.quant_config.min_k_threads != 0: + raise ValueError( + f"Weight input_size_per_partition = " + f"{input_size_per_partition} is not divisible by " + f"min_k_threads = {self.quant_config.min_k_threads}." + ) + if ( + self.quant_config.group_size != -1 + and input_size_per_partition % self.quant_config.group_size != 0 + ): + raise ValueError( + f"Weight input_size_per_partition = " + f"{input_size_per_partition} is not divisible by " + f"group_size = {self.quant_config.group_size}." + ) + + # Check that we have at least 4 tiles horizontally in the shard + num_tiles_per_perm = self.quant_config.perm_len // ( + self.quant_config.tile_size**2 + ) + if output_size_per_partition % num_tiles_per_perm != 0: + raise ValueError("Each permutation group must reside on the same gpu") + + # Quantized 4Bit weights packed into Int32. + qweight = PackedvLLMParameter( + data=torch.empty( + input_size_per_partition // self.quant_config.tile_size, + output_size_per_partition + * self.quant_config.tile_size + // self.quant_config.pack_factor, + device="cuda", + dtype=torch.int32, + ), + input_dim=0, + output_dim=1, + packed_dim=1, + packed_factor=self.quant_config.pack_factor, + marlin_tile_size=self.quant_config.tile_size, + weight_loader=weight_loader, + ) + + # Determine if channelwise or not + input_groups = ( + 1 + if self.quant_config.group_size == -1 + else input_size_per_partition // self.quant_config.group_size + ) + + weight_scale_args = { + "data": torch.empty( + input_groups, + output_size_per_partition, + device="cuda", + dtype=params_dtype, + ), + "weight_loader": weight_loader, + } + if input_groups == 1: + scales = ChannelQuantScaleParameter(output_dim=1, **weight_scale_args) + else: + scales = GroupQuantScaleParameter( + output_dim=1, input_dim=0, **weight_scale_args + ) + + # Allocate workspace (Used for internal locking mechanism) + max_workspace_size = ( + output_size_per_partition // self.quant_config.min_n_threads + ) * self.quant_config.max_parallel + + workspace = BasevLLMParameter( + data=torch.zeros(max_workspace_size, device="cuda", dtype=torch.int), + weight_loader=weight_loader, + ) + + layer.register_parameter("B", qweight) + layer.register_parameter("s", scales) + layer.register_parameter("workspace", workspace) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + # required by torch.compile + layer.B = torch.nn.Parameter(layer.B.data, requires_grad=False) + layer.s = torch.nn.Parameter(layer.s.data, requires_grad=False) + layer.workspace = torch.nn.Parameter(layer.workspace.data, requires_grad=False) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + qweight = layer.B + scales = layer.s + workspace = layer.workspace + + x_2d = x.view(-1, x.shape[-1]) + + size_m = x_2d.shape[0] + size_k = x_2d.shape[1] + size_n = scales.shape[1] + + output_2d = ops.marlin_gemm( + x_2d, qweight, scales, workspace, size_m, size_n, size_k + ) + + output = output_2d.view(x.shape[:-1] + (output_2d.shape[1],)) + + if bias is not None: + output.add_(bias) # In-place add + + return output diff --git a/python/sglang/srt/layers/quantization/moe_wna16.py b/python/sglang/srt/layers/quantization/moe_wna16.py index 0bae43435f07..fe812595a80b 100644 --- a/python/sglang/srt/layers/quantization/moe_wna16.py +++ b/python/sglang/srt/layers/quantization/moe_wna16.py @@ -19,6 +19,36 @@ logger = logging.getLogger(__name__) +def get_weight_perm(num_bits: int): + perm_list: List[int] = [] + for i in range(32): + perm1: List[int] = [] + col = i // 4 + for block in [0, 1]: + for row in [ + 2 * (i % 4), + 2 * (i % 4) + 1, + 2 * (i % 4 + 4), + 2 * (i % 4 + 4) + 1, + ]: + perm1.append(16 * row + col + 8 * block) + for j in range(4): + perm_list.extend([p + 256 * j for p in perm1]) + + perm = np.array(perm_list) + + if num_bits == 4: + interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7]) + elif num_bits == 8: + interleave = np.array([0, 2, 1, 3]) + else: + raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) + + perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() + perm = torch.from_numpy(perm) + return perm + + class MoeWNA16Config(QuantizationConfig): """Config class for MOE WNA16 (W8A16/W4A16) quantization.""" diff --git a/python/sglang/srt/layers/quantization/quant_utils.py b/python/sglang/srt/layers/quantization/quant_utils.py deleted file mode 100644 index 59a1b1fdcfa6..000000000000 --- a/python/sglang/srt/layers/quantization/quant_utils.py +++ /dev/null @@ -1,166 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/quant_utils.py - -from typing import Optional - -import numpy -import torch -from sgl_kernel.scalar_type import ScalarType - - -def get_pack_factor(num_bits): - assert 32 % num_bits == 0, f"Unsupported num_bits = {num_bits}" - return 32 // num_bits - - -def pack_cols( - q_w: torch.Tensor, - num_bits: int, - size_k: int, - size_n: int, -): - assert q_w.shape == (size_k, size_n) - - pack_factor = get_pack_factor(num_bits) - assert size_n % pack_factor == 0 - - orig_device = q_w.device - - q_w = q_w.cpu().numpy().astype(numpy.uint32) - - q_res = numpy.zeros((size_k, size_n // pack_factor), dtype=numpy.uint32) - - for i in range(pack_factor): - q_res |= q_w[:, i::pack_factor] << num_bits * i - - q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) - q_res = q_res.contiguous() - - return q_res - - -def unpack_cols( - packed_q_w: torch.Tensor, - num_bits: int, - size_k: int, - size_n: int, -): - pack_factor = get_pack_factor(num_bits) - assert size_n % pack_factor == 0 - assert packed_q_w.shape == ( - size_k, - size_n // pack_factor, - ), "packed_q_w.shape = {} size_k = {}, size_n = {} pack_Factor = {}".format( - packed_q_w.shape, size_k, size_n, pack_factor - ) - - orig_device = packed_q_w.device - - packed_q_w_cpu = packed_q_w.cpu().numpy().astype(numpy.uint32) - q_res = numpy.zeros((size_k, size_n), dtype=numpy.uint32) - - mask = (1 << num_bits) - 1 - for i in range(pack_factor): - vals = packed_q_w_cpu & mask - packed_q_w_cpu >>= num_bits - q_res[:, i::pack_factor] = vals - - q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) - q_res = q_res.contiguous() - - return q_res - - -def quantize_weights( - w: torch.Tensor, - quant_type: ScalarType, - group_size: Optional[int], - zero_points: bool = False, - ref_zero_points_after_scales: bool = False, -): - assert ( - quant_type.is_integer() - ), "Floating point quantization may work but has not been tested" - assert not zero_points or group_size is not None, ( - "to have group zero points, group_size must be provided " - "(-1 group_size is channelwise)" - ) - - orig_device = w.device - orig_type = w.dtype - size_k, size_n = w.shape - - assert w.is_floating_point(), "w must be float" - - if group_size == -1: - group_size = size_k - - # Reshape to [groupsize, -1] - if group_size is not None and group_size < size_k: - w = w.reshape((-1, group_size, size_n)) - w = w.permute(1, 0, 2) - w = w.reshape((group_size, -1)) - - # Compute scale for each group - max_val = torch.max(w, 0, keepdim=True).values - min_val = torch.min(w, 0, keepdim=True).values - - max_q_val = quant_type.max() - min_q_val = quant_type.min() - - w_s = torch.Tensor([1.0]).to(w.device) # unscaled case - maybe_w_zp = None - if group_size is not None: - if zero_points: - assert not quant_type.is_signed() and quant_type.max() > 0 - w_s = (max_val - min_val).clamp(min=1e-5) / quant_type.max() - maybe_w_zp = ( - torch.round(torch.abs(min_val / w_s)).clamp(min_q_val, max_q_val).int() - ) - else: - # If the bias is such that there are no possible negative/positive - # values, set the max value to inf to avoid divide by 0 - w_s = torch.max( - abs(max_val / (max_q_val if max_q_val != 0 else torch.inf)), - abs(min_val / (min_q_val if min_q_val != 0 else torch.inf)), - ) - - # Quantize - w_q = torch.round(w / w_s).int() + (maybe_w_zp if zero_points else 0) - w_q = torch.clamp(w_q, min_q_val, max_q_val) - - # Compute ref (dequantized) - # For some kernels (namely Machete) the zero-points are applied after the - # scales are applied, for this case computing the reference in similar way - # allows us to use tighter error tolerances in our unit tests. - if ref_zero_points_after_scales and maybe_w_zp is not None: - w_ref = w_q.to(orig_type) * w_s - maybe_w_zp.to(orig_type) * w_s - else: - w_ref = (w_q - (maybe_w_zp if zero_points else 0)).to(orig_type) * w_s - - if quant_type.has_bias(): - w_q += quant_type.bias - - # Restore original shapes - if group_size is not None and group_size < size_k: - - def reshape_w(w): - w = w.reshape((group_size, -1, size_n)) - w = w.permute(1, 0, 2) - w = w.reshape((size_k, size_n)).contiguous() - return w - - w_q = reshape_w(w_q) - w_ref = reshape_w(w_ref) - w_s = w_s.reshape((-1, size_n)).contiguous() - - if maybe_w_zp is not None: - maybe_w_zp = maybe_w_zp.reshape((-1, size_n)).contiguous() - maybe_w_zp = maybe_w_zp.to(device=orig_device) - - return ( - w_ref.to(device=orig_device), - w_q.to(device=orig_device), - w_s if group_size is not None else None, - maybe_w_zp, - ) diff --git a/sgl-kernel/python/sgl_kernel/scalar_type.py b/python/sglang/srt/layers/quantization/scalar_type.py similarity index 100% rename from sgl-kernel/python/sgl_kernel/scalar_type.py rename to python/sglang/srt/layers/quantization/scalar_type.py diff --git a/python/sglang/srt/layers/quantization/utils.py b/python/sglang/srt/layers/quantization/utils.py index 40a381f3b9f8..2371208f7895 100644 --- a/python/sglang/srt/layers/quantization/utils.py +++ b/python/sglang/srt/layers/quantization/utils.py @@ -1,11 +1,13 @@ # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/quant_utils.py from types import MappingProxyType -from typing import List, Mapping, Tuple, Union +from typing import List, Mapping, Optional, Tuple, Union +import numpy import torch from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant +from sglang.srt.layers.quantization.scalar_type import ScalarType from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_npu _is_cuda = is_cuda() @@ -143,3 +145,162 @@ def replace_parameter( if not isinstance(new, torch.nn.Parameter): new = torch.nn.Parameter(new, requires_grad=False) mod.register_parameter(name, torch.nn.Parameter(new, requires_grad=False)) + + +def get_pack_factor(num_bits): + assert 32 % num_bits == 0, f"Unsupported num_bits = {num_bits}" + return 32 // num_bits + + +def pack_cols( + q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + assert q_w.shape == (size_k, size_n) + + pack_factor = get_pack_factor(num_bits) + assert size_n % pack_factor == 0 + + orig_device = q_w.device + + q_w = q_w.cpu().numpy().astype(numpy.uint32) + + q_res = numpy.zeros((size_k, size_n // pack_factor), dtype=numpy.uint32) + + for i in range(pack_factor): + q_res |= q_w[:, i::pack_factor] << num_bits * i + + q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) + q_res = q_res.contiguous() + + return q_res + + +def unpack_cols( + packed_q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + pack_factor = get_pack_factor(num_bits) + assert size_n % pack_factor == 0 + assert packed_q_w.shape == ( + size_k, + size_n // pack_factor, + ), "packed_q_w.shape = {} size_k = {}, size_n = {} pack_Factor = {}".format( + packed_q_w.shape, size_k, size_n, pack_factor + ) + + orig_device = packed_q_w.device + + packed_q_w_cpu = packed_q_w.cpu().numpy().astype(numpy.uint32) + q_res = numpy.zeros((size_k, size_n), dtype=numpy.uint32) + + mask = (1 << num_bits) - 1 + for i in range(pack_factor): + vals = packed_q_w_cpu & mask + packed_q_w_cpu >>= num_bits + q_res[:, i::pack_factor] = vals + + q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) + q_res = q_res.contiguous() + + return q_res + + +# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/quant_utils.py +def quantize_weights( + w: torch.Tensor, + quant_type: ScalarType, + group_size: Optional[int], + zero_points: bool = False, + ref_zero_points_after_scales: bool = False, +): + assert ( + quant_type.is_integer() + ), "Floating point quantization may work but has not been tested" + assert not zero_points or group_size is not None, ( + "to have group zero points, group_size must be provided " + "(-1 group_size is channelwise)" + ) + + orig_device = w.device + orig_type = w.dtype + size_k, size_n = w.shape + + assert w.is_floating_point(), "w must be float" + + if group_size == -1: + group_size = size_k + + # Reshape to [groupsize, -1] + if group_size is not None and group_size < size_k: + w = w.reshape((-1, group_size, size_n)) + w = w.permute(1, 0, 2) + w = w.reshape((group_size, -1)) + + # Compute scale for each group + max_val = torch.max(w, 0, keepdim=True).values + min_val = torch.min(w, 0, keepdim=True).values + + max_q_val = quant_type.max() + min_q_val = quant_type.min() + + w_s = torch.Tensor([1.0]).to(w.device) # unscaled case + maybe_w_zp = None + if group_size is not None: + if zero_points: + assert not quant_type.is_signed() and quant_type.max() > 0 + w_s = (max_val - min_val).clamp(min=1e-5) / quant_type.max() + maybe_w_zp = ( + torch.round(torch.abs(min_val / w_s)).clamp(min_q_val, max_q_val).int() + ) + else: + # If the bias is such that there are no possible negative/positive + # values, set the max value to inf to avoid divide by 0 + w_s = torch.max( + abs(max_val / (max_q_val if max_q_val != 0 else torch.inf)), + abs(min_val / (min_q_val if min_q_val != 0 else torch.inf)), + ) + + # Quantize + w_q = torch.round(w / w_s).int() + (maybe_w_zp if zero_points else 0) + w_q = torch.clamp(w_q, min_q_val, max_q_val) + + # Compute ref (dequantized) + # For some kernels (namely Machete) the zero-points are applied after the + # scales are applied, for this case computing the reference in similar way + # allows us to use tighter error tolerances in our unit tests. + if ref_zero_points_after_scales and maybe_w_zp is not None: + w_ref = w_q.to(orig_type) * w_s - maybe_w_zp.to(orig_type) * w_s + else: + w_ref = (w_q - (maybe_w_zp if zero_points else 0)).to(orig_type) * w_s + + if quant_type.has_bias(): + w_q += quant_type.bias + + # Restore original shapes + if group_size is not None and group_size < size_k: + + def reshape_w(w): + w = w.reshape((group_size, -1, size_n)) + w = w.permute(1, 0, 2) + w = w.reshape((size_k, size_n)).contiguous() + return w + + w_q = reshape_w(w_q) + w_ref = reshape_w(w_ref) + w_s = w_s.reshape((-1, size_n)).contiguous() + + if maybe_w_zp is not None: + maybe_w_zp = maybe_w_zp.reshape((-1, size_n)).contiguous() + maybe_w_zp = maybe_w_zp.to(device=orig_device) + + return ( + w_ref.to(device=orig_device), + w_q.to(device=orig_device), + w_s if group_size is not None else None, + maybe_w_zp, + ) diff --git a/sgl-kernel/python/sgl_kernel/fused_moe.py b/sgl-kernel/python/sgl_kernel/fused_moe.py index f9322e22824a..f825131ac254 100644 --- a/sgl-kernel/python/sgl_kernel/fused_moe.py +++ b/sgl-kernel/python/sgl_kernel/fused_moe.py @@ -2,10 +2,11 @@ from typing import Optional import torch -from sgl_kernel.scalar_type import scalar_types def get_scalar_type(num_bits: int, has_zp: bool): + from sglang.srt.layers.quantization.scalar_type import scalar_types + if has_zp: assert num_bits == 4 return scalar_types.uint4 diff --git a/sgl-kernel/tests/test_marlin_repack.py b/sgl-kernel/tests/test_marlin_repack.py index c0f13f46bea0..c229ae1cd01e 100644 --- a/sgl-kernel/tests/test_marlin_repack.py +++ b/sgl-kernel/tests/test_marlin_repack.py @@ -1,12 +1,10 @@ -import math - import numpy as np import pytest import torch from sgl_kernel import awq_marlin_repack -from sgl_kernel.scalar_type import scalar_types -from sglang.srt.layers.quantization.quant_utils import ( +from sglang.srt.layers.quantization.scalar_type import scalar_types +from sglang.srt.layers.quantization.utils import ( get_pack_factor, pack_cols, quantize_weights, diff --git a/test/srt/test_gptqmodel_dynamic.py b/test/srt/test_gptqmodel_dynamic.py index 284465b8b39e..feda8693459e 100644 --- a/test/srt/test_gptqmodel_dynamic.py +++ b/test/srt/test_gptqmodel_dynamic.py @@ -51,13 +51,12 @@ def check_quant_method(model_path: str, use_marlin_kernel: bool): model_config=model_config, load_config=load_config, device_config=device_config ) - from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod - from vllm.model_executor.layers.quantization.gptq_marlin import ( + from sglang.srt.layers.linear import UnquantizedLinearMethod + from sglang.srt.layers.quantization.gptq import ( + GPTQLinearMethod, GPTQMarlinLinearMethod, ) - from sglang.srt.layers.linear import UnquantizedLinearMethod - linear_method_cls = ( GPTQMarlinLinearMethod if use_marlin_kernel else (GPTQLinearMethod) ) @@ -162,7 +161,7 @@ def setUpClass(cls): cls.model, cls.base_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=["--dtype", "float16"], + other_args=["--dtype", "bfloat16"], ) @classmethod diff --git a/test/srt/test_int4_kernel.py b/test/srt/test_int4_kernel.py deleted file mode 100644 index 0665f9b91a56..000000000000 --- a/test/srt/test_int4_kernel.py +++ /dev/null @@ -1,301 +0,0 @@ -import itertools -import sys -import unittest - -import torch - -sys.path.insert(0, "/home/hadoop-hmart-waimai-rank/vllm") - -# from sglang.srt.layers.moe.topk import select_experts -from sgl_kernel import fused_marlin_moe -from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk - -# from vllm.model_executor.layers. import select_experts -from vllm.model_executor.layers.fused_moe.layer import FusedMoE -from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( - marlin_quantize, -) -from vllm.scalar_type import scalar_types - - -def stack_and_dev(tensors: list[torch.Tensor]): - dev = tensors[0].device - return torch.stack(tensors, dim=0).to(dev) - - -def torch_moe(a, w1, w2, score, topk, expert_map): - B, D = a.shape - a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) - out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device) - score = torch.softmax(score, dim=-1, dtype=torch.float32) - topk_weight, topk_ids = torch.topk(score, topk) - topk_weight = topk_weight.view(-1) - topk_ids = topk_ids.view(-1) - if expert_map is not None: - topk_ids = expert_map[topk_ids] - for i in range(w1.shape[0]): - mask = topk_ids == i - if mask.sum(): - out[mask] = SiluAndMul()(a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose( - 0, 1 - ) - return ( - out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype) - ).sum(dim=1) - - -def native_w8a8_per_token_matmul(A, B, As, Bs, output_dtype=torch.float16): - """Matrix multiplication function that supports per-token input quantization and per-column weight quantization""" - A = A.to(torch.float32) - B = B.to(torch.float32) - - assert A.shape[-1] == B.shape[-1], "Dimension mismatch" - assert B.ndim == 2 and B.is_contiguous(), "B must be a 2D contiguous tensor" - - # Reshape input - M = A.numel() // A.shape[-1] - B = B.t() # Transpose weight matrix - N, K = B.shape - origin_C_shape = A.shape[:-1] + (K,) - A = A.reshape(M, N) - # As is per-token [M, 1], Bs is per-column [1, K] - C = torch.matmul(A, B) # [M, K] - C = As * C * Bs.view(1, -1) # Broadcast per-column scale - - return C.reshape(origin_C_shape).to(output_dtype) - - -def torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk): - """This function performs fused moe with per-column int8 quantization using native torch.""" - - B, D = a.shape - # Perform per-token quantization - a_q, a_s = per_token_quant_int8(a) - # Repeat tokens to match topk - a_q = a_q.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) - # Also repeat the scale - a_s = a_s.view(B, -1, 1).repeat(1, topk, 1).reshape(-1, 1) # [B*topk, 1] - - out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device) - - # Calculate routing - score = torch.softmax(score, dim=-1, dtype=torch.float32) - topk_weight, topk_ids = torch.topk(score, topk) - topk_weight = topk_weight.view(-1) - topk_ids = topk_ids.view(-1) - # Process each expert - for i in range(w1.shape[0]): - mask = topk_ids == i - if mask.sum(): - # First MLP layer: note that a_s is now per-token - inter_out = native_w8a8_per_token_matmul( - a_q[mask], w1[i], a_s[mask], w1_s[i], output_dtype=a.dtype - ) - # Activation function - act_out = SiluAndMul().forward_native(inter_out) - # Quantize activation output with per-token - act_out_q, act_out_s = per_token_quant_int8(act_out) - - # Second MLP layer - out[mask] = native_w8a8_per_token_matmul( - act_out_q, w2[i], act_out_s, w2_s[i], output_dtype=a.dtype - ) - # Apply routing weights and sum - return ( - out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype) - ).sum(dim=1) - - -def marlin_fused_moe( - N, E, K, a, w1, w2, num_bits, group_size, act_order, score, topk, ep_size -): - quant_type = scalar_types.uint4b8 if num_bits == 4 else scalar_types.uint8b128 - if ep_size > 1: - local_e = E // ep_size - e_ids = torch.randperm(E, device="cuda", dtype=torch.int32)[:local_e] - e_map = torch.full((E,), -1, device="cuda", dtype=torch.int32) - e_map[e_ids] = torch.arange(local_e, device="cuda", dtype=torch.int32) - w1 = w1[e_ids] - w2 = w2[e_ids] - else: - e_map = None - w_ref1_l = [] - qweight1_l = [] - scales1_l = [] - zeros1_l = [] - g_idx1_l = [] - sort_indices1_l = [] - s1_l = [] - for i in range(w1.shape[0]): - test_perm = torch.randperm(n=K) - quant_res = marlin_quantize( - w1[i].transpose(1, 0), quant_type, group_size, act_order, test_perm - ) - w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = quant_res - w_ref1_l.append(w_ref1.T) - qweight1_l.append(qweight1) - scales1_l.append(scales1) - g_idx1_l.append(g_idx1) - sort_indices1_l.append(sort_indices1) - w_ref1 = stack_and_dev(w_ref1_l) - qweight1 = stack_and_dev(qweight1_l).contiguous() - scales1 = stack_and_dev(scales1_l) - g_idx1 = stack_and_dev(g_idx1_l) if g_idx1_l else None - zeros1 = stack_and_dev(zeros1_l) if zeros1_l else None - sort_indices1 = stack_and_dev(sort_indices1_l) if sort_indices1_l else None - - w_ref2_l = [] - qweight2_l = [] - scales2_l = [] - zeros2_l = [] - g_idx2_l = [] - sort_indices2_l = [] - for i in range(w2.shape[0]): - test_perm = torch.randperm(n=N) - quant_res = marlin_quantize( - w2[i].transpose(1, 0), quant_type, group_size, act_order, test_perm - ) - w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = quant_res - - w_ref2_l.append(w_ref2.T) - qweight2_l.append(qweight2) - scales2_l.append(scales2) - g_idx2_l.append(g_idx2) - sort_indices2_l.append(sort_indices2) - - w_ref2 = stack_and_dev(w_ref2_l) - qweight2 = stack_and_dev(qweight2_l).contiguous() - scales2 = stack_and_dev(scales2_l) - g_idx2 = stack_and_dev(g_idx2_l) if g_idx2_l else None - zeros2 = stack_and_dev(zeros2_l) if zeros2_l else None - sort_indices2 = stack_and_dev(sort_indices2_l) if sort_indices2_l else None - - topk_weights, topk_ids = fused_topk(a, score, topk, False) - # topk_weights, topk_ids = FusedMoE.select_experts( - # hidden_states=a, - # router_logits=score, - # top_k=topk, - # num_expert_group=E, - # use_grouped_topk=False, - # renormalize=False, - # topk_group=None, - # ) - - torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, e_map) - marlin_output = fused_marlin_moe( - a, - qweight1, - qweight2, - scales1, - scales2, - score, - topk_weights, - topk_ids, - global_num_experts=E, - expert_map=e_map, - g_idx1=g_idx1, - g_idx2=g_idx2, - sort_indices1=sort_indices1, - sort_indices2=sort_indices2, - w1_zeros=zeros1, - w2_zeros=zeros2, - num_bits=num_bits, - is_k_full=True, - ) - return marlin_output, torch_output - - -class TestW8A8Int8FusedMoE(unittest.TestCase): - DTYPES = [torch.float16] - M = [1, 16] - N = [128] - K = [256] - E = [4, 10] - TOP_KS = [2, 4] - BLOCK_SIZE = [[128, 128]] - SEEDS = [0] - NUM_BITS = [4] - EP_SIZE = [1, 4] - - @classmethod - def setUpClass(cls): - if not torch.cuda.is_available(): - raise unittest.SkipTest("CUDA is not available") - torch.set_default_device("cuda") - - def _w4a8_int8_fused_moe( - self, M, N, K, E, topk, block_size, dtype, seed, num_bits, ep_size - ): - torch.manual_seed(seed) - a = torch.randn((M, K), dtype=dtype) / 10 - - # Generate int8 weights - w1_fp16 = (torch.rand((E, 2 * N, K), dtype=dtype) - 0.5) * 2 - w2_fp16 = (torch.rand((E, K, N), dtype=dtype) - 0.5) * 2 - - score = torch.randn((M, E), dtype=dtype) - - with torch.inference_mode(): - marlin_out, ref_out = marlin_fused_moe( - N=N, - E=E, - K=K, - a=a, - w1=w1_fp16, - w2=w2_fp16, - num_bits=num_bits, - group_size=-1, - act_order=False, - score=score, - topk=topk, - ep_size=ep_size, - ) - # Check results - if ( - torch.mean( - torch.abs(marlin_out.to(torch.float32) - ref_out.to(torch.float32)) - ) - / torch.mean(torch.abs(ref_out.to(torch.float32))) - > 0.1 - ): - print(f"marlin_out: {marlin_out}") - print(f"ref_out: {ref_out}") - print( - torch.mean( - torch.abs(marlin_out.to(torch.float32) - ref_out.to(torch.float32)) - ) - / torch.mean(torch.abs(ref_out.to(torch.float32))) - ) - torch.testing.assert_close(marlin_out, ref_out, atol=2e-2, rtol=0) - - def test_w4a8_int8_fused_moe(self): - for params in itertools.product( - self.M, - self.N, - self.K, - self.E, - self.TOP_KS, - self.BLOCK_SIZE, - self.DTYPES, - self.SEEDS, - self.NUM_BITS, - self.EP_SIZE, - ): - with self.subTest( - M=params[0], - N=params[1], - K=params[2], - E=params[3], - topk=params[4], - block_size=params[5], - dtype=params[6], - seed=params[7], - num_bits=params[8], - ep_size=params[9], - ): - self._w4a8_int8_fused_moe(*params) - - -if __name__ == "__main__": - unittest.main(verbosity=2) diff --git a/test/srt/test_w4a8.py b/test/srt/test_w4a8.py deleted file mode 100644 index 75d41ee5f8e2..000000000000 --- a/test/srt/test_w4a8.py +++ /dev/null @@ -1,14 +0,0 @@ -import sgl_kernel -import torch - -x = torch.randn(10, 10, device="cuda") -qweight = torch.randn(10, 10, device="cuda") -s1_scales = torch.randn(10, device="cuda") -input_scales = torch.randn(10, device="cuda") -s1_szeros = torch.randn(10, device="cuda") -input_sum = torch.randn(10, device="cuda") -output_buffer = torch.randn(10, device="cuda") - -torch.ops.sgl_kernel.gemm_forward_cuda.default( - x, qweight, s1_scales, input_scales, s1_szeros, input_sum, output_buffer -) From 4395c87a9b831672823c40884348620e641f6559 Mon Sep 17 00:00:00 2001 From: Mick Date: Thu, 17 Jul 2025 08:52:38 +0800 Subject: [PATCH 012/396] refactor: unify names of the feature field of MultimodalDataItem (#8075) --- .../multimodal_processors/qwen_audio.py | 2 +- python/sglang/srt/managers/schedule_batch.py | 41 +++++++------------ python/sglang/srt/models/clip.py | 2 +- .../sglang/srt/models/deepseek_janus_pro.py | 2 +- python/sglang/srt/models/deepseek_vl2.py | 4 +- python/sglang/srt/models/gemma3_mm.py | 2 +- python/sglang/srt/models/gemma3n_mm.py | 6 +-- python/sglang/srt/models/internvl.py | 2 +- python/sglang/srt/models/kimi_vl.py | 2 +- python/sglang/srt/models/llava.py | 4 +- python/sglang/srt/models/llavavid.py | 2 +- python/sglang/srt/models/minicpmo.py | 10 ++--- python/sglang/srt/models/minicpmv.py | 2 +- python/sglang/srt/models/mistral.py | 2 +- python/sglang/srt/models/mllama.py | 8 ++-- python/sglang/srt/models/mllama4.py | 2 +- python/sglang/srt/models/phi4mm.py | 4 +- python/sglang/srt/models/qwen2_5_vl.py | 8 ++-- python/sglang/srt/models/qwen2_audio.py | 2 +- python/sglang/srt/models/qwen2_vl.py | 8 ++-- python/sglang/srt/models/vila.py | 2 +- .../multimodal/processors/base_processor.py | 8 +++- .../sglang/srt/multimodal/processors/clip.py | 2 +- .../multimodal/processors/deepseek_vl_v2.py | 2 +- .../srt/multimodal/processors/internvl.py | 2 +- .../srt/multimodal/processors/janus_pro.py | 2 +- .../sglang/srt/multimodal/processors/llava.py | 2 +- .../srt/multimodal/processors/minicpm.py | 4 +- .../sglang/srt/multimodal/processors/mlama.py | 2 +- .../srt/multimodal/processors/mllama4.py | 2 +- .../srt/multimodal/processors/phi4mm.py | 2 +- .../srt/multimodal/processors/pixtral.py | 2 +- test/srt/test_vlm_accuracy.py | 2 +- 33 files changed, 66 insertions(+), 83 deletions(-) diff --git a/python/sglang/srt/managers/multimodal_processors/qwen_audio.py b/python/sglang/srt/managers/multimodal_processors/qwen_audio.py index 0558b5f5a4b8..23b7de5cfd96 100644 --- a/python/sglang/srt/managers/multimodal_processors/qwen_audio.py +++ b/python/sglang/srt/managers/multimodal_processors/qwen_audio.py @@ -78,7 +78,7 @@ async def process_mm_data_async( output_lengths = (input_lengths - 2) // 2 + 1 item = MultimodalDataItem( - audio_features=res["input_features"], + feature=res["input_features"], audio_feature_lens=output_lengths, audio_offsets=audio_offsets, modality=Modality.AUDIO, diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index c2750d072457..01da558b7bf9 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -207,13 +207,12 @@ class MultimodalDataItem: modality: Modality hash: int = None pad_value: int = None - image_sizes: Tuple[int, int] = None offsets: Optional[list] = None + # the raw features returned by processor, e.g. pixel_values or audio_features + feature: Union[torch.Tensor, np.ndarray] = None + + image_sizes: Tuple[int, int] = None - # the real data, pixel_values or audio_features - # data: Union[List[torch.Tensor], List[np.ndarray]] - pixel_values: Union[torch.Tensor, np.ndarray, "PIL.Image"] = None - audio_features: Union[torch.Tensor, np.ndarray] = None audio_feature_lens: Optional[List[torch.Tensor]] = None audio_offsets: Optional[List[Tuple[int, int]]] = None precomputed_features: Optional[Union[torch.Tensor, np.ndarray]] = None @@ -238,7 +237,6 @@ class MultimodalDataItem: image_grid_hws: Optional[List[torch.Tensor]] = None # For gemma3n - input_features: Optional[torch.Tensor] = None input_features_mask: Optional[torch.Tensor] = None @staticmethod @@ -254,18 +252,11 @@ def set_pad_value(self): from sglang.srt.managers.mm_utils import hash_feature if self.hash is None: - if self.precomputed_features is not None: - self.hash = hash_feature(self.precomputed_features) - elif self.is_audio(): - if self.audio_features is not None: - self.hash = hash_feature(self.audio_features) - elif self.input_features is not None: - self.hash = hash_feature(self.input_features) - elif self.is_video(): - self.hash = hash_feature(self.pixel_values_videos) + if self.feature is not None: + hashed_feature = self.feature else: - self.hash = hash_feature(self.pixel_values) - + hashed_feature = self.precomputed_features + self.hash = hash_feature(hashed_feature) assert self.hash is not None self.pad_value = self.hash % (1 << 30) @@ -275,8 +266,7 @@ def is_modality(self, modality: Modality) -> bool: def is_audio(self): return (self.modality == Modality.AUDIO) and ( self.precomputed_features is not None - or not MultimodalDataItem.is_empty_list(self.audio_features) - or not MultimodalDataItem.is_empty_list(self.input_features) + or not MultimodalDataItem.is_empty_list(self.feature) ) def is_image(self): @@ -284,13 +274,13 @@ def is_image(self): self.is_modality(Modality.IMAGE) or self.is_modality(Modality.MULTI_IMAGES) ) and ( self.precomputed_features is not None - or not MultimodalDataItem.is_empty_list(self.pixel_values) + or not MultimodalDataItem.is_empty_list(self.feature) ) def is_video(self): return (self.modality == Modality.VIDEO) and ( self.precomputed_features is not None - or not MultimodalDataItem.is_empty_list(self.pixel_values_videos) + or not MultimodalDataItem.is_empty_list(self.feature) ) def is_valid(self) -> bool: @@ -311,7 +301,7 @@ def from_dict(obj: dict): return ret def merge(self, other): - self.pixel_values += other.pixel_values + self.feature += other.feature self.image_sizes += other.image_sizes self.image_offsets += other.image_offsets self.hash = hash((self.hash, other.hash)) @@ -354,7 +344,6 @@ def from_dict(obj: dict): assert isinstance(ret.mm_items, list) ret.mm_items = [item for item in ret.mm_items if item.is_valid()] - for item in ret.mm_items: item.set_pad_value() @@ -1278,11 +1267,9 @@ def prepare_for_extend(self): if mm_input is None: continue for mm_item in mm_input.mm_items: - pixel_values = getattr(mm_item, "pixel_values", None) + pixel_values = getattr(mm_item, "feature", None) if isinstance(pixel_values, torch.Tensor): - mm_item.pixel_values = pixel_values.to( - self.device, non_blocking=True - ) + mm_item.feature = pixel_values.to(self.device, non_blocking=True) self.multimodal_inputs = multimodal_inputs self.token_type_ids = token_type_ids_tensor self.seq_lens_sum = sum(seq_lens) diff --git a/python/sglang/srt/models/clip.py b/python/sglang/srt/models/clip.py index f271b45a4d11..ea9fee9ac29e 100644 --- a/python/sglang/srt/models/clip.py +++ b/python/sglang/srt/models/clip.py @@ -463,7 +463,7 @@ def forward( if forward_batch.mm_inputs is not None: mm_inputs = forward_batch.mm_inputs pixel_values_list = [ - item.pixel_values + item.feature for item in flatten_nested_list( [mm_input.mm_items for mm_input in mm_inputs if mm_input is not None] ) diff --git a/python/sglang/srt/models/deepseek_janus_pro.py b/python/sglang/srt/models/deepseek_janus_pro.py index 8d266a3be6d8..fe1c833f7224 100644 --- a/python/sglang/srt/models/deepseek_janus_pro.py +++ b/python/sglang/srt/models/deepseek_janus_pro.py @@ -1960,7 +1960,7 @@ def __init__( self.logits_processor = LogitsProcessor(config) def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: - pixel_values = torch.concat([item.pixel_values for item in items], dim=0) + pixel_values = torch.concat([item.feature for item in items], dim=0) bs, n = pixel_values.shape[0:2] pixel_values = pixel_values.to( device=self.vision_model.device, dtype=self.vision_model.dtype diff --git a/python/sglang/srt/models/deepseek_vl2.py b/python/sglang/srt/models/deepseek_vl2.py index 9941927cd65e..cf4988b5201b 100644 --- a/python/sglang/srt/models/deepseek_vl2.py +++ b/python/sglang/srt/models/deepseek_vl2.py @@ -268,9 +268,9 @@ def get_image_feature(self, items: List[MultimodalDataItem]): # TODO: can it be batched ? images_in_this_batch = [] for item in items: - assert item.pixel_values.dim() == 4 + assert item.feature.dim() == 4 image_feature = self.vision.forward_features( - item.pixel_values.type(next(self.vision.parameters()).dtype).to( + item.feature.type(next(self.vision.parameters()).dtype).to( device=next(self.vision.parameters()).device ) ) diff --git a/python/sglang/srt/models/gemma3_mm.py b/python/sglang/srt/models/gemma3_mm.py index 93c145e1b54b..527a11b691e2 100644 --- a/python/sglang/srt/models/gemma3_mm.py +++ b/python/sglang/srt/models/gemma3_mm.py @@ -283,7 +283,7 @@ def get_image_feature(self, items: List[MultimodalDataItem]): image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`). """ # Process images one by one to handle flatten_batch=True constraint in vision_tower - all_pixel_values = flatten_nested_list([item.pixel_values for item in items]) + all_pixel_values = flatten_nested_list([item.feature for item in items]) vision_outputs_list = [] for pixel_values_batch in all_pixel_values: diff --git a/python/sglang/srt/models/gemma3n_mm.py b/python/sglang/srt/models/gemma3n_mm.py index 3bc327ea3e97..5139a9c2ded5 100644 --- a/python/sglang/srt/models/gemma3n_mm.py +++ b/python/sglang/srt/models/gemma3n_mm.py @@ -265,7 +265,7 @@ def get_image_feature(self, items: List[MultimodalDataItem]): image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`). """ # Process images one by one to handle flatten_batch=True constraint in vision_tower - all_pixel_values = flatten_nested_list([item.pixel_values for item in items]) + all_pixel_values = flatten_nested_list([item.feature for item in items]) vision_outputs_list = [] for pixel_values_batch in all_pixel_values: @@ -316,9 +316,7 @@ def get_audio_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: audio_features (`torch.Tensor`): Audio feature tensor of shape `(num_audios, audio_length, embed_dim)`). """ # Extract audio features and masks from items - all_input_features = flatten_nested_list( - [item.input_features for item in items] - ) + all_input_features = flatten_nested_list([item.feature for item in items]) all_input_features_mask = flatten_nested_list( [~item.input_features_mask for item in items] ) # Note(Xinyuan): reverse the mask according to the HF implementation diff --git a/python/sglang/srt/models/internvl.py b/python/sglang/srt/models/internvl.py index 732752317400..056797cbfe00 100644 --- a/python/sglang/srt/models/internvl.py +++ b/python/sglang/srt/models/internvl.py @@ -510,7 +510,7 @@ def get_image_feature(self, items: List[MultimodalDataItem]): Returns: image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`). """ - pixel_values = torch.cat([item.pixel_values for item in items]) + pixel_values = torch.cat([item.feature for item in items]) image_features = self.extract_feature(pixel_values) return image_features diff --git a/python/sglang/srt/models/kimi_vl.py b/python/sglang/srt/models/kimi_vl.py index f4386a80882f..68ed47b2ef0f 100644 --- a/python/sglang/srt/models/kimi_vl.py +++ b/python/sglang/srt/models/kimi_vl.py @@ -144,7 +144,7 @@ def __init__( def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: pixel_values = ( - torch.cat([item.pixel_values for item in items], dim=0) + torch.cat([item.feature for item in items], dim=0) .type(self.vision_tower.dtype) .to(self.vision_tower.device) ) diff --git a/python/sglang/srt/models/llava.py b/python/sglang/srt/models/llava.py index b0b82a82b770..6375657e77a6 100644 --- a/python/sglang/srt/models/llava.py +++ b/python/sglang/srt/models/llava.py @@ -186,7 +186,7 @@ def forward( bs = forward_batch.batch_size pixel_values = flatten_nested_list( [ - [item.pixel_values for item in image_inputs[i].mm_items] + [item.feature for item in image_inputs[i].mm_items] for i in range(bs) if need_vision[i] ] @@ -753,7 +753,7 @@ def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: features = [] for item in items: # in each item, we assume pixel_values is always batched - pixel_values, image_sizes = item.pixel_values, item.image_sizes + pixel_values, image_sizes = item.feature, item.image_sizes image_outputs = self.vision_tower( pixel_values, image_sizes, output_hidden_states=True ) diff --git a/python/sglang/srt/models/llavavid.py b/python/sglang/srt/models/llavavid.py index 22a007e128ad..e5d6aa72ba9a 100644 --- a/python/sglang/srt/models/llavavid.py +++ b/python/sglang/srt/models/llavavid.py @@ -135,7 +135,7 @@ def forward( if need_vision.any(): pixel_values = flatten_nested_list( [ - [item.pixel_values for item in image_inputs[i].mm_items] + [item.feature for item in image_inputs[i].mm_items] for i in range(bs) if need_vision[i] ] diff --git a/python/sglang/srt/models/minicpmo.py b/python/sglang/srt/models/minicpmo.py index a5234772eaf1..786738ecaa21 100644 --- a/python/sglang/srt/models/minicpmo.py +++ b/python/sglang/srt/models/minicpmo.py @@ -1552,9 +1552,7 @@ def get_audio_embedding_streaming(self, items: List[MultimodalDataItem]): Returns: List[List[torch.Tensor]]: audio embeddings """ - wavforms = flatten_nested_list( - [item.audio_features for item in items if item.audio_features] - ) + wavforms = flatten_nested_list([item.feature for item in items if item.feature]) # list, [[x1, x2], [y1], [z1]] audio_feature_lens_raw = flatten_nested_list( [item.audio_feature_lens for item in items if item.audio_feature_lens] @@ -1659,9 +1657,7 @@ def get_audio_embedding(self, items: List[MultimodalDataItem], chunk_length=-1): List[List[torch.Tensor]]: audio embeddings """ # (bs, 80, frames) or [], multi audios need filled in advance - wavforms = flatten_nested_list( - [item.audio_features for item in items if item.audio_features] - ) + wavforms = flatten_nested_list([item.feature for item in items if item.feature]) # list, [[x1, x2], [y1], [z1]] audio_feature_lens_raw = flatten_nested_list( [item.audio_feature_lens for item in items if item.audio_feature_lens] @@ -1778,7 +1774,7 @@ def get_omni_embedding( def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: # list of tensors - pixel_values = flatten_nested_list([item.pixel_values for item in items]) + pixel_values = flatten_nested_list([item.feature for item in items]) tgt_sizes = torch.stack( flatten_nested_list([item.tgt_size for item in items]), dim=0 ) diff --git a/python/sglang/srt/models/minicpmv.py b/python/sglang/srt/models/minicpmv.py index 0c6d4297fb9d..8166d1646ad9 100644 --- a/python/sglang/srt/models/minicpmv.py +++ b/python/sglang/srt/models/minicpmv.py @@ -724,7 +724,7 @@ def get_vision_embedding( def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: # list of tensors - pixel_values = flatten_nested_list([item.pixel_values for item in items]) + pixel_values = flatten_nested_list([item.feature for item in items]) tgt_sizes = torch.stack( flatten_nested_list([item.tgt_size for item in items]), dim=0 ) diff --git a/python/sglang/srt/models/mistral.py b/python/sglang/srt/models/mistral.py index d3d2efcaee94..632e857c280b 100644 --- a/python/sglang/srt/models/mistral.py +++ b/python/sglang/srt/models/mistral.py @@ -56,7 +56,7 @@ def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: features = [] for item in items: # in each item, we assume pixel_values is always batched - pixel_values, image_sizes = item.pixel_values, item.image_sizes + pixel_values, image_sizes = item.feature, item.image_sizes image_outputs = self.vision_tower( pixel_values, image_sizes, output_hidden_states=True ) diff --git a/python/sglang/srt/models/mllama.py b/python/sglang/srt/models/mllama.py index fed9e4b59a16..fa294ddcd0c4 100644 --- a/python/sglang/srt/models/mllama.py +++ b/python/sglang/srt/models/mllama.py @@ -838,9 +838,7 @@ def __init__( self.logits_processor = LogitsProcessor(config.text_config) def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs): - pixel_values = torch.cat( - [item.pixel_values for item in mm_inputs.mm_items], dim=0 - ) + pixel_values = torch.cat([item.feature for item in mm_inputs.mm_items], dim=0) pad_values = [item.pad_value for item in mm_inputs.mm_items] num_concurrent_media, num_tiles = pixel_values.shape[1:3] @@ -862,7 +860,7 @@ def _batch_image_inputs(self, forward_batch: ForwardBatch): if not forward_batch.encoder_cached[i] and mm_input is not None: pixel_values = torch.cat( - [item.pixel_values for item in mm_input.mm_items], dim=0 + [item.feature for item in mm_input.mm_items], dim=0 ) max_num_images = max(max_num_images, pixel_values.shape[1]) @@ -897,7 +895,7 @@ def _batch_image_inputs(self, forward_batch: ForwardBatch): encoder_lens_need.append(forward_batch.encoder_lens[k]) pixel_values = torch.cat( - [item.pixel_values for item in mm_input.mm_items], dim=0 + [item.feature for item in mm_input.mm_items], dim=0 ) for j in range(pixel_values.shape[1]): img = pixel_values[0, j] diff --git a/python/sglang/srt/models/mllama4.py b/python/sglang/srt/models/mllama4.py index 55e793247584..18b7e57e5872 100644 --- a/python/sglang/srt/models/mllama4.py +++ b/python/sglang/srt/models/mllama4.py @@ -147,7 +147,7 @@ def get_image_feature( raise ValueError("Vision model not available for text-only checkpoint") pixel_values = ( - torch.concat([item.pixel_values for item in items]) + torch.concat([item.feature for item in items]) .to(next(self.vision_model.parameters()).device) .type(next(self.vision_model.parameters()).dtype) ) diff --git a/python/sglang/srt/models/phi4mm.py b/python/sglang/srt/models/phi4mm.py index 44bcad97a81e..8a74888ac9c5 100644 --- a/python/sglang/srt/models/phi4mm.py +++ b/python/sglang/srt/models/phi4mm.py @@ -422,9 +422,7 @@ def __init__( def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: dtype = next(self.vision_encoder.parameters()).dtype - pixel_values = torch.cat([item.pixel_values for item in items], dim=0).type( - dtype - ) + pixel_values = torch.cat([item.feature for item in items], dim=0).type(dtype) image_attention_mask = torch.cat([item.image_emb_mask for item in items], dim=0) image_sizes = torch.cat([item.image_sizes for item in items], dim=0) image_embeds = self.vision_encoder( diff --git a/python/sglang/srt/models/qwen2_5_vl.py b/python/sglang/srt/models/qwen2_5_vl.py index d4f412e49582..d2a92217a315 100644 --- a/python/sglang/srt/models/qwen2_5_vl.py +++ b/python/sglang/srt/models/qwen2_5_vl.py @@ -497,7 +497,7 @@ def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs): def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: # in qwen-vl, last dim is the same - pixel_values = torch.cat([item.pixel_values for item in items], dim=0).type( + pixel_values = torch.cat([item.feature for item in items], dim=0).type( self.visual.dtype ) image_grid_thw = torch.concat([item.image_grid_thw for item in items], dim=0) @@ -508,9 +508,9 @@ def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: def get_video_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: # in qwen-vl, last dim is the same - pixel_values = torch.cat( - [getattr(item, "pixel_values_videos") for item in items], dim=0 - ).type(self.visual.dtype) + pixel_values = torch.cat([item.feature for item in items], dim=0).type( + self.visual.dtype + ) video_grid_thw = torch.concat([item.video_grid_thw for item in items], dim=0) assert pixel_values.dim() == 2, pixel_values.dim() assert video_grid_thw.dim() == 2, video_grid_thw.dim() diff --git a/python/sglang/srt/models/qwen2_audio.py b/python/sglang/srt/models/qwen2_audio.py index 53e087496242..bc232f0bee15 100644 --- a/python/sglang/srt/models/qwen2_audio.py +++ b/python/sglang/srt/models/qwen2_audio.py @@ -118,7 +118,7 @@ def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs): def get_audio_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: # Extract audio features from input items - input_features = torch.cat([item.audio_features for item in items], dim=0).type( + input_features = torch.cat([item.feature for item in items], dim=0).type( self.audio_tower.dtype ) diff --git a/python/sglang/srt/models/qwen2_vl.py b/python/sglang/srt/models/qwen2_vl.py index 59179752a7e2..55f325813782 100644 --- a/python/sglang/srt/models/qwen2_vl.py +++ b/python/sglang/srt/models/qwen2_vl.py @@ -484,7 +484,7 @@ def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs): def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: # in qwen-vl, last dim is the same - pixel_values = torch.cat([item.pixel_values for item in items], dim=0).type( + pixel_values = torch.cat([item.feature for item in items], dim=0).type( self.visual.dtype ) image_grid_thw = torch.concat([item.image_grid_thw for item in items], dim=0) @@ -495,9 +495,9 @@ def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: def get_video_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: # in qwen-vl, last dim is the same - pixel_values = torch.cat( - [item.pixel_values_videos for item in items], dim=0 - ).type(self.visual.dtype) + pixel_values = torch.cat([item.feature for item in items], dim=0).type( + self.visual.dtype + ) video_grid_thw = torch.concat([item.video_grid_thw for item in items], dim=0) assert pixel_values.dim() == 2, pixel_values.dim() assert video_grid_thw.dim() == 2, video_grid_thw.dim() diff --git a/python/sglang/srt/models/vila.py b/python/sglang/srt/models/vila.py index 752217d674f6..2bb0b2d35d9e 100644 --- a/python/sglang/srt/models/vila.py +++ b/python/sglang/srt/models/vila.py @@ -237,7 +237,7 @@ def forward( return cast(LogitsProcessorOutput, output) def get_image_feature(self, mm_input: List[MultimodalDataItem]) -> Tensor: - pixel_values = cast(Tensor, mm_input[0].pixel_values) + pixel_values = cast(Tensor, mm_input[0].feature) ##### BEGIN COPY modeling_vila.py ##### diff --git a/python/sglang/srt/multimodal/processors/base_processor.py b/python/sglang/srt/multimodal/processors/base_processor.py index 91aaa19090cf..44e22885caec 100644 --- a/python/sglang/srt/multimodal/processors/base_processor.py +++ b/python/sglang/srt/multimodal/processors/base_processor.py @@ -5,7 +5,6 @@ import os import re from abc import ABC, abstractmethod -from functools import lru_cache from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np @@ -156,6 +155,10 @@ def __init__(self, hf_config, server_args, _processor): # "precomputed_features" - handled specially as it can be any modality } + # name of the feature filed + # TODO: pass from processors + self.FEATURE_NAMES = ["pixel_values", "pixel_values_videos", "audio_features"] + def process_mm_data( self, input_text, images=None, videos=None, audios=None, **kwargs ): @@ -524,6 +527,9 @@ def collect_mm_items_from_processor_output( if modality not in items: items[modality] = MultimodalDataItem(modality=modality) + if attr_name in self.FEATURE_NAMES: + attr_name = "feature" + # Set attribute setattr(items[modality], attr_name, value) diff --git a/python/sglang/srt/multimodal/processors/clip.py b/python/sglang/srt/multimodal/processors/clip.py index cda5edf89525..a36269819c42 100644 --- a/python/sglang/srt/multimodal/processors/clip.py +++ b/python/sglang/srt/multimodal/processors/clip.py @@ -26,7 +26,7 @@ async def process_mm_data_async( image_inputs["input_ids"] = image_inputs["input_ids"].tolist()[0] image_inputs["mm_items"] = [ MultimodalDataItem( - pixel_values=image_inputs["pixel_values"], modality=Modality.IMAGE + feature=image_inputs["pixel_values"], modality=Modality.IMAGE ) ] diff --git a/python/sglang/srt/multimodal/processors/deepseek_vl_v2.py b/python/sglang/srt/multimodal/processors/deepseek_vl_v2.py index 0ffd91dc3237..50547ad2d714 100644 --- a/python/sglang/srt/multimodal/processors/deepseek_vl_v2.py +++ b/python/sglang/srt/multimodal/processors/deepseek_vl_v2.py @@ -68,7 +68,7 @@ async def process_mm_data_async( input_ids=input_ids, mm_token_id=self._processor.image_token_id ) item = MultimodalDataItem( - pixel_values=res["images"], + feature=res["images"], offsets=image_offsets, modality=Modality.IMAGE, image_emb_mask=images_seq_mask, diff --git a/python/sglang/srt/multimodal/processors/internvl.py b/python/sglang/srt/multimodal/processors/internvl.py index df9b67aadeae..f9ed9ba76d86 100644 --- a/python/sglang/srt/multimodal/processors/internvl.py +++ b/python/sglang/srt/multimodal/processors/internvl.py @@ -223,7 +223,7 @@ def process_image_internvl(image, input_size=448, max_num=12): ) items = [ MultimodalDataItem( - pixel_values=pixel_values, + feature=pixel_values, modality=Modality.IMAGE, offsets=image_offsets, ) diff --git a/python/sglang/srt/multimodal/processors/janus_pro.py b/python/sglang/srt/multimodal/processors/janus_pro.py index 36be9ded80ab..8ea013d29aae 100644 --- a/python/sglang/srt/multimodal/processors/janus_pro.py +++ b/python/sglang/srt/multimodal/processors/janus_pro.py @@ -47,7 +47,7 @@ async def process_mm_data_async( return { "mm_items": [ MultimodalDataItem( - pixel_values=res["pixel_values"], + feature=res["pixel_values"], image_emb_mask=res["images_emb_mask"], offsets=image_offsets, modality=Modality.IMAGE, diff --git a/python/sglang/srt/multimodal/processors/llava.py b/python/sglang/srt/multimodal/processors/llava.py index d32398d85b9a..03c4bf5ec634 100644 --- a/python/sglang/srt/multimodal/processors/llava.py +++ b/python/sglang/srt/multimodal/processors/llava.py @@ -158,7 +158,7 @@ async def process_mm_data_async( return { "mm_items": [ MultimodalDataItem( - pixel_values=pixel_values, + feature=pixel_values, image_sizes=image_sizes, modality=modality, ) diff --git a/python/sglang/srt/multimodal/processors/minicpm.py b/python/sglang/srt/multimodal/processors/minicpm.py index 7945f20b5f50..369971ccbe53 100644 --- a/python/sglang/srt/multimodal/processors/minicpm.py +++ b/python/sglang/srt/multimodal/processors/minicpm.py @@ -114,7 +114,7 @@ async def process_mm_data_async( if len(pixel_values) != 0: item = MultimodalDataItem( - pixel_values=pixel_values, + feature=pixel_values, offsets=image_offsets, tgt_size=tgt_sizes_flat, modality=Modality.IMAGE, @@ -135,7 +135,7 @@ async def process_mm_data_async( else: audio_offsets = None item = MultimodalDataItem( - audio_features=[res["audio_features"]], + feature=[res["audio_features"]], audio_feature_lens=res["audio_feature_lens"], offsets=audio_offsets, modality=Modality.AUDIO, diff --git a/python/sglang/srt/multimodal/processors/mlama.py b/python/sglang/srt/multimodal/processors/mlama.py index aeb227be2f70..783145027b79 100644 --- a/python/sglang/srt/multimodal/processors/mlama.py +++ b/python/sglang/srt/multimodal/processors/mlama.py @@ -24,7 +24,7 @@ async def process_mm_data_async( image_inputs["input_ids"] = image_inputs["input_ids"].tolist()[0] image_inputs["mm_items"] = [ MultimodalDataItem( - pixel_values=image_inputs["pixel_values"], + feature=image_inputs["pixel_values"], aspect_ratio_id=image_inputs["aspect_ratio_ids"], aspect_ratio_mask=image_inputs["aspect_ratio_mask"], modality=Modality.IMAGE, diff --git a/python/sglang/srt/multimodal/processors/mllama4.py b/python/sglang/srt/multimodal/processors/mllama4.py index a7988c3557f9..ccf70adc8766 100644 --- a/python/sglang/srt/multimodal/processors/mllama4.py +++ b/python/sglang/srt/multimodal/processors/mllama4.py @@ -142,7 +142,7 @@ async def process_mm_data_async( # Add metadata for image processing processor_output["mm_items"] = [ MultimodalDataItem( - pixel_values=processor_output["pixel_values"], + feature=processor_output["pixel_values"], modality=Modality.IMAGE, offsets=image_offsets, ) diff --git a/python/sglang/srt/multimodal/processors/phi4mm.py b/python/sglang/srt/multimodal/processors/phi4mm.py index fbf2cccb590f..d2e009d27f3e 100644 --- a/python/sglang/srt/multimodal/processors/phi4mm.py +++ b/python/sglang/srt/multimodal/processors/phi4mm.py @@ -62,7 +62,7 @@ async def process_mm_data_async( items = [ MultimodalDataItem( - pixel_values=res["input_image_embeds"], + feature=res["input_image_embeds"], image_sizes=res["image_sizes"], image_emb_mask=res["image_attention_mask"], offsets=image_offsets, diff --git a/python/sglang/srt/multimodal/processors/pixtral.py b/python/sglang/srt/multimodal/processors/pixtral.py index 9be08cdcc99a..8b741d6279c0 100644 --- a/python/sglang/srt/multimodal/processors/pixtral.py +++ b/python/sglang/srt/multimodal/processors/pixtral.py @@ -103,7 +103,7 @@ async def process_mm_data_async( ) mm_items = [ MultimodalDataItem( - pixel_values=processor_output["pixel_values"], + feature=processor_output["pixel_values"], image_sizes=processor_output["image_sizes"], modality=Modality.IMAGE, offsets=image_offsets, diff --git a/test/srt/test_vlm_accuracy.py b/test/srt/test_vlm_accuracy.py index ea83f3eef755..2f2e294fa0c3 100644 --- a/test/srt/test_vlm_accuracy.py +++ b/test/srt/test_vlm_accuracy.py @@ -245,7 +245,7 @@ async def test_vlm_embedding_output(self): MultimodalInputs( mm_items=[ MultimodalDataItem( - pixel_values=pixel_values_flat, + feature=pixel_values_flat, offsets=image_offsets, tgt_size=tgt_sizes_flat, modality=Modality.IMAGE, From 795668dc73eecc09907b7f25161c53b0bdc3cc43 Mon Sep 17 00:00:00 2001 From: Yingchun Lai Date: Thu, 17 Jul 2025 08:55:59 +0800 Subject: [PATCH 013/396] feat: add tp_rank, pp_rank and dp_rank labels for scheduler metrics (#7597) Co-authored-by: Stefan He --- python/sglang/srt/managers/scheduler.py | 38 ++++++++++++------- .../scheduler_output_processor_mixin.py | 2 +- python/sglang/srt/server_args.py | 8 ++++ 3 files changed, 33 insertions(+), 15 deletions(-) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index a7f893253637..ab966f924cc6 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -252,6 +252,9 @@ def __init__( self.enable_overlap = not server_args.disable_overlap_schedule self.skip_tokenizer_init = server_args.skip_tokenizer_init self.enable_metrics = server_args.enable_metrics + self.enable_metrics_for_all_schedulers = ( + server_args.enable_metrics_for_all_schedulers + ) self.enable_kv_cache_events = server_args.kv_events_config is not None self.stream_interval = server_args.stream_interval self.spec_algorithm = SpeculativeAlgorithm.from_string( @@ -281,9 +284,6 @@ def __init__( self.send_to_tokenizer = get_zmq_socket( context, zmq.PUSH, port_args.tokenizer_ipc_name, False ) - self.send_metrics_from_scheduler = get_zmq_socket( - context, zmq.PUSH, port_args.metrics_ipc_name, False - ) if server_args.skip_tokenizer_init: # Directly send to the TokenizerManager @@ -309,10 +309,14 @@ def __init__( else: self.recv_from_tokenizer = None self.recv_from_rpc = None - self.send_metrics_from_scheduler = None self.send_to_tokenizer = SimpleNamespace(send_pyobj=lambda x: None) self.send_to_detokenizer = SimpleNamespace(send_pyobj=lambda x: None) + if self.current_scheduler_metrics_enabled(): + self.send_metrics_from_scheduler = get_zmq_socket( + context, zmq.PUSH, port_args.metrics_ipc_name, False + ) + # Init tokenizer self.init_tokenizer() @@ -495,7 +499,7 @@ def __init__( self.init_profier() # Init metrics stats - self.init_metrics() + self.init_metrics(tp_rank, pp_rank, dp_rank) self.init_kv_events(server_args.kv_events_config) # Init request dispatcher @@ -537,6 +541,9 @@ def __init__( if get_bool_env_var("SGLANG_GC_LOG"): configure_gc_logger() + def current_scheduler_metrics_enabled(self): + return self.attn_tp_rank == 0 or self.enable_metrics_for_all_schedulers + def maybe_sleep_on_idle(self): if self.idle_sleeper is not None: self.idle_sleeper.maybe_sleep() @@ -660,7 +667,7 @@ def init_profier(self): self.profile_in_progress: bool = False self.rpd_profiler = None - def init_metrics(self): + def init_metrics(self, tp_rank: int, pp_rank: int, dp_rank: Optional[int]): self.last_gen_throughput: float = 0.0 self.last_input_throughput: float = 0.0 self.step_time_dict = defaultdict(list) # Dict[batch size -> step time] @@ -671,12 +678,15 @@ def init_metrics(self): self.stats = SchedulerStats() if self.enable_metrics: engine_type = "unified" - self.metrics_collector = SchedulerMetricsCollector( - labels={ - "model_name": self.server_args.served_model_name, - "engine_type": engine_type, - }, - ) + labels = { + "model_name": self.server_args.served_model_name, + "engine_type": engine_type, + "tp_rank": tp_rank, + "pp_rank": pp_rank, + } + if dp_rank is not None: + labels["dp_rank"] = dp_rank + self.metrics_collector = SchedulerMetricsCollector(labels=labels) def init_kv_events(self, kv_events_config: Optional[str]): if self.enable_kv_cache_events: @@ -1519,7 +1529,7 @@ def check_memory(self): if ( self.enable_metrics - and self.attn_tp_rank == 0 + and self.current_scheduler_metrics_enabled() and time.perf_counter() > self.metrics_collector.last_log_time + 30 ): # During idle time, also collect metrics every 30 seconds. @@ -1755,7 +1765,7 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]: self.chunked_req.is_chunked += 1 # Print stats - if self.attn_tp_rank == 0: + if self.current_scheduler_metrics_enabled(): self.log_prefill_stats(adder, can_run_list, running_bs) # Create a new batch diff --git a/python/sglang/srt/managers/scheduler_output_processor_mixin.py b/python/sglang/srt/managers/scheduler_output_processor_mixin.py index 75bc4427a7e5..635121920479 100644 --- a/python/sglang/srt/managers/scheduler_output_processor_mixin.py +++ b/python/sglang/srt/managers/scheduler_output_processor_mixin.py @@ -290,7 +290,7 @@ def process_batch_result_decode( self.forward_ct_decode = (self.forward_ct_decode + 1) % (1 << 30) if ( - self.attn_tp_rank == 0 + self.current_scheduler_metrics_enabled() and self.forward_ct_decode % self.server_args.decode_log_interval == 0 ): self.log_decode_stats(can_run_cuda_graph, running_batch=batch) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 95ba9bee69e5..e475039d7380 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -105,6 +105,7 @@ class ServerArgs: crash_dump_folder: Optional[str] = None show_time_cost: bool = False enable_metrics: bool = False + enable_metrics_for_all_schedulers: bool = False bucket_time_to_first_token: Optional[List[float]] = None bucket_e2e_request_latency: Optional[List[float]] = None bucket_inter_token_latency: Optional[List[float]] = None @@ -1002,6 +1003,13 @@ def add_cli_args(parser: argparse.ArgumentParser): action="store_true", help="Enable log prometheus metrics.", ) + parser.add_argument( + "--enable-metrics-for-all-schedulers", + action="store_true", + help="Enable --enable-metrics-for-all-schedulers when you want schedulers on all TP ranks (not just TP 0) " + "to record request metrics separately. This is especially useful when dp_attention is enabled, as " + "otherwise all metrics appear to come from TP 0.", + ) parser.add_argument( "--bucket-time-to-first-token", type=float, From 8a7a7770e58b2dfaa67aa49b2e24fc98ddcfd731 Mon Sep 17 00:00:00 2001 From: Simo Lin Date: Wed, 16 Jul 2025 18:09:28 -0700 Subject: [PATCH 014/396] [ci] limit cmake build nproc (#8100) --- .github/workflows/release-docker-dev.yml | 2 +- docker/Dockerfile | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/.github/workflows/release-docker-dev.yml b/.github/workflows/release-docker-dev.yml index fcc1d8866bf3..f75f64683bb7 100644 --- a/.github/workflows/release-docker-dev.yml +++ b/.github/workflows/release-docker-dev.yml @@ -41,5 +41,5 @@ jobs: - name: Build and Push Dev Image run: | - docker buildx build --output type=image,compression=zstd . -f docker/Dockerfile --build-arg CUDA_VERSION=${{ matrix.variant.version }} --build-arg BUILD_TYPE=${{ matrix.variant.type }} -t lmsysorg/sglang:${{ matrix.variant.tag }} --no-cache + docker buildx build --output type=image,compression=zstd . -f docker/Dockerfile --build-arg CUDA_VERSION=${{ matrix.variant.version }} --build-arg BUILD_TYPE=${{ matrix.variant.type }} --build-arg CMAKE_BUILD_PARALLEL_LEVEL=$(nproc) -t lmsysorg/sglang:${{ matrix.variant.tag }} --no-cache docker push lmsysorg/sglang:${{ matrix.variant.tag }} diff --git a/docker/Dockerfile b/docker/Dockerfile index 349873da4acf..eac2c8a4c446 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -3,6 +3,7 @@ FROM nvidia/cuda:${CUDA_VERSION}-cudnn-devel-ubuntu22.04 ARG BUILD_TYPE=all ARG DEEPEP_COMMIT=b6ce310bb0b75079682d09bc2ebc063a074fbd58 +ARG CMAKE_BUILD_PARALLEL_LEVEL=2 ENV DEBIAN_FRONTEND=noninteractive \ CUDA_HOME=/usr/local/cuda \ GDRCOPY_HOME=/usr/src/gdrdrv-2.4.4/ \ @@ -78,7 +79,7 @@ RUN wget https://developer.download.nvidia.com/compute/redist/nvshmem/3.3.9/sour NVSHMEM_TIMEOUT_DEVICE_POLLING=0 \ NVSHMEM_USE_GDRCOPY=1 \ cmake -S . -B build/ -DCMAKE_INSTALL_PREFIX=${NVSHMEM_DIR} -DCMAKE_CUDA_ARCHITECTURES=90 \ - && cmake --build build --target install -j \ + && cmake --build build --target install -j${CMAKE_BUILD_PARALLEL_LEVEL} \ && cd /sgl-workspace/DeepEP \ && NVSHMEM_DIR=${NVSHMEM_DIR} pip install . From 9069884b5140f95fc4a381b5c98114717744e110 Mon Sep 17 00:00:00 2001 From: Cheng Wan <54331508+ch-wan@users.noreply.github.com> Date: Wed, 16 Jul 2025 20:41:47 -0700 Subject: [PATCH 015/396] [ci] disable memory imbalance check for draft worker (#8108) --- python/sglang/srt/model_executor/model_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 12db1d0559f3..923b4d02b543 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -561,7 +561,7 @@ def init_torch_distributed(self): # Check memory for tensor parallelism local_gpu_memory = get_available_gpu_memory(self.device, self.gpu_id) - if self.tp_size > 1: + if self.tp_size > 1 and not self.is_draft_worker: if min_per_gpu_memory < local_gpu_memory * 0.9: if get_bool_env_var("SGL_DISABLE_TP_MEMORY_INBALANCE_CHECK"): logger.warning( From 5c08a36cbfaeefab461ef7c42d897acae568b97a Mon Sep 17 00:00:00 2001 From: hzh0425 Date: Thu, 17 Jul 2025 12:33:29 +0800 Subject: [PATCH 016/396] [Fix] ensure DeepGEMM is only enabled for FP8_W8A8 models (#8110) --- python/sglang/srt/layers/moe/ep_moe/layer.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index 353f131c91a3..e8bfadfb65fe 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -1272,6 +1272,12 @@ def __init__( routed_scaling_factor=routed_scaling_factor, ) self.deepep_mode = deepep_mode + if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM: + assert self.use_fp8_w8a8, ( + "DeepGEMM requires an fp8_w8a8 model; " + "alternatively, you can disable DeepGEMM by turning off the ENABLE_JIT_DEEPGEMM environment variable." + ) + if self.deepep_mode.enable_low_latency(): assert ( deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM From 02404a1e35d9e53b6ed28f0707f4eaa5a431d3a1 Mon Sep 17 00:00:00 2001 From: Cheng Wan <54331508+ch-wan@users.noreply.github.com> Date: Thu, 17 Jul 2025 00:46:40 -0700 Subject: [PATCH 017/396] [ci] recover 8-gpu deepep test (#8105) --- .github/workflows/pr-test.yml | 42 +++++++++++++++++------------------ scripts/ci_install_deepep.sh | 29 +++++++++--------------- test/srt/test_deepep_large.py | 20 +++++++++-------- test/srt/test_deepep_small.py | 20 ++++++++--------- 4 files changed, 52 insertions(+), 59 deletions(-) diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml index b29bf63f4576..2378695e21ee 100644 --- a/.github/workflows/pr-test.yml +++ b/.github/workflows/pr-test.yml @@ -324,33 +324,33 @@ jobs: cd test/srt python3 run_suite.py --suite per-commit-4-gpu-deepep - # unit-test-deepep-8-gpu: - # if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && - # github.event.pull_request.draft == false - # runs-on: 8-gpu-runner - # needs: [ - # unit-test-deepep-4-gpu, - # ] - # steps: - # - name: Checkout code - # uses: actions/checkout@v4 - # - # - name: Install dependencies - # run: | - # bash scripts/ci_install_deepep.sh - # - # - name: Run test - # timeout-minutes: 20 - # run: | - # cd test/srt - # python3 run_suite.py --suite per-commit-8-gpu-deepep + unit-test-deepep-8-gpu: + if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && + github.event.pull_request.draft == false + runs-on: 8-gpu-runner + needs: [ + unit-test-deepep-4-gpu, + ] + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Install dependencies + run: | + bash scripts/ci_install_deepep.sh + + - name: Run test + timeout-minutes: 20 + run: | + cd test/srt + python3 run_suite.py --suite per-commit-8-gpu-deepep finish: if: always() needs: [ unit-test-frontend, unit-test-backend-1-gpu, unit-test-backend-2-gpu, unit-test-backend-4-gpu, unit-test-backend-8-gpu, performance-test-1-gpu-part-1, performance-test-1-gpu-part-2, performance-test-2-gpu, - accuracy-test-1-gpu, accuracy-test-2-gpu, unit-test-deepep-4-gpu, # unit-test-deepep-8-gpu, + accuracy-test-1-gpu, accuracy-test-2-gpu, unit-test-deepep-4-gpu, unit-test-deepep-8-gpu, ] runs-on: ubuntu-latest steps: diff --git a/scripts/ci_install_deepep.sh b/scripts/ci_install_deepep.sh index aa4dab097bb6..e743bddaf6a6 100755 --- a/scripts/ci_install_deepep.sh +++ b/scripts/ci_install_deepep.sh @@ -4,30 +4,30 @@ set -euxo pipefail bash scripts/ci_install_dependency.sh -if python3 -c "import deep_ep" >/dev/null 2>&1; then - echo "deep_ep is already installed or importable. Skipping installation." - exit 0 -fi - export GDRCOPY_HOME=/usr/src/gdrdrv-2.4.4/ export NVSHMEM_DIR=/opt/nvshmem/install export LD_LIBRARY_PATH="${NVSHMEM_DIR}/lib:$LD_LIBRARY_PATH" export PATH="${NVSHMEM_DIR}/bin:$PATH" export CUDA_HOME=/usr/local/cuda +if python3 -c "import deep_ep" >/dev/null 2>&1; then + echo "deep_ep is already installed or importable. Skipping installation." + exit 0 +fi + # Install system dependencies apt install -y curl wget git sudo libibverbs-dev rdma-core infiniband-diags openssh-server perftest ibverbs-providers libibumad3 libibverbs1 libnl-3-200 libnl-route-3-200 librdmacm1 build-essential cmake # Install GDRCopy rm -rf /opt/gdrcopy && mkdir -p /opt/gdrcopy -mkdir -p /opt/nvshmem +rm -rf /opt/nvshmem && mkdir -p /opt/nvshmem cd /opt/gdrcopy git clone https://github.com/NVIDIA/gdrcopy.git . git checkout v2.4.4 apt update apt install -y nvidia-dkms-535 apt install -y build-essential devscripts debhelper fakeroot pkg-config dkms -apt install -y check libsubunit0 libsubunit-dev +apt install -y check libsubunit0 libsubunit-dev python3-venv cd packages CUDA=/usr/local/cuda ./build-deb-packages.sh dpkg -i gdrdrv-dkms_*.deb @@ -40,16 +40,11 @@ if [ ! -e "/usr/lib/x86_64-linux-gnu/libmlx5.so" ]; then fi apt-get update && apt-get install -y libfabric-dev -# Clone DeepEP -rm -rf /root/.cache/deepep && git clone https://github.com/deepseek-ai/DeepEP.git /root/.cache/deepep && cd /root/.cache/deepep && git checkout eef7ab50fa5cf0ab1dd3fce4c6493c90bdf290ac - # Install NVSHMEM cd /opt/nvshmem -wget https://developer.download.nvidia.com/compute/redist/nvshmem/3.2.5/source/nvshmem_src_3.2.5-1.txz -tar -xf nvshmem_src_3.2.5-1.txz -rm -rf nvshmem && mv nvshmem_src nvshmem -cd nvshmem -git apply /root/.cache/deepep/third-party/nvshmem.patch +wget https://developer.download.nvidia.com/compute/redist/nvshmem/3.3.9/source/nvshmem_src_cuda12-all-all-3.3.9.tar.gz +tar -xf nvshmem_src_cuda12-all-all-3.3.9.tar.gz +mv nvshmem_src nvshmem && cd nvshmem NVSHMEM_SHMEM_SUPPORT=0 \ NVSHMEM_UCX_SUPPORT=0 \ NVSHMEM_USE_NCCL=0 \ @@ -63,12 +58,10 @@ cd build make -j$(nproc) install # Install DeepEP +rm -rf /root/.cache/deepep && git clone https://github.com/deepseek-ai/DeepEP.git /root/.cache/deepep && cd /root/.cache/deepep && git checkout b6ce310bb0b75079682d09bc2ebc063a074fbd58 cd /root/.cache/deepep && python3 setup.py install # Verify configuration -echo "=== NCCL Configuration ===" -nvidia-smi topo -m -nvidia-smi nvlink -s echo "=== Verify GDRCOPY ===" gdrcopy_copybw echo "=== Verify NVSHMEM ===" diff --git a/test/srt/test_deepep_large.py b/test/srt/test_deepep_large.py index 8afb2896f8f8..703eb7789316 100644 --- a/test/srt/test_deepep_large.py +++ b/test/srt/test_deepep_large.py @@ -45,6 +45,7 @@ def setUpClass(cls): "256", "--max-running-requests", "2048", + "--disable-radix-cache", ], ) @@ -54,10 +55,10 @@ def tearDownClass(cls): def test_gsm8k(self): args = SimpleNamespace( - num_shots=8, + num_shots=5, data_path=None, - num_questions=1250, - parallel=1250, + num_questions=1200, + parallel=1200, max_new_tokens=512, host="http://127.0.0.1", port=int(self.base_url.split(":")[-1]), @@ -65,7 +66,7 @@ def test_gsm8k(self): metrics = run_eval_few_shot_gsm8k(args) print(f"Eval accuracy of GSM8K: {metrics=}") - self.assertGreater(metrics["accuracy"], 0.93) + self.assertGreater(metrics["accuracy"], 0.92) class TestDeepseekMTP(CustomTestCase): @@ -107,6 +108,7 @@ def setUpClass(cls): "1", "--speculative-num-draft-tokens", "2", + "--disable-radix-cache", ], ) @@ -116,10 +118,10 @@ def tearDownClass(cls): def test_gsm8k(self): args = SimpleNamespace( - num_shots=8, + num_shots=5, data_path=None, - num_questions=1250, - parallel=1250, + num_questions=1200, + parallel=1200, max_new_tokens=512, host="http://127.0.0.1", port=int(self.base_url.split(":")[-1]), @@ -127,7 +129,7 @@ def test_gsm8k(self): metrics = run_eval_few_shot_gsm8k(args) print(f"Eval accuracy of GSM8K: {metrics=}") - self.assertGreater(metrics["accuracy"], 0.93) + self.assertGreater(metrics["accuracy"], 0.92) server_info = requests.get(self.base_url + "/get_server_info") avg_spec_accept_length = server_info.json()["internal_states"][0][ @@ -138,7 +140,7 @@ def test_gsm8k(self): f"accuracy={metrics['accuracy']=:.3f}\n" f"{avg_spec_accept_length=:.3f}\n" ) - self.assertGreater(avg_spec_accept_length, 1.9) + self.assertGreater(avg_spec_accept_length, 1.85) if __name__ == "__main__": diff --git a/test/srt/test_deepep_small.py b/test/srt/test_deepep_small.py index 9724ae735f97..e26017ade608 100644 --- a/test/srt/test_deepep_small.py +++ b/test/srt/test_deepep_small.py @@ -36,6 +36,8 @@ def setUpClass(cls): "128", "--max-running-requests", "128", + "--mem-fraction-static", + "0.5", ], ) @@ -56,7 +58,7 @@ def test_gsm8k(self): metrics = run_eval_few_shot_gsm8k(args) print(metrics) - self.assertGreater(metrics["accuracy"], 0.62) + self.assertGreater(metrics["accuracy"], 0.60) class TestHybridDPTP(CustomTestCase): @@ -100,7 +102,7 @@ def test_gsm8k(self): metrics = run_eval_few_shot_gsm8k(args) print(metrics) - self.assertGreater(metrics["accuracy"], 0.62) + self.assertGreater(metrics["accuracy"], 0.60) class TestTP(CustomTestCase): @@ -141,10 +143,10 @@ def test_gsm8k(self): metrics = run_eval_few_shot_gsm8k(args) print(metrics) - self.assertGreater(metrics["accuracy"], 0.62) + self.assertGreater(metrics["accuracy"], 0.60) -# @unittest.skip("covered in test_deepep_large.py") +@unittest.skip("covered in test_deepep_large.py") class TestNoGatherdBuffer(CustomTestCase): @classmethod def setUpClass(cls): @@ -189,7 +191,7 @@ def test_gsm8k(self): metrics = run_eval_few_shot_gsm8k(args) print(metrics) - self.assertGreater(metrics["accuracy"], 0.62) + self.assertGreater(metrics["accuracy"], 0.60) class TestTBO(CustomTestCase): @@ -236,10 +238,10 @@ def test_gsm8k(self): metrics = run_eval_few_shot_gsm8k(args) print(metrics) - self.assertGreater(metrics["accuracy"], 0.62) + self.assertGreater(metrics["accuracy"], 0.60) -# @unittest.skip("covered in TestMTPWithTBO") +@unittest.skip("covered in TestMTPWithTBO") class TestMTP(CustomTestCase): @classmethod def setUpClass(cls): @@ -280,8 +282,6 @@ def tearDownClass(cls): kill_process_tree(cls.process.pid) def test_gsm8k(self): - requests.get(self.base_url + "/flush_cache") - args = SimpleNamespace( num_shots=5, data_path=None, @@ -352,8 +352,6 @@ def tearDownClass(cls): kill_process_tree(cls.process.pid) def test_gsm8k(self): - requests.get(self.base_url + "/flush_cache") - args = SimpleNamespace( num_shots=5, data_path=None, From 49b8777460b707809c60584b7a801fac5e0426b4 Mon Sep 17 00:00:00 2001 From: Cheng Wan <54331508+ch-wan@users.noreply.github.com> Date: Thu, 17 Jul 2025 00:47:07 -0700 Subject: [PATCH 018/396] Refactor: move all quantization-related code to `srt/layer/quantization` (#7989) --- python/sglang/srt/layers/linear.py | 116 +--- python/sglang/srt/layers/moe/ep_moe/layer.py | 327 +---------- .../layers/moe/fused_moe_triton/__init__.py | 3 - .../srt/layers/moe/fused_moe_triton/layer.py | 375 +------------ python/sglang/srt/layers/moe/topk.py | 6 +- .../srt/layers/quantization/__init__.py | 100 +--- python/sglang/srt/layers/quantization/awq.py | 16 +- .../srt/layers/quantization/base_config.py | 86 ++- .../srt/layers/quantization/blockwise_int8.py | 37 +- .../compressed_tensors/compressed_tensors.py | 19 +- python/sglang/srt/layers/quantization/fp8.py | 292 +++++++++- python/sglang/srt/layers/quantization/gptq.py | 23 +- .../srt/layers/quantization/marlin_utils.py | 19 +- .../srt/layers/quantization/modelopt_quant.py | 71 +-- .../srt/layers/quantization/moe_wna16.py | 30 +- python/sglang/srt/layers/quantization/qoq.py | 13 +- .../sglang/srt/layers/quantization/unquant.py | 515 ++++++++++++++++++ .../sglang/srt/layers/quantization/utils.py | 97 +++- .../sglang/srt/layers/quantization/w4afp8.py | 12 +- .../srt/layers/quantization/w8a8_fp8.py | 31 +- .../srt/layers/quantization/w8a8_int8.py | 40 +- .../srt/layers/vocab_parallel_embedding.py | 40 +- 22 files changed, 1094 insertions(+), 1174 deletions(-) create mode 100644 python/sglang/srt/layers/quantization/unquant.py diff --git a/python/sglang/srt/layers/linear.py b/python/sglang/srt/layers/linear.py index 0cc44be55321..1c770193fccb 100644 --- a/python/sglang/srt/layers/linear.py +++ b/python/sglang/srt/layers/linear.py @@ -1,12 +1,12 @@ """Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/linear.py""" +from __future__ import annotations + import itertools import logging -from abc import abstractmethod -from typing import Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple import torch -import torch.nn.functional as F from torch.nn.parameter import Parameter, UninitializedParameter from sglang.srt.distributed import ( @@ -17,7 +17,6 @@ tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce, ) -from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading from sglang.srt.layers.parameter import ( BasevLLMParameter, BlockQuantScaleParameter, @@ -27,17 +26,14 @@ RowvLLMParameter, _ColumnvLLMParameter, ) -from sglang.srt.layers.quantization.base_config import ( - QuantizationConfig, - QuantizeMethodBase, -) -from sglang.srt.utils import ( - cpu_has_amx_support, - is_cpu, - is_npu, - set_weight_attrs, - use_intel_amx_backend, -) +from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod +from sglang.srt.utils import is_cpu, is_npu, set_weight_attrs + +if TYPE_CHECKING: + from sglang.srt.layers.quantization.base_config import ( + QuantizationConfig, + QuantizeMethodBase, + ) logger = logging.getLogger(__name__) @@ -59,7 +55,6 @@ "IPEXAWQLinearMethod", ] -_is_cpu_amx_available = cpu_has_amx_support() _is_cpu = is_cpu() _is_npu = is_npu() @@ -110,91 +105,6 @@ def adjust_scalar_to_fused_array(param, loaded_weight, shard_id): return param[shard_id], loaded_weight -class LinearMethodBase(QuantizeMethodBase): - """Base class for different (maybe quantized) linear methods.""" - - @abstractmethod - def create_weights( - self, - layer: torch.nn.Module, - input_size_per_partition: int, - output_partition_sizes: List[int], - input_size: int, - output_size: int, - params_dtype: torch.dtype, - **extra_weight_attrs, - ): - """Create weights for a linear layer. - The weights will be set as attributes of the layer. - - Args: - layer: The layer that is using the LinearMethodBase factory. - input_size_per_partition: Size of the weight input dim on rank X. - output_partition_sizes: Sizes of the output dim of each logical - weight on rank X. E.g., output_partition_sizes for QKVLinear - is a list contains the width of Wq, Wk, Wv on rank X. - input_size: Size of the input dim of the weight across all ranks. - output_size: Size of the output dim of the weight across all ranks. - params_dtype: Datatype of the parameters. - """ - raise NotImplementedError - - @abstractmethod - def apply( - self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - """Apply the weights in layer to the input tensor. - Expects create_weights to have been called before on the layer.""" - raise NotImplementedError - - -class UnquantizedLinearMethod(LinearMethodBase): - """Linear method without quantization.""" - - def create_weights( - self, - layer: torch.nn.Module, - input_size_per_partition: int, - output_partition_sizes: List[int], - input_size: int, - output_size: int, - params_dtype: torch.dtype, - **extra_weight_attrs, - ): - weight = Parameter( - torch.empty( - sum(output_partition_sizes), - input_size_per_partition, - dtype=params_dtype, - ), - requires_grad=False, - ) - set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) - layer.register_parameter("weight", weight) - set_weight_attrs(weight, extra_weight_attrs) - - def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - if _is_cpu and _is_cpu_amx_available: - _amx_process_weight_after_loading(layer, ["weight"]) - - def apply( - self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - - if use_intel_amx_backend(layer): - return torch.ops.sgl_kernel.weight_packed_linear( - x, layer.weight, bias, True # is_vnni - ) - - return F.linear(x, layer.weight, bias) - - class LinearBase(torch.nn.Module): """Base linear layer. @@ -310,7 +220,7 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): assert param.size() == loaded_weight.size() param.data.copy_(loaded_weight) - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: bias = self.bias if not self.skip_bias_add else None assert self.quant_method is not None output = self.quant_method.apply(self, x, bias) @@ -845,7 +755,7 @@ def __init__( bias: bool = True, skip_bias_add: bool = False, params_dtype: Optional[torch.dtype] = None, - quant_config: Optional[QuantizationConfig] = None, + quant_config: Optional["QuantizationConfig"] = None, prefix: str = "", tp_rank: Optional[int] = None, tp_size: Optional[int] = None, diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index e8bfadfb65fe..a839b47febed 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -27,22 +27,20 @@ silu_and_mul_triton_kernel, tma_align_input_scale, ) -from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported -from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE, FusedMoEMethodBase +from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.quantization import deep_gemm_wrapper from sglang.srt.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase, ) -from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod +from sglang.srt.layers.quantization.fp8 import Fp8EPMoEMethod from sglang.srt.layers.quantization.fp8_kernel import ( is_fp8_fnuz, - scaled_fp8_quant, sglang_per_token_group_quant_fp8, sglang_per_token_quant_fp8, ) -from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz +from sglang.srt.layers.quantization.unquant import UnquantizedEPMoEMethod from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config, W4AFp8MoEMethod from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch @@ -53,7 +51,6 @@ get_bool_env_var, is_hip, is_npu, - set_weight_attrs, ) _is_hip = is_hip() @@ -904,324 +901,6 @@ def _load_fp8_scale( param_data[expert_id] = loaded_weight -class UnquantizedEPMoEMethod(FusedMoEMethodBase, CustomOp): - - def create_weights( - self, - layer: torch.nn.Module, - num_experts_per_partition: int, - hidden_size: int, - intermediate_size: int, - params_dtype: torch.dtype, - **extra_weight_attrs, - ): - # Fused gate_up_proj (column parallel) - w13_weight = torch.nn.Parameter( - torch.empty( - num_experts_per_partition, - 2 * intermediate_size, - hidden_size, - dtype=params_dtype, - ), - requires_grad=False, - ) - layer.register_parameter("w13_weight", w13_weight) - set_weight_attrs(w13_weight, extra_weight_attrs) - - # down_proj (row parallel) - w2_weight = torch.nn.Parameter( - torch.empty( - num_experts_per_partition, - hidden_size, - intermediate_size, - dtype=params_dtype, - ), - requires_grad=False, - ) - layer.register_parameter("w2_weight", w2_weight) - set_weight_attrs(w2_weight, extra_weight_attrs) - - # scale - layer.register_parameter("w13_input_scale", None) - layer.register_parameter("w13_weight_scale", None) - - ones_tensor = torch.ones(num_experts_per_partition, dtype=torch.float32) - - w2_input_scale = torch.nn.Parameter( - ones_tensor, - requires_grad=False, - ) - layer.register_parameter("w2_input_scale", w2_input_scale) - set_weight_attrs(w2_input_scale, extra_weight_attrs) - - w2_weight_scale = torch.nn.Parameter( - ones_tensor, - requires_grad=False, - ) - layer.register_parameter("w2_weight_scale", w2_weight_scale) - set_weight_attrs(w2_weight_scale, extra_weight_attrs) - - def apply( - self, - layer: torch.nn.Module, - x: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool, - use_grouped_topk: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - custom_routing_function: Optional[Callable] = None, - ) -> torch.Tensor: - raise NotImplementedError - - -class Fp8EPMoEMethod(Fp8MoEMethod): - """MoE method for FP8. - Supports loading FP8 checkpoints with static weight scale and - dynamic/static activation scale. - - Args: - quant_config: The quantization config. - """ - - def __init__(self, quant_config: Fp8Config): - self.quant_config = quant_config - self.block_quant = self.quant_config.weight_block_size is not None - - def create_weights( - self, - layer: Module, - num_experts_per_partition: int, - hidden_size: int, - intermediate_size: int, - params_dtype: torch.dtype, - **extra_weight_attrs, - ): - if self.quant_config.is_checkpoint_fp8_serialized: - params_dtype = torch.float8_e4m3fn - - tp_size = get_tensor_model_parallel_world_size() - if self.block_quant: - block_n, block_k = ( - self.quant_config.weight_block_size[0], - self.quant_config.weight_block_size[1], - ) - # NOTE(HandH1998): To ensure proper alignment of the block-wise quantization scales, the output_size of the weights for both the gate and up layers must be divisible by block_n. - # Required by column parallel or enabling merged weights - if intermediate_size % block_n != 0: - raise ValueError( - f"The output_size of gate's and up's weight = " - f"{intermediate_size} is not divisible by " - f"weight quantization block_n = {block_n}." - ) - if tp_size > 1: - # Required by row parallel - if intermediate_size % block_k != 0: - raise ValueError( - f"The input_size of down's weight = " - f"{intermediate_size} is not divisible by " - f"weight quantization block_k = {block_k}." - ) - - # WEIGHTS - w13_weight = torch.nn.Parameter( - torch.empty( - num_experts_per_partition, - 2 * intermediate_size, - hidden_size, - dtype=params_dtype, - ), - requires_grad=False, - ) - layer.register_parameter("w13_weight", w13_weight) - set_weight_attrs(w13_weight, extra_weight_attrs) - - w2_weight = torch.nn.Parameter( - torch.empty( - num_experts_per_partition, - hidden_size, - intermediate_size, - dtype=params_dtype, - ), - requires_grad=False, - ) - layer.register_parameter("w2_weight", w2_weight) - set_weight_attrs(w2_weight, extra_weight_attrs) - - # WEIGHT_SCALES - if self.block_quant: - w13_weight_scale = torch.nn.Parameter( - torch.ones( - num_experts_per_partition, - 2 * ((intermediate_size + block_n - 1) // block_n), - (hidden_size + block_k - 1) // block_k, - dtype=torch.float32, - ), - requires_grad=False, - ) - w2_weight_scale = torch.nn.Parameter( - torch.ones( - num_experts_per_partition, - (hidden_size + block_n - 1) // block_n, - (intermediate_size + block_k - 1) // block_k, - dtype=torch.float32, - ), - requires_grad=False, - ) - layer.register_parameter("w13_weight_scale_inv", w13_weight_scale) - layer.register_parameter("w2_weight_scale_inv", w2_weight_scale) - assert self.quant_config.activation_scheme == "dynamic" - else: - # WEIGHT_SCALES - # Allocate 2 scales for w1 and w3 respectively. - w13_weight_scale = torch.nn.Parameter( - torch.ones(num_experts_per_partition, 2, dtype=torch.float32), - requires_grad=False, - ) - layer.register_parameter("w13_weight_scale", w13_weight_scale) - - w2_weight_scale = torch.nn.Parameter( - torch.ones(num_experts_per_partition, dtype=torch.float32), - requires_grad=False, - ) - layer.register_parameter("w2_weight_scale", w2_weight_scale) - # Add the quantization method used (per tensor/grouped/channel) - # to ensure the weight scales are loaded in properly - extra_weight_attrs.update( - {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value} - if self.block_quant - else {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value} - ) - # If loading fp8 checkpoint, pass the weight loaders. - # If loading an fp16 checkpoint, do not (we will quantize in - # process_weights_after_loading() - if self.quant_config.is_checkpoint_fp8_serialized: - set_weight_attrs(w13_weight_scale, extra_weight_attrs) - set_weight_attrs(w2_weight_scale, extra_weight_attrs) - - # INPUT_SCALES - if self.quant_config.activation_scheme == "static": - if not self.quant_config.is_checkpoint_fp8_serialized: - raise ValueError( - "Found static activation scheme for checkpoint that " - "was not serialized fp8." - ) - - w13_input_scale = torch.nn.Parameter( - torch.ones(num_experts_per_partition, dtype=torch.float32), - requires_grad=False, - ) - layer.register_parameter("w13_input_scale", w13_input_scale) - set_weight_attrs(w13_input_scale, extra_weight_attrs) - - w2_input_scale = torch.nn.Parameter( - torch.ones(num_experts_per_partition, dtype=torch.float32), - requires_grad=False, - ) - layer.register_parameter("w2_input_scale", w2_input_scale) - set_weight_attrs(w2_input_scale, extra_weight_attrs) - - else: - layer.w13_input_scale = None - layer.w2_input_scale = None - - def process_weights_after_loading(self, layer: Module) -> None: - - # If checkpoint is fp16, quantize in place. - if not self.quant_config.is_checkpoint_fp8_serialized: - # If rocm, use float8_e4m3fnuz as dtype - fp8_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn - w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype) - w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype) - - layer.w13_weight_scale = torch.nn.Parameter( - torch.ones( - layer.num_experts_per_partition, - dtype=torch.float32, - device=w13_weight.device, - ), - requires_grad=False, - ) - - for expert in range(layer.num_experts_per_partition): - w13_weight[expert, :, :], layer.w13_weight_scale[expert] = ( - scaled_fp8_quant(layer.w13_weight.data[expert, :, :]) - ) - w2_weight[expert, :, :], layer.w2_weight_scale[expert] = ( - scaled_fp8_quant(layer.w2_weight.data[expert, :, :]) - ) - layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False) - layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False) - return - - # If checkpoint is fp8, we need to handle that the - # MoE kernels require single activation scale and single weight - # scale for w13 per expert. - else: - if self.quant_config.activation_scheme == "static": - if layer.w13_input_scale is None or layer.w2_input_scale is None: - raise ValueError( - "QuantConfig has static quantization, but found " - "activation scales are None." - ) - layer.w13_weight_scale = torch.nn.Parameter( - torch.max(layer.w13_weight_scale, dim=1).values, - requires_grad=False, - ) - if self.block_quant: - # If ROCm, normalize the weights and scales to e4m3fnuz - if _is_fp8_fnuz: - # activation_scheme: dynamic - w13_weight, w13_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( - weight=layer.w13_weight, - weight_scale=layer.w13_weight_scale_inv, - input_scale=None, - ) - w2_weight, w2_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( - weight=layer.w2_weight, - weight_scale=layer.w2_weight_scale_inv, - input_scale=None, - ) - # Reset the parameter - layer.w13_weight = torch.nn.Parameter( - w13_weight, requires_grad=False - ) - layer.w13_weight_scale_inv = torch.nn.Parameter( - w13_weight_scale, requires_grad=False - ) - layer.w13_input_scale = None - layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False) - layer.w2_weight_scale_inv = torch.nn.Parameter( - w2_weight_scale, requires_grad=False - ) - layer.w2_input_scale = None - if _use_aiter: - layer.w13_weight = torch.nn.Parameter( - shuffle_weight(layer.w13_weight.data, (16, 16)), - requires_grad=False, - ) - layer.w2_weight = torch.nn.Parameter( - shuffle_weight(layer.w2_weight.data, (16, 16)), - requires_grad=False, - ) - return - - def apply( - self, - layer: torch.nn.Module, - x: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool, - use_grouped_topk: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - custom_routing_function: Optional[Callable] = None, - ) -> torch.Tensor: - raise NotImplementedError - - class DeepEPMoE(EPMoE): """ MoE Expert Parallel Impl based on DeepEP (https://github.com/deepseek-ai/DeepEP/tree/main) diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/__init__.py b/python/sglang/srt/layers/moe/fused_moe_triton/__init__.py index 839b659fe31b..6d8aee85293d 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/__init__.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/__init__.py @@ -9,7 +9,6 @@ ) from sglang.srt.layers.moe.fused_moe_triton.layer import ( FusedMoE, - FusedMoEMethodBase, FusedMoeWeightScaleSupported, ) @@ -31,11 +30,9 @@ def get_config() -> Optional[Dict[str, Any]]: __all__ = [ "FusedMoE", - "FusedMoEMethodBase", "FusedMoeWeightScaleSupported", "override_config", "get_config", - "fused_moe", "fused_experts", "get_config_file_name", "moe_align_block_size", diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index ad495d5953cf..41ae6274b087 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -1,60 +1,28 @@ # Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/model_executor/layers/fused_moe/layer.py -import importlib -from abc import abstractmethod +import logging from enum import Enum from typing import Callable, List, Optional, Tuple import torch -from sglang.srt.custom_op import CustomOp from sglang.srt.distributed import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce, ) -from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading -from sglang.srt.layers.moe.fused_moe_native import moe_forward_native -from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase, ) +from sglang.srt.layers.quantization.unquant import UnquantizedFusedMoEMethod from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_loader.weight_utils import narrow_padded_param_and_loaded_weight -from sglang.srt.utils import ( - cpu_has_amx_support, - get_bool_env_var, - is_cpu, - is_hip, - set_weight_attrs, - use_intel_amx_backend, -) - -has_triton_kernels = importlib.util.find_spec("triton_kernels") is not None - -if torch.cuda.is_available(): - from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts - - if has_triton_kernels: - from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import ( - triton_kernel_moe_forward, - ) -else: - fused_experts = None # type: ignore - -import logging +from sglang.srt.utils import cpu_has_amx_support, get_bool_env_var, is_cpu, is_hip _is_hip = is_hip() _is_cpu_amx_available = cpu_has_amx_support() _is_cpu = is_cpu() -_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip - -if _use_aiter: - from aiter import ActivationType - from aiter.fused_moe import fused_moe - from aiter.fused_moe_bf16_asm import ck_moe_2stages - from aiter.ops.shuffle import shuffle_weight logger = logging.getLogger(__name__) @@ -66,333 +34,6 @@ class FusedMoeWeightScaleSupported(Enum): BLOCK = "block" -class FusedMoEMethodBase(QuantizeMethodBase): - - @abstractmethod - def create_weights( - self, - layer: torch.nn.Module, - num_experts: int, - hidden_size: int, - intermediate_size: int, - params_dtype: torch.dtype, - **extra_weight_attrs, - ): - raise NotImplementedError - - @abstractmethod - def apply( - self, - layer: torch.nn.Module, - x: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool, - use_grouped_topk: bool, - ) -> torch.Tensor: - raise NotImplementedError - - -class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): - """MoE method without quantization.""" - - def __init__(self, use_triton_kernels: bool = False): - super().__init__() - self.use_triton_kernels = use_triton_kernels - - def create_weights( - self, - layer: torch.nn.Module, - num_experts: int, - hidden_size: int, - intermediate_size: int, - params_dtype: torch.dtype, - **extra_weight_attrs, - ): - # Fused gate_up_proj (column parallel) - w13_weight_n, w13_weight_k = 2 * intermediate_size, hidden_size - if self.use_triton_kernels: - w13_weight_n, w13_weight_k = w13_weight_k, w13_weight_n - w13_weight = torch.nn.Parameter( - torch.empty(num_experts, w13_weight_n, w13_weight_k, dtype=params_dtype), - requires_grad=False, - ) - layer.register_parameter("w13_weight", w13_weight) - set_weight_attrs(w13_weight, extra_weight_attrs) - - # down_proj (row parallel) - w2_weight_n, w2_weight_k = ( - hidden_size, - intermediate_size, - ) - if self.use_triton_kernels: - w2_weight_n, w2_weight_k = w2_weight_k, w2_weight_n - w2_weight = torch.nn.Parameter( - torch.empty(num_experts, w2_weight_n, w2_weight_k, dtype=params_dtype), - requires_grad=False, - ) - layer.register_parameter("w2_weight", w2_weight) - set_weight_attrs(w2_weight, extra_weight_attrs) - - def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - if _use_aiter: - layer.w13_weight = torch.nn.Parameter( - shuffle_weight(layer.w13_weight.data, (16, 16)), - requires_grad=False, - ) - torch.cuda.empty_cache() - layer.w2_weight = torch.nn.Parameter( - shuffle_weight(layer.w2_weight.data, (16, 16)), - requires_grad=False, - ) - torch.cuda.empty_cache() - - # Pack weight for get better performance on CPU - if _is_cpu and _is_cpu_amx_available: - _amx_process_weight_after_loading(layer, ["w13_weight", "w2_weight"]) - - return - - def apply( - self, - layer: torch.nn.Module, - x: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool, - use_grouped_topk: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - num_fused_shared_experts: int = 0, - custom_routing_function: Optional[Callable] = None, - correction_bias: Optional[torch.Tensor] = None, - activation: str = "silu", - apply_router_weight_on_input: bool = False, - inplace: bool = True, - no_combine: bool = False, - routed_scaling_factor: Optional[float] = None, - ) -> torch.Tensor: - return self.forward( - x=x, - layer=layer, - router_logits=router_logits, - top_k=top_k, - renormalize=renormalize, - use_grouped_topk=use_grouped_topk, - topk_group=topk_group, - num_expert_group=num_expert_group, - num_fused_shared_experts=num_fused_shared_experts, - custom_routing_function=custom_routing_function, - correction_bias=correction_bias, - activation=activation, - apply_router_weight_on_input=apply_router_weight_on_input, - inplace=inplace, - no_combine=no_combine, - routed_scaling_factor=routed_scaling_factor, - ) - - def forward_cuda( - self, - layer: torch.nn.Module, - x: torch.Tensor, - use_grouped_topk: bool, - top_k: int, - router_logits: torch.Tensor, - renormalize: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - num_fused_shared_experts: int = 0, - custom_routing_function: Optional[Callable] = None, - correction_bias: Optional[torch.Tensor] = None, - activation: str = "silu", - apply_router_weight_on_input: bool = False, - inplace: bool = True, - no_combine: bool = False, - routed_scaling_factor: Optional[float] = None, - ) -> torch.Tensor: - - if self.use_triton_kernels: - return triton_kernel_moe_forward( - hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - gating_output=router_logits, - topk=top_k, - renormalize=renormalize, - ) - else: - topk_weights, topk_ids = select_experts( - hidden_states=x, - router_logits=router_logits, - use_grouped_topk=use_grouped_topk, - top_k=top_k, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - num_fused_shared_experts=num_fused_shared_experts, - custom_routing_function=custom_routing_function, - correction_bias=correction_bias, - routed_scaling_factor=routed_scaling_factor, - ) - - if _use_aiter: - assert not no_combine, "unsupported" - if apply_router_weight_on_input: - assert ( - topk_weights.dim() == 2 - ), "`topk_weights` should be in shape (num_tokens, topk)" - _, topk = topk_weights.shape - assert ( - topk == 1 - ), "Only support topk=1 when `apply_router_weight_on_input` is True" - x = x * topk_weights.to(x.dtype) - topk_weights = torch.ones_like( - topk_weights, dtype=torch.float32 - ) # topk_weights must be FP32 (float32) - - return fused_moe( - x, - layer.w13_weight, - layer.w2_weight, - topk_weights, - topk_ids, - activation=( - ActivationType.Silu - if activation == "silu" - else ActivationType.Gelu - ), - ) - else: - return fused_experts( - hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - inplace=inplace and not no_combine, - activation=activation, - apply_router_weight_on_input=apply_router_weight_on_input, - no_combine=no_combine, - routed_scaling_factor=routed_scaling_factor, - ) - - def forward_cpu( - self, - layer: torch.nn.Module, - x: torch.Tensor, - use_grouped_topk: bool, - top_k: int, - router_logits: torch.Tensor, - renormalize: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - num_fused_shared_experts: int = 0, - custom_routing_function: Optional[Callable] = None, - correction_bias: Optional[torch.Tensor] = None, - activation: str = "silu", - apply_router_weight_on_input: bool = False, - inplace: bool = True, - no_combine: bool = False, - routed_scaling_factor: Optional[float] = None, - ) -> torch.Tensor: - assert activation == "silu", f"activation = {activation} is not supported." - - if use_intel_amx_backend(layer) and not apply_router_weight_on_input: - topk_weights, topk_ids = select_experts( - hidden_states=x, - router_logits=router_logits, - use_grouped_topk=use_grouped_topk, - top_k=top_k, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - num_fused_shared_experts=num_fused_shared_experts, - custom_routing_function=custom_routing_function, - correction_bias=correction_bias, - routed_scaling_factor=routed_scaling_factor, - ) - - # TODO: support apply_router_weight_on_input in the fused_experts_cpu kernel - return torch.ops.sgl_kernel.fused_experts_cpu( - x, - layer.w13_weight, - layer.w2_weight, - topk_weights, - topk_ids, - False, # inplace # See [Note] inplace should be False in fused_experts. - False, # use_int8_w8a8 - False, # use_fp8_w8a16 - None, # w1_scale - None, # w2_scale - None, # block_size - None, # a1_scale - None, # a2_scale - True, # is_vnni - ) - else: - return moe_forward_native( - layer, - x, - use_grouped_topk, - top_k, - router_logits, - renormalize, - topk_group, - num_expert_group, - num_fused_shared_experts, - custom_routing_function, - correction_bias, - activation, - apply_router_weight_on_input, - inplace, - no_combine, - routed_scaling_factor, - ) - - def forward_npu( - self, - layer: torch.nn.Module, - x: torch.Tensor, - use_grouped_topk: bool, - top_k: int, - router_logits: torch.Tensor, - renormalize: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - num_fused_shared_experts: int = 0, - custom_routing_function: Optional[Callable] = None, - correction_bias: Optional[torch.Tensor] = None, - activation: str = "silu", - apply_router_weight_on_input: bool = False, - inplace: bool = True, - no_combine: bool = False, - routed_scaling_factor: Optional[float] = None, - ) -> torch.Tensor: - return moe_forward_native( - layer, - x, - use_grouped_topk, - top_k, - router_logits, - renormalize, - topk_group, - num_expert_group, - num_fused_shared_experts, - custom_routing_function, - correction_bias, - activation, - apply_router_weight_on_input, - inplace, - no_combine, - routed_scaling_factor, - ) - - def forward_tpu(self, *args, **kwargs) -> torch.Tensor: - raise NotImplementedError("The TPU backend currently does not support MoE.") - - forward_native = forward_cpu - - class FusedMoE(torch.nn.Module): """FusedMoE layer for MoE models. @@ -553,7 +194,7 @@ def _load_model_weight_or_group_weight_scale( shard_dim: int, expert_data: torch.Tensor, shard_id: str, - loaded_weight: torch.tensor, + loaded_weight: torch.Tensor, tp_rank: int, ): # Load grouped weight scales for group quantization @@ -580,7 +221,7 @@ def _load_per_channel_weight_scale( expert_data: torch.Tensor, shard_dim: int, shard_id: str, - loaded_weight: torch.tensor, + loaded_weight: torch.Tensor, tp_rank: int, ): # for per channel weight quantization @@ -600,7 +241,7 @@ def _load_w13( expert_data: torch.Tensor, shard_dim: int, shard_id: str, - loaded_weight: torch.tensor, + loaded_weight: torch.Tensor, tp_rank: int, ): @@ -645,7 +286,7 @@ def _load_w2( expert_data: torch.Tensor, shard_dim: int, shard_id: str, - loaded_weight: torch.tensor, + loaded_weight: torch.Tensor, tp_rank: int, ): """Load w2 weights for down projection. @@ -717,7 +358,7 @@ def _load_g_idx( shard_id: str, expert_data: torch.Tensor, shard_dim: int, - loaded_weight: torch.tensor, + loaded_weight: torch.Tensor, tp_rank: int, ): diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index 18f3dea8dffa..1c8d219e4ec0 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -19,15 +19,11 @@ import torch.nn.functional as F from sglang.srt.eplb import expert_location_dispatch -from sglang.srt.eplb.expert_distribution import ( - ExpertDistributionRecorder, - get_global_expert_distribution_recorder, -) +from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder from sglang.srt.eplb.expert_location_dispatch import ( ExpertLocationDispatchInfo, topk_ids_logical_to_physical, ) -from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.utils import ( cpu_has_amx_support, get_bool_env_var, diff --git a/python/sglang/srt/layers/quantization/__init__.py b/python/sglang/srt/layers/quantization/__init__.py index 7507a5b62893..e0f4363437b3 100644 --- a/python/sglang/srt/layers/quantization/__init__.py +++ b/python/sglang/srt/layers/quantization/__init__.py @@ -1,8 +1,6 @@ # Adapted from https://raw.githubusercontent.com/vllm-project/vllm/v0.5.5/vllm/model_executor/layers/quantization/__init__.py import builtins import inspect -import re -from copy import deepcopy from typing import Callable, Dict, Optional, Type, Union import torch @@ -45,7 +43,6 @@ def override_quantization_method(self, *args, **kwargs): ) = QQQConfig = Int8TpuConfig = DummyConfig -from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod from sglang.srt.layers.quantization.awq import AWQConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.blockwise_int8 import BlockInt8Config @@ -66,6 +63,10 @@ def override_quantization_method(self, *args, **kwargs): ) from sglang.srt.layers.quantization.moe_wna16 import MoeWNA16Config from sglang.srt.layers.quantization.qoq import QoQConfig +from sglang.srt.layers.quantization.utils import ( + get_dynamic_override, + get_linear_quant_method, +) from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config @@ -120,99 +121,6 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]: return QUANTIZATION_METHODS[quantization] -# Match dynamic rules with module name (prefix) and override quantize -# config if module (prefix) matches a rule -def override_config(config: QuantizationConfig, prefix: str): - weight_bits = get_dynamic_override(config, prefix, "bits", config.weight_bits) - if isinstance(weight_bits, int): - config.weight_bits = weight_bits - group_size = get_dynamic_override(config, prefix, "group_size", config.group_size) - if isinstance(group_size, int): - config.group_size = group_size - desc_act = get_dynamic_override(config, prefix, "desc_act", config.desc_act) - if isinstance(desc_act, bool): - config.desc_act = desc_act - - config.pack_factor = 32 // config.weight_bits # packed into int32 - if config.get_name() == "gptq_marlin": - is_sym = get_dynamic_override(config, prefix, "sym", config.is_sym) - if isinstance(is_sym, bool): - config.is_sym = is_sym - - if (config.weight_bits, config.is_sym) not in config.TYPE_MAP: - raise ValueError( - "Unsupported quantization config: " - f"bits={config.weight_bits}, sym={config.is_sym}" - ) - - config.quant_type = config.TYPE_MAP[(config.weight_bits, config.is_sym)] - elif config.get_name() == "gptq": - if config.weight_bits not in [2, 3, 4, 8]: - raise ValueError( - "Currently, only 2/3/4/8-bit weight quantization is " - f"supported for GPTQ, but got {config.weight_bits} bits." - ) - - -def get_dynamic_override( - config: QuantizationConfig, - layer_name: str, - key: Optional[str] = None, - default_value: Union[int, bool, None] = None, -) -> Union[Dict, int, bool, None]: - for pattern, pattern_dict in config.dynamic.items(): - # Negative match: matched modules are excluded from quantized init - if pattern.startswith("-:"): - if re.match(pattern.removeprefix("-:"), layer_name): - return False - # Positive match: matched modules have quant properties overrides - # base quant config - elif re.match(pattern.removeprefix("+:"), layer_name): - if key is None: - return pattern_dict - else: - return pattern_dict.get(key, default_value) - return default_value - - -def get_linear_quant_method( - config: QuantizationConfig, - layer: torch.nn.Module, - prefix: str, - linear_method_cls: type, -): - # Move import here to avoid circular import. This is only used in monkey patching - # of vllm's QuantizationConfig. - from sglang.srt.layers.vocab_parallel_embedding import ( - ParallelLMHead, - UnquantizedEmbeddingMethod, - ) - - cloned_config = deepcopy(config) - parallel_lm_head_quantized = ( - isinstance(layer, ParallelLMHead) and cloned_config.lm_head_quantized - ) - - if isinstance(layer, LinearBase) or parallel_lm_head_quantized: - # False = skip module, None = no override, else = Positive match - if ( - get_dynamic_override( # noqa: E712 - cloned_config, layer_name=prefix # noqa: E712 - ) - == False - ): # noqa: E712 - if parallel_lm_head_quantized: - return UnquantizedEmbeddingMethod() - return UnquantizedLinearMethod() - - if prefix: - # Dynamic per module/layer rules may override base config - override_config(cloned_config, prefix=prefix) - - return linear_method_cls(cloned_config) - return None - - def gptq_get_quant_method(self, layer, prefix): from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE diff --git a/python/sglang/srt/layers/quantization/awq.py b/python/sglang/srt/layers/quantization/awq.py index 9f14ac4c1cac..6265f2217d79 100644 --- a/python/sglang/srt/layers/quantization/awq.py +++ b/python/sglang/srt/layers/quantization/awq.py @@ -1,16 +1,17 @@ # SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + import logging from typing import Any, Dict, List, Optional import torch -from sglang.srt.layers.linear import ( - LinearBase, +from sglang.srt.layers.parameter import GroupQuantScaleParameter, PackedvLLMParameter +from sglang.srt.layers.quantization.base_config import ( LinearMethodBase, - UnquantizedLinearMethod, + QuantizationConfig, ) -from sglang.srt.layers.parameter import GroupQuantScaleParameter, PackedvLLMParameter -from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod from sglang.srt.utils import is_cuda _is_cuda = is_cuda() @@ -81,7 +82,7 @@ def get_config_filenames() -> List[str]: ] @classmethod - def from_config(cls, config: Dict[str, Any]) -> "AWQConfig": + def from_config(cls, config: Dict[str, Any]) -> AWQConfig: weight_bits = cls.get_from_keys(config, ["w_bit", "bits"]) group_size = cls.get_from_keys(config, ["q_group_size", "group_size"]) zero_point = cls.get_from_keys(config, ["zero_point"]) @@ -92,7 +93,8 @@ def from_config(cls, config: Dict[str, Any]) -> "AWQConfig": def get_quant_method( self, layer: torch.nn.Module, prefix: str - ) -> Optional["LinearMethodBase"]: + ) -> Optional[LinearMethodBase]: + from sglang.srt.layers.linear import LinearBase if isinstance(layer, LinearBase): if is_layer_skipped_awq(prefix, self.modules_to_not_convert): diff --git a/python/sglang/srt/layers/quantization/base_config.py b/python/sglang/srt/layers/quantization/base_config.py index 6058702c9a10..607151671bff 100644 --- a/python/sglang/srt/layers/quantization/base_config.py +++ b/python/sglang/srt/layers/quantization/base_config.py @@ -18,14 +18,14 @@ def create_weights( """Create weights for a layer. The weights will be set as attributes of the layer.""" - raise NotImplementedError + raise NotImplementedError() @abstractmethod def apply(self, layer: torch.nn.Module, *args, **kwargs) -> torch.Tensor: """Apply the weights in layer to the input tensor. Expects create_weights to have been called before on the layer.""" - raise NotImplementedError + raise NotImplementedError() def process_weights_after_loading(self, layer: nn.Module) -> None: """Process the weight after loading. @@ -35,6 +35,74 @@ def process_weights_after_loading(self, layer: nn.Module) -> None: return +class LinearMethodBase(QuantizeMethodBase): + """Base class for different (maybe quantized) linear methods.""" + + @abstractmethod + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + """Create weights for a linear layer. + The weights will be set as attributes of the layer. + + Args: + layer: The layer that is using the LinearMethodBase factory. + input_size_per_partition: Size of the weight input dim on rank X. + output_partition_sizes: Sizes of the output dim of each logical + weight on rank X. E.g., output_partition_sizes for QKVLinear + is a list contains the width of Wq, Wk, Wv on rank X. + input_size: Size of the input dim of the weight across all ranks. + output_size: Size of the output dim of the weight across all ranks. + params_dtype: Datatype of the parameters. + """ + raise NotImplementedError() + + @abstractmethod + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Apply the weights in layer to the input tensor. + Expects create_weights to have been called before on the layer.""" + raise NotImplementedError() + + +class FusedMoEMethodBase(QuantizeMethodBase): + + @abstractmethod + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + raise NotImplementedError() + + @abstractmethod + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool, + ) -> torch.Tensor: + raise NotImplementedError() + + class QuantizationConfig(ABC): """Base class for quantization configs.""" @@ -46,12 +114,12 @@ def __init__(self): @abstractmethod def get_name(self) -> str: """Name of the quantization method.""" - raise NotImplementedError + raise NotImplementedError() @abstractmethod def get_supported_act_dtypes(self) -> List[torch.dtype]: """List of supported activation dtypes.""" - raise NotImplementedError + raise NotImplementedError() @classmethod @abstractmethod @@ -62,19 +130,19 @@ def get_min_capability(cls) -> int: This requirement is due to the custom CUDA kernels used by the quantization method. """ - raise NotImplementedError + raise NotImplementedError() @staticmethod @abstractmethod def get_config_filenames() -> List[str]: """List of filenames to search for in the model directory.""" - raise NotImplementedError + raise NotImplementedError() @classmethod @abstractmethod def from_config(cls, config: Dict[str, Any]) -> "QuantizationConfig": """Create a config class from the model's quantization config.""" - raise NotImplementedError + raise NotImplementedError() @classmethod def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]: @@ -117,7 +185,7 @@ def get_quant_method( The quantize method. None if the given layer doesn't support quant method. """ - raise NotImplementedError + raise NotImplementedError() @abstractmethod def get_scaled_act_names(self) -> List[str]: @@ -125,7 +193,7 @@ def get_scaled_act_names(self) -> List[str]: For now, this is only used by AWQ. """ - raise NotImplementedError + raise NotImplementedError() def method_has_implemented_embedding(method_class: Type[QuantizeMethodBase]) -> bool: diff --git a/python/sglang/srt/layers/quantization/blockwise_int8.py b/python/sglang/srt/layers/quantization/blockwise_int8.py index f38857595580..a1da999b3af1 100644 --- a/python/sglang/srt/layers/quantization/blockwise_int8.py +++ b/python/sglang/srt/layers/quantization/blockwise_int8.py @@ -1,5 +1,7 @@ # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/quantization/fp8.py +from __future__ import annotations + import logging from typing import Any, Callable, Dict, List, Optional @@ -7,17 +9,15 @@ from torch.nn import Module from sglang.srt.distributed import get_tensor_model_parallel_world_size -from sglang.srt.layers.linear import ( - LinearBase, - LinearMethodBase, - UnquantizedLinearMethod, -) from sglang.srt.layers.parameter import BlockQuantScaleParameter, ModelWeightParameter from sglang.srt.layers.quantization.base_config import ( + FusedMoEMethodBase, + LinearMethodBase, QuantizationConfig, QuantizeMethodBase, ) from sglang.srt.layers.quantization.int8_utils import apply_w8a8_block_int8_linear +from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod from sglang.srt.layers.quantization.utils import is_layer_skipped from sglang.srt.utils import set_weight_attrs @@ -78,7 +78,7 @@ def get_config_filenames(cls) -> List[str]: return [] @classmethod - def from_config(cls, config: Dict[str, Any]) -> "BlockInt8Config": + def from_config(cls, config: Dict[str, Any]) -> BlockInt8Config: quant_method = cls.get_from_keys(config, ["quant_method"]) is_checkpoint_int8_serialized = "int8" in quant_method activation_scheme = cls.get_from_keys(config, ["activation_scheme"]) @@ -93,7 +93,8 @@ def from_config(cls, config: Dict[str, Any]) -> "BlockInt8Config": def get_quant_method( self, layer: torch.nn.Module, prefix: str - ) -> Optional["QuantizeMethodBase"]: + ) -> Optional[QuantizeMethodBase]: + from sglang.srt.layers.linear import LinearBase from sglang.srt.layers.moe.fused_moe_triton import FusedMoE if isinstance(layer, LinearBase): @@ -230,7 +231,7 @@ def apply( ) -class BlockInt8MoEMethod: +class BlockInt8MoEMethod(FusedMoEMethodBase): """MoE method for INT8. Supports loading INT8 checkpoints with static weight scale and dynamic activation scale. @@ -242,25 +243,7 @@ class BlockInt8MoEMethod: quant_config: The quantization config. """ - def __new__(cls, *args, **kwargs): - from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase - - if not hasattr(cls, "_initialized"): - original_init = cls.__init__ - new_cls = type( - cls.__name__, - (FusedMoEMethodBase,), - { - "__init__": original_init, - **{k: v for k, v in cls.__dict__.items() if k != "__dict__"}, - }, - ) - obj = super(new_cls, new_cls).__new__(new_cls) - obj.__init__(*args, **kwargs) - return obj - return super().__new__(cls) - - def __init__(self, quant_config): + def __init__(self, quant_config: BlockInt8Config): self.quant_config = quant_config assert self.quant_config.weight_block_size is not None assert self.quant_config.is_checkpoint_int8_serialized diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py index 7ce89345fd6b..50d90406d26f 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py @@ -1,5 +1,6 @@ # Adapted from https://github.com/vllm-project/vllm/tree/v0.8.2/vllm/model_executor/layers/quantization/compressed_tensors # SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations import logging from contextlib import suppress @@ -18,12 +19,8 @@ ) from pydantic import BaseModel -from sglang.srt.layers.linear import ( - LinearBase, - LinearMethodBase, - UnquantizedLinearMethod, -) from sglang.srt.layers.quantization.base_config import ( + LinearMethodBase, QuantizationConfig, QuantizeMethodBase, ) @@ -40,6 +37,7 @@ is_activation_quantization_format, should_ignore_layer, ) +from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod try: import vllm @@ -97,7 +95,7 @@ def __init__( self.config = config self.packed_modules_mapping = packed_modules_mapping - def get_linear_method(self) -> "CompressedTensorsLinearMethod": + def get_linear_method(self) -> CompressedTensorsLinearMethod: return CompressedTensorsLinearMethod(self) def get_supported_act_dtypes(cls) -> List[torch.dtype]: @@ -117,7 +115,8 @@ def get_quant_method( self, layer: torch.nn.Module, prefix: str, - ) -> Optional["QuantizeMethodBase"]: + ) -> Optional[QuantizeMethodBase]: + from sglang.srt.layers.linear import LinearBase # Check if the layer is skipped for quantization. # TODO (@robertgshaw2): support module names @@ -138,7 +137,7 @@ def get_quant_method( return None @classmethod - def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig": + def from_config(cls, config: Dict[str, Any]) -> CompressedTensorsConfig: ignore: List[str] = cast(List[str], config.get("ignore", [])) quant_format = cast(str, config.get("format")) target_scheme_map = cls._quantization_scheme_map_from_config(config=config) @@ -357,7 +356,7 @@ def _is_wNa16_group_channel( def _get_scheme_from_parts( self, weight_quant: BaseModel, input_quant: BaseModel - ) -> "CompressedTensorsScheme": + ) -> CompressedTensorsScheme: # Detect If Mixed Precision if self._is_wNa16_group_channel(weight_quant, input_quant): @@ -435,7 +434,7 @@ def _get_scheme_from_parts( def get_scheme( self, layer: torch.nn.Module, layer_name: Optional[str] = None - ) -> Optional["CompressedTensorsScheme"]: + ) -> Optional[CompressedTensorsScheme]: """ compressed-tensors supports non uniform in the following way: diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index 4d886de91818..38588c809039 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -1,7 +1,9 @@ # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/quantization/fp8.py +from __future__ import annotations + import logging -from typing import Any, Callable, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union import torch import torch.nn.functional as F @@ -28,17 +30,14 @@ def dummy_func(*args, **kwargs): from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading -from sglang.srt.layers.linear import ( - LinearBase, - LinearMethodBase, - UnquantizedLinearMethod, -) from sglang.srt.layers.parameter import ( BlockQuantScaleParameter, ModelWeightParameter, PerTensorScaleParameter, ) from sglang.srt.layers.quantization.base_config import ( + FusedMoEMethodBase, + LinearMethodBase, QuantizationConfig, QuantizeMethodBase, ) @@ -56,6 +55,7 @@ def dummy_func(*args, **kwargs): normalize_e4m3fn_to_e4m3fnuz, ) from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod +from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod from sglang.srt.layers.quantization.utils import ( all_close_1d, convert_to_channelwise, @@ -77,6 +77,9 @@ def dummy_func(*args, **kwargs): use_intel_amx_backend, ) +if TYPE_CHECKING: + from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config + _is_hip = is_hip() _is_cuda = is_cuda() _is_npu = is_npu() @@ -152,7 +155,7 @@ def get_config_filenames(cls) -> List[str]: return [] @classmethod - def from_config(cls, config: Dict[str, Any]) -> "Fp8Config": + def from_config(cls, config: Dict[str, Any]) -> Fp8Config: quant_method = cls.get_from_keys(config, ["quant_method"]) is_checkpoint_fp8_serialized = "fp8" in quant_method activation_scheme = cls.get_from_keys(config, ["activation_scheme"]) @@ -167,7 +170,8 @@ def from_config(cls, config: Dict[str, Any]) -> "Fp8Config": def get_quant_method( self, layer: torch.nn.Module, prefix: str - ) -> Optional["QuantizeMethodBase"]: + ) -> Optional[QuantizeMethodBase]: + from sglang.srt.layers.linear import LinearBase from sglang.srt.layers.moe.fused_moe_triton import FusedMoE if isinstance(layer, LinearBase): @@ -200,7 +204,7 @@ class Fp8LinearMethod(LinearMethodBase): quant_config: The quantization config. """ - def __init__(self, quant_config: Union["Fp8Config", "W4AFp8Config"]): + def __init__(self, quant_config: Union[Fp8Config, W4AFp8Config]): self.quant_config = quant_config self.cutlass_fp8_supported = cutlass_fp8_supported() @@ -486,7 +490,7 @@ def apply( ) -class Fp8MoEMethod: +class Fp8MoEMethod(FusedMoEMethodBase): """MoE method for FP8. Supports loading FP8 checkpoints with static weight scale and dynamic/static activation scale. @@ -499,25 +503,7 @@ class Fp8MoEMethod: quant_config: The quantization config. """ - def __new__(cls, *args, **kwargs): - from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase - - if not hasattr(cls, "_initialized"): - original_init = cls.__init__ - new_cls = type( - cls.__name__, - (FusedMoEMethodBase,), - { - "__init__": original_init, - **{k: v for k, v in cls.__dict__.items() if k != "__dict__"}, - }, - ) - obj = super(new_cls, new_cls).__new__(new_cls) - obj.__init__(*args, **kwargs) - return obj - return super().__new__(cls) - - def __init__(self, quant_config): + def __init__(self, quant_config: Fp8Config): self.quant_config = quant_config self.block_quant = self.quant_config.weight_block_size is not None self.cutlass_fp8_supported = cutlass_fp8_supported() @@ -1169,6 +1155,254 @@ def maybe_apply_hip_fused_experts( return None +class Fp8EPMoEMethod(Fp8MoEMethod): + """MoE method for FP8. + Supports loading FP8 checkpoints with static weight scale and + dynamic/static activation scale. + + Args: + quant_config: The quantization config. + """ + + def __init__(self, quant_config: Fp8Config): + self.quant_config = quant_config + self.block_quant = self.quant_config.weight_block_size is not None + + def create_weights( + self, + layer: Module, + num_experts_per_partition: int, + hidden_size: int, + intermediate_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported + + if self.quant_config.is_checkpoint_fp8_serialized: + params_dtype = torch.float8_e4m3fn + + tp_size = get_tensor_model_parallel_world_size() + if self.block_quant: + block_n, block_k = ( + self.quant_config.weight_block_size[0], + self.quant_config.weight_block_size[1], + ) + # NOTE(HandH1998): To ensure proper alignment of the block-wise quantization scales, the output_size of the weights for both the gate and up layers must be divisible by block_n. + # Required by column parallel or enabling merged weights + if intermediate_size % block_n != 0: + raise ValueError( + f"The output_size of gate's and up's weight = " + f"{intermediate_size} is not divisible by " + f"weight quantization block_n = {block_n}." + ) + if tp_size > 1: + # Required by row parallel + if intermediate_size % block_k != 0: + raise ValueError( + f"The input_size of down's weight = " + f"{intermediate_size} is not divisible by " + f"weight quantization block_k = {block_k}." + ) + + # WEIGHTS + w13_weight = torch.nn.Parameter( + torch.empty( + num_experts_per_partition, + 2 * intermediate_size, + hidden_size, + dtype=params_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + w2_weight = torch.nn.Parameter( + torch.empty( + num_experts_per_partition, + hidden_size, + intermediate_size, + dtype=params_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + # WEIGHT_SCALES + if self.block_quant: + w13_weight_scale = torch.nn.Parameter( + torch.ones( + num_experts_per_partition, + 2 * ((intermediate_size + block_n - 1) // block_n), + (hidden_size + block_k - 1) // block_k, + dtype=torch.float32, + ), + requires_grad=False, + ) + w2_weight_scale = torch.nn.Parameter( + torch.ones( + num_experts_per_partition, + (hidden_size + block_n - 1) // block_n, + (intermediate_size + block_k - 1) // block_k, + dtype=torch.float32, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight_scale_inv", w13_weight_scale) + layer.register_parameter("w2_weight_scale_inv", w2_weight_scale) + assert self.quant_config.activation_scheme == "dynamic" + else: + # WEIGHT_SCALES + # Allocate 2 scales for w1 and w3 respectively. + w13_weight_scale = torch.nn.Parameter( + torch.ones(num_experts_per_partition, 2, dtype=torch.float32), + requires_grad=False, + ) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + + w2_weight_scale = torch.nn.Parameter( + torch.ones(num_experts_per_partition, dtype=torch.float32), + requires_grad=False, + ) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + # Add the quantization method used (per tensor/grouped/channel) + # to ensure the weight scales are loaded in properly + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value} + if self.block_quant + else {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value} + ) + # If loading fp8 checkpoint, pass the weight loaders. + # If loading an fp16 checkpoint, do not (we will quantize in + # process_weights_after_loading() + if self.quant_config.is_checkpoint_fp8_serialized: + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) + + # INPUT_SCALES + if self.quant_config.activation_scheme == "static": + if not self.quant_config.is_checkpoint_fp8_serialized: + raise ValueError( + "Found static activation scheme for checkpoint that " + "was not serialized fp8." + ) + + w13_input_scale = torch.nn.Parameter( + torch.ones(num_experts_per_partition, dtype=torch.float32), + requires_grad=False, + ) + layer.register_parameter("w13_input_scale", w13_input_scale) + set_weight_attrs(w13_input_scale, extra_weight_attrs) + + w2_input_scale = torch.nn.Parameter( + torch.ones(num_experts_per_partition, dtype=torch.float32), + requires_grad=False, + ) + layer.register_parameter("w2_input_scale", w2_input_scale) + set_weight_attrs(w2_input_scale, extra_weight_attrs) + + else: + layer.w13_input_scale = None + layer.w2_input_scale = None + + def process_weights_after_loading(self, layer: Module) -> None: + + # If checkpoint is fp16, quantize in place. + if not self.quant_config.is_checkpoint_fp8_serialized: + # If rocm, use float8_e4m3fnuz as dtype + fp8_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn + w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype) + w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype) + + layer.w13_weight_scale = torch.nn.Parameter( + torch.ones( + layer.num_experts_per_partition, + dtype=torch.float32, + device=w13_weight.device, + ), + requires_grad=False, + ) + + for expert in range(layer.num_experts_per_partition): + w13_weight[expert, :, :], layer.w13_weight_scale[expert] = ( + scaled_fp8_quant(layer.w13_weight.data[expert, :, :]) + ) + w2_weight[expert, :, :], layer.w2_weight_scale[expert] = ( + scaled_fp8_quant(layer.w2_weight.data[expert, :, :]) + ) + layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False) + layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False) + return + + # If checkpoint is fp8, we need to handle that the + # MoE kernels require single activation scale and single weight + # scale for w13 per expert. + else: + if self.quant_config.activation_scheme == "static": + if layer.w13_input_scale is None or layer.w2_input_scale is None: + raise ValueError( + "QuantConfig has static quantization, but found " + "activation scales are None." + ) + layer.w13_weight_scale = torch.nn.Parameter( + torch.max(layer.w13_weight_scale, dim=1).values, + requires_grad=False, + ) + if self.block_quant: + # If ROCm, normalize the weights and scales to e4m3fnuz + if _is_fp8_fnuz: + # activation_scheme: dynamic + w13_weight, w13_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( + weight=layer.w13_weight, + weight_scale=layer.w13_weight_scale_inv, + input_scale=None, + ) + w2_weight, w2_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( + weight=layer.w2_weight, + weight_scale=layer.w2_weight_scale_inv, + input_scale=None, + ) + # Reset the parameter + layer.w13_weight = torch.nn.Parameter( + w13_weight, requires_grad=False + ) + layer.w13_weight_scale_inv = torch.nn.Parameter( + w13_weight_scale, requires_grad=False + ) + layer.w13_input_scale = None + layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False) + layer.w2_weight_scale_inv = torch.nn.Parameter( + w2_weight_scale, requires_grad=False + ) + layer.w2_input_scale = None + if _use_aiter: + layer.w13_weight = torch.nn.Parameter( + shuffle_weight(layer.w13_weight.data, (16, 16)), + requires_grad=False, + ) + layer.w2_weight = torch.nn.Parameter( + shuffle_weight(layer.w2_weight.data, (16, 16)), + requires_grad=False, + ) + return + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + ) -> torch.Tensor: + raise NotImplementedError + + class Fp8KVCacheMethod(BaseKVCacheMethod): """ Supports loading kv-cache scaling factors from FP8 checkpoints. diff --git a/python/sglang/srt/layers/quantization/gptq.py b/python/sglang/srt/layers/quantization/gptq.py index 3658d0b85793..af56c3be719a 100644 --- a/python/sglang/srt/layers/quantization/gptq.py +++ b/python/sglang/srt/layers/quantization/gptq.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import logging from dataclasses import dataclass from fractions import Fraction @@ -5,7 +7,6 @@ import torch -from sglang.srt.layers.linear import LinearBase, LinearMethodBase, set_weight_attrs from sglang.srt.layers.parameter import ( BasevLLMParameter, ChannelQuantScaleParameter, @@ -16,6 +17,8 @@ permute_param_layout_, ) from sglang.srt.layers.quantization.base_config import ( + FusedMoEMethodBase, + LinearMethodBase, QuantizationConfig, QuantizeMethodBase, ) @@ -34,7 +37,11 @@ verify_marlin_supported, ) from sglang.srt.layers.quantization.scalar_type import ScalarType, scalar_types -from sglang.srt.layers.quantization.utils import replace_parameter, unpack_cols +from sglang.srt.layers.quantization.utils import ( + get_linear_quant_method, + replace_parameter, + unpack_cols, +) try: from vllm import _custom_ops as ops @@ -49,8 +56,6 @@ from sgl_kernel import fused_marlin_moe -FusedMoEMethodBase = QuantizeMethodBase - logger = logging.getLogger(__name__) @@ -179,7 +184,7 @@ def get_config_filenames(cls) -> List[str]: return ["quantize_config.json"] @classmethod - def from_config(cls, config: Dict[str, Any]) -> "GPTQConfig": + def from_config(cls, config: Dict[str, Any]) -> GPTQConfig: dynamic = cls.get_from_keys_or(config, ["dynamic"], default={}) dynamic = {} if dynamic is None else dynamic @@ -191,10 +196,10 @@ def from_config(cls, config: Dict[str, Any]) -> "GPTQConfig": def get_quant_method( self, layer: torch.nn.Module, prefix: str - ) -> Optional["LinearMethodBase"]: + ) -> Optional[LinearMethodBase]: # Delay the import to avoid circular dependency + from sglang.srt.layers.linear import LinearBase from sglang.srt.layers.moe.fused_moe_triton import FusedMoE - from sglang.srt.layers.quantization import get_linear_quant_method if isinstance(layer, LinearBase): return get_linear_quant_method(self, layer, prefix, GPTQLinearMethod) @@ -303,7 +308,7 @@ def get_config_filenames(cls) -> List[str]: return ["quantize_config.json"] @classmethod - def from_config(cls, config: Dict[str, Any]) -> "GPTQMarlinConfig": + def from_config(cls, config: Dict[str, Any]) -> GPTQMarlinConfig: dynamic = cls.get_from_keys_or(config, ["dynamic"], default={}) dynamic = {} if dynamic is None else dynamic @@ -354,7 +359,6 @@ def get_quant_method( ) -> Optional[QuantizeMethodBase]: # Delay the import to avoid circular dependency from sglang.srt.layers.moe.fused_moe_triton import FusedMoE - from sglang.srt.layers.quantization import get_linear_quant_method if isinstance(layer, FusedMoE): return GPTQMarlinMoEMethod(self) @@ -832,6 +836,7 @@ def create_weights( **extra_weight_attrs, ): # Delay the import to avoid circular dependency + from sglang.srt.layers.linear import set_weight_attrs from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported intermediate_size = extra_weight_attrs.pop("intermediate_size") diff --git a/python/sglang/srt/layers/quantization/marlin_utils.py b/python/sglang/srt/layers/quantization/marlin_utils.py index 503c3d003632..1edc672ab3f8 100644 --- a/python/sglang/srt/layers/quantization/marlin_utils.py +++ b/python/sglang/srt/layers/quantization/marlin_utils.py @@ -1,25 +1,31 @@ # SPDX-License-Identifier: Apache-2.0 # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/marlin_utils.py +from __future__ import annotations + import logging -from typing import Any, Optional +from typing import TYPE_CHECKING, Any, Optional import numpy import torch -from sglang.srt.layers.linear import LinearBase, LinearMethodBase from sglang.srt.layers.parameter import ( BasevLLMParameter, ChannelQuantScaleParameter, GroupQuantScaleParameter, PackedvLLMParameter, ) -from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.layers.quantization.base_config import ( + LinearMethodBase, + QuantizationConfig, +) from sglang.srt.layers.quantization.scalar_type import ScalarType, scalar_types from sglang.srt.layers.quantization.utils import pack_cols, unpack_cols -from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead from sglang.srt.utils import get_device_capability +if TYPE_CHECKING: + from sglang.srt.layers.linear import LinearBase + try: from vllm import _custom_ops as ops except ImportError: @@ -617,7 +623,10 @@ def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str] def get_quant_method( self, layer: torch.nn.Module, prefix: str - ) -> Optional["MarlinLinearMethod"]: + ) -> Optional[MarlinLinearMethod]: + from sglang.srt.layers.linear import LinearBase + from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead + if isinstance(layer, LinearBase) or ( isinstance(layer, ParallelLMHead) and self.lm_head_quantized ): diff --git a/python/sglang/srt/layers/quantization/modelopt_quant.py b/python/sglang/srt/layers/quantization/modelopt_quant.py index 85be4f8f4604..5263f3b920b1 100644 --- a/python/sglang/srt/layers/quantization/modelopt_quant.py +++ b/python/sglang/srt/layers/quantization/modelopt_quant.py @@ -1,4 +1,5 @@ # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/modelopt.py +from __future__ import annotations import logging from typing import Any, Callable, Dict, List, Optional @@ -6,14 +7,11 @@ import torch from torch.nn.parameter import Parameter -from sglang.srt.layers.linear import ( - LinearBase, - LinearMethodBase, - UnquantizedLinearMethod, -) from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams, CutlassMoEType from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter from sglang.srt.layers.quantization.base_config import ( + FusedMoEMethodBase, + LinearMethodBase, QuantizationConfig, QuantizeMethodBase, ) @@ -23,6 +21,7 @@ is_sm100_supported, ) from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod +from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod from sglang.srt.layers.quantization.utils import ( convert_to_channelwise, is_layer_skipped, @@ -86,7 +85,7 @@ def get_config_filenames(cls) -> List[str]: return ["hf_quant_config.json"] @classmethod - def from_config(cls, config: Dict[str, Any]) -> "ModelOptFp8Config": + def from_config(cls, config: Dict[str, Any]) -> ModelOptFp8Config: quant_method = cls.get_from_keys(config, ["quantization"]).get("quant_algo") kv_cache_quant_method = cls.get_from_keys(config, ["quantization"]).get( "kv_cache_quant_algo" @@ -109,7 +108,11 @@ def from_config(cls, config: Dict[str, Any]) -> "ModelOptFp8Config": def get_quant_method( self, layer: torch.nn.Module, prefix: str - ) -> Optional["QuantizeMethodBase"]: + ) -> Optional[QuantizeMethodBase]: + + from sglang.srt.layers.linear import LinearBase + from sglang.srt.layers.moe.fused_moe_triton import FusedMoE + if self.exclude_modules and any( module in prefix or ( @@ -125,9 +128,6 @@ def get_quant_method( if self.kv_cache_quant_method and isinstance(layer, RadixAttention): return ModelOptFp8KVCacheMethod(self) - # Add MoE support - from sglang.srt.layers.moe.fused_moe_triton import FusedMoE - if isinstance(layer, FusedMoE): return ModelOptFp8MoEMethod(self) @@ -246,7 +246,7 @@ def __init__(self, quant_config: ModelOptFp8Config): super().__init__(quant_config) -class ModelOptFp8MoEMethod: +class ModelOptFp8MoEMethod(FusedMoEMethodBase): """MoE method for ModelOpt FP8. Supports loading FP8 checkpoints with static weight scale and activation scale. @@ -254,30 +254,6 @@ class ModelOptFp8MoEMethod: quant_config: The ModelOpt quantization config. """ - def __new__(cls, *args, **kwargs): - """ - Dynamic class composition pattern. - - This allows us to effectively "inject" FusedMoEMethodBase as a parent class - at runtime while avoiding circular import issues. - """ - from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase - - if not hasattr(cls, "_initialized"): - original_init = cls.__init__ - new_cls = type( - cls.__name__, - (FusedMoEMethodBase,), - { - "__init__": original_init, - **{k: v for k, v in cls.__dict__.items() if k != "__dict__"}, - }, - ) - obj = super(new_cls, new_cls).__new__(new_cls) - obj.__init__(*args, **kwargs) - return obj - return super().__new__(cls) - def __init__(self, quant_config: ModelOptFp8Config): self.quant_config = quant_config self.cutlass_fp8_supported = cutlass_fp8_supported() @@ -514,7 +490,7 @@ def get_config_filenames(cls) -> List[str]: return ["hf_quant_config.json"] @classmethod - def from_config(cls, config: Dict[str, Any]) -> "ModelOptFp4Config": + def from_config(cls, config: Dict[str, Any]) -> ModelOptFp4Config: quant_config = cls.get_from_keys(config, ["quantization"]) quant_method = quant_config["quant_algo"] if not quant_method in ["FP8", "NVFP4"]: @@ -559,7 +535,8 @@ def is_layer_excluded(self, prefix: str, exclude_modules: list): def get_quant_method( self, layer: torch.nn.Module, prefix: str - ) -> Optional["QuantizeMethodBase"]: + ) -> Optional[QuantizeMethodBase]: + from sglang.srt.layers.linear import LinearBase from sglang.srt.layers.moe.fused_moe_triton import FusedMoE if isinstance(layer, LinearBase): @@ -740,31 +717,13 @@ def apply( return out.view(*output_shape) -class ModelOptNvFp4FusedMoEMethod: +class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): """ MoE Method for FP4 Quantization with Blockscales and PerTensorScales Args: quant_config: NVFP4 Quant Config """ - def __new__(cls, *args, **kwargs): - from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase - - if not hasattr(cls, "_initialized"): - original_init = cls.__init__ - new_cls = type( - cls.__name__, - (FusedMoEMethodBase,), - { - "__init__": original_init, - **{k: v for k, v in cls.__dict__.items() if k != "__dict__"}, - }, - ) - obj = super(new_cls, new_cls).__new__(new_cls) - obj.__init__(*args, **kwargs) - return obj - return super().__new__(cls) - def __init__(self, quant_config: ModelOptFp4Config): self.quant_config = quant_config if not is_sm100_supported(): diff --git a/python/sglang/srt/layers/quantization/moe_wna16.py b/python/sglang/srt/layers/quantization/moe_wna16.py index fe812595a80b..f83b9bb1f71d 100644 --- a/python/sglang/srt/layers/quantization/moe_wna16.py +++ b/python/sglang/srt/layers/quantization/moe_wna16.py @@ -1,4 +1,5 @@ # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/moe_wna16.py +from __future__ import annotations import logging from typing import Any, Callable, Dict, List, Optional @@ -7,13 +8,14 @@ from sglang.srt.distributed import get_tensor_model_parallel_rank from sglang.srt.distributed.parallel_state import get_tp_group -from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod from sglang.srt.layers.quantization.awq import AWQConfig from sglang.srt.layers.quantization.base_config import ( + FusedMoEMethodBase, QuantizationConfig, QuantizeMethodBase, ) from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQMarlinConfig +from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod from sglang.srt.utils import get_device_capability, set_weight_attrs logger = logging.getLogger(__name__) @@ -118,7 +120,7 @@ def get_scaled_act_names(self) -> List[str]: raise NotImplementedError @classmethod - def from_config(cls, config: Dict[str, Any]) -> "MoeWNA16Config": + def from_config(cls, config: Dict[str, Any]) -> MoeWNA16Config: quant_method = cls.get_from_keys(config, ["quant_method"]) weight_bits = cls.get_from_keys(config, ["bits"]) group_size = cls.get_from_keys(config, ["group_size"]) @@ -177,8 +179,9 @@ def is_moe_wna16_compatible(cls, quant_config: Dict[str, Any]): def get_quant_method( self, layer: torch.nn.Module, prefix: str - ) -> Optional["QuantizeMethodBase"]: + ) -> Optional[QuantizeMethodBase]: # avoid circular import + from sglang.srt.layers.linear import LinearBase from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE if is_layer_skipped_quant(prefix, self.modules_to_not_convert): @@ -209,32 +212,13 @@ def is_layer_skipped_quant(prefix: str, modules_to_not_convert: List[str]): return any(module_name in prefix for module_name in modules_to_not_convert) -class MoeWNA16Method: +class MoeWNA16Method(FusedMoEMethodBase): """Linear method for MOE WNA16 (W8A16/W4A16) quantization. Args: quant_config: The MOE WNA16 (W8A16/W4A16) quantization config. """ - def __new__(cls, *args, **kwargs): - # avoid circular import - from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase - - if not hasattr(cls, "_initialized"): - original_init = cls.__init__ - new_cls = type( - cls.__name__, - (FusedMoEMethodBase,), - { - "__init__": original_init, - **{k: v for k, v in cls.__dict__.items() if k != "__dict__"}, - }, - ) - obj = super(new_cls, new_cls).__new__(new_cls) - obj.__init__(*args, **kwargs) - return obj - return super().__new__(cls) - def __init__(self, quant_config: MoeWNA16Config): self.quant_config = quant_config diff --git a/python/sglang/srt/layers/quantization/qoq.py b/python/sglang/srt/layers/quantization/qoq.py index 3e3a3dfb6340..ec0fda482c4b 100644 --- a/python/sglang/srt/layers/quantization/qoq.py +++ b/python/sglang/srt/layers/quantization/qoq.py @@ -1,16 +1,17 @@ -from typing import Any, Callable, Dict, List, Optional +from __future__ import annotations + +from typing import Any, Dict, List, Optional import torch from torch.nn.parameter import Parameter -from sglang.srt.distributed import get_tensor_model_parallel_world_size -from sglang.srt.layers.linear import LinearMethodBase from sglang.srt.layers.parameter import ( ChannelQuantScaleParameter, GroupQuantScaleParameter, ModelWeightParameter, ) from sglang.srt.layers.quantization.base_config import ( + LinearMethodBase, QuantizationConfig, QuantizeMethodBase, ) @@ -71,7 +72,7 @@ def get_min_capability(cls) -> int: return 80 @classmethod - def get_name(self) -> str: + def get_name(cls) -> str: return "qoq" @classmethod @@ -83,7 +84,7 @@ def get_config_filenames(cls) -> List[str]: ] @classmethod - def from_config(cls, config: Dict[str, Any]) -> "QoQConfig": + def from_config(cls, config: Dict[str, Any]) -> QoQConfig: weight_bits = cls.get_from_keys(config, ["wbits"]) group_size = cls.get_from_keys(config, ["group_size"]) return cls(weight_bits, group_size) @@ -92,7 +93,7 @@ def get_quant_method( self, layer: torch.nn.Module, prefix: str, - ) -> Optional["QuantizeMethodBase"]: + ) -> Optional[QuantizeMethodBase]: from sglang.srt.layers.linear import LinearBase if isinstance(layer, LinearBase): diff --git a/python/sglang/srt/layers/quantization/unquant.py b/python/sglang/srt/layers/quantization/unquant.py new file mode 100644 index 000000000000..28d006255d8e --- /dev/null +++ b/python/sglang/srt/layers/quantization/unquant.py @@ -0,0 +1,515 @@ +import importlib +from typing import Callable, List, Optional + +import torch +import torch.nn.functional as F +from torch.nn.parameter import Parameter + +from sglang.srt.custom_op import CustomOp +from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading +from sglang.srt.layers.quantization.base_config import ( + FusedMoEMethodBase, + LinearMethodBase, + QuantizeMethodBase, +) +from sglang.srt.utils import ( + cpu_has_amx_support, + get_bool_env_var, + is_cpu, + is_hip, + set_weight_attrs, + use_intel_amx_backend, +) + +has_triton_kernels = importlib.util.find_spec("triton_kernels") is not None + + +_is_cpu_amx_available = cpu_has_amx_support() +_is_hip = is_hip() +_is_cpu = is_cpu() +_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip + +if _use_aiter: + from aiter import ActivationType + from aiter.fused_moe import fused_moe + from aiter.fused_moe_bf16_asm import ck_moe_2stages + from aiter.ops.shuffle import shuffle_weight + + +class UnquantizedEmbeddingMethod(QuantizeMethodBase): + """Unquantized method for embeddings.""" + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + """Create weights for embedding layer.""" + weight = Parameter( + torch.empty( + sum(output_partition_sizes), + input_size_per_partition, + dtype=params_dtype, + ), + requires_grad=False, + ) + set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) + layer.register_parameter("weight", weight) + set_weight_attrs(weight, extra_weight_attrs) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + return F.linear(x, layer.weight, bias) + + def embedding(self, layer: torch.nn.Module, input_: torch.Tensor) -> torch.Tensor: + return F.embedding(input_, layer.weight) + + +class UnquantizedLinearMethod(LinearMethodBase): + """Linear method without quantization.""" + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + weight = Parameter( + torch.empty( + sum(output_partition_sizes), + input_size_per_partition, + dtype=params_dtype, + ), + requires_grad=False, + ) + set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) + layer.register_parameter("weight", weight) + set_weight_attrs(weight, extra_weight_attrs) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + if _is_cpu and _is_cpu_amx_available: + _amx_process_weight_after_loading(layer, ["weight"]) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + + if use_intel_amx_backend(layer): + return torch.ops.sgl_kernel.weight_packed_linear( + x, layer.weight, bias, True # is_vnni + ) + + return F.linear(x, layer.weight, bias) + + +class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): + """MoE method without quantization.""" + + def __init__(self, use_triton_kernels: bool = False): + super().__init__() + self.use_triton_kernels = use_triton_kernels + + from sglang.srt.layers.moe.fused_moe_native import moe_forward_native + + if torch.cuda.is_available(): + from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts + + if has_triton_kernels: + from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import ( + triton_kernel_moe_forward, + ) + else: + triton_kernel_moe_forward = None + else: + fused_experts = None # type: ignore + triton_kernel_moe_forward = None + + self.moe_forward_native = moe_forward_native + self.fused_experts = fused_experts + self.triton_kernel_moe_forward = triton_kernel_moe_forward + + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + # Fused gate_up_proj (column parallel) + w13_weight_n, w13_weight_k = 2 * intermediate_size, hidden_size + if self.use_triton_kernels: + w13_weight_n, w13_weight_k = w13_weight_k, w13_weight_n + w13_weight = torch.nn.Parameter( + torch.empty(num_experts, w13_weight_n, w13_weight_k, dtype=params_dtype), + requires_grad=False, + ) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + # down_proj (row parallel) + w2_weight_n, w2_weight_k = ( + hidden_size, + intermediate_size, + ) + if self.use_triton_kernels: + w2_weight_n, w2_weight_k = w2_weight_k, w2_weight_n + w2_weight = torch.nn.Parameter( + torch.empty(num_experts, w2_weight_n, w2_weight_k, dtype=params_dtype), + requires_grad=False, + ) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + if _use_aiter: + layer.w13_weight = torch.nn.Parameter( + shuffle_weight(layer.w13_weight.data, (16, 16)), + requires_grad=False, + ) + torch.cuda.empty_cache() + layer.w2_weight = torch.nn.Parameter( + shuffle_weight(layer.w2_weight.data, (16, 16)), + requires_grad=False, + ) + torch.cuda.empty_cache() + + # Pack weight for get better performance on CPU + if _is_cpu and _is_cpu_amx_available: + _amx_process_weight_after_loading(layer, ["w13_weight", "w2_weight"]) + + return + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + num_fused_shared_experts: int = 0, + custom_routing_function: Optional[Callable] = None, + correction_bias: Optional[torch.Tensor] = None, + activation: str = "silu", + apply_router_weight_on_input: bool = False, + inplace: bool = True, + no_combine: bool = False, + routed_scaling_factor: Optional[float] = None, + ) -> torch.Tensor: + + return self.forward( + x=x, + layer=layer, + router_logits=router_logits, + top_k=top_k, + renormalize=renormalize, + use_grouped_topk=use_grouped_topk, + topk_group=topk_group, + num_expert_group=num_expert_group, + num_fused_shared_experts=num_fused_shared_experts, + custom_routing_function=custom_routing_function, + correction_bias=correction_bias, + activation=activation, + apply_router_weight_on_input=apply_router_weight_on_input, + inplace=inplace, + no_combine=no_combine, + routed_scaling_factor=routed_scaling_factor, + ) + + def forward_cuda( + self, + layer: torch.nn.Module, + x: torch.Tensor, + use_grouped_topk: bool, + top_k: int, + router_logits: torch.Tensor, + renormalize: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + num_fused_shared_experts: int = 0, + custom_routing_function: Optional[Callable] = None, + correction_bias: Optional[torch.Tensor] = None, + activation: str = "silu", + apply_router_weight_on_input: bool = False, + inplace: bool = True, + no_combine: bool = False, + routed_scaling_factor: Optional[float] = None, + ) -> torch.Tensor: + + if self.use_triton_kernels: + return self.triton_kernel_moe_forward( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + gating_output=router_logits, + topk=top_k, + renormalize=renormalize, + ) + else: + from sglang.srt.layers.moe.topk import select_experts + + topk_weights, topk_ids = select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + num_fused_shared_experts=num_fused_shared_experts, + custom_routing_function=custom_routing_function, + correction_bias=correction_bias, + routed_scaling_factor=routed_scaling_factor, + ) + + if _use_aiter: + assert not no_combine, "unsupported" + if apply_router_weight_on_input: + assert ( + topk_weights.dim() == 2 + ), "`topk_weights` should be in shape (num_tokens, topk)" + _, topk = topk_weights.shape + assert ( + topk == 1 + ), "Only support topk=1 when `apply_router_weight_on_input` is True" + x = x * topk_weights.to(x.dtype) + topk_weights = torch.ones_like( + topk_weights, dtype=torch.float32 + ) # topk_weights must be FP32 (float32) + + return fused_moe( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights, + topk_ids, + activation=( + ActivationType.Silu + if activation == "silu" + else ActivationType.Gelu + ), + ) + else: + return self.fused_experts( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=inplace and not no_combine, + activation=activation, + apply_router_weight_on_input=apply_router_weight_on_input, + no_combine=no_combine, + routed_scaling_factor=routed_scaling_factor, + ) + + def forward_cpu( + self, + layer: torch.nn.Module, + x: torch.Tensor, + use_grouped_topk: bool, + top_k: int, + router_logits: torch.Tensor, + renormalize: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + num_fused_shared_experts: int = 0, + custom_routing_function: Optional[Callable] = None, + correction_bias: Optional[torch.Tensor] = None, + activation: str = "silu", + apply_router_weight_on_input: bool = False, + inplace: bool = True, + no_combine: bool = False, + routed_scaling_factor: Optional[float] = None, + ) -> torch.Tensor: + assert activation == "silu", f"activation = {activation} is not supported." + + if use_intel_amx_backend(layer) and not apply_router_weight_on_input: + + from sglang.srt.layers.moe.topk import select_experts + + topk_weights, topk_ids = select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + num_fused_shared_experts=num_fused_shared_experts, + custom_routing_function=custom_routing_function, + correction_bias=correction_bias, + routed_scaling_factor=routed_scaling_factor, + ) + + # TODO: support apply_router_weight_on_input in the fused_experts_cpu kernel + return torch.ops.sgl_kernel.fused_experts_cpu( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights, + topk_ids, + False, # inplace # See [Note] inplace should be False in fused_experts. + False, # use_int8_w8a8 + False, # use_fp8_w8a16 + None, # w1_scale + None, # w2_scale + None, # block_size + None, # a1_scale + None, # a2_scale + True, # is_vnni + ) + else: + return self.moe_forward_native( + layer, + x, + use_grouped_topk, + top_k, + router_logits, + renormalize, + topk_group, + num_expert_group, + num_fused_shared_experts, + custom_routing_function, + correction_bias, + activation, + apply_router_weight_on_input, + inplace, + no_combine, + routed_scaling_factor, + ) + + def forward_npu( + self, + layer: torch.nn.Module, + x: torch.Tensor, + use_grouped_topk: bool, + top_k: int, + router_logits: torch.Tensor, + renormalize: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + num_fused_shared_experts: int = 0, + custom_routing_function: Optional[Callable] = None, + correction_bias: Optional[torch.Tensor] = None, + activation: str = "silu", + apply_router_weight_on_input: bool = False, + inplace: bool = True, + no_combine: bool = False, + routed_scaling_factor: Optional[float] = None, + ) -> torch.Tensor: + return self.moe_forward_native( + layer, + x, + use_grouped_topk, + top_k, + router_logits, + renormalize, + topk_group, + num_expert_group, + num_fused_shared_experts, + custom_routing_function, + correction_bias, + activation, + apply_router_weight_on_input, + inplace, + no_combine, + routed_scaling_factor, + ) + + def forward_tpu(self, *args, **kwargs) -> torch.Tensor: + raise NotImplementedError("The TPU backend currently does not support MoE.") + + forward_native = forward_cpu + + +class UnquantizedEPMoEMethod(FusedMoEMethodBase, CustomOp): + + def create_weights( + self, + layer: torch.nn.Module, + num_experts_per_partition: int, + hidden_size: int, + intermediate_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + # Fused gate_up_proj (column parallel) + w13_weight = torch.nn.Parameter( + torch.empty( + num_experts_per_partition, + 2 * intermediate_size, + hidden_size, + dtype=params_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + # down_proj (row parallel) + w2_weight = torch.nn.Parameter( + torch.empty( + num_experts_per_partition, + hidden_size, + intermediate_size, + dtype=params_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + # scale + layer.register_parameter("w13_input_scale", None) + layer.register_parameter("w13_weight_scale", None) + + ones_tensor = torch.ones(num_experts_per_partition, dtype=torch.float32) + + w2_input_scale = torch.nn.Parameter( + ones_tensor, + requires_grad=False, + ) + layer.register_parameter("w2_input_scale", w2_input_scale) + set_weight_attrs(w2_input_scale, extra_weight_attrs) + + w2_weight_scale = torch.nn.Parameter( + ones_tensor, + requires_grad=False, + ) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + ) -> torch.Tensor: + raise NotImplementedError diff --git a/python/sglang/srt/layers/quantization/utils.py b/python/sglang/srt/layers/quantization/utils.py index 2371208f7895..51d70255d90c 100644 --- a/python/sglang/srt/layers/quantization/utils.py +++ b/python/sglang/srt/layers/quantization/utils.py @@ -1,7 +1,11 @@ # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/quant_utils.py +from __future__ import annotations + +import re +from copy import deepcopy from types import MappingProxyType -from typing import List, Mapping, Optional, Tuple, Union +from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, Union import numpy import torch @@ -10,6 +14,9 @@ from sglang.srt.layers.quantization.scalar_type import ScalarType from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_npu +if TYPE_CHECKING: + from sglang.srt.layers.quantization.base_config import QuantizationConfig + _is_cuda = is_cuda() _is_npu = is_npu() _is_cpu_amx_available = cpu_has_amx_support() @@ -147,6 +154,94 @@ def replace_parameter( mod.register_parameter(name, torch.nn.Parameter(new, requires_grad=False)) +# Match dynamic rules with module name (prefix) and override quantize +# config if module (prefix) matches a rule +def override_config(config: QuantizationConfig, prefix: str): + weight_bits = get_dynamic_override(config, prefix, "bits", config.weight_bits) + if isinstance(weight_bits, int): + config.weight_bits = weight_bits + group_size = get_dynamic_override(config, prefix, "group_size", config.group_size) + if isinstance(group_size, int): + config.group_size = group_size + desc_act = get_dynamic_override(config, prefix, "desc_act", config.desc_act) + if isinstance(desc_act, bool): + config.desc_act = desc_act + + config.pack_factor = 32 // config.weight_bits # packed into int32 + if config.get_name() == "gptq_marlin": + is_sym = get_dynamic_override(config, prefix, "sym", config.is_sym) + if isinstance(is_sym, bool): + config.is_sym = is_sym + + if (config.weight_bits, config.is_sym) not in config.TYPE_MAP: + raise ValueError( + "Unsupported quantization config: " + f"bits={config.weight_bits}, sym={config.is_sym}" + ) + + config.quant_type = config.TYPE_MAP[(config.weight_bits, config.is_sym)] + elif config.get_name() == "gptq": + if config.weight_bits not in [2, 3, 4, 8]: + raise ValueError( + "Currently, only 2/3/4/8-bit weight quantization is " + f"supported for GPTQ, but got {config.weight_bits} bits." + ) + + +def get_dynamic_override( + config: QuantizationConfig, + layer_name: str, + key: Optional[str] = None, + default_value: Union[int, bool, None] = None, +) -> Union[Dict, int, bool, None]: + for pattern, pattern_dict in config.dynamic.items(): + # Negative match: matched modules are excluded from quantized init + if pattern.startswith("-:"): + if re.match(pattern.removeprefix("-:"), layer_name): + return False + # Positive match: matched modules have quant properties overrides + # base quant config + elif re.match(pattern.removeprefix("+:"), layer_name): + if key is None: + return pattern_dict + else: + return pattern_dict.get(key, default_value) + return default_value + + +def get_linear_quant_method( + config: QuantizationConfig, + layer: torch.nn.Module, + prefix: str, + linear_method_cls: type, +): + from sglang.srt.layers.linear import LinearBase + from sglang.srt.layers.quantization.unquant import ( + UnquantizedEmbeddingMethod, + UnquantizedLinearMethod, + ) + from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead + + cloned_config = deepcopy(config) + parallel_lm_head_quantized = ( + isinstance(layer, ParallelLMHead) and cloned_config.lm_head_quantized + ) + + if isinstance(layer, LinearBase) or parallel_lm_head_quantized: + # False = skip module, None = no override, else = Positive match + if get_dynamic_override(cloned_config, layer_name=prefix) is False: + if parallel_lm_head_quantized: + return UnquantizedEmbeddingMethod() + return UnquantizedLinearMethod() + + if prefix: + # Dynamic per module/layer rules may override base config + override_config(cloned_config, prefix=prefix) + + return linear_method_cls(cloned_config) + return None + + def get_pack_factor(num_bits): assert 32 % num_bits == 0, f"Unsupported num_bits = {num_bits}" return 32 // num_bits diff --git a/python/sglang/srt/layers/quantization/w4afp8.py b/python/sglang/srt/layers/quantization/w4afp8.py index c2820bdfc8cf..1c9dc5d33710 100644 --- a/python/sglang/srt/layers/quantization/w4afp8.py +++ b/python/sglang/srt/layers/quantization/w4afp8.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import logging from typing import Any, Dict, List, Optional @@ -5,12 +7,13 @@ from torch.nn import Module from torch.nn.parameter import Parameter -from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod from sglang.srt.layers.quantization.base_config import ( + FusedMoEMethodBase, QuantizationConfig, QuantizeMethodBase, ) from sglang.srt.layers.quantization.fp8 import Fp8LinearMethod +from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod from sglang.srt.layers.quantization.utils import is_layer_skipped from sglang.srt.utils import set_weight_attrs @@ -62,7 +65,7 @@ def get_config_filenames(cls) -> List[str]: return [] @classmethod - def from_config(cls, config: Dict[str, Any]) -> "W4AFp8Config": + def from_config(cls, config: Dict[str, Any]) -> W4AFp8Config: quant_method = cls.get_from_keys(config, ["quant_method"]) is_checkpoint_fp8_serialized = "fp8" in quant_method is_checkpoint_w4afp8_serialized = "w4afp8" in quant_method @@ -79,7 +82,8 @@ def from_config(cls, config: Dict[str, Any]) -> "W4AFp8Config": def get_quant_method( self, layer: torch.nn.Module, prefix: str - ) -> Optional["QuantizeMethodBase"]: + ) -> Optional[QuantizeMethodBase]: + from sglang.srt.layers.linear import LinearBase from sglang.srt.layers.moe.fused_moe_triton import FusedMoE if isinstance(layer, LinearBase): @@ -94,7 +98,7 @@ def get_scaled_act_names(self) -> List[str]: return [] -class W4AFp8MoEMethod: +class W4AFp8MoEMethod(FusedMoEMethodBase): def __init__(self, quant_config: W4AFp8Config): self.quant_config = quant_config diff --git a/python/sglang/srt/layers/quantization/w8a8_fp8.py b/python/sglang/srt/layers/quantization/w8a8_fp8.py index b2e606f4d2ed..871a4534ca3e 100644 --- a/python/sglang/srt/layers/quantization/w8a8_fp8.py +++ b/python/sglang/srt/layers/quantization/w8a8_fp8.py @@ -1,11 +1,14 @@ +from __future__ import annotations + from typing import Any, Callable, Dict, List, Optional import torch from torch.nn.parameter import Parameter -from sglang.srt.layers.linear import LinearMethodBase from sglang.srt.layers.parameter import ChannelQuantScaleParameter, ModelWeightParameter from sglang.srt.layers.quantization.base_config import ( + FusedMoEMethodBase, + LinearMethodBase, QuantizationConfig, QuantizeMethodBase, ) @@ -64,7 +67,7 @@ def get_config_filenames(cls) -> List[str]: return [] @classmethod - def from_config(cls, config: Dict[str, Any]) -> "W8A8Fp8Config": + def from_config(cls, config: Dict[str, Any]) -> W8A8Fp8Config: quant_method = cls.get_from_keys(config, ["quant_method"]) is_checkpoint_fp8_serialized = ( "compressed-tensors" in quant_method or "w8a8_fp8" in quant_method @@ -75,7 +78,7 @@ def get_quant_method( self, layer: torch.nn.Module, prefix: str, - ) -> Optional["QuantizeMethodBase"]: + ) -> Optional[QuantizeMethodBase]: from sglang.srt.layers.linear import LinearBase from sglang.srt.layers.moe.fused_moe_triton import FusedMoE @@ -183,7 +186,7 @@ def apply( ) -class W8A8FP8MoEMethod: +class W8A8FP8MoEMethod(FusedMoEMethodBase): """MoE method for FP8. Supports loading FP8 checkpoints with static weight scale and dynamic/static activation scale. @@ -194,25 +197,7 @@ class W8A8FP8MoEMethod: quant_config: The quantization config. """ - def __new__(cls, *args, **kwargs): - from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase - - if not hasattr(cls, "_initialized"): - original_init = cls.__init__ - new_cls = type( - cls.__name__, - (FusedMoEMethodBase,), - { - "__init__": original_init, - **{k: v for k, v in cls.__dict__.items() if k != "__dict__"}, - }, - ) - obj = super(new_cls, new_cls).__new__(new_cls) - obj.__init__(*args, **kwargs) - return obj - return super().__new__(cls) - - def __init__(self, quant_config): + def __init__(self, quant_config: W8A8Fp8Config): self.quant_config = quant_config def create_weights( diff --git a/python/sglang/srt/layers/quantization/w8a8_int8.py b/python/sglang/srt/layers/quantization/w8a8_int8.py index 49e6f0e8c91e..c8a024bf33ed 100644 --- a/python/sglang/srt/layers/quantization/w8a8_int8.py +++ b/python/sglang/srt/layers/quantization/w8a8_int8.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import importlib import sys from types import MappingProxyType @@ -11,21 +13,19 @@ get_tensor_model_parallel_world_size, ) from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading -from sglang.srt.layers.linear import ( - LinearMethodBase, - RowParallelLinear, - UnquantizedLinearMethod, -) from sglang.srt.layers.parameter import ( ChannelQuantScaleParameter, ModelWeightParameter, PerTensorScaleParameter, ) from sglang.srt.layers.quantization.base_config import ( + FusedMoEMethodBase, + LinearMethodBase, QuantizationConfig, QuantizeMethodBase, ) from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8 +from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod from sglang.srt.utils import ( apply_module_patch, cpu_has_amx_support, @@ -229,14 +229,14 @@ def get_config_filenames(cls) -> List[str]: return [] @classmethod - def from_config(cls, config: Dict[str, Any]) -> "W8A8Int8Config": + def from_config(cls, config: Dict[str, Any]) -> W8A8Int8Config: return cls(config) def get_quant_method( self, layer: torch.nn.Module, prefix: str, - ) -> Optional["QuantizeMethodBase"]: + ) -> Optional[QuantizeMethodBase]: from sglang.srt.layers.linear import LinearBase from sglang.srt.layers.moe.fused_moe_triton import FusedMoE @@ -374,7 +374,7 @@ def apply( ) -class W8A8Int8MoEMethod: +class W8A8Int8MoEMethod(FusedMoEMethodBase): """MoE method for INT8. Supports loading INT8 checkpoints with static weight scale and dynamic/static activation scale. @@ -385,25 +385,7 @@ class W8A8Int8MoEMethod: quant_config: The quantization config. """ - def __new__(cls, *args, **kwargs): - from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase - - if not hasattr(cls, "_initialized"): - original_init = cls.__init__ - new_cls = type( - cls.__name__, - (FusedMoEMethodBase,), - { - "__init__": original_init, - **{k: v for k, v in cls.__dict__.items() if k != "__dict__"}, - }, - ) - obj = super(new_cls, new_cls).__new__(new_cls) - obj.__init__(*args, **kwargs) - return obj - return super().__new__(cls) - - def __init__(self, quant_config): + def __init__(self, quant_config: W8A8Int8Config): self.quant_config = quant_config def create_weights( @@ -885,13 +867,15 @@ def apply( x: torch.Tensor, bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: + from sglang.srt.layers.linear import RowParallelLinear + if isinstance(layer, RowParallelLinear): tp_rank = get_tensor_model_parallel_rank() return self.quant_method.apply(layer, x, bias, tp_rank) return self.quant_method.apply(layer, x, bias) -class NPU_W8A8MoEMethod: +class NPU_W8A8MoEMethod(FusedMoEMethodBase): """MoE method for NPU quantization. This class search for specific quantization diff --git a/python/sglang/srt/layers/vocab_parallel_embedding.py b/python/sglang/srt/layers/vocab_parallel_embedding.py index 0e075a2518f2..d925506f5ecc 100644 --- a/python/sglang/srt/layers/vocab_parallel_embedding.py +++ b/python/sglang/srt/layers/vocab_parallel_embedding.py @@ -5,7 +5,6 @@ from typing import List, Optional, Sequence, Tuple import torch -import torch.nn.functional as F from torch.nn.parameter import Parameter, UninitializedParameter from sglang.srt.distributed import ( @@ -22,6 +21,7 @@ QuantizeMethodBase, method_has_implemented_embedding, ) +from sglang.srt.layers.quantization.unquant import UnquantizedEmbeddingMethod from sglang.srt.utils import cpu_has_amx_support, is_cpu, set_weight_attrs DEFAULT_VOCAB_PADDING_SIZE = 64 @@ -32,44 +32,6 @@ logger = logging.getLogger(__name__) -class UnquantizedEmbeddingMethod(QuantizeMethodBase): - """Unquantized method for embeddings.""" - - def create_weights( - self, - layer: torch.nn.Module, - input_size_per_partition: int, - output_partition_sizes: List[int], - input_size: int, - output_size: int, - params_dtype: torch.dtype, - **extra_weight_attrs, - ): - """Create weights for embedding layer.""" - weight = Parameter( - torch.empty( - sum(output_partition_sizes), - input_size_per_partition, - dtype=params_dtype, - ), - requires_grad=False, - ) - set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) - layer.register_parameter("weight", weight) - set_weight_attrs(weight, extra_weight_attrs) - - def apply( - self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - return F.linear(x, layer.weight, bias) - - def embedding(self, layer: torch.nn.Module, input_: torch.Tensor) -> torch.Tensor: - return F.embedding(input_, layer.weight) - - def pad_vocab_size(vocab_size: int, pad_to: int = DEFAULT_VOCAB_PADDING_SIZE) -> int: """Pad the vocab size to the given value.""" return ((vocab_size + pad_to - 1) // pad_to) * pad_to From af1cc8fe2dd8f87f3d79419e20cf655338eecf28 Mon Sep 17 00:00:00 2001 From: Yuan Luo Date: Thu, 17 Jul 2025 19:33:02 +0800 Subject: [PATCH 019/396] [kernel] opt moe align block kernel by block/warp scan algorithm (#7884) --- sgl-kernel/csrc/moe/moe_align_kernel.cu | 93 ++++++++++++++----------- 1 file changed, 51 insertions(+), 42 deletions(-) diff --git a/sgl-kernel/csrc/moe/moe_align_kernel.cu b/sgl-kernel/csrc/moe/moe_align_kernel.cu index ad80b0c75e81..b66afa0e4843 100644 --- a/sgl-kernel/csrc/moe/moe_align_kernel.cu +++ b/sgl-kernel/csrc/moe/moe_align_kernel.cu @@ -26,6 +26,12 @@ limitations under the License. #define VEC_SIZE 4 using Vec = int4; +#ifndef __CUDA_ARCH__ // HIP +#define SHFL_UP(mask, val, delta) __shfl_up((val), (delta)) +#else // CUDA +#define SHFL_UP(mask, val, delta) __shfl_up_sync((mask), (val), (delta)) +#endif + template __global__ void count_and_sort_expert_tokens_kernel( const scalar_t* __restrict__ topk_ids, @@ -42,6 +48,16 @@ __global__ void count_and_sort_expert_tokens_kernel( } } +__device__ __forceinline__ int warp_exclusive_scan(int v, unsigned mask = 0xffffffffu) { + int original = v; +#pragma unroll + for (int offset = 1; offset < WARP_SIZE; offset <<= 1) { + int n = SHFL_UP(mask, v, offset); + if ((threadIdx.x & (WARP_SIZE - 1)) >= offset) v += n; + } + return v - original; +} + template __global__ void moe_align_block_size_kernel( const scalar_t* __restrict__ topk_ids, @@ -58,6 +74,7 @@ __global__ void moe_align_block_size_kernel( int32_t* shared_counts = smem; // [num_experts] int32_t* prefix = shared_counts + num_experts; // [num_experts + 1] int32_t* scan_buf = prefix + num_experts + 1; // [scan_size] + int32_t* warp_sums = scan_buf + scan_size; // [<= 32] __shared__ int32_t s_total_tokens_post_pad; const size_t tid = threadIdx.x; @@ -76,6 +93,7 @@ __global__ void moe_align_block_size_kernel( __syncthreads(); + // Calculate padded_cnt, write scan_buf, directly prefix sum int32_t padded_count = 0; if (tid < num_experts) { int32_t count = shared_counts[tid]; @@ -83,58 +101,52 @@ __global__ void moe_align_block_size_kernel( scan_buf[tid] = padded_count; } - if (tid >= num_experts && tid < scan_size) { - scan_buf[tid] = 0; - } - + // Intra warp prefix sum + const int warp_id = tid / WARP_SIZE; + const int lane_id = tid & (WARP_SIZE - 1); + const int num_warps_for_scan = (scan_size + WARP_SIZE - 1) / WARP_SIZE; + const int warp_sum = warp_exclusive_scan(padded_count) + padded_count; + if (lane_id == WARP_SIZE - 1) warp_sums[warp_id] = warp_sum; __syncthreads(); - // Blelloch scan - int offset = 1; -#pragma unroll - for (int d = scan_size >> 1; d > 0; d >>= 1) { - if (tid < d) { - int ai = offset * (2 * tid + 1) - 1; - int bi = offset * (2 * tid + 2) - 1; - scan_buf[bi] += scan_buf[ai]; - } - offset <<= 1; - __syncthreads(); + // warp0 accumulate all the block's prefix sum + if (tid < WARP_SIZE) { + int val = (tid < num_warps_for_scan) ? warp_sums[tid] : 0; + int incl = warp_exclusive_scan(val) + val; + warp_sums[tid] = incl; } + __syncthreads(); - // down-sweep + // Every thread obtains the whole block's sum if (tid == 0) { - prefix[num_experts] = scan_buf[scan_size - 1]; - scan_buf[scan_size - 1] = 0; + prefix[num_experts] = warp_sums[num_warps_for_scan - 1]; + s_total_tokens_post_pad = prefix[num_experts]; + *total_tokens_post_pad = s_total_tokens_post_pad; } __syncthreads(); -#pragma unroll - for (int d = 1; d < scan_size; d <<= 1) { - offset >>= 1; - if (tid < d) { - int ai = offset * (2 * tid + 1) - 1; - int bi = offset * (2 * tid + 2) - 1; - if (bi < scan_size) { - int temp = scan_buf[ai]; - scan_buf[ai] = scan_buf[bi]; - scan_buf[bi] += temp; - } - } - __syncthreads(); - } + // Fill 0 to scan_buf extended area (tid >= num_expert) + if (tid >= num_experts && tid < scan_size) scan_buf[tid] = 0; + __syncthreads(); - if (tid < num_experts) { - prefix[tid] = scan_buf[tid]; - } + // Perform 2 level exclusive-prefix-sum to scan_buf + int v = (tid < scan_size) ? scan_buf[tid] : 0; + int pre = warp_exclusive_scan(v); + if (lane_id == WARP_SIZE - 1) warp_sums[warp_id] = pre + v; + __syncthreads(); - if (tid == 0) { - s_total_tokens_post_pad = prefix[num_experts]; - *total_tokens_post_pad = s_total_tokens_post_pad; + if (warp_id == 0) { + int val = (lane_id < num_warps_for_scan) ? warp_sums[lane_id] : 0; + warp_sums[lane_id] = warp_exclusive_scan(val); } + __syncthreads(); + int offset = warp_sums[warp_id]; + if (tid < scan_size) scan_buf[tid] = pre + offset; __syncthreads(); + // Write prefix[0..num_experts - 1] and cumsum + if (tid < num_experts) prefix[tid] = scan_buf[tid]; if (tid <= num_experts) { cumsum[tid] = prefix[tid]; } @@ -250,9 +262,6 @@ void moe_align_block_size( bool pad_sorted_token_ids) { const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - int64_t padded_num_experts = ((num_experts + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; - - int experts_per_warp = WARP_SIZE; int threads = 1024; threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; @@ -278,7 +287,7 @@ void moe_align_block_size( auto align_kernel = moe_align_block_size_kernel; const size_t scan_size = next_pow2(num_experts); - const size_t shared_mem_size = (num_experts + (num_experts + 1) + scan_size) * sizeof(int32_t); + const size_t shared_mem_size = (num_experts + (num_experts + 1) + scan_size + WARP_SIZE) * sizeof(int32_t); align_kernel<<<1, threads, shared_mem_size, stream>>>( topk_ids.data_ptr(), From 519ff5c8e69e076fa9120d8f3ffaed98c68b5236 Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Thu, 17 Jul 2025 21:15:51 +0800 Subject: [PATCH 020/396] Super tiny fix typo (#8046) --- .../sglang/srt/layers/attention/flashattention_backend.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/layers/attention/flashattention_backend.py b/python/sglang/srt/layers/attention/flashattention_backend.py index b0615be3c2ca..740b46b6be18 100644 --- a/python/sglang/srt/layers/attention/flashattention_backend.py +++ b/python/sglang/srt/layers/attention/flashattention_backend.py @@ -1617,7 +1617,7 @@ def init_forward_metadata_replay_cuda_graph( metadata.max_seq_len_k + self.page_size - 1 ) // self.page_size - normal_decode_set_medadata( + normal_decode_set_metadata( metadata.cache_seqlens_int32, metadata.cu_seqlens_k, metadata.page_table, @@ -1666,7 +1666,7 @@ def init_forward_metadata_replay_cuda_graph( max_seq_pages = (max_len + self.page_size - 1) // self.page_size metadata.max_seq_len_k = max_len - normal_decode_set_medadata( + normal_decode_set_metadata( metadata.cache_seqlens_int32, metadata.cu_seqlens_k, metadata.page_table, @@ -2089,7 +2089,7 @@ def init_forward_metadata_replay_cuda_graph( # @torch.compile(dynamic=True, backend=get_compiler_backend()) # TODO: fuse these kernels # NOTE: torch.compile makes it slower in speculative decoding -def normal_decode_set_medadata( +def normal_decode_set_metadata( cache_seqlens_int32: torch.Tensor, cu_seqlens_k: torch.Tensor, page_table: torch.Tensor, From 01857fab6189a81b31c9140b67bab5135cf36bb0 Mon Sep 17 00:00:00 2001 From: Ziqi Fan Date: Thu, 17 Jul 2025 06:24:34 -0700 Subject: [PATCH 021/396] fix: update HostKVCache init to report correct msg when available memory is not enough (#8102) --- python/sglang/srt/mem_cache/memory_pool_host.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/mem_cache/memory_pool_host.py b/python/sglang/srt/mem_cache/memory_pool_host.py index a5977fd1d7ce..1bc2ddf7ec45 100644 --- a/python/sglang/srt/mem_cache/memory_pool_host.py +++ b/python/sglang/srt/mem_cache/memory_pool_host.py @@ -71,11 +71,12 @@ def __init__( requested_bytes = self.size * self.size_per_token # preserve at least 10GB for other usage ten_gb = 10 * (1024**3) - if requested_bytes > host_mem.available - ten_gb: + available_bytes = host_mem.available - ten_gb + if requested_bytes > available_bytes: raise ValueError( f"Not enough host memory available. Requesting " f"{requested_bytes / 1e9:.2f} GB but only have " - f"{host_mem.available / 1e9:.2f} GB free. Please reduce the " + f"{available_bytes / 1e9:.2f} GB free. Please reduce the " f"size of the hierarchical cache." ) else: From 42960214994461d93dec2fc3e00383e33c9f0401 Mon Sep 17 00:00:00 2001 From: Asher Date: Fri, 18 Jul 2025 01:00:11 +0800 Subject: [PATCH 022/396] [Hunyuan]: Fix Dense Model Support (#8117) Signed-off-by: Asher Zhang --- python/sglang/srt/models/hunyuan.py | 66 ++++++++++++++++++++++++----- 1 file changed, 55 insertions(+), 11 deletions(-) diff --git a/python/sglang/srt/models/hunyuan.py b/python/sglang/srt/models/hunyuan.py index 41a833f3df98..f23ccc0a8d94 100644 --- a/python/sglang/srt/models/hunyuan.py +++ b/python/sglang/srt/models/hunyuan.py @@ -206,6 +206,42 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return final_hidden_states.view(orig_shape) +def get_head_dim(config): + if hasattr(config, "head_dim"): + return int(config.head_dim) + if hasattr(config, "attention_head_dim"): + return int(config.attention_head_dim) + + # since some hunyuan model don't follow the self.hidden_size // self.total_num_heads rule + # wrong setting may cause runtime error, just throw error if this field is missing. + raise ValueError("Missing head dim config, try set head_dim in config.json") + + +def check_head_dim(config): + # Some models may lack `head_dim` and use `attention_head_dim` instead. + # This attribute is also used by flashinfer_backend.py, so we check for + # consistency and raise an error if it's not met to avoid silent failures. + # Although we could adapt the HunYuan model to use `attention_head_dim`, + # flashinfer expects `head_dim`, so we enforce its presence for correctness. + calc_head_dim = config.hidden_size // config.num_attention_heads + + if hasattr(config, "attention_head_dim"): + if calc_head_dim != config.attention_head_dim and not hasattr( + config, "head_dim" + ): + # in this case, flash infer(and other components may calculate wrong value.) + raise ValueError( + f"HunYuan model config error: calculated head_dim {calc_head_dim} != attention_head_dim {config.attention_head_dim}" + + f"\nPlease Add head_dim:{config.attention_head_dim} in config.json to make sure correctly inference." + ) + + if hasattr(config, "head_dim") and config.attention_head_dim != config.head_dim: + raise ValueError( + f"HunYuan model config error: head_dim({config.head_dim}) != attention_head_dim({config.attention_head_dim})" + + f"\nPlease change head_dim:{config.attention_head_dim} in config.json to make sure correctly inference." + ) + + class HunYuanAttention(nn.Module): def __init__( @@ -240,9 +276,11 @@ def __init__( assert tp_size % self.total_num_kv_heads == 0 self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) # MistralConfig has an optional head_dim introduced by Mistral-Nemo - self.head_dim = getattr( - config, "head_dim", self.hidden_size // self.total_num_heads - ) + # Prioritize `head_dim` but fall back to `attention_head_dim` for Hunyuan models. + self.head_dim = get_head_dim(config) + + check_head_dim(config) + self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 @@ -493,7 +531,6 @@ def forward( hidden_states = self.get_input_embeddings(input_ids) residual = None - cla_factor = _get_cla_factor(self.config) prev_kv_states = None for i in range(len(self.layers)): layer = self.layers[i] @@ -560,6 +597,11 @@ def __init__( if config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight + self.hidden_size = config.hidden_size + self.head_dim = get_head_dim(config) + + check_head_dim(config) + logit_scale = getattr(config, "logit_scale", 1.0) self.logits_processor = LogitsProcessor(config, logit_scale=logit_scale) self.sampler = Sampler() @@ -582,16 +624,14 @@ def _split_qkv_weight(self, qkv: torch.Tensor): self.config, "num_key_value_heads", self.config.num_attention_heads ) num_key_value_groups = num_attention_heads // num_kv_heads - hidden_size = self.config.hidden_size - attention_head_dim = self.config.hidden_size // num_attention_heads qkv = qkv.reshape( - num_kv_heads, num_key_value_groups + 2, attention_head_dim, hidden_size + num_kv_heads, num_key_value_groups + 2, self.head_dim, self.hidden_size ) q, k, v = torch.split(qkv, (num_key_value_groups, 1, 1), dim=1) - q = q.reshape(-1, hidden_size) - k = k.reshape(-1, hidden_size) - v = v.reshape(-1, hidden_size) + q = q.reshape(-1, self.hidden_size) + k = k.reshape(-1, self.hidden_size) + v = v.reshape(-1, self.hidden_size) return torch.concat((q, k, v)) # return qkv.reshape((num_kv_heads, num_key_value_groups+2 , attention_head_dim, hidden_size)).permute((1,0,2,3)).reshape((-1, hidden_size)), @@ -768,4 +808,8 @@ def load_kv_cache_scales(self, quantization_param_path: str) -> None: ) -EntryClass = HunYuanMoEV1ForCausalLM +class HunYuanDenseV1ForCausalLM(HunYuanMoEV1ForCausalLM): + pass + + +EntryClass = [HunYuanMoEV1ForCausalLM, HunYuanDenseV1ForCausalLM] From 3586b4cef232d829491fa47631d3522900f8ff35 Mon Sep 17 00:00:00 2001 From: Zhao Chen Date: Fri, 18 Jul 2025 02:59:05 +0800 Subject: [PATCH 023/396] feat: add production metric for retracted requests due to insufficient kvcache (#7030) Signed-off-by: Zhao Chen --- python/sglang/srt/managers/scheduler.py | 7 ++++++- python/sglang/srt/metrics/collector.py | 9 +++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index ab966f924cc6..874ed60f0fd2 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -675,6 +675,7 @@ def init_metrics(self, tp_rank: int, pp_rank: int, dp_rank: Optional[int]): self.spec_num_total_forward_ct = 0 self.cum_spec_accept_length = 0 self.cum_spec_accept_count = 0 + self.total_retracted_reqs = 0 self.stats = SchedulerStats() if self.enable_metrics: engine_type = "unified" @@ -1477,6 +1478,7 @@ def log_decode_stats( self.stats.num_queue_reqs = len(self.waiting_queue) self.stats.num_grammar_queue_reqs = len(self.grammar_queue) self.stats.spec_accept_length = spec_accept_length + self.stats.total_retracted_reqs = self.total_retracted_reqs self.metrics_collector.log_stats(self.stats) self._emit_kv_metrics() self._publish_kv_events() @@ -1824,14 +1826,17 @@ def update_running_batch(self, batch: ScheduleBatch) -> Optional[ScheduleBatch]: old_ratio = self.new_token_ratio retracted_reqs, new_token_ratio = batch.retract_decode(self.server_args) + num_retracted_reqs = len(retracted_reqs) self.new_token_ratio = new_token_ratio logger.info( "KV cache pool is full. Retract requests. " - f"#retracted_reqs: {len(retracted_reqs)}, " + f"#retracted_reqs: {num_retracted_reqs}, " f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}" ) + self._extend_requests_to_queue(retracted_reqs, is_retracted=True) + self.total_retracted_reqs += num_retracted_reqs else: self.new_token_ratio = max( self.new_token_ratio - self.new_token_ratio_decay, diff --git a/python/sglang/srt/metrics/collector.py b/python/sglang/srt/metrics/collector.py index f8dac44727f2..4c32b8fc6348 100644 --- a/python/sglang/srt/metrics/collector.py +++ b/python/sglang/srt/metrics/collector.py @@ -145,6 +145,7 @@ class SchedulerStats: num_prefill_infight_queue_reqs: int = 0 num_decode_prealloc_queue_reqs: int = 0 num_decode_transfer_queue_reqs: int = 0 + total_retracted_reqs: int = 0 class SchedulerMetricsCollector: @@ -219,6 +220,13 @@ def __init__(self, labels: Dict[str, str]) -> None: multiprocess_mode="mostrecent", ) + self.total_retracted_reqs = Gauge( + name="sglang:total_retracted_reqs", + documentation="The total number of retracted requests due to kvcache full.", + labelnames=labels.keys(), + multiprocess_mode="mostrecent", + ) + # Disaggregation queue metrics self.num_prefill_prealloc_queue_reqs = Gauge( name="sglang:num_prefill_prealloc_queue_reqs", @@ -279,6 +287,7 @@ def log_stats(self, stats: SchedulerStats) -> None: self._log_gauge(self.num_grammar_queue_reqs, stats.num_grammar_queue_reqs) self._log_gauge(self.cache_hit_rate, stats.cache_hit_rate) self._log_gauge(self.spec_accept_length, stats.spec_accept_length) + self._log_gauge(self.total_retracted_reqs, stats.total_retracted_reqs) # Disaggregation metrics self._log_gauge( From e1020dc5883b5a39191952d386f4af60a2ae7a5b Mon Sep 17 00:00:00 2001 From: Mick Date: Fri, 18 Jul 2025 08:59:15 +0800 Subject: [PATCH 024/396] refactor: simply MultimodalTokens logic (#7924) --- .../multimodal/processors/base_processor.py | 79 +++++++++++-------- .../multimodal/processors/deepseek_vl_v2.py | 6 +- .../srt/multimodal/processors/gemma3.py | 25 +++--- .../srt/multimodal/processors/gemma3n.py | 40 +++++----- .../srt/multimodal/processors/internvl.py | 14 ++-- .../srt/multimodal/processors/janus_pro.py | 12 +-- .../srt/multimodal/processors/kimi_vl.py | 24 +++--- .../srt/multimodal/processors/minicpm.py | 14 ++-- .../srt/multimodal/processors/mllama4.py | 4 +- .../srt/multimodal/processors/phi4mm.py | 2 +- .../srt/multimodal/processors/pixtral.py | 2 +- .../srt/multimodal/processors/qwen_vl.py | 35 ++++---- .../sglang/srt/multimodal/processors/vila.py | 22 +++--- 13 files changed, 146 insertions(+), 133 deletions(-) diff --git a/python/sglang/srt/multimodal/processors/base_processor.py b/python/sglang/srt/multimodal/processors/base_processor.py index 44e22885caec..5c44c4d49953 100644 --- a/python/sglang/srt/multimodal/processors/base_processor.py +++ b/python/sglang/srt/multimodal/processors/base_processor.py @@ -21,7 +21,7 @@ class BaseMultiModalProcessorOutput: # input_text, with each frame of video/image represented with a image_token input_text: str - # frames loaded from image and video, in given order + # frames loaded from image, in given order images: Optional[list[Union[Image.Image, dict]]] = None # videos @@ -44,14 +44,26 @@ def organize_results(self) -> List[Tuple[Modality, Any]]: @dataclasses.dataclass class MultimodalSpecialTokens: - image_token: Optional[Union[int, str, List[str]]] = None - video_token: Optional[Union[int, str, List[str]]] = None - audio_token: Optional[Union[int, str, List[str]]] = None + image_token: Optional[Union[str, List[str]]] = None + video_token: Optional[Union[str, List[str]]] = None + audio_token: Optional[Union[str, List[str]]] = None + + image_token_id: Optional[int] = None + video_token_id: Optional[int] = None + audio_token_id: Optional[int] = None image_token_regex: Optional[re.Pattern] = None video_token_regex: Optional[re.Pattern] = None audio_token_regex: Optional[re.Pattern] = None + combined_regex: Optional[re.Pattern] = None + + def build(self, processor): + self.convert_to_strs(processor) + self.parse_regex() + self.get_combined_regex() + return self + def convert_to_str(self, token: Union[str, int], processor) -> str: if token is None: return token @@ -60,11 +72,14 @@ def convert_to_str(self, token: Union[str, int], processor) -> str: return processor.tokenizer.convert_ids_to_tokens([token])[0] def convert_to_strs(self, processor): - self.image_token = self.convert_to_str(self.image_token, processor) - self.video_token = self.convert_to_str(self.video_token, processor) - self.audio_token = self.convert_to_str(self.audio_token, processor) - - def get_modality_of_token(self, token) -> Optional[Modality]: + if not self.image_token: + self.image_token = self.convert_to_str(self.image_token_id, processor) + if not self.video_token: + self.video_token = self.convert_to_str(self.video_token_id, processor) + if not self.audio_token: + self.audio_token = self.convert_to_str(self.audio_token_id, processor) + + def get_modality_of_token(self, token: str) -> Optional[Modality]: """ :return: the modality associated with the given token, if the token is a special_token or matches with the multimodal token regex """ @@ -94,7 +109,12 @@ def parse_regex(self): if self.audio_token_regex is None and self.audio_token is not None: self.audio_token_regex = re.compile(re.escape(self.audio_token)) - def combine_regex(self) -> re.Pattern: + def get_combined_regex(self) -> re.Pattern: + """ + Builds and returns a regex, used to split input str into tokens (with mm special tokens) + """ + if self.combined_regex: + return self.combined_regex tokens = [ self.image_token_regex, self.video_token_regex, @@ -107,7 +127,8 @@ def combine_regex(self) -> re.Pattern: patterns.append(t.pattern) flags |= t.flags combined = "(" + "|".join(f"(?:{p})" for p in patterns) + ")" - return re.compile(combined, flags) + self.combined_regex = re.compile(combined, flags) + return self.combined_regex class BaseMultimodalProcessor(ABC): @@ -341,9 +362,8 @@ def load_mm_data( discard_alpha_channel: if True, discards the alpha channel in the returned images """ - multimodal_tokens.convert_to_strs(self._processor) - multimodal_tokens.parse_regex() - multimodal_tokens_pattern = multimodal_tokens.combine_regex() + multimodal_tokens_pattern = multimodal_tokens.get_combined_regex() + if isinstance(prompt, list) and return_text: assert len(prompt) and isinstance(prompt[0], int) prompt = self._processor.tokenizer.decode(prompt) @@ -445,7 +465,6 @@ def get_mm_items_offset( return result = [(2,4),(6,7)] """ mask = input_ids == mm_token_id - start_positions = (mask & ~torch.roll(mask, 1)).nonzero(as_tuple=True)[0] end_positions = (mask & ~torch.roll(mask, -1)).nonzero(as_tuple=True)[0] @@ -554,7 +573,9 @@ def _process_and_collect_mm_items( return collected_items, input_ids, ret def process_and_combine_mm_data( - self, base_output: BaseMultiModalProcessorOutput + self, + base_output: BaseMultiModalProcessorOutput, + mm_tokens: MultimodalSpecialTokens, ) -> Tuple[List[MultimodalDataItem], torch.Tensor, dict]: """ Process multimodal data and return the combined multimodal items and input_ids. @@ -618,22 +639,14 @@ def process_and_combine_mm_data( # Add offsets to all items for mm_item in all_collected_items: - if mm_item.modality in [Modality.IMAGE, Modality.MULTI_IMAGES]: - mm_item.offsets = self.get_mm_items_offset( - input_ids=input_ids, - mm_token_id=self.IM_TOKEN_ID, - ) - elif mm_item.modality == Modality.AUDIO: - mm_item.offsets = self.get_mm_items_offset( - input_ids=input_ids, - mm_token_id=self.AUDIO_TOKEN_ID, - ) - elif mm_item.modality == Modality.VIDEO: - mm_item.offsets = self.get_mm_items_offset( - input_ids=input_ids, - mm_token_id=self.VIDEO_TOKEN_ID, - ) - else: - raise ValueError(f"Unknown modality: {mm_item.modality}") + mm_item.offsets = self.get_mm_items_offset( + input_ids=input_ids, + mm_token_id={ + Modality.IMAGE: mm_tokens.image_token_id, + Modality.MULTI_IMAGES: mm_tokens.image_token_id, + Modality.VIDEO: mm_tokens.video_token_id, + Modality.AUDIO: mm_tokens.audio_token_id, + }.get(mm_item.modality, None), + ) return all_collected_items, input_ids, ret diff --git a/python/sglang/srt/multimodal/processors/deepseek_vl_v2.py b/python/sglang/srt/multimodal/processors/deepseek_vl_v2.py index 50547ad2d714..c21dce176905 100644 --- a/python/sglang/srt/multimodal/processors/deepseek_vl_v2.py +++ b/python/sglang/srt/multimodal/processors/deepseek_vl_v2.py @@ -33,7 +33,9 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor): def __init__(self, hf_config, server_args, _processor): super().__init__(hf_config, server_args, _processor) - self.IMAGE_TOKEN = "" + self.mm_tokens = MultimodalSpecialTokens(image_token="").build( + _processor + ) async def process_mm_data_async( self, @@ -47,7 +49,7 @@ async def process_mm_data_async( base_output = self.load_mm_data( input_text, image_data=image_data, - multimodal_tokens=MultimodalSpecialTokens(image_token=self.IMAGE_TOKEN), + multimodal_tokens=self.mm_tokens, max_req_input_len=max_req_input_len, ) res = self.process_mm_data( diff --git a/python/sglang/srt/multimodal/processors/gemma3.py b/python/sglang/srt/multimodal/processors/gemma3.py index e0858674a7b2..dac9bd5c8241 100644 --- a/python/sglang/srt/multimodal/processors/gemma3.py +++ b/python/sglang/srt/multimodal/processors/gemma3.py @@ -4,7 +4,6 @@ from sglang.srt.managers.multimodal_processor import ( BaseMultimodalProcessor as SGLangBaseProcessor, ) -from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem from sglang.srt.models.gemma3_mm import Gemma3ForConditionalGeneration from sglang.srt.multimodal.processors.base_processor import MultimodalSpecialTokens @@ -17,15 +16,17 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor): def __init__(self, hf_config, server_args, _processor): super().__init__(hf_config, server_args, _processor) - # The single, pre-expanded image token. - self.IMAGE_TOKEN = "" - # The regex that matches expanded image tokens. - self.IMAGE_TOKEN_REGEX = re.compile( - r"(?:(?:)*)?" - ) self.IM_START_TOKEN_ID = hf_config.boi_token_index self.IM_END_TOKEN_ID = hf_config.eoi_token_index - self.IM_TOKEN_ID = hf_config.image_token_index + self.mm_tokens = MultimodalSpecialTokens( + # The single, pre-expanded image token. + image_token="", + image_token_id=hf_config.image_token_index, + # The regex that matches expanded image tokens. + image_token_regex=re.compile( + r"(?:(?:)*)?" + ), + ).build(_processor) async def process_mm_data_async( self, @@ -39,14 +40,14 @@ async def process_mm_data_async( base_output = self.load_mm_data( prompt=input_text, image_data=image_data, - multimodal_tokens=MultimodalSpecialTokens( - image_token=self.IMAGE_TOKEN, image_token_regex=self.IMAGE_TOKEN_REGEX - ), + multimodal_tokens=self.mm_tokens, max_req_input_len=max_req_input_len, discard_alpha_channel=True, ) - mm_items, input_ids, _ = self.process_and_combine_mm_data(base_output) + mm_items, input_ids, _ = self.process_and_combine_mm_data( + base_output, self.mm_tokens + ) return { "input_ids": input_ids.tolist(), "mm_items": mm_items, diff --git a/python/sglang/srt/multimodal/processors/gemma3n.py b/python/sglang/srt/multimodal/processors/gemma3n.py index 92f3c0b939d5..aafeab7c9383 100644 --- a/python/sglang/srt/multimodal/processors/gemma3n.py +++ b/python/sglang/srt/multimodal/processors/gemma3n.py @@ -30,23 +30,23 @@ class Gemma3nSGLangProcessor(SGLangBaseProcessor): def __init__(self, hf_config, server_args, _processor): super().__init__(hf_config, server_args, _processor) - self.IMAGE_TOKEN = "" - self.IMAGE_TOKEN_REGEX = re.compile( - r"(?:(?:)*)?" - ) - - self.AUDIO_TOKEN = "" - self.AUDIO_TOKEN_REGEX = re.compile( - r"(?:(?:)*)?" - ) - - self.IM_TOKEN_ID = hf_config.image_token_id self.IM_START_TOKEN_ID = hf_config.boi_token_id self.IM_END_TOKEN_ID = hf_config.eoi_token_id - self.AUDIO_TOKEN_ID = hf_config.audio_token_id self.AUDIO_START_TOKEN_ID = hf_config.boa_token_id self.AUDIO_END_TOKEN_ID = hf_config.eoa_token_id + self.mm_tokens = MultimodalSpecialTokens( + image_token="", + image_token_id=hf_config.image_token_id, + image_token_regex=re.compile( + r"(?:(?:)*)?" + ), + audio_token="", + audio_token_id=hf_config.audio_token_id, + audio_token_regex=re.compile( + r"(?:(?:)*)?" + ), + ).build(_processor) async def process_mm_data_async( self, @@ -64,19 +64,17 @@ async def process_mm_data_async( image_data=image_data, audio_data=audio_data, max_req_input_len=max_req_input_len, - multimodal_tokens=MultimodalSpecialTokens( - image_token=self.IMAGE_TOKEN, - image_token_regex=self.IMAGE_TOKEN_REGEX, - audio_token=self.AUDIO_TOKEN, - audio_token_regex=self.AUDIO_TOKEN_REGEX, - ), + multimodal_tokens=self.mm_tokens, ) - mm_items, input_ids, _ = self.process_and_combine_mm_data(base_output) + mm_items, input_ids, _ = self.process_and_combine_mm_data( + base_output, self.mm_tokens + ) return { "input_ids": input_ids.tolist(), "mm_items": mm_items, - "im_token_id": self.IM_TOKEN_ID, - "audio_token_id": self.AUDIO_TOKEN_ID, + # TODO(mick): could we return MultimodalSpecialTokens directly? + "im_token_id": self.mm_tokens.image_token_id, + "audio_token_id": self.mm_tokens.audio_token_id, } diff --git a/python/sglang/srt/multimodal/processors/internvl.py b/python/sglang/srt/multimodal/processors/internvl.py index f9ed9ba76d86..d3413c457dde 100644 --- a/python/sglang/srt/multimodal/processors/internvl.py +++ b/python/sglang/srt/multimodal/processors/internvl.py @@ -24,7 +24,6 @@ def __init__(self, hf_config, server_args, _image_processor): self.IMG_CONTEXT_TOKEN = "" self.IMG_START_TOKEN = "" self.IMG_END_TOKEN = "" - self.IMG_TOKEN = "" self.num_image_token = int( (image_size // patch_size) ** 2 * (hf_config.downsample_ratio**2) ) @@ -32,9 +31,10 @@ def __init__(self, hf_config, server_args, _image_processor): tokenizer = self._processor self.img_start_token_id = tokenizer.convert_tokens_to_ids(self.IMG_START_TOKEN) self.img_end_token_id = tokenizer.convert_tokens_to_ids(self.IMG_END_TOKEN) - self.img_context_token_id = tokenizer.convert_tokens_to_ids( - self.IMG_CONTEXT_TOKEN - ) + self.mm_tokens = MultimodalSpecialTokens( + image_token="", + image_token_id=tokenizer.convert_tokens_to_ids(self.IMG_CONTEXT_TOKEN), + ).build(_image_processor) @staticmethod def build_transform(input_size): @@ -175,7 +175,7 @@ async def process_mm_data_async( base_output = self.load_mm_data( prompt=input_text, image_data=image_data, - multimodal_tokens=MultimodalSpecialTokens(image_token=self.IMG_TOKEN), + multimodal_tokens=self.mm_tokens, max_req_input_len=max_req_input_len, discard_alpha_channel=True, ) @@ -219,7 +219,7 @@ def process_image_internvl(image, input_size=448, max_num=12): input_ids = tokenizer(input_text, return_tensors="pt")["input_ids"].flatten() image_offsets = self.get_mm_items_offset( input_ids=input_ids, - mm_token_id=self.img_context_token_id, + mm_token_id=self.mm_tokens.image_token_id, ) items = [ MultimodalDataItem( @@ -234,5 +234,5 @@ def process_image_internvl(image, input_size=448, max_num=12): "mm_items": items, "im_start_id": self.img_start_token_id, "im_end_id": self.img_end_token_id, - "im_token_id": self.img_context_token_id, + "im_token_id": self.mm_tokens.image_token_id, } diff --git a/python/sglang/srt/multimodal/processors/janus_pro.py b/python/sglang/srt/multimodal/processors/janus_pro.py index 8ea013d29aae..28be34c57b01 100644 --- a/python/sglang/srt/multimodal/processors/janus_pro.py +++ b/python/sglang/srt/multimodal/processors/janus_pro.py @@ -11,8 +11,12 @@ class JanusProImageProcessor(BaseMultimodalProcessor): models = [MultiModalityCausalLM] - def __init__(self, hf_config, server_args, _processor): - super().__init__(hf_config, server_args, _processor) + def __init__(self, hf_config, server_args, processor): + super().__init__(hf_config, server_args, processor) + + self.mm_tokens = MultimodalSpecialTokens( + image_token=processor.image_token + ).build(processor) async def process_mm_data_async( self, @@ -27,9 +31,7 @@ async def process_mm_data_async( base_out = self.load_mm_data( prompt=input_text, image_data=image_data, - multimodal_tokens=MultimodalSpecialTokens( - image_token=processor.image_token - ), + multimodal_tokens=self.mm_tokens, max_req_input_len=max_req_input_len, ) diff --git a/python/sglang/srt/multimodal/processors/kimi_vl.py b/python/sglang/srt/multimodal/processors/kimi_vl.py index b593da48f27a..ef533c16d579 100644 --- a/python/sglang/srt/multimodal/processors/kimi_vl.py +++ b/python/sglang/srt/multimodal/processors/kimi_vl.py @@ -1,9 +1,6 @@ import re -from typing import Any, Dict, List, Optional, Union +from typing import Dict, List, Union -import torch - -from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem from sglang.srt.models.kimi_vl import KimiVLForConditionalGeneration from sglang.srt.multimodal.processors.base_processor import ( BaseMultimodalProcessor as SGLangBaseProcessor, @@ -17,9 +14,12 @@ class KimiVLImageProcessor(SGLangBaseProcessor): def __init__(self, hf_config, server_args, _processor): super().__init__(hf_config, server_args, _processor) - self.IMAGE_TOKEN = "<|media_pad|>" - self.IMAGE_TOKEN_REGEX = re.compile(r"(?:<\|media_pad\|>)+") - self.IM_TOKEN_ID = _processor.tokenizer.convert_tokens_to_ids(self.IMAGE_TOKEN) + self.mm_tokens = MultimodalSpecialTokens( + image_token="<|media_pad|>", + # TODO: could we convert in MultimodalSpecialTokens? + image_token_id=hf_config.media_placeholder_token_id, + image_token_regex=re.compile(r"(?:<\|media_pad\|>)+"), + ).build(_processor) async def process_mm_data_async( self, @@ -33,16 +33,16 @@ async def process_mm_data_async( base_output = self.load_mm_data( prompt=input_text, image_data=image_data, - multimodal_tokens=MultimodalSpecialTokens( - image_token=self.IMAGE_TOKEN, image_token_regex=self.IMAGE_TOKEN_REGEX - ), + multimodal_tokens=self.mm_tokens, max_req_input_len=max_req_input_len, ) - mm_items, input_ids, _ = self.process_and_combine_mm_data(base_output) + mm_items, input_ids, _ = self.process_and_combine_mm_data( + base_output, self.mm_tokens + ) return { "input_ids": input_ids.tolist(), "mm_items": mm_items, - "im_token_id": self.IM_TOKEN_ID, + "im_token_id": self.mm_tokens.image_token_id, } diff --git a/python/sglang/srt/multimodal/processors/minicpm.py b/python/sglang/srt/multimodal/processors/minicpm.py index 369971ccbe53..3ba547b380e0 100644 --- a/python/sglang/srt/multimodal/processors/minicpm.py +++ b/python/sglang/srt/multimodal/processors/minicpm.py @@ -17,9 +17,11 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor): def __init__(self, hf_config, server_args, _processor): super().__init__(hf_config, server_args, _processor) - self.image_token = "(./)" - self.audio_token = "()" - self.video_token = "()" + self.mm_tokens = MultimodalSpecialTokens( + image_token="(./)", + audio_token="()", + video_token="()", + ).build(_processor) async def process_mm_data_async( self, @@ -35,11 +37,7 @@ async def process_mm_data_async( max_req_input_len=max_req_input_len, audio_data=audio_data, image_data=image_data, - multimodal_tokens=MultimodalSpecialTokens( - image_token=self.image_token, - video_token=self.video_token, - audio_token=self.audio_token, - ), + multimodal_tokens=self.mm_tokens, ) if base_output is None: return None diff --git a/python/sglang/srt/multimodal/processors/mllama4.py b/python/sglang/srt/multimodal/processors/mllama4.py index ccf70adc8766..566eb3230c17 100644 --- a/python/sglang/srt/multimodal/processors/mllama4.py +++ b/python/sglang/srt/multimodal/processors/mllama4.py @@ -26,8 +26,8 @@ def __init__(self, hf_config, server_args, _processor): self.eoi_token_index = hf_config.eoi_token_index self.image_token_index = hf_config.image_token_index self.multimodal_tokens = MultimodalSpecialTokens( - image_token=_processor.image_token - ) + image_token=_processor.image_token, + ).build(_processor) async def process_mm_data_async( self, diff --git a/python/sglang/srt/multimodal/processors/phi4mm.py b/python/sglang/srt/multimodal/processors/phi4mm.py index d2e009d27f3e..aea06506d078 100644 --- a/python/sglang/srt/multimodal/processors/phi4mm.py +++ b/python/sglang/srt/multimodal/processors/phi4mm.py @@ -21,7 +21,7 @@ def __init__(self, hf_config, server_args, _processor): super().__init__(hf_config, server_args, _processor) self.multimodal_tokens = MultimodalSpecialTokens( image_token=_IMAGE_SPECIAL_TOKEN, - ) + ).build(_processor) async def process_mm_data_async( self, diff --git a/python/sglang/srt/multimodal/processors/pixtral.py b/python/sglang/srt/multimodal/processors/pixtral.py index 8b741d6279c0..b18dfa1b023e 100644 --- a/python/sglang/srt/multimodal/processors/pixtral.py +++ b/python/sglang/srt/multimodal/processors/pixtral.py @@ -55,7 +55,7 @@ def __init__(self, hf_config, server_args, _processor): self.patch_size = self.vision_config.patch_size self.multimodal_tokens = MultimodalSpecialTokens( image_token=_processor.image_token - ) + ).build(_processor) _processor.tokenizer.add_special_tokens( { "pad_token": getattr(hf_config, "pad_token", self.PAD_TOKEN), diff --git a/python/sglang/srt/multimodal/processors/qwen_vl.py b/python/sglang/srt/multimodal/processors/qwen_vl.py index 1ecb4e119ac3..bdfaf140624f 100644 --- a/python/sglang/srt/multimodal/processors/qwen_vl.py +++ b/python/sglang/srt/multimodal/processors/qwen_vl.py @@ -203,16 +203,9 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor): def __init__(self, hf_config, server_args, _processor): super().__init__(hf_config, server_args, _processor) - # The single, pre-expanded image token. - self.IMAGE_TOKEN = "<|vision_start|><|image_pad|><|vision_end|>" # The regex that matches expanded image tokens. - self.IMAGE_TOKEN_REGEX = re.compile( - r"<\|vision_start\|>(?:<\|image_pad\|>)+<\|vision_end\|>" - ) self.IM_START_TOKEN_ID = hf_config.vision_start_token_id self.IM_END_TOKEN_ID = hf_config.vision_end_token_id - self.IM_TOKEN_ID = hf_config.image_token_id - self.VIDEO_TOKEN_ID = hf_config.video_token_id self.vision_start_token_id = hf_config.vision_start_token_id self.vision_end_token_id = hf_config.vision_end_token_id self.NUM_TOKEN_PER_FRAME = 770 @@ -220,12 +213,14 @@ def __init__(self, hf_config, server_args, _processor): self.MIN_PIXELS = 4 * 28 * 28 self.MAX_PIXELS = 16384 * 28 * 28 self.MAX_RATIO = 200 - # TODO(mick): move all MultimodalSpecialTokens initializations into processor init - self.mm_special_tokens = MultimodalSpecialTokens( - image_token=self.IMAGE_TOKEN, - image_token_regex=self.IMAGE_TOKEN_REGEX, - video_token=self.VIDEO_TOKEN_ID, - ) + self.mm_tokens = MultimodalSpecialTokens( + image_token="<|vision_start|><|image_pad|><|vision_end|>", + image_token_id=hf_config.image_token_id, + image_token_regex=re.compile( + r"<\|vision_start\|>(?:<\|image_pad\|>)+<\|vision_end\|>" + ), + video_token_id=hf_config.video_token_id, + ).build(_processor) async def process_mm_data_async( self, @@ -241,7 +236,7 @@ async def process_mm_data_async( prompt=input_text, image_data=image_data, video_data=request_obj.video_data, - multimodal_tokens=self.mm_special_tokens, + multimodal_tokens=self.mm_tokens, max_req_input_len=max_req_input_len, ) @@ -255,13 +250,15 @@ async def process_mm_data_async( await preprocess_video(video) for video in base_output.videos ] - mm_items, input_ids, ret = self.process_and_combine_mm_data(base_output) + mm_items, input_ids, ret = self.process_and_combine_mm_data( + base_output, self.mm_tokens + ) input_ids = input_ids.flatten() mrope_positions, mrope_position_delta = MRotaryEmbedding.get_rope_index( spatial_merge_size=self.hf_config.vision_config.spatial_merge_size, - image_token_id=self.IM_TOKEN_ID, - video_token_id=self.VIDEO_TOKEN_ID, + image_token_id=self.mm_tokens.image_token_id, + video_token_id=self.mm_tokens.video_token_id, vision_start_token_id=self.vision_start_token_id, model_type=self.hf_config.model_type, tokens_per_second=getattr( @@ -279,8 +276,8 @@ async def process_mm_data_async( "mm_items": mm_items, "im_start_id": self.IM_START_TOKEN_ID, "im_end_id": self.IM_END_TOKEN_ID, - "im_token_id": self.IM_TOKEN_ID, - "video_token_id": self.VIDEO_TOKEN_ID, + "im_token_id": self.mm_tokens.image_token_id, + "video_token_id": self.mm_tokens.video_token_id, "mrope_positions": mrope_positions, "mrope_position_delta": mrope_position_delta, } diff --git a/python/sglang/srt/multimodal/processors/vila.py b/python/sglang/srt/multimodal/processors/vila.py index c4d676c6d09f..8e0f04acae89 100644 --- a/python/sglang/srt/multimodal/processors/vila.py +++ b/python/sglang/srt/multimodal/processors/vila.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional, Type, cast +from typing import Any, Dict, List, Optional, Type import torch.nn as nn from transformers.configuration_utils import PretrainedConfig @@ -10,7 +10,6 @@ GenerateReqInput, ImageDataInputItem, ) -from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem from sglang.srt.models.vila import VILAForConditionalGeneration from sglang.srt.multimodal.processors.base_processor import ( BaseMultimodalProcessor, @@ -37,8 +36,11 @@ def __init__( _processor: VILAProcessor, ) -> None: super().__init__(hf_config, server_args, _processor) - self.IM_TOKEN_ID = hf_config.image_token_id - self.VIDEO_TOKEN_ID = hf_config.video_token_id + self.mm_tokens = MultimodalSpecialTokens( + image_token=self._processor.tokenizer.image_token, + image_token_id=hf_config.image_token_id, + video_token_id=hf_config.video_token_id, + ).build(_processor) async def process_mm_data_async( self, @@ -50,18 +52,18 @@ async def process_mm_data_async( ) -> Optional[Dict[str, Any]]: base_output = self.load_mm_data( prompt=input_text, - multimodal_tokens=MultimodalSpecialTokens( - image_token=self._processor.tokenizer.image_token - ), + multimodal_tokens=self.mm_tokens, max_req_input_len=max_req_input_len, image_data=image_data, ) - mm_items, input_ids, _ = self.process_and_combine_mm_data(base_output) + mm_items, input_ids, _ = self.process_and_combine_mm_data( + base_output, self.mm_tokens + ) return { "input_ids": input_ids.tolist(), "mm_items": mm_items, - "im_token_id": self.IM_TOKEN_ID, - "video_token_id": self.VIDEO_TOKEN_ID, + "im_token_id": self.mm_tokens.image_token_id, + "video_token_id": self.mm_tokens.video_token_id, } From 6e92da8fca18c746a0aa15c7bd95b47b6827befa Mon Sep 17 00:00:00 2001 From: Qi Yuhang <45795032+HydraQYH@users.noreply.github.com> Date: Fri, 18 Jul 2025 11:49:36 +0800 Subject: [PATCH 025/396] [Fix][Ready]Fix register spilling in cutlass nvfp4 gemm kernel on Blackwell (#8127) --- .../csrc/gemm/nvfp4_scaled_mm_kernels.cu | 52 ++++++++++--------- 1 file changed, 28 insertions(+), 24 deletions(-) diff --git a/sgl-kernel/csrc/gemm/nvfp4_scaled_mm_kernels.cu b/sgl-kernel/csrc/gemm/nvfp4_scaled_mm_kernels.cu index 4fc4972dc0e1..d1193ea4473a 100644 --- a/sgl-kernel/csrc/gemm/nvfp4_scaled_mm_kernels.cu +++ b/sgl-kernel/csrc/gemm/nvfp4_scaled_mm_kernels.cu @@ -40,27 +40,21 @@ using namespace cute; #if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) // Kernel Perf config template -struct KernelTraits; - -template <> -struct KernelTraits { - using MmaTileShape = Shape<_128, _128, _256>; - using ClusterShape = Shape<_1, _1, _1>; - using PerSmTileShape_MNK = Shape<_128, _128, _256>; -}; - -template <> -struct KernelTraits { +struct KernelTraits { using MmaTileShape = Shape<_256, _256, _256>; - using ClusterShape = Shape<_4, _4, _1>; - using PerSmTileShape_MNK = Shape<_128, _256, _256>; + using ClusterShape = Shape; + using EpilogueTile = Shape<_128, _64>; + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized2Sm; + using MainloopSchedule = cutlass::gemm::KernelTmaWarpSpecialized2SmNvf4Sm100; }; template <> -struct KernelTraits { - using MmaTileShape = Shape<_256, _256, _256>; - using ClusterShape = Shape<_4, _4, _1>; - using PerSmTileShape_MNK = Shape<_128, _256, _256>; +struct KernelTraits { + using MmaTileShape = Shape<_128, _128, _256>; + using ClusterShape = Shape; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized1Sm; + using MainloopSchedule = cutlass::gemm::KernelTmaWarpSpecialized1SmNvf4Sm100; }; template @@ -90,23 +84,26 @@ struct Fp4GemmSm100 { // Kernel Perf config using MmaTileShape = typename KernelTraits::MmaTileShape; using ClusterShape = typename KernelTraits::ClusterShape; - using PerSmTileShape_MNK = typename KernelTraits::PerSmTileShape_MNK; + using EpilogueTile = typename KernelTraits::EpilogueTile; + using EpilogueSchedule = typename KernelTraits::EpilogueSchedule; + using MainloopSchedule = typename KernelTraits::MainloopSchedule; using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< ArchTag, - OperatorClass, - PerSmTileShape_MNK, + cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, - cutlass::epilogue::collective::EpilogueTileAuto, + EpilogueTile, ElementAccumulator, ElementAccumulator, - ElementC, + void, LayoutCTag, AlignmentC, ElementD, LayoutDTag, AlignmentD, - cutlass::epilogue::collective::EpilogueScheduleAuto>::CollectiveOp; + EpilogueSchedule, + cutlass::epilogue::fusion::LinearCombination>::CollectiveOp; using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< ArchTag, @@ -122,7 +119,7 @@ struct Fp4GemmSm100 { ClusterShape, cutlass::gemm::collective::StageCountAutoCarveout( sizeof(typename CollectiveEpilogue::SharedStorage))>, - cutlass::gemm::collective::KernelScheduleAuto>::CollectiveOp; + MainloopSchedule>::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal, CollectiveMainloop, CollectiveEpilogue, void>; @@ -191,6 +188,13 @@ typename T::Gemm::Arguments args_from_options( stride_D}}; auto& fusion_args = arguments.epilogue.thread; fusion_args.alpha_ptr = static_cast(alpha.data_ptr()); + if constexpr (std::is_same_v) { + arguments.hw_info.cluster_shape = dim3(1, 4, 1); + arguments.hw_info.cluster_shape_fallback = dim3(1, 1, 1); + } else { + arguments.hw_info.cluster_shape = dim3(4, 4, 1); + arguments.hw_info.cluster_shape_fallback = dim3(2, 1, 1); + } return arguments; } From 8a3235570403f203021b4d1730dcce04f652ff96 Mon Sep 17 00:00:00 2001 From: Minglei Zhu Date: Thu, 17 Jul 2025 20:56:03 -0700 Subject: [PATCH 026/396] Feat: Support Granite 3.0 MoE in SGLang (#7959) --- docs/supported_models/generative_models.md | 2 + python/sglang/srt/models/granitemoe.py | 379 +++++++++++++++++++++ 2 files changed, 381 insertions(+) create mode 100644 python/sglang/srt/models/granitemoe.py diff --git a/docs/supported_models/generative_models.md b/docs/supported_models/generative_models.md index f1a941cdc3bf..0096d6e0932d 100644 --- a/docs/supported_models/generative_models.md +++ b/docs/supported_models/generative_models.md @@ -45,3 +45,5 @@ in the GitHub search bar. | **SmolLM** (135M–1.7B) | `HuggingFaceTB/SmolLM-1.7B` | Hugging Face’s ultra-small LLM series (135M–1.7B params) offering surprisingly strong results, enabling advanced AI on mobile/edge devices. | | **GLM-4** (Multilingual 9B) | `ZhipuAI/glm-4-9b-chat` | Zhipu’s GLM-4 series (up to 9B parameters) – open multilingual models with support for 1M-token context and even a 5.6B multimodal variant (Phi-4V). | | **MiMo** (7B series) | `XiaomiMiMo/MiMo-7B-RL` | Xiaomi's reasoning-optimized model series, leverages Multiple-Token Prediction for faster inference. | +| **Granite 3.0, 3.1** (IBM) | `ibm-granite/granite-3.1-8b-instruct` | IBM's open dense foundation models optimized for reasoning, code, and business AI use cases. Integrated with Red Hat and watsonx systems. | +| **Granite 3.0 MoE** (IBM) | `ibm-granite/granite-3.0-3b-a800m-instruct` | IBM’s Mixture-of-Experts models offering strong performance with cost-efficiency. MoE expert routing designed for enterprise deployment at scale. | diff --git a/python/sglang/srt/models/granitemoe.py b/python/sglang/srt/models/granitemoe.py new file mode 100644 index 000000000000..b4a9c17af56f --- /dev/null +++ b/python/sglang/srt/models/granitemoe.py @@ -0,0 +1,379 @@ +"""Inference-only GraniteMoe model.""" + +from typing import Iterable, Optional + +import torch +from torch import nn +from transformers import GraniteConfig + +from sglang.srt.distributed import get_tensor_model_parallel_world_size +from sglang.srt.layers.layernorm import RMSNorm +from sglang.srt.layers.linear import ( + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) +from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput +from sglang.srt.layers.moe.fused_moe_triton import FusedMoE +from sglang.srt.layers.pooler import Pooler, PoolingType +from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.rotary_embedding import get_rope +from sglang.srt.layers.vocab_parallel_embedding import ( + ParallelLMHead, + VocabParallelEmbedding, +) +from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.models import mixtral +from sglang.srt.utils import add_prefix + + +class GraniteMoeMoE(nn.Module): + """A tensor-parallel MoE implementation for GraniteMoe that shards each + expert across all ranks. + Each expert's weights are sharded across all ranks and a fused MoE + kernel is used for the forward pass, and finally we reduce the outputs + across ranks. + """ + + def __init__( + self, + num_experts: int, + top_k: int, + hidden_size: int, + intermediate_size: int, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + tp_size: Optional[int] = None, + prefix: str = "", + ): + super().__init__() + self.hidden_size = hidden_size + + # Gate always runs at half / full precision for now. + self.gate = ReplicatedLinear( + hidden_size, + num_experts, + bias=False, + params_dtype=params_dtype, + quant_config=None, + prefix=f"{prefix}.gate", + ) + + self.experts = FusedMoE( + num_experts=num_experts, + top_k=top_k, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + params_dtype=params_dtype, + reduce_results=True, + renormalize=True, + quant_config=quant_config, + tp_size=tp_size, + prefix=f"{prefix}.experts", + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # NOTE: hidden_states can have either 1D or 2D shape. + orig_shape = hidden_states.shape + hidden_states = hidden_states.view(-1, self.hidden_size) + router_logits, _ = self.gate(hidden_states) + final_hidden_states = self.experts(hidden_states, router_logits) + return final_hidden_states.view(orig_shape) + + +class GraniteMoeAttention(nn.Module): + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + max_position: int = 4096 * 32, + layer_id: int = 0, + rope_theta: float = 10000, + quant_config: Optional[QuantizationConfig] = None, + attention_multiplier: Optional[float] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = ( + attention_multiplier + if attention_multiplier is not None + else self.head_dim**-1 + ) + self.rope_theta = rope_theta + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position, + base=int(self.rope_theta), + is_neox_style=True, + ) + self.attn = RadixAttention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + layer_id=layer_id, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + forward_batch: ForwardBatch, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v, forward_batch) + output, _ = self.o_proj(attn_output) + return output + + +class GraniteMoeDecoderLayer(nn.Module): + + def __init__( + self, + config: GraniteConfig, + layer_id: int = 0, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + self.self_attn = GraniteMoeAttention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + max_position=config.max_position_embeddings, + num_kv_heads=config.num_key_value_heads, + rope_theta=rope_theta, + layer_id=layer_id, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + attention_multiplier=config.attention_multiplier, + ) + self.block_sparse_moe = GraniteMoeMoE( + num_experts=config.num_local_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + quant_config=quant_config, + prefix=f"{prefix}.block_sparse_moe", + ) + + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + self.residual_multiplier = config.residual_multiplier + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + forward_batch: ForwardBatch, + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + # Self Attention + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + forward_batch=forward_batch, + ) + hidden_states = residual + hidden_states * self.residual_multiplier + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.block_sparse_moe(hidden_states) + hidden_states = residual + hidden_states * self.residual_multiplier + + return hidden_states + + +class GraniteMoeModel(nn.Module): + + def __init__( + self, + config: GraniteConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + ) + self.embedding_multiplier = config.embedding_multiplier + + self.layers = nn.ModuleList( + [ + GraniteMoeDecoderLayer( + config, + i, + quant_config=quant_config, + prefix=add_prefix(f"layers.{i}", prefix), + ) + for i in range(config.num_hidden_layers) + ] + ) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + hidden_states *= self.embedding_multiplier + + for i in range(len(self.layers)): + layer = self.layers[i] + hidden_states = layer( + positions, + hidden_states, + forward_batch, + ) + hidden_states = self.norm(hidden_states) + return hidden_states + + +class GraniteMoeForCausalLM(nn.Module): + + def __init__( + self, + config: GraniteConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.config = config + self.quant_config = quant_config + + self.model = GraniteMoeModel( + config, quant_config=quant_config, prefix=add_prefix("model", prefix) + ) + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=add_prefix("lm_head", prefix), + ) + if config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + # Granite logit scaling factors are applied via division, but + # LogitsProcessor expects a multiplicative factor. + if hasattr(config, "logits_scaling"): + logit_scale = 1.0 / config.logits_scaling + else: + logit_scale = None + self.logits_processor = LogitsProcessor(config, logit_scale=logit_scale) + self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) + + @torch.no_grad() + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + input_embeds: torch.Tensor = None, + get_embedding: bool = False, + ) -> LogitsProcessorOutput: + hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) + if not get_embedding: + logits_processor_output: LogitsProcessorOutput = self.logits_processor( + input_ids, hidden_states, self.lm_head, forward_batch + ) + return logits_processor_output + else: + return self.pooler(hidden_states, forward_batch) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + new_weights = {} + for n, p in weights: + if n.endswith(".block_sparse_moe.input_linear.weight"): + for e in range(p.size(0)): + w1_name = n.replace( + ".block_sparse_moe.input_linear.weight", + f".block_sparse_moe.experts.{e}.w1.weight", + ) + w3_name = n.replace( + ".block_sparse_moe.input_linear.weight", + f".block_sparse_moe.experts.{e}.w3.weight", + ) + w1_param, w3_param = p[e].chunk(2, dim=0) + assert w1_name not in new_weights + assert w3_name not in new_weights + new_weights[w1_name] = w1_param + new_weights[w3_name] = w3_param + elif n.endswith(".block_sparse_moe.output_linear.weight"): + for e in range(p.size(0)): + w2_name = n.replace( + ".block_sparse_moe.output_linear.weight", + f".block_sparse_moe.experts.{e}.w2.weight", + ) + w2_param = p[e] + assert w2_name not in new_weights + new_weights[w2_name] = w2_param + elif n.endswith(".block_sparse_moe.router.layer.weight"): + gate_name = n.replace( + ".block_sparse_moe.router.layer.weight", + ".block_sparse_moe.gate.weight", + ) + assert gate_name not in new_weights + new_weights[gate_name] = p + else: + new_weights[n] = p + mixtral.MixtralForCausalLM.load_weights(self, new_weights.items()) + + +EntryClass = [GraniteMoeForCausalLM] From 8aa5ae6b042f09d9beb2b0e814ea9c2311b6c2b6 Mon Sep 17 00:00:00 2001 From: yilian49 <43861414+yilian49@users.noreply.github.com> Date: Fri, 18 Jul 2025 00:41:32 -0400 Subject: [PATCH 027/396] load draft model fix (#7506) --- python/sglang/srt/model_loader/loader.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/model_loader/loader.py b/python/sglang/srt/model_loader/loader.py index 733e6df9e4de..2e2f71078382 100644 --- a/python/sglang/srt/model_loader/loader.py +++ b/python/sglang/srt/model_loader/loader.py @@ -575,7 +575,13 @@ def load_model( # 2. Post-processing of weights, including assigning specific member variables. # For `dummy_init`, only the second stage is required. if hasattr(model, "post_load_weights"): - model.post_load_weights() + if ( + model_config.hf_config.architectures[0] + == "DeepseekV3ForCausalLMNextN" + ): + model.post_load_weights(is_nextn=True) + else: + model.post_load_weights() return model.eval() From 48c1fa7bb6950b81788a84da32c3c42bc7c77e67 Mon Sep 17 00:00:00 2001 From: jianan-gu Date: Fri, 18 Jul 2025 12:43:25 +0800 Subject: [PATCH 028/396] [CPU][Llama4] Fix Llama4 MoE inputs with "apply_router_weight_on_input" (#7889) --- python/sglang/srt/configs/update_config.py | 4 +++- python/sglang/srt/layers/moe/topk.py | 13 +++++++++++++ python/sglang/srt/layers/quantization/fp8.py | 6 ++++++ python/sglang/srt/layers/quantization/unquant.py | 11 ++++++++--- python/sglang/srt/layers/quantization/w8a8_int8.py | 5 +++++ 5 files changed, 35 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/configs/update_config.py b/python/sglang/srt/configs/update_config.py index f9e6d15a85f1..241d9566ab5e 100644 --- a/python/sglang/srt/configs/update_config.py +++ b/python/sglang/srt/configs/update_config.py @@ -115,5 +115,7 @@ def adjust_config_with_unaligned_cpu_tp( model_config = update_intermediate_size( model_config, "intermediate_size", intermediate_padding_size ) - + model_config = update_intermediate_size( + model_config, "intermediate_size_mlp", intermediate_padding_size + ) return model_config diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index 1c8d219e4ec0..40fc0b61f650 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -93,6 +93,19 @@ def fused_topk_cpu( return topk_weights, topk_ids +def apply_topk_weights_cpu(need_apply, topk_weights, inputs): + if not need_apply: + return inputs, topk_weights + + # TODO: fuse below processing in fused_experts_cpu kernel + inputs = inputs * topk_weights.to(inputs.dtype) + topk_weights = torch.ones_like( + topk_weights, dtype=torch.float32 + ) # clear topk_weights as already applied + + return inputs, topk_weights + + def fused_topk( hidden_states: torch.Tensor, gating_output: torch.Tensor, diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index 38588c809039..7275ea430132 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -1005,6 +1005,12 @@ def apply( ) if use_intel_amx_backend(layer): + from sglang.srt.layers.moe.topk import apply_topk_weights_cpu + + x, topk_weights = apply_topk_weights_cpu( + apply_router_weight_on_input, topk_weights, x + ) + return torch.ops.sgl_kernel.fused_experts_cpu( x, layer.w13_weight, diff --git a/python/sglang/srt/layers/quantization/unquant.py b/python/sglang/srt/layers/quantization/unquant.py index 28d006255d8e..821b1cb8509b 100644 --- a/python/sglang/srt/layers/quantization/unquant.py +++ b/python/sglang/srt/layers/quantization/unquant.py @@ -344,9 +344,12 @@ def forward_cpu( ) -> torch.Tensor: assert activation == "silu", f"activation = {activation} is not supported." - if use_intel_amx_backend(layer) and not apply_router_weight_on_input: + if use_intel_amx_backend(layer): - from sglang.srt.layers.moe.topk import select_experts + from sglang.srt.layers.moe.topk import ( + select_experts, + apply_topk_weights_cpu, + ) topk_weights, topk_ids = select_experts( hidden_states=x, @@ -361,8 +364,10 @@ def forward_cpu( correction_bias=correction_bias, routed_scaling_factor=routed_scaling_factor, ) + x, topk_weights = apply_topk_weights_cpu( + apply_router_weight_on_input, topk_weights, x + ) - # TODO: support apply_router_weight_on_input in the fused_experts_cpu kernel return torch.ops.sgl_kernel.fused_experts_cpu( x, layer.w13_weight, diff --git a/python/sglang/srt/layers/quantization/w8a8_int8.py b/python/sglang/srt/layers/quantization/w8a8_int8.py index c8a024bf33ed..56ac26c57823 100644 --- a/python/sglang/srt/layers/quantization/w8a8_int8.py +++ b/python/sglang/srt/layers/quantization/w8a8_int8.py @@ -497,6 +497,11 @@ def apply( ) if use_intel_amx_backend(layer): + from sglang.srt.layers.moe.topk import apply_topk_weights_cpu + + x, topk_weights = apply_topk_weights_cpu( + apply_router_weight_on_input, topk_weights, x + ) return torch.ops.sgl_kernel.fused_experts_cpu( x, layer.w13_weight, From 7891bac16b0a905aacfbbe49709d740916555ae0 Mon Sep 17 00:00:00 2001 From: jianan-gu Date: Fri, 18 Jul 2025 13:03:56 +0800 Subject: [PATCH 029/396] [Quantization][w8a8_int8] Fix weight loading issue for w8a8_int8 path with "ignore" layer list in quantization config (#7820) --- .../sglang/srt/layers/quantization/unquant.py | 2 +- .../srt/layers/quantization/w8a8_int8.py | 36 +++++++++++-------- 2 files changed, 22 insertions(+), 16 deletions(-) diff --git a/python/sglang/srt/layers/quantization/unquant.py b/python/sglang/srt/layers/quantization/unquant.py index 821b1cb8509b..06afcb70be91 100644 --- a/python/sglang/srt/layers/quantization/unquant.py +++ b/python/sglang/srt/layers/quantization/unquant.py @@ -347,8 +347,8 @@ def forward_cpu( if use_intel_amx_backend(layer): from sglang.srt.layers.moe.topk import ( - select_experts, apply_topk_weights_cpu, + select_experts, ) topk_weights, topk_ids = select_experts( diff --git a/python/sglang/srt/layers/quantization/w8a8_int8.py b/python/sglang/srt/layers/quantization/w8a8_int8.py index 56ac26c57823..c9af7ae29cc7 100644 --- a/python/sglang/srt/layers/quantization/w8a8_int8.py +++ b/python/sglang/srt/layers/quantization/w8a8_int8.py @@ -3,7 +3,7 @@ import importlib import sys from types import MappingProxyType -from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union, cast import torch from torch.nn.parameter import Parameter @@ -24,6 +24,7 @@ QuantizationConfig, QuantizeMethodBase, ) +from sglang.srt.layers.quantization.compressed_tensors.utils import should_ignore_layer from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8 from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod from sglang.srt.utils import ( @@ -178,17 +179,18 @@ class W8A8Int8Config(QuantizationConfig): - Activation: dynamic, per-token, symmetric """ - def __init__(self, quant_config: Dict[str, Any]): + def __init__(self, quant_config: Dict[str, Any] = {}): super().__init__() self.quant_description = quant_config self.is_dynamic = quant_config.get("is_dynamic", False) - if _is_npu: - if ( - "packed_modules_mapping" in quant_config - and quant_config["packed_modules_mapping"] is not None - ): - self.packed_modules_mapping = quant_config["packed_modules_mapping"] + ignore = cast(List[str], quant_config.get("ignore", [])) + self.ignore = ignore if ignore is not None else [] + packed_modules_mapping = quant_config.get("packed_modules_mapping", {}) + self.packed_modules_mapping = ( + packed_modules_mapping if packed_modules_mapping is not None else {} + ) + if _is_npu: # Ascend w8a8_int8 quantization with bias, use wrappers to isolate the effects between models for name in self.quant_description.keys(): if "norm.bias" in name: @@ -237,7 +239,7 @@ def get_quant_method( layer: torch.nn.Module, prefix: str, ) -> Optional[QuantizeMethodBase]: - from sglang.srt.layers.linear import LinearBase + from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod from sglang.srt.layers.moe.fused_moe_triton import FusedMoE if _is_npu: @@ -262,12 +264,16 @@ def get_quant_method( elif isinstance(layer, FusedMoE): return NPU_W8A8MoEMethod(self) return None - else: - if isinstance(layer, LinearBase): - return W8A8Int8LinearMethod(self) - elif isinstance(layer, FusedMoE): - return W8A8Int8MoEMethod(self) - return None + + if should_ignore_layer( + prefix, ignore=self.ignore, fused_mapping=self.packed_modules_mapping + ): + return UnquantizedLinearMethod() + if isinstance(layer, LinearBase): + return W8A8Int8LinearMethod(self) + elif isinstance(layer, FusedMoE): + return W8A8Int8MoEMethod(self) + return None def is_layer_skipped( self, prefix: str, fused_mapping: Mapping[str, List[str]] = MappingProxyType({}) From 9d33fcfb8e93c4a01fb39c6609c71f7104cb3371 Mon Sep 17 00:00:00 2001 From: Zhiqiang Xie Date: Fri, 18 Jul 2025 00:20:19 -0700 Subject: [PATCH 030/396] Hicache Storage Layer Prototype (#7704) --- .../sglang/srt/managers/cache_controller.py | 241 ++++++++++++++++++ python/sglang/srt/managers/scheduler.py | 14 + .../sglang/srt/mem_cache/hicache_storage.py | 152 +++++++++++ python/sglang/srt/mem_cache/hiradix_cache.py | 183 ++++++++++++- .../sglang/srt/mem_cache/memory_pool_host.py | 38 +++ python/sglang/srt/mem_cache/radix_cache.py | 26 ++ python/sglang/srt/server_args.py | 8 + test/srt/run_suite.py | 1 + test/srt/test_hicache_storage.py | 55 ++++ 9 files changed, 714 insertions(+), 4 deletions(-) create mode 100644 python/sglang/srt/mem_cache/hicache_storage.py create mode 100644 test/srt/test_hicache_storage.py diff --git a/python/sglang/srt/managers/cache_controller.py b/python/sglang/srt/managers/cache_controller.py index cad1d74b71de..5f43a5e9a033 100644 --- a/python/sglang/srt/managers/cache_controller.py +++ b/python/sglang/srt/managers/cache_controller.py @@ -25,6 +25,8 @@ from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator from sglang.srt.mem_cache.memory_pool_host import HostKVCache +from sglang.srt.mem_cache.hicache_storage import HiCacheFile, get_hash_str + logger = logging.getLogger(__name__) @@ -159,6 +161,57 @@ def clear(self): self.buffers.queue.clear() +class StorageOperation: + counter = 0 + + def __init__( + self, + host_indices: torch.Tensor, + token_ids: List[int], + last_hash: Optional[str] = None, + ): + self.host_indices = host_indices + self.token_ids = token_ids + self.last_hash = last_hash + self.completed_tokens = 0 + self.hash_value = [] + + self.id = StorageOperation.counter + StorageOperation.counter += 1 + + def __lt__(self, other: "StorageOperation"): + return self.id < other.id + + +class PrefetchOperation(StorageOperation): + def __init__( + self, + request_id: str, + host_indices: torch.Tensor, + token_ids: List[int], + last_hash: Optional[str] = None, + ): + self.request_id = request_id + + self._done_flag = False + self._lock = threading.Lock() + + super().__init__(host_indices, token_ids, last_hash) + + def increment(self, num_tokens: int): + with self._lock: + if self._done_flag: + return + self.completed_tokens += num_tokens + + def mark_done(self): + with self._lock: + self._done_flag = True + + def is_done(self) -> bool: + return self._done_flag + + class HiCacheController: def __init__( @@ -169,6 +222,8 @@ def __init__( load_cache_event: threading.Event = None, write_policy: str = "write_through_selective", io_backend: str = "", + storage_backend: Optional[str] = None, + prefetch_threshold: int = 256, ): self.mem_pool_device_allocator = token_to_kv_pool_allocator self.mem_pool_device = token_to_kv_pool_allocator.get_kvcache() @@ -186,6 +241,19 @@ def __init__( else: self.io_backend = io_backend + self.enable_storage = False + # todo: move backend initialization to storage backend module + if storage_backend is not None: + if storage_backend == "file": + self.storage_backend = HiCacheFile() + self.enable_storage = True + # todo: threshold policy for prefetching + self.prefetch_threshold = prefetch_threshold + else: + raise NotImplementedError( + f"Unsupported storage backend: {storage_backend}" + ) + self.load_cache_event = load_cache_event self.layer_done_counter = LayerDoneCounter(self.mem_pool_device.layer_num) self.mem_pool_device.register_layer_transfer_counter(self.layer_done_counter) @@ -218,9 +286,26 @@ def __init__( self.load_thread = threading.Thread( target=self.load_thread_func_layer_by_layer, daemon=True ) + self.write_thread.start() self.load_thread.start() + if self.enable_storage: + self.prefetch_thread = threading.Thread( + target=self.prefetch_thread_func, daemon=True + ) + self.backup_thread = threading.Thread( + target=self.backup_thread_func, daemon=True + ) + self.prefetch_queue = Queue() + self.backup_queue = Queue() + + self.prefetch_revoke_queue = Queue() + self.ack_backup_queue = Queue() + + self.prefetch_thread.start() + self.backup_thread.start() + def reset(self): self.stop_event.set() self.write_thread.join() @@ -232,6 +317,13 @@ def reset(self): self.load_buffer.clear() self.ack_write_queue.queue.clear() self.ack_load_queue.queue.clear() + if self.enable_storage: + self.prefetch_thread.join() + self.backup_thread.join() + self.prefetch_queue.queue.clear() + self.backup_queue.queue.clear() + self.prefetch_revoke_queue.queue.clear() + self.ack_backup_queue.queue.clear() self.write_thread = threading.Thread( target=self.write_thread_func_direct, daemon=True @@ -243,6 +335,16 @@ def reset(self): self.write_thread.start() self.load_thread.start() + if self.enable_storage: + self.prefetch_thread = threading.Thread( + target=self.prefetch_thread_func, daemon=True + ) + self.backup_thread = threading.Thread( + target=self.backup_thread_func, daemon=True + ) + self.prefetch_thread.start() + self.backup_thread.start() + def write( self, device_indices: torch.Tensor, @@ -383,3 +485,142 @@ def evict_host(self, host_indices: torch.Tensor, backup_only: bool = True) -> in raise ValueError( f"Inconsistent states: {self.mem_pool_host.get_state(host_indices)}" ) + + def prefetch( + self, + request_id: str, + host_indices: torch.Tensor, + new_input_tokens: List[int], + last_hash: Optional[str] = None, + ) -> int: + """ + Prefetch KV caches from storage backend to host memory. + """ + operation = PrefetchOperation( + request_id, host_indices, new_input_tokens, last_hash + ) + self.prefetch_queue.put(operation) + return operation + + def terminate_prefetch(self, operation): + operation.mark_done() + return operation.completed_tokens, operation.hash_value + + def prefetch_io_aux_func(self): + """ + Auxiliary function conducting IO operations for prefetching. + """ + while not self.stop_event.is_set(): + try: + operation = self.prefetch_buffer.get(block=True, timeout=1) + for h in operation.hash_value: + page_data = self.storage_backend.get(h) + if page_data is None: + logger.warning( + f"Prefetch operation {operation.request_id} failed to retrieve page {h}." + ) + break + self.mem_pool_host.set_from_flat_data_page( + operation.host_indices[operation.completed_tokens], + page_data, + ) + operation.increment(self.page_size) + if operation.is_done(): + # operation terminated by controller, release pre-allocated memory + self.mem_pool_host.free( + operation.host_indices[operation.completed_tokens :] + ) + break + except Empty: + continue + + def prefetch_thread_func(self): + """ + Manage prefetching operations from storage backend to host memory. + """ + self.prefetch_buffer = Queue() + aux_thread = threading.Thread(target=self.prefetch_io_aux_func, daemon=True) + aux_thread.start() + while (not self.stop_event.is_set()) or not self.prefetch_queue.empty(): + try: + operation = self.prefetch_queue.get(block=True, timeout=1) + if operation is None: + continue + + last_hash = operation.last_hash + tokens_to_fetch = operation.token_ids + + storage_hit_count = 0 + remaining_tokens = len(tokens_to_fetch) + hash_value = [] + while remaining_tokens >= self.page_size: + last_hash = get_hash_str( + tokens_to_fetch[ + storage_hit_count : storage_hit_count + self.page_size + ], + last_hash, + ) + if self.storage_backend.exists(last_hash): + storage_hit_count += self.page_size + hash_value.append(last_hash) + remaining_tokens -= self.page_size + else: + break + + if storage_hit_count < self.prefetch_threshold: + # not to prefetch if not enough benefits + self.prefetch_revoke_queue.put(operation.request_id) + else: + operation.hash_value = hash_value + logger.debug( + f"Prefetching {len(hash_value)} pages for request {operation.request_id}." + ) + self.prefetch_buffer.put(operation) + + except Empty: + continue + + def write_storage( + self, + host_indices: torch.Tensor, + token_ids: List[int], + last_hash: Optional[str] = None, + ) -> int: + """ + Write KV caches from host memory to storage backend. + """ + operation = StorageOperation(host_indices, token_ids, last_hash) + self.backup_queue.put(operation) + return operation.id + + def backup_thread_func(self): + """ + Manage backup operations from host memory to storage backend. + """ + while not self.stop_event.is_set(): + try: + operation = self.backup_queue.get(block=True, timeout=1) + if operation is None: + continue + + last_hash = operation.last_hash + tokens_to_backup = operation.token_ids + + for i in range(0, len(tokens_to_backup), self.page_size): + last_hash = get_hash_str( + tokens_to_backup[i : i + self.page_size], last_hash + ) + # todo, handle failures in storage backend + self.storage_backend.set( + last_hash, + self.mem_pool_host.get_flat_data_page( + operation.host_indices[i] + ), + ) + operation.completed_tokens += self.page_size + operation.hash_value.append(last_hash) + + self.ack_backup_queue.put((operation.id, operation.hash_value)) + + except Empty: + continue diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 874ed60f0fd2..c79e296f60f9 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -262,6 +262,7 @@ def __init__( ) self.gpu_id = gpu_id self.enable_hierarchical_cache = server_args.enable_hierarchical_cache + self.enable_hicache_storage = server_args.hicache_storage_backend is not None self.page_size = server_args.page_size self.dp_size = server_args.dp_size self.attn_tp_rank, self.attn_tp_size, self.attn_dp_rank = ( @@ -614,6 +615,7 @@ def init_memory_pool_and_cache(self): == "fa3" # hot fix for incompatibility else server_args.hicache_io_backend ), + hicache_storage_backend=server_args.hicache_storage_backend, ) self.tp_worker.register_hicache_layer_transfer_counter( self.tree_cache.cache_controller.layer_done_counter @@ -1258,6 +1260,15 @@ def _add_request_to_queue(self, req: Req): elif self.disaggregation_mode == DisaggregationMode.DECODE: self.disagg_decode_prealloc_queue.add(req) else: + if self.enable_hicache_storage: + req.init_next_round_input(self.tree_cache) + last_hash = req.last_host_node.get_last_hash_value() + matched_len = len(req.prefix_indices) + req.host_hit_length + if (matched_len > 0 and last_hash is not None) or matched_len == 0: + new_input_tokens = req.fill_ids[matched_len:] + self.tree_cache.prefetch_from_storage( + req.rid, req.last_host_node, new_input_tokens, last_hash + ) self.waiting_queue.append(req) def _extend_requests_to_queue(self, reqs: List[Req], is_retracted: bool = False): @@ -1731,6 +1742,9 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]: self.running_batch.batch_is_full = True break + if self.enable_hicache_storage: + self.tree_cache.check_prefetch_progress(req.rid) + req.init_next_round_input(self.tree_cache) res = adder.add_one_req(req, has_chunked_req=(self.chunked_req is not None)) diff --git a/python/sglang/srt/mem_cache/hicache_storage.py b/python/sglang/srt/mem_cache/hicache_storage.py new file mode 100644 index 000000000000..1dfe661ab5c9 --- /dev/null +++ b/python/sglang/srt/mem_cache/hicache_storage.py @@ -0,0 +1,152 @@ +import hashlib +import logging +import os +from abc import ABC, abstractmethod +from typing import List, Optional + +import torch + +logger = logging.getLogger(__name__) + + +def get_hash_str(token_ids: List[int], prior_hash: Optional[str] = None) -> str: + hasher = hashlib.sha256() + + if prior_hash: + hasher.update(bytes.fromhex(prior_hash)) + + for t in token_ids: + hasher.update(t.to_bytes(4, byteorder="little", signed=False)) + + return hasher.hexdigest() + + +class HiCacheStorage(ABC): + """ + HiCacheStorage is a class that provides a generic key-value interface for storing and retrieving KV cache. + It abstracts the underlying storage mechanism, allowing different implementations to be used. + """ + + # todo, translate tensor object access for different TP ranks + # potentially pass model and TP configs into storage backend + # todo, the page size of storage backend does not have to be the same as the same as host memory pool + + @abstractmethod + def get( + self, key: str, target_location: Optional[torch.Tensor] = None + ) -> torch.Tensor | None: + """ + Retrieve the value associated with the given key. + Returns None if the key does not exist. + """ + pass + + @abstractmethod + def batch_get( + self, keys: List[str], target_locations: Optional[List[torch.Tensor]] = None + ) -> List[torch.Tensor | None]: + """ + Retrieve values for multiple keys. + Returns a list of tensors or None for each key. + """ + pass + + @abstractmethod + def set(self, key, value) -> bool: + """ + Store the value associated with the given key. + Returns True if the operation was successful, False otherwise. + """ + pass + + @abstractmethod + def batch_set(self, keys: List[str], values: List[torch.Tensor]) -> bool: + """ + Store multiple key-value pairs. + Returns True if all operations were successful, False otherwise. + """ + pass + + @abstractmethod + def exists(self, key: str) -> bool: + """ + Check if the key exists in the storage. + Returns True if the key exists, False otherwise. + """ + pass + + +class HiCacheFile(HiCacheStorage): + + def __init__(self, file_path: str = "/tmp/hicache"): + self.file_path = file_path + if not os.path.exists(self.file_path): + os.makedirs(self.file_path) + logger.info(f"Created HiCacheFile storage directory at {self.file_path}") + + def get( + self, key: str, target_location: Optional[torch.Tensor] = None + ) -> torch.Tensor | None: + tensor_path = os.path.join(self.file_path, f"{key}.bin") + try: + # todo: fixing the target_location logic to enable in-place loading + loaded_tensor = torch.load(tensor_path) + if isinstance(loaded_tensor, torch.Tensor): + return loaded_tensor + else: + logger.error(f"Loaded data for key {key} is not a tensor.") + return None + except FileNotFoundError: + return None + + def batch_get( + self, + keys: List[str], + target_locations: Optional[List[torch.Tensor]] = None, + ) -> List[torch.Tensor | None]: + return [ + self.get(key, target_location) + for key, target_location in zip( + keys, target_locations or [None] * len(keys) + ) + ] + + def set(self, key: str, value: torch.Tensor) -> bool: + tensor_path = os.path.join(self.file_path, f"{key}.bin") + if self.exists(key): + logger.debug(f"Key {key} already exists. Skipped.") + return True + try: + torch.save(value, tensor_path) + return True + except Exception as e: + logger.error(f"Failed to save tensor {key}: {e}") + return False + + def batch_set(self, keys: List[str], values: List[torch.Tensor]) -> bool: + for key, value in zip(keys, values): + if not self.set(key, value): + return False + return True + + def exists(self, key: str) -> bool: + tensor_path = os.path.join(self.file_path, f"{key}.bin") + return os.path.exists(tensor_path) + + def delete(self, key: str) -> None: + tensor_path = os.path.join(self.file_path, f"{key}.bin") + try: + os.remove(tensor_path) + except FileNotFoundError: + logger.warning(f"Key {key} does not exist. Cannot delete.") + return + + def clear(self) -> None: + try: + for filename in os.listdir(self.file_path): + file_path = os.path.join(self.file_path, filename) + if os.path.isfile(file_path): + os.remove(file_path) + logger.info("Cleared all entries in HiCacheFile storage.") + except Exception as e: + logger.error(f"Failed to clear HiCacheFile storage: {e}") diff --git a/python/sglang/srt/mem_cache/hiradix_cache.py b/python/sglang/srt/mem_cache/hiradix_cache.py index cb7d95558bec..796f0553ceca 100644 --- a/python/sglang/srt/mem_cache/hiradix_cache.py +++ b/python/sglang/srt/mem_cache/hiradix_cache.py @@ -35,6 +35,7 @@ def __init__( hicache_size: int, hicache_write_policy: str, hicache_io_backend: str, + hicache_storage_backend: Optional[str] = None, ): self.kv_cache = token_to_kv_pool_allocator.get_kvcache() if isinstance(self.kv_cache, MHATokenToKVPool): @@ -49,6 +50,9 @@ def __init__( raise ValueError(f"HiRadixCache only supports MHA and MLA yet") self.tp_group = tp_cache_group + self.enable_storage = hicache_storage_backend is not None + # todo: customizable storage prefetch threshold + self.prefetch_threshold = 256 self.load_cache_event = threading.Event() self.cache_controller = HiCacheController( @@ -58,16 +62,22 @@ def __init__( load_cache_event=self.load_cache_event, write_policy=hicache_write_policy, io_backend=hicache_io_backend, + storage_backend=hicache_storage_backend, + prefetch_threshold=self.prefetch_threshold, ) # record the nodes with ongoing write through self.ongoing_write_through = {} # record the node segments with ongoing load back self.ongoing_load_back = {} + # record the ongoing prefetch requests + self.ongoing_prefetch = {} + self.ongoing_backup = {} # todo: dynamically adjust the threshold self.write_through_threshold = ( 1 if hicache_write_policy == "write_through" else 3 ) + self.write_through_threshold_storage = 3 self.load_back_threshold = 10 super().__init__( req_to_token_pool, token_to_kv_pool_allocator, page_size, disable=False @@ -108,13 +118,30 @@ def write_backup(self, node: TreeNode, write_back=False): return len(host_indices) + def write_backup_storage(self, node: TreeNode): + operation_id = self.cache_controller.write_storage( + node.host_value, node.key, node.parent.get_last_hash_value() + ) + self.ongoing_backup[operation_id] = node + node.protect_host() + def inc_hit_count(self, node: TreeNode): - if node.backuped or self.cache_controller.write_policy == "write_back": + if self.cache_controller.write_policy == "write_back": return node.hit_count += 1 - if node.hit_count >= self.write_through_threshold: - self.write_backup(node) - node.hit_count = 0 + + if not node.backuped: + if node.hit_count >= self.write_through_threshold: + # write to host if the node is not backuped + self.write_backup(node) + else: + if ( + self.enable_storage + and (not node.backuped_storage) + and node.hit_count >= self.write_through_threshold_storage + ): + # if the node is backuped on host memory but not on storage + self.write_backup_storage(node) def writing_check(self, write_back=False): if write_back: @@ -221,6 +248,10 @@ def evict_host(self, num_tokens: int): if not x.evicted: continue + # node is protected from eviction as it has ongoing prefetch or backup to storage + if x.host_ref_counter > 0: + continue + num_evicted += self.cache_controller.evict_host(x.host_value) for k, v in x.parent.children.items(): @@ -314,6 +345,85 @@ def ready_to_load_host_cache(self): def check_hicache_events(self): self.writing_check() self.loading_check() + if self.enable_storage: + self.check_revoked_prefetch() + self.check_backup_progress() + + def check_revoked_prefetch(self): + queue_size = torch.tensor( + self.cache_controller.prefetch_revoke_queue.qsize(), dtype=torch.int + ) + if torch.distributed.get_world_size(group=self.tp_group) > 1: + # synchrnoize TP workers to make the same update to hiradix cache + torch.distributed.all_reduce( + queue_size, + op=torch.distributed.ReduceOp.MIN, + group=self.tp_group, + ) + for _ in range(queue_size.item()): + req_id = self.cache_controller.prefetch_revoke_queue.get() + if req_id in self.ongoing_prefetch: + last_host_node, _, host_indices, _ = self.ongoing_prefetch[req_id] + last_host_node.release_host() + self.cache_controller.mem_pool_host.free(host_indices) + del self.ongoing_prefetch[req_id] + + def check_backup_progress(self): + queue_size = torch.tensor( + self.cache_controller.ack_backup_queue.qsize(), dtype=torch.int + ) + if torch.distributed.get_world_size(group=self.tp_group) > 1: + # synchrnoize TP workers to make the same update to hiradix cache + torch.distributed.all_reduce( + queue_size, + op=torch.distributed.ReduceOp.MIN, + group=self.tp_group, + ) + for _ in range(queue_size.item()): + ack_id, hash_value = self.cache_controller.ack_backup_queue.get() + self.ongoing_backup[ack_id].hash_value = hash_value + self.ongoing_backup[ack_id].release_host() + del self.ongoing_backup[ack_id] + + def check_prefetch_progress(self, req_id: str): + if req_id not in self.ongoing_prefetch: + # there is no ongoing prefetch for this request or it has been revoked + return + + # todo: more policies for prefetch progress such as timeout + # the current policy is to prefetch with best effort and terminate when queuing is over + last_host_node, token_ids, host_indices, operation = self.ongoing_prefetch[ + req_id + ] + completed_tokens, hash_value = self.cache_controller.terminate_prefetch( + operation + ) + logger.debug(f"Prefetch {req_id} completed with {completed_tokens} tokens") + + min_completed_tokens = torch.tensor(completed_tokens, dtype=torch.int) + if torch.distributed.get_world_size(group=self.tp_group) > 1: + # synchrnoize TP workers to make the same update to hiradix cache + torch.distributed.all_reduce( + min_completed_tokens, + op=torch.distributed.ReduceOp.MIN, + group=self.tp_group, + ) + min_completed_tokens = min_completed_tokens.item() + fetched_token_ids = token_ids[:min_completed_tokens] + written_indices = host_indices[:min_completed_tokens] + matched_length = self._insert_helper_host( + last_host_node, + fetched_token_ids, + written_indices, + hash_value[:min_completed_tokens], + ) + + self.cache_controller.mem_pool_host.free(host_indices[:matched_length]) + self.cache_controller.mem_pool_host.free( + host_indices[min_completed_tokens:completed_tokens] + ) + last_host_node.release_host() + del self.ongoing_prefetch[req_id] def match_prefix(self, key: List[int], **kwargs): empty_value = torch.empty((0,), dtype=torch.int64, device=self.device) @@ -348,6 +458,71 @@ def match_prefix(self, key: List[int], **kwargs): host_hit_length=host_hit_length, ) + def prefetch_from_storage( + self, + req_id: str, + last_host_node: TreeNode, + new_input_tokens: List[int], + last_hash: Optional[str] = None, + ): + if not self.enable_storage or len(new_input_tokens) < self.prefetch_threshold: + return + + last_host_node.protect_host() + host_indices = self.cache_controller.mem_pool_host.alloc(len(new_input_tokens)) + if host_indices is None: + self.evict_host(len(new_input_tokens)) + host_indices = self.cache_controller.mem_pool_host.alloc( + len(new_input_tokens) + ) + if host_indices is None: + last_host_node.release_host() + # no sufficient host memory to prefetch + return + operation = self.cache_controller.prefetch( + req_id, host_indices, new_input_tokens, last_hash + ) + self.ongoing_prefetch[req_id] = ( + last_host_node, + new_input_tokens, + host_indices, + operation, + ) + + def _insert_helper_host(self, node: TreeNode, key: List, host_value, hash_value): + node.last_access_time = time.monotonic() + if len(key) == 0: + return 0 + + child_key = self.get_child_key_fn(key) + + matched_length = 0 + while len(key) > 0 and child_key in node.children.keys(): + node = node.children[child_key] + node.last_access_time = time.monotonic() + prefix_len = self.key_match_fn(node.key, key) + key = key[prefix_len:] + host_value = host_value[prefix_len:] + hash_value = hash_value[prefix_len:] + matched_length += prefix_len + + if prefix_len < len(node.key): + new_node = self._split_node(node.key, node, prefix_len) + node = new_node + + if len(key): + child_key = self.get_child_key_fn(key) + + if len(key): + new_node = TreeNode() + new_node.parent = node + new_node.key = key + new_node.value = None + new_node.host_value = host_value + new_node.hash_value = hash_value + node.children[child_key] = new_node + return matched_length + def _match_prefix_helper(self, node: TreeNode, key: List): node.last_access_time = time.monotonic() child_key = self.get_child_key_fn(key) diff --git a/python/sglang/srt/mem_cache/memory_pool_host.py b/python/sglang/srt/mem_cache/memory_pool_host.py index 1bc2ddf7ec45..f503479628a9 100644 --- a/python/sglang/srt/mem_cache/memory_pool_host.py +++ b/python/sglang/srt/mem_cache/memory_pool_host.py @@ -99,6 +99,20 @@ def get_size_per_token(self): def init_kv_buffer(self): raise NotImplementedError() + @abc.abstractmethod + def get_flat_data_page(self, index) -> torch.Tensor: + """ + Get a flat data page from the host memory pool. + """ + raise NotImplementedError() + + @abc.abstractmethod + def set_from_flat_data_page(self, index: int, data_page: torch.Tensor) -> None: + """ + Set a flat data page to the host memory pool. + """ + raise NotImplementedError() + @synchronized() def clear(self): # Initialize memory states and tracking structures. @@ -227,6 +241,19 @@ def init_kv_buffer(self): pin_memory=self.pin_memory, ) + # todo, page first memory layout + def get_flat_data_page(self, index) -> torch.Tensor: + return self.kv_buffer[:, :, index : index + self.page_size, :, :].flatten() + + def set_from_flat_data_page(self, index: int, data_page: torch.Tensor) -> None: + self.kv_buffer[:, :, index : index + self.page_size, :, :] = data_page.reshape( + 2, + self.layer_num, + self.page_size, + self.head_num, + self.head_dim, + ) + @property def k_buffer(self): return self.kv_buffer[0] @@ -276,3 +303,14 @@ def init_kv_buffer(self): device=self.device, pin_memory=self.pin_memory, ) + + def get_flat_data_page(self, index) -> torch.Tensor: + return self.kv_buffer[:, index : index + self.page_size, :, :].flatten() + + def set_from_flat_data_page(self, index: int, data_page: torch.Tensor) -> None: + self.kv_buffer[:, index : index + self.page_size, :, :] = data_page.reshape( + self.layer_num, + self.page_size, + 1, + self.kv_lora_rank + self.qk_rope_head_dim, + ) diff --git a/python/sglang/srt/mem_cache/radix_cache.py b/python/sglang/srt/mem_cache/radix_cache.py index 7064322090ae..0826990c21aa 100644 --- a/python/sglang/srt/mem_cache/radix_cache.py +++ b/python/sglang/srt/mem_cache/radix_cache.py @@ -55,8 +55,13 @@ def __init__(self, id: Optional[int] = None): self.hit_count = 0 # indicating the node is loading KV cache from host self.loading = False + # indicating the node is locked to protect from eviction + # incremented when the node is referenced by a storage operation + self.host_ref_counter = 0 # store the host indices of KV cache self.host_value: Optional[torch.Tensor] = None + # store hash values of each pages + self.hash_value: Optional[List[str]] = None self.id = TreeNode.counter if id is None else id TreeNode.counter += 1 @@ -69,6 +74,27 @@ def evicted(self): def backuped(self): return self.host_value is not None + @property + def backuped_storage(self): + return self.hash_value is not None and len(self.hash_value) > 0 + + def protect_host(self): + """Protect the host value from eviction.""" + self.host_ref_counter += 1 + + def release_host(self): + """Release the host value, allowing it to be evicted.""" + if self.host_ref_counter > 0: + self.host_ref_counter -= 1 + else: + raise RuntimeError("Host reference counter is already zero.") + + def get_last_hash_value(self) -> Optional[str]: + """Returns the hash value of the last page in this node.""" + if self.hash_value is None or len(self.hash_value) == 0: + return None + return self.hash_value[-1] + def __lt__(self, other: "TreeNode"): return self.last_access_time < other.last_access_time diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index e475039d7380..cb8038d3366a 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -222,6 +222,7 @@ class ServerArgs: hicache_size: int = 0 hicache_write_policy: str = "write_through_selective" hicache_io_backend: str = "" + hicache_storage_backend: Optional[str] = None flashinfer_mla_disable_ragged: bool = False disable_shared_experts_fusion: bool = False disable_chunked_prefix_cache: bool = False @@ -1604,6 +1605,13 @@ def add_cli_args(parser: argparse.ArgumentParser): default=ServerArgs.hicache_io_backend, help="The IO backend for KV cache transfer between CPU and GPU", ) + parser.add_argument( + "--hicache-storage-backend", + type=str, + choices=["file"], # todo, mooncacke + default=ServerArgs.hicache_storage_backend, + help="The storage backend for hierarchical KV cache.", + ) parser.add_argument( "--flashinfer-mla-disable-ragged", action="store_true", diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 059955f3351c..41564869ed9b 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -64,6 +64,7 @@ class TestFile: TestFile("test_fused_moe.py", 30), TestFile("test_hicache.py", 116), TestFile("test_hicache_mla.py", 127), + TestFile("test_hicache_storage.py", 127), TestFile("test_hidden_states.py", 55), TestFile("test_int8_kernel.py", 8), TestFile("test_input_embeddings.py", 38), diff --git a/test/srt/test_hicache_storage.py b/test/srt/test_hicache_storage.py new file mode 100644 index 000000000000..aadc9529d50b --- /dev/null +++ b/test/srt/test_hicache_storage.py @@ -0,0 +1,55 @@ +import unittest +from types import SimpleNamespace + +from sglang.srt.utils import kill_process_tree +from sglang.test.run_eval import run_eval +from sglang.test.test_utils import ( + DEFAULT_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + CustomTestCase, + popen_launch_server, +) + + +class TestHiCache(CustomTestCase): + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--enable-hierarchical-cache", + "--mem-fraction-static", + 0.7, + "--hicache-size", + 100, + "--page-size", + "64", + "--hicache-storage-backend", + "file", + ], + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_mmlu(self): + args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="mmlu", + num_examples=64, + num_threads=32, + ) + + metrics = run_eval(args) + self.assertGreaterEqual(metrics["score"], 0.65) + + +if __name__ == "__main__": + unittest.main() From d0510f08feaa155c4d99f01667e1b5673652478c Mon Sep 17 00:00:00 2001 From: Sai Enduri Date: Fri, 18 Jul 2025 01:12:11 -0700 Subject: [PATCH 031/396] Revert "Fix different device type adjustment in PP" (#8141) --- .../sglang/srt/distributed/parallel_state.py | 12 ++++--- python/sglang/srt/managers/scheduler.py | 5 --- python/sglang/srt/managers/tp_worker.py | 1 - python/sglang/srt/utils.py | 34 +++++++++++-------- 4 files changed, 27 insertions(+), 25 deletions(-) diff --git a/python/sglang/srt/distributed/parallel_state.py b/python/sglang/srt/distributed/parallel_state.py index 5ab2e3758115..509c71531062 100644 --- a/python/sglang/srt/distributed/parallel_state.py +++ b/python/sglang/srt/distributed/parallel_state.py @@ -699,14 +699,14 @@ def send_object(self, obj: Any, dst: int) -> None: ) # Serialize object to tensor and get the size as well - object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8).to( - device=self.device + object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8).cuda( + device=torch.cuda.current_device() ) size_tensor = torch.tensor( [object_tensor.numel()], dtype=torch.long, - device=self.device, + device=torch.cuda.current_device(), ) # Send object size @@ -731,7 +731,9 @@ def recv_object(self, src: int) -> Any: src != self.rank_in_group ), "Invalid source rank. Source rank is the same as the current rank." - size_tensor = torch.empty(1, dtype=torch.long, device=self.device) + size_tensor = torch.empty( + 1, dtype=torch.long, device=torch.cuda.current_device() + ) # Receive object size rank_size = torch.distributed.recv( @@ -742,7 +744,7 @@ def recv_object(self, src: int) -> Any: object_tensor = torch.empty( # type: ignore[call-overload] size_tensor.item(), # type: ignore[arg-type] dtype=torch.uint8, - device=self.device, + device=torch.cuda.current_device(), ) rank_object = torch.distributed.recv( diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index c79e296f60f9..748cb7322ade 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -975,7 +975,6 @@ def event_loop_pp(self): self.world_group.device_group, self.pp_rank * self.tp_size + dp_offset, (self.pp_rank + 1) * self.tp_size + dp_offset, - device=self.device, ) # send out proxy tensors to the next stage @@ -1024,7 +1023,6 @@ def recv_requests(self) -> List[Req]: self.world_group.device_group, (self.pp_rank - 1) * self.tp_size + dp_offset, self.pp_rank * self.tp_size + dp_offset, - device=self.device, ) else: recv_reqs = None @@ -1055,7 +1053,6 @@ def recv_requests(self) -> List[Req]: self.attn_tp_group.rank, self.attn_tp_cpu_group, src=self.attn_tp_group.ranks[0], - device=self.device, ) if self.tp_size != 1: control_reqs = broadcast_pyobj( @@ -1063,7 +1060,6 @@ def recv_requests(self) -> List[Req]: self.tp_group.rank, self.tp_cpu_group, src=self.tp_group.ranks[0], - device=self.device, ) recv_reqs = work_reqs + control_reqs elif self.tp_size != 1: @@ -1072,7 +1068,6 @@ def recv_requests(self) -> List[Req]: self.tp_group.rank, self.tp_cpu_group, src=self.tp_group.ranks[0], - device=self.device, ) return recv_reqs diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index daeed4faff7c..ff20ea01e4d3 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -144,7 +144,6 @@ def __init__( self.tp_size * self.pp_rank + tp_rank, self.world_group.cpu_group, src=self.world_group.ranks[0], - device=self.device, )[0] set_random_seed(self.random_seed) diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 37e06b8dcc72..ce159a4da77b 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -1100,15 +1100,15 @@ def broadcast_pyobj( rank: int, dist_group: Optional[torch.distributed.ProcessGroup] = None, src: int = 0, - device: Optional[str] = None, + force_cpu_device: bool = True, ): """Broadcast inputs from src rank to all other ranks with torch.dist backend. The `rank` here refer to the source rank on global process group (regardless of dist_group argument). """ - - if device is None: - device = get_device() + device = torch.device( + "cuda" if torch.cuda.is_available() and not force_cpu_device else "cpu" + ) if rank == src: if len(data) == 0: @@ -1148,38 +1148,44 @@ def point_to_point_pyobj( group: Optional[torch.distributed.ProcessGroup] = None, src: int = 0, dst: int = 1, - device: Optional[str] = None, ): """Send data from src to dst in group using DeviceToDevice communication.""" - if device is None: - device = get_device() + if rank == src: if len(data) == 0: - tensor_size = torch.tensor([0], dtype=torch.long, device=device) + tensor_size = torch.tensor( + [0], dtype=torch.long, device=torch.cuda.current_device() + ) dist.send(tensor_size, dst=dst, group=group) else: serialized_data = pickle.dumps(data) size = len(serialized_data) tensor_data = torch.ByteTensor( np.frombuffer(serialized_data, dtype=np.uint8) - ).to( - device=device - ) # Move to Device - tensor_size = torch.tensor([size], dtype=torch.long, device=device) + ).cuda( + device=torch.cuda.current_device() + ) # Move to GPU + tensor_size = torch.tensor( + [size], dtype=torch.long, device=torch.cuda.current_device() + ) dist.send(tensor_size, dst=dst, group=group) dist.send(tensor_data, dst=dst, group=group) return data elif rank == dst: - tensor_size = torch.tensor([0], dtype=torch.long, device=device) + tensor_size = torch.tensor( + [0], dtype=torch.long, device=torch.cuda.current_device() + ) dist.recv(tensor_size, src=src, group=group) size = tensor_size.item() if size == 0: return [] - tensor_data = torch.empty(size, dtype=torch.uint8, device=device) + tensor_data = torch.empty( + size, dtype=torch.uint8, device=torch.cuda.current_device() + ) dist.recv(tensor_data, src=src, group=group) serialized_data = bytes( From 719b29f218a09642193c4bda2a7ffa32829d5604 Mon Sep 17 00:00:00 2001 From: Peng Zhang Date: Fri, 18 Jul 2025 17:45:03 +0800 Subject: [PATCH 032/396] feat: enchance green context stream creation robust with backward compatibility (#8136) --- sgl-kernel/csrc/spatial/greenctx_stream.cu | 59 ++++++++++++---------- 1 file changed, 33 insertions(+), 26 deletions(-) diff --git a/sgl-kernel/csrc/spatial/greenctx_stream.cu b/sgl-kernel/csrc/spatial/greenctx_stream.cu index 8c2e6d813c95..9d7a44a1aab3 100644 --- a/sgl-kernel/csrc/spatial/greenctx_stream.cu +++ b/sgl-kernel/csrc/spatial/greenctx_stream.cu @@ -7,17 +7,15 @@ #include "cuda_utils.h" #include "greenctx_stream.h" -std::vector create_greenctx_stream_fallback(CUgreenCtx gctx[2]) { +static std::vector create_greenctx_stream_fallback(CUgreenCtx gctx[2]) { CUstream streamA, streamB; CUcontext ctx; - // Stream A CUDA_DRV(cuCtxFromGreenCtx(&ctx, gctx[0])); CUDA_DRV(cuCtxPushCurrent(ctx)); CUDA_DRV(cuStreamCreate(&streamA, CU_STREAM_NON_BLOCKING)); CUDA_DRV(cuCtxPopCurrent(nullptr)); - // Stream B CUDA_DRV(cuCtxFromGreenCtx(&ctx, gctx[1])); CUDA_DRV(cuCtxPushCurrent(ctx)); CUDA_DRV(cuStreamCreate(&streamB, CU_STREAM_NON_BLOCKING)); @@ -26,18 +24,31 @@ std::vector create_greenctx_stream_fallback(CUgreenCtx gctx[2]) { return {(int64_t)streamA, (int64_t)streamB}; } -#if CUDA_VERSION >= 12050 -std::vector create_greenctx_stream_direct(CUgreenCtx gctx[2]) { - CUstream streamA; - CUstream streamB; +typedef CUresult(CUDAAPI* PFN_cuGreenCtxStreamCreate)(CUstream*, CUgreenCtx, unsigned int, int); - CUDA_DRV(cuGreenCtxStreamCreate(&streamA, gctx[0], CU_STREAM_NON_BLOCKING, 0)); - CUDA_DRV(cuGreenCtxStreamCreate(&streamB, gctx[1], CU_STREAM_NON_BLOCKING, 0)); +static std::vector create_greenctx_stream_direct_dynamic(CUgreenCtx gctx[2]) { + static PFN_cuGreenCtxStreamCreate pfn = nullptr; + static std::once_flag pfn_probed_flag; - std::vector vec = {(int64_t)streamA, (int64_t)streamB}; - return vec; + // detect compatibility in runtime + std::call_once(pfn_probed_flag, []() { + cuGetProcAddress("cuGreenCtxStreamCreate", reinterpret_cast(&pfn), 0, 0, nullptr); + }); + + if (!pfn) { // fallback if not compatible + return create_greenctx_stream_fallback(gctx); + } + + CUstream streamA, streamB; + CUDA_DRV(pfn(&streamA, gctx[0], CU_STREAM_NON_BLOCKING, 0)); + CUDA_DRV(pfn(&streamB, gctx[1], CU_STREAM_NON_BLOCKING, 0)); + + return {(int64_t)streamA, (int64_t)streamB}; +} + +inline void destroy_green_context(int64_t h) { + if (h) CUDA_DRV(cuGreenCtxDestroy(reinterpret_cast(h))); } -#endif std::vector create_greenctx_stream_by_value(int64_t smA, int64_t smB, int64_t device) { TORCH_CHECK(CUDA_VERSION >= 12040, "Green Contexts feature requires CUDA Toolkit 12.4 or newer."); @@ -46,42 +57,38 @@ std::vector create_greenctx_stream_by_value(int64_t smA, int64_t smB, i CUdevResourceDesc desc[3]; CUdevResource input; CUdevResource resources[4]; - unsigned int nbGroups = 1; - if (smA <= 0 || smB <= 0) { TORCH_CHECK(false, "SM counts must be positive"); } CUDA_DRV(cuDeviceGetDevResource((CUdevice)device, &input, CU_DEV_RESOURCE_TYPE_SM)); - unsigned int minCount = (unsigned int)(smA + smB); - unsigned int minCountA = (unsigned int)(smA); + + const unsigned minCount = smA + smB; + const unsigned minCountA = smA; TORCH_CHECK(minCount <= input.sm.smCount, "Not enough SMs available for the requested configuration"); + unsigned nbGroups = 1; CUDA_DRV(cuDevSmResourceSplitByCount(&resources[2], &nbGroups, &input, &resources[3], 0, minCount)); CUDA_DRV(cuDevResourceGenerateDesc(&desc[2], &resources[2], 1)); CUDA_DRV(cuGreenCtxCreate(&gctx[2], desc[2], (CUdevice)device, CU_GREEN_CTX_DEFAULT_STREAM)); CUDA_DRV(cuGreenCtxGetDevResource(gctx[2], &input, CU_DEV_RESOURCE_TYPE_SM)); + nbGroups = 1; CUDA_DRV(cuDevSmResourceSplitByCount(&resources[0], &nbGroups, &input, &resources[1], 0, minCountA)); CUDA_DRV(cuDevResourceGenerateDesc(&desc[0], &resources[0], 1)); CUDA_DRV(cuGreenCtxCreate(&gctx[0], desc[0], (CUdevice)device, CU_GREEN_CTX_DEFAULT_STREAM)); CUDA_DRV(cuDevResourceGenerateDesc(&desc[1], &resources[1], 1)); CUDA_DRV(cuGreenCtxCreate(&gctx[1], desc[1], (CUdevice)device, CU_GREEN_CTX_DEFAULT_STREAM)); - int smCountA = resources[0].sm.smCount; - int smCountB = resources[1].sm.smCount; - std::vector stream_handles; + const int smCountA = resources[0].sm.smCount; + const int smCountB = resources[1].sm.smCount; -#if CUDA_VERSION >= 12050 - stream_handles = create_greenctx_stream_direct(gctx); -#else - stream_handles = create_greenctx_stream_fallback(gctx); -#endif + std::vector streams = create_greenctx_stream_direct_dynamic(gctx); CUDA_DRV(cuGreenCtxDestroy(gctx[2])); std::vector vec = { - stream_handles[0], // streamA - stream_handles[1], // streamB + streams[0], // streamA + streams[1], // streamB (int64_t)smCountA, (int64_t)smCountB}; From fd63b62eaad903ac0b75630e5b1eee9002783b10 Mon Sep 17 00:00:00 2001 From: Enrique Shockwave <33002121+qeternity@users.noreply.github.com> Date: Fri, 18 Jul 2025 19:34:14 +0100 Subject: [PATCH 033/396] fix compressed tensors WNA16 imports (#8142) --- .../quantization/compressed_tensors/compressed_tensors.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py index 50d90406d26f..8afc15a73718 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py @@ -40,7 +40,10 @@ from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod try: - import vllm + from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( + WNA16_SUPPORTED_BITS, + CompressedTensorsWNA16, + ) VLLM_AVAILABLE = True except ImportError: From 6737671c82cd654dc052b3ffd7ddfcce73dfbe90 Mon Sep 17 00:00:00 2001 From: Even Zhou Date: Sat, 19 Jul 2025 02:34:55 +0800 Subject: [PATCH 034/396] [Bugfix] Fix w8a8_int8 import error on NPU (#8147) --- python/sglang/srt/layers/quantization/w8a8_int8.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/sglang/srt/layers/quantization/w8a8_int8.py b/python/sglang/srt/layers/quantization/w8a8_int8.py index c9af7ae29cc7..19cf49c9bc86 100644 --- a/python/sglang/srt/layers/quantization/w8a8_int8.py +++ b/python/sglang/srt/layers/quantization/w8a8_int8.py @@ -754,6 +754,8 @@ def apply( x: torch.Tensor, bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: + from sglang.srt.layers.linear import RowParallelLinear + if isinstance(layer, RowParallelLinear): tp_rank = get_tensor_model_parallel_rank() return self.quant_method.apply(layer, x, bias, tp_rank) From 1f76fc874759c257b4e928b9847a8da8e8ea2b30 Mon Sep 17 00:00:00 2001 From: Hongbo Xu <1320612015@qq.com> Date: Sat, 19 Jul 2025 02:45:22 +0800 Subject: [PATCH 035/396] [3/n] chore: decouple AWQ implementation from vLLM dependency (#8113) Co-authored-by: AniZpZ --- benchmark/deepseek_v3/README.md | 9 + .../srt/layers/quantization/__init__.py | 22 +- python/sglang/srt/layers/quantization/awq.py | 584 +++++++++++++++++- .../sglang/srt/layers/quantization/utils.py | 85 ++- python/sglang/srt/models/deepseek_v2.py | 4 +- python/sglang/test/test_marlin_moe.py | 286 +++++++++ python/sglang/test/test_marlin_utils.py | 171 +++++ test/srt/test_gptqmodel_dynamic.py | 2 +- 8 files changed, 1143 insertions(+), 20 deletions(-) create mode 100644 python/sglang/test/test_marlin_moe.py create mode 100644 python/sglang/test/test_marlin_utils.py diff --git a/benchmark/deepseek_v3/README.md b/benchmark/deepseek_v3/README.md index ebac6f41abaa..7fd380f91a62 100644 --- a/benchmark/deepseek_v3/README.md +++ b/benchmark/deepseek_v3/README.md @@ -178,6 +178,8 @@ python3 -m sglang.bench_one_batch_server --model None --base-url http://10.0.0.1 ### Example: Serving with 8 A100/A800 with AWQ Quantization +**Recommended Usage** + Add `--quantization moe_wna16` flag to enable moe wna16 kernel for better performance. One example is as follows: @@ -185,6 +187,13 @@ One example is as follows: python3 -m sglang.launch_server --model cognitivecomputations/DeepSeek-R1-AWQ --tp 8 --trust-remote-code --quantization moe_wna16 ``` +Alternatively, you can use `--quantization awq_marlin` as follows: + +```bash +python3 -m sglang.launch_server --model cognitivecomputations/DeepSeek-R1-AWQ --tp 8 --trust-remote-code --quantization awq_marlin --dtype float16 +``` + +Note that `awq_marlin` only supports `float16` now, which may lead to some precision loss. ### Example: Serving with 16 A100/A800 with int8 Quantization diff --git a/python/sglang/srt/layers/quantization/__init__.py b/python/sglang/srt/layers/quantization/__init__.py index e0f4363437b3..9995b72d0e0b 100644 --- a/python/sglang/srt/layers/quantization/__init__.py +++ b/python/sglang/srt/layers/quantization/__init__.py @@ -7,10 +7,6 @@ try: from vllm.model_executor.layers.quantization.aqlm import AQLMConfig - from vllm.model_executor.layers.quantization.awq_marlin import ( - AWQMarlinConfig, - AWQMoEMethod, - ) from vllm.model_executor.layers.quantization.bitsandbytes import BitsAndBytesConfig from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import ( CompressedTensorsW8A8Fp8MoEMethod, @@ -36,14 +32,14 @@ class DummyConfig: def override_quantization_method(self, *args, **kwargs): return None - AQLMConfig = AWQMarlinConfig = BitsAndBytesConfig = CompressedTensorsConfig = ( - DeepSpeedFPConfig - ) = ExpertsInt8Config = FBGEMMFp8Config = GGUFConfig = GPTQMarlin24Config = ( - MarlinConfig - ) = QQQConfig = Int8TpuConfig = DummyConfig + AQLMConfig = BitsAndBytesConfig = CompressedTensorsConfig = DeepSpeedFPConfig = ( + ExpertsInt8Config + ) = FBGEMMFp8Config = GGUFConfig = GPTQMarlin24Config = MarlinConfig = QQQConfig = ( + Int8TpuConfig + ) = DummyConfig -from sglang.srt.layers.quantization.awq import AWQConfig +from sglang.srt.layers.quantization.awq import AWQConfig, AWQMarlinConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.blockwise_int8 import BlockInt8Config from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import ( @@ -63,10 +59,7 @@ def override_quantization_method(self, *args, **kwargs): ) from sglang.srt.layers.quantization.moe_wna16 import MoeWNA16Config from sglang.srt.layers.quantization.qoq import QoQConfig -from sglang.srt.layers.quantization.utils import ( - get_dynamic_override, - get_linear_quant_method, -) +from sglang.srt.layers.quantization.utils import get_linear_quant_method from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config @@ -237,7 +230,6 @@ def monkey_patch_quant_configs(): setattr(GPTQMarlinConfig, "get_quant_method", gptq_get_quant_method) setattr(GPTQConfig, "get_quant_method", gptq_get_quant_method) - monkey_patch_moe_apply(AWQMoEMethod) monkey_patch_moe_apply(GPTQMarlinMoEMethod) monkey_patch_moe_apply(CompressedTensorsW8A8Fp8MoEMethod) monkey_patch_moe_apply(CompressedTensorsWNA16MoEMethod) diff --git a/python/sglang/srt/layers/quantization/awq.py b/python/sglang/srt/layers/quantization/awq.py index 6265f2217d79..4532673837dc 100644 --- a/python/sglang/srt/layers/quantization/awq.py +++ b/python/sglang/srt/layers/quantization/awq.py @@ -2,21 +2,52 @@ from __future__ import annotations import logging -from typing import Any, Dict, List, Optional +import warnings +from typing import Any, Callable, Dict, List, Optional import torch +from sglang.srt.layers.linear import LinearBase, set_weight_attrs from sglang.srt.layers.parameter import GroupQuantScaleParameter, PackedvLLMParameter from sglang.srt.layers.quantization.base_config import ( + FusedMoEMethodBase, LinearMethodBase, QuantizationConfig, + QuantizeMethodBase, ) +from sglang.srt.layers.quantization.marlin_utils import ( + apply_awq_marlin_linear, + awq_to_marlin_zero_points, + check_marlin_supported, + check_marlin_supports_layer, + check_moe_marlin_supports_layer, + marlin_make_empty_g_idx, + marlin_make_workspace, + marlin_moe_permute_scales, + marlin_permute_scales, + moe_awq_to_marlin_zero_points, + verify_marlin_supported, + verify_marlin_supports_shape, +) +from sglang.srt.layers.quantization.scalar_type import scalar_types from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod +from sglang.srt.layers.quantization.utils import replace_parameter + +try: + from vllm import _custom_ops as ops + + warnings.warn( + f"Using kernels directly from vllm. This might lead to performance degradation or " + f"missing functionalities as certain kernels may not be optimized. " + ) +except ImportError: + ops = None + from sglang.srt.utils import is_cuda _is_cuda = is_cuda() if _is_cuda: - from sgl_kernel import awq_dequantize + from sgl_kernel import awq_dequantize, fused_marlin_moe logger = logging.getLogger(__name__) @@ -103,6 +134,176 @@ def get_quant_method( return None +class AWQMarlinConfig(QuantizationConfig): + """Config class for AWQ Marlin""" + + # num_bits -> type + TYPE_MAP = { + 4: scalar_types.uint4, + 8: scalar_types.uint8, + } + + def __init__( + self, + weight_bits: int, + group_size: int, + zero_point: bool, + lm_head_quantized: bool, + modules_to_not_convert: Optional[list[str]], + full_config: dict[str, Any], + ) -> None: + super().__init__() + self.pack_factor = 32 // weight_bits # packed into int32 + self.group_size = group_size + self.zero_point = zero_point + self.lm_head_quantized = lm_head_quantized + self.weight_bits = weight_bits + self.modules_to_not_convert = modules_to_not_convert or [] + self.full_config = full_config + + if self.weight_bits not in self.TYPE_MAP: + raise ValueError( + f"Unsupported num_bits = {self.weight_bits}. " + f"Supported num_bits = {self.TYPE_MAP.keys()}" + ) + + self.quant_type = self.TYPE_MAP[self.weight_bits] + + verify_marlin_supported( + self.quant_type, group_size=self.group_size, has_zp=self.zero_point + ) + + def __repr__(self) -> str: + return ( + f"AWQMarlinConfig(quant_type={self.quant_type}, " + f"group_size={self.group_size}, " + f"zero_point={self.zero_point}, " + f"lm_head_quantized={self.lm_head_quantized}, " + f"modules_to_not_convert={self.modules_to_not_convert})" + ) + + def get_scaled_act_names(self) -> List[str]: + return [] + + @classmethod + def get_name(cls) -> str: + return "awq_marlin" + + @classmethod + def get_supported_act_dtypes(cls) -> list[torch.dtype]: + return [torch.half, torch.bfloat16] + + @classmethod + def get_min_capability(cls) -> int: + return 80 + + @classmethod + def get_config_filenames(cls) -> list[str]: + return ["quantize_config.json"] + + @classmethod + def from_config(cls, config: dict[str, Any]) -> AWQMarlinConfig: + weight_bits = cls.get_from_keys(config, ["bits"]) + group_size = cls.get_from_keys(config, ["group_size"]) + zero_point = cls.get_from_keys(config, ["zero_point"]) + lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False) + modules_to_not_convert = cls.get_from_keys_or( + config, ["modules_to_not_convert"], None + ) + return cls( + weight_bits, + group_size, + zero_point, + lm_head_quantized, + modules_to_not_convert, + config, + ) + + @classmethod + def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]: + can_convert = cls.is_awq_marlin_compatible(hf_quant_cfg) + is_valid_user_quant = ( + user_quant is None or user_quant == "marlin" or user_quant == "awq_marlin" + ) + + if can_convert and is_valid_user_quant: + msg = ( + "The model is convertible to {} during runtime." + " Using {} kernel.".format(cls.get_name(), cls.get_name()) + ) + logger.info(msg) + return cls.get_name() + + if can_convert and user_quant == "awq": + logger.info( + "Detected that the model can run with awq_marlin" + ", however you specified quantization=awq explicitly," + " so forcing awq. Use quantization=awq_marlin for" + " faster inference" + ) + return None + + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> Optional[QuantizeMethodBase]: + from sglang.srt.layers.moe.fused_moe_triton import FusedMoE + from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead + + if isinstance(layer, LinearBase) or ( + isinstance(layer, ParallelLMHead) and self.lm_head_quantized + ): + if is_layer_skipped_awq(prefix, self.modules_to_not_convert): + return UnquantizedLinearMethod() + # Check if the layer is supported by AWQMarlin. + if not check_marlin_supports_layer(layer, self.group_size): + logger.warning_once( + "Layer '%s' is not supported by AWQMarlin. Falling back to unoptimized AWQ kernels.", # noqa: E501 + prefix, + ) + return AWQConfig.from_config(self.full_config).get_quant_method( + layer, prefix + ) + return AWQMarlinLinearMethod(self) + elif isinstance(layer, FusedMoE): + from sglang.srt.layers.quantization.moe_wna16 import MoeWNA16Config + + if not check_moe_marlin_supports_layer(layer, self.group_size): + logger.warning_once( + f"Layer '{prefix}' is not supported by AWQMoeMarlin. " + "Falling back to Moe WNA16 kernels." + ) + return MoeWNA16Config.from_config(self.full_config).get_quant_method( + layer, prefix + ) + return AWQMoEMethod(self) + return None + + @classmethod + def is_awq_marlin_compatible(cls, quant_config: dict[str, Any]): + # Extract data from quant config. + quant_method = quant_config.get("quant_method", "").lower() + num_bits = quant_config.get("bits") + group_size = quant_config.get("group_size") + zero_point = quant_config.get("zero_point") + + if not _is_cuda: + return False + + if quant_method != "awq": + return False + + # If we cannot find the info needed in the config, cannot convert. + if num_bits is None or group_size is None or zero_point is None: + return False + + if num_bits not in cls.TYPE_MAP: + return False + + return check_marlin_supported( + quant_type=cls.TYPE_MAP[num_bits], group_size=group_size, has_zp=zero_point + ) + + class AWQLinearMethod(LinearMethodBase): """Linear method for AWQ. @@ -204,3 +405,382 @@ def apply( if bias is not None: out.add_(bias) return out.reshape(out_shape) + + +class AWQMarlinLinearMethod(LinearMethodBase): + """Linear method for AWQ Marlin. + + Args: + quant_config: The AWQ Marlin quantization config. + """ + + def __init__(self, quant_config: AWQMarlinConfig) -> None: + self.quant_config = quant_config + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ) -> None: + del output_size + output_size_per_partition = sum(output_partition_sizes) + weight_loader = extra_weight_attrs.get("weight_loader") + + # Normalize group_size + if self.quant_config.group_size != -1: + group_size = self.quant_config.group_size + else: + group_size = input_size + + verify_marlin_supports_shape( + output_size_per_partition=output_size_per_partition, + input_size_per_partition=input_size_per_partition, + input_size=input_size, + group_size=group_size, + ) + + qweight = PackedvLLMParameter( + data=torch.empty( + input_size_per_partition, + output_size_per_partition // self.quant_config.pack_factor, + dtype=torch.int32, + ), + input_dim=0, + output_dim=1, + packed_dim=1, + packed_factor=self.quant_config.pack_factor, + weight_loader=weight_loader, + ) + + num_groups = input_size_per_partition // group_size + + qzeros = PackedvLLMParameter( + data=torch.empty( + num_groups, + output_size_per_partition // self.quant_config.pack_factor, + dtype=torch.int32, + ), + input_dim=0, + output_dim=1, + packed_dim=1, + packed_factor=self.quant_config.pack_factor, + weight_loader=weight_loader, + ) + + scales = GroupQuantScaleParameter( + data=torch.empty( + num_groups, + output_size_per_partition, + dtype=params_dtype, + ), + input_dim=0, + output_dim=1, + weight_loader=weight_loader, + ) + + layer.register_parameter("qweight", qweight) + layer.register_parameter("qzeros", qzeros) + layer.register_parameter("scales", scales) + + layer.input_size_per_partition = input_size_per_partition + layer.output_size_per_partition = output_size_per_partition + layer.num_groups = num_groups + + # TODO: Update this docs + # Checkpoints are serialized in AutoAWQ format, which is different from the + # marlin format. This function is called after the weights are loaded. + # Here, we handle the repacking + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + device = layer.qweight.device + layer.qweight = torch.nn.Parameter(layer.qweight.data, requires_grad=False) + layer.qzeros = torch.nn.Parameter(layer.qzeros.data, requires_grad=False) + layer.scales = torch.nn.Parameter(layer.scales.data, requires_grad=False) + + # Allocate marlin workspace + layer.workspace = marlin_make_workspace(device) + + # Repack weights from AWQ format to marlin format. + marlin_qweight = ops.awq_marlin_repack( + layer.qweight, + size_k=layer.input_size_per_partition, + size_n=layer.output_size_per_partition, + num_bits=self.quant_config.quant_type.size_bits, + ) + replace_parameter(layer, "qweight", marlin_qweight) + + # Permute scales from AWQ format to marlin format. + marlin_scales = marlin_permute_scales( + layer.scales, + size_k=layer.input_size_per_partition, + size_n=layer.output_size_per_partition, + group_size=self.quant_config.group_size, + ) + replace_parameter(layer, "scales", marlin_scales) + + # Permute zero-points from AWQ format to marlin format. + marlin_zp = awq_to_marlin_zero_points( + layer.qzeros, + size_k=layer.num_groups, + size_n=layer.output_size_per_partition, + num_bits=self.quant_config.quant_type.size_bits, + ) + replace_parameter(layer, "qzeros", marlin_zp) + + # Not-used + layer.g_idx = marlin_make_empty_g_idx(device) + layer.g_idx_sort_indices = marlin_make_empty_g_idx(device) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + return apply_awq_marlin_linear( + input=x, + weight=layer.qweight, + weight_scale=layer.scales, + weight_zp=layer.qzeros, + g_idx=layer.g_idx, + g_idx_sort_indices=layer.g_idx_sort_indices, + workspace=layer.workspace, + quant_type=self.quant_config.quant_type, + output_size_per_partition=layer.output_size_per_partition, + input_size_per_partition=layer.input_size_per_partition, + bias=bias, + ) + + +class AWQMoEMethod(FusedMoEMethodBase): + + def __init__(self, quant_config: AWQMarlinConfig): + self.quant_config = quant_config + if self.quant_config.weight_bits != 4: + raise ValueError("AWQMoEMethod only supports 4bit now.") + self.quant_type = scalar_types.uint4 + + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + # Delay the import to avoid circular dependency + from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported + + extra_weight_attrs.update( + { + "is_transposed": True, + "quant_method": FusedMoeWeightScaleSupported.GROUP.value, + } + ) + + w13_qweight = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + 2 * intermediate_size_per_partition // self.quant_config.pack_factor, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w13_qweight", w13_qweight) + set_weight_attrs(w13_qweight, extra_weight_attrs) + + w2_qweight = torch.nn.Parameter( + torch.empty( + num_experts, + intermediate_size_per_partition, + hidden_size // self.quant_config.pack_factor, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w2_qweight", w2_qweight) + set_weight_attrs(w2_qweight, extra_weight_attrs) + + num_groups_w13 = hidden_size // self.quant_config.group_size + num_groups_w2 = intermediate_size_per_partition // self.quant_config.group_size + + # WEIGHT_SCALES + # Allocate 2 scales for w1 and w3 respectively. + w13_scales = torch.nn.Parameter( + torch.empty( + num_experts, + num_groups_w13, + intermediate_size_per_partition * 2, + dtype=params_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w13_scales", w13_scales) + set_weight_attrs(w13_scales, extra_weight_attrs) + + w2_scales = torch.nn.Parameter( + torch.empty(num_experts, num_groups_w2, hidden_size, dtype=params_dtype), + requires_grad=False, + ) + layer.register_parameter("w2_scales", w2_scales) + set_weight_attrs(w2_scales, extra_weight_attrs) + + # WEIGHT_ZERO_POINT + # Allocate 2 zero points for w1 and w3 respectively. + w13_qzeros = torch.nn.Parameter( + torch.empty( + num_experts, + num_groups_w13, + 2 * intermediate_size_per_partition // self.quant_config.pack_factor, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w13_qzeros", w13_qzeros) + set_weight_attrs(w13_qzeros, extra_weight_attrs) + + w2_qzeros = torch.nn.Parameter( + torch.empty( + num_experts, + num_groups_w2, + hidden_size // self.quant_config.pack_factor, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w2_qzeros", w2_qzeros) + set_weight_attrs(w2_qzeros, extra_weight_attrs) + + device = layer.w13_qweight.device + layer.workspace = marlin_make_workspace(device, 4) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + num_experts = layer.w13_qweight.shape[0] + device = layer.w13_qweight.device + + layer.w13_g_idx_sort_indices = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, device=device), + requires_grad=False, + ) + layer.w2_g_idx_sort_indices = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, device=device), + requires_grad=False, + ) + + marlin_w13_qweight = ops.awq_marlin_moe_repack( + layer.w13_qweight, + layer.w13_g_idx_sort_indices, + size_k=layer.w13_qweight.shape[1], + size_n=layer.w13_qweight.shape[2] * self.quant_config.pack_factor, + num_bits=self.quant_config.weight_bits, + ) + replace_parameter(layer, "w13_qweight", marlin_w13_qweight) + + marlin_w2_qweight = ops.awq_marlin_moe_repack( + layer.w2_qweight, + layer.w2_g_idx_sort_indices, + size_k=layer.w2_qweight.shape[1], + size_n=layer.w2_qweight.shape[2] * self.quant_config.pack_factor, + num_bits=self.quant_config.weight_bits, + ) + replace_parameter(layer, "w2_qweight", marlin_w2_qweight) + + # hidden_size->intermediate_size + marlin_w13_scales = marlin_moe_permute_scales( + s=layer.w13_scales, + size_k=layer.intermediate_size_per_partition, + size_n=layer.w13_scales.shape[2], + group_size=self.quant_config.group_size, + ) + + replace_parameter(layer, "w13_scales", marlin_w13_scales) + + marlin_w2_scales = marlin_moe_permute_scales( + s=layer.w2_scales, + size_k=layer.intermediate_size_per_partition, + size_n=layer.w2_scales.shape[2], + group_size=self.quant_config.group_size, + ) + replace_parameter(layer, "w2_scales", marlin_w2_scales) + + marlin_w13_zp = moe_awq_to_marlin_zero_points( + layer.w13_qzeros, + size_k=layer.w13_qzeros.shape[1], + size_n=layer.w13_qzeros.shape[2] * self.quant_config.pack_factor, + num_bits=self.quant_config.weight_bits, + ) + replace_parameter(layer, "w13_qzeros", marlin_w13_zp) + + marlin_w2_zp = moe_awq_to_marlin_zero_points( + layer.w2_qzeros, + size_k=layer.w2_qzeros.shape[1], + size_n=layer.w2_qzeros.shape[2] * self.quant_config.pack_factor, + num_bits=self.quant_config.weight_bits, + ) + replace_parameter(layer, "w2_qzeros", marlin_w2_zp) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + num_fused_shared_experts: int = 0, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + routed_scaling_factor: Optional[float] = None, + ) -> torch.Tensor: + # Delay the import to avoid circular dependency + from sglang.srt.layers.moe.topk import select_experts + + assert activation == "silu", "Only SiLU activation is supported." + assert ( + scoring_func == "softmax" + ), "Only softmax score func is supported for now." + + # The input must currently be float16 + orig_dtype = x.dtype + x = x.half() + + topk_weights, topk_ids = select_experts( + hidden_states=x, + router_logits=router_logits, + top_k=top_k, + use_grouped_topk=use_grouped_topk, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + num_fused_shared_experts=num_fused_shared_experts, + custom_routing_function=custom_routing_function, + correction_bias=correction_bias, + routed_scaling_factor=routed_scaling_factor, + ) + + return fused_marlin_moe( + x, + layer.w13_qweight, + layer.w2_qweight, + layer.w13_scales, + layer.w2_scales, + router_logits, + topk_weights, + topk_ids, + sort_indices1=layer.w13_g_idx_sort_indices, + sort_indices2=layer.w2_g_idx_sort_indices, + w1_zeros=layer.w13_qzeros, + w2_zeros=layer.w2_qzeros, + num_bits=self.quant_config.weight_bits, + ).to(orig_dtype) diff --git a/python/sglang/srt/layers/quantization/utils.py b/python/sglang/srt/layers/quantization/utils.py index 51d70255d90c..89e0eb84a2e6 100644 --- a/python/sglang/srt/layers/quantization/utils.py +++ b/python/sglang/srt/layers/quantization/utils.py @@ -11,7 +11,7 @@ import torch from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant -from sglang.srt.layers.quantization.scalar_type import ScalarType +from sglang.srt.layers.quantization.scalar_type import ScalarType, scalar_types from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_npu if TYPE_CHECKING: @@ -247,6 +247,36 @@ def get_pack_factor(num_bits): return 32 // num_bits +def permute_rows( + q_w: torch.Tensor, + w_ref: torch.Tensor, + group_size: int, + test_perm: Optional[torch.Tensor] = None, +): + assert q_w.shape == w_ref.shape + + orig_device = q_w.device + k_size, _ = q_w.shape + + g_idx = torch.zeros((k_size,), dtype=torch.int32) + for i in range(k_size): + g_idx[i] = i // group_size + + # Simulate act_order by doing a random permutation on K + rand_perm = test_perm if test_perm is not None else torch.randperm(k_size) + + g_idx = g_idx[rand_perm].contiguous() + q_w = q_w[rand_perm, :].contiguous() + w_ref = w_ref[rand_perm, :].contiguous() + + return ( + w_ref.to(device=orig_device), + q_w.to(device=orig_device), + g_idx.to(device=orig_device), + rand_perm.to(device=orig_device), + ) + + def pack_cols( q_w: torch.Tensor, num_bits: int, @@ -399,3 +429,56 @@ def reshape_w(w): w_s if group_size is not None else None, maybe_w_zp, ) + + +SUPPORTED_GPTQ_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128] +SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] + + +def gptq_quantize_weights( + w: torch.Tensor, + quant_type: ScalarType, + group_size: int, + act_order: bool, + test_perm: Optional[torch.Tensor] = None, +): + size_k, _ = w.shape + + assert w.is_floating_point(), "w must be float" + assert ( + quant_type in SUPPORTED_GPTQ_QUANT_TYPES + ), f"Unsupported gptq type = {quant_type}" + assert group_size in SUPPORTED_GROUP_SIZES + [ + size_k + ], f"Unsupported groupsize = {group_size}" + + w_ref, w_q, w_s, _ = quantize_weights(w, quant_type, group_size) + + # Apply act_order + g_idx = torch.empty(0, dtype=torch.int, device=w.device) + rand_perm = torch.empty(0, dtype=torch.int, device=w.device) + if act_order: + assert ( + group_size < size_k + ), "For act_order, groupsize = {} must be less than size_k = {}".format( + group_size, size_k + ) + + w_ref, w_q, g_idx, rand_perm = permute_rows(w_q, w_ref, group_size, test_perm) + + return w_ref, w_q, w_s, g_idx, rand_perm + + +def sort_weights(q_w: torch.Tensor, g_idx: torch.Tensor): + orig_device = q_w.device + + sort_indices = torch.argsort(g_idx).to(dtype=torch.int32) # Sort based on g_idx + + g_idx = g_idx[sort_indices].contiguous() + q_w = q_w[sort_indices, :].contiguous() + + return ( + q_w.to(device=orig_device), + g_idx.to(device=orig_device), + sort_indices.to(device=orig_device), + ) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index bb1efde2941e..12aa9cb39c78 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -355,6 +355,7 @@ def __init__( self.shared_experts.gate_up_proj.quant_method, "quant_config" ) and self.shared_experts.gate_up_proj.quant_method.quant_config.get_name() in { "awq", + "awq_marlin", "moe_wna16", } self.shared_experts_is_int8 = ( @@ -929,7 +930,7 @@ def __init__( has_fused_proj and hasattr(self.fused_qkv_a_proj_with_mqa.quant_method, "quant_config") and self.fused_qkv_a_proj_with_mqa.quant_method.quant_config.get_name() - in {"awq", "moe_wna16"} + in {"awq", "awq_marlin", "moe_wna16"} ) self.use_min_latency_fused_a_gemm = ( has_fused_proj @@ -2551,6 +2552,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=Fal cat_dim = 0 if self.quant_config is not None and ( self.quant_config.get_name() == "awq" + or self.quant_config.get_name() == "awq_marlin" or self.quant_config.get_name() == "moe_wna16" ): cat_dim = 1 diff --git a/python/sglang/test/test_marlin_moe.py b/python/sglang/test/test_marlin_moe.py new file mode 100644 index 000000000000..e5b4c986a770 --- /dev/null +++ b/python/sglang/test/test_marlin_moe.py @@ -0,0 +1,286 @@ +import types +from typing import Optional + +import pytest +import torch +from sgl_kernel import fused_marlin_moe + +from sglang.srt.layers.activation import SiluAndMul +from sglang.srt.layers.quantization.scalar_type import ScalarType, scalar_types +from sglang.test.test_marlin_utils import awq_marlin_quantize, marlin_quantize + + +def stack_and_dev(tensors: list[torch.Tensor]): + dev = tensors[0].device + return torch.stack(tensors, dim=0).to(dev) + + +def torch_experts( + a: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weight: torch.Tensor, + topk_ids: torch.Tensor, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + quant_dtype: Optional[torch.dtype] = None, + apply_router_weights_on_input: bool = False, +) -> torch.Tensor: + assert ( + global_num_experts == -1 + or (global_num_experts == w1.shape[0] and expert_map is None) + or (expert_map is not None and global_num_experts == expert_map.shape[0]) + ) + + M, K = a.shape + topk = topk_ids.shape[1] + print("quant_dtype", quant_dtype) + # exit(0) + if apply_router_weights_on_input: + assert topk == 1 + a = a * topk_weight.to(a.dtype) + + a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K) + + out = torch.zeros(M * topk, w2.shape[1], dtype=a.dtype, device=a.device) + + num_experts = w1.shape[0] + + topk_ids = topk_ids.view(-1) + if expert_map is not None: + topk_ids = expert_map[topk_ids] + + f32 = torch.float32 + + for i in range(num_experts): + mask = topk_ids == i + if mask.sum(): + if quant_dtype is None: + tmp1 = a[mask] @ w1[i].transpose(0, 1) + tmp2 = SiluAndMul()(tmp1) + out[mask] = tmp2 @ w2[i].transpose(0, 1) + + if apply_router_weights_on_input: + return out + else: + return ( + (out.view(M, -1, w2.shape[1]).to(f32) * topk_weight.view(M, -1, 1)) + .sum(dim=1) + .to(out.dtype) + ) + + +def torch_moe( + a: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + score: torch.Tensor, + topk: int, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, +) -> torch.Tensor: + score = torch.softmax(score, dim=-1, dtype=torch.float32) + topk_weight, topk_ids = torch.topk(score, topk) + return torch_experts( + a, w1, w2, topk_weight, topk_ids, global_num_experts, expert_map + ) + + +def marlin_moe_generate_valid_test_cases(): + import itertools + + m_list = [1, 123, 666] + n_list = [128, 1024] + k_list = [256, 2048] + e_list = [4, 12] + topk_list = [2, 3] + dtype_list = [torch.half, torch.bfloat16] + group_size_list = [128] + act_order_list = [True, False] + quant_type_list = [ + scalar_types.uint4, + scalar_types.uint4b8, + ] + is_k_full_list = [True, False] + + all_combinations = itertools.product( + m_list, + n_list, + k_list, + e_list, + topk_list, + dtype_list, + group_size_list, + act_order_list, + quant_type_list, + is_k_full_list, + ) + + def is_invalid( + m, n, k, e, topk, dtype, group_size, act_order, quant_type, is_k_full + ): + + # Filter act_order + if act_order: + if group_size in (-1, k, n): + return False + if quant_type not in [scalar_types.uint4b8]: + return False + elif not is_k_full: + return False + + return True + + cases = [] + for case in all_combinations: + if is_invalid(*case): + cases.append(case) + return cases + + +@pytest.mark.flaky(reruns=2) +@pytest.mark.parametrize( + ("m, n, k, e, topk, dtype, group_size," "act_order, quant_type, is_k_full"), + marlin_moe_generate_valid_test_cases(), +) +def test_fused_marlin_moe( + m: int, + n: int, + k: int, + e: int, + topk: int, + dtype: torch.dtype, + group_size: int, + act_order: bool, + quant_type: ScalarType, + is_k_full: bool, +): + if not torch.cuda.is_available(): + pytest.skip("CUDA device not available") + + torch.manual_seed(0) + + has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8] + + # Filter act_order + if act_order: + if group_size == -1: + return + if group_size in (k, n): + return + if has_zp: + return + else: + if not is_k_full: + return + + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 + w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 20 + w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 20 + + e_map = None + + w_ref1_l = [] + qweight1_l = [] + scales1_l = [] + zeros1_l = [] + g_idx1_l = [] + sort_indices1_l = [] + + for i in range(w1.shape[0]): + if has_zp: + w_ref1, qweight1, scales1, zeros1 = awq_marlin_quantize( + w1[i].transpose(1, 0), quant_type, group_size + ) + + w_ref1_l.append(w_ref1.T) + qweight1_l.append(qweight1) + scales1_l.append(scales1) + zeros1_l.append(zeros1) + else: + test_perm = torch.randperm(k) + w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = marlin_quantize( + w1[i].transpose(1, 0), quant_type, group_size, act_order, test_perm + ) + + w_ref1_l.append(w_ref1.T) + qweight1_l.append(qweight1) + scales1_l.append(scales1) + g_idx1_l.append(g_idx1) + sort_indices1_l.append(sort_indices1) + + w_ref1 = stack_and_dev(w_ref1_l) + qweight1 = stack_and_dev(qweight1_l).contiguous() + scales1 = stack_and_dev(scales1_l) + g_idx1 = stack_and_dev(g_idx1_l) if g_idx1_l else None + zeros1 = stack_and_dev(zeros1_l) if zeros1_l else None + sort_indices1 = stack_and_dev(sort_indices1_l) if sort_indices1_l else None + + w_ref2_l = [] + qweight2_l = [] + scales2_l = [] + zeros2_l = [] + g_idx2_l = [] + sort_indices2_l = [] + + for i in range(w2.shape[0]): + if has_zp: + w_ref2, qweight2, scales2, zeros2 = awq_marlin_quantize( + w2[i].transpose(1, 0), quant_type, group_size + ) + + w_ref2_l.append(w_ref2.T) + qweight2_l.append(qweight2) + scales2_l.append(scales2) + zeros2_l.append(zeros2) + else: + test_perm = torch.randperm(n) + w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = marlin_quantize( + w2[i].transpose(1, 0), quant_type, group_size, act_order, test_perm + ) + + w_ref2_l.append(w_ref2.T) + qweight2_l.append(qweight2) + scales2_l.append(scales2) + g_idx2_l.append(g_idx2) + sort_indices2_l.append(sort_indices2) + + w_ref2 = stack_and_dev(w_ref2_l) + qweight2 = stack_and_dev(qweight2_l).contiguous() + scales2 = stack_and_dev(scales2_l) + g_idx2 = stack_and_dev(g_idx2_l) if g_idx2_l else None + zeros2 = stack_and_dev(zeros2_l) if zeros2_l else None + sort_indices2 = stack_and_dev(sort_indices2_l) if sort_indices2_l else None + + score = torch.randn((m, e), device="cuda", dtype=dtype) + from sglang.srt.layers.moe.topk import fused_topk_torch_native + + topk_weights, topk_ids = fused_topk_torch_native(a, score, topk, False) + + torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, expert_map=e_map) + + marlin_output = fused_marlin_moe( + a, + qweight1, + qweight2, + scales1, + scales2, + score, + topk_weights, + topk_ids, + g_idx1=g_idx1, + g_idx2=g_idx2, + sort_indices1=sort_indices1, + sort_indices2=sort_indices2, + w1_zeros=zeros1, + w2_zeros=zeros2, + num_bits=4, + is_k_full=is_k_full, + ) + + torch.testing.assert_close(marlin_output, torch_output, atol=5e-2, rtol=0) + + +if __name__ == "__main__": + # Run the specific test function directly + pytest.main([__file__]) diff --git a/python/sglang/test/test_marlin_utils.py b/python/sglang/test/test_marlin_utils.py new file mode 100644 index 000000000000..920cb7d8bef7 --- /dev/null +++ b/python/sglang/test/test_marlin_utils.py @@ -0,0 +1,171 @@ +""" +Adapted from +https://github.com/vllm-project/vllm/blob/020f58abcdea65302225663130d08fd8f4dd755a/vllm/model_executor/layers/quantization/utils/marlin_utils_test.py +""" + +# SPDX-License-Identifier: Apache-2.0 +"""Utility functions used for tests and benchmarks""" + +from typing import Optional + +import numpy as np +import torch + +from sglang.srt.layers.quantization.marlin_utils import ( + GPTQ_MARLIN_TILE, + marlin_permute_scales, + marlin_zero_points, +) +from sglang.srt.layers.quantization.scalar_type import ScalarType +from sglang.srt.layers.quantization.utils import ( + get_pack_factor, + gptq_quantize_weights, + quantize_weights, + sort_weights, +) + + +class MarlinWorkspace: + + def __init__(self, out_features, min_thread_n, max_parallel): + assert ( + out_features % min_thread_n == 0 + ), "out_features = {} is undivisible by min_thread_n = {}".format( + out_features, min_thread_n + ) + + max_workspace_size = (out_features // min_thread_n) * max_parallel + + self.scratch = torch.zeros(max_workspace_size, dtype=torch.int, device="cuda") + + +def marlin_permute_weights(q_w, size_k, size_n, perm, tile=GPTQ_MARLIN_TILE): + assert q_w.shape == (size_k, size_n) + assert size_k % tile == 0, f"size_k = {size_k}, tile = {tile}" + assert size_n % tile == 0, f"size_k = {size_n}, tile = {tile}" + + # Permute weights to 16x64 marlin tiles + q_w = q_w.reshape((size_k // tile, tile, size_n // tile, tile)) + q_w = q_w.permute((0, 2, 1, 3)) + q_w = q_w.reshape((size_k // tile, size_n * tile)) + + q_w = q_w.reshape((-1, perm.numel()))[:, perm].reshape(q_w.shape) + + return q_w + + +def marlin_weights(q_w, size_k, size_n, num_bits, perm): + # Permute + q_w = marlin_permute_weights(q_w, size_k, size_n, perm) + + # Pack + pack_factor = get_pack_factor(num_bits) + orig_device = q_w.device + + q_w = q_w.cpu().numpy().astype(np.uint32) + + q_packed = np.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), dtype=np.uint32) + for i in range(pack_factor): + q_packed |= q_w[:, i::pack_factor] << num_bits * i + + q_packed = torch.from_numpy(q_packed.astype(np.int32)).to(orig_device) + + return q_packed + + +def get_weight_perm(num_bits: int): + perm_list: list[int] = [] + for i in range(32): + perm1: list[int] = [] + col = i // 4 + for block in [0, 1]: + for row in [ + 2 * (i % 4), + 2 * (i % 4) + 1, + 2 * (i % 4 + 4), + 2 * (i % 4 + 4) + 1, + ]: + perm1.append(16 * row + col + 8 * block) + for j in range(4): + perm_list.extend([p + 256 * j for p in perm1]) + + perm = np.array(perm_list) + + if num_bits == 4: + interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7]) + elif num_bits == 8: + interleave = np.array([0, 2, 1, 3]) + else: + raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) + + perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() + perm = torch.from_numpy(perm) + return perm + + +def marlin_quantize( + w: torch.Tensor, + quant_type: ScalarType, + group_size: int, + act_order: bool, + test_perm: Optional[torch.Tensor] = None, +): + size_k, size_n = w.shape + num_bits = quant_type.size_bits + + # Normalize group_size + if group_size == -1: + group_size = size_k + assert group_size <= size_k + + # Quantize (and apply act_order if provided) + w_ref, q_w, s, g_idx, rand_perm = gptq_quantize_weights( + w, quant_type, group_size, act_order, test_perm + ) + + # For act_order, sort the "weights" and "g_idx" so that group ids are + # increasing + sort_indices = torch.empty(0, dtype=torch.int, device=w.device) + if act_order: + q_w, g_idx, sort_indices = sort_weights(q_w, g_idx) + + # Reformat to marlin + weight_perm = get_weight_perm(num_bits) + marlin_q_w = marlin_weights(q_w, size_k, size_n, num_bits, weight_perm) + marlin_s = marlin_permute_scales(s, size_k, size_n, group_size) + + # Create result + res_list = [w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, rand_perm] + for i in range(len(res_list)): + res_list[i] = res_list[i].to(w.device) + + return res_list + + +def awq_marlin_quantize(w: torch.Tensor, quant_type: ScalarType, group_size: int): + size_k, size_n = w.shape + + # Normalize group_size + if group_size == -1: + group_size = size_k + assert group_size <= size_k + + # Detect num groups + assert size_k % group_size == 0 + num_groups = size_k // group_size + + # Quantize with zp + w_ref, q_w, s, zp = quantize_weights(w, quant_type, group_size, zero_points=True) + + # Reformat to marlin + weight_perm = get_weight_perm(quant_type.size_bits) + marlin_q_w = marlin_weights(q_w, size_k, size_n, quant_type.size_bits, weight_perm) + marlin_s = marlin_permute_scales(s, size_k, size_n, group_size) + marlin_zp = marlin_zero_points(zp, num_groups, size_n, quant_type.size_bits) + + # Create result + res_list = [w_ref, marlin_q_w, marlin_s, marlin_zp] + for i in range(len(res_list)): + res_list[i] = res_list[i].to(w.device) + + return res_list diff --git a/test/srt/test_gptqmodel_dynamic.py b/test/srt/test_gptqmodel_dynamic.py index feda8693459e..9be711d12420 100644 --- a/test/srt/test_gptqmodel_dynamic.py +++ b/test/srt/test_gptqmodel_dynamic.py @@ -24,7 +24,7 @@ def check_quant_method(model_path: str, use_marlin_kernel: bool): set_custom_all_reduce, ) from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state - from sglang.srt.layers.quantization import get_dynamic_override + from sglang.srt.layers.quantization.utils import get_dynamic_override from sglang.srt.model_loader import get_model from sglang.srt.server_args import PortArgs, ServerArgs From c8f31042a85fe49c19e5dd2b38bc8356d2bf9e94 Mon Sep 17 00:00:00 2001 From: Simo Lin Date: Fri, 18 Jul 2025 14:24:24 -0700 Subject: [PATCH 036/396] [router] Refactor router and policy traits with dependency injection (#7987) Co-authored-by: Jin Pan Co-authored-by: Keru Yang Co-authored-by: Yingyi Huang Co-authored-by: Philip Zhu --- .github/workflows/pr-test-pd-router.yml | 310 ++-- scripts/ci_start_disaggregation_servers.sh | 20 +- sgl-router/benches/request_processing.rs | 2 +- sgl-router/py_test/test_launch_router.py | 35 +- sgl-router/src/config/types.rs | 4 +- sgl-router/src/config/validation.rs | 51 +- sgl-router/src/lib.rs | 14 +- sgl-router/src/policies/cache_aware.rs | 399 +++++ sgl-router/src/policies/factory.rs | 94 ++ sgl-router/src/policies/mod.rs | 143 ++ sgl-router/src/policies/power_of_two.rs | 201 +++ sgl-router/src/policies/random.rs | 116 ++ sgl-router/src/policies/round_robin.rs | 136 ++ sgl-router/src/router.rs | 1376 ----------------- sgl-router/src/routers/factory.rs | 66 + sgl-router/src/routers/mod.rs | 101 ++ sgl-router/src/{ => routers}/pd_router.rs | 633 +++++--- sgl-router/src/{ => routers}/pd_types.rs | 0 .../src/{ => routers}/request_adapter.rs | 2 +- sgl-router/src/routers/router.rs | 1055 +++++++++++++ sgl-router/src/server.rs | 193 +-- sgl-router/src/service_discovery.rs | 87 +- sgl-router/tests/benchmark_integration.rs | 2 +- sgl-router/tests/test_pd_routing.rs | 110 +- 24 files changed, 3198 insertions(+), 1952 deletions(-) create mode 100644 sgl-router/src/policies/cache_aware.rs create mode 100644 sgl-router/src/policies/factory.rs create mode 100644 sgl-router/src/policies/mod.rs create mode 100644 sgl-router/src/policies/power_of_two.rs create mode 100644 sgl-router/src/policies/random.rs create mode 100644 sgl-router/src/policies/round_robin.rs delete mode 100644 sgl-router/src/router.rs create mode 100644 sgl-router/src/routers/factory.rs create mode 100644 sgl-router/src/routers/mod.rs rename sgl-router/src/{ => routers}/pd_router.rs (67%) rename sgl-router/src/{ => routers}/pd_types.rs (100%) rename sgl-router/src/{ => routers}/request_adapter.rs (99%) create mode 100644 sgl-router/src/routers/router.rs diff --git a/.github/workflows/pr-test-pd-router.yml b/.github/workflows/pr-test-pd-router.yml index 271a8b3d92b6..91e809123934 100644 --- a/.github/workflows/pr-test-pd-router.yml +++ b/.github/workflows/pr-test-pd-router.yml @@ -131,110 +131,199 @@ jobs: SERVER_PID=$! echo "server_pid=$SERVER_PID" >> $GITHUB_OUTPUT - echo "Waiting for router to become healthy..." - TIMEOUT=300 - ELAPSED=0 - while [ $ELAPSED -lt $TIMEOUT ]; do - if curl --connect-timeout 5 --silent http://127.0.0.9:8000 > /dev/null 2>&1; then - echo "✓ Router is reachable" - break - fi - if ! ps -p $SERVER_PID > /dev/null; then - echo "Error: Server processes failed to start" - exit 1 + # Wait for all 8 servers to be healthy (script already does this) + wait_count=0 + while [ $wait_count -lt 30 ]; do + if ps -p $SERVER_PID > /dev/null; then + # Check if the startup script printed success message + sleep 2 + wait_count=$((wait_count + 1)) + else + # Script exited - check if it was successful + wait $SERVER_PID + exit_code=$? + if [ $exit_code -eq 0 ]; then + echo "✓ All disaggregation servers are healthy" + break + else + echo "Error: Server startup failed with code $exit_code" + exit 1 + fi fi - echo "Waiting for router... (${ELAPSED}s/${TIMEOUT}s)" - sleep 10 - ELAPSED=$((ELAPSED + 10)) done - if [ $ELAPSED -ge $TIMEOUT ]; then - echo "Error: Router health check timeout after ${TIMEOUT}s" - exit 1 - fi - - echo "✓ Servers started and healthy (PID: $SERVER_PID)" + echo "✓ Servers started (PID: $SERVER_PID)" - - name: Test API functionality - timeout-minutes: 5 + - name: Test all policies sequentially + timeout-minutes: 30 run: | + POLICIES=("random" "round_robin" "cache_aware" "power_of_two") BASE_URL="http://127.0.0.9:8000" - echo "Testing API completions..." - response=$(curl -s -X POST "$BASE_URL/v1/chat/completions" \ - -H "Content-Type: application/json" \ - -H "Authorization: Bearer test-token" \ - -d '{ - "model": "/raid/models/meta-llama/Llama-3.1-8B-Instruct", - "messages": [ - {"role": "user", "content": "Write a Python function to calculate fibonacci numbers recursively"} - ], - "stream": false, - "max_tokens": 100 - }') - - if echo "$response" | jq -e '.choices[0].message.content' > /dev/null 2>&1; then - echo "✓ API test passed" - else - echo "✗ API test failed: $response" - exit 1 - fi + for policy in "${POLICIES[@]}"; do + echo "" + echo "==================================================" + echo "Testing policy: $policy" + echo "==================================================" + + # Start router with the current policy + echo "Starting router with policy: $policy..." + python3 -m sglang_router.launch_router \ + --pd-disaggregation \ + --policy "$policy" \ + --prefill http://127.0.0.1:30001 9001 \ + --prefill http://127.0.0.2:30002 9002 \ + --prefill http://127.0.0.3:30003 9003 \ + --prefill http://127.0.0.4:30004 9004 \ + --decode http://127.0.0.5:30005 \ + --decode http://127.0.0.6:30006 \ + --decode http://127.0.0.7:30007 \ + --decode http://127.0.0.8:30008 \ + --host 127.0.0.9 \ + --port 8000 & + ROUTER_PID=$! + + # Wait for router to become healthy + echo "Waiting for router to become healthy..." + TIMEOUT=60 + ELAPSED=0 + while [ $ELAPSED -lt $TIMEOUT ]; do + if curl --connect-timeout 5 --silent http://127.0.0.9:8000 > /dev/null 2>&1; then + echo "✓ Router is reachable" + break + fi + if ! ps -p $ROUTER_PID > /dev/null; then + echo "Error: Router process died" + exit 1 + fi + sleep 5 + ELAPSED=$((ELAPSED + 5)) + done - echo "Testing streaming API..." - stream_response=$(timeout 30 curl -s -X POST "$BASE_URL/v1/chat/completions" \ - -H "Content-Type: application/json" \ - -H "Authorization: Bearer test-token" \ - -d '{ - "model": "/raid/models/meta-llama/Llama-3.1-8B-Instruct", - "messages": [ - {"role": "user", "content": "Count from 1 to 5"} - ], - "stream": true, - "max_tokens": 50 - }') - - if echo "$stream_response" | grep -q "data:"; then - echo "✓ Streaming API test passed" - else - echo "✗ Streaming API test failed" - exit 1 - fi + if [ $ELAPSED -ge $TIMEOUT ]; then + echo "Error: Router health check timeout" + kill $ROUTER_PID 2>/dev/null || true + exit 1 + fi - - name: Run benchmark test - timeout-minutes: 5 - run: | - echo "Running benchmark test..." - benchmark_output=$(python3 -m sglang.bench_one_batch_server \ - --model-path "/raid/models/meta-llama/Llama-3.1-8B-Instruct" \ - --base-url "http://127.0.0.9:8000" \ - --batch-size 8 \ - --input-len 4096 \ - --output-len 5 \ - --skip-warmup) - - echo "$benchmark_output" - - # Extract metrics from output - latency=$(echo "$benchmark_output" | grep "latency:" | awk '{print $2}' | sed 's/s//') - input_throughput=$(echo "$benchmark_output" | grep "input throughput:" | awk '{print $3}') - output_throughput=$(echo "$benchmark_output" | grep "output throughput:" | awk '{print $3}') - - # Validate performance (latency<1.5s, input>20k, output>1k) - command -v bc >/dev/null || (apt-get update && apt-get install -y bc) - - echo "Performance: ${latency}s | ${input_throughput} | ${output_throughput} tok/s" - - fail="" - (( $(echo "$latency > 1.5" | bc -l) )) && fail="Latency too high (${latency}s>1.5s) " - (( $(echo "$input_throughput < 20000" | bc -l) )) && fail="${fail}Input too low (${input_throughput}<20k) " - (( $(echo "$output_throughput < 1000" | bc -l) )) && fail="${fail}Output too low (${output_throughput}<1k) " - - if [ -n "$fail" ]; then - echo "✗ Benchmark failed: $fail" - exit 1 - else - echo "✓ Performance validation passed" - fi + # Test API functionality + echo "Testing API completions for $policy..." + response=$(curl -s -X POST "$BASE_URL/v1/chat/completions" \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer test-token" \ + -d '{ + "model": "/raid/models/meta-llama/Llama-3.1-8B-Instruct", + "messages": [ + {"role": "user", "content": "Write a Python function to calculate fibonacci numbers recursively"} + ], + "stream": false, + "max_tokens": 100 + }') + + if echo "$response" | jq -e '.choices[0].message.content' > /dev/null 2>&1; then + echo "✓ API test passed for $policy" + else + echo "✗ API test failed for $policy: $response" + kill $ROUTER_PID 2>/dev/null || true + exit 1 + fi + + # Test streaming + echo "Testing streaming API for $policy..." + stream_response=$(timeout 30 curl -s -X POST "$BASE_URL/v1/chat/completions" \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer test-token" \ + -d '{ + "model": "/raid/models/meta-llama/Llama-3.1-8B-Instruct", + "messages": [ + {"role": "user", "content": "Count from 1 to 5"} + ], + "stream": true, + "max_tokens": 50 + }') + + if echo "$stream_response" | grep -q "data:"; then + echo "✓ Streaming API test passed for $policy" + else + echo "✗ Streaming API test failed for $policy" + kill $ROUTER_PID 2>/dev/null || true + exit 1 + fi + + # Run benchmark + echo "Running benchmark for $policy..." + benchmark_output=$(python3 -m sglang.bench_one_batch_server \ + --model-path "/raid/models/meta-llama/Llama-3.1-8B-Instruct" \ + --base-url "http://127.0.0.9:8000" \ + --batch-size 8 \ + --input-len 4096 \ + --output-len 5 \ + --skip-warmup) + + echo "$benchmark_output" + + # Save benchmark output + echo "$benchmark_output" > "benchmark_${policy}.txt" + + # Extract and validate metrics + latency=$(echo "$benchmark_output" | grep "latency:" | awk '{print $2}' | sed 's/s//') + input_throughput=$(echo "$benchmark_output" | grep "input throughput:" | awk '{print $3}') + output_throughput=$(echo "$benchmark_output" | grep "output throughput:" | awk '{print $3}') + + command -v bc >/dev/null || (apt-get update && apt-get install -y bc) + + echo "Performance for $policy: ${latency}s | ${input_throughput} | ${output_throughput} tok/s" + + # Validate performance + fail="" + (( $(echo "$latency > 1.5" | bc -l) )) && fail="Latency too high (${latency}s>1.5s) " + (( $(echo "$input_throughput < 20000" | bc -l) )) && fail="${fail}Input too low (${input_throughput}<20k) " + (( $(echo "$output_throughput < 1000" | bc -l) )) && fail="${fail}Output too low (${output_throughput}<1k) " + + if [ -n "$fail" ]; then + echo "✗ Benchmark failed for $policy: $fail" + kill $ROUTER_PID 2>/dev/null || true + exit 1 + else + echo "✓ Performance validation passed for $policy" + fi + + # Stop router before testing next policy + echo "Stopping router for $policy..." + # First try graceful shutdown + kill $ROUTER_PID 2>/dev/null || true + + # Wait up to 5 seconds for graceful shutdown + for i in {1..5}; do + if ! ps -p $ROUTER_PID > /dev/null 2>&1; then + echo "Router stopped gracefully" + break + fi + sleep 1 + done + + # Force kill if still running + if ps -p $ROUTER_PID > /dev/null 2>&1; then + echo "Force killing router..." + kill -9 $ROUTER_PID 2>/dev/null || true + fi + + # Short delay to ensure port is released + sleep 2 + + echo "✓ Completed testing for $policy" + done + + echo "" + echo "✅ All policies tested successfully!" + + + - name: Upload benchmark results + if: success() + uses: actions/upload-artifact@v4 + with: + name: benchmark-results-all-policies + path: benchmark_*.txt - name: Cleanup servers if: always() @@ -247,3 +336,34 @@ jobs: sleep 5 remaining=$(ps aux | grep -c "sglang.launch_server" || echo "0") echo "Cleanup completed. Remaining processes: $remaining" + + summarize-benchmarks: + needs: test-disaggregation + runs-on: ubuntu-latest + if: success() + + steps: + - name: Download benchmark results + uses: actions/download-artifact@v4 + with: + name: benchmark-results-all-policies + + - name: Create benchmark summary + run: | + echo "## PD Router Benchmark Results Summary" >> $GITHUB_STEP_SUMMARY + echo "" >> $GITHUB_STEP_SUMMARY + echo "| Policy | Latency (s) | Input Throughput (tok/s) | Output Throughput (tok/s) |" >> $GITHUB_STEP_SUMMARY + echo "|--------|-------------|-------------------------|--------------------------|" >> $GITHUB_STEP_SUMMARY + + for policy in random round_robin cache_aware power_of_two; do + if [ -f "benchmark_${policy}.txt" ]; then + latency=$(grep "latency:" "benchmark_${policy}.txt" | awk '{print $2}') + input_throughput=$(grep "input throughput:" "benchmark_${policy}.txt" | awk '{print $3}') + output_throughput=$(grep "output throughput:" "benchmark_${policy}.txt" | awk '{print $3}') + + echo "| ${policy} | ${latency} | ${input_throughput} | ${output_throughput} |" >> $GITHUB_STEP_SUMMARY + fi + done + + echo "" >> $GITHUB_STEP_SUMMARY + echo "✅ All policies tested successfully!" >> $GITHUB_STEP_SUMMARY diff --git a/scripts/ci_start_disaggregation_servers.sh b/scripts/ci_start_disaggregation_servers.sh index f652a4f048b4..22643e0df1a8 100755 --- a/scripts/ci_start_disaggregation_servers.sh +++ b/scripts/ci_start_disaggregation_servers.sh @@ -87,20 +87,8 @@ while true; do fi done -# Launch the router -echo "Launching router at 127.0.0.9:8000..." -python3 -m sglang_router.launch_router \ - --pd-disaggregation \ - --policy power_of_two \ - --prefill http://127.0.0.1:30001 9001 \ - --prefill http://127.0.0.2:30002 9002 \ - --prefill http://127.0.0.3:30003 9003 \ - --prefill http://127.0.0.4:30004 9004 \ - --decode http://127.0.0.5:30005 \ - --decode http://127.0.0.6:30006 \ - --decode http://127.0.0.7:30007 \ - --decode http://127.0.0.8:30008 \ - --host 127.0.0.9 \ - --port 8000 & +# Don't launch router here - just keep servers running +echo "✅ All disaggregation servers are ready and waiting for router connections" -wait # Wait for all background jobs to finish +# Keep the script running +wait # Wait for all background server jobs diff --git a/sgl-router/benches/request_processing.rs b/sgl-router/benches/request_processing.rs index c2cee90d51c1..576d07d2f79c 100644 --- a/sgl-router/benches/request_processing.rs +++ b/sgl-router/benches/request_processing.rs @@ -6,7 +6,7 @@ use sglang_router_rs::openai_api_types::{ ChatCompletionRequest, ChatMessage, CompletionRequest, GenerateParameters, GenerateRequest, SamplingParams, StringOrArray, UserMessageContent, }; -use sglang_router_rs::request_adapter::{RouteableRequest, ToPdRequest}; +use sglang_router_rs::routers::request_adapter::{RouteableRequest, ToPdRequest}; // Sample request data for benchmarks fn create_sample_generate_request() -> GenerateRequest { diff --git a/sgl-router/py_test/test_launch_router.py b/sgl-router/py_test/test_launch_router.py index eb2018283070..14a0fa12d4a9 100644 --- a/sgl-router/py_test/test_launch_router.py +++ b/sgl-router/py_test/test_launch_router.py @@ -164,56 +164,47 @@ def test_policy_validation(self): """Test that policy validation works correctly for PD and regular modes.""" from sglang_router.launch_router import RouterArgs, launch_router - # Test 1: PowerOfTwo is only valid in PD mode + # Test 1: PowerOfTwo requires at least 2 workers args = self.create_router_args( pd_disaggregation=False, policy="power_of_two", - worker_urls=["http://localhost:8000"], + worker_urls=["http://localhost:8000"], # Only 1 worker ) # Should raise error with self.assertRaises(ValueError) as cm: launch_router(args) self.assertIn( - "PowerOfTwo policy is only supported in PD disaggregated mode", + "Power-of-two policy requires at least 2 workers", str(cm.exception), ) - # Test 2: RoundRobin is not valid in PD mode + # Test 2: PowerOfTwo with sufficient workers should succeed args = self.create_router_args( - pd_disaggregation=True, - policy="round_robin", - prefill=[["http://prefill1:8080", "9000"]], - decode=[["http://decode1:8081"]], - worker_urls=[], - ) - - # Should raise error - with self.assertRaises(ValueError) as cm: - launch_router(args) - self.assertIn( - "RoundRobin policy is not supported in PD disaggregated mode", - str(cm.exception), + pd_disaggregation=False, + policy="power_of_two", + worker_urls=["http://localhost:8000", "http://localhost:8001"], # 2 workers ) + # This should not raise an error (validation passes) - # Test 3: Valid combinations should not raise errors + # Test 3: All policies now work in both modes # Regular mode with RoundRobin args = self.create_router_args( pd_disaggregation=False, policy="round_robin", worker_urls=["http://localhost:8000"], ) - # This should not raise (though it may fail to connect) + # This should not raise validation error - # PD mode with PowerOfTwo + # PD mode with RoundRobin (now supported!) args = self.create_router_args( pd_disaggregation=True, - policy="power_of_two", + policy="round_robin", prefill=[["http://prefill1:8080", "9000"]], decode=[["http://decode1:8081"]], worker_urls=[], ) - # This should not raise (though it may fail to connect) + # This should not raise validation error def test_pd_service_discovery_args_parsing(self): """Test PD service discovery CLI argument parsing.""" diff --git a/sgl-router/src/config/types.rs b/sgl-router/src/config/types.rs index 9d57f439d756..6b24a5fd1f4a 100644 --- a/sgl-router/src/config/types.rs +++ b/sgl-router/src/config/types.rs @@ -1,4 +1,4 @@ -use super::{ConfigError, ConfigResult}; +use super::ConfigResult; use serde::{Deserialize, Serialize}; use std::collections::HashMap; @@ -215,6 +215,7 @@ impl RouterConfig { self.metrics.is_some() } + /* Commented out - no longer needed without compatibility layer /// Convert to routing PolicyConfig for internal use pub fn to_routing_policy_config(&self) -> ConfigResult { match (&self.mode, &self.policy) { @@ -291,4 +292,5 @@ impl RouterConfig { } } } + */ } diff --git a/sgl-router/src/config/validation.rs b/sgl-router/src/config/validation.rs index 838742722e19..381fcce075d8 100644 --- a/sgl-router/src/config/validation.rs +++ b/sgl-router/src/config/validation.rs @@ -255,29 +255,8 @@ impl ConfigValidator { /// Validate compatibility between different configuration sections fn validate_compatibility(config: &RouterConfig) -> ConfigResult<()> { - // Check mode and policy compatibility - match (&config.mode, &config.policy) { - (RoutingMode::Regular { .. }, PolicyConfig::PowerOfTwo { .. }) => { - // PowerOfTwo is only supported in PD mode - return Err(ConfigError::IncompatibleConfig { - reason: "PowerOfTwo policy is only supported in PD disaggregated mode" - .to_string(), - }); - } - (RoutingMode::PrefillDecode { .. }, PolicyConfig::RoundRobin) => { - return Err(ConfigError::IncompatibleConfig { - reason: "RoundRobin policy is not supported in PD disaggregated mode" - .to_string(), - }); - } - (RoutingMode::PrefillDecode { .. }, PolicyConfig::CacheAware { .. }) => { - return Err(ConfigError::IncompatibleConfig { - reason: "CacheAware policy is not supported in PD disaggregated mode" - .to_string(), - }); - } - _ => {} - } + // All policies are now supported for both router types thanks to the unified trait design + // No mode/policy restrictions needed anymore // Check if service discovery is enabled for worker count validation let has_service_discovery = config.discovery.as_ref().map_or(false, |d| d.enabled); @@ -459,8 +438,8 @@ mod tests { } #[test] - fn test_validate_incompatible_policy() { - // RoundRobin with PD mode + fn test_validate_roundrobin_with_pd_mode() { + // RoundRobin with PD mode is now supported let config = RouterConfig::new( RoutingMode::PrefillDecode { prefill_urls: vec![("http://prefill:8000".to_string(), None)], @@ -470,16 +449,12 @@ mod tests { ); let result = ConfigValidator::validate(&config); - assert!(result.is_err()); - assert!(result - .unwrap_err() - .to_string() - .contains("RoundRobin policy is not supported in PD disaggregated mode")); + assert!(result.is_ok()); } #[test] fn test_validate_cache_aware_with_pd_mode() { - // CacheAware with PD mode should fail + // CacheAware with PD mode is now supported let config = RouterConfig::new( RoutingMode::PrefillDecode { prefill_urls: vec![("http://prefill:8000".to_string(), None)], @@ -495,16 +470,12 @@ mod tests { ); let result = ConfigValidator::validate(&config); - assert!(result.is_err()); - assert!(result - .unwrap_err() - .to_string() - .contains("CacheAware policy is not supported in PD disaggregated mode")); + assert!(result.is_ok()); } #[test] fn test_validate_power_of_two_with_regular_mode() { - // PowerOfTwo with Regular mode should fail + // PowerOfTwo with Regular mode is now supported let config = RouterConfig::new( RoutingMode::Regular { worker_urls: vec![ @@ -518,10 +489,6 @@ mod tests { ); let result = ConfigValidator::validate(&config); - assert!(result.is_err()); - assert!(result - .unwrap_err() - .to_string() - .contains("PowerOfTwo policy is only supported in PD disaggregated mode")); + assert!(result.is_ok()); } } diff --git a/sgl-router/src/lib.rs b/sgl-router/src/lib.rs index 2b1bcffce94d..49e8cc573059 100644 --- a/sgl-router/src/lib.rs +++ b/sgl-router/src/lib.rs @@ -4,11 +4,9 @@ pub mod logging; use std::collections::HashMap; pub mod core; pub mod openai_api_types; -pub mod pd_router; -pub mod pd_types; +pub mod policies; pub mod prometheus; -pub mod request_adapter; -pub mod router; +pub mod routers; pub mod server; pub mod service_discovery; pub mod tree; @@ -241,11 +239,6 @@ impl Router { )) })?; - // Convert to internal policy config - let policy_config = router_config - .to_routing_policy_config() - .map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string()))?; - // Create service discovery config if enabled let service_discovery_config = if self.service_discovery { Some(service_discovery::ServiceDiscoveryConfig { @@ -282,8 +275,7 @@ impl Router { server::startup(server::ServerConfig { host: self.host.clone(), port: self.port, - worker_urls: self.worker_urls.clone(), - policy_config, + router_config, max_payload_size: self.max_payload_size, log_dir: self.log_dir.clone(), log_level: self.log_level.clone(), diff --git a/sgl-router/src/policies/cache_aware.rs b/sgl-router/src/policies/cache_aware.rs new file mode 100644 index 000000000000..db5972ba68a1 --- /dev/null +++ b/sgl-router/src/policies/cache_aware.rs @@ -0,0 +1,399 @@ +/* + Cache-Aware Load Balancing Router + + This router combines two strategies to optimize both cache utilization and request distribution: + + 1. Cache-Aware Routing (Approximate Tree) + 2. Load Balancing (Shortest Queue with Balance Thresholds) + + The router dynamically switches between these strategies based on load conditions: + - Uses load balancing when the system is imbalanced + - Uses cache-aware routing when the system is balanced + + A system is considered imbalanced if both conditions are met: + 1. (max - min) > abs_threshold + 2. max > rel_threshold * min + + Strategy Details: + + 1. Cache-Aware Routing (Approximate Tree) + ------------------------------------------- + This strategy maintains an approximate radix tree for each worker based on request history, + eliminating the need for direct cache state queries. The tree stores raw text characters + instead of token IDs to avoid tokenization overhead. + + Process: + a. For each request, find the worker with the highest prefix match + b. If match rate > cache_threshold: + Route to the worker with highest match (likely has relevant data cached) + c. If match rate ≤ cache_threshold: + Route to the worker with smallest tree size (most available cache capacity) + d. Background maintenance: + Periodically evict least recently used leaf nodes to prevent memory overflow + + 2. Load Balancing (Shortest Queue) + ------------------------------------------- + This strategy tracks pending request counts per worker and routes new requests + to the least busy worker when the system is detected to be imbalanced. + + Configuration Parameters: + ------------------------ + 1. cache_threshold: (float, 0.0 to 1.0) + Minimum prefix match ratio to use highest-match routing. + Below this threshold, routes to worker with most available cache space. + + 2. balance_abs_threshold: (integer) + Absolute difference threshold for load imbalance detection. + System is potentially imbalanced if (max_load - min_load) > abs_threshold + + 3. balance_rel_threshold: (float) + Relative ratio threshold for load imbalance detection. + System is potentially imbalanced if max_load > min_load * rel_threshold + Used in conjunction with abs_threshold to determine final imbalance state. + + 4. eviction_interval_secs: (integer) + Interval between LRU eviction cycles for the approximate trees. + + 5. max_tree_size: (integer) + Maximum nodes per tree. When exceeded, LRU leaf nodes are evicted + during the next eviction cycle. +*/ + +use super::{get_healthy_worker_indices, CacheAwareConfig, LoadBalancingPolicy}; +use crate::core::Worker; +use crate::tree::Tree; +use metrics::{counter, gauge}; +use std::sync::{Arc, Mutex}; +use std::thread; +use std::time::Duration; +use tracing::{debug, info}; + +/// Cache-aware routing policy +/// +/// Routes requests based on cache affinity when load is balanced, +/// switches to shortest-queue routing when load is imbalanced. +#[derive(Debug)] +pub struct CacheAwarePolicy { + config: CacheAwareConfig, + tree: Arc>, + eviction_handle: Option>, +} + +impl CacheAwarePolicy { + pub fn new() -> Self { + Self::with_config(CacheAwareConfig::default()) + } + + pub fn with_config(config: CacheAwareConfig) -> Self { + let tree = Arc::new(Mutex::new(Tree::new())); + + // Start background eviction thread if configured + let eviction_handle = if config.eviction_interval_secs > 0 { + let tree_clone = Arc::clone(&tree); + let max_tree_size = config.max_tree_size; + let interval = config.eviction_interval_secs; + + Some(thread::spawn(move || loop { + thread::sleep(Duration::from_secs(interval)); + + if let Ok(tree_guard) = tree_clone.lock() { + tree_guard.evict_tenant_by_size(max_tree_size); + debug!("Cache eviction completed, max_size: {}", max_tree_size); + } + })) + } else { + None + }; + + Self { + config, + tree, + eviction_handle, + } + } + + /// Initialize the tree with worker URLs + pub fn init_workers(&self, workers: &[Box]) { + if let Ok(tree) = self.tree.lock() { + for worker in workers { + tree.insert("", worker.url()); + } + } + } + + /// Remove a worker from the tree + pub fn remove_worker(&self, url: &str) { + if let Ok(tree) = self.tree.lock() { + tree.remove_tenant(url); + } + } + + /// Run cache eviction to prevent unbounded growth + pub fn evict_cache(&self, max_size: usize) { + if let Ok(tree) = self.tree.lock() { + tree.evict_tenant_by_size(max_size); + } + } +} + +impl LoadBalancingPolicy for CacheAwarePolicy { + fn select_worker( + &self, + workers: &[Box], + request_text: Option<&str>, + ) -> Option { + let healthy_indices = get_healthy_worker_indices(workers); + + if healthy_indices.is_empty() { + return None; + } + + // Get current load statistics + let loads: Vec = workers.iter().map(|w| w.load()).collect(); + let max_load = *loads.iter().max().unwrap_or(&0); + let min_load = *loads.iter().min().unwrap_or(&0); + + // Check if load is imbalanced + let is_imbalanced = max_load.saturating_sub(min_load) > self.config.balance_abs_threshold + && (max_load as f32) > (min_load as f32 * self.config.balance_rel_threshold); + + if is_imbalanced { + // Log load balancing trigger + let worker_loads: Vec<(String, usize)> = workers + .iter() + .map(|w| (w.url().to_string(), w.load())) + .collect(); + + info!( + "Load balancing triggered due to workload imbalance:\n\ + Max load: {}, Min load: {}\n\ + Current worker loads: {:?}", + max_load, min_load, worker_loads + ); + + counter!("sgl_router_load_balancing_events_total").increment(1); + gauge!("sgl_router_max_load").set(max_load as f64); + gauge!("sgl_router_min_load").set(min_load as f64); + + // Use shortest queue when imbalanced + let min_load_idx = healthy_indices + .iter() + .min_by_key(|&&idx| workers[idx].load()) + .copied()?; + + // Increment processed counter + workers[min_load_idx].increment_processed(); + counter!("sgl_router_processed_requests_total", "worker" => workers[min_load_idx].url().to_string()) + .increment(1); + + return Some(min_load_idx); + } + + // Use cache-aware routing when balanced + let text = request_text.unwrap_or(""); + + if let Ok(tree) = self.tree.lock() { + let (matched_text, matched_worker) = tree.prefix_match(text); + let match_rate = if text.is_empty() { + 0.0 + } else { + matched_text.chars().count() as f32 / text.chars().count() as f32 + }; + + let selected_url = if match_rate > self.config.cache_threshold { + counter!("sgl_router_cache_hits_total").increment(1); + matched_worker.to_string() + } else { + counter!("sgl_router_cache_misses_total").increment(1); + tree.get_smallest_tenant() + }; + + // Find the index of the selected worker + let selected_idx = workers.iter().position(|w| w.url() == selected_url)?; + + // Only proceed if the worker is healthy + if !workers[selected_idx].is_healthy() { + return healthy_indices.first().copied(); + } + + // Update the tree with this request + tree.insert(text, &selected_url); + + // Increment processed counter + workers[selected_idx].increment_processed(); + counter!("sgl_router_processed_requests_total", "worker" => selected_url).increment(1); + + return Some(selected_idx); + } + + // Fallback to first healthy worker if tree operations fail + healthy_indices.first().copied() + } + + fn name(&self) -> &'static str { + "cache_aware" + } + + fn on_request_complete(&self, worker_url: &str, success: bool) { + // Could track success rates per worker for more intelligent routing + if !success { + // Optionally reduce affinity for failed requests + tracing::debug!( + "Request to {} completed with success={}", + worker_url, + success + ); + } + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn select_worker_pair( + &self, + prefill_workers: &[Box], + decode_workers: &[Box], + request_text: Option<&str>, + ) -> Option<(usize, usize)> { + // In PD mode: + // - Prefill: Use cache-aware routing for better cache utilization + // - Decode: Use least-load routing for better load distribution + + // Select prefill worker using cache-aware logic + let prefill_idx = self.select_worker(prefill_workers, request_text)?; + + // Select decode worker using least-load logic + let healthy_decode = get_healthy_worker_indices(decode_workers); + if healthy_decode.is_empty() { + return None; + } + + let decode_idx = healthy_decode + .iter() + .min_by_key(|&&idx| decode_workers[idx].load()) + .copied()?; + + Some((prefill_idx, decode_idx)) + } +} + +impl Default for CacheAwarePolicy { + fn default() -> Self { + Self::new() + } +} + +impl Drop for CacheAwarePolicy { + fn drop(&mut self) { + // Note: We can't properly stop the eviction thread since it's in an infinite loop + // In a production system, we'd use a channel or atomic flag to signal shutdown + if let Some(handle) = self.eviction_handle.take() { + // The thread will continue running until the program exits + // This is acceptable for now since the router typically runs for the lifetime of the program + drop(handle); + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::core::{BasicWorker, WorkerType}; + + #[test] + fn test_cache_aware_with_balanced_load() { + // Create policy without eviction thread for testing + let config = CacheAwareConfig { + eviction_interval_secs: 0, // Disable eviction thread + ..Default::default() + }; + let policy = CacheAwarePolicy::with_config(config); + let workers: Vec> = vec![ + Box::new(BasicWorker::new( + "http://w1:8000".to_string(), + WorkerType::Regular, + )), + Box::new(BasicWorker::new( + "http://w2:8000".to_string(), + WorkerType::Regular, + )), + ]; + + // Initialize the policy with workers + policy.init_workers(&workers); + + // First request should be distributed + let idx1 = policy.select_worker(&workers, Some("hello world")).unwrap(); + + // Same request should go to same worker (cache hit) + let idx2 = policy.select_worker(&workers, Some("hello world")).unwrap(); + assert_eq!(idx1, idx2); + + // Similar request should also go to same worker + let idx3 = policy.select_worker(&workers, Some("hello")).unwrap(); + assert_eq!(idx1, idx3); + } + + #[test] + fn test_cache_aware_with_imbalanced_load() { + let policy = CacheAwarePolicy::with_config(CacheAwareConfig { + cache_threshold: 0.5, + balance_abs_threshold: 5, + balance_rel_threshold: 2.0, + eviction_interval_secs: 0, // Disable eviction thread + max_tree_size: 10000, + }); + + let worker1 = BasicWorker::new("http://w1:8000".to_string(), WorkerType::Regular); + let worker2 = BasicWorker::new("http://w2:8000".to_string(), WorkerType::Regular); + + // Create significant load imbalance + for _ in 0..20 { + worker1.increment_load(); + } + // worker2 has load 0 + + let workers: Vec> = vec![Box::new(worker1), Box::new(worker2)]; + policy.init_workers(&workers); + + // Should select worker2 (lower load) despite cache affinity + for _ in 0..5 { + let idx = policy.select_worker(&workers, Some("test")).unwrap(); + assert_eq!(idx, 1); // Should always pick worker2 + } + } + + #[test] + fn test_cache_aware_worker_removal() { + let config = CacheAwareConfig { + eviction_interval_secs: 0, // Disable eviction thread + ..Default::default() + }; + let policy = CacheAwarePolicy::with_config(config); + let workers: Vec> = vec![ + Box::new(BasicWorker::new( + "http://w1:8000".to_string(), + WorkerType::Regular, + )), + Box::new(BasicWorker::new( + "http://w2:8000".to_string(), + WorkerType::Regular, + )), + ]; + + policy.init_workers(&workers); + + // Route some requests + policy.select_worker(&workers, Some("test1")); + policy.select_worker(&workers, Some("test2")); + + // Remove a worker + policy.remove_worker("http://w1:8000"); + workers[0].set_healthy(false); + + // All requests should now go to worker2 + let idx = policy.select_worker(&workers, Some("test1")).unwrap(); + assert_eq!(idx, 1); + } +} diff --git a/sgl-router/src/policies/factory.rs b/sgl-router/src/policies/factory.rs new file mode 100644 index 000000000000..c65785d637ce --- /dev/null +++ b/sgl-router/src/policies/factory.rs @@ -0,0 +1,94 @@ +//! Factory for creating load balancing policies + +use super::{ + CacheAwareConfig, CacheAwarePolicy, LoadBalancingPolicy, PowerOfTwoPolicy, RandomPolicy, + RoundRobinPolicy, +}; +use crate::config::PolicyConfig; +use std::sync::Arc; + +/// Factory for creating policy instances +pub struct PolicyFactory; + +impl PolicyFactory { + /// Create a policy from configuration + pub fn create_from_config(config: &PolicyConfig) -> Arc { + match config { + PolicyConfig::Random => Arc::new(RandomPolicy::new()), + PolicyConfig::RoundRobin => Arc::new(RoundRobinPolicy::new()), + PolicyConfig::PowerOfTwo { .. } => Arc::new(PowerOfTwoPolicy::new()), + PolicyConfig::CacheAware { + cache_threshold, + balance_abs_threshold, + balance_rel_threshold, + eviction_interval_secs, + max_tree_size, + } => { + let config = CacheAwareConfig { + cache_threshold: *cache_threshold, + balance_abs_threshold: *balance_abs_threshold, + balance_rel_threshold: *balance_rel_threshold, + eviction_interval_secs: *eviction_interval_secs, + max_tree_size: *max_tree_size, + }; + Arc::new(CacheAwarePolicy::with_config(config)) + } + } + } + + /// Create a policy by name (for dynamic loading) + pub fn create_by_name(name: &str) -> Option> { + match name.to_lowercase().as_str() { + "random" => Some(Arc::new(RandomPolicy::new())), + "round_robin" | "roundrobin" => Some(Arc::new(RoundRobinPolicy::new())), + "power_of_two" | "poweroftwo" => Some(Arc::new(PowerOfTwoPolicy::new())), + "cache_aware" | "cacheaware" => Some(Arc::new(CacheAwarePolicy::new())), + _ => None, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_create_from_config() { + // Test Random + let policy = PolicyFactory::create_from_config(&PolicyConfig::Random); + assert_eq!(policy.name(), "random"); + + // Test RoundRobin + let policy = PolicyFactory::create_from_config(&PolicyConfig::RoundRobin); + assert_eq!(policy.name(), "round_robin"); + + // Test PowerOfTwo + let policy = PolicyFactory::create_from_config(&PolicyConfig::PowerOfTwo { + load_check_interval_secs: 60, + }); + assert_eq!(policy.name(), "power_of_two"); + + // Test CacheAware + let policy = PolicyFactory::create_from_config(&PolicyConfig::CacheAware { + cache_threshold: 0.7, + balance_abs_threshold: 10, + balance_rel_threshold: 1.5, + eviction_interval_secs: 30, + max_tree_size: 1000, + }); + assert_eq!(policy.name(), "cache_aware"); + } + + #[test] + fn test_create_by_name() { + assert!(PolicyFactory::create_by_name("random").is_some()); + assert!(PolicyFactory::create_by_name("RANDOM").is_some()); + assert!(PolicyFactory::create_by_name("round_robin").is_some()); + assert!(PolicyFactory::create_by_name("RoundRobin").is_some()); + assert!(PolicyFactory::create_by_name("power_of_two").is_some()); + assert!(PolicyFactory::create_by_name("PowerOfTwo").is_some()); + assert!(PolicyFactory::create_by_name("cache_aware").is_some()); + assert!(PolicyFactory::create_by_name("CacheAware").is_some()); + assert!(PolicyFactory::create_by_name("unknown").is_none()); + } +} diff --git a/sgl-router/src/policies/mod.rs b/sgl-router/src/policies/mod.rs new file mode 100644 index 000000000000..83fdd95b085d --- /dev/null +++ b/sgl-router/src/policies/mod.rs @@ -0,0 +1,143 @@ +//! Load balancing policies for SGLang router +//! +//! This module provides a unified abstraction for routing policies that work +//! across both regular and prefill-decode (PD) routing modes. + +use crate::core::Worker; +use std::fmt::Debug; + +mod cache_aware; +mod factory; +mod power_of_two; +mod random; +mod round_robin; + +pub use cache_aware::CacheAwarePolicy; +pub use factory::PolicyFactory; +pub use power_of_two::PowerOfTwoPolicy; +pub use random::RandomPolicy; +pub use round_robin::RoundRobinPolicy; + +/// Core trait for load balancing policies +/// +/// This trait provides a unified interface for implementing routing algorithms +/// that can work with both regular single-worker selection and PD dual-worker selection. +pub trait LoadBalancingPolicy: Send + Sync + Debug { + /// Select a single worker from the available workers + /// + /// This is used for regular routing mode where requests go to a single worker. + fn select_worker( + &self, + workers: &[Box], + request_text: Option<&str>, + ) -> Option; + + /// Select a pair of workers (prefill and decode) for PD routing + /// + /// Returns indices of (prefill_worker, decode_worker) from their respective arrays. + /// Default implementation uses select_worker for each array independently. + fn select_worker_pair( + &self, + prefill_workers: &[Box], + decode_workers: &[Box], + request_text: Option<&str>, + ) -> Option<(usize, usize)> { + // Default implementation: independently select from each pool + let prefill_idx = self.select_worker(prefill_workers, request_text)?; + let decode_idx = self.select_worker(decode_workers, request_text)?; + Some((prefill_idx, decode_idx)) + } + + /// Update policy state after request completion + /// + /// This is called when a request completes (successfully or not) to allow + /// policies to update their internal state. + fn on_request_complete(&self, _worker_url: &str, _success: bool) { + // Default: no-op for stateless policies + } + + /// Get policy name for metrics and debugging + fn name(&self) -> &'static str; + + /// Update worker load information + /// + /// This is called periodically with current load information for load-aware policies. + fn update_loads(&self, _loads: &std::collections::HashMap) { + // Default: no-op for policies that don't use load information + } + + /// Reset any internal state + /// + /// This is useful for policies that maintain state (e.g., round-robin counters). + fn reset(&self) { + // Default: no-op for stateless policies + } + + /// Get as Any for downcasting + fn as_any(&self) -> &dyn std::any::Any; +} + +/// Configuration for cache-aware policy +#[derive(Debug, Clone)] +pub struct CacheAwareConfig { + pub cache_threshold: f32, + pub balance_abs_threshold: usize, + pub balance_rel_threshold: f32, + pub eviction_interval_secs: u64, + pub max_tree_size: usize, +} + +impl Default for CacheAwareConfig { + fn default() -> Self { + Self { + cache_threshold: 0.5, + balance_abs_threshold: 32, + balance_rel_threshold: 1.1, + eviction_interval_secs: 30, + max_tree_size: 10000, + } + } +} + +/// Helper function to filter healthy workers and return their indices +pub(crate) fn get_healthy_worker_indices(workers: &[Box]) -> Vec { + workers + .iter() + .enumerate() + .filter(|(_, w)| w.is_healthy()) + .map(|(idx, _)| idx) + .collect() +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::core::{BasicWorker, WorkerType}; + + #[test] + fn test_get_healthy_worker_indices() { + let workers: Vec> = vec![ + Box::new(BasicWorker::new( + "http://w1:8000".to_string(), + WorkerType::Regular, + )), + Box::new(BasicWorker::new( + "http://w2:8000".to_string(), + WorkerType::Regular, + )), + Box::new(BasicWorker::new( + "http://w3:8000".to_string(), + WorkerType::Regular, + )), + ]; + + // All healthy initially + let indices = get_healthy_worker_indices(&workers); + assert_eq!(indices, vec![0, 1, 2]); + + // Mark one unhealthy + workers[1].set_healthy(false); + let indices = get_healthy_worker_indices(&workers); + assert_eq!(indices, vec![0, 2]); + } +} diff --git a/sgl-router/src/policies/power_of_two.rs b/sgl-router/src/policies/power_of_two.rs new file mode 100644 index 000000000000..53c8461965ff --- /dev/null +++ b/sgl-router/src/policies/power_of_two.rs @@ -0,0 +1,201 @@ +//! Power-of-two choices load balancing policy + +use super::{get_healthy_worker_indices, LoadBalancingPolicy}; +use crate::core::Worker; +use metrics::counter; +use rand::Rng; +use std::collections::HashMap; +use std::sync::RwLock; +use tracing::info; + +/// Power-of-two choices policy +/// +/// Randomly selects two workers and routes to the one with lower load. +/// This provides good load distribution with minimal coordination overhead. +#[derive(Debug)] +pub struct PowerOfTwoPolicy { + /// Cached load information from external monitoring + cached_loads: RwLock>, +} + +impl PowerOfTwoPolicy { + pub fn new() -> Self { + Self { + cached_loads: RwLock::new(HashMap::new()), + } + } + + fn get_worker_load(&self, worker: &dyn Worker) -> isize { + // First check cached loads (from external monitoring) + if let Ok(loads) = self.cached_loads.read() { + if let Some(&load) = loads.get(worker.url()) { + return load; + } + } + + // Fall back to local load counter + worker.load() as isize + } +} + +impl LoadBalancingPolicy for PowerOfTwoPolicy { + fn select_worker( + &self, + workers: &[Box], + _request_text: Option<&str>, + ) -> Option { + let healthy_indices = get_healthy_worker_indices(workers); + + if healthy_indices.is_empty() { + return None; + } + + if healthy_indices.len() == 1 { + return Some(healthy_indices[0]); + } + + // Select two random workers + let mut rng = rand::thread_rng(); + let idx1 = rng.gen_range(0..healthy_indices.len()); + let mut idx2 = rng.gen_range(0..healthy_indices.len()); + + // Ensure we pick two different workers + while idx2 == idx1 { + idx2 = rng.gen_range(0..healthy_indices.len()); + } + + let worker_idx1 = healthy_indices[idx1]; + let worker_idx2 = healthy_indices[idx2]; + + // Compare loads and select the less loaded one + let load1 = self.get_worker_load(workers[worker_idx1].as_ref()); + let load2 = self.get_worker_load(workers[worker_idx2].as_ref()); + + // Log selection for debugging + let selected_idx = if load1 <= load2 { + worker_idx1 + } else { + worker_idx2 + }; + + info!( + "Power-of-two selection: {}={} vs {}={} -> selected {}", + workers[worker_idx1].url(), + load1, + workers[worker_idx2].url(), + load2, + workers[selected_idx].url() + ); + + // Increment processed counter + workers[selected_idx].increment_processed(); + counter!("sgl_router_processed_requests_total", "worker" => workers[selected_idx].url().to_string()) + .increment(1); + + Some(selected_idx) + } + + fn name(&self) -> &'static str { + "power_of_two" + } + + fn update_loads(&self, loads: &HashMap) { + if let Ok(mut cached) = self.cached_loads.write() { + *cached = loads.clone(); + } + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } +} + +impl Default for PowerOfTwoPolicy { + fn default() -> Self { + Self::new() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::core::{BasicWorker, WorkerType}; + + #[test] + fn test_power_of_two_selection() { + let policy = PowerOfTwoPolicy::new(); + let worker1 = BasicWorker::new("http://w1:8000".to_string(), WorkerType::Regular); + let worker2 = BasicWorker::new("http://w2:8000".to_string(), WorkerType::Regular); + let worker3 = BasicWorker::new("http://w3:8000".to_string(), WorkerType::Regular); + + // Set different loads + for _ in 0..10 { + worker1.increment_load(); + } + for _ in 0..5 { + worker2.increment_load(); + } + // worker3 has load 0 + + let workers: Vec> = + vec![Box::new(worker1), Box::new(worker2), Box::new(worker3)]; + + // Run multiple selections + let mut selected_counts = vec![0; 3]; + for _ in 0..100 { + if let Some(idx) = policy.select_worker(&workers, None) { + selected_counts[idx] += 1; + } + } + + // Worker with lowest load (worker3) should be selected most often + assert!(selected_counts[2] > selected_counts[1]); + assert!(selected_counts[1] > selected_counts[0]); + } + + #[test] + fn test_power_of_two_with_cached_loads() { + let policy = PowerOfTwoPolicy::new(); + let workers: Vec> = vec![ + Box::new(BasicWorker::new( + "http://w1:8000".to_string(), + WorkerType::Regular, + )), + Box::new(BasicWorker::new( + "http://w2:8000".to_string(), + WorkerType::Regular, + )), + ]; + + // Update cached loads + let mut loads = HashMap::new(); + loads.insert("http://w1:8000".to_string(), 100); + loads.insert("http://w2:8000".to_string(), 10); + policy.update_loads(&loads); + + // Should prefer worker2 with lower cached load + let mut w2_selected = 0; + for _ in 0..50 { + if let Some(idx) = policy.select_worker(&workers, None) { + if idx == 1 { + w2_selected += 1; + } + } + } + + // Worker2 should be selected significantly more often + assert!(w2_selected > 35); // Should win most of the time + } + + #[test] + fn test_power_of_two_single_worker() { + let policy = PowerOfTwoPolicy::new(); + let workers: Vec> = vec![Box::new(BasicWorker::new( + "http://w1:8000".to_string(), + WorkerType::Regular, + ))]; + + // With single worker, should always select it + assert_eq!(policy.select_worker(&workers, None), Some(0)); + } +} diff --git a/sgl-router/src/policies/random.rs b/sgl-router/src/policies/random.rs new file mode 100644 index 000000000000..50920bdf1800 --- /dev/null +++ b/sgl-router/src/policies/random.rs @@ -0,0 +1,116 @@ +//! Random load balancing policy + +use super::{get_healthy_worker_indices, LoadBalancingPolicy}; +use crate::core::Worker; +use rand::Rng; + +/// Random selection policy +/// +/// Selects workers randomly with uniform distribution among healthy workers. +#[derive(Debug, Default)] +pub struct RandomPolicy; + +impl RandomPolicy { + pub fn new() -> Self { + Self + } +} + +impl LoadBalancingPolicy for RandomPolicy { + fn select_worker( + &self, + workers: &[Box], + _request_text: Option<&str>, + ) -> Option { + let healthy_indices = get_healthy_worker_indices(workers); + + if healthy_indices.is_empty() { + return None; + } + + let mut rng = rand::thread_rng(); + let random_idx = rng.gen_range(0..healthy_indices.len()); + Some(healthy_indices[random_idx]) + } + + fn name(&self) -> &'static str { + "random" + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::core::{BasicWorker, WorkerType}; + use std::collections::HashMap; + + #[test] + fn test_random_selection() { + let policy = RandomPolicy::new(); + let workers: Vec> = vec![ + Box::new(BasicWorker::new( + "http://w1:8000".to_string(), + WorkerType::Regular, + )), + Box::new(BasicWorker::new( + "http://w2:8000".to_string(), + WorkerType::Regular, + )), + Box::new(BasicWorker::new( + "http://w3:8000".to_string(), + WorkerType::Regular, + )), + ]; + + // Test multiple selections to ensure randomness + let mut counts = HashMap::new(); + for _ in 0..100 { + if let Some(idx) = policy.select_worker(&workers, None) { + *counts.entry(idx).or_insert(0) += 1; + } + } + + // All workers should be selected at least once + assert_eq!(counts.len(), 3); + assert!(counts.values().all(|&count| count > 0)); + } + + #[test] + fn test_random_with_unhealthy_workers() { + let policy = RandomPolicy::new(); + let workers: Vec> = vec![ + Box::new(BasicWorker::new( + "http://w1:8000".to_string(), + WorkerType::Regular, + )), + Box::new(BasicWorker::new( + "http://w2:8000".to_string(), + WorkerType::Regular, + )), + ]; + + // Mark first worker as unhealthy + workers[0].set_healthy(false); + + // Should always select the healthy worker (index 1) + for _ in 0..10 { + assert_eq!(policy.select_worker(&workers, None), Some(1)); + } + } + + #[test] + fn test_random_no_healthy_workers() { + let policy = RandomPolicy::new(); + let workers: Vec> = vec![Box::new(BasicWorker::new( + "http://w1:8000".to_string(), + WorkerType::Regular, + ))]; + + workers[0].set_healthy(false); + assert_eq!(policy.select_worker(&workers, None), None); + } +} diff --git a/sgl-router/src/policies/round_robin.rs b/sgl-router/src/policies/round_robin.rs new file mode 100644 index 000000000000..4401605f007e --- /dev/null +++ b/sgl-router/src/policies/round_robin.rs @@ -0,0 +1,136 @@ +//! Round-robin load balancing policy + +use super::{get_healthy_worker_indices, LoadBalancingPolicy}; +use crate::core::Worker; +use std::sync::atomic::{AtomicUsize, Ordering}; + +/// Round-robin selection policy +/// +/// Selects workers in sequential order, cycling through all healthy workers. +#[derive(Debug, Default)] +pub struct RoundRobinPolicy { + counter: AtomicUsize, +} + +impl RoundRobinPolicy { + pub fn new() -> Self { + Self { + counter: AtomicUsize::new(0), + } + } +} + +impl LoadBalancingPolicy for RoundRobinPolicy { + fn select_worker( + &self, + workers: &[Box], + _request_text: Option<&str>, + ) -> Option { + let healthy_indices = get_healthy_worker_indices(workers); + + if healthy_indices.is_empty() { + return None; + } + + // Get and increment counter atomically + let count = self.counter.fetch_add(1, Ordering::Relaxed); + let selected_idx = count % healthy_indices.len(); + + Some(healthy_indices[selected_idx]) + } + + fn name(&self) -> &'static str { + "round_robin" + } + + fn reset(&self) { + self.counter.store(0, Ordering::Relaxed); + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::core::{BasicWorker, WorkerType}; + + #[test] + fn test_round_robin_selection() { + let policy = RoundRobinPolicy::new(); + let workers: Vec> = vec![ + Box::new(BasicWorker::new( + "http://w1:8000".to_string(), + WorkerType::Regular, + )), + Box::new(BasicWorker::new( + "http://w2:8000".to_string(), + WorkerType::Regular, + )), + Box::new(BasicWorker::new( + "http://w3:8000".to_string(), + WorkerType::Regular, + )), + ]; + + // Should select workers in order: 0, 1, 2, 0, 1, 2, ... + assert_eq!(policy.select_worker(&workers, None), Some(0)); + assert_eq!(policy.select_worker(&workers, None), Some(1)); + assert_eq!(policy.select_worker(&workers, None), Some(2)); + assert_eq!(policy.select_worker(&workers, None), Some(0)); + assert_eq!(policy.select_worker(&workers, None), Some(1)); + } + + #[test] + fn test_round_robin_with_unhealthy_workers() { + let policy = RoundRobinPolicy::new(); + let workers: Vec> = vec![ + Box::new(BasicWorker::new( + "http://w1:8000".to_string(), + WorkerType::Regular, + )), + Box::new(BasicWorker::new( + "http://w2:8000".to_string(), + WorkerType::Regular, + )), + Box::new(BasicWorker::new( + "http://w3:8000".to_string(), + WorkerType::Regular, + )), + ]; + + // Mark middle worker as unhealthy + workers[1].set_healthy(false); + + // Should skip unhealthy worker: 0, 2, 0, 2, ... + assert_eq!(policy.select_worker(&workers, None), Some(0)); + assert_eq!(policy.select_worker(&workers, None), Some(2)); + assert_eq!(policy.select_worker(&workers, None), Some(0)); + assert_eq!(policy.select_worker(&workers, None), Some(2)); + } + + #[test] + fn test_round_robin_reset() { + let policy = RoundRobinPolicy::new(); + let workers: Vec> = vec![ + Box::new(BasicWorker::new( + "http://w1:8000".to_string(), + WorkerType::Regular, + )), + Box::new(BasicWorker::new( + "http://w2:8000".to_string(), + WorkerType::Regular, + )), + ]; + + // Advance the counter + assert_eq!(policy.select_worker(&workers, None), Some(0)); + assert_eq!(policy.select_worker(&workers, None), Some(1)); + + // Reset should start from beginning + policy.reset(); + assert_eq!(policy.select_worker(&workers, None), Some(0)); + } +} diff --git a/sgl-router/src/router.rs b/sgl-router/src/router.rs deleted file mode 100644 index e8b68d7c599e..000000000000 --- a/sgl-router/src/router.rs +++ /dev/null @@ -1,1376 +0,0 @@ -use crate::core::{HealthChecker, Worker, WorkerFactory}; -use crate::pd_router::PDRouter; -use crate::pd_types::PDSelectionPolicy; -use crate::tree::Tree; -use ::metrics::{counter, gauge, histogram}; -use actix_web::http::header::{HeaderValue, CONTENT_TYPE}; -use actix_web::{HttpRequest, HttpResponse}; -use futures_util::{StreamExt, TryStreamExt}; -use std::fmt::Debug; -use std::sync::atomic::AtomicUsize; -use std::sync::{Arc, Mutex, RwLock}; -use std::thread; -use std::time::Duration; -use std::time::Instant; -use tokio; -use tracing::{debug, error, info, warn}; - -pub fn copy_request_headers(req: &HttpRequest) -> Vec<(String, String)> { - req.headers() - .iter() - .filter_map(|(name, value)| { - value - .to_str() - .ok() - .map(|v| (name.to_string(), v.to_string())) - }) - .collect() -} - -#[derive(Debug)] -pub enum Router { - RoundRobin { - workers: Arc>>>, - current_index: AtomicUsize, - timeout_secs: u64, - interval_secs: u64, - _health_checker: Option, - }, - Random { - workers: Arc>>>, - timeout_secs: u64, - interval_secs: u64, - _health_checker: Option, - }, - PrefillDecode { - pd_router: Arc, - }, - CacheAware { - /* - Cache-Aware Load Balancing Router - - This router combines two strategies to optimize both cache utilization and request distribution: - - 1. Cache-Aware Routing (Approximate Tree) - 2. Load Balancing (Shortest Queue with Balance Thresholds) - - The router dynamically switches between these strategies based on load conditions: - - Uses load balancing when the system is imbalanced - - Uses cache-aware routing when the system is balanced - - A system is considered imbalanced if both conditions are met: - 1. (max - min) > abs_threshold - 2. max > rel_threshold * min - - Strategy Details: - - 1. Cache-Aware Routing (Approximate Tree) - ------------------------------------------- - This strategy maintains an approximate radix tree for each worker based on request history, - eliminating the need for direct cache state queries. The tree stores raw text characters - instead of token IDs to avoid tokenization overhead. - - Process: - a. For each request, find the worker with the highest prefix match - b. If match rate > cache_threshold: - Route to the worker with highest match (likely has relevant data cached) - c. If match rate ≤ cache_threshold: - Route to the worker with smallest tree size (most available cache capacity) - d. Background maintenance: - Periodically evict least recently used leaf nodes to prevent memory overflow - - 2. Load Balancing (Shortest Queue) - ------------------------------------------- - This strategy tracks pending request counts per worker and routes new requests - to the least busy worker when the system is detected to be imbalanced. - - Configuration Parameters: - ------------------------ - 1. cache_threshold: (float, 0.0 to 1.0) - Minimum prefix match ratio to use highest-match routing. - Below this threshold, routes to worker with most available cache space. - - 2. balance_abs_threshold: (integer) - Absolute difference threshold for load imbalance detection. - System is potentially imbalanced if (max_load - min_load) > abs_threshold - - 3. balance_rel_threshold: (float) - Relative ratio threshold for load imbalance detection. - System is potentially imbalanced if max_load > min_load * rel_threshold - Used in conjunction with abs_threshold to determine final imbalance state. - - 4. eviction_interval_secs: (integer) - Interval between LRU eviction cycles for the approximate trees. - - 5. max_tree_size: (integer) - Maximum nodes per tree. When exceeded, LRU leaf nodes are evicted - during the next eviction cycle. - */ - workers: Arc>>>, - tree: Arc>, - cache_threshold: f32, - balance_abs_threshold: usize, - balance_rel_threshold: f32, - timeout_secs: u64, - interval_secs: u64, - _eviction_thread: Option>, - _health_checker: Option, - }, -} - -#[derive(Debug, Clone)] -pub enum PolicyConfig { - RandomConfig { - timeout_secs: u64, - interval_secs: u64, - }, - RoundRobinConfig { - timeout_secs: u64, - interval_secs: u64, - }, - CacheAwareConfig { - cache_threshold: f32, - balance_abs_threshold: usize, - balance_rel_threshold: f32, - eviction_interval_secs: u64, - max_tree_size: usize, - timeout_secs: u64, - interval_secs: u64, - }, - PrefillDecodeConfig { - selection_policy: PDSelectionPolicy, - prefill_urls: Vec<(String, Option)>, // (url, bootstrap_port) - decode_urls: Vec, - timeout_secs: u64, - interval_secs: u64, - }, -} - -impl Router { - pub fn new(worker_urls: Vec, policy_config: PolicyConfig) -> Result { - // Update active workers gauge - gauge!("sgl_router_active_workers").set(worker_urls.len() as f64); - - // Get timeout and interval from policy config - let (timeout_secs, interval_secs) = match &policy_config { - PolicyConfig::RandomConfig { - timeout_secs, - interval_secs, - } => (*timeout_secs, *interval_secs), - PolicyConfig::RoundRobinConfig { - timeout_secs, - interval_secs, - } => (*timeout_secs, *interval_secs), - PolicyConfig::CacheAwareConfig { - timeout_secs, - interval_secs, - .. - } => (*timeout_secs, *interval_secs), - PolicyConfig::PrefillDecodeConfig { - timeout_secs, - interval_secs, - .. - } => (*timeout_secs, *interval_secs), - }; - - // For PrefillDecode, we need to handle workers differently - match &policy_config { - PolicyConfig::PrefillDecodeConfig { .. } => { - // PD mode doesn't use the worker_urls parameter - // We'll validate PD workers separately - } - _ => { - // Wait until all workers are healthy for regular modes - let worker_urls = worker_urls.clone(); - std::thread::spawn(move || { - Self::wait_for_healthy_workers(&worker_urls, timeout_secs, interval_secs) - }) - .join() - .map_err(|e| { - error!("Health-check thread panicked: {:?}", e); - format!("Health-check thread panicked: {e:?}") - })??; - } - } - - // Create Worker trait objects from URLs - let workers: Vec> = worker_urls - .iter() - .map(|url| WorkerFactory::create_regular(url.clone())) - .collect(); - - // Create router based on policy... - Ok(match policy_config { - PolicyConfig::RandomConfig { - timeout_secs, - interval_secs, - } => { - let workers = Arc::new(RwLock::new(workers)); - let health_checker = - crate::core::start_health_checker(Arc::clone(&workers), interval_secs); - Router::Random { - workers, - timeout_secs, - interval_secs, - _health_checker: Some(health_checker), - } - } - PolicyConfig::RoundRobinConfig { - timeout_secs, - interval_secs, - } => { - let workers = Arc::new(RwLock::new(workers)); - let health_checker = - crate::core::start_health_checker(Arc::clone(&workers), interval_secs); - Router::RoundRobin { - workers, - current_index: std::sync::atomic::AtomicUsize::new(0), - timeout_secs, - interval_secs, - _health_checker: Some(health_checker), - } - } - PolicyConfig::CacheAwareConfig { - cache_threshold, - balance_abs_threshold, - balance_rel_threshold, - eviction_interval_secs, - max_tree_size, - timeout_secs, - interval_secs, - } => { - let tree = Arc::new(Mutex::new(Tree::new())); - - // Create background eviction thread - let tree_clone = Arc::clone(&tree); - let workers = Arc::new(RwLock::new(workers)); - let workers_clone = Arc::clone(&workers); - let eviction_thread = thread::spawn(move || { - loop { - // Sleep for the specified interval - thread::sleep(Duration::from_secs(eviction_interval_secs)); - - let locked_tree_clone = tree_clone.lock().unwrap(); - // Run eviction - locked_tree_clone.evict_tenant_by_size(max_tree_size); - drop(locked_tree_clone); - - // Log worker loads and processed requests - let workers_guard = workers_clone.read().unwrap(); - let loads: Vec<(String, usize)> = workers_guard - .iter() - .map(|w| (w.url().to_string(), w.load())) - .collect(); - info!("Worker loads: {:?}", loads); - - let processed: Vec<(String, usize)> = workers_guard - .iter() - .map(|w| (w.url().to_string(), w.processed_requests())) - .collect(); - info!("Processed requests: {:?}", processed); - } - }); - - for worker in workers.read().unwrap().iter() { - tree.lock().unwrap().insert("", worker.url()); - } - - let health_checker = - crate::core::start_health_checker(Arc::clone(&workers), interval_secs); - - Router::CacheAware { - workers, - tree, - cache_threshold, - balance_abs_threshold, - balance_rel_threshold, - timeout_secs, - interval_secs, - _eviction_thread: Some(eviction_thread), - _health_checker: Some(health_checker), - } - } - PolicyConfig::PrefillDecodeConfig { - selection_policy, - prefill_urls, - decode_urls, - timeout_secs, - interval_secs, - } => { - // Create PDRouter instance - let pd_router = PDRouter::new( - prefill_urls, - decode_urls, - selection_policy, - timeout_secs, - interval_secs, - )?; - - Router::PrefillDecode { - pd_router: Arc::new(pd_router), - } - } - }) - } - - /// Get the current list of worker URLs - pub fn get_worker_urls(&self) -> Vec { - match self { - Router::RoundRobin { workers, .. } - | Router::Random { workers, .. } - | Router::CacheAware { workers, .. } => workers - .read() - .unwrap() - .iter() - .map(|w| w.url().to_string()) - .collect(), - Router::PrefillDecode { .. } => Vec::new(), - } - } - - pub fn wait_for_healthy_workers( - worker_urls: &[String], - timeout_secs: u64, - interval_secs: u64, - ) -> Result<(), String> { - let start_time = std::time::Instant::now(); - let sync_client = reqwest::blocking::Client::builder() - .timeout(Duration::from_secs(timeout_secs)) - .build() - .map_err(|e| format!("Failed to create HTTP client: {}", e))?; - - loop { - if start_time.elapsed() > Duration::from_secs(timeout_secs) { - error!( - "Timeout {}s waiting for workers {:?} to become healthy. Please set --router-worker-startup-timeout-secs (sglang_router.launch_server) or --worker-startup-timeout-secs (sglang_worker.router) to a larger value", - timeout_secs, worker_urls - ); - return Err(format!( - "Timeout {}s waiting for workers {:?} to become healthy. Please set --router-worker-startup-timeout-secs (sglang_router.launch_server) or --worker-startup-timeout-secs (sglang_worker.router) to a larger value", - timeout_secs, worker_urls - )); - } - - let mut all_healthy = true; - let mut unhealthy_workers = Vec::new(); - - for url in worker_urls { - match sync_client.get(&format!("{}/health", url)).send() { - Ok(res) => { - if !res.status().is_success() { - let msg = format!( - "Worker heatlh check is pending with status {}", - res.status() - ); - info!("{}", msg); - all_healthy = false; - unhealthy_workers.push((url, msg)); - } - } - Err(_) => { - let msg = format!("Worker is not ready yet"); - info!("{}", msg); - all_healthy = false; - unhealthy_workers.push((url, msg)); - } - } - } - - if all_healthy { - info!("All workers are healthy"); - return Ok(()); - } else { - info!("Initializing workers:"); - for (url, reason) in &unhealthy_workers { - info!(" {} - {}", url, reason); - } - thread::sleep(Duration::from_secs(interval_secs)); - } - } - } - - fn select_first_worker(&self) -> Result { - match self { - Router::RoundRobin { workers, .. } - | Router::Random { workers, .. } - | Router::CacheAware { workers, .. } => { - let workers_guard = workers.read().unwrap(); - if workers_guard.is_empty() { - Err("No workers are available".to_string()) - } else { - Ok(workers_guard[0].url().to_string()) - } - } - Router::PrefillDecode { .. } => { - // For PD mode, we don't need this method as routing is handled by PDRouter - Err("PrefillDecode mode doesn't use select_first_worker".to_string()) - } - } - } - - pub async fn send_request( - &self, - client: &reqwest::Client, - worker_url: &str, - route: &str, - req: &HttpRequest, - ) -> HttpResponse { - let start = Instant::now(); - let mut request_builder = client.get(format!("{}{}", worker_url, route)); - - // Copy all headers from original request except for /health because it does not need authorization - if route != "/health" { - for (name, value) in copy_request_headers(req) { - // Skip Content-Type and Content-Length as .json() sets them - if name.to_lowercase() != "content-type" && name.to_lowercase() != "content-length" - { - request_builder = request_builder.header(name, value); - } - } - } - - let response = match request_builder.send().await { - Ok(res) => { - let status = actix_web::http::StatusCode::from_u16(res.status().as_u16()) - .unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR); - - match res.bytes().await { - Ok(body) => HttpResponse::build(status).body(body.to_vec()), - Err(e) => HttpResponse::InternalServerError() - .body(format!("Failed to read response body: {}", e)), - } - } - Err(e) => HttpResponse::InternalServerError().body(format!( - "Failed to send request to worker {}: {}", - worker_url, e - )), - }; - - // Record request metrics - if route != "/health" { - let duration = start.elapsed(); - counter!("sgl_router_requests_total", "route" => route.to_string()).increment(1); - histogram!("sgl_router_request_duration_seconds", "route" => route.to_string()) - .record(duration.as_secs_f64()); - - if !response.status().is_success() { - counter!("sgl_router_request_errors_total", "route" => route.to_string()) - .increment(1); - } - } - response - } - - pub async fn route_to_first( - &self, - client: &reqwest::Client, - route: &str, - req: &HttpRequest, - ) -> HttpResponse { - const MAX_REQUEST_RETRIES: u32 = 3; - const MAX_TOTAL_RETRIES: u32 = 6; - let mut total_retries = 0; - - while total_retries < MAX_TOTAL_RETRIES { - match self.select_first_worker() { - Ok(worker_url) => { - let mut request_retries = 0; - - // Try the same worker multiple times - while request_retries < MAX_REQUEST_RETRIES { - if total_retries >= 1 { - info!("Retrying request after {} failed attempts", total_retries); - } - - let response = self.send_request(client, &worker_url, route, req).await; - - if response.status().is_success() { - return response; - } else { - // if the worker is healthy, it means the request is bad, so return the error response - let health_response = - self.send_request(client, &worker_url, "/health", req).await; - if health_response.status().is_success() { - return response; - } - } - - warn!( - "Request to {} failed (attempt {}/{})", - worker_url, - request_retries + 1, - MAX_REQUEST_RETRIES - ); - - request_retries += 1; - total_retries += 1; - - if request_retries == MAX_REQUEST_RETRIES { - warn!("Removing failed worker: {}", worker_url); - self.remove_worker(&worker_url); - break; - } - } - } - Err(e) => return HttpResponse::InternalServerError().body(e), - } - } - - HttpResponse::InternalServerError().body("All retry attempts failed") - } - - pub async fn route_to_all( - &self, - client: &reqwest::Client, - route: &str, - req: &HttpRequest, - ) -> HttpResponse { - // Get all worker URLs based on router type - let worker_urls = match self { - Router::PrefillDecode { .. } => { - // For PD mode, route_to_all is not supported directly - // It should be handled by PDRouter if needed - return HttpResponse::NotImplemented() - .body("route_to_all not implemented for PrefillDecode mode"); - } - _ => self.get_worker_urls(), - }; - - // Send requests to all workers concurrently - let mut tasks = Vec::new(); - for worker_url in &worker_urls { - let mut request_builder = client.post(format!("{}{}", worker_url, route)); - - // Copy headers from original request - for (name, value) in copy_request_headers(req) { - request_builder = request_builder.header(name, value); - } - - tasks.push(request_builder.send()); - } - - // Wait for all responses - let results = futures_util::future::join_all(tasks).await; - - // Check if all succeeded - let all_success = results.iter().all(|r| { - r.as_ref() - .map(|res| res.status().is_success()) - .unwrap_or(false) - }); - - if all_success { - HttpResponse::Ok().body("Operation completed on all servers") - } else { - HttpResponse::InternalServerError().body("Operation failed on one or more servers") - } - } - - pub async fn get_all_loads( - &self, - client: &reqwest::Client, - _req: &HttpRequest, - ) -> HttpResponse { - // For PD mode, delegate to PDRouter - match self { - Router::PrefillDecode { pd_router } => { - return pd_router.get_loads(client).await; - } - _ => { - // For non-PD routers, handle normally - } - } - - let urls = self.get_worker_urls(); - let prefill_urls: Vec = Vec::new(); - let decode_urls = urls; - - // Collect loads from all servers - let mut prefill_loads = Vec::new(); - let mut decode_loads = Vec::new(); - - // Get prefill loads - for url in &prefill_urls { - let load = self.get_worker_load(client, url).await.unwrap_or(-1); - prefill_loads.push(serde_json::json!({ - "engine": format!("(Prefill@{})", url), - "load": load as i64 - })); - } - - // Get decode loads - for url in &decode_urls { - let load = self.get_worker_load(client, url).await.unwrap_or(-1); - decode_loads.push(serde_json::json!({ - "engine": format!("(Decode@{})", url), - "load": load as i64 - })); - } - - HttpResponse::Ok().json(serde_json::json!({ - "prefill": prefill_loads, - "decode": decode_loads - })) - } - - // New method to route typed requests directly - pub async fn route_typed_request< - T: crate::openai_api_types::GenerationRequest + serde::Serialize + Clone, - >( - &self, - client: &reqwest::Client, - req: &HttpRequest, - typed_req: &T, - route: &str, - ) -> HttpResponse { - match self { - Router::PrefillDecode { .. } => HttpResponse::InternalServerError() - .body("PD routing should use specialized typed handlers"), - _ => { - // Handle retries like the original implementation - let start = Instant::now(); - const MAX_REQUEST_RETRIES: u32 = 3; - const MAX_TOTAL_RETRIES: u32 = 6; - let mut total_retries = 0; - - while total_retries < MAX_TOTAL_RETRIES { - // Extract routing text directly from typed request - let text = typed_req.extract_text_for_routing(); - let is_stream = typed_req.is_stream(); - - // Select worker based on text - let worker_url = self.select_generate_worker_from_text(&text); - let mut request_retries = 0; - - // Try the same worker multiple times - while request_retries < MAX_REQUEST_RETRIES { - if total_retries >= 1 { - info!("Retrying request after {} failed attempts", total_retries); - counter!("sgl_router_retries_total", "route" => route.to_string()) - .increment(1); - } - - // For CacheAware router, increment load before request - let load_incremented = match self { - Router::CacheAware { workers, .. } => { - let workers_guard = workers.read().unwrap(); - if let Some(worker) = - workers_guard.iter().find(|w| w.url() == &worker_url) - { - worker.increment_load(); - gauge!("sgl_router_running_requests", "worker" => worker_url.to_string()) - .set(worker.load() as f64); - true - } else { - false - } - } - _ => false, - }; - - // Send typed request directly - let response = self - .send_typed_request( - client, - req, - typed_req, - route, - &worker_url, - is_stream, - load_incremented, - ) - .await; - - if response.status().is_success() { - let duration = start.elapsed(); - histogram!("sgl_router_generate_duration_seconds", "route" => route.to_string()) - .record(duration.as_secs_f64()); - return response; - } else { - // if the worker is healthy, it means the request is bad, so return the error response - let health_response = - self.send_request(client, &worker_url, "/health", req).await; - if health_response.status().is_success() { - counter!("sgl_router_request_errors_total", "route" => route.to_string()) - .increment(1); - return response; - } - } - - warn!( - "Generate request to {} failed (attempt {}/{})", - worker_url, - request_retries + 1, - MAX_REQUEST_RETRIES - ); - - request_retries += 1; - total_retries += 1; - - if request_retries == MAX_REQUEST_RETRIES { - warn!("Removing failed worker: {}", worker_url); - self.remove_worker(&worker_url); - break; - } - } - } - - counter!("sgl_router_request_errors_total", "route" => route.to_string()) - .increment(1); - HttpResponse::InternalServerError().body("All retry attempts failed") - } - } - } - - // Helper method to select worker from text (returns index for RoundRobin/Random, URL for CacheAware) - fn select_generate_worker_from_text(&self, text: &str) -> String { - match self { - Router::RoundRobin { - workers, - current_index, - .. - } => { - let workers_guard = workers.read().unwrap(); - let idx = current_index - .fetch_update( - std::sync::atomic::Ordering::SeqCst, - std::sync::atomic::Ordering::SeqCst, - |x| Some((x + 1) % workers_guard.len()), - ) - .unwrap(); - workers_guard[idx].url().to_string() - } - - Router::Random { workers, .. } => { - let workers_guard = workers.read().unwrap(); - workers_guard[rand::random::() % workers_guard.len()] - .url() - .to_string() - } - - Router::CacheAware { - workers, - tree, - cache_threshold, - balance_abs_threshold, - balance_rel_threshold, - .. - } => { - let tree = tree.lock().unwrap(); - let workers_guard = workers.read().unwrap(); - - // Get current load statistics from workers - let loads: Vec = workers_guard.iter().map(|w| w.load()).collect(); - let max_load = *loads.iter().max().unwrap_or(&0); - let min_load = *loads.iter().min().unwrap_or(&0); - - // Load is considered imbalanced if: - // 1. (max - min) > abs_threshold AND - // 2. max > rel_threshold * min - let is_imbalanced = max_load.saturating_sub(min_load) > *balance_abs_threshold - && (max_load as f32) > (min_load as f32 * balance_rel_threshold); - - let selected_url = if is_imbalanced { - // Log load balancing trigger and current queue state - let worker_loads: Vec<(String, usize)> = workers_guard - .iter() - .map(|w| (w.url().to_string(), w.load())) - .collect(); - - info!( - "Load balancing triggered due to workload imbalance:\n\ - Max load: {}, Min load: {}\n\ - Current worker loads: {:?}", - max_load, min_load, worker_loads - ); - - counter!("sgl_router_load_balancing_events_total").increment(1); - gauge!("sgl_router_max_load").set(max_load as f64); - gauge!("sgl_router_min_load").set(min_load as f64); - - // Use shortest queue routing when load is imbalanced - workers_guard - .iter() - .min_by_key(|w| w.load()) - .map(|w| w.url().to_string()) - .unwrap_or_else(|| workers_guard[0].url().to_string()) - } else { - // Use cache-aware routing when load is balanced - let (matched_text, matched_worker) = tree.prefix_match(&text); - let matched_rate = - matched_text.chars().count() as f32 / text.chars().count() as f32; - - if matched_rate > *cache_threshold { - counter!("sgl_router_cache_hits_total").increment(1); - matched_worker.to_string() - } else { - counter!("sgl_router_cache_misses_total").increment(1); - tree.get_smallest_tenant() - } - }; - - // Find the selected worker and increment processed counter only - if let Some(worker) = workers_guard.iter().find(|w| w.url() == &selected_url) { - worker.increment_processed(); - counter!("sgl_router_processed_requests_total", "worker" => selected_url.to_string()) - .increment(1); - } - - tree.insert(&text, &selected_url); - - selected_url - } - Router::PrefillDecode { .. } => { - // For PD mode, we don't use this method - return "PD_MODE_ERROR".to_string(); - } - } - } - - // Send typed request directly without conversion - async fn send_typed_request( - &self, - client: &reqwest::Client, - req: &HttpRequest, - typed_req: &T, - route: &str, - worker_url: &str, - is_stream: bool, - load_incremented: bool, // Whether load was incremented for this request - ) -> HttpResponse { - let start = Instant::now(); - - // Debug: Log what we're sending - if let Ok(json_str) = serde_json::to_string_pretty(typed_req) { - debug!("Sending request to {}: {}", route, json_str); - } - - let mut request_builder = client - .post(format!("{}{}", worker_url, route)) - .json(typed_req); // Use json() directly with typed request - - // Copy all headers from original request - for (name, value) in copy_request_headers(req) { - // Skip Content-Type and Content-Length as .json() sets them - if name.to_lowercase() != "content-type" && name.to_lowercase() != "content-length" { - request_builder = request_builder.header(&name, &value); - } - } - - let res = match request_builder.send().await { - Ok(res) => res, - Err(e) => { - error!("Failed to send request to {}: {}", worker_url, e); - - // Decrement load on error for CacheAware router - if load_incremented { - if let Router::CacheAware { workers, .. } = self { - if let Ok(workers_guard) = workers.read() { - if let Some(worker) = - workers_guard.iter().find(|w| w.url() == worker_url) - { - worker.decrement_load(); - gauge!("sgl_router_running_requests", "worker" => worker_url.to_string()) - .set(worker.load() as f64); - } - } - } - } - - return HttpResponse::InternalServerError().body(format!("Request failed: {}", e)); - } - }; - - let status = actix_web::http::StatusCode::from_u16(res.status().as_u16()) - .unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR); - - if !is_stream { - // For non-streaming requests, get response first - let response = match res.bytes().await { - Ok(body) => HttpResponse::build(status).body(body.to_vec()), - Err(e) => { - let error_msg = format!("Failed to get response body: {}", e); - HttpResponse::InternalServerError().body(error_msg) - } - }; - - // Decrement load counter for non-streaming CacheAware requests - if load_incremented && !is_stream { - if let Router::CacheAware { workers, .. } = self { - if let Ok(workers_guard) = workers.read() { - if let Some(worker) = workers_guard.iter().find(|w| w.url() == worker_url) { - worker.decrement_load(); - gauge!("sgl_router_running_requests", "worker" => worker_url.to_string()) - .set(worker.load() as f64); - } - } - } - } - - // Record metrics - let duration = start.elapsed(); - histogram!("sgl_router_generate_duration_seconds", "route" => route.to_string()) - .record(duration.as_secs_f64()); - counter!("sgl_router_requests_total", "route" => route.to_string()).increment(1); - - response - } else if let Router::CacheAware { workers, .. } = self { - // For streaming with CacheAware router, we need to manually decrement when done - let workers = Arc::clone(workers); - let worker_url = worker_url.to_string(); - - HttpResponse::build(status) - .insert_header((CONTENT_TYPE, HeaderValue::from_static("text/event-stream"))) - .streaming( - res.bytes_stream() - .map_err(|_| { - actix_web::error::ErrorInternalServerError("Failed to read stream") - }) - .inspect(move |bytes| { - if let Ok(bytes) = bytes { - if bytes - .as_ref() - .windows(12) - .any(|window| window == b"data: [DONE]") - { - if let Ok(workers_guard) = workers.read() { - if let Some(worker) = - workers_guard.iter().find(|w| w.url() == &worker_url) - { - worker.decrement_load(); - gauge!("sgl_router_running_requests", "worker" => worker_url.to_string()) - .set(worker.load() as f64); - debug!("Streaming is done!!") - } - } - } - } - }), - ) - } else { - // For non-CacheAware routers, just stream without load tracking - HttpResponse::build(status) - .insert_header((CONTENT_TYPE, HeaderValue::from_static("text/event-stream"))) - .streaming(res.bytes_stream().map_err(|_| { - actix_web::error::ErrorInternalServerError("Failed to read stream") - })) - } - } - - pub async fn add_worker(&self, worker_url: &str) -> Result { - let (timeout_secs, interval_secs) = match self { - Router::Random { - timeout_secs, - interval_secs, - .. - } => (*timeout_secs, *interval_secs), - Router::RoundRobin { - timeout_secs, - interval_secs, - .. - } => (*timeout_secs, *interval_secs), - Router::CacheAware { - timeout_secs, - interval_secs, - .. - } => (*timeout_secs, *interval_secs), - Router::PrefillDecode { .. } => { - // For PD mode, we don't support adding workers via this method - return Err("Adding workers to PrefillDecode router not supported via add_worker. Use dedicated PD management methods.".to_string()); - } - }; - - let start_time = std::time::Instant::now(); - let client = reqwest::Client::builder() - .timeout(Duration::from_secs(timeout_secs)) - .build() - .map_err(|e| format!("Failed to create HTTP client: {}", e))?; - - loop { - if start_time.elapsed() > Duration::from_secs(timeout_secs) { - error!( - "Timeout {}s waiting for worker {} to become healthy. Please set --router-worker-startup-timeout-secs (sglang_router.launch_server) or --worker-startup-timeout-secs (sglang_worker.router) to a larger value", - timeout_secs, worker_url - ); - return Err(format!( - "Timeout {}s waiting for worker {} to become healthy. Please set --router-worker-startup-timeout-secs (sglang_router.launch_server) or --worker-startup-timeout-secs (sglang_worker.router) to a larger value", - timeout_secs, worker_url - )); - } - - match client.get(&format!("{}/health", worker_url)).send().await { - Ok(res) => { - if res.status().is_success() { - match self { - Router::RoundRobin { workers, .. } - | Router::Random { workers, .. } - | Router::CacheAware { workers, .. } => { - info!("Worker {} health check passed", worker_url); - let mut workers_guard = workers.write().unwrap(); - if workers_guard.iter().any(|w| w.url() == worker_url) { - return Err(format!("Worker {} already exists", worker_url)); - } - info!("Added worker: {}", worker_url); - let new_worker = - WorkerFactory::create_regular(worker_url.to_string()); - workers_guard.push(new_worker); - gauge!("sgl_router_active_workers").set(workers_guard.len() as f64); - } - Router::PrefillDecode { .. } => { - return Err("Adding workers to PrefillDecode router not supported via add_worker. Use dedicated PD management methods.".to_string()); - } - } - - // If cache aware, add worker to tree - if let Router::CacheAware { tree, .. } = self { - // Add worker to tree - tree.lock().unwrap().insert("", worker_url); - } - - return Ok(format!("Successfully added worker: {}", worker_url)); - } else { - info!( - "Worker {} health check is pending with status: {}.", - worker_url, - res.status() - ); - // if the url does not have http or https prefix, warn users - if !worker_url.starts_with("http://") && !worker_url.starts_with("https://") - { - warn!("The worker url {} does not have http or https prefix. Please add the prefix to the url.", worker_url); - } - - tokio::time::sleep(Duration::from_secs(interval_secs)).await; - continue; - } - } - Err(e) => { - info!( - "Worker {} health check is pending with error: {}", - worker_url, e - ); - - // if the url does not have http or https prefix, warn users - if !worker_url.starts_with("http://") && !worker_url.starts_with("https://") { - warn!("The worker url {} does not have http or https prefix. Please add the prefix to the url.", worker_url); - } - - tokio::time::sleep(Duration::from_secs(interval_secs)).await; - continue; - } - } - } - } - - pub fn remove_worker(&self, worker_url: &str) { - match self { - Router::RoundRobin { workers, .. } - | Router::Random { workers, .. } - | Router::CacheAware { workers, .. } => { - let mut workers_guard = workers.write().unwrap(); - if let Some(index) = workers_guard.iter().position(|w| w.url() == worker_url) { - workers_guard.remove(index); - info!("Removed worker: {}", worker_url); - gauge!("sgl_router_active_workers").set(workers_guard.len() as f64); - } else { - warn!("Worker {} not found, skipping removal", worker_url); - return; - } - } - Router::PrefillDecode { .. } => { - warn!("Removing workers from PrefillDecode router not supported via remove_worker. Use dedicated PD management methods."); - return; - } - } - - // if cache aware, remove the worker from the tree - if let Router::CacheAware { tree, .. } = self { - tree.lock().unwrap().remove_tenant(&worker_url); - info!("Removed worker from tree: {}", worker_url); - } - } - - /// Add a worker with PD mode support - pub async fn add_pd_worker( - &self, - worker_url: &str, - pod_type: crate::service_discovery::PodType, - bootstrap_port: Option, - ) -> Result { - match self { - Router::PrefillDecode { pd_router } => match pod_type { - crate::service_discovery::PodType::Prefill => pd_router - .add_prefill_server(worker_url.to_string(), bootstrap_port) - .await - .map_err(|e| e.to_string()), - crate::service_discovery::PodType::Decode => pd_router - .add_decode_server(worker_url.to_string()) - .await - .map_err(|e| e.to_string()), - crate::service_discovery::PodType::Regular => { - Err("Regular pod type not supported in PD mode".to_string()) - } - }, - _ => Err("add_pd_worker only supported in PD mode".to_string()), - } - } - - /// Remove a worker with PD mode support - pub async fn remove_pd_worker( - &self, - worker_url: &str, - pod_type: crate::service_discovery::PodType, - ) -> Result { - match self { - Router::PrefillDecode { pd_router } => match pod_type { - crate::service_discovery::PodType::Prefill => pd_router - .remove_prefill_server(worker_url) - .await - .map_err(|e| e.to_string()), - crate::service_discovery::PodType::Decode => pd_router - .remove_decode_server(worker_url) - .await - .map_err(|e| e.to_string()), - crate::service_discovery::PodType::Regular => { - Err("Regular pod type not supported in PD mode".to_string()) - } - }, - _ => Err("remove_pd_worker only supported in PD mode".to_string()), - } - } - - async fn get_worker_load(&self, client: &reqwest::Client, worker_url: &str) -> Option { - match client.get(&format!("{}/get_load", worker_url)).send().await { - Ok(res) if res.status().is_success() => match res.bytes().await { - Ok(bytes) => match serde_json::from_slice::(&bytes) { - Ok(data) => data - .get("load") - .and_then(|v| v.as_i64()) - .map(|v| v as isize), - Err(e) => { - debug!("Failed to parse load response from {}: {}", worker_url, e); - None - } - }, - Err(e) => { - debug!("Failed to read load response from {}: {}", worker_url, e); - None - } - }, - Ok(res) => { - debug!( - "Worker {} returned non-success status: {}", - worker_url, - res.status() - ); - None - } - Err(e) => { - debug!("Failed to get load from {}: {}", worker_url, e); - None - } - } - } - - // PD-specific wrapper methods that delegate to PDRouter - pub async fn route_pd_health_generate( - &self, - _client: &reqwest::Client, - _req: &HttpRequest, - ) -> HttpResponse { - match self { - Router::PrefillDecode { pd_router } => { - pd_router.health_generate(&pd_router.http_client).await - } - _ => HttpResponse::InternalServerError().body("Not in PrefillDecode mode"), - } - } - - pub async fn route_pd_generate_typed( - &self, - _client: &reqwest::Client, - req: &HttpRequest, - typed_req: crate::pd_types::GenerateReqInput, - route: &str, - ) -> HttpResponse { - match self { - Router::PrefillDecode { pd_router } => { - pd_router - .route_generate(&pd_router.http_client, req, typed_req, route) - .await - } - _ => HttpResponse::InternalServerError().body("Not in PrefillDecode mode"), - } - } - - pub async fn route_pd_chat_typed( - &self, - _client: &reqwest::Client, - req: &HttpRequest, - typed_req: crate::pd_types::ChatReqInput, - route: &str, - ) -> HttpResponse { - match self { - Router::PrefillDecode { pd_router } => { - pd_router - .route_chat(&pd_router.http_client, req, typed_req, route) - .await - } - _ => HttpResponse::InternalServerError().body("Not in PrefillDecode mode"), - } - } - - pub async fn get_pd_server_info( - &self, - _client: &reqwest::Client, - _req: &HttpRequest, - ) -> HttpResponse { - match self { - Router::PrefillDecode { pd_router } => { - pd_router.get_server_info(&pd_router.http_client).await - } - _ => HttpResponse::InternalServerError().body("Not in PrefillDecode mode"), - } - } - - pub async fn get_pd_models( - &self, - _client: &reqwest::Client, - req: &HttpRequest, - ) -> HttpResponse { - match self { - Router::PrefillDecode { pd_router } => { - pd_router.get_models(&pd_router.http_client, req).await - } - _ => HttpResponse::InternalServerError().body("Not in PrefillDecode mode"), - } - } - - pub async fn route_pd_flush_cache(&self, _client: &reqwest::Client) -> HttpResponse { - match self { - Router::PrefillDecode { pd_router } => { - pd_router.flush_cache(&pd_router.http_client).await - } - _ => HttpResponse::InternalServerError().body("Not in PrefillDecode mode"), - } - } - - pub async fn get_pd_model_info( - &self, - _client: &reqwest::Client, - req: &HttpRequest, - ) -> HttpResponse { - match self { - Router::PrefillDecode { pd_router } => { - pd_router.get_model_info(&pd_router.http_client, req).await - } - _ => HttpResponse::InternalServerError().body("Not in PrefillDecode mode"), - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::service_discovery::PodType; - - fn create_test_regular_router() -> Router { - let workers = vec![ - WorkerFactory::create_regular("http://worker1:8080".to_string()), - WorkerFactory::create_regular("http://worker2:8080".to_string()), - ]; - Router::Random { - workers: Arc::new(RwLock::new(workers)), - timeout_secs: 5, - interval_secs: 1, - _health_checker: None, - } - } - - #[test] - fn test_router_get_worker_urls_regular() { - let router = create_test_regular_router(); - let urls = router.get_worker_urls(); - - assert_eq!(urls.len(), 2); - assert!(urls.contains(&"http://worker1:8080".to_string())); - assert!(urls.contains(&"http://worker2:8080".to_string())); - } - - // #[test] - // fn test_router_get_worker_urls_pd_mode() { - // // For PD mode, get_worker_urls returns empty list - // // Note: PDRouter::new requires health checks which fail in tests - // // This test would need a mock server or different test setup - // } - - #[tokio::test] - async fn test_add_pd_worker_with_regular_router() { - let router = create_test_regular_router(); - - let result = router - .add_pd_worker("http://new-worker:8080", PodType::Prefill, Some(8081)) - .await; - - assert!(result.is_err()); - assert!(result - .unwrap_err() - .contains("add_pd_worker only supported in PD mode")); - } - - #[tokio::test] - async fn test_remove_pd_worker_with_regular_router() { - let router = create_test_regular_router(); - - let result = router - .remove_pd_worker("http://worker:8080", PodType::Decode) - .await; - - assert!(result.is_err()); - assert!(result - .unwrap_err() - .contains("remove_pd_worker only supported in PD mode")); - } - - // #[tokio::test] - // async fn test_add_pd_worker_with_pd_router_regular_type() { - // // Note: PDRouter::new requires health checks which fail in tests - // // This test would need a mock server or different test setup - // } - - // #[tokio::test] - // async fn test_remove_pd_worker_with_pd_router_regular_type() { - // // Note: PDRouter::new requires health checks which fail in tests - // // This test would need a mock server or different test setup - // } - - #[test] - fn test_select_first_worker_regular() { - let router = create_test_regular_router(); - let result = router.select_first_worker(); - - assert!(result.is_ok()); - assert_eq!(result.unwrap(), "http://worker1:8080"); - } - - // #[test] - // fn test_select_first_worker_pd_mode() { - // // Note: PDRouter::new requires health checks which fail in tests - // // This test would need a mock server or different test setup - // } - - #[test] - fn test_wait_for_healthy_workers_empty_list() { - let result = Router::wait_for_healthy_workers(&[], 1, 1); - assert!(result.is_ok()); - } - - #[test] - fn test_wait_for_healthy_workers_invalid_urls() { - // This test will timeout quickly since the URLs are invalid - let result = - Router::wait_for_healthy_workers(&["http://nonexistent:8080".to_string()], 1, 1); - assert!(result.is_err()); - assert!(result.unwrap_err().contains("Timeout")); - } -} diff --git a/sgl-router/src/routers/factory.rs b/sgl-router/src/routers/factory.rs new file mode 100644 index 000000000000..201240121645 --- /dev/null +++ b/sgl-router/src/routers/factory.rs @@ -0,0 +1,66 @@ +//! Factory for creating router instances + +use super::{pd_router::PDRouter, router::Router, RouterTrait}; +use crate::config::{PolicyConfig, RouterConfig, RoutingMode}; +use crate::policies::PolicyFactory; + +/// Factory for creating router instances based on configuration +pub struct RouterFactory; + +impl RouterFactory { + /// Create a router instance from configuration + pub fn create_router(config: &RouterConfig) -> Result, String> { + match &config.mode { + RoutingMode::Regular { worker_urls } => { + Self::create_regular_router(worker_urls, &config.policy, config) + } + RoutingMode::PrefillDecode { + prefill_urls, + decode_urls, + } => Self::create_pd_router(prefill_urls, decode_urls, &config.policy, config), + } + } + + /// Create a regular router with injected policy + fn create_regular_router( + worker_urls: &[String], + policy_config: &PolicyConfig, + router_config: &RouterConfig, + ) -> Result, String> { + // Create policy + let policy = PolicyFactory::create_from_config(policy_config); + + // Create regular router with injected policy + let router = Router::new( + worker_urls.to_vec(), + policy, + router_config.worker_startup_timeout_secs, + router_config.worker_startup_check_interval_secs, + )?; + + Ok(Box::new(router)) + } + + /// Create a PD router with injected policy + fn create_pd_router( + prefill_urls: &[(String, Option)], + decode_urls: &[String], + policy_config: &PolicyConfig, + router_config: &RouterConfig, + ) -> Result, String> { + // Create policy directly from PolicyConfig + // All policies now support PD mode through the select_worker_pair method + let policy = PolicyFactory::create_from_config(policy_config); + + // Create PD router with injected policy + let router = PDRouter::new( + prefill_urls.to_vec(), + decode_urls.to_vec(), + policy, + router_config.worker_startup_timeout_secs, + router_config.worker_startup_check_interval_secs, + )?; + + Ok(Box::new(router)) + } +} diff --git a/sgl-router/src/routers/mod.rs b/sgl-router/src/routers/mod.rs new file mode 100644 index 000000000000..ffb6d93c7d24 --- /dev/null +++ b/sgl-router/src/routers/mod.rs @@ -0,0 +1,101 @@ +//! Router implementations + +use actix_web::{HttpRequest, HttpResponse}; +use async_trait::async_trait; +use reqwest::Client; +use std::fmt::Debug; + +pub mod factory; +pub mod pd_router; +pub mod pd_types; +pub mod request_adapter; +pub mod router; + +pub use factory::RouterFactory; + +/// Worker management trait for administrative operations +/// +/// This trait is separate from RouterTrait to allow Send futures +/// for use in service discovery and other background tasks +#[async_trait] +pub trait WorkerManagement: Send + Sync { + /// Add a worker to the router + async fn add_worker(&self, worker_url: &str) -> Result; + + /// Remove a worker from the router + fn remove_worker(&self, worker_url: &str); + + /// Get all worker URLs + fn get_worker_urls(&self) -> Vec; +} + +/// Core trait for all router implementations +/// +/// This trait provides a unified interface for routing requests, +/// regardless of whether it's a regular router or PD router. +#[async_trait(?Send)] +pub trait RouterTrait: Send + Sync + Debug + WorkerManagement { + /// Get a reference to self as Any for downcasting + fn as_any(&self) -> &dyn std::any::Any; + /// Route a health check request + async fn health(&self, client: &Client, req: &HttpRequest) -> HttpResponse; + + /// Route a health generate request + async fn health_generate(&self, client: &Client, req: &HttpRequest) -> HttpResponse; + + /// Get server information + async fn get_server_info(&self, client: &Client, req: &HttpRequest) -> HttpResponse; + + /// Get available models + async fn get_models(&self, client: &Client, req: &HttpRequest) -> HttpResponse; + + /// Get model information + async fn get_model_info(&self, client: &Client, req: &HttpRequest) -> HttpResponse; + + /// Route a generate request + async fn route_generate( + &self, + client: &Client, + req: &HttpRequest, + body: serde_json::Value, + ) -> HttpResponse; + + /// Route a chat completion request + async fn route_chat( + &self, + client: &Client, + req: &HttpRequest, + body: serde_json::Value, + ) -> HttpResponse; + + /// Route a completion request + async fn route_completion( + &self, + client: &Client, + req: &HttpRequest, + body: serde_json::Value, + ) -> HttpResponse; + + /// Flush cache on all workers + async fn flush_cache(&self, client: &Client) -> HttpResponse; + + /// Get worker loads (for monitoring) + async fn get_worker_loads(&self, client: &Client) -> HttpResponse; + + /// Get router type name + fn router_type(&self) -> &'static str; + + /// Check if this is a PD router + fn is_pd_mode(&self) -> bool { + self.router_type() == "pd" + } + + /// Server liveness check - is the server process running + fn liveness(&self) -> HttpResponse { + // Simple liveness check - if we can respond, we're alive + HttpResponse::Ok().body("OK") + } + + /// Server readiness check - is the server ready to handle requests + fn readiness(&self) -> HttpResponse; +} diff --git a/sgl-router/src/pd_router.rs b/sgl-router/src/routers/pd_router.rs similarity index 67% rename from sgl-router/src/pd_router.rs rename to sgl-router/src/routers/pd_router.rs index a1f04c7d29db..2ac8f9027762 100644 --- a/sgl-router/src/pd_router.rs +++ b/sgl-router/src/routers/pd_router.rs @@ -1,10 +1,11 @@ // PD (Prefill-Decode) Router Implementation // This module handles routing for disaggregated prefill-decode systems +use super::pd_types::{api_path, Bootstrap, ChatReqInput, GenerateReqInput, PDRouterError}; +use super::request_adapter::ToPdRequest; use crate::core::{HealthChecker, Worker, WorkerFactory, WorkerLoadGuard}; -use crate::pd_types::{ - api_path, Bootstrap, ChatReqInput, GenerateReqInput, PDRouterError, PDSelectionPolicy, -}; +use crate::openai_api_types::{ChatCompletionRequest, CompletionRequest, GenerateRequest}; +use crate::policies::LoadBalancingPolicy; use crate::tree::Tree; use actix_web::http::header::{HeaderValue, CONTENT_TYPE}; use actix_web::{HttpRequest, HttpResponse}; @@ -17,13 +18,11 @@ use std::time::{Duration, Instant}; use tracing::{debug, error, info, warn}; use uuid::Uuid; -// Removed over-engineered ProxyResponse - using HttpResponse directly - #[derive(Debug)] pub struct PDRouter { pub prefill_workers: Arc>>>, pub decode_workers: Arc>>>, - pub selection_policy: PDSelectionPolicy, + pub policy: Arc, pub prefill_tree: Option>>, pub timeout_secs: u64, pub interval_secs: u64, @@ -42,7 +41,7 @@ impl PDRouter { bootstrap_port: Option, ) -> Result { // Wait for the new server to be healthy - crate::router::Router::wait_for_healthy_workers( + crate::routers::router::Router::wait_for_healthy_workers( &[url.clone()], self.timeout_secs, self.interval_secs, @@ -78,7 +77,7 @@ impl PDRouter { pub async fn add_decode_server(&self, url: String) -> Result { // Wait for the new server to be healthy - crate::router::Router::wait_for_healthy_workers( + crate::routers::router::Router::wait_for_healthy_workers( &[url.clone()], self.timeout_secs, self.interval_secs, @@ -103,9 +102,6 @@ impl PDRouter { workers.push(worker); - // Initialize load tracking - // Worker tracks its own load internally - info!("Added decode server: {}", url); Ok(format!("Successfully added decode server: {}", url)) } @@ -128,9 +124,6 @@ impl PDRouter { }); } - // Remove from load tracking - // Worker load tracking is internal - // Remove from cache tree if using cache-aware policy if let Some(ref tree) = self.prefill_tree { // Note: Tree doesn't have a remove method, so we rebuild it @@ -170,7 +163,7 @@ impl PDRouter { pub fn new( prefill_urls: Vec<(String, Option)>, decode_urls: Vec, - selection_policy: PDSelectionPolicy, + policy: Arc, timeout_secs: u64, interval_secs: u64, ) -> Result { @@ -185,25 +178,38 @@ impl PDRouter { .map(WorkerFactory::create_decode) .collect(); - // Wait for PD workers to be healthy + // Wait for PD workers to be healthy (skip if empty - for service discovery mode) let all_urls: Vec = prefill_workers .iter() .chain(decode_workers.iter()) .map(|worker| worker.url().to_string()) .collect(); - crate::router::Router::wait_for_healthy_workers(&all_urls, timeout_secs, interval_secs)?; + if !all_urls.is_empty() { + crate::routers::router::Router::wait_for_healthy_workers( + &all_urls, + timeout_secs, + interval_secs, + )?; + } // Initialize cache-aware components if needed - let prefill_tree = match &selection_policy { - PDSelectionPolicy::CacheAware { .. } => { - let tree = Arc::new(Mutex::new(Tree::new())); - // Initialize tree with prefill workers - for worker in &prefill_workers { - tree.lock().unwrap().insert("", worker.url()); - } - Some(tree) + let prefill_tree = if policy.name() == "cache_aware" { + // Initialize the policy's internal tree with prefill workers + if let Some(cache_policy) = policy + .as_any() + .downcast_ref::() + { + cache_policy.init_workers(&prefill_workers); + } + + let tree = Arc::new(Mutex::new(Tree::new())); + // Initialize tree with prefill workers + for worker in &prefill_workers { + tree.lock().unwrap().insert("", worker.url()); } - _ => None, + Some(tree) + } else { + None }; // Set up background load monitoring for power-of-two selection @@ -216,10 +222,11 @@ impl PDRouter { .build() .map_err(|e| format!("Failed to create HTTP client: {}", e))?; - let load_monitor_handle = if matches!(selection_policy, PDSelectionPolicy::PowerOfTwo) { + let load_monitor_handle = if policy.name() == "power_of_two" { let monitor_urls = all_urls.clone(); let monitor_interval = interval_secs; let monitor_client = http_client.clone(); + let policy_clone = Arc::clone(&policy); Some(Arc::new(tokio::spawn(async move { Self::monitor_worker_loads_with_client( @@ -227,6 +234,7 @@ impl PDRouter { tx, monitor_interval, monitor_client, + policy_clone, ) .await; }))) @@ -246,7 +254,7 @@ impl PDRouter { Ok(PDRouter { prefill_workers, decode_workers, - selection_policy, + policy, prefill_tree, timeout_secs, interval_secs, @@ -270,15 +278,21 @@ impl PDRouter { let _request_id = Uuid::new_v4(); // Get stream flag and return_logprob flag before moving the request - let is_stream = typed_req.is_stream(); + let is_stream = typed_req.stream; let return_logprob = typed_req .other .get("return_logprob") .and_then(|v| v.as_bool()) .unwrap_or(false); + // Extract text for cache-aware routing from the typed request + let request_text = typed_req.text.as_ref().and_then(|t| match t { + super::pd_types::InputText::Single(s) => Some(s.as_str()), + super::pd_types::InputText::Batch(v) => v.first().map(|s| s.as_str()), + }); + // Select servers - let (prefill, decode) = match self.select_pd_pair(client).await { + let (prefill, decode) = match self.select_pd_pair(client, request_text).await { Ok(pair) => pair, Err(e) => { error!("Failed to select PD pair: {}", e); @@ -339,15 +353,24 @@ impl PDRouter { let start = Instant::now(); // Get stream flag and return_logprob flag before moving the request - let is_stream = typed_req.is_stream(); + let is_stream = typed_req.stream; let return_logprob = typed_req .other .get("return_logprob") .and_then(|v| v.as_bool()) .unwrap_or(false); + // Extract text for cache-aware routing from chat messages + let request_text = typed_req + .other + .get("messages") + .and_then(|messages| messages.as_array()) + .and_then(|arr| arr.first()) + .and_then(|msg| msg.get("content")) + .and_then(|content| content.as_str()); + // Select servers - let (prefill, decode) = match self.select_pd_pair(client).await { + let (prefill, decode) = match self.select_pd_pair(client, request_text).await { Ok(pair) => pair, Err(e) => { error!("Failed to select PD pair: {}", e); @@ -424,7 +447,7 @@ impl PDRouter { .json(&json_request); // Copy headers from original request - for (name, value) in crate::router::copy_request_headers(req) { + for (name, value) in crate::routers::router::copy_request_headers(req) { if name.to_lowercase() != "content-type" && name.to_lowercase() != "content-length" { prefill_request = prefill_request.header(&name, &value); decode_request = decode_request.header(&name, &value); @@ -620,104 +643,47 @@ impl PDRouter { async fn select_pd_pair( &self, _client: &reqwest::Client, + request_text: Option<&str>, ) -> Result<(Box, Box), String> { - // Check we have workers - if self + // Get read locks for both worker lists + let prefill_workers = self .prefill_workers .read() - .map_err(|e| format!("Failed to acquire prefill workers lock: {}", e))? - .is_empty() - { - return Err("No prefill workers available. Please check if prefill servers are configured and healthy.".to_string()); - } - if self + .map_err(|e| format!("Failed to acquire prefill workers lock: {}", e))?; + let decode_workers = self .decode_workers .read() - .map_err(|e| format!("Failed to acquire decode workers lock: {}", e))? - .is_empty() - { + .map_err(|e| format!("Failed to acquire decode workers lock: {}", e))?; + + // Check we have workers + if prefill_workers.is_empty() { + return Err("No prefill workers available. Please check if prefill servers are configured and healthy.".to_string()); + } + if decode_workers.is_empty() { return Err("No decode workers available. Please check if decode servers are configured and healthy.".to_string()); } - match &self.selection_policy { - PDSelectionPolicy::Random => self.select_random(), - PDSelectionPolicy::PowerOfTwo => self.select_power_of_two().await, - PDSelectionPolicy::CacheAware { .. } => { - // TODO: Implement cache-aware selection - self.select_power_of_two().await + // Use the policy to select worker pair + match self + .policy + .select_worker_pair(&prefill_workers, &decode_workers, request_text) + { + Some((prefill_idx, decode_idx)) => { + let prefill = prefill_workers[prefill_idx].clone_worker(); + let decode = decode_workers[decode_idx].clone_worker(); + Ok((prefill, decode)) } + None => Err("Failed to select worker pair".to_string()), } } - fn select_random(&self) -> Result<(Box, Box), String> { - let prefill_list = self.prefill_workers.read().map_err(|_| "Lock error")?; - let decode_list = self.decode_workers.read().map_err(|_| "Lock error")?; - - let prefill = prefill_list[rand::random::() % prefill_list.len()].clone_worker(); - let decode = decode_list[rand::random::() % decode_list.len()].clone_worker(); - - Ok((prefill, decode)) - } - - async fn select_power_of_two(&self) -> Result<(Box, Box), String> { - let prefill_list = self.prefill_workers.read().map_err(|_| "Lock error")?; - let decode_list = self.decode_workers.read().map_err(|_| "Lock error")?; - - let (p1_idx, p2_idx) = get_two_random_indices(prefill_list.len()); - let (d1_idx, d2_idx) = get_two_random_indices(decode_list.len()); - - let loads = self.worker_loads.borrow(); - - let p1_load = loads - .get(prefill_list[p1_idx].url()) - .copied() - .unwrap_or(isize::MAX); - let p2_load = loads - .get(prefill_list[p2_idx].url()) - .copied() - .unwrap_or(isize::MAX); - let d1_load = loads - .get(decode_list[d1_idx].url()) - .copied() - .unwrap_or(isize::MAX); - let d2_load = loads - .get(decode_list[d2_idx].url()) - .copied() - .unwrap_or(isize::MAX); - - info!( - "Power-of-two selection - Prefill: {}={} vs {}={} | Decode: {}={} vs {}={}", - prefill_list[p1_idx].url(), - p1_load, - prefill_list[p2_idx].url(), - p2_load, - decode_list[d1_idx].url(), - d1_load, - decode_list[d2_idx].url(), - d2_load - ); - - let selected_prefill = if p1_load <= p2_load { - prefill_list[p1_idx].clone_worker() - } else { - prefill_list[p2_idx].clone_worker() - }; - - let selected_decode = if d1_load <= d2_load { - decode_list[d1_idx].clone_worker() - } else { - decode_list[d2_idx].clone_worker() - }; - - Ok((selected_prefill, selected_decode)) - } - // Background task to monitor worker loads with shared client async fn monitor_worker_loads_with_client( worker_urls: Vec, tx: tokio::sync::watch::Sender>, interval_secs: u64, client: reqwest::Client, + policy: Arc, ) { loop { let mut loads = HashMap::new(); @@ -742,6 +708,9 @@ impl PDRouter { debug!("Worker loads updated: {:?}", loads); + // Update the policy with current loads + policy.update_loads(&loads); + // Check if receiver is still active if tx.send(loads).is_err() { info!("Load monitor receiver dropped, shutting down monitor task"); @@ -792,18 +761,6 @@ impl PDRouter { } // Helper functions -fn get_two_random_indices(len: usize) -> (usize, usize) { - if len == 1 { - (0, 0) - } else { - let idx1 = rand::random::() % len; - let mut idx2 = rand::random::() % len; - while idx2 == idx1 { - idx2 = rand::random::() % len; - } - (idx1, idx2) - } -} async fn get_worker_load(client: &reqwest::Client, worker_url: &str) -> Option { match client.get(format!("{}/get_load", worker_url)).send().await { @@ -841,61 +798,72 @@ async fn get_worker_load(client: &reqwest::Client, worker_url: &str) -> Option HttpResponse { - let mut all_healthy = true; - let mut unhealthy_servers = Vec::new(); + // Test model generation capability by selecting a random pair and testing them + // Note: This endpoint actually causes the model to generate tokens, so we only test one pair - // Collect all worker URLs with their types - let mut worker_infos = Vec::new(); + // Select a random worker pair using the policy + let (prefill, decode) = match self.select_pd_pair(client, None).await { + Ok(pair) => pair, + Err(e) => { + return HttpResponse::ServiceUnavailable() + .body(format!("No healthy worker pair available: {}", e)); + } + }; - for worker in self.prefill_workers.read().unwrap().iter() { - worker_infos.push((worker.url().to_string(), "prefill")); - } + // Test prefill server's health_generate + let prefill_url = format!("{}/health_generate", prefill.url()); + let prefill_result = client.get(&prefill_url).send().await; - for worker in self.decode_workers.read().unwrap().iter() { - worker_infos.push((worker.url().to_string(), "decode")); - } + // Test decode server's health_generate + let decode_url = format!("{}/health_generate", decode.url()); + let decode_result = client.get(&decode_url).send().await; - // Create tasks with URL tracking - let tasks: Vec<_> = worker_infos - .iter() - .map(|(url, _)| { - let health_url = format!("{}/health_generate", url); - client.get(&health_url).send() - }) - .collect(); + // Check results + let mut errors = Vec::new(); - let results = futures_util::future::join_all(tasks).await; + match prefill_result { + Ok(res) if res.status().is_success() => { + debug!( + "Health generate passed for prefill server: {}", + prefill.url() + ); + } + Ok(res) => { + errors.push(format!( + "Prefill {} returned status {}", + prefill.url(), + res.status() + )); + } + Err(e) => { + errors.push(format!("Prefill {} error: {}", prefill.url(), e)); + } + } - for ((url, worker_type), result) in worker_infos.iter().zip(results.into_iter()) { - match result { - Ok(res) if res.status().is_success() => { - debug!("Health check passed for {} server: {}", worker_type, url); - } - Ok(res) => { - all_healthy = false; - let msg = format!( - "{} server {} returned status {}", - worker_type, - url, - res.status() - ); - error!("{}", msg); - unhealthy_servers.push(msg); - } - Err(e) => { - all_healthy = false; - let msg = format!("{} server {} error: {}", worker_type, url, e); - error!("{}", msg); - unhealthy_servers.push(msg); - } + match decode_result { + Ok(res) if res.status().is_success() => { + debug!("Health generate passed for decode server: {}", decode.url()); + } + Ok(res) => { + errors.push(format!( + "Decode {} returned status {}", + decode.url(), + res.status() + )); + } + Err(e) => { + errors.push(format!("Decode {} error: {}", decode.url(), e)); } } - if all_healthy { - HttpResponse::Ok().body("Health check passed on all servers") + if errors.is_empty() { + HttpResponse::Ok().body(format!( + "Health generate passed on selected pair: prefill={}, decode={}", + prefill.url(), + decode.url() + )) } else { - HttpResponse::ServiceUnavailable() - .body(format!("Health check failed: {:?}", unhealthy_servers)) + HttpResponse::ServiceUnavailable().body(format!("Health generate failed: {:?}", errors)) } } @@ -955,7 +923,7 @@ impl PDRouter { if let Some(worker_url) = first_worker_url { // Send request directly without going through Router let mut request_builder = client.get(format!("{}/v1/models", worker_url)); - for (name, value) in crate::router::copy_request_headers(req) { + for (name, value) in crate::routers::router::copy_request_headers(req) { if name.to_lowercase() != "content-type" && name.to_lowercase() != "content-length" { request_builder = request_builder.header(name, value); @@ -1035,7 +1003,7 @@ impl PDRouter { if let Some(worker_url) = first_worker_url { let mut request_builder = client.get(format!("{}/get_model_info", worker_url)); - for (name, value) in crate::router::copy_request_headers(req) { + for (name, value) in crate::routers::router::copy_request_headers(req) { if name.to_lowercase() != "content-type" && name.to_lowercase() != "content-length" { request_builder = request_builder.header(name, value); @@ -1102,3 +1070,324 @@ impl PDRouter { } } } + +use crate::routers::{RouterTrait, WorkerManagement}; +use async_trait::async_trait; +use reqwest::Client; + +#[async_trait] +impl WorkerManagement for PDRouter { + async fn add_worker(&self, _worker_url: &str) -> Result { + // For PD router, we don't support adding workers via this generic method + Err( + "PD router requires specific add_prefill_server or add_decode_server methods" + .to_string(), + ) + } + + fn remove_worker(&self, worker_url: &str) { + // For PD router, we would need to know if it's a prefill or decode server + // For now, try both + if let Ok(mut workers) = self.prefill_workers.write() { + if let Some(index) = workers.iter().position(|w| w.url() == worker_url) { + workers.remove(index); + info!("Removed prefill worker: {}", worker_url); + return; + } + } + + if let Ok(mut workers) = self.decode_workers.write() { + if let Some(index) = workers.iter().position(|w| w.url() == worker_url) { + workers.remove(index); + info!("Removed decode worker: {}", worker_url); + } + } + } + + fn get_worker_urls(&self) -> Vec { + let mut urls = Vec::new(); + + // Add prefill worker URLs + if let Ok(workers) = self.prefill_workers.read() { + for worker in workers.iter() { + urls.push(worker.url().to_string()); + } + } + + // Add decode worker URLs + if let Ok(workers) = self.decode_workers.read() { + for worker in workers.iter() { + urls.push(worker.url().to_string()); + } + } + + urls + } +} + +#[async_trait(?Send)] +impl RouterTrait for PDRouter { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + async fn health(&self, _client: &Client, _req: &HttpRequest) -> HttpResponse { + // This is a server readiness check - checking if we have healthy workers + // Workers handle their own health checks in the background + let mut all_healthy = true; + let mut unhealthy_servers = Vec::new(); + + // Check prefill servers + for worker in self.prefill_workers.read().unwrap().iter() { + if !worker.is_healthy() { + all_healthy = false; + unhealthy_servers.push(format!("Prefill: {}", worker.url())); + } + } + + // Check decode servers + for worker in self.decode_workers.read().unwrap().iter() { + if !worker.is_healthy() { + all_healthy = false; + unhealthy_servers.push(format!("Decode: {}", worker.url())); + } + } + + if all_healthy { + HttpResponse::Ok().body("All servers healthy") + } else { + HttpResponse::ServiceUnavailable() + .body(format!("Unhealthy servers: {:?}", unhealthy_servers)) + } + } + + async fn health_generate(&self, client: &Client, _req: &HttpRequest) -> HttpResponse { + // Use the existing PDRouter health_generate method + PDRouter::health_generate(self, client).await + } + + async fn get_server_info(&self, client: &Client, _req: &HttpRequest) -> HttpResponse { + // Use the existing PDRouter get_server_info method + PDRouter::get_server_info(self, client).await + } + + async fn get_models(&self, client: &Client, req: &HttpRequest) -> HttpResponse { + // Get first prefill worker URL to avoid holding lock across await + let first_worker_url = if let Ok(workers) = self.prefill_workers.read() { + workers.first().map(|w| w.url().to_string()) + } else { + return HttpResponse::InternalServerError().body("Failed to access prefill workers"); + }; + + if let Some(worker_url) = first_worker_url { + // Send request directly without going through Router + let mut request_builder = client.get(format!("{}/v1/models", worker_url)); + for (name, value) in crate::routers::router::copy_request_headers(req) { + if name.to_lowercase() != "content-type" && name.to_lowercase() != "content-length" + { + request_builder = request_builder.header(name, value); + } + } + match request_builder.send().await { + Ok(res) => { + let status = actix_web::http::StatusCode::from_u16(res.status().as_u16()) + .unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR); + match res.bytes().await { + Ok(body) => HttpResponse::build(status).body(body.to_vec()), + Err(e) => HttpResponse::InternalServerError() + .body(format!("Failed to read response body: {}", e)), + } + } + Err(e) => HttpResponse::InternalServerError() + .body(format!("Failed to send request: {}", e)), + } + } else { + HttpResponse::ServiceUnavailable().body("No prefill servers available") + } + } + + async fn get_model_info(&self, client: &Client, req: &HttpRequest) -> HttpResponse { + // For PD router, get model info from the first prefill server + // Get first prefill worker URL to avoid holding lock across await + let first_worker_url = if let Ok(workers) = self.prefill_workers.read() { + workers.first().map(|w| w.url().to_string()) + } else { + return HttpResponse::InternalServerError().body("Failed to access prefill workers"); + }; + + if let Some(worker_url) = first_worker_url { + let mut request_builder = client.get(format!("{}/get_model_info", worker_url)); + for (name, value) in crate::routers::router::copy_request_headers(req) { + if name.to_lowercase() != "content-type" && name.to_lowercase() != "content-length" + { + request_builder = request_builder.header(name, value); + } + } + match request_builder.send().await { + Ok(res) => { + let status = actix_web::http::StatusCode::from_u16(res.status().as_u16()) + .unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR); + match res.bytes().await { + Ok(body) => HttpResponse::build(status).body(body.to_vec()), + Err(e) => HttpResponse::InternalServerError() + .body(format!("Failed to read response body: {}", e)), + } + } + Err(e) => HttpResponse::InternalServerError() + .body(format!("Failed to send request: {}", e)), + } + } else { + HttpResponse::ServiceUnavailable().body("No prefill servers available") + } + } + + async fn route_generate( + &self, + client: &Client, + req: &HttpRequest, + body: serde_json::Value, + ) -> HttpResponse { + match serde_json::from_value::(body.clone()) { + Ok(openai_req) => { + // Convert OpenAI format to PD format + let pd_req = openai_req.to_pd_request(); + PDRouter::route_generate(self, client, req, pd_req, "/generate").await + } + Err(_) => { + // If that fails, try to deserialize directly as PD format (for backwards compatibility) + match serde_json::from_value::(body) { + Ok(pd_req) => { + PDRouter::route_generate(self, client, req, pd_req, "/generate").await + } + Err(e) => { + HttpResponse::BadRequest().body(format!("Invalid request format: {}", e)) + } + } + } + } + } + + async fn route_chat( + &self, + client: &Client, + req: &HttpRequest, + body: serde_json::Value, + ) -> HttpResponse { + match serde_json::from_value::(body.clone()) { + Ok(openai_req) => { + // Convert OpenAI format to PD format + let pd_req = openai_req.to_pd_request(); + PDRouter::route_chat(self, client, req, pd_req, "/v1/chat/completions").await + } + Err(_) => { + // If that fails, try to deserialize directly as PD format (for backwards compatibility) + match serde_json::from_value::(body) { + Ok(pd_req) => { + PDRouter::route_chat(self, client, req, pd_req, "/v1/chat/completions") + .await + } + Err(e) => { + HttpResponse::BadRequest().body(format!("Invalid request format: {}", e)) + } + } + } + } + } + + async fn route_completion( + &self, + client: &Client, + req: &HttpRequest, + body: serde_json::Value, + ) -> HttpResponse { + match serde_json::from_value::(body.clone()) { + Ok(openai_req) => { + // Convert OpenAI format to PD format (CompletionRequest -> GenerateReqInput) + let pd_req = openai_req.to_pd_request(); + PDRouter::route_generate(self, client, req, pd_req, "/v1/completions").await + } + Err(_) => { + // If that fails, try to deserialize directly as PD format (for backwards compatibility) + match serde_json::from_value::(body) { + Ok(pd_req) => { + PDRouter::route_generate(self, client, req, pd_req, "/v1/completions").await + } + Err(e) => { + HttpResponse::BadRequest().body(format!("Invalid request format: {}", e)) + } + } + } + } + } + + async fn flush_cache(&self, client: &Client) -> HttpResponse { + // Use the existing PDRouter flush_cache method + PDRouter::flush_cache(self, client).await + } + + async fn get_worker_loads(&self, client: &Client) -> HttpResponse { + // Use the existing PDRouter get_loads method + PDRouter::get_loads(self, client).await + } + + fn router_type(&self) -> &'static str { + "pd" + } + + fn readiness(&self) -> HttpResponse { + // PD router is ready if it has at least one healthy prefill AND one healthy decode worker + let healthy_prefill_count = self + .prefill_workers + .read() + .unwrap() + .iter() + .filter(|w| w.is_healthy()) + .count(); + + let healthy_decode_count = self + .decode_workers + .read() + .unwrap() + .iter() + .filter(|w| w.is_healthy()) + .count(); + + let total_prefill = self.prefill_workers.read().unwrap().len(); + let total_decode = self.decode_workers.read().unwrap().len(); + + if healthy_prefill_count > 0 && healthy_decode_count > 0 { + HttpResponse::Ok().json(serde_json::json!({ + "status": "ready", + "prefill": { + "healthy": healthy_prefill_count, + "total": total_prefill + }, + "decode": { + "healthy": healthy_decode_count, + "total": total_decode + } + })) + } else { + let mut reasons = Vec::new(); + if healthy_prefill_count == 0 { + reasons.push("no healthy prefill workers"); + } + if healthy_decode_count == 0 { + reasons.push("no healthy decode workers"); + } + + HttpResponse::ServiceUnavailable().json(serde_json::json!({ + "status": "not_ready", + "reason": reasons.join(", "), + "prefill": { + "healthy": healthy_prefill_count, + "total": total_prefill + }, + "decode": { + "healthy": healthy_decode_count, + "total": total_decode + } + })) + } + } +} diff --git a/sgl-router/src/pd_types.rs b/sgl-router/src/routers/pd_types.rs similarity index 100% rename from sgl-router/src/pd_types.rs rename to sgl-router/src/routers/pd_types.rs diff --git a/sgl-router/src/request_adapter.rs b/sgl-router/src/routers/request_adapter.rs similarity index 99% rename from sgl-router/src/request_adapter.rs rename to sgl-router/src/routers/request_adapter.rs index 4396cc4d7eec..f5611bbe492b 100644 --- a/sgl-router/src/request_adapter.rs +++ b/sgl-router/src/routers/request_adapter.rs @@ -1,9 +1,9 @@ // Request adapter to bridge OpenAI API types with PD routing requirements +use super::pd_types::{Bootstrap, ChatReqInput, GenerateReqInput, SingleOrBatch}; use crate::openai_api_types::{ ChatCompletionRequest, CompletionRequest, GenerateRequest, GenerationRequest, StringOrArray, }; -use crate::pd_types::{Bootstrap, ChatReqInput, GenerateReqInput, SingleOrBatch}; use serde_json::Value; /// Adapter trait to convert OpenAI requests to PD-compatible requests diff --git a/sgl-router/src/routers/router.rs b/sgl-router/src/routers/router.rs new file mode 100644 index 000000000000..ef44348eca20 --- /dev/null +++ b/sgl-router/src/routers/router.rs @@ -0,0 +1,1055 @@ +use crate::core::{HealthChecker, Worker, WorkerFactory}; +use crate::policies::LoadBalancingPolicy; +use ::metrics::{counter, gauge, histogram}; +use actix_web::http::header::{HeaderValue, CONTENT_TYPE}; +use actix_web::{HttpRequest, HttpResponse}; +use futures_util::{StreamExt, TryStreamExt}; +use std::collections::HashMap; +use std::sync::{Arc, RwLock}; +use std::thread; +use std::time::{Duration, Instant}; +use tracing::{debug, error, info, warn}; + +pub fn copy_request_headers(req: &HttpRequest) -> Vec<(String, String)> { + req.headers() + .iter() + .filter_map(|(name, value)| { + value + .to_str() + .ok() + .map(|v| (name.to_string(), v.to_string())) + }) + .collect() +} + +/// Regular router that uses injected load balancing policies +#[derive(Debug)] +pub struct Router { + workers: Arc>>>, + policy: Arc, + timeout_secs: u64, + interval_secs: u64, + _worker_loads: Arc>>, + _load_monitor_handle: Option>>, + _health_checker: Option, +} + +impl Router { + /// Create a new router with injected policy + pub fn new( + worker_urls: Vec, + policy: Arc, + timeout_secs: u64, + interval_secs: u64, + ) -> Result { + // Update active workers gauge + gauge!("sgl_router_active_workers").set(worker_urls.len() as f64); + + // Wait for workers to be healthy (skip if empty - for service discovery mode) + if !worker_urls.is_empty() { + Self::wait_for_healthy_workers(&worker_urls, timeout_secs, interval_secs)?; + } + + // Create Worker trait objects from URLs + let workers: Vec> = worker_urls + .iter() + .map(|url| WorkerFactory::create_regular(url.clone())) + .collect(); + + // Initialize policy with workers if needed (e.g., for cache-aware) + if let Some(cache_aware) = policy + .as_any() + .downcast_ref::() + { + cache_aware.init_workers(&workers); + } + + let workers = Arc::new(RwLock::new(workers)); + let health_checker = crate::core::start_health_checker(Arc::clone(&workers), interval_secs); + + // Setup load monitoring for PowerOfTwo policy + let (tx, rx) = tokio::sync::watch::channel(HashMap::new()); + let worker_loads = Arc::new(rx); + + let load_monitor_handle = if policy.name() == "power_of_two" { + let monitor_urls = worker_urls.clone(); + let monitor_interval = interval_secs; + let policy_clone = Arc::clone(&policy); + + Some(Arc::new(tokio::spawn(async move { + Self::monitor_worker_loads(monitor_urls, tx, monitor_interval, policy_clone).await; + }))) + } else { + None + }; + + Ok(Router { + workers, + policy, + timeout_secs, + interval_secs, + _worker_loads: worker_loads, + _load_monitor_handle: load_monitor_handle, + _health_checker: Some(health_checker), + }) + } + + /// Get the current list of worker URLs + pub fn get_worker_urls(&self) -> Vec { + self.workers + .read() + .unwrap() + .iter() + .map(|w| w.url().to_string()) + .collect() + } + + pub fn wait_for_healthy_workers( + worker_urls: &[String], + timeout_secs: u64, + interval_secs: u64, + ) -> Result<(), String> { + let start_time = std::time::Instant::now(); + let sync_client = reqwest::blocking::Client::builder() + .timeout(Duration::from_secs(timeout_secs)) + .build() + .map_err(|e| format!("Failed to create HTTP client: {}", e))?; + + loop { + if start_time.elapsed() > Duration::from_secs(timeout_secs) { + error!( + "Timeout {}s waiting for workers {:?} to become healthy. Please set --router-worker-startup-timeout-secs (sglang_router.launch_server) or --worker-startup-timeout-secs (sglang_worker.router) to a larger value", + timeout_secs, worker_urls + ); + return Err(format!( + "Timeout {}s waiting for workers {:?} to become healthy. Please set --router-worker-startup-timeout-secs (sglang_router.launch_server) or --worker-startup-timeout-secs (sglang_worker.router) to a larger value", + timeout_secs, worker_urls + )); + } + + let mut all_healthy = true; + let mut unhealthy_workers = Vec::new(); + + for url in worker_urls { + match sync_client.get(&format!("{}/health", url)).send() { + Ok(res) => { + if !res.status().is_success() { + let msg = format!( + "Worker heatlh check is pending with status {}", + res.status() + ); + info!("{}", msg); + all_healthy = false; + unhealthy_workers.push((url, msg)); + } + } + Err(_) => { + let msg = format!("Worker is not ready yet"); + info!("{}", msg); + all_healthy = false; + unhealthy_workers.push((url, msg)); + } + } + } + + if all_healthy { + info!("All workers are healthy"); + return Ok(()); + } else { + info!("Initializing workers:"); + for (url, reason) in &unhealthy_workers { + info!(" {} - {}", url, reason); + } + thread::sleep(Duration::from_secs(interval_secs)); + } + } + } + + fn select_first_worker(&self) -> Result { + let workers_guard = self.workers.read().unwrap(); + if workers_guard.is_empty() { + Err("No workers are available".to_string()) + } else { + Ok(workers_guard[0].url().to_string()) + } + } + + pub async fn send_request( + &self, + client: &reqwest::Client, + worker_url: &str, + route: &str, + req: &HttpRequest, + ) -> HttpResponse { + let start = Instant::now(); + let mut request_builder = client.get(format!("{}{}", worker_url, route)); + + // Copy all headers from original request except for /health because it does not need authorization + if route != "/health" { + for (name, value) in copy_request_headers(req) { + // Skip Content-Type and Content-Length as .json() sets them + if name.to_lowercase() != "content-type" && name.to_lowercase() != "content-length" + { + request_builder = request_builder.header(name, value); + } + } + } + + let response = match request_builder.send().await { + Ok(res) => { + let status = actix_web::http::StatusCode::from_u16(res.status().as_u16()) + .unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR); + + match res.bytes().await { + Ok(body) => HttpResponse::build(status).body(body.to_vec()), + Err(e) => HttpResponse::InternalServerError() + .body(format!("Failed to read response body: {}", e)), + } + } + Err(e) => HttpResponse::InternalServerError().body(format!( + "Failed to send request to worker {}: {}", + worker_url, e + )), + }; + + // Record request metrics + if route != "/health" { + let duration = start.elapsed(); + counter!("sgl_router_requests_total", "route" => route.to_string()).increment(1); + histogram!("sgl_router_request_duration_seconds", "route" => route.to_string()) + .record(duration.as_secs_f64()); + + if !response.status().is_success() { + counter!("sgl_router_request_errors_total", "route" => route.to_string()) + .increment(1); + } + } + response + } + + pub async fn route_to_first( + &self, + client: &reqwest::Client, + route: &str, + req: &HttpRequest, + ) -> HttpResponse { + const MAX_REQUEST_RETRIES: u32 = 3; + const MAX_TOTAL_RETRIES: u32 = 6; + let mut total_retries = 0; + + while total_retries < MAX_TOTAL_RETRIES { + match self.select_first_worker() { + Ok(worker_url) => { + let mut request_retries = 0; + + // Try the same worker multiple times + while request_retries < MAX_REQUEST_RETRIES { + if total_retries >= 1 { + info!("Retrying request after {} failed attempts", total_retries); + } + + let response = self.send_request(client, &worker_url, route, req).await; + + if response.status().is_success() { + return response; + } else { + // if the worker is healthy, it means the request is bad, so return the error response + let health_response = + self.send_request(client, &worker_url, "/health", req).await; + if health_response.status().is_success() { + return response; + } + } + + warn!( + "Request to {} failed (attempt {}/{})", + worker_url, + request_retries + 1, + MAX_REQUEST_RETRIES + ); + + request_retries += 1; + total_retries += 1; + + if request_retries == MAX_REQUEST_RETRIES { + warn!("Removing failed worker: {}", worker_url); + self.remove_worker(&worker_url); + break; + } + } + } + Err(e) => return HttpResponse::InternalServerError().body(e), + } + } + + HttpResponse::InternalServerError().body("All retry attempts failed") + } + + pub async fn route_to_all( + &self, + client: &reqwest::Client, + route: &str, + req: &HttpRequest, + ) -> HttpResponse { + // Get all worker URLs + let worker_urls = self.get_worker_urls(); + + // Send requests to all workers concurrently + let mut tasks = Vec::new(); + for worker_url in &worker_urls { + let mut request_builder = client.post(format!("{}{}", worker_url, route)); + + // Copy headers from original request + for (name, value) in copy_request_headers(req) { + request_builder = request_builder.header(name, value); + } + + tasks.push(request_builder.send()); + } + + // Wait for all responses + let results = futures_util::future::join_all(tasks).await; + + // Check if all succeeded + let all_success = results.iter().all(|r| { + r.as_ref() + .map(|res| res.status().is_success()) + .unwrap_or(false) + }); + + if all_success { + HttpResponse::Ok().body("Operation completed on all servers") + } else { + HttpResponse::InternalServerError().body("Operation failed on one or more servers") + } + } + + pub async fn get_all_loads( + &self, + client: &reqwest::Client, + _req: &HttpRequest, + ) -> HttpResponse { + let urls = self.get_worker_urls(); + let prefill_urls: Vec = Vec::new(); + let decode_urls = urls; + + // Collect loads from all servers + let mut prefill_loads = Vec::new(); + let mut decode_loads = Vec::new(); + + // Get prefill loads + for url in &prefill_urls { + let load = self.get_worker_load(client, url).await.unwrap_or(-1); + prefill_loads.push(serde_json::json!({ + "engine": format!("(Prefill@{})", url), + "load": load as i64 + })); + } + + // Get decode loads + for url in &decode_urls { + let load = self.get_worker_load(client, url).await.unwrap_or(-1); + decode_loads.push(serde_json::json!({ + "engine": format!("(Decode@{})", url), + "load": load as i64 + })); + } + + HttpResponse::Ok().json(serde_json::json!({ + "prefill": prefill_loads, + "decode": decode_loads + })) + } + + // New method to route typed requests directly + pub async fn route_typed_request< + T: crate::openai_api_types::GenerationRequest + serde::Serialize + Clone, + >( + &self, + client: &reqwest::Client, + req: &HttpRequest, + typed_req: &T, + route: &str, + ) -> HttpResponse { + // Handle retries like the original implementation + let start = Instant::now(); + const MAX_REQUEST_RETRIES: u32 = 3; + const MAX_TOTAL_RETRIES: u32 = 6; + let mut total_retries = 0; + + while total_retries < MAX_TOTAL_RETRIES { + // Extract routing text directly from typed request + let text = typed_req.extract_text_for_routing(); + let is_stream = typed_req.is_stream(); + + // Select worker based on text + let worker_url = self.select_generate_worker_from_text(&text); + let mut request_retries = 0; + + // Try the same worker multiple times + while request_retries < MAX_REQUEST_RETRIES { + if total_retries >= 1 { + info!("Retrying request after {} failed attempts", total_retries); + counter!("sgl_router_retries_total", "route" => route.to_string()).increment(1); + } + + // Increment load before request if using RAII load tracking + let load_incremented = if self.policy.name() == "cache_aware" { + let workers_guard = self.workers.read().unwrap(); + if let Some(worker) = workers_guard.iter().find(|w| w.url() == &worker_url) { + worker.increment_load(); + gauge!("sgl_router_running_requests", "worker" => worker_url.to_string()) + .set(worker.load() as f64); + true + } else { + false + } + } else { + false + }; + + // Send typed request directly + let response = self + .send_typed_request( + client, + req, + typed_req, + route, + &worker_url, + is_stream, + load_incremented, + ) + .await; + + if response.status().is_success() { + let duration = start.elapsed(); + histogram!("sgl_router_generate_duration_seconds", "route" => route.to_string()) + .record(duration.as_secs_f64()); + return response; + } else { + // if the worker is healthy, it means the request is bad, so return the error response + let health_response = + self.send_request(client, &worker_url, "/health", req).await; + if health_response.status().is_success() { + counter!("sgl_router_request_errors_total", "route" => route.to_string()) + .increment(1); + return response; + } + } + + warn!( + "Generate request to {} failed (attempt {}/{})", + worker_url, + request_retries + 1, + MAX_REQUEST_RETRIES + ); + + request_retries += 1; + total_retries += 1; + + if request_retries == MAX_REQUEST_RETRIES { + warn!("Removing failed worker: {}", worker_url); + self.remove_worker(&worker_url); + break; + } + } + } + + counter!("sgl_router_request_errors_total", "route" => route.to_string()).increment(1); + HttpResponse::InternalServerError().body("All retry attempts failed") + } + + // Helper method to select worker from text using the policy + fn select_generate_worker_from_text(&self, text: &str) -> String { + let workers = self.workers.read().unwrap(); + + match self.policy.select_worker(&workers, Some(text)) { + Some(idx) => workers[idx].url().to_string(), + None => { + warn!("No healthy workers available"); + String::new() + } + } + } + + // Send typed request directly without conversion + async fn send_typed_request( + &self, + client: &reqwest::Client, + req: &HttpRequest, + typed_req: &T, + route: &str, + worker_url: &str, + is_stream: bool, + load_incremented: bool, // Whether load was incremented for this request + ) -> HttpResponse { + let start = Instant::now(); + + // Debug: Log what we're sending + if let Ok(json_str) = serde_json::to_string_pretty(typed_req) { + debug!("Sending request to {}: {}", route, json_str); + } + + let mut request_builder = client + .post(format!("{}{}", worker_url, route)) + .json(typed_req); // Use json() directly with typed request + + // Copy all headers from original request + for (name, value) in copy_request_headers(req) { + // Skip Content-Type and Content-Length as .json() sets them + if name.to_lowercase() != "content-type" && name.to_lowercase() != "content-length" { + request_builder = request_builder.header(&name, &value); + } + } + + let res = match request_builder.send().await { + Ok(res) => res, + Err(e) => { + error!("Failed to send request to {}: {}", worker_url, e); + + // Decrement load on error if it was incremented + if load_incremented { + if let Ok(workers_guard) = self.workers.read() { + if let Some(worker) = workers_guard.iter().find(|w| w.url() == worker_url) { + worker.decrement_load(); + gauge!("sgl_router_running_requests", "worker" => worker_url.to_string()) + .set(worker.load() as f64); + } + } + } + + return HttpResponse::InternalServerError().body(format!("Request failed: {}", e)); + } + }; + + let status = actix_web::http::StatusCode::from_u16(res.status().as_u16()) + .unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR); + + if !is_stream { + // For non-streaming requests, get response first + let response = match res.bytes().await { + Ok(body) => HttpResponse::build(status).body(body.to_vec()), + Err(e) => { + let error_msg = format!("Failed to get response body: {}", e); + HttpResponse::InternalServerError().body(error_msg) + } + }; + + // Decrement load counter for non-streaming requests if it was incremented + if load_incremented && !is_stream { + if let Ok(workers_guard) = self.workers.read() { + if let Some(worker) = workers_guard.iter().find(|w| w.url() == worker_url) { + worker.decrement_load(); + gauge!("sgl_router_running_requests", "worker" => worker_url.to_string()) + .set(worker.load() as f64); + } + } + } + + // Record metrics + let duration = start.elapsed(); + histogram!("sgl_router_generate_duration_seconds", "route" => route.to_string()) + .record(duration.as_secs_f64()); + counter!("sgl_router_requests_total", "route" => route.to_string()).increment(1); + + response + } else if load_incremented { + // For streaming with load tracking, we need to manually decrement when done + let workers = Arc::clone(&self.workers); + let worker_url = worker_url.to_string(); + + HttpResponse::build(status) + .insert_header((CONTENT_TYPE, HeaderValue::from_static("text/event-stream"))) + .streaming( + res.bytes_stream() + .map_err(|_| { + actix_web::error::ErrorInternalServerError("Failed to read stream") + }) + .inspect(move |bytes| { + if let Ok(bytes) = bytes { + if bytes + .as_ref() + .windows(12) + .any(|window| window == b"data: [DONE]") + { + if let Ok(workers_guard) = workers.read() { + if let Some(worker) = + workers_guard.iter().find(|w| w.url() == &worker_url) + { + worker.decrement_load(); + gauge!("sgl_router_running_requests", "worker" => worker_url.to_string()) + .set(worker.load() as f64); + debug!("Streaming is done!!") + } + } + } + } + }), + ) + } else { + // For requests without load tracking, just stream + HttpResponse::build(status) + .insert_header((CONTENT_TYPE, HeaderValue::from_static("text/event-stream"))) + .streaming(res.bytes_stream().map_err(|_| { + actix_web::error::ErrorInternalServerError("Failed to read stream") + })) + } + } + + pub async fn add_worker(&self, worker_url: &str) -> Result { + let start_time = std::time::Instant::now(); + let client = reqwest::Client::builder() + .timeout(Duration::from_secs(self.timeout_secs)) + .build() + .map_err(|e| format!("Failed to create HTTP client: {}", e))?; + + loop { + if start_time.elapsed() > Duration::from_secs(self.timeout_secs) { + error!( + "Timeout {}s waiting for worker {} to become healthy. Please set --router-worker-startup-timeout-secs (sglang_router.launch_server) or --worker-startup-timeout-secs (sglang_worker.router) to a larger value", + self.timeout_secs, worker_url + ); + return Err(format!( + "Timeout {}s waiting for worker {} to become healthy. Please set --router-worker-startup-timeout-secs (sglang_router.launch_server) or --worker-startup-timeout-secs (sglang_worker.router) to a larger value", + self.timeout_secs, worker_url + )); + } + + match client.get(&format!("{}/health", worker_url)).send().await { + Ok(res) => { + if res.status().is_success() { + info!("Worker {} health check passed", worker_url); + let mut workers_guard = self.workers.write().unwrap(); + if workers_guard.iter().any(|w| w.url() == worker_url) { + return Err(format!("Worker {} already exists", worker_url)); + } + info!("Added worker: {}", worker_url); + let new_worker = WorkerFactory::create_regular(worker_url.to_string()); + workers_guard.push(new_worker); + gauge!("sgl_router_active_workers").set(workers_guard.len() as f64); + + // If cache aware policy, initialize the worker in the tree + if let Some(cache_aware) = + self.policy + .as_any() + .downcast_ref::() + { + // Get updated workers after adding + drop(workers_guard); + let workers_guard = self.workers.read().unwrap(); + cache_aware.init_workers(&workers_guard); + } + + return Ok(format!("Successfully added worker: {}", worker_url)); + } else { + info!( + "Worker {} health check is pending with status: {}.", + worker_url, + res.status() + ); + // if the url does not have http or https prefix, warn users + if !worker_url.starts_with("http://") && !worker_url.starts_with("https://") + { + warn!("The worker url {} does not have http or https prefix. Please add the prefix to the url.", worker_url); + } + + tokio::time::sleep(Duration::from_secs(self.interval_secs)).await; + continue; + } + } + Err(e) => { + info!( + "Worker {} health check is pending with error: {}", + worker_url, e + ); + + // if the url does not have http or https prefix, warn users + if !worker_url.starts_with("http://") && !worker_url.starts_with("https://") { + warn!("The worker url {} does not have http or https prefix. Please add the prefix to the url.", worker_url); + } + + tokio::time::sleep(Duration::from_secs(self.interval_secs)).await; + continue; + } + } + } + } + + pub fn remove_worker(&self, worker_url: &str) { + let mut workers_guard = self.workers.write().unwrap(); + if let Some(index) = workers_guard.iter().position(|w| w.url() == worker_url) { + workers_guard.remove(index); + info!("Removed worker: {}", worker_url); + gauge!("sgl_router_active_workers").set(workers_guard.len() as f64); + } else { + warn!("Worker {} not found, skipping removal", worker_url); + return; + } + + // If cache aware policy, remove the worker from the tree + if let Some(cache_aware) = self + .policy + .as_any() + .downcast_ref::() + { + cache_aware.remove_worker(worker_url); + info!("Removed worker from tree: {}", worker_url); + } + } + + async fn get_worker_load(&self, client: &reqwest::Client, worker_url: &str) -> Option { + match client.get(&format!("{}/get_load", worker_url)).send().await { + Ok(res) if res.status().is_success() => match res.bytes().await { + Ok(bytes) => match serde_json::from_slice::(&bytes) { + Ok(data) => data + .get("load") + .and_then(|v| v.as_i64()) + .map(|v| v as isize), + Err(e) => { + debug!("Failed to parse load response from {}: {}", worker_url, e); + None + } + }, + Err(e) => { + debug!("Failed to read load response from {}: {}", worker_url, e); + None + } + }, + Ok(res) => { + debug!( + "Worker {} returned non-success status: {}", + worker_url, + res.status() + ); + None + } + Err(e) => { + debug!("Failed to get load from {}: {}", worker_url, e); + None + } + } + } + + // Background task to monitor worker loads + async fn monitor_worker_loads( + worker_urls: Vec, + tx: tokio::sync::watch::Sender>, + interval_secs: u64, + policy: Arc, + ) { + let client = match reqwest::Client::builder() + .timeout(Duration::from_secs(5)) + .build() + { + Ok(c) => c, + Err(e) => { + error!("Failed to create HTTP client for load monitoring: {}", e); + return; + } + }; + + let mut interval = tokio::time::interval(Duration::from_secs(interval_secs)); + + loop { + interval.tick().await; + + let mut loads = HashMap::new(); + for url in &worker_urls { + if let Some(load) = Self::get_worker_load_static(&client, url).await { + loads.insert(url.clone(), load); + debug!("Worker {} load: {}", url, load); + } + } + + if !loads.is_empty() { + // Update policy with new loads + policy.update_loads(&loads); + + // Send to watchers + if let Err(e) = tx.send(loads) { + error!("Failed to send load update: {}", e); + } + } + } + } + + // Static version of get_worker_load for use in monitoring task + async fn get_worker_load_static(client: &reqwest::Client, worker_url: &str) -> Option { + match client.get(&format!("{}/get_load", worker_url)).send().await { + Ok(res) if res.status().is_success() => match res.bytes().await { + Ok(bytes) => match serde_json::from_slice::(&bytes) { + Ok(data) => data + .get("load") + .and_then(|v| v.as_i64()) + .map(|v| v as isize), + Err(e) => { + debug!("Failed to parse load response from {}: {}", worker_url, e); + None + } + }, + Err(e) => { + debug!("Failed to read load response from {}: {}", worker_url, e); + None + } + }, + Ok(res) => { + debug!( + "Worker {} returned non-success status: {}", + worker_url, + res.status() + ); + None + } + Err(e) => { + debug!("Failed to get load from {}: {}", worker_url, e); + None + } + } + } +} + +use crate::routers::{RouterTrait, WorkerManagement}; +use async_trait::async_trait; +use reqwest::Client; + +#[async_trait] +impl WorkerManagement for Router { + async fn add_worker(&self, worker_url: &str) -> Result { + Router::add_worker(self, worker_url).await + } + + fn remove_worker(&self, worker_url: &str) { + Router::remove_worker(self, worker_url) + } + + fn get_worker_urls(&self) -> Vec { + Router::get_worker_urls(self) + } +} + +#[async_trait(?Send)] +impl RouterTrait for Router { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + async fn health(&self, _client: &Client, _req: &HttpRequest) -> HttpResponse { + // Check local health state of all workers (consistent with PD router) + // Note: This uses cached health status from background health checks, not live checks + let mut all_healthy = true; + let mut unhealthy_servers = Vec::new(); + + for worker in self.workers.read().unwrap().iter() { + if !worker.is_healthy() { + all_healthy = false; + unhealthy_servers.push(worker.url().to_string()); + } + } + + if all_healthy { + HttpResponse::Ok().body("All servers healthy") + } else { + HttpResponse::ServiceUnavailable() + .body(format!("Unhealthy servers: {:?}", unhealthy_servers)) + } + } + + async fn health_generate(&self, client: &Client, req: &HttpRequest) -> HttpResponse { + // Test model generation capability by sending to first available worker + // Note: This endpoint actually causes the model to generate a token, so we only test one worker + self.route_to_first(client, "/health_generate", req).await + } + + async fn get_server_info(&self, client: &Client, req: &HttpRequest) -> HttpResponse { + self.route_to_first(client, "/get_server_info", req).await + } + + async fn get_models(&self, client: &Client, req: &HttpRequest) -> HttpResponse { + self.route_to_first(client, "/v1/models", req).await + } + + async fn get_model_info(&self, client: &Client, req: &HttpRequest) -> HttpResponse { + self.route_to_first(client, "/get_model_info", req).await + } + + async fn route_generate( + &self, + client: &Client, + req: &HttpRequest, + body: serde_json::Value, + ) -> HttpResponse { + // Convert JSON to typed request + match serde_json::from_value::(body) { + Ok(typed_req) => { + self.route_typed_request(client, req, &typed_req, "/generate") + .await + } + Err(e) => HttpResponse::BadRequest().body(format!("Invalid request: {}", e)), + } + } + + async fn route_chat( + &self, + client: &Client, + req: &HttpRequest, + body: serde_json::Value, + ) -> HttpResponse { + // Convert JSON to typed request + match serde_json::from_value::(body) { + Ok(typed_req) => { + self.route_typed_request(client, req, &typed_req, "/v1/chat/completions") + .await + } + Err(e) => HttpResponse::BadRequest().body(format!("Invalid request: {}", e)), + } + } + + async fn route_completion( + &self, + client: &Client, + req: &HttpRequest, + body: serde_json::Value, + ) -> HttpResponse { + // Convert JSON to typed request + match serde_json::from_value::(body) { + Ok(typed_req) => { + self.route_typed_request(client, req, &typed_req, "/v1/completions") + .await + } + Err(e) => HttpResponse::BadRequest().body(format!("Invalid request: {}", e)), + } + } + + async fn flush_cache(&self, client: &Client) -> HttpResponse { + // Get all worker URLs + let worker_urls = self.get_worker_urls(); + + // Send requests to all workers concurrently without headers + let mut tasks = Vec::new(); + for worker_url in &worker_urls { + let request_builder = client.post(format!("{}/flush_cache", worker_url)); + tasks.push(request_builder.send()); + } + + // Wait for all responses + let results = futures_util::future::join_all(tasks).await; + + // Check if all succeeded + let all_success = results.iter().all(|r| { + r.as_ref() + .map(|res| res.status().is_success()) + .unwrap_or(false) + }); + + if all_success { + HttpResponse::Ok().body("Cache flushed on all servers") + } else { + HttpResponse::InternalServerError().body("Cache flush failed on one or more servers") + } + } + + async fn get_worker_loads(&self, client: &Client) -> HttpResponse { + let urls = self.get_worker_urls(); + let mut loads = Vec::new(); + + // Get loads from all workers + for url in &urls { + let load = self.get_worker_load(client, url).await.unwrap_or(-1); + loads.push(serde_json::json!({ + "worker": url, + "load": load + })); + } + + HttpResponse::Ok().json(serde_json::json!({ + "workers": loads + })) + } + + fn router_type(&self) -> &'static str { + "regular" + } + + fn readiness(&self) -> HttpResponse { + // Regular router is ready if it has at least one healthy worker + let healthy_count = self + .workers + .read() + .unwrap() + .iter() + .filter(|w| w.is_healthy()) + .count(); + + if healthy_count > 0 { + HttpResponse::Ok().json(serde_json::json!({ + "status": "ready", + "healthy_workers": healthy_count, + "total_workers": self.workers.read().unwrap().len() + })) + } else { + HttpResponse::ServiceUnavailable().json(serde_json::json!({ + "status": "not_ready", + "reason": "no healthy workers available", + "total_workers": self.workers.read().unwrap().len() + })) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::policies::RandomPolicy; + use std::collections::HashMap; + + fn create_test_regular_router() -> Router { + let workers = vec![ + WorkerFactory::create_regular("http://worker1:8080".to_string()), + WorkerFactory::create_regular("http://worker2:8080".to_string()), + ]; + let (_, rx) = tokio::sync::watch::channel(HashMap::new()); + Router { + workers: Arc::new(RwLock::new(workers)), + policy: Arc::new(RandomPolicy::new()), + timeout_secs: 5, + interval_secs: 1, + _worker_loads: Arc::new(rx), + _load_monitor_handle: None, + _health_checker: None, + } + } + + #[test] + fn test_router_get_worker_urls_regular() { + let router = create_test_regular_router(); + let urls = router.get_worker_urls(); + + assert_eq!(urls.len(), 2); + assert!(urls.contains(&"http://worker1:8080".to_string())); + assert!(urls.contains(&"http://worker2:8080".to_string())); + } + + #[test] + fn test_select_first_worker_regular() { + let router = create_test_regular_router(); + let result = router.select_first_worker(); + + assert!(result.is_ok()); + assert_eq!(result.unwrap(), "http://worker1:8080"); + } + + #[test] + fn test_wait_for_healthy_workers_empty_list() { + let result = Router::wait_for_healthy_workers(&[], 1, 1); + assert!(result.is_ok()); + } + + #[test] + fn test_wait_for_healthy_workers_invalid_urls() { + // This test will timeout quickly since the URLs are invalid + let result = + Router::wait_for_healthy_workers(&["http://nonexistent:8080".to_string()], 1, 1); + assert!(result.is_err()); + assert!(result.unwrap_err().contains("Timeout")); + } +} diff --git a/sgl-router/src/server.rs b/sgl-router/src/server.rs index bb2695b932ce..69340eefe52b 100644 --- a/sgl-router/src/server.rs +++ b/sgl-router/src/server.rs @@ -1,9 +1,8 @@ +use crate::config::RouterConfig; use crate::logging::{self, LoggingConfig}; use crate::openai_api_types::{ChatCompletionRequest, CompletionRequest, GenerateRequest}; use crate::prometheus::{self, PrometheusConfig}; -use crate::request_adapter::ToPdRequest; -use crate::router::PolicyConfig; -use crate::router::Router; +use crate::routers::{RouterFactory, RouterTrait}; use crate::service_discovery::{start_service_discovery, ServiceDiscoveryConfig}; use actix_web::{ error, get, post, web, App, Error, HttpRequest, HttpResponse, HttpServer, Responder, @@ -19,27 +18,19 @@ use tracing::{error, info, warn, Level}; #[derive(Debug)] pub struct AppState { - router: Arc, + router: Arc, client: Client, - is_pd_mode: bool, // Add flag to track PD mode } impl AppState { - pub fn new( - worker_urls: Vec, - client: Client, - policy_config: PolicyConfig, - ) -> Result { - // Check if this is PD mode from policy config - let is_pd_mode = matches!(policy_config, PolicyConfig::PrefillDecodeConfig { .. }); - - // Create router based on policy - let router = Arc::new(Router::new(worker_urls, policy_config)?); - Ok(Self { - router, - client, - is_pd_mode, - }) + pub fn new(router_config: RouterConfig, client: Client) -> Result { + // Use RouterFactory to create the appropriate router type + let router = RouterFactory::create_router(&router_config)?; + + // Convert Box to Arc + let router = Arc::from(router); + + Ok(Self { router, client }) } } @@ -76,65 +67,39 @@ fn json_error_handler(err: error::JsonPayloadError, _req: &HttpRequest) -> Error } } +#[get("/liveness")] +async fn liveness(_req: HttpRequest, data: web::Data) -> impl Responder { + data.router.liveness() +} + +#[get("/readiness")] +async fn readiness(_req: HttpRequest, data: web::Data) -> impl Responder { + data.router.readiness() +} + #[get("/health")] async fn health(req: HttpRequest, data: web::Data) -> impl Responder { - data.router - .route_to_first(&data.client, "/health", &req) - .await + data.router.health(&data.client, &req).await } #[get("/health_generate")] async fn health_generate(req: HttpRequest, data: web::Data) -> impl Responder { - // Check if we're in PD mode - if data.is_pd_mode { - // For PD mode, check health on all servers - data.router - .route_pd_health_generate(&data.client, &req) - .await - } else { - // Regular mode - data.router - .route_to_first(&data.client, "/health_generate", &req) - .await - } + data.router.health_generate(&data.client, &req).await } #[get("/get_server_info")] async fn get_server_info(req: HttpRequest, data: web::Data) -> impl Responder { - if data.is_pd_mode { - // For PD mode, aggregate info from both prefill and decode servers - data.router.get_pd_server_info(&data.client, &req).await - } else { - // Regular mode - return first server's info - data.router - .route_to_first(&data.client, "/get_server_info", &req) - .await - } + data.router.get_server_info(&data.client, &req).await } #[get("/v1/models")] async fn v1_models(req: HttpRequest, data: web::Data) -> impl Responder { - if data.is_pd_mode { - // For PD mode, return models from the first prefill server - data.router.get_pd_models(&data.client, &req).await - } else { - // Regular mode - data.router - .route_to_first(&data.client, "/v1/models", &req) - .await - } + data.router.get_models(&data.client, &req).await } #[get("/get_model_info")] async fn get_model_info(req: HttpRequest, data: web::Data) -> impl Responder { - if data.is_pd_mode { - // For PD mode, get model info from the first prefill server - data.router.get_pd_model_info(&data.client, &req).await - } else { - data.router - .route_to_first(&data.client, "/get_model_info", &req) - .await - } + data.router.get_model_info(&data.client, &req).await } #[post("/generate")] @@ -143,24 +108,12 @@ async fn generate( body: web::Json, state: web::Data, ) -> Result { - let client = &state.client; - let router = &state.router; - - // Use typed request directly for both PD and regular routing - if state.is_pd_mode { - // For PD mode, convert to PD request with bootstrap - let pd_request = body.into_inner().to_pd_request(); - - Ok(router - .route_pd_generate_typed(&client, &req, pd_request, "/generate") - .await) - } else { - // For regular mode, use typed request directly - let request = body.into_inner(); - Ok(router - .route_typed_request(&client, &req, &request, "/generate") - .await) - } + let json_body = serde_json::to_value(body.into_inner()) + .map_err(|e| error::ErrorBadRequest(format!("Invalid JSON: {}", e)))?; + Ok(state + .router + .route_generate(&state.client, &req, json_body) + .await) } #[post("/v1/chat/completions")] @@ -169,24 +122,12 @@ async fn v1_chat_completions( body: web::Json, state: web::Data, ) -> Result { - let client = &state.client; - let router = &state.router; - - // Use typed request directly for both PD and regular routing - if state.is_pd_mode { - // For PD mode, convert to PD request with bootstrap - let pd_request = body.into_inner().to_pd_request(); - - Ok(router - .route_pd_chat_typed(&client, &req, pd_request, "/v1/chat/completions") - .await) - } else { - // For regular mode, use typed request directly - let request = body.into_inner(); - Ok(router - .route_typed_request(&client, &req, &request, "/v1/chat/completions") - .await) - } + let json_body = serde_json::to_value(body.into_inner()) + .map_err(|e| error::ErrorBadRequest(format!("Invalid JSON: {}", e)))?; + Ok(state + .router + .route_chat(&state.client, &req, json_body) + .await) } #[post("/v1/completions")] @@ -195,24 +136,12 @@ async fn v1_completions( body: web::Json, state: web::Data, ) -> Result { - let client = &state.client; - let router = &state.router; - - // Use typed request directly for both PD and regular routing - if state.is_pd_mode { - // For PD mode, convert to PD request with bootstrap - let pd_request = body.into_inner().to_pd_request(); - - Ok(router - .route_pd_generate_typed(&client, &req, pd_request, "/v1/completions") - .await) - } else { - // For regular mode, use typed request directly - let request = body.into_inner(); - Ok(router - .route_typed_request(&client, &req, &request, "/v1/completions") - .await) - } + let json_body = serde_json::to_value(body.into_inner()) + .map_err(|e| error::ErrorBadRequest(format!("Invalid JSON: {}", e)))?; + Ok(state + .router + .route_completion(&state.client, &req, json_body) + .await) } #[post("/add_worker")] @@ -254,29 +183,19 @@ async fn remove_worker( } #[post("/flush_cache")] -async fn flush_cache(req: HttpRequest, data: web::Data) -> impl Responder { - if data.is_pd_mode { - // For PD mode, flush cache on both prefill and decode servers - data.router.route_pd_flush_cache(&data.client).await - } else { - // Route to all workers for cache flushing - data.router - .route_to_all(&data.client, "/flush_cache", &req) - .await - } +async fn flush_cache(_req: HttpRequest, data: web::Data) -> impl Responder { + data.router.flush_cache(&data.client).await } #[get("/get_loads")] -async fn get_loads(req: HttpRequest, data: web::Data) -> impl Responder { - // Get loads from all workers - data.router.get_all_loads(&data.client, &req).await +async fn get_loads(_req: HttpRequest, data: web::Data) -> impl Responder { + data.router.get_worker_loads(&data.client).await } pub struct ServerConfig { pub host: String, pub port: u16, - pub worker_urls: Vec, - pub policy_config: PolicyConfig, + pub router_config: RouterConfig, pub max_payload_size: usize, pub log_dir: Option, pub log_level: Option, @@ -324,8 +243,8 @@ pub async fn startup(config: ServerConfig) -> std::io::Result<()> { } info!("🚧 Initializing router on {}:{}", config.host, config.port); - info!("🚧 Initializing workers on {:?}", config.worker_urls); - info!("🚧 Policy Config: {:?}", config.policy_config); + info!("🚧 Router mode: {:?}", config.router_config.mode); + info!("🚧 Policy: {:?}", config.router_config.policy); info!( "🚧 Max payload size: {} MB", config.max_payload_size / (1024 * 1024) @@ -345,12 +264,8 @@ pub async fn startup(config: ServerConfig) -> std::io::Result<()> { .build() .expect("Failed to create HTTP client"); - let app_state_init = AppState::new( - config.worker_urls.clone(), - client.clone(), - config.policy_config.clone(), - ) - .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?; + let app_state_init = AppState::new(config.router_config.clone(), client.clone()) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?; let router_arc = Arc::clone(&app_state_init.router); let app_state = web::Data::new(app_state_init); @@ -397,6 +312,8 @@ pub async fn startup(config: ServerConfig) -> std::io::Result<()> { .service(v1_completions) .service(v1_models) .service(get_model_info) + .service(liveness) + .service(readiness) .service(health) .service(health_generate) .service(get_server_info) diff --git a/sgl-router/src/service_discovery.rs b/sgl-router/src/service_discovery.rs index 0e78717ce23b..72d78b490951 100644 --- a/sgl-router/src/service_discovery.rs +++ b/sgl-router/src/service_discovery.rs @@ -1,4 +1,4 @@ -use crate::router::Router; +use crate::routers::RouterTrait; use futures::{StreamExt, TryStreamExt}; use k8s_openapi::api::core::v1::Pod; @@ -176,7 +176,7 @@ impl PodInfo { pub async fn start_service_discovery( config: ServiceDiscoveryConfig, - router: Arc, + router: Arc, ) -> Result, kube::Error> { // Don't initialize anything if service discovery is disabled if !config.enabled { @@ -346,7 +346,7 @@ pub async fn start_service_discovery( async fn handle_pod_event( pod_info: &PodInfo, tracked_pods: Arc>>, - router: Arc, + router: Arc, port: u16, pd_mode: bool, ) { @@ -379,17 +379,32 @@ async fn handle_pod_event( pod_info.name, pod_info.pod_type, worker_url ); + // Handle PD mode with specific pod types let result = if pd_mode && pod_info.pod_type.is_some() { - // Use PD-aware worker management - if let Some(pod_type) = &pod_info.pod_type { - router - .add_pd_worker(&worker_url, pod_type.clone(), pod_info.bootstrap_port) - .await + // Need to import PDRouter type + use crate::routers::pd_router::PDRouter; + + // Try to downcast to PDRouter + if let Some(pd_router) = router.as_any().downcast_ref::() { + match &pod_info.pod_type { + Some(PodType::Prefill) => pd_router + .add_prefill_server(worker_url.clone(), pod_info.bootstrap_port) + .await + .map_err(|e| e.to_string()), + Some(PodType::Decode) => pd_router + .add_decode_server(worker_url.clone()) + .await + .map_err(|e| e.to_string()), + Some(PodType::Regular) | None => { + // Fall back to regular add_worker for regular pods + router.add_worker(&worker_url).await + } + } } else { - Err("Pod type is None in PD mode".to_string()) + Err("PD mode enabled but router is not a PDRouter".to_string()) } } else { - // Fallback to regular worker management + // Regular mode or no pod type specified router.add_worker(&worker_url).await }; @@ -412,7 +427,7 @@ async fn handle_pod_event( async fn handle_pod_deletion( pod_info: &PodInfo, tracked_pods: Arc>>, - router: Arc, + router: Arc, port: u16, pd_mode: bool, ) { @@ -435,18 +450,34 @@ async fn handle_pod_deletion( pod_info.name, pod_info.pod_type, worker_url ); + // Handle PD mode removal if pd_mode && pod_info.pod_type.is_some() { - // Use PD-aware worker removal - if let Some(pod_type) = &pod_info.pod_type { - if let Err(e) = router.remove_pd_worker(&worker_url, pod_type.clone()).await { - error!( - "Failed to remove PD worker {} from router: {}", - worker_url, e - ); + use crate::routers::pd_router::PDRouter; + + // Try to downcast to PDRouter for PD-specific removal + if let Some(pd_router) = router.as_any().downcast_ref::() { + match &pod_info.pod_type { + Some(PodType::Prefill) => { + if let Err(e) = pd_router.remove_prefill_server(&worker_url).await { + error!("Failed to remove prefill server {}: {}", worker_url, e); + } + } + Some(PodType::Decode) => { + if let Err(e) = pd_router.remove_decode_server(&worker_url).await { + error!("Failed to remove decode server {}: {}", worker_url, e); + } + } + Some(PodType::Regular) | None => { + // Fall back to regular remove_worker + router.remove_worker(&worker_url); + } } + } else { + // PD mode but not a PDRouter, use generic removal + router.remove_worker(&worker_url); } } else { - // Fallback to regular worker removal + // Regular mode removal router.remove_worker(&worker_url); } } else { @@ -462,11 +493,9 @@ async fn handle_pod_deletion( #[cfg(test)] mod tests { use super::*; - use crate::router::Router; use k8s_openapi::api::core::v1::{Pod, PodCondition, PodSpec, PodStatus}; use k8s_openapi::apimachinery::pkg::apis::meta::v1::ObjectMeta; use k8s_openapi::apimachinery::pkg::apis::meta::v1::Time; - use std::sync::RwLock; // Helper function to create a Pod for testing PodInfo::from_pod fn create_k8s_pod( @@ -546,14 +575,14 @@ mod tests { } // Helper to create a Router instance for testing event handlers - fn create_test_router() -> Arc { - let workers = Arc::new(RwLock::new(Vec::new())); - Arc::new(Router::Random { - workers, - timeout_secs: 5, - interval_secs: 1, - _health_checker: None, - }) + fn create_test_router() -> Arc { + use crate::config::PolicyConfig; + use crate::policies::PolicyFactory; + use crate::routers::router::Router; + + let policy = PolicyFactory::create_from_config(&PolicyConfig::Random); + let router = Router::new(vec![], policy, 5, 1).unwrap(); + Arc::new(router) as Arc } // Helper to create a PD config for testing diff --git a/sgl-router/tests/benchmark_integration.rs b/sgl-router/tests/benchmark_integration.rs index b21c93fcf7e7..31785900011f 100644 --- a/sgl-router/tests/benchmark_integration.rs +++ b/sgl-router/tests/benchmark_integration.rs @@ -6,7 +6,7 @@ use sglang_router_rs::openai_api_types::{ ChatCompletionRequest, ChatMessage, CompletionRequest, GenerateParameters, GenerateRequest, SamplingParams, StringOrArray, UserMessageContent, }; -use sglang_router_rs::request_adapter::{RouteableRequest, ToPdRequest}; +use sglang_router_rs::routers::request_adapter::{RouteableRequest, ToPdRequest}; #[test] fn test_benchmark_request_creation() { diff --git a/sgl-router/tests/test_pd_routing.rs b/sgl-router/tests/test_pd_routing.rs index 02b8c99f5318..ceb5fe9e69d3 100644 --- a/sgl-router/tests/test_pd_routing.rs +++ b/sgl-router/tests/test_pd_routing.rs @@ -8,12 +8,18 @@ //! Note: PD mode is enabled via the pd_disaggregation flag, not as a policy type. //! The policy type (Random, PowerOfTwo, CacheAware) determines the selection algorithm within PD mode. +// TODO: This test file needs to be updated for the new configuration structure +// where RoutingMode and PolicyConfig are separate + #[cfg(test)] mod test_pd_routing { use rand::Rng; use serde_json::json; - use sglang_router_rs::pd_types::PDSelectionPolicy; - use sglang_router_rs::router::{PolicyConfig, Router}; + use sglang_router_rs::config::{PolicyConfig, RouterConfig, RoutingMode}; + use sglang_router_rs::core::{WorkerFactory, WorkerType}; + use sglang_router_rs::routers::pd_types::get_hostname; + use sglang_router_rs::routers::pd_types::PDSelectionPolicy; + use sglang_router_rs::routers::RouterFactory; // Test-only struct to help validate PD request parsing #[derive(Debug)] @@ -116,49 +122,68 @@ mod test_pd_routing { #[test] fn test_pd_router_configuration() { - // Test PrefillDecodeConfig creation with various policies - // This config is used when pd_disaggregation=true - let configs = vec![ - PolicyConfig::PrefillDecodeConfig { - selection_policy: PDSelectionPolicy::Random, - prefill_urls: vec![ - ("http://prefill1:8080".to_string(), Some(9000)), - ("http://prefill2:8080".to_string(), None), - ], - decode_urls: vec![ - "http://decode1:8080".to_string(), - "http://decode2:8080".to_string(), - ], - timeout_secs: 10, - interval_secs: 1, - }, - PolicyConfig::PrefillDecodeConfig { - selection_policy: PDSelectionPolicy::PowerOfTwo, - prefill_urls: vec![("http://prefill:8080".to_string(), Some(9000))], - decode_urls: vec!["http://decode:8080".to_string()], - timeout_secs: 5, - interval_secs: 1, - }, - PolicyConfig::PrefillDecodeConfig { - selection_policy: PDSelectionPolicy::CacheAware { + // Test PD router configuration with various policies + // In the new structure, RoutingMode and PolicyConfig are separate + let test_cases = vec![ + ( + RoutingMode::PrefillDecode { + prefill_urls: vec![ + ("http://prefill1:8080".to_string(), Some(9000)), + ("http://prefill2:8080".to_string(), None), + ], + decode_urls: vec![ + "http://decode1:8080".to_string(), + "http://decode2:8080".to_string(), + ], + }, + PolicyConfig::Random, + ), + ( + RoutingMode::PrefillDecode { + prefill_urls: vec![("http://prefill:8080".to_string(), Some(9000))], + decode_urls: vec!["http://decode:8080".to_string()], + }, + PolicyConfig::PowerOfTwo { + load_check_interval_secs: 5, + }, + ), + ( + RoutingMode::PrefillDecode { + prefill_urls: vec![ + ("http://p1:8080".to_string(), Some(9000)), + ("http://p2:8080".to_string(), Some(9001)), + ("http://p3:8080".to_string(), Some(9002)), + ], + decode_urls: vec!["http://d1:8080".to_string(), "http://d2:8080".to_string()], + }, + PolicyConfig::CacheAware { cache_threshold: 0.7, balance_abs_threshold: 20, balance_rel_threshold: 1.2, + eviction_interval_secs: 60, + max_tree_size: 1000000, }, - prefill_urls: vec![ - ("http://p1:8080".to_string(), Some(9000)), - ("http://p2:8080".to_string(), Some(9001)), - ("http://p3:8080".to_string(), Some(9002)), - ], - decode_urls: vec!["http://d1:8080".to_string(), "http://d2:8080".to_string()], - timeout_secs: 10, - interval_secs: 2, - }, + ), ]; - for config in configs { + for (mode, policy) in test_cases { + let config = RouterConfig { + mode, + policy, + host: "127.0.0.1".to_string(), + port: 3001, + max_payload_size: 1024 * 1024, + request_timeout_secs: 60, + worker_startup_timeout_secs: 10, + worker_startup_check_interval_secs: 1, + discovery: None, + metrics: None, + log_dir: None, + log_level: None, + }; + // Router creation will fail due to health checks, but config should be valid - let result = Router::new(vec![], config); + let result = RouterFactory::create_router(&config); assert!(result.is_err()); let error_msg = result.unwrap_err(); // Error should be about health/timeout, not configuration @@ -225,9 +250,6 @@ mod test_pd_routing { #[test] fn test_bootstrap_injection_simulation() { - use sglang_router_rs::core::{WorkerFactory, WorkerType}; - use sglang_router_rs::pd_types::get_hostname; - // Since we can't test the actual inject_bootstrap_fields function here // (it's private in the router module), we'll test the expected behavior @@ -315,8 +337,6 @@ mod test_pd_routing { #[test] fn test_hostname_extraction() { - use sglang_router_rs::pd_types::get_hostname; - // Test various URL formats let test_cases = vec![ ("http://localhost:8080", "localhost"), @@ -662,7 +682,6 @@ mod test_pd_routing { #[test] fn test_bootstrap_injection_with_benchmark_requests() { use sglang_router_rs::core::{WorkerFactory, WorkerType}; - use sglang_router_rs::pd_types::get_hostname; // Test bootstrap injection with actual benchmark request patterns let mut benchmark_request = json!({ @@ -790,9 +809,6 @@ mod test_pd_routing { #[test] fn test_large_batch_bootstrap_injection() { - use sglang_router_rs::core::{WorkerFactory, WorkerType}; - use sglang_router_rs::pd_types::get_hostname; - // Test bootstrap injection performance with very large batches // This simulates the bench_one_batch_server.py scenario let large_batch_sizes = vec![1024, 4096, 8192]; From 7750b91ca81d15b85290703f24f8cd2716fe149a Mon Sep 17 00:00:00 2001 From: Hubert Lu <55214931+hubertlu-tw@users.noreply.github.com> Date: Fri, 18 Jul 2025 14:27:25 -0700 Subject: [PATCH 037/396] [AMD] Add triton awq_dequantize kernel to support AWQ on ROCm (#7661) --- python/sglang/srt/layers/quantization/awq.py | 12 +- .../srt/layers/quantization/awq_triton.py | 339 ++++++++++++++++++ python/sglang/srt/models/deepseek_v2.py | 6 +- test/srt/run_suite.py | 1 + test/srt/test_awq_dequant.py | 175 +++++++++ 5 files changed, 530 insertions(+), 3 deletions(-) create mode 100644 python/sglang/srt/layers/quantization/awq_triton.py create mode 100644 test/srt/test_awq_dequant.py diff --git a/python/sglang/srt/layers/quantization/awq.py b/python/sglang/srt/layers/quantization/awq.py index 4532673837dc..c20beb2ff0b9 100644 --- a/python/sglang/srt/layers/quantization/awq.py +++ b/python/sglang/srt/layers/quantization/awq.py @@ -43,11 +43,20 @@ except ImportError: ops = None -from sglang.srt.utils import is_cuda +from sglang.srt.utils import is_cuda, is_hip _is_cuda = is_cuda() +_is_hip = is_hip() if _is_cuda: from sgl_kernel import awq_dequantize, fused_marlin_moe +elif _is_hip: + from sglang.srt.layers.quantization.awq_triton import ( + awq_dequantize_triton as awq_dequantize, + ) + + warnings.warn(f"HIP does not support fused_marlin_moe currently.") +else: + warnings.warn(f"Only CUDA and HIP support AWQ currently.") logger = logging.getLogger(__name__) @@ -398,7 +407,6 @@ def apply( pack_factor = self.quant_config.pack_factor out_shape = x.shape[:-1] + (qweight.shape[-1] * pack_factor,) reshaped_x = x.reshape(-1, x.shape[-1]) - out = awq_dequantize(qweight, scales, qzeros) out = torch.matmul(reshaped_x, out) diff --git a/python/sglang/srt/layers/quantization/awq_triton.py b/python/sglang/srt/layers/quantization/awq_triton.py new file mode 100644 index 000000000000..13352efdb650 --- /dev/null +++ b/python/sglang/srt/layers/quantization/awq_triton.py @@ -0,0 +1,339 @@ +# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/awq_triton.py + +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch +import triton +import triton.language as tl + +AWQ_TRITON_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] + + +@triton.jit +def awq_dequantize_kernel( + qweight_ptr, # quantized matrix + scales_ptr, # scales, per group + zeros_ptr, # zeros, per group + group_size, # Should always be one of the supported group sizes + result_ptr, # Output matrix + num_cols, # input num cols in qweight + num_rows, # input num rows in qweight + BLOCK_SIZE_X: tl.constexpr, + BLOCK_SIZE_Y: tl.constexpr, +): + # Setup the pids. + pid_x = tl.program_id(axis=0) + pid_y = tl.program_id(axis=1) + + # Compute offsets and masks for qweight_ptr. + offsets_y = pid_y * BLOCK_SIZE_Y + tl.arange(0, BLOCK_SIZE_Y) + offsets_x = pid_x * BLOCK_SIZE_X + tl.arange(0, BLOCK_SIZE_X) + offsets = num_cols * offsets_y[:, None] + offsets_x[None, :] + + masks_y = offsets_y < num_rows + masks_x = offsets_x < num_cols + + masks = masks_y[:, None] & masks_x[None, :] + + # Compute offsets and masks for result output ptr. + result_offsets_y = pid_y * BLOCK_SIZE_Y + tl.arange(0, BLOCK_SIZE_Y) + result_offsets_x = pid_x * BLOCK_SIZE_X * 8 + tl.arange(0, BLOCK_SIZE_X * 8) + result_offsets = ( + 8 * num_cols * result_offsets_y[:, None] + result_offsets_x[None, :] + ) + + result_masks_y = result_offsets_y < num_rows + result_masks_x = result_offsets_x < num_cols * 8 + result_masks = result_masks_y[:, None] & result_masks_x[None, :] + + # Load the weights. + iweights = tl.load(qweight_ptr + offsets, masks, 0.0) + iweights = tl.interleave(iweights, iweights) + iweights = tl.interleave(iweights, iweights) + iweights = tl.interleave(iweights, iweights) + + # Create reverse AWQ order as tensor: [0, 4, 1, 5, 2, 6, 3, 7] + # that will map given indices to the correct order. + reverse_awq_order_tensor = ( + (tl.arange(0, 2) * 4)[None, :] + tl.arange(0, 4)[:, None] + ).reshape(8) + + # Use this to compute a set of shifts that can be used to unpack and + # reorder the values in iweights and zeros. + shifts = reverse_awq_order_tensor * 4 + shifts = tl.broadcast_to(shifts[None, :], (BLOCK_SIZE_Y * BLOCK_SIZE_X, 8)) + shifts = tl.reshape(shifts, (BLOCK_SIZE_Y, BLOCK_SIZE_X * 8)) + + # Unpack and reorder: shift out the correct 4-bit value and mask. + iweights = (iweights >> shifts) & 0xF + + # Compute zero offsets and masks. + zero_offsets_y = pid_y * BLOCK_SIZE_Y // group_size + tl.arange(0, 1) + zero_offsets_x = pid_x * BLOCK_SIZE_X + tl.arange(0, BLOCK_SIZE_X) + zero_offsets = num_cols * zero_offsets_y[:, None] + zero_offsets_x[None, :] + + zero_masks_y = zero_offsets_y < num_rows // group_size + zero_masks_x = zero_offsets_x < num_cols + zero_masks = zero_masks_y[:, None] & zero_masks_x[None, :] + + # Load the zeros. + zeros = tl.load(zeros_ptr + zero_offsets, zero_masks, 0.0) + zeros = tl.interleave(zeros, zeros) + zeros = tl.interleave(zeros, zeros) + zeros = tl.interleave(zeros, zeros) + zeros = tl.broadcast_to(zeros, (BLOCK_SIZE_Y, BLOCK_SIZE_X * 8)) + + # Unpack and reorder: shift out the correct 4-bit value and mask. + zeros = (zeros >> shifts) & 0xF + + # Compute scale offsets and masks. + scale_offsets_y = pid_y * BLOCK_SIZE_Y // group_size + tl.arange(0, 1) + scale_offsets_x = pid_x * BLOCK_SIZE_X * 8 + tl.arange(0, BLOCK_SIZE_X * 8) + scale_offsets = num_cols * 8 * scale_offsets_y[:, None] + scale_offsets_x[None, :] + scale_masks_y = scale_offsets_y < num_rows // group_size + scale_masks_x = scale_offsets_x < num_cols * 8 + scale_masks = scale_masks_y[:, None] & scale_masks_x[None, :] + + # Load the scales. + scales = tl.load(scales_ptr + scale_offsets, scale_masks, 0.0) + scales = tl.broadcast_to(scales, (BLOCK_SIZE_Y, BLOCK_SIZE_X * 8)) + + # Dequantize. + iweights = (iweights - zeros) * scales + iweights = iweights.to(result_ptr.type.element_ty) + + # Finally, store. + tl.store(result_ptr + result_offsets, iweights, result_masks) + + +@triton.jit +def awq_gemm_kernel( + a_ptr, + b_ptr, + c_ptr, + zeros_ptr, + scales_ptr, + M, + N, + K, + group_size, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + SPLIT_K: tl.constexpr, +): + pid = tl.program_id(axis=0) + pid_z = tl.program_id(1) + + # NOTE: This doesn't work in TRITON_INTERPRET=1 mode. Use below instead. + # num_pid_n = (N + BLOCK_SIZE_N - 1) // BLOCK_SIZE_N + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + + accumulator_dtype = c_ptr.type.element_ty + + # NOTE: This doesn't work in TRITON_INTERPRET=1 mode. Use below instead. + # accumulator = tl.arange(0, BLOCK_SIZE_N) + # accumulator = tl.broadcast_to(accumulator[None, :], + # (BLOCK_SIZE_M, BLOCK_SIZE_N)) + # accumulator = accumulator & 0x0 + # accumulator = accumulator.to(accumulator_dtype) + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=accumulator_dtype) + + # Create reverse AWQ order as tensor: [0, 4, 1, 5, 2, 6, 3, 7] + # that will map given indices to the correct order. + reverse_awq_order_tensor = ( + (tl.arange(0, 2) * 4)[None, :] + tl.arange(0, 4)[:, None] + ).reshape(8) + + # Create the necessary shifts to use to unpack. + shifts = reverse_awq_order_tensor * 4 + shifts = tl.broadcast_to(shifts[None, :], (BLOCK_SIZE_K * (BLOCK_SIZE_N // 8), 8)) + shifts = tl.reshape(shifts, (BLOCK_SIZE_K, BLOCK_SIZE_N)) + + # Offsets and masks. + offsets_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + masks_am = offsets_am < M + + offsets_bn = pid_n * (BLOCK_SIZE_N // 8) + tl.arange(0, BLOCK_SIZE_N // 8) + masks_bn = offsets_bn < N // 8 + + offsets_zn = pid_n * (BLOCK_SIZE_N // 8) + tl.arange(0, BLOCK_SIZE_N // 8) + masks_zn = offsets_zn < N // 8 + + offsets_sn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + masks_sn = offsets_sn < N + + offsets_k = pid_z * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + offsets_a = K * offsets_am[:, None] + offsets_k[None, :] + offsets_b = (N // 8) * offsets_k[:, None] + offsets_bn[None, :] + + a_ptrs = a_ptr + offsets_a + b_ptrs = b_ptr + offsets_b + + # NOTE: Use this in TRITON_INTERPRET=1 mode instead of tl.cdiv + # block_offset = BLOCK_SIZE_K * SPLIT_K + # for k in range(0, (K + block_offset - 1) // (block_offset)): + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)): + masks_k = offsets_k < K + masks_a = masks_am[:, None] & masks_k[None, :] + a = tl.load(a_ptrs, mask=masks_a, other=0.0) + + masks_b = masks_k[:, None] & masks_bn[None, :] + b = tl.load(b_ptrs, mask=masks_b, other=0.0) + b = tl.interleave(b, b) + b = tl.interleave(b, b) + b = tl.interleave(b, b) + + # Dequantize b. + offsets_szk = ( + BLOCK_SIZE_K * SPLIT_K * k + pid_z * BLOCK_SIZE_K + ) // group_size + tl.arange(0, 1) + offsets_z = (N // 8) * offsets_szk[:, None] + offsets_zn[None, :] + masks_zk = offsets_szk < K // group_size + masks_z = masks_zk[:, None] & masks_zn[None, :] + zeros_ptrs = zeros_ptr + offsets_z + zeros = tl.load(zeros_ptrs, mask=masks_z, other=0.0) + zeros = tl.interleave(zeros, zeros) + zeros = tl.interleave(zeros, zeros) + zeros = tl.interleave(zeros, zeros) + zeros = tl.broadcast_to(zeros, (BLOCK_SIZE_K, BLOCK_SIZE_N)) + + offsets_s = N * offsets_szk[:, None] + offsets_sn[None, :] + masks_sk = offsets_szk < K // group_size + masks_s = masks_sk[:, None] & masks_sn[None, :] + scales_ptrs = scales_ptr + offsets_s + scales = tl.load(scales_ptrs, mask=masks_s, other=0.0) + scales = tl.broadcast_to(scales, (BLOCK_SIZE_K, BLOCK_SIZE_N)) + + b = (b >> shifts) & 0xF + zeros = (zeros >> shifts) & 0xF + b = (b - zeros) * scales + b = b.to(c_ptr.type.element_ty) + + # Accumulate results. + accumulator = tl.dot(a, b, accumulator, out_dtype=accumulator_dtype) + + offsets_k += BLOCK_SIZE_K * SPLIT_K + a_ptrs += BLOCK_SIZE_K * SPLIT_K + b_ptrs += BLOCK_SIZE_K * SPLIT_K * (N // 8) + + c = accumulator.to(c_ptr.type.element_ty) + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + pid_z * N * M + N * offs_cm[:, None] + offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) + + +# qweights - [K , M // 8], int32 +# scales - [K // G, M ], float16 +# zeros - [K // G, M // 8], int32 +def awq_dequantize_triton( + qweight: torch.Tensor, + scales: torch.Tensor, + zeros: torch.Tensor, + block_size_x: int = 32, + block_size_y: int = 32, +) -> torch.Tensor: + K = qweight.shape[0] + M = scales.shape[1] + group_size = qweight.shape[0] // scales.shape[0] + + assert K > 0 and M > 0 + assert scales.shape[0] == K // group_size and scales.shape[1] == M + assert zeros.shape[0] == K // group_size and zeros.shape[1] == M // 8 + assert group_size <= K + assert group_size in AWQ_TRITON_SUPPORTED_GROUP_SIZES or group_size == K + + # Result tensor: + # number of rows = same as input tensor + # number of cols = 8 x input tensor num cols + result = torch.empty( + qweight.shape[0], + qweight.shape[1] * 8, + device=qweight.device, + dtype=scales.dtype, + ) + + Y = qweight.shape[0] # num rows + X = qweight.shape[1] # num cols + + grid = lambda META: ( + triton.cdiv(X, META["BLOCK_SIZE_X"]), + triton.cdiv(Y, META["BLOCK_SIZE_Y"]), + ) + awq_dequantize_kernel[grid]( + qweight, + scales, + zeros, + group_size, + result, + X, + Y, + BLOCK_SIZE_X=block_size_x, + BLOCK_SIZE_Y=block_size_y, + ) + + return result + + +# input - [M, K] +# qweight - [K, N // 8] +# qzeros - [K // G, N // 8] +# scales - [K // G, N] +# split_k_iters - parallelism along K-dimension, int, power of 2. +def awq_gemm_triton( + input: torch.Tensor, + qweight: torch.Tensor, + scales: torch.Tensor, + qzeros: torch.Tensor, + split_k_iters: int, + block_size_m: int = 32, + block_size_n: int = 32, + block_size_k: int = 32, +) -> torch.Tensor: + M, K = input.shape + N = qweight.shape[1] * 8 + group_size = qweight.shape[0] // qzeros.shape[0] + + assert N > 0 and K > 0 and M > 0 + assert qweight.shape[0] == K and qweight.shape[1] == N // 8 + assert qzeros.shape[0] == K // group_size and qzeros.shape[1] == N // 8 + assert scales.shape[0] == K // group_size and scales.shape[1] == N + assert split_k_iters & (split_k_iters - 1) == 0 and split_k_iters != 0 + assert split_k_iters <= 32 + assert group_size <= K + assert group_size in AWQ_TRITON_SUPPORTED_GROUP_SIZES or group_size == K + + grid = lambda META: ( + triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), + split_k_iters, + ) + + result = torch.zeros((split_k_iters, M, N), dtype=scales.dtype, device=input.device) + + # A = input, B = qweight, C = result + # A = M x K, B = K x N, C = M x N + awq_gemm_kernel[grid]( + input, + qweight, + result, + qzeros, + scales, + M, + N, + K, + group_size, + BLOCK_SIZE_M=block_size_m, + BLOCK_SIZE_N=block_size_n, + BLOCK_SIZE_K=block_size_k, + SPLIT_K=split_k_iters, + ) + + result = result.sum(0) + + return result diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 12aa9cb39c78..0da956b0158f 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -127,6 +127,10 @@ ) elif _is_cpu and _is_cpu_amx_available: pass +elif _is_hip: + from sglang.srt.layers.quantization.awq_triton import ( + awq_dequantize_triton as awq_dequantize, + ) else: from vllm._custom_ops import awq_dequantize @@ -2176,7 +2180,7 @@ def post_load_weights(self, is_nextn=False, weight_names=None): ) if hasattr(self_attn.kv_b_proj, "qweight"): # AWQ compatible - if _is_cuda: + if _is_cuda or _is_hip: w = awq_dequantize( self_attn.kv_b_proj.qweight, self_attn.kv_b_proj.scales, diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 41564869ed9b..1a89971e1775 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -147,6 +147,7 @@ class TestFile: # TestFile("test_vision_chunked_prefill.py", 175), # Disabled temporarily and track in #7701 TestFile("test_reasoning_parser.py", 5), TestFile("test_rope_rocm.py", 3), + TestFile("test_awq_dequant.py", 2), ], "per-commit-npu": [ TestFile("test_ascend_attention_backend.py", 400), diff --git a/test/srt/test_awq_dequant.py b/test/srt/test_awq_dequant.py new file mode 100644 index 000000000000..ec1f2b16a3d2 --- /dev/null +++ b/test/srt/test_awq_dequant.py @@ -0,0 +1,175 @@ +# Adapted from https://github.com/vllm-project/vllm/blob/main/tests/kernels/quantization/test_awq_triton.py +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +unittest version of the AWQ Triton kernel tests. + +Run with: + python -m unittest test_awq_dequant.py +""" +import unittest + +import torch + +from sglang.srt.layers.quantization.awq_triton import ( + AWQ_TRITON_SUPPORTED_GROUP_SIZES, + awq_dequantize_triton, + awq_gemm_triton, +) +from sglang.test.test_utils import CustomTestCase + +device = "cuda" + + +def reverse_awq_order(t: torch.Tensor) -> torch.Tensor: + bits = 4 + AWQ_REVERSE_ORDER = [0, 4, 1, 5, 2, 6, 3, 7] + idx = torch.arange(t.shape[-1], dtype=torch.int32, device=t.device) + idx = idx.view(-1, 32 // bits)[:, AWQ_REVERSE_ORDER].view(-1) + return (t[:, idx] & 0xF).contiguous() + + +def awq_dequantize_torch( + qweight: torch.Tensor, + scales: torch.Tensor, + qzeros: torch.Tensor, + group_size: int, +) -> torch.Tensor: + if group_size == -1: + group_size = qweight.shape[0] + + bits = 4 + shifts = torch.arange(0, 32, bits, device=qzeros.device) + + iweights = torch.bitwise_right_shift(qweight[:, :, None], shifts[None, None, :]).to( + torch.int8 + ) + iweights = reverse_awq_order(iweights.view(iweights.shape[0], -1)) + + zeros = torch.bitwise_right_shift(qzeros[:, :, None], shifts[None, None, :]).to( + torch.int8 + ) + zeros = reverse_awq_order(zeros.view(qzeros.shape[0], -1)) + + iweights = torch.bitwise_and(iweights, (2**bits) - 1) + zeros = torch.bitwise_and(zeros, (2**bits) - 1) + + scales = scales.repeat_interleave(group_size, dim=0) + zeros = zeros.repeat_interleave(group_size, dim=0) + return (iweights - zeros) * scales + + +class TestAWQTriton(CustomTestCase): + def test_dequantize(self): + rows_list = [3584, 18944, 128, 256, 512, 1024] + cols_list = [448, 576, 4736, 16, 32, 64, 128] + + for qweight_rows in rows_list: + for qweight_cols in cols_list: + for group_size in AWQ_TRITON_SUPPORTED_GROUP_SIZES: + with self.subTest( + rows=qweight_rows, cols=qweight_cols, g=group_size + ): + self._run_dequant_case( + qweight_rows=qweight_rows, + qweight_cols=qweight_cols, + group_size=group_size, + ) + + def _run_dequant_case(self, qweight_rows, qweight_cols, group_size): + if group_size == -1: + group_size = qweight_rows + + torch.manual_seed(0) + + qweight = torch.randint( + 0, + torch.iinfo(torch.int32).max, + (qweight_rows, qweight_cols), + dtype=torch.int32, + device=device, + ) + scales = torch.rand( + qweight_rows // group_size, + qweight_cols * 8, + dtype=torch.float16, + device=device, + ) + zeros = torch.randint( + 0, + torch.iinfo(torch.int32).max, + (qweight_rows // group_size, qweight_cols), + dtype=torch.int32, + device=device, + ) + + ref = awq_dequantize_torch(qweight, scales, zeros, group_size) + tri = awq_dequantize_triton(qweight, scales, zeros) + + # sanity + self.assertFalse(torch.any(torch.isinf(tri)) or torch.any(torch.isnan(tri))) + torch.testing.assert_close(ref, tri) + + # GEMM + def test_gemm(self): + N_list = [1, 2, 4, 8, 14, 17, 23, 32] + K_list = [128] + M_list = [16, 24, 32] + splitK_list = [1, 8] + + for N in N_list: + for K in K_list: + for M in M_list: + for group_size in AWQ_TRITON_SUPPORTED_GROUP_SIZES: + for splitK in splitK_list: + with self.subTest(N=N, K=K, M=M, g=group_size, sk=splitK): + self._run_gemm_case( + N=N, + K=K, + M=M, + group_size=group_size, + splitK=splitK, + ) + + def _run_gemm_case(self, N, K, M, group_size, splitK): + if group_size == -1: + group_size = K + + torch.manual_seed(0) + + x = torch.rand((N, K), dtype=torch.float32, device=device) + qweight = torch.randint( + 0, + torch.iinfo(torch.int32).max, + (K, M // 8), + dtype=torch.int32, + device=device, + ) + qzeros = torch.randint( + 0, + torch.iinfo(torch.int32).max, + (K // group_size, M // 8), + dtype=torch.int32, + device=device, + ) + scales = torch.rand((K // group_size, M), dtype=torch.float32, device=device) + + tri_out = awq_gemm_triton(x, qweight, scales, qzeros, splitK) + + self.assertFalse( + torch.any(torch.isinf(tri_out)) or torch.any(torch.isnan(tri_out)) + ) + + # dequantize & compare + w_deq = awq_dequantize_triton(qweight, scales, qzeros) + ref_out = torch.matmul(x, w_deq) + + self.assertFalse( + torch.any(torch.isinf(ref_out)) or torch.any(torch.isnan(ref_out)) + ) + + torch.testing.assert_close(tri_out.cpu(), ref_out.cpu(), atol=1e-1, rtol=1e-1) + + +if __name__ == "__main__": + unittest.main(verbosity=2) From 9c7a46180c251347c13bdf3325a04ceb77667bb3 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Fri, 18 Jul 2025 16:38:26 -0700 Subject: [PATCH 038/396] [Doc] Steps to add a new attention backend (#8155) --- .github/workflows/pr-test.yml | 4 ++-- docs/backend/attention_backend.md | 28 +++++++++++++++++++++++++ python/sglang/srt/managers/io_struct.py | 28 ++++++++++++------------- test/srt/run_suite.py | 22 +++++++++---------- 4 files changed, 55 insertions(+), 27 deletions(-) diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml index 2378695e21ee..6c79b0ae63fa 100644 --- a/.github/workflows/pr-test.yml +++ b/.github/workflows/pr-test.yml @@ -56,7 +56,7 @@ jobs: strategy: fail-fast: false matrix: - part: [0, 1, 2, 3, 4, 5, 6, 7, 8] + part: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] steps: - name: Checkout code uses: actions/checkout@v4 @@ -69,7 +69,7 @@ jobs: timeout-minutes: 30 run: | cd test/srt - python3 run_suite.py --suite per-commit --auto-partition-id ${{ matrix.part }} --auto-partition-size 9 + python3 run_suite.py --suite per-commit --auto-partition-id ${{ matrix.part }} --auto-partition-size 10 unit-test-backend-2-gpu: if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && diff --git a/docs/backend/attention_backend.md b/docs/backend/attention_backend.md index 4e9ecf8e206a..caf23446f5a6 100644 --- a/docs/backend/attention_backend.md +++ b/docs/backend/attention_backend.md @@ -52,3 +52,31 @@ python3 -m sglang.launch_server --tp 8 --model deepseek-ai/DeepSeek-R1 --attenti ```bash python3 -m sglang.launch_server --model meta-llama/Meta-Llama-3.1-8B-Instruct --attention-backend ascend ``` + + +## Steps to add a new attention backend +To add a new attention backend, you can learn from the existing backends +(`python/sglang/srt/layers/attention/triton_backend.py`, `python/sglang/srt/layers/attention/flashattention_backend.py`) +and follow the steps below. + +1. Run without cuda graph. Support the two forward functions + - forward_extend + - Will be used for prefill, prefill with KV cache, and target verification + - It will be called once per layer + - forward_decode + - Will be used for normal decode, and draft decode + - It will be called once per layer + - init_forward_metadata + - Initialize the class and common metadata shared by all layers + - Call the plan function for optimizations like split_kv + - It will be called once per forward +2. Run with cuda graph. It has two phases (capture and replay) and you need to implement three functions + - init_cuda_graph_state + - It will be called once during life time + - Create all common shared buffers + - init_forward_metadata_capture_cuda_graph + - It will be called before capturing a cuda graph + - It is similar to init_forward_metadata but write the medatada to some pre-defined buffers + - init_forward_metadata_replay_cuda_graph + - It will be called before replaying a cuda graph + - This function is in the critical path and needs to be fast diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 6eebf21e94b6..8e1d1075aab6 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -13,14 +13,14 @@ # ============================================================================== """ The definition of objects transferred between different -processes (TokenizerManager, DetokenizerManager, Controller). +processes (TokenizerManager, DetokenizerManager, Scheduler). """ import copy import uuid from dataclasses import dataclass, field from enum import Enum -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from sglang.srt.managers.schedule_batch import BaseFinishReason from sglang.srt.multimodal.mm_utils import has_valid_data @@ -545,7 +545,7 @@ class EmbeddingReqInput: # The request id. rid: Optional[Union[List[str], str]] = None # Dummy sampling params for compatibility - sampling_params: Union[List[Dict], Dict] = None + sampling_params: Optional[Union[List[Dict], Dict]] = None # Dummy input embeds for compatibility input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None # Whether to log metrics for this request (e.g. health_generate calls do not log metrics) @@ -953,17 +953,6 @@ class ProfileReqType(Enum): STOP_PROFILE = 2 -class ExpertDistributionReq(Enum): - START_RECORD = 1 - STOP_RECORD = 2 - DUMP_RECORD = 3 - - -@dataclass -class ExpertDistributionReqOutput: - pass - - @dataclass class ProfileReq: type: ProfileReqType @@ -1013,6 +1002,17 @@ class HealthCheckOutput: pass +class ExpertDistributionReq(Enum): + START_RECORD = 1 + STOP_RECORD = 2 + DUMP_RECORD = 3 + + +@dataclass +class ExpertDistributionReqOutput: + pass + + @dataclass class Function: description: Optional[str] = None diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 1a89971e1775..e67362cf8258 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -155,11 +155,11 @@ class TestFile: "per-commit-2-gpu": [ TestFile("models/lora/test_lora_tp.py", 116), TestFile("test_data_parallelism.py", 73), - TestFile("test_dp_attention.py", 137), + TestFile("test_dp_attention.py", 277), TestFile("test_mla_tp.py", 170), TestFile("test_patch_torch.py", 19), TestFile("test_update_weights_from_distributed.py", 103), - TestFile("test_release_memory_occupation.py", 44), + TestFile("test_release_memory_occupation.py", 127), ], "per-commit-2-gpu-amd": [ TestFile("models/lora/test_lora_tp.py", 116), @@ -170,7 +170,7 @@ class TestFile: ], "per-commit-4-gpu": [ TestFile("test_local_attn.py", 250), - TestFile("test_pp_single_node.py", 150), + TestFile("test_pp_single_node.py", 372), TestFile("test_multi_instance_release_memory_occupation.py", 64), ], "per-commit-4-gpu-deepep": [ @@ -182,12 +182,12 @@ class TestFile: "per-commit-8-gpu": [ # Disabled because it hangs on the CI. # TestFile("test_moe_ep.py", 181), - TestFile("test_disaggregation.py", 270), + TestFile("test_disaggregation.py", 499), TestFile("test_disaggregation_different_tp.py", 155), - TestFile("test_full_deepseek_v3.py", 463), + TestFile("test_full_deepseek_v3.py", 333), ], "per-commit-8-gpu-deepep": [ - TestFile("test_deepep_large.py", 485), + TestFile("test_deepep_large.py", 338), ], "per-commit-8-gpu-amd": [ TestFile("test_full_deepseek_v3.py", 250), @@ -214,11 +214,11 @@ class TestFile: TestFile("test_nightly_gsm8k_eval_amd.py"), ], "vllm_dependency_test": [ - TestFile("test_awq.py"), - TestFile("test_bnb.py"), - TestFile("test_gguf.py", 78), - TestFile("test_gptqmodel_dynamic.py", 72), - TestFile("test_vllm_dependency.py"), + TestFile("test_awq.py", 163), + TestFile("test_bnb.py", 5), + TestFile("test_gguf.py", 96), + TestFile("test_gptqmodel_dynamic.py", 102), + TestFile("test_vllm_dependency.py", 185), ], } From 3964b352c3613b06b0f10fa5d7a8b2630fa80d61 Mon Sep 17 00:00:00 2001 From: Mick Date: Sat, 19 Jul 2025 08:19:27 +0800 Subject: [PATCH 039/396] chore: tune mem fraction static for vlm (#6881) --- .../sglang/srt/model_executor/model_runner.py | 4 +- python/sglang/srt/server_args.py | 48 ++++++++++++++++++- test/srt/test_vision_openai_server_a.py | 10 ++-- test/srt/test_vision_openai_server_b.py | 8 ++-- 4 files changed, 57 insertions(+), 13 deletions(-) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 923b4d02b543..bbd5b000067f 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -411,7 +411,7 @@ def model_specific_adjustment(self): else: server_args.attention_backend = "triton" logger.info( - f"Attention backend not set. Use {server_args.attention_backend} backend by default." + f"Attention backend not explicitly specified. Use {server_args.attention_backend} backend by default." ) elif self.use_mla_backend: if server_args.device != "cpu": @@ -463,7 +463,7 @@ def model_specific_adjustment(self): if not self.is_multimodal_chunked_prefill_supported: server_args.chunked_prefill_size = -1 logger.info( - f"Automatically turn of --chunked-prefill-size as it is not supported for " + f"Automatically turn off --chunked-prefill-size as it is not supported for " f"{self.model_config.hf_config.model_type}" ) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index cb8038d3366a..20db0b4b9c79 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -337,8 +337,52 @@ def __post_init__(self): # Multimodal models need more memory for the image processor model_config = ModelConfig.from_server_args(self) - if model_config.is_multimodal: - self.mem_fraction_static *= 0.90 + + vision_config = getattr(model_config.hf_config, "vision_config", None) + + if model_config.is_multimodal and vision_config: + # roughly reduce the mem_fraction_static base on params of Vit + original_server_arg_mem_fraction = self.mem_fraction_static + # a base mem_fraction_static factor for regular Vit + base_mem_fraction_reduction_ratio = 0.95 + + vit_num_layers = getattr(vision_config, "num_hidden_layers", 24) + vit_hidden_size = getattr(vision_config, "hidden_size", 1024) + + # baseline ViT params (ViT-L/14) + baseline_vit_layers = 24 + baseline_vit_hidden_size = 1024 + + # weight params count + current_complexity_score = vit_num_layers * (vit_hidden_size**2) + baseline_complexity_score = baseline_vit_layers * ( + baseline_vit_hidden_size**2 + ) + complexity_ratio = ( + current_complexity_score / baseline_complexity_score + if baseline_complexity_score > 0 + else 1.0 + ) + + # every time the complexity grows 100%, adjust final factor for 10% + sensitivity_scale = 0.1 + dynamic_adjustment_factor = 1.0 - sensitivity_scale * ( + complexity_ratio - 1.0 + ) + dynamic_adjustment_factor = max( + 0.8, min(1.05, dynamic_adjustment_factor) + ) + + final_overall_factor = ( + base_mem_fraction_reduction_ratio * dynamic_adjustment_factor + ) + self.mem_fraction_static = ( + original_server_arg_mem_fraction * final_overall_factor + ) + logger.warning( + f"Multimodal model: Dynamically adjusted --mem-fraction-static " + f"from: {original_server_arg_mem_fraction:.3f} to: {self.mem_fraction_static:.3f}." + ) # Set chunked prefill size, which depends on the gpu memory capacity if self.chunked_prefill_size is None: diff --git a/test/srt/test_vision_openai_server_a.py b/test/srt/test_vision_openai_server_a.py index 90b91578f3cd..f252c4884eb0 100644 --- a/test/srt/test_vision_openai_server_a.py +++ b/test/srt/test_vision_openai_server_a.py @@ -30,7 +30,7 @@ def setUpClass(cls): api_key=cls.api_key, other_args=[ "--mem-fraction-static", - "0.4", + "0.35", ], ) cls.base_url += "/v1" @@ -52,7 +52,7 @@ def setUpClass(cls): api_key=cls.api_key, other_args=[ "--mem-fraction-static", - "0.4", + "0.35", ], ) cls.base_url += "/v1" @@ -75,7 +75,7 @@ def setUpClass(cls): other_args=[ "--context-length", "300", - "--mem-fraction-static=0.80", + "--mem-fraction-static=0.75", ], ) cls.base_url += "/v1" @@ -147,7 +147,7 @@ def setUpClass(cls): other_args=[ "--trust-remote-code", "--mem-fraction-static", - "0.4", + "0.35", ], ) cls.base_url += "/v1" @@ -181,7 +181,7 @@ def setUpClass(cls): other_args=[ "--trust-remote-code", "--mem-fraction-static", - "0.7", + "0.65", ], ) cls.base_url += "/v1" diff --git a/test/srt/test_vision_openai_server_b.py b/test/srt/test_vision_openai_server_b.py index 7a5716cb18a6..f6152ea76dfc 100644 --- a/test/srt/test_vision_openai_server_b.py +++ b/test/srt/test_vision_openai_server_b.py @@ -22,7 +22,7 @@ def setUpClass(cls): other_args=[ "--trust-remote-code", "--mem-fraction-static", - "0.73", + "0.70", ], ) cls.base_url += "/v1" @@ -44,7 +44,7 @@ def setUpClass(cls): other_args=[ "--trust-remote-code", "--mem-fraction-static", - "0.8", + "0.75", ], ) cls.base_url += "/v1" @@ -88,7 +88,7 @@ def setUpClass(cls): other_args=[ "--trust-remote-code", "--mem-fraction-static", - "0.4", + "0.35", ], ) cls.base_url += "/v1" @@ -197,7 +197,7 @@ def setUpClass(cls): other_args=[ "--trust-remote-code", "--mem-fraction-static", - "0.75", + "0.70", "--disable-radix-cache", "--max-loras-per-batch", "1", From d918ab7985580cebea03216a5e309058df449821 Mon Sep 17 00:00:00 2001 From: Haohui Mai Date: Fri, 18 Jul 2025 19:59:39 -0700 Subject: [PATCH 040/396] Support NVFP4 quantized dense models on AMD CDNA2/CDNA3 GPUs (#7302) Co-authored-by: HAI Co-authored-by: Sai Enduri --- python/pyproject.toml | 1 + python/sglang/srt/configs/model_config.py | 3 + python/sglang/srt/layers/linear.py | 1 + .../srt/layers/quantization/__init__.py | 2 + .../sglang/srt/layers/quantization/petit.py | 249 ++++++++++++++++++ .../srt/layers/quantization/petit_utils.py | 104 ++++++++ python/sglang/srt/server_args.py | 1 + 7 files changed, 361 insertions(+) create mode 100644 python/sglang/srt/layers/quantization/petit.py create mode 100644 python/sglang/srt/layers/quantization/petit_utils.py diff --git a/python/pyproject.toml b/python/pyproject.toml index 7afb3581a3b5..5b6501afd192 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -79,6 +79,7 @@ blackwell = [ srt_hip = [ "sglang[runtime_common]", "torch", + "petit_kernel", ] # xpu is not enabled in public vllm and torch whl, diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index 1a62178b96c8..7d7f2eb95b22 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -391,6 +391,7 @@ def _verify_quantization(self) -> None: "compressed-tensors", "fbgemm_fp8", "w8a8_fp8", + "petit_nvfp4", ] optimized_quantization_methods = [ "fp8", @@ -408,9 +409,11 @@ def _verify_quantization(self) -> None: "moe_wna16", "qoq", "w4afp8", + "petit_nvfp4", ] compatible_quantization_methods = { "modelopt_fp4": ["modelopt"], + "petit_nvfp4": ["modelopt"], "w8a8_int8": ["compressed-tensors", "compressed_tensors"], "w8a8_fp8": ["compressed-tensors", "compressed_tensors"], } diff --git a/python/sglang/srt/layers/linear.py b/python/sglang/srt/layers/linear.py index 1c770193fccb..07be9a3c6b14 100644 --- a/python/sglang/srt/layers/linear.py +++ b/python/sglang/srt/layers/linear.py @@ -53,6 +53,7 @@ "ModelOptFp8LinearMethod", "ModelOptFp4LinearMethod", "IPEXAWQLinearMethod", + "PetitNvFp4LinearMethod", ] _is_cpu = is_cpu() diff --git a/python/sglang/srt/layers/quantization/__init__.py b/python/sglang/srt/layers/quantization/__init__.py index 9995b72d0e0b..d51186465a0f 100644 --- a/python/sglang/srt/layers/quantization/__init__.py +++ b/python/sglang/srt/layers/quantization/__init__.py @@ -58,6 +58,7 @@ def override_quantization_method(self, *args, **kwargs): ModelOptFp8Config, ) from sglang.srt.layers.quantization.moe_wna16 import MoeWNA16Config +from sglang.srt.layers.quantization.petit import PetitNvFp4Config from sglang.srt.layers.quantization.qoq import QoQConfig from sglang.srt.layers.quantization.utils import get_linear_quant_method from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config @@ -76,6 +77,7 @@ def override_quantization_method(self, *args, **kwargs): "compressed-tensors": CompressedTensorsConfig, "qoq": QoQConfig, "w4afp8": W4AFp8Config, + "petit_nvfp4": PetitNvFp4Config, } # VLLM-dependent quantization methods diff --git a/python/sglang/srt/layers/quantization/petit.py b/python/sglang/srt/layers/quantization/petit.py new file mode 100644 index 000000000000..e7ee3239f64c --- /dev/null +++ b/python/sglang/srt/layers/quantization/petit.py @@ -0,0 +1,249 @@ +# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/modelopt.py + + +import logging +from typing import Any, Callable, Dict, List, Optional + +import regex as re +import torch +from torch.nn.parameter import Parameter + +from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod +from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter +from sglang.srt.layers.quantization.base_config import ( + LinearMethodBase, + QuantizationConfig, + QuantizeMethodBase, +) +from sglang.srt.layers.quantization.petit_utils import ( + apply_petit_nvfp4_linear, + prepare_nvfp4_layer_for_petit, + verify_petit_nvfp4_supported, +) +from sglang.srt.layers.quantization.utils import is_layer_skipped + +# Initialize logger for the module +logger = logging.getLogger(__name__) + + +# Configuration class to support the NVFP4 quantized model generated by the ModelOpt quantization tool +class PetitNvFp4Config(QuantizationConfig): + """Config class for Petit FP4.""" + + def __init__( + self, + is_checkpoint_nvfp4_serialized: bool = False, + kv_cache_quant_algo: str = None, + group_size: int = None, + exclude_modules: List[str] = None, + ) -> None: + self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized + if is_checkpoint_nvfp4_serialized: + logger.warning( + "Detected nvfp4 checkpoint. Please note that the " + "format is experimental and subject to change." + ) + self.group_size = group_size + self.kv_cache_quant_algo = kv_cache_quant_algo + self.exclude_modules = exclude_modules + + @classmethod + def get_name(cls) -> str: + return "petit_nvfp4" + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.bfloat16, torch.half] + + @classmethod + def get_min_capability(cls) -> int: + # Petit supports the gfx90a and gfx942 GPUs + return 90 + + @classmethod + def get_config_filenames(cls) -> List[str]: + return ["hf_quant_config.json"] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "PetitNvFp4Config": + quant_config = cls.get_from_keys(config, ["quantization"]) + quant_method = quant_config["quant_algo"] + group_size = quant_config.get("group_size", None) + verify_petit_nvfp4_supported(quant_method, group_size) + + is_checkpoint_nvfp4_serialized = "NVFP4" in quant_method + kv_cache_quant_algo = quant_config["kv_cache_quant_algo"] + if not kv_cache_quant_algo: + kv_cache_quant_algo = "auto" + exclude_modules = quant_config.get("exclude_modules", None) + if not (group_size and kv_cache_quant_algo and (exclude_modules is not None)): + logger.warning( + f"group_size: {group_size}," + f"kv_cache_quant_algo: {kv_cache_quant_algo}," + f"exclude_modules: {exclude_modules}" + ) + raise ValueError( + "NVFP4 quantization requires group size and " + "kv_cache_quant_algo specified in " + "hf_quant_config.json" + ) + return cls( + is_checkpoint_nvfp4_serialized, + kv_cache_quant_algo, + group_size, + exclude_modules, + ) + + @classmethod + def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]: + can_convert = cls.is_petit_nvfp4_compatible(hf_quant_cfg) + if can_convert: + return cls.get_name() + return None + + @classmethod + def is_petit_nvfp4_compatible(cls, quant_config: Dict[str, Any]) -> bool: + quant_method = quant_config.get("quant_method", "").lower() + return quant_method == "modelopt" + + def is_layer_excluded(self, prefix: str, exclude_modules: list): + for pattern in exclude_modules: + regex_str = pattern.replace(".", r"\.").replace("*", r".*") + if re.fullmatch(regex_str, prefix): + return True + return False + + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> Optional["QuantizeMethodBase"]: + if isinstance(layer, LinearBase): + if is_layer_skipped(prefix, self.exclude_modules) or self.is_layer_excluded( + prefix, self.exclude_modules + ): + return UnquantizedLinearMethod() + return PetitNvFp4LinearMethod(self) + return None + + def get_scaled_act_names(self) -> List[str]: + return [] + + +class PetitNvFp4LinearMethod(LinearMethodBase): + """Linear method for NVFP4. + Supports loading NVFP4 checkpoints with the following structure: + + |Tensor Name | datatype | shape | + |----------------------------------------------------| + |input_scale | torch.float32 | scalar | + |weight | NVFP4(SE2M1) | [1, X, y/2] | + |weight_scale | FP8-E4M3 | [X, Y] | + |weight_scale_2 | torch.float32 | scalar | + + The weights are quantized per block of 16 elements. + Args: quant_config: The ModelOpt quantization config. + """ + + def __init__(self, quant_config: PetitNvFp4Config): + self.quant_config = quant_config + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + del input_size, output_size + if not self.quant_config.is_checkpoint_nvfp4_serialized: + raise ValueError( + "NVFP4 quantization was selected, " + " dynamic quantization is not supported." + ) + + output_size_per_partition = sum(output_partition_sizes) + weight_loader = extra_weight_attrs.get("weight_loader") + + layer.logical_widths = output_partition_sizes + + layer.input_size_per_partition = input_size_per_partition + layer.output_size_per_partition = output_size_per_partition + if input_size_per_partition % 16 != 0: + raise ValueError( + "Unsupported model when in features size is " "not multiple of 16" + ) + + weight_dtype = ( + torch.float8_e4m3fn + if self.quant_config.is_checkpoint_nvfp4_serialized + else params_dtype + ) + + weight = ModelWeightParameter( + data=torch.empty( + # 2 fp4 data is packed in one uint8 in the input dimension + output_size_per_partition, + input_size_per_partition // 2, + dtype=torch.uint8, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("weight", weight) + + input_scale = PerTensorScaleParameter( + data=torch.empty(len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader, + ) + + layer.register_parameter("input_scale", input_scale) + + weight_scale_2 = PerTensorScaleParameter( + data=torch.empty(len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader, + ) + layer.register_parameter("weight_scale_2", weight_scale_2) + + weight_scale = ModelWeightParameter( + data=torch.empty( + output_size_per_partition, + input_size_per_partition // self.quant_config.group_size, + dtype=weight_dtype, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + + layer.register_parameter("weight_scale", weight_scale) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + input_scale_2 = layer.input_scale.max().to(torch.float32) + weight_scale_2 = layer.weight_scale_2.max().to(torch.float32) + layer.input_scale = Parameter(input_scale_2, requires_grad=False) + layer.weight_scale_2 = Parameter(weight_scale_2, requires_grad=False) + layer.alpha = Parameter( + layer.input_scale * layer.weight_scale_2, requires_grad=False + ) + + prepare_nvfp4_layer_for_petit(layer) + del layer.input_scale + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + return apply_petit_nvfp4_linear( + input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + weight_scale_2=layer.weight_scale_2, + size_n=layer.output_size_per_partition, + size_k=layer.input_size_per_partition, + bias=bias, + ) diff --git a/python/sglang/srt/layers/quantization/petit_utils.py b/python/sglang/srt/layers/quantization/petit_utils.py new file mode 100644 index 000000000000..529869f2413f --- /dev/null +++ b/python/sglang/srt/layers/quantization/petit_utils.py @@ -0,0 +1,104 @@ +from typing import Optional + +import torch + +try: + from petit_kernel import mul_nvfp4_a16, process_nvfp4_scales, repack_nvfp4 +except ImportError: + + def _check_petit_nvfp4_supported( + quant_method: str, group_size: Optional[int] + ) -> tuple[bool, Optional[str]]: + return ( + False, + "Petit is not installed. Please install it with `pip install petit-kernel`.", + ) + + def prepare_nvfp4_layer_for_petit(layer: torch.nn.Module) -> None: + raise ValueError( + "Petit is not installed. Please install it with `pip install petit-kernel`." + ) + + def apply_petit_nvfp4_linear( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + weight_scale_2: torch.Tensor, + size_n: int, + size_k: int, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + raise ValueError( + "Petit is not installed. Please install it with `pip install petit-kernel`." + ) + + +def _check_petit_nvfp4_supported( + quant_method: str, group_size: Optional[int] +) -> tuple[bool, Optional[str]]: + if quant_method != "NVFP4": + return ( + False, + "Petit currently only supports: NVFP4" + " quantizations in sglang. Please check the " + "`hf_quant_config.json` file for your model's " + "quant configuration.", + ) + if group_size is not None and group_size != 16: + return ( + False, + "Petit currently only supports: group_size=16" " quantizations.", + ) + return (True, None) + + +def verify_petit_nvfp4_supported(quant_method: str, group_size: Optional[int]) -> None: + supported, error_msg = _check_petit_nvfp4_supported(quant_method, group_size) + if not supported: + raise ValueError(error_msg) + + +def prepare_nvfp4_layer_for_petit(layer: torch.nn.Module) -> None: + # Repack weights to petit format + part_size_n = layer.output_size_per_partition + part_size_k = layer.input_size_per_partition + qweight = layer.weight.view(torch.int32).contiguous() + petit_qweight = repack_nvfp4(qweight, size_n=part_size_n, size_k=part_size_k) + layer.weight = torch.nn.Parameter(petit_qweight, requires_grad=False) + + # Permute scales + weight_scale = process_nvfp4_scales( + scales=layer.weight_scale, size_k=part_size_k, size_n=part_size_n + ) + layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False) + + return + + +def apply_petit_nvfp4_linear( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + weight_scale_2: torch.Tensor, + size_n: int, + size_k: int, + bias: Optional[torch.Tensor] = None, +) -> torch.Tensor: + reshaped_x = input.reshape(-1, input.shape[-1]) + out_shape = input.shape[:-1] + (size_n,) + + # TODO: Use auto-tuning to find the performant solution_id + output = mul_nvfp4_a16( + a=reshaped_x, + b=weight, + s=weight_scale, + global_scale=weight_scale_2, + size_m=reshaped_x.size(0), + size_n=size_n, + size_k=size_k, + solution_id=-1, + ) + if bias is not None: + output.add_(bias) # In-place add + + return output.reshape(out_shape) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 20db0b4b9c79..4f9e17e05dda 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -766,6 +766,7 @@ def add_cli_args(parser: argparse.ArgumentParser): "gguf", "modelopt", "modelopt_fp4", + "petit_nvfp4", "w8a8_int8", "w8a8_fp8", "moe_wna16", From b7e951a6dbcd64a1c011f276c57ab84fb7fa76f0 Mon Sep 17 00:00:00 2001 From: Binyao Jiang Date: Fri, 18 Jul 2025 21:03:53 -0700 Subject: [PATCH 041/396] Feat: Support audio in Phi4-mm model (#8048) --- .../multimodal_language_models.md | 2 +- python/sglang/srt/conversation.py | 1 + python/sglang/srt/managers/schedule_batch.py | 4 + python/sglang/srt/models/phi4mm.py | 41 +- python/sglang/srt/models/phi4mm_audio.py | 1260 +++++++++++ python/sglang/srt/models/phi4mm_utils.py | 1917 +++++++++++++++++ .../multimodal/processors/base_processor.py | 14 +- .../srt/multimodal/processors/phi4mm.py | 95 +- python/sglang/srt/utils.py | 7 +- test/srt/test_vision_openai_server_b.py | 22 +- test/srt/test_vision_openai_server_common.py | 22 +- 11 files changed, 3332 insertions(+), 53 deletions(-) create mode 100644 python/sglang/srt/models/phi4mm_audio.py create mode 100644 python/sglang/srt/models/phi4mm_utils.py diff --git a/docs/supported_models/multimodal_language_models.md b/docs/supported_models/multimodal_language_models.md index 665d8de7ed7d..66de3d8a1c15 100644 --- a/docs/supported_models/multimodal_language_models.md +++ b/docs/supported_models/multimodal_language_models.md @@ -37,5 +37,5 @@ in the GitHub search bar. | **Gemma 3 (Multimodal)** | `google/gemma-3-4b-it` | `gemma-it` | Gemma 3's larger models (4B, 12B, 27B) accept images (each image encoded as 256 tokens) alongside text in a combined 128K-token context. | | **Kimi-VL** (A3B) | `moonshotai/Kimi-VL-A3B-Instruct` | `kimi-vl` | Kimi-VL is a multimodal model that can understand and generate text from images. | | **Mistral-Small-3.1-24B** | `mistralai/Mistral-Small-3.1-24B-Instruct-2503` | `mistral` | Mistral 3.1 is a multimodal model that can generate text from text or images input. It also supports tool calling and structured output. | -| **Phi-4-multimodal-instruct** | `microsoft/Phi-4-multimodal-instruct` | `phi-4-mm` | Phi-4-multimodal-instruct is the multimodal variant of the Phi-4-mini model, enhanced with LoRA for improved multimodal capabilities. Currently, it supports only text and vision modalities in SGLang. | +| **Phi-4-multimodal-instruct** | `microsoft/Phi-4-multimodal-instruct` | `phi-4-mm` | Phi-4-multimodal-instruct is the multimodal variant of the Phi-4-mini model, enhanced with LoRA for improved multimodal capabilities. It supports text, vision and audio modalities in SGLang. | | **MiMo-VL** (7B) | `XiaomiMiMo/MiMo-VL-7B-RL` | `mimo-vl` | Xiaomi's compact yet powerful vision-language model featuring a native resolution ViT encoder for fine-grained visual details, an MLP projector for cross-modal alignment, and the MiMo-7B language model optimized for complex reasoning tasks. | diff --git a/python/sglang/srt/conversation.py b/python/sglang/srt/conversation.py index c085c4423af6..cb4bdbc44a0c 100644 --- a/python/sglang/srt/conversation.py +++ b/python/sglang/srt/conversation.py @@ -729,6 +729,7 @@ def generate_chat_conv( sep="<|end|>", stop_str="<|end|>", image_token="<|endoftext10|>", + audio_token="<|endoftext11|>", ) ) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 01da558b7bf9..a9ed66f9aa3d 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -239,6 +239,10 @@ class MultimodalDataItem: # For gemma3n input_features_mask: Optional[torch.Tensor] = None + # For phi4-mm + image_attention_mask: Optional[torch.Tensor] = None + audio_attention_mask: Optional[torch.Tensor] = None + @staticmethod def is_empty_list(l): if l is None: diff --git a/python/sglang/srt/models/phi4mm.py b/python/sglang/srt/models/phi4mm.py index 8a74888ac9c5..b7997fc0acae 100644 --- a/python/sglang/srt/models/phi4mm.py +++ b/python/sglang/srt/models/phi4mm.py @@ -40,6 +40,7 @@ from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.models.idefics2 import Idefics2VisionTransformer from sglang.srt.models.llama import LlamaForCausalLM +from sglang.srt.models.phi4mm_audio import AudioEmbedding logger = logging.getLogger(__name__) @@ -420,16 +421,49 @@ def __init__( model_dir=config._name_or_path, ) + if isinstance(config.embd_layer["audio_embd_layer"], dict): + embedding_config = { + "embedding_cls": config.embd_layer["audio_embd_layer"]["embedding_cls"], + **config.embd_layer["audio_embd_layer"], + } + else: + embedding_config = {"embedding_cls": config.embd_layer["embedding_cls"]} + + self.embed_tokens_extend = AudioEmbedding(config, **embedding_config) + def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: dtype = next(self.vision_encoder.parameters()).dtype pixel_values = torch.cat([item.feature for item in items], dim=0).type(dtype) - image_attention_mask = torch.cat([item.image_emb_mask for item in items], dim=0) + image_attention_mask = torch.cat( + [item.image_attention_mask for item in items], dim=0 + ) image_sizes = torch.cat([item.image_sizes for item in items], dim=0) image_embeds = self.vision_encoder( pixel_values, image_sizes, image_attention_mask ) return torch.cat(image_embeds).type(dtype) + def get_audio_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: + # (e.g. multiple examples) and the second dim is the multi-audio dim + # (e.g. multiple audios in the same example) + embed_tokens_extend_param = next(self.embed_tokens_extend.parameters()) + device = embed_tokens_extend_param.device + dtype = embed_tokens_extend_param.dtype + audio_embeds = [ + self.embed_tokens_extend( + # item.feature: (num_audios_in_a_sequence, T, D) + # item.audio_attention_mask: (num_audios_in_a_sequence, T, D) BoolTensor or None + audio_features=item.feature.to(device).type(dtype), + audio_attention_mask=( + item.audio_attention_mask.to(device) + if item.audio_attention_mask is not None + else None + ), + ) + for item in items + ] + return torch.cat(audio_embeds).type(dtype) + def forward( self, input_ids: torch.Tensor, @@ -443,6 +477,7 @@ def forward( language_model=self.language_model, data_embedding_funcs={ Modality.IMAGE: self.get_image_feature, + Modality.AUDIO: self.get_audio_feature, }, positions=positions, ) @@ -464,6 +499,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): (".self_attn.qkv_proj", ".self_attn.v_proj", "v"), ] prefix_mapping = { + "model.embed_tokens_extend.audio_embed.audio_projection.vision.": "embed_tokens_extend.audio_projection_for_vision.", + "model.embed_tokens_extend.audio_embed.audio_projection.speech.": "embed_tokens_extend.audio_projection.", + "model.embed_tokens_extend.audio_embed.": "embed_tokens_extend.", "model.embed_tokens_extend.image_embed.": "vision_encoder.", "model.": "language_model.model.", } @@ -472,7 +510,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): "img_processor.encoder.layers.26", "img_processor.head", "img_processor.post_layernorm", - "audio", ] def _should_skip(name: str) -> bool: diff --git a/python/sglang/srt/models/phi4mm_audio.py b/python/sglang/srt/models/phi4mm_audio.py new file mode 100644 index 000000000000..fd199836e9a9 --- /dev/null +++ b/python/sglang/srt/models/phi4mm_audio.py @@ -0,0 +1,1260 @@ +# Copyright 2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +#!/usr/bin/env python3 +import abc +import math +from typing import Literal, Optional + +import numpy as np +import torch +import torch.nn.functional as F +from torch import Tensor, nn +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( + CheckpointWrapper, +) +from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel +from transformers import PretrainedConfig + +from sglang.srt.models.phi4mm_utils import ( + AbsolutePositionalEncoding, + ConvModule, + FeedForward, + MeanVarianceNormLayer, + MultiHeadedAttention, + MultiSequential, + NemoConvSubsampling, + T5RelativeAttentionLogitBias, + adaptive_enc_mask, + get_offset, + unfold_tensor, +) + +_AUDIO_PLACEHOLDER_TOKEN_ID = 200011 # <|endoftext11|> + + +class ConformerEncoderLayer(nn.Module): + """ConformerEncoder Layer module. + for more details see conformer paper: + https://arxiv.org/abs/2005.08100 + This module implement the Conformer block layer. + + Args: + d_model: int + attention dim. + ext_pw_out_channel: int + if > 0, ext_pw_out_channel is a dim channel size + for the last pointwise conv after swish activation. + depthwise_seperable_out_channel: int + if set different to 0, the number of + depthwise_seperable_out_channel will be used as a + channel_out of the second conv1d layer. + otherwise, it equal to 0, the second conv1d layer is skipped. + depthwise_multiplier: int + number of input_dim channels duplication. this value + will be used to compute the hidden channels of the Conv1D. + n_head: int + the number of heads for multihead attention module. + d_ffn: int + output size of the feed_forward blocks. + ext_pw_kernel_size: int + kernel size of the conv pointwise of the conformer. + kernel_size: int + kernel size. + dropout_rate: float + dropout rate. + causal: bool, optional + if set to True, convolution have no access + to future frames. default False. + batch_norm: bool, optional + if set to True, apply batchnorm before activation + in ConvModule layer of the conformer. + default False + activation: str, optional + activation function name, + one of ["relu", "swish", "sigmoid"], + sigmoid activation is only used with "glu_in_fnn=True", + default "relu". + chunk_se: int, optional + 0 for offline SE. + 1 for streaming SE, where mean is computed + by accumulated history until current chunk_se. + 2 for streaming SE, where mean is computed + by only the current chunk. + default 0. + chunk_size: int, optional + chunk_size for cnn. default 18 + conv_activation: str, optional + activation function used in ConvModule part + of the conformer, default "relu". + conv_glu_type: str, optional + activation function used for the glu inside + the ConvModule part of the conformer. + default: "sigmoid". + bias_in_glu: bool, optional + if set to True, use additive bias in the weight module + before GLU. + linear_glu_in_convm: bool, optional + if set to True, use GLULinear module, + otherwise, used GLUPointWiseConv module. + default to False. + attention_inner_dim: int, optional + if equal to -1, attention dim for linears k/q/v is + equal to d_model. otherwise attention_inner_dim is used. + default -1. + attention_glu_type: str, optional + activation function for glu used in the multihead attention, + default "swish". + activation_checkpointing: str, optional + a dictionarry of {"module","interval","offload"}, where + "module": str + accept ["transformer", "attention"] to select + which module should do activation checkpointing. + "interval": int, default 1, + interval of applying activation checkpointing, + interval = 1 means that we apply checkpointing + on every layer (if activation), otherwise, + we apply it every x interval. + "offload": bool, default False, + if set to True, we offload activation to cpu and + reload it during backward, otherwise, + we recalculate activation in backward. + default "". + export: bool, optional + if set to True, it remove the padding from convolutional layers + and allow the onnx conversion for inference. + default False. + use_pt_scaled_dot_product_attention: bool, optional + if set to True, use pytorch's scaled dot product attention + implementation in training. + attn_group_sizes: int, optional + the number of groups to use for attention, default 1 + (Multi-Head Attention), + 1 = typical Multi-Head Attention, + 1 < attn_group_sizes < attention_heads = Grouped-Query Attention + attn_group_sizes = attention_heads = Multi-Query Attention + """ + + def __init__( + self, + d_model=512, + ext_pw_out_channel=0, + depthwise_seperable_out_channel=256, + depthwise_multiplier=1, + n_head=4, + d_ffn=2048, + ext_pw_kernel_size=1, + kernel_size=3, + dropout_rate=0.1, + causal=False, + batch_norm=False, + activation="relu", + chunk_se=0, + chunk_size=18, + conv_activation="relu", + conv_glu_type="sigmoid", + bias_in_glu=True, + linear_glu_in_convm=False, + attention_inner_dim=-1, + attention_glu_type="swish", + activation_checkpointing="", + export=False, + use_pt_scaled_dot_product_attention=False, + attn_group_sizes: int = 1, + ): + super().__init__() + + self.feed_forward_in = FeedForward( + d_model=d_model, + d_inner=d_ffn, + dropout_rate=dropout_rate, + activation=activation, + bias_in_glu=bias_in_glu, + ) + + self.self_attn = MultiHeadedAttention( + n_head, + d_model, + dropout_rate, + attention_inner_dim, + attention_glu_type, + bias_in_glu, + use_pt_scaled_dot_product_attention=use_pt_scaled_dot_product_attention, + group_size=attn_group_sizes, + ) + self.conv = ConvModule( + d_model, + ext_pw_out_channel, + depthwise_seperable_out_channel, + ext_pw_kernel_size, + kernel_size, + depthwise_multiplier, + dropout_rate, + causal, + batch_norm, + chunk_se, + chunk_size, + conv_activation, + conv_glu_type, + bias_in_glu, + linear_glu_in_convm, + export=export, + ) + + self.feed_forward_out = FeedForward( + d_model=d_model, + d_inner=d_ffn, + dropout_rate=dropout_rate, + activation=activation, + bias_in_glu=bias_in_glu, + ) + + self.layer_norm_att = nn.LayerNorm(d_model) + self.layer_norm = nn.LayerNorm(d_model) + + def forward( + self, + x, + pos_k, + pos_v, + mask, + relative_attention_bias: Optional[Tensor] = None, + ): + """ConformerEncoder forward. + + Args: + x: torch.Tensor + input feature of shape (batch, max_time_in, size) + pos_k: torch.Tensor + positional key embedding. + mask: torch.Tensor + mask for x (batch, max_time_in) + relative_attention_bias: Optional[torch.Tensor] + bias added to attention logits w.r.t. relative positions + (1, n_head, time1, time2) + """ + x = x + 0.5 * self.feed_forward_in(x) + norm_x = self.layer_norm_att(x) + + x = x + self.self_attn( + norm_x, + norm_x, + norm_x, + pos_k, + pos_v, + mask, + relative_attention_bias=relative_attention_bias, + ) + x = x + self.conv(x) + x = x + 0.5 * self.feed_forward_out(x) + + out = self.layer_norm(x) + + return out, pos_k, pos_v, mask + + +class TransformerEncoderBase(abc.ABC, nn.Module): + """The Base class for Transformer based encoders + + Please set causal = True in streaming model + Args: + input_size: int + input feature dimension. + chunk_size: int, list(int) + Number of frames for each chunk + This variable can take 2 forms: + int: Used for inference, or single chunk size training + list(int) : Used only for variable chunk size training + Some examples for the 2 cases: + chunk_size = 12 + chunk_size = [6, 8, 12, 24] + left_chunk: int, list(int) + Number of chunks used for masking in streaming mode. + This variable can take 2 forms: + int: Used for inference, or single chunk size training + list(int) : Used only for variable chunk size training. When + chunk_size is a list, left_chunk must be a list with same length. + Some examples for the 2 cases: + left_chunk = 6 + left_chunk = [12, 9, 6, 3] + attention_dim: int, optional + attention dimension. default 256. + attention_heads: int, optional + the number of heads. default 4 + input_layer: str, optional + input layer type before Conformer, + one of ["linear", "conv2d", "custom", "vgg2l", "embed"], + default "conv2d" + cnn_out: int, optional + the number of CNN channels before Conformer. + default -1. + cnn_layer_norm: bool, optional + layer norm between Conformer and the first CNN. + default False. + time_reduction: int, optional + time reduction factor + default 4 + dropout_rate: float, optional + dropout rate. default 0.1 + padding_idx: int, optional + padding index for input_layer=embed + default -1 + relative_attention_bias_args: dict, optional + use more efficient scalar bias-based relative multihead attention + (Q*K^T + B) implemented in cmb.basics.embedding. + [T5/ALiBi]RelativeAttentionLogitBias + usage: relative_attention_bias_args={"type": t5/alibi} + additional method-specific arguments can be provided (see + transformer_base.py) + positional_dropout_rate: float, optional + dropout rate after positional encoding. default 0.0 + nemo_conv_settings: dict, optional + A dictionary of settings for NeMo Subsampling. + default None + conv2d_extra_padding: str, optional + Add extra padding in conv2d subsampling layers. Choices are + (feat, feat_time, none, True). + if True or feat_time, the extra padding is added into non full + supraframe utts in batch. + Default: none + attention_group_size: int, optional + the number of groups to use for attention, default 1 + (Multi-Head Attention), + 1 = typical Multi-Head Attention, + 1 < attention_group_size < attention_heads = Grouped-Query + Attention + attention_group_size = attention_heads = Multi-Query Attention + """ + + def __init__( + self, + input_size, + chunk_size, + left_chunk, + attention_dim=256, + attention_heads=4, + input_layer="nemo_conv", + cnn_out=-1, + cnn_layer_norm=False, + time_reduction=4, + dropout_rate=0.0, + padding_idx=-1, + relative_attention_bias_args=None, + positional_dropout_rate=0.0, + nemo_conv_settings=None, + conv2d_extra_padding: Literal["feat", "feat_time", "none", True] = "none", + attention_group_size=1, + encoder_embedding_config=None, + ): + super().__init__() + self.input_size = input_size + self.input_layer = input_layer + self.chunk_size = chunk_size + self.left_chunk = left_chunk + self.attention_dim = attention_dim + self.num_heads = attention_heads + self.attention_group_size = attention_group_size + self.time_reduction = time_reduction + self.nemo_conv_settings = nemo_conv_settings + self.encoder_embedding_config = encoder_embedding_config + + if self.input_layer == "nemo_conv": + default_nemo_conv_settings = { + "subsampling": "dw_striding", + "subsampling_factor": self.time_reduction, + "feat_in": input_size, + "feat_out": attention_dim, + "conv_channels": 256, + "subsampling_conv_chunking_factor": 1, + "activation": nn.ReLU(), + "is_causal": False, + } + # Override any of the defaults with the incoming, user settings + if nemo_conv_settings: + default_nemo_conv_settings.update(nemo_conv_settings) + for i in ["subsampling_factor", "feat_in", "feat_out"]: + assert ( + i not in nemo_conv_settings + ), "{i} should be specified outside of the NeMo dictionary" + + self.embed = NemoConvSubsampling( + **default_nemo_conv_settings, + ) + else: + raise ValueError("unknown input_layer: " + input_layer) + + self.pos_emb = AbsolutePositionalEncoding( + attention_dim, positional_dropout_rate + ) + + self.relative_attention_bias_type = ( + relative_attention_bias_args.get("type") + if relative_attention_bias_args + else None + ) + if self.relative_attention_bias_type == "t5": + assert ( + self.num_heads % self.attention_group_size == 0 + ), "attention_group_size must divide n_head" + self.relative_attention_bias_layer = T5RelativeAttentionLogitBias( + self.num_heads // self.attention_group_size, + max_distance=relative_attention_bias_args.get( + "t5_bias_max_distance", 1000 + ), + symmetric=relative_attention_bias_args.get("t5_bias_symmetric", False), + ) + else: + raise NotImplementedError + + self.encoder_embedding = MeanVarianceNormLayer( + self.encoder_embedding_config["input_size"] + ) + + def compute_lens_change(self, feature_lens): + """feature_lens: int + return updated feature lens. + + This used to return a different lambda function for each case that + computed the right thing. That does not work within Torchscript. + If you really need this to be faster, create nn.Module()-s for all + the cases and return one of them. Torchscript does support that. + """ + if self.input_layer == "nemo_conv": + # Handle the special causal case + subsampling_causal_cond = self.nemo_conv_settings.get( + "subsampling", "dw_striding" + ) in [ + "dw_striding", + "striding", + "striding_conv1d", + ] + is_causal = self.nemo_conv_settings.get("is_causal", False) + if is_causal and subsampling_causal_cond: + lens_change = ( + torch.ceil(feature_lens / self.time_reduction).long() + if isinstance(feature_lens, Tensor) + else math.ceil(feature_lens / self.time_reduction) + ) + feature_lens_remainder = feature_lens % self.time_reduction + if isinstance(feature_lens, Tensor): + lens_change[feature_lens_remainder != 1] += 1 + elif feature_lens_remainder != 1: + lens_change += 1 + return lens_change + ceil_func = math.ceil if isinstance(feature_lens, int) else torch.ceil + return ceil_func(feature_lens / self.time_reduction) + + @abc.abstractmethod + def forward(self): + """Abstract forward method implementation.""" + + def _chunk_size_selection(self, chunk_size=None, left_chunk=None): + """If chunk size is a list, we will randomly select a chunk size.""" + + if chunk_size is None: + chunk_size = self.chunk_size + if left_chunk is None: + left_chunk = self.left_chunk + if isinstance(chunk_size, list): + # Variable chunk size during training + chunk_size_index = int( + torch.randint(low=0, high=len(chunk_size), size=(1,)) + ) + chunk_size_train_eff = chunk_size[chunk_size_index] + if not isinstance(left_chunk, list): + raise ValueError( + "Since chunk_size is a list, left_chunk must be a list" + ) + if len(left_chunk) != len(chunk_size): + raise ValueError( + "The length of left_chunk must be the same as length of " + "chunk_size." + ) + left_chunk_train_eff = left_chunk[chunk_size_index] + else: + chunk_size_train_eff = chunk_size + left_chunk_train_eff = left_chunk + + return chunk_size_train_eff, left_chunk_train_eff + + def _get_embed_class(self, embed): + # pylint: disable=protected-access + is_embed_using_act_chkpt = isinstance(embed, CheckpointWrapper) + is_embed_fsdp_wrapped = isinstance(embed, FullyShardedDataParallel) + embed_class = embed + if is_embed_using_act_chkpt: + embed_class = embed._checkpoint_wrapped_module + if is_embed_fsdp_wrapped: + embed_class = embed.module + return embed_class + + def _forward_embeddings_core(self, input_tensor, masks): + embed_class = self._get_embed_class(self.embed) + assert isinstance(embed_class, NemoConvSubsampling) + input_tensor, masks = self.embed(input_tensor, masks) + return input_tensor, masks + + def _position_embedding(self, input_tensor): + pos_k = None + pos_v = None + if self.relative_attention_bias_layer is None: + input_tensor = self.pos_emb( + input_tensor + ) # default to add abs sinusoid embedding + return pos_k, pos_v + + def _streaming_mask(self, seq_len, batch_size, chunk_size, left_chunk): + chunk_size_train_eff, left_chunk_train_eff = self._chunk_size_selection( + chunk_size, left_chunk + ) + + # Create mask matrix for streaming + # S stores start index. if chunksize is 18, s is [0,18,36,....] + chunk_start_idx = np.arange(0, seq_len, chunk_size_train_eff) + + enc_streaming_mask = ( + adaptive_enc_mask( + seq_len, chunk_start_idx, left_window=left_chunk_train_eff + ) + .unsqueeze(0) + .expand([batch_size, -1, -1]) + ) + return enc_streaming_mask + + def forward_embeddings(self, xs_pad, masks, chunk_size_nc=None, left_chunk_nc=None): + """Forwarding the inputs through the top embedding layers + + Args: + xs_pad: torch.Tensor + input tensor + masks: torch.Tensor + input mask + chunk_size_nc: (optional, default is None) chunk size for + non-causal layers + left_chunk_nc: (optional, default is None) # of left chunks for + non-causal layers + """ + # pylint: disable=R0915 + # get new lens. + seq_len = int(self.compute_lens_change(xs_pad.shape[1])) + if seq_len <= 0: + raise ValueError( + f"""The sequence length after time reduction is invalid: + {seq_len}. Your input feature is too short. Consider + filtering out the very short sentence from data + loader""", + ) + + batch_size = xs_pad.shape[0] + + enc_streaming_mask = self._streaming_mask( + seq_len, batch_size, self.chunk_size, self.left_chunk + ) + + if xs_pad.is_cuda: + enc_streaming_mask = enc_streaming_mask.cuda() + xs_pad = xs_pad.cuda() + + input_tensor = xs_pad + input_tensor, masks = self._forward_embeddings_core(input_tensor, masks) + + streaming_mask = enc_streaming_mask + if streaming_mask is not None and masks is not None: + hs_mask = masks & streaming_mask + elif masks is not None: + hs_mask = masks + else: + hs_mask = streaming_mask + + if chunk_size_nc is not None: + enc_streaming_mask_nc = self._streaming_mask( + seq_len, batch_size, chunk_size_nc, left_chunk_nc + ) + if xs_pad.is_cuda: + enc_streaming_mask_nc = enc_streaming_mask_nc.cuda() + if masks is not None: + hs_mask_nc = masks & enc_streaming_mask_nc + else: + hs_mask_nc = enc_streaming_mask_nc + else: + hs_mask_nc = None + + pos_k, pos_v = self._position_embedding(input_tensor) + + if chunk_size_nc is None: + return input_tensor, pos_k, pos_v, hs_mask, masks + return input_tensor, pos_k, pos_v, hs_mask, masks, hs_mask_nc + + def get_offset(self): + """Returns offset used when retaining inputs for decoding. + + This is essentially, how many additional frames have to be added to + the front-end CNN input to ensure it can produce a single output. + So if the "padding" parameter is 0, typically offset will be > 0. + """ + return get_offset(self.input_layer, self.time_reduction) + + +class ConformerEncoder(TransformerEncoderBase): + """ConformerEncoder module. + see original paper for more details: + https://arxiv.org/abs/2005.08100 + + Please set causal = True in streaming model + Args: + input_size: int + input feature dimension. + chunk_size: int, list(int) + Number of frames for each chunk + This variable can take 2 forms: + int: Used for inference, or single chunk size training + list(int) : Used only for variable chunk size training + Some examples for the 2 cases: + chunk_size = 12 + chunk_size = [6, 8, 12, 24] + left_chunk: int, list(int) + Number of chunks used for masking in streaming mode. + This variable can take 2 forms: + int: Used for inference, or single chunk size training + list(int) : Used only for variable chunk size training. When + chunk_size is a list, left_chunk must be a list with same length. + Some examples for the 2 cases: + left_chunk = 6 + left_chunk = [12, 9, 6, 3] + left_chunk: int + number of chunks used for masking in streaming mode. + num_lang: int + This parameter is used to store the number of languages in the + lang_dict, only used for multiseed/multilingual models. + default None. + attention_dim: int, optional + attention dimension. default 256. + attention_heads: int, optional + the number of heads. default 4 + linear_units: + the number of units of position-wise feed forward. + default 2048 + num_block: + number of Transformer layer. default 6 + dropout_rate: float, optional + dropout rate. default 0.1 + input_layer: str, optional + input layer type before Conformer, + one of ["linear", "conv2d", "custom", "vgg2l", "embed"], + default "conv2d" + causal: bool, optional + if set to True, convolution have no access + to future frames. default False. + batch_norm: bool, optional + if set to True, apply batchnorm before activation + in ConvModule layer of the conformer. + default False + cnn_out: int, optional + the number of CNN channels before Conformer. + default -1. + cnn_layer_norm: bool, optional + layer norm between Conformer and the first CNN. + default False. + ext_pw_out_channel: int, optional + the number of channel for CNN + before depthwise_seperable_CNN. + If 0 then use linear. default 0. + ext_pw_kernel_size: int, optional + kernel size of N before depthwise_seperable_CNN. + only work for ext_pw_out_channel > 0. + default 1 + depthwise_seperable_out_channel: int, optional + the number of channel for + depthwise_seperable_CNN. + default 256. + depthwise_multiplier: int, optional + the number of multiplier for + depthwise_seperable_CNN. + default 1. + chunk_se: int, optional + 0 for offline SE. + 1 for streaming SE, where mean is computed + by accumulated history until current chunk_se. + 2 for streaming SE, where mean is computed + by only the current chunk. + default 0. + kernel_size: int, optional + the number of kernels for depthwise_seperable_CNN. + default 3. + activation: str, optional + FeedForward block activation. + one of ["relu", "swish", "sigmoid"] + default "relu". + conv_activation: str, optional + activation function used in ConvModule part + of the conformer, default "relu". + conv_glu_type: str, optional + activation used use glu in depthwise_seperable_CNN, + default "sigmoid" + bias_in_glu: bool, optional + if set to True, use additive bias in the weight module + before GLU. default True + linear_glu_in_convm: bool, optional + if set to True, use GLULinear module, + otherwise, used GLUPointWiseConv module. + default to False. + attention_glu_type: str + only work for glu_in_attention !=0 + default "swish". + export: bool, optional + if set to True, it remove the padding from convolutional layers + and allow the onnx conversion for inference. + default False. + activation_checkpointing: str, optional + a dictionarry of {"module","interval","offload"}, where + "module": str + accept ["transformer", "attention"] to select + which module should do activation checkpointing. + "interval": int, default 1, + interval of applying activation checkpointing, + interval = 1 means that we apply checkpointing + on every layer (if activation), otherwise, + we apply it every x interval. + "offload": bool, default False, + if set to True, we offload activation to cpu and + reload it during backward, otherwise, + we recalculate activation in backward. + default "". + extra_layer_output_idx: int + the layer index to be exposed. + relative_attention_bias_args: dict, optional + use more efficient scalar bias-based relative multihead attention + (Q*K^T + B) implemented in cmb.basics.embedding. + [T5/ALiBi]RelativeAttentionLogitBias + usage: relative_attention_bias_args={"type": t5/alibi} + additional method-specific arguments can be provided (see + transformer_base.py) + time_reduction: int optional + time reduction factor + default 4 + use_pt_scaled_dot_product_attention: whether to use pytorch scaled + dot product attention in training. + Default: False + nemo_conv_settings: dict, optional + A dictionary of settings for NeMo Subsampling. + default: None + usage: nemo_conv_settings= + { + "subsampling": + dw_striding/striding/dw_striding_conv1d/striding_conv1d, + "conv_channels": int, + "subsampling_conv_chunking_factor": int, + "is_causal": True/False + } + conv2d_extra_padding: str, optional + Add extra padding in conv2d subsampling layers. Choices are + (feat, feat_time, none, True) + Default: none + replication_pad_for_subsample_embedding: For batched-streaming + decoding, use "replication" padding for the cache at start of + utterance. + Default: False + attention_group_size: int, optional + the number of groups to use for attention, default 1 + (Multi-Head Attention), + 1 = typical Multi-Head Attention, + 1 < attention_group_size < attention_heads = Grouped-Query + Attention + attention_group_size = attention_heads = Multi-Query Attention + """ + + extra_multi_layer_output_idxs: list[int] + + def __init__( # pylint: disable-all + self, + input_size, + chunk_size, + left_chunk, + num_lang=None, + attention_dim=256, + attention_heads=4, + linear_units=2048, + num_blocks=6, + dropout_rate=0.1, + input_layer="nemo_conv", + causal=True, + batch_norm=False, + cnn_out=-1, + cnn_layer_norm=False, + ext_pw_out_channel=0, + ext_pw_kernel_size=1, + depthwise_seperable_out_channel=256, + depthwise_multiplier=1, + chunk_se=0, + kernel_size=3, + activation="relu", + conv_activation="relu", + conv_glu_type="sigmoid", + bias_in_glu=True, + linear_glu_in_convm=False, + attention_glu_type="swish", + export=False, + extra_layer_output_idx=-1, + extra_multi_layer_output_idxs=[], # noqa + activation_checkpointing="", + relative_attention_bias_args=None, + time_reduction=4, + use_pt_scaled_dot_product_attention=False, + nemo_conv_settings=None, + conv2d_extra_padding: Literal["feat", "feat_time", "none", True] = "none", + replication_pad_for_subsample_embedding=False, + attention_group_size=1, + encoder_embedding_config=None, + ): + super().__init__( + input_size, + chunk_size, + left_chunk, + attention_dim, + attention_heads, + input_layer, + cnn_out, + cnn_layer_norm, + time_reduction, + dropout_rate=dropout_rate, + relative_attention_bias_args=relative_attention_bias_args, + positional_dropout_rate=0.0, + nemo_conv_settings=nemo_conv_settings, + conv2d_extra_padding=conv2d_extra_padding, + attention_group_size=attention_group_size, + encoder_embedding_config=encoder_embedding_config, + ) + self.num_blocks = num_blocks + self.num_lang = num_lang + self.kernel_size = kernel_size + self.replication_pad_for_subsample_embedding: bool = ( + replication_pad_for_subsample_embedding + ) + assert ( + self.num_heads % attention_group_size == 0 + ), "attention_group_size must divide n_head" + self.num_heads_k = self.num_heads // attention_group_size + + self.encoders = MultiSequential( + *[ + ConformerEncoderLayer( + d_model=attention_dim, + ext_pw_out_channel=ext_pw_out_channel, + depthwise_seperable_out_channel=depthwise_seperable_out_channel, + depthwise_multiplier=depthwise_multiplier, + n_head=attention_heads, + d_ffn=linear_units, + ext_pw_kernel_size=ext_pw_kernel_size, + kernel_size=kernel_size, + dropout_rate=dropout_rate, + causal=causal, + batch_norm=batch_norm, + activation=activation, + chunk_se=chunk_se, + chunk_size=chunk_size, + conv_activation=conv_activation, + conv_glu_type=conv_glu_type, + bias_in_glu=bias_in_glu, + linear_glu_in_convm=linear_glu_in_convm, + attention_glu_type=attention_glu_type, + activation_checkpointing=activation_checkpointing, + export=export, + use_pt_scaled_dot_product_attention=use_pt_scaled_dot_product_attention, + attn_group_sizes=attention_group_size, + ) + for _ in range(num_blocks) + ] + ) + self.extra_layer_output_idx = extra_layer_output_idx + self.extra_multi_layer_output_idxs = extra_multi_layer_output_idxs + # Make a zeros scalar we can use in get_initial_state to determine + # the device and the needed dtype: + self.register_buffer("dev_type", torch.zeros(()), persistent=False) + + def init_relative_attention_bias(self, input_tensor): + if self.relative_attention_bias_layer: + return self.relative_attention_bias_layer(input_tensor) + + def calculate_hs_mask(self, xs_pad, device, mask): + max_audio_length = xs_pad.shape[1] + batch_size = xs_pad.shape[0] + enc_streaming_mask = self._streaming_mask( + max_audio_length, batch_size, self.chunk_size, self.left_chunk + ) + enc_streaming_mask = enc_streaming_mask.to(device) + if mask is None: + return enc_streaming_mask + + feature_lens = mask.sum(1) + padding_length = feature_lens + pad_mask = torch.arange(0, max_audio_length, device=device).expand( + padding_length.size(0), -1 + ) < padding_length.unsqueeze(1) + pad_mask = pad_mask.unsqueeze(1) + pad_mask = pad_mask & enc_streaming_mask + return pad_mask + + @torch.jit.ignore + def forward(self, xs_pad, masks): + """Conformer Forward function + + Args: + xs_pad: torch.Tensor + input tensor + masks: torch.Tensor + post-embedding input lengths + """ + xs_pad = self.encoder_embedding(xs_pad) + input_tensor, pos_k, pos_v, hs_mask, masks = self.forward_embeddings( + xs_pad, masks + ) + + unfolded = False + ori_bz, seq_len, D = input_tensor.shape + max_seq_len = 500 # maximum position for absolute positional encoding + if seq_len > max_seq_len: + # audio sequence is longer than max_seq_len, unfold it into chunks + # of max_seq_len + unfolded = True + # the unfold op will drop residual frames, pad it to the multiple + # of max_seq_len + if seq_len % max_seq_len > 0: + chunk_pad_size = max_seq_len - (seq_len % max_seq_len) + else: + chunk_pad_size = 0 + if chunk_pad_size > 0: + input_tensor_pad = F.pad( + input_tensor, (0, 0, 0, chunk_pad_size), "constant", 0 + ) + input_tensor = input_tensor_pad.to(input_tensor.device) + input_tensor = unfold_tensor(input_tensor, max_seq_len) + if masks is not None: + # revise hs_mask here because the previous calculated hs_mask + # did not consider extra pad + subsampled_pad_mask = masks.squeeze( + 1 + ) # [bz, subsampled_unmask_seq_len] + extra_padded_subsamlped_pad_mask = F.pad( + subsampled_pad_mask, (0, chunk_pad_size), "constant", False + ) # extra padding to the pad mask + extra_padded_subsamlped_pad_mask = ( + extra_padded_subsamlped_pad_mask.unsqueeze(-1).float() + ) + masks_unfold = unfold_tensor( + extra_padded_subsamlped_pad_mask, max_seq_len + ) # unfold the pad mask like we did to the input tensor + masks_unfold = masks_unfold.squeeze( + -1 + ).bool() # unfold op does not support bool tensor + else: + masks_unfold = None + hs_mask = self.calculate_hs_mask( + input_tensor, input_tensor.device, masks_unfold + ) # calculate hs_mask based on the unfolded pad mask + + # layer_emb = None + + relative_attention_bias = self.init_relative_attention_bias(input_tensor) + + _simplified_path = ( + self.extra_layer_output_idx == -1 and relative_attention_bias is None + ) + + if _simplified_path: + input_tensor, *_ = self.encoders(input_tensor, pos_k, pos_v, hs_mask) + else: + for i, layer in enumerate(self.encoders): + input_tensor, _, _, _ = layer( + input_tensor, + pos_k, + pos_v, + hs_mask, + relative_attention_bias=relative_attention_bias, + ) + + # if i == self.extra_layer_output_idx: + # layer_emb = input_tensor + + if unfolded: + embed_dim = input_tensor.shape[-1] + input_tensor = input_tensor.reshape(ori_bz, -1, embed_dim) + # if we ever padded before unfolding, we need to remove the padding + if chunk_pad_size > 0: + input_tensor = input_tensor[:, :-chunk_pad_size, :] + + return input_tensor, masks # , layer_emb + + +class WindowQformer(nn.Module): + """Window-level Qformer""" + + def __init__( + self, + window_size: int = 8, + num_queries: int = 1, + num_blocks: int = 2, + attention_dim: int = 512, + attention_heads: int = 8, + linear_units: int = 2048, + dropout_rate: float = 0.0, + normalize_before: bool = True, + ): + super().__init__() + + self.decoders = nn.ModuleList( + [ + nn.TransformerDecoderLayer( + d_model=attention_dim, + nhead=attention_heads, + dim_feedforward=linear_units, + dropout=dropout_rate, + activation="relu", + batch_first=True, + norm_first=normalize_before, # TODO need to verify + ) + for _ in range(num_blocks) + ] + ) + + self.queries = nn.Parameter(torch.zeros(1, num_queries, attention_dim)) + self.after_norm = ( + nn.LayerNorm(attention_dim, eps=1e-12) if normalize_before else None + ) + self.window_size = window_size + + def forward(self, audio_embed, mask, embed_len=None): + """forward decoder""" + # audio_embed: N x T x D => N x D x T + + audio_embed = audio_embed.transpose(1, 2) + # audio_embed: N x D x 1 x T => N x DK x T' + padding = audio_embed.shape[-1] % self.window_size + if padding > 0: + audio_embed = F.pad( + audio_embed, (0, self.window_size - padding), "constant", 0 + ) + + embed_chunk = F.unfold( + audio_embed[..., None, :], + kernel_size=(1, self.window_size), + stride=(1, self.window_size), + ) + bsz, _, slen = embed_chunk.shape + # N x D x K x T' + embed_chunk = embed_chunk.view(bsz, -1, self.window_size, slen) + # N x T' x K x D + embed_chunk = embed_chunk.transpose(1, 3).contiguous() + # NT' x K x D + embed_chunk = embed_chunk.view(bsz * slen, self.window_size, -1) + # NT' x 1 x D + q = self.queries.expand(bsz * slen, -1, -1) + for layer in self.decoders: + q = layer(tgt=q, memory=embed_chunk, tgt_mask=None, memory_mask=mask) + + if self.after_norm is not None: + q = self.after_norm(q) + + if embed_len is not None: + embed_len = embed_len // self.window_size + # N x T' x D + out = q.view(bsz, slen, -1) + + return out, embed_len + + +class AudioEmbedding(nn.Module): + """Image embedding.""" + + def __init__(self, config: PretrainedConfig, **kwargs) -> None: + super().__init__() + self.config = config + # n_embed or hidden_size for text LM + hidden_size = config.n_embd if hasattr(config, "n_embd") else config.hidden_size + + # self.wte = nn.Embedding(config.vocab_size, hidden_size) + + audio_dim_out = ( + None # Set this variable according to the actual audio processor + ) + self.layer_idx = -2 + + if ( + isinstance(config.audio_processor, dict) + and config.audio_processor.get("name", None) == "cascades" + ): + encoder_config = config.audio_processor.get("config", None) + assert encoder_config is not None + self.encoder = ConformerEncoder(**encoder_config) + + audio_dim_out = encoder_config["attention_dim"] + n_mels = encoder_config["input_size"] + else: + raise NotImplementedError("") + + assert audio_dim_out is not None, "Remember to set values for audio_dim_out" + self.audio_dim_out = audio_dim_out + self.audio_dim_in = n_mels + + self.freeze_audio_processor = kwargs.get("freeze_audio_processor", False) + + self.downsample_rate = kwargs.get("downsample_rate", 1) + + if kwargs.get("use_qformer", False): + qformer_config = kwargs.get("qformer_config", {}) + qformer_config["attention_dim"] = audio_dim_out + self.qformer = WindowQformer(**qformer_config) + else: + self.qformer = None + + if kwargs.get("use_conv_downsample", False): + assert ( + self.qformer is None + ), "don't support use qformer and conv downsample together" + nemo_conv_settings = kwargs.get("nemo_conv_settings", {}) + default_nemo_conv_settings = { + "subsampling": "dw_striding", + "subsampling_factor": self.downsample_rate, + "feat_in": audio_dim_out, + "feat_out": audio_dim_out, + "conv_channels": 256, + "subsampling_conv_chunking_factor": 1, + "activation": nn.ReLU(), + "is_causal": False, + } + # Override any of the defaults with the incoming, user settings + if nemo_conv_settings: + default_nemo_conv_settings.update(nemo_conv_settings) + for i in ["subsampling_factor", "feat_in", "feat_out"]: + assert ( + i not in nemo_conv_settings + ), "{i} should be specified outside of the NeMo dictionary" + + self.conv_ds = NemoConvSubsampling( + **default_nemo_conv_settings, + ) + else: + self.conv_ds = None + + projection_cls = kwargs.get("projection_cls", "linear") + if projection_cls == "linear": + self.audio_projection = nn.Linear(audio_dim_out, hidden_size) + elif projection_cls == "mlp": + # follow llava-v1.5's implementation + # (do not use image_projection and image_proj_norm) + dim_projection = hidden_size + depth = 2 + self.linear_downsample_rate = ( + 1 if (self.qformer or self.conv_ds) else self.downsample_rate + ) + layers = [ + nn.Linear(audio_dim_out * self.linear_downsample_rate, dim_projection) + ] + for _ in range(1, depth): + layers.extend([nn.GELU(), nn.Linear(dim_projection, dim_projection)]) + self.audio_projection = nn.Sequential(*layers) + # NOTE vision-speech tasks use a separate projection layer + layers = [ + nn.Linear(audio_dim_out * self.linear_downsample_rate, dim_projection) + ] + for _ in range(1, depth): + layers.extend([nn.GELU(), nn.Linear(dim_projection, dim_projection)]) + self.audio_projection_for_vision = nn.Sequential(*layers) + else: + raise NotImplementedError( + f"projection_cls = {projection_cls}, not implemented" + ) + + # TODO: audio sequence compression - Qformer + self.vocab_size = config.vocab_size + self.input_embeds = None + self.audio_embed_sizes = None + + def set_audio_embeds(self, input_embeds: torch.FloatTensor) -> None: + self.input_embeds = input_embeds + + def set_audio_embed_sizes(self, audio_embed_sizes: torch.LongTensor) -> None: + self.audio_embed_sizes = audio_embed_sizes + + def get_audio_features( + self, + input_embeds: torch.FloatTensor, + audio_attention_mask: torch.Tensor = None, + audio_projection_mode: str = "speech", + ) -> torch.FloatTensor: + """ + arguments: + input_embeds: audio features (B, T, D) B: num audios in a sequence + """ + if self.freeze_audio_processor: + with torch.no_grad(): + audio_features, masks = self.encoder(input_embeds, audio_attention_mask) + else: + audio_features, masks = self.encoder(input_embeds, audio_attention_mask) + + if self.qformer is not None: + audio_features, _ = self.qformer(audio_features, mask=None) + + if self.conv_ds is not None: + if masks is not None: + masks = masks.squeeze(1) + + audio_features, masks = self.conv_ds(audio_features, mask=masks) + + if self.linear_downsample_rate != 1: + bs, seq_len, feat_dim = audio_features.size() + padding = seq_len % self.linear_downsample_rate + if padding > 0: + audio_features = F.pad( + audio_features, + (0, 0, 0, self.linear_downsample_rate - padding), + "constant", + 0, + ) + + seq_len = audio_features.size(1) + audio_features = audio_features.view( + bs, + seq_len // self.linear_downsample_rate, + feat_dim * self.linear_downsample_rate, + ) + + if audio_projection_mode == "speech": + audio_set_tensor = self.audio_projection(audio_features) + elif audio_projection_mode == "vision": + audio_set_tensor = self.audio_projection_for_vision(audio_features) + else: + raise ValueError( + f"audio_projection_mode = {audio_projection_mode} not " "implemented" + ) + + return audio_set_tensor + + def forward( + self, + audio_features: torch.FloatTensor, + audio_attention_mask: torch.Tensor = None, + audio_projection_mode: str = "speech", + ) -> torch.FloatTensor: + """ + arguments: + audio_features: audio features (num_audio_tokens, T, D) + + returns: + audio_embeds: audio embeddings (num_audio_tokens, hidden_dim) + """ + audio_embeds = self.get_audio_features( + audio_features, + audio_attention_mask=audio_attention_mask, + audio_projection_mode=audio_projection_mode, + ) + return audio_embeds diff --git a/python/sglang/srt/models/phi4mm_utils.py b/python/sglang/srt/models/phi4mm_utils.py new file mode 100644 index 000000000000..e6bf35ebfc46 --- /dev/null +++ b/python/sglang/srt/models/phi4mm_utils.py @@ -0,0 +1,1917 @@ +# Copyright 2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +#!/usr/bin/env python3 +import math +from typing import Optional, Union + +import torch +import torch.nn.functional as F +from torch import Tensor, nn + + +class BlockBase(nn.Module): + """Block abstract module""" + + def __init__(self, input_size, output_size): + super().__init__() + self.input_size = input_size + self.output_size = output_size + + +def get_activation(name="relu"): + """Select an activation function by name + + Args: + name: str + activation function name, + one of ["relu", "gelu", "swish", "sigmoid"], + default "relu". + """ + name = name.lower() + if name == "relu": + return nn.ReLU(inplace=True) + if name == "gelu": + return nn.GELU() + if name == "swish": + return Swish() + if name == "sigmoid": + return torch.nn.Sigmoid() + return nn.Identity() + + +def adaptive_enc_mask(x_len, chunk_start_idx, left_window=0, right_window=0): + """ + The function is very important for Transformer Transducer Streaming mode + Args: + xs_len (int): sequence length + chunk_start_idx (list): first idx of each chunk, such as [0,18,36,48]. + It also supports adaptive chunk size [0,10,15,45] + left_window (int): how many left chunks can be seen + right_window (int): how many right chunks can be seen. It is used for + chunk overlap model. + Returns: + mask (torch.Tensor): a mask tensor for streaming model + Torch 1.0.1 + tensor([[1., 1., 0., 0.], + [0., 1., 1., 0.], + [0., 0., 1., 1.]]) + Torch 1.4.1 + tensor([[True., True., False., False.], + [False., True., True., False.], + [False., False., True., True.]]) + """ + chunk_start_idx = torch.Tensor( + chunk_start_idx + ).long() # first idx of each chunk, such as [0,18,36,48]. + start_pad = torch.nn.functional.pad( + chunk_start_idx, (1, 0) + ) # append 0 to the beginning, so it becomes [0, 0, 18, 36, 48] + end_pad = torch.nn.functional.pad( + chunk_start_idx, (0, 1), value=x_len + ) # append x_len to the end, so it becomes [0,18,36,48, x_len] + seq_range = torch.arange(0, x_len).unsqueeze(-1) # seq_range size: [x_len, 1] + idx = ((seq_range < end_pad) & (seq_range >= start_pad)).nonzero()[ + :, 1 + ] # idx size: [x_len] + # boundary = end_pad[idx] # boundary size: [x_len] + seq_range_expand = ( + torch.arange(0, x_len).unsqueeze(0).expand(x_len, -1) + ) # seq_range_expand size [x_len, x_len] + idx_left = idx - left_window + idx_left[idx_left < 0] = 0 + boundary_left = start_pad[idx_left] + mask_left = seq_range_expand >= boundary_left.unsqueeze(-1) + idx_right = idx + right_window + idx_right[idx_right > len(chunk_start_idx)] = len(chunk_start_idx) + boundary_right = end_pad[idx_right] + mask_right = seq_range_expand < boundary_right.unsqueeze(-1) + return mask_left & mask_right + + +class Swish(nn.Module): + """Implement Swish activation module. + From https://arxiv.org/pdf/2005.03191.pdf + + """ + + def __init__(self) -> None: + super().__init__() + self.act_fn = nn.Sigmoid() + + def forward(self, x: Tensor) -> Tensor: + """Apply Swish function + + Args: + x: torch.Tensor + Input. + """ + return x * self.act_fn(x) + + +class GLU(nn.Module): + """Implement Gated Linear Unit (GLU) module""" + + def __init__(self, dim: int = -1, act_name: str = "sigmoid") -> None: + super().__init__() + self.dim = dim + self.act_name = act_name.lower() + + if self.act_name == "relu": + self.act_fn = nn.ReLU(inplace=True) + elif self.act_name == "gelu": + self.act_fn = nn.GELU() + elif self.act_name == "swish": + self.act_fn = Swish() + elif self.act_name == "sigmoid": + self.act_fn = nn.Sigmoid() + else: + self.act_fn = nn.Identity() + + def forward(self, x: Tensor) -> Tensor: + """GLU forward + Apply Swish function on the first half of input matrices + with sigmoid of the second half. + + Args: + x: torch.Tensor + Input. + + """ + half_x, gate = x.chunk(2, dim=self.dim) + return half_x * self.act_fn(gate) + + +# TODO: Abdel, this can be improved using GLU module +class GLUPointWiseConv(nn.Module): + """GLUPointWiseConv module + used for conformer architecture, + for more details see: + https://arxiv.org/pdf/2005.08100v1.pdf + + Args: + input_dim: int + input channel size. + output_dim: int + output channel size. + kernel_size: int + kernel size + glu_type: str, optional + activation function one of + ["sigmoid", "relu", "gelu"] + default "sigmoid". + bias_in_glu: bool, optional + use addtive bias in glu + causal: bool, optional + if set to True, padding is set to the half of + kernel size, ie, convolution can't see future frames. + default False. + + """ + + def __init__( + self, + input_dim, + output_dim, + kernel_size, + glu_type="sigmoid", + bias_in_glu=True, + causal=False, + ): + super().__init__() + + self.glu_type = glu_type + self.output_dim = output_dim + self.bias_in_glu = bias_in_glu + if causal: + self.ext_pw_conv_1d = nn.Conv1d( + input_dim, + output_dim * 2, + kernel_size, + 1, + padding=(kernel_size - 1), + ) + else: + self.ext_pw_conv_1d = nn.Conv1d( + input_dim, + output_dim * 2, + kernel_size, + 1, + padding=(kernel_size - 1) // 2, + ) + + if glu_type == "sigmoid": + self.glu_act = nn.Sigmoid() + elif glu_type == "relu": + self.glu_act = nn.ReLU() + elif glu_type == "gelu": + self.glu_act = nn.GELU() + elif glu_type == "swish": + self.glu_act = Swish() + else: + raise ValueError(f"Unsupported activation type {self.glu_act}") + + if bias_in_glu: + self.b1 = nn.Parameter(torch.zeros(1, output_dim, 1)) + self.b2 = nn.Parameter(torch.zeros(1, output_dim, 1)) + + def forward(self, x): + """ + Args: + x: torch.Tensor + input tensor + """ + # to be consistent with GLULinear, we assume the input always has the + # #channel (#dim) in the last dimension of the tensor, so need to + # switch the dimension first for 1D-Conv case + x = x.permute([0, 2, 1]) + x = self.ext_pw_conv_1d(x) + if self.glu_type == "bilinear": + if self.bias_in_glu: + x = (x[:, 0 : self.output_dim, :] + self.b1) * ( + x[:, self.output_dim : self.output_dim * 2, :] + self.b2 + ) + else: + x = (x[:, 0 : self.output_dim, :]) * ( + x[:, self.output_dim : self.output_dim * 2, :] + ) + else: + if self.bias_in_glu: + x = (x[:, 0 : self.output_dim, :] + self.b1) * self.glu_act( + x[:, self.output_dim : self.output_dim * 2, :] + self.b2 + ) + else: + x = (x[:, 0 : self.output_dim, :]) * self.glu_act( + x[:, self.output_dim : self.output_dim * 2, :] + ) + + x = x.permute([0, 2, 1]) + return x + + +class DepthWiseSeperableConv1d(nn.Module): + """DepthWiseSeperableConv1d module used in Convnet module + for the conformer, for more details see: + https://arxiv.org/pdf/2005.08100v1.pdf + + Args: + input_dim: int + input channel size. + depthwise_seperable_out_channel: int + if set different to 0, the number of + depthwise_seperable_out_channel will be used as a channel_out + of the second conv1d layer. + otherwise, it equal to 0, the second conv1d layer is skipped. + kernel_size: int + kernel_size + depthwise_multiplier: int + number of input_dim channels duplication. this value + will be used to compute the hidden channels of the Conv1D. + padding: int, optional + padding for the conv1d, + default: 0. + + """ + + def __init__( + self, + input_dim, + depthwise_seperable_out_channel, + kernel_size, + depthwise_multiplier, + padding=0, + ): + super().__init__() + + self.dw_conv = nn.Conv1d( + input_dim, + input_dim * depthwise_multiplier, + kernel_size, + 1, + padding=padding, + groups=input_dim, + ) + + if depthwise_seperable_out_channel != 0: + self.pw_conv = nn.Conv1d( + input_dim * depthwise_multiplier, + depthwise_seperable_out_channel, + 1, + 1, + 0, + ) + else: + self.pw_conv = nn.Identity() + self.depthwise_seperable_out_channel = depthwise_seperable_out_channel + + def forward(self, x): + """ + + Args: + x: torch.Tensor + input tensor + """ + x = self.dw_conv(x) + if self.depthwise_seperable_out_channel != 0: + x = self.pw_conv(x) + return x + + +class ConvModule(nn.Module): + """ConvModule Module for the conformer block. + for more details see: + https://arxiv.org/pdf/2005.08100v1.pdf + + Args: + input_dim: int + input channel size. + ext_pw_out_channel: int + if > 0, ext_pw_out_channel is a dim channel size + for the last pointwise conv after swish activation. + depthwise_seperable_out_channel: int + if set different to 0, the number of + depthwise_seperable_out_channel + will be used as a channel_out of the second conv1d layer. + otherwise, it equal to 0, the second conv1d layer is skipped. + ext_pw_kernel_size: int + kernel size of the conv pointwise of the conformer. + kernel_size: int + kernel size. + depthwise_multiplier: int + number of input_dim channels duplication. this value + will be used to compute the hidden channels of the Conv1D. + dropout_rate: float + dropout rate. + causal: bool, optional + if set to True, convolution have no access + to future frames. default False. + batch_norm: bool, optional + if set to True, apply batchnorm before activation. + default False + chunk_se: int, optional + 0 for offline SE. + 1 for streaming SE, where mean is computed + by accumulated history until current chunk_se. + 2 for streaming SE, where mean is computed + by only the current chunk. + chunk_size: int, optional + chunk size for cnn. default 18 + activation: str, optional + activation function used in ConvModule, + default: "relu". + glu_type: str, optional + activation function used for the glu, + default: "sigmoid". + bias_in_glu: bool, optional + if set to True, use additive bias in the weight module + before GLU. + linear_glu_in_convm: bool, optional + if set to True, use GLULinear module, + otherwise, used GLUPointWiseConv module. + default to False. + export: bool, optional, + if set to True, padding is equal to 0. This is for inference, + or onnx export. Typically this is set by the export program or + the decoder program, and it isn't present in your config file. + default False + """ + + def __init__( + self, + input_dim, + ext_pw_out_channel, + depthwise_seperable_out_channel, + ext_pw_kernel_size, + kernel_size, + depthwise_multiplier, + dropout_rate, + causal=False, + batch_norm=False, + chunk_se=0, + chunk_size=18, + activation="relu", + glu_type="sigmoid", + bias_in_glu=True, + linear_glu_in_convm=False, + export=False, + ): + super().__init__() + self.layer_norm = nn.LayerNorm(input_dim) + self.input_dim = input_dim + self.ext_pw_out_channel = ext_pw_out_channel + self.ext_pw_kernel_size = ext_pw_kernel_size + self.depthwise_seperable_out_channel = depthwise_seperable_out_channel + self.glu_type = glu_type + self.bias_in_glu = bias_in_glu + self.linear_glu_in_convm = linear_glu_in_convm + self.causal = causal + + self._add_ext_pw_layer() + + self.batch_norm = batch_norm + self.kernel_size = kernel_size + + if batch_norm: + self.bn_layer = nn.BatchNorm1d(input_dim) + + self.act = get_activation(activation) + self.dropout = nn.Dropout(dropout_rate) + self.export = export + + if causal: + padding = 0 if export else kernel_size - 1 + else: + padding = (kernel_size - 1) // 2 + + self.dw_sep_conv_1d = DepthWiseSeperableConv1d( + input_dim, + depthwise_seperable_out_channel, + kernel_size, + depthwise_multiplier, + padding=padding, + ) + + if depthwise_seperable_out_channel != 0: + if input_dim != depthwise_seperable_out_channel: + self.ln2 = nn.Linear(depthwise_seperable_out_channel, input_dim) + else: + if depthwise_multiplier != 1: + self.ln2 = nn.Linear(input_dim * depthwise_multiplier, input_dim) + + def _add_ext_pw_layer(self): + """ + This function is an extension of __init__ function + and dedicated to the convolution module creation + of the conformer. + """ + self.ln1 = self.glu = self.bn_layer = self.ext_pw_conv_1d = ( + nn.Identity() + ) # jit hacks. + self.squeeze_excitation = nn.Identity() # jit. + self.apply_ln1 = self.fix_len1 = False # jit. + + if self.ext_pw_out_channel != 0: + if self.causal: + self.ext_pw_conv_1d = nn.Conv1d( + self.input_dim, + self.ext_pw_out_channel, + self.ext_pw_kernel_size, + 1, + padding=(self.ext_pw_kernel_size - 1), + ) + if self.ext_pw_kernel_size > 1: + self.fix_len1 = True + else: + self.fix_len1 = False + else: + self.ext_pw_conv_1d = nn.Conv1d( + self.input_dim, + self.ext_pw_out_channel, + self.ext_pw_kernel_size, + 1, + padding=(self.ext_pw_kernel_size - 1) // 2, + ) + self.fix_len1 = False + + if self.linear_glu_in_convm: + self.glu = GLULinear( + self.input_dim, + self.ext_pw_out_channel, + self.glu_type, + self.bias_in_glu, + ) + else: + self.glu = GLUPointWiseConv( + self.input_dim, + self.ext_pw_out_channel, + self.ext_pw_kernel_size, + self.glu_type, + self.bias_in_glu, + self.causal, + ) + + if self.input_dim != self.ext_pw_out_channel: + self.apply_ln1 = True + self.ln1 = nn.Linear(self.ext_pw_out_channel, self.input_dim) + else: + self.apply_ln1 = False + else: + self.pw_conv_simplify_w = torch.nn.Parameter(torch.ones(3)) + self.pw_conv_simplify_b = torch.nn.Parameter(torch.zeros(3)) + + def forward(self, x): + """ConvModule Forward. + + Args: + x: torch.Tensor + input tensor. + """ + x = self.layer_norm(x) + + if self.ext_pw_out_channel != 0: + x = self.glu(x) + if self.causal and self.ext_pw_kernel_size > 1: + x = x[:, : -(self.ext_pw_kernel_size - 1), :] + if self.apply_ln1: + x = self.ln1(x) + else: + x_0 = x * self.pw_conv_simplify_w[0] + self.pw_conv_simplify_b[0] + x_1 = x * self.pw_conv_simplify_w[1] + self.pw_conv_simplify_b[1] + x = x_0 + x_1 + + x = x.permute([0, 2, 1]) + + x = self.dw_sep_conv_1d(x) + if self.causal and self.kernel_size > 1: + x = x[:, :, : -(self.kernel_size - 1)] + if hasattr(self, "ln2"): + x = x.permute([0, 2, 1]) + x = self.ln2(x) + x = x.permute([0, 2, 1]) + if self.batch_norm: + x = self.bn_layer(x) + x = self.act(x) + + if self.ext_pw_out_channel != 0: + x = self.ext_pw_conv_1d(x) + if self.fix_len1: + x = x[:, :, : -(self.ext_pw_kernel_size - 1)] + + if self.apply_ln1: + x = x.permute([0, 2, 1]) + x = self.ln1(x) + x = x.permute([0, 2, 1]) + + x = x.permute([0, 2, 1]) + else: + x = x.unsqueeze(1).permute([0, 1, 3, 2]) + x = x * self.pw_conv_simplify_w[2] + self.pw_conv_simplify_b[2] + x = x.squeeze(1) + + x = self.dropout(x) + return x + + +class GLULinear(nn.Module): + """Linear + GLU module + + Args: + input_dim: int + input size + output_dim: int + output size. + glu_type: + activation function name used in glu module. + default "sigmoid" (swish function). + bias_in_glu: bool, optional + If True, the addtive bias is added. Default False. + """ + + def __init__( + self, + input_dim, + output_dim, + glu_type="sigmoid", + bias_in_glu=True, + ): + super().__init__() + self.linear = nn.Linear(input_dim, output_dim * 2, bias_in_glu) + self.glu_act = GLU(-1, glu_type) + + def forward(self, x): + """GLULinear forward + + Args: + x: torch.Tensor + inpute tensor. + """ + x = self.linear(x) + return self.glu_act(x) + + +class FeedForward(nn.Module): + """FeedForward Module. + For more details see Conformer paper: + https://arxiv.org/pdf/2005.08100.pdf + + Args: + d_model: int + input size. + d_inner: int + output size. + dropout_rate: float, + dropout rate. + activation: str, + activation function name, + one of ["relu", "swish", "sigmoid"], + sigmoid activation is only used with "glu_in_fnn=True", + default "sigmoid". + bias_in_glu: bool, optional + """ + + def __init__( + self, + d_model, + d_inner, + dropout_rate, + activation="sigmoid", + bias_in_glu=True, + ): + super().__init__() + self.d_model = d_model + self.d_inner = d_inner + + self.layer_norm = nn.LayerNorm(d_model) + module = GLULinear(d_model, d_inner, activation, bias_in_glu) + self.net = nn.Sequential( + module, + nn.Dropout(dropout_rate), + nn.Linear(d_inner, d_model), + nn.Dropout(dropout_rate), + ) + + def forward(self, x): + """FeedForward forward function. + + Args: + x: torch.Tensor + input tensor. + """ + out = self.net(self.layer_norm(x)) + + return out + + +#### positional encoding starts here +def _pre_hook( + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, +): + """Perform pre-hook in load_state_dict for backward compatibility. + + Note: + We saved self.pe until v.0.5.2 but we have omitted it later. + Therefore, we remove the item "pe" from `state_dict` for backward + compatibility. + + """ + k = prefix + "pe" + if k in state_dict: + state_dict.pop(k) + + +class T5RelativeAttentionLogitBias(nn.Module): + """ + This module implements the relative position bias described in Section + 2.1 of the T5 paper: https://arxiv.org/pdf/1910.10683.pdf + + The Huggingface implementation is used as a reference + https://github.com/huggingface/transformers/blob/v4.30.0/src/ + transformers/models/t5/modeling_t5.py#L435 + + Modifies attention as Q*K^T + B, where B is a learned scalar bias based + on relative position of the query and key. It is HxNxN, where H is the + number of heads, N is the sequence length. + + I've made these modifications to the original T5 bias: + - Skipping of the bucketing step. Original T5 bias converted rel + position distances into logarithmically increasing buckets. This is + supposed to help with length generalization. + - I just directly use rel position index as bias values, as we don't + need length generalization (40s max is good enough for ASR encoder), + and it keeps ONNX export simple. + - I've also extended it so that biases can be asymmetric, the default + implementation treats L->R and R->L the same. Asymmetric was found to + yield better results in my experiments. + + Args: + num_heads: int + Number of attention heads + num_buckets: int + Number of buckets to use for relative attention bias. This is the + size of the learnable bias parameter. Bucketing is not yet + supported, so this defaults to -1 which means no bucketing is + used (max_distance determines size of bias param). + max_distance: int + Maximum distance to use for relative attention bias. With + num_buckets=-1, this directly controls the max size of the bias + parameter. When num_buckets > 0 is supported, this will control + the maximum distance for logarithmic bucketing after which all + positions are in the same bucket. + symmetric: bool + Whether to use symmetric or asymmetric biases. symmetric=False uses + 2x number of bias params to distinguish L->R from R->L. This was + found to be better for the encoder. + """ + + def __init__(self, num_heads, num_buckets=-1, max_distance=1000, symmetric=False): + super().__init__() + self.num_heads = num_heads + self.num_buckets = num_buckets + self.max_distance = max_distance + self.symmetric = symmetric + self._skip_bucketing = self.num_buckets < 0 + if self._skip_bucketing: + self.num_buckets = max_distance + else: + raise NotImplementedError( + "T5 attention bias with bucketed positions is not yet tested" + ) + if not self.symmetric: + self.num_buckets *= 2 + self.bias_values = nn.Embedding(self.num_buckets, self.num_heads) + + def forward(self, x): + # instantiate bias compatible with shape of x + maxpos = x.size(1) + context_position = torch.arange(maxpos, device=x.device, dtype=torch.long)[ + :, None + ] + memory_position = torch.arange(maxpos, device=x.device, dtype=torch.long)[ + None, : + ] + relative_position = memory_position - context_position + # clipping to a maximum distance using ops that play well with ONNX + # export + relative_position = relative_position.masked_fill( + relative_position < -self.max_distance, -self.max_distance + ) + relative_position = relative_position.masked_fill( + relative_position > self.max_distance - 1, self.max_distance - 1 + ) + + # mapping from relative position to index in the bias parameter + if self._skip_bucketing: + bias_idx = relative_position + else: + bias_idx = self._bucket_relative_position(relative_position) + if self.symmetric: + bias_idx = bias_idx.abs() + else: + bias_idx += self.num_buckets // 2 + + t5_rel_att_bias = self.bias_values(bias_idx) # [L, L, H] + t5_rel_att_bias = t5_rel_att_bias.permute(2, 0, 1).unsqueeze(0) # [1, H, L, L] + + return t5_rel_att_bias + + def _bucket_relative_position(self, relative_position): + # this is a placeholder (isn't tested, likely buggy) using HuggingFace + # implem as a reference this also needs to be extended to support + # asymmetric +/- ve positions + relative_buckets = 0 + if not self.causal: + self.num_buckets //= 2 + relative_buckets += (relative_position > 0).to( + torch.long + ) * self.num_buckets + relative_position = torch.abs(relative_position) + else: + relative_position = -torch.min( + relative_position, torch.zeros_like(relative_position) + ) + # now relative_position is in the range [0, inf) + + # half of the buckets are for exact increments in positions + max_exact = self.num_buckets // 2 + is_small = relative_position < max_exact + + # The other half of the buckets are for logarithmically bigger bins in + # positions up to max_distance + relative_position_if_large = max_exact + ( + torch.log(relative_position.float() / max_exact) + / math.log(self.max_distance / max_exact) + * (self.num_buckets - max_exact) + ).to(torch.long) + relative_position_if_large = torch.min( + relative_position_if_large, + torch.full_like(relative_position_if_large, self.num_buckets - 1), + ) + + relative_buckets += torch.where( + is_small, relative_position, relative_position_if_large + ) + return relative_buckets + + +class AbsolutePositionalEncoding(nn.Module): + """Absolute Positional encoding module. + This module implement Absolute sinusoidal positional encoding + from: https://arxiv.org/pdf/1706.03762.pdf + + Args: + d_model: int + Input embedding size. + dropout_rate: float + dropout rate + max_len: int, optional + Maximum input length sequence, Default 5000 + + """ + + def __init__(self, d_model, dropout_rate, max_len=5000): + """Construct an PositionalEncoding object.""" + super().__init__() + self.d_model = d_model + self.xscale = math.sqrt(self.d_model) + self.dropout = torch.nn.Dropout(p=dropout_rate) + self.pe = None + self.extend_pe(torch.tensor(0.0).expand(1, max_len)) + self._register_load_state_dict_pre_hook(_pre_hook) + + def extend_pe(self, x): + """Reset the positional encodings. + + Args: + x: torch.Tensor + """ + if self.pe is not None and self.pe.size(1) >= x.size(1): + if self.pe.dtype != x.dtype or self.pe.device != x.device: + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + pe = torch.zeros(x.size(1), self.d_model) + position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, self.d_model, 2, dtype=torch.float32) + * -(math.log(10000.0) / self.d_model) + ) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0) + self.pe = pe.to(device=x.device, dtype=x.dtype) + + def forward(self, x: torch.Tensor): + """Add positional encoding. + + Args: + x: torch.Tensor + Input tensor. shape is (batch, time, ...) + + Returns: + torch.Tensor: Encoded tensor. Its shape is (batch, time, ...) + + """ + self.extend_pe(x) + x = x * self.xscale + self.pe[:, : x.size(1)] + return self.dropout(x) + + +#### forward embedding layers starts here +class MeanVarianceNormLayer(nn.Module): + """Mean/variance normalization layer. + + Will subtract mean and multiply input by inverted standard deviation. + Typically used as a very first layer in a model. + + Args: + input_size: int + layer input size. + """ + + def __init__(self, input_size): + super().__init__() + self.input_size = input_size + self.global_mean = nn.Parameter(torch.zeros(input_size)) + self.global_invstd = nn.Parameter(torch.ones(input_size)) + + def forward(self, input_: Tensor) -> Tensor: + """MeanVarianceNormLayer Forward + + Args: + input_: torch.Tensor + input tensor. + """ + return (input_ - self.global_mean) * self.global_invstd + + +class CausalConv1D(nn.Conv1d): + """ + A causal version of nn.Conv1d where each step would have limited access to + locations on its right or left + All arguments are the same as nn.Conv1d except padding. + + If padding is set None, then paddings are set automatically to make it a + causal convolution where each location would not see any steps on its right. + + If padding is set as a list (size of 2), then padding[0] would be used as + left padding and padding[1] as right padding. + It would make it possible to control the number of steps to be accessible + on the right and left. + This mode is not supported when stride > 1. padding[0]+padding[1] should + be equal to (kernel_size - 1). + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + padding: Union[str, int] = 0, + dilation: int = 1, + groups: int = 1, + bias: bool = True, + padding_mode: str = "zeros", + device=None, + dtype=None, + ) -> None: + self.cache_drop_size = None + if padding is None: + self._left_padding = kernel_size - 1 + self._right_padding = stride - 1 + else: + if stride != 1 and padding != kernel_size - 1: + raise ValueError("No striding allowed for non-symmetric convolutions!") + if isinstance(padding, int): + self._left_padding = padding + self._right_padding = padding + elif ( + isinstance(padding, list) + and len(padding) == 2 + and padding[0] + padding[1] == kernel_size - 1 + ): + self._left_padding = padding[0] + self._right_padding = padding[1] + else: + raise ValueError(f"Invalid padding param: {padding}!") + + self._max_cache_len = self._left_padding + + super().__init__( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=0, + dilation=dilation, + groups=groups, + bias=bias, + padding_mode=padding_mode, + device=device, + dtype=dtype, + ) + + def update_cache(self, x, cache=None): + if cache is None: + new_x = F.pad(x, pad=(self._left_padding, self._right_padding)) + next_cache = cache + else: + new_x = F.pad(x, pad=(0, self._right_padding)) + new_x = torch.cat([cache, new_x], dim=-1) + if self.cache_drop_size > 0: + next_cache = new_x[:, :, : -self.cache_drop_size] + else: + next_cache = new_x + next_cache = next_cache[:, :, -cache.size(-1) :] + return new_x, next_cache + + def forward(self, x, cache=None): + x, cache = self.update_cache(x, cache=cache) + x = super().forward(x) + if cache is None: + return x + else: + return x, cache + + +class CausalConv2D(nn.Conv2d): + """ + A causal version of nn.Conv2d where each location in the 2D matrix would + have no access to locations on its right or down + All arguments are the same as nn.Conv2d except padding which should be + set as None + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + padding: Union[str, int] = 0, + dilation: int = 1, + groups: int = 1, + bias: bool = True, + padding_mode: str = "zeros", + device=None, + dtype=None, + ) -> None: + if padding is not None: + raise ValueError("Argument padding should be set to None for CausalConv2D.") + self._left_padding = kernel_size - 1 + self._right_padding = stride - 1 + + padding = 0 + super().__init__( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + bias, + padding_mode, + device, + dtype, + ) + + def forward( + self, + x, + ): + x = F.pad( + x, + pad=(self._left_padding, self._right_padding, 0, 0), + ) + x = super().forward(x) + return x + + +class NemoConvSubsampling(torch.nn.Module): + """Convlutional subsampling module, taken from NeMo ASR + (https://github.com/NVIDIA/NeMo/blob/b367413645d5c72db3c2c96e46e95a + 34501479cf/nemo/collections/asr/parts/submodules/subsampling.py) + + Striding Subsampling: "Speech-Transformer: A No-Recurrence + Sequence-to-Sequence Model for Speech Recognition" by Linhao Dong + et al. (https://ieeexplore.ieee.org/document/8462506) + + + Compared with the EncoderConv2D (`input_layer: custom`), this is a + much simplified approach, and uses no LayerNorm and far fewer Conv2Ds. + Moreover, depthwise convolutions are used to reduce FLOPs, but the first + layer is kept as a regular convolution so as not to degrade accuracy. + + `Striding` and `dw_striding` are the same except that the latter uses + depthwise convolutions after the first layer, whereas the former does not. + + Args: + subsampling_factor (int): Time reduction factor + feat_in (int): size of the input features + feat_out (int): size of the output features + subsampling (str): The subsampling technique, choose from + {"striding", "dw-striding", "striding_conv1d", + "dw_striding_conv1d"} + conv_channels (int): Number of channels for the convolution layers, + default is 256. + subsampling_conv_chunking_factor (int): Input chunking factor which + can be -1 (no chunking) 1 (auto) or a power of 2. Default is 1 + activation (Module): activation function, default is nn.ReLU() + is_causal (bool): whether to use causal Conv1/2D, where each step will + have limited access to locations on its right or left + """ + + def __init__( + self, + feat_in, + feat_out, + subsampling_factor=4, + subsampling="dw_striding", + conv_channels=256, + subsampling_conv_chunking_factor=1, + activation=nn.ReLU(), # noqa: B008 + is_causal=False, + ): + super().__init__() + self._subsampling = subsampling + self._conv_channels = conv_channels + self._feat_in = feat_in + self._feat_out = feat_out + + if subsampling_factor % 2 != 0: + raise ValueError("Sampling factor should be a multiply of 2!") + self._sampling_num = int(math.log(subsampling_factor, 2)) + self.subsampling_factor = subsampling_factor + self.is_causal = is_causal + self.subsampling_causal_cond = subsampling in ( + "dw_striding", + "striding", + "striding_conv1d", + ) + + if ( + subsampling_conv_chunking_factor != -1 + and subsampling_conv_chunking_factor != 1 + and subsampling_conv_chunking_factor % 2 != 0 + ): + raise ValueError( + "subsampling_conv_chunking_factor should be -1, 1, or a " "power of 2" + ) + self.subsampling_conv_chunking_factor = subsampling_conv_chunking_factor + + in_channels = 1 + layers = [] + + if subsampling == "dw_striding": + self._stride = 2 + self._kernel_size = 3 + self._ceil_mode = False + + if self.is_causal: + self._left_padding = self._kernel_size - 1 + self._right_padding = self._stride - 1 + self._max_cache_len = subsampling_factor + 1 + else: + self._left_padding = (self._kernel_size - 1) // 2 + self._right_padding = (self._kernel_size - 1) // 2 + self._max_cache_len = 0 + + # Layer 1 + if self.is_causal: + layers.append( + CausalConv2D( + in_channels=in_channels, + out_channels=conv_channels, + kernel_size=self._kernel_size, + stride=self._stride, + padding=None, + ) + ) + else: + layers.append( + torch.nn.Conv2d( + in_channels=in_channels, + out_channels=conv_channels, + kernel_size=self._kernel_size, + stride=self._stride, + padding=self._left_padding, + ) + ) + in_channels = conv_channels + layers.append(activation) + + for i in range(self._sampling_num - 1): + if self.is_causal: + layers.append( + CausalConv2D( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=self._kernel_size, + stride=self._stride, + padding=None, + groups=in_channels, + ) + ) + else: + layers.append( + torch.nn.Conv2d( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=self._kernel_size, + stride=self._stride, + padding=self._left_padding, + groups=in_channels, + ) + ) + + layers.append( + torch.nn.Conv2d( + in_channels=in_channels, + out_channels=conv_channels, + kernel_size=1, + stride=1, + padding=0, + groups=1, + ) + ) + layers.append(activation) + in_channels = conv_channels + + elif subsampling == "striding": + self._stride = 2 + self._kernel_size = 3 + self._ceil_mode = False + + if self.is_causal: + self._left_padding = self._kernel_size - 1 + self._right_padding = self._stride - 1 + self._max_cache_len = subsampling_factor + 1 + else: + self._left_padding = (self._kernel_size - 1) // 2 + self._right_padding = (self._kernel_size - 1) // 2 + self._max_cache_len = 0 + + for i in range(self._sampling_num): + if self.is_causal: + layers.append( + CausalConv2D( + in_channels=in_channels, + out_channels=conv_channels, + kernel_size=self._kernel_size, + stride=self._stride, + padding=None, + ) + ) + else: + layers.append( + torch.nn.Conv2d( + in_channels=in_channels, + out_channels=conv_channels, + kernel_size=self._kernel_size, + stride=self._stride, + padding=self._left_padding, + ) + ) + layers.append(activation) + in_channels = conv_channels + + elif subsampling == "striding_conv1d": + in_channels = feat_in + + self._stride = 2 + self._kernel_size = 5 + self._ceil_mode = False + + if self.is_causal: + self._left_padding = self._kernel_size - 1 + self._right_padding = self._stride - 1 + self._max_cache_len = subsampling_factor + 1 + else: + self._left_padding = (self._kernel_size - 1) // 2 + self._right_padding = (self._kernel_size - 1) // 2 + self._max_cache_len = 0 + + for i in range(self._sampling_num): + if self.is_causal: + layers.append( + CausalConv1D( + in_channels=in_channels, + out_channels=( + feat_out + if self._sampling_num == i + 1 + else conv_channels + ), + kernel_size=self._kernel_size, + stride=self._stride, + padding=None, + ) + ) + else: + layers.append( + torch.nn.Conv1d( + in_channels=in_channels, + out_channels=( + feat_out + if self._sampling_num == i + 1 + else conv_channels + ), + kernel_size=self._kernel_size, + stride=self._stride, + padding=self._left_padding, + ) + ) + layers.append(activation) + in_channels = conv_channels + + elif subsampling == "dw_striding_conv1d": + in_channels = feat_in + + self._stride = 2 + self._kernel_size = 5 + self._ceil_mode = False + + self._left_padding = (self._kernel_size - 1) // 2 + self._right_padding = (self._kernel_size - 1) // 2 + + # Layer 1 + layers.extend( + [ + torch.nn.Conv1d( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=self._kernel_size, + stride=self._stride, + padding=self._left_padding, + groups=in_channels, + ), + torch.nn.Conv1d( + in_channels=in_channels, + out_channels=( + feat_out if self._sampling_num == 1 else conv_channels + ), + kernel_size=1, + stride=1, + padding=0, + groups=1, + ), + ] + ) + in_channels = conv_channels + layers.append(activation) + + for i in range(self._sampling_num - 1): + layers.extend( + [ + torch.nn.Conv1d( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=self._kernel_size, + stride=self._stride, + padding=self._left_padding, + groups=in_channels, + ), + torch.nn.Conv1d( + in_channels=in_channels, + out_channels=( + feat_out + if self._sampling_num == i + 2 + else conv_channels + ), + kernel_size=1, + stride=1, + padding=0, + groups=1, + ), + ] + ) + layers.append(activation) + in_channels = conv_channels + + else: + raise ValueError(f"Not valid sub-sampling: {subsampling}!") + + if subsampling in ["dw_striding", "striding"]: + in_length = torch.tensor(feat_in, dtype=torch.float) + out_length = calc_length( + lengths=in_length, + all_paddings=self._left_padding + self._right_padding, + kernel_size=self._kernel_size, + stride=self._stride, + ceil_mode=self._ceil_mode, + repeat_num=self._sampling_num, + ) + self.out = torch.nn.Linear(conv_channels * int(out_length), feat_out) + self.conv2d_subsampling = True + elif subsampling in ["striding_conv1d", "dw_striding_conv1d"]: + self.out = None + self.conv2d_subsampling = False + else: + raise ValueError(f"Not valid sub-sampling: {subsampling}!") + + self.conv = torch.nn.Sequential(*layers) + + def get_sampling_frames(self): + return [1, self.subsampling_factor] + + def get_streaming_cache_size(self): + return [0, self.subsampling_factor + 1] + + def forward(self, x, mask): + """ + Forward method for NeMo subsampling. + + Args: + x[Batch, Time, Filters]: torch.Tensor + input tensor + x_mask: torch.Tensor + input mask + + Returns: + x: torch.Tensor + Resulting tensor from subsampling (B, T // + time_reduction_factor, feat_out) + pad_mask: torch.Tensor + tensor of padded hidden state sequences (B, 1, T // + time_reduction_factor) + """ + x = x.unsqueeze(1) if self.conv2d_subsampling else x.transpose(1, 2) + + # split inputs if chunking_factor is set + if self.subsampling_conv_chunking_factor != -1 and self.conv2d_subsampling: + if self.subsampling_conv_chunking_factor == 1: + # if subsampling_conv_chunking_factor is 1, we split only + # if needed. + # avoiding a bug / feature limiting indexing of tensors + # to 2**31. + # see https://github.com/pytorch/pytorch/issues/80020 + x_ceil = 2**31 / self._conv_channels * self._stride * self._stride + need_to_split = torch.numel(x) > x_ceil + else: + # if subsampling_conv_chunking_factor > 1 we always split + need_to_split = True + + if need_to_split: + x, success = self.conv_split_by_batch(x) + if not success: # if unable to split by batch, try by channel + if self._subsampling == "dw_striding": + x = self.conv_split_by_channel(x) + else: + x = self.conv(x) # try anyway + else: + x = self.conv(x) + else: + x = self.conv(x) + + # Flatten Channel and Frequency Axes + if self.conv2d_subsampling: + b, c, t, f = x.size() + x = self.out(x.transpose(1, 2).reshape(b, t, -1)) + # Transpose to Channel Last mode + else: + x = x.transpose(1, 2) + + if mask is None: + return x, None + + max_audio_length = x.shape[1] + feature_lens = mask.sum(1) + padding_length = torch.ceil(feature_lens / self.subsampling_factor) + if self.is_causal and self.subsampling_causal_cond: + feature_lens_remainder = feature_lens % self.subsampling_factor + padding_length[feature_lens_remainder != 1] += 1 + pad_mask = torch.arange(0, max_audio_length, device=x.device).expand( + padding_length.size(0), -1 + ) < padding_length.unsqueeze(1) + return x, pad_mask.unsqueeze(1) + + def reset_parameters(self): + # initialize weights + if self._subsampling == "dw_striding": + with torch.no_grad(): + # init conv + scale = 1.0 / self._kernel_size + dw_max = (self._kernel_size**2) ** -0.5 + pw_max = self._conv_channels**-0.5 + + torch.nn.init.uniform_(self.conv[0].weight, -scale, scale) + torch.nn.init.uniform_(self.conv[0].bias, -scale, scale) + + for idx in range(2, len(self.conv), 3): + torch.nn.init.uniform_(self.conv[idx].weight, -dw_max, dw_max) + torch.nn.init.uniform_(self.conv[idx].bias, -dw_max, dw_max) + torch.nn.init.uniform_(self.conv[idx + 1].weight, -pw_max, pw_max) + torch.nn.init.uniform_(self.conv[idx + 1].bias, -pw_max, pw_max) + + # init fc (80 * 64 = 5120 from https://github.com/kssteven418/ + # Squeezeformer/blob/13c97d6cf92f2844d2cb3142b4c5bfa9ad1a8951/ + # src/models/conformer_encoder.py#L487 + fc_scale = (self._feat_out * self._feat_in / self._sampling_num) ** -0.5 + torch.nn.init.uniform_(self.out.weight, -fc_scale, fc_scale) + torch.nn.init.uniform_(self.out.bias, -fc_scale, fc_scale) + + def conv_split_by_batch(self, x): + """Tries to split input by batch, run conv and concat results""" + b, _, _, _ = x.size() + if b == 1: # can't split if batch size is 1 + return x, False + + if self.subsampling_conv_chunking_factor > 1: + cf = self.subsampling_conv_chunking_factor + else: + # avoiding a bug / feature limiting indexing of tensors to 2**31 + # see https://github.com/pytorch/pytorch/issues/80020 + x_ceil = 2**31 / self._conv_channels * self._stride * self._stride + p = math.ceil(math.log(torch.numel(x) / x_ceil, 2)) + cf = 2**p + + new_batch_size = b // cf + if new_batch_size == 0: # input is too big + return x, False + + return ( + torch.cat( + [self.conv(chunk) for chunk in torch.split(x, new_batch_size, 0)] + ), + True, + ) + + def conv_split_by_channel(self, x): + """For dw convs, tries to split input by time, run conv and concat + results""" + x = self.conv[0](x) # full conv2D + x = self.conv[1](x) # activation + + for i in range(self._sampling_num - 1): + _, c, t, _ = x.size() + + if self.subsampling_conv_chunking_factor > 1: + cf = self.subsampling_conv_chunking_factor + else: + # avoiding a bug / feature limiting indexing of tensors + # to 2**31 + # see https://github.com/pytorch/pytorch/issues/80020 + p = math.ceil(math.log(torch.numel(x) / 2**31, 2)) + cf = 2**p + + new_c = int(c // cf) + if new_c == 0: + new_c = 1 + + new_t = int(t // cf) + if new_t == 0: + new_t = 1 + + x = self.channel_chunked_conv( + self.conv[i * 3 + 2], new_c, x + ) # conv2D, depthwise + + # splitting pointwise convs by time + x = torch.cat( + [self.conv[i * 3 + 3](chunk) for chunk in torch.split(x, new_t, 2)], + 2, + ) # conv2D, pointwise + x = self.conv[i * 3 + 4](x) # activation + return x + + def channel_chunked_conv(self, conv, chunk_size, x): + """Performs channel chunked convolution""" + + ind = 0 + out_chunks = [] + for chunk in torch.split(x, chunk_size, 1): + step = chunk.size()[1] + + if self.is_causal: + chunk = nn.functional.pad( + chunk, + pad=( + self._kernel_size - 1, + self._stride - 1, + self._kernel_size - 1, + self._stride - 1, + ), + ) + ch_out = nn.functional.conv2d( + chunk, + conv.weight[ind : ind + step, :, :, :], + bias=conv.bias[ind : ind + step], + stride=self._stride, + padding=0, + groups=step, + ) + else: + ch_out = nn.functional.conv2d( + chunk, + conv.weight[ind : ind + step, :, :, :], + bias=conv.bias[ind : ind + step], + stride=self._stride, + padding=self._left_padding, + groups=step, + ) + out_chunks.append(ch_out) + ind += step + + return torch.cat(out_chunks, 1) + + def change_subsampling_conv_chunking_factor( + self, subsampling_conv_chunking_factor: int + ): + if ( + subsampling_conv_chunking_factor != -1 + and subsampling_conv_chunking_factor != 1 + and subsampling_conv_chunking_factor % 2 != 0 + ): + raise ValueError( + "subsampling_conv_chunking_factor should be -1, 1, or a " "power of 2" + ) + self.subsampling_conv_chunking_factor = subsampling_conv_chunking_factor + + +def calc_length(lengths, all_paddings, kernel_size, stride, ceil_mode, repeat_num=1): + """Calculates the output length of a Tensor passed through a convolution or + max pooling layer""" + add_pad: float = all_paddings - kernel_size + one: float = 1.0 + for i in range(repeat_num): + lengths = torch.div(lengths.to(dtype=torch.float) + add_pad, stride) + one + lengths = torch.ceil(lengths) if ceil_mode else torch.floor(lengths) + return lengths.to(dtype=torch.int) + + +#### multihead attention starts here +class AttModule(nn.Module): + """Attention abstraction module""" + + def __init__(self): + super().__init__() + self.export_mode = False + + def set_export(self, mode=True): + """set the export mode""" + self.export_mode = mode + + def forward( + self, + x: Tensor, + memory: Optional[Tensor] = None, + pos_emb: Optional[Tensor] = None, + att_mask: Optional[Tensor] = None, + ) -> tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]: + """AttModule forward + + Args: + x: torch.Tensor + input tensor. + memory: torch.Tensor, optional + memory tensor. + pos_emb: torch.Tensor, optional + positional encoder embedding. + att_mask: torch.Tensor, optional + attention mask tensor. + """ + return x, memory, pos_emb, att_mask + + +class AttBlock(BlockBase, AttModule): + """Attention Block module to support both Attention and Block module.""" + + def memory_dims(self, max_len=False): + """memory dimensions""" + return (1, self.input_size) + + +def masked_softmax( + scores, + mask: Optional[Tensor], +): + if mask is not None: + mask = mask.unsqueeze(1).eq(0) # (batch, 1, time1, time2) + scores = scores.masked_fill(mask, -torch.inf) + attn = torch.softmax(scores, dim=-1).masked_fill( + mask, 0.0 + ) # (batch, head, time1, time2) + else: + attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2) + return attn + + +class MultiHeadedAttention(nn.Module): + """Multi-Head Attention layer with optional relative position embedding + and GLU. + + Args: + n_head: int + the number of heads. + n_feat: int + input size features. + dropout_rate: float + dropout rate. + use_LN: bool + apply layer norm or not + dropout_at_output: bool + whether to apply dropout at output + attention_inner_dim: int, optional + the attention dimension used in the class, + it can be different from the input dimension n_feat. + default: -1 (equal to n_feat). + use_pt_scaled_dot_product_attention: bool, optional + if set True, use pytorch scaled dot product attention in training. + NOTE: this will NOT be used in ONNX decoding due to a lack of + support. In that case, we use the original attention + implementation, which shows no regression. + default: False. + n_value: int, optional + if set to values other than -1, use a different dimension for + value. With the default value (i.e. -1), it is backward compatible. + group_size: int, optional. must divide `n_head` + if group_size > 1: GQA + if group_size = 1: MHA + if group_size = n_head: MQA + """ + + inv_sqrt_d_k: torch.jit.Final[float] + h: torch.jit.Final[int] + h_k: torch.jit.Final[int] + g: torch.jit.Final[int] + + def __init__( + self, + n_head, + n_feat, + dropout_rate, + attention_inner_dim=-1, + glu_type="swish", + bias_in_glu=True, + use_pt_scaled_dot_product_attention=False, + n_value=-1, + group_size: int = 1, + ): + super().__init__() + if n_value == -1: + n_value = n_feat + if attention_inner_dim == -1: + attention_inner_dim = n_feat + assert attention_inner_dim % n_head == 0 + + # We assume d_v always equals d_k + self.d_k = attention_inner_dim // n_head + self.inv_sqrt_d_k = 1.0 / math.sqrt(self.d_k) + self.h = n_head + assert n_head % group_size == 0, "group_size must divide n_head" + self.g = group_size + self.h_k = n_head // group_size + + self.linear_q = nn.Linear(n_feat, attention_inner_dim) + self.linear_k = nn.Linear(n_feat, attention_inner_dim // group_size) + self.linear_v = nn.Linear(n_value, attention_inner_dim // group_size) + self.linear_out = nn.Linear(attention_inner_dim // group_size, n_value) + + self.attn = torch.jit.Attribute(None, Optional[Tensor]) + self.dropout = nn.Dropout(p=dropout_rate) + self.dropout_rate = dropout_rate + self.use_pt_scaled_dot_product_attention = use_pt_scaled_dot_product_attention + + if use_pt_scaled_dot_product_attention and group_size > 1: + raise ValueError("Cannot use PT Scaled Attention with GQA") + + # Torchscript eager quantization. Note that these functions below are + # NOOPs and have very little impact on performance unless quantization + # is enabled. + self.quant_q = torch.ao.quantization.QuantStub() + self.quant_x = torch.ao.quantization.QuantStub() + self.dequant = torch.ao.quantization.DeQuantStub() + self.ffunc = torch.ao.nn.quantized.FloatFunctional() + + def forward( + self, + query: Tensor, + key: Tensor, + value: Tensor, + pos_k: Tensor, + pos_v: Tensor, + mask: Optional[Tensor], + relative_attention_bias: Optional[Tensor] = None, + ): + """Compute 'Scaled Dot Product Attention'. + + Args: + query: torch.Tensor + query tensor (batch, time1, size) + key: torch.Tensor + key tensor (batch, time2, size) + value: torch.Tensor + value tensor (batch, time1, size) + pos_k: torch.Tensor + key tensor used for relative positional embedding. + pos_v: torch.Tensor + value tensor used for relative positional embedding. + mask: torch.Tensor + mask tensor (batch, time1, time2) + relative_attention_bias: torch.Tensor + bias added to attention logits w.r.t. relative positions + (1, n_head, time1, time2) + """ + n_batch = query.size(0) + + q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k) # (b, t, d) + k = self.linear_k(key).view(n_batch, -1, self.h_k, self.d_k) # (b, t, d) + v = self.linear_v(value).view(n_batch, -1, self.h_k, self.d_k) + q = ( + q.transpose(1, 2) + if self.use_pt_scaled_dot_product_attention and not torch.jit.is_scripting() + else q.transpose(1, 2) * self.inv_sqrt_d_k + ) + k = k.transpose(1, 2) # (batch, head_k, time2, d_k) + v = v.transpose(1, 2) # (batch, head_k, time2, d_k) + + if self.use_pt_scaled_dot_product_attention and not torch.jit.is_scripting(): + attn_mask = None + if mask is not None: + mask = mask.unsqueeze(1) + if relative_attention_bias is not None: + attn_mask = mask + relative_attention_bias + else: + attn_mask = mask + if mask.dtype != q.dtype: + attn_mask = attn_mask.to(q.dtype) + + with torch.nn.attention.sdpa_kernel( + [ + torch.nn.attention.SDPBackend.FLASH_ATTENTION, + torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION, + torch.nn.attention.SDPBackend.MATH, + torch.nn.attention.SDPBackend.CUDNN_ATTENTION, + ] + ): + x = torch.nn.functional.scaled_dot_product_attention( + q, + k, + v, + attn_mask=attn_mask, + dropout_p=self.dropout_rate, + ) + else: + if self.h != self.h_k: + q = q.reshape(n_batch, self.g, self.h_k, -1, self.d_k) + A = torch.einsum("b g h t d, b h s d -> b h t s", q, k) + else: + A = torch.matmul(q, k.transpose(-2, -1)) + if pos_k is not None: + if self.h != self.h_k: + B = torch.einsum("b g h t d, t s d -> b h t s", q, pos_k) + else: + reshape_q = ( + q.contiguous() + .view(n_batch * self.h, -1, self.d_k) + .transpose(0, 1) + ) # (t1,nh,dk) + B = torch.matmul( + reshape_q, pos_k.transpose(-2, -1) + ) # pos_k: (t1,dk,t2) + B = B.transpose(0, 1).view( + n_batch, self.h, pos_k.size(0), pos_k.size(1) + ) + scores = A + B + else: + scores = A + + if relative_attention_bias is not None: + scores = scores + relative_attention_bias + + attn = masked_softmax(scores, mask) # (batch, head, time1, time2) + + self.attn = attn + + p_attn = self.dropout(attn) + x = torch.matmul(p_attn.to(v.dtype), v) # (batch, head, time1, d_k) + if pos_v is not None: + reshape_attn = ( + p_attn.contiguous() + .view(n_batch * self.h, pos_v.size(0), pos_v.size(1)) + .transpose(0, 1) + ) # (t1, bh, t2) + + attn_v = ( + torch.matmul(reshape_attn, pos_v) + .transpose(0, 1) + .contiguous() + .view(n_batch, self.h, pos_v.size(0), self.d_k) + ) + x = x + attn_v + x = ( + x.transpose(1, 2).contiguous().view(n_batch, -1, self.h_k * self.d_k) + ) # (batch, time1, d_model) + + return self.linear_out(x) # (batch, time1, d_model) + + +class MultiSequential(torch.nn.Sequential): + """Multi-input multi-output torch.nn.Sequential""" + + @torch.jit.ignore + def forward(self, *args): + """Forward method implementation.""" + for m in self: + args = m(*args) + return args + + +def get_offset(input_layer: str, time_reduction: int): + """Get an offset. We will use the offset for determining #frames of a + subsampled feature. + + Args: + input_layer (str): Type of an input layer + time_reduction (int): time reduction factor for downsampling a feature + Returns: + int: offset + """ + if input_layer in ("conv2d", "nemo_conv") and time_reduction == 4: + return 3 + if input_layer in ("conv2d",) and time_reduction == 6: + return 1 + if input_layer in ("conv2d", "nemo_conv") and time_reduction == 8: + return 7 + return 0 + + +def unfold_tensor(xs_pad, max_seq_len): + """ + For a given tensor with shape of (N, T, D), if sequence length T is + longer than max_seq_len, this function unfold it to a + (NT', max_seq_len, D) where T' is T // max_seq_len. + Args: + xs_pad: N, T, D + """ + _, _, D = xs_pad.shape + xs_pad = xs_pad.transpose(-1, -2) # convert to N, D, T + # N x D x 1 x T => N x (D x max_seq_len) x T' + xs_pad = F.unfold( + xs_pad[..., None, :], + kernel_size=(1, max_seq_len), + stride=(1, max_seq_len), + ) + new_bsz, _, slen = xs_pad.shape + # N x D x max_seq_len x T' + xs_pad = xs_pad.view(new_bsz, -1, max_seq_len, slen) + # N x T' x max_seq_len x D + xs_pad = xs_pad.permute(0, 3, 2, 1).contiguous() + # NT' x max_seq_len x D + xs_pad = xs_pad.view(-1, max_seq_len, D) + return xs_pad diff --git a/python/sglang/srt/multimodal/processors/base_processor.py b/python/sglang/srt/multimodal/processors/base_processor.py index 5c44c4d49953..6c6495c5f8f0 100644 --- a/python/sglang/srt/multimodal/processors/base_processor.py +++ b/python/sglang/srt/multimodal/processors/base_processor.py @@ -158,6 +158,7 @@ def __init__(self, hf_config, server_args, _processor): "pixel_values_videos": Modality.VIDEO, "image_sizes": Modality.IMAGE, "image_grid_thw": Modality.IMAGE, + "image_attention_mask": Modality.IMAGE, "image_emb_mask": Modality.IMAGE, "image_spatial_crop": Modality.IMAGE, "tgt_size": Modality.IMAGE, @@ -170,6 +171,7 @@ def __init__(self, hf_config, server_args, _processor): "audio_feature_lens": Modality.AUDIO, "input_features": Modality.AUDIO, "input_features_mask": Modality.AUDIO, + "audio_attention_mask": Modality.AUDIO, # Video-related attributes "video_grid_thw": Modality.VIDEO, # Generic attributes that could apply to multiple modalities @@ -251,7 +253,11 @@ def get_estimated_frames_list(self, image_data): @staticmethod def _load_single_item( - data, modality: Modality, frame_count_limit=None, discard_alpha_channel=True + data, + modality: Modality, + frame_count_limit=None, + audio_sample_rate: Optional[int] = None, + discard_alpha_channel=True, ): """ Load a single multimodal data. @@ -268,7 +274,7 @@ def _load_single_item( elif modality == Modality.VIDEO: return load_video(data, frame_count_limit) elif modality == Modality.AUDIO: - return load_audio(data) + return load_audio(data, audio_sample_rate) except Exception as e: raise RuntimeError(f"Error while loading data {data}: {e}") @@ -282,6 +288,7 @@ def submit_data_loading_tasks( image_estimated_frames_iter: Optional[iter] = None, image_scaling_factor: float = 1.0, max_image_frames: int = 30, + audio_sample_rate: Optional[int] = None, ) -> Tuple[List, List]: """ load multimodal data parallelly using iterators. @@ -324,6 +331,7 @@ def submit_data_loading_tasks( data, modality, frame_count_limit, + audio_sample_rate, discard_alpha_channel, ) ) @@ -352,6 +360,7 @@ def load_mm_data( audio_data: Optional[list] = None, return_text: Optional[bool] = True, discard_alpha_channel: bool = True, + audio_sample_rate: Optional[int] = None, ) -> BaseMultiModalProcessorOutput: """ Each frame of video/image will be replaced by a single image token @@ -390,6 +399,7 @@ def load_mm_data( multimodal_tokens=multimodal_tokens, data_iterators=data_iterators, discard_alpha_channel=discard_alpha_channel, + audio_sample_rate=audio_sample_rate, ) task_info_iter = iter(task_info) futures_iter = iter(futures) diff --git a/python/sglang/srt/multimodal/processors/phi4mm.py b/python/sglang/srt/multimodal/processors/phi4mm.py index aea06506d078..8772403dbdb7 100644 --- a/python/sglang/srt/multimodal/processors/phi4mm.py +++ b/python/sglang/srt/multimodal/processors/phi4mm.py @@ -1,6 +1,8 @@ import logging from typing import List, Union +from transformers.processing_utils import ProcessorMixin + from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem from sglang.srt.models.phi4mm import Phi4MMForCausalLM from sglang.srt.multimodal.processors.base_processor import ( @@ -10,18 +12,58 @@ logger = logging.getLogger(__name__) -_IMAGE_SPECIAL_TOKEN = "<|endoftext10|>" -_IMAGE_SPECIAL_TOKEN_ID = 200010 + +# It is an adapter of hf phi4 mm processor to make it work for sglang +# Ref: https://huggingface.co/microsoft/Phi-4-multimodal-instruct/blob/main/processing_phi4mm.py#L693 +class Phi4MMProcessorAdapter(ProcessorMixin): + def __init__(self, _processor) -> None: + self._processor = _processor + + def __call__(self, **kwargs): + result = self._processor(**kwargs) + + # Map HuggingFace output keys to sglang standard keys + key_mapping = { + "input_image_embeds": "pixel_values", + "input_audio_embeds": "audio_features", + "audio_embed_sizes": "audio_feature_lens", + } + for hf_key, sglang_key in key_mapping.items(): + if hf_key in result: + result[sglang_key] = result[hf_key] + + # Filter out None or empty tensors from the result. + # This prevents the sglang function base_processor.collect_mm_items_from_processor_output() + # from misclassifying audio content as image content, and vice versa. + filtered_result = { + k: v + for k, v in result.items() + if v is not None and (not hasattr(v, "numel") or v.numel() > 0) + } + return filtered_result -class Phi4MMImageProcessor(BaseMultimodalProcessor): +class Phi4MMMultimodalProcessor(BaseMultimodalProcessor): models = [Phi4MMForCausalLM] def __init__(self, hf_config, server_args, _processor): - super().__init__(hf_config, server_args, _processor) + self.processor = Phi4MMProcessorAdapter(_processor) + super().__init__(hf_config, server_args, self.processor) + + # the following CONSTANTS come from hugging-face microsoft/Phi-4-multimodal-instruct's processing_phi4mm.py file + # ref: https://huggingface.co/microsoft/Phi-4-multimodal-instruct/blob/main/processing_phi4mm.py + self.IMAGE_TOKEN = "<|endoftext10|>" + self.AUDIO_TOKEN = "<|endoftext11|>" + self.IM_TOKEN_ID = 200010 + self.AUDIO_TOKEN_ID = 200011 + self.AUDIO_SAMPLE_RATE = 16000 + self.multimodal_tokens = MultimodalSpecialTokens( - image_token=_IMAGE_SPECIAL_TOKEN, - ).build(_processor) + image_token=self.IMAGE_TOKEN, + image_token_id=self.IM_TOKEN_ID, + audio_token=self.AUDIO_TOKEN, + audio_token_id=self.AUDIO_TOKEN_ID, + ).build(self.processor) async def process_mm_data_async( self, @@ -32,46 +74,29 @@ async def process_mm_data_async( max_req_input_len, **kwargs, ): - if audio_data: - logger.warning( - "Currently SGLang does not support audio data for Phi4MM. We are working on it. You can file an issue to help us prioritize." - ) - audio_data = [] - base_output = self.load_mm_data( prompt=input_text, max_req_input_len=max_req_input_len, audio_data=audio_data, image_data=image_data, multimodal_tokens=self.multimodal_tokens, + audio_sample_rate=self.AUDIO_SAMPLE_RATE, ) - if base_output is None: - return None - res = self.process_mm_data( - input_text=base_output.input_text, - images=base_output.images, - audios=base_output.audios, - ) + if base_output.audios is not None: + # hugging-face microsoft/Phi-4-multimodal-instruct's processing_phi4mm.py file requires the audio input to be tuple of (audio, sample_rate) + # ref: https://huggingface.co/microsoft/Phi-4-multimodal-instruct/blob/main/processing_phi4mm.py + base_output.audios = [ + (audio, self.AUDIO_SAMPLE_RATE) for audio in base_output.audios + ] - input_ids = res["input_ids"].flatten() - image_offsets = self.get_mm_items_offset( - input_ids=input_ids, - mm_token_id=_IMAGE_SPECIAL_TOKEN_ID, + mm_items, input_ids, _ = self.process_and_combine_mm_data( + base_output, self.multimodal_tokens ) - items = [ - MultimodalDataItem( - feature=res["input_image_embeds"], - image_sizes=res["image_sizes"], - image_emb_mask=res["image_attention_mask"], - offsets=image_offsets, - modality=Modality.IMAGE, - ) - ] - return { - "mm_items": items, "input_ids": input_ids.tolist(), - "im_token_id": _IMAGE_SPECIAL_TOKEN_ID, + "mm_items": mm_items, + "im_token_id": self.IM_TOKEN_ID, + "audio_token_id": self.AUDIO_TOKEN_ID, } diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index ce159a4da77b..dc6e72d75dcd 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -691,12 +691,17 @@ def decode_video_base64(video_base64): ) # Return an empty array and size tuple if no frames were found -def load_audio(audio_file: str, sr: int = 16000, mono: bool = True) -> np.ndarray: +def load_audio( + audio_file: str, sr: Optional[int] = None, mono: bool = True +) -> np.ndarray: # Use soundfile here, since librosa use it under the hood, # and librosa will not support audio loading in the future import soundfile as sf from scipy.signal import resample + if sr is None: + sr = 16000 + # Load audio data if isinstance(audio_file, bytes): audio, original_sr = sf.read(BytesIO(audio_file)) diff --git a/test/srt/test_vision_openai_server_b.py b/test/srt/test_vision_openai_server_b.py index f6152ea76dfc..53498946144c 100644 --- a/test/srt/test_vision_openai_server_b.py +++ b/test/srt/test_vision_openai_server_b.py @@ -200,16 +200,17 @@ def setUpClass(cls): "0.70", "--disable-radix-cache", "--max-loras-per-batch", - "1", + "2", "--revision", revision, "--lora-paths", f"vision={constants.HF_HUB_CACHE}/models--microsoft--Phi-4-multimodal-instruct/snapshots/{revision}/vision-lora", + f"speech={constants.HF_HUB_CACHE}/models--microsoft--Phi-4-multimodal-instruct/snapshots/{revision}/speech-lora", ], ) cls.base_url += "/v1" - def get_request_kwargs(self): + def get_vision_request_kwargs(self): return { "extra_body": { "lora_path": "vision", @@ -218,8 +219,21 @@ def get_request_kwargs(self): } } - def test_video_chat_completion(self): - pass + def get_audio_request_kwargs(self): + return { + "extra_body": { + "lora_path": "speech", + "top_k": 1, + "top_p": 1.0, + } + } + + def test_audio_chat_completion(self): + self._test_audio_speech_completion() + # TODO: currently phi4-mm cannot pass this test. + # We are investigating this issue. + # Response: La ciudad está situada en la costa este de la isla, en la desembocadura del río St. Lawrence. + # self._test_audio_ambient_completion() class TestVILAServer(TestOpenAIVisionServer): diff --git a/test/srt/test_vision_openai_server_common.py b/test/srt/test_vision_openai_server_common.py index 5d958fd5a26c..341db654e053 100644 --- a/test/srt/test_vision_openai_server_common.py +++ b/test/srt/test_vision_openai_server_common.py @@ -47,6 +47,12 @@ def setUpClass(cls): def tearDownClass(cls): kill_process_tree(cls.process.pid) + def get_audio_request_kwargs(self): + return self.get_request_kwargs() + + def get_vision_request_kwargs(self): + return self.get_request_kwargs() + def get_request_kwargs(self): return {} @@ -71,7 +77,7 @@ def test_single_image_chat_completion(self): }, ], temperature=0, - **(self.get_request_kwargs()), + **(self.get_vision_request_kwargs()), ) assert response.choices[0].message.role == "assistant" @@ -134,7 +140,7 @@ def test_multi_turn_chat_completion(self): }, ], temperature=0, - **(self.get_request_kwargs()), + **(self.get_vision_request_kwargs()), ) assert response.choices[0].message.role == "assistant" @@ -177,7 +183,7 @@ def test_multi_images_chat_completion(self): }, ], temperature=0, - **(self.get_request_kwargs()), + **(self.get_vision_request_kwargs()), ) assert response.choices[0].message.role == "assistant" @@ -333,7 +339,7 @@ def _test_video_chat_completion(self): temperature=0, max_tokens=1024, stream=False, - **(self.get_request_kwargs()), + **(self.get_vision_request_kwargs()), ) video_response = response.choices[0].message.content @@ -376,7 +382,7 @@ def test_regex(self): + r"""\}""" ) - extra_kwargs = self.get_request_kwargs() + extra_kwargs = self.get_vision_request_kwargs() extra_kwargs.setdefault("extra_body", {})["regex"] = regex response = client.chat.completions.create( @@ -443,7 +449,7 @@ def run_decode_with_image(self, image_id): {"role": "user", "content": content}, ], temperature=0, - **(self.get_request_kwargs()), + **(self.get_vision_request_kwargs()), ) assert response.choices[0].message.role == "assistant" @@ -486,7 +492,7 @@ def get_audio_response(self, url: str, prompt, category): temperature=0, max_tokens=128, stream=False, - **(self.get_request_kwargs()), + **(self.get_audio_request_kwargs()), ) audio_response = response.choices[0].message.content @@ -500,7 +506,7 @@ def get_audio_response(self, url: str, prompt, category): self.assertIsNotNone(audio_response) self.assertGreater(len(audio_response), 0) - return audio_response + return audio_response.lower() def _test_audio_speech_completion(self): # a fragment of Trump's speech From 1403ea56949e4e388853f835288c83a86ec96027 Mon Sep 17 00:00:00 2001 From: Shangming Cai Date: Sat, 19 Jul 2025 13:00:49 +0800 Subject: [PATCH 042/396] [PD] Support non-MLA models PD different TP with DP attention (#7931) Signed-off-by: Shangming Cai --- .../srt/disaggregation/mooncake/conn.py | 94 ++++++++----------- 1 file changed, 41 insertions(+), 53 deletions(-) diff --git a/python/sglang/srt/disaggregation/mooncake/conn.py b/python/sglang/srt/disaggregation/mooncake/conn.py index a9e9bf2c5938..e345d9519eac 100644 --- a/python/sglang/srt/disaggregation/mooncake/conn.py +++ b/python/sglang/srt/disaggregation/mooncake/conn.py @@ -321,67 +321,60 @@ def send_kvcache_slice( This may introduce performance overhead (increased TTFT) for long sequences. """ # Extract configuration - local_tp_rank = self.kv_args.engine_rank local_tp_size = self.tp_size // self.dp_size + local_tp_rank_in_group = self.kv_args.engine_rank % local_tp_size + src_kv_item_len = self.kv_args.kv_item_lens[0] + dst_tp_rank_in_group = dst_tp_rank % dst_tp_size num_kv_heads = self.kv_args.kv_head_num num_layers = len(self.kv_args.kv_data_ptrs) page_size = self.kv_args.page_size # Calculate head distribution - heads_per_decode_rank = num_kv_heads * local_tp_size // dst_tp_size - heads_per_prefill_rank = num_kv_heads - decode_global_head_start = dst_tp_rank * heads_per_decode_rank - prefill_global_head_start = local_tp_rank * heads_per_prefill_rank - bytes_per_head = dst_kv_item_len // heads_per_decode_rank // page_size - - decode_rank_item_lens = [dst_kv_item_len for _ in range(num_layers)] + src_heads_per_rank = num_kv_heads + dst_heads_per_rank = num_kv_heads * local_tp_size // dst_tp_size + bytes_per_head_slice_to_send = ( + dst_kv_item_len // page_size // dst_heads_per_rank + ) # Determine slicing parameters based on TP configuration if local_tp_size > dst_tp_size: - src_head_offset = 0 - num_heads_to_send = heads_per_prefill_rank - dst_head_offset = prefill_global_head_start - decode_global_head_start + # Send KVCache from multiple prefill instances to 1 decode instance + src_head_start_offset = 0 + num_heads_to_send = src_heads_per_rank + dst_head_start_offset = local_tp_rank_in_group * src_heads_per_rank else: - src_head_offset = decode_global_head_start - prefill_global_head_start - num_heads_to_send = heads_per_decode_rank - dst_head_offset = 0 + # Send KVCache from 1 prefill instance to multiple decode instances + src_head_start_offset = dst_tp_rank_in_group * dst_heads_per_rank + num_heads_to_send = dst_heads_per_rank + dst_head_start_offset = 0 - layer_transfer_params = [] + layers_params = [] for layer_id in range(num_layers): - item_len_of_prefill_rank_page = self.kv_args.kv_item_lens[layer_id] - - # Page stride on the target dst decode rank for its slice pages - item_len_of_decode_rank_page = decode_rank_item_lens[layer_id] - - if item_len_of_prefill_rank_page == 0 or num_kv_heads == 0: - logger.error( - f"Invalid item_len_of_prefill_rank_page or num_kv_heads for layer {layer_id}" - ) - return -1 - - # Calculate precise byte offset and length for the sub-slice within the prefill page data - src_slice_offset = src_head_offset * bytes_per_head - dst_slice_offset = dst_head_offset * bytes_per_head - slice_lens_per_page = num_heads_to_send * bytes_per_head + # Calculate precise byte offset and length for the sub-slice within the token + src_head_slice_offset = src_head_start_offset * bytes_per_head_slice_to_send + dst_head_slice_offset = dst_head_start_offset * bytes_per_head_slice_to_send + heads_bytes_per_token_to_send = ( + num_heads_to_send * bytes_per_head_slice_to_send + ) - # Sanity check: The data sub-slice to be sent should fit into the decode instance's page. - # This means slice_lens_per_page <= item_len_of_decode_rank_page - if slice_lens_per_page > item_len_of_decode_rank_page: + # Sanity check: The data sub-slice to be sent should fit into the dst buffer. + # This means heads_bytes_per_token_to_send <= (dst_kv_item_len // page_size) + if heads_bytes_per_token_to_send > (dst_kv_item_len // page_size): logger.error( f"[{mooncake_session_id}] Layer {layer_id}: " - f"slice size ({slice_lens_per_page}) exceeds " - f"target page size ({item_len_of_decode_rank_page})" + f"slice size ({heads_bytes_per_token_to_send}) exceeds " + f"target token slot size ({dst_kv_item_len // page_size})" ) return -1 - layer_transfer_params.append( + layers_params.append( ( self.kv_args.kv_data_ptrs[layer_id], dst_kv_ptrs[layer_id], - item_len_of_prefill_rank_page, - item_len_of_decode_rank_page, - src_slice_offset, - dst_slice_offset, - slice_lens_per_page, + src_kv_item_len, + dst_kv_item_len, + src_head_slice_offset, + dst_head_slice_offset, + heads_bytes_per_token_to_send, ) ) @@ -391,9 +384,9 @@ def process_layer_tp_aware(layer_params): dst_ptr, src_item_len, dst_item_len, - src_offset, - dst_offset, - slice_lens_per_page, + src_head_slice_offset, + dst_head_slice_offset, + heads_bytes_per_token_to_send, ) = layer_params src_addr_list = [] dst_addr_list = [] @@ -424,17 +417,12 @@ def process_layer_tp_aware(layer_params): ) # Calculate final src and dst addresses by applying head-slice offsets - src_slice_addr = src_token_slot_start_addr + src_offset - dst_slice_addr = dst_token_slot_start_addr + dst_offset + src_slice_addr = src_token_slot_start_addr + src_head_slice_offset + dst_slice_addr = dst_token_slot_start_addr + dst_head_slice_offset src_addr_list.append(src_slice_addr) dst_addr_list.append(dst_slice_addr) - length_list.append(slice_lens_per_page) - - logger.debug( - f"SYNC: sid={mooncake_session_id}, " - f"src={src_slice_addr}, dst={dst_slice_addr}, len={slice_lens_per_page}" - ) + length_list.append(heads_bytes_per_token_to_send) return self.engine.batch_transfer_sync( mooncake_session_id, src_addr_list, dst_addr_list, length_list @@ -445,7 +433,7 @@ def process_layer_tp_aware(layer_params): process_layer_tp_aware, layer_params, ) - for layer_params in layer_transfer_params + for layer_params in layers_params ] for future in concurrent.futures.as_completed(futures): From 610381b75e6317cf60870ed443f02967892cd729 Mon Sep 17 00:00:00 2001 From: Yingchun Lai Date: Sat, 19 Jul 2025 13:08:46 +0800 Subject: [PATCH 043/396] [health_generate] fix: fix the /health_generate always success bug (#8028) --- python/sglang/srt/managers/tokenizer_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 38a8fa53af7a..7ba07f675120 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -1359,7 +1359,7 @@ async def handle_loop(self): while True: recv_obj = await self.recv_from_detokenizer.recv_pyobj() self._result_dispatcher(recv_obj) - self.last_receive_tstamp = time.time() + self.last_receive_tstamp = time.perf_counter() def _handle_batch_output( self, From 8fcc55cfa1c365a3ab92ed097eb10b6658fe1e74 Mon Sep 17 00:00:00 2001 From: Simo Lin Date: Fri, 18 Jul 2025 22:09:17 -0700 Subject: [PATCH 044/396] [router] router metrics cleanup (#8158) --- sgl-router/src/lib.rs | 4 +- sgl-router/src/metrics.rs | 324 ++++++++++++++++++++++++ sgl-router/src/policies/cache_aware.rs | 16 +- sgl-router/src/policies/power_of_two.rs | 5 +- sgl-router/src/prometheus.rs | 40 --- sgl-router/src/routers/pd_router.rs | 40 +-- sgl-router/src/routers/pd_types.rs | 14 - sgl-router/src/routers/router.rs | 46 ++-- sgl-router/src/server.rs | 4 +- 9 files changed, 378 insertions(+), 115 deletions(-) create mode 100644 sgl-router/src/metrics.rs delete mode 100644 sgl-router/src/prometheus.rs diff --git a/sgl-router/src/lib.rs b/sgl-router/src/lib.rs index 49e8cc573059..a37a4b474728 100644 --- a/sgl-router/src/lib.rs +++ b/sgl-router/src/lib.rs @@ -3,14 +3,14 @@ pub mod config; pub mod logging; use std::collections::HashMap; pub mod core; +pub mod metrics; pub mod openai_api_types; pub mod policies; -pub mod prometheus; pub mod routers; pub mod server; pub mod service_discovery; pub mod tree; -use crate::prometheus::PrometheusConfig; +use crate::metrics::PrometheusConfig; #[pyclass(eq)] #[derive(Clone, PartialEq, Debug)] diff --git a/sgl-router/src/metrics.rs b/sgl-router/src/metrics.rs new file mode 100644 index 000000000000..0ff2055c540c --- /dev/null +++ b/sgl-router/src/metrics.rs @@ -0,0 +1,324 @@ +use metrics::{counter, describe_counter, describe_gauge, describe_histogram, gauge, histogram}; +use metrics_exporter_prometheus::{Matcher, PrometheusBuilder}; +use std::net::{IpAddr, Ipv4Addr, SocketAddr}; +use std::time::Duration; + +#[derive(Debug, Clone)] +pub struct PrometheusConfig { + pub port: u16, + pub host: String, +} + +impl Default for PrometheusConfig { + fn default() -> Self { + Self { + port: 29000, + host: "0.0.0.0".to_string(), + } + } +} + +pub fn init_metrics() { + // Request metrics + describe_counter!( + "sgl_router_requests_total", + "Total number of requests by route and method" + ); + describe_histogram!( + "sgl_router_request_duration_seconds", + "Request duration in seconds by route" + ); + describe_counter!( + "sgl_router_request_errors_total", + "Total number of request errors by route and error type" + ); + describe_counter!( + "sgl_router_retries_total", + "Total number of request retries by route" + ); + + // Worker metrics + describe_gauge!( + "sgl_router_active_workers", + "Number of currently active workers" + ); + describe_gauge!( + "sgl_router_worker_health", + "Worker health status (1=healthy, 0=unhealthy)" + ); + describe_gauge!("sgl_router_worker_load", "Current load on each worker"); + describe_counter!( + "sgl_router_processed_requests_total", + "Total requests processed by each worker" + ); + + // Policy metrics + describe_counter!( + "sgl_router_policy_decisions_total", + "Total routing policy decisions by policy and worker" + ); + describe_counter!("sgl_router_cache_hits_total", "Total cache hits"); + describe_counter!("sgl_router_cache_misses_total", "Total cache misses"); + describe_gauge!( + "sgl_router_tree_size", + "Current tree size for cache-aware routing" + ); + describe_counter!( + "sgl_router_load_balancing_events_total", + "Total load balancing trigger events" + ); + describe_gauge!("sgl_router_max_load", "Maximum worker load"); + describe_gauge!("sgl_router_min_load", "Minimum worker load"); + + // PD-specific metrics + describe_counter!("sgl_router_pd_requests_total", "Total PD requests by route"); + describe_counter!( + "sgl_router_pd_prefill_requests_total", + "Total prefill requests per worker" + ); + describe_counter!( + "sgl_router_pd_decode_requests_total", + "Total decode requests per worker" + ); + describe_counter!( + "sgl_router_pd_errors_total", + "Total PD errors by error type" + ); + describe_counter!( + "sgl_router_pd_prefill_errors_total", + "Total prefill server errors" + ); + describe_counter!( + "sgl_router_pd_decode_errors_total", + "Total decode server errors" + ); + describe_counter!( + "sgl_router_pd_stream_errors_total", + "Total streaming errors per worker" + ); + describe_histogram!( + "sgl_router_pd_request_duration_seconds", + "PD request duration by route" + ); + + // Service discovery metrics + describe_counter!( + "sgl_router_discovery_updates_total", + "Total service discovery update events" + ); + describe_gauge!( + "sgl_router_discovery_workers_added", + "Number of workers added in last discovery update" + ); + describe_gauge!( + "sgl_router_discovery_workers_removed", + "Number of workers removed in last discovery update" + ); + + // Generate request specific metrics + describe_histogram!( + "sgl_router_generate_duration_seconds", + "Generate request duration" + ); + + // Running requests gauge for cache-aware policy + describe_gauge!( + "sgl_router_running_requests", + "Number of running requests per worker" + ); +} + +pub fn start_prometheus(config: PrometheusConfig) { + // Initialize metric descriptions + init_metrics(); + + let duration_matcher = Matcher::Suffix(String::from("duration")); + let duration_bucket = [ + 0.001, 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0, 10.0, 15.0, 30.0, 45.0, + 60.0, 90.0, 120.0, 180.0, 240.0, + ]; + + let ip_addr: IpAddr = config + .host + .parse() + .unwrap_or(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0))); + let socket_addr = SocketAddr::new(ip_addr, config.port); + + PrometheusBuilder::new() + .with_http_listener(socket_addr) + .upkeep_timeout(Duration::from_secs(5 * 60)) + .set_buckets_for_metric(duration_matcher, &duration_bucket) + .expect("failed to set duration bucket") + .install() + .expect("failed to install Prometheus metrics exporter"); +} + +pub struct RouterMetrics; + +impl RouterMetrics { + // Request metrics + pub fn record_request(route: &str) { + counter!("sgl_router_requests_total", + "route" => route.to_string() + ) + .increment(1); + } + + pub fn record_request_duration(route: &str, duration: Duration) { + histogram!("sgl_router_request_duration_seconds", + "route" => route.to_string() + ) + .record(duration.as_secs_f64()); + } + + pub fn record_request_error(route: &str, error_type: &str) { + counter!("sgl_router_request_errors_total", + "route" => route.to_string(), + "error_type" => error_type.to_string() + ) + .increment(1); + } + + pub fn record_retry(route: &str) { + counter!("sgl_router_retries_total", + "route" => route.to_string() + ) + .increment(1); + } + + // Worker metrics + pub fn set_active_workers(count: usize) { + gauge!("sgl_router_active_workers").set(count as f64); + } + + pub fn set_worker_health(worker_url: &str, healthy: bool) { + gauge!("sgl_router_worker_health", + "worker" => worker_url.to_string() + ) + .set(if healthy { 1.0 } else { 0.0 }); + } + + pub fn set_worker_load(worker_url: &str, load: usize) { + gauge!("sgl_router_worker_load", + "worker" => worker_url.to_string() + ) + .set(load as f64); + } + + pub fn record_processed_request(worker_url: &str) { + counter!("sgl_router_processed_requests_total", + "worker" => worker_url.to_string() + ) + .increment(1); + } + + // Policy metrics + pub fn record_policy_decision(policy: &str, worker: &str) { + counter!("sgl_router_policy_decisions_total", + "policy" => policy.to_string(), + "worker" => worker.to_string() + ) + .increment(1); + } + + pub fn record_cache_hit() { + counter!("sgl_router_cache_hits_total").increment(1); + } + + pub fn record_cache_miss() { + counter!("sgl_router_cache_misses_total").increment(1); + } + + pub fn set_tree_size(worker: &str, size: usize) { + gauge!("sgl_router_tree_size", + "worker" => worker.to_string() + ) + .set(size as f64); + } + + pub fn record_load_balancing_event() { + counter!("sgl_router_load_balancing_events_total").increment(1); + } + + pub fn set_load_range(max_load: usize, min_load: usize) { + gauge!("sgl_router_max_load").set(max_load as f64); + gauge!("sgl_router_min_load").set(min_load as f64); + } + + // PD-specific metrics + pub fn record_pd_request(route: &str) { + counter!("sgl_router_pd_requests_total", + "route" => route.to_string() + ) + .increment(1); + } + + pub fn record_pd_request_duration(route: &str, duration: Duration) { + histogram!("sgl_router_pd_request_duration_seconds", + "route" => route.to_string() + ) + .record(duration.as_secs_f64()); + } + + pub fn record_pd_prefill_request(worker: &str) { + counter!("sgl_router_pd_prefill_requests_total", + "worker" => worker.to_string() + ) + .increment(1); + } + + pub fn record_pd_decode_request(worker: &str) { + counter!("sgl_router_pd_decode_requests_total", + "worker" => worker.to_string() + ) + .increment(1); + } + + pub fn record_pd_error(error_type: &str) { + counter!("sgl_router_pd_errors_total", + "error_type" => error_type.to_string() + ) + .increment(1); + } + + pub fn record_pd_prefill_error(worker: &str) { + counter!("sgl_router_pd_prefill_errors_total", + "worker" => worker.to_string() + ) + .increment(1); + } + + pub fn record_pd_decode_error(worker: &str) { + counter!("sgl_router_pd_decode_errors_total", + "worker" => worker.to_string() + ) + .increment(1); + } + + pub fn record_pd_stream_error(worker: &str) { + counter!("sgl_router_pd_stream_errors_total", + "worker" => worker.to_string() + ) + .increment(1); + } + + // Service discovery metrics + pub fn record_discovery_update(added: usize, removed: usize) { + counter!("sgl_router_discovery_updates_total").increment(1); + gauge!("sgl_router_discovery_workers_added").set(added as f64); + gauge!("sgl_router_discovery_workers_removed").set(removed as f64); + } + + // Generate request metrics + pub fn record_generate_duration(duration: Duration) { + histogram!("sgl_router_generate_duration_seconds").record(duration.as_secs_f64()); + } + + // Running requests for cache-aware policy + pub fn set_running_requests(worker: &str, count: usize) { + gauge!("sgl_router_running_requests", + "worker" => worker.to_string() + ) + .set(count as f64); + } +} diff --git a/sgl-router/src/policies/cache_aware.rs b/sgl-router/src/policies/cache_aware.rs index db5972ba68a1..9e30c0d01f70 100644 --- a/sgl-router/src/policies/cache_aware.rs +++ b/sgl-router/src/policies/cache_aware.rs @@ -61,8 +61,8 @@ use super::{get_healthy_worker_indices, CacheAwareConfig, LoadBalancingPolicy}; use crate::core::Worker; +use crate::metrics::RouterMetrics; use crate::tree::Tree; -use metrics::{counter, gauge}; use std::sync::{Arc, Mutex}; use std::thread; use std::time::Duration; @@ -171,9 +171,8 @@ impl LoadBalancingPolicy for CacheAwarePolicy { max_load, min_load, worker_loads ); - counter!("sgl_router_load_balancing_events_total").increment(1); - gauge!("sgl_router_max_load").set(max_load as f64); - gauge!("sgl_router_min_load").set(min_load as f64); + RouterMetrics::record_load_balancing_event(); + RouterMetrics::set_load_range(max_load, min_load); // Use shortest queue when imbalanced let min_load_idx = healthy_indices @@ -183,8 +182,7 @@ impl LoadBalancingPolicy for CacheAwarePolicy { // Increment processed counter workers[min_load_idx].increment_processed(); - counter!("sgl_router_processed_requests_total", "worker" => workers[min_load_idx].url().to_string()) - .increment(1); + RouterMetrics::record_processed_request(workers[min_load_idx].url()); return Some(min_load_idx); } @@ -201,10 +199,10 @@ impl LoadBalancingPolicy for CacheAwarePolicy { }; let selected_url = if match_rate > self.config.cache_threshold { - counter!("sgl_router_cache_hits_total").increment(1); + RouterMetrics::record_cache_hit(); matched_worker.to_string() } else { - counter!("sgl_router_cache_misses_total").increment(1); + RouterMetrics::record_cache_miss(); tree.get_smallest_tenant() }; @@ -221,7 +219,7 @@ impl LoadBalancingPolicy for CacheAwarePolicy { // Increment processed counter workers[selected_idx].increment_processed(); - counter!("sgl_router_processed_requests_total", "worker" => selected_url).increment(1); + RouterMetrics::record_processed_request(&selected_url); return Some(selected_idx); } diff --git a/sgl-router/src/policies/power_of_two.rs b/sgl-router/src/policies/power_of_two.rs index 53c8461965ff..2167273aef35 100644 --- a/sgl-router/src/policies/power_of_two.rs +++ b/sgl-router/src/policies/power_of_two.rs @@ -2,7 +2,7 @@ use super::{get_healthy_worker_indices, LoadBalancingPolicy}; use crate::core::Worker; -use metrics::counter; +use crate::metrics::RouterMetrics; use rand::Rng; use std::collections::HashMap; use std::sync::RwLock; @@ -89,8 +89,7 @@ impl LoadBalancingPolicy for PowerOfTwoPolicy { // Increment processed counter workers[selected_idx].increment_processed(); - counter!("sgl_router_processed_requests_total", "worker" => workers[selected_idx].url().to_string()) - .increment(1); + RouterMetrics::record_processed_request(workers[selected_idx].url()); Some(selected_idx) } diff --git a/sgl-router/src/prometheus.rs b/sgl-router/src/prometheus.rs deleted file mode 100644 index ff5a221bd6e8..000000000000 --- a/sgl-router/src/prometheus.rs +++ /dev/null @@ -1,40 +0,0 @@ -use metrics_exporter_prometheus::{Matcher, PrometheusBuilder}; -use std::net::{IpAddr, Ipv4Addr, SocketAddr}; -use std::time::Duration; - -#[derive(Debug, Clone)] -pub struct PrometheusConfig { - pub port: u16, - pub host: String, -} - -impl Default for PrometheusConfig { - fn default() -> Self { - Self { - port: 29000, - host: "0.0.0.0".to_string(), - } - } -} - -pub fn start_prometheus(config: PrometheusConfig) { - let duration_matcher = Matcher::Suffix(String::from("duration")); - let duration_bucket = [ - 0.001, 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0, 10.0, 15.0, 30.0, 45.0, - 60.0, 90.0, 120.0, 180.0, 240.0, - ]; - - let ip_addr: IpAddr = config - .host - .parse() - .unwrap_or(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0))); - let socket_addr = SocketAddr::new(ip_addr, config.port); - - PrometheusBuilder::new() - .with_http_listener(socket_addr) - .upkeep_timeout(Duration::from_secs(5 * 60)) - .set_buckets_for_metric(duration_matcher, &duration_bucket) - .expect("failed to set duration bucket") - .install() - .expect("failed to install Prometheus metrics exporter"); -} diff --git a/sgl-router/src/routers/pd_router.rs b/sgl-router/src/routers/pd_router.rs index 2ac8f9027762..d156c9f341d6 100644 --- a/sgl-router/src/routers/pd_router.rs +++ b/sgl-router/src/routers/pd_router.rs @@ -4,13 +4,13 @@ use super::pd_types::{api_path, Bootstrap, ChatReqInput, GenerateReqInput, PDRouterError}; use super::request_adapter::ToPdRequest; use crate::core::{HealthChecker, Worker, WorkerFactory, WorkerLoadGuard}; +use crate::metrics::RouterMetrics; use crate::openai_api_types::{ChatCompletionRequest, CompletionRequest, GenerateRequest}; use crate::policies::LoadBalancingPolicy; use crate::tree::Tree; use actix_web::http::header::{HeaderValue, CONTENT_TYPE}; use actix_web::{HttpRequest, HttpResponse}; use futures_util::{StreamExt, TryStreamExt}; -use metrics::{counter, histogram}; use serde_json::Value; use std::collections::HashMap; use std::sync::{Arc, Mutex, RwLock}; @@ -296,7 +296,7 @@ impl PDRouter { Ok(pair) => pair, Err(e) => { error!("Failed to select PD pair: {}", e); - counter!("sgl_router_pd_errors_total", "error" => "server_selection").increment(1); + RouterMetrics::record_pd_error("server_selection"); return HttpResponse::ServiceUnavailable() .body(format!("No available servers: {}", e)); } @@ -313,7 +313,7 @@ impl PDRouter { // Add bootstrap info using the trait method if let Err(e) = typed_req.add_bootstrap_info(prefill.as_ref()) { error!("Failed to add bootstrap info: {}", e); - counter!("sgl_router_pd_errors_total", "error" => "bootstrap_injection").increment(1); + RouterMetrics::record_pd_error("bootstrap_injection"); return HttpResponse::InternalServerError() .body(format!("Bootstrap injection failed: {}", e)); } @@ -374,7 +374,7 @@ impl PDRouter { Ok(pair) => pair, Err(e) => { error!("Failed to select PD pair: {}", e); - counter!("sgl_router_pd_errors_total", "error" => "server_selection").increment(1); + RouterMetrics::record_pd_error("server_selection"); return HttpResponse::ServiceUnavailable() .body(format!("No available servers: {}", e)); } @@ -391,7 +391,7 @@ impl PDRouter { // Add bootstrap info using the trait method if let Err(e) = typed_req.add_bootstrap_info(prefill.as_ref()) { error!("Failed to add bootstrap info: {}", e); - counter!("sgl_router_pd_errors_total", "error" => "bootstrap_injection").increment(1); + RouterMetrics::record_pd_error("bootstrap_injection"); return HttpResponse::InternalServerError() .body(format!("Bootstrap injection failed: {}", e)); } @@ -460,13 +460,10 @@ impl PDRouter { // Update metrics let duration = start_time.elapsed(); - histogram!("sgl_router_pd_request_duration_seconds", "route" => route.to_string()) - .record(duration.as_secs_f64()); - counter!("sgl_router_pd_requests_total", "route" => route.to_string()).increment(1); - counter!("sgl_router_pd_prefill_requests_total", "worker" => prefill.url().to_string()) - .increment(1); - counter!("sgl_router_pd_decode_requests_total", "worker" => decode.url().to_string()) - .increment(1); + RouterMetrics::record_pd_request_duration(route, duration); + RouterMetrics::record_pd_request(route); + RouterMetrics::record_pd_prefill_request(prefill.url()); + RouterMetrics::record_pd_decode_request(decode.url()); // Process decode response match decode_result { @@ -475,7 +472,7 @@ impl PDRouter { .unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR); if !status.is_success() { - counter!("sgl_router_pd_decode_errors_total", "worker" => decode.url().to_string()).increment(1); + RouterMetrics::record_pd_decode_error(decode.url()); error!( "Decode server {} returned error status: {}", decode.url(), @@ -501,7 +498,7 @@ impl PDRouter { prefill.url(), e ); - counter!("sgl_router_pd_prefill_errors_total", "worker" => prefill.url().to_string()).increment(1); + RouterMetrics::record_pd_prefill_error(prefill.url()); } if is_stream { @@ -548,13 +545,19 @@ impl PDRouter { } else { // No logprob merging needed HttpResponse::build(status) - .insert_header((CONTENT_TYPE, HeaderValue::from_static("text/event-stream"))) + .insert_header(( + CONTENT_TYPE, + HeaderValue::from_static("text/event-stream"), + )) .streaming({ let decode_url = decode.url().to_string(); res.bytes_stream().map_err(move |e| { error!("Stream error from decode server {}: {}", decode_url, e); - counter!("sgl_router_pd_stream_errors_total", "worker" => decode_url.to_string()).increment(1); - actix_web::error::ErrorInternalServerError(format!("Stream error: {}", e)) + RouterMetrics::record_pd_stream_error(&decode_url); + actix_web::error::ErrorInternalServerError(format!( + "Stream error: {}", + e + )) }) }) } @@ -578,8 +581,7 @@ impl PDRouter { } Err(e) => { error!("Decode request failed: {}", e); - counter!("sgl_router_pd_decode_errors_total", "worker" => decode.url().to_string()) - .increment(1); + RouterMetrics::record_pd_decode_error(decode.url()); HttpResponse::BadGateway().body(format!("Decode server error: {}", e)) } } diff --git a/sgl-router/src/routers/pd_types.rs b/sgl-router/src/routers/pd_types.rs index 75473b0e33a8..155274b06f16 100644 --- a/sgl-router/src/routers/pd_types.rs +++ b/sgl-router/src/routers/pd_types.rs @@ -151,13 +151,6 @@ impl GenerateReqInput { if texts.is_empty() { return Err("Batch text array is empty".to_string()); } - if texts.len() > 10000 { - // Reasonable limit for production - return Err(format!( - "Batch size {} exceeds maximum allowed (10000)", - texts.len() - )); - } return Ok(Some(texts.len())); } @@ -166,13 +159,6 @@ impl GenerateReqInput { if ids.is_empty() { return Err("Batch input_ids array is empty".to_string()); } - if ids.len() > 10000 { - // Reasonable limit for production - return Err(format!( - "Batch size {} exceeds maximum allowed (10000)", - ids.len() - )); - } // Validate each sequence is not empty for (i, seq) in ids.iter().enumerate() { if seq.is_empty() { diff --git a/sgl-router/src/routers/router.rs b/sgl-router/src/routers/router.rs index ef44348eca20..c198b0c1dba5 100644 --- a/sgl-router/src/routers/router.rs +++ b/sgl-router/src/routers/router.rs @@ -1,6 +1,6 @@ use crate::core::{HealthChecker, Worker, WorkerFactory}; +use crate::metrics::RouterMetrics; use crate::policies::LoadBalancingPolicy; -use ::metrics::{counter, gauge, histogram}; use actix_web::http::header::{HeaderValue, CONTENT_TYPE}; use actix_web::{HttpRequest, HttpResponse}; use futures_util::{StreamExt, TryStreamExt}; @@ -43,7 +43,7 @@ impl Router { interval_secs: u64, ) -> Result { // Update active workers gauge - gauge!("sgl_router_active_workers").set(worker_urls.len() as f64); + RouterMetrics::set_active_workers(worker_urls.len()); // Wait for workers to be healthy (skip if empty - for service discovery mode) if !worker_urls.is_empty() { @@ -215,13 +215,11 @@ impl Router { // Record request metrics if route != "/health" { let duration = start.elapsed(); - counter!("sgl_router_requests_total", "route" => route.to_string()).increment(1); - histogram!("sgl_router_request_duration_seconds", "route" => route.to_string()) - .record(duration.as_secs_f64()); + RouterMetrics::record_request(route); + RouterMetrics::record_request_duration(route, duration); if !response.status().is_success() { - counter!("sgl_router_request_errors_total", "route" => route.to_string()) - .increment(1); + RouterMetrics::record_request_error(route, "request_failed"); } } response @@ -390,7 +388,7 @@ impl Router { while request_retries < MAX_REQUEST_RETRIES { if total_retries >= 1 { info!("Retrying request after {} failed attempts", total_retries); - counter!("sgl_router_retries_total", "route" => route.to_string()).increment(1); + RouterMetrics::record_retry(route); } // Increment load before request if using RAII load tracking @@ -398,8 +396,7 @@ impl Router { let workers_guard = self.workers.read().unwrap(); if let Some(worker) = workers_guard.iter().find(|w| w.url() == &worker_url) { worker.increment_load(); - gauge!("sgl_router_running_requests", "worker" => worker_url.to_string()) - .set(worker.load() as f64); + RouterMetrics::set_running_requests(&worker_url, worker.load()); true } else { false @@ -423,16 +420,14 @@ impl Router { if response.status().is_success() { let duration = start.elapsed(); - histogram!("sgl_router_generate_duration_seconds", "route" => route.to_string()) - .record(duration.as_secs_f64()); + RouterMetrics::record_generate_duration(duration); return response; } else { // if the worker is healthy, it means the request is bad, so return the error response let health_response = self.send_request(client, &worker_url, "/health", req).await; if health_response.status().is_success() { - counter!("sgl_router_request_errors_total", "route" => route.to_string()) - .increment(1); + RouterMetrics::record_request_error(route, "request_failed"); return response; } } @@ -455,7 +450,7 @@ impl Router { } } - counter!("sgl_router_request_errors_total", "route" => route.to_string()).increment(1); + RouterMetrics::record_request_error(route, "request_failed"); HttpResponse::InternalServerError().body("All retry attempts failed") } @@ -512,8 +507,7 @@ impl Router { if let Ok(workers_guard) = self.workers.read() { if let Some(worker) = workers_guard.iter().find(|w| w.url() == worker_url) { worker.decrement_load(); - gauge!("sgl_router_running_requests", "worker" => worker_url.to_string()) - .set(worker.load() as f64); + RouterMetrics::set_running_requests(&worker_url, worker.load()); } } } @@ -540,17 +534,15 @@ impl Router { if let Ok(workers_guard) = self.workers.read() { if let Some(worker) = workers_guard.iter().find(|w| w.url() == worker_url) { worker.decrement_load(); - gauge!("sgl_router_running_requests", "worker" => worker_url.to_string()) - .set(worker.load() as f64); + RouterMetrics::set_running_requests(&worker_url, worker.load()); } } } // Record metrics let duration = start.elapsed(); - histogram!("sgl_router_generate_duration_seconds", "route" => route.to_string()) - .record(duration.as_secs_f64()); - counter!("sgl_router_requests_total", "route" => route.to_string()).increment(1); + RouterMetrics::record_generate_duration(duration); + RouterMetrics::record_request(route); response } else if load_incremented { @@ -577,8 +569,10 @@ impl Router { workers_guard.iter().find(|w| w.url() == &worker_url) { worker.decrement_load(); - gauge!("sgl_router_running_requests", "worker" => worker_url.to_string()) - .set(worker.load() as f64); + RouterMetrics::set_running_requests( + &worker_url, + worker.load(), + ); debug!("Streaming is done!!") } } @@ -626,7 +620,7 @@ impl Router { info!("Added worker: {}", worker_url); let new_worker = WorkerFactory::create_regular(worker_url.to_string()); workers_guard.push(new_worker); - gauge!("sgl_router_active_workers").set(workers_guard.len() as f64); + RouterMetrics::set_active_workers(workers_guard.len()); // If cache aware policy, initialize the worker in the tree if let Some(cache_aware) = @@ -680,7 +674,7 @@ impl Router { if let Some(index) = workers_guard.iter().position(|w| w.url() == worker_url) { workers_guard.remove(index); info!("Removed worker: {}", worker_url); - gauge!("sgl_router_active_workers").set(workers_guard.len() as f64); + RouterMetrics::set_active_workers(workers_guard.len()); } else { warn!("Worker {} not found, skipping removal", worker_url); return; diff --git a/sgl-router/src/server.rs b/sgl-router/src/server.rs index 69340eefe52b..83774f172a35 100644 --- a/sgl-router/src/server.rs +++ b/sgl-router/src/server.rs @@ -1,7 +1,7 @@ use crate::config::RouterConfig; use crate::logging::{self, LoggingConfig}; +use crate::metrics::{self, PrometheusConfig}; use crate::openai_api_types::{ChatCompletionRequest, CompletionRequest, GenerateRequest}; -use crate::prometheus::{self, PrometheusConfig}; use crate::routers::{RouterFactory, RouterTrait}; use crate::service_discovery::{start_service_discovery, ServiceDiscoveryConfig}; use actix_web::{ @@ -237,7 +237,7 @@ pub async fn startup(config: ServerConfig) -> std::io::Result<()> { "🚧 Initializing Prometheus metrics on {}:{}", prometheus_config.host, prometheus_config.port ); - prometheus::start_prometheus(prometheus_config); + metrics::start_prometheus(prometheus_config); } else { info!("🚧 Prometheus metrics disabled"); } From b763cf7e8e2519f9b03ae29922ecbeba1db8e314 Mon Sep 17 00:00:00 2001 From: Simo Lin Date: Fri, 18 Jul 2025 22:09:54 -0700 Subject: [PATCH 045/396] [router] allow router to have empty workers (#8160) --- .../py_src/sglang_router/launch_router.py | 3 ++- sgl-router/py_test/test_launch_router.py | 23 ++++++++++++++++++- 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/sgl-router/py_src/sglang_router/launch_router.py b/sgl-router/py_src/sglang_router/launch_router.py index 092946a2719b..f7aaf6dee628 100644 --- a/sgl-router/py_src/sglang_router/launch_router.py +++ b/sgl-router/py_src/sglang_router/launch_router.py @@ -97,7 +97,8 @@ def add_cli_args( parser.add_argument( "--worker-urls", type=str, - nargs="+", + nargs="*", + default=[], help="List of worker URLs (e.g., http://worker1:8000 http://worker2:8000)", ) diff --git a/sgl-router/py_test/test_launch_router.py b/sgl-router/py_test/test_launch_router.py index 14a0fa12d4a9..90d8aa664395 100644 --- a/sgl-router/py_test/test_launch_router.py +++ b/sgl-router/py_test/test_launch_router.py @@ -90,7 +90,9 @@ def test_launch_router_common(self): def test_launch_router_with_empty_worker_urls(self): args = self.create_router_args(worker_urls=[]) - self.run_router_process(args) # Expected error + self.run_router_process( + args + ) # Should start successfully with empty worker list def test_launch_router_with_service_discovery(self): # Test router startup with service discovery enabled but no selectors @@ -279,6 +281,25 @@ def test_regular_service_discovery_args_parsing(self): self.assertEqual(router_args.prefill_selector, {}) self.assertEqual(router_args.decode_selector, {}) + def test_empty_worker_urls_args_parsing(self): + """Test that router accepts no worker URLs and defaults to empty list.""" + import argparse + + from sglang_router.launch_router import RouterArgs + + parser = argparse.ArgumentParser() + RouterArgs.add_cli_args(parser) + + # Test with no --worker-urls argument at all + args = parser.parse_args(["--policy", "random", "--port", "30000"]) + router_args = RouterArgs.from_cli_args(args) + self.assertEqual(router_args.worker_urls, []) + + # Test with explicit empty --worker-urls + args = parser.parse_args(["--worker-urls", "--policy", "random"]) + router_args = RouterArgs.from_cli_args(args) + self.assertEqual(router_args.worker_urls, []) + if __name__ == "__main__": unittest.main() From cfab0ff6e291851ffb5c96bf25f5ae07c5af3614 Mon Sep 17 00:00:00 2001 From: kyleliang-nv Date: Fri, 18 Jul 2025 22:34:29 -0700 Subject: [PATCH 046/396] Add GB200 wide-EP docker (#8157) --- docker/Dockerfile.gb200 | 357 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 357 insertions(+) create mode 100644 docker/Dockerfile.gb200 diff --git a/docker/Dockerfile.gb200 b/docker/Dockerfile.gb200 new file mode 100644 index 000000000000..05b0f42043bc --- /dev/null +++ b/docker/Dockerfile.gb200 @@ -0,0 +1,357 @@ +ARG CUDA_VERSION=12.8.1 +FROM nvidia/cuda:${CUDA_VERSION}-cudnn-devel-ubuntu22.04 + +ARG BUILD_TYPE=blackwell +ENV DEBIAN_FRONTEND=noninteractive \ + CUDA_HOME=/usr/local/cuda \ + GDRCOPY_HOME=/usr/src/gdrdrv-2.4.4/ \ + NVSHMEM_DIR=/sgl-workspace/nvshmem/install \ + BUILD_TYPE=${BUILD_TYPE} \ + TORCH_CUDA_ARCH_LIST="10.0 12.0" + +# Set timezone and install all packages +RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \ + && echo 'tzdata tzdata/Zones/America select Los_Angeles' | debconf-set-selections \ + && apt-get update && apt-get install -y --no-install-recommends \ + tzdata \ + software-properties-common netcat-openbsd kmod unzip openssh-server \ + curl wget lsof zsh ccache tmux htop git-lfs tree \ + python3 python3-pip python3-dev libpython3-dev \ + build-essential cmake \ + libopenmpi-dev libnuma1 libnuma-dev \ + libibverbs-dev libibverbs1 libibumad3 \ + librdmacm1 libnl-3-200 libnl-route-3-200 libnl-route-3-dev libnl-3-dev \ + ibverbs-providers infiniband-diags perftest \ + libgoogle-glog-dev libgtest-dev libjsoncpp-dev libunwind-dev \ + libboost-all-dev libssl-dev \ + libgrpc-dev libgrpc++-dev libprotobuf-dev protobuf-compiler-grpc \ + pybind11-dev \ + libhiredis-dev libcurl4-openssl-dev \ + libczmq4 libczmq-dev \ + libfabric-dev \ + patchelf \ + nvidia-dkms-550 \ + devscripts debhelper fakeroot dkms check libsubunit0 libsubunit-dev \ + && ln -sf /usr/bin/python3 /usr/bin/python \ + && rm -rf /var/lib/apt/lists/* \ + && apt-get clean + + +# --- Install SGLang missing package +RUN pip install netifaces + +# --- Install nightly PyTorch --- +RUN pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128 --force-reinstall + + +# GDRCopy installation +RUN mkdir -p /tmp/gdrcopy && cd /tmp \ + && git clone https://github.com/NVIDIA/gdrcopy.git -b v2.4.4 \ + && cd gdrcopy/packages \ + && CUDA=/usr/local/cuda ./build-deb-packages.sh \ + && dpkg -i gdrdrv-dkms_*.deb libgdrapi_*.deb gdrcopy-tests_*.deb gdrcopy_*.deb \ + && cd / && rm -rf /tmp/gdrcopy + +# Fix DeepEP IBGDA symlink +RUN ln -sf /usr/lib/$(uname -m)-linux-gnu/libmlx5.so.1 /usr/lib/$(uname -m)-linux-gnu/libmlx5.so + +# Clone and install SGLang +# FIXME: Forcing SGLang to 2a2d3478afe8cdb336888f2e6faa3775ac40254e because sgl-kernel v0.2.5 is missing aarch64 package +WORKDIR /sgl-workspace +RUN python3 -m pip install --no-cache-dir --upgrade pip setuptools wheel html5lib six \ + && git clone https://github.com/sgl-project/sglang.git \ + && cd sglang \ + && git checkout 2a2d3478afe8cdb336888f2e6faa3775ac40254e \ + && case "$CUDA_VERSION" in \ + 12.6.1) CUINDEX=126 ;; \ + 12.8.1) CUINDEX=128 ;; \ + *) echo "Unsupported CUDA version: $CUDA_VERSION" && exit 1 ;; \ + esac \ + && python3 -m pip install --no-cache-dir -e "python[${BUILD_TYPE}]" --extra-index-url https://download.pytorch.org/whl/cu${CUINDEX} \ + && if [ "$CUDA_VERSION" = "12.8.1" ]; then \ + python3 -m pip install --no-cache-dir nvidia-nccl-cu12==2.27.5 --force-reinstall --no-deps ; \ + python3 -m pip install --no-cache-dir https://github.com/sgl-project/whl/releases/download/v0.2.4/sgl_kernel-0.2.4+cu128-cp39-abi3-manylinux2014_$(uname -m).whl --force-reinstall --no-deps ; \ + fi + + +# Build NVSHMEM +# Build and install NVSHMEM + DeepEP +RUN wget https://developer.download.nvidia.com/compute/redist/nvshmem/3.2.5/source/nvshmem_src_3.2.5-1.txz \ + && git clone https://github.com/fzyzcjy/DeepEP.git \ + && cd DeepEP \ + && git checkout 1b14ad661c7640137fcfe93cccb2694ede1220b0 \ + && cd .. \ + && tar -xf nvshmem_src_3.2.5-1.txz && mv nvshmem_src nvshmem \ + && cd nvshmem \ + && git apply /sgl-workspace/DeepEP/third-party/nvshmem.patch \ + && sed -i '1i#include ' examples/moe_shuffle.cu \ + && rm -f /sgl-workspace/nvshmem_src_3.2.5-1.txz \ + && NVSHMEM_SHMEM_SUPPORT=0 \ + NVSHMEM_UCX_SUPPORT=0 \ + NVSHMEM_USE_NCCL=0 \ + NVSHMEM_MPI_SUPPORT=0 \ + NVSHMEM_IBGDA_SUPPORT=1 \ + NVSHMEM_PMIX_SUPPORT=0 \ + NVSHMEM_TIMEOUT_DEVICE_POLLING=0 \ + NVSHMEM_USE_GDRCOPY=1 \ + cmake -S . -B build/ -DCMAKE_INSTALL_PREFIX=${NVSHMEM_DIR} -DCMAKE_CUDA_ARCHITECTURES="100;120" \ + && cmake --build build --target install -j \ + && cd /sgl-workspace/DeepEP \ + && NVSHMEM_DIR=${NVSHMEM_DIR} pip install . + +# Python tools +RUN python3 -m pip install --no-cache-dir \ + datamodel_code_generator \ + pre-commit \ + pytest \ + black \ + isort \ + icdiff \ + uv \ + wheel \ + scikit-build-core + +# Install development tools and utilities +RUN apt-get update && apt-get install -y \ + gdb \ + ninja-build \ + vim \ + tmux \ + htop \ + wget \ + curl \ + locales \ + lsof \ + git \ + git-lfs \ + zsh \ + tree \ + silversearcher-ag \ + cloc \ + unzip \ + pkg-config \ + libssl-dev \ + bear \ + ccache \ + less \ + && apt install -y rdma-core infiniband-diags openssh-server perftest ibverbs-providers libibumad3 libibverbs1 libnl-3-200 libnl-route-3-200 librdmacm1 \ + && rm -rf /var/lib/apt/lists/* \ + && apt-get clean + +RUN apt update -y \ + && apt install -y --no-install-recommends gnupg \ + && echo "deb http://developer.download.nvidia.com/devtools/repos/ubuntu2204/$(if [ "$(uname -m)" = "aarch64" ]; then echo "arm64"; else echo "amd64"; fi) /" | tee /etc/apt/sources.list.d/nvidia-devtools.list \ + && apt-key adv --fetch-keys http://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/$(if [ "$(uname -m)" = "aarch64" ]; then echo "sbsa"; else echo "x86_64"; fi)/3bf863cc.pub \ + && apt update -y \ + && apt install nsight-systems-cli -y + +RUN git clone https://github.com/kvcache-ai/Mooncake.git \ + && cd Mooncake \ + && bash dependencies.sh -y \ + && mkdir build \ + && cd build \ + && cmake .. -DUSE_MNNVL=ON \ + && make -j \ + && make install + +# Set up locale +RUN locale-gen en_US.UTF-8 +ENV LANG en_US.UTF-8 +ENV LANGUAGE en_US:en +ENV LC_ALL en_US.UTF-8 + +# Install minimal Python packages +RUN python3 -m pip install --no-cache-dir --break-system-packages \ + pytest \ + black \ + isort \ + icdiff \ + scikit_build_core \ + uv \ + pre-commit \ + pandas \ + matplotlib \ + tabulate + +# Install diff-so-fancy +RUN curl -LSso /usr/local/bin/diff-so-fancy https://github.com/so-fancy/diff-so-fancy/releases/download/v1.4.4/diff-so-fancy \ + && chmod +x /usr/local/bin/diff-so-fancy + +# Install clang-format +RUN curl -LSso /usr/local/bin/clang-format https://github.com/muttleyxd/clang-tools-static-binaries/releases/download/master-32d3ac78/clang-format-16_linux-amd64 \ + && chmod +x /usr/local/bin/clang-format + +# Install clangd +RUN curl -L https://github.com/clangd/clangd/releases/download/18.1.3/clangd-linux-18.1.3.zip -o clangd.zip \ + && unzip clangd.zip \ + && cp -r clangd_18.1.3/bin/* /usr/local/bin/ \ + && cp -r clangd_18.1.3/lib/* /usr/local/lib/ \ + && rm -rf clangd_18.1.3 clangd.zip + +# Install CMake +RUN CMAKE_VERSION=3.31.1 \ + && ARCH=$(uname -m) \ + && CMAKE_INSTALLER="cmake-${CMAKE_VERSION}-linux-${ARCH}" \ + && wget "https://github.com/Kitware/CMake/releases/download/v${CMAKE_VERSION}/${CMAKE_INSTALLER}.tar.gz" \ + && tar -xzf "${CMAKE_INSTALLER}.tar.gz" \ + && cp -r "${CMAKE_INSTALLER}/bin/"* /usr/local/bin/ \ + && cp -r "${CMAKE_INSTALLER}/share/"* /usr/local/share/ \ + && rm -rf "${CMAKE_INSTALLER}" "${CMAKE_INSTALLER}.tar.gz" + +# Add yank script +COPY --chown=root:root <<-"EOF" /usr/local/bin/yank +#!/bin/bash +put() { + esc=$1 + test -n "$TMUX" -o -z "${TERM##screen*}" && esc="\033Ptmux;\033$esc\033\\" + printf "$esc" +} +put "\033]52;c;!\a" +buf=$( cat "$@" ) +len=$( printf %s "$buf" | wc -c ) max=74994 +test $len -gt $max && echo "$0: input is $(( len - max )) bytes too long" >&2 +put "\033]52;c;$( printf %s "$buf" | head -c $max | base64 | tr -d '\r\n' )\a" +test -n "$TMUX" && tmux set-buffer "$buf" ||: +EOF + +RUN chmod +x /usr/local/bin/yank + +# Install oh-my-zsh and plugins +RUN sh -c "$(curl -fsSL https://raw.githubusercontent.com/ohmyzsh/ohmyzsh/master/tools/install.sh)" "" --unattended \ + && git clone https://github.com/zsh-users/zsh-autosuggestions ${ZSH_CUSTOM:-~/.oh-my-zsh/custom}/plugins/zsh-autosuggestions \ + && git clone https://github.com/zsh-users/zsh-syntax-highlighting.git ${ZSH_CUSTOM:-~/.oh-my-zsh/custom}/plugins/zsh-syntax-highlighting + +# Configure Vim +COPY --chown=root:root <<-"EOF" /root/.vimrc +function! Yank(text) abort + let escape = system('yank', a:text) + if v:shell_error + echoerr escape + else + call writefile([escape], '/dev/tty', 'b') + endif +endfunction + +noremap y y:call Yank(@0) + +" automatically run yank(1) whenever yanking in Vim +function! CopyYank() abort + call Yank(join(v:event.regcontents, "\n")) +endfunction + +autocmd TextYankPost * call CopyYank() + +" Basic settings +set number +syntax on +set mouse=a +filetype indent on + +" Indentation +set autoindent nosmartindent +set smarttab +set expandtab +set shiftwidth=4 +set softtabstop=4 + +" Visual guides +set colorcolumn=120 +highlight ColorColumn ctermbg=5 + +" Status line +set laststatus=2 +set statusline=%<%f\ %h%m%r%=%{\"[\".(&fenc==\"\"?&enc:&fenc).((exists(\"+bomb\")\ &&\ &bomb)?\",B\":\"\").\"]\ \"}%k\ %-14.(%l,%c%V%)\ %P + +" Backspace behavior +set backspace=2 + +" Encoding +set encoding=utf-8 +set fileencoding=utf-8 +EOF + +# Configure tmux +COPY --chown=root:root <<-"EOF" /root/.tmux.conf +# Pane border styling +set -g pane-border-style fg='#742727',bg=black +set -g pane-active-border-style fg=red,bg=black + +# Status bar styling +set -g status-style bg='#0C8A92',fg=black + +# Change prefix key to backtick +set-option -g prefix ` +unbind C-b +bind-key ` send-prefix + +# Split panes using - and = with current path +unbind '"' +bind - splitw -v -c '#{pane_current_path}' +unbind '%' +bind = splitw -h -c '#{pane_current_path}' + +# Vi mode settings +bind-key -T copy-mode-vi Y send-keys -X copy-pipe 'yank > #{pane_tty}' +set-window-option -g mode-keys vi + +# Other settings +set-option -g escape-time 0 +set-option -g base-index 1 +set-window-option -g mouse on +EOF + +# Configure Git +RUN git config --global core.editor "vim" \ + && git config --global core.whitespace "fix,-indent-with-non-tab,trailing-space,cr-at-eol" \ + && git config --global core.pager "diff-so-fancy | less --tabs=4 -RFX" \ + && git config --global color.ui true \ + && git config --global color."diff-highlight".oldNormal "red bold" \ + && git config --global color."diff-highlight".oldHighlight "red bold 52" \ + && git config --global color."diff-highlight".newNormal "green bold" \ + && git config --global color."diff-highlight".newHighlight "green bold 22" \ + && git config --global color.diff.meta "11" \ + && git config --global color.diff.frag "magenta bold" \ + && git config --global color.diff.commit "yellow bold" \ + && git config --global color.diff.old "red bold" \ + && git config --global color.diff.new "green bold" \ + && git config --global color.diff.whitespace "red reverse" \ + && git config --global alias.lg "log --color --graph --pretty=format:'%Cred%h%Creset - %s %Cgreen(%cr) %C(bold blue)<%an>%Creset%C(auto)%d%Creset' --abbrev-commit --" \ + && git config --global http.sslVerify false \ + && git config --global pull.rebase true + +# Configure zsh +COPY --chown=root:root <<-"EOF" /root/.zshrc +export ZSH="/root/.oh-my-zsh" + +# Theme +ZSH_THEME="robbyrussell" + +# Plugins +plugins=( + git + z + zsh-autosuggestions + zsh-syntax-highlighting +) + +source $ZSH/oh-my-zsh.sh + +# Aliases +alias ll='ls -alF' +alias la='ls -A' +alias l='ls -CF' +alias vi='vim' + +# Enhanced history +HISTSIZE=10000 +SAVEHIST=10000 +setopt HIST_IGNORE_ALL_DUPS +setopt HIST_FIND_NO_DUPS +setopt INC_APPEND_HISTORY +EOF + +RUN set -euxo ; \ + curl --proto '=https' --tlsv1.2 -sSf https://just.systems/install.sh | bash -s -- --to /usr/local/bin + +# Set workspace directory +WORKDIR /sgl-workspace/sglang From 15ad6c908670492243cfcb820ca24c40cc9b840d Mon Sep 17 00:00:00 2001 From: Cheng Wan <54331508+ch-wan@users.noreply.github.com> Date: Sat, 19 Jul 2025 00:51:15 -0700 Subject: [PATCH 047/396] [1/N] MoE Refactor: refactor `select_experts` (#7966) --- python/sglang/srt/custom_op.py | 7 +- python/sglang/srt/layers/linear.py | 2 +- python/sglang/srt/layers/moe/ep_moe/layer.py | 87 ++------ .../sglang/srt/layers/moe/fused_moe_native.py | 54 +---- .../layers/moe/fused_moe_triton/fused_moe.py | 45 +--- .../srt/layers/moe/fused_moe_triton/layer.py | 35 +-- python/sglang/srt/layers/moe/topk.py | 176 ++++++++++++++- .../srt/layers/quantization/__init__.py | 32 +-- python/sglang/srt/layers/quantization/awq.py | 39 +--- .../srt/layers/quantization/base_config.py | 21 +- .../srt/layers/quantization/blockwise_int8.py | 35 +-- .../compressed_tensors_moe.py | 92 ++------ python/sglang/srt/layers/quantization/fp8.py | 52 +---- python/sglang/srt/layers/quantization/gptq.py | 35 +-- .../srt/layers/quantization/modelopt_quant.py | 63 +----- .../srt/layers/quantization/moe_wna16.py | 34 +-- .../sglang/srt/layers/quantization/unquant.py | 207 +++++------------- .../srt/layers/quantization/w8a8_fp8.py | 37 +--- .../srt/layers/quantization/w8a8_int8.py | 89 ++------ python/sglang/srt/models/deepseek.py | 15 +- python/sglang/srt/models/deepseek_v2.py | 52 ++--- python/sglang/srt/models/granitemoe.py | 10 +- python/sglang/srt/models/grok.py | 12 +- python/sglang/srt/models/hunyuan.py | 13 +- python/sglang/srt/models/llama4.py | 22 +- python/sglang/srt/models/mixtral.py | 11 +- python/sglang/srt/models/olmoe.py | 13 +- python/sglang/srt/models/phimoe.py | 12 +- python/sglang/srt/models/qwen2_moe.py | 14 +- python/sglang/srt/models/qwen3_moe.py | 31 ++- python/sglang/test/test_block_fp8.py | 11 +- python/sglang/test/test_block_fp8_ep.py | 2 +- python/sglang/test/test_cutlass_w4a8_moe.py | 4 +- python/sglang/test/test_fp4_moe.py | 4 +- test/srt/test_block_int8.py | 11 +- test/srt/test_fused_moe.py | 19 +- test/srt/test_int8_kernel.py | 10 +- .../srt/test_triton_moe_channel_fp8_kernel.py | 10 +- test/srt/test_triton_moe_wna16.py | 11 +- 39 files changed, 557 insertions(+), 872 deletions(-) diff --git a/python/sglang/srt/custom_op.py b/python/sglang/srt/custom_op.py index 5b502a153326..8c662b5ccb57 100644 --- a/python/sglang/srt/custom_op.py +++ b/python/sglang/srt/custom_op.py @@ -29,15 +29,18 @@ def enter_torch_compile(self, num_tokens: int): self._original_forward_method = self._forward_method # NOTE: Temporarily workaround MoE + # The performance of torch.compile on this layer is not always good when bs > 1, + # so we decide to only use torch.compile when bs=1 if "FusedMoE" in self.__class__.__name__: if num_tokens == 1: from sglang.srt.layers.moe.fused_moe_native import ( fused_moe_forward_native, ) - # The performance of torch.compile on this layer is not always good when bs > 1, - # so we decide to only use torch.compile when bs =1 self._forward_method = fused_moe_forward_native + elif "TopK" in self.__class__.__name__: + if num_tokens == 1: + self._forward_method = self.forward_native else: self._forward_method = self.forward_native self.is_torch_compile = True diff --git a/python/sglang/srt/layers/linear.py b/python/sglang/srt/layers/linear.py index 07be9a3c6b14..9d8ab8632752 100644 --- a/python/sglang/srt/layers/linear.py +++ b/python/sglang/srt/layers/linear.py @@ -756,7 +756,7 @@ def __init__( bias: bool = True, skip_bias_add: bool = False, params_dtype: Optional[torch.dtype] = None, - quant_config: Optional["QuantizationConfig"] = None, + quant_config: Optional[QuantizationConfig] = None, prefix: str = "", tp_rank: Optional[int] = None, tp_size: Optional[int] = None, diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index a839b47febed..77d849f3f67b 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -1,17 +1,13 @@ import logging -from typing import Callable, List, Optional, Tuple +from typing import List, Optional, Tuple -import einops import torch -from torch.nn import Module -from sglang.srt.custom_op import CustomOp from sglang.srt.distributed import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, ) from sglang.srt.eplb.expert_location import get_global_expert_location_metadata -from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo from sglang.srt.layers.moe.ep_moe.kernels import ( ep_gather, ep_scatter, @@ -28,7 +24,7 @@ tma_align_input_scale, ) from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE -from sglang.srt.layers.moe.topk import select_experts +from sglang.srt.layers.moe.topk import TopKOutput from sglang.srt.layers.quantization import deep_gemm_wrapper from sglang.srt.layers.quantization.base_config import ( QuantizationConfig, @@ -162,16 +158,9 @@ def __init__( intermediate_size: int, layer_id: int, params_dtype: Optional[torch.dtype] = None, - renormalize: bool = True, - use_grouped_topk: bool = False, - num_expert_group: Optional[int] = None, - num_fused_shared_experts: int = 0, - topk_group: Optional[int] = None, quant_config: Optional[QuantizationConfig] = None, tp_size: Optional[int] = None, prefix: str = "", - correction_bias: Optional[torch.Tensor] = None, - custom_routing_function: Optional[Callable] = None, activation: str = "silu", routed_scaling_factor: Optional[float] = None, use_per_token_if_dynamic: bool = True, @@ -189,24 +178,12 @@ def __init__( self.layer_id = layer_id self.num_experts = num_experts assert self.num_experts % self.tp_size == 0 - assert ( - num_fused_shared_experts == 0 - ), "num_fused_shared_experts is not supported in EP" - self.num_fused_shared_experts = num_fused_shared_experts self.num_experts_per_partition, self.expert_map = self.determine_expert_map() self.start_expert_id = self.tp_rank * self.num_experts_per_partition self.end_expert_id = self.start_expert_id + self.num_experts_per_partition - 1 self.top_k = top_k self.intermediate_size = intermediate_size - self.renormalize = renormalize - self.use_grouped_topk = use_grouped_topk - if self.use_grouped_topk: - assert num_expert_group is not None and topk_group is not None - self.num_expert_group = num_expert_group - self.topk_group = topk_group - self.correction_bias = correction_bias - self.custom_routing_function = custom_routing_function self.activation = activation self.routed_scaling_factor = routed_scaling_factor self.use_per_token_if_dynamic = use_per_token_if_dynamic @@ -311,33 +288,24 @@ def determine_expert_map(self) -> Tuple[int, Optional[torch.Tensor]]: ) return (local_num_experts, expert_map) - def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): + def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput): if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8: - return self.forward_deepgemm(hidden_states, router_logits) + return self.forward_deepgemm(hidden_states, topk_output) else: - return self.forward_normal(hidden_states, router_logits) + return self.forward_normal(hidden_states, topk_output) def forward_deepgemm( - self, hidden_states: torch.Tensor, router_logits: torch.Tensor + self, + hidden_states: torch.Tensor, + topk_output: TopKOutput, ): assert self.quant_method is not None assert self.activation == "silu" hidden_states_shape = hidden_states.shape hidden_states_dtype = hidden_states.dtype hidden_states_device = hidden_states.device - topk_weights, topk_ids = select_experts( - hidden_states=hidden_states, - router_logits=router_logits, - top_k=self.top_k, - use_grouped_topk=self.use_grouped_topk, - renormalize=self.renormalize, - topk_group=self.topk_group, - num_expert_group=self.num_expert_group, - num_fused_shared_experts=self.num_fused_shared_experts, - correction_bias=self.correction_bias, - custom_routing_function=self.custom_routing_function, - routed_scaling_factor=self.routed_scaling_factor, - ) + + topk_weights, topk_ids, _ = topk_output if not self.use_block_quant: # Convert per-tensor quant to per-block quant by repeating scales for forward_deepgemm @@ -469,8 +437,10 @@ def forward_deepgemm( ) return output - def forward_normal(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): + def forward_normal(self, hidden_states: torch.Tensor, topk_output: TopKOutput): assert self.quant_method is not None + topk_weights, topk_ids, _ = topk_output + hidden_states_shape = hidden_states.shape hidden_states_dtype = hidden_states.dtype hidden_states_device = hidden_states.device @@ -481,23 +451,6 @@ def forward_normal(self, hidden_states: torch.Tensor, router_logits: torch.Tenso use_per_token_if_dynamic=self.use_per_token_if_dynamic, ) - topk_weights, topk_ids = select_experts( - hidden_states=hidden_states, - router_logits=router_logits, - top_k=self.top_k, - use_grouped_topk=self.use_grouped_topk, - renormalize=self.renormalize, - topk_group=self.topk_group, - num_expert_group=self.num_expert_group, - num_fused_shared_experts=self.num_fused_shared_experts, - correction_bias=self.correction_bias, - custom_routing_function=self.custom_routing_function, - routed_scaling_factor=self.routed_scaling_factor, - expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new( - layer_id=self.layer_id, - ), - ) - if self.use_w4afp8: local_topk_ids = topk_ids if self.expert_map is not None: @@ -916,16 +869,9 @@ def __init__( intermediate_size: int, layer_id: int, params_dtype: Optional[torch.dtype] = None, - renormalize: bool = True, - use_grouped_topk: bool = False, - num_expert_group: Optional[int] = None, - num_fused_shared_experts: int = 0, - topk_group: Optional[int] = None, quant_config: Optional[QuantizationConfig] = None, tp_size: Optional[int] = None, prefix: str = "", - correction_bias: Optional[torch.Tensor] = None, - custom_routing_function: Optional[Callable] = None, activation: str = "silu", routed_scaling_factor: Optional[float] = None, deepep_mode: DeepEPMode = DeepEPMode.auto, @@ -937,16 +883,9 @@ def __init__( intermediate_size=intermediate_size, layer_id=layer_id, params_dtype=params_dtype, - renormalize=renormalize, - use_grouped_topk=use_grouped_topk, - num_expert_group=num_expert_group, - num_fused_shared_experts=num_fused_shared_experts, - topk_group=topk_group, quant_config=quant_config, tp_size=tp_size, prefix=prefix, - correction_bias=correction_bias, - custom_routing_function=custom_routing_function, activation=activation, routed_scaling_factor=routed_scaling_factor, ) diff --git a/python/sglang/srt/layers/moe/fused_moe_native.py b/python/sglang/srt/layers/moe/fused_moe_native.py index 25645ad00e91..61eacd78c02c 100644 --- a/python/sglang/srt/layers/moe/fused_moe_native.py +++ b/python/sglang/srt/layers/moe/fused_moe_native.py @@ -9,21 +9,14 @@ from torch.nn import functional as F from sglang.srt.layers.activation import GeluAndMul, SiluAndMul -from sglang.srt.layers.moe.topk import select_experts +from sglang.srt.layers.moe.topk import TopKOutput def fused_moe_forward_native( layer: torch.nn.Module, x: torch.Tensor, - use_grouped_topk: bool, - top_k: int, - router_logits: torch.Tensor, - renormalize: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - num_fused_shared_experts: int = 0, - custom_routing_function: Optional[Callable] = None, - correction_bias: Optional[torch.Tensor] = None, + topk_output: TopKOutput, + *, activation: str = "silu", apply_router_weight_on_input: bool = False, inplace: bool = True, @@ -34,20 +27,7 @@ def fused_moe_forward_native( if apply_router_weight_on_input: raise NotImplementedError() - topk_weights, topk_ids = select_experts( - hidden_states=x, - router_logits=router_logits, - use_grouped_topk=use_grouped_topk, - top_k=top_k, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - num_fused_shared_experts=num_fused_shared_experts, - custom_routing_function=custom_routing_function, - correction_bias=correction_bias, - routed_scaling_factor=routed_scaling_factor, - torch_native=True, - ) + topk_weights, topk_ids, _ = topk_output w13_weights = layer.w13_weight[topk_ids] w1_weights, w3_weights = torch.chunk(w13_weights, 2, dim=2) @@ -67,15 +47,8 @@ def fused_moe_forward_native( def moe_forward_native( layer: torch.nn.Module, x: torch.Tensor, - use_grouped_topk: bool, - top_k: int, - router_logits: torch.Tensor, - renormalize: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - num_fused_shared_experts: int = 0, - custom_routing_function: Optional[Callable] = None, - correction_bias: Optional[torch.Tensor] = None, + topk_output: TopKOutput, + *, activation: str = "silu", apply_router_weight_on_input: bool = False, inplace: bool = True, @@ -86,20 +59,7 @@ def moe_forward_native( if apply_router_weight_on_input: raise NotImplementedError() - topk_weights, topk_ids = select_experts( - hidden_states=x, - router_logits=router_logits, - use_grouped_topk=use_grouped_topk, - top_k=top_k, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - num_fused_shared_experts=num_fused_shared_experts, - custom_routing_function=custom_routing_function, - correction_bias=correction_bias, - torch_native=True, - routed_scaling_factor=routed_scaling_factor, - ) + topk_weights, topk_ids, _ = topk_output # Ref code from https://huggingface.co/deepseek-ai/DeepSeek-V2/blob/e0828e3cc0a03408724b80c3cc92c8e072db8d01/modeling_deepseek.py#L589 len_experts = layer.num_experts diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py index baf8f5c87e5b..a39d6d5d3da4 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py @@ -6,13 +6,13 @@ import json import logging import os -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple import torch import triton import triton.language as tl -from sglang.srt.layers.moe.topk import select_experts +from sglang.srt.layers.moe.topk import TopKOutput from sglang.srt.layers.quantization.fp8_kernel import ( per_token_group_quant_fp8, scaled_fp8_quant, @@ -1328,8 +1328,7 @@ def fused_experts( hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, + topk_output: TopKOutput, inplace: bool = False, activation: str = "silu", apply_router_weight_on_input: bool = False, @@ -1348,7 +1347,7 @@ def fused_experts( no_combine: bool = False, routed_scaling_factor: Optional[float] = None, ): - + topk_weights, topk_ids, _ = topk_output if inplace: assert not no_combine, "no combine + inplace makes no sense" torch.ops.sglang.inplace_fused_experts( @@ -1732,17 +1731,10 @@ def fused_moe( hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, - gating_output: torch.Tensor, - topk: int, - renormalize: bool, + topk_output: TopKOutput, inplace: bool = False, activation: str = "silu", apply_router_weight_on_input: bool = False, - use_grouped_topk: bool = False, - num_expert_group: Optional[int] = None, - num_fused_shared_experts: int = 0, - topk_group: Optional[int] = None, - custom_routing_function: Optional[Callable] = None, use_fp8_w8a8: bool = False, use_int8_w8a8: bool = False, use_int8_w8a16: bool = False, @@ -1766,16 +1758,9 @@ def fused_moe( - hidden_states (torch.Tensor): The input tensor to the MoE layer. - w1 (torch.Tensor): The first set of expert weights. - w2 (torch.Tensor): The second set of expert weights. - - gating_output (torch.Tensor): The output of the gating operation - (before softmax). - - topk (int): The number of top-k experts to select. - - renormalize (bool): If True, renormalize the top-k weights to sum to 1. + - topk_output (TopKOutput): The top-k output of the experts. - inplace (bool): If True, perform the operation in-place. Defaults to False. - - num_expert_group: Optional[int]: additional parameter for grouped_topk - - topk_group: Optional[int]: additional parameter for grouped_topk - - use_grouped_topk: If True, use grouped_topk instead of fused_topk - note: Deepseek V2/V3/R1 series models use grouped_topk - use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner products for w1 and w2. Defaults to False. - use_int8_w8a8 (bool): If True, use int8 arithmetic to compute the inner @@ -1799,28 +1784,12 @@ def fused_moe( Returns: - torch.Tensor: The output tensor after applying the MoE layer. """ - # Check constraints. - assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch" - - topk_weights, topk_ids = select_experts( - hidden_states=hidden_states, - router_logits=gating_output, - use_grouped_topk=use_grouped_topk, - top_k=topk, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - num_fused_shared_experts=num_fused_shared_experts, - custom_routing_function=custom_routing_function, - routed_scaling_factor=routed_scaling_factor, - ) return fused_experts( hidden_states, w1, w2, - topk_weights, - topk_ids, + topk_output, inplace=inplace, activation=activation, apply_router_weight_on_input=apply_router_weight_on_input, diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index 41ae6274b087..0c3cb0422f55 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -2,7 +2,7 @@ import logging from enum import Enum -from typing import Callable, List, Optional, Tuple +from typing import List, Optional, Tuple import torch @@ -11,6 +11,7 @@ get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce, ) +from sglang.srt.layers.moe.topk import TopKOutput from sglang.srt.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase, @@ -59,22 +60,15 @@ class FusedMoE(torch.nn.Module): def __init__( self, num_experts: int, - top_k: int, hidden_size: int, intermediate_size: int, + top_k: Optional[int] = None, layer_id: Optional[int] = None, params_dtype: Optional[torch.dtype] = None, reduce_results: bool = False, - renormalize: bool = True, - use_grouped_topk: bool = False, - num_expert_group: Optional[int] = None, - num_fused_shared_experts: int = 0, - topk_group: Optional[int] = None, quant_config: Optional[QuantizationConfig] = None, tp_size: Optional[int] = None, prefix: str = "", - custom_routing_function: Optional[Callable] = None, - correction_bias: Optional[torch.Tensor] = None, activation: str = "silu", apply_router_weight_on_input: bool = False, use_presharded_weights: bool = False, @@ -89,6 +83,7 @@ def __init__( if params_dtype is None: params_dtype = torch.get_default_dtype() + self.top_k = top_k self.hidden_size = hidden_size self.tp_size = ( tp_size if tp_size is not None else get_tensor_model_parallel_world_size() @@ -126,19 +121,9 @@ def __init__( self.ep_rank = 0 self.local_num_experts = num_experts self.routed_scaling_factor = routed_scaling_factor - self.top_k = top_k assert intermediate_size % self.tp_size == 0 self.intermediate_size_per_partition = intermediate_size // self.tp_size self.reduce_results = reduce_results - self.renormalize = renormalize - self.use_grouped_topk = use_grouped_topk - if self.use_grouped_topk: - assert num_expert_group is not None and topk_group is not None - self.num_expert_group = num_expert_group - self.num_fused_shared_experts = num_fused_shared_experts - self.topk_group = topk_group - self.custom_routing_function = custom_routing_function - self.correction_bias = correction_bias self.activation = activation self.apply_router_weight_on_input = apply_router_weight_on_input self.use_presharded_weights = use_presharded_weights @@ -562,22 +547,14 @@ def weight_loader( ) return - def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): + def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput): assert self.quant_method is not None # Matrix multiply. final_hidden_states = self.quant_method.apply( layer=self, x=hidden_states, - router_logits=router_logits, - top_k=self.top_k, - renormalize=self.renormalize, - use_grouped_topk=self.use_grouped_topk, - topk_group=self.topk_group, - num_expert_group=self.num_expert_group, - num_fused_shared_experts=self.num_fused_shared_experts, - custom_routing_function=self.custom_routing_function, - correction_bias=self.correction_bias, + topk_output=topk_output, activation=self.activation, apply_router_weight_on_input=self.apply_router_weight_on_input, routed_scaling_factor=self.routed_scaling_factor, diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index 40fc0b61f650..bb3cf651542a 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -12,12 +12,15 @@ # limitations under the License. # ============================================================================== +from __future__ import annotations + import math -from typing import Callable, Optional +from typing import TYPE_CHECKING, Callable, NamedTuple, Optional import torch import torch.nn.functional as F +from sglang.srt.custom_op import CustomOp from sglang.srt.eplb import expert_location_dispatch from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder from sglang.srt.eplb.expert_location_dispatch import ( @@ -52,6 +55,168 @@ except ImportError: raise ImportError("aiter is required when SGLANG_USE_AITER is set to True") +if _is_npu: + import torch_npu + + +class TopKOutput(NamedTuple): + topk_weights: torch.Tensor + topk_ids: torch.Tensor + router_logits: torch.Tensor + + +class TopK(CustomOp): + + # TODO(ch-wan): support triton_kernels + + def __init__( + self, + top_k: int, + *, + use_grouped_topk: bool = False, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + renormalize: bool = True, + num_fused_shared_experts: int = 0, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + correction_bias: Optional[torch.Tensor] = None, + routed_scaling_factor: Optional[float] = None, + ): + # NOTE: scoring_func is not used for now, but we keep it for future use + # see https://github.com/sgl-project/sglang/pull/4505 for more details + super().__init__() + if use_grouped_topk: + assert num_expert_group is not None and topk_group is not None + self.top_k = top_k + self.use_grouped_topk = use_grouped_topk + self.renormalize = renormalize + self.topk_group = topk_group + self.num_expert_group = num_expert_group + self.num_fused_shared_experts = num_fused_shared_experts + self.custom_routing_function = custom_routing_function + self.correction_bias = correction_bias + self.routed_scaling_factor = routed_scaling_factor + + def forward_native( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + *, + num_token_non_padded: Optional[torch.Tensor] = None, + expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None, + ) -> TopKOutput: + torch_native = True + return select_experts( + hidden_states=hidden_states, + router_logits=router_logits, + top_k=self.top_k, + use_grouped_topk=self.use_grouped_topk, + renormalize=self.renormalize, + topk_group=self.topk_group, + num_expert_group=self.num_expert_group, + num_fused_shared_experts=self.num_fused_shared_experts, + custom_routing_function=self.custom_routing_function, + correction_bias=self.correction_bias, + torch_native=torch_native, + routed_scaling_factor=self.routed_scaling_factor, + num_token_non_padded=num_token_non_padded, + expert_location_dispatch_info=expert_location_dispatch_info, + ) + + def forward_cuda( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + *, + num_token_non_padded: Optional[torch.Tensor] = None, + expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None, + ) -> TopKOutput: + torch_native = False + return select_experts( + hidden_states=hidden_states, + router_logits=router_logits, + top_k=self.top_k, + use_grouped_topk=self.use_grouped_topk, + renormalize=self.renormalize, + topk_group=self.topk_group, + num_expert_group=self.num_expert_group, + num_fused_shared_experts=self.num_fused_shared_experts, + custom_routing_function=self.custom_routing_function, + correction_bias=self.correction_bias, + torch_native=torch_native, + routed_scaling_factor=self.routed_scaling_factor, + num_token_non_padded=num_token_non_padded, + expert_location_dispatch_info=expert_location_dispatch_info, + ) + + def forward_cpu( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + *, + num_token_non_padded: Optional[torch.Tensor] = None, + expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None, + ) -> TopKOutput: + return select_experts( + hidden_states=hidden_states, + router_logits=router_logits, + top_k=self.top_k, + use_grouped_topk=self.use_grouped_topk, + renormalize=self.renormalize, + topk_group=self.topk_group, + num_expert_group=self.num_expert_group, + num_fused_shared_experts=self.num_fused_shared_experts, + custom_routing_function=self.custom_routing_function, + correction_bias=self.correction_bias, + routed_scaling_factor=self.routed_scaling_factor, + num_token_non_padded=num_token_non_padded, + expert_location_dispatch_info=expert_location_dispatch_info, + ) + + def forward_npu( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + *, + num_token_non_padded: Optional[torch.Tensor] = None, + expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None, + ) -> TopKOutput: + global_num_experts = router_logits.shape[-1] + + # NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern + if global_num_experts == 256: + return torch_npu.npu_moe_gating_top_k( + router_logits, + k=self.top_k, + bias=self.correction_bias, + k_group=self.topk_group, + group_count=self.num_expert_group, + group_select_mode=1, + renorm=0, + norm_type=1, + routed_scaling_factor=1, + eps=float(1e-20), + ) + else: + torch_native = True + return select_experts( + hidden_states=hidden_states, + router_logits=router_logits, + top_k=self.top_k, + use_grouped_topk=self.use_grouped_topk, + renormalize=self.renormalize, + topk_group=self.topk_group, + num_expert_group=self.num_expert_group, + num_fused_shared_experts=self.num_fused_shared_experts, + custom_routing_function=self.custom_routing_function, + correction_bias=self.correction_bias, + torch_native=torch_native, + routed_scaling_factor=self.routed_scaling_factor, + num_token_non_padded=num_token_non_padded, + expert_location_dispatch_info=expert_location_dispatch_info, + ) + def fused_topk_torch_native( hidden_states: torch.Tensor, @@ -436,8 +601,9 @@ def select_experts( hidden_states: torch.Tensor, router_logits: torch.Tensor, top_k: int, - use_grouped_topk: bool, - renormalize: bool, + *, + use_grouped_topk: bool = False, + renormalize: bool = False, topk_group: Optional[int] = None, num_expert_group: Optional[int] = None, num_fused_shared_experts: int = 0, @@ -447,7 +613,7 @@ def select_experts( routed_scaling_factor: Optional[float] = None, num_token_non_padded: Optional[torch.Tensor] = None, expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None, -): +) -> TopKOutput: router_logits, correction_bias = ( expert_location_dispatch.transform_select_experts_inputs( router_logits=router_logits, @@ -522,4 +688,4 @@ def select_experts( get_global_expert_distribution_recorder().on_select_experts(topk_ids=topk_ids) - return topk_weights, topk_ids + return TopKOutput(topk_weights, topk_ids, router_logits) diff --git a/python/sglang/srt/layers/quantization/__init__.py b/python/sglang/srt/layers/quantization/__init__.py index d51186465a0f..496cbc8f5392 100644 --- a/python/sglang/srt/layers/quantization/__init__.py +++ b/python/sglang/srt/layers/quantization/__init__.py @@ -1,7 +1,9 @@ # Adapted from https://raw.githubusercontent.com/vllm-project/vllm/v0.5.5/vllm/model_executor/layers/quantization/__init__.py +from __future__ import annotations + import builtins import inspect -from typing import Callable, Dict, Optional, Type, Union +from typing import TYPE_CHECKING, Callable, Dict, Optional, Type, Union import torch @@ -65,6 +67,9 @@ def override_quantization_method(self, *args, **kwargs): from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config +if TYPE_CHECKING: + from sglang.srt.layers.moe.topk import TopKOutput + # Base quantization methods that don't depend on vllm BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = { "fp8": Fp8Config, @@ -186,15 +191,8 @@ def new_apply( self, layer: torch.nn.Module, x: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool, - use_grouped_topk: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - num_fused_shared_experts: int = 0, - custom_routing_function: Optional[Callable] = None, - correction_bias: Optional[torch.Tensor] = None, + topk_output: TopKOutput, + *, activation: str = "silu", apply_router_weight_on_input: bool = False, inplace: bool = True, @@ -208,20 +206,8 @@ def new_apply( "self": self, "layer": layer, "x": x, - "router_logits": router_logits, - "top_k": top_k, - "renormalize": renormalize, - "use_grouped_topk": use_grouped_topk, - "topk_group": topk_group, - "num_expert_group": num_expert_group, - "custom_routing_function": custom_routing_function, + "topk_output": topk_output, } - if correction_bias is not None: - if not has_correction_bias: - raise ValueError( - "Please increase the version of your vllm. Try `pip install vllm==0.9.0.1`" - ) - kwargs["e_score_correction_bias"] = correction_bias return original_apply(**kwargs) setattr(class_obj, "apply", new_apply) diff --git a/python/sglang/srt/layers/quantization/awq.py b/python/sglang/srt/layers/quantization/awq.py index c20beb2ff0b9..0f66b954ca72 100644 --- a/python/sglang/srt/layers/quantization/awq.py +++ b/python/sglang/srt/layers/quantization/awq.py @@ -3,7 +3,7 @@ import logging import warnings -from typing import Any, Callable, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional import torch @@ -33,6 +33,9 @@ from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod from sglang.srt.layers.quantization.utils import replace_parameter +if TYPE_CHECKING: + from sglang.srt.layers.moe.topk import TopKOutput + try: from vllm import _custom_ops as ops @@ -737,45 +740,19 @@ def apply( self, layer: torch.nn.Module, x: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool, - use_grouped_topk: bool = False, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - num_fused_shared_experts: int = 0, - custom_routing_function: Optional[Callable] = None, - scoring_func: str = "softmax", - correction_bias: Optional[torch.Tensor] = None, - apply_router_weight_on_input: bool = False, + topk_output: TopKOutput, + *, activation: str = "silu", - routed_scaling_factor: Optional[float] = None, + **kwargs, ) -> torch.Tensor: - # Delay the import to avoid circular dependency - from sglang.srt.layers.moe.topk import select_experts assert activation == "silu", "Only SiLU activation is supported." - assert ( - scoring_func == "softmax" - ), "Only softmax score func is supported for now." # The input must currently be float16 orig_dtype = x.dtype x = x.half() - topk_weights, topk_ids = select_experts( - hidden_states=x, - router_logits=router_logits, - top_k=top_k, - use_grouped_topk=use_grouped_topk, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - num_fused_shared_experts=num_fused_shared_experts, - custom_routing_function=custom_routing_function, - correction_bias=correction_bias, - routed_scaling_factor=routed_scaling_factor, - ) + topk_weights, topk_ids, router_logits = topk_output return fused_marlin_moe( x, diff --git a/python/sglang/srt/layers/quantization/base_config.py b/python/sglang/srt/layers/quantization/base_config.py index 607151671bff..bf24c3701076 100644 --- a/python/sglang/srt/layers/quantization/base_config.py +++ b/python/sglang/srt/layers/quantization/base_config.py @@ -1,12 +1,16 @@ # Adapted from https://raw.githubusercontent.com/vllm-project/vllm/v0.5.5/vllm/model_executor/layers/quantization/base_config.py +from __future__ import annotations import inspect from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional, Type +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type import torch from torch import nn +if TYPE_CHECKING: + from sglang.srt.layers.moe.topk import TopKOutput + class QuantizeMethodBase(ABC): """Base class for different quantized methods.""" @@ -88,19 +92,22 @@ def create_weights( params_dtype: torch.dtype, **extra_weight_attrs, ): - raise NotImplementedError() + raise NotImplementedError @abstractmethod def apply( self, layer: torch.nn.Module, x: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool, - use_grouped_topk: bool, + topk_output: TopKOutput, + *, + activation: str = "silu", + apply_router_weight_on_input: bool = False, + inplace: bool = True, + no_combine: bool = False, + routed_scaling_factor: Optional[float] = None, ) -> torch.Tensor: - raise NotImplementedError() + raise NotImplementedError class QuantizationConfig(ABC): diff --git a/python/sglang/srt/layers/quantization/blockwise_int8.py b/python/sglang/srt/layers/quantization/blockwise_int8.py index a1da999b3af1..62dc45ad9ca9 100644 --- a/python/sglang/srt/layers/quantization/blockwise_int8.py +++ b/python/sglang/srt/layers/quantization/blockwise_int8.py @@ -3,7 +3,7 @@ from __future__ import annotations import logging -from typing import Any, Callable, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional import torch from torch.nn import Module @@ -21,6 +21,9 @@ from sglang.srt.layers.quantization.utils import is_layer_skipped from sglang.srt.utils import set_weight_attrs +if TYPE_CHECKING: + from sglang.srt.layers.moe.topk import TopKOutput + ACTIVATION_SCHEMES = ["static", "dynamic"] logger = logging.getLogger(__name__) @@ -344,15 +347,8 @@ def apply( self, layer: torch.nn.Module, x: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool, - use_grouped_topk: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - num_fused_shared_experts: int = 0, - custom_routing_function: Optional[Callable] = None, - correction_bias: Optional[torch.Tensor] = None, + topk_output: TopKOutput, + *, activation: str = "silu", apply_router_weight_on_input: bool = False, inplace: bool = True, @@ -360,30 +356,13 @@ def apply( routed_scaling_factor: Optional[float] = None, ) -> torch.Tensor: from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts - from sglang.srt.layers.moe.topk import select_experts - - # Expert selection - topk_weights, topk_ids = select_experts( - hidden_states=x, - router_logits=router_logits, - use_grouped_topk=use_grouped_topk, - top_k=top_k, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - num_fused_shared_experts=num_fused_shared_experts, - custom_routing_function=custom_routing_function, - correction_bias=correction_bias, - routed_scaling_factor=routed_scaling_factor, - ) # Expert fusion with INT8 quantization return fused_experts( x, layer.w13_weight, layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, + topk_output=topk_output, inplace=inplace, activation=activation, apply_router_weight_on_input=apply_router_weight_on_input, diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py index b471184d2260..39e5f9e252da 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -1,15 +1,17 @@ # Adapted from https://github.com/vllm-project/vllm/tree/v0.8.2/vllm/model_executor/layers/quantization/compressed_tensors # SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations import enum import logging from enum import Enum -from typing import Callable, List, Optional +from typing import TYPE_CHECKING, List, Optional import torch from compressed_tensors import CompressionFormat from compressed_tensors.quantization import QuantizationStrategy +from sglang.srt.layers.quantization.base_config import FusedMoEMethodBase from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz, scaled_fp8_quant from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz from sglang.srt.layers.quantization.utils import ( @@ -20,6 +22,12 @@ ) from sglang.srt.utils import is_cpu, is_cuda, is_npu, set_weight_attrs +if TYPE_CHECKING: + from sglang.srt.layers.moe.topk import TopKOutput + from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import ( + CompressedTensorsConfig, + ) + _is_cuda = is_cuda() _is_npu = is_npu() _is_cpu_amx_available = cpu_has_amx_support() @@ -51,7 +59,7 @@ class GPTQMarlinState(Enum): ] -class CompressedTensorsMoEMethod: +class CompressedTensorsMoEMethod(FusedMoEMethodBase): def __new__(cls, *args, **kwargs): if cls is CompressedTensorsMoEMethod: return super().__new__(cls) @@ -59,7 +67,7 @@ def __new__(cls, *args, **kwargs): @staticmethod def get_moe_method( - quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501 + quant_config: CompressedTensorsConfig, ) -> "CompressedTensorsMoEMethod": # TODO: @dsikka: refactor this to use schemes as other kernels # are supported + check if the layer is being ignored. @@ -82,9 +90,7 @@ def get_moe_method( class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): - def __init__( - self, quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501 - ): + def __init__(self, quant_config: CompressedTensorsConfig): self.quant_config = quant_config self.weight_quant = self.quant_config.target_scheme_map["Linear"].get("weights") self.input_quant = self.quant_config.target_scheme_map["Linear"].get( @@ -270,47 +276,21 @@ def apply( self, layer: torch.nn.Module, x: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool, - use_grouped_topk: bool = False, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - num_fused_shared_experts: int = 0, - global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - custom_routing_function: Optional[Callable] = None, - scoring_func: str = "softmax", - correction_bias: Optional[torch.Tensor] = None, + topk_output: TopKOutput, + *, activation: str = "silu", + apply_router_weight_on_input: bool = False, inplace: bool = True, no_combine: bool = False, - apply_router_weight_on_input: bool = False, routed_scaling_factor: Optional[float] = None, ) -> torch.Tensor: from sglang.srt.layers.moe.fused_moe_triton import fused_experts - from sglang.srt.layers.moe.topk import select_experts - - topk_weights, topk_ids = select_experts( - hidden_states=x, - router_logits=router_logits, - use_grouped_topk=use_grouped_topk, - top_k=top_k, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - num_fused_shared_experts=num_fused_shared_experts, - custom_routing_function=custom_routing_function, - correction_bias=correction_bias, - routed_scaling_factor=routed_scaling_factor, - ) return fused_experts( x, layer.w13_weight, layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, + topk_output=topk_output, inplace=inplace, activation=activation, use_fp8_w8a8=True, @@ -327,9 +307,7 @@ def apply( class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): - def __init__( - self, quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501 - ): + def __init__(self, quant_config: CompressedTensorsConfig): self.quant_config = quant_config # TODO: @dsikka: refactor this to use schemes as other kernels # are supported + check if the layer is being ignored. @@ -628,43 +606,15 @@ def apply( self, layer: torch.nn.Module, x: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool, - use_grouped_topk: bool = False, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - num_fused_shared_experts: int = 0, - global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - custom_routing_function: Optional[Callable] = None, - scoring_func: str = "softmax", - correction_bias: Optional[torch.Tensor] = None, + topk_output: TopKOutput, + *, activation: str = "silu", - routed_scaling_factor: Optional[float] = None, + **kwargs, ) -> torch.Tensor: - from sglang.srt.layers.moe.topk import select_experts assert activation == "silu", "Only SiLU activation is supported." - if expert_map is not None: - raise NotImplementedError( - "Expert Parallelism is not supported for " "fused Marlin MoE method." - ) - topk_weights, topk_ids = select_experts( - hidden_states=x, - router_logits=router_logits, - use_grouped_topk=use_grouped_topk, - top_k=top_k, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - num_fused_shared_experts=num_fused_shared_experts, - custom_routing_function=custom_routing_function, - scoring_func=scoring_func, - correction_bias=correction_bias, - routed_scaling_factor=routed_scaling_factor, - ) + topk_weights, topk_ids, router_logits = topk_output return torch.ops.vllm.fused_marlin_moe( x, diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index 7275ea430132..23daa5d26fb8 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -3,7 +3,7 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union import torch import torch.nn.functional as F @@ -78,6 +78,7 @@ def dummy_func(*args, **kwargs): ) if TYPE_CHECKING: + from sglang.srt.layers.moe.topk import TopKOutput from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config _is_hip = is_hip() @@ -971,15 +972,8 @@ def apply( self, layer: torch.nn.Module, x: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool, - use_grouped_topk: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - num_fused_shared_experts: int = 0, - custom_routing_function: Optional[Callable] = None, - correction_bias: Optional[torch.Tensor] = None, + topk_output: TopKOutput, + *, activation: str = "silu", apply_router_weight_on_input: bool = False, inplace: bool = True, @@ -987,26 +981,11 @@ def apply( routed_scaling_factor: Optional[float] = None, ) -> torch.Tensor: from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts - from sglang.srt.layers.moe.topk import select_experts - - # Expert selection - topk_weights, topk_ids = select_experts( - hidden_states=x, - router_logits=router_logits, - use_grouped_topk=use_grouped_topk, - top_k=top_k, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - num_fused_shared_experts=num_fused_shared_experts, - custom_routing_function=custom_routing_function, - correction_bias=correction_bias, - routed_scaling_factor=routed_scaling_factor, - ) if use_intel_amx_backend(layer): from sglang.srt.layers.moe.topk import apply_topk_weights_cpu + topk_weights, topk_ids, _ = topk_output x, topk_weights = apply_topk_weights_cpu( apply_router_weight_on_input, topk_weights, x ) @@ -1032,8 +1011,7 @@ def apply( ret = self.maybe_apply_hip_fused_experts( layer, x, - topk_weights, - topk_ids, + topk_output, activation, no_combine, ) @@ -1048,6 +1026,7 @@ def apply( ): from sglang.srt.layers.moe.cutlass_moe import cutlass_fused_experts_fp8 + topk_weights, topk_ids, _ = topk_output return cutlass_fused_experts_fp8( x, layer.w13_weight.transpose(1, 2), @@ -1076,8 +1055,7 @@ def apply( x, layer.w13_weight, layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, + topk_output=topk_output, inplace=inplace and not no_combine, activation=activation, apply_router_weight_on_input=apply_router_weight_on_input, @@ -1101,11 +1079,11 @@ def maybe_apply_hip_fused_experts( self, layer: torch.nn.Module, x: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, + topk_output: TopKOutput, activation: str = "silu", no_combine: bool = False, ) -> Optional[torch.Tensor]: + topk_weights, topk_ids, _ = topk_output if _use_hip_int4: # TODO: add triton kernel and add check _use_aiter assert not no_combine, f"{no_combine=} is not supported." @@ -1397,14 +1375,8 @@ def process_weights_after_loading(self, layer: Module) -> None: def apply( self, layer: torch.nn.Module, - x: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool, - use_grouped_topk: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - custom_routing_function: Optional[Callable] = None, + hidden_states: torch.Tensor, + topk_output: TopKOutput, ) -> torch.Tensor: raise NotImplementedError diff --git a/python/sglang/srt/layers/quantization/gptq.py b/python/sglang/srt/layers/quantization/gptq.py index af56c3be719a..4f2eba4e3f48 100644 --- a/python/sglang/srt/layers/quantization/gptq.py +++ b/python/sglang/srt/layers/quantization/gptq.py @@ -3,7 +3,7 @@ import logging from dataclasses import dataclass from fractions import Fraction -from typing import Any, Callable, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union import torch @@ -43,6 +43,9 @@ unpack_cols, ) +if TYPE_CHECKING: + from sglang.srt.layers.moe.topk import TopKOutput + try: from vllm import _custom_ops as ops except ImportError: @@ -1057,42 +1060,20 @@ def apply( self, layer: torch.nn.Module, x: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool, - use_grouped_topk: bool = False, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - custom_routing_function: Optional[Callable] = None, - scoring_func: str = "softmax", - e_score_correction_bias: Optional[torch.Tensor] = None, + topk_output: TopKOutput, + *, activation: str = "silu", + **kwargs, ) -> torch.Tensor: # Delay the import to avoid circular dependency - from sglang.srt.layers.moe.topk import select_experts assert activation == "silu", "Only SiLU activation is supported." - assert ( - scoring_func == "softmax" - ), "Only softmax score func is supported for now." # The input must currently be float16 orig_dtype = x.dtype x = x.half() - topk_weights, topk_ids = select_experts( - hidden_states=x, - router_logits=router_logits, - use_grouped_topk=use_grouped_topk, - top_k=top_k, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - custom_routing_function=custom_routing_function, - correction_bias=e_score_correction_bias, - ) + topk_weights, topk_ids, router_logits = topk_output return fused_marlin_moe( x, diff --git a/python/sglang/srt/layers/quantization/modelopt_quant.py b/python/sglang/srt/layers/quantization/modelopt_quant.py index 5263f3b920b1..73de5b0d1594 100644 --- a/python/sglang/srt/layers/quantization/modelopt_quant.py +++ b/python/sglang/srt/layers/quantization/modelopt_quant.py @@ -2,7 +2,7 @@ from __future__ import annotations import logging -from typing import Any, Callable, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional import torch from torch.nn.parameter import Parameter @@ -31,6 +31,9 @@ from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.utils import is_cuda, next_power_of_2 +if TYPE_CHECKING: + from sglang.srt.layers.moe.topk import TopKOutput + if is_cuda(): from sgl_kernel import cutlass_scaled_fp4_mm, scaled_fp4_quant @@ -402,15 +405,8 @@ def apply( self, layer: torch.nn.Module, x: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool, - use_grouped_topk: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - num_fused_shared_experts: Optional[int] = None, - custom_routing_function: Optional[Callable] = None, - correction_bias: Optional[torch.Tensor] = None, + topk_output: TopKOutput, + *, activation: str = "silu", apply_router_weight_on_input: bool = False, inplace: bool = True, @@ -418,29 +414,12 @@ def apply( routed_scaling_factor: Optional[float] = None, ) -> torch.Tensor: from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts - from sglang.srt.layers.moe.topk import select_experts - - # Expert selection - topk_weights, topk_ids = select_experts( - hidden_states=x, - router_logits=router_logits, - use_grouped_topk=use_grouped_topk, - top_k=top_k, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - num_fused_shared_experts=num_fused_shared_experts, - custom_routing_function=custom_routing_function, - correction_bias=correction_bias, - routed_scaling_factor=routed_scaling_factor, - ) return fused_experts( x, layer.w13_weight, layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, + topk_output=topk_output, inplace=inplace, activation=activation, use_fp8_w8a8=True, @@ -961,15 +940,8 @@ def apply( self, layer: torch.nn.Module, x: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool, - use_grouped_topk: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - num_fused_shared_experts: Optional[int] = None, - custom_routing_function: Optional[Callable] = None, - correction_bias: Optional[torch.Tensor] = None, + topk_output: TopKOutput, + *, activation: str = "silu", apply_router_weight_on_input: bool = False, inplace: bool = True, @@ -982,21 +954,6 @@ def apply( ) -> torch.Tensor: assert activation == "silu", "Only SiLU activation is supported." - from sglang.srt.layers.moe.topk import select_experts - - topk_weights, topk_ids = select_experts( - hidden_states=x, - router_logits=router_logits, - use_grouped_topk=use_grouped_topk, - top_k=top_k, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - num_fused_shared_experts=num_fused_shared_experts, - custom_routing_function=custom_routing_function, - correction_bias=correction_bias, - routed_scaling_factor=routed_scaling_factor, - ) if self.enable_flashinfer_moe: assert ( @@ -1004,6 +961,7 @@ def apply( ), "apply_router_weight_on_input is not supported for Flashinfer" # TRTLLM Cutlass moe takes in activations in BF16/Half/nvfp4 precision # and fp4 quantized weights loaded from the checkpoint + topk_weights, topk_ids, _ = topk_output output = flashinfer_cutlass_fused_moe( x, topk_ids.to(torch.int), @@ -1029,6 +987,7 @@ def apply( from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4 + topk_weights, topk_ids, _ = topk_output return cutlass_moe_fp4( a=x, a1_gscale=layer.w13_input_scale_quant, diff --git a/python/sglang/srt/layers/quantization/moe_wna16.py b/python/sglang/srt/layers/quantization/moe_wna16.py index f83b9bb1f71d..fbbf1106616d 100644 --- a/python/sglang/srt/layers/quantization/moe_wna16.py +++ b/python/sglang/srt/layers/quantization/moe_wna16.py @@ -2,8 +2,9 @@ from __future__ import annotations import logging -from typing import Any, Callable, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Dict, List, Optional +import numpy as np import torch from sglang.srt.distributed import get_tensor_model_parallel_rank @@ -20,6 +21,9 @@ logger = logging.getLogger(__name__) +if TYPE_CHECKING: + from sglang.srt.layers.moe.topk import TopKOutput + def get_weight_perm(num_bits: int): perm_list: List[int] = [] @@ -348,15 +352,8 @@ def apply( self, layer: torch.nn.Module, x: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool, - use_grouped_topk: bool = False, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - num_fused_shared_experts: int = 0, - custom_routing_function: Optional[Callable] = None, - correction_bias: Optional[torch.Tensor] = None, + topk_output: TopKOutput, + *, activation: str = "silu", apply_router_weight_on_input: bool = False, inplace: bool = True, @@ -365,22 +362,8 @@ def apply( ) -> torch.Tensor: # avoid circular import from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts - from sglang.srt.layers.moe.topk import select_experts assert activation == "silu", "Only SiLU activation is supported." - topk_weights, topk_ids = select_experts( - hidden_states=x, - router_logits=router_logits, - top_k=top_k, - use_grouped_topk=use_grouped_topk, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - num_fused_shared_experts=num_fused_shared_experts, - custom_routing_function=custom_routing_function, - correction_bias=correction_bias, - routed_scaling_factor=routed_scaling_factor, - ) weight_bits = self.quant_config.weight_bits has_zp = self.quant_config.has_zp @@ -389,8 +372,7 @@ def apply( x, layer.w13_qweight, layer.w2_qweight, - topk_weights=topk_weights, - topk_ids=topk_ids, + topk_output=topk_output, inplace=inplace, apply_router_weight_on_input=apply_router_weight_on_input, use_int4_w4a16=weight_bits == 4, diff --git a/python/sglang/srt/layers/quantization/unquant.py b/python/sglang/srt/layers/quantization/unquant.py index 06afcb70be91..fa4cbf582027 100644 --- a/python/sglang/srt/layers/quantization/unquant.py +++ b/python/sglang/srt/layers/quantization/unquant.py @@ -1,5 +1,7 @@ +from __future__ import annotations + import importlib -from typing import Callable, List, Optional +from typing import TYPE_CHECKING, Callable, List, Optional import torch import torch.nn.functional as F @@ -21,6 +23,9 @@ use_intel_amx_backend, ) +if TYPE_CHECKING: + from sglang.srt.layers.moe.topk import TopKOutput + has_triton_kernels = importlib.util.find_spec("triton_kernels") is not None @@ -125,25 +130,6 @@ def __init__(self, use_triton_kernels: bool = False): super().__init__() self.use_triton_kernels = use_triton_kernels - from sglang.srt.layers.moe.fused_moe_native import moe_forward_native - - if torch.cuda.is_available(): - from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts - - if has_triton_kernels: - from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import ( - triton_kernel_moe_forward, - ) - else: - triton_kernel_moe_forward = None - else: - fused_experts = None # type: ignore - triton_kernel_moe_forward = None - - self.moe_forward_native = moe_forward_native - self.fused_experts = fused_experts - self.triton_kernel_moe_forward = triton_kernel_moe_forward - def create_weights( self, layer: torch.nn.Module, @@ -201,34 +187,18 @@ def apply( self, layer: torch.nn.Module, x: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool, - use_grouped_topk: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - num_fused_shared_experts: int = 0, - custom_routing_function: Optional[Callable] = None, - correction_bias: Optional[torch.Tensor] = None, + topk_output: TopKOutput, + *, activation: str = "silu", apply_router_weight_on_input: bool = False, inplace: bool = True, no_combine: bool = False, routed_scaling_factor: Optional[float] = None, ) -> torch.Tensor: - return self.forward( x=x, layer=layer, - router_logits=router_logits, - top_k=top_k, - renormalize=renormalize, - use_grouped_topk=use_grouped_topk, - topk_group=topk_group, - num_expert_group=num_expert_group, - num_fused_shared_experts=num_fused_shared_experts, - custom_routing_function=custom_routing_function, - correction_bias=correction_bias, + topk_output=topk_output, activation=activation, apply_router_weight_on_input=apply_router_weight_on_input, inplace=inplace, @@ -240,15 +210,8 @@ def forward_cuda( self, layer: torch.nn.Module, x: torch.Tensor, - use_grouped_topk: bool, - top_k: int, - router_logits: torch.Tensor, - renormalize: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - num_fused_shared_experts: int = 0, - custom_routing_function: Optional[Callable] = None, - correction_bias: Optional[torch.Tensor] = None, + topk_output: TopKOutput, + *, activation: str = "silu", apply_router_weight_on_input: bool = False, inplace: bool = True, @@ -257,33 +220,20 @@ def forward_cuda( ) -> torch.Tensor: if self.use_triton_kernels: - return self.triton_kernel_moe_forward( - hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - gating_output=router_logits, - topk=top_k, - renormalize=renormalize, - ) + # TODO(ch-wan): re-enable the Triton kernel + raise NotImplementedError("The Triton kernel is temporarily disabled.") + # return triton_kernel_moe_forward( + # hidden_states=x, + # w1=layer.w13_weight, + # w2=layer.w2_weight, + # gating_output=router_logits, + # topk=top_k, + # renormalize=renormalize, + # ) else: - from sglang.srt.layers.moe.topk import select_experts - - topk_weights, topk_ids = select_experts( - hidden_states=x, - router_logits=router_logits, - use_grouped_topk=use_grouped_topk, - top_k=top_k, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - num_fused_shared_experts=num_fused_shared_experts, - custom_routing_function=custom_routing_function, - correction_bias=correction_bias, - routed_scaling_factor=routed_scaling_factor, - ) - if _use_aiter: assert not no_combine, "unsupported" + topk_weights, topk_ids, _ = topk_output if apply_router_weight_on_input: assert ( topk_weights.dim() == 2 @@ -296,7 +246,6 @@ def forward_cuda( topk_weights = torch.ones_like( topk_weights, dtype=torch.float32 ) # topk_weights must be FP32 (float32) - return fused_moe( x, layer.w13_weight, @@ -310,12 +259,15 @@ def forward_cuda( ), ) else: - return self.fused_experts( + from sglang.srt.layers.moe.fused_moe_triton.fused_moe import ( + fused_experts, + ) + + return fused_experts( hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, + topk_output=topk_output, inplace=inplace and not no_combine, activation=activation, apply_router_weight_on_input=apply_router_weight_on_input, @@ -327,15 +279,8 @@ def forward_cpu( self, layer: torch.nn.Module, x: torch.Tensor, - use_grouped_topk: bool, - top_k: int, - router_logits: torch.Tensor, - renormalize: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - num_fused_shared_experts: int = 0, - custom_routing_function: Optional[Callable] = None, - correction_bias: Optional[torch.Tensor] = None, + topk_output: TopKOutput, + *, activation: str = "silu", apply_router_weight_on_input: bool = False, inplace: bool = True, @@ -344,30 +289,13 @@ def forward_cpu( ) -> torch.Tensor: assert activation == "silu", f"activation = {activation} is not supported." - if use_intel_amx_backend(layer): + if use_intel_amx_backend(layer) and not apply_router_weight_on_input: + from sglang.srt.layers.moe.topk import apply_topk_weights_cpu - from sglang.srt.layers.moe.topk import ( - apply_topk_weights_cpu, - select_experts, - ) - - topk_weights, topk_ids = select_experts( - hidden_states=x, - router_logits=router_logits, - use_grouped_topk=use_grouped_topk, - top_k=top_k, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - num_fused_shared_experts=num_fused_shared_experts, - custom_routing_function=custom_routing_function, - correction_bias=correction_bias, - routed_scaling_factor=routed_scaling_factor, - ) + topk_weights, topk_ids, _ = topk_output x, topk_weights = apply_topk_weights_cpu( apply_router_weight_on_input, topk_weights, x ) - return torch.ops.sgl_kernel.fused_experts_cpu( x, layer.w13_weight, @@ -385,61 +313,42 @@ def forward_cpu( True, # is_vnni ) else: - return self.moe_forward_native( + from sglang.srt.layers.moe.fused_moe_native import moe_forward_native + + return moe_forward_native( layer, x, - use_grouped_topk, - top_k, - router_logits, - renormalize, - topk_group, - num_expert_group, - num_fused_shared_experts, - custom_routing_function, - correction_bias, - activation, - apply_router_weight_on_input, - inplace, - no_combine, - routed_scaling_factor, + topk_output, + activation=activation, + apply_router_weight_on_input=apply_router_weight_on_input, + inplace=inplace, + no_combine=no_combine, + routed_scaling_factor=routed_scaling_factor, ) def forward_npu( self, layer: torch.nn.Module, x: torch.Tensor, - use_grouped_topk: bool, - top_k: int, - router_logits: torch.Tensor, - renormalize: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - num_fused_shared_experts: int = 0, - custom_routing_function: Optional[Callable] = None, - correction_bias: Optional[torch.Tensor] = None, + topk_output: TopKOutput, + *, activation: str = "silu", apply_router_weight_on_input: bool = False, inplace: bool = True, no_combine: bool = False, routed_scaling_factor: Optional[float] = None, ) -> torch.Tensor: - return self.moe_forward_native( + from sglang.srt.layers.moe.fused_moe_native import moe_forward_native + + return moe_forward_native( layer, x, - use_grouped_topk, - top_k, - router_logits, - renormalize, - topk_group, - num_expert_group, - num_fused_shared_experts, - custom_routing_function, - correction_bias, - activation, - apply_router_weight_on_input, - inplace, - no_combine, - routed_scaling_factor, + topk_output, + activation=activation, + apply_router_weight_on_input=apply_router_weight_on_input, + inplace=inplace, + no_combine=no_combine, + routed_scaling_factor=routed_scaling_factor, ) def forward_tpu(self, *args, **kwargs) -> torch.Tensor: @@ -508,13 +417,7 @@ def create_weights( def apply( self, layer: torch.nn.Module, - x: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool, - use_grouped_topk: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - custom_routing_function: Optional[Callable] = None, + hidden_states: torch.Tensor, + topk_output: TopKOutput, ) -> torch.Tensor: raise NotImplementedError diff --git a/python/sglang/srt/layers/quantization/w8a8_fp8.py b/python/sglang/srt/layers/quantization/w8a8_fp8.py index 871a4534ca3e..e486fef0b3a8 100644 --- a/python/sglang/srt/layers/quantization/w8a8_fp8.py +++ b/python/sglang/srt/layers/quantization/w8a8_fp8.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, Callable, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Dict, List, Optional import torch from torch.nn.parameter import Parameter @@ -25,6 +25,9 @@ ) from sglang.srt.utils import set_weight_attrs +if TYPE_CHECKING: + from sglang.srt.layers.moe.topk import TopKOutput + _is_fp8_fnuz = is_fp8_fnuz() @@ -266,45 +269,23 @@ def apply( self, layer: torch.nn.Module, x: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool, - use_grouped_topk: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - num_fused_shared_experts: int = 0, - custom_routing_function: Optional[Callable] = None, - correction_bias: Optional[torch.Tensor] = None, + topk_output: TopKOutput, + *, activation: str = "silu", + apply_router_weight_on_input: bool = False, inplace: bool = True, no_combine: bool = False, routed_scaling_factor: Optional[float] = None, ) -> torch.Tensor: from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts - from sglang.srt.layers.moe.topk import select_experts - - # Expert selection - topk_weights, topk_ids = select_experts( - hidden_states=x, - router_logits=router_logits, - use_grouped_topk=use_grouped_topk, - top_k=top_k, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - num_fused_shared_experts=num_fused_shared_experts, - custom_routing_function=custom_routing_function, - correction_bias=correction_bias, - routed_scaling_factor=routed_scaling_factor, - ) return fused_experts( x, layer.w13_weight, layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, + topk_output=topk_output, inplace=inplace, + apply_router_weight_on_input=apply_router_weight_on_input, activation=activation, use_fp8_w8a8=True, per_channel_quant=True, diff --git a/python/sglang/srt/layers/quantization/w8a8_int8.py b/python/sglang/srt/layers/quantization/w8a8_int8.py index 19cf49c9bc86..22e8b108f7f8 100644 --- a/python/sglang/srt/layers/quantization/w8a8_int8.py +++ b/python/sglang/srt/layers/quantization/w8a8_int8.py @@ -3,7 +3,7 @@ import importlib import sys from types import MappingProxyType -from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union, cast +from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple, Union, cast import torch from torch.nn.parameter import Parameter @@ -37,6 +37,9 @@ use_intel_amx_backend, ) +if TYPE_CHECKING: + from sglang.srt.layers.moe.topk import TopKOutput + _is_cuda = is_cuda() _is_cpu_amx_available = cpu_has_amx_support() _is_cpu = is_cpu() @@ -239,7 +242,7 @@ def get_quant_method( layer: torch.nn.Module, prefix: str, ) -> Optional[QuantizeMethodBase]: - from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod + from sglang.srt.layers.linear import LinearBase from sglang.srt.layers.moe.fused_moe_triton import FusedMoE if _is_npu: @@ -469,15 +472,8 @@ def apply( self, layer: torch.nn.Module, x: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool, - use_grouped_topk: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - num_fused_shared_experts: int = 0, - custom_routing_function: Optional[Callable] = None, - correction_bias: Optional[torch.Tensor] = None, + topk_output: TopKOutput, + *, activation: str = "silu", apply_router_weight_on_input: bool = False, inplace: bool = True, @@ -485,26 +481,11 @@ def apply( routed_scaling_factor: Optional[float] = None, ) -> torch.Tensor: from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts - from sglang.srt.layers.moe.topk import select_experts - - # Expert selection - topk_weights, topk_ids = select_experts( - hidden_states=x, - router_logits=router_logits, - use_grouped_topk=use_grouped_topk, - top_k=top_k, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - num_fused_shared_experts=num_fused_shared_experts, - custom_routing_function=custom_routing_function, - correction_bias=correction_bias, - routed_scaling_factor=routed_scaling_factor, - ) if use_intel_amx_backend(layer): from sglang.srt.layers.moe.topk import apply_topk_weights_cpu + topk_weights, topk_ids, _ = topk_output x, topk_weights = apply_topk_weights_cpu( apply_router_weight_on_input, topk_weights, x ) @@ -529,8 +510,7 @@ def apply( x, layer.w13_weight, layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, + topk_output=topk_output, inplace=inplace, activation=activation, apply_router_weight_on_input=apply_router_weight_on_input, @@ -907,7 +887,7 @@ def create_weights( layer: torch.nn.Module, num_experts: int, hidden_size: int, - intermediate_size: List[int], + intermediate_size: int, params_dtype: torch.dtype, **extra_weight_attrs, ) -> None: @@ -984,52 +964,11 @@ def apply( self, layer, x, - router_logits, - top_k, - renormalize, - use_grouped_topk, - topk_group, - num_expert_group, - num_fused_shared_experts, - custom_routing_function, - correction_bias, - activation, - apply_router_weight_on_input, - routed_scaling_factor, + topk_output: TopKOutput, **kwargs, ) -> torch.Tensor: - from sglang.srt.layers.moe.topk import select_experts - - global_num_experts = router_logits.shape[-1] - # NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern - if global_num_experts == 256: - topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k( - router_logits, - k=top_k, - bias=correction_bias, - k_group=topk_group, - group_count=num_expert_group, - group_select_mode=1, - renorm=0, - norm_type=1, - routed_scaling_factor=1, - eps=float(1e-20), - ) - else: - topk_weights, topk_ids = select_experts( - hidden_states=x, - router_logits=router_logits, - use_grouped_topk=use_grouped_topk, - top_k=top_k, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - num_fused_shared_experts=num_fused_shared_experts, - custom_routing_function=custom_routing_function, - correction_bias=correction_bias, - torch_native=True, - routed_scaling_factor=routed_scaling_factor, - ) + + topk_weights, topk_ids, _ = topk_output topk_ids = topk_ids.to(torch.int32) topk_weights = topk_weights.to(x.dtype) return npu_fused_experts( @@ -1040,5 +979,5 @@ def apply( w2_scale=layer.w2_weight_scale, topk_weights=topk_weights, topk_ids=topk_ids, - top_k=top_k, + top_k=topk_ids.shape[1], ) diff --git a/python/sglang/srt/models/deepseek.py b/python/sglang/srt/models/deepseek.py index 95bfe001a2da..f2f0d0344ad2 100644 --- a/python/sglang/srt/models/deepseek.py +++ b/python/sglang/srt/models/deepseek.py @@ -37,6 +37,7 @@ ) from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.moe.fused_moe_triton import fused_moe +from sglang.srt.layers.moe.topk import TopK from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.rotary_embedding import get_rope @@ -109,7 +110,10 @@ def __init__( f"Tensor parallel size {self.tp_size} is greater than " f"the number of experts {self.n_routed_experts}." ) - + self.topk = TopK( + top_k=self.top_k, + renormalize=config.norm_topk_prob, + ) self.experts = nn.ModuleList( [ DeepseekMLP( @@ -170,13 +174,12 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: shared_output = self.shared_experts(hidden_states) # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) + topk_output = self.topk(hidden_states, router_logits) final_hidden_states = fused_moe.fused_moe( hidden_states, - self.w1, - self.w2, - router_logits, - self.top_k, - renormalize=self.config.norm_topk_prob, + w1=self.w1, + w2=self.w2, + topk_output=topk_output, inplace=True, ) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 0da956b0158f..9ec5db9260d3 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -58,7 +58,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher -from sglang.srt.layers.moe.topk import select_experts +from sglang.srt.layers.moe.topk import TopK from sglang.srt.layers.quantization import deep_gemm_wrapper from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.fp8_kernel import ( @@ -303,6 +303,17 @@ def __init__( config=config, prefix=add_prefix("gate", prefix), is_nextn=is_nextn ) + self.topk = TopK( + top_k=config.num_experts_per_tok + self.num_fused_shared_experts, + renormalize=config.norm_topk_prob, + use_grouped_topk=True, + num_expert_group=config.n_group, + num_fused_shared_experts=self.num_fused_shared_experts, + topk_group=config.topk_group, + correction_bias=self.gate.e_score_correction_bias, + routed_scaling_factor=self.routed_scaling_factor, + ) + self.experts = get_moe_impl_class()( num_experts=config.n_routed_experts + self.num_fused_shared_experts @@ -311,13 +322,7 @@ def __init__( hidden_size=config.hidden_size, intermediate_size=config.moe_intermediate_size, layer_id=self.layer_id, - renormalize=config.norm_topk_prob, quant_config=quant_config, - use_grouped_topk=True, - num_expert_group=config.n_group, - num_fused_shared_experts=self.num_fused_shared_experts, - topk_group=config.topk_group, - correction_bias=self.gate.e_score_correction_bias, routed_scaling_factor=self.routed_scaling_factor, prefix=add_prefix("experts", prefix), **( @@ -451,8 +456,9 @@ def forward_normal_dual_stream( with torch.cuda.stream(self.alt_stream): # router_logits: (num_tokens, n_experts) router_logits = self.gate(hidden_states) + topk_output = self.topk(hidden_states, router_logits) final_hidden_states = self.experts( - hidden_states=hidden_states, router_logits=router_logits + hidden_states=hidden_states, topk_output=topk_output ) if not _is_cuda: final_hidden_states *= self.routed_scaling_factor @@ -473,8 +479,9 @@ def forward_normal( shared_output = self._forward_shared_experts(hidden_states) # router_logits: (num_tokens, n_experts) router_logits = self.gate(hidden_states) + topk_output = self.topk(hidden_states, router_logits) final_hidden_states = self.experts( - hidden_states=hidden_states, router_logits=router_logits + hidden_states=hidden_states, topk_output=topk_output ) if not _is_cuda and not _use_aiter: # fused in biased_grouped_topk so we can skip here @@ -490,8 +497,9 @@ def forward_cpu( ) -> torch.Tensor: # router_logits: (num_tokens, n_experts) router_logits = self.gate(hidden_states) + topk_output = self.topk(hidden_states, router_logits) fused_experts_out = self.experts( - hidden_states=hidden_states, router_logits=router_logits + hidden_states=hidden_states, topk_output=topk_output ) assert use_intel_amx_backend( @@ -549,17 +557,9 @@ def forward_deepep( # router_logits: (num_tokens, n_experts) router_logits = self.gate(hidden_states) shared_output = self._forward_shared_experts(hidden_states) - topk_weights, topk_idx = select_experts( - hidden_states=hidden_states, - router_logits=router_logits, - top_k=self.top_k, - use_grouped_topk=True, - renormalize=self.renormalize, - topk_group=self.topk_group, - num_expert_group=self.num_expert_group, - num_fused_shared_experts=self.num_fused_shared_experts, - correction_bias=self.correction_bias, - routed_scaling_factor=self.routed_scaling_factor, + topk_weights, topk_idx, _ = self.topk( + hidden_states, + router_logits, num_token_non_padded=forward_batch.num_token_non_padded, expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new( layer_id=self.layer_id, @@ -649,17 +649,9 @@ def op_select_experts(self, state): with get_global_expert_distribution_recorder().with_current_layer( self.layer_id ): - state.topk_weights_local, state.topk_idx_local = select_experts( + state.topk_weights_local, state.topk_idx_local, _ = self.topk( hidden_states=hidden_states, router_logits=router_logits, - top_k=self.top_k, - use_grouped_topk=True, - renormalize=self.renormalize, - topk_group=self.topk_group, - num_expert_group=self.num_expert_group, - num_fused_shared_experts=self.num_fused_shared_experts, - correction_bias=self.correction_bias, - routed_scaling_factor=self.routed_scaling_factor, num_token_non_padded=state.forward_batch.num_token_non_padded, expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new( layer_id=self.layer_id, diff --git a/python/sglang/srt/models/granitemoe.py b/python/sglang/srt/models/granitemoe.py index b4a9c17af56f..1e61092090ac 100644 --- a/python/sglang/srt/models/granitemoe.py +++ b/python/sglang/srt/models/granitemoe.py @@ -15,6 +15,7 @@ ) from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput from sglang.srt.layers.moe.fused_moe_triton import FusedMoE +from sglang.srt.layers.moe.topk import TopK from sglang.srt.layers.pooler import Pooler, PoolingType from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention @@ -60,6 +61,11 @@ def __init__( prefix=f"{prefix}.gate", ) + self.topk = TopK( + top_k=top_k, + renormalize=True, + ) + self.experts = FusedMoE( num_experts=num_experts, top_k=top_k, @@ -67,7 +73,6 @@ def __init__( intermediate_size=intermediate_size, params_dtype=params_dtype, reduce_results=True, - renormalize=True, quant_config=quant_config, tp_size=tp_size, prefix=f"{prefix}.experts", @@ -78,7 +83,8 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: orig_shape = hidden_states.shape hidden_states = hidden_states.view(-1, self.hidden_size) router_logits, _ = self.gate(hidden_states) - final_hidden_states = self.experts(hidden_states, router_logits) + topk_output = self.topk(hidden_states, router_logits) + final_hidden_states = self.experts(hidden_states, topk_output) return final_hidden_states.view(orig_shape) diff --git a/python/sglang/srt/models/grok.py b/python/sglang/srt/models/grok.py index a8cde8e09c02..4a46bf1973d8 100644 --- a/python/sglang/srt/models/grok.py +++ b/python/sglang/srt/models/grok.py @@ -45,6 +45,7 @@ from sglang.srt.layers.moe.ep_moe.layer import EPMoE from sglang.srt.layers.moe.fused_moe_triton import FusedMoE from sglang.srt.layers.moe.router import fused_moe_router_shim +from sglang.srt.layers.moe.topk import TopK from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.rotary_embedding import get_rope @@ -108,6 +109,12 @@ def __init__( fused_moe_router_shim, self.router_logit_softcapping ) + self.topk = TopK( + top_k=top_k, + renormalize=False, + custom_routing_function=custom_routing_function, + ) + kwargs = {} if global_server_args_dict["enable_ep_moe"]: MoEImpl = EPMoE @@ -124,17 +131,16 @@ def __init__( hidden_size=hidden_size, intermediate_size=intermediate_size, params_dtype=params_dtype, - renormalize=False, quant_config=quant_config, tp_size=tp_size, - custom_routing_function=custom_routing_function, activation="gelu", **kwargs, ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # need to assert self.gate.quant_method is unquantized - return self.experts(hidden_states, self.gate.weight) + topk_output = self.topk(hidden_states, self.gate.weight) + return self.experts(hidden_states, topk_output) class Grok1Attention(nn.Module): diff --git a/python/sglang/srt/models/hunyuan.py b/python/sglang/srt/models/hunyuan.py index f23ccc0a8d94..58e95bbb1cd8 100644 --- a/python/sglang/srt/models/hunyuan.py +++ b/python/sglang/srt/models/hunyuan.py @@ -40,6 +40,7 @@ ) from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.moe.fused_moe_triton import FusedMoE +from sglang.srt.layers.moe.topk import TopK from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.rotary_embedding import get_rope @@ -152,13 +153,16 @@ def __init__( else config.moe_intermediate_size[layer_id] ) + self.topk = TopK( + top_k=top_k, + renormalize=True if top_k > 1 else False, + ) + self.experts = FusedMoE( num_experts=config.num_experts, - top_k=top_k, hidden_size=config.hidden_size, intermediate_size=intermediate_size, reduce_results=False, - renormalize=True if top_k > 1 else False, quant_config=quant_config, ) @@ -195,9 +199,8 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) - final_hidden_states = self.experts( - hidden_states=hidden_states, router_logits=router_logits - ) + topk_output = self.topk(hidden_states, router_logits) + final_hidden_states = self.experts(hidden_states, topk_output) if shared_output is not None: final_hidden_states = final_hidden_states + shared_output if self.tp_size > 1: diff --git a/python/sglang/srt/models/llama4.py b/python/sglang/srt/models/llama4.py index 1bb6fcc12193..cf0b20800410 100644 --- a/python/sglang/srt/models/llama4.py +++ b/python/sglang/srt/models/llama4.py @@ -40,6 +40,7 @@ RowParallelLinear, ) from sglang.srt.layers.moe.fused_moe_triton import FusedMoE +from sglang.srt.layers.moe.topk import TopK from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.rotary_embedding import get_rope @@ -103,14 +104,17 @@ def __init__( prefix=add_prefix("router", prefix), ) + self.topk = TopK( + top_k=self.top_k, + renormalize=False, + custom_routing_function=Llama4MoE.custom_routing_function, + ) + self.experts = FusedMoE( num_experts=config.num_local_experts, - top_k=config.num_experts_per_tok, hidden_size=config.hidden_size, - custom_routing_function=Llama4MoE.custom_routing_function, intermediate_size=intermediate_size_moe, reduce_results=False, - renormalize=False, quant_config=quant_config, apply_router_weight_on_input=True, prefix=add_prefix("experts", prefix), @@ -147,10 +151,8 @@ def _forward_core_normal(self, hidden_states): # router_scores: [num_tokens, num_experts] router_logits, _ = self.router(hidden_states) shared_out = self.shared_expert(hidden_states) - routed_out = self.experts( - hidden_states=hidden_states, - router_logits=router_logits, - ) + topk_output = self.topk(hidden_states, router_logits) + routed_out = self.experts(hidden_states, topk_output) return shared_out, routed_out def _forward_core_shared_routed_overlap(self, hidden_states): @@ -163,10 +165,8 @@ def _forward_core_shared_routed_overlap(self, hidden_states): with self.device_module.stream(alt_stream): # router_scores: [num_tokens, num_experts] router_logits, _ = self.router(hidden_states) - routed_out = self.experts( - hidden_states=hidden_states, - router_logits=router_logits, - ) + topk_output = self.topk(hidden_states, router_logits) + routed_out = self.experts(hidden_states, topk_output) self.device_module.current_stream().wait_stream(alt_stream) return shared_out, routed_out diff --git a/python/sglang/srt/models/mixtral.py b/python/sglang/srt/models/mixtral.py index 90a12f12f1dd..b09fc2f24827 100644 --- a/python/sglang/srt/models/mixtral.py +++ b/python/sglang/srt/models/mixtral.py @@ -37,6 +37,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.moe.ep_moe.layer import EPMoE from sglang.srt.layers.moe.fused_moe_triton import FusedMoE +from sglang.srt.layers.moe.topk import TopK from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.rotary_embedding import get_rope @@ -86,6 +87,12 @@ def __init__( quant_config=None, prefix=add_prefix("gate", prefix), ) + + self.topk = TopK( + top_k=top_k, + renormalize=True, + ) + MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE self.experts = MoEImpl( num_experts=num_experts, @@ -93,7 +100,6 @@ def __init__( hidden_size=hidden_size, intermediate_size=intermediate_size, params_dtype=params_dtype, - renormalize=True, quant_config=quant_config, tp_size=tp_size, prefix=add_prefix("experts", prefix), @@ -105,7 +111,8 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = hidden_states.view(-1, self.hidden_size) # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) - final_hidden_states = self.experts(hidden_states, router_logits) + topk_output = self.topk(hidden_states, router_logits) + final_hidden_states = self.experts(hidden_states, topk_output) if self.tp_size > 1: final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) return final_hidden_states.view(orig_shape) diff --git a/python/sglang/srt/models/olmoe.py b/python/sglang/srt/models/olmoe.py index 612120fe939b..ce53f2b0148a 100644 --- a/python/sglang/srt/models/olmoe.py +++ b/python/sglang/srt/models/olmoe.py @@ -32,6 +32,7 @@ ) from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.moe.fused_moe_triton import FusedMoE +from sglang.srt.layers.moe.topk import TopK from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.rotary_embedding import get_rope @@ -76,13 +77,16 @@ def __init__( prefix=add_prefix("gate", prefix), ) + self.topk = TopK( + top_k=top_k, + renormalize=False, + ) + self.experts = FusedMoE( num_experts=num_experts, - top_k=top_k, hidden_size=hidden_size, intermediate_size=intermediate_size, reduce_results=True, - renormalize=False, quant_config=quant_config, tp_size=tp_size, prefix=add_prefix("experts", prefix), @@ -94,9 +98,8 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = hidden_states.view(-1, self.hidden_size) # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) - final_hidden_states = self.experts( - hidden_states=hidden_states, router_logits=router_logits - ) + topk_output = self.topk(hidden_states, router_logits) + final_hidden_states = self.experts(hidden_states, topk_output) return final_hidden_states.view(orig_shape) diff --git a/python/sglang/srt/models/phimoe.py b/python/sglang/srt/models/phimoe.py index 22ee023c83c4..865b94f51665 100644 --- a/python/sglang/srt/models/phimoe.py +++ b/python/sglang/srt/models/phimoe.py @@ -13,6 +13,7 @@ ) from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput from sglang.srt.layers.moe.fused_moe_triton import FusedMoE +from sglang.srt.layers.moe.topk import TopK from sglang.srt.layers.pooler import Pooler, PoolingType from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention @@ -200,15 +201,19 @@ def __init__( quant_config=None, ) + self.topk = TopK( + top_k=top_k, + renormalize=False, + custom_routing_function=phimoe_routing_function, + ) + self.experts = FusedMoE( num_experts=num_experts, top_k=top_k, hidden_size=hidden_size, intermediate_size=intermediate_size, reduce_results=True, - renormalize=False, quant_config=quant_config, - custom_routing_function=phimoe_routing_function, prefix=add_prefix("experts", prefix), ) @@ -219,7 +224,8 @@ def forward( orig_shape = hidden_states.shape hidden_states = hidden_states.view(-1, self.hidden_size) router_logits, _ = self.gate(hidden_states) - final_hidden_states = self.experts(hidden_states, router_logits) + topk_output = self.topk(hidden_states, router_logits) + final_hidden_states = self.experts(hidden_states, topk_output) return final_hidden_states.view(orig_shape) diff --git a/python/sglang/srt/models/qwen2_moe.py b/python/sglang/srt/models/qwen2_moe.py index fe2636ab74e8..e033424cf023 100644 --- a/python/sglang/srt/models/qwen2_moe.py +++ b/python/sglang/srt/models/qwen2_moe.py @@ -61,6 +61,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput from sglang.srt.layers.moe.ep_moe.layer import EPMoE, get_moe_impl_class from sglang.srt.layers.moe.fused_moe_triton import FusedMoE +from sglang.srt.layers.moe.topk import TopK from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.rotary_embedding import get_rope @@ -134,13 +135,17 @@ def __init__( f"the number of experts {config.num_experts}." ) + self.topk = TopK( + top_k=config.num_experts_per_tok, + renormalize=config.norm_topk_prob, + ) + self.experts = get_moe_impl_class()( layer_id=self.layer_id, - num_experts=config.num_experts, top_k=config.num_experts_per_tok, + num_experts=config.num_experts, hidden_size=config.hidden_size, intermediate_size=config.moe_intermediate_size, - renormalize=config.norm_topk_prob, quant_config=quant_config, prefix=add_prefix("experts", prefix), # Additional args for FusedMoE @@ -189,9 +194,8 @@ def forward( # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) - final_hidden_states = self.experts( - hidden_states=hidden_states, router_logits=router_logits - ) + topk_output = self.topk(hidden_states, router_logits) + final_hidden_states = self.experts(hidden_states, topk_output) if shared_output is not None: final_hidden_states = final_hidden_states + shared_output final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) diff --git a/python/sglang/srt/models/qwen3_moe.py b/python/sglang/srt/models/qwen3_moe.py index 75d3b475cb0e..c75a384990e8 100644 --- a/python/sglang/srt/models/qwen3_moe.py +++ b/python/sglang/srt/models/qwen3_moe.py @@ -56,8 +56,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher -from sglang.srt.layers.moe.fused_moe_triton import FusedMoE -from sglang.srt.layers.moe.topk import select_experts +from sglang.srt.layers.moe.topk import TopK from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.rotary_embedding import get_rope @@ -102,6 +101,12 @@ def __init__( f"the number of experts {config.num_experts}." ) + self.topk = TopK( + top_k=config.num_experts_per_tok, + renormalize=config.norm_topk_prob, + use_grouped_topk=False, + ) + self.experts = get_moe_impl_class()( num_experts=config.num_experts + global_server_args_dict["ep_num_redundant_experts"], @@ -109,7 +114,6 @@ def __init__( layer_id=layer_id, hidden_size=config.hidden_size, intermediate_size=config.moe_intermediate_size, - renormalize=config.norm_topk_prob, quant_config=quant_config, prefix=add_prefix("experts", prefix), **( @@ -143,7 +147,6 @@ def __init__( config.num_experts + global_server_args_dict["ep_num_redundant_experts"] ) self.top_k = config.num_experts_per_tok - self.renormalize = config.norm_topk_prob self.deepep_dispatcher = MaybeTboDeepEPDispatcher( group=parallel_state.get_tp_group().device_group, @@ -180,9 +183,8 @@ def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor: # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) - final_hidden_states = self.experts( - hidden_states=hidden_states, router_logits=router_logits - ) + topk_output = self.topk(hidden_states, router_logits) + final_hidden_states = self.experts(hidden_states, topk_output) if self.tp_size > 1: final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) @@ -195,13 +197,9 @@ def forward_deepep( if is_non_idle_and_non_empty(forward_mode, hidden_states): # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) - - topk_weights, topk_idx = select_experts( - hidden_states=hidden_states, - router_logits=router_logits, - top_k=self.top_k, - use_grouped_topk=False, - renormalize=self.renormalize, + topk_weights, topk_idx, _ = self.topk( + hidden_states, + router_logits, num_token_non_padded=forward_batch.num_token_non_padded, expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new( layer_id=self.layer_id, @@ -267,12 +265,9 @@ def op_select_experts(self, state): with get_global_expert_distribution_recorder().with_current_layer( self.layer_id ): - state.topk_weights_local, state.topk_idx_local = select_experts( + state.topk_weights_local, state.topk_idx_local, _ = self.topk( hidden_states=hidden_states, router_logits=router_logits, - top_k=self.top_k, - use_grouped_topk=False, - renormalize=self.renormalize, num_token_non_padded=state.forward_batch.num_token_non_padded, expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new( layer_id=self.layer_id, diff --git a/python/sglang/test/test_block_fp8.py b/python/sglang/test/test_block_fp8.py index a5a338632f58..fd2c95608a17 100644 --- a/python/sglang/test/test_block_fp8.py +++ b/python/sglang/test/test_block_fp8.py @@ -6,6 +6,7 @@ from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe +from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.quantization.fp8_kernel import ( per_tensor_quant_mla_fp8, per_token_group_quant_fp8, @@ -497,13 +498,17 @@ def _w8a8_block_fp8_fused_moe(self, M, N, K, E, topk, block_size, dtype, seed): score = torch.randn((M, E), dtype=dtype) with torch.inference_mode(): + topk_output = select_experts( + hidden_states=a, + router_logits=score, + top_k=topk, + renormalize=False, + ) out = fused_moe( a, w1, w2, - score, - topk, - renormalize=False, + topk_output, use_fp8_w8a8=True, w1_scale=w1_s, w2_scale=w2_s, diff --git a/python/sglang/test/test_block_fp8_ep.py b/python/sglang/test/test_block_fp8_ep.py index bd735edbdc50..2f92c5435b8f 100644 --- a/python/sglang/test/test_block_fp8_ep.py +++ b/python/sglang/test/test_block_fp8_ep.py @@ -40,7 +40,7 @@ def ep_moe( block_shape: Optional[List[int]] = None, ): use_blockwise_fp8 = block_shape is not None - topk_weights, topk_ids = select_experts( + topk_weights, topk_ids, _ = select_experts( hidden_states=hidden_states, router_logits=router_logits, top_k=top_k, diff --git a/python/sglang/test/test_cutlass_w4a8_moe.py b/python/sglang/test/test_cutlass_w4a8_moe.py index acf8a27b918f..c823bf1f7e48 100644 --- a/python/sglang/test/test_cutlass_w4a8_moe.py +++ b/python/sglang/test/test_cutlass_w4a8_moe.py @@ -100,12 +100,10 @@ def test_cutlass_w4a8_moe(M, N, K, E, ep_size, topk, group_size, dtype): s_strides2 = c_strides2 score = torch.randn((M, E), dtype=dtype, device=device) - topk_weights, topk_ids = select_experts( + topk_weights, topk_ids, _ = select_experts( hidden_states=a, router_logits=score, top_k=topk, - use_grouped_topk=False, - renormalize=False, ) expert_map = torch.arange(E, dtype=torch.int32, device=device) expert_map[local_e:] = E diff --git a/python/sglang/test/test_fp4_moe.py b/python/sglang/test/test_fp4_moe.py index 7e3de278cbe9..30b1fe9db5a4 100644 --- a/python/sglang/test/test_fp4_moe.py +++ b/python/sglang/test/test_fp4_moe.py @@ -159,12 +159,10 @@ def test_cutlass_fp4_moe_no_graph( score = torch.randn((m, e), device="cuda", dtype=dtype) - topk_weights, topk_ids = select_experts( + topk_weights, topk_ids, _ = select_experts( hidden_states=a, router_logits=score, top_k=topk, - use_grouped_topk=False, - renormalize=False, ) a1_gs = torch.ones((e,), device="cuda", dtype=torch.float32) diff --git a/test/srt/test_block_int8.py b/test/srt/test_block_int8.py index 2b8b841f02f4..58bd7c1e1998 100644 --- a/test/srt/test_block_int8.py +++ b/test/srt/test_block_int8.py @@ -5,6 +5,7 @@ from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe +from sglang.srt.layers.moe.topk import select_experts from sglang.test.test_utils import CustomTestCase @@ -171,14 +172,18 @@ def _w8a8_block_int8_fused_moe(self, M, N, K, E, topk, block_size, dtype, seed): score = torch.randn((M, E), dtype=dtype) + topk_output = select_experts( + hidden_states=a, + router_logits=score, + top_k=topk, + ) + with torch.inference_mode(): out = fused_moe( a, w1, w2, - score, - topk, - renormalize=False, + topk_output, use_int8_w8a8=True, w1_scale=w1_s, w2_scale=w2_s, diff --git a/test/srt/test_fused_moe.py b/test/srt/test_fused_moe.py index d1c2735d13c2..1a0452c41196 100644 --- a/test/srt/test_fused_moe.py +++ b/test/srt/test_fused_moe.py @@ -6,6 +6,7 @@ from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe +from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz from sglang.srt.utils import is_hip @@ -132,13 +133,17 @@ def _test_case(self, m, n, k, e, topk, dtype, use_fp8_w8a8=False): input_scale=a2_scale, ) + topk_output = select_experts( + hidden_states=a, + router_logits=score, + top_k=topk, + ) + sglang_output = fused_moe( a, w1, w2, - score, - topk, - renormalize=False, + topk_output, use_fp8_w8a8=True, w1_scale=w1_scale, w2_scale=w2_scale, @@ -166,7 +171,13 @@ def _test_case(self, m, n, k, e, topk, dtype, use_fp8_w8a8=False): w2 = self.create_random_cuda_tensor((e, k, n), dtype) score = self.create_random_cuda_tensor((m, e), dtype) - triton_output = fused_moe(a, w1, w2, score, topk, renormalize=False) + topk_output = select_experts( + hidden_states=a, + router_logits=score, + top_k=topk, + ) + + triton_output = fused_moe(a, w1, w2, topk_output) torch_output = self.torch_naive_moe(a, w1, w2, score, topk) torch.testing.assert_close( triton_output, torch_output, rtol=rtol, atol=atol diff --git a/test/srt/test_int8_kernel.py b/test/srt/test_int8_kernel.py index 3e9f7a7dd98b..bbadce230304 100644 --- a/test/srt/test_int8_kernel.py +++ b/test/srt/test_int8_kernel.py @@ -5,6 +5,7 @@ from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe +from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8 from sglang.test.test_utils import CustomTestCase @@ -114,13 +115,16 @@ def _w8a8_int8_fused_moe(self, M, N, K, E, topk, block_size, dtype, seed): with torch.inference_mode(): ref_out = torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk) + topk_output = select_experts( + hidden_states=a, + router_logits=score, + top_k=topk, + ) out = fused_moe( a, w1, w2, - score, - topk, - renormalize=False, + topk_output, use_fp8_w8a8=False, # Not using fp8 use_int8_w8a16=False, # Not using int8-w8a16 use_int8_w8a8=True, # Using int8-w8a8 diff --git a/test/srt/test_triton_moe_channel_fp8_kernel.py b/test/srt/test_triton_moe_channel_fp8_kernel.py index 89b5af650df4..577570757d35 100644 --- a/test/srt/test_triton_moe_channel_fp8_kernel.py +++ b/test/srt/test_triton_moe_channel_fp8_kernel.py @@ -5,6 +5,7 @@ from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe +from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant from sglang.test.test_utils import CustomTestCase @@ -126,13 +127,16 @@ def _w8a8_fp8_fused_moe(self, M, N, K, E, topk, block_size, dtype, seed): with torch.inference_mode(): ref_out = torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk) + topk_output = select_experts( + hidden_states=a, + router_logits=score, + top_k=topk, + ) out = fused_moe( a, w1, w2, - score, - topk, - renormalize=False, + topk_output, use_fp8_w8a8=True, # using fp8 use_int8_w8a16=False, use_int8_w8a8=False, diff --git a/test/srt/test_triton_moe_wna16.py b/test/srt/test_triton_moe_wna16.py index 2613586a8466..51583c2f200f 100644 --- a/test/srt/test_triton_moe_wna16.py +++ b/test/srt/test_triton_moe_wna16.py @@ -5,6 +5,7 @@ from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe +from sglang.srt.layers.moe.topk import select_experts NUM_EXPERTS = [8, 64] TOP_KS = [2, 6] @@ -219,13 +220,17 @@ def test_fused_moe_wn16( if has_zp: w_qzeros[expert_id] = qzeros + topk_output = select_experts( + hidden_states=a, + router_logits=score, + top_k=topk, + ) + triton_output = fused_moe( a, w1_qweight, w2_qweight, - score, - topk, - renormalize=False, + topk_output, use_int4_w4a16=weight_bits == 4, use_int8_w8a16=weight_bits == 8, w1_scale=w1_scales, From f98e88b9fbbb59ad700892da765bc49bda34c59b Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Sat, 19 Jul 2025 00:56:18 -0700 Subject: [PATCH 048/396] chore: bump sgl-kernel v0.2.6 (#8165) --- docker/Dockerfile | 2 +- sgl-kernel/pyproject.toml | 2 +- sgl-kernel/pyproject_cpu.toml | 2 +- sgl-kernel/pyproject_rocm.toml | 2 +- sgl-kernel/python/sgl_kernel/version.py | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/docker/Dockerfile b/docker/Dockerfile index eac2c8a4c446..bc0eb095e917 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -60,7 +60,7 @@ RUN python3 -m pip install --no-cache-dir --upgrade pip setuptools wheel html5li && python3 -m pip install --no-cache-dir -e "python[${BUILD_TYPE}]" --extra-index-url https://download.pytorch.org/whl/cu${CUINDEX} \ && if [ "$CUDA_VERSION" = "12.8.1" ]; then \ python3 -m pip install --no-cache-dir nvidia-nccl-cu12==2.27.5 --force-reinstall --no-deps ; \ - python3 -m pip install --no-cache-dir https://github.com/sgl-project/whl/releases/download/v0.2.5/sgl_kernel-0.2.5+cu128-cp39-abi3-manylinux2014_x86_64.whl --force-reinstall --no-deps ; \ + python3 -m pip install --no-cache-dir https://github.com/sgl-project/whl/releases/download/v0.2.6/sgl_kernel-0.2.6+cu128-cp39-abi3-manylinux2014_x86_64.whl --force-reinstall --no-deps ; \ fi # Build and install NVSHMEM + DeepEP diff --git a/sgl-kernel/pyproject.toml b/sgl-kernel/pyproject.toml index bb460f05986f..4d8ff394df4d 100644 --- a/sgl-kernel/pyproject.toml +++ b/sgl-kernel/pyproject.toml @@ -8,7 +8,7 @@ build-backend = "scikit_build_core.build" [project] name = "sgl-kernel" -version = "0.2.5" +version = "0.2.6" description = "Kernel Library for SGLang" readme = "README.md" requires-python = ">=3.9" diff --git a/sgl-kernel/pyproject_cpu.toml b/sgl-kernel/pyproject_cpu.toml index b88b38b4a497..c243596515bd 100644 --- a/sgl-kernel/pyproject_cpu.toml +++ b/sgl-kernel/pyproject_cpu.toml @@ -8,7 +8,7 @@ build-backend = "scikit_build_core.build" [project] name = "sgl-kernel" -version = "0.2.5" +version = "0.2.6" description = "Kernel Library for SGLang" readme = "README.md" requires-python = ">=3.9" diff --git a/sgl-kernel/pyproject_rocm.toml b/sgl-kernel/pyproject_rocm.toml index de2e9bcf384f..6ab48599c5cf 100644 --- a/sgl-kernel/pyproject_rocm.toml +++ b/sgl-kernel/pyproject_rocm.toml @@ -9,7 +9,7 @@ build-backend = "setuptools.build_meta" [project] name = "sgl-kernel" -version = "0.2.5" +version = "0.2.6" description = "Kernel Library for SGLang" readme = "README.md" requires-python = ">=3.9" diff --git a/sgl-kernel/python/sgl_kernel/version.py b/sgl-kernel/python/sgl_kernel/version.py index fe404ae570d5..01ef12070dc3 100644 --- a/sgl-kernel/python/sgl_kernel/version.py +++ b/sgl-kernel/python/sgl_kernel/version.py @@ -1 +1 @@ -__version__ = "0.2.5" +__version__ = "0.2.6" From 561dd7b2ce2b1a4ef9bbffa840eb5b60f520f839 Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Sat, 19 Jul 2025 03:17:08 -0700 Subject: [PATCH 049/396] chore: upgrade sgl-kernel 0.2.6 (#8166) --- python/pyproject.toml | 2 +- python/sglang/srt/entrypoints/engine.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/pyproject.toml b/python/pyproject.toml index 5b6501afd192..5949a100a96e 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -54,7 +54,7 @@ runtime_common = [ srt = [ "sglang[runtime_common]", - "sgl-kernel==0.2.5", + "sgl-kernel==0.2.6", "torch==2.7.1", "torchaudio==2.7.1", "torchvision==0.22.1", diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index e130dc227d21..990fac9a12a7 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -654,7 +654,7 @@ def _set_envs_and_config(server_args: ServerArgs): if _is_cuda: assert_pkg_version( "sgl-kernel", - "0.2.5", + "0.2.6", "Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`", ) From f3d97361564331d591af6038cccad1099025cfe3 Mon Sep 17 00:00:00 2001 From: Charles Chen Date: Sun, 20 Jul 2025 01:11:24 +0800 Subject: [PATCH 050/396] Fix suffix mismatch for the metrics. (#8168) Signed-off-by: Charles Chen --- sgl-router/src/metrics.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sgl-router/src/metrics.rs b/sgl-router/src/metrics.rs index 0ff2055c540c..76e952a03736 100644 --- a/sgl-router/src/metrics.rs +++ b/sgl-router/src/metrics.rs @@ -132,7 +132,7 @@ pub fn start_prometheus(config: PrometheusConfig) { // Initialize metric descriptions init_metrics(); - let duration_matcher = Matcher::Suffix(String::from("duration")); + let duration_matcher = Matcher::Suffix(String::from("duration_seconds")); let duration_bucket = [ 0.001, 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0, 10.0, 15.0, 30.0, 45.0, 60.0, 90.0, 120.0, 180.0, 240.0, From 1b427dae0269024ec7c7330bc7b5e181b557d342 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Sat, 19 Jul 2025 11:04:19 -0700 Subject: [PATCH 051/396] Update README.md (#8171) --- README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index a1f9f904677f..b19a9cdabfc0 100644 --- a/README.md +++ b/README.md @@ -25,14 +25,14 @@ - [2025/05] 🔥 Deploying DeepSeek with PD Disaggregation and Large-scale Expert Parallelism on 96 H100 GPUs ([blog](https://lmsys.org/blog/2025-05-05-large-scale-ep/)). - [2025/03] Supercharge DeepSeek-R1 Inference on AMD Instinct MI300X ([AMD blog](https://rocm.blogs.amd.com/artificial-intelligence/DeepSeekR1-Part2/README.html)) - [2025/03] SGLang Joins PyTorch Ecosystem: Efficient LLM Serving Engine ([PyTorch blog](https://pytorch.org/blog/sglang-joins-pytorch/)) -- [2025/01] 🔥 SGLang provides day one support for DeepSeek V3/R1 models on NVIDIA and AMD GPUs with DeepSeek-specific optimizations. ([instructions](https://github.com/sgl-project/sglang/tree/main/benchmark/deepseek_v3), [AMD blog](https://www.amd.com/en/developer/resources/technical-articles/amd-instinct-gpus-power-deepseek-v3-revolutionizing-ai-development-with-sglang.html), [10+ other companies](https://x.com/lmsysorg/status/1887262321636221412)) -- [2024/12] 🔥 v0.4 Release: Zero-Overhead Batch Scheduler, Cache-Aware Load Balancer, Faster Structured Outputs ([blog](https://lmsys.org/blog/2024-12-04-sglang-v0-4/)). +- [2024/12] v0.4 Release: Zero-Overhead Batch Scheduler, Cache-Aware Load Balancer, Faster Structured Outputs ([blog](https://lmsys.org/blog/2024-12-04-sglang-v0-4/)). - [2024/07] v0.2 Release: Faster Llama3 Serving with SGLang Runtime (vs. TensorRT-LLM, vLLM) ([blog](https://lmsys.org/blog/2024-07-25-sglang-llama3/)).
More - [2025/02] Unlock DeepSeek-R1 Inference Performance on AMD Instinct™ MI300X GPU ([AMD blog](https://rocm.blogs.amd.com/artificial-intelligence/DeepSeekR1_Perf/README.html)) +- [2025/01] SGLang provides day one support for DeepSeek V3/R1 models on NVIDIA and AMD GPUs with DeepSeek-specific optimizations. ([instructions](https://github.com/sgl-project/sglang/tree/main/benchmark/deepseek_v3), [AMD blog](https://www.amd.com/en/developer/resources/technical-articles/amd-instinct-gpus-power-deepseek-v3-revolutionizing-ai-development-with-sglang.html), [10+ other companies](https://x.com/lmsysorg/status/1887262321636221412)) - [2024/10] The First SGLang Online Meetup ([slides](https://github.com/sgl-project/sgl-learning-materials?tab=readme-ov-file#the-first-sglang-online-meetup)). - [2024/09] v0.3 Release: 7x Faster DeepSeek MLA, 1.5x Faster torch.compile, Multi-Image/Video LLaVA-OneVision ([blog](https://lmsys.org/blog/2024-09-04-sglang-v0-3/)). - [2024/02] SGLang enables **3x faster JSON decoding** with compressed finite state machine ([blog](https://lmsys.org/blog/2024-02-05-compressed-fsm/)). @@ -59,7 +59,7 @@ The core features include: - [Contribution Guide](https://docs.sglang.ai/references/contribution_guide.html) ## Benchmark and Performance -Learn more in the release blogs: [v0.2 blog](https://lmsys.org/blog/2024-07-25-sglang-llama3/), [v0.3 blog](https://lmsys.org/blog/2024-09-04-sglang-v0-3/), [v0.4 blog](https://lmsys.org/blog/2024-12-04-sglang-v0-4/). +Learn more in the release blogs: [v0.2 blog](https://lmsys.org/blog/2024-07-25-sglang-llama3/), [v0.3 blog](https://lmsys.org/blog/2024-09-04-sglang-v0-3/), [v0.4 blog](https://lmsys.org/blog/2024-12-04-sglang-v0-4/), [Large-scale expert parallelism](https://lmsys.org/blog/2025-05-05-large-scale-ep/). ## Roadmap [Development Roadmap (2025 H1)](https://github.com/sgl-project/sglang/issues/4042) From bb0e8a32b579b57ecc18863620dd5c7366f15af5 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Sat, 19 Jul 2025 11:32:52 -0700 Subject: [PATCH 052/396] Clean up server args (#8161) --- .github/CODEOWNERS | 19 +- docs/backend/server_arguments.md | 128 +++-- python/sglang/srt/configs/model_config.py | 8 +- python/sglang/srt/managers/scheduler.py | 3 - python/sglang/srt/model_loader/utils.py | 8 +- python/sglang/srt/server_args.py | 556 ++++++++++---------- python/sglang/test/runners.py | 4 +- test/srt/models/test_transformers_models.py | 6 +- 8 files changed, 389 insertions(+), 343 deletions(-) diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 7369e035cede..9d640b90b60f 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -1,23 +1,24 @@ /3rdparty/amd @HaiShaw /docker @zhyncs @HaiShaw @ByronHsu /docs @zhaochenyang20 -/python/sglang/lang @merrymercy @Ying1123 @hnyls2002 @ByronHsu +/python/sglang/lang @merrymercy @Ying1123 @hnyls2002 /python/sglang/srt @merrymercy @Ying1123 @hnyls2002 @zhyncs @ispobock @ByronHsu /python/sglang/srt/constrained @hnyls2002 -/python/sglang/srt/disaggregation @hnyls2002 @ByronHsu +/python/sglang/srt/disaggregation @ByronHsu @hnyls2002 /python/sglang/srt/distributed @yizhang2077 -/python/sglang/srt/entrypoints @zhaochenyang20 -/python/sglang/srt/entrypoints/openai @merrymercy @Ying1123 @hnyls2002 @zhyncs @ispobock @ByronHsu @CatherineSue +/python/sglang/srt/entrypoints @zhaochenyang20 @CatherineSue +/python/sglang/srt/eplb @fzyzcjy +/python/sglang/srt/function_call @CatherineSue /python/sglang/srt/layers @merrymercy @Ying1123 @zhyncs @ispobock @HaiShaw @ch-wan @BBuf /python/sglang/srt/lora @Ying1123 @Fridge003 /python/sglang/srt/managers @merrymercy @Ying1123 @hnyls2002 @xiezhq-hermann /python/sglang/srt/mem_cache @merrymercy @Ying1123 @hnyls2002 @xiezhq-hermann /python/sglang/srt/model_executor @merrymercy @Ying1123 @hnyls2002 @zhyncs @ispobock -/python/sglang/srt/models @merrymercy @Ying1123 @hnyls2002 @zhyncs @ispobock @ByronHsu @zhaochenyang20 -/python/sglang/srt/sampling @merrymercy @hnyls2002 -/python/sglang/srt/speculative @Ying1123 @merrymercy @rkooo567 @kssteven418 +/python/sglang/srt/models @zhyncs @ispobock @ByronHsu @zhaochenyang20 /python/sglang/srt/multimodal @mickqian @JustinTong0323 -/test/lang @merrymercy @Ying1123 @ByronHsu +/python/sglang/srt/sampling @hnyls2002 +/python/sglang/srt/speculative @Ying1123 @merrymercy @rkooo567 @kssteven418 +/test/lang @merrymercy @Ying1123 /test/srt @merrymercy @Ying1123 @zhyncs -/sgl-router @ByronHsu @Ying1123 @slin1237 +/sgl-router @ByronHsu @slin1237 /sgl-kernel @zhyncs @ispobock @HandH1998 @BBuf @yizhang2077 @merrymercy @yinfan98 @HaiShaw diff --git a/docs/backend/server_arguments.md b/docs/backend/server_arguments.md index ad9c136c8b78..6320a6e61aac 100644 --- a/docs/backend/server_arguments.md +++ b/docs/backend/server_arguments.md @@ -51,7 +51,7 @@ You can find all arguments by `python3 -m sglang.launch_server --help` Please consult the documentation below and [server_args.py](https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/server_args.py) to learn more about the arguments you may provide when launching a server. -## Model, processor and tokenizer +## Model and tokenizer | Arguments | Description | Defaults | |-----------|-------------|----------| @@ -61,20 +61,30 @@ Please consult the documentation below and [server_args.py](https://github.com/s | `--skip-tokenizer-init` | If set, skip init tokenizer and pass input_ids in generate request. | False | | `--load-format` | The format of the model weights to load. 'auto' will try to load the weights in the safetensors format and fall back to the pytorch bin format if safetensors format is not available. 'pt' will load the weights in the pytorch bin format. 'safetensors' will load the weights in the safetensors format. 'npcache' will load the weights in pytorch format and store a numpy cache to speed up the loading. 'dummy' will initialize the weights with random values, which is mainly for profiling. 'gguf' will load the weights in the gguf format. 'bitsandbytes' will load the weights using bitsandbytes quantization. 'layered' loads weights layer by layer so that one can quantize a layer before loading another to make the peak memory envelope smaller. | auto | | `--trust-remote-code` | Whether or not to allow for custom models defined on the Hub in their own modeling files. | False | -| `--dtype` | Data type for model weights and activations. 'auto' will use FP16 precision for FP32 and FP16 models, and BF16 precision for BF16 models. 'half' for FP16. Recommended for AWQ quantization. 'float16' is the same as 'half'. 'bfloat16' for a balance between precision and range. 'float' is shorthand for FP32 precision. 'float32' for FP32 precision. | auto | -| `--kv-cache-dtype` | Data type for kv cache storage. 'auto' will use model data type. 'fp8_e5m2' and 'fp8_e4m3' is supported for CUDA 11.8+. | auto | -| `--quantization` | The quantization method. | None | -| `--quantization-param-path` | Path to the JSON file containing the KV cache scaling factors. This should generally be supplied, when KV cache dtype is FP8. Otherwise, KV cache scaling factors default to 1.0, which may cause accuracy issues. | None | | `--context-length` | The model's maximum context length. Defaults to None (will use the value from the model's config.json instead). | None | -| `--device` | The device to use ('cuda', 'xpu', 'hpu', 'npu', 'cpu'). Defaults to auto-detection if not specified. | None | -| `--served-model-name` | Override the model name returned by the v1/models endpoint in OpenAI API server. | None | -| `--chat-template` | The buliltin chat template name or the path of the chat template file. This is only used for OpenAI-compatible API server. | None | -| `--completion-template` | The buliltin completion template name or the path of the completion template file. This is only used for OpenAI-compatible API server. only for code completion currently. | None | | `--is-embedding` | Whether to use a CausalLM as an embedding model. | False | | `--enable-multimodal` | Enable the multimodal functionality for the served model. If the model being served is not multimodal, nothing will happen. | None | | `--revision` | The specific model version to use. It can be a branch name, a tag name, or a commit id. If unspecified, will use the default version. | None | -| `--impl` | Which implementation of the model to use. 'auto' will try to use the SGLang implementation if it exists and fall back to the Transformers implementation if no SGLang implementation is available. 'sglang' will use the SGLang model implementation. 'transformers' will use the Transformers model implementation. | auto | +| `--model-impl` | Which implementation of the model to use. 'auto' will try to use the SGLang implementation if it exists and fall back to the Transformers implementation if no SGLang implementation is available. 'sglang' will use the SGLang model implementation. 'transformers' will use the Transformers model implementation. | auto | + +## HTTP server + +| Arguments | Description | Defaults | +|-----------|-------------|----------| +| `--host` | The host address for the server. | 127.0.0.1 | +| `--port` | The port number for the server. | 30000 | +| `--skip-server-warmup` | If set, skip the server warmup process. | False | +| `--warmups` | Warmup configurations. | None | +| `--nccl-port` | The port for NCCL initialization. | None | +## Quantization and data type + +| Arguments | Description | Defaults | +|-----------|-------------|----------| +| `--dtype` | Data type for model weights and activations. 'auto' will use FP16 precision for FP32 and FP16 models, and BF16 precision for BF16 models. 'half' for FP16. Recommended for AWQ quantization. 'float16' is the same as 'half'. 'bfloat16' for a balance between precision and range. 'float' is shorthand for FP32 precision. 'float32' for FP32 precision. | auto | +| `--quantization` | The quantization method. | None | +| `--quantization-param-path` | Path to the JSON file containing the KV cache scaling factors. This should generally be supplied, when KV cache dtype is FP8. Otherwise, KV cache scaling factors default to 1.0, which may cause accuracy issues. | None | +| `--kv-cache-dtype` | Data type for kv cache storage. 'auto' will use model data type. 'fp8_e5m2' and 'fp8_e4m3' is supported for CUDA 11.8+. | auto | ## Memory and scheduling @@ -90,13 +100,13 @@ Please consult the documentation below and [server_args.py](https://github.com/s | `--cpu-offload-gb` | How many GBs of RAM to reserve for CPU offloading. | 0 | | `--page-size` | The number of tokens in a page. | 1 | - -## Other runtime options +## Runtime options | Arguments | Description | Defaults | |-----------|-------------|----------| -| `--tensor-parallel-size` or `--tp-size` | The tensor parallelism size. | 1 | -| `--pipeline-parallel-size` or `--pp-size` | The pipeline parallelism size. | 1 | +| `--device` | The device to use ('cuda', 'xpu', 'hpu', 'npu', 'cpu'). Defaults to auto-detection if not specified. | None | +| `--tp-size` | The tensor parallelism size. | 1 | +| `--pp-size` | The pipeline parallelism size. | 1 | | `--max-micro-batch-size` | The maximum micro batch size in pipeline parallelism. | None | | `--stream-interval` | The interval (or buffer size) for streaming in terms of the token length. A smaller value makes streaming smoother, while a larger value makes the throughput higher. | 1 | | `--stream-output` | Whether to output as a sequence of disjoint segments. | False | @@ -132,20 +142,22 @@ Please consult the documentation below and [server_args.py](https://github.com/s | Arguments | Description | Defaults | |-----------|-------------|----------| | `--api-key` | Set API key of the server. It is also used in the OpenAI API compatible server. | None | +| `--served-model-name` | Override the model name returned by the v1/models endpoint in OpenAI API server. | None | +| `--chat-template` | The buliltin chat template name or the path of the chat template file. This is only used for OpenAI-compatible API server. | None | +| `--completion-template` | The buliltin completion template name or the path of the completion template file. This is only used for OpenAI-compatible API server. only for code completion currently. | None | | `--file-storage-path` | The path of the file storage in backend. | sglang_storage | | `--enable-cache-report` | Return number of cached tokens in usage.prompt_tokens_details for each openai request. | False | | `--reasoning-parser` | Specify the parser for reasoning models, supported parsers are: {list(ReasoningParser.DetectorMap.keys())}. | None | | `--tool-call-parser` | Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', 'llama3', 'deepseekv3', 'pythonic', and 'kimi_k2'. | None | -## Data parallelism +## Data parallelism | Arguments | Description | Defaults | |-----------|-------------|----------| -| `--data-parallel-size` or `--dp-size` | The data parallelism size. | 1 | +| `--dp-size` | The data parallelism size. | 1 | | `--load-balance-method` | The load balancing strategy for data parallelism. | round_robin | - -## Multi-node distributed serving +## Multi-node distributed serving | Arguments | Description | Defaults | |-----------|-------------|----------| @@ -153,7 +165,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s | `--nnodes` | The number of nodes. | 1 | | `--node-rank` | The node rank. | 0 | -## Model override args +## Model override args in JSON | Arguments | Description | Defaults | |-----------|-------------|----------| @@ -164,11 +176,11 @@ Please consult the documentation below and [server_args.py](https://github.com/s | Arguments | Description | Defaults | |-----------|-------------|----------| +| `--max-lora-rank` | The maximum LoRA rank that should be supported. If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of larger LoRA rank after server startup. | None | +| `--lora-target-modules` | The union set of all target modules where LoRA should be applied (e.g., `q_proj`, `k_proj`, `gate_proj`). If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of different target modules after server startup. | None | | `--lora-paths` | The list of LoRA adapters. You can provide a list of either path in str or renamed path in the format {name}={path}. | None | | `--max-loras-per-batch` | Maximum number of adapters for a running batch, include base-only request. | 8 | | `--lora-backend` | Choose the kernel backend for multi-LoRA serving. | triton | -| `--max-lora-rank` | The maximum LoRA rank that should be supported. If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of larger LoRA rank after server startup. | None | -| `--lora-target-modules` | The union set of all target modules where LoRA should be applied (e.g., `q_proj`, `k_proj`, `gate_proj`). If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of different target modules after server startup. | None | ## Kernel backend @@ -179,7 +191,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s | `--grammar-backend` | Choose the backend for grammar-guided decoding. | None | | `--mm-attention-backend` | Set multimodal attention backend. | None | -## Speculative decoding +## Speculative decoding | Arguments | Description | Defaults | |-----------|-------------|----------| @@ -192,13 +204,14 @@ Please consult the documentation below and [server_args.py](https://github.com/s | `--speculative-accept-threshold-acc` | The accept probability of a draft token is raised from its target probability p to min(1, p / threshold_acc). | 1.0 | | `--speculative-token-map` | The path of the draft model's small vocab table. | None | -## Expert parallelism +## Expert parallelism | Arguments | Description | Defaults | |-----------|-------------|----------| -| `--expert-parallel-size` or `--ep-size` | The expert parallelism size. | 1 | +| `--ep-size` | The expert parallelism size. | 1 | | `--enable-ep-moe` | Enabling expert parallelism for moe. The ep size is equal to the tp size. | False | | `--enable-deepep-moe` | Enabling DeepEP MoE implementation for EP MoE. | False | +| `--enable-flashinfer-moe` | Enabling Flashinfer MoE implementation. | False | | `--deepep-mode` | Select the mode when enable DeepEP MoE, could be `normal`, `low_latency` or `auto`. Default is `auto`, which means `low_latency` for decode batch and `normal` for prefill batch. | auto | | `--ep-num-redundant-experts` | Allocate this number of redundant experts in expert parallel. | 0 | | `--ep-dispatch-algorithm` | The algorithm to choose ranks for redundant experts in expert parallel. | None | @@ -213,7 +226,18 @@ Please consult the documentation below and [server_args.py](https://github.com/s | `--deepep-config` | Tuned DeepEP config suitable for your own cluster. It can be either a string with JSON content or a file path. | None | | `--moe-dense-tp-size` | TP size for MoE dense MLP layers. This flag is useful when, with large TP size, there are errors caused by weights in MLP layers having dimension smaller than the min dimension GEMM supports. | None | -## Optimization/debug options +## Hierarchical cache + +| Arguments | Description | Defaults | +|-----------|-------------|----------| +| `--enable-hierarchical-cache` | Enable hierarchical cache. | False | +| `--hicache-ratio` | The ratio of the size of host KV cache memory pool to the size of device pool. | 2.0 | +| `--hicache-size` | The size of the hierarchical cache. | 0 | +| `--hicache-write-policy` | The write policy for hierarchical cache. | write_through_selective | +| `--hicache-io-backend` | The IO backend for hierarchical cache. | | +| `--hicache-storage-backend` | The storage backend for hierarchical cache. | None | + +## Optimization/debug options | Arguments | Description | Defaults | |-----------|-------------|----------| @@ -229,7 +253,6 @@ Please consult the documentation below and [server_args.py](https://github.com/s | `--disable-custom-all-reduce` | Disable the custom all-reduce kernel and fall back to NCCL. | False | | `--enable-mscclpp` | Enable using mscclpp for small messages for all-reduce kernel and fall back to NCCL. | False | | `--disable-overlap-schedule` | Disable the overlap scheduler, which overlaps the CPU scheduler with GPU model worker. | False | -| `--disable-overlap-cg-plan` | Disable the overlap optimization for cudagraph preparation in eagle verify. | False | | `--enable-mixed-chunk` | Enabling mixing prefill and decode in a batch when using chunked prefill. | False | | `--enable-dp-attention` | Enabling data parallelism for attention and tensor parallelism for FFN. The dp size should be equal to the tp size. Currently DeepSeek-V2 and Qwen 2/3 MoE models are supported. | False | | `--enable-dp-lm-head` | Enable vocabulary parallel across the attention TP group to avoid all-gather across DP groups, optimizing performance under DP attention. | False | @@ -246,24 +269,43 @@ Please consult the documentation below and [server_args.py](https://github.com/s | `--enable-memory-saver` | Allow saving memory using release_memory_occupation and resume_memory_occupation. | False | | `--allow-auto-truncate` | Allow automatically truncating requests that exceed the maximum input length instead of returning an error. | False | | `--enable-custom-logit-processor` | Enable users to pass custom logit processors to the server (disabled by default for security). | False | -| `--enable-hierarchical-cache` | Enable hierarchical cache. | False | -| `--hicache-ratio` | The ratio of the size of host KV cache memory pool to the size of device pool. | 2.0 | -| `--hicache-size` | The size of host KV cache memory pool in gigabytes, which will override the hicache_ratio if set. | 0 | -| `--hicache-write-policy` | The write policy of hierarchical cache. | write_through_selective | -| `--flashinfer-mla-disable-ragged` | Not using ragged prefill wrapper when running flashinfer mla. | False | -| `--disable-shared-experts-fusion` | Disable shared experts fusion optimization for deepseek v3/r1. | False | -| `--disable-chunked-prefix-cache` | Disable chunked prefix cache feature for deepseek, which should save overhead for short sequences. | False | -| `--disable-fast-image-processor` | Adopt base image processor instead of fast image processor. | False | -| `--enable-return-hidden-states` | Enable returning hidden states with responses. | False | -| `--warmups` | Specify custom warmup functions (csv) to run before server starts eg. --warmups=warmup_name1,warmup_name2 will run the functions `warmup_name1` and `warmup_name2` specified in warmup.py before the server starts listening for requests. | None | +| `--flashinfer-mla-disable-ragged` | Disable ragged processing in Flashinfer MLA. | False | +| `--disable-shared-experts-fusion` | Disable shared experts fusion. | False | +| `--disable-chunked-prefix-cache` | Disable chunked prefix cache. | False | +| `--disable-fast-image-processor` | Disable fast image processor. | False | +| `--enable-return-hidden-states` | Enable returning hidden states. | False | +| `--enable-triton-kernel-moe` | Enable Triton kernel for MoE. | False | + +## Debug tensor dumps + +| Arguments | Description | Defaults | +|-----------|-------------|----------| +| `--debug-tensor-dump-output-folder` | The output folder for debug tensor dumps. | None | +| `--debug-tensor-dump-input-file` | The input file for debug tensor dumps. | None | +| `--debug-tensor-dump-inject` | Enable injection of debug tensor dumps. | False | +| `--debug-tensor-dump-prefill-only` | Enable prefill-only mode for debug tensor dumps. | False | + +## PD disaggregation + +| Arguments | Description | Defaults | +|-----------|-------------|----------| +| `--disaggregation-mode` | PD disaggregation mode: "null" (not disaggregated), "prefill" (prefill-only), or "decode" (decode-only). | null | +| `--disaggregation-transfer-backend` | The transfer backend for PD disaggregation. | mooncake | +| `--disaggregation-bootstrap-port` | The bootstrap port for PD disaggregation. | 8998 | +| `--disaggregation-decode-tp` | The decode TP for PD disaggregation. | None | +| `--disaggregation-decode-dp` | The decode DP for PD disaggregation. | None | +| `--disaggregation-prefill-pp` | The prefill PP for PD disaggregation. | 1 | + +## Model weight update + +| Arguments | Description | Defaults | +|-----------|-------------|----------| +| `--custom-weight-loader` | Custom weight loader paths. | None | +| `--weight-loader-disable-mmap` | Disable mmap for weight loader. | False | -## Prefill decode disaggregation +## PD-Multiplexing | Arguments | Description | Defaults | |-----------|-------------|----------| -| `--disaggregation-mode` | Only used for PD disaggregation. "prefill" for prefill-only server, and "decode" for decode-only server. If not specified, it is not PD disaggregated. | null | -| `--disaggregation-transfer-backend` | The backend for disaggregation transfer. Default is mooncake. | mooncake | -| `--disaggregation-bootstrap-port` | Bootstrap server port on the prefill server. Default is 8998. | 8998 | -| `--disaggregation-ib-device` | The InfiniBand devices for disaggregation transfer, accepts single device (e.g., --disaggregation-ib-device mlx5_0) or multiple comma-separated devices (e.g., --disaggregation-ib-device mlx5_0,mlx5_1). Default is None, which triggers automatic device detection when mooncake backend is enabled. | None | -| `--num-reserved-decode-tokens` | Number of decode tokens that will have memory reserved when adding new request to the running batch. | 512 | -| `--pdlb-url` | The URL of the PD disaggregation load balancer. If set, the prefill/decode server will register with the load balancer. | None | +| `--enable-pdmux` | Enable PD-Multiplexing. | False | +| `--sm-group-num` | Number of SM groups for PD-Multiplexing. | 3 | diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index 7d7f2eb95b22..84c96d91df0b 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -53,7 +53,7 @@ def __init__( trust_remote_code: bool = True, revision: Optional[str] = None, context_length: Optional[int] = None, - model_override_args: Optional[str] = None, + model_override_args: str = "{}", is_embedding: Optional[bool] = None, enable_multimodal: Optional[bool] = None, dtype: str = "auto", @@ -61,13 +61,13 @@ def __init__( override_config_file: Optional[str] = None, is_draft_model: bool = False, hybrid_kvcache_ratio: Optional[float] = None, - impl: Union[str, ModelImpl] = ModelImpl.AUTO, + model_impl: Union[str, ModelImpl] = ModelImpl.AUTO, ) -> None: self.model_path = model_path self.revision = revision self.quantization = quantization - self.impl = impl + self.model_impl = model_impl # Parse args self.maybe_pull_model_tokenizer_from_remote() @@ -286,7 +286,7 @@ def from_server_args(server_args: ServerArgs, model_path: str = None, **kwargs): dtype=server_args.dtype, quantization=server_args.quantization, hybrid_kvcache_ratio=server_args.hybrid_kvcache_ratio, - impl=server_args.impl, + model_impl=server_args.model_impl, **kwargs, ) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 748cb7322ade..e6dd80d717ad 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -1389,8 +1389,6 @@ def log_prefill_stats( f += f"#running-req: {running_bs}, " f += f"#queue-req: {len(self.waiting_queue)}, " - f += f"timestamp: {datetime.datetime.now().isoformat()}" - logger.info(f) if self.enable_metrics: @@ -1471,7 +1469,6 @@ def log_decode_stats( f"cuda graph: {can_run_cuda_graph}, " f"gen throughput (token/s): {self.last_gen_throughput:.2f}, " f"#queue-req: {len(self.waiting_queue)}, " - f"timestamp: {datetime.datetime.now().isoformat()}" ) logger.info(msg) diff --git a/python/sglang/srt/model_loader/utils.py b/python/sglang/srt/model_loader/utils.py index 4f65ad5fecd1..dfbbd154d627 100644 --- a/python/sglang/srt/model_loader/utils.py +++ b/python/sglang/srt/model_loader/utils.py @@ -56,14 +56,14 @@ def resolve_transformers_arch(model_config: ModelConfig, architectures: list[str "if the model is custom)." ) model_module = auto_modules["AutoModel"] - if model_config.impl == ModelImpl.TRANSFORMERS: + if model_config.model_impl == ModelImpl.TRANSFORMERS: if not model_module.is_backend_compatible(): raise ValueError( f"The Transformers implementation of {arch} is not " - "compatible with vLLM." + "compatible with SGLang." ) architectures[i] = "TransformersForCausalLM" - if model_config.impl == ModelImpl.AUTO: + if model_config.model_impl == ModelImpl.AUTO: if not model_module.is_backend_compatible(): raise ValueError( f"{arch} has no SGlang implementation and the Transformers " @@ -97,7 +97,7 @@ def get_model_architecture(model_config: ModelConfig) -> Tuple[Type[nn.Module], supported_archs = ModelRegistry.get_supported_archs() is_native_supported = any(arch in supported_archs for arch in architectures) - if not is_native_supported or model_config.impl == ModelImpl.TRANSFORMERS: + if not is_native_supported or model_config.model_impl == ModelImpl.TRANSFORMERS: architectures = resolve_transformers_arch(model_config, architectures) return ModelRegistry.resolve_model_cls(architectures) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 4f9e17e05dda..24292bcd79b8 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -20,6 +20,7 @@ import os import random import tempfile +from token import OP from typing import List, Literal, Optional, Union from sglang.srt.hf_transformers_utils import check_gguf_file, get_config @@ -46,31 +47,28 @@ class ServerArgs: tokenizer_path: Optional[str] = None tokenizer_mode: str = "auto" skip_tokenizer_init: bool = False - skip_server_warmup: bool = False load_format: str = "auto" model_loader_extra_config: str = "{}" trust_remote_code: bool = False - dtype: str = "auto" - kv_cache_dtype: str = "auto" - quantization: Optional[str] = None - quantization_param_path: Optional[str] = None context_length: Optional[int] = None - device: Optional[str] = None - served_model_name: Optional[str] = None - chat_template: Optional[str] = None - completion_template: Optional[str] = None is_embedding: bool = False enable_multimodal: Optional[bool] = None revision: Optional[str] = None - hybrid_kvcache_ratio: Optional[float] = None - swa_full_tokens_ratio: float = 0.8 - impl: str = "auto" + model_impl: str = "auto" - # Port for the HTTP server + # HTTP server host: str = "127.0.0.1" port: int = 30000 + skip_server_warmup: bool = False + warmups: Optional[str] = None nccl_port: Optional[int] = None + # Quantization and data type + dtype: str = "auto" + quantization: Optional[str] = None + quantization_param_path: Optional[str] = None + kv_cache_dtype: str = "auto" + # Memory and scheduling mem_fraction_static: Optional[float] = None max_running_requests: Optional[int] = None @@ -81,8 +79,12 @@ class ServerArgs: schedule_conservativeness: float = 1.0 cpu_offload_gb: int = 0 page_size: int = 1 + hybrid_kvcache_ratio: Optional[float] = None + swa_full_tokens_ratio: float = 0.8 + disable_hybrid_swa_memory: bool = False - # Other runtime options + # Runtime options + device: Optional[str] = None tp_size: int = 1 pp_size: int = 1 max_micro_batch_size: Optional[int] = None @@ -107,8 +109,8 @@ class ServerArgs: enable_metrics: bool = False enable_metrics_for_all_schedulers: bool = False bucket_time_to_first_token: Optional[List[float]] = None - bucket_e2e_request_latency: Optional[List[float]] = None bucket_inter_token_latency: Optional[List[float]] = None + bucket_e2e_request_latency: Optional[List[float]] = None collect_tokens_histogram: bool = False decode_log_interval: int = 40 enable_request_time_stats_logging: bool = False @@ -116,6 +118,9 @@ class ServerArgs: # API related api_key: Optional[str] = None + served_model_name: Optional[str] = None + chat_template: Optional[str] = None + completion_template: Optional[str] = None file_storage_path: str = "sglang_storage" enable_cache_report: bool = False reasoning_parser: Optional[str] = None @@ -179,6 +184,14 @@ class ServerArgs: deepep_config: Optional[str] = None moe_dense_tp_size: Optional[int] = None + # Hierarchical cache + enable_hierarchical_cache: bool = False + hicache_ratio: float = 2.0 + hicache_size: int = 0 + hicache_write_policy: str = "write_through_selective" + hicache_io_backend: str = "" + hicache_storage_backend: Optional[str] = None + # Double Sparsity enable_double_sparsity: bool = False ds_channel_config_path: Optional[str] = None @@ -200,7 +213,6 @@ class ServerArgs: disable_custom_all_reduce: bool = False enable_mscclpp: bool = False disable_overlap_schedule: bool = False - disable_overlap_cg_plan: bool = False enable_mixed_chunk: bool = False enable_dp_attention: bool = False enable_dp_lm_head: bool = False @@ -217,20 +229,12 @@ class ServerArgs: enable_memory_saver: bool = False allow_auto_truncate: bool = False enable_custom_logit_processor: bool = False - enable_hierarchical_cache: bool = False - hicache_ratio: float = 2.0 - hicache_size: int = 0 - hicache_write_policy: str = "write_through_selective" - hicache_io_backend: str = "" - hicache_storage_backend: Optional[str] = None flashinfer_mla_disable_ragged: bool = False disable_shared_experts_fusion: bool = False disable_chunked_prefix_cache: bool = False disable_fast_image_processor: bool = False enable_return_hidden_states: bool = False enable_triton_kernel_moe: bool = False - warmups: Optional[str] = None - disable_hybrid_swa_memory: bool = False # Debug tensor dumps debug_tensor_dump_output_folder: Optional[str] = None @@ -238,7 +242,7 @@ class ServerArgs: debug_tensor_dump_inject: bool = False debug_tensor_dump_prefill_only: bool = False - # For PD disaggregation: can be "null" (not disaggregated), "prefill" (prefill-only), or "decode" (decode-only) + # PD disaggregation: can be "null" (not disaggregated), "prefill" (prefill-only), or "decode" (decode-only) disaggregation_mode: str = "null" disaggregation_transfer_backend: str = "mooncake" disaggregation_bootstrap_port: int = 8998 @@ -273,6 +277,7 @@ def __post_init__(self): logger.warning( f"Flashinfer MoE is enabled. Shared expert fusion is disabled." ) + # Set missing default values if self.tokenizer_path is None: self.tokenizer_path = self.model_path @@ -333,56 +338,12 @@ def __post_init__(self): self.mem_fraction_static = 0.88 # Lazy init to avoid circular import + # Multimodal models need more memory for the image processor from sglang.srt.configs.model_config import ModelConfig - # Multimodal models need more memory for the image processor model_config = ModelConfig.from_server_args(self) - - vision_config = getattr(model_config.hf_config, "vision_config", None) - - if model_config.is_multimodal and vision_config: - # roughly reduce the mem_fraction_static base on params of Vit - original_server_arg_mem_fraction = self.mem_fraction_static - # a base mem_fraction_static factor for regular Vit - base_mem_fraction_reduction_ratio = 0.95 - - vit_num_layers = getattr(vision_config, "num_hidden_layers", 24) - vit_hidden_size = getattr(vision_config, "hidden_size", 1024) - - # baseline ViT params (ViT-L/14) - baseline_vit_layers = 24 - baseline_vit_hidden_size = 1024 - - # weight params count - current_complexity_score = vit_num_layers * (vit_hidden_size**2) - baseline_complexity_score = baseline_vit_layers * ( - baseline_vit_hidden_size**2 - ) - complexity_ratio = ( - current_complexity_score / baseline_complexity_score - if baseline_complexity_score > 0 - else 1.0 - ) - - # every time the complexity grows 100%, adjust final factor for 10% - sensitivity_scale = 0.1 - dynamic_adjustment_factor = 1.0 - sensitivity_scale * ( - complexity_ratio - 1.0 - ) - dynamic_adjustment_factor = max( - 0.8, min(1.05, dynamic_adjustment_factor) - ) - - final_overall_factor = ( - base_mem_fraction_reduction_ratio * dynamic_adjustment_factor - ) - self.mem_fraction_static = ( - original_server_arg_mem_fraction * final_overall_factor - ) - logger.warning( - f"Multimodal model: Dynamically adjusted --mem-fraction-static " - f"from: {original_server_arg_mem_fraction:.3f} to: {self.mem_fraction_static:.3f}." - ) + if model_config.is_multimodal: + self.adjust_mem_fraction_for_vlm(model_config) # Set chunked prefill size, which depends on the gpu memory capacity if self.chunked_prefill_size is None: @@ -406,23 +367,6 @@ def __post_init__(self): else: self.cuda_graph_max_bs = 80 - assert self.moe_dense_tp_size in { - 1, - None, - }, "moe_dense_tp_size only support 1 and None currently" - - if self.attention_backend == "flashmla": - logger.warning( - "FlashMLA only supports a page_size of 64, change page_size to 64." - ) - self.page_size = 64 - - if self.attention_backend == "cutlass_mla": - logger.warning( - "Cutlass MLA only supports a page_size of 128, change page_size to 128." - ) - self.page_size = 128 - # Set kernel backends for hpu device if self.device == "hpu": self.attention_backend = "torch_native" @@ -451,6 +395,18 @@ def __post_init__(self): ) self.page_size = 128 + if self.attention_backend == "flashmla": + logger.warning( + "FlashMLA only supports a page_size of 64, change page_size to 64." + ) + self.page_size = 64 + + if self.attention_backend == "cutlass_mla": + logger.warning( + "Cutlass MLA only supports a page_size of 128, change page_size to 128." + ) + self.page_size = 128 + # Choose grammar backend if self.grammar_backend is None: self.grammar_backend = "xgrammar" @@ -482,12 +438,6 @@ def __post_init__(self): f"DeepEP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]." ) - if self.pp_size > 1: - self.disable_overlap_schedule = True - logger.warning( - "Pipeline parallelism is incompatible with overlap schedule." - ) - if self.enable_eplb and (self.expert_distribution_recorder_mode is None): self.expert_distribution_recorder_mode = "stat" logger.info( @@ -513,6 +463,13 @@ def __post_init__(self): elif self.expert_distribution_recorder_mode is not None: self.expert_distribution_recorder_buffer_size = 1000 + # Pipeline parallelism + if self.pp_size > 1: + self.disable_overlap_schedule = True + logger.warning( + "Pipeline parallelism is incompatible with overlap schedule." + ) + # Speculative Decoding if self.speculative_algorithm == "NEXTN": # NEXTN shares the same implementation of EAGLE @@ -533,8 +490,7 @@ def __post_init__(self): "eagle speculative decoding." ) - model_arch = get_model_arch(self) - + model_arch = self.get_hf_config().architectures[0] if model_arch == "DeepseekV3ForCausalLM": # Auto set draft_model_path DeepSeek-V3/R1 if self.speculative_draft_model_path is None: @@ -624,17 +580,9 @@ def __post_init__(self): if self.custom_weight_loader is None: self.custom_weight_loader = [] - def validate_disagg_tp_size(self, prefill_tp: int, decode_tp: int): - larger_tp = max(decode_tp, prefill_tp) - smaller_tp = min(decode_tp, prefill_tp) - assert larger_tp % smaller_tp == 0, ( - "Different tp size is supported only when one tp is multiple of the other. " - f"decode_tp={decode_tp}, prefill_tp={prefill_tp}" - ) - @staticmethod def add_cli_args(parser: argparse.ArgumentParser): - # Model and port args + # Model and tokenizer parser.add_argument( "--model-path", "--model", @@ -648,24 +596,6 @@ def add_cli_args(parser: argparse.ArgumentParser): default=ServerArgs.tokenizer_path, help="The path of the tokenizer.", ) - parser.add_argument( - "--host", - type=str, - default=ServerArgs.host, - help="The host of the HTTP server.", - ) - parser.add_argument( - "--port", - type=int, - default=ServerArgs.port, - help="The port of the HTTP server.", - ) - parser.add_argument( - "--nccl-port", - type=int, - default=ServerArgs.nccl_port, - help="The port for NCCL distributed environment setup. Defaults to a random port.", - ) parser.add_argument( "--tokenizer-mode", type=str, @@ -680,11 +610,6 @@ def add_cli_args(parser: argparse.ArgumentParser): action="store_true", help="If set, skip init tokenizer and pass input_ids in generate request.", ) - parser.add_argument( - "--skip-server-warmup", - action="store_true", - help="If set, skip warmup.", - ) parser.add_argument( "--load-format", type=str, @@ -730,6 +655,77 @@ def add_cli_args(parser: argparse.ArgumentParser): action="store_true", help="Whether or not to allow for custom models defined on the Hub in their own modeling files.", ) + parser.add_argument( + "--context-length", + type=int, + default=ServerArgs.context_length, + help="The model's maximum context length. Defaults to None (will use the value from the model's config.json instead).", + ) + parser.add_argument( + "--is-embedding", + action="store_true", + help="Whether to use a CausalLM as an embedding model.", + ) + parser.add_argument( + "--enable-multimodal", + default=ServerArgs.enable_multimodal, + action="store_true", + help="Enable the multimodal functionality for the served model. If the model being served is not multimodal, nothing will happen", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + help="The specific model version to use. It can be a branch " + "name, a tag name, or a commit id. If unspecified, will use " + "the default version.", + ) + parser.add_argument( + "--model-impl", + type=str, + default=ServerArgs.model_impl, + help="Which implementation of the model to use.\n\n" + '* "auto" will try to use the SGLang implementation if it exists ' + "and fall back to the Transformers implementation if no SGLang " + "implementation is available.\n" + '* "sglang" will use the SGLang model implementation.\n' + '* "transformers" will use the Transformers model ' + "implementation.\n", + ) + + # HTTP server + parser.add_argument( + "--host", + type=str, + default=ServerArgs.host, + help="The host of the HTTP server.", + ) + parser.add_argument( + "--port", + type=int, + default=ServerArgs.port, + help="The port of the HTTP server.", + ) + parser.add_argument( + "--skip-server-warmup", + action="store_true", + help="If set, skip warmup.", + ) + parser.add_argument( + "--warmups", + type=str, + required=False, + help="Specify custom warmup functions (csv) to run before server starts eg. --warmups=warmup_name1,warmup_name2 " + "will run the functions `warmup_name1` and `warmup_name2` specified in warmup.py before the server starts listening for requests", + ) + parser.add_argument( + "--nccl-port", + type=int, + default=ServerArgs.nccl_port, + help="The port for NCCL distributed environment setup. Defaults to a random port.", + ) + + # Quantization and data type parser.add_argument( "--dtype", type=str, @@ -744,13 +740,6 @@ def add_cli_args(parser: argparse.ArgumentParser): '* "float" is shorthand for FP32 precision.\n' '* "float32" for FP32 precision.', ) - parser.add_argument( - "--kv-cache-dtype", - type=str, - default=ServerArgs.kv_cache_dtype, - choices=["auto", "fp8_e5m2", "fp8_e4m3"], - help='Data type for kv cache storage. "auto" will use model data type. "fp8_e5m2" and "fp8_e4m3" is supported for CUDA 11.8+.', - ) parser.add_argument( "--quantization", type=str, @@ -785,65 +774,11 @@ def add_cli_args(parser: argparse.ArgumentParser): "default to 1.0, which may cause accuracy issues. ", ) parser.add_argument( - "--context-length", - type=int, - default=ServerArgs.context_length, - help="The model's maximum context length. Defaults to None (will use the value from the model's config.json instead).", - ) - parser.add_argument( - "--device", - type=str, - default=ServerArgs.device, - help="The device to use ('cuda', 'xpu', 'hpu', 'npu', 'cpu'). Defaults to auto-detection if not specified.", - ) - parser.add_argument( - "--served-model-name", - type=str, - default=ServerArgs.served_model_name, - help="Override the model name returned by the v1/models endpoint in OpenAI API server.", - ) - parser.add_argument( - "--chat-template", - type=str, - default=ServerArgs.chat_template, - help="The buliltin chat template name or the path of the chat template file. This is only used for OpenAI-compatible API server.", - ) - parser.add_argument( - "--completion-template", - type=str, - default=ServerArgs.completion_template, - help="The buliltin completion template name or the path of the completion template file. This is only used for OpenAI-compatible API server. only for code completion currently.", - ) - parser.add_argument( - "--is-embedding", - action="store_true", - help="Whether to use a CausalLM as an embedding model.", - ) - parser.add_argument( - "--enable-multimodal", - default=ServerArgs.enable_multimodal, - action="store_true", - help="Enable the multimodal functionality for the served model. If the model being served is not multimodal, nothing will happen", - ) - parser.add_argument( - "--revision", - type=str, - default=None, - help="The specific model version to use. It can be a branch " - "name, a tag name, or a commit id. If unspecified, will use " - "the default version.", - ) - parser.add_argument( - "--impl", + "--kv-cache-dtype", type=str, - default=ServerArgs.impl, - help="Which implementation of the model to use.\n\n" - '* "auto" will try to use the SGLang implementation if it exists ' - "and fall back to the Transformers implementation if no SGLang " - "implementation is available.\n" - '* "sglang" will use the SGLang model implementation.\n' - '* "transformers" will use the Transformers model ' - "implementation.\n", + default=ServerArgs.kv_cache_dtype, + choices=["auto", "fp8_e5m2", "fp8_e4m3"], + help='Data type for kv cache storage. "auto" will use model data type. "fp8_e5m2" and "fp8_e4m3" is supported for CUDA 11.8+.', ) # Memory and scheduling @@ -928,7 +863,13 @@ def add_cli_args(parser: argparse.ArgumentParser): help="Disable the hybrid SWA memory.", ) - # Other runtime options + # Runtime options + parser.add_argument( + "--device", + type=str, + default=ServerArgs.device, + help="The device to use ('cuda', 'xpu', 'hpu', 'npu', 'cpu'). Defaults to auto-detection if not specified.", + ) parser.add_argument( "--tensor-parallel-size", "--tp-size", @@ -970,7 +911,7 @@ def add_cli_args(parser: argparse.ArgumentParser): "--constrained-json-whitespace-pattern", type=str, default=ServerArgs.constrained_json_whitespace_pattern, - help=r"Regex pattern for syntactic whitespaces allowed in JSON constrained output. For example, to allow the model generate consecutive whitespaces, set the pattern to [\n\t ]*", + help="(outlines backend only) Regex pattern for syntactic whitespaces allowed in JSON constrained output. For example, to allow the model generate consecutive whitespaces, set the pattern to [\n\t ]*", ) parser.add_argument( "--watchdog-timeout", @@ -1083,12 +1024,6 @@ def add_cli_args(parser: argparse.ArgumentParser): default=ServerArgs.collect_tokens_histogram, help="Collect prompt/generation tokens histogram.", ) - parser.add_argument( - "--kv-events-config", - type=str, - default=None, - help="Config in json format for NVIDIA dynamo KV event publishing. Publishing will be enabled if this flag is used.", - ) parser.add_argument( "--decode-log-interval", type=int, @@ -1101,6 +1036,12 @@ def add_cli_args(parser: argparse.ArgumentParser): default=ServerArgs.enable_request_time_stats_logging, help="Enable per request time stats logging", ) + parser.add_argument( + "--kv-events-config", + type=str, + default=None, + help="Config in json format for NVIDIA dynamo KV event publishing. Publishing will be enabled if this flag is used.", + ) # API related parser.add_argument( @@ -1109,6 +1050,24 @@ def add_cli_args(parser: argparse.ArgumentParser): default=ServerArgs.api_key, help="Set API key of the server. It is also used in the OpenAI API compatible server.", ) + parser.add_argument( + "--served-model-name", + type=str, + default=ServerArgs.served_model_name, + help="Override the model name returned by the v1/models endpoint in OpenAI API server.", + ) + parser.add_argument( + "--chat-template", + type=str, + default=ServerArgs.chat_template, + help="The buliltin chat template name or the path of the chat template file. This is only used for OpenAI-compatible API server.", + ) + parser.add_argument( + "--completion-template", + type=str, + default=ServerArgs.completion_template, + help="The buliltin completion template name or the path of the completion template file. This is only used for OpenAI-compatible API server. only for code completion currently.", + ) parser.add_argument( "--file-storage-path", type=str, @@ -1427,6 +1386,46 @@ def add_cli_args(parser: argparse.ArgumentParser): help="TP size for MoE dense MLP layers. This flag is useful when, with large TP size, there are errors caused by weights in MLP layers having dimension smaller than the min dimension GEMM supports.", ) + # Hierarchical cache + parser.add_argument( + "--enable-hierarchical-cache", + action="store_true", + help="Enable hierarchical cache", + ) + parser.add_argument( + "--hicache-ratio", + type=float, + default=ServerArgs.hicache_ratio, + help="The ratio of the size of host KV cache memory pool to the size of device pool.", + ) + parser.add_argument( + "--hicache-size", + type=int, + default=ServerArgs.hicache_size, + help="The size of host KV cache memory pool in gigabytes, which will override the hicache_ratio if set.", + ) + parser.add_argument( + "--hicache-write-policy", + type=str, + choices=["write_back", "write_through", "write_through_selective"], + default=ServerArgs.hicache_write_policy, + help="The write policy of hierarchical cache.", + ) + parser.add_argument( + "--hicache-io-backend", + type=str, + choices=["direct", "kernel"], + default=ServerArgs.hicache_io_backend, + help="The IO backend for KV cache transfer between CPU and GPU", + ) + parser.add_argument( + "--hicache-storage-backend", + type=str, + choices=["file"], # todo, mooncake + default=ServerArgs.hicache_storage_backend, + help="The storage backend for hierarchical KV cache.", + ) + # Double Sparsity parser.add_argument( "--enable-double-sparsity", @@ -1619,44 +1618,6 @@ def add_cli_args(parser: argparse.ArgumentParser): action="store_true", help="Enable users to pass custom logit processors to the server (disabled by default for security)", ) - parser.add_argument( - "--enable-hierarchical-cache", - action="store_true", - help="Enable hierarchical cache", - ) - parser.add_argument( - "--hicache-ratio", - type=float, - default=ServerArgs.hicache_ratio, - help="The ratio of the size of host KV cache memory pool to the size of device pool.", - ) - parser.add_argument( - "--hicache-size", - type=int, - default=ServerArgs.hicache_size, - help="The size of host KV cache memory pool in gigabytes, which will override the hicache_ratio if set.", - ) - parser.add_argument( - "--hicache-write-policy", - type=str, - choices=["write_back", "write_through", "write_through_selective"], - default=ServerArgs.hicache_write_policy, - help="The write policy of hierarchical cache.", - ) - parser.add_argument( - "--hicache-io-backend", - type=str, - choices=["direct", "kernel"], - default=ServerArgs.hicache_io_backend, - help="The IO backend for KV cache transfer between CPU and GPU", - ) - parser.add_argument( - "--hicache-storage-backend", - type=str, - choices=["file"], # todo, mooncacke - default=ServerArgs.hicache_storage_backend, - help="The storage backend for hierarchical KV cache.", - ) parser.add_argument( "--flashinfer-mla-disable-ragged", action="store_true", @@ -1687,13 +1648,6 @@ def add_cli_args(parser: argparse.ArgumentParser): action="store_true", help="Use triton moe grouped gemm kernel.", ) - parser.add_argument( - "--warmups", - type=str, - required=False, - help="Specify custom warmup functions (csv) to run before server starts eg. --warmups=warmup_name1,warmup_name2 " - "will run the functions `warmup_name1` and `warmup_name2` specified in warmup.py before the server starts listening for requests", - ) # Debug tensor dumps parser.add_argument( @@ -1720,7 +1674,7 @@ def add_cli_args(parser: argparse.ArgumentParser): help="Only dump the tensors for prefill requests (i.e. batch size > 1).", ) - # Disaggregation + # PD disaggregation parser.add_argument( "--disaggregation-mode", type=str, @@ -1779,6 +1733,8 @@ def add_cli_args(parser: argparse.ArgumentParser): default=None, help="The URL of the PD disaggregation load balancer. If set, the prefill/decode server will register with the load balancer.", ) + + # Custom weight loader parser.add_argument( "--custom-weight-loader", type=str, @@ -1791,6 +1747,8 @@ def add_cli_args(parser: argparse.ArgumentParser): action="store_true", help="Enable PD-Multiplexing, PD running on greenctx stream.", ) + + # For PD-Multiplexing parser.add_argument( "--sm-group-num", type=int, @@ -1818,6 +1776,17 @@ def url(self): else: return f"http://{self.host}:{self.port}" + def get_hf_config(self): + kwargs = {} + hf_config = get_config( + self.model_path, + trust_remote_code=self.trust_remote_code, + revision=self.revision, + model_override_args=json.loads(self.json_model_override_args), + **kwargs, + ) + return hf_config + def check_server_args(self): assert ( self.tp_size * self.pp_size @@ -1842,6 +1811,11 @@ def check_server_args(self): assert self.base_gpu_id >= 0, "base_gpu_id must be non-negative" assert self.gpu_id_step >= 1, "gpu_id_step must be positive" + assert self.moe_dense_tp_size in { + 1, + None, + }, "moe_dense_tp_size only support 1 and None currently" + if isinstance(self.lora_paths, list): lora_paths = self.lora_paths self.lora_paths = {} @@ -1852,6 +1826,56 @@ def check_server_args(self): else: self.lora_paths[lora_path] = lora_path + def validate_disagg_tp_size(self, prefill_tp: int, decode_tp: int): + larger_tp = max(decode_tp, prefill_tp) + smaller_tp = min(decode_tp, prefill_tp) + assert larger_tp % smaller_tp == 0, ( + "Different tp size is supported only when one tp is multiple of the other. " + f"decode_tp={decode_tp}, prefill_tp={prefill_tp}" + ) + + def adjust_mem_fraction_for_vlm(self, model_config): + vision_config = getattr(model_config.hf_config, "vision_config", None) + if vision_config is None: + return + + # roughly reduce the mem_fraction_static base on params of Vit + original_server_arg_mem_fraction = self.mem_fraction_static + # a base mem_fraction_static factor for regular Vit + base_mem_fraction_reduction_ratio = 0.95 + + vit_num_layers = getattr(vision_config, "num_hidden_layers", 24) + vit_hidden_size = getattr(vision_config, "hidden_size", 1024) + + # baseline ViT params (ViT-L/14) + baseline_vit_layers = 24 + baseline_vit_hidden_size = 1024 + + # weight params count + current_complexity_score = vit_num_layers * (vit_hidden_size**2) + baseline_complexity_score = baseline_vit_layers * (baseline_vit_hidden_size**2) + complexity_ratio = ( + current_complexity_score / baseline_complexity_score + if baseline_complexity_score > 0 + else 1.0 + ) + + # every time the complexity grows 100%, adjust final factor for 10% + sensitivity_scale = 0.1 + dynamic_adjustment_factor = 1.0 - sensitivity_scale * (complexity_ratio - 1.0) + dynamic_adjustment_factor = max(0.8, min(1.05, dynamic_adjustment_factor)) + + final_overall_factor = ( + base_mem_fraction_reduction_ratio * dynamic_adjustment_factor + ) + self.mem_fraction_static = ( + original_server_arg_mem_fraction * final_overall_factor + ) + logger.warning( + f"Multimodal model: Dynamically adjusted --mem-fraction-static " + f"from: {original_server_arg_mem_fraction:.3f} to: {self.mem_fraction_static:.3f}." + ) + def prepare_server_args(argv: List[str]) -> ServerArgs: """ @@ -1895,16 +1919,16 @@ class PortArgs: @staticmethod def init_new(server_args, dp_rank: Optional[int] = None) -> "PortArgs": if server_args.nccl_port is None: - port = server_args.port + random.randint(100, 1000) + nccl_port = server_args.port + random.randint(100, 1000) while True: - if is_port_available(port): + if is_port_available(nccl_port): break - if port < 60000: - port += 42 + if nccl_port < 60000: + nccl_port += 42 else: - port -= 43 + nccl_port -= 43 else: - port = server_args.nccl_port + nccl_port = server_args.nccl_port if not server_args.enable_dp_attention: # Normal case, use IPC within a single node @@ -1912,7 +1936,7 @@ def init_new(server_args, dp_rank: Optional[int] = None) -> "PortArgs": tokenizer_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}", scheduler_input_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}", detokenizer_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}", - nccl_port=port, + nccl_port=nccl_port, rpc_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}", metrics_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}", ) @@ -1942,7 +1966,7 @@ def init_new(server_args, dp_rank: Optional[int] = None) -> "PortArgs": tokenizer_ipc_name=f"tcp://{dist_init_host}:{port_base}", scheduler_input_ipc_name=f"tcp://{dist_init_host}:{scheduler_input_port}", detokenizer_ipc_name=f"tcp://{dist_init_host}:{port_base + 1}", - nccl_port=port, + nccl_port=nccl_port, rpc_ipc_name=f"tcp://{dist_init_host}:{port_base + 2}", metrics_ipc_name=f"tcp://{dist_init_host}:{port_base + 3}", ) @@ -1969,31 +1993,13 @@ def __call__(self, parser, namespace, values, option_string=None): raise ValueError(self.help) -def get_model_arch(args: ServerArgs): - hf_config = get_config( - args.model_path, - trust_remote_code=args.trust_remote_code, - revision=args.revision, - model_override_args=json.loads(args.json_model_override_args), - ) - return hf_config.architectures[0] - - def auto_choose_speculative_params(self: ServerArgs): """ Automatically choose the parameters for speculative decoding. You can tune them on your own models and prompts with scripts/playground/bench_speculative.py """ - kwargs = {} - - hf_config = get_config( - self.model_path, - trust_remote_code=self.trust_remote_code, - revision=self.revision, - model_override_args=json.loads(self.json_model_override_args), - **kwargs, - ) + hf_config = self.get_hf_config() arch = hf_config.architectures[0] if arch in ["LlamaForCausalLM"]: diff --git a/python/sglang/test/runners.py b/python/sglang/test/runners.py index 64a1b34c20a6..941940fe0fd8 100644 --- a/python/sglang/test/runners.py +++ b/python/sglang/test/runners.py @@ -481,7 +481,7 @@ def __init__( torch_dtype: torch.dtype, model_type: str, tp_size: int = 1, - impl: str = "auto", + model_impl: str = "auto", port: int = DEFAULT_PORT_FOR_SRT_TEST_RUNNER, lora_paths: List[str] = None, max_loras_per_batch: int = 4, @@ -525,7 +525,7 @@ def __init__( tp_size=tp_size, dtype=get_dtype_str(torch_dtype), port=port, - impl=impl, + model_impl=model_impl, torchao_config=torchao_config, mem_fraction_static=mem_fraction_static, trust_remote_code=trust_remote_code, diff --git a/test/srt/models/test_transformers_models.py b/test/srt/models/test_transformers_models.py index 7e92b49d1637..95592453fb10 100644 --- a/test/srt/models/test_transformers_models.py +++ b/test/srt/models/test_transformers_models.py @@ -27,7 +27,7 @@ def setUpClass(cls): cls.model, cls.base_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=["--impl", "transformers"], + other_args=["--model-impl", "transformers"], ) cls.mmlu_lower_bound = 0.65 cls.gsm8k_lower_bound = 0.65 @@ -76,7 +76,7 @@ def setUpClass(cls): cls.base_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, other_args=[ - "--impl", + "--model-impl", "transformers", "--torchao-config", "int4wo-128", @@ -127,7 +127,7 @@ def assert_close_logits_and_output_strs( tp_size=model_case.tp_size, torch_dtype=model_case.torch_dtype, model_type="generation", - impl="transformers", + model_impl="transformers", trust_remote_code=model_case.trust_remote_code, torchao_config=model_case.torchao_config, ) as srt_runner: From 3de617a75bc9682763ba4f5f402a679e0df5dd22 Mon Sep 17 00:00:00 2001 From: Lifu Huang Date: Sat, 19 Jul 2025 13:14:08 -0700 Subject: [PATCH 053/396] Fix LoRA buffer contamination during adapter eviction (#8103) --- python/sglang/srt/lora/mem_pool.py | 51 +++++++--- test/srt/models/lora/test_lora_eviction.py | 111 +++++++++++++++++++++ test/srt/run_suite.py | 1 + 3 files changed, 148 insertions(+), 15 deletions(-) create mode 100644 test/srt/models/lora/test_lora_eviction.py diff --git a/python/sglang/srt/lora/mem_pool.py b/python/sglang/srt/lora/mem_pool.py index 713b03650cf1..1b36cac5e1a7 100644 --- a/python/sglang/srt/lora/mem_pool.py +++ b/python/sglang/srt/lora/mem_pool.py @@ -188,10 +188,18 @@ def load_lora_weight_to_buffer( lora_adapter: LoRAAdapter, lora_modules: Dict[int, Dict[str, BaseLayerWithLoRA]], ): - def check_lora_weight_shape(buffer_view: torch.Tensor, weight: torch.Tensor): - assert ( - buffer_view.shape == weight.shape - ), f"LoRA buffer shape {buffer_view.shape} does not match weight shape {weight.shape}." + def load_lora_weight_tensor( + buffer_view: torch.Tensor, weight: Optional[torch.Tensor] + ): + if weight is None: + # If the particular weight is not present in the adapter, we initialize the buffer to zero + # to avoid contamination from the residual weight of the evicted adapters. + buffer_view.zero_() + else: + assert ( + buffer_view.shape == weight.shape + ), f"LoRA buffer shape {buffer_view.shape} does not match weight shape {weight.shape}." + buffer_view.copy_(weight) if uid is None: for i in range(self.num_layer): @@ -203,8 +211,12 @@ def check_lora_weight_shape(buffer_view: torch.Tensor, weight: torch.Tensor): lora_rank = lora_adapter.config.hf_config["r"] for layer_id in range(self.num_layer): layer_weights = lora_adapter.layers[layer_id].weights - temp_A_buffer: Dict[str, torch.Tensor] = {} - temp_B_buffer: Dict[str, torch.Tensor] = {} + temp_A_buffer: Dict[str, Optional[torch.Tensor]] = { + weight_name: None for weight_name in self.A_buffer + } + temp_B_buffer: Dict[str, Optional[torch.Tensor]] = { + weight_name: None for weight_name in self.B_buffer + } for name, weights in layer_weights.items(): if "lora_A" in name: lora_weight_name = get_weight_name( @@ -220,6 +232,14 @@ def check_lora_weight_shape(buffer_view: torch.Tensor, weight: torch.Tensor): if self.tp_size > 1: cur_layer_modules = lora_modules[layer_id] for module_name, module in cur_layer_modules.items(): + weight_name = get_weight_name( + module_name, self.lora_weight_names, LoRAType.LORA_A + ) + + if temp_A_buffer[weight_name] is None: + # Skip weight slicing if the weight is not present in the adapter + continue + if "qkv_proj" in module_name: temp_A_buffer["qkv_proj"] = module.slice_lora_a_weights( temp_A_buffer["qkv_proj"], self.tp_rank @@ -231,9 +251,10 @@ def check_lora_weight_shape(buffer_view: torch.Tensor, weight: torch.Tensor): ) ) else: - weight_name = get_weight_name( - module_name, self.lora_weight_names, LoRAType.LORA_A - ) + # TODO (lifuhuang): Ideally, we should call `get_weight_name` separately for both A and B. + # Currently, we're reusing A's weight name as a workaround, relying on the fact that A and + # B share the same name except for `qkv_proj`. We should clean this up once we deprecate the + # FlashInfer LoRA backend. temp_A_buffer[weight_name] = module.slice_lora_a_weights( temp_A_buffer[weight_name], self.tp_rank ) @@ -246,8 +267,7 @@ def check_lora_weight_shape(buffer_view: torch.Tensor, weight: torch.Tensor): buffer_view = self.A_buffer[name][layer_id][buffer_id][ : lora_rank * c, : ] - check_lora_weight_shape(buffer_view, weights) - buffer_view.copy_(weights) + load_lora_weight_tensor(buffer_view, weights) for name, weights in temp_B_buffer.items(): c = get_stacked_multiply(name) @@ -256,14 +276,15 @@ def check_lora_weight_shape(buffer_view: torch.Tensor, weight: torch.Tensor): buffer_view = self.B_buffer[name][layer_id][stacked_id][ buffer_id ][:, :lora_rank] - check_lora_weight_shape(buffer_view, weights[stacked_id]) - buffer_view.copy_(weights[stacked_id]) + weight_slice = ( + weights[stacked_id] if weights is not None else None + ) + load_lora_weight_tensor(buffer_view, weight_slice) else: buffer_view = self.B_buffer[name][layer_id][0][buffer_id][ :, :lora_rank ] - check_lora_weight_shape(buffer_view, weights) - buffer_view.copy_(weights) + load_lora_weight_tensor(buffer_view, weights) def get_tensor( self, weight_name: str, layer_id: int, lora_type: LoRAType diff --git a/test/srt/models/lora/test_lora_eviction.py b/test/srt/models/lora/test_lora_eviction.py new file mode 100644 index 000000000000..e74af0a0e61d --- /dev/null +++ b/test/srt/models/lora/test_lora_eviction.py @@ -0,0 +1,111 @@ +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import multiprocessing as mp +import unittest +from typing import Dict, List, Tuple + +import torch + +from sglang.test.runners import SRTRunner +from sglang.test.test_utils import CustomTestCase + +PROMPTS = [ + "AI is a field of computer science focused on", + """ + ### Instruction: + Compose a SQL query that uses the following table: users, and returns the user_id and name of all users whose name that does not have a duplicate in the table. + ### Response: + SELECT user_id, name FROM users WHERE name LIKE 'A%'; + """, +] + +ADAPTERS = [ + "faridlazuarda/valadapt-llama-3.1-8B-it-chinese", # target_modules = q, v + "philschmid/code-llama-3-1-8b-text-to-sql-lora", # target_modules = q, k, v, o, gate, up, down +] + +BASE_MODEL = "meta-llama/Meta-Llama-3.1-8B-Instruct" + + +class TestLoRAEviction(CustomTestCase): + def test_lora_eviction_with_different_target_modules(self): + """ + Test LoRA eviction with different target modules. + + This test runs inference against two LoRA adapters in different orders to force eviction behavior, and ensures + that the outputs of the same (adapter, prompt) pair are consistent across runs. + """ + output_history = {} + self._run_test(ADAPTERS, output_history, reverse=False) + self._run_test(ADAPTERS, output_history, reverse=True) + + def _run_test( + self, + lora_paths: List[str], + output_history: Dict[Tuple[str, str], str], + reverse: bool, + repeat: int = 2, + ): + max_new_tokens = 256 + backend = "triton" + torch_dtype = torch.float16 + base_path = BASE_MODEL + assert len(lora_paths) >= 2 + + # Initialize runners + with SRTRunner( + base_path, + torch_dtype=torch_dtype, + model_type="generation", + lora_paths=lora_paths, + max_loras_per_batch=1, + lora_backend=backend, + disable_radix_cache=True, + ) as srt_runner: + adapter_sequence = lora_paths if not reverse else lora_paths[::-1] + + for i in range(repeat): + for j, adapter in enumerate(adapter_sequence): + print( + f"\n========== Testing LoRA eviction with adapter '{adapter}' (#{j+1}/{len(adapter_sequence)}), reversed: {reverse}, repeat: {i+1}/{repeat} ---" + ) + for prompt in PROMPTS: + print("\nprompt:\n", prompt) + srt_outputs = srt_runner.forward( + [prompt], + max_new_tokens=max_new_tokens, + lora_paths=[adapter], + ) + output = srt_outputs.output_strs[0].strip() + print("\noutput:\n", output) + + prev_output = output_history.get((adapter, prompt)) + if prev_output is not None: + self.assertEqual( + prev_output, + output, + f"Output mismatch for adapter {adapter} and prompt '{prompt}' on repeat {j + 1}, previous: '{prev_output}', current: '{output}'.", + ) + else: + output_history[(adapter, prompt)] = output + + +if __name__ == "__main__": + try: + mp.set_start_method("spawn") + except RuntimeError: + pass + + unittest.main(warnings="ignore") diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index e67362cf8258..f59aed623e0f 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -14,6 +14,7 @@ class TestFile: suites = { "per-commit": [ TestFile("models/lora/test_lora.py", 200), + TestFile("models/lora/test_lora_eviction.py", 120), TestFile("models/lora/test_lora_backend.py", 99), TestFile("models/lora/test_multi_lora_backend.py", 60), TestFile("models/lora/test_lora_cuda_graph.py", 250), From bfdd226f355721fd93d35f48c3132130fc3ff70e Mon Sep 17 00:00:00 2001 From: kyleliang-nv Date: Sat, 19 Jul 2025 14:37:53 -0700 Subject: [PATCH 054/396] Fix Dockerfile.gb200 (#8169) --- docker/Dockerfile.gb200 | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docker/Dockerfile.gb200 b/docker/Dockerfile.gb200 index 05b0f42043bc..1e0e665234f1 100644 --- a/docker/Dockerfile.gb200 +++ b/docker/Dockerfile.gb200 @@ -140,8 +140,8 @@ RUN apt-get update && apt-get install -y \ RUN apt update -y \ && apt install -y --no-install-recommends gnupg \ - && echo "deb http://developer.download.nvidia.com/devtools/repos/ubuntu2204/$(if [ "$(uname -m)" = "aarch64" ]; then echo "arm64"; else echo "amd64"; fi) /" | tee /etc/apt/sources.list.d/nvidia-devtools.list \ - && apt-key adv --fetch-keys http://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/$(if [ "$(uname -m)" = "aarch64" ]; then echo "sbsa"; else echo "x86_64"; fi)/3bf863cc.pub \ + && echo "deb http://developer.download.nvidia.com/devtools/repos/ubuntu2004/$(if [ "$(uname -m)" = "aarch64" ]; then echo "arm64"; else echo "amd64"; fi) /" | tee /etc/apt/sources.list.d/nvidia-devtools.list \ + && apt-key adv --fetch-keys http://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/$(if [ "$(uname -m)" = "aarch64" ]; then echo "arm64"; else echo "x86_64"; fi)/7fa2af80.pub \ && apt update -y \ && apt install nsight-systems-cli -y From 41d33e4736707cea54aa731055cf88f367befefc Mon Sep 17 00:00:00 2001 From: Simo Lin Date: Sat, 19 Jul 2025 14:38:33 -0700 Subject: [PATCH 055/396] [router] add ut for worker and errors (#8170) --- sgl-router/src/core/error.rs | 179 ++++++++++ sgl-router/src/core/worker.rs | 610 ++++++++++++++++++++++++++++++++++ 2 files changed, 789 insertions(+) diff --git a/sgl-router/src/core/error.rs b/sgl-router/src/core/error.rs index 02a87dbbc630..4d50ccee0df5 100644 --- a/sgl-router/src/core/error.rs +++ b/sgl-router/src/core/error.rs @@ -55,3 +55,182 @@ impl From for WorkerError { } } } + +#[cfg(test)] +mod tests { + use super::*; + use std::error::Error; + + #[test] + fn test_health_check_failed_display() { + let error = WorkerError::HealthCheckFailed { + url: "http://worker1:8080".to_string(), + reason: "Connection refused".to_string(), + }; + assert_eq!( + error.to_string(), + "Health check failed for worker http://worker1:8080: Connection refused" + ); + } + + #[test] + fn test_worker_not_found_display() { + let error = WorkerError::WorkerNotFound { + url: "http://worker2:8080".to_string(), + }; + assert_eq!(error.to_string(), "Worker not found: http://worker2:8080"); + } + + #[test] + fn test_invalid_configuration_display() { + let error = WorkerError::InvalidConfiguration { + message: "Missing port number".to_string(), + }; + assert_eq!( + error.to_string(), + "Invalid worker configuration: Missing port number" + ); + } + + #[test] + fn test_network_error_display() { + let error = WorkerError::NetworkError { + url: "http://worker3:8080".to_string(), + error: "Timeout after 30s".to_string(), + }; + assert_eq!( + error.to_string(), + "Network error for worker http://worker3:8080: Timeout after 30s" + ); + } + + #[test] + fn test_worker_at_capacity_display() { + let error = WorkerError::WorkerAtCapacity { + url: "http://worker4:8080".to_string(), + }; + assert_eq!(error.to_string(), "Worker at capacity: http://worker4:8080"); + } + + #[test] + fn test_worker_error_implements_std_error() { + let error = WorkerError::WorkerNotFound { + url: "http://test".to_string(), + }; + // Verify it implements Error trait + let _: &dyn Error = &error; + assert!(error.source().is_none()); + } + + #[test] + fn test_error_send_sync() { + fn assert_send_sync() {} + assert_send_sync::(); + } + + #[test] + fn test_worker_result_type_alias() { + // Test Ok variant + let result: WorkerResult = Ok(42); + assert!(result.is_ok()); + assert_eq!(result.unwrap(), 42); + + // Test Err variant + let error = WorkerError::WorkerNotFound { + url: "test".to_string(), + }; + let result: WorkerResult = Err(error); + assert!(result.is_err()); + } + + #[test] + fn test_empty_url_handling() { + // Test empty URLs in error variants + let error1 = WorkerError::HealthCheckFailed { + url: "".to_string(), + reason: "No connection".to_string(), + }; + assert_eq!( + error1.to_string(), + "Health check failed for worker : No connection" + ); + + let error2 = WorkerError::NetworkError { + url: "".to_string(), + error: "DNS failure".to_string(), + }; + assert_eq!(error2.to_string(), "Network error for worker : DNS failure"); + + let error3 = WorkerError::WorkerNotFound { + url: "".to_string(), + }; + assert_eq!(error3.to_string(), "Worker not found: "); + } + + #[test] + fn test_special_characters_in_messages() { + // Test with special characters + let error = WorkerError::InvalidConfiguration { + message: "Invalid JSON: {\"error\": \"test\"}".to_string(), + }; + assert_eq!( + error.to_string(), + "Invalid worker configuration: Invalid JSON: {\"error\": \"test\"}" + ); + + // Test with unicode + let error2 = WorkerError::HealthCheckFailed { + url: "http://测试:8080".to_string(), + reason: "连接被拒绝".to_string(), + }; + assert_eq!( + error2.to_string(), + "Health check failed for worker http://测试:8080: 连接被拒绝" + ); + } + + #[test] + fn test_very_long_error_messages() { + let long_message = "A".repeat(10000); + let error = WorkerError::InvalidConfiguration { + message: long_message.clone(), + }; + let display = error.to_string(); + assert!(display.contains(&long_message)); + assert_eq!( + display.len(), + "Invalid worker configuration: ".len() + long_message.len() + ); + } + + // Mock reqwest error for testing conversion + #[test] + fn test_reqwest_error_conversion() { + // Test that NetworkError is the correct variant + let network_error = WorkerError::NetworkError { + url: "http://example.com".to_string(), + error: "connection timeout".to_string(), + }; + + match network_error { + WorkerError::NetworkError { url, error } => { + assert_eq!(url, "http://example.com"); + assert_eq!(error, "connection timeout"); + } + _ => panic!("Expected NetworkError variant"), + } + } + + #[test] + fn test_error_equality() { + // WorkerError doesn't implement PartialEq, but we can test that + // the same error construction produces the same display output + let error1 = WorkerError::WorkerNotFound { + url: "http://test".to_string(), + }; + let error2 = WorkerError::WorkerNotFound { + url: "http://test".to_string(), + }; + assert_eq!(error1.to_string(), error2.to_string()); + } +} diff --git a/sgl-router/src/core/worker.rs b/sgl-router/src/core/worker.rs index ae88bdd1cc99..1aa6766c1886 100644 --- a/sgl-router/src/core/worker.rs +++ b/sgl-router/src/core/worker.rs @@ -452,3 +452,613 @@ pub fn start_health_checker( HealthChecker { handle, shutdown } } + +#[cfg(test)] +mod tests { + use super::*; + use std::sync::RwLock; + use std::time::Duration; + use tokio::time::timeout; + + // Test WorkerType + #[test] + fn test_worker_type_display() { + assert_eq!(WorkerType::Regular.to_string(), "Regular"); + assert_eq!( + WorkerType::Prefill { + bootstrap_port: Some(8080) + } + .to_string(), + "Prefill(bootstrap:8080)" + ); + assert_eq!( + WorkerType::Prefill { + bootstrap_port: None + } + .to_string(), + "Prefill" + ); + assert_eq!(WorkerType::Decode.to_string(), "Decode"); + } + + #[test] + fn test_worker_type_equality() { + assert_eq!(WorkerType::Regular, WorkerType::Regular); + assert_ne!(WorkerType::Regular, WorkerType::Decode); + assert_eq!( + WorkerType::Prefill { + bootstrap_port: Some(8080) + }, + WorkerType::Prefill { + bootstrap_port: Some(8080) + } + ); + assert_ne!( + WorkerType::Prefill { + bootstrap_port: Some(8080) + }, + WorkerType::Prefill { + bootstrap_port: Some(8081) + } + ); + } + + #[test] + fn test_worker_type_clone() { + let original = WorkerType::Prefill { + bootstrap_port: Some(8080), + }; + let cloned = original.clone(); + assert_eq!(original, cloned); + } + + // Test HealthConfig + #[test] + fn test_health_config_default() { + let config = HealthConfig::default(); + assert_eq!(config.timeout_secs, 5); + assert_eq!(config.check_interval_secs, 30); + assert_eq!(config.endpoint, "/health"); + } + + #[test] + fn test_health_config_custom() { + let config = HealthConfig { + timeout_secs: 10, + check_interval_secs: 60, + endpoint: "/healthz".to_string(), + }; + assert_eq!(config.timeout_secs, 10); + assert_eq!(config.check_interval_secs, 60); + assert_eq!(config.endpoint, "/healthz"); + } + + // Test BasicWorker + #[test] + fn test_basic_worker_creation() { + let worker = BasicWorker::new("http://test:8080".to_string(), WorkerType::Regular); + assert_eq!(worker.url(), "http://test:8080"); + assert_eq!(worker.worker_type(), WorkerType::Regular); + assert!(worker.is_healthy()); + assert_eq!(worker.load(), 0); + assert_eq!(worker.processed_requests(), 0); + } + + #[test] + fn test_worker_with_labels() { + let mut labels = std::collections::HashMap::new(); + labels.insert("env".to_string(), "prod".to_string()); + labels.insert("zone".to_string(), "us-west".to_string()); + + let worker = BasicWorker::new("http://test:8080".to_string(), WorkerType::Regular) + .with_labels(labels.clone()); + + assert_eq!(worker.metadata().labels, labels); + } + + #[test] + fn test_worker_with_health_config() { + let custom_config = HealthConfig { + timeout_secs: 15, + check_interval_secs: 45, + endpoint: "/custom-health".to_string(), + }; + + let worker = BasicWorker::new("http://test:8080".to_string(), WorkerType::Regular) + .with_health_config(custom_config.clone()); + + assert_eq!(worker.metadata().health_config.timeout_secs, 15); + assert_eq!(worker.metadata().health_config.check_interval_secs, 45); + assert_eq!(worker.metadata().health_config.endpoint, "/custom-health"); + } + + // Test Worker trait implementation + #[test] + fn test_worker_url() { + let worker = BasicWorker::new("http://worker1:8080".to_string(), WorkerType::Regular); + assert_eq!(worker.url(), "http://worker1:8080"); + } + + #[test] + fn test_worker_type_getter() { + let regular = BasicWorker::new("http://test:8080".to_string(), WorkerType::Regular); + assert_eq!(regular.worker_type(), WorkerType::Regular); + + let prefill = BasicWorker::new( + "http://test:8080".to_string(), + WorkerType::Prefill { + bootstrap_port: Some(9090), + }, + ); + assert_eq!( + prefill.worker_type(), + WorkerType::Prefill { + bootstrap_port: Some(9090) + } + ); + + let decode = BasicWorker::new("http://test:8080".to_string(), WorkerType::Decode); + assert_eq!(decode.worker_type(), WorkerType::Decode); + } + + #[test] + fn test_health_status() { + let worker = BasicWorker::new("http://test:8080".to_string(), WorkerType::Regular); + + // Initial state is healthy + assert!(worker.is_healthy()); + + // Set unhealthy + worker.set_healthy(false); + assert!(!worker.is_healthy()); + + // Set healthy again + worker.set_healthy(true); + assert!(worker.is_healthy()); + } + + #[test] + fn test_load_counter_operations() { + let worker = BasicWorker::new("http://test:8080".to_string(), WorkerType::Regular); + + // Initial load is 0 + assert_eq!(worker.load(), 0); + + // Increment once + worker.increment_load(); + assert_eq!(worker.load(), 1); + + // Increment twice more + worker.increment_load(); + worker.increment_load(); + assert_eq!(worker.load(), 3); + + // Decrement once + worker.decrement_load(); + assert_eq!(worker.load(), 2); + + // Decrement to 0 + worker.decrement_load(); + worker.decrement_load(); + assert_eq!(worker.load(), 0); + + // Decrement below 0 should stay at 0 + worker.decrement_load(); + assert_eq!(worker.load(), 0); + } + + #[test] + fn test_processed_counter() { + let worker = BasicWorker::new("http://test:8080".to_string(), WorkerType::Regular); + + // Initial count is 0 + assert_eq!(worker.processed_requests(), 0); + + // Increment multiple times + for i in 1..=100 { + worker.increment_processed(); + assert_eq!(worker.processed_requests(), i); + } + } + + #[test] + fn test_clone_worker() { + let original = BasicWorker::new("http://test:8080".to_string(), WorkerType::Regular); + original.increment_load(); + original.increment_processed(); + original.set_healthy(false); + + let cloned = original.clone_worker(); + + // Verify cloned worker has same URL and type + assert_eq!(cloned.url(), original.url()); + assert_eq!(cloned.worker_type(), original.worker_type()); + + // Load counters should be independent (cloned shares the Arc) + assert_eq!(cloned.load(), original.load()); + + // Modify original and verify clone is affected (shared state) + original.increment_load(); + assert_eq!(cloned.load(), original.load()); + } + + // Test concurrent operations + #[tokio::test] + async fn test_concurrent_load_increments() { + let worker = Arc::new(BasicWorker::new( + "http://test:8080".to_string(), + WorkerType::Regular, + )); + + let mut handles = vec![]; + + // Spawn 100 tasks incrementing load + for _ in 0..100 { + let worker_clone = Arc::clone(&worker); + let handle = tokio::spawn(async move { + worker_clone.increment_load(); + }); + handles.push(handle); + } + + // Wait for all tasks + for handle in handles { + handle.await.unwrap(); + } + + // Final count should be 100 + assert_eq!(worker.load(), 100); + } + + #[tokio::test] + async fn test_concurrent_load_decrements() { + let worker = Arc::new(BasicWorker::new( + "http://test:8080".to_string(), + WorkerType::Regular, + )); + + // Set initial load to 100 + for _ in 0..100 { + worker.increment_load(); + } + assert_eq!(worker.load(), 100); + + let mut handles = vec![]; + + // Spawn 100 tasks decrementing load + for _ in 0..100 { + let worker_clone = Arc::clone(&worker); + let handle = tokio::spawn(async move { + worker_clone.decrement_load(); + }); + handles.push(handle); + } + + // Wait for all tasks + for handle in handles { + handle.await.unwrap(); + } + + // Final count should be 0 + assert_eq!(worker.load(), 0); + } + + #[tokio::test] + async fn test_concurrent_health_updates() { + let worker = Arc::new(BasicWorker::new( + "http://test:8080".to_string(), + WorkerType::Regular, + )); + + let mut handles = vec![]; + + // Spawn threads randomly setting health status + for i in 0..100 { + let worker_clone = Arc::clone(&worker); + let handle = tokio::spawn(async move { + worker_clone.set_healthy(i % 2 == 0); + tokio::time::sleep(Duration::from_micros(10)).await; + }); + handles.push(handle); + } + + // Wait for all tasks + for handle in handles { + handle.await.unwrap(); + } + + // Final state should be deterministic (last write wins) + // We can't predict the exact final state due to scheduling, + // but we can verify no data corruption occurred + let final_health = worker.is_healthy(); + assert!(final_health == true || final_health == false); + } + + // Test WorkerFactory + #[test] + fn test_create_regular_worker() { + let worker = WorkerFactory::create_regular("http://regular:8080".to_string()); + assert_eq!(worker.url(), "http://regular:8080"); + assert_eq!(worker.worker_type(), WorkerType::Regular); + } + + #[test] + fn test_create_prefill_worker() { + // With bootstrap port + let worker1 = WorkerFactory::create_prefill("http://prefill:8080".to_string(), Some(9090)); + assert_eq!(worker1.url(), "http://prefill:8080"); + assert_eq!( + worker1.worker_type(), + WorkerType::Prefill { + bootstrap_port: Some(9090) + } + ); + + // Without bootstrap port + let worker2 = WorkerFactory::create_prefill("http://prefill:8080".to_string(), None); + assert_eq!( + worker2.worker_type(), + WorkerType::Prefill { + bootstrap_port: None + } + ); + } + + #[test] + fn test_create_decode_worker() { + let worker = WorkerFactory::create_decode("http://decode:8080".to_string()); + assert_eq!(worker.url(), "http://decode:8080"); + assert_eq!(worker.worker_type(), WorkerType::Decode); + } + + #[test] + fn test_create_from_urls() { + let regular_urls = vec![ + "http://regular1:8080".to_string(), + "http://regular2:8080".to_string(), + ]; + let prefill_urls = vec![ + ("http://prefill1:8080".to_string(), Some(9090)), + ("http://prefill2:8080".to_string(), None), + ]; + let decode_urls = vec![ + "http://decode1:8080".to_string(), + "http://decode2:8080".to_string(), + ]; + + let (regular, prefill, decode) = + WorkerFactory::create_from_urls(regular_urls, prefill_urls, decode_urls); + + assert_eq!(regular.len(), 2); + assert_eq!(prefill.len(), 2); + assert_eq!(decode.len(), 2); + + assert_eq!(regular[0].url(), "http://regular1:8080"); + assert_eq!(prefill[0].url(), "http://prefill1:8080"); + assert_eq!(decode[0].url(), "http://decode1:8080"); + } + + // Test WorkerCollection trait + #[test] + fn test_healthy_workers_filter() { + let workers: Vec> = vec![ + WorkerFactory::create_regular("http://w1:8080".to_string()), + WorkerFactory::create_regular("http://w2:8080".to_string()), + WorkerFactory::create_regular("http://w3:8080".to_string()), + ]; + + // Set some workers unhealthy + workers[0].set_healthy(false); + workers[2].set_healthy(false); + + let healthy = workers.healthy_workers(); + assert_eq!(healthy.len(), 1); + assert_eq!(healthy[0].url(), "http://w2:8080"); + } + + #[test] + fn test_total_load_calculation() { + let workers: Vec> = vec![ + WorkerFactory::create_regular("http://w1:8080".to_string()), + WorkerFactory::create_regular("http://w2:8080".to_string()), + WorkerFactory::create_regular("http://w3:8080".to_string()), + ]; + + // Set different loads + workers[0].increment_load(); + workers[0].increment_load(); // load = 2 + + workers[1].increment_load(); + workers[1].increment_load(); + workers[1].increment_load(); // load = 3 + + workers[2].increment_load(); // load = 1 + + assert_eq!(workers.total_load(), 6); + } + + #[test] + fn test_find_worker() { + let workers: Vec> = vec![ + WorkerFactory::create_regular("http://w1:8080".to_string()), + WorkerFactory::create_regular("http://w2:8080".to_string()), + WorkerFactory::create_regular("http://w3:8080".to_string()), + ]; + + // Found case + let found = workers.find_worker("http://w2:8080"); + assert!(found.is_some()); + assert_eq!(found.unwrap().url(), "http://w2:8080"); + + // Not found case + let not_found = workers.find_worker("http://w4:8080"); + assert!(not_found.is_none()); + } + + #[test] + fn test_find_worker_mut() { + let mut workers: Vec> = vec![ + WorkerFactory::create_regular("http://w1:8080".to_string()), + WorkerFactory::create_regular("http://w2:8080".to_string()), + ]; + + // Find and modify + if let Some(worker) = workers.find_worker_mut("http://w1:8080") { + worker.set_healthy(false); + } + + // Verify modification + assert!(!workers[0].is_healthy()); + assert!(workers[1].is_healthy()); + } + + // Test WorkerLoadGuard + #[test] + fn test_load_guard_single_worker() { + let worker = BasicWorker::new("http://test:8080".to_string(), WorkerType::Regular); + assert_eq!(worker.load(), 0); + + { + let _guard = WorkerLoadGuard::new(&worker); + assert_eq!(worker.load(), 1); + } + + // Guard dropped, load decremented + assert_eq!(worker.load(), 0); + } + + #[test] + fn test_load_guard_multiple_workers() { + let workers: Vec> = vec![ + WorkerFactory::create_regular("http://w1:8080".to_string()), + WorkerFactory::create_regular("http://w2:8080".to_string()), + WorkerFactory::create_regular("http://w3:8080".to_string()), + ]; + + let worker_refs: Vec<&dyn Worker> = workers.iter().map(|w| w.as_ref()).collect(); + + { + let _guard = WorkerLoadGuard::new_multi(worker_refs); + // All loads incremented + assert_eq!(workers[0].load(), 1); + assert_eq!(workers[1].load(), 1); + assert_eq!(workers[2].load(), 1); + } + + // All loads decremented + assert_eq!(workers[0].load(), 0); + assert_eq!(workers[1].load(), 0); + assert_eq!(workers[2].load(), 0); + } + + #[test] + fn test_load_guard_panic_safety() { + let worker = Arc::new(BasicWorker::new( + "http://test:8080".to_string(), + WorkerType::Regular, + )); + assert_eq!(worker.load(), 0); + + // Clone for use inside catch_unwind + let worker_clone = Arc::clone(&worker); + + // This will panic, but the guard should still clean up + let result = std::panic::catch_unwind(|| { + let _guard = WorkerLoadGuard::new(worker_clone.as_ref()); + assert_eq!(worker_clone.load(), 1); + panic!("Test panic"); + }); + + // Verify panic occurred + assert!(result.is_err()); + + // Load should be decremented even after panic + assert_eq!(worker.load(), 0); + } + + // Test helper functions + #[test] + fn test_urls_to_workers() { + let urls = vec!["http://w1:8080".to_string(), "http://w2:8080".to_string()]; + + let workers = urls_to_workers(urls); + assert_eq!(workers.len(), 2); + assert_eq!(workers[0].url(), "http://w1:8080"); + assert_eq!(workers[1].url(), "http://w2:8080"); + assert_eq!(workers[0].worker_type(), WorkerType::Regular); + } + + #[test] + fn test_workers_to_urls() { + let workers: Vec> = vec![ + WorkerFactory::create_regular("http://w1:8080".to_string()), + WorkerFactory::create_regular("http://w2:8080".to_string()), + ]; + + let urls = workers_to_urls(&workers); + assert_eq!(urls, vec!["http://w1:8080", "http://w2:8080"]); + } + + // Test synchronous health check wrapper + #[test] + fn test_check_health_sync_wrapper() { + // We can't easily test the actual HTTP call without mocking, + // but we can verify the sync wrapper works + let worker = BasicWorker::new("http://test:8080".to_string(), WorkerType::Regular); + + // This will fail because there's no server at this URL, + // but it tests that the sync wrapper doesn't panic + let result = worker.check_health(); + assert!(result.is_err()); + } + + // Test HealthChecker background task + #[tokio::test] + async fn test_health_checker_startup() { + let workers = Arc::new(RwLock::new(vec![WorkerFactory::create_regular( + "http://w1:8080".to_string(), + )])); + + let checker = start_health_checker(workers.clone(), 60); + + // Verify it starts without panic + tokio::time::sleep(Duration::from_millis(100)).await; + + // Shutdown + checker.shutdown().await; + } + + #[tokio::test] + async fn test_health_checker_shutdown() { + let workers = Arc::new(RwLock::new(vec![WorkerFactory::create_regular( + "http://w1:8080".to_string(), + )])); + + let checker = start_health_checker(workers.clone(), 60); + + // Shutdown should complete quickly + let shutdown_result = timeout(Duration::from_secs(1), checker.shutdown()).await; + assert!(shutdown_result.is_ok()); + } + + // Performance test for load counter + #[test] + fn test_load_counter_performance() { + use std::time::Instant; + + let worker = BasicWorker::new("http://test:8080".to_string(), WorkerType::Regular); + let iterations = 1_000_000; + + let start = Instant::now(); + for _ in 0..iterations { + worker.increment_load(); + } + let duration = start.elapsed(); + + let ops_per_sec = iterations as f64 / duration.as_secs_f64(); + println!("Load counter operations per second: {:.0}", ops_per_sec); + + // Should be well over 1M ops/sec + assert!(ops_per_sec > 1_000_000.0); + } +} From 60468da4e2d7bda65ee3ad04857d7e29db9396af Mon Sep 17 00:00:00 2001 From: Garry Fang Date: Sun, 20 Jul 2025 05:41:27 +0800 Subject: [PATCH 056/396] bugfix: fix sglang crash in NVIDIA MIG container (#8167) Signed-off-by: Garrybest --- python/sglang/srt/utils.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index dc6e72d75dcd..7123722eb80a 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -1422,6 +1422,13 @@ def get_nvgpu_memory_capacity(): ] if not memory_values: + # Fallback to torch.cuda.mem_get_info() when failed to get memory capacity from nvidia-smi, + # typically in NVIDIA MIG mode. + if torch.cuda.is_available(): + logger.warning( + "Failed to get GPU memory capacity from nvidia-smi, falling back to torch.cuda.mem_get_info()." + ) + return torch.cuda.mem_get_info()[1] // 1024 // 1024 # unit: MB raise ValueError("No GPU memory values found.") # Return the minimum memory value From 4e3defe5a77e14d70ad4ebfb3115ce507789f6e9 Mon Sep 17 00:00:00 2001 From: Lifu Huang Date: Sat, 19 Jul 2025 15:38:09 -0700 Subject: [PATCH 057/396] Support start up LoRA server without initial adapters (#8019) --- docs/backend/lora.ipynb | 161 ++++++++---------- docs/backend/server_arguments.md | 3 +- python/sglang/srt/lora/lora_manager.py | 6 +- .../sglang/srt/managers/tokenizer_manager.py | 10 +- .../srt/model_executor/cuda_graph_runner.py | 11 +- .../srt/model_executor/forward_batch_info.py | 2 +- .../sglang/srt/model_executor/model_runner.py | 8 +- python/sglang/srt/server_args.py | 74 +++++--- python/sglang/srt/utils.py | 14 ++ python/sglang/test/runners.py | 2 + test/srt/models/lora/test_lora_update.py | 82 ++++++++- test/srt/run_suite.py | 2 +- 12 files changed, 235 insertions(+), 140 deletions(-) diff --git a/docs/backend/lora.ipynb b/docs/backend/lora.ipynb index 6c089b654fd5..8626d3e71a68 100644 --- a/docs/backend/lora.ipynb +++ b/docs/backend/lora.ipynb @@ -27,6 +27,8 @@ "source": [ "The following server arguments are relevant for multi-LoRA serving:\n", "\n", + "* `enable_lora`: Enable LoRA support for the model. This argument is automatically set to True if `--lora-paths` is provided for backward compatibility.\n", + "\n", "* `lora_paths`: A mapping from each adaptor's name to its path, in the form of `{name}={path} {name}={path}`.\n", "\n", "* `max_loras_per_batch`: Maximum number of adaptors used by each batch. This argument can affect the amount of GPU memory reserved for multi-LoRA serving, so it should be set to a smaller value when memory is scarce. Defaults to be 8.\n", @@ -35,7 +37,7 @@ "\n", "* `max_lora_rank`: The maximum LoRA rank that should be supported. If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of larger LoRA rank after server startup.\n", "\n", - "* `lora_target_modules`: The union set of all target modules where LoRA should be applied (e.g., `q_proj`, `k_proj`, `gate_proj`). If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of different target modules after server startup.\n", + "* `lora_target_modules`: The union set of all target modules where LoRA should be applied (e.g., `q_proj`, `k_proj`, `gate_proj`). If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of different target modules after server startup. You can also set it to `all` to enable LoRA for all supported modules. However, enabling LoRA on additional modules introduces a minor performance overhead. If your application is performance-sensitive, we recommend only specifying the modules for which you plan to load adapters.\n", "\n", "* `tp_size`: LoRA serving along with Tensor Parallelism is supported by SGLang. `tp_size` controls the number of GPUs for tensor parallelism. More details on the tensor sharding strategy can be found in [S-Lora](https://arxiv.org/pdf/2311.03285) paper.\n", "\n", @@ -79,6 +81,7 @@ "server_process, port = launch_server_cmd(\n", " \"\"\"\n", "python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\n", + " --enable-lora \\\n", " --lora-paths lora0=algoprog/fact-generation-llama-3.1-8b-instruct-lora \\\n", " --max-loras-per-batch 1 --lora-backend triton \\\n", " --disable-radix-cache\n", @@ -98,7 +101,7 @@ "json_data = {\n", " \"text\": [\n", " \"List 3 countries and their capitals.\",\n", - " \"AI is a field of computer science focused on\",\n", + " \"List 3 countries and their capitals.\",\n", " ],\n", " \"sampling_params\": {\"max_new_tokens\": 32, \"temperature\": 0},\n", " # The first input uses lora0, and the second input uses the base model\n", @@ -137,6 +140,7 @@ "server_process, port = launch_server_cmd(\n", " \"\"\"\n", "python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\n", + " --enable-lora \\\n", " --lora-paths lora0=algoprog/fact-generation-llama-3.1-8b-instruct-lora \\\n", " lora1=Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16 \\\n", " --max-loras-per-batch 2 --lora-backend triton \\\n", @@ -157,7 +161,7 @@ "json_data = {\n", " \"text\": [\n", " \"List 3 countries and their capitals.\",\n", - " \"AI is a field of computer science focused on\",\n", + " \"List 3 countries and their capitals.\",\n", " ],\n", " \"sampling_params\": {\"max_new_tokens\": 32, \"temperature\": 0},\n", " # The first input uses lora0, and the second input uses lora1\n", @@ -191,11 +195,9 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Basic Usage\n", - "\n", "Instead of specifying all adapters during server startup via `--lora-paths`. You can also load & unload LoRA adapters dynamically via the `/load_lora_adapter` and `/unload_lora_adapter` API.\n", "\n", - "(Please note that, currently we still require you to specify at least one adapter in `--lora-paths` to enable the LoRA feature, this limitation will be lifted soon.)" + "When using dynamic LoRA loading, it's recommended to explicitly specify both `--max-lora-rank` and `--lora-target-modules` at startup. For backward compatibility, SGLang will infer these values from `--lora-paths` if they are not explicitly provided. However, in that case, you would have to ensure that all dynamically loaded adapters share the same shape (rank and target modules) as those in the initial `--lora-paths` or are strictly \"smaller\"." ] }, { @@ -204,13 +206,22 @@ "metadata": {}, "outputs": [], "source": [ + "lora0 = \"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16\" # rank - 4, target modules - q_proj, k_proj, v_proj, o_proj, gate_proj\n", + "lora1 = \"algoprog/fact-generation-llama-3.1-8b-instruct-lora\" # rank - 64, target modules - q_proj, k_proj, v_proj, o_proj, gate_proj, up_proj, down_proj\n", + "lora0_new = \"philschmid/code-llama-3-1-8b-text-to-sql-lora\" # rank - 256, target modules - q_proj, k_proj, v_proj, o_proj, gate_proj, up_proj, down_proj\n", + "\n", + "\n", + "# The `--target-lora-modules` param below is technically not needed, as the server will infer it from lora0 which already has all the target modules specified.\n", + "# We are adding it here just to demonstrate usage.\n", "server_process, port = launch_server_cmd(\n", " \"\"\"\n", " python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\n", - " --lora-paths lora0=philschmid/code-llama-3-1-8b-text-to-sql-lora \\\n", + " --enable-lora \\\n", " --cuda-graph-max-bs 2 \\\n", " --max-loras-per-batch 2 --lora-backend triton \\\n", " --disable-radix-cache\n", + " --max-lora-rank 256\n", + " --lora-target-modules all\n", " \"\"\"\n", ")\n", "\n", @@ -218,6 +229,13 @@ "wait_for_server(url)" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Load adapter lora0" + ] + }, { "cell_type": "code", "execution_count": null, @@ -227,8 +245,8 @@ "response = requests.post(\n", " url + \"/load_lora_adapter\",\n", " json={\n", - " \"lora_name\": \"lora1\",\n", - " \"lora_path\": \"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16\",\n", + " \"lora_name\": \"lora0\",\n", + " \"lora_path\": lora0,\n", " },\n", ")\n", "\n", @@ -239,38 +257,10 @@ ] }, { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "response = requests.post(\n", - " url + \"/generate\",\n", - " json={\n", - " \"text\": [\n", - " \"List 3 countries and their capitals.\",\n", - " \"List 3 countries and their capitals.\",\n", - " ],\n", - " \"sampling_params\": {\"max_new_tokens\": 32, \"temperature\": 0},\n", - " \"lora_path\": [\"lora0\", \"lora1\"],\n", - " },\n", - ")\n", - "print(f\"Output from lora0: {response.json()[0]['text']}\")\n", - "print(f\"Output from lora1: {response.json()[1]['text']}\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, + "cell_type": "markdown", "metadata": {}, - "outputs": [], "source": [ - "response = requests.post(\n", - " url + \"/unload_lora_adapter\",\n", - " json={\n", - " \"lora_name\": \"lora0\",\n", - " },\n", - ")" + "Load adapter lora1:" ] }, { @@ -282,8 +272,8 @@ "response = requests.post(\n", " url + \"/load_lora_adapter\",\n", " json={\n", - " \"lora_name\": \"lora2\",\n", - " \"lora_path\": \"pbevan11/llama-3.1-8b-ocr-correction\",\n", + " \"lora_name\": \"lora1\",\n", + " \"lora_path\": lora1,\n", " },\n", ")\n", "\n", @@ -294,24 +284,10 @@ ] }, { - "cell_type": "code", - "execution_count": null, + "cell_type": "markdown", "metadata": {}, - "outputs": [], "source": [ - "response = requests.post(\n", - " url + \"/generate\",\n", - " json={\n", - " \"text\": [\n", - " \"List 3 countries and their capitals.\",\n", - " \"List 3 countries and their capitals.\",\n", - " ],\n", - " \"sampling_params\": {\"max_new_tokens\": 32, \"temperature\": 0},\n", - " \"lora_path\": [\"lora1\", \"lora2\"],\n", - " },\n", - ")\n", - "print(f\"Output from lora1: {response.json()[0]['text']}\")\n", - "print(f\"Output from lora2: {response.json()[1]['text']}\")" + "Check inference output:" ] }, { @@ -320,18 +296,29 @@ "metadata": {}, "outputs": [], "source": [ - "terminate_process(server_process)" + "url = f\"http://127.0.0.1:{port}\"\n", + "json_data = {\n", + " \"text\": [\n", + " \"List 3 countries and their capitals.\",\n", + " \"List 3 countries and their capitals.\",\n", + " ],\n", + " \"sampling_params\": {\"max_new_tokens\": 32, \"temperature\": 0},\n", + " # The first input uses lora0, and the second input uses lora1\n", + " \"lora_path\": [\"lora0\", \"lora1\"],\n", + "}\n", + "response = requests.post(\n", + " url + \"/generate\",\n", + " json=json_data,\n", + ")\n", + "print(f\"Output from lora0: \\n{response.json()[0]['text']}\\n\")\n", + "print(f\"Output from lora1 (updated): \\n{response.json()[1]['text']}\\n\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "### Advanced: hosting adapters of different shapes\n", - "\n", - "In some cases, you may want to load LoRA adapters with different ranks or target modules (e.g., `q_proj`, `k_proj`) simultaneously. To ensure the server can accommodate all expected LoRA shapes, it's recommended to explicitly specify `--max-lora-rank` and/or `--lora-target-modules` at startup.\n", - "\n", - "For backward compatibility, SGLang will infer these values from `--lora-paths` if they are not explicitly provided. This means it's safe to omit them **only if** all dynamically loaded adapters share the same shape (rank and target modules) as those in the initial `--lora-paths` or are strictly \"smaller\"." + "Unload lora0 and replace it with a different adapter:" ] }, { @@ -340,39 +327,18 @@ "metadata": {}, "outputs": [], "source": [ - "lora0 = \"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16\" # rank - 4, target modules - q_proj, k_proj, v_proj, o_proj, gate_proj\n", - "lora1 = \"algoprog/fact-generation-llama-3.1-8b-instruct-lora\" # rank - 64, target modules - q_proj, k_proj, v_proj, o_proj, gate_proj, up_proj, down_proj\n", - "\n", - "\n", - "# The `--target-lora-modules` param below is technically not needed, as the server will infer it from lora0 which already has all the target modules specified.\n", - "# We are adding it here just to demonstrate usage.\n", - "server_process, port = launch_server_cmd(\n", - " f\"\"\"\n", - " python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\n", - " --lora-paths lora0={lora0} \\\n", - " --cuda-graph-max-bs 2 \\\n", - " --max-loras-per-batch 2 --lora-backend triton \\\n", - " --disable-radix-cache\n", - " --max-lora-rank 64\n", - " --lora-target-modules q_proj k_proj v_proj o_proj down_proj up_proj gate_proj\n", - " \"\"\"\n", + "response = requests.post(\n", + " url + \"/unload_lora_adapter\",\n", + " json={\n", + " \"lora_name\": \"lora0\",\n", + " },\n", ")\n", "\n", - "url = f\"http://127.0.0.1:{port}\"\n", - "wait_for_server(url)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ "response = requests.post(\n", " url + \"/load_lora_adapter\",\n", " json={\n", - " \"lora_name\": \"lora1\",\n", - " \"lora_path\": lora1,\n", + " \"lora_name\": \"lora0\",\n", + " \"lora_path\": lora0_new,\n", " },\n", ")\n", "\n", @@ -382,6 +348,13 @@ " print(\"Failed to load LoRA adapter.\", response.json())" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Check output again:" + ] + }, { "cell_type": "code", "execution_count": null, @@ -392,7 +365,7 @@ "json_data = {\n", " \"text\": [\n", " \"List 3 countries and their capitals.\",\n", - " \"AI is a field of computer science focused on\",\n", + " \"List 3 countries and their capitals.\",\n", " ],\n", " \"sampling_params\": {\"max_new_tokens\": 32, \"temperature\": 0},\n", " # The first input uses lora0, and the second input uses lora1\n", @@ -402,8 +375,8 @@ " url + \"/generate\",\n", " json=json_data,\n", ")\n", - "print(f\"Output from lora0: {response.json()[0]['text']}\")\n", - "print(f\"Output from lora1: {response.json()[1]['text']}\")" + "print(f\"Output from lora0: \\n{response.json()[0]['text']}\\n\")\n", + "print(f\"Output from lora1 (updated): \\n{response.json()[1]['text']}\\n\")" ] }, { diff --git a/docs/backend/server_arguments.md b/docs/backend/server_arguments.md index 6320a6e61aac..d7c5ff520dc9 100644 --- a/docs/backend/server_arguments.md +++ b/docs/backend/server_arguments.md @@ -176,8 +176,9 @@ Please consult the documentation below and [server_args.py](https://github.com/s | Arguments | Description | Defaults | |-----------|-------------|----------| +| `--enable-lora` | Enable LoRA support for the model. This argument is automatically set to True if `--lora-paths` is provided for backward compatibility. | False | | `--max-lora-rank` | The maximum LoRA rank that should be supported. If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of larger LoRA rank after server startup. | None | -| `--lora-target-modules` | The union set of all target modules where LoRA should be applied (e.g., `q_proj`, `k_proj`, `gate_proj`). If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of different target modules after server startup. | None | +| `--lora-target-modules` | The union set of all target modules where LoRA should be applied (e.g., `q_proj`, `k_proj`, `gate_proj`). If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of different target modules after server startup. You can also set it to `all` to enable LoRA for all supported modules. However, enabling LoRA on additional modules introduces a minor performance overhead. If your application is performance-sensitive, we recommend only specifying the modules for which you plan to load adapters. | None | | `--lora-paths` | The list of LoRA adapters. You can provide a list of either path in str or renamed path in the format {name}={path}. | None | | `--max-loras-per-batch` | Maximum number of adapters for a running batch, include base-only request. | 8 | | `--lora-backend` | Choose the kernel backend for multi-LoRA serving. | triton | diff --git a/python/sglang/srt/lora/lora_manager.py b/python/sglang/srt/lora/lora_manager.py index 96102d1efd5c..85fd246163c1 100644 --- a/python/sglang/srt/lora/lora_manager.py +++ b/python/sglang/srt/lora/lora_manager.py @@ -186,9 +186,9 @@ def validate_new_adapter(self, lora_name: str, lora_config: LoRAConfig): ) if incompatible: raise ValueError( - f"LoRA adapter {lora_name} with rank {lora_config.r} is incompatible with the current LoRA memory pool configuration." - "We are still working on supporting dynamically updating LoRA shapes. If you expect to use adapters of different shapes, " - "You can specify expected configs via --max_lora_rank and --enable_lora_modules." + f"LoRA adapter {lora_name} with rank {lora_config.r} is incompatible with the current LoRA memory pool configuration. " + "Please ensure that the LoRA adapter's rank is within the configured `--max_lora_rank` and that the target modules are " + "included in `--enable_lora_modules`." ) def unload_lora_adapter(self, lora_name: str) -> LoRAUpdateResult: diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 7ba07f675120..631d23f17335 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -574,7 +574,7 @@ def _validate_one_request( "The server is not configured to enable custom logit processor. " "Please set `--enable-custom-logits-processor` to enable this feature." ) - if self.server_args.lora_paths and obj.lora_path: + if self.server_args.enable_lora and obj.lora_path: self._validate_lora_adapters(obj) def _validate_input_ids_in_vocab( @@ -1037,6 +1037,10 @@ async def load_lora_adapter( _: Optional[fastapi.Request] = None, ) -> LoadLoRAAdapterReqOutput: self.auto_create_handle_loop() + if not self.server_args.enable_lora: + raise ValueError( + "LoRA is not enabled. Please set `--enable-lora` to enable LoRA." + ) # TODO (lifuhuang): Remove this after we verify that dynamic lora loading works # with dp_size > 1. @@ -1060,6 +1064,10 @@ async def unload_lora_adapter( _: Optional[fastapi.Request] = None, ) -> UnloadLoRAAdapterReqOutput: self.auto_create_handle_loop() + if not self.server_args.enable_lora: + raise ValueError( + "LoRA is not enabled. Please set `--enable-lora` to enable LoRA." + ) # TODO (lifuhuang): Remove this after we verify that dynamic lora loading works # with dp_size > 1. diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 1f654ca7ecff..520a631c5ecf 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -264,7 +264,7 @@ def __init__(self, model_runner: ModelRunner): if self.enable_torch_compile: set_torch_compile_config() - if self.model_runner.server_args.lora_paths is not None: + if self.model_runner.server_args.enable_lora: self.model_runner.lora_manager.init_cuda_graph_batch_info(self.max_bs) # Graph inputs @@ -510,11 +510,10 @@ def capture_one_batch_size(self, bs: int, forward: Callable): spec_info.capture_hidden_mode if spec_info else CaptureHiddenMode.NULL ) - if self.model_runner.server_args.lora_paths is not None: - # Currently, if the lora_path in `lora_paths` is None, the lora backend will use a - # different logic to handle lora, so we need to set `lora_paths` to a list of non-None - # values if lora is enabled. - lora_paths = [next(iter(self.model_runner.server_args.lora_paths))] * bs + if self.model_runner.server_args.enable_lora: + # It is safe to capture CUDA graph using empty LoRA path, as the LoRA kernels will always be launched whenever + # `--enable-lora` is set to True (and return immediately if the LoRA path is empty for perf optimization). + lora_paths = [None] * bs else: lora_paths = None diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index fde60e0e5012..6f3ea547477f 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -418,7 +418,7 @@ def init_new( ret._compute_mrope_positions(model_runner, batch) # Init lora information - if model_runner.server_args.lora_paths is not None: + if model_runner.server_args.enable_lora: model_runner.lora_manager.prepare_lora_batch(ret) TboForwardBatchPreparer.prepare( diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index bbd5b000067f..4f0b1d64ce8a 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -304,11 +304,7 @@ def initialize(self, min_per_gpu_memory: float): self.apply_torch_tp() # Init lora - # TODO (lifuhuang): when we support dynamic LoRA loading / unloading, we should add - # a new server arg `enable_lora` to control whether to init LoRA manager to be more - # explicit, as it is perfectly valid to start a server with an empty lora_paths and - # load LoRA adapters dynamically later. - if server_args.lora_paths is not None: + if server_args.enable_lora: self.init_lora_manager() # Init memory pool and attention backends @@ -895,7 +891,7 @@ def init_lora_manager(self): max_lora_rank=self.server_args.max_lora_rank, target_modules=self.server_args.lora_target_modules, ) - result = self.lora_manager.load_lora_adapters(self.server_args.lora_paths) + result = self.lora_manager.load_lora_adapters(self.server_args.lora_paths or {}) if result.success: logger.info( f"LoRA manager ready. Loaded LoRA adapters: {', '.join(result.loaded_adapters)}" diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 24292bcd79b8..6464f9f40a39 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -26,6 +26,8 @@ from sglang.srt.hf_transformers_utils import check_gguf_file, get_config from sglang.srt.reasoning_parser import ReasoningParser from sglang.srt.utils import ( + LORA_TARGET_ALL_MODULES, + SUPPORTED_LORA_TARGET_MODULES, configure_ipv6, get_device, get_device_memory_capacity, @@ -140,8 +142,9 @@ class ServerArgs: preferred_sampling_params: Optional[str] = None # LoRA + enable_lora: Optional[bool] = None max_lora_rank: Optional[int] = None - lora_target_modules: Optional[List[str]] = None + lora_target_modules: Optional[Union[set[str], List[str]]] = None lora_paths: Optional[Union[dict[str, str], List[str]]] = None max_loras_per_batch: int = 8 lora_backend: str = "triton" @@ -1148,6 +1151,12 @@ def add_cli_args(parser: argparse.ArgumentParser): ) # LoRA + parser.add_argument( + "--enable-lora", + default=ServerArgs.enable_lora, + action="store_true", + help="Enable LoRA support for the model. This argument is automatically set to True if `--lora-paths` is provided for backward compatibility.", + ) parser.add_argument( "--max-lora-rank", default=ServerArgs.max_lora_rank, @@ -1157,18 +1166,12 @@ def add_cli_args(parser: argparse.ArgumentParser): parser.add_argument( "--lora-target-modules", type=str, - choices=[ - "q_proj", - "k_proj", - "v_proj", - "o_proj", - "gate_proj", - "up_proj", - "down_proj", - ], + choices=SUPPORTED_LORA_TARGET_MODULES + [LORA_TARGET_ALL_MODULES], nargs="*", default=None, - help="The union set of all target modules where LoRA should be applied. If not specified, it will be automatically inferred from the adapters provided in --lora-paths.", + help="The union set of all target modules where LoRA should be applied. If not specified, " + "it will be automatically inferred from the adapters provided in --lora-paths. If 'all' is specified, " + "all supported modules will be targeted.", ) parser.add_argument( "--lora-paths", @@ -1816,15 +1819,46 @@ def check_server_args(self): None, }, "moe_dense_tp_size only support 1 and None currently" - if isinstance(self.lora_paths, list): - lora_paths = self.lora_paths - self.lora_paths = {} - for lora_path in lora_paths: - if "=" in lora_path: - name, path = lora_path.split("=", 1) - self.lora_paths[name] = path - else: - self.lora_paths[lora_path] = lora_path + self.check_lora_server_args() + + def check_lora_server_args(self): + # Enable LoRA if any LoRA paths are provided for backward compatibility. + if self.lora_paths: + if self.enable_lora is None: + self.enable_lora = True + logger.info( + "--enable-lora is set to True because --lora-paths is provided." + ) + elif self.enable_lora is False: + logger.warning( + "--enable-lora is set to False, any provided lora_paths will be ignored." + ) + + if self.enable_lora: + # Normalize lora_paths to a dictionary if it is a list. + if isinstance(self.lora_paths, list): + lora_paths = self.lora_paths + self.lora_paths = {} + for lora_path in lora_paths: + if "=" in lora_path: + name, path = lora_path.split("=", 1) + self.lora_paths[name] = path + else: + self.lora_paths[lora_path] = lora_path + + # Expand target modules + if self.lora_target_modules: + self.lora_target_modules = set(self.lora_target_modules) + if "all" in self.lora_target_modules: + assert ( + len(self.lora_target_modules) == 1 + ), "If 'all' is specified in --lora-target-modules, it should be the only module specified." + self.lora_target_modules = set(SUPPORTED_LORA_TARGET_MODULES) + + # Ensure sufficient information is provided for LoRA initialization. + assert self.lora_paths or ( + self.max_lora_rank and self.lora_target_modules + ), "When no initial --lora-paths is provided, you need to specify both --max-lora-rank and --lora-target-modules for LoRA initialization." def validate_disagg_tp_size(self, prefill_tp: int, decode_tp: int): larger_tp = max(decode_tp, prefill_tp) diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 7123722eb80a..23960a8c1123 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -2892,3 +2892,17 @@ def placeholder(*args, **kwargs): return final_module, getattr(final_module, function_name) return final_module, None + + +# LoRA-related constants and utilities +SUPPORTED_LORA_TARGET_MODULES = [ + "q_proj", + "k_proj", + "v_proj", + "o_proj", + "gate_proj", + "up_proj", + "down_proj", +] + +LORA_TARGET_ALL_MODULES = "all" diff --git a/python/sglang/test/runners.py b/python/sglang/test/runners.py index 941940fe0fd8..9ec71c29bac8 100644 --- a/python/sglang/test/runners.py +++ b/python/sglang/test/runners.py @@ -507,6 +507,7 @@ def __init__( sleep_on_idle=False, max_lora_rank: Optional[int] = None, lora_target_modules: Optional[List[str]] = None, + enable_lora: Optional[bool] = None, ): self.model_type = model_type self.is_generation = model_type == "generation" @@ -547,6 +548,7 @@ def __init__( sleep_on_idle=sleep_on_idle, max_lora_rank=max_lora_rank, lora_target_modules=lora_target_modules, + enable_lora=enable_lora, **spec_kwargs, ) diff --git a/test/srt/models/lora/test_lora_update.py b/test/srt/models/lora/test_lora_update.py index 785b44e953fd..83392b9247be 100644 --- a/test/srt/models/lora/test_lora_update.py +++ b/test/srt/models/lora/test_lora_update.py @@ -64,8 +64,9 @@ class TestCase: base: str max_loras_per_batch: int all_adapters: List[str] - initial_adapters: List[str] op_sequence: List[Operation] + initial_adapters: Optional[List[str]] = None + enable_lora: Optional[bool] = None max_lora_rank: Optional[int] = None lora_target_modules: Optional[List] = None max_new_tokens: int = 32 @@ -171,6 +172,64 @@ def create_batch_data(adapters: Union[str, list]) -> List[tuple[str, str]]: ), ], ), + TestCase( + description="dynamic lora update without initial lora_paths", + base="meta-llama/Llama-3.1-8B-Instruct", + enable_lora=True, + max_lora_rank=256, + lora_target_modules=["all"], + max_loras_per_batch=4, + all_adapters=[ + "philschmid/code-llama-3-1-8b-text-to-sql-lora", + "Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16", + "pbevan11/llama-3.1-8b-ocr-correction", + ], + op_sequence=[ + Operation( + type=OperationType.LOAD, + data="philschmid/code-llama-3-1-8b-text-to-sql-lora", + ), + Operation( + type=OperationType.LOAD, + data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16", + ), + Operation( + type=OperationType.LOAD, + data="pbevan11/llama-3.1-8b-ocr-correction", + ), + Operation( + type=OperationType.FORWARD, + data=create_batch_data( + [ + "philschmid/code-llama-3-1-8b-text-to-sql-lora", + "Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16", + "pbevan11/llama-3.1-8b-ocr-correction", + None, + ] + ), + ), + Operation( + type=OperationType.UNLOAD, + data="philschmid/code-llama-3-1-8b-text-to-sql-lora", + ), + Operation( + type=OperationType.FORWARD, + data=create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"), + expected_error="not loaded", + ), + Operation( + type=OperationType.FORWARD, + data=create_batch_data( + [ + None, + "Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16", + "pbevan11/llama-3.1-8b-ocr-correction", + None, + ] + ), + ), + ], + ), TestCase( description="dynamic lora update with evictions", base="meta-llama/Llama-3.1-8B-Instruct", @@ -371,7 +430,7 @@ def create_batch_data(adapters: Union[str, list]) -> List[tuple[str, str]]: Operation( type=OperationType.LOAD, data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16", - expected_error="updating LoRA shapes", + expected_error="incompatible", ), Operation( type=OperationType.FORWARD, @@ -431,7 +490,7 @@ def create_batch_data(adapters: Union[str, list]) -> List[tuple[str, str]]: Operation( type=OperationType.LOAD, data="philschmid/code-llama-3-1-8b-text-to-sql-lora", - expected_error="updating LoRA shapes", + expected_error="incompatible", ), Operation( type=OperationType.FORWARD, @@ -470,7 +529,7 @@ def create_batch_data(adapters: Union[str, list]) -> List[tuple[str, str]]: Operation( type=OperationType.LOAD, data="philschmid/code-llama-3-1-8b-text-to-sql-lora", - expected_error="updating LoRA shapes", + expected_error="incompatible", ), Operation( type=OperationType.FORWARD, @@ -521,6 +580,7 @@ def __init__( lora_paths: list[str], max_loras_per_batch: int, max_lora_rank: Optional[int], + enable_lora: Optional[bool] = None, lora_target_modules: Optional[List[str]] = None, lora_backend: str = "triton", disable_cuda_graph: bool = False, @@ -535,8 +595,9 @@ def __init__( self.lora_backend = lora_backend self.disable_cuda_graph = disable_cuda_graph self.cuda_graph_max_bs = cuda_graph_max_bs + self.enable_lora = enable_lora - self.expected_adapters = set(lora_paths) + self.expected_adapters = set(lora_paths or []) self.handle = None # Will be set in __enter__ def __enter__(self): @@ -596,6 +657,7 @@ def __enter__(self): disable_cuda_graph=self.disable_cuda_graph, cuda_graph_max_bs=self.cuda_graph_max_bs, disable_radix_cache=True, + enable_lora=self.enable_lora, ) self.handle.__enter__() return self @@ -690,8 +752,6 @@ def __enter__(self): other_args = [ "--cuda-graph-max-bs", str(self.cuda_graph_max_bs), - "--lora-paths", - *self.lora_paths, "--max-loras-per-batch", str(self.max_loras_per_batch), "--lora-backend", @@ -704,6 +764,10 @@ def __enter__(self): "--mem-fraction-static", str(MEM_FRACTION_STATIC), ] + if self.enable_lora: + other_args.append("--enable-lora") + if self.lora_paths: + other_args.extend(["--lora-paths"] + self.lora_paths) if self.disable_cuda_graph: other_args.append("--disable-cuda-graph") if self.max_lora_rank is not None: @@ -836,6 +900,7 @@ def _run_operation_sequence( initial_adapters: List[str], max_loras_per_batch: int, op_sequence: List[Operation], + enable_lora: Optional[bool] = None, max_lora_rank: Optional[int] = None, lora_target_modules: Optional[List[str]] = None, max_new_tokens: int = 32, @@ -854,6 +919,7 @@ def _run_operation_sequence( max_loras_per_batch=max_loras_per_batch, max_lora_rank=max_lora_rank, lora_target_modules=lora_target_modules, + enable_lora=enable_lora, ) as session: for op in op_sequence: op_type = op.type @@ -903,6 +969,7 @@ def _run_dynamic_adapter_updates( dynamic_output = self._run_operation_sequence( mode=mode, initial_adapters=test_case.initial_adapters, + enable_lora=test_case.enable_lora, base=test_case.base, max_loras_per_batch=test_case.max_loras_per_batch, op_sequence=test_case.op_sequence, @@ -923,6 +990,7 @@ def _run_dynamic_adapter_updates( static_output = self._run_operation_sequence( mode=mode, initial_adapters=test_case.all_adapters, + enable_lora=test_case.enable_lora, base=test_case.base, max_loras_per_batch=test_case.max_loras_per_batch, op_sequence=forward_ops, diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index f59aed623e0f..d7b4739e38cb 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -18,7 +18,7 @@ class TestFile: TestFile("models/lora/test_lora_backend.py", 99), TestFile("models/lora/test_multi_lora_backend.py", 60), TestFile("models/lora/test_lora_cuda_graph.py", 250), - TestFile("models/lora/test_lora_update.py", 700), + TestFile("models/lora/test_lora_update.py", 800), TestFile("models/test_embedding_models.py", 73), # TestFile("models/test_clip_models.py", 52), TestFile("models/test_encoder_embedding_models.py", 100), From 8cddfa56a14e9ac03677bfc9e8df2f59b5bce382 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Sat, 19 Jul 2025 15:56:50 -0700 Subject: [PATCH 058/396] Clean warning logs for gate_proj loading in Lora (#8172) --- python/sglang/srt/lora/lora.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/python/sglang/srt/lora/lora.py b/python/sglang/srt/lora/lora.py index 2a3d2acfdff5..7bc6af532e8c 100644 --- a/python/sglang/srt/lora/lora.py +++ b/python/sglang/srt/lora/lora.py @@ -186,10 +186,6 @@ def normalize_gate_up_proj( up_name = weight_name.replace("gate_proj", "up_proj") gate_up_name = weight_name.replace("gate_proj", "gate_up_proj") if up_name not in weights: - logger.warning( - f"Gate projection {weight_name} does not have a corresponding up projection {up_name}. " - f"Initializing up projection to zero." - ) weights[up_name] = torch.zeros_like(weights[weight_name]) # FIXME: Add gate-only support for flashinfer in future implementations assert self.lora_backend.name == "triton", ( From abda2542d5cd465bbbfa5971139090df2dc02646 Mon Sep 17 00:00:00 2001 From: Cheng Wan <54331508+ch-wan@users.noreply.github.com> Date: Sat, 19 Jul 2025 17:33:50 -0700 Subject: [PATCH 059/396] Fix tuning_fused_moe_triton.py (#8175) --- .../fused_moe_triton/tuning_fused_moe_triton.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py b/benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py index 5af1b32be8f9..69b0563e9cbf 100644 --- a/benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py +++ b/benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py @@ -18,6 +18,7 @@ get_default_config, get_moe_configs, ) +from sglang.srt.layers.moe.topk import select_experts from sglang.srt.utils import is_hip _is_hip = is_hip() @@ -115,10 +116,15 @@ def benchmark_config( w1 = w1.to(torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn) w2 = w2.to(torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn) - input_gating = torch.empty(num_tokens, num_experts, dtype=torch.float32) + input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32) + topk_output = select_experts(x, input_gating, topk, renormalize=True) def prepare(i: int): - input_gating.copy_(gating_output[i]) + input_gating = gating_output[i] + new_topk_output = select_experts(x, input_gating, topk, renormalize=True) + topk_output.topk_weights.copy_(new_topk_output.topk_weights) + topk_output.topk_ids.copy_(new_topk_output.topk_ids) + topk_output.router_logits.copy_(new_topk_output.router_logits) def run(): from sglang.srt.layers.moe.fused_moe_triton import override_config @@ -128,9 +134,7 @@ def run(): x, w1, w2, - input_gating, - topk, - renormalize=True, + topk_output, inplace=True, use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a8=use_int8_w8a8, From 4540a4666a112a82dcf21505b781f3e31e50d178 Mon Sep 17 00:00:00 2001 From: ybyang <10629930+whybeyoung@users.noreply.github.com> Date: Sun, 20 Jul 2025 09:10:00 +0800 Subject: [PATCH 060/396] [Feature] Simple Improve Health Check Mechanism for Production-Grade Stability (#8115) Signed-off-by: ybyang --- python/sglang/srt/entrypoints/engine.py | 4 ++ python/sglang/srt/entrypoints/http_server.py | 57 ++++++++++++++++--- python/sglang/srt/managers/io_struct.py | 6 ++ python/sglang/srt/managers/scheduler.py | 3 + .../sglang/srt/managers/tokenizer_manager.py | 7 ++- python/sglang/srt/utils.py | 16 ++++++ 6 files changed, 82 insertions(+), 11 deletions(-) diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index 990fac9a12a7..957d85aa5998 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -65,6 +65,7 @@ from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter from sglang.srt.utils import ( MultiprocessingSerializer, + ServerStatus, assert_pkg_version, configure_logger, get_zmq_socket, @@ -73,6 +74,7 @@ launch_dummy_health_check_server, maybe_set_triton_cache_manager, prepare_model_and_tokenizer, + report_health, set_prometheus_multiproc_dir, set_ulimit, ) @@ -661,6 +663,7 @@ def _set_envs_and_config(server_args: ServerArgs): def sigchld_handler(signum, frame): pid, exitcode = os.waitpid(0, os.WNOHANG) if exitcode != 0: + report_health(ServerStatus.Crashed, server_args.host, server_args.port) logger.warning( f"Child process unexpectedly failed with {exitcode=}. {pid=}" ) @@ -674,6 +677,7 @@ def sigquit_handler(signum, frame): logger.error( "Received sigquit from a child process. It usually means the child failed." ) + report_health(ServerStatus.Crashed, server_args.host, server_args.port) kill_process_tree(os.getpid()) signal.signal(signal.SIGQUIT, sigquit_handler) diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index 43819e1a65e4..f880c4aa5cd4 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -77,6 +77,7 @@ ParseFunctionCallReq, ProfileReqInput, ReleaseMemoryOccupationReqInput, + ReportHealthInput, ResumeMemoryOccupationReqInput, SeparateReasoningReqInput, SetInternalStateReq, @@ -93,6 +94,7 @@ from sglang.srt.reasoning_parser import ReasoningParser from sglang.srt.server_args import ServerArgs from sglang.srt.utils import ( + ServerStatus, add_api_key_middleware, add_prometheus_middleware, delete_directory, @@ -220,8 +222,31 @@ async def validate_json_request(raw_request: Request): @app.get("/health") async def health() -> Response: - """Check the health of the http server.""" - return Response(status_code=200) + """Check the status of the http server.""" + code = HTTPStatus.SERVICE_UNAVAILABLE.value + if _global_state.tokenizer_manager.server_status == ServerStatus.Up: + code = HTTPStatus.OK.value + return Response( + status_code=code, + content=json.dumps( + {"status": _global_state.tokenizer_manager.server_status.value} + ), + ) + + +@app.post("/health") +async def health_update(obj: ReportHealthInput, request: Request) -> Response: + """Update the Status of the http server.""" + try: + server_status = ServerStatus(obj.status) + _global_state.tokenizer_manager.server_status = server_status + if server_status != ServerStatus.Up: + return Response( + status_code=HTTPStatus.SERVICE_UNAVAILABLE.value, content=obj.msg + ) + except Exception as e: + logger.error(e) + return Response(status_code=HTTPStatus.SERVICE_UNAVAILABLE.value) @app.get("/health_generate") @@ -256,7 +281,7 @@ async def gen(): if _global_state.tokenizer_manager.last_receive_tstamp > tic: task.cancel() _global_state.tokenizer_manager.rid_to_state.pop(rid, None) - _global_state.tokenizer_manager.health_check_failed = False + _global_state.tokenizer_manager.server_status = ServerStatus.Up return Response(status_code=200) task.cancel() @@ -270,7 +295,7 @@ async def gen(): f"last_heartbeat time: {last_receive_time}" ) _global_state.tokenizer_manager.rid_to_state.pop(rid, None) - _global_state.tokenizer_manager.health_check_failed = True + _global_state.tokenizer_manager.server_status = ServerStatus.UnHealthy return Response(status_code=503) @@ -1022,9 +1047,13 @@ def _execute_server_warmup( headers=headers, timeout=600, ) - assert res.status_code == 200, f"{res}" + if res.status_code == 200: + _global_state.tokenizer_manager.server_status = ServerStatus.Up + else: + _global_state.tokenizer_manager.server_status = ServerStatus.UnHealthy + logger.info(f"{res}") else: - logger.info(f"Start of prefill warmup ...") + logger.info(f"Start of prefill/decode warmup ...") json_data = { "sampling_params": { "temperature": 0.0, @@ -1046,15 +1075,25 @@ def _execute_server_warmup( headers=headers, timeout=1800, # because of deep gemm precache is very long if not precache. ) - logger.info( - f"End of prefill warmup with status {res.status_code}, resp: {res.json()}" - ) + if res.status_code == 200: + logger.info( + f"End of prefill disaggregation mode warmup with status {res.status_code}, resp: {res.json()}" + ) + _global_state.tokenizer_manager.server_status = ServerStatus.Up + else: + logger.info( + "Prefill disaggregation mode warm Up Failed, status code: {}".format( + res.status_code + ) + ) + _global_state.tokenizer_manager.server_status = ServerStatus.UnHealthy except Exception: last_traceback = get_exception_traceback() if pipe_finish_writer is not None: pipe_finish_writer.send(last_traceback) logger.error(f"Initialization failed. warmup error: {last_traceback}") + _global_state.tokenizer_manager.server_status = ServerStatus.Crashed kill_process_tree(os.getpid()) return False diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 8e1d1075aab6..b8332fdf656c 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -1083,3 +1083,9 @@ class LoRAUpdateResult: LoadLoRAAdapterReqOutput = UnloadLoRAAdapterReqOutput = LoRAUpdateResult + + +@dataclass +class ReportHealthInput: + status: str + msg: Optional[str] = "" diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index e6dd80d717ad..aee1596dbe9c 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -143,6 +143,7 @@ from sglang.srt.utils import ( DeepEPMode, DynamicGradMode, + ServerStatus, broadcast_pyobj, configure_gc_logger, configure_logger, @@ -154,6 +155,7 @@ kill_itself_when_parent_died, point_to_point_pyobj, pyspy_dump_schedulers, + report_health, require_mlp_sync, require_mlp_tp_gather, set_gpu_proc_affinity, @@ -2964,4 +2966,5 @@ def run_scheduler_process( except Exception: traceback = get_exception_traceback() logger.error(f"Scheduler hit an exception: {traceback}") + report_health(ServerStatus.Crashed, server_args.host, ServerArgs.port) parent_process.send_signal(signal.SIGQUIT) diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 631d23f17335..a0f66419e768 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -116,6 +116,7 @@ from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.utils import ( + ServerStatus, dataclass_to_string_truncated, get_bool_env_var, get_zmq_socket, @@ -173,6 +174,9 @@ def __init__( server_args: ServerArgs, port_args: PortArgs, ): + # Server Status + self.server_status = ServerStatus.Starting + # Parse args self.server_args = server_args self.enable_metrics = server_args.enable_metrics @@ -251,7 +255,6 @@ def __init__( # Store states self.no_create_loop = False self.rid_to_state: Dict[str, ReqState] = {} - self.health_check_failed = False self.gracefully_exit = False self.last_receive_tstamp = 0 self.dump_requests_folder = "" # By default do not dump @@ -1332,7 +1335,7 @@ async def sigterm_watchdog(self): while True: remain_num_req = len(self.rid_to_state) - if self.health_check_failed: + if not self.server_status.is_healthy(): # if health check failed, we should exit immediately logger.error( "Signal SIGTERM received while health check failed. Exiting... remaining number of requests: %d", diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 23960a8c1123..03565a018c34 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -93,6 +93,22 @@ HIP_FP8_E4M3_FNUZ_MAX = 224.0 +class ServerStatus(Enum): + Up = "Up" + Starting = "Starting" + UnHealthy = "UnHealthy" + Crashed = "Crashed" + + def is_healthy(self) -> bool: + return self == ServerStatus.Up + + +def report_health(status: ServerStatus, host: str, http_port: int, msg: str = ""): + requests.post( + f"http://{host}:{http_port}/health", json={"status": status.value, "msg": msg} + ) + + # https://pytorch.org/docs/stable/notes/hip.html#checking-for-hip def is_hip() -> bool: return torch.version.hip is not None From 282eb59ff352e616eb311e6ac036f28d4a87ea13 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Sat, 19 Jul 2025 18:49:37 -0700 Subject: [PATCH 061/396] Add bf16 output option for dsv3_router_gemm kernel (#7999) --- sgl-kernel/CMakeLists.txt | 4 +- .../benchmark/bench_dsv3_router_gemm.py | 53 +++- .../csrc/gemm/dsv3_router_gemm_bf16_out.cu | 234 ++++++++++++++++++ .../csrc/gemm/dsv3_router_gemm_entry.cu | 127 ++++++++++ ..._gemm.cu => dsv3_router_gemm_float_out.cu} | 131 +++------- sgl-kernel/python/sgl_kernel/gemm.py | 3 +- sgl-kernel/tests/test_dsv3_router_gemm.py | 17 +- 7 files changed, 465 insertions(+), 104 deletions(-) create mode 100644 sgl-kernel/csrc/gemm/dsv3_router_gemm_bf16_out.cu create mode 100644 sgl-kernel/csrc/gemm/dsv3_router_gemm_entry.cu rename sgl-kernel/csrc/gemm/{dsv3_router_gemm.cu => dsv3_router_gemm_float_out.cu} (54%) diff --git a/sgl-kernel/CMakeLists.txt b/sgl-kernel/CMakeLists.txt index 89a298c3469f..e8f9a0839658 100644 --- a/sgl-kernel/CMakeLists.txt +++ b/sgl-kernel/CMakeLists.txt @@ -222,7 +222,9 @@ set(SOURCES "csrc/gemm/awq_kernel.cu" "csrc/gemm/bmm_fp8.cu" "csrc/gemm/dsv3_fused_a_gemm.cu" - "csrc/gemm/dsv3_router_gemm.cu" + "csrc/gemm/dsv3_router_gemm_bf16_out.cu" + "csrc/gemm/dsv3_router_gemm_entry.cu" + "csrc/gemm/dsv3_router_gemm_float_out.cu" "csrc/gemm/fp8_blockwise_gemm_kernel.cu" "csrc/gemm/fp8_gemm_kernel.cu" "csrc/gemm/int8_gemm_kernel.cu" diff --git a/sgl-kernel/benchmark/bench_dsv3_router_gemm.py b/sgl-kernel/benchmark/bench_dsv3_router_gemm.py index 16b3143f0623..4502746f9b39 100644 --- a/sgl-kernel/benchmark/bench_dsv3_router_gemm.py +++ b/sgl-kernel/benchmark/bench_dsv3_router_gemm.py @@ -7,6 +7,48 @@ from sgl_kernel import dsv3_router_gemm +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["num_tokens"], + x_vals=[i + 1 for i in range(16)], + x_log=False, + line_arg="impl", + line_vals=["torch", "sgl-kernel"], + line_names=["torch", "dsv3_router_gemm"], + styles=[("blue", "-"), ("orange", "-")], + ylabel="TFLOPs", + plot_name="input-bf16-output-bf16 dsv3 router gemm throughput", + args={}, + ) +) +def benchmark_bf16_output(num_tokens, impl): + # M: num_tokens, K: hidden_dim, N: num_experts + M, K, N = num_tokens, 7168, 256 + + mat_a = torch.randn((M, K), dtype=torch.bfloat16, device="cuda").contiguous() + mat_b = torch.randn((N, K), dtype=torch.bfloat16, device="cuda").contiguous() + + quantiles = [0.5, 0.2, 0.8] + + if impl == "torch": + + def runner(): + F.linear(mat_a, mat_b) + + elif impl == "sgl-kernel": + + def runner(): + dsv3_router_gemm(mat_a, mat_b, out_dtype=torch.bfloat16) + + ms, min_ms, max_ms = triton.testing.do_bench(runner, quantiles=quantiles) + + def tflops(t_ms): + flops = 2 * M * K * N + return flops / (t_ms * 1e-3) / 1e12 + + return tflops(ms), tflops(max_ms), tflops(min_ms) + + @triton.testing.perf_report( triton.testing.Benchmark( x_names=["num_tokens"], @@ -21,7 +63,7 @@ args={}, ) ) -def benchmark(num_tokens, impl): +def benchmark_float_output(num_tokens, impl): # M: num_tokens, K: hidden_dim, N: num_experts M, K, N = num_tokens, 7168, 256 @@ -38,7 +80,7 @@ def runner(): elif impl == "sgl-kernel": def runner(): - dsv3_router_gemm(mat_a, mat_b) + dsv3_router_gemm(mat_a, mat_b, out_dtype=torch.float32) ms, min_ms, max_ms = triton.testing.do_bench(runner, quantiles=quantiles) @@ -53,4 +95,9 @@ def tflops(t_ms): parser = argparse.ArgumentParser() args = parser.parse_args() - benchmark.run(print_data=True, show_plots=True, save_path="bench_dsv3_router_gemm") + benchmark_bf16_output.run( + print_data=True, show_plots=True, save_path="bench_dsv3_router_gemm" + ) + benchmark_float_output.run( + print_data=True, show_plots=True, save_path="bench_dsv3_router_gemm" + ) diff --git a/sgl-kernel/csrc/gemm/dsv3_router_gemm_bf16_out.cu b/sgl-kernel/csrc/gemm/dsv3_router_gemm_bf16_out.cu new file mode 100644 index 000000000000..ef011dfb0b54 --- /dev/null +++ b/sgl-kernel/csrc/gemm/dsv3_router_gemm_bf16_out.cu @@ -0,0 +1,234 @@ +/* + * Adapted from + * https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/kernels/dsv3MinLatencyKernels/dsv3RouterGemm.cu + * https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/thop/dsv3RouterGemmOp.cpp + * + * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include "cuda_bf16.h" +#include "cuda_runtime.h" +#include "utils.h" + +// Custom FMA implementation using PTX assembly instructions +__device__ __forceinline__ void fma(float2& d, float2 const& a, float2 const& b, float2 const& c) { + asm volatile("fma.rn.f32x2 %0, %1, %2, %3;\n" + : "=l"(reinterpret_cast(d)) + : "l"(reinterpret_cast(a)), + "l"(reinterpret_cast(b)), + "l"(reinterpret_cast(c))); +} + +// Convert 8 bfloat16 values from a uint4 to float array - optimized conversion +template +__device__ __forceinline__ void bf16_uint4_to_float8(uint4 const& vec, float* dst) { + __nv_bfloat16* bf16_ptr = reinterpret_cast<__nv_bfloat16*>(const_cast(&vec)); + +#pragma unroll + for (int i = 0; i < VPT; i++) { + dst[i] = __bfloat162float(bf16_ptr[i]); + } +} + +template +__global__ +__launch_bounds__(128, 1) void router_gemm_kernel_bf16_output(__nv_bfloat16* out, T const* mat_a, T const* mat_b) { + // Each block handles one expert column + int const n_idx = blockIdx.x; + int const tid = threadIdx.x; + constexpr int kWarpSize = 32; + constexpr int kNumWarps = kBlockSize / kWarpSize; + // Constants for this kernel + constexpr int k_elems_per_k_iteration = VPT * kBlockSize; + constexpr int k_iterations = kHiddenDim / k_elems_per_k_iteration; // Total K iterations + + // Initialize accumulators for all M rows + float acc[kNumTokens] = {}; + + // Shared memory for warp-level reduction + __shared__ float sm_reduction[kNumTokens][kNumWarps]; // kNumWarps + + // B matrix is in column-major order, so we can directly load a column for the n_idx expert + T const* b_col = mat_b + n_idx * kHiddenDim; + + // Pre-compute k_base values for each iteration to help compiler optimize + // int k_bases[k_iterations]; + int k_bases[k_iterations]; +#pragma unroll + for (int ki = 0; ki < k_iterations; ki++) { + k_bases[ki] = ki * k_elems_per_k_iteration + tid * VPT; + } + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + asm volatile("griddepcontrol.wait;"); +#endif + + // Process the GEMM in chunks + for (int ki = 0; ki < k_iterations; ki++) { + int const k_base = k_bases[ki]; + + // Load B matrix values using vector load (8 bf16 values) + uint4 b_vec = *reinterpret_cast(b_col + k_base); + + // Convert B values to float + float b_float[VPT]; + bf16_uint4_to_float8(b_vec, b_float); + +// Process each token +#pragma unroll + for (int m_idx = 0; m_idx < kNumTokens; m_idx++) { + // Load both rows of A matrix using vector loads + uint4 a_vec = *reinterpret_cast(mat_a + (m_idx * kHiddenDim) + k_base); + + // Convert A values to float + float a_float[VPT]; + bf16_uint4_to_float8(a_vec, a_float); + +// Process elements in this chunk +#pragma unroll + for (int k = 0; k < VPT; k++) { + float a = a_float[k]; + float b = b_float[k]; + acc[m_idx] += a * b; + } + } + } + + // Perform warp-level reduction + int const warpSize = 32; + int const warpId = tid / warpSize; + int const laneId = tid % warpSize; + + // Register for warp-level reduction results + float warp_result[kNumTokens]; + +#pragma unroll + for (int m_idx = 0; m_idx < kNumTokens; m_idx++) { + warp_result[m_idx] = acc[m_idx]; + } + +// Perform warp-level reduction using optimized butterfly pattern +#pragma unroll + for (int m = 0; m < kNumTokens; m++) { + float sum = warp_result[m]; + + // Butterfly reduction pattern + sum += __shfl_xor_sync(0xffffffff, sum, 16); + sum += __shfl_xor_sync(0xffffffff, sum, 8); + sum += __shfl_xor_sync(0xffffffff, sum, 4); + sum += __shfl_xor_sync(0xffffffff, sum, 2); + sum += __shfl_xor_sync(0xffffffff, sum, 1); + + // Only the first thread in each warp stores to shared memory + if (laneId == 0) { + sm_reduction[m][warpId] = sum; + } + } + + __syncthreads(); + + // Final reduction across warps (only first thread) + if (tid == 0) { +#pragma unroll + for (int m = 0; m < kNumTokens; m++) { + float final_sum = 0.0f; + +// Sum across the kNumWarps +#pragma unroll + for (int w = 0; w < kNumWarps; w++) { + final_sum += sm_reduction[m][w]; + } + + // Write final result + out[m * kNumExperts + n_idx] = __float2bfloat16(final_sum); + } + } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + asm volatile("griddepcontrol.launch_dependents;"); +#endif +} + +template +void invokeRouterGemmBf16Output(__nv_bfloat16* output, T const* mat_a, T const* mat_b, cudaStream_t stream) { + constexpr int VPT = 16 / sizeof(T); + constexpr int kBlockSize = 128; + cudaLaunchConfig_t config; + config.gridDim = kNumExperts; + config.blockDim = kBlockSize; + config.dynamicSmemBytes = 0; + config.stream = stream; + cudaLaunchAttribute attrs[1]; + attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; + attrs[0].val.programmaticStreamSerializationAllowed = getEnvEnablePDL(); + config.numAttrs = 1; + config.attrs = attrs; + cudaLaunchKernelEx( + &config, + router_gemm_kernel_bf16_output, + output, + mat_a, + mat_b); +} + +template void invokeRouterGemmBf16Output<__nv_bfloat16, 1, 256, 7168>( + __nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmBf16Output<__nv_bfloat16, 2, 256, 7168>( + __nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmBf16Output<__nv_bfloat16, 3, 256, 7168>( + __nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmBf16Output<__nv_bfloat16, 4, 256, 7168>( + __nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmBf16Output<__nv_bfloat16, 5, 256, 7168>( + __nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmBf16Output<__nv_bfloat16, 6, 256, 7168>( + __nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmBf16Output<__nv_bfloat16, 7, 256, 7168>( + __nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmBf16Output<__nv_bfloat16, 8, 256, 7168>( + __nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmBf16Output<__nv_bfloat16, 9, 256, 7168>( + __nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmBf16Output<__nv_bfloat16, 10, 256, 7168>( + __nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmBf16Output<__nv_bfloat16, 11, 256, 7168>( + __nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmBf16Output<__nv_bfloat16, 12, 256, 7168>( + __nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmBf16Output<__nv_bfloat16, 13, 256, 7168>( + __nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmBf16Output<__nv_bfloat16, 14, 256, 7168>( + __nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmBf16Output<__nv_bfloat16, 15, 256, 7168>( + __nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); + +template void invokeRouterGemmBf16Output<__nv_bfloat16, 16, 256, 7168>( + __nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); diff --git a/sgl-kernel/csrc/gemm/dsv3_router_gemm_entry.cu b/sgl-kernel/csrc/gemm/dsv3_router_gemm_entry.cu new file mode 100644 index 000000000000..c316a8193ea4 --- /dev/null +++ b/sgl-kernel/csrc/gemm/dsv3_router_gemm_entry.cu @@ -0,0 +1,127 @@ +/* + * Adapted from + * https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/kernels/dsv3MinLatencyKernels/dsv3RouterGemm.cu + * https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/thop/dsv3RouterGemmOp.cpp + * + * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include "cuda_bf16.h" +#include "cuda_runtime.h" +#include "utils.h" + +template +void invokeRouterGemmFloatOutput(float* output, T const* mat_a, T const* mat_b, cudaStream_t stream); + +template +void invokeRouterGemmBf16Output(__nv_bfloat16* output, T const* mat_a, T const* mat_b, cudaStream_t stream); + +template +struct LoopUnroller { + static void unroll_float_output( + int num_tokens, float* output, __nv_bfloat16 const* input, __nv_bfloat16 const* weights, cudaStream_t stream) { + if (num_tokens == kBegin) { + invokeRouterGemmFloatOutput<__nv_bfloat16, kBegin, kNumExperts, kHiddenDim>(output, input, weights, stream); + } else { + LoopUnroller::unroll_float_output( + num_tokens, output, input, weights, stream); + } + } + + static void unroll_bf16_output( + int num_tokens, + __nv_bfloat16* output, + __nv_bfloat16 const* input, + __nv_bfloat16 const* weights, + cudaStream_t stream) { + if (num_tokens == kBegin) { + invokeRouterGemmBf16Output<__nv_bfloat16, kBegin, kNumExperts, kHiddenDim>(output, input, weights, stream); + } else { + LoopUnroller::unroll_bf16_output( + num_tokens, output, input, weights, stream); + } + } +}; + +template +struct LoopUnroller { + static void unroll_float_output( + int num_tokens, float* output, __nv_bfloat16 const* input, __nv_bfloat16 const* weights, cudaStream_t stream) { + if (num_tokens == kEnd) { + invokeRouterGemmFloatOutput<__nv_bfloat16, kEnd, kNumExperts, kHiddenDim>(output, input, weights, stream); + } else { + throw std::invalid_argument("Invalid num_tokens, only supports 1 to 16"); + } + } + + static void unroll_bf16_output( + int num_tokens, + __nv_bfloat16* output, + __nv_bfloat16 const* input, + __nv_bfloat16 const* weights, + cudaStream_t stream) { + if (num_tokens == kEnd) { + invokeRouterGemmBf16Output<__nv_bfloat16, kEnd, kNumExperts, kHiddenDim>(output, input, weights, stream); + } else { + throw std::invalid_argument("Invalid num_tokens, only supports 1 to 16"); + } + } +}; + +void dsv3_router_gemm( + torch::Tensor& output, // [num_tokens, num_experts] + const torch::Tensor& mat_a, // [num_tokens, hidden_dim] + const torch::Tensor& mat_b // [num_experts, hidden_dim] +) { + TORCH_CHECK(output.dim() == 2 && mat_a.dim() == 2 && mat_b.dim() == 2); + + const int num_tokens = mat_a.size(0); + constexpr int num_experts = 256; + constexpr int hidden_dim = 7168; + + TORCH_CHECK(mat_a.size(1) == mat_b.size(1), "mat_a and mat_b must have the same hidden_dim"); + TORCH_CHECK(mat_a.size(1) == hidden_dim, "currently hidden_dim only supports 7168"); + TORCH_CHECK(mat_b.size(0) == num_experts, "currently num_experts only supports 256"); + TORCH_CHECK( + num_tokens >= 1 && num_tokens <= 16, "currently num_tokens must be less than or equal to 16 for router_gemm"); + TORCH_CHECK(mat_a.dtype() == torch::kBFloat16, "mat_a must be bf16"); + TORCH_CHECK(mat_b.dtype() == torch::kBFloat16, "mat_b must be bf16"); + TORCH_CHECK( + output.dtype() == torch::kFloat32 || output.dtype() == torch::kBFloat16, "output must be float32 or bf16"); + + auto const sm = getSMVersion(); + TORCH_CHECK(sm >= 90, "required CUDA ARCH >= SM_90"); + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + if (output.dtype() == torch::kFloat32) { + LoopUnroller<1, 16, num_experts, hidden_dim>::unroll_float_output( + num_tokens, + reinterpret_cast(output.mutable_data_ptr()), + reinterpret_cast<__nv_bfloat16 const*>(mat_a.data_ptr()), + reinterpret_cast<__nv_bfloat16 const*>(mat_b.data_ptr()), + stream); + } else if (output.dtype() == torch::kBFloat16) { + LoopUnroller<1, 16, num_experts, hidden_dim>::unroll_bf16_output( + num_tokens, + reinterpret_cast<__nv_bfloat16*>(output.mutable_data_ptr()), + reinterpret_cast<__nv_bfloat16 const*>(mat_a.data_ptr()), + reinterpret_cast<__nv_bfloat16 const*>(mat_b.data_ptr()), + stream); + } +} diff --git a/sgl-kernel/csrc/gemm/dsv3_router_gemm.cu b/sgl-kernel/csrc/gemm/dsv3_router_gemm_float_out.cu similarity index 54% rename from sgl-kernel/csrc/gemm/dsv3_router_gemm.cu rename to sgl-kernel/csrc/gemm/dsv3_router_gemm_float_out.cu index 410bbcefd3a6..e7577c55bc44 100644 --- a/sgl-kernel/csrc/gemm/dsv3_router_gemm.cu +++ b/sgl-kernel/csrc/gemm/dsv3_router_gemm_float_out.cu @@ -46,7 +46,7 @@ __device__ __forceinline__ void bf16_uint4_to_float8(uint4 const& vec, float* ds } template -__global__ __launch_bounds__(128, 1) void router_gemm_kernel(float* out, T const* mat_a, T const* mat_b) { +__global__ __launch_bounds__(128, 1) void router_gemm_kernel_float_output(float* out, T const* mat_a, T const* mat_b) { // Each block handles one expert column int const n_idx = blockIdx.x; int const tid = threadIdx.x; @@ -163,7 +163,7 @@ __global__ __launch_bounds__(128, 1) void router_gemm_kernel(float* out, T const } template -void invokeRouterGemm(float* output, T const* mat_a, T const* mat_b, cudaStream_t stream) { +void invokeRouterGemmFloatOutput(float* output, T const* mat_a, T const* mat_b, cudaStream_t stream) { constexpr int VPT = 16 / sizeof(T); constexpr int kBlockSize = 128; cudaLaunchConfig_t config; @@ -177,110 +177,57 @@ void invokeRouterGemm(float* output, T const* mat_a, T const* mat_b, cudaStream_ config.numAttrs = 1; config.attrs = attrs; cudaLaunchKernelEx( - &config, router_gemm_kernel, output, mat_a, mat_b); + &config, + router_gemm_kernel_float_output, + output, + mat_a, + mat_b); } -template void -invokeRouterGemm<__nv_bfloat16, 1, 256, 7168>(float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); +template void invokeRouterGemmFloatOutput<__nv_bfloat16, 1, 256, 7168>( + float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); -template void -invokeRouterGemm<__nv_bfloat16, 2, 256, 7168>(float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); +template void invokeRouterGemmFloatOutput<__nv_bfloat16, 2, 256, 7168>( + float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); -template void -invokeRouterGemm<__nv_bfloat16, 3, 256, 7168>(float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); +template void invokeRouterGemmFloatOutput<__nv_bfloat16, 3, 256, 7168>( + float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); -template void -invokeRouterGemm<__nv_bfloat16, 4, 256, 7168>(float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); +template void invokeRouterGemmFloatOutput<__nv_bfloat16, 4, 256, 7168>( + float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); -template void -invokeRouterGemm<__nv_bfloat16, 5, 256, 7168>(float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); +template void invokeRouterGemmFloatOutput<__nv_bfloat16, 5, 256, 7168>( + float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); -template void -invokeRouterGemm<__nv_bfloat16, 6, 256, 7168>(float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); +template void invokeRouterGemmFloatOutput<__nv_bfloat16, 6, 256, 7168>( + float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); -template void -invokeRouterGemm<__nv_bfloat16, 7, 256, 7168>(float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); +template void invokeRouterGemmFloatOutput<__nv_bfloat16, 7, 256, 7168>( + float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); -template void -invokeRouterGemm<__nv_bfloat16, 8, 256, 7168>(float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); +template void invokeRouterGemmFloatOutput<__nv_bfloat16, 8, 256, 7168>( + float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); -template void -invokeRouterGemm<__nv_bfloat16, 9, 256, 7168>(float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); +template void invokeRouterGemmFloatOutput<__nv_bfloat16, 9, 256, 7168>( + float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); -template void -invokeRouterGemm<__nv_bfloat16, 10, 256, 7168>(float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); +template void invokeRouterGemmFloatOutput<__nv_bfloat16, 10, 256, 7168>( + float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); -template void -invokeRouterGemm<__nv_bfloat16, 11, 256, 7168>(float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); +template void invokeRouterGemmFloatOutput<__nv_bfloat16, 11, 256, 7168>( + float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); -template void -invokeRouterGemm<__nv_bfloat16, 12, 256, 7168>(float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); +template void invokeRouterGemmFloatOutput<__nv_bfloat16, 12, 256, 7168>( + float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); -template void -invokeRouterGemm<__nv_bfloat16, 13, 256, 7168>(float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); +template void invokeRouterGemmFloatOutput<__nv_bfloat16, 13, 256, 7168>( + float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); -template void -invokeRouterGemm<__nv_bfloat16, 14, 256, 7168>(float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); +template void invokeRouterGemmFloatOutput<__nv_bfloat16, 14, 256, 7168>( + float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); -template void -invokeRouterGemm<__nv_bfloat16, 15, 256, 7168>(float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); +template void invokeRouterGemmFloatOutput<__nv_bfloat16, 15, 256, 7168>( + float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); -template void -invokeRouterGemm<__nv_bfloat16, 16, 256, 7168>(float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); - -template -struct LoopUnroller { - static void - unroll(int num_tokens, float* output, __nv_bfloat16 const* input, __nv_bfloat16 const* weights, cudaStream_t stream) { - if (num_tokens == kBegin) { - invokeRouterGemm<__nv_bfloat16, kBegin, kNumExperts, kHiddenDim>(output, input, weights, stream); - } else { - LoopUnroller::unroll(num_tokens, output, input, weights, stream); - } - } -}; - -template -struct LoopUnroller { - static void - unroll(int num_tokens, float* output, __nv_bfloat16 const* input, __nv_bfloat16 const* weights, cudaStream_t stream) { - if (num_tokens == kEnd) { - invokeRouterGemm<__nv_bfloat16, kEnd, kNumExperts, kHiddenDim>(output, input, weights, stream); - } else { - throw std::invalid_argument("Invalid num_tokens, only supports 1 to 16"); - } - } -}; - -void dsv3_router_gemm( - torch::Tensor& output, // [num_tokens, num_experts] - const torch::Tensor& mat_a, // [num_tokens, hidden_dim] - const torch::Tensor& mat_b // [num_experts, hidden_dim] -) { - TORCH_CHECK(output.dim() == 2 && mat_a.dim() == 2 && mat_b.dim() == 2); - - const int num_tokens = mat_a.size(0); - constexpr int num_experts = 256; - constexpr int hidden_dim = 7168; - - TORCH_CHECK(mat_a.size(1) == mat_b.size(1), "mat_a and mat_b must have the same hidden_dim"); - TORCH_CHECK(mat_a.size(1) == hidden_dim, "currently hidden_dim only supports 7168"); - TORCH_CHECK(mat_b.size(0) == num_experts, "currently num_experts only supports 256"); - TORCH_CHECK( - num_tokens >= 1 && num_tokens <= 16, "currently num_tokens must be less than or equal to 16 for router_gemm"); - TORCH_CHECK(mat_a.dtype() == torch::kBFloat16, "mat_a must be bf16"); - TORCH_CHECK(mat_b.dtype() == torch::kBFloat16, "mat_b must be bf16"); - TORCH_CHECK(output.dtype() == torch::kFloat32, "output must be float32"); - - auto const sm = getSMVersion(); - TORCH_CHECK(sm >= 90, "required CUDA ARCH >= SM_90"); - - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - - LoopUnroller<1, 16, num_experts, hidden_dim>::unroll( - num_tokens, - reinterpret_cast(output.mutable_data_ptr()), - reinterpret_cast<__nv_bfloat16 const*>(mat_a.data_ptr()), - reinterpret_cast<__nv_bfloat16 const*>(mat_b.data_ptr()), - stream); -} +template void invokeRouterGemmFloatOutput<__nv_bfloat16, 16, 256, 7168>( + float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); diff --git a/sgl-kernel/python/sgl_kernel/gemm.py b/sgl-kernel/python/sgl_kernel/gemm.py index 6ec4ce78ab32..7435cfdda1e4 100644 --- a/sgl-kernel/python/sgl_kernel/gemm.py +++ b/sgl-kernel/python/sgl_kernel/gemm.py @@ -262,12 +262,13 @@ def qserve_w4a8_per_group_gemm( def dsv3_router_gemm( hidden_states: torch.Tensor, router_weights: torch.Tensor, + out_dtype: torch.dtype = torch.bfloat16, ) -> torch.Tensor: output = torch.empty( hidden_states.shape[0], router_weights.shape[0], device=hidden_states.device, - dtype=torch.float32, + dtype=out_dtype, ) torch.ops.sgl_kernel.dsv3_router_gemm( output, diff --git a/sgl-kernel/tests/test_dsv3_router_gemm.py b/sgl-kernel/tests/test_dsv3_router_gemm.py index 1b60bcf920d5..169c996719d5 100644 --- a/sgl-kernel/tests/test_dsv3_router_gemm.py +++ b/sgl-kernel/tests/test_dsv3_router_gemm.py @@ -15,17 +15,20 @@ def test_dsv3_router_gemm(num_tokens): mat_b = torch.randn( (num_experts, hidden_dim), dtype=torch.bfloat16, device="cuda" ).contiguous() - output = torch.empty( - (num_tokens, num_experts), dtype=torch.float32, device="cuda" - ).contiguous() - ref = F.linear(mat_a, mat_b).to(torch.float32) + bf16_ref = F.linear(mat_a, mat_b) + float_ref = bf16_ref.to(torch.float32) + + bf16_output = dsv3_router_gemm(mat_a, mat_b, out_dtype=torch.bfloat16) + float_output = dsv3_router_gemm(mat_a, mat_b, out_dtype=torch.float32) - output = dsv3_router_gemm(mat_a, mat_b) + assert torch.allclose( + bf16_output, bf16_ref, rtol=1e-2, atol=1e-3 + ), "Router GEMM output in bf16 dtype mismatch with torch.nn.functional.linear reference" assert torch.allclose( - output, ref, rtol=1e-2, atol=1e-3 - ), "Router GEMM output mismatch with torch.nn.functional.linear reference" + float_output, float_ref, rtol=1e-2, atol=1e-3 + ), "Router GEMM output in float32 dtype mismatch with torch.nn.functional.linear reference" if __name__ == "__main__": From cbdfb77123e020aa6d45e423b283f9a3d96e4f96 Mon Sep 17 00:00:00 2001 From: Clay Date: Sun, 20 Jul 2025 10:30:16 +0800 Subject: [PATCH 062/396] Enable FlashInfer support encoder models and add head_dim padding workaround (#6230) --- .../srt/layers/attention/flashinfer_backend.py | 11 ++++++++++- .../srt/models/test_encoder_embedding_models.py | 17 +++++++++++++++-- 2 files changed, 25 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index f65e533d92fb..c7da38ac51cc 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -25,6 +25,7 @@ from sglang.srt.layers.attention.base_attn_backend import AttentionBackend from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton from sglang.srt.layers.dp_attention import get_attention_tp_size +from sglang.srt.layers.radix_attention import AttentionType from sglang.srt.layers.utils import is_sm100_supported from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode @@ -486,12 +487,20 @@ def forward_extend( v_scale=layer.v_scale, ) else: + causal = True + if layer.attn_type == AttentionType.ENCODER_ONLY: + save_kv_cache = False + causal = False + if self.forward_metadata.extend_no_prefix: + # NOTE: FlashInfer currently has limitations with head_dim = 32 or other dimensions + # The FlashInfer head_dim limitation itself is tracked here: + # https://github.com/flashinfer-ai/flashinfer/issues/1048 o = self.prefill_wrapper_ragged.forward( q.view(-1, layer.tp_q_head_num, layer.head_dim), k.view(-1, layer.tp_k_head_num, layer.head_dim), v.view(-1, layer.tp_v_head_num, layer.head_dim), - causal=True, + causal=causal, sm_scale=layer.scaling, logits_soft_cap=logits_soft_cap, ) diff --git a/test/srt/models/test_encoder_embedding_models.py b/test/srt/models/test_encoder_embedding_models.py index bea5d4affe47..dafaa72db595 100644 --- a/test/srt/models/test_encoder_embedding_models.py +++ b/test/srt/models/test_encoder_embedding_models.py @@ -27,9 +27,9 @@ MODELS = [("BAAI/bge-small-en", 1, 1e-5), ("BAAI/bge-m3", 1, 1e-5)] -ATTENTION_BACKEND = ["torch_native", "triton"] +ATTENTION_BACKEND = ["torch_native", "triton", "flashinfer"] BATCH_SIZE = [1, 2] -TORCH_DTYPES = [torch.float32] +TORCH_DTYPES = [torch.float32, torch.float16] sgl_to_st_ratio = [] @@ -126,6 +126,19 @@ def test_prefill_logits(self): for attention_backend in ATTENTION_BACKEND: for batch_size in BATCH_SIZE: for torch_dtype in TORCH_DTYPES: + # NOTE: FlashInfer currently has limitations with head_dim = 32 or + # other dimensions. + # The FlashInfer head_dim limitation itself is tracked here: + # https://github.com/flashinfer-ai/flashinfer/issues/1048 + # + # Flashinfer does not support torch.float32 for dtype_q, so skip it + if attention_backend == "flashinfer": + if ( + model == "BAAI/bge-small-en" + or torch_dtype == torch.float32 + ): + continue + self.assert_close_prefill_logits( DEFAULT_PROMPTS, model, From 877e35d7754cd1fa60b3f1226929dbc84146ea70 Mon Sep 17 00:00:00 2001 From: Pavel Logachev Date: Sun, 20 Jul 2025 05:31:16 +0300 Subject: [PATCH 063/396] Add get_hidden_dim to qwen3.py for correct lora (#7312) --- python/sglang/srt/models/qwen3.py | 24 +++ python/sglang/test/runners.py | 7 +- test/srt/models/lora/test_lora.py | 1 - test/srt/models/lora/test_lora_qwen3.py | 209 ++++++++++++++++++++++++ test/srt/run_suite.py | 1 + 5 files changed, 240 insertions(+), 2 deletions(-) create mode 100644 test/srt/models/lora/test_lora_qwen3.py diff --git a/python/sglang/srt/models/qwen3.py b/python/sglang/srt/models/qwen3.py index 6289e61e7a72..7d7c3bf7b19f 100644 --- a/python/sglang/srt/models/qwen3.py +++ b/python/sglang/srt/models/qwen3.py @@ -330,6 +330,30 @@ def __init__( def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) + def get_hidden_dim(self, module_name: str) -> Tuple[int]: + # return input_dim, output_dim + if module_name in ["q_proj", "qkv_proj"]: + return ( + self.config.hidden_size, + self.config.head_dim * self.config.num_attention_heads, + ) + elif module_name in ["o_proj"]: + return ( + self.config.head_dim * self.config.num_attention_heads, + self.config.hidden_size, + ) + elif module_name in ["kv_proj"]: + return ( + self.config.hidden_size, + self.config.head_dim * self.config.num_key_value_heads, + ) + elif module_name == "gate_up_proj": + return self.config.hidden_size, self.config.intermediate_size + elif module_name == "down_proj": + return self.config.intermediate_size, self.config.hidden_size + else: + raise NotImplementedError() + @torch.no_grad() def forward( self, diff --git a/python/sglang/test/runners.py b/python/sglang/test/runners.py index 9ec71c29bac8..ed30b3687922 100644 --- a/python/sglang/test/runners.py +++ b/python/sglang/test/runners.py @@ -134,10 +134,12 @@ def __init__( model_type: str = "generation", output_str_only: bool = False, trust_remote_code: bool = False, + patch_model_do_sample_false: bool = False, ): self.model_type = model_type self.output_str_only = output_str_only self.trust_remote_code = trust_remote_code + self.patch_model_do_sample_false = patch_model_do_sample_false self.in_queue = mp.Queue() self.out_queue = mp.Queue() @@ -292,6 +294,7 @@ def start_model_process(self, in_queue, out_queue, model_path, torch_dtype): torch_dtype=torch_dtype, output_str_only=self.output_str_only, token_ids_logprob=token_ids_logprob, + patch_model_do_sample_false=self.patch_model_do_sample_false, ) ) elif self.model_type == "embedding": @@ -380,6 +383,7 @@ def forward_generation_raw( lora_paths: Optional[List[str]] = None, output_str_only: bool = False, token_ids_logprob: Optional[int] = None, + patch_model_do_sample_false: Optional[bool] = False, ) -> ModelOutput: output_strs = [] top_input_logprobs = [] @@ -407,7 +411,8 @@ def forward_generation_raw( ) else: model = base_model - + if patch_model_do_sample_false: + model.generation_config.do_sample = False outputs = model.generate( input_ids=input_ids, generation_config=GenerationConfig( diff --git a/test/srt/models/lora/test_lora.py b/test/srt/models/lora/test_lora.py index bfa727234072..17aa6f3b8c00 100644 --- a/test/srt/models/lora/test_lora.py +++ b/test/srt/models/lora/test_lora.py @@ -84,7 +84,6 @@ def ensure_reproducibility(self): torch.use_deterministic_algorithms(True) def _run_lora_multiple_batch_on_model_cases(self, model_cases: List[LoRAModelCase]): - for model_case in model_cases: for torch_dtype in TORCH_DTYPES: max_new_tokens = 32 diff --git a/test/srt/models/lora/test_lora_qwen3.py b/test/srt/models/lora/test_lora_qwen3.py new file mode 100644 index 000000000000..4519c3c1f8d8 --- /dev/null +++ b/test/srt/models/lora/test_lora_qwen3.py @@ -0,0 +1,209 @@ +# Copyright 2023-2025 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import multiprocessing as mp +import os +import random +import unittest +from typing import List + +from utils import TORCH_DTYPES, LoRAAdaptor, LoRAModelCase + +from sglang.test.runners import HFRunner, SRTRunner +from sglang.test.test_utils import CustomTestCase, calculate_rouge_l, is_in_ci + +LORA_MODELS_QWEN3 = [ + LoRAModelCase( + base="Qwen/Qwen3-4B", + adaptors=[ + LoRAAdaptor( + name="nissenj/Qwen3-4B-lora-v2", + prefill_tolerance=3e-1, + ), + LoRAAdaptor( + name="y9760210/Qwen3-4B-lora_model", + prefill_tolerance=3e-1, + ), + ], + max_loras_per_batch=2, + ), +] + + +TEST_MULTIPLE_BATCH_PROMPTS = [ + """ + ### Instruction: + Tell me about llamas and alpacas + ### Response: + Llamas are large, long-necked animals with a woolly coat. They have two toes on each foot instead of three like other camelids (camels, dromedaries). Llamas live in the Andean mountains of South America where they graze on grasses and shrubs. Alpaca is another name for domesticated llama. The word "alpaca" comes from an Incan language meaning "golden fleece." Alpacas look very similar to llamas but are smaller than their wild relatives. Both species were used by ancient people as pack animals and for meat. Today both llamas and alpacas are raised primarily for their fiber which can be spun into yarn or knitted into clothing. + ### Question 2: + What do you know about llamas? + ### Answer: + """, + """ + ### Instruction: + Write a poem about the transformers Python library. + Mention the word "large language models" in that poem. + ### Response: + The Transformers are large language models, + They're used to make predictions on text. + """, + # "AI is a field of computer science focused on", TODO: Add it back after fixing its bug + "Computer science is the study of", + "Write a short story.", + "What are the main components of a computer?", +] + + +class TestLoRA(CustomTestCase): + + def _run_lora_multiple_batch_on_model_cases(self, model_cases: List[LoRAModelCase]): + for model_case in model_cases: + for torch_dtype in TORCH_DTYPES: + max_new_tokens = 10 + backend = "triton" + base_path = model_case.base + lora_adapter_paths = [a.name for a in model_case.adaptors] + assert len(lora_adapter_paths) >= 2 + + batches = [ + ( + [ + random.choice(TEST_MULTIPLE_BATCH_PROMPTS), + random.choice(TEST_MULTIPLE_BATCH_PROMPTS), + random.choice(TEST_MULTIPLE_BATCH_PROMPTS), + ], + [ + None, + lora_adapter_paths[0], + lora_adapter_paths[1], + ], + ), + ( + [ + random.choice(TEST_MULTIPLE_BATCH_PROMPTS), + random.choice(TEST_MULTIPLE_BATCH_PROMPTS), + random.choice(TEST_MULTIPLE_BATCH_PROMPTS), + ], + [ + lora_adapter_paths[0], + None, + lora_adapter_paths[1], + ], + ), + ( + [ + random.choice(TEST_MULTIPLE_BATCH_PROMPTS), + random.choice(TEST_MULTIPLE_BATCH_PROMPTS), + random.choice(TEST_MULTIPLE_BATCH_PROMPTS), + ], + [lora_adapter_paths[0], lora_adapter_paths[1], None], + ), + ( + [ + random.choice(TEST_MULTIPLE_BATCH_PROMPTS), + random.choice(TEST_MULTIPLE_BATCH_PROMPTS), + random.choice(TEST_MULTIPLE_BATCH_PROMPTS), + ], + [None, lora_adapter_paths[1], None], + ), + ( + [ + random.choice(TEST_MULTIPLE_BATCH_PROMPTS), + random.choice(TEST_MULTIPLE_BATCH_PROMPTS), + random.choice(TEST_MULTIPLE_BATCH_PROMPTS), + ], + [None, None, None], + ), + ] + + print( + f"\n========== Testing multiple batches on base '{base_path}' with backend={backend}, dtype={torch_dtype} ---" + ) + + # Initialize runners + srt_runner = SRTRunner( + base_path, + torch_dtype=torch_dtype, + model_type="generation", + lora_paths=[lora_adapter_paths[0], lora_adapter_paths[1]], + max_loras_per_batch=len(lora_adapter_paths) + 1, + lora_backend=backend, + disable_radix_cache=True, + ) + hf_runner = HFRunner( + base_path, + torch_dtype=torch_dtype, + model_type="generation", + patch_model_do_sample_false=True, + ) + + with srt_runner, hf_runner: + for i, (prompts, lora_paths) in enumerate(batches): + print( + f"\n--- Running Batch {i+1} --- prompts: {prompts}, lora_paths: {lora_paths}" + ) + + srt_outputs = srt_runner.batch_forward( + prompts, + max_new_tokens=max_new_tokens, + lora_paths=lora_paths, + ) + + hf_outputs = hf_runner.forward( + prompts, + max_new_tokens=max_new_tokens, + lora_paths=lora_paths, + ) + + print("SRT outputs:", [s for s in srt_outputs.output_strs]) + print("HF outputs:", [s for s in hf_outputs.output_strs]) + + for srt_out, hf_out in zip( + srt_outputs.output_strs, hf_outputs.output_strs + ): + srt_str = srt_out.strip() + hf_str = hf_out.strip() + rouge_tol = model_case.rouge_l_tolerance + rouge_score = calculate_rouge_l([srt_str], [hf_str])[0] + if rouge_score < rouge_tol: + raise AssertionError( + f"ROUGE-L score {rouge_score} below tolerance {rouge_tol} " + f"for base '{base_path}', adaptor '{lora_paths}', backend '{backend}', prompt: '{prompts}...'" + ) + + print(f"--- Batch {i+1} Comparison Passed --- ") + + def test_ci_lora_models(self): + self._run_lora_multiple_batch_on_model_cases(LORA_MODELS_QWEN3) + + def test_all_lora_models(self): + if is_in_ci(): + return + qwen_filtered_models = [] + for model_case in LORA_MODELS_QWEN3: + if "ONLY_RUN" in os.environ and os.environ["ONLY_RUN"] != model_case.base: + continue + qwen_filtered_models.append(model_case) + + self._run_lora_multiple_batch_on_model_cases(qwen_filtered_models) + + +if __name__ == "__main__": + try: + mp.set_start_method("spawn") + except RuntimeError: + pass + + unittest.main(warnings="ignore") diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index d7b4739e38cb..0e62760ab72f 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -19,6 +19,7 @@ class TestFile: TestFile("models/lora/test_multi_lora_backend.py", 60), TestFile("models/lora/test_lora_cuda_graph.py", 250), TestFile("models/lora/test_lora_update.py", 800), + TestFile("models/lora/test_lora_qwen3.py", 97), TestFile("models/test_embedding_models.py", 73), # TestFile("models/test_clip_models.py", 52), TestFile("models/test_encoder_embedding_models.py", 100), From 0f9b11e3101b691fa89df8db212a01b13344431d Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Sat, 19 Jul 2025 20:04:02 -0700 Subject: [PATCH 064/396] feat: add h200 tp 16 kimi k2 moe config (#8176) --- ...dtype=fp8_w8a8,block_shape=[128, 128].json | 146 ++++++++++++++++++ 1 file changed, 146 insertions(+) create mode 100644 python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 000000000000..b5c45dd7231e --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + } +} From f62d75b6a17d836aac6d1d81c1b124d0708e9ca0 Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Sat, 19 Jul 2025 20:04:12 -0700 Subject: [PATCH 065/396] feat: add b200 tp 16 kimi k2 moe config (#8178) --- ...dtype=fp8_w8a8,block_shape=[128, 128].json | 146 ++++++++++++++++++ 1 file changed, 146 insertions(+) create mode 100644 python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 000000000000..faf1aa4d4ce0 --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + } +} From a589a0716774196d437bdbfe282283e593f0882a Mon Sep 17 00:00:00 2001 From: Atream <80757050+Atream@users.noreply.github.com> Date: Sun, 20 Jul 2025 13:13:46 +0800 Subject: [PATCH 066/396] fix moe gate dtype, fix tbo, fix fake dispatch (#7825) --- python/sglang/srt/eplb/expert_location_dispatch.py | 2 +- python/sglang/srt/layers/moe/topk.py | 2 +- python/sglang/srt/models/deepseek_v2.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/eplb/expert_location_dispatch.py b/python/sglang/srt/eplb/expert_location_dispatch.py index 36224eee7f8a..8d2160b6ed73 100644 --- a/python/sglang/srt/eplb/expert_location_dispatch.py +++ b/python/sglang/srt/eplb/expert_location_dispatch.py @@ -66,7 +66,7 @@ def transform_select_experts_inputs( info: Optional[ExpertLocationDispatchInfo], ): if (info is not None) and (info.ep_dispatch_algorithm == "fake"): - router_logits = torch.randn_like(router_logits) + router_logits.uniform_(5, 10) if correction_bias is not None: correction_bias = torch.zeros_like(correction_bias) return router_logits, correction_bias diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index bb3cf651542a..c3ae9af25d0d 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -499,7 +499,7 @@ def biased_grouped_topk_gpu( and is_power_of_two(correction_bias.shape[0]) ): topk_weights, topk_ids = moe_fused_gate( - gating_output, + gating_output.to(dtype=torch.float32), correction_bias, num_expert_group, topk_group, diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 9ec5db9260d3..a65337945f6b 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -229,7 +229,7 @@ def __init__( ) if config.topk_method == "noaux_tc": self.e_score_correction_bias = nn.Parameter( - torch.empty((config.n_routed_experts)) + torch.empty((config.n_routed_experts), dtype=torch.float32) ) else: self.e_score_correction_bias = None From 55381a46ac6bf7d9b0e39d0673f8318feea2ff7e Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Sat, 19 Jul 2025 22:41:30 -0700 Subject: [PATCH 067/396] Revert "[Feature] Simple Improve Health Check Mechanism for Production-Grade Stability" (#8181) --- python/sglang/srt/entrypoints/engine.py | 4 -- python/sglang/srt/entrypoints/http_server.py | 57 +++---------------- python/sglang/srt/managers/io_struct.py | 6 -- python/sglang/srt/managers/scheduler.py | 3 - .../sglang/srt/managers/tokenizer_manager.py | 7 +-- python/sglang/srt/utils.py | 16 ------ 6 files changed, 11 insertions(+), 82 deletions(-) diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index 957d85aa5998..990fac9a12a7 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -65,7 +65,6 @@ from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter from sglang.srt.utils import ( MultiprocessingSerializer, - ServerStatus, assert_pkg_version, configure_logger, get_zmq_socket, @@ -74,7 +73,6 @@ launch_dummy_health_check_server, maybe_set_triton_cache_manager, prepare_model_and_tokenizer, - report_health, set_prometheus_multiproc_dir, set_ulimit, ) @@ -663,7 +661,6 @@ def _set_envs_and_config(server_args: ServerArgs): def sigchld_handler(signum, frame): pid, exitcode = os.waitpid(0, os.WNOHANG) if exitcode != 0: - report_health(ServerStatus.Crashed, server_args.host, server_args.port) logger.warning( f"Child process unexpectedly failed with {exitcode=}. {pid=}" ) @@ -677,7 +674,6 @@ def sigquit_handler(signum, frame): logger.error( "Received sigquit from a child process. It usually means the child failed." ) - report_health(ServerStatus.Crashed, server_args.host, server_args.port) kill_process_tree(os.getpid()) signal.signal(signal.SIGQUIT, sigquit_handler) diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index f880c4aa5cd4..43819e1a65e4 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -77,7 +77,6 @@ ParseFunctionCallReq, ProfileReqInput, ReleaseMemoryOccupationReqInput, - ReportHealthInput, ResumeMemoryOccupationReqInput, SeparateReasoningReqInput, SetInternalStateReq, @@ -94,7 +93,6 @@ from sglang.srt.reasoning_parser import ReasoningParser from sglang.srt.server_args import ServerArgs from sglang.srt.utils import ( - ServerStatus, add_api_key_middleware, add_prometheus_middleware, delete_directory, @@ -222,31 +220,8 @@ async def validate_json_request(raw_request: Request): @app.get("/health") async def health() -> Response: - """Check the status of the http server.""" - code = HTTPStatus.SERVICE_UNAVAILABLE.value - if _global_state.tokenizer_manager.server_status == ServerStatus.Up: - code = HTTPStatus.OK.value - return Response( - status_code=code, - content=json.dumps( - {"status": _global_state.tokenizer_manager.server_status.value} - ), - ) - - -@app.post("/health") -async def health_update(obj: ReportHealthInput, request: Request) -> Response: - """Update the Status of the http server.""" - try: - server_status = ServerStatus(obj.status) - _global_state.tokenizer_manager.server_status = server_status - if server_status != ServerStatus.Up: - return Response( - status_code=HTTPStatus.SERVICE_UNAVAILABLE.value, content=obj.msg - ) - except Exception as e: - logger.error(e) - return Response(status_code=HTTPStatus.SERVICE_UNAVAILABLE.value) + """Check the health of the http server.""" + return Response(status_code=200) @app.get("/health_generate") @@ -281,7 +256,7 @@ async def gen(): if _global_state.tokenizer_manager.last_receive_tstamp > tic: task.cancel() _global_state.tokenizer_manager.rid_to_state.pop(rid, None) - _global_state.tokenizer_manager.server_status = ServerStatus.Up + _global_state.tokenizer_manager.health_check_failed = False return Response(status_code=200) task.cancel() @@ -295,7 +270,7 @@ async def gen(): f"last_heartbeat time: {last_receive_time}" ) _global_state.tokenizer_manager.rid_to_state.pop(rid, None) - _global_state.tokenizer_manager.server_status = ServerStatus.UnHealthy + _global_state.tokenizer_manager.health_check_failed = True return Response(status_code=503) @@ -1047,13 +1022,9 @@ def _execute_server_warmup( headers=headers, timeout=600, ) - if res.status_code == 200: - _global_state.tokenizer_manager.server_status = ServerStatus.Up - else: - _global_state.tokenizer_manager.server_status = ServerStatus.UnHealthy - logger.info(f"{res}") + assert res.status_code == 200, f"{res}" else: - logger.info(f"Start of prefill/decode warmup ...") + logger.info(f"Start of prefill warmup ...") json_data = { "sampling_params": { "temperature": 0.0, @@ -1075,25 +1046,15 @@ def _execute_server_warmup( headers=headers, timeout=1800, # because of deep gemm precache is very long if not precache. ) - if res.status_code == 200: - logger.info( - f"End of prefill disaggregation mode warmup with status {res.status_code}, resp: {res.json()}" - ) - _global_state.tokenizer_manager.server_status = ServerStatus.Up - else: - logger.info( - "Prefill disaggregation mode warm Up Failed, status code: {}".format( - res.status_code - ) - ) - _global_state.tokenizer_manager.server_status = ServerStatus.UnHealthy + logger.info( + f"End of prefill warmup with status {res.status_code}, resp: {res.json()}" + ) except Exception: last_traceback = get_exception_traceback() if pipe_finish_writer is not None: pipe_finish_writer.send(last_traceback) logger.error(f"Initialization failed. warmup error: {last_traceback}") - _global_state.tokenizer_manager.server_status = ServerStatus.Crashed kill_process_tree(os.getpid()) return False diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index b8332fdf656c..8e1d1075aab6 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -1083,9 +1083,3 @@ class LoRAUpdateResult: LoadLoRAAdapterReqOutput = UnloadLoRAAdapterReqOutput = LoRAUpdateResult - - -@dataclass -class ReportHealthInput: - status: str - msg: Optional[str] = "" diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index aee1596dbe9c..e6dd80d717ad 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -143,7 +143,6 @@ from sglang.srt.utils import ( DeepEPMode, DynamicGradMode, - ServerStatus, broadcast_pyobj, configure_gc_logger, configure_logger, @@ -155,7 +154,6 @@ kill_itself_when_parent_died, point_to_point_pyobj, pyspy_dump_schedulers, - report_health, require_mlp_sync, require_mlp_tp_gather, set_gpu_proc_affinity, @@ -2966,5 +2964,4 @@ def run_scheduler_process( except Exception: traceback = get_exception_traceback() logger.error(f"Scheduler hit an exception: {traceback}") - report_health(ServerStatus.Crashed, server_args.host, ServerArgs.port) parent_process.send_signal(signal.SIGQUIT) diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index a0f66419e768..631d23f17335 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -116,7 +116,6 @@ from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.utils import ( - ServerStatus, dataclass_to_string_truncated, get_bool_env_var, get_zmq_socket, @@ -174,9 +173,6 @@ def __init__( server_args: ServerArgs, port_args: PortArgs, ): - # Server Status - self.server_status = ServerStatus.Starting - # Parse args self.server_args = server_args self.enable_metrics = server_args.enable_metrics @@ -255,6 +251,7 @@ def __init__( # Store states self.no_create_loop = False self.rid_to_state: Dict[str, ReqState] = {} + self.health_check_failed = False self.gracefully_exit = False self.last_receive_tstamp = 0 self.dump_requests_folder = "" # By default do not dump @@ -1335,7 +1332,7 @@ async def sigterm_watchdog(self): while True: remain_num_req = len(self.rid_to_state) - if not self.server_status.is_healthy(): + if self.health_check_failed: # if health check failed, we should exit immediately logger.error( "Signal SIGTERM received while health check failed. Exiting... remaining number of requests: %d", diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 03565a018c34..23960a8c1123 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -93,22 +93,6 @@ HIP_FP8_E4M3_FNUZ_MAX = 224.0 -class ServerStatus(Enum): - Up = "Up" - Starting = "Starting" - UnHealthy = "UnHealthy" - Crashed = "Crashed" - - def is_healthy(self) -> bool: - return self == ServerStatus.Up - - -def report_health(status: ServerStatus, host: str, http_port: int, msg: str = ""): - requests.post( - f"http://{host}:{http_port}/health", json={"status": status.value, "msg": msg} - ) - - # https://pytorch.org/docs/stable/notes/hip.html#checking-for-hip def is_hip() -> bool: return torch.version.hip is not None From 2db6719cc5bb1de607c07bcefe06d915fd0ca45d Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Sat, 19 Jul 2025 22:55:45 -0700 Subject: [PATCH 068/396] feat: update nccl 2.27.6 (#8182) --- docker/Dockerfile | 2 +- docker/Dockerfile.gb200 | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docker/Dockerfile b/docker/Dockerfile index bc0eb095e917..97be3625af7c 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -59,7 +59,7 @@ RUN python3 -m pip install --no-cache-dir --upgrade pip setuptools wheel html5li esac \ && python3 -m pip install --no-cache-dir -e "python[${BUILD_TYPE}]" --extra-index-url https://download.pytorch.org/whl/cu${CUINDEX} \ && if [ "$CUDA_VERSION" = "12.8.1" ]; then \ - python3 -m pip install --no-cache-dir nvidia-nccl-cu12==2.27.5 --force-reinstall --no-deps ; \ + python3 -m pip install --no-cache-dir nvidia-nccl-cu12==2.27.6 --force-reinstall --no-deps ; \ python3 -m pip install --no-cache-dir https://github.com/sgl-project/whl/releases/download/v0.2.6/sgl_kernel-0.2.6+cu128-cp39-abi3-manylinux2014_x86_64.whl --force-reinstall --no-deps ; \ fi diff --git a/docker/Dockerfile.gb200 b/docker/Dockerfile.gb200 index 1e0e665234f1..b4da2c5ddb84 100644 --- a/docker/Dockerfile.gb200 +++ b/docker/Dockerfile.gb200 @@ -69,7 +69,7 @@ RUN python3 -m pip install --no-cache-dir --upgrade pip setuptools wheel html5li esac \ && python3 -m pip install --no-cache-dir -e "python[${BUILD_TYPE}]" --extra-index-url https://download.pytorch.org/whl/cu${CUINDEX} \ && if [ "$CUDA_VERSION" = "12.8.1" ]; then \ - python3 -m pip install --no-cache-dir nvidia-nccl-cu12==2.27.5 --force-reinstall --no-deps ; \ + python3 -m pip install --no-cache-dir nvidia-nccl-cu12==2.27.6 --force-reinstall --no-deps ; \ python3 -m pip install --no-cache-dir https://github.com/sgl-project/whl/releases/download/v0.2.4/sgl_kernel-0.2.4+cu128-cp39-abi3-manylinux2014_$(uname -m).whl --force-reinstall --no-deps ; \ fi From 83c104b18823ea52dba4e90e6a3ca6c54ca037a4 Mon Sep 17 00:00:00 2001 From: Praneth Paruchuri <34855725+ppraneth@users.noreply.github.com> Date: Sun, 20 Jul 2025 11:37:47 +0530 Subject: [PATCH 069/396] Feat: Support for Persimmon Model (#7983) --- docs/supported_models/generative_models.md | 1 + python/sglang/srt/layers/activation.py | 12 + python/sglang/srt/models/persimmon.py | 330 +++++++++++++++++++++ test/srt/models/test_generation_models.py | 1 + 4 files changed, 344 insertions(+) create mode 100644 python/sglang/srt/models/persimmon.py diff --git a/docs/supported_models/generative_models.md b/docs/supported_models/generative_models.md index 0096d6e0932d..8aeac1ae4dbc 100644 --- a/docs/supported_models/generative_models.md +++ b/docs/supported_models/generative_models.md @@ -45,5 +45,6 @@ in the GitHub search bar. | **SmolLM** (135M–1.7B) | `HuggingFaceTB/SmolLM-1.7B` | Hugging Face’s ultra-small LLM series (135M–1.7B params) offering surprisingly strong results, enabling advanced AI on mobile/edge devices. | | **GLM-4** (Multilingual 9B) | `ZhipuAI/glm-4-9b-chat` | Zhipu’s GLM-4 series (up to 9B parameters) – open multilingual models with support for 1M-token context and even a 5.6B multimodal variant (Phi-4V). | | **MiMo** (7B series) | `XiaomiMiMo/MiMo-7B-RL` | Xiaomi's reasoning-optimized model series, leverages Multiple-Token Prediction for faster inference. | +| **Persimmon** (8B) | `adept/persimmon-8b-chat` | Adept’s open 8B model with a 16K context window and fast inference; trained for broad usability and licensed under Apache 2.0. | | **Granite 3.0, 3.1** (IBM) | `ibm-granite/granite-3.1-8b-instruct` | IBM's open dense foundation models optimized for reasoning, code, and business AI use cases. Integrated with Red Hat and watsonx systems. | | **Granite 3.0 MoE** (IBM) | `ibm-granite/granite-3.0-3b-a800m-instruct` | IBM’s Mixture-of-Experts models offering strong performance with cost-efficiency. MoE expert routing designed for enterprise deployment at scale. | diff --git a/python/sglang/srt/layers/activation.py b/python/sglang/srt/layers/activation.py index 056c5693e466..63e9fcdd3cc9 100644 --- a/python/sglang/srt/layers/activation.py +++ b/python/sglang/srt/layers/activation.py @@ -110,6 +110,17 @@ def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: return self.forward_native(x) +class ReLU2(nn.Module): + """ + Applies the squared Rectified Linear Unit function. + y = max(0, x)^2 + """ + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = F.relu(x) + return x * x + + class QuickGELU(CustomOp): def forward_native(self, x: torch.Tensor) -> torch.Tensor: return x * torch.sigmoid(1.702 * x) @@ -165,6 +176,7 @@ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor): "gelu": nn.GELU(), "gelu_pytorch_tanh": nn.GELU(approximate="tanh"), "gelu_new": NewGELU(), + "relu2": ReLU2(), } diff --git a/python/sglang/srt/models/persimmon.py b/python/sglang/srt/models/persimmon.py new file mode 100644 index 000000000000..5f8885e716e5 --- /dev/null +++ b/python/sglang/srt/models/persimmon.py @@ -0,0 +1,330 @@ +from collections.abc import Iterable +from typing import Optional + +import torch +from torch import nn +from transformers import PersimmonConfig + +from sglang.srt.distributed import get_pp_group, get_tensor_model_parallel_world_size +from sglang.srt.layers.activation import get_act_fn +from sglang.srt.layers.linear import ( + ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) +from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput +from sglang.srt.layers.quantization import QuantizationConfig +from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.rotary_embedding import get_rope +from sglang.srt.layers.utils import PPMissingLayer +from sglang.srt.layers.vocab_parallel_embedding import ( + ParallelLMHead, + VocabParallelEmbedding, +) +from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.utils import add_prefix, make_layers + + +class PersimmonMLP(nn.Module): + + def __init__( + self, config: PersimmonConfig, quant_config: Optional[QuantizationConfig] = None + ): + super().__init__() + self.dense_h_to_4h = ColumnParallelLinear( + config.hidden_size, config.intermediate_size, quant_config=quant_config + ) + self.dense_4h_to_h = RowParallelLinear( + config.intermediate_size, config.hidden_size, quant_config=quant_config + ) + self.act = get_act_fn(config.hidden_act) + + def forward(self, hidden_states) -> torch.Tensor: + hidden_states, _ = self.dense_h_to_4h(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states, _ = self.dense_4h_to_h(hidden_states) + return hidden_states + + +class PersimmonAttention(nn.Module): + + def __init__( + self, + config: PersimmonConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + layer_id: int = 0, + ): + super().__init__() + self.config = config + tensor_parallel_world_size = get_tensor_model_parallel_world_size() + + self.hidden_size = config.hidden_size + self.total_num_heads = config.num_attention_heads + self.num_heads = self.total_num_heads // tensor_parallel_world_size + self.head_dim = self.hidden_size // self.total_num_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.partial_rotary_factor = config.partial_rotary_factor + self.is_causal = True + + assert (self.head_dim * self.total_num_heads) == self.hidden_size + assert self.total_num_heads % tensor_parallel_world_size == 0 + + self.query_key_value = QKVParallelLinear( + self.hidden_size, + self.head_dim, + self.total_num_heads, + bias=True, + quant_config=quant_config, + ) + self.dense = RowParallelLinear( + self.total_num_heads * self.head_dim, + self.hidden_size, + bias=True, + quant_config=quant_config, + ) + self.is_qk_layernorm = config.qk_layernorm + + if self.is_qk_layernorm: + self.q_layernorm = nn.LayerNorm(self.head_dim) + self.k_layernorm = nn.LayerNorm(self.head_dim) + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=self.max_position_embeddings, + base=self.rope_theta, + partial_rotary_factor=self.partial_rotary_factor, + ) + self.scaling = self.head_dim**-0.5 + self.attn = RadixAttention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_heads, + layer_id=layer_id, + quant_config=quant_config, + prefix=add_prefix("attn", prefix), + ) + + def _split_heads(self, x: torch.Tensor) -> torch.Tensor: + seq_length = x.shape[0] + return x.view(seq_length, self.num_heads, self.head_dim) + + def _merge_heads(self, x: torch.Tensor) -> torch.Tensor: + seq_length = x.shape[0] + return x.view(seq_length, self.num_heads * self.head_dim) + + def forward( + self, + position_ids: torch.Tensor, + forward_batch: ForwardBatch, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.query_key_value(hidden_states) + q, k, v = qkv.chunk(chunks=3, dim=-1) + + if self.is_qk_layernorm: + q = self._split_heads(q) + k = self._split_heads(k) + + q = self.q_layernorm(q) + k = self.k_layernorm(k) + + q = self._merge_heads(q) + k = self._merge_heads(k) + + q, k = self.rotary_emb(position_ids, q, k) + attn_output = self.attn(q, k, v, forward_batch=forward_batch) + output, _ = self.dense(attn_output) + return output + + +class PersimmonDecoderLayer(nn.Module): + + def __init__( + self, + config: PersimmonConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + idx: int = 0, + ): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = PersimmonAttention( + config=config, + quant_config=quant_config, + prefix=add_prefix("self_attn", prefix), + layer_id=idx, + ) + self.mlp = PersimmonMLP(config, quant_config=quant_config) + self.input_layernorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps + ) + self.post_attention_layernorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps + ) + + def forward( + self, + position_ids: torch.Tensor, + forward_batch: ForwardBatch, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + hidden_states = self.self_attn( + position_ids=position_ids, + hidden_states=hidden_states, + forward_batch=forward_batch, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + + hidden_states = hidden_states + residual + + outputs = hidden_states + return outputs + + +class PersimmonModel(nn.Module): + + def __init__( + self, + config: PersimmonConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.config = config + self.pp_group = get_pp_group() + + if self.pp_group.is_first_rank: + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, config.hidden_size + ) + else: + self.embed_tokens = PPMissingLayer() + + self.layers, self.start_layer, self.end_layer = make_layers( + config.num_hidden_layers, + lambda idx, prefix: PersimmonDecoderLayer( + config, quant_config=quant_config, prefix=prefix, idx=idx + ), + prefix="model.layers", + pp_rank=self.pp_group.rank_in_group, + pp_size=self.pp_group.world_size, + ) + + if self.pp_group.is_last_rank: + self.final_layernorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps + ) + else: + self.final_layernorm = PPMissingLayer() + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + forward_batch: ForwardBatch, + positions: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if self.pp_group.is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + else: + hidden_states = forward_batch.pp_input_hidden + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + hidden_states = layer( + position_ids=positions, + forward_batch=forward_batch, + hidden_states=hidden_states, + ) + return self.final_layernorm(hidden_states) + + +class PersimmonForCausalLM(nn.Module): + + def __init__( + self, + config: PersimmonConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.config = config + self.quant_config = quant_config + self.model = PersimmonModel( + config=config, quant_config=quant_config, prefix=add_prefix("model", prefix) + ) + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + bias=False, + quant_config=quant_config, + ) + self.logits_processor = LogitsProcessor(config) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> LogitsProcessorOutput: + hidden_states = self.model( + input_ids=input_ids, + forward_batch=forward_batch, + positions=positions, + inputs_embeds=inputs_embeds, + ) + + return self.logits_processor( + input_ids, hidden_states, self.lm_head, forward_batch + ) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if name not in params_dict: + if name == "lm_head.weight": + continue + print(f"Warning: weight {name} not found in model.") + continue + param = params_dict[name] + if "query_key_value" in name: + output_dim = getattr(param, "output_dim", None) + if output_dim is not None: + loaded_weight_shape = loaded_weight.shape + num_heads = self.config.num_attention_heads + loaded_weight = loaded_weight.view( + loaded_weight_shape[:output_dim] + + (num_heads, 3, -1) + + loaded_weight_shape[output_dim + 1 :] + ) + loaded_weight = loaded_weight.transpose(output_dim, output_dim + 1) + loaded_weight = loaded_weight.reshape(loaded_weight_shape) + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + + +EntryClass = PersimmonForCausalLM diff --git a/test/srt/models/test_generation_models.py b/test/srt/models/test_generation_models.py index bfeb00c8d809..daa99001d7b7 100644 --- a/test/srt/models/test_generation_models.py +++ b/test/srt/models/test_generation_models.py @@ -66,6 +66,7 @@ class ModelCase: ), ModelCase("openai-community/gpt2"), ModelCase("microsoft/phi-1_5", trust_remote_code=True), + ModelCase("adept/persimmon-8b-chat"), ModelCase("microsoft/Phi-3-small-8k-instruct", trust_remote_code=True), ModelCase("allenai/OLMo-2-1124-7B-Instruct", skip_long_prompt=True), ModelCase("ibm-granite/granite-3.0-2b-instruct", skip_long_prompt=True), From bbcfbc1a024980e76926c114ca2daf8cec9098a2 Mon Sep 17 00:00:00 2001 From: Qiaolin Yu Date: Sun, 20 Jul 2025 02:30:08 -0400 Subject: [PATCH 070/396] feat: add h200 tp 16 kimi k2 moe config (#8183) --- ...dtype=fp8_w8a8,block_shape=[128, 128].json | 146 ++++++++++++++++++ 1 file changed, 146 insertions(+) create mode 100644 python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 000000000000..358873315860 --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} From 99aefa037edf44b855a22bf350adac7c559cded8 Mon Sep 17 00:00:00 2001 From: Jay Zhou <50169346+Ja1Zhou@users.noreply.github.com> Date: Sun, 20 Jul 2025 00:28:06 -0700 Subject: [PATCH 071/396] Fix eagle3 cuda graph (#8163) --- .../eagle_draft_extend_cuda_graph_runner.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py b/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py index b4ffde60ef62..7057c502da0e 100644 --- a/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +++ b/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py @@ -84,7 +84,15 @@ def __init__(self, eagle_worker: EAGLEWorker): self.hidden_states = torch.zeros( ( self.max_num_token, - self.model_runner.model_config.hidden_size * 3, + ( + self.model_runner.model_config.hf_config.target_hidden_size + * 3 + if hasattr( + self.model_runner.model_config.hf_config, + "target_hidden_size", + ) + else self.model_runner.model_config.hidden_size * 3 + ), ), dtype=self.model_runner.dtype, ) From 750838adc4f9f7c8f4c9464ca92043a06197540a Mon Sep 17 00:00:00 2001 From: GuoYipin <64318822+coco-alen@users.noreply.github.com> Date: Sun, 20 Jul 2025 22:22:54 +0800 Subject: [PATCH 072/396] fix: fix the bug of loading Internvl3 (#8067) Co-authored-by: Xinyuan Tong --- python/sglang/srt/configs/internvl.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/sglang/srt/configs/internvl.py b/python/sglang/srt/configs/internvl.py index 14b6482524db..b4ddda22773d 100644 --- a/python/sglang/srt/configs/internvl.py +++ b/python/sglang/srt/configs/internvl.py @@ -9,6 +9,7 @@ LlamaConfig, PretrainedConfig, PreTrainedTokenizer, + Qwen2Config, ) from sglang.utils import logger @@ -311,6 +312,8 @@ def __init__( self.llm_config = LlamaConfig(**llm_config) elif llm_config.get("architectures")[0] == "InternLM2ForCausalLM": self.llm_config = InternLM2Config(**llm_config) + elif llm_config.get("architectures")[0] == "Qwen2ForCausalLM": + self.llm_config = Qwen2Config(**llm_config) else: raise ValueError( "Unsupported architecture: {}".format( From 465968b2e328623758e69801386c51d6384ac944 Mon Sep 17 00:00:00 2001 From: Ke Bao Date: Mon, 21 Jul 2025 00:27:55 +0800 Subject: [PATCH 073/396] Fix dtype error in CI (#8197) --- python/sglang/srt/layers/moe/topk.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index c3ae9af25d0d..a806a40520be 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -524,7 +524,7 @@ def biased_grouped_topk_gpu( topk_weights = torch.empty((token, topk), dtype=torch.float32, device=device) topk_ids = torch.empty((token, topk), dtype=torch.int32, device=device) aiter_biased_grouped_topk( - gating_output, + gating_output.to(dtype=torch.float32), correction_bias, topk_weights, topk_ids, From 1fc455e8b65f0fcbe5d1c41ac5868667650317c9 Mon Sep 17 00:00:00 2001 From: Simo Lin Date: Sun, 20 Jul 2025 10:53:42 -0700 Subject: [PATCH 074/396] [router] add ut for pd request, metrics and config (#8184) --- sgl-router/src/config/types.rs | 649 +++++++++++-- sgl-router/src/metrics.rs | 411 +++++++++ sgl-router/src/routers/pd_types.rs | 2 +- sgl-router/src/routers/request_adapter.rs | 1013 +++++++++++++++++++++ 4 files changed, 2003 insertions(+), 72 deletions(-) diff --git a/sgl-router/src/config/types.rs b/sgl-router/src/config/types.rs index 6b24a5fd1f4a..5e25b2c3b218 100644 --- a/sgl-router/src/config/types.rs +++ b/sgl-router/src/config/types.rs @@ -214,83 +214,590 @@ impl RouterConfig { pub fn has_metrics(&self) -> bool { self.metrics.is_some() } +} + +#[cfg(test)] +mod tests { + use super::*; + + // ============= RouterConfig Tests ============= - /* Commented out - no longer needed without compatibility layer - /// Convert to routing PolicyConfig for internal use - pub fn to_routing_policy_config(&self) -> ConfigResult { - match (&self.mode, &self.policy) { - ( - RoutingMode::PrefillDecode { - prefill_urls, - decode_urls, - }, - policy, - ) => { - // Map policy to PDSelectionPolicy - let selection_policy = match policy { - PolicyConfig::Random => crate::pd_types::PDSelectionPolicy::Random, - PolicyConfig::PowerOfTwo { .. } => { - crate::pd_types::PDSelectionPolicy::PowerOfTwo - } - PolicyConfig::CacheAware { .. } => { - return Err(ConfigError::IncompatibleConfig { - reason: "CacheAware policy is not supported in PD disaggregated mode" - .to_string(), - }); - } - PolicyConfig::RoundRobin => { - return Err(ConfigError::IncompatibleConfig { - reason: "RoundRobin policy is not supported in PD disaggregated mode" - .to_string(), - }); - } - }; - - Ok(crate::router::PolicyConfig::PrefillDecodeConfig { - selection_policy, - prefill_urls: prefill_urls.clone(), - decode_urls: decode_urls.clone(), - timeout_secs: self.worker_startup_timeout_secs, - interval_secs: self.worker_startup_check_interval_secs, - }) + #[test] + fn test_router_config_default() { + let config = RouterConfig::default(); + + assert!( + matches!(config.mode, RoutingMode::Regular { worker_urls } if worker_urls.is_empty()) + ); + assert!(matches!(config.policy, PolicyConfig::Random)); + assert_eq!(config.host, "127.0.0.1"); + assert_eq!(config.port, 3001); + assert_eq!(config.max_payload_size, 268_435_456); + assert_eq!(config.request_timeout_secs, 600); + assert_eq!(config.worker_startup_timeout_secs, 300); + assert_eq!(config.worker_startup_check_interval_secs, 10); + assert!(config.discovery.is_none()); + assert!(config.metrics.is_none()); + assert!(config.log_dir.is_none()); + assert!(config.log_level.is_none()); + } + + #[test] + fn test_router_config_new() { + let mode = RoutingMode::Regular { + worker_urls: vec!["http://worker1".to_string(), "http://worker2".to_string()], + }; + let policy = PolicyConfig::RoundRobin; + + let config = RouterConfig::new(mode, policy); + + match config.mode { + RoutingMode::Regular { worker_urls } => { + assert_eq!(worker_urls.len(), 2); + assert_eq!(worker_urls[0], "http://worker1"); + assert_eq!(worker_urls[1], "http://worker2"); } - (RoutingMode::Regular { .. }, PolicyConfig::Random) => { - Ok(crate::router::PolicyConfig::RandomConfig { - timeout_secs: self.worker_startup_timeout_secs, - interval_secs: self.worker_startup_check_interval_secs, - }) + _ => panic!("Expected Regular mode"), + } + + assert!(matches!(config.policy, PolicyConfig::RoundRobin)); + // Other fields should be default + assert_eq!(config.host, "127.0.0.1"); + assert_eq!(config.port, 3001); + } + + #[test] + fn test_router_config_serialization() { + let config = RouterConfig { + mode: RoutingMode::Regular { + worker_urls: vec!["http://worker1".to_string()], + }, + policy: PolicyConfig::Random, + host: "0.0.0.0".to_string(), + port: 8080, + max_payload_size: 1024, + request_timeout_secs: 30, + worker_startup_timeout_secs: 60, + worker_startup_check_interval_secs: 5, + discovery: Some(DiscoveryConfig::default()), + metrics: Some(MetricsConfig::default()), + log_dir: Some("/var/log".to_string()), + log_level: Some("debug".to_string()), + }; + + let json = serde_json::to_string(&config).unwrap(); + let deserialized: RouterConfig = serde_json::from_str(&json).unwrap(); + + assert_eq!(config.host, deserialized.host); + assert_eq!(config.port, deserialized.port); + assert_eq!(config.max_payload_size, deserialized.max_payload_size); + assert!(deserialized.discovery.is_some()); + assert!(deserialized.metrics.is_some()); + } + + // ============= RoutingMode Tests ============= + + #[test] + fn test_routing_mode_is_pd_mode() { + let regular = RoutingMode::Regular { + worker_urls: vec!["http://worker1".to_string()], + }; + assert!(!regular.is_pd_mode()); + + let pd = RoutingMode::PrefillDecode { + prefill_urls: vec![("http://prefill1".to_string(), Some(8001))], + decode_urls: vec!["http://decode1".to_string()], + }; + assert!(pd.is_pd_mode()); + } + + #[test] + fn test_routing_mode_worker_count() { + let regular = RoutingMode::Regular { + worker_urls: vec![ + "http://worker1".to_string(), + "http://worker2".to_string(), + "http://worker3".to_string(), + ], + }; + assert_eq!(regular.worker_count(), 3); + + let pd = RoutingMode::PrefillDecode { + prefill_urls: vec![ + ("http://prefill1".to_string(), Some(8001)), + ("http://prefill2".to_string(), None), + ], + decode_urls: vec![ + "http://decode1".to_string(), + "http://decode2".to_string(), + "http://decode3".to_string(), + ], + }; + assert_eq!(pd.worker_count(), 5); + + let empty_regular = RoutingMode::Regular { + worker_urls: vec![], + }; + assert_eq!(empty_regular.worker_count(), 0); + } + + #[test] + fn test_routing_mode_serialization() { + // Test Regular mode + let regular = RoutingMode::Regular { + worker_urls: vec!["http://worker1".to_string()], + }; + let json = serde_json::to_string(®ular).unwrap(); + assert!(json.contains("\"type\":\"regular\"")); + assert!(json.contains("\"worker_urls\"")); + + // Test PrefillDecode mode + let pd = RoutingMode::PrefillDecode { + prefill_urls: vec![("http://prefill1".to_string(), Some(8001))], + decode_urls: vec!["http://decode1".to_string()], + }; + let json = serde_json::to_string(&pd).unwrap(); + assert!(json.contains("\"type\":\"prefill_decode\"")); + assert!(json.contains("\"prefill_urls\"")); + assert!(json.contains("\"decode_urls\"")); + } + + // ============= PolicyConfig Tests ============= + + #[test] + fn test_policy_config_name() { + assert_eq!(PolicyConfig::Random.name(), "random"); + assert_eq!(PolicyConfig::RoundRobin.name(), "round_robin"); + + let cache_aware = PolicyConfig::CacheAware { + cache_threshold: 0.8, + balance_abs_threshold: 10, + balance_rel_threshold: 1.5, + eviction_interval_secs: 300, + max_tree_size: 1000, + }; + assert_eq!(cache_aware.name(), "cache_aware"); + + let power_of_two = PolicyConfig::PowerOfTwo { + load_check_interval_secs: 60, + }; + assert_eq!(power_of_two.name(), "power_of_two"); + } + + #[test] + fn test_policy_config_serialization() { + // Test Random + let random = PolicyConfig::Random; + let json = serde_json::to_string(&random).unwrap(); + assert_eq!(json, r#"{"type":"random"}"#); + + // Test CacheAware with all parameters + let cache_aware = PolicyConfig::CacheAware { + cache_threshold: 0.8, + balance_abs_threshold: 10, + balance_rel_threshold: 1.5, + eviction_interval_secs: 300, + max_tree_size: 1000, + }; + let json = serde_json::to_string(&cache_aware).unwrap(); + assert!(json.contains("\"type\":\"cache_aware\"")); + assert!(json.contains("\"cache_threshold\":0.8")); + assert!(json.contains("\"balance_abs_threshold\":10")); + + // Test PowerOfTwo + let power_of_two = PolicyConfig::PowerOfTwo { + load_check_interval_secs: 60, + }; + let json = serde_json::to_string(&power_of_two).unwrap(); + assert!(json.contains("\"type\":\"power_of_two\"")); + assert!(json.contains("\"load_check_interval_secs\":60")); + } + + #[test] + fn test_cache_aware_parameters() { + let cache_aware = PolicyConfig::CacheAware { + cache_threshold: 0.75, + balance_abs_threshold: 20, + balance_rel_threshold: 2.0, + eviction_interval_secs: 600, + max_tree_size: 5000, + }; + + match cache_aware { + PolicyConfig::CacheAware { + cache_threshold, + balance_abs_threshold, + balance_rel_threshold, + eviction_interval_secs, + max_tree_size, + } => { + assert!((cache_threshold - 0.75).abs() < 0.0001); + assert_eq!(balance_abs_threshold, 20); + assert!((balance_rel_threshold - 2.0).abs() < 0.0001); + assert_eq!(eviction_interval_secs, 600); + assert_eq!(max_tree_size, 5000); } - (RoutingMode::Regular { .. }, PolicyConfig::RoundRobin) => { - Ok(crate::router::PolicyConfig::RoundRobinConfig { - timeout_secs: self.worker_startup_timeout_secs, - interval_secs: self.worker_startup_check_interval_secs, - }) + _ => panic!("Expected CacheAware"), + } + } + + #[test] + fn test_power_of_two_parameters() { + let power_of_two = PolicyConfig::PowerOfTwo { + load_check_interval_secs: 120, + }; + + match power_of_two { + PolicyConfig::PowerOfTwo { + load_check_interval_secs, + } => { + assert_eq!(load_check_interval_secs, 120); } - ( - RoutingMode::Regular { .. }, - PolicyConfig::CacheAware { - cache_threshold, - balance_abs_threshold, - balance_rel_threshold, - eviction_interval_secs, - max_tree_size, - }, - ) => Ok(crate::router::PolicyConfig::CacheAwareConfig { - cache_threshold: *cache_threshold, - balance_abs_threshold: *balance_abs_threshold, - balance_rel_threshold: *balance_rel_threshold, - eviction_interval_secs: *eviction_interval_secs, - max_tree_size: *max_tree_size, - timeout_secs: self.worker_startup_timeout_secs, - interval_secs: self.worker_startup_check_interval_secs, + _ => panic!("Expected PowerOfTwo"), + } + } + + // ============= DiscoveryConfig Tests ============= + + #[test] + fn test_discovery_config_default() { + let config = DiscoveryConfig::default(); + + assert!(!config.enabled); + assert!(config.namespace.is_none()); + assert_eq!(config.port, 8000); + assert_eq!(config.check_interval_secs, 60); + assert!(config.selector.is_empty()); + assert!(config.prefill_selector.is_empty()); + assert!(config.decode_selector.is_empty()); + assert_eq!(config.bootstrap_port_annotation, "sglang.ai/bootstrap-port"); + } + + #[test] + fn test_discovery_config_with_selectors() { + let mut selector = HashMap::new(); + selector.insert("app".to_string(), "sglang".to_string()); + selector.insert("role".to_string(), "worker".to_string()); + + let config = DiscoveryConfig { + enabled: true, + namespace: Some("default".to_string()), + port: 9000, + check_interval_secs: 30, + selector: selector.clone(), + prefill_selector: selector.clone(), + decode_selector: selector.clone(), + bootstrap_port_annotation: "custom.io/port".to_string(), + }; + + assert!(config.enabled); + assert_eq!(config.namespace, Some("default".to_string())); + assert_eq!(config.port, 9000); + assert_eq!(config.selector.len(), 2); + assert_eq!(config.selector.get("app"), Some(&"sglang".to_string())); + } + + #[test] + fn test_discovery_config_namespace() { + // Test None namespace (all namespaces) + let config = DiscoveryConfig { + namespace: None, + ..Default::default() + }; + assert!(config.namespace.is_none()); + + // Test specific namespace + let config = DiscoveryConfig { + namespace: Some("production".to_string()), + ..Default::default() + }; + assert_eq!(config.namespace, Some("production".to_string())); + } + + // ============= MetricsConfig Tests ============= + + #[test] + fn test_metrics_config_default() { + let config = MetricsConfig::default(); + + assert_eq!(config.port, 29000); + assert_eq!(config.host, "127.0.0.1"); + } + + #[test] + fn test_metrics_config_custom() { + let config = MetricsConfig { + port: 9090, + host: "0.0.0.0".to_string(), + }; + + assert_eq!(config.port, 9090); + assert_eq!(config.host, "0.0.0.0"); + } + + // ============= RouterConfig Utility Methods Tests ============= + + #[test] + fn test_mode_type() { + let config = RouterConfig { + mode: RoutingMode::Regular { + worker_urls: vec![], + }, + ..Default::default() + }; + assert_eq!(config.mode_type(), "regular"); + + let config = RouterConfig { + mode: RoutingMode::PrefillDecode { + prefill_urls: vec![], + decode_urls: vec![], + }, + ..Default::default() + }; + assert_eq!(config.mode_type(), "prefill_decode"); + } + + #[test] + fn test_has_service_discovery() { + let config = RouterConfig::default(); + assert!(!config.has_service_discovery()); + + let config = RouterConfig { + discovery: Some(DiscoveryConfig { + enabled: false, + ..Default::default() }), - (RoutingMode::Regular { .. }, PolicyConfig::PowerOfTwo { .. }) => { - Err(ConfigError::IncompatibleConfig { - reason: "PowerOfTwo policy is only supported in PD disaggregated mode" - .to_string(), - }) + ..Default::default() + }; + assert!(!config.has_service_discovery()); + + let config = RouterConfig { + discovery: Some(DiscoveryConfig { + enabled: true, + ..Default::default() + }), + ..Default::default() + }; + assert!(config.has_service_discovery()); + } + + #[test] + fn test_has_metrics() { + let config = RouterConfig::default(); + assert!(!config.has_metrics()); + + let config = RouterConfig { + metrics: Some(MetricsConfig::default()), + ..Default::default() + }; + assert!(config.has_metrics()); + } + + // ============= Edge Cases ============= + + #[test] + fn test_large_worker_lists() { + let large_urls: Vec = (0..1000).map(|i| format!("http://worker{}", i)).collect(); + + let mode = RoutingMode::Regular { + worker_urls: large_urls.clone(), + }; + + assert_eq!(mode.worker_count(), 1000); + + // Test serialization with large list + let config = RouterConfig { + mode, + ..Default::default() + }; + + let json = serde_json::to_string(&config).unwrap(); + let deserialized: RouterConfig = serde_json::from_str(&json).unwrap(); + + match deserialized.mode { + RoutingMode::Regular { worker_urls } => { + assert_eq!(worker_urls.len(), 1000); } + _ => panic!("Expected Regular mode"), } } - */ + + #[test] + fn test_unicode_in_config() { + let config = RouterConfig { + mode: RoutingMode::Regular { + worker_urls: vec!["http://работник1".to_string(), "http://工作者2".to_string()], + }, + log_dir: Some("/日志/目录".to_string()), + ..Default::default() + }; + + let json = serde_json::to_string(&config).unwrap(); + let deserialized: RouterConfig = serde_json::from_str(&json).unwrap(); + + match deserialized.mode { + RoutingMode::Regular { worker_urls } => { + assert_eq!(worker_urls[0], "http://работник1"); + assert_eq!(worker_urls[1], "http://工作者2"); + } + _ => panic!("Expected Regular mode"), + } + + assert_eq!(deserialized.log_dir, Some("/日志/目录".to_string())); + } + + #[test] + fn test_empty_string_fields() { + let config = RouterConfig { + host: "".to_string(), + log_dir: Some("".to_string()), + log_level: Some("".to_string()), + ..Default::default() + }; + + assert_eq!(config.host, ""); + assert_eq!(config.log_dir, Some("".to_string())); + assert_eq!(config.log_level, Some("".to_string())); + } + + // ============= Complex Configuration Tests ============= + + #[test] + fn test_full_pd_mode_config() { + let config = RouterConfig { + mode: RoutingMode::PrefillDecode { + prefill_urls: vec![ + ("http://prefill1:8000".to_string(), Some(8001)), + ("http://prefill2:8000".to_string(), None), + ], + decode_urls: vec![ + "http://decode1:8000".to_string(), + "http://decode2:8000".to_string(), + ], + }, + policy: PolicyConfig::PowerOfTwo { + load_check_interval_secs: 30, + }, + host: "0.0.0.0".to_string(), + port: 3000, + max_payload_size: 1048576, + request_timeout_secs: 120, + worker_startup_timeout_secs: 60, + worker_startup_check_interval_secs: 5, + discovery: Some(DiscoveryConfig { + enabled: true, + namespace: Some("sglang".to_string()), + ..Default::default() + }), + metrics: Some(MetricsConfig { + port: 9090, + host: "0.0.0.0".to_string(), + }), + log_dir: Some("/var/log/sglang".to_string()), + log_level: Some("info".to_string()), + }; + + assert!(config.mode.is_pd_mode()); + assert_eq!(config.mode.worker_count(), 4); + assert_eq!(config.policy.name(), "power_of_two"); + assert!(config.has_service_discovery()); + assert!(config.has_metrics()); + } + + #[test] + fn test_full_regular_mode_config() { + let mut selector = HashMap::new(); + selector.insert("app".to_string(), "sglang".to_string()); + + let config = RouterConfig { + mode: RoutingMode::Regular { + worker_urls: vec![ + "http://worker1:8000".to_string(), + "http://worker2:8000".to_string(), + "http://worker3:8000".to_string(), + ], + }, + policy: PolicyConfig::CacheAware { + cache_threshold: 0.9, + balance_abs_threshold: 5, + balance_rel_threshold: 1.2, + eviction_interval_secs: 600, + max_tree_size: 10000, + }, + host: "0.0.0.0".to_string(), + port: 3001, + max_payload_size: 536870912, + request_timeout_secs: 300, + worker_startup_timeout_secs: 180, + worker_startup_check_interval_secs: 15, + discovery: Some(DiscoveryConfig { + enabled: true, + namespace: None, + port: 8080, + check_interval_secs: 45, + selector, + ..Default::default() + }), + metrics: Some(MetricsConfig::default()), + log_dir: None, + log_level: Some("debug".to_string()), + }; + + assert!(!config.mode.is_pd_mode()); + assert_eq!(config.mode.worker_count(), 3); + assert_eq!(config.policy.name(), "cache_aware"); + assert!(config.has_service_discovery()); + assert!(config.has_metrics()); + } + + #[test] + fn test_config_with_all_options() { + let mut selectors = HashMap::new(); + selectors.insert("env".to_string(), "prod".to_string()); + selectors.insert("version".to_string(), "v1".to_string()); + + let config = RouterConfig { + mode: RoutingMode::Regular { + worker_urls: vec!["http://worker1".to_string()], + }, + policy: PolicyConfig::RoundRobin, + host: "::1".to_string(), // IPv6 + port: 8888, + max_payload_size: 1024 * 1024 * 512, // 512MB + request_timeout_secs: 900, + worker_startup_timeout_secs: 600, + worker_startup_check_interval_secs: 20, + discovery: Some(DiscoveryConfig { + enabled: true, + namespace: Some("production".to_string()), + port: 8443, + check_interval_secs: 120, + selector: selectors.clone(), + prefill_selector: selectors.clone(), + decode_selector: selectors, + bootstrap_port_annotation: "mycompany.io/bootstrap".to_string(), + }), + metrics: Some(MetricsConfig { + port: 9999, + host: "::".to_string(), // IPv6 any + }), + log_dir: Some("/opt/logs/sglang".to_string()), + log_level: Some("trace".to_string()), + }; + + assert!(config.has_service_discovery()); + assert!(config.has_metrics()); + assert_eq!(config.mode_type(), "regular"); + + // Test round-trip serialization + let json = serde_json::to_string_pretty(&config).unwrap(); + let deserialized: RouterConfig = serde_json::from_str(&json).unwrap(); + + assert_eq!(deserialized.host, "::1"); + assert_eq!(deserialized.port, 8888); + assert_eq!( + deserialized.discovery.unwrap().namespace, + Some("production".to_string()) + ); + } } diff --git a/sgl-router/src/metrics.rs b/sgl-router/src/metrics.rs index 76e952a03736..78a06de44e45 100644 --- a/sgl-router/src/metrics.rs +++ b/sgl-router/src/metrics.rs @@ -322,3 +322,414 @@ impl RouterMetrics { .set(count as f64); } } + +#[cfg(test)] +mod tests { + use super::*; + use std::net::TcpListener; + + // ============= PrometheusConfig Tests ============= + + #[test] + fn test_prometheus_config_default() { + let config = PrometheusConfig::default(); + assert_eq!(config.port, 29000); + assert_eq!(config.host, "0.0.0.0"); + } + + #[test] + fn test_prometheus_config_custom() { + let config = PrometheusConfig { + port: 8080, + host: "127.0.0.1".to_string(), + }; + assert_eq!(config.port, 8080); + assert_eq!(config.host, "127.0.0.1"); + } + + #[test] + fn test_prometheus_config_clone() { + let config = PrometheusConfig { + port: 9090, + host: "192.168.1.1".to_string(), + }; + let cloned = config.clone(); + assert_eq!(cloned.port, config.port); + assert_eq!(cloned.host, config.host); + } + + // ============= IP Address Parsing Tests ============= + + #[test] + fn test_valid_ipv4_parsing() { + let test_cases = vec!["127.0.0.1", "192.168.1.1", "0.0.0.0"]; + + for ip_str in test_cases { + let config = PrometheusConfig { + port: 29000, + host: ip_str.to_string(), + }; + + let ip_addr: IpAddr = config.host.parse().unwrap(); + assert!(matches!(ip_addr, IpAddr::V4(_))); + } + } + + #[test] + fn test_valid_ipv6_parsing() { + let test_cases = vec!["::1", "2001:db8::1", "::"]; + + for ip_str in test_cases { + let config = PrometheusConfig { + port: 29000, + host: ip_str.to_string(), + }; + + let ip_addr: IpAddr = config.host.parse().unwrap(); + assert!(matches!(ip_addr, IpAddr::V6(_))); + } + } + + #[test] + fn test_invalid_ip_parsing() { + let test_cases = vec!["invalid", "256.256.256.256", "hostname"]; + + for ip_str in test_cases { + let config = PrometheusConfig { + port: 29000, + host: ip_str.to_string(), + }; + + let ip_addr: IpAddr = config + .host + .parse() + .unwrap_or(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0))); + + // Should fall back to 0.0.0.0 + assert_eq!(ip_addr, IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0))); + } + } + + // ============= Socket Address Creation Tests ============= + + #[test] + fn test_socket_addr_creation() { + let test_cases = vec![("127.0.0.1", 8080), ("0.0.0.0", 29000), ("::1", 9090)]; + + for (host, port) in test_cases { + let config = PrometheusConfig { + port, + host: host.to_string(), + }; + + let ip_addr: IpAddr = config.host.parse().unwrap(); + let socket_addr = SocketAddr::new(ip_addr, config.port); + + assert_eq!(socket_addr.port(), port); + assert_eq!(socket_addr.ip().to_string(), host); + } + } + + #[test] + fn test_socket_addr_with_different_ports() { + let ports = vec![0, 80, 8080, 65535]; + + for port in ports { + let config = PrometheusConfig { + port, + host: "127.0.0.1".to_string(), + }; + + let ip_addr: IpAddr = config.host.parse().unwrap(); + let socket_addr = SocketAddr::new(ip_addr, config.port); + + assert_eq!(socket_addr.port(), port); + } + } + + // ============= Duration Bucket Tests ============= + + #[test] + fn test_duration_bucket_values() { + let expected_buckets = vec![ + 0.001, 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0, 10.0, 15.0, 30.0, 45.0, + 60.0, 90.0, 120.0, 180.0, 240.0, + ]; + + // The buckets are defined in start_prometheus function + assert_eq!(expected_buckets.len(), 20); + + // Verify proper ordering + for i in 1..expected_buckets.len() { + assert!(expected_buckets[i] > expected_buckets[i - 1]); + } + } + + #[test] + fn test_duration_bucket_coverage() { + let test_cases = vec![ + (0.0005, "sub-millisecond"), + (0.005, "5ms"), + (0.05, "50ms"), + (1.0, "1s"), + (10.0, "10s"), + (60.0, "1m"), + (240.0, "4m"), + ]; + + let buckets = vec![ + 0.001, 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0, 10.0, 15.0, 30.0, 45.0, + 60.0, 90.0, 120.0, 180.0, 240.0, + ]; + + for (duration, label) in test_cases { + let bucket_found = buckets + .iter() + .any(|&b| ((b - duration) as f64).abs() < 0.0001 || b > duration); + assert!(bucket_found, "No bucket found for {} ({})", duration, label); + } + } + + // ============= Matcher Configuration Tests ============= + + #[test] + fn test_duration_suffix_matcher() { + let matcher = Matcher::Suffix(String::from("duration_seconds")); + + // Test matching behavior + let _matching_metrics = vec![ + "request_duration_seconds", + "response_duration_seconds", + "sgl_router_request_duration_seconds", + ]; + + let _non_matching_metrics = + vec!["duration_total", "duration_seconds_total", "other_metric"]; + + // Note: We can't directly test Matcher matching without the internals, + // but we can verify the matcher is created correctly + match matcher { + Matcher::Suffix(suffix) => assert_eq!(suffix, "duration_seconds"), + _ => panic!("Expected Suffix matcher"), + } + } + + // ============= Builder Configuration Tests ============= + + #[test] + fn test_prometheus_builder_configuration() { + // This test verifies the builder configuration without actually starting Prometheus + let _config = PrometheusConfig::default(); + + let duration_matcher = Matcher::Suffix(String::from("duration_seconds")); + let duration_bucket = [ + 0.001, 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0, 10.0, 15.0, 30.0, 45.0, + 60.0, 90.0, 120.0, 180.0, 240.0, + ]; + + // Verify bucket configuration + assert_eq!(duration_bucket.len(), 20); + + // Verify matcher is suffix type + match duration_matcher { + Matcher::Suffix(s) => assert_eq!(s, "duration_seconds"), + _ => panic!("Expected Suffix matcher"), + } + } + + // ============= Upkeep Timeout Tests ============= + + #[test] + fn test_upkeep_timeout_duration() { + let timeout = Duration::from_secs(5 * 60); + assert_eq!(timeout.as_secs(), 300); + } + + // ============= Custom Bucket Tests ============= + + #[test] + fn test_custom_buckets_for_different_metrics() { + // Test that we can create different bucket configurations + let request_buckets = vec![0.001, 0.01, 0.1, 1.0, 10.0]; + let generate_buckets = vec![0.1, 0.5, 1.0, 5.0, 30.0, 60.0]; + + assert_eq!(request_buckets.len(), 5); + assert_eq!(generate_buckets.len(), 6); + + // Verify each set is sorted + for i in 1..request_buckets.len() { + assert!(request_buckets[i] > request_buckets[i - 1]); + } + + for i in 1..generate_buckets.len() { + assert!(generate_buckets[i] > generate_buckets[i - 1]); + } + } + + // ============= RouterMetrics Tests ============= + + #[test] + fn test_metrics_static_methods() { + // Test that all static methods can be called without panic + RouterMetrics::record_request("/generate"); + RouterMetrics::record_request_duration("/generate", Duration::from_millis(100)); + RouterMetrics::record_request_error("/generate", "timeout"); + RouterMetrics::record_retry("/generate"); + + RouterMetrics::set_active_workers(5); + RouterMetrics::set_worker_health("http://worker1", true); + RouterMetrics::set_worker_load("http://worker1", 10); + RouterMetrics::record_processed_request("http://worker1"); + + RouterMetrics::record_policy_decision("random", "http://worker1"); + RouterMetrics::record_cache_hit(); + RouterMetrics::record_cache_miss(); + RouterMetrics::set_tree_size("http://worker1", 1000); + RouterMetrics::record_load_balancing_event(); + RouterMetrics::set_load_range(20, 5); + + RouterMetrics::record_pd_request("/v1/chat/completions"); + RouterMetrics::record_pd_request_duration("/v1/chat/completions", Duration::from_secs(1)); + RouterMetrics::record_pd_prefill_request("http://prefill1"); + RouterMetrics::record_pd_decode_request("http://decode1"); + RouterMetrics::record_pd_error("invalid_request"); + RouterMetrics::record_pd_prefill_error("http://prefill1"); + RouterMetrics::record_pd_decode_error("http://decode1"); + RouterMetrics::record_pd_stream_error("http://decode1"); + + RouterMetrics::record_discovery_update(3, 1); + RouterMetrics::record_generate_duration(Duration::from_secs(2)); + RouterMetrics::set_running_requests("http://worker1", 15); + } + + // ============= Port Availability Tests ============= + + #[test] + fn test_port_already_in_use() { + // Skip this test if we can't bind to the port + let port = 29123; // Use a different port to avoid conflicts + + if let Ok(_listener) = TcpListener::bind(("127.0.0.1", port)) { + // Port is available, we can test + let config = PrometheusConfig { + port, + host: "127.0.0.1".to_string(), + }; + + // Just verify config is created correctly + assert_eq!(config.port, port); + } + } + + // ============= Integration Test Helpers ============= + + #[test] + fn test_metrics_endpoint_accessibility() { + // This would be an integration test in practice + // Here we just verify the configuration + let config = PrometheusConfig { + port: 29000, + host: "127.0.0.1".to_string(), + }; + + let ip_addr: IpAddr = config.host.parse().unwrap(); + let socket_addr = SocketAddr::new(ip_addr, config.port); + + assert_eq!(socket_addr.to_string(), "127.0.0.1:29000"); + } + + #[test] + fn test_concurrent_metric_updates() { + // Test that metric updates can be called concurrently + use std::sync::atomic::{AtomicBool, Ordering}; + use std::sync::Arc; + use std::thread; + + let done = Arc::new(AtomicBool::new(false)); + let mut handles = vec![]; + + for i in 0..3 { + let done_clone = done.clone(); + let handle = thread::spawn(move || { + let worker = format!("http://worker{}", i); + while !done_clone.load(Ordering::Relaxed) { + RouterMetrics::set_worker_load(&worker, i * 10); + RouterMetrics::record_processed_request(&worker); + thread::sleep(Duration::from_millis(1)); + } + }); + handles.push(handle); + } + + // Let threads run briefly + thread::sleep(Duration::from_millis(10)); + done.store(true, Ordering::Relaxed); + + // Wait for all threads + for handle in handles { + handle.join().unwrap(); + } + + // If we get here without panic, concurrent access works + assert!(true); + } + + // ============= Edge Cases Tests ============= + + #[test] + fn test_empty_string_metrics() { + // Test that empty strings don't cause issues + RouterMetrics::record_request(""); + RouterMetrics::set_worker_health("", true); + RouterMetrics::record_policy_decision("", ""); + + // If we get here without panic, empty strings are handled + assert!(true); + } + + #[test] + fn test_very_long_metric_labels() { + let long_label = "a".repeat(1000); + + RouterMetrics::record_request(&long_label); + RouterMetrics::set_worker_health(&long_label, false); + + // If we get here without panic, long labels are handled + assert!(true); + } + + #[test] + fn test_special_characters_in_labels() { + let special_labels = vec![ + "test/with/slashes", + "test-with-dashes", + "test_with_underscores", + "test.with.dots", + "test:with:colons", + ]; + + for label in special_labels { + RouterMetrics::record_request(label); + RouterMetrics::set_worker_health(label, true); + } + + // If we get here without panic, special characters are handled + assert!(true); + } + + #[test] + fn test_extreme_metric_values() { + // Test extreme values + RouterMetrics::set_active_workers(0); + RouterMetrics::set_active_workers(usize::MAX); + + RouterMetrics::set_worker_load("worker", 0); + RouterMetrics::set_worker_load("worker", usize::MAX); + + RouterMetrics::record_request_duration("route", Duration::from_nanos(1)); + RouterMetrics::record_request_duration("route", Duration::from_secs(86400)); // 24 hours + + // If we get here without panic, extreme values are handled + assert!(true); + } +} diff --git a/sgl-router/src/routers/pd_types.rs b/sgl-router/src/routers/pd_types.rs index 155274b06f16..e83ab5b60f5b 100644 --- a/sgl-router/src/routers/pd_types.rs +++ b/sgl-router/src/routers/pd_types.rs @@ -58,7 +58,7 @@ pub enum PDSelectionPolicy { }, } // Bootstrap types from PDLB -#[derive(Debug, Deserialize, Serialize)] +#[derive(Debug, Deserialize, Serialize, PartialEq)] #[serde(untagged)] pub enum SingleOrBatch { Single(T), diff --git a/sgl-router/src/routers/request_adapter.rs b/sgl-router/src/routers/request_adapter.rs index f5611bbe492b..201c61aa55c8 100644 --- a/sgl-router/src/routers/request_adapter.rs +++ b/sgl-router/src/routers/request_adapter.rs @@ -211,6 +211,7 @@ impl ToPdRequest for ChatCompletionRequest { self.temperature => "temperature", self.top_p => "top_p", self.n => "n", + self.stream_options => "stream_options", self.stop => "stop", self.max_tokens => "max_tokens", self.max_completion_tokens => "max_completion_tokens", @@ -262,3 +263,1015 @@ pub trait RouteableRequest: GenerationRequest + serde::Serialize + Clone { impl RouteableRequest for GenerateRequest {} impl RouteableRequest for CompletionRequest {} impl RouteableRequest for ChatCompletionRequest {} + +#[cfg(test)] +mod tests { + use super::*; + use crate::openai_api_types::*; + use serde_json::json; + use std::collections::HashMap; + + // ============= GenerateRequest to_pd_request Tests ============= + + #[test] + fn test_generate_to_pd_request_with_text_only() { + let req = GenerateRequest { + text: Some("Hello world".to_string()), + prompt: None, + input_ids: None, + stream: false, + parameters: None, + sampling_params: None, + return_logprob: false, + }; + + let pd_req = req.to_pd_request(); + + // Check text field conversion + assert!(matches!(pd_req.text, Some(SingleOrBatch::Single(ref s)) if s == "Hello world")); + assert!(pd_req.input_ids.is_none()); + + // Check bootstrap fields are None + assert!(pd_req.bootstrap_host.is_none()); + assert!(pd_req.bootstrap_port.is_none()); + assert!(pd_req.bootstrap_room.is_none()); + + // Check stream flag + assert_eq!(pd_req.stream, false); + + // Check other fields + let other = pd_req.other.as_object().unwrap(); + assert_eq!(other.get("stream"), Some(&json!(false))); + assert_eq!(other.get("return_logprob"), Some(&json!(false))); + } + + #[test] + fn test_generate_to_pd_request_with_prompt_string() { + let req = GenerateRequest { + text: None, + prompt: Some(StringOrArray::String("Test prompt".to_string())), + input_ids: None, + stream: true, + parameters: None, + sampling_params: None, + return_logprob: true, + }; + + let pd_req = req.to_pd_request(); + + assert!(matches!(pd_req.text, Some(SingleOrBatch::Single(ref s)) if s == "Test prompt")); + assert!(pd_req.input_ids.is_none()); + assert_eq!(pd_req.stream, true); + + let other = pd_req.other.as_object().unwrap(); + assert_eq!(other.get("stream"), Some(&json!(true))); + assert_eq!(other.get("return_logprob"), Some(&json!(true))); + } + + #[test] + fn test_generate_to_pd_request_with_prompt_array() { + let req = GenerateRequest { + text: None, + prompt: Some(StringOrArray::Array(vec![ + "Prompt 1".to_string(), + "Prompt 2".to_string(), + "Prompt 3".to_string(), + ])), + input_ids: None, + stream: false, + parameters: None, + sampling_params: None, + return_logprob: false, + }; + + let pd_req = req.to_pd_request(); + + match pd_req.text { + Some(SingleOrBatch::Batch(ref batch)) => { + assert_eq!(batch.len(), 3); + assert_eq!(batch[0], "Prompt 1"); + assert_eq!(batch[1], "Prompt 2"); + assert_eq!(batch[2], "Prompt 3"); + } + _ => panic!("Expected batch text"), + } + } + + #[test] + fn test_generate_to_pd_request_with_single_input_ids() { + let req = GenerateRequest { + text: None, + prompt: None, + input_ids: Some(InputIds::Single(vec![100, 200, 300, 400])), + stream: false, + parameters: None, + sampling_params: None, + return_logprob: false, + }; + + let pd_req = req.to_pd_request(); + + assert!(pd_req.text.is_none()); + assert!(matches!( + pd_req.input_ids, + Some(SingleOrBatch::Single(ref ids)) if ids == &vec![100, 200, 300, 400] + )); + } + + #[test] + fn test_generate_to_pd_request_with_batch_input_ids() { + let req = GenerateRequest { + text: None, + prompt: None, + input_ids: Some(InputIds::Batch(vec![ + vec![1, 2, 3], + vec![4, 5, 6, 7], + vec![8, 9], + ])), + stream: false, + parameters: None, + sampling_params: None, + return_logprob: false, + }; + + let pd_req = req.to_pd_request(); + + match pd_req.input_ids { + Some(SingleOrBatch::Batch(ref batch)) => { + assert_eq!(batch.len(), 3); + assert_eq!(batch[0], vec![1, 2, 3]); + assert_eq!(batch[1], vec![4, 5, 6, 7]); + assert_eq!(batch[2], vec![8, 9]); + } + _ => panic!("Expected batch input_ids"), + } + } + + #[test] + fn test_generate_to_pd_request_priority_text_over_prompt() { + let req = GenerateRequest { + text: Some("SGLang text".to_string()), + prompt: Some(StringOrArray::String("OpenAI prompt".to_string())), + input_ids: Some(InputIds::Single(vec![1, 2, 3])), + stream: false, + parameters: None, + sampling_params: None, + return_logprob: false, + }; + + let pd_req = req.to_pd_request(); + + // text should take priority + assert!(matches!(pd_req.text, Some(SingleOrBatch::Single(ref s)) if s == "SGLang text")); + assert!(pd_req.input_ids.is_none()); + } + + #[test] + fn test_generate_to_pd_request_priority_prompt_over_input_ids() { + let req = GenerateRequest { + text: None, + prompt: Some(StringOrArray::String("OpenAI prompt".to_string())), + input_ids: Some(InputIds::Single(vec![1, 2, 3])), + stream: false, + parameters: None, + sampling_params: None, + return_logprob: false, + }; + + let pd_req = req.to_pd_request(); + + // prompt should take priority over input_ids + assert!(matches!(pd_req.text, Some(SingleOrBatch::Single(ref s)) if s == "OpenAI prompt")); + assert!(pd_req.input_ids.is_none()); + } + + #[test] + fn test_generate_to_pd_request_with_parameters() { + let params = GenerateParameters { + max_new_tokens: Some(100), + temperature: Some(0.8), + top_p: Some(0.95), + seed: Some(12345), + stop: Some(vec!["END".to_string(), "STOP".to_string()]), + repetition_penalty: Some(1.1), + ..Default::default() + }; + + let req = GenerateRequest { + text: Some("test".to_string()), + prompt: None, + input_ids: None, + stream: false, + parameters: Some(params), + sampling_params: None, + return_logprob: false, + }; + + let pd_req = req.to_pd_request(); + let other = pd_req.other.as_object().unwrap(); + + // Check that max_new_tokens and temperature were extracted to top level + assert_eq!(other.get("max_new_tokens"), Some(&json!(100))); + assert!(other.get("temperature").unwrap().as_f64().unwrap() - 0.8 < 0.0001); + + // Check that other parameters remain under "parameters" + let params = other.get("parameters").unwrap().as_object().unwrap(); + assert!(params.get("top_p").unwrap().as_f64().unwrap() - 0.95 < 0.0001); + assert_eq!(params.get("seed"), Some(&json!(12345))); + assert_eq!(params.get("stop"), Some(&json!(vec!["END", "STOP"]))); + assert!(params.get("repetition_penalty").unwrap().as_f64().unwrap() - 1.1 < 0.0001); + } + + #[test] + fn test_generate_to_pd_request_with_sampling_params() { + let sampling = SamplingParams { + max_new_tokens: Some(200), + temperature: Some(0.7), + top_p: Some(0.9), + top_k: Some(50), + frequency_penalty: Some(0.1), + presence_penalty: Some(0.2), + repetition_penalty: Some(1.05), + ..Default::default() + }; + + let req = GenerateRequest { + text: Some("test".to_string()), + prompt: None, + input_ids: None, + stream: false, + parameters: None, + sampling_params: Some(sampling), + return_logprob: false, + }; + + let pd_req = req.to_pd_request(); + let other = pd_req.other.as_object().unwrap(); + + // Check extracted top-level fields + assert_eq!(other.get("max_new_tokens"), Some(&json!(200))); + assert!(other.get("temperature").unwrap().as_f64().unwrap() - 0.7 < 0.0001); + + // Check full sampling_params is preserved + let sampling = other.get("sampling_params").unwrap().as_object().unwrap(); + assert_eq!(sampling.get("max_new_tokens"), Some(&json!(200))); + assert!(sampling.get("temperature").unwrap().as_f64().unwrap() - 0.7 < 0.0001); + assert!(sampling.get("top_p").unwrap().as_f64().unwrap() - 0.9 < 0.0001); + assert_eq!(sampling.get("top_k"), Some(&json!(50))); + assert!(sampling.get("frequency_penalty").unwrap().as_f64().unwrap() - 0.1 < 0.0001); + assert!(sampling.get("presence_penalty").unwrap().as_f64().unwrap() - 0.2 < 0.0001); + } + + #[test] + fn test_generate_to_pd_request_sampling_params_override_parameters() { + // When both parameters and sampling_params have max_new_tokens/temperature, + // sampling_params should take precedence (processed last) + let params = GenerateParameters { + max_new_tokens: Some(100), + temperature: Some(0.5), + ..Default::default() + }; + + let sampling = SamplingParams { + max_new_tokens: Some(200), + temperature: Some(0.9), + ..Default::default() + }; + + let req = GenerateRequest { + text: Some("test".to_string()), + prompt: None, + input_ids: None, + stream: false, + parameters: Some(params), + sampling_params: Some(sampling), + return_logprob: false, + }; + + let pd_req = req.to_pd_request(); + let other = pd_req.other.as_object().unwrap(); + + // Should use values from sampling_params since they're processed last + assert_eq!(other.get("max_new_tokens"), Some(&json!(200))); + assert!(other.get("temperature").unwrap().as_f64().unwrap() - 0.9 < 0.0001); + } + + #[test] + fn test_generate_to_pd_request_empty_parameters() { + let params = GenerateParameters::default(); + + let req = GenerateRequest { + text: Some("test".to_string()), + prompt: None, + input_ids: None, + stream: false, + parameters: Some(params), + sampling_params: None, + return_logprob: false, + }; + + let pd_req = req.to_pd_request(); + let other = pd_req.other.as_object().unwrap(); + + // Should not have parameters field if all values are None/default + assert!(!other.contains_key("parameters")); + assert!(!other.contains_key("max_new_tokens")); + assert!(!other.contains_key("temperature")); + } + + #[test] + fn test_generate_to_pd_request_all_fields() { + let params = GenerateParameters { + max_new_tokens: Some(150), + temperature: Some(0.6), + top_k: Some(40), + ..Default::default() + }; + + let sampling = SamplingParams { + max_new_tokens: Some(250), // Will override parameters + temperature: Some(0.8), // Will override parameters + presence_penalty: Some(0.1), + ..Default::default() + }; + + let req = GenerateRequest { + text: Some("Complex test".to_string()), + prompt: Some(StringOrArray::String("Ignored prompt".to_string())), + input_ids: None, + stream: true, + parameters: Some(params), + sampling_params: Some(sampling), + return_logprob: true, + }; + + let pd_req = req.to_pd_request(); + + // Verify all fields + assert!(matches!(pd_req.text, Some(SingleOrBatch::Single(ref s)) if s == "Complex test")); + assert!(pd_req.input_ids.is_none()); + assert_eq!(pd_req.stream, true); + assert!(pd_req.bootstrap_host.is_none()); + assert!(pd_req.bootstrap_port.is_none()); + assert!(pd_req.bootstrap_room.is_none()); + + let other = pd_req.other.as_object().unwrap(); + assert_eq!(other.get("stream"), Some(&json!(true))); + assert_eq!(other.get("return_logprob"), Some(&json!(true))); + // Sampling params override parameters + assert_eq!(other.get("max_new_tokens"), Some(&json!(250))); + assert!(other.get("temperature").unwrap().as_f64().unwrap() - 0.8 < 0.0001); + assert!(other.contains_key("parameters")); + assert!(other.contains_key("sampling_params")); + } + + // ============= CompletionRequest to_pd_request Tests ============= + + #[test] + fn test_completion_to_pd_request_basic() { + let req = CompletionRequest { + model: "gpt-3.5-turbo".to_string(), + prompt: StringOrArray::String("Complete this sentence".to_string()), + max_tokens: None, + temperature: None, + top_p: None, + n: None, + stream: false, + stream_options: None, + logprobs: None, + echo: false, + stop: None, + presence_penalty: None, + frequency_penalty: None, + best_of: None, + logit_bias: None, + user: None, + seed: None, + suffix: None, + }; + + let pd_req = req.to_pd_request(); + + assert!( + matches!(pd_req.text, Some(SingleOrBatch::Single(ref s)) if s == "Complete this sentence") + ); + assert!(pd_req.input_ids.is_none()); + assert_eq!(pd_req.stream, false); + + let other = pd_req.other.as_object().unwrap(); + assert_eq!(other.get("model"), Some(&json!("gpt-3.5-turbo"))); + assert_eq!(other.get("stream"), Some(&json!(false))); + } + + #[test] + fn test_completion_to_pd_request_array_prompt() { + let req = CompletionRequest { + model: "test".to_string(), + prompt: StringOrArray::Array(vec![ + "First prompt".to_string(), + "Second prompt".to_string(), + ]), + max_tokens: None, + temperature: None, + top_p: None, + n: None, + stream: false, + stream_options: None, + logprobs: None, + echo: false, + stop: None, + presence_penalty: None, + frequency_penalty: None, + best_of: None, + logit_bias: None, + user: None, + seed: None, + suffix: None, + }; + + let pd_req = req.to_pd_request(); + + match pd_req.text { + Some(SingleOrBatch::Batch(ref batch)) => { + assert_eq!(batch.len(), 2); + assert_eq!(batch[0], "First prompt"); + assert_eq!(batch[1], "Second prompt"); + } + _ => panic!("Expected batch text"), + } + } + + #[test] + fn test_completion_to_pd_request_parameter_mapping() { + let req = CompletionRequest { + model: "test".to_string(), + prompt: StringOrArray::String("test".to_string()), + max_tokens: Some(150), // -> max_new_tokens + temperature: Some(0.75), + top_p: Some(0.92), + n: Some(3), // -> best_of + stream: true, + stream_options: None, + logprobs: Some(10), // -> top_n_tokens + echo: true, // -> return_full_text + stop: Some(StringOrArray::Array(vec![ + "\\n".to_string(), + "END".to_string(), + ])), + presence_penalty: Some(0.5), // -> repetition_penalty = 1.5 + frequency_penalty: Some(0.2), + best_of: Some(5), + logit_bias: None, + user: Some("user123".to_string()), + seed: Some(42), + suffix: Some("...".to_string()), + }; + + let pd_req = req.to_pd_request(); + let other = pd_req.other.as_object().unwrap(); + let params = other.get("parameters").unwrap().as_object().unwrap(); + + // Check parameter mappings + assert_eq!(params.get("max_new_tokens"), Some(&json!(150))); + assert!(params.get("temperature").unwrap().as_f64().unwrap() - 0.75 < 0.0001); + assert!(params.get("top_p").unwrap().as_f64().unwrap() - 0.92 < 0.0001); + assert_eq!(params.get("best_of"), Some(&json!(3))); + assert_eq!(params.get("top_n_tokens"), Some(&json!(10))); + assert_eq!(params.get("return_full_text"), Some(&json!(true))); + assert_eq!(params.get("stop"), Some(&json!(vec!["\\n", "END"]))); + assert!(params.get("repetition_penalty").unwrap().as_f64().unwrap() - 1.5 < 0.0001); + assert_eq!(params.get("seed"), Some(&json!(42))); + + // Check other fields + assert_eq!(other.get("model"), Some(&json!("test"))); + assert_eq!(other.get("stream"), Some(&json!(true))); + } + + #[test] + fn test_completion_to_pd_request_stop_string() { + let req = CompletionRequest { + model: "test".to_string(), + prompt: StringOrArray::String("test".to_string()), + stop: Some(StringOrArray::String("STOP".to_string())), + max_tokens: None, + temperature: None, + top_p: None, + n: None, + stream: false, + stream_options: None, + logprobs: None, + echo: false, + presence_penalty: None, + frequency_penalty: None, + best_of: None, + logit_bias: None, + user: None, + seed: None, + suffix: None, + }; + + let pd_req = req.to_pd_request(); + let other = pd_req.other.as_object().unwrap(); + let params = other.get("parameters").unwrap().as_object().unwrap(); + + // Single string stop should be converted to array + assert_eq!(params.get("stop"), Some(&json!(vec!["STOP"]))); + } + + #[test] + fn test_completion_to_pd_request_no_presence_penalty() { + let req = CompletionRequest { + model: "test".to_string(), + prompt: StringOrArray::String("test".to_string()), + presence_penalty: None, + max_tokens: None, + temperature: None, + top_p: None, + n: None, + stream: false, + stream_options: None, + logprobs: None, + echo: false, + stop: None, + frequency_penalty: None, + best_of: None, + logit_bias: None, + user: None, + seed: None, + suffix: None, + }; + + let pd_req = req.to_pd_request(); + let other = pd_req.other.as_object().unwrap(); + let params = other.get("parameters").unwrap().as_object().unwrap(); + + // Should not have repetition_penalty if presence_penalty is None + assert!(!params.contains_key("repetition_penalty")); + } + + // ============= ChatCompletionRequest to_pd_request Tests ============= + + #[test] + fn test_chat_to_pd_request_basic() { + let messages = vec![ + ChatMessage::System { + role: "system".to_string(), + content: "You are a helpful assistant".to_string(), + name: None, + }, + ChatMessage::User { + role: "user".to_string(), + content: UserMessageContent::Text("Hello!".to_string()), + name: None, + }, + ]; + + let req = ChatCompletionRequest { + messages, + model: "gpt-4".to_string(), + temperature: None, + top_p: None, + n: None, + stream: false, + stream_options: None, + stop: None, + max_tokens: None, + max_completion_tokens: None, + presence_penalty: None, + frequency_penalty: None, + logit_bias: None, + logprobs: false, + top_logprobs: None, + user: None, + seed: None, + response_format: None, + tools: None, + tool_choice: None, + parallel_tool_calls: None, + functions: None, + function_call: None, + }; + + let pd_req = req.to_pd_request(); + + assert_eq!(pd_req.stream, false); + assert!(pd_req.bootstrap_host.is_none()); + assert!(pd_req.bootstrap_port.is_none()); + assert!(pd_req.bootstrap_room.is_none()); + + let other = pd_req.other.as_object().unwrap(); + assert!(other.contains_key("messages")); + assert_eq!(other.get("model"), Some(&json!("gpt-4"))); + assert_eq!(other.get("stream"), Some(&json!(false))); + + // Check messages are preserved + let messages = other.get("messages").unwrap().as_array().unwrap(); + assert_eq!(messages.len(), 2); + } + + #[test] + fn test_chat_to_pd_request_with_all_optional_fields() { + let messages = vec![ChatMessage::User { + role: "user".to_string(), + content: UserMessageContent::Text("Test".to_string()), + name: Some("test_user".to_string()), + }]; + + let mut logit_bias = HashMap::new(); + logit_bias.insert("50256".to_string(), -100); + + let tool = Tool { + tool_type: "function".to_string(), + function: Function { + name: "get_weather".to_string(), + description: Some("Get weather info".to_string()), + parameters: json!({"type": "object"}), + }, + }; + + let req = ChatCompletionRequest { + messages, + model: "gpt-4".to_string(), + temperature: Some(0.8), + top_p: Some(0.95), + n: Some(2), + stream: true, + stream_options: Some(StreamOptions { + include_usage: Some(true), + }), + stop: Some(StringOrArray::String("\\n\\n".to_string())), + max_tokens: Some(200), + max_completion_tokens: Some(150), + presence_penalty: Some(0.1), + frequency_penalty: Some(0.2), + logit_bias: Some(logit_bias), + logprobs: true, + top_logprobs: Some(5), + user: Some("user456".to_string()), + seed: Some(12345), + response_format: Some(ResponseFormat::JsonObject), + tools: Some(vec![tool]), + tool_choice: Some(ToolChoice::Auto), + parallel_tool_calls: Some(false), + functions: None, + function_call: None, + }; + + let pd_req = req.to_pd_request(); + let other = pd_req.other.as_object().unwrap(); + + // Check all fields are preserved + assert!(other.get("temperature").unwrap().as_f64().unwrap() - 0.8 < 0.0001); + assert!(other.get("top_p").unwrap().as_f64().unwrap() - 0.95 < 0.0001); + assert_eq!(other.get("n"), Some(&json!(2))); + assert_eq!(other.get("stream"), Some(&json!(true))); + assert!(other.contains_key("stream_options")); + assert!(other.contains_key("stop")); + assert_eq!(other.get("max_tokens"), Some(&json!(200))); + assert_eq!(other.get("max_completion_tokens"), Some(&json!(150))); + assert!(other.get("presence_penalty").unwrap().as_f64().unwrap() - 0.1 < 0.0001); + assert!(other.get("frequency_penalty").unwrap().as_f64().unwrap() - 0.2 < 0.0001); + assert!(other.contains_key("logit_bias")); + assert_eq!(other.get("logprobs"), Some(&json!(true))); + assert_eq!(other.get("top_logprobs"), Some(&json!(5))); + assert_eq!(other.get("user"), Some(&json!("user456"))); + assert_eq!(other.get("seed"), Some(&json!(12345))); + assert!(other.contains_key("response_format")); + assert!(other.contains_key("tools")); + assert!(other.contains_key("tool_choice")); + assert_eq!(other.get("parallel_tool_calls"), Some(&json!(false))); + } + + #[test] + fn test_chat_to_pd_request_multimodal_content() { + let messages = vec![ChatMessage::User { + role: "user".to_string(), + content: UserMessageContent::Parts(vec![ + ContentPart::Text { + text: "What's in this image?".to_string(), + }, + ContentPart::ImageUrl { + image_url: ImageUrl { + url: "https://example.com/image.jpg".to_string(), + detail: Some("high".to_string()), + }, + }, + ]), + name: None, + }]; + + let req = ChatCompletionRequest { + messages, + model: "gpt-4-vision".to_string(), + temperature: None, + top_p: None, + n: None, + stream: false, + stream_options: None, + stop: None, + max_tokens: None, + max_completion_tokens: None, + presence_penalty: None, + frequency_penalty: None, + logit_bias: None, + logprobs: false, + top_logprobs: None, + user: None, + seed: None, + response_format: None, + tools: None, + tool_choice: None, + parallel_tool_calls: None, + functions: None, + function_call: None, + }; + + let pd_req = req.to_pd_request(); + let other = pd_req.other.as_object().unwrap(); + + // Messages with multimodal content should be preserved + assert!(other.contains_key("messages")); + let messages = other.get("messages").unwrap().as_array().unwrap(); + assert_eq!(messages.len(), 1); + + // Verify the message structure is preserved + let msg = &messages[0]; + assert_eq!(msg["role"], "user"); + assert!(msg["content"].is_array()); + } + + #[test] + fn test_chat_to_pd_request_logprobs_boolean() { + let messages = vec![ChatMessage::User { + role: "user".to_string(), + content: UserMessageContent::Text("Test".to_string()), + name: None, + }]; + + let req = ChatCompletionRequest { + messages, + model: "test".to_string(), + logprobs: true, // Boolean logprobs flag + top_logprobs: Some(3), + temperature: None, + top_p: None, + n: None, + stream: false, + stream_options: None, + stop: None, + max_tokens: None, + max_completion_tokens: None, + presence_penalty: None, + frequency_penalty: None, + logit_bias: None, + user: None, + seed: None, + response_format: None, + tools: None, + tool_choice: None, + parallel_tool_calls: None, + functions: None, + function_call: None, + }; + + let pd_req = req.to_pd_request(); + let other = pd_req.other.as_object().unwrap(); + + assert_eq!(other.get("logprobs"), Some(&json!(true))); + assert_eq!(other.get("top_logprobs"), Some(&json!(3))); + } + + #[test] + fn test_chat_to_pd_request_minimal_fields() { + let messages = vec![ChatMessage::Assistant { + role: "assistant".to_string(), + content: Some("I can help with that.".to_string()), + name: None, + tool_calls: None, + function_call: None, + }]; + + let req = ChatCompletionRequest { + messages, + model: "gpt-3.5-turbo".to_string(), + temperature: None, + top_p: None, + n: None, + stream: false, + stream_options: None, + stop: None, + max_tokens: None, + max_completion_tokens: None, + presence_penalty: None, + frequency_penalty: None, + logit_bias: None, + logprobs: false, + top_logprobs: None, + user: None, + seed: None, + response_format: None, + tools: None, + tool_choice: None, + parallel_tool_calls: None, + functions: None, + function_call: None, + }; + + let pd_req = req.to_pd_request(); + let other = pd_req.other.as_object().unwrap(); + + // Should only have required fields + assert!(other.contains_key("messages")); + assert!(other.contains_key("model")); + assert!(other.contains_key("stream")); + + // Optional fields should not be present + assert!(!other.contains_key("temperature")); + assert!(!other.contains_key("top_p")); + assert!(!other.contains_key("max_tokens")); + assert!(!other.contains_key("stop")); + } + + #[test] + fn test_routeable_request_to_json() { + let req = GenerateRequest { + text: Some("test".to_string()), + prompt: None, + input_ids: None, + stream: false, + parameters: None, + sampling_params: None, + return_logprob: false, + }; + + let json = req.to_json().unwrap(); + assert_eq!(json["text"], "test"); + assert_eq!(json["stream"], false); + } + + // ============= Macro Tests ============= + + #[test] + fn test_insert_if_some_macro() { + let mut map = serde_json::Map::new(); + + let some_value: Option = Some(42); + let none_value: Option = None; + + insert_if_some!(map, + some_value => "present", + none_value => "absent" + ); + + assert_eq!(map.get("present"), Some(&json!(42))); + assert!(!map.contains_key("absent")); + } + + #[test] + fn test_insert_value_macro() { + let mut map = serde_json::Map::new(); + + let value1 = "test"; + let value2 = 42; + + insert_value!(map, + value1 => "string_field", + value2 => "int_field" + ); + + assert_eq!(map.get("string_field"), Some(&json!("test"))); + assert_eq!(map.get("int_field"), Some(&json!(42))); + } + + // ============= Edge Cases and Error Handling ============= + + #[test] + fn test_null_value_handling() { + let params = GenerateParameters { + max_new_tokens: None, + temperature: None, + ..Default::default() + }; + + let req = GenerateRequest { + text: Some("test".to_string()), + prompt: None, + input_ids: None, + stream: false, + parameters: Some(params), + sampling_params: None, + return_logprob: false, + }; + + let pd_req = req.to_pd_request(); + let other = pd_req.other.as_object().unwrap(); + + // Should not have parameters field if all fields are None + assert!(!other.contains_key("parameters")); + } + + #[test] + fn test_large_batch_conversion() { + let large_batch: Vec = (0..1000).map(|i| format!("item_{}", i)).collect(); + + let req = GenerateRequest { + text: None, + prompt: Some(StringOrArray::Array(large_batch.clone())), + input_ids: None, + stream: false, + parameters: None, + sampling_params: None, + return_logprob: false, + }; + + let pd_req = req.to_pd_request(); + + if let Some(SingleOrBatch::Batch(batch)) = pd_req.text { + assert_eq!(batch.len(), 1000); + assert_eq!(batch[0], "item_0"); + assert_eq!(batch[999], "item_999"); + } else { + panic!("Expected batch text"); + } + } + + #[test] + fn test_unicode_string_handling() { + let unicode_text = "Hello 世界 🌍 नमस्ते мир".to_string(); + + let req = GenerateRequest { + text: Some(unicode_text.clone()), + prompt: None, + input_ids: None, + stream: false, + parameters: None, + sampling_params: None, + return_logprob: false, + }; + + let pd_req = req.to_pd_request(); + + if let Some(SingleOrBatch::Single(text)) = pd_req.text { + assert_eq!(text, unicode_text); + } else { + panic!("Expected single text"); + } + } + + #[test] + fn test_deeply_nested_parameters() { + let mut nested_params = serde_json::Map::new(); + nested_params.insert( + "nested".to_string(), + json!({ + "level1": { + "level2": { + "level3": "value" + } + } + }), + ); + + let params = GenerateParameters { + max_new_tokens: Some(100), + ..Default::default() + }; + + let req = GenerateRequest { + text: Some("test".to_string()), + prompt: None, + input_ids: None, + stream: false, + parameters: Some(params), + sampling_params: None, + return_logprob: false, + }; + + let pd_req = req.to_pd_request(); + let other = pd_req.other.as_object().unwrap(); + + // Parameters should be preserved even with nested structures + assert!(other.contains_key("max_new_tokens")); + } + + // ============= Bootstrap Field Tests ============= + + #[test] + fn test_bootstrap_fields_none() { + let req = GenerateRequest { + text: Some("test".to_string()), + prompt: None, + input_ids: None, + stream: false, + parameters: None, + sampling_params: None, + return_logprob: false, + }; + + let pd_req = req.to_pd_request(); + + assert_eq!(pd_req.bootstrap_host, None); + assert_eq!(pd_req.bootstrap_port, None); + assert_eq!(pd_req.bootstrap_room, None); + } +} From 93d124ef5a4b71a11b409150c85e70d4a0256bab Mon Sep 17 00:00:00 2001 From: ronnie_zheng Date: Sun, 20 Jul 2025 23:12:42 +0300 Subject: [PATCH 075/396] [feature] enable NPU CI (#7935) Co-authored-by: Even Zhou <14368888+iforgetmyname@users.noreply.github.com> --- .github/workflows/pr-test-npu.yml | 64 +++++++++++++++++++++++ .pre-commit-config.yaml | 2 +- scripts/npu_ci_install_dependency.sh | 46 ++++++++++++++++ test/srt/test_ascend_attention_backend.py | 16 +----- 4 files changed, 113 insertions(+), 15 deletions(-) create mode 100644 .github/workflows/pr-test-npu.yml create mode 100755 scripts/npu_ci_install_dependency.sh diff --git a/.github/workflows/pr-test-npu.yml b/.github/workflows/pr-test-npu.yml new file mode 100644 index 000000000000..be924d7bbc09 --- /dev/null +++ b/.github/workflows/pr-test-npu.yml @@ -0,0 +1,64 @@ +name: PR Test (Ascend NPU) + +on: + push: + branches: [ main ] + paths: + - "python/**" + - "scripts/**" + - "test/**" + - ".github/workflows/pr-test-npu.yml" + pull_request: + branches: [ main ] + paths: + - "python/**" + - "scripts/**" + - "test/**" + - ".github/workflows/pr-test-npu.yml" + workflow_dispatch: + +concurrency: + group: pr-test-npu-${{ github.ref }} + cancel-in-progress: true + +jobs: + unit-test-basic: + if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && + github.event.pull_request.draft == false + runs-on: linux-arm64-npu-1 + container: + image: swr.cn-southwest-2.myhuaweicloud.com/base_image/ascend-ci/cann:8.2.rc1.alpha003-910b-ubuntu22.04-py3.11 + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Install dependencies + run: | + bash scripts/npu_ci_install_dependency.sh + # copy required dataset file from our daily cache + cp ~/.cache/modelscope/hub/datasets/otavia/ShareGPT_Vicuna_unfiltered/ShareGPT_V3_unfiltered_cleaned_split.json /tmp + + - name: Run test + timeout-minutes: 30 + env: + SGLANG_USE_MODELSCOPE: true + HF_ENDPOINT: https://hf-mirror.com + run: | + cd test/srt + python3 run_suite.py --suite per-commit-npu + finish: + if: always() + needs: [ unit-test-basic ] + runs-on: ubuntu-latest + steps: + - name: Check all dependent job statuses + run: | + results=(${{ join(needs.*.result, ' ') }}) + for result in "${results[@]}"; do + if [ "$result" = "failure" ] || [ "$result" = "cancelled" ]; then + echo "Job failed with result: $result" + exit 1 + fi + done + echo "All jobs completed successfully" + exit 0 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 89d4664c5715..e9e9af1d0a02 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -38,7 +38,7 @@ repos: hooks: - id: codespell additional_dependencies: ['tomli'] - args: ['--toml', 'python/pyproject.toml'] + args: ['--toml', 'python/pyproject.toml', '-L', 'cann'] exclude: test/srt/test_reasoning_parser.py # Exclude the test file that is expected to fail - repo: https://github.com/pre-commit/mirrors-clang-format rev: v18.1.8 diff --git a/scripts/npu_ci_install_dependency.sh b/scripts/npu_ci_install_dependency.sh new file mode 100755 index 000000000000..ec3a162d52a4 --- /dev/null +++ b/scripts/npu_ci_install_dependency.sh @@ -0,0 +1,46 @@ +#!/bin/bash +set -euo pipefail + +# Install the required dependencies in CI. +sed -i 's|ports.ubuntu.com|mirrors.tuna.tsinghua.edu.cn|g' /etc/apt/sources.list +apt update -y +apt install -y build-essential cmake python3-pip python3-dev wget net-tools zlib1g-dev lld clang software-properties-common + + +pip config set global.index-url https://mirrors.huaweicloud.com/repository/pypi/simple +python3 -m pip install --upgrade pip +pip uninstall sgl-kernel -y || true + + +### Download MemFabricV2 +MF_WHL_NAME="mf_adapter-1.0.0-cp311-cp311-linux_aarch64.whl" +MEMFABRIC_URL="https://sglang-ascend.obs.cn-east-3.myhuaweicloud.com:443/sglang/${MF_WHL_NAME}" +wget "${MEMFABRIC_URL}" && pip install "./${MF_WHL_NAME}" + + +### Install vLLM +VLLM_TAG=v0.8.5 +git clone --depth 1 https://github.com/vllm-project/vllm.git --branch $VLLM_TAG +(cd vllm && VLLM_TARGET_DEVICE="empty" pip install -v -e .) + + +### Install PyTorch and PTA +PYTORCH_VERSION=2.6.0 +TORCHVISION_VERSION=0.21.0 +PTA_VERSION=2.6.0rc1 +pip install torch==$PYTORCH_VERSION torchvision==$TORCHVISION_VERSION --index-url https://download.pytorch.org/whl/cpu +pip install torch_npu==$PTA_VERSION + + +### Install Triton-Ascend +TRITON_ASCEND_VERSION=3.2.0rc2 +pip install attrs==24.2.0 numpy==1.26.4 scipy==1.13.1 decorator==5.1.1 psutil==6.0.0 pytest==8.3.2 pytest-xdist==3.6.1 pyyaml pybind11 +pip install triton-ascend==$TRITON_ASCEND_VERSION + + +pip install -e "python[srt_npu]" + + +### Modify PyTorch TODO: to be removed later +TORCH_LOCATION=$(python3 -c 'import torch; print(torch.__path__[0])') +sed -i 's/from triton.runtime.autotuner import OutOfResources/from triton.runtime.errors import OutOfResources/' "${TORCH_LOCATION}/_inductor/runtime/triton_heuristics.py" diff --git a/test/srt/test_ascend_attention_backend.py b/test/srt/test_ascend_attention_backend.py index 4ca6bba8f3dc..e406fee3c070 100644 --- a/test/srt/test_ascend_attention_backend.py +++ b/test/srt/test_ascend_attention_backend.py @@ -20,22 +20,10 @@ run_bench_offline_throughput, ) +DEFAULT_MODEL_NAME_FOR_TEST = "Qwen/Qwen2.5-7B-Instruct" -class TestAscendAttnBackend(CustomTestCase): - def test_latency(self): - output_throughput = run_bench_offline_throughput( - DEFAULT_MODEL_NAME_FOR_TEST, - [ - "--attention-backend", - "ascend", - ], - ) - - print(f"{output_throughput=}") - - if is_in_ci(): - self.assertGreater(output_throughput, 18) +class TestAscendAttnBackend(CustomTestCase): def test_gsm8k(self): model = DEFAULT_MODEL_NAME_FOR_TEST base_url = DEFAULT_URL_FOR_TEST From 7eebd4404764dd778e18cc0fc4866d97504271f0 Mon Sep 17 00:00:00 2001 From: JieXin Liang Date: Mon, 21 Jul 2025 08:39:57 +0800 Subject: [PATCH 076/396] [fix] fix modelopt fp4 on b200 (#8195) --- python/sglang/srt/layers/quantization/petit.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/layers/quantization/petit.py b/python/sglang/srt/layers/quantization/petit.py index e7ee3239f64c..2c608507c9c2 100644 --- a/python/sglang/srt/layers/quantization/petit.py +++ b/python/sglang/srt/layers/quantization/petit.py @@ -21,6 +21,9 @@ verify_petit_nvfp4_supported, ) from sglang.srt.layers.quantization.utils import is_layer_skipped +from sglang.srt.utils import is_hip + +_is_hip = is_hip() # Initialize logger for the module logger = logging.getLogger(__name__) @@ -104,7 +107,7 @@ def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str] @classmethod def is_petit_nvfp4_compatible(cls, quant_config: Dict[str, Any]) -> bool: quant_method = quant_config.get("quant_method", "").lower() - return quant_method == "modelopt" + return _is_hip and quant_method == "modelopt" def is_layer_excluded(self, prefix: str, exclude_modules: list): for pattern in exclude_modules: From 429bb0efa2032c6f2826b97477f44f5326ba0a22 Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Sun, 20 Jul 2025 19:50:28 -0700 Subject: [PATCH 077/396] chore: bump sgl-kernel v0.2.6.post1 (#8200) --- docker/Dockerfile | 2 +- sgl-kernel/pyproject.toml | 2 +- sgl-kernel/pyproject_cpu.toml | 2 +- sgl-kernel/pyproject_rocm.toml | 2 +- sgl-kernel/python/sgl_kernel/version.py | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/docker/Dockerfile b/docker/Dockerfile index 97be3625af7c..1e5f21c9d5f5 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -60,7 +60,7 @@ RUN python3 -m pip install --no-cache-dir --upgrade pip setuptools wheel html5li && python3 -m pip install --no-cache-dir -e "python[${BUILD_TYPE}]" --extra-index-url https://download.pytorch.org/whl/cu${CUINDEX} \ && if [ "$CUDA_VERSION" = "12.8.1" ]; then \ python3 -m pip install --no-cache-dir nvidia-nccl-cu12==2.27.6 --force-reinstall --no-deps ; \ - python3 -m pip install --no-cache-dir https://github.com/sgl-project/whl/releases/download/v0.2.6/sgl_kernel-0.2.6+cu128-cp39-abi3-manylinux2014_x86_64.whl --force-reinstall --no-deps ; \ + python3 -m pip install --no-cache-dir https://github.com/sgl-project/whl/releases/download/v0.2.6.post1/sgl_kernel-0.2.6.post1+cu128-cp39-abi3-manylinux2014_x86_64.whl --force-reinstall --no-deps ; \ fi # Build and install NVSHMEM + DeepEP diff --git a/sgl-kernel/pyproject.toml b/sgl-kernel/pyproject.toml index 4d8ff394df4d..3b49eab5d9a8 100644 --- a/sgl-kernel/pyproject.toml +++ b/sgl-kernel/pyproject.toml @@ -8,7 +8,7 @@ build-backend = "scikit_build_core.build" [project] name = "sgl-kernel" -version = "0.2.6" +version = "0.2.6.post1" description = "Kernel Library for SGLang" readme = "README.md" requires-python = ">=3.9" diff --git a/sgl-kernel/pyproject_cpu.toml b/sgl-kernel/pyproject_cpu.toml index c243596515bd..6746b212d364 100644 --- a/sgl-kernel/pyproject_cpu.toml +++ b/sgl-kernel/pyproject_cpu.toml @@ -8,7 +8,7 @@ build-backend = "scikit_build_core.build" [project] name = "sgl-kernel" -version = "0.2.6" +version = "0.2.6.post1" description = "Kernel Library for SGLang" readme = "README.md" requires-python = ">=3.9" diff --git a/sgl-kernel/pyproject_rocm.toml b/sgl-kernel/pyproject_rocm.toml index 6ab48599c5cf..0ba8b0399bff 100644 --- a/sgl-kernel/pyproject_rocm.toml +++ b/sgl-kernel/pyproject_rocm.toml @@ -9,7 +9,7 @@ build-backend = "setuptools.build_meta" [project] name = "sgl-kernel" -version = "0.2.6" +version = "0.2.6.post1" description = "Kernel Library for SGLang" readme = "README.md" requires-python = ">=3.9" diff --git a/sgl-kernel/python/sgl_kernel/version.py b/sgl-kernel/python/sgl_kernel/version.py index 01ef12070dc3..e39bc3f224a0 100644 --- a/sgl-kernel/python/sgl_kernel/version.py +++ b/sgl-kernel/python/sgl_kernel/version.py @@ -1 +1 @@ -__version__ = "0.2.6" +__version__ = "0.2.6.post1" From c9e8613c9708afb1138f3ecef30517fb606a07a7 Mon Sep 17 00:00:00 2001 From: Ke Bao Date: Mon, 21 Jul 2025 11:19:48 +0800 Subject: [PATCH 078/396] Apply fused sorted token ids padding (#8193) --- .../sglang/srt/layers/moe/fused_moe_triton/fused_moe.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py index a39d6d5d3da4..2466067461cf 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py @@ -752,14 +752,13 @@ def moe_align_block_size( sorted_ids = torch.empty( (max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device ) - sorted_ids.fill_(topk_ids.numel()) - max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size) expert_ids = torch.empty( (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device ) num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device) if enable_moe_align_block_size_triton: + sorted_ids.fill_(topk_ids.numel()) moe_align_block_size_triton( topk_ids, num_experts, @@ -778,6 +777,11 @@ def moe_align_block_size( device=topk_ids.device, ) + # Threshold based on benchmark results + fuse_sorted_ids_padding = sorted_ids.shape[0] <= 4096 + if not fuse_sorted_ids_padding: + sorted_ids.fill_(topk_ids.numel()) + sgl_moe_align_block_size( topk_ids, num_experts, @@ -787,6 +791,7 @@ def moe_align_block_size( num_tokens_post_pad, token_cnts_buffer, cumsum_buffer, + fuse_sorted_ids_padding, ) return sorted_ids, expert_ids, num_tokens_post_pad From 8430bfe3e9ae7591feeca6c102e3b21984934a61 Mon Sep 17 00:00:00 2001 From: Xinyuan Tong <115166877+JustinTong0323@users.noreply.github.com> Date: Sun, 20 Jul 2025 21:43:09 -0700 Subject: [PATCH 079/396] [Refactor] simplify multimodal data processing (#8107) Signed-off-by: Xinyuan Tong --- docs/backend/vlm_query.ipynb | 4 +- python/sglang/srt/configs/deepseekvl2.py | 13 ++- python/sglang/srt/configs/janus_pro.py | 3 + python/sglang/srt/managers/mm_utils.py | 14 +-- .../multimodal_processors/qwen_audio.py | 94 ------------------- python/sglang/srt/managers/schedule_batch.py | 72 ++++++-------- python/sglang/srt/models/deepseek_vl2.py | 6 +- python/sglang/srt/models/mllama4.py | 4 +- python/sglang/srt/models/phi4mm.py | 9 +- .../multimodal/processors/base_processor.py | 81 ++++------------ .../sglang/srt/multimodal/processors/clip.py | 40 ++++---- .../multimodal/processors/deepseek_vl_v2.py | 34 ++----- .../srt/multimodal/processors/gemma3.py | 2 - .../srt/multimodal/processors/gemma3n.py | 2 - .../srt/multimodal/processors/internvl.py | 3 +- .../srt/multimodal/processors/janus_pro.py | 39 +++----- .../srt/multimodal/processors/kimi_vl.py | 2 - .../sglang/srt/multimodal/processors/llava.py | 4 +- .../srt/multimodal/processors/minicpm.py | 61 ++++++------ .../sglang/srt/multimodal/processors/mlama.py | 39 ++++---- .../srt/multimodal/processors/mllama4.py | 3 +- .../srt/multimodal/processors/phi4mm.py | 13 ++- .../srt/multimodal/processors/pixtral.py | 47 +++------- .../srt/multimodal/processors/qwen_audio.py | 65 +++++++++++++ .../srt/multimodal/processors/qwen_vl.py | 2 - .../sglang/srt/multimodal/processors/vila.py | 2 - test/srt/test_vision_openai_server_a.py | 33 +++---- test/srt/test_vision_openai_server_b.py | 1 + test/srt/test_vision_openai_server_common.py | 22 ++++- test/srt/test_vlm_input_format.py | 10 +- 30 files changed, 300 insertions(+), 424 deletions(-) delete mode 100644 python/sglang/srt/managers/multimodal_processors/qwen_audio.py create mode 100644 python/sglang/srt/multimodal/processors/qwen_audio.py diff --git a/docs/backend/vlm_query.ipynb b/docs/backend/vlm_query.ipynb index b47d55580bc3..3f03a5671626 100644 --- a/docs/backend/vlm_query.ipynb +++ b/docs/backend/vlm_query.ipynb @@ -126,14 +126,14 @@ " images=[image], text=conv.get_prompt(), return_tensors=\"pt\"\n", ")\n", "input_ids = processed_prompt[\"input_ids\"][0].detach().cpu().tolist()\n", - "precomputed_features = vision(\n", + "precomputed_embeddings = vision(\n", " processed_prompt[\"pixel_values\"].cuda(), processed_prompt[\"image_grid_thw\"].cuda()\n", ")\n", "\n", "mm_item = dict(\n", " modality=\"IMAGE\",\n", " image_grid_thw=processed_prompt[\"image_grid_thw\"],\n", - " precomputed_features=precomputed_features,\n", + " precomputed_embeddings=precomputed_embeddings,\n", ")\n", "out = llm.generate(input_ids=input_ids, image_data=[mm_item])\n", "print(out[\"text\"])" diff --git a/python/sglang/srt/configs/deepseekvl2.py b/python/sglang/srt/configs/deepseekvl2.py index 29fc49696fbd..bcb0afe5ae74 100644 --- a/python/sglang/srt/configs/deepseekvl2.py +++ b/python/sglang/srt/configs/deepseekvl2.py @@ -42,6 +42,9 @@ def select_best_resolution(image_size, candidate_resolutions): class DictOutput(object): + def items(self): + return self.__dict__.items() + def keys(self): return self.__dict__.keys() @@ -59,7 +62,9 @@ def __setitem__(self, key, value): class VLChatProcessorOutput(DictOutput): input_ids: torch.LongTensor target_ids: torch.LongTensor - images: torch.Tensor + pixel_values: ( + torch.Tensor + ) # rename from "images" to "pixel_values" for compatibility images_seq_mask: torch.BoolTensor images_spatial_crop: torch.LongTensor @@ -312,10 +317,14 @@ def process_one( images = torch.stack(images_list, dim=0) images_spatial_crop = torch.tensor(images_spatial_crop, dtype=torch.long) + images_spatial_crop = torch.stack( + [images_spatial_crop], dim=0 + ) # stack the tensor to make it a batch of 1 + prepare = VLChatProcessorOutput( input_ids=input_ids, target_ids=target_ids, - images=images, + pixel_values=images, images_seq_mask=images_seq_mask, images_spatial_crop=images_spatial_crop, ) diff --git a/python/sglang/srt/configs/janus_pro.py b/python/sglang/srt/configs/janus_pro.py index 143ebf578836..d574953e95d9 100644 --- a/python/sglang/srt/configs/janus_pro.py +++ b/python/sglang/srt/configs/janus_pro.py @@ -284,6 +284,9 @@ def default_shape(self): class DictOutput(object): + def items(self): + return self.__dict__.items() + def keys(self): return self.__dict__.keys() diff --git a/python/sglang/srt/managers/mm_utils.py b/python/sglang/srt/managers/mm_utils.py index d36d5d1d968c..f3faa75d9a07 100644 --- a/python/sglang/srt/managers/mm_utils.py +++ b/python/sglang/srt/managers/mm_utils.py @@ -221,17 +221,17 @@ def _get_precomputed_embedding( items: List[MultimodalDataItem], ) -> Optional[torch.Tensor]: """ - If all items have precomputed_features, return their concatenation. - If some but not all have precomputed_features, raise NotImplementedError. - If none have precomputed_features, return None. + If all items have precomputed_embeddings, return their concatenation. + If some but not all have precomputed_embeddings, raise NotImplementedError. + If none have precomputed_embeddings, return None. """ - precomputed_features = [item.precomputed_features for item in items] - if any(feature is not None for feature in precomputed_features): - if not all(feature is not None for feature in precomputed_features): + precomputed_embeddings = [item.precomputed_embeddings for item in items] + if any(feature is not None for feature in precomputed_embeddings): + if not all(feature is not None for feature in precomputed_embeddings): raise NotImplementedError( "MM inputs where only some items are precomputed." ) - result = torch.concat(precomputed_features) + result = torch.concat(precomputed_embeddings) # some models embedding is 3-dim, reshape it to 2-dim (similar to get_embedding_chunk) result = result.reshape(-1, result.shape[-1]) return result diff --git a/python/sglang/srt/managers/multimodal_processors/qwen_audio.py b/python/sglang/srt/managers/multimodal_processors/qwen_audio.py deleted file mode 100644 index 23b7de5cfd96..000000000000 --- a/python/sglang/srt/managers/multimodal_processors/qwen_audio.py +++ /dev/null @@ -1,94 +0,0 @@ -import re -from typing import List, Union - -import torch - -from sglang.srt.managers.multimodal_processors.base_processor import ( - BaseMultimodalProcessor, - MultimodalSpecialTokens, -) -from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem -from sglang.srt.models.qwen2_audio import Qwen2AudioForConditionalGeneration - - -class Qwen2AudioMultimodalProcessor(BaseMultimodalProcessor): - models = [Qwen2AudioForConditionalGeneration] - - def __init__(self, hf_config, server_args, _processor): - super().__init__(hf_config, server_args, _processor) - self.AUDIO_TOKEN = "<|audio_bos|><|AUDIO|><|audio_eos|>" - self.AUDIO_TOKEN_REGEX = re.compile( - r"<\|audio_bos\|>(?:<\|AUDIO\|>)+<\|audio_eos\|>" - ) - - async def process_mm_data_async( - self, - image_data: List[Union[str, bytes]], - input_text, - request_obj, - max_req_input_len, - **kwargs, - ): - audio_data = request_obj.audio_data - if not isinstance(audio_data, list): - audio_data = [audio_data] - - base_output = self.load_mm_data( - prompt=input_text, - max_req_input_len=max_req_input_len, - audio_data=audio_data, - multimodal_tokens=MultimodalSpecialTokens( - audio_token=self.AUDIO_TOKEN, - audio_token_regex=self.AUDIO_TOKEN_REGEX, - ), - ) - if base_output is None: - return None - - res = self.process_mm_data( - input_text=base_output.input_text, - audio=base_output.audios, - ) - - # Collect special token ids - tokenizer = self._processor.tokenizer - audio_start_id = tokenizer.convert_tokens_to_ids("<|audio_bos|>") - audio_token_id = tokenizer.convert_tokens_to_ids("<|AUDIO|>") - audio_end_id = tokenizer.convert_tokens_to_ids("<|audio_eos|>") - - items = [] - input_ids = res["input_ids"].flatten() - - if ( - "input_features" in res - and res["input_features"] is not None - and len(res["input_features"]) != 0 - ): - if audio_start_id is not None and audio_end_id is not None: - audio_offsets = self.get_mm_items_offset_by_pair( - input_ids=input_ids, - mm_start_id=audio_start_id, - mm_end_id=audio_end_id, - ) - else: - audio_offsets = None - - input_lengths = res["feature_attention_mask"].sum(dim=-1) - input_lengths = (input_lengths - 1) // 2 + 1 - output_lengths = (input_lengths - 2) // 2 + 1 - - item = MultimodalDataItem( - feature=res["input_features"], - audio_feature_lens=output_lengths, - audio_offsets=audio_offsets, - modality=Modality.AUDIO, - ) - items += [item] - - return { - "mm_items": items, - "input_ids": input_ids.tolist(), - "audio_start_id": audio_start_id, - "audio_token_id": audio_token_id, - "audio_end_id": audio_end_id, - } diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index a9ed66f9aa3d..536198cd27b4 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -201,7 +201,7 @@ class MultimodalDataItem: For example, if there are 3 images and 1 audio inputs, there will be 2 MultimodalDataItem. One for images and one for audio. - We put the common fields first and the model-specific fields last. + We put the common fields first and the model-specific fields in model_specific_data. """ modality: Modality @@ -211,37 +211,31 @@ class MultimodalDataItem: # the raw features returned by processor, e.g. pixel_values or audio_features feature: Union[torch.Tensor, np.ndarray] = None - image_sizes: Tuple[int, int] = None + # the precomputed embeddings for the modality, e.g. image_emb for image, audio_emb for audio + precomputed_embeddings: Optional[Union[torch.Tensor, np.ndarray]] = None - audio_feature_lens: Optional[List[torch.Tensor]] = None - audio_offsets: Optional[List[Tuple[int, int]]] = None - precomputed_features: Optional[Union[torch.Tensor, np.ndarray]] = None + # Model-specific data stored in a dictionary + model_specific_data: dict[str, Any] = dataclasses.field(default_factory=dict) - # For qwen-vl - image_grid_thw: Union[torch.Tensor, np.ndarray] = None - second_per_grid_ts: Optional[List[torch.Tensor]] = None - - # For deepseek-vl - image_emb_mask: Optional[torch.Tensor] = None - image_spatial_crop: Optional[torch.Tensor] = None - - # For minicpmv - # [num_images, (n, w, h)] - tgt_size: Tuple[int, int] = None - - # For mllama - aspect_ratio_id: Optional[List[torch.Tensor]] = None - aspect_ratio_mask: Optional[List[torch.Tensor]] = None - - # For kimi-vl - image_grid_hws: Optional[List[torch.Tensor]] = None + def __getattr__(self, name: str): + if ( + "model_specific_data" in self.__dict__ + and name in self.__dict__["model_specific_data"] + ): + return self.__dict__["model_specific_data"][name] + else: + raise AttributeError( + f"'{self.__class__.__name__}' object has no attribute '{name}'" + ) - # For gemma3n - input_features_mask: Optional[torch.Tensor] = None + def __setitem__(self, key: str, value: Any): + if key in self.__dict__: + self.__dict__[key] = value + else: + self.model_specific_data[key] = value - # For phi4-mm - image_attention_mask: Optional[torch.Tensor] = None - audio_attention_mask: Optional[torch.Tensor] = None + def set(self, key: str, value: Any): + self.__setitem__(key, value) @staticmethod def is_empty_list(l): @@ -259,7 +253,7 @@ def set_pad_value(self): if self.feature is not None: hashed_feature = self.feature else: - hashed_feature = self.precomputed_features + hashed_feature = self.precomputed_embeddings self.hash = hash_feature(hashed_feature) assert self.hash is not None self.pad_value = self.hash % (1 << 30) @@ -268,24 +262,13 @@ def is_modality(self, modality: Modality) -> bool: return self.modality == modality def is_audio(self): - return (self.modality == Modality.AUDIO) and ( - self.precomputed_features is not None - or not MultimodalDataItem.is_empty_list(self.feature) - ) + return self.modality == Modality.AUDIO def is_image(self): - return ( - self.is_modality(Modality.IMAGE) or self.is_modality(Modality.MULTI_IMAGES) - ) and ( - self.precomputed_features is not None - or not MultimodalDataItem.is_empty_list(self.feature) - ) + return self.modality in [Modality.IMAGE, Modality.MULTI_IMAGES] def is_video(self): - return (self.modality == Modality.VIDEO) and ( - self.precomputed_features is not None - or not MultimodalDataItem.is_empty_list(self.feature) - ) + return self.modality == Modality.VIDEO def is_valid(self) -> bool: return self.is_image() or self.is_video() or self.is_audio() @@ -306,8 +289,7 @@ def from_dict(obj: dict): def merge(self, other): self.feature += other.feature - self.image_sizes += other.image_sizes - self.image_offsets += other.image_offsets + self.offsets += other.offsets self.hash = hash((self.hash, other.hash)) self.set_pad_value() diff --git a/python/sglang/srt/models/deepseek_vl2.py b/python/sglang/srt/models/deepseek_vl2.py index cf4988b5201b..3fba37008b64 100644 --- a/python/sglang/srt/models/deepseek_vl2.py +++ b/python/sglang/srt/models/deepseek_vl2.py @@ -260,7 +260,7 @@ def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs): def get_image_feature(self, items: List[MultimodalDataItem]): images_spatial_crop = torch.cat( - [item.image_spatial_crop for item in items], dim=0 + [item.images_spatial_crop for item in items], dim=0 ) assert images_spatial_crop.dim() == 3 @@ -278,8 +278,8 @@ def get_image_feature(self, items: List[MultimodalDataItem]): _, hw, n_dim = images_embeds.shape h = w = int(hw**0.5) tile_index = 0 - for jdx in range(item.image_spatial_crop.shape[1]): - num_width_tiles, num_height_tiles = item.image_spatial_crop[0, jdx] + for jdx in range(item.images_spatial_crop.shape[1]): + num_width_tiles, num_height_tiles = item.images_spatial_crop[0, jdx] if num_width_tiles == 0 or num_height_tiles == 0: break num_tiles_in_image = num_width_tiles * num_height_tiles diff --git a/python/sglang/srt/models/mllama4.py b/python/sglang/srt/models/mllama4.py index 18b7e57e5872..8712191a98af 100644 --- a/python/sglang/srt/models/mllama4.py +++ b/python/sglang/srt/models/mllama4.py @@ -81,6 +81,7 @@ def __init__( self.logits_processor = LogitsProcessor( config.text_config if hasattr(config, "text_config") else config ) + self.padding_pattern = MultiModalityDataPaddingPatternMultimodalTokens() def _has_vision_weights(self, config) -> bool: """Check if the model has vision components by examining the checkpoint.""" @@ -135,8 +136,7 @@ def _check_vision_weights_in_index(self, index_file: str) -> bool: return False def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs): - pattern = MultiModalityDataPaddingPatternMultimodalTokens() - return pattern.pad_input_tokens(input_ids, mm_inputs) + return self.padding_pattern.pad_input_tokens(input_ids, mm_inputs) def get_image_feature( self, diff --git a/python/sglang/srt/models/phi4mm.py b/python/sglang/srt/models/phi4mm.py index b7997fc0acae..e1c5fee7837e 100644 --- a/python/sglang/srt/models/phi4mm.py +++ b/python/sglang/srt/models/phi4mm.py @@ -435,7 +435,12 @@ def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: dtype = next(self.vision_encoder.parameters()).dtype pixel_values = torch.cat([item.feature for item in items], dim=0).type(dtype) image_attention_mask = torch.cat( - [item.image_attention_mask for item in items], dim=0 + [ + item.image_attention_mask + for item in items + if hasattr(item, "image_attention_mask") + ], + dim=0, ) image_sizes = torch.cat([item.image_sizes for item in items], dim=0) image_embeds = self.vision_encoder( @@ -456,7 +461,7 @@ def get_audio_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: audio_features=item.feature.to(device).type(dtype), audio_attention_mask=( item.audio_attention_mask.to(device) - if item.audio_attention_mask is not None + if hasattr(item, "audio_attention_mask") else None ), ) diff --git a/python/sglang/srt/multimodal/processors/base_processor.py b/python/sglang/srt/multimodal/processors/base_processor.py index 6c6495c5f8f0..b79d90b987ea 100644 --- a/python/sglang/srt/multimodal/processors/base_processor.py +++ b/python/sglang/srt/multimodal/processors/base_processor.py @@ -5,7 +5,7 @@ import os import re from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, Iterator, List, Optional, Tuple, Union import numpy as np import torch @@ -155,17 +155,15 @@ def __init__(self, hf_config, server_args, _processor): self.ATTR_NAME_TO_MODALITY = { # Image-related attributes "pixel_values": Modality.IMAGE, - "pixel_values_videos": Modality.VIDEO, "image_sizes": Modality.IMAGE, "image_grid_thw": Modality.IMAGE, "image_attention_mask": Modality.IMAGE, "image_emb_mask": Modality.IMAGE, - "image_spatial_crop": Modality.IMAGE, + "images_spatial_crop": Modality.IMAGE, "tgt_size": Modality.IMAGE, "image_grid_hws": Modality.IMAGE, - "aspect_ratio_id": Modality.IMAGE, + "aspect_ratio_ids": Modality.IMAGE, "aspect_ratio_mask": Modality.IMAGE, - "second_per_grid_ts": Modality.IMAGE, # Audio-related attributes "audio_features": Modality.AUDIO, "audio_feature_lens": Modality.AUDIO, @@ -173,9 +171,11 @@ def __init__(self, hf_config, server_args, _processor): "input_features_mask": Modality.AUDIO, "audio_attention_mask": Modality.AUDIO, # Video-related attributes + "pixel_values_videos": Modality.VIDEO, + "second_per_grid_ts": Modality.VIDEO, "video_grid_thw": Modality.VIDEO, # Generic attributes that could apply to multiple modalities - # "precomputed_features" - handled specially as it can be any modality + # "precomputed_embeddings" - handled specially as it can be any modality } # name of the feature filed @@ -222,7 +222,6 @@ async def process_mm_data_async( audio_data, input_text, request_obj, - max_req_input_len, **kwargs, ) -> Optional[Dict[str, Any]]: pass @@ -283,7 +282,7 @@ def submit_data_loading_tasks( self, text_parts: List[str], multimodal_tokens: MultimodalSpecialTokens, - data_iterators: dict, + data_iterators: dict[Modality, Iterator[Any]], discard_alpha_channel: bool = True, image_estimated_frames_iter: Optional[iter] = None, image_scaling_factor: float = 1.0, @@ -354,7 +353,6 @@ def load_mm_data( self, prompt: str, multimodal_tokens: MultimodalSpecialTokens, - max_req_input_len: int, image_data: Optional[list] = None, video_data: Optional[list] = None, audio_data: Optional[list] = None, @@ -489,50 +487,11 @@ def get_mm_items_offset_by_pair( return list(zip(indices_start.tolist(), indices_end.tolist())) - @staticmethod - def _extract_processor_features( - items: List[dict], attr_name: str - ) -> Optional[torch.Tensor]: - """ - Helper function to concat extracted attributes from processor output. - """ - values = [value for item in items if (value := item.get(attr_name)) is not None] - return torch.cat(values) if values else None - - # When we assume that all the items have the same attributes - def _extract_processor_features_from_all_attributes( - self, items: List[dict] - ) -> dict: - values = {} - # Verify all items have the same keys - first_keys = set(items[0].keys()) - for item in items[1:]: - if set(item.keys()) != first_keys: - raise ValueError( - f"All items must have the same attributes. " - f"First item has {first_keys}, but found {set(item.keys())}" - ) - - # Process each attribute - for k, v in items[0].items(): - if isinstance(v, list): - values[k] = self._extract_processor_features(items, k) - else: - # Verify all items have the same value for non-list attributes - for item in items[1:]: - if item[k] != v: - raise ValueError( - f"All items must have the same value for attribute {k}. " - f"First item has {v}, but found {item[k]}" - ) - values[k] = v - return values - def collect_mm_items_from_processor_output( self, data_dict: dict ) -> List[MultimodalDataItem]: """Create mm_items directly from processor output.""" - items = {} # modality -> MultimodalDataItem + items: dict[Modality, MultimodalDataItem] = {} for attr_name, value in data_dict.items(): if attr_name == "input_ids": @@ -541,16 +500,15 @@ def collect_mm_items_from_processor_output( # Get modality for this attribute modality = self.ATTR_NAME_TO_MODALITY.get(attr_name) - if not modality and attr_name == "precomputed_features": + if attr_name == "precomputed_embeddings": modality_str = data_dict.get("modality") - try: - modality = ( - Modality.from_str(modality_str) - if modality_str - else Modality.IMAGE - ) - except ValueError: - modality = Modality.IMAGE + modality = Modality.IMAGE + if modality_str: + try: + modality = Modality.from_str(modality_str) + except ValueError: + pass + if modality: # Create item if needed if modality not in items: @@ -559,8 +517,7 @@ def collect_mm_items_from_processor_output( if attr_name in self.FEATURE_NAMES: attr_name = "feature" - # Set attribute - setattr(items[modality], attr_name, value) + items[modality].set(attr_name, value) return list(items.values()) @@ -586,6 +543,7 @@ def process_and_combine_mm_data( self, base_output: BaseMultiModalProcessorOutput, mm_tokens: MultimodalSpecialTokens, + **kwargs, ) -> Tuple[List[MultimodalDataItem], torch.Tensor, dict]: """ Process multimodal data and return the combined multimodal items and input_ids. @@ -618,7 +576,7 @@ def process_and_combine_mm_data( else: raise ValueError(f"Unknown multimodal item type: {type(item)}") # Process items and get input_ids - all_collected_items = [] + all_collected_items: list[MultimodalDataItem] = [] input_ids = None # Handle dict items (already processed) @@ -634,6 +592,7 @@ def process_and_combine_mm_data( images=raw_images, audios=raw_audios, videos=raw_videos, + **kwargs, ) all_collected_items.extend(collected_items) else: diff --git a/python/sglang/srt/multimodal/processors/clip.py b/python/sglang/srt/multimodal/processors/clip.py index a36269819c42..0925212cb44c 100644 --- a/python/sglang/srt/multimodal/processors/clip.py +++ b/python/sglang/srt/multimodal/processors/clip.py @@ -1,9 +1,10 @@ from typing import List, Union -from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem from sglang.srt.models.clip import CLIPModel -from sglang.srt.multimodal.processors.base_processor import BaseMultimodalProcessor -from sglang.srt.utils import load_image +from sglang.srt.multimodal.processors.base_processor import ( + BaseMultimodalProcessor, + MultimodalSpecialTokens, +) class ClipImageProcessor(BaseMultimodalProcessor): @@ -11,23 +12,24 @@ class ClipImageProcessor(BaseMultimodalProcessor): def __init__(self, hf_config, server_args, _processor): super().__init__(hf_config, server_args, _processor) + self.mm_tokens = MultimodalSpecialTokens(image_token="").build( + _processor + ) async def process_mm_data_async( self, image_data: List[Union[str, bytes]], input_text, *args, **kwargs ): - if isinstance(input_text, list): - assert len(input_text) and isinstance(input_text[0], int) - input_text = self._processor.tokenizer.decode(input_text) - - images = [load_image(image)[0] for image in image_data] - - image_inputs = self.process_mm_data(input_text=input_text, images=images) - image_inputs["data_hashes"] = [hash(str(image_data))] - image_inputs["input_ids"] = image_inputs["input_ids"].tolist()[0] - image_inputs["mm_items"] = [ - MultimodalDataItem( - feature=image_inputs["pixel_values"], modality=Modality.IMAGE - ) - ] - - return image_inputs + base_output = self.load_mm_data( + prompt=input_text, + multimodal_tokens=self.mm_tokens, + image_data=image_data, + ) + + mm_items, input_ids, _ = self.process_and_combine_mm_data( + base_output, self.mm_tokens + ) + + return { + "input_ids": input_ids.tolist(), + "mm_items": mm_items, + } diff --git a/python/sglang/srt/multimodal/processors/deepseek_vl_v2.py b/python/sglang/srt/multimodal/processors/deepseek_vl_v2.py index c21dce176905..9847929f7b0f 100644 --- a/python/sglang/srt/multimodal/processors/deepseek_vl_v2.py +++ b/python/sglang/srt/multimodal/processors/deepseek_vl_v2.py @@ -33,9 +33,9 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor): def __init__(self, hf_config, server_args, _processor): super().__init__(hf_config, server_args, _processor) - self.mm_tokens = MultimodalSpecialTokens(image_token="").build( - _processor - ) + self.mm_tokens = MultimodalSpecialTokens( + image_token="", image_token_id=self._processor.image_token_id + ).build(_processor) async def process_mm_data_async( self, @@ -50,36 +50,16 @@ async def process_mm_data_async( input_text, image_data=image_data, multimodal_tokens=self.mm_tokens, - max_req_input_len=max_req_input_len, ) - res = self.process_mm_data( - input_text=base_output.input_text, - images=base_output.images, + mm_items, input_ids, _ = self.process_and_combine_mm_data( + base_output, + self.mm_tokens, max_req_input_len=max_req_input_len, conversations=base_output.input_text, ) - images_seq_mask = res["images_seq_mask"] - images_spatial_crop = res["images_spatial_crop"] - batched_images_spatial_crop = [] - batched_images_spatial_crop.append(images_spatial_crop) - batched_images_spatial_crop = torch.stack(batched_images_spatial_crop, dim=0) - - items = [] - input_ids = res["input_ids"] - image_offsets = self.get_mm_items_offset( - input_ids=input_ids, mm_token_id=self._processor.image_token_id - ) - item = MultimodalDataItem( - feature=res["images"], - offsets=image_offsets, - modality=Modality.IMAGE, - image_emb_mask=images_seq_mask, - image_spatial_crop=batched_images_spatial_crop, - ) - items += [item] return { - "mm_items": items, + "mm_items": mm_items, "input_ids": input_ids.tolist(), "im_token_id": self._processor.image_token_id, } diff --git a/python/sglang/srt/multimodal/processors/gemma3.py b/python/sglang/srt/multimodal/processors/gemma3.py index dac9bd5c8241..9abf172b2c09 100644 --- a/python/sglang/srt/multimodal/processors/gemma3.py +++ b/python/sglang/srt/multimodal/processors/gemma3.py @@ -33,7 +33,6 @@ async def process_mm_data_async( image_data: List[Union[str, bytes, Dict]], input_text, request_obj, - max_req_input_len, *args, **kwargs, ): @@ -41,7 +40,6 @@ async def process_mm_data_async( prompt=input_text, image_data=image_data, multimodal_tokens=self.mm_tokens, - max_req_input_len=max_req_input_len, discard_alpha_channel=True, ) diff --git a/python/sglang/srt/multimodal/processors/gemma3n.py b/python/sglang/srt/multimodal/processors/gemma3n.py index aafeab7c9383..938819d9143e 100644 --- a/python/sglang/srt/multimodal/processors/gemma3n.py +++ b/python/sglang/srt/multimodal/processors/gemma3n.py @@ -54,7 +54,6 @@ async def process_mm_data_async( audio_data: Optional[List[Union[str, bytes, Dict]]] = None, input_text: str = "", request_obj=None, - max_req_input_len: int = 0, *args, **kwargs, ): @@ -63,7 +62,6 @@ async def process_mm_data_async( prompt=input_text, image_data=image_data, audio_data=audio_data, - max_req_input_len=max_req_input_len, multimodal_tokens=self.mm_tokens, ) diff --git a/python/sglang/srt/multimodal/processors/internvl.py b/python/sglang/srt/multimodal/processors/internvl.py index d3413c457dde..12823077f0ad 100644 --- a/python/sglang/srt/multimodal/processors/internvl.py +++ b/python/sglang/srt/multimodal/processors/internvl.py @@ -170,13 +170,12 @@ def load_video(video_path, bound=None, input_size=448, max_num=1, num_segments=3 return pixel_values, num_patches_list async def process_mm_data_async( - self, image_data, input_text, request_obj, max_req_input_len, **kwargs + self, image_data, input_text, request_obj, **kwargs ): base_output = self.load_mm_data( prompt=input_text, image_data=image_data, multimodal_tokens=self.mm_tokens, - max_req_input_len=max_req_input_len, discard_alpha_channel=True, ) diff --git a/python/sglang/srt/multimodal/processors/janus_pro.py b/python/sglang/srt/multimodal/processors/janus_pro.py index 28be34c57b01..4dd8c1a8476a 100644 --- a/python/sglang/srt/multimodal/processors/janus_pro.py +++ b/python/sglang/srt/multimodal/processors/janus_pro.py @@ -11,52 +11,35 @@ class JanusProImageProcessor(BaseMultimodalProcessor): models = [MultiModalityCausalLM] - def __init__(self, hf_config, server_args, processor): - super().__init__(hf_config, server_args, processor) + def __init__(self, hf_config, server_args, _processor): + super().__init__(hf_config, server_args, _processor) self.mm_tokens = MultimodalSpecialTokens( - image_token=processor.image_token - ).build(processor) + image_token=_processor.image_token, + image_token_id=_processor.image_id, + ).build(_processor) async def process_mm_data_async( self, image_data: List[Union[str, bytes]], input_text, request_obj, - max_req_input_len, **kwargs, ): - processor = self._processor - base_out = self.load_mm_data( prompt=input_text, image_data=image_data, multimodal_tokens=self.mm_tokens, - max_req_input_len=max_req_input_len, ) - images = base_out.images - res = self.process_mm_data( - input_text=base_out.input_text, - prompt=base_out.input_text, - images=images, + mm_items, input_ids, _ = self.process_and_combine_mm_data( + base_out, self.mm_tokens, prompt=base_out.input_text ) - input_ids = res["input_ids"].flatten() - image_offsets = self.get_mm_items_offset( - input_ids=input_ids, mm_token_id=processor.image_id - ) return { - "mm_items": [ - MultimodalDataItem( - feature=res["pixel_values"], - image_emb_mask=res["images_emb_mask"], - offsets=image_offsets, - modality=Modality.IMAGE, - ) - ], + "mm_items": mm_items, "input_ids": input_ids.tolist(), - "im_start_id": processor.image_start_id, - "im_end_id": processor.image_end_id, - "im_token_id": processor.image_id, + "im_start_id": self._processor.image_start_id, + "im_end_id": self._processor.image_end_id, + "im_token_id": self.mm_tokens.image_token_id, } diff --git a/python/sglang/srt/multimodal/processors/kimi_vl.py b/python/sglang/srt/multimodal/processors/kimi_vl.py index ef533c16d579..84c4a5133853 100644 --- a/python/sglang/srt/multimodal/processors/kimi_vl.py +++ b/python/sglang/srt/multimodal/processors/kimi_vl.py @@ -26,7 +26,6 @@ async def process_mm_data_async( image_data: List[Union[str, bytes, Dict]], input_text, request_obj, - max_req_input_len, *args, **kwargs, ): @@ -34,7 +33,6 @@ async def process_mm_data_async( prompt=input_text, image_data=image_data, multimodal_tokens=self.mm_tokens, - max_req_input_len=max_req_input_len, ) mm_items, input_ids, _ = self.process_and_combine_mm_data( diff --git a/python/sglang/srt/multimodal/processors/llava.py b/python/sglang/srt/multimodal/processors/llava.py index 03c4bf5ec634..f4504ecea2de 100644 --- a/python/sglang/srt/multimodal/processors/llava.py +++ b/python/sglang/srt/multimodal/processors/llava.py @@ -159,7 +159,9 @@ async def process_mm_data_async( "mm_items": [ MultimodalDataItem( feature=pixel_values, - image_sizes=image_sizes, + model_specific_data={ + "image_sizes": image_sizes, + }, modality=modality, ) ], diff --git a/python/sglang/srt/multimodal/processors/minicpm.py b/python/sglang/srt/multimodal/processors/minicpm.py index 3ba547b380e0..ed4f86511b1d 100644 --- a/python/sglang/srt/multimodal/processors/minicpm.py +++ b/python/sglang/srt/multimodal/processors/minicpm.py @@ -17,10 +17,21 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor): def __init__(self, hf_config, server_args, _processor): super().__init__(hf_config, server_args, _processor) + # Collect special token ids + tokenizer = self._processor.tokenizer + self.slice_start_id = getattr(tokenizer, "slice_start_id", None) + self.slice_end_id = getattr(tokenizer, "slice_end_id", None) + self.audio_start_id = getattr(tokenizer, "audio_start_id", None) + self.audio_end_id = getattr(tokenizer, "audio_end_id", None) + self.im_start_id = getattr(tokenizer, "im_start_id", None) + self.im_end_id = getattr(tokenizer, "im_end_id", None) + self.im_token_id = getattr(tokenizer, "unk_id", None) + self.mm_tokens = MultimodalSpecialTokens( image_token="(./)", audio_token="()", video_token="()", + image_token_id=self.im_token_id, ).build(_processor) async def process_mm_data_async( @@ -29,12 +40,10 @@ async def process_mm_data_async( audio_data: List[Union[str, bytes]], input_text, request_obj, - max_req_input_len, **kwargs, ): base_output = self.load_mm_data( prompt=input_text, - max_req_input_len=max_req_input_len, audio_data=audio_data, image_data=image_data, multimodal_tokens=self.mm_tokens, @@ -48,24 +57,6 @@ async def process_mm_data_async( audios=base_output.audios, ) - # Collect special token ids - tokenizer = self._processor.tokenizer - slice_start_id, slice_end_id, audio_start_id, audio_end_id = ( - None, - None, - None, - None, - ) - if tokenizer.slice_start_id: - slice_start_id = tokenizer.slice_start_id - slice_end_id = tokenizer.slice_end_id - if hasattr(tokenizer, "audio_start_id"): - audio_start_id = tokenizer.audio_start_id - audio_end_id = tokenizer.audio_end_id - - im_start_id = tokenizer.im_start_id - im_end_id = tokenizer.im_end_id - im_token_id = tokenizer.unk_id pixel_values = res["pixel_values"] tgt_sizes = res["tgt_sizes"] @@ -102,10 +93,12 @@ async def process_mm_data_async( items = [] input_ids = res["input_ids"].flatten() image_offsets = self.get_mm_items_offset_by_pair( - input_ids=input_ids, mm_start_id=im_start_id, mm_end_id=im_end_id + input_ids=input_ids, mm_start_id=self.im_start_id, mm_end_id=self.im_end_id ) slice_offsets = self.get_mm_items_offset_by_pair( - input_ids=input_ids, mm_start_id=slice_start_id, mm_end_id=slice_end_id + input_ids=input_ids, + mm_start_id=self.slice_start_id, + mm_end_id=self.slice_end_id, ) image_offsets.extend(slice_offsets) image_offsets = sorted(image_offsets) @@ -114,7 +107,7 @@ async def process_mm_data_async( item = MultimodalDataItem( feature=pixel_values, offsets=image_offsets, - tgt_size=tgt_sizes_flat, + model_specific_data={"tgt_size": tgt_sizes_flat}, modality=Modality.IMAGE, ) items += [item] @@ -124,17 +117,17 @@ async def process_mm_data_async( and res["audio_features"] is not None and len(res["audio_features"]) != 0 ): - if audio_start_id is not None and audio_end_id is not None: + if self.audio_start_id is not None and self.audio_end_id is not None: audio_offsets = self.get_mm_items_offset_by_pair( input_ids=input_ids, - mm_start_id=audio_start_id, - mm_end_id=audio_end_id, + mm_start_id=self.audio_start_id, + mm_end_id=self.audio_end_id, ) else: audio_offsets = None item = MultimodalDataItem( feature=[res["audio_features"]], - audio_feature_lens=res["audio_feature_lens"], + model_specific_data={"audio_feature_lens": res["audio_feature_lens"]}, offsets=audio_offsets, modality=Modality.AUDIO, ) @@ -142,11 +135,11 @@ async def process_mm_data_async( return { "mm_items": items, "input_ids": input_ids.tolist(), - "audio_start_id": audio_start_id, - "audio_end_id": audio_end_id, - "im_token_id": im_token_id, - "im_start_id": im_start_id, - "im_end_id": im_end_id, - "slice_start_id": slice_start_id, - "slice_end_id": slice_end_id, + "audio_start_id": self.audio_start_id, + "audio_end_id": self.audio_end_id, + "im_token_id": self.im_token_id, + "im_start_id": self.im_start_id, + "im_end_id": self.im_end_id, + "slice_start_id": self.slice_start_id, + "slice_end_id": self.slice_end_id, } diff --git a/python/sglang/srt/multimodal/processors/mlama.py b/python/sglang/srt/multimodal/processors/mlama.py index 783145027b79..dd31844525b4 100644 --- a/python/sglang/srt/multimodal/processors/mlama.py +++ b/python/sglang/srt/multimodal/processors/mlama.py @@ -1,9 +1,10 @@ from typing import List, Union -from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem from sglang.srt.models.mllama import MllamaForConditionalGeneration -from sglang.srt.multimodal.processors.base_processor import BaseMultimodalProcessor -from sglang.srt.utils import load_image +from sglang.srt.multimodal.processors.base_processor import ( + BaseMultimodalProcessor, + MultimodalSpecialTokens, +) class MllamaImageProcessor(BaseMultimodalProcessor): @@ -11,24 +12,26 @@ class MllamaImageProcessor(BaseMultimodalProcessor): def __init__(self, hf_config, server_args, _processor): super().__init__(hf_config, server_args, _processor) + self.mm_tokens = MultimodalSpecialTokens( + image_token=self._processor.image_token, + image_token_id=self._processor.image_token_id, + ).build(_processor) async def process_mm_data_async( self, image_data: List[Union[str, bytes]], input_text, *args, **kwargs ): - if isinstance(input_text, list): - assert len(input_text) and isinstance(input_text[0], int) - input_text = self._processor.tokenizer.decode(input_text) + base_out = self.load_mm_data( + prompt=input_text, + image_data=image_data, + multimodal_tokens=self.mm_tokens, + ) - images = [load_image(image)[0] for image in image_data] - image_inputs = self.process_mm_data(input_text=input_text, images=images) - image_inputs["input_ids"] = image_inputs["input_ids"].tolist()[0] - image_inputs["mm_items"] = [ - MultimodalDataItem( - feature=image_inputs["pixel_values"], - aspect_ratio_id=image_inputs["aspect_ratio_ids"], - aspect_ratio_mask=image_inputs["aspect_ratio_mask"], - modality=Modality.IMAGE, - ) - ] + mm_items, input_ids, _ = self.process_and_combine_mm_data( + base_out, self.mm_tokens + ) - return image_inputs + return { + "mm_items": mm_items, + "input_ids": input_ids.tolist(), + "im_token_id": self.mm_tokens.image_token_id, + } diff --git a/python/sglang/srt/multimodal/processors/mllama4.py b/python/sglang/srt/multimodal/processors/mllama4.py index 566eb3230c17..2d0eba2fd499 100644 --- a/python/sglang/srt/multimodal/processors/mllama4.py +++ b/python/sglang/srt/multimodal/processors/mllama4.py @@ -27,13 +27,13 @@ def __init__(self, hf_config, server_args, _processor): self.image_token_index = hf_config.image_token_index self.multimodal_tokens = MultimodalSpecialTokens( image_token=_processor.image_token, + image_token_id=self.image_token_index, ).build(_processor) async def process_mm_data_async( self, image_data: List[Union[str, bytes]], input_text, - max_req_input_len=None, *args, **kwargs, ): @@ -45,7 +45,6 @@ async def process_mm_data_async( processed_data = self.load_mm_data( prompt=input_text, multimodal_tokens=self.multimodal_tokens, - max_req_input_len=max_req_input_len or 4096, image_data=image_data, return_text=True, ) diff --git a/python/sglang/srt/multimodal/processors/phi4mm.py b/python/sglang/srt/multimodal/processors/phi4mm.py index 8772403dbdb7..720e3c1324e7 100644 --- a/python/sglang/srt/multimodal/processors/phi4mm.py +++ b/python/sglang/srt/multimodal/processors/phi4mm.py @@ -31,6 +31,7 @@ def __call__(self, **kwargs): for hf_key, sglang_key in key_mapping.items(): if hf_key in result: result[sglang_key] = result[hf_key] + del result[hf_key] # Filter out None or empty tensors from the result. # This prevents the sglang function base_processor.collect_mm_items_from_processor_output() @@ -58,7 +59,7 @@ def __init__(self, hf_config, server_args, _processor): self.AUDIO_TOKEN_ID = 200011 self.AUDIO_SAMPLE_RATE = 16000 - self.multimodal_tokens = MultimodalSpecialTokens( + self.mm_tokens = MultimodalSpecialTokens( image_token=self.IMAGE_TOKEN, image_token_id=self.IM_TOKEN_ID, audio_token=self.AUDIO_TOKEN, @@ -71,15 +72,13 @@ async def process_mm_data_async( audio_data, input_text, request_obj, - max_req_input_len, **kwargs, ): base_output = self.load_mm_data( prompt=input_text, - max_req_input_len=max_req_input_len, audio_data=audio_data, image_data=image_data, - multimodal_tokens=self.multimodal_tokens, + multimodal_tokens=self.mm_tokens, audio_sample_rate=self.AUDIO_SAMPLE_RATE, ) @@ -91,12 +90,12 @@ async def process_mm_data_async( ] mm_items, input_ids, _ = self.process_and_combine_mm_data( - base_output, self.multimodal_tokens + base_output, self.mm_tokens ) return { "input_ids": input_ids.tolist(), "mm_items": mm_items, - "im_token_id": self.IM_TOKEN_ID, - "audio_token_id": self.AUDIO_TOKEN_ID, + "im_token_id": self.mm_tokens.image_token_id, + "audio_token_id": self.mm_tokens.audio_token_id, } diff --git a/python/sglang/srt/multimodal/processors/pixtral.py b/python/sglang/srt/multimodal/processors/pixtral.py index b18dfa1b023e..fdfd6bd627ee 100644 --- a/python/sglang/srt/multimodal/processors/pixtral.py +++ b/python/sglang/srt/multimodal/processors/pixtral.py @@ -6,7 +6,6 @@ _num_image_tokens as _get_pixtral_hf_num_image_tokens, ) -from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem from sglang.srt.models.pixtral import PixtralVisionModel from sglang.srt.multimodal.processors.base_processor import ( BaseMultimodalProcessor, @@ -45,7 +44,7 @@ def get_patch_grid_size( def __init__(self, hf_config, server_args, _processor): super().__init__(hf_config, server_args, _processor) - self.image_token_id = getattr( + self.IM_TOKEN_ID = getattr( hf_config, "image_token_index", PixtralVisionModel.DEFAULT_IMAGE_TOKEN_ID ) # Instantiate the patcher logic helper using the class defined above @@ -53,8 +52,9 @@ def __init__(self, hf_config, server_args, _processor): self.vision_config = hf_config.vision_config self.image_size = self.vision_config.image_size self.patch_size = self.vision_config.patch_size - self.multimodal_tokens = MultimodalSpecialTokens( - image_token=_processor.image_token + self.mm_tokens = MultimodalSpecialTokens( + image_token=_processor.image_token, + image_token_id=self.IM_TOKEN_ID, ).build(_processor) _processor.tokenizer.add_special_tokens( { @@ -80,42 +80,21 @@ async def process_mm_data_async( ): mm_data = self.load_mm_data( prompt=input_text, - multimodal_tokens=self.multimodal_tokens, - max_req_input_len=kwargs.get("max_req_input_len", 4096), + multimodal_tokens=self.mm_tokens, image_data=image_data, return_text=True, ) - if mm_data.images: resize_tasks = [self._resize(image) for image in mm_data.images] mm_data.images = await asyncio.gather(*resize_tasks) - processor_output = self.process_mm_data( - input_text=mm_data.input_text, - images=mm_data.images, + mm_items, input_ids, _ = self.process_and_combine_mm_data( + mm_data, self.mm_tokens ) - if "pixel_values" in processor_output: - input_ids = processor_output["input_ids"].view(-1) - image_offsets = self.get_mm_items_offset( - input_ids=input_ids, - mm_token_id=self.image_token_id, - ) - mm_items = [ - MultimodalDataItem( - feature=processor_output["pixel_values"], - image_sizes=processor_output["image_sizes"], - modality=Modality.IMAGE, - offsets=image_offsets, - ) - ] - - input_ids = input_ids.tolist() - processor_output.update( - input_ids=input_ids, - mm_items=mm_items, - # there's no im_start_id for pixtral, only im_token and im_end_token - im_end_id=self.IMG_END_TOKEN_ID, - im_token_id=self.image_token_id, - ) - return processor_output + return { + "mm_items": mm_items, + "input_ids": input_ids.tolist(), + "im_token_id": self.IM_TOKEN_ID, + "im_token": self._processor.image_token, + } diff --git a/python/sglang/srt/multimodal/processors/qwen_audio.py b/python/sglang/srt/multimodal/processors/qwen_audio.py new file mode 100644 index 000000000000..34d440375ae3 --- /dev/null +++ b/python/sglang/srt/multimodal/processors/qwen_audio.py @@ -0,0 +1,65 @@ +import re + +from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem +from sglang.srt.models.qwen2_audio import Qwen2AudioForConditionalGeneration +from sglang.srt.multimodal.processors.base_processor import ( + BaseMultimodalProcessor, + MultimodalSpecialTokens, +) + + +class Qwen2AudioMultimodalProcessor(BaseMultimodalProcessor): + models = [Qwen2AudioForConditionalGeneration] + + def __init__(self, hf_config, server_args, _processor): + super().__init__(hf_config, server_args, _processor) + self.AUDIO_TOKEN = "<|audio_bos|><|AUDIO|><|audio_eos|>" + self.AUDIO_TOKEN_REGEX = re.compile( + r"<\|audio_bos\|>(?:<\|AUDIO\|>)+<\|audio_eos\|>" + ) + # Collect special token ids + tokenizer = self._processor.tokenizer + self.audio_start_id = tokenizer.convert_tokens_to_ids("<|audio_bos|>") + self.audio_token_id = tokenizer.convert_tokens_to_ids("<|AUDIO|>") + self.audio_end_id = tokenizer.convert_tokens_to_ids("<|audio_eos|>") + + self.mm_tokens = MultimodalSpecialTokens( + audio_token=self.AUDIO_TOKEN, + audio_token_regex=self.AUDIO_TOKEN_REGEX, + audio_token_id=self.audio_token_id, + ).build(_processor) + + async def process_mm_data_async( + self, + audio_data, + input_text, + **kwargs, + ): + base_output = self.load_mm_data( + prompt=input_text, + audio_data=audio_data, + multimodal_tokens=self.mm_tokens, + ) + if base_output is None: + return None + + mm_items, input_ids, ret = self.process_and_combine_mm_data( + base_output, self.mm_tokens + ) + + assert ( + "feature_attention_mask" in ret + ), "feature_attention_mask not found in processor output" + input_lengths = ret["feature_attention_mask"].sum(dim=-1) + input_lengths = (input_lengths - 1) // 2 + 1 + output_lengths = (input_lengths - 2) // 2 + 1 + + mm_items[0].model_specific_data["audio_feature_lens"] = output_lengths + + return { + "mm_items": mm_items, + "input_ids": input_ids.tolist(), + "audio_start_id": self.audio_start_id, + "audio_token_id": self.audio_token_id, + "audio_end_id": self.audio_end_id, + } diff --git a/python/sglang/srt/multimodal/processors/qwen_vl.py b/python/sglang/srt/multimodal/processors/qwen_vl.py index bdfaf140624f..1b1de43695bb 100644 --- a/python/sglang/srt/multimodal/processors/qwen_vl.py +++ b/python/sglang/srt/multimodal/processors/qwen_vl.py @@ -227,7 +227,6 @@ async def process_mm_data_async( image_data: List[Union[str, bytes]], input_text, request_obj, - max_req_input_len, *args, **kwargs, ): @@ -237,7 +236,6 @@ async def process_mm_data_async( image_data=image_data, video_data=request_obj.video_data, multimodal_tokens=self.mm_tokens, - max_req_input_len=max_req_input_len, ) # Qwen-specific: resize images if they are raw Image objects diff --git a/python/sglang/srt/multimodal/processors/vila.py b/python/sglang/srt/multimodal/processors/vila.py index 8e0f04acae89..7070dfe73dc9 100644 --- a/python/sglang/srt/multimodal/processors/vila.py +++ b/python/sglang/srt/multimodal/processors/vila.py @@ -47,13 +47,11 @@ async def process_mm_data_async( image_data: Optional[ImageDataInputItem | List[ImageDataInputItem]], input_text: str | List[int], request_obj: GenerateReqInput | EmbeddingReqInput, - max_req_input_len: int, **kwargs, ) -> Optional[Dict[str, Any]]: base_output = self.load_mm_data( prompt=input_text, multimodal_tokens=self.mm_tokens, - max_req_input_len=max_req_input_len, image_data=image_data, ) diff --git a/test/srt/test_vision_openai_server_a.py b/test/srt/test_vision_openai_server_a.py index f252c4884eb0..4c41e2feca90 100644 --- a/test/srt/test_vision_openai_server_a.py +++ b/test/srt/test_vision_openai_server_a.py @@ -116,22 +116,23 @@ def test_single_image_chat_completion(self): ) -class TestMllamaServer(TestOpenAIVisionServer): - @classmethod - def setUpClass(cls): - cls.model = "meta-llama/Llama-3.2-11B-Vision-Instruct" - cls.base_url = DEFAULT_URL_FOR_TEST - cls.api_key = "sk-123456" - cls.process = popen_launch_server( - cls.model, - cls.base_url, - timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - api_key=cls.api_key, - ) - cls.base_url += "/v1" - - def test_video_chat_completion(self): - pass +# Note(Xinyuan): mllama is not stable for now, skip for CI +# class TestMllamaServer(TestOpenAIVisionServer): +# @classmethod +# def setUpClass(cls): +# cls.model = "meta-llama/Llama-3.2-11B-Vision-Instruct" +# cls.base_url = DEFAULT_URL_FOR_TEST +# cls.api_key = "sk-123456" +# cls.process = popen_launch_server( +# cls.model, +# cls.base_url, +# timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, +# api_key=cls.api_key, +# ) +# cls.base_url += "/v1" + +# def test_video_chat_completion(self): +# pass class TestMinicpmvServer(TestOpenAIVisionServer): diff --git a/test/srt/test_vision_openai_server_b.py b/test/srt/test_vision_openai_server_b.py index 53498946144c..dabf948b3567 100644 --- a/test/srt/test_vision_openai_server_b.py +++ b/test/srt/test_vision_openai_server_b.py @@ -67,6 +67,7 @@ def setUpClass(cls): "--trust-remote-code", "--context-length", "4096", + "--disable-cuda-graph", ], ) cls.base_url += "/v1" diff --git a/test/srt/test_vision_openai_server_common.py b/test/srt/test_vision_openai_server_common.py index 341db654e053..2f7e404cb697 100644 --- a/test/srt/test_vision_openai_server_common.py +++ b/test/srt/test_vision_openai_server_common.py @@ -308,19 +308,35 @@ def test_video_images_chat_completion(self): "iPod" in video_response or "device" in video_response or "microphone" in video_response - ), video_response + ), f""" + ====================== video_response ===================== + {video_response} + =========================================================== + should contain 'iPod' or 'device' or 'microphone' + """ assert ( "man" in video_response or "person" in video_response or "individual" in video_response or "speaker" in video_response - ), video_response + or "Steve" in video_response + ), f""" + ====================== video_response ===================== + {video_response} + =========================================================== + should contain 'man' or 'person' or 'individual' or 'speaker' + """ assert ( "present" in video_response or "examine" in video_response or "display" in video_response or "hold" in video_response - ) + ), f""" + ====================== video_response ===================== + {video_response} + =========================================================== + should contain 'present' or 'examine' or 'display' or 'hold' + """ assert "black" in video_response or "dark" in video_response self.assertIsNotNone(video_response) self.assertGreater(len(video_response), 0) diff --git a/test/srt/test_vlm_input_format.py b/test/srt/test_vlm_input_format.py index d2670ecac5f0..79625ee82cbb 100644 --- a/test/srt/test_vlm_input_format.py +++ b/test/srt/test_vlm_input_format.py @@ -104,15 +104,15 @@ async def test_understands_image(self): ) self.verify_response(output) - async def test_understands_precomputed_features(self): + async def test_understands_precomputed_embeddings(self): req = self.get_completion_request() processor_output = self.get_processor_output(req=req) with torch.inference_mode(): - precomputed_features = self.__class__.visual(processor_output) + precomputed_embeddings = self.__class__.visual(processor_output) output = await self.engine.async_generate( input_ids=processor_output["input_ids"][0].detach().cpu().tolist(), image_data=[ - self._precomputed_image_data(processor_output, precomputed_features) + self._precomputed_image_data(processor_output, precomputed_embeddings) ], sampling_params=dict(temperature=0.0), ) @@ -128,11 +128,11 @@ async def test_understands_pixel_values(self): ) self.verify_response(output) - def _precomputed_image_data(self, processor_output, precomputed_features): + def _precomputed_image_data(self, processor_output, precomputed_embeddings): """This should not be overridden.""" return dict( modality="IMAGE", - precomputed_features=precomputed_features, + precomputed_embeddings=precomputed_embeddings, ) def _pixel_values_image_data(self, processor_output): From 5c8365a0516ae908c1733054afb6852f3bee91dd Mon Sep 17 00:00:00 2001 From: Simo Lin Date: Sun, 20 Jul 2025 23:12:52 -0700 Subject: [PATCH 080/396] [router] add ut for pd router (#8208) --- sgl-router/src/routers/pd_router.rs | 512 ++++++++++++++++++++++++++++ sgl-router/tests/test_pd_routing.rs | 21 -- 2 files changed, 512 insertions(+), 21 deletions(-) diff --git a/sgl-router/src/routers/pd_router.rs b/sgl-router/src/routers/pd_router.rs index d156c9f341d6..7c70a3873fc3 100644 --- a/sgl-router/src/routers/pd_router.rs +++ b/sgl-router/src/routers/pd_router.rs @@ -1393,3 +1393,515 @@ impl RouterTrait for PDRouter { } } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::core::{BasicWorker, WorkerType}; + use crate::policies::{CacheAwarePolicy, RandomPolicy}; + use crate::routers::pd_types::SingleOrBatch; + use actix_web::test::TestRequest; + + fn create_test_pd_router() -> PDRouter { + let policy = Arc::new(RandomPolicy::new()); + + PDRouter { + prefill_workers: Arc::new(RwLock::new(vec![])), + decode_workers: Arc::new(RwLock::new(vec![])), + policy, + prefill_tree: None, + timeout_secs: 5, + interval_secs: 1, + worker_loads: Arc::new(tokio::sync::watch::channel(HashMap::new()).1), + load_monitor_handle: None, + http_client: reqwest::Client::new(), + _prefill_health_checker: None, + _decode_health_checker: None, + } + } + + fn create_test_worker(url: String, worker_type: WorkerType, healthy: bool) -> Box { + let worker = BasicWorker::new(url, worker_type); + worker.set_healthy(healthy); + Box::new(worker) + } + + // ============= Worker Management Tests ============= + + #[tokio::test] + async fn test_add_prefill_server_already_exists() { + let router = create_test_pd_router(); + + // Add a worker first + let worker = create_test_worker( + "http://localhost:8000".to_string(), + WorkerType::Prefill { + bootstrap_port: Some(8080), + }, + true, + ); + router.prefill_workers.write().unwrap().push(worker); + + // Try to add the same URL again - this would fail during health check in real scenario + // For unit test, we test the duplicate check logic + let workers = router.prefill_workers.read().unwrap(); + let exists = workers.iter().any(|w| w.url() == "http://localhost:8000"); + assert!(exists); + } + + #[tokio::test] + async fn test_remove_prefill_server_success() { + let router = create_test_pd_router(); + + // Add servers first + let worker1 = create_test_worker( + "http://worker1".to_string(), + WorkerType::Prefill { + bootstrap_port: None, + }, + true, + ); + let worker2 = create_test_worker( + "http://worker2".to_string(), + WorkerType::Prefill { + bootstrap_port: Some(8080), + }, + true, + ); + + router.prefill_workers.write().unwrap().push(worker1); + router.prefill_workers.write().unwrap().push(worker2); + + // Remove one + let result = router.remove_prefill_server("http://worker1").await; + + assert!(result.is_ok()); + assert!(result.unwrap().contains("Successfully removed")); + + let workers = router.prefill_workers.read().unwrap(); + assert_eq!(workers.len(), 1); + assert_eq!(workers[0].url(), "http://worker2"); + } + + #[tokio::test] + async fn test_remove_prefill_server_not_found() { + let router = create_test_pd_router(); + + let result = router.remove_prefill_server("http://nonexistent").await; + + assert!(result.is_err()); + match result.unwrap_err() { + PDRouterError::WorkerNotFound { url } => { + assert_eq!(url, "http://nonexistent"); + } + _ => panic!("Expected WorkerNotFound error"), + } + } + + #[tokio::test] + async fn test_remove_decode_server_success() { + let router = create_test_pd_router(); + + // Add server first + let worker = create_test_worker("http://decode1".to_string(), WorkerType::Decode, true); + router.decode_workers.write().unwrap().push(worker); + + let result = router.remove_decode_server("http://decode1").await; + + assert!(result.is_ok()); + assert!(result.unwrap().contains("Successfully removed")); + + let workers = router.decode_workers.read().unwrap(); + assert_eq!(workers.len(), 0); + } + + // ============= Lock Error Handling Tests ============= + + #[test] + fn test_lock_operations() { + let router = create_test_pd_router(); + + // Test read/write locks work correctly + { + let read_guard = router.prefill_workers.read().unwrap(); + assert_eq!(read_guard.len(), 0); + } + + { + let mut write_guard = router.prefill_workers.write().unwrap(); + write_guard.push(create_test_worker( + "http://test".to_string(), + WorkerType::Prefill { + bootstrap_port: None, + }, + true, + )); + } + + { + let read_guard = router.prefill_workers.read().unwrap(); + assert_eq!(read_guard.len(), 1); + } + } + + // ============= Cache Tree Integration Tests ============= + + #[tokio::test] + async fn test_cache_tree_operations() { + let policy = Arc::new(CacheAwarePolicy::new()); + let mut router = create_test_pd_router(); + router.policy = policy; + + // Initialize cache tree + let tree = Arc::new(Mutex::new(Tree::new())); + router.prefill_tree = Some(Arc::clone(&tree)); + + // Manually add worker and update tree + let worker = create_test_worker( + "http://worker1".to_string(), + WorkerType::Prefill { + bootstrap_port: None, + }, + true, + ); + router.prefill_workers.write().unwrap().push(worker); + + // Update tree + tree.lock().unwrap().insert("", "http://worker1"); + + // Verify tree contains the worker + let tree_guard = tree.lock().unwrap(); + let (_matched_text, tenant) = tree_guard.prefix_match(""); + // Since we inserted with empty prefix, we should get a match + assert_eq!(tenant, "http://worker1"); + } + + #[tokio::test] + async fn test_cache_tree_rebuild_on_remove() { + let policy = Arc::new(CacheAwarePolicy::new()); + let mut router = create_test_pd_router(); + router.policy = policy; + + // Initialize cache tree + let tree = Arc::new(Mutex::new(Tree::new())); + router.prefill_tree = Some(Arc::clone(&tree)); + + // Add multiple workers + let worker1 = create_test_worker( + "http://worker1".to_string(), + WorkerType::Prefill { + bootstrap_port: None, + }, + true, + ); + let worker2 = create_test_worker( + "http://worker2".to_string(), + WorkerType::Prefill { + bootstrap_port: None, + }, + true, + ); + + router.prefill_workers.write().unwrap().push(worker1); + router.prefill_workers.write().unwrap().push(worker2); + + // Initialize tree with both workers + { + let tree_guard = tree.lock().unwrap(); + tree_guard.insert("", "http://worker1"); + tree_guard.insert("", "http://worker2"); + } + + // Remove one worker + let result = router.remove_prefill_server("http://worker1").await; + assert!(result.is_ok()); + + // Verify tree only contains remaining worker + let tree_guard = tree.lock().unwrap(); + let (_matched_text, tenant) = tree_guard.prefix_match(""); + // After rebuild, tree should only have worker2 + assert_eq!(tenant, "http://worker2"); + } + + #[tokio::test] + async fn test_no_cache_tree_operations() { + let router = create_test_pd_router(); + assert!(router.prefill_tree.is_none()); + + // Add a worker without cache tree + let worker = create_test_worker( + "http://worker1".to_string(), + WorkerType::Prefill { + bootstrap_port: None, + }, + true, + ); + router.prefill_workers.write().unwrap().push(worker); + + // Remove should work without tree + let result = router.remove_prefill_server("http://worker1").await; + assert!(result.is_ok()); + } + + // ============= Bootstrap Injection Tests ============= + + #[test] + fn test_bootstrap_injection_with_existing_fields() { + let mut req = GenerateReqInput { + text: Some(SingleOrBatch::Single("Test".to_string())), + input_ids: None, + stream: false, + bootstrap_host: Some(SingleOrBatch::Single("existing-host".to_string())), + bootstrap_port: Some(SingleOrBatch::Single(Some(9999))), + bootstrap_room: Some(SingleOrBatch::Single(12345)), + other: Value::Object(serde_json::Map::new()), + }; + + let prefill_worker = create_test_worker( + "http://new-host:8000".to_string(), + WorkerType::Prefill { + bootstrap_port: Some(8080), + }, + true, + ); + + // Bootstrap info is added regardless of existing fields + let result = req.add_bootstrap_info(prefill_worker.as_ref()); + assert!(result.is_ok()); + + // Bootstrap info should be updated with new values + assert_eq!( + req.bootstrap_host, + Some(SingleOrBatch::Single("new-host".to_string())) + ); + assert_eq!(req.bootstrap_port, Some(SingleOrBatch::Single(Some(8080)))); + // Room should be regenerated (different from original) + if let Some(SingleOrBatch::Single(room)) = req.bootstrap_room { + assert_ne!(room, 12345); + } else { + panic!("Expected single room ID"); + } + } + + #[test] + fn test_bootstrap_room_generation() { + let mut req1 = GenerateReqInput { + text: Some(SingleOrBatch::Single("Test".to_string())), + input_ids: None, + stream: false, + bootstrap_host: None, + bootstrap_port: None, + bootstrap_room: None, + other: Value::Object(serde_json::Map::new()), + }; + + let mut req2 = GenerateReqInput { + text: Some(SingleOrBatch::Single("Test".to_string())), + input_ids: None, + stream: false, + bootstrap_host: None, + bootstrap_port: None, + bootstrap_room: None, + other: Value::Object(serde_json::Map::new()), + }; + + let prefill_worker = create_test_worker( + "http://host:8000".to_string(), + WorkerType::Prefill { + bootstrap_port: Some(8080), + }, + true, + ); + + // Add bootstrap info to both requests + let _ = req1.add_bootstrap_info(prefill_worker.as_ref()); + let _ = req2.add_bootstrap_info(prefill_worker.as_ref()); + + // Room IDs should be different + if let (Some(SingleOrBatch::Single(room1)), Some(SingleOrBatch::Single(room2))) = + (req1.bootstrap_room, req2.bootstrap_room) + { + assert_ne!(room1, room2, "Room IDs should be unique"); + } else { + panic!("Expected single room IDs"); + } + } + + // ============= Worker Selection Tests ============= + + #[tokio::test] + async fn test_select_healthy_prefill_worker() { + let router = create_test_pd_router(); + + // Add mix of healthy and unhealthy workers + let healthy_worker = create_test_worker( + "http://healthy".to_string(), + WorkerType::Prefill { + bootstrap_port: None, + }, + true, + ); + let unhealthy_worker = create_test_worker( + "http://unhealthy".to_string(), + WorkerType::Prefill { + bootstrap_port: None, + }, + false, + ); + let decode_worker = + create_test_worker("http://decode".to_string(), WorkerType::Decode, true); + + router + .prefill_workers + .write() + .unwrap() + .push(unhealthy_worker); + router.prefill_workers.write().unwrap().push(healthy_worker); + router.decode_workers.write().unwrap().push(decode_worker); + + let client = reqwest::Client::new(); + let result = router.select_pd_pair(&client, None).await; + + assert!(result.is_ok()); + let (prefill, _decode) = result.unwrap(); + + // Should select the healthy worker + assert_eq!(prefill.url(), "http://healthy"); + assert!(prefill.is_healthy()); + } + + #[tokio::test] + async fn test_empty_worker_lists() { + let router = create_test_pd_router(); + + let client = reqwest::Client::new(); + let result = router.select_pd_pair(&client, None).await; + + assert!(result.is_err()); + assert!(result.unwrap_err().contains("No prefill workers available")); + } + + // ============= Health Endpoints Tests ============= + + #[tokio::test] + async fn test_health_endpoints() { + let router = create_test_pd_router(); + + // Add healthy workers + let prefill_worker = create_test_worker( + "http://localhost:8000".to_string(), + WorkerType::Prefill { + bootstrap_port: None, + }, + true, + ); + let decode_worker = create_test_worker( + "http://localhost:8001".to_string(), + WorkerType::Decode, + true, + ); + + router.prefill_workers.write().unwrap().push(prefill_worker); + router.decode_workers.write().unwrap().push(decode_worker); + + // Test health endpoint + let client = reqwest::Client::new(); + let http_req = TestRequest::default().to_http_request(); + let response = router.health(&client, &http_req).await; + + assert_eq!(response.status(), 200); + + // Test readiness endpoint + let response = router.readiness(); + assert_eq!(response.status(), 200); + } + + // ============= Load Monitoring Tests ============= + + #[tokio::test] + async fn test_load_monitor_updates() { + let policy = Arc::new(crate::policies::PowerOfTwoPolicy::new()); + let mut router = create_test_pd_router(); + router.policy = policy; + + // Create load channel + let (tx, rx) = tokio::sync::watch::channel(HashMap::new()); + router.worker_loads = Arc::new(rx); + + // Simulate load updates + let mut loads = HashMap::new(); + loads.insert("http://worker1".to_string(), 10); + loads.insert("http://worker2".to_string(), 5); + + let _ = tx.send(loads.clone()); + + // Router should receive updates + let received = router.worker_loads.borrow().clone(); + assert_eq!(received.get("http://worker1"), Some(&10)); + assert_eq!(received.get("http://worker2"), Some(&5)); + } + + // ============= Worker Load Tests ============= + + #[test] + fn test_worker_load_metrics() { + let prefill_worker = create_test_worker( + "http://prefill".to_string(), + WorkerType::Prefill { + bootstrap_port: None, + }, + true, + ); + let decode_worker = + create_test_worker("http://decode".to_string(), WorkerType::Decode, true); + + // Create load guard for both workers + let _guard = + WorkerLoadGuard::new_multi(vec![prefill_worker.as_ref(), decode_worker.as_ref()]); + + // Load should be incremented + assert_eq!(prefill_worker.load(), 1); + assert_eq!(decode_worker.load(), 1); + + // Drop guard - load should decrement + drop(_guard); + + assert_eq!(prefill_worker.load(), 0); + assert_eq!(decode_worker.load(), 0); + } + + // ============= Concurrent Operations Tests ============= + + #[tokio::test] + async fn test_concurrent_worker_operations() { + let router = Arc::new(create_test_pd_router()); + + let mut handles = vec![]; + + // Spawn tasks to add workers + for i in 0..5 { + let router_clone = Arc::clone(&router); + let url = format!("http://worker{}", i); + let handle = tokio::spawn(async move { + let worker = create_test_worker( + url, + WorkerType::Prefill { + bootstrap_port: None, + }, + true, + ); + router_clone.prefill_workers.write().unwrap().push(worker); + }); + handles.push(handle); + } + + // Wait for all tasks + for handle in handles { + let _ = handle.await; + } + + // Check final state + let workers = router.prefill_workers.read().unwrap(); + assert_eq!(workers.len(), 5); + } +} diff --git a/sgl-router/tests/test_pd_routing.rs b/sgl-router/tests/test_pd_routing.rs index ceb5fe9e69d3..a2c0d7e3197d 100644 --- a/sgl-router/tests/test_pd_routing.rs +++ b/sgl-router/tests/test_pd_routing.rs @@ -1,16 +1,3 @@ -//! Comprehensive tests for PrefillDecode (PD) routing functionality -//! -//! This test suite covers: -//! - Phase 1: Basic PD router creation and configuration -//! - Phase 2: Bootstrap injection and request handling -//! - Phase 3: Cache-aware selection (when implemented) -//! -//! Note: PD mode is enabled via the pd_disaggregation flag, not as a policy type. -//! The policy type (Random, PowerOfTwo, CacheAware) determines the selection algorithm within PD mode. - -// TODO: This test file needs to be updated for the new configuration structure -// where RoutingMode and PolicyConfig are separate - #[cfg(test)] mod test_pd_routing { use rand::Rng; @@ -921,14 +908,6 @@ mod test_pd_routing { #[test] fn test_policy_type_to_pd_selection_policy_mapping() { - // Document the mapping from PolicyType to PDSelectionPolicy - // This mapping happens in lib.rs when pd_disaggregation=true - - // PolicyType::Random -> PDSelectionPolicy::Random - // PolicyType::PowerOfTwo -> PDSelectionPolicy::PowerOfTwo - // PolicyType::CacheAware -> PDSelectionPolicy::CacheAware { ... } - // PolicyType::RoundRobin -> ERROR (not supported in PD mode) - // Test that PDSelectionPolicy doesn't include RoundRobin let pd_policy_count = 3; // Random, PowerOfTwo, CacheAware assert_eq!( From 9b5de6cb069ba7af66de45762dab489941ad0947 Mon Sep 17 00:00:00 2001 From: Simo Lin Date: Sun, 20 Jul 2025 23:13:20 -0700 Subject: [PATCH 081/396] [router] upgade router version to 0.1.6 (#8209) --- sgl-router/pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sgl-router/pyproject.toml b/sgl-router/pyproject.toml index 915a15de966d..7422aa6bb428 100644 --- a/sgl-router/pyproject.toml +++ b/sgl-router/pyproject.toml @@ -4,8 +4,8 @@ build-backend = "setuptools.build_meta" [project] name = "sglang-router" -version = "0.1.5" -description = "SGLang router is a standalone module implemented in Rust to achieve data parallelism across SGLang instances." +version = "0.1.6" +description = "High-performance Rust-based load balancer for SGLang with multiple routing algorithms and prefill-decode disaggregation support" authors = [{name = "Byron Hsu", email = "byronhsu1230@gmail.com"}] requires-python = ">=3.8" readme = "README.md" From 6936be32210fdf16b0159b2de3f1b8a27e5a679d Mon Sep 17 00:00:00 2001 From: Ke Bao Date: Mon, 21 Jul 2025 15:37:00 +0800 Subject: [PATCH 082/396] Remve router gemm output dtype conversion (#8204) --- python/sglang/srt/models/deepseek_v2.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index a65337945f6b..e02d30839007 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -254,9 +254,8 @@ def forward(self, hidden_states): and self.weight.shape[0] == 256 and _device_sm >= 90 ): - logits = dsv3_router_gemm(hidden_states, self.weight).to( - hidden_states.dtype - ) + # router gemm output float32 + logits = dsv3_router_gemm(hidden_states, self.weight) else: logits = F.linear(hidden_states, self.weight, None) From 74f59ae55557b307484fedace0ee30a41b384ab2 Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Mon, 21 Jul 2025 02:10:24 -0700 Subject: [PATCH 083/396] chore: upgrade sgl-kernel 0.2.6.post1 (#8202) --- python/pyproject.toml | 2 +- python/sglang/srt/entrypoints/engine.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/pyproject.toml b/python/pyproject.toml index 5949a100a96e..5f53a5ca328f 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -54,7 +54,7 @@ runtime_common = [ srt = [ "sglang[runtime_common]", - "sgl-kernel==0.2.6", + "sgl-kernel==0.2.6.post1", "torch==2.7.1", "torchaudio==2.7.1", "torchvision==0.22.1", diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index 990fac9a12a7..e2cb02cc3014 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -654,7 +654,7 @@ def _set_envs_and_config(server_args: ServerArgs): if _is_cuda: assert_pkg_version( "sgl-kernel", - "0.2.6", + "0.2.6.post1", "Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`", ) From 7b68d271119655e993232b4785a9cec26e0180ec Mon Sep 17 00:00:00 2001 From: Xiaoze Fan Date: Mon, 21 Jul 2025 22:06:15 +0800 Subject: [PATCH 084/396] [Feature] Add a test for Layer-wise Prefill (#8231) Signed-off-by: jason-fxz --- test/srt/test_forward_split_prefill.py | 299 +++++++++++++++++++++++++ 1 file changed, 299 insertions(+) create mode 100644 test/srt/test_forward_split_prefill.py diff --git a/test/srt/test_forward_split_prefill.py b/test/srt/test_forward_split_prefill.py new file mode 100644 index 000000000000..bbd247583f84 --- /dev/null +++ b/test/srt/test_forward_split_prefill.py @@ -0,0 +1,299 @@ +""" +Test forward_split_prefill functionality. + +Usage: +python3 -m unittest test_forward_split_prefill.TestForwardSplitPrefill +or +python3 test_forward_split_prefill.py +""" + +import time +import unittest + +import numpy as np +import torch + +from sglang.srt.configs.model_config import ModelConfig +from sglang.srt.hf_transformers_utils import get_tokenizer +from sglang.srt.managers.schedule_batch import Req, ScheduleBatch +from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode +from sglang.srt.model_executor.model_runner import ModelRunner +from sglang.srt.sampling.sampling_params import SamplingParams +from sglang.srt.server_args import PortArgs, ServerArgs +from sglang.srt.speculative.spec_info import SpeculativeAlgorithm +from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST, CustomTestCase + + +class TestForwardSplitPrefill(CustomTestCase): + """Test cases for forward_split_prefill functionality.""" + + @classmethod + def setUpClass(cls): + """Set up the test environment once for all tests.""" + cls.model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST + cls.tp_size = 1 + cls.device = "cuda" + + # Initialize server args + cls.server_args = ServerArgs( + model_path=cls.model_path, + tokenizer_path=cls.model_path, + host="127.0.0.1", + disable_cuda_graph=True, # Disable CUDA graph for testing split prefill + disable_hybrid_swa_memory=True, + port=30000, + tp_size=cls.tp_size, + mem_fraction_static=0.8, + trust_remote_code=True, + ) + + cls.port_args = PortArgs.init_new(cls.server_args) + + # Load model and tokenizer + cls.model_config = ModelConfig.from_server_args(cls.server_args) + cls.model_runner = ModelRunner( + model_config=cls.model_config, + mem_fraction_static=cls.server_args.mem_fraction_static, + gpu_id=0, + tp_rank=0, + tp_size=cls.tp_size, + pp_rank=0, + pp_size=1, + nccl_port=cls.port_args.nccl_port, + server_args=cls.server_args, + ) + + cls.tokenizer = get_tokenizer( + cls.server_args.tokenizer_path, + tokenizer_mode=cls.server_args.tokenizer_mode, + trust_remote_code=cls.server_args.trust_remote_code, + ) + + print( + f"Test with model: {cls.model_path}, num_hidden_layers: {cls.model_config.num_hidden_layers}" + ) + + def prepare_test_batch(self, batch_size=2, input_len=128, is_split_prefill=True): + """Prepare a test batch for split prefill testing.""" + # Create synthetic input + input_ids = np.random.randint(10, 1000, (batch_size, input_len), dtype=np.int32) + + sampling_params = SamplingParams( + temperature=0.0, + max_new_tokens=8, + ) + + reqs = [] + for i in range(batch_size): + req = Req( + rid=i, + origin_input_text="", + origin_input_ids=list(input_ids[i]), + sampling_params=sampling_params, + ) + req.prefix_indices = [] + req.fill_ids = req.origin_input_ids + req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices) + req.logprob_start_len = len(req.origin_input_ids) - 1 + reqs.append(req) + + batch = ScheduleBatch.init_new( + reqs=reqs, + req_to_token_pool=self.model_runner.req_to_token_pool, + token_to_kv_pool_allocator=self.model_runner.token_to_kv_pool_allocator, + tree_cache=None, + model_config=self.model_config, + enable_overlap=False, + spec_algorithm=SpeculativeAlgorithm.NONE, + enable_custom_logit_processor=False, + ) + if is_split_prefill: + batch.prepare_for_split_prefill() + else: + batch.prepare_for_extend() + + # Create forward batch + model_worker_batch = batch.get_model_worker_batch() + forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) + + return forward_batch + + def test_split_prefill_functionality(self): + """Test that split prefill can complete successfully.""" + print("\n=== Testing split prefill functionality ===") + + forward_batch = self.prepare_test_batch(batch_size=2, input_len=64) + + # Reset split index + forward_batch.split_index = 0 + + # Test split prefill in chunks + num_layers = self.model_config.num_hidden_layers + chunk_size = max(1, num_layers // 4) # Split into 4 chunks + + results = [] + split_count = 0 + + while forward_batch.split_index < num_layers: + print( + f"Processing split {split_count}, split_index: {forward_batch.split_index}" + ) + + result = self.model_runner.forward_split_prefill( + forward_batch=forward_batch, + reinit_attn_backend=(split_count == 0), + forward_count=chunk_size, + ) + + results.append(result) + split_count += 1 + + # Verify split_index is updated correctly + expected_next_index = min(split_count * chunk_size, num_layers) + self.assertEqual(forward_batch.split_index, expected_next_index) + + # The last result should contain logits + self.assertIsNotNone(results[-1], "Final split should return logits") + print(f"Split prefill completed in {split_count} splits") + + def test_split_prefill_vs_normal_prefill(self): + """Test that split prefill produces the same results as normal prefill.""" + print("\n=== Testing split prefill vs normal prefill consistency ===") + + forward_batch_normal = self.prepare_test_batch( + batch_size=2, input_len=128, is_split_prefill=False + ) + forward_batch_split = self.prepare_test_batch( + batch_size=2, input_len=128, is_split_prefill=True + ) + + # Ensure same input + forward_batch_split.input_ids = forward_batch_normal.input_ids.clone() + forward_batch_split.positions = forward_batch_normal.positions.clone() + + # Method 1: Normal extend (prefill) + print("Running normal extend (prefill)...") + normal_result = self.model_runner.forward_extend(forward_batch_normal) + + # Method 2: Split prefill + print("Running split prefill...") + num_layers = self.model_config.num_hidden_layers + chunk_size = max(1, num_layers // 3) # Split into 3 chunks + + split_result = None + + while forward_batch_split.split_index < num_layers: + result = self.model_runner.forward_split_prefill( + forward_batch=forward_batch_split, + forward_count=chunk_size, + ) + if result is not None: + split_result = result + + # Compare results + self.assertIsNotNone(normal_result, "Normal prefill should return result") + self.assertIsNotNone(split_result, "Split prefill should return result") + + # Compare logits shapes + self.assertEqual( + normal_result.next_token_logits.shape, + split_result.next_token_logits.shape, + "Logits shapes should match", + ) + + # Compare logits values (should be very close due to same computation) + # Use a larger tolerance for numerical differences in split computation + torch.testing.assert_close( + normal_result.next_token_logits, + split_result.next_token_logits, + rtol=1e-3, + atol=1e-3, + msg="Split prefill and normal prefill should produce similar logits", + ) + + print("✓ Split prefill and normal prefill produce consistent results") + + def test_split_prefill_different_chunk_sizes(self): + """Test split prefill with different chunk sizes.""" + print("\n=== Testing split prefill with different chunk sizes ===") + + num_layers = self.model_config.num_hidden_layers + chunk_sizes = [1, 2, max(1, num_layers // 2), num_layers] + + # Prepare identical batches for each test + base_batch = self.prepare_test_batch(batch_size=1, input_len=16) + base_input_ids = base_batch.input_ids.clone() + base_positions = base_batch.positions.clone() + + results = [] + + for chunk_size in chunk_sizes: + if chunk_size > num_layers: + continue + + print(f"Testing chunk size: {chunk_size}") + + # Prepare fresh batch + forward_batch = self.prepare_test_batch(batch_size=1, input_len=16) + forward_batch.input_ids = base_input_ids.clone() + forward_batch.positions = base_positions.clone() + forward_batch.split_index = 0 + + # Run split prefill + split_result = None + + while forward_batch.split_index < num_layers: + result = self.model_runner.forward_split_prefill( + forward_batch=forward_batch, + forward_count=chunk_size, + ) + if result is not None: + split_result = result + + self.assertIsNotNone( + split_result, + f"Split prefill should succeed with chunk_size={chunk_size}", + ) + results.append(split_result) + + # Compare all results should be identical (same input, same computation) + if len(results) > 1: + for i, result in enumerate(results[1:], 1): + torch.testing.assert_close( + results[0].next_token_logits, + result.next_token_logits, + rtol=1e-3, + atol=1e-3, + msg=f"Results with different chunk sizes should be identical (chunk_size {chunk_sizes[i]})", + ) + + print("✓ All chunk sizes produce consistent results") + + def test_split_prefill_edge_cases(self): + """Test edge cases for split prefill.""" + print("\n=== Testing split prefill edge cases ===") + + # Test with single layer chunks + forward_batch = self.prepare_test_batch(batch_size=1, input_len=8) + + # Process one layer at a time + num_layers = self.model_config.num_hidden_layers + for layer_idx in range(num_layers): + result = self.model_runner.forward_split_prefill( + forward_batch=forward_batch, + reinit_attn_backend=(layer_idx == 0), + forward_count=1, # One layer at a time + ) + + if layer_idx == num_layers - 1: + # Last layer should return result + self.assertIsNotNone(result, "Last layer should return logits") + else: + # Intermediate layers should return None + self.assertIsNone(result, f"Layer {layer_idx} should return None") + + print("✓ Single layer processing works correctly") + + +if __name__ == "__main__": + unittest.main() From 114837854fdc1c94d36ce0ffcde6cd0d16f87a97 Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Mon, 21 Jul 2025 14:02:48 -0700 Subject: [PATCH 085/396] docs: update 2025 h2 roadmap (#8237) --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index b19a9cdabfc0..0a0a78577228 100644 --- a/README.md +++ b/README.md @@ -62,7 +62,7 @@ The core features include: Learn more in the release blogs: [v0.2 blog](https://lmsys.org/blog/2024-07-25-sglang-llama3/), [v0.3 blog](https://lmsys.org/blog/2024-09-04-sglang-v0-3/), [v0.4 blog](https://lmsys.org/blog/2024-12-04-sglang-v0-4/), [Large-scale expert parallelism](https://lmsys.org/blog/2025-05-05-large-scale-ep/). ## Roadmap -[Development Roadmap (2025 H1)](https://github.com/sgl-project/sglang/issues/4042) +[Development Roadmap (2025 H2)](https://github.com/sgl-project/sglang/issues/7736) ## Adoption and Sponsorship SGLang has been deployed at large scale, generating trillions of tokens in production each day. It is trusted and adopted by a wide range of leading enterprises and institutions, including xAI, AMD, NVIDIA, Intel, LinkedIn, Cursor, Oracle Cloud, Google Cloud, Microsoft Azure, AWS, Atlas Cloud, Voltage Park, Nebius, DataCrunch, Novita, InnoMatrix, MIT, UCLA, the University of Washington, Stanford, UC Berkeley, Tsinghua University, Jam & Tea Studios, Baseten, and other major technology organizations across North America and Asia. As an open-source LLM inference engine, SGLang has become the de facto industry standard, with deployments running on over 1,000,000 GPUs worldwide. From 69adc4f81c56403803840e49e4fe5385667bb55f Mon Sep 17 00:00:00 2001 From: Xinyuan Tong <115166877+JustinTong0323@users.noreply.github.com> Date: Mon, 21 Jul 2025 17:06:35 -0700 Subject: [PATCH 086/396] fix: retrieve mm token by modality, raise error if none (#8221) Signed-off-by: Xinyuan Tong Co-authored-by: Xinyuan Tong --- .../multimodal/processors/base_processor.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/multimodal/processors/base_processor.py b/python/sglang/srt/multimodal/processors/base_processor.py index b79d90b987ea..3d548a19ee9e 100644 --- a/python/sglang/srt/multimodal/processors/base_processor.py +++ b/python/sglang/srt/multimodal/processors/base_processor.py @@ -101,6 +101,14 @@ def get_modality_of_token(self, token: str) -> Optional[Modality]: return None + def get_token_id_by_modality(self, modality: Modality) -> Optional[int]: + return { + Modality.IMAGE: self.image_token_id, + Modality.MULTI_IMAGES: self.image_token_id, + Modality.VIDEO: self.video_token_id, + Modality.AUDIO: self.audio_token_id, + }.get(modality) + def parse_regex(self): if self.image_token_regex is None and self.image_token is not None: self.image_token_regex = re.compile(re.escape(self.image_token)) @@ -608,14 +616,12 @@ def process_and_combine_mm_data( # Add offsets to all items for mm_item in all_collected_items: + mm_token_id = mm_tokens.get_token_id_by_modality(mm_item.modality) + if mm_token_id is None: + raise ValueError(f"No token id found for modality: {mm_item.modality}") mm_item.offsets = self.get_mm_items_offset( input_ids=input_ids, - mm_token_id={ - Modality.IMAGE: mm_tokens.image_token_id, - Modality.MULTI_IMAGES: mm_tokens.image_token_id, - Modality.VIDEO: mm_tokens.video_token_id, - Modality.AUDIO: mm_tokens.audio_token_id, - }.get(mm_item.modality, None), + mm_token_id=mm_token_id, ) return all_collected_items, input_ids, ret From e50109f2edfec7cc48a56c41b05fcaef3190087f Mon Sep 17 00:00:00 2001 From: Hubert Lu <55214931+hubertlu-tw@users.noreply.github.com> Date: Mon, 21 Jul 2025 17:33:19 -0700 Subject: [PATCH 087/396] [AMD] Remove vllm's scaled_fp8_quant and moe_sum when SGLANG_USE_AITER=1 (#7484) --- python/sglang/srt/layers/moe/ep_moe/layer.py | 5 +- .../layers/moe/fused_moe_triton/fused_moe.py | 26 ++- .../compressed_tensors_moe.py | 5 +- python/sglang/srt/layers/quantization/fp8.py | 3 +- .../srt/layers/quantization/fp8_kernel.py | 161 +++++++++++++----- .../sglang/srt/layers/quantization/unquant.py | 1 - .../sglang/srt/layers/quantization/utils.py | 5 +- python/sglang/test/test_custom_ops.py | 19 ++- 8 files changed, 156 insertions(+), 69 deletions(-) diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index 77d849f3f67b..83f74fb27019 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -54,14 +54,11 @@ _is_fp8_fnuz = is_fp8_fnuz() _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip -if not _is_npu: +if not (_is_npu or _is_hip): from sgl_kernel import silu_and_mul from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe -if _is_hip: - from vllm._custom_ops import scaled_fp8_quant - if _use_aiter: from aiter import ActivationType, QuantType from aiter.fused_moe import fused_moe diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py index 2466067461cf..9c13c7e9dcb5 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py @@ -39,11 +39,20 @@ _is_cuda = is_cuda() _is_cpu_amx_available = cpu_has_amx_support() _is_cpu = is_cpu() +_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip if _is_cuda: from sgl_kernel import gelu_and_mul, silu_and_mul elif _is_cpu and _is_cpu_amx_available: pass +elif _is_hip: + from vllm import _custom_ops as vllm_ops # gelu_and_mul, silu_and_mul + + if _use_aiter: + try: + from aiter import moe_sum + except ImportError: + raise ImportError("aiter is required when SGLANG_USE_AITER is set to True") else: from vllm import _custom_ops as vllm_ops from vllm._custom_ops import scaled_fp8_quant @@ -1521,11 +1530,7 @@ def fused_experts_impl( routed_scaling_factor: Optional[float] = None, ): padded_size = padding_size - if ( - not (use_fp8_w8a8 or use_int8_w8a8) - or block_shape is not None - or (_is_hip and get_bool_env_var("SGLANG_USE_AITER")) - ): + if not (use_fp8_w8a8 or use_int8_w8a8) or block_shape is not None or _use_aiter: padded_size = 0 # Check constraints. @@ -1723,6 +1728,17 @@ def fused_experts_impl( out_hidden_states[begin_chunk_idx:end_chunk_idx], routed_scaling_factor, ) + elif _is_hip: + if _use_aiter: + moe_sum( + intermediate_cache3.view(*intermediate_cache3.shape), + out_hidden_states[begin_chunk_idx:end_chunk_idx], + ) + else: + vllm_ops.moe_sum( + intermediate_cache3.view(*intermediate_cache3.shape), + out_hidden_states[begin_chunk_idx:end_chunk_idx], + ) else: vllm_ops.moe_sum( intermediate_cache3.view(*intermediate_cache3.shape), diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 39e5f9e252da..af1f6cbf7cc2 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -20,7 +20,7 @@ per_tensor_dequantize, replace_parameter, ) -from sglang.srt.utils import is_cpu, is_cuda, is_npu, set_weight_attrs +from sglang.srt.utils import is_cpu, is_cuda, is_hip, is_npu, set_weight_attrs if TYPE_CHECKING: from sglang.srt.layers.moe.topk import TopKOutput @@ -32,8 +32,9 @@ _is_npu = is_npu() _is_cpu_amx_available = cpu_has_amx_support() _is_cpu = is_cpu() +_is_hip = is_hip() -if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available)): +if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available) or _is_hip): from vllm import _custom_ops as vllm_ops from vllm._custom_ops import scaled_fp8_quant diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index 23daa5d26fb8..6fa3ccc59ee5 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -95,10 +95,9 @@ def dummy_func(*args, **kwargs): if _is_hip and (_use_aiter or _use_hip_int4): from aiter import ActivationType, QuantType from aiter.fused_moe import fused_moe - from aiter.fused_moe_bf16_asm import asm_moe, ck_moe_2stages from aiter.ops.shuffle import shuffle_weight -if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available)): +if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available) or _is_hip): from vllm._custom_ops import scaled_fp8_quant diff --git a/python/sglang/srt/layers/quantization/fp8_kernel.py b/python/sglang/srt/layers/quantization/fp8_kernel.py index 79504265c299..b488a65c08d9 100644 --- a/python/sglang/srt/layers/quantization/fp8_kernel.py +++ b/python/sglang/srt/layers/quantization/fp8_kernel.py @@ -27,6 +27,7 @@ from sglang.srt.utils import ( align, direct_register_custom_op, + get_bool_env_var, get_device_core_count, get_device_name, is_cpu, @@ -39,6 +40,7 @@ _is_hip = is_hip() _is_cuda = is_cuda() _is_cpu = is_cpu() +_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip if _is_cuda: from sgl_kernel import ( @@ -47,6 +49,22 @@ sgl_per_token_quant_fp8, ) +if _is_hip: + if _use_aiter: + try: + from aiter import ( # v0.1.3 + dynamic_per_tensor_quant, + dynamic_per_token_scaled_quant, + static_per_tensor_quant, + ) + except ImportError: + raise ImportError("aiter is required when SGLANG_USE_AITER is set to True") + else: + try: + import vllm._C + except ImportError: + raise ImportError("vllm is required when SGLANG_USE_AITER is set to False") + logger = logging.getLogger(__name__) @@ -1116,58 +1134,109 @@ def per_token_group_quant_mla_deep_gemm_masked_fp8( return x_q, x_s.transpose(1, 2), masked_m, m, aligned_m -def scaled_fp8_quant( - input: torch.Tensor, - scale: Optional[torch.Tensor] = None, - num_token_padding: Optional[int] = None, - use_per_token_if_dynamic: bool = False, -) -> tuple[torch.Tensor, torch.Tensor]: - """ - Quantize input tensor to FP8 (8-bit floating point) format. +""" +Quantize input tensor to FP8 (8-bit floating point) format. + +Args: + input (torch.Tensor): Input tensor to be quantized + scale (Optional[torch.Tensor]): Pre-computed scaling factor for static quantization. + If None, scales will be computed dynamically. + num_token_padding (Optional[int]): If specified, pad the first dimension + of the output to at least this value. + use_per_token_if_dynamic (bool): When using dynamic scaling (scale=None), + determines the quantization granularity: + - True: compute scale per token + - False: compute single scale per tensor + +Returns: + Tuple[torch.Tensor, torch.Tensor]: A tuple containing: + - quantized_tensor: The FP8 quantized version of input + - scale_tensor: The scaling factors used for quantization + +Raises: + AssertionError: If input is not 2D or if static scale's numel != 1 +""" +if _is_hip: - Args: - input (torch.Tensor): Input tensor to be quantized - scale (Optional[torch.Tensor]): Pre-computed scaling factor for static quantization. - If None, scales will be computed dynamically. - num_token_padding (Optional[int]): If specified, pad the first dimension - of the output to at least this value. - use_per_token_if_dynamic (bool): When using dynamic scaling (scale=None), - determines the quantization granularity: - - True: compute scale per token - - False: compute single scale per tensor + def scaled_fp8_quant( + input: torch.Tensor, + scale: Optional[torch.Tensor] = None, + num_token_padding: Optional[int] = None, + use_per_token_if_dynamic: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor]: + assert input.ndim == 2, f"Expected 2D input tensor, got {input.ndim}D" + shape = input.shape + if num_token_padding: + shape = (max(num_token_padding, input.shape[0]), shape[1]) + output = torch.empty(shape, device=input.device, dtype=fp8_dtype) + + if scale is None: + # Dynamic scaling + if use_per_token_if_dynamic: + scale = torch.empty( + (shape[0], 1), device=input.device, dtype=torch.float32 + ) + if _use_aiter: + dynamic_per_token_scaled_quant(output, input, scale) + else: + torch.ops._C.dynamic_per_token_scaled_fp8_quant( + output, input.contiguous(), scale, None + ) + else: + scale = torch.zeros(1, device=input.device, dtype=torch.float32) + if _use_aiter: + dynamic_per_tensor_quant(output, input, scale) + else: + torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale) + else: + # Static scaling + assert ( + scale.numel() == 1 + ), f"Expected scalar scale, got numel={scale.numel()}" + if _use_aiter: + static_per_tensor_quant(output, input, scale) + else: + torch.ops._C.static_scaled_fp8_quant(output, input, scale) - Returns: - Tuple[torch.Tensor, torch.Tensor]: A tuple containing: - - quantized_tensor: The FP8 quantized version of input - - scale_tensor: The scaling factors used for quantization + return output, scale - Raises: - AssertionError: If input is not 2D or if static scale's numel != 1 - """ - assert input.ndim == 2, f"Expected 2D input tensor, got {input.ndim}D" - shape = input.shape - if num_token_padding: - shape = (max(num_token_padding, input.shape[0]), shape[1]) - output = torch.empty(shape, device=input.device, dtype=fp8_dtype) - - if scale is None: - # Dynamic scaling - if use_per_token_if_dynamic: - scale = torch.empty((shape[0], 1), device=input.device, dtype=torch.float32) - sgl_per_token_quant_fp8(input, output, scale) +else: + + def scaled_fp8_quant( + input: torch.Tensor, + scale: Optional[torch.Tensor] = None, + num_token_padding: Optional[int] = None, + use_per_token_if_dynamic: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor]: + + assert input.ndim == 2, f"Expected 2D input tensor, got {input.ndim}D" + shape = input.shape + if num_token_padding: + shape = (max(num_token_padding, input.shape[0]), shape[1]) + output = torch.empty(shape, device=input.device, dtype=fp8_dtype) + + if scale is None: + # Dynamic scaling + if use_per_token_if_dynamic: + scale = torch.empty( + (shape[0], 1), device=input.device, dtype=torch.float32 + ) + sgl_per_token_quant_fp8(input, output, scale) + else: + scale = torch.zeros(1, device=input.device, dtype=torch.float32) + sgl_per_tensor_quant_fp8( + input, output, scale, is_static=False + ) # False for dynamic else: - scale = torch.zeros(1, device=input.device, dtype=torch.float32) + # Static scaling + assert ( + scale.numel() == 1 + ), f"Expected scalar scale, got numel={scale.numel()}" sgl_per_tensor_quant_fp8( - input, output, scale, is_static=False - ) # False for dynamic - else: - # Static scaling - assert scale.numel() == 1, f"Expected scalar scale, got numel={scale.numel()}" - sgl_per_tensor_quant_fp8( - input, output, scale, is_static=True - ) # True for static + input, output, scale, is_static=True + ) # True for static - return output, scale + return output, scale fp8_autotune = triton.autotune( diff --git a/python/sglang/srt/layers/quantization/unquant.py b/python/sglang/srt/layers/quantization/unquant.py index fa4cbf582027..ddafcc6f5d9f 100644 --- a/python/sglang/srt/layers/quantization/unquant.py +++ b/python/sglang/srt/layers/quantization/unquant.py @@ -37,7 +37,6 @@ if _use_aiter: from aiter import ActivationType from aiter.fused_moe import fused_moe - from aiter.fused_moe_bf16_asm import ck_moe_2stages from aiter.ops.shuffle import shuffle_weight diff --git a/python/sglang/srt/layers/quantization/utils.py b/python/sglang/srt/layers/quantization/utils.py index 89e0eb84a2e6..8904247a6a8f 100644 --- a/python/sglang/srt/layers/quantization/utils.py +++ b/python/sglang/srt/layers/quantization/utils.py @@ -12,7 +12,7 @@ from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant from sglang.srt.layers.quantization.scalar_type import ScalarType, scalar_types -from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_npu +from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_hip, is_npu if TYPE_CHECKING: from sglang.srt.layers.quantization.base_config import QuantizationConfig @@ -21,8 +21,9 @@ _is_npu = is_npu() _is_cpu_amx_available = cpu_has_amx_support() _is_cpu = is_cpu() +_is_hip = is_hip() -if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available)): +if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available) or _is_hip): from vllm._custom_ops import scaled_fp8_quant diff --git a/python/sglang/test/test_custom_ops.py b/python/sglang/test/test_custom_ops.py index 873f9960e0f9..c07c95db6998 100644 --- a/python/sglang/test/test_custom_ops.py +++ b/python/sglang/test/test_custom_ops.py @@ -3,8 +3,13 @@ import pytest import torch -from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant -from sglang.srt.utils import is_cuda +from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz, scaled_fp8_quant +from sglang.srt.utils import is_cuda, is_hip + +_is_cuda = is_cuda() +_is_hip = is_hip() +_is_fp8_fnuz = is_fp8_fnuz() +fp8_dtype = torch.float8_e4m3fnuz if _is_fp8_fnuz else torch.float8_e4m3fn @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @@ -13,10 +18,10 @@ def test_scaled_fp8_quant_per_tensor(dtype) -> None: def quantize_ref_per_tensor(tensor, inv_scale): # The reference implementation that fully aligns to # the kernel being tested. - finfo = torch.finfo(torch.float8_e4m3fn) + finfo = torch.finfo(fp8_dtype) scale = inv_scale.reciprocal() qweight = (tensor.to(torch.float32) * scale).clamp(min=finfo.min, max=finfo.max) - qweight = qweight.to(torch.float8_e4m3fn) + qweight = qweight.to(fp8_dtype) return qweight def dequantize_per_tensor(tensor, inv_scale, dtype): @@ -48,19 +53,19 @@ def dequantize_per_tensor(tensor, inv_scale, dtype): ) -if is_cuda: +if _is_cuda or _is_hip: @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) def test_scaled_fp8_quant_per_token_dynamic(dtype) -> None: def quantize_ref_per_token(tensor, inv_scale): # The reference implementation that fully aligns to # the kernel being tested. - finfo = torch.finfo(torch.float8_e4m3fn) + finfo = torch.finfo(fp8_dtype) scale = inv_scale.reciprocal() qweight = (tensor.to(torch.float32) * scale).clamp( min=finfo.min, max=finfo.max ) - qweight = qweight.to(torch.float8_e4m3fn) + qweight = qweight.to(fp8_dtype) return qweight def dequantize_per_token(tensor, inv_scale, dtype): From c33499a67b3e7bf62facdb3f59b36822a4bea2fb Mon Sep 17 00:00:00 2001 From: Rui Chen Date: Tue, 22 Jul 2025 23:41:23 +0800 Subject: [PATCH 088/396] fix: sgl-router remove dead code (#8257) --- sgl-router/src/routers/router.rs | 76 -------------------------------- 1 file changed, 76 deletions(-) diff --git a/sgl-router/src/routers/router.rs b/sgl-router/src/routers/router.rs index c198b0c1dba5..84bb28fb58e8 100644 --- a/sgl-router/src/routers/router.rs +++ b/sgl-router/src/routers/router.rs @@ -283,82 +283,6 @@ impl Router { HttpResponse::InternalServerError().body("All retry attempts failed") } - pub async fn route_to_all( - &self, - client: &reqwest::Client, - route: &str, - req: &HttpRequest, - ) -> HttpResponse { - // Get all worker URLs - let worker_urls = self.get_worker_urls(); - - // Send requests to all workers concurrently - let mut tasks = Vec::new(); - for worker_url in &worker_urls { - let mut request_builder = client.post(format!("{}{}", worker_url, route)); - - // Copy headers from original request - for (name, value) in copy_request_headers(req) { - request_builder = request_builder.header(name, value); - } - - tasks.push(request_builder.send()); - } - - // Wait for all responses - let results = futures_util::future::join_all(tasks).await; - - // Check if all succeeded - let all_success = results.iter().all(|r| { - r.as_ref() - .map(|res| res.status().is_success()) - .unwrap_or(false) - }); - - if all_success { - HttpResponse::Ok().body("Operation completed on all servers") - } else { - HttpResponse::InternalServerError().body("Operation failed on one or more servers") - } - } - - pub async fn get_all_loads( - &self, - client: &reqwest::Client, - _req: &HttpRequest, - ) -> HttpResponse { - let urls = self.get_worker_urls(); - let prefill_urls: Vec = Vec::new(); - let decode_urls = urls; - - // Collect loads from all servers - let mut prefill_loads = Vec::new(); - let mut decode_loads = Vec::new(); - - // Get prefill loads - for url in &prefill_urls { - let load = self.get_worker_load(client, url).await.unwrap_or(-1); - prefill_loads.push(serde_json::json!({ - "engine": format!("(Prefill@{})", url), - "load": load as i64 - })); - } - - // Get decode loads - for url in &decode_urls { - let load = self.get_worker_load(client, url).await.unwrap_or(-1); - decode_loads.push(serde_json::json!({ - "engine": format!("(Decode@{})", url), - "load": load as i64 - })); - } - - HttpResponse::Ok().json(serde_json::json!({ - "prefill": prefill_loads, - "decode": decode_loads - })) - } - // New method to route typed requests directly pub async fn route_typed_request< T: crate::openai_api_types::GenerationRequest + serde::Serialize + Clone, From 0f8b5386145c3c326fcd37d725df56738f7a93e3 Mon Sep 17 00:00:00 2001 From: Peter Pan Date: Tue, 22 Jul 2025 23:55:35 +0800 Subject: [PATCH 089/396] [fix] benchmark : routed_scaling_factor is None (#8059) Co-authored-by: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com> --- sgl-kernel/benchmark/bench_moe_fused_gate.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/sgl-kernel/benchmark/bench_moe_fused_gate.py b/sgl-kernel/benchmark/bench_moe_fused_gate.py index 2405c49b6c93..36cc9c4984fd 100644 --- a/sgl-kernel/benchmark/bench_moe_fused_gate.py +++ b/sgl-kernel/benchmark/bench_moe_fused_gate.py @@ -18,10 +18,13 @@ def biased_grouped_topk_org(scores, bias, num_expert_group, topk_group, topk): renormalize=True, num_expert_group=num_expert_group, topk_group=topk_group, + routed_scaling_factor=2.5, # DeepSeek-R1 : 2.5, Kimi K2: 2.872 ) -def biased_grouped_topk_org_kernel(scores, bias, num_expert_group, topk_group, topk): +def biased_grouped_topk_org_fuse_kernel( + scores, bias, num_expert_group, topk_group, topk +): return moe_fused_gate(scores, bias, num_expert_group, topk_group, topk) @@ -61,7 +64,7 @@ def benchmark(seq_length, provider): ) elif provider == "kernel": ms, min_ms, max_ms = triton.testing.do_bench( - lambda: biased_grouped_topk_org_kernel( + lambda: biased_grouped_topk_org_fuse_kernel( scores.clone(), bias.clone(), num_expert_group, topk_group, topk ), quantiles=quantiles, From ff45ab7a5fa726193d4d4a01fae4e85cf775ac41 Mon Sep 17 00:00:00 2001 From: zhongwei <974337380@qq.com> Date: Wed, 23 Jul 2025 05:02:40 +0800 Subject: [PATCH 090/396] [Benchmark] add disable-auto-run param for hicache/bench_multiturn (#7822) Co-authored-by: zhongwei.ren Co-authored-by: Zhiqiang Xie --- benchmark/hicache/bench_multiturn.py | 67 +++++++++++++++++++++++++--- 1 file changed, 62 insertions(+), 5 deletions(-) diff --git a/benchmark/hicache/bench_multiturn.py b/benchmark/hicache/bench_multiturn.py index 5b8d706a399c..5e954ecd6466 100644 --- a/benchmark/hicache/bench_multiturn.py +++ b/benchmark/hicache/bench_multiturn.py @@ -9,6 +9,7 @@ from typing import Optional import aiohttp +import numpy as np import requests from tqdm.asyncio import tqdm @@ -97,6 +98,30 @@ def parse_args(): default="performance_metrics.jsonl", help="File to log performance metrics", ) + parser.add_argument( + "--disable-auto-run", + action="store_true", + help="If set, disable automatically testing with a range of request rates.", + ) + + parser.add_argument( + "--disable-random-sample", + action="store_true", + help="If set, disable random sampling of requests from the ShareGPT dataset.", + ) + parser.add_argument( + "--sub-question-input-length", + type=int, + default=0, + help="Length of the sub question input for each request, if set 0 use request_length", + ) + parser.add_argument( + "--ready-queue-policy", + type=str, + default="random", + help="Policy for popping requests from the ready queue (random or fifo)", + ) + parser.add_argument("--seed", type=int, default=1, help="The random seed.") return parser.parse_args() @@ -234,13 +259,29 @@ def __init__(self, args): self.candidate_inputs = sample_random_requests( input_len=args.request_length, output_len=args.output_length, - num_prompts=args.num_clients * args.num_rounds, + num_prompts=args.num_clients, range_ratio=1.0, tokenizer=self.tokenizer, dataset_path=args.dataset_path, + random_sample=not args.disable_random_sample, ) self.candidate_inputs = [i.prompt for i in self.candidate_inputs] + if args.sub_question_input_length != 0: + sub_question_input_length = args.sub_question_input_length + else: + sub_question_input_length = args.request_length + + self.sub_question_inputs = sample_random_requests( + input_len=sub_question_input_length, + output_len=args.output_length, + num_prompts=args.num_clients * max(args.num_rounds - 1, 1), + range_ratio=1.0, + tokenizer=self.tokenizer, + dataset_path=args.dataset_path, + random_sample=not args.disable_random_sample, + ) + init_requests = [ (i, gen_payload(self.candidate_inputs[i], args.output_length)) for i in range(args.num_clients) @@ -249,7 +290,9 @@ def __init__(self, args): i: {"round": 0, "history": init_requests[i][1]["text"]} for i in range(args.num_clients) } - self.ready_queue = ReadyQueue(init_requests=init_requests) + self.ready_queue = ReadyQueue( + init_requests=init_requests, policy=args.ready_queue_policy + ) self.candidate_inputs = self.candidate_inputs[args.num_clients :] self.response_queue = queue.Queue() @@ -314,9 +357,10 @@ def response_handler(self): self.completed_requests += 1 if self.client_records[client_id]["round"] < args.num_rounds: + # append new request to client's history self.client_records[client_id][ "history" - ] += self.candidate_inputs.pop() + ] += self.sub_question_inputs.pop() self.ready_queue.append( ( client_id, @@ -329,6 +373,9 @@ def response_handler(self): except queue.Empty: if self.pbar.n == self.pbar.total: break + except ValueError as e: + print(f"Error processing response for client {client_id}: {e}") + continue def run(self): request_thread = threading.Thread(target=self.request_sender, daemon=True) @@ -388,8 +435,18 @@ def run(self): args = parse_args() flush_cache_url = f"http://{args.host}:{args.port}/flush_cache" - for request_rate in [16, 14, 12, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1]: - args.request_rate = request_rate + random.seed(args.seed) + np.random.seed(args.seed) + + if args.disable_auto_run: + print("Running with specified request rate...") + request_rates = [args.request_rate] + else: + print("Auto-running with different request rates...") + request_rates = [16, 14, 12, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1] + + for rate in request_rates: + args.request_rate = rate requests.post(flush_cache_url) time.sleep(1) WorkloadGenerator(args).run() From 0dfe2491aceb6847bd1e8845b3443801164d4600 Mon Sep 17 00:00:00 2001 From: yhyang201 <47235274+yhyang201@users.noreply.github.com> Date: Wed, 23 Jul 2025 06:49:38 +0800 Subject: [PATCH 091/396] Preliminary Support for Qwen3XMLDetector (#8260) Co-authored-by: Chayenne --- .../srt/function_call/function_call_parser.py | 2 + .../srt/function_call/qwen3_detector.py | 150 ++++++++++++++++++ python/sglang/srt/server_args.py | 1 + 3 files changed, 153 insertions(+) create mode 100644 python/sglang/srt/function_call/qwen3_detector.py diff --git a/python/sglang/srt/function_call/function_call_parser.py b/python/sglang/srt/function_call/function_call_parser.py index a6708024f876..4c38d9d4fb04 100644 --- a/python/sglang/srt/function_call/function_call_parser.py +++ b/python/sglang/srt/function_call/function_call_parser.py @@ -14,6 +14,7 @@ from sglang.srt.function_call.llama32_detector import Llama32Detector from sglang.srt.function_call.mistral_detector import MistralDetector from sglang.srt.function_call.pythonic_detector import PythonicDetector +from sglang.srt.function_call.qwen3_detector import Qwen3XMLDetector from sglang.srt.function_call.qwen25_detector import Qwen25Detector logger = logging.getLogger(__name__) @@ -35,6 +36,7 @@ class FunctionCallParser: "deepseekv3": DeepSeekV3Detector, "pythonic": PythonicDetector, "kimi_k2": KimiK2Detector, + "qwen3": Qwen3XMLDetector, } def __init__(self, tools: List[Tool], tool_call_parser: str): diff --git a/python/sglang/srt/function_call/qwen3_detector.py b/python/sglang/srt/function_call/qwen3_detector.py new file mode 100644 index 000000000000..5c6ac698e8ea --- /dev/null +++ b/python/sglang/srt/function_call/qwen3_detector.py @@ -0,0 +1,150 @@ +import ast +import html +import json +import logging +import re +from typing import Any, Dict, List, Tuple + +from sglang.srt.entrypoints.openai.protocol import Tool +from sglang.srt.function_call.base_format_detector import BaseFormatDetector +from sglang.srt.function_call.core_types import ( + StreamingParseResult, + StructureInfo, + ToolCallItem, + _GetInfoFunc, +) +from sglang.srt.function_call.ebnf_composer import EBNFComposer + +logger = logging.getLogger(__name__) + + +def _safe_val(raw: str) -> Any: + raw = html.unescape(raw.strip()) + try: + return json.loads(raw) + except Exception: + try: + return ast.literal_eval(raw) + except Exception: + return raw + + +class Qwen3XMLDetector(BaseFormatDetector): + """ + Detector for Qwen 3 models. + Assumes function call format: + + + + pwd && ls + + + + """ + + def __init__(self): + super().__init__() + self.tool_call_start_token: str = "" + self.tool_call_end_token: str = "" + self.tool_call_prefix: str = "(.*?)|(.*?)$", re.DOTALL + ) + self.tool_call_function_regex = re.compile( + r"|| bool: + return self.tool_call_start_token in text + + def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult: + normal, calls = self._extract(text, tools) + return StreamingParseResult(normal_text=normal, calls=calls) + + def parse_streaming_increment( + self, new_text: str, tools: List[Tool] + ) -> StreamingParseResult: + self._buf += new_text + normal = "" + calls: List[ToolCallItem] = [] + while True: + if self.tool_call_start_token not in self._buf: + normal += self._buf + self._buf = "" + break + s = self._buf.find(self.tool_call_start_token) + if s > 0: + normal += self._buf[:s] + self._buf = self._buf[s:] + e = self._buf.find(self.tool_call_end_token) + if e == -1: + break + block = self._buf[: e + len(self.tool_call_end_token)] + self._buf = self._buf[e + len(self.tool_call_end_token) :] + calls.extend(self._parse_block(block, tools)) + return StreamingParseResult(normal_text=normal, calls=calls) + + def _extract(self, text: str, tools: List[Tool]) -> Tuple[str, List[ToolCallItem]]: + normal_parts: List[str] = [] + calls: List[ToolCallItem] = [] + cursor = 0 + while True: + s = text.find(self.tool_call_start_token, cursor) + if s == -1: + normal_parts.append(text[cursor:]) + break + normal_parts.append(text[cursor:s]) + e = text.find(self.tool_call_end_token, s) + if e == -1: + normal_parts.append(text[s:]) + break + block = text[s : e + len(self.tool_call_end_token)] + cursor = e + len(self.tool_call_end_token) + calls.extend(self._parse_block(block, tools)) + return "".join(normal_parts), calls + + def _parse_block(self, block: str, tools: List[Tool]) -> List[ToolCallItem]: + res: List[ToolCallItem] = [] + for m in self.tool_call_function_regex.findall(block): + txt = m[0] if m[0] else m[1] + if ">" not in txt: + continue + idx = txt.index(">") + fname = txt[:idx].strip() + body = txt[idx + 1 :] + params: Dict[str, Any] = {} + for pm in self.tool_call_parameter_regex.findall(body): + ptxt = pm[0] if pm[0] else pm[1] + if ">" not in ptxt: + continue + pidx = ptxt.index(">") + pname = ptxt[:pidx].strip() + pval = ptxt[pidx + 1 :].lstrip("\n").rstrip("\n") + params[pname] = _safe_val(pval) + raw = {"name": fname, "arguments": params} + try: + res.extend(self.parse_base_json(raw, tools)) + except Exception: + logger.warning("invalid tool call for %s dropped", fname) + return res + + def structure_info(self) -> _GetInfoFunc: + return lambda n: StructureInfo( + begin=f"{self.tool_call_start_token}\n", + end=f"\n{self.tool_call_end_token}", + trigger=self.tool_call_start_token, + ) + + # TODO: fake ebnf for xml + outlines backend + def build_ebnf(self, tools: List[Tool]): + return EBNFComposer.build_ebnf( + tools, + individual_call_start_token=self.tool_call_start_token.replace("\n", "\\n"), + individual_call_end_token=self.tool_call_end_token.replace("\n", "\\n"), + tool_call_separator="\\n", + function_format="json", + ) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 6464f9f40a39..400a1bf99e8e 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -1099,6 +1099,7 @@ def add_cli_args(parser: argparse.ArgumentParser): "deepseekv3", "pythonic", "kimi_k2", + "qwen3", ], default=ServerArgs.tool_call_parser, help="Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', 'llama3', 'deepseekv3', 'pythonic', and 'kimi_k2'.", From 01c000043c96e50d3bd33416cf26d394834729cc Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Tue, 22 Jul 2025 15:55:48 -0700 Subject: [PATCH 092/396] chore: bump v0.4.9.post3 (#8265) --- benchmark/deepseek_v3/README.md | 2 +- docs/references/setup_github_runner.md | 4 ++-- docs/start/install.md | 12 ++++++------ python/pyproject.toml | 2 +- python/sglang/version.py | 2 +- 5 files changed, 11 insertions(+), 11 deletions(-) diff --git a/benchmark/deepseek_v3/README.md b/benchmark/deepseek_v3/README.md index 7fd380f91a62..bb202fcf4924 100644 --- a/benchmark/deepseek_v3/README.md +++ b/benchmark/deepseek_v3/README.md @@ -33,7 +33,7 @@ Add [performance optimization options](#performance-optimization-options) as nee ```bash # Installation -pip install "sglang[all]>=0.4.9.post2" +pip install "sglang[all]>=0.4.9.post3" # Launch python3 -m sglang.launch_server --model deepseek-ai/DeepSeek-V3 --tp 8 --trust-remote-code diff --git a/docs/references/setup_github_runner.md b/docs/references/setup_github_runner.md index c99f903a454a..6b13b8150d11 100644 --- a/docs/references/setup_github_runner.md +++ b/docs/references/setup_github_runner.md @@ -11,9 +11,9 @@ docker pull nvidia/cuda:12.1.1-devel-ubuntu22.04 # Nvidia docker run --shm-size 128g -it -v /tmp/huggingface:/hf_home --gpus all nvidia/cuda:12.1.1-devel-ubuntu22.04 /bin/bash # AMD -docker run --rm --device=/dev/kfd --device=/dev/dri --group-add video --shm-size 128g -it -v /tmp/huggingface:/hf_home lmsysorg/sglang:v0.4.9.post2-rocm630 /bin/bash +docker run --rm --device=/dev/kfd --device=/dev/dri --group-add video --shm-size 128g -it -v /tmp/huggingface:/hf_home lmsysorg/sglang:v0.4.9.post3-rocm630 /bin/bash # AMD just the last 2 GPUs -docker run --rm --device=/dev/kfd --device=/dev/dri/renderD176 --device=/dev/dri/renderD184 --group-add video --shm-size 128g -it -v /tmp/huggingface:/hf_home lmsysorg/sglang:v0.4.9.post2-rocm630 /bin/bash +docker run --rm --device=/dev/kfd --device=/dev/dri/renderD176 --device=/dev/dri/renderD184 --group-add video --shm-size 128g -it -v /tmp/huggingface:/hf_home lmsysorg/sglang:v0.4.9.post3-rocm630 /bin/bash ``` ### Step 2: Configure the runner by `config.sh` diff --git a/docs/start/install.md b/docs/start/install.md index 4ec191f71b46..cd2e731108c3 100644 --- a/docs/start/install.md +++ b/docs/start/install.md @@ -11,7 +11,7 @@ It is recommended to use uv to install the dependencies for faster installation: ```bash pip install --upgrade pip pip install uv -uv pip install "sglang[all]>=0.4.9.post2" +uv pip install "sglang[all]>=0.4.9.post3" ``` **Quick Fixes to Common Problems** @@ -27,7 +27,7 @@ uv pip install "sglang[all]>=0.4.9.post2" ```bash # Use the last release branch -git clone -b v0.4.9.post2 https://github.com/sgl-project/sglang.git +git clone -b v0.4.9.post3 https://github.com/sgl-project/sglang.git cd sglang pip install --upgrade pip @@ -42,7 +42,7 @@ Note: For AMD ROCm system with Instinct/MI GPUs, do following instead: ```bash # Use the last release branch -git clone -b v0.4.9.post2 https://github.com/sgl-project/sglang.git +git clone -b v0.4.9.post3 https://github.com/sgl-project/sglang.git cd sglang pip install --upgrade pip @@ -71,7 +71,7 @@ docker run --gpus all \ Note: For AMD ROCm system with Instinct/MI GPUs, it is recommended to use `docker/Dockerfile.rocm` to build images, example and usage as below: ```bash -docker build --build-arg SGL_BRANCH=v0.4.9.post2 -t v0.4.9.post2-rocm630 -f Dockerfile.rocm . +docker build --build-arg SGL_BRANCH=v0.4.9.post3 -t v0.4.9.post3-rocm630 -f Dockerfile.rocm . alias drun='docker run -it --rm --network=host --device=/dev/kfd --device=/dev/dri --ipc=host \ --shm-size 16G --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined \ @@ -80,11 +80,11 @@ alias drun='docker run -it --rm --network=host --device=/dev/kfd --device=/dev/d drun -p 30000:30000 \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HF_TOKEN=" \ - v0.4.9.post2-rocm630 \ + v0.4.9.post3-rocm630 \ python3 -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --host 0.0.0.0 --port 30000 # Till flashinfer backend available, --attention-backend triton --sampling-backend pytorch are set by default -drun v0.4.9.post2-rocm630 python3 -m sglang.bench_one_batch --batch-size 32 --input 1024 --output 128 --model amd/Meta-Llama-3.1-8B-Instruct-FP8-KV --tp 8 --quantization fp8 +drun v0.4.9.post3-rocm630 python3 -m sglang.bench_one_batch --batch-size 32 --input 1024 --output 128 --model amd/Meta-Llama-3.1-8B-Instruct-FP8-KV --tp 8 --quantization fp8 ``` ## Method 4: Using docker compose diff --git a/python/pyproject.toml b/python/pyproject.toml index 5f53a5ca328f..aa9fc460d977 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "sglang" -version = "0.4.9.post2" +version = "0.4.9.post3" description = "SGLang is yet another fast serving framework for large language models and vision language models." readme = "README.md" requires-python = ">=3.8" diff --git a/python/sglang/version.py b/python/sglang/version.py index 2b4f02700a07..d07dcd150de0 100644 --- a/python/sglang/version.py +++ b/python/sglang/version.py @@ -1 +1 @@ -__version__ = "0.4.9.post2" +__version__ = "0.4.9.post3" From e2d66f60c8f8c90ed9491e21061b73d959c2c4d7 Mon Sep 17 00:00:00 2001 From: Ke Bao Date: Wed, 23 Jul 2025 12:41:25 +0800 Subject: [PATCH 093/396] Skip llama4 vision module loading when multimodal disabled (#8272) Co-authored-by: Mick --- python/sglang/srt/managers/schedule_batch.py | 1 + python/sglang/srt/models/mllama4.py | 13 ++++++++++--- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 536198cd27b4..714af6fba588 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -106,6 +106,7 @@ "num_reserved_decode_tokens", "weight_loader_disable_mmap", "enable_triton_kernel_moe", + "enable_multimodal", ] # Put some global args for easy access diff --git a/python/sglang/srt/models/mllama4.py b/python/sglang/srt/models/mllama4.py index 8712191a98af..4a2d5f7ded4b 100644 --- a/python/sglang/srt/models/mllama4.py +++ b/python/sglang/srt/models/mllama4.py @@ -23,6 +23,7 @@ Modality, MultimodalDataItem, MultimodalInputs, + global_server_args_dict, ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader @@ -55,13 +56,17 @@ def __init__( self.quant_config = quant_config # Check if this is a text-only model (modelopt fp8 llama4 has no vision components) - self.has_vision = self._has_vision_weights(config) - if not self.has_vision: + self.has_vision_weights = self._has_vision_weights(config) + if not self.has_vision_weights: logger.warning( "No vision weights found in checkpoint. Model will run in text-only mode. " "Multimodal capabilities (image processing) will be unavailable." ) + self.has_vision = ( + self.has_vision_weights and global_server_args_dict["enable_multimodal"] + ) + if self.has_vision: self.vision_model = Llama4VisionModel(config.vision_config) self.multi_modal_projector = Llama4MultiModalProjector(config) @@ -269,7 +274,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: def _should_skip_weight(self, name: str) -> bool: """Check if we should skip loading this weight.""" - return "vision" in name and not self.has_vision + return not self.has_vision and ( + "vision" in name or "multi_modal_projector" in name + ) def _transform_weight_name(self, name: str) -> str: """Transform weight name by adding language_model prefix if needed.""" From e885bfdc6a4da0766213e80162410abcfe34574b Mon Sep 17 00:00:00 2001 From: Ke Bao Date: Wed, 23 Jul 2025 14:01:47 +0800 Subject: [PATCH 094/396] Fix sgl-kernel ci test (#8284) --- sgl-kernel/tests/test_moe_fused_gate.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sgl-kernel/tests/test_moe_fused_gate.py b/sgl-kernel/tests/test_moe_fused_gate.py index 1e1b108c7d3d..b08e0d97b23d 100644 --- a/sgl-kernel/tests/test_moe_fused_gate.py +++ b/sgl-kernel/tests/test_moe_fused_gate.py @@ -10,7 +10,6 @@ list(range(1, 10)) + [16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536], ) -@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16]) @pytest.mark.parametrize( "params", [ @@ -20,13 +19,14 @@ ], ) @pytest.mark.parametrize("num_fused_shared_experts", [0, 1, 2]) -def test_moe_fused_gate_combined(seq_length, dtype, params, num_fused_shared_experts): +def test_moe_fused_gate_combined(seq_length, params, num_fused_shared_experts): num_experts, num_expert_group, topk_group, topk = params + dtype = torch.float32 torch.manual_seed(seq_length) - tensor = torch.rand((seq_length, num_experts)).to(dtype).cuda() + tensor = torch.rand((seq_length, num_experts), dtype=dtype, device="cuda") scores = tensor.clone() - bias = torch.rand(num_experts).to(dtype).cuda() + bias = torch.rand(num_experts, dtype=dtype, device="cuda") topk = topk + num_fused_shared_experts output, indices = moe_fused_gate( From 8abd3e77feca9ed740356c1b879e524d09482fb2 Mon Sep 17 00:00:00 2001 From: Lifu Huang Date: Wed, 23 Jul 2025 00:32:16 -0700 Subject: [PATCH 095/396] Introduce Stable LoRA ID System for Overlapped Updates and Prefix Caching (#8261) --- python/sglang/srt/lora/lora_manager.py | 302 ++++++++---------- python/sglang/srt/lora/lora_registry.py | 124 +++++++ python/sglang/srt/lora/mem_pool.py | 4 +- python/sglang/srt/managers/io_struct.py | 20 +- python/sglang/srt/managers/scheduler.py | 20 +- .../sglang/srt/managers/tokenizer_manager.py | 53 +-- python/sglang/srt/managers/tp_worker.py | 6 +- .../sglang/srt/model_executor/model_runner.py | 25 +- python/sglang/srt/server_args.py | 23 +- test/srt/models/lora/test_lora_eviction.py | 80 +++-- test/srt/run_suite.py | 2 +- 11 files changed, 399 insertions(+), 260 deletions(-) create mode 100644 python/sglang/srt/lora/lora_registry.py diff --git a/python/sglang/srt/lora/lora_manager.py b/python/sglang/srt/lora/lora_manager.py index 85fd246163c1..719c52ef8d7c 100644 --- a/python/sglang/srt/lora/lora_manager.py +++ b/python/sglang/srt/lora/lora_manager.py @@ -16,7 +16,7 @@ # and "Punica: Multi-Tenant LoRA Serving" import logging -from typing import Dict, Iterable, Optional, Set, Tuple +from typing import Dict, Iterable, List, Optional, Set, Tuple import torch @@ -26,6 +26,7 @@ from sglang.srt.lora.layers import BaseLayerWithLoRA, get_lora_layer from sglang.srt.lora.lora import LoRAAdapter from sglang.srt.lora.lora_config import LoRAConfig +from sglang.srt.lora.lora_registry import LoRARef from sglang.srt.lora.mem_pool import LoRAMemoryPool from sglang.srt.lora.utils import ( LoRABatchInfo, @@ -55,6 +56,7 @@ def __init__( tp_rank: int = 0, max_lora_rank: Optional[int] = None, target_modules: Optional[Iterable[str]] = None, + lora_paths: Optional[Dict[str, LoRARef]] = None, ): self.base_model: torch.nn.Module = base_model self.base_hf_config: AutoConfig = base_hf_config @@ -64,10 +66,6 @@ def __init__( self.device: torch.device = next(self.base_model.parameters()).device self.tp_size: int = tp_size self.tp_rank: int = tp_rank - self.max_lora_rank: Optional[int] = max_lora_rank - self.target_modules: Optional[Set[str]] = ( - set(target_modules) if target_modules else None - ) # LoRA backend for running sgemm kernels logger.info(f"Using {lora_backend} as backend of LoRA kernels.") @@ -75,7 +73,11 @@ def __init__( self.lora_backend: BaseLoRABackend = backend_type(lora_backend) # Initialize mutable internal state of the LoRAManager. - self.init_state() + self.init_state( + max_lora_rank=max_lora_rank, + target_modules=target_modules, + lora_paths=lora_paths, + ) def init_cuda_graph_batch_info(self, max_bs_in_cuda_graph: int): self.max_bs_in_cuda_graph = max_bs_in_cuda_graph @@ -112,108 +114,87 @@ def create_lora_update_result( success=success, error_message=error_message, loaded_adapters={ - name: config.path for name, config in self.configs.items() + lora_ref.lora_name: lora_ref.lora_path + for lora_ref in self.lora_refs.values() }, ) - def load_lora_adapters(self, lora_paths: Dict[str, str]) -> LoRAUpdateResult: - """ - Load LoRA adapters from the specified paths. - - Args: - lora_paths (Dict[str, str]): A dictionary mapping LoRA adapter names to their file paths. - If a LoRA adapter is already loaded, it will be skipped with a warning. - """ - - results = [] - for lora_name, lora_path in lora_paths.items(): - result = self.load_lora_adapter(lora_name, lora_path, update_state=False) - results.append(result) - - self.update_state_from_configs() - - return self.create_lora_update_result( - success=all(result.success for result in results), - error_message="\n".join( - result.error_message for result in results if not result.success - ), - ) - - def load_lora_adapter( - self, lora_name: str, lora_path: str, update_state: bool = True - ) -> LoRAUpdateResult: + def load_lora_adapter(self, lora_ref: LoRARef) -> LoRAUpdateResult: """ Load a single LoRA adapter from the specified path. Args: - lora_name (str): The name of the LoRA adapter. - lora_path (str): The file path to the LoRA adapter. - update_state (bool): Whether to refresh the internal state after loading the adapter. This is useful for batch loading. + lora_ref (LoRARef): The LoRARef object containing the LoRA name, path, and ID. """ + assert ( + lora_ref.lora_name is not None and lora_ref.lora_path is not None + ), "LoRARef must have both lora_name and lora_path set for loading." + assert ( + lora_ref.lora_id not in self.loras + ), f"LoRA adapter with ID {lora_ref.lora_id} is already loaded. This should have been verified before request is sent to the backend." - success = True - error_message = "" + try: + # load configs + new_adapter = LoRAConfig(lora_ref.lora_path) + self.validate_new_adapter(new_adapter, lora_ref) + self.configs[lora_ref.lora_id] = new_adapter - if lora_name in self.loras: - success = False - error_message = f"LoRA adapter {lora_name} is skipped as it is already loaded. If you want to reload it, please unload it first." + # load weights + self.load_lora_weights(lora_ref) - try: - new_adapter = LoRAConfig(lora_path) - self.validate_new_adapter(lora_name, new_adapter) - self.configs[lora_name] = new_adapter + # keep metadata for displayed messages + self.lora_refs[lora_ref.lora_id] = lora_ref except Exception as e: - success = False - error_message = ( - f"Failed to load LoRA adapter {lora_name} from {lora_path}: {str(e)}" + return self.create_lora_update_result( + success=False, + error_message=str(e), ) - if update_state: - self.update_state_from_configs() + return self.create_lora_update_result(success=True) - return self.create_lora_update_result( - success=success, - error_message=error_message, - ) - - def validate_new_adapter(self, lora_name: str, lora_config: LoRAConfig): + def validate_new_adapter(self, lora_config: LoRAConfig, lora_ref: LoRARef): """ Validate if an adapter can be loaded into the current LoRA memory pool and generate error if it is incompatible. """ - incompatible = self.memory_pool and not self.memory_pool.can_support( - lora_config - ) + memory_pool = getattr(self, "memory_pool", None) + incompatible = memory_pool and not memory_pool.can_support(lora_config) if incompatible: raise ValueError( - f"LoRA adapter {lora_name} with rank {lora_config.r} is incompatible with the current LoRA memory pool configuration. " + f"LoRA adapter {lora_ref.lora_name} with rank {lora_config.r} is incompatible with the current LoRA memory pool configuration. " "Please ensure that the LoRA adapter's rank is within the configured `--max_lora_rank` and that the target modules are " "included in `--enable_lora_modules`." ) - def unload_lora_adapter(self, lora_name: str) -> LoRAUpdateResult: + def unload_lora_adapter(self, lora_ref: LoRARef) -> LoRAUpdateResult: """ Unload LoRA adapters by their names. This will remove the adapters from the memory pool and delete the corresponding LoRA modules. """ - success = True - error_message = "" - if lora_name in self.loras: - del self.configs[lora_name] - else: - error_message = f"LoRA adapter {lora_name} is not loaded." - success = False + adapter = self.configs.get(lora_ref.lora_id, None) + assert ( + adapter is not None + ), f"LoRA adapter with ID {lora_ref.lora_id} is not loaded. This should have been verified before request is sent to the backend." - self.update_state_from_configs() + try: + del self.configs[lora_ref.lora_id] + del self.loras[lora_ref.lora_id] + del self.lora_refs[lora_ref.lora_id] + except Exception as e: + return self.create_lora_update_result( + success=False, + error_message=str(e), + ) - return self.create_lora_update_result( - success=success, - error_message=error_message, - ) + return self.create_lora_update_result(success=True) def prepare_lora_batch(self, forward_batch: ForwardBatch): - # load active loras into lora memory pool + # Load active loras into lora memory pool + # TODO (lifuhuang): The naming of `forward_batch.lora_paths` is confusing. It actually contains a set of unique + # LoRA IDs, not LoRA paths. While unfortunately we cannot change the name in API for backward compatibility, we + # should consider (1) renaming the incorrect usage within the system, and (2) deprecating the parameter name in + # the current API schema and introducing a better request schema in the future (e.g., use `model_name`). cur_uids = set(forward_batch.lora_paths) assert len(cur_uids) <= self.max_loras_per_batch self.memory_pool.prepare_lora_batch(cur_uids, self.loras, self.lora_modules) @@ -233,10 +214,10 @@ def transfer_adapter_info( weight_indices = [0] * len(forward_batch.lora_paths) lora_ranks = [0] * self.max_loras_per_batch scalings = [0] * self.max_loras_per_batch - for i, lora_path in enumerate(forward_batch.lora_paths): - weight_indices[i] = self.memory_pool.get_buffer_id(lora_path) - if lora_path is not None: - lora = self.loras[lora_path] + for i, uid in enumerate(forward_batch.lora_paths): + weight_indices[i] = self.memory_pool.get_buffer_id(uid) + if uid is not None: + lora = self.loras[uid] lora_ranks[weight_indices[i]] = lora.config.r scalings[weight_indices[i]] = lora.scaling @@ -326,7 +307,7 @@ def update_lora_info(self): """ Update all LoRA modules to associate them with the latest memory buffer. """ - for layer_id, layer_modules in self.lora_modules.items(): + for layer_id, layer_modules in enumerate(self.lora_modules): for module_name, module in layer_modules.items(): if "qkv_proj" in module_name: module.set_lora_info( @@ -353,115 +334,94 @@ def update_lora_info(self): ), ) - def init_state(self): + def init_state( + self, + max_lora_rank: Optional[int] = None, + target_modules: Optional[Iterable[str]] = None, + lora_paths: Optional[Dict[str, LoRARef]] = None, + ): """ Initialize the internal (mutable) state of the LoRAManager. - These states are mutable via the `update_state_from_configs` as LoRA adapters are loaded and unloaded dynamically. + When `lora_paths` is provided and not empty, it might be used for inferring LoRA shape info such as + the target modules and max_lora_rank. """ - # Configs of all active LoRA adapters. - self.configs: Dict[str, LoRAConfig] = {} - - # LoRA adapter weights cached in CPU memory. - self.loras: Dict[str, LoRAAdapter] = {} + assert lora_paths or ( + max_lora_rank is not None and target_modules is not None + ), "When no initial --lora-paths is provided, you need to specify both --max-lora-rank and --lora-target-modules for LoRA initialization." - # Supported weight names (e.g., qkv_proj) for LoRA A and B respectively. - self.lora_weight_names: Tuple[Set[str]] = (set(), set()) - - # Look-up table that essentially maps (layer_index, module_name) to the corresponding LoRA module. - self.lora_modules: Dict[int, Dict[str, BaseLayerWithLoRA]] = { - i: {} for i in range(self.base_hf_config.num_hidden_layers) - } + self.init_lora_adapters(lora_paths) + self.init_lora_shapes( + max_lora_rank=max_lora_rank, + target_modules=target_modules, + ) + self.init_lora_weight_names() + self.init_lora_modules() + self.init_memory_pool() - # The LoRA memory pool that manages the GPU buffers for active LoRA weights. - # It is initialized lazily when the first LoRA adapter is loaded. - self.memory_pool: Optional[LoRAMemoryPool] = None + def init_lora_adapters(self, lora_paths: Optional[Dict[str, LoRARef]] = None): + # Configs of all active LoRA adapters, indexed by LoRA ID. + self.configs: Dict[str, LoRAConfig] = {} - def update_state_from_configs(self): - """ - Update the internal state of the LoRAManager based on the current `self.configs`. This method - should be called whenever `self.configs` is modified (e.g., when new LoRA adapters are loaded). - """ + # LoRA adapter weights cached in CPU memory, indexed by LoRA ID. + self.loras: Dict[str, LoRAAdapter] = {} - # Loads / unloads LoRA adapters based on the latest configs. - self.update_lora_adapters() - # Apply the latest LoRA configurations to the internal state for inferencing. - self.apply_lora_configs() + # Mapping from LoRA ID to LoRARef object. + self.lora_refs: Dict[str, LoRARef] = {} - def apply_lora_configs(self): - """ - Apply the LoRA configurations to the base model and internal states of the LoRAManager for inferencing. + if lora_paths: + for lora_ref in lora_paths.values(): + result = self.load_lora_adapter(lora_ref) + if not result.success: + raise RuntimeError( + f"Failed to load LoRA adapter {lora_ref.lora_name}: {result.error_message}" + ) - Notes: - - Currently, this method is effectively only invoked during the initialization phase of the LoRAManager as - we do not yet support dynamically updating adapter shape configs, which has a dependency on (1) FlashInfer - LoRA backend deprecation and (2) CUDA graph recapture support. We are targeting completing these work in - early CY25H2. - """ + def init_lora_shapes( + self, + max_lora_rank: Optional[int] = None, + target_modules: Optional[Iterable[str]] = None, + ): + """Infer LoRA target modules and max_lora_rank from loaded adapters if not provided.""" - if self.memory_pool is None: - # Infer max_lora_rank and target_modules if not explicitly specified in server args. - if self.target_modules is None: - self.target_modules = set() - for config in self.configs.values(): - self.target_modules.update(config.target_modules) - - if self.max_lora_rank is None: - self.max_lora_rank = max( - [x.hf_config["r"] for x in self.configs.values()], - default=0, - ) + if target_modules is not None: + self.target_modules = set(target_modules) + else: + self.target_modules = set() + for config in self.configs.values(): + self.target_modules.update(config.target_modules) - self.update_lora_weight_names() - self.update_lora_modules() - self.update_memory_buffers() + if max_lora_rank is not None: + self.max_lora_rank = max_lora_rank else: - # No-op if the memory pool can support the current LoRA configurations. - # TODO (lifuhuang): support reinitializing the memory pool when the maximum LoRA rank or target - # module is changed once FlashInfer backend is deprecated. - assert self.memory_pool.can_support(self.configs.values()), ( - "LoRA memory pool cannot support the current LoRA configuration. " - "This should never happen as we should have validated adapter compatibility. " - "Please create a Github issue to report.", + self.max_lora_rank = max( + [x.hf_config["r"] for x in self.configs.values()], + default=0, ) - def update_lora_weight_names(self): + def init_lora_weight_names(self): """ Add new LoRA weight names if needed based on the current `self.configs`. """ # Target lora weight names for lora_a and lora_b modules respectively. lora_A, lora_B = get_normalized_lora_weight_names(self.target_modules) - self.lora_weight_names[0].update(lora_A) - self.lora_weight_names[1].update(lora_B) + self.lora_weight_names: Tuple[Set[str]] = (set(lora_A), set(lora_B)) - def update_lora_adapters(self): + def load_lora_weights(self, lora_ref: LoRARef): """ - Update the LoRA adapters in CPU memory based on the current `self.configs`. - It loads any new adapters that are not already loaded, and unloads any adapters - that are no longer in `self.configs` (e.g., unloaded). + Load the weights of a LoRA adapter to CPU memory and conducts post-loading validation. """ - - # Load new adapter weights to cpu - for name, config in self.configs.items(): - if name not in self.loras: - logger.info(f"Loading weight of LoRA adapter {name} from {config.path}") - lora_adapter = LoRAAdapter( - name, - config, - self.base_hf_config, - self.load_config, - self.lora_backend, - ) - lora_adapter.initialize_weights() - self.loras[name] = lora_adapter - - # Clean up unused LoRA adapters, copying the list to avoid modifying the dict during iteration. - for name in list(self.loras): - if name not in self.configs: - logger.info(f"Unloading LoRA adapter {name}") - del self.loras[name] + lora_adapter = LoRAAdapter( + lora_ref.lora_id, + self.configs[lora_ref.lora_id], + self.base_hf_config, + self.load_config, + self.lora_backend, + ) + lora_adapter.initialize_weights() + self.loras[lora_ref.lora_id] = lora_adapter # Additional checks for flashinfer backend # FIXME remove the restrictions after supporting multi-rank for flashinfer backend @@ -472,7 +432,7 @@ def update_lora_adapters(self): len(lora_dims) == 1 and len(scalings) == 1 ), "Flashinfer backend currently only supports single LoRA rank and scaling across all adapters. " - def update_memory_buffers(self): + def init_memory_pool(self): """(Re)initialize the LoRA memory pool based on the current configurations.""" self.memory_pool = LoRAMemoryPool( base_hf_config=self.base_hf_config, @@ -490,7 +450,12 @@ def set_lora_module(self, module_name, module): replace_submodule(self.base_model, module_name, lora_module) return lora_module - def update_lora_modules(self): + def init_lora_modules(self): + # Look-up table that essentially maps (layer_index, module_name) to the corresponding LoRA module. + self.lora_modules: List[Dict[str, BaseLayerWithLoRA]] = [ + {} for _ in range(self.base_hf_config.num_hidden_layers) + ] + # Target module names of customized layers defined in python/sglang/srt/layers # e.g., {"qkv_proj", "o_proj"} customized_target_names = get_customized_names_from_hf_names( @@ -511,7 +476,6 @@ def update_lora_modules(self): # The module should be converted if it is included in target_names if module_name.split(".")[-1] in customized_target_names: layer_id = get_layer_id(module_name) - if module_name not in self.lora_modules[layer_id]: - self.lora_modules[layer_id][module_name] = self.set_lora_module( - module_name, module - ) + self.lora_modules[layer_id][module_name] = self.set_lora_module( + module_name, module + ) diff --git a/python/sglang/srt/lora/lora_registry.py b/python/sglang/srt/lora/lora_registry.py new file mode 100644 index 000000000000..b596c7371f9c --- /dev/null +++ b/python/sglang/srt/lora/lora_registry.py @@ -0,0 +1,124 @@ +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + + +import asyncio +from dataclasses import dataclass, field, fields +from typing import Dict, List, Optional, Union +from uuid import uuid4 + + +@dataclass(frozen=True, slots=True) +class LoRARef: + """ + Reference record for a LoRA model. + + This object guarantees a unique ``lora_id`` and may include ``lora_name`` and ``lora_path``. The ID + eliminates conflicts from reused LoRA names or paths and can be used to generate deterministic cache + keys (e.g., radix cache). + """ + + lora_id: str = field(default_factory=lambda: uuid4().hex) + lora_name: Optional[str] = None + lora_path: Optional[str] = None + + def __post_init__(self): + if self.lora_id is None: + raise ValueError("lora_id cannot be None") + + def __str__(self) -> str: + parts = [ + f"{f.name}={value}" + for f in fields(self) + if (value := getattr(self, f.name)) is not None + ] + return f"{self.__class__.__name__}({', '.join(parts)})" + + +class LoRARegistry: + """ + The central registry to keep track of available LoRA adapters. + + TODO (lifuhuang): This registry is intended as the foundation for overlapped lora update. We decided + to keep it in a separate PR to keep code review simple and to unblock the radix cache work. + """ + + def __init__(self, lora_paths: Optional[Dict[str, LoRARef]] = None): + assert lora_paths is None or all( + isinstance(lora, LoRARef) for lora in lora_paths.values() + ), ( + "server_args.lora_paths should have been normalized to LoRARef objects during server initialization. " + "Please file an issue if you see this error." + ) + + # A dictionary to hold LoRARef objects, mapping from LoRA name to LoRARef. + self._registry: Dict[str, LoRARef] = dict(lora_paths or {}) + + async def register(self, lora_ref: LoRARef): + """ + Register a new LoRARef object in the registry. + + Args: + lora_ref (LoRARef): The LoRARef object to register. + """ + if lora_ref.lora_name in self._registry: + raise ValueError( + f"LoRA with name {lora_ref.lora_name} already exists. Loaded LoRAs: {self._registry.keys()}" + ) + self._registry[lora_ref.lora_name] = lora_ref + + async def unregister(self, lora_name: str) -> str: + """ + Unregister a LoRARef object from the registry and returns the removed LoRA ID. + + Args: + lora_name (str): The name of the LoRA model to unregister. + """ + lora_ref = self._registry.get(lora_name, None) + if lora_ref is None: + raise ValueError( + f"LoRA with name {lora_name} does not exist. Loaded LoRAs: {self._registry.keys()}" + ) + del self._registry[lora_name] + + return lora_ref.lora_id + + async def acquire(self, lora_name: Union[str, List[str]]) -> Union[str, List[str]]: + """ + Queries registry for LoRA IDs based on LoRA names and start tracking the usage of the corresponding LoRA adapters + by incrementing its counter. + + TODO (lifuhuang): currently it only queries the registry and does not track the usage of LoRA adapters. + """ + + async def _acquire_single(name: str) -> str: + lora_ref = self._registry.get(name, None) + if lora_ref is None: + raise ValueError( + f"The following requested LoRA adapters are not loaded: {name}\n" + f"Loaded adapters: {self._registry.keys()}." + ) + # await self._counters[lora_ref.lora_id].increment() + return lora_ref.lora_id + + if isinstance(lora_name, str): + lora_id = await _acquire_single(lora_name) + return lora_id + elif isinstance(lora_name, list): + lora_ids = await asyncio.gather( + *[_acquire_single(name) for name in lora_name] + ) + return lora_ids + else: + raise TypeError("lora_name must be either a string or a list of strings.") diff --git a/python/sglang/srt/lora/mem_pool.py b/python/sglang/srt/lora/mem_pool.py index 1b36cac5e1a7..ae856246dd92 100644 --- a/python/sglang/srt/lora/mem_pool.py +++ b/python/sglang/srt/lora/mem_pool.py @@ -153,7 +153,7 @@ def prepare_lora_batch( self, cur_uids: Set[Optional[str]], lora_adapters: Dict[str, LoRAAdapter], - lora_modules: Dict[int, Dict[str, BaseLayerWithLoRA]], + lora_modules: List[Dict[str, BaseLayerWithLoRA]], ): def get_available_buffer_slot(): for buffer_id in range(self.max_loras_per_batch): @@ -186,7 +186,7 @@ def load_lora_weight_to_buffer( uid: str, buffer_id: int, lora_adapter: LoRAAdapter, - lora_modules: Dict[int, Dict[str, BaseLayerWithLoRA]], + lora_modules: List[Dict[str, BaseLayerWithLoRA]], ): def load_lora_weight_tensor( buffer_view: torch.Tensor, weight: Optional[torch.Tensor] diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 8e1d1075aab6..3d18e1af450d 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -22,6 +22,7 @@ from enum import Enum from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +from sglang.srt.lora.lora_registry import LoRARef from sglang.srt.managers.schedule_batch import BaseFinishReason from sglang.srt.multimodal.mm_utils import has_valid_data from sglang.srt.sampling.sampling_params import SamplingParams @@ -1067,19 +1068,36 @@ class LoadLoRAAdapterReqInput: lora_name: str # The path of loading. lora_path: str + # The unique identifier for the LoRA adapter, which automatically generated in the `TokenizerManager`. + lora_id: Optional[str] = None + + def to_ref(self) -> LoRARef: + return LoRARef( + lora_id=self.lora_id, + lora_name=self.lora_name, + lora_path=self.lora_path, + ) @dataclass class UnloadLoRAAdapterReqInput: # The name of lora module to unload. lora_name: str + # The unique identifier for the LoRA adapter, which automatically generated in the `TokenizerManager`. + lora_id: Optional[str] = None + + def to_ref(self) -> LoRARef: + return LoRARef( + lora_id=self.lora_id, + lora_name=self.lora_name, + ) @dataclass class LoRAUpdateResult: success: bool error_message: Optional[str] = None - loaded_adapters: Dict[str, str] = field(default_factory=dict) + loaded_adapters: Dict[str, LoRARef] = field(default_factory=dict) LoadLoRAAdapterReqOutput = UnloadLoRAAdapterReqOutput = LoRAUpdateResult diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index e6dd80d717ad..c3b5fc2e885f 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -247,7 +247,7 @@ def __init__( self.pp_size = server_args.pp_size self.dp_size = server_args.dp_size self.schedule_policy = server_args.schedule_policy - self.lora_paths = server_args.lora_paths + self.enable_lora = server_args.enable_lora self.max_loras_per_batch = server_args.max_loras_per_batch self.enable_overlap = not server_args.disable_overlap_schedule self.skip_tokenizer_init = server_args.skip_tokenizer_init @@ -1706,13 +1706,13 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]: self.chunked_req.init_next_round_input() self.chunked_req = adder.add_chunked_req(self.chunked_req) - if self.lora_paths: + if self.enable_lora: lora_set = set([req.lora_path for req in self.running_batch.reqs]) # Get requests from the waiting queue to a new prefill batch for req in self.waiting_queue: if ( - self.lora_paths + self.enable_lora and len( lora_set | set([req.lora_path for req in adder.can_run_list]) @@ -2466,12 +2466,6 @@ def load_lora_adapter( """In-place loading a new lora adapter from disk or huggingface.""" result = self.tp_worker.load_lora_adapter(recv_req) - - if result.success: - flush_cache_success = self.flush_cache() - assert flush_cache_success, "Cache flush failed after loading lora adapter." - else: - logger.error(result.error_message) return result def unload_lora_adapter( @@ -2480,14 +2474,6 @@ def unload_lora_adapter( """Unload the lora adapter.""" result = self.tp_worker.unload_lora_adapter(recv_req) - - if result.success: - flush_cache_success = self.flush_cache() - assert ( - flush_cache_success - ), "Cache flush failed after unloading LoRA weights" - else: - logger.error(result.error_message) return result def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput): diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 631d23f17335..0f65fa9257e8 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -62,6 +62,7 @@ get_tokenizer, get_tokenizer_from_processor, ) +from sglang.srt.lora.lora_registry import LoRARef, LoRARegistry from sglang.srt.managers.io_struct import ( AbortReq, BatchEmbeddingOut, @@ -242,11 +243,11 @@ def __init__( revision=server_args.revision, ) - # Initialize loaded loRA adapters with the initial lora paths in the server_args. - # This list will be updated when new LoRA adapters are loaded or unloaded dynamically. - self.loaded_lora_adapters: Dict[str, str] = dict( - self.server_args.lora_paths or {} - ) + # Initialize the `LoRARegistry` with initial LoRA adapter paths provided in `server_args`. + # The registry dynamically updates as adapters are loaded / unloaded during runtime. It + # serves as the source of truth for available adapters and maps user-friendly LoRA names + # to internally used unique LoRA IDs. + self.lora_registry = LoRARegistry(self.server_args.lora_paths or {}) # Store states self.no_create_loop = False @@ -523,6 +524,10 @@ async def _tokenize_one_request( else: mm_inputs = None + if self.server_args.enable_lora and obj.lora_path: + # Replace the user-friendly LoRA names in `lora_path` with their corresponding unique LoRA IDs. + obj.lora_path = await self.lora_registry.acquire(obj.lora_path) + self._validate_one_request(obj, input_ids) return self._create_tokenized_object( obj, input_text, input_ids, input_embeds, mm_inputs, token_type_ids @@ -574,8 +579,6 @@ def _validate_one_request( "The server is not configured to enable custom logit processor. " "Please set `--enable-custom-logits-processor` to enable this feature." ) - if self.server_args.enable_lora and obj.lora_path: - self._validate_lora_adapters(obj) def _validate_input_ids_in_vocab( self, input_ids: List[int], vocab_size: int @@ -689,21 +692,6 @@ def _validate_batch_tokenization_constraints( "Batch tokenization is not needed for input_embeds. Do not set `enable_tokenizer_batch_encode`." ) - def _validate_lora_adapters(self, obj: GenerateReqInput): - """Validate that the requested LoRA adapters are loaded.""" - requested_adapters = ( - set(obj.lora_path) if isinstance(obj.lora_path, list) else {obj.lora_path} - ) - loaded_adapters = ( - self.loaded_lora_adapters.keys() if self.loaded_lora_adapters else set() - ) - unloaded_adapters = requested_adapters - loaded_adapters - if unloaded_adapters: - raise ValueError( - f"The following requested LoRA adapters are not loaded: {unloaded_adapters}\n" - f"Loaded adapters: {loaded_adapters}." - ) - def _send_one_request( self, obj: Union[GenerateReqInput, EmbeddingReqInput], @@ -1054,8 +1042,18 @@ async def load_lora_adapter( ) async with self.model_update_lock.writer_lock: + # Generate new uniquely identifiable LoRARef object. + new_adapter = LoRARef( + lora_name=obj.lora_name, + lora_path=obj.lora_path, + ) + + # Register the new adapter in the registry. + obj.lora_id = new_adapter.lora_id result = (await self.update_lora_adapter_communicator(obj))[0] - self.loaded_lora_adapters = result.loaded_adapters + if result.success: + await self.lora_registry.register(new_adapter) + return result async def unload_lora_adapter( @@ -1069,6 +1067,10 @@ async def unload_lora_adapter( "LoRA is not enabled. Please set `--enable-lora` to enable LoRA." ) + assert ( + obj.lora_name is not None + ), "lora_name must be provided to unload LoRA adapter" + # TODO (lifuhuang): Remove this after we verify that dynamic lora loading works # with dp_size > 1. assert ( @@ -1080,8 +1082,9 @@ async def unload_lora_adapter( ) async with self.model_update_lock.writer_lock: + obj.lora_id = await self.lora_registry.unregister(obj.lora_name) result = (await self.update_lora_adapter_communicator(obj))[0] - self.loaded_lora_adapters = result.loaded_adapters + return result async def get_weights_by_name( @@ -1309,7 +1312,7 @@ def dump_requests_before_crash(self): filename = os.path.join( self.crash_dump_folder, os.getenv("HOSTNAME", None), - f'crash_dump_{datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}.pkl', + f"crash_dump_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.pkl", ) os.makedirs(os.path.dirname(filename), exist_ok=True) diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index ff20ea01e4d3..d0939ffcaeaa 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -293,11 +293,9 @@ def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput): return parameter def load_lora_adapter(self, recv_req: LoadLoRAAdapterReqInput): - result = self.model_runner.load_lora_adapter( - recv_req.lora_name, recv_req.lora_path - ) + result = self.model_runner.load_lora_adapter(recv_req.to_ref()) return result def unload_lora_adapter(self, recv_req: UnloadLoRAAdapterReqInput): - result = self.model_runner.unload_lora_adapter(recv_req.lora_name) + result = self.model_runner.unload_lora_adapter(recv_req.to_ref()) return result diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 4f0b1d64ce8a..9e6d14aaca55 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -68,6 +68,7 @@ from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model from sglang.srt.layers.utils import is_sm100_supported from sglang.srt.lora.lora_manager import LoRAManager +from sglang.srt.lora.lora_registry import LoRARef from sglang.srt.managers.schedule_batch import ( GLOBAL_SERVER_ARGS_KEYS, global_server_args_dict, @@ -890,44 +891,38 @@ def init_lora_manager(self): tp_rank=self.tp_rank, max_lora_rank=self.server_args.max_lora_rank, target_modules=self.server_args.lora_target_modules, + lora_paths=self.server_args.lora_paths, ) - result = self.lora_manager.load_lora_adapters(self.server_args.lora_paths or {}) - if result.success: - logger.info( - f"LoRA manager ready. Loaded LoRA adapters: {', '.join(result.loaded_adapters)}" - ) - else: - raise RuntimeError(f"Failed to load LoRA adapters: {result.error_message}") - def load_lora_adapter(self, lora_name: str, lora_path: str): + def load_lora_adapter(self, lora_ref: LoRARef): """Load a new lora adapter from disk or huggingface.""" logger.info( - f"LoRA adapter loading starts: name={lora_name}, path={lora_path}. " + f"LoRA adapter loading starts: {lora_ref}. " f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB" ) - result = self.lora_manager.load_lora_adapter(lora_name, lora_path) + result = self.lora_manager.load_lora_adapter(lora_ref) logger.info( - f"LoRA adapter loading completes: name={lora_name}, path={lora_path}. " + f"LoRA adapter loading completes: {lora_ref}. " f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB" ) return result - def unload_lora_adapter(self, lora_name: str): + def unload_lora_adapter(self, lora_ref: LoRARef): """Unload a lora adapter that was previously loaded during initialization or dynamic loading.""" logger.info( - f"LoRA adapter unloading starts: name={lora_name}. " + f"LoRA adapter unloading starts: {lora_ref}. " f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB" ) - result = self.lora_manager.unload_lora_adapter(lora_name) + result = self.lora_manager.unload_lora_adapter(lora_ref) logger.info( - f"LoRA adapter unloading completes: name={lora_name}. " + f"LoRA adapter unloading completes: {lora_ref}. " f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB" ) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 400a1bf99e8e..1625f2c3af21 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -20,10 +20,10 @@ import os import random import tempfile -from token import OP from typing import List, Literal, Optional, Union from sglang.srt.hf_transformers_utils import check_gguf_file, get_config +from sglang.srt.lora.lora_registry import LoRARef from sglang.srt.reasoning_parser import ReasoningParser from sglang.srt.utils import ( LORA_TARGET_ALL_MODULES, @@ -145,7 +145,7 @@ class ServerArgs: enable_lora: Optional[bool] = None max_lora_rank: Optional[int] = None lora_target_modules: Optional[Union[set[str], List[str]]] = None - lora_paths: Optional[Union[dict[str, str], List[str]]] = None + lora_paths: Optional[Union[dict[str, str], dict[str, LoRARef], List[str]]] = None max_loras_per_batch: int = 8 lora_backend: str = "triton" @@ -1843,9 +1843,24 @@ def check_lora_server_args(self): for lora_path in lora_paths: if "=" in lora_path: name, path = lora_path.split("=", 1) - self.lora_paths[name] = path + self.lora_paths[name] = LoRARef(lora_name=name, lora_path=path) else: - self.lora_paths[lora_path] = lora_path + self.lora_paths[lora_path] = LoRARef( + lora_name=lora_path, + lora_path=lora_path, + ) + elif isinstance(self.lora_paths, dict): + self.lora_paths = { + k: LoRARef(lora_name=k, lora_path=v) + for k, v in self.lora_paths.items() + } + elif self.lora_paths is None: + self.lora_paths = {} + else: + raise ValueError( + f"Invalid type for --lora-paths: {type(self.lora_paths)}. " + "Expected a list or a dictionary." + ) # Expand target modules if self.lora_target_modules: diff --git a/test/srt/models/lora/test_lora_eviction.py b/test/srt/models/lora/test_lora_eviction.py index e74af0a0e61d..b352da2d5d99 100644 --- a/test/srt/models/lora/test_lora_eviction.py +++ b/test/srt/models/lora/test_lora_eviction.py @@ -12,6 +12,7 @@ # limitations under the License. # ============================================================================== +import contextlib import multiprocessing as mp import unittest from typing import Dict, List, Tuple @@ -39,6 +40,16 @@ BASE_MODEL = "meta-llama/Meta-Llama-3.1-8B-Instruct" +@contextlib.contextmanager +def dynamically_loaded_adapter(runner, lora_path: str, lora_name: str): + """A context manager to load and automatically unload a LoRA adapter.""" + try: + runner.load_lora_adapter(lora_name=lora_name, lora_path=lora_path) + yield + finally: + runner.unload_lora_adapter(lora_name=lora_name) + + class TestLoRAEviction(CustomTestCase): def test_lora_eviction_with_different_target_modules(self): """ @@ -51,55 +62,80 @@ def test_lora_eviction_with_different_target_modules(self): self._run_test(ADAPTERS, output_history, reverse=False) self._run_test(ADAPTERS, output_history, reverse=True) + def test_lora_eviction_with_reused_lora_name(self): + """ + Test LoRA eviction with reused LoRA names. + + This test runs inference against two LoRA adapters with the same name to ensure that the eviction behavior + works correctly when reusing LoRA names. + """ + output_history = {} + self._run_test(ADAPTERS, output_history, reuse_lora_name=True, repeat=1) + self._run_test(ADAPTERS, output_history, reuse_lora_name=False, repeat=1) + def _run_test( self, lora_paths: List[str], output_history: Dict[Tuple[str, str], str], - reverse: bool, + reverse: bool = False, repeat: int = 2, + reuse_lora_name: bool = False, ): + REUSED_LORA_NAME = "lora" max_new_tokens = 256 backend = "triton" torch_dtype = torch.float16 base_path = BASE_MODEL assert len(lora_paths) >= 2 + initial_lora_paths = lora_paths if not reuse_lora_name else None # Initialize runners with SRTRunner( base_path, torch_dtype=torch_dtype, model_type="generation", - lora_paths=lora_paths, + lora_paths=initial_lora_paths, max_loras_per_batch=1, lora_backend=backend, disable_radix_cache=True, + enable_lora=True, + max_lora_rank=256, + lora_target_modules=["all"], ) as srt_runner: adapter_sequence = lora_paths if not reverse else lora_paths[::-1] for i in range(repeat): - for j, adapter in enumerate(adapter_sequence): + for j, lora_path in enumerate(adapter_sequence): print( - f"\n========== Testing LoRA eviction with adapter '{adapter}' (#{j+1}/{len(adapter_sequence)}), reversed: {reverse}, repeat: {i+1}/{repeat} ---" + f"\n========== Testing LoRA eviction with adapter '{lora_path}' (#{j + 1}/{len(adapter_sequence)}), reuse_lora_name: {reuse_lora_name}, reversed: {reverse}, repeat: {i + 1}/{repeat} ---" + ) + + lora_name = REUSED_LORA_NAME if reuse_lora_name else lora_path + context = ( + dynamically_loaded_adapter(srt_runner, lora_path, lora_name) + if reuse_lora_name + else contextlib.nullcontext() ) - for prompt in PROMPTS: - print("\nprompt:\n", prompt) - srt_outputs = srt_runner.forward( - [prompt], - max_new_tokens=max_new_tokens, - lora_paths=[adapter], - ) - output = srt_outputs.output_strs[0].strip() - print("\noutput:\n", output) - - prev_output = output_history.get((adapter, prompt)) - if prev_output is not None: - self.assertEqual( - prev_output, - output, - f"Output mismatch for adapter {adapter} and prompt '{prompt}' on repeat {j + 1}, previous: '{prev_output}', current: '{output}'.", + with context: + for prompt in PROMPTS: + print("\nprompt:\n", prompt) + srt_outputs = srt_runner.forward( + [prompt], + max_new_tokens=max_new_tokens, + lora_paths=[lora_name], ) - else: - output_history[(adapter, prompt)] = output + output = srt_outputs.output_strs[0].strip() + print("\noutput:\n", output) + + prev_output = output_history.get((lora_path, prompt)) + if prev_output is not None: + self.assertEqual( + prev_output, + output, + f"Output mismatch for adapter {lora_path} and prompt '{prompt}' on repeat {j + 1}, previous: '{prev_output}', current: '{output}'.", + ) + else: + output_history[(lora_path, prompt)] = output if __name__ == "__main__": diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 0e62760ab72f..6a96cf598648 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -14,7 +14,7 @@ class TestFile: suites = { "per-commit": [ TestFile("models/lora/test_lora.py", 200), - TestFile("models/lora/test_lora_eviction.py", 120), + TestFile("models/lora/test_lora_eviction.py", 200), TestFile("models/lora/test_lora_backend.py", 99), TestFile("models/lora/test_multi_lora_backend.py", 60), TestFile("models/lora/test_lora_cuda_graph.py", 250), From b43263307f40a206f1371e4064d410a136d4e004 Mon Sep 17 00:00:00 2001 From: Zhiqiang Xie Date: Wed, 23 Jul 2025 01:49:03 -0700 Subject: [PATCH 096/396] Hicache IO kernel refactoring (#8264) --- sgl-kernel/csrc/common_extension.cc | 37 +- sgl-kernel/csrc/kvcacheio/transfer.cu | 415 ++++++++++++++-------- sgl-kernel/include/sgl_kernel_ops.h | 61 ++-- sgl-kernel/python/sgl_kernel/kvcacheio.py | 160 +++++++-- sgl-kernel/tests/test_kvcacheio.py | 110 +++--- 5 files changed, 524 insertions(+), 259 deletions(-) diff --git a/sgl-kernel/csrc/common_extension.cc b/sgl-kernel/csrc/common_extension.cc index 070fe4bd2f60..20b9a804872d 100644 --- a/sgl-kernel/csrc/common_extension.cc +++ b/sgl-kernel/csrc/common_extension.cc @@ -249,34 +249,39 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { "dst_indices, int item_size, int block_quota, int num_warps_per_block) -> ()"); m.impl("transfer_kv_per_layer", torch::kCUDA, &transfer_kv_per_layer); m.def( - "transfer_kv_per_layer_direct(Tensor src_k, Tensor dst_k, Tensor src_v, Tensor dst_v, Tensor src_indices, Tensor " - "dst_indices, int page_size) -> ()"); - m.impl("transfer_kv_per_layer_direct", torch::kCUDA, &transfer_kv_per_layer_direct); + "transfer_kv_per_layer_pf_lf(Tensor src_k, Tensor dst_k, Tensor src_v, Tensor dst_v, Tensor src_indices, Tensor " + "dst_indices, int item_size, int src_layout_dim, int block_quota, int num_warps_per_block) -> ()"); + m.impl("transfer_kv_per_layer_pf_lf", torch::kCUDA, &transfer_kv_per_layer_pf_lf); m.def( - "transfer_kv_all_layer(Tensor src_k, Tensor dst_k, Tensor src_v, Tensor dst_v, Tensor src_indices, Tensor " - "dst_indices, int item_size, int num_layers, int src_layer_offset, int dst_layer_offset, int block_quota, int " + "transfer_kv_all_layer(Tensor src_k_layers, Tensor dst_k_layers, Tensor src_v_layers, Tensor dst_v_layers, " + "Tensor src_indices, Tensor dst_indices, int item_size, int num_layers, int block_quota, int " "num_warps_per_block) -> ()"); m.impl("transfer_kv_all_layer", torch::kCUDA, &transfer_kv_all_layer); m.def( - "transfer_kv_all_layer_direct(Tensor src_k, Tensor dst_k, Tensor src_v, Tensor dst_v, Tensor src_indices, Tensor " - "dst_indices, int page_size, int num_layers) -> ()"); - m.impl("transfer_kv_all_layer_direct", torch::kCUDA, &transfer_kv_all_layer_direct); + "transfer_kv_all_layer_lf_pf(Tensor src_k_layers, Tensor dst_k, Tensor src_v_layers, Tensor dst_v, " + "Tensor src_indices, Tensor dst_indices, int item_size, int dst_layout_dim, int num_layers, int block_quota, int " + "num_warps_per_block) -> ()"); + m.impl("transfer_kv_all_layer_lf_pf", torch::kCUDA, &transfer_kv_all_layer_lf_pf); m.def( "transfer_kv_per_layer_mla(Tensor src, Tensor dst, Tensor src_indices, Tensor dst_indices, int item_size, int " "block_quota, int num_warps_per_block) -> ()"); m.impl("transfer_kv_per_layer_mla", torch::kCUDA, &transfer_kv_per_layer_mla); m.def( - "transfer_kv_per_layer_mla_direct(Tensor src, Tensor dst, Tensor src_indices, Tensor dst_indices, int page_size) " - "-> ()"); - m.impl("transfer_kv_per_layer_mla_direct", torch::kCUDA, &transfer_kv_per_layer_mla_direct); + "transfer_kv_per_layer_mla_pf_lf(Tensor src, Tensor dst, Tensor src_indices, Tensor dst_indices, int item_size, " + "int src_layout_dim, int block_quota, int num_warps_per_block) -> ()"); + m.impl("transfer_kv_per_layer_mla_pf_lf", torch::kCUDA, &transfer_kv_per_layer_mla_pf_lf); m.def( - "transfer_kv_all_layer_mla(Tensor src, Tensor dst, Tensor src_indices, Tensor dst_indices, int item_size, int " - "num_layers, int src_layer_offset, int dst_layer_offset, int block_quota, int num_warps_per_block) -> ()"); + "transfer_kv_all_layer_mla(Tensor src_layers, Tensor dst_layers, Tensor src_indices, Tensor dst_indices, int " + "item_size, int num_layers, int block_quota, int num_warps_per_block) -> ()"); m.impl("transfer_kv_all_layer_mla", torch::kCUDA, &transfer_kv_all_layer_mla); m.def( - "transfer_kv_all_layer_mla_direct(Tensor src, Tensor dst, Tensor src_indices, Tensor dst_indices, int page_size, " - "int num_layers) -> ()"); - m.impl("transfer_kv_all_layer_mla_direct", torch::kCUDA, &transfer_kv_all_layer_mla_direct); + "transfer_kv_all_layer_mla_lf_pf(Tensor src_layers, Tensor dst, Tensor src_indices, Tensor dst_indices, " + "int item_size, int dst_layout_dim, int num_layers, int block_quota, int num_warps_per_block) -> ()"); + m.impl("transfer_kv_all_layer_mla_lf_pf", torch::kCUDA, &transfer_kv_all_layer_mla_lf_pf); + m.def( + "transfer_kv_direct(Tensor[] src_layers, Tensor[] dst_layers, Tensor src_indices, Tensor dst_indices, int " + "page_size) -> ()"); + m.impl("transfer_kv_direct", torch::kCUDA, &transfer_kv_direct); /* * From csrc/moe/cutlass_moe/w4a8 diff --git a/sgl-kernel/csrc/kvcacheio/transfer.cu b/sgl-kernel/csrc/kvcacheio/transfer.cu index 6c939dd55c4c..cc6942e67731 100644 --- a/sgl-kernel/csrc/kvcacheio/transfer.cu +++ b/sgl-kernel/csrc/kvcacheio/transfer.cu @@ -22,17 +22,40 @@ transfer_item_warp(int32_t lane_id, const void* src_addr, void* dst_addr, int64_ } } -// todo, structs for different memory layout -__device__ __forceinline__ int64_t -get_global_offset_lf(int64_t layer_id, int64_t layer_dim, int64_t page_id, int64_t item_size_bytes) { +template +__device__ __forceinline__ T* get_global_offset_lf( + T* base, + const uintptr_t* __restrict__ /*unused*/, + int64_t layer_id, + int64_t layer_dim, + int64_t page_id, + int64_t item_size_bytes) { // layer first - return layer_id * layer_dim + page_id * item_size_bytes; + return base + layer_id * layer_dim + page_id * item_size_bytes; } -__device__ __forceinline__ int64_t -get_global_offset_pf(int64_t layer_id, int64_t page_dim, int64_t page_id, int64_t item_size_bytes) { +template +__device__ __forceinline__ T* get_global_offset_pf( + T* base, + const uintptr_t* __restrict__ /*unused*/, + int64_t layer_id, + int64_t page_dim, + int64_t page_id, + int64_t item_size_bytes) { // page first - return page_id * page_dim + layer_id * item_size_bytes; + return base + page_id * page_dim + layer_id * item_size_bytes; +} + +// get offset from layer base table when layers are not contiguous +template +__device__ __forceinline__ T* get_global_offset_lf_tbl( + T* /*unused*/, + const uintptr_t* __restrict__ layer_base_tbl, + int64_t layer_id, + int64_t /*unused*/, + int64_t page_id, + int64_t item_size_bytes) { + return reinterpret_cast(layer_base_tbl[layer_id]) + page_id * item_size_bytes; } template @@ -49,42 +72,37 @@ __global__ void transfer_kernel_impl( int64_t items_per_warp, int64_t item_size_bytes, int64_t src_layout_dim, - int64_t dst_layout_dim) { + int64_t dst_layout_dim, + const uintptr_t* __restrict__ src_k_layer_tbl, + const uintptr_t* __restrict__ dst_k_layer_tbl, + const uintptr_t* __restrict__ src_v_layer_tbl, + const uintptr_t* __restrict__ dst_v_layer_tbl) { int32_t tid = blockIdx.x * blockDim.x + threadIdx.x; int32_t lane_id = tid % 32; int32_t warp_id = tid / 32; for (int i = 0; i < items_per_warp; ++i) { - int32_t item_id = warp_id * items_per_warp + i; + int64_t item_id = warp_id * items_per_warp + i; if (item_id >= num_items) { - return; + break; } const int64_t src_page_id = src_indices[item_id]; const int64_t dst_page_id = dst_indices[item_id]; // Loop over layers if necessary for (int64_t layer_id = start_layer_id; layer_id < start_layer_id + num_layers_to_process; ++layer_id) { - // Calculate offsets using the provided function pointers - const int64_t src_offset = SrcOffsetFn(layer_id, src_layout_dim, src_page_id, item_size_bytes); - const int64_t dst_offset = DstOffsetFn(layer_id, dst_layout_dim, dst_page_id, item_size_bytes); + const char* src_ptr = SrcOffsetFn( + static_cast(src_k), src_k_layer_tbl, layer_id, src_layout_dim, src_page_id, item_size_bytes); + char* dst_ptr = DstOffsetFn( + static_cast(dst_k), dst_k_layer_tbl, layer_id, dst_layout_dim, dst_page_id, item_size_bytes); + transfer_item_warp(lane_id, src_ptr, dst_ptr, item_size_bytes); - if constexpr (IsMLA) { - transfer_item_warp( - lane_id, - static_cast(src_k) + src_offset, - static_cast(dst_k) + dst_offset, - item_size_bytes); - } else { - transfer_item_warp( - lane_id, - static_cast(src_k) + src_offset, - static_cast(dst_k) + dst_offset, - item_size_bytes); - transfer_item_warp( - lane_id, - static_cast(src_v) + src_offset, - static_cast(dst_v) + dst_offset, - item_size_bytes); + if constexpr (!IsMLA) { + const char* src_v_ptr = SrcOffsetFn( + static_cast(src_v), src_v_layer_tbl, layer_id, src_layout_dim, src_page_id, item_size_bytes); + char* dst_v_ptr = DstOffsetFn( + static_cast(dst_v), dst_v_layer_tbl, layer_id, dst_layout_dim, dst_page_id, item_size_bytes); + transfer_item_warp(lane_id, src_v_ptr, dst_v_ptr, item_size_bytes); } } } @@ -103,44 +121,54 @@ void transfer_kv_launcher( int64_t item_size, int64_t src_layout_dim, int64_t dst_layout_dim, + const at::Tensor& src_k_layers, + const at::Tensor& dst_k_layers, + const at::Tensor& src_v_layers, + const at::Tensor& dst_v_layers, int64_t block_quota, int64_t num_warps_per_block) { - TORCH_CHECK(src_k.scalar_type() == dst_k.scalar_type(), "Source and destination keys must have the same type"); TORCH_CHECK(src_indices.is_cuda(), "Source indices must be a CUDA tensor"); TORCH_CHECK(dst_indices.is_cuda(), "Destination indices must be a CUDA tensor"); TORCH_CHECK(src_indices.scalar_type() == at::kLong, "Source indices must be of type long"); TORCH_CHECK(dst_indices.scalar_type() == at::kLong, "Destination indices must be of type long"); TORCH_CHECK(src_indices.numel() == dst_indices.numel(), "Source and destination indices must have the same length"); + TORCH_CHECK(item_size % 8 == 0, "Item byte size must be divisible by 8"); - if (!IsMLA) { - TORCH_CHECK(src_v.scalar_type() == dst_v.scalar_type(), "Source and destination values must have the same type"); - } - - int dtype_size = src_k.element_size(); - TORCH_CHECK((item_size * dtype_size) % 8 == 0, "Item byte size must be divisible by 8"); - - auto div_up = [](int32_t x, int32_t y) { return (x + y - 1) / y; }; + auto div_up = [](int64_t x, int64_t y) { return (x + y - 1) / y; }; const int64_t num_items = src_indices.numel(); const int64_t items_per_warp = div_up(num_items, block_quota * num_warps_per_block); const int32_t num_blocks = div_up(num_items, items_per_warp * num_warps_per_block); dim3 grid_dim(num_blocks, 1, 1); const int32_t threads_per_block = num_warps_per_block * 32; + const void* src_k_ptr = src_k.defined() ? src_k.data_ptr() : nullptr; + void* dst_k_ptr = dst_k.defined() ? dst_k.data_ptr() : nullptr; + const void* src_v_ptr = IsMLA || !src_v.defined() ? nullptr : src_v.data_ptr(); + void* dst_v_ptr = IsMLA || !dst_v.defined() ? nullptr : dst_v.data_ptr(); + const uintptr_t* src_k_tbl_ptr = src_k_layers.defined() ? src_k_layers.data_ptr() : nullptr; + const uintptr_t* dst_k_tbl_ptr = dst_k_layers.defined() ? dst_k_layers.data_ptr() : nullptr; + const uintptr_t* src_v_tbl_ptr = IsMLA || !src_v_layers.defined() ? nullptr : src_v_layers.data_ptr(); + const uintptr_t* dst_v_tbl_ptr = IsMLA || !dst_v_layers.defined() ? nullptr : dst_v_layers.data_ptr(); + cudaStream_t torch_current_stream = at::cuda::getCurrentCUDAStream(); transfer_kernel_impl<<>>( - src_k.data_ptr(), - dst_k.data_ptr(), - (IsMLA ? nullptr : src_v.data_ptr()), - (IsMLA ? nullptr : dst_v.data_ptr()), + src_k_ptr, + dst_k_ptr, + src_v_ptr, + dst_v_ptr, src_indices.data_ptr(), dst_indices.data_ptr(), start_layer_id, num_layers_to_process, num_items, items_per_warp, - item_size * dtype_size, - src_layout_dim * dtype_size, - dst_layout_dim * dtype_size); + item_size, + src_layout_dim, + dst_layout_dim, + src_k_tbl_ptr, + dst_k_tbl_ptr, + src_v_tbl_ptr, + dst_v_tbl_ptr); C10_CUDA_KERNEL_LAUNCH_CHECK(); } @@ -154,11 +182,28 @@ void transfer_kv_per_layer( int64_t item_size, int64_t block_quota, int64_t num_warps_per_block) { - transfer_kv_launcher( - src_k, dst_k, src_v, dst_v, src_indices, dst_indices, 0, 1, item_size, 0, 0, block_quota, num_warps_per_block); + at::Tensor empty; + transfer_kv_launcher, get_global_offset_lf, false>( + src_k, + dst_k, + src_v, + dst_v, + src_indices, + dst_indices, + 0, + 1, + item_size, + 0, + 0, + empty, + empty, + empty, + empty, + block_quota, + num_warps_per_block); } -void transfer_kv_all_layer( +void transfer_kv_per_layer_pf_lf( const at::Tensor src_k, at::Tensor dst_k, const at::Tensor src_v, @@ -166,12 +211,11 @@ void transfer_kv_all_layer( const at::Tensor src_indices, const at::Tensor dst_indices, int64_t item_size, - int64_t num_layers, - int64_t src_layer_offset, - int64_t dst_layer_offset, + int64_t src_layout_dim, int64_t block_quota, int64_t num_warps_per_block) { - transfer_kv_launcher( + at::Tensor empty; + transfer_kv_launcher, get_global_offset_lf, false>( src_k, dst_k, src_v, @@ -179,10 +223,81 @@ void transfer_kv_all_layer( src_indices, dst_indices, 0, + 1, + item_size, + src_layout_dim, + 0, + empty, + empty, + empty, + empty, + block_quota, + num_warps_per_block); +} + +void transfer_kv_all_layer( + const at::Tensor src_k_layers, + const at::Tensor dst_k_layers, + const at::Tensor src_v_layers, + const at::Tensor dst_v_layers, + const at::Tensor src_indices, + const at::Tensor dst_indices, + int64_t item_size, + int64_t num_layers, + int64_t block_quota, + int64_t num_warps_per_block) { + TORCH_CHECK(num_layers == src_k_layers.size(0), "Number of layers in source k tensor does not match num_layers"); + at::Tensor empty; + transfer_kv_launcher, get_global_offset_lf_tbl, false>( + empty, + empty, + empty, + empty, + src_indices, + dst_indices, + 0, + num_layers, + item_size, + 0, + 0, + src_k_layers, + dst_k_layers, + src_v_layers, + dst_v_layers, + block_quota, + num_warps_per_block); +} + +void transfer_kv_all_layer_lf_pf( + const at::Tensor src_k_layers, + at::Tensor dst_k, + const at::Tensor src_v_layers, + at::Tensor dst_v, + const at::Tensor src_indices, + const at::Tensor dst_indices, + int64_t item_size, + int64_t dst_layout_dim, + int64_t num_layers, + int64_t block_quota, + int64_t num_warps_per_block) { + TORCH_CHECK(num_layers == src_k_layers.size(0), "Number of layers in source k tensor does not match num_layers"); + at::Tensor empty; + transfer_kv_launcher, get_global_offset_pf, false>( + empty, + dst_k, + empty, + dst_v, + src_indices, + dst_indices, + 0, num_layers, item_size, - src_layer_offset, - dst_layer_offset, + 0, + dst_layout_dim, + src_k_layers, + empty, + src_v_layers, + empty, block_quota, num_warps_per_block); } @@ -195,12 +310,12 @@ void transfer_kv_per_layer_mla( int64_t item_size, int64_t block_quota, int64_t num_warps_per_block) { - at::Tensor empty_tensor = at::Tensor(); - transfer_kv_launcher( + at::Tensor empty; + transfer_kv_launcher, get_global_offset_lf, true>( src, dst, - empty_tensor, - empty_tensor, + empty, + empty, src_indices, dst_indices, 0, @@ -208,41 +323,110 @@ void transfer_kv_per_layer_mla( item_size, 0, 0, + empty, + empty, + empty, + empty, block_quota, num_warps_per_block); } -void transfer_kv_all_layer_mla( +void transfer_kv_per_layer_mla_pf_lf( const at::Tensor src, at::Tensor dst, const at::Tensor src_indices, const at::Tensor dst_indices, int64_t item_size, - int64_t num_layers, - int64_t src_layer_offset, - int64_t dst_layer_offset, + int64_t src_layout_dim, int64_t block_quota, int64_t num_warps_per_block) { - at::Tensor empty_tensor = at::Tensor(); - transfer_kv_launcher( + at::Tensor empty; + transfer_kv_launcher, get_global_offset_lf, true>( src, dst, - empty_tensor, - empty_tensor, + empty, + empty, + src_indices, + dst_indices, + 0, + 1, + item_size, + src_layout_dim, + 0, + empty, + empty, + empty, + empty, + block_quota, + num_warps_per_block); +} + +void transfer_kv_all_layer_mla( + const at::Tensor src_layers, + const at::Tensor dst_layers, + const at::Tensor src_indices, + const at::Tensor dst_indices, + int64_t item_size, + int64_t num_layers, + int64_t block_quota, + int64_t num_warps_per_block) { + TORCH_CHECK(num_layers == src_layers.size(0), "Number of layers in source tensor does not match num_layers"); + at::Tensor empty; + transfer_kv_launcher, get_global_offset_lf_tbl, true>( + empty, + empty, + empty, + empty, + src_indices, + dst_indices, + 0, + num_layers, + item_size, + 0, + 0, + src_layers, + dst_layers, + empty, + empty, + block_quota, + num_warps_per_block); +} + +void transfer_kv_all_layer_mla_lf_pf( + const at::Tensor src_layers, + at::Tensor dst, + const at::Tensor src_indices, + const at::Tensor dst_indices, + int64_t item_size, + int64_t dst_layout_dim, + int64_t num_layers, + int64_t block_quota, + int64_t num_warps_per_block) { + TORCH_CHECK(num_layers == src_layers.size(0), "Number of layers in source tensor does not match num_layers"); + at::Tensor empty; + transfer_kv_launcher, get_global_offset_pf, true>( + empty, + dst, + empty, + empty, src_indices, dst_indices, 0, num_layers, item_size, - src_layer_offset, - dst_layer_offset, + 0, + dst_layout_dim, + src_layers, + empty, + empty, + empty, block_quota, num_warps_per_block); } inline void transfer_page_direct( - const at::Tensor src_buffer, - at::Tensor dst_buffer, + const at::Tensor& src_buffer, + at::Tensor& dst_buffer, int64_t src_page_index, int64_t dst_page_index, int64_t page_size) { @@ -252,16 +436,14 @@ inline void transfer_page_direct( /* non_blocking= */ true); } -template -inline void transfer_kv_direct_impl( - const at::Tensor& src_k, - at::Tensor& dst_k, - const at::Tensor& src_v_opt, // Only used when IsMLA is false (for src_v) - at::Tensor& dst_v_opt, // Only used when IsMLA is false (for dst_v) - const at::Tensor& src_indices, - const at::Tensor& dst_indices, - int64_t page_size, - int64_t num_layers = 1) { +void transfer_kv_direct( + const std::vector& src_layers, + std::vector dst_layers, + const at::Tensor src_indices, + const at::Tensor dst_indices, + int64_t page_size) { + TORCH_CHECK( + src_layers.size() == dst_layers.size(), "Source and destination layers must have the same number of layers"); TORCH_CHECK(src_indices.numel() == dst_indices.numel(), "Source and destination indices must have the same length"); TORCH_CHECK(page_size > 0, "Page size must be positive"); TORCH_CHECK(src_indices.numel() % page_size == 0, "Source indices size must be divisible by page size"); @@ -270,73 +452,14 @@ inline void transfer_kv_direct_impl( auto dst_indices_cpu = dst_indices.cpu(); const int64_t num_pages = src_indices_cpu.size(0) / page_size; + const int64_t num_layers = src_layers.size(); - for (const auto i : c10::irange(num_pages)) { - auto s_index = src_indices_cpu[i * page_size].item(); - auto d_index = dst_indices_cpu[i * page_size].item(); + for (int64_t i = 0; i < num_pages; ++i) { + auto src_index = src_indices_cpu[i * page_size].item(); + auto dst_index = dst_indices_cpu[i * page_size].item(); - if constexpr (AllLayers) { - for (const auto j : c10::irange(num_layers)) { - if constexpr (IsMLA) { - transfer_page_direct(src_k.select(0, j), dst_k.select(0, j), s_index, d_index, page_size); - } else { - transfer_page_direct(src_k.select(0, j), dst_k.select(0, j), s_index, d_index, page_size); - transfer_page_direct(src_v_opt.select(0, j), dst_v_opt.select(0, j), s_index, d_index, page_size); - } - } - } else { // Per-layer - if constexpr (IsMLA) { - transfer_page_direct(src_k, dst_k, s_index, d_index, page_size); - } else { - transfer_page_direct(src_k, dst_k, s_index, d_index, page_size); - transfer_page_direct(src_v_opt, dst_v_opt, s_index, d_index, page_size); - } + for (int64_t j = 0; j < num_layers; ++j) { + transfer_page_direct(src_layers[j], dst_layers[j], src_index, dst_index, page_size); } } } - -void transfer_kv_per_layer_direct( - const at::Tensor src_k, - at::Tensor dst_k, - const at::Tensor src_v, - at::Tensor dst_v, - const at::Tensor src_indices, - const at::Tensor dst_indices, - int64_t page_size) { - transfer_kv_direct_impl(src_k, dst_k, src_v, dst_v, src_indices, dst_indices, page_size); -} - -void transfer_kv_all_layer_direct( - const at::Tensor src_k, - at::Tensor dst_k, - const at::Tensor src_v, - at::Tensor dst_v, - const at::Tensor src_indices, - const at::Tensor dst_indices, - int64_t page_size, - int64_t num_layers) { - transfer_kv_direct_impl(src_k, dst_k, src_v, dst_v, src_indices, dst_indices, page_size, num_layers); -} - -void transfer_kv_per_layer_mla_direct( - const at::Tensor src, - at::Tensor dst, - const at::Tensor src_indices, - const at::Tensor dst_indices, - int64_t page_size) { - at::Tensor empty_tensor = at::Tensor(); - - transfer_kv_direct_impl(src, dst, empty_tensor, empty_tensor, src_indices, dst_indices, page_size); -} - -void transfer_kv_all_layer_mla_direct( - const at::Tensor src, - at::Tensor dst, - const at::Tensor src_indices, - const at::Tensor dst_indices, - int64_t page_size, - int64_t num_layers) { - at::Tensor empty_tensor = at::Tensor(); - transfer_kv_direct_impl( - src, dst, empty_tensor, empty_tensor, src_indices, dst_indices, page_size, num_layers); -} diff --git a/sgl-kernel/include/sgl_kernel_ops.h b/sgl-kernel/include/sgl_kernel_ops.h index df06bd3cdcf3..6b589101feaa 100644 --- a/sgl-kernel/include/sgl_kernel_ops.h +++ b/sgl-kernel/include/sgl_kernel_ops.h @@ -399,38 +399,42 @@ void transfer_kv_per_layer( int64_t block_quota, int64_t num_warps_per_block); -void transfer_kv_per_layer_direct( +void transfer_kv_per_layer_pf_lf( const at::Tensor src_k, at::Tensor dst_k, const at::Tensor src_v, at::Tensor dst_v, const at::Tensor src_indices, const at::Tensor dst_indices, - int64_t page_size); + int64_t item_size, + int64_t src_layout_dim, + int64_t block_quota, + int64_t num_warps_per_block); void transfer_kv_all_layer( - const at::Tensor src_k, - at::Tensor dst_k, - const at::Tensor src_v, - at::Tensor dst_v, + const at::Tensor src_k_layers, + const at::Tensor dst_k_layers, + const at::Tensor src_v_layers, + const at::Tensor dst_v_layers, const at::Tensor src_indices, const at::Tensor dst_indices, int64_t item_size, int64_t num_layers, - int64_t src_layer_offset, - int64_t dst_layer_offset, int64_t block_quota, int64_t num_warps_per_block); -void transfer_kv_all_layer_direct( - const at::Tensor src_k, +void transfer_kv_all_layer_lf_pf( + const at::Tensor src_k_layers, at::Tensor dst_k, - const at::Tensor src_v, + const at::Tensor src_v_layers, at::Tensor dst_v, const at::Tensor src_indices, const at::Tensor dst_indices, - int64_t page_size, - int64_t num_layers); + int64_t item_size, + int64_t dst_layout_dim, + int64_t num_layers, + int64_t block_quota, + int64_t num_warps_per_block); void transfer_kv_per_layer_mla( const at::Tensor src, @@ -441,32 +445,43 @@ void transfer_kv_per_layer_mla( int64_t block_quota, int64_t num_warps_per_block); -void transfer_kv_per_layer_mla_direct( +void transfer_kv_per_layer_mla_pf_lf( const at::Tensor src, at::Tensor dst, const at::Tensor src_indices, const at::Tensor dst_indices, - int64_t page_size); + int64_t item_size, + int64_t src_layout_dim, + int64_t block_quota, + int64_t num_warps_per_block); void transfer_kv_all_layer_mla( - const at::Tensor src, - at::Tensor dst, + const at::Tensor src_layers, + const at::Tensor dst_layers, const at::Tensor src_indices, const at::Tensor dst_indices, int64_t item_size, int64_t num_layers, - int64_t src_layer_offset, - int64_t dst_layer_offset, int64_t block_quota, int64_t num_warps_per_block); -void transfer_kv_all_layer_mla_direct( - const at::Tensor src, +void transfer_kv_all_layer_mla_lf_pf( + const at::Tensor src_layers, at::Tensor dst, const at::Tensor src_indices, const at::Tensor dst_indices, - int64_t page_size, - int64_t num_layers); + int64_t item_size, + int64_t dst_layout_dim, + int64_t num_layers, + int64_t block_quota, + int64_t num_warps_per_block); + +void transfer_kv_direct( + const std::vector& src_layers, + std::vector dst_layers, + const at::Tensor src_indices, + const at::Tensor dst_indices, + int64_t page_size); /* * From csrc/moe/cutlass_moe/w4a8 diff --git a/sgl-kernel/python/sgl_kernel/kvcacheio.py b/sgl-kernel/python/sgl_kernel/kvcacheio.py index 5350e49ddbcf..1440c2ca35ec 100644 --- a/sgl-kernel/python/sgl_kernel/kvcacheio.py +++ b/sgl-kernel/python/sgl_kernel/kvcacheio.py @@ -1,3 +1,5 @@ +from typing import List + import torch @@ -22,57 +24,116 @@ def transfer_kv_per_layer( dst_v, src_indices, dst_indices, - item_size, + item_size * src_k.element_size(), # todo, hot fix for compatibility block_quota, num_warps_per_block, ) elif io_backend == "direct": - torch.ops.sgl_kernel.transfer_kv_per_layer_direct( - src_k, dst_k, src_v, dst_v, src_indices, dst_indices, page_size + torch.ops.sgl_kernel.transfer_kv_direct( + [src_k, src_v], [dst_k, dst_v], src_indices, dst_indices, page_size ) else: raise ValueError(f"Unsupported io backend") -def transfer_kv_all_layer( +def transfer_kv_per_layer_pf_lf( src_k: torch.Tensor, dst_k: torch.Tensor, src_v: torch.Tensor, dst_v: torch.Tensor, src_indices: torch.Tensor, dst_indices: torch.Tensor, + item_size: int, + src_layout_dim: int, + block_quota: int = 2, + num_warps_per_block: int = 32, +): + torch.ops.sgl_kernel.transfer_kv_per_layer_pf_lf( + src_k, + dst_k, + src_v, + dst_v, + src_indices, + dst_indices, + item_size, + src_layout_dim, + block_quota, + num_warps_per_block, + ) + + +def transfer_kv_all_layer( + src_k_layers: torch.Tensor, + dst_k_layers: torch.Tensor, + src_v_layers: torch.Tensor, + dst_v_layers: torch.Tensor, + src_indices: torch.Tensor, + dst_indices: torch.Tensor, io_backend: str, - page_size: int, item_size: int, num_layers: int, - src_layer_offset: int, - dst_layer_offset: int, block_quota: int = 2, num_warps_per_block: int = 32, ): if io_backend == "kernel": torch.ops.sgl_kernel.transfer_kv_all_layer( - src_k, - dst_k, - src_v, - dst_v, + src_k_layers, + dst_k_layers, + src_v_layers, + dst_v_layers, src_indices, dst_indices, item_size, num_layers, - src_layer_offset, - dst_layer_offset, block_quota, num_warps_per_block, ) elif io_backend == "direct": - torch.ops.sgl_kernel.transfer_kv_all_layer_direct( - src_k, dst_k, src_v, dst_v, src_indices, dst_indices, page_size, num_layers - ) + raise NotImplementedError("Deprecated interface") else: raise ValueError(f"Unsupported io backend") +def transfer_kv_all_layer_lf_pf( + src_k_layers: torch.Tensor, + dst_k: torch.Tensor, + src_v_layers: torch.Tensor, + dst_v: torch.Tensor, + src_indices: torch.Tensor, + dst_indices: torch.Tensor, + item_size: int, + dst_layout_dim: int, + num_layers: int, + block_quota: int = 2, + num_warps_per_block: int = 32, +): + torch.ops.sgl_kernel.transfer_kv_all_layer_lf_pf( + src_k_layers, + dst_k, + src_v_layers, + dst_v, + src_indices, + dst_indices, + item_size, + dst_layout_dim, + num_layers, + block_quota, + num_warps_per_block, + ) + + +def transfer_kv_direct( + src_layers: List[torch.Tensor], + dst_layers: List[torch.Tensor], + src_indices: torch.Tensor, + dst_indices: torch.Tensor, + page_size: int, +): + torch.ops.sgl_kernel.transfer_kv_direct( + src_layers, dst_layers, src_indices, dst_indices, page_size + ) + + def transfer_kv_per_layer_mla( src: torch.Tensor, dst: torch.Tensor, @@ -90,48 +151,87 @@ def transfer_kv_per_layer_mla( dst, src_indices, dst_indices, - item_size, + item_size * src.element_size(), # todo, hot fix for compatibility block_quota, num_warps_per_block, ) elif io_backend == "direct": - torch.ops.sgl_kernel.transfer_kv_per_layer_mla_direct( - src, dst, src_indices, dst_indices, page_size + torch.ops.sgl_kernel.transfer_kv_direct( + [src], [dst], src_indices, dst_indices, page_size ) else: raise ValueError(f"Unsupported io backend") -def transfer_kv_all_layer_mla( +def transfer_kv_per_layer_mla_pf_lf( src: torch.Tensor, dst: torch.Tensor, src_indices: torch.Tensor, dst_indices: torch.Tensor, + item_size: int, + src_layout_dim: int, + block_quota: int = 2, + num_warps_per_block: int = 32, +): + torch.ops.sgl_kernel.transfer_kv_per_layer_mla_pf_lf( + src, + dst, + src_indices, + dst_indices, + item_size, + src_layout_dim, + block_quota, + num_warps_per_block, + ) + + +def transfer_kv_all_layer_mla( + src_layers: torch.Tensor, + dst_layers: torch.Tensor, + src_indices: torch.Tensor, + dst_indices: torch.Tensor, io_backend: str, - page_size: int, item_size: int, num_layers: int, - src_layer_offset: int, - dst_layer_offset: int, block_quota: int = 2, num_warps_per_block: int = 32, ): if io_backend == "kernel": torch.ops.sgl_kernel.transfer_kv_all_layer_mla( - src, - dst, + src_layers, + dst_layers, src_indices, dst_indices, item_size, num_layers, - src_layer_offset, - dst_layer_offset, block_quota, num_warps_per_block, ) elif io_backend == "direct": - torch.ops.sgl_kernel.transfer_kv_all_layer_mla_direct( - src, dst, src_indices, dst_indices, page_size, num_layers - ) + raise NotImplementedError("Deprecated interface") else: raise ValueError(f"Unsupported io backend") + + +def transfer_kv_all_layer_mla_lf_pf( + src_layers: torch.Tensor, + dst: torch.Tensor, + src_indices: torch.Tensor, + dst_indices: torch.Tensor, + item_size: int, + dst_layout_dim: int, + num_layers: int, + block_quota: int = 2, + num_warps_per_block: int = 32, +): + torch.ops.sgl_kernel.transfer_kv_all_layer_mla_lf_pf( + src_layers, + dst, + src_indices, + dst_indices, + item_size, + dst_layout_dim, + num_layers, + block_quota, + num_warps_per_block, + ) diff --git a/sgl-kernel/tests/test_kvcacheio.py b/sgl-kernel/tests/test_kvcacheio.py index 635b5ba507ab..171fc4ca4793 100644 --- a/sgl-kernel/tests/test_kvcacheio.py +++ b/sgl-kernel/tests/test_kvcacheio.py @@ -3,6 +3,7 @@ from sgl_kernel.kvcacheio import ( transfer_kv_all_layer, transfer_kv_all_layer_mla, + transfer_kv_direct, transfer_kv_per_layer, transfer_kv_per_layer_mla, ) @@ -104,14 +105,12 @@ def test_transfer_kv( page_size=page_size, item_size=item_size, ) - transfer_kv_per_layer_mla( - src_pool_host[layer_idx_to_test], - dst_pool_direct[layer_idx_to_test], + transfer_kv_direct( + [src_pool_host[layer_idx_to_test]], + [dst_pool_direct[layer_idx_to_test]], src_indices_host, dst_indices_device, - io_backend="direct", page_size=page_size, - item_size=item_size, ) else: for layer_id in range(num_layers): @@ -121,29 +120,34 @@ def test_transfer_kv( src_indices_host, dst_indices_device, ) + src_layers_device = torch.tensor( + [src_pool_host[layer_id].data_ptr() for layer_id in range(num_layers)], + dtype=torch.uint64, + device=device, + ) + dst_layers_device = torch.tensor( + [ + dst_pool_kernel[layer_id].data_ptr() + for layer_id in range(num_layers) + ], + dtype=torch.uint64, + device=device, + ) transfer_kv_all_layer_mla( - src_pool_host, - dst_pool_kernel, + src_layers_device, + dst_layers_device, src_indices_device, dst_indices_device, io_backend="kernel", - page_size=page_size, - item_size=item_size, + item_size=item_size * dtype.itemsize, num_layers=num_layers, - src_layer_offset=total_items_in_pool * item_size, - dst_layer_offset=total_items_in_pool * item_size, ) - transfer_kv_all_layer_mla( - src_pool_host, - dst_pool_direct, + transfer_kv_direct( + [src_pool_host[layer_id] for layer_id in range(num_layers)], + [dst_pool_direct[layer_id] for layer_id in range(num_layers)], src_indices_host, dst_indices_device, - io_backend="direct", page_size=page_size, - item_size=item_size, - num_layers=num_layers, - src_layer_offset=total_items_in_pool * item_size, - dst_layer_offset=total_items_in_pool * item_size, ) torch.cuda.synchronize() torch.testing.assert_close(dst_pool_kernel, dst_pool_ref) @@ -173,16 +177,15 @@ def test_transfer_kv( page_size=page_size, item_size=item_size, ) - transfer_kv_per_layer( - src_k_pool[layer_idx_to_test], - dst_k_pool_direct[layer_idx_to_test], - src_v_pool[layer_idx_to_test], - dst_v_pool_direct[layer_idx_to_test], + transfer_kv_direct( + [src_k_pool[layer_idx_to_test], src_v_pool[layer_idx_to_test]], + [ + dst_k_pool_direct[layer_idx_to_test], + dst_v_pool_direct[layer_idx_to_test], + ], src_indices_host, dst_indices_device, - io_backend="direct", page_size=page_size, - item_size=item_size, ) else: for layer_id in range(num_layers): @@ -198,33 +201,52 @@ def test_transfer_kv( src_indices_host, dst_indices_device, ) + + src_k_layers_device = torch.tensor( + [src_k_pool[layer_id].data_ptr() for layer_id in range(num_layers)], + dtype=torch.uint64, + device=device, + ) + src_v_layers_device = torch.tensor( + [src_v_pool[layer_id].data_ptr() for layer_id in range(num_layers)], + dtype=torch.uint64, + device=device, + ) + dst_k_layers_device = torch.tensor( + [ + dst_k_pool_kernel[layer_id].data_ptr() + for layer_id in range(num_layers) + ], + dtype=torch.uint64, + device=device, + ) + dst_v_layers_device = torch.tensor( + [ + dst_v_pool_kernel[layer_id].data_ptr() + for layer_id in range(num_layers) + ], + dtype=torch.uint64, + device=device, + ) transfer_kv_all_layer( - src_k_pool, - dst_k_pool_kernel, - src_v_pool, - dst_v_pool_kernel, + src_k_layers_device, + dst_k_layers_device, + src_v_layers_device, + dst_v_layers_device, src_indices_device, dst_indices_device, io_backend="kernel", - page_size=page_size, - item_size=item_size, + item_size=item_size * dtype.itemsize, num_layers=num_layers, - src_layer_offset=total_items_in_pool * item_size, - dst_layer_offset=total_items_in_pool * item_size, ) - transfer_kv_all_layer( - src_k_pool, - dst_k_pool_direct, - src_v_pool, - dst_v_pool_direct, + transfer_kv_direct( + [src_k_pool[layer_id] for layer_id in range(num_layers)] + + [src_v_pool[layer_id] for layer_id in range(num_layers)], + [dst_k_pool_direct[layer_id] for layer_id in range(num_layers)] + + [dst_v_pool_direct[layer_id] for layer_id in range(num_layers)], src_indices_host, dst_indices_device, - io_backend="direct", page_size=page_size, - item_size=item_size, - num_layers=num_layers, - src_layer_offset=total_items_in_pool * item_size, - dst_layer_offset=total_items_in_pool * item_size, ) torch.cuda.synchronize() torch.testing.assert_close(dst_k_pool_kernel, dst_k_pool_ref) From ce86e201df7f2c60677c975f107e080687c07996 Mon Sep 17 00:00:00 2001 From: Zhiqiang Xie Date: Wed, 23 Jul 2025 01:50:31 -0700 Subject: [PATCH 097/396] bug fix and tag (#8282) --- benchmark/hicache/bench_multiturn.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/benchmark/hicache/bench_multiturn.py b/benchmark/hicache/bench_multiturn.py index 5e954ecd6466..311632525172 100644 --- a/benchmark/hicache/bench_multiturn.py +++ b/benchmark/hicache/bench_multiturn.py @@ -121,6 +121,12 @@ def parse_args(): default="random", help="Policy for popping requests from the ready queue (random or fifo)", ) + parser.add_argument( + "--tag", + type=str, + default="", + help="Tag of a certain run in the log file", + ) parser.add_argument("--seed", type=int, default=1, help="The random seed.") return parser.parse_args() @@ -202,9 +208,9 @@ def gen_payload(prompt, output_len): return payload -def log_to_jsonl_file(data, file_path="performance_metrics.jsonl"): - """Append the data with a timestamp to the specified JSONL file.""" - timestamped_data = {"timestamp": datetime.now().isoformat(), **data} +def log_to_jsonl_file(data, file_path="performance_metrics.jsonl", tag=""): + """Append the data with a timestamp and tag to the specified JSONL file.""" + timestamped_data = {"timestamp": datetime.now().isoformat(), "tag": tag, **data} try: with open(file_path, "a") as file: file.write( @@ -360,7 +366,7 @@ def response_handler(self): # append new request to client's history self.client_records[client_id][ "history" - ] += self.sub_question_inputs.pop() + ] += self.sub_question_inputs.pop().prompt self.ready_queue.append( ( client_id, @@ -428,7 +434,7 @@ def run(self): print( f" Throughput: {performance_data['summary']['throughput']:.2f} requests per second" ) - log_to_jsonl_file(performance_data, args.log_file) + log_to_jsonl_file(performance_data, args.log_file, tag=args.tag) if __name__ == "__main__": From f39037fffbeb463595a1e31d72c85e53b6e7d355 Mon Sep 17 00:00:00 2001 From: Zhiqiang Xie Date: Wed, 23 Jul 2025 01:51:32 -0700 Subject: [PATCH 098/396] HiCache Fix (#8288) Co-authored-by: pansicheng --- python/sglang/srt/managers/cache_controller.py | 1 + python/sglang/srt/model_executor/model_runner.py | 5 ++++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/cache_controller.py b/python/sglang/srt/managers/cache_controller.py index 5f43a5e9a033..a94fdec78c32 100644 --- a/python/sglang/srt/managers/cache_controller.py +++ b/python/sglang/srt/managers/cache_controller.py @@ -358,6 +358,7 @@ def write( if host_indices is None: return None self.mem_pool_host.protect_write(host_indices) + torch.cuda.current_stream().synchronize() self.write_queue.put( CacheOperation(host_indices, device_indices, node_id, priority) ) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 9e6d14aaca55..919622cc77d1 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -378,6 +378,7 @@ def model_specific_adjustment(self): is_hopper_with_cuda_12_3() and is_no_spec_infer_or_topk_one(server_args) and is_fa3_default_architecture(self.model_config.hf_config) + and (not server_args.enable_hierarchical_cache) ): server_args.attention_backend = "fa3" elif _is_hip: @@ -390,7 +391,9 @@ def model_specific_adjustment(self): ) else: # MLA architecture - if is_hopper_with_cuda_12_3(): + if is_hopper_with_cuda_12_3() and ( + not server_args.enable_hierarchical_cache + ): server_args.attention_backend = "fa3" elif is_sm100_supported(): server_args.attention_backend = "flashinfer" From 0c8dab9e67b1fe0d274a27af03540b2ce5525a37 Mon Sep 17 00:00:00 2001 From: Yuan Luo Date: Wed, 23 Jul 2025 21:22:59 +0800 Subject: [PATCH 099/396] [sgl-kernel] Opt per_token_quant_fp8 with warp reduce (#8130) Co-authored-by: luoyuan.luo --- sgl-kernel/csrc/gemm/per_token_quant_fp8.cu | 122 +++++++++++++++++--- 1 file changed, 106 insertions(+), 16 deletions(-) diff --git a/sgl-kernel/csrc/gemm/per_token_quant_fp8.cu b/sgl-kernel/csrc/gemm/per_token_quant_fp8.cu index db09483ce9b0..9367f1584362 100644 --- a/sgl-kernel/csrc/gemm/per_token_quant_fp8.cu +++ b/sgl-kernel/csrc/gemm/per_token_quant_fp8.cu @@ -1,18 +1,95 @@ #include #include -#include #include #include "utils.h" -template +static constexpr int kWarpSize = 32; + +// --------------------------------------------------------------------------- +// 1. Warp‑local, no shared memory +// • One warp handles one token. +// • Eight tokens per 256‑thread CTA. +// --------------------------------------------------------------------------- +template __global__ void per_token_quant_fp8_kernel( const T* __restrict__ input, FP8_TYPE* __restrict__ output_q, float* __restrict__ output_s, const int64_t hidden_dim, const int64_t num_tokens) { + const int warp_id = threadIdx.x / kWarpSize; // 0‑7 (8 warps) + const int lane_id = threadIdx.x & (kWarpSize - 1); // 0‑31 + const int token_id = blockIdx.x * kTokensPerCTA + warp_id; + if (token_id >= num_tokens) return; + + // Global tensors for this token + const T* token_input = input + token_id * hidden_dim; + FP8_TYPE* token_output = output_q + token_id * hidden_dim; + float* token_scale = output_s + token_id; + + // + // Pass-1: Perform a warp reduce to find the max_value of a token's hidden_dim + // + float max_value = 0.f; + using vec_t = flashinfer::vec_t; + const int32_t num_vec_elems = hidden_dim / kVecSize; + + for (int32_t i = lane_id; i < num_vec_elems; i += kWarpSize) { + vec_t input_vec; + input_vec.cast_load(token_input + i * kVecSize); + +#pragma unroll + for (uint32_t j = 0; j < kVecSize; ++j) { + max_value = fmaxf(max_value, fabsf(static_cast(input_vec[j]))); + } + } + + float warp_max = warpReduceMax(max_value); + + __shared__ float scale; + scale = warp_max / FP8_E4M3_MAX; + // Broadcast scale + if (lane_id == 0) { + token_scale[0] = scale; + } + float scale_inv = (scale == 0.f) ? 0.f : 1.0f / scale; + + // + // Pass-2: quantize and write back + // + for (int i = lane_id; i < num_vec_elems; i += kWarpSize) { + vec_t input_vec; + input_vec.cast_load(token_input + i * kVecSize); + FP8_TYPE output_arr[kVecSize]; +#pragma unroll + for (uint32_t j = 0; j < kVecSize; ++j) { + float val = static_cast(input_vec[j]) * scale_inv; + val = fmaxf(fminf(val, FP8_E4M3_MAX), -FP8_E4M3_MAX); + +#ifndef USE_ROCM + output_arr[j] = static_cast(val); +#else + output_arr[j] = c10::Float8_e4m3fnuz( + __hip_cvt_float_to_fp8(val, fp8::fp8_type::__default_saturation, fp8::fp8_type::__default_interpret), + c10::Float8_e4m3fnuz::from_bits()); +#endif + } + *(uint4*)(token_output + i * kVecSize) = *(uint4*)output_arr; + } +} + +// --------------------------------------------------------------------------- +// 2. Baseline kernel (1 token / CTA, CUB block reduce) +// --------------------------------------------------------------------------- +template +__global__ void per_token_quant_fp8_small_batch_kernel( + const T* __restrict__ input, + FP8_TYPE* __restrict__ output_q, + float* __restrict__ output_s, + const int64_t hidden_dim, + const int64_t num_tokens) { const int token_idx = blockIdx.x; if (token_idx >= num_tokens) return; @@ -79,28 +156,41 @@ void sgl_per_token_quant_fp8(torch::Tensor input, torch::Tensor output_q, torch: CHECK_INPUT(input); CHECK_INPUT(output_q); CHECK_INPUT(output_s); - const auto input_sizes = input.sizes(); const int64_t num_tokens = input_sizes[0]; const int64_t hidden_dim = input_sizes[1]; - TORCH_CHECK(hidden_dim % 16 == 0, "Hidden dimension must be divisible by 16, but got ", hidden_dim); - const int block_size = 256; - const int num_blocks = num_tokens; - - dim3 grid(num_blocks); - dim3 block(block_size); - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + // Hard-code sm_count + int sm_count = 132; + constexpr int TOKENS_PER_CTA = 8; + const bool use_warp_kernel = (num_tokens >= sm_count * 2 * TOKENS_PER_CTA); DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), scalar_t, [&] { - per_token_quant_fp8_kernel<<>>( - static_cast(input.data_ptr()), - static_cast(output_q.data_ptr()), - static_cast(output_s.data_ptr()), - hidden_dim, - num_tokens); + if (use_warp_kernel) { + // -------- warp‑local --------------------------------------------------- + constexpr int THREADS = TOKENS_PER_CTA * kWarpSize; // 256 + dim3 grid((num_tokens + TOKENS_PER_CTA - 1) / TOKENS_PER_CTA); + dim3 block(THREADS); + per_token_quant_fp8_kernel<<>>( + static_cast(input.data_ptr()), + static_cast(output_q.data_ptr()), + static_cast(output_s.data_ptr()), + hidden_dim, + num_tokens); + } else { + // -------- baseline ----------------------------------------------------- + constexpr int THREADS = 256; + dim3 grid(num_tokens); + dim3 block(THREADS); + per_token_quant_fp8_small_batch_kernel<<>>( + static_cast(input.data_ptr()), + static_cast(output_q.data_ptr()), + static_cast(output_s.data_ptr()), + hidden_dim, + num_tokens); + } return true; }); } From 6f8f4aeea458ae7ba5a54619b1f108aab6076726 Mon Sep 17 00:00:00 2001 From: Simo Lin Date: Wed, 23 Jul 2025 10:07:51 -0700 Subject: [PATCH 100/396] [router] add common ut infra to mock worker and app (#8295) --- sgl-router/Cargo.toml | 3 + sgl-router/tests/common/mock_worker.rs | 650 +++++++++++++++++++++++++ sgl-router/tests/common/mod.rs | 56 +++ 3 files changed, 709 insertions(+) create mode 100644 sgl-router/tests/common/mock_worker.rs create mode 100644 sgl-router/tests/common/mod.rs diff --git a/sgl-router/Cargo.toml b/sgl-router/Cargo.toml index b23b6d7ac3e4..74b1ed129026 100644 --- a/sgl-router/Cargo.toml +++ b/sgl-router/Cargo.toml @@ -42,6 +42,9 @@ url = "2.5.4" [dev-dependencies] criterion = { version = "0.5", features = ["html_reports"] } +tokio-stream = "0.1" +actix-http = "3.0" +futures = "0.3" [[bench]] name = "request_processing" diff --git a/sgl-router/tests/common/mock_worker.rs b/sgl-router/tests/common/mock_worker.rs new file mode 100644 index 000000000000..c5129febc895 --- /dev/null +++ b/sgl-router/tests/common/mock_worker.rs @@ -0,0 +1,650 @@ +use actix_web::{middleware, web, App, HttpRequest, HttpResponse, HttpServer}; +use futures_util::StreamExt; +use serde_json::json; +use std::sync::Arc; +use std::time::{SystemTime, UNIX_EPOCH}; +use tokio::sync::RwLock; +use uuid; + +/// Configuration for mock worker behavior +#[derive(Clone)] +pub struct MockWorkerConfig { + pub port: u16, + pub worker_type: WorkerType, + pub health_status: HealthStatus, + pub response_delay_ms: u64, + pub fail_rate: f32, +} + +#[derive(Clone, Debug)] +pub enum WorkerType { + Regular, + Prefill, + Decode, +} + +#[derive(Clone, Debug)] +pub enum HealthStatus { + Healthy, + Unhealthy, + Degraded, +} + +/// Mock worker server for testing +pub struct MockWorker { + config: Arc>, + server_handle: Option, +} + +impl MockWorker { + pub fn new(config: MockWorkerConfig) -> Self { + Self { + config: Arc::new(RwLock::new(config)), + server_handle: None, + } + } + + /// Start the mock worker server + pub async fn start(&mut self) -> Result> { + let config = self.config.clone(); + let port = config.read().await.port; + + let server = HttpServer::new(move || { + App::new() + .app_data(web::Data::new(config.clone())) + .wrap(middleware::Logger::default()) + .route("/health", web::get().to(health_handler)) + .route("/health_generate", web::get().to(health_generate_handler)) + .route("/get_server_info", web::get().to(server_info_handler)) + .route("/get_model_info", web::get().to(model_info_handler)) + .route("/generate", web::post().to(generate_handler)) + .route( + "/v1/chat/completions", + web::post().to(chat_completions_handler), + ) + .route("/v1/completions", web::post().to(completions_handler)) + .route("/flush_cache", web::post().to(flush_cache_handler)) + .route("/v1/models", web::get().to(v1_models_handler)) + }) + .bind(("127.0.0.1", port))? + .run(); + + let handle = server.handle(); + self.server_handle = Some(handle); + + tokio::spawn(server); + + Ok(format!("http://127.0.0.1:{}", port)) + } + + /// Stop the mock worker server + pub async fn stop(&mut self) { + if let Some(handle) = self.server_handle.take() { + // First try graceful stop with short timeout + handle.stop(false); + // Give it a moment to stop gracefully + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + } + } + + /// Update the mock worker configuration + pub async fn update_config(&self, updater: F) + where + F: FnOnce(&mut MockWorkerConfig), + { + let mut config = self.config.write().await; + updater(&mut *config); + } +} + +// Handler implementations + +async fn health_handler(config: web::Data>>) -> HttpResponse { + let config = config.read().await; + + match config.health_status { + HealthStatus::Healthy => HttpResponse::Ok().json(json!({ + "status": "healthy", + "timestamp": SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs(), + "worker_type": format!("{:?}", config.worker_type), + })), + HealthStatus::Unhealthy => HttpResponse::ServiceUnavailable().json(json!({ + "status": "unhealthy", + "error": "Worker is not responding" + })), + HealthStatus::Degraded => HttpResponse::Ok().json(json!({ + "status": "degraded", + "warning": "High load detected" + })), + } +} + +async fn health_generate_handler(config: web::Data>>) -> HttpResponse { + let config = config.read().await; + + if matches!(config.health_status, HealthStatus::Healthy) { + HttpResponse::Ok().json(json!({ + "status": "ok", + "queue_length": 0, + "processing_time_ms": config.response_delay_ms + })) + } else { + HttpResponse::ServiceUnavailable().json(json!({ + "error": "Generation service unavailable" + })) + } +} + +async fn server_info_handler(config: web::Data>>) -> HttpResponse { + let config = config.read().await; + + // Return response matching actual sglang server implementation + HttpResponse::Ok().json(json!({ + // Server args fields + "model_path": "mock-model-path", + "tokenizer_path": "mock-tokenizer-path", + "port": config.port, + "host": "127.0.0.1", + "max_num_batched_tokens": 32768, + "max_prefill_tokens": 16384, + "mem_fraction_static": 0.88, + "tp_size": 1, + "dp_size": 1, + "stream_interval": 8, + "dtype": "float16", + "device": "cuda", + "enable_flashinfer": true, + "enable_p2p_check": true, + "context_length": 32768, + "chat_template": null, + "disable_radix_cache": false, + "enable_torch_compile": false, + "trust_remote_code": false, + "show_time_cost": false, + + // Scheduler info fields + "waiting_queue_size": 0, + "running_queue_size": 0, + "req_to_token_ratio": 1.2, + "min_running_requests": 0, + "max_running_requests": 2048, + "max_req_num": 8192, + "max_batch_tokens": 32768, + "schedule_policy": "lpm", + "schedule_conservativeness": 1.0, + + // Additional fields + "version": "0.3.0", + "internal_states": [{ + "waiting_queue_size": 0, + "running_queue_size": 0 + }] + })) +} + +async fn model_info_handler(_config: web::Data>>) -> HttpResponse { + // Return response matching actual sglang server implementation + HttpResponse::Ok().json(json!({ + "model_path": "mock-model-path", + "tokenizer_path": "mock-tokenizer-path", + "is_generation": true, + "preferred_sampling_params": { + "temperature": 0.7, + "top_p": 0.9, + "top_k": 40, + "max_tokens": 2048 + } + })) +} + +async fn generate_handler( + config: web::Data>>, + _req: HttpRequest, + payload: web::Json, +) -> HttpResponse { + let config = config.read().await; + + // Simulate failure based on fail_rate + if rand::random::() < config.fail_rate { + return HttpResponse::InternalServerError().json(json!({ + "error": "Random failure for testing" + })); + } + + // Simulate processing delay + if config.response_delay_ms > 0 { + tokio::time::sleep(tokio::time::Duration::from_millis(config.response_delay_ms)).await; + } + + let is_stream = payload + .get("stream") + .and_then(|v| v.as_bool()) + .unwrap_or(false); + + if is_stream { + // Return streaming response matching sglang format + let (tx, rx) = tokio::sync::mpsc::channel(10); + let stream_delay = config.response_delay_ms; + let request_id = format!("mock-req-{}", rand::random::()); + + tokio::spawn(async move { + let tokens = vec!["This ", "is ", "a ", "mock ", "response."]; + let timestamp_start = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs_f64(); + + for (i, token) in tokens.iter().enumerate() { + let chunk = json!({ + "text": token, + "meta_info": { + "id": &request_id, + "finish_reason": if i == tokens.len() - 1 { + json!({"type": "stop", "matched_stop": null}) + } else { + json!(null) + }, + "prompt_tokens": 10, + "completion_tokens": i + 1, + "cached_tokens": 0, + "e2e_latency": SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs_f64() - timestamp_start + } + }); + + if tx.send(format!("data: {}\n\n", serde_json::to_string(&chunk).unwrap())).await.is_err() { + break; + } + + if stream_delay > 0 { + tokio::time::sleep(tokio::time::Duration::from_millis(stream_delay)).await; + } + } + + let _ = tx.send("data: [DONE]\n\n".to_string()).await; + }); + + let stream = tokio_stream::wrappers::ReceiverStream::new(rx); + + HttpResponse::Ok() + .content_type("text/event-stream") + .insert_header(("Cache-Control", "no-cache")) + .streaming(stream.map(|chunk| Ok::<_, actix_web::Error>(bytes::Bytes::from(chunk)))) + } else { + // Return non-streaming response matching sglang format + let request_id = format!("mock-req-{}", rand::random::()); + let timestamp_start = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs_f64(); + + HttpResponse::Ok().json(json!({ + "text": "Mock generated response for the input", + "meta_info": { + "id": request_id, + "finish_reason": { + "type": "stop", + "matched_stop": null + }, + "prompt_tokens": 10, + "completion_tokens": 7, + "cached_tokens": 0, + "e2e_latency": 0.042 + } + })) + } +} + +async fn chat_completions_handler( + config: web::Data>>, + payload: web::Json, +) -> HttpResponse { + let config = config.read().await; + + // Simulate failure + if rand::random::() < config.fail_rate { + return HttpResponse::InternalServerError().json(json!({ + "error": "Chat completion failed" + })); + } + + let is_stream = payload + .get("stream") + .and_then(|v| v.as_bool()) + .unwrap_or(false); + + if is_stream { + // Return proper streaming response for chat completions + let (tx, rx) = tokio::sync::mpsc::channel(10); + let stream_delay = config.response_delay_ms; + let model = payload + .get("model") + .and_then(|m| m.as_str()) + .unwrap_or("mock-model") + .to_string(); + + tokio::spawn(async move { + let chat_id = format!("chatcmpl-mock{}", rand::random::()); + let timestamp = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs(); + + // Send initial chunk with role + let initial_chunk = json!({ + "id": &chat_id, + "object": "chat.completion.chunk", + "created": timestamp, + "model": &model, + "choices": [{ + "index": 0, + "delta": { + "role": "assistant" + }, + "finish_reason": null + }] + }); + + let _ = tx + .send(format!( + "data: {}\n\n", + serde_json::to_string(&initial_chunk).unwrap() + )) + .await; + + // Send content chunks + let content_chunks = [ + "This ", + "is ", + "a ", + "mock ", + "streaming ", + "chat ", + "response.", + ]; + for chunk in content_chunks.iter() { + let data = json!({ + "id": &chat_id, + "object": "chat.completion.chunk", + "created": timestamp, + "model": &model, + "choices": [{ + "index": 0, + "delta": { + "content": chunk + }, + "finish_reason": null + }] + }); + + if tx + .send(format!( + "data: {}\n\n", + serde_json::to_string(&data).unwrap() + )) + .await + .is_err() + { + break; + } + + if stream_delay > 0 { + tokio::time::sleep(tokio::time::Duration::from_millis(stream_delay)).await; + } + } + + // Send final chunk with finish_reason + let final_chunk = json!({ + "id": &chat_id, + "object": "chat.completion.chunk", + "created": timestamp, + "model": &model, + "choices": [{ + "index": 0, + "delta": {}, + "finish_reason": "stop" + }] + }); + + let _ = tx + .send(format!( + "data: {}\n\n", + serde_json::to_string(&final_chunk).unwrap() + )) + .await; + let _ = tx.send("data: [DONE]\n\n".to_string()).await; + }); + + let stream = tokio_stream::wrappers::ReceiverStream::new(rx); + + HttpResponse::Ok() + .content_type("text/event-stream") + .insert_header(("Cache-Control", "no-cache")) + .streaming(stream.map(|chunk| Ok::<_, actix_web::Error>(bytes::Bytes::from(chunk)))) + } else { + // Non-streaming response matching OpenAI format + let model = payload + .get("model") + .and_then(|m| m.as_str()) + .unwrap_or("mock-model") + .to_string(); + + HttpResponse::Ok().json(json!({ + "id": format!("chatcmpl-{}", uuid::Uuid::new_v4()), + "object": "chat.completion", + "created": SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs(), + "model": model, + "choices": [{ + "index": 0, + "message": { + "role": "assistant", + "content": "This is a mock chat completion response." + }, + "logprobs": null, + "finish_reason": "stop", + "matched_stop": null + }], + "usage": { + "prompt_tokens": 10, + "completion_tokens": 8, + "total_tokens": 18, + "prompt_tokens_details": { + "cached_tokens": 0 + } + } + })) + } +} + +async fn completions_handler( + config: web::Data>>, + payload: web::Json, +) -> HttpResponse { + let config = config.read().await; + + if rand::random::() < config.fail_rate { + return HttpResponse::InternalServerError().json(json!({ + "error": "Completion failed" + })); + } + + // Check if streaming is requested + let is_stream = payload + .get("stream") + .and_then(|v| v.as_bool()) + .unwrap_or(false); + + let prompts = payload + .get("prompt") + .map(|p| { + if p.is_array() { + p.as_array().unwrap().len() + } else { + 1 + } + }) + .unwrap_or(1); + + if is_stream { + // Return streaming response for completions + let (tx, rx) = tokio::sync::mpsc::channel(10); + let stream_delay = config.response_delay_ms; + let model = payload + .get("model") + .and_then(|m| m.as_str()) + .unwrap_or("mock-model") + .to_string(); + + tokio::spawn(async move { + let completion_id = format!("cmpl-mock{}", rand::random::()); + let timestamp = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs(); + + // Stream completions for each prompt + for prompt_idx in 0..prompts { + let prompt_suffix = format!("{} ", prompt_idx); + let tokens = vec!["This ", "is ", "mock ", "completion ", &prompt_suffix]; + + for (token_idx, token) in tokens.iter().enumerate() { + let data = json!({ + "id": &completion_id, + "object": "text_completion", + "created": timestamp, + "model": &model, + "choices": [{ + "text": token, + "index": prompt_idx, + "logprobs": null, + "finish_reason": if token_idx == tokens.len() - 1 { Some("stop") } else { None } + }] + }); + + if tx + .send(format!( + "data: {}\n\n", + serde_json::to_string(&data).unwrap() + )) + .await + .is_err() + { + return; + } + + if stream_delay > 0 { + tokio::time::sleep(tokio::time::Duration::from_millis(stream_delay)).await; + } + } + } + + let _ = tx.send("data: [DONE]\n\n".to_string()).await; + }); + + let stream = tokio_stream::wrappers::ReceiverStream::new(rx); + + HttpResponse::Ok() + .content_type("text/event-stream") + .insert_header(("Cache-Control", "no-cache")) + .streaming(stream.map(|chunk| Ok::<_, actix_web::Error>(bytes::Bytes::from(chunk)))) + } else { + // Return non-streaming response + let mut choices = vec![]; + for i in 0..prompts { + choices.push(json!({ + "text": format!("Mock completion {}", i), + "index": i, + "logprobs": null, + "finish_reason": "stop" + })); + } + + HttpResponse::Ok().json(json!({ + "id": format!("cmpl-mock{}", rand::random::()), + "object": "text_completion", + "created": SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs(), + "model": payload.get("model").and_then(|m| m.as_str()).unwrap_or("mock-model"), + "choices": choices, + "usage": { + "prompt_tokens": 5 * prompts, + "completion_tokens": 10 * prompts, + "total_tokens": 15 * prompts + } + })) + } +} + +async fn flush_cache_handler(_config: web::Data>>) -> HttpResponse { + HttpResponse::Ok().json(json!({ + "status": "success", + "message": "Cache flushed", + "freed_entries": 42 + })) +} + +async fn v1_models_handler(_config: web::Data>>) -> HttpResponse { + HttpResponse::Ok().json(json!({ + "object": "list", + "data": [{ + "id": "mock-model-v1", + "object": "model", + "created": SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs(), + "owned_by": "sglang", + "permission": [{ + "id": "modelperm-mock", + "object": "model_permission", + "created": SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs(), + "allow_create_engine": false, + "allow_sampling": true, + "allow_logprobs": true, + "allow_search_indices": false, + "allow_view": true, + "allow_fine_tuning": false, + "organization": "*", + "group": null, + "is_blocking": false + }], + "root": "mock-model-v1", + "parent": null + }] + })) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_mock_worker_lifecycle() { + let config = MockWorkerConfig { + port: 18080, + worker_type: WorkerType::Regular, + health_status: HealthStatus::Healthy, + response_delay_ms: 0, + fail_rate: 0.0, + }; + + let mut worker = MockWorker::new(config); + + // Start the worker + let url = worker.start().await.unwrap(); + assert_eq!(url, "http://127.0.0.1:18080"); + + // Give server time to start + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + // Test health endpoint + let client = reqwest::Client::new(); + let resp = client.get(&format!("{}/health", url)).send().await.unwrap(); + + assert_eq!(resp.status(), 200); + let body: serde_json::Value = resp.json().await.unwrap(); + assert_eq!(body["status"], "healthy"); + + // Update config to unhealthy + worker + .update_config(|c| c.health_status = HealthStatus::Unhealthy) + .await; + + // Test health again + let resp = client.get(&format!("{}/health", url)).send().await.unwrap(); + + assert_eq!(resp.status(), 503); + + // Stop the worker + worker.stop().await; + } +} diff --git a/sgl-router/tests/common/mod.rs b/sgl-router/tests/common/mod.rs new file mode 100644 index 000000000000..34467cd0885a --- /dev/null +++ b/sgl-router/tests/common/mod.rs @@ -0,0 +1,56 @@ +pub mod mock_worker; + +use actix_web::web; +use reqwest::Client; +use sglang_router_rs::config::{PolicyConfig, RouterConfig, RoutingMode}; +use sglang_router_rs::server::AppState; + +/// Helper function to create test router configuration +pub fn create_test_config(worker_urls: Vec) -> RouterConfig { + RouterConfig { + mode: RoutingMode::Regular { worker_urls }, + policy: PolicyConfig::Random, + host: "127.0.0.1".to_string(), + port: 3001, + max_payload_size: 256 * 1024 * 1024, // 256MB + request_timeout_secs: 600, + worker_startup_timeout_secs: 300, + worker_startup_check_interval_secs: 10, + discovery: None, + metrics: None, + log_dir: None, + log_level: None, + } +} + +/// Helper function to create test router configuration with no health check +pub fn create_test_config_no_workers() -> RouterConfig { + RouterConfig { + mode: RoutingMode::Regular { + worker_urls: vec![], + }, // Empty to skip health check + policy: PolicyConfig::Random, + host: "127.0.0.1".to_string(), + port: 3001, + max_payload_size: 256 * 1024 * 1024, // 256MB + request_timeout_secs: 600, + worker_startup_timeout_secs: 0, // No wait + worker_startup_check_interval_secs: 10, + discovery: None, + metrics: None, + log_dir: None, + log_level: None, + } +} + +/// Helper function to create test app state +pub async fn create_test_app_state(config: RouterConfig) -> Result, String> { + // Create a non-blocking client + let client = Client::builder() + .timeout(std::time::Duration::from_secs(config.request_timeout_secs)) + .build() + .map_err(|e| e.to_string())?; + + let app_state = AppState::new(config, client)?; + Ok(web::Data::new(app_state)) +} From 4c605235aa832f259e148dfbdce08d9e471b5099 Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Wed, 23 Jul 2025 12:01:51 -0700 Subject: [PATCH 101/396] fix: workaround for deepgemm warmup issue (#8302) --- docker/Dockerfile | 2 +- sgl-kernel/CMakeLists.txt | 2 +- sgl-kernel/pyproject.toml | 2 +- sgl-kernel/pyproject_cpu.toml | 2 +- sgl-kernel/pyproject_rocm.toml | 2 +- sgl-kernel/python/sgl_kernel/version.py | 2 +- 6 files changed, 6 insertions(+), 6 deletions(-) diff --git a/docker/Dockerfile b/docker/Dockerfile index 1e5f21c9d5f5..5494762150d0 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -60,7 +60,7 @@ RUN python3 -m pip install --no-cache-dir --upgrade pip setuptools wheel html5li && python3 -m pip install --no-cache-dir -e "python[${BUILD_TYPE}]" --extra-index-url https://download.pytorch.org/whl/cu${CUINDEX} \ && if [ "$CUDA_VERSION" = "12.8.1" ]; then \ python3 -m pip install --no-cache-dir nvidia-nccl-cu12==2.27.6 --force-reinstall --no-deps ; \ - python3 -m pip install --no-cache-dir https://github.com/sgl-project/whl/releases/download/v0.2.6.post1/sgl_kernel-0.2.6.post1+cu128-cp39-abi3-manylinux2014_x86_64.whl --force-reinstall --no-deps ; \ + python3 -m pip install --no-cache-dir https://github.com/sgl-project/whl/releases/download/v0.2.7/sgl_kernel-0.2.7+cu128-cp39-abi3-manylinux2014_x86_64.whl --force-reinstall --no-deps ; \ fi # Build and install NVSHMEM + DeepEP diff --git a/sgl-kernel/CMakeLists.txt b/sgl-kernel/CMakeLists.txt index e8f9a0839658..739782372909 100644 --- a/sgl-kernel/CMakeLists.txt +++ b/sgl-kernel/CMakeLists.txt @@ -56,7 +56,7 @@ if("${CUDA_VERSION}" VERSION_EQUAL "12.8") set(DeepGEMM_TAG "blackwell") else() set(DeepGEMM_REPO "https://github.com/deepseek-ai/DeepGEMM") - set(DeepGEMM_TAG "8dfa3298274bfe6b242f6f8a3e6f3eff2707dd9f") + set(DeepGEMM_TAG "391755ada0ffefa9a6a52b6f14dcaf22d1a463e0") endif() FetchContent_Declare( diff --git a/sgl-kernel/pyproject.toml b/sgl-kernel/pyproject.toml index 3b49eab5d9a8..59f69f628346 100644 --- a/sgl-kernel/pyproject.toml +++ b/sgl-kernel/pyproject.toml @@ -8,7 +8,7 @@ build-backend = "scikit_build_core.build" [project] name = "sgl-kernel" -version = "0.2.6.post1" +version = "0.2.7" description = "Kernel Library for SGLang" readme = "README.md" requires-python = ">=3.9" diff --git a/sgl-kernel/pyproject_cpu.toml b/sgl-kernel/pyproject_cpu.toml index 6746b212d364..f9d5cb3975aa 100644 --- a/sgl-kernel/pyproject_cpu.toml +++ b/sgl-kernel/pyproject_cpu.toml @@ -8,7 +8,7 @@ build-backend = "scikit_build_core.build" [project] name = "sgl-kernel" -version = "0.2.6.post1" +version = "0.2.7" description = "Kernel Library for SGLang" readme = "README.md" requires-python = ">=3.9" diff --git a/sgl-kernel/pyproject_rocm.toml b/sgl-kernel/pyproject_rocm.toml index 0ba8b0399bff..6791bb47b2ce 100644 --- a/sgl-kernel/pyproject_rocm.toml +++ b/sgl-kernel/pyproject_rocm.toml @@ -9,7 +9,7 @@ build-backend = "setuptools.build_meta" [project] name = "sgl-kernel" -version = "0.2.6.post1" +version = "0.2.7" description = "Kernel Library for SGLang" readme = "README.md" requires-python = ">=3.9" diff --git a/sgl-kernel/python/sgl_kernel/version.py b/sgl-kernel/python/sgl_kernel/version.py index e39bc3f224a0..6cd38b746590 100644 --- a/sgl-kernel/python/sgl_kernel/version.py +++ b/sgl-kernel/python/sgl_kernel/version.py @@ -1 +1 @@ -__version__ = "0.2.6.post1" +__version__ = "0.2.7" From a99801e0750f41553fedd02e36f58d835c4d4bd6 Mon Sep 17 00:00:00 2001 From: YiXR <37775155+YiXR@users.noreply.github.com> Date: Thu, 24 Jul 2025 04:28:12 +0800 Subject: [PATCH 102/396] [Performance][PD Disaggregation] optimize TokenToKVPoolAllocator by sorting free pages (#8133) Signed-off-by: Xingrui Yi Co-authored-by: Xingrui Yi --- python/sglang/srt/mem_cache/allocator.py | 74 +++++++++++++++++++++--- 1 file changed, 67 insertions(+), 7 deletions(-) diff --git a/python/sglang/srt/mem_cache/allocator.py b/python/sglang/srt/mem_cache/allocator.py index 7dd488e9cf18..58afbf312f02 100644 --- a/python/sglang/srt/mem_cache/allocator.py +++ b/python/sglang/srt/mem_cache/allocator.py @@ -51,6 +51,7 @@ def __init__( self._kvcache = kvcache self.free_pages = None + self.release_pages = None self.is_not_in_free_group = True self.free_group = [] @@ -58,16 +59,16 @@ def debug_print(self) -> str: return "" def available_size(self): - return len(self.free_pages) * self.page_size + return (len(self.free_pages) + len(self.release_pages)) * self.page_size def get_kvcache(self): return self._kvcache - def restore_state(self, free_pages): - self.free_pages = free_pages + def restore_state(self, state): + self.free_pages, self.release_pages = state def backup_state(self): - return self.free_pages + return (self.free_pages, self.release_pages) def free_group_begin(self): self.is_not_in_free_group = False @@ -78,6 +79,14 @@ def free_group_end(self): if self.free_group: self.free(torch.cat(self.free_group)) + def merge_and_sort_free(self): + if len(self.release_pages) > 0: + self.free_pages = torch.cat((self.free_pages, self.release_pages)) + self.free_pages, _ = torch.sort(self.free_pages) + self.release_pages = torch.empty( + (0,), dtype=self.release_pages.dtype, device=self.device + ) + def get_cpu_copy(self, *args, **kwargs): # FIXME: reuse the get_cpu_copy after paged allocator is implemented raise NotImplementedError() @@ -119,12 +128,15 @@ def clear(self): ) self.is_not_in_free_group = True self.free_group = [] + self.release_pages = torch.empty((0,), dtype=torch.int64, device=self.device) def available_size(self): # To avoid minor "len(free_pages) * 1" overhead - return len(self.free_pages) + return len(self.free_pages) + len(self.release_pages) def alloc(self, need_size: int): + if need_size > len(self.free_pages): + self.merge_and_sort_free() if need_size > len(self.free_pages): return None @@ -137,7 +149,7 @@ def free(self, free_index: torch.Tensor): return if self.is_not_in_free_group: - self.free_pages = torch.cat((self.free_pages, free_index)) + self.release_pages = torch.cat((self.release_pages, free_index)) else: self.free_group.append(free_index) @@ -421,6 +433,8 @@ def alloc(self, need_size: int): ), "The allocation size should be page-aligned" num_pages = need_size // self.page_size + if num_pages > len(self.free_pages): + self.merge_and_sort_free() if num_pages > len(self.free_pages): return None @@ -446,6 +460,17 @@ def alloc_extend( (last_loc + 1) % self.page_size == prefix_lens % self.page_size ) + estimated_num_new_pages = ( + ( + (seq_lens + self.page_size - 1) // self.page_size + - (prefix_lens + self.page_size - 1) // self.page_size + ) + .sum() + .item() + ) + if estimated_num_new_pages > len(self.free_pages): + self.merge_and_sort_free() + bs = len(prefix_lens) out_indices = torch.empty( (extend_num_tokens,), dtype=torch.int64, device=self.device @@ -483,6 +508,17 @@ def alloc_decode( (last_loc + 2) % self.page_size == seq_lens % self.page_size ) + estimated_num_new_pages = ( + ( + (seq_lens + self.page_size - 1) // self.page_size + - (seq_lens - 1 + self.page_size - 1) // self.page_size + ) + .sum() + .item() + ) + if estimated_num_new_pages > len(self.free_pages): + self.merge_and_sort_free() + bs = len(seq_lens) out_indices = torch.empty((bs,), dtype=torch.int64, device=self.device) alloc_decode_kernel[(bs,)]( @@ -511,7 +547,7 @@ def free(self, free_index: torch.Tensor): if self.is_not_in_free_group: free_page_indices = torch.unique(free_index // self.page_size) - self.free_pages = torch.cat((free_page_indices, self.free_pages)) + self.release_pages = torch.cat((free_page_indices, self.release_pages)) else: self.free_group.append(free_index) @@ -525,6 +561,7 @@ def clear(self): ) self.is_not_in_free_group = True self.free_group = [] + self.release_pages = torch.empty((0,), dtype=torch.int64, device=self.device) def get_cpu_copy(self, indices): return self._kvcache.get_cpu_copy(indices) @@ -633,6 +670,17 @@ def alloc_extend( (last_loc + 1) % self.page_size == prefix_lens % self.page_size ) + estimated_num_new_pages = ( + ( + (seq_lens + self.page_size - 1) // self.page_size + - (prefix_lens + self.page_size - 1) // self.page_size + ) + .sum() + .item() + ) + if estimated_num_new_pages > len(self.free_pages): + self.merge_and_sort_free() + bs = len(prefix_lens) out_indices = torch.empty( (extend_num_tokens,), dtype=torch.int32, device=self.device @@ -668,6 +716,17 @@ def alloc_decode( (last_loc + 2) % self.page_size == seq_lens % self.page_size ) + estimated_num_new_pages = ( + ( + (seq_lens + self.page_size - 1) // self.page_size + - (seq_lens - 1 + self.page_size - 1) // self.page_size + ) + .sum() + .item() + ) + if estimated_num_new_pages > len(self.free_pages): + self.merge_and_sort_free() + bs = len(seq_lens) out_indices = torch.empty((bs,), dtype=torch.int32, device=self.device) @@ -692,3 +751,4 @@ def alloc_decode( def clear(self): super().clear() self.free_pages = self.free_pages.to(torch.int32) + self.release_pages = self.release_pages.to(torch.int32) From c87d4fec9998d278fb416f2523677e70908f5e11 Mon Sep 17 00:00:00 2001 From: xianzhiT Date: Thu, 24 Jul 2025 04:28:53 +0800 Subject: [PATCH 103/396] Fix the issue of incorrect finish reason in final stream response chunk returned during tool call (#7708) Co-authored-by: Xinyuan Tong <115166877+JustinTong0323@users.noreply.github.com> --- .../srt/entrypoints/openai/serving_chat.py | 17 +++++++++++++---- .../test_openai_function_calling.py | 7 +++++++ 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/entrypoints/openai/serving_chat.py b/python/sglang/srt/entrypoints/openai/serving_chat.py index e69587432c12..9889cb2edd66 100644 --- a/python/sglang/srt/entrypoints/openai/serving_chat.py +++ b/python/sglang/srt/entrypoints/openai/serving_chat.py @@ -484,7 +484,10 @@ async def _generate_chat_stream( # Handle tool calls if request.tool_choice != "none" and request.tools: - async for chunk in self._process_tool_call_stream( + async for ( + chunk, + tool_call_finish_reason_type, + ) in self._process_tool_call_stream( index, delta, parser_dict, @@ -492,7 +495,10 @@ async def _generate_chat_stream( request, finish_reason_type, ): - yield chunk + if chunk: + yield chunk + finish_reason_type = tool_call_finish_reason_type + else: # Regular content if delta or not ( @@ -865,7 +871,7 @@ async def _process_tool_call_stream( choices=[choice_data], model=request.model, ) - yield f"data: {chunk.model_dump_json()}\n\n" + yield f"data: {chunk.model_dump_json()}\n\n", finish_reason_type # Yield tool calls for call_item in calls: @@ -920,4 +926,7 @@ async def _process_tool_call_stream( choices=[choice_data], model=request.model, ) - yield f"data: {chunk.model_dump_json()}\n\n" + yield f"data: {chunk.model_dump_json()}\n\n", finish_reason_type + + if finish_reason_type == "stop": + yield None, "tool_calls" diff --git a/test/srt/openai_server/function_call/test_openai_function_calling.py b/test/srt/openai_server/function_call/test_openai_function_calling.py index 012fc15c5ff3..8b437a8ac910 100644 --- a/test/srt/openai_server/function_call/test_openai_function_calling.py +++ b/test/srt/openai_server/function_call/test_openai_function_calling.py @@ -159,6 +159,13 @@ def test_function_calling_streaming_simple(self): "Target function name 'get_current_weather' was not found in the streaming chunks", ) + finish_reason = chunks[-1].choices[0].finish_reason + self.assertEqual( + finish_reason, + "tool_calls", + "Final response of function calling should have finish_reason 'tool_calls'", + ) + def test_function_calling_streaming_args_parsing(self): """ Test: Whether the function call arguments returned in streaming mode can be correctly concatenated into valid JSON. From 70251e935e9d466f36e75d74fffeea90af346418 Mon Sep 17 00:00:00 2001 From: Xinyuan Tong <115166877+JustinTong0323@users.noreply.github.com> Date: Wed, 23 Jul 2025 13:29:03 -0700 Subject: [PATCH 104/396] fix: match chat-template for internvl3 (#8262) Signed-off-by: Xinyuan Tong Co-authored-by: Xinyuan Tong --- python/sglang/srt/conversation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/conversation.py b/python/sglang/srt/conversation.py index cb4bdbc44a0c..80b706430bf7 100644 --- a/python/sglang/srt/conversation.py +++ b/python/sglang/srt/conversation.py @@ -984,7 +984,7 @@ def generate_chat_conv( @register_conv_template_matching_function def match_internvl(model_path: str): - if re.search(r"internvl2_5", model_path, re.IGNORECASE): + if re.search(r"internvl", model_path, re.IGNORECASE): return "internvl-2-5" From 38000a5f44d16b216f5d1fb476fdad15c3fa4616 Mon Sep 17 00:00:00 2001 From: Xinyuan Tong <115166877+JustinTong0323@users.noreply.github.com> Date: Wed, 23 Jul 2025 13:29:18 -0700 Subject: [PATCH 105/396] Fix gemma3n with hybrid swa (#8240) Signed-off-by: Xinyuan Tong Co-authored-by: Xinyuan Tong --- .../sglang/srt/model_executor/model_runner.py | 8 +++++-- test/srt/run_suite.py | 2 +- test/srt/test_vision_openai_server_b.py | 21 +++++++++++++++++++ 3 files changed, 28 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 919622cc77d1..cbb35bf270d3 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -276,6 +276,7 @@ def initialize(self, min_per_gpu_memory: float): self.sampler = Sampler() self.load_model() + # Check if the model is using hybrid SWA if ( not self.server_args.disable_hybrid_swa_memory and self.sliding_window_size is not None @@ -1008,8 +1009,11 @@ def set_num_token_hybrid(self): try: layers = self.model.language_model.model.layers except: - self.is_hybrid = False - return + try: + layers = self.model.language_model.layers + except: + self.is_hybrid = False + return for layer in layers: if ( diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 6a96cf598648..18dcd004ff62 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -105,7 +105,7 @@ class TestFile: TestFile("test_vision_chunked_prefill.py", 175), TestFile("test_vlm_input_format.py", 300), TestFile("test_vision_openai_server_a.py", 584), - TestFile("test_vision_openai_server_b.py", 556), + TestFile("test_vision_openai_server_b.py", 620), TestFile("test_w8a8_quantization.py", 46), TestFile("test_reasoning_parser.py", 5), ], diff --git a/test/srt/test_vision_openai_server_b.py b/test/srt/test_vision_openai_server_b.py index dabf948b3567..f5b33a72e380 100644 --- a/test/srt/test_vision_openai_server_b.py +++ b/test/srt/test_vision_openai_server_b.py @@ -151,6 +151,27 @@ def test_video_chat_completion(self): pass +class TestGemma3nServer(TestOpenAIVisionServer): + @classmethod + def setUpClass(cls): + cls.model = "google/gemma-3n-E2B-it" + cls.base_url = DEFAULT_URL_FOR_TEST + cls.api_key = "sk-123456" + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--trust-remote-code", + "--mem-fraction-static", + "0.70", + "--cuda-graph-max-bs", + "1", + ], + ) + cls.base_url += "/v1" + + class TestKimiVLServer(TestOpenAIVisionServer): @classmethod def setUpClass(cls): From 4953f4ca9a3a440168cb4a0e9d1e4ae883c97d52 Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Wed, 23 Jul 2025 15:07:27 -0700 Subject: [PATCH 106/396] chore: upgrade sgl-kernel 0.2.7 (#8304) --- python/pyproject.toml | 2 +- python/sglang/srt/entrypoints/engine.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/pyproject.toml b/python/pyproject.toml index aa9fc460d977..64915df6b590 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -54,7 +54,7 @@ runtime_common = [ srt = [ "sglang[runtime_common]", - "sgl-kernel==0.2.6.post1", + "sgl-kernel==0.2.7", "torch==2.7.1", "torchaudio==2.7.1", "torchvision==0.22.1", diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index e2cb02cc3014..edf81a79a098 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -654,7 +654,7 @@ def _set_envs_and_config(server_args: ServerArgs): if _is_cuda: assert_pkg_version( "sgl-kernel", - "0.2.6.post1", + "0.2.7", "Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`", ) From 0e7a5b26945c7a21dbaff10254477d2d3de779ff Mon Sep 17 00:00:00 2001 From: J Date: Wed, 23 Jul 2025 15:30:55 -0700 Subject: [PATCH 107/396] fix: prevent crashes due to logit bias dimension mismatch (#7685) --- python/sglang/srt/sampling/sampling_batch_info.py | 11 ++++++----- python/sglang/srt/speculative/eagle_utils.py | 6 ++++++ 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/sampling/sampling_batch_info.py b/python/sglang/srt/sampling/sampling_batch_info.py index f88082e690b0..bcdadbe1120f 100644 --- a/python/sglang/srt/sampling/sampling_batch_info.py +++ b/python/sglang/srt/sampling/sampling_batch_info.py @@ -322,6 +322,12 @@ def merge_batch(self, other: "SamplingBatchInfo"): # Set the flag to True if any of the two has custom logit processor self.has_custom_logit_processor = True + # Merge logit bias - note this has to come before the temperatures tensor update! Otherwise will cause crashes. + # See note below on len(self) and len(other). + self.logit_bias = merge_bias_tensor( + self.logit_bias, other.logit_bias, len(self), len(other), self.device, 0.0 + ) + # Note: because the __len()__ operator is defined on the temperatures tensor, # please make sure any merge operation with len(self) or len(other) is done before # the merge operation of the temperatures tensor below. @@ -340,11 +346,6 @@ def merge_batch(self, other: "SamplingBatchInfo"): self.need_top_k_sampling |= other.need_top_k_sampling self.need_min_p_sampling |= other.need_min_p_sampling - # Merge logit bias - self.logit_bias = merge_bias_tensor( - self.logit_bias, other.logit_bias, len(self), len(other), self.device, 0.0 - ) - def merge_bias_tensor( lhs: Optional[torch.Tensor], diff --git a/python/sglang/srt/speculative/eagle_utils.py b/python/sglang/srt/speculative/eagle_utils.py index 83724b3851ec..7f7e21e968c1 100644 --- a/python/sglang/srt/speculative/eagle_utils.py +++ b/python/sglang/srt/speculative/eagle_utils.py @@ -1,5 +1,6 @@ from __future__ import annotations +import copy import logging import os import time @@ -362,6 +363,11 @@ def verify( ) accept_length = torch.empty((bs,), dtype=torch.int32, device="cuda") + if bs != len(sampling_info): + sampling_info = copy.deepcopy(sampling_info) + # NOTE: retrive_index are the indices of the requests that are kept. + sampling_info.filter_batch(self.retrive_index.tolist(), self.retrive_index) + # Apply the custom logit processors if registered in the sampling info. if sampling_info.has_custom_logit_processor: apply_custom_logit_processor( From 01079e174ff8a7a052b4f8f74b4f8a59edd13f61 Mon Sep 17 00:00:00 2001 From: Chang Su Date: Wed, 23 Jul 2025 17:37:31 -0700 Subject: [PATCH 108/396] feat(function call): complete utility method for KimiK2Detector and enhance documentation (#8043) --- .../srt/function_call/base_format_detector.py | 82 ++++++++++++++++--- .../srt/function_call/deepseekv3_detector.py | 35 +++++--- .../srt/function_call/kimik2_detector.py | 57 +++++++++---- .../srt/function_call/llama32_detector.py | 9 +- .../srt/function_call/mistral_detector.py | 14 +++- .../srt/function_call/pythonic_detector.py | 21 +++-- .../srt/function_call/qwen25_detector.py | 15 +++- test/srt/test_function_call_parser.py | 28 +++++++ 8 files changed, 205 insertions(+), 56 deletions(-) diff --git a/python/sglang/srt/function_call/base_format_detector.py b/python/sglang/srt/function_call/base_format_detector.py index 3989ec98d95c..d9ac71253e6d 100644 --- a/python/sglang/srt/function_call/base_format_detector.py +++ b/python/sglang/srt/function_call/base_format_detector.py @@ -25,23 +25,49 @@ class BaseFormatDetector(ABC): """Base class providing two sets of interfaces: one-time and streaming incremental.""" def __init__(self): - # initialize properties used for state when parsing tool calls in + # Streaming state management + # Buffer for accumulating incomplete patterns that arrive across multiple streaming chunks self._buffer = "" - # streaming mode + # Stores complete tool call info (name and arguments) for each tool being parsed. + # Used by serving layer for completion handling when streaming ends. + # Format: [{"name": str, "arguments": dict}, ...] self.prev_tool_call_arr: List[Dict] = [] + # Index of currently streaming tool call. Starts at -1 (no active tool), + # increments as each tool completes. Tracks which tool's arguments are streaming. self.current_tool_id: int = -1 + # Flag for whether current tool's name has been sent to client. + # Tool names sent first with empty parameters, then arguments stream incrementally. self.current_tool_name_sent: bool = False - self.streamed_args_for_tool: List[str] = ( - [] - ) # map what has been streamed for each tool so far to a list + # Tracks raw JSON string content streamed to client for each tool's arguments. + # Critical for serving layer to calculate remaining content when streaming ends. + # Each index corresponds to a tool_id. Example: ['{"location": "San Francisco"', '{"temp": 72'] + self.streamed_args_for_tool: List[str] = [] + + # Token configuration (override in subclasses) self.bot_token = "" self.eot_token = "" self.tool_call_separator = ", " - def parse_base_json(self, action: Any, tools: List[Tool]) -> List[ToolCallItem]: - tool_indices = { + def _get_tool_indices(self, tools: List[Tool]) -> Dict[str, int]: + """ + Get a mapping of tool names to their indices in the tools list. + + This utility method creates a dictionary mapping function names to their + indices in the tools list, which is commonly needed for tool validation + and ToolCallItem creation. + + Args: + tools: List of available tools + + Returns: + Dictionary mapping tool names to their indices + """ + return { tool.function.name: i for i, tool in enumerate(tools) if tool.function.name } + + def parse_base_json(self, action: Any, tools: List[Tool]) -> List[ToolCallItem]: + tool_indices = self._get_tool_indices(tools) if not isinstance(action, list): action = [action] @@ -130,11 +156,7 @@ def parse_streaming_increment( # Build tool indices if not already built if not hasattr(self, "_tool_indices"): - self._tool_indices = { - tool.function.name: i - for i, tool in enumerate(tools) - if tool.function and tool.function.name - } + self._tool_indices = self._get_tool_indices(tools) flags = Allow.ALL if self.current_tool_name_sent else Allow.ALL & ~Allow.STR @@ -294,12 +316,48 @@ def parse_streaming_increment( @abstractmethod def has_tool_call(self, text: str) -> bool: + """ + Check if the given text contains function call markers specific to this format. + """ raise NotImplementedError() @abstractmethod def structure_info(self) -> _GetInfoFunc: + """ + Return a function that creates StructureInfo for constrained generation. + + The returned function takes a tool name and returns a StructureInfo object + containing the begin/end patterns and trigger tokens needed for constrained + generation of function calls in this format. + + Returns: + A function that takes a tool name (str) and returns StructureInfo + """ raise NotImplementedError() @abstractmethod def build_ebnf(self, tools: List[Tool]) -> str: + """ + Build an EBNF grammar for constrained generation of function calls. + + This method generates an Extended Backus-Naur Form (EBNF) grammar that + constrains the model's output to valid function calls in this format. + The grammar should include all available tools and their parameter schemas. + + Args: + tools: List of available tools/functions that can be called + + Returns: + A string containing the EBNF grammar for this function call format + + The EBNF grammar should: + - Define the overall structure of function calls in this format + - Include all tool names from the provided tools list + - Define valid JSON structures for function arguments + - Handle multiple function calls if the format supports them + + Note: + Most implementations use EBNFComposer.build_ebnf() utility with + format-specific parameters rather than writing EBNF from scratch. + """ raise NotImplementedError() diff --git a/python/sglang/srt/function_call/deepseekv3_detector.py b/python/sglang/srt/function_call/deepseekv3_detector.py index e3befca5bcf4..35e96c715295 100644 --- a/python/sglang/srt/function_call/deepseekv3_detector.py +++ b/python/sglang/srt/function_call/deepseekv3_detector.py @@ -19,9 +19,28 @@ class DeepSeekV3Detector(BaseFormatDetector): """ - Detector for DeepSeek models. - Assumes function call format: - '<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_current_weather\n```json\n{"location": "Tokyo"}\n```<|tool▁call▁end|>\n<|tool▁call▁begin|>function<|tool▁sep|>get_current_weather\n```json\n{"location": "Paris"}\n```<|tool▁call▁end|><|tool▁calls▁end|><|end▁of▁sentence|> + Detector for DeepSeek V3 model function call format. + + The DeepSeek V3 format uses special Unicode tokens to delimit function calls + with JSON code blocks for arguments. + + Format Structure: + ``` + <|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>{function_name}\n```json\n{json_arguments}\n```<|tool▁calls▁end|><|end▁of▁sentence|> + ``` + Examples: + ``` + <|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_current_weather\n```json\n{"location": "Tokyo"}\n```<|tool▁call▁end|>\n<|tool▁call▁begin|>function<|tool▁sep|>get_current_weather\n```json\n{"location": "Paris"}\n```<|tool▁call▁end|><|tool▁calls▁end|><|end▁of▁sentence|> + ``` + + Key Components: + - Tool Calls Section: Wrapped between `<|tool▁calls▁begin|>` and `<|tool▁calls▁end|>` + - Individual Tool Call: Wrapped between `<|tool▁call▁begin|>` and `<|tool▁call▁end|>` + - Function Declaration: `function<|tool▁sep|>{function_name}` + - Arguments: JSON code block between ````json` and ```` + - Supports multiple tool calls + + Reference: https://huggingface.co/deepseek-ai/DeepSeek-V3-0324?chat_template=default """ def __init__(self): @@ -89,11 +108,7 @@ def parse_streaming_increment( return StreamingParseResult(normal_text=new_text) if not hasattr(self, "_tool_indices"): - self._tool_indices = { - tool.function.name: i - for i, tool in enumerate(tools) - if tool.function and tool.function.name - } + self._tool_indices = self._get_tool_indices(tools) calls: list[ToolCallItem] = [] try: @@ -127,7 +142,7 @@ def parse_streaming_increment( ) ) self.current_tool_name_sent = True - # Store the tool call info for adapter.py + # Store the tool call info for serving layer completions endpoint self.prev_tool_call_arr[self.current_tool_id] = { "name": func_name, "arguments": {}, @@ -153,7 +168,7 @@ def parse_streaming_increment( ] += argument_diff if _is_complete_json(func_args_raw): - # Update the stored arguments for adapter.py + # Update the stored arguments try: parsed_args = json.loads(func_args_raw) self.prev_tool_call_arr[self.current_tool_id][ diff --git a/python/sglang/srt/function_call/kimik2_detector.py b/python/sglang/srt/function_call/kimik2_detector.py index 94457ccda15c..54ee777873f1 100644 --- a/python/sglang/srt/function_call/kimik2_detector.py +++ b/python/sglang/srt/function_call/kimik2_detector.py @@ -18,16 +18,21 @@ class KimiK2Detector(BaseFormatDetector): + """ + Detector for Kimi K2 model function call format. + + Format Structure: + ``` + <|tool_calls_section_begin|> + <|tool_call_begin|>functions.{func_name}:{index} <|tool_call_argument_begin|>{json_args}<|tool_call_end|> + <|tool_calls_section_end|> + ``` + + Reference: https://huggingface.co/moonshotai/Kimi-K2-Instruct/blob/main/docs/tool_call_guidance.md + """ def __init__(self): super().__init__() - self._buffer = "" - self.current_tool_name_sent: bool = False - self.prev_tool_call_arr: list[dict] = [] - self.current_tool_id: int = -1 - self.streamed_args_for_tool: list[str] = ( - [] - ) # map what has been streamed for each tool so far to a list self.bot_token: str = "<|tool_calls_section_begin|>" self.eot_token: str = "<|tool_calls_section_end|>" @@ -114,11 +119,7 @@ def parse_streaming_increment( return StreamingParseResult(normal_text=new_text) if not hasattr(self, "_tool_indices"): - self._tool_indices = { - tool.function.name: i - for i, tool in enumerate(tools) - if tool.function and tool.function.name - } + self._tool_indices = self._get_tool_indices(tools) calls: list[ToolCallItem] = [] try: @@ -150,7 +151,7 @@ def parse_streaming_increment( ) ) self.current_tool_name_sent = True - # Store the tool call info for adapter.py + # Store the tool call info for serving layer completions endpoint self.prev_tool_call_arr[self.current_tool_id] = { "name": function_name, "arguments": {}, @@ -214,7 +215,31 @@ def parse_streaming_increment( return StreamingParseResult(normal_text=current_text) def structure_info(self) -> _GetInfoFunc: - raise NotImplementedError() + """Return function that creates StructureInfo for guided generation.""" + + def get_info(name: str) -> StructureInfo: + return StructureInfo( + begin=f"<|tool_calls_section_begin|><|tool_call_begin|>functions.{name}:0 <|tool_call_argument_begin|>", + end="<|tool_call_end|><|tool_calls_section_end|>", + trigger="<|tool_calls_section_begin|>", + ) + + return get_info - def build_ebnf(self, tools: List[Tool]): - raise NotImplementedError() + def build_ebnf(self, tools: List[Tool]) -> str: + """ + Build EBNF grammar for KimiK2 tool call format. + + NOTE: The call_rule_fmt uses [0-9]+ for the function index to allow the grammar + to accept any numeric index (0, 1, 2, etc.) for proper sequential indexing in + multiple function call scenarios, while still maintaining the correct KimiK2 + format structure for constrained generation. + """ + return EBNFComposer.build_ebnf( + tools, + sequence_start_token=self.bot_token, + sequence_end_token=self.eot_token, + tool_call_separator="", + call_rule_fmt='"<|tool_call_begin|>functions.{name}:" [0-9]+ " <|tool_call_argument_begin|>" {arguments_rule} "<|tool_call_end|>"', + function_format="json", + ) diff --git a/python/sglang/srt/function_call/llama32_detector.py b/python/sglang/srt/function_call/llama32_detector.py index e7afeddb031f..453bcbc9a75a 100644 --- a/python/sglang/srt/function_call/llama32_detector.py +++ b/python/sglang/srt/function_call/llama32_detector.py @@ -16,9 +16,12 @@ class Llama32Detector(BaseFormatDetector): """ - Detector for Llama 3.2 models. - Assumes function call format: - <|python_tag|>{"name":"xxx", "arguments":{...}} + Detector for Llama 3.2 models with json tool call format. + + Format Structure: + ``` + {"name":"xxx", "arguments":{...}} + ``` """ def __init__(self): diff --git a/python/sglang/srt/function_call/mistral_detector.py b/python/sglang/srt/function_call/mistral_detector.py index 031368006ed9..49767fd53ba0 100644 --- a/python/sglang/srt/function_call/mistral_detector.py +++ b/python/sglang/srt/function_call/mistral_detector.py @@ -17,9 +17,17 @@ class MistralDetector(BaseFormatDetector): """ - Detector for Mistral models. - Assumes function call format: - [TOOL_CALLS] [{"name":"func1", "arguments":{...}}, {"name":"func2", "arguments":{...}}] + Detector for Mistral model function call format. + + The Mistral format uses a simple bracket-delimited structure with JSON arrays + containing function call objects. + + Format Structure: + ``` + [TOOL_CALLS] [{"name": "function_name", "arguments": {json_args}}, ...] + ``` + + Reference: https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.3?chat_template=default """ def __init__(self): diff --git a/python/sglang/srt/function_call/pythonic_detector.py b/python/sglang/srt/function_call/pythonic_detector.py index d3096d9199ed..85c3cd1359ed 100644 --- a/python/sglang/srt/function_call/pythonic_detector.py +++ b/python/sglang/srt/function_call/pythonic_detector.py @@ -19,10 +19,17 @@ class PythonicDetector(BaseFormatDetector): """ - Detector for Llama-3.2 and Llama-4 models with pythonic tool call format. - Assumes function call format: - [tool1(arg1=val1, arg2=val2), tool2(arg1=val3)] - Arguments are Python literals (not JSON). + Detector for Llama-4 models with Pythonic tool call format. + + The Pythonic format uses Python function call syntax within square brackets, + with arguments as Python literals rather than JSON. + + Format Structure: + ``` + [tool1(arg1=val1, arg2=val2), tool2(arg1=val3)] + ``` + + Reference: https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct?chat_template=default """ def __init__(self): @@ -75,11 +82,7 @@ def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult return StreamingParseResult(normal_text=normal_text, calls=[]) calls = [] - tool_indices = { - tool.function.name: i - for i, tool in enumerate(tools) - if tool.function.name - } + tool_indices = self._get_tool_indices(tools) for call_index, call in enumerate(parsed.elts): if not isinstance(call.func, ast.Name): continue diff --git a/python/sglang/srt/function_call/qwen25_detector.py b/python/sglang/srt/function_call/qwen25_detector.py index cee3f18eae0a..40a65e5df742 100644 --- a/python/sglang/srt/function_call/qwen25_detector.py +++ b/python/sglang/srt/function_call/qwen25_detector.py @@ -17,9 +17,18 @@ class Qwen25Detector(BaseFormatDetector): """ - Detector for Qwen 2.5 models. - Assumes function call format: - \n{"name":"func1", "arguments":{...}}\n\n\n{"name":"func2", "arguments":{...}}\n + Detector for Qwen 2.5 and Qwen 3 model function call format. + + Format Structure: + ``` + \n{"name":"func1", "arguments":{...}}\n\n\n{"name":"func2", "arguments":{...}}\n + ``` + + Key Components: + - Tool Call Tags: `` and `` wrap each individual call + - Function Call Object: JSON object with "name" and "arguments" fields + + Reference: https://huggingface.co/Qwen/Qwen2.5-0.5B-Instruct?chat_template=default """ def __init__(self): diff --git a/test/srt/test_function_call_parser.py b/test/srt/test_function_call_parser.py index f9c36a9a2ed2..c2f63e7e4a0c 100644 --- a/test/srt/test_function_call_parser.py +++ b/test/srt/test_function_call_parser.py @@ -507,6 +507,7 @@ def setUp(self): self.llama32_detector = Llama32Detector() self.mistral_detector = MistralDetector() self.qwen25_detector = Qwen25Detector() + self.kimik2_detector = KimiK2Detector() def test_pythonic_detector_ebnf(self): """Test that the PythonicDetector generates valid EBNF.""" @@ -542,6 +543,33 @@ def test_deepseekv3_detector_ebnf(self): except RuntimeError as e: self.fail(f"Failed to compile EBNF: {e}") + def test_kimik2_detector_ebnf(self): + """Test that the KimiK2Detector generates valid EBNF.""" + ebnf = self.kimik2_detector.build_ebnf(self.tools) + self.assertIsNotNone(ebnf) + + # Check that the EBNF contains expected patterns for KimiK2 format + self.assertIn("<|tool_calls_section_begin|>", ebnf) + self.assertIn("<|tool_calls_section_end|>", ebnf) + + # Check for KimiK2-specific function call structure + self.assertIn("<|tool_call_begin|>functions.get_weather:", ebnf) + self.assertIn("<|tool_call_begin|>functions.search:", ebnf) + self.assertIn("<|tool_call_argument_begin|>", ebnf) + self.assertIn("<|tool_call_end|>", ebnf) + + # Check that it uses the correct namespace.function format with numeric index pattern + self.assertIn("functions.get_weather:", ebnf) + self.assertIn("functions.search:", ebnf) + self.assertIn("[0-9]+", ebnf) # Numeric index pattern + + # Validate that the EBNF can be compiled by GrammarCompiler + try: + ctx = self.grammar_compiler.compile_grammar(ebnf) + self.assertIsNotNone(ctx, "EBNF should be valid and compile successfully") + except RuntimeError as e: + self.fail(f"Failed to compile EBNF: {e}") + def test_llama32_detector_ebnf(self): """Test that the Llama32Detector generates valid EBNF.""" ebnf = self.llama32_detector.build_ebnf(self.tools) From 624a3b8d1f105a1d9d730a709b73e23bd6f8b482 Mon Sep 17 00:00:00 2001 From: xianzhiT Date: Thu, 24 Jul 2025 08:40:23 +0800 Subject: [PATCH 109/396] Fix incomplete tool call capture issue in streaming response of DeepSeek-V3 when enable MTP (#7562) --- .../srt/function_call/deepseekv3_detector.py | 2 +- test/srt/test_function_call_parser.py | 89 +++++++++++++++++++ 2 files changed, 90 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/function_call/deepseekv3_detector.py b/python/sglang/srt/function_call/deepseekv3_detector.py index 35e96c715295..afd0e3012703 100644 --- a/python/sglang/srt/function_call/deepseekv3_detector.py +++ b/python/sglang/srt/function_call/deepseekv3_detector.py @@ -113,7 +113,7 @@ def parse_streaming_increment( calls: list[ToolCallItem] = [] try: partial_match = re.search( - pattern=r"<|tool▁call▁begin|>(.*)<|tool▁sep|>(.*)\n```json\n(.*)", + pattern=r"<|tool▁call▁begin|>(.*)<|tool▁sep|>(.*)\n```json\n(.*)\n```.*", string=current_text, flags=re.DOTALL, ) diff --git a/test/srt/test_function_call_parser.py b/test/srt/test_function_call_parser.py index c2f63e7e4a0c..26dd24fbb71b 100644 --- a/test/srt/test_function_call_parser.py +++ b/test/srt/test_function_call_parser.py @@ -1375,5 +1375,94 @@ def test_partial_tool_call(self): self.assertEqual(tool_calls[0]["parameters"], '{"city": "Paris"') +class TestDeepSeekV3Detector(unittest.TestCase): + def setUp(self): + """Set up test tools and detector for DeepSeekV3 format testing.""" + self.tools = [ + Tool( + type="function", + function=Function( + name="get_weather", + description="Get weather information", + parameters={ + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "City name", + } + }, + "required": ["city"], + }, + ), + ), + Tool( + type="function", + function=Function( + name="get_tourist_attractions", + description="Get tourist attractions", + parameters={ + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "City name", + } + }, + "required": ["city"], + }, + ), + ), + ] + self.detector = DeepSeekV3Detector() + + def test_parse_streaming_multiple_tool_calls_with_multi_token_chunk(self): + """Test parsing multiple tool calls when streaming chunks contains multi-tokens (e.g. DeepSeekV3 enable MTP)""" + # Simulate streaming chunks with multi-tokens for two consecutive tool calls + chunks = [ + "<|tool▁calls▁begin|>", + "<|tool▁call▁begin|>function", + "<|tool▁sep|>get", + "_weather\n", + "```json\n", + '{"city":', + '"Shanghai', + '"}\n```<|tool▁call▁end|>', + "\n<|tool▁call▁begin|>", + "function<|tool▁sep|>", + "get_tour", + "ist_att", + "ractions\n```" 'json\n{"', + 'city": "', + 'Beijing"}\n', + "```<|tool▁call▁end|>", + "<|tool▁calls▁end|>", + ] + + tool_calls_seen = [] + tool_calls_parameters = [] + + for chunk in chunks: + result = self.detector.parse_streaming_increment(chunk, self.tools) + if result.calls: + for call in result.calls: + if call.name: + tool_calls_seen.append(call.name) + if call.parameters: + tool_calls_parameters.append(call.parameters) + + # Should see both tool names + self.assertIn("get_weather", tool_calls_seen, "Should process first tool") + self.assertIn( + "get_tourist_attractions", tool_calls_seen, "Should process second tool" + ) + + # Verify that the parameters are valid JSON and contain the expected content + params1 = json.loads(tool_calls_parameters[0]) + params2 = json.loads(tool_calls_parameters[1]) + self.assertEqual(params1["city"], "Shanghai") + self.assertEqual(params2["city"], "Beijing") + + if __name__ == "__main__": unittest.main() From 0e5fa67773535d8916cf436fc3d1f689d7195b2f Mon Sep 17 00:00:00 2001 From: michael-amd Date: Wed, 23 Jul 2025 17:56:14 -0700 Subject: [PATCH 110/396] [AMD] Pull latest image for AMD CI (#8070) --- scripts/amd_ci_start_container.sh | 112 +++++++++++++++++++++++++++++- 1 file changed, 109 insertions(+), 3 deletions(-) diff --git a/scripts/amd_ci_start_container.sh b/scripts/amd_ci_start_container.sh index a6a527380d4f..239fd3770c26 100755 --- a/scripts/amd_ci_start_container.sh +++ b/scripts/amd_ci_start_container.sh @@ -1,6 +1,38 @@ #!/bin/bash set -euo pipefail +# Default base tags (can be overridden by command line arguments) +DEFAULT_MI30X_BASE_TAG="v0.4.9.post2-rocm630-mi30x" +DEFAULT_MI35X_BASE_TAG="v0.4.9.post2-rocm700-mi35x" + +# Parse command line arguments +MI30X_BASE_TAG="$DEFAULT_MI30X_BASE_TAG" +MI35X_BASE_TAG="$DEFAULT_MI35X_BASE_TAG" + +while [[ $# -gt 0 ]]; do + case $1 in + --mi30x-base-tag) + MI30X_BASE_TAG="$2" + shift 2 + ;; + --mi35x-base-tag) + MI35X_BASE_TAG="$2" + shift 2 + ;; + -h|--help) + echo "Usage: $0 [--mi30x-base-tag TAG] [--mi35x-base-tag TAG]" + echo " --mi30x-base-tag TAG Base tag for mi30x images (default: $DEFAULT_MI30X_BASE_TAG)" + echo " --mi35x-base-tag TAG Base tag for mi35x images (default: $DEFAULT_MI35X_BASE_TAG)" + exit 0 + ;; + *) + echo "Unknown option $1" + echo "Use --help for usage information" + exit 1 + ;; + esac +done + # Set up DEVICE_FLAG based on Kubernetes pod info if [ -f "/etc/podinfo/gha-render-devices" ]; then DEVICE_FLAG=$(cat /etc/podinfo/gha-render-devices) @@ -8,9 +40,83 @@ else DEVICE_FLAG="--device /dev/dri" fi -# Pull the image -IMAGE="rocm/sgl-dev:v0.4.9.post2-rocm630-mi30x-20250715" -echo "Pulling Docker image: $IMAGE" +# Function to find latest available image for a given GPU architecture +find_latest_image() { + local gpu_arch=$1 + local base_tag + + if [ "$gpu_arch" == "mi30x" ]; then + base_tag="$MI30X_BASE_TAG" + elif [ "$gpu_arch" == "mi35x" ]; then + base_tag="$MI35X_BASE_TAG" + else + echo "Error: Unsupported GPU architecture '$gpu_arch'" >&2 + return 1 + fi + + local days_back=0 + + while [ $days_back -lt 30 ]; do + local check_date=$(date -d "$days_back days ago" +%Y%m%d) + local image_tag="${base_tag}-${check_date}" + + echo "Checking for image: rocm/sgl-dev:${image_tag}" >&2 + + # Check if the image exists by trying to get its manifest + if docker manifest inspect "rocm/sgl-dev:${image_tag}" >/dev/null 2>&1; then + echo "Found available image: rocm/sgl-dev:${image_tag}" >&2 + echo "rocm/sgl-dev:${image_tag}" + return 0 + fi + + days_back=$((days_back + 1)) + done + + echo "Error: No ${gpu_arch} image found in the last 30 days" >&2 + return 1 +} + +# Determine image finder and fallback based on runner +# In Kubernetes, the hostname contains the GPU type (e.g., linux-mi300-gpu-1-bgg8r-runner-vknlb) +# Extract the GPU type from hostname +HOSTNAME_VALUE=$(hostname) +RUNNER_NAME="unknown" + +if [[ "${HOSTNAME_VALUE}" =~ ^(linux-mi[0-9]+-gpu-[0-9]+) ]]; then + RUNNER_NAME="${BASH_REMATCH[1]}" + echo "Extracted runner from hostname: ${RUNNER_NAME}" +else + echo "Could not extract runner info from hostname: ${HOSTNAME_VALUE}" +fi + +echo "The runner is: ${RUNNER_NAME}" +GPU_ARCH="mi30x" +FALLBACK_IMAGE="rocm/sgl-dev:${MI30X_BASE_TAG}-20250715" +FALLBACK_MSG="No mi30x image found in last 30 days, using fallback image" + +# Check for mi350/mi355 runners +if [[ "${RUNNER_NAME}" =~ ^linux-mi350-gpu-[0-9]+$ ]] || [[ "${RUNNER_NAME}" =~ ^linux-mi355-gpu-[0-9]+$ ]]; then + echo "Runner is ${RUNNER_NAME}, will find mi35x image." + GPU_ARCH="mi35x" + FALLBACK_IMAGE="rocm/sgl-dev:${MI35X_BASE_TAG}-20250715" + FALLBACK_MSG="No mi35x image found in last 30 days, using fallback image" +# Check for mi300/mi325 runners +elif [[ "${RUNNER_NAME}" =~ ^linux-mi300-gpu-[0-9]+$ ]] || [[ "${RUNNER_NAME}" =~ ^linux-mi325-gpu-[0-9]+$ ]]; then + echo "Runner is ${RUNNER_NAME}, will find mi30x image." +else + echo "Runner type not recognized: '${RUNNER_NAME}'" + echo "Defaulting to find mi30x image" +fi + +# Find and pull the latest image +IMAGE=$(find_latest_image "${GPU_ARCH}") +if [ $? -eq 0 ]; then + echo "Pulling Docker image: $IMAGE" +else + echo "$FALLBACK_MSG" >&2 + IMAGE="$FALLBACK_IMAGE" + echo "Pulling fallback Docker image: $IMAGE" +fi docker pull "$IMAGE" # Run the container From f7e102d56af50317b003fa3d3e86fcf4fe53d0d8 Mon Sep 17 00:00:00 2001 From: Haohui Mai Date: Wed, 23 Jul 2025 17:57:20 -0700 Subject: [PATCH 111/396] Pin the version of petit kernel to fix the APIs (#8235) --- python/pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyproject.toml b/python/pyproject.toml index 64915df6b590..1cf32215d9ef 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -79,7 +79,7 @@ blackwell = [ srt_hip = [ "sglang[runtime_common]", "torch", - "petit_kernel", + "petit_kernel==0.0.2", ] # xpu is not enabled in public vllm and torch whl, From 5dd0f870ab4f5b8d35efab7acca500c13c3b8419 Mon Sep 17 00:00:00 2001 From: Simo Lin Date: Wed, 23 Jul 2025 23:18:17 -0700 Subject: [PATCH 112/396] [bug] fix pd completion protocol for batching support (#8317) --- python/sglang/srt/entrypoints/openai/protocol.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/entrypoints/openai/protocol.py b/python/sglang/srt/entrypoints/openai/protocol.py index 7d065b5aaa0d..9c73e5fad19d 100644 --- a/python/sglang/srt/entrypoints/openai/protocol.py +++ b/python/sglang/srt/entrypoints/openai/protocol.py @@ -192,9 +192,9 @@ class CompletionRequest(BaseModel): session_params: Optional[Dict] = None # For PD disaggregation - bootstrap_host: Optional[str] = None - bootstrap_port: Optional[int] = None - bootstrap_room: Optional[int] = None + bootstrap_host: Optional[Union[List[str], str]] = None + bootstrap_port: Optional[Union[List[Optional[int]], int]] = None + bootstrap_room: Optional[Union[List[int], int]] = None # For request id rid: Optional[Union[List[str], str]] = None From f6e07f27969c6b55bd5b27316b0c9760ce221c6e Mon Sep 17 00:00:00 2001 From: Simo Lin Date: Wed, 23 Jul 2025 23:18:29 -0700 Subject: [PATCH 113/396] [router] fix pd model completion request (#8303) --- sgl-router/benches/request_processing.rs | 1 + sgl-router/src/openai_api_types.rs | 4 + sgl-router/src/routers/pd_router.rs | 90 +++++++-- sgl-router/src/routers/pd_types.rs | 233 ++++++++++++++++++++++ sgl-router/src/routers/request_adapter.rs | 5 + sgl-router/tests/benchmark_integration.rs | 2 + 6 files changed, 320 insertions(+), 15 deletions(-) diff --git a/sgl-router/benches/request_processing.rs b/sgl-router/benches/request_processing.rs index 576d07d2f79c..db5cdc901154 100644 --- a/sgl-router/benches/request_processing.rs +++ b/sgl-router/benches/request_processing.rs @@ -97,6 +97,7 @@ fn create_sample_completion_request() -> CompletionRequest { logit_bias: None, user: None, seed: None, + other: serde_json::Map::new(), } } diff --git a/sgl-router/src/openai_api_types.rs b/sgl-router/src/openai_api_types.rs index 9870fd06b8f0..d57e617675c9 100644 --- a/sgl-router/src/openai_api_types.rs +++ b/sgl-router/src/openai_api_types.rs @@ -91,6 +91,10 @@ pub struct CompletionRequest { /// If specified, our system will make a best effort to sample deterministically #[serde(skip_serializing_if = "Option::is_none")] pub seed: Option, + + /// Additional fields including bootstrap info for PD routing + #[serde(flatten)] + pub other: serde_json::Map, } impl GenerationRequest for CompletionRequest { diff --git a/sgl-router/src/routers/pd_router.rs b/sgl-router/src/routers/pd_router.rs index 7c70a3873fc3..ab9927d244d6 100644 --- a/sgl-router/src/routers/pd_router.rs +++ b/sgl-router/src/routers/pd_router.rs @@ -420,6 +420,77 @@ impl PDRouter { .await } + // Route a completion request while preserving OpenAI format + pub async fn route_completion( + &self, + client: &reqwest::Client, + req: &HttpRequest, + mut typed_req: CompletionRequest, + route: &str, + ) -> HttpResponse { + let start = Instant::now(); + + // Get stream flag and return_logprob flag before moving the request + let is_stream = typed_req.stream; + let return_logprob = typed_req.logprobs.is_some(); + + // Extract text for cache-aware routing from the typed request + let request_text = match &typed_req.prompt { + crate::openai_api_types::StringOrArray::String(s) => Some(s.as_str()), + crate::openai_api_types::StringOrArray::Array(arr) => arr.first().map(|s| s.as_str()), + }; + + // Select servers + let (prefill, decode) = match self.select_pd_pair(client, request_text).await { + Ok(pair) => pair, + Err(e) => { + error!("Failed to select PD pair: {}", e); + RouterMetrics::record_pd_error("server_selection"); + return HttpResponse::ServiceUnavailable() + .body(format!("No available servers: {}", e)); + } + }; + + // Log routing decision + info!( + "PD routing: {} -> prefill={}, decode={}", + route, + prefill.url(), + decode.url() + ); + + // Add bootstrap info using the trait method + if let Err(e) = typed_req.add_bootstrap_info(prefill.as_ref()) { + error!("Failed to add bootstrap info: {}", e); + RouterMetrics::record_pd_error("bootstrap_injection"); + return HttpResponse::InternalServerError() + .body(format!("Bootstrap injection failed: {}", e)); + } + + // Convert to JSON after bootstrap injection + let json_with_bootstrap = match serde_json::to_value(&typed_req) { + Ok(json) => json, + Err(e) => { + error!("Failed to serialize request: {}", e); + return HttpResponse::InternalServerError().body("Failed to serialize request"); + } + }; + + // Execute dual dispatch + self.execute_dual_dispatch( + client, + req, + json_with_bootstrap, + route, + prefill.as_ref(), + decode.as_ref(), + is_stream, + return_logprob, + start, + ) + .await + } + // Execute the dual dispatch to prefill and decode servers #[allow(clippy::too_many_arguments)] async fn execute_dual_dispatch( @@ -1302,23 +1373,12 @@ impl RouterTrait for PDRouter { req: &HttpRequest, body: serde_json::Value, ) -> HttpResponse { - match serde_json::from_value::(body.clone()) { + match serde_json::from_value::(body) { Ok(openai_req) => { - // Convert OpenAI format to PD format (CompletionRequest -> GenerateReqInput) - let pd_req = openai_req.to_pd_request(); - PDRouter::route_generate(self, client, req, pd_req, "/v1/completions").await - } - Err(_) => { - // If that fails, try to deserialize directly as PD format (for backwards compatibility) - match serde_json::from_value::(body) { - Ok(pd_req) => { - PDRouter::route_generate(self, client, req, pd_req, "/v1/completions").await - } - Err(e) => { - HttpResponse::BadRequest().body(format!("Invalid request format: {}", e)) - } - } + // Use the new method that preserves OpenAI format + PDRouter::route_completion(self, client, req, openai_req, "/v1/completions").await } + Err(e) => HttpResponse::BadRequest().body(format!("Invalid request format: {}", e)), } } diff --git a/sgl-router/src/routers/pd_types.rs b/sgl-router/src/routers/pd_types.rs index e83ab5b60f5b..993f2bf3d622 100644 --- a/sgl-router/src/routers/pd_types.rs +++ b/sgl-router/src/routers/pd_types.rs @@ -1,6 +1,7 @@ // Essential PDLB types extracted for PD routing use crate::core::{Worker, WorkerType}; +use crate::openai_api_types::{CompletionRequest, StringOrArray}; use serde::{Deserialize, Serialize}; use serde_json::Value; @@ -233,3 +234,235 @@ impl Bootstrap for ChatReqInput { self.bootstrap_room = Some(bootstrap_room); } } + +// Bootstrap implementation for CompletionRequest to preserve OpenAI format +impl Bootstrap for CompletionRequest { + fn is_stream(&self) -> bool { + self.stream + } + + fn get_batch_size(&self) -> Result, String> { + if let StringOrArray::Array(prompts) = &self.prompt { + if prompts.is_empty() { + return Err("Batch prompt array is empty".to_string()); + } + return Ok(Some(prompts.len())); + } + + // Single string prompt + Ok(None) + } + + fn set_bootstrap_info( + &mut self, + bootstrap_host: BootstrapHost, + bootstrap_port: BootstrapPort, + bootstrap_room: BootstrapRoom, + ) { + // Insert bootstrap_host - it serializes correctly whether Single or Batch + if let Ok(host_value) = serde_json::to_value(&bootstrap_host) { + self.other.insert("bootstrap_host".to_string(), host_value); + } + + // Insert bootstrap_port - it serializes correctly whether Single or Batch + if let Ok(port_value) = serde_json::to_value(&bootstrap_port) { + self.other.insert("bootstrap_port".to_string(), port_value); + } + + // Insert bootstrap_room - it serializes correctly whether Single or Batch + if let Ok(room_value) = serde_json::to_value(&bootstrap_room) { + self.other.insert("bootstrap_room".to_string(), room_value); + } + } +} + +#[cfg(test)] +mod bootstrap_tests { + use super::*; + use crate::openai_api_types::StringOrArray; + + #[test] + fn test_completion_batch_size_with_array_prompt() { + let req = CompletionRequest { + model: "test".to_string(), + prompt: StringOrArray::Array(vec!["prompt1".to_string(), "prompt2".to_string()]), + n: None, + other: serde_json::Map::new(), + suffix: None, + max_tokens: None, + temperature: None, + top_p: None, + stream: false, + stream_options: None, + logprobs: None, + echo: false, + stop: None, + presence_penalty: None, + frequency_penalty: None, + best_of: None, + logit_bias: None, + user: None, + seed: None, + }; + + // Should return batch size for array prompt + assert_eq!(req.get_batch_size().unwrap(), Some(2)); + } + + #[test] + fn test_completion_batch_size_with_single_prompt() { + let req = CompletionRequest { + model: "test".to_string(), + prompt: StringOrArray::String("single prompt".to_string()), + n: None, + other: serde_json::Map::new(), + suffix: None, + max_tokens: None, + temperature: None, + top_p: None, + stream: false, + stream_options: None, + logprobs: None, + echo: false, + stop: None, + presence_penalty: None, + frequency_penalty: None, + best_of: None, + logit_bias: None, + user: None, + seed: None, + }; + + // Should return None for single prompt + assert_eq!(req.get_batch_size().unwrap(), None); + } + + #[test] + fn test_completion_batch_size_with_n_parameter() { + let req = CompletionRequest { + model: "test".to_string(), + prompt: StringOrArray::String("single prompt".to_string()), + n: Some(3), + other: serde_json::Map::new(), + suffix: None, + max_tokens: None, + temperature: None, + top_p: None, + stream: false, + stream_options: None, + logprobs: None, + echo: false, + stop: None, + presence_penalty: None, + frequency_penalty: None, + best_of: None, + logit_bias: None, + user: None, + seed: None, + }; + + // Should return None for single string prompt, even with n > 1 + // SGLang handles n parameter differently than batch requests + assert_eq!(req.get_batch_size().unwrap(), None); + } + + #[test] + fn test_completion_bootstrap_single_values() { + let mut req = CompletionRequest { + model: "test".to_string(), + prompt: StringOrArray::Array(vec!["prompt1".to_string(), "prompt2".to_string()]), + n: None, + other: serde_json::Map::new(), + suffix: None, + max_tokens: None, + temperature: None, + top_p: None, + stream: false, + stream_options: None, + logprobs: None, + echo: false, + stop: None, + presence_penalty: None, + frequency_penalty: None, + best_of: None, + logit_bias: None, + user: None, + seed: None, + }; + + // Set bootstrap info - should always use single values + req.set_bootstrap_info( + BootstrapHost::Single("test-server".to_string()), + BootstrapPort::Single(Some(5678)), + BootstrapRoom::Single(12345), + ); + + // Verify single values were created + assert!(req.other.get("bootstrap_host").unwrap().is_string()); + assert!(req.other.get("bootstrap_port").unwrap().is_number()); + assert!(req.other.get("bootstrap_room").unwrap().is_number()); + + assert_eq!( + req.other.get("bootstrap_host").unwrap().as_str().unwrap(), + "test-server" + ); + assert_eq!( + req.other.get("bootstrap_port").unwrap().as_u64().unwrap(), + 5678 + ); + assert_eq!( + req.other.get("bootstrap_room").unwrap().as_u64().unwrap(), + 12345 + ); + } + + #[test] + fn test_completion_bootstrap_array_values() { + let mut req = CompletionRequest { + model: "test".to_string(), + prompt: StringOrArray::Array(vec!["prompt1".to_string(), "prompt2".to_string()]), + n: None, + other: serde_json::Map::new(), + suffix: None, + max_tokens: None, + temperature: None, + top_p: None, + stream: false, + stream_options: None, + logprobs: None, + echo: false, + stop: None, + presence_penalty: None, + frequency_penalty: None, + best_of: None, + logit_bias: None, + user: None, + seed: None, + }; + + // Set bootstrap info with arrays + req.set_bootstrap_info( + BootstrapHost::Batch(vec!["test-server".to_string(); 2]), + BootstrapPort::Batch(vec![Some(5678); 2]), + BootstrapRoom::Batch(vec![12345, 67890]), + ); + + // Verify arrays were created correctly + assert!(req.other.get("bootstrap_host").unwrap().is_array()); + assert!(req.other.get("bootstrap_port").unwrap().is_array()); + assert!(req.other.get("bootstrap_room").unwrap().is_array()); + + let hosts = req.other.get("bootstrap_host").unwrap().as_array().unwrap(); + assert_eq!(hosts.len(), 2); + assert_eq!(hosts[0].as_str().unwrap(), "test-server"); + + let ports = req.other.get("bootstrap_port").unwrap().as_array().unwrap(); + assert_eq!(ports.len(), 2); + assert_eq!(ports[0].as_u64().unwrap(), 5678); + + let rooms = req.other.get("bootstrap_room").unwrap().as_array().unwrap(); + assert_eq!(rooms.len(), 2); + assert_eq!(rooms[0].as_u64().unwrap(), 12345); + assert_eq!(rooms[1].as_u64().unwrap(), 67890); + } +} diff --git a/sgl-router/src/routers/request_adapter.rs b/sgl-router/src/routers/request_adapter.rs index 201c61aa55c8..f29bcecc9ea6 100644 --- a/sgl-router/src/routers/request_adapter.rs +++ b/sgl-router/src/routers/request_adapter.rs @@ -648,6 +648,7 @@ mod tests { user: None, seed: None, suffix: None, + other: serde_json::Map::new(), }; let pd_req = req.to_pd_request(); @@ -687,6 +688,7 @@ mod tests { user: None, seed: None, suffix: None, + other: serde_json::Map::new(), }; let pd_req = req.to_pd_request(); @@ -725,6 +727,7 @@ mod tests { user: Some("user123".to_string()), seed: Some(42), suffix: Some("...".to_string()), + other: serde_json::Map::new(), }; let pd_req = req.to_pd_request(); @@ -768,6 +771,7 @@ mod tests { user: None, seed: None, suffix: None, + other: serde_json::Map::new(), }; let pd_req = req.to_pd_request(); @@ -799,6 +803,7 @@ mod tests { user: None, seed: None, suffix: None, + other: serde_json::Map::new(), }; let pd_req = req.to_pd_request(); diff --git a/sgl-router/tests/benchmark_integration.rs b/sgl-router/tests/benchmark_integration.rs index 31785900011f..b7876e22398b 100644 --- a/sgl-router/tests/benchmark_integration.rs +++ b/sgl-router/tests/benchmark_integration.rs @@ -86,6 +86,7 @@ fn test_benchmark_request_creation() { logit_bias: None, user: None, seed: None, + other: serde_json::Map::new(), }; // Test serialization works @@ -181,6 +182,7 @@ fn test_benchmark_request_adaptation() { logit_bias: None, user: None, seed: None, + other: serde_json::Map::new(), }; // Test PD adaptation (should not panic) From bfb118c01e38fb7865742dcd9cf9075270283e9e Mon Sep 17 00:00:00 2001 From: Minho Ryu Date: Thu, 24 Jul 2025 15:18:47 +0900 Subject: [PATCH 114/396] fix bug when eos_ids==0 (#8315) --- python/sglang/srt/configs/model_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index 84c96d91df0b..cea455a24ed4 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -475,7 +475,7 @@ def _verify_quantization(self) -> None: def get_hf_eos_token_id(self) -> Optional[Set[int]]: eos_ids = getattr(self.hf_config, "eos_token_id", None) - if eos_ids: + if eos_ids is not None: # it can be either int or list of int eos_ids = {eos_ids} if isinstance(eos_ids, int) else set(eos_ids) if eos_ids is None: From 2f86f3ad62c175ff3f41e87fef6431cfb97a8083 Mon Sep 17 00:00:00 2001 From: Simo Lin Date: Wed, 23 Jul 2025 23:26:44 -0700 Subject: [PATCH 115/396] [router] add endpoint unit test (#8298) --- sgl-router/tests/api_endpoints_test.rs | 1309 ++++++++++++++++++++++++ sgl-router/tests/common/mock_worker.rs | 72 +- 2 files changed, 1374 insertions(+), 7 deletions(-) create mode 100644 sgl-router/tests/api_endpoints_test.rs diff --git a/sgl-router/tests/api_endpoints_test.rs b/sgl-router/tests/api_endpoints_test.rs new file mode 100644 index 000000000000..12e8dd2d2b88 --- /dev/null +++ b/sgl-router/tests/api_endpoints_test.rs @@ -0,0 +1,1309 @@ +mod common; + +use actix_web::{http::StatusCode, rt::System, test as actix_test, web, App}; +use common::mock_worker::{HealthStatus, MockWorker, MockWorkerConfig, WorkerType}; +use reqwest::Client; +use serde_json::json; +use sglang_router_rs::config::{PolicyConfig, RouterConfig, RoutingMode}; +use sglang_router_rs::server::{ + add_worker, flush_cache, generate, get_loads, get_model_info, get_server_info, health, + health_generate, list_workers, liveness, readiness, remove_worker, v1_chat_completions, + v1_completions, v1_models, AppState, +}; + +/// Test context that manages mock workers +struct TestContext { + workers: Vec, + app_state: web::Data, +} + +impl TestContext { + async fn new(worker_configs: Vec) -> Self { + // Create default router config + let config = RouterConfig { + mode: RoutingMode::Regular { + worker_urls: vec![], + }, + policy: PolicyConfig::Random, + host: "127.0.0.1".to_string(), + port: 3002, + max_payload_size: 256 * 1024 * 1024, + request_timeout_secs: 600, + worker_startup_timeout_secs: 1, + worker_startup_check_interval_secs: 1, + discovery: None, + metrics: None, + log_dir: None, + log_level: None, + }; + + Self::new_with_config(config, worker_configs).await + } + + async fn new_with_config(config: RouterConfig, worker_configs: Vec) -> Self { + let mut workers = Vec::new(); + let mut worker_urls = Vec::new(); + + // Start mock workers if any + for worker_config in worker_configs { + let mut worker = MockWorker::new(worker_config); + let url = worker.start().await.unwrap(); + worker_urls.push(url); + workers.push(worker); + } + + if !workers.is_empty() { + tokio::time::sleep(tokio::time::Duration::from_millis(200)).await; + } + + let client = Client::builder() + .timeout(std::time::Duration::from_secs(config.request_timeout_secs)) + .build() + .unwrap(); + + let app_state = AppState::new(config, client).unwrap(); + let app_state = web::Data::new(app_state); + + // Add workers if any + if !worker_urls.is_empty() { + let app = actix_test::init_service( + App::new().app_data(app_state.clone()).service(add_worker), + ) + .await; + + for url in &worker_urls { + let req = actix_test::TestRequest::post() + .uri(&format!("/add_worker?url={}", url)) + .to_request(); + let resp = actix_test::call_service(&app, req).await; + assert!(resp.status().is_success()); + } + + tokio::time::sleep(tokio::time::Duration::from_millis(500)).await; + } + + Self { workers, app_state } + } + + async fn create_app( + &self, + ) -> impl actix_web::dev::Service< + actix_http::Request, + Response = actix_web::dev::ServiceResponse, + Error = actix_web::Error, + > { + actix_test::init_service( + App::new() + .app_data(self.app_state.clone()) + .service(liveness) + .service(readiness) + .service(health) + .service(health_generate) + .service(get_server_info) + .service(get_model_info) + .service(v1_models) + .service(generate) + .service(v1_chat_completions) + .service(v1_completions) + .service(add_worker) + .service(list_workers) + .service(remove_worker) + .service(flush_cache) + .service(get_loads), + ) + .await + } + + async fn shutdown(mut self) { + for worker in &mut self.workers { + worker.stop().await; + } + } +} + +#[cfg(test)] +mod health_tests { + use super::*; + + #[test] + fn test_liveness_endpoint() { + System::new().block_on(async { + let ctx = TestContext::new(vec![]).await; + let app = ctx.create_app().await; + + let req = actix_test::TestRequest::get().uri("/liveness").to_request(); + + let resp = actix_test::call_service(&app, req).await; + assert_eq!(resp.status(), StatusCode::OK); + + ctx.shutdown().await; + }); + } + + #[test] + fn test_readiness_with_healthy_workers() { + System::new().block_on(async { + let ctx = TestContext::new(vec![MockWorkerConfig { + port: 18001, + worker_type: WorkerType::Regular, + health_status: HealthStatus::Healthy, + response_delay_ms: 0, + fail_rate: 0.0, + }]) + .await; + + let app = ctx.create_app().await; + + let req = actix_test::TestRequest::get() + .uri("/readiness") + .to_request(); + + let resp = actix_test::call_service(&app, req).await; + assert_eq!(resp.status(), StatusCode::OK); + + ctx.shutdown().await; + }); + } + + #[test] + fn test_readiness_with_unhealthy_workers() { + System::new().block_on(async { + // Create an empty context (no workers) + let ctx = TestContext::new(vec![]).await; + + let app = ctx.create_app().await; + + let req = actix_test::TestRequest::get() + .uri("/readiness") + .to_request(); + + let resp = actix_test::call_service(&app, req).await; + // With no workers, readiness should return SERVICE_UNAVAILABLE + assert_eq!(resp.status(), StatusCode::SERVICE_UNAVAILABLE); + + ctx.shutdown().await; + }); + } + + #[test] + fn test_health_endpoint_details() { + System::new().block_on(async { + let ctx = TestContext::new(vec![ + MockWorkerConfig { + port: 18003, + worker_type: WorkerType::Regular, + health_status: HealthStatus::Healthy, + response_delay_ms: 0, + fail_rate: 0.0, + }, + MockWorkerConfig { + port: 18004, + worker_type: WorkerType::Regular, + health_status: HealthStatus::Healthy, + response_delay_ms: 0, + fail_rate: 0.0, + }, + ]) + .await; + + let app = ctx.create_app().await; + + let req = actix_test::TestRequest::get().uri("/health").to_request(); + + let resp = actix_test::call_service(&app, req).await; + assert_eq!(resp.status(), StatusCode::OK); + + // The health endpoint returns plain text, not JSON + let body = actix_test::read_body(resp).await; + let body_str = String::from_utf8_lossy(&body); + assert!(body_str.contains("All servers healthy")); + + ctx.shutdown().await; + }); + } + + #[test] + fn test_health_generate_endpoint() { + System::new().block_on(async { + let ctx = TestContext::new(vec![MockWorkerConfig { + port: 18005, + worker_type: WorkerType::Regular, + health_status: HealthStatus::Healthy, + response_delay_ms: 0, + fail_rate: 0.0, + }]) + .await; + + let app = ctx.create_app().await; + + let req = actix_test::TestRequest::get() + .uri("/health_generate") + .to_request(); + + let resp = actix_test::call_service(&app, req).await; + assert_eq!(resp.status(), StatusCode::OK); + + let body: serde_json::Value = actix_test::read_body_json(resp).await; + assert!(body.is_object()); + + ctx.shutdown().await; + }); + } +} + +#[cfg(test)] +mod generation_tests { + use super::*; + + #[test] + fn test_generate_success() { + System::new().block_on(async { + let ctx = TestContext::new(vec![MockWorkerConfig { + port: 18101, + worker_type: WorkerType::Regular, + health_status: HealthStatus::Healthy, + response_delay_ms: 0, + fail_rate: 0.0, + }]) + .await; + + let app = ctx.create_app().await; + + let payload = json!({ + "text": "Hello, world!", + "stream": false + }); + + let req = actix_test::TestRequest::post() + .uri("/generate") + .set_json(&payload) + .to_request(); + + let resp = actix_test::call_service(&app, req).await; + assert_eq!(resp.status(), StatusCode::OK); + + let body: serde_json::Value = actix_test::read_body_json(resp).await; + assert!(body.get("text").is_some()); + assert!(body.get("meta_info").is_some()); + let meta_info = &body["meta_info"]; + assert!(meta_info.get("finish_reason").is_some()); + assert_eq!(meta_info["finish_reason"]["type"], "stop"); + + ctx.shutdown().await; + }); + } + + #[test] + fn test_generate_streaming() { + System::new().block_on(async { + let ctx = TestContext::new(vec![MockWorkerConfig { + port: 18102, + worker_type: WorkerType::Regular, + health_status: HealthStatus::Healthy, + response_delay_ms: 0, + fail_rate: 0.0, + }]) + .await; + + let app = ctx.create_app().await; + + let payload = json!({ + "text": "Stream test", + "stream": true + }); + + let req = actix_test::TestRequest::post() + .uri("/generate") + .set_json(&payload) + .to_request(); + + let resp = actix_test::call_service(&app, req).await; + assert_eq!(resp.status(), StatusCode::OK); + + // Check that it's a streaming response + let content_type = resp.headers().get("content-type"); + assert!(content_type.is_some()); + assert_eq!(content_type.unwrap(), "text/event-stream"); + + ctx.shutdown().await; + }); + } + + #[test] + fn test_generate_with_worker_failure() { + System::new().block_on(async { + let ctx = TestContext::new(vec![MockWorkerConfig { + port: 18103, + worker_type: WorkerType::Regular, + health_status: HealthStatus::Healthy, + response_delay_ms: 0, + fail_rate: 1.0, // Always fail + }]) + .await; + + let app = ctx.create_app().await; + + let payload = json!({ + "text": "This should fail", + "stream": false + }); + + let req = actix_test::TestRequest::post() + .uri("/generate") + .set_json(&payload) + .to_request(); + + let resp = actix_test::call_service(&app, req).await; + assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR); + + ctx.shutdown().await; + }); + } + + #[test] + fn test_v1_chat_completions_success() { + System::new().block_on(async { + let ctx = TestContext::new(vec![MockWorkerConfig { + port: 18104, + worker_type: WorkerType::Regular, + health_status: HealthStatus::Healthy, + response_delay_ms: 0, + fail_rate: 0.0, + }]) + .await; + + let app = ctx.create_app().await; + + let payload = json!({ + "model": "test-model", + "messages": [ + {"role": "user", "content": "Hello!"} + ], + "stream": false + }); + + let req = actix_test::TestRequest::post() + .uri("/v1/chat/completions") + .set_json(&payload) + .to_request(); + + let resp = actix_test::call_service(&app, req).await; + assert_eq!(resp.status(), StatusCode::OK); + + let body: serde_json::Value = actix_test::read_body_json(resp).await; + assert!(body.get("choices").is_some()); + + ctx.shutdown().await; + }); + } +} + +#[cfg(test)] +mod model_info_tests { + use super::*; + + #[test] + fn test_get_server_info() { + System::new().block_on(async { + let ctx = TestContext::new(vec![MockWorkerConfig { + port: 18201, + worker_type: WorkerType::Regular, + health_status: HealthStatus::Healthy, + response_delay_ms: 0, + fail_rate: 0.0, + }]) + .await; + + let app = ctx.create_app().await; + + let req = actix_test::TestRequest::get() + .uri("/get_server_info") + .to_request(); + + let resp = actix_test::call_service(&app, req).await; + assert_eq!(resp.status(), StatusCode::OK); + + let body: serde_json::Value = actix_test::read_body_json(resp).await; + assert!(body.is_object()); + // Check for actual sglang server fields + assert!(body.get("version").is_some()); + assert!(body.get("model_path").is_some()); + assert!(body.get("tokenizer_path").is_some()); + assert!(body.get("port").is_some()); + assert!(body.get("max_num_batched_tokens").is_some()); + assert!(body.get("schedule_policy").is_some()); + + ctx.shutdown().await; + }); + } + + #[test] + fn test_get_model_info() { + System::new().block_on(async { + let ctx = TestContext::new(vec![MockWorkerConfig { + port: 18202, + worker_type: WorkerType::Regular, + health_status: HealthStatus::Healthy, + response_delay_ms: 0, + fail_rate: 0.0, + }]) + .await; + + let app = ctx.create_app().await; + + let req = actix_test::TestRequest::get() + .uri("/get_model_info") + .to_request(); + + let resp = actix_test::call_service(&app, req).await; + assert_eq!(resp.status(), StatusCode::OK); + + let body: serde_json::Value = actix_test::read_body_json(resp).await; + assert!(body.is_object()); + // Check for actual sglang model info fields + assert_eq!( + body.get("model_path").and_then(|v| v.as_str()), + Some("mock-model-path") + ); + assert_eq!( + body.get("tokenizer_path").and_then(|v| v.as_str()), + Some("mock-tokenizer-path") + ); + assert_eq!( + body.get("is_generation").and_then(|v| v.as_bool()), + Some(true) + ); + assert!(body.get("preferred_sampling_params").is_some()); + + ctx.shutdown().await; + }); + } + + #[test] + fn test_v1_models() { + System::new().block_on(async { + let ctx = TestContext::new(vec![MockWorkerConfig { + port: 18203, + worker_type: WorkerType::Regular, + health_status: HealthStatus::Healthy, + response_delay_ms: 0, + fail_rate: 0.0, + }]) + .await; + + let app = ctx.create_app().await; + + let req = actix_test::TestRequest::get() + .uri("/v1/models") + .to_request(); + + let resp = actix_test::call_service(&app, req).await; + assert_eq!(resp.status(), StatusCode::OK); + + let body: serde_json::Value = actix_test::read_body_json(resp).await; + assert!(body.get("object").is_some()); + assert_eq!(body.get("object").and_then(|v| v.as_str()), Some("list")); + + let data = body.get("data").and_then(|v| v.as_array()); + assert!(data.is_some()); + + let models = data.unwrap(); + assert!(!models.is_empty()); + + let first_model = &models[0]; + assert_eq!( + first_model.get("id").and_then(|v| v.as_str()), + Some("mock-model-v1") + ); + assert_eq!( + first_model.get("object").and_then(|v| v.as_str()), + Some("model") + ); + assert!(first_model.get("created").is_some()); + assert_eq!( + first_model.get("owned_by").and_then(|v| v.as_str()), + Some("sglang") + ); + + ctx.shutdown().await; + }); + } + + #[test] + fn test_model_info_with_no_workers() { + System::new().block_on(async { + let ctx = TestContext::new(vec![]).await; + let app = ctx.create_app().await; + + // Test server info with no workers + let req = actix_test::TestRequest::get() + .uri("/get_server_info") + .to_request(); + let resp = actix_test::call_service(&app, req).await; + // Router may return various error codes when no workers + assert!( + resp.status() == StatusCode::OK + || resp.status() == StatusCode::SERVICE_UNAVAILABLE + || resp.status() == StatusCode::NOT_FOUND + || resp.status() == StatusCode::INTERNAL_SERVER_ERROR, + "Unexpected status code: {:?}", + resp.status() + ); + + // Test model info with no workers + let req = actix_test::TestRequest::get() + .uri("/get_model_info") + .to_request(); + let resp = actix_test::call_service(&app, req).await; + // Router may return various error codes when no workers + assert!( + resp.status() == StatusCode::OK + || resp.status() == StatusCode::SERVICE_UNAVAILABLE + || resp.status() == StatusCode::NOT_FOUND + || resp.status() == StatusCode::INTERNAL_SERVER_ERROR, + "Unexpected status code: {:?}", + resp.status() + ); + + // Test v1/models with no workers + let req = actix_test::TestRequest::get() + .uri("/v1/models") + .to_request(); + let resp = actix_test::call_service(&app, req).await; + // Router may return various error codes when no workers + assert!( + resp.status() == StatusCode::OK + || resp.status() == StatusCode::SERVICE_UNAVAILABLE + || resp.status() == StatusCode::NOT_FOUND + || resp.status() == StatusCode::INTERNAL_SERVER_ERROR, + "Unexpected status code: {:?}", + resp.status() + ); + + ctx.shutdown().await; + }); + } + + #[test] + fn test_model_info_with_multiple_workers() { + System::new().block_on(async { + let ctx = TestContext::new(vec![ + MockWorkerConfig { + port: 18204, + worker_type: WorkerType::Regular, + health_status: HealthStatus::Healthy, + response_delay_ms: 0, + fail_rate: 0.0, + }, + MockWorkerConfig { + port: 18205, + worker_type: WorkerType::Regular, + health_status: HealthStatus::Healthy, + response_delay_ms: 0, + fail_rate: 0.0, + }, + ]) + .await; + + let app = ctx.create_app().await; + + // Test that model info is consistent across workers + for _ in 0..5 { + let req = actix_test::TestRequest::get() + .uri("/get_model_info") + .to_request(); + + let resp = actix_test::call_service(&app, req).await; + assert_eq!(resp.status(), StatusCode::OK); + + let body: serde_json::Value = actix_test::read_body_json(resp).await; + assert_eq!( + body.get("model_path").and_then(|v| v.as_str()), + Some("mock-model-path") + ); + } + + ctx.shutdown().await; + }); + } + + #[test] + fn test_model_info_with_unhealthy_worker() { + System::new().block_on(async { + let ctx = TestContext::new(vec![MockWorkerConfig { + port: 18206, + worker_type: WorkerType::Regular, + health_status: HealthStatus::Healthy, + response_delay_ms: 0, + fail_rate: 1.0, // Always fail + }]) + .await; + + let app = ctx.create_app().await; + + let req = actix_test::TestRequest::get() + .uri("/get_model_info") + .to_request(); + + let resp = actix_test::call_service(&app, req).await; + // Worker with fail_rate: 1.0 should always return an error status + assert!( + resp.status() == StatusCode::INTERNAL_SERVER_ERROR + || resp.status() == StatusCode::SERVICE_UNAVAILABLE, + "Expected error status for always-failing worker, got: {:?}", + resp.status() + ); + + ctx.shutdown().await; + }); + } +} + +#[cfg(test)] +mod worker_management_tests { + use super::*; + + #[test] + fn test_add_new_worker() { + System::new().block_on(async { + let ctx = TestContext::new(vec![]).await; + let app = ctx.create_app().await; + + // Start a mock worker + let mut worker = MockWorker::new(MockWorkerConfig { + port: 18301, + worker_type: WorkerType::Regular, + health_status: HealthStatus::Healthy, + response_delay_ms: 0, + fail_rate: 0.0, + }); + let url = worker.start().await.unwrap(); + + // Add the worker + let req = actix_test::TestRequest::post() + .uri(&format!("/add_worker?url={}", url)) + .to_request(); + + let resp = actix_test::call_service(&app, req).await; + assert_eq!(resp.status(), StatusCode::OK); + + // List workers to verify + let req = actix_test::TestRequest::get() + .uri("/list_workers") + .to_request(); + + let resp = actix_test::call_service(&app, req).await; + assert_eq!(resp.status(), StatusCode::OK); + + let body: serde_json::Value = actix_test::read_body_json(resp).await; + let workers = body["urls"].as_array().unwrap(); + assert!(workers.iter().any(|w| w.as_str().unwrap() == url)); + + worker.stop().await; + ctx.shutdown().await; + }); + } + + #[test] + fn test_remove_existing_worker() { + System::new().block_on(async { + let ctx = TestContext::new(vec![MockWorkerConfig { + port: 18302, + worker_type: WorkerType::Regular, + health_status: HealthStatus::Healthy, + response_delay_ms: 0, + fail_rate: 0.0, + }]) + .await; + + let app = ctx.create_app().await; + + // Get the worker URL + let req = actix_test::TestRequest::get() + .uri("/list_workers") + .to_request(); + let resp = actix_test::call_service(&app, req).await; + let body: serde_json::Value = actix_test::read_body_json(resp).await; + let workers = body["urls"].as_array().unwrap(); + let worker_url = workers[0].as_str().unwrap(); + + // Remove the worker + let req = actix_test::TestRequest::post() + .uri(&format!("/remove_worker?url={}", worker_url)) + .to_request(); + + let resp = actix_test::call_service(&app, req).await; + assert_eq!(resp.status(), StatusCode::OK); + + // Verify it's removed + let req = actix_test::TestRequest::get() + .uri("/list_workers") + .to_request(); + let resp = actix_test::call_service(&app, req).await; + let body: serde_json::Value = actix_test::read_body_json(resp).await; + let workers = body["urls"].as_array().unwrap(); + assert!(workers.is_empty()); + + ctx.shutdown().await; + }); + } + + #[test] + fn test_add_worker_invalid_url() { + System::new().block_on(async { + let ctx = TestContext::new(vec![]).await; + let app = ctx.create_app().await; + + // Invalid URL format + let req = actix_test::TestRequest::post() + .uri("/add_worker?url=not-a-valid-url") + .to_request(); + + let resp = actix_test::call_service(&app, req).await; + assert_eq!(resp.status(), StatusCode::BAD_REQUEST); + + // Missing URL parameter + let req = actix_test::TestRequest::post() + .uri("/add_worker") + .to_request(); + + let resp = actix_test::call_service(&app, req).await; + assert_eq!(resp.status(), StatusCode::BAD_REQUEST); + + // Empty URL + let req = actix_test::TestRequest::post() + .uri("/add_worker?url=") + .to_request(); + + let resp = actix_test::call_service(&app, req).await; + assert_eq!(resp.status(), StatusCode::BAD_REQUEST); + + ctx.shutdown().await; + }); + } + + #[test] + fn test_add_duplicate_worker() { + System::new().block_on(async { + // Start a mock worker + let mut worker = MockWorker::new(MockWorkerConfig { + port: 18303, + worker_type: WorkerType::Regular, + health_status: HealthStatus::Healthy, + response_delay_ms: 0, + fail_rate: 0.0, + }); + let url = worker.start().await.unwrap(); + + let ctx = TestContext::new(vec![]).await; + let app = ctx.create_app().await; + + // Add worker first time + let req = actix_test::TestRequest::post() + .uri(&format!("/add_worker?url={}", url)) + .to_request(); + let resp = actix_test::call_service(&app, req).await; + assert_eq!(resp.status(), StatusCode::OK); + + tokio::time::sleep(tokio::time::Duration::from_millis(500)).await; + + // Try to add same worker again + let req = actix_test::TestRequest::post() + .uri(&format!("/add_worker?url={}", url)) + .to_request(); + let resp = actix_test::call_service(&app, req).await; + // Should return error for duplicate + assert_eq!(resp.status(), StatusCode::BAD_REQUEST); + + worker.stop().await; + ctx.shutdown().await; + }); + } + + #[test] + fn test_add_unhealthy_worker() { + System::new().block_on(async { + // Start unhealthy worker + let mut worker = MockWorker::new(MockWorkerConfig { + port: 18304, + worker_type: WorkerType::Regular, + health_status: HealthStatus::Unhealthy, + response_delay_ms: 0, + fail_rate: 0.0, + }); + let url = worker.start().await.unwrap(); + + let ctx = TestContext::new(vec![]).await; + let app = ctx.create_app().await; + + // Try to add unhealthy worker + let req = actix_test::TestRequest::post() + .uri(&format!("/add_worker?url={}", url)) + .to_request(); + let resp = actix_test::call_service(&app, req).await; + + // Router should reject unhealthy workers + assert!( + resp.status() == StatusCode::BAD_REQUEST + || resp.status() == StatusCode::SERVICE_UNAVAILABLE + ); + + worker.stop().await; + ctx.shutdown().await; + }); + } +} + +#[cfg(test)] +mod error_tests { + use super::*; + + #[test] + fn test_404_not_found() { + System::new().block_on(async { + let ctx = TestContext::new(vec![MockWorkerConfig { + port: 18401, + worker_type: WorkerType::Regular, + health_status: HealthStatus::Healthy, + response_delay_ms: 0, + fail_rate: 0.0, + }]) + .await; + + let app = ctx.create_app().await; + + // Test unknown endpoint + let req = actix_test::TestRequest::get() + .uri("/unknown_endpoint") + .to_request(); + + let resp = actix_test::call_service(&app, req).await; + assert_eq!(resp.status(), StatusCode::NOT_FOUND); + + // Test POST to unknown endpoint + let req = actix_test::TestRequest::post() + .uri("/api/v2/generate") + .set_json(&json!({"text": "test"})) + .to_request(); + + let resp = actix_test::call_service(&app, req).await; + assert_eq!(resp.status(), StatusCode::NOT_FOUND); + + ctx.shutdown().await; + }); + } + + #[test] + fn test_method_not_allowed() { + System::new().block_on(async { + let ctx = TestContext::new(vec![MockWorkerConfig { + port: 18402, + worker_type: WorkerType::Regular, + health_status: HealthStatus::Healthy, + response_delay_ms: 0, + fail_rate: 0.0, + }]) + .await; + + let app = ctx.create_app().await; + + // GET request to POST-only endpoint + let req = actix_test::TestRequest::get().uri("/generate").to_request(); + + let resp = actix_test::call_service(&app, req).await; + // Note: actix-web returns 404 for unmatched methods in some configurations + assert!( + resp.status() == StatusCode::METHOD_NOT_ALLOWED + || resp.status() == StatusCode::NOT_FOUND + ); + + // POST request to GET-only endpoint + let req = actix_test::TestRequest::post() + .uri("/health") + .set_json(&json!({})) + .to_request(); + + let resp = actix_test::call_service(&app, req).await; + // Note: actix-web returns 404 for unmatched methods in some configurations + assert!( + resp.status() == StatusCode::METHOD_NOT_ALLOWED + || resp.status() == StatusCode::NOT_FOUND + ); + + ctx.shutdown().await; + }); + } + + #[test] + fn test_payload_too_large() { + System::new().block_on(async { + // Create context with small payload limit + let config = RouterConfig { + mode: RoutingMode::Regular { + worker_urls: vec![], + }, + policy: PolicyConfig::Random, + host: "127.0.0.1".to_string(), + port: 3010, + max_payload_size: 1024, // 1KB limit + request_timeout_secs: 600, + worker_startup_timeout_secs: 1, + worker_startup_check_interval_secs: 1, + discovery: None, + metrics: None, + log_dir: None, + log_level: None, + }; + + let ctx = TestContext::new_with_config( + config, + vec![MockWorkerConfig { + port: 18403, + worker_type: WorkerType::Regular, + health_status: HealthStatus::Healthy, + response_delay_ms: 0, + fail_rate: 0.0, + }], + ) + .await; + + let app = ctx.create_app().await; + + // Create large payload (> 1KB) + let large_text = "x".repeat(2000); + let payload = json!({ + "text": large_text, + "stream": false + }); + + let req = actix_test::TestRequest::post() + .uri("/generate") + .set_json(&payload) + .to_request(); + + let resp = actix_test::call_service(&app, req).await; + // Note: The test framework may not enforce payload size limits the same way as the full server + // In production, the server middleware would reject large payloads before reaching handlers + assert!( + resp.status() == StatusCode::PAYLOAD_TOO_LARGE || resp.status() == StatusCode::OK + ); + + ctx.shutdown().await; + }); + } + + #[test] + fn test_invalid_json_payload() { + System::new().block_on(async { + let ctx = TestContext::new(vec![MockWorkerConfig { + port: 18404, + worker_type: WorkerType::Regular, + health_status: HealthStatus::Healthy, + response_delay_ms: 0, + fail_rate: 0.0, + }]) + .await; + + let app = ctx.create_app().await; + + // Send invalid JSON + let req = actix_test::TestRequest::post() + .uri("/generate") + .insert_header(("content-type", "application/json")) + .set_payload("{invalid json}") + .to_request(); + + let resp = actix_test::call_service(&app, req).await; + assert_eq!(resp.status(), StatusCode::BAD_REQUEST); + + // Send empty body + let req = actix_test::TestRequest::post() + .uri("/generate") + .insert_header(("content-type", "application/json")) + .to_request(); + + let resp = actix_test::call_service(&app, req).await; + assert_eq!(resp.status(), StatusCode::BAD_REQUEST); + + ctx.shutdown().await; + }); + } + + #[test] + fn test_missing_required_fields() { + System::new().block_on(async { + let ctx = TestContext::new(vec![MockWorkerConfig { + port: 18405, + worker_type: WorkerType::Regular, + health_status: HealthStatus::Healthy, + response_delay_ms: 0, + fail_rate: 0.0, + }]) + .await; + + let app = ctx.create_app().await; + + // Missing messages in chat completion + let payload = json!({ + "model": "test-model" + // missing "messages" + }); + + let req = actix_test::TestRequest::post() + .uri("/v1/chat/completions") + .set_json(&payload) + .to_request(); + + let resp = actix_test::call_service(&app, req).await; + // Note: Mock worker might accept this, but real implementation would return 400 + // The status depends on the actual router implementation + assert!(resp.status() == StatusCode::OK || resp.status() == StatusCode::BAD_REQUEST); + + ctx.shutdown().await; + }); + } + + #[test] + fn test_invalid_model() { + System::new().block_on(async { + let ctx = TestContext::new(vec![MockWorkerConfig { + port: 18406, + worker_type: WorkerType::Regular, + health_status: HealthStatus::Healthy, + response_delay_ms: 0, + fail_rate: 0.0, + }]) + .await; + + let app = ctx.create_app().await; + + let payload = json!({ + "model": "invalid-model-name-that-does-not-exist", + "messages": [{"role": "user", "content": "Hello"}], + "stream": false + }); + + let req = actix_test::TestRequest::post() + .uri("/v1/chat/completions") + .set_json(&payload) + .to_request(); + + let resp = actix_test::call_service(&app, req).await; + // Mock worker accepts any model, but real implementation might return 400 + assert!(resp.status().is_success() || resp.status() == StatusCode::BAD_REQUEST); + + ctx.shutdown().await; + }); + } +} + +#[cfg(test)] +mod cache_tests { + use super::*; + + #[test] + fn test_flush_cache() { + System::new().block_on(async { + let ctx = TestContext::new(vec![MockWorkerConfig { + port: 18501, + worker_type: WorkerType::Regular, + health_status: HealthStatus::Healthy, + response_delay_ms: 0, + fail_rate: 0.0, + }]) + .await; + + let app = actix_test::init_service( + App::new() + .app_data(ctx.app_state.clone()) + .service(flush_cache), + ) + .await; + + let req = actix_test::TestRequest::post() + .uri("/flush_cache") + .to_request(); + + let resp = actix_test::call_service(&app, req).await; + assert_eq!(resp.status(), StatusCode::OK); + + // The response might be empty or contain a message + let body_bytes = actix_test::read_body(resp).await; + if !body_bytes.is_empty() { + if let Ok(body) = serde_json::from_slice::(&body_bytes) { + // Check that we got a successful response with expected fields + assert!(body.is_object()); + assert!(body.get("message").is_some() || body.get("status").is_some()); + } + } + + ctx.shutdown().await; + }); + } + + #[test] + fn test_get_loads() { + System::new().block_on(async { + let ctx = TestContext::new(vec![ + MockWorkerConfig { + port: 18502, + worker_type: WorkerType::Regular, + health_status: HealthStatus::Healthy, + response_delay_ms: 0, + fail_rate: 0.0, + }, + MockWorkerConfig { + port: 18503, + worker_type: WorkerType::Regular, + health_status: HealthStatus::Healthy, + response_delay_ms: 0, + fail_rate: 0.0, + }, + ]) + .await; + + let app = actix_test::init_service( + App::new() + .app_data(ctx.app_state.clone()) + .service(get_loads), + ) + .await; + + let req = actix_test::TestRequest::get() + .uri("/get_loads") + .to_request(); + + let resp = actix_test::call_service(&app, req).await; + assert_eq!(resp.status(), StatusCode::OK); + + let body: serde_json::Value = actix_test::read_body_json(resp).await; + + // Verify the response contains load information + assert!(body.is_object()); + // The exact structure depends on the implementation + // but should contain worker load information + + ctx.shutdown().await; + }); + } + + #[test] + fn test_flush_cache_no_workers() { + System::new().block_on(async { + let ctx = TestContext::new(vec![]).await; + + let app = actix_test::init_service( + App::new() + .app_data(ctx.app_state.clone()) + .service(flush_cache), + ) + .await; + + let req = actix_test::TestRequest::post() + .uri("/flush_cache") + .to_request(); + + let resp = actix_test::call_service(&app, req).await; + // Should either succeed (no-op) or return service unavailable + assert!( + resp.status() == StatusCode::OK || resp.status() == StatusCode::SERVICE_UNAVAILABLE + ); + + ctx.shutdown().await; + }); + } +} + +#[cfg(test)] +mod load_balancing_tests { + use super::*; + + #[test] + fn test_request_distribution() { + System::new().block_on(async { + // Create multiple workers + let ctx = TestContext::new(vec![ + MockWorkerConfig { + port: 18601, + worker_type: WorkerType::Regular, + health_status: HealthStatus::Healthy, + response_delay_ms: 0, + fail_rate: 0.0, + }, + MockWorkerConfig { + port: 18602, + worker_type: WorkerType::Regular, + health_status: HealthStatus::Healthy, + response_delay_ms: 0, + fail_rate: 0.0, + }, + ]) + .await; + + let app = ctx.create_app().await; + + // Send multiple requests and track distribution + let mut request_count = 0; + for _ in 0..10 { + let payload = json!({ + "text": format!("Request {}", request_count), + "stream": false + }); + + let req = actix_test::TestRequest::post() + .uri("/generate") + .set_json(&payload) + .to_request(); + + let resp = actix_test::call_service(&app, req).await; + if resp.status() == StatusCode::OK { + request_count += 1; + } + } + + // With random policy, all requests should succeed + assert_eq!(request_count, 10); + + ctx.shutdown().await; + }); + } +} + +#[cfg(test)] +mod pd_mode_tests { + use super::*; + + #[test] + fn test_pd_mode_routing() { + System::new().block_on(async { + // Create PD mode configuration with prefill and decode workers + let mut prefill_worker = MockWorker::new(MockWorkerConfig { + port: 18701, + worker_type: WorkerType::Prefill, + health_status: HealthStatus::Healthy, + response_delay_ms: 0, + fail_rate: 0.0, + }); + + let mut decode_worker = MockWorker::new(MockWorkerConfig { + port: 18702, + worker_type: WorkerType::Decode, + health_status: HealthStatus::Healthy, + response_delay_ms: 0, + fail_rate: 0.0, + }); + + let prefill_url = prefill_worker.start().await.unwrap(); + let decode_url = decode_worker.start().await.unwrap(); + + tokio::time::sleep(tokio::time::Duration::from_millis(200)).await; + + // For PD mode, we'll skip the test for now since it requires special handling + // TODO: Implement PD mode testing with proper worker management + let _prefill_url = prefill_url; + let _decode_url = decode_url; + prefill_worker.stop().await; + decode_worker.stop().await; + }); + } +} diff --git a/sgl-router/tests/common/mock_worker.rs b/sgl-router/tests/common/mock_worker.rs index c5129febc895..3aba2b3b439c 100644 --- a/sgl-router/tests/common/mock_worker.rs +++ b/sgl-router/tests/common/mock_worker.rs @@ -99,9 +99,17 @@ impl MockWorker { // Handler implementations +/// Check if request should fail based on configured fail_rate +async fn should_fail(config: &MockWorkerConfig) -> bool { + rand::random::() < config.fail_rate +} + async fn health_handler(config: web::Data>>) -> HttpResponse { let config = config.read().await; + // Note: We don't apply fail_rate to health endpoint to allow workers to be added successfully + // fail_rate is only applied to actual request endpoints + match config.health_status { HealthStatus::Healthy => HttpResponse::Ok().json(json!({ "status": "healthy", @@ -122,6 +130,13 @@ async fn health_handler(config: web::Data>>) -> Htt async fn health_generate_handler(config: web::Data>>) -> HttpResponse { let config = config.read().await; + // Simulate failure based on fail_rate + if should_fail(&config).await { + return HttpResponse::InternalServerError().json(json!({ + "error": "Random failure for testing" + })); + } + if matches!(config.health_status, HealthStatus::Healthy) { HttpResponse::Ok().json(json!({ "status": "ok", @@ -138,6 +153,13 @@ async fn health_generate_handler(config: web::Data> async fn server_info_handler(config: web::Data>>) -> HttpResponse { let config = config.read().await; + // Simulate failure based on fail_rate + if should_fail(&config).await { + return HttpResponse::InternalServerError().json(json!({ + "error": "Random failure for testing" + })); + } + // Return response matching actual sglang server implementation HttpResponse::Ok().json(json!({ // Server args fields @@ -182,7 +204,16 @@ async fn server_info_handler(config: web::Data>>) - })) } -async fn model_info_handler(_config: web::Data>>) -> HttpResponse { +async fn model_info_handler(config: web::Data>>) -> HttpResponse { + let config = config.read().await; + + // Simulate failure based on fail_rate + if should_fail(&config).await { + return HttpResponse::InternalServerError().json(json!({ + "error": "Random failure for testing" + })); + } + // Return response matching actual sglang server implementation HttpResponse::Ok().json(json!({ "model_path": "mock-model-path", @@ -205,7 +236,7 @@ async fn generate_handler( let config = config.read().await; // Simulate failure based on fail_rate - if rand::random::() < config.fail_rate { + if should_fail(&config).await { return HttpResponse::InternalServerError().json(json!({ "error": "Random failure for testing" })); @@ -229,7 +260,10 @@ async fn generate_handler( tokio::spawn(async move { let tokens = vec!["This ", "is ", "a ", "mock ", "response."]; - let timestamp_start = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs_f64(); + let timestamp_start = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs_f64(); for (i, token) in tokens.iter().enumerate() { let chunk = json!({ @@ -248,7 +282,14 @@ async fn generate_handler( } }); - if tx.send(format!("data: {}\n\n", serde_json::to_string(&chunk).unwrap())).await.is_err() { + if tx + .send(format!( + "data: {}\n\n", + serde_json::to_string(&chunk).unwrap() + )) + .await + .is_err() + { break; } @@ -269,7 +310,6 @@ async fn generate_handler( } else { // Return non-streaming response matching sglang format let request_id = format!("mock-req-{}", rand::random::()); - let timestamp_start = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs_f64(); HttpResponse::Ok().json(json!({ "text": "Mock generated response for the input", @@ -567,7 +607,16 @@ async fn completions_handler( } } -async fn flush_cache_handler(_config: web::Data>>) -> HttpResponse { +async fn flush_cache_handler(config: web::Data>>) -> HttpResponse { + let config = config.read().await; + + // Simulate failure based on fail_rate + if should_fail(&config).await { + return HttpResponse::InternalServerError().json(json!({ + "error": "Random failure for testing" + })); + } + HttpResponse::Ok().json(json!({ "status": "success", "message": "Cache flushed", @@ -575,7 +624,16 @@ async fn flush_cache_handler(_config: web::Data>>) })) } -async fn v1_models_handler(_config: web::Data>>) -> HttpResponse { +async fn v1_models_handler(config: web::Data>>) -> HttpResponse { + let config = config.read().await; + + // Simulate failure based on fail_rate + if should_fail(&config).await { + return HttpResponse::InternalServerError().json(json!({ + "error": "Random failure for testing" + })); + } + HttpResponse::Ok().json(json!({ "object": "list", "data": [{ From a167fd0bcb9ef4b0f4331a109e40c8cdc770b026 Mon Sep 17 00:00:00 2001 From: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com> Date: Thu, 24 Jul 2025 14:38:30 +0800 Subject: [PATCH 116/396] [code style] Clean dead triton kernel code in fused_moe and useless vllm_ops import (#8310) --- .../layers/moe/fused_moe_triton/fused_moe.py | 249 ++---------------- .../compressed_tensors_moe.py | 11 +- .../sglang/srt/layers/quantization/utils.py | 9 - 3 files changed, 27 insertions(+), 242 deletions(-) diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py index 9c13c7e9dcb5..267b594c0a7b 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py @@ -53,9 +53,7 @@ from aiter import moe_sum except ImportError: raise ImportError("aiter is required when SGLANG_USE_AITER is set to True") -else: - from vllm import _custom_ops as vllm_ops - from vllm._custom_ops import scaled_fp8_quant + if _is_cuda or _is_hip: from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size @@ -63,9 +61,6 @@ logger = logging.getLogger(__name__) padding_size = 128 if bool(int(os.getenv("SGLANG_MOE_PADDING", "0"))) else 0 -enable_moe_align_block_size_triton = bool( - int(os.getenv("ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON", "0")) -) @triton.jit @@ -533,190 +528,6 @@ def fused_moe_kernel( tl.store(c_ptrs, accumulator, mask=c_mask) -@triton.jit -def moe_align_block_size_stage1( - topk_ids_ptr, - tokens_cnts_ptr, - num_experts: tl.constexpr, - numel: tl.constexpr, - tokens_per_thread: tl.constexpr, -): - pid = tl.program_id(0) - - start_idx = pid * tokens_per_thread - - off_c = (pid + 1) * num_experts - - for i in range(tokens_per_thread): - if start_idx + i < numel: - idx = tl.load(topk_ids_ptr + start_idx + i) - token_cnt = tl.load(tokens_cnts_ptr + off_c + idx) - tl.store(tokens_cnts_ptr + off_c + idx, token_cnt + 1) - - -@triton.jit -def moe_align_block_size_stage2( - tokens_cnts_ptr, - num_experts: tl.constexpr, -): - pid = tl.program_id(0) - - last_cnt = 0 - for i in range(1, num_experts + 1): - token_cnt = tl.load(tokens_cnts_ptr + i * num_experts + pid) - last_cnt = last_cnt + token_cnt - tl.store(tokens_cnts_ptr + i * num_experts + pid, last_cnt) - - -@triton.jit -def moe_align_block_size_stage3( - total_tokens_post_pad_ptr, - tokens_cnts_ptr, - cumsum_ptr, - num_experts: tl.constexpr, - block_size: tl.constexpr, -): - last_cumsum = 0 - off_cnt = num_experts * num_experts - for i in range(1, num_experts + 1): - token_cnt = tl.load(tokens_cnts_ptr + off_cnt + i - 1) - last_cumsum = last_cumsum + tl.cdiv(token_cnt, block_size) * block_size - tl.store(cumsum_ptr + i, last_cumsum) - tl.store(total_tokens_post_pad_ptr, last_cumsum) - - -@triton.jit -def moe_align_block_size_stage4( - topk_ids_ptr, - sorted_token_ids_ptr, - expert_ids_ptr, - tokens_cnts_ptr, - cumsum_ptr, - num_experts: tl.constexpr, - block_size: tl.constexpr, - numel: tl.constexpr, - tokens_per_thread: tl.constexpr, -): - pid = tl.program_id(0) - start_idx = tl.load(cumsum_ptr + pid) - end_idx = tl.load(cumsum_ptr + pid + 1) - - for i in range(start_idx, end_idx, block_size): - tl.store(expert_ids_ptr + i // block_size, pid) - - start_idx = pid * tokens_per_thread - off_t = pid * num_experts - - for i in range(start_idx, tl.minimum(start_idx + tokens_per_thread, numel)): - expert_id = tl.load(topk_ids_ptr + i) - token_cnt = tl.load(tokens_cnts_ptr + off_t + expert_id) - rank_post_pad = token_cnt + tl.load(cumsum_ptr + expert_id) - tl.store(sorted_token_ids_ptr + rank_post_pad, i) - tl.store(tokens_cnts_ptr + off_t + expert_id, token_cnt + 1) - - -def moe_align_block_size_triton( - topk_ids: torch.Tensor, - num_experts: int, - block_size: int, - sorted_token_ids: torch.Tensor, - expert_ids: torch.Tensor, - num_tokens_post_pad: torch.Tensor, -) -> None: - numel = topk_ids.numel() - grid = (num_experts,) - tokens_cnts = torch.zeros( - (num_experts + 1, num_experts), dtype=torch.int32, device=topk_ids.device - ) - cumsum = torch.zeros((num_experts + 1,), dtype=torch.int32, device=topk_ids.device) - tokens_per_thread = ceil_div(numel, num_experts) - - moe_align_block_size_stage1[grid]( - topk_ids, - tokens_cnts, - num_experts, - numel, - tokens_per_thread, - ) - moe_align_block_size_stage2[grid]( - tokens_cnts, - num_experts, - ) - moe_align_block_size_stage3[(1,)]( - num_tokens_post_pad, - tokens_cnts, - cumsum, - num_experts, - block_size, - ) - moe_align_block_size_stage4[grid]( - topk_ids, - sorted_token_ids, - expert_ids, - tokens_cnts, - cumsum, - num_experts, - block_size, - numel, - tokens_per_thread, - ) - - -@triton.jit -def init_sorted_ids_and_cumsum_buffer_kernel( - sorted_ids_ptr, - cumsum_buffer_ptr, - max_num_tokens_padded, - topk_ids_numel, - num_experts: tl.constexpr, - BLOCK_SIZE: tl.constexpr, - ALIGNED_NUM_EXPERTS_P1: tl.constexpr, -): - pid = tl.program_id(0) - offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - - sorted_ids_blocks = tl.cdiv(max_num_tokens_padded, BLOCK_SIZE) - - if pid < sorted_ids_blocks: - mask = offsets < max_num_tokens_padded - tl.store( - sorted_ids_ptr + offsets, - tl.full((BLOCK_SIZE,), topk_ids_numel, dtype=tl.int32), - mask=mask, - ) - elif pid == sorted_ids_blocks: - offset_e = tl.arange(0, ALIGNED_NUM_EXPERTS_P1) - mask_e = offset_e < num_experts + 1 - tl.store( - cumsum_buffer_ptr + offset_e, - tl.zeros((ALIGNED_NUM_EXPERTS_P1,), dtype=tl.int32), - mask=mask_e, - ) - - -def init_sorted_ids_and_cumsum_buffer( - max_num_tokens_padded: int, topk_ids_numel: int, num_experts: int, device="cuda" -): - sorted_ids = torch.empty((max_num_tokens_padded,), dtype=torch.int32, device=device) - cumsum_buffer = torch.empty((num_experts + 1,), dtype=torch.int32, device=device) - - BLOCK_SIZE = 1024 - sorted_ids_blocks = triton.cdiv(max_num_tokens_padded, BLOCK_SIZE) - grid = (sorted_ids_blocks + 1,) - - init_sorted_ids_and_cumsum_buffer_kernel[grid]( - sorted_ids, - cumsum_buffer, - max_num_tokens_padded, - topk_ids_numel, - num_experts, - BLOCK_SIZE, - next_power_of_2(num_experts + 1), - ) - - return sorted_ids, cumsum_buffer - - def moe_align_block_size( topk_ids: torch.Tensor, block_size: int, num_experts: int ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: @@ -766,42 +577,32 @@ def moe_align_block_size( (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device ) num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device) - if enable_moe_align_block_size_triton: - sorted_ids.fill_(topk_ids.numel()) - moe_align_block_size_triton( - topk_ids, - num_experts, - block_size, - sorted_ids, - expert_ids, - num_tokens_post_pad, - ) - else: - cumsum_buffer = torch.empty( - (num_experts + 1,), dtype=torch.int32, device=topk_ids.device - ) - token_cnts_buffer = torch.empty( - (num_experts + 1) * num_experts, - dtype=torch.int32, - device=topk_ids.device, - ) - # Threshold based on benchmark results - fuse_sorted_ids_padding = sorted_ids.shape[0] <= 4096 - if not fuse_sorted_ids_padding: - sorted_ids.fill_(topk_ids.numel()) + cumsum_buffer = torch.empty( + (num_experts + 1,), dtype=torch.int32, device=topk_ids.device + ) + token_cnts_buffer = torch.empty( + (num_experts + 1) * num_experts, + dtype=torch.int32, + device=topk_ids.device, + ) - sgl_moe_align_block_size( - topk_ids, - num_experts, - block_size, - sorted_ids, - expert_ids, - num_tokens_post_pad, - token_cnts_buffer, - cumsum_buffer, - fuse_sorted_ids_padding, - ) + # Threshold based on benchmark results + fuse_sorted_ids_padding = sorted_ids.shape[0] <= 4096 + if not fuse_sorted_ids_padding: + sorted_ids.fill_(topk_ids.numel()) + + sgl_moe_align_block_size( + topk_ids, + num_experts, + block_size, + sorted_ids, + expert_ids, + num_tokens_post_pad, + token_cnts_buffer, + cumsum_buffer, + fuse_sorted_ids_padding, + ) return sorted_ids, expert_ids, num_tokens_post_pad diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py index af1f6cbf7cc2..525a75069fe0 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -28,15 +28,6 @@ CompressedTensorsConfig, ) -_is_cuda = is_cuda() -_is_npu = is_npu() -_is_cpu_amx_available = cpu_has_amx_support() -_is_cpu = is_cpu() -_is_hip = is_hip() - -if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available) or _is_hip): - from vllm import _custom_ops as vllm_ops - from vllm._custom_ops import scaled_fp8_quant try: import vllm @@ -568,6 +559,8 @@ def marlin_moe_permute_scales( requires_grad=False, ) + from vllm import _custom_ops as vllm_ops + marlin_w13_qweight = vllm_ops.gptq_marlin_moe_repack( layer.w13_weight_packed, layer.w13_g_idx_sort_indices, diff --git a/python/sglang/srt/layers/quantization/utils.py b/python/sglang/srt/layers/quantization/utils.py index 8904247a6a8f..9b19e0309047 100644 --- a/python/sglang/srt/layers/quantization/utils.py +++ b/python/sglang/srt/layers/quantization/utils.py @@ -17,15 +17,6 @@ if TYPE_CHECKING: from sglang.srt.layers.quantization.base_config import QuantizationConfig -_is_cuda = is_cuda() -_is_npu = is_npu() -_is_cpu_amx_available = cpu_has_amx_support() -_is_cpu = is_cpu() -_is_hip = is_hip() - -if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available) or _is_hip): - from vllm._custom_ops import scaled_fp8_quant - def is_layer_skipped( prefix: str, From 8d1c5b948ed095fab7e0d4c0a7d31855d8fb8c0b Mon Sep 17 00:00:00 2001 From: Swipe4057 <106391009+Swipe4057@users.noreply.github.com> Date: Fri, 25 Jul 2025 01:29:56 +0400 Subject: [PATCH 117/396] chore: upgrade flashinfer v0.2.9rc1 (#8301) Co-authored-by: Yineng Zhang --- python/pyproject.toml | 4 ++-- python/sglang/srt/entrypoints/engine.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/python/pyproject.toml b/python/pyproject.toml index 1cf32215d9ef..7a18ee94ddaf 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -60,7 +60,7 @@ srt = [ "torchvision==0.22.1", "cuda-python", "einops", - "flashinfer_python==0.2.7.post1", + "flashinfer_python==0.2.9rc1", ] blackwell = [ @@ -71,7 +71,7 @@ blackwell = [ "torchvision==0.22.1", "cuda-python", "einops", - "flashinfer_python==0.2.7.post1", + "flashinfer_python==0.2.9rc1", ] # HIP (Heterogeneous-computing Interface for Portability) for AMD diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index edf81a79a098..fd59624bcb56 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -646,7 +646,7 @@ def _set_envs_and_config(server_args: ServerArgs): if server_args.attention_backend == "flashinfer": assert_pkg_version( "flashinfer_python", - "0.2.7.post1", + "0.2.9rc1", "Please uninstall the old version and " "reinstall the latest version by following the instructions " "at https://docs.flashinfer.ai/installation.html.", From 33c4b4d04e50db11ebd1a81b37217da97a379044 Mon Sep 17 00:00:00 2001 From: Simo Lin Date: Thu, 24 Jul 2025 14:30:27 -0700 Subject: [PATCH 118/396] [router] add streaming unit test (#8299) --- sgl-router/tests/streaming_tests.rs | 579 ++++++++++++++++++++++++++++ 1 file changed, 579 insertions(+) create mode 100644 sgl-router/tests/streaming_tests.rs diff --git a/sgl-router/tests/streaming_tests.rs b/sgl-router/tests/streaming_tests.rs new file mode 100644 index 000000000000..47a1326ae575 --- /dev/null +++ b/sgl-router/tests/streaming_tests.rs @@ -0,0 +1,579 @@ +mod common; + +use actix_web::{http::StatusCode, rt::System, test as actix_test, web, App}; +use bytes::Bytes; +use common::mock_worker::{HealthStatus, MockWorker, MockWorkerConfig, WorkerType}; +use reqwest::Client; +use serde_json::json; +use sglang_router_rs::config::{PolicyConfig, RouterConfig, RoutingMode}; +use sglang_router_rs::server::{ + add_worker, generate, list_workers, v1_chat_completions, v1_completions, AppState, +}; +use std::time::Instant; + +/// Test context for streaming tests +struct StreamingTestContext { + workers: Vec, + app_state: web::Data, +} + +impl StreamingTestContext { + async fn new(worker_configs: Vec) -> Self { + let mut workers = Vec::new(); + let mut worker_urls = Vec::new(); + + // Start mock workers + for config in worker_configs { + let mut worker = MockWorker::new(config); + let url = worker.start().await.unwrap(); + worker_urls.push(url); + workers.push(worker); + } + + // Give workers time to start + tokio::time::sleep(tokio::time::Duration::from_millis(50)).await; + + // Create router config with empty worker URLs initially + // We'll add workers via the /add_worker endpoint + let config = RouterConfig { + mode: RoutingMode::Regular { + worker_urls: vec![], + }, + policy: PolicyConfig::Random, + host: "127.0.0.1".to_string(), + port: 3003, + max_payload_size: 256 * 1024 * 1024, + request_timeout_secs: 600, + worker_startup_timeout_secs: 1, + worker_startup_check_interval_secs: 1, + discovery: None, + metrics: None, + log_dir: None, + log_level: None, + }; + + let client = Client::builder() + .timeout(std::time::Duration::from_secs(config.request_timeout_secs)) + .build() + .unwrap(); + + let app_state = AppState::new(config, client).unwrap(); + let app_state = web::Data::new(app_state); + + // Add workers via HTTP API + let app = + actix_test::init_service(App::new().app_data(app_state.clone()).service(add_worker)) + .await; + + for url in &worker_urls { + let req = actix_test::TestRequest::post() + .uri(&format!("/add_worker?url={}", url)) + .to_request(); + let resp = actix_test::call_service(&app, req).await; + assert!(resp.status().is_success()); + } + + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + Self { workers, app_state } + } + + async fn create_app( + &self, + ) -> impl actix_web::dev::Service< + actix_http::Request, + Response = actix_web::dev::ServiceResponse, + Error = actix_web::Error, + > { + actix_test::init_service( + App::new() + .app_data(self.app_state.clone()) + .service(generate) + .service(v1_chat_completions) + .service(v1_completions) + .service(list_workers), + ) + .await + } + + async fn shutdown(mut self) { + for worker in &mut self.workers { + worker.stop().await; + } + } +} + +/// Parse SSE (Server-Sent Events) from response body +async fn parse_sse_stream(body: Bytes) -> Vec { + let text = String::from_utf8_lossy(&body); + let mut events = Vec::new(); + + for line in text.lines() { + if line.starts_with("data: ") { + let data = &line[6..]; + if data == "[DONE]" { + continue; + } + if let Ok(json) = serde_json::from_str::(data) { + events.push(json); + } + } + } + + events +} + +#[cfg(test)] +mod basic_streaming_tests { + use super::*; + + #[test] + fn test_router_uses_mock_workers() { + System::new().block_on(async { + let ctx = StreamingTestContext::new(vec![MockWorkerConfig { + port: 19000, + worker_type: WorkerType::Regular, + health_status: HealthStatus::Healthy, + response_delay_ms: 0, + fail_rate: 0.0, + }]) + .await; + + let app = ctx.create_app().await; + + // Verify workers are registered with the router + let req = actix_test::TestRequest::get() + .uri("/list_workers") + .to_request(); + + let resp = actix_test::call_service(&app, req).await; + assert_eq!(resp.status(), StatusCode::OK); + + let body: serde_json::Value = actix_test::read_body_json(resp).await; + let urls = body["urls"].as_array().unwrap(); + assert_eq!(urls.len(), 1); + assert!(urls[0].as_str().unwrap().contains("19000")); + + ctx.shutdown().await; + }); + } + + #[test] + fn test_generate_streaming() { + System::new().block_on(async { + let ctx = StreamingTestContext::new(vec![MockWorkerConfig { + port: 19001, + worker_type: WorkerType::Regular, + health_status: HealthStatus::Healthy, + response_delay_ms: 0, + fail_rate: 0.0, + }]) + .await; + + let app = ctx.create_app().await; + + let payload = json!({ + "text": "Hello, streaming world!", + "stream": true, + "max_new_tokens": 50 + }); + + let req = actix_test::TestRequest::post() + .uri("/generate") + .set_json(&payload) + .to_request(); + + let resp = actix_test::call_service(&app, req).await; + assert_eq!(resp.status(), StatusCode::OK); + + // Check content type + let content_type = resp.headers().get("content-type").unwrap(); + assert_eq!(content_type, "text/event-stream"); + + // Read streaming body + let body = actix_test::read_body(resp).await; + let events = parse_sse_stream(body).await; + + // Verify we got multiple chunks + assert!(events.len() > 1); + + // Verify first chunk has text + assert!(events[0].get("text").is_some()); + + // Verify last chunk has finish_reason in meta_info + let last_event = events.last().unwrap(); + assert!(last_event.get("meta_info").is_some()); + let meta_info = &last_event["meta_info"]; + assert!(meta_info.get("finish_reason").is_some()); + + ctx.shutdown().await; + }); + } + + #[test] + fn test_chat_completion_streaming() { + System::new().block_on(async { + let ctx = StreamingTestContext::new(vec![MockWorkerConfig { + port: 19002, + worker_type: WorkerType::Regular, + health_status: HealthStatus::Healthy, + response_delay_ms: 0, + fail_rate: 0.0, + }]) + .await; + + let app = ctx.create_app().await; + + let payload = json!({ + "model": "test-model", + "messages": [ + {"role": "user", "content": "Hello, streaming!"} + ], + "stream": true + }); + + let req = actix_test::TestRequest::post() + .uri("/v1/chat/completions") + .set_json(&payload) + .to_request(); + + let resp = actix_test::call_service(&app, req).await; + assert_eq!(resp.status(), StatusCode::OK); + assert_eq!( + resp.headers().get("content-type").unwrap(), + "text/event-stream" + ); + + let body = actix_test::read_body(resp).await; + let events = parse_sse_stream(body).await; + + // Verify we got streaming events + // Note: Mock doesn't provide full OpenAI format, just verify we got chunks + assert!(!events.is_empty(), "Should have received streaming events"); + + ctx.shutdown().await; + }); + } + + #[test] + fn test_completion_streaming() { + System::new().block_on(async { + let ctx = StreamingTestContext::new(vec![MockWorkerConfig { + port: 19003, + worker_type: WorkerType::Regular, + health_status: HealthStatus::Healthy, + response_delay_ms: 0, + fail_rate: 0.0, + }]) + .await; + + let app = ctx.create_app().await; + + let payload = json!({ + "model": "test-model", + "prompt": "Once upon a time", + "stream": true, + "max_tokens": 30 + }); + + let req = actix_test::TestRequest::post() + .uri("/v1/completions") + .set_json(&payload) + .to_request(); + + let resp = actix_test::call_service(&app, req).await; + assert_eq!(resp.status(), StatusCode::OK); + assert_eq!( + resp.headers().get("content-type").unwrap(), + "text/event-stream" + ); + + let _body = actix_test::read_body(resp).await; + + ctx.shutdown().await; + }); + } +} + +#[cfg(test)] +mod streaming_performance_tests { + use super::*; + + #[test] + fn test_streaming_first_token_latency() { + System::new().block_on(async { + let ctx = StreamingTestContext::new(vec![MockWorkerConfig { + port: 19010, + worker_type: WorkerType::Regular, + health_status: HealthStatus::Healthy, + response_delay_ms: 10, // Small delay to simulate processing + fail_rate: 0.0, + }]) + .await; + + let app = ctx.create_app().await; + + let payload = json!({ + "text": "Measure latency", + "stream": true + }); + + let req = actix_test::TestRequest::post() + .uri("/generate") + .set_json(&payload) + .to_request(); + + let start = Instant::now(); + let resp = actix_test::call_service(&app, req).await; + assert_eq!(resp.status(), StatusCode::OK); + + // Note: actix_test framework doesn't provide easy access to streaming chunks. + // The ideal solution would be to: + // 1. Start the router as a real HTTP server + // 2. Use reqwest::Client to make streaming requests + // 3. Measure time to first chunk properly + // + // For now, we verify that streaming responses work correctly, + // but cannot accurately measure TTFT with actix_test. + let body = actix_test::read_body(resp).await; + let total_time = start.elapsed(); + + // Verify we got streaming data + let events = parse_sse_stream(body).await; + assert!(!events.is_empty(), "Should receive streaming events"); + + // With mock worker delay of 10ms, total time should still be reasonable + assert!( + total_time.as_millis() < 1000, + "Total response took {}ms", + total_time.as_millis() + ); + + ctx.shutdown().await; + }); + } + + #[test] + fn test_concurrent_streaming_requests() { + System::new().block_on(async { + // Test basic concurrent streaming functionality + let ctx = StreamingTestContext::new(vec![ + MockWorkerConfig { + port: 19050, + worker_type: WorkerType::Regular, + health_status: HealthStatus::Healthy, + response_delay_ms: 0, + fail_rate: 0.0, + }, + MockWorkerConfig { + port: 19051, + worker_type: WorkerType::Regular, + health_status: HealthStatus::Healthy, + response_delay_ms: 0, + fail_rate: 0.0, + }, + ]) + .await; + + let app = ctx.create_app().await; + + // Send a moderate number of concurrent requests for unit testing + use futures::future::join_all; + let mut futures = Vec::new(); + + for i in 0..20 { + let app_ref = &app; + let future = async move { + let payload = json!({ + "text": format!("Concurrent request {}", i), + "stream": true, + "max_new_tokens": 5 + }); + + let req = actix_test::TestRequest::post() + .uri("/generate") + .set_json(&payload) + .to_request(); + + let resp = actix_test::call_service(app_ref, req).await; + resp.status() == StatusCode::OK + }; + + futures.push(future); + } + + let results = join_all(futures).await; + let successful = results.iter().filter(|&&r| r).count(); + + // All requests should succeed in a unit test environment + assert_eq!( + successful, 20, + "Expected all 20 requests to succeed, got {}", + successful + ); + + ctx.shutdown().await; + }); + } + + // Note: Extreme load testing has been moved to benches/streaming_load_test.rs + // Run with: cargo run --release --bin streaming_load_test 10000 10 + // Or: cargo bench streaming_load_test +} + +#[cfg(test)] +mod streaming_error_tests { + use super::*; + + #[test] + fn test_streaming_with_worker_failure() { + System::new().block_on(async { + let ctx = StreamingTestContext::new(vec![MockWorkerConfig { + port: 19020, + worker_type: WorkerType::Regular, + health_status: HealthStatus::Healthy, + response_delay_ms: 0, + fail_rate: 1.0, // Always fail + }]) + .await; + + let app = ctx.create_app().await; + + let payload = json!({ + "text": "This should fail", + "stream": true + }); + + let req = actix_test::TestRequest::post() + .uri("/generate") + .set_json(&payload) + .to_request(); + + let resp = actix_test::call_service(&app, req).await; + assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR); + + ctx.shutdown().await; + }); + } + + #[test] + fn test_streaming_with_invalid_payload() { + System::new().block_on(async { + let ctx = StreamingTestContext::new(vec![MockWorkerConfig { + port: 19021, + worker_type: WorkerType::Regular, + health_status: HealthStatus::Healthy, + response_delay_ms: 0, + fail_rate: 0.0, + }]) + .await; + + let app = ctx.create_app().await; + + let payload = json!({ + // Missing required fields + "stream": true + }); + + let req = actix_test::TestRequest::post() + .uri("/generate") + .set_json(&payload) + .to_request(); + + let resp = actix_test::call_service(&app, req).await; + // TODO: Router should validate payload and reject requests with missing content fields + // Currently, the router accepts requests with no prompt/text/input_ids which is a bug + // This should return StatusCode::BAD_REQUEST once proper validation is implemented + assert_eq!(resp.status(), StatusCode::OK); + + ctx.shutdown().await; + }); + } +} + +#[cfg(test)] +mod streaming_content_tests { + use super::*; + + #[test] + fn test_unicode_streaming() { + System::new().block_on(async { + let ctx = StreamingTestContext::new(vec![MockWorkerConfig { + port: 19030, + worker_type: WorkerType::Regular, + health_status: HealthStatus::Healthy, + response_delay_ms: 0, + fail_rate: 0.0, + }]) + .await; + + let app = ctx.create_app().await; + + let payload = json!({ + "text": "Test Unicode: 你好世界 🌍 émojis", + "stream": true + }); + + let req = actix_test::TestRequest::post() + .uri("/generate") + .set_json(&payload) + .to_request(); + + let resp = actix_test::call_service(&app, req).await; + assert_eq!(resp.status(), StatusCode::OK); + + let body = actix_test::read_body(resp).await; + let events = parse_sse_stream(body).await; + + // Verify events were parsed correctly (Unicode didn't break parsing) + assert!(!events.is_empty()); + + ctx.shutdown().await; + }); + } + + #[test] + fn test_incremental_text_building() { + System::new().block_on(async { + let ctx = StreamingTestContext::new(vec![MockWorkerConfig { + port: 19031, + worker_type: WorkerType::Regular, + health_status: HealthStatus::Healthy, + response_delay_ms: 0, + fail_rate: 0.0, + }]) + .await; + + let app = ctx.create_app().await; + + let payload = json!({ + "text": "Build text incrementally", + "stream": true + }); + + let req = actix_test::TestRequest::post() + .uri("/generate") + .set_json(&payload) + .to_request(); + + let resp = actix_test::call_service(&app, req).await; + assert_eq!(resp.status(), StatusCode::OK); + + let body = actix_test::read_body(resp).await; + let events = parse_sse_stream(body).await; + + // Build complete text from chunks + let mut complete_text = String::new(); + for event in &events { + if let Some(text) = event.get("text").and_then(|t| t.as_str()) { + complete_text.push_str(text); + } + } + + // Verify we got some text + assert!(!complete_text.is_empty()); + + ctx.shutdown().await; + }); + } +} From 39fe1e880d55157179ff8e57d8fe385ef03d51e8 Mon Sep 17 00:00:00 2001 From: Simo Lin Date: Thu, 24 Jul 2025 14:30:37 -0700 Subject: [PATCH 119/396] [router] add request format unit test (#8300) --- sgl-router/tests/request_formats_test.rs | 573 +++++++++++++++++++++++ 1 file changed, 573 insertions(+) create mode 100644 sgl-router/tests/request_formats_test.rs diff --git a/sgl-router/tests/request_formats_test.rs b/sgl-router/tests/request_formats_test.rs new file mode 100644 index 000000000000..40045a0f7b15 --- /dev/null +++ b/sgl-router/tests/request_formats_test.rs @@ -0,0 +1,573 @@ +mod common; + +use actix_web::{http::StatusCode, rt::System, test as actix_test, web, App}; +use common::mock_worker::{HealthStatus, MockWorker, MockWorkerConfig, WorkerType}; +use reqwest::Client; +use serde_json::json; +use sglang_router_rs::config::{PolicyConfig, RouterConfig, RoutingMode}; +use sglang_router_rs::server::{ + add_worker, generate, v1_chat_completions, v1_completions, AppState, +}; + +/// Test context for request type testing +struct RequestTestContext { + workers: Vec, + app_state: web::Data, +} + +impl RequestTestContext { + async fn new(worker_configs: Vec) -> Self { + let mut workers = Vec::new(); + let mut worker_urls = Vec::new(); + + // Start mock workers + for config in worker_configs { + let mut worker = MockWorker::new(config); + let url = worker.start().await.unwrap(); + worker_urls.push(url); + workers.push(worker); + } + + tokio::time::sleep(tokio::time::Duration::from_millis(200)).await; + + // Create router config + let config = RouterConfig { + mode: RoutingMode::Regular { + worker_urls: vec![], + }, + policy: PolicyConfig::Random, + host: "127.0.0.1".to_string(), + port: 3006, + max_payload_size: 256 * 1024 * 1024, + request_timeout_secs: 600, + worker_startup_timeout_secs: 1, + worker_startup_check_interval_secs: 1, + discovery: None, + metrics: None, + log_dir: None, + log_level: None, + }; + + let client = Client::builder() + .timeout(std::time::Duration::from_secs(config.request_timeout_secs)) + .build() + .unwrap(); + + let app_state = AppState::new(config, client).unwrap(); + let app_state = web::Data::new(app_state); + + // Add workers via HTTP API + let app = + actix_test::init_service(App::new().app_data(app_state.clone()).service(add_worker)) + .await; + + for url in &worker_urls { + let req = actix_test::TestRequest::post() + .uri(&format!("/add_worker?url={}", url)) + .to_request(); + let resp = actix_test::call_service(&app, req).await; + assert!(resp.status().is_success()); + } + + tokio::time::sleep(tokio::time::Duration::from_millis(500)).await; + + Self { workers, app_state } + } + + async fn create_app( + &self, + ) -> impl actix_web::dev::Service< + actix_http::Request, + Response = actix_web::dev::ServiceResponse, + Error = actix_web::Error, + > { + actix_test::init_service( + App::new() + .app_data(self.app_state.clone()) + .service(generate) + .service(v1_chat_completions) + .service(v1_completions), + ) + .await + } + + async fn shutdown(mut self) { + for worker in &mut self.workers { + worker.stop().await; + } + } +} + +#[cfg(test)] +mod generate_input_format_tests { + use super::*; + + #[test] + fn test_generate_with_text_input() { + System::new().block_on(async { + let ctx = RequestTestContext::new(vec![MockWorkerConfig { + port: 21001, + worker_type: WorkerType::Regular, + health_status: HealthStatus::Healthy, + response_delay_ms: 0, + fail_rate: 0.0, + }]) + .await; + + let app = ctx.create_app().await; + + // Standard text input + let payload = json!({ + "text": "Hello world", + "stream": false + }); + + let req = actix_test::TestRequest::post() + .uri("/generate") + .set_json(&payload) + .to_request(); + + let resp = actix_test::call_service(&app, req).await; + assert_eq!(resp.status(), StatusCode::OK); + + let body: serde_json::Value = actix_test::read_body_json(resp).await; + assert!(body.get("text").is_some()); + + ctx.shutdown().await; + }); + } + + #[test] + fn test_generate_with_prompt_input() { + System::new().block_on(async { + let ctx = RequestTestContext::new(vec![MockWorkerConfig { + port: 21002, + worker_type: WorkerType::Regular, + health_status: HealthStatus::Healthy, + response_delay_ms: 0, + fail_rate: 0.0, + }]) + .await; + + let app = ctx.create_app().await; + + // Prompt input (alternative to text) + let payload = json!({ + "prompt": "Once upon a time", + "stream": false + }); + + let req = actix_test::TestRequest::post() + .uri("/generate") + .set_json(&payload) + .to_request(); + + let resp = actix_test::call_service(&app, req).await; + assert_eq!(resp.status(), StatusCode::OK); + + ctx.shutdown().await; + }); + } + + #[test] + fn test_generate_with_input_ids() { + System::new().block_on(async { + let ctx = RequestTestContext::new(vec![MockWorkerConfig { + port: 21003, + worker_type: WorkerType::Regular, + health_status: HealthStatus::Healthy, + response_delay_ms: 0, + fail_rate: 0.0, + }]) + .await; + + let app = ctx.create_app().await; + + // Input IDs (tokenized input) + let payload = json!({ + "input_ids": [1, 2, 3, 4, 5], + "stream": false + }); + + let req = actix_test::TestRequest::post() + .uri("/generate") + .set_json(&payload) + .to_request(); + + let resp = actix_test::call_service(&app, req).await; + assert_eq!(resp.status(), StatusCode::OK); + + ctx.shutdown().await; + }); + } + + #[test] + fn test_generate_with_all_parameters() { + System::new().block_on(async { + let ctx = RequestTestContext::new(vec![MockWorkerConfig { + port: 21004, + worker_type: WorkerType::Regular, + health_status: HealthStatus::Healthy, + response_delay_ms: 0, + fail_rate: 0.0, + }]) + .await; + + let app = ctx.create_app().await; + + // All generation parameters + let payload = json!({ + "text": "Complete this", + "temperature": 0.7, + "top_p": 0.9, + "top_k": 50, + "max_new_tokens": 100, + "min_new_tokens": 10, + "frequency_penalty": 0.5, + "presence_penalty": 0.3, + "repetition_penalty": 1.1, + "stop": [".", "!", "?"], + "stream": false + }); + + let req = actix_test::TestRequest::post() + .uri("/generate") + .set_json(&payload) + .to_request(); + + let resp = actix_test::call_service(&app, req).await; + assert_eq!(resp.status(), StatusCode::OK); + + ctx.shutdown().await; + }); + } +} + +#[cfg(test)] +mod chat_completion_format_tests { + use super::*; + + #[test] + fn test_chat_with_system_message() { + System::new().block_on(async { + let ctx = RequestTestContext::new(vec![MockWorkerConfig { + port: 21010, + worker_type: WorkerType::Regular, + health_status: HealthStatus::Healthy, + response_delay_ms: 0, + fail_rate: 0.0, + }]) + .await; + + let app = ctx.create_app().await; + + let payload = json!({ + "model": "test-model", + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello!"} + ] + }); + + let req = actix_test::TestRequest::post() + .uri("/v1/chat/completions") + .set_json(&payload) + .to_request(); + + let resp = actix_test::call_service(&app, req).await; + assert_eq!(resp.status(), StatusCode::OK); + + ctx.shutdown().await; + }); + } + + // Note: Function calling and tools tests are commented out because + // they require special handling in the mock worker that's not implemented yet. + // In production, these would be forwarded to the actual model. + + // #[test] + // fn test_chat_with_function_calling() { + // // Test would go here when mock worker supports function calling + // } + + // #[test] + // fn test_chat_with_tools() { + // // Test would go here when mock worker supports tools + // } + + #[test] + fn test_chat_with_response_format() { + System::new().block_on(async { + let ctx = RequestTestContext::new(vec![MockWorkerConfig { + port: 21013, + worker_type: WorkerType::Regular, + health_status: HealthStatus::Healthy, + response_delay_ms: 0, + fail_rate: 0.0, + }]) + .await; + + let app = ctx.create_app().await; + + let payload = json!({ + "model": "test-model", + "messages": [ + {"role": "user", "content": "Return JSON"} + ], + "response_format": { + "type": "json_object" + } + }); + + let req = actix_test::TestRequest::post() + .uri("/v1/chat/completions") + .set_json(&payload) + .to_request(); + + let resp = actix_test::call_service(&app, req).await; + assert_eq!(resp.status(), StatusCode::OK); + + ctx.shutdown().await; + }); + } +} + +#[cfg(test)] +mod completion_format_tests { + use super::*; + + #[test] + fn test_completion_with_single_prompt() { + System::new().block_on(async { + let ctx = RequestTestContext::new(vec![MockWorkerConfig { + port: 21020, + worker_type: WorkerType::Regular, + health_status: HealthStatus::Healthy, + response_delay_ms: 0, + fail_rate: 0.0, + }]) + .await; + + let app = ctx.create_app().await; + + let payload = json!({ + "model": "test-model", + "prompt": "Once upon a time", + "max_tokens": 50 + }); + + let req = actix_test::TestRequest::post() + .uri("/v1/completions") + .set_json(&payload) + .to_request(); + + let resp = actix_test::call_service(&app, req).await; + assert_eq!(resp.status(), StatusCode::OK); + + let body: serde_json::Value = actix_test::read_body_json(resp).await; + assert!(body.get("choices").is_some()); + + ctx.shutdown().await; + }); + } + + #[test] + fn test_completion_with_batch_prompts() { + System::new().block_on(async { + let ctx = RequestTestContext::new(vec![MockWorkerConfig { + port: 21021, + worker_type: WorkerType::Regular, + health_status: HealthStatus::Healthy, + response_delay_ms: 0, + fail_rate: 0.0, + }]) + .await; + + let app = ctx.create_app().await; + + let payload = json!({ + "model": "test-model", + "prompt": ["First prompt", "Second prompt", "Third prompt"], + "max_tokens": 30 + }); + + let req = actix_test::TestRequest::post() + .uri("/v1/completions") + .set_json(&payload) + .to_request(); + + let resp = actix_test::call_service(&app, req).await; + assert_eq!(resp.status(), StatusCode::OK); + + ctx.shutdown().await; + }); + } + + #[test] + fn test_completion_with_echo() { + System::new().block_on(async { + let ctx = RequestTestContext::new(vec![MockWorkerConfig { + port: 21022, + worker_type: WorkerType::Regular, + health_status: HealthStatus::Healthy, + response_delay_ms: 0, + fail_rate: 0.0, + }]) + .await; + + let app = ctx.create_app().await; + + let payload = json!({ + "model": "test-model", + "prompt": "Echo this prompt", + "echo": true, + "max_tokens": 20 + }); + + let req = actix_test::TestRequest::post() + .uri("/v1/completions") + .set_json(&payload) + .to_request(); + + let resp = actix_test::call_service(&app, req).await; + assert_eq!(resp.status(), StatusCode::OK); + + ctx.shutdown().await; + }); + } + + #[test] + fn test_completion_with_logprobs() { + System::new().block_on(async { + let ctx = RequestTestContext::new(vec![MockWorkerConfig { + port: 21023, + worker_type: WorkerType::Regular, + health_status: HealthStatus::Healthy, + response_delay_ms: 0, + fail_rate: 0.0, + }]) + .await; + + let app = ctx.create_app().await; + + let payload = json!({ + "model": "test-model", + "prompt": "Calculate probability", + "logprobs": 5, + "max_tokens": 10 + }); + + let req = actix_test::TestRequest::post() + .uri("/v1/completions") + .set_json(&payload) + .to_request(); + + let resp = actix_test::call_service(&app, req).await; + assert_eq!(resp.status(), StatusCode::OK); + + ctx.shutdown().await; + }); + } + + #[test] + fn test_completion_with_suffix() { + System::new().block_on(async { + let ctx = RequestTestContext::new(vec![MockWorkerConfig { + port: 21024, + worker_type: WorkerType::Regular, + health_status: HealthStatus::Healthy, + response_delay_ms: 0, + fail_rate: 0.0, + }]) + .await; + + let app = ctx.create_app().await; + + let payload = json!({ + "model": "test-model", + "prompt": "Insert text here: ", + "suffix": " and continue from here.", + "max_tokens": 20 + }); + + let req = actix_test::TestRequest::post() + .uri("/v1/completions") + .set_json(&payload) + .to_request(); + + let resp = actix_test::call_service(&app, req).await; + assert_eq!(resp.status(), StatusCode::OK); + + ctx.shutdown().await; + }); + } +} + +#[cfg(test)] +mod stop_sequence_tests { + use super::*; + + #[test] + fn test_stop_sequences_array() { + System::new().block_on(async { + let ctx = RequestTestContext::new(vec![MockWorkerConfig { + port: 21030, + worker_type: WorkerType::Regular, + health_status: HealthStatus::Healthy, + response_delay_ms: 0, + fail_rate: 0.0, + }]) + .await; + + let app = ctx.create_app().await; + + let payload = json!({ + "text": "Generate until stop", + "stop": [".", "!", "?", "\n"], + "stream": false + }); + + let req = actix_test::TestRequest::post() + .uri("/generate") + .set_json(&payload) + .to_request(); + + let resp = actix_test::call_service(&app, req).await; + assert_eq!(resp.status(), StatusCode::OK); + + ctx.shutdown().await; + }); + } + + #[test] + fn test_stop_sequences_string() { + System::new().block_on(async { + let ctx = RequestTestContext::new(vec![MockWorkerConfig { + port: 21031, + worker_type: WorkerType::Regular, + health_status: HealthStatus::Healthy, + response_delay_ms: 0, + fail_rate: 0.0, + }]) + .await; + + let app = ctx.create_app().await; + + let payload = json!({ + "text": "Generate until stop", + "stop": "\n\n", + "stream": false + }); + + let req = actix_test::TestRequest::post() + .uri("/generate") + .set_json(&payload) + .to_request(); + + let resp = actix_test::call_service(&app, req).await; + assert_eq!(resp.status(), StatusCode::OK); + + ctx.shutdown().await; + }); + } +} From 145482f422117eb5710bd2052679f0ceab8444f5 Mon Sep 17 00:00:00 2001 From: Zhiqiang Xie Date: Thu, 24 Jul 2025 17:31:47 -0700 Subject: [PATCH 120/396] HiCache Storage TP Refinement (#8307) Co-authored-by: pansicheng --- .../sglang/srt/managers/cache_controller.py | 58 +++++++++++++++++-- .../sglang/srt/mem_cache/hicache_storage.py | 18 +++++- python/sglang/srt/mem_cache/hiradix_cache.py | 46 ++++++++++----- .../sglang/srt/mem_cache/memory_pool_host.py | 3 + 4 files changed, 102 insertions(+), 23 deletions(-) diff --git a/python/sglang/srt/managers/cache_controller.py b/python/sglang/srt/managers/cache_controller.py index a94fdec78c32..9ef860f632c6 100644 --- a/python/sglang/srt/managers/cache_controller.py +++ b/python/sglang/srt/managers/cache_controller.py @@ -219,6 +219,7 @@ def __init__( token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator, mem_pool_host: HostKVCache, page_size: int, + tp_group: torch.distributed.ProcessGroup, load_cache_event: threading.Event = None, write_policy: str = "write_through_selective", io_backend: str = "", @@ -244,11 +245,17 @@ def __init__( self.enable_storage = False # todo: move backend initialization to storage backend module if storage_backend is not None: + # create a new communication group for synchronizing storage operations across TP workers + self.tp_world_size = torch.distributed.get_world_size(group=tp_group) + if self.tp_world_size > 1: + group_ranks = torch.distributed.get_process_group_ranks(tp_group) + self.tp_group = torch.distributed.new_group(group_ranks, backend="gloo") + if storage_backend == "file": self.storage_backend = HiCacheFile() self.enable_storage = True # todo: threshold policy for prefetching - self.prefetch_threshold = prefetch_threshold + self.prefetch_threshold = max(prefetch_threshold, self.page_size) else: raise NotImplementedError( f"Unsupported storage backend: {storage_backend}" @@ -568,13 +575,32 @@ def prefetch_thread_func(self): else: break + if self.tp_world_size > 1: + storage_hit_count_tensor = torch.tensor( + storage_hit_count, dtype=torch.int + ) + torch.distributed.all_reduce( + storage_hit_count_tensor, + op=torch.distributed.ReduceOp.MIN, + group=self.tp_group, + ) + storage_hit_count = storage_hit_count_tensor.item() + if storage_hit_count < self.prefetch_threshold: # not to prefetch if not enough benefits self.prefetch_revoke_queue.put(operation.request_id) + logger.debug( + f"Revoking prefetch for request {operation.request_id} due to insufficient hits ({storage_hit_count})." + ) else: - operation.hash_value = hash_value + operation.hash_value = hash_value[ + : (storage_hit_count // self.page_size) + ] + # free the pre-allocated memory for pages that are not hit + self.mem_pool_host.free(operation.host_indices[storage_hit_count:]) + operation.host_indices = operation.host_indices[:storage_hit_count] logger.debug( - f"Prefetching {len(hash_value)} pages for request {operation.request_id}." + f"Prefetching {len(operation.hash_value)} pages for request {operation.request_id}." ) self.prefetch_buffer.put(operation) @@ -611,17 +637,37 @@ def backup_thread_func(self): last_hash = get_hash_str( tokens_to_backup[i : i + self.page_size], last_hash ) - # todo, handle failures in storage backend - self.storage_backend.set( + success = self.storage_backend.set( last_hash, self.mem_pool_host.get_flat_data_page( operation.host_indices[i] ), ) + if not success: + logger.warning(f"Failed to write page {last_hash} to storage.") + break operation.completed_tokens += self.page_size operation.hash_value.append(last_hash) - self.ack_backup_queue.put((operation.id, operation.hash_value)) + min_completed_tokens = operation.completed_tokens + if self.tp_world_size > 1: + completed_tokens_tensor = torch.tensor( + min_completed_tokens, dtype=torch.int + ) + torch.distributed.all_reduce( + completed_tokens_tensor, + op=torch.distributed.ReduceOp.MIN, + group=self.tp_group, + ) + min_completed_tokens = completed_tokens_tensor.item() + + self.ack_backup_queue.put( + ( + operation.id, + operation.hash_value[: min_completed_tokens // self.page_size], + min_completed_tokens, + ) + ) except Empty: continue diff --git a/python/sglang/srt/mem_cache/hicache_storage.py b/python/sglang/srt/mem_cache/hicache_storage.py index 1dfe661ab5c9..45b26d10008b 100644 --- a/python/sglang/srt/mem_cache/hicache_storage.py +++ b/python/sglang/srt/mem_cache/hicache_storage.py @@ -9,6 +9,12 @@ logger = logging.getLogger(__name__) +from sglang.srt.distributed import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) + + def get_hash_str(token_ids: List[int], prior_hash: Optional[str] = None) -> str: hasher = hashlib.sha256() @@ -80,13 +86,20 @@ class HiCacheFile(HiCacheStorage): def __init__(self, file_path: str = "/tmp/hicache"): self.file_path = file_path - if not os.path.exists(self.file_path): + tp_rank = get_tensor_model_parallel_rank() + tp_size = get_tensor_model_parallel_world_size() + self.tp_suffix = f"_{tp_rank}_{tp_size}" if tp_size > 1 else "" + if not os.path.exists(self.file_path) and tp_rank == 0: os.makedirs(self.file_path) logger.info(f"Created HiCacheFile storage directory at {self.file_path}") + def _get_suffixed_key(self, key: str) -> str: + return key + self.tp_suffix + def get( self, key: str, target_location: Optional[torch.Tensor] = None ) -> torch.Tensor | None: + key = self._get_suffixed_key(key) tensor_path = os.path.join(self.file_path, f"{key}.bin") try: # todo: fixing the target_location logic to enable in-place loading @@ -112,6 +125,7 @@ def batch_get( ] def set(self, key: str, value: torch.Tensor) -> bool: + key = self._get_suffixed_key(key) tensor_path = os.path.join(self.file_path, f"{key}.bin") if self.exists(key): logger.debug(f"Key {key} already exists. Skipped.") @@ -130,10 +144,12 @@ def batch_set(self, keys: List[str], values: List[torch.Tensor]) -> bool: return True def exists(self, key: str) -> bool: + key = self._get_suffixed_key(key) tensor_path = os.path.join(self.file_path, f"{key}.bin") return os.path.exists(tensor_path) def delete(self, key: str) -> None: + key = self._get_suffixed_key(key) tensor_path = os.path.join(self.file_path, f"{key}.bin") try: os.remove(tensor_path) diff --git a/python/sglang/srt/mem_cache/hiradix_cache.py b/python/sglang/srt/mem_cache/hiradix_cache.py index 796f0553ceca..05248a1deb22 100644 --- a/python/sglang/srt/mem_cache/hiradix_cache.py +++ b/python/sglang/srt/mem_cache/hiradix_cache.py @@ -50,6 +50,7 @@ def __init__( raise ValueError(f"HiRadixCache only supports MHA and MLA yet") self.tp_group = tp_cache_group + self.tp_world_size = torch.distributed.get_world_size(group=self.tp_group) self.enable_storage = hicache_storage_backend is not None # todo: customizable storage prefetch threshold self.prefetch_threshold = 256 @@ -59,6 +60,7 @@ def __init__( token_to_kv_pool_allocator, self.token_to_kv_pool_host, page_size, + self.tp_group, load_cache_event=self.load_cache_event, write_policy=hicache_write_policy, io_backend=hicache_io_backend, @@ -153,7 +155,7 @@ def writing_check(self, write_back=False): queue_size = torch.tensor( self.cache_controller.ack_write_queue.qsize(), dtype=torch.int ) - if torch.distributed.get_world_size(group=self.tp_group) > 1: + if self.tp_world_size > 1: # synchrnoize TP workers to make the same update to radix cache torch.distributed.all_reduce( queue_size, @@ -353,7 +355,7 @@ def check_revoked_prefetch(self): queue_size = torch.tensor( self.cache_controller.prefetch_revoke_queue.qsize(), dtype=torch.int ) - if torch.distributed.get_world_size(group=self.tp_group) > 1: + if self.tp_world_size > 1: # synchrnoize TP workers to make the same update to hiradix cache torch.distributed.all_reduce( queue_size, @@ -372,7 +374,7 @@ def check_backup_progress(self): queue_size = torch.tensor( self.cache_controller.ack_backup_queue.qsize(), dtype=torch.int ) - if torch.distributed.get_world_size(group=self.tp_group) > 1: + if self.tp_world_size > 1: # synchrnoize TP workers to make the same update to hiradix cache torch.distributed.all_reduce( queue_size, @@ -380,9 +382,15 @@ def check_backup_progress(self): group=self.tp_group, ) for _ in range(queue_size.item()): - ack_id, hash_value = self.cache_controller.ack_backup_queue.get() - self.ongoing_backup[ack_id].hash_value = hash_value - self.ongoing_backup[ack_id].release_host() + ack_id, hash_value, completed_tokens = ( + self.cache_controller.ack_backup_queue.get() + ) + host_node = self.ongoing_backup[ack_id] + if completed_tokens < len(host_node.key): + # backup is only partially successful, split the node + new_node = self._split_node(host_node.key, host_node, completed_tokens) + new_node.hash_value = hash_value + host_node.release_host() del self.ongoing_backup[ack_id] def check_prefetch_progress(self, req_id: str): @@ -400,15 +408,18 @@ def check_prefetch_progress(self, req_id: str): ) logger.debug(f"Prefetch {req_id} completed with {completed_tokens} tokens") - min_completed_tokens = torch.tensor(completed_tokens, dtype=torch.int) - if torch.distributed.get_world_size(group=self.tp_group) > 1: + min_completed_tokens = completed_tokens + if self.tp_world_size > 1: # synchrnoize TP workers to make the same update to hiradix cache + completed_tokens_tensor = torch.tensor( + min_completed_tokens, dtype=torch.int + ) torch.distributed.all_reduce( - min_completed_tokens, + completed_tokens_tensor, op=torch.distributed.ReduceOp.MIN, group=self.tp_group, ) - min_completed_tokens = min_completed_tokens.item() + min_completed_tokens = completed_tokens_tensor.item() fetched_token_ids = token_ids[:min_completed_tokens] written_indices = host_indices[:min_completed_tokens] matched_length = self._insert_helper_host( @@ -465,16 +476,19 @@ def prefetch_from_storage( new_input_tokens: List[int], last_hash: Optional[str] = None, ): - if not self.enable_storage or len(new_input_tokens) < self.prefetch_threshold: + # align the number of fetching tokens to the page size + prefetch_length = len(new_input_tokens) - ( + len(new_input_tokens) % self.page_size + ) + new_input_tokens = new_input_tokens[:prefetch_length] + if not self.enable_storage or prefetch_length < self.prefetch_threshold: return last_host_node.protect_host() - host_indices = self.cache_controller.mem_pool_host.alloc(len(new_input_tokens)) + host_indices = self.cache_controller.mem_pool_host.alloc(prefetch_length) if host_indices is None: - self.evict_host(len(new_input_tokens)) - host_indices = self.cache_controller.mem_pool_host.alloc( - len(new_input_tokens) - ) + self.evict_host(prefetch_length) + host_indices = self.cache_controller.mem_pool_host.alloc(prefetch_length) if host_indices is None: last_host_node.release_host() # no sufficient host memory to prefetch diff --git a/python/sglang/srt/mem_cache/memory_pool_host.py b/python/sglang/srt/mem_cache/memory_pool_host.py index f503479628a9..0116e7141a38 100644 --- a/python/sglang/srt/mem_cache/memory_pool_host.py +++ b/python/sglang/srt/mem_cache/memory_pool_host.py @@ -126,6 +126,9 @@ def available_size(self): @synchronized() def alloc(self, need_size: int) -> torch.Tensor: + assert ( + need_size % self.page_size == 0 + ), "The requested size should be a multiple of the page size." if need_size > self.available_size(): return None From d40846d456ecc930c04538778ed11f67cc793c23 Mon Sep 17 00:00:00 2001 From: Zhiqiang Xie Date: Thu, 24 Jul 2025 17:33:17 -0700 Subject: [PATCH 121/396] breakdown kernel update (#8334) --- sgl-kernel/python/sgl_kernel/kvcacheio.py | 114 ++++++++-------------- sgl-kernel/tests/test_kvcacheio.py | 10 +- 2 files changed, 44 insertions(+), 80 deletions(-) diff --git a/sgl-kernel/python/sgl_kernel/kvcacheio.py b/sgl-kernel/python/sgl_kernel/kvcacheio.py index 1440c2ca35ec..83a611dd5873 100644 --- a/sgl-kernel/python/sgl_kernel/kvcacheio.py +++ b/sgl-kernel/python/sgl_kernel/kvcacheio.py @@ -10,30 +10,21 @@ def transfer_kv_per_layer( dst_v: torch.Tensor, src_indices: torch.Tensor, dst_indices: torch.Tensor, - io_backend: str, - page_size: int, item_size: int, block_quota: int = 2, num_warps_per_block: int = 32, ): - if io_backend == "kernel": - torch.ops.sgl_kernel.transfer_kv_per_layer( - src_k, - dst_k, - src_v, - dst_v, - src_indices, - dst_indices, - item_size * src_k.element_size(), # todo, hot fix for compatibility - block_quota, - num_warps_per_block, - ) - elif io_backend == "direct": - torch.ops.sgl_kernel.transfer_kv_direct( - [src_k, src_v], [dst_k, dst_v], src_indices, dst_indices, page_size - ) - else: - raise ValueError(f"Unsupported io backend") + torch.ops.sgl_kernel.transfer_kv_per_layer( + src_k, + dst_k, + src_v, + dst_v, + src_indices, + dst_indices, + item_size, + block_quota, + num_warps_per_block, + ) def transfer_kv_per_layer_pf_lf( @@ -69,29 +60,23 @@ def transfer_kv_all_layer( dst_v_layers: torch.Tensor, src_indices: torch.Tensor, dst_indices: torch.Tensor, - io_backend: str, item_size: int, num_layers: int, block_quota: int = 2, num_warps_per_block: int = 32, ): - if io_backend == "kernel": - torch.ops.sgl_kernel.transfer_kv_all_layer( - src_k_layers, - dst_k_layers, - src_v_layers, - dst_v_layers, - src_indices, - dst_indices, - item_size, - num_layers, - block_quota, - num_warps_per_block, - ) - elif io_backend == "direct": - raise NotImplementedError("Deprecated interface") - else: - raise ValueError(f"Unsupported io backend") + torch.ops.sgl_kernel.transfer_kv_all_layer( + src_k_layers, + dst_k_layers, + src_v_layers, + dst_v_layers, + src_indices, + dst_indices, + item_size, + num_layers, + block_quota, + num_warps_per_block, + ) def transfer_kv_all_layer_lf_pf( @@ -139,28 +124,19 @@ def transfer_kv_per_layer_mla( dst: torch.Tensor, src_indices: torch.Tensor, dst_indices: torch.Tensor, - io_backend: str, - page_size: int, item_size: int, block_quota: int = 2, num_warps_per_block: int = 32, ): - if io_backend == "kernel": - torch.ops.sgl_kernel.transfer_kv_per_layer_mla( - src, - dst, - src_indices, - dst_indices, - item_size * src.element_size(), # todo, hot fix for compatibility - block_quota, - num_warps_per_block, - ) - elif io_backend == "direct": - torch.ops.sgl_kernel.transfer_kv_direct( - [src], [dst], src_indices, dst_indices, page_size - ) - else: - raise ValueError(f"Unsupported io backend") + torch.ops.sgl_kernel.transfer_kv_per_layer_mla( + src, + dst, + src_indices, + dst_indices, + item_size, + block_quota, + num_warps_per_block, + ) def transfer_kv_per_layer_mla_pf_lf( @@ -190,27 +166,21 @@ def transfer_kv_all_layer_mla( dst_layers: torch.Tensor, src_indices: torch.Tensor, dst_indices: torch.Tensor, - io_backend: str, item_size: int, num_layers: int, block_quota: int = 2, num_warps_per_block: int = 32, ): - if io_backend == "kernel": - torch.ops.sgl_kernel.transfer_kv_all_layer_mla( - src_layers, - dst_layers, - src_indices, - dst_indices, - item_size, - num_layers, - block_quota, - num_warps_per_block, - ) - elif io_backend == "direct": - raise NotImplementedError("Deprecated interface") - else: - raise ValueError(f"Unsupported io backend") + torch.ops.sgl_kernel.transfer_kv_all_layer_mla( + src_layers, + dst_layers, + src_indices, + dst_indices, + item_size, + num_layers, + block_quota, + num_warps_per_block, + ) def transfer_kv_all_layer_mla_lf_pf( diff --git a/sgl-kernel/tests/test_kvcacheio.py b/sgl-kernel/tests/test_kvcacheio.py index 171fc4ca4793..d2b5be111973 100644 --- a/sgl-kernel/tests/test_kvcacheio.py +++ b/sgl-kernel/tests/test_kvcacheio.py @@ -101,9 +101,7 @@ def test_transfer_kv( dst_pool_kernel[layer_idx_to_test], src_indices_device, dst_indices_device, - io_backend="kernel", - page_size=page_size, - item_size=item_size, + item_size=item_size * dtype.itemsize, ) transfer_kv_direct( [src_pool_host[layer_idx_to_test]], @@ -138,7 +136,6 @@ def test_transfer_kv( dst_layers_device, src_indices_device, dst_indices_device, - io_backend="kernel", item_size=item_size * dtype.itemsize, num_layers=num_layers, ) @@ -173,9 +170,7 @@ def test_transfer_kv( dst_v_pool_kernel[layer_idx_to_test], src_indices_device, dst_indices_device, - io_backend="kernel", - page_size=page_size, - item_size=item_size, + item_size=item_size * dtype.itemsize, ) transfer_kv_direct( [src_k_pool[layer_idx_to_test], src_v_pool[layer_idx_to_test]], @@ -235,7 +230,6 @@ def test_transfer_kv( dst_v_layers_device, src_indices_device, dst_indices_device, - io_backend="kernel", item_size=item_size * dtype.itemsize, num_layers=num_layers, ) From f4674df646ca8a5515dfdc93677f7bdc052416c6 Mon Sep 17 00:00:00 2001 From: ZhichenJiang <1147802470@qq.com> Date: Fri, 25 Jul 2025 11:43:52 +0800 Subject: [PATCH 122/396] support idle batch for TBO (#8233) --- python/sglang/srt/two_batch_overlap.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/two_batch_overlap.py b/python/sglang/srt/two_batch_overlap.py index 3fdf2a1f77a6..74bc1ba8572e 100644 --- a/python/sglang/srt/two_batch_overlap.py +++ b/python/sglang/srt/two_batch_overlap.py @@ -341,15 +341,18 @@ def _compute_local_forward_mode(local_batch): @staticmethod def _compute_global_forward_mode(forward_modes): - converted_forward_modes = [ - ForwardMode.DECODE.value if x == ForwardMode.IDLE.value else x - for x in forward_modes + forward_modes_excluding_idle = [ + x for x in forward_modes if x != ForwardMode.IDLE.value ] + + if not forward_modes_excluding_idle: + return ForwardMode.IDLE, False + forward_mode_agree = TboDPAttentionPreparer._is_all_same( - converted_forward_modes + forward_modes_excluding_idle ) global_forward_mode = ( - ForwardMode(converted_forward_modes[0]) if forward_mode_agree else None + ForwardMode(forward_modes_excluding_idle[0]) if forward_mode_agree else None ) return global_forward_mode, forward_mode_agree From 28d4d4728088f551f13edfcafadf12484b32ee64 Mon Sep 17 00:00:00 2001 From: li haoyang Date: Fri, 25 Jul 2025 11:48:42 +0800 Subject: [PATCH 123/396] [Feature] Integrate quick allreduce and select the best allreduce implementation (#6619) Signed-off-by: Haoyang Li Co-authored-by: ilmarkov --- python/sglang/srt/_custom_ops.py | 30 +- .../device_communicators/custom_all_reduce.py | 94 +-- .../custom_all_reduce_utils.py | 97 ++- .../device_communicators/quick_all_reduce.py | 273 ++++++++ .../sglang/srt/distributed/parallel_state.py | 76 ++- sgl-kernel/csrc/allreduce/quick_all_reduce.cu | 111 +++ .../csrc/allreduce/quick_all_reduce.cuh | 633 ++++++++++++++++++ sgl-kernel/csrc/allreduce/quick_all_reduce.h | 233 +++++++ .../csrc/allreduce/quick_all_reduce_base.h | 318 +++++++++ sgl-kernel/csrc/torch_extension_rocm.cc | 19 + sgl-kernel/include/sgl_kernel_ops.h | 9 + sgl-kernel/python/sgl_kernel/allreduce.py | 34 +- sgl-kernel/setup_rocm.py | 1 + test/srt/test_quick_allreduce.py | 212 ++++++ 14 files changed, 2031 insertions(+), 109 deletions(-) create mode 100644 python/sglang/srt/distributed/device_communicators/quick_all_reduce.py create mode 100644 sgl-kernel/csrc/allreduce/quick_all_reduce.cu create mode 100644 sgl-kernel/csrc/allreduce/quick_all_reduce.cuh create mode 100644 sgl-kernel/csrc/allreduce/quick_all_reduce.h create mode 100644 sgl-kernel/csrc/allreduce/quick_all_reduce_base.h create mode 100644 test/srt/test_quick_allreduce.py diff --git a/python/sglang/srt/_custom_ops.py b/python/sglang/srt/_custom_ops.py index 1c232d19f8c2..5ed175312c9b 100644 --- a/python/sglang/srt/_custom_ops.py +++ b/python/sglang/srt/_custom_ops.py @@ -1,6 +1,6 @@ # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/_custom_ops.py import logging -from typing import List, Tuple +from typing import List, Optional, Tuple import torch @@ -114,6 +114,34 @@ def allocate_meta_buffer(size: int) -> torch.Tensor: def get_meta_buffer_ipc_handle(inp: torch.Tensor) -> torch.Tensor: return sgl_kernel.allreduce.get_meta_buffer_ipc_handle(inp) + # ROCM custom quick allreduce + + def init_custom_qr( + rank: int, world_size: int, qr_max_size: Optional[int] = None + ) -> int: + return sgl_kernel.allreduce.init_custom_qr(world_size, rank, qr_max_size) + + def qr_get_handle(fa: int) -> torch.Tensor: + return sgl_kernel.allreduce.qr_get_handle(fa) + + def qr_open_handles(fa: int, handles: list[torch.Tensor]) -> None: + sgl_kernel.allreduce.qr_open_handles(fa, handles) + + def qr_all_reduce( + fa: int, + inp: torch.Tensor, + out: torch.Tensor, + quant_level: int, + cast_bf2half: bool, + ) -> None: + sgl_kernel.allreduce.qr_all_reduce(fa, inp, out, quant_level, cast_bf2half) + + def qr_destroy(fa: int) -> None: + sgl_kernel.allreduce.qr_destroy(fa) + + def qr_max_size() -> int: + return sgl_kernel.allreduce.qr_max_size() + def mscclpp_generate_unique_id() -> bytes: return sgl_kernel.allreduce.mscclpp_generate_unique_id() diff --git a/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py b/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py index 9faff648c039..a1d28f2fc1d1 100644 --- a/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py +++ b/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py @@ -4,18 +4,18 @@ import logging import os from contextlib import contextmanager -from functools import wraps -from typing import Any, Callable, List, Optional, TypeVar, Union +from typing import Any, List, Optional, Union import torch import torch.distributed as dist from torch.distributed import ProcessGroup -from typing_extensions import ParamSpec from sglang.srt import _custom_ops as ops from sglang.srt.distributed.device_communicators.cuda_wrapper import CudaRTLibrary from sglang.srt.distributed.device_communicators.custom_all_reduce_utils import ( gpu_p2p_access_check, + is_full_nvlink, + is_weak_contiguous, ) from sglang.srt.distributed.parallel_state import in_the_same_node_as from sglang.srt.utils import is_cuda, is_hip @@ -25,23 +25,6 @@ _is_cuda = is_cuda() _is_hip = is_hip() -if _is_cuda: - try: - import pynvml - except ImportError as e: - logger.warning("Failed to import pynvml with %r", e) - -if _is_hip: - try: - from amdsmi import ( - AmdSmiException, - amdsmi_get_processor_handles, - amdsmi_init, - amdsmi_shut_down, - amdsmi_topo_get_link_type, - ) - except ImportError as e: - logger.warning("Failed to import amdsmi with %r", e) try: if ops.use_vllm_custom_allreduce and not _is_hip: @@ -57,70 +40,6 @@ logger = logging.getLogger(__name__) -_P = ParamSpec("_P") -_R = TypeVar("_R") - - -def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]: - @wraps(fn) - def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R: - if _is_hip: - try: - amdsmi_init() - return fn(*args, **kwargs) - finally: - amdsmi_shut_down() - else: - pynvml.nvmlInit() - try: - return fn(*args, **kwargs) - finally: - pynvml.nvmlShutdown() - - return wrapper - - -@with_nvml_context -def is_full_nvlink(physical_device_ids: List[int], world_size: int) -> bool: - if _is_hip: - """ - query if the set of gpus are fully connected by xgmi (1 hop) - """ - handles = [amdsmi_get_processor_handles()[i] for i in physical_device_ids] - for i, handle in enumerate(handles): - for j, peer_handle in enumerate(handles): - if i < j: - try: - link_type = amdsmi_topo_get_link_type(handle, peer_handle) - # type is 2 for XGMI - if link_type["hops"] != 1 or link_type["type"] != 2: - return False - except AmdSmiException as error: - logger.error("AMD 1 hop XGMI detection failed.", exc_info=error) - return False - return True - else: - """ - query if the set of gpus are fully connected by nvlink (1 hop) - """ - handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in physical_device_ids] - for i, handle in enumerate(handles): - for j, peer_handle in enumerate(handles): - if i < j: - try: - p2p_status = pynvml.nvmlDeviceGetP2PStatus( - handle, peer_handle, pynvml.NVML_P2P_CAPS_INDEX_NVLINK - ) - if p2p_status != pynvml.NVML_P2P_STATUS_OK: - return False - except pynvml.NVMLError: - logger.exception( - "NVLink detection failed. This is normal if your" - " machine has no NVLink equipped." - ) - return False - return True - def _can_p2p(rank: int, world_size: int) -> bool: # SGLANG_SKIP_P2P_CHECK can be set to False in sglang @@ -136,13 +55,6 @@ def _can_p2p(rank: int, world_size: int) -> bool: return True -def is_weak_contiguous(inp: torch.Tensor): - return inp.is_contiguous() or ( - inp.storage().nbytes() - inp.storage_offset() * inp.element_size() - == inp.numel() * inp.element_size() - ) - - class CustomAllreduce: _SUPPORTED_WORLD_SIZES = [2, 4, 6, 8] _MAX_CAR_SIZE = 8192 * 1024 diff --git a/python/sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py b/python/sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py index 86121ac976ee..c7baac845287 100644 --- a/python/sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +++ b/python/sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py @@ -8,17 +8,44 @@ import subprocess import sys import tempfile +from functools import wraps from itertools import product -from typing import Dict, List, Optional, Sequence +from typing import Callable, Dict, List, Optional, Sequence, TypeVar import torch import torch.distributed as dist import torch.multiprocessing as mp +from typing_extensions import ParamSpec from sglang.srt.distributed.device_communicators.cuda_wrapper import CudaRTLibrary +from sglang.srt.utils import is_cuda, is_hip logger = logging.getLogger(__name__) +_is_cuda = is_cuda() +_is_hip = is_hip() + +if _is_cuda: + try: + import pynvml + except ImportError as e: + logger.warning("Failed to import pynvml with %r", e) + +if _is_hip: + try: + from amdsmi import ( + AmdSmiException, + amdsmi_get_processor_handles, + amdsmi_init, + amdsmi_shut_down, + amdsmi_topo_get_link_type, + ) + except ImportError as e: + logger.warning("Failed to import amdsmi with %r", e) + +_P = ParamSpec("_P") +_R = TypeVar("_R") + def update_environment_variables(envs: Dict[str, str]): for k, v in envs.items(): @@ -282,6 +309,74 @@ def gpu_p2p_access_check(src: int, tgt: int) -> bool: return _gpu_p2p_access_cache[f"{src}->{tgt}"] +def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]: + @wraps(fn) + def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R: + if _is_hip: + try: + amdsmi_init() + return fn(*args, **kwargs) + finally: + amdsmi_shut_down() + else: + pynvml.nvmlInit() + try: + return fn(*args, **kwargs) + finally: + pynvml.nvmlShutdown() + + return wrapper + + +@with_nvml_context +def is_full_nvlink(physical_device_ids: List[int], world_size: int) -> bool: + if _is_hip: + """ + query if the set of gpus are fully connected by xgmi (1 hop) + """ + handles = [amdsmi_get_processor_handles()[i] for i in physical_device_ids] + for i, handle in enumerate(handles): + for j, peer_handle in enumerate(handles): + if i < j: + try: + link_type = amdsmi_topo_get_link_type(handle, peer_handle) + # type is 2 for XGMI + if link_type["hops"] != 1 or link_type["type"] != 2: + return False + except AmdSmiException as error: + logger.error("AMD 1 hop XGMI detection failed.", exc_info=error) + return False + return True + else: + """ + query if the set of gpus are fully connected by nvlink (1 hop) + """ + handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in physical_device_ids] + for i, handle in enumerate(handles): + for j, peer_handle in enumerate(handles): + if i < j: + try: + p2p_status = pynvml.nvmlDeviceGetP2PStatus( + handle, peer_handle, pynvml.NVML_P2P_CAPS_INDEX_NVLINK + ) + if p2p_status != pynvml.NVML_P2P_STATUS_OK: + return False + except pynvml.NVMLError: + logger.exception( + "NVLink detection failed. This is normal if your" + " machine has no NVLink equipped." + ) + return False + return True + + +def is_weak_contiguous(inp: torch.Tensor): + return inp.is_contiguous() or ( + inp.storage().nbytes() - inp.storage_offset() * inp.element_size() + == inp.numel() * inp.element_size() + ) + + __all__ = ["gpu_p2p_access_check"] if __name__ == "__main__": diff --git a/python/sglang/srt/distributed/device_communicators/quick_all_reduce.py b/python/sglang/srt/distributed/device_communicators/quick_all_reduce.py new file mode 100644 index 000000000000..0113c432df85 --- /dev/null +++ b/python/sglang/srt/distributed/device_communicators/quick_all_reduce.py @@ -0,0 +1,273 @@ +# SPDX-License-Identifier: Apache-2.0 + +import logging +import os +from enum import Enum +from typing import Union + +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup + +from sglang.srt import _custom_ops as ops +from sglang.srt.distributed.device_communicators.custom_all_reduce_utils import ( + is_full_nvlink, + is_weak_contiguous, +) +from sglang.srt.distributed.parallel_state import in_the_same_node_as +from sglang.srt.utils import is_cuda, is_hip + +logger = logging.getLogger(__name__) + +_is_cuda = is_cuda() +_is_hip = is_hip() + + +try: + ops.qr_max_size() + quick_ar = True +except Exception: + # For CPUs and CUDA + quick_ar = False + + +def qr_rocm_arch_available(): + if not _is_hip: + return False + try: + props = torch.cuda.get_device_properties(0) + gcn_arch = getattr(props, "gcnArchName", "") + supported_archs = ["gfx94", "gfx95"] + return any(gfx in gcn_arch for gfx in supported_archs) + except Exception as e: + logger.warning("Failed to determine ROCm for quick allreduce: %s", e) + return False + + +class QuickReduceRegime(Enum): + FP = 0 + INT8 = 1 + INT6 = 2 + INT4 = 3 + NONE = 4 + + +MB = 1024 * 1024 + + +class QuickAllReduce: + + _SUPPORTED_WORLD_SIZES = [2, 4, 8] + _SUPPORTED_DTYPES = [torch.float16, torch.bfloat16] + # The following data is based on kernel tests. + # In this order [FP, INT8, INT6, INT4]. + _QR_MIN_SIZE = { + (torch.float16, 2): [1 * MB, 2 * MB, 2 * MB, 1 * MB], + (torch.float16, 4): [1 * MB, 16 * MB, 4 * MB, 2 * MB], + (torch.float16, 8): [16 * MB, 4 * MB, 4 * MB, 2 * MB], + (torch.bfloat16, 2): [2 * MB, 8 * MB, 8 * MB, 8 * MB], + (torch.bfloat16, 4): [8 * MB, 64 * MB, 64 * MB, 16 * MB], + (torch.bfloat16, 8): [16 * MB, 2048 * MB, 2048 * MB, 2048 * MB], + } + + def __init__( + self, group: ProcessGroup, device: Union[int, str, torch.device] + ) -> None: + """ + Custom allreduce provides non-destructive acceleration and is + available for CUDA and ROCm MI300 series. + Custom quick allreduce leverages quantization for further + acceleration on ROCm. It currently supports Q8, Q6, and Q4 + quantization formats and FP(float16, bfloat16). + Quick allreduce is designed as a complement to custom allreduce. + Its initialization requires even stricter conditions. + Only the ROCm MI300 series is supported for quick allreduce at + this time. + Args: + group: the process group to work on. If None, it will use the + default process group. + device: the device to bind the CustomAllreduce to. If None, + it will be bind to f"cuda:{local_rank}". + It is the caller's responsibility to make sure each communicator + is bind to a unique device, and all communicators in this group + are in the same node. + """ + self.disabled = True + if not qr_rocm_arch_available(): + logger.debug( + "Custom quick allreduce is only supported on ROCm MI300 series." + ) + return + + if not quick_ar: + # disable because of missing quick reduce library + # e.g. in a cuda environment + logger.info( + "Custom quick allreduce is disabled because " + "of missing custom quick allreduce library" + ) + return + + self.group = group + assert ( + dist.get_backend(group) != dist.Backend.NCCL + ), "Custom quick allreduce should be attached to a non-NCCL group." + if not all(in_the_same_node_as(group, source_rank=0)): + # No need to initialize custom quick allreduce for + # multi-node case. + logger.warning( + "Custom quick allreduce is disabled because this " + "process group spans across nodes." + ) + return + rank = dist.get_rank(group=self.group) + world_size = dist.get_world_size(group=self.group) + self.rank = rank + self.world_size = world_size + if world_size == 1: + # No need to initialize QuickReduce for single GPU case. + return + + if world_size not in QuickAllReduce._SUPPORTED_WORLD_SIZES: + logger.warning( + "Custom quick allreduce is disabled due to an " + "unsupported world size: %d. Supported world sizes: %s.", + world_size, + str(QuickAllReduce._SUPPORTED_WORLD_SIZES), + ) + return + + if isinstance(device, int): + device = torch.device(f"cuda:{device}") + elif isinstance(device, str): + device = torch.device(device) + assert isinstance(device, torch.device) + self.device = device + + cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None) + if cuda_visible_devices: + device_ids = list(map(int, cuda_visible_devices.split(","))) + else: + device_ids = list(range(torch.cuda.device_count())) + physical_device_id = device_ids[device.index] + tensor = torch.tensor([physical_device_id], dtype=torch.int, device="cpu") + gather_list = [ + torch.tensor([0], dtype=torch.int, device="cpu") + for _ in range(self.world_size) + ] + dist.all_gather(gather_list, tensor, group=self.group) + physical_device_ids = [t.item() for t in gather_list] + + # test nvlink first, this will filter out most of the cases + # where custom quick allreduce is not supported + # this checks hardware and driver support for NVLink + if _is_cuda or _is_hip: + self.fully_connected = is_full_nvlink(physical_device_ids, self.world_size) + if self.world_size > 2 and not self.fully_connected: + logger.debug( + "Custom quick allreduce is disabled because it's not supported " + "on more than two PCIe-only GPUs. " + ) + return + + self.init_quick_all_reduce() + + def init_quick_all_reduce(self): + # On RocM, bfloat16 kernels are slower than fp16 + # due to slower match operations + # If environment variable is set to 1, we convert input to fp16 + self.use_fp16_kernels = int( + os.environ.get("ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16", 1) + ) + regime_str = os.environ.get("ROCM_QUICK_REDUCE_QUANTIZATION", "NONE") + if regime_str not in QuickReduceRegime.__members__: + logger.warning( + "Custom quick allreduce:", + f"Invalid quantization level: {regime_str}. " + "Supported levels: " + f"{list(QuickReduceRegime.__members__.keys())}", + ) + return + + if regime_str == "NONE": + logger.debug( + "Custom quick allreduce is disabled based " + "on env variable " + "ROCM_QUICK_REDUCE_QUANTIZATION='NONE'" + ) + return + self.qr_quant_level = QuickReduceRegime[regime_str] + + # TODO: If the dtype is not bfloat16 or then float16, + # quickallreduce should not be created. + + # ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB is specified in MB + qr_max_size = int(os.environ.get("ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB", 0)) + if qr_max_size > 0: + if qr_max_size < 1: + logger.info( + "You should not set a max_size smaller than 1MB, which can " + "lead to error or degradation to custom allreduce or rccl." + ) + qr_max_size = qr_max_size * MB + # If qr_max_size is None, then 2GB is used by default. + self._ptr = ops.init_custom_qr(self.rank, self.world_size, qr_max_size) + self.qr_max_size = qr_max_size if qr_max_size > 0 else ops.qr_max_size() + self.create_shared_buffer() + self.disabled = False + + def create_shared_buffer(self): + """ + Creates a shared buffer for quickreduce. + Has to be called after init_custom_qr + """ + handle = ops.qr_get_handle(self._ptr) + world_size = dist.get_world_size(group=self.group) + handles = [None] * world_size + dist.all_gather_object(handles, handle, group=self.group) + ops.qr_open_handles(self._ptr, handles) + + def should_quick_allreduce(self, inp: torch.Tensor): + """ + Check if quickreduce is available + """ + if self.disabled: + return False + if inp.dtype not in self._SUPPORTED_DTYPES: + return False + inp_size = inp.numel() * inp.element_size() + # custom quick allreduce requires input byte size to be + # multiples of 16 + if inp_size % 16 != 0: + return False + if not is_weak_contiguous(inp): + return False + dtype = inp.dtype + if self.use_fp16_kernels: + dtype = torch.float16 + return ( + inp_size <= self.qr_max_size + and inp_size + >= self._QR_MIN_SIZE[(dtype, self.world_size)][self.qr_quant_level.value] + ) + + def quick_all_reduce(self, inp: torch.Tensor, *, out: torch.Tensor = None): + """Performs an out-of-place custom quick all reduce.""" + # quick allreduce doesn't require a separate graph mode, + # as QR uses static IPC buffer. + if out is None: + out = torch.empty_like(inp) + ops.qr_all_reduce( + self._ptr, inp, out, self.qr_quant_level.value, self.use_fp16_kernels + ) + return out + + def close(self): + if not self.disabled and getattr(self, "_ptr", None): + if ops is not None: + ops.qr_destroy(self._ptr) + self._ptr = 0 + self.disabled = True + + def __del__(self): + self.close() diff --git a/python/sglang/srt/distributed/parallel_state.py b/python/sglang/srt/distributed/parallel_state.py index 509c71531062..130bc53c7ed9 100644 --- a/python/sglang/srt/distributed/parallel_state.py +++ b/python/sglang/srt/distributed/parallel_state.py @@ -44,6 +44,7 @@ get_bool_env_var, get_int_env_var, is_cuda_alike, + is_hip, is_npu, is_shm_available, supports_custom_op, @@ -126,14 +127,18 @@ def inplace_all_reduce_fake(tensor: torch.Tensor, group_name: str) -> None: fake_impl=inplace_all_reduce_fake, ) - def outplace_all_reduce(tensor: torch.Tensor, group_name: str) -> torch.Tensor: + def outplace_all_reduce( + tensor: torch.Tensor, group_name: str, outplace_all_reduce_method: str + ) -> torch.Tensor: assert group_name in _groups, f"Group {group_name} is not found." group = _groups[group_name]() if group is None: raise ValueError(f"Group {group_name} is destroyed.") - return group._all_reduce_out_place(tensor) + return group._all_reduce_out_place(tensor, outplace_all_reduce_method) - def outplace_all_reduce_fake(tensor: torch.Tensor, group_name: str) -> torch.Tensor: + def outplace_all_reduce_fake( + tensor: torch.Tensor, group_name: str, outplace_all_reduce_method: str + ) -> torch.Tensor: return torch.empty_like(tensor) direct_register_custom_op( @@ -264,6 +269,12 @@ def __init__( PyNcclCommunicator, ) + if is_hip(): + from sglang.srt.distributed.device_communicators.quick_all_reduce import ( + QuickAllReduce, + qr_rocm_arch_available, + ) + self.pynccl_comm: Optional[PyNcclCommunicator] = None if use_pynccl and self.world_size > 1: self.pynccl_comm = PyNcclCommunicator( @@ -283,6 +294,7 @@ def __init__( ) self.ca_comm: Optional[CustomAllreduce] = None + self.qr_comm: Optional[QuickAllReduce] = None if use_custom_allreduce and self.world_size > 1: # Initialize a custom fast all-reduce implementation. try: @@ -295,6 +307,18 @@ def __init__( f"Setup Custom allreduce failed with {e}. To silence this " "warning, specify --disable-custom-all-reduce explicitly." ) + if is_hip(): + try: + # Initialize a custom quick all-reduce implementation for AMD + # when rocm >= gfx942. Quick reduce is designed as a + # complement to custom allreduce. + # Based on quickreduce (https://github.com/mk1-project/quickreduce). + if qr_rocm_arch_available(): + self.qr_comm = QuickAllReduce( + group=self.cpu_group, device=self.device + ) + except Exception as e: + logger.warning(f"Failed to initialize QuickAllReduce: {e}") from sglang.srt.distributed.device_communicators.hpu_communicator import ( HpuCommunicator, @@ -373,7 +397,8 @@ def graph_capture( graph_capture_context = GraphCaptureContext(stream) else: stream = graph_capture_context.stream - + # We don't need the context of custom quick allreduce because the ipc access + # is already collected in init() and we can capture the quick allreduce directly. ca_comm = self.ca_comm maybe_ca_context = nullcontext() if ca_comm is None else ca_comm.capture() @@ -388,23 +413,24 @@ def graph_capture( # operations. The current status is: # allreduce \ Mode | Eager | Graph | # -------------------------------------------- + # quick allreduce | enabled | enabled | # custom allreduce | enabled | enabled | # PyNccl | disabled| enabled | # PyMscclpp | disabled| enabled | # torch.distributed | enabled | disabled| # + # Note: When custom quick allreduce is enabled, a runtime check + # will be performed. If the tensor size is too small, it will + # automatically fall back to the next available option. # Note that custom allreduce will have a runtime check, if the # tensor size is too large, it will fallback to the next # available option. # Note that the PyMsccl needs to register the tensor in ahead, # which will introduce large overhead in the eager case, # therefore it is only supported in the graph case. - # In summary: When using CUDA graph, we use - # either custom all-reduce kernel or pynccl. When not using - # CUDA graph, we use either custom all-reduce kernel or - # PyTorch NCCL. We always prioritize using custom all-reduce - # kernel but fall back to PyTorch or pynccl if it is - # disabled or not supported. + # In summary: We select the appropriate allreduce method for + # each mode based on the algorithm order in the table and + # their usage conditions. pynccl_comm = self.pynccl_comm maybe_pynccl_context: Any if not pynccl_comm: @@ -464,27 +490,47 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: if self.npu_communicator is not None and not self.npu_communicator.disabled: return self.npu_communicator.all_reduce(input_) + outplace_all_reduce_method = None if ( + self.qr_comm is not None + and not self.qr_comm.disabled + and self.qr_comm.should_quick_allreduce(input_) + ): + outplace_all_reduce_method = "qr" + elif ( self.ca_comm is not None and not self.ca_comm.disabled and self.ca_comm.should_custom_ar(input_) - ) or ( + ): + outplace_all_reduce_method = "ca" + elif ( self.pymscclpp_comm is not None and not self.pymscclpp_comm.disabled and self.pymscclpp_comm.should_mscclpp_allreduce(input_) ): + outplace_all_reduce_method = "pymscclpp" + if outplace_all_reduce_method is not None: return torch.ops.sglang.outplace_all_reduce( - input_, group_name=self.unique_name + input_, + group_name=self.unique_name, + outplace_all_reduce_method=outplace_all_reduce_method, ) else: torch.ops.sglang.inplace_all_reduce(input_, group_name=self.unique_name) return input_ - def _all_reduce_out_place(self, input_: torch.Tensor) -> torch.Tensor: + def _all_reduce_out_place( + self, input_: torch.Tensor, outplace_all_reduce_method: str + ) -> torch.Tensor: + qr_comm = self.qr_comm ca_comm = self.ca_comm pymscclpp_comm = self.pymscclpp_comm - assert ca_comm is not None or pymscclpp_comm is not None - if ca_comm is not None and not ca_comm.disabled: + assert any([qr_comm, ca_comm, pymscclpp_comm]) + if outplace_all_reduce_method == "qr": + assert not qr_comm.disabled + out = qr_comm.quick_all_reduce(input_) + elif outplace_all_reduce_method == "ca": + assert not ca_comm.disabled out = ca_comm.custom_all_reduce(input_) else: assert not pymscclpp_comm.disabled diff --git a/sgl-kernel/csrc/allreduce/quick_all_reduce.cu b/sgl-kernel/csrc/allreduce/quick_all_reduce.cu new file mode 100644 index 000000000000..757c05d2bddc --- /dev/null +++ b/sgl-kernel/csrc/allreduce/quick_all_reduce.cu @@ -0,0 +1,111 @@ +#include +#include +#include +#include + +#ifdef USE_ROCM + +#include "quick_all_reduce.h" + +quickreduce::fptr_t init_custom_qr(int64_t rank, int64_t world_size, std::optional qr_max_size) { + if (world_size > 8) throw std::invalid_argument("world size > 8 is not supported"); + if (world_size == 6) throw std::invalid_argument("world size == 6 is not supported"); + if (world_size % 2 != 0) throw std::invalid_argument("Odd num gpus is not supported for now"); + if (rank < 0 || rank >= world_size) throw std::invalid_argument("invalid rank passed in"); + quickreduce::DeviceComms* fptr = new quickreduce::DeviceComms(); + fptr->init(world_size, rank, qr_max_size); + return (quickreduce::fptr_t)fptr; +} + +void qr_destroy(quickreduce::fptr_t _fa) { + if (_fa) { + auto fa = reinterpret_cast(_fa); + fa->destroy(); + delete fa; + } +} + +torch::Tensor qr_get_handle(quickreduce::fptr_t _fa) { + auto fa = reinterpret_cast(_fa); + hipIpcMemHandle_t handle = fa->get_handle(); + auto options = torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU); + auto data_handle = torch::empty({static_cast(sizeof(hipIpcMemHandle_t))}, options); + std::memcpy(data_handle.data_ptr(), &handle, sizeof(hipIpcMemHandle_t)); + return data_handle; +} + +void qr_open_handles(quickreduce::fptr_t _fa, const std::vector& handles) { + auto fa = reinterpret_cast(_fa); + std::vector ipc_handles; + ipc_handles.reserve(handles.size()); + for (auto& handle : handles) { + // Ensure the tensor is on the same device as the current device. + hipIpcMemHandle_t ipc_handle; + std::memcpy(&ipc_handle, handle.data_ptr(), sizeof(hipIpcMemHandle_t)); + ipc_handles.push_back(ipc_handle); + } + fa->open_ipc_handles(ipc_handles); +} + +void qr_all_reduce( + quickreduce::fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, int64_t quant_level, bool cast_bf2half) { + auto fa = reinterpret_cast(_fa); + const at::cuda::OptionalCUDAGuard device_guard(device_of(inp)); + auto stream = at::cuda::getCurrentHIPStreamMasqueradingAsCUDA(); + + TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type()); + TORCH_CHECK_EQ(inp.numel(), out.numel()); + TORCH_CHECK_LE(out.numel(), fa->kMaxProblemSize); + if (out.scalar_type() == at::ScalarType::Half) { + fa->allreduce( + reinterpret_cast(inp.data_ptr()), + reinterpret_cast(out.data_ptr()), + out.numel(), + quant_level, + stream); + } else if (out.scalar_type() == at::ScalarType::BFloat16) { + if (cast_bf2half) { + fa->allreduce( + reinterpret_cast(inp.data_ptr()), + reinterpret_cast(out.data_ptr()), + out.numel(), + quant_level, + stream); + } else { + fa->allreduce( + reinterpret_cast(inp.data_ptr()), + reinterpret_cast(out.data_ptr()), + out.numel(), + quant_level, + stream); + } + } else { + throw std::runtime_error("quick allreduce only supports float16 and bfloat16"); + } +} + +int64_t qr_max_size() { + // The default is 2GB (2,147,483,648 bytes) + return static_cast(std::numeric_limits::max()) + 1; +} + +#define INSTANTIATE_FOR_WORLDSIZE(T, Codec, cast_bf2half) \ + template struct quickreduce::AllReduceTwoshot, cast_bf2half>; \ + template struct quickreduce::AllReduceTwoshot, cast_bf2half>; \ + template struct quickreduce::AllReduceTwoshot, cast_bf2half>; + +INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecFP, false) +INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ4, false) +INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ6, false) +INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ8, false) +INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecFP, true) +INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ4, true) +INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ6, true) +INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ8, true) + +INSTANTIATE_FOR_WORLDSIZE(half, quickreduce::CodecFP, false) +INSTANTIATE_FOR_WORLDSIZE(half, quickreduce::CodecQ4, false) +INSTANTIATE_FOR_WORLDSIZE(half, quickreduce::CodecQ6, false) +INSTANTIATE_FOR_WORLDSIZE(half, quickreduce::CodecQ8, false) + +#endif // USE_ROCM diff --git a/sgl-kernel/csrc/allreduce/quick_all_reduce.cuh b/sgl-kernel/csrc/allreduce/quick_all_reduce.cuh new file mode 100644 index 000000000000..bd9e7b10fa19 --- /dev/null +++ b/sgl-kernel/csrc/allreduce/quick_all_reduce.cuh @@ -0,0 +1,633 @@ +#pragma once + +#include + +#include "quick_all_reduce_base.h" + +namespace quickreduce { + +struct CodecBase { + const int thread; + const int rank; + const int group_leader; + __quickreduce_device_inline__ CodecBase(int thread, int rank) + : thread(thread), rank(rank), group_leader((threadIdx.x / kThreadGroupSize) * kThreadGroupSize) { + set_fp16_ovfl(true); + } +}; + +// Default full precision codec. +template +struct CodecFP : public CodecBase { + static constexpr int kWorldSize = world_size; + static constexpr int kRankAtoms = kAtoms / kWorldSize; + + // Codec tile size process by this workgroup. + // Each thread processes atoms of f16x8_t (16B). + static constexpr int kRankTransmittedTileSize = kBlockSize * kRankAtoms * sizeof(int32x4_t); + static_assert(kRankTransmittedTileSize % 16 == 0, "kRankTransmittedTileSize must be 16B aligned."); + + // Total tile size for the collective communication. + static constexpr int kTransmittedTileSize = kRankTransmittedTileSize * kWorldSize; + + __quickreduce_device_inline__ CodecFP(int thread, int rank) : CodecBase(thread, rank) {} + + __quickreduce_device_inline__ void send(int32x4_t* __restrict__ send_buffer, const int32x4_t* __restrict__ data) { + for (int i = 0; i < kRankAtoms; i++) { + __builtin_nontemporal_store(data[i], send_buffer + thread); + send_buffer += kAtomStride; + } + } + + __quickreduce_device_inline__ void recv(int32x4_t** __restrict__ recv_buffer, int32x4_t* __restrict__ data) { + for (int i = 0; i < kRankAtoms; i++) { + data[i] = __builtin_nontemporal_load(*recv_buffer + thread); + *recv_buffer += kAtomStride; + } + } +}; + +// Int4 symmetric quantization codec. +// We quantize the FP16 data to block-scaled Int4 in blocks of 4 * +// kThreadGroupSize. +template +struct CodecQ4 : public CodecBase { + static constexpr int kWorldSize = world_size; + + // Codec tile size process by this workgroup. + // Each threads processes a fragment of fp16x8_t (16B), + // into a int4x8_t (4B) and a fp16 scale shared among 32 values. + static constexpr int kRankAtoms = kAtoms / kWorldSize; + static constexpr int kRankTileStride = 1152; + static constexpr int kRankTileScaleOffset = 1024; + static constexpr int kRankTransmittedTileSize = kRankTileStride * kRankAtoms; + static_assert(kRankTransmittedTileSize % 16 == 0, "kRankTransmittedTileSize must be 16B aligned."); + + static constexpr int kRankBufferTileStride = kRankTileStride / sizeof(int32x4_t); + + // Total tile size for the collective communication. + static constexpr int kTransmittedTileSize = kRankTransmittedTileSize * kWorldSize; + + // Constants configuration + + // {-1/8.0h, -1/8.0h}, f16x2_t + static constexpr int kScaleFactor = std::is_same::value ? 0xB000B000 : 0xBE00BE00; + + // {1e-7, 1e-7}, f16x2_t + static constexpr int kScaleEpsilon = std::is_same::value ? 0x00010001 : 0x33D733D7; + + // {-8, -8}, f16x2_t + static constexpr int kRangeMin = std::is_same::value ? 0xC800C800 : 0xC100C100; + + // {+7, +7}, f16x2_t + static constexpr int kRangeMax = std::is_same::value ? 0x47004700 : 0x40E040E0; + + // {+8, +8}, int16x2_t + static constexpr int kRangeBias = 0x00080008; + + __quickreduce_device_inline__ CodecQ4(int thread, int rank) : CodecBase(thread, rank) {} + + __quickreduce_device_inline__ void send(int32x4_t* __restrict__ send_buffer, const int32x4_t* __restrict__ data) { + for (int k = 0; k < kRankAtoms; k++) { + int32x4_t const atom = data[k]; + + // Compute the absolute maximum of the atom in the thread group + // In 2 blocks of values, upper/lower halves of the f16x2_t + int wblockmax = group_abs_max(atom); + + // Derive scales + int decoding_scale; + int encoding_scale; + decoding_scale = packed_mul(wblockmax, kScaleFactor); + encoding_scale = packed_add(decoding_scale, kScaleEpsilon); + encoding_scale = packed_rcp(encoding_scale); + + // Apply scales to get quantized values + int32x4_t w; + for (int i = 0; i < 4; i++) { + w[i] = packed_mul(atom[i], encoding_scale); + w[i] = packed_max(w[i], kRangeMin); + w[i] = packed_min(w[i], kRangeMax); + } + + // Convert from f16x2_t to uint16x2_t + int32x4_t q; + { + int16_t* qi = reinterpret_cast(&q); + T* wh = reinterpret_cast(&w); + for (int i = 0; i < 8; i++) + qi[i] = (int16_t)rintf(T2float_cast(wh[i])); + + for (int i = 0; i < 4; i++) { + q[i] = packed_add(q[i], kRangeBias); + } + } + + // Pack 8 x q4 into int32_t + int qw = q[0] | (q[1] << 4) | (q[2] << 8) | (q[3] << 12); + + // Write quantized atom to send_buffer + // note: only the group leader stores the scale + uint8_t* atom_ptr = reinterpret_cast(send_buffer + k * kRankBufferTileStride); + int32_t* qw_ptr = reinterpret_cast(atom_ptr) + thread; + int* qs_ptr = reinterpret_cast(atom_ptr + kRankTileScaleOffset) + (thread / 8); + + __builtin_nontemporal_store(qw, qw_ptr); + if (threadIdx.x == group_leader) { + __builtin_nontemporal_store(decoding_scale, qs_ptr); + } + } + } + + __quickreduce_device_inline__ void recv(int32x4_t** __restrict__ recv_buffer, int32x4_t* __restrict__ data) { + for (int k = 0; k < kRankAtoms; k++) { + // Directly read quantized atom from recv_buffer + uint8_t* atom_ptr = reinterpret_cast(*recv_buffer); + int32_t* qw_ptr = reinterpret_cast(atom_ptr) + thread; + int* qs_ptr = reinterpret_cast(atom_ptr + kRankTileScaleOffset) + (thread / 8); + + int32_t qw = __builtin_nontemporal_load(qw_ptr); + int qs = __builtin_nontemporal_load(qs_ptr); + + *recv_buffer += kRankBufferTileStride; + + // Unpack q4 into f16x8_t + int32x4_t w; + { + static constexpr uint kMask000F = 0x000F000F; + static constexpr uint kHalf2_1024 = 0x64006400; // {1024.0, 1024.0}, fp16x2_t + static uint constexpr kHalf2_1032 = 0xE408E408; // {-1032.0, -1032.0}, fp16x2_t + + for (int i = 0; i < 4; i++) { + if constexpr (std::is_same::value) { + int32_t q4 = ((qw >> (i * 4)) & kMask000F) | kHalf2_1024; + w[i] = packed_add(q4, kHalf2_1032); + } else { + int32_t int16_2 = (qw >> (i * 4)) & kMask000F; + int16_t low = static_cast(int16_2 & 0xFFFF); + int16_t high = static_cast((int16_2 >> 16) & 0xFFFF); + nv_bfloat16 bf_low = __float2bfloat16(static_cast(low)); + nv_bfloat16 bf_high = __float2bfloat16(static_cast(high)); + nv_bfloat162 bf2 = __halves2bfloat162(bf_low, bf_high); + int32_t packed_bf16 = *reinterpret_cast(&bf2); + w[i] = packed_add(packed_bf16, kRangeMin); + } + } + } + + // Apply decoding scales + for (int i = 0; i < 4; i++) { + w[i] = packed_mul(w[i], qs); + } + + data[k] = w; + } + } +}; + +// Int6 symmetric quantization codec. +// We quantize the FP16 data to block-scaled Int6 in blocks of 4 * +// kThreadGroupSize. +template +struct CodecQ6 : public CodecBase { + static constexpr int kWorldSize = world_size; + + // Codec tile size process by this workgroup. + // Each threads processes a fragment of fp16x8_t (16B), + // into a int6x8_t (4B + 2B) and a fp16 scale shared among 32 values. + static constexpr int kRankAtoms = kAtoms / kWorldSize; + static constexpr int kRankTileStride = 1664; + static constexpr int kRankTileQ2Offset = 1024; + static constexpr int kRankTileScaleOffset = 1536; + static constexpr int kRankTransmittedTileSize = kRankTileStride * kRankAtoms; + static_assert(kRankTransmittedTileSize % 16 == 0, "kRankTransmittedTileSize must be 16B aligned."); + + static constexpr int kRankBufferTileStride = kRankTileStride / sizeof(int32x4_t); + + // Total tile size for the collective communication. + static constexpr int kTransmittedTileSize = kRankTransmittedTileSize * kWorldSize; + + // Constants configuration + + // {-1/32.0h, -1/32.0h}, fp16x2_t + static constexpr int kScaleFactor = std::is_same::value ? 0xA800A800 : 0xBD00BD00; + + // {1e-7, 1e-7}, fp16x2_t + static constexpr int kScaleEpsilon = std::is_same::value ? 0x00010001 : 0x33D733D7; + + // {-32, -32}, fp16x2_t + static constexpr int kRangeMin = std::is_same::value ? 0xD000D000 : 0xC200C200; + + // {+31, +31}, fp16x2_t + static constexpr int kRangeMax = std::is_same::value ? 0x4FC04FC0 : 0x41F841F8; + + // {+32, +32}, int16x2_t + static constexpr int kRangeBias = 0x00200020; + + __quickreduce_device_inline__ CodecQ6(int thread, int rank) : CodecBase(thread, rank) {} + + __quickreduce_device_inline__ void send(int32x4_t* __restrict__ send_buffer, const int32x4_t* __restrict__ data) { + for (int k = 0; k < kRankAtoms; k++) { + int32x4_t const atom = data[k]; + + // Compute the absolute maximum of the atom in the thread group + // In 2 blocks of values, upper/lower halves of the f16x2_t + int wblockmax = group_abs_max(atom); + + // Derive scales + int decoding_scale; + int encoding_scale; + decoding_scale = packed_mul(wblockmax, kScaleFactor); + encoding_scale = packed_add(decoding_scale, kScaleEpsilon); + encoding_scale = packed_rcp(encoding_scale); + + // Apply scales to get quantized values + int32x4_t w; + for (int i = 0; i < 4; i++) { + w[i] = packed_mul(atom[i], encoding_scale); + w[i] = packed_max(w[i], kRangeMin); + w[i] = packed_min(w[i], kRangeMax); + } + + // Convert from f16x2_t to uint16x2_t + int32x4_t q; + { + int16_t* qi = reinterpret_cast(&q); + T* wh = reinterpret_cast(&w); + for (int i = 0; i < 8; i++) + qi[i] = (int16_t)rintf(T2float_cast(wh[i])); + + for (int i = 0; i < 4; i++) { + q[i] = packed_add(q[i], kRangeBias); + } + } + + // Pack 8 x q6 into int32_t + int16_t + uint32_t q4w; + uint16_t q2w = 0; + q4w = (q[0] & 0x000F000F) | ((q[1] & 0x000F000F) << 4) | ((q[2] & 0x000F000F) << 8) | ((q[3] & 0x000F000F) << 12); + { + int16_t* tw = reinterpret_cast(&q); +#pragma unroll + for (int i = 0; i < 8; i++) { + q2w |= (tw[i] >> 4) << (i * 2); + } + } + // Write quantized atom to send_buffer + // note: only the group leader stores the scale + uint8_t* atom_ptr = reinterpret_cast(send_buffer + k * kRankBufferTileStride); + uint32_t* q4w_ptr = reinterpret_cast(atom_ptr) + thread; + uint16_t* q2w_ptr = reinterpret_cast(atom_ptr + kRankTileQ2Offset) + thread; + int* qs_ptr = reinterpret_cast(atom_ptr + kRankTileScaleOffset) + (thread / 8); + + __builtin_nontemporal_store(q4w, q4w_ptr); + __builtin_nontemporal_store(q2w, q2w_ptr); + if (threadIdx.x == group_leader) { + __builtin_nontemporal_store(decoding_scale, qs_ptr); + } + } + } + + __quickreduce_device_inline__ void recv(int32x4_t** __restrict__ recv_buffer, int32x4_t* __restrict__ data) { + for (int k = 0; k < kRankAtoms; k++) { + // Directly read quantized atom from recv_buffer + uint8_t* atom_ptr = reinterpret_cast(*recv_buffer); + uint32_t* q4w_ptr = reinterpret_cast(atom_ptr) + thread; + uint16_t* q2w_ptr = reinterpret_cast(atom_ptr + kRankTileQ2Offset) + thread; + int* qs_ptr = reinterpret_cast(atom_ptr + kRankTileScaleOffset) + (thread / 8); + + uint32_t q4w = __builtin_nontemporal_load(q4w_ptr); + uint16_t q2w = __builtin_nontemporal_load(q2w_ptr); + int qs = __builtin_nontemporal_load(qs_ptr); + + *recv_buffer += kRankBufferTileStride; + + // Unpack q6 into fp16x8_t + int32x4_t w; + { + static uint constexpr kMask000F = 0x000F000F; + static uint constexpr kHalf2_1024 = 0x64006400; // {1024.0, 1024.0}, fp16x2_t + static uint constexpr kHalf2_1056 = 0xE420E420; // {-1056.0, -1056.0}, fp16x2_t + +#pragma unroll + for (int i = 0; i < 4; i++) { + int32_t q4 = q4w & kMask000F; + int32_t q2 = (q2w & 0x3) | ((q2w & 0xC) << 14); + q4w >>= 4; + q2w >>= 4; + if constexpr (std::is_same::value) { + int32_t q6 = q4 | (q2 << 4) | kHalf2_1024; + asm volatile("v_pk_add_f16 %0, %1, %2" : "=v"(w[i]) : "v"(q6), "v"(kHalf2_1056)); + } else { + int32_t int16_2 = q4 | (q2 << 4); + int16_t low = static_cast(int16_2 & 0xFFFF); + int16_t high = static_cast((int16_2 >> 16) & 0xFFFF); + + nv_bfloat16 bf_low = __float2bfloat16(static_cast(low)); + nv_bfloat16 bf_high = __float2bfloat16(static_cast(high)); + nv_bfloat162 bf2 = __halves2bfloat162(bf_low, bf_high); + int32_t packed_bf16 = *reinterpret_cast(&bf2); + w[i] = packed_add(packed_bf16, kRangeMin); + } + } + } + + // Apply decoding scales + for (int i = 0; i < 4; i++) { + w[i] = packed_mul(w[i], qs); + } + + // That's pretty much it... + data[k] = w; + } + } +}; + +// Int8 symmetric quantization codec. +// We quantize the FP16 data to block-scaled Int8 in blocks of 4 * +// kThreadGroupSize. +template +struct CodecQ8 : public CodecBase { + static constexpr int kWorldSize = world_size; + + // Codec tile size process by this workgroup. + // Each threads processes a fragment of f16x8_t (16B), + // into a int8x8_t (8B) and a f16 scale shared among 32 values. + static constexpr int kRankAtoms = kAtoms / kWorldSize; + static constexpr int kRankTileStride = 2176; + static constexpr int kRankTileScaleOffset = 2048; + static constexpr int kRankTransmittedTileSize = kRankTileStride * kRankAtoms; + static_assert(kRankTransmittedTileSize % 16 == 0, "kRankTileSize must be 16B aligned."); + + static constexpr int kRankBufferTileStride = kRankTileStride / sizeof(int32x4_t); + + // Total tile size for the collective communication. + static constexpr int kTransmittedTileSize = kRankTransmittedTileSize * kWorldSize; + + // Constants configuration + + // {-1/128.0h, -1/128.0h}, f16x2_t + static constexpr int kScaleFactor = std::is_same::value ? 0xA000A000 : 0xBC00BC00; + + // {1e-7, 1e-7}, f16x2_t + static constexpr int kScaleEpsilon = std::is_same::value ? 0x00010001 : 0x33D733D7; + + // {-128, -128}, f16x2_t + static constexpr int kRangeMin = std::is_same::value ? 0xD800D800 : 0xC300C300; + // {+127, +127}, f16x2_t + static constexpr int kRangeMax = std::is_same::value ? 0x57F057F0 : 0x42FE42FE; + + // {+128, +128}, int16x2_t + static constexpr int kRangeBias = 0x00800080; + + __quickreduce_device_inline__ CodecQ8(int thread, int rank) : CodecBase(thread, rank) {} + + __quickreduce_device_inline__ void send(int32x4_t* __restrict__ send_buffer, int32x4_t const* __restrict__ data) { + for (int k = 0; k < kRankAtoms; k++) { + int32x4_t const atom = data[k]; + // Compute the absolute maximum of the atom in the thread group + // In 2 blocks of values, upper/lower halves of the f16x2_t + int wblockmax = group_abs_max(atom); + + // Derive scales + int decoding_scale; + int encoding_scale; + decoding_scale = packed_mul(wblockmax, kScaleFactor); + encoding_scale = packed_add(decoding_scale, kScaleEpsilon); + encoding_scale = packed_rcp(encoding_scale); + + // Apply scales to get quantized values + int32x4_t w; + for (int i = 0; i < 4; i++) { + w[i] = packed_mul(atom[i], encoding_scale); + w[i] = packed_max(w[i], kRangeMin); + w[i] = packed_min(w[i], kRangeMax); + } + + // Convert from f16x2_t to uint16x2_t + int32x4_t q; + { + int16_t* qi = reinterpret_cast(&q); + T* wh = reinterpret_cast(&w); + for (int i = 0; i < 8; i++) + qi[i] = (int16_t)rintf(T2float_cast(wh[i])); + + for (int i = 0; i < 4; i++) { + q[i] = packed_add(q[i], kRangeBias); + } + } + + // Pack 8 x q8 into int32x2_t + int32x2_t qw; + qw[0] = q[0] | (q[1] << 8); + qw[1] = q[2] | (q[3] << 8); + + // Write quantized atom to send_buffer + // note: only the group leader stores the scale + uint8_t* atom_ptr = reinterpret_cast(send_buffer + k * kRankBufferTileStride); + int32x2_t* qw_ptr = reinterpret_cast(atom_ptr) + thread; + int* qs_ptr = reinterpret_cast(atom_ptr + kRankTileScaleOffset) + (thread / 8); + + __builtin_nontemporal_store(qw, qw_ptr); + if (threadIdx.x == group_leader) { + __builtin_nontemporal_store(decoding_scale, qs_ptr); + } + } + } + + __quickreduce_device_inline__ void recv(int32x4_t** __restrict__ recv_buffer, int32x4_t* __restrict__ data) { + for (int k = 0; k < kRankAtoms; k++) { + // Directly read quantized atom from recv_buffer + uint8_t* atom_ptr = reinterpret_cast(*recv_buffer); + int32x2_t* qw_ptr = reinterpret_cast(atom_ptr) + thread; + int* qs_ptr = reinterpret_cast(atom_ptr + kRankTileScaleOffset) + (thread / 8); + + int32x2_t qw = __builtin_nontemporal_load(qw_ptr); + int qs = __builtin_nontemporal_load(qs_ptr); + + *recv_buffer += kRankBufferTileStride; + + // Unpack q8 into fp16x8_t + int32x4_t w; + { + static uint constexpr kMask00FF = 0x00FF00FF; + + // {1024.0, 1024.0}, fp16x2_t + static uint constexpr kHalf2_1024 = 0x64006400; + + // {-1152.0, -1152.0}, fp16x2_t + static uint constexpr kHalf2_1152 = 0xE480E480; + +#pragma unroll + for (int i = 0; i < 4; i++) { + if constexpr (std::is_same::value) { + int32_t q8 = ((qw[i / 2] >> ((i % 2) * 8)) & kMask00FF) | kHalf2_1024; + w[i] = packed_add(q8, kHalf2_1152); + } else { + int32_t int16_2 = (qw[i / 2] >> ((i % 2) * 8)) & kMask00FF; + int16_t low = static_cast(int16_2 & 0xFFFF); + int16_t high = static_cast((int16_2 >> 16) & 0xFFFF); + nv_bfloat16 bf_low = __float2bfloat16(static_cast(low)); + nv_bfloat16 bf_high = __float2bfloat16(static_cast(high)); + nv_bfloat162 bf2 = __halves2bfloat162(bf_low, bf_high); + int32_t packed_bf16 = *reinterpret_cast(&bf2); + w[i] = packed_add(packed_bf16, kRangeMin); + } + } + } + + // Apply decoding scales + for (int i = 0; i < 4; i++) { + w[i] = packed_mul(w[i], qs); + } + + data[k] = w; + } + } +}; + +// Twoshot All Reduce +template +struct AllReduceTwoshot { + static_assert(sizeof(T) == 2); + + static constexpr int kWorldSize = Codec::kWorldSize; + + __device__ static void + run(T const* __restrict__ input, + T* __restrict__ output, + uint32_t const N, // number of elements + int const block, // block index + int const rank, // rank index + uint8_t** __restrict__ buffer_list, // communication buffers + uint32_t const data_offset, // offset to start of the data buffer + uint32_t flag_color) { + // Topology + int thread = threadIdx.x + threadIdx.y * kWavefront; + uint8_t* rank_buffer = buffer_list[rank]; + Codec codec(thread, rank); + int block_id = blockIdx.x; + int grid_size = gridDim.x; + // -------------------------------------------------------- + // Read input into registers + int32x4_t tA[kAtoms]; + + BufferResource src_buffer(const_cast(input), N * sizeof(T)); + uint32_t src_offset = block * kTileSize + thread * sizeof(int32x4_t); + + for (int i = 0; i < kAtoms; i++) { + tA[i] = buffer_load_dwordx4(src_buffer.descriptor, src_offset, 0, 0); + src_offset += kAtomStride * sizeof(int32x4_t); + if constexpr (cast_bf2half) { + const nv_bfloat162* bf_buf = reinterpret_cast(&tA[i]); + half2 half_buf[4]; +#pragma unroll + for (int j = 0; j < 4; ++j) { + float2 f = __bfloat1622float2(bf_buf[j]); + half_buf[j] = __float22half2_rn(f); + } + tA[i] = *reinterpret_cast(half_buf); + } + } + + // -------------------------------------------------------- + // Phase-1A: Write segment data into the communication buffer of the target + // rank responsible for this segment. + uint32_t comm_data0_offset = data_offset + block_id * Codec::kTransmittedTileSize; + uint32_t comm_data1_offset = grid_size * Codec::kTransmittedTileSize + comm_data0_offset; + + uint32_t comm_flags0_offset = block_id * (kWorldSize * sizeof(uint32_t)); + uint32_t comm_flags1_offset = grid_size * (kWorldSize * sizeof(uint32_t)) + comm_flags0_offset; + + for (int r = 0; r < kWorldSize; r++) { + int32x4_t* send_buffer = + reinterpret_cast(buffer_list[r] + comm_data0_offset + rank * Codec::kRankTransmittedTileSize); + codec.send(send_buffer, &tA[r * Codec::kRankAtoms]); + } + + __syncthreads(); + if (thread < kWorldSize) { + int r = thread; + uint32_t* flag_ptr = reinterpret_cast(buffer_list[r] + comm_flags0_offset + rank * sizeof(uint32_t)); + set_sync_flag(flag_ptr, flag_color); + } + // -------------------------------------------------------- + // Phase-1B: Reduce the segment data from the communication buffers. + int32x4_t tR[Codec::kRankAtoms] = {}; + { + // Read the data from the communication buffer. + int32x4_t* recv_buffer = reinterpret_cast(rank_buffer + comm_data0_offset); + uint32_t* flag_ptr = reinterpret_cast(rank_buffer + comm_flags0_offset); + + for (int r = 0; r < kWorldSize; r++) { + // Wait for the flags to be set. + if (thread == 0) { + wait_sync_flag(&flag_ptr[r], flag_color); + } + __syncthreads(); + + // note: we reuse tA as temp buffer here + codec.recv(&recv_buffer, tA); + + for (int i = 0; i < Codec::kRankAtoms; i++) { + packed_assign_add(&tR[i], &tA[i]); + } + } + } + + // Phase-2: Write the reduced segment to every other rank + for (int r = 0; r < kWorldSize; r++) { + int32x4_t* send_buffer = + reinterpret_cast(buffer_list[r] + comm_data1_offset + rank * Codec::kRankTransmittedTileSize); + codec.send(send_buffer, tR); + } + + __syncthreads(); + if (thread < kWorldSize) { + int r = thread; + uint32_t* flag_ptr = reinterpret_cast(buffer_list[r] + comm_flags1_offset + rank * sizeof(uint32_t)); + set_sync_flag(flag_ptr, flag_color); + } + + // Phase-2: Read the gather segments from the rank's communication buffer. + { + // Read the data from the communication buffer. + int32x4_t* recv_buffer = reinterpret_cast(rank_buffer + comm_data1_offset); + uint32_t* flag_ptr = reinterpret_cast(rank_buffer + comm_flags1_offset); + + for (int r = 0; r < kWorldSize; r++) { + // Wait for the flags to be set. + if (thread == 0) { + wait_sync_flag(&flag_ptr[r], flag_color); + } + __syncthreads(); + + // Gather all reduced and final rank segments into tA. + codec.recv(&recv_buffer, &tA[r * Codec::kRankAtoms]); + } + } + + // -------------------------------------------------------- + // Write the result to output. + BufferResource dst_buffer(output, N * sizeof(T)); + uint32_t dst_offset = block * kTileSize + thread * sizeof(int32x4_t); + + for (int i = 0; i < kAtoms; i++) { + if constexpr (cast_bf2half) { + const half2* half_buf = reinterpret_cast(&tA[i]); + nv_bfloat162 bf16_buf[4]; +#pragma unroll + for (int j = 0; j < 4; ++j) { + float2 f = __half22float2(half_buf[j]); + bf16_buf[j] = __float22bfloat162_rn(f); + } + buffer_store_dwordx4(*reinterpret_cast(bf16_buf), dst_buffer.descriptor, dst_offset, 0, 0); + } else { + buffer_store_dwordx4(tA[i], dst_buffer.descriptor, dst_offset, 0, 0); + } + dst_offset += kAtomStride * sizeof(int32x4_t); + } + } +}; + +} // namespace quickreduce diff --git a/sgl-kernel/csrc/allreduce/quick_all_reduce.h b/sgl-kernel/csrc/allreduce/quick_all_reduce.h new file mode 100644 index 000000000000..1d629e018241 --- /dev/null +++ b/sgl-kernel/csrc/allreduce/quick_all_reduce.h @@ -0,0 +1,233 @@ +#pragma once + +#include + +#include + +#include "quick_all_reduce.cuh" + +#define HIP_CHECK(err) \ + do { \ + hipError_t err_ = (err); \ + if (err_ != hipSuccess) { \ + std::printf("HIP error %d at %s:%d. %s\n", err_, __FILE__, __LINE__, hipGetErrorString(err_)); \ + throw std::runtime_error("HIP error"); \ + } \ + } while (0) + +namespace quickreduce { +using fptr_t = int64_t; +static_assert(sizeof(void*) == sizeof(fptr_t)); + +template +__global__ __quickreduce_launch_bounds_two_shot__ static void allreduce_prototype_twoshot( + T const* A, + T* B, + uint32_t N, + uint32_t num_blocks, + int rank, + uint8_t** dbuffer_list, + uint32_t data_offset, + uint32_t flag_color) { + int block = blockIdx.x; + int grid = gridDim.x; + + while (block < num_blocks) { + AllReduceKernel::run(A, B, N, block, rank, dbuffer_list, data_offset, flag_color); + block += grid; + flag_color++; + } +} + +#define TWOSHOT_DISPATCH(__codec) \ + if (world_size == 2) { \ + using LineCodec = __codec; \ + using AllReduceKernel = AllReduceTwoshot; \ + hipLaunchKernelGGL( \ + (allreduce_prototype_twoshot), \ + dim3(grid), \ + dim3(kBlockTwoShot), \ + 0, \ + stream, \ + A, \ + B, \ + N, \ + num_blocks, \ + rank, \ + dbuffer_list, \ + data_offset, \ + flag_color); \ + } else if (world_size == 4) { \ + using LineCodec = __codec; \ + using AllReduceKernel = AllReduceTwoshot; \ + hipLaunchKernelGGL( \ + (allreduce_prototype_twoshot), \ + dim3(grid), \ + dim3(kBlockTwoShot), \ + 0, \ + stream, \ + A, \ + B, \ + N, \ + num_blocks, \ + rank, \ + dbuffer_list, \ + data_offset, \ + flag_color); \ + } else if (world_size == 8) { \ + using LineCodec = __codec; \ + using AllReduceKernel = AllReduceTwoshot; \ + hipLaunchKernelGGL( \ + (allreduce_prototype_twoshot), \ + dim3(grid), \ + dim3(kBlockTwoShot), \ + 0, \ + stream, \ + A, \ + B, \ + N, \ + num_blocks, \ + rank, \ + dbuffer_list, \ + data_offset, \ + flag_color); \ + } + +enum QuickReduceQuantLevel { + F16 = 0, + INT8 = 1, + INT6 = 2, + INT4 = 3, +}; + +struct DeviceComms { + // Max problem size is 2GB (in bytes) or half of uint32_t max value. + int64_t kMaxProblemSize = static_cast(std::numeric_limits::max()) + 1; + + // Max TP-8 + static int constexpr kMaxWorldSize = 8; + + bool initialized = false; + uint32_t flag_color = 1; + int world_size; + int rank; + + uint8_t* dbuffer; + uint8_t** dbuffer_list; + hipIpcMemHandle_t buffer_ipc_handle; + std::vector all_buffer_ipc_handles; + std::vector buffer_list; + uint32_t data_offset; + + DeviceComms() : initialized(false), world_size(1), rank(0) {} + ~DeviceComms() { + destroy(); + } + + void init(int world_size, int rank, std::optional max_problem_size = std::nullopt) { + destroy(); + this->world_size = world_size; + this->rank = rank; + if (max_problem_size.has_value() && max_problem_size.value() > 0) { + this->kMaxProblemSize = max_problem_size.value(); + } + // Allocate buffer size for worst case: F16 2-stage buffer. + uint32_t flags_buffer_size = 2 * world_size * kMaxNumBlocks * sizeof(uint32_t); + static int64_t data_buffer_size = 2 * this->kMaxProblemSize; + int64_t total_buffer_size = flags_buffer_size + data_buffer_size; + data_offset = flags_buffer_size; + HIP_CHECK(hipExtMallocWithFlags((void**)&dbuffer, total_buffer_size, hipDeviceMallocUncached)); + + // Clear the flags buffer. + HIP_CHECK(hipMemset(dbuffer, 0, flags_buffer_size)); + + // Device-side list of IPC buffers. + buffer_list.resize(world_size); + HIP_CHECK(hipMalloc(&dbuffer_list, world_size * sizeof(uint8_t*))); + + // Create IPC handles for rank's communication buffer. + all_buffer_ipc_handles.resize(world_size); + HIP_CHECK(hipIpcGetMemHandle(&buffer_ipc_handle, dbuffer)); + + initialized = true; + } + int get_world_size() { + return world_size; + } + int get_rank() { + return rank; + } + bool status() { + return initialized; + } + hipIpcMemHandle_t const get_handle() { + return buffer_ipc_handle; + } + + void destroy() { + if (initialized) { + for (int i = 0; i < world_size; i++) { + if (i != rank) { + HIP_CHECK(hipIpcCloseMemHandle(dbuffer_list[i])); + } + } + + HIP_CHECK(hipFree(dbuffer)); + HIP_CHECK(hipFree(dbuffer_list)); + + initialized = false; + } + } + + void open_ipc_handles(std::vector const& ipc_handles) { + assert(ipc_handles.size() == all_buffer_ipc_handles.size()); + for (int i = 0; i < world_size; i++) { + all_buffer_ipc_handles[i] = ipc_handles[i]; + } + + // Open device memory access to the IPC communication buffers. + // Note: For our own rank, we do not need to open a handle. + for (int i = 0; i < world_size; i++) { + if (i != rank) { + HIP_CHECK( + hipIpcOpenMemHandle((void**)&buffer_list[i], all_buffer_ipc_handles[i], hipIpcMemLazyEnablePeerAccess)); + } else { + buffer_list[i] = dbuffer; + } + } + + HIP_CHECK(hipMemcpy(dbuffer_list, buffer_list.data(), world_size * sizeof(uint8_t*), hipMemcpyHostToDevice)); + } + + template + void allreduce(T const* A, T* B, uint32_t N, int quant_level, hipStream_t stream) { + if (world_size != 2 && world_size != 4 && world_size != 8) { + throw std::runtime_error("All Reduce not supported for world_size = " + std::to_string(world_size)); + } + + // Configuration. + uint32_t msg_size = N * sizeof(T); + uint32_t num_blocks = divceil(msg_size, kTileSize); + uint32_t grid = min(kMaxNumBlocks, num_blocks); + auto quant_level_ = static_cast(quant_level); + switch (quant_level_) { + case QuickReduceQuantLevel::INT8: + TWOSHOT_DISPATCH(CodecQ8) + break; + case QuickReduceQuantLevel::INT6: + TWOSHOT_DISPATCH(CodecQ6) + break; + case QuickReduceQuantLevel::INT4: + TWOSHOT_DISPATCH(CodecQ4) + break; + default: + TWOSHOT_DISPATCH(CodecFP) + break; + } + HIP_CHECK(cudaGetLastError()); + // Rotate the flag color. + flag_color += divceil(N, grid); + } +}; + +} // namespace quickreduce diff --git a/sgl-kernel/csrc/allreduce/quick_all_reduce_base.h b/sgl-kernel/csrc/allreduce/quick_all_reduce_base.h new file mode 100644 index 000000000000..759b28f38ef9 --- /dev/null +++ b/sgl-kernel/csrc/allreduce/quick_all_reduce_base.h @@ -0,0 +1,318 @@ +#pragma once + +#include +#include +#include + +#include + +#define __quickreduce_device_inline__ __device__ __forceinline__ +#define __quickreduce_launch_bounds_two_shot__ __launch_bounds__(256, 4) +#define __quickreduce_launch_bounds_one_shot__ __launch_bounds__(512, 4) + +namespace quickreduce { + +typedef __hip_bfloat16 nv_bfloat16; +typedef __hip_bfloat162 nv_bfloat162; + +using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int; +using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int; + +// Setup acquire-release semantics for vector memory reads (mubuf instruction) +// as per architecture. +#if defined(__gfx942__) +// CDNA3: Scope bits sc0, sc1 +#define MUBUF_ACQUIRE 16 +#define MUBUF_RELEASE 16 +#elif (defined(__gfx908__) || defined(__gfx90a__)) +// CDNA1 and CDNA2 - glc bit +#define MUBUF_ACQUIRE 1 +#define MUBUF_RELEASE 0 +#endif + +static constexpr int kNegOne = 0xBC00BC00; // {-1, -1}, fp16x2_t + +// Number of atoms (4xf16x2_t) processed by a single thread +static constexpr int kAtoms = 8; + +// We use a workgroup of 256 threads +static constexpr int kBlockSize = 256; +static constexpr int kAtomStride = kBlockSize; + +// Size and atom stride of source/destination data that the block will +// process. +// Workgroup scope = Tile = (256 threads x 8 atoms x 16B) +static constexpr int kTileSize = kBlockSize * kAtoms * sizeof(int32x4_t); + +// Max number of blocks. 304 CUs on MI300 +static constexpr int kMaxNumBlocks = 304 * 4; + +// Standard CDNA wavefront size. +static constexpr int kWavefront = 64; + +// 256 thread, 4 wavefronts. +static dim3 constexpr kBlockTwoShot = {kWavefront, kBlockSize / kWavefront, 1}; + +// Number of threads in a group for quantization +// It corresponds to 32 F16 elements in quantization block +static constexpr int kThreadGroupSize = 8; + +// Methods +__quickreduce_device_inline__ __host__ unsigned long divceil(unsigned long x, unsigned long y) { + return ((x + y - 1) / y); +} + +union BufferResource { + __quickreduce_device_inline__ constexpr BufferResource() : config(0x00020000U) {} + + __quickreduce_device_inline__ constexpr BufferResource(void* buffer_address, uint32_t buffer_size) + : address(buffer_address), range(buffer_size), config(0x00020000U) {} + + int32x4_t descriptor; + struct { + void* address; // 8B, out of which first 48b is address, and 16b is stride + // (unused) + uint32_t range; // Byte range for the buffer resource + uint32_t config; // Constant, DFMT=32b + }; +}; + +__quickreduce_device_inline__ static int32x4_t buffer_load_dwordx4( + int32x4_t srsrc, int32_t voffset, int32_t soffset, int32_t aux) __asm("llvm.amdgcn.raw.buffer.load.v4i32"); + +__quickreduce_device_inline__ static void +buffer_store_dwordx4(int32x4_t data, int32x4_t srsrc, int32_t voffset, int32_t soffset, int32_t aux) __asm( + "llvm.amdgcn.raw.buffer.store.v4i32"); + +__quickreduce_device_inline__ static void set_fp16_ovfl(bool const value) { +#if defined(__gfx942__) + if (value) { + asm volatile("s_setreg_imm32_b32 0xdc1, 1;" ::); + } else { + asm volatile("s_setreg_imm32_b32 0xdc1, 0;" ::); + } +#endif +} +union bf162_int_union { + int i; + nv_bfloat162 bf2; +}; + +template +__quickreduce_device_inline__ void packed_assign_add(int32x4_t* A, int32x4_t* B); + +template <> +__quickreduce_device_inline__ void packed_assign_add(int32x4_t* A, int32x4_t* B) { + int32x4_t& tR_fragment = A[0]; + int32x4_t& tA_fragment = B[0]; + + asm volatile("v_pk_add_f16 %0, %1, %2" : "=v"(tR_fragment[0]) : "v"(tR_fragment[0]), "v"(tA_fragment[0])); + asm volatile("v_pk_add_f16 %0, %1, %2" : "=v"(tR_fragment[1]) : "v"(tR_fragment[1]), "v"(tA_fragment[1])); + asm volatile("v_pk_add_f16 %0, %1, %2" : "=v"(tR_fragment[2]) : "v"(tR_fragment[2]), "v"(tA_fragment[2])); + asm volatile("v_pk_add_f16 %0, %1, %2" : "=v"(tR_fragment[3]) : "v"(tR_fragment[3]), "v"(tA_fragment[3])); +} + +template <> +__quickreduce_device_inline__ void packed_assign_add(int32x4_t* A, int32x4_t* B) { + nv_bfloat162* tA = reinterpret_cast(A); + nv_bfloat162* tB = reinterpret_cast(B); +#pragma unroll + for (int i = 0; i < 4; i++) { + tA[i] = __hadd2(tA[i], tB[i]); + } +} + +template +__quickreduce_device_inline__ int packed_max(int a, int b); + +template <> +__quickreduce_device_inline__ int packed_max(int a, int b) { + int result; + asm volatile("v_pk_max_f16 %0, %1, %2" : "=v"(result) : "v"(a), "v"(b)); + return result; +} + +template <> +__quickreduce_device_inline__ int packed_max(int a, int b) { + bf162_int_union A, B, R; + A.i = a; + B.i = b; + R.bf2 = __hmax2(A.bf2, B.bf2); + return R.i; +} + +template +__quickreduce_device_inline__ int packed_min(int a, int b); + +template <> +__quickreduce_device_inline__ int packed_min(int a, int b) { + int result; + asm volatile("v_pk_min_f16 %0, %1, %2" : "=v"(result) : "v"(a), "v"(b)); + return result; +} + +template <> +__quickreduce_device_inline__ int packed_min(int a, int b) { + bf162_int_union A, B, R; + A.i = a; + B.i = b; + R.bf2 = __hmin2(A.bf2, B.bf2); + return R.i; +} + +template +__quickreduce_device_inline__ int packed_abs_max(int a, int b); + +template <> +__quickreduce_device_inline__ int packed_abs_max(int a, int b) { + half2 wmaxh2 = __builtin_bit_cast(half2, a); + half2 wminh2 = __builtin_bit_cast(half2, b); + half2 wblockmaxh2; + + wblockmaxh2.x = __hgt(__habs(wmaxh2.x), __habs(wminh2.x)) ? wmaxh2.x : wminh2.x; + wblockmaxh2.y = __hgt(__habs(wmaxh2.y), __habs(wminh2.y)) ? wmaxh2.y : wminh2.y; + return __builtin_bit_cast(int, wblockmaxh2); +} + +template <> +__quickreduce_device_inline__ int packed_abs_max(int a, int b) { + bf162_int_union A, B, R; + A.i = a; + B.i = b; + R.bf2.x = __hgt(__habs(A.bf2.x), __habs(B.bf2.x)) ? A.bf2.x : B.bf2.x; + R.bf2.y = __hgt(__habs(A.bf2.y), __habs(B.bf2.y)) ? A.bf2.y : B.bf2.y; + return R.i; +} + +template +__quickreduce_device_inline__ int packed_add(int a, int b); + +template <> +__quickreduce_device_inline__ int packed_add(int a, int b) { + int result; + asm volatile("v_pk_add_f16 %0, %1, %2" : "=v"(result) : "v"(a), "v"(b)); + return result; +} + +template <> +__quickreduce_device_inline__ int packed_add(int a, int b) { + bf162_int_union A, B, R; + A.i = a; + B.i = b; + R.bf2 = __hadd2(A.bf2, B.bf2); + return R.i; +} + +template <> +__quickreduce_device_inline__ int packed_add(int a, int b) { + int result; + asm volatile("v_pk_add_i16 %0, %1, %2" : "=v"(result) : "v"(a), "v"(b)); + return result; +} + +template +__quickreduce_device_inline__ int packed_sub(int a, int b); + +template <> +__quickreduce_device_inline__ int packed_sub(int a, int b) { + int result; + + // MI300 lacks packed fp16 sub instruction. So we do -1 * min + max + asm volatile("v_pk_fma_f16 %0, %1, %2 %3" : "=v"(result) : "v"(kNegOne), "v"(b), "v"(a)); + return result; +} + +template <> +__quickreduce_device_inline__ int packed_sub(int a, int b) { + bf162_int_union A, B, R; + A.i = a; + B.i = b; + R.bf2 = __hsub2(A.bf2, B.bf2); + return R.i; +} + +template +__quickreduce_device_inline__ int packed_mul(int a, int b); + +template <> +__quickreduce_device_inline__ int packed_mul(int a, int b) { + int result; + asm volatile("v_pk_mul_f16 %0, %1, %2" : "=v"(result) : "v"(a), "v"(b)); + return result; +} + +template <> +__quickreduce_device_inline__ int packed_mul(int a, int b) { + nv_bfloat162* tA = reinterpret_cast(&a); + nv_bfloat162* tB = reinterpret_cast(&b); + nv_bfloat162 tR = __hmul2(*tA, *tB); + return *(reinterpret_cast(&tR)); +} + +template +__quickreduce_device_inline__ int packed_rcp(int a); + +template <> +__quickreduce_device_inline__ int packed_rcp(int a) { + return __builtin_bit_cast(int, h2rcp(__builtin_bit_cast(half2, a))); +} + +template <> +__quickreduce_device_inline__ int packed_rcp(int a) { + bf162_int_union A, R; + A.i = a; + R.bf2 = h2rcp(A.bf2); + return R.i; +} + +// changes dtype +__quickreduce_device_inline__ float T2float_cast(half a) { + return __half2float(a); +} + +__quickreduce_device_inline__ float T2float_cast(nv_bfloat16 a) { + return __bfloat162float(a); +} + +template +__quickreduce_device_inline__ int group_abs_max(int32x4_t atom) { + const int group_leader = (threadIdx.x / kThreadGroupSize) * kThreadGroupSize; + + int wmax, wmin, wblockmax; + int a, b; + a = packed_max(atom[0], atom[1]); + b = packed_max(atom[2], atom[3]); + + wmax = packed_max(a, b); + + a = packed_min(atom[0], atom[1]); + b = packed_min(atom[2], atom[3]); + + wmin = packed_min(a, b); + + // Reduce the max among a group of threads + // Note: This is basically 2 blocks of values setup as the + // upper/lower halves of the f16x2_t + for (int i = 1; i < kThreadGroupSize; i <<= 1) { + int x = __shfl_down(wmax, i); + wmax = packed_max(wmax, x); + + int y = __shfl_down(wmin, i); + wmin = packed_min(wmin, y); + } + wblockmax = packed_abs_max(wmax, wmin); + // Share with the cohort + wblockmax = __shfl(wblockmax, group_leader); + return wblockmax; +} + +__quickreduce_device_inline__ void set_sync_flag(uint32_t* flag_ptr, uint32_t flag) { + __atomic_store_n(flag_ptr, flag, __ATOMIC_RELEASE); +} + +__quickreduce_device_inline__ void wait_sync_flag(uint32_t* flag_ptr, uint32_t flag) { + while (__atomic_load_n(flag_ptr, __ATOMIC_RELAXED) != flag) { + } +} + +} // namespace quickreduce diff --git a/sgl-kernel/csrc/torch_extension_rocm.cc b/sgl-kernel/csrc/torch_extension_rocm.cc index 84f9d1e7a4d8..46a50ca6b969 100644 --- a/sgl-kernel/csrc/torch_extension_rocm.cc +++ b/sgl-kernel/csrc/torch_extension_rocm.cc @@ -54,6 +54,25 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) { m.def("get_meta_buffer_ipc_handle", &get_meta_buffer_ipc_handle); m.impl("get_meta_buffer_ipc_handle", torch::kCPU, &get_meta_buffer_ipc_handle); + // quick allreduce +#ifdef USE_ROCM + m.def( + "qr_all_reduce(int fa, Tensor inp, Tensor out, int quant_level, bool " + "cast_bf2half) -> ()"); + m.impl("qr_all_reduce", torch::kCUDA, &qr_all_reduce); + + m.def("init_custom_qr", &init_custom_qr); + m.def("qr_destroy", &qr_destroy); + + m.def("qr_get_handle", &qr_get_handle); + + m.def("qr_open_handles(int _fa, Tensor[](b!) handles) -> ()"); + m.impl("qr_open_handles", torch::kCPU, &qr_open_handles); + + // Max input size in bytes + m.def("qr_max_size", &qr_max_size); +#endif + /* * From csrc/moe */ diff --git a/sgl-kernel/include/sgl_kernel_ops.h b/sgl-kernel/include/sgl_kernel_ops.h index 6b589101feaa..ffd240a04dd0 100644 --- a/sgl-kernel/include/sgl_kernel_ops.h +++ b/sgl-kernel/include/sgl_kernel_ops.h @@ -66,6 +66,13 @@ void register_graph_buffers( fptr_t _fa, const std::vector& handles, const std::vector>& offsets); torch::Tensor allocate_meta_buffer(int64_t size); torch::Tensor get_meta_buffer_ipc_handle(torch::Tensor& inp); +// quick allreduce +fptr_t init_custom_qr(int64_t rank, int64_t world_size, std::optional qr_max_size = std::nullopt); +void qr_destroy(fptr_t _fa); +torch::Tensor qr_get_handle(fptr_t _fa); +void qr_open_handles(fptr_t _fa, const std::vector& handles); +void qr_all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, int64_t quant_level, bool cast_bf2half = false); +int64_t qr_max_size(); #else // custom allreduce fptr_t @@ -77,6 +84,8 @@ std::tuple, std::vector> get_graph_buffer_ipc_meta void register_buffer(fptr_t _fa, const std::vector& fake_ipc_ptrs); void register_graph_buffers( fptr_t _fa, const std::vector>& handles, const std::vector>& offsets); + +// mscclpp torch::Tensor mscclpp_generate_unique_id(); fptr_t mscclpp_init_context( const torch::Tensor& unique_id, diff --git a/sgl-kernel/python/sgl_kernel/allreduce.py b/sgl-kernel/python/sgl_kernel/allreduce.py index 317b2f1a7813..544fc1d77e27 100644 --- a/sgl-kernel/python/sgl_kernel/allreduce.py +++ b/sgl-kernel/python/sgl_kernel/allreduce.py @@ -1,4 +1,4 @@ -from typing import List, Tuple +from typing import List, Optional, Tuple import torch @@ -49,6 +49,38 @@ def allocate_meta_buffer(size: int) -> torch.Tensor: def get_meta_buffer_ipc_handle(inp: torch.Tensor) -> torch.Tensor: return torch.ops.sgl_kernel.get_meta_buffer_ipc_handle.default(inp) + # ROCM quick allreduce + def init_custom_qr( + rank: int, world_size: int, qr_max_size: Optional[int] = None + ) -> int: + return torch.ops.sgl_kernel.init_custom_qr.default( + world_size, rank, qr_max_size + ) + + def qr_get_handle(fa: int) -> torch.Tensor: + return torch.ops.sgl_kernel.qr_get_handle.default(fa) + + def qr_open_handles(fa: int, handles: list[torch.Tensor]) -> None: + torch.ops.sgl_kernel.qr_open_handles.default(fa, handles) + + def qr_all_reduce( + fa: int, + profile: int, + inp: torch.Tensor, + out: torch.Tensor, + cast_bf162half: bool, + ) -> None: + torch.ops.sgl_kernel.qr_all_reduce.default( + fa, profile, inp, out, cast_bf162half + ) + + def qr_destroy(fa: int) -> None: + torch.ops.sgl_kernel.qr_destroy.default(fa) + + def qr_max_size() -> int: + return torch.ops.sgl_kernel.qr_max_size.default() + + # mscclpp def mscclpp_generate_unique_id() -> bytes: raise NotImplementedError() diff --git a/sgl-kernel/setup_rocm.py b/sgl-kernel/setup_rocm.py index 4ab8635a83ea..a814b819689a 100644 --- a/sgl-kernel/setup_rocm.py +++ b/sgl-kernel/setup_rocm.py @@ -41,6 +41,7 @@ def _get_version(): sources = [ "csrc/allreduce/custom_all_reduce.hip", + "csrc/allreduce/quick_all_reduce.cu", "csrc/moe/moe_align_kernel.cu", "csrc/moe/moe_topk_softmax_kernels.cu", "csrc/torch_extension_rocm.cc", diff --git a/test/srt/test_quick_allreduce.py b/test/srt/test_quick_allreduce.py new file mode 100644 index 000000000000..ed081255f683 --- /dev/null +++ b/test/srt/test_quick_allreduce.py @@ -0,0 +1,212 @@ +import os +import random +import socket +import unittest +from typing import Any + +import ray +import torch +import torch.distributed as dist + +from sglang.srt.distributed import init_distributed_environment +from sglang.srt.distributed.communication_op import ( # noqa + tensor_model_parallel_all_reduce, +) +from sglang.srt.distributed.device_communicators.quick_all_reduce import ( + qr_rocm_arch_available, +) +from sglang.srt.distributed.parallel_state import ( + get_tensor_model_parallel_group, + graph_capture, + initialize_model_parallel, +) +from sglang.test.test_utils import CustomTestCase + +torch.manual_seed(42) +random.seed(44) # keep the deterministic seed + + +def get_open_port() -> int: + # try ipv4 + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + return s.getsockname()[1] + except OSError: + # try ipv6 + with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + return s.getsockname()[1] + + +def multi_process_parallel( + world_size: int, cls: Any, test_target: Any, quant_mode: str +) -> None: + + # Using ray helps debugging the error when it failed + # as compared to multiprocessing. + # NOTE: We need to set working_dir for distributed tests, + # otherwise we may get import errors on ray workers + + ray.init(log_to_driver=True) + + distributed_init_port = get_open_port() + refs = [] + for rank in range(world_size): + refs.append( + test_target.remote(cls, world_size, rank, distributed_init_port, quant_mode) + ) + ray.get(refs) + + ray.shutdown() + + +class TestQuickAllReduce(CustomTestCase): + TEST_SIZES = [ + 2 * 1024 * 1024, + 4 * 1024 * 1024, + 8 * 1024 * 1024, + 16 * 1024 * 1024, + 32 * 1024 * 1024, + ] + TEST_LOOP = 5 + # Too many configurations can lead to a test grid that is too large + # The tp takes too long to boot,let's just choose 4 out of 12 configurations + # WORLD_SIZES = [2, 4, 8] + # QUANT_MODE = ["FP", "INT8", "INT6", "INT4"] + QUANT_MODE_WORLD_SIZE_PART = [["FP", 8], ["INT4", 4], ["INT8", 2], ["INT6", 2]] + + @unittest.skipIf( + not qr_rocm_arch_available(), + "Only test Quick AllReduce on ROCm architectures >= gfx94*", + ) + def test_graph_allreduce(self): + for quant_mode_world_size_part in self.QUANT_MODE_WORLD_SIZE_PART: + quant_mode = quant_mode_world_size_part[0] + world_size = quant_mode_world_size_part[1] + if world_size > torch.cuda.device_count(): + continue + multi_process_parallel(world_size, self, self.graph_allreduce, quant_mode) + + @unittest.skipIf( + not qr_rocm_arch_available(), + "Only test Quick AllReduce on ROCm architectures >= gfx94*", + ) + def test_eager_allreduce(self): + for quant_mode_world_size_part in self.QUANT_MODE_WORLD_SIZE_PART: + quant_mode = quant_mode_world_size_part[0] + world_size = quant_mode_world_size_part[1] + if world_size > torch.cuda.device_count(): + continue + multi_process_parallel(world_size, self, self.eager_allreduce, quant_mode) + + @ray.remote(num_gpus=1, max_calls=1) + def graph_allreduce(self, world_size, rank, distributed_init_port, quant_mode): + os.environ.pop("CUDA_VISIBLE_DEVICES", None) + os.environ["ROCM_QUICK_REDUCE_QUANTIZATION"] = quant_mode + os.environ["ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16"] = "0" + device = torch.device(f"cuda:{rank}") + torch.cuda.set_device(device) + distributed_init_method = f"tcp://localhost:{distributed_init_port}" + init_distributed_environment( + world_size=world_size, + rank=rank, + distributed_init_method=distributed_init_method, + local_rank=rank, + ) + initialize_model_parallel(tensor_model_parallel_size=world_size) + group = get_tensor_model_parallel_group().device_group + + # A small all_reduce for warmup. + # this is needed because device communicators might be created lazily + # (e.g. NCCL). This will ensure that the communicator is initialized + # before any communication happens, so that this group can be used for + # graph capture immediately. + data = torch.zeros(1) + data = data.to(device=device) + torch.distributed.all_reduce(data, group=group) + torch.cuda.synchronize() + del data + + for sz in self.TEST_SIZES: + for dtype in [torch.float16, torch.bfloat16]: + for _ in range(self.TEST_LOOP): + with graph_capture() as graph_capture_context: + # use integers so result matches NCCL exactly + inp1 = torch.randint( + 1, + 23, + (sz,), + dtype=dtype, + device=torch.cuda.current_device(), + ) + inp2 = torch.randint( + -23, + 1, + (sz,), + dtype=dtype, + device=torch.cuda.current_device(), + ) + torch.cuda.synchronize() + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph( + graph, stream=graph_capture_context.stream + ): + out1 = tensor_model_parallel_all_reduce(inp1) + # the input buffer is immediately modified to test + # synchronization + dist.all_reduce(inp1, group=group) + out2 = tensor_model_parallel_all_reduce(inp2) + dist.all_reduce(inp2, group=group) + graph.replay() + atol = 1.25 * world_size + rtol = 0.5 * world_size + for inp, out in [[inp1, out1], [inp2, out2]]: + torch.testing.assert_close(out, inp, atol=atol, rtol=rtol) + # try: + # torch.testing.assert_close(out, inp, atol=atol, rtol=rtol) + # except AssertionError as e: + # print("Max abs diff:", (out - inp).abs().max()) + # print("Max rel diff:", ((out - inp).abs() / inp.abs().clamp(min=1e-5)).max()) + + @ray.remote(num_gpus=1, max_calls=1) + def eager_allreduce(self, world_size, rank, distributed_init_port, quant_mode): + os.environ.pop("CUDA_VISIBLE_DEVICES", None) + os.environ["ROCM_QUICK_REDUCE_QUANTIZATION"] = quant_mode + os.environ["ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16"] = "0" + device = torch.device(f"cuda:{rank}") + torch.cuda.set_device(device) + distributed_init_method = f"tcp://localhost:{distributed_init_port}" + init_distributed_environment( + world_size=world_size, + rank=rank, + distributed_init_method=distributed_init_method, + local_rank=rank, + ) + initialize_model_parallel(tensor_model_parallel_size=world_size) + group = get_tensor_model_parallel_group().device_group + + for sz in self.TEST_SIZES: + for dtype in [torch.float16, torch.bfloat16]: + for _ in range(self.TEST_LOOP): + inp1 = torch.randint( + 1, + 23, + (sz,), + dtype=dtype, + device=torch.cuda.current_device(), + ) + out1 = tensor_model_parallel_all_reduce(inp1) + dist.all_reduce(inp1, group=group) + atol = 1.25 * world_size + rtol = 0.5 * world_size + torch.testing.assert_close(out1, inp1, atol=atol, rtol=rtol) + # try: + # torch.testing.assert_close(out1, inp1, atol=atol, rtol=rtol) + # except AssertionError as e: + # print("Max abs diff:", (out1 - inp1).abs().max()) + # print("Max rel diff:", ((out1 - inp1).abs() / inp1.abs().clamp(min=1e-5)).max()) + + +if __name__ == "__main__": + unittest.main() From c0fb25e9493927cfdf09f49fbe2638584600aae3 Mon Sep 17 00:00:00 2001 From: Cheng Wan <54331508+ch-wan@users.noreply.github.com> Date: Thu, 24 Jul 2025 21:36:21 -0700 Subject: [PATCH 124/396] DP Enhancement (#8280) --- .../sglang/srt/distributed/parallel_state.py | 9 + .../srt/layers/attention/base_attn_backend.py | 4 +- python/sglang/srt/layers/communicator.py | 24 +- python/sglang/srt/layers/dp_attention.py | 96 +- python/sglang/srt/layers/logits_processor.py | 58 +- python/sglang/srt/layers/radix_attention.py | 8 +- python/sglang/srt/managers/schedule_batch.py | 5 +- .../srt/model_executor/cuda_graph_runner.py | 86 +- .../srt/model_executor/forward_batch_info.py | 215 +++- .../sglang/srt/model_executor/model_runner.py | 25 +- python/sglang/srt/models/deepseek_v2.py | 3 +- python/sglang/srt/models/qwen2_moe.py | 4 - python/sglang/srt/models/qwen3_moe.py | 7 +- .../eagle_draft_cuda_graph_runner.py | 60 +- .../eagle_draft_extend_cuda_graph_runner.py | 73 +- python/sglang/srt/speculative/eagle_utils.py | 68 +- python/sglang/srt/speculative/eagle_worker.py | 103 +- python/sglang/srt/two_batch_overlap.py | 1 + test/srt/test_deepep_small.py | 12 +- test/srt/test_hybrid_dp_ep_tp_mtp.py | 920 ++---------------- 20 files changed, 665 insertions(+), 1116 deletions(-) diff --git a/python/sglang/srt/distributed/parallel_state.py b/python/sglang/srt/distributed/parallel_state.py index 130bc53c7ed9..45a1a42093cd 100644 --- a/python/sglang/srt/distributed/parallel_state.py +++ b/python/sglang/srt/distributed/parallel_state.py @@ -545,6 +545,15 @@ def _all_reduce_in_place(self, input_: torch.Tensor) -> None: else: torch.distributed.all_reduce(input_, group=self.device_group) + def reduce_scatter_tensor( + self, + output: torch.Tensor, + input: torch.Tensor, + ) -> None: + # TODO(ch-wan): support other backends + torch.distributed.reduce_scatter_tensor(output, input, group=self.device_group) + return output + def reduce_scatter( self, output: torch.Tensor, diff --git a/python/sglang/srt/layers/attention/base_attn_backend.py b/python/sglang/srt/layers/attention/base_attn_backend.py index bddd7891f924..3025d0b118f9 100644 --- a/python/sglang/srt/layers/attention/base_attn_backend.py +++ b/python/sglang/srt/layers/attention/base_attn_backend.py @@ -65,7 +65,9 @@ def forward( **kwargs, ): """Run forward on an attention layer.""" - if forward_batch.forward_mode.is_decode(): + if forward_batch.forward_mode.is_idle(): + return q.new_empty(q.shape[0], layer.tp_q_head_num * layer.v_head_dim) + elif forward_batch.forward_mode.is_decode(): return self.forward_decode( q, k, diff --git a/python/sglang/srt/layers/communicator.py b/python/sglang/srt/layers/communicator.py index 5e0931ead0b9..aeb8449a17d7 100644 --- a/python/sglang/srt/layers/communicator.py +++ b/python/sglang/srt/layers/communicator.py @@ -24,8 +24,8 @@ tensor_model_parallel_all_reduce, ) from sglang.srt.layers.dp_attention import ( - attn_tp_all_gather, - attn_tp_reduce_scatter, + attn_tp_all_gather_into_tensor, + attn_tp_reduce_scatter_tensor, dp_gather_partial, dp_scatter, get_attention_dp_size, @@ -309,8 +309,8 @@ def _scattered_to_tp_attn_full( forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]], hidden_states, ) - attn_tp_all_gather( - list(hidden_states.tensor_split(context.attn_tp_size)), + attn_tp_all_gather_into_tensor( + hidden_states, local_hidden_states, ) return hidden_states @@ -400,9 +400,7 @@ def _gather_hidden_states_and_residual( ].clone(), residual, ) - attn_tp_all_gather( - list(residual.tensor_split(context.attn_tp_size)), local_residual - ) + attn_tp_all_gather_into_tensor(residual, local_residual) if context.attn_dp_size != 1: if context.attn_tp_rank == 0: hidden_states += residual @@ -442,9 +440,11 @@ def _scatter_hidden_states_and_residual( *, residual_input_mode, ): - tensor_list = list(hidden_states.tensor_split(context.attn_tp_size)) - hidden_states = tensor_list[context.attn_tp_rank] - attn_tp_reduce_scatter(hidden_states, tensor_list) + input_hidden_states = hidden_states + hidden_states = hidden_states.tensor_split(context.attn_tp_size)[ + context.attn_tp_rank + ] + attn_tp_reduce_scatter_tensor(hidden_states, input_hidden_states) if residual_input_mode == ScatterMode.TP_ATTN_FULL: residual = residual.tensor_split(context.attn_tp_size)[context.attn_tp_rank] if hidden_states.shape[0] != 0: @@ -547,8 +547,8 @@ def _gather( forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]], hidden_states, ) - attn_tp_all_gather( - list(hidden_states.tensor_split(context.attn_tp_size)), + attn_tp_all_gather_into_tensor( + hidden_states, local_hidden_states, ) return hidden_states, residual diff --git a/python/sglang/srt/layers/dp_attention.py b/python/sglang/srt/layers/dp_attention.py index ae4041956d9b..55db1333663e 100644 --- a/python/sglang/srt/layers/dp_attention.py +++ b/python/sglang/srt/layers/dp_attention.py @@ -3,7 +3,8 @@ import functools import logging from contextlib import contextmanager -from typing import TYPE_CHECKING, List +from enum import IntEnum, auto +from typing import TYPE_CHECKING, List, Tuple import torch import triton @@ -30,6 +31,34 @@ _LOCAL_ATTN_DP_RANK = None +class DPPaddingMode(IntEnum): + + # Padding tokens to max length and then gather tokens using `all_gather_into_tensor` + MAX_LEN = auto() + # Padding tokens to sum length and then gather tokens using `all_reduce` + SUM_LEN = auto() + + def is_max_len(self): + return self == DPPaddingMode.MAX_LEN + + def is_sum_len(self): + return self == DPPaddingMode.SUM_LEN + + @classmethod + def get_dp_padding_mode(cls, global_num_tokens: List[int]) -> DPPaddingMode: + # we choose the mode that minimizes the communication cost + max_len = max(global_num_tokens) + sum_len = sum(global_num_tokens) + if sum_len * 2 > max_len * get_attention_dp_size(): + return cls.MAX_LEN + else: + return cls.SUM_LEN + + @classmethod + def get_default_mode_in_cuda_graph(cls) -> DPPaddingMode: + return cls.MAX_LEN + + def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_size): if not enable_dp_attention: return tp_rank, tp_size, 0 @@ -162,7 +191,7 @@ def disable_dp_size(): _ATTN_DP_SIZE = old_dp_size -def get_dp_local_info(forward_batch: ForwardBatch): +def get_dp_local_info(forward_batch: ForwardBatch) -> Tuple[torch.Tensor, torch.Tensor]: # `get_dp_local_info` is only called in global DP gather and scatter. We use global DP rank here. dp_rank = get_attention_dp_rank() @@ -221,7 +250,7 @@ def memcpy_triton(dst, src, dim, offset, sz, offset_src): memcpy_triton_kernel[grid](dst, src, offset, sz, offset_src, chunk_size, BLOCK_SIZE) -def _dp_gather( +def _dp_gather_via_all_reduce( global_tokens: torch.Tensor, local_tokens: torch.Tensor, forward_batch: ForwardBatch, @@ -238,13 +267,6 @@ def _dp_gather( local_tokens.untyped_storage() is not global_tokens.untyped_storage() ), "aliasing between global_tokens and local_tokens not allowed" - # NOTE: During draft extend, the gathered_buffer is padded to num_tokens * (speculative_num_steps + 1). - # But the size of local_tokens is total accepted tokens. We need to reduce the local_num_tokens to the - # actual size of the accepted tokens. - if forward_batch.forward_mode.is_draft_extend(): - shape_tensor = local_num_tokens.new_full((), local_tokens.shape[0]) - local_num_tokens = torch.minimum(local_num_tokens, shape_tensor) - memcpy_triton( global_tokens, local_tokens, 0, local_start_pos, local_num_tokens, False ) @@ -263,6 +285,38 @@ def _dp_gather( global_tokens[:] = tensor_model_parallel_all_reduce(global_tokens) +def _dp_gather_via_all_gather( + global_tokens: torch.Tensor, + local_tokens: torch.Tensor, + forward_batch: ForwardBatch, + is_partial: bool, +): + if not is_partial: + if get_attention_tp_rank() != 0: + local_tokens.fill_(0) + scattered_local_tokens = local_tokens.tensor_split(get_attention_tp_size())[ + get_attention_tp_rank() + ] + get_attention_tp_group().reduce_scatter_tensor(scattered_local_tokens, local_tokens) + get_tp_group().all_gather_into_tensor(global_tokens, scattered_local_tokens) + + +def _dp_gather( + global_tokens: torch.Tensor, + local_tokens: torch.Tensor, + forward_batch: ForwardBatch, + is_partial: bool, +): + if forward_batch.dp_padding_mode.is_max_len(): + _dp_gather_via_all_gather( + global_tokens, local_tokens, forward_batch, is_partial + ) + else: + _dp_gather_via_all_reduce( + global_tokens, local_tokens, forward_batch, is_partial + ) + + def dp_gather_partial( global_tokens: torch.Tensor, local_tokens: torch.Tensor, @@ -296,24 +350,18 @@ def dp_scatter( local_tokens.untyped_storage() is not global_tokens.untyped_storage() ), "aliasing between local_tokens and global_tokens not allowed" - # NOTE: During draft extend, the gathered_buffer is padded to num_tokens * (speculative_num_steps + 1). - # But the size of local_tokens is total accepted tokens. We need to reduce the local_num_tokens to the - # actual size of the accepted tokens. - if forward_batch.forward_mode.is_draft_extend(): - shape_tensor = local_num_tokens.new_full((), local_tokens.shape[0]) - local_num_tokens = torch.minimum(local_num_tokens, shape_tensor) - memcpy_triton( local_tokens, global_tokens, 0, local_start_pos, local_num_tokens, True ) -def attn_tp_reduce_scatter( - output: torch.Tensor, - input_list: List[torch.Tensor], -): - return get_attention_tp_group().reduce_scatter(output, input_list) +def attn_tp_reduce_scatter_tensor(output: torch.Tensor, input: torch.Tensor): + return get_attention_tp_group().reduce_scatter_tensor(output, input) + + +def attn_tp_all_gather_into_tensor(output: torch.Tensor, input: torch.Tensor): + return get_attention_tp_group().all_gather_into_tensor(output, input) -def attn_tp_all_gather(output_list: List[torch.Tensor], input_: torch.Tensor): - return get_attention_tp_group().all_gather(input_, output_tensor_list=output_list) +def attn_tp_all_gather(output_list: List[torch.Tensor], input: torch.Tensor): + return get_attention_tp_group().all_gather(input, output_tensor_list=output_list) diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index 79d38193e6aa..0aee86f68a28 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -27,7 +27,9 @@ tensor_model_parallel_all_gather, ) from sglang.srt.layers.dp_attention import ( + DPPaddingMode, attn_tp_all_gather, + attn_tp_all_gather_into_tensor, dp_gather_replicate, dp_scatter, get_attention_dp_rank, @@ -111,7 +113,8 @@ class LogitsMetadata: # Number of tokens to sample per DP rank global_num_tokens_for_logprob_cpu: Optional[torch.Tensor] = None global_num_tokens_for_logprob_gpu: Optional[torch.Tensor] = None - + # The gather mode for DP attention + dp_padding_mode: Optional[DPPaddingMode] = None # for padding padded_static_len: int = -1 @@ -163,12 +166,12 @@ def from_forward_batch(cls, forward_batch: ForwardBatch): forward_batch_gathered_buffer=forward_batch.gathered_buffer, global_num_tokens_for_logprob_cpu=forward_batch.global_num_tokens_for_logprob_cpu, global_num_tokens_for_logprob_gpu=forward_batch.global_num_tokens_for_logprob_gpu, + dp_padding_mode=DPPaddingMode.SUM_LEN, ) - def compute_dp_attention_metadata(self, hidden_states: torch.Tensor): - if self.global_num_tokens_for_logprob_cpu is None: - # we are capturing cuda graph - return + def compute_dp_attention_metadata(self): + # TODO(ch-wan): gathered_buffer here is larger than the actual required size in draft extend, + # we may use a smaller buffer in draft extend. cumtokens = torch.cumsum(self.global_num_tokens_for_logprob_gpu, dim=0) dp_rank = get_attention_dp_rank() @@ -179,18 +182,9 @@ def compute_dp_attention_metadata(self, hidden_states: torch.Tensor): else: dp_local_start_pos = cumtokens[dp_rank - 1] dp_local_num_tokens = self.global_num_tokens_for_logprob_gpu[dp_rank] - gathered_buffer = torch.zeros( - ( - sum(self.global_num_tokens_for_logprob_cpu), - hidden_states.shape[1], - ), - dtype=hidden_states.dtype, - device=hidden_states.device, - ) self.dp_local_start_pos = dp_local_start_pos self.dp_local_num_tokens = dp_local_num_tokens - self.gathered_buffer = gathered_buffer class LogitsProcessor(nn.Module): @@ -434,7 +428,7 @@ def _get_logits( guarantee the given hidden_states follow this constraint. """ if self.do_tensor_parallel_all_gather_dp_attn: - logits_metadata.compute_dp_attention_metadata(hidden_states) + logits_metadata.compute_dp_attention_metadata() hidden_states, local_hidden_states = ( torch.empty_like(logits_metadata.gathered_buffer), hidden_states, @@ -463,15 +457,31 @@ def _get_logits( if self.do_tensor_parallel_all_gather: if self.use_attn_tp_group: - global_logits = torch.empty( - (self.config.vocab_size, logits.shape[0]), - device=logits.device, - dtype=logits.dtype, - ) - global_logits = global_logits.T - attn_tp_all_gather( - list(global_logits.tensor_split(self.attn_tp_size, dim=-1)), logits - ) + if self.config.vocab_size % self.attn_tp_size == 0: + global_logits = torch.empty( + ( + self.attn_tp_size, + logits.shape[0], + self.config.vocab_size // self.attn_tp_size, + ), + device=logits.device, + dtype=logits.dtype, + ) + attn_tp_all_gather_into_tensor(global_logits, logits) + global_logits = global_logits.permute(1, 0, 2).reshape( + logits.shape[0], self.config.vocab_size + ) + else: + global_logits = torch.empty( + (self.config.vocab_size, logits.shape[0]), + device=logits.device, + dtype=logits.dtype, + ) + global_logits = global_logits.T + attn_tp_all_gather( + list(global_logits.tensor_split(self.attn_tp_size, dim=-1)), + logits, + ) logits = global_logits else: logits = tensor_model_parallel_all_gather(logits) diff --git a/python/sglang/srt/layers/radix_attention.py b/python/sglang/srt/layers/radix_attention.py index 322704ca9f78..8004fc7c9c4e 100644 --- a/python/sglang/srt/layers/radix_attention.py +++ b/python/sglang/srt/layers/radix_attention.py @@ -12,14 +12,16 @@ # limitations under the License. # ============================================================================== """Radix attention.""" +from __future__ import annotations from enum import Enum -from typing import Optional +from typing import TYPE_CHECKING, Optional from torch import nn -from sglang.srt.layers.quantization.base_config import QuantizationConfig -from sglang.srt.model_executor.forward_batch_info import ForwardBatch +if TYPE_CHECKING: + from sglang.srt.layers.quantization.base_config import QuantizationConfig + from sglang.srt.model_executor.forward_batch_info import ForwardBatch class AttentionType(Enum): diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 714af6fba588..ea7cad98be90 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -45,7 +45,6 @@ import triton.language as tl from sglang.global_config import global_config -from sglang.srt.configs.model_config import ModelConfig from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject from sglang.srt.disaggregation.base import BaseKVSender from sglang.srt.disaggregation.decode_schedule_batch_mixin import ( @@ -68,6 +67,7 @@ from sglang.srt.utils import flatten_nested_list, support_triton if TYPE_CHECKING: + from sglang.srt.configs.model_config import ModelConfig from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput from sglang.srt.speculative.spec_info import SpeculativeAlgorithm @@ -1880,7 +1880,7 @@ class ModelWorkerBatch: sampling_info: SamplingBatchInfo # The input Embeds - input_embeds: Optional[torch.tensor] = None + input_embeds: Optional[torch.Tensor] = None # For corss-encoder model token_type_ids: Optional[torch.Tensor] = None @@ -1890,7 +1890,6 @@ class ModelWorkerBatch: spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None # If set, the output of the batch contains the hidden states of the run. capture_hidden_mode: CaptureHiddenMode = None - spec_num_draft_tokens: Optional[int] = None hicache_consumer_index: int = 0 # Overlap event diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 520a631c5ecf..eef7fba14734 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -29,9 +29,9 @@ from sglang.srt.custom_op import CustomOp from sglang.srt.distributed import get_tensor_model_parallel_rank from sglang.srt.distributed.parallel_state import GroupCoordinator, graph_capture +from sglang.srt.layers.dp_attention import DPPaddingMode, get_attention_tp_size from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.torchao_utils import save_gemlite_cache -from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ( CaptureHiddenMode, ForwardBatch, @@ -167,8 +167,15 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner): # is very small. We add more values here to make sure we capture the maximum bs. capture_bs += [model_runner.req_to_token_pool.size] + mul_base = 1 + if server_args.enable_two_batch_overlap: - capture_bs = [bs for bs in capture_bs if bs % 2 == 0] + mul_base *= 2 + + if require_gathered_buffer(server_args): + mul_base *= get_attention_tp_size() + + capture_bs = [bs for bs in capture_bs if bs % mul_base == 0] if server_args.cuda_graph_max_bs: capture_bs = [bs for bs in capture_bs if bs <= server_args.cuda_graph_max_bs] @@ -306,20 +313,37 @@ def __init__(self, model_runner: ModelRunner): self.encoder_lens = None if self.require_gathered_buffer: - self.gathered_buffer = torch.zeros( - ( - self.max_num_token, - self.model_runner.model_config.hidden_size, - ), - dtype=self.model_runner.dtype, - ) if self.require_mlp_tp_gather: self.global_num_tokens_gpu = torch.zeros( (self.dp_size,), dtype=torch.int32 ) + self.global_num_tokens_for_logprob_gpu = torch.zeros( + (self.dp_size,), dtype=torch.int32 + ) + self.gathered_buffer = torch.zeros( + ( + self.max_num_token * self.dp_size, + self.model_runner.model_config.hidden_size, + ), + dtype=self.model_runner.dtype, + ) else: assert self.require_attn_tp_gather self.global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32) + self.global_num_tokens_for_logprob_gpu = torch.zeros( + (1,), dtype=torch.int32 + ) + self.gathered_buffer = torch.zeros( + ( + self.max_num_token, + self.model_runner.model_config.hidden_size, + ), + dtype=self.model_runner.dtype, + ) + else: + self.global_num_tokens_gpu = None + self.global_num_tokens_for_logprob_gpu = None + self.gathered_buffer = None self.custom_mask = torch.ones( ( @@ -342,9 +366,9 @@ def __init__(self, model_runner: ModelRunner): def can_run(self, forward_batch: ForwardBatch): if self.require_mlp_tp_gather: cuda_graph_bs = ( - sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs + max(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs if self.model_runner.spec_algorithm.is_eagle() - else sum(forward_batch.global_num_tokens_cpu) + else max(forward_batch.global_num_tokens_cpu) ) else: cuda_graph_bs = forward_batch.batch_size @@ -480,16 +504,19 @@ def capture_one_batch_size(self, bs: int, forward: Callable): if self.require_mlp_tp_gather: self.global_num_tokens_gpu.copy_( torch.tensor( - [ - num_tokens // self.dp_size + (i < (num_tokens % self.dp_size)) - for i in range(self.dp_size) - ], + [num_tokens] * self.dp_size, dtype=torch.int32, device=input_ids.device, ) ) - global_num_tokens = self.global_num_tokens_gpu - gathered_buffer = self.gathered_buffer[:num_tokens] + self.global_num_tokens_for_logprob_gpu.copy_( + torch.tensor( + [num_tokens] * self.dp_size, + dtype=torch.int32, + device=input_ids.device, + ) + ) + gathered_buffer = self.gathered_buffer[: num_tokens * self.dp_size] elif self.require_attn_tp_gather: self.global_num_tokens_gpu.copy_( torch.tensor( @@ -498,10 +525,15 @@ def capture_one_batch_size(self, bs: int, forward: Callable): device=input_ids.device, ) ) - global_num_tokens = self.global_num_tokens_gpu + self.global_num_tokens_for_logprob_gpu.copy_( + torch.tensor( + [num_tokens], + dtype=torch.int32, + device=input_ids.device, + ) + ) gathered_buffer = self.gathered_buffer[:num_tokens] else: - global_num_tokens = None gathered_buffer = None spec_info = self.get_spec_info(num_tokens) @@ -531,7 +563,9 @@ def capture_one_batch_size(self, bs: int, forward: Callable): encoder_lens=encoder_lens, return_logprob=False, positions=positions, - global_num_tokens_gpu=global_num_tokens, + global_num_tokens_gpu=self.global_num_tokens_gpu, + global_num_tokens_for_logprob_gpu=self.global_num_tokens_for_logprob_gpu, + dp_padding_mode=DPPaddingMode.get_default_mode_in_cuda_graph(), gathered_buffer=gathered_buffer, mrope_positions=mrope_positions, spec_algorithm=self.model_runner.spec_algorithm, @@ -635,12 +669,13 @@ def replay_prepare( # Pad if self.require_mlp_tp_gather: - total_batch_size = ( - sum(forward_batch.global_num_tokens_cpu) / self.num_tokens_per_bs + max_num_tokens = max(forward_batch.global_num_tokens_cpu) + max_batch_size = ( + max_num_tokens / self.num_tokens_per_bs if self.model_runner.spec_algorithm.is_eagle() - else sum(forward_batch.global_num_tokens_cpu) + else max_num_tokens ) - index = bisect.bisect_left(self.capture_bs, total_batch_size) + index = bisect.bisect_left(self.capture_bs, max_batch_size) else: index = bisect.bisect_left(self.capture_bs, raw_bs) bs = self.capture_bs[index] @@ -670,7 +705,8 @@ def replay_prepare( if forward_batch.mrope_positions is not None: self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions) if self.require_gathered_buffer: - self.global_num_tokens_gpu.copy_(forward_batch.global_num_tokens_gpu) + self.global_num_tokens_gpu.fill_(bs * self.num_tokens_per_bs) + self.global_num_tokens_for_logprob_gpu.fill_(bs * self.num_tokens_per_bs) if enable_num_token_non_padded(self.model_runner.server_args): self.num_token_non_padded.copy_(forward_batch.num_token_non_padded) if self.enable_two_batch_overlap: diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index 6f3ea547477f..d6850aabd8be 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -38,6 +38,11 @@ import triton import triton.language as tl +from sglang.srt.layers.dp_attention import ( + DPPaddingMode, + get_attention_dp_rank, + get_attention_tp_size, +) from sglang.srt.layers.rotary_embedding import MRotaryEmbedding from sglang.srt.utils import ( flatten_nested_list, @@ -48,6 +53,7 @@ if TYPE_CHECKING: from sglang.srt.layers.attention.base_attn_backend import AttentionBackend + from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.managers.schedule_batch import ModelWorkerBatch, MultimodalInputs from sglang.srt.mem_cache.memory_pool import KVCache, ReqToTokenPool from sglang.srt.model_executor.model_runner import ModelRunner @@ -242,7 +248,7 @@ class ForwardBatch: lora_paths: Optional[List[str]] = None # For input embeddings - input_embeds: Optional[torch.tensor] = None + input_embeds: Optional[torch.Tensor] = None # For cross-encoder model token_type_ids: Optional[torch.Tensor] = None @@ -261,6 +267,8 @@ class ForwardBatch: # Has to be None when cuda graph is captured. global_num_tokens_for_logprob_cpu: Optional[List[int]] = None global_num_tokens_for_logprob_gpu: Optional[torch.Tensor] = None + # The padding mode for DP attention + dp_padding_mode: Optional[DPPaddingMode] = None # for extend, local start pos and num tokens is different in logits processor # this will be computed in get_dp_local_info # this will be recomputed in LogitsMetadata.from_forward_batch @@ -286,7 +294,7 @@ class ForwardBatch: # For two-batch overlap tbo_split_seq_index: Optional[int] = None tbo_parent_token_range: Optional[Tuple[int, int]] = None - tbo_children: Optional[List["ForwardBatch"]] = None + tbo_children: Optional[List[ForwardBatch]] = None @classmethod def init_new( @@ -340,20 +348,38 @@ def init_new( len(batch.input_ids), dtype=torch.int32 ).to(device, non_blocking=True) - # For DP attention + # For MLP sync if batch.global_num_tokens is not None: - - spec_num_draft_tokens = ( - batch.spec_num_draft_tokens - if batch.spec_num_draft_tokens is not None - else 1 + from sglang.srt.speculative.eagle_utils import ( + EagleDraftInput, + EagleVerifyInput, ) - global_num_tokens = [ - x * spec_num_draft_tokens for x in batch.global_num_tokens - ] - global_num_tokens_for_logprob = [ - x * spec_num_draft_tokens for x in batch.global_num_tokens_for_logprob - ] + + assert batch.global_num_tokens_for_logprob is not None + # process global_num_tokens and global_num_tokens_for_logprob + if batch.spec_info is not None: + if isinstance(batch.spec_info, EagleDraftInput): + global_num_tokens = [ + x * batch.spec_info.num_tokens_per_batch + for x in batch.global_num_tokens + ] + global_num_tokens_for_logprob = [ + x * batch.spec_info.num_tokens_for_logprob_per_batch + for x in batch.global_num_tokens_for_logprob + ] + else: + assert isinstance(batch.spec_info, EagleVerifyInput) + global_num_tokens = [ + x * batch.spec_info.draft_token_num + for x in batch.global_num_tokens + ] + global_num_tokens_for_logprob = [ + x * batch.spec_info.draft_token_num + for x in batch.global_num_tokens_for_logprob + ] + else: + global_num_tokens = batch.global_num_tokens + global_num_tokens_for_logprob = batch.global_num_tokens_for_logprob ret.global_num_tokens_cpu = global_num_tokens ret.global_num_tokens_gpu = torch.tensor( @@ -365,15 +391,8 @@ def init_new( global_num_tokens_for_logprob, dtype=torch.int64 ).to(device, non_blocking=True) - sum_len = sum(global_num_tokens) - ret.gathered_buffer = torch.zeros( - (sum_len, model_runner.model_config.hidden_size), - dtype=model_runner.dtype, - device=device, - ) - if ret.forward_mode.is_idle(): - ret.positions = torch.empty((0,), device=device) + ret.positions = torch.empty((0,), dtype=torch.int64, device=device) TboForwardBatchPreparer.prepare( ret, is_draft_worker=model_runner.is_draft_worker ) @@ -573,6 +592,158 @@ def prepare_chunked_kv_indices(self, device: torch.device): ) self.prefix_chunk_kv_indices.append(chunk_kv_indices) + def _pad_tensor_to_size(self, tensor: torch.Tensor, size: int, *, value: int = 0): + if value == 0: + return torch.cat( + [tensor, tensor.new_zeros(size - tensor.shape[0], *tensor.shape[1:])], + dim=0, + ) + else: + return torch.cat( + [ + tensor, + tensor.new_full((size - tensor.shape[0], *tensor.shape[1:]), value), + ], + dim=0, + ) + + def prepare_mlp_sync_batch(self, model_runner: ModelRunner): + + from sglang.srt.speculative.eagle_utils import EagleDraftInput + + assert self.global_num_tokens_cpu is not None + assert self.global_num_tokens_for_logprob_cpu is not None + + global_num_tokens = self.global_num_tokens_cpu + sync_group_size = len(global_num_tokens) + attn_tp_size = get_attention_tp_size() + + for i in range(sync_group_size): + # make sure that the padded length is divisible by attn_tp_size because we may need reduce-scatter across attn_tp dim. + # there is no reduce-scatter in LM logprob, so we do not need to adjust the padded length for logprob + global_num_tokens[i] = ( + (global_num_tokens[i] - 1) // attn_tp_size + 1 + ) * attn_tp_size + + dp_padding_mode = DPPaddingMode.get_dp_padding_mode(global_num_tokens) + self.dp_padding_mode = dp_padding_mode + + if dp_padding_mode.is_max_len(): + # when DP gather mode is all gather, we will use all_gather_into_tensor to gather hidden states, + # where transferred tokens should be padded to the same length. + max_num_tokens = max(global_num_tokens) + global_num_tokens = [max_num_tokens] * sync_group_size + buffer_len = max_num_tokens * sync_group_size + else: + buffer_len = sum(global_num_tokens) + + self.gathered_buffer = torch.zeros( + (buffer_len, model_runner.model_config.hidden_size), + dtype=model_runner.dtype, + device=model_runner.device, + ) + + bs = self.batch_size + if len(global_num_tokens) > 1: + num_tokens = global_num_tokens[get_attention_dp_rank()] + else: + num_tokens = global_num_tokens[0] + + # padding + self.input_ids = self._pad_tensor_to_size(self.input_ids, num_tokens) + self.req_pool_indices = self._pad_tensor_to_size(self.req_pool_indices, bs) + + seq_len_fill_value = ( + model_runner.attn_backend.get_cuda_graph_seq_len_fill_value() + ) + self.seq_lens = self._pad_tensor_to_size( + self.seq_lens, bs, value=seq_len_fill_value + ) + if self.seq_lens_cpu is not None: + self.seq_lens_cpu = self._pad_tensor_to_size( + self.seq_lens_cpu, bs, value=seq_len_fill_value + ) + + self.out_cache_loc = self._pad_tensor_to_size(self.out_cache_loc, num_tokens) + if self.encoder_lens is not None: + self.encoder_lens = self._pad_tensor_to_size(self.encoder_lens, bs) + self.positions = self._pad_tensor_to_size(self.positions, num_tokens) + self.global_num_tokens_cpu = global_num_tokens + self.global_num_tokens_gpu = self.global_num_tokens_gpu.new_tensor( + global_num_tokens + ) + + if self.mrope_positions is not None: + self.mrope_positions = self._pad_tensor_to_size(self.mrope_positions, bs) + + if self.extend_seq_lens is not None: + self.extend_seq_lens = self._pad_tensor_to_size(self.extend_seq_lens, bs) + + if self.spec_info is not None and isinstance(self.spec_info, EagleDraftInput): + spec_info = self.spec_info + self.output_cache_loc_backup = self.out_cache_loc + self.hidden_states_backup = spec_info.hidden_states + if spec_info.topk_p is not None: + spec_info.topk_p = self._pad_tensor_to_size(spec_info.topk_p, bs) + if spec_info.topk_index is not None: + spec_info.topk_index = self._pad_tensor_to_size( + spec_info.topk_index, bs + ) + if spec_info.accept_length is not None: + spec_info.accept_length = self._pad_tensor_to_size( + spec_info.accept_length, bs + ) + spec_info.hidden_states = self._pad_tensor_to_size( + spec_info.hidden_states, num_tokens + ) + + def post_forward_mlp_sync_batch(self, logits_output: LogitsProcessorOutput): + + bs = self.batch_size + + if self.spec_info is not None: + if self.forward_mode.is_decode(): # draft + num_tokens = self.hidden_states_backup.shape[0] + self.positions = self.positions[:num_tokens] + self.seq_lens = self.seq_lens[:bs] + self.req_pool_indices = self.req_pool_indices[:bs] + if self.seq_lens_cpu is not None: + self.seq_lens_cpu = self.seq_lens_cpu[:bs] + logits_output.next_token_logits = logits_output.next_token_logits[ + :num_tokens + ] + logits_output.hidden_states = logits_output.hidden_states[:num_tokens] + elif self.forward_mode.is_target_verify(): # verify + num_tokens = bs * self.spec_info.draft_token_num + logits_output.next_token_logits = logits_output.next_token_logits[ + :num_tokens + ] + logits_output.hidden_states = logits_output.hidden_states[:num_tokens] + elif self.forward_mode.is_draft_extend(): # draft extend + self.spec_info.accept_length = self.spec_info.accept_length[:bs] + logits_output.next_token_logits = logits_output.next_token_logits[:bs] + logits_output.hidden_states = logits_output.hidden_states[:bs] + elif self.forward_mode.is_extend() or self.forward_mode.is_idle(): + logits_output.next_token_logits = logits_output.next_token_logits[:bs] + logits_output.hidden_states = logits_output.hidden_states[:bs] + + if hasattr(self, "hidden_states_backup"): + self.spec_info.hidden_states = self.hidden_states_backup + if hasattr(self, "output_cache_loc_backup"): + self.out_cache_loc = self.output_cache_loc_backup + + elif self.forward_mode.is_decode() or self.forward_mode.is_idle(): + logits_output.next_token_logits = logits_output.next_token_logits[:bs] + if logits_output.hidden_states is not None: + logits_output.hidden_states = logits_output.hidden_states[:bs] + elif self.forward_mode.is_extend(): + num_tokens = self.seq_lens_sum + logits_output.next_token_logits = logits_output.next_token_logits[ + :num_tokens + ] + if logits_output.hidden_states is not None: + logits_output.hidden_states = logits_output.hidden_states[:num_tokens] + # Here we suppose the length of each chunk is equal # For example, if we have 4 sequences with prefix length [256, 512, 768, 1024], prefix_chunk_len = 256 # num_prefix_chunks = cdiv(1024, 256) = 4 diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index cbb35bf270d3..3d3be71f1b82 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -1464,9 +1464,13 @@ def apply_torch_tp(self): tensor_parallel(self.model, device_mesh) def forward_decode( - self, forward_batch: ForwardBatch, pp_proxy_tensors=None + self, + forward_batch: ForwardBatch, + skip_attn_backend_init: bool = False, + pp_proxy_tensors=None, ) -> LogitsProcessorOutput: - self.attn_backend.init_forward_metadata(forward_batch) + if not skip_attn_backend_init: + self.attn_backend.init_forward_metadata(forward_batch) # FIXME: add pp_proxy_tensors arg to all models kwargs = {} if self.support_pp: @@ -1578,8 +1582,18 @@ def _forward_raw( skip_attn_backend_init=skip_attn_backend_init, pp_proxy_tensors=pp_proxy_tensors, ) - elif forward_batch.forward_mode.is_decode(): - ret = self.forward_decode(forward_batch, pp_proxy_tensors=pp_proxy_tensors) + return ret, can_run_cuda_graph + + # For MLP sync + if forward_batch.global_num_tokens_cpu is not None: + forward_batch.prepare_mlp_sync_batch(self) + + if forward_batch.forward_mode.is_decode(): + ret = self.forward_decode( + forward_batch, + skip_attn_backend_init=skip_attn_backend_init, + pp_proxy_tensors=pp_proxy_tensors, + ) elif forward_batch.forward_mode.is_extend(): ret = self.forward_extend( forward_batch, @@ -1597,6 +1611,9 @@ def _forward_raw( else: raise ValueError(f"Invalid forward mode: {forward_batch.forward_mode}") + if forward_batch.global_num_tokens_cpu is not None: + forward_batch.post_forward_mlp_sync_batch(ret) + return ret, can_run_cuda_graph def _preprocess_logits( diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index e02d30839007..7c627bc090f6 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -550,9 +550,8 @@ def forward_cpu( def forward_deepep( self, hidden_states: torch.Tensor, forward_batch: ForwardBatch ) -> torch.Tensor: - forward_mode = forward_batch.forward_mode shared_output = None - if is_non_idle_and_non_empty(forward_mode, hidden_states): + if hidden_states.shape[0] > 0: # router_logits: (num_tokens, n_experts) router_logits = self.gate(hidden_states) shared_output = self._forward_shared_experts(hidden_states) diff --git a/python/sglang/srt/models/qwen2_moe.py b/python/sglang/srt/models/qwen2_moe.py index e033424cf023..291678652939 100644 --- a/python/sglang/srt/models/qwen2_moe.py +++ b/python/sglang/srt/models/qwen2_moe.py @@ -43,10 +43,6 @@ ScatterMode, ) from sglang.srt.layers.dp_attention import ( - attn_tp_all_gather, - attn_tp_reduce_scatter, - dp_gather_partial, - dp_scatter, get_attention_tp_rank, get_attention_tp_size, get_local_attention_dp_size, diff --git a/python/sglang/srt/models/qwen3_moe.py b/python/sglang/srt/models/qwen3_moe.py index c75a384990e8..8eeee74fad1e 100644 --- a/python/sglang/srt/models/qwen3_moe.py +++ b/python/sglang/srt/models/qwen3_moe.py @@ -38,10 +38,6 @@ from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes from sglang.srt.layers.dp_attention import ( - attn_tp_all_gather, - attn_tp_reduce_scatter, - dp_gather_partial, - dp_scatter, get_attention_tp_rank, get_attention_tp_size, get_local_attention_dp_size, @@ -193,8 +189,7 @@ def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward_deepep( self, hidden_states: torch.Tensor, forward_batch: ForwardBatch ) -> torch.Tensor: - forward_mode = forward_batch.forward_mode - if is_non_idle_and_non_empty(forward_mode, hidden_states): + if hidden_states.shape[0] > 0: # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) topk_weights, topk_idx, _ = self.topk( diff --git a/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py b/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py index 6b6c1a777aaa..2c8cdf255e4e 100644 --- a/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +++ b/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py @@ -5,6 +5,7 @@ import torch +from sglang.srt.layers.dp_attention import DPPaddingMode from sglang.srt.model_executor.cuda_graph_runner import ( CUDA_GRAPH_CAPTURE_FAILED_MSG, CudaGraphRunner, @@ -97,13 +98,6 @@ def __init__(self, eagle_worker: EAGLEWorker): ) if self.require_gathered_buffer: - self.gathered_buffer = torch.zeros( - ( - self.max_num_token, - self.model_runner.model_config.hidden_size, - ), - dtype=self.model_runner.dtype, - ) if self.require_mlp_tp_gather: self.global_num_tokens_gpu = torch.zeros( (self.dp_size,), dtype=torch.int32 @@ -111,12 +105,30 @@ def __init__(self, eagle_worker: EAGLEWorker): self.global_num_tokens_for_logprob_gpu = torch.zeros( (self.dp_size,), dtype=torch.int32 ) + self.gathered_buffer = torch.zeros( + ( + self.max_num_token * self.dp_size, + self.model_runner.model_config.hidden_size, + ), + dtype=self.model_runner.dtype, + ) else: assert self.require_attn_tp_gather self.global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32) self.global_num_tokens_for_logprob_gpu = torch.zeros( (1,), dtype=torch.int32 ) + self.gathered_buffer = torch.zeros( + ( + self.max_num_token, + self.model_runner.model_config.hidden_size, + ), + dtype=self.model_runner.dtype, + ) + else: + self.global_num_tokens_gpu = None + self.global_num_tokens_for_logprob_gpu = None + self.gathered_buffer = None # Capture try: @@ -130,9 +142,9 @@ def __init__(self, eagle_worker: EAGLEWorker): def can_run(self, forward_batch: ForwardBatch): if self.require_mlp_tp_gather: cuda_graph_bs = ( - sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs + max(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs if self.model_runner.spec_algorithm.is_eagle() - else sum(forward_batch.global_num_tokens_cpu) + else max(forward_batch.global_num_tokens_cpu) ) else: cuda_graph_bs = forward_batch.batch_size @@ -168,26 +180,20 @@ def capture_one_batch_size(self, num_seqs: int, forward: Callable): if self.require_mlp_tp_gather: self.global_num_tokens_gpu.copy_( torch.tensor( - [ - num_tokens // self.dp_size + (i < (num_tokens % self.dp_size)) - for i in range(self.dp_size) - ], + [num_tokens] * self.dp_size, dtype=torch.int32, device=self.input_ids.device, ) ) self.global_num_tokens_for_logprob_gpu.copy_( torch.tensor( - [ - num_tokens // self.dp_size + (i < (num_tokens % self.dp_size)) - for i in range(self.dp_size) - ], + [num_tokens] * self.dp_size, dtype=torch.int32, device=self.input_ids.device, ) ) global_num_tokens = self.global_num_tokens_gpu - gathered_buffer = self.gathered_buffer[:num_tokens] + gathered_buffer = self.gathered_buffer[: num_tokens * self.dp_size] global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu elif self.require_attn_tp_gather: self.global_num_tokens_gpu.copy_( @@ -233,6 +239,7 @@ def capture_one_batch_size(self, num_seqs: int, forward: Callable): return_logprob=False, positions=positions, global_num_tokens_gpu=global_num_tokens, + dp_padding_mode=DPPaddingMode.get_default_mode_in_cuda_graph(), gathered_buffer=gathered_buffer, spec_algorithm=self.model_runner.spec_algorithm, spec_info=spec_info, @@ -290,12 +297,13 @@ def replay(self, forward_batch: ForwardBatch): # Pad if self.require_mlp_tp_gather: - total_batch_size = ( - sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs + max_num_tokens = max(forward_batch.global_num_tokens_cpu) + max_batch_size = ( + max_num_tokens // self.num_tokens_per_bs if self.model_runner.spec_algorithm.is_eagle() - else sum(forward_batch.global_num_tokens_cpu) + else max_num_tokens ) - index = bisect.bisect_left(self.capture_bs, total_batch_size) + index = bisect.bisect_left(self.capture_bs, max_batch_size) else: index = bisect.bisect_left(self.capture_bs, raw_bs) bs = self.capture_bs[index] @@ -316,12 +324,10 @@ def replay(self, forward_batch: ForwardBatch): self.topk_index[:raw_bs].copy_(forward_batch.spec_info.topk_index) self.hidden_states[:raw_bs].copy_(forward_batch.spec_info.hidden_states) + # TODO(ch-wan): support num_token_non_padded if self.require_gathered_buffer: - self.global_num_tokens_gpu.copy_(forward_batch.global_num_tokens_gpu) - self.global_num_tokens_for_logprob_gpu.copy_( - forward_batch.global_num_tokens_for_logprob_gpu - ) - forward_batch.gathered_buffer = self.gathered_buffer + self.global_num_tokens_gpu.fill_(bs * self.num_tokens_per_bs) + self.global_num_tokens_for_logprob_gpu.fill_(bs * self.num_tokens_per_bs) # Attention backend if bs != raw_bs: diff --git a/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py b/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py index 7057c502da0e..f4ed31d7e995 100644 --- a/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +++ b/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py @@ -5,6 +5,7 @@ import torch +from sglang.srt.layers.dp_attention import DPPaddingMode from sglang.srt.model_executor.cuda_graph_runner import ( CUDA_GRAPH_CAPTURE_FAILED_MSG, CudaGraphRunner, @@ -109,13 +110,6 @@ def __init__(self, eagle_worker: EAGLEWorker): ) if self.require_gathered_buffer: - self.gathered_buffer = torch.zeros( - ( - self.max_num_token, - self.model_runner.model_config.hidden_size, - ), - dtype=self.model_runner.dtype, - ) if self.require_mlp_tp_gather: self.global_num_tokens_gpu = torch.zeros( (self.dp_size,), dtype=torch.int32 @@ -123,12 +117,31 @@ def __init__(self, eagle_worker: EAGLEWorker): self.global_num_tokens_for_logprob_gpu = torch.zeros( (self.dp_size,), dtype=torch.int32 ) + self.gathered_buffer = torch.zeros( + ( + self.max_num_token * self.dp_size, + self.model_runner.model_config.hidden_size, + ), + dtype=self.model_runner.dtype, + ) else: assert self.require_attn_tp_gather self.global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32) self.global_num_tokens_for_logprob_gpu = torch.zeros( (1,), dtype=torch.int32 ) + self.gathered_buffer = torch.zeros( + ( + self.max_num_token, + self.model_runner.model_config.hidden_size, + ), + dtype=self.model_runner.dtype, + ) + else: + self.global_num_tokens_gpu = None + self.global_num_tokens_for_logprob_gpu = None + self.gathered_buffer = None + # Capture try: with model_capture_mode(): @@ -141,9 +154,9 @@ def __init__(self, eagle_worker: EAGLEWorker): def can_run(self, forward_batch: ForwardBatch): if self.require_mlp_tp_gather: cuda_graph_bs = ( - sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs + max(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs if self.model_runner.spec_algorithm.is_eagle() - else sum(forward_batch.global_num_tokens_cpu) + else max(forward_batch.global_num_tokens_cpu) ) else: cuda_graph_bs = forward_batch.seq_lens.numel() @@ -180,27 +193,19 @@ def capture_one_batch_size(self, bs: int, forward: Callable): if self.require_mlp_tp_gather: self.global_num_tokens_gpu.copy_( torch.tensor( - [ - num_tokens // self.dp_size + (i < (num_tokens % self.dp_size)) - for i in range(self.dp_size) - ], + [num_tokens] * self.dp_size, dtype=torch.int32, device=self.input_ids.device, ) ) self.global_num_tokens_for_logprob_gpu.copy_( torch.tensor( - [ - num_tokens // self.dp_size + (i < (num_tokens % self.dp_size)) - for i in range(self.dp_size) - ], + [bs] * self.dp_size, dtype=torch.int32, device=self.input_ids.device, ) ) - global_num_tokens = self.global_num_tokens_gpu - gathered_buffer = self.gathered_buffer[:num_tokens] - global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu + gathered_buffer = self.gathered_buffer[: num_tokens * self.dp_size] elif self.require_attn_tp_gather: self.global_num_tokens_gpu.copy_( torch.tensor( @@ -211,18 +216,14 @@ def capture_one_batch_size(self, bs: int, forward: Callable): ) self.global_num_tokens_for_logprob_gpu.copy_( torch.tensor( - [num_tokens], + [bs], dtype=torch.int32, device=self.input_ids.device, ) ) - global_num_tokens = self.global_num_tokens_gpu gathered_buffer = self.gathered_buffer[:num_tokens] - global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu else: - global_num_tokens = None gathered_buffer = None - global_num_tokens_for_logprob = None spec_info = EagleDraftInput( hidden_states=hidden_states, @@ -243,8 +244,9 @@ def capture_one_batch_size(self, bs: int, forward: Callable): seq_lens_sum=seq_lens.sum().item(), return_logprob=False, positions=positions, - global_num_tokens_gpu=global_num_tokens, - global_num_tokens_for_logprob_gpu=global_num_tokens_for_logprob, + global_num_tokens_gpu=self.global_num_tokens_gpu, + global_num_tokens_for_logprob_gpu=self.global_num_tokens_for_logprob_gpu, + dp_padding_mode=DPPaddingMode.get_default_mode_in_cuda_graph(), gathered_buffer=gathered_buffer, spec_algorithm=self.model_runner.spec_algorithm, spec_info=spec_info, @@ -306,12 +308,13 @@ def replay(self, forward_batch: ForwardBatch): raw_bs = forward_batch.batch_size num_tokens = forward_batch.input_ids.shape[0] if self.require_mlp_tp_gather: - total_batch_size = ( - sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs + max_num_tokens = max(forward_batch.global_num_tokens_cpu) + max_batch_size = ( + max_num_tokens // self.num_tokens_per_bs if self.model_runner.spec_algorithm.is_eagle() - else sum(forward_batch.global_num_tokens_cpu) + else max_num_tokens ) - index = bisect.bisect_left(self.capture_bs, total_batch_size) + index = bisect.bisect_left(self.capture_bs, max_batch_size) else: index = bisect.bisect_left(self.capture_bs, raw_bs) @@ -334,12 +337,10 @@ def replay(self, forward_batch: ForwardBatch): self.accept_length[:raw_bs].copy_(forward_batch.spec_info.accept_length) self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices) + # TODO(ch-wan): support num_token_non_padded if self.require_gathered_buffer: - self.global_num_tokens_gpu.copy_(forward_batch.global_num_tokens_gpu) - self.global_num_tokens_for_logprob_gpu.copy_( - forward_batch.global_num_tokens_for_logprob_gpu - ) - forward_batch.gathered_buffer = self.gathered_buffer + self.global_num_tokens_gpu.fill_(bs * self.num_tokens_per_bs) + self.global_num_tokens_for_logprob_gpu.fill_(bs) if forward_batch.seq_lens_cpu is not None: if bs != raw_bs: diff --git a/python/sglang/srt/speculative/eagle_utils.py b/python/sglang/srt/speculative/eagle_utils.py index 7f7e21e968c1..aa49e4fc753e 100644 --- a/python/sglang/srt/speculative/eagle_utils.py +++ b/python/sglang/srt/speculative/eagle_utils.py @@ -71,9 +71,20 @@ class EagleDraftInput: kv_indptr: torch.Tensor = None kv_indices: torch.Tensor = None + # Shape info for padding + num_tokens_per_batch: int = -1 + num_tokens_for_logprob_per_batch: int = -1 + + # Inputs for draft extend + # shape: (b,) + seq_lens_for_draft_extend: torch.Tensor = None + req_pool_indices_for_draft_extend: torch.Tensor = None + def prepare_for_extend(self, batch: ScheduleBatch): + if batch.forward_mode.is_idle(): return + # Prefill only generate 1 token. assert len(self.verified_id) == len(batch.seq_lens) @@ -95,7 +106,7 @@ def create_idle_input( capture_hidden_mode: CaptureHiddenMode, ): return cls( - verified_id=None, + verified_id=torch.empty((0,), device=device, dtype=torch.int32), hidden_states=torch.empty((0, hidden_size), device=device, dtype=dtype), topk_p=torch.empty((0, topk), device=device, dtype=torch.float32), topk_index=torch.empty((0, topk), device=device, dtype=torch.int64), @@ -109,7 +120,10 @@ def prepare_extend_after_decode( batch: ScheduleBatch, speculative_num_steps: int, ): - batch.forward_mode = ForwardMode.DRAFT_EXTEND + + if batch.forward_mode.is_idle(): + return + batch.input_ids = self.verified_id batch.extend_lens = [x + 1 for x in batch.spec_info.accept_length_cpu] batch.extend_num_tokens = sum(batch.extend_lens) @@ -316,7 +330,7 @@ def generate_attn_arg_prefill( def verify( self, batch: ScheduleBatch, - logits_output: torch.Tensor, + logits_output: LogitsProcessorOutput, token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator, page_size: int, vocab_mask: Optional[torch.Tensor] = None, # For grammar @@ -599,13 +613,14 @@ def verify( batch.out_cache_loc = tgt_cache_loc batch.seq_lens.add_(accept_length + 1) - draft_input = EagleDraftInput() - draft_input.hidden_states = batch.spec_info.hidden_states[accept_index] - draft_input.verified_id = verified_id - draft_input.accept_length = accept_length - draft_input.accept_length_cpu = accept_length.tolist() - draft_input.seq_lens_for_draft_extend = batch.seq_lens - draft_input.req_pool_indices_for_draft_extend = batch.req_pool_indices + draft_input = EagleDraftInput( + hidden_states=batch.spec_info.hidden_states[accept_index], + verified_id=verified_id, + accept_length=accept_length, + accept_length_cpu=accept_length.tolist(), + seq_lens_for_draft_extend=batch.seq_lens, + req_pool_indices_for_draft_extend=batch.req_pool_indices, + ) return EagleVerifyOutput( draft_input=draft_input, @@ -628,7 +643,6 @@ def verify( batch.seq_lens.add_(accept_length + 1) accept_length_cpu = accept_length.tolist() - draft_input = EagleDraftInput() if len(unfinished_accept_index) > 0: unfinished_accept_index = torch.cat(unfinished_accept_index) unfinished_index_device = torch.tensor( @@ -659,18 +673,26 @@ def verify( next_power_of_2(self.draft_token_num), ) - draft_input.hidden_states = batch.spec_info.hidden_states[ - unfinished_accept_index - ] - draft_input.verified_id = predict[unfinished_accept_index] - draft_input.accept_length_cpu = draft_input_accept_length_cpu - draft_input.accept_length = accept_length[unfinished_index_device] - draft_input.seq_lens_for_draft_extend = batch.seq_lens[ - unfinished_index_device - ] - draft_input.req_pool_indices_for_draft_extend = batch.req_pool_indices[ - unfinished_index_device - ] + draft_input = EagleDraftInput( + hidden_states=batch.spec_info.hidden_states[ + unfinished_accept_index + ], + verified_id=predict[unfinished_accept_index], + accept_length_cpu=draft_input_accept_length_cpu, + accept_length=accept_length[unfinished_index_device], + seq_lens_for_draft_extend=batch.seq_lens[unfinished_index_device], + req_pool_indices_for_draft_extend=batch.req_pool_indices[ + unfinished_index_device + ], + ) + else: + draft_input = EagleDraftInput.create_idle_input( + device=batch.device, + hidden_size=batch.model_config.hidden_size, + dtype=batch.model_config.dtype, + topk=self.topk, + capture_hidden_mode=CaptureHiddenMode.LAST, + ) return EagleVerifyOutput( draft_input=draft_input, diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index b6a6dace64ae..2d2e23a01066 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -297,7 +297,7 @@ def draft_model_runner(self): def forward_batch_speculative_generation( self, batch: ScheduleBatch - ) -> Tuple[LogitsProcessorOutput, List[int], int, int]: + ) -> Tuple[LogitsProcessorOutput, torch.Tensor, int, int, bool]: """Run speculative decoding forward. NOTE: Many states of batch is modified as you go through. It is not guaranteed that @@ -325,11 +325,16 @@ def forward_batch_speculative_generation( self.verify(batch, spec_info) ) - if self.check_forward_draft_extend_after_decode(batch): - with self.draft_tp_context(self.draft_model_runner.tp_group): - self.forward_draft_extend_after_decode( - batch, - ) + with self.draft_tp_context(self.draft_model_runner.tp_group): + # NOTE: We should use `check_forward_draft_extend_after_decode` + # when DP attention is enabled, but it is slow. Skip it for now. + if ( + self.server_args.enable_dp_attention + or batch.spec_info.verified_id.shape[0] > 0 + ): + # decode is not finished + self.forward_draft_extend_after_decode(batch) + return ( logits_output, verify_output.verified_id, @@ -339,10 +344,7 @@ def forward_batch_speculative_generation( ) def check_forward_draft_extend_after_decode(self, batch: ScheduleBatch): - local_need_forward = ( - batch.spec_info.verified_id is not None - and batch.spec_info.verified_id.shape[0] > 0 - ) + local_need_forward = batch.spec_info.verified_id.shape[0] > 0 if not self.server_args.enable_dp_attention: return local_need_forward @@ -361,7 +363,7 @@ def check_forward_draft_extend_after_decode(self, batch: ScheduleBatch): def forward_target_extend( self, batch: ScheduleBatch - ) -> Tuple[LogitsProcessorOutput, List[int], int]: + ) -> Tuple[LogitsProcessorOutput, torch.Tensor, int, Optional[torch.Tensor]]: """Run the target extend. Args: @@ -376,7 +378,6 @@ def forward_target_extend( # We need the full hidden states to prefill the KV cache of the draft model. model_worker_batch = batch.get_model_worker_batch() model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL - model_worker_batch.spec_num_draft_tokens = 1 logits_output, next_token_ids, _ = self.target_worker.forward_batch_generation( model_worker_batch ) @@ -508,13 +509,15 @@ def draft(self, batch: ScheduleBatch): self._draft_preprocess_decode(batch) spec_info = batch.spec_info + assert isinstance(spec_info, EagleDraftInput) spec_info.capture_hidden_mode = CaptureHiddenMode.LAST + spec_info.num_tokens_per_batch = self.topk + spec_info.num_tokens_for_logprob_per_batch = self.topk batch.return_hidden_states = False # Get forward batch model_worker_batch = batch.get_model_worker_batch() - model_worker_batch.spec_num_draft_tokens = self.topk assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST forward_batch = ForwardBatch.init_new( model_worker_batch, self.draft_model_runner @@ -527,6 +530,7 @@ def draft(self, batch: ScheduleBatch): forward_batch ) else: + forward_batch.can_run_dp_cuda_graph = False if not forward_batch.forward_mode.is_idle(): # Initialize attention backend self.draft_attn_backend.init_forward_metadata(forward_batch) @@ -578,6 +582,7 @@ def draft(self, batch: ScheduleBatch): def draft_forward(self, forward_batch: ForwardBatch): # Parse args spec_info = forward_batch.spec_info + assert isinstance(spec_info, EagleDraftInput) out_cache_loc = forward_batch.out_cache_loc topk_p, topk_index, hidden_states = ( spec_info.topk_p, @@ -621,8 +626,8 @@ def draft_forward(self, forward_batch: ForwardBatch): spec_info.hidden_states = hidden_states # Run forward - logits_output = self.draft_model_runner.model.forward( - forward_batch.input_ids, forward_batch.positions, forward_batch + logits_output, _ = self.draft_model_runner.forward( + forward_batch, skip_attn_backend_init=True ) self._detect_nan_if_needed(logits_output) probs = torch.softmax(logits_output.next_token_logits, dim=-1) @@ -642,10 +647,10 @@ def verify(self, batch: ScheduleBatch, spec_info: EagleVerifyInput): else ForwardMode.IDLE ) batch.spec_info = spec_info + model_worker_batch = batch.get_model_worker_batch( seq_lens_cpu_cache=spec_info.seq_lens_cpu ) - model_worker_batch.spec_num_draft_tokens = self.speculative_num_draft_tokens assert model_worker_batch.capture_hidden_mode == spec_info.capture_hidden_mode if batch.has_grammar: @@ -782,8 +787,8 @@ def forward_draft_extend( self, batch: ScheduleBatch, hidden_states: torch.Tensor, - next_token_ids: List[int], - seq_lens_cpu: torch.Tensor, + next_token_ids: torch.Tensor, + seq_lens_cpu: Optional[torch.Tensor], ): """Run draft model extend. This API modifies the states of the batch. @@ -795,6 +800,8 @@ def forward_draft_extend( batch.spec_info = EagleDraftInput( hidden_states=hidden_states, verified_id=next_token_ids, + num_tokens_per_batch=1, + num_tokens_for_logprob_per_batch=1, ) batch.return_hidden_states = False batch.spec_info.prepare_for_extend(batch) @@ -802,7 +809,6 @@ def forward_draft_extend( model_worker_batch = batch.get_model_worker_batch( seq_lens_cpu_cache=seq_lens_cpu ) - model_worker_batch.spec_num_draft_tokens = 1 forward_batch = ForwardBatch.init_new( model_worker_batch, self.draft_model_runner ) @@ -814,37 +820,45 @@ def forward_draft_extend( self.capture_for_decode(logits_output, forward_batch.spec_info) def forward_draft_extend_after_decode(self, batch: ScheduleBatch): + assert isinstance(batch.spec_info, EagleDraftInput) # Backup fields that will be modified in-place seq_lens_backup = batch.seq_lens.clone() req_pool_indices_backup = batch.req_pool_indices accept_length_backup = batch.spec_info.accept_length return_logprob_backup = batch.return_logprob + input_is_idle = batch.forward_mode.is_idle() - if not input_is_idle: - # Prepare metadata - if batch.spec_info.verified_id is not None: - batch.spec_info.prepare_extend_after_decode( - batch, - self.speculative_num_steps, - ) - else: - batch = batch.copy() - batch.prepare_for_idle() - hidden_size = ( - self.model_config.hidden_size * 3 - if self.speculative_algorithm.is_eagle3() - else self.model_config.hidden_size - ) - batch.spec_info = EagleDraftInput.create_idle_input( - device=self.device, - hidden_size=hidden_size, - dtype=self.model_config.dtype, - topk=self.topk, - capture_hidden_mode=CaptureHiddenMode.LAST, - ) + + if not input_is_idle and batch.spec_info.verified_id.numel() == 0: + batch = batch.copy() + batch.prepare_for_idle() + hidden_size = ( + self.model_config.hidden_size * 3 + if self.speculative_algorithm.is_eagle3() + else self.model_config.hidden_size + ) + batch.spec_info = EagleDraftInput.create_idle_input( + device=self.device, + hidden_size=hidden_size, + dtype=self.model_config.dtype, + topk=self.topk, + capture_hidden_mode=CaptureHiddenMode.LAST, + ) + + batch.spec_info.num_tokens_per_batch = self.speculative_num_steps + 1 + batch.spec_info.num_tokens_for_logprob_per_batch = 1 + batch.spec_info.prepare_extend_after_decode( + batch, + self.speculative_num_steps, + ) + batch.forward_mode = ( + ForwardMode.DRAFT_EXTEND + if not batch.forward_mode.is_idle() + else ForwardMode.IDLE + ) + batch.return_hidden_states = False model_worker_batch = batch.get_model_worker_batch() - model_worker_batch.spec_num_draft_tokens = self.speculative_num_steps + 1 assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST forward_batch = ForwardBatch.init_new( model_worker_batch, self.draft_model_runner @@ -869,12 +883,13 @@ def forward_draft_extend_after_decode(self, batch: ScheduleBatch): ) forward_batch.spec_info.hidden_states = logits_output.hidden_states else: + forward_batch.can_run_dp_cuda_graph = False if not forward_batch.forward_mode.is_idle(): self.draft_model_runner.attn_backend.init_forward_metadata( forward_batch ) - logits_output = self.draft_model_runner.model.forward( - forward_batch.input_ids, forward_batch.positions, forward_batch + logits_output, _ = self.draft_model_runner.forward( + forward_batch, skip_attn_backend_init=True ) self.capture_for_decode(logits_output, forward_batch.spec_info) diff --git a/python/sglang/srt/two_batch_overlap.py b/python/sglang/srt/two_batch_overlap.py index 74bc1ba8572e..e802a7254d40 100644 --- a/python/sglang/srt/two_batch_overlap.py +++ b/python/sglang/srt/two_batch_overlap.py @@ -545,6 +545,7 @@ def filter_batch( tbo_children=None, global_num_tokens_gpu=None, global_num_tokens_cpu=None, + dp_padding_mode=None, gathered_buffer=gathered_buffer, global_num_tokens_for_logprob_gpu=None, global_num_tokens_for_logprob_cpu=None, diff --git a/test/srt/test_deepep_small.py b/test/srt/test_deepep_small.py index e26017ade608..0f6ccb9553b4 100644 --- a/test/srt/test_deepep_small.py +++ b/test/srt/test_deepep_small.py @@ -35,7 +35,7 @@ def setUpClass(cls): "--cuda-graph-max-bs", "128", "--max-running-requests", - "128", + "512", "--mem-fraction-static", "0.5", ], @@ -81,7 +81,7 @@ def setUpClass(cls): "--cuda-graph-max-bs", "128", "--max-running-requests", - "128", + "256", ], ) @@ -170,7 +170,7 @@ def setUpClass(cls): "--cuda-graph-max-bs", "32", "--max-running-requests", - "128", + "512", ], ) @@ -217,7 +217,7 @@ def setUpClass(cls): "--cuda-graph-max-bs", "128", "--max-running-requests", - "128", + "512", ], ) @@ -273,7 +273,7 @@ def setUpClass(cls): "--cuda-graph-max-bs", "32", "--max-running-requests", - "32", + "64", ], ) @@ -343,7 +343,7 @@ def setUpClass(cls): "--cuda-graph-max-bs", "32", "--max-running-requests", - "32", + "128", ], ) diff --git a/test/srt/test_hybrid_dp_ep_tp_mtp.py b/test/srt/test_hybrid_dp_ep_tp_mtp.py index a3d44a67adcb..74363649a1f1 100644 --- a/test/srt/test_hybrid_dp_ep_tp_mtp.py +++ b/test/srt/test_hybrid_dp_ep_tp_mtp.py @@ -16,7 +16,7 @@ ) -class Test0(CustomTestCase): +class Test00(CustomTestCase): @classmethod def setUpClass(cls): cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST @@ -47,23 +47,10 @@ def test_mmlu(self): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) + self.assertGreater(metrics["score"], 0.48) -class Test1(CustomTestCase): +class Test01(CustomTestCase): @classmethod def setUpClass(cls): cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST @@ -97,23 +84,10 @@ def test_mmlu(self): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) + self.assertGreater(metrics["score"], 0.48) -class Test2(CustomTestCase): +class Test02(CustomTestCase): @classmethod def setUpClass(cls): cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST @@ -147,23 +121,10 @@ def test_mmlu(self): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) + self.assertGreater(metrics["score"], 0.48) - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) - -class Test3(CustomTestCase): +class Test03(CustomTestCase): @classmethod def setUpClass(cls): cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST @@ -196,23 +157,10 @@ def test_mmlu(self): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) + self.assertGreater(metrics["score"], 0.48) -class Test4(CustomTestCase): +class Test04(CustomTestCase): @classmethod def setUpClass(cls): cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST @@ -248,23 +196,10 @@ def test_mmlu(self): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) + self.assertGreater(metrics["score"], 0.48) - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) - - -class Test5(CustomTestCase): +class Test05(CustomTestCase): @classmethod def setUpClass(cls): cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST @@ -300,23 +235,10 @@ def test_mmlu(self): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) + self.assertGreater(metrics["score"], 0.48) -class Test6(CustomTestCase): +class Test06(CustomTestCase): @classmethod def setUpClass(cls): cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST @@ -351,23 +273,10 @@ def test_mmlu(self): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) + self.assertGreater(metrics["score"], 0.48) - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) - - -class Test7(CustomTestCase): +class Test07(CustomTestCase): @classmethod def setUpClass(cls): cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST @@ -402,23 +311,10 @@ def test_mmlu(self): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) + self.assertGreater(metrics["score"], 0.48) -class Test8(CustomTestCase): +class Test08(CustomTestCase): @classmethod def setUpClass(cls): cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST @@ -455,23 +351,10 @@ def test_mmlu(self): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) + self.assertGreater(metrics["score"], 0.48) - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) - - -class Test9(CustomTestCase): +class Test09(CustomTestCase): @classmethod def setUpClass(cls): cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST @@ -508,20 +391,7 @@ def test_mmlu(self): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) + self.assertGreater(metrics["score"], 0.48) class Test10(CustomTestCase): @@ -560,20 +430,7 @@ def test_mmlu(self): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) + self.assertGreater(metrics["score"], 0.48) class Test11(CustomTestCase): @@ -615,20 +472,7 @@ def test_mmlu(self): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) + self.assertGreater(metrics["score"], 0.48) class Test12(CustomTestCase): @@ -670,20 +514,7 @@ def test_mmlu(self): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) + self.assertGreater(metrics["score"], 0.48) class Test13(CustomTestCase): @@ -724,20 +555,7 @@ def test_mmlu(self): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) + self.assertGreater(metrics["score"], 0.48) class Test14(CustomTestCase): @@ -781,20 +599,7 @@ def test_mmlu(self): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) + self.assertGreater(metrics["score"], 0.48) class Test15(CustomTestCase): @@ -838,20 +643,7 @@ def test_mmlu(self): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) + self.assertGreater(metrics["score"], 0.48) class Test16(CustomTestCase): @@ -894,20 +686,7 @@ def test_mmlu(self): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) + self.assertGreater(metrics["score"], 0.48) class Test17(CustomTestCase): @@ -950,20 +729,7 @@ def test_mmlu(self): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) + self.assertGreater(metrics["score"], 0.48) class Test18(CustomTestCase): @@ -1008,20 +774,7 @@ def test_mmlu(self): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) + self.assertGreater(metrics["score"], 0.48) class Test19(CustomTestCase): @@ -1066,20 +819,7 @@ def test_mmlu(self): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) + self.assertGreater(metrics["score"], 0.48) class Test20(CustomTestCase): @@ -1114,20 +854,7 @@ def test_mmlu(self): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) + self.assertGreater(metrics["score"], 0.48) class Test21(CustomTestCase): @@ -1165,20 +892,7 @@ def test_mmlu(self): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) + self.assertGreater(metrics["score"], 0.48) class Test22(CustomTestCase): @@ -1216,20 +930,7 @@ def test_mmlu(self): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) + self.assertGreater(metrics["score"], 0.48) class Test23(CustomTestCase): @@ -1266,20 +967,7 @@ def test_mmlu(self): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) + self.assertGreater(metrics["score"], 0.48) class Test24(CustomTestCase): @@ -1319,20 +1007,7 @@ def test_mmlu(self): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) + self.assertGreater(metrics["score"], 0.48) class Test25(CustomTestCase): @@ -1372,20 +1047,7 @@ def test_mmlu(self): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) + self.assertGreater(metrics["score"], 0.48) class Test26(CustomTestCase): @@ -1424,20 +1086,7 @@ def test_mmlu(self): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) + self.assertGreater(metrics["score"], 0.48) class Test27(CustomTestCase): @@ -1476,20 +1125,7 @@ def test_mmlu(self): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) + self.assertGreater(metrics["score"], 0.48) class Test28(CustomTestCase): @@ -1530,20 +1166,7 @@ def test_mmlu(self): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) + self.assertGreater(metrics["score"], 0.48) class Test29(CustomTestCase): @@ -1584,20 +1207,7 @@ def test_mmlu(self): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) + self.assertGreater(metrics["score"], 0.48) class Test30(CustomTestCase): @@ -1641,20 +1251,7 @@ def test_mmlu(self): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) + self.assertGreater(metrics["score"], 0.48) class Test31(CustomTestCase): @@ -1701,20 +1298,7 @@ def test_mmlu(self): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) + self.assertGreater(metrics["score"], 0.48) class Test32(CustomTestCase): @@ -1761,20 +1345,7 @@ def test_mmlu(self): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) + self.assertGreater(metrics["score"], 0.48) class Test33(CustomTestCase): @@ -1820,20 +1391,7 @@ def test_mmlu(self): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) + self.assertGreater(metrics["score"], 0.48) class Test34(CustomTestCase): @@ -1882,20 +1440,7 @@ def test_mmlu(self): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) + self.assertGreater(metrics["score"], 0.48) class Test35(CustomTestCase): @@ -1944,20 +1489,7 @@ def test_mmlu(self): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) + self.assertGreater(metrics["score"], 0.48) class Test36(CustomTestCase): @@ -2005,20 +1537,7 @@ def test_mmlu(self): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) + self.assertGreater(metrics["score"], 0.48) class Test37(CustomTestCase): @@ -2066,20 +1585,7 @@ def test_mmlu(self): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) + self.assertGreater(metrics["score"], 0.48) class Test38(CustomTestCase): @@ -2129,20 +1635,7 @@ def test_mmlu(self): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) + self.assertGreater(metrics["score"], 0.48) class Test39(CustomTestCase): @@ -2192,20 +1685,7 @@ def test_mmlu(self): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) + self.assertGreater(metrics["score"], 0.48) class Test40(CustomTestCase): @@ -2256,20 +1736,7 @@ def test_mmlu(self): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) + self.assertGreater(metrics["score"], 0.48) class Test41(CustomTestCase): @@ -2323,20 +1790,7 @@ def test_mmlu(self): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) + self.assertGreater(metrics["score"], 0.48) class Test42(CustomTestCase): @@ -2390,20 +1844,7 @@ def test_mmlu(self): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) + self.assertGreater(metrics["score"], 0.48) class Test43(CustomTestCase): @@ -2456,20 +1897,7 @@ def test_mmlu(self): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) + self.assertGreater(metrics["score"], 0.48) class Test44(CustomTestCase): @@ -2525,20 +1953,7 @@ def test_mmlu(self): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) + self.assertGreater(metrics["score"], 0.48) class Test45(CustomTestCase): @@ -2594,20 +2009,7 @@ def test_mmlu(self): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) + self.assertGreater(metrics["score"], 0.48) class Test46(CustomTestCase): @@ -2662,20 +2064,7 @@ def test_mmlu(self): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) + self.assertGreater(metrics["score"], 0.48) class Test47(CustomTestCase): @@ -2730,20 +2119,7 @@ def test_mmlu(self): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) + self.assertGreater(metrics["score"], 0.48) class Test48(CustomTestCase): @@ -2800,20 +2176,7 @@ def test_mmlu(self): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) + self.assertGreater(metrics["score"], 0.48) class Test49(CustomTestCase): @@ -2870,20 +2233,7 @@ def test_mmlu(self): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) + self.assertGreater(metrics["score"], 0.48) class Test50(CustomTestCase): @@ -2928,20 +2278,7 @@ def test_mmlu(self): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) + self.assertGreater(metrics["score"], 0.48) class Test51(CustomTestCase): @@ -2989,20 +2326,7 @@ def test_mmlu(self): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) + self.assertGreater(metrics["score"], 0.48) class Test52(CustomTestCase): @@ -3050,20 +2374,7 @@ def test_mmlu(self): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) + self.assertGreater(metrics["score"], 0.48) class Test53(CustomTestCase): @@ -3110,20 +2421,7 @@ def test_mmlu(self): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) + self.assertGreater(metrics["score"], 0.48) class Test54(CustomTestCase): @@ -3173,20 +2471,7 @@ def test_mmlu(self): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) + self.assertGreater(metrics["score"], 0.48) class Test55(CustomTestCase): @@ -3236,20 +2521,7 @@ def test_mmlu(self): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) + self.assertGreater(metrics["score"], 0.48) class Test56(CustomTestCase): @@ -3298,20 +2570,7 @@ def test_mmlu(self): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) + self.assertGreater(metrics["score"], 0.48) class Test57(CustomTestCase): @@ -3360,20 +2619,7 @@ def test_mmlu(self): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) + self.assertGreater(metrics["score"], 0.48) class Test58(CustomTestCase): @@ -3424,20 +2670,7 @@ def test_mmlu(self): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) + self.assertGreater(metrics["score"], 0.48) class Test59(CustomTestCase): @@ -3488,20 +2721,7 @@ def test_mmlu(self): metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.5) - - def test_mgsm_en(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mgsm_en", - num_examples=None, - num_threads=1024, - ) - - metrics = run_eval(args) - print(f"{metrics=}") - self.assertGreater(metrics["score"], 0.8) + self.assertGreater(metrics["score"], 0.48) if __name__ == "__main__": From 7ad6b766c589cc51f4716b1d2052d66ac1a135fb Mon Sep 17 00:00:00 2001 From: Ying Wang <83981870+ynwang007@users.noreply.github.com> Date: Thu, 24 Jul 2025 23:11:32 -0700 Subject: [PATCH 125/396] fix: Fix failed functional tests https://github.com/meta-llama/llama-stack-evals (#8266) --- .../sglang/srt/entrypoints/openai/serving_chat.py | 14 ++++++++++++++ python/sglang/srt/utils.py | 10 +++++++--- 2 files changed, 21 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/entrypoints/openai/serving_chat.py b/python/sglang/srt/entrypoints/openai/serving_chat.py index 9889cb2edd66..ca090e06074f 100644 --- a/python/sglang/srt/entrypoints/openai/serving_chat.py +++ b/python/sglang/srt/entrypoints/openai/serving_chat.py @@ -55,6 +55,20 @@ def __init__( def _request_id_prefix(self) -> str: return "chatcmpl-" + def _validate_request(self, request: ChatCompletionRequest) -> Optional[str]: + """Validate that the input is valid.""" + if not request.messages: + return "Messages cannot be empty." + + if ( + isinstance(request.tool_choice, str) + and request.tool_choice.lower() == "required" + and not request.tools + ): + return "Tools cannot be empty if tool choice is set to required." + + return None + def _convert_to_internal_request( self, request: ChatCompletionRequest, diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 23960a8c1123..01e54392ac65 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -744,9 +744,13 @@ def load_image( image = Image.open(BytesIO(image_file)) elif image_file.startswith("http://") or image_file.startswith("https://"): timeout = int(os.getenv("REQUEST_TIMEOUT", "3")) - response = requests.get(image_file, stream=True, timeout=timeout).raw - image = Image.open(response) - response.close() + response = requests.get(image_file, stream=True, timeout=timeout) + try: + response.raise_for_status() + image = Image.open(response.raw) + image.load() # Force loading to avoid issues after closing the stream + finally: + response.close() elif image_file.lower().endswith(("png", "jpg", "jpeg", "webp", "gif")): image = Image.open(image_file) elif image_file.startswith("data:"): From af4b9bae95cc992712980bf83d1dce5f3ed33023 Mon Sep 17 00:00:00 2001 From: Hubert Lu <55214931+hubertlu-tw@users.noreply.github.com> Date: Thu, 24 Jul 2025 23:44:28 -0700 Subject: [PATCH 126/396] [AMD] Add silu_and_mul, gelu_and_mul, gelu_tanh_and_mul, and gelu_quick kernels for AMD GPUs (#7135) Co-authored-by: yiakwy-xpu-ml-framework-team <961186938@qq.com> Co-authored-by: HAI --- python/sglang/srt/layers/activation.py | 14 +- python/sglang/test/test_activation.py | 51 +++++- sgl-kernel/benchmark/bench_activation.py | 153 +++++++++++++++++ sgl-kernel/csrc/common_extension.cc | 6 +- sgl-kernel/csrc/elementwise/activation.cu | 128 +++++++++++--- sgl-kernel/csrc/torch_extension_rocm.cc | 14 ++ sgl-kernel/include/hip_act_and_mul.cuh | 87 ++++++++++ sgl-kernel/include/hip_math_def.h | 94 +++++++++++ sgl-kernel/include/hip_vec_dtypes.h | 101 +++++++++++ sgl-kernel/include/impl/hip_vec_bf16_impl.h | 177 ++++++++++++++++++++ sgl-kernel/include/impl/hip_vec_fp32_impl.h | 129 ++++++++++++++ sgl-kernel/include/impl/hip_vec_half_impl.h | 172 +++++++++++++++++++ sgl-kernel/include/sgl_kernel_ops.h | 10 +- sgl-kernel/include/utils.h | 110 +++++++++--- sgl-kernel/python/sgl_kernel/__init__.py | 4 + sgl-kernel/python/sgl_kernel/elementwise.py | 30 +++- sgl-kernel/setup_rocm.py | 5 +- 17 files changed, 1225 insertions(+), 60 deletions(-) create mode 100644 sgl-kernel/benchmark/bench_activation.py create mode 100644 sgl-kernel/include/hip_act_and_mul.cuh create mode 100644 sgl-kernel/include/hip_math_def.h create mode 100644 sgl-kernel/include/hip_vec_dtypes.h create mode 100644 sgl-kernel/include/impl/hip_vec_bf16_impl.h create mode 100644 sgl-kernel/include/impl/hip_vec_fp32_impl.h create mode 100644 sgl-kernel/include/impl/hip_vec_half_impl.h diff --git a/python/sglang/srt/layers/activation.py b/python/sglang/srt/layers/activation.py index 63e9fcdd3cc9..15c2ba077272 100644 --- a/python/sglang/srt/layers/activation.py +++ b/python/sglang/srt/layers/activation.py @@ -33,6 +33,7 @@ cpu_has_amx_support, is_cpu, is_cuda, + is_hip, is_npu, set_weight_attrs, ) @@ -42,9 +43,12 @@ _is_npu = is_npu() _is_cpu_amx_available = cpu_has_amx_support() _is_cpu = is_cpu() +_is_hip = is_hip() if _is_cuda: from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul +elif _is_hip: + from sgl_kernel import gelu_and_mul, gelu_quick, gelu_tanh_and_mul, silu_and_mul if is_npu(): import torch_npu @@ -126,9 +130,13 @@ def forward_native(self, x: torch.Tensor) -> torch.Tensor: return x * torch.sigmoid(1.702 * x) def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: - # TODO(zhyncs): Implement the CUDA kernel for QuickGELU in sgl-kernel return self.forward_native(x) + def forward_hip(self, x: torch.Tensor) -> torch.Tensor: + out = torch.empty(x.shape, dtype=x.dtype, device=x.device) + gelu_quick(x, out) + return out + class ScaledActivation(nn.Module): """An activation function with post-scale parameters. @@ -222,8 +230,8 @@ def get_cross_encoder_activation_function(config: PretrainedConfig): return nn.Identity() -if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available)): +if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available) or _is_hip): logger.info( - "sgl-kernel is not available on Non-NV platforms or Non-AMX CPUs. Fallback to other kernel libraries." + "sgl-kernel is not available on Non-NV, Non-AMD platforms or Non-AMX CPUs. Fallback to other kernel libraries." ) from vllm.model_executor.layers.activation import GeluAndMul, SiluAndMul diff --git a/python/sglang/test/test_activation.py b/python/sglang/test/test_activation.py index 38366e92be78..dd5c668cfce0 100644 --- a/python/sglang/test/test_activation.py +++ b/python/sglang/test/test_activation.py @@ -3,9 +3,12 @@ import torch -from sglang.srt.layers.activation import GeluAndMul +from sglang.srt.layers.activation import GeluAndMul, QuickGELU +from sglang.srt.utils import is_hip from sglang.test.test_utils import CustomTestCase +_is_hip = is_hip() + class TestGeluAndMul(CustomTestCase): DTYPES = [torch.half, torch.bfloat16] @@ -52,5 +55,51 @@ def test_gelu_and_mul(self): self._run_gelu_and_mul_test(*params) +class TestQuickGELU(CustomTestCase): + DTYPES = [torch.half, torch.bfloat16] + NUM_TOKENS = [7, 83, 2048] # batch = sequence length + DIMS = [512, 4096, 5120, 13824] # all multiples of 16 bytes + SEEDS = [0] + + @classmethod + def setUpClass(cls): + if not torch.cuda.is_available(): + raise unittest.SkipTest("CUDA is not available") + torch.set_default_device("cuda") + + def _run_gelu_quick_test(self, n_tok: int, dim: int, dtype: torch.dtype, seed: int): + torch.manual_seed(seed) + + layer = QuickGELU().to(dtype=dtype) + + x = torch.randn(n_tok, dim, dtype=dtype, device="cuda") + + with torch.inference_mode(): + ref = layer.forward_native(x) # x * sigmoid(1.702 * x), fp32 math + if _is_hip: + out = layer.forward_hip(x) # 128-bit vectorised kernel from sgl-kernel + else: + out = layer.forward_cuda(x) + + tol = 1e-2 if dtype is torch.bfloat16 else 1e-3 + self.assertTrue( + torch.allclose(out, ref, atol=tol, rtol=tol), + msg=f"Mismatch @ B={n_tok}, D={dim}, dtype={dtype}", + ) + print(f"Match @ B={n_tok}, D={dim}, dtype={dtype}") + + def test_quick_gelu(self): + for params in itertools.product( + self.NUM_TOKENS, self.DIMS, self.DTYPES, self.SEEDS + ): + with self.subTest( + num_tokens=params[0], + dim=params[1], + dtype=params[2], + seed=params[3], + ): + self._run_gelu_quick_test(*params) + + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/sgl-kernel/benchmark/bench_activation.py b/sgl-kernel/benchmark/bench_activation.py new file mode 100644 index 000000000000..cfea789158b8 --- /dev/null +++ b/sgl-kernel/benchmark/bench_activation.py @@ -0,0 +1,153 @@ +# Benchmarks SGLang kernels versus vLLM across +# (kernel, dtype, batch_size, seq_len, dim) and prints speed-up. +import argparse +import itertools +import re +from typing import List, Tuple + +import sgl_kernel +import torch +import torch.nn.functional as F +import triton +import triton.testing +from sgl_kernel import gelu_quick # activation-only kernel +from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul +from vllm import _custom_ops as vllm_ops + +if not hasattr(vllm_ops, "silu_and_mul"): + vllm_ops = torch.ops._C + + +def str2int_list(arg: str) -> List[int]: + if arg in ("", None): + return [] + if re.fullmatch(r"\d+(,\d+)*", arg.strip()) is None: + raise argparse.ArgumentTypeError(f"Bad int list: {arg}") + return [int(x) for x in arg.split(",")] + + +def calculate_diff( + kernel: str, dtype: torch.dtype, batch_size: int, seq_len: int, dim: int +) -> bool: + """Compare vLLM with SGLang for one shape.""" + device = torch.device("cuda") + + # activation-only quick GELU + if kernel == "gelu_quick": + x = torch.randn(batch_size, seq_len, dim, dtype=dtype, device=device) + ref_out = torch.zeros_like(x) + getattr(vllm_ops, kernel)(ref_out, x) + test_out = getattr(sgl_kernel, kernel)(x) + # fused activation x mul kernels + else: + x = torch.randn(batch_size, seq_len, 2 * dim, dtype=dtype, device=device) + ref_out = torch.zeros(batch_size, seq_len, dim, dtype=dtype, device=device) + getattr(vllm_ops, kernel)(ref_out, x) + test_out = getattr(sgl_kernel, kernel)(x) + + ok = torch.allclose(ref_out, test_out, rtol=1e-3, atol=1e-5) + tag = "✅ match" if ok else "❌ mismatch" + print( + f"[{kernel:14s} | {str(dtype):9s} | B={batch_size:3d} | " + f"L={seq_len:3d} | D={dim:5d}] {tag}" + ) + return ok + + +kernels = ["silu_and_mul", "gelu_and_mul", "gelu_tanh_and_mul", "gelu_quick"] +dtypes = [torch.float16, torch.bfloat16] + + +def make_configs(bsizes: List[int], slens: List[int], dims_: List[int]) -> List[Tuple]: + return list(itertools.product(kernels, dtypes, bsizes, slens, dims_)) + + +default_batch_sizes = [2**i for i in range(0, 5, 2)] # 1,4,16 +default_seq_lens = [2**i for i in range(0, 8, 2)] # 1,4,16,64 +default_dims = [2**i for i in range(7, 15)] # 128...16384 + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["kernel", "dtype", "batch_size", "seq_len", "dim"], + x_vals=[], + line_arg="provider", + line_vals=["vllm", "sglang", "speedup"], + line_names=["vLLM", "SGL Kernel", "Speed-up (x)"], + styles=[("blue", "-"), ("green", "-"), ("red", "--")], + ylabel="µs (median) or × (speed-up)", + plot_name="activation-performance", + args={}, + ) +) +def benchmark(kernel, dtype, batch_size, seq_len, dim, provider): + device = torch.device("cuda") + in_mult = 1 if kernel == "gelu_quick" else 2 + x = torch.randn(batch_size, seq_len, in_mult * dim, dtype=dtype, device=device) + y0 = torch.zeros(batch_size, seq_len, dim, dtype=dtype, device=device) + + vllm_kernel = getattr(vllm_ops, kernel) + sglang_kernel = getattr(sgl_kernel, kernel) + + def baseline(): + tmp = y0.clone() + vllm_kernel(tmp, x) + return tmp + + def sglang(): + return sglang_kernel(x) + + # one-time correctness check + if provider == "vllm" and not calculate_diff( + kernel, dtype, batch_size, seq_len, dim + ): + raise ValueError("Mismatch – abort benchmark") + + # timing helper + def timed(fn): + for _ in range(5): + fn() + torch.cuda.synchronize() + ms, qmin, qmax = triton.testing.do_bench(fn, quantiles=[0.5, 0.2, 0.8]) + return 1000 * ms, 1000 * qmax, 1000 * qmin + + if provider == "vllm": + return timed(baseline) + if provider == "sglang": + return timed(sglang) + + # provider == "speedup" + t_ref, _, _ = timed(baseline) + t_sgl, _, _ = timed(sglang) + spd = t_ref / t_sgl + return (spd, spd, spd) + + +if __name__ == "__main__": + p = argparse.ArgumentParser("Activation kernel benchmark") + p.add_argument("--batch_sizes", type=str2int_list, default=default_batch_sizes) + p.add_argument("--seq_lens", type=str2int_list, default=default_seq_lens) + p.add_argument("--dims", type=str2int_list, default=default_dims) + p.add_argument("--verify_only", action="store_true") + args = p.parse_args() + + # coerce lists + if isinstance(args.batch_sizes, str): + args.batch_sizes = str2int_list(args.batch_sizes) + if isinstance(args.seq_lens, str): + args.seq_lens = str2int_list(args.seq_lens) + if isinstance(args.dims, str): + args.dims = str2int_list(args.dims) + + # patch perf_report grid + benchmark_grid = make_configs(args.batch_sizes, args.seq_lens, args.dims) + if hasattr(benchmark, "benchmarks"): + benchmark.benchmarks.x_vals = benchmark_grid + else: + benchmark.benchmark.x_vals = benchmark_grid + + if args.verify_only: + ok = calculate_diff("gelu_quick", torch.float16, 1, 1, args.dims[0]) + print("✅ sanity pass" if ok else "❌ mismatch") + else: + benchmark.run(print_data=True) diff --git a/sgl-kernel/csrc/common_extension.cc b/sgl-kernel/csrc/common_extension.cc index 20b9a804872d..623fbefb514b 100644 --- a/sgl-kernel/csrc/common_extension.cc +++ b/sgl-kernel/csrc/common_extension.cc @@ -78,13 +78,13 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { m.def("gemma_fused_add_rmsnorm(Tensor! input, Tensor! residual, Tensor weight, float eps, bool enable_pdl) -> ()"); m.impl("gemma_fused_add_rmsnorm", torch::kCUDA, &gemma_fused_add_rmsnorm); - m.def("silu_and_mul(Tensor! out, Tensor input, int cuda_stream) -> ()"); + m.def("silu_and_mul(Tensor! out, Tensor input) -> ()"); m.impl("silu_and_mul", torch::kCUDA, &silu_and_mul); - m.def("gelu_tanh_and_mul(Tensor! out, Tensor input, int cuda_stream) -> ()"); + m.def("gelu_tanh_and_mul(Tensor! out, Tensor input) -> ()"); m.impl("gelu_tanh_and_mul", torch::kCUDA, &gelu_tanh_and_mul); - m.def("gelu_and_mul(Tensor! out, Tensor input, int cuda_stream) -> ()"); + m.def("gelu_and_mul(Tensor! out, Tensor input) -> ()"); m.impl("gelu_and_mul", torch::kCUDA, &gelu_and_mul); m.def( diff --git a/sgl-kernel/csrc/elementwise/activation.cu b/sgl-kernel/csrc/elementwise/activation.cu index 242281fd9ddc..20b889530146 100644 --- a/sgl-kernel/csrc/elementwise/activation.cu +++ b/sgl-kernel/csrc/elementwise/activation.cu @@ -13,70 +13,158 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + +#include +#include +#include + +#ifndef USE_ROCM + #include -#include "pytorch_extension_utils.h" +#include "utils.h" + +#else +#include "hip_act_and_mul.cuh" +#endif + +// Adapted from flashinfer activation +// https://github.com/flashinfer-ai/flashinfer/blob/4e8eb1879f9c3ba6d75511e5893183bf8f289a62/csrc/activation.cu#L44 + +namespace detail { + +template +__device__ __forceinline__ float to_f32(const T& x) { +#if USE_ROCM + return castToFloat(x); +#else + return static_cast(x); +#endif +} + +template +__device__ __forceinline__ T from_f32(float f32) { +#if USE_ROCM + return castFromFloat(f32); +#else + return static_cast(f32); +#endif +} -using namespace flashinfer; +} // namespace detail -__device__ __forceinline__ float silu(const float& val) { - return val / (1.0f + __expf(-val)); +template +__device__ __forceinline__ T silu(const T& x) { + float f32_val = detail::to_f32(x); + return detail::from_f32(f32_val / (1.0f + expf(-f32_val))); } -__device__ __forceinline__ float gelu(const float& val) { +template +__device__ __forceinline__ T gelu(const T& x) { constexpr float kAlpha = M_SQRT1_2; - return val * 0.5f * (1.0f + ::erf(val * kAlpha)); + float f32_val = detail::to_f32(x); + return detail::from_f32(f32_val * (0.5f * (1.0f + erf(f32_val * kAlpha)))); +} + +// gelu_quick(x) = x * torch.sigmoid(1.702 * x) +template +__device__ __forceinline__ T gelu_quick_act(const T& x) { + float f32_val = detail::to_f32(x); + return detail::from_f32(f32_val / (1.0f + expf(-f32_val * 1.702f))); } -__device__ __forceinline__ float gelu_tanh(const float& val) { - const float cdf = 0.5f * (1.0f + math::tanh((0.7978845608028654f * (val + 0.044715f * val * val * val)))); - return val * cdf; +template +__device__ __forceinline__ T gelu_tanh(const T& x) { + constexpr float kAlpha = 0.044715f; + constexpr float kBeta = 0.7978845608028654f; + float f32_val = detail::to_f32(x); + const float cdf = 0.5f * (1.0f + tanhf((kBeta * (f32_val + kAlpha * f32_val * f32_val * f32_val)))); + return detail::from_f32(f32_val * cdf); } -void silu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream) { +void silu_and_mul(at::Tensor& out, at::Tensor& input) { int d = input.size(-1) / 2; int64_t num_tokens = input.numel() / input.size(-1); dim3 grid(num_tokens); - cudaStream_t stream = reinterpret_cast(cuda_stream); - DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input.scalar_type(), c_type, [&] { + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); + + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), c_type, [&] { uint32_t vec_size = 16 / sizeof(c_type); dim3 block(std::min(d / vec_size, 1024U)); +#if USE_ROCM + sgl_hip::activation::act_and_mul_kernel + <<>>(static_cast(out.data_ptr()), static_cast(input.data_ptr()), d); +#else flashinfer::activation::act_and_mul_kernel <<>>(static_cast(out.data_ptr()), static_cast(input.data_ptr()), d); - +#endif return true; }); } -void gelu_tanh_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream) { +void gelu_tanh_and_mul(at::Tensor& out, at::Tensor& input) { int d = input.size(-1) / 2; int64_t num_tokens = input.numel() / input.size(-1); dim3 grid(num_tokens); - cudaStream_t stream = reinterpret_cast(cuda_stream); - DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input.scalar_type(), c_type, [&] { + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); + + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), c_type, [&] { uint32_t vec_size = 16 / sizeof(c_type); dim3 block(std::min(d / vec_size, 1024U)); +#if USE_ROCM + sgl_hip::activation::act_and_mul_kernel + <<>>(static_cast(out.data_ptr()), static_cast(input.data_ptr()), d); +#else flashinfer::activation::act_and_mul_kernel <<>>(static_cast(out.data_ptr()), static_cast(input.data_ptr()), d); - +#endif return true; }); } -void gelu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream) { +void gelu_and_mul(at::Tensor& out, at::Tensor& input) { int d = input.size(-1) / 2; int64_t num_tokens = input.numel() / input.size(-1); dim3 grid(num_tokens); - cudaStream_t stream = reinterpret_cast(cuda_stream); - DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input.scalar_type(), c_type, [&] { + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); + + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), c_type, [&] { uint32_t vec_size = 16 / sizeof(c_type); dim3 block(std::min(d / vec_size, 1024U)); +#if USE_ROCM + sgl_hip::activation::act_and_mul_kernel + <<>>(static_cast(out.data_ptr()), static_cast(input.data_ptr()), d); +#else flashinfer::activation::act_and_mul_kernel <<>>(static_cast(out.data_ptr()), static_cast(input.data_ptr()), d); +#endif + + return true; + }); +} + +#if USE_ROCM +void gelu_quick(at::Tensor& out, const at::Tensor& input) { + int d = input.size(-1); + int64_t num_tokens = input.numel() / input.size(-1); + dim3 grid(num_tokens); + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); + + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), c_type, [&] { + uint32_t vec_size = 16 / sizeof(c_type); + dim3 block(std::min(d / vec_size, 1024U)); + sgl_hip::activation::act_only_kernel + <<>>(static_cast(out.data_ptr()), static_cast(input.data_ptr()), d); return true; }); } +#endif diff --git a/sgl-kernel/csrc/torch_extension_rocm.cc b/sgl-kernel/csrc/torch_extension_rocm.cc index 46a50ca6b969..9010d0b260f0 100644 --- a/sgl-kernel/csrc/torch_extension_rocm.cc +++ b/sgl-kernel/csrc/torch_extension_rocm.cc @@ -19,6 +19,20 @@ limitations under the License. #include "sgl_kernel_ops.h" TORCH_LIBRARY_EXPAND(sgl_kernel, m) { + /* + * From csrc/activation + */ + m.def("silu_and_mul(Tensor! out, Tensor input) -> ()"); + m.impl("silu_and_mul", torch::kCUDA, &silu_and_mul); + + m.def("gelu_tanh_and_mul(Tensor! out, Tensor input) -> ()"); + m.impl("gelu_tanh_and_mul", torch::kCUDA, &gelu_tanh_and_mul); + + m.def("gelu_and_mul(Tensor! out, Tensor input) -> ()"); + m.impl("gelu_and_mul", torch::kCUDA, &gelu_and_mul); + + m.def("gelu_quick(Tensor! out, Tensor input) -> ()"); + m.impl("gelu_quick", torch::kCUDA, &gelu_quick); /* * From csrc/allreduce */ diff --git a/sgl-kernel/include/hip_act_and_mul.cuh b/sgl-kernel/include/hip_act_and_mul.cuh new file mode 100644 index 000000000000..ddb1b702d92d --- /dev/null +++ b/sgl-kernel/include/hip_act_and_mul.cuh @@ -0,0 +1,87 @@ +/* Copyright 2025 SGLang Team. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include "utils.h" + +#define kBitsToLoad 128 +#define kBytesToLoad (kBitsToLoad / 8) + +// Adapted from +// [flashinfer::activation::act_and_mul_kernel](https://github.com/flashinfer-ai/flashinfer/blob/4e8eb1879f9c3ba6d75511e5893183bf8f289a62/include/flashinfer/activation.cuh#L29) + +namespace sgl_hip { +namespace activation { + +template +__global__ void act_and_mul_kernel(T* __restrict__ out, const T* __restrict__ input, const int d) { + constexpr uint32_t vec_size = kBytesToLoad / sizeof(T); + const int64_t token_idx = blockIdx.x; + const int64_t thread_idx = threadIdx.x; + const int64_t stride = blockDim.x; + const int64_t offset = token_idx * 2 * d; + +#pragma unroll 1 + for (uint32_t idx = thread_idx; idx < d / vec_size; idx += stride) { + sgl_hip::vec_t x_vec, y_vec, out_vec; + x_vec.cast_load(input + offset + idx * vec_size); + y_vec.cast_load(input + offset + d + idx * vec_size); +#pragma unroll + for (uint32_t i = 0; i < vec_size; ++i) { + out_vec[i] = Activation(x_vec[i]) * y_vec[i]; + } + out_vec.cast_store(out + token_idx * d + idx * vec_size); + } + + const int64_t remaining_offset = d - d % (stride * vec_size); + // process the remaining elements +#pragma unroll 1 + for (int64_t idx = thread_idx; idx < d % (stride * vec_size); idx += stride) { + T x = input[offset + remaining_offset + idx], y = input[offset + remaining_offset + d + idx]; + out[token_idx * d + remaining_offset + idx] = Activation(x) * y; + } +} + +template +__global__ void act_only_kernel(T* __restrict__ out, const T* __restrict__ input, const int d) { + constexpr uint32_t vec_size = kBytesToLoad / sizeof(T); + const int64_t token_idx = blockIdx.x; + const int64_t thread_idx = threadIdx.x; + const int64_t stride = blockDim.x; + const int64_t offset = token_idx * d; + +#pragma unroll 1 + for (uint32_t idx = thread_idx; idx < d / vec_size; idx += stride) { + sgl_hip::vec_t x_vec, y_vec, out_vec; + x_vec.cast_load(input + offset + idx * vec_size); +#pragma unroll + for (uint32_t i = 0; i < vec_size; ++i) { + out_vec[i] = Activation(x_vec[i]); + } + out_vec.cast_store(out + token_idx * d + idx * vec_size); + } + + const int64_t remaining_offset = d - d % (stride * vec_size); + // process the remaining elements +#pragma unroll 1 + for (int64_t idx = thread_idx; idx < d % (stride * vec_size); idx += stride) { + T x = input[offset + remaining_offset + idx]; + out[token_idx * d + remaining_offset + idx] = Activation(x); + } +} + +} // namespace activation +} // namespace sgl_hip diff --git a/sgl-kernel/include/hip_math_def.h b/sgl-kernel/include/hip_math_def.h new file mode 100644 index 000000000000..21cc67456ee7 --- /dev/null +++ b/sgl-kernel/include/hip_math_def.h @@ -0,0 +1,94 @@ +/* Copyright 2025 SGLang Team. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#if defined(__HIP_PLATFORM_AMD__) + +#include +#include +#include + +// Adapted from flashinfer-rocm [PR#491](https://github.com/flashinfer-ai/flashinfer/pull/491) + +namespace amdgpu { + +template +__forceinline__ __device__ T shfl_xor_sync(unsigned mask, T var, int laneMask, int width = warpSize); + +template +__forceinline__ __device__ destDtype cast(srcDtype val); + +// specialization +template <> +__forceinline__ __device__ float shfl_xor_sync(unsigned mask, float var, int laneMask, int width) { + return __shfl_xor(var, laneMask, width); +} + +template <> +__forceinline__ __device__ int shfl_xor_sync(unsigned mask, int var, int laneMask, int width) { + return __shfl_xor(var, laneMask, width); +} + +template <> +__forceinline__ __device__ float cast(float val) { + return val; +} + +template <> +__forceinline__ __device__ float cast(__half val) { + return __half2float(val); +} + +template <> +__forceinline__ __device__ float cast(__hip_bfloat16 val) { + return __bfloat162float(val); +} + +template <> +__forceinline__ __device__ __half cast(float fval) { + return __float2half(fval); +} + +template <> +__forceinline__ __device__ __hip_bfloat16 cast(float fval) { + return __float2bfloat16(fval); +} + +} // namespace amdgpu + +template +__forceinline__ __device__ T __shfl_xor_sync(unsigned mask, T var, int laneMask, int width = warpSize) { + return amdgpu::shfl_xor_sync(mask, var, laneMask, width); +} + +template +__device__ __forceinline__ float castToFloat(srcDtype val) { + return amdgpu::cast(val); +} + +template +__device__ __forceinline__ dstDtype castFromFloat(float val) { + return amdgpu::cast(val); +} + +// operator overload to support flashinfer +__host__ __device__ __forceinline__ __half operator*(const __half& x, const __half& y) { + __half h_x = x; + __half h_y = y; + return __hmul(h_x, h_y); +} + +#endif diff --git a/sgl-kernel/include/hip_vec_dtypes.h b/sgl-kernel/include/hip_vec_dtypes.h new file mode 100644 index 000000000000..a68a6986e027 --- /dev/null +++ b/sgl-kernel/include/hip_vec_dtypes.h @@ -0,0 +1,101 @@ +/* Copyright 2025 SGLang Team. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#if USE_ROCM + +#include +#include +#include + +// Adapted from flashinfer-rocm [PR#491](https://github.com/flashinfer-ai/flashinfer/pull/491)d + +#define SGL_HIP_INLINE inline __attribute__((always_inline)) __device__ + +namespace sgl_hip { + +template +struct vec_t; + +template +SGL_HIP_INLINE void cast_load_impl(vec_t& dst, const srcDtype* src); + +template +SGL_HIP_INLINE void cast_store_impl(dstDtype* dst_ptr, const vec_t& src); + +template +struct vec_t { + SGL_HIP_INLINE float_t& operator[](size_t i); + SGL_HIP_INLINE const float_t& operator[](size_t i) const; + SGL_HIP_INLINE float_t* ptr(); + + SGL_HIP_INLINE void load(const float_t* ptr); + SGL_HIP_INLINE void store(float_t* ptr) const; + + template + SGL_HIP_INLINE void cast_from(const vec_t& src); + template + SGL_HIP_INLINE void cast_load(const T* ptr); + template + SGL_HIP_INLINE void cast_store(T* ptr) const; +}; + +} // namespace sgl_hip + +// **** impl ***** + +namespace sgl_hip { + +template +SGL_HIP_INLINE void cast_load_impl(vec_t& dst, const srcDtype* src_ptr) { + if constexpr (std::is_same::value) { + dst.load(src_ptr); + } else { + vec_t tmp; + tmp.load(src_ptr); + dst.cast_from(tmp); + } +} + +template +SGL_HIP_INLINE void cast_store_impl(dstDtype* dst_ptr, const vec_t& src) { + if constexpr (std::is_same::value) { + src.store(dst_ptr); + } else { + vec_t tmp; + tmp.cast_from(src); + tmp.store(dst_ptr); + } +} + +template +template +SGL_HIP_INLINE void vec_t::cast_load(const T* ptr) { + cast_load_impl(*this, ptr); +} + +template +template +SGL_HIP_INLINE void vec_t::cast_store(T* ptr) const { + cast_store_impl(ptr, *this); +} + +} // namespace sgl_hip + +#include "impl/hip_vec_bf16_impl.h" +#include "impl/hip_vec_fp32_impl.h" +#include "impl/hip_vec_half_impl.h" +#endif diff --git a/sgl-kernel/include/impl/hip_vec_bf16_impl.h b/sgl-kernel/include/impl/hip_vec_bf16_impl.h new file mode 100644 index 000000000000..b783f3f43fa8 --- /dev/null +++ b/sgl-kernel/include/impl/hip_vec_bf16_impl.h @@ -0,0 +1,177 @@ +#pragma once + +#if USE_ROCM + +#include +#include + +// Adapted from flashinfer-rocm [PR#491](https://github.com/flashinfer-ai/flashinfer/pull/491) + +using nv_bfloat16 = __hip_bfloat16; +using nv_bfloat162 = __hip_bfloat162; + +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 make_bfloat162(const __hip_bfloat16 x, const __hip_bfloat16 y) { + __hip_bfloat162 t; + t.x = x; + t.y = y; + return t; +} + +namespace sgl_hip { + +// nv_bfloat16 x 1 +template <> +struct vec_t { + nv_bfloat16 data; + SGL_HIP_INLINE nv_bfloat16& operator[](size_t i) { + return ((nv_bfloat16*)(&data))[i]; + } + SGL_HIP_INLINE const nv_bfloat16& operator[](size_t i) const { + return ((const nv_bfloat16*)(&data))[i]; + } + SGL_HIP_INLINE nv_bfloat16* ptr() { + return reinterpret_cast(&data); + } + SGL_HIP_INLINE void load(const nv_bfloat16* ptr); + SGL_HIP_INLINE void store(nv_bfloat16* ptr) const; + template + SGL_HIP_INLINE void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + SGL_HIP_INLINE void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + SGL_HIP_INLINE void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } +}; + +SGL_HIP_INLINE void vec_t::load(const nv_bfloat16* ptr) { + data = *ptr; +} + +SGL_HIP_INLINE void vec_t::store(nv_bfloat16* ptr) const { + *ptr = data; +} + +// nv_bfloat16 x 2 +template <> +struct vec_t { + nv_bfloat162 data; + + SGL_HIP_INLINE nv_bfloat16& operator[](size_t i) { + return ((nv_bfloat16*)(&data))[i]; + } + SGL_HIP_INLINE const nv_bfloat16& operator[](size_t i) const { + return ((const nv_bfloat16*)(&data))[i]; + } + SGL_HIP_INLINE nv_bfloat16* ptr() { + return reinterpret_cast(&data); + } + SGL_HIP_INLINE void load(const nv_bfloat16* ptr); + SGL_HIP_INLINE void store(nv_bfloat16* ptr) const; + template + SGL_HIP_INLINE void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + SGL_HIP_INLINE void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + SGL_HIP_INLINE void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } +}; + +SGL_HIP_INLINE void vec_t::load(const nv_bfloat16* ptr) { + data = *((nv_bfloat162*)ptr); +} + +SGL_HIP_INLINE void vec_t::store(nv_bfloat16* ptr) const { + *((nv_bfloat162*)ptr) = data; +} + +template <> +struct vec_t { + uint2 data; + + SGL_HIP_INLINE nv_bfloat16& operator[](size_t i) { + return ((nv_bfloat16*)(&data))[i]; + } + SGL_HIP_INLINE const nv_bfloat16& operator[](size_t i) const { + return ((const nv_bfloat16*)(&data))[i]; + } + SGL_HIP_INLINE nv_bfloat16* ptr() { + return reinterpret_cast(&data); + } + SGL_HIP_INLINE void load(const nv_bfloat16* ptr); + SGL_HIP_INLINE void store(nv_bfloat16* ptr) const; + template + SGL_HIP_INLINE void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + SGL_HIP_INLINE void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + SGL_HIP_INLINE void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } +}; + +SGL_HIP_INLINE void vec_t::load(const nv_bfloat16* ptr) { + data = *((uint2*)ptr); +} + +SGL_HIP_INLINE void vec_t::store(nv_bfloat16* ptr) const { + *((uint2*)ptr) = data; +} + +// nv_bfloat16 x 8 or more + +template +struct vec_t { + uint4 data[vec_size / 8]; + + SGL_HIP_INLINE nv_bfloat16& operator[](size_t i) { + return ((nv_bfloat16*)data)[i]; + } + SGL_HIP_INLINE const nv_bfloat16& operator[](size_t i) const { + return ((const nv_bfloat16*)data)[i]; + } + SGL_HIP_INLINE nv_bfloat16* ptr() { + return reinterpret_cast(&data); + } + SGL_HIP_INLINE void load(const nv_bfloat16* ptr) { +#pragma unoll + for (size_t i = 0; i < vec_size / 8; ++i) { + data[i] = ((uint4*)ptr)[i]; + } + } + SGL_HIP_INLINE void store(nv_bfloat16* ptr) const { +#pragma unoll + for (size_t i = 0; i < vec_size / 8; ++i) { + ((uint4*)ptr)[i] = data[i]; + } + } + template + SGL_HIP_INLINE void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + SGL_HIP_INLINE void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + SGL_HIP_INLINE void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } +}; + +} // namespace sgl_hip + +#endif diff --git a/sgl-kernel/include/impl/hip_vec_fp32_impl.h b/sgl-kernel/include/impl/hip_vec_fp32_impl.h new file mode 100644 index 000000000000..97cba6320d38 --- /dev/null +++ b/sgl-kernel/include/impl/hip_vec_fp32_impl.h @@ -0,0 +1,129 @@ +#pragma once + +#if USE_ROCM + +#include + +// Adapted from flashinfer-rocm [PR#491](https://github.com/flashinfer-ai/flashinfer/pull/491) + +namespace sgl_hip { + +template <> +struct vec_t { + float data; + + SGL_HIP_INLINE float& operator[](size_t i) { + return ((float*)(&data))[i]; + } + SGL_HIP_INLINE const float& operator[](size_t i) const { + return ((const float*)(&data))[i]; + } + SGL_HIP_INLINE float* ptr() { + return reinterpret_cast(&data); + } + SGL_HIP_INLINE void load(const float* ptr); + SGL_HIP_INLINE void store(float* ptr) const; + template + SGL_HIP_INLINE void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + SGL_HIP_INLINE void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + SGL_HIP_INLINE void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } +}; + +SGL_HIP_INLINE void vec_t::load(const float* ptr) { + data = *ptr; +} + +SGL_HIP_INLINE void vec_t::store(float* ptr) const { + *ptr = data; +} + +// float x 2 + +template <> +struct vec_t { + float2 data; + + SGL_HIP_INLINE float& operator[](size_t i) { + return ((float*)(&data))[i]; + } + SGL_HIP_INLINE const float& operator[](size_t i) const { + return ((const float*)(&data))[i]; + } + SGL_HIP_INLINE float* ptr() { + return reinterpret_cast(&data); + } + SGL_HIP_INLINE void load(const float* ptr); + SGL_HIP_INLINE void store(float* ptr) const; + template + SGL_HIP_INLINE void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + SGL_HIP_INLINE void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + SGL_HIP_INLINE void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } +}; + +SGL_HIP_INLINE void vec_t::load(const float* ptr) { + data = *((float2*)ptr); +} + +SGL_HIP_INLINE void vec_t::store(float* ptr) const { + *((float2*)ptr) = data; +} + +// float x 4 or more +template +struct vec_t { + float4 data[vec_size / 4]; + + SGL_HIP_INLINE float& operator[](size_t i) { + return ((float*)(data))[i]; + } + SGL_HIP_INLINE const float& operator[](size_t i) const { + return ((const float*)(data))[i]; + } + SGL_HIP_INLINE float* ptr() { + return reinterpret_cast(&data); + } + SGL_HIP_INLINE void load(const float* ptr) { +#pragma unroll + for (size_t i = 0; i < vec_size / 4; ++i) { + data[i] = ((float4*)ptr)[i]; + } + } + SGL_HIP_INLINE void store(float* ptr) const { +#pragma unroll + for (size_t i = 0; i < vec_size / 4; ++i) { + ((float4*)ptr)[i] = data[i]; + } + } + template + SGL_HIP_INLINE void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + SGL_HIP_INLINE void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + SGL_HIP_INLINE void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } +}; + +} // namespace sgl_hip + +#endif diff --git a/sgl-kernel/include/impl/hip_vec_half_impl.h b/sgl-kernel/include/impl/hip_vec_half_impl.h new file mode 100644 index 000000000000..767b9c62f9b9 --- /dev/null +++ b/sgl-kernel/include/impl/hip_vec_half_impl.h @@ -0,0 +1,172 @@ +#pragma once + +#if USE_ROCM + +#include +#include + +// Adapted from flashinfer-rocm [PR#491](https://github.com/flashinfer-ai/flashinfer/pull/491) + +using half = __half; +using half2 = __half2; + +namespace sgl_hip { + +// half x 1 +template <> +struct vec_t { + half data; + + SGL_HIP_INLINE half& operator[](size_t i) { + return ((half*)(&data))[i]; + } + SGL_HIP_INLINE const half& operator[](size_t i) const { + return ((const half*)(&data))[i]; + } + SGL_HIP_INLINE half* ptr() { + return reinterpret_cast(&data); + } + SGL_HIP_INLINE void load(const half* ptr); + SGL_HIP_INLINE void store(half* ptr) const; + template + SGL_HIP_INLINE void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + SGL_HIP_INLINE void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + SGL_HIP_INLINE void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } +}; + +SGL_HIP_INLINE void vec_t::load(const half* ptr) { + data = *ptr; +} + +SGL_HIP_INLINE void vec_t::store(half* ptr) const { + *ptr = data; +} + +// half x 2 +template <> +struct vec_t { + half2 data; + + SGL_HIP_INLINE half& operator[](size_t i) { + return ((half*)(&data))[i]; + } + SGL_HIP_INLINE const half& operator[](size_t i) const { + return ((const half*)(&data))[i]; + } + SGL_HIP_INLINE half* ptr() { + return reinterpret_cast(&data); + } + SGL_HIP_INLINE void load(const half* ptr); + SGL_HIP_INLINE void store(half* ptr) const; + template + SGL_HIP_INLINE void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + SGL_HIP_INLINE void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + SGL_HIP_INLINE void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } +}; + +SGL_HIP_INLINE void vec_t::load(const half* ptr) { + data = *((half2*)ptr); +} + +SGL_HIP_INLINE void vec_t::store(half* ptr) const { + *((half2*)ptr) = data; +} + +// half x 4 + +template <> +struct vec_t { + uint2 data; + + SGL_HIP_INLINE half& operator[](size_t i) { + return ((half*)(&data))[i]; + } + SGL_HIP_INLINE const half& operator[](size_t i) const { + return ((const half*)(&data))[i]; + } + SGL_HIP_INLINE half* ptr() { + return reinterpret_cast(&data); + } + SGL_HIP_INLINE void load(const half* ptr); + SGL_HIP_INLINE void store(half* ptr) const; + template + SGL_HIP_INLINE void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + SGL_HIP_INLINE void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + SGL_HIP_INLINE void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } +}; + +SGL_HIP_INLINE void vec_t::load(const half* ptr) { + data = *((uint2*)ptr); +} + +SGL_HIP_INLINE void vec_t::store(half* ptr) const { + *((uint2*)ptr) = data; +} + +// half x 8 or more + +template +struct vec_t { + uint4 data[vec_size / 8]; + + SGL_HIP_INLINE half& operator[](size_t i) { + return ((half*)data)[i]; + } + SGL_HIP_INLINE const half& operator[](size_t i) const { + return ((const half*)data)[i]; + } + SGL_HIP_INLINE half* ptr() { + return reinterpret_cast(&data); + } + SGL_HIP_INLINE void load(const half* ptr) { +#pragma unroll + for (size_t i = 0; i < vec_size / 8; ++i) { + data[i] = ((uint4*)ptr)[i]; + } + } + SGL_HIP_INLINE void store(half* ptr) const { +#pragma unroll + for (size_t i = 0; i < vec_size / 8; ++i) { + ((uint4*)ptr)[i] = data[i]; + } + } + template + SGL_HIP_INLINE void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + SGL_HIP_INLINE void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + SGL_HIP_INLINE void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } +}; + +} // namespace sgl_hip +#endif diff --git a/sgl-kernel/include/sgl_kernel_ops.h b/sgl-kernel/include/sgl_kernel_ops.h index ffd240a04dd0..ca82760500ce 100644 --- a/sgl-kernel/include/sgl_kernel_ops.h +++ b/sgl-kernel/include/sgl_kernel_ops.h @@ -138,9 +138,10 @@ void sgl_fused_add_rmsnorm( torch::Tensor input, torch::Tensor residual, torch::Tensor weight, double eps, bool enable_pdl); void gemma_rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, bool enable_pdl); void gemma_fused_add_rmsnorm(at::Tensor& input, at::Tensor& residual, at::Tensor& weight, double eps, bool enable_pdl); -void silu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream); -void gelu_tanh_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream); -void gelu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream); +void silu_and_mul(at::Tensor& out, at::Tensor& input); +void gelu_tanh_and_mul(at::Tensor& out, at::Tensor& input); +void gelu_and_mul(at::Tensor& out, at::Tensor& input); + void apply_rope_pos_ids_cos_sin_cache( at::Tensor q, at::Tensor k, @@ -151,6 +152,9 @@ void apply_rope_pos_ids_cos_sin_cache( bool interleave, int64_t cuda_stream); +#ifdef USE_ROCM +void gelu_quick(at::Tensor& out, const at::Tensor& input); +#endif /* * From csrc/gemm */ diff --git a/sgl-kernel/include/utils.h b/sgl-kernel/include/utils.h index 1054dbc5286a..d7d0d5d1fc83 100644 --- a/sgl-kernel/include/utils.h +++ b/sgl-kernel/include/utils.h @@ -19,7 +19,20 @@ limitations under the License. #include #include -#include +#ifdef USE_ROCM +// Adapted from flashinfer-rocm [PR#491](https://github.com/flashinfer-ai/flashinfer/pull/491) +#define _DISPATCH_CASE_F16(c_type, ...) \ + case at::ScalarType::Half: { \ + using c_type = __half; \ + return __VA_ARGS__(); \ + } + +#define _DISPATCH_CASE_BF16(c_type, ...) \ + case at::ScalarType::BFloat16: { \ + using c_type = __hip_bfloat16; \ + return __VA_ARGS__(); \ + } +#endif // USE_ROCM #ifndef USE_ROCM // Adapt from FlashInfer @@ -31,7 +44,7 @@ limitations under the License. } #else #define _DISPATCH_CASE_F16(c_type, ...) -#endif +#endif // FLASHINFER_ENABLE_F16 #ifdef FLASHINFER_ENABLE_BF16 #define _DISPATCH_CASE_BF16(c_type, ...) \ @@ -41,7 +54,7 @@ limitations under the License. } #else #define _DISPATCH_CASE_BF16(c_type, ...) -#endif +#endif // FLASHINFER_ENABLE_BF16 #ifdef FLASHINFER_ENABLE_FP8_E4M3 #define _DISPATCH_CASE_FP8_E4M3(c_type, ...) \ @@ -51,7 +64,7 @@ limitations under the License. } #else #define _DISPATCH_CASE_FP8_E4M3(c_type, ...) -#endif +#endif // FLASHINFER_ENABLE_FP8_E4M3 #ifdef FLASHINFER_ENABLE_FP8_E5M2 #define _DISPATCH_CASE_FP8_E5M2(c_type, ...) \ @@ -61,7 +74,7 @@ limitations under the License. } #else #define _DISPATCH_CASE_FP8_E5M2(c_type, ...) -#endif +#endif // FLASHINFER_ENABLE_FP8_E5M2 #define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(pytorch_dtype, c_type, ...) \ [&]() -> bool { \ @@ -197,7 +210,7 @@ inline constexpr uint32_t pack_u16(uint16_t a, uint16_t b) { inline bool is_float8_tensor(const at::Tensor& tensor) { return tensor.scalar_type() == at::ScalarType::Float8_e4m3fn || tensor.scalar_type() == at::ScalarType::Float8_e5m2; } -#endif +#endif // USE_ROCM struct cuda_error : public std::runtime_error { /** @@ -267,7 +280,6 @@ inline bool getEnvEnablePDL() { #define SGLANG_SHFL_XOR_SYNC_WIDTH(mask, var, lane_mask, width) __shfl_xor((var), (lane_mask), (width)) #endif -#ifndef USE_ROCM #define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(pytorch_dtype, c_type, ...) \ [&]() -> bool { \ switch (pytorch_dtype) { \ @@ -284,7 +296,6 @@ inline bool getEnvEnablePDL() { return false; \ } \ }() -#endif #define DISPATCH_CASE_INTEGRAL_TYPES(...) \ AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \ @@ -297,52 +308,99 @@ inline bool getEnvEnablePDL() { AT_DISPATCH_SWITCH(TYPE, NAME, DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__)) #define CEILDIV(x, y) (((x) + (y) - 1) / (y)) + +#ifndef USE_ROCM #define WARP_SIZE 32 +#else +#define WARP_SIZE warpSize // 64 +#endif + +#if defined(__HIP_PLATFORM_AMD__) + +#include "hip_math_def.h" +#include "hip_vec_dtypes.h" + +#else + +template +__device__ __forceinline__ float castToFloat(srcDtype val) { + return static_cast(val); +} + +template +__device__ __forceinline__ dstDtype castFromFloat(float val) { + return static_cast(val); +} + +#endif + +// add FP8 support #ifndef USE_ROCM #include using FP8_TYPE = c10::Float8_e4m3fn; C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX = std::numeric_limits::max(); -#else -#include +#else // USE_ROCM + +#if HIP_FP8_TYPE_FNUZ +#include using FP8_TYPE = c10::Float8_e4m3fnuz; constexpr auto FP8_E4M3_MAX = 224.0f; -#endif +#else +#if HIP_FP8_TYPE_E4M3 +#include +using FP8_TYPE = c10::Float8_e4m3fn; +C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX = std::numeric_limits::max(); +#else +#error "fp8 is not supported in this processor (arch < gfx942)." +#endif // HIP_FP8_TYPE_E4M3 +#endif // HIP_FP8_TYPE_FNUZ +#endif // USE_ROCM + +#define FULL_MASK 0xffffffff -#ifndef USE_ROCM __device__ __forceinline__ float atomicMaxFloat(float* addr, float value) { +#ifndef USE_ROCM float old; old = (value >= 0) ? __int_as_float(atomicMax((int*)addr, __float_as_int(value))) : __uint_as_float(atomicMin((unsigned int*)addr, __float_as_uint(value))); return old; +#else + int* addr_as_i = (int*)addr; + int old = *addr_as_i, assumed; + do { + assumed = old; + old = atomicCAS(addr_as_i, assumed, __float_as_int(fmaxf(value, __int_as_float(assumed)))); + } while (assumed != old); + return __int_as_float(old); +#endif } -__device__ __forceinline__ float warpReduceMax(float max_value) { - max_value = fmaxf(max_value, SGLANG_SHFL_XOR_SYNC(0xffffffff, max_value, 16)); - max_value = fmaxf(max_value, SGLANG_SHFL_XOR_SYNC(0xffffffff, max_value, 8)); - max_value = fmaxf(max_value, SGLANG_SHFL_XOR_SYNC(0xffffffff, max_value, 4)); - max_value = fmaxf(max_value, SGLANG_SHFL_XOR_SYNC(0xffffffff, max_value, 2)); - max_value = fmaxf(max_value, SGLANG_SHFL_XOR_SYNC(0xffffffff, max_value, 1)); - return max_value; +__device__ __forceinline__ float warpReduceMax(float value) { + value = fmaxf(value, __shfl_xor_sync(FULL_MASK, value, 16)); + value = fmaxf(value, __shfl_xor_sync(FULL_MASK, value, 8)); + value = fmaxf(value, __shfl_xor_sync(FULL_MASK, value, 4)); + value = fmaxf(value, __shfl_xor_sync(FULL_MASK, value, 2)); + value = fmaxf(value, __shfl_xor_sync(FULL_MASK, value, 1)); + return value; } -__device__ __forceinline__ float blockReduceMax(float max_value) { +__device__ __forceinline__ float blockReduceMax(float value) { static __shared__ float warpLevelMaxs[WARP_SIZE]; const int laneId = threadIdx.x % WARP_SIZE; const int warpId = threadIdx.x / WARP_SIZE; - max_value = warpReduceMax(max_value); + value = warpReduceMax(value); - if (laneId == 0) warpLevelMaxs[warpId] = max_value; + if (laneId == 0) warpLevelMaxs[warpId] = value; __syncthreads(); - max_value = (threadIdx.x < blockDim.x / WARP_SIZE) ? warpLevelMaxs[laneId] : 0; - if (warpId == 0) max_value = warpReduceMax(max_value); + value = (threadIdx.x < blockDim.x / WARP_SIZE) ? warpLevelMaxs[laneId] : 0; + if (warpId == 0) value = warpReduceMax(value); - return max_value; + return value; } -#endif // Pads to a multiple of `alignment` rows. inline torch::Tensor pad_tensor(const torch::Tensor& tensor, int64_t alignment = 4, bool is_column_major = false) { diff --git a/sgl-kernel/python/sgl_kernel/__init__.py b/sgl-kernel/python/sgl_kernel/__init__.py index 5cecfc3c041e..2a4656aea21b 100755 --- a/sgl-kernel/python/sgl_kernel/__init__.py +++ b/sgl-kernel/python/sgl_kernel/__init__.py @@ -31,6 +31,10 @@ silu_and_mul, ) from sgl_kernel.fused_moe import fused_marlin_moe + +if torch.version.hip is not None: + from sgl_kernel.elementwise import gelu_quick + from sgl_kernel.gemm import ( awq_dequantize, bmm_fp8, diff --git a/sgl-kernel/python/sgl_kernel/elementwise.py b/sgl-kernel/python/sgl_kernel/elementwise.py index 0e2bbc9904dd..01ee718606ba 100644 --- a/sgl-kernel/python/sgl_kernel/elementwise.py +++ b/sgl-kernel/python/sgl_kernel/elementwise.py @@ -179,7 +179,7 @@ def silu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor: device=input.device, dtype=input.dtype, ) - torch.ops.sgl_kernel.silu_and_mul.default(out, input, get_cuda_stream()) + torch.ops.sgl_kernel.silu_and_mul.default(out, input) return out @@ -194,7 +194,7 @@ def gelu_tanh_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Te device=input.device, dtype=input.dtype, ) - torch.ops.sgl_kernel.gelu_tanh_and_mul.default(out, input, get_cuda_stream()) + torch.ops.sgl_kernel.gelu_tanh_and_mul.default(out, input) return out @@ -209,10 +209,34 @@ def gelu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor: device=input.device, dtype=input.dtype, ) - torch.ops.sgl_kernel.gelu_and_mul.default(out, input, get_cuda_stream()) + torch.ops.sgl_kernel.gelu_and_mul.default(out, input) return out +if torch.version.hip is not None: + + def gelu_quick(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor: + """ + Quick-GELU: y = x * sigmoid(1.702 * x) + + The CUDA/HIP kernel uses 128-bit (16-byte) vector loads & stores, + so the last-dimension byte length must be a multiple of 16 bytes. + """ + if input.shape[-1] * input.dtype.itemsize % 16 != 0: + raise ValueError( + f"The last dimension ({input.shape[-1]}) x itemsize " + f"({input.dtype.itemsize}) must be a multiple of 16 bytes." + ) + + if out is not None: + assert input.shape == out.shape, f"{input.shape} != {out.shape}" + else: + out = torch.empty_like(input) + + torch.ops.sgl_kernel.gelu_quick(out, input) + return out + + def apply_rope_with_cos_sin_cache_inplace( positions: torch.Tensor, query: torch.Tensor, diff --git a/sgl-kernel/setup_rocm.py b/sgl-kernel/setup_rocm.py index a814b819689a..47f59071f4d3 100644 --- a/sgl-kernel/setup_rocm.py +++ b/sgl-kernel/setup_rocm.py @@ -36,16 +36,18 @@ def _get_version(): operator_namespace = "sgl_kernel" include_dirs = [ root / "include", + root / "include" / "impl", root / "csrc", ] sources = [ "csrc/allreduce/custom_all_reduce.hip", "csrc/allreduce/quick_all_reduce.cu", + "csrc/elementwise/activation.cu", "csrc/moe/moe_align_kernel.cu", "csrc/moe/moe_topk_softmax_kernels.cu", - "csrc/torch_extension_rocm.cc", "csrc/speculative/eagle_utils.cu", + "csrc/torch_extension_rocm.cc", ] cxx_flags = ["-O3"] @@ -69,6 +71,7 @@ def _get_version(): ) sys.exit(1) + hipcc_flags = [ "-DNDEBUG", f"-DOPERATOR_NAMESPACE={operator_namespace}", From 15d275917431648a85cfa8b06c6471cbf2ffbd8b Mon Sep 17 00:00:00 2001 From: Zaili Wang <109502517+ZailiWang@users.noreply.github.com> Date: Fri, 25 Jul 2025 15:03:16 +0800 Subject: [PATCH 127/396] [CPU] Add tutorial docs for SGL on CPU (#8000) --- docs/references/cpu.md | 197 +++++++++++++++++++++++++++++++++++ docs/references/deepseek.md | 3 + docs/references/hardware.rst | 1 + docs/start/install.md | 6 ++ 4 files changed, 207 insertions(+) create mode 100644 docs/references/cpu.md diff --git a/docs/references/cpu.md b/docs/references/cpu.md new file mode 100644 index 000000000000..5aa76af32c41 --- /dev/null +++ b/docs/references/cpu.md @@ -0,0 +1,197 @@ +# SGLang on CPU + +The document addresses how to set up the [SGLang](https://github.com/sgl-project/sglang) environment and run LLM inference on CPU servers. +Specifically, SGLang is well optimized on the CPUs equipped with Intel® AMX® Instructions, +which are 4th generation or newer Intel® Xeon® Scalable Processors. + +## Optimized Model List + +A list of popular LLMs are optimized and run efficiently on CPU, +including the most notable open-source models like Llama series, Qwen series, +and the phenomenal high-quality reasoning model DeepSeek-R1. + +| Model Name | BF16 | w8a8_int8 | FP8 | +|:---:|:---:|:---:|:---:| +| DeepSeek-R1 | | [meituan/DeepSeek-R1-Channel-INT8](https://huggingface.co/meituan/DeepSeek-R1-Channel-INT8) | [deepseek-ai/DeepSeek-R1](https://huggingface.co/deepseek-ai/DeepSeek-R1) | +| Llama-3.2-3B | [meta-llama/Llama-3.2-3B-Instruct](https://huggingface.co/meta-llama/Llama-3.2-3B-Instruct) | [RedHatAI/Llama-3.2-3B-quantized.w8a8](https://huggingface.co/RedHatAI/Llama-3.2-3B-Instruct-quantized.w8a8) | | +| Llama-3.1-8B | [meta-llama/Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct) | [RedHatAI/Meta-Llama-3.1-8B-quantized.w8a8](https://huggingface.co/RedHatAI/Meta-Llama-3.1-8B-quantized.w8a8) | | +| QwQ-32B | | [RedHatAI/QwQ-32B-quantized.w8a8](https://huggingface.co/RedHatAI/QwQ-32B-quantized.w8a8) | | +| DeepSeek-Distilled-Llama | | [RedHatAI/DeepSeek-R1-Distill-Llama-70B-quantized.w8a8](https://huggingface.co/RedHatAI/DeepSeek-R1-Distill-Llama-70B-quantized.w8a8) | | +| Qwen3-235B | | | [Qwen/Qwen3-235B-A22B-FP8](https://huggingface.co/Qwen/Qwen3-235B-A22B-FP8) | + +**Note:** The model identifiers listed in the table above +have been verified on 6th Gen Intel® Xeon® P-core platforms. + +## Installation + +### Install Using Docker + +It is recommended to use Docker for setting up the SGLang environment. +A [Dockerfile](https://github.com/sgl-project/sglang/blob/main/docker/Dockerfile.xeon) is provided to facilitate the installation. +Replace `` below with your [HuggingFace access token](https://huggingface.co/docs/hub/en/security-tokens). + +```bash +# Clone the SGLang repository +git clone https://github.com/sgl-project/sglang.git +cd sglang/docker + +# Build the docker image +docker build -t sglang-cpu:main -f Dockerfile.xeon . + +# Initiate a docker container +docker run \ + -it \ + --privileged \ + --ipc=host \ + --network=host \ + -v /dev/shm:/dev/shm \ + -v ~/.cache/huggingface:/root/.cache/huggingface \ + -p 30000:30000 \ + -e "HF_TOKEN=" \ + sglang-cpu:main /bin/bash +``` + +### Install From Source + +If you'd prefer to install SGLang in a bare metal environment, +the command list is as below. +It is worth noting that the environment variable `SGLANG_USE_CPU_ENGINE=1` +is required to enable SGLang service with CPU engine. + +```bash +# Create and activate a conda environment +conda create -n sgl-cpu python=3.12 -y +conda activate sgl-cpu + +# Optional: Set PyTorch CPU as primary pip install channel to avoid installing CUDA version +pip config set global.index-url https://download.pytorch.org/whl/cpu +pip config set global.extra-index-url https://pypi.org/simple + +# Check if some conda related environment variables have been set +env | grep -i conda +# The following environment variable settings are required +# if they have not been set properly +export CONDA_EXE=$(which conda) +export CONDA_ROOT=${CONDA_EXE}/../.. +export CONDA_PREFIX=${CONDA_ROOT}/envs/sgl-cpu +export PATH=${PATH}:${CONDA_ROOT}/bin:${CONDA_ROOT}/condabin + +# Clone the SGLang code +git clone https://github.com/sgl-project/sglang.git +cd sglang +git checkout + +# Install SGLang dependent libs, and build SGLang main package +pip install --upgrade pip setuptools +conda install -y libsqlite==3.48.0 gperftools tbb libnuma numactl +pip install intel-openmp +pip install -e "python[all_cpu]" + +# Build the CPU backend kernels +cd sgl-kernel +cp pyproject_cpu.toml pyproject.toml +pip install -v . + +# Other required environment variables +# Recommend to set these in ~/.bashrc in order not to set every time in a new terminal +export SGLANG_USE_CPU_ENGINE=1 +export LD_PRELOAD=${LD_PRELOAD}:${CONDA_PREFIX}/lib/libiomp5.so:${CONDA_PREFIX}/lib/libtcmalloc.so:${CONDA_PREFIX}/lib/libtbbmalloc.so.2 +``` + +## Launch of the Serving Engine + +Example command to launch SGLang serving: + +```bash +python -m sglang.launch_server \ + --model \ + --trust-remote-code \ + --disable-overlap-schedule \ + --device cpu \ + --host 0.0.0.0 \ + --tp 6 +``` + +Notes: + +1. For running W8A8 quantized models, please add the flag `--quantization w8a8_int8`. + +2. The flag `--tp 6` specifies that tensor parallelism will be applied using 6 ranks (TP6). + The number of TP specified is how many TP ranks will be used during the execution. + In a CPU platform, a TP rank means a sub-NUMA cluster (SNC). + Usually we can get the SNC information (How many available) from Operation System. + User can specify TP to be no more than the total available SNCs in current system. + + If the specified TP rank number differs from the total SNC count, + the system will automatically utilize the first `n` SNCs. + Note that `n` cannot exceed the total SNC number, doing so will result in an error. + + To specify the cores to be used, we need to explicitly set the environment variable `SGLANG_CPU_OMP_THREADS_BIND`. + For example, if we want to run the SGLang service using the first 40 cores of each SNC on a Xeon® 6980P server, + which has 43-43-42 cores on the 3 SNCs of a socket, we should set: + + ```bash + export SGLANG_CPU_OMP_THREADS_BIND="0-39|43-82|86-125|128-167|171-210|214-253" + ``` + +3. A warmup step is automatically triggered when the service is started. +The server is ready when you see the log `The server is fired up and ready to roll!`. + +## Benchmarking with Requests + +You can benchmark the performance via the `bench_serving` script. +Run the command in another terminal. + +```bash +python -m sglang.bench_serving \ + --dataset-name random \ + --random-input-len 1024 \ + --random-output-len 1024 \ + --num-prompts 1 \ + --request-rate inf \ + --random-range-ratio 1.0 +``` + +The detail explanations of the parameters can be looked up by the command: + +```bash +python -m sglang.bench_serving -h +``` + +Additionally, the requests can be formed with +[OpenAI Completions API](https://docs.sglang.ai/backend/openai_api_completions.html) +and sent via the command line (e.g. using `curl`) or via your own script. + +## Example: Running DeepSeek-R1 + +An example command to launch service for W8A8 DeepSeek-R1 on a Xeon® 6980P server + +```bash +python -m sglang.launch_server \ + --model meituan/DeepSeek-R1-Channel-INT8 \ + --trust-remote-code \ + --disable-overlap-schedule \ + --device cpu \ + --quantization w8a8_int8 \ + --host 0.0.0.0 \ + --mem-fraction-static 0.8 \ + --max-total-token 65536 \ + --tp 6 +``` + +Similarly, an example command to launch service for FP8 DeepSeek-R1 would be + +```bash +python -m sglang.launch_server \ + --model deepseek-ai/DeepSeek-R1 \ + --trust-remote-code \ + --disable-overlap-schedule \ + --device cpu \ + --host 0.0.0.0 \ + --mem-fraction-static 0.8 \ + --max-total-token 65536 \ + --tp 6 +``` + +Then you can test with `bench_serving` command or construct your own command or script +following [the benchmarking example](#benchmarking-with-requests). diff --git a/docs/references/deepseek.md b/docs/references/deepseek.md index efa4f1928616..8b6d688d1507 100644 --- a/docs/references/deepseek.md +++ b/docs/references/deepseek.md @@ -14,6 +14,7 @@ To run DeepSeek V3/R1 models, the requirements are as follows: | **Full precision FP8**
*(recommended)* | 8 x H200 | | | 8 x MI300X | | | 2 x 8 x H100/800/20 | +| | Xeon 6980P CPU | | **Full precision BF16** | 2 x 8 x H200 | | | 2 x 8 x MI300X | | | 4 x 8 x H100/800/20 | @@ -22,6 +23,7 @@ To run DeepSeek V3/R1 models, the requirements are as follows: | | 8 x A100/A800 | | **Quantized weights (int8)** | 16 x A100/800 | | | 32 x L40S | +| | Xeon 6980P CPU |