Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
00b648a
Add repetition_penalties for #5703
zRzRzRzRzRzRzR Mar 24, 2026
fde0313
add init
zRzRzRzRzRzRzR Mar 24, 2026
0b12c1d
change comment
zRzRzRzRzRzRzR Mar 24, 2026
fe26403
Update sampling_batch_info.py
zRzRzRzRzRzRzR Mar 24, 2026
0d6c418
Update glm4_moe.py
zRzRzRzRzRzRzR Mar 24, 2026
9d4efe1
format
zRzRzRzRzRzRzR Mar 24, 2026
e27aa92
Update glm4_moe.py
zRzRzRzRzRzRzR Mar 24, 2026
6b92feb
Merge branch 'main' into glm
JustinTong0323 Mar 24, 2026
5e59699
Merge branch 'main' into glm
JustinTong0323 Mar 25, 2026
8850526
update with glm interleaved
zRzRzRzRzRzRzR Mar 25, 2026
2d6e696
Merge branch 'sgl-project:main' into glm
zRzRzRzRzRzRzR Mar 26, 2026
2d24f49
fix comment
zRzRzRzRzRzRzR Mar 26, 2026
4be9b6c
fix
zRzRzRzRzRzRzR Mar 26, 2026
ae96759
Update eagle_info.py
zRzRzRzRzRzRzR Mar 26, 2026
095121e
Update ngram_info.py
zRzRzRzRzRzRzR Mar 26, 2026
40ab489
Merge branch 'main' into glm
zRzRzRzRzRzRzR Mar 27, 2026
ce86c70
Merge branch 'main' into glm
JustinTong0323 Mar 27, 2026
ad05b34
Fix test_required_creates_penalties_tensor to match new per-penalizer…
JustinTong0323 Mar 27, 2026
3295a9c
Merge branch 'sgl-project:main' into glm
zRzRzRzRzRzRzR Mar 28, 2026
72d6f92
Merge branch 'main' into glm
zRzRzRzRzRzRzR Mar 29, 2026
671d350
Merge branch 'sgl-project:main' into glm
zRzRzRzRzRzRzR Mar 30, 2026
abf2d96
fall back1
zRzRzRzRzRzRzR Mar 30, 2026
3f55e0d
fix comment
zRzRzRzRzRzRzR Mar 30, 2026
af64b36
fall back
zRzRzRzRzRzRzR Mar 30, 2026
9a9568b
Use get_scaling_penalties() interface instead of direct attribute access
hnyls2002 Mar 30, 2026
5f539a2
Merge branch 'main' into glm
hnyls2002 Mar 31, 2026
700fb3d
Merge branch 'main' into glm
hnyls2002 Mar 31, 2026
eef407d
renaming
hnyls2002 Apr 1, 2026
2295478
fix naming: linear -> additive in penalties
hnyls2002 Apr 1, 2026
1adb874
fix remaining linear -> additive renames in test
hnyls2002 Apr 1, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions python/sglang/srt/sampling/penaltylib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
from sglang.srt.sampling.penaltylib.min_new_tokens import BatchedMinNewTokensPenalizer
from sglang.srt.sampling.penaltylib.orchestrator import BatchedPenalizerOrchestrator
from sglang.srt.sampling.penaltylib.presence_penalty import BatchedPresencePenalizer
from sglang.srt.sampling.penaltylib.repetition_penalty import BatchedRepetitionPenalizer

__all__ = [
"BatchedFrequencyPenalizer",
"BatchedMinNewTokensPenalizer",
"BatchedPresencePenalizer",
"BatchedPenalizerOrchestrator",
"BatchedRepetitionPenalizer",
]
62 changes: 54 additions & 8 deletions python/sglang/srt/sampling/penaltylib/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,19 +52,56 @@ def cumulate_output_tokens(self, output_ids: torch.Tensor):
for penalizer in self.penalizers.values():
penalizer.cumulate_output_tokens(output_ids=output_ids)

