diff --git a/skyrl/train/generators/skyrl_gym_generator.py b/skyrl/train/generators/skyrl_gym_generator.py index 6f0d27cacf..c50c6ba3a5 100644 --- a/skyrl/train/generators/skyrl_gym_generator.py +++ b/skyrl/train/generators/skyrl_gym_generator.py @@ -134,6 +134,43 @@ def get_turn_rollout_logprobs(self) -> Optional[List[float]]: return None return self.output_logprobs + [0.0] * len(self.obs_ids) + # ---------------------------------------------------------------------------- + # Gen-only variants (step-wise path). Step-wise emits one training sample per + # turn, and each sample's response is just the model's generation — the + # observation tokens belong to the *next* turn's prompt. The helpers below + # return aligned loss_mask / rollout_logprobs / rollout_expert_indices for + # the gen-only response so downstream length invariants stay consistent. + # ---------------------------------------------------------------------------- + def get_turn_gen_only_loss_mask(self) -> List[int]: + """Like `get_turn_loss_mask()` but without the trailing observation zeros.""" + if self.added_eos: + # The EOS wasn't actually generated, so mask it out. + return [1] * (len(self.output_ids) - 1) + [0] + return [1] * len(self.output_ids) + + def get_turn_gen_only_rollout_logprobs(self) -> Optional[List[float]]: + """Like `get_turn_rollout_logprobs()` but without the trailing observation zeros. + + `output_logprobs` already covers the manually appended EOS (agent_loop adds a 0.0 + stub there), so this is exactly the list we want. + """ + if not self.output_logprobs: + return None + return list(self.output_logprobs) + + def get_turn_gen_only_rollout_expert_indices(self) -> Optional[List[List[List[int]]]]: + """Like `get_turn_rollout_expert_indices()` but without the trailing observation pads.""" + if self.rollout_expert_indices is None: + return None + if not self.rollout_expert_indices: + return self.rollout_expert_indices + indices = list(self.rollout_expert_indices) + if self.added_eos: + layer_num = len(self.rollout_expert_indices[0]) + topk = len(self.rollout_expert_indices[0][0]) if layer_num > 0 else 0 + indices.append([[0] * topk for _ in range(layer_num)]) + return indices + class SkyRLGymGenerator(GeneratorInterface): def __init__( @@ -317,10 +354,6 @@ async def agent_loop( while not agent_loop_state.done: - if len(agent_loop_state.input_ids) > max_input_length: - stop_reason = "length" - break - # 1. Generate output if is_step_wise or retokenize_chat_history: # re-apply whole chat template so length check is correct @@ -335,6 +368,10 @@ async def agent_loop( agent_loop_state.loss_mask = [] agent_loop_state.rollout_logprobs = None + if len(agent_loop_state.input_ids) > max_input_length: + stop_reason = "length" + break + engine_input = InferenceEngineInput( prompt_token_ids=[agent_loop_state.input_ids], session_ids=[session_id], sampling_params=sampling_params ) @@ -403,13 +440,16 @@ async def agent_loop( agent_loop_state.rollout_expert_indices = [] if is_step_wise: - # current response + observation ids - turn_response_ids = turn_output.output_ids + turn_output.obs_ids + # Gen-only response: the obs tokens will appear in the *next* turn's prompt + # (or be recovered as `obs_delta` under `merge_stepwise_output`). Keeping obs + # out of `response_ids` caps per-step sequence length at + # ``max_input_length + max_generate_length`` — previously obs inflated each + # training sample by up to several thousand tokens for e.g. Search-R1. + turn_response_ids = turn_output.output_ids turn_prompt_ids = agent_loop_state.input_ids - # agent loop only tracks loss mask and rollout logprobs for this turn with step_wise training - turn_loss_mask = turn_output.get_turn_loss_mask() - turn_response_logprobs: Optional[List[float]] = turn_output.get_turn_rollout_logprobs() + turn_loss_mask = turn_output.get_turn_gen_only_loss_mask() + turn_response_logprobs: Optional[List[float]] = turn_output.get_turn_gen_only_rollout_logprobs() per_step_output = TrajectoryOutput( response_ids=turn_response_ids, @@ -419,7 +459,7 @@ async def agent_loop( rollout_logprobs=turn_response_logprobs, stop_reason=stop_reason, env_metrics=env.get_metrics() if agent_loop_state.done else {}, - rollout_expert_indices=turn_output.get_turn_rollout_expert_indices(), + rollout_expert_indices=turn_output.get_turn_gen_only_rollout_expert_indices(), ) agent_loop_output.step_outputs.append(per_step_output) @@ -494,9 +534,8 @@ async def agent_loop( assert response_ids is not None and loss_mask is not None if stop_reason != "length" and response_ids and response_ids[-1] != self.tokenizer.eos_token_id: response_ids.append(self.tokenizer.eos_token_id) - # TODO(Charlie): this should be 0? Otherwise logprobs will be extremely off. But if it is loss - # masked with 0, why bother adding it? - loss_mask.append(1) + # The EOS was not actually generated by the model, so mask it out. + loss_mask.append(0) if rollout_logprobs is not None: rollout_logprobs.append(0.0) if rollout_expert_indices_out is not None and rollout_expert_indices_out: diff --git a/tests/train/generators/test_skyrl_gym_generator.py b/tests/train/generators/test_skyrl_gym_generator.py index cf973457ae..fbab844d31 100644 --- a/tests/train/generators/test_skyrl_gym_generator.py +++ b/tests/train/generators/test_skyrl_gym_generator.py @@ -304,7 +304,9 @@ def mock_generate(_): # No EOS: just add it expected_response_ids = mock_llm_output_ids + [mock_tokenizer.eos_token_id] - expected_loss_mask = [1] * (len(expected_response_ids)) + # The trailing EOS in the singleturn path is always post-hoc appended (the model's EOS + # was stripped and re-added), so it carries no real logprob and must be masked out. + expected_loss_mask = [1] * (len(expected_response_ids) - 1) + [0] if logprobs_setting is not None: assert output.rollout_logprobs is not None @@ -792,14 +794,15 @@ async def llm_generate_side_effect(input_batch): output_normal = await generator.generate(input_batch_normal) - # Verify normal response keeps original loss mask (all 1s) + # Verify normal response loss mask: the singleturn path strips the model's EOS and re-appends + # it post-hoc, and that appended EOS carries no real logprob so it is masked out (0). assert len(output_normal["loss_masks"]) == 1 assert len(output_normal["loss_masks"][0]) == 3 # 3 response tokens (already includes EOS token) assert output_normal["loss_masks"][0] == [ 1, 1, - 1, - ], "Loss mask should remain as 1s for response ending with eos token" + 0, + ], "Loss mask: real tokens = 1, post-hoc appended EOS = 0" @pytest.mark.asyncio @@ -1494,11 +1497,13 @@ def step(self, action): assert isinstance(reward, list), f"rewards[{i}] should be a list (per-token rewards)" assert all(isinstance(r, (int, float)) for r in reward), f"rewards[{i}] should contain numeric values" - # Validate response_ids structure + # Validate response_ids are gen-only: obs tokens must not be appended to step-wise response. + # The mock LLM returns [10, 11, 12, eos=4] per turn; with obs excluded, that's exactly the + # expected per-step response. for i, response_ids in enumerate(generator_output["response_ids"]): - assert isinstance(response_ids, list), f"response_ids[{i}] should be a list" - assert len(response_ids) > 0, f"response_ids[{i}] should not be empty" - assert all(isinstance(token, int) for token in response_ids), f"response_ids[{i}] should contain integers" + assert response_ids == [10, 11, 12, mock_tokenizer.eos_token_id], ( + f"response_ids[{i}] should be gen-only output tokens (no obs appended), got {response_ids}" + ) # Validate loss_masks structure for i, loss_mask in enumerate(generator_output["loss_masks"]): @@ -1511,3 +1516,97 @@ def step(self, action): # Validate stop_reasons for i, stop_reason in enumerate(generator_output["stop_reasons"]): assert isinstance(stop_reason, str), f"stop_reasons[{i}] should be a string" + + +@pytest.mark.asyncio +@patch("skyrl_gym.make") +async def test_step_wise_trajectories_length_check_uses_current_prompt( + mock_make, mock_tokenizer, mock_llm, mock_env_cfg +): + """Regression: step-wise length check must run AFTER the per-turn re-tokenization. + + Previously, `len(agent_loop_state.input_ids) > max_input_length` was checked against the + stale input_ids left over from the previous turn's re-tokenize (which reflected chat_history + one turn behind). That let a turn whose freshly re-tokenized prompt exceeded `max_input_length` + still generate — producing a step with an over-length prompt_ids (observed as + `generate/batch_padded_seq_len` spikes up to ~10k in Search-R1 step-wise runs). + """ + from skyrl.train.generators.base import TrajectoryID + + mock_tokenizer.eos_token_id = 4 + + # apply_chat_template length grows with chat_history message count — 10 tokens per message. + def apply_chat_template_side_effect(messages, **kwargs): + if kwargs.get("tokenize", True): + return [7] * (10 * len(messages)) + return "".join([m.get("content", "") for m in messages]) + + mock_tokenizer.apply_chat_template.side_effect = apply_chat_template_side_effect + + async def llm_generate_side_effect(input_batch): + num = len(input_batch["prompt_token_ids"]) if "prompt_token_ids" in input_batch else len(input_batch["prompts"]) + return { + "responses": ["step"] * num, + "stop_reasons": ["stop"] * num, + "response_logprobs": None, + "response_ids": [[10, 11, 12, mock_tokenizer.eos_token_id] for _ in range(num)], + } + + mock_llm.generate = AsyncMock(side_effect=llm_generate_side_effect) + + class NeverDoneEnv(BaseTextEnv): + def init(self, prompt): + return prompt, {} + + def step(self, action): + return BaseTextEnvStepOutput( + observations=[{"role": "user", "content": "obs"}], reward=0.0, done=False, metadata={} + ) + + mock_make.side_effect = lambda *a, **k: NeverDoneEnv() + + cfg = GeneratorConfig() + cfg.sampling_params.max_generate_length = 50 + cfg.sampling_params.logprobs = None + cfg.apply_overlong_filtering = False + # Turn 1 prompt is 10 tokens (initial user only); turn 2 prompt is 30 tokens (user + assistant + # + user-obs). Choose a limit between the two: the fix must stop at turn 2 re-tokenize, not + # after generating a second step with a 30-token prompt. + cfg.max_input_length = 25 + cfg.batched = False + cfg.max_turns = 10 + cfg.zero_reward_on_non_stop = False + cfg.use_conversation_multi_turn = True + cfg.step_wise_trajectories = True + cfg.chat_template = ChatTemplateConfig(source="name", name_or_path=None) + + generator = SkyRLGymGenerator( + generator_cfg=cfg, + skyrl_gym_cfg=mock_env_cfg, + inference_engine_client=mock_llm, + tokenizer=mock_tokenizer, + ) + generator.base_conversation_token_ids = [] + + out = await generator.agent_loop( + [{"role": "user", "content": "Q?"}], + mock_env_cfg.env_class, + {}, + max_tokens=50, + max_input_length=cfg.max_input_length, + trajectory_id=TrajectoryID(instance_id="uid1", repetition_id=0), + ) + + # Exactly one step emitted (turn 1). Turn 2's re-tokenized prompt (30) exceeds 25, so the + # loop must break before generating a second step. + assert len(out.step_outputs) == 1, ( + f"Expected exactly 1 step before hitting max_input_length=25, got " + f"{len(out.step_outputs)}. Each recorded prompt length: " + f"{[len(s.prompt_ids) for s in out.step_outputs]}" + ) + # And no step should have a prompt exceeding the limit. + for i, step in enumerate(out.step_outputs): + assert len(step.prompt_ids) <= cfg.max_input_length, ( + f"step {i} prompt_ids length {len(step.prompt_ids)} exceeds max_input_length " + f"{cfg.max_input_length}" + )