Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions verl/workers/rollout/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,12 +103,12 @@ class AsyncRolloutRequest(BaseModel):
}
}

def get_generation_prompt(self, tokenizer: PreTrainedTokenizer) -> str:
def get_generation_prompt(self, tokenizer: PreTrainedTokenizer) -> list[int]:
return tokenizer.apply_chat_template( # type: ignore
conversation=[msg.model_dump() for msg in self.messages],
tools=[tool.model_dump() for tool in self.tools] if self.tools else None,
add_generation_prompt=True,
tokenize=False,
tokenize=True,
)

def add_assistant_message(
Expand Down
10 changes: 7 additions & 3 deletions verl/workers/rollout/sglang_rollout/async_sglang_rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,7 +482,11 @@ async def _async_rollout_a_request(self, req: AsyncRolloutRequest, do_sample: bo
else:
raise ValueError(f"Unexpected tool calling last message state: {_req.messages[-1]}")
elif _req.state == AsyncRolloutRequestStateEnum.RUNNING:
generation_prompt = _req.get_generation_prompt(self.tokenizer)
generation_prompt_ids = _req.get_generation_prompt(self.tokenizer)
max_new_tokens = min(self.config.response_length, self.config.max_model_len - len(generation_prompt_ids) - 1)
if max_new_tokens <= 0:
finish_reason_type = FinishReasonTypeEnum.STOP
break
if not do_sample:
kwargs = dict(
n=1,
Expand All @@ -494,7 +498,6 @@ async def _async_rollout_a_request(self, req: AsyncRolloutRequest, do_sample: bo
top_k=-1,
ignore_eos=False,
min_new_tokens=0,
max_new_tokens=self.config.response_length,
skip_special_tokens=True,
spaces_between_special_tokens=True,
)
Expand All @@ -506,12 +509,13 @@ async def _async_rollout_a_request(self, req: AsyncRolloutRequest, do_sample: bo
"temperature": self.config.val_kwargs.temperature,
"n": 1, # if validate, already repeat in ray_trainer
}
kwargs["max_new_tokens"] = max_new_tokens
if "n" not in kwargs or kwargs["n"] > 1: # group size is supported in preprocess
kwargs["n"] = 1
# users can customize different sampling_params at different run
with self.update_sampling_params(**kwargs):
output = await self._engine.async_generate(
prompt=generation_prompt,
input_ids=generation_prompt_ids,
sampling_params=self.sampling_params,
return_logprob=False,
)
Expand Down
Loading