Skip to content

Commit

Permalink
Enable 8-bit
Browse files Browse the repository at this point in the history
Summary: Enables 8-bit kernel in operators and tests

Differential Revision: D65688410
  • Loading branch information
metascroy authored and facebook-github-bot committed Nov 8, 2024
1 parent 242f181 commit ac23009
Show file tree
Hide file tree
Showing 10 changed files with 74 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def __init__(
group_size: int,
target: str,
):
assert nbit <= 7
assert nbit <= 8
self.nbit = nbit
self.group_size = group_size
self.target = target_from_str(target)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ TORCH_LIBRARY_FRAGMENT(torchao, m) {
DEFINE_OP(5);
DEFINE_OP(6);
DEFINE_OP(7);
DEFINE_OP(8);
}

TORCH_LIBRARY_IMPL(torchao, CPU, m) {
Expand All @@ -46,6 +47,7 @@ TORCH_LIBRARY_IMPL(torchao, CPU, m) {
DEFINE_CPU_IMPL(5);
DEFINE_CPU_IMPL(6);
DEFINE_CPU_IMPL(7);
DEFINE_CPU_IMPL(8);
}

TORCH_LIBRARY_IMPL(torchao, Meta, m) {
Expand All @@ -56,4 +58,5 @@ TORCH_LIBRARY_IMPL(torchao, Meta, m) {
DEFINE_META_IMPL(5);
DEFINE_META_IMPL(6);
DEFINE_META_IMPL(7);
DEFINE_META_IMPL(8);
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,4 @@ DEFINE_OP(4);
DEFINE_OP(5);
DEFINE_OP(6);
DEFINE_OP(7);
DEFINE_OP(8);
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ TORCH_LIBRARY(torchao, m) {
DEFINE_OP(5);
DEFINE_OP(6);
DEFINE_OP(7);
DEFINE_OP(8);
}

TORCH_LIBRARY_IMPL(torchao, CPU, m) {
Expand All @@ -78,6 +79,7 @@ TORCH_LIBRARY_IMPL(torchao, CPU, m) {
DEFINE_CPU_IMPL(5);
DEFINE_CPU_IMPL(6);
DEFINE_CPU_IMPL(7);
DEFINE_CPU_IMPL(8);
}

TORCH_LIBRARY_IMPL(torchao, Meta, m) {
Expand All @@ -88,4 +90,5 @@ TORCH_LIBRARY_IMPL(torchao, Meta, m) {
DEFINE_META_IMPL(5);
DEFINE_META_IMPL(6);
DEFINE_META_IMPL(7);
DEFINE_META_IMPL(8);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// Copyright (c) Meta Platforms, Inc. and affiliates.
// All rights reserved.
//
// This source code is licensed under the license found in the
// LICENSE file in the root directory of this source tree.

// Unlike ATen, ExecuTorch op registration appears to only allow on
// EXECUTORCH_LIBRARY per cpp file due to a name redefinition error, so a new
// file is needed for each variant

#include <torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight-impl.h>

namespace {
Tensor _op_out(
RuntimeContext& ctx,
const Tensor& activations,
const Tensor& packed_weights,
const Tensor& group_size_tensor,
const Tensor& n_tensor,
const Tensor& k_tensor,
Tensor& out) {
(void)ctx;
linear_out_cpu</*weight_nbit*/ 8, /*has_weight_zeros*/ false>(
activations, packed_weights, group_size_tensor, n_tensor, k_tensor, out);
return out;
}
} // namespace

EXECUTORCH_LIBRARY(torchao, "_linear_8bit_act_8bit0zp_weight.out", _op_out);
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// Copyright (c) Meta Platforms, Inc. and affiliates.
// All rights reserved.
//
// This source code is licensed under the license found in the
// LICENSE file in the root directory of this source tree.

// Unlike ATen, ExecuTorch op registration appears to only allow on
// EXECUTORCH_LIBRARY per cpp file due to a name redefinition error, so a new
// file is needed for each variant

#include <torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight-impl.h>

namespace {
Tensor _op_out(
RuntimeContext& ctx,
const Tensor& activations,
const Tensor& packed_weights,
const Tensor& group_size_tensor,
const Tensor& n_tensor,
const Tensor& k_tensor,
Tensor& out) {
(void)ctx;
linear_out_cpu</*weight_nbit*/ 8, /*has_weight_zeros*/ true>(
activations, packed_weights, group_size_tensor, n_tensor, k_tensor, out);
return out;
}
} // namespace

EXECUTORCH_LIBRARY(torchao, "_linear_8bit_act_8bit_weight.out", _op_out);
10 changes: 5 additions & 5 deletions torchao/experimental/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def forward(self, x):

def _maybe_get_quantized_linear_native(nbit, has_weight_zeros):
try:
if nbit in [1, 2, 3, 4, 5, 6, 7]:
if nbit in [1, 2, 3, 4, 5, 6, 7, 8]:
wzp_suffix = "" if has_weight_zeros else "0zp"
return _Int8DynActIntxWeightQuantizedLinearNative(
pack_weight_op=getattr(
Expand Down Expand Up @@ -230,7 +230,7 @@ def _replace_linear_with_quantized_linear(module: nn.Module, kwargs={}):
has_weight_zeros = kwargs["has_weight_zeros"]

assert not isinstance(module, nn.Linear)
assert nbit >= 1 and nbit <= 7
assert nbit >= 1 and nbit <= 8

for name, child in module.named_children():
if not isinstance(child, nn.Linear):
Expand Down Expand Up @@ -366,9 +366,9 @@ def quantize_and_pack_weights(self, weights, group_size):
weight_qvals, weight_scales, weight_zeros = _quantize(
weights, self.group_size, self.nbit, has_weight_zeros=True
)
self.weight_qvals = weight_qvals.to(torch.int8)
self.weight_qvals = weight_qvals.to(torch.int32)
self.weight_scales = weight_scales
self.weight_zeros = weight_zeros.to(torch.int8)
self.weight_zeros = weight_zeros.to(torch.int32)

def forward(self, x):
shape = x.shape
Expand All @@ -394,7 +394,7 @@ def _replace_embedding_with_quantized_embedding(module: nn.Module, kwargs={}):
nbit = kwargs["nbit"]

assert not isinstance(module, nn.Embedding)
assert nbit >= 1 and nbit <= 7
assert nbit >= 1 and nbit <= 8

for name, child in module.named_children():
if not isinstance(child, nn.Embedding):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def test_accuracy(self):
model = torch.nn.Sequential(*[torch.nn.Embedding(num_embeddings, embedding_dim)])
indices = torch.randint(0, num_embeddings, (7,), dtype=torch.int32)

for nbit in [1, 2, 3, 4, 5, 6, 7]:
for nbit in [1, 2, 3, 4, 5, 6, 7, 8]:
print(f"Testing nbit={nbit}")
quantized_model = copy.deepcopy(model)
quantizer = IntxWeightEmbeddingQuantizer(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def test_accuracy(self):
activations = torch.randn(2, 3, m, k, dtype=torch.float32)
model = torch.nn.Sequential(*[torch.nn.Linear(k, n, bias=False)])

for nbit in [1, 2, 3, 4, 5, 6, 7]:
for nbit in [1, 2, 3, 4, 5, 6, 7, 8]:
for has_weight_zeros in [True, False]:
print(f"Testing nbit={nbit}, has_weight_zeros={has_weight_zeros}")
quantized_model = copy.deepcopy(model)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def test_accuracy(self):
activations = torch.randn(m, k, dtype=torch.float32)
model = torch.nn.Sequential(*[torch.nn.Linear(k, n, bias=False)])

for nbit in [1, 2, 3, 4, 5, 6, 7]:
for nbit in [1, 2, 3, 4, 5, 6, 7, 8]:
for has_weight_zeros in [True, False]:
print(f"Testing nbit={nbit}, has_weight_zeros={has_weight_zeros}")
quantized_model = copy.deepcopy(model)
Expand Down

0 comments on commit ac23009

Please sign in to comment.