diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000000..c9e09d1b7c --- /dev/null +++ b/.gitmodules @@ -0,0 +1,8 @@ +[submodule "examples/swe-agent/nemo-gym"] + path = examples/swe-agent/nemo-gym + url = https://github.com/yueming-yuan/Gym + branch = slime-swe-agent +[submodule "examples/swe-agent/mini-swe-agent"] + path = examples/swe-agent/mini-swe-agent + url = https://github.com/yueming-yuan/nv-mini-swe-agent + branch = slime-swe-agent diff --git a/examples/formal_math/single_round/run_minimal.py b/examples/formal_math/single_round/run_minimal.py index bb734a219f..0c5f2c0005 100644 --- a/examples/formal_math/single_round/run_minimal.py +++ b/examples/formal_math/single_round/run_minimal.py @@ -96,10 +96,14 @@ ) wandb_args = ( - "--use-wandb " - "--wandb-project slime-formal-math-run-minimal " - "--wandb-group demo " - "--wandb-key ${WANDB_API_KEY} " + ( + "--use-wandb " + "--wandb-project slime-formal-math-run-minimal " + "--wandb-group demo " + f"--wandb-key '{wandb_api_key}' " + ) + if (wandb_api_key := os.environ.get("WANDB_API_KEY")) + else "" ) train_args = ( diff --git a/examples/geo3k_vlm_multi_turn/run_geo3k_vlm_multi_turn.py b/examples/geo3k_vlm_multi_turn/run_geo3k_vlm_multi_turn.py index 7e77593ba2..bc47d65a0d 100644 --- a/examples/geo3k_vlm_multi_turn/run_geo3k_vlm_multi_turn.py +++ b/examples/geo3k_vlm_multi_turn/run_geo3k_vlm_multi_turn.py @@ -43,10 +43,14 @@ def execute(): ckpt_args = f"--hf-checkpoint /root/models/{MODEL_NAME} " wandb_args = ( - "--use-wandb " - "--wandb-project slime-dev " - "--wandb-group geo3k_vlm_multi_turn " - "--wandb-key ${WANDB_API_KEY} " + ( + "--use-wandb " + "--wandb-project slime-dev " + "--wandb-group geo3k_vlm_multi_turn " + f"--wandb-key '{wandb_api_key}' " + ) + if (wandb_api_key := os.environ.get("WANDB_API_KEY")) + else "" ) rollout_args = ( diff --git a/examples/swe-agent/README.md b/examples/swe-agent/README.md new file mode 100644 index 0000000000..b655d449ca --- /dev/null +++ b/examples/swe-agent/README.md @@ -0,0 +1,130 @@ +### Introduction + +This is an example for SWE-agent training. This example uses NVIDIA's Nemo-Gym as the Gym environment implement, SWE-Gym as the training data, and SWE-bench as the evaluation. + +This implementation of this example is partially in submodules below: +- Nemo-Gym: https://github.com/yueming-yuan/Gym/tree/slime-swe-agent +- mini-swe-agent: https://github.com/yueming-yuan/nv-mini-swe-agent/tree/slime-swe-agent + + +### Prepare environment +#### Update submodules +```bash +git submodule update --init --recursive . +``` +#### Docker settings +```bash +# 1. create a docker network +docker network create swe-net + +# 2. create environment docker +docker run -itd \ + --name swe_env \ + --shm-size 16g \ + -v /var/run/docker.sock:/var/run/docker.sock \ + -v /mnt/data:/data \ + -v /home/sglang-rl/:/workspace \ + --ipc=host \ + --ulimit nofile=65536:65536 \ + --ulimit memlock=-1 \ + --ulimit stack=67108864 \ + --network swe-net \ + ubuntu:latest \ + /bin/bash + +# 3. create slime docker +docker run -itd \ + --shm-size 32g \ + --gpus all \ + -v /mnt/data/cache/huggingface:/root/.cache/huggingface \ + -v /mnt/data:/data \ + -v /home/sglang-rl/:/workspace \ + --ipc=host \ + --ulimit nofile=65536:65536 \ + --ulimit memlock=-1 \ + --ulimit stack=67108864 \ + --privileged \ + --network swe-net \ + --name slime_ \ + slimerl/slime:latest \ + /bin/zsh + +# 4. install utils in environment docker +docker exec -it swe_env /bin/bash +apt update && apt install -y zsh curl git python3 python3-pip docker.io +``` +note: `-v /var/run/docker.sock:/var/run/docker.sock` is required for Docker-in-Docker SWE environment execution; use `--network swe-net` to enable communication between training & environment. + +#### Installation + +In **environment docker**, install Gym +```bash +git clone https://github.com/yueming-yuan/Gym +cd Gym + +curl -LsSf https://astral.sh/uv/install.sh | sh +source $HOME/.local/bin/env +uv venv --python 3.12 && source .venv/bin/activate +uv sync --extra dev --group docs + +# configure env.yaml +echo "policy_base_url: https://api.openai.com/v1 +policy_api_key: your-openai-api-key +policy_model_name: gpt-4.1-2025-04-14 +default_host: 0.0.0.0" > env.yaml +``` +note: set host IP to `0.0.0.0` to enable communications between dockers. + +then set up for SWE-agent server: +```bash +cd responses_api_agents/mini_swe_agent +uv pip install -r requirements.txt +``` +Now you should be able to run the SWE-agent server. + +For **slime docker** setup, please follow the standard setup process. + +### Preparing data +In **slime docker**, download **SWE-Gym** data from huggingface and convert it to Slime' prompt data format with this script. +``` +cd slime/examples/swe-agent +python download_and_process_data.py --input SWE-Gym/SWE-Gym --output /root/swe_train.jsonl +``` + +### Running train +1. In environment docker, launch the agent server +```bash +cd Gym +source .venv/bin/activate +cd responses_api_agents/mini_swe_agent +./start_server.sh +``` + + +2. In slime docker, +(1) export `SWE_AGENT_GYM_URL` to be the port of the second server you started in Gym in environment docker, whose `server_type` is `responses_api_agents`. `swe_env` is the environment docker's name; replace it if you changed the name. +(minor TODO: modify the port selections to avoid setting this every time.) (2) launch the training. +```bash +export SWE_AGENT_GYM_URL="http://swe_env:" +bash examples/swe-agent/run-qwen3-4b-instruct.sh +``` + + +### Troubleshooting +1. The first time of every SWE environment can be slow, and may need to wait before generation, because each SWE-Gym task has a specific docker, and `docker pull` takes time. +2. Sometimes the environment may also be slow at evaluation. The timeout of evaluation is 10 minutes by default. If the server is stuck at `[EVAL] Running eval`, you may need to wait for it. + +## Metrics +``` +agent/turns_mean, agent/turns_sum - Turn counts +agent/tool_calls_mean, agent/tool_calls_sum - Tool call counts +agent/total_time_mean/max/min - Total time statistics +agent/model_query_time_sum_mean - Avg total model time per rollout +agent/env_execution_time_sum_mean - Avg total env time per rollout +agent/eval_time_mean - Avg evaluation time +agent/overhead_time_mean - Avg overhead time +agent/time_per_turn - Avg time per turn +agent/model_query_time_avg - Avg model query time per turn +agent/env_execution_time_avg - Avg env execution time per turn +agent/model_time_ratio, agent/env_time_ratio - Time ratios +``` diff --git a/examples/swe-agent/download_and_process_data.py b/examples/swe-agent/download_and_process_data.py new file mode 100755 index 0000000000..ae90e0ad22 --- /dev/null +++ b/examples/swe-agent/download_and_process_data.py @@ -0,0 +1,85 @@ +#!/usr/bin/env python3 +"""Download and process data to Slime format.""" + +import argparse +import json +import tempfile +from pathlib import Path +from datasets import load_dataset + + +def convert_to_slime_format(input_path: str, output_path: str, limit: int = None, split: str = "train"): + """Convert JSONL to Slime format. + + Args: + input_path: Path to input JSONL file + output_path: Path to output JSONL file in Slime format + limit: Optional limit on number of samples + split: Dataset split name (used in metadata) + """ + count = 0 + with open(input_path) as fin, open(output_path, "w") as fout: + for line in fin: + if limit and count >= limit: + break + + instance = json.loads(line) + + # Add subset and split to metadata for Gym API + metadata = dict(instance) + metadata["subset"] = "gym" + metadata["split"] = split + + slime_sample = { + "prompt": instance.get("problem_statement", ""), + "metadata": metadata, + } + + fout.write(json.dumps(slime_sample) + "\n") + count += 1 + + print(f"Converted {count} samples: {input_path} -> {output_path}") + + +def main(): + parser = argparse.ArgumentParser(description="Download HuggingFace dataset and convert to Slime format") + parser.add_argument("--input", type=str, required=True, help="HuggingFace dataset path or local JSONL file") + parser.add_argument("--output", type=str, required=True, help="Output JSONL file path") + parser.add_argument( + "--split", type=str, default="train", help="Dataset split (default: train, only for HF datasets)" + ) + parser.add_argument("--limit", type=int, help="Limit number of samples") + + args = parser.parse_args() + + input_path = Path(args.input) + + if input_path.exists() and input_path.suffix == ".jsonl": + print(f"Processing local file: {args.input}") + convert_to_slime_format(args.input, args.output, args.limit, args.split) + else: + print(f"Loading HuggingFace dataset: {args.input} (split={args.split})") + ds = load_dataset(args.input, split=args.split) + + if args.limit: + ds = ds.select(range(min(args.limit, len(ds)))) + + tmp_path = None + try: + with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as tmp: + tmp_path = tmp.name + + print(f"Downloading to temporary file: {tmp_path}") + ds.to_json(tmp_path) + + print(f"Converting to Slime format: {args.output}") + convert_to_slime_format(tmp_path, args.output, split=args.split) + finally: + if tmp_path and Path(tmp_path).exists(): + Path(tmp_path).unlink() + + print("Done.") + + +if __name__ == "__main__": + main() diff --git a/examples/swe-agent/generate_with_swe_agent.py b/examples/swe-agent/generate_with_swe_agent.py new file mode 100644 index 0000000000..170d2f8464 --- /dev/null +++ b/examples/swe-agent/generate_with_swe_agent.py @@ -0,0 +1,242 @@ +import logging +import os +from argparse import Namespace +from collections.abc import Callable +from typing import Any + +from slime.rollout.base_types import RolloutFnEvalOutput, RolloutFnTrainOutput +from slime.rollout.filter_hub.base_types import DynamicFilterOutput +from slime.rollout.sglang_rollout import GenerateState, eval_rollout +from slime.utils.async_utils import run +from slime.utils.http_utils import post +from slime.utils.types import Sample + +logger = logging.getLogger(__name__) + + +def build_tokens_and_mask_from_messages( + messages: list[dict], + tokenizer, +) -> tuple[list[int], list[int], str, int]: + + if not messages or len(messages) < 2: + return [], [], "", 0 + + prompt_msgs = messages[:2] + response_msgs = messages[2:] + + prompt_tokens = [] + for msg in prompt_msgs: + content = msg.get("content", "") + if content: + prompt_tokens.extend(tokenizer(content, add_special_tokens=False)["input_ids"]) + + response_tokens = [] + loss_mask = [] + response_text_parts = [] + + for msg in response_msgs: + content = msg.get("content", "") + if not content: + continue + + tokens = tokenizer(content, add_special_tokens=False)["input_ids"] + token_len = len(tokens) + + response_tokens.extend(tokens) + response_text_parts.append(content) + + mask_val = 1 if msg.get("role") == "assistant" else 0 + loss_mask.extend([mask_val] * token_len) + + all_tokens = prompt_tokens + response_tokens + response_text = "".join(response_text_parts) + response_length = len(response_tokens) + + return all_tokens, loss_mask, response_text, response_length + + +async def generate(args: Namespace, sample: Sample, sampling_params: dict[str, Any]) -> Sample: + """ + Custom generation function for SWE-Agent integration. + + Orchestrates the interaction with the external Gym environment: + 1. Sends prompt/metadata to Gym. + 2. Receives execution trace (messages) and rewards. + 3. Formats data for Slime training format. + + Note: Performs in-place modification of `sample` for memory efficiency. + """ + # Prepare request for Gym /run endpoint + request = { + "responses_create_params": { + "input": [], + }, + "sampling_params": sampling_params, + **sample.metadata, + "sglang_url": f"http://{args.sglang_router_ip}:{args.sglang_router_port}/v1", + } + + gym_url = os.getenv("SWE_AGENT_GYM_URL", "http://localhost:11000") + response = await post(f"{gym_url}/run", request) + + exit_status = response.get("info", {}).get("exit_status", "") + logger.debug(f"exit_status: {exit_status}, reward: {response.get('reward', 0.0)}") + + messages = response.get("messages", []) + + if len(messages) >= 2: + sample.prompt = messages[:2] + + state = GenerateState(args) + tokens, loss_mask, response_text, response_length = build_tokens_and_mask_from_messages( + messages=messages, + tokenizer=state.tokenizer, + ) + + sample.rollout_log_probs = None # TODO + sample.tokens = tokens + sample.loss_mask = loss_mask + sample.response = response_text + sample.response_length = response_length + sample.metadata["reward"] = response.get("reward", 0.0) + sample.metadata["eval_report"] = response.get("metadata", {}) + sample.metadata["messages"] = messages + + agent_metrics = response.get("info", {}).get("agent_metrics", {}) + sample.metadata["agent_metrics"] = agent_metrics + + if exit_status == "Submitted": + sample.status = Sample.Status.COMPLETED + elif exit_status in ("RolloutTruncated", "LimitsExceeded", "CollapseContinued"): + sample.status = Sample.Status.TRUNCATED + else: + sample.status = Sample.Status.ABORTED + sample.reward = 0.0 + + return sample + + +async def reward_func(args, sample: Sample, **kwargs) -> float: + """Reward function - already computed in generate()""" + reward = sample.metadata.get("reward", 0.0) + return reward + + +def dynamic_filter(args, samples: list[Sample], **kwargs) -> DynamicFilterOutput: + """Filter out groups with any aborted samples from training""" + has_aborted = any(sample.status == Sample.Status.ABORTED for sample in samples) + if has_aborted: + return DynamicFilterOutput(keep=False, reason="group_has_aborted") + return DynamicFilterOutput(keep=True) + + +def aggregate_agent_metrics(samples: list[Sample]) -> dict: + """Aggregate agent metrics across samples for logging""" + metrics = {} + + all_metrics = [] + for sample in samples: + if hasattr(sample, "metadata") and sample.metadata: + agent_metrics = sample.metadata.get("agent_metrics", {}) + if agent_metrics: + all_metrics.append(agent_metrics) + + if not all_metrics: + return {} + + # Count metrics - mean and sum + for key in ["turns", "tool_calls"]: + values = [m.get(key, 0) for m in all_metrics] + if values: + metrics[f"agent/{key}_mean"] = sum(values) / len(values) + metrics[f"agent/{key}_sum"] = sum(values) + + # Time sum metrics - mean across rollouts + for key in ["model_query_time_sum", "env_execution_time_sum", "eval_time", "agent_run_time"]: + values = [m.get(key, 0) for m in all_metrics] + if values: + metrics[f"agent/{key}_mean"] = sum(values) / len(values) + + # Time avg metrics - mean of means + for key in ["time_per_turn", "model_query_time_avg", "env_execution_time_avg"]: + values = [m.get(key, 0) for m in all_metrics] + if values: + metrics[f"agent/{key}"] = sum(values) / len(values) + + # Ratio metrics (all based on total_time which includes eval) + for key in ["model_time_ratio", "env_time_ratio", "eval_time_ratio"]: + values = [m.get(key, 0) for m in all_metrics] + if values: + metrics[f"agent/{key}"] = sum(values) / len(values) + + # Total time stats + values = [m.get("total_time", 0) for m in all_metrics] + if values: + metrics["agent/total_time_mean"] = sum(values) / len(values) + metrics["agent/total_time_max"] = max(values) + metrics["agent/total_time_min"] = min(values) + + return metrics + + +async def generate_rollout_async( + args: Namespace, rollout_id: int, data_source: Callable[[int], list[list[Sample]]] +) -> tuple[RolloutFnTrainOutput, list[list[Sample]]]: + """ + Custom rollout function that wraps sglang_rollout.generate_rollout_async + and adds agent metrics aggregation. + """ + from slime.rollout.sglang_rollout import generate_rollout_async as base_generate_rollout_async + + rollout_output, aborted_samples = await base_generate_rollout_async(args, rollout_id, data_source) + + all_samples = [] + for group in rollout_output.samples: + if isinstance(group[0], list): + for sample_list in group: + all_samples.extend(sample_list) + else: + all_samples.extend(group) + + agent_metrics = aggregate_agent_metrics(all_samples) + + metrics = rollout_output.metrics or {} + metrics.update(agent_metrics) + + logger.info(f"Aggregated agent metrics for rollout {rollout_id}: {agent_metrics}") + + return RolloutFnTrainOutput(samples=rollout_output.samples, metrics=metrics), aborted_samples + + +def generate_rollout( + args: Namespace, rollout_id: int, data_buffer: Any, evaluation: bool = False +) -> RolloutFnTrainOutput | RolloutFnEvalOutput: + """An example to implement the generate_rollout function for an rule based rm rollout generation. + + Args: + args: the whole args + rollout_id: int, the id of the rollout, used for deterministic data generation + data_buffer: the data buffer to store the generated samples + evaluation: bool, whether the rollout is for evaluation or not + + Returns: + list[list[Sample]]: a list of list of samples generated by the rollout + """ + output, aborted_samples = generate_abortable_samples( + args, rollout_id, data_buffer.get_samples, evaluation=evaluation + ) + data_buffer.add_samples(aborted_samples) + return output + + +def generate_abortable_samples( + args: Namespace, + rollout_id: int, + data_source: Callable[[int], list[list[Sample]]], + evaluation: bool = False, +) -> tuple[Any, list[list[Sample]]]: + assert args.rollout_global_dataset + if evaluation: + return run(eval_rollout(args, rollout_id)) + return run(generate_rollout_async(args, rollout_id, data_source)) diff --git a/examples/swe-agent/run-qwen3-4b-instruct.sh b/examples/swe-agent/run-qwen3-4b-instruct.sh new file mode 100755 index 0000000000..5160e75658 --- /dev/null +++ b/examples/swe-agent/run-qwen3-4b-instruct.sh @@ -0,0 +1,166 @@ +#!/bin/bash + +# for rerun the task +pkill -9 sglang +sleep 3 +ray stop --force +pkill -9 ray +pkill -9 python +sleep 3 +pkill -9 ray +pkill -9 python + +set -ex + +# will prevent ray from buffering stdout/stderr +export PYTHONBUFFERED=1 + +NVLINK_COUNT=$(nvidia-smi topo -m 2>/dev/null | grep -o 'NV[0-9][0-9]*' | wc -l) +if [ "$NVLINK_COUNT" -gt 0 ]; then + HAS_NVLINK=1 +else + HAS_NVLINK=0 +fi +echo "HAS_NVLINK: $HAS_NVLINK (detected $NVLINK_COUNT NVLink references)" + +SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" + +export SWE_AGENT_GYM_URL="${SWE_AGENT_GYM_URL:-http://swe_env:11000}" + +source "${SCRIPT_DIR}/../../scripts/models/qwen3-4B-Instruct-2507.sh" + +CKPT_ARGS=( + --hf-checkpoint /root/qwen3-4B-Instruct-2507 + --ref-load /root/qwen3-4B-Instruct-2507_torch_dist + # --load /path/to/checkpoint/ + --save /root/qwen3-4B-Instruct-2507_slime/ + --save-interval 100 +) + +PERF_ARGS=( + --tensor-model-parallel-size 2 + --pipeline-model-parallel-size 1 + --context-parallel-size 1 + --expert-model-parallel-size 1 + --expert-tensor-parallel-size 1 + + --recompute-granularity full + --recompute-method uniform + --recompute-num-layers 1 + + # --micro-batch-size 1 + --use-dynamic-batch-size + --max-tokens-per-gpu 2048 +) + +ROLLOUT_ARGS=( + --prompt-data /root/swe_train.jsonl + --input-key prompt + --metadata-key metadata + --rollout-shuffle + --num-rollout 3000 + --rollout-batch-size 8 + --n-samples-per-prompt 8 + --rollout-temperature 0.8 + --rollout-max-response-len 8192 + + --global-batch-size 64 + --balance-data +) + +EVAL_ARGS=( + # --eval-interval 50 + # --eval-prompt-data /workspace/data/swe_gym_val.jsonl + # --eval-input-key prompt + # --eval-metadata-key metadata + # --n-samples-per-eval-prompt 1 + # --eval-max-response-len 4096 +) + +GRPO_ARGS=( + --advantage-estimator grpo + --use-kl-loss + --kl-loss-coef 0.01 + --kl-loss-type low_var_kl + --entropy-coef 0.0 + --eps-clip 0.2 + --eps-clip-high 0.28 +) + +OPTIMIZER_ARGS=( + --optimizer adam + --lr 1e-6 + --lr-decay-style constant + --weight-decay 0.1 + --adam-beta1 0.9 + --adam-beta2 0.98 +) + +WANDB_ARGS=() +if [ -n "$WANDB_KEY" ]; then + WANDB_ARGS=( + --use-wandb + --wandb-project slime-swe-agent + --wandb-group swe-agent-qwen2.5-3b + --wandb-key ${WANDB_KEY} + ) +fi + +SGLANG_ARGS=( + --rollout-num-gpus-per-engine 1 + --sglang-mem-fraction-static 0.7 +) + +MISC_ARGS=( + # default dropout in megatron is 0.1 + --attention-dropout 0.0 + --hidden-dropout 0.0 + # should be good for model performance + --accumulate-allreduce-grads-in-fp32 + --attention-softmax-in-fp32 + # need to comment this when using model with MLA + --attention-backend flash +) + +CUSTOM_ARGS=( + --custom-generate-function-path generate_with_swe_agent.generate + --custom-rm-path generate_with_swe_agent.reward_func + --rollout-function-path generate_with_swe_agent.generate_rollout + --dynamic-sampling-filter-path generate_with_swe_agent.dynamic_filter +) + +export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} +echo "Starting Ray cluster at ${MASTER_ADDR}..." +ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus 4 --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265 --port=8899 + +RUNTIME_ENV_JSON="{ + \"env_vars\": { + \"PYTHONPATH\": \"/root/Megatron-LM/:${SCRIPT_DIR}:/root/slime\", + \"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\", + \"SWE_AGENT_GYM_URL\": \"${SWE_AGENT_GYM_URL}\" + } +}" +# \"NCCL_NVLS_ENABLE\": \"${HAS_NVLINK}\", + +echo "Launching training..." +echo " SWE Agent URL: ${SWE_AGENT_GYM_URL}" + +ray job submit --address="http://127.0.0.1:8265" \ + --runtime-env-json="${RUNTIME_ENV_JSON}" \ + -- python3 train.py \ + --actor-num-nodes 1 \ + --actor-num-gpus-per-node 4 \ + --colocate \ + ${MODEL_ARGS[@]} \ + ${CKPT_ARGS[@]} \ + ${ROLLOUT_ARGS[@]} \ + ${OPTIMIZER_ARGS[@]} \ + ${GRPO_ARGS[@]} \ + ${WANDB_ARGS[@]} \ + ${PERF_ARGS[@]} \ + ${EVAL_ARGS[@]} \ + ${SGLANG_ARGS[@]} \ + ${MISC_ARGS[@]} \ + ${CUSTOM_ARGS[@]} + +echo "Training completed!" diff --git a/scripts/run-qwen3-4B-fsdp.sh b/scripts/run-qwen3-4B-fsdp.sh index 495d2196d3..71e48e21a7 100644 --- a/scripts/run-qwen3-4B-fsdp.sh +++ b/scripts/run-qwen3-4B-fsdp.sh @@ -75,12 +75,16 @@ OPTIMIZER_ARGS=( --adam-beta2 0.98 ) -WANDB_ARGS=( - --use-wandb - --wandb-project slime-dev-mcore-fsdp - --wandb-group qwen3-4B-fsdp-1130-ref - --wandb-key ${WANDB_API_KEY} -) +if [ -z "${WANDB_API_KEY}" ]; then + WANDB_ARGS=() +else + WANDB_ARGS=( + --use-wandb + --wandb-project slime-dev-mcore-fsdp + --wandb-group qwen3-4B-fsdp-1130-ref + --wandb-key "${WANDB_API_KEY}" + ) +fi SGLANG_ARGS=( --rollout-num-gpus-per-engine 1 @@ -128,15 +132,15 @@ RUNTIME_ENV_JSON="{ ray job submit --address="http://127.0.0.1:8265" \ --runtime-env-json="${RUNTIME_ENV_JSON}" \ -- python3 train.py \ - ${CKPT_ARGS[@]} \ - ${ROLLOUT_ARGS[@]} \ - ${OPTIMIZER_ARGS[@]} \ - ${GRPO_ARGS[@]} \ - ${WANDB_ARGS[@]} \ - ${SGLANG_ARGS[@]} \ - ${TRAIN_BACKEND_ARGS[@]} \ - ${PERF_ARGS[@]} \ - ${MISC_ARGS[@]} + "${CKPT_ARGS[@]}" \ + "${ROLLOUT_ARGS[@]}" \ + "${OPTIMIZER_ARGS[@]}" \ + "${GRPO_ARGS[@]}" \ + "${WANDB_ARGS[@]}" \ + "${SGLANG_ARGS[@]}" \ + "${TRAIN_BACKEND_ARGS[@]}" \ + "${PERF_ARGS[@]}" \ + "${MISC_ARGS[@]}" diff --git a/slime/backends/sglang_utils/arguments.py b/slime/backends/sglang_utils/arguments.py index 0b6870b70b..c734b1bf63 100644 --- a/slime/backends/sglang_utils/arguments.py +++ b/slime/backends/sglang_utils/arguments.py @@ -41,7 +41,6 @@ def add_sglang_arguments(parser): skipped_args = [ "model_path", - "dtype", "trust_remote_code", "random_seed", # memory diff --git a/slime/backends/sglang_utils/sglang_engine.py b/slime/backends/sglang_utils/sglang_engine.py index 2bd6e75c66..263a8ee307 100644 --- a/slime/backends/sglang_utils/sglang_engine.py +++ b/slime/backends/sglang_utils/sglang_engine.py @@ -497,4 +497,6 @@ def _compute_server_args( "nccl_port", "dist_init_addr", "skip_server_warmup", + "enable_draft_weights_cpu_backup", + "mem_fraction_static", ] diff --git a/slime/ray/rollout.py b/slime/ray/rollout.py index 862ca67e53..17715b9210 100644 --- a/slime/ray/rollout.py +++ b/slime/ray/rollout.py @@ -423,10 +423,11 @@ def init_rollout_engines(args, pg, all_rollout_engines): def _allocate_rollout_engine_addr_and_ports_external(args, rollout_engines): addr_and_ports = [] for rank, _ in rollout_engines: - [host, port] = args.rollout_external_engine_addrs[rank].split(":") + addr = args.rollout_external_engine_addrs[rank] + [host, port] = addr.split(":") addr_and_ports.append( dict( - dist_init_addr=None, + dist_init_addr=addr, nccl_port=None, host=host, port=int(port), diff --git a/slime/router/router.py b/slime/router/router.py index a0866ad169..87c61c2db1 100644 --- a/slime/router/router.py +++ b/slime/router/router.py @@ -1,5 +1,7 @@ import argparse +import asyncio import json +import logging import httpx import uvicorn @@ -9,6 +11,8 @@ from slime.utils.misc import load_function +logger = logging.getLogger(__name__) + def run_router(args): """ @@ -28,9 +32,14 @@ def __init__(self, args, verbose=False): self.verbose = verbose self.app = FastAPI() - - # Worker information - self.worker_urls: dict[str, int] = {} + self.app.add_event_handler("startup", self._start_background_health_check) + + # URL -> Active Request Count (load state) + self.worker_request_counts: dict[str, int] = {} + # URL -> Consecutive Failures + self.worker_failure_counts: dict[str, int] = {} + # Quarantined workers excluded from routing pool + self.dead_workers: set[str] = set() self.max_weight_version = None max_connections = getattr(args, "slime_router_max_connections", None) @@ -63,9 +72,61 @@ def _setup_routes(self): # Catch-all route for proxying to SGLang - must be registered LAST self.app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])(self.proxy) - async def health_check(self, request: Request): - # TODO: do health check in background - pass + async def _start_background_health_check(self): + asyncio.create_task(self._health_check_loop()) + + async def _check_worker_health(self, url): + """Encapsulated health check logic for better maintainability.""" + try: + response = await self.client.get(f"{url}/health", timeout=5.0) + if response.status_code == 200: + return url, True + logger.debug(f"[slime-router] Worker {url} is unhealthy (Status: {response.status_code})") + except Exception as e: + logger.debug(f"[slime-router] Worker {url} health check failed: {e}") + return url, False + + async def _health_check_loop(self): + """Background loop to monitor worker health and adjust routing pool.""" + interval = self.args.rollout_health_check_interval + threshold = self.args.slime_router_health_check_failure_threshold + + while True: + try: + await asyncio.sleep(interval) + + urls = [u for u in self.worker_request_counts if u not in self.dead_workers] + if not urls: + continue + + results = await asyncio.gather(*(self._check_worker_health(url) for url in urls)) + + for url, is_healthy in results: + if not is_healthy: + failures = self.worker_failure_counts.get(url, 0) + 1 + self.worker_failure_counts[url] = failures + + if failures >= threshold: + logger.warning( + f"[slime-router] Worker {url} failed {threshold} consecutive health checks. Marking as DEAD." + ) + self.dead_workers.add(url) + # TODO (chenyang): Connect back 'dead' workers requires a mechanism to sync + # model versions to avoid off-policy issues from stale weights, since these + # dead workers' parameters may not be refitted. + else: + self.worker_failure_counts[url] = 0 + + logger.debug( + f"[slime-router] Health check complete. {len(self.worker_request_counts) - len(self.dead_workers)} workers healthy." + ) + + except asyncio.CancelledError: + logger.warning("[slime-router] Background health check loop is being cancelled.") + raise + except Exception as e: + logger.error(f"[slime-router] Unexpected error in health check loop: {e}", exc_info=True) + await asyncio.sleep(5) async def proxy(self, request: Request, path: str): """Proxy all other requests to the SGLang router""" @@ -124,16 +185,17 @@ async def add_worker(self, request: Request): ) # Add if new, keep a simple request count per worker - if worker_url not in self.worker_urls: - self.worker_urls[worker_url] = 0 + if worker_url not in self.worker_request_counts: + self.worker_request_counts[worker_url] = 0 + self.worker_failure_counts[worker_url] = 0 if self.verbose: print(f"[slime-router] Added new worker: {worker_url}") - return {"status": "success", "worker_urls": self.worker_urls} + return {"status": "success", "worker_urls": self.worker_request_counts} async def list_workers(self, request: Request): """List all registered workers""" - return {"urls": list(self.worker_urls.keys())} + return {"urls": list(self.worker_request_counts.keys())} async def retrieve_from_text(self, request: Request): """Get token information from text input""" @@ -158,19 +220,27 @@ async def retrieve_from_text(self, request: Request): return result def _use_url(self): - """Select a worker URL using round-robin strategy""" - assert len(self.worker_urls) > 0, "No workers available" + """Select worker URL with minimal active requests.""" + + if not self.dead_workers: + # Healthy path: select from all workers + url = min(self.worker_request_counts, key=self.worker_request_counts.get) + else: + # Degraded path: select from workers not in dead_workers + valid_workers = (w for w in self.worker_request_counts if w not in self.dead_workers) + try: + url = min(valid_workers, key=self.worker_request_counts.get) + except ValueError: + raise RuntimeError("No healthy workers available in the pool") from None - # get the url with mininal count - url = min(self.worker_urls, key=self.worker_urls.get) - self.worker_urls[url] += 1 + self.worker_request_counts[url] += 1 return url def _finish_url(self, url): """Mark the request to the given URL as finished""" - assert url in self.worker_urls, f"URL {url} not recognized" - self.worker_urls[url] -= 1 - assert self.worker_urls[url] >= 0, f"URL {url} count went negative" + assert url in self.worker_request_counts, f"URL {url} not recognized" + self.worker_request_counts[url] -= 1 + assert self.worker_request_counts[url] >= 0, f"URL {url} count went negative" if __name__ == "__main__": diff --git a/slime/utils/arguments.py b/slime/utils/arguments.py index d3ebf1d454..37cfd1337d 100644 --- a/slime/utils/arguments.py +++ b/slime/utils/arguments.py @@ -919,6 +919,12 @@ def add_router_arguments(parser): default=None, help="Max connections for SlimeRouter HTTP client.", ) + parser.add_argument( + "--slime-router-health-check-failure-threshold", + type=int, + default=3, + help="Number of consecutive failures before marking a worker as unhealthy.", + ) RouterArgs.add_cli_args(parser, use_router_prefix=True, exclude_host_port=True) return parser diff --git a/slime/utils/debug_utils/display_debug_rollout_data.py b/slime/utils/debug_utils/display_debug_rollout_data.py index 608bcd6db2..a954e679a4 100644 --- a/slime/utils/debug_utils/display_debug_rollout_data.py +++ b/slime/utils/debug_utils/display_debug_rollout_data.py @@ -6,7 +6,7 @@ import torch import typer -from slime.ray.rollout import compute_metrics_from_samples +from slime.ray.rollout import compute_perf_metrics_from_samples from slime.utils.types import Sample _WHITELIST_KEYS = [ @@ -47,7 +47,7 @@ def main( log_reward_category=None, ) sample_objects = [Sample.from_dict(s) for s in sample_dicts] - metrics = compute_metrics_from_samples(args, sample_objects) + metrics = compute_perf_metrics_from_samples(args, sample_objects) print("metrics", metrics) if show_samples: diff --git a/slime/utils/external_utils/command_utils.py b/slime/utils/external_utils/command_utils.py index c93e4aa4c1..72c3dcb559 100644 --- a/slime/utils/external_utils/command_utils.py +++ b/slime/utils/external_utils/command_utils.py @@ -213,12 +213,13 @@ def get_default_wandb_args(test_file: str, run_name_prefix: str | None = None, r if (x := run_name_prefix) is not None: wandb_run_name = f"{x}_{wandb_run_name}" - # do not put wandb_api_key value here to avoid leaking to logs explicitly + # Use the actual key value from environment to avoid shell expansion issues + wandb_key = os.environ.get("WANDB_API_KEY") return ( "--use-wandb " f"--wandb-project slime-{test_name} " f"--wandb-group {wandb_run_name} " - f"--wandb-key ${{WANDB_API_KEY}} " + f"--wandb-key '{wandb_key}' " "--disable-wandb-random-suffix " ) diff --git a/tests/test_moonlight_16B_A3B.py b/tests/test_moonlight_16B_A3B.py index aa0d399ce8..1bcfa1731e 100644 --- a/tests/test_moonlight_16B_A3B.py +++ b/tests/test_moonlight_16B_A3B.py @@ -118,8 +118,6 @@ def execute(): if __name__ == "__main__": prepare() - os.environ.pop("http_proxy") - os.environ.pop("https_proxy") - os.environ.pop("HTTP_PROXY") - os.environ.pop("HTTPS_PROXY") + for proxy_var in ("http_proxy", "https_proxy", "HTTP_PROXY", "HTTPS_PROXY"): + os.environ.pop(proxy_var, None) execute() diff --git a/tests/test_quick_start_glm4_9B.py b/tests/test_quick_start_glm4_9B.py index 2a928c6e3f..bd63a9d878 100644 --- a/tests/test_quick_start_glm4_9B.py +++ b/tests/test_quick_start_glm4_9B.py @@ -121,8 +121,6 @@ def execute(): if __name__ == "__main__": # TODO also use typer prepare() - os.environ.pop("http_proxy") - os.environ.pop("https_proxy") - os.environ.pop("HTTP_PROXY") - os.environ.pop("HTTPS_PROXY") + for proxy_var in ("http_proxy", "https_proxy", "HTTP_PROXY", "HTTPS_PROXY"): + os.environ.pop(proxy_var, None) execute() diff --git a/tests/test_qwen2.5_0.5B_gsm8k.py b/tests/test_qwen2.5_0.5B_gsm8k.py index 4ca8e8553b..cc6912d219 100644 --- a/tests/test_qwen2.5_0.5B_gsm8k.py +++ b/tests/test_qwen2.5_0.5B_gsm8k.py @@ -125,8 +125,6 @@ def execute(): if __name__ == "__main__": prepare() - os.environ.pop("http_proxy") - os.environ.pop("https_proxy") - os.environ.pop("HTTP_PROXY") - os.environ.pop("HTTPS_PROXY") + for proxy_var in ("http_proxy", "https_proxy", "HTTP_PROXY", "HTTPS_PROXY"): + os.environ.pop(proxy_var, None) execute() diff --git a/tests/test_qwen2.5_0.5B_gsm8k_async.py b/tests/test_qwen2.5_0.5B_gsm8k_async.py index 8c86a1410c..12625cfe25 100644 --- a/tests/test_qwen2.5_0.5B_gsm8k_async.py +++ b/tests/test_qwen2.5_0.5B_gsm8k_async.py @@ -125,8 +125,6 @@ def execute(): if __name__ == "__main__": prepare() - os.environ.pop("http_proxy") - os.environ.pop("https_proxy") - os.environ.pop("HTTP_PROXY") - os.environ.pop("HTTPS_PROXY") + for proxy_var in ("http_proxy", "https_proxy", "HTTP_PROXY", "HTTPS_PROXY"): + os.environ.pop(proxy_var, None) execute() diff --git a/tests/test_qwen3_0.6B_fsdp_colocated_2xGPU.py b/tests/test_qwen3_0.6B_fsdp_colocated_2xGPU.py index 193f2c7166..da1310da31 100644 --- a/tests/test_qwen3_0.6B_fsdp_colocated_2xGPU.py +++ b/tests/test_qwen3_0.6B_fsdp_colocated_2xGPU.py @@ -98,8 +98,6 @@ def execute(): if __name__ == "__main__": prepare() - os.environ.pop("http_proxy") - os.environ.pop("https_proxy") - os.environ.pop("HTTP_PROXY") - os.environ.pop("HTTPS_PROXY") + for proxy_var in ("http_proxy", "https_proxy", "HTTP_PROXY", "HTTPS_PROXY"): + os.environ.pop(proxy_var, None) execute() diff --git a/tests/test_qwen3_0.6B_fsdp_distributed.py b/tests/test_qwen3_0.6B_fsdp_distributed.py index f3eaa437fb..c858c2f86a 100644 --- a/tests/test_qwen3_0.6B_fsdp_distributed.py +++ b/tests/test_qwen3_0.6B_fsdp_distributed.py @@ -100,8 +100,6 @@ def execute(): if __name__ == "__main__": prepare() - os.environ.pop("http_proxy") - os.environ.pop("https_proxy") - os.environ.pop("HTTP_PROXY") - os.environ.pop("HTTPS_PROXY") + for proxy_var in ("http_proxy", "https_proxy", "HTTP_PROXY", "HTTPS_PROXY"): + os.environ.pop(proxy_var, None) execute() diff --git a/tests/test_qwen3_0.6B_parallel_check.py b/tests/test_qwen3_0.6B_parallel_check.py index f513afcd56..ab08ad5968 100644 --- a/tests/test_qwen3_0.6B_parallel_check.py +++ b/tests/test_qwen3_0.6B_parallel_check.py @@ -131,8 +131,6 @@ def execute(): if __name__ == "__main__": # TODO also use typer prepare() - os.environ.pop("http_proxy") - os.environ.pop("https_proxy") - os.environ.pop("HTTP_PROXY") - os.environ.pop("HTTPS_PROXY") + for proxy_var in ("http_proxy", "https_proxy", "HTTP_PROXY", "HTTPS_PROXY"): + os.environ.pop(proxy_var, None) execute() diff --git a/tests/test_qwen3_30B_A3B.py b/tests/test_qwen3_30B_A3B.py index 309d4e486d..4b3ac796b0 100644 --- a/tests/test_qwen3_30B_A3B.py +++ b/tests/test_qwen3_30B_A3B.py @@ -145,8 +145,6 @@ def execute(): if __name__ == "__main__": # TODO also use typer prepare() - os.environ.pop("http_proxy") - os.environ.pop("https_proxy") - os.environ.pop("HTTP_PROXY") - os.environ.pop("HTTPS_PROXY") + for proxy_var in ("http_proxy", "https_proxy", "HTTP_PROXY", "HTTPS_PROXY"): + os.environ.pop(proxy_var, None) execute() diff --git a/tests/test_qwen3_4B_ckpt.py b/tests/test_qwen3_4B_ckpt.py index 1de8759282..b745bf58a3 100644 --- a/tests/test_qwen3_4B_ckpt.py +++ b/tests/test_qwen3_4B_ckpt.py @@ -131,9 +131,7 @@ def execute(mode: str = ""): args = parser.parse_args() # TODO also use typer prepare() - os.environ.pop("http_proxy") - os.environ.pop("https_proxy") - os.environ.pop("HTTP_PROXY") - os.environ.pop("HTTPS_PROXY") + for proxy_var in ("http_proxy", "https_proxy", "HTTP_PROXY", "HTTPS_PROXY"): + os.environ.pop(proxy_var, None) execute("save" if not args.async_save else "async_save") execute("load") diff --git a/tests/test_qwen3_4B_fsdp_true_on_policy.py b/tests/test_qwen3_4B_fsdp_true_on_policy.py index 1dc102c4e1..b3c9c7007d 100644 --- a/tests/test_qwen3_4B_fsdp_true_on_policy.py +++ b/tests/test_qwen3_4B_fsdp_true_on_policy.py @@ -107,8 +107,6 @@ def execute(): if __name__ == "__main__": prepare() - os.environ.pop("http_proxy") - os.environ.pop("https_proxy") - os.environ.pop("HTTP_PROXY") - os.environ.pop("HTTPS_PROXY") + for proxy_var in ("http_proxy", "https_proxy", "HTTP_PROXY", "HTTPS_PROXY"): + os.environ.pop(proxy_var, None) execute() diff --git a/tests/test_qwen3_4B_ppo.py b/tests/test_qwen3_4B_ppo.py index fecdd1b147..702bb7cd20 100644 --- a/tests/test_qwen3_4B_ppo.py +++ b/tests/test_qwen3_4B_ppo.py @@ -128,8 +128,6 @@ def execute(): if __name__ == "__main__": # TODO also use typer prepare() - os.environ.pop("http_proxy") - os.environ.pop("https_proxy") - os.environ.pop("HTTP_PROXY") - os.environ.pop("HTTPS_PROXY") + for proxy_var in ("http_proxy", "https_proxy", "HTTP_PROXY", "HTTPS_PROXY"): + os.environ.pop(proxy_var, None) execute()