Skip to content

Commit f4d10b3

Browse files
natashasehgalfacebook-github-bot
authored andcommitted
TDigest Aggregate Fuzzer Test (#13301)
Summary: Pull Request resolved: #13301 Differential Revision: D74505772
1 parent 9b6a3c1 commit f4d10b3

File tree

4 files changed

+289
-0
lines changed

4 files changed

+289
-0
lines changed

velox/functions/prestosql/fuzzer/AggregationFuzzerTest.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@
3737
#include "velox/functions/prestosql/fuzzer/NoisyCountIfResultVerifier.h"
3838
#include "velox/functions/prestosql/fuzzer/QDigestAggInputGenerator.h"
3939
#include "velox/functions/prestosql/fuzzer/QDigestAggResultVerifier.h"
40+
#include "velox/functions/prestosql/fuzzer/TDigestAggregateInputGenerator.h"
41+
#include "velox/functions/prestosql/fuzzer/TDigestAggregateResultVerifier.h"
4042
#include "velox/functions/prestosql/registration/RegistrationFunctions.h"
4143
#include "velox/functions/prestosql/window/WindowFunctionsRegistration.h"
4244
#include "velox/vector/fuzzer/VectorFuzzer.h"
@@ -86,6 +88,7 @@ getCustomInputGenerators() {
8688
{"map_union_sum", std::make_shared<MapUnionSumInputGenerator>()},
8789
{"noisy_count_if_gaussian",
8890
std::make_shared<NoisyCountIfInputGenerator>()},
91+
{"tdigest_agg", std::make_shared<TDigestAggregateInputGenerator>()},
8992
};
9093
}
9194

