Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ jobs:
check_code_quality:
name: Check code quality
runs-on: ubuntu-latest
if: github.event.pull_request.draft == false
# if: github.event.pull_request.draft == false
steps:
- uses: actions/checkout@v4
- name: Set up Python 3.12
Expand All @@ -40,7 +40,7 @@ jobs:
name: Tests
strategy:
matrix:
python-version: ['3.10', '3.11', '3.12', '3.13']
python-version: ['3.10'] # , '3.11', '3.12', '3.13']
fail-fast: false
runs-on:
group: aws-g4dn-2xlarge
Expand All @@ -50,7 +50,7 @@ jobs:
defaults:
run:
shell: bash
if: github.event.pull_request.draft == false
# if: github.event.pull_request.draft == false
steps:
- name: Git checkout
uses: actions/checkout@v4
Expand Down
255 changes: 147 additions & 108 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1227,7 +1227,7 @@ def _generate_single_turn(self, prompts: list):
device = self.accelerator.device
mode = "train" if self.model.training else "eval"

# Generate completions using either vLLM or regular generation
# Step 1: vLLM preparation (if using vLLM, regardless of whether rollout_func is used)
if self.use_vllm:
if self.vllm_mode == "colocate" and self.args.vllm_enable_sleep_mode:
# wake up colocated vLLM instances if needed
Expand All @@ -1254,6 +1254,74 @@ def _generate_single_turn(self, prompts: list):
if isinstance(args, dict): # only convert dict → JSON string
call["function"]["arguments"] = json.dumps(args)

# Step 2: Generation dispatch - independent of vLLM preparation
if self.rollout_func is not None:
# Generate using custom rollout function
# The rollout_func is responsible for generation and may use vLLM internally via generate_rollout_completions()
if self.use_vllm and self.vllm_mode == "server":
# Server mode: gather all prompts, generate on main process, broadcast results
all_prompts = gather_object(prompts)
num_generations = self.num_generations if mode == "train" else self.num_generations_eval

if self.accelerator.is_main_process:
ordered_set_of_prompts = all_prompts[::num_generations]
rollout_prompts = ordered_set_of_prompts
if rollout_prompts and is_conversational({"prompt": rollout_prompts[0]}):
rollout_prompts = [
apply_chat_template({"prompt": p}, self.processing_class, **self.chat_template_kwargs)[
"prompt"
]
for p in rollout_prompts
]
output = self.rollout_func(rollout_prompts, self)
required_keys = {"prompt_ids", "completion_ids", "logprobs"}
extra_fields = {k: v for k, v in output.items() if k not in required_keys}
payload = (output["prompt_ids"], output["completion_ids"], output["logprobs"], extra_fields)
else:
payload = None

# Broadcast the completions from the main process to all processes
obj_list = [payload]
broadcast_object_list(obj_list, from_process=0)
all_prompt_ids, all_completion_ids, all_logprobs, all_extra_fields = obj_list[0]

# At this point, we only get 1 copy of each prompt, so we need to repeat them num_generations times
all_prompt_ids = [ids for ids in all_prompt_ids for _ in range(num_generations)]

process_slice = slice(
self.accelerator.process_index * len(prompts),
(self.accelerator.process_index + 1) * len(prompts),
)
prompt_ids = all_prompt_ids[process_slice]
completion_ids = all_completion_ids[process_slice]
logprobs = all_logprobs[process_slice]

# Slice extra fields dict-of-lists per process
extra_fields = {}
for key, values in all_extra_fields.items():
if isinstance(values, list):
extra_fields[key] = values[process_slice]
else:
extra_fields[key] = values
else:
# Colocate or non-vllm mode: each process handles its own prompts
rollout_prompts = prompts
if rollout_prompts and is_conversational({"prompt": rollout_prompts[0]}):
rollout_prompts = [
apply_chat_template({"prompt": prompt}, self.processing_class, **self.chat_template_kwargs)[
"prompt"
]
for prompt in rollout_prompts
]
output = self.rollout_func(rollout_prompts, self)
required_keys = {"prompt_ids", "completion_ids", "logprobs"}
extra_fields = {k: v for k, v in output.items() if k not in required_keys}
prompt_ids = output["prompt_ids"]
completion_ids = output["completion_ids"]
logprobs = output["logprobs"]

