Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 21 additions & 19 deletions trl/trainer/rloo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import inspect
import os
import textwrap
Expand Down Expand Up @@ -100,7 +101,6 @@
if is_bitsandbytes_available():
import bitsandbytes as bnb


logger = get_logger(__name__)

# What we call a reward function is a callable that takes a list of prompts and completions and returns a list of
Expand Down Expand Up @@ -719,7 +719,6 @@ def _get_per_token_logps_and_entropies(

# Build model inputs - check if the model supports logits_to_keep (some models and VLMs don't)
model_inputs = {"input_ids": input_ids_batch, "attention_mask": attention_mask_batch}

if image_grid_thw is not None and pixel_values is not None:
rows_per_image = image_grid_thw.prod(dim=-1)
rows_per_sample = torch.split(rows_per_image, num_images)
Expand Down Expand Up @@ -754,7 +753,6 @@ def _get_per_token_logps_and_entropies(
# Divide logits by sampling temperature.
# See https://huggingface.co/blog/the_n_implementation_details_of_rlhf_with_ppo#policy-training-implementation-details
logits = logits / self.temperature

completion_ids = input_ids_batch[:, -logits_to_keep:]
logps = selective_log_softmax(logits, completion_ids) # compute logprobs
all_logps.append(logps)
Expand Down Expand Up @@ -939,7 +937,6 @@ def _prepare_inputs(self, generation_batch: dict[str, torch.Tensor | Any]) -> di
generation_batches = split_tensor_dict(generation_batch, self.args.steps_per_generation)
self._buffered_inputs = [unsplit_pixel_values_by_grid(batch) for batch in generation_batches]
inputs = self._buffered_inputs[self._step % self.args.steps_per_generation]
self._step += 1
else:
# In evaluation, there is neither batch grouping for generation, nor multiple iterations, hence
# local generation batch == local eval batch
Expand Down Expand Up @@ -1022,7 +1019,8 @@ def _generate_single_turn(self, prompts: list):
self._move_model_to_vllm()
self._last_loaded_step = self.state.global_step

prompts = [prepare_multimodal_messages_vllm(prompt) for prompt in prompts]
if is_conversational({"prompt": prompts[0]}):
prompts = [prepare_multimodal_messages_vllm(prompt) for prompt in prompts]

# Generate completions using vLLM: gather all prompts and use them in a single call in the main process
if self.vllm_mode == "server":
Expand Down Expand Up @@ -1110,7 +1108,12 @@ def _generate_single_turn(self, prompts: list):

with profiling_context(self, "vLLM.generate"):
if is_conversational({"prompt": prompts[0]}):
all_outputs = self.llm.chat(all_prompts, sampling_params=sampling_params, use_tqdm=False)
all_outputs = self.llm.chat(
all_prompts,
sampling_params=sampling_params,
use_tqdm=False,
chat_template_kwargs=self.chat_template_kwargs,
)
else:
all_outputs = self.llm.generate(all_prompts, sampling_params=sampling_params, use_tqdm=False)

Expand Down Expand Up @@ -1215,8 +1218,18 @@ def _generate(self, prompts: list):
device = self.accelerator.device
mode = "train" if self.model.training else "eval"

# Copy the prompts to avoid modifying the original list
prompts = copy.deepcopy(prompts)

prompt_ids, completion_ids = self._generate_single_turn(prompts)

# Decode completions. It's important to use `parse_response` when possible, because it handles tool calls.
if is_conversational({"prompt": prompts[0]}):
contents = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
completions = [[{"role": "assistant", "content": content}] for content in contents]
else:
completions = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)

# Get completion length per sequence, used for logging
prompt_lengths = torch.tensor([len(ids) for ids in prompt_ids], device=device)
completion_lengths = torch.tensor([len(ids) for ids in completion_ids], device=device)
Expand All @@ -1231,7 +1244,6 @@ def _generate(self, prompts: list):
self._metrics[mode]["num_tokens"] = [self.state.num_input_tokens_seen]

# Log completion lengths, mean, min, max
agg_completion_lengths = self.accelerator.gather(completion_lengths)
self._metrics[mode]["completions/mean_length"].append(agg_completion_lengths.float().mean().item())
self._metrics[mode]["completions/min_length"].append(agg_completion_lengths.float().min().item())
self._metrics[mode]["completions/max_length"].append(agg_completion_lengths.float().max().item())
Expand All @@ -1248,7 +1260,7 @@ def _generate(self, prompts: list):
self._metrics[mode]["completions/min_terminated_length"].append(term_completion_lengths.float().min().item())
self._metrics[mode]["completions/max_terminated_length"].append(term_completion_lengths.float().max().item())

return prompt_ids, completion_ids
return prompt_ids, completion_ids, completions

def _generate_and_score_completions(
self, inputs: list[dict[str, torch.Tensor | Any]]
Expand Down Expand Up @@ -1277,7 +1289,7 @@ def _generate_and_score_completions(
for prompt, image_list in zip(prompts, images, strict=True)
]

prompt_ids_list, completion_ids_list = self._generate(prompts)
prompt_ids_list, completion_ids_list, completions = self._generate(prompts)

# Convert lists of token IDs to padded tensors
prompt_ids = [torch.tensor(ids, device=device) for ids in prompt_ids_list]
Expand Down Expand Up @@ -1368,16 +1380,6 @@ def _generate_and_score_completions(
# Decode
prompts_text = self.processing_class.batch_decode(prompt_ids, skip_special_tokens=True)
completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
if is_conversational(inputs[0]):
completions = []
for prompt, completion in zip(prompts, completions_text, strict=True):
bootstrap = prompt.pop()["content"] if prompt[-1]["role"] == "assistant" else ""
if isinstance(bootstrap, list): # for VLM, the format might be [{"type": "text", "text": "..."}]
assert len(bootstrap) == 1 and bootstrap[0]["type"] == "text"
bootstrap = bootstrap[0]["text"]
completions.append([{"role": "assistant", "content": bootstrap + completion}])
else:
completions = completions_text

# Calculate rewards for each reward function. rewards_per_func aggregates rewards across all processes. This is
# important because rewards will be normalized per group, and completions are distributed. We will later slice
Expand Down
Loading