From f4c5c175c8019d064da73efe4ba4f8c022e36edd Mon Sep 17 00:00:00 2001 From: Mahadevuni Naveen Kumar Date: Fri, 13 Feb 2026 16:44:53 +0530 Subject: [PATCH] feat(native): Implement Sketch Theta aggregate and scalar functions --- presto-docs/src/main/sphinx/presto-cpp.rst | 1 + .../src/main/sphinx/presto_cpp/functions.rst | 8 + .../sphinx/presto_cpp/functions/sketch.rst | 38 ++ presto-native-execution/CMakeLists.txt | 2 + .../presto_cpp/main/CMakeLists.txt | 1 + .../presto_cpp/main/PrestoServer.cpp | 4 + .../presto_cpp/main/functions/CMakeLists.txt | 1 + .../functions/theta_sketch/CMakeLists.txt | 17 + .../theta_sketch/ThetaSketchAggregate.cpp | 255 +++++++++++++ .../theta_sketch/ThetaSketchFunctions.cpp | 67 ++++ .../theta_sketch/ThetaSketchRegistration.h | 39 ++ .../theta_sketch/tests/CMakeLists.txt | 29 ++ .../tests/ThetaSketchAggregationTest.cpp | 353 ++++++++++++++++++ .../functions/TestThetaSketchFunctions.java | 117 ++++++ 14 files changed, 932 insertions(+) create mode 100644 presto-docs/src/main/sphinx/presto_cpp/functions.rst create mode 100644 presto-docs/src/main/sphinx/presto_cpp/functions/sketch.rst create mode 100644 presto-native-execution/presto_cpp/main/functions/theta_sketch/CMakeLists.txt create mode 100644 presto-native-execution/presto_cpp/main/functions/theta_sketch/ThetaSketchAggregate.cpp create mode 100644 presto-native-execution/presto_cpp/main/functions/theta_sketch/ThetaSketchFunctions.cpp create mode 100644 presto-native-execution/presto_cpp/main/functions/theta_sketch/ThetaSketchRegistration.h create mode 100644 presto-native-execution/presto_cpp/main/functions/theta_sketch/tests/CMakeLists.txt create mode 100644 presto-native-execution/presto_cpp/main/functions/theta_sketch/tests/ThetaSketchAggregationTest.cpp create mode 100644 presto-native-execution/src/test/java/com/facebook/presto/nativeworker/functions/TestThetaSketchFunctions.java diff --git a/presto-docs/src/main/sphinx/presto-cpp.rst b/presto-docs/src/main/sphinx/presto-cpp.rst index 305b520ef4f55..f826d6beebafb 100644 --- a/presto-docs/src/main/sphinx/presto-cpp.rst +++ b/presto-docs/src/main/sphinx/presto-cpp.rst @@ -9,6 +9,7 @@ Note: Presto C++ is in active development. See :doc:`Limitations varbinary + + Computes a theta sketch from an input dataset. The output from + this function can be used as an input to any of the other ``sketch_theta_*`` + family of functions. + +.. function:: sketch_theta_estimate(sketch) -> double + + Returns the estimate of distinct values from the input sketch. + +.. function:: sketch_theta_summary(sketch) -> row(estimate double, theta double, upper_bound_std double, lower_bound_std double, retained_entries int) + + Returns a summary of the input sketch which includes the distinct values + estimate alongside other useful information such as the sketch theta + parameter, current error bounds corresponding to 1 standard deviation, and + the number of retained entries in the sketch. + +.. _Apache DataSketches: https://datasketches.apache.org/ +.. _Theta sketch documentation: https://datasketches.apache.org/docs/Theta/ThetaSketches.html#theta-sketch-framework diff --git a/presto-native-execution/CMakeLists.txt b/presto-native-execution/CMakeLists.txt index 97f3973145950..101a2bd422af9 100644 --- a/presto-native-execution/CMakeLists.txt +++ b/presto-native-execution/CMakeLists.txt @@ -235,6 +235,8 @@ if(PRESTO_ENABLE_JWT) add_compile_definitions(PRESTO_ENABLE_JWT) endif() +find_package(DataSketches) + if("${MAX_LINK_JOBS}") set_property(GLOBAL APPEND PROPERTY JOB_POOLS "presto_link_job_pool=${MAX_LINK_JOBS}") else() diff --git a/presto-native-execution/presto_cpp/main/CMakeLists.txt b/presto-native-execution/presto_cpp/main/CMakeLists.txt index 9138bb9be597b..dc6935a833c87 100644 --- a/presto-native-execution/presto_cpp/main/CMakeLists.txt +++ b/presto-native-execution/presto_cpp/main/CMakeLists.txt @@ -69,6 +69,7 @@ target_link_libraries( presto_session_properties presto_velox_plan_conversion presto_hive_functions + presto_theta_sketch_functions velox_abfs velox_aggregates velox_caching diff --git a/presto-native-execution/presto_cpp/main/PrestoServer.cpp b/presto-native-execution/presto_cpp/main/PrestoServer.cpp index fce198fec93c9..0f89fbdcd46a2 100644 --- a/presto-native-execution/presto_cpp/main/PrestoServer.cpp +++ b/presto-native-execution/presto_cpp/main/PrestoServer.cpp @@ -32,6 +32,7 @@ #include "presto_cpp/main/connectors/SystemConnector.h" #include "presto_cpp/main/connectors/hive/functions/HiveFunctionRegistration.h" #include "presto_cpp/main/functions/FunctionMetadata.h" +#include "presto_cpp/main/functions/theta_sketch/ThetaSketchRegistration.h" #include "presto_cpp/main/http/HttpConstants.h" #include "presto_cpp/main/http/filters/AccessLogFilter.h" #include "presto_cpp/main/http/filters/HttpEndpointLatencyFilter.h" @@ -1451,6 +1452,9 @@ void PrestoServer::registerFunctions() { velox::connector::hasConnector("hive-hadoop2")) { hive::functions::registerHiveNativeFunctions(); } + + functions::aggregate::theta_sketch::registerAllThetaSketchFunctions( + prestoBuiltinFunctionPrefix_); } void PrestoServer::registerRemoteFunctions() { diff --git a/presto-native-execution/presto_cpp/main/functions/CMakeLists.txt b/presto-native-execution/presto_cpp/main/functions/CMakeLists.txt index 20020ea182e5d..7ac8ebcd7b1aa 100644 --- a/presto-native-execution/presto_cpp/main/functions/CMakeLists.txt +++ b/presto-native-execution/presto_cpp/main/functions/CMakeLists.txt @@ -14,6 +14,7 @@ add_library(presto_function_metadata OBJECT FunctionMetadata.cpp) target_link_libraries(presto_function_metadata presto_common velox_function_registry) add_subdirectory(dynamic_registry) +add_subdirectory(theta_sketch) if(PRESTO_ENABLE_REMOTE_FUNCTIONS) add_subdirectory(remote) diff --git a/presto-native-execution/presto_cpp/main/functions/theta_sketch/CMakeLists.txt b/presto-native-execution/presto_cpp/main/functions/theta_sketch/CMakeLists.txt new file mode 100644 index 0000000000000..757c61764ea82 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/functions/theta_sketch/CMakeLists.txt @@ -0,0 +1,17 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +add_library(presto_theta_sketch_functions ThetaSketchAggregate.cpp ThetaSketchFunctions.cpp) + +if(PRESTO_ENABLE_TESTING) + add_subdirectory(tests) +endif() diff --git a/presto-native-execution/presto_cpp/main/functions/theta_sketch/ThetaSketchAggregate.cpp b/presto-native-execution/presto_cpp/main/functions/theta_sketch/ThetaSketchAggregate.cpp new file mode 100644 index 0000000000000..25aca5a9bac7f --- /dev/null +++ b/presto-native-execution/presto_cpp/main/functions/theta_sketch/ThetaSketchAggregate.cpp @@ -0,0 +1,255 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "presto_cpp/main/functions/theta_sketch/ThetaSketchRegistration.h" + +#include "DataSketches/theta_sketch.hpp" +#include "DataSketches/theta_union.hpp" + +#include "velox/exec/Aggregate.h" +#include "velox/exec/SimpleAggregateAdapter.h" +#include "velox/functions/prestosql/aggregates/AggregateNames.h" +#include "velox/type/HugeInt.h" + +namespace facebook::presto::functions::aggregate { + +namespace { + +const char* const kThetaSketch = "sketch_theta"; + +template +class ThetaSketchAggregate { + public: + // Type(s) of input vector(s) wrapped in Row. + using InputType = velox::Row; + + // Type of intermediate result + using IntermediateType = velox::Varbinary; + + // Type of output vector. + using OutputType = velox::Varbinary; + + static constexpr bool default_null_behavior_ = false; + + static bool toIntermediate( + velox::exec::out_type& out, + velox::exec::optional_arg_type in) { + if (in.has_value()) { + auto updateSketch = datasketches::update_theta_sketch::builder().build(); + if constexpr (std::is_same_v) { + updateSketch.update(std::to_string(in.value())); + } else if constexpr ( + std::is_same_v || + std::is_same_v) { + const auto& strView = in.value(); + updateSketch.update(std::string(strView.data(), strView.size())); + } else { + updateSketch.update(in.value()); + } + datasketches::theta_union thetaUnion = + datasketches::theta_union::builder().build(); + thetaUnion.update(updateSketch); + auto compactSketch = thetaUnion.get_result(); + out.resize(compactSketch.get_serialized_size_bytes()); + auto serializedBytes = compactSketch.serialize(); + std::memcpy(out.data(), serializedBytes.data(), out.size()); + } + return true; + } + + struct AccumulatorType { + datasketches::theta_union thetaUnion = + datasketches::theta_union::builder().build(); + datasketches::update_theta_sketch updateSketch = + datasketches::update_theta_sketch::builder().build(); + + AccumulatorType() = delete; + + // Constructor used in initializeNewGroups(). + explicit AccumulatorType( + velox::HashStringAllocator* /*allocator*/, + ThetaSketchAggregate* /*fn*/) {} + + void updateUnion() { + thetaUnion.update(updateSketch); + updateSketch.reset(); + } + + // addInput expects one parameter of exec::arg_type for each child-type T + // wrapped in InputType. + bool addInput( + velox::HashStringAllocator* /*allocator*/, + velox::exec::optional_arg_type data) { + if (data.has_value()) { + if constexpr (std::is_same_v) { + updateSketch.update(std::to_string(data.value())); + } else if constexpr ( + std::is_same_v || + std::is_same_v) { + const auto& strView = data.value(); + updateSketch.update(std::string(strView.data(), strView.size())); + } else { + updateSketch.update(data.value()); + } + } + return true; + } + + // combine expects one parameter of exec::arg_type. + bool combine( + velox::HashStringAllocator* /*allocator*/, + velox::exec::optional_arg_type other) { + if (other.has_value()) { + updateUnion(); + auto compactSketch = datasketches::wrapped_compact_theta_sketch::wrap( + other->data(), other->size()); + thetaUnion.update(compactSketch); + } + return true; + } + + bool getResult(velox::exec::out_type& out) { + updateUnion(); + auto compactSketch = thetaUnion.get_result(); + out.resize(compactSketch.get_serialized_size_bytes()); + auto serializedBytes = compactSketch.serialize(); + std::memcpy(out.data(), serializedBytes.data(), out.size()); + return true; + } + + bool writeFinalResult( + bool nonNullGroup, + velox::exec::out_type& out) { + return getResult(out); + } + + bool writeIntermediateResult( + bool nonNullGroup, + velox::exec::out_type& out) { + return getResult(out); + } + }; +}; + +} // namespace + +velox::exec::AggregateRegistrationResult registerThetaSketchAggregate( + const std::string& prefix, + bool withCompanionFunctions, + bool overwrite) { + std::vector> + signatures; + std::string returnType = "varbinary"; + std::string intermediateType = "varbinary"; + + for (const auto& inputType : + {"tinyint", + "smallint", + "integer", + "bigint", + "hugeint", + "real", + "double", + "varchar", + "date", + "timestamp"}) { + signatures.push_back( + velox::exec::AggregateFunctionSignatureBuilder() + .returnType(returnType) + .intermediateType(intermediateType) + .argumentType(inputType) + .build()); + } + signatures.push_back( + velox::exec::AggregateFunctionSignatureBuilder() + .integerVariable("a_precision") + .integerVariable("a_scale") + .returnType(returnType) + .intermediateType(intermediateType) + .argumentType("DECIMAL(a_precision, a_scale)") + .build()); + signatures.push_back( + velox::exec::AggregateFunctionSignatureBuilder() + .integerVariable("a_precision") + .integerVariable("a_scale") + .returnType(returnType) + .intermediateType(intermediateType) + .argumentType("DECIMAL(a_precision, a_scale)") + .argumentType("double") + .build()); + + auto name = prefix + kThetaSketch; + + return velox::exec::registerAggregateFunction( + name, + std::move(signatures), + [name]( + velox::core::AggregationNode::Step step, + const std::vector& argTypes, + const velox::TypePtr& resultType, + const velox::core::QueryConfig& /*config*/) + -> std::unique_ptr { + VELOX_CHECK_LE( + argTypes.size(), 1, "{} takes at most one argument", name); + auto inputType = argTypes[0]; + if (velox::exec::isRawInput(step)) { + switch (inputType->kind()) { + case velox::TypeKind::TINYINT: + return std::make_unique>>(step, argTypes, resultType); + case velox::TypeKind::SMALLINT: + return std::make_unique>>(step, argTypes, resultType); + case velox::TypeKind::INTEGER: + return std::make_unique>>(step, argTypes, resultType); + case velox::TypeKind::BIGINT: + return std::make_unique>>(step, argTypes, resultType); + case velox::TypeKind::HUGEINT: + return std::make_unique>>( + step, argTypes, resultType); + case velox::TypeKind::REAL: + return std::make_unique>>(step, argTypes, resultType); + case velox::TypeKind::DOUBLE: + return std::make_unique>>(step, argTypes, resultType); + case velox::TypeKind::VARCHAR: + return std::make_unique>>( + step, argTypes, resultType); + case velox::TypeKind::TIMESTAMP: + return std::make_unique::NativeType>>>( + step, argTypes, resultType); + default: + VELOX_FAIL( + "Unknown input type for {} aggregation {}", + name, + inputType->kindName()); + } + } else { + return std::make_unique>>( + step, argTypes, resultType); + } + }, + withCompanionFunctions, + overwrite); +} + +} // namespace facebook::presto::functions::aggregate diff --git a/presto-native-execution/presto_cpp/main/functions/theta_sketch/ThetaSketchFunctions.cpp b/presto-native-execution/presto_cpp/main/functions/theta_sketch/ThetaSketchFunctions.cpp new file mode 100644 index 0000000000000..dab48b74dfff6 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/functions/theta_sketch/ThetaSketchFunctions.cpp @@ -0,0 +1,67 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "presto_cpp/main/functions/theta_sketch/ThetaSketchRegistration.h" + +#include "DataSketches/theta_sketch.hpp" + +#include "velox/velox/functions/Macros.h" +#include "velox/velox/functions/Registerer.h" + +namespace facebook::presto::functions { + +namespace { + +template +struct ThetaSketchEstimateFunction { + VELOX_DEFINE_FUNCTION_TYPES(T); + + FOLLY_ALWAYS_INLINE void call( + out_type& result, + const arg_type& in) { + auto compactSketch = + datasketches::wrapped_compact_theta_sketch::wrap(in.data(), in.size()); + result = compactSketch.get_estimate(); + } +}; + +template +struct ThetaSketchSummaryFunction { + VELOX_DEFINE_FUNCTION_TYPES(T); + FOLLY_ALWAYS_INLINE void call( + out_type>& result, + const arg_type& in) { + auto compactSketch = + datasketches::wrapped_compact_theta_sketch::wrap(in.data(), in.size()); + result.copy_from( + std::make_tuple( + compactSketch.get_estimate(), + compactSketch.get_theta(), + compactSketch.get_upper_bound(1), + compactSketch.get_lower_bound(1), + compactSketch.get_num_retained())); + } +}; +} // namespace + +void registerThetaSketchFunctions(const std::string& prefix) { + velox:: + registerFunction( + {prefix + "sketch_theta_estimate"}); + velox::registerFunction< + ThetaSketchSummaryFunction, + velox::Row, + velox::Varbinary>({prefix + "sketch_theta_summary"}); +} +} // namespace facebook::presto::functions diff --git a/presto-native-execution/presto_cpp/main/functions/theta_sketch/ThetaSketchRegistration.h b/presto-native-execution/presto_cpp/main/functions/theta_sketch/ThetaSketchRegistration.h new file mode 100644 index 0000000000000..be706b5a24b79 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/functions/theta_sketch/ThetaSketchRegistration.h @@ -0,0 +1,39 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "velox/exec/Aggregate.h" + +namespace facebook::presto::functions::aggregate { + +velox::exec::AggregateRegistrationResult registerThetaSketchAggregate( + const std::string& prefix, + bool withCompanionFunctions = true, + bool overwrite = false); +} // namespace facebook::presto::functions::aggregate + +namespace facebook::presto::functions { + +void registerThetaSketchFunctions(const std::string& prefix = ""); +} + +namespace facebook::presto::functions::aggregate::theta_sketch { +namespace { +void registerAllThetaSketchFunctions(const std::string& prefix = "") { + facebook::presto::functions::aggregate::registerThetaSketchAggregate(prefix); + facebook::presto::functions::registerThetaSketchFunctions(prefix); +} +} // namespace +} // namespace facebook::presto::functions::aggregate::theta_sketch diff --git a/presto-native-execution/presto_cpp/main/functions/theta_sketch/tests/CMakeLists.txt b/presto-native-execution/presto_cpp/main/functions/theta_sketch/tests/CMakeLists.txt new file mode 100644 index 0000000000000..1eb4676b3c96e --- /dev/null +++ b/presto-native-execution/presto_cpp/main/functions/theta_sketch/tests/CMakeLists.txt @@ -0,0 +1,29 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +add_executable(presto_server_sketch_functions_test ThetaSketchAggregationTest.cpp) + +add_test( + NAME presto_server_sketch_functions_test + COMMAND presto_server_sketch_functions_test + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} +) + +target_link_libraries( + presto_server_sketch_functions_test + presto_theta_sketch_functions + velox_exec + velox_exec_test_lib + velox_functions_aggregates_test_lib + GTest::gtest + GTest::gtest_main +) diff --git a/presto-native-execution/presto_cpp/main/functions/theta_sketch/tests/ThetaSketchAggregationTest.cpp b/presto-native-execution/presto_cpp/main/functions/theta_sketch/tests/ThetaSketchAggregationTest.cpp new file mode 100644 index 0000000000000..ae4617d847e11 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/functions/theta_sketch/tests/ThetaSketchAggregationTest.cpp @@ -0,0 +1,353 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "DataSketches/theta_sketch.hpp" +#include "DataSketches/theta_union.hpp" + +#include "presto_cpp/main/functions/theta_sketch/ThetaSketchRegistration.h" +#include "velox/common/hyperloglog/HllUtils.h" +#include "velox/exec/PlanNodeStats.h" +#include "velox/exec/tests/utils/PlanBuilder.h" +#include "velox/functions/lib/aggregates/tests/utils/AggregationTestBase.h" + +using namespace facebook::velox; +using namespace facebook::velox::exec; +using namespace facebook::velox::exec::test; +using namespace facebook::velox::functions::aggregate::test; +using namespace datasketches; + +namespace facebook::presto::functions::aggregate::test { +namespace { +class ThetaSketchAggregationTest : public AggregationTestBase { + protected: + static const std::vector kFruits; + static const std::vector kVegetables; + + void SetUp() override { + folly::SingletonVault::singleton()->registrationComplete(); + AggregationTestBase::SetUp(); + presto::functions::aggregate::theta_sketch::registerAllThetaSketchFunctions( + ""); + } + + template + void testGlobalAgg(const VectorPtr& values) { + auto vectors = makeRowVector({values}); + auto expected = makeRowVector({makeFlatVector( + {getExpectedResult(values)}, VARBINARY())}); + + testAggregations({vectors}, {}, {"sketch_theta(c0)"}, {expected}); + } + + template + const std::string getExpectedResult(const VectorPtr& values) { + update_theta_sketch updateSketch = update_theta_sketch::builder().build(); + FlatVector* flatVector = values->asFlatVector(); + for (auto i = 0; i < flatVector->size(); i++) { + if (!flatVector->isNullAt(i)) { + if constexpr ( + std::is_same_v || std::is_same_v) { + const auto& strView = flatVector->valueAt(i); + updateSketch.update(std::string(strView.data(), strView.size())); + } else if constexpr (std::is_same_v) { + updateSketch.update(std::to_string(flatVector->valueAt(i))); + } else { + updateSketch.update(flatVector->valueAt(i)); + } + } + } + theta_union thetaUnion = theta_union::builder().build(); + thetaUnion.update(updateSketch); + + std::stringstream s(std::ios::in | std::ios::out | std::ios::binary); + thetaUnion.get_result().serialize(s); + return s.str(); + } + + template + const RowVectorPtr getExpectedResultForGroupBy( + const VectorPtr& keys, + const VectorPtr& values) { + VELOX_CHECK_EQ(keys->size(), values->size()); + typedef struct thetaUnionStruct { + theta_union thetaUnion = theta_union::builder().build(); + update_theta_sketch updateSketch = update_theta_sketch::builder().build(); + bool hasNull = false; + } theta_unionStruct; + + std::unordered_map groupedTheta; + FlatVector* keysVector = keys->asFlatVector(); + FlatVector* valuesVector = values->asFlatVector(); + + for (auto i = 0; i < keysVector->size(); ++i) { + auto key = keysVector->valueAt(i); + if (!valuesVector->isNullAt(i)) { + auto value = valuesVector->valueAt(i); + if constexpr ( + std::is_same_v || std::is_same_v) { + groupedTheta[key].updateSketch.update( + std::string(value.data(), value.size())); + } else if constexpr (std::is_same_v) { + groupedTheta[key].updateSketch.update(std::to_string(value)); + } else { + groupedTheta[key].updateSketch.update(value); + } + } else { + groupedTheta[keysVector->valueAt(i)].hasNull = true; + } + } + + std::unordered_map results; + + for (auto& iter : groupedTheta) { + groupedTheta[iter.first].thetaUnion.update( + groupedTheta[iter.first].updateSketch); + std::stringstream s(std::ios::in | std::ios::out | std::ios::binary); + groupedTheta[iter.first].thetaUnion.get_result().serialize(s); + results[iter.first] = s.str(); + } + + return toRowVector(results); + } + + template + RowVectorPtr toRowVector(const std::unordered_map& data) { + std::vector keys(data.size()); + transform(data.begin(), data.end(), keys.begin(), [](auto pair) { + return pair.first; + }); + + std::vector values(data.size()); + transform(data.begin(), data.end(), values.begin(), [](auto pair) { + return pair.second; + }); + + return makeRowVector( + {makeFlatVector(keys), makeFlatVector(values, VARBINARY())}); + } + + template + void testGroupByAgg(const VectorPtr& keys, const VectorPtr& values) { + auto vectors = makeRowVector({keys, values}); + auto expectedResults = getExpectedResultForGroupBy(keys, values); + + testAggregations( + {vectors}, {"c0"}, {"sketch_theta(c1)"}, {expectedResults}); + } + + template + void runNumericTest(int64_t minValue) { + vector_size_t size = 50'000; + auto keys = makeFlatVector( + size, [&minValue](auto row) { return (minValue + row) % 2; }); + auto values = makeFlatVector( + size, [&minValue](auto row) { return minValue + row; }); + testGroupByAgg(keys, values); + testGlobalAgg(values); + } +}; + +const std::vector ThetaSketchAggregationTest::kFruits = { + "apple", + "banana", + "cherry", + "dragonfruit", + "grapefruit", + "melon", + "orange", + "pear", + "pineapple", + "unknown fruit with a very long name", + "watermelon"}; + +const std::vector ThetaSketchAggregationTest::kVegetables = { + "cucumber", + "tomato", + "potato", + "squash", + "unknown vegetable with a very long name"}; + +TEST_F(ThetaSketchAggregationTest, NumericTest) { + runNumericTest(0); + runNumericTest(std::numeric_limits::max()); + runNumericTest(std::numeric_limits::max()); + runNumericTest(std::numeric_limits::max()); + runNumericTest(std::numeric_limits::max()); +} + +TEST_F(ThetaSketchAggregationTest, VarcharTest) { + vector_size_t size = 50'000; + + auto keys = makeFlatVector(size, [](auto row) { return row % 2; }); + auto values = makeFlatVector(size, [&](auto row) { + return StringView( + row % 2 == 0 ? kFruits[row % kFruits.size()] + : kVegetables[row % kVegetables.size()]); + }); + + testGroupByAgg(keys, values); + testGlobalAgg(values); +} + +TEST_F(ThetaSketchAggregationTest, FloatingPointTest) { + vector_size_t size = 50'000; + auto keys = makeFlatVector(size, [](auto row) { return row % 2; }); + + { + auto values = makeFlatVector(size, [](auto row) { + return static_cast(rand()) / + (static_cast(RAND_MAX / 50000)); + }); + testGroupByAgg(keys, values); + testGlobalAgg(values); + } + + { + auto values = makeFlatVector(50000, [](auto row) { + return static_cast(rand()) / + (static_cast(RAND_MAX / 50000)); + }); + testGroupByAgg(keys, values); + testGlobalAgg(values); + } +} + +TEST_F(ThetaSketchAggregationTest, TimestampTest) { + vector_size_t size = 50'000; + auto keys = makeFlatVector(size, [](auto row) { return row % 2; }); + auto values = makeFlatVector( + size, [](auto row) { return Timestamp(row, row); }); + testGroupByAgg(keys, values); + testGlobalAgg(values); +} + +TEST_F(ThetaSketchAggregationTest, AllNullsTest) { + vector_size_t size = 5000; + auto keys = makeFlatVector(size, [](auto row) { return row % 2; }); + auto values = + makeFlatVector(size, [](auto row) { return row; }, nullEvery(1)); + testGroupByAgg(keys, values); + testGlobalAgg(values); +} + +TEST_F(ThetaSketchAggregationTest, MixedNullsTest) { + vector_size_t size = 5000; + auto keys = makeFlatVector(size, [](auto row) { return row % 2; }); + auto values = + makeFlatVector(size, [](auto row) { return row; }, nullEvery(2)); + testGroupByAgg(keys, values); + testGlobalAgg(values); +} + +TEST_F(ThetaSketchAggregationTest, streaming) { + auto rawInput1 = makeFlatVector({1, 2, 3}); + auto rawInput2 = makeFlatVector(1000, folly::identity); + auto combinedInput = makeFlatVector({1, 2, 3}); + combinedInput->append(rawInput2->wrappedVector()); + auto result = testStreaming("sketch_theta", true, {rawInput1}, {rawInput2}); + auto expectedResult = getExpectedResult(combinedInput); + ASSERT_EQ(result->size(), 1); + ASSERT_EQ(result->asFlatVector()->valueAt(0), expectedResult); + + result = testStreaming("sketch_theta", false, {rawInput1}, {rawInput2}); + ASSERT_EQ(result->size(), 1); + ASSERT_EQ(result->asFlatVector()->valueAt(0), expectedResult); +} + +TEST_F(ThetaSketchAggregationTest, testSketchThetaEstimate) { + auto sketch_theta_estimate = [this](auto input) { + auto op = PlanBuilder() + .values({makeRowVector({input})}) + .singleAggregation({}, {"sketch_theta(c0)"}) + .project({"sketch_theta_estimate(a0)"}) + .planNode(); + + return readSingleValue(op); + }; + + // Empty sketch + auto input = makeFlatVector({}); + ASSERT_EQ(sketch_theta_estimate(input).value(), 0.0); + + // Single value sketch + input = makeFlatVector(1); + ASSERT_EQ(sketch_theta_estimate(input).value(), 1.0); + + // Many value sketch + input = makeFlatVector(100, [](auto row) { return row; }); + update_theta_sketch updateSketch = update_theta_sketch::builder().build(); + for (auto i = 0; i < input->size(); ++i) { + updateSketch.update(input->valueAt(i)); + } + theta_union thetaUnion = theta_union::builder().build(); + thetaUnion.update(updateSketch); + + ASSERT_EQ( + sketch_theta_estimate(input).value(), + thetaUnion.get_result().get_estimate()); +} + +void assertSummaryMatches( + compact_theta_sketch compactSketch, + variant sketchSummary) { + auto row = + sketchSummary.value>(); + ASSERT_EQ(row.at(0).value(), compactSketch.get_estimate()); + ASSERT_EQ(row.at(1).value(), compactSketch.get_theta()); + ASSERT_EQ( + row.at(2).value(), compactSketch.get_upper_bound(1)); + ASSERT_EQ( + row.at(3).value(), compactSketch.get_lower_bound(1)); + ASSERT_EQ( + row.at(4).value(), compactSketch.get_num_retained()); +} + +TEST_F(ThetaSketchAggregationTest, testSketchThetaSummary) { + auto sketch_theta_summary = [this](auto input) { + auto op = PlanBuilder() + .values({makeRowVector({input})}) + .singleAggregation({}, {"sketch_theta(c0)"}) + .project({"sketch_theta_summary(a0)"}) + .planNode(); + + return readSingleValue(op); + }; + + // Empty sketch + auto input = makeFlatVector({}); + update_theta_sketch updateSketch = update_theta_sketch::builder().build(); + theta_union thetaUnion = theta_union::builder().build(); + thetaUnion.update(updateSketch); + assertSummaryMatches(thetaUnion.get_result(), sketch_theta_summary(input)); + + // Single value sketch + input = makeFlatVector(1); + updateSketch = update_theta_sketch::builder().build(); + updateSketch.update(1); + thetaUnion = theta_union::builder().build(); + thetaUnion.update(updateSketch); + assertSummaryMatches(thetaUnion.get_result(), sketch_theta_summary(input)); + + // Many value sketch + input = makeFlatVector(100, [](auto row) { return row; }); + updateSketch = update_theta_sketch::builder().build(); + for (auto i = 0; i < input->size(); ++i) { + updateSketch.update(input->valueAt(i)); + } + thetaUnion = theta_union::builder().build(); + thetaUnion.update(updateSketch); + assertSummaryMatches(thetaUnion.get_result(), sketch_theta_summary(input)); +} +} // namespace +} // namespace facebook::presto::functions::aggregate::test diff --git a/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/functions/TestThetaSketchFunctions.java b/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/functions/TestThetaSketchFunctions.java new file mode 100644 index 0000000000000..66a9218d9f338 --- /dev/null +++ b/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/functions/TestThetaSketchFunctions.java @@ -0,0 +1,117 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.nativeworker.functions; + +import com.facebook.presto.nativeworker.PrestoNativeQueryRunnerUtils; +import com.facebook.presto.testing.ExpectedQueryRunner; +import com.facebook.presto.testing.QueryRunner; +import com.facebook.presto.tests.AbstractTestQueryFramework; +import com.facebook.presto.tests.DistributedQueryRunner; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +public class TestThetaSketchFunctions + extends AbstractTestQueryFramework +{ + private String storageFormat; + + @BeforeClass + @Override + public void init() + throws Exception + { + storageFormat = "PARQUET"; + super.init(); + } + + @Override + protected QueryRunner createQueryRunner() throws Exception + { + DistributedQueryRunner queryRunner = (DistributedQueryRunner) PrestoNativeQueryRunnerUtils.nativeHiveQueryRunnerBuilder() + .setStorageFormat(storageFormat) + .build(); + return queryRunner; + } + + @Override + protected ExpectedQueryRunner createExpectedQueryRunner() + throws Exception + { + return PrestoNativeQueryRunnerUtils.javaHiveQueryRunnerBuilder() + .setStorageFormat(storageFormat) + .build(); + } + + @Override + protected void createTables() + { + QueryRunner queryRunner = (QueryRunner) getExpectedQueryRunner(); + queryRunner.execute("DROP TABLE IF EXISTS test_sketch_theta_functions"); + queryRunner.execute("CREATE TABLE test_sketch_theta_functions (" + + "nullColumn integer, t tinyint, s smallint, i integer," + + "l bigint, r real, d double, v varchar, " + "dt date" + ", ts timestamp," + + "sd decimal(10,4)" + + ", ld decimal(25,5))"); + queryRunner.execute("INSERT INTO test_sketch_theta_functions VALUES(" + + "null,cast(25 as tinyint),cast(250 as smallint),40000,2147483650," + + "214748.3650,2147483650123283628.72323,'sampletesttext'," + + "date'2025-11-12'" + + ",timestamp'2025-11-12 03:47:58',cast(214748.3650 as DECIMAL(10,4))," + + "cast(2147483650123283628.72323 as DECIMAL(25,5)))"); + } + + @Test + public void testSketchThetaSummary() + { + assertQuery("SELECT sketch_theta_summary(sketch_theta(nullColumn)) FROM test_sketch_theta_functions"); + + assertQuery("SELECT sketch_theta_summary(sketch_theta(t)) FROM test_sketch_theta_functions"); + assertQuery("SELECT sketch_theta_summary(sketch_theta(s)) FROM test_sketch_theta_functions"); + assertQuery("SELECT sketch_theta_summary(sketch_theta(t)) FROM test_sketch_theta_functions"); + assertQuery("SELECT sketch_theta_summary(sketch_theta(l)) FROM test_sketch_theta_functions"); + + assertQuery("SELECT sketch_theta_summary(sketch_theta(r)) FROM test_sketch_theta_functions"); + assertQuery("SELECT sketch_theta_summary(sketch_theta(d)) FROM test_sketch_theta_functions"); + + assertQuery("SELECT sketch_theta_summary(sketch_theta(v)) FROM test_sketch_theta_functions"); + + assertQuery("SELECT sketch_theta_summary(sketch_theta(dt)) FROM test_sketch_theta_functions"); + assertQuery("SELECT sketch_theta_summary(sketch_theta(ts)) FROM test_sketch_theta_functions"); + + assertQuery("SELECT sketch_theta_summary(sketch_theta(sd)) FROM test_sketch_theta_functions"); + assertQuery("SELECT sketch_theta_summary(sketch_theta(ld)) FROM test_sketch_theta_functions"); + } + + @Test + public void testSketchThetaEstimate() + { + assertQuery("SELECT sketch_theta_estimate(sketch_theta(nullColumn)) FROM test_sketch_theta_functions"); + + assertQuery("SELECT sketch_theta_estimate(sketch_theta(t)) FROM test_sketch_theta_functions"); + assertQuery("SELECT sketch_theta_estimate(sketch_theta(s)) FROM test_sketch_theta_functions"); + assertQuery("SELECT sketch_theta_estimate(sketch_theta(t)) FROM test_sketch_theta_functions"); + assertQuery("SELECT sketch_theta_estimate(sketch_theta(l)) FROM test_sketch_theta_functions"); + + assertQuery("SELECT sketch_theta_estimate(sketch_theta(r)) FROM test_sketch_theta_functions"); + assertQuery("SELECT sketch_theta_estimate(sketch_theta(d)) FROM test_sketch_theta_functions"); + + assertQuery("SELECT sketch_theta_estimate(sketch_theta(v)) FROM test_sketch_theta_functions"); + + assertQuery("SELECT sketch_theta_estimate(sketch_theta(dt)) FROM test_sketch_theta_functions"); + assertQuery("SELECT sketch_theta_estimate(sketch_theta(ts)) FROM test_sketch_theta_functions"); + + assertQuery("SELECT sketch_theta_estimate(sketch_theta(sd)) FROM test_sketch_theta_functions"); + assertQuery("SELECT sketch_theta_estimate(sketch_theta(ld)) FROM test_sketch_theta_functions"); + } +}