Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,12 @@
GPU_REQUIRES = ["liger-kernel", "flash-attn"]
MATH_REQUIRES = ["math-verify"] # Add math-verify as an optional dependency
VLLM_REQUIRES = ["tensordict<=0.6.2", "vllm<=0.8.5"]
SGLANG_REQUIRES = ["tensordict<=0.6.2", "sglang[srt,openai]==0.4.6.post4", "torch-memory-saver>=0.0.5", "torch==2.6.0"]
SGLANG_REQUIRES = [
"tensordict<=0.6.2",
"sglang[srt,openai]==0.4.6.post4",
"torch-memory-saver>=0.0.5",
"torch==2.6.0",
]

extras_require = {
"test": TEST_REQUIRES,
Expand Down
135 changes: 101 additions & 34 deletions verl/workers/rollout/sglang_rollout/sglang_rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from json import JSONDecodeError
from typing import TYPE_CHECKING
from uuid import uuid4

import math
import numpy as np
import torch
import torch.distributed as dist
Expand Down Expand Up @@ -213,7 +213,23 @@ def __init__(
model_hf_config.max_position_embeddings >= self.config.max_model_len
), "model context length should be greater than total sequence length"

# `max_turns` stands for max number of tool calls
# `max_turns` defines the maximum number of tool-calling
# rounds within a single request.

# In multi-turn RL scenarios, each request to the rollout
# engine can invoke tools multiple times.

# The entire multi-turn trajectory is then used for a
# single actor and critic training step before being
# discarded.

# Tool-calling ends under three conditions:
# 1. The `max_turns` limit is reached.
# 2. The trajectory exceeds the model's context length.
# 3. A tool returns a terminal state.

# If `max_turns` is not explicitly set, it defaults
# to `max_model_len // 3`, which is a large number.
if self.config.multi_turn.max_turns is None:
self.config.multi_turn.max_turns = self.config.max_model_len // 3

Expand Down Expand Up @@ -248,7 +264,9 @@ def __init__(
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(sorted(list(visible_devices_set)))

# initialize the inference engine
nnodes = -(-tp_size // len(visible_devices_set))
visible_devices_count = len(visible_devices_set)
nnodes = math.ceil(tp_size / visible_devices_count)

if nnodes > 1:
ip = get_ip()
port = get_open_port() if port is None else port
Expand All @@ -263,6 +281,19 @@ def __init__(
else:
dist_init_addr = None

# The format of the model weights to be loaded.

# “auto” will try to load the weights in the safetensors format and
# fall back to the pytorch bin format if safetensors format is not
# available.
# “pt” will load the weights in the pytorch bin format.
# “safetensors” will load the weights in the safetensors format.
# “dummy” will initialize the weights with random values, which is
# mainly for profiling.
# “bitsandbytes” will load the weights using bitsandbytes quantization.
# “npcache” will load the weights in pytorch format and store a numpy
# cache to speed up the loading.

load_format = (
"dummy" if config.load_format.startswith("dummy") else config.load_format
)
Expand Down Expand Up @@ -306,32 +337,50 @@ def __init__(
if self._tp_rank == 0:
self._engine.release_memory_occupation()

kwargs = dict(
n=1,
max_new_tokens=config.response_length,
presence_penalty=0.0,
frequency_penalty=0.0,
repetition_penalty=1.0,
)
# supporting adding any sampling params from the config file
# Initialize sampling parameters with default values
self.sampling_params = {
"n": 1,
"max_new_tokens": config.response_length,
"presence_penalty": 0.0,
"frequency_penalty": 0.0,
"repetition_penalty": 1.0,
}

# Override defaults with any corresponding values from
# the config that are valid SGLang SamplingParams attributes.

for k in config.keys():
if hasattr(SamplingParams(), str(k)):
kwargs[k] = config.get(k)
print(f"kwargs: {kwargs}")
self.sampling_params = kwargs
self.sampling_params[k] = config.get(k)

print(f"Sampling parameters: {self.sampling_params}")

self.tokenizer = tokenizer
self.pad_token_id = tokenizer.pad_token_id

def _initialize_tools(self, config, tokenizer):
"""Initialize tools from configuration.
"""Initializes external tools.

Args:
config: Configuration object containing tool settings
tokenizer: Tokenizer instance for tool call parsing
config: Configuration object containing tool-related settings,
specifically `config.multi_turn.tool_config_path`.
tokenizer: The tokenizer instance used for parsing tool calls from
the model's generated text.

Returns:
tuple: (tool_schemas, tool_map, tool_call_parser_type, sgl_tools, function_call_parser)
tuple: A tuple containing:
- tool_schemas (list[dict]): OpenAI-formatted JSON schemas
defining each tool's capabilities.
- tool_map (dict[str, BaseTool]): A dictionary mapping tool
names to their executable `BaseTool` objects.
- tool_call_parser_type (str): The identifier for the specific
parser type (e.g., 'json_mode', 'tool_code') used to extract
tool calls.
- sgl_tools (list[sglang.srt.openai_api.protocol.Tool]): Tool
definitions optimized for SGLang's internal engine.
- function_call_parser (sglang.srt.function_call_parser.FunctionCallParser):
The active parser instance responsible for extracting
structured tool calls from model outputs.
"""
if config.multi_turn.tool_config_path is None:
return [], {}, None, [], None
Expand Down Expand Up @@ -363,7 +412,7 @@ def initialize_tools_from_config(tools_config) -> list:
tool_schema_dict = OmegaConf.to_container(
tool_config.tool_schema, resolve=True
)
tool_schema = OpenAIFunctionToolSchema.parse_obj(tool_schema_dict)
tool_schema = OpenAIFunctionToolSchema.model_validate(tool_schema_dict)

tool = tool_cls(
config=OmegaConf.to_container(tool_config.config, resolve=True),
Expand Down Expand Up @@ -398,19 +447,38 @@ def initialize_tools_from_config(tools_config) -> list:

@contextmanager
def update_sampling_params(self, **kwargs):
# update sampling params
old_sampling_params_args = {}
if kwargs:
for key, value in kwargs.items():
if key in self.sampling_params:
old_value = self.sampling_params[key]
old_sampling_params_args[key] = old_value
self.sampling_params[key] = value
yield
# roll back to previous sampling params
# if len(old_sampling_params_args):
for key, value in old_sampling_params_args.items():
self.sampling_params[key] = value
"""
Temporarily updates the model's sampling parameters for the
duration of a `with` block. Parameters are automatically fall
back to their original values upon exiting the block.

Args:
**kwargs: Keyword arguments representing sampling parameters
to be updated. Only parameters that already exist in
`self.sampling_params` will be updated.
"""
# Store original values of parameters that will be updated
old_sampling_params_args = {
key: self.sampling_params[key]
for key in kwargs
if key in self.sampling_params
}

# Update sampling parameters with new values
for key, value in kwargs.items():
if key in self.sampling_params:
self.sampling_params[key] = value
else:
logger.warning(f"Sampling parameter {key} is not supported by SGLang.")

try:
yield
# Yield and execute the code within the 'with' block
finally:
# Always restore original values, even if an error occurred
# in the `with` block
for key, value in old_sampling_params_args.items():
self.sampling_params[key] = value

@GPUMemoryLogger(role="sglang rollout", logger=logger)
@torch.no_grad()
Expand Down Expand Up @@ -795,9 +863,8 @@ async def calc_reward_and_release_fn(name: str, tool: BaseTool):
@GPUMemoryLogger(role="sglang rollout", logger=logger)
@torch.no_grad()
def generate_sequences_with_tools(self, prompts: DataProto, **kwargs) -> DataProto:
import warnings

warnings.warn(
logger.warning(
"`generate_sequences_with_tools` is deprecated, please use `generate_sequences(...)`",
DeprecationWarning,
stacklevel=2,
Expand Down