From b766c17a4cbb9268d8563ddd21f6ad1cf8f7e3ff Mon Sep 17 00:00:00 2001 From: Sahil Jain Date: Wed, 30 Apr 2025 02:03:16 -0700 Subject: [PATCH 1/2] Fixed max seqlen not respected correctly Signed-off-by: Sahil Jain --- nemo_rl/experience/rollouts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo_rl/experience/rollouts.py b/nemo_rl/experience/rollouts.py index a556a32a42..1e23b0f9fb 100644 --- a/nemo_rl/experience/rollouts.py +++ b/nemo_rl/experience/rollouts.py @@ -311,7 +311,7 @@ def run_multi_turn_rollout( >= max_seq_len ): # truncate - tokenized_obs = tokenized_obs[: max_seq_len - active_input_lengths[i]] + tokenized_obs = tokenized_obs[: max_seq_len - (len(generated_ids[i]) + active_input_lengths[i])] truncation_mask[i] = True # Record truncation sample_truncated[active_indices[i]] = True From f8b739c66cdd322db676210b80c1098d6faa59f1 Mon Sep 17 00:00:00 2001 From: Sahil Jain Date: Wed, 30 Apr 2025 12:14:24 -0700 Subject: [PATCH 2/2] Added unit test for max seqlen multiturn Signed-off-by: Sahil Jain --- nemo_rl/experience/rollouts.py | 4 ++- tests/unit/experience/test_rollouts.py | 40 ++++++++++++++++++++++++++ 2 files changed, 43 insertions(+), 1 deletion(-) diff --git a/nemo_rl/experience/rollouts.py b/nemo_rl/experience/rollouts.py index 1e23b0f9fb..567add0dfc 100644 --- a/nemo_rl/experience/rollouts.py +++ b/nemo_rl/experience/rollouts.py @@ -311,7 +311,9 @@ def run_multi_turn_rollout( >= max_seq_len ): # truncate - tokenized_obs = tokenized_obs[: max_seq_len - (len(generated_ids[i]) + active_input_lengths[i])] + tokenized_obs = tokenized_obs[ + : max_seq_len - (len(generated_ids[i]) + active_input_lengths[i]) + ] truncation_mask[i] = True # Record truncation sample_truncated[active_indices[i]] = True diff --git a/tests/unit/experience/test_rollouts.py b/tests/unit/experience/test_rollouts.py index b45811d4f8..bcfa1b84d2 100644 --- a/tests/unit/experience/test_rollouts.py +++ b/tests/unit/experience/test_rollouts.py @@ -20,6 +20,7 @@ import torch from transformers import AutoTokenizer +from nemo_rl.data.llm_message_utils import batched_message_log_to_flat_message from nemo_rl.distributed.batched_data_dict import BatchedDataDict from nemo_rl.distributed.virtual_cluster import RayVirtualCluster from nemo_rl.environments.games.sliding_puzzle import ( @@ -440,6 +441,45 @@ def test_run_multi_step_calculator_vllm(multi_step_setup_vllm): print("\nMulti-Step Calculator VLLM Test assertions passed.") +@pytest.mark.skipif( + not torch.cuda.is_available() or torch.cuda.device_count() < 1, + reason="VLLM test requires at least 1 GPU", +) +def test_max_seqlen_respected(multi_step_setup_vllm): + """Tests multi-step calculator rollout with VllmGeneration.""" + vllm_generation, rollout_tokenizer, task_to_env, initial_batch, rollout_cluster = ( + multi_step_setup_vllm + ) + max_rollout_turns = initial_batch["extra_env_info"][0]["max_steps"] + 1 + max_seq_len = 290 + + print("\nRunning multi-step calculator rollout (VLLM)...") + vllm_generation.prepare_for_generation() + final_batch, rollout_metrics = run_multi_turn_rollout( + policy_generation=vllm_generation, + input_batch=initial_batch, + tokenizer=rollout_tokenizer, + task_to_env=task_to_env, + max_seq_len=max_seq_len, + max_rollout_turns=max_rollout_turns, + ) + vllm_generation.finish_generation() + print("Multi-step calculator rollout complete (VLLM).") + + # --- Assertions --- + assert isinstance(final_batch, BatchedDataDict) + assert "message_log" in final_batch + assert "total_reward" in final_batch + assert len(final_batch["message_log"]) == len(initial_batch["message_log"]) + flattened_message_log, _ = batched_message_log_to_flat_message( + final_batch["message_log"] + ) + # Check that the sequence length is respected by flattening the message log and checking the length + assert len(flattened_message_log["token_ids"][0]) == max_seq_len, ( + f"Sequence length {len(flattened_message_log['token_ids'][0])} is not equal to max_seq_len {max_seq_len}" + ) + + # --- Fixture for Sliding Puzzle Environment --- @pytest.fixture(scope="function") def sliding_puzzle_environment(rollout_cluster):