@@ -167,6 +167,7 @@ def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto:
167
167
'top_k' : - 1 ,
168
168
'min_p' : 0.0 ,
169
169
'temperature' : 0 ,
170
+ 'n' : 1 # if greedy, only 1 response
170
171
}
171
172
172
173
# users can customize different sampling_params at different run
@@ -177,13 +178,20 @@ def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto:
177
178
prompt_token_ids = idx_list ,
178
179
use_tqdm = False )
179
180
180
- response = output [0 ].to (idx .device ) # (bs, response_length)
181
- log_probs = output [1 ].to (idx .device ) # (bs, response_length)
181
+ # TODO(sgm): disable logprob when recompute_log_prob is enable
182
+ # if n = 1: (bs, response_length) ; if n > 1: (bs * n, response_length)
183
+ response = output [0 ].to (idx .device )
184
+ log_probs = output [1 ].to (idx .device )
182
185
183
186
if response .shape [1 ] < self .config .response_length :
184
187
response = pad_sequence_to_length (response , self .config .response_length , self .pad_token_id )
185
188
log_probs = pad_sequence_to_length (log_probs , self .config .response_length , self .pad_token_id )
186
189
190
+ if self .config .n > 1 and do_sample :
191
+ idx = idx .repeat_interleave (self .config .n , dim = 0 )
192
+ attention_mask = attention_mask .repeat_interleave (self .config .n , dim = 0 )
193
+ position_ids = position_ids .repeat_interleave (self .config .n , dim = 0 )
194
+ batch_size = batch_size * self .config .n
187
195
seq = torch .cat ([idx , response ], dim = - 1 )
188
196
189
197
response_length = response .size (1 )
0 commit comments