Skip to content

Commit db467b4

Browse files
natashasehgalfacebook-github-bot
authored andcommitted
TDigest Aggregate Fuzzer Test (#13301)
Summary: Pull Request resolved: #13301 Differential Revision: D74505772
1 parent a2211ea commit db467b4

File tree

4 files changed

+280
-0
lines changed

4 files changed

+280
-0
lines changed

velox/functions/prestosql/fuzzer/AggregationFuzzerTest.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@
3939
#include "velox/functions/prestosql/fuzzer/NoisyCountResultVerifier.h"
4040
#include "velox/functions/prestosql/fuzzer/QDigestAggInputGenerator.h"
4141
#include "velox/functions/prestosql/fuzzer/QDigestAggResultVerifier.h"
42+
#include "velox/functions/prestosql/fuzzer/TDigestAggregateInputGenerator.h"
43+
#include "velox/functions/prestosql/fuzzer/TDigestAggregateResultVerifier.h"
4244
#include "velox/functions/prestosql/registration/RegistrationFunctions.h"
4345
#include "velox/functions/prestosql/window/WindowFunctionsRegistration.h"
4446
#include "velox/vector/fuzzer/VectorFuzzer.h"
@@ -89,6 +91,7 @@ getCustomInputGenerators() {
8991
{"noisy_count_if_gaussian",
9092
std::make_shared<NoisyCountIfInputGenerator>()},
9193
{"noisy_count_gaussian", std::make_shared<NoisyCountInputGenerator>()},
94+
{"tdigest_agg", std::make_shared<TDigestAggregateInputGenerator>()},
9295
};
9396
}
9497

