Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 113 additions & 21 deletions kernels/optimized/cpu/op_linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,30 +6,76 @@
* LICENSE file in the root directory of this source tree.
*/

#include <array>

#include <c10/util/irange.h>

#include <executorch/kernels/optimized/blas/CPUBlas.h>
#include <executorch/kernels/optimized/vec/functional_base.h>
#include <executorch/kernels/optimized/vec/vec_base.h>
#include <executorch/kernels/portable/cpu/util/matmul_ops_util.h>
#include <executorch/runtime/kernel/kernel_includes.h>

#include <array>

namespace torch {
namespace executor {
namespace native {

using Tensor = executorch::aten::Tensor;
namespace {
using ::executorch::aten::Tensor;
using ::executorch::cpublas::gemm;
using ::executorch::cpublas::TransposeType;
using ::executorch::runtime::toString;
using ::executorch::vec::map;
using ::executorch::vec::Vectorized;

// Use vector store to initialize with scalar bias.
template <typename scalar_t>
void initialize_scalar(
const ssize_t out_numel,
const scalar_t init,
scalar_t* out) {
using Vec = Vectorized<scalar_t>;

// Initialize a vector with the scalar initial value.
Vec init_vec(init);

ssize_t d = 0;
for (; d < out_numel - (out_numel % Vec::size()); d += Vec::size()) {
// Vector-length store.
init_vec.store(out + d);
}
if (out_numel - d > 0) {
// Sub-vector-length store.
init_vec.store(out + d, static_cast<int>(out_numel - d));
}
}

// Use std::memcpy to initialize with vector bias.
template <typename scalar_t>
void initialize_to_vector(
const ssize_t n,
const ssize_t m,
const scalar_t* bias,
scalar_t* out) {
// Output is a n x m x scalar_t, while bias is m x scalar_t.
const size_t row_size = static_cast<size_t>(m) * sizeof(scalar_t);
for (const auto col : c10::irange(n)) {
std::memcpy(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To handle 2d bias, you need to fix this. Bias pointer is not advancing

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, it's more complicated than that. bias can be 1xm, nx1, or nxm.

// Point to Column `col` of the output tensor.
out + col * m,
bias,
row_size);
}
}

} // namespace

Tensor& opt_linear_out(
RuntimeContext& ctx,
const Tensor& in,
const Tensor& mat2,
const optional<Tensor>& bias,
Tensor& out) {
ET_KERNEL_CHECK_MSG(
ctx,
!bias.has_value(),
InvalidArgument,
out,
"bias not supported yet in linear");
ET_KERNEL_CHECK(ctx, check_linear_args(in, mat2, out), InvalidArgument, out);

size_t output_ndim = 0;
Expand All @@ -46,28 +92,74 @@ Tensor& opt_linear_out(
return out;
}

int flattened_input_dim = 1;
ssize_t n = 1;
for (int ii = 0; ii < in.dim() - 1; ++ii) {
flattened_input_dim *= in.sizes()[ii];
n *= in.sizes()[ii];
}
const ssize_t k = in.sizes()[in.dim() - 1];
const ssize_t m = mat2.size(0);

if (bias.has_value()) {
ET_KERNEL_CHECK_MSG(
ctx,
// Bias and output dtype must match.
bias->dtype() == out.dtype(),
InvalidArgument,
out,
"Bias has wrong dtype! Expected bias dtype to be the same as out dtype %s"
" but got %s",
toString(bias->dtype()),
toString(out.dtype()));

ET_KERNEL_CHECK_MSG(
ctx,
// Either no bias or bias is a 1D tensor of size m or 1.
bias->dim() == 1 && (bias->size(0) == m || bias->size(0) == 1),
InvalidArgument,
out,
"Bias has wrong dimensionality! Expected 1-D tensor of size %d or empty,"
" but got %d-D tensor with %d elements",
static_cast<int>(m),
static_cast<int>(bias->dim()),
static_cast<int>(bias->numel()));
}

ET_SWITCH_REAL_TYPES_AND2(
Half, BFloat16, in.scalar_type(), ctx, "mm.out", CTYPE, [&]() {
size_t n = flattened_input_dim;
size_t k = in.sizes()[in.dim() - 1];
size_t m = mat2.size(0);

executorch::cpublas::gemm(
executorch::cpublas::TransposeType::Transpose,
executorch::cpublas::TransposeType::NoTranspose,
Half, BFloat16, out.scalar_type(), ctx, "linear.out", CTYPE, [&] {
// Fill output with bias if it is provided.
if (bias.has_value() && bias->numel() == 1) {
// Scalar version of initialization.
initialize_scalar<CTYPE>(
out.numel(),
*bias->const_data_ptr<CTYPE>(),
out.mutable_data_ptr<CTYPE>());
} else if (bias.has_value()) {
// Assume bias is a 1D tensor of size m.
initialize_to_vector<CTYPE>(
n,
m,
bias->const_data_ptr<CTYPE>(),
out.mutable_data_ptr<CTYPE>());
}

// Set beta to 1 if bias was applied so that GEMM adds to the pre-filled
// bias, otherwise beta remains 0 (i.e. the output is fully overwritten
// by GEMM).
const CTYPE beta =
bias.has_value() ? static_cast<CTYPE>(1) : static_cast<CTYPE>(0);

gemm(
/*transa=*/TransposeType::Transpose,
/*transb=*/TransposeType::NoTranspose,
m,
n,
k,
static_cast<CTYPE>(1),
/*alpha=*/static_cast<CTYPE>(1),
mat2.const_data_ptr<CTYPE>(),
k,
in.const_data_ptr<CTYPE>(),
k,
static_cast<CTYPE>(0),
beta,
out.mutable_data_ptr<CTYPE>(),
m);
});
Expand Down
98 changes: 90 additions & 8 deletions kernels/test/op_linear_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
#include <gtest/gtest.h>
#include <limits>

using namespace ::testing;
namespace {

using executorch::aten::ArrayRef;
using executorch::aten::Scalar;
using executorch::aten::ScalarType;
Expand All @@ -31,7 +32,15 @@ class OpLinearOutTest : public OperatorTest {
return torch::executor::aten::linear_outf(context_, self, mat2, {}, out);
}

template <class CTYPE, executorch::aten::ScalarType DTYPE>
Tensor& op_linear_out(
const Tensor& self,
const Tensor& mat2,
const Tensor& bias,
Tensor& out) {
return torch::executor::aten::linear_outf(context_, self, mat2, bias, out);
}

template <class CTYPE, ScalarType DTYPE>
void test_dtype() {
TensorFactory<DTYPE> tf;

Expand All @@ -43,16 +52,16 @@ class OpLinearOutTest : public OperatorTest {
}
}

// matmul gives 32 * 2 * 3 = 192
Tensor x = tf.full({3, 32}, 2);
Tensor y = tf.full({5, 32}, 3);
// matmul gives 19 * 2 * 3 = 114
Tensor x = tf.full({3, 19}, 2);
Tensor y = tf.full({5, 19}, 3);

// Output shape should be (3, 5)
Tensor out = tf.zeros({3, 5});

op_linear_out(x, y, out);

Tensor expected = tf.full({3, 5}, 192);
Tensor expected = tf.full({3, 5}, 114);

EXPECT_TENSOR_EQ(out, expected);
}
Expand Down Expand Up @@ -88,6 +97,80 @@ TEST_F(OpLinearOutTest, AllDtypesSupported) {
// for those types.
}

TEST_F(OpLinearOutTest, BiasTest) {
TensorFactory<ScalarType::Int> tf;

// Initialize input tensors.
constexpr int kReduceDim = 4;
constexpr int kDimX = 3, kDimY = 2;
constexpr int kValueX = 1;
constexpr int kValueY = 2;
constexpr int kValueBias0 = 4, kValueBias1 = 7;
const Tensor x = tf.full({kDimX, kReduceDim}, kValueX);
const Tensor y = tf.full({kDimY, kReduceDim}, kValueY);
const Tensor b = tf.make({kDimY}, {kValueBias0, kValueBias1});
// Output matrix is also empty
Tensor out = tf.zeros({kDimX, kDimY});
// Initialize expected tensor.
constexpr int kValueExpected0 = kValueX * kValueY * kReduceDim + kValueBias0;
constexpr int kValueExpected1 = kValueX * kValueY * kReduceDim + kValueBias1;
// Check that the bias is added to the correct position in the output matrix.
const Tensor expected = tf.make(
{kDimX, kDimY},
{kValueExpected0,
kValueExpected1,
kValueExpected0,
kValueExpected1,
kValueExpected0,
kValueExpected1});

EXPECT_TENSOR_EQ(op_linear_out(x, y, b, out), expected);
}

TEST_F(OpLinearOutTest, BiasBroadcastTest) {
TensorFactory<ScalarType::Int> tf;

// Initialize input tensors.
constexpr int kReduceDim = 4;
constexpr int kDimX = 3, kDimY = 5;
constexpr int kValueX = 1;
constexpr int kValueY = 2;
constexpr int kValueBias = 4;
const Tensor x = tf.full({kDimX, kReduceDim}, kValueX);
const Tensor y = tf.full({kDimY, kReduceDim}, kValueY);
const Tensor b = tf.full({1}, kValueBias);
// Output matrix is also empty
Tensor out = tf.zeros({kDimX, kDimY});
// Initialize expected tensor.
constexpr int kValueExpected = kValueX * kValueY * kReduceDim + kValueBias;
const Tensor expected = tf.full({kDimX, kDimY}, kValueExpected);

EXPECT_TENSOR_EQ(op_linear_out(x, y, b, out), expected);
}

TEST_F(OpLinearOutTest, BiasDtypeMismatch) {
TensorFactory<ScalarType::Int> tf;
TensorFactory<ScalarType::Short> tf_bias;

// Initialize input tensors.
constexpr int kReduceDim = 4;
constexpr int kDimX = 3, kDimY = 5;
constexpr int kValueX = 1;
constexpr int kValueY = 2;
constexpr int kValueBias = 4;
Tensor x = tf.full({kDimX, kReduceDim}, kValueX);
Tensor y = tf.full({kDimY, kReduceDim}, kValueY);
// Same size as output.
Tensor b = tf_bias.full({kDimY}, kValueBias);
// Output matrix is also empty
Tensor out = tf.zeros({kDimX, kDimY});
// Initialize expected tensor.
constexpr int kValueExpected = kValueX * kValueY * kReduceDim + kValueBias;
Tensor expected = tf.full({kDimX, kDimY}, kValueExpected);

ET_EXPECT_KERNEL_FAILURE(context_, op_linear_out(x, y, b, out));
}

TEST_F(OpLinearOutTest, EmptyInputWithEmptyOutTensorPasses) {
TensorFactory<ScalarType::Float> tf;

Expand Down Expand Up @@ -297,5 +380,4 @@ TEST_F(OpLinearOutTest, DynamicShapeUnbound) {
Tensor ret = op_linear_out(x, y, out);
EXPECT_TENSOR_CLOSE(out, expected_result);
}

// TODO: support and test bias
} // namespace
Loading