Skip to content

Commit

Permalink
Implicit conversion
Browse files Browse the repository at this point in the history
  • Loading branch information
jainapurva committed Nov 27, 2024
1 parent 62bd5da commit a184759
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 20 deletions.
14 changes: 14 additions & 0 deletions test/float8/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,20 @@ def test_inference_mode(self):
with torch.inference_mode(mode=True):
m(x)

def test_quantize(self):
x = torch.randn(32, 32, device="cuda")
m = nn.Sequential(nn.Linear(32, 32)).cuda()
m = convert_to_float8_training(m)
assert isinstance(m[0], Float8Linear), "Module is not a Float8Linear"
from torchao.quantization.quant_api import float8_weight_only, quantize_

quantize_(m, float8_weight_only())
assert (
m[0].weight.tensor_impl.float8_data.dtype == torch.float8_e4m3fn
), "Post quantization dtype should be torch.float8_e4m3fn"
with torch.no_grad():
m(x)


class TestScaledMM:
@unittest.skipIf(
Expand Down
2 changes: 0 additions & 2 deletions torchao/float8/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from torchao.float8.float8_linear import WeightWithDelayedFloat8CastTensor
from torchao.float8.float8_linear_utils import (
convert_to_float8_training,
dequantize_float8_training,
linear_requires_sync,
sync_float8_amax_and_scale_history,
)
Expand Down Expand Up @@ -55,6 +54,5 @@
"linear_requires_sync",
"sync_float8_amax_and_scale_history",
"precompute_float8_dynamic_scale_for_fsdp",
"dequantize_float8_training",
# note: Float8Tensor and Float8Linear are not public APIs
]
18 changes: 0 additions & 18 deletions torchao/float8/float8_linear_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,21 +321,3 @@ def inner_func():
for child in fp8_layers:
# Set a flag to signal that initialization is done
child.is_amax_initialized = True


def dequantize_float8_training(model: nn.Module) -> nn.Module:
"""
Converts `Float8Linear` modules in `model` to `torch.nn.Linear`.
"""

def dequant_func(mod: Float8Linear) -> nn.Linear:
new_module = nn.Linear(mod.in_features, mod.out_features)
new_module.weight = mod.weight
new_module.bias = mod.bias
return new_module

return swap_linear_layers(
model,
dequant_func,
target_module=Float8Linear,
)
20 changes: 20 additions & 0 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@
to_affine_quantized_intx,
to_marlinqqq_quantized_intx,
)
from torchao.float8.float8_linear import Float8Linear
from torchao.float8.float8_linear_utils import swap_linear_layers
from torchao.float8.inference import Float8MMConfig
from torchao.quantization.linear_activation_weight_observed_tensor import (
LinearActivationWeightObservedTensor,
Expand Down Expand Up @@ -199,6 +201,22 @@ def change_linear_weights_to_int4_woqtensors(
########
# TO BE DEPRECATED END
########
def dequantize_float8_training(model: nn.Module) -> nn.Module:
"""
Converts `Float8Linear` modules in `model` to `torch.nn.Linear`.
"""

def dequant_func(mod: Float8Linear) -> nn.Linear:
new_module = nn.Linear(mod.in_features, mod.out_features)
new_module.weight = mod.weight
new_module.bias = mod.bias
return new_module

return swap_linear_layers(
model,
dequant_func,
target_module=Float8Linear,
)


def _replace_with_custom_fn_if_matches_filter(
Expand All @@ -222,6 +240,8 @@ def _replace_with_custom_fn_if_matches_filter(
Returns:
None
"""
if isinstance(model, Float8Linear):
model = dequantize_float8_training(model)
if filter_fn(model, cur_fqn[:-1]):
if device is not None:
model.to(device=device) # move to device before quantization
Expand Down

0 comments on commit a184759

Please sign in to comment.