Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion docs/content/docs/algorithms/custom_algorithms.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,9 @@ We show the outline of creating a custom trainer below, and you can find a full
```python
class CustomTrainer(RayPPOTrainer):
@torch.no_grad()
def postprocess_generator_output(self, generator_output: GeneratorOutput, uids: List[str]) -> GeneratorOutput:
def postprocess_generator_output(
self, generator_output: GeneratorOutput, uids: List[str]
) -> Tuple[GeneratorOutput, List[str]]:
# apply custom reward penalties
...
# use base class impl for metrics and per-token reward conversion
Expand Down
4 changes: 3 additions & 1 deletion docs/content/docs/algorithms/dapo.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,9 @@ We provide an example of this in `examples/train/algorithms/dapo/main_dapo.py`,
```python title="examples/train/algorithms/dapo/main_dapo.py"
class DAPOTrainer(RayPPOTrainer):
@torch.no_grad()
def postprocess_generator_output(self, generator_output: GeneratorOutput, uids: List[str]) -> GeneratorOutput:
def postprocess_generator_output(
self, generator_output: GeneratorOutput, uids: List[str]
) -> Tuple[GeneratorOutput, List[str]]:
# apply soft overlong punishment
overlong_buffer_len = self.cfg.trainer.algorithm.overlong_buffer.len
overlong_buffer_penalty_factor = self.cfg.trainer.algorithm.overlong_buffer.penalty_factor
Expand Down
2 changes: 1 addition & 1 deletion docs/content/docs/tutorials/agent-integration.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,6 @@ Your agent harness can still use `/chat/completions` with tool call parsing, sin

**Cons:**
- Training time can grow: O(T^2) vs O(T), since each trajectory of T turns becomes T sequences to forward (each with a growing prefix), as opposed to 1 sequence.
- SkyRL will support prefix-aware merging of per-step sequences when the prefix matches (WIP), which brings the cost back to O(T) in the common case.
- SkyRL support prefix-aware merging of per-step sequences when the prefix matches with config flag `generator.merge_stepwise_output`, which can reduce the O(T^2) cost if chat history is linearly appending across turns and there is no token mismatch. See https://github.com/NovaSky-AI/SkyRL/pull/1532

For the full details on how to structure the `GeneratorOutput` for step-wise training, including the required fields, invariants, and a concrete example, see: [Step-Wise Training](step-wise-training).
2 changes: 1 addition & 1 deletion docs/content/docs/tutorials/step-wise-training.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ When step-wise is enabled, a batch of T trajectories with an average of M turns

- **Each mini-batch contains the sequences for exactly `policy_mini_batch_size` prompts**, regardless of how many turns those prompts produced. This means the number of mini-batches (and hence optimizer steps) per training batch is always `train_batch_size / policy_mini_batch_size`, independent of the number of turns. This also means that the actual mini batch size (number of sequences) trained in each mini batch can vary. Each mini batch always leads to a single optimizer step.
- **Advantages are computed on last steps only**, then broadcast to all steps of the same trajectory. This is mathematically equivalent to non-step-wise advantage computation for GRPO.
- **Training time grows as O(T²) vs O(T)**, since each trajectory of T turns becomes T sequences to forward (each with a growing prompt prefix), as opposed to 1 sequence. SkyRL will support prefix-aware merging of per-step sequences when the prefix matches (WIP), which brings the cost back to O(T) in the common case.
- **Training time grows as O(T²) vs O(T)**, since each trajectory of T turns becomes T sequences to forward (each with a growing prompt prefix), as opposed to 1 sequence. SkyRL supports prefix-aware merging of per-step sequences when the prefix matches with config flag `generator.merge_stepwise_output`, which can reduce the O(T²) cost if chat history is linearly appending across turns and there is no token mismatch. See https://github.com/NovaSky-AI/SkyRL/pull/1532
- **Metrics** like `generate/avg_num_tokens` and `generate/avg_response_length` are per-turn rather than per-trajectory, since each training sample is a single turn.

