diff --git a/verl/workers/rollout/sglang_rollout/async_sglang_rollout.py b/verl/workers/rollout/sglang_rollout/async_sglang_rollout.py index 84eba8a8821..369754747e8 100644 --- a/verl/workers/rollout/sglang_rollout/async_sglang_rollout.py +++ b/verl/workers/rollout/sglang_rollout/async_sglang_rollout.py @@ -48,4 +48,4 @@ def __init__( trust_remote_code=trust_remote_code, device_mesh=device_mesh, **kwargs, - ) \ No newline at end of file + ) diff --git a/verl/workers/rollout/sglang_rollout/sglang_rollout.py b/verl/workers/rollout/sglang_rollout/sglang_rollout.py index ad7d5439ed4..d33f6d24685 100644 --- a/verl/workers/rollout/sglang_rollout/sglang_rollout.py +++ b/verl/workers/rollout/sglang_rollout/sglang_rollout.py @@ -40,11 +40,18 @@ from verl import DataProto from verl.third_party.sglang import parallel_state as sglang_ps from verl.tools.base_tool import BaseTool -from verl.tools.schemas import OpenAIFunctionCallSchema, OpenAIFunctionParsedSchema, OpenAIFunctionToolCall +from verl.tools.schemas import ( + OpenAIFunctionCallSchema, + OpenAIFunctionParsedSchema, + OpenAIFunctionToolCall, +) from verl.utils.debug import GPUMemoryLogger from verl.utils.model import compute_position_id_with_mask from verl.utils.net_utils import is_ipv6 -from verl.utils.torch_functional import get_response_mask, pad_sequence_to_length +from verl.utils.torch_functional import ( + get_response_mask, + pad_sequence_to_length, +) from verl.workers.rollout.base import BaseRollout from verl.workers.rollout.schemas import ( AsyncRolloutRequest, @@ -61,12 +68,16 @@ logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) -# NOTE(sgm): add for verl. We can optimize it by making the dataloader yield List[int] without padding. -def _pre_process_inputs(pad_token_id, prompt_token_ids: torch.Tensor) -> list[int]: +# NOTE(sgm): add for verl. We can optimize it by making +# the dataloader yield List[int] without padding. +def _pre_process_inputs( + pad_token_id, + prompt_token_ids: torch.Tensor, +) -> list[int]: # remove the left padding in the prompt token_id - # pad_token_id = self.llm_engine.tokenizer.pad_token_id if self.llm_engine.tokenizer.pad_token_id is - # not None else self.llm_engine.tokenizer.eos_token_id - non_pad_index = torch.nonzero(prompt_token_ids != pad_token_id, as_tuple=False)[0][0] + non_pad_index = torch.nonzero(prompt_token_ids != pad_token_id, as_tuple=False)[0][ + 0 + ] token_ids = prompt_token_ids[non_pad_index:].tolist() return token_ids @@ -74,14 +85,11 @@ def _pre_process_inputs(pad_token_id, prompt_token_ids: torch.Tensor) -> list[in # NOTE(linjunrong): adhoc def _post_process_outputs(tokenizer, output): def _map_each_response(resp): - log_probs = [] - output_token_ids = [] - for log_prob, token_ids, _ in resp["meta_info"]["output_token_logprobs"]: - log_probs.append(log_prob) - output_token_ids.append(token_ids) - log_probs = torch.tensor(log_probs) - output_token_ids = torch.tensor(output_token_ids) - return output_token_ids, log_probs + output_token_logprobs = resp["meta_info"]["output_token_logprobs"] + log_probs, output_token_ids = zip( + *[(log_prob, token_ids) for log_prob, token_ids, _ in output_token_logprobs] + ) + return torch.tensor(output_token_ids), torch.tensor(log_probs) out_map = map(lambda x: _map_each_response(x), output) batched_output_token_ids = [] @@ -89,17 +97,28 @@ def _map_each_response(resp): for output_token_ids, log_probs in out_map: batched_output_token_ids.append(output_token_ids) batched_logprobs.append(log_probs) - pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id - batched_output_token_ids = pad_sequence(batched_output_token_ids, batch_first=True, padding_value=pad_token_id) + pad_token_id = ( + tokenizer.pad_token_id + if tokenizer.pad_token_id is not None + else tokenizer.eos_token_id + ) + batched_output_token_ids = pad_sequence( + batched_output_token_ids, batch_first=True, padding_value=pad_token_id + ) if len(batched_logprobs) > 0: - batched_logprobs = pad_sequence(batched_logprobs, batch_first=True, padding_value=pad_token_id) + batched_logprobs = pad_sequence( + batched_logprobs, batch_first=True, padding_value=pad_token_id + ) return batched_output_token_ids, batched_logprobs def get_tool_call_parser_type(tokenizer: PreTrainedTokenizer) -> str: - for parser_type, parser_cls in FunctionCallParser.ToolCallParserEnum.items(): + items = FunctionCallParser.ToolCallParserEnum.items() + for parser_type, parser_cls in items: parser = parser_cls() - if parser.bot_token in tokenizer.get_vocab() and (parser.eot_token == "" or parser.eot_token in tokenizer.get_vocab()): + if parser.bot_token in tokenizer.get_vocab() and ( + parser.eot_token == "" or parser.eot_token in tokenizer.get_vocab() + ): return parser_type else: raise ValueError(f"No tool call parser found for tokenizer {tokenizer}") @@ -108,7 +127,7 @@ def get_tool_call_parser_type(tokenizer: PreTrainedTokenizer) -> str: class SGLangRollout(BaseRollout): def __init__( self, - actor_module: nn.Module | str, + actor_module: str, config: DictConfig, tokenizer, model_hf_config, @@ -117,31 +136,65 @@ def __init__( device_mesh: DeviceMesh | None = None, **kwargs, ): - """A SGLang rollout. It requires the module is supported by the SGLang. + """Synchronized SGLang rollout engine. Args: - actor_module: module here follows huggingface APIs - config: DictConfig - tokenizer: the task/model tokenizer - model_hf_config: the huggingface config to initiallize the generating model in SGLang - **kwargs: train_tp, for Megatron Backend to initialize hybrid engine (zero redundancy) process group + actor_module: Huggingface model name or path to the model. The + model should be supported by SGLang. + config: A DictConfig object containing SGLang-specific operational + parameters and rollout settings. + Refer to https://docs.sglang.ai/backend/server_arguments.html + tokenizer: The tokenizer instance compatible with the actor_module. + model_hf_config: The Hugging Face model's configuration (e.g., + `transformers.PretrainedConfig`). It provides architectural + details and hyperparameters like `max_position_embeddings`, + used by SGLang for correct model initialization. This is + the model's inherent design, not SGLang's runtime behavior. + port: Optional port for multi-node initialization when nnodes > 1. + trust_remote_code: Whether or not to allow for custom models + defined on the Hub in their own modeling files. + device_mesh: Optional `DeviceMesh` object for distributed setup. + **kwargs: Additional keyword arguments, primarily `train_tp` for + Megatron Backend integration to initialize hybrid engine + process groups. """ super().__init__() self.config = config self._device_mesh_cpu = device_mesh os.environ.setdefault("SGL_DISABLE_TP_MEMORY_INBALANCE_CHECK", "true") - self._tool_schemas, self._tool_map, self._tool_call_parser_type, self._sgl_tools, self._function_call_parser = self._initialize_tools(config, tokenizer) - assert not (not config.enforce_eager and config.free_cache_engine), "disable CUDA graph (enforce_eager = False) if free cache engine" + ( + self._tool_schemas, + self._tool_map, + self._tool_call_parser_type, + self._sgl_tools, + self._function_call_parser, + ) = self._initialize_tools(config, tokenizer) + # If turn on `free_cache_engine`, SGLang engine's KV cache + # will be freed after each `generate_sequences` call. + assert not ( + not config.enforce_eager and config.free_cache_engine + ), "disable CUDA graph (enforce_eager = False) if free cache engine" tensor_parallel_size = self.config.get("tensor_model_parallel_size", 1) - assert tensor_parallel_size <= dist.get_world_size(), "tensor parallel size should be less than or equal to the world size" + assert ( + tensor_parallel_size <= dist.get_world_size() + ), "tensor parallel size should be less than or equal to world size" if kwargs.get("train_tp", None) is not None: - # deployed with megatron + # `train_tp` is the total tensor parallelism size for Megatron. + # `tensor_parallel_size` is SGLang tensor parallel size. + # `num_tp_per_train_tp` is how many SGLang tensor parallel + # groups fit into one Megatron training tensor parallel group. + + # This is crucial for aligning SGLang's internal parallel + # groups with Megatron's distributed setup where SGLang + # operates within a subset of Megatron's TP. + os.environ["CUDA_TIMER_STREAM_KAFKA_ENABLE"] = "0" os.environ["MEGATRON_IMPORT_TIMERS"] = "0" train_tp = kwargs.get("train_tp", None) + num_tp_per_train_tp = train_tp // tensor_parallel_size sglang_ps.initialize_parallel_state( tensor_model_parallel_size=tensor_parallel_size, @@ -149,11 +202,24 @@ def __init__( ) if not self.config.get("max_model_len", None): - self.config.max_model_len = self.config.prompt_length + self.config.response_length - assert self.config.max_model_len >= self.config.prompt_length + self.config.response_length, f"""max_model_len should be greater than total sequence length (prompt_length + response_length): - {self.config.max_model_len} >= {self.config.prompt_length} + {self.config.response_length}""" - assert model_hf_config.max_position_embeddings >= self.config.max_model_len, "model context length should be greater than total sequence length" - # currently max_turns stand for max number of tool calls + self.config.max_model_len = ( + self.config.prompt_length + self.config.response_length + ) + assert ( + self.config.max_model_len + >= self.config.prompt_length + self.config.response_length + ), f"""max_model_len should be greater than total sequence length + (prompt_length + response_length): {self.config.max_model_len} >= + {self.config.prompt_length} + {self.config.response_length}""" + assert ( + model_hf_config.max_position_embeddings >= self.config.max_model_len + ), "model context length should be greater than total sequence length" + + # TODO(chenyang): is `max_turns` the max number of tool call rounds? + # If so, I think it should be a small number like 3 or 5. + # self.config.max_model_len // 3 is a large number. + + # `max_turns` stands for max number of tool calls if self.config.multi_turn.max_turns is None: self.config.multi_turn.max_turns = self.config.max_model_len // 3 @@ -167,7 +233,10 @@ def __init__( mesh_dim_names=["dp", "tp", "pp"], ) - self._device_mesh_cpu = init_device_mesh("cpu", **device_mesh_kwargs) + self._device_mesh_cpu = init_device_mesh( + "cpu", + **device_mesh_kwargs, + ) self._rank = self._device_mesh_cpu.get_rank() self._tp_rank = self._device_mesh_cpu["tp"].get_local_rank() @@ -176,11 +245,16 @@ def __init__( # get tp_rank of this process in this tp group visible_devices = [None] * self._device_mesh_cpu.size(1) - torch.distributed.all_gather_object(visible_devices, os.environ["CUDA_VISIBLE_DEVICES"], self._device_mesh_cpu.get_group("tp")) + torch.distributed.all_gather_object( + visible_devices, + os.environ["CUDA_VISIBLE_DEVICES"], + self._device_mesh_cpu.get_group("tp"), + ) visible_devices_set = set(",".join(visible_devices).split(",")) os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(sorted(list(visible_devices_set))) # initialize the inference engine + # TODO(chenyang): Strange way to calculate nnodes. nnodes = -(-tp_size // len(visible_devices_set)) if nnodes > 1: ip = get_ip() @@ -196,7 +270,9 @@ def __init__( else: dist_init_addr = None - load_format = "dummy" if config.load_format.startswith("dummy") else config.load_format + load_format = ( + "dummy" if config.load_format.startswith("dummy") else config.load_format + ) tp_size_per_node = self._tp_size // nnodes node_rank = self._tp_rank // tp_size_per_node first_rank_in_node = self._tp_rank % tp_size_per_node == 0 @@ -217,12 +293,14 @@ def __init__( dist_init_addr=dist_init_addr, nnodes=nnodes, trust_remote_code=trust_remote_code, - # NOTE(linjunrong): add rank to prevent SGLang generate same port inside PortArgs.init_new + # NOTE(linjunrong): add rank to prevent SGLang + # generate same port inside PortArgs.init_new # when random.seed is being set during training port=30000 + rank, - # NOTE(Chenyang): if you want to debug the SGLang engine output - # please set the following parameters - # Otherwise, it will make the engine run too slow + # NOTE(Chenyang): if you want to debug the SGLang + # engine output, please set these parameters. + # Do not set them in production. + # It will make the engine run too slow. # log_level="INFO", # log_requests=True, # log_requests_level=2, @@ -289,10 +367,15 @@ def initialize_tools_from_config(tools_config) -> list: tool_cls = getattr(module, class_name) - tool_schema_dict = OmegaConf.to_container(tool_config.tool_schema, resolve=True) + tool_schema_dict = OmegaConf.to_container( + tool_config.tool_schema, resolve=True + ) tool_schema = OpenAIFunctionToolSchema.parse_obj(tool_schema_dict) - tool = tool_cls(config=OmegaConf.to_container(tool_config.config, resolve=True), tool_schema=tool_schema) + tool = tool_cls( + config=OmegaConf.to_container(tool_config.config, resolve=True), + tool_schema=tool_schema, + ) tool_list.append(tool) return tool_list @@ -301,7 +384,9 @@ def initialize_tools_from_config(tools_config) -> list: tools_config = OmegaConf.load(tools_config_file) tool_list = initialize_tools_from_config(tools_config) - tool_schemas = [tool.get_openai_tool_schema().model_dump() for tool in tool_list] + tool_schemas = [ + tool.get_openai_tool_schema().model_dump() for tool in tool_list + ] tool_map = {tool.name: tool for tool in tool_list} tool_call_parser_type = get_tool_call_parser_type(tokenizer) sgl_tools = [Tool.model_validate(tool_schema) for tool_schema in tool_schemas] @@ -310,7 +395,13 @@ def initialize_tools_from_config(tools_config) -> list: tool_call_parser_type, ) - return tool_schemas, tool_map, tool_call_parser_type, sgl_tools, function_call_parser + return ( + tool_schemas, + tool_map, + tool_call_parser_type, + sgl_tools, + function_call_parser, + ) @contextmanager def update_sampling_params(self, **kwargs): @@ -337,7 +428,9 @@ def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto: @GPUMemoryLogger(role="sglang rollout", logger=logger) @torch.no_grad() - def _batch_level_generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto: + def _batch_level_generate_sequences( + self, prompts: DataProto, **kwargs + ) -> DataProto: idx = prompts.batch["input_ids"] # (bs, prompt_length) # left-padded attention_mask attention_mask = prompts.batch["attention_mask"] @@ -351,31 +444,51 @@ def _batch_level_generate_sequences(self, prompts: DataProto, **kwargs) -> DataP # Extract non-tensor data non_tensor_batch = prompts.non_tensor_batch if "raw_prompt_ids" not in non_tensor_batch: - non_tensor_batch["raw_prompt_ids"] = np.array([_pre_process_inputs(self.pad_token_id, idx[i]) for i in range(batch_size)], dtype=object) + non_tensor_batch["raw_prompt_ids"] = np.array( + [ + _pre_process_inputs(self.pad_token_id, idx[i]) + for i in range(batch_size) + ], + dtype=object, + ) if "multi_modal_data" in non_tensor_batch: sglang_inputs = [] - for raw_prompt_ids, multi_modal_data in zip(non_tensor_batch.pop("raw_prompt_ids"), non_tensor_batch.pop("multi_modal_data")): + for raw_prompt_ids, multi_modal_data in zip( + non_tensor_batch.pop("raw_prompt_ids"), + non_tensor_batch.pop("multi_modal_data"), + ): sglang_inputs.append( { "prompt_token_ids": raw_prompt_ids, "multi_modal_data": multi_modal_data, - "image_data": multi_modal_data.get("image", None) if isinstance(multi_modal_data, dict) else None, + "image_data": ( + multi_modal_data.get("image", None) + if isinstance(multi_modal_data, dict) + else None + ), } ) else: - sglang_inputs = [{"prompt_token_ids": raw_prompt_ids} for raw_prompt_ids in non_tensor_batch.pop("raw_prompt_ids")] + sglang_inputs = [ + {"prompt_token_ids": raw_prompt_ids} + for raw_prompt_ids in non_tensor_batch.pop("raw_prompt_ids") + ] # Ensure token IDs are lists for input_data in sglang_inputs: if isinstance(input_data["prompt_token_ids"], np.ndarray): input_data["prompt_token_ids"] = input_data["prompt_token_ids"].tolist() elif not isinstance(input_data["prompt_token_ids"], list): - raise TypeError(f"prompt_token_ids must be a list or numpy array, got {type(input_data['prompt_token_ids'])}") + raise TypeError( + f"prompt_token_ids must be a list or numpy array, got {type(input_data['prompt_token_ids'])}" + ) # Extract token IDs and image data for SGLang Engine idx_list = [input_data["prompt_token_ids"] for input_data in sglang_inputs] - image_list = [input_data.get("image_data", None) for input_data in sglang_inputs] + image_list = [ + input_data.get("image_data", None) for input_data in sglang_inputs + ] do_sample = prompts.meta_info.get("do_sample", True) is_validate = prompts.meta_info.get("validate", False) @@ -432,24 +545,34 @@ def _batch_level_generate_sequences(self, prompts: DataProto, **kwargs) -> DataP # log_probs = out[1].to(idx.device) if response.shape[1] < self.config.response_length: - response = pad_sequence_to_length(response, self.config.response_length, self.pad_token_id) + response = pad_sequence_to_length( + response, self.config.response_length, self.pad_token_id + ) # log_probs = pad_sequence_to_length(log_probs, self.config.response_length, self.pad_token_id) # utilize current sampling params if self.sampling_params.get("n", 1) > 1 and do_sample: idx = idx.repeat_interleave(self.sampling_params["n"], dim=0) - attention_mask = attention_mask.repeat_interleave(self.sampling_params["n"], dim=0) - position_ids = position_ids.repeat_interleave(self.sampling_params["n"], dim=0) + attention_mask = attention_mask.repeat_interleave( + self.sampling_params["n"], dim=0 + ) + position_ids = position_ids.repeat_interleave( + self.sampling_params["n"], dim=0 + ) batch_size = batch_size * self.sampling_params["n"] _non_tensor_batch = {} for key, val in non_tensor_batch.items(): - _non_tensor_batch[key] = np.repeat(val, self.sampling_params["n"], axis=0) + _non_tensor_batch[key] = np.repeat( + val, self.sampling_params["n"], axis=0 + ) else: _non_tensor_batch = non_tensor_batch seq = torch.cat([idx, response], dim=-1) response_length = response.size(1) - delta_position_id = torch.arange(1, response_length + 1, device=position_ids.device) + delta_position_id = torch.arange( + 1, response_length + 1, device=position_ids.device + ) delta_position_id = delta_position_id.unsqueeze(0).repeat(batch_size, 1) # TODO(sgm): fix position_ids on right_pad @@ -458,7 +581,9 @@ def _batch_level_generate_sequences(self, prompts: DataProto, **kwargs) -> DataP # position_ids: [0,0,0,0,0,1,2,3, | 4,5,6,7,8,9,10,11] response_position_ids = position_ids[:, -1:] + delta_position_id position_ids = torch.cat([position_ids, response_position_ids], dim=-1) - response_attention_mask = get_response_mask(response_id=response, eos_token=eos_token_id, dtype=attention_mask.dtype) + response_attention_mask = get_response_mask( + response_id=response, eos_token=eos_token_id, dtype=attention_mask.dtype + ) attention_mask = torch.cat((attention_mask, response_attention_mask), dim=-1) # all the tp ranks should contain the same data here. data in all ranks are valid @@ -480,7 +605,13 @@ def _batch_level_generate_sequences(self, prompts: DataProto, **kwargs) -> DataP return DataProto(batch=batch, non_tensor_batch=_non_tensor_batch) - async def _async_rollout_a_request(self, req: AsyncRolloutRequest, do_sample: bool = True, is_validate: bool = False, **kwargs) -> AsyncRolloutRequest: + async def _async_rollout_a_request( + self, + req: AsyncRolloutRequest, + do_sample: bool = True, + is_validate: bool = False, + **kwargs, + ) -> AsyncRolloutRequest: assert self._tp_rank == 0, "only the master process can call this function" _req = deepcopy(req) finish_reason_type = None @@ -493,8 +624,12 @@ async def _async_rollout_a_request(self, req: AsyncRolloutRequest, do_sample: bo tool_creation_coroutines = [] for tool_schema in _req.tools: tool = self._tool_map[tool_schema.function.name] - create_kwargs = _req.tools_kwargs[tool.name].get("create_kwargs", {}) - tool_creation_coroutines.append(tool.create(_req.request_id, **create_kwargs)) + create_kwargs = _req.tools_kwargs[tool.name].get( + "create_kwargs", {} + ) + tool_creation_coroutines.append( + tool.create(_req.request_id, **create_kwargs) + ) await asyncio.gather(*tool_creation_coroutines) _req.state = AsyncRolloutRequestStateEnum.RUNNING elif _req.state == AsyncRolloutRequestStateEnum.TOOL_CALLING: @@ -505,13 +640,22 @@ async def _async_rollout_a_request(self, req: AsyncRolloutRequest, do_sample: bo self._tool_map[tool_call.function.name].execute( _req.request_id, tool_call.function.arguments, - **_req.tools_kwargs[tool_call.function.name].get("execute_kwargs", {}), + **_req.tools_kwargs[tool_call.function.name].get( + "execute_kwargs", {} + ), ) for tool_call in parsed_tool_calls ] ) - for i, (tool_call, (resp, reward, metrics)) in enumerate(zip(parsed_tool_calls, tool_call_results)): - _req.add_tool_response_message(self.tokenizer, resp, (i == len(parsed_tool_calls) - 1), format=self.config.multi_turn.format) + for i, (tool_call, (resp, reward, metrics)) in enumerate( + zip(parsed_tool_calls, tool_call_results) + ): + _req.add_tool_response_message( + self.tokenizer, + resp, + (i == len(parsed_tool_calls) - 1), + format=self.config.multi_turn.format, + ) if len(_req.input_ids) >= self.config.max_model_len: break if len(_req.input_ids) >= self.config.max_model_len: @@ -519,7 +663,9 @@ async def _async_rollout_a_request(self, req: AsyncRolloutRequest, do_sample: bo break _req.state = AsyncRolloutRequestStateEnum.RUNNING else: - raise ValueError(f"Unexpected tool calling last message state: {_req.messages[-1]}") + raise ValueError( + f"Unexpected tool calling last message state: {_req.messages[-1]}" + ) elif _req.state == AsyncRolloutRequestStateEnum.RUNNING: generation_prompt = _req.get_generation_prompt(self.tokenizer) if not do_sample: @@ -545,7 +691,9 @@ async def _async_rollout_a_request(self, req: AsyncRolloutRequest, do_sample: bo "temperature": self.config.val_kwargs.temperature, "n": 1, # if validate, already repeat in ray_trainer } - if "n" not in kwargs or kwargs["n"] > 1: # group size is supported in preprocess + if ( + "n" not in kwargs or kwargs["n"] > 1 + ): # group size is supported in preprocess kwargs["n"] = 1 # users can customize different sampling_params at different run with self.update_sampling_params(**kwargs): @@ -556,17 +704,29 @@ async def _async_rollout_a_request(self, req: AsyncRolloutRequest, do_sample: bo ) content = output["text"] - finish_reason_type = FinishReasonTypeEnum.from_str(output["meta_info"]["finish_reason"]["type"]) + finish_reason_type = FinishReasonTypeEnum.from_str( + output["meta_info"]["finish_reason"]["type"] + ) current_turns += 1 if finish_reason_type == FinishReasonTypeEnum.LENGTH: - _req.add_assistant_message(self.tokenizer, content, already_over_long=True, format=self.config.multi_turn.format) + _req.add_assistant_message( + self.tokenizer, + content, + already_over_long=True, + format=self.config.multi_turn.format, + ) break else: - if self._function_call_parser and self._function_call_parser.has_tool_call(content): + if ( + self._function_call_parser + and self._function_call_parser.has_tool_call(content) + ): finish_reason_type = FinishReasonTypeEnum.TOOL_CALL _req.state = AsyncRolloutRequestStateEnum.TOOL_CALLING try: - normed_content, tool_calls = self._function_call_parser.parse_non_stream(content) + normed_content, tool_calls = ( + self._function_call_parser.parse_non_stream(content) + ) except JSONDecodeError: normed_content = content tool_calls = [] @@ -575,7 +735,14 @@ async def _async_rollout_a_request(self, req: AsyncRolloutRequest, do_sample: bo tool_calls = [] parsed_tool_calls = [] for tool_call in tool_calls: - function, has_decode_error = OpenAIFunctionCallSchema.from_openai_function_parsed_schema(OpenAIFunctionParsedSchema(name=tool_call.name, arguments=tool_call.parameters)) + function, has_decode_error = ( + OpenAIFunctionCallSchema.from_openai_function_parsed_schema( + OpenAIFunctionParsedSchema( + name=tool_call.name, + arguments=tool_call.parameters, + ) + ) + ) # Drop the tool call if its arguments has decode error if has_decode_error: continue @@ -593,12 +760,20 @@ async def _async_rollout_a_request(self, req: AsyncRolloutRequest, do_sample: bo format=self.config.multi_turn.format, ) else: - _req.add_assistant_message(self.tokenizer, content, format=self.config.multi_turn.format) + _req.add_assistant_message( + self.tokenizer, + content, + format=self.config.multi_turn.format, + ) finish_reason_type = FinishReasonTypeEnum.STOP _req.state = AsyncRolloutRequestStateEnum.COMPLETED break else: - _req.add_assistant_message(self.tokenizer, content, format=self.config.multi_turn.format) + _req.add_assistant_message( + self.tokenizer, + content, + format=self.config.multi_turn.format, + ) break if current_turns >= self.config.multi_turn.max_turns: @@ -606,8 +781,12 @@ async def _async_rollout_a_request(self, req: AsyncRolloutRequest, do_sample: bo # Calculate the reward for each tool async def calc_reward_and_release_fn(name: str, tool: BaseTool): - reward = await tool.calc_reward(_req.request_id, **_req.tools_kwargs[name].get("calc_reward_kwargs", {})) - await tool.release(_req.request_id, **_req.tools_kwargs[name].get("release_kwargs", {})) + reward = await tool.calc_reward( + _req.request_id, **_req.tools_kwargs[name].get("calc_reward_kwargs", {}) + ) + await tool.release( + _req.request_id, **_req.tools_kwargs[name].get("release_kwargs", {}) + ) return name, reward tool_reward_tasks = [] @@ -619,10 +798,12 @@ async def calc_reward_and_release_fn(name: str, tool: BaseTool): _req.finalize(self.tokenizer, tool_reward_scores, finish_reason_type) return _req + @GPUMemoryLogger(role="sglang rollout", logger=logger) @torch.no_grad() def generate_sequences_with_tools(self, prompts: DataProto, **kwargs) -> DataProto: import warnings + warnings.warn( "`generate_sequences_with_tools` is deprecated, please use `generate_sequences(...)`", DeprecationWarning, @@ -645,10 +826,17 @@ def _req_level_generate_sequences(self, prompts: DataProto, **kwargs) -> DataPro loop = asyncio.get_event_loop() output_req_list = loop.run_until_complete( asyncio.gather( - *[self._async_rollout_a_request(req, do_sample, is_validate, **kwargs) for req in req_list], + *[ + self._async_rollout_a_request( + req, do_sample, is_validate, **kwargs + ) + for req in req_list + ], ) ) - sorted_output_req_list = sorted(output_req_list, key=lambda x: (x.batch_data_id, x.rollout_offset)) + sorted_output_req_list = sorted( + output_req_list, key=lambda x: (x.batch_data_id, x.rollout_offset) + ) else: sorted_output_req_list = None @@ -667,8 +855,15 @@ def _req_level_generate_sequences(self, prompts: DataProto, **kwargs) -> DataPro messages = [] reward_scores = [] for req in sorted_output_req_list: - assert req.state == AsyncRolloutRequestStateEnum.COMPLETED, f"Request {req.request_id} is not completed" - assert len(req.input_ids) == len(req.attention_mask) == len(req.position_ids) == len(req.loss_mask), f"""Request {req.request_id} has different length of + assert ( + req.state == AsyncRolloutRequestStateEnum.COMPLETED + ), f"Request {req.request_id} is not completed" + assert ( + len(req.input_ids) + == len(req.attention_mask) + == len(req.position_ids) + == len(req.loss_mask) + ), f"""Request {req.request_id} has different length of {len(req.input_ids)=}, {len(req.attention_mask)=}, {len(req.position_ids)=}, {len(req.loss_mask)=}""" error_message_lines = [ f"""Request {req.request_id} has input_ids length {len(req.input_ids)} @@ -682,50 +877,114 @@ def _req_level_generate_sequences(self, prompts: DataProto, **kwargs) -> DataPro error_message = "\n".join(error_message_lines) assert len(req.input_ids) <= self.config.max_model_len, error_message - prompt_ids.append(torch.tensor(req.prompt_ids, dtype=torch.int, device=tgt_device)) - response_ids.append(torch.tensor(req.response_ids, dtype=torch.int, device=tgt_device)) + prompt_ids.append( + torch.tensor(req.prompt_ids, dtype=torch.int, device=tgt_device) + ) + response_ids.append( + torch.tensor(req.response_ids, dtype=torch.int, device=tgt_device) + ) if len(req.response_ids) > self.config.response_length: print( f"""{req.request_id=} has response_ids length {len(req.response_ids)} greater than max_response_len {self.config.response_length},\n{req=}""" ) - prompt_attention_mask.append(torch.tensor(req.prompt_attention_mask, dtype=torch.int, device=tgt_device)) - response_attention_mask.append(torch.tensor(req.response_attention_mask, dtype=torch.int, device=tgt_device)) - prompt_position_ids.append(torch.tensor(req.prompt_position_ids, dtype=torch.int, device=tgt_device)) - response_position_ids.append(torch.tensor(req.response_position_ids, dtype=torch.int, device=tgt_device)) - prompt_loss_mask.append(torch.tensor(req.prompt_loss_mask, dtype=torch.int, device=tgt_device)) - response_loss_mask.append(torch.tensor(req.response_loss_mask, dtype=torch.int, device=tgt_device)) + prompt_attention_mask.append( + torch.tensor( + req.prompt_attention_mask, dtype=torch.int, device=tgt_device + ) + ) + response_attention_mask.append( + torch.tensor( + req.response_attention_mask, dtype=torch.int, device=tgt_device + ) + ) + prompt_position_ids.append( + torch.tensor( + req.prompt_position_ids, dtype=torch.int, device=tgt_device + ) + ) + response_position_ids.append( + torch.tensor( + req.response_position_ids, dtype=torch.int, device=tgt_device + ) + ) + prompt_loss_mask.append( + torch.tensor(req.prompt_loss_mask, dtype=torch.int, device=tgt_device) + ) + response_loss_mask.append( + torch.tensor(req.response_loss_mask, dtype=torch.int, device=tgt_device) + ) messages.append({"messages": req.messages}) reward_scores.append(req.reward_scores) - prompt_ids = pad_sequence(prompt_ids, batch_first=True, padding_value=self.pad_token_id, padding_side="left") + prompt_ids = pad_sequence( + prompt_ids, + batch_first=True, + padding_value=self.pad_token_id, + padding_side="left", + ) if prompt_ids.shape[1] < self.config.prompt_length: - prompt_ids = pad_sequence_to_length(prompt_ids, self.config.prompt_length, self.pad_token_id, left_pad=True) - response_ids = pad_sequence(response_ids, batch_first=True, padding_value=self.pad_token_id) + prompt_ids = pad_sequence_to_length( + prompt_ids, self.config.prompt_length, self.pad_token_id, left_pad=True + ) + response_ids = pad_sequence( + response_ids, batch_first=True, padding_value=self.pad_token_id + ) if response_ids.shape[1] < self.config.response_length: - response_ids = pad_sequence_to_length(response_ids, self.config.response_length, self.pad_token_id) - prompt_attention_mask = pad_sequence(prompt_attention_mask, batch_first=True, padding_value=0, padding_side="left") + response_ids = pad_sequence_to_length( + response_ids, self.config.response_length, self.pad_token_id + ) + prompt_attention_mask = pad_sequence( + prompt_attention_mask, + batch_first=True, + padding_value=0, + padding_side="left", + ) if prompt_attention_mask.shape[1] < self.config.prompt_length: - prompt_attention_mask = pad_sequence_to_length(prompt_attention_mask, self.config.prompt_length, 0, left_pad=True) - response_attention_mask = pad_sequence(response_attention_mask, batch_first=True, padding_value=0) + prompt_attention_mask = pad_sequence_to_length( + prompt_attention_mask, self.config.prompt_length, 0, left_pad=True + ) + response_attention_mask = pad_sequence( + response_attention_mask, batch_first=True, padding_value=0 + ) if response_attention_mask.shape[1] < self.config.response_length: - response_attention_mask = pad_sequence_to_length(response_attention_mask, self.config.response_length, 0) - prompt_position_ids = pad_sequence(prompt_position_ids, batch_first=True, padding_value=0, padding_side="left") + response_attention_mask = pad_sequence_to_length( + response_attention_mask, self.config.response_length, 0 + ) + prompt_position_ids = pad_sequence( + prompt_position_ids, batch_first=True, padding_value=0, padding_side="left" + ) if prompt_position_ids.shape[1] < self.config.prompt_length: - prompt_position_ids = pad_sequence_to_length(prompt_position_ids, self.config.prompt_length, 0, left_pad=True) + prompt_position_ids = pad_sequence_to_length( + prompt_position_ids, self.config.prompt_length, 0, left_pad=True + ) response_length = response_ids.size(1) - delta_position_id = torch.arange(1, response_length + 1, device=response_ids.device) - delta_position_id = delta_position_id.unsqueeze(0).repeat(len(sorted_output_req_list), 1) + delta_position_id = torch.arange( + 1, response_length + 1, device=response_ids.device + ) + delta_position_id = delta_position_id.unsqueeze(0).repeat( + len(sorted_output_req_list), 1 + ) response_position_ids = prompt_position_ids[:, -1:] + delta_position_id - prompt_loss_mask = pad_sequence(prompt_loss_mask, batch_first=True, padding_value=0, padding_side="left") + prompt_loss_mask = pad_sequence( + prompt_loss_mask, batch_first=True, padding_value=0, padding_side="left" + ) if prompt_loss_mask.shape[1] < self.config.prompt_length: - prompt_loss_mask = pad_sequence_to_length(prompt_loss_mask, self.config.prompt_length, 0, left_pad=True) - response_loss_mask = pad_sequence(response_loss_mask, batch_first=True, padding_value=0) + prompt_loss_mask = pad_sequence_to_length( + prompt_loss_mask, self.config.prompt_length, 0, left_pad=True + ) + response_loss_mask = pad_sequence( + response_loss_mask, batch_first=True, padding_value=0 + ) if response_loss_mask.shape[1] < self.config.response_length: - response_loss_mask = pad_sequence_to_length(response_loss_mask, self.config.response_length, 0) + response_loss_mask = pad_sequence_to_length( + response_loss_mask, self.config.response_length, 0 + ) input_ids = torch.cat((prompt_ids, response_ids), dim=-1) - attention_mask = torch.cat((prompt_attention_mask, response_attention_mask), dim=-1) + attention_mask = torch.cat( + (prompt_attention_mask, response_attention_mask), dim=-1 + ) position_ids = torch.cat((prompt_position_ids, response_position_ids), dim=-1) loss_mask = torch.cat((prompt_loss_mask, response_loss_mask), dim=-1) @@ -743,13 +1002,27 @@ def _req_level_generate_sequences(self, prompts: DataProto, **kwargs) -> DataPro ) # free cache engine - if self.config.free_cache_engine and self._engine is not None and self._tp_rank == 0: + if ( + self.config.free_cache_engine + and self._engine is not None + and self._tp_rank == 0 + ): self._engine.flush_cache() - return DataProto(batch=batch, non_tensor_batch={"messages": np.array(messages), "reward_scores": np.array(reward_scores)}) + return DataProto( + batch=batch, + non_tensor_batch={ + "messages": np.array(messages), + "reward_scores": np.array(reward_scores), + }, + ) - def _preprocess_prompt_to_async_rollout_requests(self, prompts: DataProto, n: int) -> list[AsyncRolloutRequest]: - assert "raw_prompt" in prompts.non_tensor_batch, "need data.return_raw_chat=True, due to no official way do parse_messages" + def _preprocess_prompt_to_async_rollout_requests( + self, prompts: DataProto, n: int + ) -> list[AsyncRolloutRequest]: + assert ( + "raw_prompt" in prompts.non_tensor_batch + ), "need data.return_raw_chat=True, due to no official way do parse_messages" req_list = [] for data_idx, raw_prompt in enumerate(prompts.non_tensor_batch["raw_prompt"]): for rollout_offset in range(n): @@ -765,10 +1038,16 @@ def _preprocess_prompt_to_async_rollout_requests(self, prompts: DataProto, n: in tokenize=False, return_tensors="pt", ) - input_data = self.tokenizer(prompt_with_chat_template, return_tensors="pt", add_special_tokens=False) + input_data = self.tokenizer( + prompt_with_chat_template, + return_tensors="pt", + add_special_tokens=False, + ) _input_ids = input_data["input_ids"][0].tolist() _attention_mask = input_data["attention_mask"][0].tolist() - _position_ids = compute_position_id_with_mask(input_data["attention_mask"][0]).tolist() + _position_ids = compute_position_id_with_mask( + input_data["attention_mask"][0] + ).tolist() if len(_input_ids) > self.config.prompt_length: logger.warning( "Prompt {} has length {} greater than max_prompt_len {}", @@ -780,9 +1059,15 @@ def _preprocess_prompt_to_async_rollout_requests(self, prompts: DataProto, n: in _attention_mask = _attention_mask[: self.config.prompt_length] _position_ids = _position_ids[: self.config.prompt_length] else: - _input_ids = _pre_process_inputs(self.pad_token_id, prompts.batch["input_ids"][data_idx]) - _attention_mask = _pre_process_inputs(0, prompts.batch["attention_mask"][data_idx]) - _position_ids = compute_position_id_with_mask(torch.tensor(_attention_mask)).tolist() + _input_ids = _pre_process_inputs( + self.pad_token_id, prompts.batch["input_ids"][data_idx] + ) + _attention_mask = _pre_process_inputs( + 0, prompts.batch["attention_mask"][data_idx] + ) + _position_ids = compute_position_id_with_mask( + torch.tensor(_attention_mask) + ).tolist() _tool_schemas = [] _tools_kwargs = {} @@ -808,11 +1093,19 @@ def _preprocess_prompt_to_async_rollout_requests(self, prompts: DataProto, n: in response_loss_mask=[], reward_scores={}, max_response_len=self.config.response_length, - max_model_len=min(self.config.max_model_len, self.config.prompt_length + self.config.response_length), + max_model_len=min( + self.config.max_model_len, + self.config.prompt_length + self.config.response_length, + ), ) error_message = f"Request {req.request_id} has mismatched lengths: input_ids={len(req.input_ids)}, attention_mask={len(req.attention_mask)}, position_ids={len(req.position_ids)}, loss_mask={len(req.loss_mask)}" - assert len(req.input_ids) == len(req.attention_mask) == len(req.position_ids) == len(req.loss_mask), error_message + assert ( + len(req.input_ids) + == len(req.attention_mask) + == len(req.position_ids) + == len(req.loss_mask) + ), error_message req_list.append(req) diff --git a/verl/workers/rollout/sglang_rollout/utils.py b/verl/workers/rollout/sglang_rollout/utils.py index c887e3f9e7d..438facd9e14 100644 --- a/verl/workers/rollout/sglang_rollout/utils.py +++ b/verl/workers/rollout/sglang_rollout/utils.py @@ -34,7 +34,9 @@ def broadcast_pyobj( The `rank` here refer to the source rank on global process group (regardless of dist_group argument). """ - device = torch.device("cuda" if torch.cuda.is_available() and not force_cpu_device else "cpu") + device = torch.device( + "cuda" if torch.cuda.is_available() and not force_cpu_device else "cpu" + ) if rank == src: if len(data) == 0: @@ -44,7 +46,9 @@ def broadcast_pyobj( serialized_data = pickle.dumps(data) size = len(serialized_data) - tensor_data = torch.ByteTensor(np.frombuffer(serialized_data, dtype=np.uint8)).to(device) + tensor_data = torch.ByteTensor( + np.frombuffer(serialized_data, dtype=np.uint8) + ).to(device) tensor_size = torch.tensor([size], dtype=torch.long, device=device) dist.broadcast(tensor_size, src=src, group=dist_group)