@@ -154,6 +157,7 @@ int main(int argc, char** argv) {
154157
using facebook::velox::exec::test::NoisyCountIfResultVerifier;
155158
using facebook::velox::exec::test::QDigestAggResultVerifier;
156159
using facebook::velox::exec::test::setupReferenceQueryRunner;
160+
using facebook::velox::exec::test::TDigestAggregateResultVerifier;
157161
using facebook::velox::exec::test::TransformResultVerifier;
158162

159163
auto makeArrayVerifier = []() {
@@ -205,6 +209,7 @@ int main(int argc, char** argv) {
205209
{"sum_data_size_for_stats", nullptr},
206210
{"noisy_count_if_gaussian",
207211
std::make_shared<NoisyCountIfResultVerifier>()},
212+
{"tdigest_agg", std::make_shared<TDigestAggregateResultVerifier>()},
208213
};
209214

210215
using Runner = facebook::velox::exec::test::AggregationFuzzerRunner;
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
/*
2+
* Copyright (c) Facebook, Inc. and its affiliates.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
#pragma once
17+
18+
#include <boost/random/uniform_int_distribution.hpp>
19+
#include <boost/random/uniform_real_distribution.hpp>
20+
#include <string>
21+
22+
#include "velox/exec/fuzzer/InputGenerator.h"
23+
#include "velox/vector/FlatVector.h"
24+
#include "velox/vector/fuzzer/VectorFuzzer.h"
25+
26+
namespace facebook::velox::exec::test {
27+
28+
class TDigestAggregateInputGenerator : public InputGenerator {
29+
public:
30+
std::vector<VectorPtr> generate(
31+
const std::vector<TypePtr>& types,
32+
VectorFuzzer& fuzzer,
33+
FuzzerGenerator& rng,
34+
memory::MemoryPool* pool) override {
35+
VELOX_CHECK_GE(types.size(), 1);
36+
VELOX_CHECK_LE(types.size(), 3);
37+
38+
// Type checks
39+
VELOX_CHECK(types[0]->isDouble());
40+
if (types.size() > 1) {
41+
VELOX_CHECK(types[1]->isBigint());
42+
}
43+
if (types.size() > 2) {
44+
VELOX_CHECK(types[2]->isDouble());
45+
}
46+
47+
std::vector<VectorPtr> inputs;
48+
inputs.reserve(types.size());
49+
50+
// Values vector
51+
auto valuesVector = fuzzer.fuzz(types[0]);
52+
inputs.push_back(valuesVector);
53+
54+
// Weight is optional
55+
if (types.size() > 1) {
56+
auto weightsVector = fuzzer.fuzz(types[1]);
57+
inputs.push_back(weightsVector);
58+
}
59+
60+
// Compression is optional
61+
if (types.size() > 2) {
62+
const auto size = fuzzer.getOptions().vectorSize;
63+
// Make sure to use the same value of 'compression' for all batches in a
64+
// given Fuzzer iteration.
65+
if (!compression_.has_value()) {
66+
boost::random::uniform_real_distribution<double> dist(10.0, 1000.0);
67+
compression_ = dist(rng);
68+
}
69+
inputs.push_back(BaseVector::createConstant(
70+
DOUBLE(), compression_.value(), size, pool));
71+
}
72+
73+
return inputs;
74+
}
75+
76+
void reset() override {
77+
compression_.reset();
78+
}
79+
80+
private:
81+
std::optional<double> compression_;
82+
};
83+
84+
} // namespace facebook::velox::exec::test
Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,194 @@
1+
/*
2+
* Copyright (c) Facebook, Inc. and its affiliates.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
#pragma once
17+
18+
#include "velox/core/PlanNode.h"
19+
#include "velox/exec/fuzzer/ResultVerifier.h"
20+
#include "velox/exec/tests/utils/AssertQueryBuilder.h"
21+
#include "velox/exec/tests/utils/PlanBuilder.h"
22+
#include "velox/functions/lib/TDigest.h"
23+
#include "velox/vector/ComplexVector.h"
24+
25+
namespace facebook::velox::exec::test {
26+
27+
class TDigestAggregateResultVerifier : public ResultVerifier {
28+
public:
29+
bool supportsCompare() override {
30+
return true;
31+
}
32+
33+
bool supportsVerify() override {
34+
return false;
35+
}
36+
37+
void initialize(
38+
const std::vector<RowVectorPtr>& /*input*/,
39+
const std::vector<core::ExprPtr>& /*projections*/,
40+
const std::vector<std::string>& groupingKeys,
41+
const core::AggregationNode::Aggregate& aggregate,
42+
const std::string& aggregateName) override {
43+
keys_ = groupingKeys;
44+
resultName_ = aggregateName;
45+
46+
// Check if the input type is double() which is the only valid one for
47+
// TDigest
48+
auto inputType = aggregate.call->inputs()[0]->type();
49+
if (inputType->kind() != TypeKind::DOUBLE) {
50+
VELOX_FAIL(
51+
"TDigest only supports DOUBLE input type, got {}",
52+
inputType->toString());
53+
}
54+
}
55+
56+
void initializeWindow(
57+
const std::vector<RowVectorPtr>& /*input*/,
58+
const std::vector<core::ExprPtr>& /*projections*/,
59+
const std::vector<std::string>& /*partitionByKeys*/,
60+
const std::vector<SortingKeyAndOrder>& /*sortingKeysAndOrders*/,
61+
const core::WindowNode::Function& function,
62+
const std::string& /*frame*/,
63+
const std::string& windowName) override {
64+
keys_ = {"row_number"};
65+
resultName_ = windowName;
66+
67+
// Check input type is double
68+
auto inputType = function.functionCall->inputs()[0]->type();
69+
if (inputType->kind() != TypeKind::DOUBLE) {
70+
VELOX_FAIL(
71+
"TDigest only supports DOUBLE input type, got {}",
72+
inputType->toString());
73+
}
74+
75+
// Check return type is varbinary
76+
auto returnType = function.functionCall->type();
77+
if (returnType->kind() != TypeKind::VARBINARY) {
78+
VELOX_FAIL(
79+
"TDigest return type must be VARBINARY, got {}",
80+
returnType->toString());
81+
}
82+
}
83+
84+
bool compare(const RowVectorPtr& result, const RowVectorPtr& altResult)
85+
override {
86+
VELOX_CHECK_EQ(result->size(), altResult->size());
87+
88+
auto projection = keys_;
89+
projection.push_back(resultName_);
90+
91+
auto planNodeIdGenerator = std::make_shared<core::PlanNodeIdGenerator>();
92+
auto builder = PlanBuilder(planNodeIdGenerator).values({result});
93+
if (!keys_.empty()) {
94+
builder = builder.orderBy(keys_, false);
95+
}
96+
auto sortByKeys = builder.project(projection).planNode();
97+
auto sortedResult =
98+
AssertQueryBuilder(sortByKeys).copyResults(result->pool());
99+
100+
builder = PlanBuilder(planNodeIdGenerator).values({altResult});
101+
if (!keys_.empty()) {
102+
builder = builder.orderBy(keys_, false);
103+
}
104+
sortByKeys = builder.project(projection).planNode();
105+
auto sortedAltResult =
106+
AssertQueryBuilder(sortByKeys).copyResults(altResult->pool());
107+
108+
VELOX_CHECK_EQ(sortedResult->size(), sortedAltResult->size());
109+
auto size = sortedResult->size();
110+
for (auto i = 0; i < size; i++) {
111+
auto resultIsNull = sortedResult->childAt(resultName_)->isNullAt(i);
112+
auto altResultIsNull = sortedAltResult->childAt(resultName_)->isNullAt(i);
113+
if (resultIsNull || altResultIsNull) {
114+
VELOX_CHECK(resultIsNull && altResultIsNull);
115+
continue;
116+
}
117+
118+
auto resultValue = sortedResult->childAt(resultName_)
119+
->as<SimpleVector<StringView>>()
120+
->valueAt(i);
121+
auto altResultValue = sortedAltResult->childAt(resultName_)
122+
->as<SimpleVector<StringView>>()
123+
->valueAt(i);
124+
if (resultValue == altResultValue) {
125+
continue;
126+
} else {
127+
checkEquivalentTDigest(resultValue, altResultValue);
128+
}
129+
}
130+
return true;
131+
}
132+
133+
bool verify(const RowVectorPtr& /*result*/) override {
134+
VELOX_UNSUPPORTED();
135+
}
136+
137+
void reset() override {
138+
keys_.clear();
139+
resultName_.clear();
140+
}
141+
142+
private:
143+
void checkEquivalentTDigest(
144+
const StringView& result,
145+
const StringView& altResult) {
146+
// Create TDigests from serialized data
147+
facebook::velox::functions::TDigest<> resultTdigest;
148+
facebook::velox::functions::TDigest<> altResultTdigest;
149+
std::vector<int16_t> positions;
150+
151+
try {
152+
resultTdigest.mergeDeserialized(positions, result.data());
153+
resultTdigest.compress(positions);
154+
155+
positions.clear();
156+
altResultTdigest.mergeDeserialized(positions, altResult.data());
157+
altResultTdigest.compress(positions);
158+
} catch (const std::exception& e) {
159+
VELOX_FAIL("Failed to deserialize TDigest: {}", e.what());
160+
}
161+
162+
// Compare TDigest values at specific quantiles
163+
for (auto quantile : kQuantiles) {
164+
double resultQuantile = resultTdigest.estimateQuantile(quantile);
165+
double altResultQuantile = altResultTdigest.estimateQuantile(quantile);
166+
167+
VELOX_CHECK(
168+
std::abs(resultQuantile - altResultQuantile) <= kError,
169+
"TDigest quantile values differ at {}: {} vs {}",
170+
quantile,
171+
resultQuantile,
172+
altResultQuantile);
173+
}
174+
}
175+
176+
static constexpr double kError = 0.0001;
177+
178+
static constexpr double kQuantiles[] = {
179+
0.01,
180+
0.05,
181+
0.1,
182+
0.25,
183+
0.50,
184+
0.75,
185+
0.9,
186+
0.95,
187+
0.99,
188+
};
189+
190+
std::vector<std::string> keys_;
191+
std::string resultName_;
192+
};
193+
194+
} // namespace facebook::velox::exec::test

