Skip to content

Commit 10d3abe

Browse files
natashasehgalfacebook-github-bot
authored andcommitted
TDigest Aggregate Fuzzer Test (#13301)
Summary: Pull Request resolved: #13301 Differential Revision: D74505772
1 parent 04d0b47 commit 10d3abe

File tree

4 files changed

+266
-2
lines changed

4 files changed

+266
-2
lines changed

velox/functions/prestosql/fuzzer/AggregationFuzzerTest.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@
3535
#include "velox/functions/prestosql/fuzzer/MinMaxInputGenerator.h"
3636
#include "velox/functions/prestosql/fuzzer/QDigestAggInputGenerator.h"
3737
#include "velox/functions/prestosql/fuzzer/QDigestAggResultVerifier.h"
38+
#include "velox/functions/prestosql/fuzzer/TDigestAggregateInputGenerator.h"
39+
#include "velox/functions/prestosql/fuzzer/TDigestAggregateResultVerifier.h"
3840
#include "velox/functions/prestosql/registration/RegistrationFunctions.h"
3941
#include "velox/functions/prestosql/window/WindowFunctionsRegistration.h"
4042
#include "velox/vector/fuzzer/VectorFuzzer.h"
@@ -82,7 +84,7 @@ getCustomInputGenerators() {
8284
{"approx_percentile", std::make_shared<ApproxPercentileInputGenerator>()},
8385
{"qdigest_agg", std::make_shared<QDigestAggInputGenerator>()},
8486
{"map_union_sum", std::make_shared<MapUnionSumInputGenerator>()},
85-
};
87+
{"tdigest_agg", std::make_shared<TDigestAggregateInputGenerator>()}};
8688
}
8789

