Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for extended exp operation as halide_extended_exp. #8206

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions src/IROperator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -854,6 +854,38 @@ Expr halide_exp(const Expr &x_full) {
return result;
}

Tuple halide_extended_exp(const Expr &x_full) {
Type type = x_full.type();
internal_assert(type.element_of() == Float(32));

float ln2_part1 = 0.6931457519f;
float ln2_part2 = 1.4286067653e-6f;
float one_over_ln2 = 1.0f / logf(2.0f);

Expr scaled = x_full * one_over_ln2;
Expr k_real = floor(scaled);

Expr x = x_full - k_real * ln2_part1;
x = x - k_real * ln2_part2;

float coeff[] = {
0.00031965933071842413f,
0.00119156835564003744f,
0.00848988645943932717f,
0.04160188091348320655f,
0.16667983794100929562f,
0.49999899033463041098f,
1.0f,
1.0f};
Expr result = evaluate_polynomial(x, coeff, sizeof(coeff) / sizeof(coeff[0]));

// Ensure that the mantissa part is not a NaN or itself an infinity.
result = strict_float(select(!is_finite(k_real), 1, result));
result = common_subexpression_elimination(result);

return {result, k_real};
}

Expr halide_erf(const Expr &x_full) {
user_assert(x_full.type() == Float(32)) << "halide_erf only works for Float(32)";

Expand Down
28 changes: 28 additions & 0 deletions src/IROperator.h
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,34 @@ Expr halide_exp(const Expr &a);
Expr halide_erf(const Expr &a);
// @}

/** Extended exponential which produces two output values,
* each of the same precision as the input, as described in
* "The Two-Pass Softmax Algorithm" by Marat Dukhan and
* Artsiom Ablavatski [https://arxiv.org/abs/2001.04438].
*
* The first element of the returned Tuple is a psuedo-mantissa while
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pseudo

* the second is an exponent which is an integer. The product of the
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So is the returned Tuple a pair of (float32, int32), or is it (float32, float32) where the second is always an integral value?

* pseudo-mantissa and 2 raised to the returned exponent is the
* desired result e^a. For arguments up to slightly greater than
* 11629079, the pseudo-mantissa is guaranteed to be within the
* interval (-e, e). For larger arguments, the exponent result of the
* tuple may not be able to represent the exact integer necessary to
* keep the pseudo-mantissa within bounds. Thus it can become
* progressively larger in magnitude as the argument increases.
*
* Ideally this routine will maintain a degree of accuracy through the
* entire range and be able to produce results out to the end of the
* numeric range. At present neither of these properties are true due to
* the following issues:
* - Range reduction may overflow when scaling the argument.
* - Range reduction is increasingly inaccurate in reducing the value
* due to the implementation. This results in overflow in the polynomial
* evaluation.
* - Even if the above to issues were resolved, the approximation polynomial
* would have to run on values outside its intended approximation range.
*/
Tuple halide_extended_exp(const Expr &a);

/** Raise an expression to an integer power by repeatedly multiplying
* it by itself. */
Expr raise_to_integer_power(Expr a, int64_t b);
Expand Down
1 change: 1 addition & 0 deletions test/correctness/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ tests(GROUPS correctness
erf.cpp
exception.cpp
explicit_inline_reductions.cpp
extended_exp.cpp
extern_bounds_inference.cpp
extern_consumer.cpp
extern_error.cpp
Expand Down
140 changes: 140 additions & 0 deletions test/correctness/extended_exp.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
#include "Halide.h"
#include <cmath>
#include <iomanip>
#include <iostream>
#include <limits>

using namespace Halide;
using Halide::Internal::halide_exp;
using Halide::Internal::halide_extended_exp;

// Compare naive two pass softmax, which will overflow easily, to two
// pass algorithm from "The Two-Pass Softmax Algorithm" by Marat
// Dukhan and Artsiom Ablavatski [https://arxiv.org/abs/2001.04438],
// which is implemented using halide_extended_exp.
void two_pass_softmax_test(float scale) {
Var x("x");
RDom r(0, 1024);

Func input("input");
input(x) = 0.0f;
input(r) = random_float() * scale;

// Naive two pass algorithm. Doesn't work for large values or large size inputs.
Func in_exp("in_exp");
in_exp(x) = halide_exp(input(x));
Func exp_sum("exp_sum");
exp_sum() = sum(in_exp(r));

Func naive_softmax("naive_softmax");
naive_softmax(x) = in_exp(x) / exp_sum();

// Three pass algorithm that works for all inputs.
Func max_input("max_input");
max_input() = maximum(input(r));
Func biased_in_exp("biased_in_exp");
biased_in_exp(x) = halide_exp(input(x) - max_input());
Func biased_exp_sum("biased_exp_sum");
biased_exp_sum() = sum(biased_in_exp(r));

Func three_pass_softmax("three_pass_softmax");
three_pass_softmax(x) = biased_in_exp(x) / biased_exp_sum();

// Two pass extended exp algorithm.
Func in_extended_exp("in_extended_exp");
in_extended_exp(x) = halide_extended_exp(input(x));
Expr mantissa = in_extended_exp(x)[0];
Expr exponent = in_extended_exp(x)[1];

Func extended_exp_sum("extended_exp_sum");
extended_exp_sum() = Tuple(0.0f, std::numeric_limits<float>::lowest()); // mantissa, exponent
Expr max_exp = max(extended_exp_sum()[1], in_extended_exp(r)[1]);
Expr mantissa_sum = in_extended_exp(r)[0] * pow(2, in_extended_exp(r)[1] - max_exp) +
extended_exp_sum()[0] * pow(2, extended_exp_sum()[1] - max_exp);
extended_exp_sum() = Tuple(mantissa_sum, max_exp);

Expr lambda = 1 / extended_exp_sum()[0];
Func two_pass_softmax("two_pass_softmax");
two_pass_softmax(x) = in_extended_exp(x)[0] * lambda * pow(2, in_extended_exp(x)[1] - extended_exp_sum()[1]);

Func relative_error("relative_error");
relative_error(x) = abs(three_pass_softmax(x) - two_pass_softmax(x)) / max(.000001f, three_pass_softmax(x));
Func max_relative_error("max_relative_error");
max_relative_error() = maximum(relative_error(r));
Func max_prob("max_prob");
max_prob() = maximum(two_pass_softmax(r));
Func min_prob("min_prob");
min_prob() = minimum(two_pass_softmax(r));
Func sum_prob("sum_prob");
sum_prob() = sum(two_pass_softmax(r));

Func result("result");
result() = Tuple(max_relative_error(), max_prob(), min_prob(), sum_prob());
exp_sum.compute_root();
biased_exp_sum.compute_root();
extended_exp_sum.compute_root();
naive_softmax.compute_root();
three_pass_softmax.compute_root();
two_pass_softmax.compute_root();

auto output = result.realize();

float max_relative_error_result = ((Buffer<float> &)output[0])();
float max_probability = ((Buffer<float> &)output[1])();
float min_probability = ((Buffer<float> &)output[2])();
float sum_probability = ((Buffer<float> &)output[3])();

if (max_relative_error_result > .0001f) {
std::cout << "Failed: Softmax results do not match.\n";
exit(1);
}

if (max_probability > 1.0f) {
std::cout << "Failed: Softmax probability is greater than 1.0f.\n";
exit(1);
}

if (min_probability < 0.0f) {
std::cout << "Failed: Softmax probability is negative.\n";
exit(1);
}

if (sum_probability > 1.0001f) {
std::cout << "Failed: Softmax probability sum is too large.\n";
exit(1);
}
}

void expect(float x, float mantissa, float exponent) {
float computed_mantissa;
float computed_exponent;
evaluate(halide_extended_exp(x), &computed_mantissa, &computed_exponent);
if (fabs(computed_mantissa) > exp(1.0f)) {
std::cout << "Mantissa large for x " << x << " mantissa " << computed_mantissa
<< " exponent " << computed_exponent << "\n";
}
if (fabs(mantissa - computed_mantissa) > .00001 ||
fabs(exponent - computed_exponent) > .00001) {
std::cout << "Falied: halide_extended_exp(" << x << ") == {"
<< computed_mantissa << ", " << computed_exponent
<< "} expected {"
<< mantissa << ", " << exponent << "}\n";
exit(1);
}
}

int main(int argc, char **argv) {
std::cout << std::hexfloat;
expect(0, 1, 0);
expect(1, exp(1.0f) / 2, 1);
expect(88, 1.94149, 126);
expect(0x1.62e43p+23f, 0x1.085012p+0, 0x1p+24);
expect(std::numeric_limits<float>::lowest(), 1.0f, -std::numeric_limits<float>::infinity());
expect(std::numeric_limits<float>::max(), 1.0f, std::numeric_limits<float>::infinity());
two_pass_softmax_test(1.0f);
two_pass_softmax_test(10000.0f);
two_pass_softmax_test(-10000.0f);
two_pass_softmax_test(std::numeric_limits<float>::max());
two_pass_softmax_test(std::numeric_limits<float>::lowest());
std::cout << "Success!\n";
}
Loading