elif self.use_vllm:
# Generate completions using vLLM (preparation already done above)
# Generate completions using vLLM: gather all prompts and use them in a single call in the main process
if self.vllm_mode == "server":
all_prompts = gather_object(prompts)
Expand All @@ -1277,31 +1345,20 @@ def _generate_single_turn(self, prompts: list):
"generation_kwargs": self.args.generation_kwargs,
}
with profiling_context(self, "vLLM.generate"):
if self.rollout_func is not None:
rollout_prompts = ordered_set_of_prompts
if rollout_prompts and is_conversational({"prompt": rollout_prompts[0]}):
rollout_prompts = [
apply_chat_template(
{"prompt": p}, self.processing_class, **self.chat_template_kwargs
)["prompt"]
for p in rollout_prompts
]
output = self.rollout_func(rollout_prompts, self)
if is_conversational({"prompt": ordered_set_of_prompts[0]}):
output = self.vllm_client.chat(
messages=ordered_set_of_prompts,
**sampling_params,
chat_template_kwargs=self.chat_template_kwargs,
tools=self.tools,
chat_template=self.chat_template,
)
else:
if is_conversational({"prompt": ordered_set_of_prompts[0]}):
output = self.vllm_client.chat(
messages=ordered_set_of_prompts,
**sampling_params,
chat_template_kwargs=self.chat_template_kwargs,
tools=self.tools,
chat_template=self.chat_template,
)
else:
output = self.vllm_client.generate(prompts=ordered_set_of_prompts, **sampling_params)
# Extract required fields and collect any extra fields for reward functions
required_keys = {"prompt_ids", "completion_ids", "logprobs"}
extra_fields = {k: v for k, v in output.items() if k not in required_keys}
payload = (output["prompt_ids"], output["completion_ids"], output["logprobs"], extra_fields)
output = self.vllm_client.generate(prompts=ordered_set_of_prompts, **sampling_params)
# Extract required fields and collect any extra fields for reward functions
required_keys = {"prompt_ids", "completion_ids", "logprobs"}
extra_fields = {k: v for k, v in output.items() if k not in required_keys}
payload = (output["prompt_ids"], output["completion_ids"], output["logprobs"], extra_fields)
else:
payload = None

Expand Down Expand Up @@ -1331,95 +1388,77 @@ def _generate_single_turn(self, prompts: list):

# Generate completions using colocated vLLM instances: each device holds vLLM copy and work on their own batch of prompts
elif self.vllm_mode == "colocate":
if self.rollout_func is not None:
rollout_prompts = prompts
if rollout_prompts and is_conversational({"prompt": rollout_prompts[0]}):
rollout_prompts = [
apply_chat_template(
{"prompt": prompt}, self.processing_class, **self.chat_template_kwargs
)["prompt"]
for prompt in rollout_prompts
]
output = self.rollout_func(rollout_prompts, self)
required_keys = {"prompt_ids", "completion_ids", "logprobs"}
extra_fields = {k: v for k, v in output.items() if k not in required_keys}
prompt_ids = output["prompt_ids"]
completion_ids = output["completion_ids"]
logprobs = output["logprobs"]
if self.guided_decoding_regex:
guided_decoding = GuidedDecodingParams(regex=self.guided_decoding_regex)
else:
if self.guided_decoding_regex:
guided_decoding = GuidedDecodingParams(regex=self.guided_decoding_regex)
else:
guided_decoding = None
guided_decoding = None

generation_kwargs = {
"n": 1, # vLLM on each GPU generates only 1 in colocate mode
"repetition_penalty": self.repetition_penalty,
"temperature": self.temperature,
"top_p": self.top_p,
"top_k": self.top_k,
"min_p": 0.0 if self.min_p is None else self.min_p,
"max_tokens": self.max_completion_length,
"guided_decoding": guided_decoding,
"logprobs": 0, # enable returning log probabilities; 0 means for the sampled tokens only
}
if self.args.generation_kwargs is not None:
generation_kwargs.update(self.args.generation_kwargs)
sampling_params = SamplingParams(**generation_kwargs)

generation_kwargs = {
"n": 1, # vLLM on each GPU generates only 1 in colocate mode
"repetition_penalty": self.repetition_penalty,
"temperature": self.temperature,
"top_p": self.top_p,
"top_k": self.top_k,
"min_p": 0.0 if self.min_p is None else self.min_p,
"max_tokens": self.max_completion_length,
"guided_decoding": guided_decoding,
"logprobs": 0, # enable returning log probabilities; 0 means for the sampled tokens only
}
if self.args.generation_kwargs is not None:
generation_kwargs.update(self.args.generation_kwargs)
sampling_params = SamplingParams(**generation_kwargs)

if self.vllm_tensor_parallel_size > 1:
# Gather prompts from all ranks in the TP group and flatten.
# Each rank starts with its own prompts; after gathering, all ranks see the full group set.
orig_size = len(prompts)
gathered_prompts = [None for _ in range(self.vllm_tensor_parallel_size)]
torch.distributed.all_gather_object(gathered_prompts, prompts, group=self.tp_group)
all_prompts = [p for sublist in gathered_prompts for p in sublist]
else:
all_prompts = prompts

