Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 52 additions & 13 deletions skyrl/train/generators/skyrl_gym_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,43 @@ def get_turn_rollout_logprobs(self) -> Optional[List[float]]:
return None
return self.output_logprobs + [0.0] * len(self.obs_ids)

# ----------------------------------------------------------------------------
# Gen-only variants (step-wise path). Step-wise emits one training sample per
# turn, and each sample's response is just the model's generation — the
# observation tokens belong to the *next* turn's prompt. The helpers below
# return aligned loss_mask / rollout_logprobs / rollout_expert_indices for
# the gen-only response so downstream length invariants stay consistent.
# ----------------------------------------------------------------------------
def get_turn_gen_only_loss_mask(self) -> List[int]:
"""Like `get_turn_loss_mask()` but without the trailing observation zeros."""
if self.added_eos:
# The EOS wasn't actually generated, so mask it out.
return [1] * (len(self.output_ids) - 1) + [0]
return [1] * len(self.output_ids)

def get_turn_gen_only_rollout_logprobs(self) -> Optional[List[float]]:
"""Like `get_turn_rollout_logprobs()` but without the trailing observation zeros.

`output_logprobs` already covers the manually appended EOS (agent_loop adds a 0.0
stub there), so this is exactly the list we want.
"""
if not self.output_logprobs:
return None
return list(self.output_logprobs)

def get_turn_gen_only_rollout_expert_indices(self) -> Optional[List[List[List[int]]]]:
"""Like `get_turn_rollout_expert_indices()` but without the trailing observation pads."""
if self.rollout_expert_indices is None:
return None
if not self.rollout_expert_indices:
return self.rollout_expert_indices
indices = list(self.rollout_expert_indices)
if self.added_eos:
layer_num = len(self.rollout_expert_indices[0])
topk = len(self.rollout_expert_indices[0][0]) if layer_num > 0 else 0
indices.append([[0] * topk for _ in range(layer_num)])
return indices
Comment on lines +161 to +172
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The implementation of get_turn_gen_only_rollout_expert_indices contains a potential IndexError if self.rollout_expert_indices is an empty list (e.g., in certain mock or edge cases where the model generates no tokens). Accessing self.rollout_expert_indices[0] at line 169 will fail. While line 165 provides a safety check, it returns the empty list immediately, which would cause a length mismatch with response_ids if added_eos is True. Consider ensuring that layer_num and topk can be determined or return None if the shape is ambiguous.