8890
} // namespace
@@ -133,7 +135,6 @@ int main(int argc, char** argv) {
133135
"max_data_size_for_stats",
134136
"any_value",
135137
};
136-
137138
static const std::unordered_set<std::string> functionsRequireSortedInput = {
138139
"tdigest_agg",
139140
"qdigest_agg",
@@ -146,6 +147,7 @@ int main(int argc, char** argv) {
146147
using facebook::velox::exec::test::MinMaxByResultVerifier;
147148
using facebook::velox::exec::test::QDigestAggResultVerifier;
148149
using facebook::velox::exec::test::setupReferenceQueryRunner;
150+
using facebook::velox::exec::test::TDigestAggregateResultVerifier;
149151
using facebook::velox::exec::test::TransformResultVerifier;
150152

151153
auto makeArrayVerifier = []() {
@@ -190,6 +192,7 @@ int main(int argc, char** argv) {
190192
"transform_values({}, (k, v) -> \"$internal$canonicalize\"(v))")},
191193
// Semantically inconsistent functions
192194
{"skewness", nullptr},
195+
{"tdigest_agg", std::make_shared<TDigestAggregateResultVerifier>()},
193196
{"kurtosis", nullptr},
194197
{"entropy", nullptr},
195198
// https://github.com/facebookincubator/velox/issues/6330
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
/*
2+
* Copyright (c) Facebook, Inc. and its affiliates.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
#pragma once
17+
18+
#include <boost/random/uniform_int_distribution.hpp>
19+
#include <boost/random/uniform_real_distribution.hpp>
20+
#include <string>
21+
22+
#include "velox/exec/fuzzer/InputGenerator.h"
23+
#include "velox/vector/FlatVector.h"
24+
#include "velox/vector/fuzzer/VectorFuzzer.h"
25+
26+
namespace facebook::velox::exec::test {
27+
28+
class TDigestAggregateInputGenerator : public InputGenerator {
29+
public:
30+
std::vector<VectorPtr> generate(
31+
const std::vector<TypePtr>& types,
32+
VectorFuzzer& fuzzer,
33+
FuzzerGenerator& rng,
34+
memory::MemoryPool* pool) override {
35+
VELOX_CHECK_GE(types.size(), 1);
36+
VELOX_CHECK_LE(types.size(), 3);
37+
38+
// Type checks
39+
VELOX_CHECK(types[0]->isDouble());
40+
if (types.size() > 1) {
41+
VELOX_CHECK(types[1]->isBigint());
42+
}
43+
if (types.size() > 2) {
44+
VELOX_CHECK(types[2]->isDouble());
45+
}
46+
47+
std::vector<VectorPtr> inputs;
48+
inputs.reserve(types.size());
49+
50+
// Values vector
51+
auto valuesVector = fuzzer.fuzz(types[0]);
52+
inputs.push_back(valuesVector);
53+
54+
// Weight is optional
55+
if (types.size() > 1) {
56+
auto weightsVector = fuzzer.fuzz(types[1]);
57+
inputs.push_back(weightsVector);
58+
}
59+
60+
// Compression is optional
61+
if (types.size() > 2) {
62+
const auto size = fuzzer.getOptions().vectorSize;
63+
// Make sure to use the same value of 'compression' for all batches in a
64+
// given Fuzzer iteration.
65+
if (!compression_.has_value()) {
66+
boost::random::uniform_real_distribution<double> dist(10.0, 1000.0);
67+
compression_ = dist(rng);
68+
}
69+
inputs.push_back(BaseVector::createConstant(
70+
DOUBLE(), compression_.value(), size, pool));
71+
}
72+
73+
return inputs;
74+
}
75+
76+
void reset() override {
77+
compression_.reset();
78+
}
79+
80+
private:
81+
std::optional<double> compression_;
82+
};
83+
84+
} // namespace facebook::velox::exec::test
Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
/*
2+
* Copyright (c) Facebook, Inc. and its affiliates.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
#pragma once
17+
18+
#include "velox/core/PlanNode.h"
19+
#include "velox/exec/fuzzer/ResultVerifier.h"
20+
#include "velox/exec/tests/utils/AssertQueryBuilder.h"
21+
#include "velox/exec/tests/utils/PlanBuilder.h"
22+
#include "velox/functions/lib/TDigest.h"
23+
#include "velox/vector/ComplexVector.h"
24+
25+
namespace facebook::velox::exec::test {
26+
27+
class TDigestAggregateResultVerifier : public ResultVerifier {
28+
public:
29+
bool supportsCompare() override {
30+
return true;
31+
}
32+
33+
bool supportsVerify() override {
34+
return false;
35+
}
36+
37+
void initialize(
38+
const std::vector<RowVectorPtr>& /*input*/,
39+
const std::vector<core::ExprPtr>& /*projections*/,
40+
const std::vector<std::string>& groupingKeys,
41+
const core::AggregationNode::Aggregate& aggregate,
42+
const std::string& aggregateName) override {
43+
keys_ = groupingKeys;
44+
resultName_ = aggregateName;
45+
argumentTypeKind_ = aggregate.call->inputs()[0]->type()->kind();
46+
}
47+
48+
void initializeWindow(
49+
const std::vector<RowVectorPtr>& /*input*/,
50+
const std::vector<core::ExprPtr>& /*projections*/,
51+
const std::vector<std::string>& /*partitionByKeys*/,
52+
const std::vector<SortingKeyAndOrder>& /*sortingKeysAndOrders*/,
53+
const core::WindowNode::Function& function,
54+
const std::string& /*frame*/,
55+
const std::string& windowName) override {
56+
keys_ = {"row_number"};
57+
resultName_ = windowName;
58+
argumentTypeKind_ = function.functionCall->inputs()[0]->type()->kind();
59+
}
60+
61+
bool compare(const RowVectorPtr& result, const RowVectorPtr& altResult)
62+
override {
63+
VELOX_CHECK_EQ(result->size(), altResult->size());
64+
65+
auto projection = keys_;
66+
projection.push_back(resultName_);
67+
68+
auto planNodeIdGenerator = std::make_shared<core::PlanNodeIdGenerator>();
69+
auto builder = PlanBuilder(planNodeIdGenerator).values({result});
70+
if (!keys_.empty()) {
71+
builder = builder.orderBy(keys_, false);
72+
}
73+
auto sortByKeys = builder.project(projection).planNode();
74+
auto sortedResult =
75+
AssertQueryBuilder(sortByKeys).copyResults(result->pool());
76+
77+
builder = PlanBuilder(planNodeIdGenerator).values({altResult});
78+
if (!keys_.empty()) {
79+
builder = builder.orderBy(keys_, false);
80+
}
81+
sortByKeys = builder.project(projection).planNode();
82+
auto sortedAltResult =
83+
AssertQueryBuilder(sortByKeys).copyResults(altResult->pool());
84+
85+
VELOX_CHECK_EQ(sortedResult->size(), sortedAltResult->size());
86+
auto size = sortedResult->size();
87+
for (auto i = 0; i < size; i++) {
88+
auto resultIsNull = sortedResult->childAt(resultName_)->isNullAt(i);
89+
auto altResultIsNull = sortedAltResult->childAt(resultName_)->isNullAt(i);
90+
if (resultIsNull || altResultIsNull) {
91+
VELOX_CHECK(resultIsNull && altResultIsNull);
92+
continue;
93+
}
94+
95+
auto resultValue = sortedResult->childAt(resultName_)
96+
->as<SimpleVector<StringView>>()
97+
->valueAt(i);
98+
auto altResultValue = sortedAltResult->childAt(resultName_)
99+
->as<SimpleVector<StringView>>()
100+
->valueAt(i);
101+
if (resultValue == altResultValue) {
102+
continue;
103+
} else {
104+
try {
105+
facebook::velox::functions::TDigest<> resultTdigest =
106+
createDigest(resultValue.data());
107+
facebook::velox::functions::TDigest<> altResultTdigest =
108+
createDigest(altResultValue.data());
109+
110+
// Compare TDigest values at specific quantiles
111+
for (auto quantile : kQuantiles) {
112+
double resultQuantile = resultTdigest.estimateQuantile(quantile);
113+
double altResultQuantile =
114+
altResultTdigest.estimateQuantile(quantile);
115+
116+
if (std::abs(resultQuantile - altResultQuantile) > kError) {
117+
return false;
118+
}
119+
}
120+
} catch (const std::exception& e) {
121+
// Consider false if can't deserialize
122+
return false;
123+
}
124+
}
125+
}
126+
return true;
127+
}
128+
129+
bool verify(const RowVectorPtr& /*result*/) override {
130+
VELOX_UNSUPPORTED();
131+
}
132+
133+
void reset() override {
134+
keys_.clear();
135+
resultName_.clear();
136+
}
137+
138+
private:
139+
static constexpr double kError = 0.0001;
140+
141+
static constexpr double kQuantiles[] = {
142+
0.01,
143+
0.05,
144+
0.1,
145+
0.25,
146+
0.50,
147+
0.75,
148+
0.9,
149+
0.95,
150+
0.99,
151+
};
152+
153+
std::vector<std::string> keys_;
154+
std::string resultName_;
155+
TypeKind argumentTypeKind_;
156+
157+
facebook::velox::functions::TDigest<> createDigest(const char* inputData) {
158+
VELOX_CHECK_NOT_NULL(inputData, "TDigest input data cannot be null");
159+
facebook::velox::functions::TDigest<> tdigest;
160+
std::vector<int16_t> positions;
161+
try {
162+
tdigest.mergeDeserialized(positions, inputData);
163+
tdigest.compress(positions);
164+
} catch (const std::exception& e) {
165+
VELOX_FAIL("Failed to deserialize TDigest: {}", e.what());
166+
}
167+
return tdigest;
168+
}
169+
};
170+
171+
} // namespace facebook::velox::exec::test

