diff --git a/docker/patch/latest/sglang.patch b/docker/patch/latest/sglang.patch index 54e88eaca7..7a019455ee 100644 --- a/docker/patch/latest/sglang.patch +++ b/docker/patch/latest/sglang.patch @@ -814,27 +814,23 @@ index 7f6f6a010..c4a673145 100644 if not get_global_server_args().sampling_backend == "ascend" or ( return_logprob and not SGLANG_RETURN_ORIGINAL_LOGPROB diff --git a/python/sglang/srt/managers/detokenizer_manager.py b/python/sglang/srt/managers/detokenizer_manager.py -index 87922077e..8cb6bad8d 100644 +index 87922077e..6507d8bf5 100644 --- a/python/sglang/srt/managers/detokenizer_manager.py +++ b/python/sglang/srt/managers/detokenizer_manager.py -@@ -247,6 +247,16 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin): +@@ -247,6 +247,12 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin): s.sent_offset = len(output_str) output_strs.append(incremental_output) + output_routed_experts = [] + if recv_obj.output_routed_experts is not None: + output_routed_experts = [ -+ ( -+ output_routed_experts.tolist() -+ if output_routed_experts is not None -+ else [] -+ ) ++ output_routed_experts + for output_routed_experts in recv_obj.output_routed_experts + ] return BatchStrOutput( rids=recv_obj.rids, http_worker_ipcs=recv_obj.http_worker_ipcs, -@@ -272,6 +282,7 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin): +@@ -272,6 +278,7 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin): output_token_ids_logprobs_idx=recv_obj.output_token_ids_logprobs_idx, output_token_entropy_val=recv_obj.output_token_entropy_val, output_hidden_states=recv_obj.output_hidden_states, @@ -1165,11 +1161,47 @@ index f8ebfc1f4..48b9a1a3b 100644 return ResumeMemoryOccupationReqOutput() def check_weights(self: Scheduler, recv_req: CheckWeightsReqInput): +diff --git a/python/sglang/srt/managers/tokenizer_communicator_mixin.py b/python/sglang/srt/managers/tokenizer_communicator_mixin.py +index edbc52526..2cdc42755 100644 +--- a/python/sglang/srt/managers/tokenizer_communicator_mixin.py ++++ b/python/sglang/srt/managers/tokenizer_communicator_mixin.py +@@ -421,6 +421,11 @@ class TokenizerCommunicatorMixin: + result = (await self.update_weights_from_distributed_communicator(obj))[ + 0 + ] ++ if result.success and obj.weight_version is not None: ++ self._update_weight_version_if_provided(obj.weight_version) ++ result.message += ( ++ f" Weight version updated to {obj.weight_version}." ++ ) + return result.success, result.message + + # This means that weight sync +@@ -480,6 +485,11 @@ class TokenizerCommunicatorMixin: + async with self.is_pause_cond: + if self.is_pause: + result = (await self.update_weights_from_tensor_communicator(obj))[0] ++ if result.success and obj.weight_version is not None: ++ self._update_weight_version_if_provided(obj.weight_version) ++ result.message += ( ++ f" Weight version updated to {obj.weight_version}." ++ ) + return result.success, result.message + + # This means that weight sync diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py -index b90cf0616..98d71d896 100644 +index b90cf0616..9b0992655 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py -@@ -888,6 +888,7 @@ class TokenizerManager(TokenizerCommunicatorMixin): +@@ -20,6 +20,7 @@ import logging + import math + import os + import pickle ++import pybase64 + import signal + import sys + import threading +@@ -888,6 +889,7 @@ class TokenizerManager(TokenizerCommunicatorMixin): session_params=session_params, custom_logit_processor=obj.custom_logit_processor, return_hidden_states=obj.return_hidden_states, @@ -1177,17 +1209,24 @@ index b90cf0616..98d71d896 100644 data_parallel_rank=obj.data_parallel_rank, priority=obj.priority, extra_key=obj.extra_key, -@@ -1621,6 +1622,9 @@ class TokenizerManager(TokenizerCommunicatorMixin): +@@ -1621,6 +1623,16 @@ class TokenizerManager(TokenizerCommunicatorMixin): if getattr(recv_obj, "output_hidden_states", None): meta_info["hidden_states"] = recv_obj.output_hidden_states[i] + if getattr(recv_obj, "output_routed_experts", None): -+ meta_info["routed_experts"] = recv_obj.output_routed_experts[i] ++ if recv_obj.output_routed_experts[i] is not None: ++ # print(f"{recv_obj.output_routed_experts[i].shape=}, {recv_obj.output_routed_experts[i].dtype=}") ++ # torch.save(recv_obj.output_routed_experts[i], f"/root/{recv_obj.output_routed_experts[i].shape[0]}.pt") ++ meta_info["routed_experts"] = pybase64.b64encode( ++ recv_obj.output_routed_experts[i].contiguous().numpy().tobytes(order="C") ++ ).decode("ascii") ++ else: ++ meta_info["routed_experts"] = None + if isinstance(recv_obj, BatchStrOutput): state.text += recv_obj.output_strs[i] if self.server_args.stream_output and state.obj.stream: -@@ -1747,12 +1751,13 @@ class TokenizerManager(TokenizerCommunicatorMixin): +@@ -1747,12 +1759,13 @@ class TokenizerManager(TokenizerCommunicatorMixin): return if len(recv_obj.input_token_logprobs_val) > 0: @@ -1975,31 +2014,3 @@ index b3d72df05..ddfe0b178 100644 @dataclass -diff --git a/python/sglang/srt/managers/tokenizer_communicator_mixin.py b/python/sglang/srt/managers/tokenizer_communicator_mixin.py -index edbc52526..2cdc42755 100644 ---- a/python/sglang/srt/managers/tokenizer_communicator_mixin.py -+++ b/python/sglang/srt/managers/tokenizer_communicator_mixin.py -@@ -421,6 +421,11 @@ class TokenizerCommunicatorMixin: - result = (await self.update_weights_from_distributed_communicator(obj))[ - 0 - ] -+ if result.success and obj.weight_version is not None: -+ self._update_weight_version_if_provided(obj.weight_version) -+ result.message += ( -+ f" Weight version updated to {obj.weight_version}." -+ ) - return result.success, result.message - - # This means that weight sync -@@ -480,6 +485,11 @@ class TokenizerCommunicatorMixin: - async with self.is_pause_cond: - if self.is_pause: - result = (await self.update_weights_from_tensor_communicator(obj))[0] -+ if result.success and obj.weight_version is not None: -+ self._update_weight_version_if_provided(obj.weight_version) -+ result.message += ( -+ f" Weight version updated to {obj.weight_version}." -+ ) - return result.success, result.message - - # This means that weight sync diff --git a/docker/version.txt b/docker/version.txt index 51694bc9e3..81449aa2b7 100644 --- a/docker/version.txt +++ b/docker/version.txt @@ -1 +1 @@ -nightly-dev-20251215a \ No newline at end of file +nightly-dev-20251216a \ No newline at end of file diff --git a/examples/multi_agent/agent_system.py b/examples/multi_agent/agent_system.py index b7e2b3a898..49b62ca6d3 100644 --- a/examples/multi_agent/agent_system.py +++ b/examples/multi_agent/agent_system.py @@ -20,11 +20,21 @@ async def generate_response(args, prompt, key): url = f"http://{args.sglang_router_ip}:{args.sglang_router_port}/generate" - prompt_token_ids = tokenizer.encode(prompt, add_special_tokens=False) + if args.apply_chat_template: + assert isinstance(prompt, list), "prompt should be a list when apply_chat_template is True" + prompt_text = tokenizer.apply_chat_template( + prompt, + tokenize=False, + add_generation_prompt=True, # Add generation prompt for the assistant + **(args.apply_chat_template_kwargs or {}), + ) + sample.prompt = prompt_text + else: + assert isinstance(prompt, str), "prompt should be a string when apply_chat_template is False" + sample.prompt = prompt + prompt_token_ids = tokenizer(sample.prompt, add_special_tokens=False)["input_ids"] sample.tokens = prompt_token_ids - sample.prompt = prompt - input_token_ids = prompt_token_ids - prompt_length = len(input_token_ids) + prompt_length = len(prompt_token_ids) current_sampling_params = deepcopy(sampling_params) current_sampling_params["max_new_tokens"] = min( sampling_params["max_new_tokens"], max_context_length - prompt_length @@ -33,7 +43,7 @@ async def generate_response(args, prompt, key): if current_sampling_params["max_new_tokens"] <= 0: return None - payload = {"input_ids": input_token_ids, "sampling_params": current_sampling_params, "return_logprob": True} + payload = {"input_ids": prompt_token_ids, "sampling_params": current_sampling_params, "return_logprob": True} output = await post(url, payload) diff --git a/examples/search-r1/generate_with_search.py b/examples/search-r1/generate_with_search.py index 968c3bebbd..65ea9e399c 100644 --- a/examples/search-r1/generate_with_search.py +++ b/examples/search-r1/generate_with_search.py @@ -151,7 +151,18 @@ async def generate(args, sample: Sample, sampling_params) -> Sample: # Handle partial rollout samples: continue generation from existing response prompt = sample.prompt - prompt_tokens_ids = state.tokenizer(sample.prompt, add_special_tokens=False)["input_ids"] + if args.apply_chat_template: + assert isinstance(prompt, list), "prompt should be a list when apply_chat_template is True" + prompt_text = state.tokenizer.apply_chat_template( + prompt, + tokenize=False, + add_generation_prompt=True, # Add generation prompt for the assistant + **(args.apply_chat_template_kwargs or {}), + ) + else: + assert isinstance(prompt, str), "prompt should be a string when apply_chat_template is False" + prompt_text = prompt + prompt_tokens_ids = state.tokenizer(prompt_text, add_special_tokens=False)["input_ids"] response = "" response_token_ids = [] loss_mask = [] @@ -159,7 +170,7 @@ async def generate(args, sample: Sample, sampling_params) -> Sample: for _turn_idx in range(SEARCH_R1_CONFIGS["max_turns"]): payload = { - "text": prompt + response, + "text": prompt_text + response, "sampling_params": sampling_params, } # Add log probability collection if enabled diff --git a/examples/train_infer_mismatch_helper/mis_fsdp.py b/examples/train_infer_mismatch_helper/mis_fsdp.py new file mode 100644 index 0000000000..0aa8275cb2 --- /dev/null +++ b/examples/train_infer_mismatch_helper/mis_fsdp.py @@ -0,0 +1,47 @@ +from typing import Any + +import torch + +from .mis import compute_mis_weights + + +def compute_mis_weights_fsdp( + args, + *, + pg_loss: torch.Tensor, + train_log_probs: list[torch.Tensor], + rollout_log_probs: list[torch.Tensor], + loss_masks: list[torch.Tensor], + **kwargs: Any, +) -> tuple[torch.Tensor, list[torch.Tensor], dict[str, torch.Tensor]]: + """Compute masked importance sampling weights for FSDP. No context parallelism. + + Args: + args: Arguments containing MIS settings (use_tis, tis_mode, etc.) + pg_loss: Policy gradient loss, flattened tensor [total_tokens] + train_log_probs: Training log probs, list of 1D tensors per sequence + rollout_log_probs: Rollout log probs, list of 1D tensors per sequence + loss_masks: Loss masks, list of 1D tensors per sequence + **kwargs: Additional arguments (cp_rank, cp_size, etc.) for compatibility + + Returns: + pg_loss: Policy gradient loss with IS weights applied + modified_masks: Modified loss masks after rejection sampling + mis_metrics: Metrics dict with flattened tensors + """ + is_weights, modified_masks, is_metrics = compute_mis_weights( + args=args, + train_log_probs=train_log_probs, + rollout_log_probs=rollout_log_probs, + loss_masks=loss_masks, + ) + + result_metrics = {} + if is_weights is not None: + is_weights_flat = torch.cat(is_weights, dim=0) + pg_loss = pg_loss * is_weights_flat + + for key, values in is_metrics.items(): + result_metrics[f"mis_{key}"] = torch.cat(values, dim=0) + + return pg_loss, modified_masks, result_metrics diff --git a/examples/train_infer_mismatch_helper/run-qwen3-4b-fsdp-mis.sh b/examples/train_infer_mismatch_helper/run-qwen3-4b-fsdp-mis.sh new file mode 100644 index 0000000000..e6d2d5bc3a --- /dev/null +++ b/examples/train_infer_mismatch_helper/run-qwen3-4b-fsdp-mis.sh @@ -0,0 +1,148 @@ +#!/bin/bash + +# for rerun the task +pkill -9 sglang +sleep 3 +ray stop --force +pkill -9 ray +pkill -9 python +sleep 3 +pkill -9 ray +pkill -9 python + + + + +set -ex + +# will prevent ray from buffering stdout/stderr +export PYTHONBUFFERED=16 +export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 +NVLINK_COUNT=$(nvidia-smi | grep -o "NVLink" | wc -l) +if [ "$NVLINK_COUNT" -gt 0 ]; then + HAS_NVLINK=1 +else + HAS_NVLINK=0 +fi +echo "HAS_NVLINK: $HAS_NVLINK (detected $NVLINK_COUNT NVLink references)" + + + +SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" + +RUN_ID=${RUN_ID:-"run_$(date +%Y%m%d_%H%M%S)"} +LOAD_SAVE_PATH="/root/shared_data/${RUN_ID}/checkpoints" + +CKPT_ARGS=( + --hf-checkpoint /root/Qwen3-4B + --load /root/Qwen3-4B + --ref-load /root/Qwen3-4B +) + +ROLLOUT_ARGS=( + --prompt-data /root/dapo-math-17k/dapo-math-17k.jsonl + --input-key prompt + --label-key label + --apply-chat-template + --rollout-shuffle + --balance-data + --rm-type deepscaler + --num-rollout 100 + --rollout-batch-size 8 + --n-samples-per-prompt 8 + --rollout-max-response-len 4096 + --rollout-temperature 0.8 + --global-batch-size 64 +) + +GRPO_ARGS=( + --use-kl-loss + --advantage-estimator grpo + --kl-loss-coef 0.00 + --kl-loss-type low_var_kl + --kl-coef 0.00 + --entropy-coef 0.00 + --eps-clip 0.2 + --eps-clip-high 0.28 + --use-tis +) + +OPTIMIZER_ARGS=( + --optimizer adam + --lr 1e-6 + --lr-decay-style constant + --weight-decay 0.1 + --adam-beta1 0.9 + --adam-beta2 0.98 +) + +WANDB_ARGS=( + --use-wandb + --wandb-project slime-dev-mcore-fsdp + --wandb-group qwen3-4B-fsdp-1130-ref + --wandb-key ${WANDB_API_KEY} +) + +SGLANG_ARGS=( + --rollout-num-gpus-per-engine 1 + --sglang-mem-fraction-static 0.75 + --sglang-decode-log-interval 1000 + --sglang-chunked-prefill-size 4096 + --sglang-attention-backend fa3 +) + +TRAIN_BACKEND_ARGS=( + --train-backend fsdp + --update-weight-buffer-size 536870912 + --gradient-checkpointing + --attn-implementation flash_attention_3 + --train-env-vars '{"PYTORCH_CUDA_ALLOC_CONF":"expandable_segments:True"}' +) + +PERF_ARGS=( + --use-dynamic-batch-size + --max-tokens-per-gpu 9216 +) + +MISC_ARGS=( + --actor-num-nodes 1 + --actor-num-gpus-per-node 8 + --colocate + --use-fault-tolerance + --dump-details /root/shared_data/qwen3-4B-fsdp-1116-noref/dump_details + # --fsdp-cpu-offload +) + +CUSTOM_ARGS=( + --custom-config-path examples/train_infer_mismatch_helper/mis.yaml + --custom-tis-function-path examples.train_infer_mismatch_helper.mis_fsdp.compute_mis_weights_fsdp +) + +# launch the master node of ray in container - 8 GPUs for training +export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} +ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus 8 --disable-usage-stats + + +RUNTIME_ENV_JSON="{ + \"env_vars\": { + \"PYTHONPATH\": \"/root/Megatron-LM/:${SCRIPT_DIR}\", + \"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\" + } +}" + + +ray job submit --address="http://127.0.0.1:8265" \ + --runtime-env-json="${RUNTIME_ENV_JSON}" \ + -- python3 train.py \ + ${CKPT_ARGS[@]} \ + ${ROLLOUT_ARGS[@]} \ + ${OPTIMIZER_ARGS[@]} \ + ${GRPO_ARGS[@]} \ + ${WANDB_ARGS[@]} \ + ${SGLANG_ARGS[@]} \ + ${TRAIN_BACKEND_ARGS[@]} \ + ${PERF_ARGS[@]} \ + ${MISC_ARGS[@]} \ + ${CUSTOM_ARGS[@]} + + diff --git a/slime/backends/fsdp_utils/actor.py b/slime/backends/fsdp_utils/actor.py index c39d77f149..71ea553d15 100644 --- a/slime/backends/fsdp_utils/actor.py +++ b/slime/backends/fsdp_utils/actor.py @@ -1,5 +1,6 @@ import logging import os +import random from argparse import Namespace from itertools import accumulate @@ -18,6 +19,7 @@ from slime.utils.distributed_utils import get_gloo_group from slime.utils.memory_utils import clear_memory, print_memory from slime.utils.metric_utils import compute_rollout_step +from slime.utils.misc import load_function from slime.utils.ppo_utils import compute_approx_kl, compute_gspo_kl, compute_opsm_mask, compute_policy_loss from slime.utils.processing_utils import load_processor, load_tokenizer from slime.utils.ray_utils import Box @@ -654,26 +656,41 @@ def _has_rollout_log_probs(batch) -> bool: else None ) - # Apply TIS before sample mean calculation + tis_metrics = {} if self.args.use_tis: - # Apply TIS off-policy correction using importance sampling assert ( has_rollout_log_probs and rollout_log_probs is not None - ), "rollout_log_probs must be provided as non-empty torch.Tensor for TIS" - tis_clip, tis, tis_clipfrac = self._compute_tis_weights( - old_log_probs=old_log_probs, - rollout_log_probs=rollout_log_probs, - loss_masks=loss_masks, - response_lengths=response_lengths, - ) - ois = (-ppo_kl).exp() - - pg_loss = pg_loss * tis_clip - - assert not self.args.calculate_per_token_loss, "calculate_per_token_loss not yet implemented" - pg_loss = sum_of_sample_mean(pg_loss, response_lengths, loss_masks) - pg_clipfrac = sum_of_sample_mean(pg_clipfrac, response_lengths, loss_masks) - ppo_kl = sum_of_sample_mean(ppo_kl.abs(), response_lengths, loss_masks) + ), "rollout_log_probs must be provided as non-empty torch.Tensor for TIS/MIS" + + train_log_probs_list = list(log_probs.split(response_lengths, dim=0)) + rollout_log_probs_list = list(rollout_log_probs.split(response_lengths, dim=0)) + + tis_kwargs = { + "args": self.args, + "pg_loss": pg_loss, + "train_log_probs": train_log_probs_list, + "rollout_log_probs": rollout_log_probs_list, + "loss_masks": loss_masks, + "response_lengths": response_lengths, + "cp_rank": self.cp_rank, + "cp_size": self.cp_size, + "cp_group": self.cp_group, + } + + if self.args.custom_tis_function_path is not None: + tis_func = load_function(self.args.custom_tis_function_path) + else: + tis_func = vanilla_tis_function_fsdp + pg_loss, loss_masks, tis_metrics = tis_func(**tis_kwargs) + + if getattr(self.args, "calculate_per_token_loss", False): + pg_loss = sum_of_token(pg_loss, response_lengths, loss_masks) + pg_clipfrac = sum_of_token(pg_clipfrac, response_lengths, loss_masks) + ppo_kl = sum_of_token(ppo_kl.abs(), response_lengths, loss_masks) + else: + pg_loss = sum_of_sample_mean(pg_loss, response_lengths, loss_masks) + pg_clipfrac = sum_of_sample_mean(pg_clipfrac, response_lengths, loss_masks) + ppo_kl = sum_of_sample_mean(ppo_kl.abs(), response_lengths, loss_masks) # Only compare rollout vs. train log probs when they originate from different stages. train_rollout_logprob_abs_diff = None @@ -720,10 +737,12 @@ def _has_rollout_log_probs(batch) -> bool: if self.args.use_opsm: reported["opsm_clipfrac"] = opsm_clipfrac - if self.args.use_tis and tis is not None: - reported["tis"] = sum_of_sample_mean(tis, response_lengths, loss_masks).detach() - reported["ois"] = sum_of_sample_mean(ois, response_lengths, loss_masks).detach() - reported["tis_clipfrac"] = sum_of_sample_mean(tis_clipfrac.float(), response_lengths, loss_masks).detach() + if self.args.use_tis and tis_metrics: + for k, v in tis_metrics.items(): + if getattr(self.args, "calculate_per_token_loss", False): + reported[k] = sum_of_token(v, response_lengths, loss_masks).detach() + else: + reported[k] = sum_of_sample_mean(v, response_lengths, loss_masks).detach() # Scale loss for gradient accumulation loss = loss * self.dp_size / self.args.global_batch_size @@ -791,6 +810,15 @@ def update_weights(self) -> None: # type: ignore[override] dist.barrier(group=get_gloo_group()) self.weight_updater.update_weights() + + if self.args.ci_test and len(rollout_engines) > 0: + engine = random.choice(rollout_engines) + engine_version = ray.get(engine.get_weight_version.remote()) + if str(engine_version) != str(self.weight_updater.weight_version): + raise RuntimeError( + f"Weight version mismatch! Engine: {engine_version}, Updater: {self.weight_updater.weight_version}" + ) + clear_memory() def _compute_tis_weights( @@ -1143,3 +1171,51 @@ def apply_fsdp2(model, mesh=None, cpu_offload=False, args=None): fully_shard(model, **fsdp_kwargs) return model + + +def sum_of_token(x: torch.Tensor, response_lengths: list[int], loss_masks: list[torch.Tensor]) -> torch.Tensor: + return sum( + [ + (x_i * loss_mask_i).sum() + for x_i, loss_mask_i in zip(x.split(response_lengths, dim=0), loss_masks, strict=False) + ] + ) + + +def vanilla_tis_function_fsdp( + args, + *, + pg_loss: torch.Tensor, + train_log_probs: list[torch.Tensor], + rollout_log_probs: list[torch.Tensor], + loss_masks: list[torch.Tensor], + **kwargs, +) -> tuple[torch.Tensor, list[torch.Tensor], dict[str, torch.Tensor]]: + """Apply TIS off-policy correction using importance sampling. + + Parameters: + args: Arguments containing TIS settings. + pg_loss: Policy gradient loss tensor of shape [total_seq_len - 1]. + train_log_probs: List of tensors containing training log-probabilities + for each sequence. + rollout_log_probs: List of tensors containing rollout log-probabilities + for each sequence. + loss_masks: List of tensors containing loss masks for each sequence. + """ + rollout_log_probs_flat = torch.cat(rollout_log_probs, dim=0) + train_log_probs_flat = torch.cat(train_log_probs, dim=0) + + tis = torch.exp(train_log_probs_flat - rollout_log_probs_flat) + tis_abs = (tis - 1).abs() + + tis_clip = torch.clamp(tis, min=getattr(args, "tis_clip_low", 0.1), max=getattr(args, "tis_clip", 2.0)) + tis_clipfrac = (tis_clip != tis).float() + + metrics = { + "tis": tis.clone().detach(), + "tis_clipfrac": tis_clipfrac.clone().detach(), + "tis_abs": tis_abs.clone().detach(), + } + pg_loss = pg_loss * tis_clip + + return pg_loss, loss_masks, metrics diff --git a/slime/backends/fsdp_utils/update_weight_utils.py b/slime/backends/fsdp_utils/update_weight_utils.py index e1970ddf66..4c0ce54781 100644 --- a/slime/backends/fsdp_utils/update_weight_utils.py +++ b/slime/backends/fsdp_utils/update_weight_utils.py @@ -33,6 +33,7 @@ class UpdateWeight(abc.ABC): def __init__(self, args: Namespace, model: torch.nn.Module) -> None: self.args = args self.model = model + self.weight_version = 0 @abc.abstractmethod def connect_rollout_engines( @@ -43,6 +44,7 @@ def connect_rollout_engines( pass def update_weights(self) -> None: + self.weight_version += 1 bucket = [] bucket_size = 0 for name, param in self.model.state_dict().items(): @@ -71,10 +73,10 @@ def update_weights(self) -> None: def wait_and_update_bucket_weights(self, bucket): bucket = [(name, param.wait()) if hasattr(param, "wait") else (name, param) for name, param in bucket] - self.update_bucket_weights(bucket) + self.update_bucket_weights(bucket, weight_version=self.weight_version) @abc.abstractmethod - def update_bucket_weights(self, named_tensors) -> None: + def update_bucket_weights(self, named_tensors, weight_version=None) -> None: pass @@ -114,7 +116,7 @@ def connect_rollout_engines( # Calculate TP rank within this SGLang engine group self.tp_rank = dist.get_rank() - start_rank - def update_bucket_weights(self, named_tensors) -> None: + def update_bucket_weights(self, named_tensors, weight_version=None) -> None: monkey_patch_torch_reductions() # Use flattened bucket approach similar to Megatron logger.info("Using flattened tensor bucket") @@ -162,6 +164,7 @@ def update_bucket_weights(self, named_tensors) -> None: "serialized_named_tensors": [tensors[i] for tensors in gathered_serialized_batches], "load_format": "flattened_bucket", "flush_cache": False, + "weight_version": str(weight_version), } ref = self._ipc_engine.update_weights_from_tensor.remote(**kwargs) ray.get(ref) @@ -174,10 +177,6 @@ def update_bucket_weights(self, named_tensors) -> None: class UpdateWeightFromDistributed(UpdateWeight): """Broadcast weights via a temporary NCCL group to rollout engines.""" - def __init__(self, args: Namespace, model: torch.nn.Module) -> None: - self.args = args - self.model = model - def connect_rollout_engines( self, rollout_engines: Sequence[ActorHandle], @@ -220,7 +219,7 @@ def connect_rollout_engines( ) ray.get(refs) - def update_bucket_weights(self, named_tensors) -> None: + def update_bucket_weights(self, named_tensors, weight_version=None) -> None: """Send names/dtypes/shapes metadata to engines, then broadcast tensors. Ensures tensors are contiguous; when `world_size == 1`, converts DTensors @@ -235,6 +234,7 @@ def update_bucket_weights(self, named_tensors) -> None: dtypes=[param.dtype for _, param in named_tensors], shapes=[param.shape for _, param in named_tensors], group_name=self._group_name, + weight_version=str(weight_version), ) for engine in self.rollout_engines ] diff --git a/slime/backends/megatron_utils/actor.py b/slime/backends/megatron_utils/actor.py index 3ba954ea48..8c4f700c81 100644 --- a/slime/backends/megatron_utils/actor.py +++ b/slime/backends/megatron_utils/actor.py @@ -1,5 +1,6 @@ import logging import os +import random import socket from argparse import Namespace from contextlib import nullcontext @@ -474,6 +475,14 @@ def update_weights(self) -> None: self.weight_updater.update_weights() print_memory("after update_weights") + if self.args.ci_test and len(rollout_engines) > 0: + engine = random.choice(rollout_engines) + engine_version = ray.get(engine.get_weight_version.remote()) + if str(engine_version) != str(self.weight_updater.weight_version): + raise RuntimeError( + f"Weight version mismatch! Engine: {engine_version}, Updater: {self.weight_updater.weight_version}" + ) + if getattr(self.args, "keep_old_actor", False): if self.args.update_weights_interval == 1: logger.info("updating model queue: rollout_actor -> old_actor, actor -> rollout_actor") diff --git a/slime/rollout/sglang_rollout.py b/slime/rollout/sglang_rollout.py index d4a4ba6605..db97004817 100644 --- a/slime/rollout/sglang_rollout.py +++ b/slime/rollout/sglang_rollout.py @@ -7,6 +7,7 @@ from typing import Any import numpy as np +import pybase64 import sglang_router from packaging.version import parse from tqdm import tqdm @@ -101,6 +102,7 @@ async def generate(args: Namespace, sample: Sample, sampling_params: dict[str, A state.tokenizer, state.processor, sample.metadata, + args.apply_chat_template, args.apply_chat_template_kwargs, ) @@ -185,8 +187,14 @@ async def generate(args: Namespace, sample: Sample, sampling_params: dict[str, A sample.weight_versions.append(output["meta_info"]["weight_version"]) if "routed_experts" in output["meta_info"]: - assert len(output["meta_info"]["routed_experts"]) == len(sample.tokens) - 1 - sample.rollout_routed_experts = np.array(output["meta_info"]["routed_experts"]) + sample.rollout_routed_experts = np.frombuffer( + pybase64.b64decode(output["meta_info"]["routed_experts"].encode("ascii")), + dtype=np.int32, + ).reshape( + len(sample.tokens) - 1, + args.num_layers, + args.moe_router_topk, + ) match output["meta_info"]["finish_reason"]["type"]: case "length": diff --git a/slime/utils/data.py b/slime/utils/data.py index 3b3e6f4b21..45ee103fa4 100644 --- a/slime/utils/data.py +++ b/slime/utils/data.py @@ -49,23 +49,32 @@ def _parse_generalized_path(s: str): return s, None -def _should_skip_prompt(prompt, tokenizer, processor, max_length, apply_chat_template_kwargs): +def _should_skip_prompt( + prompt, tokenizer, processor, metadata, max_length, apply_chat_template, apply_chat_template_kwargs +): if max_length is None: return False from slime.utils.processing_utils import prepare_model_inputs - input_ids, _ = prepare_model_inputs(prompt, tokenizer, processor, None, apply_chat_template_kwargs) + input_ids, _ = prepare_model_inputs( + prompt, tokenizer, processor, metadata, apply_chat_template, apply_chat_template_kwargs + ) return len(input_ids) > max_length -def _build_messages(data: dict, prompt_key: str, multimodal_keys: dict = None): - messages = data.get(prompt_key) +def _build_messages(data: dict, prompt_key: str, as_conversation: bool, multimodal_keys: dict = None): + prompt = data.get(prompt_key) - if isinstance(messages, str): - messages = [{"role": "user", "content": messages}] + if isinstance(prompt, str): + # If prompt is a string and we don't apply chat template, return the prompt as is. + if not as_conversation: + return prompt + else: + prompt = [{"role": "user", "content": prompt}] if multimodal_keys: + assert as_conversation, "as_conversation must be True when multimodal_keys is not None" # Build mapping: placeholder -> (MultimodalType, content_list) multimodals = {} for type_name, data_key in multimodal_keys.items(): @@ -75,7 +84,7 @@ def _build_messages(data: dict, prompt_key: str, multimodal_keys: dict = None): pattern = "(" + "|".join(re.escape(p) for p in multimodals.keys()) + ")" - for message in messages: + for message in prompt: if isinstance(message["content"], str): content_list = [] for segment in re.split(pattern, message["content"]): @@ -105,7 +114,7 @@ def _build_messages(data: dict, prompt_key: str, multimodal_keys: dict = None): f"Unsupported content type: {type(message['content'])}, expected str or list of dicts" ) - return messages + return prompt class Dataset: @@ -127,7 +136,8 @@ def __init__( ): self.origin_samples = [] for data in read_file(path): - prompt = _build_messages(data, prompt_key, multimodal_keys) + as_conversation = apply_chat_template + prompt = _build_messages(data, prompt_key, as_conversation, multimodal_keys) metadata = data.get(metadata_key) or {} if tool_key is not None and tool_key in data: @@ -140,7 +150,9 @@ def __init__( metadata["tools"] = tools # TODO: this is slow. - if _should_skip_prompt(prompt, tokenizer, processor, max_length, apply_chat_template_kwargs): + if _should_skip_prompt( + prompt, tokenizer, processor, metadata, max_length, apply_chat_template, apply_chat_template_kwargs + ): continue self.origin_samples.append( diff --git a/slime/utils/processing_utils.py b/slime/utils/processing_utils.py index fc837e613c..f5952d5d0c 100644 --- a/slime/utils/processing_utils.py +++ b/slime/utils/processing_utils.py @@ -2,6 +2,7 @@ import io import logging +import numpy as np from transformers import AutoProcessor, AutoTokenizer, PreTrainedTokenizerBase, ProcessorMixin logger = logging.getLogger(__name__) @@ -25,7 +26,9 @@ def load_processor(name_or_path: str, **kwargs): return proc -def prepare_model_inputs(prompt, tokenizer, processor=None, metadata=None, apply_chat_template_kwargs=None): +def prepare_model_inputs( + prompt, tokenizer, processor=None, metadata=None, apply_chat_template=False, apply_chat_template_kwargs=None +): """Prepare all inputs for model inference. Returns: @@ -34,13 +37,24 @@ def prepare_model_inputs(prompt, tokenizer, processor=None, metadata=None, apply - extra_info: Dict with 'images', 'videos', 'multimodal_inputs' (or empty dict) """ tools = metadata.get("tools") if metadata else None - text_prompt = tokenizer.apply_chat_template( - prompt, - tools=tools, - tokenize=False, - add_generation_prompt=True, - **(apply_chat_template_kwargs or {}), - ) + if isinstance(prompt, (list, np.ndarray)): + assert ( + apply_chat_template + ), f"apply_chat_template must be True when prompt is a list or numpy array, current prompt is {prompt}" + text_prompt = tokenizer.apply_chat_template( + prompt, + tools=tools, + tokenize=False, + add_generation_prompt=True, + **(apply_chat_template_kwargs or {}), + ) + elif isinstance(prompt, str): + assert ( + not apply_chat_template + ), f"apply_chat_template must be False when prompt is a string, current prompt is {prompt}" + text_prompt = prompt + else: + raise ValueError(f"Invalid prompt type: {type(prompt)}, current prompt is {prompt}") if not processor: input_ids = tokenizer.encode(text_prompt, add_special_tokens=False) diff --git a/tests/test_fsdp_mis.py b/tests/test_fsdp_mis.py new file mode 100644 index 0000000000..2d1d384d41 --- /dev/null +++ b/tests/test_fsdp_mis.py @@ -0,0 +1,239 @@ +""" +Usage: + PYTHONPATH=/root/Megatron-LM python -m pytest tests/test_fsdp_mis.py -v +""" +from argparse import Namespace + +import pytest +import torch + +from examples.train_infer_mismatch_helper.mis_fsdp import compute_mis_weights_fsdp +from slime.backends.fsdp_utils.actor import vanilla_tis_function_fsdp + + +def create_mis_args(**overrides): + defaults = { + "use_tis": True, + "tis_mode": "truncate", + "tis_level": "token", + "tis_upper_bound": 2.0, + "tis_lower_bound": 0.5, + "tis_batch_normalize": False, + "use_rs": False, + "rs_lower_bound": None, + "rs_upper_bound": None, + "rs_level": "token", + "rs_veto_threshold": None, + } + defaults.update(overrides) + return Namespace(**defaults) + + +@pytest.mark.parametrize( + "use_tis,tis_clip,tis_clip_low", + [ + (True, 2.0, 0.5), + (True, 5.0, 0.1), + (True, 1.5, 0.8), + ], +) +def test_vanilla_tis_clipping(use_tis, tis_clip, tis_clip_low): + args = Namespace( + use_tis=use_tis, + tis_clip=tis_clip, + tis_clip_low=tis_clip_low, + ) + + train_log_probs = [ + torch.tensor([-1.0, -1.5, -2.0]), + torch.tensor([-0.5, -1.0, -1.5, -2.0, -2.5]), + ] + rollout_log_probs = [ + torch.tensor([-2.0, -1.0, -1.5]), + torch.tensor([-1.0, -0.5, -1.0, -1.5, -2.0]), + ] + loss_masks = [torch.ones(3), torch.ones(5)] + pg_loss = torch.ones(8) + + pg_loss_out, masks_out, metrics = vanilla_tis_function_fsdp( + args, + pg_loss=pg_loss, + train_log_probs=train_log_probs, + rollout_log_probs=rollout_log_probs, + loss_masks=loss_masks, + ) + + raw_ratios = torch.exp(torch.cat(train_log_probs) - torch.cat(rollout_log_probs)) + expected_weights = raw_ratios.clamp(min=tis_clip_low, max=tis_clip) + + assert torch.allclose(pg_loss_out, expected_weights, atol=1e-5) + assert torch.allclose(raw_ratios[0], torch.exp(torch.tensor(1.0)), atol=1e-5) + assert torch.allclose( + expected_weights[0], torch.tensor(min(tis_clip, torch.exp(torch.tensor(1.0)).item())), atol=1e-5 + ) + + +@pytest.mark.parametrize( + "tis_mode,tis_upper_bound,tis_lower_bound", + [ + ("mask", 2.0, 0.5), + ("truncate", 2.0, 0.5), + ("clip", 2.0, 0.5), + ], +) +def test_mis_modes(tis_mode, tis_upper_bound, tis_lower_bound): + args = create_mis_args( + tis_mode=tis_mode, + tis_upper_bound=tis_upper_bound, + tis_lower_bound=tis_lower_bound, + ) + + train_log_probs = [ + torch.tensor([-1.0, -1.5, -2.0]), + torch.tensor([-0.1, -0.2, -0.3]), + ] + rollout_log_probs = [ + torch.tensor([-2.0, -1.0, -1.5]), + torch.tensor([-5.0, -4.5, -4.0]), + ] + loss_masks = [torch.ones(3), torch.ones(3)] + pg_loss = torch.ones(6) + + pg_loss_out, masks_out, metrics = compute_mis_weights_fsdp( + args, + pg_loss=pg_loss, + train_log_probs=train_log_probs, + rollout_log_probs=rollout_log_probs, + loss_masks=loss_masks, + ) + + raw_ratios = [torch.exp(t - r) for t, r in zip(train_log_probs, rollout_log_probs, strict=False)] + + if tis_mode == "mask": + expected_masks = [((r >= tis_lower_bound) & (r <= tis_upper_bound)).float() for r in raw_ratios] + for mask_out, expected_mask in zip(masks_out, expected_masks, strict=False): + assert torch.equal(mask_out, expected_mask) + elif tis_mode == "truncate": + expected_weights = torch.cat([r.clamp(0, tis_upper_bound) for r in raw_ratios]) + assert torch.allclose(pg_loss_out, expected_weights, atol=1e-5) + elif tis_mode == "clip": + expected_weights = torch.cat([r.clamp(tis_lower_bound, tis_upper_bound) for r in raw_ratios]) + assert torch.allclose(pg_loss_out, expected_weights, atol=1e-5) + + +@pytest.mark.parametrize("batch_normalize", [True, False]) +def test_mis_batch_normalization(batch_normalize): + args = create_mis_args(tis_batch_normalize=batch_normalize) + + train_log_probs = [ + torch.tensor([-1.0, -1.5, -2.0]), + torch.tensor([-0.5, -1.0, -1.5]), + ] + rollout_log_probs = [ + torch.tensor([-2.0, -1.0, -1.5]), + torch.tensor([-1.0, -0.5, -1.0]), + ] + loss_masks = [torch.ones(3), torch.ones(3)] + pg_loss = torch.ones(6) + + pg_loss_out, masks_out, metrics = compute_mis_weights_fsdp( + args, + pg_loss=pg_loss, + train_log_probs=train_log_probs, + rollout_log_probs=rollout_log_probs, + loss_masks=loss_masks, + ) + + raw_weights_flat = torch.exp(torch.cat(train_log_probs) - torch.cat(rollout_log_probs)) + clipped_weights = raw_weights_flat.clamp(0.5, 2.0) + + if batch_normalize: + weights_mean = clipped_weights.mean() + expected_weights = clipped_weights / weights_mean + assert torch.allclose(pg_loss_out, expected_weights, atol=1e-5) + assert torch.allclose(pg_loss_out.mean(), torch.tensor(1.0), atol=1e-5) + else: + assert torch.allclose(pg_loss_out, clipped_weights, atol=1e-5) + + +def test_mis_rejection_sampling(): + args = create_mis_args( + use_rs=True, + rs_lower_bound=0.5, + rs_upper_bound=3.0, + rs_veto_threshold=0.01, + rs_level="token", + ) + + train_log_probs = [ + torch.tensor([-1.0, -1.5, -2.0]), + torch.tensor([-10.0, -0.5, -1.0]), + ] + rollout_log_probs = [ + torch.tensor([-1.5, -1.2, -1.8]), + torch.tensor([-5.0, -0.8, -0.9]), + ] + loss_masks = [torch.ones(3), torch.ones(3)] + pg_loss = torch.ones(6) + + pg_loss_out, masks_out, metrics = compute_mis_weights_fsdp( + args, + pg_loss=pg_loss, + train_log_probs=train_log_probs, + rollout_log_probs=rollout_log_probs, + loss_masks=loss_masks, + ) + + raw_ratios = [torch.exp(t - r) for t, r in zip(train_log_probs, rollout_log_probs, strict=False)] + + assert torch.allclose(raw_ratios[1][0], torch.exp(torch.tensor(-5.0)), atol=1e-5) + assert raw_ratios[1][0] < 0.01 + + rs_mask_seq0 = ((raw_ratios[0] >= 0.5) & (raw_ratios[0] <= 3.0)).float() + assert torch.equal(masks_out[0], rs_mask_seq0) + assert torch.equal(masks_out[1], torch.zeros(3)) + + +@pytest.mark.parametrize("tis_level", ["token", "sequence", "geometric"]) +def test_mis_aggregation_levels(tis_level): + args = create_mis_args( + tis_level=tis_level, + tis_mode="truncate", + ) + + train_log_probs = [ + torch.tensor([-1.0, -1.5, -2.0]), + torch.tensor([-0.5, -1.0, -1.5]), + ] + rollout_log_probs = [ + torch.tensor([-2.0, -1.0, -1.5]), + torch.tensor([-1.0, -0.5, -1.0]), + ] + loss_masks = [torch.ones(3), torch.ones(3)] + pg_loss = torch.ones(6) + + pg_loss_out, masks_out, metrics = compute_mis_weights_fsdp( + args, + pg_loss=pg_loss, + train_log_probs=train_log_probs, + rollout_log_probs=rollout_log_probs, + loss_masks=loss_masks, + ) + + log_diffs = [t - r for t, r in zip(train_log_probs, rollout_log_probs, strict=False)] + + if tis_level == "token": + expected_weights = torch.cat([torch.exp(ld).clamp(0, 2.0) for ld in log_diffs]) + assert torch.allclose(pg_loss_out, expected_weights, atol=1e-5) + elif tis_level == "sequence": + seq_weights = [torch.exp(ld.sum()).clamp(0, 2.0).expand_as(ld) for ld in log_diffs] + expected_weights = torch.cat(seq_weights) + assert torch.allclose(pg_loss_out, expected_weights, atol=1e-5) + assert torch.allclose(pg_loss_out[0], pg_loss_out[1]) + assert torch.allclose(pg_loss_out[0], pg_loss_out[2]) + elif tis_level == "geometric": + geo_weights = [torch.exp(ld.mean()).clamp(0, 2.0).expand_as(ld) for ld in log_diffs] + expected_weights = torch.cat(geo_weights) + assert torch.allclose(pg_loss_out, expected_weights, atol=1e-5) + assert torch.allclose(pg_loss_out[3], pg_loss_out[4]) + assert torch.allclose(pg_loss_out[3], pg_loss_out[5])