class SkyRLGymGenerator(GeneratorInterface):
def __init__(
Expand Down Expand Up @@ -317,10 +354,6 @@ async def agent_loop(

while not agent_loop_state.done:

if len(agent_loop_state.input_ids) > max_input_length:
stop_reason = "length"
break

# 1. Generate output
if is_step_wise or retokenize_chat_history:
# re-apply whole chat template so length check is correct
Expand All @@ -335,6 +368,10 @@ async def agent_loop(
agent_loop_state.loss_mask = []
agent_loop_state.rollout_logprobs = None

if len(agent_loop_state.input_ids) > max_input_length:
stop_reason = "length"
break

engine_input = InferenceEngineInput(
prompt_token_ids=[agent_loop_state.input_ids], session_ids=[session_id], sampling_params=sampling_params
)
Expand Down Expand Up @@ -403,13 +440,16 @@ async def agent_loop(
agent_loop_state.rollout_expert_indices = []

if is_step_wise:
# current response + observation ids
turn_response_ids = turn_output.output_ids + turn_output.obs_ids
# Gen-only response: the obs tokens will appear in the *next* turn's prompt
# (or be recovered as `obs_delta` under `merge_stepwise_output`). Keeping obs
# out of `response_ids` caps per-step sequence length at
# ``max_input_length + max_generate_length`` — previously obs inflated each
# training sample by up to several thousand tokens for e.g. Search-R1.
turn_response_ids = turn_output.output_ids
turn_prompt_ids = agent_loop_state.input_ids

# agent loop only tracks loss mask and rollout logprobs for this turn with step_wise training
turn_loss_mask = turn_output.get_turn_loss_mask()
turn_response_logprobs: Optional[List[float]] = turn_output.get_turn_rollout_logprobs()
turn_loss_mask = turn_output.get_turn_gen_only_loss_mask()
turn_response_logprobs: Optional[List[float]] = turn_output.get_turn_gen_only_rollout_logprobs()

per_step_output = TrajectoryOutput(
response_ids=turn_response_ids,
Expand All @@ -419,7 +459,7 @@ async def agent_loop(
rollout_logprobs=turn_response_logprobs,
stop_reason=stop_reason,
env_metrics=env.get_metrics() if agent_loop_state.done else {},
rollout_expert_indices=turn_output.get_turn_rollout_expert_indices(),
rollout_expert_indices=turn_output.get_turn_gen_only_rollout_expert_indices(),
)
agent_loop_output.step_outputs.append(per_step_output)

Expand Down Expand Up @@ -494,9 +534,8 @@ async def agent_loop(
assert response_ids is not None and loss_mask is not None
if stop_reason != "length" and response_ids and response_ids[-1] != self.tokenizer.eos_token_id:
response_ids.append(self.tokenizer.eos_token_id)
# TODO(Charlie): this should be 0? Otherwise logprobs will be extremely off. But if it is loss
# masked with 0, why bother adding it?
loss_mask.append(1)
# The EOS was not actually generated by the model, so mask it out.
loss_mask.append(0)
if rollout_logprobs is not None:
rollout_logprobs.append(0.0)
if rollout_expert_indices_out is not None and rollout_expert_indices_out:
Expand Down
115 changes: 107 additions & 8 deletions tests/train/generators/test_skyrl_gym_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,9 @@ def mock_generate(_):
# No EOS: just add it
expected_response_ids = mock_llm_output_ids + [mock_tokenizer.eos_token_id]

expected_loss_mask = [1] * (len(expected_response_ids))
# The trailing EOS in the singleturn path is always post-hoc appended (the model's EOS
# was stripped and re-added), so it carries no real logprob and must be masked out.
expected_loss_mask = [1] * (len(expected_response_ids) - 1) + [0]

if logprobs_setting is not None:
assert output.rollout_logprobs is not None
Expand Down Expand Up @@ -792,14 +794,15 @@ async def llm_generate_side_effect(input_batch):

output_normal = await generator.generate(input_batch_normal)

# Verify normal response keeps original loss mask (all 1s)
# Verify normal response loss mask: the singleturn path strips the model's EOS and re-appends
# it post-hoc, and that appended EOS carries no real logprob so it is masked out (0).
assert len(output_normal["loss_masks"]) == 1
assert len(output_normal["loss_masks"][0]) == 3 # 3 response tokens (already includes EOS token)
assert output_normal["loss_masks"][0] == [
1,
1,
1,
], "Loss mask should remain as 1s for response ending with eos token"
0,
], "Loss mask: real tokens = 1, post-hoc appended EOS = 0"


@pytest.mark.asyncio
Expand Down Expand Up @@ -1494,11 +1497,13 @@ def step(self, action):
assert isinstance(reward, list), f"rewards[{i}] should be a list (per-token rewards)"
assert all(isinstance(r, (int, float)) for r in reward), f"rewards[{i}] should contain numeric values"

# Validate response_ids structure
# Validate response_ids are gen-only: obs tokens must not be appended to step-wise response.
# The mock LLM returns [10, 11, 12, eos=4] per turn; with obs excluded, that's exactly the
# expected per-step response.
for i, response_ids in enumerate(generator_output["response_ids"]):
assert isinstance(response_ids, list), f"response_ids[{i}] should be a list"
assert len(response_ids) > 0, f"response_ids[{i}] should not be empty"
assert all(isinstance(token, int) for token in response_ids), f"response_ids[{i}] should contain integers"
assert response_ids == [10, 11, 12, mock_tokenizer.eos_token_id], (
f"response_ids[{i}] should be gen-only output tokens (no obs appended), got {response_ids}"
)

# Validate loss_masks structure
for i, loss_mask in enumerate(generator_output["loss_masks"]):
Expand All @@ -1511,3 +1516,97 @@ def step(self, action):
# Validate stop_reasons
for i, stop_reason in enumerate(generator_output["stop_reasons"]):
assert isinstance(stop_reason, str), f"stop_reasons[{i}] should be a string"


@pytest.mark.asyncio
@patch("skyrl_gym.make")
async def test_step_wise_trajectories_length_check_uses_current_prompt(
mock_make, mock_tokenizer, mock_llm, mock_env_cfg
):
"""Regression: step-wise length check must run AFTER the per-turn re-tokenization.

Previously, `len(agent_loop_state.input_ids) > max_input_length` was checked against the
stale input_ids left over from the previous turn's re-tokenize (which reflected chat_history
one turn behind). That let a turn whose freshly re-tokenized prompt exceeded `max_input_length`
still generate — producing a step with an over-length prompt_ids (observed as
`generate/batch_padded_seq_len` spikes up to ~10k in Search-R1 step-wise runs).
"""
from skyrl.train.generators.base import TrajectoryID

mock_tokenizer.eos_token_id = 4

# apply_chat_template length grows with chat_history message count — 10 tokens per message.
def apply_chat_template_side_effect(messages, **kwargs):
if kwargs.get("tokenize", True):
return [7] * (10 * len(messages))
return "".join([m.get("content", "") for m in messages])

mock_tokenizer.apply_chat_template.side_effect = apply_chat_template_side_effect

async def llm_generate_side_effect(input_batch):
num = len(input_batch["prompt_token_ids"]) if "prompt_token_ids" in input_batch else len(input_batch["prompts"])
return {
"responses": ["step"] * num,
"stop_reasons": ["stop"] * num,
"response_logprobs": None,
"response_ids": [[10, 11, 12, mock_tokenizer.eos_token_id] for _ in range(num)],
}

mock_llm.generate = AsyncMock(side_effect=llm_generate_side_effect)

class NeverDoneEnv(BaseTextEnv):
def init(self, prompt):
return prompt, {}

def step(self, action):
return BaseTextEnvStepOutput(
observations=[{"role": "user", "content": "obs"}], reward=0.0, done=False, metadata={}
)

mock_make.side_effect = lambda *a, **k: NeverDoneEnv()

cfg = GeneratorConfig()
cfg.sampling_params.max_generate_length = 50
cfg.sampling_params.logprobs = None
cfg.apply_overlong_filtering = False
# Turn 1 prompt is 10 tokens (initial user only); turn 2 prompt is 30 tokens (user + assistant
# + user-obs). Choose a limit between the two: the fix must stop at turn 2 re-tokenize, not
# after generating a second step with a 30-token prompt.
cfg.max_input_length = 25
cfg.batched = False
cfg.max_turns = 10
cfg.zero_reward_on_non_stop = False
cfg.use_conversation_multi_turn = True
cfg.step_wise_trajectories = True
cfg.chat_template = ChatTemplateConfig(source="name", name_or_path=None)

generator = SkyRLGymGenerator(
generator_cfg=cfg,
skyrl_gym_cfg=mock_env_cfg,
inference_engine_client=mock_llm,
tokenizer=mock_tokenizer,
)
generator.base_conversation_token_ids = []

out = await generator.agent_loop(
[{"role": "user", "content": "Q?"}],
mock_env_cfg.env_class,
{},
max_tokens=50,
max_input_length=cfg.max_input_length,
trajectory_id=TrajectoryID(instance_id="uid1", repetition_id=0),
)

# Exactly one step emitted (turn 1). Turn 2's re-tokenized prompt (30) exceeds 25, so the
# loop must break before generating a second step.
assert len(out.step_outputs) == 1, (
f"Expected exactly 1 step before hitting max_input_length=25, got "
f"{len(out.step_outputs)}. Each recorded prompt length: "
f"{[len(s.prompt_ids) for s in out.step_outputs]}"
)
# And no step should have a prompt exceeding the limit.
for i, step in enumerate(out.step_outputs):
assert len(step.prompt_ids) <= cfg.max_input_length, (
f"step {i} prompt_ids length {len(step.prompt_ids)} exceeds max_input_length "
f"{cfg.max_input_length}"
)
Loading