diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 9d571aec883..20b2579a415 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -25,7 +25,7 @@ jobs: check_code_quality: name: Check code quality runs-on: ubuntu-latest - if: github.event.pull_request.draft == false +# if: github.event.pull_request.draft == false steps: - uses: actions/checkout@v4 - name: Set up Python 3.12 @@ -40,7 +40,7 @@ jobs: name: Tests strategy: matrix: - python-version: ['3.10', '3.11', '3.12', '3.13'] + python-version: ['3.10'] # , '3.11', '3.12', '3.13'] fail-fast: false runs-on: group: aws-g4dn-2xlarge @@ -50,7 +50,7 @@ jobs: defaults: run: shell: bash - if: github.event.pull_request.draft == false +# if: github.event.pull_request.draft == false steps: - name: Git checkout uses: actions/checkout@v4 diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index f5f333fa092..edb852cd84d 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1227,7 +1227,7 @@ def _generate_single_turn(self, prompts: list): device = self.accelerator.device mode = "train" if self.model.training else "eval" - # Generate completions using either vLLM or regular generation + # Step 1: vLLM preparation (if using vLLM, regardless of whether rollout_func is used) if self.use_vllm: if self.vllm_mode == "colocate" and self.args.vllm_enable_sleep_mode: # wake up colocated vLLM instances if needed @@ -1254,6 +1254,74 @@ def _generate_single_turn(self, prompts: list): if isinstance(args, dict): # only convert dict → JSON string call["function"]["arguments"] = json.dumps(args) + # Step 2: Generation dispatch - independent of vLLM preparation + if self.rollout_func is not None: + # Generate using custom rollout function + # The rollout_func is responsible for generation and may use vLLM internally via generate_rollout_completions() + if self.use_vllm and self.vllm_mode == "server": + # Server mode: gather all prompts, generate on main process, broadcast results + all_prompts = gather_object(prompts) + num_generations = self.num_generations if mode == "train" else self.num_generations_eval + + if self.accelerator.is_main_process: + ordered_set_of_prompts = all_prompts[::num_generations] + rollout_prompts = ordered_set_of_prompts + if rollout_prompts and is_conversational({"prompt": rollout_prompts[0]}): + rollout_prompts = [ + apply_chat_template({"prompt": p}, self.processing_class, **self.chat_template_kwargs)[ + "prompt" + ] + for p in rollout_prompts + ] + output = self.rollout_func(rollout_prompts, self) + required_keys = {"prompt_ids", "completion_ids", "logprobs"} + extra_fields = {k: v for k, v in output.items() if k not in required_keys} + payload = (output["prompt_ids"], output["completion_ids"], output["logprobs"], extra_fields) + else: + payload = None + + # Broadcast the completions from the main process to all processes + obj_list = [payload] + broadcast_object_list(obj_list, from_process=0) + all_prompt_ids, all_completion_ids, all_logprobs, all_extra_fields = obj_list[0] + + # At this point, we only get 1 copy of each prompt, so we need to repeat them num_generations times + all_prompt_ids = [ids for ids in all_prompt_ids for _ in range(num_generations)] + + process_slice = slice( + self.accelerator.process_index * len(prompts), + (self.accelerator.process_index + 1) * len(prompts), + ) + prompt_ids = all_prompt_ids[process_slice] + completion_ids = all_completion_ids[process_slice] + logprobs = all_logprobs[process_slice] + + # Slice extra fields dict-of-lists per process + extra_fields = {} + for key, values in all_extra_fields.items(): + if isinstance(values, list): + extra_fields[key] = values[process_slice] + else: + extra_fields[key] = values + else: + # Colocate or non-vllm mode: each process handles its own prompts + rollout_prompts = prompts + if rollout_prompts and is_conversational({"prompt": rollout_prompts[0]}): + rollout_prompts = [ + apply_chat_template({"prompt": prompt}, self.processing_class, **self.chat_template_kwargs)[ + "prompt" + ] + for prompt in rollout_prompts + ] + output = self.rollout_func(rollout_prompts, self) + required_keys = {"prompt_ids", "completion_ids", "logprobs"} + extra_fields = {k: v for k, v in output.items() if k not in required_keys} + prompt_ids = output["prompt_ids"] + completion_ids = output["completion_ids"] + logprobs = output["logprobs"] + + elif self.use_vllm: + # Generate completions using vLLM (preparation already done above) # Generate completions using vLLM: gather all prompts and use them in a single call in the main process if self.vllm_mode == "server": all_prompts = gather_object(prompts) @@ -1277,31 +1345,20 @@ def _generate_single_turn(self, prompts: list): "generation_kwargs": self.args.generation_kwargs, } with profiling_context(self, "vLLM.generate"): - if self.rollout_func is not None: - rollout_prompts = ordered_set_of_prompts - if rollout_prompts and is_conversational({"prompt": rollout_prompts[0]}): - rollout_prompts = [ - apply_chat_template( - {"prompt": p}, self.processing_class, **self.chat_template_kwargs - )["prompt"] - for p in rollout_prompts - ] - output = self.rollout_func(rollout_prompts, self) + if is_conversational({"prompt": ordered_set_of_prompts[0]}): + output = self.vllm_client.chat( + messages=ordered_set_of_prompts, + **sampling_params, + chat_template_kwargs=self.chat_template_kwargs, + tools=self.tools, + chat_template=self.chat_template, + ) else: - if is_conversational({"prompt": ordered_set_of_prompts[0]}): - output = self.vllm_client.chat( - messages=ordered_set_of_prompts, - **sampling_params, - chat_template_kwargs=self.chat_template_kwargs, - tools=self.tools, - chat_template=self.chat_template, - ) - else: - output = self.vllm_client.generate(prompts=ordered_set_of_prompts, **sampling_params) - # Extract required fields and collect any extra fields for reward functions - required_keys = {"prompt_ids", "completion_ids", "logprobs"} - extra_fields = {k: v for k, v in output.items() if k not in required_keys} - payload = (output["prompt_ids"], output["completion_ids"], output["logprobs"], extra_fields) + output = self.vllm_client.generate(prompts=ordered_set_of_prompts, **sampling_params) + # Extract required fields and collect any extra fields for reward functions + required_keys = {"prompt_ids", "completion_ids", "logprobs"} + extra_fields = {k: v for k, v in output.items() if k not in required_keys} + payload = (output["prompt_ids"], output["completion_ids"], output["logprobs"], extra_fields) else: payload = None @@ -1331,95 +1388,77 @@ def _generate_single_turn(self, prompts: list): # Generate completions using colocated vLLM instances: each device holds vLLM copy and work on their own batch of prompts elif self.vllm_mode == "colocate": - if self.rollout_func is not None: - rollout_prompts = prompts - if rollout_prompts and is_conversational({"prompt": rollout_prompts[0]}): - rollout_prompts = [ - apply_chat_template( - {"prompt": prompt}, self.processing_class, **self.chat_template_kwargs - )["prompt"] - for prompt in rollout_prompts - ] - output = self.rollout_func(rollout_prompts, self) - required_keys = {"prompt_ids", "completion_ids", "logprobs"} - extra_fields = {k: v for k, v in output.items() if k not in required_keys} - prompt_ids = output["prompt_ids"] - completion_ids = output["completion_ids"] - logprobs = output["logprobs"] + if self.guided_decoding_regex: + guided_decoding = GuidedDecodingParams(regex=self.guided_decoding_regex) else: - if self.guided_decoding_regex: - guided_decoding = GuidedDecodingParams(regex=self.guided_decoding_regex) - else: - guided_decoding = None + guided_decoding = None + + generation_kwargs = { + "n": 1, # vLLM on each GPU generates only 1 in colocate mode + "repetition_penalty": self.repetition_penalty, + "temperature": self.temperature, + "top_p": self.top_p, + "top_k": self.top_k, + "min_p": 0.0 if self.min_p is None else self.min_p, + "max_tokens": self.max_completion_length, + "guided_decoding": guided_decoding, + "logprobs": 0, # enable returning log probabilities; 0 means for the sampled tokens only + } + if self.args.generation_kwargs is not None: + generation_kwargs.update(self.args.generation_kwargs) + sampling_params = SamplingParams(**generation_kwargs) - generation_kwargs = { - "n": 1, # vLLM on each GPU generates only 1 in colocate mode - "repetition_penalty": self.repetition_penalty, - "temperature": self.temperature, - "top_p": self.top_p, - "top_k": self.top_k, - "min_p": 0.0 if self.min_p is None else self.min_p, - "max_tokens": self.max_completion_length, - "guided_decoding": guided_decoding, - "logprobs": 0, # enable returning log probabilities; 0 means for the sampled tokens only - } - if self.args.generation_kwargs is not None: - generation_kwargs.update(self.args.generation_kwargs) - sampling_params = SamplingParams(**generation_kwargs) - - if self.vllm_tensor_parallel_size > 1: - # Gather prompts from all ranks in the TP group and flatten. - # Each rank starts with its own prompts; after gathering, all ranks see the full group set. - orig_size = len(prompts) - gathered_prompts = [None for _ in range(self.vllm_tensor_parallel_size)] - torch.distributed.all_gather_object(gathered_prompts, prompts, group=self.tp_group) - all_prompts = [p for sublist in gathered_prompts for p in sublist] - else: - all_prompts = prompts - - if self.args.vllm_enable_sleep_mode: - self.llm.wake_up(tags=["kv_cache"]) + if self.vllm_tensor_parallel_size > 1: + # Gather prompts from all ranks in the TP group and flatten. + # Each rank starts with its own prompts; after gathering, all ranks see the full group set. + orig_size = len(prompts) + gathered_prompts = [None for _ in range(self.vllm_tensor_parallel_size)] + torch.distributed.all_gather_object(gathered_prompts, prompts, group=self.tp_group) + all_prompts = [p for sublist in gathered_prompts for p in sublist] + else: + all_prompts = prompts - with profiling_context(self, "vLLM.generate"): - if is_conversational({"prompt": prompts[0]}): - all_outputs = self.llm.chat( - all_prompts, - sampling_params=sampling_params, - use_tqdm=False, - chat_template_kwargs=self.chat_template_kwargs, - tools=self.tools, - chat_template=self.chat_template, - ) - else: - all_outputs = self.llm.generate( - all_prompts, sampling_params=sampling_params, use_tqdm=False - ) + if self.args.vllm_enable_sleep_mode: + self.llm.wake_up(tags=["kv_cache"]) + + with profiling_context(self, "vLLM.generate"): + if is_conversational({"prompt": prompts[0]}): + all_outputs = self.llm.chat( + all_prompts, + sampling_params=sampling_params, + use_tqdm=False, + chat_template_kwargs=self.chat_template_kwargs, + tools=self.tools, + chat_template=self.chat_template, + ) + else: + all_outputs = self.llm.generate(all_prompts, sampling_params=sampling_params, use_tqdm=False) - all_prompt_ids = [output.prompt_token_ids for output in all_outputs] - all_completion_ids = [output.token_ids for outputs in all_outputs for output in outputs.outputs] - all_logprobs = [ - [next(iter(lp.values())).logprob for lp in output.logprobs] - for outputs in all_outputs - for output in outputs.outputs - ] + all_prompt_ids = [output.prompt_token_ids for output in all_outputs] + all_completion_ids = [output.token_ids for outputs in all_outputs for output in outputs.outputs] + all_logprobs = [ + [next(iter(lp.values())).logprob for lp in output.logprobs] + for outputs in all_outputs + for output in outputs.outputs + ] - if self.vllm_tensor_parallel_size > 1: - # Slice completions for this rank within its TP group. - # Each rank generates all outputs — we keep only our share. - local_rank_in_group = torch.distributed.get_rank(group=self.tp_group) - tp_slice = slice(local_rank_in_group * orig_size, (local_rank_in_group + 1) * orig_size) - prompt_ids = all_prompt_ids[tp_slice] - completion_ids = all_completion_ids[tp_slice] - logprobs = all_logprobs[tp_slice] - else: - prompt_ids = all_prompt_ids - completion_ids = all_completion_ids - logprobs = all_logprobs + if self.vllm_tensor_parallel_size > 1: + # Slice completions for this rank within its TP group. + # Each rank generates all outputs — we keep only our share. + local_rank_in_group = torch.distributed.get_rank(group=self.tp_group) + tp_slice = slice(local_rank_in_group * orig_size, (local_rank_in_group + 1) * orig_size) + prompt_ids = all_prompt_ids[tp_slice] + completion_ids = all_completion_ids[tp_slice] + logprobs = all_logprobs[tp_slice] + else: + prompt_ids = all_prompt_ids + completion_ids = all_completion_ids + logprobs = all_logprobs - extra_fields = {} # No extra fields for colocate mode + extra_fields = {} # No extra fields for colocate mode - if self.args.vllm_enable_sleep_mode: - self.llm.sleep(level=2) + if self.args.vllm_enable_sleep_mode: + self.llm.sleep(level=2) elif self.use_transformers_paged: if is_conversational({"prompt": prompts[0]}): @@ -1507,7 +1546,7 @@ def _generate_single_turn(self, prompts: list): prompt_ids = [p[m].tolist() for p, m in zip(prompt_ids, prompt_mask.bool(), strict=True)] completion_ids = [c[m].tolist() for c, m in zip(completion_ids, completion_mask.bool(), strict=True)] logprobs = None # not used in this case - extra_fields = {} # No extra fields for non-rollout_func paths + extra_fields = {} # No extra fields for regular generation return prompt_ids, completion_ids, logprobs, extra_fields