Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 81 additions & 0 deletions velox/experimental/cudf/exec/CudfHashAggregation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1530,6 +1530,87 @@ StepAwareAggregationRegistry& getStepAwareAggregationRegistry() {
return registry;
}

// Get aggregation function signatures map from the CUDF registry
exec::AggregateFunctionSignatureMap getCudfAggregationFunctionSignatureMap() {
exec::AggregateFunctionSignatureMap result;
const auto& registry = getStepAwareAggregationRegistry();

for (const auto& [name, stepMap] : registry) {
const auto singleIt = stepMap.find(core::AggregationNode::Step::kSingle);
const auto partialIt = stepMap.find(core::AggregationNode::Step::kPartial);

// We need both single (for return type) and partial (for intermediate
// type) signatures to build AggregateFunctionSignature entries.
if (singleIt == stepMap.end() || partialIt == stepMap.end()) {
continue;
}

const auto& singleSignatures = singleIt->second;
const auto& partialSignatures = partialIt->second;
const auto signatureCount =
std::min(singleSignatures.size(), partialSignatures.size());

if (signatureCount == 0) {
continue;
}

std::vector<exec::AggregateFunctionSignaturePtr> aggregateSignatures;
aggregateSignatures.reserve(signatureCount);

for (size_t i = 0; i < signatureCount; ++i) {
const auto& singleSignature = singleSignatures[i];
const auto& partialSignature = partialSignatures[i];

exec::AggregateFunctionSignatureBuilder builder;

// Preserve declared signature variables.
for (const auto& [_, variable] : singleSignature->variables()) {
if (variable.isTypeParameter()) {
if (variable.knownTypesOnly()) {
builder.knownTypeVariable(variable.name());
} else if (variable.orderableTypesOnly()) {
builder.orderableTypeVariable(variable.name());
} else if (variable.comparableTypesOnly()) {
builder.comparableTypeVariable(variable.name());
} else {
builder.typeVariable(variable.name());
}
} else if (variable.isIntegerParameter()) {
builder.integerVariable(
variable.name(),
variable.constraint().empty()
? std::nullopt
: std::make_optional(variable.constraint()));
}
}

builder.returnType(singleSignature->returnType().toString());
builder.intermediateType(partialSignature->returnType().toString());

const auto& argumentTypes = singleSignature->argumentTypes();
const auto& constantArguments = singleSignature->constantArguments();
for (size_t argIndex = 0; argIndex < argumentTypes.size(); ++argIndex) {
const auto argType = argumentTypes[argIndex].toString();
if (constantArguments[argIndex]) {
builder.constantArgumentType(argType);
} else {
builder.argumentType(argType);
}
}

if (singleSignature->variableArity()) {
builder.variableArity();
}

aggregateSignatures.push_back(builder.build());
}

result[name] = std::move(aggregateSignatures);
}

return result;
}

bool registerAggregationFunctionForStep(
const std::string& name,
core::AggregationNode::Step step,
Expand Down
4 changes: 4 additions & 0 deletions velox/experimental/cudf/exec/CudfHashAggregation.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "velox/experimental/cudf/exec/NvtxHelper.h"
#include "velox/experimental/cudf/vector/CudfVector.h"

#include "velox/exec/Aggregate.h"
#include "velox/exec/Operator.h"
#include "velox/expression/FunctionSignature.h"

Expand Down Expand Up @@ -185,6 +186,9 @@ using StepAwareAggregationRegistry = std::unordered_map<
// Get the step-aware aggregation registry
StepAwareAggregationRegistry& getStepAwareAggregationRegistry();

// Get aggregation function signatures map from the CUDF registry.
exec::AggregateFunctionSignatureMap getCudfAggregationFunctionSignatureMap();

// Register aggregation function signatures for a specific step
bool registerAggregationFunctionForStep(
const std::string& name,
Expand Down
25 changes: 25 additions & 0 deletions velox/experimental/cudf/expression/ExpressionEvaluator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,31 @@ static bool matchCallAgainstSignatures(

} // namespace

// Get function signatures map from the CUDF registry
std::unordered_map<std::string, std::vector<const exec::FunctionSignature*>>
getCudfFunctionSignatureMap() {
std::unordered_map<std::string, std::vector<const exec::FunctionSignature*>>
result;
const auto& registry = getCudfFunctionRegistry();

for (const auto& [name, spec] : registry) {
// Expose only fully-qualified functions (catalog.schema.function).
if (std::count(name.begin(), name.end(), '.') != 2 || name.front() == '.' ||
name.back() == '.') {
continue;
}
std::vector<const exec::FunctionSignature*> signatures;
for (const auto& sig : spec.signatures) {
signatures.push_back(sig.get());
}
if (!signatures.empty()) {
result[name] = signatures;
}
}

return result;
}

class SplitFunction : public CudfFunction {
public:
SplitFunction(const std::shared_ptr<velox::exec::Expr>& expr) {
Expand Down
5 changes: 5 additions & 0 deletions velox/experimental/cudf/expression/ExpressionEvaluator.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,11 @@ void registerCudfFunctions(

bool registerBuiltinFunctions(const std::string& prefix);

// Get function signatures map from the CUDF registry
// Returns a map of function names to their function signatures
std::unordered_map<std::string, std::vector<const exec::FunctionSignature*>>
getCudfFunctionSignatureMap();

class CudfExpression {
public:
virtual ~CudfExpression() = default;
Expand Down
Loading