Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 121 additions & 5 deletions tests/quantization/test_torchao.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ def test_pre_quantized_model(vllm_runner):
output = llm.generate_greedy(["The capital of France is"],
max_tokens=32)
assert output
print(output)


@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available")
Expand All @@ -42,7 +41,6 @@ def test_opt_125m_int8wo_model_loading_with_params(vllm_runner,
max_tokens=32)

assert output
print(output)


@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available")
Expand All @@ -57,7 +55,6 @@ def test_opt_125m_int4wo_model_per_module_quant(vllm_runner):
max_tokens=32)

assert output
print(output)


@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available")
Expand All @@ -72,7 +69,6 @@ def test_qwenvl_int8wo_model_loading_with_params(vllm_runner):
max_tokens=32)

assert output
print(output)


@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available")
Expand All @@ -92,7 +88,127 @@ def test_opt_125m_awq_int4wo_model_loading_with_params(vllm_runner):
max_tokens=32)

assert output
print(output)


@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available")
def test_on_the_fly_quant_config_dict_json(vllm_runner):
"""Testing on the fly quantization, load_weights integration point,
with config dict serialized to json string
"""
torch._dynamo.reset()
model_name = "facebook/opt-125m"

import json

from torchao.core.config import config_to_dict
from torchao.quantization import (
Float8DynamicActivationFloat8WeightConfig, PerRow)

torchao_quant_config = Float8DynamicActivationFloat8WeightConfig(
granularity=PerRow())
hf_overrides = {
"quantization_config_dict_json":
json.dumps(config_to_dict(torchao_quant_config))
}
with vllm_runner(model_name=model_name,
dtype="bfloat16",
pt_load_map_location="cuda:0",
quantization="torchao",
hf_overrides=hf_overrides) as llm:
output = llm.generate_greedy(["The capital of France is"],
max_tokens=32)

assert output


@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available")
def test_on_the_fly_quant_config_file(vllm_runner):
"""Testing on the fly quantization, load_weights integration point,
with config file
"""
torch._dynamo.reset()
model_name = "facebook/opt-125m"
import json
from tempfile import NamedTemporaryFile

from torchao.core.config import config_to_dict
from torchao.quantization import (
Float8DynamicActivationFloat8WeightConfig, PerRow)

config = Float8DynamicActivationFloat8WeightConfig(granularity=PerRow())

with NamedTemporaryFile(mode="w", delete=False) as f:
f.write(json.dumps(config_to_dict(config)))
# close the file to save it
f.close()
config_file_name = str(f.name)

hf_overrides = {"quantization_config_file": config_file_name}
with vllm_runner(model_name=model_name,
dtype="bfloat16",
pt_load_map_location="cuda:0",
quantization="torchao",
hf_overrides=hf_overrides) as llm:
output = llm.generate_greedy(["The capital of France is"],
max_tokens=32)

assert output


@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available")
def test_reload_weights():
import json

from torchao.core.config import config_to_dict
from torchao.quantization import (
Float8DynamicActivationFloat8WeightConfig, PerRow)

from vllm import LLM, SamplingParams

torchao_quant_config = Float8DynamicActivationFloat8WeightConfig(
granularity=PerRow())

hf_overrides = {
"quantization_config_dict_json":
json.dumps(config_to_dict(torchao_quant_config))
}

llm = LLM(
model="Qwen/Qwen3-0.6B",
dtype="bfloat16",
load_format="dummy",
enforce_eager=True,
quantization="torchao",
hf_overrides=hf_overrides,
)
# Update load format from `dummy` to `auto`
llm.collective_rpc("update_config",
args=({
"load_config": {
"load_format": "auto"
}
}, ))
# Now reload real weights inplace
llm.collective_rpc("reload_weights")
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0, top_p=0.95)
outputs = llm.generate(prompts, sampling_params)
# make sure it runs
for output in outputs:
generated_text = output.outputs[0].text
assert generated_text
# can also uncomment locally to make sure the generated
# output makes sense
# prompt = output.prompt
# print(f"Prompt: {prompt!r}")
# print(f"Output: {generated_text!r}")
# print("-" * 60)


if __name__ == "__main__":
Expand Down
72 changes: 64 additions & 8 deletions vllm/model_executor/layers/quantization/torchao.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
from typing import Any, Optional

import torch
Expand Down Expand Up @@ -40,7 +41,8 @@ class TorchAOConfig(QuantizationConfig):