@@ -158,6 +161,7 @@ int main(int argc, char** argv) {
158161
using facebook::velox::exec::test::NoisyCountResultVerifier;
159162
using facebook::velox::exec::test::QDigestAggResultVerifier;
160163
using facebook::velox::exec::test::setupReferenceQueryRunner;
164+
using facebook::velox::exec::test::TDigestAggregateResultVerifier;
161165
using facebook::velox::exec::test::TransformResultVerifier;
162166

163167
auto makeArrayVerifier = []() {
@@ -211,6 +215,7 @@ int main(int argc, char** argv) {
211215
std::make_shared<NoisyCountIfResultVerifier>()},
212216
{"noisy_count_gaussian",
213217
std::make_shared<NoisyCountResultVerifier>()},
218+
{"tdigest_agg", std::make_shared<TDigestAggregateResultVerifier>()},
214219
};
215220

216221
using Runner = facebook::velox::exec::test::AggregationFuzzerRunner;
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
/*
2+
* Copyright (c) Facebook, Inc. and its affiliates.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
#pragma once
17+
18+
#include <boost/random/uniform_int_distribution.hpp>
19+
#include <boost/random/uniform_real_distribution.hpp>
20+
21+
#include "velox/exec/fuzzer/InputGenerator.h"
22+
#include "velox/vector/fuzzer/VectorFuzzer.h"
23+
24+
namespace facebook::velox::exec::test {
25+
26+
class TDigestAggregateInputGenerator : public InputGenerator {
27+
public:
28+
std::vector<VectorPtr> generate(
29+
const std::vector<TypePtr>& types,
30+
VectorFuzzer& fuzzer,
31+
FuzzerGenerator& rng,
32+
memory::MemoryPool* pool) override {
33+
VELOX_CHECK_GE(types.size(), 1);
34+
VELOX_CHECK_LE(types.size(), 3);
35+
36+
std::vector<VectorPtr> inputs;
37+
inputs.reserve(types.size());
38+
39+
40+
// Values vector
41+
VELOX_CHECK(types[0]->isDouble());
42+
auto valuesVector = fuzzer.fuzz(types[0]);
43+
inputs.push_back(valuesVector);
44+
45+
// Weight is optional
46+
if (types.size() > 1) {
47+
VELOX_CHECK(types[1]->isBigint());
48+
auto weightsVector = fuzzer.fuzz(types[1]);
49+
inputs.push_back(weightsVector);
50+
}
51+
52+
// Compression is optional
53+
if (types.size() > 2) {
54+
VELOX_CHECK(types[2]->isDouble());
55+
const auto size = fuzzer.getOptions().vectorSize;
56+
// Make sure to use the same value of 'compression' for all batches in a
57+
// given Fuzzer iteration.
58+
if (!compression_.has_value()) {
59+
boost::random::uniform_real_distribution<double> dist(10.0, 1000.0);
60+
compression_ = dist(rng);
61+
}
62+
inputs.push_back(BaseVector::createConstant(
63+
DOUBLE(), compression_.value(), size, pool));
64+
}
65+
66+
return inputs;
67+
}
68+
69+
void reset() override {
70+
compression_.reset();
71+
}
72+
73+
private:
74+
std::optional<double> compression_;
75+
};
76+
77+
} // namespace facebook::velox::exec::test
Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
/*
2+
* Copyright (c) Facebook, Inc. and its affiliates.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
#pragma once
17+
18+
#include "velox/core/PlanNode.h"
19+
#include "velox/exec/fuzzer/ResultVerifier.h"
20+
#include "velox/exec/tests/utils/AssertQueryBuilder.h"
21+
#include "velox/exec/tests/utils/PlanBuilder.h"
22+
#include "velox/functions/lib/TDigest.h"
23+
#include "velox/vector/ComplexVector.h"
24+
25+
namespace facebook::velox::exec::test {
26+
27+
class TDigestAggregateResultVerifier : public ResultVerifier {
28+
public:
29+
bool supportsCompare() override {
30+
return true;
31+
}
32+
33+
bool supportsVerify() override {
34+
return false;
35+
}
36+
37+
void initialize(
38+
const std::vector<RowVectorPtr>& /*input*/,
39+
const std::vector<core::ExprPtr>& /*projections*/,
40+
const std::vector<std::string>& groupingKeys,
41+
const core::AggregationNode::Aggregate& aggregate,
42+
const std::string& aggregateName) override {
43+
keys_ = groupingKeys;
44+
resultName_ = aggregateName;
45+
46+
// Check TDigest types
47+
validateTDigestTypes(aggregate.call);
48+
}
49+
50+
void initializeWindow(
51+
const std::vector<RowVectorPtr>& /*input*/,
52+
const std::vector<core::ExprPtr>& /*projections*/,
53+
const std::vector<std::string>& /*partitionByKeys*/,
54+
const std::vector<SortingKeyAndOrder>& /*sortingKeysAndOrders*/,
55+
const core::WindowNode::Function& function,
56+
const std::string& /*frame*/,
57+
const std::string& windowName) override {
58+
keys_ = {"row_number"};
59+
resultName_ = windowName;
60+
61+
// Check TDigest types
62+
validateTDigestTypes(function.functionCall);
63+
}
64+
65+
bool compare(const RowVectorPtr& result, const RowVectorPtr& altResult)
66+
override {
67+
VELOX_CHECK_EQ(result->size(), altResult->size());
68+
69+
auto projection = keys_;
70+
projection.push_back(resultName_);
71+
72+
auto planNodeIdGenerator = std::make_shared<core::PlanNodeIdGenerator>();
73+
auto builder = PlanBuilder(planNodeIdGenerator).values({result});
74+
if (!keys_.empty()) {
75+
builder = builder.orderBy(keys_, false);
76+
}
77+
auto sortByKeys = builder.project(projection).planNode();
78+
auto sortedResult =
79+
AssertQueryBuilder(sortByKeys).copyResults(result->pool());
80+
81+
builder = PlanBuilder(planNodeIdGenerator).values({altResult});
82+
if (!keys_.empty()) {
83+
builder = builder.orderBy(keys_, false);
84+
}
85+
sortByKeys = builder.project(projection).planNode();
86+
auto sortedAltResult =
87+
AssertQueryBuilder(sortByKeys).copyResults(altResult->pool());
88+
89+
VELOX_CHECK_EQ(sortedResult->size(), sortedAltResult->size());
90+
auto size = sortedResult->size();
91+
for (auto i = 0; i < size; i++) {
92+
auto resultIsNull = sortedResult->childAt(resultName_)->isNullAt(i);
93+
auto altResultIsNull = sortedAltResult->childAt(resultName_)->isNullAt(i);
94+
if (resultIsNull || altResultIsNull) {
95+
VELOX_CHECK(resultIsNull && altResultIsNull);
96+
continue;
97+
}
98+
99+
auto resultValue = sortedResult->childAt(resultName_)
100+
->as<SimpleVector<StringView>>()
101+
->valueAt(i);
102+
auto altResultValue = sortedAltResult->childAt(resultName_)
103+
->as<SimpleVector<StringView>>()
104+
->valueAt(i);
105+
if (resultValue == altResultValue) {
106+
continue;
107+
} else {
108+
checkEquivalentTDigest(resultValue, altResultValue);
109+
}
110+
}
111+
return true;
112+
}
113+
114+
bool verify(const RowVectorPtr& /*result*/) override {
115+
VELOX_UNSUPPORTED();
116+
}
117+
118+
void reset() override {
119+
keys_.clear();
120+
resultName_.clear();
121+
}
122+
123+
private:
124+
// Helper method to check TDigest input and return types
125+
void validateTDigestTypes(const core::CallTypedExprPtr& call) const {
126+
// Check input type is double
127+
auto inputType = call->inputs()[0]->type();
128+
if (inputType->kind() != TypeKind::DOUBLE) {
129+
VELOX_FAIL(
130+
"TDigest only supports DOUBLE input type, got {}",
131+
inputType->toString());
132+
}
133+
auto returnType = call->type();
134+
if (returnType->kind() != TypeKind::VARBINARY) {
135+
VELOX_FAIL(
136+
"TDigest return type must be VARBINARY, got {}",
137+
returnType->toString());
138+
}
139+
}
140+
141+
void checkEquivalentTDigest(
142+
const StringView& result,
143+
const StringView& altResult) {
144+
// Create TDigests from serialized data
145+
facebook::velox::functions::TDigest<> resultTdigest;
146+
facebook::velox::functions::TDigest<> altResultTdigest;
147+
std::vector<int16_t> positions;
148+
149+
try {
150+
resultTdigest.mergeDeserialized(positions, result.data());
151+
resultTdigest.compress(positions);
152+
153+
positions.clear();
154+
altResultTdigest.mergeDeserialized(positions, altResult.data());
155+
altResultTdigest.compress(positions);
156+
} catch (const std::exception& e) {
157+
VELOX_FAIL("Failed to deserialize TDigest: {}", e.what());
158+
}
159+
160+
// Compare TDigest values at specific quantiles
161+
for (auto quantile : kQuantiles) {
162+
double resultQuantile = resultTdigest.estimateQuantile(quantile);
163+
double altResultQuantile = altResultTdigest.estimateQuantile(quantile);
164+
165+
variant resultVariant(resultQuantile);
166+
variant altResultVariant(altResultQuantile);
167+
VELOX_CHECK(
168+
resultVariant.equalsWithEpsilon(altResultVariant),
169+
"TDigest quantile values differ at {}: {} vs {}",
170+
quantile,
171+
resultQuantile,
172+
altResultQuantile);
173+
}
174+
}
175+
176+
static constexpr double kQuantiles[] = {
177+
0.01,
178+
0.05,
179+
0.1,
180+
0.25,
181+
0.50,
182+
0.75,
183+
0.9,
184+
0.95,
185+
0.99,
186+
};
187+
188+
std::vector<std::string> keys_;
189+
std::string resultName_;
190+
};
191+
192+
} // namespace facebook::velox::exec::test

