diff --git a/verl/workers/rollout/schemas.py b/verl/workers/rollout/schemas.py index f43cfe02405..145fa7cf143 100644 --- a/verl/workers/rollout/schemas.py +++ b/verl/workers/rollout/schemas.py @@ -103,12 +103,12 @@ class AsyncRolloutRequest(BaseModel): } } - def get_generation_prompt(self, tokenizer: PreTrainedTokenizer) -> str: + def get_generation_prompt(self, tokenizer: PreTrainedTokenizer) -> list[int]: return tokenizer.apply_chat_template( # type: ignore conversation=[msg.model_dump() for msg in self.messages], tools=[tool.model_dump() for tool in self.tools] if self.tools else None, add_generation_prompt=True, - tokenize=False, + tokenize=True, ) def add_assistant_message( diff --git a/verl/workers/rollout/sglang_rollout/async_sglang_rollout.py b/verl/workers/rollout/sglang_rollout/async_sglang_rollout.py index ada5eb3d2a5..3e8102483b8 100644 --- a/verl/workers/rollout/sglang_rollout/async_sglang_rollout.py +++ b/verl/workers/rollout/sglang_rollout/async_sglang_rollout.py @@ -482,7 +482,11 @@ async def _async_rollout_a_request(self, req: AsyncRolloutRequest, do_sample: bo else: raise ValueError(f"Unexpected tool calling last message state: {_req.messages[-1]}") elif _req.state == AsyncRolloutRequestStateEnum.RUNNING: - generation_prompt = _req.get_generation_prompt(self.tokenizer) + generation_prompt_ids = _req.get_generation_prompt(self.tokenizer) + max_new_tokens = min(self.config.response_length, self.config.max_model_len - len(generation_prompt_ids) - 1) + if max_new_tokens <= 0: + finish_reason_type = FinishReasonTypeEnum.STOP + break if not do_sample: kwargs = dict( n=1, @@ -494,7 +498,6 @@ async def _async_rollout_a_request(self, req: AsyncRolloutRequest, do_sample: bo top_k=-1, ignore_eos=False, min_new_tokens=0, - max_new_tokens=self.config.response_length, skip_special_tokens=True, spaces_between_special_tokens=True, ) @@ -506,12 +509,13 @@ async def _async_rollout_a_request(self, req: AsyncRolloutRequest, do_sample: bo "temperature": self.config.val_kwargs.temperature, "n": 1, # if validate, already repeat in ray_trainer } + kwargs["max_new_tokens"] = max_new_tokens if "n" not in kwargs or kwargs["n"] > 1: # group size is supported in preprocess kwargs["n"] = 1 # users can customize different sampling_params at different run with self.update_sampling_params(**kwargs): output = await self._engine.async_generate( - prompt=generation_prompt, + input_ids=generation_prompt_ids, sampling_params=self.sampling_params, return_logprob=False, )