Some algorithms have their behavior altered by step-wise decomposition, since each turn is now treated as its own sequence:
Expand Down
4 changes: 3 additions & 1 deletion examples/train/algorithms/dapo/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,9 @@ To enable soft overlong punishment, you can create a custom trainer class and ov
```python
class DAPOTrainer(RayPPOTrainer):
@torch.no_grad()
def postprocess_generator_output(self, generator_output: GeneratorOutput, uids: List[str]) -> GeneratorOutput:
def postprocess_generator_output(
self, generator_output: GeneratorOutput, uids: List[str]
) -> Tuple[GeneratorOutput, List[str]]:
# apply soft overlong punishment
overlong_buffer_len = self.cfg.trainer.algorithm.overlong_buffer_len
overlong_buffer_penalty_factor = self.cfg.trainer.algorithm.overlong_buffer_penalty_factor
Expand Down
6 changes: 4 additions & 2 deletions examples/train/algorithms/dapo/main_dapo.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import ray
import torch
from dataclasses import dataclass
from typing import List
from typing import List, Tuple

from skyrl.train.config import AlgorithmConfig, make_config
from skyrl.train.trainer import RayPPOTrainer
Expand Down Expand Up @@ -36,7 +36,9 @@ class DAPOTrainer(RayPPOTrainer):
"""

@torch.no_grad()
def postprocess_generator_output(self, generator_output: GeneratorOutput, uids: List[str]) -> GeneratorOutput:
def postprocess_generator_output(
self, generator_output: GeneratorOutput, uids: List[str]
) -> Tuple[GeneratorOutput, List[str]]:
# NOTE (sumanthrh): Given the usage of `make_config`, the algorithm config subclass for DAPO is
# created dynamically and thus IDEs will not be able to resolve the attributes
# For better typing, you can always define a custom subclass of DAPOConfig manually.
Expand Down
6 changes: 4 additions & 2 deletions examples/train/algorithms/dapo/main_dapo_fully_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import ray
import torch
from typing import List
from typing import List, Tuple

from skyrl.train.fully_async_trainer import FullyAsyncRayPPOTrainer
from skyrl.train.utils import initialize_ray, validate_cfg
Expand All @@ -19,7 +19,9 @@

class FullyAsyncDAPOTrainer(FullyAsyncRayPPOTrainer):
@torch.no_grad()
def postprocess_generator_output(self, generator_output: GeneratorOutput, uids: List[str]) -> GeneratorOutput:
def postprocess_generator_output(
self, generator_output: GeneratorOutput, uids: List[str]
) -> Tuple[GeneratorOutput, List[str]]:
"""
Overrides the postprocess_generator_output method to additionally apply DAPO specific soft overlong punishment to rewards.

Expand Down
2 changes: 1 addition & 1 deletion examples/train/async/async_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ async def _run_generate_loop(self, generation_buffer: asyncio.Queue):
# generation phase
async with Timer("generate", self.all_timings):
generator_output: GeneratorOutput = await self.generate(generator_input)
generator_output = self.postprocess_generator_output(generator_output, uids)
generator_output, uids = self.postprocess_generator_output(generator_output, uids)

# Add to generation buffer
await generation_buffer.put((generator_output, uids))
Expand Down
8 changes: 5 additions & 3 deletions examples/train/flash_rl/main_dapo_flashrl.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import ray
import torch
from dataclasses import dataclass
from typing import List
from typing import List, Tuple

from skyrl.train.config import SkyRLTrainConfig, AlgorithmConfig, make_config
from skyrl.train.trainer import RayPPOTrainer
Expand Down Expand Up @@ -64,7 +64,9 @@ class DAPOTrainer(RayPPOTrainer):
"""

@torch.no_grad()
def postprocess_generator_output(self, generator_output: GeneratorOutput, uids: List[str]) -> GeneratorOutput:
def postprocess_generator_output(
self, generator_output: GeneratorOutput, uids: List[str]
) -> Tuple[GeneratorOutput, List[str]]:
"""
Overrides the postprocess_generator_output method to additionally apply DAPO specific soft overlong punishment to rewards.

Expand All @@ -73,7 +75,7 @@ def postprocess_generator_output(self, generator_output: GeneratorOutput, uids:
uids: List[str]

Returns:
GeneratorOutput
(GeneratorOutput, uids) — uids may be shortened if base class applies step-wise merging.
"""
overlong_buffer_len = self.cfg.trainer.algorithm.overlong_buffer_len
overlong_buffer_penalty_factor = self.cfg.trainer.algorithm.overlong_buffer_penalty_factor
Expand Down
5 changes: 5 additions & 0 deletions examples/train/search/run_search.sh
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ else
MULTI_TURN_ARGS="generator.use_conversation_multi_turn=false"
fi

: "${MERGE_STEPWISE:=false}"

STEP_WISE_ARGS=""
if [ "$STEP_WISE" = "true" ]; then
STEP_WISE_ARGS="generator.step_wise_trajectories=true"
Expand All @@ -56,6 +58,9 @@ if [ "$STEP_WISE" = "true" ]; then
echo "WARNING: STEP_WISE=true requires USE_CONVERSATION_MULTI_TURN=true. Enabling it automatically."
MULTI_TURN_ARGS="generator.use_conversation_multi_turn=true generator.append_eos_token_after_stop_str_in_multi_turn=true"
fi
if [ "$MERGE_STEPWISE" = "true" ]; then
STEP_WISE_ARGS="$STEP_WISE_ARGS generator.merge_stepwise_output=true"
fi
fi

uv run --isolated --frozen --extra fsdp -m skyrl.train.entrypoints.main_base \
Expand Down
8 changes: 5 additions & 3 deletions examples/train/tis_correction/main_tis_dapo.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import ray
import torch
from dataclasses import dataclass
from typing import List
from typing import List, Tuple

from skyrl.train.config import AlgorithmConfig, make_config
from skyrl.train.trainer import RayPPOTrainer
Expand Down Expand Up @@ -36,7 +36,9 @@ class DAPOTrainer(RayPPOTrainer):
"""

@torch.no_grad()
def postprocess_generator_output(self, generator_output: GeneratorOutput, uids: List[str]) -> GeneratorOutput:
def postprocess_generator_output(
self, generator_output: GeneratorOutput, uids: List[str]
) -> Tuple[GeneratorOutput, List[str]]:
"""
Overrides the postprocess_generator_output method to additionally apply DAPO specific soft overlong punishment to rewards.

Expand All @@ -45,7 +47,7 @@ def postprocess_generator_output(self, generator_output: GeneratorOutput, uids:
uids: List[str]

Returns:
GeneratorOutput
(GeneratorOutput, uids) — uids may be shortened if base class applies step-wise merging.
"""
overlong_buffer_len = self.cfg.trainer.algorithm.overlong_buffer_len
overlong_buffer_penalty_factor = self.cfg.trainer.algorithm.overlong_buffer_penalty_factor
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ async def train(self):

# 1.2 postprocess rewards
with Timer("postprocess_generator_output", self.all_timings):
generator_output = self.postprocess_generator_output(generator_output, uids)
generator_output, uids = self.postprocess_generator_output(generator_output, uids)

# 2. print example just for debugging
vis = self.tokenizer.decode(generator_output["response_ids"][0])
Expand Down
3 changes: 3 additions & 0 deletions skyrl/train/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,6 +536,9 @@ class GeneratorConfig(BaseConfig):
"""Can differ from the trainer's ``rope_scaling``, useful for thinking models."""
rope_theta: Optional[float] = None
step_wise_trajectories: bool = False
merge_stepwise_output: bool = False
"""When True (and step_wise_trajectories is True), apply prefix-aware merging
to collapse multi-turn step-wise sequences into single sequences before training."""

def __post_init__(self):

Expand Down
1 change: 1 addition & 0 deletions skyrl/train/config/ppo_base_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,7 @@ generator:
rope_theta: ${trainer.rope_theta}

step_wise_trajectories: false
merge_stepwise_output: false

environment:
env_class: "gsm8k"
Expand Down
2 changes: 1 addition & 1 deletion skyrl/train/fully_async_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -662,7 +662,7 @@ def convert_generation_group_mini_batch_to_training_input(
)

# Convert rewards to per-token form and compute reward metrics before training conversion
generator_output = self.postprocess_generator_output(generator_output, uids)
generator_output, uids = self.postprocess_generator_output(generator_output, uids)

# print example just for debugging
vis = self.tokenizer.decode(generator_output["response_ids"][0])
Expand Down
Loading
Loading