def apply(self, logits: torch.Tensor) -> torch.Tensor:
def apply(self, logits: torch.Tensor, repeat: Optional[int] = None):
"""
Apply the penalizers to the logits.
Note that it may apply the penalizers in-place.
Apply all penalizers to the logits in-place.

Args:
logits (torch.Tensor): The logits to apply the penalizers to.

Returns:
torch.Tensor: The logits after applying the penalizers.
logits: The logits tensor to apply penalties to.
repeat: If set (speculative decoding), per-request penalties are
expanded via repeat_interleave to match the draft token layout.
Additive penalties are captured into a zeros tensor, expanded,
then added; scaling penalties are accumulated, expanded, then
applied directly.
"""
if repeat is None:
for penalizer in self.penalizers.values():
penalizer.apply(logits)
else:
# Additive: capture into zeros, expand, add
bs = logits.shape[0] // repeat
additive = torch.zeros(
(bs, logits.shape[1]), dtype=torch.float32, device=logits.device
)
self.accumulate_additive_penalties(additive)
logits.add_(torch.repeat_interleave(additive, repeat, dim=0))
# Scaling: accumulate, expand, apply
accumulated = self.accumulate_scaling_penalties()
if accumulated is not None:
from sglang.srt.sampling.penaltylib.repetition_penalty import (
apply_scaling_penalties,
)

expanded = torch.repeat_interleave(accumulated, repeat, dim=0)
apply_scaling_penalties(logits, expanded)

def accumulate_additive_penalties(self, logits: torch.Tensor):
"""Apply only additive (non-multiplicative) penalizers."""
for penalizer in self.penalizers.values():
penalizer.apply(logits)
if not penalizer.is_multiplicative:
penalizer.apply(logits)

def accumulate_scaling_penalties(self) -> Optional[torch.Tensor]:
"""Accumulate all multiplicative penalty tensors into one, or None if none active."""
result = None
for penalizer in self.penalizers.values():
if not penalizer._is_prepared or not penalizer.is_multiplicative:
continue
if result is None:
result = penalizer.get_scaling_penalties().clone()
else:
result *= penalizer.get_scaling_penalties()
return result

def filter(self, keep_indices: torch.Tensor):
"""
Expand Down Expand Up @@ -132,6 +169,8 @@ class _BatchedPenalizer(abc.ABC):
An abstract class for a batched penalizer.
"""

is_multiplicative: bool = False

