From 5358e76155267f8eeeb510ff34321d4ce8fd2e26 Mon Sep 17 00:00:00 2001 From: rongrong Date: Tue, 19 Feb 2019 17:33:57 -0800 Subject: [PATCH 1/8] Remove unused variable in LocalExecutionPlanner --- .../com/facebook/presto/sql/planner/LocalExecutionPlanner.java | 3 --- 1 file changed, 3 deletions(-) diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/LocalExecutionPlanner.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/LocalExecutionPlanner.java index d7dc1913b33f9..b1002d24f3286 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/LocalExecutionPlanner.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/LocalExecutionPlanner.java @@ -191,7 +191,6 @@ import com.google.common.collect.Multimap; import com.google.common.collect.SetMultimap; import com.google.common.primitives.Ints; -import io.airlift.log.Logger; import io.airlift.units.DataSize; import javax.inject.Inject; @@ -285,8 +284,6 @@ public class LocalExecutionPlanner { - private static final Logger log = Logger.get(LocalExecutionPlanner.class); - private final Metadata metadata; private final SqlParser sqlParser; private final Optional explainAnalyzeContext; From 05bae3b2a9ddec920286be76544d2c0d01f1ba44 Mon Sep 17 00:00:00 2001 From: rongrong Date: Fri, 15 Feb 2019 17:21:41 -0800 Subject: [PATCH 2/8] Add FunctionHandle Add FunctionHandle abstraction. Eventually FunctionHandle should function similarly to TableHandle, as an abstraction, with implementation details decided by FunctionNamespace. Right now this is just a wrapper on Signature to enable further refactor. --- .../presto/metadata/FunctionHandle.java | 65 +++++++++++++++++++ .../presto/metadata/FunctionManager.java | 4 +- .../presto/metadata/FunctionNamespace.java | 4 +- .../presto/metadata/FunctionRegistry.java | 14 ++-- .../presto/metadata/TestFunctionRegistry.java | 18 ++--- 5 files changed, 85 insertions(+), 20 deletions(-) create mode 100644 presto-main/src/main/java/com/facebook/presto/metadata/FunctionHandle.java diff --git a/presto-main/src/main/java/com/facebook/presto/metadata/FunctionHandle.java b/presto-main/src/main/java/com/facebook/presto/metadata/FunctionHandle.java new file mode 100644 index 0000000000000..937dae0a2d1be --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/metadata/FunctionHandle.java @@ -0,0 +1,65 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.metadata; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.util.Objects; + +import static com.google.common.base.Preconditions.checkArgument; +import static java.util.Objects.requireNonNull; + +public class FunctionHandle +{ + private final Signature signature; + + @JsonCreator + public FunctionHandle(@JsonProperty("signature") Signature signature) + { + this.signature = requireNonNull(signature, "signature is null"); + checkArgument(signature.getTypeVariableConstraints().isEmpty(), "%s has unbound type parameters", signature); + } + + @JsonProperty + public Signature getSignature() + { + return signature; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + FunctionHandle that = (FunctionHandle) o; + return Objects.equals(signature, that.signature); + } + + @Override + public int hashCode() + { + return Objects.hash(signature); + } + + @Override + public String toString() + { + return signature.toString(); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/metadata/FunctionManager.java b/presto-main/src/main/java/com/facebook/presto/metadata/FunctionManager.java index abc7a560e9a81..3f64d29d5dd5e 100644 --- a/presto-main/src/main/java/com/facebook/presto/metadata/FunctionManager.java +++ b/presto-main/src/main/java/com/facebook/presto/metadata/FunctionManager.java @@ -63,7 +63,7 @@ public List listFunctions() public Signature resolveFunction(QualifiedName name, List parameterTypes) { - return globalFunctionNamespace.resolveFunction(name, parameterTypes); + return globalFunctionNamespace.resolveFunction(name, parameterTypes).getSignature(); } public WindowFunctionSupplier getWindowFunctionImplementation(Signature signature) @@ -93,7 +93,7 @@ public boolean canResolveOperator(OperatorType operatorType, Type returnType, Li public Signature resolveOperator(OperatorType operatorType, List argumentTypes) { - return globalFunctionNamespace.resolveOperator(operatorType, argumentTypes); + return globalFunctionNamespace.resolveOperator(operatorType, argumentTypes).getSignature(); } public boolean isRegistered(Signature signature) diff --git a/presto-main/src/main/java/com/facebook/presto/metadata/FunctionNamespace.java b/presto-main/src/main/java/com/facebook/presto/metadata/FunctionNamespace.java index 302ef6b437799..7dd2ea1cf0ace 100644 --- a/presto-main/src/main/java/com/facebook/presto/metadata/FunctionNamespace.java +++ b/presto-main/src/main/java/com/facebook/presto/metadata/FunctionNamespace.java @@ -48,7 +48,7 @@ public List listFunctions() return registry.list(); } - public Signature resolveFunction(QualifiedName name, List parameterTypes) + public FunctionHandle resolveFunction(QualifiedName name, List parameterTypes) { return registry.resolveFunction(name, parameterTypes); } @@ -78,7 +78,7 @@ public boolean canResolveOperator(OperatorType operatorType, Type returnType, Li return registry.canResolveOperator(operatorType, returnType, argumentTypes); } - public Signature resolveOperator(OperatorType operatorType, List argumentTypes) + public FunctionHandle resolveOperator(OperatorType operatorType, List argumentTypes) throws OperatorNotFoundException { return registry.resolveOperator(operatorType, argumentTypes); diff --git a/presto-main/src/main/java/com/facebook/presto/metadata/FunctionRegistry.java b/presto-main/src/main/java/com/facebook/presto/metadata/FunctionRegistry.java index fe02f6b243d4b..8f7882901fcad 100644 --- a/presto-main/src/main/java/com/facebook/presto/metadata/FunctionRegistry.java +++ b/presto-main/src/main/java/com/facebook/presto/metadata/FunctionRegistry.java @@ -682,7 +682,7 @@ public boolean isAggregationFunction(QualifiedName name) return Iterables.any(functions.get(name), function -> function.getSignature().getKind() == AGGREGATE); } - public Signature resolveFunction(QualifiedName name, List parameterTypes) + public FunctionHandle resolveFunction(QualifiedName name, List parameterTypes) { Collection allCandidates = functions.get(name); List exactCandidates = allCandidates.stream() @@ -691,7 +691,7 @@ public Signature resolveFunction(QualifiedName name, List Optional match = matchFunctionExact(exactCandidates, parameterTypes); if (match.isPresent()) { - return match.get(); + return new FunctionHandle(match.get()); } List genericCandidates = allCandidates.stream() @@ -700,12 +700,12 @@ public Signature resolveFunction(QualifiedName name, List match = matchFunctionExact(genericCandidates, parameterTypes); if (match.isPresent()) { - return match.get(); + return new FunctionHandle(match.get()); } match = matchFunctionWithCoercion(allCandidates, parameterTypes); if (match.isPresent()) { - return match.get(); + return new FunctionHandle(match.get()); } List expectedParameters = new ArrayList<>(); @@ -732,7 +732,7 @@ public Signature resolveFunction(QualifiedName name, List // verify we have one parameter of the proper type checkArgument(parameterTypes.size() == 1, "Expected one argument to literal function, but got %s", parameterTypes); - return getMagicLiteralFunctionSignature(type); + return new FunctionHandle(getMagicLiteralFunctionSignature(type)); } throw new PrestoException(FUNCTION_NOT_FOUND, message); @@ -1072,11 +1072,11 @@ public boolean isRegistered(Signature signature) } } - public Signature resolveOperator(OperatorType operatorType, List argumentTypes) + public FunctionHandle resolveOperator(OperatorType operatorType, List argumentTypes) throws OperatorNotFoundException { try { - return resolveFunction(QualifiedName.of(mangleOperatorName(operatorType)), fromTypes(argumentTypes)); + return new FunctionHandle(resolveFunction(QualifiedName.of(mangleOperatorName(operatorType)), fromTypes(argumentTypes)).getSignature()); } catch (PrestoException e) { if (e.getErrorCode().getCode() == FUNCTION_NOT_FOUND.toErrorCode().getCode()) { diff --git a/presto-main/src/test/java/com/facebook/presto/metadata/TestFunctionRegistry.java b/presto-main/src/test/java/com/facebook/presto/metadata/TestFunctionRegistry.java index 3afcbe8303ae0..d6e3107ef362c 100644 --- a/presto-main/src/test/java/com/facebook/presto/metadata/TestFunctionRegistry.java +++ b/presto-main/src/test/java/com/facebook/presto/metadata/TestFunctionRegistry.java @@ -86,8 +86,8 @@ public void testExactMatchBeforeCoercion() if (function.getSignature().getArgumentTypes().stream().anyMatch(TypeSignature::isCalculated)) { continue; } - Signature exactOperator = registry.resolveOperator(operatorType, resolveTypes(function.getSignature().getArgumentTypes(), typeManager)); - assertEquals(exactOperator, function.getSignature()); + FunctionHandle exactOperator = registry.resolveOperator(operatorType, resolveTypes(function.getSignature().getArgumentTypes(), typeManager)); + assertEquals(exactOperator.getSignature(), function.getSignature()); foundOperator = true; } assertTrue(foundOperator); @@ -103,8 +103,8 @@ public void testMagicLiteralFunction() TypeRegistry typeManager = new TypeRegistry(); FunctionRegistry registry = createFunctionRegistry(typeManager); - Signature function = registry.resolveFunction(QualifiedName.of(signature.getName()), fromTypeSignatures(signature.getArgumentTypes())); - assertEquals(function.getArgumentTypes(), ImmutableList.of(parseTypeSignature(StandardTypes.BIGINT))); + FunctionHandle functionHandle = registry.resolveFunction(QualifiedName.of(signature.getName()), fromTypeSignatures(signature.getArgumentTypes())); + assertEquals(functionHandle.getSignature().getArgumentTypes(), ImmutableList.of(parseTypeSignature(StandardTypes.BIGINT))); assertEquals(signature.getReturnType().getBase(), StandardTypes.TIMESTAMP_WITH_TIME_ZONE); } @@ -353,16 +353,16 @@ public ResolveFunctionAssertion forParameters(String... parameters) public ResolveFunctionAssertion returns(SignatureBuilder functionSignature) { - Signature expectedSignature = functionSignature.name(TEST_FUNCTION_NAME).build(); - Signature actualSignature = resolveSignature(); - assertEquals(actualSignature, expectedSignature); + FunctionHandle expectedFunction = new FunctionHandle(functionSignature.name(TEST_FUNCTION_NAME).build()); + FunctionHandle actualFunction = resolveFunctionHandle(); + assertEquals(expectedFunction, actualFunction); return this; } public ResolveFunctionAssertion failsWithMessage(String... messages) { try { - resolveSignature(); + resolveFunctionHandle(); fail("didn't fail as expected"); } catch (RuntimeException e) { @@ -376,7 +376,7 @@ public ResolveFunctionAssertion failsWithMessage(String... messages) return this; } - private Signature resolveSignature() + private FunctionHandle resolveFunctionHandle() { FeaturesConfig featuresConfig = new FeaturesConfig(); FunctionManager functionManager = new FunctionManager(typeRegistry, blockEncoding, featuresConfig); From 7dc031e85145e9e54a1aeeb2f1b67474e2cb8977 Mon Sep 17 00:00:00 2001 From: rongrong Date: Tue, 19 Feb 2019 16:54:04 -0800 Subject: [PATCH 3/8] Resolve functions to a handle during analysis --- .../presto/metadata/FunctionManager.java | 11 ++++++++++ .../presto/sql/analyzer/Analysis.java | 12 +++++----- .../sql/analyzer/ExpressionAnalyzer.java | 22 ++++++++++--------- .../sql/analyzer/WindowFunctionValidator.java | 2 +- .../presto/sql/planner/QueryPlanner.java | 4 ++-- 5 files changed, 32 insertions(+), 19 deletions(-) diff --git a/presto-main/src/main/java/com/facebook/presto/metadata/FunctionManager.java b/presto-main/src/main/java/com/facebook/presto/metadata/FunctionManager.java index 3f64d29d5dd5e..bce227a68ff36 100644 --- a/presto-main/src/main/java/com/facebook/presto/metadata/FunctionManager.java +++ b/presto-main/src/main/java/com/facebook/presto/metadata/FunctionManager.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.metadata; +import com.facebook.presto.Session; import com.facebook.presto.operator.aggregation.InternalAggregationFunction; import com.facebook.presto.operator.scalar.ScalarFunctionImplementation; import com.facebook.presto.operator.window.WindowFunctionSupplier; @@ -61,6 +62,16 @@ public List listFunctions() return globalFunctionNamespace.listFunctions(); } + public FunctionHandle resolveFunction(Session session, QualifiedName name, List parameterTypes) + { + // TODO Actually use session + // Session will be used to provide information about the order of function namespaces to through resolving the function. + // This is likely to be in terms of SQL path. Currently we still don't have support multiple function namespaces, nor + // SQL path. As a result, session is not used here. We still add this to distinguish the two versions of resolveFunction + // while the refactoring is on-going. + return globalFunctionNamespace.resolveFunction(name, parameterTypes); + } + public Signature resolveFunction(QualifiedName name, List parameterTypes) { return globalFunctionNamespace.resolveFunction(name, parameterTypes).getSignature(); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/Analysis.java b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/Analysis.java index 5ca74f928b643..160707a7b97b6 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/Analysis.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/Analysis.java @@ -13,8 +13,8 @@ */ package com.facebook.presto.sql.analyzer; +import com.facebook.presto.metadata.FunctionHandle; import com.facebook.presto.metadata.QualifiedObjectName; -import com.facebook.presto.metadata.Signature; import com.facebook.presto.metadata.TableHandle; import com.facebook.presto.security.AccessControl; import com.facebook.presto.spi.ColumnHandle; @@ -112,7 +112,7 @@ public class Analysis private final Map, Type> coercions = new LinkedHashMap<>(); private final Set> typeOnlyCoercions = new LinkedHashSet<>(); private final Map, List> relationCoercions = new LinkedHashMap<>(); - private final Map, Signature> functionSignature = new LinkedHashMap<>(); + private final Map, FunctionHandle> functionHandles = new LinkedHashMap<>(); private final Map, LambdaArgumentDeclaration> lambdaArgumentReferences = new LinkedHashMap<>(); private final Map columns = new LinkedHashMap<>(); @@ -452,14 +452,14 @@ public void registerTable(Table table, TableHandle handle) tables.put(NodeRef.of(table), handle); } - public Signature getFunctionSignature(FunctionCall function) + public FunctionHandle getFunctionHandle(FunctionCall function) { - return functionSignature.get(NodeRef.of(function)); + return functionHandles.get(NodeRef.of(function)); } - public void addFunctionSignatures(Map, Signature> infos) + public void addFunctionHandles(Map, FunctionHandle> infos) { - functionSignature.putAll(infos); + functionHandles.putAll(infos); } public Set> getColumnReferences() diff --git a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/ExpressionAnalyzer.java b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/ExpressionAnalyzer.java index fded35beb139b..0103c85dd7e83 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/ExpressionAnalyzer.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/ExpressionAnalyzer.java @@ -16,6 +16,7 @@ import com.facebook.presto.Session; import com.facebook.presto.SystemSessionProperties; import com.facebook.presto.execution.warnings.WarningCollector; +import com.facebook.presto.metadata.FunctionHandle; import com.facebook.presto.metadata.FunctionManager; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.metadata.OperatorNotFoundException; @@ -176,7 +177,7 @@ public class ExpressionAnalyzer private final boolean isDescribe; private final boolean legacyRowFieldOrdinalAccess; - private final Map, Signature> resolvedFunctions = new LinkedHashMap<>(); + private final Map, FunctionHandle> resolvedFunctions = new LinkedHashMap<>(); private final Set> scalarSubqueries = new LinkedHashSet<>(); private final Set> existsSubqueries = new LinkedHashSet<>(); private final Map, Type> expressionCoercions = new LinkedHashMap<>(); @@ -215,7 +216,7 @@ public ExpressionAnalyzer( this.warningCollector = requireNonNull(warningCollector, "warningCollector is null"); } - public Map, Signature> getResolvedFunctions() + public Map, FunctionHandle> getResolvedFunctions() { return unmodifiableMap(resolvedFunctions); } @@ -865,7 +866,8 @@ protected Type visitFunctionCall(FunctionCall node, StackableAstVisitorContext argumentTypes = argumentTypesBuilder.build(); - Signature function = resolveFunction(node, argumentTypes, functionManager); + FunctionHandle function = resolveFunction(session, node, argumentTypes, functionManager); + Signature functionSignature = function.getSignature(); if (node.getOrderBy().isPresent()) { for (SortItem sortItem : node.getOrderBy().get().getSortItems()) { @@ -878,8 +880,8 @@ protected Type visitFunctionCall(FunctionCall node, StackableAstVisitorContext getFunctionInputTypes() } } - public static Signature resolveFunction(FunctionCall node, List argumentTypes, FunctionManager functionManager) + public static FunctionHandle resolveFunction(Session session, FunctionCall node, List argumentTypes, FunctionManager functionManager) { try { - return functionManager.resolveFunction(node.getName(), argumentTypes); + return functionManager.resolveFunction(session, node.getName(), argumentTypes); } catch (PrestoException e) { if (e.getErrorCode().getCode() == StandardErrorCode.FUNCTION_NOT_FOUND.toErrorCode().getCode()) { @@ -1588,11 +1590,11 @@ public static ExpressionAnalysis analyzeExpression( Map, Type> expressionTypes = analyzer.getExpressionTypes(); Map, Type> expressionCoercions = analyzer.getExpressionCoercions(); Set> typeOnlyCoercions = analyzer.getTypeOnlyCoercions(); - Map, Signature> resolvedFunctions = analyzer.getResolvedFunctions(); + Map, FunctionHandle> resolvedFunctions = analyzer.getResolvedFunctions(); analysis.addTypes(expressionTypes); analysis.addCoercions(expressionCoercions, typeOnlyCoercions); - analysis.addFunctionSignatures(resolvedFunctions); + analysis.addFunctionHandles(resolvedFunctions); analysis.addColumnReferences(analyzer.getColumnReferences()); analysis.addLambdaArgumentReferences(analyzer.getLambdaArgumentReferences()); analysis.addTableColumnReferences(accessControl, session.getIdentity(), analyzer.getTableColumnReferences()); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/WindowFunctionValidator.java b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/WindowFunctionValidator.java index 3215a71f510ff..b65659bde0aee 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/WindowFunctionValidator.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/WindowFunctionValidator.java @@ -29,7 +29,7 @@ protected Void visitFunctionCall(FunctionCall functionCall, Analysis analysis) { requireNonNull(analysis, "analysis is null"); - Signature signature = analysis.getFunctionSignature(functionCall); + Signature signature = analysis.getFunctionHandle(functionCall).getSignature(); if (signature != null && signature.getKind() == WINDOW && !functionCall.getWindow().isPresent()) { throw new SemanticException(WINDOW_REQUIRES_OVER, functionCall, "Window function %s requires an OVER clause", signature.getName()); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/QueryPlanner.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/QueryPlanner.java index 9af4e526775d5..479cef33537f0 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/QueryPlanner.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/QueryPlanner.java @@ -556,7 +556,7 @@ private PlanBuilder aggregate(PlanBuilder subPlan, QuerySpecification node) } aggregationTranslations.put(aggregate, newSymbol); - aggregationsBuilder.put(newSymbol, new Aggregation((FunctionCall) rewritten, analysis.getFunctionSignature(aggregate), Optional.empty())); + aggregationsBuilder.put(newSymbol, new Aggregation((FunctionCall) rewritten, analysis.getFunctionHandle(aggregate).getSignature(), Optional.empty())); } Map aggregations = aggregationsBuilder.build(); @@ -797,7 +797,7 @@ private PlanBuilder window(PlanBuilder subPlan, List windowFunctio outputTranslations.put(windowFunction, newSymbol); WindowNode.Function function = new WindowNode.Function( - (FunctionCall) rewritten, analysis.getFunctionSignature(windowFunction), frame); + (FunctionCall) rewritten, analysis.getFunctionHandle(windowFunction).getSignature(), frame); List sourceSymbols = subPlan.getRoot().getOutputSymbols(); ImmutableList.Builder orderBySymbols = ImmutableList.builder(); From 7d50083719b80c3030cac57b8393b1c57dc4b123 Mon Sep 17 00:00:00 2001 From: rongrong Date: Tue, 19 Feb 2019 17:33:02 -0800 Subject: [PATCH 4/8] Switch window function to use FunctionHandle --- .../presto/metadata/FunctionManager.java | 4 +- .../presto/metadata/FunctionNamespace.java | 4 +- .../presto/metadata/FunctionRegistry.java | 3 +- .../sql/analyzer/TypeSignatureProvider.java | 5 +++ .../sql/planner/LocalExecutionPlanner.java | 8 ++-- .../presto/sql/planner/QueryPlanner.java | 2 +- .../UnaliasSymbolReferences.java | 6 +-- .../optimizations/WindowFilterPushDown.java | 2 +- .../presto/sql/planner/plan/WindowNode.java | 16 ++++---- .../sql/planner/sanity/TypeValidator.java | 2 +- .../presto/sql/planner/TestTypeValidator.java | 37 +++++-------------- .../assertions/WindowFunctionMatcher.java | 14 +++---- .../sql/planner/assertions/WindowMatcher.java | 6 +-- .../rule/TestMergeAdjacentWindows.java | 19 ++++------ .../rule/TestPruneWindowColumns.java | 25 +++++-------- ...stSwapAdjacentWindowsBySpecifications.java | 27 ++++++-------- .../sql/planner/plan/TestWindowNode.java | 17 +++------ 17 files changed, 83 insertions(+), 114 deletions(-) diff --git a/presto-main/src/main/java/com/facebook/presto/metadata/FunctionManager.java b/presto-main/src/main/java/com/facebook/presto/metadata/FunctionManager.java index bce227a68ff36..b84c51e1708c6 100644 --- a/presto-main/src/main/java/com/facebook/presto/metadata/FunctionManager.java +++ b/presto-main/src/main/java/com/facebook/presto/metadata/FunctionManager.java @@ -77,9 +77,9 @@ public Signature resolveFunction(QualifiedName name, List return globalFunctionNamespace.resolveFunction(name, parameterTypes).getSignature(); } - public WindowFunctionSupplier getWindowFunctionImplementation(Signature signature) + public WindowFunctionSupplier getWindowFunctionImplementation(FunctionHandle functionHandle) { - return globalFunctionNamespace.getWindowFunctionImplementation(signature); + return globalFunctionNamespace.getWindowFunctionImplementation(functionHandle); } public InternalAggregationFunction getAggregateFunctionImplementation(Signature signature) diff --git a/presto-main/src/main/java/com/facebook/presto/metadata/FunctionNamespace.java b/presto-main/src/main/java/com/facebook/presto/metadata/FunctionNamespace.java index 7dd2ea1cf0ace..2d8443fb23dec 100644 --- a/presto-main/src/main/java/com/facebook/presto/metadata/FunctionNamespace.java +++ b/presto-main/src/main/java/com/facebook/presto/metadata/FunctionNamespace.java @@ -53,9 +53,9 @@ public FunctionHandle resolveFunction(QualifiedName name, List boundTypeParameters) return typeSignatureResolver.apply(boundTypeParameters); } + public static List fromTypes(Type... types) + { + return fromTypes(ImmutableList.copyOf(types)); + } + public static List fromTypes(List types) { return types.stream() diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/LocalExecutionPlanner.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/LocalExecutionPlanner.java index b1002d24f3286..54a5e51af8e09 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/LocalExecutionPlanner.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/LocalExecutionPlanner.java @@ -21,8 +21,8 @@ import com.facebook.presto.execution.buffer.OutputBuffer; import com.facebook.presto.execution.buffer.PagesSerdeFactory; import com.facebook.presto.index.IndexManager; +import com.facebook.presto.metadata.FunctionHandle; import com.facebook.presto.metadata.Metadata; -import com.facebook.presto.metadata.Signature; import com.facebook.presto.operator.AggregationOperator.AggregationOperatorFactory; import com.facebook.presto.operator.AssignUniqueIdOperator; import com.facebook.presto.operator.DeleteOperator.DeleteOperatorFactory; @@ -901,15 +901,15 @@ public PhysicalOperation visitWindow(WindowNode node, LocalExecutionPlanContext FrameInfo frameInfo = new FrameInfo(frame.getType(), frame.getStartType(), frameStartChannel, frame.getEndType(), frameEndChannel); FunctionCall functionCall = entry.getValue().getFunctionCall(); - Signature signature = entry.getValue().getSignature(); + FunctionHandle functionHandle = entry.getValue().getFunctionHandle(); ImmutableList.Builder arguments = ImmutableList.builder(); for (Expression argument : functionCall.getArguments()) { Symbol argumentSymbol = Symbol.from(argument); arguments.add(source.getLayout().get(argumentSymbol)); } Symbol symbol = entry.getKey(); - WindowFunctionSupplier windowFunctionSupplier = metadata.getFunctionManager().getWindowFunctionImplementation(signature); - Type type = metadata.getType(signature.getReturnType()); + WindowFunctionSupplier windowFunctionSupplier = metadata.getFunctionManager().getWindowFunctionImplementation(functionHandle); + Type type = metadata.getType(functionHandle.getSignature().getReturnType()); windowFunctionsBuilder.add(window(windowFunctionSupplier, type, frameInfo, arguments.build())); windowFunctionOutputSymbolsBuilder.add(symbol); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/QueryPlanner.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/QueryPlanner.java index 479cef33537f0..1e8fbd25f285d 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/QueryPlanner.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/QueryPlanner.java @@ -797,7 +797,7 @@ private PlanBuilder window(PlanBuilder subPlan, List windowFunctio outputTranslations.put(windowFunction, newSymbol); WindowNode.Function function = new WindowNode.Function( - (FunctionCall) rewritten, analysis.getFunctionHandle(windowFunction).getSignature(), frame); + (FunctionCall) rewritten, analysis.getFunctionHandle(windowFunction), frame); List sourceSymbols = subPlan.getRoot().getOutputSymbols(); ImmutableList.Builder orderBySymbols = ImmutableList.builder(); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/UnaliasSymbolReferences.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/UnaliasSymbolReferences.java index 1e82d82564a95..a1f1df7534a89 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/UnaliasSymbolReferences.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/UnaliasSymbolReferences.java @@ -15,7 +15,7 @@ import com.facebook.presto.Session; import com.facebook.presto.execution.warnings.WarningCollector; -import com.facebook.presto.metadata.Signature; +import com.facebook.presto.metadata.FunctionHandle; import com.facebook.presto.spi.block.SortOrder; import com.facebook.presto.sql.planner.DeterminismEvaluator; import com.facebook.presto.sql.planner.OrderingScheme; @@ -195,10 +195,10 @@ public PlanNode visitWindow(WindowNode node, RewriteContext context) Symbol symbol = entry.getKey(); FunctionCall canonicalFunctionCall = (FunctionCall) canonicalize(entry.getValue().getFunctionCall()); - Signature signature = entry.getValue().getSignature(); + FunctionHandle functionHandle = entry.getValue().getFunctionHandle(); WindowNode.Frame canonicalFrame = canonicalize(entry.getValue().getFrame()); - functions.put(canonicalize(symbol), new WindowNode.Function(canonicalFunctionCall, signature, canonicalFrame)); + functions.put(canonicalize(symbol), new WindowNode.Function(canonicalFunctionCall, functionHandle, canonicalFrame)); } return new WindowNode( diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/WindowFilterPushDown.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/WindowFilterPushDown.java index bc2d3a291fed9..c52ca851329a6 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/WindowFilterPushDown.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/WindowFilterPushDown.java @@ -277,7 +277,7 @@ private static boolean canOptimizeWindowFunction(WindowNode node) return false; } Symbol rowNumberSymbol = getOnlyElement(node.getWindowFunctions().entrySet()).getKey(); - return isRowNumberSignature(node.getWindowFunctions().get(rowNumberSymbol).getSignature()); + return isRowNumberSignature(node.getWindowFunctions().get(rowNumberSymbol).getFunctionHandle().getSignature()); } private static boolean isRowNumberSignature(Signature signature) diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/WindowNode.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/WindowNode.java index f62c14caec6de..5161b1cde86c9 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/WindowNode.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/WindowNode.java @@ -13,7 +13,7 @@ */ package com.facebook.presto.sql.planner.plan; -import com.facebook.presto.metadata.Signature; +import com.facebook.presto.metadata.FunctionHandle; import com.facebook.presto.sql.planner.OrderingScheme; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.tree.Expression; @@ -326,17 +326,17 @@ public int hashCode() public static final class Function { private final FunctionCall functionCall; - private final Signature signature; + private final FunctionHandle functionHandle; private final Frame frame; @JsonCreator public Function( @JsonProperty("functionCall") FunctionCall functionCall, - @JsonProperty("signature") Signature signature, + @JsonProperty("functionHandle") FunctionHandle functionHandle, @JsonProperty("frame") Frame frame) { this.functionCall = requireNonNull(functionCall, "functionCall is null"); - this.signature = requireNonNull(signature, "Signature is null"); + this.functionHandle = requireNonNull(functionHandle, "Signature is null"); this.frame = requireNonNull(frame, "Frame is null"); } @@ -347,9 +347,9 @@ public FunctionCall getFunctionCall() } @JsonProperty - public Signature getSignature() + public FunctionHandle getFunctionHandle() { - return signature; + return functionHandle; } @JsonProperty @@ -361,7 +361,7 @@ public Frame getFrame() @Override public int hashCode() { - return Objects.hash(functionCall, signature, frame); + return Objects.hash(functionCall, functionHandle, frame); } @Override @@ -375,7 +375,7 @@ public boolean equals(Object obj) } Function other = (Function) obj; return Objects.equals(this.functionCall, other.functionCall) && - Objects.equals(this.signature, other.signature) && + Objects.equals(this.functionHandle, other.functionHandle) && Objects.equals(this.frame, other.frame); } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/sanity/TypeValidator.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/sanity/TypeValidator.java index 3c0f05162eed0..2f5e09d74297b 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/sanity/TypeValidator.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/sanity/TypeValidator.java @@ -147,7 +147,7 @@ public Void visitUnion(UnionNode node, Void context) private void checkWindowFunctions(Map functions) { for (Map.Entry entry : functions.entrySet()) { - Signature signature = entry.getValue().getSignature(); + Signature signature = entry.getValue().getFunctionHandle().getSignature(); FunctionCall call = entry.getValue().getFunctionCall(); checkSignature(entry.getKey(), signature); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestTypeValidator.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestTypeValidator.java index 09796e1bbb349..ee592f558b0da 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestTypeValidator.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestTypeValidator.java @@ -15,7 +15,9 @@ import com.facebook.presto.connector.ConnectorId; import com.facebook.presto.execution.warnings.WarningCollector; +import com.facebook.presto.metadata.FunctionHandle; import com.facebook.presto.metadata.FunctionKind; +import com.facebook.presto.metadata.FunctionManager; import com.facebook.presto.metadata.Signature; import com.facebook.presto.metadata.TableHandle; import com.facebook.presto.spi.ColumnHandle; @@ -60,6 +62,7 @@ import static com.facebook.presto.spi.type.DoubleType.DOUBLE; import static com.facebook.presto.spi.type.IntegerType.INTEGER; import static com.facebook.presto.spi.type.VarcharType.VARCHAR; +import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes; import static com.facebook.presto.sql.planner.plan.AggregationNode.Step.SINGLE; import static com.facebook.presto.sql.planner.plan.AggregationNode.singleGroupingSet; @@ -69,6 +72,7 @@ public class TestTypeValidator private static final TableHandle TEST_TABLE_HANDLE = new TableHandle(new ConnectorId("test"), new TestingTableHandle()); private static final SqlParser SQL_PARSER = new SqlParser(); private static final TypeValidator TYPE_VALIDATOR = new TypeValidator(); + private static final FunctionManager FUNCTION_MANAGER = createTestMetadataManager().getFunctionManager(); private SymbolAllocator symbolAllocator; private TableScanNode baseTableScan; @@ -145,14 +149,7 @@ public void testValidUnion() public void testValidWindow() { Symbol windowSymbol = symbolAllocator.newSymbol("sum", DOUBLE); - Signature signature = new Signature( - "sum", - FunctionKind.WINDOW, - ImmutableList.of(), - ImmutableList.of(), - DOUBLE.getTypeSignature(), - ImmutableList.of(DOUBLE.getTypeSignature()), - false); + FunctionHandle functionHandle = FUNCTION_MANAGER.resolveFunction(TEST_SESSION, QualifiedName.of("sum"), fromTypes(DOUBLE)); FunctionCall functionCall = new FunctionCall(QualifiedName.of("sum"), ImmutableList.of(columnC.toSymbolReference())); WindowNode.Frame frame = new WindowNode.Frame( @@ -164,7 +161,7 @@ public void testValidWindow() Optional.empty(), Optional.empty()); - WindowNode.Function function = new WindowNode.Function(functionCall, signature, frame); + WindowNode.Function function = new WindowNode.Function(functionCall, functionHandle, frame); WindowNode.Specification specification = new WindowNode.Specification(ImmutableList.of(), Optional.empty()); @@ -298,14 +295,7 @@ public void testInvalidAggregationFunctionSignature() public void testInvalidWindowFunctionCall() { Symbol windowSymbol = symbolAllocator.newSymbol("sum", DOUBLE); - Signature signature = new Signature( - "sum", - FunctionKind.WINDOW, - ImmutableList.of(), - ImmutableList.of(), - DOUBLE.getTypeSignature(), - ImmutableList.of(DOUBLE.getTypeSignature()), - false); + FunctionHandle functionHandle = FUNCTION_MANAGER.resolveFunction(TEST_SESSION, QualifiedName.of("sum"), fromTypes(DOUBLE)); FunctionCall functionCall = new FunctionCall(QualifiedName.of("sum"), ImmutableList.of(columnA.toSymbolReference())); // should be columnC WindowNode.Frame frame = new WindowNode.Frame( @@ -317,7 +307,7 @@ public void testInvalidWindowFunctionCall() Optional.empty(), Optional.empty()); - WindowNode.Function function = new WindowNode.Function(functionCall, signature, frame); + WindowNode.Function function = new WindowNode.Function(functionCall, functionHandle, frame); WindowNode.Specification specification = new WindowNode.Specification(ImmutableList.of(), Optional.empty()); @@ -337,14 +327,7 @@ public void testInvalidWindowFunctionCall() public void testInvalidWindowFunctionSignature() { Symbol windowSymbol = symbolAllocator.newSymbol("sum", DOUBLE); - Signature signature = new Signature( - "sum", - FunctionKind.WINDOW, - ImmutableList.of(), - ImmutableList.of(), - BIGINT.getTypeSignature(), // should be DOUBLE - ImmutableList.of(DOUBLE.getTypeSignature()), - false); + FunctionHandle functionHandle = FUNCTION_MANAGER.resolveFunction(TEST_SESSION, QualifiedName.of("sum"), fromTypes(BIGINT)); // should be DOUBLE FunctionCall functionCall = new FunctionCall(QualifiedName.of("sum"), ImmutableList.of(columnC.toSymbolReference())); WindowNode.Frame frame = new WindowNode.Frame( @@ -356,7 +339,7 @@ public void testInvalidWindowFunctionSignature() Optional.empty(), Optional.empty()); - WindowNode.Function function = new WindowNode.Function(functionCall, signature, frame); + WindowNode.Function function = new WindowNode.Function(functionCall, functionHandle, frame); WindowNode.Specification specification = new WindowNode.Specification(ImmutableList.of(), Optional.empty()); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/WindowFunctionMatcher.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/WindowFunctionMatcher.java index 2daa8c47f5730..07a0941e52cad 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/WindowFunctionMatcher.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/WindowFunctionMatcher.java @@ -14,8 +14,8 @@ package com.facebook.presto.sql.planner.assertions; import com.facebook.presto.Session; +import com.facebook.presto.metadata.FunctionHandle; import com.facebook.presto.metadata.Metadata; -import com.facebook.presto.metadata.Signature; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.planner.plan.WindowNode; @@ -34,21 +34,21 @@ public class WindowFunctionMatcher implements RvalueMatcher { private final ExpectedValueProvider callMaker; - private final Optional signature; + private final Optional functionHandle; private final Optional> frameMaker; /** * @param callMaker Always validates the function call - * @param signature Optionally validates the signature + * @param functionHandle Optionally validates the function handle * @param frameMaker Optionally validates the frame */ public WindowFunctionMatcher( ExpectedValueProvider callMaker, - Optional signature, + Optional functionHandle, Optional> frameMaker) { this.callMaker = requireNonNull(callMaker, "functionCall is null"); - this.signature = requireNonNull(signature, "signature is null"); + this.functionHandle = requireNonNull(functionHandle, "functionHandle is null"); this.frameMaker = requireNonNull(frameMaker, "frameMaker is null"); } @@ -68,7 +68,7 @@ public Optional getAssignedSymbol(PlanNode node, Session session, Metada List matchedOutputs = windowNode.getWindowFunctions().entrySet().stream() .filter(assignment -> expectedCall.equals(assignment.getValue().getFunctionCall()) - && signature.map(assignment.getValue().getSignature()::equals).orElse(true) + && functionHandle.map(assignment.getValue().getFunctionHandle()::equals).orElse(true) && expectedFrame.map(assignment.getValue().getFrame()::equals).orElse(true)) .map(Map.Entry::getKey) .collect(toImmutableList()); @@ -88,7 +88,7 @@ public String toString() return toStringHelper(this) .omitNullValues() .add("callMaker", callMaker) - .add("signature", signature.orElse(null)) + .add("functionHandle", functionHandle.orElse(null)) .add("frameMaker", frameMaker.orElse(null)) .toString(); } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/WindowMatcher.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/WindowMatcher.java index de6053030846e..cd5369e7c22bb 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/WindowMatcher.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/WindowMatcher.java @@ -15,8 +15,8 @@ import com.facebook.presto.Session; import com.facebook.presto.cost.StatsProvider; +import com.facebook.presto.metadata.FunctionHandle; import com.facebook.presto.metadata.Metadata; -import com.facebook.presto.metadata.Signature; import com.facebook.presto.spi.block.SortOrder; import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.planner.plan.WindowNode; @@ -193,13 +193,13 @@ private Builder addFunction(Optional outputAlias, ExpectedValueProvider< public Builder addFunction( String outputAlias, ExpectedValueProvider functionCall, - Signature signature, + FunctionHandle functionHandle, ExpectedValueProvider frame) { windowFunctionMatchers.add( new AliasMatcher( Optional.of(outputAlias), - new WindowFunctionMatcher(functionCall, Optional.of(signature), Optional.of(frame)))); + new WindowFunctionMatcher(functionCall, Optional.of(functionHandle), Optional.of(frame)))); return this; } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestMergeAdjacentWindows.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestMergeAdjacentWindows.java index 6b97b41c2ccf6..cc94aa684a4a2 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestMergeAdjacentWindows.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestMergeAdjacentWindows.java @@ -13,8 +13,7 @@ */ package com.facebook.presto.sql.planner.iterative.rule; -import com.facebook.presto.metadata.FunctionKind; -import com.facebook.presto.metadata.Signature; +import com.facebook.presto.metadata.FunctionHandle; import com.facebook.presto.sql.planner.assertions.ExpectedValueProvider; import com.facebook.presto.sql.planner.assertions.PlanMatchPattern; import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; @@ -34,8 +33,11 @@ import java.util.Optional; import java.util.stream.Collectors; +import static com.facebook.presto.SessionTestUtils.TEST_SESSION; +import static com.facebook.presto.metadata.MetadataManager.createTestMetadataManager; import static com.facebook.presto.spi.type.BigintType.BIGINT; import static com.facebook.presto.spi.type.DoubleType.DOUBLE; +import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.functionCall; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.specification; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.strictProject; @@ -57,14 +59,7 @@ public class TestMergeAdjacentWindows Optional.empty(), Optional.empty()); - private static final Signature signature = new Signature( - "avg", - FunctionKind.WINDOW, - ImmutableList.of(), - ImmutableList.of(), - DOUBLE.getTypeSignature(), - ImmutableList.of(DOUBLE.getTypeSignature()), - false); + private static final FunctionHandle FUNCTION_HANDLE = createTestMetadataManager().getFunctionManager().resolveFunction(TEST_SESSION, QualifiedName.of("avg"), fromTypes(DOUBLE)); private static final String columnAAlias = "ALIAS_A"; private static final ExpectedValueProvider specificationA = specification(ImmutableList.of(columnAAlias), ImmutableList.of(), ImmutableMap.of()); @@ -231,7 +226,7 @@ private WindowNode.Function newWindowNodeFunction(String functionName, String... new FunctionCall( QualifiedName.of(functionName), Arrays.stream(symbols).map(SymbolReference::new).collect(Collectors.toList())), - signature, + FUNCTION_HANDLE, frame); } @@ -243,7 +238,7 @@ private WindowNode.Function newWindowNodeFunction(String functionName, Optional< window, false, Arrays.stream(symbols).map(SymbolReference::new).collect(Collectors.toList())), - signature, + FUNCTION_HANDLE, frame); } } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneWindowColumns.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneWindowColumns.java index edc12994c2b24..89ae8d2e4b880 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneWindowColumns.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneWindowColumns.java @@ -13,8 +13,7 @@ */ package com.facebook.presto.sql.planner.iterative.rule; -import com.facebook.presto.metadata.FunctionKind; -import com.facebook.presto.metadata.Signature; +import com.facebook.presto.metadata.FunctionHandle; import com.facebook.presto.spi.block.SortOrder; import com.facebook.presto.sql.planner.OrderingScheme; import com.facebook.presto.sql.planner.Symbol; @@ -40,7 +39,10 @@ import java.util.Set; import java.util.function.Predicate; +import static com.facebook.presto.SessionTestUtils.TEST_SESSION; +import static com.facebook.presto.metadata.MetadataManager.createTestMetadataManager; import static com.facebook.presto.spi.type.BigintType.BIGINT; +import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.expression; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.functionCall; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.strictProject; @@ -55,14 +57,7 @@ public class TestPruneWindowColumns extends BaseRuleTest { - private static final Signature signature = new Signature( - "min", - FunctionKind.WINDOW, - ImmutableList.of(), - ImmutableList.of(), - BIGINT.getTypeSignature(), - ImmutableList.of(BIGINT.getTypeSignature()), - false); + private static final FunctionHandle FUNCTION_HANDLE = createTestMetadataManager().getFunctionManager().resolveFunction(TEST_SESSION, QualifiedName.of("min"), fromTypes(BIGINT)); private static final List inputSymbolNameList = ImmutableList.of("orderKey", "partitionKey", "hash", "startValue1", "startValue2", "endValue1", "endValue2", "input1", "input2", "unused"); @@ -115,7 +110,7 @@ public void testOneFunctionNotNeeded() .addFunction( "output2", functionCall("min", ImmutableList.of("input2")), - signature, + FUNCTION_HANDLE, frameProvider2) .hashSymbol("hash"), strictProject( @@ -171,12 +166,12 @@ public void testUnusedInputNotNeeded() .addFunction( "output1", functionCall("min", ImmutableList.of("input1")), - signature, + FUNCTION_HANDLE, frameProvider1) .addFunction( "output2", functionCall("min", ImmutableList.of("input2")), - signature, + FUNCTION_HANDLE, frameProvider2) .hashSymbol("hash"), strictProject( @@ -221,7 +216,7 @@ private static PlanNode buildProjectedWindow( output1, new WindowNode.Function( new FunctionCall(QualifiedName.of("min"), ImmutableList.of(input1.toSymbolReference())), - signature, + FUNCTION_HANDLE, new WindowNode.Frame( WindowFrame.Type.RANGE, UNBOUNDED_PRECEDING, @@ -233,7 +228,7 @@ private static PlanNode buildProjectedWindow( output2, new WindowNode.Function( new FunctionCall(QualifiedName.of("min"), ImmutableList.of(input2.toSymbolReference())), - signature, + FUNCTION_HANDLE, new WindowNode.Frame( WindowFrame.Type.RANGE, UNBOUNDED_PRECEDING, diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestSwapAdjacentWindowsBySpecifications.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestSwapAdjacentWindowsBySpecifications.java index 782776342a036..14e416aed2a79 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestSwapAdjacentWindowsBySpecifications.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestSwapAdjacentWindowsBySpecifications.java @@ -13,8 +13,7 @@ */ package com.facebook.presto.sql.planner.iterative.rule; -import com.facebook.presto.metadata.FunctionKind; -import com.facebook.presto.metadata.Signature; +import com.facebook.presto.metadata.FunctionHandle; import com.facebook.presto.sql.planner.assertions.ExpectedValueProvider; import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; import com.facebook.presto.sql.planner.plan.WindowNode; @@ -29,8 +28,11 @@ import java.util.Optional; +import static com.facebook.presto.SessionTestUtils.TEST_SESSION; +import static com.facebook.presto.metadata.MetadataManager.createTestMetadataManager; import static com.facebook.presto.spi.type.BigintType.BIGINT; import static com.facebook.presto.spi.type.DoubleType.DOUBLE; +import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.functionCall; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.specification; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; @@ -42,7 +44,7 @@ public class TestSwapAdjacentWindowsBySpecifications extends BaseRuleTest { private WindowNode.Frame frame; - private Signature signature; + private FunctionHandle functionHandle; public TestSwapAdjacentWindowsBySpecifications() { @@ -55,14 +57,7 @@ public TestSwapAdjacentWindowsBySpecifications() Optional.empty(), Optional.empty()); - signature = new Signature( - "avg", - FunctionKind.WINDOW, - ImmutableList.of(), - ImmutableList.of(), - DOUBLE.getTypeSignature(), - ImmutableList.of(BIGINT.getTypeSignature()), - false); + functionHandle = createTestMetadataManager().getFunctionManager().resolveFunction(TEST_SESSION, QualifiedName.of("avg"), fromTypes(BIGINT)); } @Test @@ -81,7 +76,7 @@ public void doesNotFireOnPlanWithSingleWindowNode() ImmutableList.of(p.symbol("a")), Optional.empty()), ImmutableMap.of(p.symbol("avg_1"), - new WindowNode.Function(new FunctionCall(QualifiedName.of("avg"), ImmutableList.of()), signature, frame)), + new WindowNode.Function(new FunctionCall(QualifiedName.of("avg"), ImmutableList.of()), functionHandle, frame)), p.values(p.symbol("a")))) .doesNotFire(); } @@ -104,12 +99,12 @@ public void subsetComesFirst() ImmutableList.of(p.symbol("a")), Optional.empty()), ImmutableMap.of(p.symbol("avg_1", DOUBLE), - new WindowNode.Function(new FunctionCall(QualifiedName.of("avg"), windowA, false, ImmutableList.of(new SymbolReference("a"))), signature, frame)), + new WindowNode.Function(new FunctionCall(QualifiedName.of("avg"), windowA, false, ImmutableList.of(new SymbolReference("a"))), functionHandle, frame)), p.window(new WindowNode.Specification( ImmutableList.of(p.symbol("a"), p.symbol("b")), Optional.empty()), ImmutableMap.of(p.symbol("avg_2", DOUBLE), - new WindowNode.Function(new FunctionCall(QualifiedName.of("avg"), windowAB, false, ImmutableList.of(new SymbolReference("b"))), signature, frame)), + new WindowNode.Function(new FunctionCall(QualifiedName.of("avg"), windowAB, false, ImmutableList.of(new SymbolReference("b"))), functionHandle, frame)), p.values(p.symbol("a"), p.symbol("b"))))) .matches( window(windowMatcherBuilder -> windowMatcherBuilder @@ -132,12 +127,12 @@ public void dependentWindowsAreNotReordered() ImmutableList.of(p.symbol("a")), Optional.empty()), ImmutableMap.of(p.symbol("avg_1"), - new WindowNode.Function(new FunctionCall(QualifiedName.of("avg"), windowA, false, ImmutableList.of(new SymbolReference("avg_2"))), signature, frame)), + new WindowNode.Function(new FunctionCall(QualifiedName.of("avg"), windowA, false, ImmutableList.of(new SymbolReference("avg_2"))), functionHandle, frame)), p.window(new WindowNode.Specification( ImmutableList.of(p.symbol("a"), p.symbol("b")), Optional.empty()), ImmutableMap.of(p.symbol("avg_2"), - new WindowNode.Function(new FunctionCall(QualifiedName.of("avg"), windowA, false, ImmutableList.of(new SymbolReference("a"))), signature, frame)), + new WindowNode.Function(new FunctionCall(QualifiedName.of("avg"), windowA, false, ImmutableList.of(new SymbolReference("a"))), functionHandle, frame)), p.values(p.symbol("a"), p.symbol("b"))))) .doesNotFire(); } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/plan/TestWindowNode.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/plan/TestWindowNode.java index 12c8a9537a2e2..249ef77902c1c 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/plan/TestWindowNode.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/plan/TestWindowNode.java @@ -13,8 +13,7 @@ */ package com.facebook.presto.sql.planner.plan; -import com.facebook.presto.metadata.FunctionKind; -import com.facebook.presto.metadata.Signature; +import com.facebook.presto.metadata.FunctionHandle; import com.facebook.presto.server.SliceDeserializer; import com.facebook.presto.server.SliceSerializer; import com.facebook.presto.spi.block.SortOrder; @@ -42,7 +41,10 @@ import java.util.Set; import java.util.UUID; +import static com.facebook.presto.SessionTestUtils.TEST_SESSION; +import static com.facebook.presto.metadata.MetadataManager.createTestMetadataManager; import static com.facebook.presto.spi.type.BigintType.BIGINT; +import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes; import static org.testng.Assert.assertEquals; public class TestWindowNode @@ -89,14 +91,7 @@ public void testSerializationRoundtrip() throws Exception { Symbol windowSymbol = symbolAllocator.newSymbol("sum", BIGINT); - Signature signature = new Signature( - "sum", - FunctionKind.WINDOW, - ImmutableList.of(), - ImmutableList.of(), - BIGINT.getTypeSignature(), - ImmutableList.of(BIGINT.getTypeSignature()), - false); + FunctionHandle functionHandle = createTestMetadataManager().getFunctionManager().resolveFunction(TEST_SESSION, QualifiedName.of("sum"), fromTypes(BIGINT)); FunctionCall functionCall = new FunctionCall(QualifiedName.of("sum"), ImmutableList.of(columnC.toSymbolReference())); WindowNode.Frame frame = new WindowNode.Frame( WindowFrame.Type.RANGE, @@ -113,7 +108,7 @@ public void testSerializationRoundtrip() Optional.of(new OrderingScheme( ImmutableList.of(columnB), ImmutableMap.of(columnB, SortOrder.ASC_NULLS_FIRST)))); - Map functions = ImmutableMap.of(windowSymbol, new WindowNode.Function(functionCall, signature, frame)); + Map functions = ImmutableMap.of(windowSymbol, new WindowNode.Function(functionCall, functionHandle, frame)); Optional hashSymbol = Optional.of(columnB); Set prePartitionedInputs = ImmutableSet.of(columnA); WindowNode windowNode = new WindowNode( From 8d329026227f0d9a08edb773f7b9ce480bc11a69 Mon Sep 17 00:00:00 2001 From: rongrong Date: Wed, 20 Feb 2019 10:59:18 -0800 Subject: [PATCH 5/8] Switch aggergation function to use FunctionHandle --- .../benchmark/CountAggregationBenchmark.java | 11 +- .../DoubleSumAggregationBenchmark.java | 11 +- .../presto/benchmark/HandTpchQuery1.java | 25 +-- .../presto/benchmark/HandTpchQuery6.java | 14 +- .../benchmark/HashAggregationBenchmark.java | 11 +- ...patialPartitioningInternalAggregation.java | 23 +-- .../AbstractTestGeoAggregationFunctions.java | 20 +-- .../presto/metadata/FunctionManager.java | 4 +- .../presto/metadata/FunctionNamespace.java | 4 +- .../presto/metadata/FunctionRegistry.java | 3 +- .../sql/planner/LocalExecutionPlanner.java | 4 +- .../presto/sql/planner/LogicalPlanner.java | 2 +- .../presto/sql/planner/PlanOptimizers.java | 8 +- .../presto/sql/planner/QueryPlanner.java | 2 +- .../planner/StatisticsAggregationPlanner.java | 15 +- .../presto/sql/planner/SymbolAllocator.java | 6 + .../rule/AddIntermediateAggregations.java | 5 +- .../rule/ExpressionRewriteRuleSet.java | 2 +- .../rule/ImplementFilteredAggregations.java | 2 +- ...ipleDistinctAggregationToMarkDistinct.java | 2 +- .../rule/PruneCountAggregationOverScalar.java | 4 +- .../rule/PruneOrderByInAggregation.java | 4 +- .../rule/PushAggregationThroughOuterJoin.java | 2 +- ...PushPartialAggregationThroughExchange.java | 15 +- ...RewriteSpatialPartitioningAggregation.java | 11 +- .../rule/SimplifyCountOverConstant.java | 21 ++- .../SingleDistinctAggregationToGroupBy.java | 2 +- .../TransformCorrelatedInPredicateToJoin.java | 24 ++- ...formCorrelatedScalarAggregationToJoin.java | 2 +- .../TransformExistsApplyToLateralNode.java | 12 +- .../ImplementIntersectAndExceptAsUnion.java | 27 +++- .../OptimizeMixedDistinctAggregations.java | 30 ++-- .../ScalarAggregationToJoinRewriter.java | 6 +- .../planner/optimizations/SymbolMapper.java | 2 +- ...uantifiedComparisonApplyToLateralJoin.java | 32 ++-- .../sql/planner/plan/AggregationNode.java | 14 +- .../planner/plan/StatisticAggregations.java | 15 +- .../sql/planner/sanity/TypeValidator.java | 2 +- .../presto/cost/TestCostCalculator.java | 7 +- ...kHashAndStreamingAggregationOperators.java | 18 ++- .../operator/TestAggregationOperator.java | 39 +++-- .../operator/TestHashAggregationOperator.java | 56 +++---- .../operator/TestRealAverageAggregation.java | 13 +- .../TestStreamingAggregationOperator.java | 15 +- .../operator/TestTableFinishOperator.java | 11 +- .../operator/TestTableWriterOperator.java | 12 +- .../AbstractTestAggregationFunction.java | 10 +- .../AbstractTestApproximateCountDistinct.java | 3 +- .../BenchmarkArrayAggregation.java | 12 +- .../BenchmarkGroupedTypedHistogram.java | 23 +-- .../TestApproximateCountDistinctBoolean.java | 10 +- .../TestApproximateCountDistinctDouble.java | 10 +- .../TestApproximateCountDistinctInteger.java | 10 +- ...TestApproximateCountDistinctIpAddress.java | 10 +- .../TestApproximateCountDistinctLong.java | 9 +- ...stApproximateCountDistinctLongDecimal.java | 10 +- .../TestApproximateCountDistinctSmallint.java | 10 +- .../TestApproximateCountDistinctTinyint.java | 10 +- ...TestApproximateCountDistinctVarBinary.java | 15 +- .../TestApproximatePercentileAggregation.java | 76 ++++----- .../aggregation/TestArbitraryAggregation.java | 57 +++---- .../aggregation/TestArrayAggregation.java | 69 ++++----- .../aggregation/TestChecksumAggregation.java | 79 ++++------ .../TestDoubleHistogramAggregation.java | 22 +-- .../operator/aggregation/TestHistogram.java | 106 ++++--------- .../aggregation/TestMapAggAggregation.java | 96 +++--------- .../aggregation/TestMapUnionAggregation.java | 37 +++-- .../TestMultimapAggAggregation.java | 14 +- ...TestQuantileDigestAggregationFunction.java | 40 ++--- .../TestRealHistogramAggregation.java | 15 +- .../minmaxby/TestMinMaxByAggregation.java | 146 +++++++++--------- .../minmaxby/TestMinMaxByNAggregation.java | 115 ++++---------- .../TestEffectivePredicateExtractor.java | 23 +-- .../presto/sql/planner/TestTypeValidator.java | 29 +--- .../iterative/rule/test/PlanBuilder.java | 9 +- presto-ml/pom.xml | 5 + .../ml/TestEvaluateClassifierPredictions.java | 18 +-- 77 files changed, 725 insertions(+), 943 deletions(-) diff --git a/presto-benchmark/src/main/java/com/facebook/presto/benchmark/CountAggregationBenchmark.java b/presto-benchmark/src/main/java/com/facebook/presto/benchmark/CountAggregationBenchmark.java index 89888fc18c2f2..137ea7f786715 100644 --- a/presto-benchmark/src/main/java/com/facebook/presto/benchmark/CountAggregationBenchmark.java +++ b/presto-benchmark/src/main/java/com/facebook/presto/benchmark/CountAggregationBenchmark.java @@ -13,12 +13,13 @@ */ package com.facebook.presto.benchmark; -import com.facebook.presto.metadata.Signature; +import com.facebook.presto.metadata.FunctionManager; import com.facebook.presto.operator.AggregationOperator.AggregationOperatorFactory; import com.facebook.presto.operator.OperatorFactory; import com.facebook.presto.operator.aggregation.InternalAggregationFunction; import com.facebook.presto.sql.planner.plan.AggregationNode.Step; import com.facebook.presto.sql.planner.plan.PlanNodeId; +import com.facebook.presto.sql.tree.QualifiedName; import com.facebook.presto.testing.LocalQueryRunner; import com.google.common.collect.ImmutableList; @@ -26,8 +27,9 @@ import java.util.Optional; import static com.facebook.presto.benchmark.BenchmarkQueryRunner.createLocalQueryRunner; -import static com.facebook.presto.metadata.FunctionKind.AGGREGATE; import static com.facebook.presto.spi.type.BigintType.BIGINT; +import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes; +import static com.facebook.presto.testing.TestingSession.testSessionBuilder; public class CountAggregationBenchmark extends AbstractSimpleOperatorBenchmark @@ -41,8 +43,9 @@ public CountAggregationBenchmark(LocalQueryRunner localQueryRunner) protected List createOperatorFactories() { OperatorFactory tableScanOperator = createTableScanOperator(0, new PlanNodeId("test"), "orders", "orderkey"); - InternalAggregationFunction countFunction = localQueryRunner.getMetadata().getFunctionManager().getAggregateFunctionImplementation( - new Signature("count", AGGREGATE, BIGINT.getTypeSignature())); + FunctionManager functionManager = localQueryRunner.getMetadata().getFunctionManager(); + InternalAggregationFunction countFunction = functionManager.getAggregateFunctionImplementation( + functionManager.resolveFunction(testSessionBuilder().build(), QualifiedName.of("count"), fromTypes(BIGINT))); AggregationOperatorFactory aggregationOperator = new AggregationOperatorFactory(1, new PlanNodeId("test"), Step.SINGLE, ImmutableList.of(countFunction.bind(ImmutableList.of(0), Optional.empty())), false); return ImmutableList.of(tableScanOperator, aggregationOperator); } diff --git a/presto-benchmark/src/main/java/com/facebook/presto/benchmark/DoubleSumAggregationBenchmark.java b/presto-benchmark/src/main/java/com/facebook/presto/benchmark/DoubleSumAggregationBenchmark.java index 5bcb482c83cb8..05201aa544d8e 100644 --- a/presto-benchmark/src/main/java/com/facebook/presto/benchmark/DoubleSumAggregationBenchmark.java +++ b/presto-benchmark/src/main/java/com/facebook/presto/benchmark/DoubleSumAggregationBenchmark.java @@ -13,13 +13,14 @@ */ package com.facebook.presto.benchmark; +import com.facebook.presto.metadata.FunctionManager; import com.facebook.presto.metadata.MetadataManager; -import com.facebook.presto.metadata.Signature; import com.facebook.presto.operator.AggregationOperator.AggregationOperatorFactory; import com.facebook.presto.operator.OperatorFactory; import com.facebook.presto.operator.aggregation.InternalAggregationFunction; import com.facebook.presto.sql.planner.plan.AggregationNode.Step; import com.facebook.presto.sql.planner.plan.PlanNodeId; +import com.facebook.presto.sql.tree.QualifiedName; import com.facebook.presto.testing.LocalQueryRunner; import com.google.common.collect.ImmutableList; @@ -27,8 +28,9 @@ import java.util.Optional; import static com.facebook.presto.benchmark.BenchmarkQueryRunner.createLocalQueryRunner; -import static com.facebook.presto.metadata.FunctionKind.AGGREGATE; import static com.facebook.presto.spi.type.DoubleType.DOUBLE; +import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes; +import static com.facebook.presto.testing.TestingSession.testSessionBuilder; public class DoubleSumAggregationBenchmark extends AbstractSimpleOperatorBenchmark @@ -42,8 +44,9 @@ public DoubleSumAggregationBenchmark(LocalQueryRunner localQueryRunner) protected List createOperatorFactories() { OperatorFactory tableScanOperator = createTableScanOperator(0, new PlanNodeId("test"), "orders", "totalprice"); - InternalAggregationFunction doubleSum = MetadataManager.createTestMetadataManager().getFunctionManager().getAggregateFunctionImplementation( - new Signature("sum", AGGREGATE, DOUBLE.getTypeSignature(), DOUBLE.getTypeSignature())); + FunctionManager functionManager = MetadataManager.createTestMetadataManager().getFunctionManager(); + InternalAggregationFunction doubleSum = functionManager.getAggregateFunctionImplementation( + functionManager.resolveFunction(testSessionBuilder().build(), QualifiedName.of("sum"), fromTypes(DOUBLE))); AggregationOperatorFactory aggregationOperator = new AggregationOperatorFactory(1, new PlanNodeId("test"), Step.SINGLE, ImmutableList.of(doubleSum.bind(ImmutableList.of(0), Optional.empty())), false); return ImmutableList.of(tableScanOperator, aggregationOperator); } diff --git a/presto-benchmark/src/main/java/com/facebook/presto/benchmark/HandTpchQuery1.java b/presto-benchmark/src/main/java/com/facebook/presto/benchmark/HandTpchQuery1.java index 0c22945db2c5d..b49ad3083f741 100644 --- a/presto-benchmark/src/main/java/com/facebook/presto/benchmark/HandTpchQuery1.java +++ b/presto-benchmark/src/main/java/com/facebook/presto/benchmark/HandTpchQuery1.java @@ -13,8 +13,9 @@ */ package com.facebook.presto.benchmark; +import com.facebook.presto.Session; import com.facebook.presto.benchmark.HandTpchQuery1.TpchQuery1Operator.TpchQuery1OperatorFactory; -import com.facebook.presto.metadata.Signature; +import com.facebook.presto.metadata.FunctionManager; import com.facebook.presto.operator.DriverContext; import com.facebook.presto.operator.HashAggregationOperator.HashAggregationOperatorFactory; import com.facebook.presto.operator.Operator; @@ -27,6 +28,7 @@ import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.planner.plan.AggregationNode.Step; import com.facebook.presto.sql.planner.plan.PlanNodeId; +import com.facebook.presto.sql.tree.QualifiedName; import com.facebook.presto.testing.LocalQueryRunner; import com.facebook.presto.util.DateTimeUtils; import com.google.common.collect.ImmutableList; @@ -37,11 +39,12 @@ import java.util.Optional; import static com.facebook.presto.benchmark.BenchmarkQueryRunner.createLocalQueryRunner; -import static com.facebook.presto.metadata.FunctionKind.AGGREGATE; import static com.facebook.presto.spi.type.BigintType.BIGINT; import static com.facebook.presto.spi.type.DateType.DATE; import static com.facebook.presto.spi.type.DoubleType.DOUBLE; import static com.facebook.presto.spi.type.VarcharType.VARCHAR; +import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes; +import static com.facebook.presto.testing.TestingSession.testSessionBuilder; import static com.google.common.base.Preconditions.checkState; import static io.airlift.units.DataSize.Unit.MEGABYTE; import static java.util.Objects.requireNonNull; @@ -58,14 +61,16 @@ public HandTpchQuery1(LocalQueryRunner localQueryRunner) { super(localQueryRunner, "hand_tpch_query_1", 1, 5); - longAverage = localQueryRunner.getMetadata().getFunctionManager().getAggregateFunctionImplementation( - new Signature("avg", AGGREGATE, DOUBLE.getTypeSignature(), BIGINT.getTypeSignature())); - doubleAverage = localQueryRunner.getMetadata().getFunctionManager().getAggregateFunctionImplementation( - new Signature("avg", AGGREGATE, DOUBLE.getTypeSignature(), DOUBLE.getTypeSignature())); - doubleSum = localQueryRunner.getMetadata().getFunctionManager().getAggregateFunctionImplementation( - new Signature("sum", AGGREGATE, DOUBLE.getTypeSignature(), DOUBLE.getTypeSignature())); - countFunction = localQueryRunner.getMetadata().getFunctionManager().getAggregateFunctionImplementation( - new Signature("count", AGGREGATE, BIGINT.getTypeSignature())); + FunctionManager functionManager = localQueryRunner.getMetadata().getFunctionManager(); + Session session = testSessionBuilder().setCatalog("tpch").setSchema("tiny").build(); + longAverage = functionManager.getAggregateFunctionImplementation( + functionManager.resolveFunction(session, QualifiedName.of("avg"), fromTypes(BIGINT))); + doubleAverage = functionManager.getAggregateFunctionImplementation( + functionManager.resolveFunction(session, QualifiedName.of("avg"), fromTypes(DOUBLE))); + doubleSum = functionManager.getAggregateFunctionImplementation( + functionManager.resolveFunction(session, QualifiedName.of("sum"), fromTypes(DOUBLE))); + countFunction = functionManager.getAggregateFunctionImplementation( + functionManager.resolveFunction(session, QualifiedName.of("count"), ImmutableList.of())); } @Override diff --git a/presto-benchmark/src/main/java/com/facebook/presto/benchmark/HandTpchQuery6.java b/presto-benchmark/src/main/java/com/facebook/presto/benchmark/HandTpchQuery6.java index 80a0bc251d708..4ee67bcb953b6 100644 --- a/presto-benchmark/src/main/java/com/facebook/presto/benchmark/HandTpchQuery6.java +++ b/presto-benchmark/src/main/java/com/facebook/presto/benchmark/HandTpchQuery6.java @@ -13,7 +13,8 @@ */ package com.facebook.presto.benchmark; -import com.facebook.presto.metadata.Signature; +import com.facebook.presto.Session; +import com.facebook.presto.metadata.FunctionManager; import com.facebook.presto.operator.AggregationOperator.AggregationOperatorFactory; import com.facebook.presto.operator.FilterAndProjectOperator; import com.facebook.presto.operator.OperatorFactory; @@ -29,6 +30,7 @@ import com.facebook.presto.sql.gen.PageFunctionCompiler; import com.facebook.presto.sql.planner.plan.AggregationNode.Step; import com.facebook.presto.sql.planner.plan.PlanNodeId; +import com.facebook.presto.sql.tree.QualifiedName; import com.facebook.presto.testing.LocalQueryRunner; import com.facebook.presto.util.DateTimeUtils; import com.google.common.collect.ImmutableList; @@ -39,11 +41,12 @@ import java.util.function.Supplier; import static com.facebook.presto.benchmark.BenchmarkQueryRunner.createLocalQueryRunner; -import static com.facebook.presto.metadata.FunctionKind.AGGREGATE; import static com.facebook.presto.spi.type.BigintType.BIGINT; import static com.facebook.presto.spi.type.DateType.DATE; import static com.facebook.presto.spi.type.DoubleType.DOUBLE; +import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes; import static com.facebook.presto.sql.relational.Expressions.field; +import static com.facebook.presto.testing.TestingSession.testSessionBuilder; import static io.airlift.units.DataSize.Unit.BYTE; public class HandTpchQuery6 @@ -54,9 +57,10 @@ public class HandTpchQuery6 public HandTpchQuery6(LocalQueryRunner localQueryRunner) { super(localQueryRunner, "hand_tpch_query_6", 10, 100); - - doubleSum = localQueryRunner.getMetadata().getFunctionManager().getAggregateFunctionImplementation( - new Signature("sum", AGGREGATE, DOUBLE.getTypeSignature(), DOUBLE.getTypeSignature())); + FunctionManager functionManager = localQueryRunner.getMetadata().getFunctionManager(); + Session session = testSessionBuilder().setCatalog("tpch").setSchema("tiny").build(); + doubleSum = functionManager.getAggregateFunctionImplementation( + functionManager.resolveFunction(session, QualifiedName.of("sum"), fromTypes(DOUBLE))); } @Override diff --git a/presto-benchmark/src/main/java/com/facebook/presto/benchmark/HashAggregationBenchmark.java b/presto-benchmark/src/main/java/com/facebook/presto/benchmark/HashAggregationBenchmark.java index 0b3e45ba073b3..a5f7a4560d830 100644 --- a/presto-benchmark/src/main/java/com/facebook/presto/benchmark/HashAggregationBenchmark.java +++ b/presto-benchmark/src/main/java/com/facebook/presto/benchmark/HashAggregationBenchmark.java @@ -13,13 +13,14 @@ */ package com.facebook.presto.benchmark; -import com.facebook.presto.metadata.Signature; +import com.facebook.presto.metadata.FunctionManager; import com.facebook.presto.operator.HashAggregationOperator.HashAggregationOperatorFactory; import com.facebook.presto.operator.OperatorFactory; import com.facebook.presto.operator.aggregation.InternalAggregationFunction; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.planner.plan.AggregationNode.Step; import com.facebook.presto.sql.planner.plan.PlanNodeId; +import com.facebook.presto.sql.tree.QualifiedName; import com.facebook.presto.testing.LocalQueryRunner; import com.google.common.collect.ImmutableList; import com.google.common.primitives.Ints; @@ -29,8 +30,9 @@ import java.util.Optional; import static com.facebook.presto.benchmark.BenchmarkQueryRunner.createLocalQueryRunner; -import static com.facebook.presto.metadata.FunctionKind.AGGREGATE; import static com.facebook.presto.spi.type.DoubleType.DOUBLE; +import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes; +import static com.facebook.presto.testing.TestingSession.testSessionBuilder; import static io.airlift.units.DataSize.Unit.MEGABYTE; public class HashAggregationBenchmark @@ -42,8 +44,9 @@ public HashAggregationBenchmark(LocalQueryRunner localQueryRunner) { super(localQueryRunner, "hash_agg", 5, 25); - doubleSum = localQueryRunner.getMetadata().getFunctionManager().getAggregateFunctionImplementation( - new Signature("sum", AGGREGATE, DOUBLE.getTypeSignature(), DOUBLE.getTypeSignature())); + FunctionManager functionManager = localQueryRunner.getMetadata().getFunctionManager(); + doubleSum = functionManager.getAggregateFunctionImplementation( + functionManager.resolveFunction(testSessionBuilder().build(), QualifiedName.of("sum"), fromTypes(DOUBLE))); } @Override diff --git a/presto-geospatial/src/test/java/com/facebook/presto/plugin/geospatial/TestSpatialPartitioningInternalAggregation.java b/presto-geospatial/src/test/java/com/facebook/presto/plugin/geospatial/TestSpatialPartitioningInternalAggregation.java index 7bdde305031a3..1b19b0e02e260 100644 --- a/presto-geospatial/src/test/java/com/facebook/presto/plugin/geospatial/TestSpatialPartitioningInternalAggregation.java +++ b/presto-geospatial/src/test/java/com/facebook/presto/plugin/geospatial/TestSpatialPartitioningInternalAggregation.java @@ -20,8 +20,7 @@ import com.facebook.presto.block.BlockAssertions; import com.facebook.presto.geospatial.KdbTreeUtils; import com.facebook.presto.geospatial.Rectangle; -import com.facebook.presto.metadata.FunctionKind; -import com.facebook.presto.metadata.Signature; +import com.facebook.presto.metadata.FunctionManager; import com.facebook.presto.operator.aggregation.Accumulator; import com.facebook.presto.operator.aggregation.AccumulatorFactory; import com.facebook.presto.operator.aggregation.GroupedAccumulator; @@ -30,7 +29,7 @@ import com.facebook.presto.spi.Page; import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.BlockBuilder; -import com.facebook.presto.spi.type.TypeSignature; +import com.facebook.presto.sql.tree.QualifiedName; import com.google.common.collect.ImmutableList; import com.google.common.primitives.Ints; import org.testng.annotations.BeforeClass; @@ -40,15 +39,15 @@ import java.util.List; import java.util.Optional; +import static com.facebook.presto.SessionTestUtils.TEST_SESSION; import static com.facebook.presto.geospatial.KdbTree.buildKdbTree; import static com.facebook.presto.geospatial.serde.GeometrySerde.serialize; import static com.facebook.presto.operator.aggregation.AggregationTestUtils.createGroupByIdBlock; import static com.facebook.presto.operator.aggregation.AggregationTestUtils.getFinalBlock; import static com.facebook.presto.operator.aggregation.AggregationTestUtils.getGroupValue; import static com.facebook.presto.plugin.geospatial.GeometryType.GEOMETRY; -import static com.facebook.presto.plugin.geospatial.GeometryType.GEOMETRY_TYPE_NAME; -import static com.facebook.presto.spi.type.StandardTypes.INTEGER; -import static com.facebook.presto.spi.type.StandardTypes.VARCHAR; +import static com.facebook.presto.spi.type.IntegerType.INTEGER; +import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes; import static com.google.common.math.DoubleMath.roundToInt; import static java.math.RoundingMode.CEILING; import static org.testng.Assert.assertEquals; @@ -98,15 +97,9 @@ public void test(int partitionCount) private InternalAggregationFunction getFunction() { - return functionAssertions - .getMetadata() - .getFunctionManager() - .getAggregateFunctionImplementation( - new Signature("spatial_partitioning", - FunctionKind.AGGREGATE, - TypeSignature.parseTypeSignature(VARCHAR), - TypeSignature.parseTypeSignature(GEOMETRY_TYPE_NAME), - TypeSignature.parseTypeSignature(INTEGER))); + FunctionManager functionManager = functionAssertions.getMetadata().getFunctionManager(); + return functionManager.getAggregateFunctionImplementation( + functionManager.resolveFunction(TEST_SESSION, QualifiedName.of("spatial_partitioning"), fromTypes(GEOMETRY, INTEGER))); } private List makeGeometries() diff --git a/presto-geospatial/src/test/java/com/facebook/presto/plugin/geospatial/aggregation/AbstractTestGeoAggregationFunctions.java b/presto-geospatial/src/test/java/com/facebook/presto/plugin/geospatial/aggregation/AbstractTestGeoAggregationFunctions.java index 9dac5a4572b5f..d3d1541bc1ab4 100644 --- a/presto-geospatial/src/test/java/com/facebook/presto/plugin/geospatial/aggregation/AbstractTestGeoAggregationFunctions.java +++ b/presto-geospatial/src/test/java/com/facebook/presto/plugin/geospatial/aggregation/AbstractTestGeoAggregationFunctions.java @@ -16,14 +16,13 @@ import com.esri.core.geometry.ogc.OGCGeometry; import com.facebook.presto.block.BlockAssertions; import com.facebook.presto.geospatial.serde.GeometrySerde; -import com.facebook.presto.metadata.FunctionKind; -import com.facebook.presto.metadata.Signature; +import com.facebook.presto.metadata.FunctionManager; import com.facebook.presto.operator.aggregation.InternalAggregationFunction; import com.facebook.presto.operator.scalar.AbstractTestFunctions; import com.facebook.presto.plugin.geospatial.GeoPlugin; -import com.facebook.presto.plugin.geospatial.GeometryType; import com.facebook.presto.spi.Page; import com.facebook.presto.spi.type.Type; +import com.facebook.presto.sql.tree.QualifiedName; import io.airlift.slice.Slice; import org.testng.annotations.BeforeClass; @@ -33,9 +32,11 @@ import java.util.function.BiFunction; import java.util.stream.Collectors; +import static com.facebook.presto.SessionTestUtils.TEST_SESSION; import static com.facebook.presto.metadata.FunctionExtractor.extractFunctions; import static com.facebook.presto.operator.aggregation.AggregationTestUtils.assertAggregation; -import static com.facebook.presto.spi.type.TypeSignature.parseTypeSignature; +import static com.facebook.presto.plugin.geospatial.GeometryType.GEOMETRY; +import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes; public abstract class AbstractTestGeoAggregationFunctions extends AbstractTestFunctions @@ -50,14 +51,9 @@ public void registerFunctions() functionAssertions.getTypeRegistry().addType(type); } functionAssertions.getMetadata().addFunctions(extractFunctions(plugin.getFunctions())); - function = functionAssertions - .getMetadata() - .getFunctionManager() - .getAggregateFunctionImplementation(new Signature( - getFunctionName(), - FunctionKind.AGGREGATE, - parseTypeSignature(GeometryType.GEOMETRY_TYPE_NAME), - parseTypeSignature(GeometryType.GEOMETRY_TYPE_NAME))); + FunctionManager functionManager = functionAssertions.getMetadata().getFunctionManager(); + function = functionManager.getAggregateFunctionImplementation( + functionManager.resolveFunction(TEST_SESSION, QualifiedName.of(getFunctionName()), fromTypes(GEOMETRY))); } protected void assertAggregatedGeometries(String testDescription, String expectedWkt, String... wkts) diff --git a/presto-main/src/main/java/com/facebook/presto/metadata/FunctionManager.java b/presto-main/src/main/java/com/facebook/presto/metadata/FunctionManager.java index b84c51e1708c6..906e4ada7b475 100644 --- a/presto-main/src/main/java/com/facebook/presto/metadata/FunctionManager.java +++ b/presto-main/src/main/java/com/facebook/presto/metadata/FunctionManager.java @@ -82,9 +82,9 @@ public WindowFunctionSupplier getWindowFunctionImplementation(FunctionHandle fun return globalFunctionNamespace.getWindowFunctionImplementation(functionHandle); } - public InternalAggregationFunction getAggregateFunctionImplementation(Signature signature) + public InternalAggregationFunction getAggregateFunctionImplementation(FunctionHandle functionHandle) { - return globalFunctionNamespace.getAggregateFunctionImplementation(signature); + return globalFunctionNamespace.getAggregateFunctionImplementation(functionHandle); } public ScalarFunctionImplementation getScalarFunctionImplementation(Signature signature) diff --git a/presto-main/src/main/java/com/facebook/presto/metadata/FunctionNamespace.java b/presto-main/src/main/java/com/facebook/presto/metadata/FunctionNamespace.java index 2d8443fb23dec..c2651bb79771f 100644 --- a/presto-main/src/main/java/com/facebook/presto/metadata/FunctionNamespace.java +++ b/presto-main/src/main/java/com/facebook/presto/metadata/FunctionNamespace.java @@ -58,9 +58,9 @@ public WindowFunctionSupplier getWindowFunctionImplementation(FunctionHandle fun return registry.getWindowFunctionImplementation(functionHandle); } - public InternalAggregationFunction getAggregateFunctionImplementation(Signature signature) + public InternalAggregationFunction getAggregateFunctionImplementation(FunctionHandle functionHandle) { - return registry.getAggregateFunctionImplementation(signature); + return registry.getAggregateFunctionImplementation(functionHandle); } public ScalarFunctionImplementation getScalarFunctionImplementation(Signature signature) diff --git a/presto-main/src/main/java/com/facebook/presto/metadata/FunctionRegistry.java b/presto-main/src/main/java/com/facebook/presto/metadata/FunctionRegistry.java index 11aed1a29d349..8289f2697b3c2 100644 --- a/presto-main/src/main/java/com/facebook/presto/metadata/FunctionRegistry.java +++ b/presto-main/src/main/java/com/facebook/presto/metadata/FunctionRegistry.java @@ -932,8 +932,9 @@ public WindowFunctionSupplier getWindowFunctionImplementation(FunctionHandle fun } } - public InternalAggregationFunction getAggregateFunctionImplementation(Signature signature) + public InternalAggregationFunction getAggregateFunctionImplementation(FunctionHandle functionHandle) { + Signature signature = functionHandle.getSignature(); checkArgument(signature.getKind() == AGGREGATE, "%s is not an aggregate function", signature); checkArgument(signature.getTypeVariableConstraints().isEmpty(), "%s has unbound type parameters", signature); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/LocalExecutionPlanner.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/LocalExecutionPlanner.java index 54a5e51af8e09..7cef2d1edc76f 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/LocalExecutionPlanner.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/LocalExecutionPlanner.java @@ -2522,7 +2522,7 @@ private AccumulatorFactory buildAccumulatorFactory( { InternalAggregationFunction internalAggregationFunction = metadata .getFunctionManager() - .getAggregateFunctionImplementation(aggregation.getSignature()); + .getAggregateFunctionImplementation(aggregation.getFunctionHandle()); List valueChannels = new ArrayList<>(); for (Expression argument : aggregation.getCall().getArguments()) { @@ -2538,7 +2538,7 @@ private AccumulatorFactory buildAccumulatorFactory( .map(LambdaExpression.class::cast) .collect(toImmutableList()); if (!lambdaExpressions.isEmpty()) { - List functionTypes = aggregation.getSignature().getArgumentTypes().stream() + List functionTypes = aggregation.getFunctionHandle().getSignature().getArgumentTypes().stream() .filter(typeSignature -> typeSignature.getBase().equals(FunctionType.NAME)) .map(typeSignature -> (FunctionType) (metadata.getTypeManager().getType(typeSignature))) .collect(toImmutableList()); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/LogicalPlanner.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/LogicalPlanner.java index 398616d3a378c..436968aa64e85 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/LogicalPlanner.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/LogicalPlanner.java @@ -149,7 +149,7 @@ public LogicalPlanner(Session session, this.idAllocator = requireNonNull(idAllocator, "idAllocator is null"); this.metadata = requireNonNull(metadata, "metadata is null"); this.sqlParser = requireNonNull(sqlParser, "sqlParser is null"); - this.statisticsAggregationPlanner = new StatisticsAggregationPlanner(symbolAllocator, metadata); + this.statisticsAggregationPlanner = new StatisticsAggregationPlanner(session, symbolAllocator, metadata); this.statsCalculator = requireNonNull(statsCalculator, "statsCalculator is null"); this.costCalculator = requireNonNull(costCalculator, "costCalculator is null"); this.warningCollector = requireNonNull(warningCollector, "warningCollector is null"); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java index 168014c551fad..ae5418ef4cda4 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java @@ -299,7 +299,7 @@ public PlanOptimizers( estimatedExchangesCostCalculator, ImmutableSet.of(new RemoveRedundantIdentityProjections())), new SetFlatteningOptimizer(), - new ImplementIntersectAndExceptAsUnion(), + new ImplementIntersectAndExceptAsUnion(metadata.getFunctionManager()), new LimitPushDown(), // Run the LimitPushDown after flattening set operators to make it easier to do the set flattening new PruneUnreferencedOutputs(), inlineProjections, @@ -313,7 +313,7 @@ public PlanOptimizers( statsCalculator, estimatedExchangesCostCalculator, ImmutableSet.of(new TransformExistsApplyToLateralNode(metadata.getFunctionManager()))), - new TransformQuantifiedComparisonApplyToLateralJoin(metadata), + new TransformQuantifiedComparisonApplyToLateralJoin(metadata.getFunctionManager()), new IterativeOptimizer( ruleStats, statsCalculator, @@ -330,7 +330,7 @@ public PlanOptimizers( estimatedExchangesCostCalculator, ImmutableSet.of( new RemoveUnreferencedScalarApplyNodes(), - new TransformCorrelatedInPredicateToJoin(), // must be run after PruneUnreferencedOutputs + new TransformCorrelatedInPredicateToJoin(metadata.getFunctionManager()), // must be run after PruneUnreferencedOutputs new TransformCorrelatedScalarSubquery(), // must be run after TransformCorrelatedScalarAggregationToJoin new TransformCorrelatedLateralJoinToJoin(), new ImplementFilteredAggregations())), @@ -367,7 +367,7 @@ public PlanOptimizers( ruleStats, statsCalculator, estimatedExchangesCostCalculator, - ImmutableSet.of(new SimplifyCountOverConstant())), + ImmutableSet.of(new SimplifyCountOverConstant(metadata.getFunctionManager()))), new LimitPushDown(), // Run LimitPushDown before WindowFilterPushDown new WindowFilterPushDown(metadata), // This must run after PredicatePushDown and LimitPushDown so that it squashes any successive filter nodes and limits new IterativeOptimizer( diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/QueryPlanner.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/QueryPlanner.java index 1e8fbd25f285d..8943ffbe87005 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/QueryPlanner.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/QueryPlanner.java @@ -556,7 +556,7 @@ private PlanBuilder aggregate(PlanBuilder subPlan, QuerySpecification node) } aggregationTranslations.put(aggregate, newSymbol); - aggregationsBuilder.put(newSymbol, new Aggregation((FunctionCall) rewritten, analysis.getFunctionHandle(aggregate).getSignature(), Optional.empty())); + aggregationsBuilder.put(newSymbol, new Aggregation((FunctionCall) rewritten, analysis.getFunctionHandle(aggregate), Optional.empty())); } Map aggregations = aggregationsBuilder.build(); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/StatisticsAggregationPlanner.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/StatisticsAggregationPlanner.java index 91f0ddc96bb60..ab5c9d750e496 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/StatisticsAggregationPlanner.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/StatisticsAggregationPlanner.java @@ -13,9 +13,10 @@ */ package com.facebook.presto.sql.planner; +import com.facebook.presto.Session; +import com.facebook.presto.metadata.FunctionHandle; import com.facebook.presto.metadata.FunctionManager; import com.facebook.presto.metadata.Metadata; -import com.facebook.presto.metadata.Signature; import com.facebook.presto.operator.aggregation.MaxDataSizeForStats; import com.facebook.presto.operator.aggregation.SumDataSizeForStats; import com.facebook.presto.spi.PrestoException; @@ -49,11 +50,13 @@ public class StatisticsAggregationPlanner { + private final Session session; private final SymbolAllocator symbolAllocator; private final Metadata metadata; - public StatisticsAggregationPlanner(SymbolAllocator symbolAllocator, Metadata metadata) + public StatisticsAggregationPlanner(Session session, SymbolAllocator symbolAllocator, Metadata metadata) { + this.session = requireNonNull(session, "session is null"); this.symbolAllocator = requireNonNull(symbolAllocator, "symbolAllocator is null"); this.metadata = requireNonNull(metadata, "metadata is null"); } @@ -80,7 +83,7 @@ public TableStatisticAggregation createStatisticsAggregation(TableStatisticsMeta QualifiedName count = QualifiedName.of("count"); AggregationNode.Aggregation aggregation = new AggregationNode.Aggregation( new FunctionCall(count, ImmutableList.of()), - functionManager.resolveFunction(count, ImmutableList.of()), + functionManager.resolveFunction(session, count, ImmutableList.of()), Optional.empty()); Symbol symbol = symbolAllocator.newSymbol("rowCount", BIGINT); aggregations.put(symbol, aggregation); @@ -128,13 +131,13 @@ private ColumnStatisticsAggregation createColumnAggregation(ColumnStatisticType private ColumnStatisticsAggregation createAggregation(QualifiedName functionName, SymbolReference input, Type inputType, Type outputType) { - Signature signature = metadata.getFunctionManager().resolveFunction(functionName, TypeSignatureProvider.fromTypes(ImmutableList.of(inputType))); - Type resolvedType = metadata.getType(getOnlyElement(signature.getArgumentTypes())); + FunctionHandle functionHandle = metadata.getFunctionManager().resolveFunction(session, functionName, TypeSignatureProvider.fromTypes(ImmutableList.of(inputType))); + Type resolvedType = metadata.getType(getOnlyElement(functionHandle.getSignature().getArgumentTypes())); verify(resolvedType.equals(inputType), "resolved function input type does not match the input type: %s != %s", resolvedType, inputType); return new ColumnStatisticsAggregation( new AggregationNode.Aggregation( new FunctionCall(functionName, ImmutableList.of(input)), - signature, + functionHandle, Optional.empty()), outputType); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/SymbolAllocator.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/SymbolAllocator.java index 1131e1f8f8c2f..34f07544b1c21 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/SymbolAllocator.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/SymbolAllocator.java @@ -20,6 +20,7 @@ import com.facebook.presto.sql.tree.FunctionCall; import com.facebook.presto.sql.tree.GroupingOperation; import com.facebook.presto.sql.tree.Identifier; +import com.facebook.presto.sql.tree.QualifiedName; import com.facebook.presto.sql.tree.SymbolReference; import com.google.common.primitives.Ints; @@ -51,6 +52,11 @@ public Symbol newSymbol(Symbol symbolHint) return newSymbol(symbolHint.getName(), symbols.get(symbolHint)); } + public Symbol newSymbol(QualifiedName nameHint, Type type) + { + return newSymbol(nameHint.getSuffix(), type, null); + } + public Symbol newSymbol(String nameHint, Type type) { return newSymbol(nameHint, type, null); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/AddIntermediateAggregations.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/AddIntermediateAggregations.java index bf4f04179a5fe..47c33e27b9c56 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/AddIntermediateAggregations.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/AddIntermediateAggregations.java @@ -27,7 +27,6 @@ import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.planner.plan.ProjectNode; import com.facebook.presto.sql.tree.FunctionCall; -import com.facebook.presto.sql.tree.QualifiedName; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -183,8 +182,8 @@ private static Map outputsAsInputs(Map builder.put( output, new Aggregation( - new FunctionCall(QualifiedName.of(aggregation.getSignature().getName()), ImmutableList.of(output.toSymbolReference())), - aggregation.getSignature(), + new FunctionCall(aggregation.getCall().getName(), ImmutableList.of(output.toSymbolReference())), + aggregation.getFunctionHandle(), Optional.empty())); // No mask for INTERMEDIATE } return builder.build(); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ExpressionRewriteRuleSet.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ExpressionRewriteRuleSet.java index 4fbead2e3140a..e610b18bd86f1 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ExpressionRewriteRuleSet.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ExpressionRewriteRuleSet.java @@ -152,7 +152,7 @@ public Result apply(AggregationNode aggregationNode, Captures captures, Context FunctionCall call = (FunctionCall) rewriter.rewrite(aggregation.getCall(), context); aggregations.put( entry.getKey(), - new Aggregation(call, aggregation.getSignature(), aggregation.getMask())); + new Aggregation(call, aggregation.getFunctionHandle(), aggregation.getMask())); if (!aggregation.getCall().equals(call)) { anyRewritten = true; } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ImplementFilteredAggregations.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ImplementFilteredAggregations.java index 3cc5a6f1030dd..9d9277ca79004 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ImplementFilteredAggregations.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ImplementFilteredAggregations.java @@ -107,7 +107,7 @@ public Result apply(AggregationNode aggregation, Captures captures, Context cont aggregations.put(output, new Aggregation( new FunctionCall(call.getName(), call.getWindow(), Optional.empty(), call.getOrderBy(), call.isDistinct(), call.getArguments()), - entry.getValue().getSignature(), + entry.getValue().getFunctionHandle(), mask)); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/MultipleDistinctAggregationToMarkDistinct.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/MultipleDistinctAggregationToMarkDistinct.java index b7836217b0e96..96c16f48d35e3 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/MultipleDistinctAggregationToMarkDistinct.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/MultipleDistinctAggregationToMarkDistinct.java @@ -157,7 +157,7 @@ public Result apply(AggregationNode parent, Captures captures, Context context) call.getOrderBy(), false, call.getArguments()), - aggregation.getSignature(), + aggregation.getFunctionHandle(), Optional.of(marker))); } else { diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneCountAggregationOverScalar.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneCountAggregationOverScalar.java index 40ca3b32ea1de..05765426a7cea 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneCountAggregationOverScalar.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneCountAggregationOverScalar.java @@ -15,7 +15,6 @@ import com.facebook.presto.matching.Captures; import com.facebook.presto.matching.Pattern; -import com.facebook.presto.metadata.Signature; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.AggregationNode; @@ -55,9 +54,8 @@ public Result apply(AggregationNode parent, Captures captures, Context context) for (Map.Entry entry : assignments.entrySet()) { AggregationNode.Aggregation aggregation = entry.getValue(); requireNonNull(aggregation, "aggregation is null"); - Signature signature = aggregation.getSignature(); FunctionCall functionCall = aggregation.getCall(); - if (!"count".equals(signature.getName()) || !functionCall.getArguments().isEmpty()) { + if (!"count".equals(functionCall.getName().getSuffix()) || !functionCall.getArguments().isEmpty()) { return Result.empty(); } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneOrderByInAggregation.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneOrderByInAggregation.java index f22d6e110cf2a..f25f3aa00f530 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneOrderByInAggregation.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneOrderByInAggregation.java @@ -60,7 +60,7 @@ public Result apply(AggregationNode node, Captures captures, Context context) aggregations.put(entry); } // getAggregateFunctionImplementation can be expensive, so check it last. - else if (functionManager.getAggregateFunctionImplementation(aggregation.getSignature()).isOrderSensitive()) { + else if (functionManager.getAggregateFunctionImplementation(aggregation.getFunctionHandle()).isOrderSensitive()) { aggregations.put(entry); } else { @@ -71,7 +71,7 @@ else if (functionManager.getAggregateFunctionImplementation(aggregation.getSigna aggregation.getCall().getArguments(), aggregation.getCall().getFilter()); - aggregations.put(entry.getKey(), new Aggregation(rewritten, aggregation.getSignature(), aggregation.getMask())); + aggregations.put(entry.getKey(), new Aggregation(rewritten, aggregation.getFunctionHandle(), aggregation.getMask())); } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushAggregationThroughOuterJoin.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushAggregationThroughOuterJoin.java index f9b6ec057bba6..67685af716e4c 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushAggregationThroughOuterJoin.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushAggregationThroughOuterJoin.java @@ -298,7 +298,7 @@ private Optional createAggregationOverNull(AggregationNod AggregationNode.Aggregation overNullAggregation = new AggregationNode.Aggregation( (FunctionCall) inlineSymbols(sourcesSymbolMapping, aggregation.getCall()), - aggregation.getSignature(), + aggregation.getFunctionHandle(), aggregation.getMask().map(x -> Symbol.from(sourcesSymbolMapping.get(x)))); Symbol overNullSymbol = symbolAllocator.newSymbol(overNullAggregation.getCall(), symbolAllocator.getTypes().get(aggregationSymbol)); aggregationsOverNullBuilder.put(overNullSymbol, overNullAggregation); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushPartialAggregationThroughExchange.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushPartialAggregationThroughExchange.java index f5dbb7a0e0ba1..d93361743a52c 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushPartialAggregationThroughExchange.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushPartialAggregationThroughExchange.java @@ -16,8 +16,8 @@ import com.facebook.presto.matching.Capture; import com.facebook.presto.matching.Captures; import com.facebook.presto.matching.Pattern; +import com.facebook.presto.metadata.FunctionHandle; import com.facebook.presto.metadata.FunctionManager; -import com.facebook.presto.metadata.Signature; import com.facebook.presto.operator.aggregation.InternalAggregationFunction; import com.facebook.presto.sql.planner.Partitioning; import com.facebook.presto.sql.planner.PartitioningScheme; @@ -202,18 +202,19 @@ private PlanNode split(AggregationNode node, Context context) Map finalAggregation = new HashMap<>(); for (Map.Entry entry : node.getAggregations().entrySet()) { AggregationNode.Aggregation originalAggregation = entry.getValue(); - Signature signature = originalAggregation.getSignature(); - InternalAggregationFunction function = functionManager.getAggregateFunctionImplementation(signature); - Symbol intermediateSymbol = context.getSymbolAllocator().newSymbol(signature.getName(), function.getIntermediateType()); + QualifiedName functionName = originalAggregation.getCall().getName(); + FunctionHandle functionHandle = originalAggregation.getFunctionHandle(); + InternalAggregationFunction function = functionManager.getAggregateFunctionImplementation(functionHandle); + Symbol intermediateSymbol = context.getSymbolAllocator().newSymbol(functionName, function.getIntermediateType()); checkState(!originalAggregation.getCall().getOrderBy().isPresent(), "Aggregate with ORDER BY does not support partial aggregation"); - intermediateAggregation.put(intermediateSymbol, new AggregationNode.Aggregation(originalAggregation.getCall(), signature, originalAggregation.getMask())); + intermediateAggregation.put(intermediateSymbol, new AggregationNode.Aggregation(originalAggregation.getCall(), functionHandle, originalAggregation.getMask())); // rewrite final aggregation in terms of intermediate function finalAggregation.put(entry.getKey(), new AggregationNode.Aggregation( new FunctionCall( - QualifiedName.of(signature.getName()), + functionName, ImmutableList.builder() .add(intermediateSymbol.toSymbolReference()) .addAll(originalAggregation.getCall().getArguments().stream() @@ -221,7 +222,7 @@ private PlanNode split(AggregationNode node, Context context) .collect(toImmutableList())) .build()), - signature, + functionHandle, Optional.empty())); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RewriteSpatialPartitioningAggregation.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RewriteSpatialPartitioningAggregation.java index 83c81c0e7e8c5..7e630ec4f2307 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RewriteSpatialPartitioningAggregation.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RewriteSpatialPartitioningAggregation.java @@ -16,7 +16,7 @@ import com.facebook.presto.matching.Captures; import com.facebook.presto.matching.Pattern; import com.facebook.presto.metadata.Metadata; -import com.facebook.presto.metadata.Signature; +import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.TypeSignature; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.iterative.Rule; @@ -34,10 +34,9 @@ import java.util.Map; import static com.facebook.presto.SystemSessionProperties.getHashPartitionCount; -import static com.facebook.presto.metadata.FunctionKind.AGGREGATE; import static com.facebook.presto.spi.type.IntegerType.INTEGER; import static com.facebook.presto.spi.type.TypeSignature.parseTypeSignature; -import static com.facebook.presto.spi.type.VarcharType.VARCHAR; +import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes; import static com.facebook.presto.sql.planner.plan.Patterns.aggregation; import static com.google.common.collect.Iterables.getOnlyElement; import static java.util.Objects.requireNonNull; @@ -62,7 +61,6 @@ public class RewriteSpatialPartitioningAggregation { private static final TypeSignature GEOMETRY_TYPE_SIGNATURE = parseTypeSignature("Geometry"); private static final String NAME = "spatial_partitioning"; - private static final Signature INTERNAL_SIGNATURE = new Signature(NAME, AGGREGATE, VARCHAR.getTypeSignature(), GEOMETRY_TYPE_SIGNATURE, INTEGER.getTypeSignature()); private static final Pattern PATTERN = aggregation() .matching(RewriteSpatialPartitioningAggregation::hasSpatialPartitioningAggregation); @@ -96,9 +94,10 @@ public Result apply(AggregationNode node, Captures captures, Context context) Aggregation aggregation = entry.getValue(); FunctionCall call = aggregation.getCall(); QualifiedName name = call.getName(); + Type geometryType = metadata.getType(GEOMETRY_TYPE_SIGNATURE); if (name.toString().equals(NAME) && call.getArguments().size() == 1) { Expression geometry = getOnlyElement(call.getArguments()); - Symbol envelopeSymbol = context.getSymbolAllocator().newSymbol("envelope", metadata.getType(GEOMETRY_TYPE_SIGNATURE)); + Symbol envelopeSymbol = context.getSymbolAllocator().newSymbol("envelope", geometryType); if (geometry instanceof FunctionCall && ((FunctionCall) geometry).getName().toString().equalsIgnoreCase("ST_Envelope")) { envelopeAssignments.put(envelopeSymbol, geometry); } @@ -108,7 +107,7 @@ public Result apply(AggregationNode node, Captures captures, Context context) aggregations.put(entry.getKey(), new Aggregation( new FunctionCall(name, ImmutableList.of(envelopeSymbol.toSymbolReference(), partitionCountSymbol.toSymbolReference())), - INTERNAL_SIGNATURE, + metadata.getFunctionManager().resolveFunction(context.getSession(), QualifiedName.of(NAME), fromTypes(geometryType, INTEGER)), aggregation.getMask())); } else { diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/SimplifyCountOverConstant.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/SimplifyCountOverConstant.java index 90ae926dfc791..e160e7052f5eb 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/SimplifyCountOverConstant.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/SimplifyCountOverConstant.java @@ -16,8 +16,7 @@ import com.facebook.presto.matching.Capture; import com.facebook.presto.matching.Captures; import com.facebook.presto.matching.Pattern; -import com.facebook.presto.metadata.Signature; -import com.facebook.presto.spi.type.StandardTypes; +import com.facebook.presto.metadata.FunctionManager; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.AggregationNode; @@ -36,11 +35,12 @@ import java.util.Map.Entry; import static com.facebook.presto.matching.Capture.newCapture; -import static com.facebook.presto.metadata.FunctionKind.AGGREGATE; -import static com.facebook.presto.spi.type.TypeSignature.parseTypeSignature; +import static com.facebook.presto.spi.type.BigintType.BIGINT; +import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes; import static com.facebook.presto.sql.planner.plan.Patterns.aggregation; import static com.facebook.presto.sql.planner.plan.Patterns.project; import static com.facebook.presto.sql.planner.plan.Patterns.source; +import static java.util.Objects.requireNonNull; public class SimplifyCountOverConstant implements Rule @@ -50,6 +50,13 @@ public class SimplifyCountOverConstant private static final Pattern PATTERN = aggregation() .with(source().matching(project().capturedAs(CHILD))); + private final FunctionManager functionManager; + + public SimplifyCountOverConstant(FunctionManager functionManager) + { + this.functionManager = requireNonNull(functionManager, "functionManager is null"); + } + @Override public Pattern getPattern() { @@ -72,7 +79,7 @@ public Result apply(AggregationNode parent, Captures captures, Context context) changed = true; aggregations.put(symbol, new AggregationNode.Aggregation( new FunctionCall(QualifiedName.of("count"), ImmutableList.of()), - new Signature("count", AGGREGATE, parseTypeSignature(StandardTypes.BIGINT)), + functionManager.resolveFunction(context.getSession(), QualifiedName.of("count"), fromTypes(BIGINT)), aggregation.getMask())); } } @@ -94,8 +101,8 @@ public Result apply(AggregationNode parent, Captures captures, Context context) private static boolean isCountOverConstant(AggregationNode.Aggregation aggregation, Assignments inputs) { - Signature signature = aggregation.getSignature(); - if (!signature.getName().equals("count") || signature.getArgumentTypes().size() != 1) { + FunctionCall call = aggregation.getCall(); + if (!call.getName().equals("count") || call.getArguments().size() != 1) { return false; } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/SingleDistinctAggregationToGroupBy.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/SingleDistinctAggregationToGroupBy.java index c5ea622135b54..e9214481f68ba 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/SingleDistinctAggregationToGroupBy.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/SingleDistinctAggregationToGroupBy.java @@ -167,7 +167,7 @@ private static AggregationNode.Aggregation removeDistinct(AggregationNode.Aggreg call.getOrderBy(), false, call.getArguments()), - aggregation.getSignature(), + aggregation.getFunctionHandle(), aggregation.getMask()); } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformCorrelatedInPredicateToJoin.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformCorrelatedInPredicateToJoin.java index 290963979bf30..67da6f9278ffe 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformCorrelatedInPredicateToJoin.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformCorrelatedInPredicateToJoin.java @@ -13,10 +13,10 @@ */ package com.facebook.presto.sql.planner.iterative.rule; +import com.facebook.presto.Session; import com.facebook.presto.matching.Captures; import com.facebook.presto.matching.Pattern; -import com.facebook.presto.metadata.FunctionKind; -import com.facebook.presto.metadata.Signature; +import com.facebook.presto.metadata.FunctionManager; import com.facebook.presto.sql.planner.PlanNodeIdAllocator; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.SymbolAllocator; @@ -97,6 +97,13 @@ public class TransformCorrelatedInPredicateToJoin private static final Pattern PATTERN = applyNode() .with(nonEmpty(correlation())); + private final FunctionManager functionManager; + + public TransformCorrelatedInPredicateToJoin(FunctionManager functionManager) + { + this.functionManager = requireNonNull(functionManager, "functionManager is null"); + } + @Override public Pattern getPattern() { @@ -118,10 +125,11 @@ public Result apply(ApplyNode apply, Captures captures, Context context) InPredicate inPredicate = (InPredicate) assignmentExpression; Symbol inPredicateOutputSymbol = getOnlyElement(subqueryAssignments.getSymbols()); - return apply(apply, inPredicate, inPredicateOutputSymbol, context.getLookup(), context.getIdAllocator(), context.getSymbolAllocator()); + return apply(context.getSession(), apply, inPredicate, inPredicateOutputSymbol, context.getLookup(), context.getIdAllocator(), context.getSymbolAllocator()); } private Result apply( + Session session, ApplyNode apply, InPredicate inPredicate, Symbol inPredicateOutputSymbol, @@ -137,6 +145,7 @@ private Result apply( } PlanNode projection = buildInPredicateEquivalent( + session, apply, inPredicate, inPredicateOutputSymbol, @@ -148,6 +157,7 @@ private Result apply( } private PlanNode buildInPredicateEquivalent( + Session session, ApplyNode apply, InPredicate inPredicate, Symbol inPredicateOutputSymbol, @@ -199,8 +209,8 @@ private PlanNode buildInPredicateEquivalent( idAllocator.getNextId(), leftOuterJoin, ImmutableMap.builder() - .put(countMatchesSymbol, countWithFilter(matchCondition)) - .put(countNullMatchesSymbol, countWithFilter(nullMatchCondition)) + .put(countMatchesSymbol, countWithFilter(session, matchCondition)) + .put(countNullMatchesSymbol, countWithFilter(session, nullMatchCondition)) .build(), singleGroupingSet(probeSide.getOutputSymbols()), ImmutableList.of(), @@ -241,7 +251,7 @@ private static JoinNode leftOuterJoin(PlanNodeIdAllocator idAllocator, AssignUni Optional.empty()); } - private static AggregationNode.Aggregation countWithFilter(Expression condition) + private AggregationNode.Aggregation countWithFilter(Session session, Expression condition) { FunctionCall countCall = new FunctionCall( QualifiedName.of("count"), @@ -253,7 +263,7 @@ private static AggregationNode.Aggregation countWithFilter(Expression condition) return new AggregationNode.Aggregation( countCall, - new Signature("count", FunctionKind.AGGREGATE, BIGINT.getTypeSignature()), + functionManager.resolveFunction(session, QualifiedName.of("count"), ImmutableList.of()), Optional.empty()); /* mask */ } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformCorrelatedScalarAggregationToJoin.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformCorrelatedScalarAggregationToJoin.java index 2919d22916332..c7c545c353052 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformCorrelatedScalarAggregationToJoin.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformCorrelatedScalarAggregationToJoin.java @@ -96,7 +96,7 @@ public Result apply(LateralJoinNode lateralJoinNode, Captures captures, Context return Result.empty(); } - ScalarAggregationToJoinRewriter rewriter = new ScalarAggregationToJoinRewriter(functionManager, context.getSymbolAllocator(), context.getIdAllocator(), context.getLookup()); + ScalarAggregationToJoinRewriter rewriter = new ScalarAggregationToJoinRewriter(functionManager, context.getSession(), context.getSymbolAllocator(), context.getIdAllocator(), context.getLookup()); PlanNode rewrittenNode = rewriter.rewriteScalarAggregation(lateralJoinNode, aggregation.get()); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformExistsApplyToLateralNode.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformExistsApplyToLateralNode.java index 38d133fe46b93..22da0c631605f 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformExistsApplyToLateralNode.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformExistsApplyToLateralNode.java @@ -16,7 +16,6 @@ import com.facebook.presto.matching.Captures; import com.facebook.presto.matching.Pattern; import com.facebook.presto.metadata.FunctionManager; -import com.facebook.presto.metadata.Signature; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.optimizations.PlanNodeDecorrelator; @@ -81,12 +80,12 @@ public class TransformExistsApplyToLateralNode private static final QualifiedName COUNT = QualifiedName.of("count"); private static final FunctionCall COUNT_CALL = new FunctionCall(COUNT, ImmutableList.of()); - private final Signature countSignature; + + private final FunctionManager functionManager; public TransformExistsApplyToLateralNode(FunctionManager functionManager) { - requireNonNull(functionManager, "functionManager is null"); - countSignature = functionManager.resolveFunction(COUNT, ImmutableList.of()); + this.functionManager = requireNonNull(functionManager, "functionManager is null"); } @Override @@ -162,7 +161,10 @@ private PlanNode rewriteToDefaultAggregation(ApplyNode parent, Context context) new AggregationNode( context.getIdAllocator().getNextId(), parent.getSubquery(), - ImmutableMap.of(count, new Aggregation(COUNT_CALL, countSignature, Optional.empty())), + ImmutableMap.of(count, new Aggregation( + COUNT_CALL, + functionManager.resolveFunction(context.getSession(), COUNT, ImmutableList.of()), + Optional.empty())), globalAggregation(), ImmutableList.of(), AggregationNode.Step.SINGLE, diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/ImplementIntersectAndExceptAsUnion.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/ImplementIntersectAndExceptAsUnion.java index fef7cc07ba8f0..b1890ab7bcd3c 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/ImplementIntersectAndExceptAsUnion.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/ImplementIntersectAndExceptAsUnion.java @@ -15,7 +15,7 @@ import com.facebook.presto.Session; import com.facebook.presto.execution.warnings.WarningCollector; -import com.facebook.presto.metadata.Signature; +import com.facebook.presto.metadata.FunctionManager; import com.facebook.presto.spi.type.StandardTypes; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.ExpressionUtils; @@ -50,10 +50,9 @@ import java.util.Map; import java.util.Optional; -import static com.facebook.presto.metadata.FunctionKind.AGGREGATE; import static com.facebook.presto.spi.type.BigintType.BIGINT; import static com.facebook.presto.spi.type.BooleanType.BOOLEAN; -import static com.facebook.presto.spi.type.TypeSignature.parseTypeSignature; +import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes; import static com.facebook.presto.sql.planner.plan.AggregationNode.Step; import static com.facebook.presto.sql.planner.plan.AggregationNode.singleGroupingSet; import static com.facebook.presto.sql.tree.BooleanLiteral.TRUE_LITERAL; @@ -105,6 +104,13 @@ public class ImplementIntersectAndExceptAsUnion implements PlanOptimizer { + private final FunctionManager functionManager; + + public ImplementIntersectAndExceptAsUnion(FunctionManager functionManager) + { + this.functionManager = requireNonNull(functionManager, "functionManager is null"); + } + @Override public PlanNode optimize(PlanNode plan, Session session, TypeProvider types, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator, WarningCollector warningCollector) { @@ -114,19 +120,23 @@ public PlanNode optimize(PlanNode plan, Session session, TypeProvider types, Sym requireNonNull(symbolAllocator, "symbolAllocator is null"); requireNonNull(idAllocator, "idAllocator is null"); - return SimplePlanRewriter.rewriteWith(new Rewriter(idAllocator, symbolAllocator), plan); + return SimplePlanRewriter.rewriteWith(new Rewriter(session, functionManager, idAllocator, symbolAllocator), plan); } private static class Rewriter extends SimplePlanRewriter { private static final String MARKER = "marker"; - private static final Signature COUNT_AGGREGATION = new Signature("count", AGGREGATE, parseTypeSignature(StandardTypes.BIGINT), parseTypeSignature(StandardTypes.BOOLEAN)); + + private final Session session; + private final FunctionManager functionManager; private final PlanNodeIdAllocator idAllocator; private final SymbolAllocator symbolAllocator; - private Rewriter(PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator) + private Rewriter(Session session, FunctionManager functionManager, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator) { + this.session = requireNonNull(session, "session is null"); + this.functionManager = requireNonNull(functionManager, "functionManager is null"); this.idAllocator = requireNonNull(idAllocator, "idAllocator is null"); this.symbolAllocator = requireNonNull(symbolAllocator, "symbolAllocator is null"); } @@ -235,9 +245,10 @@ private AggregationNode computeCounts(UnionNode sourceNode, List origina for (int i = 0; i < markers.size(); i++) { Symbol output = aggregationOutputs.get(i); + QualifiedName name = QualifiedName.of("count"); aggregations.put(output, new Aggregation( - new FunctionCall(QualifiedName.of("count"), ImmutableList.of(markers.get(i).toSymbolReference())), - COUNT_AGGREGATION, + new FunctionCall(name, ImmutableList.of(markers.get(i).toSymbolReference())), + functionManager.resolveFunction(session, name, fromTypes(BIGINT)), Optional.empty())); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/OptimizeMixedDistinctAggregations.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/OptimizeMixedDistinctAggregations.java index 009306f335e20..27f3e94c3dd94 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/OptimizeMixedDistinctAggregations.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/OptimizeMixedDistinctAggregations.java @@ -16,7 +16,6 @@ import com.facebook.presto.Session; import com.facebook.presto.execution.warnings.WarningCollector; import com.facebook.presto.metadata.Metadata; -import com.facebook.presto.metadata.Signature; import com.facebook.presto.spi.type.BigintType; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.analyzer.TypeSignatureProvider; @@ -87,7 +86,7 @@ public OptimizeMixedDistinctAggregations(Metadata metadata) public PlanNode optimize(PlanNode plan, Session session, TypeProvider types, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator, WarningCollector warningCollector) { if (isOptimizeDistinctAggregationEnabled(session)) { - return SimplePlanRewriter.rewriteWith(new Optimizer(idAllocator, symbolAllocator, metadata), plan, Optional.empty()); + return SimplePlanRewriter.rewriteWith(new Optimizer(idAllocator, symbolAllocator, metadata, session), plan, Optional.empty()); } return plan; @@ -99,12 +98,14 @@ private static class Optimizer private final PlanNodeIdAllocator idAllocator; private final SymbolAllocator symbolAllocator; private final Metadata metadata; + private final Session session; - private Optimizer(PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Metadata metadata) + private Optimizer(PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Metadata metadata, Session session) { this.idAllocator = requireNonNull(idAllocator, "idAllocator is null"); this.symbolAllocator = requireNonNull(symbolAllocator, "symbolAllocator is null"); this.metadata = requireNonNull(metadata, "metadata is null"); + this.session = requireNonNull(session, "session is null"); } @Override @@ -164,20 +165,19 @@ public PlanNode visitAggregation(AggregationNode node, RewriteContext createAggregationNode( COUNT, ImmutableList.of(nonNullableAggregationSourceSymbol.toSymbolReference())), functionManager.resolveFunction( + session, COUNT, fromTypeSignatures(scalarAggregationSourceTypeSignatures)), entry.getValue().getMask())); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/SymbolMapper.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/SymbolMapper.java index 304ac0ed77887..b14507ffa54da 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/SymbolMapper.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/SymbolMapper.java @@ -113,7 +113,7 @@ private Aggregation map(Aggregation aggregation) { return new Aggregation( (FunctionCall) map(aggregation.getCall()), - aggregation.getSignature(), + aggregation.getFunctionHandle(), aggregation.getMask().map(this::map)); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/TransformQuantifiedComparisonApplyToLateralJoin.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/TransformQuantifiedComparisonApplyToLateralJoin.java index 4af5e75aeabbe..89df84a030e2e 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/TransformQuantifiedComparisonApplyToLateralJoin.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/TransformQuantifiedComparisonApplyToLateralJoin.java @@ -16,12 +16,11 @@ import com.facebook.presto.Session; import com.facebook.presto.execution.warnings.WarningCollector; import com.facebook.presto.metadata.FunctionManager; -import com.facebook.presto.metadata.Metadata; import com.facebook.presto.spi.type.BigintType; import com.facebook.presto.spi.type.BooleanType; import com.facebook.presto.spi.type.Type; -import com.facebook.presto.spi.type.TypeSignature; import com.facebook.presto.sql.ExpressionUtils; +import com.facebook.presto.sql.analyzer.TypeSignatureProvider; import com.facebook.presto.sql.planner.PlanNodeIdAllocator; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.SymbolAllocator; @@ -55,7 +54,7 @@ import java.util.function.Function; import static com.facebook.presto.sql.ExpressionUtils.combineConjuncts; -import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypeSignatures; +import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes; import static com.facebook.presto.sql.planner.plan.AggregationNode.globalAggregation; import static com.facebook.presto.sql.planner.plan.SimplePlanRewriter.rewriteWith; import static com.facebook.presto.sql.tree.BooleanLiteral.FALSE_LITERAL; @@ -75,17 +74,17 @@ public class TransformQuantifiedComparisonApplyToLateralJoin implements PlanOptimizer { - private final Metadata metadata; + private final FunctionManager functionManager; - public TransformQuantifiedComparisonApplyToLateralJoin(Metadata metadata) + public TransformQuantifiedComparisonApplyToLateralJoin(FunctionManager functionManager) { - this.metadata = requireNonNull(metadata, "metadata is null"); + this.functionManager = requireNonNull(functionManager, "functionManager is null"); } @Override public PlanNode optimize(PlanNode plan, Session session, TypeProvider types, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator, WarningCollector warningCollector) { - return rewriteWith(new Rewriter(idAllocator, types, symbolAllocator, metadata), plan, null); + return rewriteWith(new Rewriter(functionManager, session, idAllocator, types, symbolAllocator), plan, null); } private static class Rewriter @@ -95,17 +94,19 @@ private static class Rewriter private static final QualifiedName MAX = QualifiedName.of("max"); private static final QualifiedName COUNT = QualifiedName.of("count"); + private final FunctionManager functionManager; + private final Session session; private final PlanNodeIdAllocator idAllocator; private final TypeProvider types; private final SymbolAllocator symbolAllocator; - private final Metadata metadata; - public Rewriter(PlanNodeIdAllocator idAllocator, TypeProvider types, SymbolAllocator symbolAllocator, Metadata metadata) + public Rewriter(FunctionManager functionManager, Session session, PlanNodeIdAllocator idAllocator, TypeProvider types, SymbolAllocator symbolAllocator) { + this.functionManager = requireNonNull(functionManager, "functionManager is null"); + this.session = requireNonNull(session, "session is null"); this.idAllocator = requireNonNull(idAllocator, "idAllocator is null"); this.types = requireNonNull(types, "types is null"); this.symbolAllocator = requireNonNull(symbolAllocator, "symbolAllocator is null"); - this.metadata = requireNonNull(metadata, "metadata is null"); } @Override @@ -138,9 +139,8 @@ private PlanNode rewriteQuantifiedApplyNode(ApplyNode node, QuantifiedComparison Symbol countAllValue = symbolAllocator.newSymbol("count_all", BigintType.BIGINT); Symbol countNonNullValue = symbolAllocator.newSymbol("count_non_null", BigintType.BIGINT); - FunctionManager functionManager = metadata.getFunctionManager(); List outputColumnReferences = ImmutableList.of(outputColumn.toSymbolReference()); - List outputColumnTypeSignature = ImmutableList.of(outputColumnType.getTypeSignature()); + List outputColumnTypeSignatures = fromTypes(outputColumnType); subqueryPlan = new AggregationNode( idAllocator.getNextId(), @@ -148,19 +148,19 @@ private PlanNode rewriteQuantifiedApplyNode(ApplyNode node, QuantifiedComparison ImmutableMap.of( minValue, new Aggregation( new FunctionCall(MIN, outputColumnReferences), - functionManager.resolveFunction(MIN, fromTypeSignatures(outputColumnTypeSignature)), + functionManager.resolveFunction(session, MIN, outputColumnTypeSignatures), Optional.empty()), maxValue, new Aggregation( new FunctionCall(MAX, outputColumnReferences), - functionManager.resolveFunction(MAX, fromTypeSignatures(outputColumnTypeSignature)), + functionManager.resolveFunction(session, MAX, outputColumnTypeSignatures), Optional.empty()), countAllValue, new Aggregation( new FunctionCall(COUNT, emptyList()), - functionManager.resolveFunction(COUNT, emptyList()), + functionManager.resolveFunction(session, COUNT, emptyList()), Optional.empty()), countNonNullValue, new Aggregation( new FunctionCall(COUNT, outputColumnReferences), - functionManager.resolveFunction(COUNT, fromTypeSignatures(outputColumnTypeSignature)), + functionManager.resolveFunction(session, COUNT, outputColumnTypeSignatures), Optional.empty())), globalAggregation(), ImmutableList.of(), diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/AggregationNode.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/AggregationNode.java index c28456eadf862..5f97a0e9ccf83 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/AggregationNode.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/AggregationNode.java @@ -13,8 +13,8 @@ */ package com.facebook.presto.sql.planner.plan; +import com.facebook.presto.metadata.FunctionHandle; import com.facebook.presto.metadata.FunctionManager; -import com.facebook.presto.metadata.Signature; import com.facebook.presto.operator.aggregation.InternalAggregationFunction; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.tree.FunctionCall; @@ -216,7 +216,7 @@ public boolean isDecomposable(FunctionManager functionManager) .anyMatch(FunctionCall::isDistinct); boolean decomposableFunctions = getAggregations().values().stream() - .map(Aggregation::getSignature) + .map(Aggregation::getFunctionHandle) .map(functionManager::getAggregateFunctionImplementation) .allMatch(InternalAggregationFunction::isDecomposable); @@ -359,17 +359,17 @@ public static Step partialInput(Step step) public static class Aggregation { private final FunctionCall call; - private final Signature signature; + private final FunctionHandle functionHandle; private final Optional mask; @JsonCreator public Aggregation( @JsonProperty("call") FunctionCall call, - @JsonProperty("signature") Signature signature, + @JsonProperty("functionHandle") FunctionHandle functionHandle, @JsonProperty("mask") Optional mask) { this.call = requireNonNull(call, "call is null"); - this.signature = requireNonNull(signature, "signature is null"); + this.functionHandle = requireNonNull(functionHandle, "functionHandle is null"); this.mask = requireNonNull(mask, "mask is null"); } @@ -380,9 +380,9 @@ public FunctionCall getCall() } @JsonProperty - public Signature getSignature() + public FunctionHandle getFunctionHandle() { - return signature; + return functionHandle; } @JsonProperty diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/StatisticAggregations.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/StatisticAggregations.java index 463673586228b..b957db4bb7615 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/StatisticAggregations.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/StatisticAggregations.java @@ -13,14 +13,13 @@ */ package com.facebook.presto.sql.planner.plan; +import com.facebook.presto.metadata.FunctionHandle; import com.facebook.presto.metadata.FunctionManager; -import com.facebook.presto.metadata.Signature; import com.facebook.presto.operator.aggregation.InternalAggregationFunction; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.plan.AggregationNode.Aggregation; import com.facebook.presto.sql.tree.FunctionCall; -import com.facebook.presto.sql.tree.QualifiedName; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; @@ -65,15 +64,15 @@ public Parts createPartialAggregations(SymbolAllocator symbolAllocator, Function ImmutableMap.Builder mappings = ImmutableMap.builder(); for (Map.Entry entry : aggregations.entrySet()) { Aggregation originalAggregation = entry.getValue(); - Signature signature = originalAggregation.getSignature(); - InternalAggregationFunction function = functionManager.getAggregateFunctionImplementation(signature); - Symbol partialSymbol = symbolAllocator.newSymbol(signature.getName(), function.getIntermediateType()); + FunctionHandle functionHandle = originalAggregation.getFunctionHandle(); + InternalAggregationFunction function = functionManager.getAggregateFunctionImplementation(functionHandle); + Symbol partialSymbol = symbolAllocator.newSymbol(originalAggregation.getCall().getName(), function.getIntermediateType()); mappings.put(entry.getKey(), partialSymbol); - partialAggregation.put(partialSymbol, new Aggregation(originalAggregation.getCall(), signature, originalAggregation.getMask())); + partialAggregation.put(partialSymbol, new Aggregation(originalAggregation.getCall(), functionHandle, originalAggregation.getMask())); finalAggregation.put(entry.getKey(), new Aggregation( - new FunctionCall(QualifiedName.of(signature.getName()), ImmutableList.of(partialSymbol.toSymbolReference())), - signature, + new FunctionCall(originalAggregation.getCall().getName(), ImmutableList.of(partialSymbol.toSymbolReference())), + functionHandle, Optional.empty())); } groupingSymbols.forEach(symbol -> mappings.put(symbol, symbol)); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/sanity/TypeValidator.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/sanity/TypeValidator.java index 2f5e09d74297b..a0f7e7b147015 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/sanity/TypeValidator.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/sanity/TypeValidator.java @@ -173,7 +173,7 @@ private void checkCall(Symbol symbol, FunctionCall call) private void checkFunctionSignature(Map aggregations) { for (Map.Entry entry : aggregations.entrySet()) { - checkSignature(entry.getKey(), entry.getValue().getSignature()); + checkSignature(entry.getKey(), entry.getValue().getFunctionHandle().getSignature()); } } diff --git a/presto-main/src/test/java/com/facebook/presto/cost/TestCostCalculator.java b/presto-main/src/test/java/com/facebook/presto/cost/TestCostCalculator.java index 05344a6d662dd..4ca515e9ac337 100644 --- a/presto-main/src/test/java/com/facebook/presto/cost/TestCostCalculator.java +++ b/presto-main/src/test/java/com/facebook/presto/cost/TestCostCalculator.java @@ -24,13 +24,11 @@ import com.facebook.presto.metadata.CatalogManager; import com.facebook.presto.metadata.InMemoryNodeManager; import com.facebook.presto.metadata.MetadataManager; -import com.facebook.presto.metadata.Signature; import com.facebook.presto.metadata.TableHandle; import com.facebook.presto.metadata.TableLayoutHandle; import com.facebook.presto.security.AllowAllAccessControl; import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.predicate.TupleDomain; -import com.facebook.presto.spi.type.StandardTypes; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.analyzer.FeaturesConfig; import com.facebook.presto.sql.planner.NodePartitioningManager; @@ -73,10 +71,9 @@ import java.util.Optional; import java.util.function.Function; -import static com.facebook.presto.metadata.FunctionKind.AGGREGATE; +import static com.facebook.presto.SessionTestUtils.TEST_SESSION; import static com.facebook.presto.metadata.MetadataManager.createTestMetadataManager; import static com.facebook.presto.spi.type.BigintType.BIGINT; -import static com.facebook.presto.spi.type.TypeSignature.parseTypeSignature; import static com.facebook.presto.spi.type.VarcharType.VARCHAR; import static com.facebook.presto.sql.planner.plan.AggregationNode.singleGroupingSet; import static com.facebook.presto.sql.planner.plan.ExchangeNode.Scope.LOCAL; @@ -768,7 +765,7 @@ private AggregationNode aggregation(String id, PlanNode source) { AggregationNode.Aggregation aggregation = new AggregationNode.Aggregation( new FunctionCall(QualifiedName.of("count"), ImmutableList.of()), - new Signature("count", AGGREGATE, parseTypeSignature(StandardTypes.BIGINT)), + metadata.getFunctionManager().resolveFunction(TEST_SESSION, QualifiedName.of("count"), ImmutableList.of()), Optional.empty()); return new AggregationNode( diff --git a/presto-main/src/test/java/com/facebook/presto/operator/BenchmarkHashAndStreamingAggregationOperators.java b/presto-main/src/test/java/com/facebook/presto/operator/BenchmarkHashAndStreamingAggregationOperators.java index 30ad15feb2f7a..9256797006401 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/BenchmarkHashAndStreamingAggregationOperators.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/BenchmarkHashAndStreamingAggregationOperators.java @@ -14,8 +14,8 @@ package com.facebook.presto.operator; import com.facebook.presto.RowPagesBuilder; +import com.facebook.presto.metadata.FunctionManager; import com.facebook.presto.metadata.MetadataManager; -import com.facebook.presto.metadata.Signature; import com.facebook.presto.operator.HashAggregationOperator.HashAggregationOperatorFactory; import com.facebook.presto.operator.StreamingAggregationOperator.StreamingAggregationOperatorFactory; import com.facebook.presto.operator.aggregation.InternalAggregationFunction; @@ -26,6 +26,7 @@ import com.facebook.presto.sql.gen.JoinCompiler; import com.facebook.presto.sql.planner.plan.AggregationNode; import com.facebook.presto.sql.planner.plan.PlanNodeId; +import com.facebook.presto.sql.tree.QualifiedName; import com.facebook.presto.testing.TestingTaskContext; import com.google.common.collect.ImmutableList; import io.airlift.units.DataSize; @@ -53,11 +54,11 @@ import static com.facebook.presto.SessionTestUtils.TEST_SESSION; import static com.facebook.presto.block.BlockAssertions.createLongSequenceBlock; -import static com.facebook.presto.metadata.FunctionKind.AGGREGATE; import static com.facebook.presto.operator.BenchmarkHashAndStreamingAggregationOperators.Context.ROWS_PER_PAGE; import static com.facebook.presto.operator.BenchmarkHashAndStreamingAggregationOperators.Context.TOTAL_PAGES; import static com.facebook.presto.spi.type.BigintType.BIGINT; import static com.facebook.presto.spi.type.VarcharType.VARCHAR; +import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes; import static io.airlift.concurrent.Threads.daemonThreadsNamed; import static io.airlift.units.DataSize.Unit.GIGABYTE; import static io.airlift.units.DataSize.Unit.MEGABYTE; @@ -80,11 +81,12 @@ public class BenchmarkHashAndStreamingAggregationOperators { private static final MetadataManager metadata = MetadataManager.createTestMetadataManager(); + private static final FunctionManager functionManager = metadata.getFunctionManager(); - private static final InternalAggregationFunction LONG_SUM = metadata.getFunctionManager().getAggregateFunctionImplementation( - new Signature("sum", AGGREGATE, BIGINT.getTypeSignature(), BIGINT.getTypeSignature())); - private static final InternalAggregationFunction COUNT = metadata.getFunctionManager().getAggregateFunctionImplementation( - new Signature("count", AGGREGATE, BIGINT.getTypeSignature())); + private static final InternalAggregationFunction LONG_SUM = functionManager.getAggregateFunctionImplementation( + functionManager.resolveFunction(TEST_SESSION, QualifiedName.of("sum"), fromTypes(BIGINT))); + private static final InternalAggregationFunction COUNT = functionManager.getAggregateFunctionImplementation( + functionManager.resolveFunction(TEST_SESSION, QualifiedName.of("count"), ImmutableList.of())); @State(Thread) public static class Context @@ -144,12 +146,12 @@ private OperatorFactory createStreamingAggregationOperatorFactory() AggregationNode.Step.SINGLE, ImmutableList.of(COUNT.bind(ImmutableList.of(0), Optional.empty()), LONG_SUM.bind(ImmutableList.of(1), Optional.empty())), - new JoinCompiler(MetadataManager.createTestMetadataManager(), new FeaturesConfig())); + new JoinCompiler(metadata, new FeaturesConfig())); } private OperatorFactory createHashAggregationOperatorFactory(Optional hashChannel) { - JoinCompiler joinCompiler = new JoinCompiler(MetadataManager.createTestMetadataManager(), new FeaturesConfig()); + JoinCompiler joinCompiler = new JoinCompiler(metadata, new FeaturesConfig()); SpillerFactory spillerFactory = (types, localSpillContext, aggregatedMemoryContext) -> null; return new HashAggregationOperatorFactory( diff --git a/presto-main/src/test/java/com/facebook/presto/operator/TestAggregationOperator.java b/presto-main/src/test/java/com/facebook/presto/operator/TestAggregationOperator.java index f5f6cbc5e5ddf..fa3404d9531fa 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/TestAggregationOperator.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/TestAggregationOperator.java @@ -13,14 +13,15 @@ */ package com.facebook.presto.operator; +import com.facebook.presto.metadata.FunctionManager; import com.facebook.presto.metadata.MetadataManager; -import com.facebook.presto.metadata.Signature; import com.facebook.presto.operator.AggregationOperator.AggregationOperatorFactory; import com.facebook.presto.operator.aggregation.InternalAggregationFunction; import com.facebook.presto.spi.Page; -import com.facebook.presto.spi.type.StandardTypes; +import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.planner.plan.AggregationNode.Step; import com.facebook.presto.sql.planner.plan.PlanNodeId; +import com.facebook.presto.sql.tree.QualifiedName; import com.facebook.presto.testing.MaterializedResult; import com.google.common.collect.ImmutableList; import org.testng.annotations.AfterMethod; @@ -34,14 +35,13 @@ import static com.facebook.presto.RowPagesBuilder.rowPagesBuilder; import static com.facebook.presto.SessionTestUtils.TEST_SESSION; -import static com.facebook.presto.metadata.FunctionKind.AGGREGATE; import static com.facebook.presto.operator.OperatorAssertion.assertOperatorEquals; import static com.facebook.presto.operator.OperatorAssertion.toPages; import static com.facebook.presto.spi.type.BigintType.BIGINT; import static com.facebook.presto.spi.type.DoubleType.DOUBLE; import static com.facebook.presto.spi.type.RealType.REAL; -import static com.facebook.presto.spi.type.TypeSignature.parseTypeSignature; import static com.facebook.presto.spi.type.VarcharType.VARCHAR; +import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes; import static com.facebook.presto.testing.MaterializedResult.resultBuilder; import static com.facebook.presto.testing.TestingTaskContext.createTaskContext; import static com.google.common.collect.Iterables.getOnlyElement; @@ -56,18 +56,13 @@ @Test(singleThreaded = true) public class TestAggregationOperator { - private static final MetadataManager metadata = MetadataManager.createTestMetadataManager(); - - private static final InternalAggregationFunction LONG_AVERAGE = metadata.getFunctionManager().getAggregateFunctionImplementation( - new Signature("avg", AGGREGATE, DOUBLE.getTypeSignature(), BIGINT.getTypeSignature())); - private static final InternalAggregationFunction DOUBLE_SUM = metadata.getFunctionManager().getAggregateFunctionImplementation( - new Signature("sum", AGGREGATE, DOUBLE.getTypeSignature(), DOUBLE.getTypeSignature())); - private static final InternalAggregationFunction LONG_SUM = metadata.getFunctionManager().getAggregateFunctionImplementation( - new Signature("sum", AGGREGATE, BIGINT.getTypeSignature(), BIGINT.getTypeSignature())); - private static final InternalAggregationFunction REAL_SUM = metadata.getFunctionManager().getAggregateFunctionImplementation( - new Signature("sum", AGGREGATE, REAL.getTypeSignature(), REAL.getTypeSignature())); - private static final InternalAggregationFunction COUNT = metadata.getFunctionManager().getAggregateFunctionImplementation( - new Signature("count", AGGREGATE, BIGINT.getTypeSignature())); + private static final FunctionManager functionManager = MetadataManager.createTestMetadataManager().getFunctionManager(); + + private static final InternalAggregationFunction LONG_AVERAGE = getAggregation("avg", BIGINT); + private static final InternalAggregationFunction DOUBLE_SUM = getAggregation("sum", DOUBLE); + private static final InternalAggregationFunction LONG_SUM = getAggregation("sum", BIGINT); + private static final InternalAggregationFunction REAL_SUM = getAggregation("sum", REAL); + private static final InternalAggregationFunction COUNT = getAggregation("count", BIGINT); private ExecutorService executor; private ScheduledExecutorService scheduledExecutor; @@ -89,11 +84,8 @@ public void tearDown() @Test public void testAggregation() { - MetadataManager metadata = MetadataManager.createTestMetadataManager(); - InternalAggregationFunction countVarcharColumn = metadata.getFunctionManager().getAggregateFunctionImplementation( - new Signature("count", AGGREGATE, parseTypeSignature(StandardTypes.BIGINT), parseTypeSignature(StandardTypes.VARCHAR))); - InternalAggregationFunction maxVarcharColumn = metadata.getFunctionManager().getAggregateFunctionImplementation( - new Signature("max", AGGREGATE, parseTypeSignature(StandardTypes.VARCHAR), parseTypeSignature(StandardTypes.VARCHAR))); + InternalAggregationFunction countVarcharColumn = getAggregation("count", VARCHAR); + InternalAggregationFunction maxVarcharColumn = getAggregation("max", VARCHAR); List input = rowPagesBuilder(VARCHAR, BIGINT, VARCHAR, BIGINT, REAL, DOUBLE, VARCHAR) .addSequencePage(100, 0, 0, 300, 500, 400, 500, 500) .build(); @@ -169,4 +161,9 @@ private void testMemoryTracking(boolean useSystemMemory) assertEquals(driverContext.getSystemMemoryUsage(), 0); assertEquals(driverContext.getMemoryUsage(), 0); } + + private static InternalAggregationFunction getAggregation(String name, Type... arguments) + { + return functionManager.getAggregateFunctionImplementation(functionManager.resolveFunction(TEST_SESSION, QualifiedName.of(name), fromTypes(arguments))); + } } diff --git a/presto-main/src/test/java/com/facebook/presto/operator/TestHashAggregationOperator.java b/presto-main/src/test/java/com/facebook/presto/operator/TestHashAggregationOperator.java index c08e42a9d79fb..514ac83fdf97b 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/TestHashAggregationOperator.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/TestHashAggregationOperator.java @@ -16,8 +16,8 @@ import com.facebook.presto.ExceededMemoryLimitException; import com.facebook.presto.RowPagesBuilder; import com.facebook.presto.memory.context.AggregatedMemoryContext; +import com.facebook.presto.metadata.FunctionManager; import com.facebook.presto.metadata.MetadataManager; -import com.facebook.presto.metadata.Signature; import com.facebook.presto.operator.HashAggregationOperator.HashAggregationOperatorFactory; import com.facebook.presto.operator.aggregation.InternalAggregationFunction; import com.facebook.presto.operator.aggregation.builder.HashAggregationBuilder; @@ -25,7 +25,6 @@ import com.facebook.presto.spi.Page; import com.facebook.presto.spi.block.BlockBuilder; import com.facebook.presto.spi.block.PageBuilderStatus; -import com.facebook.presto.spi.type.StandardTypes; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spiller.Spiller; import com.facebook.presto.spiller.SpillerFactory; @@ -33,6 +32,7 @@ import com.facebook.presto.sql.gen.JoinCompiler; import com.facebook.presto.sql.planner.plan.AggregationNode.Step; import com.facebook.presto.sql.planner.plan.PlanNodeId; +import com.facebook.presto.sql.tree.QualifiedName; import com.facebook.presto.testing.MaterializedResult; import com.facebook.presto.testing.TestingTaskContext; import com.google.common.collect.ImmutableList; @@ -56,7 +56,6 @@ import static com.facebook.presto.RowPagesBuilder.rowPagesBuilder; import static com.facebook.presto.SessionTestUtils.TEST_SESSION; -import static com.facebook.presto.metadata.FunctionKind.AGGREGATE; import static com.facebook.presto.operator.GroupByHashYieldAssertion.GroupByHashYieldResult; import static com.facebook.presto.operator.GroupByHashYieldAssertion.createPagesWithDistinctHashKeys; import static com.facebook.presto.operator.GroupByHashYieldAssertion.finishOperatorWithYieldingGroupByHash; @@ -67,8 +66,8 @@ import static com.facebook.presto.spi.type.BigintType.BIGINT; import static com.facebook.presto.spi.type.BooleanType.BOOLEAN; import static com.facebook.presto.spi.type.DoubleType.DOUBLE; -import static com.facebook.presto.spi.type.TypeSignature.parseTypeSignature; import static com.facebook.presto.spi.type.VarcharType.VARCHAR; +import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes; import static com.facebook.presto.testing.MaterializedResult.resultBuilder; import static com.facebook.presto.testing.TestingTaskContext.createTaskContext; import static com.google.common.base.Strings.nullToEmpty; @@ -96,14 +95,12 @@ @Test(singleThreaded = true) public class TestHashAggregationOperator { - private static final MetadataManager metadata = MetadataManager.createTestMetadataManager(); + private static final FunctionManager functionManager = MetadataManager.createTestMetadataManager().getFunctionManager(); - private static final InternalAggregationFunction LONG_AVERAGE = metadata.getFunctionManager().getAggregateFunctionImplementation( - new Signature("avg", AGGREGATE, DOUBLE.getTypeSignature(), BIGINT.getTypeSignature())); - private static final InternalAggregationFunction LONG_SUM = metadata.getFunctionManager().getAggregateFunctionImplementation( - new Signature("sum", AGGREGATE, BIGINT.getTypeSignature(), BIGINT.getTypeSignature())); - private static final InternalAggregationFunction COUNT = metadata.getFunctionManager().getAggregateFunctionImplementation( - new Signature("count", AGGREGATE, BIGINT.getTypeSignature())); + private static final InternalAggregationFunction LONG_AVERAGE = getAggregation("avg", BIGINT); + private static final InternalAggregationFunction LONG_SUM = getAggregation("sum", BIGINT); + private static final InternalAggregationFunction COUNT = functionManager.getAggregateFunctionImplementation( + functionManager.resolveFunction(TEST_SESSION, QualifiedName.of("count"), ImmutableList.of())); private static final int MAX_BLOCK_SIZE_IN_BYTES = 64 * 1024; @@ -154,13 +151,9 @@ public void tearDown() @Test(dataProvider = "hashEnabledAndMemoryLimitForMergeValues") public void testHashAggregation(boolean hashEnabled, boolean spillEnabled, long memoryLimitForMerge, long memoryLimitForMergeWithMemory) { - MetadataManager metadata = MetadataManager.createTestMetadataManager(); - InternalAggregationFunction countVarcharColumn = metadata.getFunctionManager().getAggregateFunctionImplementation( - new Signature("count", AGGREGATE, parseTypeSignature(StandardTypes.BIGINT), parseTypeSignature(StandardTypes.VARCHAR))); - InternalAggregationFunction countBooleanColumn = metadata.getFunctionManager().getAggregateFunctionImplementation( - new Signature("count", AGGREGATE, parseTypeSignature(StandardTypes.BIGINT), parseTypeSignature(StandardTypes.BOOLEAN))); - InternalAggregationFunction maxVarcharColumn = metadata.getFunctionManager().getAggregateFunctionImplementation( - new Signature("max", AGGREGATE, parseTypeSignature(StandardTypes.VARCHAR), parseTypeSignature(StandardTypes.VARCHAR))); + InternalAggregationFunction countVarcharColumn = getAggregation("count", VARCHAR); + InternalAggregationFunction countBooleanColumn = getAggregation("count", BOOLEAN); + InternalAggregationFunction maxVarcharColumn = getAggregation("max", VARCHAR); List hashChannels = Ints.asList(1); RowPagesBuilder rowPagesBuilder = rowPagesBuilder(hashEnabled, hashChannels, VARCHAR, VARCHAR, VARCHAR, BIGINT, BOOLEAN); List input = rowPagesBuilder @@ -216,13 +209,9 @@ public void testHashAggregation(boolean hashEnabled, boolean spillEnabled, long @Test(dataProvider = "hashEnabledAndMemoryLimitForMergeValues") public void testHashAggregationWithGlobals(boolean hashEnabled, boolean spillEnabled, long memoryLimitForMerge, long memoryLimitForMergeWithMemory) { - MetadataManager metadata = MetadataManager.createTestMetadataManager(); - InternalAggregationFunction countVarcharColumn = metadata.getFunctionManager().getAggregateFunctionImplementation( - new Signature("count", AGGREGATE, parseTypeSignature(StandardTypes.BIGINT), parseTypeSignature(StandardTypes.VARCHAR))); - InternalAggregationFunction countBooleanColumn = metadata.getFunctionManager().getAggregateFunctionImplementation( - new Signature("count", AGGREGATE, parseTypeSignature(StandardTypes.BIGINT), parseTypeSignature(StandardTypes.BOOLEAN))); - InternalAggregationFunction maxVarcharColumn = metadata.getFunctionManager().getAggregateFunctionImplementation( - new Signature("max", AGGREGATE, parseTypeSignature(StandardTypes.VARCHAR), parseTypeSignature(StandardTypes.VARCHAR))); + InternalAggregationFunction countVarcharColumn = getAggregation("count", VARCHAR); + InternalAggregationFunction countBooleanColumn = getAggregation("count", BOOLEAN); + InternalAggregationFunction maxVarcharColumn = getAggregation("max", VARCHAR); Optional groupIdChannel = Optional.of(1); List groupByChannels = Ints.asList(1, 2); @@ -267,9 +256,7 @@ public void testHashAggregationWithGlobals(boolean hashEnabled, boolean spillEna @Test(dataProvider = "hashEnabledAndMemoryLimitForMergeValues") public void testHashAggregationMemoryReservation(boolean hashEnabled, boolean spillEnabled, long memoryLimitForMerge, long memoryLimitForMergeWithMemory) { - MetadataManager metadata = MetadataManager.createTestMetadataManager(); - InternalAggregationFunction arrayAggColumn = metadata.getFunctionManager().getAggregateFunctionImplementation( - new Signature("array_agg", AGGREGATE, parseTypeSignature("array(bigint)"), parseTypeSignature(StandardTypes.BIGINT))); + InternalAggregationFunction arrayAggColumn = getAggregation("array_agg", BIGINT); List hashChannels = Ints.asList(1); RowPagesBuilder rowPagesBuilder = rowPagesBuilder(hashEnabled, hashChannels, BIGINT, BIGINT); @@ -311,9 +298,7 @@ public void testHashAggregationMemoryReservation(boolean hashEnabled, boolean sp @Test(dataProvider = "hashEnabled", expectedExceptions = ExceededMemoryLimitException.class, expectedExceptionsMessageRegExp = "Query exceeded per-node user memory limit of 10B.*") public void testMemoryLimit(boolean hashEnabled) { - MetadataManager metadata = MetadataManager.createTestMetadataManager(); - InternalAggregationFunction maxVarcharColumn = metadata.getFunctionManager().getAggregateFunctionImplementation( - new Signature("max", AGGREGATE, parseTypeSignature(StandardTypes.VARCHAR), parseTypeSignature(StandardTypes.VARCHAR))); + InternalAggregationFunction maxVarcharColumn = getAggregation("max", VARCHAR); List hashChannels = Ints.asList(1); RowPagesBuilder rowPagesBuilder = rowPagesBuilder(hashEnabled, hashChannels, VARCHAR, BIGINT, VARCHAR, BIGINT); @@ -623,9 +608,7 @@ public void testMergeWithMemorySpill() @Test public void testSpillerFailure() { - MetadataManager metadata = MetadataManager.createTestMetadataManager(); - InternalAggregationFunction maxVarcharColumn = metadata.getFunctionManager().getAggregateFunctionImplementation( - new Signature("max", AGGREGATE, parseTypeSignature(StandardTypes.VARCHAR), parseTypeSignature(StandardTypes.VARCHAR))); + InternalAggregationFunction maxVarcharColumn = getAggregation("max", VARCHAR); List hashChannels = Ints.asList(1); ImmutableList types = ImmutableList.of(VARCHAR, BIGINT, VARCHAR, BIGINT); @@ -754,6 +737,11 @@ private int getHashCapacity(Operator operator) return ((InMemoryHashAggregationBuilder) aggregationBuilder).getCapacity(); } + private static InternalAggregationFunction getAggregation(String name, Type... arguments) + { + return functionManager.getAggregateFunctionImplementation(functionManager.resolveFunction(TEST_SESSION, QualifiedName.of(name), fromTypes(arguments))); + } + private static class DummySpillerFactory implements SpillerFactory { diff --git a/presto-main/src/test/java/com/facebook/presto/operator/TestRealAverageAggregation.java b/presto-main/src/test/java/com/facebook/presto/operator/TestRealAverageAggregation.java index 961d3cf894395..c1f96e31bec8a 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/TestRealAverageAggregation.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/TestRealAverageAggregation.java @@ -13,24 +13,25 @@ */ package com.facebook.presto.operator; -import com.facebook.presto.metadata.FunctionKind; +import com.facebook.presto.metadata.FunctionManager; import com.facebook.presto.metadata.MetadataManager; -import com.facebook.presto.metadata.Signature; import com.facebook.presto.operator.aggregation.AbstractTestAggregationFunction; import com.facebook.presto.operator.aggregation.InternalAggregationFunction; import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.BlockBuilder; import com.facebook.presto.spi.type.StandardTypes; +import com.facebook.presto.sql.tree.QualifiedName; import com.google.common.collect.ImmutableList; import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; import java.util.List; +import static com.facebook.presto.SessionTestUtils.TEST_SESSION; import static com.facebook.presto.block.BlockAssertions.createBlockOfReals; import static com.facebook.presto.operator.aggregation.AggregationTestUtils.assertAggregation; import static com.facebook.presto.spi.type.RealType.REAL; -import static com.facebook.presto.spi.type.TypeSignature.parseTypeSignature; +import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes; import static java.lang.Float.floatToRawIntBits; @Test(singleThreaded = true) @@ -42,9 +43,9 @@ public class TestRealAverageAggregation @BeforeClass public void setUp() { - MetadataManager metadata = MetadataManager.createTestMetadataManager(); - avgFunction = metadata.getFunctionManager().getAggregateFunctionImplementation( - new Signature("avg", FunctionKind.AGGREGATE, parseTypeSignature(StandardTypes.REAL), parseTypeSignature(StandardTypes.REAL))); + FunctionManager functionManager = MetadataManager.createTestMetadataManager().getFunctionManager(); + avgFunction = functionManager.getAggregateFunctionImplementation( + functionManager.resolveFunction(TEST_SESSION, QualifiedName.of("avg"), fromTypes(REAL))); } @Test diff --git a/presto-main/src/test/java/com/facebook/presto/operator/TestStreamingAggregationOperator.java b/presto-main/src/test/java/com/facebook/presto/operator/TestStreamingAggregationOperator.java index b549e4a363dc4..3172c847399ac 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/TestStreamingAggregationOperator.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/TestStreamingAggregationOperator.java @@ -14,8 +14,8 @@ package com.facebook.presto.operator; import com.facebook.presto.RowPagesBuilder; +import com.facebook.presto.metadata.FunctionManager; import com.facebook.presto.metadata.MetadataManager; -import com.facebook.presto.metadata.Signature; import com.facebook.presto.operator.StreamingAggregationOperator.StreamingAggregationOperatorFactory; import com.facebook.presto.operator.aggregation.InternalAggregationFunction; import com.facebook.presto.spi.Page; @@ -23,6 +23,7 @@ import com.facebook.presto.sql.gen.JoinCompiler; import com.facebook.presto.sql.planner.plan.AggregationNode; import com.facebook.presto.sql.planner.plan.PlanNodeId; +import com.facebook.presto.sql.tree.QualifiedName; import com.facebook.presto.testing.MaterializedResult; import com.google.common.collect.ImmutableList; import org.testng.annotations.AfterMethod; @@ -35,11 +36,11 @@ import java.util.concurrent.ScheduledExecutorService; import static com.facebook.presto.SessionTestUtils.TEST_SESSION; -import static com.facebook.presto.metadata.FunctionKind.AGGREGATE; import static com.facebook.presto.operator.OperatorAssertion.assertOperatorEquals; import static com.facebook.presto.spi.type.BigintType.BIGINT; import static com.facebook.presto.spi.type.BooleanType.BOOLEAN; import static com.facebook.presto.spi.type.VarcharType.VARCHAR; +import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes; import static com.facebook.presto.testing.MaterializedResult.resultBuilder; import static com.facebook.presto.testing.TestingTaskContext.createTaskContext; import static io.airlift.concurrent.Threads.daemonThreadsNamed; @@ -50,12 +51,12 @@ @Test(singleThreaded = true) public class TestStreamingAggregationOperator { - private static final MetadataManager metadata = MetadataManager.createTestMetadataManager(); + private static final FunctionManager functionManager = MetadataManager.createTestMetadataManager().getFunctionManager(); - private static final InternalAggregationFunction LONG_SUM = metadata.getFunctionManager().getAggregateFunctionImplementation( - new Signature("sum", AGGREGATE, BIGINT.getTypeSignature(), BIGINT.getTypeSignature())); - private static final InternalAggregationFunction COUNT = metadata.getFunctionManager().getAggregateFunctionImplementation( - new Signature("count", AGGREGATE, BIGINT.getTypeSignature())); + private static final InternalAggregationFunction LONG_SUM = functionManager.getAggregateFunctionImplementation( + functionManager.resolveFunction(TEST_SESSION, QualifiedName.of("sum"), fromTypes(BIGINT))); + private static final InternalAggregationFunction COUNT = functionManager.getAggregateFunctionImplementation( + functionManager.resolveFunction(TEST_SESSION, QualifiedName.of("count"), ImmutableList.of())); private ExecutorService executor; private ScheduledExecutorService scheduledExecutor; diff --git a/presto-main/src/test/java/com/facebook/presto/operator/TestTableFinishOperator.java b/presto-main/src/test/java/com/facebook/presto/operator/TestTableFinishOperator.java index 9c77155262544..a56b7322da32f 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/TestTableFinishOperator.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/TestTableFinishOperator.java @@ -14,7 +14,7 @@ package com.facebook.presto.operator; import com.facebook.presto.Session; -import com.facebook.presto.metadata.Signature; +import com.facebook.presto.metadata.FunctionManager; import com.facebook.presto.operator.TableFinishOperator.TableFinishOperatorFactory; import com.facebook.presto.operator.TableFinishOperator.TableFinisher; import com.facebook.presto.operator.aggregation.InternalAggregationFunction; @@ -27,6 +27,7 @@ import com.facebook.presto.sql.planner.plan.AggregationNode; import com.facebook.presto.sql.planner.plan.PlanNodeId; import com.facebook.presto.sql.planner.plan.StatisticAggregationsDescriptor; +import com.facebook.presto.sql.tree.QualifiedName; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import io.airlift.slice.Slice; @@ -41,13 +42,14 @@ import java.util.concurrent.ScheduledExecutorService; import static com.facebook.presto.RowPagesBuilder.rowPagesBuilder; +import static com.facebook.presto.SessionTestUtils.TEST_SESSION; import static com.facebook.presto.block.BlockAssertions.assertBlockEquals; -import static com.facebook.presto.metadata.FunctionKind.AGGREGATE; import static com.facebook.presto.metadata.MetadataManager.createTestMetadataManager; import static com.facebook.presto.operator.PageAssertions.assertPageEquals; import static com.facebook.presto.spi.statistics.ColumnStatisticType.MAX_VALUE; import static com.facebook.presto.spi.type.BigintType.BIGINT; import static com.facebook.presto.spi.type.VarbinaryType.VARBINARY; +import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes; import static com.facebook.presto.testing.TestingSession.testSessionBuilder; import static com.facebook.presto.testing.TestingTaskContext.createTaskContext; import static com.google.common.base.Preconditions.checkState; @@ -63,8 +65,9 @@ public class TestTableFinishOperator { - private static final InternalAggregationFunction LONG_MAX = createTestMetadataManager().getFunctionManager().getAggregateFunctionImplementation( - new Signature("max", AGGREGATE, BIGINT.getTypeSignature(), BIGINT.getTypeSignature())); + private static final FunctionManager functionManager = createTestMetadataManager().getFunctionManager(); + private static final InternalAggregationFunction LONG_MAX = functionManager.getAggregateFunctionImplementation( + functionManager.resolveFunction(TEST_SESSION, QualifiedName.of("max"), fromTypes(BIGINT))); private ScheduledExecutorService scheduledExecutor; diff --git a/presto-main/src/test/java/com/facebook/presto/operator/TestTableWriterOperator.java b/presto-main/src/test/java/com/facebook/presto/operator/TestTableWriterOperator.java index ae8656843cde2..174dd72b5af77 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/TestTableWriterOperator.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/TestTableWriterOperator.java @@ -17,8 +17,8 @@ import com.facebook.presto.Session; import com.facebook.presto.connector.ConnectorId; import com.facebook.presto.memory.context.MemoryTrackingContext; +import com.facebook.presto.metadata.FunctionManager; import com.facebook.presto.metadata.OutputTableHandle; -import com.facebook.presto.metadata.Signature; import com.facebook.presto.operator.AggregationOperator.AggregationOperatorFactory; import com.facebook.presto.operator.DevNullOperator.DevNullOperatorFactory; import com.facebook.presto.operator.TableWriterOperator.TableWriterInfo; @@ -37,6 +37,7 @@ import com.facebook.presto.sql.planner.plan.AggregationNode; import com.facebook.presto.sql.planner.plan.PlanNodeId; import com.facebook.presto.sql.planner.plan.TableWriterNode; +import com.facebook.presto.sql.tree.QualifiedName; import com.google.common.collect.ImmutableList; import io.airlift.slice.Slice; import org.testng.annotations.AfterClass; @@ -53,11 +54,11 @@ import static com.facebook.presto.RowPagesBuilder.rowPagesBuilder; import static com.facebook.presto.SessionTestUtils.TEST_SESSION; -import static com.facebook.presto.metadata.FunctionKind.AGGREGATE; import static com.facebook.presto.metadata.MetadataManager.createTestMetadataManager; import static com.facebook.presto.operator.PageAssertions.assertPageEquals; import static com.facebook.presto.spi.type.BigintType.BIGINT; import static com.facebook.presto.spi.type.VarbinaryType.VARBINARY; +import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes; import static com.facebook.presto.testing.TestingSession.testSessionBuilder; import static com.facebook.presto.testing.TestingTaskContext.createTaskContext; import static io.airlift.concurrent.Threads.daemonThreadsNamed; @@ -74,8 +75,6 @@ public class TestTableWriterOperator { private static final ConnectorId CONNECTOR_ID = new ConnectorId("testConnectorId"); - private static final InternalAggregationFunction LONG_MAX = createTestMetadataManager().getFunctionManager().getAggregateFunctionImplementation( - new Signature("max", AGGREGATE, BIGINT.getTypeSignature(), BIGINT.getTypeSignature())); private ExecutorService executor; private ScheduledExecutorService scheduledExecutor; @@ -201,13 +200,16 @@ public void testStatisticsAggregation() DriverContext driverContext = createTaskContext(executor, scheduledExecutor, session) .addPipelineContext(0, true, true, false) .addDriverContext(); + FunctionManager functionManager = createTestMetadataManager().getFunctionManager(); + InternalAggregationFunction longMaxFunction = functionManager.getAggregateFunctionImplementation( + functionManager.resolveFunction(TEST_SESSION, QualifiedName.of("max"), fromTypes(BIGINT))); TableWriterOperator operator = (TableWriterOperator) createTableWriterOperator( pageSinkManager, new AggregationOperatorFactory( 1, new PlanNodeId("test"), AggregationNode.Step.SINGLE, - ImmutableList.of(LONG_MAX.bind(ImmutableList.of(0), Optional.empty())), + ImmutableList.of(longMaxFunction.bind(ImmutableList.of(0), Optional.empty())), true), outputTypes, session, diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/AbstractTestAggregationFunction.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/AbstractTestAggregationFunction.java index e1a71980f6afe..3d234e7e893a0 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/AbstractTestAggregationFunction.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/AbstractTestAggregationFunction.java @@ -13,9 +13,10 @@ */ package com.facebook.presto.operator.aggregation; +import com.facebook.presto.Session; import com.facebook.presto.block.BlockEncodingManager; +import com.facebook.presto.metadata.FunctionHandle; import com.facebook.presto.metadata.FunctionManager; -import com.facebook.presto.metadata.Signature; import com.facebook.presto.spi.Plugin; import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.BlockBuilder; @@ -36,17 +37,20 @@ import static com.facebook.presto.metadata.FunctionExtractor.extractFunctions; import static com.facebook.presto.operator.aggregation.AggregationTestUtils.assertAggregation; import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypeSignatures; +import static com.facebook.presto.testing.TestingSession.testSessionBuilder; public abstract class AbstractTestAggregationFunction { protected TypeRegistry typeRegistry; protected FunctionManager functionManager; + protected Session session; @BeforeClass public final void initTestAggregationFunction() { typeRegistry = new TypeRegistry(); functionManager = new FunctionManager(typeRegistry, new BlockEncodingManager(typeRegistry), new FeaturesConfig()); + session = testSessionBuilder().build(); } @AfterClass(alwaysRun = true) @@ -73,8 +77,8 @@ protected void registerTypes(Plugin plugin) protected final InternalAggregationFunction getFunction() { List parameterTypes = fromTypeSignatures(Lists.transform(getFunctionParameterTypes(), TypeSignature::parseTypeSignature)); - Signature signature = functionManager.resolveFunction(QualifiedName.of(getFunctionName()), parameterTypes); - return functionManager.getAggregateFunctionImplementation(signature); + FunctionHandle functionHandle = functionManager.resolveFunction(session, QualifiedName.of(getFunctionName()), parameterTypes); + return functionManager.getAggregateFunctionImplementation(functionHandle); } protected abstract String getFunctionName(); diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/AbstractTestApproximateCountDistinct.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/AbstractTestApproximateCountDistinct.java index 150662215c296..4738b828d25b8 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/AbstractTestApproximateCountDistinct.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/AbstractTestApproximateCountDistinct.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.operator.aggregation; +import com.facebook.presto.metadata.FunctionManager; import com.facebook.presto.metadata.MetadataManager; import com.facebook.presto.spi.Page; import com.facebook.presto.spi.block.Block; @@ -46,7 +47,7 @@ public abstract class AbstractTestApproximateCountDistinct public abstract Object randomValue(); - protected static final MetadataManager metadata = MetadataManager.createTestMetadataManager(); + protected static final FunctionManager functionManager = MetadataManager.createTestMetadataManager().getFunctionManager(); protected int getUniqueValuesCount() { diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/BenchmarkArrayAggregation.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/BenchmarkArrayAggregation.java index 8e9b561c6a67e..3181c6968a083 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/BenchmarkArrayAggregation.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/BenchmarkArrayAggregation.java @@ -13,13 +13,14 @@ */ package com.facebook.presto.operator.aggregation; +import com.facebook.presto.metadata.FunctionManager; import com.facebook.presto.metadata.MetadataManager; -import com.facebook.presto.metadata.Signature; import com.facebook.presto.spi.Page; import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.BlockBuilder; import com.facebook.presto.spi.type.ArrayType; import com.facebook.presto.spi.type.Type; +import com.facebook.presto.sql.tree.QualifiedName; import com.google.common.collect.ImmutableList; import io.airlift.slice.Slices; import org.openjdk.jmh.annotations.Benchmark; @@ -44,11 +45,12 @@ import java.util.concurrent.ThreadLocalRandom; import java.util.concurrent.TimeUnit; -import static com.facebook.presto.metadata.FunctionKind.AGGREGATE; +import static com.facebook.presto.SessionTestUtils.TEST_SESSION; import static com.facebook.presto.spi.type.BigintType.BIGINT; import static com.facebook.presto.spi.type.BooleanType.BOOLEAN; import static com.facebook.presto.spi.type.DoubleType.DOUBLE; import static com.facebook.presto.spi.type.VarcharType.VARCHAR; +import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes; import static org.openjdk.jmh.annotations.Level.Invocation; @SuppressWarnings("MethodMayBeStatic") @@ -84,7 +86,7 @@ public static class BenchmarkData @Setup(Invocation) public void setup() { - MetadataManager metadata = MetadataManager.createTestMetadataManager(); + FunctionManager functionManager = MetadataManager.createTestMetadataManager().getFunctionManager(); Block block; Type elementType; switch (type) { @@ -104,8 +106,8 @@ public void setup() throw new UnsupportedOperationException(); } ArrayType arrayType = new ArrayType(elementType); - Signature signature = new Signature(name, AGGREGATE, arrayType.getTypeSignature(), elementType.getTypeSignature()); - InternalAggregationFunction function = metadata.getFunctionManager().getAggregateFunctionImplementation(signature); + InternalAggregationFunction function = functionManager.getAggregateFunctionImplementation( + functionManager.resolveFunction(TEST_SESSION, QualifiedName.of(name), fromTypes(elementType))); accumulator = function.bind(ImmutableList.of(0), Optional.empty()).createAccumulator(); block = createChannel(ARRAY_SIZE, elementType); diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/BenchmarkGroupedTypedHistogram.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/BenchmarkGroupedTypedHistogram.java index d861fc9b1fda9..791071aa0bec3 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/BenchmarkGroupedTypedHistogram.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/BenchmarkGroupedTypedHistogram.java @@ -13,16 +13,15 @@ */ package com.facebook.presto.operator.aggregation; +import com.facebook.presto.metadata.FunctionManager; import com.facebook.presto.metadata.MetadataManager; -import com.facebook.presto.metadata.Signature; import com.facebook.presto.operator.GroupByIdBlock; import com.facebook.presto.operator.aggregation.groupByAggregations.GroupByAggregationTestUtils; import com.facebook.presto.operator.aggregation.histogram.HistogramGroupImplementation; import com.facebook.presto.spi.Page; import com.facebook.presto.spi.block.Block; -import com.facebook.presto.spi.type.MapType; -import com.facebook.presto.spi.type.StandardTypes; import com.facebook.presto.sql.analyzer.FeaturesConfig; +import com.facebook.presto.sql.tree.QualifiedName; import org.openjdk.jmh.annotations.Benchmark; import org.openjdk.jmh.annotations.Fork; import org.openjdk.jmh.annotations.Measurement; @@ -48,13 +47,11 @@ import java.util.stream.Collectors; import java.util.stream.IntStream; +import static com.facebook.presto.SessionTestUtils.TEST_SESSION; import static com.facebook.presto.block.BlockAssertions.createStringsBlock; -import static com.facebook.presto.metadata.FunctionKind.AGGREGATE; import static com.facebook.presto.operator.aggregation.histogram.Histogram.NAME; -import static com.facebook.presto.spi.type.BigintType.BIGINT; -import static com.facebook.presto.spi.type.TypeSignature.parseTypeSignature; import static com.facebook.presto.spi.type.VarcharType.VARCHAR; -import static com.facebook.presto.util.StructuralTestUtil.mapType; +import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes; @OutputTimeUnit(TimeUnit.SECONDS) //@BenchmarkMode(Mode.AverageTime) @@ -159,14 +156,10 @@ public GroupedAccumulator testSharedGroupWithLargeBlocksRunner(Data data) private static InternalAggregationFunction getInternalAggregationFunctionVarChar(HistogramGroupImplementation groupMode) { - MapType mapType = mapType(VARCHAR, BIGINT); - MetadataManager metadata = getMetadata(groupMode); - - return metadata.getFunctionManager().getAggregateFunctionImplementation( - new Signature(NAME, - AGGREGATE, - mapType.getTypeSignature(), - parseTypeSignature(StandardTypes.VARCHAR))); + FunctionManager functionManager = getMetadata(groupMode).getFunctionManager(); + + return functionManager.getAggregateFunctionImplementation( + functionManager.resolveFunction(TEST_SESSION, QualifiedName.of(NAME), fromTypes(VARCHAR))); } private static MetadataManager getMetadata(HistogramGroupImplementation groupMode) diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestApproximateCountDistinctBoolean.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestApproximateCountDistinctBoolean.java index b7ad249817ed3..79628906387df 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestApproximateCountDistinctBoolean.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestApproximateCountDistinctBoolean.java @@ -13,8 +13,8 @@ */ package com.facebook.presto.operator.aggregation; -import com.facebook.presto.metadata.Signature; import com.facebook.presto.spi.type.Type; +import com.facebook.presto.sql.tree.QualifiedName; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import com.google.common.primitives.Booleans; @@ -24,10 +24,10 @@ import java.util.List; import java.util.concurrent.ThreadLocalRandom; -import static com.facebook.presto.metadata.FunctionKind.AGGREGATE; -import static com.facebook.presto.spi.type.BigintType.BIGINT; +import static com.facebook.presto.SessionTestUtils.TEST_SESSION; import static com.facebook.presto.spi.type.BooleanType.BOOLEAN; import static com.facebook.presto.spi.type.DoubleType.DOUBLE; +import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes; public class TestApproximateCountDistinctBoolean extends AbstractTestApproximateCountDistinct @@ -35,8 +35,8 @@ public class TestApproximateCountDistinctBoolean @Override public InternalAggregationFunction getAggregationFunction() { - return metadata.getFunctionManager().getAggregateFunctionImplementation( - new Signature("approx_distinct", AGGREGATE, BIGINT.getTypeSignature(), BOOLEAN.getTypeSignature(), DOUBLE.getTypeSignature())); + return functionManager.getAggregateFunctionImplementation( + functionManager.resolveFunction(TEST_SESSION, QualifiedName.of("approx_distinct"), fromTypes(BOOLEAN, DOUBLE))); } @Override diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestApproximateCountDistinctDouble.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestApproximateCountDistinctDouble.java index ab55316b9fd03..5ea1caa6104e1 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestApproximateCountDistinctDouble.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestApproximateCountDistinctDouble.java @@ -13,14 +13,14 @@ */ package com.facebook.presto.operator.aggregation; -import com.facebook.presto.metadata.Signature; import com.facebook.presto.spi.type.Type; +import com.facebook.presto.sql.tree.QualifiedName; import java.util.concurrent.ThreadLocalRandom; -import static com.facebook.presto.metadata.FunctionKind.AGGREGATE; -import static com.facebook.presto.spi.type.BigintType.BIGINT; +import static com.facebook.presto.SessionTestUtils.TEST_SESSION; import static com.facebook.presto.spi.type.DoubleType.DOUBLE; +import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes; public class TestApproximateCountDistinctDouble extends AbstractTestApproximateCountDistinct @@ -28,8 +28,8 @@ public class TestApproximateCountDistinctDouble @Override public InternalAggregationFunction getAggregationFunction() { - return metadata.getFunctionManager().getAggregateFunctionImplementation( - new Signature("approx_distinct", AGGREGATE, BIGINT.getTypeSignature(), DOUBLE.getTypeSignature(), DOUBLE.getTypeSignature())); + return functionManager.getAggregateFunctionImplementation( + functionManager.resolveFunction(TEST_SESSION, QualifiedName.of("approx_distinct"), fromTypes(DOUBLE, DOUBLE))); } @Override diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestApproximateCountDistinctInteger.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestApproximateCountDistinctInteger.java index 27202a887eb77..86f1d5afa71a8 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestApproximateCountDistinctInteger.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestApproximateCountDistinctInteger.java @@ -13,15 +13,15 @@ */ package com.facebook.presto.operator.aggregation; -import com.facebook.presto.metadata.Signature; import com.facebook.presto.spi.type.Type; +import com.facebook.presto.sql.tree.QualifiedName; import java.util.concurrent.ThreadLocalRandom; -import static com.facebook.presto.metadata.FunctionKind.AGGREGATE; -import static com.facebook.presto.spi.type.BigintType.BIGINT; +import static com.facebook.presto.SessionTestUtils.TEST_SESSION; import static com.facebook.presto.spi.type.DoubleType.DOUBLE; import static com.facebook.presto.spi.type.IntegerType.INTEGER; +import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes; public class TestApproximateCountDistinctInteger extends AbstractTestApproximateCountDistinct @@ -29,8 +29,8 @@ public class TestApproximateCountDistinctInteger @Override public InternalAggregationFunction getAggregationFunction() { - return metadata.getFunctionManager().getAggregateFunctionImplementation( - new Signature("approx_distinct", AGGREGATE, BIGINT.getTypeSignature(), INTEGER.getTypeSignature(), DOUBLE.getTypeSignature())); + return functionManager.getAggregateFunctionImplementation( + functionManager.resolveFunction(TEST_SESSION, QualifiedName.of("approx_distinct"), fromTypes(INTEGER, DOUBLE))); } @Override diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestApproximateCountDistinctIpAddress.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestApproximateCountDistinctIpAddress.java index 4e9008cd32c3a..354e6eed8ae93 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestApproximateCountDistinctIpAddress.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestApproximateCountDistinctIpAddress.java @@ -13,15 +13,15 @@ */ package com.facebook.presto.operator.aggregation; -import com.facebook.presto.metadata.Signature; import com.facebook.presto.spi.type.Type; +import com.facebook.presto.sql.tree.QualifiedName; import io.airlift.slice.Slices; import java.util.concurrent.ThreadLocalRandom; -import static com.facebook.presto.metadata.FunctionKind.AGGREGATE; -import static com.facebook.presto.spi.type.BigintType.BIGINT; +import static com.facebook.presto.SessionTestUtils.TEST_SESSION; import static com.facebook.presto.spi.type.DoubleType.DOUBLE; +import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes; import static com.facebook.presto.type.IpAddressType.IPADDRESS; public class TestApproximateCountDistinctIpAddress @@ -30,8 +30,8 @@ public class TestApproximateCountDistinctIpAddress @Override public InternalAggregationFunction getAggregationFunction() { - return metadata.getFunctionManager().getAggregateFunctionImplementation( - new Signature("approx_distinct", AGGREGATE, BIGINT.getTypeSignature(), IPADDRESS.getTypeSignature(), DOUBLE.getTypeSignature())); + return functionManager.getAggregateFunctionImplementation( + functionManager.resolveFunction(TEST_SESSION, QualifiedName.of("approx_distinct"), fromTypes(IPADDRESS, DOUBLE))); } @Override diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestApproximateCountDistinctLong.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestApproximateCountDistinctLong.java index 4434503b1e7bf..fbd8eb8135676 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestApproximateCountDistinctLong.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestApproximateCountDistinctLong.java @@ -13,14 +13,15 @@ */ package com.facebook.presto.operator.aggregation; -import com.facebook.presto.metadata.Signature; import com.facebook.presto.spi.type.Type; +import com.facebook.presto.sql.tree.QualifiedName; import java.util.concurrent.ThreadLocalRandom; -import static com.facebook.presto.metadata.FunctionKind.AGGREGATE; +import static com.facebook.presto.SessionTestUtils.TEST_SESSION; import static com.facebook.presto.spi.type.BigintType.BIGINT; import static com.facebook.presto.spi.type.DoubleType.DOUBLE; +import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes; public class TestApproximateCountDistinctLong extends AbstractTestApproximateCountDistinct @@ -28,8 +29,8 @@ public class TestApproximateCountDistinctLong @Override public InternalAggregationFunction getAggregationFunction() { - return metadata.getFunctionManager().getAggregateFunctionImplementation( - new Signature("approx_distinct", AGGREGATE, BIGINT.getTypeSignature(), BIGINT.getTypeSignature(), DOUBLE.getTypeSignature())); + return functionManager.getAggregateFunctionImplementation( + functionManager.resolveFunction(TEST_SESSION, QualifiedName.of("approx_distinct"), fromTypes(BIGINT, DOUBLE))); } @Override diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestApproximateCountDistinctLongDecimal.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestApproximateCountDistinctLongDecimal.java index 2391dd03b7ab8..db4a85daefe5e 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestApproximateCountDistinctLongDecimal.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestApproximateCountDistinctLongDecimal.java @@ -13,17 +13,17 @@ */ package com.facebook.presto.operator.aggregation; -import com.facebook.presto.metadata.Signature; import com.facebook.presto.spi.type.Type; +import com.facebook.presto.sql.tree.QualifiedName; import io.airlift.slice.Slices; import java.util.concurrent.ThreadLocalRandom; -import static com.facebook.presto.metadata.FunctionKind.AGGREGATE; -import static com.facebook.presto.spi.type.BigintType.BIGINT; +import static com.facebook.presto.SessionTestUtils.TEST_SESSION; import static com.facebook.presto.spi.type.DecimalType.createDecimalType; import static com.facebook.presto.spi.type.Decimals.MAX_PRECISION; import static com.facebook.presto.spi.type.DoubleType.DOUBLE; +import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes; public class TestApproximateCountDistinctLongDecimal extends AbstractTestApproximateCountDistinct @@ -33,8 +33,8 @@ public class TestApproximateCountDistinctLongDecimal @Override public InternalAggregationFunction getAggregationFunction() { - return metadata.getFunctionManager().getAggregateFunctionImplementation( - new Signature("approx_distinct", AGGREGATE, BIGINT.getTypeSignature(), LONG_DECIMAL.getTypeSignature(), DOUBLE.getTypeSignature())); + return functionManager.getAggregateFunctionImplementation( + functionManager.resolveFunction(TEST_SESSION, QualifiedName.of("approx_distinct"), fromTypes(LONG_DECIMAL, DOUBLE))); } @Override diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestApproximateCountDistinctSmallint.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestApproximateCountDistinctSmallint.java index 20917b56cd507..ee4d1a3351217 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestApproximateCountDistinctSmallint.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestApproximateCountDistinctSmallint.java @@ -13,15 +13,15 @@ */ package com.facebook.presto.operator.aggregation; -import com.facebook.presto.metadata.Signature; import com.facebook.presto.spi.type.Type; +import com.facebook.presto.sql.tree.QualifiedName; import java.util.concurrent.ThreadLocalRandom; -import static com.facebook.presto.metadata.FunctionKind.AGGREGATE; -import static com.facebook.presto.spi.type.BigintType.BIGINT; +import static com.facebook.presto.SessionTestUtils.TEST_SESSION; import static com.facebook.presto.spi.type.DoubleType.DOUBLE; import static com.facebook.presto.spi.type.SmallintType.SMALLINT; +import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes; public class TestApproximateCountDistinctSmallint extends AbstractTestApproximateCountDistinct @@ -29,8 +29,8 @@ public class TestApproximateCountDistinctSmallint @Override public InternalAggregationFunction getAggregationFunction() { - return metadata.getFunctionManager().getAggregateFunctionImplementation( - new Signature("approx_distinct", AGGREGATE, BIGINT.getTypeSignature(), SMALLINT.getTypeSignature(), DOUBLE.getTypeSignature())); + return functionManager.getAggregateFunctionImplementation( + functionManager.resolveFunction(TEST_SESSION, QualifiedName.of("approx_distinct"), fromTypes(SMALLINT, DOUBLE))); } @Override diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestApproximateCountDistinctTinyint.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestApproximateCountDistinctTinyint.java index 5816ed8aae9cf..973c21d3375c2 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestApproximateCountDistinctTinyint.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestApproximateCountDistinctTinyint.java @@ -13,15 +13,15 @@ */ package com.facebook.presto.operator.aggregation; -import com.facebook.presto.metadata.Signature; import com.facebook.presto.spi.type.Type; +import com.facebook.presto.sql.tree.QualifiedName; import java.util.concurrent.ThreadLocalRandom; -import static com.facebook.presto.metadata.FunctionKind.AGGREGATE; -import static com.facebook.presto.spi.type.BigintType.BIGINT; +import static com.facebook.presto.SessionTestUtils.TEST_SESSION; import static com.facebook.presto.spi.type.DoubleType.DOUBLE; import static com.facebook.presto.spi.type.TinyintType.TINYINT; +import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes; public class TestApproximateCountDistinctTinyint extends AbstractTestApproximateCountDistinct @@ -29,8 +29,8 @@ public class TestApproximateCountDistinctTinyint @Override public InternalAggregationFunction getAggregationFunction() { - return metadata.getFunctionManager().getAggregateFunctionImplementation( - new Signature("approx_distinct", AGGREGATE, BIGINT.getTypeSignature(), TINYINT.getTypeSignature(), DOUBLE.getTypeSignature())); + return functionManager.getAggregateFunctionImplementation( + functionManager.resolveFunction(TEST_SESSION, QualifiedName.of("approx_distinct"), fromTypes(TINYINT, DOUBLE))); } @Override diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestApproximateCountDistinctVarBinary.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestApproximateCountDistinctVarBinary.java index a12ff62342f7d..6fbf85ab9b668 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestApproximateCountDistinctVarBinary.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestApproximateCountDistinctVarBinary.java @@ -13,17 +13,16 @@ */ package com.facebook.presto.operator.aggregation; -import com.facebook.presto.metadata.Signature; import com.facebook.presto.spi.type.Type; -import com.facebook.presto.spi.type.VarcharType; +import com.facebook.presto.sql.tree.QualifiedName; import io.airlift.slice.Slices; import java.util.concurrent.ThreadLocalRandom; -import static com.facebook.presto.metadata.FunctionKind.AGGREGATE; -import static com.facebook.presto.spi.type.BigintType.BIGINT; +import static com.facebook.presto.SessionTestUtils.TEST_SESSION; import static com.facebook.presto.spi.type.DoubleType.DOUBLE; -import static com.facebook.presto.spi.type.TypeSignature.parseTypeSignature; +import static com.facebook.presto.spi.type.VarcharType.VARCHAR; +import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes; public class TestApproximateCountDistinctVarBinary extends AbstractTestApproximateCountDistinct @@ -31,14 +30,14 @@ public class TestApproximateCountDistinctVarBinary @Override public InternalAggregationFunction getAggregationFunction() { - return metadata.getFunctionManager().getAggregateFunctionImplementation( - new Signature("approx_distinct", AGGREGATE, BIGINT.getTypeSignature(), parseTypeSignature("varchar"), DOUBLE.getTypeSignature())); + return functionManager.getAggregateFunctionImplementation( + functionManager.resolveFunction(TEST_SESSION, QualifiedName.of("approx_distinct"), fromTypes(VARCHAR, DOUBLE))); } @Override public Type getValueType() { - return VarcharType.VARCHAR; + return VARCHAR; } @Override diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestApproximatePercentileAggregation.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestApproximatePercentileAggregation.java index 4bd186825b9f9..fcdd954ecb1dd 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestApproximatePercentileAggregation.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestApproximatePercentileAggregation.java @@ -13,14 +13,17 @@ */ package com.facebook.presto.operator.aggregation; +import com.facebook.presto.metadata.FunctionManager; import com.facebook.presto.metadata.MetadataManager; -import com.facebook.presto.metadata.Signature; import com.facebook.presto.spi.block.BlockBuilder; import com.facebook.presto.spi.block.RunLengthEncodedBlock; import com.facebook.presto.spi.type.ArrayType; +import com.facebook.presto.spi.type.Type; +import com.facebook.presto.sql.tree.QualifiedName; import com.google.common.collect.ImmutableList; import org.testng.annotations.Test; +import static com.facebook.presto.SessionTestUtils.TEST_SESSION; import static com.facebook.presto.block.BlockAssertions.createBlockOfReals; import static com.facebook.presto.block.BlockAssertions.createDoubleSequenceBlock; import static com.facebook.presto.block.BlockAssertions.createDoublesBlock; @@ -28,57 +31,35 @@ import static com.facebook.presto.block.BlockAssertions.createLongSequenceBlock; import static com.facebook.presto.block.BlockAssertions.createLongsBlock; import static com.facebook.presto.block.BlockAssertions.createSequenceBlockOfReal; -import static com.facebook.presto.metadata.FunctionKind.AGGREGATE; import static com.facebook.presto.operator.aggregation.AggregationTestUtils.assertAggregation; import static com.facebook.presto.spi.type.BigintType.BIGINT; import static com.facebook.presto.spi.type.DoubleType.DOUBLE; import static com.facebook.presto.spi.type.RealType.REAL; -import static com.facebook.presto.spi.type.TypeSignature.parseTypeSignature; +import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes; public class TestApproximatePercentileAggregation { - private static final MetadataManager metadata = MetadataManager.createTestMetadataManager(); - - private static final InternalAggregationFunction DOUBLE_APPROXIMATE_PERCENTILE_AGGREGATION = metadata.getFunctionManager().getAggregateFunctionImplementation( - new Signature("approx_percentile", AGGREGATE, DOUBLE.getTypeSignature(), DOUBLE.getTypeSignature(), DOUBLE.getTypeSignature())); - private static final InternalAggregationFunction DOUBLE_APPROXIMATE_PERCENTILE_WEIGHTED_AGGREGATION = metadata.getFunctionManager().getAggregateFunctionImplementation( - new Signature("approx_percentile", AGGREGATE, DOUBLE.getTypeSignature(), DOUBLE.getTypeSignature(), BIGINT.getTypeSignature(), DOUBLE.getTypeSignature())); - private static final InternalAggregationFunction DOUBLE_APPROXIMATE_PERCENTILE_WEIGHTED_AGGREGATION_WITH_ACCURACY = metadata.getFunctionManager().getAggregateFunctionImplementation( - new Signature("approx_percentile", AGGREGATE, DOUBLE.getTypeSignature(), DOUBLE.getTypeSignature(), BIGINT.getTypeSignature(), DOUBLE.getTypeSignature(), DOUBLE.getTypeSignature())); - - private static final InternalAggregationFunction LONG_APPROXIMATE_PERCENTILE_AGGREGATION = metadata.getFunctionManager().getAggregateFunctionImplementation( - new Signature("approx_percentile", AGGREGATE, BIGINT.getTypeSignature(), BIGINT.getTypeSignature(), DOUBLE.getTypeSignature())); - private static final InternalAggregationFunction LONG_APPROXIMATE_PERCENTILE_WEIGHTED_AGGREGATION = metadata.getFunctionManager().getAggregateFunctionImplementation( - new Signature("approx_percentile", AGGREGATE, BIGINT.getTypeSignature(), BIGINT.getTypeSignature(), BIGINT.getTypeSignature(), DOUBLE.getTypeSignature())); - private static final InternalAggregationFunction LONG_APPROXIMATE_PERCENTILE_WEIGHTED_AGGREGATION_WITH_ACCURACY = metadata.getFunctionManager().getAggregateFunctionImplementation( - new Signature("approx_percentile", AGGREGATE, BIGINT.getTypeSignature(), BIGINT.getTypeSignature(), BIGINT.getTypeSignature(), DOUBLE.getTypeSignature(), DOUBLE.getTypeSignature())); - - private static final InternalAggregationFunction DOUBLE_APPROXIMATE_PERCENTILE_ARRAY_AGGREGATION = metadata.getFunctionManager().getAggregateFunctionImplementation( - new Signature("approx_percentile", AGGREGATE, parseTypeSignature("array(double)"), DOUBLE.getTypeSignature(), parseTypeSignature("array(double)"))); - private static final InternalAggregationFunction DOUBLE_APPROXIMATE_PERCENTILE_ARRAY_WEIGHTED_AGGREGATION = metadata.getFunctionManager().getAggregateFunctionImplementation( - new Signature("approx_percentile", AGGREGATE, parseTypeSignature("array(double)"), DOUBLE.getTypeSignature(), BIGINT.getTypeSignature(), parseTypeSignature("array(double)"))); - - private static final InternalAggregationFunction LONG_APPROXIMATE_PERCENTILE_ARRAY_AGGREGATION = metadata.getFunctionManager().getAggregateFunctionImplementation( - new Signature("approx_percentile", AGGREGATE, parseTypeSignature("array(bigint)"), BIGINT.getTypeSignature(), parseTypeSignature("array(double)"))); - private static final InternalAggregationFunction LONG_APPROXIMATE_PERCENTILE_ARRAY_WEIGHTED_AGGREGATION = metadata.getFunctionManager().getAggregateFunctionImplementation( - new Signature("approx_percentile", AGGREGATE, parseTypeSignature("array(bigint)"), BIGINT.getTypeSignature(), BIGINT.getTypeSignature(), parseTypeSignature("array(double)"))); - - private static final InternalAggregationFunction FLOAT_APPROXIMATE_PERCENTILE_AGGREGATION = metadata.getFunctionManager().getAggregateFunctionImplementation( - new Signature("approx_percentile", AGGREGATE, REAL.getTypeSignature(), - ImmutableList.of(REAL.getTypeSignature(), DOUBLE.getTypeSignature()))); - private static final InternalAggregationFunction FLOAT_APPROXIMATE_PERCENTILE_WEIGHTED_AGGREGATION = metadata.getFunctionManager().getAggregateFunctionImplementation( - new Signature("approx_percentile", AGGREGATE, REAL.getTypeSignature(), - ImmutableList.of(REAL.getTypeSignature(), BIGINT.getTypeSignature(), DOUBLE.getTypeSignature()))); - private static final InternalAggregationFunction FLOAT_APPROXIMATE_PERCENTILE_WEIGHTED_AGGREGATION_WITH_ACCURACY = metadata.getFunctionManager().getAggregateFunctionImplementation( - new Signature("approx_percentile", AGGREGATE, REAL.getTypeSignature(), - ImmutableList.of(REAL.getTypeSignature(), BIGINT.getTypeSignature(), DOUBLE.getTypeSignature(), DOUBLE.getTypeSignature()))); - - private static final InternalAggregationFunction FLOAT_APPROXIMATE_PERCENTILE_ARRAY_AGGREGATION = metadata.getFunctionManager().getAggregateFunctionImplementation( - new Signature("approx_percentile", AGGREGATE, new ArrayType(REAL).getTypeSignature(), - ImmutableList.of(REAL.getTypeSignature(), new ArrayType(DOUBLE).getTypeSignature()))); - private static final InternalAggregationFunction FLOAT_APPROXIMATE_PERCENTILE_ARRAY_WEIGHTED_AGGREGATION = metadata.getFunctionManager().getAggregateFunctionImplementation( - new Signature("approx_percentile", AGGREGATE, new ArrayType(REAL).getTypeSignature(), - ImmutableList.of(REAL.getTypeSignature(), BIGINT.getTypeSignature(), new ArrayType(DOUBLE).getTypeSignature()))); + private static final FunctionManager functionManager = MetadataManager.createTestMetadataManager().getFunctionManager(); + + private static final InternalAggregationFunction DOUBLE_APPROXIMATE_PERCENTILE_AGGREGATION = getAggregation(DOUBLE, DOUBLE); + private static final InternalAggregationFunction DOUBLE_APPROXIMATE_PERCENTILE_WEIGHTED_AGGREGATION = getAggregation(DOUBLE, BIGINT, DOUBLE); + private static final InternalAggregationFunction DOUBLE_APPROXIMATE_PERCENTILE_WEIGHTED_AGGREGATION_WITH_ACCURACY = getAggregation(DOUBLE, BIGINT, DOUBLE, DOUBLE); + + private static final InternalAggregationFunction LONG_APPROXIMATE_PERCENTILE_AGGREGATION = getAggregation(BIGINT, DOUBLE); + private static final InternalAggregationFunction LONG_APPROXIMATE_PERCENTILE_WEIGHTED_AGGREGATION = getAggregation(BIGINT, BIGINT, DOUBLE); + private static final InternalAggregationFunction LONG_APPROXIMATE_PERCENTILE_WEIGHTED_AGGREGATION_WITH_ACCURACY = getAggregation(BIGINT, BIGINT, DOUBLE, DOUBLE); + + private static final InternalAggregationFunction DOUBLE_APPROXIMATE_PERCENTILE_ARRAY_AGGREGATION = getAggregation(DOUBLE, new ArrayType(DOUBLE)); + private static final InternalAggregationFunction DOUBLE_APPROXIMATE_PERCENTILE_ARRAY_WEIGHTED_AGGREGATION = getAggregation(DOUBLE, BIGINT, new ArrayType(DOUBLE)); + + private static final InternalAggregationFunction LONG_APPROXIMATE_PERCENTILE_ARRAY_AGGREGATION = getAggregation(BIGINT, new ArrayType(DOUBLE)); + private static final InternalAggregationFunction LONG_APPROXIMATE_PERCENTILE_ARRAY_WEIGHTED_AGGREGATION = getAggregation(BIGINT, BIGINT, new ArrayType(DOUBLE)); + + private static final InternalAggregationFunction FLOAT_APPROXIMATE_PERCENTILE_AGGREGATION = getAggregation(REAL, DOUBLE); + private static final InternalAggregationFunction FLOAT_APPROXIMATE_PERCENTILE_WEIGHTED_AGGREGATION = getAggregation(REAL, BIGINT, DOUBLE); + private static final InternalAggregationFunction FLOAT_APPROXIMATE_PERCENTILE_WEIGHTED_AGGREGATION_WITH_ACCURACY = getAggregation(REAL, BIGINT, DOUBLE, DOUBLE); + private static final InternalAggregationFunction FLOAT_APPROXIMATE_PERCENTILE_ARRAY_AGGREGATION = getAggregation(REAL, new ArrayType(DOUBLE)); + private static final InternalAggregationFunction FLOAT_APPROXIMATE_PERCENTILE_ARRAY_WEIGHTED_AGGREGATION = getAggregation(REAL, BIGINT, new ArrayType(DOUBLE)); @Test public void testLongPartialStep() @@ -462,6 +443,11 @@ public void testDoublePartialStep() createRLEBlock(ImmutableList.of(0.5, 0.8), 3)); } + private static InternalAggregationFunction getAggregation(Type... arguments) + { + return functionManager.getAggregateFunctionImplementation(functionManager.resolveFunction(TEST_SESSION, QualifiedName.of("approx_percentile"), fromTypes(arguments))); + } + private static RunLengthEncodedBlock createRLEBlock(double percentile, int positionCount) { BlockBuilder blockBuilder = DOUBLE.createBlockBuilder(null, 1); diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestArbitraryAggregation.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestArbitraryAggregation.java index 442d34a788be8..18f5ed1a26a8a 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestArbitraryAggregation.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestArbitraryAggregation.java @@ -13,31 +13,38 @@ */ package com.facebook.presto.operator.aggregation; +import com.facebook.presto.metadata.FunctionManager; import com.facebook.presto.metadata.MetadataManager; -import com.facebook.presto.metadata.Signature; -import com.facebook.presto.spi.type.StandardTypes; +import com.facebook.presto.spi.type.ArrayType; import com.facebook.presto.spi.type.Type; +import com.facebook.presto.sql.tree.QualifiedName; import com.google.common.collect.ImmutableList; import org.testng.annotations.Test; import java.util.Arrays; import java.util.Set; +import static com.facebook.presto.SessionTestUtils.TEST_SESSION; import static com.facebook.presto.block.BlockAssertions.createArrayBigintBlock; import static com.facebook.presto.block.BlockAssertions.createBooleansBlock; import static com.facebook.presto.block.BlockAssertions.createDoublesBlock; import static com.facebook.presto.block.BlockAssertions.createIntsBlock; import static com.facebook.presto.block.BlockAssertions.createLongsBlock; import static com.facebook.presto.block.BlockAssertions.createStringsBlock; -import static com.facebook.presto.metadata.FunctionKind.AGGREGATE; import static com.facebook.presto.operator.aggregation.AggregationTestUtils.assertAggregation; -import static com.facebook.presto.spi.type.TypeSignature.parseTypeSignature; +import static com.facebook.presto.spi.type.BigintType.BIGINT; +import static com.facebook.presto.spi.type.BooleanType.BOOLEAN; +import static com.facebook.presto.spi.type.DoubleType.DOUBLE; +import static com.facebook.presto.spi.type.IntegerType.INTEGER; +import static com.facebook.presto.spi.type.VarcharType.VARCHAR; +import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes; import static com.google.common.collect.ImmutableSet.toImmutableSet; import static org.testng.Assert.assertNotNull; public class TestArbitraryAggregation { private static final MetadataManager metadata = MetadataManager.createTestMetadataManager(); + private static final FunctionManager functionManager = metadata.getFunctionManager(); @Test public void testAllRegistered() @@ -45,15 +52,14 @@ public void testAllRegistered() Set allTypes = metadata.getTypeManager().getTypes().stream().collect(toImmutableSet()); for (Type valueType : allTypes) { - assertNotNull(metadata.getFunctionManager().getAggregateFunctionImplementation(new Signature("arbitrary", AGGREGATE, valueType.getTypeSignature(), valueType.getTypeSignature()))); + assertNotNull(getAggregation(valueType)); } } @Test public void testNullBoolean() { - InternalAggregationFunction booleanAgg = metadata.getFunctionManager().getAggregateFunctionImplementation( - new Signature("arbitrary", AGGREGATE, parseTypeSignature(StandardTypes.BOOLEAN), parseTypeSignature(StandardTypes.BOOLEAN))); + InternalAggregationFunction booleanAgg = getAggregation(BOOLEAN); assertAggregation( booleanAgg, null, @@ -63,8 +69,7 @@ public void testNullBoolean() @Test public void testValidBoolean() { - InternalAggregationFunction booleanAgg = metadata.getFunctionManager().getAggregateFunctionImplementation( - new Signature("arbitrary", AGGREGATE, parseTypeSignature(StandardTypes.BOOLEAN), parseTypeSignature(StandardTypes.BOOLEAN))); + InternalAggregationFunction booleanAgg = getAggregation(BOOLEAN); assertAggregation( booleanAgg, true, @@ -74,8 +79,7 @@ public void testValidBoolean() @Test public void testNullLong() { - InternalAggregationFunction longAgg = metadata.getFunctionManager().getAggregateFunctionImplementation( - new Signature("arbitrary", AGGREGATE, parseTypeSignature(StandardTypes.BIGINT), parseTypeSignature(StandardTypes.BIGINT))); + InternalAggregationFunction longAgg = getAggregation(BIGINT); assertAggregation( longAgg, null, @@ -85,8 +89,7 @@ public void testNullLong() @Test public void testValidLong() { - InternalAggregationFunction longAgg = metadata.getFunctionManager().getAggregateFunctionImplementation( - new Signature("arbitrary", AGGREGATE, parseTypeSignature(StandardTypes.BIGINT), parseTypeSignature(StandardTypes.BIGINT))); + InternalAggregationFunction longAgg = getAggregation(BIGINT); assertAggregation( longAgg, 1L, @@ -96,8 +99,7 @@ public void testValidLong() @Test public void testNullDouble() { - InternalAggregationFunction doubleAgg = metadata.getFunctionManager().getAggregateFunctionImplementation( - new Signature("arbitrary", AGGREGATE, parseTypeSignature(StandardTypes.DOUBLE), parseTypeSignature(StandardTypes.DOUBLE))); + InternalAggregationFunction doubleAgg = getAggregation(DOUBLE); assertAggregation( doubleAgg, null, @@ -107,8 +109,7 @@ public void testNullDouble() @Test public void testValidDouble() { - InternalAggregationFunction doubleAgg = metadata.getFunctionManager().getAggregateFunctionImplementation( - new Signature("arbitrary", AGGREGATE, parseTypeSignature(StandardTypes.DOUBLE), parseTypeSignature(StandardTypes.DOUBLE))); + InternalAggregationFunction doubleAgg = getAggregation(DOUBLE); assertAggregation( doubleAgg, 2.0, @@ -118,8 +119,7 @@ public void testValidDouble() @Test public void testNullString() { - InternalAggregationFunction stringAgg = metadata.getFunctionManager().getAggregateFunctionImplementation( - new Signature("arbitrary", AGGREGATE, parseTypeSignature(StandardTypes.VARCHAR), parseTypeSignature(StandardTypes.VARCHAR))); + InternalAggregationFunction stringAgg = getAggregation(VARCHAR); assertAggregation( stringAgg, null, @@ -129,8 +129,7 @@ public void testNullString() @Test public void testValidString() { - InternalAggregationFunction stringAgg = metadata.getFunctionManager().getAggregateFunctionImplementation( - new Signature("arbitrary", AGGREGATE, parseTypeSignature(StandardTypes.VARCHAR), parseTypeSignature(StandardTypes.VARCHAR))); + InternalAggregationFunction stringAgg = getAggregation(VARCHAR); assertAggregation( stringAgg, "a", @@ -140,8 +139,7 @@ public void testValidString() @Test public void testNullArray() { - InternalAggregationFunction arrayAgg = metadata.getFunctionManager().getAggregateFunctionImplementation( - new Signature("arbitrary", AGGREGATE, parseTypeSignature("array(bigint)"), parseTypeSignature("array(bigint)"))); + InternalAggregationFunction arrayAgg = getAggregation(new ArrayType(BIGINT)); assertAggregation( arrayAgg, null, @@ -151,8 +149,7 @@ public void testNullArray() @Test public void testValidArray() { - InternalAggregationFunction arrayAgg = metadata.getFunctionManager().getAggregateFunctionImplementation( - new Signature("arbitrary", AGGREGATE, parseTypeSignature("array(bigint)"), parseTypeSignature("array(bigint)"))); + InternalAggregationFunction arrayAgg = getAggregation(new ArrayType(BIGINT)); assertAggregation( arrayAgg, ImmutableList.of(23L, 45L), @@ -162,11 +159,15 @@ public void testValidArray() @Test public void testValidInt() { - InternalAggregationFunction arrayAgg = metadata.getFunctionManager().getAggregateFunctionImplementation( - new Signature("arbitrary", AGGREGATE, parseTypeSignature("integer"), parseTypeSignature("integer"))); + InternalAggregationFunction intAgg = getAggregation(INTEGER); assertAggregation( - arrayAgg, + intAgg, 3, createIntsBlock(3, 3, null)); } + + private static InternalAggregationFunction getAggregation(Type... arguments) + { + return functionManager.getAggregateFunctionImplementation(functionManager.resolveFunction(TEST_SESSION, QualifiedName.of("arbitrary"), fromTypes(arguments))); + } } diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestArrayAggregation.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestArrayAggregation.java index a6d993dbf854c..7d2b868606774 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestArrayAggregation.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestArrayAggregation.java @@ -13,16 +13,18 @@ */ package com.facebook.presto.operator.aggregation; +import com.facebook.presto.metadata.FunctionManager; import com.facebook.presto.metadata.MetadataManager; -import com.facebook.presto.metadata.Signature; import com.facebook.presto.operator.aggregation.groupByAggregations.AggregationTestInput; import com.facebook.presto.operator.aggregation.groupByAggregations.AggregationTestInputBuilder; import com.facebook.presto.operator.aggregation.groupByAggregations.AggregationTestOutput; import com.facebook.presto.operator.aggregation.groupByAggregations.GroupByAggregationTestUtils; import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.BlockBuilder; +import com.facebook.presto.spi.type.ArrayType; import com.facebook.presto.spi.type.SqlDate; -import com.facebook.presto.spi.type.StandardTypes; +import com.facebook.presto.spi.type.Type; +import com.facebook.presto.sql.tree.QualifiedName; import com.google.common.collect.ImmutableList; import org.testng.annotations.Test; import org.testng.internal.collections.Ints; @@ -33,26 +35,28 @@ import java.util.Optional; import java.util.Random; +import static com.facebook.presto.SessionTestUtils.TEST_SESSION; import static com.facebook.presto.block.BlockAssertions.createArrayBigintBlock; import static com.facebook.presto.block.BlockAssertions.createBooleansBlock; import static com.facebook.presto.block.BlockAssertions.createLongsBlock; import static com.facebook.presto.block.BlockAssertions.createStringsBlock; import static com.facebook.presto.block.BlockAssertions.createTypedLongsBlock; -import static com.facebook.presto.metadata.FunctionKind.AGGREGATE; import static com.facebook.presto.operator.aggregation.AggregationTestUtils.assertAggregation; +import static com.facebook.presto.spi.type.BigintType.BIGINT; +import static com.facebook.presto.spi.type.BooleanType.BOOLEAN; import static com.facebook.presto.spi.type.DateType.DATE; -import static com.facebook.presto.spi.type.TypeSignature.parseTypeSignature; +import static com.facebook.presto.spi.type.VarcharType.VARCHAR; +import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes; import static org.testng.Assert.assertTrue; public class TestArrayAggregation { - private static final MetadataManager metadata = MetadataManager.createTestMetadataManager(); + private static final FunctionManager functionManager = MetadataManager.createTestMetadataManager().getFunctionManager(); @Test public void testEmpty() { - InternalAggregationFunction bigIntAgg = metadata.getFunctionManager().getAggregateFunctionImplementation( - new Signature("array_agg", AGGREGATE, parseTypeSignature("array(bigint)"), parseTypeSignature(StandardTypes.BIGINT))); + InternalAggregationFunction bigIntAgg = getAggregation(BIGINT); assertAggregation( bigIntAgg, null, @@ -62,8 +66,7 @@ public void testEmpty() @Test public void testNullOnly() { - InternalAggregationFunction bigIntAgg = metadata.getFunctionManager().getAggregateFunctionImplementation( - new Signature("array_agg", AGGREGATE, parseTypeSignature("array(bigint)"), parseTypeSignature(StandardTypes.BIGINT))); + InternalAggregationFunction bigIntAgg = getAggregation(BIGINT); assertAggregation( bigIntAgg, Arrays.asList(null, null, null), @@ -73,8 +76,7 @@ public void testNullOnly() @Test public void testNullPartial() { - InternalAggregationFunction bigIntAgg = metadata.getFunctionManager().getAggregateFunctionImplementation( - new Signature("array_agg", AGGREGATE, parseTypeSignature("array(bigint)"), parseTypeSignature(StandardTypes.BIGINT))); + InternalAggregationFunction bigIntAgg = getAggregation(BIGINT); assertAggregation( bigIntAgg, Arrays.asList(null, 2L, null, 3L, null), @@ -84,8 +86,7 @@ public void testNullPartial() @Test public void testBoolean() { - InternalAggregationFunction booleanAgg = metadata.getFunctionManager().getAggregateFunctionImplementation( - new Signature("array_agg", AGGREGATE, parseTypeSignature("array(boolean)"), parseTypeSignature(StandardTypes.BOOLEAN))); + InternalAggregationFunction booleanAgg = getAggregation(BOOLEAN); assertAggregation( booleanAgg, Arrays.asList(true, false), @@ -95,8 +96,7 @@ public void testBoolean() @Test public void testBigInt() { - InternalAggregationFunction bigIntAgg = metadata.getFunctionManager().getAggregateFunctionImplementation( - new Signature("array_agg", AGGREGATE, parseTypeSignature("array(bigint)"), parseTypeSignature(StandardTypes.BIGINT))); + InternalAggregationFunction bigIntAgg = getAggregation(BIGINT); assertAggregation( bigIntAgg, Arrays.asList(2L, 1L, 2L), @@ -106,8 +106,7 @@ public void testBigInt() @Test public void testVarchar() { - InternalAggregationFunction varcharAgg = metadata.getFunctionManager().getAggregateFunctionImplementation( - new Signature("array_agg", AGGREGATE, parseTypeSignature("array(varchar)"), parseTypeSignature(StandardTypes.VARCHAR))); + InternalAggregationFunction varcharAgg = getAggregation(VARCHAR); assertAggregation( varcharAgg, Arrays.asList("hello", "world"), @@ -117,8 +116,7 @@ public void testVarchar() @Test public void testDate() { - InternalAggregationFunction varcharAgg = metadata.getFunctionManager().getAggregateFunctionImplementation( - new Signature("array_agg", AGGREGATE, parseTypeSignature("array(date)"), parseTypeSignature(StandardTypes.DATE))); + InternalAggregationFunction varcharAgg = getAggregation(DATE); assertAggregation( varcharAgg, Arrays.asList(new SqlDate(1), new SqlDate(2), new SqlDate(4)), @@ -128,9 +126,7 @@ public void testDate() @Test public void testArray() { - InternalAggregationFunction varcharAgg = metadata.getFunctionManager().getAggregateFunctionImplementation( - new Signature("array_agg", AGGREGATE, parseTypeSignature("array(array(bigint))"), parseTypeSignature("array(bigint)"))); - + InternalAggregationFunction varcharAgg = getAggregation(new ArrayType(BIGINT)); assertAggregation( varcharAgg, Arrays.asList(Arrays.asList(1L), Arrays.asList(1L, 2L), Arrays.asList(1L, 2L, 3L)), @@ -140,8 +136,7 @@ public void testArray() @Test public void testEmptyStateOutputsNull() { - InternalAggregationFunction bigIntAgg = metadata.getFunctionManager().getAggregateFunctionImplementation( - new Signature("array_agg", AGGREGATE, parseTypeSignature("array(bigint)"), parseTypeSignature(StandardTypes.BIGINT))); + InternalAggregationFunction bigIntAgg = getAggregation(BIGINT); GroupedAccumulator groupedAccumulator = bigIntAgg.bind(Ints.asList(new int[] {}), Optional.empty()) .createGroupedAccumulator(); BlockBuilder blockBuilder = groupedAccumulator.getFinalType().createBlockBuilder(null, 1000); @@ -153,12 +148,7 @@ public void testEmptyStateOutputsNull() @Test public void testWithMultiplePages() { - InternalAggregationFunction varcharAgg = metadata.getFunctionManager().getAggregateFunctionImplementation( - new Signature( - "array_agg", - AGGREGATE, - parseTypeSignature("array(varchar)"), - parseTypeSignature(StandardTypes.VARCHAR))); + InternalAggregationFunction varcharAgg = getAggregation(VARCHAR); AggregationTestInputBuilder testInputBuilder = new AggregationTestInputBuilder( new Block[] { @@ -173,12 +163,7 @@ public void testWithMultiplePages() @Test public void testMultipleGroupsWithMultiplePages() { - InternalAggregationFunction varcharAgg = metadata.getFunctionManager().getAggregateFunctionImplementation( - new Signature( - "array_agg", - AGGREGATE, - parseTypeSignature("array(varchar)"), - parseTypeSignature(StandardTypes.VARCHAR))); + InternalAggregationFunction varcharAgg = getAggregation(VARCHAR); Block block1 = createStringsBlock("a", "b", "c", "d", "e"); Block block2 = createStringsBlock("f", "g", "h", "i", "j"); @@ -203,12 +188,7 @@ public void testMultipleGroupsWithMultiplePages() public void testManyValues() { // Test many values so multiple BlockBuilders will be used to store group state. - InternalAggregationFunction varcharAgg = metadata.getFunctionManager().getAggregateFunctionImplementation( - new Signature( - "array_agg", - AGGREGATE, - parseTypeSignature("array(varchar)"), - parseTypeSignature(StandardTypes.VARCHAR))); + InternalAggregationFunction varcharAgg = getAggregation(VARCHAR); int numGroups = 50000; int arraySize = 30; @@ -242,4 +222,9 @@ private GroupedAccumulator createGroupedAccumulator(InternalAggregationFunction return function.bind(Ints.asList(args), Optional.empty()) .createGroupedAccumulator(); } + + private InternalAggregationFunction getAggregation(Type... arguments) + { + return functionManager.getAggregateFunctionImplementation(functionManager.resolveFunction(TEST_SESSION, QualifiedName.of("array_agg"), fromTypes(arguments))); + } } diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestChecksumAggregation.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestChecksumAggregation.java index cba9bc508ac27..274dabce8d69c 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestChecksumAggregation.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestChecksumAggregation.java @@ -13,20 +13,17 @@ */ package com.facebook.presto.operator.aggregation; +import com.facebook.presto.metadata.FunctionManager; import com.facebook.presto.metadata.MetadataManager; -import com.facebook.presto.metadata.Signature; import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.type.ArrayType; -import com.facebook.presto.spi.type.BigintType; -import com.facebook.presto.spi.type.BooleanType; import com.facebook.presto.spi.type.DecimalType; -import com.facebook.presto.spi.type.DoubleType; import com.facebook.presto.spi.type.SqlVarbinary; import com.facebook.presto.spi.type.Type; -import com.facebook.presto.spi.type.VarbinaryType; -import com.facebook.presto.spi.type.VarcharType; +import com.facebook.presto.sql.tree.QualifiedName; import org.testng.annotations.Test; +import static com.facebook.presto.SessionTestUtils.TEST_SESSION; import static com.facebook.presto.block.BlockAssertions.createArrayBigintBlock; import static com.facebook.presto.block.BlockAssertions.createBooleansBlock; import static com.facebook.presto.block.BlockAssertions.createDoublesBlock; @@ -34,104 +31,83 @@ import static com.facebook.presto.block.BlockAssertions.createLongsBlock; import static com.facebook.presto.block.BlockAssertions.createShortDecimalsBlock; import static com.facebook.presto.block.BlockAssertions.createStringsBlock; -import static com.facebook.presto.metadata.FunctionKind.AGGREGATE; import static com.facebook.presto.operator.aggregation.AggregationTestUtils.assertAggregation; import static com.facebook.presto.operator.aggregation.ChecksumAggregationFunction.PRIME64; -import static com.facebook.presto.spi.type.StandardTypes.BIGINT; -import static com.facebook.presto.spi.type.StandardTypes.BOOLEAN; -import static com.facebook.presto.spi.type.StandardTypes.DOUBLE; -import static com.facebook.presto.spi.type.StandardTypes.VARBINARY; -import static com.facebook.presto.spi.type.StandardTypes.VARCHAR; -import static com.facebook.presto.spi.type.TypeSignature.parseTypeSignature; +import static com.facebook.presto.spi.type.BigintType.BIGINT; +import static com.facebook.presto.spi.type.BooleanType.BOOLEAN; +import static com.facebook.presto.spi.type.DecimalType.createDecimalType; +import static com.facebook.presto.spi.type.DoubleType.DOUBLE; +import static com.facebook.presto.spi.type.VarcharType.VARCHAR; +import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes; import static io.airlift.slice.Slices.wrappedLongArray; import static java.util.Arrays.asList; public class TestChecksumAggregation { - private static final MetadataManager metadata = MetadataManager.createTestMetadataManager(); + private static final FunctionManager functionManager = MetadataManager.createTestMetadataManager().getFunctionManager(); @Test public void testEmpty() { - InternalAggregationFunction booleanAgg = metadata.getFunctionManager().getAggregateFunctionImplementation( - new Signature("checksum", - AGGREGATE, - parseTypeSignature(VARBINARY), - parseTypeSignature(BOOLEAN))); + InternalAggregationFunction booleanAgg = getAggregation(BOOLEAN); assertAggregation(booleanAgg, null, createBooleansBlock()); } @Test public void testBoolean() { - InternalAggregationFunction booleanAgg = metadata.getFunctionManager().getAggregateFunctionImplementation( - new Signature("checksum", - AGGREGATE, - parseTypeSignature(VARBINARY), - parseTypeSignature(BOOLEAN))); + InternalAggregationFunction booleanAgg = getAggregation(BOOLEAN); Block block = createBooleansBlock(null, null, true, false, false); - assertAggregation(booleanAgg, expectedChecksum(BooleanType.BOOLEAN, block), block); + assertAggregation(booleanAgg, expectedChecksum(BOOLEAN, block), block); } @Test public void testLong() { - InternalAggregationFunction longAgg = metadata.getFunctionManager().getAggregateFunctionImplementation( - new Signature("checksum", - AGGREGATE, - parseTypeSignature(VARBINARY), - parseTypeSignature(BIGINT))); + InternalAggregationFunction longAgg = getAggregation(BIGINT); Block block = createLongsBlock(null, 1L, 2L, 100L, null, Long.MAX_VALUE, Long.MIN_VALUE); - assertAggregation(longAgg, expectedChecksum(BigintType.BIGINT, block), block); + assertAggregation(longAgg, expectedChecksum(BIGINT, block), block); } @Test public void testDouble() { - InternalAggregationFunction doubleAgg = metadata.getFunctionManager().getAggregateFunctionImplementation( - new Signature("checksum", - AGGREGATE, - parseTypeSignature(VARBINARY), - parseTypeSignature(DOUBLE))); + InternalAggregationFunction doubleAgg = getAggregation(DOUBLE); Block block = createDoublesBlock(null, 2.0, null, 3.0, Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY, Double.NaN); - assertAggregation(doubleAgg, expectedChecksum(DoubleType.DOUBLE, block), block); + assertAggregation(doubleAgg, expectedChecksum(DOUBLE, block), block); } @Test public void testString() { - InternalAggregationFunction stringAgg = metadata.getFunctionManager().getAggregateFunctionImplementation( - new Signature("checksum", - AGGREGATE, - parseTypeSignature(VARBINARY), - parseTypeSignature(VARCHAR))); + InternalAggregationFunction stringAgg = getAggregation(VARCHAR); Block block = createStringsBlock("a", "a", null, "b", "c"); - assertAggregation(stringAgg, expectedChecksum(VarcharType.VARCHAR, block), block); + assertAggregation(stringAgg, expectedChecksum(VARCHAR, block), block); } @Test public void testShortDecimal() { - InternalAggregationFunction decimalAgg = metadata.getFunctionManager().getAggregateFunctionImplementation(new Signature("checksum", AGGREGATE, parseTypeSignature(VARBINARY), parseTypeSignature("decimal(10,2)"))); + InternalAggregationFunction decimalAgg = getAggregation(createDecimalType(10, 2)); Block block = createShortDecimalsBlock("11.11", "22.22", null, "33.33", "44.44"); - DecimalType shortDecimalType = DecimalType.createDecimalType(1); + DecimalType shortDecimalType = createDecimalType(1); assertAggregation(decimalAgg, expectedChecksum(shortDecimalType, block), block); } @Test public void testLongDecimal() { - InternalAggregationFunction decimalAgg = metadata.getFunctionManager().getAggregateFunctionImplementation(new Signature("checksum", AGGREGATE, parseTypeSignature(VARBINARY), parseTypeSignature("decimal(19,2)"))); + InternalAggregationFunction decimalAgg = getAggregation(createDecimalType(19, 2)); Block block = createLongDecimalsBlock("11.11", "22.22", null, "33.33", "44.44"); - DecimalType longDecimalType = DecimalType.createDecimalType(19); + DecimalType longDecimalType = createDecimalType(19); assertAggregation(decimalAgg, expectedChecksum(longDecimalType, block), block); } @Test public void testArray() { - ArrayType arrayType = new ArrayType(BigintType.BIGINT); - InternalAggregationFunction stringAgg = metadata.getFunctionManager().getAggregateFunctionImplementation(new Signature("checksum", AGGREGATE, VarbinaryType.VARBINARY.getTypeSignature(), arrayType.getTypeSignature())); + ArrayType arrayType = new ArrayType(BIGINT); + InternalAggregationFunction stringAgg = getAggregation(arrayType); Block block = createArrayBigintBlock(asList(null, asList(1L, 2L), asList(3L, 4L), asList(5L, 6L))); assertAggregation(stringAgg, expectedChecksum(arrayType, block), block); } @@ -149,4 +125,9 @@ private static SqlVarbinary expectedChecksum(Type type, Block block) } return new SqlVarbinary(wrappedLongArray(result).getBytes()); } + + private InternalAggregationFunction getAggregation(Type argument) + { + return functionManager.getAggregateFunctionImplementation(functionManager.resolveFunction(TEST_SESSION, QualifiedName.of("checksum"), fromTypes(argument))); + } } diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestDoubleHistogramAggregation.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestDoubleHistogramAggregation.java index f31be048bb12f..bae7de379def5 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestDoubleHistogramAggregation.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestDoubleHistogramAggregation.java @@ -13,17 +13,14 @@ */ package com.facebook.presto.operator.aggregation; -import com.facebook.presto.block.BlockEncodingManager; import com.facebook.presto.metadata.FunctionManager; -import com.facebook.presto.metadata.Signature; +import com.facebook.presto.metadata.MetadataManager; import com.facebook.presto.spi.Page; import com.facebook.presto.spi.PageBuilder; import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.type.MapType; -import com.facebook.presto.spi.type.StandardTypes; -import com.facebook.presto.sql.analyzer.FeaturesConfig; -import com.facebook.presto.type.TypeRegistry; +import com.facebook.presto.sql.tree.QualifiedName; import com.google.common.collect.ImmutableList; import com.google.common.collect.Maps; import org.testng.annotations.Test; @@ -31,12 +28,12 @@ import java.util.Map; import java.util.Optional; -import static com.facebook.presto.metadata.FunctionKind.AGGREGATE; +import static com.facebook.presto.SessionTestUtils.TEST_SESSION; import static com.facebook.presto.operator.aggregation.AggregationTestUtils.getFinalBlock; import static com.facebook.presto.operator.aggregation.AggregationTestUtils.getIntermediateBlock; import static com.facebook.presto.spi.type.BigintType.BIGINT; import static com.facebook.presto.spi.type.DoubleType.DOUBLE; -import static com.facebook.presto.spi.type.TypeSignature.parseTypeSignature; +import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes; import static com.facebook.presto.util.StructuralTestUtil.mapType; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertTrue; @@ -48,17 +45,10 @@ public class TestDoubleHistogramAggregation public TestDoubleHistogramAggregation() { - TypeRegistry typeRegistry = new TypeRegistry(); - FunctionManager functionManager = new FunctionManager(typeRegistry, new BlockEncodingManager(typeRegistry), new FeaturesConfig()); + FunctionManager functionManager = MetadataManager.createTestMetadataManager().getFunctionManager(); InternalAggregationFunction function = functionManager.getAggregateFunctionImplementation( - new Signature("numeric_histogram", - AGGREGATE, - parseTypeSignature("map(double,double)"), - parseTypeSignature(StandardTypes.BIGINT), - parseTypeSignature(StandardTypes.DOUBLE), - parseTypeSignature(StandardTypes.DOUBLE))); + functionManager.resolveFunction(TEST_SESSION, QualifiedName.of("numeric_histogram"), fromTypes(BIGINT, DOUBLE, DOUBLE))); factory = function.bind(ImmutableList.of(0, 1, 2), Optional.empty()); - input = makeInput(10); } diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestHistogram.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestHistogram.java index 2087d3313c1cf..23ca80c7c094c 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestHistogram.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestHistogram.java @@ -14,8 +14,8 @@ package com.facebook.presto.operator.aggregation; +import com.facebook.presto.metadata.FunctionManager; import com.facebook.presto.metadata.MetadataManager; -import com.facebook.presto.metadata.Signature; import com.facebook.presto.operator.aggregation.groupByAggregations.AggregationTestInput; import com.facebook.presto.operator.aggregation.groupByAggregations.AggregationTestInputBuilder; import com.facebook.presto.operator.aggregation.groupByAggregations.AggregationTestOutput; @@ -27,10 +27,10 @@ import com.facebook.presto.spi.type.MapType; import com.facebook.presto.spi.type.RowType; import com.facebook.presto.spi.type.SqlTimestampWithTimeZone; -import com.facebook.presto.spi.type.StandardTypes; import com.facebook.presto.spi.type.TimeZoneKey; -import com.facebook.presto.spi.type.TypeSignature; +import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.analyzer.FeaturesConfig; +import com.facebook.presto.sql.tree.QualifiedName; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import org.joda.time.DateTime; @@ -39,7 +39,6 @@ import org.testng.internal.collections.Ints; import java.util.ArrayList; -import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -48,12 +47,12 @@ import java.util.stream.Collectors; import java.util.stream.IntStream; +import static com.facebook.presto.SessionTestUtils.TEST_SESSION; import static com.facebook.presto.block.BlockAssertions.createBooleansBlock; import static com.facebook.presto.block.BlockAssertions.createDoublesBlock; import static com.facebook.presto.block.BlockAssertions.createLongsBlock; import static com.facebook.presto.block.BlockAssertions.createStringArraysBlock; import static com.facebook.presto.block.BlockAssertions.createStringsBlock; -import static com.facebook.presto.metadata.FunctionKind.AGGREGATE; import static com.facebook.presto.operator.OperatorAssertion.toRow; import static com.facebook.presto.operator.aggregation.AggregationTestUtils.assertAggregation; import static com.facebook.presto.operator.aggregation.histogram.Histogram.NAME; @@ -64,8 +63,8 @@ import static com.facebook.presto.spi.type.DoubleType.DOUBLE; import static com.facebook.presto.spi.type.TimeZoneKey.getTimeZoneKey; import static com.facebook.presto.spi.type.TimestampWithTimeZoneType.TIMESTAMP_WITH_TIME_ZONE; -import static com.facebook.presto.spi.type.TypeSignature.parseTypeSignature; import static com.facebook.presto.spi.type.VarcharType.VARCHAR; +import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes; import static com.facebook.presto.util.DateTimeZoneIndex.getDateTimeZone; import static com.facebook.presto.util.StructuralTestUtil.mapBlockOf; import static com.facebook.presto.util.StructuralTestUtil.mapType; @@ -79,41 +78,25 @@ public class TestHistogram @Test public void testSimpleHistograms() { - MapType mapType = mapType(VARCHAR, BIGINT); - InternalAggregationFunction aggregationFunction = getAggregation(mapType.getTypeSignature(), parseTypeSignature(StandardTypes.VARCHAR)); + InternalAggregationFunction aggregationFunction = getAggregation(VARCHAR); assertAggregation( aggregationFunction, ImmutableMap.of("a", 1L, "b", 1L, "c", 1L), createStringsBlock("a", "b", "c")); - mapType = mapType(BIGINT, BIGINT); - aggregationFunction = getMetadata().getFunctionManager().getAggregateFunctionImplementation( - new Signature(NAME, - AGGREGATE, - mapType.getTypeSignature(), - parseTypeSignature(StandardTypes.BIGINT))); + aggregationFunction = getAggregation(BIGINT); assertAggregation( aggregationFunction, ImmutableMap.of(100L, 1L, 200L, 1L, 300L, 1L), createLongsBlock(100L, 200L, 300L)); - mapType = mapType(DOUBLE, BIGINT); - aggregationFunction = getMetadata().getFunctionManager().getAggregateFunctionImplementation( - new Signature(NAME, - AGGREGATE, - mapType.getTypeSignature(), - parseTypeSignature(StandardTypes.DOUBLE))); + aggregationFunction = getAggregation(DOUBLE); assertAggregation( aggregationFunction, ImmutableMap.of(0.1, 1L, 0.3, 1L, 0.2, 1L), createDoublesBlock(0.1, 0.3, 0.2)); - mapType = mapType(BOOLEAN, BIGINT); - aggregationFunction = getMetadata().getFunctionManager().getAggregateFunctionImplementation( - new Signature(NAME, - AGGREGATE, - mapType.getTypeSignature(), - parseTypeSignature(StandardTypes.BOOLEAN))); + aggregationFunction = getAggregation(BOOLEAN); assertAggregation( aggregationFunction, ImmutableMap.of(true, 1L, false, 1L), @@ -123,29 +106,25 @@ public void testSimpleHistograms() @Test public void testSharedGroupBy() { - MapType mapType = mapType(VARCHAR, BIGINT); - InternalAggregationFunction aggregationFunction = getAggregation(mapType.getTypeSignature(), parseTypeSignature(StandardTypes.VARCHAR)); + InternalAggregationFunction aggregationFunction = getAggregation(VARCHAR); assertAggregation( aggregationFunction, ImmutableMap.of("a", 1L, "b", 1L, "c", 1L), createStringsBlock("a", "b", "c")); - mapType = mapType(BIGINT, BIGINT); - aggregationFunction = getAggregation(mapType.getTypeSignature(), parseTypeSignature(StandardTypes.BIGINT)); + aggregationFunction = getAggregation(BIGINT); assertAggregation( aggregationFunction, ImmutableMap.of(100L, 1L, 200L, 1L, 300L, 1L), createLongsBlock(100L, 200L, 300L)); - mapType = mapType(DOUBLE, BIGINT); - aggregationFunction = getAggregation(mapType.getTypeSignature(), parseTypeSignature(StandardTypes.DOUBLE)); + aggregationFunction = getAggregation(DOUBLE); assertAggregation( aggregationFunction, ImmutableMap.of(0.1, 1L, 0.3, 1L, 0.2, 1L), createDoublesBlock(0.1, 0.3, 0.2)); - mapType = mapType(BOOLEAN, BIGINT); - aggregationFunction = getAggregation(mapType.getTypeSignature(), parseTypeSignature(StandardTypes.BOOLEAN)); + aggregationFunction = getAggregation(BOOLEAN); assertAggregation( aggregationFunction, ImmutableMap.of(true, 1L, false, 1L), @@ -155,19 +134,13 @@ public void testSharedGroupBy() @Test public void testDuplicateKeysValues() { - MapType mapType = mapType(VARCHAR, BIGINT); - InternalAggregationFunction aggregationFunction = getAggregation(mapType.getTypeSignature(), parseTypeSignature(StandardTypes.VARCHAR)); + InternalAggregationFunction aggregationFunction = getAggregation(VARCHAR); assertAggregation( aggregationFunction, ImmutableMap.of("a", 2L, "b", 1L), createStringsBlock("a", "b", "a")); - mapType = mapType(TIMESTAMP_WITH_TIME_ZONE, BIGINT); - aggregationFunction = getMetadata().getFunctionManager().getAggregateFunctionImplementation( - new Signature(NAME, - AGGREGATE, - mapType.getTypeSignature(), - parseTypeSignature(StandardTypes.TIMESTAMP_WITH_TIME_ZONE))); + aggregationFunction = getAggregation(TIMESTAMP_WITH_TIME_ZONE); long timestampWithTimeZone1 = packDateTimeWithZone(new DateTime(1970, 1, 1, 0, 0, 0, 0, DATE_TIME_ZONE).getMillis(), TIME_ZONE_KEY); long timestampWithTimeZone2 = packDateTimeWithZone(new DateTime(2015, 1, 1, 0, 0, 0, 0, DATE_TIME_ZONE).getMillis(), TIME_ZONE_KEY); assertAggregation( @@ -179,19 +152,11 @@ public void testDuplicateKeysValues() @Test public void testWithNulls() { - MapType mapType = mapType(BIGINT, BIGINT); - InternalAggregationFunction aggregationFunction = getAggregation(mapType.getTypeSignature(), parseTypeSignature(StandardTypes.BIGINT)); + InternalAggregationFunction aggregationFunction = getAggregation(BIGINT); assertAggregation( aggregationFunction, ImmutableMap.of(1L, 1L, 2L, 1L), createLongsBlock(2L, null, 1L)); - - mapType = mapType(BIGINT, BIGINT); - aggregationFunction = getMetadata().getFunctionManager().getAggregateFunctionImplementation( - new Signature(NAME, - AGGREGATE, - mapType.getTypeSignature(), - parseTypeSignature(StandardTypes.BIGINT))); assertAggregation( aggregationFunction, null, @@ -202,8 +167,7 @@ public void testWithNulls() public void testArrayHistograms() { ArrayType arrayType = new ArrayType(VARCHAR); - MapType mapType = mapType(arrayType, BIGINT); - InternalAggregationFunction aggregationFunction = getAggregation(mapType.getTypeSignature(), arrayType.getTypeSignature()); + InternalAggregationFunction aggregationFunction = getAggregation(arrayType); assertAggregation( aggregationFunction, ImmutableMap.of(ImmutableList.of("a", "b", "c"), 1L, ImmutableList.of("d", "e", "f"), 1L, ImmutableList.of("c", "b", "a"), 1L), @@ -214,8 +178,7 @@ public void testArrayHistograms() public void testMapHistograms() { MapType innerMapType = mapType(VARCHAR, VARCHAR); - MapType mapType = mapType(innerMapType, BIGINT); - InternalAggregationFunction aggregationFunction = getAggregation(mapType.getTypeSignature(), innerMapType.getTypeSignature()); + InternalAggregationFunction aggregationFunction = getAggregation(innerMapType); BlockBuilder builder = innerMapType.createBlockBuilder(null, 3); innerMapType.writeObject(builder, mapBlockOf(VARCHAR, VARCHAR, ImmutableMap.of("a", "b"))); @@ -234,8 +197,7 @@ public void testRowHistograms() RowType innerRowType = RowType.from(ImmutableList.of( RowType.field("f1", BIGINT), RowType.field("f2", DOUBLE))); - MapType mapType = mapType(innerRowType, BIGINT); - InternalAggregationFunction aggregationFunction = getAggregation(mapType.getTypeSignature(), innerRowType.getTypeSignature()); + InternalAggregationFunction aggregationFunction = getAggregation(innerRowType); BlockBuilder builder = innerRowType.createBlockBuilder(null, 3); innerRowType.writeObject(builder, toRow(ImmutableList.of(BIGINT, DOUBLE), 1L, 1.0)); innerRowType.writeObject(builder, toRow(ImmutableList.of(BIGINT, DOUBLE), 2L, 2.0)); @@ -250,8 +212,7 @@ public void testRowHistograms() @Test public void testLargerHistograms() { - MapType mapType = mapType(VARCHAR, BIGINT); - InternalAggregationFunction aggregationFunction = getAggregation(mapType.getTypeSignature(), parseTypeSignature(StandardTypes.VARCHAR)); + InternalAggregationFunction aggregationFunction = getInternalDefaultVarCharAggregationn(); assertAggregation( aggregationFunction, ImmutableMap.of("a", 25L, "b", 10L, "c", 12L, "d", 1L, "e", 2L), @@ -424,32 +385,23 @@ private void testSharedGroupByWithOverlappingValuesRunner(InternalAggregationFun private InternalAggregationFunction getInternalDefaultVarCharAggregationn() { - TypeSignature returnType = mapType(VARCHAR, BIGINT).getTypeSignature(); - TypeSignature argumentType = parseTypeSignature(StandardTypes.VARCHAR); - - return getAggregation(returnType, argumentType); + return getAggregation(VARCHAR); } - private InternalAggregationFunction getAggregation(TypeSignature returnType, TypeSignature... arguments) + private InternalAggregationFunction getAggregation(Type... arguments) { - MetadataManager metadata = getMetadata(NEW); - Signature signature = new Signature(NAME, - AGGREGATE, - returnType, - Arrays.asList(arguments)); - return metadata.getFunctionManager().getAggregateFunctionImplementation(signature); + FunctionManager functionManager = getFunctionManager(NEW); + return functionManager.getAggregateFunctionImplementation(functionManager.resolveFunction(TEST_SESSION, QualifiedName.of(NAME), fromTypes(arguments))); } - public MetadataManager getMetadata() + public FunctionManager getFunctionManager() { - return getMetadata(NEW); + return getFunctionManager(NEW); } - public MetadataManager getMetadata(HistogramGroupImplementation groupMode) + public FunctionManager getFunctionManager(HistogramGroupImplementation groupMode) { - MetadataManager metadata = MetadataManager.createTestMetadataManager(new FeaturesConfig() - .setHistogramGroupImplementation(groupMode)); - - return metadata; + return MetadataManager.createTestMetadataManager(new FeaturesConfig() + .setHistogramGroupImplementation(groupMode)).getFunctionManager(); } } diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestMapAggAggregation.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestMapAggAggregation.java index 30aa8e222f98d..acc30da22ee00 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestMapAggAggregation.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestMapAggAggregation.java @@ -13,13 +13,14 @@ */ package com.facebook.presto.operator.aggregation; +import com.facebook.presto.metadata.FunctionManager; import com.facebook.presto.metadata.MetadataManager; -import com.facebook.presto.metadata.Signature; import com.facebook.presto.spi.block.BlockBuilder; import com.facebook.presto.spi.type.ArrayType; import com.facebook.presto.spi.type.MapType; import com.facebook.presto.spi.type.RowType; -import com.facebook.presto.spi.type.StandardTypes; +import com.facebook.presto.spi.type.Type; +import com.facebook.presto.sql.tree.QualifiedName; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import org.testng.annotations.Test; @@ -27,50 +28,38 @@ import java.util.LinkedHashMap; import java.util.Map; +import static com.facebook.presto.SessionTestUtils.TEST_SESSION; import static com.facebook.presto.block.BlockAssertions.createBooleansBlock; import static com.facebook.presto.block.BlockAssertions.createDoublesBlock; import static com.facebook.presto.block.BlockAssertions.createStringArraysBlock; import static com.facebook.presto.block.BlockAssertions.createStringsBlock; import static com.facebook.presto.block.BlockAssertions.createTypedLongsBlock; -import static com.facebook.presto.metadata.FunctionKind.AGGREGATE; import static com.facebook.presto.operator.OperatorAssertion.toRow; import static com.facebook.presto.operator.aggregation.AggregationTestUtils.assertAggregation; import static com.facebook.presto.operator.aggregation.MapAggregationFunction.NAME; import static com.facebook.presto.spi.type.BooleanType.BOOLEAN; import static com.facebook.presto.spi.type.DoubleType.DOUBLE; import static com.facebook.presto.spi.type.IntegerType.INTEGER; -import static com.facebook.presto.spi.type.TypeSignature.parseTypeSignature; import static com.facebook.presto.spi.type.VarcharType.VARCHAR; +import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes; import static com.facebook.presto.util.StructuralTestUtil.mapBlockOf; import static com.facebook.presto.util.StructuralTestUtil.mapType; public class TestMapAggAggregation { - private static final MetadataManager metadata = MetadataManager.createTestMetadataManager(); + private static final FunctionManager functionManager = MetadataManager.createTestMetadataManager().getFunctionManager(); @Test public void testDuplicateKeysValues() { - MapType mapType = mapType(DOUBLE, VARCHAR); - InternalAggregationFunction aggFunc = metadata.getFunctionManager().getAggregateFunctionImplementation( - new Signature(NAME, - AGGREGATE, - mapType.getTypeSignature(), - parseTypeSignature(StandardTypes.DOUBLE), - parseTypeSignature(StandardTypes.VARCHAR))); + InternalAggregationFunction aggFunc = getAggregation(DOUBLE, VARCHAR); assertAggregation( aggFunc, ImmutableMap.of(1.0, "a"), createDoublesBlock(1.0, 1.0, 1.0), createStringsBlock("a", "b", "c")); - mapType = mapType(DOUBLE, INTEGER); - aggFunc = metadata.getFunctionManager().getAggregateFunctionImplementation( - new Signature(NAME, - AGGREGATE, - mapType.getTypeSignature(), - parseTypeSignature(StandardTypes.DOUBLE), - parseTypeSignature(StandardTypes.INTEGER))); + aggFunc = getAggregation(DOUBLE, INTEGER); assertAggregation( aggFunc, ImmutableMap.of(1.0, 99, 2.0, 99, 3.0, 99), @@ -81,39 +70,21 @@ public void testDuplicateKeysValues() @Test public void testSimpleMaps() { - MapType mapType = mapType(DOUBLE, VARCHAR); - InternalAggregationFunction aggFunc = metadata.getFunctionManager().getAggregateFunctionImplementation( - new Signature(NAME, - AGGREGATE, - mapType.getTypeSignature(), - parseTypeSignature(StandardTypes.DOUBLE), - parseTypeSignature(StandardTypes.VARCHAR))); + InternalAggregationFunction aggFunc = getAggregation(DOUBLE, VARCHAR); assertAggregation( aggFunc, ImmutableMap.of(1.0, "a", 2.0, "b", 3.0, "c"), createDoublesBlock(1.0, 2.0, 3.0), createStringsBlock("a", "b", "c")); - mapType = mapType(DOUBLE, INTEGER); - aggFunc = metadata.getFunctionManager().getAggregateFunctionImplementation( - new Signature(NAME, - AGGREGATE, - mapType.getTypeSignature(), - parseTypeSignature(StandardTypes.DOUBLE), - parseTypeSignature(StandardTypes.INTEGER))); + aggFunc = getAggregation(DOUBLE, INTEGER); assertAggregation( aggFunc, ImmutableMap.of(1.0, 3, 2.0, 2, 3.0, 1), createDoublesBlock(1.0, 2.0, 3.0), createTypedLongsBlock(INTEGER, ImmutableList.of(3L, 2L, 1L))); - mapType = mapType(DOUBLE, BOOLEAN); - aggFunc = metadata.getFunctionManager().getAggregateFunctionImplementation( - new Signature(NAME, - AGGREGATE, - mapType.getTypeSignature(), - parseTypeSignature(StandardTypes.DOUBLE), - parseTypeSignature(StandardTypes.BOOLEAN))); + aggFunc = getAggregation(DOUBLE, BOOLEAN); assertAggregation( aggFunc, ImmutableMap.of(1.0, true, 2.0, false, 3.0, false), @@ -124,12 +95,7 @@ public void testSimpleMaps() @Test public void testNull() { - InternalAggregationFunction doubleDouble = metadata.getFunctionManager().getAggregateFunctionImplementation( - new Signature(NAME, - AGGREGATE, - mapType(DOUBLE, DOUBLE).getTypeSignature(), - parseTypeSignature(StandardTypes.DOUBLE), - parseTypeSignature(StandardTypes.DOUBLE))); + InternalAggregationFunction doubleDouble = getAggregation(DOUBLE, DOUBLE); assertAggregation( doubleDouble, ImmutableMap.of(1.0, 2.0), @@ -156,14 +122,7 @@ public void testNull() @Test public void testDoubleArrayMap() { - ArrayType arrayType = new ArrayType(VARCHAR); - MapType mapType = mapType(DOUBLE, arrayType); - InternalAggregationFunction aggFunc = metadata.getFunctionManager().getAggregateFunctionImplementation(new Signature(NAME, - AGGREGATE, - mapType.getTypeSignature(), - parseTypeSignature(StandardTypes.DOUBLE), - arrayType.getTypeSignature())); - + InternalAggregationFunction aggFunc = getAggregation(DOUBLE, new ArrayType(VARCHAR)); assertAggregation( aggFunc, ImmutableMap.of(1.0, ImmutableList.of("a", "b"), @@ -177,12 +136,7 @@ public void testDoubleArrayMap() public void testDoubleMapMap() { MapType innerMapType = mapType(VARCHAR, VARCHAR); - MapType mapType = mapType(DOUBLE, innerMapType); - InternalAggregationFunction aggFunc = metadata.getFunctionManager().getAggregateFunctionImplementation(new Signature(NAME, - AGGREGATE, - mapType.getTypeSignature(), - parseTypeSignature(StandardTypes.DOUBLE), - innerMapType.getTypeSignature())); + InternalAggregationFunction aggFunc = getAggregation(DOUBLE, innerMapType); BlockBuilder builder = innerMapType.createBlockBuilder(null, 3); innerMapType.writeObject(builder, mapBlockOf(VARCHAR, VARCHAR, ImmutableMap.of("a", "b"))); @@ -204,12 +158,7 @@ public void testDoubleRowMap() RowType innerRowType = RowType.from(ImmutableList.of( RowType.field("f1", INTEGER), RowType.field("f2", DOUBLE))); - MapType mapType = mapType(DOUBLE, innerRowType); - InternalAggregationFunction aggFunc = metadata.getFunctionManager().getAggregateFunctionImplementation(new Signature(NAME, - AGGREGATE, - mapType.getTypeSignature(), - parseTypeSignature(StandardTypes.DOUBLE), - innerRowType.getTypeSignature())); + InternalAggregationFunction aggFunc = getAggregation(DOUBLE, innerRowType); BlockBuilder builder = innerRowType.createBlockBuilder(null, 3); innerRowType.writeObject(builder, toRow(ImmutableList.of(INTEGER, DOUBLE), 1L, 1.0)); @@ -228,15 +177,7 @@ public void testDoubleRowMap() @Test public void testArrayDoubleMap() { - ArrayType arrayType = new ArrayType(VARCHAR); - MapType mapType = mapType(arrayType, DOUBLE); - InternalAggregationFunction aggFunc = metadata.getFunctionManager().getAggregateFunctionImplementation(new Signature( - NAME, - AGGREGATE, - mapType.getTypeSignature(), - arrayType.getTypeSignature(), - parseTypeSignature(StandardTypes.DOUBLE))); - + InternalAggregationFunction aggFunc = getAggregation(new ArrayType(VARCHAR), DOUBLE); assertAggregation( aggFunc, ImmutableMap.of( @@ -246,4 +187,9 @@ public void testArrayDoubleMap() createStringArraysBlock(ImmutableList.of(ImmutableList.of("a", "b"), ImmutableList.of("c", "d"), ImmutableList.of("e", "f"))), createDoublesBlock(1.0, 2.0, 3.0)); } + + private InternalAggregationFunction getAggregation(Type... arguments) + { + return functionManager.getAggregateFunctionImplementation(functionManager.resolveFunction(TEST_SESSION, QualifiedName.of(NAME), fromTypes(arguments))); + } } diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestMapUnionAggregation.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestMapUnionAggregation.java index a7f7774ec158f..ebcd70665123d 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestMapUnionAggregation.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestMapUnionAggregation.java @@ -13,10 +13,12 @@ */ package com.facebook.presto.operator.aggregation; +import com.facebook.presto.metadata.FunctionHandle; +import com.facebook.presto.metadata.FunctionManager; import com.facebook.presto.metadata.MetadataManager; -import com.facebook.presto.metadata.Signature; import com.facebook.presto.spi.type.ArrayType; import com.facebook.presto.spi.type.MapType; +import com.facebook.presto.sql.tree.QualifiedName; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import org.testng.annotations.Test; @@ -24,13 +26,14 @@ import java.util.HashMap; import java.util.Map; -import static com.facebook.presto.metadata.FunctionKind.AGGREGATE; +import static com.facebook.presto.SessionTestUtils.TEST_SESSION; import static com.facebook.presto.operator.aggregation.AggregationTestUtils.assertAggregation; import static com.facebook.presto.operator.aggregation.MapUnionAggregation.NAME; import static com.facebook.presto.spi.type.BigintType.BIGINT; import static com.facebook.presto.spi.type.BooleanType.BOOLEAN; import static com.facebook.presto.spi.type.DoubleType.DOUBLE; import static com.facebook.presto.spi.type.VarcharType.VARCHAR; +import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes; import static com.facebook.presto.util.StructuralTestUtil.arrayBlockOf; import static com.facebook.presto.util.StructuralTestUtil.mapBlockOf; import static com.facebook.presto.util.StructuralTestUtil.mapType; @@ -38,14 +41,14 @@ public class TestMapUnionAggregation { - private static final MetadataManager metadata = MetadataManager.createTestMetadataManager(); + private static final FunctionManager functionManager = MetadataManager.createTestMetadataManager().getFunctionManager(); @Test public void testSimpleWithDuplicates() { MapType mapType = mapType(DOUBLE, VARCHAR); - InternalAggregationFunction aggFunc = metadata.getFunctionManager().getAggregateFunctionImplementation( - new Signature(NAME, AGGREGATE, mapType.getTypeSignature(), mapType.getTypeSignature())); + FunctionHandle functionHandle = functionManager.resolveFunction(TEST_SESSION, QualifiedName.of(NAME), fromTypes(mapType)); + InternalAggregationFunction aggFunc = functionManager.getAggregateFunctionImplementation(functionHandle); assertAggregation( aggFunc, ImmutableMap.of(23.0, "aaa", 33.0, "bbb", 43.0, "ccc", 53.0, "ddd", 13.0, "eee"), @@ -55,8 +58,8 @@ public void testSimpleWithDuplicates() mapBlockOf(DOUBLE, VARCHAR, ImmutableMap.of(43.0, "ccc", 53.0, "ddd", 13.0, "eee")))); mapType = mapType(DOUBLE, BIGINT); - aggFunc = metadata.getFunctionManager().getAggregateFunctionImplementation( - new Signature(NAME, AGGREGATE, mapType.getTypeSignature(), mapType.getTypeSignature())); + functionHandle = functionManager.resolveFunction(TEST_SESSION, QualifiedName.of(NAME), fromTypes(mapType)); + aggFunc = functionManager.getAggregateFunctionImplementation(functionHandle); assertAggregation( aggFunc, ImmutableMap.of(1.0, 99L, 2.0, 99L, 3.0, 99L, 4.0, 44L), @@ -66,8 +69,8 @@ public void testSimpleWithDuplicates() mapBlockOf(DOUBLE, BIGINT, ImmutableMap.of(1.0, 44L, 2.0, 44L, 4.0, 44L)))); mapType = mapType(BOOLEAN, BIGINT); - aggFunc = metadata.getFunctionManager().getAggregateFunctionImplementation( - new Signature(NAME, AGGREGATE, mapType.getTypeSignature(), mapType.getTypeSignature())); + functionHandle = functionManager.resolveFunction(TEST_SESSION, QualifiedName.of(NAME), fromTypes(mapType)); + aggFunc = functionManager.getAggregateFunctionImplementation(functionHandle); assertAggregation( aggFunc, ImmutableMap.of(false, 12L, true, 13L), @@ -81,8 +84,8 @@ public void testSimpleWithDuplicates() public void testSimpleWithNulls() { MapType mapType = mapType(DOUBLE, VARCHAR); - InternalAggregationFunction aggFunc = metadata.getFunctionManager().getAggregateFunctionImplementation( - new Signature(NAME, AGGREGATE, mapType.getTypeSignature(), mapType.getTypeSignature())); + FunctionHandle functionHandle = functionManager.resolveFunction(TEST_SESSION, QualifiedName.of(NAME), fromTypes(mapType)); + InternalAggregationFunction aggFunc = functionManager.getAggregateFunctionImplementation(functionHandle); Map expected = mapOf(23.0, "aaa", 33.0, null, 43.0, "ccc", 53.0, "ddd"); @@ -100,8 +103,8 @@ public void testSimpleWithNulls() public void testStructural() { MapType mapType = mapType(DOUBLE, new ArrayType(VARCHAR)); - InternalAggregationFunction aggFunc = metadata.getFunctionManager().getAggregateFunctionImplementation( - new Signature(NAME, AGGREGATE, mapType.getTypeSignature(), mapType.getTypeSignature())); + FunctionHandle functionHandle = functionManager.resolveFunction(TEST_SESSION, QualifiedName.of(NAME), fromTypes(mapType)); + InternalAggregationFunction aggFunc = functionManager.getAggregateFunctionImplementation(functionHandle); assertAggregation( aggFunc, ImmutableMap.of( @@ -133,8 +136,8 @@ public void testStructural() ImmutableList.of("w", "z"))))); mapType = mapType(DOUBLE, mapType(VARCHAR, VARCHAR)); - aggFunc = metadata.getFunctionManager().getAggregateFunctionImplementation( - new Signature(NAME, AGGREGATE, mapType.getTypeSignature(), mapType.getTypeSignature())); + functionHandle = functionManager.resolveFunction(TEST_SESSION, QualifiedName.of(NAME), fromTypes(mapType)); + aggFunc = functionManager.getAggregateFunctionImplementation(functionHandle); assertAggregation( aggFunc, ImmutableMap.of( @@ -159,8 +162,8 @@ public void testStructural() ImmutableMap.of("e", "f"))))); mapType = mapType(new ArrayType(VARCHAR), DOUBLE); - aggFunc = metadata.getFunctionManager().getAggregateFunctionImplementation( - new Signature(NAME, AGGREGATE, mapType.getTypeSignature(), mapType.getTypeSignature())); + functionHandle = functionManager.resolveFunction(TEST_SESSION, QualifiedName.of(NAME), fromTypes(mapType)); + aggFunc = functionManager.getAggregateFunctionImplementation(functionHandle); assertAggregation( aggFunc, ImmutableMap.of( diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestMultimapAggAggregation.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestMultimapAggAggregation.java index 36483e374333d..030dbca4eea64 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestMultimapAggAggregation.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestMultimapAggAggregation.java @@ -14,8 +14,7 @@ package com.facebook.presto.operator.aggregation; import com.facebook.presto.RowPageBuilder; -import com.facebook.presto.metadata.MetadataManager; -import com.facebook.presto.metadata.Signature; +import com.facebook.presto.metadata.FunctionManager; import com.facebook.presto.operator.aggregation.groupByAggregations.AggregationTestInput; import com.facebook.presto.operator.aggregation.groupByAggregations.AggregationTestInputBuilder; import com.facebook.presto.operator.aggregation.groupByAggregations.AggregationTestOutput; @@ -24,9 +23,9 @@ import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.BlockBuilder; import com.facebook.presto.spi.type.ArrayType; -import com.facebook.presto.spi.type.MapType; import com.facebook.presto.spi.type.RowType; import com.facebook.presto.spi.type.Type; +import com.facebook.presto.sql.tree.QualifiedName; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableMultimap; @@ -40,20 +39,21 @@ import java.util.Optional; import java.util.Random; -import static com.facebook.presto.metadata.FunctionKind.AGGREGATE; +import static com.facebook.presto.SessionTestUtils.TEST_SESSION; import static com.facebook.presto.metadata.MetadataManager.createTestMetadataManager; import static com.facebook.presto.operator.aggregation.AggregationTestUtils.assertAggregation; import static com.facebook.presto.operator.aggregation.multimapagg.MultimapAggregationFunction.NAME; import static com.facebook.presto.spi.type.BigintType.BIGINT; import static com.facebook.presto.spi.type.DoubleType.DOUBLE; import static com.facebook.presto.spi.type.VarcharType.VARCHAR; +import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes; import static com.facebook.presto.util.StructuralTestUtil.mapType; import static com.google.common.base.Preconditions.checkState; import static org.testng.Assert.assertTrue; public class TestMultimapAggAggregation { - private static final MetadataManager metadata = createTestMetadataManager(); + private static final FunctionManager functionManager = createTestMetadataManager().getFunctionManager(); @Test public void testSingleValueMap() @@ -183,9 +183,7 @@ private static void testMultimapAgg(Type keyType, List expectedKeys, T private static InternalAggregationFunction getInternalAggregationFunction(Type keyType, Type valueType) { - MapType mapType = mapType(keyType, new ArrayType(valueType)); - Signature signature = new Signature(NAME, AGGREGATE, mapType.getTypeSignature(), keyType.getTypeSignature(), valueType.getTypeSignature()); - return metadata.getFunctionManager().getAggregateFunctionImplementation(signature); + return functionManager.getAggregateFunctionImplementation(functionManager.resolveFunction(TEST_SESSION, QualifiedName.of(NAME), fromTypes(keyType, valueType))); } private static void testMultimapAgg(InternalAggregationFunction aggFunc, Type keyType, List expectedKeys, Type valueType, List expectedValues) diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestQuantileDigestAggregationFunction.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestQuantileDigestAggregationFunction.java index a07485865aba2..d7e72c4396fa1 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestQuantileDigestAggregationFunction.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestQuantileDigestAggregationFunction.java @@ -13,14 +13,15 @@ */ package com.facebook.presto.operator.aggregation; +import com.facebook.presto.metadata.FunctionManager; import com.facebook.presto.metadata.MetadataManager; -import com.facebook.presto.metadata.Signature; import com.facebook.presto.operator.scalar.AbstractTestFunctions; import com.facebook.presto.spi.Page; import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.type.SqlVarbinary; import com.facebook.presto.spi.type.StandardTypes; -import com.facebook.presto.spi.type.TypeSignature; +import com.facebook.presto.spi.type.Type; +import com.facebook.presto.sql.tree.QualifiedName; import com.google.common.base.Joiner; import com.google.common.primitives.Floats; import io.airlift.stats.QuantileDigest; @@ -32,6 +33,7 @@ import java.util.stream.Collectors; import java.util.stream.LongStream; +import static com.facebook.presto.SessionTestUtils.TEST_SESSION; import static com.facebook.presto.block.BlockAssertions.createBlockOfReals; import static com.facebook.presto.block.BlockAssertions.createDoubleSequenceBlock; import static com.facebook.presto.block.BlockAssertions.createDoublesBlock; @@ -39,13 +41,16 @@ import static com.facebook.presto.block.BlockAssertions.createLongsBlock; import static com.facebook.presto.block.BlockAssertions.createRLEBlock; import static com.facebook.presto.block.BlockAssertions.createSequenceBlockOfReal; -import static com.facebook.presto.metadata.FunctionKind.AGGREGATE; import static com.facebook.presto.operator.aggregation.AggregationTestUtils.assertAggregation; import static com.facebook.presto.operator.aggregation.FloatingPointBitsConverterUtil.doubleToSortableLong; import static com.facebook.presto.operator.aggregation.FloatingPointBitsConverterUtil.floatToSortableInt; import static com.facebook.presto.operator.aggregation.TestMergeQuantileDigestFunction.QDIGEST_EQUALITY; +import static com.facebook.presto.spi.type.BigintType.BIGINT; import static com.facebook.presto.spi.type.BooleanType.BOOLEAN; +import static com.facebook.presto.spi.type.DoubleType.DOUBLE; +import static com.facebook.presto.spi.type.RealType.REAL; import static com.facebook.presto.spi.type.TypeSignature.parseTypeSignature; +import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes; import static com.google.common.collect.ImmutableList.toImmutableList; import static java.lang.Double.NaN; import static java.lang.Integer.max; @@ -160,34 +165,31 @@ public void testBigintsWithWeight() LongStream.range(-1000, 1000).toArray()); } - private InternalAggregationFunction getAggregationFunction(String... type) + private InternalAggregationFunction getAggregationFunction(Type... type) { - TypeSignature[] typeSignatures = Arrays.stream(type).map(TypeSignature::parseTypeSignature).toArray(TypeSignature[]::new); - return METADATA.getFunctionManager().getAggregateFunctionImplementation( - new Signature("qdigest_agg", - AGGREGATE, - parseTypeSignature(format("qdigest(%s)", type[0])), - typeSignatures)); + FunctionManager functionManager = METADATA.getFunctionManager(); + return functionManager.getAggregateFunctionImplementation( + functionManager.resolveFunction(TEST_SESSION, QualifiedName.of("qdigest_agg"), fromTypes(type))); } private void testAggregationBigint(Block inputBlock, Block weightsBlock, double maxError, long... inputs) { // Test without weights and accuracy testAggregationBigints( - getAggregationFunction(StandardTypes.BIGINT), + getAggregationFunction(BIGINT), new Page(inputBlock), maxError, inputs); // Test with weights and without accuracy testAggregationBigints( - getAggregationFunction(StandardTypes.BIGINT, StandardTypes.BIGINT), + getAggregationFunction(BIGINT, BIGINT), new Page(inputBlock, weightsBlock), maxError, inputs); // Test with weights and accuracy testAggregationBigints( - getAggregationFunction(StandardTypes.BIGINT, StandardTypes.BIGINT, StandardTypes.DOUBLE), + getAggregationFunction(BIGINT, BIGINT, DOUBLE), new Page(inputBlock, weightsBlock, createRLEBlock(maxError, inputBlock.getPositionCount())), maxError, inputs); @@ -197,19 +199,19 @@ private void testAggregationReal(Block longsBlock, Block weightsBlock, double ma { // Test without weights and accuracy testAggregationReal( - getAggregationFunction(StandardTypes.REAL), + getAggregationFunction(REAL), new Page(longsBlock), maxError, inputs); // Test with weights and without accuracy testAggregationReal( - getAggregationFunction(StandardTypes.REAL, StandardTypes.BIGINT), + getAggregationFunction(REAL, BIGINT), new Page(longsBlock, weightsBlock), maxError, inputs); // Test with weights and accuracy testAggregationReal( - getAggregationFunction(StandardTypes.REAL, StandardTypes.BIGINT, StandardTypes.DOUBLE), + getAggregationFunction(REAL, BIGINT, DOUBLE), new Page(longsBlock, weightsBlock, createRLEBlock(maxError, longsBlock.getPositionCount())), maxError, inputs); @@ -219,19 +221,19 @@ private void testAggregationDouble(Block longsBlock, Block weightsBlock, double { // Test without weights and accuracy testAggregationDoubles( - getAggregationFunction(StandardTypes.DOUBLE), + getAggregationFunction(DOUBLE), new Page(longsBlock), maxError, inputs); // Test with weights and without accuracy testAggregationDoubles( - getAggregationFunction(StandardTypes.DOUBLE, StandardTypes.BIGINT), + getAggregationFunction(DOUBLE, BIGINT), new Page(longsBlock, weightsBlock), maxError, inputs); // Test with weights and accuracy testAggregationDoubles( - getAggregationFunction(StandardTypes.DOUBLE, StandardTypes.BIGINT, StandardTypes.DOUBLE), + getAggregationFunction(DOUBLE, BIGINT, DOUBLE), new Page(longsBlock, weightsBlock, createRLEBlock(maxError, longsBlock.getPositionCount())), maxError, inputs); diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestRealHistogramAggregation.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestRealHistogramAggregation.java index 0dab9a6a05085..ae19d0f9cddff 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestRealHistogramAggregation.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestRealHistogramAggregation.java @@ -15,14 +15,13 @@ import com.facebook.presto.block.BlockEncodingManager; import com.facebook.presto.metadata.FunctionManager; -import com.facebook.presto.metadata.Signature; import com.facebook.presto.spi.Page; import com.facebook.presto.spi.PageBuilder; import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.type.MapType; -import com.facebook.presto.spi.type.StandardTypes; import com.facebook.presto.sql.analyzer.FeaturesConfig; +import com.facebook.presto.sql.tree.QualifiedName; import com.facebook.presto.type.TypeRegistry; import com.google.common.collect.ImmutableList; import com.google.common.collect.Maps; @@ -31,13 +30,13 @@ import java.util.Map; import java.util.Optional; -import static com.facebook.presto.metadata.FunctionKind.AGGREGATE; +import static com.facebook.presto.SessionTestUtils.TEST_SESSION; import static com.facebook.presto.operator.aggregation.AggregationTestUtils.getFinalBlock; import static com.facebook.presto.operator.aggregation.AggregationTestUtils.getIntermediateBlock; import static com.facebook.presto.spi.type.BigintType.BIGINT; import static com.facebook.presto.spi.type.DoubleType.DOUBLE; import static com.facebook.presto.spi.type.RealType.REAL; -import static com.facebook.presto.spi.type.TypeSignature.parseTypeSignature; +import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes; import static com.facebook.presto.util.StructuralTestUtil.mapType; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertTrue; @@ -52,14 +51,8 @@ public TestRealHistogramAggregation() TypeRegistry typeRegistry = new TypeRegistry(); FunctionManager functionManager = new FunctionManager(typeRegistry, new BlockEncodingManager(typeRegistry), new FeaturesConfig()); InternalAggregationFunction function = functionManager.getAggregateFunctionImplementation( - new Signature("numeric_histogram", - AGGREGATE, - parseTypeSignature("map(real, real)"), - parseTypeSignature(StandardTypes.BIGINT), - parseTypeSignature(StandardTypes.REAL), - parseTypeSignature(StandardTypes.DOUBLE))); + functionManager.resolveFunction(TEST_SESSION, QualifiedName.of("numeric_histogram"), fromTypes(BIGINT, REAL, DOUBLE))); factory = function.bind(ImmutableList.of(0, 1, 2), Optional.empty()); - input = makeInput(10); } diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/minmaxby/TestMinMaxByAggregation.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/minmaxby/TestMinMaxByAggregation.java index ec0d2da1c7b5d..0584ab79a68c4 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/minmaxby/TestMinMaxByAggregation.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/minmaxby/TestMinMaxByAggregation.java @@ -13,22 +13,22 @@ */ package com.facebook.presto.operator.aggregation.minmaxby; +import com.facebook.presto.metadata.FunctionManager; import com.facebook.presto.metadata.MetadataManager; -import com.facebook.presto.metadata.Signature; import com.facebook.presto.operator.aggregation.InternalAggregationFunction; import com.facebook.presto.operator.aggregation.state.StateCompiler; -import com.facebook.presto.spi.type.DecimalType; +import com.facebook.presto.spi.type.ArrayType; import com.facebook.presto.spi.type.RowType; import com.facebook.presto.spi.type.SqlDecimal; -import com.facebook.presto.spi.type.StandardTypes; import com.facebook.presto.spi.type.Type; -import com.facebook.presto.type.UnknownType; +import com.facebook.presto.sql.tree.QualifiedName; import com.google.common.collect.ImmutableList; import org.testng.annotations.Test; import java.util.List; import java.util.Set; +import static com.facebook.presto.SessionTestUtils.TEST_SESSION; import static com.facebook.presto.block.BlockAssertions.createArrayBigintBlock; import static com.facebook.presto.block.BlockAssertions.createBooleansBlock; import static com.facebook.presto.block.BlockAssertions.createDoublesBlock; @@ -37,19 +37,23 @@ import static com.facebook.presto.block.BlockAssertions.createLongsBlock; import static com.facebook.presto.block.BlockAssertions.createShortDecimalsBlock; import static com.facebook.presto.block.BlockAssertions.createStringsBlock; -import static com.facebook.presto.metadata.FunctionKind.AGGREGATE; import static com.facebook.presto.operator.aggregation.AggregationTestUtils.assertAggregation; import static com.facebook.presto.spi.type.BigintType.BIGINT; +import static com.facebook.presto.spi.type.BooleanType.BOOLEAN; +import static com.facebook.presto.spi.type.DecimalType.createDecimalType; import static com.facebook.presto.spi.type.DoubleType.DOUBLE; -import static com.facebook.presto.spi.type.TypeSignature.parseTypeSignature; +import static com.facebook.presto.spi.type.IntegerType.INTEGER; import static com.facebook.presto.spi.type.VarcharType.VARCHAR; +import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes; +import static com.facebook.presto.type.UnknownType.UNKNOWN; import static com.google.common.collect.ImmutableSet.toImmutableSet; import static java.util.Arrays.asList; import static org.testng.Assert.assertNotNull; public class TestMinMaxByAggregation { - private static final MetadataManager METADATA = MetadataManager.createTestMetadataManager(); + private static final MetadataManager metadata = MetadataManager.createTestMetadataManager(); + private static final FunctionManager functionManager = metadata.getFunctionManager(); @Test public void testAllRegistered() @@ -61,8 +65,8 @@ public void testAllRegistered() for (Type keyType : orderableTypes) { for (Type valueType : getTypes()) { if (StateCompiler.getSupportedFieldTypes().contains(valueType.getJavaType())) { - assertNotNull(METADATA.getFunctionManager().getAggregateFunctionImplementation(new Signature("min_by", AGGREGATE, valueType.getTypeSignature(), valueType.getTypeSignature(), keyType.getTypeSignature()))); - assertNotNull(METADATA.getFunctionManager().getAggregateFunctionImplementation(new Signature("max_by", AGGREGATE, valueType.getTypeSignature(), valueType.getTypeSignature(), keyType.getTypeSignature()))); + assertNotNull(getMinByAggregation(valueType, keyType)); + assertNotNull(getMaxByAggregation(valueType, keyType)); } } } @@ -70,11 +74,11 @@ public void testAllRegistered() private static List getTypes() { - List simpleTypes = METADATA.getTypeManager().getTypes(); + List simpleTypes = metadata.getTypeManager().getTypes(); return new ImmutableList.Builder() .addAll(simpleTypes) .add(VARCHAR) - .add(DecimalType.createDecimalType(1)) + .add(createDecimalType(1)) .add(RowType.anonymous(ImmutableList.of(BIGINT, VARCHAR, DOUBLE))) .build(); } @@ -82,17 +86,15 @@ private static List getTypes() @Test public void testMinUnknown() { - InternalAggregationFunction unknownKey = METADATA.getFunctionManager().getAggregateFunctionImplementation( - new Signature("min_by", AGGREGATE, parseTypeSignature(UnknownType.NAME), parseTypeSignature(UnknownType.NAME), parseTypeSignature(StandardTypes.DOUBLE))); + InternalAggregationFunction unknownKey = getMinByAggregation(UNKNOWN, DOUBLE); assertAggregation( unknownKey, null, createBooleansBlock(null, null), createDoublesBlock(1.0, 2.0)); - InternalAggregationFunction unknownValue = METADATA.getFunctionManager().getAggregateFunctionImplementation( - new Signature("min_by", AGGREGATE, parseTypeSignature(StandardTypes.DOUBLE), parseTypeSignature(StandardTypes.DOUBLE), parseTypeSignature(UnknownType.NAME))); + InternalAggregationFunction unknownValue = getMinByAggregation(DOUBLE, UNKNOWN); assertAggregation( - unknownKey, + unknownValue, null, createDoublesBlock(1.0, 2.0), createBooleansBlock(null, null)); @@ -101,17 +103,15 @@ public void testMinUnknown() @Test public void testMaxUnknown() { - InternalAggregationFunction unknownKey = METADATA.getFunctionManager().getAggregateFunctionImplementation( - new Signature("max_by", AGGREGATE, parseTypeSignature(UnknownType.NAME), parseTypeSignature(UnknownType.NAME), parseTypeSignature(StandardTypes.DOUBLE))); + InternalAggregationFunction unknownKey = getMaxByAggregation(UNKNOWN, DOUBLE); assertAggregation( unknownKey, null, createBooleansBlock(null, null), createDoublesBlock(1.0, 2.0)); - InternalAggregationFunction unknownValue = METADATA.getFunctionManager().getAggregateFunctionImplementation( - new Signature("max_by", AGGREGATE, parseTypeSignature(StandardTypes.DOUBLE), parseTypeSignature(StandardTypes.DOUBLE), parseTypeSignature(UnknownType.NAME))); + InternalAggregationFunction unknownValue = getMaxByAggregation(DOUBLE, UNKNOWN); assertAggregation( - unknownKey, + unknownValue, null, createDoublesBlock(1.0, 2.0), createBooleansBlock(null, null)); @@ -120,8 +120,7 @@ public void testMaxUnknown() @Test public void testMinNull() { - InternalAggregationFunction function = METADATA.getFunctionManager().getAggregateFunctionImplementation( - new Signature("min_by", AGGREGATE, parseTypeSignature(StandardTypes.DOUBLE), parseTypeSignature(StandardTypes.DOUBLE), parseTypeSignature(StandardTypes.DOUBLE))); + InternalAggregationFunction function = getMinByAggregation(DOUBLE, DOUBLE); assertAggregation( function, 1.0, @@ -137,8 +136,7 @@ public void testMinNull() @Test public void testMaxNull() { - InternalAggregationFunction function = METADATA.getFunctionManager().getAggregateFunctionImplementation( - new Signature("max_by", AGGREGATE, parseTypeSignature(StandardTypes.DOUBLE), parseTypeSignature(StandardTypes.DOUBLE), parseTypeSignature(StandardTypes.DOUBLE))); + InternalAggregationFunction function = getMaxByAggregation(DOUBLE, DOUBLE); assertAggregation( function, null, @@ -154,8 +152,7 @@ public void testMaxNull() @Test public void testMinDoubleDouble() { - InternalAggregationFunction function = METADATA.getFunctionManager().getAggregateFunctionImplementation( - new Signature("min_by", AGGREGATE, parseTypeSignature(StandardTypes.DOUBLE), parseTypeSignature(StandardTypes.DOUBLE), parseTypeSignature(StandardTypes.DOUBLE))); + InternalAggregationFunction function = getMinByAggregation(DOUBLE, DOUBLE); assertAggregation( function, null, @@ -172,8 +169,7 @@ public void testMinDoubleDouble() @Test public void testMaxDoubleDouble() { - InternalAggregationFunction function = METADATA.getFunctionManager().getAggregateFunctionImplementation( - new Signature("max_by", AGGREGATE, parseTypeSignature(StandardTypes.DOUBLE), parseTypeSignature(StandardTypes.DOUBLE), parseTypeSignature(StandardTypes.DOUBLE))); + InternalAggregationFunction function = getMaxByAggregation(DOUBLE, DOUBLE); assertAggregation( function, null, @@ -190,8 +186,7 @@ public void testMaxDoubleDouble() @Test public void testMinDoubleVarchar() { - InternalAggregationFunction function = METADATA.getFunctionManager().getAggregateFunctionImplementation( - new Signature("min_by", AGGREGATE, parseTypeSignature(StandardTypes.VARCHAR), parseTypeSignature(StandardTypes.VARCHAR), parseTypeSignature(StandardTypes.DOUBLE))); + InternalAggregationFunction function = getMinByAggregation(VARCHAR, DOUBLE); assertAggregation( function, "z", @@ -208,8 +203,7 @@ public void testMinDoubleVarchar() @Test public void testMaxDoubleVarchar() { - InternalAggregationFunction function = METADATA.getFunctionManager().getAggregateFunctionImplementation( - new Signature("max_by", AGGREGATE, parseTypeSignature(StandardTypes.VARCHAR), parseTypeSignature(StandardTypes.VARCHAR), parseTypeSignature(StandardTypes.DOUBLE))); + InternalAggregationFunction function = getMaxByAggregation(VARCHAR, DOUBLE); assertAggregation( function, "a", @@ -226,8 +220,7 @@ public void testMaxDoubleVarchar() @Test public void testMinLongLongArray() { - InternalAggregationFunction function = METADATA.getFunctionManager().getAggregateFunctionImplementation( - new Signature("min_by", AGGREGATE, parseTypeSignature("array(bigint)"), parseTypeSignature("array(bigint)"), parseTypeSignature(StandardTypes.BIGINT))); + InternalAggregationFunction function = getMinByAggregation(new ArrayType(BIGINT), BIGINT); assertAggregation( function, ImmutableList.of(8L, 9L), @@ -244,8 +237,7 @@ public void testMinLongLongArray() @Test public void testMinLongArrayLong() { - InternalAggregationFunction function = METADATA.getFunctionManager().getAggregateFunctionImplementation( - new Signature("min_by", AGGREGATE, parseTypeSignature(StandardTypes.BIGINT), parseTypeSignature(StandardTypes.BIGINT), parseTypeSignature("array(bigint)"))); + InternalAggregationFunction function = getMinByAggregation(BIGINT, new ArrayType(BIGINT)); assertAggregation( function, 3L, @@ -262,8 +254,7 @@ public void testMinLongArrayLong() @Test public void testMaxLongArrayLong() { - InternalAggregationFunction function = METADATA.getFunctionManager().getAggregateFunctionImplementation( - new Signature("max_by", AGGREGATE, parseTypeSignature(StandardTypes.BIGINT), parseTypeSignature(StandardTypes.BIGINT), parseTypeSignature("array(bigint)"))); + InternalAggregationFunction function = getMaxByAggregation(BIGINT, new ArrayType(BIGINT)); assertAggregation( function, 1L, @@ -280,8 +271,7 @@ public void testMaxLongArrayLong() @Test public void testMaxLongLongArray() { - InternalAggregationFunction function = METADATA.getFunctionManager().getAggregateFunctionImplementation( - new Signature("max_by", AGGREGATE, parseTypeSignature("array(bigint)"), parseTypeSignature("array(bigint)"), parseTypeSignature(StandardTypes.BIGINT))); + InternalAggregationFunction function = getMaxByAggregation(new ArrayType(BIGINT), BIGINT); assertAggregation( function, ImmutableList.of(1L, 2L), @@ -298,7 +288,8 @@ public void testMaxLongLongArray() @Test public void testMinLongDecimalDecimal() { - InternalAggregationFunction function = METADATA.getFunctionManager().getAggregateFunctionImplementation(new Signature("min_by", AGGREGATE, parseTypeSignature("decimal(19,1)"), parseTypeSignature("decimal(19,1)"), parseTypeSignature("decimal(19,1)"))); + Type decimalType = createDecimalType(19, 1); + InternalAggregationFunction function = getMinByAggregation(decimalType, decimalType); assertAggregation( function, SqlDecimal.of("2.2"), @@ -309,7 +300,8 @@ public void testMinLongDecimalDecimal() @Test public void testMaxLongDecimalDecimal() { - InternalAggregationFunction function = METADATA.getFunctionManager().getAggregateFunctionImplementation(new Signature("max_by", AGGREGATE, parseTypeSignature("decimal(19,1)"), parseTypeSignature("decimal(19,1)"), parseTypeSignature("decimal(19,1)"))); + Type decimalType = createDecimalType(19, 1); + InternalAggregationFunction function = getMaxByAggregation(decimalType, decimalType); assertAggregation( function, SqlDecimal.of("3.3"), @@ -320,7 +312,8 @@ public void testMaxLongDecimalDecimal() @Test public void testMinShortDecimalDecimal() { - InternalAggregationFunction function = METADATA.getFunctionManager().getAggregateFunctionImplementation(new Signature("min_by", AGGREGATE, parseTypeSignature("decimal(10,1)"), parseTypeSignature("decimal(10,1)"), parseTypeSignature("decimal(10,1)"))); + Type decimalType = createDecimalType(10, 1); + InternalAggregationFunction function = getMinByAggregation(decimalType, decimalType); assertAggregation( function, SqlDecimal.of("2.2"), @@ -331,7 +324,8 @@ public void testMinShortDecimalDecimal() @Test public void testMaxShortDecimalDecimal() { - InternalAggregationFunction function = METADATA.getFunctionManager().getAggregateFunctionImplementation(new Signature("max_by", AGGREGATE, parseTypeSignature("decimal(10,1)"), parseTypeSignature("decimal(10,1)"), parseTypeSignature("decimal(10,1)"))); + Type decimalType = createDecimalType(10, 1); + InternalAggregationFunction function = getMaxByAggregation(decimalType, decimalType); assertAggregation( function, SqlDecimal.of("3.3"), @@ -342,7 +336,7 @@ public void testMaxShortDecimalDecimal() @Test public void testMinBooleanVarchar() { - InternalAggregationFunction function = METADATA.getFunctionManager().getAggregateFunctionImplementation(new Signature("min_by", AGGREGATE, parseTypeSignature(StandardTypes.VARCHAR), parseTypeSignature(StandardTypes.VARCHAR), parseTypeSignature(StandardTypes.BOOLEAN))); + InternalAggregationFunction function = getMinByAggregation(VARCHAR, BOOLEAN); assertAggregation( function, "b", @@ -353,7 +347,7 @@ public void testMinBooleanVarchar() @Test public void testMaxBooleanVarchar() { - InternalAggregationFunction function = METADATA.getFunctionManager().getAggregateFunctionImplementation(new Signature("max_by", AGGREGATE, parseTypeSignature(StandardTypes.VARCHAR), parseTypeSignature(StandardTypes.VARCHAR), parseTypeSignature(StandardTypes.BOOLEAN))); + InternalAggregationFunction function = getMaxByAggregation(VARCHAR, BOOLEAN); assertAggregation( function, "c", @@ -364,7 +358,7 @@ public void testMaxBooleanVarchar() @Test public void testMinIntegerVarchar() { - InternalAggregationFunction function = METADATA.getFunctionManager().getAggregateFunctionImplementation(new Signature("min_by", AGGREGATE, parseTypeSignature(StandardTypes.VARCHAR), parseTypeSignature(StandardTypes.VARCHAR), parseTypeSignature(StandardTypes.INTEGER))); + InternalAggregationFunction function = getMinByAggregation(VARCHAR, INTEGER); assertAggregation( function, "a", @@ -375,7 +369,7 @@ public void testMinIntegerVarchar() @Test public void testMaxIntegerVarchar() { - InternalAggregationFunction function = METADATA.getFunctionManager().getAggregateFunctionImplementation(new Signature("max_by", AGGREGATE, parseTypeSignature(StandardTypes.VARCHAR), parseTypeSignature(StandardTypes.VARCHAR), parseTypeSignature(StandardTypes.INTEGER))); + InternalAggregationFunction function = getMaxByAggregation(VARCHAR, INTEGER); assertAggregation( function, "c", @@ -386,7 +380,7 @@ public void testMaxIntegerVarchar() @Test public void testMinBooleanLongArray() { - InternalAggregationFunction function = METADATA.getFunctionManager().getAggregateFunctionImplementation(new Signature("min_by", AGGREGATE, parseTypeSignature("array(bigint)"), parseTypeSignature("array(bigint)"), parseTypeSignature(StandardTypes.BOOLEAN))); + InternalAggregationFunction function = getMinByAggregation(new ArrayType(BIGINT), BOOLEAN); assertAggregation( function, null, @@ -397,7 +391,7 @@ public void testMinBooleanLongArray() @Test public void testMaxBooleanLongArray() { - InternalAggregationFunction function = METADATA.getFunctionManager().getAggregateFunctionImplementation(new Signature("max_by", AGGREGATE, parseTypeSignature("array(bigint)"), parseTypeSignature("array(bigint)"), parseTypeSignature(StandardTypes.BOOLEAN))); + InternalAggregationFunction function = getMaxByAggregation(new ArrayType(BIGINT), BOOLEAN); assertAggregation( function, asList(2L, 2L), @@ -408,7 +402,7 @@ public void testMaxBooleanLongArray() @Test public void testMinLongVarchar() { - InternalAggregationFunction function = METADATA.getFunctionManager().getAggregateFunctionImplementation(new Signature("min_by", AGGREGATE, parseTypeSignature(StandardTypes.VARCHAR), parseTypeSignature(StandardTypes.VARCHAR), parseTypeSignature(StandardTypes.BIGINT))); + InternalAggregationFunction function = getMinByAggregation(VARCHAR, BIGINT); assertAggregation( function, "a", @@ -419,7 +413,7 @@ public void testMinLongVarchar() @Test public void testMaxLongVarchar() { - InternalAggregationFunction function = METADATA.getFunctionManager().getAggregateFunctionImplementation(new Signature("max_by", AGGREGATE, parseTypeSignature(StandardTypes.VARCHAR), parseTypeSignature(StandardTypes.VARCHAR), parseTypeSignature(StandardTypes.BIGINT))); + InternalAggregationFunction function = getMaxByAggregation(VARCHAR, BIGINT); assertAggregation( function, "c", @@ -430,8 +424,7 @@ public void testMaxLongVarchar() @Test public void testMinDoubleLongArray() { - InternalAggregationFunction function = METADATA.getFunctionManager().getAggregateFunctionImplementation( - new Signature("min_by", AGGREGATE, parseTypeSignature("array(bigint)"), parseTypeSignature("array(bigint)"), parseTypeSignature(StandardTypes.DOUBLE))); + InternalAggregationFunction function = getMinByAggregation(new ArrayType(BIGINT), DOUBLE); assertAggregation( function, asList(3L, 4L), @@ -448,8 +441,7 @@ public void testMinDoubleLongArray() @Test public void testMaxDoubleLongArray() { - InternalAggregationFunction function = METADATA.getFunctionManager().getAggregateFunctionImplementation( - new Signature("max_by", AGGREGATE, parseTypeSignature("array(bigint)"), parseTypeSignature("array(bigint)"), parseTypeSignature(StandardTypes.DOUBLE))); + InternalAggregationFunction function = getMaxByAggregation(new ArrayType(BIGINT), DOUBLE); assertAggregation( function, null, @@ -466,8 +458,7 @@ public void testMaxDoubleLongArray() @Test public void testMinSliceLongArray() { - InternalAggregationFunction function = METADATA.getFunctionManager().getAggregateFunctionImplementation( - new Signature("min_by", AGGREGATE, parseTypeSignature("array(bigint)"), parseTypeSignature("array(bigint)"), parseTypeSignature(StandardTypes.VARCHAR))); + InternalAggregationFunction function = getMinByAggregation(new ArrayType(BIGINT), VARCHAR); assertAggregation( function, asList(3L, 4L), @@ -484,8 +475,7 @@ public void testMinSliceLongArray() @Test public void testMaxSliceLongArray() { - InternalAggregationFunction function = METADATA.getFunctionManager().getAggregateFunctionImplementation( - new Signature("max_by", AGGREGATE, parseTypeSignature("array(bigint)"), parseTypeSignature("array(bigint)"), parseTypeSignature(StandardTypes.VARCHAR))); + InternalAggregationFunction function = getMaxByAggregation(new ArrayType(BIGINT), VARCHAR); assertAggregation( function, asList(2L, 2L), @@ -502,8 +492,7 @@ public void testMaxSliceLongArray() @Test public void testMinLongArrayLongArray() { - InternalAggregationFunction function = METADATA.getFunctionManager().getAggregateFunctionImplementation( - new Signature("min_by", AGGREGATE, parseTypeSignature("array(bigint)"), parseTypeSignature("array(bigint)"), parseTypeSignature("array(bigint)"))); + InternalAggregationFunction function = getMinByAggregation(new ArrayType(BIGINT), new ArrayType(BIGINT)); assertAggregation( function, asList(1L, 2L), @@ -514,8 +503,7 @@ public void testMinLongArrayLongArray() @Test public void testMaxLongArrayLongArray() { - InternalAggregationFunction function = METADATA.getFunctionManager().getAggregateFunctionImplementation( - new Signature("max_by", AGGREGATE, parseTypeSignature("array(bigint)"), parseTypeSignature("array(bigint)"), parseTypeSignature("array(bigint)"))); + InternalAggregationFunction function = getMaxByAggregation(new ArrayType(BIGINT), new ArrayType(BIGINT)); assertAggregation( function, asList(3L, 3L), @@ -526,8 +514,7 @@ public void testMaxLongArrayLongArray() @Test public void testMinLongArraySlice() { - InternalAggregationFunction function = METADATA.getFunctionManager().getAggregateFunctionImplementation( - new Signature("min_by", AGGREGATE, parseTypeSignature(StandardTypes.VARCHAR), parseTypeSignature(StandardTypes.VARCHAR), parseTypeSignature("array(bigint)"))); + InternalAggregationFunction function = getMinByAggregation(VARCHAR, new ArrayType(BIGINT)); assertAggregation( function, "c", @@ -538,8 +525,7 @@ public void testMinLongArraySlice() @Test public void testMaxLongArraySlice() { - InternalAggregationFunction function = METADATA.getFunctionManager().getAggregateFunctionImplementation( - new Signature("max_by", AGGREGATE, parseTypeSignature(StandardTypes.VARCHAR), parseTypeSignature(StandardTypes.VARCHAR), parseTypeSignature("array(bigint)"))); + InternalAggregationFunction function = getMaxByAggregation(VARCHAR, new ArrayType(BIGINT)); assertAggregation( function, "a", @@ -550,8 +536,7 @@ public void testMaxLongArraySlice() @Test public void testMinUnknownSlice() { - InternalAggregationFunction function = METADATA.getFunctionManager().getAggregateFunctionImplementation( - new Signature("min_by", AGGREGATE, parseTypeSignature(StandardTypes.VARCHAR), parseTypeSignature(StandardTypes.VARCHAR), parseTypeSignature(UnknownType.NAME))); + InternalAggregationFunction function = getMinByAggregation(VARCHAR, UNKNOWN); assertAggregation( function, null, @@ -562,8 +547,7 @@ public void testMinUnknownSlice() @Test public void testMaxUnknownSlice() { - InternalAggregationFunction function = METADATA.getFunctionManager().getAggregateFunctionImplementation( - new Signature("max_by", AGGREGATE, parseTypeSignature(StandardTypes.VARCHAR), parseTypeSignature(StandardTypes.VARCHAR), parseTypeSignature(UnknownType.NAME))); + InternalAggregationFunction function = getMaxByAggregation(VARCHAR, UNKNOWN); assertAggregation( function, null, @@ -574,8 +558,7 @@ public void testMaxUnknownSlice() @Test public void testMinUnknownLongArray() { - InternalAggregationFunction function = METADATA.getFunctionManager().getAggregateFunctionImplementation( - new Signature("min_by", AGGREGATE, parseTypeSignature("array(bigint)"), parseTypeSignature("array(bigint)"), parseTypeSignature(UnknownType.NAME))); + InternalAggregationFunction function = getMinByAggregation(new ArrayType(BIGINT), UNKNOWN); assertAggregation( function, null, @@ -586,12 +569,21 @@ public void testMinUnknownLongArray() @Test public void testMaxUnknownLongArray() { - InternalAggregationFunction function = METADATA.getFunctionManager().getAggregateFunctionImplementation( - new Signature("max_by", AGGREGATE, parseTypeSignature("array(bigint)"), parseTypeSignature("array(bigint)"), parseTypeSignature(UnknownType.NAME))); + InternalAggregationFunction function = getMaxByAggregation(new ArrayType(BIGINT), UNKNOWN); assertAggregation( function, null, createArrayBigintBlock(asList(asList(3L, 3L), null, asList(1L, 2L))), createArrayBigintBlock(asList(null, null, null))); } + + private InternalAggregationFunction getMinByAggregation(Type... arguments) + { + return functionManager.getAggregateFunctionImplementation(functionManager.resolveFunction(TEST_SESSION, QualifiedName.of("min_by"), fromTypes(arguments))); + } + + private InternalAggregationFunction getMaxByAggregation(Type... arguments) + { + return functionManager.getAggregateFunctionImplementation(functionManager.resolveFunction(TEST_SESSION, QualifiedName.of("max_by"), fromTypes(arguments))); + } } diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/minmaxby/TestMinMaxByNAggregation.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/minmaxby/TestMinMaxByNAggregation.java index a79e9896d0dd3..74a969fb71fd2 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/minmaxby/TestMinMaxByNAggregation.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/minmaxby/TestMinMaxByNAggregation.java @@ -13,42 +13,41 @@ */ package com.facebook.presto.operator.aggregation.minmaxby; +import com.facebook.presto.metadata.FunctionManager; import com.facebook.presto.metadata.MetadataManager; -import com.facebook.presto.metadata.Signature; import com.facebook.presto.operator.aggregation.InternalAggregationFunction; import com.facebook.presto.spi.Page; import com.facebook.presto.spi.PrestoException; -import com.facebook.presto.spi.type.StandardTypes; +import com.facebook.presto.spi.type.ArrayType; +import com.facebook.presto.spi.type.Type; +import com.facebook.presto.sql.tree.QualifiedName; import com.google.common.collect.ImmutableList; import org.testng.annotations.Test; import java.util.Arrays; +import static com.facebook.presto.SessionTestUtils.TEST_SESSION; import static com.facebook.presto.block.BlockAssertions.createArrayBigintBlock; import static com.facebook.presto.block.BlockAssertions.createDoublesBlock; import static com.facebook.presto.block.BlockAssertions.createLongsBlock; import static com.facebook.presto.block.BlockAssertions.createRLEBlock; import static com.facebook.presto.block.BlockAssertions.createStringsBlock; -import static com.facebook.presto.metadata.FunctionKind.AGGREGATE; import static com.facebook.presto.operator.aggregation.AggregationTestUtils.assertAggregation; import static com.facebook.presto.operator.aggregation.AggregationTestUtils.groupedAggregation; -import static com.facebook.presto.spi.type.TypeSignature.parseTypeSignature; +import static com.facebook.presto.spi.type.BigintType.BIGINT; +import static com.facebook.presto.spi.type.DoubleType.DOUBLE; +import static com.facebook.presto.spi.type.VarcharType.VARCHAR; +import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes; import static org.testng.Assert.assertEquals; public class TestMinMaxByNAggregation { - private static final MetadataManager METADATA = MetadataManager.createTestMetadataManager(); + private static final FunctionManager functionManager = MetadataManager.createTestMetadataManager().getFunctionManager(); @Test public void testMaxDoubleDouble() { - InternalAggregationFunction function = METADATA.getFunctionManager().getAggregateFunctionImplementation( - new Signature("max_by", - AGGREGATE, - parseTypeSignature("array(double)"), - parseTypeSignature(StandardTypes.DOUBLE), - parseTypeSignature(StandardTypes.DOUBLE), - parseTypeSignature(StandardTypes.BIGINT))); + InternalAggregationFunction function = getMaxByAggregation(DOUBLE, DOUBLE, BIGINT); assertAggregation( function, Arrays.asList((Double) null), @@ -102,13 +101,7 @@ public void testMaxDoubleDouble() @Test public void testMinDoubleDouble() { - InternalAggregationFunction function = METADATA.getFunctionManager().getAggregateFunctionImplementation( - new Signature("min_by", - AGGREGATE, - parseTypeSignature("array(double)"), - parseTypeSignature(StandardTypes.DOUBLE), - parseTypeSignature(StandardTypes.DOUBLE), - parseTypeSignature(StandardTypes.BIGINT))); + InternalAggregationFunction function = getMinByAggregation(DOUBLE, DOUBLE, BIGINT); assertAggregation( function, Arrays.asList((Double) null), @@ -141,13 +134,7 @@ public void testMinDoubleDouble() @Test public void testMinDoubleVarchar() { - InternalAggregationFunction function = METADATA.getFunctionManager().getAggregateFunctionImplementation( - new Signature("min_by", - AGGREGATE, - parseTypeSignature("array(varchar)"), - parseTypeSignature(StandardTypes.VARCHAR), - parseTypeSignature(StandardTypes.DOUBLE), - parseTypeSignature(StandardTypes.BIGINT))); + InternalAggregationFunction function = getMinByAggregation(VARCHAR, DOUBLE, BIGINT); assertAggregation( function, ImmutableList.of("z", "a"), @@ -173,13 +160,7 @@ public void testMinDoubleVarchar() @Test public void testMaxDoubleVarchar() { - InternalAggregationFunction function = METADATA.getFunctionManager().getAggregateFunctionImplementation( - new Signature("max_by", - AGGREGATE, - parseTypeSignature("array(varchar)"), - parseTypeSignature(StandardTypes.VARCHAR), - parseTypeSignature(StandardTypes.DOUBLE), - parseTypeSignature(StandardTypes.BIGINT))); + InternalAggregationFunction function = getMaxByAggregation(VARCHAR, DOUBLE, BIGINT); assertAggregation( function, ImmutableList.of("a", "z"), @@ -205,13 +186,7 @@ public void testMaxDoubleVarchar() @Test public void testMinVarcharDouble() { - InternalAggregationFunction function = METADATA.getFunctionManager().getAggregateFunctionImplementation( - new Signature("min_by", - AGGREGATE, - parseTypeSignature("array(double)"), - parseTypeSignature(StandardTypes.DOUBLE), - parseTypeSignature(StandardTypes.VARCHAR), - parseTypeSignature(StandardTypes.BIGINT))); + InternalAggregationFunction function = getMinByAggregation(DOUBLE, VARCHAR, BIGINT); assertAggregation( function, ImmutableList.of(2.0, 3.0), @@ -237,13 +212,7 @@ public void testMinVarcharDouble() @Test public void testMaxVarcharDouble() { - InternalAggregationFunction function = METADATA.getFunctionManager().getAggregateFunctionImplementation( - new Signature("max_by", - AGGREGATE, - parseTypeSignature("array(double)"), - parseTypeSignature(StandardTypes.DOUBLE), - parseTypeSignature(StandardTypes.VARCHAR), - parseTypeSignature(StandardTypes.BIGINT))); + InternalAggregationFunction function = getMaxByAggregation(DOUBLE, VARCHAR, BIGINT); assertAggregation( function, ImmutableList.of(1.0, 2.0), @@ -269,13 +238,7 @@ public void testMaxVarcharDouble() @Test public void testMinVarcharArray() { - InternalAggregationFunction function = METADATA.getFunctionManager().getAggregateFunctionImplementation( - new Signature("min_by", - AGGREGATE, - parseTypeSignature("array(array(bigint))"), - parseTypeSignature("array(bigint)"), - parseTypeSignature(StandardTypes.VARCHAR), - parseTypeSignature(StandardTypes.BIGINT))); + InternalAggregationFunction function = getMinByAggregation(new ArrayType(BIGINT), VARCHAR, BIGINT); assertAggregation( function, ImmutableList.of(ImmutableList.of(2L, 3L), ImmutableList.of(4L, 5L)), @@ -287,13 +250,7 @@ public void testMinVarcharArray() @Test public void testMaxVarcharArray() { - InternalAggregationFunction function = METADATA.getFunctionManager().getAggregateFunctionImplementation( - new Signature("max_by", - AGGREGATE, - parseTypeSignature("array(array(bigint))"), - parseTypeSignature("array(bigint)"), - parseTypeSignature(StandardTypes.VARCHAR), - parseTypeSignature(StandardTypes.BIGINT))); + InternalAggregationFunction function = getMaxByAggregation(new ArrayType(BIGINT), VARCHAR, BIGINT); assertAggregation( function, ImmutableList.of(ImmutableList.of(1L, 2L), ImmutableList.of(3L, 4L)), @@ -305,13 +262,7 @@ public void testMaxVarcharArray() @Test public void testMinArrayVarchar() { - InternalAggregationFunction function = METADATA.getFunctionManager().getAggregateFunctionImplementation( - new Signature("min_by", - AGGREGATE, - parseTypeSignature("array(varchar)"), - parseTypeSignature(StandardTypes.VARCHAR), - parseTypeSignature("array(bigint)"), - parseTypeSignature(StandardTypes.BIGINT))); + InternalAggregationFunction function = getMinByAggregation(VARCHAR, new ArrayType(BIGINT), BIGINT); assertAggregation( function, ImmutableList.of("b", "x", "z"), @@ -323,13 +274,7 @@ public void testMinArrayVarchar() @Test public void testMaxArrayVarchar() { - InternalAggregationFunction function = METADATA.getFunctionManager().getAggregateFunctionImplementation( - new Signature("max_by", - AGGREGATE, - parseTypeSignature("array(varchar)"), - parseTypeSignature(StandardTypes.VARCHAR), - parseTypeSignature("array(bigint)"), - parseTypeSignature(StandardTypes.BIGINT))); + InternalAggregationFunction function = getMaxByAggregation(VARCHAR, new ArrayType(BIGINT), BIGINT); assertAggregation( function, ImmutableList.of("a", "z", "x"), @@ -341,13 +286,7 @@ public void testMaxArrayVarchar() @Test public void testOutOfBound() { - InternalAggregationFunction function = METADATA.getFunctionManager().getAggregateFunctionImplementation( - new Signature("max_by", - AGGREGATE, - parseTypeSignature("array(varchar)"), - parseTypeSignature(StandardTypes.VARCHAR), - parseTypeSignature(StandardTypes.BIGINT), - parseTypeSignature(StandardTypes.BIGINT))); + InternalAggregationFunction function = getMaxByAggregation(VARCHAR, BIGINT, BIGINT); try { groupedAggregation(function, new Page(createStringsBlock("z"), createLongsBlock(0), createLongsBlock(10001))); } @@ -355,4 +294,16 @@ public void testOutOfBound() assertEquals(e.getMessage(), "third argument of max_by/min_by must be less than or equal to 10000; found 10001"); } } + + private InternalAggregationFunction getMaxByAggregation(Type... arguments) + { + return functionManager.getAggregateFunctionImplementation( + functionManager.resolveFunction(TEST_SESSION, QualifiedName.of("max_by"), fromTypes(arguments))); + } + + private InternalAggregationFunction getMinByAggregation(Type... arguments) + { + return functionManager.getAggregateFunctionImplementation( + functionManager.resolveFunction(TEST_SESSION, QualifiedName.of("min_by"), fromTypes(arguments))); + } } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestEffectivePredicateExtractor.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestEffectivePredicateExtractor.java index 206f0ee482448..51008d0540ab5 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestEffectivePredicateExtractor.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestEffectivePredicateExtractor.java @@ -14,17 +14,15 @@ package com.facebook.presto.sql.planner; import com.facebook.presto.connector.ConnectorId; -import com.facebook.presto.metadata.FunctionKind; +import com.facebook.presto.metadata.FunctionHandle; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.metadata.MetadataManager; -import com.facebook.presto.metadata.Signature; import com.facebook.presto.metadata.TableHandle; import com.facebook.presto.metadata.TableLayoutHandle; import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.block.SortOrder; import com.facebook.presto.spi.predicate.Domain; import com.facebook.presto.spi.predicate.TupleDomain; -import com.facebook.presto.spi.type.TypeSignature; import com.facebook.presto.sql.planner.plan.AggregationNode; import com.facebook.presto.sql.planner.plan.AggregationNode.Aggregation; import com.facebook.presto.sql.planner.plan.Assignments; @@ -53,7 +51,6 @@ import com.facebook.presto.testing.TestingMetadata.TestingColumnHandle; import com.facebook.presto.testing.TestingMetadata.TestingTableHandle; import com.facebook.presto.testing.TestingTransactionHandle; -import com.facebook.presto.type.UnknownType; import com.google.common.base.Preconditions; import com.google.common.base.Predicates; import com.google.common.collect.ImmutableList; @@ -75,7 +72,7 @@ import java.util.Set; import java.util.UUID; -import static com.facebook.presto.metadata.FunctionKind.AGGREGATE; +import static com.facebook.presto.SessionTestUtils.TEST_SESSION; import static com.facebook.presto.spi.type.BigintType.BIGINT; import static com.facebook.presto.sql.ExpressionUtils.and; import static com.facebook.presto.sql.ExpressionUtils.combineConjuncts; @@ -142,6 +139,8 @@ public void setUp() @Test public void testAggregation() { + FunctionCall functionCall = new FunctionCall(QualifiedName.of("count"), ImmutableList.of()); + FunctionHandle functionHandle = metadata.getFunctionManager().resolveFunction(TEST_SESSION, QualifiedName.of("count"), ImmutableList.of()); PlanNode node = new AggregationNode(newId(), filter(baseTableScan, and( @@ -153,8 +152,8 @@ public void testAggregation() greaterThan(AE, bigintLiteral(2)), equals(EE, FE))), ImmutableMap.of( - C, new Aggregation(fakeFunction(), fakeFunctionHandle("test", AGGREGATE), Optional.empty()), - D, new Aggregation(fakeFunction(), fakeFunctionHandle("test", AGGREGATE), Optional.empty())), + C, new Aggregation(functionCall, functionHandle, Optional.empty()), + D, new Aggregation(functionCall, functionHandle, Optional.empty())), singleGroupingSet(ImmutableList.of(A, B, C)), ImmutableList.of(), AggregationNode.Step.FINAL, @@ -758,16 +757,6 @@ private static IsNullPredicate isNull(Expression expression) return new IsNullPredicate(expression); } - private static FunctionCall fakeFunction() - { - return new FunctionCall(QualifiedName.of("test"), ImmutableList.of()); - } - - private static Signature fakeFunctionHandle(String name, FunctionKind kind) - { - return new Signature(name, kind, TypeSignature.parseTypeSignature(UnknownType.NAME), ImmutableList.of()); - } - private Set normalizeConjuncts(Expression... conjuncts) { return normalizeConjuncts(Arrays.asList(conjuncts)); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestTypeValidator.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestTypeValidator.java index ee592f558b0da..40312b39f0a35 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestTypeValidator.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestTypeValidator.java @@ -16,9 +16,7 @@ import com.facebook.presto.connector.ConnectorId; import com.facebook.presto.execution.warnings.WarningCollector; import com.facebook.presto.metadata.FunctionHandle; -import com.facebook.presto.metadata.FunctionKind; import com.facebook.presto.metadata.FunctionManager; -import com.facebook.presto.metadata.Signature; import com.facebook.presto.metadata.TableHandle; import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.predicate.TupleDomain; @@ -187,14 +185,7 @@ public void testValidAggregation() baseTableScan, ImmutableMap.of(aggregationSymbol, new Aggregation( new FunctionCall(QualifiedName.of("sum"), ImmutableList.of(columnC.toSymbolReference())), - new Signature( - "sum", - FunctionKind.AGGREGATE, - ImmutableList.of(), - ImmutableList.of(), - DOUBLE.getTypeSignature(), - ImmutableList.of(DOUBLE.getTypeSignature()), - false), + FUNCTION_MANAGER.resolveFunction(TEST_SESSION, QualifiedName.of("sum"), fromTypes(DOUBLE)), Optional.empty())), singleGroupingSet(ImmutableList.of(columnA, columnB)), ImmutableList.of(), @@ -245,14 +236,7 @@ public void testInvalidAggregationFunctionCall() baseTableScan, ImmutableMap.of(aggregationSymbol, new Aggregation( new FunctionCall(QualifiedName.of("sum"), ImmutableList.of(columnA.toSymbolReference())), - new Signature( - "sum", - FunctionKind.AGGREGATE, - ImmutableList.of(), - ImmutableList.of(), - DOUBLE.getTypeSignature(), - ImmutableList.of(DOUBLE.getTypeSignature()), - false), + FUNCTION_MANAGER.resolveFunction(TEST_SESSION, QualifiedName.of("sum"), fromTypes(DOUBLE)), Optional.empty())), singleGroupingSet(ImmutableList.of(columnA, columnB)), ImmutableList.of(), @@ -273,14 +257,7 @@ public void testInvalidAggregationFunctionSignature() baseTableScan, ImmutableMap.of(aggregationSymbol, new Aggregation( new FunctionCall(QualifiedName.of("sum"), ImmutableList.of(columnC.toSymbolReference())), - new Signature( - "sum", - FunctionKind.AGGREGATE, - ImmutableList.of(), - ImmutableList.of(), - BIGINT.getTypeSignature(), // should be DOUBLE - ImmutableList.of(DOUBLE.getTypeSignature()), - false), + FUNCTION_MANAGER.resolveFunction(TEST_SESSION, QualifiedName.of("sum"), fromTypes(BIGINT)), // should be DOUBLE Optional.empty())), singleGroupingSet(ImmutableList.of(columnA, columnB)), ImmutableList.of(), diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/PlanBuilder.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/PlanBuilder.java index fe0f890895790..b21a0abecd149 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/PlanBuilder.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/PlanBuilder.java @@ -13,10 +13,11 @@ */ package com.facebook.presto.sql.planner.iterative.rule.test; +import com.facebook.presto.Session; import com.facebook.presto.connector.ConnectorId; +import com.facebook.presto.metadata.FunctionHandle; import com.facebook.presto.metadata.IndexHandle; import com.facebook.presto.metadata.Metadata; -import com.facebook.presto.metadata.Signature; import com.facebook.presto.metadata.TableHandle; import com.facebook.presto.metadata.TableLayoutHandle; import com.facebook.presto.spi.ColumnHandle; @@ -94,6 +95,7 @@ import static com.facebook.presto.spi.type.VarbinaryType.VARBINARY; import static com.facebook.presto.sql.planner.SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION; import static com.facebook.presto.sql.planner.SystemPartitioningHandle.SINGLE_DISTRIBUTION; +import static com.facebook.presto.testing.TestingSession.testSessionBuilder; import static com.facebook.presto.util.MoreLists.nElements; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; @@ -255,6 +257,7 @@ public class AggregationBuilder private Step step = Step.SINGLE; private Optional hashSymbol = Optional.empty(); private Optional groupIdSymbol = Optional.empty(); + private Session session = testSessionBuilder().build(); public AggregationBuilder source(PlanNode source) { @@ -276,8 +279,8 @@ private AggregationBuilder addAggregation(Symbol output, Expression expression, { checkArgument(expression instanceof FunctionCall); FunctionCall aggregation = (FunctionCall) expression; - Signature signature = metadata.getFunctionManager().resolveFunction(aggregation.getName(), TypeSignatureProvider.fromTypes(inputTypes)); - return addAggregation(output, new Aggregation(aggregation, signature, mask)); + FunctionHandle functionHandle = metadata.getFunctionManager().resolveFunction(session, aggregation.getName(), TypeSignatureProvider.fromTypes(inputTypes)); + return addAggregation(output, new Aggregation(aggregation, functionHandle, mask)); } public AggregationBuilder addAggregation(Symbol output, Aggregation aggregation) diff --git a/presto-ml/pom.xml b/presto-ml/pom.xml index 893b97ef795cc..369982ab9be4f 100644 --- a/presto-ml/pom.xml +++ b/presto-ml/pom.xml @@ -21,6 +21,11 @@ presto-array + + com.facebook.presto + presto-parser + + com.facebook.thirdparty libsvm diff --git a/presto-ml/src/test/java/com/facebook/presto/ml/TestEvaluateClassifierPredictions.java b/presto-ml/src/test/java/com/facebook/presto/ml/TestEvaluateClassifierPredictions.java index 5438945549a64..742083a63b139 100644 --- a/presto-ml/src/test/java/com/facebook/presto/ml/TestEvaluateClassifierPredictions.java +++ b/presto-ml/src/test/java/com/facebook/presto/ml/TestEvaluateClassifierPredictions.java @@ -14,15 +14,14 @@ package com.facebook.presto.ml; import com.facebook.presto.RowPageBuilder; -import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.metadata.FunctionManager; import com.facebook.presto.metadata.MetadataManager; -import com.facebook.presto.metadata.Signature; import com.facebook.presto.operator.aggregation.Accumulator; import com.facebook.presto.operator.aggregation.InternalAggregationFunction; import com.facebook.presto.spi.Page; import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.BlockBuilder; -import com.facebook.presto.spi.type.StandardTypes; +import com.facebook.presto.sql.tree.QualifiedName; import com.google.common.base.Splitter; import com.google.common.collect.ImmutableList; import org.testng.annotations.Test; @@ -30,25 +29,24 @@ import java.util.List; import java.util.Optional; +import static com.facebook.presto.SessionTestUtils.TEST_SESSION; import static com.facebook.presto.metadata.FunctionExtractor.extractFunctions; -import static com.facebook.presto.metadata.FunctionKind.AGGREGATE; import static com.facebook.presto.spi.type.BigintType.BIGINT; -import static com.facebook.presto.spi.type.TypeSignature.parseTypeSignature; import static com.facebook.presto.spi.type.VarcharType.VARCHAR; +import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes; import static org.testng.Assert.assertEquals; public class TestEvaluateClassifierPredictions { - private final Metadata metadata = MetadataManager.createTestMetadataManager(); + private final MetadataManager metadata = MetadataManager.createTestMetadataManager(); + private final FunctionManager functionManager = metadata.getFunctionManager(); @Test public void testEvaluateClassifierPredictions() { metadata.addFunctions(extractFunctions(new MLPlugin().getFunctions())); - InternalAggregationFunction aggregation = metadata.getFunctionManager().getAggregateFunctionImplementation( - new Signature("evaluate_classifier_predictions", - AGGREGATE, - parseTypeSignature(StandardTypes.VARCHAR), parseTypeSignature(StandardTypes.BIGINT), parseTypeSignature(StandardTypes.BIGINT))); + InternalAggregationFunction aggregation = functionManager.getAggregateFunctionImplementation( + functionManager.resolveFunction(TEST_SESSION, QualifiedName.of("evaluate_classifier_predictions"), fromTypes(BIGINT, BIGINT))); Accumulator accumulator = aggregation.bind(ImmutableList.of(0, 1), Optional.empty()).createAccumulator(); accumulator.addInput(getPage()); BlockBuilder finalOut = accumulator.getFinalType().createBlockBuilder(null, 1); From f316414395b0c7d909089911141216ab4344c333 Mon Sep 17 00:00:00 2001 From: rongrong Date: Tue, 26 Feb 2019 17:56:00 -0800 Subject: [PATCH 6/8] Switch cast to FunctionHandle --- .../com/facebook/presto/cost/StatsUtil.java | 7 ++--- .../presto/metadata/FunctionManager.java | 14 +++++----- .../presto/metadata/FunctionNamespace.java | 4 +-- .../presto/metadata/FunctionRegistry.java | 14 ++++------ .../presto/operator/scalar/ArrayJoin.java | 3 +-- .../operator/scalar/ArrayToArrayCast.java | 8 +++--- .../presto/operator/scalar/MapToMapCast.java | 2 +- .../presto/operator/scalar/RowToRowCast.java | 12 +++------ .../operator/scalar/TryCastFunction.java | 6 +++-- .../sql/InterpretedFunctionInvoker.java | 27 ++++++++++++++----- .../sql/analyzer/ExpressionAnalyzer.java | 5 ++-- .../presto/sql/gen/CastCodeGenerator.java | 8 +++--- .../presto/sql/gen/NullIfCodeGenerator.java | 8 +++--- .../presto/sql/planner/DomainTranslator.java | 17 +++++++----- .../sql/planner/ExpressionInterpreter.java | 8 +++--- .../sql/planner/LiteralInterpreter.java | 6 +++-- .../sql/planner/planPrinter/PlanPrinter.java | 7 ++--- .../optimizer/ExpressionOptimizer.java | 3 ++- .../presto/metadata/TestFunctionRegistry.java | 11 ++++---- .../presto/type/TestTypeRegistry.java | 9 +++++-- 20 files changed, 103 insertions(+), 76 deletions(-) diff --git a/presto-main/src/main/java/com/facebook/presto/cost/StatsUtil.java b/presto-main/src/main/java/com/facebook/presto/cost/StatsUtil.java index 113c8de96a880..1ab6678366300 100644 --- a/presto-main/src/main/java/com/facebook/presto/cost/StatsUtil.java +++ b/presto-main/src/main/java/com/facebook/presto/cost/StatsUtil.java @@ -14,9 +14,9 @@ package com.facebook.presto.cost; import com.facebook.presto.Session; +import com.facebook.presto.metadata.FunctionHandle; import com.facebook.presto.metadata.FunctionManager; import com.facebook.presto.metadata.Metadata; -import com.facebook.presto.metadata.Signature; import com.facebook.presto.spi.ConnectorSession; import com.facebook.presto.spi.type.BigintType; import com.facebook.presto.spi.type.BooleanType; @@ -32,6 +32,7 @@ import java.util.OptionalDouble; +import static com.facebook.presto.spi.function.OperatorType.CAST; import static java.util.Collections.singletonList; final class StatsUtil @@ -47,9 +48,9 @@ static OptionalDouble toStatsRepresentation(FunctionManager functionManager, Con { if (convertibleToDoubleWithCast(type)) { InterpretedFunctionInvoker functionInvoker = new InterpretedFunctionInvoker(functionManager); - Signature castSignature = functionManager.getCoercion(type, DoubleType.DOUBLE); + FunctionHandle cast = functionManager.lookupCast(CAST, type.getTypeSignature(), DoubleType.DOUBLE.getTypeSignature()); - return OptionalDouble.of((double) functionInvoker.invoke(castSignature, session, singletonList(value))); + return OptionalDouble.of((double) functionInvoker.invoke(cast, session, singletonList(value))); } if (DateType.DATE.equals(type)) { diff --git a/presto-main/src/main/java/com/facebook/presto/metadata/FunctionManager.java b/presto-main/src/main/java/com/facebook/presto/metadata/FunctionManager.java index 906e4ada7b475..24bdad8532c6f 100644 --- a/presto-main/src/main/java/com/facebook/presto/metadata/FunctionManager.java +++ b/presto-main/src/main/java/com/facebook/presto/metadata/FunctionManager.java @@ -87,6 +87,11 @@ public InternalAggregationFunction getAggregateFunctionImplementation(FunctionHa return globalFunctionNamespace.getAggregateFunctionImplementation(functionHandle); } + public ScalarFunctionImplementation getScalarFunctionImplementation(FunctionHandle functionHandle) + { + return globalFunctionNamespace.getScalarFunctionImplementation(functionHandle.getSignature()); + } + public ScalarFunctionImplementation getScalarFunctionImplementation(Signature signature) { return globalFunctionNamespace.getScalarFunctionImplementation(signature); @@ -112,13 +117,8 @@ public boolean isRegistered(Signature signature) return globalFunctionNamespace.isRegistered(signature); } - public Signature getCoercion(Type fromType, Type toType) - { - return getCoercion(fromType.getTypeSignature(), toType.getTypeSignature()); - } - - public Signature getCoercion(TypeSignature fromType, TypeSignature toType) + public FunctionHandle lookupCast(OperatorType castType, TypeSignature fromType, TypeSignature toType) { - return globalFunctionNamespace.getCoercion(fromType, toType); + return globalFunctionNamespace.lookupCast(castType, fromType, toType); } } diff --git a/presto-main/src/main/java/com/facebook/presto/metadata/FunctionNamespace.java b/presto-main/src/main/java/com/facebook/presto/metadata/FunctionNamespace.java index c2651bb79771f..92b9ff6adc4ef 100644 --- a/presto-main/src/main/java/com/facebook/presto/metadata/FunctionNamespace.java +++ b/presto-main/src/main/java/com/facebook/presto/metadata/FunctionNamespace.java @@ -84,9 +84,9 @@ public FunctionHandle resolveOperator(OperatorType operatorType, List> toTypes(List typeSignatureProviders, TypeManager typeManager) diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayJoin.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayJoin.java index 9b0cf185f4814..d826153aa31f9 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayJoin.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayJoin.java @@ -40,7 +40,6 @@ import java.util.Map; import java.util.Optional; -import static com.facebook.presto.metadata.Signature.internalOperator; import static com.facebook.presto.metadata.Signature.typeVariable; import static com.facebook.presto.operator.scalar.ScalarFunctionImplementation.ArgumentProperty.valueTypeArgumentProperty; import static com.facebook.presto.operator.scalar.ScalarFunctionImplementation.NullConvention.RETURN_NULL_ON_NULL; @@ -188,7 +187,7 @@ private static ScalarFunctionImplementation specializeArrayJoin(Map elementType = type.getJavaType(); diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayToArrayCast.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayToArrayCast.java index f257d0658fd35..c2cc5453583e9 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayToArrayCast.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayToArrayCast.java @@ -14,6 +14,7 @@ package com.facebook.presto.operator.scalar; import com.facebook.presto.metadata.BoundVariables; +import com.facebook.presto.metadata.FunctionHandle; import com.facebook.presto.metadata.FunctionManager; import com.facebook.presto.metadata.Signature; import com.facebook.presto.metadata.SqlOperator; @@ -36,7 +37,6 @@ import java.lang.invoke.MethodHandle; -import static com.facebook.presto.metadata.Signature.internalOperator; import static com.facebook.presto.metadata.Signature.typeVariable; import static com.facebook.presto.operator.scalar.ScalarFunctionImplementation.ArgumentProperty.valueTypeArgumentProperty; import static com.facebook.presto.operator.scalar.ScalarFunctionImplementation.NullConvention.RETURN_NULL_ON_NULL; @@ -75,9 +75,9 @@ public ScalarFunctionImplementation specialize(BoundVariables boundVariables, in Type fromType = boundVariables.getTypeVariable("F"); Type toType = boundVariables.getTypeVariable("T"); - Signature signature = internalOperator(CAST.name(), toType.getTypeSignature(), ImmutableList.of(fromType.getTypeSignature())); - ScalarFunctionImplementation function = functionManager.getScalarFunctionImplementation(signature); - Class castOperatorClass = generateArrayCast(typeManager, signature, function); + FunctionHandle functionHandle = functionManager.lookupCast(CAST, fromType.getTypeSignature(), toType.getTypeSignature()); + ScalarFunctionImplementation function = functionManager.getScalarFunctionImplementation(functionHandle); + Class castOperatorClass = generateArrayCast(typeManager, functionHandle.getSignature(), function); MethodHandle methodHandle = methodHandle(castOperatorClass, "castArray", ConnectorSession.class, Block.class); return new ScalarFunctionImplementation( false, diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapToMapCast.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapToMapCast.java index 047b984d5aec4..f5573e43f6779 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapToMapCast.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapToMapCast.java @@ -106,7 +106,7 @@ private MethodHandle buildProcessor(FunctionManager functionManager, Type fromTy MethodHandle getter = nativeValueGetter(fromType); // Adapt cast that takes ([ConnectorSession,] ?) to one that takes (?, ConnectorSession), where ? is the return type of getter. - ScalarFunctionImplementation castImplementation = functionManager.getScalarFunctionImplementation(functionManager.getCoercion(fromType, toType)); + ScalarFunctionImplementation castImplementation = functionManager.getScalarFunctionImplementation(functionManager.lookupCast(CAST, fromType.getTypeSignature(), toType.getTypeSignature())); MethodHandle cast = castImplementation.getMethodHandle(); if (cast.type().parameterArray()[0] != ConnectorSession.class) { cast = MethodHandles.dropArguments(cast, 0, ConnectorSession.class); diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/RowToRowCast.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/RowToRowCast.java index 6da191d663092..4d60b92412987 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/RowToRowCast.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/RowToRowCast.java @@ -14,8 +14,8 @@ package com.facebook.presto.operator.scalar; import com.facebook.presto.metadata.BoundVariables; +import com.facebook.presto.metadata.FunctionHandle; import com.facebook.presto.metadata.FunctionManager; -import com.facebook.presto.metadata.Signature; import com.facebook.presto.metadata.SqlOperator; import com.facebook.presto.spi.ConnectorSession; import com.facebook.presto.spi.PrestoException; @@ -43,7 +43,6 @@ import java.lang.invoke.MethodHandle; import java.util.List; -import static com.facebook.presto.metadata.Signature.internalOperator; import static com.facebook.presto.metadata.Signature.withVariadicBound; import static com.facebook.presto.operator.scalar.ScalarFunctionImplementation.ArgumentProperty.valueTypeArgumentProperty; import static com.facebook.presto.operator.scalar.ScalarFunctionImplementation.NullConvention.RETURN_NULL_ON_NULL; @@ -142,18 +141,15 @@ private static Class generateRowCast(Type fromType, Type toType, FunctionMana // loop through to append member blocks for (int i = 0; i < toTypes.size(); i++) { - Signature signature = internalOperator( - CAST.name(), - toTypes.get(i).getTypeSignature(), - ImmutableList.of(fromTypes.get(i).getTypeSignature())); - ScalarFunctionImplementation function = functionManager.getScalarFunctionImplementation(signature); + FunctionHandle functionHandle = functionManager.lookupCast(CAST, fromTypes.get(i).getTypeSignature(), toTypes.get(i).getTypeSignature()); + ScalarFunctionImplementation function = functionManager.getScalarFunctionImplementation(functionHandle); Type currentFromType = fromTypes.get(i); if (currentFromType.equals(UNKNOWN)) { body.append(singleRowBlockWriter.invoke("appendNull", BlockBuilder.class).pop()); continue; } BytecodeExpression fromElement = constantType(binder, currentFromType).getValue(value, constantInt(i)); - BytecodeExpression toElement = invokeFunction(scope, cachedInstanceBinder, signature.getName(), function, fromElement); + BytecodeExpression toElement = invokeFunction(scope, cachedInstanceBinder, CAST.name(), function, fromElement); IfStatement ifElementNull = new IfStatement("if the element in the row type is null..."); ifElementNull.condition(value.invoke("isNull", boolean.class, constantInt(i))) diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/TryCastFunction.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/TryCastFunction.java index cdcf07bb29d53..a20faaa8eade2 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/TryCastFunction.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/TryCastFunction.java @@ -14,6 +14,7 @@ package com.facebook.presto.operator.scalar; import com.facebook.presto.metadata.BoundVariables; +import com.facebook.presto.metadata.FunctionHandle; import com.facebook.presto.metadata.FunctionKind; import com.facebook.presto.metadata.FunctionManager; import com.facebook.presto.metadata.Signature; @@ -28,6 +29,7 @@ import java.util.List; import static com.facebook.presto.metadata.Signature.typeVariable; +import static com.facebook.presto.spi.function.OperatorType.CAST; import static com.facebook.presto.spi.type.TypeSignature.parseTypeSignature; import static java.lang.invoke.MethodHandles.catchException; import static java.lang.invoke.MethodHandles.constant; @@ -80,8 +82,8 @@ public ScalarFunctionImplementation specialize(BoundVariables boundVariables, in MethodHandle tryCastHandle; // the resulting method needs to return a boxed type - Signature signature = functionManager.getCoercion(fromType, toType); - ScalarFunctionImplementation implementation = functionManager.getScalarFunctionImplementation(signature); + FunctionHandle functionHandle = functionManager.lookupCast(CAST, fromType.getTypeSignature(), toType.getTypeSignature()); + ScalarFunctionImplementation implementation = functionManager.getScalarFunctionImplementation(functionHandle); argumentProperties = ImmutableList.of(implementation.getArgumentProperty(0)); MethodHandle coercion = implementation.getMethodHandle(); coercion = coercion.asType(methodType(returnType, coercion.type())); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/InterpretedFunctionInvoker.java b/presto-main/src/main/java/com/facebook/presto/sql/InterpretedFunctionInvoker.java index 7160916be943a..29849175c791a 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/InterpretedFunctionInvoker.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/InterpretedFunctionInvoker.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.sql; +import com.facebook.presto.metadata.FunctionHandle; import com.facebook.presto.metadata.FunctionManager; import com.facebook.presto.metadata.Signature; import com.facebook.presto.operator.scalar.ScalarFunctionImplementation; @@ -40,23 +41,37 @@ public InterpretedFunctionInvoker(FunctionManager functionManager) this.functionManager = requireNonNull(functionManager, "registry is null"); } + public Object invoke(FunctionHandle functionHandle, ConnectorSession session, Object... arguments) + { + return invoke(functionHandle, session, Arrays.asList(arguments)); + } + + public Object invoke(FunctionHandle functionHandle, ConnectorSession session, List arguments) + { + return invoke(functionManager.getScalarFunctionImplementation(functionHandle), session, arguments); + } + public Object invoke(Signature function, ConnectorSession session, Object... arguments) { return invoke(function, session, Arrays.asList(arguments)); } + public Object invoke(Signature function, ConnectorSession session, List arguments) + { + return invoke(functionManager.getScalarFunctionImplementation(function), session, arguments); + } + /** * Arguments must be the native container type for the corresponding SQL types. *

* Returns a value in the native container type corresponding to the declared SQL return type */ - public Object invoke(Signature function, ConnectorSession session, List arguments) + private Object invoke(ScalarFunctionImplementation function, ConnectorSession session, List arguments) { - ScalarFunctionImplementation implementation = functionManager.getScalarFunctionImplementation(function); - MethodHandle method = implementation.getMethodHandle(); + MethodHandle method = function.getMethodHandle(); // handle function on instance method, to allow use of fields - method = bindInstanceFactory(method, implementation); + method = bindInstanceFactory(method, function); if (method.type().parameterCount() > 0 && method.type().parameterType(0) == ConnectorSession.class) { method = method.bindTo(session); @@ -64,9 +79,9 @@ public Object invoke(Signature function, ConnectorSession session, List List actualArguments = new ArrayList<>(); for (int i = 0; i < arguments.size(); i++) { Object argument = arguments.get(i); - ArgumentProperty argumentProperty = implementation.getArgumentProperty(i); + ArgumentProperty argumentProperty = function.getArgumentProperty(i); if (argumentProperty.getArgumentType() == VALUE_TYPE) { - if (implementation.getArgumentProperty(i).getNullConvention() == USE_NULL_FLAG) { + if (function.getArgumentProperty(i).getNullConvention() == USE_NULL_FLAG) { boolean isNull = argument == null; if (isNull) { argument = Defaults.defaultValue(method.type().parameterType(actualArguments.size())); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/ExpressionAnalyzer.java b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/ExpressionAnalyzer.java index 0103c85dd7e83..e46a9dff7f040 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/ExpressionAnalyzer.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/ExpressionAnalyzer.java @@ -118,6 +118,7 @@ import java.util.function.Function; import static com.facebook.presto.SystemSessionProperties.isLegacyRowFieldOrdinalAccessEnabled; +import static com.facebook.presto.spi.function.OperatorType.CAST; import static com.facebook.presto.spi.function.OperatorType.SUBSCRIPT; import static com.facebook.presto.spi.type.BigintType.BIGINT; import static com.facebook.presto.spi.type.BooleanType.BOOLEAN; @@ -724,7 +725,7 @@ protected Type visitGenericLiteral(GenericLiteral node, StackableAstVisitorConte if (!JSON.equals(type)) { try { - functionManager.getCoercion(VARCHAR, type); + functionManager.lookupCast(CAST, VARCHAR.getTypeSignature(), type.getTypeSignature()); } catch (IllegalArgumentException e) { throw new SemanticException(TYPE_MISMATCH, node, "No literal form for type %s", type); @@ -1005,7 +1006,7 @@ public Type visitCast(Cast node, StackableAstVisitorContext context) Type value = process(node.getExpression(), context); if (!value.equals(UNKNOWN) && !node.isTypeOnly()) { try { - functionManager.getCoercion(value, type); + functionManager.lookupCast(CAST, value.getTypeSignature(), type.getTypeSignature()); } catch (OperatorNotFoundException e) { throw new SemanticException(TYPE_MISMATCH, node, "Cannot cast %s to %s", value, type); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/gen/CastCodeGenerator.java b/presto-main/src/main/java/com/facebook/presto/sql/gen/CastCodeGenerator.java index 530f034795016..71d0a02ad49dc 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/gen/CastCodeGenerator.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/gen/CastCodeGenerator.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.sql.gen; +import com.facebook.presto.metadata.FunctionHandle; import com.facebook.presto.metadata.Signature; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.relational.RowExpression; @@ -24,6 +25,7 @@ import java.util.List; import java.util.Optional; +import static com.facebook.presto.spi.function.OperatorType.CAST; import static com.facebook.presto.sql.gen.BytecodeGenerator.generateWrite; public class CastCodeGenerator @@ -34,13 +36,13 @@ public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorCon { RowExpression argument = arguments.get(0); - Signature function = generatorContext + FunctionHandle function = generatorContext .getFunctionManager() - .getCoercion(argument.getType(), returnType); + .lookupCast(CAST, argument.getType().getTypeSignature(), returnType.getTypeSignature()); BytecodeBlock block = new BytecodeBlock() .append(generatorContext.generateCall( - function.getName(), + CAST.name(), generatorContext.getFunctionManager().getScalarFunctionImplementation(function), ImmutableList.of(generatorContext.generate(argument, Optional.empty())))); outputBlockVariable.ifPresent(output -> block.append(generateWrite(generatorContext, returnType, output))); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/gen/NullIfCodeGenerator.java b/presto-main/src/main/java/com/facebook/presto/sql/gen/NullIfCodeGenerator.java index 6232b277c8f90..f22fc1b32784e 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/gen/NullIfCodeGenerator.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/gen/NullIfCodeGenerator.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.sql.gen; +import com.facebook.presto.metadata.FunctionHandle; import com.facebook.presto.metadata.Signature; import com.facebook.presto.operator.scalar.ScalarFunctionImplementation; import com.facebook.presto.spi.function.OperatorType; @@ -30,6 +31,7 @@ import java.util.List; import java.util.Optional; +import static com.facebook.presto.spi.function.OperatorType.CAST; import static com.facebook.presto.sql.gen.BytecodeGenerator.generateWrite; import static com.facebook.presto.sql.gen.BytecodeUtils.ifWasNullPopAndGoto; import static io.airlift.bytecode.expression.BytecodeExpressions.constantTrue; @@ -99,11 +101,11 @@ private static BytecodeNode cast( return argument; } - Signature function = generatorContext + FunctionHandle functionHandle = generatorContext .getFunctionManager() - .getCoercion(actualType.getTypeSignature(), requiredType); + .lookupCast(CAST, actualType.getTypeSignature(), requiredType); // TODO: do we need a full function call? (nullability checks, etc) - return generatorContext.generateCall(function.getName(), generatorContext.getFunctionManager().getScalarFunctionImplementation(function), ImmutableList.of(argument)); + return generatorContext.generateCall(CAST.name(), generatorContext.getFunctionManager().getScalarFunctionImplementation(functionHandle), ImmutableList.of(argument)); } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/DomainTranslator.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/DomainTranslator.java index c716166192ec4..e64a9d64fd807 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/DomainTranslator.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/DomainTranslator.java @@ -15,8 +15,9 @@ import com.facebook.presto.Session; import com.facebook.presto.execution.warnings.WarningCollector; +import com.facebook.presto.metadata.FunctionHandle; import com.facebook.presto.metadata.Metadata; -import com.facebook.presto.metadata.Signature; +import com.facebook.presto.metadata.OperatorNotFoundException; import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.predicate.DiscreteValues; import com.facebook.presto.spi.predicate.Domain; @@ -59,7 +60,7 @@ import java.util.Map; import java.util.Optional; -import static com.facebook.presto.metadata.Signature.internalOperator; +import static com.facebook.presto.spi.function.OperatorType.CAST; import static com.facebook.presto.spi.function.OperatorType.SATURATED_FLOOR_CAST; import static com.facebook.presto.sql.ExpressionUtils.and; import static com.facebook.presto.sql.ExpressionUtils.combineConjuncts; @@ -659,17 +660,19 @@ private Optional floorValue(Type fromType, Type toType, Object value) .map((operator) -> functionInvoker.invoke(operator, session.toConnectorSession(), value)); } - private Optional getSaturatedFloorCastOperator(Type fromType, Type toType) + private Optional getSaturatedFloorCastOperator(Type fromType, Type toType) { - if (metadata.getFunctionManager().canResolveOperator(SATURATED_FLOOR_CAST, toType, ImmutableList.of(fromType))) { - return Optional.of(internalOperator(SATURATED_FLOOR_CAST, toType, ImmutableList.of(fromType))); + try { + return Optional.of(metadata.getFunctionManager().lookupCast(SATURATED_FLOOR_CAST, fromType.getTypeSignature(), toType.getTypeSignature())); + } + catch (OperatorNotFoundException e) { + return Optional.empty(); } - return Optional.empty(); } private int compareOriginalValueToCoerced(Type originalValueType, Object originalValue, Type coercedValueType, Object coercedValue) { - Signature castToOriginalTypeOperator = metadata.getFunctionManager().getCoercion(coercedValueType, originalValueType); + FunctionHandle castToOriginalTypeOperator = metadata.getFunctionManager().lookupCast(CAST, coercedValueType.getTypeSignature(), originalValueType.getTypeSignature()); Object coercedValueInOriginalType = functionInvoker.invoke(castToOriginalTypeOperator, session.toConnectorSession(), coercedValue); Block originalValueBlock = Utils.nativeValueToBlock(originalValueType, originalValue); Block coercedValueBlock = Utils.nativeValueToBlock(originalValueType, coercedValueInOriginalType); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/ExpressionInterpreter.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/ExpressionInterpreter.java index 9b4c18da8cc66..57bd1cff983ed 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/ExpressionInterpreter.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/ExpressionInterpreter.java @@ -16,6 +16,7 @@ import com.facebook.presto.Session; import com.facebook.presto.client.FailureInfo; import com.facebook.presto.execution.warnings.WarningCollector; +import com.facebook.presto.metadata.FunctionHandle; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.metadata.Signature; import com.facebook.presto.operator.scalar.ArraySubscriptOperator; @@ -115,6 +116,7 @@ import static com.facebook.presto.SystemSessionProperties.isLegacyRowFieldOrdinalAccessEnabled; import static com.facebook.presto.operator.scalar.ScalarFunctionImplementation.NullConvention.RETURN_NULL_ON_NULL; import static com.facebook.presto.spi.StandardErrorCode.NOT_SUPPORTED; +import static com.facebook.presto.spi.function.OperatorType.CAST; import static com.facebook.presto.spi.type.TypeSignature.parseTypeSignature; import static com.facebook.presto.spi.type.TypeUtils.writeNativeValue; import static com.facebook.presto.spi.type.VarcharType.createVarcharType; @@ -822,8 +824,8 @@ protected Object visitNullIfExpression(NullIfExpression node, Object context) Type commonType = metadata.getTypeManager().getCommonSuperType(firstType, secondType).get(); - Signature firstCast = metadata.getFunctionManager().getCoercion(firstType, commonType); - Signature secondCast = metadata.getFunctionManager().getCoercion(secondType, commonType); + FunctionHandle firstCast = metadata.getFunctionManager().lookupCast(CAST, firstType.getTypeSignature(), commonType.getTypeSignature()); + FunctionHandle secondCast = metadata.getFunctionManager().lookupCast(CAST, secondType.getTypeSignature(), commonType.getTypeSignature()); // cast(first as ) == cast(second as ) boolean equal = Boolean.TRUE.equals(invokeOperator( @@ -1130,7 +1132,7 @@ public Object visitCast(Cast node, Object context) return null; } - Signature operator = metadata.getFunctionManager().getCoercion(sourceType, targetType); + FunctionHandle operator = metadata.getFunctionManager().lookupCast(CAST, sourceType.getTypeSignature(), targetType.getTypeSignature()); try { return functionInvoker.invoke(operator, session, ImmutableList.of(value)); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/LiteralInterpreter.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/LiteralInterpreter.java index 13b8324bcd2be..fbbe9b276d2d1 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/LiteralInterpreter.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/LiteralInterpreter.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.sql.planner; +import com.facebook.presto.metadata.FunctionHandle; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.metadata.Signature; import com.facebook.presto.spi.ConnectorSession; @@ -39,6 +40,7 @@ import io.airlift.slice.Slice; import static com.facebook.presto.metadata.FunctionKind.SCALAR; +import static com.facebook.presto.spi.function.OperatorType.CAST; import static com.facebook.presto.spi.type.TypeSignature.parseTypeSignature; import static com.facebook.presto.spi.type.VarcharType.VARCHAR; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.INVALID_LITERAL; @@ -136,8 +138,8 @@ protected Object visitGenericLiteral(GenericLiteral node, ConnectorSession sessi } try { - Signature signature = metadata.getFunctionManager().getCoercion(VARCHAR, type); - return functionInvoker.invoke(signature, session, ImmutableList.of(utf8Slice(node.getValue()))); + FunctionHandle functionHandle = metadata.getFunctionManager().lookupCast(CAST, VARCHAR.getTypeSignature(), type.getTypeSignature()); + return functionInvoker.invoke(functionHandle, session, ImmutableList.of(utf8Slice(node.getValue()))); } catch (IllegalArgumentException e) { throw new SemanticException(TYPE_MISMATCH, node, "No literal form for type %s", type); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/planPrinter/PlanPrinter.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/planPrinter/PlanPrinter.java index 09c257f6a4c26..8e9e163c6a2d0 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/planPrinter/PlanPrinter.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/planPrinter/PlanPrinter.java @@ -20,9 +20,9 @@ import com.facebook.presto.cost.StatsAndCosts; import com.facebook.presto.execution.StageInfo; import com.facebook.presto.execution.StageStats; +import com.facebook.presto.metadata.FunctionHandle; import com.facebook.presto.metadata.FunctionManager; import com.facebook.presto.metadata.OperatorNotFoundException; -import com.facebook.presto.metadata.Signature; import com.facebook.presto.metadata.TableHandle; import com.facebook.presto.operator.StageExecutionDescriptor; import com.facebook.presto.spi.ColumnHandle; @@ -116,6 +116,7 @@ import java.util.stream.Stream; import static com.facebook.presto.execution.StageInfo.getAllStages; +import static com.facebook.presto.spi.function.OperatorType.CAST; import static com.facebook.presto.spi.type.VarcharType.VARCHAR; import static com.facebook.presto.sql.planner.SystemPartitioningHandle.SINGLE_DISTRIBUTION; import static com.facebook.presto.sql.planner.planPrinter.PlanNodeStatsSummarizer.aggregateStageStats; @@ -1199,8 +1200,8 @@ private static String castToVarchar(Type type, Object value, FunctionManager fun } try { - Signature coercion = functionManager.getCoercion(type, VARCHAR); - Slice coerced = (Slice) new InterpretedFunctionInvoker(functionManager).invoke(coercion, session.toConnectorSession(), value); + FunctionHandle cast = functionManager.lookupCast(CAST, type.getTypeSignature(), VARCHAR.getTypeSignature()); + Slice coerced = (Slice) new InterpretedFunctionInvoker(functionManager).invoke(cast, session.toConnectorSession(), value); return coerced.toStringUtf8(); } catch (OperatorNotFoundException e) { diff --git a/presto-main/src/main/java/com/facebook/presto/sql/relational/optimizer/ExpressionOptimizer.java b/presto-main/src/main/java/com/facebook/presto/sql/relational/optimizer/ExpressionOptimizer.java index 32b95c2b07221..c9a96b88ddc11 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/relational/optimizer/ExpressionOptimizer.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/relational/optimizer/ExpressionOptimizer.java @@ -18,6 +18,7 @@ import com.facebook.presto.metadata.Signature; import com.facebook.presto.operator.scalar.ScalarFunctionImplementation; import com.facebook.presto.spi.ConnectorSession; +import com.facebook.presto.spi.function.OperatorType; import com.facebook.presto.spi.type.TypeManager; import com.facebook.presto.spi.type.TypeSignature; import com.facebook.presto.sql.relational.CallExpression; @@ -255,7 +256,7 @@ private CallExpression rewriteCast(CallExpression call) } return call( - functionManager.getCoercion(call.getArguments().get(0).getType(), call.getType()), + functionManager.lookupCast(OperatorType.CAST, call.getArguments().get(0).getType().getTypeSignature(), call.getType().getTypeSignature()).getSignature(), call.getType(), call.getArguments()); } diff --git a/presto-main/src/test/java/com/facebook/presto/metadata/TestFunctionRegistry.java b/presto-main/src/test/java/com/facebook/presto/metadata/TestFunctionRegistry.java index d6e3107ef362c..eaae26231b1d1 100644 --- a/presto-main/src/test/java/com/facebook/presto/metadata/TestFunctionRegistry.java +++ b/presto-main/src/test/java/com/facebook/presto/metadata/TestFunctionRegistry.java @@ -26,7 +26,6 @@ import com.facebook.presto.sql.analyzer.FeaturesConfig; import com.facebook.presto.sql.tree.QualifiedName; import com.facebook.presto.type.TypeRegistry; -import com.google.common.base.Functions; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import org.testng.annotations.Test; @@ -40,6 +39,8 @@ import static com.facebook.presto.metadata.Signature.typeVariable; import static com.facebook.presto.operator.scalar.ScalarFunctionImplementation.ArgumentProperty.valueTypeArgumentProperty; import static com.facebook.presto.operator.scalar.ScalarFunctionImplementation.NullConvention.RETURN_NULL_ON_NULL; +import static com.facebook.presto.spi.function.OperatorType.CAST; +import static com.facebook.presto.spi.function.OperatorType.SATURATED_FLOOR_CAST; import static com.facebook.presto.spi.type.HyperLogLogType.HYPER_LOG_LOG; import static com.facebook.presto.spi.type.TimestampWithTimeZoneType.TIMESTAMP_WITH_TIME_ZONE; import static com.facebook.presto.spi.type.TypeSignature.parseTypeSignature; @@ -63,10 +64,8 @@ public void testIdentityCast() { TypeRegistry typeManager = new TypeRegistry(); FunctionRegistry registry = createFunctionRegistry(typeManager); - Signature exactOperator = registry.getCoercion(HYPER_LOG_LOG, HYPER_LOG_LOG); - assertEquals(exactOperator.getName(), mangleOperatorName(OperatorType.CAST.name())); - assertEquals(transform(exactOperator.getArgumentTypes(), Functions.toStringFunction()), ImmutableList.of(StandardTypes.HYPER_LOG_LOG)); - assertEquals(exactOperator.getReturnType().getBase(), StandardTypes.HYPER_LOG_LOG); + FunctionHandle exactOperator = registry.lookupCast(CAST, HYPER_LOG_LOG.getTypeSignature(), HYPER_LOG_LOG.getTypeSignature()); + assertEquals(exactOperator, new FunctionHandle(new Signature(mangleOperatorName(CAST.name()), SCALAR, HYPER_LOG_LOG.getTypeSignature(), HYPER_LOG_LOG.getTypeSignature()))); } @Test @@ -77,7 +76,7 @@ public void testExactMatchBeforeCoercion() boolean foundOperator = false; for (SqlFunction function : registry.listOperators()) { OperatorType operatorType = unmangleOperator(function.getSignature().getName()); - if (operatorType == OperatorType.CAST || operatorType == OperatorType.SATURATED_FLOOR_CAST) { + if (operatorType == CAST || operatorType == SATURATED_FLOOR_CAST) { continue; } if (!function.getSignature().getTypeVariableConstraints().isEmpty()) { diff --git a/presto-main/src/test/java/com/facebook/presto/type/TestTypeRegistry.java b/presto-main/src/test/java/com/facebook/presto/type/TestTypeRegistry.java index ae7ba0c14051b..96f6dfed87c22 100644 --- a/presto-main/src/test/java/com/facebook/presto/type/TestTypeRegistry.java +++ b/presto-main/src/test/java/com/facebook/presto/type/TestTypeRegistry.java @@ -15,6 +15,7 @@ import com.facebook.presto.block.BlockEncodingManager; import com.facebook.presto.metadata.FunctionManager; +import com.facebook.presto.metadata.OperatorNotFoundException; import com.facebook.presto.spi.function.OperatorType; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.TypeManager; @@ -251,8 +252,12 @@ public void testCastOperatorsExistForCoercions() for (Type sourceType : types) { for (Type resultType : types) { if (typeRegistry.canCoerce(sourceType, resultType) && sourceType != UNKNOWN && resultType != UNKNOWN) { - assertTrue(functionManager.canResolveOperator(OperatorType.CAST, resultType, ImmutableList.of(sourceType)), - format("'%s' -> '%s' coercion exists but there is no cast operator", sourceType, resultType)); + try { + functionManager.lookupCast(OperatorType.CAST, sourceType.getTypeSignature(), resultType.getTypeSignature()); + } + catch (OperatorNotFoundException e) { + fail(format("'%s' -> '%s' coercion exists but there is no cast operator", sourceType, resultType)); + } } } } From 7363e311c81c4fe535be646974236e8f03ed0e48 Mon Sep 17 00:00:00 2001 From: rongrong Date: Sat, 2 Mar 2019 22:38:26 -0800 Subject: [PATCH 7/8] Remove FunctionManager canResolveOperator --- .../presto/metadata/FunctionManager.java | 5 ----- .../presto/metadata/FunctionNamespace.java | 5 ----- .../presto/metadata/FunctionRegistry.java | 6 ------ .../presto/metadata/MetadataManager.java | 21 ++++++++++++++----- 4 files changed, 16 insertions(+), 21 deletions(-) diff --git a/presto-main/src/main/java/com/facebook/presto/metadata/FunctionManager.java b/presto-main/src/main/java/com/facebook/presto/metadata/FunctionManager.java index 24bdad8532c6f..c11f5ee707962 100644 --- a/presto-main/src/main/java/com/facebook/presto/metadata/FunctionManager.java +++ b/presto-main/src/main/java/com/facebook/presto/metadata/FunctionManager.java @@ -102,11 +102,6 @@ public boolean isAggregationFunction(QualifiedName name) return globalFunctionNamespace.isAggregationFunction(name); } - public boolean canResolveOperator(OperatorType operatorType, Type returnType, List argumentTypes) - { - return globalFunctionNamespace.canResolveOperator(operatorType, returnType, argumentTypes); - } - public Signature resolveOperator(OperatorType operatorType, List argumentTypes) { return globalFunctionNamespace.resolveOperator(operatorType, argumentTypes).getSignature(); diff --git a/presto-main/src/main/java/com/facebook/presto/metadata/FunctionNamespace.java b/presto-main/src/main/java/com/facebook/presto/metadata/FunctionNamespace.java index 92b9ff6adc4ef..f1aebdd9bd8af 100644 --- a/presto-main/src/main/java/com/facebook/presto/metadata/FunctionNamespace.java +++ b/presto-main/src/main/java/com/facebook/presto/metadata/FunctionNamespace.java @@ -73,11 +73,6 @@ public boolean isAggregationFunction(QualifiedName name) return registry.isAggregationFunction(name); } - public boolean canResolveOperator(OperatorType operatorType, Type returnType, List argumentTypes) - { - return registry.canResolveOperator(operatorType, returnType, argumentTypes); - } - public FunctionHandle resolveOperator(OperatorType operatorType, List argumentTypes) throws OperatorNotFoundException { diff --git a/presto-main/src/main/java/com/facebook/presto/metadata/FunctionRegistry.java b/presto-main/src/main/java/com/facebook/presto/metadata/FunctionRegistry.java index 44d1d331b28e4..59b286c771ead 100644 --- a/presto-main/src/main/java/com/facebook/presto/metadata/FunctionRegistry.java +++ b/presto-main/src/main/java/com/facebook/presto/metadata/FunctionRegistry.java @@ -1053,12 +1053,6 @@ public List listOperators() .collect(toImmutableList()); } - public boolean canResolveOperator(OperatorType operatorType, Type returnType, List argumentTypes) - { - Signature signature = internalOperator(operatorType, returnType, argumentTypes); - return isRegistered(signature); - } - public boolean isRegistered(Signature signature) { try { diff --git a/presto-main/src/main/java/com/facebook/presto/metadata/MetadataManager.java b/presto-main/src/main/java/com/facebook/presto/metadata/MetadataManager.java index c71d0c1019af2..b97d332d6f97e 100644 --- a/presto-main/src/main/java/com/facebook/presto/metadata/MetadataManager.java +++ b/presto-main/src/main/java/com/facebook/presto/metadata/MetadataManager.java @@ -220,23 +220,23 @@ public final void verifyComparableOrderableContract() Multimap missingOperators = HashMultimap.create(); for (Type type : typeManager.getTypes()) { if (type.isComparable()) { - if (!functions.canResolveOperator(HASH_CODE, BIGINT, ImmutableList.of(type))) { + if (!canResolveOperator(HASH_CODE, BIGINT, ImmutableList.of(type))) { missingOperators.put(type, HASH_CODE); } - if (!functions.canResolveOperator(EQUAL, BOOLEAN, ImmutableList.of(type, type))) { + if (!canResolveOperator(EQUAL, BOOLEAN, ImmutableList.of(type, type))) { missingOperators.put(type, EQUAL); } - if (!functions.canResolveOperator(NOT_EQUAL, BOOLEAN, ImmutableList.of(type, type))) { + if (!canResolveOperator(NOT_EQUAL, BOOLEAN, ImmutableList.of(type, type))) { missingOperators.put(type, NOT_EQUAL); } } if (type.isOrderable()) { for (OperatorType operator : ImmutableList.of(LESS_THAN, LESS_THAN_OR_EQUAL, GREATER_THAN, GREATER_THAN_OR_EQUAL)) { - if (!functions.canResolveOperator(operator, BOOLEAN, ImmutableList.of(type, type))) { + if (!canResolveOperator(operator, BOOLEAN, ImmutableList.of(type, type))) { missingOperators.put(type, operator); } } - if (!functions.canResolveOperator(BETWEEN, BOOLEAN, ImmutableList.of(type, type, type))) { + if (!canResolveOperator(BETWEEN, BOOLEAN, ImmutableList.of(type, type, type))) { missingOperators.put(type, BETWEEN); } } @@ -1189,6 +1189,17 @@ private static JsonCodec createTestingViewCodec() return new JsonCodecFactory(provider).jsonCodec(ViewDefinition.class); } + private boolean canResolveOperator(OperatorType operatorType, Type returnType, List argumentTypes) + { + try { + getFunctionManager().resolveOperator(operatorType, argumentTypes); + return true; + } + catch (OperatorNotFoundException e) { + return false; + } + } + @VisibleForTesting public Map> getCatalogsByQueryId() { From de711286f1171e4b45c078af7cd5d77dcf981e33 Mon Sep 17 00:00:00 2001 From: rongrong Date: Sat, 2 Mar 2019 23:18:24 -0800 Subject: [PATCH 8/8] Switch resolveOperator to FunctionHandle --- .../metadata/FunctionInvokerProvider.java | 15 ++++++++---- .../presto/metadata/FunctionManager.java | 6 ++--- .../scalar/RowComparisonOperator.java | 6 ++--- .../scalar/RowDistinctFromOperator.java | 6 ++--- .../operator/scalar/RowEqualOperator.java | 4 ++-- .../sql/analyzer/ExpressionAnalyzer.java | 2 +- .../presto/sql/gen/BytecodeUtils.java | 6 ----- .../presto/sql/gen/InCodeGenerator.java | 24 +++++++++---------- .../presto/sql/gen/NullIfCodeGenerator.java | 12 +++++----- .../presto/sql/gen/SwitchCodeGenerator.java | 7 +++--- .../sql/planner/ExpressionInterpreter.java | 8 +++---- 11 files changed, 47 insertions(+), 49 deletions(-) diff --git a/presto-main/src/main/java/com/facebook/presto/metadata/FunctionInvokerProvider.java b/presto-main/src/main/java/com/facebook/presto/metadata/FunctionInvokerProvider.java index 4ff7a5e209076..6c0d8a2a36f5c 100644 --- a/presto-main/src/main/java/com/facebook/presto/metadata/FunctionInvokerProvider.java +++ b/presto-main/src/main/java/com/facebook/presto/metadata/FunctionInvokerProvider.java @@ -37,23 +37,28 @@ public class FunctionInvokerProvider { - private final FunctionRegistry functionRegistry; + private final FunctionManager functionManager; - public FunctionInvokerProvider(FunctionRegistry functionRegistry) + public FunctionInvokerProvider(FunctionManager functionManager) { - this.functionRegistry = functionRegistry; + this.functionManager = functionManager; } public FunctionInvoker createFunctionInvoker(Signature signature, Optional invocationConvention) { - ScalarFunctionImplementation scalarFunctionImplementation = functionRegistry.getScalarFunctionImplementation(signature); + return createFunctionInvoker(new FunctionHandle(signature), invocationConvention); + } + + public FunctionInvoker createFunctionInvoker(FunctionHandle functionHandle, Optional invocationConvention) + { + ScalarFunctionImplementation scalarFunctionImplementation = functionManager.getScalarFunctionImplementation(functionHandle); for (ScalarImplementationChoice choice : scalarFunctionImplementation.getAllChoices()) { if (checkChoice(choice.getArgumentProperties(), choice.isNullable(), choice.hasSession(), invocationConvention)) { return new FunctionInvoker(choice.getMethodHandle()); } } checkState(invocationConvention.isPresent()); - throw new PrestoException(FUNCTION_NOT_FOUND, format("Dependent function implementation (%s) with convention (%s) is not available", signature, invocationConvention.toString())); + throw new PrestoException(FUNCTION_NOT_FOUND, format("Dependent function implementation (%s) with convention (%s) is not available", functionHandle, invocationConvention.toString())); } @VisibleForTesting diff --git a/presto-main/src/main/java/com/facebook/presto/metadata/FunctionManager.java b/presto-main/src/main/java/com/facebook/presto/metadata/FunctionManager.java index c11f5ee707962..c023a4b085a9d 100644 --- a/presto-main/src/main/java/com/facebook/presto/metadata/FunctionManager.java +++ b/presto-main/src/main/java/com/facebook/presto/metadata/FunctionManager.java @@ -41,7 +41,7 @@ public FunctionManager(TypeManager typeManager, BlockEncodingSerde blockEncoding { FunctionRegistry functionRegistry = new FunctionRegistry(typeManager, blockEncodingSerde, featuresConfig, this); this.globalFunctionNamespace = new FunctionNamespace(functionRegistry); - this.functionInvokerProvider = new FunctionInvokerProvider(functionRegistry); + this.functionInvokerProvider = new FunctionInvokerProvider(this); if (typeManager instanceof TypeRegistry) { ((TypeRegistry) typeManager).setFunctionManager(this); } @@ -102,9 +102,9 @@ public boolean isAggregationFunction(QualifiedName name) return globalFunctionNamespace.isAggregationFunction(name); } - public Signature resolveOperator(OperatorType operatorType, List argumentTypes) + public FunctionHandle resolveOperator(OperatorType operatorType, List argumentTypes) { - return globalFunctionNamespace.resolveOperator(operatorType, argumentTypes).getSignature(); + return globalFunctionNamespace.resolveOperator(operatorType, argumentTypes); } public boolean isRegistered(Signature signature) diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/RowComparisonOperator.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/RowComparisonOperator.java index 637fb421de145..d896bbd098737 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/RowComparisonOperator.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/RowComparisonOperator.java @@ -13,8 +13,8 @@ * limitations under the License. */ +import com.facebook.presto.metadata.FunctionHandle; import com.facebook.presto.metadata.FunctionManager; -import com.facebook.presto.metadata.Signature; import com.facebook.presto.metadata.SqlOperator; import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.function.OperatorType; @@ -48,8 +48,8 @@ protected List getMethodHandles(RowType type, FunctionManager func { ImmutableList.Builder argumentMethods = ImmutableList.builder(); for (Type parameterType : type.getTypeParameters()) { - Signature signature = functionManager.resolveOperator(operatorType, ImmutableList.of(parameterType, parameterType)); - argumentMethods.add(functionManager.getScalarFunctionImplementation(signature).getMethodHandle()); + FunctionHandle operatorHandle = functionManager.resolveOperator(operatorType, ImmutableList.of(parameterType, parameterType)); + argumentMethods.add(functionManager.getScalarFunctionImplementation(operatorHandle).getMethodHandle()); } return argumentMethods.build(); } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/RowDistinctFromOperator.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/RowDistinctFromOperator.java index 7e0ab017d72b6..5f219f3224d81 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/RowDistinctFromOperator.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/RowDistinctFromOperator.java @@ -14,9 +14,9 @@ */ import com.facebook.presto.metadata.BoundVariables; +import com.facebook.presto.metadata.FunctionHandle; import com.facebook.presto.metadata.FunctionInvoker; import com.facebook.presto.metadata.FunctionManager; -import com.facebook.presto.metadata.Signature; import com.facebook.presto.metadata.SqlOperator; import com.facebook.presto.operator.scalar.ScalarFunctionImplementation.ReturnPlaceConvention; import com.facebook.presto.operator.scalar.ScalarFunctionImplementation.ScalarImplementationChoice; @@ -65,9 +65,9 @@ public ScalarFunctionImplementation specialize(BoundVariables boundVariables, in ImmutableList.Builder argumentMethods = ImmutableList.builder(); Type type = boundVariables.getTypeVariable("T"); for (Type parameterType : type.getTypeParameters()) { - Signature signature = functionManager.resolveOperator(IS_DISTINCT_FROM, ImmutableList.of(parameterType, parameterType)); + FunctionHandle operatorHandle = functionManager.resolveOperator(IS_DISTINCT_FROM, ImmutableList.of(parameterType, parameterType)); FunctionInvoker functionInvoker = functionManager.getFunctionInvokerProvider().createFunctionInvoker( - signature, + operatorHandle, Optional.of(new InvocationConvention( ImmutableList.of(NULL_FLAG, NULL_FLAG), InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL, diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/RowEqualOperator.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/RowEqualOperator.java index 7b55e4eecfddd..9941a3bef19e8 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/RowEqualOperator.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/RowEqualOperator.java @@ -14,8 +14,8 @@ */ import com.facebook.presto.metadata.BoundVariables; +import com.facebook.presto.metadata.FunctionHandle; import com.facebook.presto.metadata.FunctionManager; -import com.facebook.presto.metadata.Signature; import com.facebook.presto.metadata.SqlOperator; import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.type.RowType; @@ -76,7 +76,7 @@ public static List resolveFieldEqualOperators(RowType rowType, Fun private static MethodHandle resolveEqualOperator(Type type, FunctionManager functionManager) { - Signature operator = functionManager.resolveOperator(EQUAL, ImmutableList.of(type, type)); + FunctionHandle operator = functionManager.resolveOperator(EQUAL, ImmutableList.of(type, type)); ScalarFunctionImplementation implementation = functionManager.getScalarFunctionImplementation(operator); return implementation.getMethodHandle(); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/ExpressionAnalyzer.java b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/ExpressionAnalyzer.java index e46a9dff7f040..ca5932c6d8950 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/ExpressionAnalyzer.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/ExpressionAnalyzer.java @@ -1244,7 +1244,7 @@ private Type getOperator(StackableAstVisitorContext context, Expression Signature operatorSignature; try { - operatorSignature = functionManager.resolveOperator(operatorType, argumentTypes.build()); + operatorSignature = functionManager.resolveOperator(operatorType, argumentTypes.build()).getSignature(); } catch (OperatorNotFoundException e) { throw new SemanticException(TYPE_MISMATCH, node, "%s", e.getMessage()); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/gen/BytecodeUtils.java b/presto-main/src/main/java/com/facebook/presto/sql/gen/BytecodeUtils.java index 8bd78ffa6c691..74fc1d6b401ef 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/gen/BytecodeUtils.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/gen/BytecodeUtils.java @@ -13,7 +13,6 @@ */ package com.facebook.presto.sql.gen; -import com.facebook.presto.metadata.Signature; import com.facebook.presto.operator.scalar.ScalarFunctionImplementation; import com.facebook.presto.operator.scalar.ScalarFunctionImplementation.ArgumentProperty; import com.facebook.presto.operator.scalar.ScalarFunctionImplementation.NullConvention; @@ -419,11 +418,6 @@ public static BytecodeExpression invoke(Binding binding, String name) return invokeDynamic(BOOTSTRAP_METHOD, ImmutableList.of(binding.getBindingId()), name, binding.getType()); } - public static BytecodeExpression invoke(Binding binding, Signature signature) - { - return invoke(binding, signature.getName()); - } - public static BytecodeNode generateWrite( CallSiteBinder callSiteBinder, Scope scope, diff --git a/presto-main/src/main/java/com/facebook/presto/sql/gen/InCodeGenerator.java b/presto-main/src/main/java/com/facebook/presto/sql/gen/InCodeGenerator.java index 015d42787090d..4b87a9300aeb0 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/gen/InCodeGenerator.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/gen/InCodeGenerator.java @@ -13,10 +13,10 @@ */ package com.facebook.presto.sql.gen; +import com.facebook.presto.metadata.FunctionHandle; import com.facebook.presto.metadata.FunctionManager; import com.facebook.presto.metadata.Signature; import com.facebook.presto.operator.scalar.ScalarFunctionImplementation; -import com.facebook.presto.spi.function.OperatorType; import com.facebook.presto.spi.type.BigintType; import com.facebook.presto.spi.type.DateType; import com.facebook.presto.spi.type.IntegerType; @@ -43,6 +43,7 @@ import java.util.Optional; import java.util.Set; +import static com.facebook.presto.spi.function.OperatorType.EQUAL; import static com.facebook.presto.spi.function.OperatorType.HASH_CODE; import static com.facebook.presto.spi.function.OperatorType.INDETERMINATE; import static com.facebook.presto.sql.gen.BytecodeGenerator.generateWrite; @@ -121,10 +122,10 @@ public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorCon SwitchGenerationCase switchGenerationCase = checkSwitchGenerationCase(type, values); - Signature hashCodeSignature = generatorContext.getFunctionManager().resolveOperator(HASH_CODE, ImmutableList.of(type)); - MethodHandle hashCodeFunction = generatorContext.getFunctionManager().getScalarFunctionImplementation(hashCodeSignature).getMethodHandle(); - Signature isIndeterminateSignature = generatorContext.getFunctionManager().resolveOperator(INDETERMINATE, ImmutableList.of(type)); - ScalarFunctionImplementation isIndeterminateFunction = generatorContext.getFunctionManager().getScalarFunctionImplementation(isIndeterminateSignature); + FunctionHandle hashCodeHandle = generatorContext.getFunctionManager().resolveOperator(HASH_CODE, ImmutableList.of(type)); + MethodHandle hashCodeFunction = generatorContext.getFunctionManager().getScalarFunctionImplementation(hashCodeHandle).getMethodHandle(); + FunctionHandle isIndeterminateHandle = generatorContext.getFunctionManager().resolveOperator(INDETERMINATE, ImmutableList.of(type)); + ScalarFunctionImplementation isIndeterminateFunction = generatorContext.getFunctionManager().getScalarFunctionImplementation(isIndeterminateHandle); ImmutableListMultimap.Builder hashBucketsBuilder = ImmutableListMultimap.builder(); ImmutableList.Builder defaultBucket = ImmutableList.builder(); @@ -203,7 +204,6 @@ public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorCon value, testValues, false, - isIndeterminateSignature, isIndeterminateFunction); switchBuilder.addCase(bucket.getKey(), caseBlock); } @@ -214,7 +214,7 @@ public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorCon switchBlock = new BytecodeBlock() .comment("lookupSwitch(hashCode())") .getVariable(value) - .append(invoke(hashCodeBinding, hashCodeSignature)) + .append(invoke(hashCodeBinding, HASH_CODE.name())) .invokeStatic(Long.class, "hashCode", int.class, long.class) .putVariable(expression) .append(switchBuilder.build()); @@ -248,7 +248,6 @@ public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorCon value, defaultBucket.build(), true, - isIndeterminateSignature, isIndeterminateFunction) .setDescription("default"); @@ -296,7 +295,6 @@ private static BytecodeBlock buildInCase( Variable value, Collection testValues, boolean checkForNulls, - Signature isIndeterminateSignature, ScalarFunctionImplementation isIndeterminateFunction) { Variable caseWasNull = null; // caseWasNull is set to true the first time a null in `testValues` is encountered @@ -323,7 +321,7 @@ private static BytecodeBlock buildInCase( // That is incorrect. Doing an explicit check for indeterminate is required to correctly return NULL. if (testValues.isEmpty()) { elseBlock.append(new BytecodeBlock() - .append(generatorContext.generateCall(isIndeterminateSignature.getName(), isIndeterminateFunction, ImmutableList.of(value))) + .append(generatorContext.generateCall(INDETERMINATE.name(), isIndeterminateFunction, ImmutableList.of(value))) .putVariable(wasNull)); } else { @@ -333,8 +331,8 @@ private static BytecodeBlock buildInCase( elseBlock.gotoLabel(noMatchLabel); - Signature equalsSignature = generatorContext.getFunctionManager().resolveOperator(OperatorType.EQUAL, ImmutableList.of(type, type)); - ScalarFunctionImplementation equalsFunction = generatorContext.getFunctionManager().getScalarFunctionImplementation(equalsSignature); + FunctionHandle equalsHandle = generatorContext.getFunctionManager().resolveOperator(EQUAL, ImmutableList.of(type, type)); + ScalarFunctionImplementation equalsFunction = generatorContext.getFunctionManager().getScalarFunctionImplementation(equalsHandle); BytecodeNode elseNode = elseBlock; for (BytecodeNode testNode : testValues) { @@ -342,7 +340,7 @@ private static BytecodeBlock buildInCase( IfStatement test = new IfStatement(); BytecodeNode equalsCall = generatorContext.generateCall( - equalsSignature.getName(), + EQUAL.name(), equalsFunction, ImmutableList.of(value, testNode)); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/gen/NullIfCodeGenerator.java b/presto-main/src/main/java/com/facebook/presto/sql/gen/NullIfCodeGenerator.java index f22fc1b32784e..424a51a9779c1 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/gen/NullIfCodeGenerator.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/gen/NullIfCodeGenerator.java @@ -16,7 +16,6 @@ import com.facebook.presto.metadata.FunctionHandle; import com.facebook.presto.metadata.Signature; import com.facebook.presto.operator.scalar.ScalarFunctionImplementation; -import com.facebook.presto.spi.function.OperatorType; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.TypeSignature; import com.facebook.presto.sql.relational.RowExpression; @@ -32,6 +31,7 @@ import java.util.Optional; import static com.facebook.presto.spi.function.OperatorType.CAST; +import static com.facebook.presto.spi.function.OperatorType.EQUAL; import static com.facebook.presto.sql.gen.BytecodeGenerator.generateWrite; import static com.facebook.presto.sql.gen.BytecodeUtils.ifWasNullPopAndGoto; import static io.airlift.bytecode.expression.BytecodeExpressions.constantTrue; @@ -62,14 +62,14 @@ public BytecodeNode generateExpression(Signature signature, BytecodeGeneratorCon Type secondType = second.getType(); // if (equal(cast(first as ), cast(second as )) - Signature equalsSignature = generatorContext.getFunctionManager().resolveOperator(OperatorType.EQUAL, ImmutableList.of(firstType, secondType)); - ScalarFunctionImplementation equalsFunction = generatorContext.getFunctionManager().getScalarFunctionImplementation(equalsSignature); + FunctionHandle equalFunction = generatorContext.getFunctionManager().resolveOperator(EQUAL, ImmutableList.of(firstType, secondType)); + ScalarFunctionImplementation equalsFunction = generatorContext.getFunctionManager().getScalarFunctionImplementation(equalFunction); BytecodeNode equalsCall = generatorContext.generateCall( - equalsSignature.getName(), + EQUAL.name(), equalsFunction, ImmutableList.of( - cast(generatorContext, firstValue, firstType, equalsSignature.getArgumentTypes().get(0)), - cast(generatorContext, generatorContext.generate(second, Optional.empty()), secondType, equalsSignature.getArgumentTypes().get(1)))); + cast(generatorContext, firstValue, firstType, equalFunction.getSignature().getArgumentTypes().get(0)), + cast(generatorContext, generatorContext.generate(second, Optional.empty()), secondType, equalFunction.getSignature().getArgumentTypes().get(1)))); BytecodeBlock conditionBlock = new BytecodeBlock() .append(equalsCall) diff --git a/presto-main/src/main/java/com/facebook/presto/sql/gen/SwitchCodeGenerator.java b/presto-main/src/main/java/com/facebook/presto/sql/gen/SwitchCodeGenerator.java index fc17ae0deec05..7bb32666f83fe 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/gen/SwitchCodeGenerator.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/gen/SwitchCodeGenerator.java @@ -13,8 +13,8 @@ */ package com.facebook.presto.sql.gen; +import com.facebook.presto.metadata.FunctionHandle; import com.facebook.presto.metadata.Signature; -import com.facebook.presto.spi.function.OperatorType; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.relational.CallExpression; import com.facebook.presto.sql.relational.RowExpression; @@ -32,6 +32,7 @@ import java.util.List; import java.util.Optional; +import static com.facebook.presto.spi.function.OperatorType.EQUAL; import static com.facebook.presto.sql.gen.BytecodeGenerator.generateWrite; import static io.airlift.bytecode.expression.BytecodeExpressions.constantFalse; import static io.airlift.bytecode.expression.BytecodeExpressions.constantTrue; @@ -115,7 +116,7 @@ else if ( == ) { RowExpression result = ((CallExpression) clause).getArguments().get(1); // call equals(value, operand) - Signature equalsFunction = generatorContext.getFunctionManager().resolveOperator(OperatorType.EQUAL, ImmutableList.of(value.getType(), operand.getType())); + FunctionHandle equalsFunction = generatorContext.getFunctionManager().resolveOperator(EQUAL, ImmutableList.of(value.getType(), operand.getType())); // TODO: what if operand is null? It seems that the call will return "null" (which is cleared below) // and the code only does the right thing because the value in the stack for that scenario is @@ -123,7 +124,7 @@ else if ( == ) { // This code should probably be checking for wasNull after the call and "failing" the equality // check if wasNull is true BytecodeNode equalsCall = generatorContext.generateCall( - equalsFunction.getName(), + EQUAL.name(), generatorContext.getFunctionManager().getScalarFunctionImplementation(equalsFunction), ImmutableList.of(generatorContext.generate(operand, Optional.empty()), getTempVariableNode)); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/ExpressionInterpreter.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/ExpressionInterpreter.java index 57bd1cff983ed..eff3971cac36d 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/ExpressionInterpreter.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/ExpressionInterpreter.java @@ -709,8 +709,8 @@ protected Object visitArithmeticUnary(ArithmeticUnaryExpression node, Object con case PLUS: return value; case MINUS: - Signature operatorSignature = metadata.getFunctionManager().resolveOperator(OperatorType.NEGATION, types(node.getValue())); - MethodHandle handle = metadata.getFunctionManager().getScalarFunctionImplementation(operatorSignature).getMethodHandle(); + FunctionHandle operatorHandle = metadata.getFunctionManager().resolveOperator(OperatorType.NEGATION, types(node.getValue())); + MethodHandle handle = metadata.getFunctionManager().getScalarFunctionImplementation(operatorHandle).getMethodHandle(); if (handle.type().parameterCount() > 0 && handle.type().parameterType(0) == ConnectorSession.class) { handle = handle.bindTo(session); @@ -1263,8 +1263,8 @@ private boolean hasUnresolvedValue(List values) private Object invokeOperator(OperatorType operatorType, List argumentTypes, List argumentValues) { - Signature operatorSignature = metadata.getFunctionManager().resolveOperator(operatorType, argumentTypes); - return functionInvoker.invoke(operatorSignature, session, argumentValues); + FunctionHandle operatorHandle = metadata.getFunctionManager().resolveOperator(operatorType, argumentTypes); + return functionInvoker.invoke(operatorHandle, session, argumentValues); } private Expression toExpression(Object base, Type type)