diff --git a/.github/workflows/e2e_ppo_trainer.yml b/.github/workflows/e2e_ppo_trainer.yml index da9629a56b4..965712ab883 100644 --- a/.github/workflows/e2e_ppo_trainer.yml +++ b/.github/workflows/e2e_ppo_trainer.yml @@ -251,6 +251,16 @@ jobs: run: | ray stop --force ENGINE=sglang bash tests/special_e2e/ppo_trainer/run_function_reward.sh + - name: Running GSM8K E2E training tests on sglang async + run: | + ray stop --force + ENGINE=sglang ROLLOUT_MODE=async bash tests/special_e2e/ppo_trainer/run_function_reward.sh + - name: Running GSM8K E2E training tests on vllm async + run: | + ray stop --force + export VLLM_USE_V1=1 + ray start --head + ENGINE=vllm ROLLOUT_MODE=async bash tests/special_e2e/ppo_trainer/run_function_reward.sh e2e_ppo_trainer_sglang_multiturn_with_tool: runs-on: [L20x8] diff --git a/tests/special_e2e/ppo_trainer/run_function_reward.sh b/tests/special_e2e/ppo_trainer/run_function_reward.sh index 6d1d6ac69dc..c3b311029bb 100644 --- a/tests/special_e2e/ppo_trainer/run_function_reward.sh +++ b/tests/special_e2e/ppo_trainer/run_function_reward.sh @@ -14,6 +14,12 @@ MAX_RESPONSE_LEN=${MAX_RESPONSE_LEN:-512} ENGINE=${ENGINE:-vllm} ROLLOUT_MODE=${ROLLOUT_MODE:-sync} + +RETURN_RAW_CHAT="False" +if [ "$ROLLOUT_MODE" = "async" ]; then + RETURN_RAW_CHAT="True" +fi + GPU_MEMORY_UTILIZATION=${GPU_MEMORY_UTILIZATION:-0.8} ACTOR_FSDP_PARAM_OFFLOAD=${ACTOR_FSDP_PARAM_OFFLOAD:-False} ACTOR_FSDP_OPTIMIZER_OFFLOAD=${ACTOR_FSDP_OPTIMIZER_OFFLOAD:-False} @@ -84,6 +90,7 @@ python3 -m verl.trainer.main_ppo \ data.train_batch_size="${train_prompt_bsz}" \ data.max_prompt_length="${MAX_PROMPT_LEN}" \ data.max_response_length="${MAX_RESPONSE_LEN}" \ + data.return_raw_chat=${RETURN_RAW_CHAT} \ actor_rollout_ref.model.path="${MODEL_PATH}" \ actor_rollout_ref.model.use_shm=${USE_SHM} \ actor_rollout_ref.model.lora_rank=${LORA_RANK} \ diff --git a/verl/workers/rollout/sglang_rollout/async_sglang_server.py b/verl/workers/rollout/sglang_rollout/async_sglang_server.py index b0d5470302e..153fb167f89 100644 --- a/verl/workers/rollout/sglang_rollout/async_sglang_server.py +++ b/verl/workers/rollout/sglang_rollout/async_sglang_server.py @@ -29,9 +29,8 @@ class AsyncSglangServer(AsyncServerBase): def __init__(self, config: DictConfig, dp_size: int, dp_rank: int, wg_prefix: str): super().__init__() - self.config = config - rollout_config = config.get("rollout", {}) - self._tp_size = rollout_config.get("tensor_model_parallel_size", 1) + self.config = config.actor_rollout_ref + self._tp_size = self.config.rollout.get("tensor_model_parallel_size", 1) self._dp_size = dp_size self._dp_rank = dp_rank self.wg_prefix = wg_prefix diff --git a/verl/workers/rollout/sglang_rollout/sglang_rollout.py b/verl/workers/rollout/sglang_rollout/sglang_rollout.py index 4e9c58eb714..fa06ffdc4e8 100644 --- a/verl/workers/rollout/sglang_rollout/sglang_rollout.py +++ b/verl/workers/rollout/sglang_rollout/sglang_rollout.py @@ -1084,7 +1084,7 @@ async def chat_completion(self, json_request): request_id=str(uuid4()), state=AsyncRolloutRequestStateEnum.PENDING, messages=[Message.model_validate(msg) for msg in json_request["messages"]], - tools=_tool_schemas, + tool_schemas=_tool_schemas, tools_kwargs=_tools_kwargs, input_ids=_input_ids, prompt_ids=_input_ids, @@ -1099,8 +1099,12 @@ async def chat_completion(self, json_request): prompt_loss_mask=[0] * len(_input_ids), response_loss_mask=[], reward_scores={}, + max_prompt_len=self.config.prompt_length, max_response_len=self.config.response_length, max_model_len=min(self.config.max_model_len, self.config.prompt_length + self.config.response_length), + use_inference_chat_template=self.config.multi_turn.use_inference_chat_template, + enable_tokenization_sanity_check=self.config.multi_turn.enable_tokenization_sanity_check, + tokenizer=self.tokenizer, ) # json_request already contains sampling_params