if self.args.vllm_enable_sleep_mode:
self.llm.wake_up(tags=["kv_cache"])
if self.vllm_tensor_parallel_size > 1:
# Gather prompts from all ranks in the TP group and flatten.
# Each rank starts with its own prompts; after gathering, all ranks see the full group set.
orig_size = len(prompts)
gathered_prompts = [None for _ in range(self.vllm_tensor_parallel_size)]
torch.distributed.all_gather_object(gathered_prompts, prompts, group=self.tp_group)
all_prompts = [p for sublist in gathered_prompts for p in sublist]
else:
all_prompts = prompts

with profiling_context(self, "vLLM.generate"):
if is_conversational({"prompt": prompts[0]}):
all_outputs = self.llm.chat(
all_prompts,
sampling_params=sampling_params,
use_tqdm=False,
chat_template_kwargs=self.chat_template_kwargs,
tools=self.tools,
chat_template=self.chat_template,
)
else:
all_outputs = self.llm.generate(
all_prompts, sampling_params=sampling_params, use_tqdm=False
)
if self.args.vllm_enable_sleep_mode:
self.llm.wake_up(tags=["kv_cache"])

with profiling_context(self, "vLLM.generate"):
if is_conversational({"prompt": prompts[0]}):
all_outputs = self.llm.chat(
all_prompts,
sampling_params=sampling_params,
use_tqdm=False,
chat_template_kwargs=self.chat_template_kwargs,
tools=self.tools,
chat_template=self.chat_template,
)
else:
all_outputs = self.llm.generate(all_prompts, sampling_params=sampling_params, use_tqdm=False)

all_prompt_ids = [output.prompt_token_ids for output in all_outputs]
all_completion_ids = [output.token_ids for outputs in all_outputs for output in outputs.outputs]
all_logprobs = [
[next(iter(lp.values())).logprob for lp in output.logprobs]
for outputs in all_outputs
for output in outputs.outputs
]
all_prompt_ids = [output.prompt_token_ids for output in all_outputs]
all_completion_ids = [output.token_ids for outputs in all_outputs for output in outputs.outputs]
all_logprobs = [
[next(iter(lp.values())).logprob for lp in output.logprobs]
for outputs in all_outputs
for output in outputs.outputs
]

if self.vllm_tensor_parallel_size > 1:
# Slice completions for this rank within its TP group.
# Each rank generates all outputs — we keep only our share.
local_rank_in_group = torch.distributed.get_rank(group=self.tp_group)
tp_slice = slice(local_rank_in_group * orig_size, (local_rank_in_group + 1) * orig_size)
prompt_ids = all_prompt_ids[tp_slice]
completion_ids = all_completion_ids[tp_slice]
logprobs = all_logprobs[tp_slice]
else:
prompt_ids = all_prompt_ids
completion_ids = all_completion_ids
logprobs = all_logprobs
if self.vllm_tensor_parallel_size > 1:
# Slice completions for this rank within its TP group.
# Each rank generates all outputs — we keep only our share.
local_rank_in_group = torch.distributed.get_rank(group=self.tp_group)
tp_slice = slice(local_rank_in_group * orig_size, (local_rank_in_group + 1) * orig_size)
prompt_ids = all_prompt_ids[tp_slice]
completion_ids = all_completion_ids[tp_slice]
logprobs = all_logprobs[tp_slice]
else:
prompt_ids = all_prompt_ids
completion_ids = all_completion_ids
logprobs = all_logprobs

extra_fields = {} # No extra fields for colocate mode
extra_fields = {} # No extra fields for colocate mode

if self.args.vllm_enable_sleep_mode:
self.llm.sleep(level=2)
if self.args.vllm_enable_sleep_mode:
self.llm.sleep(level=2)

elif self.use_transformers_paged:
if is_conversational({"prompt": prompts[0]}):
Expand Down Expand Up @@ -1507,7 +1546,7 @@ def _generate_single_turn(self, prompts: list):
prompt_ids = [p[m].tolist() for p, m in zip(prompt_ids, prompt_mask.bool(), strict=True)]
completion_ids = [c[m].tolist() for c, m in zip(completion_ids, completion_mask.bool(), strict=True)]
logprobs = None # not used in this case
extra_fields = {} # No extra fields for non-rollout_func paths
extra_fields = {} # No extra fields for regular generation

return prompt_ids, completion_ids, logprobs, extra_fields

Expand Down
Loading