def __init__(self, orchestrator: BatchedPenalizerOrchestrator):
self._orchestrator_ref: weakref.ReferenceType[BatchedPenalizerOrchestrator] = (
weakref.ref(orchestrator)
Expand Down Expand Up @@ -227,6 +266,13 @@ def _apply(self, logits: torch.Tensor) -> torch.Tensor:
"""
pass

def get_scaling_penalties(self) -> torch.Tensor:
"""
Return the accumulated scaling penalty tensor for multiplicative penalizers.
Only meaningful when is_multiplicative is True. Subclasses should override.
"""
raise NotImplementedError

@abc.abstractmethod
def _filter(self, keep_indices: torch.Tensor):
"""
Expand Down
78 changes: 78 additions & 0 deletions python/sglang/srt/sampling/penaltylib/repetition_penalty.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import torch

from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer
from sglang.srt.utils import get_compiler_backend


@torch.compile(dynamic=True, backend=get_compiler_backend())
def apply_scaling_penalties(logits, scaling_penalties):
logits[:] = torch.where(
logits < 0,
logits * scaling_penalties,
logits / scaling_penalties,
)


class BatchedRepetitionPenalizer(_BatchedPenalizer):
"""
Repetition penalizer penalizes tokens based on their presence in the generated output.
"""
Comment thread
zRzRzRzRzRzRzR marked this conversation as resolved.

is_multiplicative: bool = True

def _is_required(self) -> bool:
return any(
req.sampling_params.repetition_penalty != 1.0
for req in self.orchestrator.reqs()
)

def _prepare(self):
self.cumulated_repetition_penalties = torch.ones(
(len(self.orchestrator.reqs()), self.orchestrator.vocab_size),
dtype=torch.float32,
device=self.orchestrator.device,
)
self.repetition_penalties = (
torch.tensor(
data=[
req.sampling_params.repetition_penalty
for req in self.orchestrator.reqs()
],
dtype=torch.float32,
device=self.orchestrator.device,
)
).unsqueeze_(1)

def _cumulate_output_tokens(self, output_ids: torch.Tensor):
self.cumulated_repetition_penalties.scatter_(
dim=1,
index=output_ids.unsqueeze(1),
src=self.repetition_penalties,
)

def _apply(self, logits: torch.Tensor) -> torch.Tensor:
apply_scaling_penalties(logits, self.cumulated_repetition_penalties)
return logits

def get_scaling_penalties(self) -> torch.Tensor:
return self.cumulated_repetition_penalties

def _filter(self, keep_indices: torch.Tensor):
self.repetition_penalties = self.repetition_penalties[keep_indices]
self.cumulated_repetition_penalties = self.cumulated_repetition_penalties[
keep_indices
]

def _merge(self, their: "BatchedRepetitionPenalizer"):
self.repetition_penalties = torch.cat(
[self.repetition_penalties, their.repetition_penalties], dim=0
)
self.cumulated_repetition_penalties = torch.cat(
[self.cumulated_repetition_penalties, their.cumulated_repetition_penalties],
dim=0,
)

def _teardown(self) -> None:
for name in ("repetition_penalties", "cumulated_repetition_penalties"):
if hasattr(self, name):
delattr(self, name)
27 changes: 21 additions & 6 deletions python/sglang/srt/sampling/sampling_batch_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import sglang.srt.sampling.penaltylib as penaltylib
from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
from sglang.srt.sampling.penaltylib.repetition_penalty import apply_scaling_penalties
from sglang.srt.sampling.sampling_params import TOP_K_ALL
from sglang.srt.server_args import get_global_server_args

Expand Down Expand Up @@ -46,7 +47,10 @@ class SamplingBatchInfo:

# Penalizer
penalizer_orchestrator: Optional[penaltylib.BatchedPenalizerOrchestrator] = None
acc_linear_penalties: torch.Tensor = None # Used in the overlap mode
acc_additive_penalties: Optional[torch.Tensor] = None # Used in the overlap mode
acc_scaling_penalties: Optional[torch.Tensor] = (
None # Used in the overlap mode for repetition penalty
)

# Whether any request has custom logit processor
has_custom_logit_processor: bool = False
Expand Down Expand Up @@ -159,6 +163,7 @@ def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int):
penaltylib.BatchedFrequencyPenalizer,
penaltylib.BatchedMinNewTokensPenalizer,
penaltylib.BatchedPresencePenalizer,
penaltylib.BatchedRepetitionPenalizer,
},
)

Expand Down Expand Up @@ -229,19 +234,29 @@ def update_regex_vocab_mask(self):

