Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 52 additions & 41 deletions docker/patch/latest/sglang.patch
Original file line number Diff line number Diff line change
Expand Up @@ -814,27 +814,23 @@ index 7f6f6a010..c4a673145 100644
if not get_global_server_args().sampling_backend == "ascend" or (
return_logprob and not SGLANG_RETURN_ORIGINAL_LOGPROB
diff --git a/python/sglang/srt/managers/detokenizer_manager.py b/python/sglang/srt/managers/detokenizer_manager.py
index 87922077e..8cb6bad8d 100644
index 87922077e..6507d8bf5 100644
--- a/python/sglang/srt/managers/detokenizer_manager.py
+++ b/python/sglang/srt/managers/detokenizer_manager.py
@@ -247,6 +247,16 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin):
@@ -247,6 +247,12 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin):
s.sent_offset = len(output_str)
output_strs.append(incremental_output)

+ output_routed_experts = []
+ if recv_obj.output_routed_experts is not None:
+ output_routed_experts = [
+ (
+ output_routed_experts.tolist()
+ if output_routed_experts is not None
+ else []
+ )
+ output_routed_experts
+ for output_routed_experts in recv_obj.output_routed_experts
+ ]
return BatchStrOutput(
rids=recv_obj.rids,
http_worker_ipcs=recv_obj.http_worker_ipcs,
@@ -272,6 +282,7 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin):
@@ -272,6 +278,7 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin):
output_token_ids_logprobs_idx=recv_obj.output_token_ids_logprobs_idx,
output_token_entropy_val=recv_obj.output_token_entropy_val,
output_hidden_states=recv_obj.output_hidden_states,
Expand Down Expand Up @@ -1165,29 +1161,72 @@ index f8ebfc1f4..48b9a1a3b 100644
return ResumeMemoryOccupationReqOutput()

def check_weights(self: Scheduler, recv_req: CheckWeightsReqInput):
diff --git a/python/sglang/srt/managers/tokenizer_communicator_mixin.py b/python/sglang/srt/managers/tokenizer_communicator_mixin.py
index edbc52526..2cdc42755 100644
--- a/python/sglang/srt/managers/tokenizer_communicator_mixin.py
+++ b/python/sglang/srt/managers/tokenizer_communicator_mixin.py
@@ -421,6 +421,11 @@ class TokenizerCommunicatorMixin:
result = (await self.update_weights_from_distributed_communicator(obj))[
0
]
+ if result.success and obj.weight_version is not None:
+ self._update_weight_version_if_provided(obj.weight_version)
+ result.message += (
+ f" Weight version updated to {obj.weight_version}."
+ )
return result.success, result.message

# This means that weight sync
@@ -480,6 +485,11 @@ class TokenizerCommunicatorMixin:
async with self.is_pause_cond:
if self.is_pause:
result = (await self.update_weights_from_tensor_communicator(obj))[0]
+ if result.success and obj.weight_version is not None:
+ self._update_weight_version_if_provided(obj.weight_version)
+ result.message += (
+ f" Weight version updated to {obj.weight_version}."
+ )
return result.success, result.message

# This means that weight sync
diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py
index b90cf0616..98d71d896 100644
index b90cf0616..9b0992655 100644
--- a/python/sglang/srt/managers/tokenizer_manager.py
+++ b/python/sglang/srt/managers/tokenizer_manager.py
@@ -888,6 +888,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
@@ -20,6 +20,7 @@ import logging
import math
import os
import pickle
+import pybase64
import signal
import sys
import threading
@@ -888,6 +889,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
session_params=session_params,
custom_logit_processor=obj.custom_logit_processor,
return_hidden_states=obj.return_hidden_states,
+ return_routed_experts=obj.return_routed_experts,
data_parallel_rank=obj.data_parallel_rank,
priority=obj.priority,
extra_key=obj.extra_key,
@@ -1621,6 +1622,9 @@ class TokenizerManager(TokenizerCommunicatorMixin):
@@ -1621,6 +1623,16 @@ class TokenizerManager(TokenizerCommunicatorMixin):
if getattr(recv_obj, "output_hidden_states", None):
meta_info["hidden_states"] = recv_obj.output_hidden_states[i]

+ if getattr(recv_obj, "output_routed_experts", None):
+ meta_info["routed_experts"] = recv_obj.output_routed_experts[i]
+ if recv_obj.output_routed_experts[i] is not None:
+ # print(f"{recv_obj.output_routed_experts[i].shape=}, {recv_obj.output_routed_experts[i].dtype=}")
+ # torch.save(recv_obj.output_routed_experts[i], f"/root/{recv_obj.output_routed_experts[i].shape[0]}.pt")
+ meta_info["routed_experts"] = pybase64.b64encode(
+ recv_obj.output_routed_experts[i].contiguous().numpy().tobytes(order="C")
+ ).decode("ascii")
+ else:
+ meta_info["routed_experts"] = None
+
if isinstance(recv_obj, BatchStrOutput):
state.text += recv_obj.output_strs[i]
if self.server_args.stream_output and state.obj.stream:
@@ -1747,12 +1751,13 @@ class TokenizerManager(TokenizerCommunicatorMixin):
@@ -1747,12 +1759,13 @@ class TokenizerManager(TokenizerCommunicatorMixin):
return

if len(recv_obj.input_token_logprobs_val) > 0:
Expand Down Expand Up @@ -1975,31 +2014,3 @@ index b3d72df05..ddfe0b178 100644


