diff --git a/nemo_rl/experience/rollouts.py b/nemo_rl/experience/rollouts.py index a556a32a42..567add0dfc 100644 --- a/nemo_rl/experience/rollouts.py +++ b/nemo_rl/experience/rollouts.py @@ -311,7 +311,9 @@ def run_multi_turn_rollout( >= max_seq_len ): # truncate - tokenized_obs = tokenized_obs[: max_seq_len - active_input_lengths[i]] + tokenized_obs = tokenized_obs[ + : max_seq_len - (len(generated_ids[i]) + active_input_lengths[i]) + ] truncation_mask[i] = True # Record truncation sample_truncated[active_indices[i]] = True diff --git a/tests/unit/experience/test_rollouts.py b/tests/unit/experience/test_rollouts.py index b45811d4f8..bcfa1b84d2 100644 --- a/tests/unit/experience/test_rollouts.py +++ b/tests/unit/experience/test_rollouts.py @@ -20,6 +20,7 @@ import torch from transformers import AutoTokenizer +from nemo_rl.data.llm_message_utils import batched_message_log_to_flat_message from nemo_rl.distributed.batched_data_dict import BatchedDataDict from nemo_rl.distributed.virtual_cluster import RayVirtualCluster from nemo_rl.environments.games.sliding_puzzle import ( @@ -440,6 +441,45 @@ def test_run_multi_step_calculator_vllm(multi_step_setup_vllm): print("\nMulti-Step Calculator VLLM Test assertions passed.") +@pytest.mark.skipif( + not torch.cuda.is_available() or torch.cuda.device_count() < 1, + reason="VLLM test requires at least 1 GPU", +) +def test_max_seqlen_respected(multi_step_setup_vllm): + """Tests multi-step calculator rollout with VllmGeneration.""" + vllm_generation, rollout_tokenizer, task_to_env, initial_batch, rollout_cluster = ( + multi_step_setup_vllm + ) + max_rollout_turns = initial_batch["extra_env_info"][0]["max_steps"] + 1 + max_seq_len = 290 + + print("\nRunning multi-step calculator rollout (VLLM)...") + vllm_generation.prepare_for_generation() + final_batch, rollout_metrics = run_multi_turn_rollout( + policy_generation=vllm_generation, + input_batch=initial_batch, + tokenizer=rollout_tokenizer, + task_to_env=task_to_env, + max_seq_len=max_seq_len, + max_rollout_turns=max_rollout_turns, + ) + vllm_generation.finish_generation() + print("Multi-step calculator rollout complete (VLLM).") + + # --- Assertions --- + assert isinstance(final_batch, BatchedDataDict) + assert "message_log" in final_batch + assert "total_reward" in final_batch + assert len(final_batch["message_log"]) == len(initial_batch["message_log"]) + flattened_message_log, _ = batched_message_log_to_flat_message( + final_batch["message_log"] + ) + # Check that the sequence length is respected by flattening the message log and checking the length + assert len(flattened_message_log["token_ids"][0]) == max_seq_len, ( + f"Sequence length {len(flattened_message_log['token_ids'][0])} is not equal to max_seq_len {max_seq_len}" + ) + + # --- Fixture for Sliding Puzzle Environment --- @pytest.fixture(scope="function") def sliding_puzzle_environment(rollout_cluster):