velox/functions/prestosql/fuzzer/WindowFuzzerTest.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
#include "velox/functions/prestosql/fuzzer/MinMaxInputGenerator.h"
3030
#include "velox/functions/prestosql/fuzzer/QDigestAggInputGenerator.h"
3131
#include "velox/functions/prestosql/fuzzer/QDigestAggResultVerifier.h"
32+
#include "velox/functions/prestosql/fuzzer/TDigestAggregateInputGenerator.h"
33+
#include "velox/functions/prestosql/fuzzer/TDigestAggregateResultVerifier.h"
3234
#include "velox/functions/prestosql/fuzzer/WindowOffsetInputGenerator.h"
3335
#include "velox/functions/prestosql/registration/RegistrationFunctions.h"
3436
#include "velox/functions/prestosql/window/WindowFunctionsRegistration.h"
@@ -76,6 +78,7 @@ getCustomInputGenerators() {
7678
{"approx_set", std::make_shared<ApproxDistinctInputGenerator>()},
7779
{"approx_percentile", std::make_shared<ApproxPercentileInputGenerator>()},
7880
{"qdigest_agg", std::make_shared<QDigestAggInputGenerator>()},
81+
{"tdigest_agg", std::make_shared<TDigestAggregateInputGenerator>()},
7982
{"lead", std::make_shared<WindowOffsetInputGenerator>(1)},
8083
{"lag", std::make_shared<WindowOffsetInputGenerator>(1)},
8184
{"nth_value", std::make_shared<WindowOffsetInputGenerator>(1)},
@@ -144,6 +147,7 @@ int main(int argc, char** argv) {
144147
using facebook::velox::exec::test::ApproxPercentileResultVerifier;
145148
using facebook::velox::exec::test::AverageResultVerifier;
146149
using facebook::velox::exec::test::QDigestAggResultVerifier;
150+
using facebook::velox::exec::test::TDigestAggregateResultVerifier;
147151

148152
static const std::unordered_map<
149153
std::string,
@@ -156,6 +160,7 @@ int main(int argc, char** argv) {
156160
std::make_shared<ApproxPercentileResultVerifier>()},
157161
{"approx_most_frequent", nullptr},
158162
{"qdigest_agg", std::make_shared<QDigestAggResultVerifier>()},
163+
{"tdigest_agg", std::make_shared<TDigestAggregateResultVerifier>()},
159164
{"merge", nullptr},
160165
// Semantically inconsistent functions
161166
{"skewness", nullptr},
@@ -193,6 +198,7 @@ int main(int argc, char** argv) {
193198
"min_by",
194199
"multimap_agg",
195200
"qdigest_agg",
201+
"tdigest_agg",
196202
};
197203

198204
using Runner = facebook::velox::exec::test::WindowFuzzerRunner;

0 commit comments

Comments
 (0)