diff --git a/kernels/optimized/cpu/op_linear.cpp b/kernels/optimized/cpu/op_linear.cpp index 210000b384d..d81bfd8643f 100644 --- a/kernels/optimized/cpu/op_linear.cpp +++ b/kernels/optimized/cpu/op_linear.cpp @@ -6,17 +6,69 @@ * LICENSE file in the root directory of this source tree. */ +#include + +#include + #include +#include +#include #include #include -#include - namespace torch { namespace executor { namespace native { -using Tensor = executorch::aten::Tensor; +namespace { +using ::executorch::aten::Tensor; +using ::executorch::cpublas::gemm; +using ::executorch::cpublas::TransposeType; +using ::executorch::runtime::toString; +using ::executorch::vec::map; +using ::executorch::vec::Vectorized; + +// Use vector store to initialize with scalar bias. +template +void initialize_scalar( + const ssize_t out_numel, + const scalar_t init, + scalar_t* out) { + using Vec = Vectorized; + + // Initialize a vector with the scalar initial value. + Vec init_vec(init); + + ssize_t d = 0; + for (; d < out_numel - (out_numel % Vec::size()); d += Vec::size()) { + // Vector-length store. + init_vec.store(out + d); + } + if (out_numel - d > 0) { + // Sub-vector-length store. + init_vec.store(out + d, static_cast(out_numel - d)); + } +} + +// Use std::memcpy to initialize with vector bias. +template +void initialize_to_vector( + const ssize_t n, + const ssize_t m, + const scalar_t* bias, + scalar_t* out) { + // Output is a n x m x scalar_t, while bias is m x scalar_t. + const size_t row_size = static_cast(m) * sizeof(scalar_t); + for (const auto col : c10::irange(n)) { + std::memcpy( + // Point to Column `col` of the output tensor. + out + col * m, + bias, + row_size); + } +} + +} // namespace Tensor& opt_linear_out( RuntimeContext& ctx, @@ -24,12 +76,6 @@ Tensor& opt_linear_out( const Tensor& mat2, const optional& bias, Tensor& out) { - ET_KERNEL_CHECK_MSG( - ctx, - !bias.has_value(), - InvalidArgument, - out, - "bias not supported yet in linear"); ET_KERNEL_CHECK(ctx, check_linear_args(in, mat2, out), InvalidArgument, out); size_t output_ndim = 0; @@ -46,28 +92,74 @@ Tensor& opt_linear_out( return out; } - int flattened_input_dim = 1; + ssize_t n = 1; for (int ii = 0; ii < in.dim() - 1; ++ii) { - flattened_input_dim *= in.sizes()[ii]; + n *= in.sizes()[ii]; } + const ssize_t k = in.sizes()[in.dim() - 1]; + const ssize_t m = mat2.size(0); + + if (bias.has_value()) { + ET_KERNEL_CHECK_MSG( + ctx, + // Bias and output dtype must match. + bias->dtype() == out.dtype(), + InvalidArgument, + out, + "Bias has wrong dtype! Expected bias dtype to be the same as out dtype %s" + " but got %s", + toString(bias->dtype()), + toString(out.dtype())); + + ET_KERNEL_CHECK_MSG( + ctx, + // Either no bias or bias is a 1D tensor of size m or 1. + bias->dim() == 1 && (bias->size(0) == m || bias->size(0) == 1), + InvalidArgument, + out, + "Bias has wrong dimensionality! Expected 1-D tensor of size %d or empty," + " but got %d-D tensor with %d elements", + static_cast(m), + static_cast(bias->dim()), + static_cast(bias->numel())); + } + ET_SWITCH_REAL_TYPES_AND2( - Half, BFloat16, in.scalar_type(), ctx, "mm.out", CTYPE, [&]() { - size_t n = flattened_input_dim; - size_t k = in.sizes()[in.dim() - 1]; - size_t m = mat2.size(0); - - executorch::cpublas::gemm( - executorch::cpublas::TransposeType::Transpose, - executorch::cpublas::TransposeType::NoTranspose, + Half, BFloat16, out.scalar_type(), ctx, "linear.out", CTYPE, [&] { + // Fill output with bias if it is provided. + if (bias.has_value() && bias->numel() == 1) { + // Scalar version of initialization. + initialize_scalar( + out.numel(), + *bias->const_data_ptr(), + out.mutable_data_ptr()); + } else if (bias.has_value()) { + // Assume bias is a 1D tensor of size m. + initialize_to_vector( + n, + m, + bias->const_data_ptr(), + out.mutable_data_ptr()); + } + + // Set beta to 1 if bias was applied so that GEMM adds to the pre-filled + // bias, otherwise beta remains 0 (i.e. the output is fully overwritten + // by GEMM). + const CTYPE beta = + bias.has_value() ? static_cast(1) : static_cast(0); + + gemm( + /*transa=*/TransposeType::Transpose, + /*transb=*/TransposeType::NoTranspose, m, n, k, - static_cast(1), + /*alpha=*/static_cast(1), mat2.const_data_ptr(), k, in.const_data_ptr(), k, - static_cast(0), + beta, out.mutable_data_ptr(), m); }); diff --git a/kernels/test/op_linear_test.cpp b/kernels/test/op_linear_test.cpp index d894c5a818a..0ad5790a550 100644 --- a/kernels/test/op_linear_test.cpp +++ b/kernels/test/op_linear_test.cpp @@ -18,7 +18,8 @@ #include #include -using namespace ::testing; +namespace { + using executorch::aten::ArrayRef; using executorch::aten::Scalar; using executorch::aten::ScalarType; @@ -31,7 +32,15 @@ class OpLinearOutTest : public OperatorTest { return torch::executor::aten::linear_outf(context_, self, mat2, {}, out); } - template + Tensor& op_linear_out( + const Tensor& self, + const Tensor& mat2, + const Tensor& bias, + Tensor& out) { + return torch::executor::aten::linear_outf(context_, self, mat2, bias, out); + } + + template void test_dtype() { TensorFactory tf; @@ -43,16 +52,16 @@ class OpLinearOutTest : public OperatorTest { } } - // matmul gives 32 * 2 * 3 = 192 - Tensor x = tf.full({3, 32}, 2); - Tensor y = tf.full({5, 32}, 3); + // matmul gives 19 * 2 * 3 = 114 + Tensor x = tf.full({3, 19}, 2); + Tensor y = tf.full({5, 19}, 3); // Output shape should be (3, 5) Tensor out = tf.zeros({3, 5}); op_linear_out(x, y, out); - Tensor expected = tf.full({3, 5}, 192); + Tensor expected = tf.full({3, 5}, 114); EXPECT_TENSOR_EQ(out, expected); } @@ -88,6 +97,80 @@ TEST_F(OpLinearOutTest, AllDtypesSupported) { // for those types. } +TEST_F(OpLinearOutTest, BiasTest) { + TensorFactory tf; + + // Initialize input tensors. + constexpr int kReduceDim = 4; + constexpr int kDimX = 3, kDimY = 2; + constexpr int kValueX = 1; + constexpr int kValueY = 2; + constexpr int kValueBias0 = 4, kValueBias1 = 7; + const Tensor x = tf.full({kDimX, kReduceDim}, kValueX); + const Tensor y = tf.full({kDimY, kReduceDim}, kValueY); + const Tensor b = tf.make({kDimY}, {kValueBias0, kValueBias1}); + // Output matrix is also empty + Tensor out = tf.zeros({kDimX, kDimY}); + // Initialize expected tensor. + constexpr int kValueExpected0 = kValueX * kValueY * kReduceDim + kValueBias0; + constexpr int kValueExpected1 = kValueX * kValueY * kReduceDim + kValueBias1; + // Check that the bias is added to the correct position in the output matrix. + const Tensor expected = tf.make( + {kDimX, kDimY}, + {kValueExpected0, + kValueExpected1, + kValueExpected0, + kValueExpected1, + kValueExpected0, + kValueExpected1}); + + EXPECT_TENSOR_EQ(op_linear_out(x, y, b, out), expected); +} + +TEST_F(OpLinearOutTest, BiasBroadcastTest) { + TensorFactory tf; + + // Initialize input tensors. + constexpr int kReduceDim = 4; + constexpr int kDimX = 3, kDimY = 5; + constexpr int kValueX = 1; + constexpr int kValueY = 2; + constexpr int kValueBias = 4; + const Tensor x = tf.full({kDimX, kReduceDim}, kValueX); + const Tensor y = tf.full({kDimY, kReduceDim}, kValueY); + const Tensor b = tf.full({1}, kValueBias); + // Output matrix is also empty + Tensor out = tf.zeros({kDimX, kDimY}); + // Initialize expected tensor. + constexpr int kValueExpected = kValueX * kValueY * kReduceDim + kValueBias; + const Tensor expected = tf.full({kDimX, kDimY}, kValueExpected); + + EXPECT_TENSOR_EQ(op_linear_out(x, y, b, out), expected); +} + +TEST_F(OpLinearOutTest, BiasDtypeMismatch) { + TensorFactory tf; + TensorFactory tf_bias; + + // Initialize input tensors. + constexpr int kReduceDim = 4; + constexpr int kDimX = 3, kDimY = 5; + constexpr int kValueX = 1; + constexpr int kValueY = 2; + constexpr int kValueBias = 4; + Tensor x = tf.full({kDimX, kReduceDim}, kValueX); + Tensor y = tf.full({kDimY, kReduceDim}, kValueY); + // Same size as output. + Tensor b = tf_bias.full({kDimY}, kValueBias); + // Output matrix is also empty + Tensor out = tf.zeros({kDimX, kDimY}); + // Initialize expected tensor. + constexpr int kValueExpected = kValueX * kValueY * kReduceDim + kValueBias; + Tensor expected = tf.full({kDimX, kDimY}, kValueExpected); + + ET_EXPECT_KERNEL_FAILURE(context_, op_linear_out(x, y, b, out)); +} + TEST_F(OpLinearOutTest, EmptyInputWithEmptyOutTensorPasses) { TensorFactory tf; @@ -297,5 +380,4 @@ TEST_F(OpLinearOutTest, DynamicShapeUnbound) { Tensor ret = op_linear_out(x, y, out); EXPECT_TENSOR_CLOSE(out, expected_result); } - -// TODO: support and test bias +} // namespace