@dataclass
diff --git a/python/sglang/srt/managers/tokenizer_communicator_mixin.py b/python/sglang/srt/managers/tokenizer_communicator_mixin.py
index edbc52526..2cdc42755 100644
--- a/python/sglang/srt/managers/tokenizer_communicator_mixin.py
+++ b/python/sglang/srt/managers/tokenizer_communicator_mixin.py
@@ -421,6 +421,11 @@ class TokenizerCommunicatorMixin:
result = (await self.update_weights_from_distributed_communicator(obj))[
0
]
+ if result.success and obj.weight_version is not None:
+ self._update_weight_version_if_provided(obj.weight_version)
+ result.message += (
+ f" Weight version updated to {obj.weight_version}."
+ )
return result.success, result.message

# This means that weight sync
@@ -480,6 +485,11 @@ class TokenizerCommunicatorMixin:
async with self.is_pause_cond:
if self.is_pause:
result = (await self.update_weights_from_tensor_communicator(obj))[0]
+ if result.success and obj.weight_version is not None:
+ self._update_weight_version_if_provided(obj.weight_version)
+ result.message += (
+ f" Weight version updated to {obj.weight_version}."
+ )
return result.success, result.message

# This means that weight sync
2 changes: 1 addition & 1 deletion docker/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
nightly-dev-20251215a
nightly-dev-20251216a
20 changes: 15 additions & 5 deletions examples/multi_agent/agent_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,21 @@ async def generate_response(args, prompt, key):

url = f"http://{args.sglang_router_ip}:{args.sglang_router_port}/generate"

prompt_token_ids = tokenizer.encode(prompt, add_special_tokens=False)
if args.apply_chat_template:
assert isinstance(prompt, list), "prompt should be a list when apply_chat_template is True"
prompt_text = tokenizer.apply_chat_template(
prompt,
tokenize=False,
add_generation_prompt=True, # Add generation prompt for the assistant
**(args.apply_chat_template_kwargs or {}),
)
sample.prompt = prompt_text
else:
assert isinstance(prompt, str), "prompt should be a string when apply_chat_template is False"
sample.prompt = prompt
prompt_token_ids = tokenizer(sample.prompt, add_special_tokens=False)["input_ids"]
sample.tokens = prompt_token_ids
sample.prompt = prompt
input_token_ids = prompt_token_ids
prompt_length = len(input_token_ids)
prompt_length = len(prompt_token_ids)
current_sampling_params = deepcopy(sampling_params)
current_sampling_params["max_new_tokens"] = min(
sampling_params["max_new_tokens"], max_context_length - prompt_length
Expand All @@ -33,7 +43,7 @@ async def generate_response(args, prompt, key):
if current_sampling_params["max_new_tokens"] <= 0:
return None

payload = {"input_ids": input_token_ids, "sampling_params": current_sampling_params, "return_logprob": True}
payload = {"input_ids": prompt_token_ids, "sampling_params": current_sampling_params, "return_logprob": True}

output = await post(url, payload)

Expand Down
15 changes: 13 additions & 2 deletions examples/search-r1/generate_with_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,15 +151,26 @@ async def generate(args, sample: Sample, sampling_params) -> Sample:

# Handle partial rollout samples: continue generation from existing response
prompt = sample.prompt
prompt_tokens_ids = state.tokenizer(sample.prompt, add_special_tokens=False)["input_ids"]
if args.apply_chat_template:
assert isinstance(prompt, list), "prompt should be a list when apply_chat_template is True"
prompt_text = state.tokenizer.apply_chat_template(
prompt,
tokenize=False,
add_generation_prompt=True, # Add generation prompt for the assistant
**(args.apply_chat_template_kwargs or {}),
)
else:
assert isinstance(prompt, str), "prompt should be a string when apply_chat_template is False"
prompt_text = prompt
prompt_tokens_ids = state.tokenizer(prompt_text, add_special_tokens=False)["input_ids"]
response = ""
response_token_ids = []
loss_mask = []
rollout_log_probs = [] if SEARCH_R1_CONFIGS["return_logprob"] else None

for _turn_idx in range(SEARCH_R1_CONFIGS["max_turns"]):
payload = {
"text": prompt + response,
"text": prompt_text + response,
"sampling_params": sampling_params,
}
# Add log probability collection if enabled
Expand Down
47 changes: 47 additions & 0 deletions examples/train_infer_mismatch_helper/mis_fsdp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from typing import Any

import torch

from .mis import compute_mis_weights


def compute_mis_weights_fsdp(
args,
*,
pg_loss: torch.Tensor,
train_log_probs: list[torch.Tensor],
rollout_log_probs: list[torch.Tensor],
loss_masks: list[torch.Tensor],
**kwargs: Any,
) -> tuple[torch.Tensor, list[torch.Tensor], dict[str, torch.Tensor]]:
"""Compute masked importance sampling weights for FSDP. No context parallelism.

Args:
args: Arguments containing MIS settings (use_tis, tis_mode, etc.)
pg_loss: Policy gradient loss, flattened tensor [total_tokens]
train_log_probs: Training log probs, list of 1D tensors per sequence
rollout_log_probs: Rollout log probs, list of 1D tensors per sequence
loss_masks: Loss masks, list of 1D tensors per sequence
**kwargs: Additional arguments (cp_rank, cp_size, etc.) for compatibility

Returns:
pg_loss: Policy gradient loss with IS weights applied
modified_masks: Modified loss masks after rejection sampling
mis_metrics: Metrics dict with flattened tensors
"""
is_weights, modified_masks, is_metrics = compute_mis_weights(
args=args,
train_log_probs=train_log_probs,
rollout_log_probs=rollout_log_probs,
loss_masks=loss_masks,
)

result_metrics = {}
if is_weights is not None:
is_weights_flat = torch.cat(is_weights, dim=0)
pg_loss = pg_loss * is_weights_flat

for key, values in is_metrics.items():
result_metrics[f"mis_{key}"] = torch.cat(values, dim=0)

return pg_loss, modified_masks, result_metrics
Loading