Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions python/sglang/srt/sampling/custom_logit_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,26 @@ def __call__(
"""Define the callable behavior."""
raise NotImplementedError

def to_str(self) -> str:
@classmethod
def to_str(cls) -> str:
"""Serialize the callable function to a JSON-compatible string."""
return json.dumps({"callable": dill.dumps(self).hex()})
return json.dumps({"callable": dill.dumps(cls).hex()})

@classmethod
def from_str(cls, json_str: str):
"""Deserialize a callable function from a JSON string."""
return _cache_from_str(json_str)
return _cache_from_str(json_str)()


class DisallowedTokensLogitsProcessor(CustomLogitProcessor):
def __call__(
self,
logits: torch.Tensor,
custom_param_list: Optional[List[Dict[str, Any]]] = None,
) -> torch.Tensor:
disallowed_token_ids = custom_param_list[0]["token_ids"]
assert all(
disallowed_token_ids == c["token_ids"] for c in custom_param_list
), f"{custom_param_list=}"
logits[..., disallowed_token_ids] = -float("inf")
return logits
17 changes: 13 additions & 4 deletions test/srt/test_srt_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,9 +344,7 @@ def __call__(self, logits, custom_param_list):
custom_json = base_json.copy()
# Only set the custom logit processor if target_token_id is not None.
if target_token_id is not None:
custom_json["custom_logit_processor"] = (
DeterministicLogitProcessor().to_str()
)
custom_json["custom_logit_processor"] = DeterministicLogitProcessor.to_str()
custom_json["sampling_params"]["custom_params"] = custom_params

custom_response = requests.post(
Expand All @@ -373,7 +371,6 @@ def run_stateful_custom_logit_processor(
Should sample the first `delay` tokens normally, then output first_token_id and consecutive tokens after that.
If first_token_id is None, the custom logit processor won't be passed in.
"""

custom_params = {"token_id": first_token_id, "delay": 2}

class DeterministicStatefulLogitProcessor(CustomLogitProcessor):
Expand Down Expand Up @@ -447,10 +444,22 @@ def test_custom_logit_processor_batch_mixed(self):
with ThreadPoolExecutor(len(target_token_ids)) as executor:
list(executor.map(self.run_custom_logit_processor, target_token_ids))

@unittest.skip("Skip this test because this feature has a bug. See comments below.")
def test_stateful_custom_logit_processor(self):
"""Test custom logit processor with a single request."""

"""
NOTE: This feature has a race condition bug.
This line https://github.com/sgl-project/sglang/blob/ef8ec07b2ce4c70c2a33ec5acda4ce529bc3cda4/test/srt/test_srt_endpoint.py#L395-L396 can be accessed by two concurrent threads at the same time. The access order is not guaranteed.
In sglang, we use two python threads to overlap the GPU computation and CPU scheduling.
Thread 1 (the CPU scheduling thread) will update the `param_dict["__req__"].output_ids`.
Thread 2 (the GPU computation thread) will call `DeterministicStatefulLogitProcessor` because sampling is considered as GPU computation.
We can fix this by moving the call of DeterministicStatefulLogitProcessor to the CPU scheduling thread.
"""

self.run_stateful_custom_logit_processor(first_token_id=5)

@unittest.skip("Skip this test because this feature has a bug. See comments above.")
def test_stateful_custom_logit_processor_batch_mixed(self):
"""Test a batch of requests mixed of requests with and without custom logit processor."""
target_token_ids = list(range(32)) + [None] * 16
Expand Down
Loading