diff --git a/presto-docs/src/main/sphinx/plugin/native-sidecar-plugin.rst b/presto-docs/src/main/sphinx/plugin/native-sidecar-plugin.rst index 760176afbe635..f44019a37e46e 100644 --- a/presto-docs/src/main/sphinx/plugin/native-sidecar-plugin.rst +++ b/presto-docs/src/main/sphinx/plugin/native-sidecar-plugin.rst @@ -106,3 +106,16 @@ Property Name Description query plans against native engine, ensuring execution compatibility. ============================================ ===================================================================== ============================== +Expression optimizer +----------------- + +These properties must be configured in ``etc/expression-manager/native.properties`` to use the native expression optimizer of the ``NativeSidecarPlugin``. + +============================================ ===================================================================== ============================== +Property Name Description Value +============================================ ===================================================================== ============================== +``expression-manager-factory.name`` Identifier for the expression optimizer. Enables optimization of `native` + expressions using the native expression optimizer. +============================================ ===================================================================== ============================== + +To enable the native expression optimizer for your session, set the expression_optimizer_name session property to native: ``SET SESSION expression_optimizer_name = 'native'`` diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/FunctionAndTypeManager.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/FunctionAndTypeManager.java index 2781b84a63240..1e04a550ab0ed 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/metadata/FunctionAndTypeManager.java +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/FunctionAndTypeManager.java @@ -855,6 +855,11 @@ public CatalogSchemaName getDefaultNamespace() return defaultNamespace; } + public HandleResolver getHandleResolver() + { + return handleResolver; + } + protected Type getType(UserDefinedType userDefinedType) { // Distinct type diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/HandleJsonModule.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/HandleJsonModule.java index c46f717b8524c..38621cb5c76f3 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/metadata/HandleJsonModule.java +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/HandleJsonModule.java @@ -23,6 +23,18 @@ public class HandleJsonModule implements Module { + private final HandleResolver handleResolver; + + public HandleJsonModule() + { + this(null); + } + + public HandleJsonModule(HandleResolver handleResolver) + { + this.handleResolver = handleResolver; + } + @Override public void configure(Binder binder) { @@ -40,6 +52,11 @@ public void configure(Binder binder) jsonBinder(binder).addModuleBinding().to(PartitioningHandleJacksonModule.class); jsonBinder(binder).addModuleBinding().to(FunctionHandleJacksonModule.class); - binder.bind(HandleResolver.class).in(Scopes.SINGLETON); + if (handleResolver == null) { + binder.bind(HandleResolver.class).in(Scopes.SINGLETON); + } + else { + binder.bind(HandleResolver.class).toInstance(handleResolver); + } } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/MetadataManager.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/MetadataManager.java index be66084783461..3cf4b8f540dc9 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/metadata/MetadataManager.java +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/MetadataManager.java @@ -83,6 +83,7 @@ import com.facebook.presto.sql.analyzer.FunctionsConfig; import com.facebook.presto.sql.analyzer.SemanticException; import com.facebook.presto.sql.analyzer.TypeSignatureProvider; +import com.facebook.presto.testing.TestProcedureRegistry; import com.facebook.presto.transaction.TransactionManager; import com.facebook.presto.type.TypeDeserializer; import com.google.common.annotations.VisibleForTesting; @@ -301,6 +302,21 @@ public static MetadataManager createTestMetadataManager(TransactionManager trans procedureRegistry); } + public static MetadataManager createTestMetadataManager(FunctionAndTypeManager functionAndTypeManager) + { + BlockEncodingManager blockEncodingManager = new BlockEncodingManager(); + return new MetadataManager( + functionAndTypeManager, + blockEncodingManager, + createTestingSessionPropertyManager(), + new SchemaPropertyManager(), + new TablePropertyManager(), + new ColumnPropertyManager(), + new AnalyzePropertyManager(), + functionAndTypeManager.getTransactionManager(), + new TestProcedureRegistry()); + } + @Override public final void verifyComparableOrderableContract() { diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/expressions/ExpressionOptimizerManager.java b/presto-main-base/src/main/java/com/facebook/presto/sql/expressions/ExpressionOptimizerManager.java index 5119bbda21ef1..d980a8706db16 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/expressions/ExpressionOptimizerManager.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/expressions/ExpressionOptimizerManager.java @@ -13,12 +13,14 @@ */ package com.facebook.presto.sql.expressions; +import com.facebook.airlift.log.Logger; import com.facebook.presto.FullConnectorSession; import com.facebook.presto.Session; import com.facebook.presto.metadata.FunctionAndTypeManager; import com.facebook.presto.nodeManager.PluginNodeManager; import com.facebook.presto.spi.ConnectorSession; import com.facebook.presto.spi.NodeManager; +import com.facebook.presto.spi.RowExpressionSerde; import com.facebook.presto.spi.relation.ExpressionOptimizer; import com.facebook.presto.spi.relation.ExpressionOptimizerProvider; import com.facebook.presto.spi.sql.planner.ExpressionOptimizerContext; @@ -46,12 +48,14 @@ public class ExpressionOptimizerManager implements ExpressionOptimizerProvider { + private static final Logger log = Logger.get(ExpressionOptimizerManager.class); public static final String DEFAULT_EXPRESSION_OPTIMIZER_NAME = "default"; private static final File EXPRESSION_MANAGER_CONFIGURATION_DIRECTORY = new File("etc/expression-manager/"); private static final String EXPRESSION_MANAGER_FACTORY_NAME = "expression-manager-factory.name"; private final NodeManager nodeManager; private final FunctionAndTypeManager functionAndTypeManager; + private final RowExpressionSerde rowExpressionSerde; private final FunctionResolution functionResolution; private final File configurationDirectory; @@ -59,16 +63,17 @@ public class ExpressionOptimizerManager private final Map expressionOptimizers = new ConcurrentHashMap<>(); @Inject - public ExpressionOptimizerManager(PluginNodeManager nodeManager, FunctionAndTypeManager functionAndTypeManager) + public ExpressionOptimizerManager(PluginNodeManager nodeManager, FunctionAndTypeManager functionAndTypeManager, RowExpressionSerde rowExpressionSerde) { - this(nodeManager, functionAndTypeManager, EXPRESSION_MANAGER_CONFIGURATION_DIRECTORY); + this(nodeManager, functionAndTypeManager, rowExpressionSerde, EXPRESSION_MANAGER_CONFIGURATION_DIRECTORY); } - public ExpressionOptimizerManager(PluginNodeManager nodeManager, FunctionAndTypeManager functionAndTypeManager, File configurationDirectory) + public ExpressionOptimizerManager(PluginNodeManager nodeManager, FunctionAndTypeManager functionAndTypeManager, RowExpressionSerde rowExpressionSerde, File configurationDirectory) { requireNonNull(nodeManager, "nodeManager is null"); this.nodeManager = requireNonNull(nodeManager, "nodeManager is null"); this.functionAndTypeManager = requireNonNull(functionAndTypeManager, "functionAndTypeManager is null"); + this.rowExpressionSerde = requireNonNull(rowExpressionSerde, "rowExpressionSerde is null"); this.functionResolution = new FunctionResolution(functionAndTypeManager.getFunctionAndTypeResolver()); this.configurationDirectory = requireNonNull(configurationDirectory, "configurationDirectory is null"); expressionOptimizers.put(DEFAULT_EXPRESSION_OPTIMIZER_NAME, new RowExpressionOptimizer(functionAndTypeManager)); @@ -88,7 +93,7 @@ public void loadExpressionOptimizerFactories() } } - private void loadExpressionOptimizerFactory(File configurationFile) + public void loadExpressionOptimizerFactory(File configurationFile) throws IOException { String optimizerName = getNameWithoutExtension(configurationFile.getName()); @@ -104,13 +109,16 @@ private void loadExpressionOptimizerFactory(File configurationFile) public void loadExpressionOptimizerFactory(String factoryName, String optimizerName, Map properties) { + requireNonNull(factoryName, "factoryName is null"); checkArgument(expressionOptimizerFactories.containsKey(factoryName), "ExpressionOptimizerFactory %s is not registered, registered factories: ", factoryName, expressionOptimizerFactories.keySet()); + log.info("-- Loading expression optimizer [%s] --", optimizerName); ExpressionOptimizer optimizer = expressionOptimizerFactories.get(factoryName).createOptimizer( properties, - new ExpressionOptimizerContext(nodeManager, functionAndTypeManager, functionResolution)); + new ExpressionOptimizerContext(nodeManager, rowExpressionSerde, functionAndTypeManager, functionResolution)); expressionOptimizers.put(optimizerName, optimizer); + log.info("-- Added expression optimizer [%s] --", optimizerName); } public void addExpressionOptimizerFactory(ExpressionOptimizerFactory expressionOptimizerFactory) diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/expressions/JsonCodecRowExpressionSerde.java b/presto-main-base/src/main/java/com/facebook/presto/sql/expressions/JsonCodecRowExpressionSerde.java new file mode 100644 index 0000000000000..cf276d6de3223 --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/expressions/JsonCodecRowExpressionSerde.java @@ -0,0 +1,46 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.expressions; + +import com.facebook.airlift.json.JsonCodec; +import com.facebook.presto.spi.RowExpressionSerde; +import com.facebook.presto.spi.relation.RowExpression; + +import javax.inject.Inject; + +import static java.util.Objects.requireNonNull; + +public class JsonCodecRowExpressionSerde + implements RowExpressionSerde +{ + private final JsonCodec codec; + + @Inject + public JsonCodecRowExpressionSerde(JsonCodec codec) + { + this.codec = requireNonNull(codec, "codec is null"); + } + + @Override + public String serialize(RowExpression expression) + { + return codec.toJson(expression); + } + + @Override + public RowExpression deserialize(String data) + { + return codec.fromJson(data); + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/relational/RowExpressionOptimizer.java b/presto-main-base/src/main/java/com/facebook/presto/sql/relational/RowExpressionOptimizer.java index caf443379fb4c..811fe88c20c10 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/relational/RowExpressionOptimizer.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/relational/RowExpressionOptimizer.java @@ -13,43 +13,24 @@ */ package com.facebook.presto.sql.relational; -import com.facebook.presto.common.CatalogSchemaName; -import com.facebook.presto.common.QualifiedObjectName; -import com.facebook.presto.expressions.RowExpressionRewriter; -import com.facebook.presto.expressions.RowExpressionTreeRewriter; import com.facebook.presto.metadata.FunctionAndTypeManager; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.spi.ConnectorSession; -import com.facebook.presto.spi.PrestoException; -import com.facebook.presto.spi.function.FunctionHandle; -import com.facebook.presto.spi.function.FunctionMetadata; -import com.facebook.presto.spi.relation.CallExpression; import com.facebook.presto.spi.relation.ExpressionOptimizer; import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.relation.VariableReferenceExpression; -import com.facebook.presto.sql.analyzer.TypeSignatureProvider; import com.facebook.presto.sql.planner.RowExpressionInterpreter; -import com.google.common.collect.ImmutableList; -import jakarta.annotation.Nullable; -import java.util.IdentityHashMap; -import java.util.Map; import java.util.function.Function; -import static com.facebook.presto.common.Utils.checkState; -import static com.facebook.presto.metadata.BuiltInTypeAndFunctionNamespaceManager.JAVA_BUILTIN_NAMESPACE; -import static com.facebook.presto.spi.StandardErrorCode.FUNCTION_NOT_FOUND; import static com.facebook.presto.spi.relation.ExpressionOptimizer.Level.OPTIMIZED; import static com.facebook.presto.sql.planner.LiteralEncoder.toRowExpression; -import static com.google.common.collect.ImmutableList.toImmutableList; -import static java.lang.String.format; import static java.util.Objects.requireNonNull; public final class RowExpressionOptimizer implements ExpressionOptimizer { private final FunctionAndTypeManager functionAndTypeManager; - private final CatalogSchemaName defaultNamespace; public RowExpressionOptimizer(Metadata metadata) { @@ -59,14 +40,13 @@ public RowExpressionOptimizer(Metadata metadata) public RowExpressionOptimizer(FunctionAndTypeManager functionAndTypeManager) { this.functionAndTypeManager = requireNonNull(functionAndTypeManager, "functionAndTypeManager is null"); - this.defaultNamespace = functionAndTypeManager.getDefaultNamespace(); } @Override public RowExpression optimize(RowExpression rowExpression, Level level, ConnectorSession session) { if (level.ordinal() <= OPTIMIZED.ordinal()) { - return getRowExpression(rowExpression, level, session, null); + return toRowExpression(rowExpression.getSourceLocation(), new RowExpressionInterpreter(rowExpression, functionAndTypeManager, session, level).optimize(), rowExpression.getType()); } throw new IllegalArgumentException("Not supported optimization level: " + level); } @@ -74,118 +54,7 @@ public RowExpression optimize(RowExpression rowExpression, Level level, Connecto @Override public RowExpression optimize(RowExpression expression, Level level, ConnectorSession session, Function variableResolver) { - return getRowExpression(expression, level, session, variableResolver); - } - - private RowExpression getRowExpression(RowExpression expression, Level level, ConnectorSession session, @Nullable Function variableResolver) - { - BuiltInNamespaceRewriter visitor = new BuiltInNamespaceRewriter(); - RowExpressionInterpreter interpreter = new RowExpressionInterpreter( - visitor.convertToInterpreterNamespace(expression), - functionAndTypeManager, - session, - level); - return visitor.restoreOriginalNamespaces(toRowExpression( - expression.getSourceLocation(), - interpreter.optimize(variableResolver != null ? variableResolver::apply : null), - expression.getType())); - } - - /** - * TODO: GIANT HACK - * This class is a hack and should eventually be removed. It is used to ensure consistent constant folding behavior when the built-in - * function namespace has been switched (for example, to native.default. in the case of native functions). This will no longer be needed - * when the native sidecar is capable of providing its own expression optimizer. - */ - private class BuiltInNamespaceRewriter - { - private final Map defaultToOriginalFunctionHandles = new IdentityHashMap<>(); - - public RowExpression convertToInterpreterNamespace(RowExpression expression) - { - if (defaultNamespace.equals(JAVA_BUILTIN_NAMESPACE)) { - // No need to replace built-in namespaces if the default namespace is already the Java built-in namespace - return expression; - } - return RowExpressionTreeRewriter.rewriteWith(new ReplaceBuiltInNamespaces(), expression, null); - } - - public RowExpression restoreOriginalNamespaces(RowExpression expression) - { - if (defaultToOriginalFunctionHandles.isEmpty()) { - return expression; - } - return RowExpressionTreeRewriter.rewriteWith(new ReplaceOriginalNamespaces(), expression, null); - } - - private class ReplaceBuiltInNamespaces - extends RowExpressionRewriter - { - @Override - public RowExpression rewriteCall(CallExpression call, Void context, RowExpressionTreeRewriter treeRewriter) - { - FunctionHandle functionHandle = call.getFunctionHandle(); - FunctionMetadata functionMetadata = functionAndTypeManager.getFunctionMetadata(functionHandle); - if (!functionMetadata.getImplementationType().canBeEvaluatedInCoordinator()) { - checkState(!functionHandle.getCatalogSchemaName().equals(JAVA_BUILTIN_NAMESPACE), - format("FunctionHandle %s is already in the Java built-in namespace (%s), yet is marked as ineligible to be evaluated in the coordinator", functionHandle, functionHandle.getCatalogSchemaName())); - - // Replace the namespace with the Java built-in namespace - FunctionHandle javaNamespaceFunctionHandle; - try { - javaNamespaceFunctionHandle = functionAndTypeManager.lookupFunction( - QualifiedObjectName.valueOf(JAVA_BUILTIN_NAMESPACE, call.getDisplayName()), - functionHandle.getArgumentTypes().stream().map(TypeSignatureProvider::new).collect(toImmutableList())); - } - catch (PrestoException e) { - if (e.getErrorCode().equals(FUNCTION_NOT_FOUND.toErrorCode())) { - // If the function is not found in the Java built-in namespace, let default rewriter handle it - return null; - } - throw e; // Rethrow other exceptions - } - - checkState(functionAndTypeManager.getFunctionMetadata(javaNamespaceFunctionHandle).getImplementationType().canBeEvaluatedInCoordinator(), - format("FunctionHandle %s in the Java built-in namespace (%s) is not eligible to be evaluated in the coordinator", javaNamespaceFunctionHandle, JAVA_BUILTIN_NAMESPACE)); - - defaultToOriginalFunctionHandles.put(javaNamespaceFunctionHandle, functionHandle); - ImmutableList rewrittenArgs = call.getArguments().stream() - .map(arg -> treeRewriter.rewrite(arg, context)) - .collect(toImmutableList()); - return new CallExpression( - call.getSourceLocation(), - call.getDisplayName(), - javaNamespaceFunctionHandle, - call.getType(), - rewrittenArgs); - } - - // Return null to let the default rewriter handle it (which will rewrite children automatically) - return null; - } - } - - private class ReplaceOriginalNamespaces - extends RowExpressionRewriter - { - @Override - public RowExpression rewriteCall(CallExpression call, Void context, RowExpressionTreeRewriter treeRewriter) - { - if (defaultToOriginalFunctionHandles.containsKey(call.getFunctionHandle())) { - FunctionHandle originalFunctionHandle = defaultToOriginalFunctionHandles.get(call.getFunctionHandle()); - ImmutableList rewrittenArgs = call.getArguments().stream() - .map(arg -> treeRewriter.rewrite(arg, context)) - .collect(toImmutableList()); - return new CallExpression( - call.getSourceLocation(), - call.getDisplayName(), - originalFunctionHandle, - call.getType(), - rewrittenArgs); - } - // Return null to let the default rewriter handle it (which will rewrite children automatically) - return null; - } - } + RowExpressionInterpreter interpreter = new RowExpressionInterpreter(expression, functionAndTypeManager, session, level); + return toRowExpression(expression.getSourceLocation(), interpreter.optimize(variableResolver::apply), expression.getType()); } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/testing/LocalQueryRunner.java b/presto-main-base/src/main/java/com/facebook/presto/testing/LocalQueryRunner.java index aaad7753c672f..ed686664b10d2 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/testing/LocalQueryRunner.java +++ b/presto-main-base/src/main/java/com/facebook/presto/testing/LocalQueryRunner.java @@ -153,6 +153,7 @@ import com.facebook.presto.spi.plan.StageExecutionDescriptor; import com.facebook.presto.spi.plan.TableScanNode; import com.facebook.presto.spi.procedure.ProcedureRegistry; +import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spiller.FileSingleStreamSpillerFactory; import com.facebook.presto.spiller.GenericPartitioningSpillerFactory; import com.facebook.presto.spiller.GenericSpillerFactory; @@ -179,6 +180,7 @@ import com.facebook.presto.sql.analyzer.QueryExplainer; import com.facebook.presto.sql.analyzer.QueryPreparerProviderManager; import com.facebook.presto.sql.expressions.ExpressionOptimizerManager; +import com.facebook.presto.sql.expressions.JsonCodecRowExpressionSerde; import com.facebook.presto.sql.gen.ExpressionCompiler; import com.facebook.presto.sql.gen.JoinCompiler; import com.facebook.presto.sql.gen.JoinFilterFunctionCompiler; @@ -472,7 +474,7 @@ private LocalQueryRunner(Session defaultSession, FeaturesConfig featuresConfig, this.pageIndexerFactory = new GroupByHashPageIndexerFactory(joinCompiler); NodeInfo nodeInfo = new NodeInfo("test"); - expressionOptimizerManager = new ExpressionOptimizerManager(new PluginNodeManager(nodeManager, nodeInfo.getEnvironment()), getFunctionAndTypeManager()); + expressionOptimizerManager = new ExpressionOptimizerManager(new PluginNodeManager(nodeManager, nodeInfo.getEnvironment()), getFunctionAndTypeManager(), new JsonCodecRowExpressionSerde(jsonCodec(RowExpression.class))); this.accessControl = new TestingAccessControlManager(transactionManager); this.statsNormalizer = new StatsNormalizer(); diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/expressions/TestExpressionOptimizerManager.java b/presto-main-base/src/test/java/com/facebook/presto/sql/expressions/TestExpressionOptimizerManager.java index 913e0db26adc0..7340e45b42c1d 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/sql/expressions/TestExpressionOptimizerManager.java +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/expressions/TestExpressionOptimizerManager.java @@ -34,6 +34,7 @@ import java.util.Map; import java.util.Properties; +import static com.facebook.airlift.json.JsonCodec.jsonCodec; import static com.facebook.presto.spi.relation.ExpressionOptimizer.Level.OPTIMIZED; import static com.facebook.presto.sql.relational.Expressions.constant; import static com.facebook.presto.testing.TestingSession.testSessionBuilder; @@ -65,6 +66,7 @@ public void setUp() manager = new ExpressionOptimizerManager( pluginNodeManager, METADATA.getFunctionAndTypeManager(), + new JsonCodecRowExpressionSerde(jsonCodec(RowExpression.class)), directory); } diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/assertions/OptimizerAssert.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/assertions/OptimizerAssert.java index cea4abb4e0787..a2f535e4d5c2a 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/assertions/OptimizerAssert.java +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/assertions/OptimizerAssert.java @@ -23,9 +23,11 @@ import com.facebook.presto.spi.WarningCollector; import com.facebook.presto.spi.plan.PlanNode; import com.facebook.presto.spi.plan.PlanNodeIdAllocator; +import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.security.AccessControl; import com.facebook.presto.sql.Optimizer; import com.facebook.presto.sql.expressions.ExpressionOptimizerManager; +import com.facebook.presto.sql.expressions.JsonCodecRowExpressionSerde; import com.facebook.presto.sql.planner.Plan; import com.facebook.presto.sql.planner.RuleStatsRecorder; import com.facebook.presto.sql.planner.TypeProvider; @@ -49,6 +51,7 @@ import java.util.function.Consumer; import java.util.function.Function; +import static com.facebook.airlift.json.JsonCodec.jsonCodec; import static com.facebook.presto.sql.planner.assertions.PlanAssert.assertPlan; import static com.facebook.presto.sql.planner.assertions.PlanAssert.assertPlanDoesNotMatch; import static com.facebook.presto.transaction.TransactionBuilder.transaction; @@ -177,7 +180,8 @@ private List getMinimalOptimizers() metadata, new ExpressionOptimizerManager( new PluginNodeManager(new InMemoryNodeManager()), - queryRunner.getFunctionAndTypeManager())).rules())); + queryRunner.getFunctionAndTypeManager(), + new JsonCodecRowExpressionSerde(jsonCodec(RowExpression.class)))).rules())); } private void inTransaction(Function transactionSessionConsumer) diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestSimplifyRowExpressions.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestSimplifyRowExpressions.java index 7817b311053df..78ff684386afc 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestSimplifyRowExpressions.java +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestSimplifyRowExpressions.java @@ -25,6 +25,7 @@ import com.facebook.presto.spi.relation.SpecialFormExpression; import com.facebook.presto.sql.TestingRowExpressionTranslator; import com.facebook.presto.sql.expressions.ExpressionOptimizerManager; +import com.facebook.presto.sql.expressions.JsonCodecRowExpressionSerde; import com.facebook.presto.sql.parser.SqlParser; import com.facebook.presto.sql.planner.TypeProvider; import com.facebook.presto.sql.tree.Expression; @@ -38,6 +39,7 @@ import java.util.stream.IntStream; import java.util.stream.Stream; +import static com.facebook.airlift.json.JsonCodec.jsonCodec; import static com.facebook.presto.SessionTestUtils.TEST_SESSION; import static com.facebook.presto.common.type.BooleanType.BOOLEAN; import static com.facebook.presto.metadata.MetadataManager.createTestMetadataManager; @@ -185,7 +187,7 @@ private static void assertSimplifies(String expression, String rowExpressionExpe Expression actualExpression = rewriteIdentifiersToSymbolReferences(SQL_PARSER.createExpression(expression)); InMemoryNodeManager nodeManager = new InMemoryNodeManager(); - ExpressionOptimizerManager expressionOptimizerManager = new ExpressionOptimizerManager(new PluginNodeManager(nodeManager), METADATA.getFunctionAndTypeManager()); + ExpressionOptimizerManager expressionOptimizerManager = new ExpressionOptimizerManager(new PluginNodeManager(nodeManager), METADATA.getFunctionAndTypeManager(), new JsonCodecRowExpressionSerde(jsonCodec(RowExpression.class))); TestingRowExpressionTranslator translator = new TestingRowExpressionTranslator(METADATA); RowExpression actualRowExpression = translator.translate(actualExpression, TypeProvider.viewOf(TYPES)); diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/relational/TestRowExpressionOptimizer.java b/presto-main-base/src/test/java/com/facebook/presto/sql/relational/TestRowExpressionOptimizer.java index 0a100ac9f8757..dca72b62123ef 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/sql/relational/TestRowExpressionOptimizer.java +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/relational/TestRowExpressionOptimizer.java @@ -13,50 +13,27 @@ */ package com.facebook.presto.sql.relational; -import com.facebook.presto.common.QualifiedObjectName; import com.facebook.presto.common.block.IntArrayBlock; -import com.facebook.presto.common.function.OperatorType; import com.facebook.presto.common.type.ArrayType; import com.facebook.presto.common.type.RowType; -import com.facebook.presto.common.type.StandardTypes; -import com.facebook.presto.common.type.Type; -import com.facebook.presto.functionNamespace.SqlInvokedFunctionNamespaceManagerConfig; -import com.facebook.presto.functionNamespace.execution.NoopSqlFunctionExecutor; -import com.facebook.presto.functionNamespace.execution.SqlFunctionExecutors; -import com.facebook.presto.functionNamespace.testing.InMemoryFunctionNamespaceManager; import com.facebook.presto.metadata.FunctionAndTypeManager; import com.facebook.presto.metadata.MetadataManager; import com.facebook.presto.spi.function.FunctionHandle; -import com.facebook.presto.spi.function.FunctionImplementationType; -import com.facebook.presto.spi.function.Parameter; -import com.facebook.presto.spi.function.RoutineCharacteristics; -import com.facebook.presto.spi.function.SqlFunctionHandle; -import com.facebook.presto.spi.function.SqlFunctionId; -import com.facebook.presto.spi.function.SqlInvokedFunction; import com.facebook.presto.spi.relation.CallExpression; import com.facebook.presto.spi.relation.ConstantExpression; -import com.facebook.presto.spi.relation.LambdaDefinitionExpression; import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.relation.SpecialFormExpression; -import com.facebook.presto.sql.analyzer.FunctionsConfig; import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableMap; import org.testng.annotations.AfterClass; import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; -import java.util.List; -import java.util.Optional; - import static com.facebook.airlift.testing.Assertions.assertInstanceOf; import static com.facebook.presto.block.BlockAssertions.toValues; import static com.facebook.presto.common.function.OperatorType.ADD; import static com.facebook.presto.common.function.OperatorType.EQUAL; -import static com.facebook.presto.common.function.OperatorType.GREATER_THAN; -import static com.facebook.presto.common.function.OperatorType.LESS_THAN; import static com.facebook.presto.common.type.BigintType.BIGINT; import static com.facebook.presto.common.type.BooleanType.BOOLEAN; -import static com.facebook.presto.common.type.DoubleType.DOUBLE; import static com.facebook.presto.common.type.IntegerType.INTEGER; import static com.facebook.presto.common.type.JsonType.JSON; import static com.facebook.presto.common.type.TypeSignature.parseTypeSignature; @@ -66,11 +43,6 @@ import static com.facebook.presto.metadata.CastType.JSON_TO_MAP_CAST; import static com.facebook.presto.metadata.CastType.JSON_TO_ROW_CAST; import static com.facebook.presto.metadata.FunctionAndTypeManager.createTestFunctionAndTypeManager; -import static com.facebook.presto.spi.function.FunctionVersion.notVersioned; -import static com.facebook.presto.spi.function.RoutineCharacteristics.Determinism.DETERMINISTIC; -import static com.facebook.presto.spi.function.RoutineCharacteristics.Language.CPP; -import static com.facebook.presto.spi.function.RoutineCharacteristics.Language.JAVA; -import static com.facebook.presto.spi.function.RoutineCharacteristics.NullCallClause.RETURNS_NULL_ON_NULL_INPUT; import static com.facebook.presto.spi.relation.ExpressionOptimizer.Level.OPTIMIZED; import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.IF; import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes; @@ -80,60 +52,10 @@ import static com.facebook.presto.testing.TestingConnectorSession.SESSION; import static com.facebook.presto.util.StructuralTestUtil.mapType; import static io.airlift.slice.Slices.utf8Slice; -import static java.lang.String.format; import static org.testng.Assert.assertEquals; -import static org.testng.Assert.assertThrows; public class TestRowExpressionOptimizer { - private static final SqlInvokedFunction CPP_FOO = new SqlInvokedFunction( - new QualifiedObjectName("native", "default", "sqrt"), - ImmutableList.of(new Parameter("x", parseTypeSignature(StandardTypes.BIGINT))), - parseTypeSignature(StandardTypes.DOUBLE), - "sqrt(x)", - RoutineCharacteristics.builder().setLanguage(CPP).setDeterminism(DETERMINISTIC).setNullCallClause(RETURNS_NULL_ON_NULL_INPUT).build(), - "", - notVersioned()); - private static final SqlInvokedFunction CPP_BAR = new SqlInvokedFunction( - new QualifiedObjectName("native", "default", "cbrt"), - ImmutableList.of(new Parameter("x", parseTypeSignature(StandardTypes.BIGINT))), - parseTypeSignature(StandardTypes.DOUBLE), - "cbrt(x)", - RoutineCharacteristics.builder().setLanguage(CPP).setDeterminism(DETERMINISTIC).setNullCallClause(RETURNS_NULL_ON_NULL_INPUT).build(), - "", - notVersioned()); - private static final SqlInvokedFunction CPP_CUSTOM_FUNCTION = new SqlInvokedFunction( - new QualifiedObjectName("native", "default", "cpp_custom_func"), - ImmutableList.of(new Parameter("x", parseTypeSignature(StandardTypes.BIGINT))), - parseTypeSignature(StandardTypes.BIGINT), - "cpp_custom_func(x)", - RoutineCharacteristics.builder().setLanguage(CPP).setDeterminism(DETERMINISTIC).setNullCallClause(RETURNS_NULL_ON_NULL_INPUT).build(), - "", - notVersioned()); - private static final String nativePrefix = "native.default"; - - private static final RowExpression CUBE_ROOT_EXP = call( - "cbrt", - new SqlFunctionHandle( - new SqlFunctionId( - QualifiedObjectName.valueOf(format("%s.cbrt", nativePrefix)), - ImmutableList.of(BIGINT.getTypeSignature())), - "1"), - DOUBLE, - ImmutableList.of( - constant(27L, BIGINT))); - - private static final RowExpression SQUARE_ROOT_EXP = call( - "sqrt", - new SqlFunctionHandle( - new SqlFunctionId( - QualifiedObjectName.valueOf(format("%s.sqrt", nativePrefix)), - ImmutableList.of(BIGINT.getTypeSignature())), - "1"), - DOUBLE, - ImmutableList.of( - constant(64L, BIGINT))); - private FunctionAndTypeManager functionAndTypeManager; private RowExpressionOptimizer optimizer; @@ -212,194 +134,11 @@ public void testCastWithJsonParseOptimization() call(JSON_TO_ROW_CAST.name(), functionAndTypeManager.lookupCast(JSON_TO_ROW_CAST, VARCHAR, functionAndTypeManager.getType(parseTypeSignature("row(varchar,bigint)"))), RowType.anonymous(ImmutableList.of(VARCHAR, BIGINT)), field(1, VARCHAR))); } - @Test - public void testDefaultExpressionOptimizerUsesJavaNamespaceForBuiltInFunctions() - { - RowExpressionOptimizer nativeOptimizer = getNativeOptimizer(); - assertEquals(nativeOptimizer.optimize(SQUARE_ROOT_EXP, OPTIMIZED, SESSION), constant(8.0, DOUBLE)); - assertThrows(IllegalArgumentException.class, () -> optimizer.optimize(SQUARE_ROOT_EXP, OPTIMIZED, SESSION)); - - assertEquals(nativeOptimizer.optimize(CUBE_ROOT_EXP, OPTIMIZED, SESSION), constant(3.0, DOUBLE)); - assertThrows(IllegalArgumentException.class, () -> optimizer.optimize(CUBE_ROOT_EXP, OPTIMIZED, SESSION)); - } - - @Test - public void testFunctionNotInPrestoDefaultNamespaceIsNotEvaluated() - { - RowExpressionOptimizer nativeOptimizer = getNativeOptimizer(); - - // Create a call expression to the custom native function - RowExpression customFunctionCall = call( - "cpp_custom_func", - new SqlFunctionHandle( - new SqlFunctionId( - QualifiedObjectName.valueOf(format("%s.cpp_custom_func", nativePrefix)), - ImmutableList.of(BIGINT.getTypeSignature())), - "1"), - BIGINT, - ImmutableList.of(constant(42L, BIGINT))); - - // The function should not be evaluated since it doesn't exist in presto.default namespace - // It should return the original call expression unchanged - RowExpression optimized = nativeOptimizer.optimize(customFunctionCall, OPTIMIZED, SESSION); - assertEquals(optimized, customFunctionCall); - assertInstanceOf(optimized, CallExpression.class); - - // Verify that the function handle remains the same (not replaced) - CallExpression optimizedCall = (CallExpression) optimized; - assertEquals(optimizedCall.getFunctionHandle().getCatalogSchemaName().toString(), nativePrefix); - - // Create a call expression to the custom native function with a sqrt call expression arg - RowExpression customFunctionWithCallExpressionCall = call( - "cpp_custom_func", - new SqlFunctionHandle( - new SqlFunctionId( - QualifiedObjectName.valueOf(format("%s.cpp_custom_func", nativePrefix)), - ImmutableList.of(BIGINT.getTypeSignature())), - "1"), - BIGINT, - ImmutableList.of(SQUARE_ROOT_EXP)); - - // The inner CallExpression should be optimized, but the outer shouldn't since the function doesn't exist in presto.default namespace - optimized = nativeOptimizer.optimize(customFunctionWithCallExpressionCall, OPTIMIZED, SESSION); - assertEquals( - optimized, - call( - "cpp_custom_func", - new SqlFunctionHandle( - new SqlFunctionId( - QualifiedObjectName.valueOf(format("%s.cpp_custom_func", nativePrefix)), - ImmutableList.of(BIGINT.getTypeSignature())), - "1"), - BIGINT, - ImmutableList.of(constant(8.0, DOUBLE)))); - assertInstanceOf(optimized, CallExpression.class); - // Verify that the function handle remains the same (not replaced) - optimizedCall = (CallExpression) optimized; - assertEquals(optimizedCall.getFunctionHandle().getCatalogSchemaName().toString(), nativePrefix); - assertEquals(optimizedCall.getChildren().get(0), constant(8.0, DOUBLE)); - } - - @Test - public void testSpecialFormExpressionsWhenDefaultNamespaceIsSwitched() - { - RowExpressionOptimizer nativeOptimizer = getNativeOptimizer(); - - RowExpression leftCondition = callOperator( - GREATER_THAN, - SQUARE_ROOT_EXP, - constant(5.0, DOUBLE)); - - RowExpression rightCondition = callOperator( - LESS_THAN, - CUBE_ROOT_EXP, - constant(10.0, DOUBLE)); - - RowExpression andExpr = new SpecialFormExpression( - SpecialFormExpression.Form.AND, - BOOLEAN, - ImmutableList.of(leftCondition, rightCondition)); - - RowExpression optimized = nativeOptimizer.optimize(andExpr, OPTIMIZED, SESSION); - assertEquals(optimized, constant(true, BOOLEAN)); - - // Lambda expressions inside Special form expressions - List lambdaTypes = ImmutableList.of(DOUBLE, DOUBLE); - LambdaDefinitionExpression lambda = - new LambdaDefinitionExpression( - Optional.empty(), - lambdaTypes, - ImmutableList.of("s", "x"), - callOperator(ADD, SQUARE_ROOT_EXP, CUBE_ROOT_EXP)); - - andExpr = new SpecialFormExpression( - SpecialFormExpression.Form.AND, - BOOLEAN, - ImmutableList.of(lambda, rightCondition)); - - // rightCondition is always true, hence the expression should be reduced to "(s, x) -> 11.0". - optimized = nativeOptimizer.optimize(andExpr, OPTIMIZED, SESSION); - - assertInstanceOf(optimized, LambdaDefinitionExpression.class); - LambdaDefinitionExpression lambdaExpression = (LambdaDefinitionExpression) optimized; - assertEquals(lambdaExpression.getArgumentTypes(), lambdaTypes); - assertEquals(lambdaExpression.getBody(), constant(11.0, DOUBLE)); - } - - @Test - public void testLambdaExpressionsWhenDefaultNamespaceIsSwitched() - { - RowExpressionOptimizer nativeOptimizer = getNativeOptimizer(); - - List lambdaTypes = ImmutableList.of(DOUBLE, DOUBLE); - - LambdaDefinitionExpression lambda = - new LambdaDefinitionExpression( - Optional.empty(), - lambdaTypes, - ImmutableList.of("s", "x"), - callOperator(ADD, SQUARE_ROOT_EXP, CUBE_ROOT_EXP)); - - LambdaDefinitionExpression nestedLambda = - new LambdaDefinitionExpression( - Optional.empty(), - ImmutableList.of(DOUBLE), - ImmutableList.of("y"), - lambda); - - RowExpression optimized = nativeOptimizer.optimize(lambda, OPTIMIZED, SESSION); - - assertInstanceOf(optimized, LambdaDefinitionExpression.class); - LambdaDefinitionExpression lambdaExpression = (LambdaDefinitionExpression) optimized; - assertEquals(lambdaExpression.getArgumentTypes(), lambdaTypes); - assertEquals(lambdaExpression.getBody(), constant(11.0, DOUBLE)); - - // Nested lambda - RowExpression optimizedOuterLambda = nativeOptimizer.optimize(nestedLambda, OPTIMIZED, SESSION); - - assertInstanceOf(optimizedOuterLambda, LambdaDefinitionExpression.class); - LambdaDefinitionExpression outerLambdaExpression = (LambdaDefinitionExpression) optimizedOuterLambda; - assertEquals(outerLambdaExpression.getArgumentTypes(), ImmutableList.of(DOUBLE)); - assertInstanceOf(outerLambdaExpression.getBody(), LambdaDefinitionExpression.class); - - LambdaDefinitionExpression innerLambdaExpression = (LambdaDefinitionExpression) outerLambdaExpression.getBody(); - assertEquals(innerLambdaExpression.getArgumentTypes(), lambdaTypes); - assertEquals(innerLambdaExpression.getBody(), constant(11.0, DOUBLE)); - } - private static RowExpression ifExpression(RowExpression condition, long trueValue, long falseValue) { return new SpecialFormExpression(IF, BIGINT, ImmutableList.of(condition, constant(trueValue, BIGINT), constant(falseValue, BIGINT))); } - private static RowExpressionOptimizer getNativeOptimizer() - { - String nativePrefix = "native.default"; - MetadataManager metadata = MetadataManager.createTestMetadataManager(new FunctionsConfig().setDefaultNamespacePrefix(nativePrefix)); - - metadata.getFunctionAndTypeManager().addFunctionNamespace( - "native", - new InMemoryFunctionNamespaceManager( - "native", - new SqlFunctionExecutors( - ImmutableMap.of( - CPP, FunctionImplementationType.CPP, - JAVA, FunctionImplementationType.JAVA), - new NoopSqlFunctionExecutor()), - new SqlInvokedFunctionNamespaceManagerConfig().setSupportedFunctionLanguages("cpp"))); - metadata.getFunctionAndTypeManager().createFunction(CPP_FOO, true); - metadata.getFunctionAndTypeManager().createFunction(CPP_BAR, true); - // Create a custom function that only exists in native namespace - metadata.getFunctionAndTypeManager().createFunction(CPP_CUSTOM_FUNCTION, true); - return new RowExpressionOptimizer(metadata); - } - - private RowExpression callOperator(OperatorType operator, RowExpression left, RowExpression right) - { - FunctionHandle functionHandle = functionAndTypeManager.resolveOperator(operator, fromTypes(left.getType(), right.getType())); - return Expressions.call(operator.getOperator(), functionHandle, left.getType(), left, right); - } - private RowExpression optimize(RowExpression expression) { return optimizer.optimize(expression, OPTIMIZED, SESSION); diff --git a/presto-main/src/main/java/com/facebook/presto/server/ServerMainModule.java b/presto-main/src/main/java/com/facebook/presto/server/ServerMainModule.java index 88f191bcc05cf..63beb77c6bcf8 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/ServerMainModule.java +++ b/presto-main/src/main/java/com/facebook/presto/server/ServerMainModule.java @@ -169,6 +169,7 @@ import com.facebook.presto.spi.NodeManager; import com.facebook.presto.spi.PageIndexerFactory; import com.facebook.presto.spi.PageSorter; +import com.facebook.presto.spi.RowExpressionSerde; import com.facebook.presto.spi.analyzer.ViewDefinition; import com.facebook.presto.spi.function.SqlInvokedFunction; import com.facebook.presto.spi.plan.SimplePlanFragment; @@ -177,6 +178,7 @@ import com.facebook.presto.spi.relation.DeterminismEvaluator; import com.facebook.presto.spi.relation.DomainTranslator; import com.facebook.presto.spi.relation.PredicateCompiler; +import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.spi.session.WorkerSessionPropertyProvider; import com.facebook.presto.spiller.FileSingleStreamSpillerFactory; @@ -217,6 +219,7 @@ import com.facebook.presto.sql.analyzer.QueryExplainer; import com.facebook.presto.sql.analyzer.QueryPreparerProviderManager; import com.facebook.presto.sql.expressions.ExpressionOptimizerManager; +import com.facebook.presto.sql.expressions.JsonCodecRowExpressionSerde; import com.facebook.presto.sql.gen.ExpressionCompiler; import com.facebook.presto.sql.gen.JoinCompiler; import com.facebook.presto.sql.gen.JoinFilterFunctionCompiler; @@ -384,6 +387,7 @@ else if (serverConfig.isCoordinator()) { // expression manager binder.bind(ExpressionOptimizerManager.class).in(Scopes.SINGLETON); + binder.bind(RowExpressionSerde.class).to(JsonCodecRowExpressionSerde.class).in(Scopes.SINGLETON); // schema properties binder.bind(SchemaPropertyManager.class).in(Scopes.SINGLETON); @@ -591,6 +595,7 @@ public ListeningExecutorService createResourceManagerExecutor(ResourceManagerCon jsonCodecBinder(binder).bindJsonCodec(SqlInvokedFunction.class); jsonCodecBinder(binder).bindJsonCodec(TaskSource.class); jsonCodecBinder(binder).bindJsonCodec(TableWriteInfo.class); + jsonCodecBinder(binder).bindJsonCodec(RowExpression.class); smileCodecBinder(binder).bindSmileCodec(TaskStatus.class); smileCodecBinder(binder).bindSmileCodec(TaskInfo.class); thriftCodecBinder(binder).bindThriftCodec(TaskStatus.class); diff --git a/presto-main/src/main/java/com/facebook/presto/server/testing/TestingPrestoServer.java b/presto-main/src/main/java/com/facebook/presto/server/testing/TestingPrestoServer.java index 3f19772431415..debae20a1aa3c 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/testing/TestingPrestoServer.java +++ b/presto-main/src/main/java/com/facebook/presto/server/testing/TestingPrestoServer.java @@ -726,6 +726,11 @@ public NodeManager getPluginNodeManager() return pluginNodeManager; } + public FunctionAndTypeManager getFunctionAndTypeManager() + { + return functionAndTypeManager; + } + public NodePartitioningManager getNodePartitioningManager() { return nodePartitioningManager; diff --git a/presto-native-execution/pom.xml b/presto-native-execution/pom.xml index 30f44d30d5c93..d5a9ac99b4425 100644 --- a/presto-native-execution/pom.xml +++ b/presto-native-execution/pom.xml @@ -55,6 +55,12 @@ test + + com.facebook.airlift + json + test + + org.weakref jmxutils diff --git a/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/TestPrestoNativeBuiltInFunctions.java b/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/TestPrestoNativeBuiltInFunctions.java index d7dff72ba068d..49f6a21fc2080 100644 --- a/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/TestPrestoNativeBuiltInFunctions.java +++ b/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/TestPrestoNativeBuiltInFunctions.java @@ -32,10 +32,12 @@ import com.facebook.presto.spi.function.FunctionKind; import com.facebook.presto.spi.function.RoutineCharacteristics; import com.facebook.presto.spi.function.SqlInvokedFunction; +import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.security.AllowAllAccessControl; import com.facebook.presto.sql.analyzer.FeaturesConfig; import com.facebook.presto.sql.analyzer.QueryExplainer; import com.facebook.presto.sql.expressions.ExpressionOptimizerManager; +import com.facebook.presto.sql.expressions.JsonCodecRowExpressionSerde; import com.facebook.presto.sql.planner.PartitioningProviderManager; import com.facebook.presto.sql.planner.PlanFragmenter; import com.facebook.presto.sql.planner.PlanOptimizers; @@ -56,6 +58,7 @@ import java.util.Optional; import java.util.regex.Pattern; +import static com.facebook.airlift.json.JsonCodec.jsonCodec; import static com.facebook.presto.builtin.tools.WorkerFunctionUtil.createSqlInvokedFunction; import static com.facebook.presto.nativeworker.NativeQueryRunnerUtils.createLineitem; import static com.facebook.presto.nativeworker.NativeQueryRunnerUtils.createNation; @@ -187,7 +190,8 @@ private QueryExplainer getQueryExplainerFromProvidedQueryRunner(QueryRunner quer featuresConfig, new ExpressionOptimizerManager( new PluginNodeManager(new InMemoryNodeManager()), - queryRunner.getMetadata().getFunctionAndTypeManager()), + queryRunner.getMetadata().getFunctionAndTypeManager(), + new JsonCodecRowExpressionSerde(jsonCodec(RowExpression.class))), new TaskManagerConfig(), new AllowAllAccessControl()) .getPlanningTimeOptimizers(); diff --git a/presto-native-sidecar-plugin/pom.xml b/presto-native-sidecar-plugin/pom.xml index f3b1abfd7e0c4..6ddcc91fcb757 100644 --- a/presto-native-sidecar-plugin/pom.xml +++ b/presto-native-sidecar-plugin/pom.xml @@ -237,6 +237,12 @@ test + + javax.ws.rs + javax.ws.rs-api + test + + com.facebook.presto presto-client diff --git a/presto-native-sidecar-plugin/src/main/java/com/facebook/presto/sidecar/NativeSidecarPlugin.java b/presto-native-sidecar-plugin/src/main/java/com/facebook/presto/sidecar/NativeSidecarPlugin.java index 7ad15da8201ba..c6d2073e1b170 100644 --- a/presto-native-sidecar-plugin/src/main/java/com/facebook/presto/sidecar/NativeSidecarPlugin.java +++ b/presto-native-sidecar-plugin/src/main/java/com/facebook/presto/sidecar/NativeSidecarPlugin.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.sidecar; +import com.facebook.presto.sidecar.expressions.NativeExpressionOptimizerFactory; import com.facebook.presto.sidecar.functionNamespace.NativeFunctionNamespaceManagerFactory; import com.facebook.presto.sidecar.nativechecker.NativePlanCheckerProviderFactory; import com.facebook.presto.sidecar.sessionpropertyproviders.NativeSystemSessionPropertyProviderFactory; @@ -21,6 +22,7 @@ import com.facebook.presto.spi.function.FunctionNamespaceManagerFactory; import com.facebook.presto.spi.plan.PlanCheckerProviderFactory; import com.facebook.presto.spi.session.WorkerSessionPropertyProviderFactory; +import com.facebook.presto.spi.sql.planner.ExpressionOptimizerFactory; import com.facebook.presto.spi.type.TypeManagerFactory; import com.google.common.collect.ImmutableList; @@ -51,6 +53,12 @@ public Iterable getFunctionNamespaceManagerFact return ImmutableList.of(new NativeFunctionNamespaceManagerFactory()); } + @Override + public Iterable getExpressionOptimizerFactories() + { + return ImmutableList.of(new NativeExpressionOptimizerFactory(getClassLoader())); + } + private static ClassLoader getClassLoader() { ClassLoader classLoader = Thread.currentThread().getContextClassLoader(); diff --git a/presto-native-sidecar-plugin/src/main/java/com/facebook/presto/sidecar/expressions/NativeExpressionOptimizer.java b/presto-native-sidecar-plugin/src/main/java/com/facebook/presto/sidecar/expressions/NativeExpressionOptimizer.java new file mode 100644 index 0000000000000..59b03497726d1 --- /dev/null +++ b/presto-native-sidecar-plugin/src/main/java/com/facebook/presto/sidecar/expressions/NativeExpressionOptimizer.java @@ -0,0 +1,387 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sidecar.expressions; + +import com.facebook.presto.common.type.Type; +import com.facebook.presto.spi.ConnectorSession; +import com.facebook.presto.spi.SourceLocation; +import com.facebook.presto.spi.function.FunctionMetadataManager; +import com.facebook.presto.spi.function.StandardFunctionResolution; +import com.facebook.presto.spi.relation.CallExpression; +import com.facebook.presto.spi.relation.ConstantExpression; +import com.facebook.presto.spi.relation.ExpressionOptimizer; +import com.facebook.presto.spi.relation.LambdaDefinitionExpression; +import com.facebook.presto.spi.relation.RowExpression; +import com.facebook.presto.spi.relation.RowExpressionVisitor; +import com.facebook.presto.spi.relation.SpecialFormExpression; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; +import com.google.inject.Inject; + +import java.util.ArrayDeque; +import java.util.HashMap; +import java.util.IdentityHashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.function.Function; + +import static com.facebook.presto.spi.relation.ExpressionOptimizer.Level.EVALUATED; +import static com.facebook.presto.spi.relation.ExpressionOptimizer.Level.OPTIMIZED; +import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.COALESCE; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static java.util.Collections.newSetFromMap; +import static java.util.Objects.requireNonNull; +import static java.util.stream.Collectors.toMap; + +public class NativeExpressionOptimizer + implements ExpressionOptimizer +{ + private final FunctionMetadataManager functionMetadataManager; + private final StandardFunctionResolution resolution; + private final NativeSidecarExpressionInterpreter rowExpressionInterpreterService; + + @Inject + public NativeExpressionOptimizer( + NativeSidecarExpressionInterpreter rowExpressionInterpreterService, + FunctionMetadataManager functionMetadataManager, + StandardFunctionResolution resolution) + { + this.rowExpressionInterpreterService = requireNonNull(rowExpressionInterpreterService, "rowExpressionInterpreterService is null"); + this.functionMetadataManager = requireNonNull(functionMetadataManager, "functionMetadataManager is null"); + this.resolution = requireNonNull(resolution, "resolution is null"); + } + + @Override + public RowExpression optimize(RowExpression expression, Level level, ConnectorSession session, Function variableResolver) + { + // Collect expressions to optimize + CollectingVisitor collectingVisitor = new CollectingVisitor(functionMetadataManager, level, resolution); + expression.accept(collectingVisitor, variableResolver); + List expressionsToOptimize = collectingVisitor.getExpressionsToOptimize(); + + // Create a map of original expressions and expressions with variables resolved to constants or row expressions. + Map expressions = expressionsToOptimize.stream() + .collect(toMap( + Function.identity(), + rowExpression -> rowExpression.accept( + new ReplacingVisitor(variable -> { + // Apply resolver + Object replacement = variableResolver.apply(variable); + // Preserve original variable if resolver returns null + return replacement != null + ? toRowExpression(variable.getSourceLocation(), replacement, variable.getType()) + : variable; + }), + null), + (a, b) -> a)); + if (expressions.isEmpty()) { + return expression; + } + + // Constants can be trivially replaced without invoking the interpreter. Move them into a separate map. + Map constants = new HashMap<>(); + Iterator> entries = expressions.entrySet().iterator(); + while (entries.hasNext()) { + Map.Entry entry = entries.next(); + if (entry.getValue() instanceof ConstantExpression) { + constants.put(entry.getKey(), entry.getValue()); + entries.remove(); + } + } + + // Optimize the expressions using the sidecar interpreter + Map replacements = new HashMap<>(); + if (!expressions.isEmpty()) { + // The native endpoint only supports optimizer levels OPTIMIZED or EVALUATED. + // In the sidecar, SERIALIZABLE is effectively the same as OPTIMIZED, + // so if SERIALIZABLE is requested, we use OPTIMIZED instead. + replacements.putAll( + rowExpressionInterpreterService.optimizeBatch( + session, + expressions, + level.ordinal() < OPTIMIZED.ordinal() ? OPTIMIZED : level)); + } + + // Add back in the constants + replacements.putAll(constants); + + // Replace all the expressions in the original expression with the optimized expressions + return toRowExpression(expression.getSourceLocation(), expression.accept(new ReplacingVisitor(replacements), null), expression.getType()); + } + + /** + * This visitor collects expressions that can be optimized by the sidecar interpreter. + */ + private static class CollectingVisitor + implements RowExpressionVisitor + { + private final FunctionMetadataManager functionMetadataManager; + private final Level optimizationLevel; + private final StandardFunctionResolution resolution; + private final Set expressionsToOptimize = newSetFromMap(new IdentityHashMap<>()); + private final Set hasOptimizedChildren = newSetFromMap(new IdentityHashMap<>()); + + public CollectingVisitor(FunctionMetadataManager functionMetadataManager, Level optimizationLevel, StandardFunctionResolution resolution) + { + this.functionMetadataManager = requireNonNull(functionMetadataManager, "functionMetadataManager is null"); + this.optimizationLevel = requireNonNull(optimizationLevel, "optimizationLevel is null"); + this.resolution = requireNonNull(resolution, "resolution is null"); + } + + @Override + public Void visitExpression(RowExpression node, Object context) + { + visitNode(node, false); + return null; + } + + @Override + public Void visitConstant(ConstantExpression node, Object context) + { + visitNode(node, true); + return null; + } + + @Override + public Void visitVariableReference(VariableReferenceExpression node, Object context) + { + Object value = null; + if (context instanceof Function) { + value = ((Function) context).apply(node); + } + // If context is null or not a function, value stays null + if (value == null || value instanceof RowExpression) { + visitNode(node, false); + return null; + } + visitNode(node, true); + return null; + } + + @Override + public Void visitCall(CallExpression node, Object context) + { + // If the optimization level is not EVALUATED, then we cannot optimize non-deterministic functions + boolean isDeterministic = functionMetadataManager.getFunctionMetadata(node.getFunctionHandle()).isDeterministic(); + boolean canBeEvaluated = (optimizationLevel.ordinal() < EVALUATED.ordinal() && isDeterministic) || + optimizationLevel.ordinal() == EVALUATED.ordinal(); + + // All arguments must be optimizable in order to evaluate the function + for (RowExpression child : node.getArguments()) { + child.accept(this, context); + } + + boolean allConstantFoldable = node.getArguments().stream() + .allMatch(this::canBeOptimized); + + if (canBeEvaluated && allConstantFoldable) { + visitNode(node, true); + return null; + } + + // If it's a cast and the type is already the same, then it's constant foldable + if (resolution.isCastFunction(node.getFunctionHandle()) + && node.getArguments().size() == 1 + && node.getType().equals(node.getArguments().get(0).getType())) { + visitNode(node, true); + return null; + } + visitNode(node, false); + return null; + } + + @Override + public Void visitSpecialForm(SpecialFormExpression node, Object context) + { + // Most special form expressions short circuit, meaning that they potentially don't evaluate all arguments. For example, the AND expression + // will stop evaluating arguments as soon as it finds a false argument. Because a sub-expression could be simplified into a constant, and this + // constant could cause the expression to short circuit, if there is at least one argument which is optimizable, then the entire expression should + // be sent to the sidecar to be optimized. + for (RowExpression child : node.getArguments()) { + child.accept(this, context); + } + + boolean anyArgumentsOptimizable = node.getArguments().stream() + .anyMatch(this::canBeOptimized); + + // If any arguments are constant foldable, then the whole expression is constant foldable + if (anyArgumentsOptimizable) { + visitNode(node, true); + return null; + } + + // If the special form is COALESCE, then we can optimize it if there are any duplicate arguments + if (node.getForm() == COALESCE) { + ImmutableSet.Builder uniqueArgs = ImmutableSet.builder(); + int optimizableCount = 0; + // Check if there's any duplicate arguments, these can be de-duplicated + for (RowExpression argument : node.getArguments()) { + // The duplicate argument must either be a leaf (variable reference) or constant foldable + if (canBeOptimized(argument)) { + uniqueArgs.add(argument); + optimizableCount++; + } + } + + // There is a duplicate when the number of optimizable args > number of unique ones + // If there were any duplicates, or if there's no arguments (cancel out), or if there's only one argument (just return it), + // then it's also constant foldable + boolean canBeOptimized = uniqueArgs.build().size() < optimizableCount || node.getArguments().size() <= 1; + if (canBeOptimized) { + visitNode(node, true); + return null; + } + } + visitNode(node, false); + return null; + } + + @Override + public Void visitLambda(LambdaDefinitionExpression node, Object context) + { + node.getBody().accept(this, (Function) variable -> variable); + if (canBeOptimized(node.getBody())) { + visitNode(node, true); + return null; + } + visitNode(node, false); + return null; + } + + public boolean canBeOptimized(RowExpression rowExpression) + { + return expressionsToOptimize.contains(rowExpression); + } + + private void visitNode(RowExpression node, boolean canBeOptimized) + { + requireNonNull(node, "node is null"); + // If the present node can be optimized, then we send the whole expression. Because an expression may consist of many + // sub-expressions, we need to ensure that we don't send the sub-expression along with its parent expression. For example, + // if we have the expression (a + b) + c, and we can optimize a + b, then we don't want to send a + b to the sidecar, because + // it will be optimized twice. Instead, we want to send (a + b) + c to the sidecar, and then remove a + b from the list of + // expressions to optimize. + // We need to traverse the entire subtree of possible expressions to optimize because some special form expressions may + // short circuit, and we need to ensure that we don't send the sub-expression to the sidecar if the parent expression is + // constant foldable. For example, consider the expression false AND (true OR a). Although the expression true OR a is + // constant foldable, the parent expression is also constant foldable, and we don't want to send both the parent expression + // and the sub-expression to the sidecar because the entire expression can be constant folded in one pass. + if (canBeOptimized) { + ArrayDeque queue = new ArrayDeque<>(node.getChildren()); + while (!queue.isEmpty()) { + RowExpression expression = queue.poll(); + if (hasOptimizedChildren.remove(expression)) { + expressionsToOptimize.remove(expression); + queue.addAll(expression.getChildren()); + } + } + expressionsToOptimize.add(node); + hasOptimizedChildren.add(node); + } + else if (node.getChildren().stream().anyMatch(hasOptimizedChildren::contains)) { + hasOptimizedChildren.add(node); + } + } + + public List getExpressionsToOptimize() + { + return ImmutableList.copyOf(expressionsToOptimize); + } + } + + /** + * This visitor replaces expressions with their optimized versions. + */ + private static class ReplacingVisitor + implements RowExpressionVisitor + { + private final Function resolver; + + public ReplacingVisitor(Map replacements) + { + requireNonNull(replacements, "replacements is null"); + this.resolver = i -> replacements.getOrDefault(i, i); + } + + public ReplacingVisitor(Function variableResolver) + { + requireNonNull(variableResolver, "variableResolver is null"); + this.resolver = i -> i instanceof VariableReferenceExpression ? variableResolver.apply((VariableReferenceExpression) i) : i; + } + + private boolean canBeReplaced(RowExpression rowExpression) + { + return resolver.apply(rowExpression) != rowExpression; + } + + @Override + public RowExpression visitExpression(RowExpression originalExpression, Void context) + { + return resolver.apply(originalExpression); + } + + @Override + public RowExpression visitLambda(LambdaDefinitionExpression lambda, Void context) + { + if (canBeReplaced(lambda.getBody())) { + return new LambdaDefinitionExpression( + lambda.getSourceLocation(), + lambda.getArgumentTypes(), + lambda.getArguments(), + toRowExpression(lambda.getSourceLocation(), resolver.apply(lambda.getBody()), lambda.getBody().getType())); + } + return lambda; + } + + @Override + public RowExpression visitCall(CallExpression call, Void context) + { + if (canBeReplaced(call)) { + return resolver.apply(call); + } + List updatedArguments = call.getArguments().stream() + .map(argument -> toRowExpression(argument.getSourceLocation(), argument.accept(this, context), argument.getType())) + .collect(toImmutableList()); + return new CallExpression(call.getSourceLocation(), call.getDisplayName(), call.getFunctionHandle(), call.getType(), updatedArguments); + } + + @Override + public RowExpression visitSpecialForm(SpecialFormExpression specialForm, Void context) + { + if (canBeReplaced(specialForm)) { + return resolver.apply(specialForm); + } + List updatedArguments = specialForm.getArguments().stream() + .map(argument -> toRowExpression(argument.getSourceLocation(), argument.accept(this, context), argument.getType())) + .collect(toImmutableList()); + return new SpecialFormExpression(specialForm.getSourceLocation(), specialForm.getForm(), specialForm.getType(), updatedArguments); + } + } + + private static RowExpression toRowExpression(Optional sourceLocation, Object object, Type type) + { + requireNonNull(type, "type is null"); + + if (object instanceof RowExpression) { + return (RowExpression) object; + } + + // If it's not a RowExpression, we assume it's a literal value. + return new ConstantExpression(sourceLocation, object, type); + } +} diff --git a/presto-native-sidecar-plugin/src/main/java/com/facebook/presto/sidecar/expressions/NativeExpressionOptimizerFactory.java b/presto-native-sidecar-plugin/src/main/java/com/facebook/presto/sidecar/expressions/NativeExpressionOptimizerFactory.java new file mode 100644 index 0000000000000..b1eb030e7ec0f --- /dev/null +++ b/presto-native-sidecar-plugin/src/main/java/com/facebook/presto/sidecar/expressions/NativeExpressionOptimizerFactory.java @@ -0,0 +1,65 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sidecar.expressions; + +import com.facebook.airlift.bootstrap.Bootstrap; +import com.facebook.presto.sidecar.NativeSidecarCommunicationModule; +import com.facebook.presto.spi.classloader.ThreadContextClassLoader; +import com.facebook.presto.spi.relation.ExpressionOptimizer; +import com.facebook.presto.spi.sql.planner.ExpressionOptimizerContext; +import com.facebook.presto.spi.sql.planner.ExpressionOptimizerFactory; +import com.google.inject.Injector; + +import java.util.Map; + +import static java.util.Objects.requireNonNull; + +public class NativeExpressionOptimizerFactory + implements ExpressionOptimizerFactory +{ + public static final String NAME = "native"; + + private final ClassLoader classLoader; + + @Override + public String getName() + { + return NAME; + } + + public NativeExpressionOptimizerFactory(ClassLoader classLoader) + { + this.classLoader = requireNonNull(classLoader, "classLoader is null"); + } + + @Override + public ExpressionOptimizer createOptimizer(Map config, ExpressionOptimizerContext context) + { + requireNonNull(context, "context is null"); + + try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(classLoader)) { + Bootstrap app = new Bootstrap( + new NativeSidecarCommunicationModule(), + new NativeExpressionsModule(context.getNodeManager(), context.getRowExpressionSerde(), context.getFunctionMetadataManager(), context.getFunctionResolution())); + + Injector injector = app + .noStrictConfig() + .doNotInitializeLogging() + .setRequiredConfigurationProperties(config) + .quiet() + .initialize(); + return injector.getInstance(NativeExpressionOptimizer.class); + } + } +} diff --git a/presto-native-sidecar-plugin/src/main/java/com/facebook/presto/sidecar/expressions/NativeExpressionsModule.java b/presto-native-sidecar-plugin/src/main/java/com/facebook/presto/sidecar/expressions/NativeExpressionsModule.java new file mode 100644 index 0000000000000..d7d6b4dae75d1 --- /dev/null +++ b/presto-native-sidecar-plugin/src/main/java/com/facebook/presto/sidecar/expressions/NativeExpressionsModule.java @@ -0,0 +1,67 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sidecar.expressions; + +import com.facebook.airlift.json.JsonModule; +import com.facebook.presto.spi.NodeManager; +import com.facebook.presto.spi.RowExpressionSerde; +import com.facebook.presto.spi.function.FunctionMetadataManager; +import com.facebook.presto.spi.function.StandardFunctionResolution; +import com.facebook.presto.spi.relation.RowExpression; +import com.google.inject.Binder; +import com.google.inject.Module; +import com.google.inject.Scopes; + +import static com.facebook.airlift.json.JsonBinder.jsonBinder; +import static com.facebook.airlift.json.JsonCodecBinder.jsonCodecBinder; +import static java.util.Objects.requireNonNull; + +public class NativeExpressionsModule + implements Module +{ + private final NodeManager nodeManager; + private final RowExpressionSerde rowExpressionSerde; + private final FunctionMetadataManager functionMetadataManager; + private final StandardFunctionResolution functionResolution; + + public NativeExpressionsModule(NodeManager nodeManager, RowExpressionSerde rowExpressionSerde, FunctionMetadataManager functionMetadataManager, StandardFunctionResolution functionResolution) + { + this.nodeManager = requireNonNull(nodeManager, "nodeManager is null"); + this.rowExpressionSerde = requireNonNull(rowExpressionSerde, "rowExpressionSerde is null"); + this.functionMetadataManager = requireNonNull(functionMetadataManager, "functionMetadataManager is null"); + this.functionResolution = requireNonNull(functionResolution, "functionResolution is null"); + } + + @Override + public void configure(Binder binder) + { + // Core dependencies + binder.bind(NodeManager.class).toInstance(nodeManager); + binder.bind(RowExpressionSerde.class).toInstance(rowExpressionSerde); + binder.bind(FunctionMetadataManager.class).toInstance(functionMetadataManager); + binder.bind(StandardFunctionResolution.class).toInstance(functionResolution); + + // JSON dependencies and setup + binder.install(new JsonModule()); + jsonBinder(binder).addDeserializerBinding(RowExpression.class).to(RowExpressionDeserializer.class).in(Scopes.SINGLETON); + jsonBinder(binder).addSerializerBinding(RowExpression.class).to(RowExpressionSerializer.class).in(Scopes.SINGLETON); + jsonCodecBinder(binder).bindListJsonCodec(RowExpression.class); + jsonCodecBinder(binder).bindListJsonCodec(RowExpressionOptimizationResult.class); + + binder.bind(NativeSidecarExpressionInterpreter.class).in(Scopes.SINGLETON); + + // The main service provider + binder.bind(NativeExpressionOptimizer.class).in(Scopes.SINGLETON); + } +} diff --git a/presto-native-sidecar-plugin/src/main/java/com/facebook/presto/sidecar/expressions/NativeSidecarExpressionInterpreter.java b/presto-native-sidecar-plugin/src/main/java/com/facebook/presto/sidecar/expressions/NativeSidecarExpressionInterpreter.java new file mode 100644 index 0000000000000..6af455060c8ae --- /dev/null +++ b/presto-native-sidecar-plugin/src/main/java/com/facebook/presto/sidecar/expressions/NativeSidecarExpressionInterpreter.java @@ -0,0 +1,163 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sidecar.expressions; + +import com.facebook.airlift.http.client.HttpClient; +import com.facebook.airlift.http.client.HttpUriBuilder; +import com.facebook.airlift.http.client.Request; +import com.facebook.airlift.json.JsonCodec; +import com.facebook.presto.sidecar.ForSidecarInfo; +import com.facebook.presto.sidecar.NativeSidecarFailureInfo; +import com.facebook.presto.spi.ConnectorSession; +import com.facebook.presto.spi.Node; +import com.facebook.presto.spi.NodeManager; +import com.facebook.presto.spi.PrestoException; +import com.facebook.presto.spi.relation.ExpressionOptimizer; +import com.facebook.presto.spi.relation.RowExpression; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.inject.Inject; + +import java.net.URI; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import static com.facebook.airlift.http.client.JsonBodyGenerator.jsonBodyGenerator; +import static com.facebook.airlift.http.client.JsonResponseHandler.createJsonResponseHandler; +import static com.facebook.airlift.http.client.Request.Builder.preparePost; +import static com.facebook.presto.spi.StandardErrorCode.GENERIC_USER_ERROR; +import static com.facebook.presto.spi.StandardErrorCode.INVALID_ARGUMENTS; +import static com.facebook.presto.spi.relation.ExpressionOptimizer.Level; +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.net.HttpHeaders.ACCEPT; +import static com.google.common.net.HttpHeaders.CONTENT_TYPE; +import static com.google.common.net.MediaType.JSON_UTF_8; +import static java.util.Objects.requireNonNull; + +public class NativeSidecarExpressionInterpreter +{ + public static final String PRESTO_TIME_ZONE_HEADER = "X-Presto-Time-Zone"; + public static final String PRESTO_USER_HEADER = "X-Presto-User"; + public static final String PRESTO_EXPRESSION_OPTIMIZER_LEVEL_HEADER = "X-Presto-Expression-Optimizer-Level"; + private static final String EXPRESSIONS_ENDPOINT = "/v1/expressions"; + + private final NodeManager nodeManager; + private final HttpClient httpClient; + private final JsonCodec> rowExpressionCodec; + private final JsonCodec> rowExpressionOptimizationResultJsonCodec; + + @Inject + public NativeSidecarExpressionInterpreter( + @ForSidecarInfo HttpClient httpClient, + NodeManager nodeManager, + JsonCodec> rowExpressionOptimizationResultJsonCodec, + JsonCodec> rowExpressionCodec) + { + this.httpClient = requireNonNull(httpClient, "httpClient is null"); + this.nodeManager = requireNonNull(nodeManager, "nodeManager is null"); + this.rowExpressionOptimizationResultJsonCodec = requireNonNull(rowExpressionOptimizationResultJsonCodec, "rowExpressionOptimizationResultJsonCodec is null"); + this.rowExpressionCodec = requireNonNull(rowExpressionCodec, "rowExpressionCodec is null"); + } + + public Map optimizeBatch(ConnectorSession session, Map expressions, ExpressionOptimizer.Level level) + { + ImmutableList.Builder originalExpressionsBuilder = ImmutableList.builder(); + ImmutableList.Builder resolvedExpressionsBuilder = ImmutableList.builder(); + for (Map.Entry entry : expressions.entrySet()) { + originalExpressionsBuilder.add(entry.getKey()); + resolvedExpressionsBuilder.add(entry.getValue()); + } + List originalExpressions = originalExpressionsBuilder.build(); + List resolvedExpressions = resolvedExpressionsBuilder.build(); + + List rowExpressionOptimizationResults = optimize(session, level, resolvedExpressions); + + Optional exception = rePackageExceptions(rowExpressionOptimizationResults); + if (exception.isPresent()) { + throw new PrestoException(GENERIC_USER_ERROR, "Errors encountered while optimizing expressions.", exception.get()); + } + + checkArgument( + rowExpressionOptimizationResults.size() == resolvedExpressions.size(), + "Expected %s optimized expressions, but got %s", + resolvedExpressions.size(), + rowExpressionOptimizationResults.size()); + + ImmutableMap.Builder result = ImmutableMap.builder(); + for (int i = 0; i < rowExpressionOptimizationResults.size(); i++) { + result.put(originalExpressions.get(i), rowExpressionOptimizationResults.get(i).getOptimizedExpression()); + } + return result.build(); + } + + public List optimize(ConnectorSession session, ExpressionOptimizer.Level level, List resolvedExpressions) + { + List optimizedExpressions; + try { + optimizedExpressions = httpClient.execute( + getSidecarRequest(session, level, resolvedExpressions), + createJsonResponseHandler(rowExpressionOptimizationResultJsonCodec)); + } + catch (Exception e) { + throw new PrestoException(INVALID_ARGUMENTS, "Failed to get optimized expressions from sidecar.", e); + } + return optimizedExpressions; + } + + private Request getSidecarRequest(ConnectorSession session, Level level, List resolvedExpressions) + { + return preparePost() + .setUri(getSidecarLocation()) + .setBodyGenerator(jsonBodyGenerator(rowExpressionCodec, resolvedExpressions)) + .setHeader(CONTENT_TYPE, JSON_UTF_8.toString()) + .setHeader(ACCEPT, JSON_UTF_8.toString()) + .setHeader(PRESTO_TIME_ZONE_HEADER, session.getSqlFunctionProperties().getTimeZoneKey().getId()) + .setHeader(PRESTO_USER_HEADER, session.getUser()) + .setHeader(PRESTO_EXPRESSION_OPTIMIZER_LEVEL_HEADER, level.name()) + .build(); + } + + private URI getSidecarLocation() + { + Node sidecarNode = nodeManager.getSidecarNode(); + return HttpUriBuilder + .uriBuilderFrom(sidecarNode.getHttpUri()) + .appendPath(EXPRESSIONS_ENDPOINT) + .build(); + } + + private static Optional rePackageExceptions(List rowExpressionOptimizationResults) + { + // Extract all exceptions from rowExpressionOptimizationResults + List exceptions = rowExpressionOptimizationResults.stream() + .map(RowExpressionOptimizationResult::getExpressionFailureInfo) + .map(NativeSidecarFailureInfo::toException) + .filter(e -> e.getMessage() != null && !e.getMessage().isEmpty()) + .collect(toImmutableList()); + + if (exceptions.isEmpty()) { + return Optional.empty(); + } + + Exception primary = exceptions.get(0); + + for (int i = 1; i < exceptions.size(); i++) { + primary.addSuppressed(exceptions.get(i)); + } + + return Optional.of(primary); + } +} diff --git a/presto-native-sidecar-plugin/src/main/java/com/facebook/presto/sidecar/expressions/RowExpressionDeserializer.java b/presto-native-sidecar-plugin/src/main/java/com/facebook/presto/sidecar/expressions/RowExpressionDeserializer.java new file mode 100644 index 0000000000000..36d7727f7bf97 --- /dev/null +++ b/presto-native-sidecar-plugin/src/main/java/com/facebook/presto/sidecar/expressions/RowExpressionDeserializer.java @@ -0,0 +1,52 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sidecar.expressions; + +import com.facebook.presto.spi.RowExpressionSerde; +import com.facebook.presto.spi.relation.RowExpression; +import com.fasterxml.jackson.core.JsonParser; +import com.fasterxml.jackson.databind.DeserializationContext; +import com.fasterxml.jackson.databind.JsonDeserializer; +import com.fasterxml.jackson.databind.jsontype.TypeDeserializer; +import com.google.inject.Inject; + +import java.io.IOException; + +import static java.util.Objects.requireNonNull; + +public final class RowExpressionDeserializer + extends JsonDeserializer +{ + private final RowExpressionSerde rowExpressionSerde; + + @Inject + public RowExpressionDeserializer(RowExpressionSerde rowExpressionSerde) + { + this.rowExpressionSerde = requireNonNull(rowExpressionSerde, "rowExpressionSerde is null"); + } + + @Override + public RowExpression deserialize(JsonParser jsonParser, DeserializationContext context) + throws IOException + { + return rowExpressionSerde.deserialize(jsonParser.readValueAsTree().toString()); + } + + @Override + public RowExpression deserializeWithType(JsonParser jsonParser, DeserializationContext context, TypeDeserializer typeDeserializer) + throws IOException + { + return deserialize(jsonParser, context); + } +} diff --git a/presto-native-sidecar-plugin/src/main/java/com/facebook/presto/sidecar/expressions/RowExpressionSerializer.java b/presto-native-sidecar-plugin/src/main/java/com/facebook/presto/sidecar/expressions/RowExpressionSerializer.java new file mode 100644 index 0000000000000..0b347c05e26b7 --- /dev/null +++ b/presto-native-sidecar-plugin/src/main/java/com/facebook/presto/sidecar/expressions/RowExpressionSerializer.java @@ -0,0 +1,52 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sidecar.expressions; + +import com.facebook.presto.spi.RowExpressionSerde; +import com.facebook.presto.spi.relation.RowExpression; +import com.fasterxml.jackson.core.JsonGenerator; +import com.fasterxml.jackson.databind.JsonSerializer; +import com.fasterxml.jackson.databind.SerializerProvider; +import com.fasterxml.jackson.databind.jsontype.TypeSerializer; +import com.google.inject.Inject; + +import java.io.IOException; + +import static java.util.Objects.requireNonNull; + +public final class RowExpressionSerializer + extends JsonSerializer +{ + private final RowExpressionSerde rowExpressionSerde; + + @Inject + public RowExpressionSerializer(RowExpressionSerde rowExpressionSerde) + { + this.rowExpressionSerde = requireNonNull(rowExpressionSerde, "rowExpressionSerde is null"); + } + + @Override + public void serialize(RowExpression rowExpression, JsonGenerator jsonGenerator, SerializerProvider serializerProvider) + throws IOException + { + jsonGenerator.writeRawValue(rowExpressionSerde.serialize(rowExpression)); + } + + @Override + public void serializeWithType(RowExpression rowExpression, JsonGenerator jsonGenerator, SerializerProvider serializerProvider, TypeSerializer typeSerializer) + throws IOException + { + serialize(rowExpression, jsonGenerator, serializerProvider); + } +} diff --git a/presto-native-sidecar-plugin/src/test/java/com/facebook/presto/sidecar/NativeSidecarPluginQueryRunner.java b/presto-native-sidecar-plugin/src/test/java/com/facebook/presto/sidecar/NativeSidecarPluginQueryRunner.java index ead98f4a62f43..a16cebb13212b 100644 --- a/presto-native-sidecar-plugin/src/test/java/com/facebook/presto/sidecar/NativeSidecarPluginQueryRunner.java +++ b/presto-native-sidecar-plugin/src/test/java/com/facebook/presto/sidecar/NativeSidecarPluginQueryRunner.java @@ -40,13 +40,21 @@ public static void main(String[] args) javaQueryRunner.close(); // Launch distributed runner. - DistributedQueryRunner queryRunner = (DistributedQueryRunner) PrestoNativeQueryRunnerUtils.nativeHiveQueryRunnerBuilder() - .setCoordinatorSidecarEnabled(true) - .build(); - setupNativeSidecarPlugin(queryRunner); + DistributedQueryRunner queryRunner = getQueryRunner(); Thread.sleep(10); Logger log = Logger.get(DistributedQueryRunner.class); log.info("======== SERVER STARTED ========"); log.info("\n====\n%s\n====", queryRunner.getCoordinator().getBaseUrl()); } + + public static DistributedQueryRunner getQueryRunner() + throws Exception + { + // Launch distributed runner. + DistributedQueryRunner queryRunner = (DistributedQueryRunner) PrestoNativeQueryRunnerUtils.nativeHiveQueryRunnerBuilder() + .setCoordinatorSidecarEnabled(true) + .build(); + setupNativeSidecarPlugin(queryRunner); + return queryRunner; + } } diff --git a/presto-native-sidecar-plugin/src/test/java/com/facebook/presto/sidecar/NativeSidecarPluginQueryRunnerUtils.java b/presto-native-sidecar-plugin/src/test/java/com/facebook/presto/sidecar/NativeSidecarPluginQueryRunnerUtils.java index 7f49d63b746de..d03643b6feab9 100644 --- a/presto-native-sidecar-plugin/src/test/java/com/facebook/presto/sidecar/NativeSidecarPluginQueryRunnerUtils.java +++ b/presto-native-sidecar-plugin/src/test/java/com/facebook/presto/sidecar/NativeSidecarPluginQueryRunnerUtils.java @@ -14,6 +14,7 @@ package com.facebook.presto.sidecar; import com.facebook.presto.scalar.sql.NativeSqlInvokedFunctionsPlugin; +import com.facebook.presto.sidecar.expressions.NativeExpressionOptimizerFactory; import com.facebook.presto.sidecar.functionNamespace.NativeFunctionNamespaceManagerFactory; import com.facebook.presto.sidecar.sessionpropertyproviders.NativeSystemSessionPropertyProviderFactory; import com.facebook.presto.sidecar.typemanager.NativeTypeManagerFactory; @@ -51,6 +52,7 @@ public static void setupNativeSidecarPlugin(QueryRunner queryRunner) queryRunner.loadTypeManager(NativeTypeManagerFactory.NAME); queryRunner.loadPlanCheckerProviderManager("native", ImmutableMap.of()); + queryRunner.getExpressionManager().loadExpressionOptimizerFactory(NativeExpressionOptimizerFactory.NAME, "native", ImmutableMap.of()); queryRunner.installPlugin(new NativeSqlInvokedFunctionsPlugin()); } } diff --git a/presto-native-sidecar-plugin/src/test/java/com/facebook/presto/sidecar/expressions/TestNativeExpressionInterpreter.java b/presto-native-sidecar-plugin/src/test/java/com/facebook/presto/sidecar/expressions/TestNativeExpressionInterpreter.java index aa6770c756cd2..c67df04b4e75f 100644 --- a/presto-native-sidecar-plugin/src/test/java/com/facebook/presto/sidecar/expressions/TestNativeExpressionInterpreter.java +++ b/presto-native-sidecar-plugin/src/test/java/com/facebook/presto/sidecar/expressions/TestNativeExpressionInterpreter.java @@ -14,10 +14,7 @@ package com.facebook.presto.sidecar.expressions; import com.facebook.airlift.bootstrap.Bootstrap; -import com.facebook.airlift.http.client.HttpUriBuilder; -import com.facebook.airlift.json.JsonCodec; import com.facebook.airlift.json.JsonModule; -import com.facebook.airlift.log.Logger; import com.facebook.drift.codec.guice.ThriftCodecModule; import com.facebook.presto.block.BlockJsonSerde; import com.facebook.presto.common.block.Block; @@ -29,9 +26,11 @@ import com.facebook.presto.connector.ConnectorManager; import com.facebook.presto.metadata.FunctionAndTypeManager; import com.facebook.presto.metadata.HandleJsonModule; -import com.facebook.presto.nativeworker.PrestoNativeQueryRunnerUtils; +import com.facebook.presto.metadata.MetadataManager; import com.facebook.presto.operator.scalar.FunctionAssertions; -import com.facebook.presto.sidecar.NativeSidecarFailureInfo; +import com.facebook.presto.sidecar.ForSidecarInfo; +import com.facebook.presto.sidecar.NativeSidecarPluginQueryRunner; +import com.facebook.presto.spi.NodeManager; import com.facebook.presto.spi.relation.CallExpression; import com.facebook.presto.spi.relation.ConstantExpression; import com.facebook.presto.spi.relation.ExpressionOptimizer; @@ -41,134 +40,62 @@ import com.facebook.presto.spi.relation.RowExpressionVisitor; import com.facebook.presto.spi.relation.SpecialFormExpression; import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.TestingRowExpressionTranslator; import com.facebook.presto.sql.analyzer.FeaturesConfig; import com.facebook.presto.sql.expressions.AbstractTestExpressionInterpreter; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.tree.Expression; +import com.facebook.presto.tests.DistributedQueryRunner; import com.facebook.presto.type.TypeDeserializer; -import com.fasterxml.jackson.core.JsonProcessingException; -import com.fasterxml.jackson.databind.JsonNode; -import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.collect.ImmutableList; import com.google.inject.Injector; -import com.google.inject.Key; import com.google.inject.Module; import com.google.inject.Scopes; import org.intellij.lang.annotations.Language; import org.testng.annotations.AfterClass; -import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; -import java.io.IOException; -import java.net.ServerSocket; -import java.net.URI; -import java.net.http.HttpClient; -import java.net.http.HttpRequest; -import java.net.http.HttpResponse; -import java.nio.charset.StandardCharsets; -import java.nio.file.Files; -import java.nio.file.Path; import java.util.ArrayList; import java.util.List; import java.util.Optional; -import java.util.UUID; import static com.facebook.airlift.configuration.ConfigBinder.configBinder; +import static com.facebook.airlift.http.client.HttpClientBinder.httpClientBinder; import static com.facebook.airlift.json.JsonBinder.jsonBinder; -import static com.facebook.airlift.json.JsonCodec.jsonCodec; import static com.facebook.airlift.json.JsonCodecBinder.jsonCodecBinder; +import static com.facebook.airlift.testing.Closeables.closeAllRuntimeException; import static com.facebook.presto.SessionTestUtils.TEST_SESSION; -import static com.facebook.presto.metadata.FunctionAndTypeManager.createTestFunctionAndTypeManager; -import static com.facebook.presto.nativeworker.PrestoNativeQueryRunnerUtils.getNativeQueryRunnerParameters; -import static com.facebook.presto.operator.scalar.ApplyFunction.APPLY_FUNCTION; -import static com.google.common.net.HttpHeaders.ACCEPT; -import static com.google.common.net.HttpHeaders.CONTENT_TYPE; -import static com.google.common.net.MediaType.JSON_UTF_8; +import static com.facebook.presto.metadata.MetadataManager.createTestMetadataManager; import static com.google.inject.multibindings.Multibinder.newSetBinder; import static java.lang.String.format; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertNotNull; -import static org.testng.Assert.assertNull; import static org.testng.Assert.assertTrue; public class TestNativeExpressionInterpreter extends AbstractTestExpressionInterpreter { - private static final Logger log = Logger.get(TestNativeExpressionInterpreter.class); - - private JsonCodec codec; - private TestVisitor visitor; - private Process sidecar; - private URI expressionUri; + private final TestVisitor visitor; + private final MetadataManager metadata; + private final TestingRowExpressionTranslator translator; + private final DistributedQueryRunner queryRunner; + private final NativeSidecarExpressionInterpreter rowExpressionInterpreter; public TestNativeExpressionInterpreter() - { - METADATA.getFunctionAndTypeManager().registerBuiltInFunctions(ImmutableList.of(APPLY_FUNCTION)); - } - - @BeforeClass - public void setup() throws Exception { - codec = getJsonCodec(); - visitor = new TestVisitor(); - int port = findRandomPort(); - HttpUriBuilder sidecarUri = HttpUriBuilder.uriBuilder() - .scheme("http") - .host("127.0.0.1") - .port(port); - expressionUri = sidecarUri.appendPath("/v1/expressions").build(); - sidecar = getSidecarProcess(sidecarUri.build(), port); - - try { - HttpClient client = HttpClient.newHttpClient(); - URI infoUri = sidecarUri.appendPath("/v1/info").build(); - HttpRequest request = HttpRequest.newBuilder() - .uri(infoUri) - .header(ACCEPT, JSON_UTF_8.toString()) - .GET() - .build(); - - long timeoutMs = 15000; - long pollIntervalMs = 1000; - long deadline = System.currentTimeMillis() + timeoutMs; - boolean sidecarProcessStarted = false; - - while (System.currentTimeMillis() < deadline) { - try { - HttpResponse response = client.send(request, HttpResponse.BodyHandlers.ofString(StandardCharsets.UTF_8)); - if (response.statusCode() != 500) { - sidecarProcessStarted = true; - break; - } - } - catch (IOException e) { - // ignore and retry until deadline - } - catch (InterruptedException e) { - throw new RuntimeException(e); - } - - try { - Thread.sleep(pollIntervalMs); - } - catch (InterruptedException e) { - // ignore and retry until deadline - } - } - - assertTrue(sidecarProcessStarted, format("Sidecar did not start properly within %d ms", timeoutMs)); - } - catch (Exception e) { - log.error(e, "Failed while waiting for sidecar startup"); - throw new Exception(e); - } + this.queryRunner = NativeSidecarPluginQueryRunner.getQueryRunner(); + FunctionAndTypeManager functionAndTypeManager = queryRunner.getCoordinator().getFunctionAndTypeManager(); + this.metadata = createTestMetadataManager(functionAndTypeManager); + this.translator = new TestingRowExpressionTranslator(metadata); + this.rowExpressionInterpreter = getRowExpressionInterpreter(functionAndTypeManager, queryRunner.getCoordinator().getPluginNodeManager()); + this.visitor = new TestVisitor(); } - @AfterClass + @AfterClass(alwaysRun = true) public void tearDown() { - sidecar.destroyForcibly(); + closeAllRuntimeException(queryRunner); } /// Velox permits Bigint to Varchar cast but Presto does not. @@ -207,47 +134,47 @@ public void testFailedExpressionOptimization() { // TODO: Velox COALESCE rewrite should be enhanced to deduplicate fail expressions. assertFailedMatches("coalesce(0 / 0 > 1, unbound_boolean, 0 / 0 = 0)", - "COALESCE\\(presto.default.\\$operator\\$greater_than\\(presto.default.\\$operator\\$cast\\(presto.default.fail\\(.*\\)\\), 1\\), unbound_boolean, presto.default.\\$operator\\$equal\\(presto.default.\\$operator\\$cast\\(presto.default.fail\\(.*\\)\\), 0\\)\\)"); + "COALESCE\\(presto.default.\\$operator\\$greater_than\\(presto.default.\\$operator\\$cast\\(native.default.fail\\(.*\\)\\), 1\\), unbound_boolean, presto.default.\\$operator\\$equal\\(presto.default.\\$operator\\$cast\\(native.default.fail\\(.*\\)\\), 0\\)\\)"); - assertFailedMatches("if(false, 1, 0 / 0)", "presto.default.\\$operator\\$cast\\(presto.default.fail\\(.*\\)\\)"); + assertFailedMatches("if(false, 1, 0 / 0)", "presto.default.\\$operator\\$cast\\(native.default.fail\\(.*\\)\\)"); assertFailedMatches("CASE unbound_long WHEN 1 THEN 1 WHEN 0 / 0 THEN 2 END", - "SWITCH\\(WHEN\\(presto.default.\\$operator\\$equal\\(1, unbound_long\\), 1\\), WHEN\\(presto.default.\\$operator\\$equal\\(presto.default.\\$operator\\$cast\\(presto.default.\\$operator\\$cast\\(presto.default.fail\\(.*\\)\\)\\), unbound_long\\), 2\\), null\\)"); + "SWITCH\\(WHEN\\(presto.default.\\$operator\\$equal\\(1, unbound_long\\), 1\\), WHEN\\(presto.default.\\$operator\\$equal\\(presto.default.\\$operator\\$cast\\(presto.default.\\$operator\\$cast\\(native.default.fail\\(.*\\)\\)\\), unbound_long\\), 2\\), null\\)"); assertFailedMatches("CASE unbound_boolean WHEN true THEN 1 ELSE 0 / 0 END", - "SWITCH\\(WHEN\\(presto.default.\\$operator\\$equal\\(true, unbound_boolean\\), 1\\), presto.default.\\$operator\\$cast\\(presto.default.fail\\(.*\\)\\)\\)"); + "SWITCH\\(WHEN\\(presto.default.\\$operator\\$equal\\(true, unbound_boolean\\), 1\\), presto.default.\\$operator\\$cast\\(native.default.fail\\(.*\\)\\)\\)"); assertFailedMatches("CASE bound_long WHEN unbound_long THEN 1 WHEN 0 / 0 THEN 2 ELSE 1 END", - "SWITCH\\(WHEN\\(presto.default.\\$operator\\$equal\\(unbound_long, 1234\\), 1\\), WHEN\\(presto.default.\\$operator\\$equal\\(presto.default.\\$operator\\$cast\\(presto.default.\\$operator\\$cast\\(presto.default.fail\\(.*\\)\\)\\), 1234\\), 2\\), 1\\)"); + "SWITCH\\(WHEN\\(presto.default.\\$operator\\$equal\\(unbound_long, 1234\\), 1\\), WHEN\\(presto.default.\\$operator\\$equal\\(presto.default.\\$operator\\$cast\\(presto.default.\\$operator\\$cast\\(native.default.fail\\(.*\\)\\)\\), 1234\\), 2\\), 1\\)"); assertFailedMatches("case when unbound_boolean then 1 when 0 / 0 = 0 then 2 end", - "SWITCH\\(WHEN\\(unbound_boolean, 1\\), WHEN\\(presto.default.\\$operator\\$equal\\(presto.default.\\$operator\\$cast\\(presto.default.fail\\(.*\\)\\), 0\\), 2\\), null\\)"); + "SWITCH\\(WHEN\\(unbound_boolean, 1\\), WHEN\\(presto.default.\\$operator\\$equal\\(presto.default.\\$operator\\$cast\\(native.default.fail\\(.*\\)\\), 0\\), 2\\), null\\)"); assertFailedMatches("case when unbound_boolean then 1 else 0 / 0 end", - "SWITCH\\(WHEN\\(unbound_boolean, 1\\), presto.default.\\$operator\\$cast\\(presto.default.fail\\(.*\\)\\)\\)"); + "SWITCH\\(WHEN\\(unbound_boolean, 1\\), presto.default.\\$operator\\$cast\\(native.default.fail\\(.*\\)\\)\\)"); assertFailedMatches("case when unbound_boolean then 0 / 0 else 1 end", - "SWITCH\\(WHEN\\(unbound_boolean, presto.default.\\$operator\\$cast\\(presto.default.fail\\(.*\\)\\)\\), 1\\)"); + "SWITCH\\(WHEN\\(unbound_boolean, presto.default.\\$operator\\$cast\\(native.default.fail\\(.*\\)\\)\\), 1\\)"); assertFailedMatches("case true " + "when unbound_long = 1 then 1 " + "when 0 / 0 = 0 then 2 " + "else 33 end", - "SWITCH\\(WHEN\\(presto.default.\\$operator\\$equal\\(unbound_long, 1\\), 1\\), WHEN\\(presto.default.\\$operator\\$equal\\(presto.default.\\$operator\\$cast\\(presto.default.fail\\(.*\\)\\), 0\\), 2\\), 33\\)"); + "SWITCH\\(WHEN\\(presto.default.\\$operator\\$equal\\(unbound_long, 1\\), 1\\), WHEN\\(presto.default.\\$operator\\$equal\\(presto.default.\\$operator\\$cast\\(native.default.fail\\(.*\\)\\), 0\\), 2\\), 33\\)"); assertFailedMatches("case 1 " + "when 0 / 0 then 1 " + "when 0 / 0 then 2 " + "else 1 " + "end", - "SWITCH\\(WHEN\\(presto.default.\\$operator\\$equal\\(presto.default.\\$operator\\$cast\\(presto.default.fail\\(.*\\)\\), 1\\), 1\\), WHEN\\(presto.default.\\$operator\\$equal\\(presto.default.\\$operator\\$cast\\(presto.default.fail\\(.*\\)\\), 1\\), 2\\), 1\\)"); + "SWITCH\\(WHEN\\(presto.default.\\$operator\\$equal\\(presto.default.\\$operator\\$cast\\(native.default.fail\\(.*\\)\\), 1\\), 1\\), WHEN\\(presto.default.\\$operator\\$equal\\(presto.default.\\$operator\\$cast\\(native.default.fail\\(.*\\)\\), 1\\), 2\\), 1\\)"); assertFailedMatches("case 1 " + "when unbound_long then 1 " + "when 0 / 0 then 2 " + "else 1 " + "end", - "SWITCH\\(WHEN\\(presto.default.\\$operator\\$equal\\(unbound_long, 1\\), 1\\), WHEN\\(presto.default.\\$operator\\$equal\\(presto.default.\\$operator\\$cast\\(presto.default.\\$operator\\$cast\\(presto.default.fail\\(.*\\)\\)\\), 1\\), 2\\), 1\\)"); + "SWITCH\\(WHEN\\(presto.default.\\$operator\\$equal\\(unbound_long, 1\\), 1\\), WHEN\\(presto.default.\\$operator\\$equal\\(presto.default.\\$operator\\$cast\\(presto.default.\\$operator\\$cast\\(native.default.fail\\(.*\\)\\)\\), 1\\), 2\\), 1\\)"); } /// Sidecar will return an ExecutionFailure when an expression throws during evaluation. The caller of expression @@ -429,36 +356,10 @@ private void assertEvaluateFails(@Language("SQL") String expression, @Language(" { RowExpression rowExpression = sqlToRowExpression(expression); rowExpression = rowExpression.accept(visitor, null); - HttpResponse response = null; - try { - response = getSidecarResponse(rowExpression, ExpressionOptimizer.Level.EVALUATED); - } - catch (Exception e) { - log.error(e, "Failed to get sidecar response: %s.", e.getMessage()); - throw new RuntimeException(e); - } - assertEquals(response.statusCode(), 200, "Sidecar returned error."); - String responseBody = response.body(); - ObjectMapper mapper = new ObjectMapper(); - NativeSidecarFailureInfo result = null; - try { - // Response should be of type NativeSidecarFailureInfo. - JsonNode expressionOptimizationResultList = mapper.readTree(responseBody); - assertTrue(expressionOptimizationResultList.isArray()); - JsonNode expressionOptimizationResult = expressionOptimizationResultList.get(0); - assertNull(expressionOptimizationResult.get("optimizedExpression")); - JsonNode failureInfo = expressionOptimizationResult.get("expressionFailureInfo"); - JsonCodec errorCodec = jsonCodec(NativeSidecarFailureInfo.class); - result = errorCodec.fromJson(failureInfo.toString()); - - assertNotNull(result.getMessage()); - assertTrue(result.getMessage().contains(errorMessage), format("Sidecar response: %s did not contain expected error message: %s.", response.body(), errorMessage)); - } - catch (JsonProcessingException e) { - log.error(e, "Failed to decode RowExpression from sidecar response: %s.", e.getMessage()); - throw new RuntimeException(e); - } + RowExpressionOptimizationResult response = optimize(rowExpression, ExpressionOptimizer.Level.EVALUATED); + assertNotNull(response.getExpressionFailureInfo().getMessage()); + assertTrue(response.getExpressionFailureInfo().getMessage().contains(errorMessage), format("Sidecar response: %s did not contain expected error message: %s.", response, errorMessage)); } /// Checks that the string representation of the failed optimized expression matches expected. @@ -480,8 +381,8 @@ public void assertDoNotOptimize(@Language("SQL") String expression, ExpressionOp private RowExpression sqlToRowExpression(String expression) { - Expression parsedExpression = FunctionAssertions.createExpression(expression, METADATA, SYMBOL_TYPES); - return TRANSLATOR.translate(parsedExpression, SYMBOL_TYPES); + Expression parsedExpression = FunctionAssertions.createExpression(expression, metadata, SYMBOL_TYPES); + return translator.translate(parsedExpression, SYMBOL_TYPES); } @Override @@ -493,67 +394,32 @@ public void assertEvaluatedEquals(@Language("SQL") String actual, @Language("SQL private RowExpression optimizeRowExpression(RowExpression expression, ExpressionOptimizer.Level level) { expression = expression.accept(visitor, null); - HttpResponse response = null; - try { - response = getSidecarResponse(expression, level); - } - catch (Exception e) { - log.error(e, "Failed to get sidecar response: %s.", e.getMessage()); - throw new RuntimeException(e); - } + RowExpressionOptimizationResult response = optimize(expression, level); - assertEquals(response.statusCode(), 200, "Sidecar returned error."); - String responseBody = response.body(); - ObjectMapper mapper = new ObjectMapper(); - RowExpression result = expression; - try { - // Response should be a JSON array consisting of a single RowExpression. - JsonNode expressionOptimizationResultList = mapper.readTree(responseBody); - assertTrue(expressionOptimizationResultList.isArray()); - JsonNode expressionOptimizationResult = expressionOptimizationResultList.get(0); - // Presto protocol generates a concrete struct for `expressionFailureInfo` instead of a `shared_ptr`, so - // `expressionFailureInfo` cannot be null. Hence, we check that the message field is empty to verify there - // is no failure. - assertTrue(expressionOptimizationResult.get("expressionFailureInfo").get("message").isEmpty()); - JsonNode optimizedExpression = expressionOptimizationResult.get("optimizedExpression"); - result = codec.fromJson(optimizedExpression.toString()); - } - catch (JsonProcessingException e) { - log.error(e, "Failed to decode RowExpression from sidecar response: %s.", e.getMessage()); - throw new RuntimeException(e); - } - - return result; + assertNotNull(response.getExpressionFailureInfo().getMessage()); + assertTrue(response.getExpressionFailureInfo().getMessage().isEmpty()); + return response.getOptimizedExpression(); } - private HttpResponse getSidecarResponse(RowExpression expression, ExpressionOptimizer.Level level) - throws IOException, InterruptedException + private RowExpressionOptimizationResult optimize(RowExpression expression, ExpressionOptimizer.Level level) { - String json = String.format("[%s]", codec.toJson(expression)); - HttpClient client = HttpClient.newHttpClient(); - HttpRequest request = HttpRequest.newBuilder() - .uri(expressionUri) - .header(CONTENT_TYPE, JSON_UTF_8.toString()) - .header(ACCEPT, JSON_UTF_8.toString()) - .header("X-Presto-Time-Zone", TEST_SESSION.getSqlFunctionProperties().getTimeZoneKey().getId()) - .header("X-Presto-Expression-Optimizer-Level", level.name()) - .POST(HttpRequest.BodyPublishers.ofString(json, StandardCharsets.UTF_8)) - .build(); - - return client.send(request, HttpResponse.BodyHandlers.ofString(StandardCharsets.UTF_8)); + List results = rowExpressionInterpreter.optimize(TEST_SESSION.toConnectorSession(), level, List.of(expression)); + // Since we are only sending in a rowExpression at a time, the result is going to be of fixed size 1. + assertEquals(results.size(), 1); + return results.get(0); } - private JsonCodec getJsonCodec() + private NativeSidecarExpressionInterpreter getRowExpressionInterpreter(FunctionAndTypeManager functionAndTypeManager, NodeManager nodeManager) { Module module = binder -> { + binder.bind(NodeManager.class).toInstance(nodeManager); + binder.bind(TypeManager.class).toInstance(functionAndTypeManager); binder.install(new JsonModule()); - binder.install(new HandleJsonModule()); + binder.install(new HandleJsonModule(functionAndTypeManager.getHandleResolver())); binder.bind(ConnectorManager.class).toProvider(() -> null).in(Scopes.SINGLETON); binder.install(new ThriftCodecModule()); configBinder(binder).bindConfig(FeaturesConfig.class); - FunctionAndTypeManager functionAndTypeManager = createTestFunctionAndTypeManager(); - binder.bind(TypeManager.class).toInstance(functionAndTypeManager); jsonBinder(binder).addDeserializerBinding(Type.class).to(TypeDeserializer.class); newSetBinder(binder, Type.class); @@ -561,54 +427,20 @@ private JsonCodec getJsonCodec() newSetBinder(binder, BlockEncoding.class); jsonBinder(binder).addSerializerBinding(Block.class).to(BlockJsonSerde.Serializer.class); jsonBinder(binder).addDeserializerBinding(Block.class).to(BlockJsonSerde.Deserializer.class); - jsonCodecBinder(binder).bindJsonCodec(RowExpression.class); + jsonCodecBinder(binder).bindListJsonCodec(RowExpression.class); + jsonCodecBinder(binder).bindListJsonCodec(RowExpressionOptimizationResult.class); + + httpClientBinder(binder).bindHttpClient("sidecar", ForSidecarInfo.class); + + binder.bind(NativeSidecarExpressionInterpreter.class).in(Scopes.SINGLETON); }; Bootstrap app = new Bootstrap(ImmutableList.of(module)); Injector injector = app .doNotInitializeLogging() .quiet() .initialize(); - return injector.getInstance(new Key>() {}); - } - - private static Process getSidecarProcess(URI discoveryUri, int port) - throws IOException - { - Path tempDirectoryPath = Files.createTempDirectory(PrestoNativeQueryRunnerUtils.class.getSimpleName()); - log.info("Temp directory for Sidecar: %s", tempDirectoryPath.toString()); - - String configProperties = format("discovery.uri=%s%n" + - "presto.version=testversion%n" + - "system-memory-gb=4%n" + - "native-sidecar=true%n" + - "http-server.http.port=%d", discoveryUri, port); - - Files.write(tempDirectoryPath.resolve("config.properties"), configProperties.getBytes()); - Files.write(tempDirectoryPath.resolve("node.properties"), - format("node.id=%s%n" + - "node.internal-address=127.0.0.1%n" + - "node.environment=testing%n" + - "node.location=test-location", UUID.randomUUID()).getBytes()); - - Path catalogDirectoryPath = tempDirectoryPath.resolve("catalog"); - Files.createDirectory(catalogDirectoryPath); - PrestoNativeQueryRunnerUtils.NativeQueryRunnerParameters nativeQueryRunnerParameters = getNativeQueryRunnerParameters(); - String prestoServerPath = nativeQueryRunnerParameters.serverBinary.toString(); - - return new ProcessBuilder(prestoServerPath, "--logtostderr=1", "--v=1") - .directory(tempDirectoryPath.toFile()) - .redirectErrorStream(true) - .redirectOutput(ProcessBuilder.Redirect.to(tempDirectoryPath.resolve("sidecar.out").toFile())) - .redirectError(ProcessBuilder.Redirect.to(tempDirectoryPath.resolve("sidecar.out").toFile())) - .start(); - } - public static int findRandomPort() - throws IOException - { - try (ServerSocket socket = new ServerSocket(0)) { - return socket.getLocalPort(); - } + return injector.getInstance(NativeSidecarExpressionInterpreter.class); } private static class TestVisitor diff --git a/presto-spark-base/src/main/java/com/facebook/presto/spark/PrestoSparkModule.java b/presto-spark-base/src/main/java/com/facebook/presto/spark/PrestoSparkModule.java index 889aeca27deb9..d7e41e512f830 100644 --- a/presto-spark-base/src/main/java/com/facebook/presto/spark/PrestoSparkModule.java +++ b/presto-spark-base/src/main/java/com/facebook/presto/spark/PrestoSparkModule.java @@ -138,6 +138,7 @@ import com.facebook.presto.spi.NodeManager; import com.facebook.presto.spi.PageIndexerFactory; import com.facebook.presto.spi.PageSorter; +import com.facebook.presto.spi.RowExpressionSerde; import com.facebook.presto.spi.analyzer.ViewDefinition; import com.facebook.presto.spi.memory.ClusterMemoryPoolManager; import com.facebook.presto.spi.plan.SimplePlanFragment; @@ -146,6 +147,7 @@ import com.facebook.presto.spi.relation.DeterminismEvaluator; import com.facebook.presto.spi.relation.DomainTranslator; import com.facebook.presto.spi.relation.PredicateCompiler; +import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.spi.session.WorkerSessionPropertyProvider; import com.facebook.presto.spiller.GenericPartitioningSpillerFactory; @@ -180,6 +182,7 @@ import com.facebook.presto.sql.analyzer.QueryExplainer; import com.facebook.presto.sql.analyzer.QueryPreparerProviderManager; import com.facebook.presto.sql.expressions.ExpressionOptimizerManager; +import com.facebook.presto.sql.expressions.JsonCodecRowExpressionSerde; import com.facebook.presto.sql.gen.ExpressionCompiler; import com.facebook.presto.sql.gen.JoinCompiler; import com.facebook.presto.sql.gen.JoinFilterFunctionCompiler; @@ -310,6 +313,7 @@ protected void setup(Binder binder) jsonCodecBinder(binder).bindJsonCodec(BroadcastFileInfo.class); jsonCodecBinder(binder).bindJsonCodec(SimplePlanFragment.class); binder.bind(SimplePlanFragmentSerde.class).to(JsonCodecSimplePlanFragmentSerde.class).in(Scopes.SINGLETON); + jsonCodecBinder(binder).bindJsonCodec(RowExpression.class); // smile codecs smileCodecBinder(binder).bindSmileCodec(TaskSource.class); @@ -367,6 +371,7 @@ protected void setup(Binder binder) // expression manager binder.bind(ExpressionOptimizerManager.class).in(Scopes.SINGLETON); + binder.bind(RowExpressionSerde.class).to(JsonCodecRowExpressionSerde.class).in(Scopes.SINGLETON); // tracer provider managers binder.bind(TracerProviderManager.class).in(Scopes.SINGLETON); diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/RowExpressionSerde.java b/presto-spi/src/main/java/com/facebook/presto/spi/RowExpressionSerde.java new file mode 100644 index 0000000000000..ab5381aa2556c --- /dev/null +++ b/presto-spi/src/main/java/com/facebook/presto/spi/RowExpressionSerde.java @@ -0,0 +1,23 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.spi; + +import com.facebook.presto.spi.relation.RowExpression; + +public interface RowExpressionSerde +{ + String serialize(RowExpression expression); + + RowExpression deserialize(String value); +} diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/sql/planner/ExpressionOptimizerContext.java b/presto-spi/src/main/java/com/facebook/presto/spi/sql/planner/ExpressionOptimizerContext.java index c9a6f84aaaa1d..51b7cce149d8f 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/sql/planner/ExpressionOptimizerContext.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/sql/planner/ExpressionOptimizerContext.java @@ -14,6 +14,7 @@ package com.facebook.presto.spi.sql.planner; import com.facebook.presto.spi.NodeManager; +import com.facebook.presto.spi.RowExpressionSerde; import com.facebook.presto.spi.function.FunctionMetadataManager; import com.facebook.presto.spi.function.StandardFunctionResolution; @@ -22,12 +23,14 @@ public class ExpressionOptimizerContext { private final NodeManager nodeManager; + private final RowExpressionSerde rowExpressionSerde; private final FunctionMetadataManager functionMetadataManager; private final StandardFunctionResolution functionResolution; - public ExpressionOptimizerContext(NodeManager nodeManager, FunctionMetadataManager functionMetadataManager, StandardFunctionResolution functionResolution) + public ExpressionOptimizerContext(NodeManager nodeManager, RowExpressionSerde rowExpressionSerde, FunctionMetadataManager functionMetadataManager, StandardFunctionResolution functionResolution) { this.nodeManager = requireNonNull(nodeManager, "nodeManager is null"); + this.rowExpressionSerde = requireNonNull(rowExpressionSerde, "rowExpressionSerde is null"); this.functionMetadataManager = requireNonNull(functionMetadataManager, "functionMetadataManager is null"); this.functionResolution = requireNonNull(functionResolution, "functionResolution is null"); } @@ -37,6 +40,11 @@ public NodeManager getNodeManager() return nodeManager; } + public RowExpressionSerde getRowExpressionSerde() + { + return rowExpressionSerde; + } + public FunctionMetadataManager getFunctionMetadataManager() { return functionMetadataManager; diff --git a/presto-tests/pom.xml b/presto-tests/pom.xml index 9e21ba392f653..e907fa0868449 100644 --- a/presto-tests/pom.xml +++ b/presto-tests/pom.xml @@ -478,6 +478,7 @@ com.facebook.airlift.drift:drift-codec com.facebook.airlift:jaxrs com.facebook.presto:presto-plugin-toolkit + com.facebook.airlift:node diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueryFramework.java b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueryFramework.java index 3f95501b03cd9..3553f6c070a5c 100644 --- a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueryFramework.java +++ b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueryFramework.java @@ -13,7 +13,6 @@ */ package com.facebook.presto.tests; -import com.facebook.airlift.node.NodeInfo; import com.facebook.airlift.units.Duration; import com.facebook.presto.Session; import com.facebook.presto.common.transaction.TransactionId; @@ -29,11 +28,13 @@ import com.facebook.presto.metadata.Metadata; import com.facebook.presto.nodeManager.PluginNodeManager; import com.facebook.presto.spi.WarningCollector; +import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.security.AccessDeniedException; import com.facebook.presto.spi.security.AllowAllAccessControl; import com.facebook.presto.sql.analyzer.FeaturesConfig; import com.facebook.presto.sql.analyzer.QueryExplainer; import com.facebook.presto.sql.expressions.ExpressionOptimizerManager; +import com.facebook.presto.sql.expressions.JsonCodecRowExpressionSerde; import com.facebook.presto.sql.parser.SqlParser; import com.facebook.presto.sql.planner.PartitioningProviderManager; import com.facebook.presto.sql.planner.Plan; @@ -64,6 +65,7 @@ import java.util.OptionalLong; import java.util.function.Consumer; +import static com.facebook.airlift.json.JsonCodec.jsonCodec; import static com.facebook.airlift.testing.Closeables.closeAllRuntimeException; import static com.facebook.presto.sql.SqlFormatter.formatSql; import static com.facebook.presto.transaction.TransactionBuilder.transaction; @@ -81,7 +83,6 @@ public abstract class AbstractTestQueryFramework { - private static final NodeInfo NODE_INFO = new NodeInfo("test"); private QueryRunner queryRunner; private ExpectedQueryRunner expectedQueryRunner; private SqlParser sqlParser; @@ -599,7 +600,8 @@ protected QueryExplainer getQueryExplainer() featuresConfig, new ExpressionOptimizerManager( new PluginNodeManager(new InMemoryNodeManager()), - queryRunner.getMetadata().getFunctionAndTypeManager()), + queryRunner.getMetadata().getFunctionAndTypeManager(), + new JsonCodecRowExpressionSerde(jsonCodec(RowExpression.class))), new TaskManagerConfig(), queryRunner.getAccessControl()) .getPlanningTimeOptimizers();