Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion nemo_rl/experience/rollouts.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ def run_multi_turn_rollout(
>= max_seq_len
):
# truncate
tokenized_obs = tokenized_obs[: max_seq_len - active_input_lengths[i]]
tokenized_obs = tokenized_obs[: max_seq_len - (len(generated_ids[i]) + active_input_lengths[i])]
Comment thread
SahilJain314 marked this conversation as resolved.
Outdated
truncation_mask[i] = True
# Record truncation
sample_truncated[active_indices[i]] = True
Expand Down