diff --git a/setup.py b/setup.py index 6075bb2af8b..3e58f3aae9b 100644 --- a/setup.py +++ b/setup.py @@ -49,7 +49,12 @@ GPU_REQUIRES = ["liger-kernel", "flash-attn"] MATH_REQUIRES = ["math-verify"] # Add math-verify as an optional dependency VLLM_REQUIRES = ["tensordict<=0.6.2", "vllm<=0.8.5"] -SGLANG_REQUIRES = ["tensordict<=0.6.2", "sglang[srt,openai]==0.4.6.post4", "torch-memory-saver>=0.0.5", "torch==2.6.0"] +SGLANG_REQUIRES = [ + "tensordict<=0.6.2", + "sglang[srt,openai]==0.4.6.post4", + "torch-memory-saver>=0.0.5", + "torch==2.6.0", +] extras_require = { "test": TEST_REQUIRES, diff --git a/verl/workers/rollout/sglang_rollout/sglang_rollout.py b/verl/workers/rollout/sglang_rollout/sglang_rollout.py index ae85c3ad2fc..6a3ed180229 100644 --- a/verl/workers/rollout/sglang_rollout/sglang_rollout.py +++ b/verl/workers/rollout/sglang_rollout/sglang_rollout.py @@ -22,7 +22,7 @@ from json import JSONDecodeError from typing import TYPE_CHECKING from uuid import uuid4 - +import math import numpy as np import torch import torch.distributed as dist @@ -213,7 +213,23 @@ def __init__( model_hf_config.max_position_embeddings >= self.config.max_model_len ), "model context length should be greater than total sequence length" - # `max_turns` stands for max number of tool calls + # `max_turns` defines the maximum number of tool-calling + # rounds within a single request. + + # In multi-turn RL scenarios, each request to the rollout + # engine can invoke tools multiple times. + + # The entire multi-turn trajectory is then used for a + # single actor and critic training step before being + # discarded. + + # Tool-calling ends under three conditions: + # 1. The `max_turns` limit is reached. + # 2. The trajectory exceeds the model's context length. + # 3. A tool returns a terminal state. + + # If `max_turns` is not explicitly set, it defaults + # to `max_model_len // 3`, which is a large number. if self.config.multi_turn.max_turns is None: self.config.multi_turn.max_turns = self.config.max_model_len // 3 @@ -248,7 +264,9 @@ def __init__( os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(sorted(list(visible_devices_set))) # initialize the inference engine - nnodes = -(-tp_size // len(visible_devices_set)) + visible_devices_count = len(visible_devices_set) + nnodes = math.ceil(tp_size / visible_devices_count) + if nnodes > 1: ip = get_ip() port = get_open_port() if port is None else port @@ -263,6 +281,19 @@ def __init__( else: dist_init_addr = None + # The format of the model weights to be loaded. + + # “auto” will try to load the weights in the safetensors format and + # fall back to the pytorch bin format if safetensors format is not + # available. + # “pt” will load the weights in the pytorch bin format. + # “safetensors” will load the weights in the safetensors format. + # “dummy” will initialize the weights with random values, which is + # mainly for profiling. + # “bitsandbytes” will load the weights using bitsandbytes quantization. + # “npcache” will load the weights in pytorch format and store a numpy + # cache to speed up the loading. + load_format = ( "dummy" if config.load_format.startswith("dummy") else config.load_format ) @@ -306,32 +337,50 @@ def __init__( if self._tp_rank == 0: self._engine.release_memory_occupation() - kwargs = dict( - n=1, - max_new_tokens=config.response_length, - presence_penalty=0.0, - frequency_penalty=0.0, - repetition_penalty=1.0, - ) - # supporting adding any sampling params from the config file + # Initialize sampling parameters with default values + self.sampling_params = { + "n": 1, + "max_new_tokens": config.response_length, + "presence_penalty": 0.0, + "frequency_penalty": 0.0, + "repetition_penalty": 1.0, + } + + # Override defaults with any corresponding values from + # the config that are valid SGLang SamplingParams attributes. + for k in config.keys(): if hasattr(SamplingParams(), str(k)): - kwargs[k] = config.get(k) - print(f"kwargs: {kwargs}") - self.sampling_params = kwargs + self.sampling_params[k] = config.get(k) + + print(f"Sampling parameters: {self.sampling_params}") self.tokenizer = tokenizer self.pad_token_id = tokenizer.pad_token_id def _initialize_tools(self, config, tokenizer): - """Initialize tools from configuration. + """Initializes external tools. Args: - config: Configuration object containing tool settings - tokenizer: Tokenizer instance for tool call parsing + config: Configuration object containing tool-related settings, + specifically `config.multi_turn.tool_config_path`. + tokenizer: The tokenizer instance used for parsing tool calls from + the model's generated text. Returns: - tuple: (tool_schemas, tool_map, tool_call_parser_type, sgl_tools, function_call_parser) + tuple: A tuple containing: + - tool_schemas (list[dict]): OpenAI-formatted JSON schemas + defining each tool's capabilities. + - tool_map (dict[str, BaseTool]): A dictionary mapping tool + names to their executable `BaseTool` objects. + - tool_call_parser_type (str): The identifier for the specific + parser type (e.g., 'json_mode', 'tool_code') used to extract + tool calls. + - sgl_tools (list[sglang.srt.openai_api.protocol.Tool]): Tool + definitions optimized for SGLang's internal engine. + - function_call_parser (sglang.srt.function_call_parser.FunctionCallParser): + The active parser instance responsible for extracting + structured tool calls from model outputs. """ if config.multi_turn.tool_config_path is None: return [], {}, None, [], None @@ -363,7 +412,7 @@ def initialize_tools_from_config(tools_config) -> list: tool_schema_dict = OmegaConf.to_container( tool_config.tool_schema, resolve=True ) - tool_schema = OpenAIFunctionToolSchema.parse_obj(tool_schema_dict) + tool_schema = OpenAIFunctionToolSchema.model_validate(tool_schema_dict) tool = tool_cls( config=OmegaConf.to_container(tool_config.config, resolve=True), @@ -398,19 +447,38 @@ def initialize_tools_from_config(tools_config) -> list: @contextmanager def update_sampling_params(self, **kwargs): - # update sampling params - old_sampling_params_args = {} - if kwargs: - for key, value in kwargs.items(): - if key in self.sampling_params: - old_value = self.sampling_params[key] - old_sampling_params_args[key] = old_value - self.sampling_params[key] = value - yield - # roll back to previous sampling params - # if len(old_sampling_params_args): - for key, value in old_sampling_params_args.items(): - self.sampling_params[key] = value + """ + Temporarily updates the model's sampling parameters for the + duration of a `with` block. Parameters are automatically fall + back to their original values upon exiting the block. + + Args: + **kwargs: Keyword arguments representing sampling parameters + to be updated. Only parameters that already exist in + `self.sampling_params` will be updated. + """ + # Store original values of parameters that will be updated + old_sampling_params_args = { + key: self.sampling_params[key] + for key in kwargs + if key in self.sampling_params + } + + # Update sampling parameters with new values + for key, value in kwargs.items(): + if key in self.sampling_params: + self.sampling_params[key] = value + else: + logger.warning(f"Sampling parameter {key} is not supported by SGLang.") + + try: + yield + # Yield and execute the code within the 'with' block + finally: + # Always restore original values, even if an error occurred + # in the `with` block + for key, value in old_sampling_params_args.items(): + self.sampling_params[key] = value @GPUMemoryLogger(role="sglang rollout", logger=logger) @torch.no_grad() @@ -795,9 +863,8 @@ async def calc_reward_and_release_fn(name: str, tool: BaseTool): @GPUMemoryLogger(role="sglang rollout", logger=logger) @torch.no_grad() def generate_sequences_with_tools(self, prompts: DataProto, **kwargs) -> DataProto: - import warnings - warnings.warn( + logger.warning( "`generate_sequences_with_tools` is deprecated, please use `generate_sequences(...)`", DeprecationWarning, stacklevel=2,