def __init__(self,
torchao_config,
skip_modules: Optional[list[str]] = None) -> None:
skip_modules: Optional[list[str]] = None,
is_checkpoint_torchao_serialized: bool = False) -> None:
"""
# TorchAO quantization relies on tensor subclasses. In order,
# to enable proper caching this needs standalone compile
Expand All @@ -58,9 +60,11 @@ def __init__(self,
super().__init__()
self.torchao_config = torchao_config
self.skip_modules = skip_modules or []
self.is_checkpoint_torchao_serialized = is_checkpoint_torchao_serialized

def __repr__(self) -> str:
return f"TorchAOConfig({self.torchao_config})"
return f"TorchAOConfig({self.torchao_config=}, {self.skip_modules=}, " \
f"{self.is_checkpoint_torchao_serialized=})"

def get_name(self) -> QuantizationMethods:
return "torchao"
Expand All @@ -74,7 +78,10 @@ def get_min_capability(cls) -> int:

@staticmethod
def get_config_filenames() -> list[str]:
return ["config.json"]
"""torchao doesn't require additional config files, we use
`config.json` from huggingface: `model_config.hf_config`
"""
return []

@classmethod
def from_config(cls, config: dict[str, Any]) -> "TorchAOConfig":
Expand All @@ -87,6 +94,10 @@ def from_config(cls, config: dict[str, Any]) -> "TorchAOConfig":
"`pip install torchao>=0.10.0` to use torchao quantization."
) from err

quant_method = cls.get_from_keys_or(config, ["quant_method"], None)
is_checkpoint_torchao_serialized = (quant_method is not None
and "torchao" in quant_method)

hf_config = cls.get_from_keys_or(config, ["quant_type"], None)
assert hf_config is not None, "quant_type must be specified"
assert len(hf_config) == 1 and "default" in hf_config, (
Expand All @@ -110,7 +121,38 @@ def from_config(cls, config: dict[str, Any]) -> "TorchAOConfig":
if layer_cfg is None:
skip_modules.append(layer)

return cls(ao_config, skip_modules)
return cls(ao_config, skip_modules, is_checkpoint_torchao_serialized)

@classmethod
def from_config_file(cls, config_file: str) -> "TorchAOConfig":
"""Initialize class from a config file. Example:
```
config = (
Float8DynamicActivationFloat8WeightConfig(granularity=PerRow())
)
fn = "torchao_config.json"

with open(fn, "w") as f:
f.write(json.dumps(config_to_dict(config)))
```
"""
with open(config_file) as f:
f.seek(0)
f_read = f.read()
config_dict = json.loads(f_read)

hf_config = {"quant_type": {"default": config_dict}}
return cls.from_config(hf_config)

@classmethod
def from_config_dict_json(cls, config_dict_json: str) -> "TorchAOConfig":
"""Iniitalize class from a config_dict json string, got from
torchao_config_object = some AOBaseConfig object
json.dumps(config_to_dict(torchao_config_object))
"""
config_dict = json.loads(config_dict_json)
hf_config = {"quant_type": {"default": config_dict}}
return cls.from_config(hf_config)

def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["QuantizeMethodBase"]:
Expand All @@ -128,7 +170,9 @@ def get_quant_method(self, layer: torch.nn.Module,
c = module_fqn_to_config.get(
module_fqn) or module_fqn_to_config.get("_default", None)
if c is not None:
current_torchao_config = TorchAOConfig(c, self.skip_modules)
current_torchao_config = TorchAOConfig(
c, self.skip_modules,
self.is_checkpoint_torchao_serialized)
return TorchAOLinearMethod(current_torchao_config)
else:
return UnquantizedLinearMethod()
Expand Down Expand Up @@ -172,7 +216,7 @@ class TorchAOLinearMethod(LinearMethodBase):
"""Linear method for torchao.

Args:
quant_config: The torchao quantization config, a string that encodes
quant_config: The torchao quantization config, a string that encodes
the type of quantization and all relevant arguments.
"""

Expand All @@ -197,8 +241,9 @@ def create_weights(
),
requires_grad=False,
)
weight = torchao_quantize_param_data(weight,
self.quant_config.torchao_config)
if self.quant_config.is_checkpoint_torchao_serialized:
weight = torchao_quantize_param_data(
weight, self.quant_config.torchao_config)

set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})

Expand All @@ -212,3 +257,14 @@ def apply(
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return F.linear(x, layer.weight, bias)

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if self.quant_config.is_checkpoint_torchao_serialized:
return

# quantize the weight on the fly if the checkpoint is not already
# quantized by torchao
weight = torchao_quantize_param_data(layer.weight,
self.quant_config.torchao_config)
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
layer.register_parameter("weight", weight)
31 changes: 29 additions & 2 deletions vllm/model_executor/model_loader/default_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,8 +261,35 @@ def download_model(self, model_config: ModelConfig) -> None:
def load_weights(self, model: nn.Module,
model_config: ModelConfig) -> None:
weights_to_load = {name for name, _ in model.named_parameters()}
loaded_weights = model.load_weights(
self.get_all_weights(model_config, model))

# if we don't have `model.weight_metadata_and_attr_saved` defined and
# set to True, it means that this is either offline quantization case
# or the first run of online quantization
# see online_quantization.py for detailed notes
offline_quantization_or_first_run_of_online_quantization = not getattr(
model, "weight_metadata_and_attr_saved", False)

if model_config.quantization is None:
# model is not quantized
loaded_weights = model.load_weights(
self.get_all_weights(model_config, model))
elif offline_quantization_or_first_run_of_online_quantization:
# case 1: offline quantized checkpoint
# case 2: Step I1 first run of weight loading with
# online quantization
# see online_quantization.py for detailed notes
loaded_weights = model.load_weights(
self.get_all_weights(model_config, model))
else:
# to avoid circular dependency
from vllm.model_executor.model_loader.online_quantization import (
load_weights_and_online_quantize)

# subsequent runs of weight loading with online
# quantization
loaded_weights = load_weights_and_online_quantize(
self, model, model_config)

self.counter_after_loading_weights = time.perf_counter()
logger.info(
"Loading weights took %.2f seconds",
Expand Down
Loading