diff --git a/verl/experimental/fully_async_policy/vllm_rollout/vllm_async_server.py b/verl/experimental/fully_async_policy/vllm_rollout/vllm_async_server.py index d45dbec890f..865019d0cec 100644 --- a/verl/experimental/fully_async_policy/vllm_rollout/vllm_async_server.py +++ b/verl/experimental/fully_async_policy/vllm_rollout/vllm_async_server.py @@ -73,8 +73,35 @@ async def _generate_step( image_data: Optional[list[Any]] = None, video_data: Optional[list[Any]] = None, ): + """Generate sequence with token-in-token-out.""" prompt_ids = normalize_token_ids(prompt_ids) - max_tokens = self.config.max_model_len - len(prompt_ids) + + # Calculate the maximum possible new tokens based on available context space + # This serves as a safety upper bound + max_possible_tokens = self.config.max_model_len - len(prompt_ids) + if max_possible_tokens < 0: + raise ValueError( + f"Prompt length ({len(prompt_ids)}) exceeds the model's maximum context length " + f"({self.config.max_model_len})." + ) + + # Determine max_tokens from sampling_params or use configured response_length as default + if "max_tokens" in sampling_params: + max_tokens = sampling_params.pop("max_tokens") + elif "max_new_tokens" in sampling_params: + # support sglang-style 'max_new_tokens' param + max_tokens = sampling_params.pop("max_new_tokens") + else: + # Normal case: len(prompt_ids) <= prompt_length → max_tokens = response_length (no reduction) + # Multi-turn / partial: len(prompt_ids) > prompt_length → previous responses are concatenated + # into prompt_ids, causing it to exceed the configured prompt_length + max_tokens = self.config.response_length - max(0, len(prompt_ids) - self.config.prompt_length) + + # Clamp max_tokens to the valid range [0, max_possible_tokens] + max_tokens = max(0, min(max_tokens, max_possible_tokens)) + assert max_tokens <= max_possible_tokens, ( + f"max_tokens {max_tokens} exceeds available context space {max_possible_tokens}" + ) sampling_params["logprobs"] = 1 sampling_params.setdefault("repetition_penalty", self.config.get("repetition_penalty", 1.0)) sampling_params = SamplingParams(max_tokens=max_tokens, **sampling_params) diff --git a/verl/workers/rollout/vllm_rollout/vllm_async_server.py b/verl/workers/rollout/vllm_rollout/vllm_async_server.py index 1e30a01beda..3411ae39be7 100644 --- a/verl/workers/rollout/vllm_rollout/vllm_async_server.py +++ b/verl/workers/rollout/vllm_rollout/vllm_async_server.py @@ -531,8 +531,10 @@ async def generate( # support sglang-style 'max_new_tokens' param max_tokens = sampling_params.pop("max_new_tokens") else: - # Default to a calculation that considers configured lengths - max_tokens = self.config.response_length + self.config.prompt_length - len(prompt_ids) + # Normal case: len(prompt_ids) <= prompt_length → max_tokens = response_length (no reduction) + # Multi-turn / partial: len(prompt_ids) > prompt_length → previous responses are concatenated + # into prompt_ids, causing it to exceed the configured prompt_length + max_tokens = self.config.response_length - max(0, len(prompt_ids) - self.config.prompt_length) # Clamp max_tokens to the valid range [0, max_possible_tokens] max_tokens = max(0, min(max_tokens, max_possible_tokens))