Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for quantize_() with Float8Linear module #1344

Merged
merged 6 commits into from
Nov 28, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 24 additions & 5 deletions torchao/float8/float8_linear_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def _update_history_stack(
def swap_linear_layers(
module: nn.Module,
from_float_func: Callable[[nn.Linear], nn.Linear],
target_module: nn.Module = nn.Linear,
*,
module_filter_fn: Optional[Callable[[nn.Module, str], bool]] = None,
) -> nn.Module:
Expand All @@ -71,20 +72,21 @@ def swap_linear_layers(

Args:
module: Module to modify.
from_float_func: Function that accepts a linear layer and returns a new type of linear layer.
from_float_func: Function that accepts some type of linear layer and returns a new type of linear layer.
module_filter_fn: If specified, only the `torch.nn.Linear` subclasses that
that pass the filter function will be swapped. The inputs to the
filter function are the module instance, and the FQN.
target_module: Replace these modules

Returns:
nn.Module: The modified module with swapped linear layers.
"""
if isinstance(module, nn.Linear) and (
if isinstance(module, target_module) and (
module_filter_fn is None or module_filter_fn(module, "")
):
if len(list(module.children())) > 0:
raise AssertionError(
f"Does not support a root nn.Linear with children: {module}"
f"Does not support a root {target_module} with children: {module}"
)
return from_float_func(
module,
Expand All @@ -108,12 +110,12 @@ def post_order_traversal(

post_order_traversal(child_module, new_fqn, module)

if isinstance(module, nn.Linear) and (
if isinstance(module, target_module) and (
module_filter_fn is None or module_filter_fn(module, cur_fqn)
):
assert (
parent_module is not None
), f"Linear root module should return early: {module}"
), f"{target_module} root module should return early: {module}"
new_linear_module = from_float_func(module)
cur_module_name = cur_fqn.split(".")[-1]
setattr(parent_module, cur_module_name, new_linear_module)
Expand Down Expand Up @@ -319,3 +321,20 @@ def inner_func():
for child in fp8_layers:
# Set a flag to signal that initialization is done
child.is_amax_initialized = True

def dequantize_float8_training(model: nn.Module) -> nn.Module:
"""
Converts `Float8Linear` modules in `model` to `torch.nn.Linear`.
"""

def dequant_func(mod: Float8Linear) -> nn.Linear:
new_module = nn.Linear(mod.in_features, mod.out_features)
new_module.weight = mod.weight
new_module.bias = mod.bias
return new_module

return swap_linear_layers(
model,
dequant_func,
target_module=Float8Linear,
)
5 changes: 5 additions & 0 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@
to_affine_quantized_intx,
to_marlinqqq_quantized_intx,
)
from torchao.float8.float8_linear import Float8Linear
from torchao.float8.float8_linear_utils import dequantize_float8_training
from torchao.float8.inference import Float8MMConfig
from torchao.quantization.linear_activation_weight_observed_tensor import (
LinearActivationWeightObservedTensor,
Expand Down Expand Up @@ -222,6 +224,9 @@ def _replace_with_custom_fn_if_matches_filter(
Returns:
None
"""
# If model is Float8Linear, convert it to Linear before moving forward
if isinstance(model, Float8Linear):
model = dequantize_float8_training(model)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you just move your code snippet from the other file here:

if isinstance(model, Float8Linear):
    with torch.device("meta"):
        new_module = nn.Linear(model.in_features, model.out_features)
    new_module.weight = model.weight
    new_module.bias = model.bias        
    model = new_module

and not need any changes to torchao/float8?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vkuzo what do you think about having dequantizing a model as a separate API? it feels a bit weird to have this logic in _replace_with_custom_fn_if_matches_filter which is supposed to be a simple module replacement function I feel.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my 5c -

  1. we can make it work now without adding any public APIs, with minimal increase in complexity
  2. if it's important to have a public API for "remove low precision training from a model", we can have that conversation in parallel

wdyt

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the motivation for adding a new API is making the dequantizing step more explicit for user, instead of hide it in a module replacement function.

but agree this can happen in parallel. also it's probably not worth spending time to discuss as of now, and wait until there are more use cases might be better

if filter_fn(model, cur_fqn[:-1]):
if device is not None:
model.to(device=device) # move to device before quantization
Expand Down
Loading