diff --git a/core/trino-main/src/main/java/io/trino/execution/CallTask.java b/core/trino-main/src/main/java/io/trino/execution/CallTask.java index c7ecfa652620..9fd3276d7451 100644 --- a/core/trino-main/src/main/java/io/trino/execution/CallTask.java +++ b/core/trino-main/src/main/java/io/trino/execution/CallTask.java @@ -26,6 +26,7 @@ import io.trino.spi.block.BlockBuilder; import io.trino.spi.connector.ConnectorAccessControl; import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.eventlistener.ClauseInfo; import io.trino.spi.eventlistener.RoutineInfo; import io.trino.spi.procedure.Procedure; import io.trino.spi.procedure.Procedure.Argument; @@ -202,7 +203,7 @@ else if (ConnectorAccessControl.class.equals(type)) { } accessControl.checkCanExecuteProcedure(session.toSecurityContext(), procedureName); - stateMachine.setRoutines(ImmutableList.of(new RoutineInfo(procedureName.getObjectName(), session.getUser()))); + stateMachine.setRoutines(ImmutableList.of(new RoutineInfo(procedureName.getObjectName(), session.getUser(), ClauseInfo.CALL))); try { procedure.getMethodHandle().invokeWithArguments(arguments); diff --git a/core/trino-main/src/main/java/io/trino/sql/analyzer/Analysis.java b/core/trino-main/src/main/java/io/trino/sql/analyzer/Analysis.java index 2e65779da0e3..9d31fe131a99 100644 --- a/core/trino-main/src/main/java/io/trino/sql/analyzer/Analysis.java +++ b/core/trino-main/src/main/java/io/trino/sql/analyzer/Analysis.java @@ -37,6 +37,7 @@ import io.trino.spi.connector.ColumnSchema; import io.trino.spi.connector.ConnectorTableMetadata; import io.trino.spi.connector.ConnectorTransactionHandle; +import io.trino.spi.eventlistener.ClauseInfo; import io.trino.spi.eventlistener.ColumnDetail; import io.trino.spi.eventlistener.ColumnInfo; import io.trino.spi.eventlistener.RoutineInfo; @@ -644,9 +645,9 @@ public ResolvedFunction getResolvedFunction(Expression node) return resolvedFunctions.get(NodeRef.of(node)).getFunction(); } - public void addResolvedFunction(Expression node, ResolvedFunction function, String authorization) + public void addResolvedFunction(Expression node, ResolvedFunction function, String authorization, ClauseInfo clauseInfo) { - resolvedFunctions.put(NodeRef.of(node), new RoutineEntry(function, authorization)); + resolvedFunctions.put(NodeRef.of(node), new RoutineEntry(function, authorization, clauseInfo)); } public Set> getColumnReferences() @@ -1136,7 +1137,7 @@ public List getReferencedTables() public List getRoutines() { return resolvedFunctions.entrySet().stream() - .map(entry -> new RoutineInfo(entry.getValue().function.getSignature().getName(), entry.getValue().getAuthorization())) + .map(entry -> new RoutineInfo(entry.getValue().function.getSignature().getName(), entry.getValue().getAuthorization(), entry.getValue().getClauseInfo())) .collect(toImmutableList()); } @@ -1959,11 +1960,13 @@ private static class RoutineEntry { private final ResolvedFunction function; private final String authorization; + private final ClauseInfo clauseInfo; - public RoutineEntry(ResolvedFunction function, String authorization) + public RoutineEntry(ResolvedFunction function, String authorization, ClauseInfo clauseInfo) { this.function = requireNonNull(function, "function is null"); this.authorization = requireNonNull(authorization, "authorization is null"); + this.clauseInfo = requireNonNull(clauseInfo, "clauseInfo is null"); } public ResolvedFunction getFunction() @@ -1975,6 +1978,11 @@ public String getAuthorization() { return authorization; } + + public ClauseInfo getClauseInfo() + { + return clauseInfo; + } } private static class UpdateTarget diff --git a/core/trino-main/src/main/java/io/trino/sql/analyzer/ExpressionAnalyzer.java b/core/trino-main/src/main/java/io/trino/sql/analyzer/ExpressionAnalyzer.java index afb1c4b4d82f..0424ef52130c 100644 --- a/core/trino-main/src/main/java/io/trino/sql/analyzer/ExpressionAnalyzer.java +++ b/core/trino-main/src/main/java/io/trino/sql/analyzer/ExpressionAnalyzer.java @@ -34,6 +34,7 @@ import io.trino.spi.ErrorCodeSupplier; import io.trino.spi.TrinoException; import io.trino.spi.TrinoWarning; +import io.trino.spi.eventlistener.ClauseInfo; import io.trino.spi.function.BoundSignature; import io.trino.spi.function.FunctionMetadata; import io.trino.spi.function.OperatorType; @@ -3379,7 +3380,7 @@ public static ExpressionAnalysis analyzePatternRecognitionExpression( ExpressionAnalyzer analyzer = new ExpressionAnalyzer(plannerContext, accessControl, statementAnalyzerFactory, analysis, session, TypeProvider.empty(), warningCollector); analyzer.analyze(expression, scope, labels); - updateAnalysis(analysis, analyzer, session, accessControl); + updateAnalysis(analysis, analyzer, session, accessControl, ClauseInfo.PATTERN_RECOGNITION); return new ExpressionAnalysis( analyzer.getExpressionTypes(), @@ -3435,12 +3436,13 @@ public static ExpressionAnalysis analyzeExpression( Analysis analysis, Expression expression, WarningCollector warningCollector, - CorrelationSupport correlationSupport) + CorrelationSupport correlationSupport, + ClauseInfo clauseInfo) { ExpressionAnalyzer analyzer = new ExpressionAnalyzer(plannerContext, accessControl, statementAnalyzerFactory, analysis, session, TypeProvider.empty(), warningCollector); analyzer.analyze(expression, scope, correlationSupport); - updateAnalysis(analysis, analyzer, session, accessControl); + updateAnalysis(analysis, analyzer, session, accessControl, clauseInfo); analysis.addExpressionFields(expression, analyzer.getSourceFields()); return new ExpressionAnalysis( @@ -3470,7 +3472,7 @@ public static ExpressionAnalysis analyzeWindow( ExpressionAnalyzer analyzer = new ExpressionAnalyzer(plannerContext, accessControl, statementAnalyzerFactory, analysis, session, TypeProvider.empty(), warningCollector); analyzer.analyzeWindow(window, scope, originalNode, correlationSupport); - updateAnalysis(analysis, analyzer, session, accessControl); + updateAnalysis(analysis, analyzer, session, accessControl, ClauseInfo.WINDOW); return new ExpressionAnalysis( analyzer.getExpressionTypes(), @@ -3484,7 +3486,7 @@ public static ExpressionAnalysis analyzeWindow( analyzer.getWindowFunctions()); } - private static void updateAnalysis(Analysis analysis, ExpressionAnalyzer analyzer, Session session, AccessControl accessControl) + private static void updateAnalysis(Analysis analysis, ExpressionAnalyzer analyzer, Session session, AccessControl accessControl, ClauseInfo clauseInfo) { analysis.addTypes(analyzer.getExpressionTypes()); analysis.addCoercions( @@ -3493,7 +3495,7 @@ private static void updateAnalysis(Analysis analysis, ExpressionAnalyzer analyze analyzer.getSortKeyCoercionsForFrameBoundCalculation(), analyzer.getSortKeyCoercionsForFrameBoundComparison()); analysis.addFrameBoundCalculations(analyzer.getFrameBoundCalculations()); - analyzer.getResolvedFunctions().forEach((key, value) -> analysis.addResolvedFunction(key.getNode(), value, session.getUser())); + analyzer.getResolvedFunctions().forEach((key, value) -> analysis.addResolvedFunction(key.getNode(), value, session.getUser(), clauseInfo)); analysis.addColumnReferences(analyzer.getColumnReferences()); analysis.addLambdaArgumentReferences(analyzer.getLambdaArgumentReferences()); analysis.addTableColumnReferences(accessControl, session.getIdentity(), analyzer.getTableColumnReferences()); diff --git a/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java b/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java index b898ec00f976..67f3024c881f 100644 --- a/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java +++ b/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java @@ -68,6 +68,7 @@ import io.trino.spi.connector.PointerType; import io.trino.spi.connector.SchemaTableName; import io.trino.spi.connector.TableProcedureMetadata; +import io.trino.spi.eventlistener.ClauseInfo; import io.trino.spi.function.FunctionKind; import io.trino.spi.function.OperatorType; import io.trino.spi.ptf.Argument; @@ -1441,7 +1442,7 @@ protected Scope visitUnnest(Unnest node, Optional scope) for (Expression expression : node.getExpressions()) { List expressionOutputs = new ArrayList<>(); - ExpressionAnalysis expressionAnalysis = analyzeExpression(expression, createScope(scope)); + ExpressionAnalysis expressionAnalysis = analyzeExpression(expression, createScope(scope), ClauseInfo.UNNEST); Type expressionType = expressionAnalysis.getType(expression); if (expressionType instanceof ArrayType) { Type elementType = ((ArrayType) expressionType).getElementType(); @@ -1855,7 +1856,7 @@ protected Scope visitTable(Table table, Optional scope) if (addRowIdColumn) { FieldReference reference = new FieldReference(outputFields.size() - 1); - analyzeExpression(reference, tableScope); + analyzeExpression(reference, tableScope, ClauseInfo.FROM); analysis.setRowIdField(table, reference); } @@ -2161,7 +2162,7 @@ protected Scope visitPatternRecognitionRelation(PatternRecognitionRelation relat for (Expression expression : relation.getPartitionBy()) { // The PARTITION BY clause is a list of columns of the row pattern input table. validateAndGetInputField(expression, inputScope); - Type type = analyzeExpression(expression, inputScope).getType(expression); + Type type = analyzeExpression(expression, inputScope, ClauseInfo.PATTERN_RECOGNITION).getType(expression); if (!type.isComparable()) { throw semanticException(TYPE_MISMATCH, expression, "%s is not comparable, and therefore cannot be used in PARTITION BY", type); } @@ -2172,7 +2173,7 @@ protected Scope visitPatternRecognitionRelation(PatternRecognitionRelation relat // The ORDER BY clause is a list of columns of the row pattern input table. Expression expression = sortItem.getSortKey(); validateAndGetInputField(expression, inputScope); - Type type = analyzeExpression(expression, inputScope).getType(sortItem.getSortKey()); + Type type = analyzeExpression(expression, inputScope, ClauseInfo.PATTERN_RECOGNITION).getType(sortItem.getSortKey()); if (!type.isOrderable()) { throw semanticException(TYPE_MISMATCH, sortItem, "%s is not orderable, and therefore cannot be used in ORDER BY", type); } @@ -2665,7 +2666,7 @@ else if (node.getType() == FULL) { // Need to register coercions in case when join criteria requires coercion (e.g. join on char(1) = char(2)) // Correlations are only currently support in the join criteria for INNER joins - ExpressionAnalysis expressionAnalysis = analyzeExpression(expression, output, node.getType() == INNER ? CorrelationSupport.ALLOWED : CorrelationSupport.DISALLOWED); + ExpressionAnalysis expressionAnalysis = analyzeExpression(expression, output, node.getType() == INNER ? CorrelationSupport.ALLOWED : CorrelationSupport.DISALLOWED, ClauseInfo.JOIN); Type clauseType = expressionAnalysis.getType(expression); if (!clauseType.equals(BOOLEAN)) { if (!clauseType.equals(UNKNOWN)) { @@ -2746,7 +2747,7 @@ protected Scope visitUpdate(Update update, Optional scope) ImmutableList.Builder expressionTypesBuilder = ImmutableList.builder(); for (UpdateAssignment assignment : update.getAssignments()) { Expression expression = assignment.getValue(); - ExpressionAnalysis analysis = analyzeExpression(expression, tableScope); + ExpressionAnalysis analysis = analyzeExpression(expression, tableScope, ClauseInfo.UPDATE); analysesBuilder.add(analysis); expressionTypesBuilder.add(analysis.getType(expression)); } @@ -2871,7 +2872,7 @@ protected Scope visitMerge(Merge merge, Optional scope) // Analyze all expressions in the Merge node Expression mergePredicate = merge.getPredicate(); - ExpressionAnalysis predicateAnalysis = analyzeExpression(mergePredicate, joinScope, CorrelationSupport.DISALLOWED); + ExpressionAnalysis predicateAnalysis = analyzeExpression(mergePredicate, joinScope, CorrelationSupport.DISALLOWED, ClauseInfo.MERGE); Type mergePredicateType = predicateAnalysis.getType(mergePredicate); if (!typeCoercion.canCoerce(mergePredicateType, BOOLEAN)) { throw semanticException(TYPE_MISMATCH, mergePredicate, "The MERGE predicate must evaluate to a boolean: actual type %s", mergePredicateType); @@ -2905,7 +2906,7 @@ else if (operation instanceof MergeInsert && caseColumnNames.isEmpty()) { if (operation.getExpression().isPresent()) { Expression predicate = operation.getExpression().get(); - analysis.recordSubqueries(merge, analyzeExpression(predicate, joinScope)); + analysis.recordSubqueries(merge, analyzeExpression(predicate, joinScope, ClauseInfo.MERGE)); Type predicateType = analysis.getType(predicate); if (!predicateType.equals(BOOLEAN)) { @@ -2922,7 +2923,7 @@ else if (operation instanceof MergeInsert && caseColumnNames.isEmpty()) { for (int index = 0; index < caseColumnNames.size(); index++) { String columnName = caseColumnNames.get(index); Expression expression = setExpressions.get(index); - ExpressionAnalysis expressionAnalysis = analyzeExpression(expression, joinScope); + ExpressionAnalysis expressionAnalysis = analyzeExpression(expression, joinScope, ClauseInfo.MERGE); analysis.recordSubqueries(merge, expressionAnalysis); Type targetType = requireNonNull(dataColumnTypes.get(columnName)); setColumnTypesBuilder.add(targetType); @@ -3132,7 +3133,7 @@ protected Scope visitValues(Values node, Optional scope) checkState(node.getRows().size() >= 1); List rowTypes = node.getRows().stream() - .map(row -> analyzeExpression(row, createScope(scope)).getType(row)) + .map(row -> analyzeExpression(row, createScope(scope), ClauseInfo.VALUES).getType(row)) .map(type -> { if (type instanceof RowType) { return type; @@ -3407,7 +3408,7 @@ private void analyzeHaving(QuerySpecification node, Scope scope) throw semanticException(NESTED_WINDOW, windowExpressions.get(0), "HAVING clause cannot contain window functions or row pattern measures"); } - ExpressionAnalysis expressionAnalysis = analyzeExpression(predicate, scope); + ExpressionAnalysis expressionAnalysis = analyzeExpression(predicate, scope, ClauseInfo.HAVING); analysis.recordSubqueries(node, expressionAnalysis); Type predicateType = expressionAnalysis.getType(predicate); @@ -3483,7 +3484,7 @@ private GroupingSetAnalysis analyzeGroupBy(QuerySpecification node, Scope scope, } else { verifyNoAggregateWindowOrGroupingFunctions(session, metadata, column, "GROUP BY clause"); - analyzeExpression(column, scope); + analyzeExpression(column, scope, ClauseInfo.GROUP_BY); } ResolvedField field = analysis.getColumnReferenceFields().get(NodeRef.of(column)); @@ -3491,7 +3492,7 @@ private GroupingSetAnalysis analyzeGroupBy(QuerySpecification node, Scope scope, sets.add(ImmutableList.of(ImmutableSet.of(field.getFieldId()))); } else { - analysis.recordSubqueries(node, analyzeExpression(column, scope)); + analysis.recordSubqueries(node, analyzeExpression(column, scope, ClauseInfo.GROUP_BY)); complexExpressions.add(column); } @@ -3500,7 +3501,7 @@ private GroupingSetAnalysis analyzeGroupBy(QuerySpecification node, Scope scope, } else { for (Expression column : groupingElement.getExpressions()) { - analyzeExpression(column, scope); + analyzeExpression(column, scope, ClauseInfo.GROUP_BY); if (!analysis.getColumnReferences().contains(NodeRef.of(column))) { throw semanticException(INVALID_COLUMN_REFERENCE, column, "GROUP BY expression must be a column reference: %s", column); } @@ -3805,7 +3806,7 @@ private void analyzeAllColumnsFromTable( checkState(field.getRelationAlias().isPresent(), "missing relation alias"); fieldExpression = new DereferenceExpression(DereferenceExpression.from(field.getRelationAlias().get()), new Identifier(field.getName().get())); } - analyzeExpression(fieldExpression, scope); + analyzeExpression(fieldExpression, scope, ClauseInfo.SELECT); outputExpressionBuilder.add(fieldExpression); selectExpressionBuilder.add(new SelectExpression(fieldExpression, Optional.empty())); @@ -3843,7 +3844,7 @@ private void analyzeAllFieldsFromRowTypeExpression( { ImmutableList.Builder itemOutputFieldBuilder = ImmutableList.builder(); - ExpressionAnalysis expressionAnalysis = analyzeExpression(expression, scope); + ExpressionAnalysis expressionAnalysis = analyzeExpression(expression, scope, ClauseInfo.SELECT); Type type = expressionAnalysis.getType(expression); if (!(type instanceof RowType)) { throw semanticException(TYPE_MISMATCH, node.getSelect(), "expected expression of type Row"); @@ -3858,7 +3859,7 @@ private void analyzeAllFieldsFromRowTypeExpression( for (int i = 0; i < referencedFieldsCount; i++) { Expression outputExpression = new SubscriptExpression(expression, new LongLiteral("" + (i + 1))); outputExpressionBuilder.add(outputExpression); - analyzeExpression(outputExpression, scope); + analyzeExpression(outputExpression, scope, ClauseInfo.SELECT); unfoldedExpressionsBuilder.add(outputExpression); Type outputExpressionType = type.getTypeParameters().get(i); @@ -3884,7 +3885,7 @@ private void analyzeSelectSingleColumn( ImmutableList.Builder selectExpressionBuilder) { Expression expression = singleColumn.getExpression(); - ExpressionAnalysis expressionAnalysis = analyzeExpression(expression, scope); + ExpressionAnalysis expressionAnalysis = analyzeExpression(expression, scope, ClauseInfo.SELECT); analysis.recordSubqueries(node, expressionAnalysis); outputExpressionBuilder.add(expression); selectExpressionBuilder.add(new SelectExpression(expression, Optional.empty())); @@ -3903,7 +3904,7 @@ private void analyzeWhere(Node node, Scope scope, Expression predicate) { verifyNoAggregateWindowOrGroupingFunctions(session, metadata, predicate, "WHERE clause"); - ExpressionAnalysis expressionAnalysis = analyzeExpression(predicate, scope); + ExpressionAnalysis expressionAnalysis = analyzeExpression(predicate, scope, ClauseInfo.WHERE); analysis.recordSubqueries(node, expressionAnalysis); Type predicateType = expressionAnalysis.getType(predicate); @@ -4070,7 +4071,7 @@ private Type getViewColumnType(ViewColumn column, QualifiedObjectName name, Node } } - private ExpressionAnalysis analyzeExpression(Expression expression, Scope scope) + private ExpressionAnalysis analyzeExpression(Expression expression, Scope scope, ClauseInfo clauseInfo) { return ExpressionAnalyzer.analyzeExpression( session, @@ -4081,10 +4082,11 @@ private ExpressionAnalysis analyzeExpression(Expression expression, Scope scope) analysis, expression, warningCollector, - correlationSupport); + correlationSupport, + clauseInfo); } - private ExpressionAnalysis analyzeExpression(Expression expression, Scope scope, CorrelationSupport correlationSupport) + private ExpressionAnalysis analyzeExpression(Expression expression, Scope scope, CorrelationSupport correlationSupport, ClauseInfo clauseInfo) { return ExpressionAnalyzer.analyzeExpression( session, @@ -4095,7 +4097,8 @@ private ExpressionAnalysis analyzeExpression(Expression expression, Scope scope, analysis, expression, warningCollector, - correlationSupport); + correlationSupport, + clauseInfo); } private void analyzeRowFilter(String currentIdentity, Table table, QualifiedObjectName name, Scope scope, ViewExpression filter) @@ -4127,7 +4130,8 @@ private void analyzeRowFilter(String currentIdentity, Table table, QualifiedObje analysis, expression, warningCollector, - correlationSupport); + correlationSupport, + ClauseInfo.ROW_FILTER); } catch (TrinoException e) { throw new TrinoException(e::getErrorCode, extractLocation(table), format("Invalid row filter for '%s': %s", name, e.getRawMessage()), e); @@ -4182,7 +4186,8 @@ private void analyzeColumnMask(String currentIdentity, Table table, QualifiedObj analysis, expression, warningCollector, - correlationSupport); + correlationSupport, + ClauseInfo.COLUMN_MASK); } catch (TrinoException e) { throw new TrinoException(e::getErrorCode, extractLocation(table), format("Invalid column mask for '%s.%s': %s", tableName, column, e.getRawMessage()), e); @@ -4218,7 +4223,7 @@ private List descriptorToFields(Scope scope) for (int fieldIndex = 0; fieldIndex < scope.getRelationType().getAllFieldCount(); fieldIndex++) { FieldReference expression = new FieldReference(fieldIndex); builder.add(expression); - analyzeExpression(expression, scope); + analyzeExpression(expression, scope, ClauseInfo.SELECT); } return builder.build(); } @@ -4612,7 +4617,8 @@ private List analyzeOrderBy(Node node, List sortItems, Sco analysis, expression, WarningCollector.NOOP, - correlationSupport); + correlationSupport, + ClauseInfo.ORDER_BY); analysis.recordSubqueries(node, expressionAnalysis); Type type = analysis.getType(expression); @@ -4709,12 +4715,12 @@ else if (node.getRowCount() instanceof LongLiteral) { private OptionalLong analyzeParameterAsRowCount(Parameter parameter, Scope scope, String context) { if (analysis.isDescribe()) { - analyzeExpression(parameter, scope); + analyzeExpression(parameter, scope, ClauseInfo.SELECT); analysis.addCoercion(parameter, BIGINT, false); return OptionalLong.empty(); } // validate parameter index - analyzeExpression(parameter, scope); + analyzeExpression(parameter, scope, ClauseInfo.SELECT); Expression providedValue = analysis.getParameters().get(NodeRef.of(parameter)); Object value; try { @@ -4809,7 +4815,7 @@ private Optional extractTableVersion(Table table, Optionaljava.method.addedToInterface method io.trino.spi.block.BlockBuilder io.trino.spi.block.BlockBuilder::newBlockBuilderLike(int, io.trino.spi.block.BlockBuilderStatus) + + java.annotation.removed + method void io.trino.spi.eventlistener.RoutineInfo::<init>(java.lang.String, java.lang.String) + diff --git a/core/trino-spi/src/main/java/io/trino/spi/eventlistener/ClauseInfo.java b/core/trino-spi/src/main/java/io/trino/spi/eventlistener/ClauseInfo.java new file mode 100644 index 000000000000..a8e64c3fd3d6 --- /dev/null +++ b/core/trino-spi/src/main/java/io/trino/spi/eventlistener/ClauseInfo.java @@ -0,0 +1,36 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.spi.eventlistener; + +public enum ClauseInfo +{ + SELECT, + FROM, + WHERE, + GROUP_BY, + HAVING, + WINDOW, + ORDER_BY, + OFFSET, + LIMIT, + JOIN, + UNNEST, + UPDATE, + MERGE, + ROW_FILTER, + COLUMN_MASK, + VALUES, + PATTERN_RECOGNITION, + CALL +} diff --git a/core/trino-spi/src/main/java/io/trino/spi/eventlistener/RoutineInfo.java b/core/trino-spi/src/main/java/io/trino/spi/eventlistener/RoutineInfo.java index d9d3b2e47ad2..ee29b94132e6 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/eventlistener/RoutineInfo.java +++ b/core/trino-spi/src/main/java/io/trino/spi/eventlistener/RoutineInfo.java @@ -16,6 +16,8 @@ import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; +import java.util.Optional; + import static java.util.Objects.requireNonNull; /** @@ -25,14 +27,34 @@ public class RoutineInfo { private final String routine; private final String authorization; + private final Optional clauseInfo; + + @Deprecated + public RoutineInfo(String routine, String authorization) + { + this( + routine, + authorization, + Optional.empty()); + } + + public RoutineInfo(String routine, String authorization, ClauseInfo clauseInfo) + { + this( + routine, + authorization, + Optional.of(requireNonNull(clauseInfo, "clauseInfo is null"))); + } @JsonCreator public RoutineInfo( @JsonProperty("routine") String routine, - @JsonProperty("authorization") String authorization) + @JsonProperty("authorization") String authorization, + @JsonProperty("clauseInfo") Optional clauseInfo) { this.routine = requireNonNull(routine, "routine is null"); this.authorization = requireNonNull(authorization, "authorization is null"); + this.clauseInfo = requireNonNull(clauseInfo, "clauseInfo is null"); } @JsonProperty @@ -46,4 +68,10 @@ public String getAuthorization() { return authorization; } + + @JsonProperty + public Optional getClauseInfo() + { + return clauseInfo; + } } diff --git a/testing/trino-tests/src/test/java/io/trino/execution/TestEventListenerBasic.java b/testing/trino-tests/src/test/java/io/trino/execution/TestEventListenerBasic.java index 40938d31668d..0e44920234e0 100644 --- a/testing/trino-tests/src/test/java/io/trino/execution/TestEventListenerBasic.java +++ b/testing/trino-tests/src/test/java/io/trino/execution/TestEventListenerBasic.java @@ -34,6 +34,7 @@ import io.trino.spi.connector.ConnectorMaterializedViewDefinition.Column; import io.trino.spi.connector.ConnectorViewDefinition; import io.trino.spi.connector.SchemaTableName; +import io.trino.spi.eventlistener.ClauseInfo; import io.trino.spi.eventlistener.ColumnDetail; import io.trino.spi.eventlistener.ColumnInfo; import io.trino.spi.eventlistener.OutputColumnMetadata; @@ -332,7 +333,7 @@ private void assertFailedQuery(Session session, @Language("SQL") String sql, Str public void testReferencedTablesAndRoutines() throws Exception { - QueryEvents queryEvents = runQueryAndWaitForEvents("SELECT sum(linenumber) FROM lineitem").getQueryEvents(); + QueryEvents queryEvents = runQueryAndWaitForEvents("SELECT sum(linenumber) FROM lineitem WHERE floor(linenumber) = 3 HAVING sum(linenumber) > 1").getQueryEvents(); QueryCompletedEvent event = queryEvents.getQueryCompletedEvent(); @@ -352,11 +353,28 @@ public void testReferencedTablesAndRoutines() assertTrue(column.getMasks().isEmpty()); List routines = event.getMetadata().getRoutines(); - assertEquals(tables.size(), 1); - - RoutineInfo routine = routines.get(0); - assertEquals(routine.getRoutine(), "sum"); - assertEquals(routine.getAuthorization(), "user"); + assertEquals(routines.size(), 3); + + List selectRoutines = routines.stream() + .filter(routine -> Optional.of(ClauseInfo.SELECT).equals(routine.getClauseInfo())) + .toList(); + assertEquals(selectRoutines.size(), 1); + assertEquals(selectRoutines.get(0).getRoutine(), "sum"); + assertEquals(selectRoutines.get(0).getAuthorization(), "user"); + + List whereRoutines = routines.stream() + .filter(routine -> Optional.of(ClauseInfo.WHERE).equals(routine.getClauseInfo())) + .toList(); + assertEquals(whereRoutines.size(), 1); + assertEquals(whereRoutines.get(0).getRoutine(), "floor"); + assertEquals(whereRoutines.get(0).getAuthorization(), "user"); + + List havingRoutines = routines.stream() + .filter(routine -> Optional.of(ClauseInfo.HAVING).equals(routine.getClauseInfo())) + .toList(); + assertEquals(havingRoutines.size(), 1); + assertEquals(havingRoutines.get(0).getRoutine(), "sum"); + assertEquals(havingRoutines.get(0).getAuthorization(), "user"); } @Test