Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 28 additions & 16 deletions cpp/src/arrow/compute/kernels/aggregate_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -953,14 +953,13 @@ class TestPrimitiveVarStdKernel : public ::testing::Test {
using ScalarType = typename TypeTraits<DoubleType>::ScalarType;

void AssertVarStdIs(const Array& array, const VarianceOptions& options,
double expected_var, double diff = 0) {
AssertVarStdIsInternal(array, options, expected_var, diff);
double expected_var) {
AssertVarStdIsInternal(array, options, expected_var);
}

void AssertVarStdIs(const std::shared_ptr<ChunkedArray>& array,
const VarianceOptions& options, double expected_var,
double diff = 0) {
AssertVarStdIsInternal(array, options, expected_var, diff);
const VarianceOptions& options, double expected_var) {
AssertVarStdIsInternal(array, options, expected_var);
}

void AssertVarStdIs(const std::string& json, const VarianceOptions& options,
Expand Down Expand Up @@ -999,18 +998,14 @@ class TestPrimitiveVarStdKernel : public ::testing::Test {

private:
void AssertVarStdIsInternal(const Datum& array, const VarianceOptions& options,
double expected_var, double diff = 0) {
double expected_var) {
ASSERT_OK_AND_ASSIGN(Datum out_var, Variance(array, options));
ASSERT_OK_AND_ASSIGN(Datum out_std, Stddev(array, options));
auto var = checked_cast<const ScalarType*>(out_var.scalar().get());
auto std = checked_cast<const ScalarType*>(out_std.scalar().get());
ASSERT_TRUE(var->is_valid && std->is_valid);
ASSERT_DOUBLE_EQ(std->value * std->value, var->value);
if (diff == 0) {
ASSERT_DOUBLE_EQ(var->value, expected_var); // < 4ULP
} else {
ASSERT_NEAR(var->value, expected_var, diff);
}
ASSERT_DOUBLE_EQ(var->value, expected_var); // < 4ULP
}

void AssertVarStdIsInvalidInternal(const Datum& array, const VarianceOptions& options) {
Expand Down Expand Up @@ -1070,22 +1065,39 @@ TEST_F(TestVarStdKernelStability, Basics) {
VarianceOptions options{1}; // ddof = 1
this->AssertVarStdIs("[100000004, 100000007, 100000013, 100000016]", options, 30.0);
this->AssertVarStdIs("[1000000004, 1000000007, 1000000013, 1000000016]", options, 30.0);

#ifndef __MINGW32__ // MinGW has precision issues
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was only the 32-bit MinGW build, i.e. it was perhaps not MinGW but x87 (perhaps you can check with a 32-bit Linux build?).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test failed on mingw 32 community CI. And I see similar comments in decimal unit test.
https://github.com/apache/arrow/blob/master/cpp/src/arrow/util/decimal_test.cc#L695

I didn't tested it on my side. Maybe I can start a 32bit VM to check.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I've checked and there is no failure on Linux i386. It does seem MinGW-related.

// This test is to make sure our variance combining method is stable.
// XXX: The reference value from numpy is actually wrong due to floating
// point limits. The correct result should equals variance(90, 0) = 4050.
std::vector<std::string> chunks = {"[40000008000000490]", "[40000008000000400]"};
this->AssertVarStdIs(chunks, options, 3904.0);
#endif
}

// https://en.wikipedia.org/wiki/Kahan_summation_algorithm
void KahanSum(double& sum, double& adjust, double addend) {
double y = addend - adjust;
double t = sum + y;
adjust = (t - sum) - y;
sum = t;
}

// Calculate reference variance with Welford's online algorithm
// Calculate reference variance with Welford's online algorithm + Kahan summation
// https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm
std::pair<double, double> WelfordVar(const Array& array) {
const auto& array_numeric = reinterpret_cast<const DoubleArray&>(array);
const auto values = array_numeric.raw_values();
internal::BitmapReader reader(array.null_bitmap_data(), array.offset(), array.length());
double count = 0, mean = 0, m2 = 0;
double mean_adjust = 0, m2_adjust = 0;
for (int64_t i = 0; i < array.length(); ++i) {
if (reader.IsSet()) {
++count;
double delta = values[i] - mean;
mean += delta / count;
KahanSum(mean, mean_adjust, delta / count);
double delta2 = values[i] - mean;
m2 += delta * delta2;
KahanSum(m2, m2_adjust, delta * delta2);
}
reader.Next();
}
Expand Down Expand Up @@ -1116,8 +1128,8 @@ TEST_F(TestVarStdKernelRandom, Basics) {
double var_population, var_sample;
std::tie(var_population, var_sample) = WelfordVar(*(array->Slice(0, total_size)));

this->AssertVarStdIs(chunked, VarianceOptions{0}, var_population, 0.0001);
this->AssertVarStdIs(chunked, VarianceOptions{1}, var_sample, 0.0001);
this->AssertVarStdIs(chunked, VarianceOptions{0}, var_population);
this->AssertVarStdIs(chunked, VarianceOptions{1}, var_sample);
}

} // namespace compute
Expand Down
21 changes: 11 additions & 10 deletions cpp/src/arrow/compute/kernels/aggregate_var_std.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,32 +53,33 @@ struct VarStdState {
[]() {});

this->count = count;
this->sum = sum;
this->mean = mean;
this->m2 = m2;
}

// Combine `m2` from two chunks
// https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
// Combine `m2` from two chunks (m2 = n*s2)
// https://www.emathzone.com/tutorials/basic-statistics/combined-variance.html
void MergeFrom(const ThisType& state) {
if (state.count == 0) {
return;
}
if (this->count == 0) {
this->count = state.count;
this->sum = state.sum;
this->mean = state.mean;
this->m2 = state.m2;
return;
}
double delta = this->sum / this->count - state.sum / state.count;
this->m2 += state.m2 +
delta * delta * this->count * state.count / (this->count + state.count);
double mean = (this->mean * this->count + state.mean * state.count) /
(this->count + state.count);
this->m2 += state.m2 + this->count * (this->mean - mean) * (this->mean - mean) +
state.count * (state.mean - mean) * (state.mean - mean);
this->count += state.count;
this->sum += state.sum;
this->mean = mean;
}

int64_t count = 0;
double sum = 0;
double m2 = 0; // sum((X-mean)^2)
double mean = 0;
double m2 = 0; // m2 = count*s2 = sum((X-mean)^2)
};

enum class VarOrStd : bool { Var, Std };
Expand Down