From 00b648a8f62a4c618a62c60db8b0b87ca39d028b Mon Sep 17 00:00:00 2001 From: zRzRzRzRzRzRzR <2448370773@qq.com> Date: Tue, 24 Mar 2026 11:17:31 +0800 Subject: [PATCH 01/20] Add repetition_penalties for #5703 --- .../sampling/penaltylib/repetition_penalty.py | 73 +++++++++++++++++++ .../srt/sampling/sampling_batch_info.py | 29 +++++++- python/sglang/srt/speculative/eagle_info.py | 22 ++++++ 3 files changed, 120 insertions(+), 4 deletions(-) create mode 100644 python/sglang/srt/sampling/penaltylib/repetition_penalty.py diff --git a/python/sglang/srt/sampling/penaltylib/repetition_penalty.py b/python/sglang/srt/sampling/penaltylib/repetition_penalty.py new file mode 100644 index 000000000000..f29dc255de1b --- /dev/null +++ b/python/sglang/srt/sampling/penaltylib/repetition_penalty.py @@ -0,0 +1,73 @@ +import torch + +from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer +from sglang.srt.utils import get_compiler_backend + + +@torch.compile(dynamic=True, backend=get_compiler_backend()) +def apply_scaling_penalties(logits, scaling_penalties): + logits[:] = torch.where( + logits < 0, + logits * scaling_penalties, + logits / scaling_penalties, + ) + + +class BatchedRepetitionPenalizer(_BatchedPenalizer): + """ + Repetition penalizer penalizes tokens based on their repetition in the input and output. + """ + + def _is_required(self) -> bool: + return any( + req.sampling_params.repetition_penalty != 1.0 + for req in self.orchestrator.reqs() + ) + + def _prepare(self): + self.cumulated_repetition_penalties = torch.ones( + (len(self.orchestrator.reqs()), self.orchestrator.vocab_size), + dtype=torch.float32, + device=self.orchestrator.device, + ) + self.repetition_penalties = ( + torch.tensor( + data=[ + req.sampling_params.repetition_penalty + for req in self.orchestrator.reqs() + ], + dtype=torch.float32, + device=self.orchestrator.device, + ) + ).unsqueeze_(1) + + def _cumulate_output_tokens(self, output_ids: torch.Tensor): + self.cumulated_repetition_penalties.scatter_( + dim=1, + index=output_ids.unsqueeze(1), + src=self.repetition_penalties, + ) + + def _apply(self, logits: torch.Tensor) -> torch.Tensor: + apply_scaling_penalties(logits, self.cumulated_repetition_penalties) + return logits + + def _filter(self, keep_indices: torch.Tensor): + self.repetition_penalties = self.repetition_penalties[keep_indices] + self.cumulated_repetition_penalties = self.cumulated_repetition_penalties[ + keep_indices + ] + + def _merge(self, their: "BatchedRepetitionPenalizer"): + self.repetition_penalties = torch.cat( + [self.repetition_penalties, their.repetition_penalties], dim=0 + ) + self.cumulated_repetition_penalties = torch.cat( + [self.cumulated_repetition_penalties, their.cumulated_repetition_penalties], + dim=0, + ) + + def _teardown(self) -> None: + for name in ("repetition_penalties", "cumulated_repetition_penalties"): + if hasattr(self, name): + delattr(self, name) diff --git a/python/sglang/srt/sampling/sampling_batch_info.py b/python/sglang/srt/sampling/sampling_batch_info.py index a8f17c754d8e..17c878526900 100644 --- a/python/sglang/srt/sampling/sampling_batch_info.py +++ b/python/sglang/srt/sampling/sampling_batch_info.py @@ -8,6 +8,7 @@ import sglang.srt.sampling.penaltylib as penaltylib from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor +from sglang.srt.sampling.penaltylib.repetition_penalty import apply_scaling_penalties from sglang.srt.sampling.sampling_params import TOP_K_ALL from sglang.srt.server_args import get_global_server_args @@ -46,7 +47,10 @@ class SamplingBatchInfo: # Penalizer penalizer_orchestrator: Optional[penaltylib.BatchedPenalizerOrchestrator] = None - acc_linear_penalties: torch.Tensor = None # Used in the overlap mode + acc_linear_penalties: Optional[torch.Tensor] = None # Used in the overlap mode + acc_scaling_penalties: Optional[torch.Tensor] = ( + None # Used in the overlap mode for repetition penalty + ) # Whether any request has custom logit processor has_custom_logit_processor: bool = False @@ -159,6 +163,7 @@ def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int): penaltylib.BatchedFrequencyPenalizer, penaltylib.BatchedMinNewTokensPenalizer, penaltylib.BatchedPresencePenalizer, + penaltylib.BatchedRepetitionPenalizer, }, ) @@ -228,17 +233,33 @@ def update_penalties(self): dtype=torch.float32, device=self.temperatures.device, ) - self.penalizer_orchestrator.apply(self.acc_linear_penalties) + self.acc_scaling_penalties = None + + for pen in self.penalizer_orchestrator.penalizers.values(): + if not pen._is_prepared: + continue + if isinstance(pen, penaltylib.BatchedRepetitionPenalizer): + # Snapshot the multiplicative penalty for overlap-safe forwarding + self.acc_scaling_penalties = ( + pen.cumulated_repetition_penalties.clone() + ) + else: + pen.apply(self.acc_linear_penalties) else: self.acc_linear_penalties = None + self.acc_scaling_penalties = None def apply_logits_bias(self, logits: torch.Tensor): + # Overlap mode: additive penalties (frequency, presence, min_new_tokens) if self.acc_linear_penalties is not None: - # Used in the overlap mode logits.add_(self.acc_linear_penalties) + # Overlap mode: multiplicative penalties (repetition) + if self.acc_scaling_penalties is not None: + apply_scaling_penalties(logits, self.acc_scaling_penalties) + + # Non-overlap mode: apply all penalties directly if self.penalizer_orchestrator and self.penalizer_orchestrator.is_required: - # Used in the non-overlap mode self.penalizer_orchestrator.apply(logits) if self.vocab_mask is not None: diff --git a/python/sglang/srt/speculative/eagle_info.py b/python/sglang/srt/speculative/eagle_info.py index 572a76f6a140..e1a8f330ed8c 100644 --- a/python/sglang/srt/speculative/eagle_info.py +++ b/python/sglang/srt/speculative/eagle_info.py @@ -294,6 +294,28 @@ def verify( torch.repeat_interleave(linear_penalty, self.draft_token_num, dim=0) ) + # Apply multiplicative penalties (repetition) directly on logits + import sglang.srt.sampling.penaltylib as penaltylib + + rep_cls = penaltylib.BatchedRepetitionPenalizer + if rep_cls in sampling_info.penalizer_orchestrator.penalizers: + rep_pen = sampling_info.penalizer_orchestrator.penalizers[rep_cls] + if rep_pen._is_prepared: + + # Expand per-request penalties to match draft_token_num layout + expanded_penalties = torch.repeat_interleave( + rep_pen.cumulated_repetition_penalties, + self.draft_token_num, + dim=0, + ) + mask = expanded_penalties != 1.0 + if mask.any(): + logits = logits_output.next_token_logits + pos = (logits > 0) & mask + neg = (logits <= 0) & mask + logits[pos] = logits[pos] / expanded_penalties[pos] + logits[neg] = logits[neg] * expanded_penalties[neg] + # Apply grammar mask if vocab_mask is not None: assert self.grammar is not None From fde0313504f68766fbe3819b2e8405e0bfdb9a25 Mon Sep 17 00:00:00 2001 From: zRzRzRzRzRzRzR <2448370773@qq.com> Date: Tue, 24 Mar 2026 11:25:36 +0800 Subject: [PATCH 02/20] add init --- python/sglang/srt/sampling/penaltylib/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/sglang/srt/sampling/penaltylib/__init__.py b/python/sglang/srt/sampling/penaltylib/__init__.py index 26a780517ce7..9ba6d73ac68f 100644 --- a/python/sglang/srt/sampling/penaltylib/__init__.py +++ b/python/sglang/srt/sampling/penaltylib/__init__.py @@ -2,10 +2,12 @@ from sglang.srt.sampling.penaltylib.min_new_tokens import BatchedMinNewTokensPenalizer from sglang.srt.sampling.penaltylib.orchestrator import BatchedPenalizerOrchestrator from sglang.srt.sampling.penaltylib.presence_penalty import BatchedPresencePenalizer +from sglang.srt.sampling.penaltylib.repetition_penalty import BatchedRepetitionPenalizer __all__ = [ "BatchedFrequencyPenalizer", "BatchedMinNewTokensPenalizer", "BatchedPresencePenalizer", "BatchedPenalizerOrchestrator", + "BatchedRepetitionPenalizer", ] From 0b12c1de020eaf3354a2c61ec89d9702e597c281 Mon Sep 17 00:00:00 2001 From: zRzRzRzRzRzRzR <2448370773@qq.com> Date: Tue, 24 Mar 2026 11:35:29 +0800 Subject: [PATCH 03/20] change comment --- python/sglang/srt/sampling/sampling_batch_info.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/sampling/sampling_batch_info.py b/python/sglang/srt/sampling/sampling_batch_info.py index 17c878526900..e335f9712bbb 100644 --- a/python/sglang/srt/sampling/sampling_batch_info.py +++ b/python/sglang/srt/sampling/sampling_batch_info.py @@ -250,16 +250,16 @@ def update_penalties(self): self.acc_scaling_penalties = None def apply_logits_bias(self, logits: torch.Tensor): - # Overlap mode: additive penalties (frequency, presence, min_new_tokens) if self.acc_linear_penalties is not None: + # Used in the overlap mode logits.add_(self.acc_linear_penalties) - # Overlap mode: multiplicative penalties (repetition) if self.acc_scaling_penalties is not None: + # Used in the overlap mode apply_scaling_penalties(logits, self.acc_scaling_penalties) - # Non-overlap mode: apply all penalties directly if self.penalizer_orchestrator and self.penalizer_orchestrator.is_required: + # Used in the overlap mode self.penalizer_orchestrator.apply(logits) if self.vocab_mask is not None: From fe26403b3d12884aad2ac766813a8c35f8cbeeca Mon Sep 17 00:00:00 2001 From: zRzRzRzRzRzRzR <2448370773@qq.com> Date: Tue, 24 Mar 2026 11:36:06 +0800 Subject: [PATCH 04/20] Update sampling_batch_info.py --- python/sglang/srt/sampling/sampling_batch_info.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/sampling/sampling_batch_info.py b/python/sglang/srt/sampling/sampling_batch_info.py index e335f9712bbb..174c17b89d78 100644 --- a/python/sglang/srt/sampling/sampling_batch_info.py +++ b/python/sglang/srt/sampling/sampling_batch_info.py @@ -259,7 +259,7 @@ def apply_logits_bias(self, logits: torch.Tensor): apply_scaling_penalties(logits, self.acc_scaling_penalties) if self.penalizer_orchestrator and self.penalizer_orchestrator.is_required: - # Used in the overlap mode + # Used in the non-overlap mode self.penalizer_orchestrator.apply(logits) if self.vocab_mask is not None: From 0d6c418c0ed96808de8c1449d8daa930a0e1ab39 Mon Sep 17 00:00:00 2001 From: zRzRzRzRzRzRzR <2448370773@qq.com> Date: Tue, 24 Mar 2026 12:55:30 +0800 Subject: [PATCH 05/20] Update glm4_moe.py --- python/sglang/srt/models/glm4_moe.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/models/glm4_moe.py b/python/sglang/srt/models/glm4_moe.py index 0c1e1d00084a..da0a19484454 100644 --- a/python/sglang/srt/models/glm4_moe.py +++ b/python/sglang/srt/models/glm4_moe.py @@ -328,7 +328,9 @@ def __init__( ) def forward(self, hidden_states): - logits = F.linear(hidden_states, self.weight, None) + # Cast to FP32 before lm_head projection to avoid BF16 accumulated rounding errors + # in the large matmul, which can flip top-1 token rankings under greedy decoding. + logits = F.linear(hidden_states.to(torch.float32), self.weight.to(torch.float32), None) return logits From 9d4efe163421b2eac746a82503580ffbcc762b1e Mon Sep 17 00:00:00 2001 From: zRzRzRzRzRzRzR <2448370773@qq.com> Date: Tue, 24 Mar 2026 13:03:44 +0800 Subject: [PATCH 06/20] format --- python/sglang/srt/models/glm4_moe.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/models/glm4_moe.py b/python/sglang/srt/models/glm4_moe.py index da0a19484454..658b5a8a8e4a 100644 --- a/python/sglang/srt/models/glm4_moe.py +++ b/python/sglang/srt/models/glm4_moe.py @@ -330,7 +330,9 @@ def __init__( def forward(self, hidden_states): # Cast to FP32 before lm_head projection to avoid BF16 accumulated rounding errors # in the large matmul, which can flip top-1 token rankings under greedy decoding. - logits = F.linear(hidden_states.to(torch.float32), self.weight.to(torch.float32), None) + logits = F.linear( + hidden_states.to(torch.float32), self.weight.to(torch.float32), None + ) return logits From e27aa9212ee6b865339e12b9653ee423b8936f62 Mon Sep 17 00:00:00 2001 From: zRzRzRzRzRzRzR <2448370773@qq.com> Date: Tue, 24 Mar 2026 13:05:49 +0800 Subject: [PATCH 07/20] Update glm4_moe.py --- python/sglang/srt/models/glm4_moe.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/sglang/srt/models/glm4_moe.py b/python/sglang/srt/models/glm4_moe.py index 658b5a8a8e4a..3a1c0127297e 100644 --- a/python/sglang/srt/models/glm4_moe.py +++ b/python/sglang/srt/models/glm4_moe.py @@ -328,8 +328,7 @@ def __init__( ) def forward(self, hidden_states): - # Cast to FP32 before lm_head projection to avoid BF16 accumulated rounding errors - # in the large matmul, which can flip top-1 token rankings under greedy decoding. + # Cast to FP32 before gate projection for GLM-V model. logits = F.linear( hidden_states.to(torch.float32), self.weight.to(torch.float32), None ) From 885052687f7071e75faf3c306d4a2e9053fff8aa Mon Sep 17 00:00:00 2001 From: zRzRzRzRzRzRzR <2448370773@qq.com> Date: Wed, 25 Mar 2026 14:31:02 +0800 Subject: [PATCH 08/20] update with glm interleaved --- .../srt/layers/rotary_embedding/factory.py | 3 ++ .../srt/layers/rotary_embedding/mrope.py | 33 +++++++++++++++++++ .../layers/rotary_embedding/triton_kernels.py | 18 ++++++++-- 3 files changed, 51 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/layers/rotary_embedding/factory.py b/python/sglang/srt/layers/rotary_embedding/factory.py index e95e9543f7f6..27e28577c96e 100644 --- a/python/sglang/srt/layers/rotary_embedding/factory.py +++ b/python/sglang/srt/layers/rotary_embedding/factory.py @@ -171,6 +171,9 @@ def get_rope( dtype, mrope_section=rope_scaling["mrope_section"], mrope_interleaved=rope_scaling.get("mrope_interleaved", False), + mrope_interleaved_glm=rope_scaling.get( + "mrope_interleaved_glm", False + ), ) elif rope_scaling.get("use_fope", False): rotary_emb = FourierRotaryEmbedding( diff --git a/python/sglang/srt/layers/rotary_embedding/mrope.py b/python/sglang/srt/layers/rotary_embedding/mrope.py index 237528fd1d47..0e4e8b5e4f05 100644 --- a/python/sglang/srt/layers/rotary_embedding/mrope.py +++ b/python/sglang/srt/layers/rotary_embedding/mrope.py @@ -52,12 +52,14 @@ def __init__( dtype: torch.dtype, mrope_section: Optional[List[int]] = None, mrope_interleaved: bool = False, + mrope_interleaved_glm: bool = False, ) -> None: super().__init__( head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype ) self.mrope_section = mrope_section self.mrope_interleaved = mrope_interleaved + self.mrope_interleaved_glm = mrope_interleaved_glm if self.mrope_section: expected_sum = rotary_dim // 2 actual_sum = sum(self.mrope_section) @@ -86,6 +88,35 @@ def __init__( f"Corrected mrope_section: {self.mrope_section} (sum={sum(self.mrope_section)})" ) + # MRoPE axis_map interleaving pattern depends on mrope_section sizes. + # The algorithm cycles through axes [0(T), 1(H), 2(W)] round-robin, + # skipping any axis that has exhausted its allocated pairs. + # + # For GLM-V (mrope_section=[8,12,12]): + # T(8) < H(12) = W(12), so T exhausts first at pair 24. + # Result: [0,1,2, 0,1,2, 0,1,2, 0,1,2, 0,1,2, 0,1,2, 0,1,2, 0,1,2, 1,1,2, 1,1,2, 2,2] + # After T runs out, only H and W fill the remaining slots. + # + # For Qwen3-VL (mrope_section=[24,20,20]): + # T(24) > H(20) = W(20), so H and W exhaust first near the tail. + # Result: [0,1,2, 0,1,2, ...repeated evenly..., 0,1, 0,1, 0,0] + # After H/W run out, T fills the remaining slots. + + num_pairs = rotary_dim // 2 + axis_map = torch.empty(num_pairs, dtype=torch.long) + assert sum(self.mrope_section) == num_pairs + counts = [0, 0, 0] + current_ax = 0 + + for i in range(num_pairs): + current_ax = i % 3 + while counts[current_ax] >= self.mrope_section[current_ax]: + current_ax = (current_ax + 1) % 3 + + axis_map[i] = current_ax + counts[current_ax] += 1 + + self.register_buffer("axis_map", axis_map, persistent=False) if get_global_server_args().rl_on_policy_target is not None: self._forward_method = self.forward_native @@ -214,7 +245,9 @@ def forward_triton( self.head_size, self.rotary_dim, self.mrope_interleaved, + self.mrope_interleaved_glm, self.is_neox_style, + self.axis_map, ) return query, key diff --git a/python/sglang/srt/layers/rotary_embedding/triton_kernels.py b/python/sglang/srt/layers/rotary_embedding/triton_kernels.py index 9a3d21bf83bb..0a8dc2c33c7b 100644 --- a/python/sglang/srt/layers/rotary_embedding/triton_kernels.py +++ b/python/sglang/srt/layers/rotary_embedding/triton_kernels.py @@ -29,7 +29,9 @@ def _triton_mrope_forward_fused( mrope_section_h: tl.constexpr, mrope_section_w: tl.constexpr, is_interleaved: tl.constexpr, + is_interleaved_glm: tl.constexpr, is_neox_style: tl.constexpr, + axis_map_ptr, ): pid = tl.program_id(0) q_ptr = q_ptr + pid * q_stride @@ -46,9 +48,15 @@ def _triton_mrope_forward_fused( w_sin = w_cos + half_rd cos_offsets = tl.arange(0, pad_hd // 2) if is_interleaved: - h_mask = ((cos_offsets % 3) == 1) & (cos_offsets <= 3 * mrope_section_h) - w_mask = ((cos_offsets % 3) == 2) & (cos_offsets <= 3 * mrope_section_w) - t_mask = ~(h_mask | w_mask) + if is_interleaved_glm: + axes = tl.load(axis_map_ptr + cos_offsets, mask=cos_offsets < (pad_hd // 2)) + t_mask = axes == 0 + h_mask = axes == 1 + w_mask = axes == 2 + else: + h_mask = ((cos_offsets % 3) == 1) & (cos_offsets <= 3 * mrope_section_h) + w_mask = ((cos_offsets % 3) == 2) & (cos_offsets <= 3 * mrope_section_w) + t_mask = ~(h_mask | w_mask) else: t_end = mrope_section_t h_end = t_end + mrope_section_h @@ -109,7 +117,9 @@ def triton_mrope_fused( head_size: int, rotary_dim: int, mrope_interleaved: bool, + mrope_interleaved_glm: bool, is_neox_style: bool, + axis_map: torch.Tensor, ) -> None: num_tokens, n_q_dim = q.shape n_k_dim = k.shape[1] @@ -137,7 +147,9 @@ def triton_mrope_fused( mrope_section[1], mrope_section[2], mrope_interleaved, + mrope_interleaved_glm, is_neox_style, + axis_map, ) From 2d24f494f4114df5c45caa1b9530205462c48195 Mon Sep 17 00:00:00 2001 From: zRzRzRzRzRzRzR <2448370773@qq.com> Date: Thu, 26 Mar 2026 14:57:47 +0800 Subject: [PATCH 09/20] fix comment --- .../sampling/penaltylib/repetition_penalty.py | 2 +- .../srt/sampling/sampling_batch_info.py | 3 +++ python/sglang/srt/speculative/eagle_info.py | 20 +++++++++---------- 3 files changed, 13 insertions(+), 12 deletions(-) diff --git a/python/sglang/srt/sampling/penaltylib/repetition_penalty.py b/python/sglang/srt/sampling/penaltylib/repetition_penalty.py index f29dc255de1b..aa98784e172f 100644 --- a/python/sglang/srt/sampling/penaltylib/repetition_penalty.py +++ b/python/sglang/srt/sampling/penaltylib/repetition_penalty.py @@ -15,7 +15,7 @@ def apply_scaling_penalties(logits, scaling_penalties): class BatchedRepetitionPenalizer(_BatchedPenalizer): """ - Repetition penalizer penalizes tokens based on their repetition in the input and output. + Repetition penalizer penalizes tokens based on their presence in the generated output. """ def _is_required(self) -> bool: diff --git a/python/sglang/srt/sampling/sampling_batch_info.py b/python/sglang/srt/sampling/sampling_batch_info.py index e25ac323de45..4ee957602510 100644 --- a/python/sglang/srt/sampling/sampling_batch_info.py +++ b/python/sglang/srt/sampling/sampling_batch_info.py @@ -244,6 +244,9 @@ def update_penalties(self): for pen in self.penalizer_orchestrator.penalizers.values(): if not pen._is_prepared: continue + # NOTE: BatchedRepetitionPenalizer is a multiplicative penalizer and must + # be handled separately via acc_scaling_penalties. If a new multiplicative + # penalizer is added in the future, it must also be handled here explicitly. if isinstance(pen, penaltylib.BatchedRepetitionPenalizer): # Snapshot the multiplicative penalty for overlap-safe forwarding self.acc_scaling_penalties = ( diff --git a/python/sglang/srt/speculative/eagle_info.py b/python/sglang/srt/speculative/eagle_info.py index e1a8f330ed8c..0d468c5d82aa 100644 --- a/python/sglang/srt/speculative/eagle_info.py +++ b/python/sglang/srt/speculative/eagle_info.py @@ -20,6 +20,10 @@ get_last_loc, ) from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode +from sglang.srt.sampling.penaltylib import BatchedRepetitionPenalizer +from sglang.srt.sampling.penaltylib.penalizers.repetition_penalty import ( + apply_scaling_penalties, +) from sglang.srt.server_args import get_global_server_args from sglang.srt.speculative.eagle_info_v2 import ( EagleDraftInputV2Mixin, @@ -295,26 +299,20 @@ def verify( ) # Apply multiplicative penalties (repetition) directly on logits - import sglang.srt.sampling.penaltylib as penaltylib - - rep_cls = penaltylib.BatchedRepetitionPenalizer + rep_cls = BatchedRepetitionPenalizer if rep_cls in sampling_info.penalizer_orchestrator.penalizers: rep_pen = sampling_info.penalizer_orchestrator.penalizers[rep_cls] if rep_pen._is_prepared: - # Expand per-request penalties to match draft_token_num layout expanded_penalties = torch.repeat_interleave( rep_pen.cumulated_repetition_penalties, self.draft_token_num, dim=0, ) - mask = expanded_penalties != 1.0 - if mask.any(): - logits = logits_output.next_token_logits - pos = (logits > 0) & mask - neg = (logits <= 0) & mask - logits[pos] = logits[pos] / expanded_penalties[pos] - logits[neg] = logits[neg] * expanded_penalties[neg] + apply_scaling_penalties( + logits_output.next_token_logits, + expanded_penalties, + ) # Apply grammar mask if vocab_mask is not None: From 4be9b6cd8722f07b6050c460bfccc49038a73c6c Mon Sep 17 00:00:00 2001 From: zRzRzRzRzRzRzR <2448370773@qq.com> Date: Thu, 26 Mar 2026 15:11:39 +0800 Subject: [PATCH 10/20] fix --- .../srt/sampling/penaltylib/orchestrator.py | 2 ++ .../sampling/penaltylib/repetition_penalty.py | 2 ++ .../sglang/srt/sampling/sampling_batch_info.py | 16 ++++++++-------- 3 files changed, 12 insertions(+), 8 deletions(-) diff --git a/python/sglang/srt/sampling/penaltylib/orchestrator.py b/python/sglang/srt/sampling/penaltylib/orchestrator.py index 7ef123f554f9..80264e29d26f 100644 --- a/python/sglang/srt/sampling/penaltylib/orchestrator.py +++ b/python/sglang/srt/sampling/penaltylib/orchestrator.py @@ -132,6 +132,8 @@ class _BatchedPenalizer(abc.ABC): An abstract class for a batched penalizer. """ + is_multiplicative: bool = False + def __init__(self, orchestrator: BatchedPenalizerOrchestrator): self._orchestrator_ref: weakref.ReferenceType[BatchedPenalizerOrchestrator] = ( weakref.ref(orchestrator) diff --git a/python/sglang/srt/sampling/penaltylib/repetition_penalty.py b/python/sglang/srt/sampling/penaltylib/repetition_penalty.py index aa98784e172f..9e4c88721a5f 100644 --- a/python/sglang/srt/sampling/penaltylib/repetition_penalty.py +++ b/python/sglang/srt/sampling/penaltylib/repetition_penalty.py @@ -18,6 +18,8 @@ class BatchedRepetitionPenalizer(_BatchedPenalizer): Repetition penalizer penalizes tokens based on their presence in the generated output. """ + is_multiplicative: bool = True + def _is_required(self) -> bool: return any( req.sampling_params.repetition_penalty != 1.0 diff --git a/python/sglang/srt/sampling/sampling_batch_info.py b/python/sglang/srt/sampling/sampling_batch_info.py index 4ee957602510..6a0ef9d5ef4d 100644 --- a/python/sglang/srt/sampling/sampling_batch_info.py +++ b/python/sglang/srt/sampling/sampling_batch_info.py @@ -244,14 +244,14 @@ def update_penalties(self): for pen in self.penalizer_orchestrator.penalizers.values(): if not pen._is_prepared: continue - # NOTE: BatchedRepetitionPenalizer is a multiplicative penalizer and must - # be handled separately via acc_scaling_penalties. If a new multiplicative - # penalizer is added in the future, it must also be handled here explicitly. - if isinstance(pen, penaltylib.BatchedRepetitionPenalizer): - # Snapshot the multiplicative penalty for overlap-safe forwarding - self.acc_scaling_penalties = ( - pen.cumulated_repetition_penalties.clone() - ) + if getattr(pen, "is_multiplicative", False): + # Accumulate multiplicative penalties (e.g. repetition penalty) + if self.acc_scaling_penalties is None: + self.acc_scaling_penalties = ( + pen.cumulated_repetition_penalties.clone() + ) + else: + self.acc_scaling_penalties *= pen.cumulated_repetition_penalties else: pen.apply(self.acc_linear_penalties) else: From ae96759ad221bb2a00253c08a2b98cad3ab338b2 Mon Sep 17 00:00:00 2001 From: zRzRzRzRzRzRzR <2448370773@qq.com> Date: Thu, 26 Mar 2026 15:16:06 +0800 Subject: [PATCH 11/20] Update eagle_info.py --- python/sglang/srt/speculative/eagle_info.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/speculative/eagle_info.py b/python/sglang/srt/speculative/eagle_info.py index 0d468c5d82aa..25b711f7c8de 100644 --- a/python/sglang/srt/speculative/eagle_info.py +++ b/python/sglang/srt/speculative/eagle_info.py @@ -21,7 +21,7 @@ ) from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode from sglang.srt.sampling.penaltylib import BatchedRepetitionPenalizer -from sglang.srt.sampling.penaltylib.penalizers.repetition_penalty import ( +from sglang.srt.sampling.penaltylib.repetition_penalty import ( apply_scaling_penalties, ) from sglang.srt.server_args import get_global_server_args From 095121edb7ba9558cba136c9c60fd7a63b37f638 Mon Sep 17 00:00:00 2001 From: zRzRzRzRzRzRzR <2448370773@qq.com> Date: Thu, 26 Mar 2026 16:09:41 +0800 Subject: [PATCH 12/20] Update ngram_info.py --- python/sglang/srt/speculative/ngram_info.py | 22 ++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/speculative/ngram_info.py b/python/sglang/srt/speculative/ngram_info.py index c18cf79658d7..0e1849e03c79 100644 --- a/python/sglang/srt/speculative/ngram_info.py +++ b/python/sglang/srt/speculative/ngram_info.py @@ -26,6 +26,8 @@ alloc_token_slots, get_last_loc, ) +from sglang.srt.sampling.penaltylib import BatchedRepetitionPenalizer +from sglang.srt.sampling.penaltylib.repetition_penalty import apply_scaling_penalties from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.speculative.spec_info import SpecInput, SpecInputType from sglang.srt.speculative.spec_utils import ( @@ -398,7 +400,10 @@ def verify( ) # Apply penalty - if sampling_info.penalizer_orchestrator.is_required: + if ( + sampling_info.penalizer_orchestrator.is_required + or sampling_info.logit_bias is not None + ): # This is a relaxed version of penalties for speculative decoding. linear_penalty = torch.zeros( (bs, logits_output.next_token_logits.shape[1]), @@ -410,6 +415,21 @@ def verify( torch.repeat_interleave(linear_penalty, self.draft_token_num, dim=0) ) + # Apply multiplicative penalties (repetition) directly on logits + rep_cls = BatchedRepetitionPenalizer + if rep_cls in sampling_info.penalizer_orchestrator.penalizers: + rep_pen = sampling_info.penalizer_orchestrator.penalizers[rep_cls] + if rep_pen._is_prepared: + expanded_penalties = torch.repeat_interleave( + rep_pen.cumulated_repetition_penalties, + self.draft_token_num, + dim=0, + ) + apply_scaling_penalties( + logits_output.next_token_logits, + expanded_penalties, + ) + # Apply grammar mask if vocab_mask is not None: assert self.grammar is not None From ad05b34765061dd85d236017ebd7c8a8732d18a7 Mon Sep 17 00:00:00 2001 From: Xinyuan Tong Date: Fri, 27 Mar 2026 18:54:38 +0000 Subject: [PATCH 13/20] Fix test_required_creates_penalties_tensor to match new per-penalizer apply logic The update_penalties method now iterates individual penalizers instead of calling orchestrator.apply(). Update test to mock a linear penalizer and assert pen.apply is called with acc_linear_penalties. --- .../registered/unit/sampling/test_sampling_batch_info.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/test/registered/unit/sampling/test_sampling_batch_info.py b/test/registered/unit/sampling/test_sampling_batch_info.py index bc923932cd87..59e18d970af3 100644 --- a/test/registered/unit/sampling/test_sampling_batch_info.py +++ b/test/registered/unit/sampling/test_sampling_batch_info.py @@ -194,13 +194,18 @@ def test_no_bias_no_change(self): class TestUpdatePenalties(CustomTestCase): def test_required_creates_penalties_tensor(self): - """Test that update_penalties allocates a zero tensor and calls orchestrator.apply.""" + """Test that update_penalties allocates a zero tensor and calls penalizer apply.""" + # Create a mock linear penalizer (non-multiplicative, prepared) + linear_pen = MagicMock(_is_prepared=True) + linear_pen.is_multiplicative = False + orch = MagicMock(is_required=True) + orch.penalizers = {"linear": linear_pen} info = _make_info(batch_size=2, penalizer_orchestrator=orch) info.update_penalties() self.assertIsNotNone(info.acc_linear_penalties) self.assertEqual(info.acc_linear_penalties.shape, (2, VOCAB_SIZE)) - orch.apply.assert_called_once() + linear_pen.apply.assert_called_once_with(info.acc_linear_penalties) def test_not_required_sets_none(self): """Test that update_penalties sets acc_linear_penalties to None when not required.""" From abf2d96823cefcda808900a955604874f28a7200 Mon Sep 17 00:00:00 2001 From: zRzRzRzRzRzRzR <2448370773@qq.com> Date: Mon, 30 Mar 2026 14:44:25 +0800 Subject: [PATCH 14/20] fall back1 --- python/sglang/srt/models/glm4_moe.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/python/sglang/srt/models/glm4_moe.py b/python/sglang/srt/models/glm4_moe.py index 71897b1153fc..121a31630ae3 100644 --- a/python/sglang/srt/models/glm4_moe.py +++ b/python/sglang/srt/models/glm4_moe.py @@ -329,10 +329,7 @@ def __init__( ) def forward(self, hidden_states): - # Cast to FP32 before gate projection for GLM-V model. - logits = F.linear( - hidden_states.to(torch.float32), self.weight.to(torch.float32), None - ) + logits = F.linear(hidden_states, self.weight, None) return logits From 3f55e0d72a41ee66e75eacf31b14a5474369ee54 Mon Sep 17 00:00:00 2001 From: zRzRzRzRzRzRzR <2448370773@qq.com> Date: Mon, 30 Mar 2026 14:55:31 +0800 Subject: [PATCH 15/20] fix comment --- python/sglang/srt/speculative/eagle_info.py | 12 ++++++++++-- python/sglang/srt/speculative/ngram_info.py | 11 ++++++++++- 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/speculative/eagle_info.py b/python/sglang/srt/speculative/eagle_info.py index 25b711f7c8de..6c83642fb7c1 100644 --- a/python/sglang/srt/speculative/eagle_info.py +++ b/python/sglang/srt/speculative/eagle_info.py @@ -293,7 +293,16 @@ def verify( dtype=torch.float32, device=batch.device, ) - sampling_info.apply_logits_bias(linear_penalty) + + # Only apply non-multiplicative (additive) penalizers to the capture tensor. + # Multiplicative penalizers (e.g. repetition penalty) would corrupt additive + # penalty values, so they are applied directly on logits below. + for pen in sampling_info.penalizer_orchestrator.penalizers.values(): + if not getattr(pen, "is_multiplicative", False): + pen.apply(linear_penalty) + if sampling_info.logit_bias is not None: + linear_penalty.add_(sampling_info.logit_bias) + logits_output.next_token_logits.add_( torch.repeat_interleave(linear_penalty, self.draft_token_num, dim=0) ) @@ -303,7 +312,6 @@ def verify( if rep_cls in sampling_info.penalizer_orchestrator.penalizers: rep_pen = sampling_info.penalizer_orchestrator.penalizers[rep_cls] if rep_pen._is_prepared: - # Expand per-request penalties to match draft_token_num layout expanded_penalties = torch.repeat_interleave( rep_pen.cumulated_repetition_penalties, self.draft_token_num, diff --git a/python/sglang/srt/speculative/ngram_info.py b/python/sglang/srt/speculative/ngram_info.py index 0e1849e03c79..69e6b0a42451 100644 --- a/python/sglang/srt/speculative/ngram_info.py +++ b/python/sglang/srt/speculative/ngram_info.py @@ -410,7 +410,16 @@ def verify( dtype=torch.float32, device=self.device, ) - sampling_info.apply_logits_bias(linear_penalty) + + # Only apply non-multiplicative (additive) penalizers to the capture tensor. + # Multiplicative penalizers (e.g. repetition penalty) would corrupt additive + # penalty values, so they are applied directly on logits below. + for pen in sampling_info.penalizer_orchestrator.penalizers.values(): + if not getattr(pen, "is_multiplicative", False): + pen.apply(linear_penalty) + if sampling_info.logit_bias is not None: + linear_penalty.add_(sampling_info.logit_bias) + logits_output.next_token_logits.add_( torch.repeat_interleave(linear_penalty, self.draft_token_num, dim=0) ) From af64b362caa89105a72917d51ce32466a9e842f3 Mon Sep 17 00:00:00 2001 From: zRzRzRzRzRzRzR <2448370773@qq.com> Date: Mon, 30 Mar 2026 14:58:42 +0800 Subject: [PATCH 16/20] fall back --- .../srt/layers/rotary_embedding/factory.py | 3 -- .../srt/layers/rotary_embedding/mrope.py | 33 ------------------- .../layers/rotary_embedding/triton_kernels.py | 18 ++-------- 3 files changed, 3 insertions(+), 51 deletions(-) diff --git a/python/sglang/srt/layers/rotary_embedding/factory.py b/python/sglang/srt/layers/rotary_embedding/factory.py index 27e28577c96e..e95e9543f7f6 100644 --- a/python/sglang/srt/layers/rotary_embedding/factory.py +++ b/python/sglang/srt/layers/rotary_embedding/factory.py @@ -171,9 +171,6 @@ def get_rope( dtype, mrope_section=rope_scaling["mrope_section"], mrope_interleaved=rope_scaling.get("mrope_interleaved", False), - mrope_interleaved_glm=rope_scaling.get( - "mrope_interleaved_glm", False - ), ) elif rope_scaling.get("use_fope", False): rotary_emb = FourierRotaryEmbedding( diff --git a/python/sglang/srt/layers/rotary_embedding/mrope.py b/python/sglang/srt/layers/rotary_embedding/mrope.py index 0e4e8b5e4f05..237528fd1d47 100644 --- a/python/sglang/srt/layers/rotary_embedding/mrope.py +++ b/python/sglang/srt/layers/rotary_embedding/mrope.py @@ -52,14 +52,12 @@ def __init__( dtype: torch.dtype, mrope_section: Optional[List[int]] = None, mrope_interleaved: bool = False, - mrope_interleaved_glm: bool = False, ) -> None: super().__init__( head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype ) self.mrope_section = mrope_section self.mrope_interleaved = mrope_interleaved - self.mrope_interleaved_glm = mrope_interleaved_glm if self.mrope_section: expected_sum = rotary_dim // 2 actual_sum = sum(self.mrope_section) @@ -88,35 +86,6 @@ def __init__( f"Corrected mrope_section: {self.mrope_section} (sum={sum(self.mrope_section)})" ) - # MRoPE axis_map interleaving pattern depends on mrope_section sizes. - # The algorithm cycles through axes [0(T), 1(H), 2(W)] round-robin, - # skipping any axis that has exhausted its allocated pairs. - # - # For GLM-V (mrope_section=[8,12,12]): - # T(8) < H(12) = W(12), so T exhausts first at pair 24. - # Result: [0,1,2, 0,1,2, 0,1,2, 0,1,2, 0,1,2, 0,1,2, 0,1,2, 0,1,2, 1,1,2, 1,1,2, 2,2] - # After T runs out, only H and W fill the remaining slots. - # - # For Qwen3-VL (mrope_section=[24,20,20]): - # T(24) > H(20) = W(20), so H and W exhaust first near the tail. - # Result: [0,1,2, 0,1,2, ...repeated evenly..., 0,1, 0,1, 0,0] - # After H/W run out, T fills the remaining slots. - - num_pairs = rotary_dim // 2 - axis_map = torch.empty(num_pairs, dtype=torch.long) - assert sum(self.mrope_section) == num_pairs - counts = [0, 0, 0] - current_ax = 0 - - for i in range(num_pairs): - current_ax = i % 3 - while counts[current_ax] >= self.mrope_section[current_ax]: - current_ax = (current_ax + 1) % 3 - - axis_map[i] = current_ax - counts[current_ax] += 1 - - self.register_buffer("axis_map", axis_map, persistent=False) if get_global_server_args().rl_on_policy_target is not None: self._forward_method = self.forward_native @@ -245,9 +214,7 @@ def forward_triton( self.head_size, self.rotary_dim, self.mrope_interleaved, - self.mrope_interleaved_glm, self.is_neox_style, - self.axis_map, ) return query, key diff --git a/python/sglang/srt/layers/rotary_embedding/triton_kernels.py b/python/sglang/srt/layers/rotary_embedding/triton_kernels.py index 0a8dc2c33c7b..9a3d21bf83bb 100644 --- a/python/sglang/srt/layers/rotary_embedding/triton_kernels.py +++ b/python/sglang/srt/layers/rotary_embedding/triton_kernels.py @@ -29,9 +29,7 @@ def _triton_mrope_forward_fused( mrope_section_h: tl.constexpr, mrope_section_w: tl.constexpr, is_interleaved: tl.constexpr, - is_interleaved_glm: tl.constexpr, is_neox_style: tl.constexpr, - axis_map_ptr, ): pid = tl.program_id(0) q_ptr = q_ptr + pid * q_stride @@ -48,15 +46,9 @@ def _triton_mrope_forward_fused( w_sin = w_cos + half_rd cos_offsets = tl.arange(0, pad_hd // 2) if is_interleaved: - if is_interleaved_glm: - axes = tl.load(axis_map_ptr + cos_offsets, mask=cos_offsets < (pad_hd // 2)) - t_mask = axes == 0 - h_mask = axes == 1 - w_mask = axes == 2 - else: - h_mask = ((cos_offsets % 3) == 1) & (cos_offsets <= 3 * mrope_section_h) - w_mask = ((cos_offsets % 3) == 2) & (cos_offsets <= 3 * mrope_section_w) - t_mask = ~(h_mask | w_mask) + h_mask = ((cos_offsets % 3) == 1) & (cos_offsets <= 3 * mrope_section_h) + w_mask = ((cos_offsets % 3) == 2) & (cos_offsets <= 3 * mrope_section_w) + t_mask = ~(h_mask | w_mask) else: t_end = mrope_section_t h_end = t_end + mrope_section_h @@ -117,9 +109,7 @@ def triton_mrope_fused( head_size: int, rotary_dim: int, mrope_interleaved: bool, - mrope_interleaved_glm: bool, is_neox_style: bool, - axis_map: torch.Tensor, ) -> None: num_tokens, n_q_dim = q.shape n_k_dim = k.shape[1] @@ -147,9 +137,7 @@ def triton_mrope_fused( mrope_section[1], mrope_section[2], mrope_interleaved, - mrope_interleaved_glm, is_neox_style, - axis_map, ) From 9a9568bc9f84dc7e0776acd4fef12154e08b1527 Mon Sep 17 00:00:00 2001 From: hnyls2002 Date: Mon, 30 Mar 2026 15:20:14 -0700 Subject: [PATCH 17/20] Use get_scaling_penalties() interface instead of direct attribute access --- .../srt/sampling/penaltylib/orchestrator.py | 60 ++++++++++++++++--- .../sampling/penaltylib/repetition_penalty.py | 3 + .../srt/sampling/sampling_batch_info.py | 19 ++---- python/sglang/srt/speculative/eagle_info.py | 40 ++----------- python/sglang/srt/speculative/ngram_info.py | 38 ++---------- .../unit/sampling/test_sampling_batch_info.py | 11 ++-- 6 files changed, 75 insertions(+), 96 deletions(-) diff --git a/python/sglang/srt/sampling/penaltylib/orchestrator.py b/python/sglang/srt/sampling/penaltylib/orchestrator.py index 80264e29d26f..7bb2421cce0b 100644 --- a/python/sglang/srt/sampling/penaltylib/orchestrator.py +++ b/python/sglang/srt/sampling/penaltylib/orchestrator.py @@ -52,19 +52,56 @@ def cumulate_output_tokens(self, output_ids: torch.Tensor): for penalizer in self.penalizers.values(): penalizer.cumulate_output_tokens(output_ids=output_ids) - def apply(self, logits: torch.Tensor) -> torch.Tensor: + def apply(self, logits: torch.Tensor, repeat: Optional[int] = None): """ - Apply the penalizers to the logits. - Note that it may apply the penalizers in-place. + Apply all penalizers to the logits in-place. Args: - logits (torch.Tensor): The logits to apply the penalizers to. - - Returns: - torch.Tensor: The logits after applying the penalizers. + logits: The logits tensor to apply penalties to. + repeat: If set (speculative decoding), per-request penalties are + expanded via repeat_interleave to match the draft token layout. + Additive penalties are captured into a zeros tensor, expanded, + then added; scaling penalties are accumulated, expanded, then + applied directly. """ + if repeat is None: + for penalizer in self.penalizers.values(): + penalizer.apply(logits) + else: + # Additive: capture into zeros, expand, add + bs = logits.shape[0] // repeat + linear = torch.zeros( + (bs, logits.shape[1]), dtype=torch.float32, device=logits.device + ) + self.apply_additive(linear) + logits.add_(torch.repeat_interleave(linear, repeat, dim=0)) + # Scaling: accumulate, expand, apply + accumulated = self.accumulate_scaling_penalties() + if accumulated is not None: + from sglang.srt.sampling.penaltylib.repetition_penalty import ( + apply_scaling_penalties, + ) + + expanded = torch.repeat_interleave(accumulated, repeat, dim=0) + apply_scaling_penalties(logits, expanded) + + def apply_additive(self, logits: torch.Tensor): + """Apply only additive (non-multiplicative) penalizers.""" + for penalizer in self.penalizers.values(): + if not penalizer.is_multiplicative: + penalizer.apply(logits) + + def accumulate_scaling_penalties(self) -> Optional[torch.Tensor]: + """Accumulate all multiplicative penalty tensors into one, or None if none active.""" + result = None for penalizer in self.penalizers.values(): - penalizer.apply(logits) + if not penalizer._is_prepared or not penalizer.is_multiplicative: + continue + if result is None: + result = penalizer.get_scaling_penalties().clone() + else: + result *= penalizer.get_scaling_penalties() + return result def filter(self, keep_indices: torch.Tensor): """ @@ -229,6 +266,13 @@ def _apply(self, logits: torch.Tensor) -> torch.Tensor: """ pass + def get_scaling_penalties(self) -> torch.Tensor: + """ + Return the accumulated scaling penalty tensor for multiplicative penalizers. + Only meaningful when is_multiplicative is True. Subclasses should override. + """ + raise NotImplementedError + @abc.abstractmethod def _filter(self, keep_indices: torch.Tensor): """ diff --git a/python/sglang/srt/sampling/penaltylib/repetition_penalty.py b/python/sglang/srt/sampling/penaltylib/repetition_penalty.py index 9e4c88721a5f..fd03fb2b5c89 100644 --- a/python/sglang/srt/sampling/penaltylib/repetition_penalty.py +++ b/python/sglang/srt/sampling/penaltylib/repetition_penalty.py @@ -54,6 +54,9 @@ def _apply(self, logits: torch.Tensor) -> torch.Tensor: apply_scaling_penalties(logits, self.cumulated_repetition_penalties) return logits + def get_scaling_penalties(self) -> torch.Tensor: + return self.cumulated_repetition_penalties + def _filter(self, keep_indices: torch.Tensor): self.repetition_penalties = self.repetition_penalties[keep_indices] self.cumulated_repetition_penalties = self.cumulated_repetition_penalties[ diff --git a/python/sglang/srt/sampling/sampling_batch_info.py b/python/sglang/srt/sampling/sampling_batch_info.py index 6a0ef9d5ef4d..d0c0358bcda8 100644 --- a/python/sglang/srt/sampling/sampling_batch_info.py +++ b/python/sglang/srt/sampling/sampling_batch_info.py @@ -239,21 +239,10 @@ def update_penalties(self): dtype=torch.float32, device=self.temperatures.device, ) - self.acc_scaling_penalties = None - - for pen in self.penalizer_orchestrator.penalizers.values(): - if not pen._is_prepared: - continue - if getattr(pen, "is_multiplicative", False): - # Accumulate multiplicative penalties (e.g. repetition penalty) - if self.acc_scaling_penalties is None: - self.acc_scaling_penalties = ( - pen.cumulated_repetition_penalties.clone() - ) - else: - self.acc_scaling_penalties *= pen.cumulated_repetition_penalties - else: - pen.apply(self.acc_linear_penalties) + self.penalizer_orchestrator.apply_additive(self.acc_linear_penalties) + self.acc_scaling_penalties = ( + self.penalizer_orchestrator.accumulate_scaling_penalties() + ) else: self.acc_linear_penalties = None self.acc_scaling_penalties = None diff --git a/python/sglang/srt/speculative/eagle_info.py b/python/sglang/srt/speculative/eagle_info.py index 6c83642fb7c1..dbb91f555ecf 100644 --- a/python/sglang/srt/speculative/eagle_info.py +++ b/python/sglang/srt/speculative/eagle_info.py @@ -20,10 +20,6 @@ get_last_loc, ) from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode -from sglang.srt.sampling.penaltylib import BatchedRepetitionPenalizer -from sglang.srt.sampling.penaltylib.repetition_penalty import ( - apply_scaling_penalties, -) from sglang.srt.server_args import get_global_server_args from sglang.srt.speculative.eagle_info_v2 import ( EagleDraftInputV2Mixin, @@ -288,39 +284,15 @@ def verify( or sampling_info.logit_bias is not None ): # This is a relaxed version of penalties for speculative decoding. - linear_penalty = torch.zeros( - (bs, logits_output.next_token_logits.shape[1]), - dtype=torch.float32, - device=batch.device, + sampling_info.penalizer_orchestrator.apply( + logits_output.next_token_logits, repeat=self.draft_token_num ) - - # Only apply non-multiplicative (additive) penalizers to the capture tensor. - # Multiplicative penalizers (e.g. repetition penalty) would corrupt additive - # penalty values, so they are applied directly on logits below. - for pen in sampling_info.penalizer_orchestrator.penalizers.values(): - if not getattr(pen, "is_multiplicative", False): - pen.apply(linear_penalty) if sampling_info.logit_bias is not None: - linear_penalty.add_(sampling_info.logit_bias) - - logits_output.next_token_logits.add_( - torch.repeat_interleave(linear_penalty, self.draft_token_num, dim=0) - ) - - # Apply multiplicative penalties (repetition) directly on logits - rep_cls = BatchedRepetitionPenalizer - if rep_cls in sampling_info.penalizer_orchestrator.penalizers: - rep_pen = sampling_info.penalizer_orchestrator.penalizers[rep_cls] - if rep_pen._is_prepared: - expanded_penalties = torch.repeat_interleave( - rep_pen.cumulated_repetition_penalties, - self.draft_token_num, - dim=0, - ) - apply_scaling_penalties( - logits_output.next_token_logits, - expanded_penalties, + logits_output.next_token_logits.add_( + torch.repeat_interleave( + sampling_info.logit_bias, self.draft_token_num, dim=0 ) + ) # Apply grammar mask if vocab_mask is not None: diff --git a/python/sglang/srt/speculative/ngram_info.py b/python/sglang/srt/speculative/ngram_info.py index 69e6b0a42451..7aafe9870769 100644 --- a/python/sglang/srt/speculative/ngram_info.py +++ b/python/sglang/srt/speculative/ngram_info.py @@ -26,8 +26,6 @@ alloc_token_slots, get_last_loc, ) -from sglang.srt.sampling.penaltylib import BatchedRepetitionPenalizer -from sglang.srt.sampling.penaltylib.repetition_penalty import apply_scaling_penalties from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.speculative.spec_info import SpecInput, SpecInputType from sglang.srt.speculative.spec_utils import ( @@ -405,39 +403,15 @@ def verify( or sampling_info.logit_bias is not None ): # This is a relaxed version of penalties for speculative decoding. - linear_penalty = torch.zeros( - (bs, logits_output.next_token_logits.shape[1]), - dtype=torch.float32, - device=self.device, + sampling_info.penalizer_orchestrator.apply( + logits_output.next_token_logits, repeat=self.draft_token_num ) - - # Only apply non-multiplicative (additive) penalizers to the capture tensor. - # Multiplicative penalizers (e.g. repetition penalty) would corrupt additive - # penalty values, so they are applied directly on logits below. - for pen in sampling_info.penalizer_orchestrator.penalizers.values(): - if not getattr(pen, "is_multiplicative", False): - pen.apply(linear_penalty) if sampling_info.logit_bias is not None: - linear_penalty.add_(sampling_info.logit_bias) - - logits_output.next_token_logits.add_( - torch.repeat_interleave(linear_penalty, self.draft_token_num, dim=0) - ) - - # Apply multiplicative penalties (repetition) directly on logits - rep_cls = BatchedRepetitionPenalizer - if rep_cls in sampling_info.penalizer_orchestrator.penalizers: - rep_pen = sampling_info.penalizer_orchestrator.penalizers[rep_cls] - if rep_pen._is_prepared: - expanded_penalties = torch.repeat_interleave( - rep_pen.cumulated_repetition_penalties, - self.draft_token_num, - dim=0, - ) - apply_scaling_penalties( - logits_output.next_token_logits, - expanded_penalties, + logits_output.next_token_logits.add_( + torch.repeat_interleave( + sampling_info.logit_bias, self.draft_token_num, dim=0 ) + ) # Apply grammar mask if vocab_mask is not None: diff --git a/test/registered/unit/sampling/test_sampling_batch_info.py b/test/registered/unit/sampling/test_sampling_batch_info.py index 59e18d970af3..d269fec73b36 100644 --- a/test/registered/unit/sampling/test_sampling_batch_info.py +++ b/test/registered/unit/sampling/test_sampling_batch_info.py @@ -194,18 +194,15 @@ def test_no_bias_no_change(self): class TestUpdatePenalties(CustomTestCase): def test_required_creates_penalties_tensor(self): - """Test that update_penalties allocates a zero tensor and calls penalizer apply.""" - # Create a mock linear penalizer (non-multiplicative, prepared) - linear_pen = MagicMock(_is_prepared=True) - linear_pen.is_multiplicative = False - + """Test that update_penalties allocates a zero tensor and calls orchestrator methods.""" orch = MagicMock(is_required=True) - orch.penalizers = {"linear": linear_pen} + orch.accumulate_scaling_penalties.return_value = None info = _make_info(batch_size=2, penalizer_orchestrator=orch) info.update_penalties() self.assertIsNotNone(info.acc_linear_penalties) self.assertEqual(info.acc_linear_penalties.shape, (2, VOCAB_SIZE)) - linear_pen.apply.assert_called_once_with(info.acc_linear_penalties) + orch.apply_additive.assert_called_once_with(info.acc_linear_penalties) + orch.accumulate_scaling_penalties.assert_called_once() def test_not_required_sets_none(self): """Test that update_penalties sets acc_linear_penalties to None when not required.""" From eef407d0938fa131529122f0cde249f617e4e893 Mon Sep 17 00:00:00 2001 From: hnyls2002 Date: Tue, 31 Mar 2026 20:51:25 -0700 Subject: [PATCH 18/20] renaming --- .../sglang/srt/sampling/penaltylib/orchestrator.py | 8 ++++---- python/sglang/srt/sampling/sampling_batch_info.py | 14 ++++++++------ 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/python/sglang/srt/sampling/penaltylib/orchestrator.py b/python/sglang/srt/sampling/penaltylib/orchestrator.py index 7bb2421cce0b..650c719f37ca 100644 --- a/python/sglang/srt/sampling/penaltylib/orchestrator.py +++ b/python/sglang/srt/sampling/penaltylib/orchestrator.py @@ -70,11 +70,11 @@ def apply(self, logits: torch.Tensor, repeat: Optional[int] = None): else: # Additive: capture into zeros, expand, add bs = logits.shape[0] // repeat - linear = torch.zeros( + additive = torch.zeros( (bs, logits.shape[1]), dtype=torch.float32, device=logits.device ) - self.apply_additive(linear) - logits.add_(torch.repeat_interleave(linear, repeat, dim=0)) + self.accumulate_additive_penalties(additive) + logits.add_(torch.repeat_interleave(additive, repeat, dim=0)) # Scaling: accumulate, expand, apply accumulated = self.accumulate_scaling_penalties() if accumulated is not None: @@ -85,7 +85,7 @@ def apply(self, logits: torch.Tensor, repeat: Optional[int] = None): expanded = torch.repeat_interleave(accumulated, repeat, dim=0) apply_scaling_penalties(logits, expanded) - def apply_additive(self, logits: torch.Tensor): + def accumulate_additive_penalties(self, logits: torch.Tensor): """Apply only additive (non-multiplicative) penalizers.""" for penalizer in self.penalizers.values(): if not penalizer.is_multiplicative: diff --git a/python/sglang/srt/sampling/sampling_batch_info.py b/python/sglang/srt/sampling/sampling_batch_info.py index d0c0358bcda8..885936b0ec95 100644 --- a/python/sglang/srt/sampling/sampling_batch_info.py +++ b/python/sglang/srt/sampling/sampling_batch_info.py @@ -47,7 +47,7 @@ class SamplingBatchInfo: # Penalizer penalizer_orchestrator: Optional[penaltylib.BatchedPenalizerOrchestrator] = None - acc_linear_penalties: Optional[torch.Tensor] = None # Used in the overlap mode + acc_additive_penalties: Optional[torch.Tensor] = None # Used in the overlap mode acc_scaling_penalties: Optional[torch.Tensor] = ( None # Used in the overlap mode for repetition penalty ) @@ -234,23 +234,25 @@ def update_regex_vocab_mask(self): def update_penalties(self): if self.penalizer_orchestrator.is_required: - self.acc_linear_penalties = torch.zeros( + self.acc_additive_penalties = torch.zeros( (len(self.temperatures), self.vocab_size), dtype=torch.float32, device=self.temperatures.device, ) - self.penalizer_orchestrator.apply_additive(self.acc_linear_penalties) + self.penalizer_orchestrator.accumulate_additive_penalties( + self.acc_additive_penalties + ) self.acc_scaling_penalties = ( self.penalizer_orchestrator.accumulate_scaling_penalties() ) else: - self.acc_linear_penalties = None + self.acc_additive_penalties = None self.acc_scaling_penalties = None def apply_logits_bias(self, logits: torch.Tensor): - if self.acc_linear_penalties is not None: + if self.acc_additive_penalties is not None: # Used in the overlap mode - logits.add_(self.acc_linear_penalties) + logits.add_(self.acc_additive_penalties) if self.acc_scaling_penalties is not None: # Used in the overlap mode From 2295478325b558e8aa036a13cf36552cf5482d90 Mon Sep 17 00:00:00 2001 From: hnyls2002 Date: Tue, 31 Mar 2026 20:56:09 -0700 Subject: [PATCH 19/20] fix naming: linear -> additive in penalties --- .../unit/sampling/test_sampling_batch_info.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/test/registered/unit/sampling/test_sampling_batch_info.py b/test/registered/unit/sampling/test_sampling_batch_info.py index d269fec73b36..478218b34bb2 100644 --- a/test/registered/unit/sampling/test_sampling_batch_info.py +++ b/test/registered/unit/sampling/test_sampling_batch_info.py @@ -199,17 +199,19 @@ def test_required_creates_penalties_tensor(self): orch.accumulate_scaling_penalties.return_value = None info = _make_info(batch_size=2, penalizer_orchestrator=orch) info.update_penalties() - self.assertIsNotNone(info.acc_linear_penalties) - self.assertEqual(info.acc_linear_penalties.shape, (2, VOCAB_SIZE)) - orch.apply_additive.assert_called_once_with(info.acc_linear_penalties) + self.assertIsNotNone(info.acc_additive_penalties) + self.assertEqual(info.acc_additive_penalties.shape, (2, VOCAB_SIZE)) + orch.accumulate_additive_penalties.assert_called_once_with( + info.acc_additive_penalties + ) orch.accumulate_scaling_penalties.assert_called_once() def test_not_required_sets_none(self): - """Test that update_penalties sets acc_linear_penalties to None when not required.""" + """Test that update_penalties sets acc_additive_penalties to None when not required.""" orch = MagicMock(is_required=False) info = _make_info(batch_size=2, penalizer_orchestrator=orch) info.update_penalties() - self.assertIsNone(info.acc_linear_penalties) + self.assertIsNone(info.acc_additive_penalties) # update_regex_vocab_mask From 1adb874411f04cb008d4787607e61076c117881c Mon Sep 17 00:00:00 2001 From: hnyls2002 Date: Tue, 31 Mar 2026 21:05:53 -0700 Subject: [PATCH 20/20] fix remaining linear -> additive renames in test --- test/registered/unit/sampling/test_sampling_batch_info.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/registered/unit/sampling/test_sampling_batch_info.py b/test/registered/unit/sampling/test_sampling_batch_info.py index 478218b34bb2..7b018915381f 100644 --- a/test/registered/unit/sampling/test_sampling_batch_info.py +++ b/test/registered/unit/sampling/test_sampling_batch_info.py @@ -142,10 +142,10 @@ def test_lhs_none_rhs_present(self): # apply_logits_bias class TestApplyLogitsBias(CustomTestCase): - def test_applies_linear_penalties(self): - """Test that pre-accumulated linear penalties are added to logits.""" + def test_applies_additive_penalties(self): + """Test that pre-accumulated additive penalties are added to logits.""" info = _make_info(batch_size=1) - info.acc_linear_penalties = torch.tensor([[-1.0] * VOCAB_SIZE]) + info.acc_additive_penalties = torch.tensor([[-1.0] * VOCAB_SIZE]) logits = torch.zeros(1, VOCAB_SIZE) info.apply_logits_bias(logits) self.assertAlmostEqual(logits[0, 0].item(), -1.0, places=5) @@ -181,7 +181,7 @@ def test_applies_penalizer_orchestrator(self): def test_no_bias_no_change(self): """Test that logits stay unchanged when no bias sources are set.""" info = _make_info(batch_size=1) - info.acc_linear_penalties = None + info.acc_additive_penalties = None info.logit_bias = None info.vocab_mask = None logits = torch.zeros(1, VOCAB_SIZE)