-
Notifications
You must be signed in to change notification settings - Fork 5.5k
[Feature Restoration] repetition_penalty is essential for GLM-V models #21258
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
30 commits
Select commit
Hold shift + click to select a range
00b648a
Add repetition_penalties for #5703
zRzRzRzRzRzRzR fde0313
add init
zRzRzRzRzRzRzR 0b12c1d
change comment
zRzRzRzRzRzRzR fe26403
Update sampling_batch_info.py
zRzRzRzRzRzRzR 0d6c418
Update glm4_moe.py
zRzRzRzRzRzRzR 9d4efe1
format
zRzRzRzRzRzRzR e27aa92
Update glm4_moe.py
zRzRzRzRzRzRzR 6b92feb
Merge branch 'main' into glm
JustinTong0323 5e59699
Merge branch 'main' into glm
JustinTong0323 8850526
update with glm interleaved
zRzRzRzRzRzRzR 2d6e696
Merge branch 'sgl-project:main' into glm
zRzRzRzRzRzRzR 2d24f49
fix comment
zRzRzRzRzRzRzR 4be9b6c
fix
zRzRzRzRzRzRzR ae96759
Update eagle_info.py
zRzRzRzRzRzRzR 095121e
Update ngram_info.py
zRzRzRzRzRzRzR 40ab489
Merge branch 'main' into glm
zRzRzRzRzRzRzR ce86c70
Merge branch 'main' into glm
JustinTong0323 ad05b34
Fix test_required_creates_penalties_tensor to match new per-penalizer…
JustinTong0323 3295a9c
Merge branch 'sgl-project:main' into glm
zRzRzRzRzRzRzR 72d6f92
Merge branch 'main' into glm
zRzRzRzRzRzRzR 671d350
Merge branch 'sgl-project:main' into glm
zRzRzRzRzRzRzR abf2d96
fall back1
zRzRzRzRzRzRzR 3f55e0d
fix comment
zRzRzRzRzRzRzR af64b36
fall back
zRzRzRzRzRzRzR 9a9568b
Use get_scaling_penalties() interface instead of direct attribute access
hnyls2002 5f539a2
Merge branch 'main' into glm
hnyls2002 700fb3d
Merge branch 'main' into glm
hnyls2002 eef407d
renaming
hnyls2002 2295478
fix naming: linear -> additive in penalties
hnyls2002 1adb874
fix remaining linear -> additive renames in test
hnyls2002 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
78 changes: 78 additions & 0 deletions
78
python/sglang/srt/sampling/penaltylib/repetition_penalty.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,78 @@ | ||
| import torch | ||
|
|
||
| from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer | ||
| from sglang.srt.utils import get_compiler_backend | ||
|
|
||
|
|
||
| @torch.compile(dynamic=True, backend=get_compiler_backend()) | ||
| def apply_scaling_penalties(logits, scaling_penalties): | ||
| logits[:] = torch.where( | ||
| logits < 0, | ||
| logits * scaling_penalties, | ||
| logits / scaling_penalties, | ||
| ) | ||
|
|
||
|
|
||
| class BatchedRepetitionPenalizer(_BatchedPenalizer): | ||
| """ | ||
| Repetition penalizer penalizes tokens based on their presence in the generated output. | ||
| """ | ||
|
|
||
| is_multiplicative: bool = True | ||
|
|
||
| def _is_required(self) -> bool: | ||
| return any( | ||
| req.sampling_params.repetition_penalty != 1.0 | ||
| for req in self.orchestrator.reqs() | ||
| ) | ||
|
|
||
| def _prepare(self): | ||
| self.cumulated_repetition_penalties = torch.ones( | ||
| (len(self.orchestrator.reqs()), self.orchestrator.vocab_size), | ||
| dtype=torch.float32, | ||
| device=self.orchestrator.device, | ||
| ) | ||
| self.repetition_penalties = ( | ||
| torch.tensor( | ||
| data=[ | ||
| req.sampling_params.repetition_penalty | ||
| for req in self.orchestrator.reqs() | ||
| ], | ||
| dtype=torch.float32, | ||
| device=self.orchestrator.device, | ||
| ) | ||
| ).unsqueeze_(1) | ||
|
|
||
| def _cumulate_output_tokens(self, output_ids: torch.Tensor): | ||
| self.cumulated_repetition_penalties.scatter_( | ||
| dim=1, | ||
| index=output_ids.unsqueeze(1), | ||
| src=self.repetition_penalties, | ||
| ) | ||
|
|
||
| def _apply(self, logits: torch.Tensor) -> torch.Tensor: | ||
| apply_scaling_penalties(logits, self.cumulated_repetition_penalties) | ||
| return logits | ||
|
|
||
| def get_scaling_penalties(self) -> torch.Tensor: | ||
| return self.cumulated_repetition_penalties | ||
|
|
||
| def _filter(self, keep_indices: torch.Tensor): | ||
| self.repetition_penalties = self.repetition_penalties[keep_indices] | ||
| self.cumulated_repetition_penalties = self.cumulated_repetition_penalties[ | ||
| keep_indices | ||
| ] | ||
|
|
||
| def _merge(self, their: "BatchedRepetitionPenalizer"): | ||
| self.repetition_penalties = torch.cat( | ||
| [self.repetition_penalties, their.repetition_penalties], dim=0 | ||
| ) | ||
| self.cumulated_repetition_penalties = torch.cat( | ||
| [self.cumulated_repetition_penalties, their.cumulated_repetition_penalties], | ||
| dim=0, | ||
| ) | ||
|
|
||
| def _teardown(self) -> None: | ||
| for name in ("repetition_penalties", "cumulated_repetition_penalties"): | ||
| if hasattr(self, name): | ||
| delattr(self, name) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.