diff --git a/kernels/quantized/cpu/op_embedding4b.cpp b/kernels/quantized/cpu/op_embedding4b.cpp new file mode 100644 index 00000000000..33be86e5cc4 --- /dev/null +++ b/kernels/quantized/cpu/op_embedding4b.cpp @@ -0,0 +1,344 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include + +namespace torch { +namespace executor { +namespace native { + +using Tensor = exec_aten::Tensor; +using Scalar = exec_aten::Scalar; +using ScalarType = exec_aten::ScalarType; + +namespace { + +/** + * Asserts that the parameters are valid. + */ +void check_embedding_4bit_args( + const Tensor& weight, + const Tensor& weight_scales, + const optional& opt_weight_zero_points, + const int64_t weight_quant_min, + const int64_t weight_quant_max, + const Tensor& indices, + exec_aten::optional out_dtype, + Tensor& out) { + ET_CHECK_MSG( + weight.dim() == 2, "weight must be 2D but got() %zd dims", weight.dim()); + + ET_CHECK_MSG( + weight_scales.dim() == 1 || weight_scales.dim() == 2, + "weight_scales must be 1D or 2D but got() %zd dims", + weight_scales.dim()); + + ET_CHECK_MSG( + weight_scales.size(0) == weight.size(0), + "Number of scales must be == weight.size(0)=%zd" + ", but got %zd", + weight_scales.size(0), + weight.size(0)); + + if (weight_scales.dim() == 2) { + auto num_groups = weight_scales.size(1); + ET_CHECK_MSG( + // each 8b uint8 column is 2 columns + (2 * weight.size(1)) % num_groups == 0, + "Number of groups must divide weight.size(1)=%zd" + ", but got # of groups = %zd", + weight.size(1), + num_groups); + } + + ET_CHECK_MSG( + weight.scalar_type() == ScalarType::Byte, + "weight.scalar_type() %" PRId8 " is not supported:", + static_cast(weight.scalar_type())); + + ET_CHECK_MSG( + out.scalar_type() == ScalarType::Float || + out.scalar_type() == ScalarType::Half, + "out.scalar_type() %" PRId8 " is not supported:", + static_cast(out.scalar_type())); + + ET_CHECK_MSG( + weight_scales.scalar_type() == ScalarType::Float || + weight_scales.scalar_type() == ScalarType::Half, + "weight_scales.scalar_type() %" PRId8 " is not supported:", + static_cast(weight_scales.scalar_type())); + + if (opt_weight_zero_points.has_value()) { + ET_CHECK_MSG( + opt_weight_zero_points.value().dim() == weight_scales.dim(), + "weight_zero_points's rank match that of weight_scales. " + "weight_zero_points rank: %" PRId8 ", weight_scales rank: %" PRId8, + static_cast(opt_weight_zero_points.value().dim()), + static_cast(weight_scales.dim())); + + ET_CHECK_MSG( + opt_weight_zero_points.value().scalar_type() == out.scalar_type(), + "weight zero points scalar type %" PRId8 + " does not match out.scalar_type()", + static_cast(opt_weight_zero_points.value().scalar_type())); + + for (int32_t i = 0; i < weight_scales.dim(); ++i) { + ET_CHECK_MSG( + opt_weight_zero_points.value().size(i) == weight_scales.size(i), + "Dimension size misatch at dim %" PRId8 + "Weight_zero_point size = %zd" + ", weight_scales size = %zd.", + i, + opt_weight_zero_points.value().size(i), + weight_scales.size(i)); + } + } + + ET_CHECK_MSG( + indices.scalar_type() == ScalarType::Long, + "indices.scalar_type() %" PRId8 " is not Long only Long is supported:", + static_cast(indices.scalar_type())); + + ET_CHECK_MSG( + weight_quant_min <= weight_quant_max, + "weight quant min: %" PRId64 + " is greater than weight quant max: %" PRId64, + weight_quant_min, + weight_quant_max); + + if (out_dtype.has_value()) { + ET_CHECK_MSG( + out.scalar_type() == out_dtype.value(), + "output_dtype must match the dtype of the out tensor"); + } +} + +static inline int32_t weight_value(const unsigned char* w_data, int32_t index) { + int32_t odd = index & 1; + index >>= 1; + if (odd) { + return (int32_t)(w_data[index] & 0x0F) - 8; + } else { + return (int32_t)((w_data[index] >> 4) & 0x0F) - 8; + } +} + +/** + * Retrieves the embeddings specified by indices, dequantizes them, and stores + * them in out. Weight will always be uint8 + */ +template +void embedding_4bit_per_channel( + const Tensor& weight, + const Tensor& weight_scales, + const optional& opt_weight_zero_points, + const Tensor& indices, + Tensor& out) { + auto embedding_dim = weight.size(1) * 2; + + int32_t num_groups_per_channel = 1; + if (weight_scales.dim() == 2) { + num_groups_per_channel = weight_scales.size(1); + } + int32_t group_size = embedding_dim / num_groups_per_channel; + + CTYPE_OUT* out_data = out.mutable_data_ptr(); + const int64_t* indices_ptr = indices.const_data_ptr(); + + const CTYPE_PARAMS* scales = weight_scales.const_data_ptr(); + const CTYPE_PARAMS* zero_points = nullptr; + if (opt_weight_zero_points.has_value()) { + zero_points = opt_weight_zero_points.value().const_data_ptr(); + } + + for (int i = 0; i < indices.numel(); i++) { + int64_t index = indices_ptr[i]; + // If using groupwise embedding + int32_t qparams_index = index * num_groups_per_channel; + CTYPE_PARAMS zp = 0.0; + const CTYPE_PARAMS* scale_ptr = scales + qparams_index; + const CTYPE_PARAMS* zero_points_ptr = nullptr; + if (opt_weight_zero_points.has_value()) { + zero_points_ptr = zero_points + qparams_index; + } + + const uint8_t* w_data = weight.data_ptr() + weight.size(1) * index; + + for (int j = 0; j < embedding_dim; ++j) { + int32_t group_id = j / group_size; + const CTYPE_PARAMS scale = scale_ptr[group_id]; + if (opt_weight_zero_points.has_value()) { + zp = zero_points_ptr[group_id]; + } + out_data[j] = static_cast( + (static_cast(weight_value(w_data, j)) - + static_cast(zp)) * + static_cast(scale)); + } + out_data += embedding_dim; + } +} + +void resize_out_tensor( + const Tensor& weight, + const Tensor& indices, + Tensor& out) { + exec_aten::SizesType expected_output_size[kTensorDimensionLimit]; + for (size_t i = 0; i < indices.dim(); i++) { + expected_output_size[i] = indices.size(i); + } + const size_t embedding_dim = weight.size(1); + expected_output_size[out.dim() - 1] = embedding_dim; + + exec_aten::ArrayRef output_size{ + expected_output_size, static_cast(out.dim())}; + + torch::executor::Error err = resize_tensor(out, output_size); + ET_CHECK_MSG( + err == torch::executor::Error::Ok, + "Failed to resize out Tensor in quantized_embedding_4bit_out"); +} + +} // namespace + +/** + * Retrieves the embeddings specified by indices, dequantizes them, and stores + * them in out. The weight is quantized per channel, with a scale and zero_point + * for each embedding. + * + * Corresponds as the out variant to torch.ops.quantized.embedding_4bit + * + * NOTE: quant_min, quant_max, and Dtype are not used in computation, but rather + * metadata that is passed around which can be useful for pattern matching. See + * https://github.com/pytorch/pytorch/pull/87093#discussion_r1000841181 for more + * info. + */ +Tensor& quantized_embedding_4bit_out( + // TODO Evaluate whether this name is appropriate for an operator that takes + // non quant input and returns fp output + const Tensor& weight, + const Tensor& weight_scales, + const optional& opt_weight_zero_points, + const int64_t weight_quant_min, + const int64_t weight_quant_max, + const Tensor& indices, + Tensor& out) { + ScalarType out_type = out.scalar_type(); + + // TODO (jakeszwe): improve these to account for the size of out in relation + // to weight and indices accounting for a possible batch dimension + check_embedding_4bit_args( + weight, + weight_scales, + opt_weight_zero_points, + weight_quant_min, + weight_quant_max, + indices, + out_type, + out); + + constexpr auto name = "quantized_decomposed::embedding_4bit.out"; + ET_SWITCH_TWO_TYPES(Float, Half, out_type, ctx, name, CTYPE_OUT, [&]() { + embedding_4bit_per_channel( + weight, weight_scales, opt_weight_zero_points, indices, out); + }); + + return out; +} + +Tensor& quantized_embedding_4bit_out( + RuntimeContext& context, + const Tensor& weight, + const Tensor& weight_scales, + const optional& opt_weight_zero_points, + int64_t weight_quant_min, + int64_t weight_quant_max, + const Tensor& indices, + Tensor& out) { + // TODO(larryliu): Add a context arg to the real op function and remove this + // wrapper + (void)context; + resize_out_tensor(weight, indices, out); + return quantized_embedding_4bit_out( + weight, + weight_scales, + opt_weight_zero_points, + weight_quant_min, + weight_quant_max, + indices, + out); +} + +Tensor& quantized_embedding_4bit_dtype_out( + // TODO Evaluate whether this name is appropriate for an operator that takes + // non quant input and returns fp output + const Tensor& weight, + const Tensor& weight_scales, + const optional& opt_weight_zero_points, + const int64_t weight_quant_min, + const int64_t weight_quant_max, + const Tensor& indices, + exec_aten::optional out_dtype, + Tensor& out) { + // TODO (jakeszwe): improve these to account for the size of out in relation + // to weight and indices accounting for a possible batch dimension + check_embedding_4bit_args( + weight, + weight_scales, + opt_weight_zero_points, + weight_quant_min, + weight_quant_max, + indices, + out_dtype, + out); + + ScalarType params_type = weight_scales.scalar_type(); + ScalarType out_type = out.scalar_type(); + + constexpr auto name = "quantized_decomposed::embedding_4bit.dtype_out"; + ET_SWITCH_TWO_TYPES(Float, Half, params_type, ctx, name, CTYPE_P, [&]() { + ET_SWITCH_TWO_TYPES(Float, Half, out_type, ctx, name, CTYPE_OUT, [&]() { + embedding_4bit_per_channel( + weight, weight_scales, opt_weight_zero_points, indices, out); + }); + }); + + return out; +} + +Tensor& quantized_embedding_4bit_dtype_out( + RuntimeContext& context, + const Tensor& weight, + const Tensor& weight_scales, + const optional& opt_weight_zero_points, + int64_t weight_quant_min, + int64_t weight_quant_max, + const Tensor& indices, + exec_aten::optional out_dtype, + Tensor& out) { + // TODO(larryliu): Add a context arg to the real op function and remove this + // wrapper + (void)context; + resize_out_tensor(weight, indices, out); + return quantized_embedding_4bit_dtype_out( + weight, + weight_scales, + opt_weight_zero_points, + weight_quant_min, + weight_quant_max, + indices, + out_dtype, + out); +} + +} // namespace native +} // namespace executor +} // namespace torch diff --git a/kernels/quantized/cpu/targets.bzl b/kernels/quantized/cpu/targets.bzl index 3a6f74631ac..39552aaaf10 100644 --- a/kernels/quantized/cpu/targets.bzl +++ b/kernels/quantized/cpu/targets.bzl @@ -23,6 +23,9 @@ _QUANT_OPS = ( op_target( name = "op_embedding", ), + op_target( + name = "op_embedding4b", + ), op_target( name = "op_mixed_mm", deps = [ diff --git a/kernels/quantized/quantized.yaml b/kernels/quantized/quantized.yaml index 484641318b4..ade5575e32e 100644 --- a/kernels/quantized/quantized.yaml +++ b/kernels/quantized/quantized.yaml @@ -46,6 +46,18 @@ - arg_meta: null kernel_name: torch::executor::quantized_embedding_byte_dtype_out +- func: quantized_decomposed::embedding_4bit.out(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, int weight_quant_min, int weight_quant_max, Tensor indices, *, Tensor(a!) out) -> Tensor(a!) + variants: function + kernels: + - arg_meta: null + kernel_name: torch::executor::quantized_embedding_4bit_out + +- func: quantized_decomposed::embedding_4bit.dtype_out(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, int weight_quant_min, int weight_quant_max, Tensor indices, ScalarType? dtype=None, *, Tensor(a!) out) -> Tensor(a!) + variants: function + kernels: + - arg_meta: null + kernel_name: torch::executor::quantized_embedding_4bit_dtype_out + - func: quantized_decomposed::mixed_mm.out(Tensor input, Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, *, Tensor(a!) out) -> Tensor(a!) variants: function kernels: diff --git a/kernels/quantized/test/op_embedding4b_test.cpp b/kernels/quantized/test/op_embedding4b_test.cpp new file mode 100644 index 00000000000..56944c57857 --- /dev/null +++ b/kernels/quantized/test/op_embedding4b_test.cpp @@ -0,0 +1,141 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include // Declares the operator +#include +#include +#include +#include +#include + +#include +#include + +using namespace ::testing; +using exec_aten::ArrayRef; +using exec_aten::optional; +using exec_aten::ScalarType; +using exec_aten::Tensor; +using torch::executor::native::quantized_embedding_4bit_out; + +using torch::executor::testing::TensorFactory; + +TEST(OpQuantizedEmbedding4bTest, TestGroupWiseQuantizedEmbedding) { + et_pal_init(); + TensorFactory tfb; + TensorFactory tf; + TensorFactory tfl; + + int64_t quant_min = -8; + int64_t quant_max = 7; + + Tensor weight_scales = tf.make({3}, {0.5, 1.0, 1.5}); + Tensor weight_zero_points = tf.make({3}, {1, -5, 0}); + + // -3, 1, 6, 7, + // 2, -5, -4, 0, + // -8, 3, -1, 6, + + Tensor qweight = tfb.make({3, 2}, {89, 239, 163, 72, 11, 126}); + + Tensor indices = tfl.make({3}, {0, 2, 1}); + + Tensor out = tf.zeros({3, 4}); + Tensor expected = tf.make( + {3, 4}, {-2.0, 0.0, 2.5, 3.0, -12.0, 4.5, -1.5, 9.0, 7.0, 0.0, 1.0, 5.0}); + + quantized_embedding_4bit_out( + qweight, + weight_scales, + weight_zero_points, + quant_min, + quant_max, + indices, + out); + + EXPECT_TENSOR_EQ(out, expected); + + // Groupwise quantization. groupsize = 2 + weight_scales = tf.make({3, 2}, {0.5, 1.0, 1.5, 2.0, 2.5, 3.0}); + weight_zero_points = tf.make({3, 2}, {1, -5, 0, 2, -3, -1}); + /* + fp_weight = [-2.0, 0.0, 11.0, 12.0, + 3.0, -7.5, -12.0, -4.0, + -12.5, 15.0, 0.0, 21.0] + */ + + out = tf.zeros({3, 4}); + expected = tf.make( + {3, 4}, + {-2.0, 0.0, 11.0, 12.0, -12.5, 15.0, 0.0, 21.0, 3.0, -7.5, -12.0, -4.0}); + + quantized_embedding_4bit_out( + qweight, + weight_scales, + weight_zero_points, + quant_min, + quant_max, + indices, + out); + + EXPECT_TENSOR_EQ(out, expected); +} + +TEST(OpQuantizedEmbedding4bTest, TestGroupWiseQuantizedEmbeddingDeath1) { + et_pal_init(); + TensorFactory tfb; + TensorFactory tf; + TensorFactory tfl; + + int64_t quant_min = -8; + int64_t quant_max = 7; + + Tensor weight_scales = tf.make({4}, {0.5, 1.0, 1.5, 3.3}); + Tensor weight_zero_points = tf.make({4}, {1, 5, 7, 5}); + Tensor qweight = tfb.make({3, 2}, {89, 239, 163, 72, 11, 126}); + Tensor indices = tfl.make({3}, {0, 2, 1}); + + Tensor out = tf.zeros({3, 4}); + ET_EXPECT_DEATH( + quantized_embedding_4bit_out( + qweight, + weight_scales, + weight_zero_points, + quant_min, + quant_max, + indices, + out), + ""); +} + +TEST(OpQuantizedEmbedding4bTest, TestGroupWiseQuantizedEmbeddingDeath2) { + et_pal_init(); + TensorFactory tfb; + TensorFactory tf; + TensorFactory tfl; + + int64_t quant_min = -8; + int64_t quant_max = 7; + + Tensor weight_scales = tf.make({2}, {0.5, 1.0}); + Tensor weight_zero_points = tf.make({2}, {1, 5}); + Tensor qweight = tfb.make({3, 2}, {89, 239, 163, 72, 11, 126}); + Tensor indices = tfl.make({3}, {0, 2, 1}); + + Tensor out = tf.zeros({3, 4}); + ET_EXPECT_DEATH( + quantized_embedding_4bit_out( + qweight, + weight_scales, + weight_zero_points, + quant_min, + quant_max, + indices, + out), + ""); +} diff --git a/kernels/quantized/test/targets.bzl b/kernels/quantized/test/targets.bzl index e06090cae91..a4129ee22fb 100644 --- a/kernels/quantized/test/targets.bzl +++ b/kernels/quantized/test/targets.bzl @@ -25,6 +25,7 @@ def define_common_targets(): "//executorch/kernels/portable/cpu:op_embedding", "//executorch/runtime/core/exec_aten/testing_util:tensor_util", ]) + op_test("op_embedding4b_test", kernel_name = "quantized") op_test("op_mixed_mm_test", kernel_name = "quantized", deps = [ "//executorch/kernels/quantized/cpu:op_mixed_mm", "//executorch/kernels/quantized:generated_lib_headers",