Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions verl/trainer/config/_generated_ppo_megatron_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,7 @@ actor_rollout_ref:
file: /tmp/ray/session_latest/metrics/prometheus/prometheus.yml
served_model_name: ${oc.select:actor_rollout_ref.model.path,null}
quantization: null
quantization_config_file: null
layer_name_map:
qkv_layer_name: qkv
gate_proj_layer_name: gate_up
Expand Down
2 changes: 2 additions & 0 deletions verl/trainer/config/_generated_ppo_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,8 @@ actor_rollout_ref:
port: 9090
file: /tmp/ray/session_latest/metrics/prometheus/prometheus.yml
served_model_name: ${oc.select:actor_rollout_ref.model.path,null}
quantization: null
quantization_config_file: null
layered_summon: false
model:
_target_: verl.workers.config.HFModelConfig
Expand Down
5 changes: 5 additions & 0 deletions verl/trainer/config/rollout/rollout.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -325,3 +325,8 @@ prometheus:
# Specify served_model_name to avoid displaying overly long model paths in Grafana
served_model_name: ${oc.select:actor_rollout_ref.model.path,null}

# type of quantization in vllm, currently support fp8 and torchao
quantization: null
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add comment above quantization and quantization_config_file.


# extra quantization information serialized in a config file, e.g. torchao_config.json
quantization_config_file: null
2 changes: 2 additions & 0 deletions verl/workers/config/rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,8 @@ class RolloutConfig(BaseConfig):

quantization: Optional[str] = None

quantization_config_file: Optional[str] = None

enable_rollout_routing_replay: bool = False

enable_sleep_mode: bool = True
Expand Down
17 changes: 14 additions & 3 deletions verl/workers/rollout/vllm_rollout/vllm_async_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,12 @@ async def launch_server(self, master_address: str = None, master_port: int = Non
set_expandable_segments(True)

quantization = self.config.quantization

if quantization is not None:
_SUPPORTED_QUANTIZATION = ["fp8", "torchao"]
if quantization not in _SUPPORTED_QUANTIZATION:
raise ValueError(f"Currently only support {_SUPPORTED_QUANTIZATION} quantization, got: {quantization}")

if quantization == "fp8":
FP8_BLOCK_QUANT_KWARGS = {
"activation_scheme": "dynamic",
Expand All @@ -275,8 +280,14 @@ async def launch_server(self, master_address: str = None, master_port: int = Non
# Apply vllm fp8 patches
# Will remove the patch after vllm support on-the-fly quant for rollout natively.
apply_vllm_fp8_patches()
else:
raise ValueError(f"Currently only support fp8 quantization, got: {quantization}")

hf_overrides = {}
if quantization is not None and self.config.quantization_config_file is not None:
hf_overrides["quantization_config_file"] = self.config.quantization_config_file

if quantization == "fp8":
hf_overrides["quantization_config"] = fp8_block_quant_kwargs

args = {
"dtype": self.config.dtype,
"load_format": self.config.load_format,
Expand All @@ -296,7 +307,7 @@ async def launch_server(self, master_address: str = None, master_port: int = Non
"seed": self.config.get("seed", 0),
"override_generation_config": json.dumps(override_generation_config),
"quantization": quantization,
"hf_overrides": {"quantization_config": fp8_block_quant_kwargs} if quantization == "fp8" else None,
"hf_overrides": hf_overrides,
**engine_kwargs,
}

Expand Down
9 changes: 7 additions & 2 deletions verl/workers/rollout/vllm_rollout/vllm_rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,12 +191,17 @@ def _init_worker(self, all_kwargs: list[dict[str, Any]]):
lora_dtype = getattr(torch, self.config.dtype)
self.vllm_config.lora_config = LoRAConfig(lora_dtype=lora_dtype, **self.lora_config)
if self.config.quantization is not None:
_SUPPORTED_QUANTIZATION = ["fp8", "torchao"]
if self.config.quantization not in _SUPPORTED_QUANTIZATION:
raise ValueError(
f"Currently only support {_SUPPORTED_QUANTIZATION} quantization, got: {self.config.quantization}"
)

if self.config.quantization == "fp8":
# Apply vllm fp8 patches
# Will remove the patch after vllm support on-the-fly quant for rollout natively.
apply_vllm_fp8_patches()
else:
raise ValueError(f"Currently only support fp8 quantization, got: {self.config.quantization}")

self.inference_engine = WorkerWrapperBase(vllm_config=self.vllm_config)
self.inference_engine.init_worker(all_kwargs)

Expand Down
Loading