Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for quantize_() with Float8Linear module #1344

Merged
merged 6 commits into from
Nov 28, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions test/float8/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import torch
import torch.nn as nn

from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, is_sm_89

if not TORCH_VERSION_AT_LEAST_2_5:
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
Expand Down Expand Up @@ -531,6 +531,21 @@ def test_inference_mode(self):
with torch.inference_mode(mode=True):
m(x)

@unittest.skipIf(not is_sm_89(), "CUDA arch 8.9 not available")
def test_quantize(self):
x = torch.randn(32, 32, device="cuda")
m = nn.Sequential(nn.Linear(32, 32)).cuda()
m = convert_to_float8_training(m)
assert isinstance(m[0], Float8Linear), "Module is not a Float8Linear"
from torchao.quantization.quant_api import float8_weight_only, quantize_

quantize_(m, float8_weight_only())
assert (
m[0].weight.tensor_impl.float8_data.dtype == torch.float8_e4m3fn
), "Post quantization dtype should be torch.float8_e4m3fn"
with torch.no_grad():
m(x)


class TestScaledMM:
@unittest.skipIf(
Expand Down Expand Up @@ -576,7 +591,7 @@ def test_scaled_mm_vs_emulated(self, base_dtype, use_fast_accum):
if base_dtype in {torch.bfloat16, torch.float16}:
atol, rtol = 7e-2, 7e-2
else:
atol, rtol = 2e-3, 2e-3
atol, rtol = 3e-3, 3e-3
torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol)

@unittest.skipIf(not is_cuda_8_9, "CUDA not available")
Expand Down Expand Up @@ -751,7 +766,7 @@ def test_swap_root_linear_with_children_raises(self):
config = Float8LinearConfig(emulate=emulate)
with self.assertRaisesRegex(
AssertionError,
"Does not support a root nn.Linear with children",
"Does not support a root torch.nn.modules.linear with children",
jainapurva marked this conversation as resolved.
Show resolved Hide resolved
):
convert_to_float8_training(module, config=config)

Expand Down
12 changes: 7 additions & 5 deletions torchao/float8/float8_linear_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def swap_linear_layers(
from_float_func: Callable[[nn.Linear], nn.Linear],
*,
module_filter_fn: Optional[Callable[[nn.Module, str], bool]] = None,
target_module: nn.Module = nn.Linear,
) -> nn.Module:
"""
Generic function to swap linear layers in a module with a new type of linear layer.
Expand All @@ -71,20 +72,21 @@ def swap_linear_layers(

Args:
module: Module to modify.
from_float_func: Function that accepts a linear layer and returns a new type of linear layer.
from_float_func: Function that accepts some type of linear layer and returns a new type of linear layer.
module_filter_fn: If specified, only the `torch.nn.Linear` subclasses that
that pass the filter function will be swapped. The inputs to the
filter function are the module instance, and the FQN.
target_module: Replace these modules

Returns:
nn.Module: The modified module with swapped linear layers.
"""
if isinstance(module, nn.Linear) and (
if isinstance(module, target_module) and (
module_filter_fn is None or module_filter_fn(module, "")
):
if len(list(module.children())) > 0:
raise AssertionError(
f"Does not support a root nn.Linear with children: {module}"
f"Does not support a root {target_module.__module__} with children: {module.__module__}"
)
return from_float_func(
module,
Expand All @@ -108,12 +110,12 @@ def post_order_traversal(

post_order_traversal(child_module, new_fqn, module)

if isinstance(module, nn.Linear) and (
if isinstance(module, target_module) and (
module_filter_fn is None or module_filter_fn(module, cur_fqn)
):
assert (
parent_module is not None
), f"Linear root module should return early: {module}"
), f"{target_module} root module should return early: {module}"
new_linear_module = from_float_func(module)
cur_module_name = cur_fqn.split(".")[-1]
setattr(parent_module, cur_module_name, new_linear_module)
Expand Down
20 changes: 20 additions & 0 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@
to_affine_quantized_intx,
to_marlinqqq_quantized_intx,
)
from torchao.float8.float8_linear import Float8Linear
from torchao.float8.float8_linear_utils import swap_linear_layers
from torchao.float8.inference import Float8MMConfig
from torchao.quantization.linear_activation_weight_observed_tensor import (
LinearActivationWeightObservedTensor,
Expand Down Expand Up @@ -199,6 +201,22 @@ def change_linear_weights_to_int4_woqtensors(
########
# TO BE DEPRECATED END
########
def dequantize_float8_training(model: nn.Module) -> nn.Module:
jainapurva marked this conversation as resolved.
Show resolved Hide resolved
"""
Converts `Float8Linear` modules in `model` to `torch.nn.Linear`.
"""

def dequant_func(mod: Float8Linear) -> nn.Linear:
new_module = nn.Linear(mod.in_features, mod.out_features)
new_module.weight = mod.weight
new_module.bias = mod.bias
return new_module

return swap_linear_layers(
model,
dequant_func,
target_module=Float8Linear,
)


def _replace_with_custom_fn_if_matches_filter(
Expand All @@ -222,6 +240,8 @@ def _replace_with_custom_fn_if_matches_filter(
Returns:
None
"""
if isinstance(model, Float8Linear):
model = dequantize_float8_training(model)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you just move your code snippet from the other file here:

if isinstance(model, Float8Linear):
    with torch.device("meta"):
        new_module = nn.Linear(model.in_features, model.out_features)
    new_module.weight = model.weight
    new_module.bias = model.bias        
    model = new_module

and not need any changes to torchao/float8?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vkuzo what do you think about having dequantizing a model as a separate API? it feels a bit weird to have this logic in _replace_with_custom_fn_if_matches_filter which is supposed to be a simple module replacement function I feel.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my 5c -

  1. we can make it work now without adding any public APIs, with minimal increase in complexity
  2. if it's important to have a public API for "remove low precision training from a model", we can have that conversation in parallel

wdyt

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the motivation for adding a new API is making the dequantizing step more explicit for user, instead of hide it in a module replacement function.

but agree this can happen in parallel. also it's probably not worth spending time to discuss as of now, and wait until there are more use cases might be better

if filter_fn(model, cur_fqn[:-1]):
if device is not None:
model.to(device=device) # move to device before quantization
Expand Down
Loading