diff --git a/verl/trainer/config/_generated_ppo_megatron_trainer.yaml b/verl/trainer/config/_generated_ppo_megatron_trainer.yaml index d6260ed6b19..b5c37f691a3 100644 --- a/verl/trainer/config/_generated_ppo_megatron_trainer.yaml +++ b/verl/trainer/config/_generated_ppo_megatron_trainer.yaml @@ -287,6 +287,7 @@ actor_rollout_ref: file: /tmp/ray/session_latest/metrics/prometheus/prometheus.yml served_model_name: ${oc.select:actor_rollout_ref.model.path,null} quantization: null + quantization_config_file: null layer_name_map: qkv_layer_name: qkv gate_proj_layer_name: gate_up diff --git a/verl/trainer/config/_generated_ppo_trainer.yaml b/verl/trainer/config/_generated_ppo_trainer.yaml index da79bfb943a..e696d715c9e 100644 --- a/verl/trainer/config/_generated_ppo_trainer.yaml +++ b/verl/trainer/config/_generated_ppo_trainer.yaml @@ -275,6 +275,8 @@ actor_rollout_ref: port: 9090 file: /tmp/ray/session_latest/metrics/prometheus/prometheus.yml served_model_name: ${oc.select:actor_rollout_ref.model.path,null} + quantization: null + quantization_config_file: null layered_summon: false model: _target_: verl.workers.config.HFModelConfig diff --git a/verl/trainer/config/rollout/rollout.yaml b/verl/trainer/config/rollout/rollout.yaml index 968d9e11277..c97526092ee 100644 --- a/verl/trainer/config/rollout/rollout.yaml +++ b/verl/trainer/config/rollout/rollout.yaml @@ -325,3 +325,8 @@ prometheus: # Specify served_model_name to avoid displaying overly long model paths in Grafana served_model_name: ${oc.select:actor_rollout_ref.model.path,null} +# type of quantization in vllm, currently support fp8 and torchao +quantization: null + +# extra quantization information serialized in a config file, e.g. torchao_config.json +quantization_config_file: null diff --git a/verl/workers/config/rollout.py b/verl/workers/config/rollout.py index 7c3c09ced4b..2f205375475 100644 --- a/verl/workers/config/rollout.py +++ b/verl/workers/config/rollout.py @@ -211,6 +211,8 @@ class RolloutConfig(BaseConfig): quantization: Optional[str] = None + quantization_config_file: Optional[str] = None + enable_rollout_routing_replay: bool = False enable_sleep_mode: bool = True diff --git a/verl/workers/rollout/vllm_rollout/vllm_async_server.py b/verl/workers/rollout/vllm_rollout/vllm_async_server.py index 00a8f75ef9d..14efc2973ac 100644 --- a/verl/workers/rollout/vllm_rollout/vllm_async_server.py +++ b/verl/workers/rollout/vllm_rollout/vllm_async_server.py @@ -263,7 +263,12 @@ async def launch_server(self, master_address: str = None, master_port: int = Non set_expandable_segments(True) quantization = self.config.quantization + if quantization is not None: + _SUPPORTED_QUANTIZATION = ["fp8", "torchao"] + if quantization not in _SUPPORTED_QUANTIZATION: + raise ValueError(f"Currently only support {_SUPPORTED_QUANTIZATION} quantization, got: {quantization}") + if quantization == "fp8": FP8_BLOCK_QUANT_KWARGS = { "activation_scheme": "dynamic", @@ -275,8 +280,14 @@ async def launch_server(self, master_address: str = None, master_port: int = Non # Apply vllm fp8 patches # Will remove the patch after vllm support on-the-fly quant for rollout natively. apply_vllm_fp8_patches() - else: - raise ValueError(f"Currently only support fp8 quantization, got: {quantization}") + + hf_overrides = {} + if quantization is not None and self.config.quantization_config_file is not None: + hf_overrides["quantization_config_file"] = self.config.quantization_config_file + + if quantization == "fp8": + hf_overrides["quantization_config"] = fp8_block_quant_kwargs + args = { "dtype": self.config.dtype, "load_format": self.config.load_format, @@ -296,7 +307,7 @@ async def launch_server(self, master_address: str = None, master_port: int = Non "seed": self.config.get("seed", 0), "override_generation_config": json.dumps(override_generation_config), "quantization": quantization, - "hf_overrides": {"quantization_config": fp8_block_quant_kwargs} if quantization == "fp8" else None, + "hf_overrides": hf_overrides, **engine_kwargs, } diff --git a/verl/workers/rollout/vllm_rollout/vllm_rollout.py b/verl/workers/rollout/vllm_rollout/vllm_rollout.py index 7d540cff9a1..7a0105bedaa 100644 --- a/verl/workers/rollout/vllm_rollout/vllm_rollout.py +++ b/verl/workers/rollout/vllm_rollout/vllm_rollout.py @@ -191,12 +191,17 @@ def _init_worker(self, all_kwargs: list[dict[str, Any]]): lora_dtype = getattr(torch, self.config.dtype) self.vllm_config.lora_config = LoRAConfig(lora_dtype=lora_dtype, **self.lora_config) if self.config.quantization is not None: + _SUPPORTED_QUANTIZATION = ["fp8", "torchao"] + if self.config.quantization not in _SUPPORTED_QUANTIZATION: + raise ValueError( + f"Currently only support {_SUPPORTED_QUANTIZATION} quantization, got: {self.config.quantization}" + ) + if self.config.quantization == "fp8": # Apply vllm fp8 patches # Will remove the patch after vllm support on-the-fly quant for rollout natively. apply_vllm_fp8_patches() - else: - raise ValueError(f"Currently only support fp8 quantization, got: {self.config.quantization}") + self.inference_engine = WorkerWrapperBase(vllm_config=self.vllm_config) self.inference_engine.init_worker(all_kwargs)