diff --git a/docs/sglang_multiturn/multiturn.rst b/docs/sglang_multiturn/multiturn.rst new file mode 100644 index 00000000000..8d6cde3261b --- /dev/null +++ b/docs/sglang_multiturn/multiturn.rst @@ -0,0 +1,40 @@ +Multi-turn Rollout Support +========================= + +Basic Configuration +~~~~~~~~~~~~~~~~~ + +To enable multi-turn rollout, make sure to configure the following fields in your rollout configuration: + +.. code-block:: yaml + + actor_rollout_ref: + rollout: + multi_turn: True + name: "sglang_async" + +These configuration activates the sglang_async engine for multi-turn interaction during rollout. + +Custom Tool Configuration +~~~~~~~~~~~~~~~~~~~~~~~ + +For custom environment interaction tools, you can specify your tool configurations in a YAML file. +To do so, use the following format in your rollout config: + +.. code-block:: yaml + + actor_rollout_ref: + rollout: + tool_kwargs: + tools_config_file: + +This allows integration of customized tool behaviors during actor rollout steps. You may refer to the GSM8KTool_example_configuration_ for guidance. + +GSM8K Multi-turn Training Performance +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +See the training performance of multi-turn rollout on the GSM8K task HERE_. + +.. _HERE: https://wandb.ai/zhaochenyang20/gsm8k_async_rl/runs/1ro1r7om?nw=nwuserzhaochenyang20 + +.. _GSM8KTool_example_configuration: ../../examples/sglang_multiturn/config/tool_config/gsm8k_tool_config.yaml \ No newline at end of file diff --git a/examples/sglang_multiturn/config/gsm8k_multiturn.yaml b/examples/sglang_multiturn/config/gsm8k_multiturn.yaml new file mode 100644 index 00000000000..fb4f9bab464 --- /dev/null +++ b/examples/sglang_multiturn/config/gsm8k_multiturn.yaml @@ -0,0 +1,184 @@ +data: + tokenizer: null + train_files: /root/data/gsm8k/train.parquet + val_files: /root/data/gsm8k/test.parquet + prompt_key: prompt + max_prompt_length: 512 + max_response_length: 512 + train_batch_size: 8 + val_batch_size: 8 + return_raw_input_ids: True # This should be set to true when the tokenizer between policy and rm differs + return_raw_chat: True + shuffle: False + filter_overlong_prompts: False # for large-scale dataset, filtering overlong prompts could be timeconsuming. You cat set the filter_overlong_prompts_workers to use multiprocessing to speed up. + + +actor_rollout_ref: + hybrid_engine: True + model: + path: Qwen/Qwen2.5-7B + external_lib: null + override_config: { } + enable_gradient_checkpointing: True + use_remove_padding: True + trust_remote_code: True + actor: + strategy: fsdp # This is for backward-compatibility + ppo_mini_batch_size: 4 + ppo_micro_batch_size: null # will be deprecated, use ppo_micro_batch_size_per_gpu + ppo_micro_batch_size_per_gpu: 1 # for dynamic bsz + use_dynamic_bsz: True + ppo_max_token_len_per_gpu: 32768 # n * ${data.max_prompt_length} + ${data.max_response_length} + grad_clip: 0.5 + clip_ratio: 0.2 + clip_ratio_c: 3.0 + loss_agg_mode: "token-mean" # / "seq-mean-token-sum" / "seq-mean-token-mean" + entropy_coeff: 0.0 + use_kl_loss: True # True for GRPO + kl_loss_coef: 0.0001 # for grpo + kl_loss_type: low_var_kl # for grpo + ppo_epochs: 1 + shuffle: False + ulysses_sequence_parallel_size: 1 # sp size + checkpoint: + contents: ['model', 'optimizer', 'extra'] + optim: + lr: 1e-6 + # lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime + # min_lr_ratio: null # only useful for warmup with cosine + # warmup_style: constant # select from constant/cosine + # total_training_steps: -1 # must be override by program + fsdp_config: + wrap_policy: + # transformer_layer_cls_to_wrap: None + min_num_params: 0 + param_offload: True + optimizer_offload: True + fsdp_size: -1 + ref: + fsdp_config: + param_offload: True + wrap_policy: + # transformer_layer_cls_to_wrap: None + min_num_params: 0 + log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu + log_prob_micro_batch_size_per_gpu: null + log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} + log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu} + ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size + rollout: + name: sglang_async + prompt_length: ${data.max_prompt_length} # not use for opensource + response_length: ${data.max_response_length} + max_model_len: null + # for vllm rollout + dtype: bfloat16 # should align with FSDP + temperature: ${.sampling_params.temperature} # this is currently ignored + gpu_memory_utilization: 0.8 + enable_memory_saver: False + ignore_eos: False + enforce_eager: True + free_cache_engine: True + load_format: dummy_dtensor + tensor_model_parallel_size: 4 + log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu + log_prob_micro_batch_size_per_gpu: null + log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} + log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu} + disable_log_stats: True + enable_chunked_prefill: True # could get higher throughput + # for hf rollout + do_sample: True + # number of responses (i.e. num sample times) + n: 1 + multi_turn: True + max_turns: 3 + plugin_browser: False + path: ${actor_rollout_ref.model.path} + sampling_params: + temperature: 0.8 + max_new_tokens: 192 + stop: [] + val_kwargs: + # sampling parameters for validation + top_k: -1 # 0 for hf rollout, -1 for vllm rollout + top_p: 1.0 + temperature: 0 + n: 2 + do_sample: False # default eager for validation + tool_kwargs: + tools_config_file: "examples/sglang_multiturn/config/tool_config/gsm8k_tool_config.yaml" + +critic: + strategy: fsdp + optim: + lr: 1e-5 + # lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime + # min_lr_ratio: null # only useful for warmup with cosine + # warmup_style: constant # select from constant/cosine + # total_training_steps: -1 # must be override by program + model: + path: Qwen/Qwen2.5-7B + tokenizer_path: ${actor_rollout_ref.model.path} + override_config: { } + external_lib: ${actor_rollout_ref.model.external_lib} + enable_gradient_checkpointing: True + use_remove_padding: False + fsdp_config: + param_offload: False + optimizer_offload: False + wrap_policy: + # transformer_layer_cls_to_wrap: None + min_num_params: 0 + fsdp_size: -1 + ppo_mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size} + ppo_micro_batch_size: null # will be deprecated, use ppo_micro_batch_size_per_gpu + ppo_micro_batch_size_per_gpu: null + forward_micro_batch_size: ${critic.ppo_micro_batch_size} + forward_micro_batch_size_per_gpu: ${critic.ppo_micro_batch_size_per_gpu} + use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} + ppo_max_token_len_per_gpu: 32768 # (${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}) * 2 + forward_max_token_len_per_gpu: ${critic.ppo_max_token_len_per_gpu} + ulysses_sequence_parallel_size: 1 # sp size + ppo_epochs: ${actor_rollout_ref.actor.ppo_epochs} + shuffle: ${actor_rollout_ref.actor.shuffle} + grad_clip: 1.0 + cliprange_value: 0.5 + checkpoint: + contents: ['model', 'optimizer', 'extra'] + +reward_model: + enable: False + +algorithm: + gamma: 1.0 + lam: 1.0 + adv_estimator: grpo + use_kl_in_reward: False + kl_penalty: kl # how to estimate kl divergence + kl_ctrl: + type: fixed + kl_coef: 0.001 + +trainer: + hybrid_engine: True + total_epochs: 3 + total_training_steps: null + project_name: gsm8k_async_rl + experiment_name: qwen7b_sft2_16k_t08_n8_v6 + logger: [ 'console', 'wandb' ] + val_generations_to_log_to_wandb: 0 + nnodes: 1 + n_gpus_per_node: 4 + save_freq: 100 + # auto: find the last ckpt to resume. If can't find, start from scratch + resume_mode: auto # or auto or resume_path if + resume_from_path: False + test_freq: -1 + critic_warmup: 0 + default_hdfs_dir: null + remove_previous_ckpt_in_save: True + del_local_ckpt_after_load: True + default_local_dir: /workspace/gsm8k/ckpt/${trainer.project_name}/${trainer.experiment_name} + val_before_train: False + balance_batch: False \ No newline at end of file diff --git a/examples/sglang_multiturn/config/tool_config/gsm8k_tool_config.yaml b/examples/sglang_multiturn/config/tool_config/gsm8k_tool_config.yaml new file mode 100644 index 00000000000..e861eb6035c --- /dev/null +++ b/examples/sglang_multiturn/config/tool_config/gsm8k_tool_config.yaml @@ -0,0 +1,18 @@ +tools: + - class_name: "verl.workers.tool.gsm8k_tool.Gsm8kTool" + config: {} + tool_schema: + type: "function" + function: + name: "calc_gsm8k_reward" + description: "A tool for calculating the reward of gsm8k" + parameters: + type: "object" + properties: + response: + type: "string" + description: "The model's response to the GSM8K math problem" + ground_truth: + type: "string" + description: "The ground truth answer to the GSM8K math problem" + required: ["response", "ground_truth"] diff --git a/examples/sglang_multiturn/run_qwen2.5-7b_math_gsm8k_fsdp_multiturn.sh b/examples/sglang_multiturn/run_qwen2.5-7b_math_gsm8k_fsdp_multiturn.sh new file mode 100644 index 00000000000..4db63378cbc --- /dev/null +++ b/examples/sglang_multiturn/run_qwen2.5-7b_math_gsm8k_fsdp_multiturn.sh @@ -0,0 +1,6 @@ +set -x + +gsm8k_train_path=$HOME/data/gsm8k/train.parquet +gsm8k_test_path=$HOME/data/gsm8k/test.parquet + +python3 -m verl.trainer.main_ppo --config-path=/root/verl/examples/sglang_multiturn/config --config-name='gsm8k_multiturn' diff --git a/tests/e2e/run_qwen2.5-7b_gsm8k_fsdp_multiturn_e2e.sh b/tests/e2e/run_qwen2.5-7b_gsm8k_fsdp_multiturn_e2e.sh new file mode 100644 index 00000000000..b609b4e68d2 --- /dev/null +++ b/tests/e2e/run_qwen2.5-7b_gsm8k_fsdp_multiturn_e2e.sh @@ -0,0 +1,144 @@ +set -x + +NOW=$(date +"%Y%m%d_%H_%M") +EXPERIMENT_NAME="qwen7b_sft2_${NOW}" +LOG_DIR="logs" +mkdir -p ${LOG_DIR} + +# Run the PPO training with complete configuration +python3 -m verl.trainer.main_ppo \ + data.train_files=/root/data/gsm8k/train.parquet \ + data.val_files=/root/data/gsm8k/test.parquet \ + data.prompt_key=prompt \ + data.max_prompt_length=512 \ + data.max_response_length=512 \ + data.train_batch_size=8 \ + data.val_batch_size=8 \ + data.return_raw_input_ids=True \ + data.return_raw_chat=True \ + data.shuffle=False \ + data.filter_overlong_prompts=False \ + actor_rollout_ref.hybrid_engine=True \ + actor_rollout_ref.model.path=Qwen/Qwen2.5-7B \ + actor_rollout_ref.model.external_lib=null \ + actor_rollout_ref.model.override_config={} \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.model.use_remove_padding=True \ + +actor_rollout_ref.model.trust_remote_code=True \ + actor_rollout_ref.actor.strategy=fsdp \ + actor_rollout_ref.actor.ppo_mini_batch_size=4 \ + actor_rollout_ref.actor.ppo_micro_batch_size=null \ + +actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \ + actor_rollout_ref.actor.use_dynamic_bsz=True \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=32768 \ + actor_rollout_ref.actor.grad_clip=0.5 \ + actor_rollout_ref.actor.clip_ratio=0.2 \ + actor_rollout_ref.actor.clip_ratio_c=3.0 \ + +actor_rollout_ref.actor.loss_agg_mode="token-mean" \ + actor_rollout_ref.actor.entropy_coeff=0.0 \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.0001 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.ppo_epochs=1 \ + actor_rollout_ref.actor.shuffle=False \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 \ + actor_rollout_ref.actor.checkpoint.contents=['model','optimizer','extra'] \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.fsdp_config.wrap_policy.min_num_params=0 \ + actor_rollout_ref.actor.fsdp_config.param_offload=True \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + actor_rollout_ref.ref.fsdp_config.wrap_policy.min_num_params=0 \ + actor_rollout_ref.ref.log_prob_micro_batch_size=null \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=null \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=\${actor_rollout_ref.actor.use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=\${actor_rollout_ref.actor.ppo_max_token_len_per_gpu} \ + actor_rollout_ref.ref.ulysses_sequence_parallel_size=\${actor_rollout_ref.actor.ulysses_sequence_parallel_size} \ + actor_rollout_ref.rollout.name=sglang_async \ + actor_rollout_ref.rollout.prompt_length=\${data.max_prompt_length} \ + actor_rollout_ref.rollout.response_length=\${data.max_response_length} \ + actor_rollout_ref.rollout.max_model_len=null \ + actor_rollout_ref.rollout.dtype=bfloat16 \ + actor_rollout_ref.rollout.temperature=\${actor_rollout_ref.rollout.sampling_params.temperature} \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \ + +actor_rollout_ref.rollout.enable_memory_saver=False \ + actor_rollout_ref.rollout.ignore_eos=False \ + actor_rollout_ref.rollout.enforce_eager=True \ + actor_rollout_ref.rollout.free_cache_engine=True \ + actor_rollout_ref.rollout.load_format=dummy_dtensor \ + actor_rollout_ref.rollout.tensor_model_parallel_size=4 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size=null \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=null \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=\${actor_rollout_ref.actor.use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=\${actor_rollout_ref.actor.ppo_max_token_len_per_gpu} \ + actor_rollout_ref.rollout.disable_log_stats=True \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.do_sample=True \ + actor_rollout_ref.rollout.n=1 \ + +actor_rollout_ref.rollout.max_turns=3 \ + +actor_rollout_ref.rollout.plugin_browser=False \ + +actor_rollout_ref.rollout.path=\${actor_rollout_ref.model.path} \ + +actor_rollout_ref.rollout.sampling_params.temperature=0.8 \ + +actor_rollout_ref.rollout.sampling_params.max_new_tokens=192 \ + +actor_rollout_ref.rollout.sampling_params.stop=[] \ + actor_rollout_ref.rollout.val_kwargs.top_k=-1 \ + actor_rollout_ref.rollout.val_kwargs.top_p=1.0 \ + actor_rollout_ref.rollout.val_kwargs.temperature=0 \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.rollout.val_kwargs.do_sample=False \ + actor_rollout_ref.rollout.tool_kwargs.tools_config_file="examples/sglang_multiturn/config/tool_config/gsm8k_tool_config.yaml" \ + critic.strategy=fsdp \ + critic.optim.lr=1e-5 \ + critic.model.path=Qwen/Qwen2.5-7B \ + critic.model.tokenizer_path=\${actor_rollout_ref.model.path} \ + critic.model.override_config={} \ + critic.model.external_lib=\${actor_rollout_ref.model.external_lib} \ + critic.model.enable_gradient_checkpointing=True \ + critic.model.use_remove_padding=False \ + critic.model.fsdp_config.param_offload=False \ + critic.model.fsdp_config.optimizer_offload=False \ + critic.model.fsdp_config.wrap_policy.min_num_params=0 \ + critic.model.fsdp_config.fsdp_size=-1 \ + critic.ppo_mini_batch_size=\${actor_rollout_ref.actor.ppo_mini_batch_size} \ + critic.ppo_micro_batch_size=null \ + critic.ppo_micro_batch_size_per_gpu=null \ + critic.forward_micro_batch_size=\${critic.ppo_micro_batch_size} \ + critic.forward_micro_batch_size_per_gpu=\${critic.ppo_micro_batch_size_per_gpu} \ + critic.use_dynamic_bsz=\${actor_rollout_ref.actor.use_dynamic_bsz} \ + critic.ppo_max_token_len_per_gpu=32768 \ + critic.forward_max_token_len_per_gpu=\${critic.ppo_max_token_len_per_gpu} \ + critic.ulysses_sequence_parallel_size=1 \ + critic.ppo_epochs=\${actor_rollout_ref.actor.ppo_epochs} \ + critic.shuffle=\${actor_rollout_ref.actor.shuffle} \ + critic.grad_clip=1.0 \ + critic.cliprange_value=0.5 \ + critic.checkpoint.contents=['model','optimizer','extra'] \ + reward_model.enable=False \ + algorithm.gamma=1.0 \ + algorithm.lam=1.0 \ + algorithm.adv_estimator=grpo \ + algorithm.use_kl_in_reward=False \ + algorithm.kl_penalty=kl \ + algorithm.kl_ctrl.type=fixed \ + algorithm.kl_ctrl.kl_coef=0.001 \ + +trainer.hybrid_engine=True \ + trainer.total_epochs=3 \ + trainer.total_training_steps=null \ + trainer.project_name=gsm8k_async_rl \ + trainer.experiment_name="${EXPERIMENT_NAME}" \ + trainer.logger=['console','wandb'] \ + +trainer.val_generations_to_log_to_wandb=0 \ + trainer.nnodes=1 \ + trainer.n_gpus_per_node=4 \ + trainer.resume_mode=auto \ + +trainer.resume_from_path=False \ + trainer.test_freq=-1 \ + trainer.critic_warmup=0 \ + trainer.default_hdfs_dir=null \ + +trainer.remove_previous_ckpt_in_save=True \ + trainer.del_local_ckpt_after_load=True \ + trainer.default_local_dir=/workspace/gsm8k/ckpt/\${trainer.project_name}/\${trainer.experiment_name} \ + trainer.val_before_train=False \ + trainer.balance_batch=False \ + | tee ${LOG_DIR}/${EXPERIMENT_NAME}.log diff --git a/verl/trainer/config/ppo_trainer.yaml b/verl/trainer/config/ppo_trainer.yaml index 7dd270bbaa4..b9a968356af 100644 --- a/verl/trainer/config/ppo_trainer.yaml +++ b/verl/trainer/config/ppo_trainer.yaml @@ -103,6 +103,9 @@ actor_rollout_ref: temperature: 0 n: 1 do_sample: False # default eager for validation + multi_turn: False # should set rollout.name to sglang_async if True + tool_kwargs: + tools_config_file: None critic: rollout_n: ${actor_rollout_ref.rollout.n} diff --git a/verl/trainer/ppo/ray_trainer.py b/verl/trainer/ppo/ray_trainer.py index 35aa5a2f285..df0c6115f47 100644 --- a/verl/trainer/ppo/ray_trainer.py +++ b/verl/trainer/ppo/ray_trainer.py @@ -134,14 +134,19 @@ def _check_resource_available(self): from verl.utils.torch_functional import masked_mean -def apply_kl_penalty(data: DataProto, kl_ctrl: core_algos.AdaptiveKLController, kl_penalty='kl'): +def apply_kl_penalty(data: DataProto, kl_ctrl: core_algos.AdaptiveKLController, kl_penalty='kl', multi_turn=False): responses = data.batch['responses'] response_length = responses.size(1) token_level_scores = data.batch['token_level_scores'] batch_size = data.batch.batch_size[0] - attention_mask = data.batch['attention_mask'] - response_mask = attention_mask[:, -response_length:] - + + if multi_turn: + loss_mask = data.batch['loss_mask'] + response_mask = loss_mask[:, -response_length:] + else: + attention_mask = data.batch['attention_mask'] + response_mask = attention_mask[:, -response_length:] + # compute kl between ref_policy and current policy # When apply_kl_penalty, algorithm.use_kl_in_reward=True, so the reference model has been enabled. kld = core_algos.kl_penalty(data.batch['old_log_probs'], data.batch['ref_log_prob'], @@ -886,7 +891,8 @@ def fit(self): if self.config.algorithm.use_kl_in_reward: batch, kl_metrics = apply_kl_penalty(batch, kl_ctrl=self.kl_ctrl_in_reward, - kl_penalty=self.config.algorithm.kl_penalty) + kl_penalty=self.config.algorithm.kl_penalty, + multi_turn=self.config.actor_rollout_ref.rollout.get('multi_turn', False)) metrics.update(kl_metrics) else: batch.batch['token_level_rewards'] = batch.batch['token_level_scores'] diff --git a/verl/workers/fsdp_workers.py b/verl/workers/fsdp_workers.py index c254f52a8e6..74f641f697d 100644 --- a/verl/workers/fsdp_workers.py +++ b/verl/workers/fsdp_workers.py @@ -360,7 +360,7 @@ def _build_rollout(self, trust_remote_code=False): device_mesh=rollout_device_mesh) log_gpu_memory_usage('After building sharding manager', logger=None) elif rollout_name == 'sglang_async': - from verl.workers.rollout.sglang_rollout.async_sglang_rollout import AsyncSGLangRollout + from verl.workers.rollout.sglang_rollout import AsyncSGLangRollout from verl.workers.sharding_manager.fsdp_async_sglang import FSDPAsyncSGLangShardingManager log_gpu_memory_usage(f'Before building {rollout_name} rollout', logger=None) rollout = AsyncSGLangRollout(actor_module=self.config.model.path, @@ -378,7 +378,9 @@ def _build_rollout(self, trust_remote_code=False): device_mesh=rollout_device_mesh) log_gpu_memory_usage('After building sharding manager', logger=None) - + else: + raise NotImplementedError(f"Rollout name: {self.config.rollout.name} is not supported") + return rollout, rollout_sharding_manager @register(dispatch_mode=Dispatch.ONE_TO_ALL) @@ -529,7 +531,12 @@ def generate_sequences(self, prompts: DataProto): log_gpu_memory_usage('After entering rollout sharding manager', logger=logger) prompts = self.rollout_sharding_manager.preprocess_data(prompts) - output = self.rollout.generate_sequences(prompts=prompts) + + from verl.workers.rollout.sglang_rollout import AsyncSGLangRollout + if isinstance(self.rollout, AsyncSGLangRollout) and hasattr(self.rollout, '_tool_schemas') and len(self.rollout._tool_schemas) > 0: + output = self.rollout.generate_sequences_with_tools(prompts=prompts) + else: + output = self.rollout.generate_sequences(prompts=prompts) log_gpu_memory_usage('After rollout generation', logger=logger) output = self.rollout_sharding_manager.postprocess_data(output) diff --git a/verl/workers/rollout/sglang_rollout/__init__.py b/verl/workers/rollout/sglang_rollout/__init__.py index 81320e8aec5..714b047db83 100644 --- a/verl/workers/rollout/sglang_rollout/__init__.py +++ b/verl/workers/rollout/sglang_rollout/__init__.py @@ -12,3 +12,4 @@ # See the License for the specific language governing permissions and from .sglang_rollout import SGLangRollout +from .async_sglang_rollout import AsyncSGLangRollout \ No newline at end of file diff --git a/verl/workers/rollout/sglang_rollout/async_sglang_rollout.py b/verl/workers/rollout/sglang_rollout/async_sglang_rollout.py index de0514dd918..6b448f57569 100644 --- a/verl/workers/rollout/sglang_rollout/async_sglang_rollout.py +++ b/verl/workers/rollout/sglang_rollout/async_sglang_rollout.py @@ -123,7 +123,6 @@ def __init__( config: DictConfig, tokenizer, model_hf_config, - tool_list: Optional[List[BaseTool]] = None, **kwargs, ): """A SGLang rollout. It requires the module is supported by the SGLang. @@ -137,11 +136,53 @@ def __init__( """ super().__init__() self.config = config + self.max_turns = getattr(config, "max_turns", 1) + + tool_list = None + if config.get("tool_kwargs") and config.tool_kwargs.get("tools_config_file", None) is not None: + from omegaconf import OmegaConf + def initialize_tools(tools_config) -> List: + import sys + import importlib.util + from typing import List + from verl.workers.tool.data_model import OpenAIFunctionToolSchema + + tool_list = [] + + for tool_config in tools_config.tools: + cls_name = tool_config.class_name + module_name, class_name = cls_name.rsplit(".", 1) + + if module_name not in sys.modules: + spec = importlib.util.find_spec(module_name) + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + spec.loader.exec_module(module) + else: + module = sys.modules[module_name] + + tool_cls = getattr(module, class_name) + + tool_schema_dict = OmegaConf.to_container(tool_config.tool_schema, resolve=True) + tool_schema = OpenAIFunctionToolSchema.parse_obj(tool_schema_dict) + + tool = tool_cls( + config=OmegaConf.to_container(tool_config.config, resolve=True), + tool_schema=tool_schema + ) + tool_list.append(tool) + + return tool_list + + tools_config_file = config.tool_kwargs.tools_config_file + tools_config = OmegaConf.load(tools_config_file) + tool_list = initialize_tools(tools_config) + if tool_list is not None: - self._tool_schemas = [tool.get_openai_tool_schema() for tool in tool_list] + self._tool_schemas = [tool.get_openai_tool_schema().model_dump() for tool in tool_list] self._tool_map = {tool.name: tool for tool in tool_list} self._tool_call_parser_type = get_tool_call_parser_type(tokenizer) - self._sgl_tools = [Tool.model_validate(tool) for tool in self._tool_schemas] + self._sgl_tools = self._tool_schemas self._function_call_parser = FunctionCallParser( self._sgl_tools, self._tool_call_parser_type, @@ -384,7 +425,11 @@ async def _async_rollout_a_request(self, req: AsyncRolloutRequest, do_sample: bo _req = deepcopy(req) finish_reason_type = None output = None - while True: + + current_turns = 0 + while current_turns < self.max_turns: + current_turns += 1 + if _req.state == AsyncRolloutRequestStateEnum.PENDING: await asyncio.gather(*[tool.create(_req.request_id) for tool in self._tool_map.values()]) _req.state = AsyncRolloutRequestStateEnum.RUNNING @@ -429,6 +474,7 @@ async def _async_rollout_a_request(self, req: AsyncRolloutRequest, do_sample: bo sampling_params=self.sampling_params, return_logprob=False, ) + content = output["text"] finish_reason_type = FinishReasonTypeEnum.from_str(output["meta_info"]["finish_reason"]["type"]) if finish_reason_type == FinishReasonTypeEnum.LENGTH: diff --git a/verl/workers/rollout/sglang_rollout/utils.py b/verl/workers/rollout/sglang_rollout/utils.py deleted file mode 100644 index e51986e2e6d..00000000000 --- a/verl/workers/rollout/sglang_rollout/utils.py +++ /dev/null @@ -1,115 +0,0 @@ -import uuid -import torch -from typing import List -from verl.workers.rollout.data_model import ( - AsyncRolloutRequest, - AsyncRolloutRequestStateEnum, - Message, -) -from verl.workers.tool.data_model import OpenAIFunctionToolSchema -from verl import DataProto - - -def prompts_to_async_rollout_requests( - prompts: DataProto, tokenizer, tool: List[OpenAIFunctionToolSchema] = None -) -> List[AsyncRolloutRequest]: - if tools is None: - tools = [] - requests = [] - - input_ids = prompts.batch["input_ids"] - batch_size = input_ids.size(0) - - for i in range(batch_size): - request_id = str(uuid.uuid4()) - - if tokenizer is not None: - prompt_ids = input_ids[i].tolist() - pad_token_id = ( - tokenizer.pad_token_id if hasattr(tokenizer, "pad_token_id") else None - ) - if pad_token_id is not None: - prompt_ids = [id for id in prompt_ids if id != pad_token_id] - prompt = tokenizer.decode(prompt_ids) - else: - prompt = str(input_ids[i].tolist()) - - messages = [Message(role="user", content=prompt)] - - request = AsyncRolloutRequest( - request_id=request_id, - state=AsyncRolloutRequestStateEnum.PENDING, - prompt=prompt, - messages=messages, - tools=tools, - ) - - requests.append(request) - - return requests - - -def messages_to_ids_with_loss_mask( - messages: List[Message], - tokenizer, - tools: List[OpenAIFunctionToolSchema] = None, - max_length: int = None, -) -> Tuple[List[int], List[int]]: - formatted_messages = [ - {"role": msg.role, "content": msg.content} for msg in messages - ] - - tools_dict = None - if tools: - tools_dict = [tool.model_dump() for tool in tools] - - input_ids = tokenizer.apply_chat_template( - formatted_messages, tools=tools_dict, tokenize=True, add_generation_prompt=True - ) - - if max_length and len(input_ids) > max_length: - input_ids = input_ids[:max_length] - - loss_mask = [0] * len(input_ids) - - current_pos = 0 - for msg in formatted_messages: - tokens = tokenizer.encode(msg["content"], add_special_tokens=False) - token_length = len(tokens) - - if msg["role"] == "assistant": - approx_start = current_pos - approx_end = approx_start + token_length - - for i in range(approx_start, min(approx_end, len(loss_mask))): - loss_mask[i] = 1 - - current_pos += token_length + 1 - - return input_ids, loss_mask - - -def ids_to_messages( - input_ids: torch.Tensor, - tokenizer, - skip_special_tokens: bool = True, - role: str = "assistant", -) -> List[Message]: - if isinstance(input_ids, torch.Tensor): - input_ids = input_ids.tolist() - - pad_token_id = tokenizer.pad_token_id - if pad_token_id is not None: - if isinstance(input_ids[0], list): - input_ids = [ - [tid for tid in seq if tid != pad_token_id] for seq in input_ids - ] - else: - input_ids = [tid for tid in input_ids if tid != pad_token_id] - - text = tokenizer.decode(input_ids, skip_special_tokens=skip_special_tokens) - - message = Message(role=role, content=text) - - return [message] -