Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 16 additions & 6 deletions onnxruntime/contrib_ops/cpu/activations.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,17 +77,27 @@
const T* p_input = input_data + start;
T* p_output = output_data + start;
int64_t count = std::min(length_per_task, elem_count - start);
for (int64_t i = 0; i < count; i++) {
p_output[i] = p_input[i] * alpha_;
}

MlasComputeLogistic(p_output, p_output, onnxruntime::narrow<size_t>(count));
if (alpha_ != 1.0f) {
// TODO: Consider vectorizing this scalar multiplication.

Check warning on line 82 in onnxruntime/contrib_ops/cpu/activations.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2] Raw Output: onnxruntime/contrib_ops/cpu/activations.h:82: Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2]

Check warning on line 82 in onnxruntime/contrib_ops/cpu/activations.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2] Raw Output: onnxruntime/contrib_ops/cpu/activations.h:82: Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2]
// It needs exposing a new API in MLAS to take in a scalar
// that will be used in the elementwise multiplication.
// Estimate the cost-benefit tradeoff before proceeding
// with that optimization.
for (int64_t i = 0; i < count; i++) {
p_output[i] = p_input[i] * alpha_;
}

for (int64_t i = 0; i < count; i++) {
p_output[i] = p_input[i] * p_output[i];
MlasComputeLogistic(p_output, p_output, onnxruntime::narrow<size_t>(count));
} else {
// SILU activation - this needs no `alpha_` scaling as `alpha_` will be 1.0f
MlasComputeLogistic(p_input, p_output, onnxruntime::narrow<size_t>(count));
}

MlasEltwiseMul<float>(p_input, p_output, p_output, onnxruntime::narrow<size_t>(count));
},
0);

return Status::OK();
}

Expand Down
10 changes: 10 additions & 0 deletions onnxruntime/core/mlas/inc/mlas.h
Original file line number Diff line number Diff line change
Expand Up @@ -1126,6 +1126,16 @@ MlasEltwiseAdd(
size_t N
);

template <typename T>
void
MLASCALL
MlasEltwiseMul(
const T* left,
const T* right,
T* output,
size_t N
);

template<typename T>
void
MLASCALL
Expand Down
32 changes: 32 additions & 0 deletions onnxruntime/core/mlas/lib/eltwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,38 @@ MlasEltwiseAdd<float>(
}
}

template <>
void
MLASCALL
MlasEltwiseMul<float>(
const float* left,
const float* right,
float* output,
size_t N
) {
while (N > 0) {
if (N >= 4) {
MLAS_FLOAT32X4 LeftVec = MlasLoadFloat32x4(left);
MLAS_FLOAT32X4 RightVec = MlasLoadFloat32x4(right);

MLAS_FLOAT32X4 ResultVec = MlasMultiplyFloat32x4(LeftVec, RightVec);

MlasStoreFloat32x4(output, ResultVec);

left += 4;
right += 4;
output += 4;
N -= 4;
} else {
*output = *left * *right;

left += 1;
right += 1;
output += 1;
N -= 1;
}
}
}

template <>
void
Expand Down
52 changes: 52 additions & 0 deletions onnxruntime/test/mlas/unittest/test_eltwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,62 @@ class MlasEltwiseAddTest : public MlasTestBase {
}
};

class MlasEltwiseMulTest : public MlasTestBase {
private:
MatrixGuardBuffer<float> BufferInputLeft;
MatrixGuardBuffer<float> BufferInputRight;
MatrixGuardBuffer<float> BufferOutput;
MatrixGuardBuffer<float> BufferOutputReference;

void Test(size_t N, float MinimumValue, float MaximumValue, const std::optional<float>& ScalarValue = std::nullopt) {
float* InputLeft = BufferInputLeft.GetBuffer(N);
float* InputRight = BufferInputRight.GetBuffer(N);
float* Output = BufferOutput.GetBuffer(N);
float* OutputReference = BufferOutputReference.GetBuffer(N);

std::default_random_engine generator(static_cast<unsigned>(N));
std::uniform_real_distribution<float> distribution(MinimumValue, MaximumValue);

for (size_t n = 0; n < N; n++) {
InputLeft[n] = distribution(generator);
InputRight[n] = ScalarValue.value_or(distribution(generator));
}

for (size_t n = 0; n < N; n++) {
OutputReference[n] = InputLeft[n] * InputRight[n];
}

MlasEltwiseMul(InputLeft, InputRight, Output, N);

constexpr float AbsoluteTolerance = 1e-6f;
constexpr float RelativeTolerance = 1e-6f;

for (size_t n = 0; n < N; n++) {
float diff = std::fabs(Output[n] - OutputReference[n]);
ASSERT_TRUE(diff <= AbsoluteTolerance || diff <= std::fabs(OutputReference[n]) * RelativeTolerance)
<< " @" << n << " of " << N << ", got: " << Output[n] << ", expecting: " << OutputReference[n];
}
}

public:
static const char* GetTestSuiteName() {
static const std::string suite_name("Eltwise_Mul");
return suite_name.c_str();
}

void ExecuteShort(void) override {
for (size_t n = 1; n < 128; n++) {
Test(n, -10.f, 10.f);
Test(n, -10.f, 10.f, -5000.f);
}
}
};

static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) {
size_t count = 0;
if (is_short_execute) {
count += MlasDirectShortExecuteTests<MlasEltwiseAddTest>::RegisterShortExecute();
count += MlasDirectShortExecuteTests<MlasEltwiseMulTest>::RegisterShortExecute();
}
return count;
});
Loading