velox/functions/prestosql/fuzzer/WindowFuzzerTest.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
#include "velox/functions/prestosql/fuzzer/MinMaxInputGenerator.h"
3030
#include "velox/functions/prestosql/fuzzer/QDigestAggInputGenerator.h"
3131
#include "velox/functions/prestosql/fuzzer/QDigestAggResultVerifier.h"
32+
#include "velox/functions/prestosql/fuzzer/TDigestAggregateInputGenerator.h"
33+
#include "velox/functions/prestosql/fuzzer/TDigestAggregateResultVerifier.h"
3234
#include "velox/functions/prestosql/fuzzer/WindowOffsetInputGenerator.h"
3335
#include "velox/functions/prestosql/registration/RegistrationFunctions.h"
3436
#include "velox/functions/prestosql/window/WindowFunctionsRegistration.h"
@@ -76,6 +78,7 @@ getCustomInputGenerators() {
7678
{"approx_set", std::make_shared<ApproxDistinctInputGenerator>()},
7779
{"approx_percentile", std::make_shared<ApproxPercentileInputGenerator>()},
7880
{"qdigest_agg", std::make_shared<QDigestAggInputGenerator>()},
81+
{"tdigest_agg", std::make_shared<TDigestAggregateInputGenerator>()},
7982
{"lead", std::make_shared<WindowOffsetInputGenerator>(1)},
8083
{"lag", std::make_shared<WindowOffsetInputGenerator>(1)},
8184
{"nth_value", std::make_shared<WindowOffsetInputGenerator>(1)},
@@ -143,6 +146,7 @@ int main(int argc, char** argv) {
143146
using facebook::velox::exec::test::ApproxPercentileResultVerifier;
144147
using facebook::velox::exec::test::AverageResultVerifier;
145148
using facebook::velox::exec::test::QDigestAggResultVerifier;
149+
using facebook::velox::exec::test::TDigestAggregateResultVerifier;
146150

147151
static const std::unordered_map<
148152
std::string,
@@ -155,6 +159,7 @@ int main(int argc, char** argv) {
155159
std::make_shared<ApproxPercentileResultVerifier>()},
156160
{"approx_most_frequent", nullptr},
157161
{"qdigest_agg", std::make_shared<QDigestAggResultVerifier>()},
162+
{"tdigest_agg", std::make_shared<TDigestAggregateResultVerifier>()},
158163
{"merge", nullptr},
159164
// Semantically inconsistent functions
160165
{"skewness", nullptr},
@@ -192,6 +197,7 @@ int main(int argc, char** argv) {
192197
"min_by",
193198
"multimap_agg",
194199
"qdigest_agg",
200+
"tdigest_agg",
195201
};
196202

197203
using Runner = facebook::velox::exec::test::WindowFuzzerRunner;

0 commit comments

Comments
 (0)