velox/functions/prestosql/fuzzer/WindowFuzzerTest.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
#include "velox/functions/prestosql/fuzzer/MinMaxInputGenerator.h"
3030
#include "velox/functions/prestosql/fuzzer/QDigestAggInputGenerator.h"
3131
#include "velox/functions/prestosql/fuzzer/QDigestAggResultVerifier.h"
32+
#include "velox/functions/prestosql/fuzzer/TDigestAggregateInputGenerator.h"
33+
#include "velox/functions/prestosql/fuzzer/TDigestAggregateResultVerifier.h"
3234
#include "velox/functions/prestosql/fuzzer/WindowOffsetInputGenerator.h"
3335
#include "velox/functions/prestosql/registration/RegistrationFunctions.h"
3436
#include "velox/functions/prestosql/window/WindowFunctionsRegistration.h"
@@ -76,6 +78,7 @@ getCustomInputGenerators() {
7678
{"approx_set", std::make_shared<ApproxDistinctInputGenerator>()},
7779
{"approx_percentile", std::make_shared<ApproxPercentileInputGenerator>()},
7880
{"qdigest_agg", std::make_shared<QDigestAggInputGenerator>()},
81+
{"tdigest_agg", std::make_shared<TDigestAggregateInputGenerator>()},
7982
{"lead", std::make_shared<WindowOffsetInputGenerator>(1)},
8083
{"lag", std::make_shared<WindowOffsetInputGenerator>(1)},
8184
{"nth_value", std::make_shared<WindowOffsetInputGenerator>(1)},
@@ -134,6 +137,7 @@ int main(int argc, char** argv) {
134137
using facebook::velox::exec::test::ApproxPercentileResultVerifier;
135138
using facebook::velox::exec::test::AverageResultVerifier;
136139
using facebook::velox::exec::test::QDigestAggResultVerifier;
140+
using facebook::velox::exec::test::TDigestAggregateResultVerifier;
137141

138142
static const std::unordered_map<
139143
std::string,
@@ -146,6 +150,7 @@ int main(int argc, char** argv) {
146150
std::make_shared<ApproxPercentileResultVerifier>()},
147151
{"approx_most_frequent", nullptr},
148152
{"qdigest_agg", std::make_shared<QDigestAggResultVerifier>()},
153+
{"tdigest_agg", std::make_shared<TDigestAggregateResultVerifier>()},
149154
{"merge", nullptr},
150155
// Semantically inconsistent functions
151156
{"skewness", nullptr},
@@ -183,6 +188,7 @@ int main(int argc, char** argv) {
183188
"min_by",
184189
"multimap_agg",
185190
"qdigest_agg",
191+
"tdigest_agg",
186192
};
187193

188194
using Runner = facebook::velox::exec::test::WindowFuzzerRunner;

0 commit comments

Comments
 (0)