def update_penalties(self):
if self.penalizer_orchestrator.is_required:
self.acc_linear_penalties = torch.zeros(
self.acc_additive_penalties = torch.zeros(
(len(self.temperatures), self.vocab_size),
dtype=torch.float32,
device=self.temperatures.device,
)
self.penalizer_orchestrator.apply(self.acc_linear_penalties)
self.penalizer_orchestrator.accumulate_additive_penalties(
self.acc_additive_penalties
)
self.acc_scaling_penalties = (
self.penalizer_orchestrator.accumulate_scaling_penalties()
)
else:
self.acc_linear_penalties = None
self.acc_additive_penalties = None
self.acc_scaling_penalties = None

def apply_logits_bias(self, logits: torch.Tensor):
if self.acc_linear_penalties is not None:
if self.acc_additive_penalties is not None:
# Used in the overlap mode
logits.add_(self.acc_additive_penalties)

if self.acc_scaling_penalties is not None:
# Used in the overlap mode
logits.add_(self.acc_linear_penalties)
apply_scaling_penalties(logits, self.acc_scaling_penalties)

if self.penalizer_orchestrator and self.penalizer_orchestrator.is_required:
# Used in the non-overlap mode
Expand Down
16 changes: 8 additions & 8 deletions python/sglang/srt/speculative/eagle_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,15 +284,15 @@ def verify(
or sampling_info.logit_bias is not None
):
# This is a relaxed version of penalties for speculative decoding.
linear_penalty = torch.zeros(
(bs, logits_output.next_token_logits.shape[1]),
dtype=torch.float32,
device=batch.device,
)
sampling_info.apply_logits_bias(linear_penalty)
logits_output.next_token_logits.add_(
torch.repeat_interleave(linear_penalty, self.draft_token_num, dim=0)
sampling_info.penalizer_orchestrator.apply(
logits_output.next_token_logits, repeat=self.draft_token_num
)
if sampling_info.logit_bias is not None:
logits_output.next_token_logits.add_(
torch.repeat_interleave(
sampling_info.logit_bias, self.draft_token_num, dim=0
)
)

# Apply grammar mask
if vocab_mask is not None:
Expand Down
21 changes: 12 additions & 9 deletions python/sglang/srt/speculative/ngram_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,17 +398,20 @@ def verify(
)

# Apply penalty
if sampling_info.penalizer_orchestrator.is_required:
if (
sampling_info.penalizer_orchestrator.is_required
or sampling_info.logit_bias is not None
):
# This is a relaxed version of penalties for speculative decoding.
linear_penalty = torch.zeros(
(bs, logits_output.next_token_logits.shape[1]),
dtype=torch.float32,
device=self.device,
)
sampling_info.apply_logits_bias(linear_penalty)
logits_output.next_token_logits.add_(
torch.repeat_interleave(linear_penalty, self.draft_token_num, dim=0)
sampling_info.penalizer_orchestrator.apply(
logits_output.next_token_logits, repeat=self.draft_token_num
)
if sampling_info.logit_bias is not None:
logits_output.next_token_logits.add_(
torch.repeat_interleave(
sampling_info.logit_bias, self.draft_token_num, dim=0
)
)

# Apply grammar mask
if vocab_mask is not None:
Expand Down
24 changes: 14 additions & 10 deletions test/registered/unit/sampling/test_sampling_batch_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,10 +142,10 @@ def test_lhs_none_rhs_present(self):
# apply_logits_bias
class TestApplyLogitsBias(CustomTestCase):

def test_applies_linear_penalties(self):
"""Test that pre-accumulated linear penalties are added to logits."""
def test_applies_additive_penalties(self):
"""Test that pre-accumulated additive penalties are added to logits."""
info = _make_info(batch_size=1)
info.acc_linear_penalties = torch.tensor([[-1.0] * VOCAB_SIZE])
info.acc_additive_penalties = torch.tensor([[-1.0] * VOCAB_SIZE])
logits = torch.zeros(1, VOCAB_SIZE)
info.apply_logits_bias(logits)
self.assertAlmostEqual(logits[0, 0].item(), -1.0, places=5)
Expand Down Expand Up @@ -181,7 +181,7 @@ def test_applies_penalizer_orchestrator(self):
def test_no_bias_no_change(self):
"""Test that logits stay unchanged when no bias sources are set."""
info = _make_info(batch_size=1)
info.acc_linear_penalties = None
info.acc_additive_penalties = None
info.logit_bias = None
info.vocab_mask = None
logits = torch.zeros(1, VOCAB_SIZE)
Expand All @@ -194,20 +194,24 @@ def test_no_bias_no_change(self):
class TestUpdatePenalties(CustomTestCase):

def test_required_creates_penalties_tensor(self):
"""Test that update_penalties allocates a zero tensor and calls orchestrator.apply."""
"""Test that update_penalties allocates a zero tensor and calls orchestrator methods."""
orch = MagicMock(is_required=True)
orch.accumulate_scaling_penalties.return_value = None
info = _make_info(batch_size=2, penalizer_orchestrator=orch)
info.update_penalties()
self.assertIsNotNone(info.acc_linear_penalties)
self.assertEqual(info.acc_linear_penalties.shape, (2, VOCAB_SIZE))
orch.apply.assert_called_once()
self.assertIsNotNone(info.acc_additive_penalties)
self.assertEqual(info.acc_additive_penalties.shape, (2, VOCAB_SIZE))
orch.accumulate_additive_penalties.assert_called_once_with(
info.acc_additive_penalties
)
orch.accumulate_scaling_penalties.assert_called_once()

def test_not_required_sets_none(self):
"""Test that update_penalties sets acc_linear_penalties to None when not required."""
"""Test that update_penalties sets acc_additive_penalties to None when not required."""
orch = MagicMock(is_required=False)
info = _make_info(batch_size=2, penalizer_orchestrator=orch)
info.update_penalties()
self.assertIsNone(info.acc_linear_penalties)
self.assertIsNone(info.acc_additive_penalties)


# update_regex_vocab_mask
Expand Down
Loading