diff --git a/verl/workers/rollout/sglang_rollout/async_sglang_server.py b/verl/workers/rollout/sglang_rollout/async_sglang_server.py index 734a8fd977b..0aa9d540926 100644 --- a/verl/workers/rollout/sglang_rollout/async_sglang_server.py +++ b/verl/workers/rollout/sglang_rollout/async_sglang_server.py @@ -271,9 +271,6 @@ async def generate( max_new_tokens = sampling_params.pop("max_tokens") else: max_new_tokens = response_length - assert max_new_tokens <= response_length, ( - f"max_new_tokens {max_new_tokens} exceeds available response_length {response_length}" - ) sampling_params["max_new_tokens"] = max_new_tokens return_logprob = sampling_params.pop("logprobs", False) diff --git a/verl/workers/rollout/vllm_rollout/vllm_async_server.py b/verl/workers/rollout/vllm_rollout/vllm_async_server.py index 759da0c705d..ef8010e70d9 100644 --- a/verl/workers/rollout/vllm_rollout/vllm_async_server.py +++ b/verl/workers/rollout/vllm_rollout/vllm_async_server.py @@ -470,9 +470,6 @@ async def generate( max_tokens = sampling_params.pop("max_new_tokens") else: max_tokens = response_length - assert max_tokens <= response_length, ( - f"max_tokens {max_tokens} exceeds available response_length {response_length}" - ) sampling_params["logprobs"] = 0 if sampling_params.pop("logprobs", False) else None sampling_params.setdefault("repetition_penalty", self.config.get("repetition_penalty", 1.0)) sampling_params = SamplingParams(max_tokens=max_tokens, **sampling_params)