diff --git a/examples/train/search/run_search.sh b/examples/train/search/run_search.sh index 678e1f93c0..554f94f7e9 100755 --- a/examples/train/search/run_search.sh +++ b/examples/train/search/run_search.sh @@ -1,10 +1,32 @@ set -x -# Colocated GRPO training+generation for Qwen2.5-Coder-3B-Instruct on SearchR1 data. -# follow the instructions in examples/search/README.md for setting up the dataset -# and for starting the local search server -# export WANDB_API_KEY= -# bash examples/train/search/run_search.sh +# Colocated GRPO training+generation for Qwen2.5-3B-Instruct on SearchR1 data. +# Follow the instructions in docs/content/docs/recipes/searchr1.mdx for setup. +# +# Usage: +# export WANDB_API_KEY= +# bash examples/train/search/run_search.sh +# +# Configurable knobs (override via env vars or command-line args): +# USE_CONVERSATION_MULTI_TURN - set to "true" to use conversation multi-turn format (default: false) +# When true, also enables append_eos_token_after_stop_str_in_multi_turn=true so that +# each turn's response ends with the model's EOS token (required for correct behavior +# when stop strings like or terminate generation instead of EOS). +# STEP_WISE - set to "true" to enable step-wise training (default: false) +# Requires USE_CONVERSATION_MULTI_TURN=true. +# +# Examples: +# # Default (non-conversation, non-step-wise): +# bash examples/train/search/run_search.sh +# +# # Conversation multi-turn format: +# USE_CONVERSATION_MULTI_TURN=true bash examples/train/search/run_search.sh +# +# # Step-wise with conversation multi-turn: +# USE_CONVERSATION_MULTI_TURN=true STEP_WISE=true bash examples/train/search/run_search.sh +# +# # Override any config via positional args (passed to Hydra): +# bash examples/train/search/run_search.sh trainer.epochs=2 trainer.eval_interval=10 # path for dataset (.parquet files) containing the prompts and metadata for each question DATA_DIR="$HOME/data/searchR1" @@ -14,6 +36,28 @@ RUN_NAME="skyrl-search_4turns_maxgeneratelen_500-multiturn-sync-TIS_2.0" TIS_TYPE=token TIS_IMP_RATIO_CAP=2.0 +# Configurable knobs with defaults +: "${USE_CONVERSATION_MULTI_TURN:=false}" +: "${STEP_WISE:=false}" + +# Build conditional args +MULTI_TURN_ARGS="" +if [ "$USE_CONVERSATION_MULTI_TURN" = "true" ]; then + MULTI_TURN_ARGS="generator.use_conversation_multi_turn=true generator.append_eos_token_after_stop_str_in_multi_turn=true" +else + MULTI_TURN_ARGS="generator.use_conversation_multi_turn=false" +fi + +STEP_WISE_ARGS="" +if [ "$STEP_WISE" = "true" ]; then + STEP_WISE_ARGS="generator.step_wise_trajectories=true" + # Step-wise requires conversation multi-turn + if [ "$USE_CONVERSATION_MULTI_TURN" != "true" ]; then + echo "WARNING: STEP_WISE=true requires USE_CONVERSATION_MULTI_TURN=true. Enabling it automatically." + MULTI_TURN_ARGS="generator.use_conversation_multi_turn=true generator.append_eos_token_after_stop_str_in_multi_turn=true" + fi +fi + uv run --isolated --frozen --extra fsdp -m skyrl.train.entrypoints.main_base \ data.train_data="['${DATA_DIR}/train.parquet']" \ data.val_data="['${DATA_DIR}/validation.parquet']" \ @@ -49,7 +93,8 @@ uv run --isolated --frozen --extra fsdp -m skyrl.train.entrypoints.main_base \ generator.sampling_params.max_generate_length=500 \ generator.inference_engine.async_engine=true \ generator.batched=false \ - generator.use_conversation_multi_turn=false \ + $MULTI_TURN_ARGS \ + $STEP_WISE_ARGS \ generator.n_samples_per_prompt=5 \ generator.max_turns=4 \ generator.sampling_params.temperature=1.0 \ diff --git a/examples/train/search/run_search_conversation_format.sh b/examples/train/search/run_search_conversation_format.sh deleted file mode 100755 index 0d55dd733d..0000000000 --- a/examples/train/search/run_search_conversation_format.sh +++ /dev/null @@ -1,86 +0,0 @@ -set -x - -# The exact same script as `run_search.sh` but with `use_conversation_multi_turn=true` -# and hence `append_eos_token_after_stop_str_in_multi_turn=true` -# See https://docs.skyrl.ai/docs/tutorials/skyrl_gym_generator on the -# difference between the two options. You might want to change the data generation prompt -# to let the model know that we are doing multi-turn conversations (i.e. user will provide -# the search result for each turn). - -# Colocated GRPO training+generation for Qwen2.5-Coder-3B-Instruct on SearchR1 data. -# follow the instructions in examples/train/search/README.md for setting up the dataset -# and for starting the local search server -# export WANDB_API_KEY= -# bash examples/train/search/run_search_conversation_format.sh - -# path for dataset (.parquet files) containing the prompts and metadata for each question -DATA_DIR="$HOME/data/searchR1" - -RUN_NAME="skyrl-search_4turns_maxgeneratelen_500" - -TIS_TYPE=token -TIS_IMP_RATIO_CAP=2.0 - -uv run --isolated --frozen --extra fsdp -m skyrl.train.entrypoints.main_base \ - data.train_data="['${DATA_DIR}/train.parquet']" \ - data.val_data="['${DATA_DIR}/validation.parquet']" \ - trainer.algorithm.advantage_estimator="grpo" \ - trainer.policy.optimizer_config.lr=1.0e-6 \ - trainer.policy.optimizer_config.max_grad_norm=0.5 \ - trainer.policy.optimizer_config.num_warmup_steps=94 \ - trainer.algorithm.use_kl_loss=true \ - trainer.algorithm.kl_loss_coef=0.001 \ - trainer.algorithm.off_policy_correction.tis_ratio_type=$TIS_TYPE \ - trainer.algorithm.off_policy_correction.token_tis_ratio_clip_high=$TIS_IMP_RATIO_CAP \ - trainer.policy.model.path="Qwen/Qwen2.5-3B-Instruct" \ - trainer.placement.colocate_all=true \ - trainer.strategy=fsdp2 \ - trainer.policy.fsdp_config.cpu_offload=false \ - trainer.ref.fsdp_config.cpu_offload=true \ - trainer.placement.policy_num_gpus_per_node=8 \ - trainer.placement.ref_num_gpus_per_node=8 \ - generator.inference_engine.num_engines=4 \ - generator.inference_engine.tensor_parallel_size=2 \ - generator.inference_engine.backend=vllm \ - generator.inference_engine.run_engines_locally=true \ - generator.inference_engine.weight_sync_backend=nccl \ - generator.inference_engine.gpu_memory_utilization=0.5 \ - trainer.epochs=1 \ - trainer.update_epochs_per_batch=1 \ - trainer.train_batch_size=512 \ - trainer.policy_mini_batch_size=256 \ - trainer.micro_forward_batch_size_per_gpu=4 \ - trainer.micro_train_batch_size_per_gpu=4 \ - trainer.max_prompt_length=2048 \ - generator.max_input_length=4096 \ - generator.sampling_params.max_generate_length=500 \ - generator.inference_engine.async_engine=true \ - generator.batched=false \ - generator.use_conversation_multi_turn=true \ - generator.n_samples_per_prompt=5 \ - generator.max_turns=4 \ - generator.sampling_params.temperature=1.0 \ - generator.sampling_params.top_p=1.0 \ - generator.sampling_params.stop='["", ""]' \ - generator.append_eos_token_after_stop_str_in_multi_turn=true \ - environment.env_class="search" \ - environment.skyrl_gym.max_env_workers=16 \ - environment.skyrl_gym.search.log_requests=false \ - environment.skyrl_gym.search.search_url="http://127.0.0.1:8000/retrieve" \ - environment.skyrl_gym.search.topk=3 \ - trainer.logger="wandb" \ - trainer.project_name="skyrl-search" \ - trainer.run_name="${RUN_NAME}" \ - trainer.ckpt_interval=20 \ - trainer.hf_save_interval=100 \ - trainer.max_ckpts_to_keep=5 \ - trainer.resume_mode=latest \ - trainer.ckpt_path="$HOME/${RUN_NAME}" \ - trainer.eval_batch_size=256 \ - trainer.eval_before_train=false \ - generator.eval_sampling_params.temperature=0 \ - generator.eval_sampling_params.stop='["", ""]' \ - trainer.export_path="$HOME/${RUN_NAME}/exports" \ - trainer.eval_interval=50 \ - $@ - \ No newline at end of file diff --git a/examples/train_integrations/harbor/HANDOFF.md b/examples/train_integrations/harbor/HANDOFF.md new file mode 100644 index 0000000000..6c89ae530d --- /dev/null +++ b/examples/train_integrations/harbor/HANDOFF.md @@ -0,0 +1,185 @@ +# Stepwise Training Handoff — Harbor + SkyRL + +## Current State (2026-03-12 ~23:00 UTC) + +- **Training is STOPPED** (killed manually) +- **Monitoring cron is CANCELLED** +- **Latest checkpoint**: `global_step_30` at `/home/ray/codecontest-stepwise/ckpts/global_step_30` +- **Checkpoints available**: global_step_24, global_step_27, global_step_30 +- **Total effective training steps**: ~32 (rewards went from 0.32 to peak 0.58) +- **W&B run**: `codecontest-stepwise` in project `harbor` at `sky-posttraining-uc-berkeley` + +## How to Resume Training + +```bash +cd /home/ray/default/SkyRL +# 1. Kill any lingering sandboxes +uv run --isolated --extra fsdp --extra harbor examples/train_integrations/harbor/kill_daytona_sandboxes.py + +# 2. Clean up stale Ray placement groups +python3 -c " +import ray +from ray._raylet import PlacementGroupID +from ray.util.placement_group import PlacementGroup +ray.init(address='auto') +for pg_id, pg_info in ray.util.placement_group_table().items(): + if pg_info.get('state') == 'CREATED': + try: + ray.util.remove_placement_group(PlacementGroup(PlacementGroupID.from_hex(pg_id))) + except: pass +" + +# 3. Free port 8000 if occupied +ss -tlnp | grep 8000 && fuser -k 8000/tcp + +# 4. Launch training (resumes from latest checkpoint automatically) +nohup bash examples/train_integrations/harbor/run_codecontest_stepwise.sh > /tmp/skyrl-logs/codecontest-stepwise-launch.log 2>&1 & +``` + +## How to Monitor + +Set up a 15-minute cron using Claude Code's `/loop` command: +``` +/loop 15m Check the stepwise training job status: 1) Check if the process is still running (ps aux | grep main_harbor). 2) Check the last 30 lines of /tmp/skyrl-logs/codecontest-stepwise-launch.log for errors or progress. 3) Look for training step progress ("Training Batches Processed", "step", "ckpt", "generate"). 4) If the process has crashed: a) Check the error in the log. b) If it's a transient error, run `cd /home/ray/default/SkyRL && uv run --isolated --extra fsdp --extra harbor examples/train_integrations/harbor/kill_daytona_sandboxes.py` to clean up sandboxes. c) Clean up stale Ray placement groups. d) Free port 8000 if occupied. e) Relaunch with: `cd /home/ray/default/SkyRL && nohup bash examples/train_integrations/harbor/run_codecontest_stepwise.sh > /tmp/skyrl-logs/codecontest-stepwise-launch.log 2>&1 &`. 5) Report status summary. +``` + +## Key Files Modified (from main branch) + +### SkyRL repo (`/home/ray/default/SkyRL`, branch `harbor-step-wise`) + +1. **`pyproject.toml`** — Harbor dependency changed from git commit to local path: + ``` + harbor = { path = "/home/ray/default/harbor" } + ``` + +2. **`examples/train_integrations/harbor/run_codecontest_stepwise.sh`** — NEW file, launch script with: + - 8 GPUs, step-wise enabled, dual_clip policy loss + - `ckpt_interval=2`, `max_ckpts_to_keep=3`, `resume_mode=latest` + - Export dir: `/mnt/local_storage/codecontest-stepwise/exports` + - `eval_interval=900`, `eval_before_train=false` (skip eval) + - `max_concurrency=500` rate limiting + - `PYTORCH_ALLOC_CONF=expandable_segments:True` + - `TORCH_NCCL_AVOID_RECORD_STREAMS=1` + +3. **`examples/train_integrations/harbor/harbor_generator.py`** — Bug fixes: + - `_worker()`: Added `except BaseException` to catch `CancelledError` and return zeroed output + - Added safety fill loop after `TaskGroup` to handle `None` entries in `all_outputs` + - Added `_make_error_output()` helper + +4. **`skyrl/train/trainer.py`** — Added `gc.collect()` before all `empty_cache()` calls (lines ~959, 967, 1130) + +5. **`skyrl/backends/skyrl_train/workers/fsdp/fsdp_worker.py`** — Added `gc.collect()` before both `empty_cache()` calls in `broadcast_to_inference_engines()` + +6. **`examples/train_integrations/harbor/OOM_ANALYSIS.md`** — Comprehensive crash analysis doc +7. **`examples/train_integrations/harbor/HANDOFF.md`** — This file +8. **`test_vllm_sleep_wake.py`** — Minimal vLLM sleep/wake repro script + +### vLLM patch (applied to uv cache) + +**File**: `/home/ray/.cache/uv/archive-v0/qNofxNgHzcv1I1Qbx76wG/vllm/device_allocator/cumem.py` + +Applied Python-side changes from [vLLM PR #36535](https://github.com/vllm-project/vllm/pull/36535): +- `AllocationData.mapped: bool = True` field +- `CuMemAllocator.sleeping = False` flag +- `_python_free_callback`: returns dummy handle `(0, 0, 0, 0)` during sleep +- `sleep()`: tracks `data.mapped = False`, sets `self.sleeping = True` +- `wake_up()`: checks `data.mapped`, sets `self.sleeping = False` + +**NOTE**: This patch is in the uv archive and all build dirs. It helps prevent Python-level double-free but doesn't fix the C++ side. The full fix requires recompiling vLLM with PR #36535's C++ changes or upgrading to vLLM 0.17+ when the PR is merged. + +## The Crash Bug + +### Root Cause +**vLLM issue [#36651](https://github.com/vllm-project/vllm/issues/36651)**: cumem allocator double-free and stale error codes during sleep/wake cycles. The crash manifests as `cudaErrorInvalidValue` at `flash_attn.py:484` (`self.scheduler_metadata[:n] = scheduler_metadata`). + +### Key Evidence +- **NOT OOM**: Zero OOM/SIGKILL messages in any infra log. Memory is stable at ~21 GiB per GPU across sleep cycles. +- **Probabilistic**: Crashes after 2-11 sleep/wake cycles (mode: 3-5) +- **Standalone repro**: Pure sleep/wake cycles (no weight updates) do NOT reproduce — 15 cycles passed. The crash requires the full training loop with NCCL weight broadcast + colocated FSDP. +- **Not step-wise specific**: Non-step-wise training was never run long enough to compare, but the sleep/wake cycle is identical in both modes. + +### Fix Status +- [PR #36535](https://github.com/vllm-project/vllm/pull/36535) is OPEN, not merged (filed 2026-03-10) +- Python-side patch applied (insufficient alone) +- C++ side requires recompiling vLLM or upgrading to 0.17+ + +### What Was Tested + +| Fix | Result | +|-----|--------| +| `PYTORCH_ALLOC_CONF=expandable_segments:True` | No effect (crash is not OOM) | +| `gc.collect()` before `empty_cache()` | Slight improvement (5 vs 3-4 steps) | +| `enforce_eager=true` | **Worse** — Xid 31 MMU Fault | +| `gpu_memory_utilization=0.7` | **Worse** — crashed after 1 step | +| cumem Python patch (PR #36535) | Insufficient — C++ bugs still trigger | +| `TORCH_NCCL_AVOID_RECORD_STREAMS=1` | Untested in isolation | +| `ckpt_interval=2` | **Key mitigation** — checkpoints before crash window | + +### Current Mitigation Strategy +With `ckpt_interval=2`, the training checkpoints every 2 steps. Since crashes happen at 2-5 cycles from resume, we usually get 1 checkpoint per run. The monitoring cron auto-restarts, losing ~1 step per crash. Net efficiency: ~70-80%. + +## Reward Curve + +| Global Step | Avg Reward | Notes | +|------------|-----------|-------| +| 1 | 0.324 | Baseline | +| 5 | 0.258 | | +| 10 | 0.262 | | +| 13 | 0.398 | | +| 17 | 0.445 | | +| 22 | 0.430 | | +| 25 | 0.449 | | +| 28 | 0.508-0.582 | Peak range | +| 31 | 0.348-0.383 | | + +## Saved Logs + +Key infra/experiment logs preserved at `examples/train_integrations/harbor/logs/`: + +| File | Run | Description | +|------|-----|-------------| +| `infra-260311_122019.log` | Run 4 | First `cudaErrorInvalidValue` crash (resumed from step 5) | +| `infra-260311_163209.log` | Run 6 | Longest pre-mitigation run (11 sleep/wake cycles, steps 5→15) | +| `infra-260312_092526.log` | Run 13 | `enforce_eager=true` test — Xid 31 MMU Fault | +| `infra-260312_100154.log` | Run 14 | Best run with `gc.collect()` fix (5 steps, step 24→29) | +| `infra-260312_204717.log` | Run 20 | Most recent run | +| `launch-last.log` | Run 20 | Main stdout/stderr from last training launch | +| `vllm_repro_test.log` | — | Standalone vLLM sleep/wake repro (15 cycles, all passed) | + +To grep for crash patterns: `grep "cudaErrorInvalidValue\|scheduler_metadata\|still in use" logs/infra-*.log` + +## Data Setup + +- **Dataset**: CodeContests from HuggingFace (`open-thoughts/CodeContests`) +- **Location**: `/home/ray/data/harbor/CodeContests` (9644 tasks) +- **Prepared by**: `examples/train_integrations/harbor/prepare_harbor_dataset.py` + +## Config Summary + +```yaml +model: Qwen/Qwen3-8B +num_gpus: 8 +colocate_all: true +strategy: fsdp2 +train_batch_size: 32 +n_samples_per_prompt: 8 +max_model_len: 32768 +ckpt_interval: 2 +max_ckpts_to_keep: 3 +resume_mode: latest +step_wise_trajectories: true +policy_loss_type: dual_clip +loss_reduction: seq_mean_token_sum_norm +environment: daytona +agent: terminus-2 +max_turns: 32 +max_concurrency: 500 +``` + +## Next Steps + +1. **Upgrade vLLM to 0.17+** when PR #36535 is merged — this is the real fix +2. **Or** rebuild vLLM 0.16.0 from source with PR #36535's C++ changes applied +3. **Or** continue with `ckpt_interval=2` + auto-restart as mitigation +4. **Consider**: Running a non-step-wise comparison to confirm the crash is not step-wise specific +5. **Consider**: Filing a vLLM issue with our specific repro (SkyRL + FSDP colocated + NCCL weight sync) diff --git a/examples/train_integrations/harbor/OOM_ANALYSIS.md b/examples/train_integrations/harbor/OOM_ANALYSIS.md new file mode 100644 index 0000000000..5d89eecca7 --- /dev/null +++ b/examples/train_integrations/harbor/OOM_ANALYSIS.md @@ -0,0 +1,390 @@ +# Crash Analysis: Step-wise Training with Harbor + vLLM Colocated Inference + +## Observed Behavior + +When running step-wise training on CodeContests with Qwen3-8B (8x H100, colocate_all=true, 8 vLLM engines TP=1), the job consistently crashes after ~3-4 steps from a fresh process start. The crash manifests as `EngineDeadError` from vLLM during weight sync / generation wake-up. + +## CORRECTED Root Cause: `cudaErrorInvalidValue` in FlashAttention (NOT OOM) + +After analyzing all 12 infra logs, the root cause is **NOT out-of-memory**. Every crash shows the same stack trace: + +``` +File "vllm/v1/attention/backends/flash_attn.py", line 484, in build + self.scheduler_metadata[:n] = scheduler_metadata +torch.AcceleratorError: CUDA error: invalid argument (cudaErrorInvalidValue) +``` + +**There are ZERO OOM/memory-cgroup/SIGKILL messages in any infra log.** The "Memory cgroup out of memory" messages in the main launch log are from unrelated system processes (e.g., `vector` logging daemon), not from the training workers. + +### Evidence + +```bash +# Across ALL 12 infra logs: +grep -l "out of memory|OOM|memory cgroup|SIGKILL" infra-*.log → NONE +grep -l "cudaErrorInvalidValue" infra-*.log → 7 out of 12 (all crash runs) +``` + +### What Happens + +1. vLLM's FlashAttention backend has a pre-allocated `scheduler_metadata` tensor +2. After multiple sleep/wake cycles, the tensor becomes invalid (wrong size, stale pointer, or corrupted state) +3. `self.scheduler_metadata[:n] = scheduler_metadata` fails with `cudaErrorInvalidValue` +4. The EngineCore process dies, which triggers `EngineDeadError` in the main process +5. Any in-flight generation requests get `CancelledError` + +### Memory Is NOT Leaking + +The vLLM sleep logs prove memory is stable across cycles: + +| Sleep Cycle | Memory Still In Use (avg) | Notes | +|-------------|--------------------------|-------| +| 1 (initial) | 17.1 GiB | Before any training | +| 2 | 19.25 GiB | After first FSDP optimizer load | +| 3 | ~20.5-21.3 GiB | Stabilizes | +| 4+ | ~20.8-21.7 GiB | **Flat — no leak** | + +## Run-by-Run Log + +### Run 1 — No mitigations +- **Infra log**: `infra-260311_093550.log` +- **Env**: No `expandable_segments`, no `gc.collect()` fix +- **Resumed from**: Fresh start +- **Crashed at**: Step ~5 (EngineDeadError, port 8000 was occupied) +- **Checkpoint saved**: None + +### Run 2 — Port 8000 freed +- **Infra log**: `infra-260311_094551.log` +- **Env**: Same as Run 1 +- **Resumed from**: Fresh start +- **Crashed at**: Step ~3 (`NoneType` error — `_worker` didn't catch exceptions, `all_outputs` had None entries) +- **Checkpoint saved**: None +- **Fix applied**: Added `except Exception` in `_worker` (v1) + +### Run 3 — Worker exception handling v1 +- **Infra log**: `infra-260311_110304.log` +- **Env**: `except Exception` in `_worker` (didn't catch `CancelledError`) +- **Resumed from**: Fresh start (old Ray package cache still used) +- **Crashed at**: Step ~3 (same `NoneType` — fix wasn't in Ray package) +- **Checkpoint saved**: None +- **Fix applied**: Cleared Ray package cache, changed to `except BaseException` + safety fill for None entries (v2) + +### Run 4 — Worker exception handling v2 +- **Infra log**: `infra-260311_122019.log` +- **Env**: `except BaseException` + None safety fill +- **Resumed from**: Step 5 checkpoint +- **Crashed at**: Step ~8 (`cudaErrorInvalidValue` in flash_attn.py:484) +- **Checkpoint saved**: `global_step_5` (already existed) + +### Run 5 — Same config, resumed +- **Infra log**: `infra-260311_144709.log` +- **Resumed from**: Step 5 +- **Crashed at**: Step ~8 (same `cudaErrorInvalidValue` + NCCL watchdog SIGABRT) +- **Checkpoint saved**: `global_step_10` + +### Run 6 — Longest pre-mitigation run +- **Infra log**: `infra-260311_163209.log` +- **Resumed from**: Step 5 +- **Steps completed**: 10 (steps 6-15) +- **Crashed at**: Step ~15→16 transition (`cudaErrorInvalidValue`) +- **Checkpoint saved**: `global_step_10`, `global_step_15` +- **Notable**: CancelledError batch at step 8 (all 205 workers cancelled — Daytona outage?), but training continued + +### Run 7 — Resumed from step 15 +- **Infra log**: `infra-260311_210212.log` +- **Resumed from**: Step 15 +- **Crashed at**: Step ~19→20 (`cudaErrorInvalidValue`) +- **Checkpoint saved**: None new (step 20 not reached) + +### Run 8 — Same, second attempt +- **Infra log**: `infra-260311_230207.log` +- **Resumed from**: Step 15 +- **Crashed at**: Step ~19→20 (`cudaErrorInvalidValue`) +- **Checkpoint saved**: None new + +### Run 9 — Same, third attempt +- **Infra log**: `infra-260312_011715.log` +- **Resumed from**: Step 15 +- **Crashed at**: Step ~19→20 (`cudaErrorInvalidValue`) +- **Checkpoint saved**: None new +- **Fix applied**: `PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True` (wrong env var name) + `ckpt_interval=3` + +### Run 10 — Wrong env var name +- **Infra log**: `infra-260312_034712.log` +- **Env**: `PYTORCH_CUDA_ALLOC_CONF` (deprecated in PyTorch 2.9+, showed warning) +- **Resumed from**: Step 15 +- **Crashed at**: Step ~19→20 (`cudaErrorInvalidValue`) +- **Checkpoint saved**: `global_step_18` (first progress past step 15 thanks to ckpt_interval=3!) +- **Fix applied**: Changed to `PYTORCH_ALLOC_CONF` (correct name) + +### Run 11 — Correct env var name +- **Infra log**: `infra-260312_054724.log` +- **Env**: `PYTORCH_ALLOC_CONF=expandable_segments:True` (correct) +- **Resumed from**: Step 18 +- **Crashed at**: Step ~21→22 (`cudaErrorInvalidValue`) +- **Checkpoint saved**: `global_step_21` +- **Fix applied**: `gc.collect()` before all `empty_cache()` calls + `TORCH_NCCL_AVOID_RECORD_STREAMS=1` + +### Run 12 — Current run (gc.collect + NCCL fix pending) +- **Infra log**: `infra-260312_071710.log` +- **Env**: `PYTORCH_ALLOC_CONF=expandable_segments:True` (gc.collect fix in local code but NOT in Ray package yet) +- **Resumed from**: Step 21 +- **Steps completed**: 4 (steps 22-25) +- **Crashed at**: Step ~25→26 (`cudaErrorInvalidValue` in flash_attn.py:484) +- **Checkpoint saved**: `global_step_24` + +### Run 13 — enforce_eager=true test (FAILED — made it worse) +- **Infra log**: `infra-260312_092414.log` (approx) +- **Env**: `enforce_eager=true` + `gc.collect()` + `TORCH_NCCL_AVOID_RECORD_STREAMS=1` + `expandable_segments` +- **Resumed from**: Step 24 +- **Steps completed**: 1 (step 25 only, rewards=0.031 — mostly zeroed) +- **Crashed at**: Step 25 with **NVIDIA Xid 31 MMU Fault** — GPU hardware page fault +- **Error**: `FAULT_PDE ACCESS_TYPE_VIRT_WRITE` across ALL 8 GPUs simultaneously +- **Checkpoint saved**: None (crashed at step 25, next ckpt was step 27) +- **Conclusion**: `enforce_eager=true` made the crash WORSE — exposed a raw GPU page fault instead of the `cudaErrorInvalidValue`. The root cause is dangling GPU pointers after vLLM sleep, not CUDA graph state. Reverted. + +### Run 14 — gc.collect + NCCL fix (enforce_eager reverted) — BEST RUN +- **Env**: `gc.collect()` + `TORCH_NCCL_AVOID_RECORD_STREAMS=1` + `expandable_segments` + `enforce_eager=false` +- **Resumed from**: Step 24 +- **Steps completed**: 5 (steps 25-29) — **longest since Run 6** +- **Crashed at**: Step ~29→30 (`cudaErrorInvalidValue` in flash_attn.py:484) +- **Checkpoint saved**: `global_step_27` +- **Rewards**: 0.449, 0.383, 0.0 (zeroed), **0.508** (all-time high!), 0.387 +- **Conclusion**: `gc.collect()` helped extend from 3-4 steps to 5 steps per run + +### Run 15 — Same config, continuing +- **Env**: Same as Run 14 (gc.collect + NCCL + expandable_segments) +- **Resumed from**: Step 27 +- **Steps completed**: 2 (steps 28-29) +- **Crashed at**: Step ~29→30 (`cudaErrorInvalidValue`) +- **Checkpoint saved**: None new (step 30 not reached) +- **Rewards**: **0.555** (all-time high!), 0.371 + +### Run 16 — Same config, continuing +- **Env**: Same as Run 14-15 +- **Resumed from**: Step 27 +- **Steps completed**: 4 (steps 28-31) +- **Crashed at**: Step ~31→32 (`cudaErrorInvalidValue`) +- **Checkpoint saved**: `global_step_30` +- **Rewards**: **0.582** (ATH!), 0.387, 0.344, 0.348 + +### Run 17 — Testing gpu_memory_utilization=0.7 (FAILED — worse) +- **Env**: gc.collect + NCCL + expandable_segments + **gpu_memory_utilization=0.7** +- **Resumed from**: Step 30 +- **Steps completed**: 1 (step 31 only) +- **Crashed at**: Step ~31→32 (`cudaErrorInvalidValue`) +- **Checkpoint saved**: None +- **Conclusion**: Lower gpu_memory_utilization made it WORSE (1 step vs 4-5). Reverted to 0.8. + +### Run 18 — Back to best config (gpu_util=0.8) + cumem patch pending +- **Env**: gc.collect + NCCL + expandable_segments + gpu_memory_utilization=0.8 +- **Note**: Python-side cumem patch (PR #36535) applied to vLLM archive mid-run. Won't take effect until next restart. +- **Resumed from**: Step 30 +- **Status**: Running + +## Applied Fix: vLLM cumem Allocator Python Patch (PR #36535) + +Applied the Python-side changes from [PR #36535](https://github.com/vllm-project/vllm/pull/36535) to `vllm/device_allocator/cumem.py`: + +1. **`AllocationData.mapped: bool = True`** — tracks whether each allocation is currently GPU-mapped +2. **`CuMemAllocator.sleeping = False`** flag — set during sleep +3. **`_python_free_callback`**: returns dummy handle `(0, 0, 0, 0)` when sleeping, preventing C++ from trying to unmap already-freed memory (the double-free bug) +4. **`sleep()`**: skips already-unmapped allocations, sets `data.mapped = False`, sets `self.sleeping = True` +5. **`wake_up()`**: skips already-mapped allocations, sets `self.sleeping = False`, sets `data.mapped = True` + +The C++ side (`csrc/cumem_allocator.cpp`) changes from the PR cannot be applied without recompiling. However, the Python-side changes should prevent the most common crash path (GC-triggered double-free during sleep). + +**Result: Cumem Python patch alone insufficient.** Run 18 crashed after 2 steps (same `cudaErrorInvalidValue`). The C++ side bugs are still triggering. + +### Run 18 — cumem Python patch test (INSUFFICIENT) +- **Steps completed**: 2 (steps 31-32) +- **Crashed at**: Step ~33 (`cudaErrorInvalidValue`) +- **Conclusion**: Python patch prevents Python-level double-free but C++ `cumem_allocator.cpp` still has stale error codes + +### Run 19 — Same config +- **Steps completed**: 2 (steps 31-32) +- **Crashed at**: Step ~33 + +### Run 20 — Same config, ckpt_interval changed to 2 +- **Steps completed**: ~2 (steps 31-32) +- **Crashed at**: Step ~33 + +## Minimal Reproduction Attempt + +**Script**: `test_vllm_sleep_wake.py` — standalone vLLM sleep/wake test (no SkyRL) +- 15 cycles of sleep → wake_up(weights) → wake_up(kv_cache) → generate → sleep +- Single GPU, Qwen3-8B, max_model_len=8192 +- **Result: ALL 15 CYCLES PASSED** — basic sleep/wake does NOT reproduce the crash + +**Conclusion**: The crash requires something present in the full SkyRL training loop but absent in standalone vLLM: +- Actual weight updates via NCCL broadcast during wake +- 8 colocated engines sharing GPU memory with FSDP training +- High memory pressure (256 concurrent trajectories, 32K context) +- Or interaction between PyTorch's allocator and vLLM's cumem allocator + +## Crash Variability + +Sleep/wake cycles before crash (per-engine, across all runs): +- Min: 2 cycles +- Max: 11 cycles +- Mode: 3-5 cycles +- Not deterministic — crash is probabilistic, likely depends on memory layout + +## Step-wise vs Non-step-wise + +The trainer sleep/wake cycle is **identical** for both modes. Each training step does exactly: +1. `wake_up(tags=["weights"])` → weight sync → `wake_up(tags=["kv_cache"])` → generate → `sleep()` + +The only difference is inside the generation phase (Harbor multi-turn agent loop). Step-wise collects per-turn rollout_details but doesn't change the sleep/wake pattern. + +**Non-step-wise was never run on this setup**, so we can't confirm whether it also crashes. The crash is likely NOT step-wise-specific but a general vLLM cumem allocator bug that triggers under high memory pressure with colocated training. + +## Key Insight: Crash Timing Is Steps-Since-Resume, Not Global Step + +| Run | Resumed From | Steps Until Crash | Crash At Global Step | +|-----|-------------|-------------------|---------------------| +| 4 | 5 | 3 | ~8 | +| 5 | 5 | 3-5 | ~8-10 | +| 6 | 5 | 10 | ~15 | +| 7 | 15 | 4-5 | ~19-20 | +| 8 | 15 | 4-5 | ~19-20 | +| 9 | 15 | 4-5 | ~19-20 | +| 10 | 15 | 4-5 | ~19-20 | +| 11 | 18 | 3-4 | ~21-22 | +| 12 | 21 | 4 | ~25 | +| 13 | 24 | 1 | ~25 (enforce_eager — Xid 31, WORSE) | +| 14 | 24 | **5** | ~29 (gc.collect + NCCL fix — best) | +| 15 | 27 | 2 | ~29 (same config, same crash point) | +| 16 | 27 | 4 | ~31 | +| 17 | 30 | **1** | ~31 (gpu_util=0.7 — WORSE) | + +The crash happens after ~3-5 sleep/wake cycles from a fresh process, regardless of global step. This strongly suggests **state corruption in vLLM's FlashAttention metadata that accumulates over sleep/wake transitions**. + +## The Actual Bug: Known vLLM cumem Allocator Bug (#36651) + +The crash is at `flash_attn.py:484`: +```python +self.scheduler_metadata[:n] = scheduler_metadata +``` + +This is a **known vLLM bug** documented in multiple open issues: + +### [vLLM Issue #36651](https://github.com/vllm-project/vllm/issues/36651) — cumem allocator: double-free and stale error codes during sleep/wake cycles (filed 2026-03-10, OPEN) + +Documents **five bugs** in the cumem allocator: +1. **Double `cuMemRelease`** on already-unmapped allocations during `sleep()` +2. **CUDA ops on freed memory** when PyTorch's GC triggers `my_free` during sleep +3. **Stale global `error_code`** that persists across operations, causing wrong code paths +4. **Size mismatch** in `my_free` passing wrong size to `unmap_and_release` +5. **Flash Attention 4 import failure** due to module restructuring + +**Fix**: [PR #36535](https://github.com/vllm-project/vllm/pull/36535) — OPEN, not yet merged. Tracks per-allocation mapped state, adds a `sleeping` flag, clears stale error codes. + +### [vLLM Issue #31016](https://github.com/vllm-project/vllm/issues/31016) — FlashInfer metadata not restored after wake (OPEN) + +Same class of bug: attention backend metadata tensors are stateful and get invalidated during sleep. FlashInfer's `block_table_arange` and FlashAttention's `scheduler_metadata` both suffer from this — tensors tagged with `"kv_cache"` get discarded during sleep but their handles aren't invalidated. + +### [vLLM Issue #35463](https://github.com/vllm-project/vllm/issues/35463) — Sleep mode broken on vLLM 0.16.0+ (OPEN) + +Reports the exact `CUDA Error: invalid argument` on basic sleep/wake cycles. Same root cause. + +### [vLLM Issue #36753](https://github.com/vllm-project/vllm/issues/36753) — POST /wake_up causes crash (OPEN, filed 2026-03-11) + +vLLM process crashes entirely on wake_up after sleep. Under active investigation. + +## What Has NOT Helped (Fixing the Crash) + +| Mitigation | Effect on Crash | +|-----------|----------------| +| `PYTORCH_ALLOC_CONF=expandable_segments:True` | No change (crash is not OOM) | +| `PYTORCH_CUDA_ALLOC_CONF` (wrong name) | No effect (deprecated, ignored) | +| `enforce_eager=true` | **Made it WORSE** — exposed Xid 31 MMU Fault instead of cudaErrorInvalidValue | +| `gpu_memory_utilization=0.7` | **Made it WORSE** — crashed after 1 step vs 4-5 | +| `gc.collect()` before `empty_cache()` | Slight improvement (~5 steps vs ~3-4). Good hygiene. | +| `TORCH_NCCL_AVOID_RECORD_STREAMS=1` | Untested in isolation | + +## What HAS Helped (Recovery/Resilience) + +| Mitigation | Effect | +|-----------|--------| +| `ckpt_interval=3` | **Key** — checkpoints before crash window, lose ≤1 step per restart | +| `except BaseException` in `_worker` | CancelledError batches produce zeroed output instead of crash | +| None safety fill for `all_outputs` | Prevents `NoneType` crash from cancelled tasks | +| Monitoring cron (15 min) | Auto-detects crash and relaunches | + +## What Might Help (Untested) + +### 1. `gc.collect()` before `empty_cache()` (Applied, not yet tested) +Added to `trainer.py` and `fsdp_worker.py`. Unlikely to fix the root cause (not OOM) but good hygiene. + +### 2. `TORCH_NCCL_AVOID_RECORD_STREAMS=1` (Applied, not yet tested) +Added to launch script. Unlikely to fix root cause but may help with NCCL-related memory issues. + +### 3. Investigate vLLM sleep/wake `scheduler_metadata` lifecycle +The real fix is likely in vLLM's FlashAttention backend — the `scheduler_metadata` tensor needs to be re-allocated or validated after each wake-up cycle. + +### 4. File a vLLM issue +The `cudaErrorInvalidValue` at `flash_attn.py:484` after N sleep/wake cycles is likely a vLLM bug. Should file at https://github.com/vllm-project/vllm/issues with the reproduction steps. + +### 5. Try `enforce_eager=true` +Disabling CUDA graphs might avoid the stale tensor issue since eager mode doesn't cache compiled kernels. Trade-off: slower inference. + +### 6. Try reducing `gpu_memory_utilization` +Currently 0.8. Reducing to 0.7 gives more headroom, potentially avoiding the tensor corruption trigger. + +## Current Mitigation Strategy + +The crash is non-fatal thanks to: +1. **`ckpt_interval=3`**: Checkpoints every 3 steps, crash happens at step ~3-5, so we usually get 1 checkpoint per run +2. **`resume_mode=latest`**: Auto-resumes from latest checkpoint +3. **Monitoring cron**: Checks every 15 minutes, auto-cleans Ray PGs + Daytona sandboxes + port 8000, relaunches +4. **`except BaseException` in `_worker`**: CancelledError batches produce zeroed output (loss=0, no model update) instead of crashing + +Net effect: Training makes ~2-3 steps of progress per restart, with ~5 min overhead per restart. This is ~85% efficient compared to a crash-free run. + +## Comparison with Other Frameworks + +### veRL +- Has similar OOM-after-N-steps issues (Issues #3293, #2260, #3902) +- Recommends `expandable_segments:True` +- Uses FSDP2 reserved memory workaround (PR #1667) +- Calls `gc.collect()` + `torch.cuda.empty_cache()` at every phase transition + +### SLIME +- Uses CUDA IPC for zero-copy weight sync (avoids temp allocations entirely) +- Does NOT aggressively call `empty_cache()` +- Memory partitioning via `--sglang-mem-fraction-static` + +### TRL (HuggingFace) +- Standard `gc.collect()` → `torch.cuda.empty_cache()` between phases + +## Checkpoint History + +``` +global_step_5 — Run 1-3 (initial training) [pruned] +global_step_10 — Run 5 [pruned] +global_step_15 — Run 6 [pruned] +global_step_18 — Run 10 (first with ckpt_interval=3) [pruned] +global_step_21 — Run 11 (active) +global_step_24 — Run 12 (active) +global_step_27 — Run 14 (active, latest) +``` + +## Reward Curve Summary + +| Global Step | Avg Reward | Notes | +|------------|-----------|-------| +| 1 | 0.324 | Baseline | +| 5 | 0.258 | First checkpoint | +| 10 | 0.262 | | +| 12 | 0.371 | | +| 13 | 0.398 | | +| 16 | 0.418 | | +| 17 | 0.445 | | +| 22 | 0.430 | | +| 23 | 0.434 | | +| 25 | 0.449 | | +| 28 | 0.508 → **0.555** | All-time high (Run 15) | +| 29 | 0.387 → 0.371 | | + +Training is making steady progress with rewards increasing from ~0.32 to ~0.40-0.55 over 29 effective steps. +``` diff --git a/examples/train_integrations/harbor/STEPWISE_TRAINING.md b/examples/train_integrations/harbor/STEPWISE_TRAINING.md new file mode 100644 index 0000000000..6449be7f0b --- /dev/null +++ b/examples/train_integrations/harbor/STEPWISE_TRAINING.md @@ -0,0 +1,944 @@ +# Step-Wise Training with TIS for Harbor in SkyRL + +## Table of Contents +- [Motivation](#motivation) +- [How Step-Wise Training Works in SkyRL](#how-step-wise-training-works-in-skyrl) +- [Comparison with Other Frameworks](#comparison-with-other-frameworks) +- [Implementation](#implementation) +- [Caveats and Design Decisions](#caveats-and-design-decisions) +- [Configuration](#configuration) +- [Running](#running) +- [Files Changed](#files-changed) + +--- + +## Motivation + +### The Re-Tokenization Problem + +Currently, Harbor's `HarborGenerator` re-tokenizes the final chat history (string) after the agent finishes. The flow is: + +1. vLLM generates tokens → Harbor agent executes tool calls → environment returns observations +2. Harbor returns the full chat history as a list of message dicts (strings) +3. `HarborGenerator` re-tokenizes the entire chat history using `get_response_ids_and_loss_mask_from_messages()` + +This re-tokenization can produce **different token IDs** than what the model actually generated — a phenomenon called **"retokenization drift"** (see [vLLM Agent Lightning blog post](https://blog.vllm.ai/2025/10/22/agent-lightning.html)). Causes include: +- Non-unique tokenization (e.g., `"HAVING"` → `H`+`AVING` vs `HAV`+`ING`) +- Tool-call serialization changes during parsing/re-rendering +- Chat template differences across frameworks + +### Why This Breaks TIS + +TIS (Truncated Importance Sampling) corrects for off-policy drift between the rollout policy and the current policy: + +``` +TIS ratio = π_current(token) / π_rollout(token) + = exp(current_logprobs - rollout_logprobs) +``` + +If the training tokens differ from the generation tokens due to retokenization, then `rollout_logprobs` (recorded during generation) don't correspond to the actual tokens being trained on. The TIS ratios become meaningless. + +### The Solution: Step-Wise Training + +Instead of re-tokenizing, use the **exact per-turn token IDs and logprobs from vLLM** via Harbor's `collect_rollout_details` feature. Each agent turn becomes a separate (prompt, response) training sample, where: +- `prompt_ids` = the full context vLLM saw (from `rollout_details.prompt_token_ids[turn]`) +- `response_ids` = the exact tokens vLLM generated (from `rollout_details.completion_token_ids[turn]`) +- `logprobs` = the exact per-token logprobs from vLLM (from `rollout_details.logprobs[turn]`) + +This eliminates retokenization drift entirely and enables correct TIS computation. + +--- + +## How Step-Wise Training Works in SkyRL + +### Per-Step Reward Assignment + +Each multi-turn trajectory of N turns is decomposed into N separate training samples. Rewards are assigned as per-token lists: + +``` +Step 1: reward = [0.0, 0.0, ..., 0.0] # all zeros (intermediate step) +Step 2: reward = [0.0, 0.0, ..., 0.0] # all zeros (intermediate step) +... +Step N: reward = [0.0, 0.0, ..., final_reward] # reward at last token only +``` + +Only the final step of each trajectory receives the actual reward (from Harbor's verifier), placed at the last token position. This follows the same pattern as `SkyRLGymGenerator` (see `skyrl_gym_generator.py:446-451`). + +### Advantage Computation + +Advantages are computed **only for last steps**, then broadcast to all steps in the same trajectory: + +```python +# trainer.py:784-815 +# 1. Filter to last steps only +last_step_advantages = compute_advantages(rewards[is_last_step], ...) + +# 2. Build trajectory ID mapping +traj_ids = cumsum(shifted_is_last_step) # maps each step to its trajectory + +# 3. Broadcast: all steps in trajectory i get the same advantage +advantages = last_step_advantages[traj_ids] +``` + +**This is mathematically equivalent to normal (non-step-wise) training** from an advantage perspective — the advantage signal comes entirely from the final trajectory reward. The difference is purely operational: each step gets its own (prompt, response) pair with exact token IDs. + +### TIS (Truncated Importance Sampling) + +TIS is implemented in `off_policy_correction_utils.py`. Two modes: + +- **Token-level TIS** (`tis_ratio_type="token"`): Clamp per-token `exp(old_logprobs - rollout_logprobs)` to `[0, token_tis_ratio_clip_high]`, multiply with loss. Recommended clip: 1.5-5.0. +- **Sequence-level TIS** (`tis_ratio_type="sequence"`): Product of all token ratios (sum in log space), clamped. Recommended clip: 2.0-10.0. + +Additionally, **outlier token masking** rejects entire sequences where any token has an extreme importance ratio (configurable thresholds). + +Step-wise training enables correct TIS because `rollout_logprobs` are the exact logprobs from generation, matching the exact `response_ids` used for training. + +### Batch Expansion + +A batch of N trajectories with M average turns produces N×M training samples: + +``` +Input: 4 prompts × 2 samples = 8 trajectories +Output: 8 trajectories × ~3 turns avg = ~24 step-samples +``` + +The trainer handles this transparently — `mini_batch_size` and `micro_train_batch_size_per_gpu` control memory as before. Step-wise only increases the number of gradient accumulation steps per optimizer update, not peak GPU memory. + +### Limitation: Step-Wise Rewards Are Dropped + +SkyRL's `SkyRLGymGenerator` supports per-step rewards — `env.step()` can return a non-zero `step_reward` at each turn. In the **non-step-wise** path, these are correctly placed as per-token rewards at turn boundaries (`_build_per_token_rewards()`), and the advantage estimator (GRPO or GAE) sees all of them. + +However, in the **step-wise** path, intermediate step rewards are **silently dropped**. The trainer filters to last steps only before computing advantages: + +```python +# trainer.py:794 — only last-step rewards are used +last_step_rewards = token_level_rewards[is_last_step] +last_step_advantages, last_step_returns = compute_advantages_and_returns( + token_level_rewards=last_step_rewards, ... +) +# Broadcast back: all steps get the same advantage from last step's reward +advantages = last_step_advantages[traj_ids] +``` + +This means if `env.step()` returns `reward=0.5` at step 2 and `reward=1.0` at the final step, only `reward=1.0` contributes to the advantage. The `0.5` is placed in the per-token reward list for step 2 but never read during advantage computation. + +**No framework supports both step-wise decomposition AND per-step advantages:** + +| Framework | Per-step rewards in return? | Advantage granularity | +|-----------|---------------------------|----------------------| +| **SkyRL (step-wise)** | No — dropped by `[is_last_step]` filter | Scalar per trajectory, broadcast | +| **SkyRL-Agent** | No — `Transition.reward=0.0` always | Scalar per trajectory, broadcast | +| **Prime-RL** | N/A — no per-step reward concept | Scalar per rollout, broadcast | +| **veRL/rLLM** | No — `assert mode == "broadcast"` | Scalar per trajectory, broadcast | +| **tinker-cookbook** | **Yes** — `get_total_rewards()` sums all `transition.reward` + `final_reward` | But still scalar per trajectory, broadcast | +| **SLIME** | No step-wise decomposition | Trajectory-level | + +tinker-cookbook comes closest: it correctly **sums** per-step rewards into the total return before computing group-centered advantages. But the advantage is still one scalar per trajectory, broadcast to all action tokens. There is no per-step advantage. + +The fundamental reason: **GRPO has no natural per-step formulation.** GRPO groups trajectories by prompt and computes `advantage = reward - mean(group_rewards)`. This is inherently trajectory-level — intermediate steps from different trajectories of the same prompt aren't directly comparable. + +**GAE could in principle do per-step advantages** — it uses a value function `V(s)` to estimate advantage at each token: `δ_t = r_t + γV(s_{t+1}) - V(s_t)`. SkyRL's non-step-wise path already supports GAE with per-token rewards at turn boundaries. But no one has combined GAE with step-wise decomposition — the `[is_last_step]` filter is applied regardless of which advantage estimator is configured. + +### Step-Wise Training Is NOT Mathematically Equivalent to Non-Step-Wise + +Step-wise decomposition changes the loss in ways that depend on the loss reduction method. Consider a trajectory with 3 turns: + +- **Non-step-wise**: 1 training sample. `response = [A1, O2, A2, O3, A3]` with `loss_mask = [1,1,0,0,1,1,0,0,1,1]`. +- **Step-wise**: 3 training samples, each with pure completion tokens and `loss_mask = [1,1]`. + +The per-token loss values (PPO surrogate) are also not identical because the model processes different contexts — step-wise has shorter sequences per forward pass, and prompt left-truncation can alter conditioning. But even assuming identical per-token losses, the reduction differs: + +**`token_mean`**: `masked_mean(loss, loss_mask)` — sum of valid token losses / count of valid tokens. This is the closest to equivalent because both approaches weight each valid token equally. But the batch composition differs: step-wise has N×M step-samples per mini-batch vs N trajectories, and `mini_batch_size` doesn't account for step expansion. **Approximately equivalent.** + +**`sequence_mean`**: `masked_mean(loss, loss_mask, dim=-1).mean()` — per-sequence token-mean, then batch-mean. **NOT equivalent.** Each step-sample is a separate "sequence" getting equal weight. A trajectory with 10 turns produces 10 sequences and gets 10× the gradient contribution of a 1-turn trajectory. In non-step-wise, every trajectory is 1 sequence regardless of turn count. + +**`seq_mean_token_sum_norm`** (Dr. GRPO): `sum(loss * mask, dim=-1) / max_seq_len`, then `.mean()`. **NOT equivalent.** Same per-sequence weighting issue as `sequence_mean`. Additionally, short step-samples (common — early turns have short completions) get their token sum divided by a large `max_seq_len`, then receive equal weight in the `.mean()`. + +#### Cross-Framework: Prefix Merging Preserves Loss Semantics + +This weighting issue is not unique to SkyRL. The key distinction is between frameworks that **always decompose** every turn (always inequivalent) versus those that use **prefix-aware merging** (equivalent when prefixes hold): + +| Framework | Decomposition strategy | Equivalent to non-step-wise? | +|-----------|----------------------|------------------------------| +| **SLIME** | No decomposition (single sequence, loss mask) | Always equivalent — it IS non-step-wise | +| **Prime-RL** | Prefix merging: merge when extension holds, split when breaks | **Equivalent when extension holds** (common case). Diverges only on context resets. | +| **Agent Lightning** (`trajectory`) | Prefix merging (same logic as Prime-RL) | **Equivalent when prefix holds.** Splits on mismatch (retoken/template/post-processing). | +| **Agent Lightning** (`transition`) | Always decomposes every turn | **Never equivalent** — same as SkyRL step-wise | +| **tinker-cookbook** | Prefix merging (same logic as Prime-RL) | **Equivalent when extension holds.** Diverges for MemAgent context resets. | +| **SkyRL-Agent** | Prefix merging via `transitions_to_training_data()` | **Equivalent for standard ReAct** (all turns merge into 1 datum). Diverges for MemAgent. | +| **veRL/rLLM** (experimental) | Always decomposes every turn | **Never equivalent** — same weighting issue | +| **SkyRL step-wise** | Always decomposes every turn | **Never equivalent** | +| **Harbor step-wise** (ours) | Always decomposes every turn | **Never equivalent** | + +The insight: prefix-aware merging (Prime-RL, tinker-cookbook, SkyRL-Agent) isn't just a compute optimization (O(T) vs O(T²)) — it also **preserves loss reduction semantics**. When all turns merge into one sample, the loss reduction treats the trajectory as a single sequence, identical to non-step-wise. + +SkyRL/Harbor's per-turn decomposition trades loss weighting consistency for the benefit of exact token IDs and logprobs (avoiding retokenization drift for TIS). Whether that trade-off is worth it depends on whether TIS correctness matters more than loss weighting equivalence. + +### Full Equivalence Analysis: Beyond Loss Reduction + +Loss reduction is the most visible difference, but step-wise decomposition also affects other components: + +| Component | Equivalent? | Root cause | Fixable? | +|-----------|------------|------------|----------| +| Advantage computation (GRPO) | **Yes** | `[is_last_step]` filter ensures same inputs to GRPO | N/A | +| Per-token reward placement | **Doesn't matter for GRPO** | GRPO does `scores = token_level_rewards.sum(dim=-1)` — position is irrelevant. Only matters for GAE where `δ_t = r_t + γV(t+1) - V(t)` uses per-position rewards. | N/A | +| Loss reduction | **No** | Per-sample weighting (more steps = more weight) | **Yes** — per-trajectory weighting (see below) | +| `advantage_batch_normalize` | **No** | Mean/std computed across all step-samples equally | **Yes** — same per-trajectory weighting | +| Forward pass logprobs | **No** (tiny) | Each step is a separate shorter sequence. Logprobs differ from the single long-sequence non-step-wise forward pass due to different prompt/response boundary placement and prompt left-truncation. Numerically small (same model, same causal attention). | **No** — fundamental to decomposition | +| KL loss/penalty | **No** (tiny) | Inherits from logprob difference | **No** — inherits from above | +| Entropy | **No** (tiny) | Inherits from logprob difference | **No** — inherits from above | +| Metrics | **Mostly yes** | Filtered by `is_last_step` for reward metrics | N/A | + +The logprob difference is fundamental and unfixable — it's inherent to processing shorter sequences independently vs one long sequence. But it's numerically tiny. The loss reduction and normalization differences are significant and fixable. + +### Proposed Fix: Per-Trajectory Weighting + +The loss reduction inequivalence can be fixed by weighting each step-sample by `1 / n_steps_in_its_trajectory`: + +```python +# Current (inequivalent): +per_seq_loss = masked_mean(loss, loss_mask, dim=-1) # [N_samples] +total_loss = per_seq_loss.mean() # each sample weight = 1/N_samples + +# Fixed: +per_seq_loss = masked_mean(loss, loss_mask, dim=-1) # [N_samples] +weights = 1.0 / steps_per_trajectory[traj_ids] # e.g. [1/3, 1/3, 1/3, 1/2, 1/2, 1, 1] +weights = weights / weights.sum() # normalize +total_loss = (per_seq_loss * weights).sum() +``` + +A 3-turn trajectory's 3 samples each get weight `1/3`, so their combined contribution equals one trajectory. Same idea applies to `seq_mean_token_sum_norm` and `advantage_batch_normalize`. + +The information needed is already available in all frameworks: +- SkyRL step-wise: `is_last_step` → `cumsum` gives trajectory boundaries +- SkyRL-Agent: `episode_nums` directly +- Prime-RL / Agent Lightning: `len(merged_trace_idx)` per rollout + +**No framework currently implements this fix.** For `token_mean`, the fix is less critical — it already weights each valid token equally regardless of grouping, so it's approximately equivalent. + +### Policy Loss Functions: Step-Wise Equivalence + +All policy loss functions compute per-token loss values then call `reduce_loss`. The per-token computation is element-wise (operates on `log_probs[i][t]`, `old_log_probs[i][t]`, `advantages[i][t]` independently) — **except GSPO**: + +| Policy Loss | Per-token computation identical? | Additional step-wise issue? | +|---|---|---| +| **REGULAR** (PPO clip) | Yes (modulo tiny logprob diff) | Only `reduce_loss` | +| **DUAL_CLIP** | Yes | Only `reduce_loss` | +| **SAPO** | Yes | Only `reduce_loss` (recommends `sequence_mean`) | +| **GSPO** | **No** | Sequence-level IS weight + `reduce_loss` | +| **CISPO** | Yes | Only `reduce_loss` | +| **CLIP_COV** | Yes | Only `reduce_loss` | +| **KL_COV** | Yes | Only `reduce_loss` | + +**GSPO** computes a **sequence-level importance weight**: `log_importance_weights = masked_mean(log_ratio, loss_mask, dim=-1)` — the mean log-ratio across all tokens in the same sample. In non-step-wise, "one sample" = full trajectory. In step-wise, "one sample" = one turn. The IS weight is computed over different scopes, producing different values and noisier estimates (fewer tokens to average over per turn). + +### Off-Policy Correction: Step-Wise Equivalence + +The off-policy correction utilities (`off_policy_correction_utils.py`) have multiple sequence-level operations that are affected by step-wise decomposition: + +| Component | Affected? | Issue | +|---|---|---| +| **Token-level TIS** (`tis_ratio_type="token"`) | **No** | Purely per-token: `clamp(exp(old - rollout))` | +| **Sequence-level TIS** (`tis_ratio_type="sequence"`) | **Yes** | `sum(log_ratio, dim=-1)` — product of token IS ratios across the sequence. Different scope: full trajectory vs one turn. | +| **Outlier token mask** | **Yes** | Masks entire sequence if *any* token has outlier ratio (`.all(dim=-1)`). In non-step-wise, one bad token masks 100 tokens. In step-wise, it only masks that turn's ~5 tokens. More granular but different behavior. | +| **Geometric sequence mask** | **Yes** | Geometric mean of IS ratios per sequence: `exp(sum(log_ratio) / num_tokens)`. Different `num_tokens` (full trajectory vs one turn) gives different means. | +| **Product sequence mask** | **Yes** | Product of IS ratios: `sum(log_ratio, dim=-1)`. Same issue as sequence-level TIS. | + +All sequence-level off-policy corrections that aggregate across `dim=-1` are affected — they compute statistics over a "sequence" which is the full trajectory in non-step-wise but a single turn in step-wise. Token-level TIS is the only mode that is fully equivalent. + +--- + +## Comparison with Other Frameworks + +### SkyRL (SkyRLGymGenerator) + +SkyRL's built-in `SkyRLGymGenerator` already supports step-wise via `generator.step_wise_trajectories=True`. It uses token-in-token-out: the generator directly controls tokenization at each turn, so there's no retokenization problem. Harbor's case is different because Harbor runs an external agent loop (Terminus 2) that returns strings, requiring either retokenization or rollout_details. + +Key code: `skyrl_gym_generator.py:353-371` (per-step output), `skyrl_gym_generator.py:704-774` (flattening). + +### SkyRL-Agent (Full Flow) + +SkyRL-Agent has explicit step-wise training built around three core abstractions: `Transition`, `transitions_to_training_data()`, and prefix-aware merging. Here is the full pipeline: + +#### Step 1: Transition Recording (During Agent Execution) + +Each LLM call is captured by the `@record_transition` decorator (`skyrl_agent/functional/utils.py:62`): + +```python +@record_transition +async def _generate_with_recording(self, input_ids=[], ...): + ... +``` + +This creates one `Transition` per LLM call: + +```python +Transition( + ob=Observation(input_ids=[...]), # Full token IDs sent TO the LLM + ac=TokensWithLogprobs( + token_ids=[...], # Tokens generated BY the LLM + logprobs=[...], # Per-token logprobs + text="...", + ), + reward=0.0, # Placeholder — set later at trajectory level + episode_done=False, +) +``` + +All transitions accumulate in `self.transitions: List[Transition]` during the agent run. + +#### Step 2: `transitions_to_training_data()` — Prefix-Aware Merging + +**Input**: `List[Transition]` from a **single trajectory** (one agent run). +**Output**: `List[TrainingDatum]` — potentially **fewer** items than input transitions. + +The function (`utils.py:136-235`) maintains an accumulator and processes transitions one by one. The key logic: if the current transition's observation is a **prefix extension** of the accumulated sequence, the transition is **merged** into the current datum. Otherwise, a new datum starts: + +``` +Transition 1: ob=[O1], ac=[A1] +Transition 2: ob=[O1,A1,O2], ac=[A2] ← ob extends full_sequence → MERGE +Transition 3: ob=[O3], ac=[A3] ← ob is NOT a prefix → FLUSH, start new datum +``` + +Result of merging transitions 1+2 into one `TrainingDatum`: + +```python +TrainingDatum( + input_tokens=[O1], # First observation = "prompt" + response_tokens=[A1, O2, A2], # Everything after = "response" + response_logprobs=[lp1..., 0.0..., lp2...], # Real logprobs for actions, 0.0 for obs + response_mask=[1,1,..., 0,0,..., 1,1,...], # 1 for action tokens, 0 for obs tokens +) +``` + +**When does merging happen?** In a standard ReAct agent, each LLM call receives the full conversation history. So `ob` for turn 2 = `[O1, A1, O2]`, which is a prefix extension of `[O1, A1]`. All turns merge into **one** `TrainingDatum` → the trajectory produces **1 step**. + +**When does it NOT merge?** When context resets occur — e.g., MemAgent's `next_with_summary` tool replaces the context with a summary. The new observation has no prefix relationship with the previous sequence, so a new datum starts. This is how MemAgent produces **multiple** `TrainingDatum`s (= multiple steps) from one trajectory. + +#### Step 3: Post-Processing (`AgentRunner._post_process_results()`, `base.py:406-418`) + +Iterates over all trajectories in the batch: + +```python +for result in matched_results: # For each trajectory + transitions = result.get("transitions", []) + data_list = transitions_to_training_data(transitions) # → List[TrainingDatum] + + for data in data_list: # For each step (datum) + prompt_input_ids.append(data.input_tokens) + response_ids.append(data.response_tokens) + logprobs.append(data.response_logprobs) + response_assistant_mask.append(data.response_mask) + is_last_episode_list.append(False) + + is_last_episode_list[-1] = True # Mark last step + steps_per_trajectory.append(len(data_list)) # Track steps per traj + + # Broadcast trajectory-level reward to ALL steps (scalar, same value) + reward_list.extend([result.get("reward", False)] * len(data_list)) +``` + +**Reward handling**: The trajectory-level reward (scalar from the environment) is **replicated identically** to every step. All steps of the same trajectory get the same scalar reward. This is different from SkyRLGymGenerator/Harbor which use per-token reward lists with the reward placed at a specific token position. + +#### Step 4: Output Format + +```python +output = { + "prompt_token_ids": prompt_input_ids, # Per-step prompts + "response_ids": response_ids, # Per-step responses (interleaved obs+action tokens) + "rewards": reward_list, # Per-step scalar rewards (same for all steps of a traj) + "traj_rewards": traj_reward_list, # Per-trajectory scalar rewards (not expanded) + "loss_masks": loss_mask, # Per-step masks (0 for obs tokens, 1 for action tokens) + "episode_nums": steps_per_trajectory, # [3, 2, 4, ...] — num steps each trajectory produced + "is_last_episode": is_last_episode_list, # [F, F, T, F, T, F, F, F, T, ...] + "traj_idx": traj_idx_list, # Trajectory ID per step + "rollout_logprobs": logprobs, # Per-step logprobs (aligned with response_ids) + "rollout_metrics": rollout_metrics, +} +``` + +#### Step 5: Tinker Integration Consumes This (`tinker_train.py:357-443`) + +```python +rollouts = await agent_generator.run(input_batch) + +# Use traj_rewards (one per trajectory, NOT step-expanded) +all_returns = [float(r) for r in rollouts["traj_rewards"]] + +# Compute GRPO advantages at trajectory level +all_advantages = compute_advantages_grpo(all_returns, group_size=group_size) + +# Broadcast advantages to steps using episode_nums +step_advantages = [] +for idx, num_steps in enumerate(num_steps_per_trajectory): + step_advantages.extend([all_advantages[idx]] * num_steps) +``` + +Then for each step, it builds a Tinker `Datum` with the full sequence (prompt + response), logprobs, and the broadcasted advantage value. + +#### Key Differences: SkyRL-Agent vs Harbor Step-Wise + +| Aspect | SkyRL-Agent | Harbor Step-Wise | +|--------|-------------|-----------------| +| **What is a "step"?** | A `TrainingDatum` from prefix-aware merging. For standard ReAct = 1 step (all turns merge). For MemAgent = N steps (context resets create breaks). | One LLM turn from `rollout_details.completion_token_ids[i]`. Every turn is always a separate step. | +| **Reward format** | Scalar per step; all steps get same trajectory reward | Per-token list; reward only at last token of last step | +| **Response content** | Interleaved obs+action tokens with `response_mask` distinguishing them | Pure completion tokens only (obs is in the prompt of the next step) | +| **Logprobs** | Aligned with response_tokens: real logprobs for actions, 0.0 for obs tokens | Aligned with completion tokens only | +| **Advantage computation** | Done in Tinker integration using `episode_nums` to broadcast | Done in SkyRL trainer using `is_last_step` to broadcast | +| **Why multiple steps?** | Context resets (MemAgent summarization) break prefix continuity | Every LLM turn is inherently a separate step | + +#### Key Files + +| File | Purpose | +|------|---------| +| `skyrl_agent/functional/utils.py:12-36` | `Transition`, `Observation`, `TokensWithLogprobs` dataclasses | +| `skyrl_agent/functional/utils.py:62-115` | `@record_transition` decorator | +| `skyrl_agent/functional/utils.py:136-235` | `transitions_to_training_data()` — prefix-aware merging | +| `skyrl_agent/agents/base.py:406-627` | `_post_process_results()` — flattening + reward broadcast | +| `skyrl_agent/integrations/tinker/tinker_train.py:357-443` | Tinker training loop — advantage broadcast + datum creation | + +### Prime-RL + +Prime-RL takes a fundamentally different approach: **prefix-aware trajectory merging with whole-sample scalar advantages** rather than per-step decomposition. + +#### The Extension Property + +Each multi-turn trajectory consists of steps with `(prompt_ids, completion_ids, completion_logprobs, completion_mask)`. The function `interleave_rollout()` (`prime_rl/orchestrator/trajectories.py:38-180`) processes them: + +For each step, it checks if the step's `prompt_ids` is a **prefix extension** of any active sample's accumulated sequence. If yes → **merge** into that sample. If not → **start a new sample**. + +``` +5-step trajectory where extension breaks at step 4: + +Steps 1-3: extension holds → merged into Sample 1 + completion_ids = [A1, delta_O2, A2, delta_O3, A3] + completion_mask = [1.., 0......., 1.., 0......., 1..] + logprobs = [lp1, 0.0....., lp2, 0.0....., lp3] + +Step 4: extension breaks (e.g., thinking stripped by chat template) +Steps 4-5: merged into Sample 2 +``` + +The `extend_sample()` function appends new prompt delta tokens with `mask=False, logprobs=0.0` (not trainable) and new completion tokens with `mask=True, logprobs=actual` (trainable). This is structurally identical to SkyRL-Agent's prefix-aware merging in `transitions_to_training_data()`. + +**When does extension break?** Models like Qwen3 that strip `` tags across turns, context compaction/summarization, sub-agent handoffs where context is discontinuous. + +#### Advantages: One Scalar Per Rollout + +Advantages are computed at the **rollout level**, not per-step (`advantage.py:65-91`): + +```python +# GRPO: advantage = reward - mean(group_rewards) +# One scalar per rollout, broadcast to ALL tokens: +advantages = [training_example.advantage] * len(input_ids) +``` + +There is no `is_last_step` or `episode_nums` — every merged sample gets a single scalar advantage from the rollout-level reward. + +#### Loss: IPO with Token-Level Importance Sampling + +The loss (`loss.py:107-163`) implements IPO (INTELLECT Policy Optimization): + +```python +log_importance_ratio = trainer_logprobs - inference_logprobs +importance_ratio = exp(log_importance_ratio) + +# Trust region via probability difference masking (not ratio clipping like SkyRL's TIS) +probs_diff = exp(trainer_logprobs) - exp(inference_logprobs) +keep_mask = loss_mask & (|probs_diff| < threshold) + +pg_loss = keep_mask * advantages * importance_ratio +kl_loss = loss_mask * log_importance_ratio² +loss = -pg_loss + kl_tau * kl_loss +``` + +The `inference_logprobs` are exact generation logprobs (stored per-step, aligned during merging). No re-tokenization needed because the extension property guarantees prefix alignment. + +#### Key Differences from SkyRL Step-Wise + +| Aspect | Prime-RL | SkyRL / Harbor Step-Wise | +|--------|----------|--------------------------| +| **When turns become separate samples** | Only when extension breaks | Always — every turn is separate | +| **Advantage** | One scalar per rollout | Computed on last step, broadcast | +| **Off-policy correction** | Probability-difference masking (IPO) | Ratio clipping (TIS) | +| **Compute scaling** | O(T) when extension holds | O(T²) always | +| **Philosophy** | Merge when possible, split when forced | Always decompose into per-turn units | + +### Agent Lightning + +Agent Lightning (Microsoft Research) supports two modes: **transition-level** (each turn = separate sample, like SkyRL step-wise) and **trajectory-level** (merge turns into one sample, like Prime-RL). Configured via `trace_aggregator.level: "trajectory"` or `"transition"`. + +#### Trajectory-Level Aggregation + +The trajectory path (`daemon.py:915-1023`) uses prefix matching identical in concept to Prime-RL: + +```python +for turn_index, trace in enumerate(sample_info["trace_list"]): + is_prefix, diagnostic = ids_startswith( + trace["prompt_ids"] + trace["response_ids"], + current_context, tokenizer, debug, + ) + if is_prefix: + current_context = trace["prompt_ids"] + trace["response_ids"] + current_merged_trace_idx.append(turn_index) + else: + # Start new group — soft fallback, no retry + merged_trace_idx.append(current_merged_trace_idx) + current_merged_trace_idx = [turn_index] + current_context = trace["prompt_ids"] + trace["response_ids"] +``` + +Merged turns produce one sample with interleaved response tokens and a `response_mask` (`1` for agent responses, `0` for prompt/observation delta tokens) — structurally identical to Prime-RL's `extend_sample()`. + +#### Failure Handling — Same as Prime-RL (Soft Fallback) + +When prefix matching fails, the trajectory is **split into multiple samples** starting at the mismatch point. No retry, no re-tokenization. This is exactly Prime-RL's behavior. + +The blog post ([trajectory_level_aggregation](https://agent-lightning.github.io/posts/trajectory_level_aggregation/)) documents five failure modes: + +1. **Retoken mismatch** (BPE artifacts): Same text tokenizes differently in generation vs re-tokenization. E.g., `` → `["<", "think", ">"]` during generation but `[""]` during template application. + +2. **Template mismatch**: Chat templates insert/strip special tokens at turn boundaries. E.g., `` generated explicitly by the model but stripped by the template on the next turn. + +3. **Post-processing modifications**: If agents truncate outputs (e.g., removing chain-of-thought) before feeding history to subsequent turns, the stored rollout won't match the prompt prefix. + +4. **Normalization artifacts**: Whitespace, escape character, or unicode normalization shifts token boundaries. + +5. **Structural alignment**: Manual string concatenation bypasses chat templates, causing missing role headers. + +A debug mode (`trace_aggregator.debug: true`) classifies mismatches into three categories and logs them: +- `template_mismatch`: Special token sequence differs +- `retoken_mismatch`: Token IDs differ but decoded text matches (BPE non-determinism) +- `others_mismatch`: Content itself differs + +#### Rewards and Advantages + +Scalar reward placed at the last token of the merged sample (`daemon.py:1070`): +```python +token_level_scores[torch.arange(n_transition), eos_mask_idx] = scores +``` + +Uses VERL's `compute_advantage()` with GRPO/GAE. One scalar advantage per merged sample, broadcast to all response tokens. Same pattern as Prime-RL. + +#### Transition-Level (Alternative) + +When `level: "transition"`, each turn becomes a separate sample with only its own response tokens — equivalent to SkyRL step-wise. No `response_mask` needed (all response tokens are trainable). Same per-sample weighting issues as SkyRL step-wise with `sequence_mean`. + +#### Key Difference from Prime-RL + +Structurally very similar to Prime-RL. The main addition is the **diagnostic system**: categorizing mismatches into template/retoken/others and logging them, which Prime-RL doesn't do. Both use the same soft-fallback strategy (split into new sample on mismatch). + +### SLIME + +- **Multi-turn**: Yes, via async rollout loops (`examples/geo3k_vlm_multi_turn/rollout.py`) +- **TIS**: Full support via `--use-tis` flag, with custom TIS functions loadable via `--custom-tis-function-path` +- **Advantage**: Trajectory-level (not per-step broadcast). Single advantage per trajectory replicated across all tokens. +- **Key difference**: Accumulates tokens in a single sequence with loss masking (observation tokens masked out). Does NOT decompose into separate per-step (prompt, response) pairs. + +### veRL + +- **Multi-turn**: Experimental support via `examples/data_preprocess/multiturn.py` +- **Step-wise advantages**: Experimental `_stepwise_advantage_broadcast` in `rllm/experimental/verl/verl_advantage.py` — same pattern as SkyRL (compute at last step, broadcast back) +- **TIS**: No explicit TIS; uses response mask for selective loss computation +- **Key difference**: Step-wise is experimental/external (in rllm integration), not in core veRL + +### rLLM + +- **Multi-turn**: Dedicated `MultiTurnWorkflow` class (`rllm/workflows/multi_turn_workflow.py`) with step-by-step environment interaction +- **Step-wise tracking**: Full trajectory metadata including `step_nums`, `episode_ids`, `trajectory_ids` +- **Advantage broadcasting**: Leverages veRL's experimental `_stepwise_advantage_broadcast` +- **TIS**: Indirectly through veRL backend +- **Key difference**: Most mature multi-turn workflow abstraction, but relies on veRL backend for training + +### tinker-cookbook + +- **Multi-turn**: Yes, via `Transition` and `Trajectory` abstractions (`tinker_cookbook/rl/types.py`) +- **Per-step rewards**: Each `Transition` has an immediate reward +- **Advantage**: Trajectory-level, group-centered (within-group normalization). Replicated across all action tokens. +- **TIS**: No explicit importance sampling +- **Key difference**: `Transition` is the closest conceptual match to our per-step approach. Clean abstraction but no TIS support. Prefix-aware sequence merging handles observation tokens efficiently. + +### Summary Table + +| Framework | Step-Wise Decomposition | TIS / Off-Policy | Advantage | Re-Tokenization Avoidance | +|-----------|------------------------|------------------|-----------|---------------------------| +| **SkyRL (this)** | Yes (per-turn separate samples) | Yes, ratio clipping (token/seq) | Last-step → broadcast to all steps | Yes (rollout_details) | +| **SkyRL-Agent** | Yes (prefix-aware merging) | Yes | Scalar per traj → broadcast via `episode_nums` | Yes (token-in-token-out) | +| **Prime-RL** | Merge when prefix holds, split when breaks | Yes, probability-diff masking (IPO) | Scalar per rollout → broadcast to all tokens | Yes (exact prefix invariant) | +| **Agent Lightning** | Both modes: `trajectory` (merge) or `transition` (per-turn) | KL penalty (no explicit IS) | Scalar per sample → broadcast | Yes (vLLM `return_token_ids` + prefix matching) | +| **SLIME** | No (single sequence, loss mask) | Yes (`--use-tis`) | Trajectory-level | N/A | +| **veRL** | Experimental | No | Experimental broadcast | N/A | +| **rLLM** | Yes (MultiTurnWorkflow) | Via veRL | Via veRL experimental | N/A | +| **tinker-cookbook** | Yes (Transition objects) | No | Trajectory-level, group-centered | N/A | + +### Mini-Batch Size with Step-Wise Expansion + +All frameworks that do step-wise decomposition face the same question: the effective number of training samples grows by the average number of turns. In SkyRL (and our implementation), the `policy_mini_batch_size` is computed as `policy_mini_batch_size * n_samples`, which doesn't account for step expansion. This means more optimizer steps per batch (e.g., 10× for 10-turn trajectories). + +- **SkyRL/SkyRL-Agent**: Keep as-is (more optimizer steps). This is the current behavior. +- **veRL/rLLM**: Optional `normalize_by_steps` flag to divide advantage by step count. +- **tinker-cookbook**: Group-centered advantages naturally handle this via within-group normalization. + +--- + +## Implementation + +### Architecture + +``` +harbor_agent_loop() (async, per-trajectory) + ├─ Success → _build_step_wise_output() → HarborStepWiseOutput + └─ Failure → HarborAgentOutput (same as non-step-wise failures) + ↓ +_build_step_wise_generator_output() (batch-level) + ├─ _identify_masked_instances() (shared with non-step-wise path) + └─ Flatten to GeneratorOutput with is_last_step, trajectory_ids, rollout_logprobs + ↓ +SkyRL Trainer (unchanged) + ↓ Advantage broadcast, TIS correction, PPO update +``` + +### Key Data Structures + +```python +@dataclass +class HarborAgentOutput: + response_ids: List[int] # Completion token IDs for this turn + reward: Union[float, List[float]] # Per-token rewards (list for step-wise) + stop_reason: str + loss_mask: List[int] # 1 for generated tokens, 0 for masked + prompt_ids: List[int] # Full prompt including chat history + trajectory_id: TrajectoryID + rollout_logprobs: Optional[List[float]] # Per-token logprobs from vLLM + summarization_count: Optional[int] + num_turns: Optional[int] + +@dataclass +class HarborStepWiseOutput: + step_outputs: List[HarborAgentOutput] # One per agent turn + trajectory_id: Optional[TrajectoryID] + summarization_count: Optional[int] + num_turns: Optional[int] +``` + +### GeneratorOutput Format (Step-Wise) + +When `step_wise_trajectories=True`, the generator returns: + +```python +{ + "prompt_token_ids": [[turn1_prompt], [turn2_prompt], ...], # Per-step prompts + "response_ids": [[turn1_completion], [turn2_completion], ...], # Per-step completions + "rewards": [[0,0,...,0], [0,0,...,0], ..., [0,0,...,R]], # Per-token, reward at last token of last step + "loss_masks": [[1,1,...,1], [1,1,...,1], ...], # All completion tokens trainable + "rollout_logprobs": [[lp1], [lp2], ...], # Per-token logprobs from vLLM + "is_last_step": [False, False, ..., True, False, ..., True], # Marks final step per trajectory + "trajectory_ids": [tid1, tid1, ..., tid1, tid2, ..., tid2], # Same ID for all steps of a trajectory + "stop_reasons": [...], + "rollout_metrics": {...}, +} +``` + +The trainer already handles this format (checks for `is_last_step` and `trajectory_ids` in `trainer.py:642-658, 784-815`). + +### Non-Step-Wise Path (Unchanged) + +When `step_wise_trajectories=False` (default), behavior is identical to the original implementation: re-tokenize chat history via `get_response_ids_and_loss_mask_from_messages()`, return single trajectory per prompt, `rollout_logprobs=None`, no `is_last_step`. + +--- + +## Caveats and Design Decisions + +### 1. Padding OOM (Pending Fix in `convert_prompts_responses_to_batch_tensors`) + +**Problem**: In step-wise mode, different steps have very different prompt/response length ratios. Early turns have short prompts (~100 tokens) but potentially long completions (~20K tokens with thinking). Late turns have long prompts (~30K tokens of chat history) but short completions (~200 tokens). The padding function (`convert_prompts_responses_to_batch_tensors`) pads ALL samples to `max(all_prompts) + max(all_responses)`, creating padded sequences far exceeding `max_seq_len`: + +``` +Step 1: prompt=100, response=20000 → actual total = 20100 +Step 5: prompt=30000, response=200 → actual total = 30200 +Padded: every sample = 30000 + 20000 = 50000 tokens ← OOM! +``` + +**Status**: This is a known issue. The fix should be in `convert_prompts_responses_to_batch_tensors` (in `skyrl/train/dataset/preprocess.py`) — each sequence should be capped at `max_seq_len` total length rather than taking `max(all_prompts) + max(all_responses)`. `HarborGenerator` intentionally does NOT truncate/pad prompts or responses itself; that responsibility belongs to the downstream batch tensor construction. + +### 2. Summarization Not Supported + +When `step_wise_trajectories=True`, context summarization (`enable_summarize=True`) is not supported. Summarization causes Harbor to split rollout_details into multiple segments (main + subagent), making per-turn alignment ambiguous. An assertion enforces this: + +```python +if len(rollout_details_list) > 1: + assert summarization_count == 0, "step_wise + summarization not supported" +``` + +The default Harbor config already has `enable_summarize: false`. + +### 3. `collect_rollout_details` Must Be Enabled + +Step-wise training requires Harbor's `collect_rollout_details=True` in the agent kwargs. This tells Terminus 2 to request `logprobs=True` and `return_token_ids=True` from vLLM via LiteLLM's `extra_body`. The generator auto-enables this if not set: + +```python +if self.step_wise: + if not agent_kwargs.get("collect_rollout_details", False): + self._harbor_trial_config_template["agent"]["kwargs"]["collect_rollout_details"] = True +``` + +### 4. Loss Mask: All Completion Tokens Are Trainable + +In step-wise mode, each step's `loss_mask = [1] * len(completion_ids)`. This is because the response consists ONLY of the model's completion tokens (no interleaved observation/user tokens). The prompt already contains the full chat history including previous observations. + +This differs from the non-step-wise path where `get_response_ids_and_loss_mask_from_messages()` interleaves assistant and user/observation tokens in a single response, with loss_mask=0 for non-assistant tokens. + +### 5. Failed Trajectories in Step-Wise Mode + +Failed trajectories (timeout, error, missing rollout_details) return a plain `HarborAgentOutput` with zeroed fields (same as non-step-wise failures): + +```python +HarborAgentOutput(response_ids=[0], reward=0, stop_reason="error", loss_mask=[0], prompt_ids=[0], ...) +``` + +The batch-level `_build_step_wise_generator_output()` identifies these via `_identify_masked_instances()` (shared with the non-step-wise path) and emits a single zeroed-out step for each masked instance. Instance-level masking still applies: if any trajectory for a prompt fails, all trajectories for that prompt are zeroed out. + +### 6. Feature Compatibility + +Step-wise training changes the batch structure: N trajectories become N×M step-samples. Several trainer features assume 1:1 correspondence between batch indices and trajectories. + +| Feature | Compatible? | Issue | +|---------|------------|-------| +| `dynamic_sampling` (replace/filter) | **No** | Reward variance computation is polluted by intermediate zero-reward steps. Index-level replacement breaks trajectory integrity (replaces one step of trajectory A with a step from trajectory B, producing incoherent prompt/response). Filter sampling's reward grouping is wrong. | +| `zero_variance_filter` | **Accidentally bypassed** | Skipped because step-wise rewards are per-token lists (`isinstance(rewards[0], list)` → True). Would break with scalar rewards: groups by uid, computes `np.std(rewards)` on step-level entries mixing zeros and actual rewards. | +| `advantage_batch_normalize` | **Semantically different** | `normalize_advantages_dict` computes mean/std across all step-samples equally. Trajectories with more turns contribute more samples → more influence on normalization statistics. | +| `use_kl_in_reward` | **Partially broken** | `apply_reward_kl_penalty` adds per-token KL penalty to the rewards tensor. Works at the tensor level, but `compute_advantages_and_returns` then filters to `[is_last_step]` only — KL penalty on intermediate steps is dropped, same as intermediate step rewards. | +| `use_kl_loss` | **Yes** | Token-level regularizer: `masked_mean(kl, loss_mask)` per micro-batch. No trajectory structure dependency. | +| `use_entropy_loss` | **Yes** | Token-level: `masked_mean(entropy, loss_mask)`. No trajectory structure dependency. | +| `update_ref_every_epoch` | **Yes** | Weight sync only. No batch structure dependency. | +| `dump_data_batch` | **Yes** | Works, but dumped data has step-wise expanded structure (N×M) which may confuse analysis tools. | +| `batched=True` | **Blocked** | Explicit validation in `SkyRLGymGenerator.__init__`. | +| Custom chat templates | **Blocked** | Explicit validation in `SkyRLGymGenerator.__init__`. | +| `use_conversation_multi_turn=False` | **Blocked** | Explicit validation in `SkyRLGymGenerator.__init__`. | + +--- + +## Configuration + +### Enable Step-Wise Training + +```bash +generator.step_wise_trajectories=true +``` + +### Enable TIS + +```bash +trainer.algorithm.off_policy_correction.tis_ratio_type=token # or "sequence" +trainer.algorithm.off_policy_correction.token_tis_ratio_clip_high=2.0 # recommended: 1.5-5.0 +``` + +### Harbor Config (default.yaml) + +Required settings for step-wise: +```yaml +agent: + kwargs: + collect_rollout_details: true # Get per-turn token IDs + logprobs from vLLM + enable_summarize: false # Required: summarization breaks rollout_details + store_all_messages: true # Required: for chat history extraction +``` + +### Full Example + +```bash +bash examples/train_integrations/harbor/run_codecontest_comparison.sh stepwise-tis +``` + +--- + +## Running + +### Prerequisites + +```bash +# Prepare dataset +uv run --isolated --extra harbor python examples/train_integrations/harbor/prepare_harbor_dataset.py --dataset open-thoughts/CodeContests + +# Ensure DAYTONA_API_KEY and WANDB_API_KEY are exported +``` + +### Dev Validation (Small Batch) + +```bash +bash examples/train_integrations/harbor/run_codecontest_stepwise_dev.sh +# batch=4, n_samples=2, step_wise=true, TIS=token, 10 steps +``` + +### Production Three-Way Comparison + +```bash +bash examples/train_integrations/harbor/run_all_comparison.sh +# Runs sequentially: baseline → stepwise → stepwise-tis +# Each: batch=32, n_samples=8, 10 steps +# Kills Daytona sandboxes between runs +``` + +### Individual Runs + +```bash +bash examples/train_integrations/harbor/run_codecontest_comparison.sh baseline +bash examples/train_integrations/harbor/run_codecontest_comparison.sh stepwise +bash examples/train_integrations/harbor/run_codecontest_comparison.sh stepwise-tis +``` + +### Cleanup Sandboxes + +```bash +uv run --isolated --extra harbor python examples/train_integrations/harbor/kill_daytona_sandboxes.py +``` + +--- + +## Files Changed + +| File | Change | +|------|--------| +| `pyproject.toml` (line 253) | Harbor dependency → local path `/home/ray/default/harbor` | +| `harbor_trial_config/default.yaml` | Added `collect_rollout_details: true` | +| `harbor_generator.py` | Main implementation: `HarborStepWiseOutput`, `_build_step_wise_output()`, `_build_step_wise_generator_output()`, `_identify_masked_instances()` (shared masking logic) | +| `run_codecontest_stepwise_dev.sh` | Dev run script (batch=4, n_samples=2) | +| `run_codecontest_comparison.sh` | Parameterized production script (baseline/stepwise/stepwise-tis) | +| `run_all_comparison.sh` | Chains all three comparison runs | +| `kill_daytona_sandboxes.py` | Daytona sandbox cleanup utility | + +No changes to the SkyRL training loop (`trainer.py`) — the existing step-wise and TIS infrastructure handles everything. + +--- + +## Appendix: Trainer Internals Reference + +Context for future iteration on the Harbor step-wise implementation. + +### The 4 Sets of Logprobs + +The training pipeline uses four distinct sets of log probabilities: + +| Logprob | When computed | Where stored | Used for | +|---------|--------------|--------------|----------| +| `π_rollout` | During vLLM generation | `rollout_logprobs` in GeneratorOutput | TIS off-policy correction: `TIS_ratio = exp(π_old - π_rollout)` | +| `π_old` | Stage 3 (`fwd_logprobs_values_reward`), ONCE before training | `action_log_probs` in TrainingInputBatch | PPO ratio denominator: `ratio = exp(π_current - π_old)` | +| `π_ref` | Stage 3, from frozen reference model | `base_action_log_probs` in TrainingInputBatch | KL penalty: `KL(π_current, π_ref)` | +| `π_current` | Stage 5, RECOMPUTED each micro-batch with gradients | Local variable in `_forward_backward_micro` | PPO ratio numerator (this is what gets trained) | + +### Training Call Chain + +See `TRAIN_POLICY_CALL_CHAIN.md` for the full call chain. Key structure: + +``` +train_critic_and_policy(data) + └─ _execute_training_step("policy", data) ← trainer.py + └─ for each epoch: + └─ for each mini-batch (optimizer step): ← trainer.py:1070-1082 + │ dispatch → Worker.forward_backward_from_staged() + │ │ data sharded across DP workers (each gets mini_batch/dp_size samples) + │ └─ Worker.forward_backward() ← worker.py:675 + │ └─ for each micro-batch: ← gradient accumulation + │ └─ _forward_backward_micro() ← worker.py:749, GPU hot path + │ model.forward() → π_current + │ loss_fn(π_current, π_old, advantages) → loss + │ loss.backward() → accumulates .grad + └─ Worker.optim_step() ← worker.py:939 + grad *= 1/N, optimizer.step(), zero_grad() +``` + +### Model Forward Pass + +`HFModelWrapper.forward()` (`model_wrapper.py:261`): +- Input: `sequences [batch, max_prompt + max_response]` — the FULL padded sequence +- Runs the entire transformer on the full sequence length +- Slices output to response portion: `log_probs[:, -num_actions-1 : -1]` +- Returns: `action_log_probs [batch, max_response]` + +Peak GPU memory is determined by `micro_batch_size × (max_prompt + max_response)`, not just response length. + +### Padding: Full Batch, Not Per Mini-Batch + +`convert_prompts_responses_to_batch_tensors()` (`preprocess.py:28`) pads ALL samples to the global max prompt and max response lengths across the ENTIRE batch. This happens once before training, not per mini-batch. Consequence for step-wise: if one step-sample has a 64K prompt (last turn), ALL step-samples are padded to 64K prompt length, including early turns with tiny prompts. This is a major source of memory waste and can cause OOM. **Pending fix**: each sequence should be capped at `max_seq_len` total length in `convert_prompts_responses_to_batch_tensors` rather than using `max(all_prompts) + max(all_responses)`. + +### Three Masks + +| Mask | Shape | What it marks | +|------|-------|---------------| +| `attention_mask` | `[batch, max_prompt + max_response]` | 0 for left-pad, 1 for real tokens. Used by transformer attention. | +| `response_mask` | `[batch, max_response]` | 1 for response tokens, 0 for right-pad. Used to slice logprobs from full-sequence model output. | +| `loss_mask` | `[batch, max_response]` | 1 for trainable tokens, 0 for obs/pad. Subset of response_mask. Used in loss computation. | + +Relationship: `loss_mask ⊆ response_mask`. In non-step-wise, obs tokens have `response_mask=1` but `loss_mask=0`. In step-wise (Harbor), response is pure completion tokens, so `loss_mask = response_mask` (all 1s). + +### GRPO Advantage Details + +`compute_grpo_outcome_advantage()` (`ppo_utils.py:1132`): +1. `scores = token_level_rewards.sum(dim=-1)` — collapse to scalar (position doesn't matter) +2. Group by uid (prompt), compute per-group mean (and optionally std) +3. `advantage = (score - group_mean) / (group_std + ε)` (or just `score - group_mean` without std normalization) +4. `advantages = scores.unsqueeze(-1) × response_mask` — broadcast scalar to all response tokens +5. Singleton groups (1 sample): `mean=0, std=1` → advantage = raw score + +### Advantage Estimator Comparison + +| Estimator | Baseline | Normalization | Per-token variation | Needs critic | +|---|---|---|---|---| +| **GRPO** | mean(group) | optional std(group) | No (scalar broadcast) | No | +| **RLOO** | mean(group) × N/(N-1) leave-one-out | No | No (scalar broadcast) | No | +| **REINFORCE++** | None (batch whitening) | batch-level whiten | Yes (if γ<1) | No | +| **GAE** | V(s) from critic | batch-level whiten | Yes (always) | **Yes** | + +For GRPO: reward token position doesn't matter — `sum()` collapses it. Only GAE cares about position (`δ_t = r_t + γV(t+1) - V(t)`). + +### Memory Analysis + +| Component | Determined by | GPU or CPU | +|---|---|---| +| Model forward/backward activations | `micro_batch_size × (max_prompt + max_response)` | GPU | +| Model parameters | Fixed (8B × 2 bytes / dp_size) | GPU | +| Gradients | Fixed (same as parameters, allocated once) | GPU | +| Optimizer states (Adam) | 2× parameters, offloaded after step | GPU → CPU | +| Full training batch (padded tensors) | `len(data) × (max_prompt + max_response)` | CPU (Ray object store) | +| Per-worker mini-batch slice | `(mini_batch_size / dp_size) × seq_len` | CPU → GPU | + +Step-wise multiplies `len(data)` by avg turns (M), increasing CPU memory M×. GPU peak memory depends on the padded seq_len — currently `max(all_prompts) + max(all_responses)` which can cause OOM (see caveat #1). Once `convert_prompts_responses_to_batch_tensors` is fixed to cap at `max_seq_len`, GPU peak memory will be bounded. + +### SkyRLGymGenerator vs Harbor Step-Wise: Structural Difference + +| Aspect | SkyRLGymGenerator step-wise | Harbor step-wise | +|--------|---------------------------|-----------------| +| Response content | `action_tokens + obs_tokens` | Pure completion tokens only | +| Loss mask | `[1,1,...,0,0,...,1,1,...]` (action=1, obs=0) | `[1,1,...,1]` (all trainable) | +| Logprobs | Real for actions, 0.0 for obs | Real for all completion tokens | +| Obs tokens | In response, masked by loss_mask | In next step's prompt | +| Overlong filtering | **Broken** — intermediate steps don't end with EOS, so `apply_overlong_filtering` zeros their loss_mask | Works correctly — completion tokens may end with EOS | + +**Potential fix for SkyRLGymGenerator**: Remove obs tokens from `turn_response_ids` (only include `output_ids`, not `output_ids + obs_ids`). This would make it structurally identical to Harbor step-wise, fix the overlong filtering bug, and reduce padding waste. + +### Chunked MDP Perspective (ROLL Team / IPA) + +The ROLL team's IPA algorithm argues that chunk-level (= per-turn) decomposition is **better** than both token-level and trajectory-level for agentic RL: + +- **Token-level problem**: Most tokens don't change environment state. A 500-token thinking block + 10-token tool call all get same advantage, but only the tool call mattered. +- **Trajectory-level problem**: One scalar advantage for 30 turns. Turn 3 was the critical mistake, turns 4-30 were wasted, but all get identical gradient signal. +- **Chunk-level (step-wise)**: Each turn is the natural "decision unit." Enables per-chunk credit assignment, per-chunk IS masking, per-chunk returns. + +With terminal-only reward (no per-step env rewards), chunk-level advantages still equal trajectory-level broadcast. The real benefit is chunk-level IS masking — selectively dropping turns where policy has drifted, rather than all-or-nothing. See: [ROLL Team IPA paper](https://arxiv.org/pdf/2512.24873). + +### Contiguity No Longer Required + +The trainer's advantage broadcast was updated to use trajectory-id-based mapping instead of the `cumsum(shifted_is_last_step)` trick. Steps from the same trajectory no longer need to be adjacent in the batch. See `feature/stepwise-traj-id-broadcast` branch in `/home/ray/default/SkyRL-stepwise-validation`. + +### Inspection Scripts + +All in `examples/train_integrations/harbor/`: + +| Script | What it shows | +|--------|--------------| +| `inspect_trainer_dataflow.py` | Full trainer pipeline stages 0-5 with formulas, non-step-wise multi-turn | +| `inspect_stepwise_dataflow.py` | Harbor-style step-wise generator output with dummy data | +| `inspect_stepwise_skyrl_gym.py` | Actual SkyRLGymGenerator step-wise with mocked LLM/env | +| `inspect_stepwise_vs_nonstepwise.py` | Side-by-side comparison: same data through both paths, all stages | +| `TRAIN_POLICY_CALL_CHAIN.md` | Full call chain for the training phase | +` diff --git a/examples/train_integrations/harbor/harbor_generator.py b/examples/train_integrations/harbor/harbor_generator.py index b6b6279475..175b5ce670 100644 --- a/examples/train_integrations/harbor/harbor_generator.py +++ b/examples/train_integrations/harbor/harbor_generator.py @@ -1,7 +1,7 @@ import asyncio from copy import deepcopy -from dataclasses import dataclass -from typing import List, Optional +from dataclasses import dataclass, field +from typing import List, Optional, Union from loguru import logger from uuid import uuid4 from skyrl.train.generators.base import GeneratorInterface, GeneratorInput, GeneratorOutput, TrajectoryID @@ -10,9 +10,10 @@ from skyrl.backends.skyrl_train.inference_engines.base import ConversationType from skyrl.train.utils.rate_limiter import create_rate_limiter from tqdm import tqdm -from omegaconf import DictConfig, OmegaConf +from omegaconf import DictConfig from harbor.trial.trial import Trial from harbor.models.trial.config import TrialConfig +from harbor.models.trial.result import TrialResult # Suppress LiteLLM verbose logging @@ -31,13 +32,28 @@ @dataclass class HarborAgentOutput: response_ids: List[int] - reward: float + reward: Union[float, List[float]] stop_reason: str loss_mask: List[int] prompt_ids: List[int] trajectory_id: TrajectoryID summarization_count: Optional[int] = None num_turns: Optional[int] = None + rollout_logprobs: Optional[List[float]] = None + + +@dataclass +class HarborStepWiseOutput: + """Step-wise output from a single Harbor trajectory. + + Each step_output corresponds to one agent turn (LLM call), + using the exact token IDs and logprobs from vLLM (no re-tokenization). + """ + + step_outputs: List[HarborAgentOutput] = field(default_factory=list) + trajectory_id: Optional[TrajectoryID] = None + summarization_count: Optional[int] = None + num_turns: Optional[int] = None class HarborGenerator(GeneratorInterface): @@ -62,6 +78,7 @@ def __init__( self.generator_cfg = generator_cfg self.tokenizer = tokenizer self.max_seq_len = max_seq_len + self.step_wise = getattr(generator_cfg, "step_wise_trajectories", False) # Harbor config template - users can specify any Harbor TrialConfig options in YAML or command line. # SkyRL injects: model_name and api_base (once at init), task.path and session_id (per trial) @@ -77,10 +94,29 @@ def __init__( ] = f"hosted_vllm/{ie_cfg.served_model_name}" self._harbor_trial_config_template["agent"].setdefault("kwargs", {})["api_base"] = f"{self.base_url}/v1" + # Config post-processings + agent_kwargs = self._harbor_trial_config_template["agent"].get("kwargs", {}) + + # Summarization is not supproted yet + if agent_kwargs.get("enable_summarize", False): + raise ValueError( + "step_wise_trajectories=True is incompatible with enable_summarize=True. " + "Summarization invalidates rollout_details. Set enable_summarize=false." + ) + + # Step-wise training requires collect_rollout_details to get per-turn token IDs and logprobs + if self.step_wise and not agent_kwargs.get("collect_rollout_details", False): + logger.warning( + "step_wise_trajectories=True but collect_rollout_details is not enabled in Harbor config. " + "Enabling it automatically." + ) + self._harbor_trial_config_template["agent"]["kwargs"]["collect_rollout_details"] = True + logger.info( f"HarborGenerator initialized with Harbor config. " f"Agent: {self._harbor_trial_config_template.get('agent', {}).get('name')}, " - f"Trials dir: {self._harbor_trial_config_template.get('trials_dir', 'trials')}" + f"Trials dir: {self._harbor_trial_config_template.get('trials_dir', 'trials')}, " + f"Step-wise: {self.step_wise}" ) # Read custom chat template @@ -107,7 +143,7 @@ async def generate(self, input_batch: GeneratorInput) -> GeneratorOutput: f"Prompt count ({len(prompts)}) doesn't match " f"trajectory_ids count ({len(trajectory_ids)})" ) - all_outputs: List[HarborAgentOutput] = [None] * len(prompts) # type: ignore[list-item] + all_outputs: List[Union[HarborAgentOutput, HarborStepWiseOutput]] = [None] * len(prompts) # type: ignore[list-item] progress = tqdm( total=len(prompts), desc="Generating Trajectories", @@ -115,8 +151,25 @@ async def generate(self, input_batch: GeneratorInput) -> GeneratorOutput: mininterval=5, ) + def _make_error_output(trajectory_id): + return HarborAgentOutput( + response_ids=[0], + reward=0, + stop_reason="error", + loss_mask=[0], + prompt_ids=[0], + trajectory_id=trajectory_id, + ) + async def _worker(idx, prompt, trajectory_id): - result = await self.harbor_agent_loop(prompt=prompt, trajectory_id=trajectory_id) + try: + result = await self.harbor_agent_loop(prompt=prompt, trajectory_id=trajectory_id) + except BaseException as e: + logger.warning( + f"Trajectory {trajectory_id} raised unhandled exception in harbor_agent_loop: {type(e).__name__}: {e}. " + "Returning zeroed-out output." + ) + result = _make_error_output(trajectory_id) all_outputs[idx] = result progress.update(1) @@ -124,18 +177,179 @@ async def _worker(idx, prompt, trajectory_id): async with asyncio.TaskGroup() as tg: for idx, (prompt, trajectory_id) in enumerate(zip(prompts, trajectory_ids)): tg.create_task(_worker(idx, prompt, trajectory_id)) + except BaseException as e: + logger.warning(f"TaskGroup raised {type(e).__name__}: {e}. Some workers may have failed.") finally: progress.close() - all_outputs, rollout_metrics = self._mask_failed_instances_and_compute_metrics(all_outputs) + + # Safety: fill any remaining None entries (e.g. from cancelled tasks) + for idx in range(len(all_outputs)): + if all_outputs[idx] is None: + all_outputs[idx] = _make_error_output(trajectory_ids[idx]) + + if self.step_wise: + return self._build_step_wise_generator_output(all_outputs) + else: + all_outputs, rollout_metrics = self._mask_failed_instances_and_compute_metrics(all_outputs) + generator_output: GeneratorOutput = { + "prompt_token_ids": [output.prompt_ids for output in all_outputs], + "response_ids": [output.response_ids for output in all_outputs], + "rewards": [output.reward for output in all_outputs], + "loss_masks": [output.loss_mask for output in all_outputs], + "stop_reasons": [output.stop_reason for output in all_outputs], + "rollout_metrics": rollout_metrics, + "rollout_logprobs": None, + } + return generator_output + + @staticmethod + def _identify_masked_instances( + all_outputs: List[Union[HarborAgentOutput, HarborStepWiseOutput]], + ) -> tuple[set[str], int, int, set[str]]: + """Identify instances that should be masked (zeroed out) due to failures. + + For a group of trajectories (n_samples_per_prompt for the same prompt), + if one trajectory fails we skip training the entire group. + + Returns: + (masked_instance_ids, num_timeout_trajectories, num_error_trajectories, all_instance_ids) + """ + timeout_instance_ids: set[str] = set() + error_instance_ids: set[str] = set() + all_instance_ids: set[str] = set() + num_timeout = 0 + num_error = 0 + + for output in all_outputs: + instance_id = output.trajectory_id.instance_id + all_instance_ids.add(instance_id) + + # For HarborStepWiseOutput, check the last step's stop reason + if isinstance(output, HarborStepWiseOutput): + stop = output.step_outputs[-1].stop_reason if output.step_outputs else "error" + else: + stop = output.stop_reason + + if stop == "agent_timeout": + num_timeout += 1 + timeout_instance_ids.add(instance_id) + elif stop == "error": + num_error += 1 + error_instance_ids.add(instance_id) + + return timeout_instance_ids | error_instance_ids, num_timeout, num_error, all_instance_ids + + def _build_step_wise_generator_output( + self, + all_outputs: List[Union[HarborAgentOutput, HarborStepWiseOutput]], + ) -> GeneratorOutput: + """Flatten step-wise outputs into the GeneratorOutput format. + + Each multi-turn trajectory becomes N separate (prompt, response) samples, + with `is_last_step` marking the final step of each trajectory. + """ + masked_ids, num_timeout, num_error, all_ids = self._identify_masked_instances(all_outputs) + + responses = [] + rewards = [] + stop_reasons = [] + loss_masks = [] + prompt_token_ids = [] + is_last_step_list = [] + out_trajectory_ids = [] + rollout_logprobs_list = [] + successful_last_step_outputs = [] + + for output in all_outputs: + tid = output.trajectory_id + + if tid.instance_id in masked_ids: + # Emit a single zeroed-out step for masked instances + prompt_token_ids.append([0]) + responses.append([0]) + rewards.append([0.0]) + stop_reasons.append("error") + loss_masks.append([0]) + is_last_step_list.append(True) + out_trajectory_ids.append(tid) + rollout_logprobs_list.append([0.0]) + continue + + # Non-masked outputs in step-wise mode are always HarborStepWiseOutput + # (failures return HarborAgentOutput which gets masked above) + assert isinstance(output, HarborStepWiseOutput), ( + f"Expected HarborStepWiseOutput for non-masked trajectory {tid}, got {type(output).__name__}" + ) + for j, step in enumerate(output.step_outputs): + is_last = j == len(output.step_outputs) - 1 + prompt_token_ids.append(step.prompt_ids) + responses.append(step.response_ids) + rewards.append(step.reward) + stop_reasons.append(step.stop_reason) + loss_masks.append(step.loss_mask) + is_last_step_list.append(is_last) + out_trajectory_ids.append(tid) + rollout_logprobs_list.append(step.rollout_logprobs) + if is_last: + successful_last_step_outputs.append(step) + + # Compute rollout metrics from successful last-step outputs + if successful_last_step_outputs: + metric_rewards = [] + for o in successful_last_step_outputs: + metric_rewards.append(sum(o.reward) if isinstance(o.reward, list) else o.reward) + rollout_metrics = get_rollout_metrics( + [o.response_ids for o in successful_last_step_outputs], + metric_rewards, + ) + # Harbor-specific metrics from trajectory-level outputs + summarization_counts = [] + num_turns_list = [] + context_exceeded = 0 + for output in all_outputs: + if isinstance(output, HarborStepWiseOutput) and output.trajectory_id.instance_id not in masked_ids: + if output.summarization_count is not None: + summarization_counts.append(output.summarization_count) + if output.num_turns is not None: + num_turns_list.append(output.num_turns) + if output.step_outputs and output.step_outputs[-1].stop_reason == "context_length": + context_exceeded += 1 + rollout_metrics["generate/trajectories_summarized"] = sum(1 for c in summarization_counts if c > 0) + rollout_metrics["generate/trajectories_context_length_exceeded"] = context_exceeded + if num_turns_list: + rollout_metrics["generate/avg_num_turns"] = sum(num_turns_list) / len(num_turns_list) + else: + rollout_metrics = {} + + rollout_metrics["generate/num_timeout_trajectories"] = num_timeout + rollout_metrics["generate/num_error_trajectories"] = num_error + rollout_metrics["generate/num_masked_instances"] = len(masked_ids) + + logger.info( + f"\n# of masked instances: {len(masked_ids)} / {len(all_ids)}\n" + f"# of timeout trajectories: {num_timeout}\n" + f"# of error trajectories: {num_error}\n" + f"# of flattened step-samples: {len(responses)}" + ) + + # Check if any rollout_logprobs are available (non-zero-length lists) + has_logprobs = any(lp is not None and len(lp) > 0 for lp in rollout_logprobs_list) + if has_logprobs: + # Ensure all entries are lists (replace None with zero-filled lists matching response length) + for i, lp in enumerate(rollout_logprobs_list): + if lp is None: + rollout_logprobs_list[i] = [0.0] * len(responses[i]) generator_output: GeneratorOutput = { - "prompt_token_ids": [output.prompt_ids for output in all_outputs], - "response_ids": [output.response_ids for output in all_outputs], - "rewards": [output.reward for output in all_outputs], - "loss_masks": [output.loss_mask for output in all_outputs], - "stop_reasons": [output.stop_reason for output in all_outputs], + "prompt_token_ids": prompt_token_ids, + "response_ids": responses, + "rewards": rewards, + "loss_masks": loss_masks, + "stop_reasons": stop_reasons, "rollout_metrics": rollout_metrics, - "rollout_logprobs": None, + "rollout_logprobs": rollout_logprobs_list if has_logprobs else None, + "is_last_step": is_last_step_list, + "trajectory_ids": out_trajectory_ids, } return generator_output @@ -153,28 +367,12 @@ def _mask_failed_instances_and_compute_metrics( all_outputs: The same list, with failed-instance outputs zeroed out. rollout_metrics: Dict of rollout metrics for logging. """ - # Count failures by type before grouping overwrites stop_reason. - num_timeout_trajectories = 0 - num_error_trajectories = 0 - timeout_instance_ids = set() - error_instance_ids = set() - all_instance_ids = set() - for output in all_outputs: - cur_instance_id = output.trajectory_id.instance_id - all_instance_ids.add(cur_instance_id) - if output.stop_reason == "agent_timeout": - num_timeout_trajectories += 1 - timeout_instance_ids.add(cur_instance_id) - elif output.stop_reason == "error": - num_error_trajectories += 1 - error_instance_ids.add(cur_instance_id) - - masked_instance_ids = timeout_instance_ids | error_instance_ids + masked_ids, num_timeout, num_error, all_ids = HarborGenerator._identify_masked_instances(all_outputs) # Zero-out all outputs belonging to any timeout or error instance so we skip training on them. successful_outputs: List[HarborAgentOutput] = [] for output in all_outputs: - if output.trajectory_id.instance_id in masked_instance_ids: + if output.trajectory_id.instance_id in masked_ids: output.response_ids = [0] output.stop_reason = "error" output.loss_mask = [0] @@ -202,14 +400,14 @@ def _mask_failed_instances_and_compute_metrics( rollout_metrics = {} # Failure metrics: timeout vs unknown error trajectories, and masked instances. - rollout_metrics["generate/num_timeout_trajectories"] = num_timeout_trajectories - rollout_metrics["generate/num_error_trajectories"] = num_error_trajectories - rollout_metrics["generate/num_masked_instances"] = len(masked_instance_ids) + rollout_metrics["generate/num_timeout_trajectories"] = num_timeout + rollout_metrics["generate/num_error_trajectories"] = num_error + rollout_metrics["generate/num_masked_instances"] = len(masked_ids) logger.info( - f"\n# of masked instances: {len(masked_instance_ids)} / {len(all_instance_ids)}\n" - f"# of timeout trajectories: {num_timeout_trajectories}\n" - f"# of error trajectories: {num_error_trajectories}" + f"\n# of masked instances: {len(masked_ids)} / {len(all_ids)}\n" + f"# of timeout trajectories: {num_timeout}\n" + f"# of error trajectories: {num_error}" ) return all_outputs, rollout_metrics @@ -218,9 +416,12 @@ async def harbor_agent_loop( self, prompt: ConversationType, trajectory_id: TrajectoryID, - ) -> HarborAgentOutput: + ) -> Union[HarborAgentOutput, HarborStepWiseOutput]: """ Run a single harbor agent. + + Returns HarborStepWiseOutput when step_wise_trajectories=True, + HarborAgentOutput otherwise (or on failure). """ # Run the trial to get `reward`, `chat_history`, `summarization_count`, and `num_turns` reward = None @@ -230,6 +431,7 @@ async def harbor_agent_loop( successful = False is_context_length_error = False is_agent_timeout_error = False + results: Optional[TrialResult] = None for i in range(MAX_NUM_RETRIES_PER_TRIAL): prefix = f"Trajectory {trajectory_id} attempt {i+1}/{MAX_NUM_RETRIES_PER_TRIAL}" results = None @@ -271,13 +473,25 @@ async def harbor_agent_loop( chat_history = results.agent_result.metadata["all_messages"] summarization_count = results.agent_result.metadata["summarization_count"] num_turns = results.agent_result.metadata["n_episodes"] - if len(chat_history) > 1 and chat_history[0]["role"] == "user": - successful = True - logger.debug(f"{prefix} successful: reward={reward}. Results: {results}") - break + if self.step_wise: + # For step-wise, success is defined as having rollout_details + if results.agent_result.rollout_details is not None and len(results.agent_result.rollout_details[0].get("prompt_token_ids", [])) > 0: + successful = True + logger.debug(f"{prefix} successful: reward={reward}. Results: {results}") + break + else: + logger.warning( + f"{prefix} failed: No rollout_details (or empty one). Results: {results}" + ) else: - logger.warning( - f"{prefix} failed: Did not return a chat history with a user message. chat_history: {chat_history}\nResults: {results}" + # For non-step-wise, success is defined as having a chat history starting with a user message + if len(chat_history) > 1 and chat_history[0]["role"] == "user": + successful = True + logger.debug(f"{prefix} successful: reward={reward}. Results: {results}") + break + else: + logger.warning( + f"{prefix} failed: Did not return a chat history with a user message. chat_history: {chat_history}\nResults: {results}" ) except Exception as e: logger.warning(f"{prefix} failed: Error running trial: {e}. Results: {results}") @@ -299,6 +513,18 @@ async def harbor_agent_loop( trajectory_id=trajectory_id, ) + # --- Step-wise path: use rollout_details from Harbor --- + if self.step_wise: + return self._build_step_wise_output( + results=results, + reward=reward, + trajectory_id=trajectory_id, + summarization_count=summarization_count, + num_turns=num_turns, + is_context_length_error=is_context_length_error, + ) + + # --- Non-step-wise path: re-tokenize chat history (original behavior) --- # Use the first message as the prompt. We assume to be no systems messages. assert chat_history[0]["role"] == "user", "The first message should be a user message" prompt = [chat_history[0]] @@ -346,3 +572,94 @@ async def harbor_agent_loop( summarization_count=summarization_count, num_turns=num_turns, ) + + def _build_step_wise_output( + self, + results: TrialResult, + reward: float, + trajectory_id: TrajectoryID, + summarization_count: Optional[int], + num_turns: Optional[int], + is_context_length_error: bool, + ) -> HarborStepWiseOutput: + """Build a HarborStepWiseOutput from Harbor trial results using rollout_details. + + Uses the exact per-turn token IDs and logprobs from vLLM (no re-tokenization). + This avoids retokenization drift and enables correct TIS computation. + """ + # 1. Extract needed information + rollout_details_list = results.agent_result.rollout_details + assert rollout_details_list is not None, "rollout_details_list is required for step-wise training" + + # Use the first (main) rollout detail — this is the main agent's conversation. + # Additional entries (index 1+) are subagent rollout details (e.g., summarization). + main_rollout = rollout_details_list[0] + prompt_token_ids_per_turn = main_rollout.get("prompt_token_ids", []) + completion_token_ids_per_turn = main_rollout.get("completion_token_ids", []) + logprobs_per_turn = main_rollout.get("logprobs", []) + n_turns = len(completion_token_ids_per_turn) + + # 2. Validate the data + # Assert no summarization occurred (not supported yet) + if len(rollout_details_list) > 1: + assert summarization_count == 0, ( + f"Trajectory {trajectory_id}: step_wise_trajectories=True but summarization occurred " + f"({summarization_count} summarizations, {len(rollout_details_list)} rollout detail segments). " + f"This is not supported. Set enable_summarize=false." + ) + assert n_turns > 0, "rollout_details has no completion turns" + assert len(prompt_token_ids_per_turn) == n_turns and len(logprobs_per_turn) == n_turns, ( + f"Trajectory {trajectory_id}: Expect prompt_token_ids, completion_token_ids, and " + f"logprobs to have the same length, but respectively got: {len(prompt_token_ids_per_turn)} turns, ", + f"{len(completion_token_ids_per_turn)} turns, and {len(logprobs_per_turn)} turns." + ) + for turn_idx in range(n_turns): + assert len(logprobs_per_turn[turn_idx]) == len(completion_token_ids_per_turn[turn_idx]), ( + f"Trajectory {trajectory_id} turn {turn_idx}: " + f"logprobs length ({len(logprobs_per_turn[turn_idx])}) != " + f"completion_ids length ({len(completion_token_ids_per_turn[turn_idx])}). " + ) + + step_outputs = [] + for turn_idx in range(n_turns): + completion_ids = completion_token_ids_per_turn[turn_idx] + turn_prompt_ids = prompt_token_ids_per_turn[turn_idx] + turn_logprobs = logprobs_per_turn[turn_idx] + is_last = turn_idx == n_turns - 1 + + # Per-token reward: zeros for all but the last token of the LAST step + turn_reward = [0.0] * len(completion_ids) + if is_last and len(turn_reward) > 0: + turn_reward[-1] = float(reward) + + # Determine stop reason for this step + if is_context_length_error: + turn_stop_reason = "context_length" + else: + turn_stop_reason = "complete" + + # Loss mask: all completion tokens are trainable (they are the model's generation) + turn_loss_mask = [1] * len(completion_ids) + # Apply overlong filtering. + if self.generator_cfg.apply_overlong_filtering and turn_stop_reason == "context_length": + turn_loss_mask = [0] * len(completion_ids) + + step_output = HarborAgentOutput( + response_ids=completion_ids, + reward=turn_reward, + stop_reason=turn_stop_reason, + loss_mask=turn_loss_mask, + prompt_ids=turn_prompt_ids, + trajectory_id=trajectory_id, + rollout_logprobs=turn_logprobs, + summarization_count=summarization_count if is_last else 0, + num_turns=num_turns if is_last else 0, + ) + step_outputs.append(step_output) + + return HarborStepWiseOutput( + step_outputs=step_outputs, + trajectory_id=trajectory_id, + summarization_count=summarization_count, + num_turns=num_turns, + ) diff --git a/examples/train_integrations/harbor/harbor_trial_config/default.yaml b/examples/train_integrations/harbor/harbor_trial_config/default.yaml index e38d069294..0c7f23578b 100644 --- a/examples/train_integrations/harbor/harbor_trial_config/default.yaml +++ b/examples/train_integrations/harbor/harbor_trial_config/default.yaml @@ -47,6 +47,9 @@ agent: # Store all messages in the trial output (required for SkyRL training) store_all_messages: true + # Collect per-turn token IDs and logprobs from the LLM provider (required for step-wise training / TIS) + collect_rollout_details: true + # The only sampling param that directly gets passed to Terminus temperature: 1.0 diff --git a/examples/train_integrations/harbor/kill_daytona_sandboxes.py b/examples/train_integrations/harbor/kill_daytona_sandboxes.py new file mode 100644 index 0000000000..a612e657b8 --- /dev/null +++ b/examples/train_integrations/harbor/kill_daytona_sandboxes.py @@ -0,0 +1,36 @@ +#!/usr/bin/env python3 +"""Kill all Daytona sandboxes. Run between training iterations to clean up orphaned sandboxes.""" + +import asyncio +from daytona import AsyncDaytona + + +async def main(): + async with AsyncDaytona() as daytona: + page = await daytona.list() + sandboxes = page.items or [] + if not sandboxes: + print("No sandboxes found.") + return + print(f"Found {len(sandboxes)} sandbox(es) (page 1/{page.total_pages}). Deleting...") + deleted = 0 + for sb in sandboxes: + try: + await daytona.delete(sb) + deleted += 1 + except Exception as e: + print(f" Failed to delete sandbox {sb.id}: {e}") + # Handle additional pages + for p in range(2, (page.total_pages or 1) + 1): + next_page = await daytona.list() + for sb in (next_page.items or []): + try: + await daytona.delete(sb) + deleted += 1 + except Exception as e: + print(f" Failed to delete sandbox {sb.id}: {e}") + print(f"Done. Deleted {deleted} sandbox(es).") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/train_integrations/harbor/run_codecontest_comparison.sh b/examples/train_integrations/harbor/run_codecontest_comparison.sh new file mode 100755 index 0000000000..45c6691c77 --- /dev/null +++ b/examples/train_integrations/harbor/run_codecontest_comparison.sh @@ -0,0 +1,140 @@ +set -ex + +# Three-way comparison: baseline vs step-wise vs step-wise+TIS +# Usage: +# ./run_codecontest_comparison.sh baseline +# ./run_codecontest_comparison.sh stepwise +# ./run_codecontest_comparison.sh stepwise-tis + +MODE="${1:?Usage: $0 }" + +#----------------------- +# Dataset setup +#----------------------- +DATA_DIR="$HOME/data/harbor" +TRAIN_DATA="['$DATA_DIR/CodeContests']" + +#----------------------- +# Directory setup +#----------------------- +RUN_NAME="codecontest-${MODE}" +TRIALS_DIR="$HOME/$RUN_NAME/trials_run" +CKPTS_DIR="$HOME/$RUN_NAME/ckpts" +EXPORTS_DIR="$HOME/$RUN_NAME/exports" +LOG_DIR="/tmp/skyrl-logs/$RUN_NAME" + +#----------------------- +# Training setup +#----------------------- +MINI_BATCH_SIZE=32 +MAX_MODEL_LEN=32768 +APPLY_OVERLONG_FILTERING=true + +# Dr. GRPO parameters +LOSS_REDUCTION="seq_mean_token_sum_norm" +GRPO_NORM_BY_STD=false +USE_KL_LOSS=false + +# Chat template for interleaved thinking +CHAT_TEMPLATE_PATH="$(dirname "$0")/../../../skyrl/train/utils/templates/qwen3_acc_thinking.jinja2" + +#---------------- +# Infrastructure setup +#---------------- +NUM_GPUS=8 +ENABLE_RATE_LIMITING=true +TRAJECTORIES_PER_SECOND=5 +MAX_CONCURRENCY=512 + +# Mode-specific settings +STEP_WISE="false" +TIS_RATIO_TYPE="null" +N_SAMPLES=8 + +case "$MODE" in + baseline) + STEP_WISE="false" + TIS_RATIO_TYPE="null" + ;; + stepwise) + STEP_WISE="true" + TIS_RATIO_TYPE="null" + ;; + stepwise-tis) + STEP_WISE="true" + TIS_RATIO_TYPE="token" + ;; + *) + echo "Unknown mode: $MODE. Use: baseline, stepwise, or stepwise-tis" + exit 1 + ;; +esac + +echo "Running mode: $MODE (step_wise=$STEP_WISE, tis=$TIS_RATIO_TYPE)" + +# Prevent CUDA OOM fragmentation +export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True + +# Kill any lingering Daytona sandboxes first +echo "Cleaning up Daytona sandboxes..." +uv run --isolated --extra harbor python examples/train_integrations/harbor/kill_daytona_sandboxes.py || true + +# Run SkyRL command +uv run --isolated --extra fsdp --extra harbor -m examples.train_integrations.harbor.entrypoints.main_harbor \ + data.train_data=$TRAIN_DATA \ + trainer.policy.model.path=Qwen/Qwen3-8B \ + generator.inference_engine.served_model_name=Qwen3-8B \ + harbor_trial_config.trials_dir=$TRIALS_DIR \ + trainer.export_path=$EXPORTS_DIR \ + trainer.ckpt_path=$CKPTS_DIR \ + trainer.log_path=$LOG_DIR \ + trainer.algorithm.advantage_estimator=grpo \ + trainer.algorithm.loss_reduction=$LOSS_REDUCTION \ + trainer.algorithm.grpo_norm_by_std=$GRPO_NORM_BY_STD \ + trainer.algorithm.use_kl_loss=$USE_KL_LOSS \ + trainer.placement.colocate_all=true \ + trainer.strategy=fsdp2 \ + trainer.placement.policy_num_nodes=1 \ + trainer.placement.ref_num_nodes=1 \ + trainer.placement.policy_num_gpus_per_node=$NUM_GPUS \ + trainer.placement.ref_num_gpus_per_node=$NUM_GPUS \ + generator.inference_engine.num_engines=$NUM_GPUS \ + generator.inference_engine.tensor_parallel_size=1 \ + generator.inference_engine.engine_init_kwargs.chat_template=$CHAT_TEMPLATE_PATH \ + generator.inference_engine.engine_init_kwargs.max_model_len=$MAX_MODEL_LEN \ + generator.inference_engine.engine_init_kwargs.enable_log_requests=false \ + trainer.epochs=10 \ + trainer.eval_before_train=false \ + trainer.eval_interval=999999 \ + trainer.eval_batch_size=128 \ + trainer.update_epochs_per_batch=1 \ + trainer.train_batch_size=$MINI_BATCH_SIZE \ + trainer.policy_mini_batch_size=$MINI_BATCH_SIZE \ + trainer.micro_forward_batch_size_per_gpu=1 \ + trainer.micro_train_batch_size_per_gpu=1 \ + trainer.ckpt_interval=999999 \ + trainer.hf_save_interval=999999 \ + trainer.algorithm.max_seq_len=$MAX_MODEL_LEN \ + trainer.policy.optimizer_config.lr=1.0e-6 \ + generator.n_samples_per_prompt=$N_SAMPLES \ + generator.apply_overlong_filtering=$APPLY_OVERLONG_FILTERING \ + generator.inference_engine.gpu_memory_utilization=0.8 \ + generator.step_wise_trajectories=$STEP_WISE \ + trainer.algorithm.off_policy_correction.tis_ratio_type=$TIS_RATIO_TYPE \ + trainer.logger=wandb \ + trainer.project_name=harbor \ + trainer.run_name=$RUN_NAME \ + trainer.resume_mode=latest \ + generator.inference_engine.backend=vllm \ + generator.inference_engine.run_engines_locally=true \ + generator.inference_engine.weight_sync_backend=nccl \ + generator.inference_engine.async_engine=true \ + generator.batched=false \ + generator.inference_engine.enforce_eager=false \ + generator.inference_engine.enable_http_endpoint=true \ + generator.inference_engine.http_endpoint_host=127.0.0.1 \ + generator.inference_engine.http_endpoint_port=8000 \ + generator.rate_limit.enabled=$ENABLE_RATE_LIMITING \ + generator.rate_limit.trajectories_per_second=$TRAJECTORIES_PER_SECOND \ + generator.rate_limit.max_concurrency=$MAX_CONCURRENCY \ + "${@:2}" diff --git a/examples/train_integrations/harbor/run_codecontest_stepwise.sh b/examples/train_integrations/harbor/run_codecontest_stepwise.sh new file mode 100755 index 0000000000..c33b83dd8c --- /dev/null +++ b/examples/train_integrations/harbor/run_codecontest_stepwise.sh @@ -0,0 +1,104 @@ +set -ex + +# wandb api key. +# export WANDB_API_KEY=YOUR_KEY_HERE + +# Pick the sandbox provider and provide the credentials. +# export DAYTONA_API_KEY=YOUR_KEY_HERE + +#----------------------- +# Dataset setup +#----------------------- +DATA_DIR="$HOME/data/harbor" +TRAIN_DATA="['$DATA_DIR/CodeContests']" + +#----------------------- +# Directory setup +#----------------------- +RUN_NAME="codecontest-stepwise" +TRIALS_DIR="$HOME/$RUN_NAME/trials_run" +CKPTS_DIR="$HOME/$RUN_NAME/ckpts" +EXPORTS_DIR="/mnt/local_storage/$RUN_NAME/exports" +LOG_DIR="/tmp/skyrl-logs/$RUN_NAME" + +#----------------------- +# Training setup +#----------------------- +MINI_BATCH_SIZE=32 +MAX_MODEL_LEN=32768 +APPLY_OVERLONG_FILTERING=true + +# Dr. GRPO parameters +LOSS_REDUCTION="seq_mean_token_sum_norm" +GRPO_NORM_BY_STD=false +USE_KL_LOSS=false + +#---------------- +# Infrastructure setup +#---------------- +NUM_GPUS=8 +ENABLE_RATE_LIMITING=true +MAX_CONCURRENCY=500 + +# Reduce CUDA memory fragmentation — may help avoid OOM during weight sync +export PYTORCH_ALLOC_CONF=expandable_segments:True +# Prevent FSDP2 async collectives from pinning memory with record_stream +export TORCH_NCCL_AVOID_RECORD_STREAMS=1 + +# Run SkyRL command +uv run --isolated --extra fsdp --extra harbor -m examples.train_integrations.harbor.entrypoints.main_harbor \ + data.train_data=$TRAIN_DATA \ + trainer.policy.model.path=Qwen/Qwen3-8B \ + generator.inference_engine.served_model_name=Qwen3-8B \ + harbor_trial_config.trials_dir=$TRIALS_DIR \ + trainer.export_path=$EXPORTS_DIR \ + trainer.ckpt_path=$CKPTS_DIR \ + trainer.log_path=$LOG_DIR \ + trainer.algorithm.advantage_estimator=grpo \ + trainer.algorithm.loss_reduction=$LOSS_REDUCTION \ + trainer.algorithm.grpo_norm_by_std=$GRPO_NORM_BY_STD \ + trainer.algorithm.use_kl_loss=$USE_KL_LOSS \ + trainer.algorithm.policy_loss_type=dual_clip \ + trainer.placement.colocate_all=true \ + trainer.strategy=fsdp2 \ + trainer.placement.policy_num_nodes=1 \ + trainer.placement.ref_num_nodes=1 \ + trainer.placement.policy_num_gpus_per_node=$NUM_GPUS \ + trainer.placement.ref_num_gpus_per_node=$NUM_GPUS \ + generator.inference_engine.num_engines=$NUM_GPUS \ + generator.inference_engine.tensor_parallel_size=1 \ + generator.inference_engine.engine_init_kwargs.max_model_len=$MAX_MODEL_LEN \ + generator.inference_engine.engine_init_kwargs.enable_log_requests=false \ + trainer.epochs=3 \ + trainer.eval_before_train=false \ + trainer.eval_interval=900 \ + trainer.update_epochs_per_batch=1 \ + trainer.train_batch_size=$MINI_BATCH_SIZE \ + trainer.policy_mini_batch_size=$MINI_BATCH_SIZE \ + trainer.micro_forward_batch_size_per_gpu=1 \ + trainer.micro_train_batch_size_per_gpu=1 \ + trainer.ckpt_interval=2 \ + trainer.hf_save_interval=2 \ + trainer.max_ckpts_to_keep=3 \ + trainer.algorithm.max_seq_len=$MAX_MODEL_LEN \ + trainer.policy.optimizer_config.lr=1.0e-6 \ + generator.n_samples_per_prompt=8 \ + generator.apply_overlong_filtering=$APPLY_OVERLONG_FILTERING \ + generator.inference_engine.gpu_memory_utilization=0.8 \ + trainer.logger=wandb \ + trainer.project_name=harbor \ + trainer.run_name=$RUN_NAME \ + trainer.resume_mode=latest \ + generator.inference_engine.backend=vllm \ + generator.inference_engine.run_engines_locally=true \ + generator.inference_engine.weight_sync_backend=nccl \ + generator.inference_engine.async_engine=true \ + generator.batched=false \ + generator.inference_engine.enforce_eager=false \ + generator.inference_engine.enable_http_endpoint=true \ + generator.inference_engine.http_endpoint_host=127.0.0.1 \ + generator.inference_engine.http_endpoint_port=8000 \ + generator.rate_limit.enabled=$ENABLE_RATE_LIMITING \ + generator.rate_limit.max_concurrency=$MAX_CONCURRENCY \ + generator.step_wise_trajectories=true \ + "$@" diff --git a/pyproject.toml b/pyproject.toml index fb8e501e7a..86f04ceb8d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -250,7 +250,7 @@ torchvision = [ ] # pin megatron bridge commit to fix for MoE + LoRA merging. Update this when an official release is cut megatron-bridge = {git = "https://github.com/NVIDIA-NeMo/Megatron-Bridge", rev = "02b5fccab5e5b21856d36c2e357839e0123b4b8f", marker = "sys_platform == 'linux'"} -harbor = { git = "https://github.com/laude-institute/harbor", rev = "8c040e1bb010201fd3c75bee3dede2407b9f57cd" } +harbor = { path = "/home/ray/default/harbor" } [tool.black] line-length = 120 diff --git a/skyrl/backends/skyrl_train/workers/fsdp/fsdp_worker.py b/skyrl/backends/skyrl_train/workers/fsdp/fsdp_worker.py index c30d0b7ee9..1344e6ab60 100644 --- a/skyrl/backends/skyrl_train/workers/fsdp/fsdp_worker.py +++ b/skyrl/backends/skyrl_train/workers/fsdp/fsdp_worker.py @@ -1,4 +1,5 @@ from skyrl.train.utils.trainer_utils import get_rope_scaling_config, get_rope_theta_config +import gc import ray import torch import torch.distributed @@ -237,6 +238,7 @@ async def broadcast_to_inference_engines(self, inference_engine_client, inferenc # clear prefix cache cache_reset_task = inference_engine_client.reset_prefix_cache() + gc.collect() torch.cuda.empty_cache() # Check if this is a LoRA model @@ -260,6 +262,7 @@ async def broadcast_to_inference_engines(self, inference_engine_client, inferenc if cache_reset_task is not None: await cache_reset_task + gc.collect() torch.cuda.empty_cache() torch.distributed.barrier() diff --git a/skyrl/train/dataset/preprocess.py b/skyrl/train/dataset/preprocess.py index a17cea1b91..ef78b69d47 100644 --- a/skyrl/train/dataset/preprocess.py +++ b/skyrl/train/dataset/preprocess.py @@ -1,9 +1,13 @@ +import logging from typing import List, Tuple, Optional import torch from transformers import AutoTokenizer from jaxtyping import Float +logger = logging.getLogger(__name__) + + def _verify_inputs( prompts: List[List[int]], responses: List[List[int]], @@ -32,6 +36,7 @@ def convert_prompts_responses_to_batch_tensors( rewards: List[List[float]], loss_masks: List[List[int]], logprobs: Optional[List[List[float]]] = None, + max_seq_len: Optional[int] = None, ) -> Tuple[ Float[torch.Tensor, "batch seq_len"], Float[torch.Tensor, "batch seq_len"], @@ -43,12 +48,33 @@ def convert_prompts_responses_to_batch_tensors( """ Convert prompts and responses to batch tensors for training. - This function concatenates all prompts and responses to the following format: + Each sequence is laid out as a single left-padded block: + + | [PAD] [PAD] prompt prompt prompt respon respon | + | [PAD] prompt prompt prompt respon respon respon | + | prompt prompt prompt respon respon respon respon | + |<---- max_response_len ---->| + + The padded sequence length is ``max(prompt_len_i + response_len_i)``. + This way, the max padded sequence length is ``max_seq_len``. + + This makes the response-level tensors (action_mask, rewards, loss_masks, logprobs): + | prompt prompt respon respon | + | prompt respon respon respon | + | respon respon respon respon | + + So the action_mask is: + | 0 0 1 1 | + | 0 1 1 1 | + | 1 1 1 1 | - | [PAD] [PAD] token token token | token token [PAD] [PAD] | - | token token token token token | token token [PAD] [PAD] | - | [PAD] [PAD] [PAD] token token | token token token [PAD] | - |<---------- prompt ----------->|<-------- answer ------->| + Attention mask is 1 for all real tokens, 0 for padding. + Action mask is 1 for the last ``response_len_i`` positions, 0 for padding. + + Response-level tensors are **right-aligned** within ``(batch, max_response_len)``: non-padded + values occupy the last ``response_len_i`` positions, with leading zeros. This matches the model + forward pass which extracts ``log_probs[:, -num_actions-1:-1]`` —- response tokens are always at + the end of the sequence, so their logprobs are right-aligned in the slice. Assumes that the responses already contain an eos token at index -1. @@ -59,74 +85,76 @@ def convert_prompts_responses_to_batch_tensors( rewards: List of rewards for each response loss_masks: List of loss masks for each response logprobs: List of rollout log probs for each response + max_seq_len: Optional. If provided and ``max(prompt_i + response_i)`` + exceeds it, a warning is logged (no truncation is performed). Returns: - sequences: Full trajectories (padded and concatenated prompts and responses). Size: (batch, seq_len). - attention_mask: Attention mask for the model. Size: (batch, seq_len) - action_mask: Response mask for the model. Size: (batch, response_len) - rewards: Rewards for each output. Size: (batch, response_len) - loss_masks: Loss masks for each output. Size: (batch, response_len) + sequences: ``(batch, max_total)`` where ``max_total = max(prompt_i + response_i)``. + attention_mask: ``(batch, max_total)`` + action_mask: ``(batch, max_response)`` — right-aligned response indicator. + rewards: ``(batch, max_response)`` — right-aligned. + loss_masks: ``(batch, max_response)`` — right-aligned. + logprobs: ``(batch, max_response)`` — right-aligned, or ``None``. """ _verify_inputs(prompts, responses, rewards, loss_masks) - max_input_len, max_output_len = 0, 0 - prompt_token_lens, response_token_lens = [], [] - inputs_token_ids, outputs_token_ids = [], [] - for prompt, response in zip(prompts, responses): - - inputs_token_ids.append(prompt) - outputs_token_ids.append(response) + prompt_token_lens = [len(p) for p in prompts] + response_token_lens = [len(r) for r in responses] - prompt_token_len = len(prompt) - response_token_len = len(response) - prompt_token_lens.append(prompt_token_len) - response_token_lens.append(response_token_len) + max_response = max(response_token_lens) + # Pad to the tightest bound: max per-sample total. + max_total = max(p + r for p, r in zip(prompt_token_lens, response_token_lens)) - max_input_len = max(max_input_len, prompt_token_len) - max_output_len = max(max_output_len, response_token_len) + if max_seq_len is not None and max_total > max_seq_len: + logger.warning( + f"Max sequence length in batch ({max_total}) exceeds max_seq_len ({max_seq_len}). " + f"No truncation is performed; consider checking generator settings." + ) pad_token_id = tokenizer.pad_token_id sequences = [] attention_masks = [] action_masks = [] - for i, prompt in enumerate(prompts): - # left padding input - input_len = prompt_token_lens[i] - input_ids = [pad_token_id] * (max_input_len - input_len) + list(inputs_token_ids[i]) - input_attention_mask = [0] * (max_input_len - input_len) + [1] * input_len - - # right padding output - output_len = response_token_lens[i] - output_ids = list(outputs_token_ids[i]) + [pad_token_id] * (max_output_len - output_len) - output_attention_mask = [1] * output_len + [0] * (max_output_len - output_len) - - # concat input and output - sequences.append(input_ids + output_ids) - attention_masks.append(input_attention_mask + output_attention_mask) - action_masks.append(output_attention_mask) + for i in range(len(prompts)): + total_real = prompt_token_lens[i] + response_token_lens[i] + pad_len = max_total - total_real + + # Unified left-pad: [PAD ... PAD PROMPT RESPONSE] + seq = [pad_token_id] * pad_len + prompts[i] + responses[i] + attention_mask_i = [0] * pad_len + [1] * total_real + + # Response indicator within the last max_response positions (right-aligned). + resp_pad = max_response - response_token_lens[i] + action_mask_i = [0] * resp_pad + [1] * response_token_lens[i] + + sequences.append(seq) + attention_masks.append(attention_mask_i) + action_masks.append(action_mask_i) sequences = torch.tensor(sequences) attention_mask = torch.tensor(attention_masks, dtype=torch.int64) action_mask = torch.tensor(action_masks, dtype=torch.int64) - # initialize ret loss masks to be the same as action mask - ret_loss_masks = torch.zeros_like(action_mask, dtype=torch.float) - for i, loss_mask in enumerate(loss_masks): - ret_loss_masks[i, : len(loss_mask)] = torch.tensor(loss_mask) + # Response-level tensors are RIGHT-ALIGNED to match the model output. + # The model's log_probs[:, -num_actions-1:-1] returns logprobs where + # response tokens occupy the last response_len_i positions. + ret_loss_masks = torch.zeros(len(prompts), max_response, dtype=torch.float) + for i, lm in enumerate(loss_masks): + ret_loss_masks[i, max_response - len(lm) :] = torch.tensor(lm, dtype=torch.float) - # do the same for custom rewards - ret_rewards = torch.zeros_like(action_mask, dtype=torch.float) + # Same thing for rewards. + ret_rewards = torch.zeros(len(prompts), max_response, dtype=torch.float) for i, custom_reward in enumerate(rewards): if isinstance(custom_reward, list): custom_reward = torch.tensor(custom_reward) - ret_rewards[i, : len(custom_reward)] = custom_reward + ret_rewards[i, max_response - len(custom_reward) :] = custom_reward + # Same thing for logprobs. logprobs_tensor = None if logprobs: - max_output_len = action_mask.size(1) - padded_logprobs = [ - sample_logprobs + [0.0] * (max_output_len - len(sample_logprobs)) for sample_logprobs in logprobs - ] - logprobs_tensor = torch.tensor(padded_logprobs, dtype=torch.float) + logprobs_tensor = torch.zeros(len(prompts), max_response, dtype=torch.float) + for i, sample_logprobs in enumerate(logprobs): + lp = torch.tensor(sample_logprobs, dtype=torch.float) + logprobs_tensor[i, max_response - len(sample_logprobs) :] = lp return sequences, attention_mask, action_mask, ret_rewards, ret_loss_masks, logprobs_tensor diff --git a/skyrl/train/trainer.py b/skyrl/train/trainer.py index 51541bfc77..fc832acbeb 100644 --- a/skyrl/train/trainer.py +++ b/skyrl/train/trainer.py @@ -1,4 +1,5 @@ import copy +import gc import math import os import shutil @@ -256,7 +257,6 @@ async def train(self): # 3. Convert GeneratorOutput to TrainingInputBatch with Timer("convert_to_training_input", self.all_timings): training_input: TrainingInputBatch = self.convert_to_training_input(generator_output, uids) - logger.info(f"Number of sequences: {len(training_input['sequences'])}") # 4. Inference and calculate values, log probs, rewards, kl divergence with Timer("fwd_logprobs_values_reward", self.all_timings): @@ -620,6 +620,7 @@ def convert_to_training_input(self, generator_output: GeneratorOutput, uids: Lis rewards, loss_masks, logprobs, + max_seq_len=self.cfg.trainer.algorithm.max_seq_len, ) # sanity check for off_policy_correction @@ -650,6 +651,14 @@ def convert_to_training_input(self, generator_output: GeneratorOutput, uids: Lis training_input.metadata = {"uids": uids} # padded response length training_input.metadata["response_length"] = response_masks_tensor.shape[1] + batch_num_seq, batch_padded_seq_len = sequences_tensor.shape + logger.info(f"batch_num_seq: {batch_num_seq}, batch_padded_seq_len: {batch_padded_seq_len}") + self.all_metrics.update( + { + "generate/batch_num_seq": batch_num_seq, + "generate/batch_padded_seq_len": batch_padded_seq_len, + } + ) if self.cfg.generator.step_wise_trajectories: assert ( "trajectory_ids" in generator_output @@ -693,8 +702,11 @@ async def generate( if generator_output["rollout_metrics"] is not None: self.all_metrics.update(generator_output["rollout_metrics"]) - if not self.cfg.generator.step_wise_trajectories: - validate_generator_output(len(input_batch["prompts"]), generator_output) + validate_generator_output( + len(input_batch["prompts"]), + generator_output, + step_wise=self.cfg.generator.step_wise_trajectories, + ) return generator_output @@ -944,6 +956,7 @@ def fwd_logprobs_values_reward( if self.ref_model is not None: ref_output = self.dispatch.forward("ref", data_fwd_pass) base_log_probs = ref_output["output"] + gc.collect() self.dispatch.empty_cache("ref") # Policy forward @@ -951,6 +964,7 @@ def fwd_logprobs_values_reward( action_log_probs = policy_output["output"] # Empty cache after all forward passes + gc.collect() self.dispatch.empty_cache() sequences_all: torch.Tensor = training_input["sequences"] @@ -1113,6 +1127,7 @@ def train_critic_and_policy(self, data: TrainingInputBatch): for k, v in policy_status.items(): self.all_metrics.update({f"policy/{k}": v}) + gc.collect() self.dispatch.empty_cache() return policy_status diff --git a/skyrl/train/utils/trainer_utils.py b/skyrl/train/utils/trainer_utils.py index 567c2b0ba3..1c54393617 100644 --- a/skyrl/train/utils/trainer_utils.py +++ b/skyrl/train/utils/trainer_utils.py @@ -589,20 +589,25 @@ def zero_variance_filter(rewards: List[float], uids: List[str]) -> List[int]: return [i for i, uid in enumerate(uids) if uid in kept_uids_set] -def validate_generator_output(num_prompts: int, generator_output: GeneratorOutput): +def validate_generator_output(num_prompts: int, generator_output: GeneratorOutput, step_wise: bool = False): """Validate the generator output. Args: num_prompts: Number of input prompts used to produce this output. generator_output: The generated output batch to validate. + step_wise: If True, validate step-wise specific fields (is_last_step, trajectory_ids, + contiguous ordering). In step-wise mode, num_responses may exceed num_prompts + because each trajectory is expanded into multiple per-turn samples. """ if len(generator_output["response_ids"]) <= 0: raise RuntimeError("No outputs generated") - # check that input prompts, response ids, and prompt token ids are all the same length num_responses = len(generator_output["response_ids"]) num_prompt_tokens = len(generator_output["prompt_token_ids"]) - assert num_prompts == num_responses, f"Mismatch between prompts ({num_prompts}) and responses ({num_responses})" + + if not step_wise: + assert num_prompts == num_responses, f"Mismatch between prompts ({num_prompts}) and responses ({num_responses})" + assert ( num_responses == num_prompt_tokens ), f"Mismatch between responses ({num_responses}) and prompt_token_ids ({num_prompt_tokens})" @@ -656,6 +661,72 @@ def validate_generator_output(num_prompts: int, generator_output: GeneratorOutpu not isinstance(reward, list) for reward in rewards ), "rewards must be `List[float]` or `List[List[float]]`" + if step_wise: + _validate_step_wise_fields(generator_output, num_responses) + + +def _validate_step_wise_fields(generator_output: GeneratorOutput, num_responses: int): + """Validate step-wise specific fields in the generator output. + + Checks that is_last_step and trajectory_ids are present, correctly sized, + contiguously ordered, and that is_last_step boundaries align with trajectory_id changes. + + The contiguity check is critical: the trainer's advantage broadcast uses + ``cumsum(shifted_is_last_step)`` to map each step to its trajectory, which + silently produces wrong results if steps from the same trajectory are interleaved + with steps from other trajectories. + """ + assert generator_output.get("is_last_step") is not None, ( + "step_wise=True but `is_last_step` is missing from generator output" + ) + assert generator_output.get("trajectory_ids") is not None, ( + "step_wise=True but `trajectory_ids` is missing from generator output" + ) + + is_last_step = generator_output["is_last_step"] + trajectory_ids = generator_output["trajectory_ids"] + + assert len(is_last_step) == num_responses, ( + f"is_last_step length ({len(is_last_step)}) must equal response_ids length ({num_responses})" + ) + assert len(trajectory_ids) == num_responses, ( + f"trajectory_ids length ({len(trajectory_ids)}) must equal response_ids length ({num_responses})" + ) + + assert is_last_step[-1] is True, ( + "is_last_step[-1] must be True (the last sample must be the final step of a trajectory)" + ) + + num_trajectories = sum(1 for x in is_last_step if x) + assert num_trajectories >= 1, "is_last_step must contain at least one True value" + + # Validate contiguous ordering: all steps of the same trajectory must be adjacent. + seen_trajectory_ids = set() + prev_tid = None + for i, tid in enumerate(trajectory_ids): + tid_key = tid.to_string() if hasattr(tid, "to_string") else str(tid) + if tid_key != prev_tid: + assert tid_key not in seen_trajectory_ids, ( + f"Non-contiguous trajectory at index {i}: trajectory '{tid_key}' appeared before " + f"(at earlier indices), then a different trajectory, then again here. " + f"Step-wise training requires all steps of the same trajectory to be adjacent." + ) + if prev_tid is not None: + seen_trajectory_ids.add(prev_tid) + prev_tid = tid_key + if prev_tid is not None: + seen_trajectory_ids.add(prev_tid) + + # Validate is_last_step aligns with trajectory boundaries + for i in range(num_responses - 1): + tid_cur = trajectory_ids[i].to_string() if hasattr(trajectory_ids[i], "to_string") else str(trajectory_ids[i]) + tid_next = trajectory_ids[i + 1].to_string() if hasattr(trajectory_ids[i + 1], "to_string") else str(trajectory_ids[i + 1]) + if tid_cur != tid_next: + assert is_last_step[i] is True, ( + f"Trajectory boundary at index {i} ('{tid_cur}' → '{tid_next}') " + f"but is_last_step[{i}] is False. Must be True at trajectory boundaries." + ) + def build_dataloader( cfg: SkyRLTrainConfig, dataset: PromptDataset, is_train=True, is_fully_async=False diff --git a/test_vllm_sleep_wake.py b/test_vllm_sleep_wake.py new file mode 100644 index 0000000000..10c2a05f2a --- /dev/null +++ b/test_vllm_sleep_wake.py @@ -0,0 +1,139 @@ +""" +Minimal reproduction script for vLLM cudaErrorInvalidValue crash during sleep/wake cycles. + +Tests whether repeated sleep → wake_up(weights) → wake_up(kv_cache) → generate → sleep +causes the crash at flash_attn.py:484 (scheduler_metadata[:n] = scheduler_metadata). + +Usage: + CUDA_VISIBLE_DEVICES=0 uv run --isolated --extra fsdp python test_vllm_sleep_wake.py --cycles 15 + CUDA_VISIBLE_DEVICES=0 uv run --isolated --extra fsdp python test_vllm_sleep_wake.py --cycles 15 --with-weight-update +""" + +import argparse +import asyncio +import logging +import os +import sys +import time + +os.environ["VLLM_USE_V1"] = "1" + +import torch +import vllm +from vllm import LLM, SamplingParams + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(levelname)s] %(message)s", + handlers=[logging.StreamHandler(sys.stdout)], +) +log = logging.getLogger("sleep_wake_repro") + + +def run_test(model: str, max_model_len: int, gpu_mem_util: float, + num_cycles: int, with_weight_update: bool): + """Test sleep/wake cycles with synchronous LLM API.""" + + log.info(f"vLLM version: {vllm.__version__}") + log.info(f"Config: model={model}, max_model_len={max_model_len}, " + f"gpu_mem={gpu_mem_util}, cycles={num_cycles}, " + f"with_weight_update={with_weight_update}") + + # Create engine + llm = LLM( + model=model, + max_model_len=max_model_len, + gpu_memory_utilization=gpu_mem_util, + tensor_parallel_size=1, + enforce_eager=False, + enable_prefix_caching=True, + ) + log.info("Engine created successfully") + + sampling_params = SamplingParams(temperature=0.7, top_p=0.9, max_tokens=128) + test_prompts = [ + "Write a Python function that checks if a number is prime.", + "What is the capital of France? Explain briefly.", + "Solve: 2 + 2 * 3 = ?", + "Tell me a short joke about programming.", + ] + + for cycle in range(1, num_cycles + 1): + log.info(f"{'='*60}") + log.info(f"CYCLE {cycle}/{num_cycles}") + log.info(f"{'='*60}") + + try: + # Phase 1: Sleep (free GPU memory, simulates training phase occupying GPU) + log.info(f"[Cycle {cycle}] sleep(level=2)...") + t0 = time.time() + llm.sleep(level=2) + log.info(f"[Cycle {cycle}] sleep done in {time.time()-t0:.2f}s") + + # Simulate training phase (just wait a bit) + time.sleep(0.5) + + # Phase 2: Wake up weights + log.info(f"[Cycle {cycle}] wake_up(tags=['weights'])...") + t0 = time.time() + llm.wake_up(tags=["weights"]) + log.info(f"[Cycle {cycle}] wake_up(weights) done in {time.time()-t0:.2f}s") + + # Phase 3: Optionally simulate weight update (like SkyRL's broadcast_to_inference_engines) + if with_weight_update: + log.info(f"[Cycle {cycle}] Simulating weight update (noop load_state_dict)...") + # In real SkyRL, weights are updated via NCCL broadcast. + # Here we just touch the model to simulate the effect. + # This is a lightweight stand-in — the real weight sync uses + # the NCCL weight transfer sender. + pass + + # Phase 4: Wake up KV cache + log.info(f"[Cycle {cycle}] wake_up(tags=['kv_cache'])...") + t0 = time.time() + llm.wake_up(tags=["kv_cache"]) + log.info(f"[Cycle {cycle}] wake_up(kv_cache) done in {time.time()-t0:.2f}s") + + # Phase 5: Generate completions + log.info(f"[Cycle {cycle}] Generating {len(test_prompts)} completions...") + t0 = time.time() + outputs = llm.generate(test_prompts, sampling_params) + gen_time = time.time() - t0 + + for i, output in enumerate(outputs): + text = output.outputs[0].text[:60] + log.info(f" Prompt {i}: {text!r}") + + log.info(f"[Cycle {cycle}] Generation done in {gen_time:.2f}s") + log.info(f"[Cycle {cycle}] PASSED ✓") + + except Exception as e: + log.error(f"[Cycle {cycle}] CRASH: {type(e).__name__}: {e}") + import traceback + log.error(traceback.format_exc()) + log.error(f"RESULT: Crashed on cycle {cycle}/{num_cycles}") + return cycle + + log.info(f"RESULT: All {num_cycles} cycles completed without crash ✓") + return 0 + + +def main(): + parser = argparse.ArgumentParser(description="vLLM sleep/wake crash reproduction") + parser.add_argument("--model", default="Qwen/Qwen3-8B", help="Model name/path") + parser.add_argument("--max-model-len", type=int, default=32768) + parser.add_argument("--gpu-mem-util", type=float, default=0.8) + parser.add_argument("--cycles", type=int, default=15, help="Number of sleep/wake cycles") + parser.add_argument("--with-weight-update", action="store_true", + help="Simulate weight updates between sleep/wake") + args = parser.parse_args() + + crash_cycle = run_test( + args.model, args.max_model_len, args.gpu_mem_util, + args.cycles, args.with_weight_update, + ) + sys.exit(1 if crash_cycle else 0) + + +if __name__ == "__main__": + main() diff --git a/test_vllm_sleep_wake_with_weights.py b/test_vllm_sleep_wake_with_weights.py new file mode 100644 index 0000000000..b3c9d1e410 --- /dev/null +++ b/test_vllm_sleep_wake_with_weights.py @@ -0,0 +1,261 @@ +""" +Variant 2: Sleep/wake crash reproduction WITH simulated weight updates. + +This simulates the full SkyRL training loop: + sleep -> wake_up(weights) -> update_weights -> wake_up(kv_cache) -> generate -> sleep + +In vLLM 0.16.0, update_weights is available via collective_rpc or the engine's +load_weights/update_weights API. Here we use collective_rpc("update_weights") +to simulate a weight update by reloading the same checkpoint weights, which +exercises the same CUDA memory paths as a real training update. + +Usage: python test_vllm_sleep_wake_with_weights.py [--cycles 10] [--model Qwen/Qwen3-8B] +""" + +import argparse +import asyncio +import logging +import os +import sys +import time +import traceback + +os.environ["VLLM_USE_V1"] = "1" + +import torch +import vllm +from vllm import SamplingParams +from vllm.inputs import TokensPrompt + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(levelname)s] %(message)s", + handlers=[logging.StreamHandler(sys.stdout)], +) +log = logging.getLogger("sleep_wake_weights_repro") + + +def create_engine(model: str, max_model_len: int, gpu_mem_util: float): + """Create an AsyncLLMEngine.""" + engine_args = vllm.AsyncEngineArgs( + model=model, + max_model_len=max_model_len, + gpu_memory_utilization=gpu_mem_util, + tensor_parallel_size=1, + enforce_eager=False, + enable_prefix_caching=True, + enable_log_requests=False, + ) + engine = vllm.AsyncLLMEngine.from_engine_args(engine_args) + log.info(f"Engine created: vllm {vllm.__version__}, model={model}") + return engine + + +async def generate_completions(engine, num_prompts: int = 4, max_tokens: int = 64): + """Generate a small batch of completions.""" + sampling_params = SamplingParams( + temperature=0.7, + top_p=0.9, + max_tokens=max_tokens, + ) + + test_prompts = [ + [9707, 11, 1917, 0], + [791, 4059, 315], + [25, 220, 16, 489, 220, 16, 284], + [3923, 374, 279, 6864, 315, 9822, 30], + ] + + tasks = [] + for i in range(num_prompts): + prompt_tokens = test_prompts[i % len(test_prompts)] + request_id = f"req-{int(time.time_ns())}-{i}" + + async def collect(rid, tokens): + final = None + async for output in engine.generate( + prompt=TokensPrompt(prompt_token_ids=tokens), + sampling_params=sampling_params, + request_id=rid, + ): + final = output + return final + + tasks.append(asyncio.create_task(collect(request_id, prompt_tokens))) + + outputs = await asyncio.gather(*tasks, return_exceptions=True) + successes = 0 + for i, out in enumerate(outputs): + if isinstance(out, Exception): + log.error(f" Prompt {i} FAILED: {out}") + else: + text = out.outputs[0].text[:80] if out and out.outputs else "" + successes += 1 + log.info(f" Prompt {i}: {text!r}") + return successes + + +async def simulate_weight_update(engine, model_path: str, cycle: int): + """ + Simulate a weight update between sleep/wake cycles. + + Strategy 1 (preferred): Use engine.collective_rpc("update_weights") if available. + Strategy 2 (fallback): Load model state_dict and apply via collective_rpc("load_weights"). + Strategy 3 (simplest): Use the /update_weights endpoint pattern from vLLM 0.16.0. + + For reproduction purposes, we reload the SAME weights (no actual training), + which still exercises the CUDA memory allocation/deallocation paths. + """ + log.info(f"[Cycle {cycle}] Simulating weight update...") + t0 = time.time() + + try: + # vLLM 0.16.0 has update_weights via collective_rpc + # This reloads weights from the model path (same weights, but exercises the path) + await engine.collective_rpc( + "update_weights", + args=(model_path,), + ) + log.info(f"[Cycle {cycle}] Weight update via collective_rpc done in {time.time()-t0:.2f}s") + return True + except Exception as e1: + log.warning(f"[Cycle {cycle}] collective_rpc('update_weights') failed: {e1}") + + try: + # Fallback: try direct model reload via check_weights_changed API + # This is available in some vLLM versions + await engine.check_and_update_model(model_path) + log.info(f"[Cycle {cycle}] Weight update via check_and_update_model done in {time.time()-t0:.2f}s") + return True + except Exception as e2: + log.warning(f"[Cycle {cycle}] check_and_update_model failed: {e2}") + + try: + # Last resort: perturb a single weight tensor via collective_rpc + # This exercises the weight update CUDA path with minimal overhead + log.info(f"[Cycle {cycle}] Attempting weight perturbation via collective_rpc...") + + # Get a weight name from the model (use a small one like layernorm) + # We pass a dummy perturbation that workers can apply + await engine.collective_rpc( + "apply_weight_delta", + args=(), + ) + log.info(f"[Cycle {cycle}] Weight perturbation done in {time.time()-t0:.2f}s") + return True + except Exception as e3: + log.warning(f"[Cycle {cycle}] All weight update strategies failed: {e3}") + log.warning(f"[Cycle {cycle}] Continuing without weight update (testing sleep/wake only)") + return False + + +async def run_cycles_with_weights(engine, model: str, num_cycles: int, sleep_level: int = 2): + """ + Full training-loop simulation: + wake_up(weights) -> [weight update] -> wake_up(kv_cache) -> generate -> reset_prefix_cache -> sleep + """ + for cycle in range(1, num_cycles + 1): + log.info(f"{'='*60}") + log.info(f"CYCLE {cycle}/{num_cycles}") + log.info(f"{'='*60}") + + try: + # Phase 1: Wake up weights + log.info(f"[Cycle {cycle}] wake_up(tags=['weights'])...") + t0 = time.time() + await engine.wake_up(tags=["weights"]) + log.info(f"[Cycle {cycle}] wake_up(weights) done in {time.time()-t0:.2f}s") + + # Phase 2: Simulate weight update (this is what happens during RL training) + weight_updated = await simulate_weight_update(engine, model, cycle) + if weight_updated: + log.info(f"[Cycle {cycle}] Weights updated successfully") + else: + log.info(f"[Cycle {cycle}] Weight update skipped (testing sleep/wake path only)") + + # Phase 3: Wake up KV cache + log.info(f"[Cycle {cycle}] wake_up(tags=['kv_cache'])...") + t0 = time.time() + await engine.wake_up(tags=["kv_cache"]) + log.info(f"[Cycle {cycle}] wake_up(kv_cache) done in {time.time()-t0:.2f}s") + + # Phase 4: Generate + log.info(f"[Cycle {cycle}] Generating completions...") + t0 = time.time() + successes = await generate_completions(engine, num_prompts=4, max_tokens=64) + log.info(f"[Cycle {cycle}] Generation done in {time.time()-t0:.2f}s, " + f"{successes}/4 succeeded") + + # Phase 5: Reset prefix cache before sleep + log.info(f"[Cycle {cycle}] reset_prefix_cache()...") + await engine.reset_prefix_cache() + + # Phase 6: Sleep + log.info(f"[Cycle {cycle}] sleep(level={sleep_level})...") + t0 = time.time() + await engine.sleep(level=sleep_level) + log.info(f"[Cycle {cycle}] sleep done in {time.time()-t0:.2f}s") + + # Log GPU memory state + if torch.cuda.is_available(): + allocated = torch.cuda.memory_allocated() / 1e9 + reserved = torch.cuda.memory_reserved() / 1e9 + log.info(f"[Cycle {cycle}] GPU memory: allocated={allocated:.2f}GB, reserved={reserved:.2f}GB") + + log.info(f"[Cycle {cycle}] PASSED") + + except Exception as e: + log.error(f"[Cycle {cycle}] CRASH: {type(e).__name__}: {e}") + log.error(traceback.format_exc()) + + # Log GPU memory state at crash time + if torch.cuda.is_available(): + allocated = torch.cuda.memory_allocated() / 1e9 + reserved = torch.cuda.memory_reserved() / 1e9 + log.error(f"[Cycle {cycle}] GPU memory at crash: " + f"allocated={allocated:.2f}GB, reserved={reserved:.2f}GB") + + log.error(f"Crashed on cycle {cycle}/{num_cycles}") + return cycle + + log.info(f"All {num_cycles} cycles completed without crash.") + return 0 + + +async def main(): + parser = argparse.ArgumentParser( + description="vLLM sleep/wake crash reproduction with weight updates" + ) + parser.add_argument("--model", default="Qwen/Qwen3-8B", help="Model name/path") + parser.add_argument("--max-model-len", type=int, default=32768) + parser.add_argument("--gpu-mem-util", type=float, default=0.8) + parser.add_argument("--cycles", type=int, default=10, help="Number of sleep/wake cycles") + parser.add_argument("--sleep-level", type=int, default=2, + help="Sleep level (1=keep KV cache, 2=free all)") + args = parser.parse_args() + + log.info(f"vLLM version: {vllm.__version__}") + log.info(f"PyTorch version: {torch.__version__}") + log.info(f"CUDA available: {torch.cuda.is_available()}") + if torch.cuda.is_available(): + log.info(f"GPU: {torch.cuda.get_device_name(0)}") + log.info(f"GPU memory: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f}GB") + log.info(f"Config: model={args.model}, max_model_len={args.max_model_len}, " + f"gpu_mem={args.gpu_mem_util}, cycles={args.cycles}, sleep_level={args.sleep_level}") + + engine = create_engine(args.model, args.max_model_len, args.gpu_mem_util) + + crash_cycle = await run_cycles_with_weights( + engine, args.model, args.cycles, args.sleep_level + ) + + if crash_cycle: + log.error(f"RESULT: Crashed on cycle {crash_cycle}") + sys.exit(1) + else: + log.info("RESULT: No crash detected") + sys.exit(0) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/tests/train/dataset/test_preprocess.py b/tests/train/dataset/test_preprocess.py index aacb7671c2..fa0efeae71 100644 --- a/tests/train/dataset/test_preprocess.py +++ b/tests/train/dataset/test_preprocess.py @@ -56,6 +56,15 @@ def fake_tokenizer_decode_list(ids, **kwargs): def test_convert_prompts_responses_to_batch_tensors_exact(tokenizer): + """ + Test with inputs of exact lengths. + + | [PAD] [PAD] [PAD] [PAD] prompt prompt prompt respon respon respon | + | prompt prompt prompt prompt prompt respon respon respon respon respon | + |<------- max_response_len ------->| + """ + # prompts: "abc" (3 tokens), "12345" (5 tokens) + # outputs: "def" (3 tokens), "67890" (5 tokens) prompts = ["abc", "12345"] outputs = ["def", "67890"] prompts = tokenizer(prompts)["input_ids"] @@ -74,17 +83,25 @@ def test_convert_prompts_responses_to_batch_tensors_exact(tokenizer): ) ) - # loss mask should be the same length as the action mask (padded to the longest input) + # max_total = max(3+3, 5+5) = 10, max_response = 5 assert sequences.shape[0] == len(prompts) + assert sequences.shape == (2, 10) assert action_mask.shape == ret_loss_masks.shape - assert torch.equal(ret_loss_masks[0], torch.tensor([1, 1, 0, 0, 0])) + # Response data is RIGHT-ALIGNED within (batch, max_response) + # Sample 0: response len=3, so 2 leading zeros then 3 values + assert torch.equal(ret_loss_masks[0], torch.tensor([0, 0, 1, 1, 0])) assert torch.equal(ret_loss_masks[1], torch.tensor([1, 1, 1, 0, 0])) - assert torch.equal(ret_rewards[0], torch.tensor([0, 1, 0, 0, 0])) + assert torch.equal(ret_rewards[0], torch.tensor([0, 0, 0, 1, 0])) assert torch.equal(ret_rewards[1], torch.tensor([1, 0, 0, 0, 0])) + # max_total=10: sample 0 has total=6, so 4 left-pads; sample 1 has total=10, no padding + assert torch.equal(attention_mask[0], torch.tensor([0, 0, 0, 0, 1, 1, 1, 1, 1, 1])) + assert torch.equal(attention_mask[1], torch.tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1])) def test_convert_prompts_responses_to_batch_tensors_different_lengths(tokenizer): # Test with inputs of different lengths + # "Short" = 5 tokens, "This is a longer prompt" = 23 tokens + # "Long response here" = 18 tokens, "Short" = 5 tokens prompts = ["Short", "This is a longer prompt"] outputs = ["Long response here", "Short"] prompts = tokenizer(prompts)["input_ids"] @@ -103,20 +120,23 @@ def test_convert_prompts_responses_to_batch_tensors_different_lengths(tokenizer) ) max_response_len = max([len(output) for output in outputs]) + # max_total = max(5+18, 23+5) = 28 + max_total = max(len(p) + len(r) for p, r in zip(prompts, outputs)) # Check shapes - assert sequences.shape[0] == 2 # batch size + assert sequences.shape == (2, max_total) assert attention_mask.shape == sequences.shape - # Tensor.shape can be directly compared with tuples assert action_mask.shape == (2, max_response_len) assert ret_rewards.shape == (2, max_response_len) assert ret_loss_masks.shape == (2, max_response_len) - # Verify padding is applied correctly - # First input is shorter than second input. the input is left padded + # Unified left-padding: shorter total gets left-padded + # Sample 0: total=23, pad=28-23=5 left pads assert sequences[0, 0] == tokenizer.pad_token_id - # second output is shorter than first output. the output is right padded - assert sequences[1, -1] == tokenizer.pad_token_id + assert sequences[1, 0] != tokenizer.pad_token_id + # All sequences end with real tokens (response at end), no right padding + assert sequences[0, -1] != tokenizer.pad_token_id + assert sequences[1, -1] != tokenizer.pad_token_id def test_convert_prompts_responses_to_batch_tensors_empty_input(tokenizer): @@ -153,3 +173,126 @@ def test_convert_prompts_responses_to_batch_tensors_mismatched_lengths(tokenizer rewards, loss_masks, ) + + +# --------------------------------------------------------------------------- +# Unified padding layout tests +# --------------------------------------------------------------------------- + + +def test_unified_left_padding_layout(tokenizer): + """Sequences are laid out as [PAD ... PROMPT RESPONSE] with all padding on the left.""" + # Sample 0: prompt=[1,2], response=[10,11,12] -> total=5 + # Sample 1: prompt=[3,4,5,6], response=[20,21] -> total=6 + # max_total=6, max_response=3 + prompts = [[1, 2], [3, 4, 5, 6]] + responses = [[10, 11, 12], [20, 21]] + rewards = [[0.0] * 3, [0.0] * 2] + loss_masks = [[1] * 3, [1] * 2] + + seq, attn, action, rew, lm, _ = convert_prompts_responses_to_batch_tensors( + tokenizer, + prompts, + responses, + rewards, + loss_masks, + ) + assert seq.shape == (2, 6) + + # Sample 0: pad=1, then [1,2,10,11,12] + assert seq[0].tolist() == [0, 1, 2, 10, 11, 12] + assert attn[0].tolist() == [0, 1, 1, 1, 1, 1] + # Response ends at the end of the sequence (no right-padding in sequences) + assert seq[0, -1] == 12 + + # Sample 1: no pad, [3,4,5,6,20,21] + assert seq[1].tolist() == [3, 4, 5, 6, 20, 21] + assert attn[1].tolist() == [1, 1, 1, 1, 1, 1] + + +def test_right_aligned_response_data(tokenizer): + """Response-level tensors are right-aligned: actual values at the end, zeros at the start.""" + prompts = [[1, 2, 3], [4, 5]] + responses = [[10], [20, 21, 22]] + rewards = [[1.0], [0.5, 0.6, 0.7]] + loss_masks = [[1], [1, 0, 1]] + logprobs = [[-0.1], [-0.2, -0.3, -0.4]] + prompts_copy = [p[:] for p in prompts] + responses_copy = [r[:] for r in responses] + + seq, attn, action, rew, lm, lp = convert_prompts_responses_to_batch_tensors( + tokenizer, + prompts, + responses, + rewards, + loss_masks, + logprobs, + ) + # max_response=3 + assert action.shape == (2, 3) + + # Sample 0: response_len=1, right-aligned -> [0, 0, 1] + assert action[0].tolist() == [0, 0, 1] + assert rew[0].tolist() == [0.0, 0.0, 1.0] + assert lm[0].tolist() == [0.0, 0.0, 1.0] + assert lp[0].tolist() == pytest.approx([0.0, 0.0, -0.1]) + + # Sample 1: response_len=3, right-aligned -> [1, 1, 1] (no padding) + assert action[1].tolist() == [1, 1, 1] + assert rew[1].tolist() == pytest.approx([0.5, 0.6, 0.7]) + assert lm[1].tolist() == [1.0, 0.0, 1.0] + assert lp[1].tolist() == pytest.approx([-0.2, -0.3, -0.4]) + + # Test does not mutate inputs + assert prompts == prompts_copy + assert responses == responses_copy + + +def test_max_seq_len_warns_but_does_not_truncate(tokenizer): + """max_seq_len only warns; no tokens are lost.""" + prompts = [[1] * 50, [2] * 10] + responses = [[3] * 10, [4] * 50] + rewards = [[0.0] * 10, [0.0] * 50] + loss_masks = [[1] * 10, [1] * 50] + + seq, _, action, _, _, _ = convert_prompts_responses_to_batch_tensors( + tokenizer, + prompts, + responses, + rewards, + loss_masks, + max_seq_len=30, + ) + # max_total = max(60, 60) = 60, which exceeds max_seq_len=30 + # But no truncation: all tokens preserved + assert seq.shape == (2, 60) + assert action.shape == (2, 50) + + +def test_stepwise_anti_correlation_no_inflation(tokenizer): + """Step-wise anti-correlated prompt/response lengths: seq_len = max(prompt_i + response_i), + NOT max(prompt_i) + max(response_i).""" + # Early turn: prompt=10, response=90 (total=100) + # Late turn: prompt=90, response=10 (total=100) + prompts = [list(range(10)), list(range(90))] + responses = [list(range(100, 190)), list(range(200, 210))] + rewards = [[0.0] * 90, [0.0] * 10] + loss_masks = [[1] * 90, [1] * 10] + + seq, attn, action, rew, lm, _ = convert_prompts_responses_to_batch_tensors( + tokenizer, + prompts, + responses, + rewards, + loss_masks, + ) + # max(10+90, 90+10) = 100, NOT 90+90=180 + assert seq.shape == (2, 100) + assert action.shape == (2, 90) + + # All real tokens are preserved (no truncation) + assert seq[0].tolist() == list(range(10)) + list(range(100, 190)) + assert seq[1].tolist() == list(range(90)) + list(range(200, 210)) + + # Response data right-aligned: sample 1 has 10 tokens -> [0]*80 + [1]*10 + assert action[1].tolist() == [0] * 80 + [1] * 10 diff --git a/tests/train/test_trainer_utils.py b/tests/train/test_trainer_utils.py index 444ef8108b..9c2d74f115 100644 --- a/tests/train/test_trainer_utils.py +++ b/tests/train/test_trainer_utils.py @@ -922,3 +922,131 @@ def test_validate_generator_output_invalid_rewards(): generator_output["rewards"] = [[0.5, 0.6], [0.7, 0.8]] validate_generator_output(len(input_batch["prompts"]), generator_output) + + +# ============================================================ +# Step-wise validation tests +# ============================================================ + +from skyrl.train.generators.base import TrajectoryID + + +def _make_stepwise_output(n_trajectories=2, steps_per_traj=(2, 3), contiguous=True): + """Helper to build a step-wise GeneratorOutput for testing.""" + items = [] + for traj_idx in range(n_trajectories): + n_steps = steps_per_traj[traj_idx] + tid = TrajectoryID(instance_id=str(traj_idx), repetition_id=0) + for step in range(n_steps): + is_last = step == n_steps - 1 + prompt = list(range(10 + traj_idx * 100, 10 + traj_idx * 100 + 3 + step)) + resp = list(range(50 + traj_idx * 100 + step * 10, 50 + traj_idx * 100 + step * 10 + 3)) + reward = [0.0, 0.0, float(traj_idx + 1) if is_last else 0.0] + items.append((prompt, resp, reward, [1, 1, 1], is_last, tid)) + + if not contiguous: + max_steps = max(steps_per_traj) + reordered = [] + for step in range(max_steps): + for traj_idx in range(n_trajectories): + if step < steps_per_traj[traj_idx]: + idx = sum(steps_per_traj[:traj_idx]) + step + reordered.append(items[idx]) + items = reordered + + prompt_token_ids, response_ids, rewards, loss_masks = [], [], [], [] + is_last_step, trajectory_ids = [], [] + for prompt, resp, reward, mask, is_last, tid in items: + prompt_token_ids.append(prompt) + response_ids.append(resp) + rewards.append(reward) + loss_masks.append(mask) + is_last_step.append(is_last) + trajectory_ids.append(tid) + + return { + "prompt_token_ids": prompt_token_ids, + "response_ids": response_ids, + "rewards": rewards, + "loss_masks": loss_masks, + "stop_reasons": ["complete"] * len(response_ids), + "rollout_metrics": {}, + "rollout_logprobs": None, + "is_last_step": is_last_step, + "trajectory_ids": trajectory_ids, + } + + +def test_validate_stepwise_valid(): + """Valid step-wise output should pass validation.""" + output = _make_stepwise_output(n_trajectories=3, steps_per_traj=(1, 2, 3)) + validate_generator_output(num_prompts=3, generator_output=output, step_wise=True) + + +def test_validate_stepwise_single_step_trajectories(): + """All single-step trajectories should pass.""" + output = _make_stepwise_output(n_trajectories=4, steps_per_traj=(1, 1, 1, 1)) + validate_generator_output(num_prompts=4, generator_output=output, step_wise=True) + + +def test_validate_stepwise_missing_is_last_step(): + """Missing is_last_step should fail.""" + output = _make_stepwise_output() + del output["is_last_step"] + with pytest.raises(AssertionError, match="is_last_step.*missing"): + validate_generator_output(num_prompts=2, generator_output=output, step_wise=True) + + +def test_validate_stepwise_missing_trajectory_ids(): + """Missing trajectory_ids should fail.""" + output = _make_stepwise_output() + del output["trajectory_ids"] + with pytest.raises(AssertionError, match="trajectory_ids.*missing"): + validate_generator_output(num_prompts=2, generator_output=output, step_wise=True) + + +def test_validate_stepwise_is_last_step_length_mismatch(): + """is_last_step length mismatch should fail.""" + output = _make_stepwise_output() + output["is_last_step"] = output["is_last_step"][:-1] + with pytest.raises(AssertionError, match="is_last_step length"): + validate_generator_output(num_prompts=2, generator_output=output, step_wise=True) + + +def test_validate_stepwise_last_element_not_true(): + """is_last_step[-1] must be True.""" + output = _make_stepwise_output() + output["is_last_step"][-1] = False + with pytest.raises(AssertionError, match="is_last_step\\[-1\\] must be True"): + validate_generator_output(num_prompts=2, generator_output=output, step_wise=True) + + +def test_validate_stepwise_non_contiguous(): + """Non-contiguous trajectory ordering should fail.""" + output = _make_stepwise_output(n_trajectories=2, steps_per_traj=(2, 2), contiguous=False) + with pytest.raises(AssertionError, match="Non-contiguous trajectory"): + validate_generator_output(num_prompts=2, generator_output=output, step_wise=True) + + +def test_validate_stepwise_boundary_without_is_last(): + """Trajectory boundary where is_last_step is False should fail.""" + output = _make_stepwise_output(n_trajectories=2, steps_per_traj=(2, 2)) + # Traj 0 has steps at indices 0,1 and traj 1 at 2,3. Corrupt boundary. + output["is_last_step"][1] = False + with pytest.raises(AssertionError, match="Trajectory boundary at index 1"): + validate_generator_output(num_prompts=2, generator_output=output, step_wise=True) + + +def test_validate_stepwise_no_true_in_is_last_step(): + """is_last_step with no True values should fail.""" + output = _make_stepwise_output(n_trajectories=1, steps_per_traj=(3,)) + output["is_last_step"] = [False, False, False] + with pytest.raises(AssertionError, match="is_last_step\\[-1\\] must be True"): + validate_generator_output(num_prompts=1, generator_output=output, step_wise=True) + + +def test_validate_stepwise_num_prompts_not_checked(): + """In step-wise mode, num_prompts != num_responses is allowed (expansion).""" + output = _make_stepwise_output(n_trajectories=2, steps_per_traj=(2, 3)) + # 5 step-samples from 2 prompts + validate_generator_output(num_prompts=2, generator_output=output, step_wise=True)