Skip to content

Commit

Permalink
Updated logic'
Browse files Browse the repository at this point in the history
  • Loading branch information
jainapurva committed Nov 28, 2024
1 parent dfdba92 commit 94655a9
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 25 deletions.
12 changes: 5 additions & 7 deletions torchao/float8/float8_linear_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ def swap_linear_layers(
from_float_func: Callable[[nn.Linear], nn.Linear],
*,
module_filter_fn: Optional[Callable[[nn.Module, str], bool]] = None,
target_module: nn.Module = nn.Linear,
) -> nn.Module:
"""
Generic function to swap linear layers in a module with a new type of linear layer.
Expand All @@ -72,21 +71,20 @@ def swap_linear_layers(
Args:
module: Module to modify.
from_float_func: Function that accepts some type of linear layer and returns a new type of linear layer.
from_float_func: Function that accepts a linear layer and returns a new type of linear layer.
module_filter_fn: If specified, only the `torch.nn.Linear` subclasses that
that pass the filter function will be swapped. The inputs to the
filter function are the module instance, and the FQN.
target_module: Replace these modules
Returns:
nn.Module: The modified module with swapped linear layers.
"""
if isinstance(module, target_module) and (
if isinstance(module, nn.Linear) and (
module_filter_fn is None or module_filter_fn(module, "")
):
if len(list(module.children())) > 0:
raise AssertionError(
f"Does not support a root {target_module.__module__} with children: {module.__module__}"
f"Does not support a root nn.Linear with children: {module}"
)
return from_float_func(
module,
Expand All @@ -110,12 +108,12 @@ def post_order_traversal(

post_order_traversal(child_module, new_fqn, module)

if isinstance(module, target_module) and (
if isinstance(module, nn.Linear) and (
module_filter_fn is None or module_filter_fn(module, cur_fqn)
):
assert (
parent_module is not None
), f"{target_module} root module should return early: {module}"
), f"Linear root module should return early: {module}"
new_linear_module = from_float_func(module)
cur_module_name = cur_fqn.split(".")[-1]
setattr(parent_module, cur_module_name, new_linear_module)
Expand Down
23 changes: 5 additions & 18 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
to_marlinqqq_quantized_intx,
)
from torchao.float8.float8_linear import Float8Linear
from torchao.float8.float8_linear_utils import swap_linear_layers
from torchao.float8.inference import Float8MMConfig
from torchao.quantization.linear_activation_weight_observed_tensor import (
LinearActivationWeightObservedTensor,
Expand Down Expand Up @@ -224,24 +223,12 @@ def _replace_with_custom_fn_if_matches_filter(
Returns:
None
"""

def dequantize_float8_training(model: nn.Module) -> nn.Module:
"""Converts `Float8Linear` modules in `model` to `torch.nn.Linear`."""

def dequant_func(mod: Float8Linear) -> nn.Linear:
new_module = nn.Linear(mod.in_features, mod.out_features)
new_module.weight = mod.weight
new_module.bias = mod.bias
return new_module

return swap_linear_layers(
model,
dequant_func,
target_module=Float8Linear,
)

if isinstance(model, Float8Linear):
model = dequantize_float8_training(model)
with torch.device("meta"):
new_module = nn.Linear(model.in_features, model.out_features)
new_module.weight = model.weight
new_module.bias = model.bias
model = new_module
if filter_fn(model, cur_fqn[:-1]):
if device is not None:
model.to(device=device) # move to device before quantization
Expand Down

0 comments on commit 94655a9

Please sign in to comment.