Skip to content

Commit 2997e5d

Browse files
authored
[rollout] feat: support best-of-n generation in vLLM (Jiayi-Pan#80)
1 parent 8a2a39b commit 2997e5d

File tree

6 files changed

+22
-2
lines changed

6 files changed

+22
-2
lines changed

examples/split_placement/config/ppo_trainer_split.yaml

+2
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,8 @@ actor_rollout_ref:
6666
log_prob_micro_batch_size: 128
6767
# for hf rollout
6868
do_sample: True
69+
# number of responses (i.e. num sample times)
70+
n: 1 # > 1 for grpo
6971

7072
critic:
7173
strategy: fsdp

verl/trainer/config/ppo_megatron_trainer.yaml

+2
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,8 @@ actor_rollout_ref:
7272
layer_name_map:
7373
qkv_layer_name: qkv
7474
gate_proj_layer_name: gate_up
75+
# number of responses (i.e. num sample times)
76+
n: 1
7577

7678
critic:
7779
strategy: megatron

verl/trainer/config/ppo_trainer.yaml

+2
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,8 @@ actor_rollout_ref:
6666
log_prob_micro_batch_size: 128
6767
# for hf rollout
6868
do_sample: True
69+
# number of responses (i.e. num sample times)
70+
n: 1 # > 1 for grpo
6971

7072
critic:
7173
strategy: fsdp

verl/trainer/ppo/ray_trainer.py

+2
Original file line numberDiff line numberDiff line change
@@ -486,6 +486,8 @@ def fit(self):
486486
with _timer('gen', timing_raw):
487487
gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)
488488

489+
# repeat to align with repeated responses in rollout
490+
batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)
489491
batch = batch.union(gen_batch_output)
490492

491493
if self.use_reference_policy:

verl/workers/fsdp_workers.py

+4
Original file line numberDiff line numberDiff line change
@@ -81,10 +81,14 @@ def __init__(self, config: DictConfig, role: str):
8181
if self._is_actor:
8282
self.config.actor.ppo_mini_batch_size //= self.device_mesh.shape[0]
8383
self.config.actor.ppo_micro_batch_size //= self.device_mesh.shape[0]
84+
self.config.actor.ppo_mini_batch_size *= self.config.rollout.n
85+
self.config.actor.ppo_micro_batch_size *= self.config.rollout.n
8486
if self._is_rollout:
8587
self.config.rollout.log_prob_micro_batch_size //= self.device_mesh.shape[0]
88+
self.config.rollout.log_prob_micro_batch_size *= self.config.rollout.n
8689
if self._is_ref:
8790
self.config.ref.log_prob_micro_batch_size //= self.device_mesh.shape[0]
91+
self.config.ref.log_prob_micro_batch_size *= self.config.rollout.n
8892

8993
def _build_model_optimizer(self,
9094
model_path,

verl/workers/rollout/vllm_rollout/vllm_rollout.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,7 @@ def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto:
167167
'top_k': -1,
168168
'min_p': 0.0,
169169
'temperature': 0,
170+
'n': 1 # if greedy, only 1 response
170171
}
171172

172173
# users can customize different sampling_params at different run
@@ -177,13 +178,20 @@ def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto:
177178
prompt_token_ids=idx_list,
178179
use_tqdm=False)
179180

180-
response = output[0].to(idx.device) # (bs, response_length)
181-
log_probs = output[1].to(idx.device) # (bs, response_length)
181+
# TODO(sgm): disable logprob when recompute_log_prob is enable
182+
# if n = 1: (bs, response_length) ; if n > 1: (bs * n, response_length)
183+
response = output[0].to(idx.device)
184+
log_probs = output[1].to(idx.device)
182185

183186
if response.shape[1] < self.config.response_length:
184187
response = pad_sequence_to_length(response, self.config.response_length, self.pad_token_id)
185188
log_probs = pad_sequence_to_length(log_probs, self.config.response_length, self.pad_token_id)
186189

190+
if self.config.n > 1 and do_sample:
191+
idx = idx.repeat_interleave(self.config.n, dim=0)
192+
attention_mask = attention_mask.repeat_interleave(self.config.n, dim=0)
193+
position_ids = position_ids.repeat_interleave(self.config.n, dim=0)
194+
batch_size = batch_size * self.config.n
187195
seq = torch.cat([idx, response], dim=-1)
188196

189197
response_length = response.size(1)

0 commit comments

Comments
 (0)