diff --git a/docs/examples/config.rst b/docs/examples/config.rst index 7f0116d8b48..0541c3dc17f 100644 --- a/docs/examples/config.rst +++ b/docs/examples/config.rst @@ -169,8 +169,10 @@ Actor/Rollout/Reference Policy # for hf rollout do_sample: True engine_kwargs: # inference engine parameters - swap_space: null # null means "use the engine default value" (usually 4 GB), setting it to, e.g., 32 means 32 GB - attention_backend: fa3 # null means use the engine default value, available options: flashinfer, triton, flashmla + vllm: + swap_space: null # null means "use the engine default value" (usually 4 GB), setting it to, e.g., 32 means 32 GB + sglang: + attention_backend: null # null means use the engine default value, available options: flashinfer, triton, flashmla # number of responses (i.e. num sample times) n: 1 # > 1 for grpo, rloo val_kwargs: @@ -320,15 +322,17 @@ Reference model will be enabled when ``actor.use_kl_loss`` or/and ``algorithm.us deterministic outputs. When set to True, the rollout will use the ``actor_rollout_ref.rollout.val_kwargs`` parameters (top_k, top_p, temperature) to control the sampling behavior. -- ``actor_rollout_ref.rollout.engine_kwargs.swap_space``: swap space in GB used by the inference engine. - - ``null``: means not setting and using the engine default value (usually, e.g., 4 GB for vLLM) - - Positive integer, e.g., ``32`` means 32 GB. - -- ``actor_rollout_ref.rollout.engine_kwargs.attention_backend``: The attention backend to use for the inference engine. - - ``null``: means not setting and using the engine default value (usually, e.g., ``fa3`` for SGLang) - - ``flashinfer``: Use flashinfer attention backend. - - ``triton``: Use triton attention backend. - - ``flashmla``: Use flashmla attention backend. +- ``actor_rollout_ref.rollout.engine_kwargs.vllm``: extra vllm engine args + - ``swap_space``: swap space in GB used by the inference engine. + - ``null``: means not setting and using the engine default value (usually, e.g., 4 GB for vLLM) + - Positive integer, e.g., ``32`` means 32 GB. + +- ``actor_rollout_ref.rollout.engine_kwargs.sglang``: extra sglang engine args + - ``attention_backend``: The attention backend to use for the inference engine. + - ``null``: means not setting and using the engine default value (usually, e.g., ``fa3`` for SGLang) + - ``flashinfer``: Use flashinfer attention backend. + - ``triton``: Use triton attention backend. + - ``flashmla``: Use flashmla attention backend. - ``actor_rollout_ref.rollout.ignore_eos``: Whether to ignore the EOS token and continue generating tokens after the EOS token is generated. diff --git a/verl/trainer/config/ppo_megatron_trainer.yaml b/verl/trainer/config/ppo_megatron_trainer.yaml index a8e9b250134..2b1a0941594 100644 --- a/verl/trainer/config/ppo_megatron_trainer.yaml +++ b/verl/trainer/config/ppo_megatron_trainer.yaml @@ -149,8 +149,10 @@ actor_rollout_ref: # number of responses (i.e. num sample times) n: 1 engine_kwargs: # inference engine parameters - swap_space: null # null means "use the engine default value" (usually 4 GB), setting it to, e.g., 32 means 32 GB - attention_backend: fa3 # null means use the engine default value, available options: flashinfer, triton, flashmla + vllm: + swap_space: null # null means "use the engine default value" (usually 4 GB), setting it to, e.g., 32 means 32 GB + sglang: + attention_backend: null # null means use the engine default value, available options: flashinfer, triton, flashmla val_kwargs: # sampling parameters for validation top_k: -1 # 0 for hf rollout, -1 for vllm rollout diff --git a/verl/trainer/config/ppo_trainer.yaml b/verl/trainer/config/ppo_trainer.yaml index 36d30d03b84..6d3634ea09e 100644 --- a/verl/trainer/config/ppo_trainer.yaml +++ b/verl/trainer/config/ppo_trainer.yaml @@ -120,8 +120,10 @@ actor_rollout_ref: # number of responses (i.e. num sample times) n: 1 # > 1 for grpo engine_kwargs: # inference engine parameters - swap_space: null # null means "use the engine default value" (usually 4 GB), setting it to, e.g., 32 means 32 GB - attention_backend: null # null means use the engine default value, available options: fa3, flashinfer, triton, flashmla + vllm: + swap_space: null # null means "use the engine default value" (usually 4 GB), setting it to, e.g., 32 means 32 GB + sglang: + attention_backend: null # null means use the engine default value, available options: flashinfer, triton, flashmla val_kwargs: # sampling parameters for validation top_k: -1 # 0 for hf rollout, -1 for vllm rollout diff --git a/verl/workers/rollout/sglang_rollout/sglang_rollout.py b/verl/workers/rollout/sglang_rollout/sglang_rollout.py index 47b604fe817..ed852f769f0 100644 --- a/verl/workers/rollout/sglang_rollout/sglang_rollout.py +++ b/verl/workers/rollout/sglang_rollout/sglang_rollout.py @@ -172,7 +172,7 @@ def __init__( load_format = "dummy" if config.load_format.startswith("dummy") else config.load_format # copy it to avoid secretly modifying the engine config - engine_kwargs = {} if "engine_kwargs" not in config else OmegaConf.to_container(deepcopy(config.engine_kwargs)) + engine_kwargs = {} if "engine_kwargs" not in config or "sglang" not in config.engine_kwargs else OmegaConf.to_container(deepcopy(config.engine_kwargs.sglang)) engine_kwargs = {key: val for key, val in engine_kwargs.items() if val is not None} self.inference_engine = VerlEngine( model_path=actor_module, diff --git a/verl/workers/rollout/vllm_rollout/vllm_rollout.py b/verl/workers/rollout/vllm_rollout/vllm_rollout.py index 9af5020d5b2..37a39a5ee82 100644 --- a/verl/workers/rollout/vllm_rollout/vllm_rollout.py +++ b/verl/workers/rollout/vllm_rollout/vllm_rollout.py @@ -96,7 +96,7 @@ def __init__(self, actor_module: nn.Module, config: DictConfig, tokenizer, model ): vllm_ps.initialize_parallel_state(tensor_model_parallel_size=tensor_parallel_size, num_tp_per_train_tp=num_tp_per_train_tp) - rope_scaling_config = getattr(model_hf_config, 'rope_scaling', None) + rope_scaling_config = getattr(model_hf_config, "rope_scaling", None) if not rope_scaling_config: assert model_hf_config.max_position_embeddings >= config.prompt_length + config.response_length, "model context length should be greater than total sequence length" @@ -110,7 +110,7 @@ def __init__(self, actor_module: nn.Module, config: DictConfig, tokenizer, model ) # copy it to avoid secretly modifying the engine config - engine_kwargs = {} if "engine_kwargs" not in config else OmegaConf.to_container(deepcopy(config.engine_kwargs)) + engine_kwargs = {} if "engine_kwargs" not in config or "vllm" not in config.engine_kwargs else OmegaConf.to_container(deepcopy(config.engine_kwargs.vllm)) # For each vLLM engine parameter, # - `None` means not setting it, so we pop it, and leave it to vLLM default value # (which can vary across different vLLM versions); diff --git a/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py b/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py index be2944478ff..e8ae44437dd 100644 --- a/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py +++ b/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py @@ -108,7 +108,7 @@ def __init__(self, model_path: str, config: DictConfig, tokenizer, model_hf_conf else: vllm_ps.initialize_model_parallel(tensor_model_parallel_size=tensor_parallel_size) - rope_scaling_config = getattr(model_hf_config, 'rope_scaling', None) + rope_scaling_config = getattr(model_hf_config, "rope_scaling", None) if not rope_scaling_config: max_position_embeddings = None if hasattr(model_hf_config, "max_position_embeddings"): @@ -136,7 +136,7 @@ def __init__(self, model_path: str, config: DictConfig, tokenizer, model_hf_conf limit_mm_per_prompt = {"image": config.get("limit_images")} # copy it to avoid secretly modifying the engine config - engine_kwargs = {} if "engine_kwargs" not in config else OmegaConf.to_container(deepcopy(config.engine_kwargs)) + engine_kwargs = {} if "engine_kwargs" not in config or "vllm" not in config.engine_kwargs else OmegaConf.to_container(deepcopy(config.engine_kwargs.vllm)) # For each vLLM engine parameter, # - `None` means not setting it, so we pop it, and leave it to vLLM default value # (which can vary across different vLLM versions);