diff --git a/presto-analyzer/src/main/java/com/facebook/presto/sql/analyzer/Field.java b/presto-analyzer/src/main/java/com/facebook/presto/sql/analyzer/Field.java index 15c33950ce14b..630f4670f6cc2 100644 --- a/presto-analyzer/src/main/java/com/facebook/presto/sql/analyzer/Field.java +++ b/presto-analyzer/src/main/java/com/facebook/presto/sql/analyzer/Field.java @@ -86,14 +86,6 @@ public Field(Optional nodeLocation, Optional relati this.aliased = aliased; } - public static Field newUnqualified(Optional name, Type type) - { - requireNonNull(name, "name is null"); - requireNonNull(type, "type is null"); - - return new Field(Optional.empty(), Optional.empty(), name, type, false, Optional.empty(), Optional.empty(), false); - } - public Optional getNodeLocation() { return nodeLocation; diff --git a/presto-main-base/src/main/java/com/facebook/presto/connector/ConnectorManager.java b/presto-main-base/src/main/java/com/facebook/presto/connector/ConnectorManager.java index a45f03d7c0618..19ac259c84f33 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/connector/ConnectorManager.java +++ b/presto-main-base/src/main/java/com/facebook/presto/connector/ConnectorManager.java @@ -52,6 +52,8 @@ import com.facebook.presto.spi.connector.ConnectorRecordSetProvider; import com.facebook.presto.spi.connector.ConnectorSplitManager; import com.facebook.presto.spi.function.table.ConnectorTableFunction; +import com.facebook.presto.spi.function.table.ConnectorTableFunctionHandle; +import com.facebook.presto.spi.function.table.TableFunctionProcessorProvider; import com.facebook.presto.spi.procedure.BaseProcedure; import com.facebook.presto.spi.procedure.DistributedProcedure; import com.facebook.presto.spi.procedure.Procedure; @@ -86,6 +88,7 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Function; import static com.facebook.presto.metadata.FunctionExtractor.extractFunctions; import static com.facebook.presto.spi.ConnectorId.createInformationSchemaConnectorId; @@ -215,6 +218,12 @@ public synchronized void addConnectorFactory(ConnectorFactory connectorFactory) ConnectorFactory existingConnectorFactory = connectorFactories.putIfAbsent(connectorFactory.getName(), connectorFactory); checkArgument(existingConnectorFactory == null, "Connector %s is already registered", connectorFactory.getName()); handleResolver.addConnectorName(connectorFactory.getName(), connectorFactory.getHandleResolver()); + connectorFactory.getTableFunctionHandleResolver().ifPresent(resolver -> { + handleResolver.addTableFunctionNamespace(connectorFactory.getName(), resolver); + }); + connectorFactory.getTableFunctionSplitResolver().ifPresent(resolver -> { + handleResolver.addTableFunctionSplitNamespace(connectorFactory.getName(), resolver); + }); } public synchronized ConnectorId createConnection(String catalogName, String connectorName, Map properties) @@ -334,6 +343,7 @@ private synchronized void addConnectorInternal(MaterializedConnector connector) metadataManager.getAnalyzePropertyManager().addProperties(connectorId, connector.getAnalyzeProperties()); metadataManager.getSessionPropertyManager().addConnectorSessionProperties(connectorId, connector.getSessionProperties()); metadataManager.getFunctionAndTypeManager().getTableFunctionRegistry().addTableFunctions(connectorId, connector.getTableFunctions()); + metadataManager.getFunctionAndTypeManager().addTableFunctionProcessorProvider(connectorId, connector.getTableFunctionProcessorProvider()); } public synchronized void dropConnection(String catalogName) @@ -346,6 +356,7 @@ public synchronized void dropConnection(String catalogName) removeConnectorInternal(createInformationSchemaConnectorId(connectorId)); removeConnectorInternal(createSystemTablesConnectorId(connectorId)); metadataManager.getFunctionAndTypeManager().getTableFunctionRegistry().removeTableFunctions(connectorId); + metadataManager.getFunctionAndTypeManager().removeTableFunctionProcessorProvider(connectorId); }); } @@ -422,6 +433,7 @@ private static class MaterializedConnector private final Set> functions; private final Set connectorTableFunctions; + private final Function connectorTableFunctionProcessorProvider; private final ConnectorPageSourceProvider pageSourceProvider; private final Optional pageSinkProvider; private final Optional indexProvider; @@ -459,6 +471,7 @@ public MaterializedConnector(ConnectorId connectorId, Connector connector) Set connectorTableFunctions = connector.getTableFunctions(); requireNonNull(connectorTableFunctions, format("Connector '%s' returned a null table functions set", connectorId)); this.connectorTableFunctions = ImmutableSet.copyOf(connectorTableFunctions); + this.connectorTableFunctionProcessorProvider = connector.getTableFunctionProcessorProvider(); ConnectorPageSourceProvider connectorPageSourceProvider = null; try { @@ -660,5 +673,10 @@ public Set getTableFunctions() { return connectorTableFunctions; } + + public Function getTableFunctionProcessorProvider() + { + return connectorTableFunctionProcessorProvider; + } } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/connector/system/GlobalSystemConnector.java b/presto-main-base/src/main/java/com/facebook/presto/connector/system/GlobalSystemConnector.java index 860f6e47a6285..708c3f314cc48 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/connector/system/GlobalSystemConnector.java +++ b/presto-main-base/src/main/java/com/facebook/presto/connector/system/GlobalSystemConnector.java @@ -15,11 +15,14 @@ import com.facebook.presto.common.RuntimeStats; import com.facebook.presto.common.transaction.TransactionId; +import com.facebook.presto.operator.table.ExcludeColumns; +import com.facebook.presto.operator.table.Sequence; import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.ColumnMetadata; import com.facebook.presto.spi.ConnectorPageSource; import com.facebook.presto.spi.ConnectorSession; import com.facebook.presto.spi.ConnectorSplit; +import com.facebook.presto.spi.ConnectorSplitSource; import com.facebook.presto.spi.ConnectorTableHandle; import com.facebook.presto.spi.ConnectorTableLayout; import com.facebook.presto.spi.ConnectorTableLayoutHandle; @@ -34,6 +37,9 @@ import com.facebook.presto.spi.connector.ConnectorPageSourceProvider; import com.facebook.presto.spi.connector.ConnectorSplitManager; import com.facebook.presto.spi.connector.ConnectorTransactionHandle; +import com.facebook.presto.spi.function.table.ConnectorTableFunction; +import com.facebook.presto.spi.function.table.ConnectorTableFunctionHandle; +import com.facebook.presto.spi.function.table.TableFunctionProcessorProvider; import com.facebook.presto.spi.procedure.Procedure; import com.facebook.presto.spi.transaction.IsolationLevel; import com.facebook.presto.transaction.InternalConnector; @@ -45,7 +51,9 @@ import java.util.Map; import java.util.Optional; import java.util.Set; +import java.util.function.Function; +import static com.facebook.presto.operator.table.Sequence.getSequenceFunctionSplitSource; import static java.util.Objects.requireNonNull; public class GlobalSystemConnector @@ -56,12 +64,14 @@ public class GlobalSystemConnector private final String connectorId; private final Set systemTables; private final Set procedures; + private final Set tableFunctions; - public GlobalSystemConnector(String connectorId, Set systemTables, Set procedures) + public GlobalSystemConnector(String connectorId, Set systemTables, Set procedures, Set tableFunctions) { this.connectorId = requireNonNull(connectorId, "connectorId is null"); this.systemTables = ImmutableSet.copyOf(requireNonNull(systemTables, "systemTables is null")); this.procedures = ImmutableSet.copyOf(requireNonNull(procedures, "procedures is null")); + this.tableFunctions = ImmutableSet.copyOf(requireNonNull(tableFunctions, "tableFunctions is null")); } @Override @@ -138,8 +148,22 @@ public Map> listTableColumns(ConnectorSess @Override public ConnectorSplitManager getSplitManager() { - return (transactionHandle, session, layout, splitSchedulingContext) -> { - throw new UnsupportedOperationException(); + return new ConnectorSplitManager() { + @Override + public ConnectorSplitSource getSplits(ConnectorTransactionHandle transactionHandle, ConnectorSession session, ConnectorTableLayoutHandle layout, SplitSchedulingContext splitSchedulingContext) + { + throw new UnsupportedOperationException(); + } + + @Override + public ConnectorSplitSource getSplits(ConnectorTransactionHandle transaction, ConnectorSession session, ConnectorTableFunctionHandle function) + { + if (function instanceof Sequence.SequenceFunctionHandle) { + Sequence.SequenceFunctionHandle sequenceFunctionHandle = (Sequence.SequenceFunctionHandle) function; + return getSequenceFunctionSplitSource(sequenceFunctionHandle); + } + throw new UnsupportedOperationException(); + } }; } @@ -166,4 +190,24 @@ public Set getProcedures() { return procedures; } + + @Override + public Set getTableFunctions() + { + return tableFunctions; + } + + @Override + public Function getTableFunctionProcessorProvider() + { + return connectorTableFunctionHandle -> { + if (connectorTableFunctionHandle instanceof ExcludeColumns.ExcludeColumnsFunctionHandle) { + return ExcludeColumns.getExcludeColumnsFunctionProcessorProvider(); + } + else if (connectorTableFunctionHandle instanceof Sequence.SequenceFunctionHandle) { + return Sequence.getSequenceFunctionProcessorProvider(); + } + return null; + }; + } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/connector/system/GlobalSystemConnectorFactory.java b/presto-main-base/src/main/java/com/facebook/presto/connector/system/GlobalSystemConnectorFactory.java index 3d1c8e329188b..223684418a819 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/connector/system/GlobalSystemConnectorFactory.java +++ b/presto-main-base/src/main/java/com/facebook/presto/connector/system/GlobalSystemConnectorFactory.java @@ -18,6 +18,7 @@ import com.facebook.presto.spi.connector.Connector; import com.facebook.presto.spi.connector.ConnectorContext; import com.facebook.presto.spi.connector.ConnectorFactory; +import com.facebook.presto.spi.function.table.ConnectorTableFunction; import com.facebook.presto.spi.procedure.Procedure; import com.google.common.collect.ImmutableSet; import jakarta.inject.Inject; @@ -32,12 +33,14 @@ public class GlobalSystemConnectorFactory { private final Set tables; private final Set procedures; + private final Set tableFunctions; @Inject - public GlobalSystemConnectorFactory(Set tables, Set procedures) + public GlobalSystemConnectorFactory(Set tables, Set procedures, Set tableFunctions) { this.tables = ImmutableSet.copyOf(requireNonNull(tables, "tables is null")); this.procedures = ImmutableSet.copyOf(requireNonNull(procedures, "procedures is null")); + this.tableFunctions = ImmutableSet.copyOf(requireNonNull(tableFunctions, "tableFunctions is null")); } @Override @@ -55,6 +58,6 @@ public ConnectorHandleResolver getHandleResolver() @Override public Connector create(String catalogName, Map config, ConnectorContext context) { - return new GlobalSystemConnector(catalogName, tables, procedures); + return new GlobalSystemConnector(catalogName, tables, procedures, tableFunctions); } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/connector/system/SystemConnectorModule.java b/presto-main-base/src/main/java/com/facebook/presto/connector/system/SystemConnectorModule.java index 19728ef156e78..40c974f6e7b90 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/connector/system/SystemConnectorModule.java +++ b/presto-main-base/src/main/java/com/facebook/presto/connector/system/SystemConnectorModule.java @@ -27,7 +27,10 @@ import com.facebook.presto.connector.system.jdbc.TableTypeJdbcTable; import com.facebook.presto.connector.system.jdbc.TypesJdbcTable; import com.facebook.presto.connector.system.jdbc.UdtJdbcTable; +import com.facebook.presto.operator.table.ExcludeColumns; +import com.facebook.presto.operator.table.Sequence; import com.facebook.presto.spi.SystemTable; +import com.facebook.presto.spi.function.table.ConnectorTableFunction; import com.facebook.presto.spi.procedure.Procedure; import com.google.common.collect.ImmutableMap; import com.google.inject.Binder; @@ -77,6 +80,10 @@ public void configure(Binder binder) binder.bind(GlobalSystemConnectorFactory.class).in(Scopes.SINGLETON); binder.bind(SystemConnectorRegistrar.class).asEagerSingleton(); + + Multibinder tableFunctions = Multibinder.newSetBinder(binder, ConnectorTableFunction.class); + tableFunctions.addBinding().toProvider(ExcludeColumns.class).in(Scopes.SINGLETON); + tableFunctions.addBinding().toProvider(Sequence.class).in(Scopes.SINGLETON); } @ProvidesIntoSet diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/SqlQueryExecution.java b/presto-main-base/src/main/java/com/facebook/presto/execution/SqlQueryExecution.java index 426011d14447e..f7e690c015860 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/SqlQueryExecution.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/SqlQueryExecution.java @@ -649,7 +649,7 @@ private PlanRoot runCreateLogicalPlanAsync() private void createQueryScheduler(PlanRoot plan) { - CloseableSplitSourceProvider splitSourceProvider = new CloseableSplitSourceProvider(splitManager::getSplits); + CloseableSplitSourceProvider splitSourceProvider = new CloseableSplitSourceProvider(splitManager); // ensure split sources are closed stateMachine.addStateChangeListener(state -> { diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/SqlTaskExecution.java b/presto-main-base/src/main/java/com/facebook/presto/execution/SqlTaskExecution.java index 899552d6feedf..0268c52825ce1 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/SqlTaskExecution.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/SqlTaskExecution.java @@ -1081,7 +1081,10 @@ public ListenableFuture processFor(Duration duration) @Override public String getInfo() { - return (partitionedSplit == null) ? "" : partitionedSplit.getSplit().getInfo().toString(); + if (partitionedSplit != null && partitionedSplit.getSplit() != null && partitionedSplit.getSplit().getInfo() != null) { + return partitionedSplit.getSplit().getInfo().toString(); + } + return ""; } @Override diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/FunctionAndTypeManager.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/FunctionAndTypeManager.java index 1e04a550ab0ed..c8b5ca2cfc6b9 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/metadata/FunctionAndTypeManager.java +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/FunctionAndTypeManager.java @@ -34,6 +34,7 @@ import com.facebook.presto.common.type.TypeWithName; import com.facebook.presto.common.type.UserDefinedType; import com.facebook.presto.operator.window.WindowFunctionSupplier; +import com.facebook.presto.spi.ConnectorId; import com.facebook.presto.spi.NodeManager; import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.StandardErrorCode; @@ -56,6 +57,8 @@ import com.facebook.presto.spi.function.SqlFunctionId; import com.facebook.presto.spi.function.SqlFunctionSupplier; import com.facebook.presto.spi.function.SqlInvokedFunction; +import com.facebook.presto.spi.function.table.ConnectorTableFunctionHandle; +import com.facebook.presto.spi.function.table.TableFunctionProcessorProvider; import com.facebook.presto.spi.type.TypeManagerContext; import com.facebook.presto.spi.type.TypeManagerFactory; import com.facebook.presto.sql.analyzer.FeaturesConfig; @@ -92,6 +95,7 @@ import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Function; import java.util.regex.Pattern; import static com.facebook.presto.SystemSessionProperties.isExperimentalFunctionsEnabled; @@ -105,6 +109,7 @@ import static com.facebook.presto.metadata.FunctionSignatureMatcher.decideAndThrow; import static com.facebook.presto.metadata.SessionFunctionHandle.SESSION_NAMESPACE; import static com.facebook.presto.metadata.SignatureBinder.applyBoundVariables; +import static com.facebook.presto.spi.StandardErrorCode.ALREADY_EXISTS; import static com.facebook.presto.spi.StandardErrorCode.AMBIGUOUS_FUNCTION_CALL; import static com.facebook.presto.spi.StandardErrorCode.FUNCTION_IMPLEMENTATION_MISSING; import static com.facebook.presto.spi.StandardErrorCode.FUNCTION_NOT_FOUND; @@ -160,6 +165,7 @@ public class FunctionAndTypeManager private final AtomicReference>> servingTypeManagerParametricTypesSupplier; private final BuiltInWorkerFunctionNamespaceManager builtInWorkerFunctionNamespaceManager; private final BuiltInPluginFunctionNamespaceManager builtInPluginFunctionNamespaceManager; + private final ConcurrentHashMap> tableFunctionProcessorProviderMap = new ConcurrentHashMap<>(); private final FunctionsConfig functionsConfig; private final Set types; @@ -704,6 +710,24 @@ public ScalarFunctionImplementation getScalarFunctionImplementation(FunctionHand return functionNamespaceManager.get().getScalarFunctionImplementation(functionHandle); } + public TableFunctionProcessorProvider getTableFunctionProcessorProvider(TableFunctionHandle tableFunctionHandle) + { + return tableFunctionProcessorProviderMap.get(tableFunctionHandle.getConnectorId()).apply(tableFunctionHandle.getFunctionHandle()); + } + + public void addTableFunctionProcessorProvider(ConnectorId connectorId, Function tableFunctionProcessorProvider) + { + if (tableFunctionProcessorProviderMap.putIfAbsent(connectorId, tableFunctionProcessorProvider) != null) { + throw new PrestoException(ALREADY_EXISTS, + format("TableFuncitonProcessorProvider already exists for connectorId %s. Overwriting is not supported.", connectorId.getCatalogName())); + } + } + + public void removeTableFunctionProcessorProvider(ConnectorId connectorId) + { + tableFunctionProcessorProviderMap.remove(connectorId); + } + public AggregationFunctionImplementation getAggregateFunctionImplementation(FunctionHandle functionHandle) { if (isBuiltInPluginFunctionHandle(functionHandle)) { diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/HandleJsonModule.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/HandleJsonModule.java index 38621cb5c76f3..43c0ad100f528 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/metadata/HandleJsonModule.java +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/HandleJsonModule.java @@ -51,6 +51,7 @@ public void configure(Binder binder) jsonBinder(binder).addModuleBinding().to(TransactionHandleJacksonModule.class); jsonBinder(binder).addModuleBinding().to(PartitioningHandleJacksonModule.class); jsonBinder(binder).addModuleBinding().to(FunctionHandleJacksonModule.class); + jsonBinder(binder).addModuleBinding().to(TableFunctionJacksonHandleModule.class); if (handleResolver == null) { binder.bind(HandleResolver.class).in(Scopes.SINGLETON); diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/HandleResolver.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/HandleResolver.java index 1541a98ee6bf7..7d039db3d8d48 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/metadata/HandleResolver.java +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/HandleResolver.java @@ -15,6 +15,8 @@ import com.facebook.presto.connector.informationSchema.InformationSchemaHandleResolver; import com.facebook.presto.connector.system.SystemHandleResolver; +import com.facebook.presto.operator.table.ExcludeColumns; +import com.facebook.presto.operator.table.Sequence; import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.ConnectorDeleteTableHandle; import com.facebook.presto.spi.ConnectorDistributedProcedureHandle; @@ -30,12 +32,17 @@ import com.facebook.presto.spi.connector.ConnectorTransactionHandle; import com.facebook.presto.spi.function.FunctionHandle; import com.facebook.presto.spi.function.FunctionHandleResolver; +import com.facebook.presto.spi.function.TableFunctionHandleResolver; +import com.facebook.presto.spi.function.TableFunctionSplitResolver; +import com.facebook.presto.spi.function.table.ConnectorTableFunctionHandle; import com.facebook.presto.split.EmptySplitHandleResolver; +import com.google.common.collect.ImmutableSet; import jakarta.inject.Inject; import java.util.Map.Entry; import java.util.Objects; import java.util.Optional; +import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; import java.util.function.Function; @@ -50,6 +57,8 @@ public class HandleResolver { private final ConcurrentMap handleResolvers = new ConcurrentHashMap<>(); private final ConcurrentMap functionHandleResolvers = new ConcurrentHashMap<>(); + private final ConcurrentMap> tableFunctionHandleResolvers = new ConcurrentHashMap<>(); + private final ConcurrentMap> tableFunctionSplitResolvers = new ConcurrentHashMap<>(); @Inject public HandleResolver() @@ -61,6 +70,17 @@ public HandleResolver() functionHandleResolvers.put("$static", new MaterializedFunctionHandleResolver(new BuiltInFunctionNamespaceHandleResolver())); functionHandleResolvers.put("$session", new MaterializedFunctionHandleResolver(new SessionFunctionHandleResolver())); + + tableFunctionHandleResolvers.put( + "$system", + new MaterializedResolver<>(() -> ImmutableSet.of( + ExcludeColumns.ExcludeColumnsFunctionHandle.class, + Sequence.SequenceFunctionHandle.class))); + + tableFunctionSplitResolvers.put( + "$system", + new MaterializedResolver<>(() -> + ImmutableSet.of(Sequence.SequenceFunctionSplit.class))); } public void addConnectorName(String name, ConnectorHandleResolver resolver) @@ -72,6 +92,32 @@ public void addConnectorName(String name, ConnectorHandleResolver resolver) "Connector '%s' is already assigned to resolver: %s", name, existingResolver); } + public void addTableFunctionNamespace(String name, TableFunctionHandleResolver resolver) + { + addNamespace(name, resolver::getTableFunctionHandleClasses, tableFunctionHandleResolvers); + } + + public void addTableFunctionSplitNamespace(String name, TableFunctionSplitResolver resolver) + { + addNamespace(name, resolver::getTableFunctionSplitClasses, tableFunctionSplitResolvers); + } + + private void addNamespace( + String name, + Supplier>> classSupplier, + ConcurrentMap> resolverMap) + { + requireNonNull(name, "name is null"); + requireNonNull(classSupplier, "classSupplier is null"); + + MaterializedResolver newResolver = new MaterializedResolver<>(classSupplier); + MaterializedResolver existingResolver = resolverMap.putIfAbsent(name, newResolver); + + checkState( + existingResolver == null || existingResolver.equals(newResolver), + "Name %s is already assigned to table function resolver: %s", name, existingResolver); + } + public void addFunctionNamespace(String name, FunctionHandleResolver resolver) { requireNonNull(name, "name is null"); @@ -98,6 +144,18 @@ public String getId(ColumnHandle columnHandle) public String getId(ConnectorSplit split) { + // First check if this is a table function split + for (Entry> entry : tableFunctionSplitResolvers.entrySet()) { + Optional id = entry.getValue().getClasses().stream() + .filter(clazz -> clazz.isInstance(split)) + .map(Class::getName) + .findFirst(); + if (id.isPresent()) { + return entry.getKey() + ":" + id.get(); + } + } + + // Fall back to regular connector splits return getId(split, MaterializedHandleResolver::getSplitClass); } @@ -146,6 +204,20 @@ public String getId(ConnectorMergeTableHandle mergeHandle) return getId(mergeHandle, MaterializedHandleResolver::getMergeTableHandleClass); } + public String getId(ConnectorTableFunctionHandle tableFunctionHandle) + { + for (Entry> entry : tableFunctionHandleResolvers.entrySet()) { + Optional id = entry.getValue().getClasses().stream() + .filter(clazz -> clazz.isInstance(tableFunctionHandle)) + .map(Class::getName) + .findFirst(); + if (id.isPresent()) { + return entry.getKey() + ":" + id.get(); + } + } + throw new IllegalArgumentException("No function namespace for table function handle: " + tableFunctionHandle); + } + public Class getTableHandleClass(String id) { return resolverFor(id).getTableHandleClass().orElseThrow(() -> new IllegalArgumentException("No resolver for " + id)); @@ -163,7 +235,17 @@ public Class getColumnHandleClass(String id) public Class getSplitClass(String id) { - return resolverFor(id).getSplitClass().orElseThrow(() -> new IllegalArgumentException("No resolver for " + id)); + for (Entry> entry : tableFunctionSplitResolvers.entrySet()) { + MaterializedResolver resolver = entry.getValue(); + Optional> tableFunctionSplit = resolver.getClasses().stream() + .filter(handle -> (entry.getKey() + ":" + handle.getName()).equals(id)) + .findFirst(); + if (tableFunctionSplit.isPresent()) { + return tableFunctionSplit.get(); + } + } + return resolverFor(id).getSplitClass() + .orElseThrow(() -> new IllegalArgumentException("No resolver for " + id)); } public Class getIndexHandleClass(String id) @@ -211,6 +293,20 @@ public Class getFunctionHandleClass(String id) return resolverForFunctionNamespace(id).getFunctionHandleClass().orElseThrow(() -> new IllegalArgumentException("No resolver for " + id)); } + public Class getTableFunctionHandleClass(String id) + { + for (Entry> entry : tableFunctionHandleResolvers.entrySet()) { + MaterializedResolver resolver = entry.getValue(); + Optional> tableFunctionHandle = resolver.getClasses().stream() + .filter(handle -> (entry.getKey() + ":" + handle.getName()).equals(id)) + .findFirst(); + if (tableFunctionHandle.isPresent()) { + return tableFunctionHandle.get(); + } + } + throw new IllegalArgumentException("No handle resolver for table function namespace: " + id); + } + private MaterializedHandleResolver resolverFor(String id) { MaterializedHandleResolver resolver = handleResolvers.get(id); @@ -267,6 +363,7 @@ private static class MaterializedHandleResolver private final Optional> distributedProcedureHandle; private final Optional> partitioningHandle; private final Optional> transactionHandle; + private final Optional> tableFunctionHandle; public MaterializedHandleResolver(ConnectorHandleResolver resolver) { @@ -282,6 +379,7 @@ public MaterializedHandleResolver(ConnectorHandleResolver resolver) partitioningHandle = getHandleClass(resolver::getPartitioningHandleClass); transactionHandle = getHandleClass(resolver::getTransactionHandleClass); distributedProcedureHandle = getHandleClass(resolver::getDistributedProcedureHandleClass); + tableFunctionHandle = getHandleClass(resolver::getTableFunctionHandleClass); } private static Optional> getHandleClass(Supplier> callable) @@ -354,6 +452,11 @@ public Optional> getTransactionHandl return transactionHandle; } + public Optional> getTableFunctionHandleClass() + { + return tableFunctionHandle; + } + @Override public boolean equals(Object o) { @@ -374,13 +477,14 @@ public boolean equals(Object o) Objects.equals(deleteTableHandle, that.deleteTableHandle) && Objects.equals(mergeTableHandle, that.mergeTableHandle) && Objects.equals(partitioningHandle, that.partitioningHandle) && - Objects.equals(transactionHandle, that.transactionHandle); + Objects.equals(transactionHandle, that.transactionHandle) && + Objects.equals(tableFunctionHandle, that.tableFunctionHandle); } @Override public int hashCode() { - return Objects.hash(tableHandle, layoutHandle, columnHandle, split, indexHandle, outputTableHandle, insertTableHandle, deleteTableHandle, mergeTableHandle, partitioningHandle, transactionHandle); + return Objects.hash(tableHandle, layoutHandle, columnHandle, split, indexHandle, outputTableHandle, insertTableHandle, deleteTableHandle, mergeTableHandle, partitioningHandle, transactionHandle, tableFunctionHandle); } } @@ -427,4 +531,48 @@ public int hashCode() return Objects.hash(functionHandle); } } + + private static class MaterializedResolver + { + private final Set> classes; + + public MaterializedResolver(Supplier>> classSupplier) + { + this.classes = getSafe(classSupplier); + } + + private static Set> getSafe(Supplier>> classSupplier) + { + try { + return classSupplier.get(); + } + catch (UnsupportedOperationException e) { + return ImmutableSet.of(); + } + } + + public Set> getClasses() + { + return classes; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + MaterializedResolver that = (MaterializedResolver) o; + return Objects.equals(classes, that.classes); + } + + @Override + public int hashCode() + { + return Objects.hash(classes); + } + } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/Metadata.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/Metadata.java index c79a55df796df..5fed0540cc72a 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/metadata/Metadata.java +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/Metadata.java @@ -609,5 +609,10 @@ default boolean isPushdownSupportedForFilter(Session session, TableHandle tableH String normalizeIdentifier(Session session, String catalogName, String identifier); + /** + * Attempt to push down the table function invocation into the connector. + * @return {@link Optional#empty()} if the connector doesn't support table function invocation pushdown, + * or an {@code Optional>} containing the table handle that will be used in place of the table function invocation. + */ Optional> applyTableFunction(Session session, TableFunctionHandle handle); } diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/TableFunctionJacksonHandleModule.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/TableFunctionJacksonHandleModule.java new file mode 100644 index 0000000000000..9f289f4ac491f --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/TableFunctionJacksonHandleModule.java @@ -0,0 +1,40 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.metadata; + +import com.facebook.presto.connector.ConnectorManager; +import com.facebook.presto.spi.connector.ConnectorCodecProvider; +import com.facebook.presto.spi.function.table.ConnectorTableFunctionHandle; +import com.facebook.presto.sql.analyzer.FeaturesConfig; +import jakarta.inject.Inject; +import jakarta.inject.Provider; + +public class TableFunctionJacksonHandleModule + extends AbstractTypedJacksonModule +{ + @Inject + public TableFunctionJacksonHandleModule( + HandleResolver handleResolver, + Provider connectorManagerProvider, + FeaturesConfig featuresConfig) + { + super(ConnectorTableFunctionHandle.class, + handleResolver::getId, + handleResolver::getTableFunctionHandleClass, + featuresConfig.isUseConnectorProvidedSerializationCodecs(), + connectorId -> connectorManagerProvider.get() + .getConnectorCodecProvider(connectorId) + .flatMap(ConnectorCodecProvider::getConnectorTableFunctionHandleCodec)); + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/TableFunctionRegistry.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/TableFunctionRegistry.java index da7e23f9bbf2b..d624b8364e35b 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/metadata/TableFunctionRegistry.java +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/TableFunctionRegistry.java @@ -101,6 +101,9 @@ public static List toPath(Session session, QualifiedN // add resolved path items names.add(new CatalogSchemaFunctionName(currentCatalog, currentSchema, parts.get(0))); + + // add builtin path items + names.add(new CatalogSchemaFunctionName("system", "builtin", parts.get(0))); return names.build(); } diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/EmptyTableFunctionPartition.java b/presto-main-base/src/main/java/com/facebook/presto/operator/EmptyTableFunctionPartition.java new file mode 100644 index 0000000000000..bda83ae6319d4 --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/EmptyTableFunctionPartition.java @@ -0,0 +1,107 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator; + +import com.facebook.presto.common.Page; +import com.facebook.presto.common.block.Block; +import com.facebook.presto.common.block.RunLengthEncodedBlock; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.spi.PrestoException; +import com.facebook.presto.spi.function.table.TableFunctionDataProcessor; +import com.facebook.presto.spi.function.table.TableFunctionProcessorState; + +import java.util.List; + +import static com.facebook.airlift.concurrent.MoreFutures.toListenableFuture; +import static com.facebook.presto.spi.StandardErrorCode.FUNCTION_IMPLEMENTATION_ERROR; +import static com.facebook.presto.spi.function.table.TableFunctionProcessorState.Finished.FINISHED; +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; + +/** + * This is a class representing empty input to a table function. An EmptyTableFunctionPartition is created + * when the table function has KEEP WHEN EMPTY property, which means that the function should be executed + * even if the input is empty, and all the table arguments are empty relations. + *

+ * An EmptyTableFunctionPartition is created and processed once per node. To avoid duplicated execution, + * a table function having KEEP WHEN EMPTY property must have single distribution. + */ +public class EmptyTableFunctionPartition + implements TableFunctionPartition +{ + private final TableFunctionDataProcessor tableFunction; + private final int properChannelsCount; + private final int passThroughSourcesCount; + private final Type[] passThroughTypes; + + public EmptyTableFunctionPartition(TableFunctionDataProcessor tableFunction, int properChannelsCount, int passThroughSourcesCount, List passThroughTypes) + { + this.tableFunction = requireNonNull(tableFunction, "tableFunction is null"); + this.properChannelsCount = properChannelsCount; + this.passThroughSourcesCount = passThroughSourcesCount; + this.passThroughTypes = passThroughTypes.toArray(new Type[] {}); + } + + @Override + public WorkProcessor toOutputPages() + { + return WorkProcessor.create(() -> { + TableFunctionProcessorState state = tableFunction.process(null); + if (state == FINISHED) { + return WorkProcessor.ProcessState.finished(); + } + if (state instanceof TableFunctionProcessorState.Blocked) { + return WorkProcessor.ProcessState.blocked(toListenableFuture(((TableFunctionProcessorState.Blocked) state).getFuture())); + } + TableFunctionProcessorState.Processed processed = (TableFunctionProcessorState.Processed) state; + if (processed.getResult() != null) { + return WorkProcessor.ProcessState.ofResult(appendNullsForPassThroughColumns(processed.getResult())); + } + throw new PrestoException(FUNCTION_IMPLEMENTATION_ERROR, "When function got no input, it should either produce output or return Blocked state"); + }); + } + + private Page appendNullsForPassThroughColumns(Page page) + { + if (page.getChannelCount() != properChannelsCount + passThroughSourcesCount) { + throw new PrestoException( + FUNCTION_IMPLEMENTATION_ERROR, + format( + "Table function returned a page containing %s channels. Expected channel number: %s (%s proper columns, %s pass-through index columns)", + page.getChannelCount(), + properChannelsCount + passThroughSourcesCount, + properChannelsCount, + passThroughSourcesCount)); + } + + Block[] resultBlocks = new Block[properChannelsCount + passThroughTypes.length]; + + // proper outputs first + for (int channel = 0; channel < properChannelsCount; channel++) { + resultBlocks[channel] = page.getBlock(channel); + } + + // pass-through columns next + // because no input was processed, all pass-through indexes in the result page must be null (there are no input rows they could refer to). + // for performance reasons this is not checked. All pass-through columns are filled with nulls. + int channel = properChannelsCount; + for (Type type : passThroughTypes) { + resultBlocks[channel] = RunLengthEncodedBlock.create(type, null, page.getPositionCount()); + channel++; + } + + // pass the position count so that the Page can be successfully created in the case when there are no output channels (resultBlocks is empty) + return new Page(page.getPositionCount(), resultBlocks); + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/LeafTableFunctionOperator.java b/presto-main-base/src/main/java/com/facebook/presto/operator/LeafTableFunctionOperator.java new file mode 100644 index 0000000000000..3eb272cef09ec --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/LeafTableFunctionOperator.java @@ -0,0 +1,205 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator; + +import com.facebook.presto.common.Page; +import com.facebook.presto.execution.ScheduledSplit; +import com.facebook.presto.metadata.Split; +import com.facebook.presto.spi.ConnectorSplit; +import com.facebook.presto.spi.UpdatablePageSource; +import com.facebook.presto.spi.function.table.ConnectorTableFunctionHandle; +import com.facebook.presto.spi.function.table.TableFunctionProcessorProvider; +import com.facebook.presto.spi.function.table.TableFunctionProcessorState; +import com.facebook.presto.spi.function.table.TableFunctionProcessorState.Blocked; +import com.facebook.presto.spi.function.table.TableFunctionProcessorState.Processed; +import com.facebook.presto.spi.function.table.TableFunctionSplitProcessor; +import com.facebook.presto.spi.plan.PlanNodeId; +import com.google.common.util.concurrent.ListenableFuture; + +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; +import java.util.function.Supplier; + +import static com.facebook.airlift.concurrent.MoreFutures.toListenableFuture; +import static com.facebook.presto.spi.function.table.TableFunctionProcessorState.Finished.FINISHED; +import static com.google.common.base.Preconditions.checkState; +import static java.util.Objects.requireNonNull; + +public class LeafTableFunctionOperator + implements SourceOperator +{ + public static class LeafTableFunctionOperatorFactory + implements SourceOperatorFactory + { + private final int operatorId; + private final PlanNodeId sourceId; + private final TableFunctionProcessorProvider tableFunctionProvider; + private final ConnectorTableFunctionHandle functionHandle; + private boolean closed; + + public LeafTableFunctionOperatorFactory(int operatorId, PlanNodeId sourceId, TableFunctionProcessorProvider tableFunctionProvider, ConnectorTableFunctionHandle functionHandle) + { + this.operatorId = operatorId; + this.sourceId = requireNonNull(sourceId, "sourceId is null"); + this.tableFunctionProvider = requireNonNull(tableFunctionProvider, "tableFunctionProvider is null"); + this.functionHandle = requireNonNull(functionHandle, "functionHandle is null"); + } + + @Override + public PlanNodeId getSourceId() + { + return sourceId; + } + + @Override + public SourceOperator createOperator(DriverContext driverContext) + { + checkState(!closed, "Factory is already closed"); + OperatorContext operatorContext = driverContext.addOperatorContext(operatorId, sourceId, LeafTableFunctionOperator.class.getSimpleName()); + return new LeafTableFunctionOperator(operatorContext, sourceId, tableFunctionProvider, functionHandle); + } + + @Override + public void noMoreOperators() + { + closed = true; + } + } + + private final OperatorContext operatorContext; + private final PlanNodeId sourceId; + private final TableFunctionProcessorProvider tableFunctionProvider; + private final ConnectorTableFunctionHandle functionHandle; + + private ConnectorSplit currentSplit; + private final List pendingSplits = new ArrayList<>(); + private boolean noMoreSplits; + + private TableFunctionSplitProcessor processor; + private boolean processorUsedData; + private boolean processorFinishedSplit = true; + private ListenableFuture processorBlocked = NOT_BLOCKED; + + public LeafTableFunctionOperator(OperatorContext operatorContext, PlanNodeId sourceId, TableFunctionProcessorProvider tableFunctionProvider, ConnectorTableFunctionHandle functionHandle) + { + this.operatorContext = requireNonNull(operatorContext, "operatorContext is null"); + this.sourceId = requireNonNull(sourceId, "sourceId is null"); + this.tableFunctionProvider = requireNonNull(tableFunctionProvider, "tableFunctionProvider is null"); + this.functionHandle = requireNonNull(functionHandle, "functionHandle is null"); + } + + private void resetProcessor() + { + this.processor = tableFunctionProvider.getSplitProcessor(functionHandle); + this.processorUsedData = false; + this.processorFinishedSplit = false; + this.processorBlocked = NOT_BLOCKED; + } + + @Override + public OperatorContext getOperatorContext() + { + return operatorContext; + } + + @Override + public PlanNodeId getSourceId() + { + return sourceId; + } + + @Override + public boolean needsInput() + { + return false; + } + + @Override + public void addInput(Page page) + { + throw new UnsupportedOperationException(getClass().getName() + " does not take input"); + } + + @Override + public Supplier> addSplit(ScheduledSplit split) + { + Split curSplit = requireNonNull(split, "split is null").getSplit(); + checkState(!noMoreSplits, "no more splits expected"); + ConnectorSplit curConnectorSplit = curSplit.getConnectorSplit(); + pendingSplits.add(curConnectorSplit); + return Optional::empty; + } + + @Override + public void noMoreSplits() + { + noMoreSplits = true; + } + + @Override + public Page getOutput() + { + if (processorFinishedSplit) { + // start processing a new split + if (pendingSplits.isEmpty()) { + // no more splits to process at the moment + return null; + } + currentSplit = pendingSplits.remove(0); + resetProcessor(); + } + else { + // a split is being processed + requireNonNull(currentSplit, "currentSplit is null"); + } + + TableFunctionProcessorState state = processor.process(processorUsedData ? null : currentSplit); + if (state == FINISHED) { + processorFinishedSplit = true; + } + if (state instanceof Blocked) { + Blocked blocked = (Blocked) state; + processorBlocked = toListenableFuture(blocked.getFuture()); + } + if (state instanceof Processed) { + Processed processed = (Processed) state; + if (processed.isUsedInput()) { + processorUsedData = true; + } + if (processed.getResult() != null) { + return processed.getResult(); + } + } + return null; + } + + @Override + public ListenableFuture isBlocked() + { + return processorBlocked; + } + + @Override + public void finish() + { + // this method is redundant. the operator takes no input at all. noMoreSplits() should be called instead. + } + + @Override + public boolean isFinished() + { + return processorFinishedSplit && pendingSplits.isEmpty() && noMoreSplits; + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/PageBuffer.java b/presto-main-base/src/main/java/com/facebook/presto/operator/PageBuffer.java new file mode 100644 index 0000000000000..82fe581a61db3 --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/PageBuffer.java @@ -0,0 +1,75 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator; + +import com.facebook.presto.common.Page; +import jakarta.annotation.Nullable; + +import static com.facebook.presto.operator.WorkProcessor.ProcessState.finished; +import static com.facebook.presto.operator.WorkProcessor.ProcessState.ofResult; +import static com.facebook.presto.operator.WorkProcessor.ProcessState.yield; +import static com.google.common.base.Preconditions.checkState; +import static java.util.Objects.requireNonNull; + +public class PageBuffer +{ + @Nullable + private Page page; + private boolean finished; + + public WorkProcessor pages() + { + return WorkProcessor.create(() -> { + if (isFinished() && isEmpty()) { + return finished(); + } + + if (!isEmpty()) { + Page result = page; + page = null; + return ofResult(result); + } + + return yield(); + }); + } + + public boolean isEmpty() + { + return page == null; + } + + public boolean isFinished() + { + return finished; + } + + public void add(Page page) + { + checkState(isEmpty(), "page buffer is not empty"); + checkState(!isFinished(), "page buffer is finished"); + requireNonNull(page, "page is null"); + + if (page.getPositionCount() == 0) { + return; + } + + this.page = page; + } + + public void finish() + { + finished = true; + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/PagesIndex.java b/presto-main-base/src/main/java/com/facebook/presto/operator/PagesIndex.java index 640ae9919ca90..b0b82f340b5b3 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/PagesIndex.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/PagesIndex.java @@ -270,9 +270,12 @@ public void swap(int a, int b) valueAddresses.swap(a, b); } - public int buildPage(int position, int[] outputChannels, PageBuilder pageBuilder) + public int buildPage(int position, int endPosition, int[] outputChannels, PageBuilder pageBuilder) { - while (!pageBuilder.isFull() && position < positionCount) { + // Check both endPosition (for range-based iteration) and positionCount (to handle concurrent clear()). + // If clear() is called while an iterator is consuming pages, positionCount becomes 0, + // allowing the loop to exit gracefully instead of accessing cleared internal arrays. + while (!pageBuilder.isFull() && position < endPosition && position < positionCount) { long pageAddress = valueAddresses.get(position); int blockIndex = decodeSliceIndex(pageAddress); int blockPosition = decodePosition(pageAddress); @@ -562,10 +565,29 @@ protected Page computeNext() } public Iterator getSortedPages() + { + return getSortedPagesFromRange(0, positionCount); + } + + /** + * Get sorted pages from the specified section of the PagesIndex. + * + * @param start start position of the section, inclusive + * @param end end position of the section, exclusive + * @return iterator of pages + */ + public Iterator getSortedPages(int start, int end) + { + checkArgument(start >= 0 && end <= positionCount, "position range out of bounds"); + checkArgument(start <= end, "invalid position range"); + return getSortedPagesFromRange(start, end); + } + + private Iterator getSortedPagesFromRange(int start, int end) { return new AbstractIterator() { - private int currentPosition; + private int currentPosition = start; private final PageBuilder pageBuilder = new PageBuilder(types); private final int[] outputChannels = new int[types.size()]; @@ -576,7 +598,7 @@ public Iterator getSortedPages() @Override public Page computeNext() { - currentPosition = buildPage(currentPosition, outputChannels, pageBuilder); + currentPosition = buildPage(currentPosition, end, outputChannels, pageBuilder); if (pageBuilder.isEmpty()) { return endOfData(); } diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/RegularTableFunctionPartition.java b/presto-main-base/src/main/java/com/facebook/presto/operator/RegularTableFunctionPartition.java new file mode 100644 index 0000000000000..5d0376f3be7ee --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/RegularTableFunctionPartition.java @@ -0,0 +1,438 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator; + +import com.facebook.presto.common.Page; +import com.facebook.presto.common.block.Block; +import com.facebook.presto.common.block.BlockBuilder; +import com.facebook.presto.common.block.RunLengthEncodedBlock; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.spi.PrestoException; +import com.facebook.presto.spi.function.table.TableFunctionDataProcessor; +import com.facebook.presto.spi.function.table.TableFunctionProcessorState; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import com.google.common.primitives.Ints; + +import java.util.Arrays; +import java.util.Collection; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.NoSuchElementException; +import java.util.Optional; +import java.util.Set; + +import static com.facebook.airlift.concurrent.MoreFutures.toListenableFuture; +import static com.facebook.presto.common.Utils.checkState; +import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.spi.StandardErrorCode.FUNCTION_IMPLEMENTATION_ERROR; +import static com.facebook.presto.spi.function.table.TableFunctionProcessorState.Finished.FINISHED; +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.util.concurrent.Futures.immediateFuture; +import static java.lang.Math.min; +import static java.lang.Math.toIntExact; +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; + +public class RegularTableFunctionPartition + implements TableFunctionPartition +{ + private final PagesIndex pagesIndex; + private final int partitionStart; + private final int partitionEnd; + private final Iterator sortedPages; + + private final TableFunctionDataProcessor tableFunction; + private final int properChannelsCount; + private final int passThroughSourcesCount; + + // channels required by the table function, listed by source in order of argument declarations + private final int[][] requiredChannels; + + // for each input channel, the end position of actual data in that channel (exclusive) relative to partition. The remaining rows are "filler" rows, and should not be passed to table function or passed-through + private final int[] endOfData; + + // a builder for each pass-through column, in order of argument declarations + private final PassThroughColumnProvider[] passThroughProviders; + + // number of processed input positions from partition start. all sources have been processed up to this position, except the sources whose partitions ended earlier. + private int processedPositions; + + public RegularTableFunctionPartition( + PagesIndex pagesIndex, + int partitionStart, + int partitionEnd, + TableFunctionDataProcessor tableFunction, + int properChannelsCount, + int passThroughSourcesCount, + List> requiredChannels, + Optional> markerChannels, + List passThroughSpecifications) + + { + checkArgument(pagesIndex.getPositionCount() != 0, "PagesIndex is empty for regular table function partition"); + this.pagesIndex = pagesIndex; + this.partitionStart = partitionStart; + this.partitionEnd = partitionEnd; + this.sortedPages = pagesIndex.getSortedPages(partitionStart, partitionEnd); + this.tableFunction = requireNonNull(tableFunction, "tableFunction is null"); + this.properChannelsCount = properChannelsCount; + this.passThroughSourcesCount = passThroughSourcesCount; + this.requiredChannels = requiredChannels.stream() + .map(Ints::toArray) + .toArray(int[][]::new); + this.endOfData = findEndOfData(markerChannels, requiredChannels, passThroughSpecifications); + for (List channels : requiredChannels) { + checkState( + channels.stream() + .mapToInt(channel -> endOfData[channel]) + .distinct() + .count() <= 1, + "end-of-data position is inconsistent within a table function source"); + } + this.passThroughProviders = new PassThroughColumnProvider[passThroughSpecifications.size()]; + for (int i = 0; i < passThroughSpecifications.size(); i++) { + passThroughProviders[i] = createColumnProvider(passThroughSpecifications.get(i)); + } + } + + @Override + public WorkProcessor toOutputPages() + { + return WorkProcessor.create(new WorkProcessor.Process() + { + List> inputPages = prepareInputPages(); + + @Override + public WorkProcessor.ProcessState process() + { + TableFunctionProcessorState state = tableFunction.process(inputPages); + boolean functionGotNoData = inputPages == null; + if (state == FINISHED) { + return WorkProcessor.ProcessState.finished(); + } + if (state instanceof TableFunctionProcessorState.Blocked) { + return WorkProcessor.ProcessState.blocked(toListenableFuture(((TableFunctionProcessorState.Blocked) state).getFuture())); + } + TableFunctionProcessorState.Processed processed = (TableFunctionProcessorState.Processed) state; + if (processed.isUsedInput()) { + inputPages = prepareInputPages(); + } + if (processed.getResult() != null) { + return WorkProcessor.ProcessState.ofResult(appendPassThroughColumns(processed.getResult())); + } + if (functionGotNoData) { + throw new PrestoException(FUNCTION_IMPLEMENTATION_ERROR, "When function got no input, it should either produce output or return Blocked state"); + } + return WorkProcessor.ProcessState.blocked(immediateFuture(null)); + } + }); + } + + /** + * Iterate over the partition by page and extract pages for each table function source from the input page. + * For each source, project the columns required by the table function. + * If for some source all data in the partition has been consumed, Optional.empty() is returned for that source. + * It happens when the partition of this source is shorter than the partition of some other source. + * The overall length of the table function partition is equal to the length of the longest source partition. + * When all sources are fully consumed, this method returns null. + *

+ * NOTE: There are two types of table function's source semantics: set and row. The two types of sources should be handled + * by the TableFunctionDataProcessor in different ways. For a source with set semantics, the whole partition can be used for computations, + * while for a source with row semantics, each row should be processed independently from all other rows. + * To enforce that behavior, we could pass to the TableFunctionDataProcessor only one row from a table with row semantics. + * However, for performance reasons, we handle sources with row and set semantics in the same way: the TableFunctionDataProcessor + * gets a page of data from each source. The TableFunctionDataProcessor is responsible for using the provided data accordingly + * to the declared source semantics (set or rows). + * + * @return A List containing: + * - Optional Page for every source that is not fully consumed + * - Optional.empty() for every source that is fully consumed + * or null if all sources are fully consumed. + */ + private List> prepareInputPages() + { + if (!sortedPages.hasNext()) { + return null; + } + + Page inputPage = sortedPages.next(); + ImmutableList.Builder> sourcePages = ImmutableList.builder(); + + for (int[] channelsForSource : requiredChannels) { + if (channelsForSource.length == 0) { + sourcePages.add(Optional.of(new Page(inputPage.getPositionCount()))); + } + else { + int endOfDataForSource = endOfData[channelsForSource[0]]; // end-of-data position is validated to be consistent for all channels from source + if (endOfDataForSource <= processedPositions) { + // all data for this source was already processed + sourcePages.add(Optional.empty()); + } + else { + Block[] sourceBlocks = new Block[channelsForSource.length]; + if (endOfDataForSource < processedPositions + inputPage.getPositionCount()) { + // data for this source ends within the current page + for (int i = 0; i < channelsForSource.length; i++) { + int inputChannel = channelsForSource[i]; + sourceBlocks[i] = inputPage.getBlock(inputChannel).getRegion(0, endOfDataForSource - processedPositions); + } + } + else { + // data for this source does not end within the current page + for (int i = 0; i < channelsForSource.length; i++) { + int inputChannel = channelsForSource[i]; + sourceBlocks[i] = inputPage.getBlock(inputChannel); + } + } + sourcePages.add(Optional.of(new Page(sourceBlocks))); + } + } + } + + processedPositions += inputPage.getPositionCount(); + + return sourcePages.build(); + } + + /** + * There are two types of table function's source semantics: set and row. + *

+ * For a source with row semantics, the table function result depends on the whole partition, + * so it is not always possible to associate an output row with a specific input row. + * The TableFunctionDataProcessor can return null as the pass-through index to indicate that + * the output row is not associated with any row from the given source. + *

+ * For a source with row semantics, the output is determined on a row-by-row basis, so every + * output row is associated with a specific input row. In such case, the pass-through index + * should never be null. + *

+ * In our implementation, we handle sources with row and set semantics in the same way. + * For performance reasons, we do not validate the null pass-through indexes. + * The TableFunctionDataProcessor is responsible for using the pass-through capability + * accordingly to the declared source semantics (set or rows). + */ + private Page appendPassThroughColumns(Page page) + { + if (page.getChannelCount() != properChannelsCount + passThroughSourcesCount) { + throw new PrestoException( + FUNCTION_IMPLEMENTATION_ERROR, + format( + "Table function returned a page containing %s channels. Expected channel number: %s (%s proper columns, %s pass-through index columns)", + page.getChannelCount(), + properChannelsCount + passThroughSourcesCount, + properChannelsCount, + passThroughSourcesCount)); + } + // TODO is it possible to verify types of columns returned by TF? + + Block[] resultBlocks = new Block[properChannelsCount + passThroughProviders.length]; + + // proper outputs first + for (int channel = 0; channel < properChannelsCount; channel++) { + resultBlocks[channel] = page.getBlock(channel); + } + + // pass-through columns next + int channel = properChannelsCount; + for (PassThroughColumnProvider provider : passThroughProviders) { + resultBlocks[channel] = provider.getPassThroughColumn(page); + channel++; + } + + // pass the position count so that the Page can be successfully created in the case when there are no output channels (resultBlocks is empty) + return new Page(page.getPositionCount(), resultBlocks); + } + + private int[] findEndOfData(Optional> markerChannels, List> requiredChannels, List passThroughSpecifications) + { + Set referencedChannels = ImmutableSet.builder() + .addAll(requiredChannels.stream() + .flatMap(Collection::stream) + .collect(toImmutableList())) + .addAll(passThroughSpecifications.stream() + .map(PassThroughColumnSpecification::getInputChannel) + .collect(toImmutableList())) + .build(); + + if (referencedChannels.isEmpty()) { + // no required or pass-through channels + return null; + } + + int maxInputChannel = referencedChannels.stream() + .mapToInt(Integer::intValue) + .max() + .orElseThrow(NoSuchElementException::new); + + int[] result = new int[maxInputChannel + 1]; + Arrays.fill(result, -1); + + // if table function had one source, adding a marker channel was not necessary. + // end-of-data position is equal to partition end for each input channel + if (!markerChannels.isPresent()) { + referencedChannels.stream() + .forEach(channel -> result[channel] = partitionEnd - partitionStart); + return result; + } + + // if table function had more than one source, the markers map shall be present, and it shall contain mapping for each input channel + ImmutableMap.Builder endOfDataPerMarkerBuilder = ImmutableMap.builder(); + for (int markerChannel : ImmutableSet.copyOf(markerChannels.orElseThrow(NoSuchElementException::new).values())) { + endOfDataPerMarkerBuilder.put(markerChannel, findFirstNullPosition(markerChannel)); + } + Map endOfDataPerMarker = endOfDataPerMarkerBuilder.buildOrThrow(); + referencedChannels.stream() + .forEach(channel -> result[channel] = endOfDataPerMarker.get(markerChannels.orElseThrow(NoSuchElementException::new).get(channel)) - partitionStart); + + return result; + } + + private int findFirstNullPosition(int markerChannel) + { + if (pagesIndex.isNull(markerChannel, partitionStart)) { + return partitionStart; + } + if (!pagesIndex.isNull(markerChannel, partitionEnd - 1)) { + return partitionEnd; + } + + int start = partitionStart; + int end = partitionEnd; + // value at start is not null, value at end is null + while (end - start > 1) { + int mid = (start + end) >>> 1; + if (pagesIndex.isNull(markerChannel, mid)) { + end = mid; + } + else { + start = mid; + } + } + return end; + } + + public static class PassThroughColumnSpecification + { + private final boolean isPartitioningColumn; + private final int inputChannel; + private final int indexChannel; + + public PassThroughColumnSpecification(boolean isPartitioningColumn, int inputChannel, int indexChannel) + { + this.isPartitioningColumn = isPartitioningColumn; + this.inputChannel = inputChannel; + this.indexChannel = indexChannel; + } + + public boolean isPartitioningColumn() + { + return isPartitioningColumn; + } + + public int getInputChannel() + { + return inputChannel; + } + + public int getIndexChannel() + { + return indexChannel; + } + } + + private PassThroughColumnProvider createColumnProvider(PassThroughColumnSpecification specification) + { + if (specification.isPartitioningColumn()) { + return new PartitioningColumnProvider(pagesIndex.getSingleValueBlock(specification.getInputChannel(), partitionStart)); + } + return new NonPartitioningColumnProvider(specification.getInputChannel(), specification.getIndexChannel()); + } + + private interface PassThroughColumnProvider + { + Block getPassThroughColumn(Page page); + } + + private static class PartitioningColumnProvider + implements PassThroughColumnProvider + { + private final Block partitioningValue; + + private PartitioningColumnProvider(Block partitioningValue) + { + this.partitioningValue = requireNonNull(partitioningValue, "partitioningValue is null"); + } + + @Override + public Block getPassThroughColumn(Page page) + { + return new RunLengthEncodedBlock(partitioningValue, page.getPositionCount()); + } + + public Block getPartitioningValue() + { + return partitioningValue; + } + } + + private final class NonPartitioningColumnProvider + implements PassThroughColumnProvider + { + private final int inputChannel; + private final Type type; + private final int indexChannel; + + public NonPartitioningColumnProvider(int inputChannel, int indexChannel) + { + this.inputChannel = inputChannel; + this.type = pagesIndex.getType(inputChannel); + this.indexChannel = indexChannel; + } + + @Override + public Block getPassThroughColumn(Page page) + { + Block indexes = page.getBlock(indexChannel); + BlockBuilder builder = type.createBlockBuilder(null, page.getPositionCount()); + for (int position = 0; position < page.getPositionCount(); position++) { + if (indexes.isNull(position)) { + builder.appendNull(); + } + else { + // table function returns index from partition start + long index = BIGINT.getLong(indexes, position); + // validate index + if (index < 0 || index >= endOfData[inputChannel] || index >= processedPositions) { + int end = min(endOfData[inputChannel], processedPositions) - 1; + if (end >= 0) { + throw new PrestoException(FUNCTION_IMPLEMENTATION_ERROR, format("Index of a pass-through row: %s out of processed portion of partition [0, %s]", index, end)); + } + else { + throw new PrestoException(FUNCTION_IMPLEMENTATION_ERROR, "Index of a pass-through row must be null when no input data from the partition was processed. Actual: " + index); + } + } + // index in PagesIndex + long absoluteIndex = partitionStart + index; + pagesIndex.appendTo(inputChannel, toIntExact(absoluteIndex), builder); + } + } + + return builder.build(); + } + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/TableFunctionOperator.java b/presto-main-base/src/main/java/com/facebook/presto/operator/TableFunctionOperator.java new file mode 100644 index 0000000000000..947389f342ed5 --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/TableFunctionOperator.java @@ -0,0 +1,635 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator; + +import com.facebook.presto.common.Page; +import com.facebook.presto.common.block.SortOrder; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.memory.context.LocalMemoryContext; +import com.facebook.presto.spi.function.table.ConnectorTableFunctionHandle; +import com.facebook.presto.spi.function.table.TableFunctionProcessorProvider; +import com.facebook.presto.spi.plan.PlanNodeId; +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Iterables; +import com.google.common.primitives.Ints; +import com.google.common.util.concurrent.ListenableFuture; +import jakarta.annotation.Nullable; + +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.OptionalInt; + +import static com.facebook.presto.common.block.SortOrder.ASC_NULLS_LAST; +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkPositionIndex; +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.Iterables.concat; +import static java.util.Collections.nCopies; +import static java.util.Objects.requireNonNull; + +public class TableFunctionOperator + implements Operator +{ + public static class TableFunctionOperatorFactory + implements OperatorFactory + { + private final int operatorId; + private final PlanNodeId planNodeId; + + // a provider of table function processor to be called once per partition + private final TableFunctionProcessorProvider tableFunctionProvider; + + // all information necessary to execute the table function collected during analysis + private final ConnectorTableFunctionHandle functionHandle; + + // number of proper columns produced by the table function + private final int properChannelsCount; + + // number of input tables declared as pass-through + private final int passThroughSourcesCount; + + // columns required by the table function, in order of input tables + private final List> requiredChannels; + + // map from input channel to marker channel + // for each input table, there is a channel that marks which rows contain original data, and which are "filler" rows. + // the "filler" rows are part of the algorithm, and they should not be processed by the table function, or passed-through. + // In this map, every original column from the input table is associated with the appropriate marker. + private final Optional> markerChannels; + + // necessary information to build a pass-through column, for all pass-through columns, ordered as expected on the output + // it includes columns from sources declared as pass-through as well as partitioning columns from other sources + private final List passThroughSpecifications; + + // specifies whether the function should be pruned or executed when the input is empty + // pruneWhenEmpty is false if and only if all original input tables are KEEP WHEN EMPTY + private final boolean pruneWhenEmpty; + + // partitioning channels from all sources + private final List partitionChannels; + + // subset of partition channels that are already grouped + private final List prePartitionedChannels; + + // channels necessary to sort all sources: + // - for a single source, these are the source's sort channels + // - for multiple sources, this is a single synthesized row number channel + private final List sortChannels; + private final List sortOrders; + + // number of leading sort channels that are already sorted + private final int preSortedPrefix; + + private final List sourceTypes; + private final int expectedPositions; + private final PagesIndex.Factory pagesIndexFactory; + + private boolean closed; + + public TableFunctionOperatorFactory( + int operatorId, + PlanNodeId planNodeId, + TableFunctionProcessorProvider tableFunctionProvider, + ConnectorTableFunctionHandle functionHandle, + int properChannelsCount, + int passThroughSourcesCount, + List> requiredChannels, + Optional> markerChannels, + List passThroughSpecifications, + boolean pruneWhenEmpty, + List partitionChannels, + List prePartitionedChannels, + List sortChannels, + List sortOrders, + int preSortedPrefix, + List sourceTypes, + int expectedPositions, + PagesIndex.Factory pagesIndexFactory) + { + requireNonNull(planNodeId, "planNodeId is null"); + requireNonNull(tableFunctionProvider, "tableFunctionProvider is null"); + requireNonNull(functionHandle, "functionHandle is null"); + requireNonNull(requiredChannels, "requiredChannels is null"); + requireNonNull(markerChannels, "markerChannels is null"); + requireNonNull(passThroughSpecifications, "passThroughSpecifications is null"); + requireNonNull(partitionChannels, "partitionChannels is null"); + requireNonNull(prePartitionedChannels, "prePartitionedChannels is null"); + checkArgument(partitionChannels.containsAll(prePartitionedChannels), "prePartitionedChannels must be a subset of partitionChannels"); + requireNonNull(sortChannels, "sortChannels is null"); + requireNonNull(sortOrders, "sortOrders is null"); + checkArgument(sortChannels.size() == sortOrders.size(), "The number of sort channels must be equal to the number of sort orders"); + checkArgument(preSortedPrefix <= sortChannels.size(), "The number of pre-sorted channels must be lower or equal to the number of sort channels"); + checkArgument(preSortedPrefix == 0 || ImmutableSet.copyOf(prePartitionedChannels).equals(ImmutableSet.copyOf(partitionChannels)), "preSortedPrefix can only be greater than zero if all partition channels are pre-grouped"); + requireNonNull(sourceTypes, "sourceTypes is null"); + requireNonNull(pagesIndexFactory, "pagesIndexFactory is null"); + + this.operatorId = operatorId; + this.planNodeId = planNodeId; + this.tableFunctionProvider = tableFunctionProvider; + this.functionHandle = functionHandle; + this.properChannelsCount = properChannelsCount; + this.passThroughSourcesCount = passThroughSourcesCount; + this.requiredChannels = requiredChannels.stream() + .map(ImmutableList::copyOf) + .collect(toImmutableList()); + this.markerChannels = markerChannels.map(ImmutableMap::copyOf); + this.passThroughSpecifications = ImmutableList.copyOf(passThroughSpecifications); + this.pruneWhenEmpty = pruneWhenEmpty; + this.partitionChannels = ImmutableList.copyOf(partitionChannels); + this.prePartitionedChannels = ImmutableList.copyOf(prePartitionedChannels); + this.sortChannels = ImmutableList.copyOf(sortChannels); + this.sortOrders = ImmutableList.copyOf(sortOrders); + this.preSortedPrefix = preSortedPrefix; + this.sourceTypes = ImmutableList.copyOf(sourceTypes); + this.expectedPositions = expectedPositions; + this.pagesIndexFactory = pagesIndexFactory; + } + + @Override + public Operator createOperator(DriverContext driverContext) + { + checkState(!closed, "Factory is already closed"); + + OperatorContext operatorContext = driverContext.addOperatorContext(operatorId, planNodeId, TableFunctionOperator.class.getSimpleName()); + return new TableFunctionOperator( + operatorContext, + tableFunctionProvider, + functionHandle, + properChannelsCount, + passThroughSourcesCount, + requiredChannels, + markerChannels, + passThroughSpecifications, + pruneWhenEmpty, + partitionChannels, + prePartitionedChannels, + sortChannels, + sortOrders, + preSortedPrefix, + sourceTypes, + expectedPositions, + pagesIndexFactory); + } + + @Override + public void noMoreOperators() + { + closed = true; + } + + @Override + public OperatorFactory duplicate() + { + return new TableFunctionOperatorFactory( + operatorId, + planNodeId, + tableFunctionProvider, + functionHandle, + properChannelsCount, + passThroughSourcesCount, + requiredChannels, + markerChannels, + passThroughSpecifications, + pruneWhenEmpty, + partitionChannels, + prePartitionedChannels, + sortChannels, + sortOrders, + preSortedPrefix, + sourceTypes, + expectedPositions, + pagesIndexFactory); + } + } + + private final OperatorContext operatorContext; + + private final PageBuffer pageBuffer = new PageBuffer(); + private final WorkProcessor outputPages; + private final boolean processEmptyInput; + + @Nullable + private Page pendingInput; + private boolean operatorFinishing; + + public TableFunctionOperator( + OperatorContext operatorContext, + TableFunctionProcessorProvider tableFunctionProvider, + ConnectorTableFunctionHandle functionHandle, + int properChannelsCount, + int passThroughSourcesCount, + List> requiredChannels, + Optional> markerChannels, + List passThroughSpecifications, + boolean pruneWhenEmpty, + List partitionChannels, + List prePartitionedChannels, + List sortChannels, + List sortOrders, + int preSortedPrefix, + List sourceTypes, + int expectedPositions, + PagesIndex.Factory pagesIndexFactory) + { + requireNonNull(operatorContext, "operatorContext is null"); + requireNonNull(tableFunctionProvider, "tableFunctionProvider is null"); + requireNonNull(functionHandle, "functionHandle is null"); + requireNonNull(requiredChannels, "requiredChannels is null"); + requireNonNull(markerChannels, "markerChannels is null"); + requireNonNull(passThroughSpecifications, "passThroughSpecifications is null"); + requireNonNull(partitionChannels, "partitionChannels is null"); + requireNonNull(prePartitionedChannels, "prePartitionedChannels is null"); + checkArgument(partitionChannels.containsAll(prePartitionedChannels), "prePartitionedChannels must be a subset of partitionChannels"); + requireNonNull(sortChannels, "sortChannels is null"); + requireNonNull(sortOrders, "sortOrders is null"); + checkArgument(sortChannels.size() == sortOrders.size(), "The number of sort channels must be equal to the number of sort orders"); + checkArgument(preSortedPrefix <= sortChannels.size(), "The number of pre-sorted channels must be lower or equal to the number of sort channels"); + checkArgument(preSortedPrefix == 0 || ImmutableSet.copyOf(prePartitionedChannels).equals(ImmutableSet.copyOf(partitionChannels)), "preSortedPrefix can only be greater than zero if all partition channels are pre-grouped"); + requireNonNull(sourceTypes, "sourceTypes is null"); + requireNonNull(pagesIndexFactory, "pagesIndexFactory is null"); + + this.operatorContext = operatorContext; + + this.processEmptyInput = !pruneWhenEmpty; + + PagesIndex pagesIndex = pagesIndexFactory.newPagesIndex(sourceTypes, expectedPositions); + HashStrategies hashStrategies = new HashStrategies(pagesIndex, partitionChannels, prePartitionedChannels, sortChannels, sortOrders, preSortedPrefix); + + this.outputPages = pageBuffer.pages() + .transform(new PartitionAndSort(pagesIndex, hashStrategies, processEmptyInput)) + .flatMap(groupPagesIndex -> pagesIndexToTableFunctionPartitions( + groupPagesIndex, + hashStrategies, + tableFunctionProvider, + functionHandle, + properChannelsCount, + passThroughSourcesCount, + requiredChannels, + markerChannels, + passThroughSpecifications, + processEmptyInput)) + .flatMap(TableFunctionPartition::toOutputPages); + } + + @Override + public OperatorContext getOperatorContext() + { + return operatorContext; + } + + @Override + public void finish() + { + pageBuffer.finish(); + } + + @Override + public boolean isFinished() + { + return outputPages.isFinished(); + } + + @Override + public ListenableFuture isBlocked() + { + if (outputPages.isBlocked()) { + return outputPages.getBlockedFuture(); + } + + return NOT_BLOCKED; + } + + @Override + public boolean needsInput() + { + return pageBuffer.isEmpty() && !pageBuffer.isFinished(); + } + + @Override + public void addInput(Page page) + { + pageBuffer.add(page); + } + + @Override + public Page getOutput() + { + if (!outputPages.process()) { + return null; + } + + if (outputPages.isFinished()) { + return null; + } + + return outputPages.getResult(); + } + + private static class HashStrategies + { + final PagesHashStrategy prePartitionedStrategy; + final PagesHashStrategy remainingPartitionStrategy; + final PagesHashStrategy preSortedStrategy; + final List remainingPartitionAndSortChannels; + final List remainingSortOrders; + final int[] prePartitionedChannelsArray; + + public HashStrategies( + PagesIndex pagesIndex, + List partitionChannels, + List prePartitionedChannels, + List sortChannels, + List sortOrders, + int preSortedPrefix) + { + this.prePartitionedStrategy = pagesIndex.createPagesHashStrategy(prePartitionedChannels, OptionalInt.empty()); + + List remainingPartitionChannels = partitionChannels.stream() + .filter(channel -> !prePartitionedChannels.contains(channel)) + .collect(toImmutableList()); + this.remainingPartitionStrategy = pagesIndex.createPagesHashStrategy(remainingPartitionChannels, OptionalInt.empty()); + + List preSortedChannels = sortChannels.stream() + .limit(preSortedPrefix) + .collect(toImmutableList()); + this.preSortedStrategy = pagesIndex.createPagesHashStrategy(preSortedChannels, OptionalInt.empty()); + + if (preSortedPrefix > 0) { + // preSortedPrefix > 0 implies that all partition channels are already pre-partitioned (enforced by check in the constructor), so we only need to do the remaining sort + this.remainingPartitionAndSortChannels = ImmutableList.copyOf(Iterables.skip(sortChannels, preSortedPrefix)); + this.remainingSortOrders = ImmutableList.copyOf(Iterables.skip(sortOrders, preSortedPrefix)); + } + else { + // we need to sort by the remaining partition channels so that the input is fully partitioned, + // and then need to we sort by all the sort channels so that the input is fully sorted + this.remainingPartitionAndSortChannels = ImmutableList.copyOf(concat(remainingPartitionChannels, sortChannels)); + this.remainingSortOrders = ImmutableList.copyOf(concat(nCopies(remainingPartitionChannels.size(), ASC_NULLS_LAST), sortOrders)); + } + + this.prePartitionedChannelsArray = Ints.toArray(prePartitionedChannels); + } + } + + private class PartitionAndSort + implements WorkProcessor.Transformation + { + private final PagesIndex pagesIndex; + private final HashStrategies hashStrategies; + private final LocalMemoryContext memoryContext; + + private boolean resetPagesIndex; + private int inputPosition; + private boolean processEmptyInput; + + public PartitionAndSort(PagesIndex pagesIndex, HashStrategies hashStrategies, boolean processEmptyInput) + { + this.pagesIndex = pagesIndex; + this.hashStrategies = hashStrategies; + this.memoryContext = operatorContext.aggregateUserMemoryContext().newLocalMemoryContext(PartitionAndSort.class.getSimpleName()); + this.processEmptyInput = processEmptyInput; + } + + @Override + public WorkProcessor.TransformationState process(Optional input) + { + if (resetPagesIndex) { + pagesIndex.clear(); + updateMemoryUsage(); + resetPagesIndex = false; + } + + if (!input.isPresent() && pagesIndex.getPositionCount() == 0) { + if (processEmptyInput) { + // it can only happen at the first call to process(), which implies that there is no input. Empty PagesIndex can be passed on only once. + processEmptyInput = false; + return WorkProcessor.TransformationState.ofResult(pagesIndex, false); + } + else { + memoryContext.close(); + return WorkProcessor.TransformationState.finished(); + } + } + + // there is input, so we are not interested in processing empty input + processEmptyInput = false; + + if (input.isPresent()) { + // append rows from input which belong to the current group wrt pre-partitioned columns + // it might be one or more partitions + inputPosition = appendCurrentGroup(pagesIndex, hashStrategies, input.get(), inputPosition); + updateMemoryUsage(); + + if (inputPosition >= input.get().getPositionCount()) { + inputPosition = 0; + return WorkProcessor.TransformationState.needsMoreData(); + } + } + + // we have unused input or the input is finished. we have buffered a full group + // the group contains one or more partitions, as it was determined by the pre-partitioned columns + // sorting serves two purposes: + // - sort by the remaining partition channels so that the input is fully partitioned, + // - sort by all the sort channels so that the input is fully sorted + sortCurrentGroup(pagesIndex, hashStrategies); + resetPagesIndex = true; + return WorkProcessor.TransformationState.ofResult(pagesIndex, false); + } + + void updateMemoryUsage() + { + memoryContext.setBytes(pagesIndex.getEstimatedSize().toBytes()); + } + } + + private static int appendCurrentGroup(PagesIndex pagesIndex, HashStrategies hashStrategies, Page page, int startPosition) + { + checkArgument(page.getPositionCount() > startPosition); + + PagesHashStrategy prePartitionedStrategy = hashStrategies.prePartitionedStrategy; + Page prePartitionedPage = page.extractChannels(hashStrategies.prePartitionedChannelsArray); + + if (pagesIndex.getPositionCount() == 0 || pagesIndex.positionEqualsRow(prePartitionedStrategy, 0, startPosition, prePartitionedPage)) { + // we are within the current group. find the position where the pre-grouped columns change + int groupEnd = findGroupEnd(prePartitionedPage, prePartitionedStrategy, startPosition); + + // add the section of the page that contains values for the current group + pagesIndex.addPage(page.getRegion(startPosition, groupEnd - startPosition)); + + if (page.getPositionCount() - groupEnd > 0) { + // the remaining prt of the page contains the next group + return groupEnd; + } + // page fully consumed: it contains the current group only + return page.getPositionCount(); + } + + // we had previous results buffered, but the remaining page starts with new group values + return startPosition; + } + + private static void sortCurrentGroup(PagesIndex pagesIndex, HashStrategies hashStrategies) + { + PagesHashStrategy preSortedStrategy = hashStrategies.preSortedStrategy; + List remainingPartitionAndSortChannels = hashStrategies.remainingPartitionAndSortChannels; + List remainingSortOrders = hashStrategies.remainingSortOrders; + + if (pagesIndex.getPositionCount() > 1 && !remainingPartitionAndSortChannels.isEmpty()) { + int startPosition = 0; + while (startPosition < pagesIndex.getPositionCount()) { + int endPosition = findGroupEnd(pagesIndex, preSortedStrategy, startPosition); + pagesIndex.sort(remainingPartitionAndSortChannels, remainingSortOrders, startPosition, endPosition); + startPosition = endPosition; + } + } + } + + // Assumes input grouped on relevant pagesHashStrategy columns + private static int findGroupEnd(Page page, PagesHashStrategy pagesHashStrategy, int startPosition) + { + checkArgument(page.getPositionCount() > 0, "Must have at least one position"); + checkPositionIndex(startPosition, page.getPositionCount(), "startPosition out of bounds"); + + return findEndPosition(startPosition, page.getPositionCount(), (firstPosition, secondPosition) -> pagesHashStrategy.rowEqualsRow(firstPosition, page, secondPosition, page)); + } + + // Assumes input grouped on relevant pagesHashStrategy columns + private static int findGroupEnd(PagesIndex pagesIndex, PagesHashStrategy pagesHashStrategy, int startPosition) + { + checkArgument(pagesIndex.getPositionCount() > 0, "Must have at least one position"); + checkPositionIndex(startPosition, pagesIndex.getPositionCount(), "startPosition out of bounds"); + + return findEndPosition(startPosition, pagesIndex.getPositionCount(), (firstPosition, secondPosition) -> pagesIndex.positionEqualsPosition(pagesHashStrategy, firstPosition, secondPosition)); + } + + /** + * @param startPosition - inclusive + * @param endPosition - exclusive + * @param comparator - returns true if positions given as parameters are equal + * @return the end of the group position exclusive (the position the very next group starts) + */ + @VisibleForTesting + static int findEndPosition(int startPosition, int endPosition, PositionComparator comparator) + { + checkArgument(startPosition >= 0, "startPosition must be greater or equal than zero: %s", startPosition); + checkArgument(startPosition < endPosition, "startPosition (%s) must be less than endPosition (%s)", startPosition, endPosition); + + int left = startPosition; + int right = endPosition; + + while (right - left > 1) { + int middle = (left + right) >>> 1; + + if (comparator.test(startPosition, middle)) { + left = middle; + } + else { + right = middle; + } + } + + return right; + } + + private interface PositionComparator + { + boolean test(int first, int second); + } + + private WorkProcessor pagesIndexToTableFunctionPartitions( + PagesIndex pagesIndex, + HashStrategies hashStrategies, + TableFunctionProcessorProvider tableFunctionProvider, + ConnectorTableFunctionHandle functionHandle, + int properChannelsCount, + int passThroughSourcesCount, + List> requiredChannels, + Optional> markerChannels, + List passThroughSpecifications, + boolean processEmptyInput) + { + // pagesIndex contains the full grouped and sorted data for one or more partitions + + PagesHashStrategy remainingPartitionStrategy = hashStrategies.remainingPartitionStrategy; + + return WorkProcessor.create(new WorkProcessor.Process() + { + private int partitionStart; + private boolean processEmpty = processEmptyInput; + + @Override + public WorkProcessor.ProcessState process() + { + if (partitionStart == pagesIndex.getPositionCount()) { + if (processEmpty && pagesIndex.getPositionCount() == 0) { + // empty PagesIndex can only be passed once as the result of PartitionAndSort. Neither this nor any future instance of Process will ever get an empty PagesIndex again. + processEmpty = false; + return WorkProcessor.ProcessState.ofResult(new EmptyTableFunctionPartition( + tableFunctionProvider.getDataProcessor(functionHandle), + properChannelsCount, + passThroughSourcesCount, + passThroughSpecifications.stream() + .map(RegularTableFunctionPartition.PassThroughColumnSpecification::getInputChannel) + .map(pagesIndex::getType) + .collect(toImmutableList()))); + } + return WorkProcessor.ProcessState.finished(); + } + + // there is input, so we are not interested in processing empty input + processEmpty = false; + + int partitionEnd = findGroupEnd(pagesIndex, remainingPartitionStrategy, partitionStart); + + RegularTableFunctionPartition partition = new RegularTableFunctionPartition( + pagesIndex, + partitionStart, + partitionEnd, + tableFunctionProvider.getDataProcessor(functionHandle), + properChannelsCount, + passThroughSourcesCount, + requiredChannels, + markerChannels, + passThroughSpecifications); + + partitionStart = partitionEnd; + return WorkProcessor.ProcessState.ofResult(partition); + } + }); + } + + private class PagesSource + implements WorkProcessor.Process + { + @Override + public WorkProcessor.ProcessState process() + { + if (operatorFinishing && pendingInput == null) { + return WorkProcessor.ProcessState.finished(); + } + + if (pendingInput != null) { + Page result = pendingInput; + pendingInput = null; + return WorkProcessor.ProcessState.ofResult(result); + } + + return WorkProcessor.ProcessState.yield(); + } + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/TableFunctionPartition.java b/presto-main-base/src/main/java/com/facebook/presto/operator/TableFunctionPartition.java new file mode 100644 index 0000000000000..1876b352bd251 --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/TableFunctionPartition.java @@ -0,0 +1,21 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator; + +import com.facebook.presto.common.Page; + +public interface TableFunctionPartition +{ + WorkProcessor toOutputPages(); +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/table/ExcludeColumns.java b/presto-main-base/src/main/java/com/facebook/presto/operator/table/ExcludeColumns.java new file mode 100644 index 0000000000000..a098602b07bac --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/table/ExcludeColumns.java @@ -0,0 +1,169 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.table; + +import com.facebook.presto.common.type.RowType; +import com.facebook.presto.spi.ConnectorSession; +import com.facebook.presto.spi.PrestoException; +import com.facebook.presto.spi.connector.ConnectorTransactionHandle; +import com.facebook.presto.spi.function.table.AbstractConnectorTableFunction; +import com.facebook.presto.spi.function.table.Argument; +import com.facebook.presto.spi.function.table.ConnectorTableFunction; +import com.facebook.presto.spi.function.table.ConnectorTableFunctionHandle; +import com.facebook.presto.spi.function.table.Descriptor; +import com.facebook.presto.spi.function.table.DescriptorArgument; +import com.facebook.presto.spi.function.table.DescriptorArgumentSpecification; +import com.facebook.presto.spi.function.table.TableArgument; +import com.facebook.presto.spi.function.table.TableArgumentSpecification; +import com.facebook.presto.spi.function.table.TableFunctionAnalysis; +import com.facebook.presto.spi.function.table.TableFunctionDataProcessor; +import com.facebook.presto.spi.function.table.TableFunctionProcessorProvider; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Sets; + +import javax.inject.Provider; + +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; + +import static com.facebook.presto.spi.StandardErrorCode.INVALID_ARGUMENTS; +import static com.facebook.presto.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT; +import static com.facebook.presto.spi.function.table.DescriptorArgument.NULL_DESCRIPTOR; +import static com.facebook.presto.spi.function.table.ReturnTypeSpecification.GenericTable.GENERIC_TABLE; +import static com.facebook.presto.spi.function.table.TableFunctionProcessorState.Finished.FINISHED; +import static com.facebook.presto.spi.function.table.TableFunctionProcessorState.Processed.usedInputAndProduced; +import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static com.google.common.collect.Iterables.getOnlyElement; +import static java.lang.String.format; +import static java.util.Locale.ENGLISH; +import static java.util.stream.Collectors.joining; + +public class ExcludeColumns + implements Provider +{ + public static final String NAME = "exclude_columns"; + + @Override + public ConnectorTableFunction get() + { + return new ExcludeColumnsFunction(); + } + + public static class ExcludeColumnsFunction + extends AbstractConnectorTableFunction + { + private static final String TABLE_ARGUMENT_NAME = "INPUT"; + private static final String DESCRIPTOR_ARGUMENT_NAME = "COLUMNS"; + + public ExcludeColumnsFunction() + { + super( + "builtin", + NAME, + ImmutableList.of( + TableArgumentSpecification.builder() + .name(TABLE_ARGUMENT_NAME) + .rowSemantics() + .build(), + DescriptorArgumentSpecification.builder() + .name(DESCRIPTOR_ARGUMENT_NAME) + .build()), + GENERIC_TABLE); + } + + @Override + public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + { + DescriptorArgument excludedColumns = (DescriptorArgument) arguments.get(DESCRIPTOR_ARGUMENT_NAME); + if (excludedColumns.equals(NULL_DESCRIPTOR)) { + throw new PrestoException(INVALID_FUNCTION_ARGUMENT, "COLUMNS descriptor is null"); + } + Descriptor excludedColumnsDescriptor = excludedColumns.getDescriptor().orElseThrow(() -> new PrestoException(INVALID_ARGUMENTS, "Missing exclude columns descriptor")); + if (excludedColumnsDescriptor.getFields().stream().anyMatch(field -> field.getType().isPresent())) { + throw new PrestoException(INVALID_FUNCTION_ARGUMENT, "COLUMNS descriptor contains types"); + } + + // column names in DescriptorArgument are canonical wrt SQL identifier semantics. + // column names in TableArgument are not canonical wrt SQL identifier semantics, as they are taken from the corresponding RelationType. + // because of that, we match the excluded columns names case-insensitive + // TODO: apply proper identifier semantics + Set excludedNames = excludedColumnsDescriptor.getFields().stream() + .map(Descriptor.Field::getName) + .map(name -> name.orElseThrow(() -> new PrestoException(INVALID_ARGUMENTS, "Missing Descriptor field name")).toLowerCase(ENGLISH)) + .collect(toImmutableSet()); + + List inputSchema = ((TableArgument) arguments.get(TABLE_ARGUMENT_NAME)).getRowType().getFields(); + Set inputNames = inputSchema.stream() + .map(RowType.Field::getName) + .filter(Optional::isPresent) + .map(Optional::get) + .map(name -> name.toLowerCase(ENGLISH)) + .collect(toImmutableSet()); + + if (!inputNames.containsAll(excludedNames)) { + String missingColumns = Sets.difference(excludedNames, inputNames).stream() + .collect(joining(", ", "[", "]")); + throw new PrestoException(INVALID_FUNCTION_ARGUMENT, format("Excluded columns: %s not present in the table", missingColumns)); + } + + ImmutableList.Builder requiredColumns = ImmutableList.builder(); + ImmutableList.Builder returnedColumns = ImmutableList.builder(); + + for (int i = 0; i < inputSchema.size(); i++) { + Optional name = inputSchema.get(i).getName(); + if (!name.isPresent() || !excludedNames.contains(name.orElseThrow(() -> new PrestoException(INVALID_FUNCTION_ARGUMENT, "Missing schema name")).toLowerCase(ENGLISH))) { + requiredColumns.add(i); + // per SQL standard, all columns produced by a table function must be named. We allow anonymous columns. + returnedColumns.add(new Descriptor.Field(name, Optional.of(inputSchema.get(i).getType()))); + } + } + + List returnedType = returnedColumns.build(); + if (returnedType.isEmpty()) { + throw new PrestoException(INVALID_FUNCTION_ARGUMENT, "All columns are excluded"); + } + + return TableFunctionAnalysis.builder() + .requiredColumns(TABLE_ARGUMENT_NAME, requiredColumns.build()) + .returnedType(new Descriptor(returnedType)) + .handle(new ExcludeColumnsFunctionHandle()) + .build(); + } + } + + public static TableFunctionProcessorProvider getExcludeColumnsFunctionProcessorProvider() + { + return new TableFunctionProcessorProvider() + { + @Override + public TableFunctionDataProcessor getDataProcessor(ConnectorTableFunctionHandle handle) + { + return input -> { + if (input == null) { + return FINISHED; + } + return usedInputAndProduced(getOnlyElement(input).orElseThrow(() -> new PrestoException(INVALID_ARGUMENTS, "Missing data processor input"))); + }; + } + }; + } + + public static class ExcludeColumnsFunctionHandle + implements ConnectorTableFunctionHandle + { + // there's no information to remember. All logic is effectively delegated to the engine via `requiredColumns`. + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/table/Sequence.java b/presto-main-base/src/main/java/com/facebook/presto/operator/table/Sequence.java new file mode 100644 index 0000000000000..f32f850e1632a --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/table/Sequence.java @@ -0,0 +1,325 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.table; + +import com.facebook.presto.common.PageBuilder; +import com.facebook.presto.common.block.BlockBuilder; +import com.facebook.presto.spi.ConnectorSession; +import com.facebook.presto.spi.ConnectorSplit; +import com.facebook.presto.spi.ConnectorSplitSource; +import com.facebook.presto.spi.FixedSplitSource; +import com.facebook.presto.spi.HostAddress; +import com.facebook.presto.spi.NodeProvider; +import com.facebook.presto.spi.PrestoException; +import com.facebook.presto.spi.connector.ConnectorTransactionHandle; +import com.facebook.presto.spi.function.table.AbstractConnectorTableFunction; +import com.facebook.presto.spi.function.table.Argument; +import com.facebook.presto.spi.function.table.ConnectorTableFunction; +import com.facebook.presto.spi.function.table.ConnectorTableFunctionHandle; +import com.facebook.presto.spi.function.table.ReturnTypeSpecification.DescribedTable; +import com.facebook.presto.spi.function.table.ScalarArgument; +import com.facebook.presto.spi.function.table.ScalarArgumentSpecification; +import com.facebook.presto.spi.function.table.TableFunctionAnalysis; +import com.facebook.presto.spi.function.table.TableFunctionProcessorProvider; +import com.facebook.presto.spi.function.table.TableFunctionProcessorState; +import com.facebook.presto.spi.function.table.TableFunctionSplitProcessor; +import com.facebook.presto.spi.schedule.NodeSelectionStrategy; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; + +import javax.inject.Provider; + +import java.math.BigInteger; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.operator.table.Sequence.SequenceFunctionSplit.MAX_SPLIT_SIZE; +import static com.facebook.presto.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT; +import static com.facebook.presto.spi.function.table.Descriptor.descriptor; +import static com.facebook.presto.spi.function.table.TableFunctionProcessorState.Finished.FINISHED; +import static com.facebook.presto.spi.function.table.TableFunctionProcessorState.Processed.produced; +import static com.facebook.presto.spi.function.table.TableFunctionProcessorState.Processed.usedInputAndProduced; +import static com.facebook.presto.spi.schedule.NodeSelectionStrategy.NO_PREFERENCE; +import static com.google.common.base.Preconditions.checkArgument; +import static java.lang.String.format; + +public class Sequence + implements Provider +{ + public static final String NAME = "sequence"; + + @Override + public ConnectorTableFunction get() + { + return new SequenceFunction(); + } + + public static class SequenceFunction + extends AbstractConnectorTableFunction + { + private static final String START_ARGUMENT_NAME = "START"; + private static final String STOP_ARGUMENT_NAME = "STOP"; + private static final String STEP_ARGUMENT_NAME = "STEP"; + + public SequenceFunction() + { + super( + "builtin", + NAME, + ImmutableList.of( + ScalarArgumentSpecification.builder() + .name(START_ARGUMENT_NAME) + .type(BIGINT) + .defaultValue(0L) + .build(), + ScalarArgumentSpecification.builder() + .name(STOP_ARGUMENT_NAME) + .type(BIGINT) + .build(), + ScalarArgumentSpecification.builder() + .name(STEP_ARGUMENT_NAME) + .type(BIGINT) + .defaultValue(1L) + .build()), + new DescribedTable(descriptor(ImmutableList.of("sequential_number"), ImmutableList.of(BIGINT)))); + } + + @Override + public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + { + Object startValue = ((ScalarArgument) arguments.get(START_ARGUMENT_NAME)).getValue(); + if (startValue == null) { + throw new PrestoException(INVALID_FUNCTION_ARGUMENT, "Start is null"); + } + + Object stopValue = ((ScalarArgument) arguments.get(STOP_ARGUMENT_NAME)).getValue(); + if (stopValue == null) { + throw new PrestoException(INVALID_FUNCTION_ARGUMENT, "Stop is null"); + } + + Object stepValue = ((ScalarArgument) arguments.get(STEP_ARGUMENT_NAME)).getValue(); + if (stepValue == null) { + throw new PrestoException(INVALID_FUNCTION_ARGUMENT, "Step is null"); + } + + long start = (long) startValue; + long stop = (long) stopValue; + long step = (long) stepValue; + + if (start < stop && step <= 0) { + throw new PrestoException(INVALID_FUNCTION_ARGUMENT, format("Step must be positive for sequence [%s, %s]", start, stop)); + } + + if (start > stop && step >= 0) { + throw new PrestoException(INVALID_FUNCTION_ARGUMENT, format("Step must be negative for sequence [%s, %s]", start, stop)); + } + + return TableFunctionAnalysis.builder() + .handle(new SequenceFunctionHandle(start, stop, start == stop ? 0 : step)) + .build(); + } + } + + public static class SequenceFunctionHandle + implements ConnectorTableFunctionHandle + { + private final long start; + private final long stop; + private final long step; + + @JsonCreator + public SequenceFunctionHandle(@JsonProperty("start") long start, @JsonProperty("stop") long stop, @JsonProperty("step") long step) + { + this.start = start; + this.stop = stop; + this.step = step; + } + + @JsonProperty + public long start() + { + return start; + } + + @JsonProperty + public long stop() + { + return stop; + } + + @JsonProperty + public long step() + { + return step; + } + } + + public static ConnectorSplitSource getSequenceFunctionSplitSource(SequenceFunctionHandle handle) + { + // using BigInteger to avoid long overflow since it's not in the main data processing loop + BigInteger start = BigInteger.valueOf(handle.start()); + BigInteger stop = BigInteger.valueOf(handle.stop()); + BigInteger step = BigInteger.valueOf(handle.step()); + + if (step.equals(BigInteger.ZERO)) { + checkArgument(start.equals(stop), "start is not equal to stop for step = 0"); + return new FixedSplitSource(ImmutableList.of(new SequenceFunctionSplit(start.longValueExact(), stop.longValueExact()))); + } + + ImmutableList.Builder splits = ImmutableList.builder(); + + BigInteger totalSteps = stop.subtract(start).divide(step).add(BigInteger.ONE); + BigInteger totalSplits = totalSteps.divide(BigInteger.valueOf(MAX_SPLIT_SIZE)).add(BigInteger.ONE); + BigInteger[] stepsPerSplit = totalSteps.divideAndRemainder(totalSplits); + BigInteger splitJump = stepsPerSplit[0].subtract(BigInteger.ONE).multiply(step); + + BigInteger splitStart = start; + for (BigInteger i = BigInteger.ZERO; i.compareTo(totalSplits) < 0; i = i.add(BigInteger.ONE)) { + BigInteger splitStop = splitStart.add(splitJump); + // distribute the remaining steps between the initial splits, one step per split + if (i.compareTo(stepsPerSplit[1]) < 0) { + splitStop = splitStop.add(step); + } + splits.add(new SequenceFunctionSplit(splitStart.longValueExact(), splitStop.longValueExact())); + splitStart = splitStop.add(step); + } + + return new FixedSplitSource(splits.build()); + } + + public static class SequenceFunctionSplit + implements ConnectorSplit + { + public static final int DEFAULT_SPLIT_SIZE = 1000000; + public static final int MAX_SPLIT_SIZE = 1000000; + + // the first value of sub-sequence + private final long start; + + // the last value of sub-sequence. this value is aligned so that it belongs to the sequence. + private final long stop; + + @JsonCreator + public SequenceFunctionSplit(@JsonProperty("start") long start, @JsonProperty("stop") long stop) + { + this.start = start; + this.stop = stop; + } + + @JsonProperty + public long getStart() + { + return start; + } + + @JsonProperty + public long getStop() + { + return stop; + } + + @Override + public NodeSelectionStrategy getNodeSelectionStrategy() + { + return NO_PREFERENCE; + } + + @Override + public List getPreferredNodes(NodeProvider nodeProvider) + { + return Collections.emptyList(); + } + + @Override + public Object getInfo() + { + return ImmutableMap.builder() + .put("start", start) + .put("stop", stop) + .buildOrThrow(); + } + } + + public static TableFunctionProcessorProvider getSequenceFunctionProcessorProvider() + { + return new TableFunctionProcessorProvider() { + @Override + public TableFunctionSplitProcessor getSplitProcessor(ConnectorTableFunctionHandle handle) + { + return new SequenceFunctionProcessor(((SequenceFunctionHandle) handle).step()); + } + }; + } + + public static class SequenceFunctionProcessor + implements TableFunctionSplitProcessor + { + private final PageBuilder page = new PageBuilder(ImmutableList.of(BIGINT)); + private final long step; + private long start; + private long stop; + private boolean finished; + + public SequenceFunctionProcessor(long step) + { + this.step = step; + } + + @Override + public TableFunctionProcessorState process(ConnectorSplit split) + { + if (split != null) { + SequenceFunctionSplit sequenceSplit = (SequenceFunctionSplit) split; + start = sequenceSplit.getStart(); + stop = sequenceSplit.getStop(); + BlockBuilder block = page.getBlockBuilder(0); + while (start != stop && !page.isFull()) { + page.declarePosition(); + BIGINT.writeLong(block, start); + start += step; + } + if (!page.isFull()) { + page.declarePosition(); + BIGINT.writeLong(block, start); + finished = true; + return usedInputAndProduced(page.build()); + } + return usedInputAndProduced(page.build()); + } + + if (finished) { + return FINISHED; + } + + page.reset(); + BlockBuilder block = page.getBlockBuilder(0); + while (start != stop && !page.isFull()) { + page.declarePosition(); + BIGINT.writeLong(block, start); + start += step; + } + if (!page.isFull()) { + page.declarePosition(); + BIGINT.writeLong(block, start); + finished = true; + return produced(page.build()); + } + return produced(page.build()); + } + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/split/CloseableSplitSourceProvider.java b/presto-main-base/src/main/java/com/facebook/presto/split/CloseableSplitSourceProvider.java index 8bcd62f6006ef..f7f6afc493848 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/split/CloseableSplitSourceProvider.java +++ b/presto-main-base/src/main/java/com/facebook/presto/split/CloseableSplitSourceProvider.java @@ -15,6 +15,7 @@ import com.facebook.airlift.log.Logger; import com.facebook.presto.Session; +import com.facebook.presto.metadata.TableFunctionHandle; import com.facebook.presto.spi.TableHandle; import com.facebook.presto.spi.WarningCollector; import com.facebook.presto.spi.connector.ConnectorSplitManager.SplitSchedulingStrategy; @@ -32,14 +33,14 @@ public class CloseableSplitSourceProvider { private static final Logger log = Logger.get(CloseableSplitSourceProvider.class); - private final SplitSourceProvider delegate; + private final SplitManager delegate; @GuardedBy("this") private List splitSources = new ArrayList<>(); @GuardedBy("this") private boolean closed; - public CloseableSplitSourceProvider(SplitSourceProvider delegate) + public CloseableSplitSourceProvider(SplitManager delegate) { this.delegate = requireNonNull(delegate, "delegate is null"); } @@ -53,6 +54,15 @@ public synchronized SplitSource getSplits(Session session, TableHandle tableHand return splitSource; } + @Override + public synchronized SplitSource getSplits(Session session, TableFunctionHandle tableFunctionHandle) + { + checkState(!closed, "split source provider is closed"); + SplitSource splitSource = delegate.getSplitsForTableFunction(session, tableFunctionHandle); + splitSources.add(splitSource); + return splitSource; + } + @Override public synchronized void close() { diff --git a/presto-main-base/src/main/java/com/facebook/presto/split/SplitManager.java b/presto-main-base/src/main/java/com/facebook/presto/split/SplitManager.java index 8cd7efc5e2f9b..d334e5a5e6a70 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/split/SplitManager.java +++ b/presto-main-base/src/main/java/com/facebook/presto/split/SplitManager.java @@ -18,6 +18,7 @@ import com.facebook.presto.execution.scheduler.NodeSchedulerConfig; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.metadata.MetadataManager; +import com.facebook.presto.metadata.TableFunctionHandle; import com.facebook.presto.metadata.TableLayoutResult; import com.facebook.presto.spi.ConnectorId; import com.facebook.presto.spi.ConnectorSession; @@ -104,4 +105,17 @@ private ConnectorSplitManager getConnectorSplitManager(ConnectorId connectorId) checkArgument(result != null, "No split manager for connector '%s'", connectorId); return result; } + + public SplitSource getSplitsForTableFunction(Session session, TableFunctionHandle function) + { + ConnectorId connectorId = function.getConnectorId(); + ConnectorSplitManager splitManager = splitManagers.get(connectorId); + + ConnectorSplitSource source = splitManager.getSplits( + function.getTransactionHandle(), + session.toConnectorSession(connectorId), + function.getFunctionHandle()); + + return new ConnectorAwareSplitSource(connectorId, function.getTransactionHandle(), source); + } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/split/SplitSourceProvider.java b/presto-main-base/src/main/java/com/facebook/presto/split/SplitSourceProvider.java index 617fba7093613..30b54174c27b6 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/split/SplitSourceProvider.java +++ b/presto-main-base/src/main/java/com/facebook/presto/split/SplitSourceProvider.java @@ -14,6 +14,7 @@ package com.facebook.presto.split; import com.facebook.presto.Session; +import com.facebook.presto.metadata.TableFunctionHandle; import com.facebook.presto.spi.TableHandle; import com.facebook.presto.spi.WarningCollector; import com.facebook.presto.spi.connector.ConnectorSplitManager.SplitSchedulingStrategy; @@ -21,4 +22,5 @@ public interface SplitSourceProvider { SplitSource getSplits(Session session, TableHandle tableHandle, SplitSchedulingStrategy splitSchedulingStrategy, WarningCollector warningCollector); + SplitSource getSplits(Session session, TableFunctionHandle tableFunctionHandle); } diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/StatementAnalyzer.java b/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/StatementAnalyzer.java index 1fab75b489024..90a2421cf6c31 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/StatementAnalyzer.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/StatementAnalyzer.java @@ -1622,7 +1622,7 @@ private void verifyRequiredColumns(TableFunctionInvocation node, Map column < 0 || column >= inputScope.getRelationType().getAllFieldCount()) // hidden columns can be required as well as visible columns + .filter(column -> column < 0 || column >= inputScope.getRelationType().getVisibleFieldCount()) .findFirst() .ifPresent(column -> { throw new SemanticException(TABLE_FUNCTION_IMPLEMENTATION_ERROR, "Invalid index: %s of required column from table argument %s", column, name); diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/BasePlanFragmenter.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/BasePlanFragmenter.java index 7683f51b60973..2326c62b56dfd 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/BasePlanFragmenter.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/BasePlanFragmenter.java @@ -51,6 +51,8 @@ import com.facebook.presto.sql.planner.plan.SequenceNode; import com.facebook.presto.sql.planner.plan.SimplePlanRewriter; import com.facebook.presto.sql.planner.plan.StatisticsWriterNode; +import com.facebook.presto.sql.planner.plan.TableFunctionNode; +import com.facebook.presto.sql.planner.plan.TableFunctionProcessorNode; import com.facebook.presto.sql.planner.sanity.PlanChecker; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; @@ -296,6 +298,22 @@ public PlanNode visitValues(ValuesNode node, RewriteContext return context.defaultRewrite(node, context.get()); } + @Override + public PlanNode visitTableFunction(TableFunctionNode node, RewriteContext context) + { + throw new IllegalStateException(format("Unexpected node: TableFunctionNode (%s)", node.getName())); + } + + @Override + public PlanNode visitTableFunctionProcessor(TableFunctionProcessorNode node, RewriteContext context) + { + if (!node.getSource().isPresent()) { + // context is mutable. The leaf node should set the PartitioningHandle. + context.get().addSourceDistribution(node.getId(), SOURCE_DISTRIBUTION, metadata, session); + } + return context.defaultRewrite(node, context.get()); + } + @Override public PlanNode visitExchange(ExchangeNode exchange, RewriteContext context) { diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/LocalExecutionPlanner.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/LocalExecutionPlanner.java index 878f46dc04fcb..13cc38caef3fe 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/LocalExecutionPlanner.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/LocalExecutionPlanner.java @@ -69,6 +69,7 @@ import com.facebook.presto.operator.JoinBridgeManager; import com.facebook.presto.operator.JoinOperatorFactory; import com.facebook.presto.operator.JoinOperatorFactory.OuterOperatorFactoryResult; +import com.facebook.presto.operator.LeafTableFunctionOperator; import com.facebook.presto.operator.LimitOperator.LimitOperatorFactory; import com.facebook.presto.operator.LocalPlannerAware; import com.facebook.presto.operator.LookupJoinOperators; @@ -89,6 +90,7 @@ import com.facebook.presto.operator.PartitionFunction; import com.facebook.presto.operator.PartitionedLookupSourceFactory; import com.facebook.presto.operator.PipelineExecutionStrategy; +import com.facebook.presto.operator.RegularTableFunctionPartition; import com.facebook.presto.operator.RemoteProjectOperator.RemoteProjectOperatorFactory; import com.facebook.presto.operator.RowNumberOperator; import com.facebook.presto.operator.ScanFilterAndProjectOperator.ScanFilterAndProjectOperatorFactory; @@ -102,6 +104,7 @@ import com.facebook.presto.operator.StreamingAggregationOperator.StreamingAggregationOperatorFactory; import com.facebook.presto.operator.TableCommitContext; import com.facebook.presto.operator.TableFinishOperator.PageSinkCommitter; +import com.facebook.presto.operator.TableFunctionOperator; import com.facebook.presto.operator.TableScanOperator.TableScanOperatorFactory; import com.facebook.presto.operator.TableWriterMergeOperator.TableWriterMergeOperatorFactory; import com.facebook.presto.operator.TaskContext; @@ -145,11 +148,13 @@ import com.facebook.presto.spi.function.SqlFunctionId; import com.facebook.presto.spi.function.SqlInvokedFunction; import com.facebook.presto.spi.function.aggregation.LambdaProvider; +import com.facebook.presto.spi.function.table.TableFunctionProcessorProvider; import com.facebook.presto.spi.plan.AbstractJoinNode; import com.facebook.presto.spi.plan.AggregationNode; import com.facebook.presto.spi.plan.AggregationNode.Aggregation; import com.facebook.presto.spi.plan.AggregationNode.Step; import com.facebook.presto.spi.plan.Assignments; +import com.facebook.presto.spi.plan.DataOrganizationSpecification; import com.facebook.presto.spi.plan.DeleteNode; import com.facebook.presto.spi.plan.DistinctLimitNode; import com.facebook.presto.spi.plan.EquiJoinClause; @@ -218,6 +223,7 @@ import com.facebook.presto.sql.planner.plan.SampleNode; import com.facebook.presto.sql.planner.plan.StatisticsWriterNode; import com.facebook.presto.sql.planner.plan.TableFunctionNode; +import com.facebook.presto.sql.planner.plan.TableFunctionProcessorNode; import com.facebook.presto.sql.planner.plan.TableWriterMergeNode; import com.facebook.presto.sql.planner.plan.TopNRowNumberNode; import com.facebook.presto.sql.planner.plan.UpdateNode; @@ -250,6 +256,7 @@ import java.util.LinkedHashMap; import java.util.List; import java.util.Map; +import java.util.NoSuchElementException; import java.util.Optional; import java.util.OptionalInt; import java.util.Set; @@ -358,9 +365,11 @@ import static com.google.common.base.Verify.verify; import static com.google.common.collect.DiscreteDomain.integers; import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.ImmutableMap.toImmutableMap; import static com.google.common.collect.ImmutableSet.toImmutableSet; import static com.google.common.collect.Iterables.getOnlyElement; import static com.google.common.collect.Range.closedOpen; +import static java.lang.Math.toIntExact; import static java.lang.String.format; import static java.util.Objects.requireNonNull; import static java.util.stream.IntStream.range; @@ -1218,6 +1227,92 @@ public PhysicalOperation visitTableFunction(TableFunctionNode node, LocalExecuti throw new UnsupportedOperationException("execution by operator is not yet implemented for table function " + node.getName()); } + @Override + public PhysicalOperation visitTableFunctionProcessor(TableFunctionProcessorNode node, LocalExecutionPlanContext context) + { + TableFunctionProcessorProvider processorProvider = metadata.getFunctionAndTypeManager().getTableFunctionProcessorProvider(node.getHandle()); + + if (!node.getSource().isPresent()) { + OperatorFactory operatorFactory = new LeafTableFunctionOperator.LeafTableFunctionOperatorFactory(context.getNextOperatorId(), node.getId(), processorProvider, node.getHandle().getFunctionHandle()); + return new PhysicalOperation(operatorFactory, makeLayout(node), context, Optional.empty(), UNGROUPED_EXECUTION); + } + + PhysicalOperation source = node.getSource().orElseThrow(NoSuchElementException::new).accept(this, context); + + int properChannelsCount = node.getProperOutputs().size(); + + long passThroughSourcesCount = node.getPassThroughSpecifications().stream() + .filter(TableFunctionNode.PassThroughSpecification::isDeclaredAsPassThrough) + .count(); + + List> requiredChannels = node.getRequiredVariables().stream() + .map(list -> getChannelsForVariables(list, source.getLayout())) + .collect(toImmutableList()); + + Optional> markerChannels = node.getMarkerVariables() + .map(map -> map.entrySet().stream() + .collect(toImmutableMap(entry -> source.getLayout().get(entry.getKey()), entry -> source.getLayout().get(entry.getValue())))); + + int channel = properChannelsCount; + ImmutableList.Builder passThroughColumnSpecifications = ImmutableList.builder(); + for (TableFunctionNode.PassThroughSpecification specification : node.getPassThroughSpecifications()) { + // the table function produces one index channel for each source declared as pass-through. They are laid out after the proper channels. + int indexChannel = specification.isDeclaredAsPassThrough() ? channel++ : -1; + for (TableFunctionNode.PassThroughColumn column : specification.getColumns()) { + passThroughColumnSpecifications.add(new RegularTableFunctionPartition.PassThroughColumnSpecification(column.isPartitioningColumn(), source.getLayout().get(column.getOutputVariables()), indexChannel)); + } + } + + List partitionChannels = node.getSpecification() + .map(DataOrganizationSpecification::getPartitionBy) + .map(list -> getChannelsForVariables(list, source.getLayout())) + .orElse(ImmutableList.of()); + + List sortChannels = ImmutableList.of(); + List sortOrders = ImmutableList.of(); + if (node.getSpecification().flatMap(DataOrganizationSpecification::getOrderingScheme).isPresent()) { + OrderingScheme orderingScheme = node.getSpecification().flatMap(DataOrganizationSpecification::getOrderingScheme).orElseThrow(NoSuchElementException::new); + sortChannels = getChannelsForVariables(orderingScheme.getOrderByVariables(), source.getLayout()); + sortOrders = orderingScheme.getOrderingsMap().values().stream().collect(toImmutableList()); + } + + OperatorFactory operator = new TableFunctionOperator.TableFunctionOperatorFactory( + context.getNextOperatorId(), + node.getId(), + processorProvider, + node.getHandle().getFunctionHandle(), + properChannelsCount, + toIntExact(passThroughSourcesCount), + requiredChannels, + markerChannels, + passThroughColumnSpecifications.build(), + node.isPruneWhenEmpty(), + partitionChannels, + getChannelsForVariables(ImmutableList.copyOf(node.getPrePartitioned()), source.getLayout()), + sortChannels, + sortOrders, + node.getPreSorted(), + source.getTypes(), + 10_000, + pagesIndexFactory); + + ImmutableMap.Builder outputMappings = ImmutableMap.builder(); + for (int i = 0; i < node.getProperOutputs().size(); i++) { + outputMappings.put(node.getProperOutputs().get(i), i); + } + List passThroughVariables = node.getPassThroughSpecifications().stream() + .map(TableFunctionNode.PassThroughSpecification::getColumns) + .flatMap(Collection::stream) + .map(TableFunctionNode.PassThroughColumn::getOutputVariables) + .collect(toImmutableList()); + int outputChannel = properChannelsCount; + for (VariableReferenceExpression passThroughVariable : passThroughVariables) { + outputMappings.put(passThroughVariable, outputChannel++); + } + + return new PhysicalOperation(operator, outputMappings.buildOrThrow(), context, source); + } + @Override public PhysicalOperation visitTopN(TopNNode node, LocalExecutionPlanContext context) { @@ -2942,7 +3037,7 @@ public PhysicalOperation visitTableFinish(TableFinishNode node, LocalExecutionPl Map aggregationMap = aggregation.getAggregations().entrySet() .stream().collect( - ImmutableMap.toImmutableMap( + toImmutableMap( Map.Entry::getKey, entry -> createAggregation(entry.getValue()))); if (groupingVariables.isEmpty()) { diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java index a12b350dcdd1e..76a38e2f760f3 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java @@ -84,6 +84,8 @@ import com.facebook.presto.sql.planner.iterative.rule.PruneRedundantProjectionAssignments; import com.facebook.presto.sql.planner.iterative.rule.PruneSemiJoinColumns; import com.facebook.presto.sql.planner.iterative.rule.PruneSemiJoinFilteringSourceColumns; +import com.facebook.presto.sql.planner.iterative.rule.PruneTableFunctionProcessorColumns; +import com.facebook.presto.sql.planner.iterative.rule.PruneTableFunctionProcessorSourceColumns; import com.facebook.presto.sql.planner.iterative.rule.PruneTableScanColumns; import com.facebook.presto.sql.planner.iterative.rule.PruneTopNColumns; import com.facebook.presto.sql.planner.iterative.rule.PruneUpdateSourceColumns; @@ -123,6 +125,7 @@ import com.facebook.presto.sql.planner.iterative.rule.RemoveRedundantLimit; import com.facebook.presto.sql.planner.iterative.rule.RemoveRedundantSort; import com.facebook.presto.sql.planner.iterative.rule.RemoveRedundantSortColumns; +import com.facebook.presto.sql.planner.iterative.rule.RemoveRedundantTableFunctionProcessor; import com.facebook.presto.sql.planner.iterative.rule.RemoveRedundantTopN; import com.facebook.presto.sql.planner.iterative.rule.RemoveRedundantTopNColumns; import com.facebook.presto.sql.planner.iterative.rule.RemoveTrivialFilters; @@ -135,9 +138,9 @@ import com.facebook.presto.sql.planner.iterative.rule.RewriteCaseExpressionPredicate; import com.facebook.presto.sql.planner.iterative.rule.RewriteCaseToMap; import com.facebook.presto.sql.planner.iterative.rule.RewriteConstantArrayContainsToInExpression; +import com.facebook.presto.sql.planner.iterative.rule.RewriteExcludeColumnsFunctionToProjection; import com.facebook.presto.sql.planner.iterative.rule.RewriteFilterWithExternalFunctionToProject; import com.facebook.presto.sql.planner.iterative.rule.RewriteSpatialPartitioningAggregation; -import com.facebook.presto.sql.planner.iterative.rule.RewriteTableFunctionToTableScan; import com.facebook.presto.sql.planner.iterative.rule.RuntimeReorderJoinSides; import com.facebook.presto.sql.planner.iterative.rule.ScaledWriterRule; import com.facebook.presto.sql.planner.iterative.rule.SimplifyCardinalityMap; @@ -154,6 +157,8 @@ import com.facebook.presto.sql.planner.iterative.rule.TransformDistinctInnerJoinToLeftEarlyOutJoin; import com.facebook.presto.sql.planner.iterative.rule.TransformDistinctInnerJoinToRightEarlyOutJoin; import com.facebook.presto.sql.planner.iterative.rule.TransformExistsApplyToLateralNode; +import com.facebook.presto.sql.planner.iterative.rule.TransformTableFunctionProcessorToTableScan; +import com.facebook.presto.sql.planner.iterative.rule.TransformTableFunctionToTableFunctionProcessor; import com.facebook.presto.sql.planner.iterative.rule.TransformUncorrelatedInPredicateSubqueryToDistinctInnerJoin; import com.facebook.presto.sql.planner.iterative.rule.TransformUncorrelatedInPredicateSubqueryToSemiJoin; import com.facebook.presto.sql.planner.iterative.rule.TransformUncorrelatedLateralToJoin; @@ -319,6 +324,8 @@ public PlanOptimizers( new PruneValuesColumns(), new PruneWindowColumns(), new PruneLimitColumns(), + new PruneTableFunctionProcessorColumns(), + new PruneTableFunctionProcessorSourceColumns(), new PruneTableScanColumns()); builder.add(new LogicalCteOptimizer(metadata)); @@ -424,6 +431,7 @@ public PlanOptimizers( .addAll(predicatePushDownRules) .addAll(columnPruningRules) .addAll(ImmutableSet.of( + new TransformTableFunctionToTableFunctionProcessor(metadata), new MergeDuplicateAggregation(metadata.getFunctionAndTypeManager()), new RemoveRedundantIdentityProjections(), new RemoveFullSample(), @@ -446,6 +454,9 @@ public PlanOptimizers( new MergeLimitWithDistinct(), new PruneCountAggregationOverScalar(metadata.getFunctionAndTypeManager()), new PruneOrderByInAggregation(metadata.getFunctionAndTypeManager()), + new RemoveRedundantTableFunctionProcessor(), // must run after TransformTableFunctionToTableFunctionProcessor + new RewriteExcludeColumnsFunctionToProjection(), // must run after TransformTableFunctionToTableFunctionProcessor + new TransformTableFunctionProcessorToTableScan(metadata), // must run after TransformTableFunctionToTableFunctionProcessor new RewriteSpatialPartitioningAggregation(metadata))) .build()), new IterativeOptimizer( @@ -786,7 +797,11 @@ public PlanOptimizers( ruleStats, statsCalculator, estimatedExchangesCostCalculator, - ImmutableSet.of(new RemoveRedundantIdentityProjections(), new PruneRedundantProjectionAssignments())), + ImmutableSet.of( + new RemoveRedundantIdentityProjections(), + new PruneRedundantProjectionAssignments(), + // Re-run RemoveRedundantTableFunctionProcessor after SimplifyPlanWithEmptyInput to optimize empty input tables to empty ValueNode + new RemoveRedundantTableFunctionProcessor())), new PushdownSubfields(metadata, expressionOptimizerManager)); builder.add(predicatePushDown); // Run predicate push down one more time in case we can leverage new information from layouts' effective predicate @@ -881,14 +896,6 @@ public PlanOptimizers( costCalculator, ImmutableSet.of(new ScaledWriterRule()))); - builder.add( - new IterativeOptimizer( - metadata, - ruleStats, - statsCalculator, - costCalculator, - ImmutableSet.of(new RewriteTableFunctionToTableScan(metadata)))); - if (!noExchange) { builder.add(new ReplicateSemiJoinInDelete()); // Must run before AddExchanges diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlannerUtils.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlannerUtils.java index 826eaea10044b..4e32a6e6918ed 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlannerUtils.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlannerUtils.java @@ -205,6 +205,9 @@ public static PlanNode addOverrideProjection(PlanNode source, PlanNodeIdAllocato || source.getOutputVariables().stream().distinct().count() != source.getOutputVariables().size()) { return source; } + if (source instanceof ProjectNode && ((ProjectNode) source).getAssignments().getMap().equals(variableMap)) { + return source; + } Assignments.Builder assignmentsBuilder = Assignments.builder(); assignmentsBuilder.putAll(source.getOutputVariables().stream().collect(toImmutableMap(identity(), x -> variableMap.containsKey(x) ? variableMap.get(x) : x))); return new ProjectNode(source.getSourceLocation(), planNodeIdAllocator.getNextId(), source, assignmentsBuilder.build(), LOCAL); diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/QueryPlanner.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/QueryPlanner.java index f2c102667e06f..c7deb6a31cdec 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/QueryPlanner.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/QueryPlanner.java @@ -175,7 +175,7 @@ import static java.lang.String.format; import static java.util.Objects.requireNonNull; -class QueryPlanner +public class QueryPlanner { private final Analysis analysis; private final VariableAllocator variableAllocator; @@ -891,7 +891,7 @@ private PlanBuilder project(PlanBuilder subPlan, Iterable expression * * @return the new subplan and a mapping of each expression to the symbol representing the coercion or an existing symbol if a coercion wasn't needed */ - private PlanAndMappings coerce(PlanBuilder subPlan, List expressions, Analysis analysis, PlanNodeIdAllocator idAllocator, VariableAllocator variableAllocator, Metadata metadata) + public PlanAndMappings coerce(PlanBuilder subPlan, List expressions, Analysis analysis, PlanNodeIdAllocator idAllocator, VariableAllocator variableAllocator, Metadata metadata) { Assignments.Builder assignments = Assignments.builder(); assignments.putAll(subPlan.getRoot().getOutputVariables().stream().collect(toImmutableMap(Function.identity(), Function.identity()))); @@ -1713,15 +1713,18 @@ private RowExpression rowExpression(Expression expression, SqlPlannerContext con context.getTranslatorContext()); } - private static List toSymbolReferences(List variables) + public static List toSymbolReferences(List variables) { return variables.stream() - .map(variable -> new SymbolReference( - variable.getSourceLocation().map(location -> new NodeLocation(location.getLine(), location.getColumn())), - variable.getName())) + .map(QueryPlanner::toSymbolReference) .collect(toImmutableList()); } + public static SymbolReference toSymbolReference(VariableReferenceExpression variable) + { + return new SymbolReference(variable.getSourceLocation().map(location -> new NodeLocation(location.getLine(), location.getColumn())), variable.getName()); + } + public static class PlanAndMappings { private final PlanBuilder subPlan; diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/RelationPlanner.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/RelationPlanner.java index be6b7851ec99b..e07fadcbf4666 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/RelationPlanner.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/RelationPlanner.java @@ -16,6 +16,7 @@ import com.facebook.presto.Session; import com.facebook.presto.SystemSessionProperties; import com.facebook.presto.common.QualifiedObjectName; +import com.facebook.presto.common.block.SortOrder; import com.facebook.presto.common.predicate.TupleDomain; import com.facebook.presto.common.type.ArrayType; import com.facebook.presto.common.type.MapType; @@ -31,12 +32,14 @@ import com.facebook.presto.spi.plan.AggregationNode; import com.facebook.presto.spi.plan.Assignments; import com.facebook.presto.spi.plan.CteReferenceNode; +import com.facebook.presto.spi.plan.DataOrganizationSpecification; import com.facebook.presto.spi.plan.EquiJoinClause; import com.facebook.presto.spi.plan.ExceptNode; import com.facebook.presto.spi.plan.FilterNode; import com.facebook.presto.spi.plan.IntersectNode; import com.facebook.presto.spi.plan.JoinNode; import com.facebook.presto.spi.plan.MaterializedViewScanNode; +import com.facebook.presto.spi.plan.OrderingScheme; import com.facebook.presto.spi.plan.PlanNode; import com.facebook.presto.spi.plan.PlanNodeIdAllocator; import com.facebook.presto.spi.plan.ProjectNode; @@ -52,14 +55,17 @@ import com.facebook.presto.sql.analyzer.Field; import com.facebook.presto.sql.analyzer.RelationId; import com.facebook.presto.sql.analyzer.RelationType; +import com.facebook.presto.sql.analyzer.ResolvedField; import com.facebook.presto.sql.analyzer.Scope; -import com.facebook.presto.sql.analyzer.SemanticException; import com.facebook.presto.sql.parser.SqlParser; import com.facebook.presto.sql.planner.optimizations.JoinNodeUtils; import com.facebook.presto.sql.planner.optimizations.SampleNodeUtil; import com.facebook.presto.sql.planner.plan.LateralJoinNode; import com.facebook.presto.sql.planner.plan.SampleNode; import com.facebook.presto.sql.planner.plan.TableFunctionNode; +import com.facebook.presto.sql.planner.plan.TableFunctionNode.PassThroughColumn; +import com.facebook.presto.sql.planner.plan.TableFunctionNode.PassThroughSpecification; +import com.facebook.presto.sql.planner.plan.TableFunctionNode.TableArgumentProperties; import com.facebook.presto.sql.tree.AliasedRelation; import com.facebook.presto.sql.tree.Cast; import com.facebook.presto.sql.tree.CoalesceExpression; @@ -87,11 +93,10 @@ import com.facebook.presto.sql.tree.Row; import com.facebook.presto.sql.tree.SampledRelation; import com.facebook.presto.sql.tree.SetOperation; +import com.facebook.presto.sql.tree.SortItem; import com.facebook.presto.sql.tree.SymbolReference; import com.facebook.presto.sql.tree.Table; -import com.facebook.presto.sql.tree.TableFunctionDescriptorArgument; import com.facebook.presto.sql.tree.TableFunctionInvocation; -import com.facebook.presto.sql.tree.TableFunctionTableArgument; import com.facebook.presto.sql.tree.TableSubquery; import com.facebook.presto.sql.tree.Union; import com.facebook.presto.sql.tree.Unnest; @@ -99,6 +104,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableListMultimap; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; import com.google.common.collect.Iterables; import com.google.common.collect.ListMultimap; import com.google.common.collect.UnmodifiableIterator; @@ -119,6 +125,7 @@ import static com.facebook.presto.SystemSessionProperties.getQueryAnalyzerTimeout; import static com.facebook.presto.common.type.TypeUtils.isEnumType; import static com.facebook.presto.metadata.MetadataUtil.createQualifiedObjectName; +import static com.facebook.presto.spi.StandardErrorCode.INVALID_PLAN_ERROR; import static com.facebook.presto.spi.StandardErrorCode.QUERY_PLANNING_TIMEOUT; import static com.facebook.presto.spi.plan.AggregationNode.singleGroupingSet; import static com.facebook.presto.spi.plan.ProjectNode.Locality.LOCAL; @@ -127,7 +134,6 @@ import static com.facebook.presto.sql.analyzer.ExpressionTreeUtils.isEqualComparisonExpression; import static com.facebook.presto.sql.analyzer.ExpressionTreeUtils.resolveEnumLiteral; import static com.facebook.presto.sql.analyzer.FeaturesConfig.CteMaterializationStrategy.NONE; -import static com.facebook.presto.sql.analyzer.SemanticErrorCode.NOT_SUPPORTED; import static com.facebook.presto.sql.analyzer.SemanticExceptions.notSupportedException; import static com.facebook.presto.sql.planner.PlannerUtils.newVariable; import static com.facebook.presto.sql.planner.TranslateExpressionsUtil.toRowExpression; @@ -364,51 +370,185 @@ private RelationPlan planMaterializedView(Table node, Analysis.MaterializedViewI return new RelationPlan(mvScanNode, scope, outputVariables); } + /** + * Processes a {@code TableFunctionInvocation} node to construct and return a {@link RelationPlan}. + * This involves preparing the necessary plan nodes, variable mappings, and associated properties + * to represent the execution plan for the invoked table function. + * + * @param node The {@code TableFunctionInvocation} syntax tree node to be processed. + * @param context The SQL planner context used for planning and analysis tasks. + * @return A {@link RelationPlan} encapsulating the execution plan for the table function invocation. + */ @Override protected RelationPlan visitTableFunctionInvocation(TableFunctionInvocation node, SqlPlannerContext context) { - node.getArguments().stream() - .forEach(argument -> { - if (argument.getValue() instanceof TableFunctionTableArgument) { - throw new SemanticException(NOT_SUPPORTED, argument, "Table arguments are not yet supported for table functions"); - } - if (argument.getValue() instanceof TableFunctionDescriptorArgument) { - throw new SemanticException(NOT_SUPPORTED, argument, "Descriptor arguments are not yet supported for table functions"); - } - }); Analysis.TableFunctionInvocationAnalysis functionAnalysis = analysis.getTableFunctionAnalysis(node); + ImmutableList.Builder sources = ImmutableList.builder(); + ImmutableList.Builder sourceProperties = ImmutableList.builder(); + ImmutableList.Builder outputVariables = ImmutableList.builder(); + + // create new symbols for table function's proper columns + RelationType relationType = analysis.getScope(node).getRelationType(); + List properOutputs = IntStream.range(0, functionAnalysis.getProperColumnsCount()) + .mapToObj(relationType::getFieldByIndex) + .map(field -> variableAllocator.newVariable(getSourceLocation(node), field.getName().orElse("field"), field.getType())) + .collect(toImmutableList()); - // TODO handle input relations: - // 1. extract the input relations from node.getArguments() and plan them. Apply relation coercions if requested. - // 2. for each input relation, prepare the TableArgumentProperties record, consisting of: - // - row or set semantics (from the actualArgument) - // - prune when empty property (from the actualArgument) - // - pass through columns property (from the actualArgument) - // - optional Specification: ordering scheme and partitioning (from the node's argument) <- planned upon the source's RelationPlan (or combined RelationPlan from all sources) - // TODO add - argument name - // TODO add - mapping column name => Symbol // TODO mind the fields without names and duplicate field names in RelationType - List sources = ImmutableList.of(); - List inputRelationsProperties = ImmutableList.of(); - - Scope scope = analysis.getScope(node); + outputVariables.addAll(properOutputs); - ImmutableList.Builder outputVariablesBuilder = ImmutableList.builder(); - for (Field field : scope.getRelationType().getAllFields()) { - VariableReferenceExpression variable = variableAllocator.newVariable(getSourceLocation(node), field.getName().get(), field.getType()); - outputVariablesBuilder.add(variable); - } + processTableArguments(context, functionAnalysis, outputVariables, sources, sourceProperties); - List outputVariables = outputVariablesBuilder.build(); PlanNode root = new TableFunctionNode( idAllocator.getNextId(), functionAnalysis.getFunctionName(), functionAnalysis.getArguments(), - outputVariablesBuilder.build(), - sources.stream().map(RelationPlan::getRoot).collect(toImmutableList()), - inputRelationsProperties, - new TableFunctionHandle(functionAnalysis.getConnectorId(), functionAnalysis.getConnectorTableFunctionHandle(), functionAnalysis.getTransactionHandle())); + properOutputs, + sources.build(), + sourceProperties.build(), + functionAnalysis.getCopartitioningLists(), + new TableFunctionHandle( + functionAnalysis.getConnectorId(), + functionAnalysis.getConnectorTableFunctionHandle(), + functionAnalysis.getTransactionHandle())); + + return new RelationPlan(root, analysis.getScope(node), outputVariables.build()); + } + + private void processTableArguments(SqlPlannerContext context, + Analysis.TableFunctionInvocationAnalysis functionAnalysis, + ImmutableList.Builder outputVariables, + ImmutableList.Builder sources, + ImmutableList.Builder sourceProperties) + { + QueryPlanner partitionQueryPlanner = new QueryPlanner(analysis, variableAllocator, idAllocator, lambdaDeclarationToVariableMap, metadata, session, context, sqlParser); + // process sources in order of argument declarations + for (Analysis.TableArgumentAnalysis tableArgument : functionAnalysis.getTableArgumentAnalyses()) { + RelationPlan sourcePlan = process(tableArgument.getRelation(), context); + PlanBuilder sourcePlanBuilder = initializePlanBuilder(sourcePlan); + + int[] fieldIndexForVisibleColumn = getFieldIndexesForVisibleColumns(sourcePlan); + + List requiredColumns = functionAnalysis.getRequiredColumns().get(tableArgument.getArgumentName()).stream() + .map(column -> fieldIndexForVisibleColumn[column]) + .map(sourcePlan::getVariable) + .collect(toImmutableList()); + + Optional specification = Optional.empty(); + + // if the table argument has set semantics, create Specification + if (!tableArgument.isRowSemantics()) { + // partition by + List partitionBy = ImmutableList.of(); + // if there are partitioning columns, they might have to be coerced for copartitioning + if (tableArgument.getPartitionBy().isPresent() && !tableArgument.getPartitionBy().get().isEmpty()) { + List partitioningColumns = tableArgument.getPartitionBy().get(); + for (Expression partitionColumn : partitioningColumns) { + if (!sourcePlanBuilder.canTranslate(partitionColumn)) { + ResolvedField partition = sourcePlan.getScope().tryResolveField(partitionColumn).orElseThrow(() -> new PrestoException(INVALID_PLAN_ERROR, "Missing equivalent alias")); + sourcePlanBuilder.getTranslations().put(partitionColumn, sourcePlan.getVariable(partition.getRelationFieldIndex())); + } + } + QueryPlanner.PlanAndMappings copartitionCoercions = partitionQueryPlanner.coerce(sourcePlanBuilder, partitioningColumns, analysis, idAllocator, variableAllocator, metadata); + sourcePlanBuilder = copartitionCoercions.getSubPlan(); + partitionBy = partitioningColumns.stream() + .map(copartitionCoercions::get) + .collect(toImmutableList()); + } - return new RelationPlan(root, scope, outputVariables); + // order by + Optional orderBy = getOrderingScheme(tableArgument, sourcePlanBuilder, sourcePlan); + specification = Optional.of(new DataOrganizationSpecification(partitionBy, orderBy)); + } + + // add output symbols passed from the table argument + ImmutableList.Builder passThroughColumns = ImmutableList.builder(); + addPassthroughColumns(outputVariables, tableArgument, sourcePlan, specification, passThroughColumns, sourcePlanBuilder); + sources.add(sourcePlanBuilder.getRoot()); + + sourceProperties.add(new TableArgumentProperties( + tableArgument.getArgumentName(), + tableArgument.isRowSemantics(), + tableArgument.isPruneWhenEmpty(), + new PassThroughSpecification(tableArgument.isPassThroughColumns(), passThroughColumns.build()), + requiredColumns, + specification)); + } + } + + private static int[] getFieldIndexesForVisibleColumns(RelationPlan sourcePlan) + { + // required columns are a subset of visible columns of the source. remap required column indexes to field indexes in source relation type. + RelationType sourceRelationType = sourcePlan.getScope().getRelationType(); + int[] fieldIndexForVisibleColumn = new int[sourceRelationType.getVisibleFieldCount()]; + int visibleColumn = 0; + for (int i = 0; i < sourceRelationType.getAllFieldCount(); i++) { + if (!sourceRelationType.getFieldByIndex(i).isHidden()) { + fieldIndexForVisibleColumn[visibleColumn] = i; + visibleColumn++; + } + } + return fieldIndexForVisibleColumn; + } + + private static Optional getOrderingScheme(Analysis.TableArgumentAnalysis tableArgument, PlanBuilder sourcePlanBuilder, RelationPlan sourcePlan) + { + Optional orderBy = Optional.empty(); + if (tableArgument.getOrderBy().isPresent()) { + List sortItems = tableArgument.getOrderBy().get().getSortItems(); + + // Ensure all ORDER BY columns can be translated (populate missing translations if needed) + for (SortItem sortItem : sortItems) { + Expression sortKey = sortItem.getSortKey(); + if (!sourcePlanBuilder.canTranslate(sortKey)) { + Optional resolvedField = sourcePlan.getScope().tryResolveField(sortKey); + resolvedField.ifPresent(field -> sourcePlanBuilder.getTranslations().put( + sortKey, + sourcePlan.getVariable(field.getRelationFieldIndex()))); + } + } + + // The ordering symbols are coerced + List coerced = sortItems.stream() + .map(SortItem::getSortKey) + .map(sourcePlanBuilder::translate) + .collect(toImmutableList()); + + List sortOrders = sortItems.stream() + .map(PlannerUtils::toSortOrder) + .collect(toImmutableList()); + + orderBy = Optional.of(PlannerUtils.toOrderingScheme(coerced, sortOrders)); + } + return orderBy; + } + + private static void addPassthroughColumns(ImmutableList.Builder outputVariables, + Analysis.TableArgumentAnalysis tableArgument, RelationPlan sourcePlan, + Optional specification, + ImmutableList.Builder passThroughColumns, + PlanBuilder sourcePlanBuilder) + { + if (tableArgument.isPassThroughColumns()) { + // the original output symbols from the source node, not coerced + // note: hidden columns are included. They are present in sourcePlan.fieldMappings + outputVariables.addAll(sourcePlan.getFieldMappings()); + Set partitionBy = specification + .map(DataOrganizationSpecification::getPartitionBy) + .map(ImmutableSet::copyOf) + .orElse(ImmutableSet.of()); + sourcePlan.getFieldMappings().stream() + .map(variable -> new PassThroughColumn(variable, partitionBy.contains(variable))) + .forEach(passThroughColumns::add); + } + else if (tableArgument.getPartitionBy().isPresent()) { + tableArgument.getPartitionBy().get().stream() + .map(sourcePlanBuilder::translate) + // the original symbols for partitioning columns, not coerced + .forEach(variable -> { + outputVariables.add(variable); + passThroughColumns.add(new PassThroughColumn(variable, true)); + }); + } } @Override diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/SchedulingOrderVisitor.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/SchedulingOrderVisitor.java index abb784cdaa298..471c797c426a8 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/SchedulingOrderVisitor.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/SchedulingOrderVisitor.java @@ -22,9 +22,11 @@ import com.facebook.presto.spi.plan.SpatialJoinNode; import com.facebook.presto.spi.plan.TableScanNode; import com.facebook.presto.sql.planner.plan.InternalPlanVisitor; +import com.facebook.presto.sql.planner.plan.TableFunctionProcessorNode; import com.google.common.collect.ImmutableList; import java.util.List; +import java.util.NoSuchElementException; import java.util.function.Consumer; public class SchedulingOrderVisitor @@ -88,5 +90,17 @@ public Void visitTableScan(TableScanNode node, Consumer schedulingOr schedulingOrder.accept(node.getId()); return null; } + + @Override + public Void visitTableFunctionProcessor(TableFunctionProcessorNode node, Consumer schedulingOrder) + { + if (!node.getSource().isPresent()) { + schedulingOrder.accept(node.getId()); + } + else { + node.getSource().orElseThrow(NoSuchElementException::new).accept(this, schedulingOrder); + } + return null; + } } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/SplitSourceFactory.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/SplitSourceFactory.java index 28d2bc98b1efb..387271aa94e16 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/SplitSourceFactory.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/SplitSourceFactory.java @@ -61,6 +61,7 @@ import com.facebook.presto.sql.planner.plan.RowNumberNode; import com.facebook.presto.sql.planner.plan.SampleNode; import com.facebook.presto.sql.planner.plan.StatisticsWriterNode; +import com.facebook.presto.sql.planner.plan.TableFunctionProcessorNode; import com.facebook.presto.sql.planner.plan.TableWriterMergeNode; import com.facebook.presto.sql.planner.plan.TopNRowNumberNode; import com.facebook.presto.sql.planner.plan.UpdateNode; @@ -69,6 +70,7 @@ import java.util.List; import java.util.Map; +import java.util.NoSuchElementException; import java.util.function.Supplier; import static com.facebook.presto.spi.connector.ConnectorSplitManager.SplitSchedulingStrategy.GROUPED_SCHEDULING; @@ -283,6 +285,21 @@ public Map visitRowNumber(RowNumberNode node, Context c return node.getSource().accept(this, context); } + @Override + public Map visitTableFunctionProcessor(TableFunctionProcessorNode node, Context context) + { + if (!node.getSource().isPresent()) { + // this is a source node, so produce splits + SplitSource splitSource = splitSourceProvider.getSplits( + session, + node.getHandle()); + splitSources.add(splitSource); + return ImmutableMap.of(node.getId(), splitSource); + } + + return node.getSource().orElseThrow(NoSuchElementException::new).accept(this, context); + } + @Override public Map visitTopNRowNumber(TopNRowNumberNode node, Context context) { diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneTableFunctionProcessorColumns.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneTableFunctionProcessorColumns.java new file mode 100644 index 0000000000000..c95212bf38c0d --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneTableFunctionProcessorColumns.java @@ -0,0 +1,104 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.spi.VariableAllocator; +import com.facebook.presto.spi.plan.PlanNode; +import com.facebook.presto.spi.plan.PlanNodeIdAllocator; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.planner.plan.TableFunctionNode; +import com.facebook.presto.sql.planner.plan.TableFunctionProcessorNode; + +import java.util.List; +import java.util.Optional; +import java.util.Set; + +import static com.facebook.presto.sql.planner.plan.Patterns.tableFunctionProcessor; +import static com.google.common.collect.ImmutableList.toImmutableList; + +/** + * TableFunctionProcessorNode has two kinds of outputs: + * - proper outputs, which are the columns produced by the table function, + * - pass-through outputs, which are the columns copied from table arguments. + * This rule filters out unreferenced pass-through symbols. + * Unreferenced proper symbols are not pruned, because there is currently no way + * to communicate to the table function the request for not producing certain columns. + * // TODO prune table function's proper outputs + * Example: + *

+ * - Project
+ *   assignments={proper->proper1}
+ *  - TableFunctionProcessor
+ *    properOutputs=[proper1, proper2]
+ *    passThroughSymbols=[[passthrough1],[passthrough2]]
+ * 
+ * is transformed into + *
+ * - Project
+ *   assignments={proper->proper1}
+ *   - TableFunctionProcessor
+ *     properOutputs=[proper1, proper2]
+ *     passThroughSymbols=[]
+ * 
+ */ +public class PruneTableFunctionProcessorColumns + extends ProjectOffPushDownRule +{ + public PruneTableFunctionProcessorColumns() + { + super(tableFunctionProcessor()); + } + + @Override + protected Optional pushDownProjectOff(PlanNodeIdAllocator idAllocator, VariableAllocator variableAllocator, TableFunctionProcessorNode node, Set referencedOutputs) + { + List prunedPassThroughSpecifications = node.getPassThroughSpecifications().stream() + .map(sourceSpecification -> { + List prunedPassThroughColumns = sourceSpecification.getColumns().stream() + .filter(column -> referencedOutputs.contains(column.getOutputVariables())) + .collect(toImmutableList()); + return new TableFunctionNode.PassThroughSpecification(sourceSpecification.isDeclaredAsPassThrough(), prunedPassThroughColumns); + }) + .collect(toImmutableList()); + + int originalPassThroughCount = node.getPassThroughSpecifications().stream() + .map(TableFunctionNode.PassThroughSpecification::getColumns) + .mapToInt(List::size) + .sum(); + + int prunedPassThroughCount = prunedPassThroughSpecifications.stream() + .map(TableFunctionNode.PassThroughSpecification::getColumns) + .mapToInt(List::size) + .sum(); + + if (originalPassThroughCount == prunedPassThroughCount) { + return Optional.empty(); + } + + return Optional.of(new TableFunctionProcessorNode( + node.getId(), + node.getName(), + node.getProperOutputs(), + node.getSource(), + node.isPruneWhenEmpty(), + prunedPassThroughSpecifications, + node.getRequiredVariables(), + node.getMarkerVariables(), + node.getSpecification(), + node.getPrePartitioned(), + node.getPreSorted(), + node.getHashSymbol(), + node.getHandle())); + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneTableFunctionProcessorSourceColumns.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneTableFunctionProcessorSourceColumns.java new file mode 100644 index 0000000000000..d90f668d4c98f --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneTableFunctionProcessorSourceColumns.java @@ -0,0 +1,128 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.matching.Captures; +import com.facebook.presto.matching.Pattern; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.planner.iterative.Rule; +import com.facebook.presto.sql.planner.plan.TableFunctionNode; +import com.facebook.presto.sql.planner.plan.TableFunctionProcessorNode; +import com.google.common.collect.ImmutableSet; + +import java.util.Collection; +import java.util.Map; +import java.util.NoSuchElementException; +import java.util.Optional; + +import static com.facebook.presto.sql.planner.iterative.rule.Util.restrictOutputs; +import static com.facebook.presto.sql.planner.plan.Patterns.tableFunctionProcessor; +import static com.google.common.collect.Maps.filterKeys; + +/** + * This rule prunes unreferenced outputs of TableFunctionProcessorNode. + * First, it extracts all symbols required for: + * - pass-through + * - table function computation + * - partitioning and ordering (including the hashSymbol) + * Next, a mapping of input symbols to marker symbols is updated + * so that it only contains mappings for the required symbols. + * Last, all the remaining marker symbols are added to the collection + * of required symbols. + * Any source output symbols not included in the required symbols + * can be pruned. + * Example: + *
+ * - TableFunctionProcessor
+ *   properOutputs=[proper]
+ *   passThroughSymbols=[[passthrough1],[passthrough2]]
+ *   requiredSymbols=[[require1], [require2]]
+ *   specification=[partition={[partition1]} orderby={[order1 ASC_NULLS_LAST]}]
+ *   hashSymbol=[hash]
+ *   markerVariables={passthrough1->marker1, require1->marker1, partition1->marker1, order1->marker1, passthrough2->marker2, require2->marker, unreferenced->marker2}
+ *   - Source (which produces passthrough1, require1, partition1, order1, passthrough2, require2, marker, hash, unreferenced)
+ * 
+ * is transformed into + *
+ * - TableFunctionProcessor
+ *   properOutputs=[proper]
+ *   passThroughSymbols=[[passthrough1],[passthrough2]]
+ *   requiredSymbols=[[require1], [require2]]
+ *   specification=[partition={[partition1]} orderby={[order1 ASC_NULLS_LAST]}]
+ *   hashSymbol=[hash]
+ *   markerVariables={passthrough1->marker1, require1->marker1, partition1->marker1, order1->marker1, passthrough2->marker2, require2->marker}
+ *   - Project
+ *     assignments=[passthrough1, require1, partition1, order1, passthrough2, require2, marker, hash]
+ *     - Source (which produces passthrough1, require1, partition1, order1, passthrough2, require2, marker, hash, unreferenced)
+ * 
+ */ +public class PruneTableFunctionProcessorSourceColumns + implements Rule +{ + private static final Pattern PATTERN = tableFunctionProcessor(); + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + public Result apply(TableFunctionProcessorNode node, Captures captures, Context context) + { + if (!node.getSource().isPresent()) { + return Result.empty(); + } + + ImmutableSet.Builder requiredInputs = ImmutableSet.builder(); + + node.getPassThroughSpecifications().stream() + .map(TableFunctionNode.PassThroughSpecification::getColumns) + .flatMap(Collection::stream) + .map(TableFunctionNode.PassThroughColumn::getOutputVariables) + .forEach(requiredInputs::add); + + node.getRequiredVariables() + .forEach(requiredInputs::addAll); + + node.getSpecification().ifPresent(specification -> { + requiredInputs.addAll(specification.getPartitionBy()); + specification.getOrderingScheme().ifPresent(orderingScheme -> requiredInputs.addAll(orderingScheme.getOrderByVariables())); + }); + + node.getHashSymbol().ifPresent(requiredInputs::add); + + Optional> updatedMarkerSymbols = node.getMarkerVariables() + .map(mapping -> filterKeys(mapping, requiredInputs.build()::contains)); + + updatedMarkerSymbols.ifPresent(mapping -> requiredInputs.addAll(mapping.values())); + + return restrictOutputs(context.getIdAllocator(), node.getSource().orElseThrow(NoSuchElementException::new), requiredInputs.build()) + .map(child -> Result.ofPlanNode(new TableFunctionProcessorNode( + node.getId(), + node.getName(), + node.getProperOutputs(), + Optional.of(child), + node.isPruneWhenEmpty(), + node.getPassThroughSpecifications(), + node.getRequiredVariables(), + updatedMarkerSymbols, + node.getSpecification(), + node.getPrePartitioned(), + node.getPreSorted(), + node.getHashSymbol(), + node.getHandle()))) + .orElse(Result.empty()); + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RemoveRedundantTableFunctionProcessor.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RemoveRedundantTableFunctionProcessor.java new file mode 100644 index 0000000000000..48f58bce5952a --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RemoveRedundantTableFunctionProcessor.java @@ -0,0 +1,66 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.matching.Captures; +import com.facebook.presto.matching.Pattern; +import com.facebook.presto.spi.plan.ValuesNode; +import com.facebook.presto.sql.planner.iterative.Rule; +import com.facebook.presto.sql.planner.plan.TableFunctionProcessorNode; +import com.google.common.collect.ImmutableList; + +import java.util.NoSuchElementException; +import java.util.Optional; + +import static com.facebook.presto.sql.planner.optimizations.QueryCardinalityUtil.isAtMost; +import static com.facebook.presto.sql.planner.plan.Patterns.tableFunctionProcessor; + +/** + * Table function can take multiple table arguments. Each argument is either "prune when empty" or "keep when empty". + * "Prune when empty" means that if this argument has no rows, the function result is empty, so the function can be + * removed from the plan, and replaced with empty values. + * "Keep when empty" means that even if the argument has no rows, the function should still be executed, and it can + * return a non-empty result. + * All the table arguments are combined into a single source of a TableFunctionProcessorNode. If either argument is + * "prune when empty", the overall result is "prune when empty". This rule removes a redundant TableFunctionProcessorNode + * based on the "prune when empty" property. + */ +public class RemoveRedundantTableFunctionProcessor + implements Rule +{ + private static final Pattern PATTERN = tableFunctionProcessor(); + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + public Result apply(TableFunctionProcessorNode node, Captures captures, Context context) + { + if (node.isPruneWhenEmpty() && node.getSource().isPresent()) { + if (isAtMost(node.getSource().orElseThrow(NoSuchElementException::new), context.getLookup(), 0)) { + return Result.ofPlanNode( + new ValuesNode(node.getSourceLocation(), + node.getId(), + node.getOutputVariables(), + ImmutableList.of(), + Optional.empty())); + } + } + + return Result.empty(); + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RewriteExcludeColumnsFunctionToProjection.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RewriteExcludeColumnsFunctionToProjection.java new file mode 100644 index 0000000000000..1857dafa9493c --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RewriteExcludeColumnsFunctionToProjection.java @@ -0,0 +1,78 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.matching.Captures; +import com.facebook.presto.matching.Pattern; +import com.facebook.presto.operator.table.ExcludeColumns.ExcludeColumnsFunctionHandle; +import com.facebook.presto.spi.plan.Assignments; +import com.facebook.presto.spi.plan.ProjectNode; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.planner.iterative.Rule; +import com.facebook.presto.sql.planner.plan.TableFunctionProcessorNode; + +import java.util.List; +import java.util.NoSuchElementException; + +import static com.facebook.presto.sql.planner.plan.Patterns.tableFunctionProcessor; +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.collect.Iterators.getOnlyElement; +/** + * Rewrite a TableFunctionProcessorNode into a Project node if the table function is exclude_columns. + *
+ * - TableFunctionProcessorNode
+ *   propperOutputs=[A, B]
+ *   passthroughColumns=[C, D]
+ *   - (input) plan which produces symbols [A, B, C, D]
+ * 
+ * into + *
+ * - Project
+ *   assignments={A, B, C, D}
+ *   - (input) plan which produces symbols [A, B, C, D]
+ * 
+ */ +public class RewriteExcludeColumnsFunctionToProjection + implements Rule +{ + private static final Pattern PATTERN = tableFunctionProcessor(); + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + public Result apply(TableFunctionProcessorNode node, Captures captures, Context context) + { + if (!(node.getHandle().getFunctionHandle() instanceof ExcludeColumnsFunctionHandle)) { + return Result.empty(); + } + + List inputSymbols = getOnlyElement(node.getRequiredVariables().iterator()); + List outputSymbols = node.getOutputVariables(); + + checkState(inputSymbols.size() == outputSymbols.size(), "inputSymbols size differs from outputSymbols size"); + Assignments.Builder assignments = Assignments.builder(); + for (int i = 0; i < outputSymbols.size(); i++) { + assignments.put(outputSymbols.get(i), inputSymbols.get(i)); + } + + return Result.ofPlanNode(new ProjectNode( + node.getId(), + node.getSource().orElseThrow(NoSuchElementException::new), + assignments.build())); + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RewriteTableFunctionToTableScan.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformTableFunctionProcessorToTableScan.java similarity index 61% rename from presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RewriteTableFunctionToTableScan.java rename to presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformTableFunctionProcessorToTableScan.java index 2418377c7ac53..6cbf4378b334b 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RewriteTableFunctionToTableScan.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformTableFunctionProcessorToTableScan.java @@ -23,7 +23,7 @@ import com.facebook.presto.spi.plan.TableScanNode; import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.iterative.Rule; -import com.facebook.presto.sql.planner.plan.TableFunctionNode; +import com.facebook.presto.sql.planner.plan.TableFunctionProcessorNode; import com.google.common.collect.ImmutableMap; import java.util.List; @@ -31,65 +31,70 @@ import static com.facebook.presto.matching.Pattern.empty; import static com.facebook.presto.sql.planner.plan.Patterns.sources; -import static com.facebook.presto.sql.planner.plan.Patterns.tableFunction; +import static com.facebook.presto.sql.planner.plan.Patterns.tableFunctionProcessor; import static com.google.common.base.Preconditions.checkState; import static java.util.Objects.requireNonNull; /* - * This process converts connector-resolvable TableFunctionNodes into equivalent - * TableScanNodes by invoking the connector’s applyTableFunction() during planning. - * It allows table-valued functions whose results can be expressed as a ConnectorTableHandle - * to be treated like regular scans and benefit from normal scan optimizations. + * This rule converts connector-resolvable TableFunctionProcessorNodes into equivalent + * TableScanNodes by invoking the connector's applyTableFunction() method during query planning. + * + * It enables table-valued functions whose results can be represented as a ConnectorTableHandle + * to be treated like regular table scans, allowing them to benefit from standard scan optimizations. * * Example: * Before Transformation: * TableFunction(my_function(arg1, arg2)) * * After Transformation: - * TableScan(my_function(arg1, arg2)).applyTableFunction_tableHandle) - * assignments: {outputVar1 -> my_function(arg1, arg2)).applyTableFunction_colHandle1, - * outputVar2 -> my_function(arg1, arg2)).applyTableFunction_colHandle2} + * TableScan(my_function(arg1, arg2)) + * assignments: { + * outputVar1 -> my_function(arg1, arg2)_colHandle1, + * outputVar2 -> my_function(arg1, arg2)_colHandle2 + * } */ -public class RewriteTableFunctionToTableScan - implements Rule +public class TransformTableFunctionProcessorToTableScan + implements Rule { - private static final Pattern PATTERN = tableFunction() + private static final Pattern PATTERN = tableFunctionProcessor() .with(empty(sources())); private final Metadata metadata; - public RewriteTableFunctionToTableScan(Metadata metadata) + public TransformTableFunctionProcessorToTableScan(Metadata metadata) { this.metadata = requireNonNull(metadata, "metadata is null"); } @Override - public Pattern getPattern() + public Pattern getPattern() { return PATTERN; } @Override - public Result apply(TableFunctionNode tableFunctionNode, Captures captures, Context context) + public Result apply(TableFunctionProcessorNode node, Captures captures, Context context) { - Optional> result = metadata.applyTableFunction(context.getSession(), tableFunctionNode.getHandle()); + Optional> result = metadata.applyTableFunction(context.getSession(), node.getHandle()); if (!result.isPresent()) { return Result.empty(); } List columnHandles = result.get().getColumnHandles(); - checkState(tableFunctionNode.getOutputVariables().size() == columnHandles.size(), "returned table does not match the node's output"); + checkState(node.getOutputVariables().size() == columnHandles.size(), + "Connector returned %s columns but TableFunctionProcessorNode expects %s outputs", + columnHandles.size(), node.getOutputVariables().size()); ImmutableMap.Builder assignments = ImmutableMap.builder(); for (int i = 0; i < columnHandles.size(); i++) { - assignments.put(tableFunctionNode.getOutputVariables().get(i), columnHandles.get(i)); + assignments.put(node.getOutputVariables().get(i), columnHandles.get(i)); } return Result.ofPlanNode(new TableScanNode( - tableFunctionNode.getSourceLocation(), - tableFunctionNode.getId(), + node.getSourceLocation(), + node.getId(), result.get().getTableHandle(), - tableFunctionNode.getOutputVariables(), + node.getOutputVariables(), assignments.buildOrThrow(), TupleDomain.all(), TupleDomain.all(), Optional.empty())); diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformTableFunctionToTableFunctionProcessor.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformTableFunctionToTableFunctionProcessor.java new file mode 100644 index 0000000000000..8d143ea0d006e --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformTableFunctionToTableFunctionProcessor.java @@ -0,0 +1,1032 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.common.function.OperatorType; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.matching.Captures; +import com.facebook.presto.matching.Pattern; +import com.facebook.presto.metadata.FunctionAndTypeManager; +import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.spi.function.FunctionHandle; +import com.facebook.presto.spi.function.FunctionMetadata; +import com.facebook.presto.spi.plan.Assignments; +import com.facebook.presto.spi.plan.DataOrganizationSpecification; +import com.facebook.presto.spi.plan.JoinNode; +import com.facebook.presto.spi.plan.JoinType; +import com.facebook.presto.spi.plan.Ordering; +import com.facebook.presto.spi.plan.OrderingScheme; +import com.facebook.presto.spi.plan.PlanNode; +import com.facebook.presto.spi.plan.ProjectNode; +import com.facebook.presto.spi.plan.WindowNode; +import com.facebook.presto.spi.plan.WindowNode.Frame; +import com.facebook.presto.spi.relation.CallExpression; +import com.facebook.presto.spi.relation.ConstantExpression; +import com.facebook.presto.spi.relation.RowExpression; +import com.facebook.presto.spi.relation.SpecialFormExpression; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.planner.iterative.Rule; +import com.facebook.presto.sql.planner.plan.TableFunctionNode; +import com.facebook.presto.sql.planner.plan.TableFunctionNode.PassThroughSpecification; +import com.facebook.presto.sql.planner.plan.TableFunctionNode.TableArgumentProperties; +import com.facebook.presto.sql.planner.plan.TableFunctionProcessorNode; +import com.facebook.presto.sql.relational.FunctionResolution; +import com.facebook.presto.sql.tree.QualifiedName; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Streams; + +import java.util.Collection; +import java.util.Comparator; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import static com.facebook.presto.common.block.SortOrder.ASC_NULLS_LAST; +import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.common.type.BooleanType.BOOLEAN; +import static com.facebook.presto.spi.plan.JoinType.FULL; +import static com.facebook.presto.spi.plan.JoinType.INNER; +import static com.facebook.presto.spi.plan.JoinType.LEFT; +import static com.facebook.presto.spi.plan.JoinType.RIGHT; +import static com.facebook.presto.spi.plan.WindowNode.Frame.BoundType.UNBOUNDED_FOLLOWING; +import static com.facebook.presto.spi.plan.WindowNode.Frame.BoundType.UNBOUNDED_PRECEDING; +import static com.facebook.presto.spi.plan.WindowNode.Frame.WindowType.ROWS; +import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.COALESCE; +import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.IF; +import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes; +import static com.facebook.presto.sql.planner.plan.Patterns.tableFunction; +import static com.facebook.presto.sql.relational.Expressions.coalesce; +import static com.facebook.presto.sql.tree.ComparisonExpression.Operator.EQUAL; +import static com.facebook.presto.sql.tree.ComparisonExpression.Operator.GREATER_THAN; +import static com.facebook.presto.sql.tree.ComparisonExpression.Operator.IS_DISTINCT_FROM; +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static com.google.common.collect.Iterables.getOnlyElement; +import static java.util.Objects.requireNonNull; +import static java.util.function.Function.identity; + +/** + * This rule prepares cartesian product of partitions + * from all inputs of table function. + *

+ * It rewrites TableFunctionNode with potentially many sources + * into a TableFunctionProcessorNode. The new node has one + * source being a combination of the original sources. + *

+ * The original sources are combined with joins. The join + * conditions depend on the prune when empty property, and on + * the co-partitioning of sources. + *

+ * The resulting source should be partitioned and ordered + * according to combined schemas from the component sources. + *

+ * Example transformation for two sources, both with set semantics + * and KEEP WHEN EMPTY property: + *

+ * - TableFunction foo
+ *      - source T1(a1, b1) PARTITION BY a1 ORDER BY b1
+ *      - source T2(a2, b2) PARTITION BY a2
+ * 
+ * Is transformed into: + *
+ * - TableFunctionDataProcessor foo
+ *      PARTITION BY (a1, a2), ORDER BY combined_row_number
+ *      - Project
+ *          marker_1 <= IF(table1_row_number = combined_row_number, table1_row_number, CAST(null AS bigint))
+ *          marker_2 <= IF(table2_row_number = combined_row_number, table2_row_number, CAST(null AS bigint))
+ *          - Project
+ *              combined_row_number <= IF(COALESCE(table1_row_number, BIGINT '-1') > COALESCE(table2_row_number, BIGINT '-1'), table1_row_number, table2_row_number)
+ *              combined_partition_size <= IF(COALESCE(table1_partition_size, BIGINT '-1') > COALESCE(table2_partition_size, BIGINT '-1'), table1_partition_size, table2_partition_size)
+ *              - FULL Join
+ *                  [table1_row_number = table2_row_number OR
+ *                   table1_row_number > table2_partition_size AND table2_row_number = BIGINT '1' OR
+ *                   table2_row_number > table1_partition_size AND table1_row_number = BIGINT '1']
+ *                  - Window [PARTITION BY a1 ORDER BY b1]
+ *                      table1_row_number <= row_number()
+ *                      table1_partition_size <= count()
+ *                          - source T1(a1, b1)
+ *                  - Window [PARTITION BY a2]
+ *                      table2_row_number <= row_number()
+ *                      table2_partition_size <= count()
+ *                          - source T2(a2, b2)
+ * 
+ */ +public class TransformTableFunctionToTableFunctionProcessor + implements Rule +{ + private static final Pattern PATTERN = tableFunction(); + private static final Frame FULL_FRAME = new Frame( + ROWS, + UNBOUNDED_PRECEDING, + Optional.empty(), + Optional.empty(), + UNBOUNDED_FOLLOWING, + Optional.empty(), + Optional.empty(), + Optional.empty(), + Optional.empty()); + private static final DataOrganizationSpecification UNORDERED_SINGLE_PARTITION = new DataOrganizationSpecification(ImmutableList.of(), Optional.empty()); + + private final Metadata metadata; + + public TransformTableFunctionToTableFunctionProcessor(Metadata metadata) + { + this.metadata = requireNonNull(metadata, "metadata is null"); + } + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + public Result apply(TableFunctionNode node, Captures captures, Context context) + { + if (node.getSources().isEmpty()) { + return Result.ofPlanNode(new TableFunctionProcessorNode( + node.getId(), + node.getName(), + node.getProperOutputs(), + Optional.empty(), + false, + ImmutableList.of(), + ImmutableList.of(), + Optional.empty(), + Optional.empty(), + ImmutableSet.of(), + 0, + Optional.empty(), + node.getHandle())); + } + + if (node.getSources().size() == 1) { + // Single source does not require pre-processing. + // If the source has row semantics, its specification is empty. + // If the source has set semantics, its specification is present, even if there is no partitioning or ordering specified. + // This property can be used later to choose optimal distribution. + TableArgumentProperties sourceProperties = getOnlyElement(node.getTableArgumentProperties()); + return Result.ofPlanNode(new TableFunctionProcessorNode( + node.getId(), + node.getName(), + node.getProperOutputs(), + Optional.of(getOnlyElement(node.getSources())), + sourceProperties.isPruneWhenEmpty(), + ImmutableList.of(sourceProperties.getPassThroughSpecification()), + ImmutableList.of(sourceProperties.getRequiredColumns()), + Optional.empty(), + sourceProperties.getSpecification(), + ImmutableSet.of(), + 0, + Optional.empty(), + node.getHandle())); + } + Map sources = mapSourcesByName(node.getSources(), node.getTableArgumentProperties()); + ImmutableList.Builder intermediateResultsBuilder = ImmutableList.builder(); + + FunctionAndTypeManager functionAndTypeManager = metadata.getFunctionAndTypeManager(); + + // Create call expression for row_number + FunctionHandle rowNumberFunctionHandle = functionAndTypeManager.resolveFunction(Optional.of(context.getSession().getSessionFunctions()), + context.getSession().getTransactionId(), + functionAndTypeManager.getFunctionAndTypeResolver().qualifyObjectName(QualifiedName.of("row_number")), + ImmutableList.of()); + + FunctionMetadata rowNumberFunctionMetadata = functionAndTypeManager.getFunctionMetadata(rowNumberFunctionHandle); + CallExpression rowNumberFunction = new CallExpression("row_number", rowNumberFunctionHandle, functionAndTypeManager.getType(rowNumberFunctionMetadata.getReturnType()), ImmutableList.of()); + + // Create call expression for count + FunctionHandle countFunctionHandle = functionAndTypeManager.resolveFunction(Optional.of(context.getSession().getSessionFunctions()), + context.getSession().getTransactionId(), + functionAndTypeManager.getFunctionAndTypeResolver().qualifyObjectName(QualifiedName.of("count")), + ImmutableList.of()); + + FunctionMetadata countFunctionMetadata = functionAndTypeManager.getFunctionMetadata(countFunctionHandle); + CallExpression countFunction = new CallExpression("count", countFunctionHandle, functionAndTypeManager.getType(countFunctionMetadata.getReturnType()), ImmutableList.of()); + + // handle co-partitioned sources + for (List copartitioningList : node.getCopartitioningLists()) { + List sourceList = copartitioningList.stream() + .map(sources::get) + .collect(toImmutableList()); + intermediateResultsBuilder.add(copartition(sourceList, rowNumberFunction, countFunction, context, metadata)); + } + + // prepare non-co-partitioned sources + Set copartitionedSources = node.getCopartitioningLists().stream() + .flatMap(Collection::stream) + .collect(toImmutableSet()); + sources.entrySet().stream() + .filter(entry -> !copartitionedSources.contains(entry.getKey())) + .map(entry -> planWindowFunctionsForSource(entry.getValue().source(), entry.getValue().properties(), rowNumberFunction, countFunction, context)) + .forEach(intermediateResultsBuilder::add); + + NodeWithVariables finalResultSource; + + List intermediateResultSources = intermediateResultsBuilder.build(); + if (intermediateResultSources.size() == 1) { + finalResultSource = getOnlyElement(intermediateResultSources); + } + else { + NodeWithVariables first = intermediateResultSources.get(0); + NodeWithVariables second = intermediateResultSources.get(1); + JoinedNodes joined = join(first, second, context, metadata); + + for (int i = 2; i < intermediateResultSources.size(); i++) { + NodeWithVariables joinedWithSymbols = appendHelperSymbolsForJoinedNodes(joined, context, metadata); + joined = join(joinedWithSymbols, intermediateResultSources.get(i), context, metadata); + } + + finalResultSource = appendHelperSymbolsForJoinedNodes(joined, context, metadata); + } + + // For each source, all source's output symbols are mapped to the source's row number symbol. + // The row number symbol will be later converted to a marker of "real" input rows vs "filler" input rows of the source. + // The "filler" input rows are the rows appended while joining partitions of different lengths, + // to fill the smaller partition up to the bigger partition's size. They are a side effect of the algorithm, + // and should not be processed by the table function. + Map rowNumberSymbols = finalResultSource.rowNumberSymbolsMapping(); + + // The max row number symbol from all joined partitions. + VariableReferenceExpression finalRowNumberSymbol = finalResultSource.rowNumber(); + // Combined partitioning lists from all sources. + List finalPartitionBy = finalResultSource.partitionBy(); + + NodeWithMarkers marked = appendMarkerSymbols(finalResultSource.node(), ImmutableSet.copyOf(rowNumberSymbols.values()), finalRowNumberSymbol, context, metadata); + + // Remap the symbol mapping: replace the row number symbol with the corresponding marker symbol. + // In the new map, every source symbol is associated with the corresponding marker symbol. + // Null value of the marker indicates that the source value should be ignored by the table function. + ImmutableMap markerSymbols = rowNumberSymbols.entrySet().stream() + .collect(toImmutableMap(Map.Entry::getKey, entry -> marked.variableToMarker().get(entry.getValue()))); + + // Use the final row number symbol for ordering the combined sources. + // It runs along each partition in the cartesian product, numbering the partition's rows according to the expected ordering / orderings. + // note: ordering is necessary even if all the source tables are not ordered. Thanks to the ordering, the original rows + // of each input table come before the "filler" rows. + ImmutableList.Builder newOrderings = ImmutableList.builder(); + newOrderings.add(new Ordering(finalRowNumberSymbol, ASC_NULLS_LAST)); + Optional finalOrderBy = Optional.of(new OrderingScheme(newOrderings.build())); + + // derive the prune when empty property + boolean pruneWhenEmpty = node.getTableArgumentProperties().stream().anyMatch(TableArgumentProperties::isPruneWhenEmpty); + + // Combine the pass through specifications from all sources + List passThroughSpecifications = node.getTableArgumentProperties().stream() + .map(TableArgumentProperties::getPassThroughSpecification) + .collect(toImmutableList()); + + // Combine the required symbols from all sources + List> requiredVariables = node.getTableArgumentProperties().stream() + .map(TableArgumentProperties::getRequiredColumns) + .collect(toImmutableList()); + + return Result.ofPlanNode(new TableFunctionProcessorNode( + node.getId(), + node.getName(), + node.getProperOutputs(), + Optional.of(marked.node()), + pruneWhenEmpty, + passThroughSpecifications, + requiredVariables, + Optional.of(markerSymbols), + Optional.of(new DataOrganizationSpecification(finalPartitionBy, finalOrderBy)), + ImmutableSet.of(), + 0, + Optional.empty(), + node.getHandle())); + } + + private static Map mapSourcesByName(List sources, List properties) + { + return Streams.zip(sources.stream(), properties.stream(), SourceWithProperties::new) + .collect(toImmutableMap(entry -> entry.properties().getArgumentName(), identity())); + } + + private static NodeWithVariables planWindowFunctionsForSource( + PlanNode source, + TableArgumentProperties argumentProperties, + CallExpression rowNumberFunction, + CallExpression countFunction, + Context context) + { + String argumentName = argumentProperties.getArgumentName(); + + VariableReferenceExpression rowNumber = context.getVariableAllocator().newVariable(argumentName + "_row_number", BIGINT); + Map rowNumberSymbolMapping = source.getOutputVariables().stream() + .collect(toImmutableMap(identity(), symbol -> rowNumber)); + + VariableReferenceExpression partitionSize = context.getVariableAllocator().newVariable(argumentName + "_partition_size", BIGINT); + + // If the source has set semantics, its specification is present, even if there is no partitioning or ordering specified. + // If the source has row semantics, its specification is empty. Currently, such source is processed + // as if it was a single partition. Alternatively, it could be split into smaller partitions of arbitrary size. + DataOrganizationSpecification specification = argumentProperties.getSpecification().orElse(UNORDERED_SINGLE_PARTITION); + + PlanNode innerWindow = new WindowNode( + source.getSourceLocation(), + context.getIdAllocator().getNextId(), + source, + specification, + ImmutableMap.of( + rowNumber, new WindowNode.Function(rowNumberFunction, FULL_FRAME, false)), + Optional.empty(), + ImmutableSet.of(), + 0); + PlanNode window = new WindowNode( + innerWindow.getSourceLocation(), + context.getIdAllocator().getNextId(), + innerWindow, + specification, + ImmutableMap.of( + partitionSize, new WindowNode.Function(countFunction, FULL_FRAME, false)), + Optional.empty(), + ImmutableSet.of(), + 0); + + return new NodeWithVariables(window, rowNumber, partitionSize, specification.getPartitionBy(), argumentProperties.isPruneWhenEmpty(), rowNumberSymbolMapping); + } + + private static NodeWithVariables copartition( + List sourceList, + CallExpression rowNumberFunction, + CallExpression countFunction, + Context context, + Metadata metadata) + { + checkArgument(sourceList.size() >= 2, "co-partitioning list should contain at least two tables"); + + // Reorder the co-partitioned sources to process the sources with prune when empty property first. + // It allows to use inner or side joins instead of outer joins. + sourceList = sourceList.stream() + .sorted(Comparator.comparingInt(source -> source.properties().isPruneWhenEmpty() ? -1 : 1)) + .collect(toImmutableList()); + + NodeWithVariables first = planWindowFunctionsForSource(sourceList.get(0).source(), sourceList.get(0).properties(), rowNumberFunction, countFunction, context); + NodeWithVariables second = planWindowFunctionsForSource(sourceList.get(1).source(), sourceList.get(1).properties(), rowNumberFunction, countFunction, context); + JoinedNodes copartitioned = copartition(first, second, context, metadata); + + for (int i = 2; i < sourceList.size(); i++) { + NodeWithVariables copartitionedWithSymbols = appendHelperSymbolsForCopartitionedNodes(copartitioned, context, metadata); + NodeWithVariables next = planWindowFunctionsForSource(sourceList.get(i).source(), sourceList.get(i).properties(), rowNumberFunction, countFunction, context); + copartitioned = copartition(copartitionedWithSymbols, next, context, metadata); + } + + return appendHelperSymbolsForCopartitionedNodes(copartitioned, context, metadata); + } + + private static JoinedNodes copartition(NodeWithVariables left, NodeWithVariables right, Context context, Metadata metadata) + { + checkArgument(left.partitionBy().size() == right.partitionBy().size(), "co-partitioning lists do not match"); + + // In StatementAnalyzer we require that co-partitioned tables have non-empty partitioning column lists. + // Co-partitioning tables with empty partition by would be ineffective. + checkState(!left.partitionBy().isEmpty(), "co-partitioned tables must have partitioning columns"); + + FunctionResolution functionResolution = new FunctionResolution(metadata.getFunctionAndTypeManager().getFunctionAndTypeResolver()); + + Optional copartitionConjuncts = Streams.zip( + left.partitionBy.stream(), + right.partitionBy.stream(), + (leftColumn, rightColumn) -> new CallExpression("NOT", + functionResolution.notFunction(), + BOOLEAN, + ImmutableList.of( + new CallExpression(IS_DISTINCT_FROM.name(), + functionResolution.comparisonFunction(IS_DISTINCT_FROM, leftColumn.getType(), rightColumn.getType()), + BOOLEAN, + ImmutableList.of(leftColumn, rightColumn))))) + .map(expr -> expr) + .reduce((expr, conjunct) -> new SpecialFormExpression(SpecialFormExpression.Form.AND, + BOOLEAN, + ImmutableList.of(expr, conjunct))); + + // Align matching partitions (co-partitions) from left and right source, according to row number. + // Matching partitions are identified by their corresponding partitioning columns being NOT DISTINCT from each other. + // If one or both sources are ordered, the row number reflects the ordering. + // The second and third disjunct in the join condition account for the situation when partitions have different sizes. + // It preserves the outstanding rows from the bigger partition, matching them to the first row from the smaller partition. + // + // (P1_1 IS NOT DISTINCT FROM P2_1) AND (P1_2 IS NOT DISTINCT FROM P2_2) AND ... + // AND ( + // R1 = R2 + // OR + // (R1 > S2 AND R2 = 1) + // OR + // (R2 > S1 AND R1 = 1)) + + SpecialFormExpression orExpression = new SpecialFormExpression(SpecialFormExpression.Form.OR, + BOOLEAN, + ImmutableList.of( + new CallExpression(EQUAL.name(), + functionResolution.comparisonFunction(EQUAL, BIGINT, BIGINT), + BOOLEAN, + ImmutableList.of(left.rowNumber, right.rowNumber)), + new SpecialFormExpression(SpecialFormExpression.Form.OR, + BOOLEAN, + ImmutableList.of( + new SpecialFormExpression(SpecialFormExpression.Form.AND, + BOOLEAN, + ImmutableList.of( + new CallExpression(GREATER_THAN.name(), + functionResolution.comparisonFunction(GREATER_THAN, BIGINT, BIGINT), + BOOLEAN, + ImmutableList.of(left.rowNumber, right.partitionSize)), + new CallExpression(EQUAL.name(), + functionResolution.comparisonFunction(EQUAL, BIGINT, BIGINT), + BOOLEAN, + ImmutableList.of(right.rowNumber, new ConstantExpression(1L, BIGINT))))), + new SpecialFormExpression(SpecialFormExpression.Form.AND, + BOOLEAN, + ImmutableList.of( + new CallExpression(GREATER_THAN.name(), + functionResolution.comparisonFunction(GREATER_THAN, BIGINT, BIGINT), + BOOLEAN, + ImmutableList.of(right.rowNumber, left.partitionSize)), + new CallExpression(EQUAL.name(), + functionResolution.comparisonFunction(EQUAL, BIGINT, BIGINT), + BOOLEAN, + ImmutableList.of(left.rowNumber, new ConstantExpression(1L, BIGINT))))))))); + RowExpression joinCondition = copartitionConjuncts.map( + conjunct -> new SpecialFormExpression(SpecialFormExpression.Form.AND, + BOOLEAN, + ImmutableList.of(conjunct, orExpression))) + .orElse(orExpression); + + // The join type depends on the prune when empty property of the sources. + // If a source is prune when empty, we should not process any co-partition which is not present in this source, + // so effectively the other source becomes inner side of the join. + // + // example: + // table T1 partition by P1 table T2 partition by P2 + // P1 C1 P2 C2 + // ---------- ---------- + // 1 'a' 2 'c' + // 2 'b' 3 'd' + // + // co-partitioning results: + // 1) T1 is prune when empty: do LEFT JOIN to drop co-partition '3' + // P1 C1 P2 C2 + // ------------------------ + // 1 'a' null null + // 2 'b' 2 'c' + // + // 2) T2 is prune when empty: do RIGHT JOIN to drop co-partition '1' + // P1 C1 P2 C2 + // ------------------------ + // 2 'b' 2 'c' + // null null 3 'd' + // + // 3) T1 and T2 are both prune when empty: do INNER JOIN to drop co-partitions '1' and '3' + // P1 C1 P2 C2 + // ------------------------ + // 2 'b' 2 'c' + // + // 4) neither table is prune when empty: do FULL JOIN to preserve all co-partitions + // P1 C1 P2 C2 + // ------------------------ + // 1 'a' null null + // 2 'b' 2 'c' + // null null 3 'd' + JoinType joinType; + if (left.pruneWhenEmpty() && right.pruneWhenEmpty()) { + joinType = INNER; + } + else if (left.pruneWhenEmpty()) { + joinType = LEFT; + } + else if (right.pruneWhenEmpty()) { + joinType = RIGHT; + } + else { + joinType = FULL; + } + + return new JoinedNodes( + new JoinNode( + Optional.empty(), + context.getIdAllocator().getNextId(), + joinType, + left.node(), + right.node(), + ImmutableList.of(), + Stream.concat(left.node().getOutputVariables().stream(), + right.node().getOutputVariables().stream()) + .collect(Collectors.toList()), + Optional.of(joinCondition), + Optional.empty(), + Optional.empty(), + Optional.empty(), + ImmutableMap.of()), + left.rowNumber(), + left.partitionSize(), + left.partitionBy(), + left.pruneWhenEmpty(), + left.rowNumberSymbolsMapping(), + right.rowNumber(), + right.partitionSize(), + right.partitionBy(), + right.pruneWhenEmpty(), + right.rowNumberSymbolsMapping()); + } + + private static NodeWithVariables appendHelperSymbolsForCopartitionedNodes( + JoinedNodes copartitionedNodes, + Context context, + Metadata metadata) + { + checkArgument(copartitionedNodes.leftPartitionBy().size() == copartitionedNodes.rightPartitionBy().size(), "co-partitioning lists do not match"); + + // Derive row number for joined partitions: this is the bigger partition's row number. One of the combined values might be null as a result of outer join. + VariableReferenceExpression joinedRowNumber = context.getVariableAllocator().newVariable("combined_row_number", BIGINT); + RowExpression rowNumberExpression = new SpecialFormExpression( + IF, + BIGINT, + ImmutableList.of( + new CallExpression( + GREATER_THAN.name(), + metadata.getFunctionAndTypeManager().resolveOperator( + OperatorType.GREATER_THAN, + fromTypes(BIGINT, BIGINT)), + BOOLEAN, + ImmutableList.of( + new SpecialFormExpression( + COALESCE, + BIGINT, + copartitionedNodes.leftRowNumber(), + new ConstantExpression(-1L, BIGINT)), + new SpecialFormExpression( + COALESCE, + BIGINT, + copartitionedNodes.rightRowNumber(), + new ConstantExpression(-1L, BIGINT)))), + copartitionedNodes.leftRowNumber(), + copartitionedNodes.rightRowNumber())); + + // Derive partition size for joined partitions: this is the bigger partition's size. One of the combined values might be null as a result of outer join. + VariableReferenceExpression joinedPartitionSize = context.getVariableAllocator().newVariable("combined_partition_size", BIGINT); + RowExpression partitionSizeExpression = new SpecialFormExpression( + IF, + BIGINT, + ImmutableList.of( + new CallExpression( + GREATER_THAN.name(), + metadata.getFunctionAndTypeManager().resolveOperator( + OperatorType.GREATER_THAN, + fromTypes(BIGINT, BIGINT)), + BOOLEAN, + ImmutableList.of( + new SpecialFormExpression( + COALESCE, + BIGINT, + copartitionedNodes.leftPartitionSize(), + new ConstantExpression(-1L, BIGINT)), + new SpecialFormExpression( + COALESCE, + BIGINT, + copartitionedNodes.rightPartitionSize(), + new ConstantExpression(-1L, BIGINT)))), + copartitionedNodes.leftPartitionSize(), + copartitionedNodes.rightPartitionSize())); + + // Derive partitioning columns for joined partitions. + // Either the combined partitioning columns are pairwise NOT DISTINCT (this is the co-partitioning rule), + // or one of them is null as a result of outer join. + ImmutableList.Builder joinedPartitionBy = ImmutableList.builder(); + Assignments.Builder joinedPartitionByAssignments = Assignments.builder(); + for (int i = 0; i < copartitionedNodes.leftPartitionBy().size(); i++) { + VariableReferenceExpression leftColumn = copartitionedNodes.leftPartitionBy().get(i); + VariableReferenceExpression rightColumn = copartitionedNodes.rightPartitionBy().get(i); + Type type = context.getVariableAllocator().getVariables().get(leftColumn.getName()); + + VariableReferenceExpression joinedColumn = context.getVariableAllocator().newVariable("combined_partition_column", type); + joinedPartitionByAssignments.put(joinedColumn, coalesce(leftColumn, rightColumn)); + joinedPartitionBy.add(joinedColumn); + } + + PlanNode project = new ProjectNode( + context.getIdAllocator().getNextId(), + copartitionedNodes.joinedNode(), + Assignments.builder() + .putAll( + copartitionedNodes.joinedNode().getOutputVariables().stream() + .collect(toImmutableMap(v -> v, v -> v))) + .put(joinedRowNumber, rowNumberExpression) + .put(joinedPartitionSize, partitionSizeExpression) + .putAll(joinedPartitionByAssignments.build()) + .build()); + boolean joinedPruneWhenEmpty = copartitionedNodes.leftPruneWhenEmpty() || copartitionedNodes.rightPruneWhenEmpty(); + + Map joinedRowNumberSymbolsMapping = ImmutableMap.builder() + .putAll(copartitionedNodes.leftRowNumberSymbolsMapping()) + .putAll(copartitionedNodes.rightRowNumberSymbolsMapping()) + .buildOrThrow(); + + return new NodeWithVariables(project, joinedRowNumber, joinedPartitionSize, joinedPartitionBy.build(), joinedPruneWhenEmpty, joinedRowNumberSymbolsMapping); + } + + private static JoinedNodes join(NodeWithVariables left, NodeWithVariables right, Context context, Metadata metadata) + { + // Align rows from left and right source according to row number. Because every partition is row-numbered, this produces cartesian product of partitions. + // If one or both sources are ordered, the row number reflects the ordering. + // The second and third disjunct in the join condition account for the situation when partitions have different sizes. It preserves the outstanding rows + // from the bigger partition, matching them to the first row from the smaller partition. + // + // R1 = R2 + // OR + // (R1 > S2 AND R2 = 1) + // OR + // (R2 > S1 AND R1 = 1) + + FunctionResolution functionResolution = new FunctionResolution(metadata.getFunctionAndTypeManager().getFunctionAndTypeResolver()); + RowExpression joinCondition = new SpecialFormExpression(SpecialFormExpression.Form.OR, + BOOLEAN, + ImmutableList.of( + new CallExpression(EQUAL.name(), + functionResolution.comparisonFunction(EQUAL, BIGINT, BIGINT), + BOOLEAN, + ImmutableList.of(left.rowNumber, right.rowNumber)), + new SpecialFormExpression(SpecialFormExpression.Form.OR, + BOOLEAN, + ImmutableList.of( + new SpecialFormExpression(SpecialFormExpression.Form.AND, + BOOLEAN, + ImmutableList.of( + new CallExpression(GREATER_THAN.name(), + functionResolution.comparisonFunction(GREATER_THAN, BIGINT, BIGINT), + BOOLEAN, + ImmutableList.of(left.rowNumber, right.partitionSize)), + new CallExpression(EQUAL.name(), + functionResolution.comparisonFunction(EQUAL, BIGINT, BIGINT), + BOOLEAN, + ImmutableList.of(right.rowNumber, new ConstantExpression(1L, BIGINT))))), + new SpecialFormExpression(SpecialFormExpression.Form.AND, + BOOLEAN, + ImmutableList.of( + new CallExpression(GREATER_THAN.name(), + functionResolution.comparisonFunction(GREATER_THAN, BIGINT, BIGINT), + BOOLEAN, + ImmutableList.of(right.rowNumber, left.partitionSize)), + new CallExpression(EQUAL.name(), + functionResolution.comparisonFunction(EQUAL, BIGINT, BIGINT), + BOOLEAN, + ImmutableList.of(left.rowNumber, new ConstantExpression(1L, BIGINT))))))))); + JoinType joinType; + if (left.pruneWhenEmpty() && right.pruneWhenEmpty()) { + joinType = INNER; + } + else if (left.pruneWhenEmpty()) { + joinType = LEFT; + } + else if (right.pruneWhenEmpty()) { + joinType = RIGHT; + } + else { + joinType = FULL; + } + + return new JoinedNodes( + new JoinNode( + Optional.empty(), + context.getIdAllocator().getNextId(), + joinType, + left.node(), + right.node(), + ImmutableList.of(), + Stream.concat(left.node().getOutputVariables().stream(), + right.node().getOutputVariables().stream()) + .collect(Collectors.toList()), + Optional.of(joinCondition), + Optional.empty(), + Optional.empty(), + Optional.empty(), + ImmutableMap.of()), + left.rowNumber(), + left.partitionSize(), + left.partitionBy(), + left.pruneWhenEmpty(), + left.rowNumberSymbolsMapping(), + right.rowNumber(), + right.partitionSize(), + right.partitionBy(), + right.pruneWhenEmpty(), + right.rowNumberSymbolsMapping()); + } + + private static NodeWithVariables appendHelperSymbolsForJoinedNodes(JoinedNodes joinedNodes, Context context, Metadata metadata) + { + // Derive row number for joined partitions: this is the bigger partition's row number. One of the combined values might be null as a result of outer join. + VariableReferenceExpression joinedRowNumber = context.getVariableAllocator().newVariable("combined_row_number", BIGINT); + RowExpression rowNumberExpression = new SpecialFormExpression( + IF, + BIGINT, + ImmutableList.of( + new CallExpression( + GREATER_THAN.name(), + metadata.getFunctionAndTypeManager().resolveOperator( + OperatorType.GREATER_THAN, + fromTypes(BIGINT, BIGINT)), + BOOLEAN, + ImmutableList.of( + new SpecialFormExpression( + COALESCE, + BIGINT, + joinedNodes.leftRowNumber(), + new ConstantExpression(-1L, BIGINT)), + new SpecialFormExpression( + COALESCE, + BIGINT, + joinedNodes.rightRowNumber(), + new ConstantExpression(-1L, BIGINT)))), + joinedNodes.leftRowNumber(), + joinedNodes.rightRowNumber())); + + // Derive partition size for joined partitions: this is the bigger partition's size. One of the combined values might be null as a result of outer join. + VariableReferenceExpression joinedPartitionSize = context.getVariableAllocator().newVariable("combined_partition_size", BIGINT); + RowExpression partitionSizeExpression = new SpecialFormExpression( + IF, + BIGINT, + ImmutableList.of( + new CallExpression( + GREATER_THAN.name(), + metadata.getFunctionAndTypeManager().resolveOperator( + OperatorType.GREATER_THAN, + fromTypes(BIGINT, BIGINT)), + BOOLEAN, + ImmutableList.of( + new SpecialFormExpression( + COALESCE, + BIGINT, + joinedNodes.leftPartitionSize(), + new ConstantExpression(-1L, BIGINT)), + new SpecialFormExpression( + COALESCE, + BIGINT, + joinedNodes.rightPartitionSize(), + new ConstantExpression(-1L, BIGINT)))), + joinedNodes.leftPartitionSize(), + joinedNodes.rightPartitionSize())); + + PlanNode project = new ProjectNode( + context.getIdAllocator().getNextId(), + joinedNodes.joinedNode(), + Assignments.builder() + .putAll( + joinedNodes.joinedNode().getOutputVariables().stream() + .collect(toImmutableMap(v -> v, v -> v))) + .put(joinedRowNumber, rowNumberExpression) + .put(joinedPartitionSize, partitionSizeExpression) + .build()); + + List joinedPartitionBy = ImmutableList.builder() + .addAll(joinedNodes.leftPartitionBy()) + .addAll(joinedNodes.rightPartitionBy()) + .build(); + + boolean joinedPruneWhenEmpty = joinedNodes.leftPruneWhenEmpty() || joinedNodes.rightPruneWhenEmpty(); + + Map joinedRowNumberSymbolsMapping = ImmutableMap.builder() + .putAll(joinedNodes.leftRowNumberSymbolsMapping()) + .putAll(joinedNodes.rightRowNumberSymbolsMapping()) + .buildOrThrow(); + + return new NodeWithVariables(project, joinedRowNumber, joinedPartitionSize, joinedPartitionBy, joinedPruneWhenEmpty, joinedRowNumberSymbolsMapping); + } + + private static NodeWithMarkers appendMarkerSymbols(PlanNode node, Set variables, VariableReferenceExpression referenceSymbol, Context context, Metadata metadata) + { + Assignments.Builder assignments = Assignments.builder(); + assignments.putAll( + node.getOutputVariables().stream() + .collect(toImmutableMap(v -> v, v -> v))); + + ImmutableMap.Builder variablesToMarkers = ImmutableMap.builder(); + + for (VariableReferenceExpression variable : variables) { + VariableReferenceExpression marker = context.getVariableAllocator().newVariable("marker", BIGINT); + variablesToMarkers.put(variable, marker); + RowExpression ifExpression = new SpecialFormExpression( + IF, + BIGINT, + ImmutableList.of( + new CallExpression( + EQUAL.name(), + metadata.getFunctionAndTypeManager().resolveOperator( + OperatorType.EQUAL, + fromTypes(BIGINT, BIGINT)), + BOOLEAN, + ImmutableList.of(variable, referenceSymbol)), + variable, + new ConstantExpression(null, BIGINT))); + assignments.put(marker, ifExpression); + } + + PlanNode project = new ProjectNode( + context.getIdAllocator().getNextId(), + node, + assignments.build()); + + return new NodeWithMarkers(project, variablesToMarkers.buildOrThrow()); + } + + private static class SourceWithProperties + { + private final PlanNode source; + private final TableArgumentProperties properties; + + public SourceWithProperties(PlanNode source, TableArgumentProperties properties) + { + this.source = requireNonNull(source, "source is null"); + this.properties = requireNonNull(properties, "properties is null"); + } + + public PlanNode source() + { + return source; + } + + public TableArgumentProperties properties() + { + return properties; + } + } + + public static final class NodeWithVariables + { + private final PlanNode node; + private final VariableReferenceExpression rowNumber; + private final VariableReferenceExpression partitionSize; + private final List partitionBy; + private final boolean pruneWhenEmpty; + private final Map rowNumberSymbolsMapping; + + public NodeWithVariables(PlanNode node, VariableReferenceExpression rowNumber, VariableReferenceExpression partitionSize, + List partitionBy, boolean pruneWhenEmpty, + Map rowNumberSymbolsMapping) + { + this.node = requireNonNull(node, "node is null"); + this.rowNumber = requireNonNull(rowNumber, "rowNumber is null"); + this.partitionSize = requireNonNull(partitionSize, "partitionSize is null"); + this.partitionBy = ImmutableList.copyOf(partitionBy); + this.pruneWhenEmpty = pruneWhenEmpty; + this.rowNumberSymbolsMapping = ImmutableMap.copyOf(rowNumberSymbolsMapping); + } + + public PlanNode node() + { + return node; + } + + public VariableReferenceExpression rowNumber() + { + return rowNumber; + } + + public VariableReferenceExpression partitionSize() + { + return partitionSize; + } + + public List partitionBy() + { + return partitionBy; + } + + public boolean pruneWhenEmpty() + { + return pruneWhenEmpty; + } + + public Map rowNumberSymbolsMapping() + { + return rowNumberSymbolsMapping; + } + } + + private static class JoinedNodes + { + private final PlanNode joinedNode; + private final VariableReferenceExpression leftRowNumber; + private final VariableReferenceExpression leftPartitionSize; + private final List leftPartitionBy; + private final boolean leftPruneWhenEmpty; + private final Map leftRowNumberSymbolsMapping; + private final VariableReferenceExpression rightRowNumber; + private final VariableReferenceExpression rightPartitionSize; + private final List rightPartitionBy; + private final boolean rightPruneWhenEmpty; + private final Map rightRowNumberSymbolsMapping; + + public JoinedNodes( + PlanNode joinedNode, + VariableReferenceExpression leftRowNumber, + VariableReferenceExpression leftPartitionSize, + List leftPartitionBy, + boolean leftPruneWhenEmpty, + Map leftRowNumberSymbolsMapping, + VariableReferenceExpression rightRowNumber, + VariableReferenceExpression rightPartitionSize, + List rightPartitionBy, + boolean rightPruneWhenEmpty, + Map rightRowNumberSymbolsMapping) + { + this.joinedNode = requireNonNull(joinedNode, "joinedNode is null"); + this.leftRowNumber = requireNonNull(leftRowNumber, "leftRowNumber is null"); + this.leftPartitionSize = requireNonNull(leftPartitionSize, "leftPartitionSize is null"); + this.leftPartitionBy = ImmutableList.copyOf(requireNonNull(leftPartitionBy, "leftPartitionBy is null")); + this.leftPruneWhenEmpty = leftPruneWhenEmpty; + this.leftRowNumberSymbolsMapping = ImmutableMap.copyOf(requireNonNull(leftRowNumberSymbolsMapping, "leftRowNumberSymbolsMapping is null")); + this.rightRowNumber = requireNonNull(rightRowNumber, "rightRowNumber is null"); + this.rightPartitionSize = requireNonNull(rightPartitionSize, "rightPartitionSize is null"); + this.rightPartitionBy = ImmutableList.copyOf(requireNonNull(rightPartitionBy, "rightPartitionBy is null")); + this.rightPruneWhenEmpty = rightPruneWhenEmpty; + this.rightRowNumberSymbolsMapping = ImmutableMap.copyOf(requireNonNull(rightRowNumberSymbolsMapping, "rightRowNumberSymbolsMapping is null")); + } + + public PlanNode joinedNode() + { + return joinedNode; + } + public VariableReferenceExpression leftRowNumber() + { + return leftRowNumber; + } + public VariableReferenceExpression leftPartitionSize() + { + return leftPartitionSize; + } + public List leftPartitionBy() + { + return leftPartitionBy; + } + public boolean leftPruneWhenEmpty() + { + return leftPruneWhenEmpty; + } + public Map leftRowNumberSymbolsMapping() + { + return leftRowNumberSymbolsMapping; + } + public VariableReferenceExpression rightRowNumber() + { + return rightRowNumber; + } + public VariableReferenceExpression rightPartitionSize() + { + return rightPartitionSize; + } + public List rightPartitionBy() + { + return rightPartitionBy; + } + public boolean rightPruneWhenEmpty() + { + return rightPruneWhenEmpty; + } + public Map rightRowNumberSymbolsMapping() + { + return rightRowNumberSymbolsMapping; + } + } + + private static class NodeWithMarkers + { + private final PlanNode node; + private final Map variableToMarker; + + public NodeWithMarkers(PlanNode node, Map variableToMarker) + { + this.node = requireNonNull(node, "node is null"); + this.variableToMarker = ImmutableMap.copyOf(requireNonNull(variableToMarker, "symbolToMarker is null")); + } + + public PlanNode node() + { + return node; + } + + public Map variableToMarker() + { + return variableToMarker; + } + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/AddExchanges.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/AddExchanges.java index 0b7021f66baf4..047326398832a 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/AddExchanges.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/AddExchanges.java @@ -80,6 +80,7 @@ import com.facebook.presto.sql.planner.plan.SequenceNode; import com.facebook.presto.sql.planner.plan.StatisticsWriterNode; import com.facebook.presto.sql.planner.plan.TableFunctionNode; +import com.facebook.presto.sql.planner.plan.TableFunctionProcessorNode; import com.facebook.presto.sql.planner.plan.TopNRowNumberNode; import com.google.common.annotations.VisibleForTesting; import com.google.common.cache.CacheBuilder; @@ -100,6 +101,7 @@ import java.util.Iterator; import java.util.List; import java.util.Map; +import java.util.NoSuchElementException; import java.util.Optional; import java.util.Set; import java.util.function.Function; @@ -417,7 +419,59 @@ public PlanWithProperties visitWindow(WindowNode node, PreferredProperties prefe @Override public PlanWithProperties visitTableFunction(TableFunctionNode node, PreferredProperties preferredProperties) { - throw new UnsupportedOperationException("execution by operator is not yet implemented for table function " + node.getName()); + throw new IllegalStateException(format("Unexpected node: TableFunctionNode (%s)", node.getName())); + } + + @Override + public PlanWithProperties visitTableFunctionProcessor(TableFunctionProcessorNode node, PreferredProperties preferredProperties) + { + if (!node.getSource().isPresent()) { + return new PlanWithProperties(node, deriveProperties(node, ImmutableList.of())); + } + + if (!node.getSpecification().isPresent()) { + // node.getSpecification.isEmpty() indicates that there were no sources or a single source with row semantics. + // The case of no sources was addressed above. + // The case of a single source with row semantics is addressed here. A single source with row semantics can be distributed arbitrarily. + PlanWithProperties child = planChild(node, PreferredProperties.any()); + return rebaseAndDeriveProperties(node, child); + } + + List partitionBy = node.getSpecification().orElseThrow(NoSuchElementException::new).getPartitionBy(); + List> desiredProperties = new ArrayList<>(); + if (!partitionBy.isEmpty()) { + desiredProperties.add(new GroupingProperty<>(partitionBy)); + } + node.getSpecification().orElseThrow(NoSuchElementException::new) + .getOrderingScheme() + .ifPresent(orderingScheme -> + orderingScheme.getOrderByVariables().stream() + .map(variable -> new SortingProperty<>(variable, orderingScheme.getOrdering(variable))) + .forEach(desiredProperties::add)); + + PlanWithProperties child = planChild(node, PreferredProperties.partitionedWithLocal(ImmutableSet.copyOf(partitionBy), desiredProperties)); + + // TODO do not gather if already gathered + if (!node.isPruneWhenEmpty()) { + child = withDerivedProperties( + gatheringExchange(idAllocator.getNextId(), REMOTE_STREAMING, child.getNode()), + child.getProperties()); + } + else if (!isStreamPartitionedOn(child.getProperties(), partitionBy) && + !isNodePartitionedOn(child.getProperties(), partitionBy)) { + if (partitionBy.isEmpty()) { + child = withDerivedProperties( + gatheringExchange(idAllocator.getNextId(), REMOTE_STREAMING, child.getNode()), + child.getProperties()); + } + else { + child = withDerivedProperties( + partitionedExchange(idAllocator.getNextId(), REMOTE_STREAMING, child.getNode(), Partitioning.create(FIXED_HASH_DISTRIBUTION, partitionBy), node.getHashSymbol()), + child.getProperties()); + } + } + + return rebaseAndDeriveProperties(node, child); } @Override diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/AddLocalExchanges.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/AddLocalExchanges.java index 1081944a4b064..46d17de0d459c 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/AddLocalExchanges.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/AddLocalExchanges.java @@ -22,6 +22,7 @@ import com.facebook.presto.spi.VariableAllocator; import com.facebook.presto.spi.WarningCollector; import com.facebook.presto.spi.plan.AggregationNode; +import com.facebook.presto.spi.plan.DataOrganizationSpecification; import com.facebook.presto.spi.plan.DeleteNode; import com.facebook.presto.spi.plan.DistinctLimitNode; import com.facebook.presto.spi.plan.EquiJoinClause; @@ -59,6 +60,8 @@ import com.facebook.presto.sql.planner.plan.MergeWriterNode; import com.facebook.presto.sql.planner.plan.RowNumberNode; import com.facebook.presto.sql.planner.plan.StatisticsWriterNode; +import com.facebook.presto.sql.planner.plan.TableFunctionNode; +import com.facebook.presto.sql.planner.plan.TableFunctionProcessorNode; import com.facebook.presto.sql.planner.plan.TableWriterMergeNode; import com.facebook.presto.sql.planner.plan.TopNRowNumberNode; import com.google.common.collect.ImmutableList; @@ -67,6 +70,7 @@ import java.util.ArrayList; import java.util.Iterator; import java.util.List; +import java.util.NoSuchElementException; import java.util.Optional; import java.util.Set; @@ -111,6 +115,7 @@ import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableSet.toImmutableSet; import static com.google.common.collect.Iterables.getOnlyElement; +import static java.lang.String.format; import static java.util.Objects.requireNonNull; import static java.util.stream.Collectors.toList; @@ -500,6 +505,87 @@ public PlanWithProperties visitDelete(DeleteNode node, StreamPreferredProperties return deriveProperties(result, child.getProperties()); } + @Override + public PlanWithProperties visitTableFunction(TableFunctionNode node, StreamPreferredProperties parentPreferences) + { + throw new IllegalStateException(format("Unexpected node: TableFunctionNode (%s)", node.getName())); + } + + @Override + public PlanWithProperties visitTableFunctionProcessor(TableFunctionProcessorNode node, StreamPreferredProperties parentPreferences) + { + if (!node.getSource().isPresent()) { + return deriveProperties(node, ImmutableList.of()); + } + + if (!node.getSpecification().isPresent()) { + // node.getSpecification.isEmpty() indicates that there were no sources or a single source with row semantics. + // The case of no sources was addressed above. + // The case of a single source with row semantics is addressed here. Source's properties do not hold after the TableFunctionProcessorNode + PlanWithProperties child = planAndEnforce(node.getSource().orElseThrow(NoSuchElementException::new), StreamPreferredProperties.any(), StreamPreferredProperties.any()); + return rebaseAndDeriveProperties(node, ImmutableList.of(child)); + } + + List partitionBy = node.getSpecification().orElseThrow(NoSuchElementException::new).getPartitionBy(); + StreamPreferredProperties childRequirements; + if (!node.isPruneWhenEmpty()) { + childRequirements = singleStream(); + } + else { + childRequirements = parentPreferences + .constrainTo(node.getSource().orElseThrow(NoSuchElementException::new).getOutputVariables()) + .withDefaultParallelism(session) + .withPartitioning(partitionBy); + } + + PlanWithProperties child = planAndEnforce(node.getSource().orElseThrow(NoSuchElementException::new), childRequirements, childRequirements); + + List> desiredProperties = new ArrayList<>(); + if (!partitionBy.isEmpty()) { + desiredProperties.add(new GroupingProperty<>(partitionBy)); + } + node.getSpecification() + .flatMap(DataOrganizationSpecification::getOrderingScheme) + .ifPresent(orderingScheme -> + orderingScheme.getOrderByVariables().stream() + .map(variable -> new SortingProperty<>(variable, orderingScheme.getOrdering(variable))) + .forEach(desiredProperties::add)); + Iterator>> matchIterator = LocalProperties.match(child.getProperties().getLocalProperties(), desiredProperties).iterator(); + + Set prePartitionedInputs = ImmutableSet.of(); + if (!partitionBy.isEmpty()) { + Optional> groupingRequirement = matchIterator.next(); + Set unPartitionedInputs = groupingRequirement.map(LocalProperty::getColumns).orElse(ImmutableSet.of()); + prePartitionedInputs = partitionBy.stream() + .filter(symbol -> !unPartitionedInputs.contains(symbol)) + .collect(toImmutableSet()); + } + + int preSortedOrderPrefix = 0; + if (prePartitionedInputs.equals(ImmutableSet.copyOf(partitionBy))) { + while (matchIterator.hasNext() && !matchIterator.next().isPresent()) { + preSortedOrderPrefix++; + } + } + + TableFunctionProcessorNode result = new TableFunctionProcessorNode( + node.getId(), + node.getName(), + node.getProperOutputs(), + Optional.of(child.getNode()), + node.isPruneWhenEmpty(), + node.getPassThroughSpecifications(), + node.getRequiredVariables(), + node.getMarkerVariables(), + node.getSpecification(), + prePartitionedInputs, + preSortedOrderPrefix, + node.getHashSymbol(), + node.getHandle()); + + return deriveProperties(result, child.getProperties()); + } + @Override public PlanWithProperties visitMarkDistinct(MarkDistinctNode node, StreamPreferredProperties parentPreferences) { diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/PropertyDerivations.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/PropertyDerivations.java index 9dff1d3fb86d9..7dc181ca661e0 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/PropertyDerivations.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/PropertyDerivations.java @@ -25,6 +25,7 @@ import com.facebook.presto.spi.SortingProperty; import com.facebook.presto.spi.UniqueProperty; import com.facebook.presto.spi.plan.AggregationNode; +import com.facebook.presto.spi.plan.DataOrganizationSpecification; import com.facebook.presto.spi.plan.DeleteNode; import com.facebook.presto.spi.plan.DistinctLimitNode; import com.facebook.presto.spi.plan.EquiJoinClause; @@ -73,6 +74,8 @@ import com.facebook.presto.sql.planner.plan.SampleNode; import com.facebook.presto.sql.planner.plan.SequenceNode; import com.facebook.presto.sql.planner.plan.StatisticsWriterNode; +import com.facebook.presto.sql.planner.plan.TableFunctionNode; +import com.facebook.presto.sql.planner.plan.TableFunctionProcessorNode; import com.facebook.presto.sql.planner.plan.TableWriterMergeNode; import com.facebook.presto.sql.planner.plan.TopNRowNumberNode; import com.facebook.presto.sql.planner.plan.UpdateNode; @@ -111,6 +114,7 @@ import static com.google.common.base.Verify.verify; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static java.lang.String.format; import static java.util.stream.Collectors.toMap; public class PropertyDerivations @@ -287,6 +291,50 @@ public ActualProperties visitWindow(WindowNode node, List inpu .build(); } + @Override + public ActualProperties visitTableFunction(TableFunctionNode node, List inputProperties) + { + throw new IllegalStateException(format("Unexpected node: TableFunctionNode (%s)", node.getName())); + } + + @Override + public ActualProperties visitTableFunctionProcessor(TableFunctionProcessorNode node, List inputProperties) + { + ImmutableList.Builder> localProperties = ImmutableList.builder(); + + if (node.getSource().isPresent()) { + ActualProperties properties = Iterables.getOnlyElement(inputProperties); + + // Only the partitioning properties of the source are passed-through, because the pass-through mechanism preserves the partitioning values. + // Sorting properties might be broken because input rows can be shuffled or nulls can be inserted as the result of pass-through. + // Constant properties might be broken because nulls can be inserted as the result of pass-through. + if (!node.getPrePartitioned().isEmpty()) { + GroupingProperty prePartitionedProperty = new GroupingProperty<>(node.getPrePartitioned()); + for (LocalProperty localProperty : properties.getLocalProperties()) { + if (!prePartitionedProperty.isSimplifiedBy(localProperty)) { + break; + } + localProperties.add(localProperty); + } + } + } + + List partitionBy = node.getSpecification() + .map(DataOrganizationSpecification::getPartitionBy) + .orElse(ImmutableList.of()); + if (!partitionBy.isEmpty()) { + localProperties.add(new GroupingProperty<>(partitionBy)); + } + + // TODO add global single stream property when there's Specification present with no partitioning columns + + return ActualProperties.builder() + .local(localProperties.build()) + .build() + // Crop properties to output columns. + .translateVariable(variable -> node.getOutputVariables().contains(variable) ? Optional.of(variable) : Optional.empty()); + } + @Override public ActualProperties visitGroupId(GroupIdNode node, List inputProperties) { diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/PruneUnreferencedOutputs.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/PruneUnreferencedOutputs.java index 9ec4a67577777..201ec219823af 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/PruneUnreferencedOutputs.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/PruneUnreferencedOutputs.java @@ -68,6 +68,7 @@ import com.facebook.presto.sql.planner.plan.SequenceNode; import com.facebook.presto.sql.planner.plan.SimplePlanRewriter; import com.facebook.presto.sql.planner.plan.StatisticsWriterNode; +import com.facebook.presto.sql.planner.plan.TableFunctionProcessorNode; import com.facebook.presto.sql.planner.plan.TableWriterMergeNode; import com.facebook.presto.sql.planner.plan.TopNRowNumberNode; import com.facebook.presto.sql.planner.plan.UpdateNode; @@ -1084,5 +1085,25 @@ public PlanNode visitLateralJoin(LateralJoinNode node, RewriteContext> context) + { + return node.getSource().map(source -> new TableFunctionProcessorNode( + node.getId(), + node.getName(), + node.getProperOutputs(), + Optional.of(context.rewrite(source, ImmutableSet.copyOf(source.getOutputVariables()))), + node.isPruneWhenEmpty(), + node.getPassThroughSpecifications(), + node.getRequiredVariables(), + node.getMarkerVariables(), + node.getSpecification(), + node.getPrePartitioned(), + node.getPreSorted(), + node.getHashSymbol(), + node.getHandle() + )).orElse(node); + } } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/QueryCardinalityUtil.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/QueryCardinalityUtil.java index ffd4806665c2c..cd4d5207fccc8 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/QueryCardinalityUtil.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/QueryCardinalityUtil.java @@ -19,6 +19,7 @@ import com.facebook.presto.spi.plan.PlanNode; import com.facebook.presto.spi.plan.ProjectNode; import com.facebook.presto.spi.plan.ValuesNode; +import com.facebook.presto.spi.plan.WindowNode; import com.facebook.presto.sql.planner.iterative.GroupReference; import com.facebook.presto.sql.planner.iterative.Lookup; import com.facebook.presto.sql.planner.plan.EnforceSingleRowNode; @@ -102,6 +103,12 @@ public Range visitEnforceSingleRow(EnforceSingleRowNode node, Void context return Range.singleton(1L); } + @Override + public Range visitWindow(WindowNode node, Void context) + { + return node.getSource().accept(this, null); + } + @Override public Range visitAggregation(AggregationNode node, Void context) { diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/StreamPropertyDerivations.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/StreamPropertyDerivations.java index 612fb584566bb..3effedf6638fb 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/StreamPropertyDerivations.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/StreamPropertyDerivations.java @@ -63,6 +63,8 @@ import com.facebook.presto.sql.planner.plan.SampleNode; import com.facebook.presto.sql.planner.plan.SequenceNode; import com.facebook.presto.sql.planner.plan.StatisticsWriterNode; +import com.facebook.presto.sql.planner.plan.TableFunctionNode; +import com.facebook.presto.sql.planner.plan.TableFunctionProcessorNode; import com.facebook.presto.sql.planner.plan.TableWriterMergeNode; import com.facebook.presto.sql.planner.plan.TopNRowNumberNode; import com.facebook.presto.sql.planner.plan.UpdateNode; @@ -71,11 +73,13 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Iterables; +import com.google.common.collect.Sets; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.NoSuchElementException; import java.util.Objects; import java.util.Optional; import java.util.Set; @@ -607,6 +611,33 @@ public StreamProperties visitWindow(WindowNode node, List inpu return Iterables.getOnlyElement(inputProperties); } + @Override + public StreamProperties visitTableFunction(TableFunctionNode node, List inputProperties) + { + throw new IllegalStateException(format("Unexpected node: TableFunctionNode (%s)", node.getName())); + } + + @Override + public StreamProperties visitTableFunctionProcessor(TableFunctionProcessorNode node, List inputProperties) + { + if (!node.getSource().isPresent()) { + return StreamProperties.singleStream(); // TODO allow multiple; return partitioning properties + } + + StreamProperties properties = Iterables.getOnlyElement(inputProperties); + + Set passThroughInputs = Sets.intersection(ImmutableSet.copyOf(node.getSource().orElseThrow(NoSuchElementException::new).getOutputVariables()), ImmutableSet.copyOf(node.getOutputVariables())); + StreamProperties translatedProperties = properties.translate(column -> { + if (passThroughInputs.contains(column)) { + return Optional.of(column); + } + return Optional.empty(); + }); + // Mark as unordered since table functions have opaque logic that may reorder, generate, or filter rows + // even though partitioning properties are preserved for pass-through columns + return translatedProperties.unordered(true); + } + @Override public StreamProperties visitRowNumber(RowNumberNode node, List inputProperties) { diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/SymbolMapper.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/SymbolMapper.java index 6ab46ef344a13..824c392ce9619 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/SymbolMapper.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/SymbolMapper.java @@ -20,6 +20,7 @@ import com.facebook.presto.spi.WarningCollector; import com.facebook.presto.spi.plan.AggregationNode; import com.facebook.presto.spi.plan.AggregationNode.Aggregation; +import com.facebook.presto.spi.plan.DataOrganizationSpecification; import com.facebook.presto.spi.plan.ExchangeEncoding; import com.facebook.presto.spi.plan.Ordering; import com.facebook.presto.spi.plan.OrderingScheme; @@ -41,6 +42,8 @@ import com.facebook.presto.sql.planner.plan.MergeProcessorNode; import com.facebook.presto.sql.planner.plan.MergeWriterNode; import com.facebook.presto.sql.planner.plan.StatisticsWriterNode; +import com.facebook.presto.sql.planner.plan.TableFunctionNode; +import com.facebook.presto.sql.planner.plan.TableFunctionProcessorNode; import com.facebook.presto.sql.planner.plan.TableWriterMergeNode; import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.ExpressionRewriter; @@ -55,12 +58,14 @@ import java.util.List; import java.util.Map; import java.util.Map.Entry; +import java.util.Optional; import java.util.Set; import static com.facebook.presto.spi.StandardWarningCode.MULTIPLE_ORDER_BY; import static com.facebook.presto.spi.plan.AggregationNode.groupingSets; import static com.facebook.presto.sql.analyzer.ExpressionTreeUtils.getNodeLocation; import static com.facebook.presto.sql.planner.optimizations.PartitioningUtils.translateVariable; +import static com.google.common.base.Preconditions.checkState; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableMap.toImmutableMap; import static com.google.common.collect.ImmutableSet.toImmutableSet; @@ -146,6 +151,27 @@ public RowExpression rewriteVariableReference(VariableReferenceExpression variab }, value); } + public OrderingSchemeWithPreSortedPrefix map(OrderingScheme orderingScheme, int preSorted) + { + ImmutableList.Builder newOrderings = ImmutableList.builder(); + int newPreSorted = preSorted; + + Set added = new HashSet<>(orderingScheme.getOrderBy().size()); + + for (int i = 0; i < orderingScheme.getOrderBy().size(); i++) { + VariableReferenceExpression variable = orderingScheme.getOrderBy().get(i).getVariable(); + VariableReferenceExpression canonical = map(variable); + if (added.add(canonical)) { + newOrderings.add(new Ordering(canonical, orderingScheme.getOrdering(variable))); + } + else if (i < preSorted) { + newPreSorted--; + } + } + + return new OrderingSchemeWithPreSortedPrefix(new OrderingScheme(newOrderings.build()), newPreSorted); + } + public OrderingScheme map(OrderingScheme orderingScheme) { // SymbolMapper inlines symbol with multiple level reference (SymbolInliner only inline single level). @@ -388,6 +414,68 @@ public TableWriterMergeNode map(TableWriterMergeNode node, PlanNode source) node.getStatisticsAggregation().map(this::map)); } + public TableFunctionProcessorNode map(TableFunctionProcessorNode node, PlanNode source) + { + // rewrite and deduplicate pass-through specifications + // note: Potentially, pass-through symbols from different sources might be recognized as semantically identical, and rewritten + // to the same symbol. Currently, we retrieve the first occurrence of a symbol, and skip all the following occurrences. + // For better performance, we could pick the occurrence with "isPartitioningColumn" property, since the pass-through mechanism + // is more efficient for partitioning columns which are guaranteed to be constant within partition. + // TODO choose a partitioning column to be retrieved while deduplicating + ImmutableList.Builder newPassThroughSpecifications = ImmutableList.builder(); + Set newPassThroughVariables = new HashSet<>(); + for (TableFunctionNode.PassThroughSpecification specification : node.getPassThroughSpecifications()) { + ImmutableList.Builder newColumns = ImmutableList.builder(); + for (TableFunctionNode.PassThroughColumn column : specification.getColumns()) { + VariableReferenceExpression newVariable = map(column.getOutputVariables()); + if (newPassThroughVariables.add(newVariable)) { + newColumns.add(new TableFunctionNode.PassThroughColumn(newVariable, column.isPartitioningColumn())); + } + } + newPassThroughSpecifications.add(new TableFunctionNode.PassThroughSpecification(specification.isDeclaredAsPassThrough(), newColumns.build())); + } + + // rewrite required symbols without deduplication. the table function expects specific input layout + List> newRequiredVariables = node.getRequiredVariables().stream() + .map(list -> list.stream() + .map(this::map) + .collect(toImmutableList())) + .collect(toImmutableList()); + + // rewrite and deduplicate marker mapping + Optional> newMarkerVariables = node.getMarkerVariables() + .map(mapping -> mapping.entrySet().stream() + .collect(toImmutableMap( + entry -> map(entry.getKey()), + entry -> map(entry.getValue()), + (first, second) -> { + checkState(first.equals(second), "Ambiguous marker symbols: %s and %s", first, second); + return first; + }))); + + // rewrite and deduplicate specification + Optional newSpecification = node.getSpecification().map(specification -> mapAndDistinct(specification, node.getPreSorted())); + + return new TableFunctionProcessorNode( + node.getId(), + node.getName(), + node.getProperOutputs().stream() + .map(this::map) + .collect(toImmutableList()), + Optional.of(source), + node.isPruneWhenEmpty(), + newPassThroughSpecifications.build(), + newRequiredVariables, + newMarkerVariables, + newSpecification.map(SpecificationWithPreSortedPrefix::getSpecification), + node.getPrePartitioned().stream() + .map(this::map) + .collect(toImmutableSet()), + newSpecification.map(SpecificationWithPreSortedPrefix::getPreSorted).orElse(node.getPreSorted()), + node.getHashSymbol().map(this::map), + node.getHandle()); + } + private PartitioningScheme canonicalize(PartitioningScheme scheme, PlanNode source) { return new PartitioningScheme(translateVariable(scheme.getPartitioning(), this::map), @@ -437,6 +525,25 @@ private List mapAndDistinctVariable(List newOrderingScheme = specification.getOrderingScheme() + .map(orderingScheme -> map(orderingScheme, preSorted)); + + return new SpecificationWithPreSortedPrefix( + new DataOrganizationSpecification( + mapAndDistinctVariable(specification.getPartitionBy()), + newOrderingScheme.map(OrderingSchemeWithPreSortedPrefix::getOrderingScheme)), + newOrderingScheme.map(OrderingSchemeWithPreSortedPrefix::getPreSorted).orElse(preSorted)); + } + + DataOrganizationSpecification mapAndDistinct(DataOrganizationSpecification specification) + { + return new DataOrganizationSpecification( + mapAndDistinctVariable(specification.getPartitionBy()), + specification.getOrderingScheme().map(this::map)); + } + public static SymbolMapper.Builder builder(WarningCollector warningCollector) { return new Builder(warningCollector); @@ -468,4 +575,48 @@ public void put(VariableReferenceExpression from, VariableReferenceExpression to mappingsBuilder.put(from, to); } } + + private static class OrderingSchemeWithPreSortedPrefix + { + private final OrderingScheme orderingScheme; + private final int preSorted; + + public OrderingSchemeWithPreSortedPrefix(OrderingScheme orderingScheme, int preSorted) + { + this.orderingScheme = requireNonNull(orderingScheme, "orderingScheme is null"); + this.preSorted = preSorted; + } + + public OrderingScheme getOrderingScheme() + { + return orderingScheme; + } + + public int getPreSorted() + { + return preSorted; + } + } + + private static class SpecificationWithPreSortedPrefix + { + private final DataOrganizationSpecification specification; + private final int preSorted; + + public SpecificationWithPreSortedPrefix(DataOrganizationSpecification specification, int preSorted) + { + this.specification = requireNonNull(specification, "specification is null"); + this.preSorted = preSorted; + } + + public DataOrganizationSpecification getSpecification() + { + return specification; + } + + public int getPreSorted() + { + return preSorted; + } + } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/UnaliasSymbolReferences.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/UnaliasSymbolReferences.java index f0a6f30a8db29..1c5e2e506f282 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/UnaliasSymbolReferences.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/UnaliasSymbolReferences.java @@ -78,6 +78,7 @@ import com.facebook.presto.sql.planner.plan.SimplePlanRewriter; import com.facebook.presto.sql.planner.plan.StatisticsWriterNode; import com.facebook.presto.sql.planner.plan.TableFunctionNode; +import com.facebook.presto.sql.planner.plan.TableFunctionProcessorNode; import com.facebook.presto.sql.planner.plan.TableWriterMergeNode; import com.facebook.presto.sql.planner.plan.TopNRowNumberNode; import com.facebook.presto.sql.planner.plan.UpdateNode; @@ -86,6 +87,7 @@ import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; import com.google.common.collect.Iterables; import com.google.common.collect.Lists; @@ -161,6 +163,11 @@ private Rewriter(TypeProvider types, FunctionAndTypeManager functionAndTypeManag this.warningCollector = warningCollector; } + public Map getMapping() + { + return mapping; + } + @Override public PlanNode visitAggregation(AggregationNode node, RewriteContext context) { @@ -500,18 +507,78 @@ public PlanNode visitTableFinish(TableFinishNode node, RewriteContext cont @Override public PlanNode visitTableFunction(TableFunctionNode node, RewriteContext context) { + SymbolMapper mapper = new SymbolMapper(mapping, types, warningCollector); + + List newProperOutputs = node.getOutputVariables().stream() + .map(mapper::map) + .collect(toImmutableList()); + + ImmutableList.Builder newSources = ImmutableList.builder(); + ImmutableList.Builder newTableArgumentProperties = ImmutableList.builder(); + + for (int i = 0; i < node.getSources().size(); i++) { + PlanNode newSource = context.rewrite(node.getSources().get(i)); + newSources.add(newSource); + + // Use the mapping state from after processing the source for the input properties + SymbolMapper inputMapper = new SymbolMapper(mapping, types, warningCollector); + + TableFunctionNode.TableArgumentProperties properties = node.getTableArgumentProperties().get(i); + + Optional newSpecification = properties.getSpecification().map(inputMapper::mapAndDistinct); + TableFunctionNode.PassThroughSpecification newPassThroughSpecification = new TableFunctionNode.PassThroughSpecification( + properties.getPassThroughSpecification().isDeclaredAsPassThrough(), + properties.getPassThroughSpecification().getColumns().stream() + .map(column -> new TableFunctionNode.PassThroughColumn( + inputMapper.map(column.getOutputVariables()), + column.isPartitioningColumn())) + .collect(toImmutableList())); + newTableArgumentProperties.add(new TableFunctionNode.TableArgumentProperties( + properties.getArgumentName(), + properties.isRowSemantics(), + properties.isPruneWhenEmpty(), + newPassThroughSpecification, + inputMapper.map(properties.getRequiredColumns()), + newSpecification)); + } + return new TableFunctionNode( - node.getSourceLocation(), node.getId(), - Optional.empty(), node.getName(), node.getArguments(), - node.getOutputVariables(), - node.getSources(), - node.getTableArgumentProperties(), + newProperOutputs, + newSources.build(), + newTableArgumentProperties.build(), + node.getCopartitioningLists(), node.getHandle()); } + @Override + public PlanNode visitTableFunctionProcessor(TableFunctionProcessorNode node, RewriteContext context) + { + if (!node.getSource().isPresent()) { + SymbolMapper mapper = new SymbolMapper(mapping, types, warningCollector); + return new TableFunctionProcessorNode( + node.getId(), + node.getName(), + mapper.map(node.getProperOutputs()), + Optional.empty(), + node.isPruneWhenEmpty(), + ImmutableList.of(), + ImmutableList.of(), + Optional.empty(), + Optional.empty(), + ImmutableSet.of(), + 0, + node.getHashSymbol().map(mapper::map), + node.getHandle()); + } + PlanNode rewrittenSource = context.rewrite(node.getSource().get()); + SymbolMapper mapper = new SymbolMapper(mapping, types, warningCollector); + + return mapper.map(node, rewrittenSource); + } + @Override public PlanNode visitRowNumber(RowNumberNode node, RewriteContext context) { diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/InternalPlanVisitor.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/InternalPlanVisitor.java index a3e3ec3dc0d2c..6bfa05d32a9a8 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/InternalPlanVisitor.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/InternalPlanVisitor.java @@ -147,4 +147,9 @@ public R visitTableFunction(TableFunctionNode node, C context) { return visitPlan(node, context); } + + public R visitTableFunctionProcessor(TableFunctionProcessorNode node, C context) + { + return visitPlan(node, context); + } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/Patterns.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/Patterns.java index a8af658030818..1db8a8b817eb5 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/Patterns.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/Patterns.java @@ -245,6 +245,11 @@ public static Pattern tableFunction() return typeOf(TableFunctionNode.class); } + public static Pattern tableFunctionProcessor() + { + return typeOf(TableFunctionProcessorNode.class); + } + public static Pattern rowNumber() { return typeOf(RowNumberNode.class); diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/SimplePlanRewriter.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/SimplePlanRewriter.java index 22d4f18e42ff9..f87c1a1bba5c5 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/SimplePlanRewriter.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/SimplePlanRewriter.java @@ -61,6 +61,11 @@ public C get() return userContext; } + public SimplePlanRewriter getNodeRewriter() + { + return nodeRewriter; + } + /** * Invoke the rewrite logic recursively on children of the given node and swap it * out with an identical copy with the rewritten children diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/TableFunctionNode.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/TableFunctionNode.java index 97892523498c0..8838e82b48c91 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/TableFunctionNode.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/TableFunctionNode.java @@ -22,13 +22,17 @@ import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import com.google.errorprone.annotations.Immutable; +import java.util.Collection; import java.util.List; import java.util.Map; import java.util.Optional; import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.collect.ImmutableList.toImmutableList; import static java.util.Objects.requireNonNull; @Immutable @@ -40,6 +44,7 @@ public class TableFunctionNode private final List outputVariables; private final List sources; private final List tableArgumentProperties; + private final List> copartitioningLists; private final TableFunctionHandle handle; @JsonCreator @@ -50,9 +55,10 @@ public TableFunctionNode( @JsonProperty("outputVariables") List outputVariables, @JsonProperty("sources") List sources, @JsonProperty("tableArgumentProperties") List tableArgumentProperties, + @JsonProperty("copartitioningLists") List> copartitioningLists, @JsonProperty("handle") TableFunctionHandle handle) { - this(Optional.empty(), id, Optional.empty(), name, arguments, outputVariables, sources, tableArgumentProperties, handle); + this(Optional.empty(), id, Optional.empty(), name, arguments, outputVariables, sources, tableArgumentProperties, copartitioningLists, handle); } public TableFunctionNode( @@ -64,14 +70,18 @@ public TableFunctionNode( List outputVariables, List sources, List tableArgumentProperties, + List> copartitioningLists, TableFunctionHandle handle) { super(sourceLocation, id, statsEquivalentPlanNode); this.name = requireNonNull(name, "name is null"); - this.arguments = requireNonNull(arguments, "arguments is null"); - this.outputVariables = requireNonNull(outputVariables, "outputVariables is null"); - this.sources = requireNonNull(sources, "sources is null"); - this.tableArgumentProperties = requireNonNull(tableArgumentProperties, "tableArgumentProperties is null"); + this.arguments = ImmutableMap.copyOf(arguments); + this.outputVariables = ImmutableList.copyOf(outputVariables); + this.sources = ImmutableList.copyOf(sources); + this.tableArgumentProperties = ImmutableList.copyOf(tableArgumentProperties); + this.copartitioningLists = copartitioningLists.stream() + .map(ImmutableList::copyOf) + .collect(toImmutableList()); this.handle = requireNonNull(handle, "handle is null"); } @@ -87,8 +97,23 @@ public Map getArguments() return arguments; } - @JsonProperty + @Override public List getOutputVariables() + { + ImmutableList.Builder variables = ImmutableList.builder(); + variables.addAll(outputVariables); + + tableArgumentProperties.stream() + .map(TableArgumentProperties::getPassThroughSpecification) + .map(PassThroughSpecification::getColumns) + .flatMap(Collection::stream) + .map(PassThroughColumn::getOutputVariables) + .forEach(variables::add); + + return variables.build(); + } + + public List getProperOutputs() { return outputVariables; } @@ -99,6 +124,12 @@ public List getTableArgumentProperties() return tableArgumentProperties; } + @JsonProperty + public List> getCopartitioningLists() + { + return copartitioningLists; + } + @JsonProperty public TableFunctionHandle getHandle() { @@ -122,35 +153,47 @@ public R accept(InternalPlanVisitor visitor, C context) public PlanNode replaceChildren(List newSources) { checkArgument(sources.size() == newSources.size(), "wrong number of new children"); - return new TableFunctionNode(getId(), name, arguments, outputVariables, newSources, tableArgumentProperties, handle); + return new TableFunctionNode(getId(), name, arguments, outputVariables, newSources, tableArgumentProperties, copartitioningLists, handle); } @Override public PlanNode assignStatsEquivalentPlanNode(Optional statsEquivalentPlanNode) { - return new TableFunctionNode(getSourceLocation(), getId(), statsEquivalentPlanNode, name, arguments, outputVariables, sources, tableArgumentProperties, handle); + return new TableFunctionNode(getSourceLocation(), getId(), statsEquivalentPlanNode, name, arguments, outputVariables, sources, tableArgumentProperties, copartitioningLists, handle); } public static class TableArgumentProperties { + private final String argumentName; private final boolean rowSemantics; private final boolean pruneWhenEmpty; - private final boolean passThroughColumns; + private final PassThroughSpecification passThroughSpecification; + private final List requiredColumns; private final Optional specification; @JsonCreator public TableArgumentProperties( + @JsonProperty("argumentName") String argumentName, @JsonProperty("rowSemantics") boolean rowSemantics, @JsonProperty("pruneWhenEmpty") boolean pruneWhenEmpty, - @JsonProperty("passThroughColumns") boolean passThroughColumns, + @JsonProperty("passThroughSpecification") PassThroughSpecification passThroughSpecification, + @JsonProperty("requiredColumns") List requiredColumns, @JsonProperty("specification") Optional specification) { + this.argumentName = requireNonNull(argumentName, "argumentName is null"); this.rowSemantics = rowSemantics; this.pruneWhenEmpty = pruneWhenEmpty; - this.passThroughColumns = passThroughColumns; + this.passThroughSpecification = requireNonNull(passThroughSpecification, "passThroughSpecification is null"); + this.requiredColumns = ImmutableList.copyOf(requiredColumns); this.specification = requireNonNull(specification, "specification is null"); } + @JsonProperty + public String getArgumentName() + { + return argumentName; + } + @JsonProperty public boolean isRowSemantics() { @@ -164,15 +207,83 @@ public boolean isPruneWhenEmpty() } @JsonProperty - public boolean isPassThroughColumns() + public PassThroughSpecification getPassThroughSpecification() + { + return passThroughSpecification; + } + + @JsonProperty + public List getRequiredColumns() { - return passThroughColumns; + return requiredColumns; } @JsonProperty - public Optional specification() + public Optional getSpecification() { return specification; } } + + /** + * Specifies how columns from source tables are passed through to the output of a table function. + * This class manages both explicitly declared pass-through columns and partitioning columns + * that must be preserved in the output. + */ + public static class PassThroughSpecification + { + private final boolean declaredAsPassThrough; + private final List columns; + + @JsonCreator + public PassThroughSpecification( + @JsonProperty("declaredAsPassThrough") boolean declaredAsPassThrough, + @JsonProperty("columns") List columns) + { + this.declaredAsPassThrough = declaredAsPassThrough; + this.columns = ImmutableList.copyOf(requireNonNull(columns, "columns is null")); + checkArgument( + declaredAsPassThrough || this.columns.stream().allMatch(PassThroughColumn::isPartitioningColumn), + "non-partitioning pass-through column for non-pass-through source of a table function"); + } + + @JsonProperty + public boolean isDeclaredAsPassThrough() + { + return declaredAsPassThrough; + } + + @JsonProperty + public List getColumns() + { + return columns; + } + } + + public static class PassThroughColumn + { + private final VariableReferenceExpression outputVariables; + private final boolean isPartitioningColumn; + + @JsonCreator + public PassThroughColumn( + @JsonProperty("outputVariables") VariableReferenceExpression outputVariables, + @JsonProperty("partitioningColumn") boolean isPartitioningColumn) + { + this.outputVariables = requireNonNull(outputVariables, "symbol is null"); + this.isPartitioningColumn = isPartitioningColumn; + } + + @JsonProperty + public VariableReferenceExpression getOutputVariables() + { + return outputVariables; + } + + @JsonProperty + public boolean isPartitioningColumn() + { + return isPartitioningColumn; + } + } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/TableFunctionProcessorNode.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/TableFunctionProcessorNode.java new file mode 100644 index 0000000000000..851f776a2c90f --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/TableFunctionProcessorNode.java @@ -0,0 +1,286 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.plan; + +import com.facebook.presto.metadata.TableFunctionHandle; +import com.facebook.presto.spi.plan.DataOrganizationSpecification; +import com.facebook.presto.spi.plan.OrderingScheme; +import com.facebook.presto.spi.plan.PlanNode; +import com.facebook.presto.spi.plan.PlanNodeId; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.planner.plan.TableFunctionNode.PassThroughSpecification; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; + +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; + +import static com.facebook.presto.spi.function.table.Preconditions.checkArgument; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.Iterables.getOnlyElement; +import static java.util.Objects.requireNonNull; + +public class TableFunctionProcessorNode + extends InternalPlanNode +{ + private final String name; + + // symbols produced by the function + private final List properOutputs; + + // pre-planned sources + private final Optional source; + // TODO do we need the info of which source has row semantics, or is it already included in the joins / join distribution? + + // specifies whether the function should be pruned or executed when the input is empty + // pruneWhenEmpty is false if and only if all original input tables are KEEP WHEN EMPTY + private final boolean pruneWhenEmpty; + + // all source symbols to be produced on output, ordered as table argument specifications + private final List passThroughSpecifications; + + // symbols required from each source, ordered as table argument specifications + private final List> requiredVariables; + + // mapping from source symbol to helper "marker" symbol which indicates whether the source value is valid + // for processing or for pass-through. null value in the marker column indicates that the value at the same + // position in the source column should not be processed or passed-through. + // the mapping is only present if there are two or more sources. + // + // Example: + // Given two input tables T1(a,b) PARTITION BY a and T2(c, d) PARTITION BY c + // T1 partitions: T2 partitions: + // a | b c | d + // ---+--- ---+--- + // 1 | 10 5 | 50 + // 1 | 20 5 | 60 + // 1 | 30 6 | 90 + // 2 | 40 6 | 100 + // 2 | 50 6 | 110 + // + // TransformTableFunctionToTableFunctionProcessor creates a join that produces a cartesian product of partitions from each table, resulting in 4 partitions: + // + // Partition (a=1, c=5): + // a | b | marker_1 | c | d | marker_2 + // -----+------+----------+----+-----+---------- + // 1 | 10 | 1 | 5 | 50 | 1 (row 1 from both partitions) + // 1 | 20 | 2 | 5 | 60 | 2 (row 2 from both partitions) + // 1 | 30 | 3 | 5 | 50 | null (filler row for T2, real row 3 from T1) + // + // Partition (a=1, c=6): + // a | b | marker_1 | c | d | marker_2 + // -----+------+----------+----+-----+---------- + // 1 | 10 | 1 | 6 | 90 | 1 (row 1 from both partitions) + // 1 | 20 | 2 | 6 | 100 | 2 (row 2 from both partitions) + // 1 | 30 | 3 | 6 | 110 | 3 (row 3 from both partitions) + // + // Partition (a=2, c=5): + // a | b | marker_1 | c | d | marker_2 + // -----+------+----------+----+-----+---------- + // 2 | 40 | 1 | 5 | 50 | 1 (row 1 from both partitions) + // 2 | 50 | 2 | 5 | 60 | 2 (row 2 from both partitions) + // + // Partition (a=2, c=6): + // a | b | marker_1 | c | d | marker_2 + // -----+------+----------+----+-----+---------- + // 2 | 40 | 1 | 6 | 90 | 1 (row 1 from both partitions) + // 2 | 50 | 2 | 6 | 100 | 2 (row 2 from both partitions) + // 2 | 40 | null | 6 | 110 | 3 (filler row for T1, real row 3 from T2) + // + // markerVariables map: + // { + // VariableReferenceExpression(a) -> VariableReferenceExpression(marker_1), + // VariableReferenceExpression(b) -> VariableReferenceExpression(marker_1), + // VariableReferenceExpression(c) -> VariableReferenceExpression(marker_2), + // VariableReferenceExpression(d) -> VariableReferenceExpression(marker_2) + // } + // + // When marker_1 is null, columns a and b should not be processed or passed-through. + // When marker_2 is null, columns c and d should not be processed or passed-through. + + private final Optional> markerVariables; + + private final Optional specification; + private final Set prePartitioned; + private final int preSorted; + private final Optional hashSymbol; + + private final TableFunctionHandle handle; + + @JsonCreator + public TableFunctionProcessorNode( + @JsonProperty("id") PlanNodeId id, + @JsonProperty("name") String name, + @JsonProperty("properOutputs") List properOutputs, + @JsonProperty("source") Optional source, + @JsonProperty("pruneWhenEmpty") boolean pruneWhenEmpty, + @JsonProperty("passThroughSpecifications") List passThroughSpecifications, + @JsonProperty("requiredVariables") List> requiredVariables, + @JsonProperty("markerVariables") Optional> markerVariables, + @JsonProperty("specification") Optional specification, + @JsonProperty("prePartitioned") Set prePartitioned, + @JsonProperty("preSorted") int preSorted, + @JsonProperty("hashSymbol") Optional hashSymbol, + @JsonProperty("handle") TableFunctionHandle handle) + { + super(Optional.empty(), id, Optional.empty()); + this.name = requireNonNull(name, "name is null"); + this.properOutputs = ImmutableList.copyOf(properOutputs); + this.source = requireNonNull(source, "source is null"); + this.pruneWhenEmpty = pruneWhenEmpty; + this.passThroughSpecifications = ImmutableList.copyOf(passThroughSpecifications); + this.requiredVariables = requiredVariables.stream() + .map(ImmutableList::copyOf) + .collect(toImmutableList()); + this.markerVariables = markerVariables.map(ImmutableMap::copyOf); + this.specification = requireNonNull(specification, "specification is null"); + this.prePartitioned = ImmutableSet.copyOf(prePartitioned); + Set partitionBy = specification + .map(DataOrganizationSpecification::getPartitionBy) + .map(ImmutableSet::copyOf) + .orElse(ImmutableSet.of()); + checkArgument(partitionBy.containsAll(prePartitioned), "all pre-partitioned symbols must be contained in the partitioning list"); + this.preSorted = preSorted; + checkArgument( + specification + .flatMap(DataOrganizationSpecification::getOrderingScheme) + .map(OrderingScheme::getOrderBy) + .map(List::size) + .orElse(0) >= preSorted, + "the number of pre-sorted symbols cannot be greater than the number of all ordering symbols"); + checkArgument(preSorted == 0 || partitionBy.equals(prePartitioned), "to specify pre-sorted symbols, it is required that all partitioning symbols are pre-partitioned"); + this.hashSymbol = requireNonNull(hashSymbol, "hashSymbol is null"); + this.handle = requireNonNull(handle, "handle is null"); + } + + @JsonProperty + public String getName() + { + return name; + } + + @JsonProperty + public List getProperOutputs() + { + return properOutputs; + } + + @JsonProperty + public Optional getSource() + { + return source; + } + + @JsonProperty + public boolean isPruneWhenEmpty() + { + return pruneWhenEmpty; + } + + @JsonProperty + public List getPassThroughSpecifications() + { + return passThroughSpecifications; + } + + @JsonProperty + public List> getRequiredVariables() + { + return requiredVariables; + } + + @JsonProperty + public Optional> getMarkerVariables() + { + return markerVariables; + } + + @JsonProperty + public Optional getSpecification() + { + return specification; + } + + @JsonProperty + public Set getPrePartitioned() + { + return prePartitioned; + } + + @JsonProperty + public int getPreSorted() + { + return preSorted; + } + + @JsonProperty + public Optional getHashSymbol() + { + return hashSymbol; + } + + @JsonProperty + public TableFunctionHandle getHandle() + { + return handle; + } + + @JsonProperty + @Override + public List getSources() + { + return source.map(ImmutableList::of).orElse(ImmutableList.of()); + } + + @Override + public List getOutputVariables() + { + ImmutableList.Builder variables = ImmutableList.builder(); + + variables.addAll(properOutputs); + + passThroughSpecifications.stream() + .map(PassThroughSpecification::getColumns) + .flatMap(Collection::stream) + .map(TableFunctionNode.PassThroughColumn::getOutputVariables) + .forEach(variables::add); + + return variables.build(); + } + + @Override + public PlanNode assignStatsEquivalentPlanNode(Optional statsEquivalentPlanNode) + { + return this; + } + + @Override + public PlanNode replaceChildren(List newSources) + { + Optional newSource = newSources.isEmpty() ? Optional.empty() : Optional.of(getOnlyElement(newSources)); + return new TableFunctionProcessorNode(getId(), name, properOutputs, newSource, pruneWhenEmpty, passThroughSpecifications, requiredVariables, markerVariables, specification, prePartitioned, preSorted, hashSymbol, handle); + } + + @Override + public R accept(InternalPlanVisitor visitor, C context) + { + return visitor.visitTableFunctionProcessor(this, context); + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/planPrinter/PlanPrinter.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/planPrinter/PlanPrinter.java index 93c049ba781ec..c35f28525bb09 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/planPrinter/PlanPrinter.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/planPrinter/PlanPrinter.java @@ -35,6 +35,9 @@ import com.facebook.presto.spi.SourceLocation; import com.facebook.presto.spi.TableHandle; import com.facebook.presto.spi.function.FunctionHandle; +import com.facebook.presto.spi.function.table.Argument; +import com.facebook.presto.spi.function.table.DescriptorArgument; +import com.facebook.presto.spi.function.table.ScalarArgument; import com.facebook.presto.spi.plan.AbstractJoinNode; import com.facebook.presto.spi.plan.AggregationNode; import com.facebook.presto.spi.plan.Assignments; @@ -96,13 +99,14 @@ import com.facebook.presto.sql.planner.plan.LateralJoinNode; import com.facebook.presto.sql.planner.plan.MergeProcessorNode; import com.facebook.presto.sql.planner.plan.MergeWriterNode; -import com.facebook.presto.sql.planner.plan.OffsetNode; import com.facebook.presto.sql.planner.plan.RemoteSourceNode; import com.facebook.presto.sql.planner.plan.RowNumberNode; import com.facebook.presto.sql.planner.plan.SampleNode; import com.facebook.presto.sql.planner.plan.SequenceNode; import com.facebook.presto.sql.planner.plan.StatisticsWriterNode; import com.facebook.presto.sql.planner.plan.TableFunctionNode; +import com.facebook.presto.sql.planner.plan.TableFunctionNode.TableArgumentProperties; +import com.facebook.presto.sql.planner.plan.TableFunctionProcessorNode; import com.facebook.presto.sql.planner.plan.TableWriterMergeNode; import com.facebook.presto.sql.planner.plan.TopNRowNumberNode; import com.facebook.presto.sql.planner.plan.UpdateNode; @@ -114,6 +118,7 @@ import com.google.common.base.Functions; import com.google.common.base.Joiner; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.common.collect.ImmutableSortedMap; import com.google.common.collect.Iterables; @@ -122,11 +127,13 @@ import io.airlift.slice.Slice; import java.util.ArrayList; +import java.util.Collection; import java.util.Collections; import java.util.Comparator; import java.util.LinkedList; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.Optional; import java.util.Set; import java.util.function.Function; @@ -139,6 +146,7 @@ import static com.facebook.presto.execution.StageInfo.getAllStages; import static com.facebook.presto.expressions.DynamicFilters.extractDynamicFilters; import static com.facebook.presto.metadata.CastType.CAST; +import static com.facebook.presto.spi.function.table.DescriptorArgument.NULL_DESCRIPTOR; import static com.facebook.presto.sql.analyzer.ExpressionTreeUtils.createSymbolReference; import static com.facebook.presto.sql.planner.SortExpressionExtractor.getSortExpressionContext; import static com.facebook.presto.sql.planner.SystemPartitioningHandle.SINGLE_DISTRIBUTION; @@ -152,10 +160,12 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.ImmutableMap.toImmutableMap; import static java.lang.String.format; import static java.util.Arrays.stream; import static java.util.Objects.requireNonNull; import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static java.util.function.Function.identity; import static java.util.stream.Collectors.toList; public class PlanPrinter @@ -1356,13 +1366,6 @@ public Void visitLateralJoin(LateralJoinNode node, Void context) return processChildren(node, context); } - @Override - public Void visitOffset(OffsetNode node, Void context) - { - addNode(node, "Offset", format("[%s]", node.getCount())); - return processChildren(node, context); - } - @Override public Void visitTableFunction(TableFunctionNode node, Void context) { @@ -1371,9 +1374,177 @@ public Void visitTableFunction(TableFunctionNode node, Void context) "TableFunction", node.getName()); - checkArgument( - node.getSources().isEmpty() && node.getTableArgumentProperties().isEmpty(), - "Table or descriptor arguments are not yet supported in PlanPrinter"); + if (!node.getArguments().isEmpty()) { + nodeOutput.appendDetails("Arguments:"); + + Map tableArguments = node.getTableArgumentProperties().stream() + .collect(toImmutableMap(TableArgumentProperties::getArgumentName, identity())); + + node.getArguments().entrySet().stream() + .forEach(entry -> nodeOutput.appendDetailsLine(formatArgument(entry.getKey(), entry.getValue(), tableArguments))); + + if (!node.getCopartitioningLists().isEmpty()) { + nodeOutput.appendDetailsLine(node.getCopartitioningLists().stream() + .map(list -> list.stream().collect(Collectors.joining(", ", "(", ")"))) + .collect(Collectors.joining(", ", "Co-partition: [", "] "))); + } + } + + processChildren(node, context); + + return null; + } + + private String formatArgument(String argumentName, Argument argument, Map tableArguments) + { + if (argument instanceof ScalarArgument) { + ScalarArgument scalarArgument = (ScalarArgument) argument; + return formatScalarArgument(argumentName, scalarArgument); + } + if (argument instanceof DescriptorArgument) { + DescriptorArgument descriptorArgument = (DescriptorArgument) argument; + return formatDescriptorArgument(argumentName, descriptorArgument); + } + else { + TableArgumentProperties argumentProperties = tableArguments.get(argumentName); + return formatTableArgument(argumentName, argumentProperties); + } + } + + private String formatScalarArgument(String argumentName, ScalarArgument argument) + { + return format( + "%s => ScalarArgument{type=%s, value=%s}", + argumentName, + argument.getType().getDisplayName(), + argument.getValue()); + } + + private String formatDescriptorArgument(String argumentName, DescriptorArgument argument) + { + String descriptor; + if (argument.equals(NULL_DESCRIPTOR)) { + descriptor = "NULL"; + } + else { + descriptor = argument.getDescriptor().orElseThrow(() -> new IllegalStateException("Missing descriptor")).getFields().stream() + .map(field -> field.getName() + field.getType().map(type -> " " + type.getDisplayName()).orElse("")) + .collect(Collectors.joining(", ", "(", ")")); + } + return format("%s => DescriptorArgument{%s}", argumentName, descriptor); + } + + private String formatTableArgument(String argumentName, TableArgumentProperties argumentProperties) + { + List properties = new ArrayList<>(); + + if (argumentProperties.isRowSemantics()) { + properties.add("row semantics "); + } + argumentProperties.getSpecification().ifPresent(specification -> { + StringBuilder specificationBuilder = new StringBuilder(); + specificationBuilder + .append("partition by: [") + .append(Joiner.on(", ").join(specification.getPartitionBy())) + .append("]"); + specification.getOrderingScheme().ifPresent(orderingScheme -> { + specificationBuilder + .append(", order by: ") + .append(formatOrderingScheme(orderingScheme)); + }); + properties.add(specificationBuilder.toString()); + }); + + properties.add("required columns: [" + + Joiner.on(", ").join(argumentProperties.getRequiredColumns()) + "]"); + + if (argumentProperties.isPruneWhenEmpty()) { + properties.add("prune when empty"); + } + + if (argumentProperties.getPassThroughSpecification().isDeclaredAsPassThrough()) { + properties.add("pass through columns"); + } + + return format("%s => TableArgument{%s}", argumentName, Joiner.on(", ").join(properties)); + } + + private String formatOrderingScheme(OrderingScheme orderingScheme) + { + return formatCollection(orderingScheme.getOrderByVariables(), variable -> variable + " " + orderingScheme.getOrdering(variable)); + } + + private String formatOrderingScheme(OrderingScheme orderingScheme, int preSortedOrderPrefix) + { + List orderBy = Stream.concat( + orderingScheme.getOrderByVariables().stream() + .limit(preSortedOrderPrefix) + .map(variable -> "<" + variable + " " + orderingScheme.getOrdering(variable) + ">"), + orderingScheme.getOrderByVariables().stream() + .skip(preSortedOrderPrefix) + .map(variable -> variable + " " + orderingScheme.getOrdering(variable))) + .collect(toImmutableList()); + return formatCollection(orderBy, Objects::toString); + } + + public String formatCollection(Collection collection, Function formatter) + { + return collection.stream() + .map(formatter) + .collect(Collectors.joining(", ", "[", "]")); + } + + @Override + public Void visitTableFunctionProcessor(TableFunctionProcessorNode node, Void context) + { + ImmutableMap.Builder descriptor = ImmutableMap.builder(); + + descriptor.put("name", node.getName()); + + descriptor.put("properOutputs", format("[%s]", Joiner.on(", ").join(node.getProperOutputs()))); + + String specs = node.getPassThroughSpecifications().stream() + .map(spec -> spec.getColumns().stream() + .map(col -> col.getOutputVariables().toString()) + .collect(Collectors.joining(", ", "[", "]"))) + .collect(Collectors.joining(", ")); + descriptor.put("passThroughSymbols", format("[%s]", specs)); + + String requiredSymbols = node.getRequiredVariables().stream() + .map(vars -> vars.stream() + .map(VariableReferenceExpression::toString) + .collect(Collectors.joining(", ", "[", "]"))) + .collect(Collectors.joining(", ", "[", "]")); + descriptor.put("requiredSymbols", format("[%s]", requiredSymbols)); + + node.getSpecification().ifPresent(specification -> { + if (!specification.getPartitionBy().isEmpty()) { + List prePartitioned = specification.getPartitionBy().stream() + .filter(node.getPrePartitioned()::contains) + .collect(toImmutableList()); + + List notPrePartitioned = specification.getPartitionBy().stream() + .filter(column -> !node.getPrePartitioned().contains(column)) + .collect(toImmutableList()); + + StringBuilder builder = new StringBuilder(); + if (!prePartitioned.isEmpty()) { + builder.append(prePartitioned.stream() + .map(VariableReferenceExpression::toString) + .collect(Collectors.joining(", ", "<", ">"))); + if (!notPrePartitioned.isEmpty()) { + builder.append(", "); + } + } + if (!notPrePartitioned.isEmpty()) { + builder.append(Joiner.on(", ").join(notPrePartitioned)); + } + descriptor.put("partitionBy", format("[%s]", builder)); + } + specification.getOrderingScheme().ifPresent(orderingScheme -> descriptor.put("orderBy", formatOrderingScheme(orderingScheme, node.getPreSorted()))); + }); + + addNode(node, "TableFunctionProcessorNode" + descriptor.build()); return processChildren(node, context); } diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/sanity/ValidateDependenciesChecker.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/sanity/ValidateDependenciesChecker.java index bf93c5d23388f..1a88f259c882e 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/sanity/ValidateDependenciesChecker.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/sanity/ValidateDependenciesChecker.java @@ -74,6 +74,8 @@ import com.facebook.presto.sql.planner.plan.SequenceNode; import com.facebook.presto.sql.planner.plan.StatisticsWriterNode; import com.facebook.presto.sql.planner.plan.TableFunctionNode; +import com.facebook.presto.sql.planner.plan.TableFunctionNode.PassThroughColumn; +import com.facebook.presto.sql.planner.plan.TableFunctionProcessorNode; import com.facebook.presto.sql.planner.plan.TableWriterMergeNode; import com.facebook.presto.sql.planner.plan.TopNRowNumberNode; import com.facebook.presto.sql.planner.plan.UpdateNode; @@ -88,6 +90,7 @@ import static com.facebook.presto.spi.plan.JoinNode.checkLeftOutputVariablesBeforeRight; import static com.facebook.presto.sql.planner.optimizations.AggregationNodeUtils.extractAggregationUniqueVariables; import static com.facebook.presto.sql.planner.optimizations.IndexJoinOptimizer.IndexKeyTracer; +import static com.facebook.presto.sql.planner.plan.TableFunctionNode.PassThroughSpecification; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.collect.ImmutableSet.toImmutableSet; @@ -124,6 +127,117 @@ public Void visitPlan(PlanNode node, Set boundVaria @Override public Void visitTableFunction(TableFunctionNode node, Set boundSymbols) { + for (int i = 0; i < node.getSources().size(); i++) { + PlanNode source = node.getSources().get(i); + source.accept(this, boundSymbols); + Set inputs = createInputs(source, boundSymbols); + TableFunctionNode.TableArgumentProperties argumentProperties = node.getTableArgumentProperties().get(i); + + checkDependencies( + inputs, + argumentProperties.getRequiredColumns(), + "Invalid node. Required input symbols from source %s (%s) not in source plan output (%s)", + argumentProperties.getArgumentName(), + argumentProperties.getRequiredColumns(), + source.getOutputVariables()); + argumentProperties.getSpecification().ifPresent(specification -> { + checkDependencies( + inputs, + specification.getPartitionBy(), + "Invalid node. Partition by symbols for source %s (%s) not in source plan output (%s)", + argumentProperties.getArgumentName(), + specification.getPartitionBy(), + source.getOutputVariables()); + specification.getOrderingScheme().ifPresent(orderingScheme -> { + checkDependencies( + inputs, + orderingScheme.getOrderByVariables(), + "Invalid node. Order by symbols for source %s (%s) not in source plan output (%s)", + argumentProperties.getArgumentName(), + orderingScheme.getOrderBy(), + source.getOutputVariables()); + }); + }); + Set passThroughVariable = argumentProperties.getPassThroughSpecification().getColumns().stream() + .map(PassThroughColumn::getOutputVariables) + .collect(toImmutableSet()); + checkDependencies( + inputs, + passThroughVariable, + "Invalid node. Pass-through symbols for source %s (%s) not in source plan output (%s)", + argumentProperties.getArgumentName(), + passThroughVariable, + source.getOutputVariables()); + } + return null; + } + + @Override + public Void visitTableFunctionProcessor(TableFunctionProcessorNode node, Set boundVariables) + { + if (!node.getSource().isPresent()) { + return null; + } + + PlanNode source = node.getSource().get(); + source.accept(this, boundVariables); + + Set inputs = createInputs(source, boundVariables); + + Set passThroughSymbols = node.getPassThroughSpecifications().stream() + .map(PassThroughSpecification::getColumns) + .flatMap(Collection::stream) + .map(PassThroughColumn::getOutputVariables) + .collect(toImmutableSet()); + checkDependencies( + inputs, + passThroughSymbols, + "Invalid node. Pass-through symbols (%s) not in source plan output (%s)", + passThroughSymbols, + source.getOutputVariables()); + + Set requiredSymbols = node.getRequiredVariables().stream() + .flatMap(Collection::stream) + .collect(toImmutableSet()); + checkDependencies( + inputs, + requiredSymbols, + "Invalid node. Required symbols (%s) not in source plan output (%s)", + requiredSymbols, + source.getOutputVariables()); + + node.getMarkerVariables().ifPresent(mapping -> { + checkDependencies( + inputs, + mapping.keySet(), + "Invalid node. Source symbols (%s) not in source plan output (%s)", + mapping.keySet(), + source.getOutputVariables()); + checkDependencies( + inputs, + mapping.values(), + "Invalid node. Source marker symbols (%s) not in source plan output (%s)", + mapping.values(), + source.getOutputVariables()); + }); + + node.getSpecification().ifPresent(specification -> { + checkDependencies( + inputs, + specification.getPartitionBy(), + "Invalid node. Partition by symbols (%s) not in source plan output (%s)", + specification.getPartitionBy(), + source.getOutputVariables()); + specification.getOrderingScheme().ifPresent(orderingScheme -> { + checkDependencies( + inputs, + orderingScheme.getOrderByVariables(), + "Invalid node. Order by symbols (%s) not in source plan output (%s)", + orderingScheme.getOrderBy(), + source.getOutputVariables()); + }); + }); + return null; } diff --git a/presto-main-base/src/main/java/com/facebook/presto/testing/LocalQueryRunner.java b/presto-main-base/src/main/java/com/facebook/presto/testing/LocalQueryRunner.java index ed686664b10d2..bbd8dcfa29c0b 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/testing/LocalQueryRunner.java +++ b/presto-main-base/src/main/java/com/facebook/presto/testing/LocalQueryRunner.java @@ -125,6 +125,7 @@ import com.facebook.presto.operator.TableCommitContext; import com.facebook.presto.operator.TaskContext; import com.facebook.presto.operator.index.IndexJoinLookupStats; +import com.facebook.presto.operator.table.ExcludeColumns; import com.facebook.presto.server.NodeStatusNotificationManager; import com.facebook.presto.server.PluginManager; import com.facebook.presto.server.PluginManagerConfig; @@ -529,7 +530,8 @@ private LocalQueryRunner(Session defaultSession, FeaturesConfig featuresConfig, new ColumnPropertiesSystemTable(transactionManager, metadata), new AnalyzePropertiesSystemTable(transactionManager, metadata), new TransactionsSystemTable(metadata.getFunctionAndTypeManager(), transactionManager)), - ImmutableSet.of()); + ImmutableSet.of(), + ImmutableSet.of(new ExcludeColumns.ExcludeColumnsFunction())); BuiltInQueryAnalyzer queryAnalyzer = new BuiltInQueryAnalyzer(metadata, sqlParser, accessControl, Optional.empty(), metadataExtractorExecutor); BuiltInAnalyzerProvider analyzerProvider = new BuiltInAnalyzerProvider(queryAnalyzer); @@ -781,7 +783,8 @@ public void installPlugin(Plugin plugin) @Override public void createCatalog(String catalogName, String connectorName, Map properties) { - throw new UnsupportedOperationException(); + nodeManager.addCurrentNodeConnector(new ConnectorId(catalogName)); + connectorManager.createConnection(catalogName, connectorName, properties); } @Override diff --git a/presto-main-base/src/main/java/com/facebook/presto/util/GraphvizPrinter.java b/presto-main-base/src/main/java/com/facebook/presto/util/GraphvizPrinter.java index da4a7fc2cd4e0..5368f400f4646 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/util/GraphvizPrinter.java +++ b/presto-main-base/src/main/java/com/facebook/presto/util/GraphvizPrinter.java @@ -70,6 +70,8 @@ import com.facebook.presto.sql.planner.plan.SampleNode; import com.facebook.presto.sql.planner.plan.SequenceNode; import com.facebook.presto.sql.planner.plan.StatisticsWriterNode; +import com.facebook.presto.sql.planner.plan.TableFunctionNode; +import com.facebook.presto.sql.planner.plan.TableFunctionProcessorNode; import com.facebook.presto.sql.planner.plan.TableWriterMergeNode; import com.facebook.presto.sql.planner.plan.TopNRowNumberNode; import com.facebook.presto.sql.planner.plan.UpdateNode; @@ -136,6 +138,8 @@ private enum NodeType EXPLAIN_ANALYZE, UPDATE, MERGE, + TABLE_FUNCTION, + TABLE_FUNCTION_PROCESSOR } private static final Map NODE_COLORS = immutableEnumMap(ImmutableMap.builder() @@ -168,6 +172,8 @@ private enum NodeType .put(NodeType.EXPLAIN_ANALYZE, "cadetblue1") .put(NodeType.UPDATE, "blue") .put(NodeType.MERGE, "lightblue") + .put(NodeType.TABLE_FUNCTION, "mediumorchid3") + .put(NodeType.TABLE_FUNCTION_PROCESSOR, "steelblue3") .build()); static { @@ -409,6 +415,24 @@ public Void visitWindow(WindowNode node, Void context) return node.getSource().accept(this, context); } + @Override + public Void visitTableFunctionProcessor(TableFunctionProcessorNode node, Void context) + { + printNode(node, "Table Function Processor", NODE_COLORS.get(NodeType.TABLE_FUNCTION)); + if (node.getSource().isPresent()) { + node.getSource().get().accept(this, context); + } + return null; + } + + @Override + public Void visitTableFunction(TableFunctionNode node, Void context) + { + printNode(node, "Table Function Node", NODE_COLORS.get(NodeType.TABLE_FUNCTION)); + node.getSources().stream().map(source -> source.accept(this, context)); + return null; + } + @Override public Void visitRowNumber(RowNumberNode node, Void context) { diff --git a/presto-main-base/src/test/java/com/facebook/presto/connector/tvf/TestTVFConnectorFactory.java b/presto-main-base/src/test/java/com/facebook/presto/connector/tvf/TestTVFConnectorFactory.java index 2674c87d28cc6..f64d51a0c10f3 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/connector/tvf/TestTVFConnectorFactory.java +++ b/presto-main-base/src/test/java/com/facebook/presto/connector/tvf/TestTVFConnectorFactory.java @@ -47,8 +47,11 @@ import com.facebook.presto.spi.connector.ConnectorSplitManager; import com.facebook.presto.spi.connector.ConnectorTransactionHandle; import com.facebook.presto.spi.connector.TableFunctionApplicationResult; +import com.facebook.presto.spi.function.TableFunctionHandleResolver; +import com.facebook.presto.spi.function.TableFunctionSplitResolver; import com.facebook.presto.spi.function.table.ConnectorTableFunction; import com.facebook.presto.spi.function.table.ConnectorTableFunctionHandle; +import com.facebook.presto.spi.function.table.TableFunctionProcessorProvider; import com.facebook.presto.spi.schedule.NodeSelectionStrategy; import com.facebook.presto.spi.statistics.TableStatistics; import com.facebook.presto.spi.transaction.IsolationLevel; @@ -57,6 +60,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Sets; import java.util.Collections; import java.util.List; @@ -69,6 +73,7 @@ import java.util.stream.IntStream; import static com.facebook.presto.common.type.VarcharType.createUnboundedVarcharType; +import static com.facebook.presto.connector.tvf.TestTVFConnectorFactory.TestTVFConnector.TestTVFConnectorSplit.TEST_TVF_CONNECTOR_SPLIT; import static com.facebook.presto.spi.schedule.NodeSelectionStrategy.NO_PREFERENCE; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableMap.toImmutableMap; @@ -84,6 +89,10 @@ public class TestTVFConnectorFactory private final Supplier getTableStatistics; private final ApplyTableFunction applyTableFunction; private final Set tableFunctions; + private final Function tableFunctionProcessorProvider; + private final TestTvfTableFunctionHandleResolver tableFunctionHandleResolver; + private final TestTvfTableFunctionSplitResolver tableFunctionSplitResolver; + private final Function tableFunctionSplitsSources; private TestTVFConnectorFactory( Function> listSchemaNames, @@ -92,7 +101,11 @@ private TestTVFConnectorFactory( BiFunction> getColumnHandles, Supplier getTableStatistics, ApplyTableFunction applyTableFunction, - Set tableFunctions) + Set tableFunctions, + Function getTableFunctionProcessorProvider, + TestTvfTableFunctionHandleResolver tableFunctionHandleResolver, + TestTvfTableFunctionSplitResolver tableFunctionSplitResolver, + Function tableFunctionSplitsSources) { this.listSchemaNames = requireNonNull(listSchemaNames, "listSchemaNames is null"); this.listTables = requireNonNull(listTables, "listTables is null"); @@ -101,6 +114,10 @@ private TestTVFConnectorFactory( this.getTableStatistics = requireNonNull(getTableStatistics, "getTableStatistics is null"); this.applyTableFunction = requireNonNull(applyTableFunction, "applyTableFunction is null"); this.tableFunctions = requireNonNull(tableFunctions, "tableFunctions is null"); + this.tableFunctionProcessorProvider = requireNonNull(getTableFunctionProcessorProvider, "tableFunctionProcessorProvider is null"); + this.tableFunctionHandleResolver = requireNonNull(tableFunctionHandleResolver, "tableFunctionHandleResolver is null"); + this.tableFunctionSplitResolver = requireNonNull(tableFunctionSplitResolver, "tableFunctionSplitResolver is null"); + this.tableFunctionSplitsSources = requireNonNull(tableFunctionSplitsSources, "tableFunctionSplitsSources is null"); } @Override @@ -115,10 +132,22 @@ public ConnectorHandleResolver getHandleResolver() return new TestTVFHandleResolver(); } + @Override + public Optional getTableFunctionHandleResolver() + { + return Optional.of(tableFunctionHandleResolver); + } + + @Override + public Optional getTableFunctionSplitResolver() + { + return Optional.of(tableFunctionSplitResolver); + } + @Override public Connector create(String catalogName, Map config, ConnectorContext context) { - return new TestTVFConnector(context, listSchemaNames, listTables, getViews, getColumnHandles, getTableStatistics, applyTableFunction, tableFunctions); + return new TestTVFConnector(context, listSchemaNames, listTables, getViews, getColumnHandles, getTableStatistics, applyTableFunction, tableFunctions, tableFunctionProcessorProvider, tableFunctionSplitsSources); } public static Builder builder() @@ -154,7 +183,9 @@ public static class TestTVFConnector private final BiFunction> getColumnHandles; private final Supplier getTableStatistics; private final ApplyTableFunction applyTableFunction; + private final Function tableFunctionProcessorProvider; private final Set tableFunctions; + private final Function tableFunctionSplitsSources; public TestTVFConnector( ConnectorContext context, @@ -164,7 +195,9 @@ public TestTVFConnector( BiFunction> getColumnHandles, Supplier getTableStatistics, ApplyTableFunction applyTableFunction, - Set tableFunctions) + Set tableFunctions, + Function getTableFunctionProcessorProvider, + Function tableFunctionSplitsSources) { this.context = requireNonNull(context, "context is null"); this.listSchemaNames = requireNonNull(listSchemaNames, "listSchemaNames is null"); @@ -174,6 +207,8 @@ public TestTVFConnector( this.getTableStatistics = requireNonNull(getTableStatistics, "getTableStatistics is null"); this.applyTableFunction = requireNonNull(applyTableFunction, "applyTableFunction is null"); this.tableFunctions = requireNonNull(tableFunctions, "tableFunctions is null"); + this.tableFunctionProcessorProvider = requireNonNull(getTableFunctionProcessorProvider, "tableFunctionProcessorProvider is null"); + this.tableFunctionSplitsSources = requireNonNull(tableFunctionSplitsSources, "tableFunctionSplitsSources is null"); } @Override @@ -220,7 +255,15 @@ public ConnectorSplitManager getSplitManager() @Override public ConnectorSplitSource getSplits(ConnectorTransactionHandle transactionHandle, ConnectorSession session, ConnectorTableLayoutHandle layout, SplitSchedulingContext splitSchedulingContext) { - return new FixedSplitSource(Collections.singleton(TestTVFConnectorSplit.TEST_TVF_CONNECTOR_SPLIT)); + return new FixedSplitSource(Collections.singleton(TEST_TVF_CONNECTOR_SPLIT)); + } + + @Override + public ConnectorSplitSource getSplits(ConnectorTransactionHandle transaction, ConnectorSession session, ConnectorTableFunctionHandle functionHandle) + { + ConnectorSplitSource splits = tableFunctionSplitsSources.apply(functionHandle); + return requireNonNull(splits, "missing ConnectorSplitSource for table function handle " + + functionHandle.getClass().getSimpleName()); } }; } @@ -243,6 +286,12 @@ public Set getTableFunctions() return tableFunctions; } + @Override + public Function getTableFunctionProcessorProvider() + { + return tableFunctionProcessorProvider; + } + private class TestTVFConnectorMetadata implements ConnectorMetadata { @@ -382,6 +431,40 @@ private Map getColumnIndexes(SchemaTableName tableName) } } + public static class TestTvfTableFunctionHandleResolver + implements TableFunctionHandleResolver + { + Set> handles = Sets.newHashSet(); + + @Override + public Set> getTableFunctionHandleClasses() + { + return handles; + } + + public void addTableFunctionHandle(Class tableFunctionHandleClass) + { + handles.add(tableFunctionHandleClass); + } + } + + public static class TestTvfTableFunctionSplitResolver + implements TableFunctionSplitResolver + { + Set> handles = Sets.newHashSet(); + + @Override + public Set> getTableFunctionSplitClasses() + { + return handles; + } + + public void addSplitClass(Class splitClass) + { + handles.add(splitClass); + } + } + public static final class Builder { private Function> listSchemaNames = (session) -> ImmutableList.of(); @@ -396,6 +479,10 @@ public static final class Builder private Supplier getTableStatistics = TableStatistics::empty; private ApplyTableFunction applyTableFunction = (session, handle) -> Optional.empty(); private Set tableFunctions = ImmutableSet.of(); + private Function tableFunctionProcessorProvider = handle -> null; + private final TestTvfTableFunctionHandleResolver tableFunctionHandleResolver = new TestTvfTableFunctionHandleResolver(); + private TestTvfTableFunctionSplitResolver tableFunctionSplitResolver = new TestTvfTableFunctionSplitResolver(); + private Function tableFunctionSplitsSources = handle -> null; public Builder withListSchemaNames(Function> listSchemaNames) { @@ -439,14 +526,38 @@ public Builder withTableFunctions(Iterable tableFunction return this; } + public Builder withTableFunctionProcessorProvider(Function tableFunctionProcessorProvider) + { + this.tableFunctionProcessorProvider = tableFunctionProcessorProvider; + return this; + } + + public Builder withTableFunctionResolver(Class tableFunctionHandleclass) + { + this.tableFunctionHandleResolver.addTableFunctionHandle(tableFunctionHandleclass); + return this; + } + + public Builder withTableFunctionSplitResolver(Class splitClass) + { + this.tableFunctionSplitResolver.addSplitClass(splitClass); + return this; + } + public TestTVFConnectorFactory build() { - return new TestTVFConnectorFactory(listSchemaNames, listTables, getViews, getColumnHandles, getTableStatistics, applyTableFunction, tableFunctions); + return new TestTVFConnectorFactory(listSchemaNames, listTables, getViews, getColumnHandles, getTableStatistics, applyTableFunction, tableFunctions, tableFunctionProcessorProvider, tableFunctionHandleResolver, tableFunctionSplitResolver, tableFunctionSplitsSources); } private static T notSupported() { throw new UnsupportedOperationException(); } + + public Builder withTableFunctionSplitSource(Function sourceProvider) + { + tableFunctionSplitsSources = requireNonNull(sourceProvider, "sourceProvider is null"); + return this; + } } } diff --git a/presto-main-base/src/test/java/com/facebook/presto/connector/tvf/TestingTableFunctions.java b/presto-main-base/src/test/java/com/facebook/presto/connector/tvf/TestingTableFunctions.java index 96373d826b50a..21e739cb21feb 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/connector/tvf/TestingTableFunctions.java +++ b/presto-main-base/src/test/java/com/facebook/presto/connector/tvf/TestingTableFunctions.java @@ -13,8 +13,17 @@ */ package com.facebook.presto.connector.tvf; +import com.facebook.presto.common.Page; +import com.facebook.presto.common.block.Block; +import com.facebook.presto.common.block.BlockBuilder; import com.facebook.presto.common.predicate.TupleDomain; +import com.facebook.presto.common.type.RowType; import com.facebook.presto.spi.ConnectorSession; +import com.facebook.presto.spi.ConnectorSplit; +import com.facebook.presto.spi.ConnectorSplitSource; +import com.facebook.presto.spi.FixedSplitSource; +import com.facebook.presto.spi.HostAddress; +import com.facebook.presto.spi.NodeProvider; import com.facebook.presto.spi.SchemaTableName; import com.facebook.presto.spi.connector.ConnectorTransactionHandle; import com.facebook.presto.spi.function.SchemaFunctionName; @@ -26,18 +35,30 @@ import com.facebook.presto.spi.function.table.ReturnTypeSpecification; import com.facebook.presto.spi.function.table.ScalarArgument; import com.facebook.presto.spi.function.table.ScalarArgumentSpecification; +import com.facebook.presto.spi.function.table.TableArgument; import com.facebook.presto.spi.function.table.TableArgumentSpecification; import com.facebook.presto.spi.function.table.TableFunctionAnalysis; +import com.facebook.presto.spi.function.table.TableFunctionDataProcessor; +import com.facebook.presto.spi.function.table.TableFunctionProcessorProvider; +import com.facebook.presto.spi.function.table.TableFunctionProcessorState; +import com.facebook.presto.spi.function.table.TableFunctionSplitProcessor; +import com.facebook.presto.spi.schedule.NodeSelectionStrategy; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; import io.airlift.slice.Slice; +import org.openjdk.jol.info.ClassLayout; import java.util.Arrays; +import java.util.Collections; +import java.util.List; import java.util.Map; +import java.util.NoSuchElementException; import java.util.Optional; +import java.util.stream.IntStream; +import static com.facebook.presto.common.Utils.checkArgument; import static com.facebook.presto.common.type.BigintType.BIGINT; import static com.facebook.presto.common.type.BooleanType.BOOLEAN; import static com.facebook.presto.common.type.IntegerType.INTEGER; @@ -45,7 +66,15 @@ import static com.facebook.presto.spi.function.table.ReturnTypeSpecification.DescribedTable; import static com.facebook.presto.spi.function.table.ReturnTypeSpecification.GenericTable.GENERIC_TABLE; import static com.facebook.presto.spi.function.table.ReturnTypeSpecification.OnlyPassThrough.ONLY_PASS_THROUGH; +import static com.facebook.presto.spi.function.table.TableFunctionProcessorState.Finished.FINISHED; +import static com.facebook.presto.spi.function.table.TableFunctionProcessorState.Processed.produced; +import static com.facebook.presto.spi.function.table.TableFunctionProcessorState.Processed.usedInput; +import static com.facebook.presto.spi.function.table.TableFunctionProcessorState.Processed.usedInputAndProduced; +import static com.facebook.presto.spi.schedule.NodeSelectionStrategy.NO_PREFERENCE; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.Iterables.getOnlyElement; import static io.airlift.slice.Slices.utf8Slice; +import static java.lang.Math.toIntExact; import static java.util.Objects.requireNonNull; public class TestingTableFunctions @@ -67,18 +96,17 @@ public class TestingTableFunctions public static class TestConnectorTableFunction extends AbstractConnectorTableFunction { - private static final String TEST_FUNCTION = "test_function"; - + private static final String FUNCTION_NAME = "test_function"; public TestConnectorTableFunction() { - super(SCHEMA_NAME, TEST_FUNCTION, ImmutableList.of(), ReturnTypeSpecification.GenericTable.GENERIC_TABLE); + super(SCHEMA_NAME, FUNCTION_NAME, ImmutableList.of(), ReturnTypeSpecification.GenericTable.GENERIC_TABLE); } @Override public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) { return TableFunctionAnalysis.builder() - .handle(new TestingTableFunctionHandle(new SchemaFunctionName(SCHEMA_NAME, TEST_FUNCTION))) + .handle(new TestingTableFunctionHandle(new SchemaFunctionName(SCHEMA_NAME, FUNCTION_NAME))) .returnedType(new Descriptor(ImmutableList.of(new Descriptor.Field("c1", Optional.of(BOOLEAN))))) .build(); } @@ -87,11 +115,10 @@ public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransact public static class TestConnectorTableFunction2 extends AbstractConnectorTableFunction { - private static final String TEST_FUNCTION_2 = "test_function2"; - + private static final String FUNCTION_NAME = "test_function2"; public TestConnectorTableFunction2() { - super(SCHEMA_NAME, TEST_FUNCTION_2, ImmutableList.of(), ONLY_PASS_THROUGH); + super(SCHEMA_NAME, FUNCTION_NAME, ImmutableList.of(), ONLY_PASS_THROUGH); } @Override @@ -104,11 +131,10 @@ public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransact public static class NullArgumentsTableFunction extends AbstractConnectorTableFunction { - private static final String NULL_ARGUMENTS_FUNCTION = "null_arguments_function"; - + private static final String FUNCTION_NAME = "null_arguments_function"; public NullArgumentsTableFunction() { - super(SCHEMA_NAME, NULL_ARGUMENTS_FUNCTION, null, ONLY_PASS_THROUGH); + super(SCHEMA_NAME, FUNCTION_NAME, null, ONLY_PASS_THROUGH); } @Override @@ -121,12 +147,12 @@ public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransact public static class DuplicateArgumentsTableFunction extends AbstractConnectorTableFunction { - private static final String DUPLICATE_ARGUMENTS_FUNCTION = "duplicate_arguments_function"; + private static final String FUNCTION_NAME = "duplicate_arguments_function"; public DuplicateArgumentsTableFunction() { super( SCHEMA_NAME, - DUPLICATE_ARGUMENTS_FUNCTION, + FUNCTION_NAME, ImmutableList.of( ScalarArgumentSpecification.builder().name("a").type(INTEGER).build(), ScalarArgumentSpecification.builder().name("a").type(INTEGER).build()), @@ -143,12 +169,12 @@ public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransact public static class MultipleRSTableFunction extends AbstractConnectorTableFunction { - private static final String MULTIPLE_SOURCES_FUNCTION = "multiple_sources_function"; + private static final String FUNCTION_NAME = "multiple_sources_function"; public MultipleRSTableFunction() { super( SCHEMA_NAME, - MULTIPLE_SOURCES_FUNCTION, + FUNCTION_NAME, ImmutableList.of(TableArgumentSpecification.builder().name("t").rowSemantics().build(), TableArgumentSpecification.builder().name("t2").rowSemantics().build()), ONLY_PASS_THROUGH); @@ -172,7 +198,6 @@ public static class SimpleTableFunction { private static final String FUNCTION_NAME = "simple_table_function"; private static final String TABLE_NAME = "simple_table"; - public SimpleTableFunction() { super( @@ -227,11 +252,12 @@ public TestTVFConnectorTableHandle getTableHandle() public static class TwoScalarArgumentsFunction extends AbstractConnectorTableFunction { + private static final String FUNCTION_NAME = "two_scalar_arguments_function"; public TwoScalarArgumentsFunction() { super( SCHEMA_NAME, - "two_arguments_function", + FUNCTION_NAME, ImmutableList.of( ScalarArgumentSpecification.builder() .name("TEXT") @@ -256,7 +282,6 @@ public static class TableArgumentFunction extends AbstractConnectorTableFunction { public static final String FUNCTION_NAME = "table_argument_function"; - public TableArgumentFunction() { super( @@ -284,11 +309,12 @@ public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransact public static class DescriptorArgumentFunction extends AbstractConnectorTableFunction { + public static final String FUNCTION_NAME = "descriptor_argument_function"; public DescriptorArgumentFunction() { super( SCHEMA_NAME, - "descriptor_argument_function", + FUNCTION_NAME, ImmutableList.of( DescriptorArgumentSpecification.builder() .name("SCHEMA") @@ -327,11 +353,16 @@ public TestTVFConnectorTableHandle getTableHandle() public static class TestingTableFunctionHandle implements ConnectorTableFunctionHandle { + private final TestTVFConnectorTableHandle tableHandle; private final SchemaFunctionName schemaFunctionName; @JsonCreator public TestingTableFunctionHandle(@JsonProperty("schemaFunctionName") SchemaFunctionName schemaFunctionName) { + this.tableHandle = new TestTVFConnectorTableHandle( + new SchemaTableName(SCHEMA_NAME, TABLE_NAME), + Optional.of(ImmutableList.of(new TestTVFConnectorColumnHandle(COLUMN_NAME, BOOLEAN))), + TupleDomain.all()); this.schemaFunctionName = requireNonNull(schemaFunctionName, "schemaFunctionName is null"); } @@ -340,16 +371,22 @@ public SchemaFunctionName getSchemaFunctionName() { return schemaFunctionName; } + + public TestTVFConnectorTableHandle getTableHandle() + { + return tableHandle; + } } public static class TableArgumentRowSemanticsFunction extends AbstractConnectorTableFunction { + public static final String FUNCTION_NAME = "table_argument_row_semantics_function"; public TableArgumentRowSemanticsFunction() { super( SCHEMA_NAME, - "table_argument_row_semantics_function", + FUNCTION_NAME, ImmutableList.of( TableArgumentSpecification.builder() .name("INPUT") @@ -372,17 +409,20 @@ public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransact public static class TwoTableArgumentsFunction extends AbstractConnectorTableFunction { + public static final String FUNCTION_NAME = "two_table_arguments_function"; public TwoTableArgumentsFunction() { super( SCHEMA_NAME, - "two_table_arguments_function", + FUNCTION_NAME, ImmutableList.of( TableArgumentSpecification.builder() .name("INPUT1") + .keepWhenEmpty() .build(), TableArgumentSpecification.builder() .name("INPUT2") + .keepWhenEmpty() .build()), GENERIC_TABLE); } @@ -402,11 +442,12 @@ public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransact public static class OnlyPassThroughFunction extends AbstractConnectorTableFunction { + public static final String FUNCTION_NAME = "only_pass_through_function"; public OnlyPassThroughFunction() { super( SCHEMA_NAME, - "only_pass_through_function", + FUNCTION_NAME, ImmutableList.of( TableArgumentSpecification.builder() .name("INPUT") @@ -425,11 +466,12 @@ public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransact public static class MonomorphicStaticReturnTypeFunction extends AbstractConnectorTableFunction { + public static final String FUNCTION_NAME = "monomorphic_static_return_type_function"; public MonomorphicStaticReturnTypeFunction() { super( SCHEMA_NAME, - "monomorphic_static_return_type_function", + FUNCTION_NAME, ImmutableList.of(), new DescribedTable(Descriptor.descriptor( ImmutableList.of("a", "b"), @@ -448,11 +490,12 @@ public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransact public static class PolymorphicStaticReturnTypeFunction extends AbstractConnectorTableFunction { + public static final String FUNCTION_NAME = "polymorphic_static_return_type_function"; public PolymorphicStaticReturnTypeFunction() { super( SCHEMA_NAME, - "polymorphic_static_return_type_function", + FUNCTION_NAME, ImmutableList.of(TableArgumentSpecification.builder() .name("INPUT") .build()), @@ -471,14 +514,16 @@ public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransact public static class PassThroughFunction extends AbstractConnectorTableFunction { + public static final String FUNCTION_NAME = "pass_through_function"; public PassThroughFunction() { super( SCHEMA_NAME, - "pass_through_function", + FUNCTION_NAME, ImmutableList.of(TableArgumentSpecification.builder() .name("INPUT") .passThroughColumns() + .keepWhenEmpty() .build()), new DescribedTable(Descriptor.descriptor( ImmutableList.of("x"), @@ -495,14 +540,16 @@ public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransact public static class RequiredColumnsFunction extends AbstractConnectorTableFunction { + public static final String FUNCTION_NAME = "required_columns_function"; public RequiredColumnsFunction() { super( SCHEMA_NAME, - "required_columns_function", + FUNCTION_NAME, ImmutableList.of( TableArgumentSpecification.builder() .name("INPUT") + .keepWhenEmpty() .build()), GENERIC_TABLE); } @@ -517,4 +564,915 @@ public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransact .build(); } } + + public static class DifferentArgumentTypesFunction + extends AbstractConnectorTableFunction + { + public static final String FUNCTION_NAME = "different_arguments_function"; + public DifferentArgumentTypesFunction() + { + super( + SCHEMA_NAME, + FUNCTION_NAME, + ImmutableList.of( + TableArgumentSpecification.builder() + .name("INPUT_1") + .passThroughColumns() + .keepWhenEmpty() + .build(), + DescriptorArgumentSpecification.builder() + .name("LAYOUT") + .build(), + TableArgumentSpecification.builder() + .name("INPUT_2") + .rowSemantics() + .passThroughColumns() + .build(), + ScalarArgumentSpecification.builder() + .name("ID") + .type(BIGINT) + .build(), + TableArgumentSpecification.builder() + .name("INPUT_3") + .pruneWhenEmpty() + .build()), + GENERIC_TABLE); + } + + @Override + public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + { + return TableFunctionAnalysis.builder() + .handle(new TestingTableFunctionHandle(new SchemaFunctionName(SCHEMA_NAME, FUNCTION_NAME))) + .returnedType(new Descriptor(ImmutableList.of(new Descriptor.Field(COLUMN_NAME, Optional.of(BOOLEAN))))) + .requiredColumns("INPUT_1", ImmutableList.of(0)) + .requiredColumns("INPUT_2", ImmutableList.of(0)) + .requiredColumns("INPUT_3", ImmutableList.of(0)) + .build(); + } + } + + // for testing execution by operator + + public static class IdentityFunction + extends AbstractConnectorTableFunction + { + public static final String FUNCTION_NAME = "identity_function"; + public IdentityFunction() + { + super( + SCHEMA_NAME, + FUNCTION_NAME, + ImmutableList.of( + TableArgumentSpecification.builder() + .name("INPUT") + .keepWhenEmpty() + .build()), + GENERIC_TABLE); + } + + @Override + public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + { + List inputColumns = ((TableArgument) arguments.get("INPUT")).getRowType().getFields(); + Descriptor returnedType = new Descriptor(inputColumns.stream() + .map(field -> new Descriptor.Field(field.getName().orElse("anonymous_column"), Optional.of(field.getType()))) + .collect(toImmutableList())); + return TableFunctionAnalysis.builder() + .handle(new TestingTableFunctionHandle(new SchemaFunctionName(SCHEMA_NAME, FUNCTION_NAME))) + .returnedType(returnedType) + .requiredColumns("INPUT", IntStream.range(0, inputColumns.size()).boxed().collect(toImmutableList())) + .build(); + } + + public static class IdentityFunctionProcessorProvider + implements TableFunctionProcessorProvider + { + @Override + public TableFunctionDataProcessor getDataProcessor(ConnectorTableFunctionHandle handle) + { + return input -> { + if (input == null) { + return FINISHED; + } + Optional inputPage = getOnlyElement(input); + return inputPage.map(TableFunctionProcessorState.Processed::usedInputAndProduced).orElseThrow(NoSuchElementException::new); + }; + } + } + } + + public static class IdentityPassThroughFunction + extends AbstractConnectorTableFunction + { + public static final String FUNCTION_NAME = "identity_pass_through_function"; + public IdentityPassThroughFunction() + { + super( + SCHEMA_NAME, + FUNCTION_NAME, + ImmutableList.of( + TableArgumentSpecification.builder() + .name("INPUT") + .passThroughColumns() + .keepWhenEmpty() + .build()), + ONLY_PASS_THROUGH); + } + + @Override + public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + { + return TableFunctionAnalysis.builder() + .handle(new TestingTableFunctionHandle(new SchemaFunctionName(SCHEMA_NAME, FUNCTION_NAME))) + .requiredColumns("INPUT", ImmutableList.of(0)) // per spec, function must require at least one column + .build(); + } + + public static class IdentityPassThroughFunctionProcessorProvider + implements TableFunctionProcessorProvider + { + @Override + public TableFunctionDataProcessor getDataProcessor(ConnectorTableFunctionHandle handle) + { + return new IdentityPassThroughFunctionDataProcessor(); + } + } + + public static class IdentityPassThroughFunctionDataProcessor + implements TableFunctionDataProcessor + { + private long processedPositions; // stateful + + @Override + public TableFunctionProcessorState process(List> input) + { + if (input == null) { + return FINISHED; + } + + Page page = getOnlyElement(input).orElseThrow(NoSuchElementException::new); + BlockBuilder builder = BIGINT.createBlockBuilder(null, page.getPositionCount()); + for (long index = processedPositions; index < processedPositions + page.getPositionCount(); index++) { + // TODO check for long overflow + builder.writeLong(index); + } + processedPositions = processedPositions + page.getPositionCount(); + return usedInputAndProduced(new Page(builder.build())); + } + } + } + + public static class RepeatFunction + extends AbstractConnectorTableFunction + { + public static final String FUNCTION_NAME = "repeat"; + public RepeatFunction() + { + super( + SCHEMA_NAME, + FUNCTION_NAME, + ImmutableList.of( + TableArgumentSpecification.builder() + .name("INPUT") + .passThroughColumns() + .keepWhenEmpty() + .build(), + ScalarArgumentSpecification.builder() + .name("N") + .type(INTEGER) + .defaultValue(2L) + .build()), + ONLY_PASS_THROUGH); + } + + @Override + public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + { + ScalarArgument count = (ScalarArgument) arguments.get("N"); + requireNonNull(count.getValue(), "count value for function repeat() is null"); + checkArgument((long) count.getValue() > 0, "count value for function repeat() must be positive"); + + return TableFunctionAnalysis.builder() + .handle(new RepeatFunctionHandle((long) count.getValue())) + .requiredColumns("INPUT", ImmutableList.of(0)) // per spec, function must require at least one column + .build(); + } + + public static class RepeatFunctionHandle + implements ConnectorTableFunctionHandle + { + private final long count; + + @JsonCreator + public RepeatFunctionHandle(@JsonProperty("count") long count) + { + this.count = count; + } + + @JsonProperty + public long getCount() + { + return count; + } + } + + public static class RepeatFunctionProcessorProvider + implements TableFunctionProcessorProvider + { + @Override + public TableFunctionDataProcessor getDataProcessor(ConnectorTableFunctionHandle handle) + { + return new RepeatFunctionDataProcessor(((RepeatFunctionHandle) handle).getCount()); + } + } + + public static class RepeatFunctionDataProcessor + implements TableFunctionDataProcessor + { + private final long count; + + // stateful + private long processedPositions; + private long processedRounds; + private Block indexes; + boolean usedData; + + public RepeatFunctionDataProcessor(long count) + { + this.count = count; + } + + @Override + public TableFunctionProcessorState process(List> input) + { + if (input == null) { + if (processedRounds < count && indexes != null) { + processedRounds++; + return produced(new Page(indexes)); + } + return FINISHED; + } + + Page page = getOnlyElement(input).orElseThrow(NoSuchElementException::new); + if (processedRounds == 0) { + BlockBuilder builder = BIGINT.createBlockBuilder(null, page.getPositionCount()); + for (long index = processedPositions; index < processedPositions + page.getPositionCount(); index++) { + // TODO check for long overflow + builder.writeLong(index); + } + processedPositions = processedPositions + page.getPositionCount(); + indexes = builder.build(); + usedData = true; + } + else { + usedData = false; + } + processedRounds++; + + Page result = new Page(indexes); + + if (processedRounds == count) { + processedRounds = 0; + indexes = null; + } + + if (usedData) { + return usedInputAndProduced(result); + } + return produced(result); + } + } + } + + public static class EmptyOutputFunction + extends AbstractConnectorTableFunction + { + public static final String FUNCTION_NAME = "empty_output"; + public EmptyOutputFunction() + { + super( + SCHEMA_NAME, + FUNCTION_NAME, + ImmutableList.of(TableArgumentSpecification.builder() + .name("INPUT") + .keepWhenEmpty() + .build()), + new DescribedTable(new Descriptor(ImmutableList.of(new Descriptor.Field("column", Optional.of(BOOLEAN)))))); + } + + @Override + public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + { + return TableFunctionAnalysis.builder() + .handle(new TestingTableFunctionHandle(new SchemaFunctionName(SCHEMA_NAME, FUNCTION_NAME))) + .requiredColumns("INPUT", IntStream.range(0, ((TableArgument) arguments.get("INPUT")).getRowType().getFields().size()).boxed().collect(toImmutableList())) + .build(); + } + + public static class EmptyOutputProcessorProvider + implements TableFunctionProcessorProvider + { + @Override + public TableFunctionDataProcessor getDataProcessor(ConnectorTableFunctionHandle handle) + { + return new EmptyOutputDataProcessor(); + } + } + + // returns an empty Page (one column, zero rows) for each Page of input + private static class EmptyOutputDataProcessor + implements TableFunctionDataProcessor + { + private static final Page EMPTY_PAGE = new Page(BOOLEAN.createBlockBuilder(null, 0).build()); + + @Override + public TableFunctionProcessorState process(List> input) + { + if (input == null) { + return FINISHED; + } + return usedInputAndProduced(EMPTY_PAGE); + } + } + } + + public static class EmptyOutputWithPassThroughFunction + extends AbstractConnectorTableFunction + { + public static final String FUNCTION_NAME = "empty_output_with_pass_through"; + public EmptyOutputWithPassThroughFunction() + { + super( + SCHEMA_NAME, + FUNCTION_NAME, + ImmutableList.of(TableArgumentSpecification.builder() + .name("INPUT") + .keepWhenEmpty() + .passThroughColumns() + .build()), + new DescribedTable(new Descriptor(ImmutableList.of(new Descriptor.Field("column", Optional.of(BOOLEAN)))))); + } + + @Override + public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + { + return TableFunctionAnalysis.builder() + .handle(new TestingTableFunctionHandle(new SchemaFunctionName(SCHEMA_NAME, FUNCTION_NAME))) + .requiredColumns("INPUT", IntStream.range(0, ((TableArgument) arguments.get("INPUT")).getRowType().getFields().size()).boxed().collect(toImmutableList())) + .build(); + } + + public static class EmptyOutputWithPassThroughProcessorProvider + implements TableFunctionProcessorProvider + { + @Override + public TableFunctionDataProcessor getDataProcessor(ConnectorTableFunctionHandle handle) + { + return new EmptyOutputWithPassThroughDataProcessor(); + } + } + + // returns an empty Page (one proper column and pass-through, zero rows) for each Page of input + private static class EmptyOutputWithPassThroughDataProcessor + implements TableFunctionDataProcessor + { + // one proper channel, and one pass-through index channel + private static final Page EMPTY_PAGE = new Page( + BOOLEAN.createBlockBuilder(null, 0).build(), + BIGINT.createBlockBuilder(null, 0).build()); + + @Override + public TableFunctionProcessorState process(List> input) + { + if (input == null) { + return FINISHED; + } + return usedInputAndProduced(EMPTY_PAGE); + } + } + } + + public static class TestInputsFunction + extends AbstractConnectorTableFunction + { + public static final String FUNCTION_NAME = "test_inputs_function"; + public TestInputsFunction() + { + super( + SCHEMA_NAME, + FUNCTION_NAME, + ImmutableList.of( + TableArgumentSpecification.builder() + .rowSemantics() + .name("INPUT_1") + .build(), + TableArgumentSpecification.builder() + .name("INPUT_2") + .keepWhenEmpty() + .build(), + TableArgumentSpecification.builder() + .name("INPUT_3") + .keepWhenEmpty() + .build(), + TableArgumentSpecification.builder() + .name("INPUT_4") + .keepWhenEmpty() + .build()), + new DescribedTable(new Descriptor(ImmutableList.of(new Descriptor.Field("boolean_result", Optional.of(BOOLEAN)))))); + } + + @Override + public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + { + return TableFunctionAnalysis.builder() + .handle(new TestingTableFunctionHandle(new SchemaFunctionName(SCHEMA_NAME, FUNCTION_NAME))) + .requiredColumns("INPUT_1", IntStream.range(0, ((TableArgument) arguments.get("INPUT_1")).getRowType().getFields().size()).boxed().collect(toImmutableList())) + .requiredColumns("INPUT_2", IntStream.range(0, ((TableArgument) arguments.get("INPUT_2")).getRowType().getFields().size()).boxed().collect(toImmutableList())) + .requiredColumns("INPUT_3", IntStream.range(0, ((TableArgument) arguments.get("INPUT_3")).getRowType().getFields().size()).boxed().collect(toImmutableList())) + .requiredColumns("INPUT_4", IntStream.range(0, ((TableArgument) arguments.get("INPUT_4")).getRowType().getFields().size()).boxed().collect(toImmutableList())) + .build(); + } + + public static class TestInputsFunctionProcessorProvider + implements TableFunctionProcessorProvider + { + @Override + public TableFunctionDataProcessor getDataProcessor(ConnectorTableFunctionHandle handle) + { + BlockBuilder resultBuilder = BOOLEAN.createBlockBuilder(null, 1); + BOOLEAN.writeBoolean(resultBuilder, true); + + Page result = new Page(resultBuilder.build()); + + return input -> { + if (input == null) { + return FINISHED; + } + return usedInputAndProduced(result); + }; + } + } + } + + public static class PassThroughInputFunction + extends AbstractConnectorTableFunction + { + public static final String FUNCTION_NAME = "pass_through"; + public PassThroughInputFunction() + { + super( + SCHEMA_NAME, + FUNCTION_NAME, + ImmutableList.of( + TableArgumentSpecification.builder() + .name("INPUT_1") + .passThroughColumns() + .keepWhenEmpty() + .build(), + TableArgumentSpecification.builder() + .name("INPUT_2") + .passThroughColumns() + .keepWhenEmpty() + .build()), + new DescribedTable(new Descriptor(ImmutableList.of( + new Descriptor.Field("input_1_present", Optional.of(BOOLEAN)), + new Descriptor.Field("input_2_present", Optional.of(BOOLEAN)))))); + } + + @Override + public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + { + return TableFunctionAnalysis.builder() + .handle(new TestingTableFunctionHandle(new SchemaFunctionName(SCHEMA_NAME, FUNCTION_NAME))) + .requiredColumns("INPUT_1", ImmutableList.of(0)) + .requiredColumns("INPUT_2", ImmutableList.of(0)) + .build(); + } + + public static class PassThroughInputProcessorProvider + implements TableFunctionProcessorProvider + { + @Override + public TableFunctionDataProcessor getDataProcessor(ConnectorTableFunctionHandle handle) + { + return new PassThroughInputDataProcessor(); + } + } + + private static class PassThroughInputDataProcessor + implements TableFunctionDataProcessor + { + private boolean input1Present; + private boolean input2Present; + private int input1EndIndex; + private int input2EndIndex; + private boolean finished; + + @Override + public TableFunctionProcessorState process(List> input) + { + if (finished) { + return FINISHED; + } + if (input == null) { + finished = true; + + // proper column input_1_present + BlockBuilder input1Builder = BOOLEAN.createBlockBuilder(null, 1); + BOOLEAN.writeBoolean(input1Builder, input1Present); + + // proper column input_2_present + BlockBuilder input2Builder = BOOLEAN.createBlockBuilder(null, 1); + BOOLEAN.writeBoolean(input2Builder, input2Present); + + // pass-through index for input_1 + BlockBuilder input1PassThroughBuilder = BIGINT.createBlockBuilder(null, 1); + if (input1Present) { + input1PassThroughBuilder.writeLong(input1EndIndex - 1); + } + else { + input1PassThroughBuilder.appendNull(); + } + + // pass-through index for input_2 + BlockBuilder input2PassThroughBuilder = BIGINT.createBlockBuilder(null, 1); + if (input2Present) { + input2PassThroughBuilder.writeLong(input2EndIndex - 1); + } + else { + input2PassThroughBuilder.appendNull(); + } + + return produced(new Page(input1Builder.build(), input2Builder.build(), input1PassThroughBuilder.build(), input2PassThroughBuilder.build())); + } + input.get(0).ifPresent(page -> { + input1Present = true; + input1EndIndex += page.getPositionCount(); + }); + input.get(1).ifPresent(page -> { + input2Present = true; + input2EndIndex += page.getPositionCount(); + }); + return usedInput(); + } + } + } + + public static class TestInputFunction + extends AbstractConnectorTableFunction + { + public static final String FUNCTION_NAME = "test_input"; + public TestInputFunction() + { + super( + SCHEMA_NAME, + FUNCTION_NAME, + ImmutableList.of(TableArgumentSpecification.builder() + .name("INPUT") + .keepWhenEmpty() + .build()), + new DescribedTable(new Descriptor(ImmutableList.of(new Descriptor.Field("got_input", Optional.of(BOOLEAN)))))); + } + + @Override + public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + { + return TableFunctionAnalysis.builder() + .handle(new TestingTableFunctionHandle(new SchemaFunctionName(SCHEMA_NAME, FUNCTION_NAME))) + .requiredColumns("INPUT", IntStream.range(0, ((TableArgument) arguments.get("INPUT")).getRowType().getFields().size()).boxed().collect(toImmutableList())) + .build(); + } + + public static class TestInputProcessorProvider + implements TableFunctionProcessorProvider + { + @Override + public TableFunctionDataProcessor getDataProcessor(ConnectorTableFunctionHandle handle) + { + return new TestInputDataProcessor(); + } + } + + private static class TestInputDataProcessor + implements TableFunctionDataProcessor + { + private boolean processorGotInput; + private boolean finished; + + @Override + public TableFunctionProcessorState process(List> input) + { + if (finished) { + return FINISHED; + } + if (input == null) { + finished = true; + BlockBuilder builder = BOOLEAN.createBlockBuilder(null, 1); + BOOLEAN.writeBoolean(builder, processorGotInput); + return produced(new Page(builder.build())); + } + processorGotInput = true; + return usedInput(); + } + } + } + + public static class TestSingleInputRowSemanticsFunction + extends AbstractConnectorTableFunction + { + public static final String FUNCTION_NAME = "test_single_input_function"; + public TestSingleInputRowSemanticsFunction() + { + super( + SCHEMA_NAME, + FUNCTION_NAME, + ImmutableList.of(TableArgumentSpecification.builder() + .rowSemantics() + .name("INPUT") + .build()), + new DescribedTable(new Descriptor(ImmutableList.of(new Descriptor.Field("boolean_result", Optional.of(BOOLEAN)))))); + } + + @Override + public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + { + return TableFunctionAnalysis.builder() + .handle(new TestingTableFunctionHandle(new SchemaFunctionName(SCHEMA_NAME, FUNCTION_NAME))) + .requiredColumns("INPUT", IntStream.range(0, ((TableArgument) arguments.get("INPUT")).getRowType().getFields().size()).boxed().collect(toImmutableList())) + .build(); + } + + public static class TestSingleInputFunctionProcessorProvider + implements TableFunctionProcessorProvider + { + @Override + public TableFunctionDataProcessor getDataProcessor(ConnectorTableFunctionHandle handle) + { + BlockBuilder builder = BOOLEAN.createBlockBuilder(null, 1); + BOOLEAN.writeBoolean(builder, true); + Page result = new Page(builder.build()); + + return input -> { + if (input == null) { + return FINISHED; + } + return usedInputAndProduced(result); + }; + } + } + } + + public static class ConstantFunction + extends AbstractConnectorTableFunction + { + static final String FUNCTION_NAME = "constant"; + public ConstantFunction() + { + super( + SCHEMA_NAME, + FUNCTION_NAME, + ImmutableList.of( + ScalarArgumentSpecification.builder() + .name("VALUE") + .type(INTEGER) + .build(), + ScalarArgumentSpecification.builder() + .name("N") + .type(INTEGER) + .defaultValue(1L) + .build()), + new DescribedTable(Descriptor.descriptor( + ImmutableList.of("constant_column"), + ImmutableList.of(INTEGER)))); + } + + @Override + public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + { + ScalarArgument count = (ScalarArgument) arguments.get("N"); + requireNonNull(count.getValue(), "count value for function repeat() is null"); + checkArgument((long) count.getValue() > 0, "count value for function repeat() must be positive"); + + return TableFunctionAnalysis.builder() + .handle(new ConstantFunctionHandle((Long) ((ScalarArgument) arguments.get("VALUE")).getValue(), (long) count.getValue())) + .build(); + } + + public static class ConstantFunctionHandle + implements ConnectorTableFunctionHandle + { + private final Long value; + private final long count; + + @JsonCreator + public ConstantFunctionHandle(@JsonProperty("value") Long value, @JsonProperty("count") long count) + { + this.value = value; + this.count = count; + } + + @JsonProperty + public Long getValue() + { + return value; + } + + @JsonProperty + public long getCount() + { + return count; + } + } + + public static class ConstantFunctionProcessorProvider + implements TableFunctionProcessorProvider + { + @Override + public TableFunctionSplitProcessor getSplitProcessor(ConnectorTableFunctionHandle handle) + { + return new ConstantFunctionProcessor(((ConstantFunctionHandle) handle).getValue()); + } + } + + public static class ConstantFunctionProcessor + implements TableFunctionSplitProcessor + { + private static final int PAGE_SIZE = 1000; + + private final Long value; + + private long fullPagesCount; + private long processedPages; + private int reminder; + private Block block; + + public ConstantFunctionProcessor(Long value) + { + this.value = value; + } + + @Override + public TableFunctionProcessorState process(ConnectorSplit split) + { + boolean usedData = false; + + if (split != null) { + long count = ((ConstantFunctionSplit) split).getCount(); + this.fullPagesCount = count / PAGE_SIZE; + this.reminder = toIntExact(count % PAGE_SIZE); + if (fullPagesCount > 0) { + BlockBuilder builder = INTEGER.createBlockBuilder(null, PAGE_SIZE); + if (value == null) { + for (int i = 0; i < PAGE_SIZE; i++) { + builder.appendNull(); + } + } + else { + for (int i = 0; i < PAGE_SIZE; i++) { + builder.writeInt(toIntExact(value)); + } + } + this.block = builder.build(); + } + else { + BlockBuilder builder = INTEGER.createBlockBuilder(null, reminder); + if (value == null) { + for (int i = 0; i < reminder; i++) { + builder.appendNull(); + } + } + else { + for (int i = 0; i < reminder; i++) { + builder.writeInt(toIntExact(value)); + } + } + this.block = builder.build(); + } + usedData = true; + } + + if (processedPages < fullPagesCount) { + processedPages++; + Page result = new Page(block); + if (usedData) { + return usedInputAndProduced(result); + } + return produced(result); + } + + if (reminder > 0) { + Page result = new Page(block.getRegion(0, toIntExact(reminder))); + reminder = 0; + if (usedData) { + return usedInputAndProduced(result); + } + return produced(result); + } + + return FINISHED; + } + } + + public static ConnectorSplitSource getConstantFunctionSplitSource(ConstantFunctionHandle handle) + { + long splitSize = ConstantFunctionSplit.DEFAULT_SPLIT_SIZE; + ImmutableList.Builder splits = ImmutableList.builder(); + for (long i = 0; i < handle.getCount() / splitSize; i++) { + splits.add(new ConstantFunctionSplit(splitSize)); + } + long remainingSize = handle.getCount() % splitSize; + if (remainingSize > 0) { + splits.add(new ConstantFunctionSplit(remainingSize)); + } + return new FixedSplitSource(splits.build()); + } + + public static final class ConstantFunctionSplit + implements ConnectorSplit + { + private static final int INSTANCE_SIZE = toIntExact(ClassLayout.parseClass(ConstantFunctionSplit.class).instanceSize()); + public static final int DEFAULT_SPLIT_SIZE = 5500; + + private final long count; + + @JsonCreator + public ConstantFunctionSplit(@JsonProperty("count") long count) + { + this.count = count; + } + + @JsonProperty + public long getCount() + { + return count; + } + + @Override + public NodeSelectionStrategy getNodeSelectionStrategy() + { + return NO_PREFERENCE; + } + + @Override + public List getPreferredNodes(NodeProvider nodeProvider) + { + return Collections.emptyList(); + } + + @Override + public Object getInfo() + { + return count; + } + } + } + + public static class EmptySourceFunction + extends AbstractConnectorTableFunction + { + public static final String FUNCTION_NAME = "empty_source"; + public EmptySourceFunction() + { + super( + SCHEMA_NAME, + FUNCTION_NAME, + ImmutableList.of(), + new DescribedTable(new Descriptor(ImmutableList.of(new Descriptor.Field("column", Optional.of(BOOLEAN)))))); + } + + @Override + public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + { + return TableFunctionAnalysis.builder() + .handle(new TestingTableFunctionHandle(new SchemaFunctionName(SCHEMA_NAME, FUNCTION_NAME))) + .build(); + } + + public static class EmptySourceFunctionProcessorProvider + implements TableFunctionProcessorProvider + { + @Override + public TableFunctionSplitProcessor getSplitProcessor(ConnectorTableFunctionHandle handle) + { + return new EmptySourceFunctionProcessor(); + } + } + + public static class EmptySourceFunctionProcessor + implements TableFunctionSplitProcessor + { + private static final Page EMPTY_PAGE = new Page(BOOLEAN.createBlockBuilder(null, 0).build()); + + @Override + public TableFunctionProcessorState process(ConnectorSplit split) + { + if (split == null) { + return FINISHED; + } + + return usedInputAndProduced(EMPTY_PAGE); + } + } + } } diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/AbstractAnalyzerTest.java b/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/AbstractAnalyzerTest.java index 5c28e36be9241..458c86f986203 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/AbstractAnalyzerTest.java +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/AbstractAnalyzerTest.java @@ -587,11 +587,6 @@ protected void assertFails(SemanticErrorCode error, String message, @Language("S assertFails(CLIENT_SESSION, error, message, query, false); } - protected void assertFailsExact(SemanticErrorCode error, String message, @Language("SQL") String query) - { - assertFails(CLIENT_SESSION, error, message, query, true); - } - protected void assertFails(Session session, SemanticErrorCode error, @Language("SQL") String query) { assertFails(session, error, Optional.empty(), query); diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/TestAnalyzer.java b/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/TestAnalyzer.java index a74d59c1e2d03..c3236d3e6627a 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/TestAnalyzer.java +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/TestAnalyzer.java @@ -33,6 +33,7 @@ import org.testng.annotations.Test; import java.util.List; +import java.util.regex.Pattern; import static com.facebook.presto.metadata.SessionPropertyManager.createTestingSessionPropertyManager; import static com.facebook.presto.spi.StandardWarningCode.PERFORMANCE_WARNING; @@ -1979,59 +1980,59 @@ public void testTableFunctionNotFound() @Test public void testTableFunctionArguments() { - assertFails(TABLE_FUNCTION_INVALID_ARGUMENTS, "line 1:51: Too many arguments. Expected at most 2 arguments, got 3 arguments", "SELECT * FROM TABLE(system.two_arguments_function(1, 2, 3))"); + assertFails(TABLE_FUNCTION_INVALID_ARGUMENTS, "line 1:58: Too many arguments. Expected at most 2 arguments, got 3 arguments", "SELECT * FROM TABLE(system.two_scalar_arguments_function(1, 2, 3))"); - analyze("SELECT * FROM TABLE(system.two_arguments_function('foo'))"); - analyze("SELECT * FROM TABLE(system.two_arguments_function(text => 'foo'))"); - analyze("SELECT * FROM TABLE(system.two_arguments_function('foo', 1))"); - analyze("SELECT * FROM TABLE(system.two_arguments_function(text => 'foo', number => 1))"); + analyze("SELECT * FROM TABLE(system.two_scalar_arguments_function('foo'))"); + analyze("SELECT * FROM TABLE(system.two_scalar_arguments_function(text => 'foo'))"); + analyze("SELECT * FROM TABLE(system.two_scalar_arguments_function('foo', 1))"); + analyze("SELECT * FROM TABLE(system.two_scalar_arguments_function(text => 'foo', number => 1))"); assertFails(TABLE_FUNCTION_INVALID_ARGUMENTS, - "line 1:51: All arguments must be passed by name or all must be passed positionally", - "SELECT * FROM TABLE(system.two_arguments_function('foo', number => 1))"); + "line 1:58: All arguments must be passed by name or all must be passed positionally", + "SELECT * FROM TABLE(system.two_scalar_arguments_function('foo', number => 1))"); assertFails(TABLE_FUNCTION_INVALID_ARGUMENTS, - "line 1:51: All arguments must be passed by name or all must be passed positionally", - "SELECT * FROM TABLE(system.two_arguments_function(text => 'foo', 1))"); + "line 1:58: All arguments must be passed by name or all must be passed positionally", + "SELECT * FROM TABLE(system.two_scalar_arguments_function(text => 'foo', 1))"); assertFails(TABLE_FUNCTION_INVALID_FUNCTION_ARGUMENT, - "line 1:66: Duplicate argument name: TEXT", - "SELECT * FROM TABLE(system.two_arguments_function(text => 'foo', text => 'bar'))"); + "line 1:73: Duplicate argument name: TEXT", + "SELECT * FROM TABLE(system.two_scalar_arguments_function(text => 'foo', text => 'bar'))"); // argument names are resolved in the canonical form assertFails(TABLE_FUNCTION_INVALID_FUNCTION_ARGUMENT, - "line 1:66: Duplicate argument name: TEXT", - "SELECT * FROM TABLE(system.two_arguments_function(text => 'foo', TeXt => 'bar'))"); + "line 1:73: Duplicate argument name: TEXT", + "SELECT * FROM TABLE(system.two_scalar_arguments_function(text => 'foo', TeXt => 'bar'))"); assertFails(TABLE_FUNCTION_INVALID_FUNCTION_ARGUMENT, - "line 1:66: Unexpected argument name: BAR", - "SELECT * FROM TABLE(system.two_arguments_function(text => 'foo', bar => 'bar'))"); + "line 1:73: Unexpected argument name: BAR", + "SELECT * FROM TABLE(system.two_scalar_arguments_function(text => 'foo', bar => 'bar'))"); assertFails(TABLE_FUNCTION_MISSING_ARGUMENT, - "line 1:51: Missing argument: TEXT", - "SELECT * FROM TABLE(system.two_arguments_function(number => 1))"); + "line 1:58: Missing argument: TEXT", + "SELECT * FROM TABLE(system.two_scalar_arguments_function(number => 1))"); } @Test public void testScalarArgument() { - analyze("SELECT * FROM TABLE(system.two_arguments_function('foo', 1))"); + analyze("SELECT * FROM TABLE(system.two_scalar_arguments_function('foo', 1))"); assertFails(TABLE_FUNCTION_INVALID_FUNCTION_ARGUMENT, - "line 1:64: Invalid argument NUMBER. Expected expression, got descriptor", - "SELECT * FROM TABLE(system.two_arguments_function(text => 'a', number => DESCRIPTOR(x integer, y boolean)))"); + "line 1:71: Invalid argument NUMBER. Expected expression, got descriptor", + "SELECT * FROM TABLE(system.two_scalar_arguments_function(text => 'a', number => DESCRIPTOR(x integer, y boolean)))"); assertFails(TABLE_FUNCTION_INVALID_FUNCTION_ARGUMENT, - "line 1:64: 'descriptor' function is not allowed as a table function argument", - "SELECT * FROM TABLE(system.two_arguments_function(text => 'a', number => DESCRIPTOR(1 + 2)))"); + "line 1:71: 'descriptor' function is not allowed as a table function argument", + "SELECT * FROM TABLE(system.two_scalar_arguments_function(text => 'a', number => DESCRIPTOR(1 + 2)))"); assertFails(TABLE_FUNCTION_INVALID_FUNCTION_ARGUMENT, - "line 1:64: Invalid argument NUMBER. Expected expression, got table", - "SELECT * FROM TABLE(system.two_arguments_function(text => 'a', number => TABLE(t1)))"); + "line 1:71: Invalid argument NUMBER. Expected expression, got table", + "SELECT * FROM TABLE(system.two_scalar_arguments_function(text => 'a', number => TABLE(t1)))"); assertFails(EXPRESSION_NOT_CONSTANT, - "line 1:74: Constant expression cannot contain a subquery", - "SELECT * FROM TABLE(system.two_arguments_function(text => 'a', number => (SELECT 1)))"); + "line 1:81: Constant expression cannot contain a subquery", + "SELECT * FROM TABLE(system.two_scalar_arguments_function(text => 'a', number => (SELECT 1)))"); } @Test @@ -2127,8 +2128,8 @@ public void testDescriptorArgument() { analyze("SELECT * FROM TABLE(system.descriptor_argument_function(schema => DESCRIPTOR(x integer, y boolean)))"); - assertFailsExact(TABLE_FUNCTION_INVALID_FUNCTION_ARGUMENT, - "line 1:57: Invalid descriptor argument SCHEMA. Descriptors should be formatted as 'DESCRIPTOR(name [type], ...)'", + assertFails(TABLE_FUNCTION_INVALID_FUNCTION_ARGUMENT, + Pattern.quote("line 1:57: Invalid descriptor argument SCHEMA. Descriptors should be formatted as 'DESCRIPTOR(name [type], ...)'"), "SELECT * FROM TABLE(system.descriptor_argument_function(schema => DESCRIPTOR(1 + 2)))"); assertFails(TABLE_FUNCTION_INVALID_FUNCTION_ARGUMENT, @@ -2243,10 +2244,10 @@ public void testNullArguments() // the default value for the argument schema is null analyze("SELECT * FROM TABLE(system.descriptor_argument_function())"); - analyze("SELECT * FROM TABLE(system.two_arguments_function(null, null))"); + analyze("SELECT * FROM TABLE(system.two_scalar_arguments_function(null, null))"); // the default value for the second argument is null - analyze("SELECT * FROM TABLE(system.two_arguments_function('a'))"); + analyze("SELECT * FROM TABLE(system.two_scalar_arguments_function('a'))"); } @Test @@ -2258,8 +2259,8 @@ public void testTableFunctionInvocationContext() "SELECT * FROM TABLE(system.only_pass_through_function(TABLE(t1))) f(x)"); // per SQL standard, relation alias is required for table function with GENERIC TABLE return type. We don't require it. - analyze("SELECT * FROM TABLE(system.two_arguments_function('a', 1)) f(x)"); - analyze("SELECT * FROM TABLE(system.two_arguments_function('a', 1))"); + analyze("SELECT * FROM TABLE(system.two_scalar_arguments_function('a', 1)) f(x)"); + analyze("SELECT * FROM TABLE(system.two_scalar_arguments_function('a', 1))"); // per SQL standard, relation alias is required for table function with statically declared return type, only if the function is polymorphic. // We don't require aliasing polymorphic functions. @@ -2276,7 +2277,7 @@ public void testTableFunctionInvocationContext() // aliased + sampled assertFails(TABLE_FUNCTION_INVALID_TABLE_FUNCTION_INVOCATION, "line 1:15: Cannot apply sample to polymorphic table function invocation", - "SELECT * FROM TABLE(system.two_arguments_function('a', 1)) f(x) TABLESAMPLE BERNOULLI (10)"); + "SELECT * FROM TABLE(system.two_scalar_arguments_function('a', 1)) f(x) TABLESAMPLE BERNOULLI (10)"); } @Test @@ -2294,19 +2295,19 @@ public void testTableFunctionAliasing() analyze("SELECT * FROM TABLE(system.table_argument_function(TABLE(t1) t2)) T1(x)"); // the original returned relation type is ("column" : BOOLEAN) - analyze("SELECT column FROM TABLE(system.two_arguments_function('a', 1)) table_alias"); + analyze("SELECT column FROM TABLE(system.two_scalar_arguments_function('a', 1)) table_alias"); - analyze("SELECT column_alias FROM TABLE(system.two_arguments_function('a', 1)) table_alias(column_alias)"); + analyze("SELECT column_alias FROM TABLE(system.two_scalar_arguments_function('a', 1)) table_alias(column_alias)"); - analyze("SELECT table_alias.column_alias FROM TABLE(system.two_arguments_function('a', 1)) table_alias(column_alias)"); + analyze("SELECT table_alias.column_alias FROM TABLE(system.two_scalar_arguments_function('a', 1)) table_alias(column_alias)"); assertFails(MISSING_ATTRIBUTE, "line 1:8: Column 'column' cannot be resolved", - "SELECT column FROM TABLE(system.two_arguments_function('a', 1)) table_alias(column_alias)"); + "SELECT column FROM TABLE(system.two_scalar_arguments_function('a', 1)) table_alias(column_alias)"); assertFails(MISMATCHED_COLUMN_ALIASES, "line 1:20: Column alias list has 3 entries but table function has 1 proper columns", - "SELECT column FROM TABLE(system.two_arguments_function('a', 1)) table_alias(col1, col2, col3)"); + "SELECT column FROM TABLE(system.two_scalar_arguments_function('a', 1)) table_alias(col1, col2, col3)"); // the original returned relation type is ("a" : BOOLEAN, "b" : INTEGER) analyze("SELECT column_alias_1, column_alias_2 FROM TABLE(system.monomorphic_static_return_type_function()) table_alias(column_alias_1, column_alias_2)"); @@ -2348,8 +2349,10 @@ public void testTableFunctionRequiredColumns() "Invalid index: 1 of required column from table argument INPUT", "SELECT * FROM TABLE(system.required_columns_function(input => TABLE(SELECT 1)))"); - // table s1.t5 has two columns. The second column is hidden. Table function can require a hidden column. - analyze("SELECT * FROM TABLE(system.required_columns_function(input => TABLE(s1.t5)))"); + // table s1.t5 has two columns. The second column is hidden. Table function cannot require a hidden column. + assertFails(TABLE_FUNCTION_IMPLEMENTATION_ERROR, + "Invalid index: 1 of required column from table argument INPUT", + "SELECT * FROM TABLE(system.required_columns_function(input => TABLE(s1.t5)))"); } @Test diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/TestLogicalPlanner.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/TestLogicalPlanner.java index a8a69731f4361..46f2f031e858b 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/TestLogicalPlanner.java +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/TestLogicalPlanner.java @@ -1803,6 +1803,16 @@ public void testOffsetWithLimit() .withAlias("row_num", new RowNumberSymbolMatcher())))))); } + @Test + public void testRewriteExcludeColumnsFunctionToProjection() + { + assertPlan("SELECT *\n" + + "FROM TABLE(system.builtin.exclude_columns(\n" + + " INPUT => TABLE(orders),\n" + + " COLUMNS => DESCRIPTOR(comment)))\n", + output(tableScan("orders"))); + } + private Session noJoinReordering() { return Session.builder(this.getQueryRunner().getDefaultSession()) diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/TestTableFunctionInvocation.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/TestTableFunctionInvocation.java new file mode 100644 index 0000000000000..6d236432e1d15 --- /dev/null +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/TestTableFunctionInvocation.java @@ -0,0 +1,272 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner; + +import com.facebook.presto.connector.tvf.TestTVFConnectorFactory; +import com.facebook.presto.connector.tvf.TestTVFConnectorPlugin; +import com.facebook.presto.connector.tvf.TestingTableFunctions; +import com.facebook.presto.connector.tvf.TestingTableFunctions.DescriptorArgumentFunction; +import com.facebook.presto.connector.tvf.TestingTableFunctions.DifferentArgumentTypesFunction; +import com.facebook.presto.connector.tvf.TestingTableFunctions.TestingTableFunctionHandle; +import com.facebook.presto.connector.tvf.TestingTableFunctions.TwoScalarArgumentsFunction; +import com.facebook.presto.connector.tvf.TestingTableFunctions.TwoTableArgumentsFunction; +import com.facebook.presto.spi.connector.TableFunctionApplicationResult; +import com.facebook.presto.spi.function.table.Descriptor; +import com.facebook.presto.spi.function.table.Descriptor.Field; +import com.facebook.presto.sql.planner.assertions.BasePlanTest; +import com.facebook.presto.sql.planner.assertions.RowNumberSymbolMatcher; +import com.facebook.presto.sql.planner.plan.TableFunctionProcessorNode; +import com.facebook.presto.sql.tree.LongLiteral; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +import java.util.Optional; + +import static com.facebook.presto.common.block.SortOrder.ASC_NULLS_LAST; +import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.common.type.BooleanType.BOOLEAN; +import static com.facebook.presto.sql.Optimizer.PlanStage.CREATED; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.anyTree; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.expression; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.node; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.output; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.project; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.rowNumber; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.specification; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.strictOutput; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.strictProject; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.tableFunction; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.tableFunctionProcessor; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; +import static com.facebook.presto.sql.planner.assertions.TableFunctionMatcher.DescriptorArgumentValue.descriptorArgument; +import static com.facebook.presto.sql.planner.assertions.TableFunctionMatcher.DescriptorArgumentValue.nullDescriptor; +import static com.facebook.presto.sql.planner.assertions.TableFunctionMatcher.TableArgumentValue.Builder.tableArgument; + +public class TestTableFunctionInvocation + extends BasePlanTest +{ + private static final String TESTING_CATALOG = "test"; + + @BeforeClass + public final void setup() + { + getQueryRunner().installPlugin(new TestTVFConnectorPlugin(TestTVFConnectorFactory.builder() + .withTableFunctions(ImmutableSet.of( + new DifferentArgumentTypesFunction(), + new TwoScalarArgumentsFunction(), + new TwoTableArgumentsFunction(), + new DescriptorArgumentFunction(), + new TestingTableFunctions.PassThroughFunction())) + .withApplyTableFunction((session, handle) -> { + if (handle instanceof TestingTableFunctionHandle) { + TestingTableFunctionHandle functionHandle = (TestingTableFunctionHandle) handle; + return Optional.of(new TableFunctionApplicationResult<>(functionHandle.getTableHandle(), functionHandle.getTableHandle().getColumns().orElseThrow(() -> new IllegalStateException("Missing columns")))); + } + throw new IllegalStateException("Unsupported table function handle: " + handle.getClass().getSimpleName()); + }) + .build())); + getQueryRunner().createCatalog(TESTING_CATALOG, "testTVF", ImmutableMap.of()); + } + + @Test + public void testTableFunctionInitialPlan() + { + assertPlan( + "SELECT * FROM TABLE(test.system.different_arguments_function(" + + "INPUT_1 => TABLE(SELECT 'a') t1(c1) PARTITION BY c1 ORDER BY c1," + + "INPUT_3 => TABLE(SELECT 'b') t3(c3) PARTITION BY c3," + + "INPUT_2 => TABLE(VALUES 1) t2(c2)," + + "ID => BIGINT '2001'," + + "LAYOUT => DESCRIPTOR (x boolean, y bigint)" + + "COPARTITION (t1, t3))) t", + CREATED, + anyTree(tableFunction(builder -> builder + .name("different_arguments_function") + .addTableArgument( + "INPUT_1", + tableArgument(0) + .specification(specification(ImmutableList.of("c1"), ImmutableList.of("c1"), ImmutableMap.of("c1", ASC_NULLS_LAST))) + .passThroughVariables(ImmutableSet.of("c1")) + .passThroughColumns()) + .addTableArgument( + "INPUT_3", + tableArgument(2) + .specification(specification(ImmutableList.of("c3"), ImmutableList.of(), ImmutableMap.of())) + .pruneWhenEmpty() + .passThroughVariables(ImmutableSet.of("c3"))) + .addTableArgument( + "INPUT_2", + tableArgument(1) + .rowSemantics() + .passThroughVariables(ImmutableSet.of("c2")) + .passThroughColumns()) + .addScalarArgument("ID", 2001L) + .addDescriptorArgument( + "LAYOUT", + descriptorArgument(new Descriptor(ImmutableList.of( + new Field("X", Optional.of(BOOLEAN)), + new Field("Y", Optional.of(BIGINT)))))) + .addCopartitioning(ImmutableList.of("INPUT_1", "INPUT_3")) + .properOutputs(ImmutableList.of("OUTPUT")), + anyTree(project(ImmutableMap.of("c1", expression("'a'")), values(1))), + anyTree(values(ImmutableList.of("c2"), ImmutableList.of(ImmutableList.of(new LongLiteral("1"))))), + anyTree(project(ImmutableMap.of("c3", expression("'b'")), values(1)))))); + } + + @Test + public void testTableFunctionInitialPlanWithCoercionForCopartitioning() + { + assertPlan("SELECT * FROM TABLE(test.system.two_table_arguments_function(" + + "INPUT1 => TABLE(VALUES SMALLINT '1') t1(c1) PARTITION BY c1," + + "INPUT2 => TABLE(VALUES INTEGER '2') t2(c2) PARTITION BY c2 " + + "COPARTITION (t1, t2))) t", + CREATED, + anyTree(tableFunction(builder -> builder + .name("two_table_arguments_function") + .addTableArgument( + "INPUT1", + tableArgument(0) + .specification(specification(ImmutableList.of("c1_coerced"), ImmutableList.of(), ImmutableMap.of())) + .passThroughVariables(ImmutableSet.of("c1"))) + .addTableArgument( + "INPUT2", + tableArgument(1) + .specification(specification(ImmutableList.of("c2"), ImmutableList.of(), ImmutableMap.of())) + .passThroughVariables(ImmutableSet.of("c2"))) + .addCopartitioning(ImmutableList.of("INPUT1", "INPUT2")) + .properOutputs(ImmutableList.of("COLUMN")), + project(ImmutableMap.of("c1_coerced", expression("CAST(c1 AS INTEGER)")), + anyTree(values(ImmutableList.of("c1"), ImmutableList.of(ImmutableList.of(new LongLiteral("1")))))), + anyTree(values(ImmutableList.of("c2"), ImmutableList.of(ImmutableList.of(new LongLiteral("2")))))))); + } + + @Test + public void testNullScalarArgument() + { + // the argument NUMBER has null default value + assertPlan( + " SELECT * FROM TABLE(test.system.two_scalar_arguments_function(TEXT => null))", + CREATED, + anyTree(tableFunction(builder -> builder + .name("two_scalar_arguments_function") + .addScalarArgument("TEXT", null) + .addScalarArgument("NUMBER", null) + .properOutputs(ImmutableList.of("OUTPUT"))))); + } + + @Test + public void testNullDescriptorArgument() + { + assertPlan( + " SELECT * FROM TABLE(test.system.descriptor_argument_function(SCHEMA => CAST(null AS DESCRIPTOR)))", + CREATED, + anyTree(tableFunction(builder -> builder + .name("descriptor_argument_function") + .addDescriptorArgument("SCHEMA", nullDescriptor()) + .properOutputs(ImmutableList.of("OUTPUT"))))); + + // the argument SCHEMA has null default value + assertPlan( + " SELECT * FROM TABLE(test.system.descriptor_argument_function())", + CREATED, + anyTree(tableFunction(builder -> builder + .name("descriptor_argument_function") + .addDescriptorArgument("SCHEMA", nullDescriptor()) + .properOutputs(ImmutableList.of("OUTPUT"))))); + } + + @Test + public void testPruneTableFunctionColumns() + { + // all table function outputs are referenced with SELECT *, no pruning + assertPlan("SELECT * FROM TABLE(test.system.pass_through_function(input => TABLE(SELECT 1, true) t(a, b)))", + strictOutput( + ImmutableList.of("x", "a", "b"), + tableFunctionProcessor( + builder -> builder + .name("pass_through_function") + .properOutputs(ImmutableList.of("x")) + .passThroughSymbols( + ImmutableList.of(ImmutableList.of("a", "b"))) + .requiredSymbols(ImmutableList.of(ImmutableList.of("a"))) + .specification(specification(ImmutableList.of(), ImmutableList.of(), ImmutableMap.of())), + project(ImmutableMap.of("a", expression("INTEGER'1'"), "b", expression("BOOLEAN'true'")), values(1))))); + + // no table function outputs are referenced. All pass-through symbols are pruned from the TableFunctionProcessorNode. The unused symbol "b" is pruned from the source values node. + assertPlan("SELECT 'constant' c FROM TABLE(test.system.pass_through_function(input => TABLE(SELECT 1, true) t(a, b)))", + strictOutput( + ImmutableList.of("c"), + strictProject( + ImmutableMap.of("c", expression("VARCHAR'constant'")), + tableFunctionProcessor( + builder -> builder + .name("pass_through_function") + .properOutputs(ImmutableList.of("x")) + .passThroughSymbols(ImmutableList.of(ImmutableList.of())) + .requiredSymbols(ImmutableList.of(ImmutableList.of("a"))) + .specification(specification(ImmutableList.of(), ImmutableList.of(), ImmutableMap.of())), + project(ImmutableMap.of("a", expression("INTEGER'1'")), values(1)))))); + } + + @Test + public void testRemoveRedundantTableFunction() + { + assertPlan("SELECT * FROM TABLE(test.system.pass_through_function(input => TABLE(SELECT 1, true WHERE false) t(a, b) PRUNE WHEN EMPTY))", + output(values(ImmutableList.of("x", "a", "b")))); + + assertPlan("SELECT *\n" + + "FROM TABLE(test.system.two_table_arguments_function(\n" + + " input1 => TABLE(SELECT 1, true WHERE false) t1(a, b) PRUNE WHEN EMPTY,\n" + + " input2 => TABLE(SELECT 2, false) t2(c, d) KEEP WHEN EMPTY))\n", + output(values(ImmutableList.of("column")))); + + assertPlan("SELECT *\n" + + "FROM TABLE(test.system.two_table_arguments_function(\n" + + " input1 => TABLE(SELECT 1, true WHERE false) t1(a, b) PRUNE WHEN EMPTY,\n" + + " input2 => TABLE(SELECT 2, false WHERE false) t2(c, d) PRUNE WHEN EMPTY))\n", + output(values(ImmutableList.of("column")))); + + assertPlan("SELECT *\n" + + "FROM TABLE(test.system.two_table_arguments_function(\n" + + " input1 => TABLE(SELECT 1, true WHERE false) t1(a, b) PRUNE WHEN EMPTY,\n" + + " input2 => TABLE(SELECT 2, false WHERE false) t2(c, d) KEEP WHEN EMPTY))\n", + output(values(ImmutableList.of("column")))); + + assertPlan("SELECT *\n" + + "FROM TABLE(test.system.two_table_arguments_function(\n" + + " input1 => TABLE(SELECT 1, true WHERE false) t1(a, b) KEEP WHEN EMPTY,\n" + + " input2 => TABLE(SELECT 2, false WHERE false) t2(c, d) KEEP WHEN EMPTY))\n", + output( + node(TableFunctionProcessorNode.class, + values(ImmutableList.of("a", "marker_1", "c", "marker_2", "row_number"))))); + + assertPlan("SELECT *\n" + + "FROM TABLE(test.system.two_table_arguments_function(\n" + + " input1 => TABLE(SELECT 1, true WHERE false) t1(a, b) KEEP WHEN EMPTY,\n" + + " input2 => TABLE(SELECT 2, false) t2(c, d) PRUNE WHEN EMPTY))\n", + output( + node(TableFunctionProcessorNode.class, + project( + project( + rowNumber( + builder -> builder.partitionBy(ImmutableList.of()), + project( + ImmutableMap.of("c", expression("INTEGER'2'")), + values(1)) + ).withAlias("input_2_row_number", new RowNumberSymbolMatcher())))))); + } +} diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/assertions/PlanMatchPattern.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/assertions/PlanMatchPattern.java index d882fe7e54a5f..7b41ae0b02bf6 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/assertions/PlanMatchPattern.java +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/assertions/PlanMatchPattern.java @@ -658,6 +658,11 @@ public static PlanMatchPattern values(Map aliasToIndex) return values(aliasToIndex, Optional.empty(), Optional.empty()); } + public static PlanMatchPattern values(int rowCount) + { + return values(ImmutableList.of(), nCopies(rowCount, ImmutableList.of())); + } + public static PlanMatchPattern values(String... aliases) { return values(ImmutableList.copyOf(aliases)); @@ -713,6 +718,27 @@ public static PlanMatchPattern remoteSource(List sourceFragmentI return node(RemoteSourceNode.class).with(new RemoteSourceMatcher(sourceFragmentIds, outputSymbolAliases)); } + public static PlanMatchPattern tableFunction(Consumer handler, PlanMatchPattern... sources) + { + TableFunctionMatcher.Builder builder = new TableFunctionMatcher.Builder(sources); + handler.accept(builder); + return builder.build(); + } + + public static PlanMatchPattern tableFunctionProcessor(Consumer handler, PlanMatchPattern source) + { + TableFunctionProcessorMatcher.Builder builder = new TableFunctionProcessorMatcher.Builder(source); + handler.accept(builder); + return builder.build(); + } + + public static PlanMatchPattern tableFunctionProcessor(Consumer handler) + { + TableFunctionProcessorMatcher.Builder builder = new TableFunctionProcessorMatcher.Builder(); + handler.accept(builder); + return builder.build(); + } + public PlanMatchPattern(List sourcePatterns) { requireNonNull(sourcePatterns, "sourcePatterns are null"); diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/assertions/TableFunctionMatcher.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/assertions/TableFunctionMatcher.java new file mode 100644 index 0000000000000..c14b68b443867 --- /dev/null +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/assertions/TableFunctionMatcher.java @@ -0,0 +1,412 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.assertions; + +import com.facebook.presto.Session; +import com.facebook.presto.cost.StatsProvider; +import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.spi.function.table.Argument; +import com.facebook.presto.spi.function.table.Descriptor; +import com.facebook.presto.spi.function.table.DescriptorArgument; +import com.facebook.presto.spi.function.table.ScalarArgument; +import com.facebook.presto.spi.function.table.TableArgument; +import com.facebook.presto.spi.plan.DataOrganizationSpecification; +import com.facebook.presto.spi.plan.PlanNode; +import com.facebook.presto.sql.planner.plan.TableFunctionNode; +import com.facebook.presto.sql.planner.plan.TableFunctionNode.TableArgumentProperties; +import com.facebook.presto.sql.tree.SymbolReference; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; + +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; + +import static com.facebook.presto.sql.planner.QueryPlanner.toSymbolReferences; +import static com.facebook.presto.sql.planner.assertions.MatchResult.NO_MATCH; +import static com.facebook.presto.sql.planner.assertions.MatchResult.match; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.node; +import static com.facebook.presto.sql.planner.assertions.SpecificationProvider.matchSpecification; +import static com.google.common.base.MoreObjects.toStringHelper; +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static java.util.Objects.requireNonNull; + +public class TableFunctionMatcher + implements Matcher +{ + private final String name; + private final Map arguments; + private final List properOutputs; + private final List> copartitioningLists; + + private TableFunctionMatcher( + String name, + Map arguments, + List properOutputs, + List> copartitioningLists) + { + this.name = requireNonNull(name, "name is null"); + this.arguments = ImmutableMap.copyOf(requireNonNull(arguments, "arguments is null")); + this.properOutputs = ImmutableList.copyOf(requireNonNull(properOutputs, "properOutputs is null")); + requireNonNull(copartitioningLists, "copartitioningLists is null"); + this.copartitioningLists = copartitioningLists.stream() + .map(ImmutableList::copyOf) + .collect(toImmutableList()); + } + + @Override + public boolean shapeMatches(PlanNode node) + { + return node instanceof TableFunctionNode; + } + + @Override + public MatchResult detailMatches(PlanNode node, StatsProvider stats, Session session, Metadata metadata, SymbolAliases symbolAliases) + { + checkState(shapeMatches(node), "Plan testing framework error: shapeMatches returned false in detailMatches in %s", this.getClass().getName()); + + TableFunctionNode tableFunctionNode = (TableFunctionNode) node; + + if (!name.equals(tableFunctionNode.getName())) { + return NO_MATCH; + } + + if (arguments.size() != tableFunctionNode.getArguments().size()) { + return NO_MATCH; + } + for (Map.Entry entry : arguments.entrySet()) { + String name = entry.getKey(); + Argument actual = tableFunctionNode.getArguments().get(name); + if (actual == null) { + return NO_MATCH; + } + ArgumentValue expected = entry.getValue(); + switch (expected.getType()) { + case DescriptorArgumentValue.type: + DescriptorArgumentValue expectedDescriptor = (DescriptorArgumentValue) expected; + if (!(actual instanceof DescriptorArgument) || !expectedDescriptor.getDescriptor().equals(((DescriptorArgument) actual).getDescriptor())) { + return NO_MATCH; + } + break; + case ScalarArgumentValue.type: + ScalarArgumentValue expectedScalar = (ScalarArgumentValue) expected; + if (!(actual instanceof ScalarArgument) || !Objects.equals(expectedScalar.getValue(), ((ScalarArgument) actual).getValue())) { + return NO_MATCH; + } + break; + default: + if (!(actual instanceof TableArgument) || getMatchResult(symbolAliases, (TableArgumentValue) expected, tableFunctionNode, name).equals(NO_MATCH)) { + return NO_MATCH; + } + } + } + + if (!ImmutableSet.copyOf(copartitioningLists).equals(ImmutableSet.copyOf(tableFunctionNode.getCopartitioningLists()))) { + return NO_MATCH; + } + + ImmutableMap.Builder properOutputsMapping = ImmutableMap.builder(); + return match(SymbolAliases.builder() + .putAll(symbolAliases) + .putAll(properOutputsMapping.buildOrThrow()) + .build()); + } + + private MatchResult getMatchResult(SymbolAliases symbolAliases, TableArgumentValue expected, TableFunctionNode tableFunctionNode, String name) + { + TableArgumentValue expectedTableArgument = expected; + TableArgumentProperties argumentProperties = tableFunctionNode.getTableArgumentProperties().get(expectedTableArgument.sourceIndex()); + if (!name.equals(argumentProperties.getArgumentName())) { + return NO_MATCH; + } + if (expectedTableArgument.rowSemantics() != argumentProperties.isRowSemantics() || + expectedTableArgument.pruneWhenEmpty() != argumentProperties.isPruneWhenEmpty() || + expectedTableArgument.passThroughColumns() != argumentProperties.getPassThroughSpecification().isDeclaredAsPassThrough()) { + return NO_MATCH; + } + + if (expectedTableArgument.specification().isPresent() != argumentProperties.getSpecification().isPresent()) { + return NO_MATCH; + } + if (!expectedTableArgument.specification() + .map(expectedSpecification -> matchSpecification(argumentProperties.getSpecification().get(), expectedSpecification.getExpectedValue(symbolAliases))) + .orElse(true)) { + return NO_MATCH; + } + Set expectedPassThrough = expectedTableArgument.passThroughVariables().stream() + .map(symbolAliases::get) + .collect(toImmutableSet()); + Set actualPassThrough = toSymbolReferences( + argumentProperties.getPassThroughSpecification().getColumns().stream() + .map(TableFunctionNode.PassThroughColumn::getOutputVariables) + .collect(Collectors.toList())) + .stream() + .map(SymbolReference.class::cast) + .collect(Collectors.toSet()); + if (!expectedPassThrough.equals(actualPassThrough)) { + return NO_MATCH; + } + return match(symbolAliases); + } + + @Override + public String toString() + { + return toStringHelper(this) + .omitNullValues() + .add("name", name) + .add("arguments", arguments) + .add("properOutputs", properOutputs) + .add("copartitioningLists", copartitioningLists) + .toString(); + } + + public static class Builder + { + private final PlanMatchPattern[] sources; + private String name; + private final ImmutableMap.Builder arguments = ImmutableMap.builder(); + private List properOutputs = ImmutableList.of(); + private final ImmutableList.Builder> copartitioningLists = ImmutableList.builder(); + + Builder(PlanMatchPattern... sources) + { + this.sources = Arrays.copyOf(sources, sources.length); + } + + public Builder name(String name) + { + this.name = name; + return this; + } + + public Builder addDescriptorArgument(String name, DescriptorArgumentValue descriptor) + { + this.arguments.put(name, descriptor); + return this; + } + + public Builder addScalarArgument(String name, Object value) + { + this.arguments.put(name, new ScalarArgumentValue(value)); + return this; + } + + public Builder addTableArgument(String name, TableArgumentValue.Builder tableArgument) + { + this.arguments.put(name, tableArgument.build()); + return this; + } + + public Builder properOutputs(List properOutputs) + { + this.properOutputs = properOutputs; + return this; + } + + public Builder addCopartitioning(List copartitioning) + { + this.copartitioningLists.add(copartitioning); + return this; + } + + public PlanMatchPattern build() + { + return node(TableFunctionNode.class, sources) + .with(new TableFunctionMatcher(name, arguments.buildOrThrow(), properOutputs, copartitioningLists.build())); + } + } + + interface ArgumentValue + { + String getType(); + } + + public static class DescriptorArgumentValue + implements ArgumentValue + { + private final Optional descriptor; + public static final String type = "Descriptor"; + + public DescriptorArgumentValue(Optional descriptor) + { + this.descriptor = requireNonNull(descriptor, "descriptor is null"); + } + + public static DescriptorArgumentValue descriptorArgument(Descriptor descriptor) + { + return new DescriptorArgumentValue(Optional.of(requireNonNull(descriptor, "descriptor is null"))); + } + + public static DescriptorArgumentValue nullDescriptor() + { + return new DescriptorArgumentValue(Optional.empty()); + } + + public Optional getDescriptor() + { + return descriptor; + } + + @Override + public String getType() + { + return type; + } + } + + public static class ScalarArgumentValue + implements ArgumentValue + { + private final Object value; + public static final String type = "Scalar"; + + public ScalarArgumentValue(Object value) + { + this.value = value; + } + + public Object getValue() + { + return value; + } + + @Override + public String getType() + { + return type; + } + } + + public static class TableArgumentValue + implements ArgumentValue + { + private final int sourceIndex; + private final boolean rowSemantics; + private final boolean pruneWhenEmpty; + private final boolean passThroughColumns; + private final Optional> specification; + private final Set passThroughVariables; + public static final String type = "Table"; + + public TableArgumentValue(int sourceIndex, boolean rowSemantics, boolean pruneWhenEmpty, boolean passThroughColumns, Optional> specification, Set passThroughVariables) + { + this.sourceIndex = sourceIndex; + this.rowSemantics = rowSemantics; + this.pruneWhenEmpty = pruneWhenEmpty; + this.passThroughColumns = passThroughColumns; + this.specification = requireNonNull(specification, "specification is null"); + this.passThroughVariables = ImmutableSet.copyOf(passThroughVariables); + } + + public int sourceIndex() + { + return sourceIndex; + } + + public boolean rowSemantics() + { + return rowSemantics; + } + + public boolean pruneWhenEmpty() + { + return pruneWhenEmpty; + } + + public boolean passThroughColumns() + { + return passThroughColumns; + } + + public Set passThroughVariables() + { + return passThroughVariables; + } + + public Optional> specification() + { + return specification; + } + + @Override + public String getType() + { + return type; + } + + public static class Builder + { + private final int sourceIndex; + private boolean rowSemantics; + private boolean pruneWhenEmpty; + private boolean passThroughColumns; + private Optional> specification = Optional.empty(); + private Set passThroughVariables = ImmutableSet.of(); + + private Builder(int sourceIndex) + { + this.sourceIndex = sourceIndex; + } + + public static Builder tableArgument(int sourceIndex) + { + return new Builder(sourceIndex); + } + + public Builder rowSemantics() + { + this.rowSemantics = true; + this.pruneWhenEmpty = true; + return this; + } + + public Builder pruneWhenEmpty() + { + this.pruneWhenEmpty = true; + return this; + } + + public Builder passThroughColumns() + { + this.passThroughColumns = true; + return this; + } + + public Builder specification(ExpectedValueProvider specification) + { + this.specification = Optional.of(specification); + return this; + } + + public Builder passThroughVariables(Set variables) + { + this.passThroughVariables = variables; + return this; + } + + private TableArgumentValue build() + { + return new TableArgumentValue(sourceIndex, rowSemantics, pruneWhenEmpty, passThroughColumns, specification, passThroughVariables); + } + } + } +} diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/assertions/TableFunctionProcessorMatcher.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/assertions/TableFunctionProcessorMatcher.java new file mode 100644 index 0000000000000..4891c3eb021dd --- /dev/null +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/assertions/TableFunctionProcessorMatcher.java @@ -0,0 +1,239 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.assertions; + +import com.facebook.presto.Session; +import com.facebook.presto.cost.StatsProvider; +import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.spi.plan.DataOrganizationSpecification; +import com.facebook.presto.spi.plan.PlanNode; +import com.facebook.presto.sql.planner.QueryPlanner; +import com.facebook.presto.sql.planner.plan.TableFunctionNode.PassThroughColumn; +import com.facebook.presto.sql.planner.plan.TableFunctionNode.PassThroughSpecification; +import com.facebook.presto.sql.planner.plan.TableFunctionProcessorNode; +import com.facebook.presto.sql.tree.Expression; +import com.facebook.presto.sql.tree.SymbolReference; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; + +import java.util.List; +import java.util.Map; +import java.util.NoSuchElementException; +import java.util.Optional; + +import static com.facebook.presto.sql.planner.QueryPlanner.toSymbolReference; +import static com.facebook.presto.sql.planner.assertions.MatchResult.NO_MATCH; +import static com.facebook.presto.sql.planner.assertions.MatchResult.match; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.node; +import static com.facebook.presto.sql.planner.assertions.SpecificationProvider.matchSpecification; +import static com.google.common.base.MoreObjects.toStringHelper; +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static java.util.Objects.requireNonNull; + +public class TableFunctionProcessorMatcher + implements Matcher +{ + private final String name; + private final List properOutputs; + private final List> passThroughSymbols; + private final List> requiredSymbols; + private final Optional> markerSymbols; + private final Optional> specification; + private final Optional hashSymbol; + + private TableFunctionProcessorMatcher( + String name, + List properOutputs, + List> passThroughSymbols, + List> requiredSymbols, + Optional> markerSymbols, + Optional> specification, + Optional hashSymbol) + { + this.name = requireNonNull(name, "name is null"); + this.properOutputs = ImmutableList.copyOf(properOutputs); + this.passThroughSymbols = passThroughSymbols.stream() + .map(ImmutableList::copyOf) + .collect(toImmutableList()); + this.requiredSymbols = requiredSymbols.stream() + .map(ImmutableList::copyOf) + .collect(toImmutableList()); + this.markerSymbols = markerSymbols.map(ImmutableMap::copyOf); + this.specification = requireNonNull(specification, "specification is null"); + this.hashSymbol = requireNonNull(hashSymbol, "hashSymbol is null"); + } + + @Override + public boolean shapeMatches(PlanNode node) + { + return node instanceof TableFunctionProcessorNode; + } + + @Override + public MatchResult detailMatches(PlanNode node, StatsProvider stats, Session session, Metadata metadata, SymbolAliases symbolAliases) + { + checkState(shapeMatches(node), "Plan testing framework error: shapeMatches returned false in detailMatches in %s", this.getClass().getName()); + + TableFunctionProcessorNode tableFunctionProcessorNode = (TableFunctionProcessorNode) node; + + if (!name.equals(tableFunctionProcessorNode.getName())) { + return NO_MATCH; + } + + if (properOutputs.size() != tableFunctionProcessorNode.getProperOutputs().size()) { + return NO_MATCH; + } + + List> expectedPassThrough = passThroughSymbols.stream() + .map(list -> list.stream() + .map(symbolAliases::get) + .collect(toImmutableList())) + .collect(toImmutableList()); + List> actualPassThrough = tableFunctionProcessorNode.getPassThroughSpecifications().stream() + .map(PassThroughSpecification::getColumns) + .map(list -> list.stream() + .map(PassThroughColumn::getOutputVariables) + .map(QueryPlanner::toSymbolReference) + .collect(toImmutableList())) + .collect(toImmutableList()); + if (!expectedPassThrough.equals(actualPassThrough)) { + return NO_MATCH; + } + + if (markerSymbols.isPresent() != tableFunctionProcessorNode.getMarkerVariables().isPresent()) { + return NO_MATCH; + } + if (markerSymbols.isPresent()) { + Map expectedMapping = markerSymbols.get().entrySet().stream() + .collect(toImmutableMap(entry -> symbolAliases.get(entry.getKey()), entry -> symbolAliases.get(entry.getValue()))); + Map actualMapping = tableFunctionProcessorNode.getMarkerVariables().get().entrySet().stream() + .collect(toImmutableMap(entry -> toSymbolReference(entry.getKey()), entry -> toSymbolReference(entry.getValue()))); + if (!expectedMapping.equals(actualMapping)) { + return NO_MATCH; + } + } + + if (specification.isPresent() != tableFunctionProcessorNode.getSpecification().isPresent()) { + return NO_MATCH; + } + if (specification.isPresent()) { + if (!matchSpecification(specification.get().getExpectedValue(symbolAliases), tableFunctionProcessorNode.getSpecification().orElseThrow(NoSuchElementException::new))) { + return NO_MATCH; + } + } + if (hashSymbol.isPresent()) { + if (!hashSymbol.map(symbolAliases::get).equals(tableFunctionProcessorNode.getHashSymbol().map(QueryPlanner::toSymbolReference))) { + return NO_MATCH; + } + } + + ImmutableMap.Builder properOutputsMapping = ImmutableMap.builder(); + for (int i = 0; i < properOutputs.size(); i++) { + properOutputsMapping.put(properOutputs.get(i), toSymbolReference(tableFunctionProcessorNode.getProperOutputs().get(i))); + } + + return match(SymbolAliases.builder() + .putAll(symbolAliases) + .putAll(properOutputsMapping.buildOrThrow()) + .build()); + } + + @Override + public String toString() + { + return toStringHelper(this) + .omitNullValues() + .add("name", name) + .add("properOutputs", properOutputs) + .add("passThroughSymbols", passThroughSymbols) + .add("requiredSymbols", requiredSymbols) + .add("markerSymbols", markerSymbols) + .add("specification", specification) + .add("hashSymbol", hashSymbol) + .toString(); + } + + public static class Builder + { + private final Optional source; + private String name; + private List properOutputs = ImmutableList.of(); + private List> passThroughSymbols = ImmutableList.of(); + private List> requiredSymbols = ImmutableList.of(); + private Optional> markerSymbols = Optional.empty(); + private Optional> specification = Optional.empty(); + private Optional hashSymbol = Optional.empty(); + + public Builder() + { + this.source = Optional.empty(); + } + + public Builder(PlanMatchPattern source) + { + this.source = Optional.of(source); + } + + public Builder name(String name) + { + this.name = name; + return this; + } + + public Builder properOutputs(List properOutputs) + { + this.properOutputs = properOutputs; + return this; + } + + public Builder passThroughSymbols(List> passThroughSymbols) + { + this.passThroughSymbols = passThroughSymbols; + return this; + } + + public Builder requiredSymbols(List> requiredSymbols) + { + this.requiredSymbols = requiredSymbols; + return this; + } + + public Builder markerSymbols(Map markerSymbols) + { + this.markerSymbols = Optional.of(markerSymbols); + return this; + } + + public Builder specification(ExpectedValueProvider specification) + { + this.specification = Optional.of(specification); + return this; + } + + public Builder hashSymbol(String hashSymbol) + { + this.hashSymbol = Optional.of(hashSymbol); + return this; + } + + public PlanMatchPattern build() + { + PlanMatchPattern[] sources = source.map(sourcePattern -> new PlanMatchPattern[] {sourcePattern}).orElse(new PlanMatchPattern[] {}); + return node(TableFunctionProcessorNode.class, sources) + .with(new TableFunctionProcessorMatcher(name, properOutputs, passThroughSymbols, requiredSymbols, markerSymbols, specification, hashSymbol)); + } + } +} diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneTableFunctionProcessorColumns.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneTableFunctionProcessorColumns.java new file mode 100644 index 0000000000000..bcae22ae6c623 --- /dev/null +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneTableFunctionProcessorColumns.java @@ -0,0 +1,221 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.spi.plan.Assignments; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; +import com.facebook.presto.sql.planner.plan.TableFunctionNode.PassThroughColumn; +import com.facebook.presto.sql.planner.plan.TableFunctionNode.PassThroughSpecification; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.Test; + +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.expression; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.project; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.tableFunctionProcessor; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; + +public class TestPruneTableFunctionProcessorColumns + extends BaseRuleTest +{ + @Test + public void testDoNotPruneProperOutputs() + { + tester().assertThat(new PruneTableFunctionProcessorColumns()) + .on(p -> p.project( + Assignments.of(), + p.tableFunctionProcessor( + builder -> builder + .name("test_function") + .properOutputs(p.variable("p")) + .source(p.values(p.variable("x")))))) + .doesNotFire(); + } + + @Test + public void testPrunePassThroughOutputs() + { + tester().assertThat(new PruneTableFunctionProcessorColumns()) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + return p.project( + Assignments.of(), + p.tableFunctionProcessor( + builder -> builder + .name("test_function") + .passThroughSpecifications( + new PassThroughSpecification( + true, + ImmutableList.of( + new PassThroughColumn(a, true), + new PassThroughColumn(b, false)))) + .source(p.values(a, b)))); + }) + .matches(project( + ImmutableMap.of(), + tableFunctionProcessor(builder -> builder + .name("test_function") + .passThroughSymbols(ImmutableList.of(ImmutableList.of())), + values("a", "b")))); + + tester().assertThat(new PruneTableFunctionProcessorColumns()) + .on(p -> { + VariableReferenceExpression proper = p.variable("proper"); + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + return p.project( + Assignments.of(), + p.tableFunctionProcessor( + builder -> builder + .name("test_function") + .properOutputs(proper) + .passThroughSpecifications( + new PassThroughSpecification( + true, + ImmutableList.of( + new PassThroughColumn(a, true), + new PassThroughColumn(b, false)))) + .source(p.values(a, b)))); + }) + .matches(project( + ImmutableMap.of(), + tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("proper")) + .passThroughSymbols(ImmutableList.of(ImmutableList.of())), + values("a", "b")))); + } + + @Test + public void testReferencedPassThroughOutputs() + { + tester().assertThat(new PruneTableFunctionProcessorColumns()) + .on(p -> { + VariableReferenceExpression x = p.variable("x"); + VariableReferenceExpression y = p.variable("y"); + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + return p.project( + Assignments.builder().put(y, y).put(b, b).build(), + p.tableFunctionProcessor( + builder -> builder + .name("test_function") + .properOutputs(x, y) + .passThroughSpecifications( + new PassThroughSpecification( + true, + ImmutableList.of( + new PassThroughColumn(a, true), + new PassThroughColumn(b, false)))) + .source(p.values(a, b)))); + }) + .matches(project( + ImmutableMap.of("y", expression("y"), "b", expression("b")), + tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("x", "y")) + .passThroughSymbols(ImmutableList.of(ImmutableList.of("b"))), + values("a", "b")))); + } + + @Test + public void testAllPassThroughOutputsReferenced() + { + tester().assertThat(new PruneTableFunctionProcessorColumns()) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + return p.project( + Assignments.builder().put(a, a).put(b, b).build(), + p.tableFunctionProcessor( + builder -> builder + .name("test_function") + .passThroughSpecifications( + new PassThroughSpecification( + true, + ImmutableList.of( + new PassThroughColumn(a, true), + new PassThroughColumn(b, false)))) + .source(p.values(a, b)))); + }) + .doesNotFire(); + + tester().assertThat(new PruneTableFunctionProcessorColumns()) + .on(p -> { + VariableReferenceExpression proper = p.variable("proper"); + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + return p.project( + Assignments.builder().put(a, a).put(b, b).build(), + p.tableFunctionProcessor( + builder -> builder + .name("test_function") + .properOutputs(proper) + .passThroughSpecifications( + new PassThroughSpecification( + true, + ImmutableList.of( + new PassThroughColumn(a, true), + new PassThroughColumn(b, false)))) + .source(p.values(a, b)))); + }) + .doesNotFire(); + } + + @Test + public void testNoSource() + { + tester().assertThat(new PruneTableFunctionProcessorColumns()) + .on(p -> p.project( + Assignments.of(), + p.tableFunctionProcessor( + builder -> builder + .name("test_function") + .properOutputs(p.variable("proper"))))) + .doesNotFire(); + } + + @Test + public void testMultipleTableArguments() + { + // multiple pass-through specifications indicate that the table function has multiple table arguments + tester().assertThat(new PruneTableFunctionProcessorColumns()) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + VariableReferenceExpression c = p.variable("c"); + VariableReferenceExpression d = p.variable("d"); + return p.project( + Assignments.builder().put(b, b).build(), + p.tableFunctionProcessor( + builder -> builder + .name("test_function") + .properOutputs(p.variable("proper")) + .passThroughSpecifications( + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(a, true))), + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(b, true))), + new PassThroughSpecification(true, ImmutableList.of(new PassThroughColumn(c, false)))) + .source(p.values(a, b, c, d)))); + }) + .matches(project( + ImmutableMap.of("b", expression("b")), + tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("proper")) + .passThroughSymbols(ImmutableList.of(ImmutableList.of(), ImmutableList.of("b"), ImmutableList.of())), + values("a", "b", "c", "d")))); + } +} diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneTableFunctionProcessorSourceColumns.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneTableFunctionProcessorSourceColumns.java new file mode 100644 index 0000000000000..68f56d320e396 --- /dev/null +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneTableFunctionProcessorSourceColumns.java @@ -0,0 +1,198 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.spi.plan.DataOrganizationSpecification; +import com.facebook.presto.spi.plan.Ordering; +import com.facebook.presto.spi.plan.OrderingScheme; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; +import com.facebook.presto.sql.planner.plan.TableFunctionNode.PassThroughColumn; +import com.facebook.presto.sql.planner.plan.TableFunctionNode.PassThroughSpecification; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.Test; + +import java.util.Optional; + +import static com.facebook.presto.common.block.SortOrder.ASC_NULLS_FIRST; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.expression; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.project; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.specification; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.tableFunctionProcessor; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; + +public class TestPruneTableFunctionProcessorSourceColumns + extends BaseRuleTest +{ + @Test + public void testPruneUnreferencedSymbol() + { + // symbols 'a', 'b', 'c', 'd', 'hash', and 'marker' are used by the node. + // symbol 'unreferenced' is pruned out. Also, the mapping for this symbol is removed from marker mappings + tester().assertThat(new PruneTableFunctionProcessorSourceColumns()) + .on(p -> { + VariableReferenceExpression proper = p.variable("proper"); + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + VariableReferenceExpression c = p.variable("c"); + VariableReferenceExpression d = p.variable("d"); + VariableReferenceExpression unreferenced = p.variable("unreferenced"); + VariableReferenceExpression hash = p.variable("hash"); + VariableReferenceExpression marker = p.variable("marker"); + return p.tableFunctionProcessor( + builder -> builder + .name("test_function") + .properOutputs(proper) + .passThroughSpecifications(new PassThroughSpecification(true, ImmutableList.of(new PassThroughColumn(a, false)))) + .requiredSymbols(ImmutableList.of(ImmutableList.of(b))) + .markerSymbols(ImmutableMap.of( + a, marker, + b, marker, + c, marker, + d, marker, + unreferenced, marker)) + .specification(new DataOrganizationSpecification(ImmutableList.of(c), Optional.of(new OrderingScheme(ImmutableList.of(new Ordering(d, ASC_NULLS_FIRST)))))) + .hashSymbol(hash) + .source(p.values(a, b, c, d, unreferenced, hash, marker))); + }) + .matches(tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("proper")) + .passThroughSymbols(ImmutableList.of(ImmutableList.of("a"))) + .requiredSymbols(ImmutableList.of(ImmutableList.of("b"))) + .markerSymbols(ImmutableMap.of( + "a", "marker", + "b", "marker", + "c", "marker", + "d", "marker")) + .specification(specification(ImmutableList.of("c"), ImmutableList.of("d"), ImmutableMap.of("d", ASC_NULLS_FIRST))) + .hashSymbol("hash"), + project( + ImmutableMap.of( + "a", expression("a"), + "b", expression("b"), + "c", expression("c"), + "d", expression("d"), + "hash", expression("hash"), + "marker", expression("marker")), + values("a", "b", "c", "d", "unreferenced", "hash", "marker")))); + } + + @Test + public void testPruneUnusedMarkerSymbol() + { + // symbol 'unreferenced' is pruned out because the node does not use it. + // also, the mapping for this symbol is removed from marker mappings. + // because the marker symbol 'marker' is no longer used, it is pruned out too. + // note: currently a marker symbol cannot become unused because the function + // must use at least one symbol from each source. it might change in the future. + tester().assertThat(new PruneTableFunctionProcessorSourceColumns()) + .on(p -> { + VariableReferenceExpression unreferenced = p.variable("unreferenced"); + VariableReferenceExpression marker = p.variable("marker"); + return p.tableFunctionProcessor( + builder -> builder + .name("test_function") + .markerSymbols(ImmutableMap.of(unreferenced, marker)) + .source(p.values(unreferenced, marker))); + }) + .matches(tableFunctionProcessor(builder -> builder + .name("test_function") + .markerSymbols(ImmutableMap.of()), + project( + ImmutableMap.of(), + values("unreferenced", "marker")))); + } + + @Test + public void testMultipleSources() + { + // multiple pass-through specifications indicate that the table function has multiple table arguments + // the third argument provides symbols 'e', 'f', and 'unreferenced'. those symbols are mapped to common marker symbol 'marker3' + tester().assertThat(new PruneTableFunctionProcessorSourceColumns()) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + VariableReferenceExpression c = p.variable("c"); + VariableReferenceExpression d = p.variable("d"); + VariableReferenceExpression e = p.variable("e"); + VariableReferenceExpression f = p.variable("f"); + VariableReferenceExpression marker1 = p.variable("marker1"); + VariableReferenceExpression marker2 = p.variable("marker2"); + VariableReferenceExpression marker3 = p.variable("marker3"); + VariableReferenceExpression unreferenced = p.variable("unreferenced"); + return p.tableFunctionProcessor( + builder -> builder + .name("test_function") + .passThroughSpecifications( + new PassThroughSpecification(true, ImmutableList.of(new PassThroughColumn(a, false))), + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(c, true))), + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(e, true)))) + .requiredSymbols(ImmutableList.of( + ImmutableList.of(b), + ImmutableList.of(d), + ImmutableList.of(f))) + .markerSymbols(ImmutableMap.of( + a, marker1, + b, marker1, + c, marker2, + d, marker2, + e, marker3, + f, marker3, + unreferenced, marker3)) + .source(p.values(a, b, c, d, e, f, marker1, marker2, marker3, unreferenced))); + }) + .matches(tableFunctionProcessor(builder -> builder + .name("test_function") + .passThroughSymbols(ImmutableList.of(ImmutableList.of("a"), ImmutableList.of("c"), ImmutableList.of("e"))) + .requiredSymbols(ImmutableList.of(ImmutableList.of("b"), ImmutableList.of("d"), ImmutableList.of("f"))) + .markerSymbols(ImmutableMap.of( + "a", "marker1", + "b", "marker1", + "c", "marker2", + "d", "marker2", + "e", "marker3", + "f", "marker3")), + project( + ImmutableMap.of( + "a", expression("a"), + "b", expression("b"), + "c", expression("c"), + "d", expression("d"), + "e", expression("e"), + "f", expression("f"), + "marker1", expression("marker1"), + "marker2", expression("marker2"), + "marker3", expression("marker3")), + values("a", "b", "c", "d", "e", "f", "marker1", "marker2", "marker3", "unreferenced")))); + } + + @Test + public void allSymbolsReferenced() + { + tester().assertThat(new PruneTableFunctionProcessorSourceColumns()) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression marker = p.variable("marker"); + return p.tableFunctionProcessor( + builder -> builder + .name("test_function") + .requiredSymbols(ImmutableList.of(ImmutableList.of(a))) + .markerSymbols(ImmutableMap.of(a, marker)) + .source(p.values(a, marker))); + }) + .doesNotFire(); + } +} diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRemoveRedundantTableFunctionProcessor.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRemoveRedundantTableFunctionProcessor.java new file mode 100644 index 0000000000000..86b6cc74e74af --- /dev/null +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRemoveRedundantTableFunctionProcessor.java @@ -0,0 +1,80 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; +import com.facebook.presto.sql.planner.plan.TableFunctionNode.PassThroughColumn; +import com.facebook.presto.sql.planner.plan.TableFunctionNode.PassThroughSpecification; +import com.google.common.collect.ImmutableList; +import org.testng.annotations.Test; + +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; + +public class TestRemoveRedundantTableFunctionProcessor + extends BaseRuleTest +{ + @Test + public void testRemoveTableFunction() + { + tester().assertThat(new RemoveRedundantTableFunctionProcessor()) + .on(p -> { + VariableReferenceExpression passThrough = p.variable("pass_through"); + VariableReferenceExpression proper = p.variable("proper"); + return p.tableFunctionProcessor( + builder -> builder + .name("test_function") + .pruneWhenEmpty() + .properOutputs(proper) + .passThroughSpecifications(new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(passThrough, true)))) + .source(p.values(passThrough))); + }) + .matches(values("proper", "pass_through")); + } + + @Test + public void testDoNotRemoveKeepWhenEmpty() + { + tester().assertThat(new RemoveRedundantTableFunctionProcessor()) + .on(p -> { + VariableReferenceExpression passThrough = p.variable("pass_through"); + VariableReferenceExpression proper = p.variable("proper"); + return p.tableFunctionProcessor( + builder -> builder + .name("test_function") + .properOutputs(proper) + .passThroughSpecifications(new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(passThrough, true)))) + .source(p.values(passThrough))); + }) + .doesNotFire(); + } + + @Test + public void testDoNotRemoveNonEmptyInput() + { + tester().assertThat(new RemoveRedundantTableFunctionProcessor()) + .on(p -> { + VariableReferenceExpression passThrough = p.variable("pass_through"); + VariableReferenceExpression proper = p.variable("proper"); + return p.tableFunctionProcessor( + builder -> builder + .name("test_function") + .pruneWhenEmpty() + .properOutputs(proper) + .passThroughSpecifications(new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(passThrough, true)))) + .source(p.values(5, passThrough))); + }) + .doesNotFire(); + } +} diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRewriteExcludeColumnsFunctionToProjection.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRewriteExcludeColumnsFunctionToProjection.java new file mode 100644 index 0000000000000..bfdcea0b8ca5f --- /dev/null +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRewriteExcludeColumnsFunctionToProjection.java @@ -0,0 +1,74 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.operator.table.ExcludeColumns.ExcludeColumnsFunctionHandle; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.planner.assertions.PlanMatchPattern; +import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.Test; + +import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.common.type.BooleanType.BOOLEAN; +import static com.facebook.presto.common.type.SmallintType.SMALLINT; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.expression; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; + +public class TestRewriteExcludeColumnsFunctionToProjection + extends BaseRuleTest +{ + @Test + public void rewriteExcludeColumnsFunction() + { + tester().assertThat(new RewriteExcludeColumnsFunctionToProjection()) + .on(p -> { + VariableReferenceExpression a = p.variable("a", BOOLEAN); + VariableReferenceExpression b = p.variable("b", BIGINT); + VariableReferenceExpression c = p.variable("c", SMALLINT); + VariableReferenceExpression x = p.variable("x", BIGINT); + VariableReferenceExpression y = p.variable("y", SMALLINT); + return p.tableFunctionProcessor( + builder -> builder + .name("exclude_columns") + .properOutputs(x, y) + .pruneWhenEmpty() + .requiredSymbols(ImmutableList.of(ImmutableList.of(b, c))) + .connectorHandle(new ExcludeColumnsFunctionHandle()) + .source(p.values(a, b, c))); + }) + .matches(PlanMatchPattern.strictProject( + ImmutableMap.of( + "x", expression("b"), + "y", expression("c")), + values("a", "b", "c"))); + } + + @Test + public void doNotRewriteOtherFunction() + { + tester().assertThat(new RewriteExcludeColumnsFunctionToProjection()) + .on(p -> { + VariableReferenceExpression a = p.variable("a", BOOLEAN); + VariableReferenceExpression b = p.variable("b", BIGINT); + VariableReferenceExpression c = p.variable("c", SMALLINT); + return p.tableFunctionProcessor( + builder -> builder + .name("testing_function") + .requiredSymbols(ImmutableList.of(ImmutableList.of(b, c))) + .source(p.values(a, b, c))); + }).doesNotFire(); + } +} diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestTransformTableFunctionToTableFunctionProcessor.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestTransformTableFunctionToTableFunctionProcessor.java new file mode 100644 index 0000000000000..b6fb48a904e81 --- /dev/null +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestTransformTableFunctionToTableFunctionProcessor.java @@ -0,0 +1,1404 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.spi.plan.Assignments; +import com.facebook.presto.spi.plan.DataOrganizationSpecification; +import com.facebook.presto.spi.plan.JoinType; +import com.facebook.presto.spi.plan.Ordering; +import com.facebook.presto.spi.plan.OrderingScheme; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.planner.assertions.PlanMatchPattern; +import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; +import com.facebook.presto.sql.planner.plan.TableFunctionNode; +import com.facebook.presto.sql.planner.plan.TableFunctionNode.PassThroughSpecification; +import com.facebook.presto.sql.planner.plan.TableFunctionNode.TableArgumentProperties; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.Test; + +import java.util.Optional; + +import static com.facebook.presto.common.block.SortOrder.ASC_NULLS_LAST; +import static com.facebook.presto.common.block.SortOrder.DESC_NULLS_FIRST; +import static com.facebook.presto.common.type.IntegerType.INTEGER; +import static com.facebook.presto.common.type.TinyintType.TINYINT; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.expression; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.functionCall; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.join; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.project; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.specification; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.tableFunctionProcessor; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.window; +import static com.facebook.presto.sql.planner.plan.TableFunctionNode.PassThroughColumn; + +public class TestTransformTableFunctionToTableFunctionProcessor + extends BaseRuleTest +{ + @Test + public void testNoSources() + { + tester().assertThat(new TransformTableFunctionToTableFunctionProcessor(tester().getMetadata())) + .on(p -> p.tableFunction( + "test_function", + ImmutableList.of(p.variable("a")), + ImmutableList.of(), + ImmutableList.of(), + ImmutableList.of())) + .matches(tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a")))); + } + + @Test + public void testSingleSourceWithRowSemantics() + { + // no pass-through columns + tester().assertThat(new TransformTableFunctionToTableFunctionProcessor(tester().getMetadata())) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + VariableReferenceExpression c = p.variable("c"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of(p.values(c)), + ImmutableList.of(new TableFunctionNode.TableArgumentProperties( + "table_argument", + true, + true, + new TableFunctionNode.PassThroughSpecification(false, ImmutableList.of()), + ImmutableList.of(c), + Optional.empty())), + ImmutableList.of()); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableList.of(ImmutableList.of())) + .requiredSymbols(ImmutableList.of(ImmutableList.of("c"))), + values("c"))); + + // pass-through columns + tester().assertThat(new TransformTableFunctionToTableFunctionProcessor(tester().getMetadata())) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + VariableReferenceExpression c = p.variable("c"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of(p.values(c)), + ImmutableList.of(new TableFunctionNode.TableArgumentProperties( + "table_argument", + true, + true, + new TableFunctionNode.PassThroughSpecification(true, ImmutableList.of(new TableFunctionNode.PassThroughColumn(c, false))), + ImmutableList.of(c), + Optional.empty())), + ImmutableList.of()); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableList.of(ImmutableList.of("c"))) + .requiredSymbols(ImmutableList.of(ImmutableList.of("c"))), + values("c"))); + } + + @Test + public void testSingleSourceWithSetSemantics() + { + // no pass-through columns, no partition by + tester().assertThat(new TransformTableFunctionToTableFunctionProcessor(tester().getMetadata())) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + VariableReferenceExpression c = p.variable("c"); + VariableReferenceExpression d = p.variable("d"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of(p.values(c, d)), + ImmutableList.of(new TableArgumentProperties( + "table_argument", + false, + false, + new PassThroughSpecification(false, ImmutableList.of()), + ImmutableList.of(c, d), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(), Optional.of(new OrderingScheme(ImmutableList.of(new Ordering(d, ASC_NULLS_LAST)))))))), + ImmutableList.of()); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableList.of(ImmutableList.of())) + .requiredSymbols(ImmutableList.of(ImmutableList.of("c", "d"))) + .specification(specification(ImmutableList.of(), ImmutableList.of("d"), ImmutableMap.of("d", ASC_NULLS_LAST))), + values("c", "d"))); + + // no pass-through columns, partitioning column present + tester().assertThat(new TransformTableFunctionToTableFunctionProcessor(tester().getMetadata())) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + VariableReferenceExpression c = p.variable("c"); + VariableReferenceExpression d = p.variable("d"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of(p.values(c, d)), + ImmutableList.of(new TableArgumentProperties( + "table_argument", + false, + false, + new PassThroughSpecification(false, ImmutableList.of(new TableFunctionNode.PassThroughColumn(c, true))), + ImmutableList.of(c, d), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(c), Optional.of(new OrderingScheme(ImmutableList.of(new Ordering(d, ASC_NULLS_LAST)))))))), + ImmutableList.of()); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableList.of(ImmutableList.of("c"))) + .requiredSymbols(ImmutableList.of(ImmutableList.of("c", "d"))) + .specification(specification(ImmutableList.of("c"), ImmutableList.of("d"), ImmutableMap.of("d", ASC_NULLS_LAST))), + values("c", "d"))); + + // pass-through columns + tester().assertThat(new TransformTableFunctionToTableFunctionProcessor(tester().getMetadata())) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + VariableReferenceExpression c = p.variable("c"); + VariableReferenceExpression d = p.variable("d"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of(p.values(c, d)), + ImmutableList.of(new TableArgumentProperties( + "table_argument", + false, + false, + new PassThroughSpecification(true, ImmutableList.of(new PassThroughColumn(c, true), new PassThroughColumn(d, false))), + ImmutableList.of(d), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(c), Optional.empty())))), + ImmutableList.of()); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableList.of(ImmutableList.of("c", "d"))) + .requiredSymbols(ImmutableList.of(ImmutableList.of("d"))) + .specification(specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of())), + values("c", "d"))); + } + + @Test + public void testTwoSourcesWithSetSemantics() + { + tester().assertThat(new TransformTableFunctionToTableFunctionProcessor(tester().getMetadata())) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + VariableReferenceExpression c = p.variable("c"); + VariableReferenceExpression d = p.variable("d"); + VariableReferenceExpression e = p.variable("e"); + VariableReferenceExpression f = p.variable("f"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of( + p.values(c, d), + p.values(e, f)), + ImmutableList.of( + new TableArgumentProperties( + "input_1", + false, + false, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(c, true))), + ImmutableList.of(d), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(c), Optional.empty()))), + new TableArgumentProperties( + "input_2", + false, + false, + new PassThroughSpecification(true, ImmutableList.of(new PassThroughColumn(e, false), new PassThroughColumn(f, false))), + ImmutableList.of(f), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(), Optional.empty())))), + ImmutableList.of()); + }) + .matches(tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableList.of(ImmutableList.of("c"), ImmutableList.of("e", "f"))) + .requiredSymbols(ImmutableList.of(ImmutableList.of("d"), ImmutableList.of("f"))) + .markerSymbols(ImmutableMap.of( + "c", "marker_1", + "d", "marker_1", + "e", "marker_2", + "f", "marker_2")) + .specification(specification(ImmutableList.of("c"), ImmutableList.of("combined_row_number"), ImmutableMap.of("combined_row_number", ASC_NULLS_LAST))), + project(// append marker symbols + ImmutableMap.of( + "marker_1", expression("IF(input_1_row_number = combined_row_number, input_1_row_number, null)"), + "marker_2", expression("IF(input_2_row_number = combined_row_number, input_2_row_number, null)")), + project(// append helper symbols for joined nodes + ImmutableMap.of( + "combined_row_number", expression("IF(COALESCE(input_1_row_number, BIGINT '-1') > COALESCE(input_2_row_number, BIGINT '-1'), input_1_row_number, input_2_row_number)"), + "combined_partition_size", expression("IF(COALESCE(input_1_partition_size, BIGINT '-1') > COALESCE(input_2_partition_size, BIGINT '-1'), input_1_partition_size, input_2_partition_size)")), + join(// join nodes using helper symbols + JoinType.FULL, + ImmutableList.of(), + Optional.of("input_1_row_number = input_2_row_number OR " + + "(input_1_row_number > input_2_partition_size AND input_2_row_number = BIGINT '1' OR " + + "input_2_row_number > input_1_partition_size AND input_1_row_number = BIGINT '1')"), + window(// append helper symbols for source input_1 + builder -> builder + .specification(specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_1_partition_size", functionCall("count", ImmutableList.of())), + // input_1 + window(builder -> builder + .specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of()) + .addFunction("input_1_row_number", functionCall("row_number", ImmutableList.of())), + values("c", "d"))), + window(// append helper symbols for source input_2 + builder -> builder + .specification(specification(ImmutableList.of(), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_2_partition_size", functionCall("count", ImmutableList.of())), + // input_2 + window(builder -> builder + .specification(ImmutableList.of(), ImmutableList.of(), ImmutableMap.of()) + .addFunction("input_2_row_number", functionCall("row_number", ImmutableList.of())), + values("e", "f")))))))); + } + + @Test + public void testThreeSourcesWithSetSemantics() + { + tester().assertThat(new TransformTableFunctionToTableFunctionProcessor(tester().getMetadata())) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + VariableReferenceExpression c = p.variable("c"); + VariableReferenceExpression d = p.variable("d"); + VariableReferenceExpression e = p.variable("e"); + VariableReferenceExpression f = p.variable("f"); + VariableReferenceExpression g = p.variable("g"); + VariableReferenceExpression h = p.variable("h"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of( + p.values(c, d), + p.values(e, f), + p.values(g, h)), + ImmutableList.of( + new TableArgumentProperties( + "input_1", + false, + false, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(c, true))), + ImmutableList.of(d), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(c), Optional.empty()))), + new TableArgumentProperties( + "input_2", + false, + false, + new PassThroughSpecification(true, ImmutableList.of(new PassThroughColumn(e, false), new PassThroughColumn(f, false))), + ImmutableList.of(f), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(), Optional.empty()))), + new TableArgumentProperties( + "input_3", + false, + false, + new PassThroughSpecification(false, ImmutableList.of()), + ImmutableList.of(h), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(), Optional.of(new OrderingScheme(ImmutableList.of(new Ordering(h, DESC_NULLS_FIRST)))))))), + ImmutableList.of()); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableList.of(ImmutableList.of("c"), ImmutableList.of("e", "f"), ImmutableList.of())) + .requiredSymbols(ImmutableList.of(ImmutableList.of("d"), ImmutableList.of("f"), ImmutableList.of("h"))) + .markerSymbols(ImmutableMap.of( + "c", "marker_1", + "d", "marker_1", + "e", "marker_2", + "f", "marker_2", + "g", "marker_3", + "h", "marker_3")) + .specification(specification(ImmutableList.of("c"), ImmutableList.of("combined_row_number_1_2_3"), ImmutableMap.of("combined_row_number_1_2_3", ASC_NULLS_LAST))), + project(// append marker symbols + ImmutableMap.of( + "marker_1", expression("IF(input_1_row_number = combined_row_number_1_2_3, input_1_row_number, null)"), + "marker_2", expression("IF(input_2_row_number = combined_row_number_1_2_3, input_2_row_number, null)"), + "marker_3", expression("IF(input_3_row_number = combined_row_number_1_2_3, input_3_row_number, null)")), + project(// append helper symbols for joined nodes + ImmutableMap.of( + "combined_row_number_1_2_3", expression("IF(COALESCE(combined_row_number_1_2, BIGINT '-1') > COALESCE(input_3_row_number, BIGINT '-1'), combined_row_number_1_2, input_3_row_number)"), + "combined_partition_size_1_2_3", expression("IF(COALESCE(combined_partition_size_1_2, BIGINT '-1') > COALESCE(input_3_partition_size, BIGINT '-1'), combined_partition_size_1_2, input_3_partition_size)")), + join(// join nodes using helper symbols + JoinType.FULL, + ImmutableList.of(), + Optional.of("combined_row_number_1_2 = input_3_row_number OR " + + "(combined_row_number_1_2 > input_3_partition_size AND input_3_row_number = BIGINT '1' OR " + + "input_3_row_number > combined_partition_size_1_2 AND combined_row_number_1_2 = BIGINT '1')"), + project(// append helper symbols for joined nodes + ImmutableMap.of( + "combined_row_number_1_2", expression("IF(COALESCE(input_1_row_number, BIGINT '-1') > COALESCE(input_2_row_number, BIGINT '-1'), input_1_row_number, input_2_row_number)"), + "combined_partition_size_1_2", expression("IF(COALESCE(input_1_partition_size, BIGINT '-1') > COALESCE(input_2_partition_size, BIGINT '-1'), input_1_partition_size, input_2_partition_size)")), + join(// join nodes using helper symbols + JoinType.FULL, + ImmutableList.of(), + Optional.of("input_1_row_number = input_2_row_number OR " + + "(input_1_row_number > input_2_partition_size AND input_2_row_number = BIGINT '1' OR " + + "input_2_row_number > input_1_partition_size AND input_1_row_number = BIGINT '1')"), + window(// append helper symbols for source input_1 + builder -> builder + .specification(specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_1_partition_size", functionCall("count", ImmutableList.of())), + // input_1 + window(builder -> builder + .specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of()) + .addFunction("input_1_row_number", functionCall("row_number", ImmutableList.of())), + values("c", "d"))), + window(// append helper symbols for source input_2 + builder -> builder + .specification(specification(ImmutableList.of(), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_2_partition_size", functionCall("count", ImmutableList.of())), + // input_2 + window(builder -> builder + .specification(ImmutableList.of(), ImmutableList.of(), ImmutableMap.of()) + .addFunction("input_2_row_number", functionCall("row_number", ImmutableList.of())), + values("e", "f"))))), + window(// append helper symbols for source input_3 + builder -> builder + .specification(specification(ImmutableList.of(), ImmutableList.of("h"), ImmutableMap.of("h", DESC_NULLS_FIRST))) + .addFunction("input_3_partition_size", functionCall("count", ImmutableList.of())), + // input_3 + window(builder -> builder + .specification(ImmutableList.of(), ImmutableList.of("h"), ImmutableMap.of("h", DESC_NULLS_FIRST)) + .addFunction("input_3_row_number", functionCall("row_number", ImmutableList.of())), + values("g", "h")))))))); + } + + @Test + public void testTwoCoPartitionedSources() + { + tester().assertThat(new TransformTableFunctionToTableFunctionProcessor(tester().getMetadata())) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + VariableReferenceExpression c = p.variable("c"); + VariableReferenceExpression d = p.variable("d"); + VariableReferenceExpression e = p.variable("e"); + VariableReferenceExpression f = p.variable("f"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of( + p.values(c, d), + p.values(e, f)), + ImmutableList.of( + new TableArgumentProperties( + "input_1", + false, + true, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(c, true))), + ImmutableList.of(c, d), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(c), Optional.empty()))), + new TableArgumentProperties( + "input_2", + false, + false, + new PassThroughSpecification(true, ImmutableList.of(new PassThroughColumn(e, true), new PassThroughColumn(f, false))), + ImmutableList.of(f), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(e), Optional.of(new OrderingScheme(ImmutableList.of(new Ordering(f, DESC_NULLS_FIRST)))))))), + ImmutableList.of(ImmutableList.of("input_1", "input_2"))); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableList.of(ImmutableList.of("c"), ImmutableList.of("e", "f"))) + .requiredSymbols(ImmutableList.of(ImmutableList.of("c", "d"), ImmutableList.of("f"))) + .markerSymbols(ImmutableMap.of( + "c", "marker_1", + "d", "marker_1", + "e", "marker_2", + "f", "marker_2")) + .specification(specification(ImmutableList.of("combined_partition_column"), ImmutableList.of("combined_row_number"), ImmutableMap.of("combined_row_number", ASC_NULLS_LAST))), + project(// append marker symbols + ImmutableMap.of( + "marker_1", expression("IF(input_1_row_number = combined_row_number, input_1_row_number, null)"), + "marker_2", expression("IF(input_2_row_number = combined_row_number, input_2_row_number, null)")), + project(// append helper and partitioning symbols for co-partitioned nodes + ImmutableMap.of( + "combined_row_number", expression("IF(COALESCE(input_1_row_number, BIGINT '-1') > COALESCE(input_2_row_number, BIGINT '-1'), input_1_row_number, input_2_row_number)"), + "combined_partition_size", expression("IF(COALESCE(input_1_partition_size, BIGINT '-1') > COALESCE(input_2_partition_size, BIGINT '-1'), input_1_partition_size, input_2_partition_size)"), + "combined_partition_column", expression("COALESCE(c, e)")), + join(// co-partition nodes + JoinType.LEFT, + ImmutableList.of(), + Optional.of("NOT (c IS DISTINCT FROM e) " + + "AND ( " + + "input_1_row_number = input_2_row_number OR " + + "(input_1_row_number > input_2_partition_size AND input_2_row_number = BIGINT '1' OR " + + "input_2_row_number > input_1_partition_size AND input_1_row_number = BIGINT '1')) "), + window(// append helper symbols for source input_1 + builder -> builder + .specification(specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_1_partition_size", functionCall("count", ImmutableList.of())), + // input_1 + window(builder -> builder + .specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of()) + .addFunction("input_1_row_number", functionCall("row_number", ImmutableList.of())), + values("c", "d"))), + window(// append helper symbols for source input_2 + builder -> builder + .specification(specification(ImmutableList.of("e"), ImmutableList.of("f"), ImmutableMap.of("f", DESC_NULLS_FIRST))) + .addFunction("input_2_partition_size", functionCall("count", ImmutableList.of())), + // input_2 + window(builder -> builder + .specification(ImmutableList.of("e"), ImmutableList.of("f"), ImmutableMap.of("f", DESC_NULLS_FIRST)) + .addFunction("input_2_row_number", functionCall("row_number", ImmutableList.of())), + values("e", "f")))))))); + } + + @Test + public void testCoPartitionJoinTypes() + { + // both sources are prune when empty, so they are combined using inner join + tester().assertThat(new TransformTableFunctionToTableFunctionProcessor(tester().getMetadata())) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + VariableReferenceExpression c = p.variable("c"); + VariableReferenceExpression d = p.variable("d"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of( + p.values(c), + p.values(d)), + ImmutableList.of( + new TableArgumentProperties( + "input_1", + false, + true, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(c, true))), + ImmutableList.of(c), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(c), Optional.empty()))), + new TableArgumentProperties( + "input_2", + false, + true, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(d, true))), + ImmutableList.of(d), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(d), Optional.empty())))), + ImmutableList.of(ImmutableList.of("input_1", "input_2"))); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableList.of(ImmutableList.of("c"), ImmutableList.of("d"))) + .requiredSymbols(ImmutableList.of(ImmutableList.of("c"), ImmutableList.of("d"))) + .markerSymbols(ImmutableMap.of( + "c", "marker_1", + "d", "marker_2")) + .specification(specification(ImmutableList.of("combined_partition_column"), ImmutableList.of("combined_row_number"), ImmutableMap.of("combined_row_number", ASC_NULLS_LAST))), + project(// append marker symbols + ImmutableMap.of( + "marker_1", expression("IF(input_1_row_number = combined_row_number, input_1_row_number, null)"), + "marker_2", expression("IF(input_2_row_number = combined_row_number, input_2_row_number, null)")), + project(// append helper and partitioning symbols for co-partitioned nodes + ImmutableMap.of( + "combined_row_number", expression("IF(COALESCE(input_1_row_number, BIGINT '-1') > COALESCE(input_2_row_number, BIGINT '-1'), input_1_row_number, input_2_row_number)"), + "combined_partition_size", expression("IF(COALESCE(input_1_partition_size, BIGINT '-1') > COALESCE(input_2_partition_size, BIGINT '-1'), input_1_partition_size, input_2_partition_size)"), + "combined_partition_column", expression("COALESCE(c, d)")), + join(// co-partition nodes + JoinType.INNER, + ImmutableList.of(), + Optional.of("NOT (c IS DISTINCT FROM d) " + + "AND ( " + + "input_1_row_number = input_2_row_number OR " + + "(input_1_row_number > input_2_partition_size AND input_2_row_number = BIGINT '1' OR " + + "input_2_row_number > input_1_partition_size AND input_1_row_number = BIGINT '1')) "), + window(// append helper symbols for source input_1 + builder -> builder + .specification(specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_1_partition_size", functionCall("count", ImmutableList.of())), + // input_1 + window(builder -> builder + .specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of()) + .addFunction("input_1_row_number", functionCall("row_number", ImmutableList.of())), + values("c"))), + window(// append helper symbols for source input_2 + builder -> builder + .specification(specification(ImmutableList.of("d"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_2_partition_size", functionCall("count", ImmutableList.of())), + // input_2 + window(builder -> builder + .specification(ImmutableList.of("d"), ImmutableList.of(), ImmutableMap.of()) + .addFunction("input_2_row_number", functionCall("row_number", ImmutableList.of())), + values("d")))))))); + + // only the left source is prune when empty, so sources are combined using left join + tester().assertThat(new TransformTableFunctionToTableFunctionProcessor(tester().getMetadata())) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + VariableReferenceExpression c = p.variable("c"); + VariableReferenceExpression d = p.variable("d"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of( + p.values(c), + p.values(d)), + ImmutableList.of( + new TableArgumentProperties( + "input_1", + false, + true, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(c, true))), + ImmutableList.of(c), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(c), Optional.empty()))), + new TableArgumentProperties( + "input_2", + false, + false, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(d, true))), + ImmutableList.of(d), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(d), Optional.empty())))), + ImmutableList.of(ImmutableList.of("input_1", "input_2"))); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableList.of(ImmutableList.of("c"), ImmutableList.of("d"))) + .requiredSymbols(ImmutableList.of(ImmutableList.of("c"), ImmutableList.of("d"))) + .markerSymbols(ImmutableMap.of( + "c", "marker_1", + "d", "marker_2")) + .specification(specification(ImmutableList.of("combined_partition_column"), ImmutableList.of("combined_row_number"), ImmutableMap.of("combined_row_number", ASC_NULLS_LAST))), + project(// append marker symbols + ImmutableMap.of( + "marker_1", expression("IF(input_1_row_number = combined_row_number, input_1_row_number, null)"), + "marker_2", expression("IF(input_2_row_number = combined_row_number, input_2_row_number, null)")), + project(// append helper and partitioning symbols for co-partitioned nodes + ImmutableMap.of( + "combined_row_number", expression("IF(COALESCE(input_1_row_number, BIGINT '-1') > COALESCE(input_2_row_number, BIGINT '-1'), input_1_row_number, input_2_row_number)"), + "combined_partition_size", expression("IF(COALESCE(input_1_partition_size, BIGINT '-1') > COALESCE(input_2_partition_size, BIGINT '-1'), input_1_partition_size, input_2_partition_size)"), + "combined_partition_column", expression("COALESCE(c, d)")), + join(// co-partition nodes + JoinType.LEFT, + ImmutableList.of(), + Optional.of("NOT (c IS DISTINCT FROM d) " + + "AND ( " + + " input_1_row_number = input_2_row_number OR " + + " (input_1_row_number > input_2_partition_size AND input_2_row_number = BIGINT '1' OR " + + " input_2_row_number > input_1_partition_size AND input_1_row_number = BIGINT '1'))"), + window(// append helper symbols for source input_1 + builder -> builder + .specification(specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_1_partition_size", functionCall("count", ImmutableList.of())), + // input_1 + window(builder -> builder + .specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of()) + .addFunction("input_1_row_number", functionCall("row_number", ImmutableList.of())), + values("c"))), + window(// append helper symbols for source input_2 + builder -> builder + .specification(specification(ImmutableList.of("d"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_2_partition_size", functionCall("count", ImmutableList.of())), + // input_2 + window(builder -> builder + .specification(ImmutableList.of("d"), ImmutableList.of(), ImmutableMap.of()) + .addFunction("input_2_row_number", functionCall("row_number", ImmutableList.of())), + values("d")))))))); + + // only the right source is prune when empty. the sources are reordered so that the prune when empty source is first. they are combined using left join + tester().assertThat(new TransformTableFunctionToTableFunctionProcessor(tester().getMetadata())) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + VariableReferenceExpression c = p.variable("c"); + VariableReferenceExpression d = p.variable("d"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of( + p.values(c), + p.values(d)), + ImmutableList.of( + new TableArgumentProperties( + "input_1", + false, + false, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(c, true))), + ImmutableList.of(c), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(c), Optional.empty()))), + new TableArgumentProperties( + "input_2", + false, + true, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(d, true))), + ImmutableList.of(d), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(d), Optional.empty())))), + ImmutableList.of(ImmutableList.of("input_1", "input_2"))); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableList.of(ImmutableList.of("c"), ImmutableList.of("d"))) + .requiredSymbols(ImmutableList.of(ImmutableList.of("c"), ImmutableList.of("d"))) + .markerSymbols(ImmutableMap.of( + "c", "marker_1", + "d", "marker_2")) + .specification(specification(ImmutableList.of("combined_partition_column"), ImmutableList.of("combined_row_number"), ImmutableMap.of("combined_row_number", ASC_NULLS_LAST))), + project(// append marker symbols + ImmutableMap.of( + "marker_1", expression("IF(input_1_row_number = combined_row_number, input_1_row_number, null)"), + "marker_2", expression("IF(input_2_row_number = combined_row_number, input_2_row_number, null)")), + project(// append helper and partitioning symbols for co-partitioned nodes + ImmutableMap.of( + "combined_row_number", expression("IF(COALESCE(input_2_row_number, BIGINT '-1') > COALESCE(input_1_row_number, BIGINT '-1'), input_2_row_number, input_1_row_number)"), + "combined_partition_size", expression("IF(COALESCE(input_2_partition_size, BIGINT '-1') > COALESCE(input_1_partition_size, BIGINT '-1'), input_2_partition_size, input_1_partition_size)"), + "combined_partition_column", expression("COALESCE(d, c)")), + join(// co-partition nodes + JoinType.LEFT, + ImmutableList.of(), + Optional.of("NOT (d IS DISTINCT FROM c) " + + "AND (" + + " input_2_row_number = input_1_row_number OR" + + " (input_2_row_number > input_1_partition_size AND input_1_row_number = BIGINT '1' OR" + + " input_1_row_number > input_2_partition_size AND input_2_row_number = BIGINT '1'))"), + window(// append helper symbols for source input_2 + builder -> builder + .specification(specification(ImmutableList.of("d"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_2_partition_size", functionCall("count", ImmutableList.of())), + // input_2 + window(builder -> builder + .specification(ImmutableList.of("d"), ImmutableList.of(), ImmutableMap.of()) + .addFunction("input_2_row_number", functionCall("row_number", ImmutableList.of())), + values("d"))), + window(// append helper symbols for source input_1 + builder -> builder + .specification(specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_1_partition_size", functionCall("count", ImmutableList.of())), + // input_1 + window(builder -> builder + .specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of()) + .addFunction("input_1_row_number", functionCall("row_number", ImmutableList.of())), + values("c")))))))); + + // neither source is prune when empty, so sources are combined using full join + tester().assertThat(new TransformTableFunctionToTableFunctionProcessor(tester().getMetadata())) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + VariableReferenceExpression c = p.variable("c"); + VariableReferenceExpression d = p.variable("d"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of( + p.values(c), + p.values(d)), + ImmutableList.of( + new TableArgumentProperties( + "input_1", + false, + false, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(c, true))), + ImmutableList.of(c), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(c), Optional.empty()))), + new TableArgumentProperties( + "input_2", + false, + false, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(d, true))), + ImmutableList.of(d), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(d), Optional.empty())))), + ImmutableList.of(ImmutableList.of("input_1", "input_2"))); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableList.of(ImmutableList.of("c"), ImmutableList.of("d"))) + .requiredSymbols(ImmutableList.of(ImmutableList.of("c"), ImmutableList.of("d"))) + .markerSymbols(ImmutableMap.of( + "c", "marker_1", + "d", "marker_2")) + .specification(specification(ImmutableList.of("combined_partition_column"), ImmutableList.of("combined_row_number"), ImmutableMap.of("combined_row_number", ASC_NULLS_LAST))), + project(// append marker symbols + ImmutableMap.of( + "marker_1", expression("IF(input_1_row_number = combined_row_number, input_1_row_number, null)"), + "marker_2", expression("IF(input_2_row_number = combined_row_number, input_2_row_number, null)")), + project(// append helper and partitioning symbols for co-partitioned nodes + ImmutableMap.of( + "combined_row_number", expression("IF(COALESCE(input_1_row_number, BIGINT '-1') > COALESCE(input_2_row_number, BIGINT '-1'), input_1_row_number, input_2_row_number)"), + "combined_partition_size", expression("IF(COALESCE(input_1_partition_size, BIGINT '-1') > COALESCE(input_2_partition_size, BIGINT '-1'), input_1_partition_size, input_2_partition_size)"), + "combined_partition_column", expression("COALESCE(c, d)")), + join(// co-partition nodes + JoinType.FULL, + ImmutableList.of(), + Optional.of("NOT (c IS DISTINCT FROM d)" + + " AND (" + + " input_1_row_number = input_2_row_number OR" + + " (input_1_row_number > input_2_partition_size AND input_2_row_number = BIGINT '1' OR" + + " input_2_row_number > input_1_partition_size AND input_1_row_number = BIGINT '1'))"), + window(// append helper symbols for source input_1 + builder -> builder + .specification(specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_1_partition_size", functionCall("count", ImmutableList.of())), + // input_1 + window(builder -> builder + .specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of()) + .addFunction("input_1_row_number", functionCall("row_number", ImmutableList.of())), + values("c"))), + window(// append helper symbols for source input_2 + builder -> builder + .specification(specification(ImmutableList.of("d"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_2_partition_size", functionCall("count", ImmutableList.of())), + // input_2 + window(builder -> builder + .specification(ImmutableList.of("d"), ImmutableList.of(), ImmutableMap.of()) + .addFunction("input_2_row_number", functionCall("row_number", ImmutableList.of())), + values("d")))))))); + } + + @Test + public void testThreeCoPartitionedSources() + { + tester().assertThat(new TransformTableFunctionToTableFunctionProcessor(tester().getMetadata())) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + VariableReferenceExpression c = p.variable("c"); + VariableReferenceExpression d = p.variable("d"); + VariableReferenceExpression e = p.variable("e"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of( + p.values(c), + p.values(d), + p.values(e)), + ImmutableList.of( + new TableArgumentProperties( + "input_1", + false, + true, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(c, true))), + ImmutableList.of(c), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(c), Optional.empty()))), + new TableArgumentProperties( + "input_2", + false, + true, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(d, true))), + ImmutableList.of(d), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(d), Optional.empty()))), + new TableArgumentProperties( + "input_3", + false, + false, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(e, true))), + ImmutableList.of(e), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(e), Optional.empty())))), + ImmutableList.of(ImmutableList.of("input_1", "input_2", "input_3"))); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableList.of(ImmutableList.of("c"), ImmutableList.of("d"), ImmutableList.of("e"))) + .requiredSymbols(ImmutableList.of(ImmutableList.of("c"), ImmutableList.of("d"), ImmutableList.of("e"))) + .markerSymbols(ImmutableMap.of( + "c", "marker_1", + "d", "marker_2", + "e", "marker_3")) + .specification(specification(ImmutableList.of("combined_partition_column_1_2_3"), ImmutableList.of("combined_row_number_1_2_3"), ImmutableMap.of("combined_row_number_1_2_3", ASC_NULLS_LAST))), + project(// append marker symbols + ImmutableMap.of( + "marker_1", expression("IF(input_1_row_number = combined_row_number_1_2_3, input_1_row_number, null)"), + "marker_2", expression("IF(input_2_row_number = combined_row_number_1_2_3, input_2_row_number, null)"), + "marker_3", expression("IF(input_3_row_number = combined_row_number_1_2_3, input_3_row_number, null)")), + project(// append helper and partitioning symbols for co-partitioned nodes + ImmutableMap.of( + "combined_row_number_1_2_3", expression("IF(COALESCE(combined_row_number_1_2, BIGINT '-1') > COALESCE(input_3_row_number, BIGINT '-1'), combined_row_number_1_2, input_3_row_number)"), + "combined_partition_size_1_2_3", expression("IF(COALESCE(combined_partition_size_1_2, BIGINT '-1') > COALESCE(input_3_partition_size, BIGINT '-1'), combined_partition_size_1_2, input_3_partition_size)"), + "combined_partition_column_1_2_3", expression("COALESCE(combined_partition_column_1_2, e)")), + join(// co-partition nodes + JoinType.LEFT, + ImmutableList.of(), + Optional.of("NOT (combined_partition_column_1_2 IS DISTINCT FROM e) " + + "AND (" + + " combined_row_number_1_2 = input_3_row_number OR" + + " (combined_row_number_1_2 > input_3_partition_size AND input_3_row_number = BIGINT '1' OR" + + " input_3_row_number > combined_partition_size_1_2 AND combined_row_number_1_2 = BIGINT '1'))"), + project(// append helper and partitioning symbols for co-partitioned nodes + ImmutableMap.of( + "combined_row_number_1_2", expression("IF(COALESCE(input_1_row_number, BIGINT '-1') > COALESCE(input_2_row_number, BIGINT '-1'), input_1_row_number, input_2_row_number)"), + "combined_partition_size_1_2", expression("IF(COALESCE(input_1_partition_size, BIGINT '-1') > COALESCE(input_2_partition_size, BIGINT '-1'), input_1_partition_size, input_2_partition_size)"), + "combined_partition_column_1_2", expression("COALESCE(c, d)")), + join(// co-partition nodes + JoinType.INNER, + ImmutableList.of(), + Optional.of("NOT (c IS DISTINCT FROM d) " + + "AND (" + + " input_1_row_number = input_2_row_number OR" + + " (input_1_row_number > input_2_partition_size AND input_2_row_number = BIGINT '1' OR" + + " input_2_row_number > input_1_partition_size AND input_1_row_number = BIGINT '1'))"), + window(// append helper symbols for source input_1 + builder -> builder + .specification(specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_1_partition_size", functionCall("count", ImmutableList.of())), + // input_1 + window(builder -> builder + .specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of()) + .addFunction("input_1_row_number", functionCall("row_number", ImmutableList.of())), + values("c"))), + window(// append helper symbols for source input_2 + builder -> builder + .specification(specification(ImmutableList.of("d"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_2_partition_size", functionCall("count", ImmutableList.of())), + // input_2 + window(builder -> builder + .specification(ImmutableList.of("d"), ImmutableList.of(), ImmutableMap.of()) + .addFunction("input_2_row_number", functionCall("row_number", ImmutableList.of())), + values("d"))))), + window(// append helper symbols for source input_3 + builder -> builder + .specification(specification(ImmutableList.of("e"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_3_partition_size", functionCall("count", ImmutableList.of())), + // input_3 + window(builder -> builder + .specification(ImmutableList.of("e"), ImmutableList.of(), ImmutableMap.of()) + .addFunction("input_3_row_number", functionCall("row_number", ImmutableList.of())), + values("e")))))))); + } + + @Test + public void testTwoCoPartitionLists() + { + tester().assertThat(new TransformTableFunctionToTableFunctionProcessor(tester().getMetadata())) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + VariableReferenceExpression c = p.variable("c"); + VariableReferenceExpression d = p.variable("d"); + VariableReferenceExpression e = p.variable("e"); + VariableReferenceExpression f = p.variable("f"); + VariableReferenceExpression g = p.variable("g"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of( + p.values(c), + p.values(d), + p.values(e), + p.values(f, g)), + ImmutableList.of( + new TableArgumentProperties( + "input_1", + false, + true, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(c, true))), + ImmutableList.of(c), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(c), Optional.empty()))), + new TableArgumentProperties( + "input_2", + false, + true, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(d, true))), + ImmutableList.of(d), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(d), Optional.empty()))), + new TableArgumentProperties( + "input_3", + false, + false, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(e, true))), + ImmutableList.of(e), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(e), Optional.empty()))), + new TableArgumentProperties( + "input_4", + false, + false, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(f, true))), + ImmutableList.of(g), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(f), Optional.of(new OrderingScheme(ImmutableList.of(new Ordering(g, DESC_NULLS_FIRST)))))))), + ImmutableList.of( + ImmutableList.of("input_1", "input_2"), + ImmutableList.of("input_3", "input_4"))); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableList.of(ImmutableList.of("c"), ImmutableList.of("d"), ImmutableList.of("e"), ImmutableList.of("f"))) + .requiredSymbols(ImmutableList.of(ImmutableList.of("c"), ImmutableList.of("d"), ImmutableList.of("e"), ImmutableList.of("g"))) + .markerSymbols(ImmutableMap.of( + "c", "marker_1", + "d", "marker_2", + "e", "marker_3", + "f", "marker_4", + "g", "marker_4")) + .specification(specification(ImmutableList.of("combined_partition_column_1_2", "combined_partition_column_3_4"), ImmutableList.of("combined_row_number_1_2_3_4"), ImmutableMap.of("combined_row_number_1_2_3_4", ASC_NULLS_LAST))), + project(// append marker symbols + ImmutableMap.of( + "marker_1", expression("IF(input_1_row_number = combined_row_number_1_2_3_4, input_1_row_number, null)"), + "marker_2", expression("IF(input_2_row_number = combined_row_number_1_2_3_4, input_2_row_number, null)"), + "marker_3", expression("IF(input_3_row_number = combined_row_number_1_2_3_4, input_3_row_number, null)"), + "marker_4", expression("IF(input_4_row_number = combined_row_number_1_2_3_4, input_4_row_number, null)")), + project(// append helper symbols for joined nodes + ImmutableMap.of( + "combined_row_number_1_2_3_4", expression("IF(COALESCE(combined_row_number_1_2, BIGINT '-1') > COALESCE(combined_row_number_3_4, BIGINT '-1'), combined_row_number_1_2, combined_row_number_3_4)"), + "combined_partition_size_1_2_3_4", expression("IF(COALESCE(combined_partition_size_1_2, BIGINT '-1') > COALESCE(combined_partition_size_3_4, BIGINT '-1'), combined_partition_size_1_2, combined_partition_size_3_4)")), + join(// join nodes using helper symbols + JoinType.LEFT, + ImmutableList.of(), + Optional.of("combined_row_number_1_2 = combined_row_number_3_4 OR " + + "(combined_row_number_1_2 > combined_partition_size_3_4 AND combined_row_number_3_4 = BIGINT '1' OR " + + "combined_row_number_3_4 > combined_partition_size_1_2 AND combined_row_number_1_2 = BIGINT '1')"), + project(// append helper and partitioning symbols for co-partitioned nodes + ImmutableMap.of( + "combined_row_number_1_2", expression("IF(COALESCE(input_1_row_number, BIGINT '-1') > COALESCE(input_2_row_number, BIGINT '-1'), input_1_row_number, input_2_row_number)"), + "combined_partition_size_1_2", expression("IF(COALESCE(input_1_partition_size, BIGINT '-1') > COALESCE(input_2_partition_size, BIGINT '-1'), input_1_partition_size, input_2_partition_size)"), + "combined_partition_column_1_2", expression("COALESCE(c, d)")), + join(// co-partition nodes + JoinType.INNER, + ImmutableList.of(), + Optional.of("NOT (c IS DISTINCT FROM d) " + + "AND ( " + + "input_1_row_number = input_2_row_number OR " + + "(input_1_row_number > input_2_partition_size AND input_2_row_number = BIGINT '1' OR " + + "input_2_row_number > input_1_partition_size AND input_1_row_number = BIGINT '1'))"), + window(// append helper symbols for source input_1 + builder -> builder + .specification(specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_1_partition_size", functionCall("count", ImmutableList.of())), + // input_1 + window(builder -> builder + .specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of()) + .addFunction("input_1_row_number", functionCall("row_number", ImmutableList.of())), + values("c"))), + window(// append helper symbols for source input_2 + builder -> builder + .specification(specification(ImmutableList.of("d"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_2_partition_size", functionCall("count", ImmutableList.of())), + // input_2 + window(builder -> builder + .specification(ImmutableList.of("d"), ImmutableList.of(), ImmutableMap.of()) + .addFunction("input_2_row_number", functionCall("row_number", ImmutableList.of())), + values("d"))))), + project(// append helper and partitioning symbols for co-partitioned nodes + ImmutableMap.of( + "combined_row_number_3_4", expression("IF(COALESCE(input_3_row_number, BIGINT '-1') > COALESCE(input_4_row_number, BIGINT '-1'), input_3_row_number, input_4_row_number)"), + "combined_partition_size_3_4", expression("IF(COALESCE(input_3_partition_size, BIGINT '-1') > COALESCE(input_4_partition_size, BIGINT '-1'), input_3_partition_size, input_4_partition_size)"), + "combined_partition_column_3_4", expression("COALESCE(e, f)")), + join(// co-partition nodes + JoinType.FULL, + ImmutableList.of(), + Optional.of("NOT (e IS DISTINCT FROM f) " + + "AND ( " + + "input_3_row_number = input_4_row_number OR " + + "(input_3_row_number > input_4_partition_size AND input_4_row_number = BIGINT '1' OR " + + "input_4_row_number > input_3_partition_size AND input_3_row_number = BIGINT '1')) "), + window(// append helper symbols for source input_3 + builder -> builder + .specification(specification(ImmutableList.of("e"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_3_partition_size", functionCall("count", ImmutableList.of())), + // input_3 + window(builder -> builder + .specification(ImmutableList.of("e"), ImmutableList.of(), ImmutableMap.of()) + .addFunction("input_3_row_number", functionCall("row_number", ImmutableList.of())), + values("e"))), + window(// append helper symbols for source input_4 + builder -> builder + .specification(specification(ImmutableList.of("f"), ImmutableList.of("g"), ImmutableMap.of("g", DESC_NULLS_FIRST))) + .addFunction("input_4_partition_size", functionCall("count", ImmutableList.of())), + // input_4 + window(builder -> builder + .specification(ImmutableList.of("f"), ImmutableList.of("g"), ImmutableMap.of("g", DESC_NULLS_FIRST)) + .addFunction("input_4_row_number", functionCall("row_number", ImmutableList.of())), + values("f", "g")))))))))); + } + + @Test + public void testCoPartitionedAndNotCoPartitionedSources() + { + tester().assertThat(new TransformTableFunctionToTableFunctionProcessor(tester().getMetadata())) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + VariableReferenceExpression c = p.variable("c"); + VariableReferenceExpression d = p.variable("d"); + VariableReferenceExpression e = p.variable("e"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of( + p.values(c), + p.values(d), + p.values(e)), + ImmutableList.of( + new TableArgumentProperties( + "input_1", + false, + true, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(c, true))), + ImmutableList.of(c), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(c), Optional.empty()))), + new TableArgumentProperties( + "input_2", + false, + true, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(d, true))), + ImmutableList.of(d), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(d), Optional.empty()))), + new TableArgumentProperties( + "input_3", + false, + false, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(e, true))), + ImmutableList.of(e), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(e), Optional.empty())))), + ImmutableList.of(ImmutableList.of("input_2", "input_3"))); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableList.of(ImmutableList.of("c"), ImmutableList.of("d"), ImmutableList.of("e"))) + .requiredSymbols(ImmutableList.of(ImmutableList.of("c"), ImmutableList.of("d"), ImmutableList.of("e"))) + .markerSymbols(ImmutableMap.of( + "c", "marker_1", + "d", "marker_2", + "e", "marker_3")) + .specification(specification(ImmutableList.of("combined_partition_column_2_3", "c"), ImmutableList.of("combined_row_number_2_3_1"), ImmutableMap.of("combined_row_number_2_3_1", ASC_NULLS_LAST))), + project(// append marker symbols + ImmutableMap.of( + "marker_1", expression("IF(input_1_row_number = combined_row_number_2_3_1, input_1_row_number, null)"), + "marker_2", expression("IF(input_2_row_number = combined_row_number_2_3_1, input_2_row_number, null)"), + "marker_3", expression("IF(input_3_row_number = combined_row_number_2_3_1, input_3_row_number, null)")), + project(// append helper symbols for joined nodes + ImmutableMap.of( + "combined_row_number_2_3_1", expression("IF(COALESCE(combined_row_number_2_3, BIGINT '-1') > COALESCE(input_1_row_number, BIGINT '-1'), combined_row_number_2_3, input_1_row_number)"), + "combined_partition_size_2_3_1", expression("IF(COALESCE(combined_partition_size_2_3, BIGINT '-1') > COALESCE(input_1_partition_size, BIGINT '-1'), combined_partition_size_2_3, input_1_partition_size)")), + join(// join nodes using helper symbols + JoinType.INNER, + ImmutableList.of(), + Optional.of("combined_row_number_2_3 = input_1_row_number OR " + + "(combined_row_number_2_3 > input_1_partition_size AND input_1_row_number = BIGINT '1' OR " + + "input_1_row_number > combined_partition_size_2_3 AND combined_row_number_2_3 = BIGINT '1')"), + project(// append helper and partitioning symbols for co-partitioned nodes + ImmutableMap.of( + "combined_row_number_2_3", expression("IF(COALESCE(input_2_row_number, BIGINT '-1') > COALESCE(input_3_row_number, BIGINT '-1'), input_2_row_number, input_3_row_number)"), + "combined_partition_size_2_3", expression("IF(COALESCE(input_2_partition_size, BIGINT '-1') > COALESCE(input_3_partition_size, BIGINT '-1'), input_2_partition_size, input_3_partition_size)"), + "combined_partition_column_2_3", expression("COALESCE(d, e)")), + join(// co-partition nodes + JoinType.LEFT, + ImmutableList.of(), + Optional.of("NOT (d IS DISTINCT FROM e) " + + "AND ( " + + "input_2_row_number = input_3_row_number OR " + + "(input_2_row_number > input_3_partition_size AND input_3_row_number = BIGINT '1' OR " + + "input_3_row_number > input_2_partition_size AND input_2_row_number = BIGINT '1'))"), + window(// append helper symbols for source input_2 + builder -> builder + .specification(specification(ImmutableList.of("d"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_2_partition_size", functionCall("count", ImmutableList.of())), + // input_2 + window(builder -> builder + .specification(ImmutableList.of("d"), ImmutableList.of(), ImmutableMap.of()) + .addFunction("input_2_row_number", functionCall("row_number", ImmutableList.of())), + values("d"))), + window(// append helper symbols for source input_3 + builder -> builder + .specification(specification(ImmutableList.of("e"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_3_partition_size", functionCall("count", ImmutableList.of())), + // input_3 + window(builder -> builder + .specification(ImmutableList.of("e"), ImmutableList.of(), ImmutableMap.of()) + .addFunction("input_3_row_number", functionCall("row_number", ImmutableList.of())), + values("e"))))), + window(// append helper symbols for source input_1 + builder -> builder + .specification(specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_1_partition_size", functionCall("count", ImmutableList.of())), + // input_1 + window(builder -> builder + .specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of()) + .addFunction("input_1_row_number", functionCall("row_number", ImmutableList.of())), + values("c")))))))); + } + + @Test + public void testCoerceForCopartitioning() + { + tester().assertThat(new TransformTableFunctionToTableFunctionProcessor(tester().getMetadata())) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + VariableReferenceExpression c = p.variable("c", TINYINT); + VariableReferenceExpression cCoerced = p.variable("c_coerced", INTEGER); + VariableReferenceExpression d = p.variable("d"); + VariableReferenceExpression e = p.variable("e", INTEGER); + VariableReferenceExpression f = p.variable("f"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of( + // coerce column c for co-partitioning + p.project( + Assignments.builder() + .put(c, p.rowExpression("c")) + .put(d, p.rowExpression("d")) + .put(cCoerced, p.rowExpression("CAST(c AS INTEGER)")) + .build(), + p.values(c, d)), + p.values(e, f)), + ImmutableList.of( + new TableArgumentProperties( + "input_1", + false, + true, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(c, true))), + ImmutableList.of(c, d), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(cCoerced), Optional.empty()))), + new TableArgumentProperties( + "input_2", + false, + false, + new PassThroughSpecification(true, ImmutableList.of(new PassThroughColumn(e, true), new PassThroughColumn(f, false))), + ImmutableList.of(f), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(e), Optional.of(new OrderingScheme(ImmutableList.of(new Ordering(f, DESC_NULLS_FIRST)))))))), + ImmutableList.of(ImmutableList.of("input_1", "input_2"))); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableList.of(ImmutableList.of("c"), ImmutableList.of("e", "f"))) + .requiredSymbols(ImmutableList.of(ImmutableList.of("c", "d"), ImmutableList.of("f"))) + .markerSymbols(ImmutableMap.of( + "c", "marker_1", + "c_coerced", "marker_1", + "d", "marker_1", + "e", "marker_2", + "f", "marker_2")) + .specification(specification(ImmutableList.of("combined_partition_column"), ImmutableList.of("combined_row_number"), ImmutableMap.of("combined_row_number", ASC_NULLS_LAST))), + project(// append marker symbols + ImmutableMap.of( + "marker_1", expression("IF(input_1_row_number = combined_row_number, input_1_row_number, null)"), + "marker_2", expression("IF(input_2_row_number = combined_row_number, input_2_row_number, null)")), + project(// append helper and partitioning symbols for co-partitioned nodes + ImmutableMap.of( + "combined_row_number", expression("IF(COALESCE(input_1_row_number, BIGINT '-1') > COALESCE(input_2_row_number, BIGINT '-1'), input_1_row_number, input_2_row_number)"), + "combined_partition_size", expression("IF(COALESCE(input_1_partition_size, BIGINT '-1') > COALESCE(input_2_partition_size, BIGINT '-1'), input_1_partition_size, input_2_partition_size)"), + "combined_partition_column", expression("COALESCE(c_coerced, e)")), + join(// co-partition nodes + JoinType.LEFT, + ImmutableList.of(), + Optional.of("NOT (c_coerced IS DISTINCT FROM e) " + + "AND ( " + + " input_1_row_number = input_2_row_number OR" + + " (input_1_row_number > input_2_partition_size AND input_2_row_number = BIGINT '1' OR" + + " input_2_row_number > input_1_partition_size AND input_1_row_number = BIGINT '1'))"), + window(// append helper symbols for source input_1 + builder -> builder + .specification(specification(ImmutableList.of("c_coerced"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_1_partition_size", functionCall("count", ImmutableList.of())), + // input_1 + window(builder -> builder + .specification(ImmutableList.of("c_coerced"), ImmutableList.of(), ImmutableMap.of()) + .addFunction("input_1_row_number", functionCall("row_number", ImmutableList.of())), + project( + ImmutableMap.of("c_coerced", expression("CAST(c AS INTEGER)")), + values("c", "d")))), + window(// append helper symbols for source input_2 + builder -> builder + .specification(specification(ImmutableList.of("e"), ImmutableList.of("f"), ImmutableMap.of("f", DESC_NULLS_FIRST))) + .addFunction("input_2_partition_size", functionCall("count", ImmutableList.of())), + // input_2 + window(builder -> builder + .specification(ImmutableList.of("e"), ImmutableList.of("f"), ImmutableMap.of("f", DESC_NULLS_FIRST)) + .addFunction("input_2_row_number", functionCall("row_number", ImmutableList.of())), + values("e", "f")))))))); + } + + @Test + public void testTwoCoPartitioningColumns() + { + tester().assertThat(new TransformTableFunctionToTableFunctionProcessor(tester().getMetadata())) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + VariableReferenceExpression c = p.variable("c"); + VariableReferenceExpression d = p.variable("d"); + VariableReferenceExpression e = p.variable("e"); + VariableReferenceExpression f = p.variable("f"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of( + p.values(c, d), + p.values(e, f)), + ImmutableList.of( + new TableArgumentProperties( + "input_1", + false, + true, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(c, true), new PassThroughColumn(d, true))), + ImmutableList.of(c), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(c, d), Optional.empty()))), + new TableArgumentProperties( + "input_2", + false, + false, + new PassThroughSpecification(true, ImmutableList.of(new PassThroughColumn(e, true), new PassThroughColumn(f, true))), + ImmutableList.of(e), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(e, f), Optional.empty())))), + ImmutableList.of(ImmutableList.of("input_1", "input_2"))); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableList.of(ImmutableList.of("c", "d"), ImmutableList.of("e", "f"))) + .requiredSymbols(ImmutableList.of(ImmutableList.of("c"), ImmutableList.of("e"))) + .markerSymbols(ImmutableMap.of( + "c", "marker_1", + "d", "marker_1", + "e", "marker_2", + "f", "marker_2")) + .specification(specification(ImmutableList.of("combined_partition_column_1", "combined_partition_column_2"), ImmutableList.of("combined_row_number"), ImmutableMap.of("combined_row_number", ASC_NULLS_LAST))), + project(// append marker symbols + ImmutableMap.of( + "marker_1", expression("IF(input_1_row_number = combined_row_number, input_1_row_number, null)"), + "marker_2", expression("IF(input_2_row_number = combined_row_number, input_2_row_number, null)")), + project(// append helper and partitioning symbols for co-partitioned nodes + ImmutableMap.of( + "combined_row_number", expression("IF(COALESCE(input_1_row_number, BIGINT '-1') > COALESCE(input_2_row_number, BIGINT '-1'), input_1_row_number, input_2_row_number)"), + "combined_partition_size", expression("IF(COALESCE(input_1_partition_size, BIGINT '-1') > COALESCE(input_2_partition_size, BIGINT '-1'), input_1_partition_size, input_2_partition_size)"), + "combined_partition_column_1", expression("COALESCE(c, e)"), + "combined_partition_column_2", expression("COALESCE(d, f)")), + join(// co-partition nodes + JoinType.LEFT, + ImmutableList.of(), + Optional.of("NOT (c IS DISTINCT FROM e) " + + "AND NOT (d IS DISTINCT FROM f) " + + "AND ( " + + " input_1_row_number = input_2_row_number OR" + + " (input_1_row_number > input_2_partition_size AND input_2_row_number = BIGINT '1' OR" + + " input_2_row_number > input_1_partition_size AND input_1_row_number = BIGINT '1'))"), + window(// append helper symbols for source input_1 + builder -> builder + .specification(specification(ImmutableList.of("c", "d"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_1_partition_size", functionCall("count", ImmutableList.of())), + // input_1 + window(builder -> builder + .specification(ImmutableList.of("c", "d"), ImmutableList.of(), ImmutableMap.of()) + .addFunction("input_1_row_number", functionCall("row_number", ImmutableList.of())), + values("c", "d"))), + window(// append helper symbols for source input_2 + builder -> builder + .specification(specification(ImmutableList.of("e", "f"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_2_partition_size", functionCall("count", ImmutableList.of())), + // input_2 + window(builder -> builder + .specification(ImmutableList.of("e", "f"), ImmutableList.of(), ImmutableMap.of()) + .addFunction("input_2_row_number", functionCall("row_number", ImmutableList.of())), + values("e", "f")))))))); + } + + @Test + public void testTwoSourcesWithRowAndSetSemantics() + { + tester().assertThat(new TransformTableFunctionToTableFunctionProcessor(tester().getMetadata())) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + VariableReferenceExpression c = p.variable("c"); + VariableReferenceExpression d = p.variable("d"); + VariableReferenceExpression e = p.variable("e"); + VariableReferenceExpression f = p.variable("f"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of( + p.values(c, d), + p.values(e, f)), + ImmutableList.of( + new TableArgumentProperties( + "input_1", + false, + false, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(c, true))), + ImmutableList.of(d), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(c), Optional.empty()))), + new TableArgumentProperties( + "input_2", + true, + false, + new PassThroughSpecification(true, ImmutableList.of(new PassThroughColumn(e, false), new PassThroughColumn(f, false))), + ImmutableList.of(e), + Optional.empty())), + ImmutableList.of()); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableList.of(ImmutableList.of("c"), ImmutableList.of("e", "f"))) + .requiredSymbols(ImmutableList.of(ImmutableList.of("d"), ImmutableList.of("e"))) + .markerSymbols(ImmutableMap.of( + "c", "marker_1", + "d", "marker_1", + "e", "marker_2", + "f", "marker_2")) + .specification(specification(ImmutableList.of("c"), ImmutableList.of("combined_row_number"), ImmutableMap.of("combined_row_number", ASC_NULLS_LAST))), + project(// append marker symbols + ImmutableMap.of( + "marker_1", expression("IF(input_1_row_number = combined_row_number, input_1_row_number, null)"), + "marker_2", expression("IF(input_2_row_number = combined_row_number, input_2_row_number, null)")), + project(// append helper symbols for joined nodes + ImmutableMap.of( + "combined_row_number", expression("IF(COALESCE(input_1_row_number, BIGINT '-1') > COALESCE(input_2_row_number, BIGINT '-1'), input_1_row_number, input_2_row_number)"), + "combined_partition_size", expression("IF(COALESCE(input_1_partition_size, BIGINT '-1') > COALESCE(input_2_partition_size, BIGINT '-1'), input_1_partition_size, input_2_partition_size)")), + join(// join nodes using helper symbols + JoinType.FULL, + ImmutableList.of(), + Optional.of("input_1_row_number = input_2_row_number OR " + + "(input_1_row_number > input_2_partition_size AND input_2_row_number = BIGINT '1' OR " + + "input_2_row_number > input_1_partition_size AND input_1_row_number = BIGINT '1')"), + window(// append helper symbols for source input_1 + builder -> builder + .specification(specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_1_partition_size", functionCall("count", ImmutableList.of())), + // input_1 + window(builder -> builder + .specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of()) + .addFunction("input_1_row_number", functionCall("row_number", ImmutableList.of())), + values("c", "d"))), + window(// append helper symbols for source input_2 + builder -> builder + .specification(specification(ImmutableList.of(), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_2_partition_size", functionCall("count", ImmutableList.of())), + // input_2 + window(builder -> builder + .specification(ImmutableList.of(), ImmutableList.of(), ImmutableMap.of()) + .addFunction("input_2_row_number", functionCall("row_number", ImmutableList.of())), + values("e", "f")))))))); + } +} diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/PlanBuilder.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/PlanBuilder.java index 9329b3326441f..73f200c33c3a6 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/PlanBuilder.java +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/PlanBuilder.java @@ -20,6 +20,7 @@ import com.facebook.presto.common.predicate.TupleDomain; import com.facebook.presto.common.type.Type; import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.metadata.TableFunctionHandle; import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.ConnectorId; import com.facebook.presto.spi.IndexHandle; @@ -29,6 +30,7 @@ import com.facebook.presto.spi.connector.RowChangeParadigm; import com.facebook.presto.spi.constraints.TableConstraint; import com.facebook.presto.spi.function.FunctionHandle; +import com.facebook.presto.spi.function.table.ConnectorTableFunctionHandle; import com.facebook.presto.spi.plan.AggregationNode; import com.facebook.presto.spi.plan.AggregationNode.Aggregation; import com.facebook.presto.spi.plan.AggregationNode.Step; @@ -92,6 +94,8 @@ import com.facebook.presto.sql.planner.plan.RemoteSourceNode; import com.facebook.presto.sql.planner.plan.RowNumberNode; import com.facebook.presto.sql.planner.plan.SampleNode; +import com.facebook.presto.sql.planner.plan.TableFunctionNode; +import com.facebook.presto.sql.planner.plan.TableFunctionProcessorNode; import com.facebook.presto.sql.relational.FunctionResolution; import com.facebook.presto.sql.relational.SqlToRowExpressionTranslator; import com.facebook.presto.sql.tree.Expression; @@ -1006,6 +1010,32 @@ public WindowNode window(DataOrganizationSpecification specification, Map properOutputs, + List sources, + List tableArgumentProperties, + List> copartitioningLists) + + { + return new TableFunctionNode( + idAllocator.getNextId(), + name, + ImmutableMap.of(), + properOutputs, + sources, + tableArgumentProperties, + copartitioningLists, + new TableFunctionHandle(new ConnectorId("connector_id"), new ConnectorTableFunctionHandle() {}, TestingTransactionHandle.create())); + } + + public TableFunctionProcessorNode tableFunctionProcessor(Consumer consumer) + { + TableFunctionProcessorBuilder tableFunctionProcessorBuilder = new TableFunctionProcessorBuilder(); + consumer.accept(tableFunctionProcessorBuilder); + return tableFunctionProcessorBuilder.build(idAllocator); + } + public RowNumberNode rowNumber(List partitionBy, Optional maxRowCountPerPartition, VariableReferenceExpression rowNumberVariable, PlanNode source) { return new RowNumberNode( diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/TableFunctionProcessorBuilder.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/TableFunctionProcessorBuilder.java new file mode 100644 index 0000000000000..404831b10f0ef --- /dev/null +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/TableFunctionProcessorBuilder.java @@ -0,0 +1,140 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule.test; + +import com.facebook.presto.metadata.TableFunctionHandle; +import com.facebook.presto.spi.ConnectorId; +import com.facebook.presto.spi.function.table.ConnectorTableFunctionHandle; +import com.facebook.presto.spi.plan.DataOrganizationSpecification; +import com.facebook.presto.spi.plan.PlanNode; +import com.facebook.presto.spi.plan.PlanNodeIdAllocator; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.planner.plan.TableFunctionNode.PassThroughSpecification; +import com.facebook.presto.sql.planner.plan.TableFunctionProcessorNode; +import com.facebook.presto.testing.TestingTransactionHandle; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; + +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; + +public class TableFunctionProcessorBuilder +{ + private String name; + private List properOutputs = ImmutableList.of(); + private Optional source = Optional.empty(); + private boolean pruneWhenEmpty; + private List passThroughSpecifications = ImmutableList.of(); + private List> requiredSymbols = ImmutableList.of(); + private Optional> markerSymbols = Optional.empty(); + private Optional specification = Optional.empty(); + private Set prePartitioned = ImmutableSet.of(); + private int preSorted; + private Optional hashSymbol = Optional.empty(); + private ConnectorTableFunctionHandle connectorHandle = new ConnectorTableFunctionHandle() {}; + + public TableFunctionProcessorBuilder() {} + + public TableFunctionProcessorBuilder name(String name) + { + this.name = name; + return this; + } + + public TableFunctionProcessorBuilder properOutputs(VariableReferenceExpression... properOutputs) + { + this.properOutputs = ImmutableList.copyOf(properOutputs); + return this; + } + + public TableFunctionProcessorBuilder source(PlanNode source) + { + this.source = Optional.of(source); + return this; + } + + public TableFunctionProcessorBuilder pruneWhenEmpty() + { + this.pruneWhenEmpty = true; + return this; + } + + public TableFunctionProcessorBuilder passThroughSpecifications(PassThroughSpecification... passThroughSpecifications) + { + this.passThroughSpecifications = ImmutableList.copyOf(passThroughSpecifications); + return this; + } + + public TableFunctionProcessorBuilder requiredSymbols(List> requiredSymbols) + { + this.requiredSymbols = requiredSymbols; + return this; + } + + public TableFunctionProcessorBuilder markerSymbols(Map markerSymbols) + { + this.markerSymbols = Optional.of(markerSymbols); + return this; + } + + public TableFunctionProcessorBuilder specification(DataOrganizationSpecification specification) + { + this.specification = Optional.of(specification); + return this; + } + + public TableFunctionProcessorBuilder prePartitioned(Set prePartitioned) + { + this.prePartitioned = prePartitioned; + return this; + } + + public TableFunctionProcessorBuilder preSorted(int preSorted) + { + this.preSorted = preSorted; + return this; + } + + public TableFunctionProcessorBuilder hashSymbol(VariableReferenceExpression hashSymbol) + { + this.hashSymbol = Optional.of(hashSymbol); + return this; + } + + public TableFunctionProcessorBuilder connectorHandle(ConnectorTableFunctionHandle connectorHandle) + { + this.connectorHandle = connectorHandle; + return this; + } + + public TableFunctionProcessorNode build(PlanNodeIdAllocator idAllocator) + { + return new TableFunctionProcessorNode( + idAllocator.getNextId(), + name, + properOutputs, + source, + pruneWhenEmpty, + passThroughSpecifications, + requiredSymbols, + markerSymbols, + specification, + prePartitioned, + preSorted, + hashSymbol, + new TableFunctionHandle(new ConnectorId("connector_id"), connectorHandle, TestingTransactionHandle.create())); + } +} diff --git a/presto-native-tests/src/test/java/com/facebook/presto/nativetests/operator/scalar/AbstractTestNativeFunctions.java b/presto-native-tests/src/test/java/com/facebook/presto/nativetests/operator/scalar/AbstractTestNativeFunctions.java index 434a1414ea1f0..6c99067f3f399 100644 --- a/presto-native-tests/src/test/java/com/facebook/presto/nativetests/operator/scalar/AbstractTestNativeFunctions.java +++ b/presto-native-tests/src/test/java/com/facebook/presto/nativetests/operator/scalar/AbstractTestNativeFunctions.java @@ -76,7 +76,7 @@ public void assertNotSupported(String projection, @Language("RegExp") String mes fail("expected exception"); } catch (RuntimeException ex) { - assertExceptionMessage(rewritten, ex, message, true); + assertExceptionMessage(rewritten, ex, message, true, false); } } diff --git a/presto-parser/src/main/java/com/facebook/presto/sql/parser/AstBuilder.java b/presto-parser/src/main/java/com/facebook/presto/sql/parser/AstBuilder.java index 2a1b5b8828d13..4a4e91eb576c6 100644 --- a/presto-parser/src/main/java/com/facebook/presto/sql/parser/AstBuilder.java +++ b/presto-parser/src/main/java/com/facebook/presto/sql/parser/AstBuilder.java @@ -1689,7 +1689,7 @@ public Node visitDescriptorArgument(SqlBaseParser.DescriptorArgumentContext cont @Override public Node visitDescriptorField(SqlBaseParser.DescriptorFieldContext context) { - return new DescriptorField(getLocation(context), (Identifier) visit(context.identifier()), Optional.of(getType(context.type()))); + return new DescriptorField(getLocation(context), (Identifier) visit(context.identifier()), Optional.ofNullable(context.type()).map(this::getType)); } /** diff --git a/presto-spark-base/src/main/java/com/facebook/presto/spark/planner/PrestoSparkRddFactory.java b/presto-spark-base/src/main/java/com/facebook/presto/spark/planner/PrestoSparkRddFactory.java index 593538ad6a242..9853b3a5c8289 100644 --- a/presto-spark-base/src/main/java/com/facebook/presto/spark/planner/PrestoSparkRddFactory.java +++ b/presto-spark-base/src/main/java/com/facebook/presto/spark/planner/PrestoSparkRddFactory.java @@ -223,7 +223,7 @@ private JavaPairRDD cre Optional taskSourceRdd; List sources = findTableScanNodes(fragment.getRoot()); if (!sources.isEmpty()) { - try (CloseableSplitSourceProvider splitSourceProvider = new CloseableSplitSourceProvider(splitManager::getSplits)) { + try (CloseableSplitSourceProvider splitSourceProvider = new CloseableSplitSourceProvider(splitManager)) { SplitSourceFactory splitSourceFactory = new SplitSourceFactory(splitSourceProvider, WarningCollector.NOOP); Map splitSources = splitSourceFactory.createSplitSources(fragment, session, tableWriteInfo); taskSourceRdd = Optional.of(createTaskSourcesRdd( diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/ConnectorHandleResolver.java b/presto-spi/src/main/java/com/facebook/presto/spi/ConnectorHandleResolver.java index 55bee3238e34d..5dca9ab76d47c 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/ConnectorHandleResolver.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/ConnectorHandleResolver.java @@ -15,6 +15,7 @@ import com.facebook.presto.spi.connector.ConnectorPartitioningHandle; import com.facebook.presto.spi.connector.ConnectorTransactionHandle; +import com.facebook.presto.spi.function.table.ConnectorTableFunctionHandle; public interface ConnectorHandleResolver { @@ -65,4 +66,9 @@ default Class getTransactionHandleClass() { throw new UnsupportedOperationException(); } + + default Class getTableFunctionHandleClass() + { + throw new UnsupportedOperationException(); + } } diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/connector/Connector.java b/presto-spi/src/main/java/com/facebook/presto/spi/connector/Connector.java index d2d0ce779a77e..5fb86e8926ff0 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/connector/Connector.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/connector/Connector.java @@ -15,6 +15,8 @@ import com.facebook.presto.spi.SystemTable; import com.facebook.presto.spi.function.table.ConnectorTableFunction; +import com.facebook.presto.spi.function.table.ConnectorTableFunctionHandle; +import com.facebook.presto.spi.function.table.TableFunctionProcessorProvider; import com.facebook.presto.spi.procedure.DistributedProcedure; import com.facebook.presto.spi.procedure.Procedure; import com.facebook.presto.spi.session.PropertyMetadata; @@ -22,6 +24,7 @@ import java.util.List; import java.util.Set; +import java.util.function.Function; import static com.facebook.presto.spi.connector.EmptyConnectorCommitHandle.INSTANCE; import static java.util.Collections.emptyList; @@ -127,6 +130,16 @@ default Set getTableFunctions() return emptySet(); } + /** + * @return the table function processor provider for the connector + */ + default Function getTableFunctionProcessorProvider() + { + return handle -> { + throw new UnsupportedOperationException(); + }; + } + /** * @return the set of functions provided by this connector */ diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/connector/ConnectorCodecProvider.java b/presto-spi/src/main/java/com/facebook/presto/spi/connector/ConnectorCodecProvider.java index 98e71be2e266d..60172b3ce9cec 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/connector/ConnectorCodecProvider.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/connector/ConnectorCodecProvider.java @@ -24,6 +24,7 @@ import com.facebook.presto.spi.ConnectorSplit; import com.facebook.presto.spi.ConnectorTableHandle; import com.facebook.presto.spi.ConnectorTableLayoutHandle; +import com.facebook.presto.spi.function.table.ConnectorTableFunctionHandle; import java.util.Optional; @@ -84,6 +85,11 @@ default Optional> getConnectorIndexHandleCo return Optional.empty(); } + default Optional> getConnectorTableFunctionHandleCodec() + { + return Optional.empty(); + } + default Optional> getConnectorDistributedProcedureHandleCodec() { return Optional.empty(); diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/connector/ConnectorFactory.java b/presto-spi/src/main/java/com/facebook/presto/spi/connector/ConnectorFactory.java index 07b36b4dca528..22e658686adee 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/connector/ConnectorFactory.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/connector/ConnectorFactory.java @@ -14,8 +14,14 @@ package com.facebook.presto.spi.connector; import com.facebook.presto.spi.ConnectorHandleResolver; +import com.facebook.presto.spi.function.TableFunctionHandleResolver; +import com.facebook.presto.spi.function.TableFunctionSplitResolver; +import com.facebook.presto.spi.function.table.ConnectorTableFunctionHandle; +import com.facebook.presto.spi.function.table.TableFunctionProcessorProvider; import java.util.Map; +import java.util.Optional; +import java.util.function.Function; public interface ConnectorFactory { @@ -24,4 +30,19 @@ public interface ConnectorFactory ConnectorHandleResolver getHandleResolver(); Connector create(String catalogName, Map config, ConnectorContext context); + + default Function getTableFunctionProcessorProvider() + { + return null; + } + + default Optional getTableFunctionHandleResolver() + { + return Optional.empty(); + } + + default Optional getTableFunctionSplitResolver() + { + return Optional.empty(); + } } diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/connector/ConnectorSplitManager.java b/presto-spi/src/main/java/com/facebook/presto/spi/connector/ConnectorSplitManager.java index 69ac79c9f7522..8736640c5f41d 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/connector/ConnectorSplitManager.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/connector/ConnectorSplitManager.java @@ -17,6 +17,7 @@ import com.facebook.presto.spi.ConnectorSplitSource; import com.facebook.presto.spi.ConnectorTableLayoutHandle; import com.facebook.presto.spi.WarningCollector; +import com.facebook.presto.spi.function.table.ConnectorTableFunctionHandle; import static java.util.Objects.requireNonNull; @@ -71,4 +72,12 @@ public WarningCollector getWarningCollector() return warningCollector; } } + + default ConnectorSplitSource getSplits( + ConnectorTransactionHandle transaction, + ConnectorSession session, + ConnectorTableFunctionHandle function) + { + throw new UnsupportedOperationException(); + } } diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/connector/classloader/ClassLoaderSafeConnectorSplitManager.java b/presto-spi/src/main/java/com/facebook/presto/spi/connector/classloader/ClassLoaderSafeConnectorSplitManager.java index 4efb85e07c088..815ba36ff4974 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/connector/classloader/ClassLoaderSafeConnectorSplitManager.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/connector/classloader/ClassLoaderSafeConnectorSplitManager.java @@ -19,6 +19,7 @@ import com.facebook.presto.spi.classloader.ThreadContextClassLoader; import com.facebook.presto.spi.connector.ConnectorSplitManager; import com.facebook.presto.spi.connector.ConnectorTransactionHandle; +import com.facebook.presto.spi.function.table.ConnectorTableFunctionHandle; import static java.util.Objects.requireNonNull; @@ -41,4 +42,12 @@ public ConnectorSplitSource getSplits(ConnectorTransactionHandle transactionHand return delegate.getSplits(transactionHandle, session, layout, splitSchedulingContext); } } + + @Override + public ConnectorSplitSource getSplits(ConnectorTransactionHandle transaction, ConnectorSession session, ConnectorTableFunctionHandle function) + { + try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(classLoader)) { + return delegate.getSplits(transaction, session, function); + } + } } diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/function/TableFunctionHandleResolver.java b/presto-spi/src/main/java/com/facebook/presto/spi/function/TableFunctionHandleResolver.java new file mode 100644 index 0000000000000..fd24b9c694c50 --- /dev/null +++ b/presto-spi/src/main/java/com/facebook/presto/spi/function/TableFunctionHandleResolver.java @@ -0,0 +1,23 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.spi.function; + +import com.facebook.presto.spi.function.table.ConnectorTableFunctionHandle; + +import java.util.Set; + +public interface TableFunctionHandleResolver +{ + Set> getTableFunctionHandleClasses(); +} diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/function/TableFunctionSplitResolver.java b/presto-spi/src/main/java/com/facebook/presto/spi/function/TableFunctionSplitResolver.java new file mode 100644 index 0000000000000..2a31b1a9aa113 --- /dev/null +++ b/presto-spi/src/main/java/com/facebook/presto/spi/function/TableFunctionSplitResolver.java @@ -0,0 +1,23 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.spi.function; + +import com.facebook.presto.spi.ConnectorSplit; + +import java.util.Set; + +public interface TableFunctionSplitResolver +{ + Set> getTableFunctionSplitClasses(); +} diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/function/table/TableFunctionDataProcessor.java b/presto-spi/src/main/java/com/facebook/presto/spi/function/table/TableFunctionDataProcessor.java new file mode 100644 index 0000000000000..a9884ab826c14 --- /dev/null +++ b/presto-spi/src/main/java/com/facebook/presto/spi/function/table/TableFunctionDataProcessor.java @@ -0,0 +1,35 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.spi.function.table; + +import com.facebook.presto.common.Page; + +import java.util.List; +import java.util.Optional; + +public interface TableFunctionDataProcessor +{ + /** + * This method processes a portion of data. It is called multiple times until the partition is fully processed. + * + * @param input a tuple of {@link Page} including one page for each table function's input table. + * Pages list is ordered according to the corresponding argument specifications in {@link ConnectorTableFunction}. + * A page for an argument consists of columns requested during analysis (see {@link TableFunctionAnalysis#getRequiredColumns()}}. + * If any of the sources is fully processed, {@code Optional.empty()} is returned for that source. + * If all sources are fully processed, the argument is {@code null}. + * @return {@link TableFunctionProcessorState} including the processor's state and optionally a portion of result. + * After the returned state is {@code FINISHED}, the method will not be called again. + */ + TableFunctionProcessorState process(List> input); +} diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/function/table/TableFunctionProcessorProvider.java b/presto-spi/src/main/java/com/facebook/presto/spi/function/table/TableFunctionProcessorProvider.java new file mode 100644 index 0000000000000..556e3828eb79c --- /dev/null +++ b/presto-spi/src/main/java/com/facebook/presto/spi/function/table/TableFunctionProcessorProvider.java @@ -0,0 +1,35 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.spi.function.table; + +public interface TableFunctionProcessorProvider +{ + /** + * This method returns a {@code TableFunctionDataProcessor}. All the necessary information collected during analysis is available + * in the form of {@link ConnectorTableFunctionHandle}. It is called once per each partition processed by the table function. + */ + default TableFunctionDataProcessor getDataProcessor(ConnectorTableFunctionHandle handle) + { + throw new UnsupportedOperationException("this table function does not process input data"); + } + + /** + * This method returns a {@code TableFunctionSplitProcessor}. All the necessary information collected during analysis is available + * in the form of {@link ConnectorTableFunctionHandle}. It is called once per each split processed by the table function. + */ + default TableFunctionSplitProcessor getSplitProcessor(ConnectorTableFunctionHandle handle) + { + throw new UnsupportedOperationException("this table function does not process splits"); + } +} diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/function/table/TableFunctionProcessorState.java b/presto-spi/src/main/java/com/facebook/presto/spi/function/table/TableFunctionProcessorState.java new file mode 100644 index 0000000000000..f12620e203d84 --- /dev/null +++ b/presto-spi/src/main/java/com/facebook/presto/spi/function/table/TableFunctionProcessorState.java @@ -0,0 +1,106 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.spi.function.table; + +import com.facebook.presto.common.Page; +import jakarta.annotation.Nullable; + +import java.util.concurrent.CompletableFuture; + +import static java.util.Objects.requireNonNull; + +/** + * The result of processing input by {@link TableFunctionDataProcessor} or {@link TableFunctionSplitProcessor}. + * It can optionally include a portion of output data in the form of {@link Page} + * The returned {@link Page} should consist of: + * - proper columns produced by the table function + * - one column of type {@code BIGINT} for each table function's input table having the pass-through property (see {@link TableArgumentSpecification#isPassThroughColumns}), + * in order of the corresponding argument specifications. Entries in these columns are the indexes of input rows (from partition start) to be attached to output, + * or null to indicate that a row of nulls should be attached instead of an input row. The indexes are validated to be within the portion of the partition + * provided to the function so far. + * Note: when the input is empty, the only valid index value is null, because there are no input rows that could be attached to output. In such case, for performance + * reasons, the validation of indexes is skipped, and all pass-through columns are filled with nulls. + */ +public interface TableFunctionProcessorState +{ + final class Blocked + implements TableFunctionProcessorState + { + private final CompletableFuture future; + + private Blocked(CompletableFuture future) + { + this.future = requireNonNull(future, "future is null"); + } + + public static Blocked blocked(CompletableFuture future) + { + return new Blocked(future); + } + + public CompletableFuture getFuture() + { + return future; + } + } + + final class Finished + implements TableFunctionProcessorState + { + public static final Finished FINISHED = new Finished(); + + private Finished() {} + } + + final class Processed + implements TableFunctionProcessorState + { + // Represents that the page has been consumed, and it may be released. + private final boolean usedInput; + private final Page result; + + private Processed(boolean usedInput, @Nullable Page result) + { + this.usedInput = usedInput; + this.result = result; + } + + public static Processed usedInput() + { + return new Processed(true, null); + } + + public static Processed produced(Page result) + { + requireNonNull(result, "result is null"); + return new Processed(false, result); + } + + public static Processed usedInputAndProduced(Page result) + { + requireNonNull(result, "result is null"); + return new Processed(true, result); + } + + public boolean isUsedInput() + { + return usedInput; + } + + public Page getResult() + { + return result; + } + } +} diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/function/table/TableFunctionSplitProcessor.java b/presto-spi/src/main/java/com/facebook/presto/spi/function/table/TableFunctionSplitProcessor.java new file mode 100644 index 0000000000000..ba4a07e470d75 --- /dev/null +++ b/presto-spi/src/main/java/com/facebook/presto/spi/function/table/TableFunctionSplitProcessor.java @@ -0,0 +1,37 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.spi.function.table; + +import com.facebook.presto.spi.ConnectorSplit; +import jakarta.annotation.Nullable; + +/** + * Processes a single split for a table function. Each {@code TableFunctionSplitProcessor} instance + * is associated with exactly one split and is responsible for processing that split to completion. + * The {@link #process(ConnectorSplit)} method is called repeatedly until the processor returns + * {@link TableFunctionProcessorState.Finished}, at which point the split is considered fully processed. + */ +public interface TableFunctionSplitProcessor +{ + /** + * This method processes a split. It is called multiple times until the whole output for the split is produced. + * + * @param split a {@link ConnectorSplit} representing a subtask, or {@code null} when the table function + * has the KEEP WHEN EMPTY property and all table arguments are empty relations. In this case, + * the function is executed once with no input to produce output for the empty partition. + * @return {@link TableFunctionProcessorState} including the processor's state and optionally a portion of result. + * After the returned state is {@code FINISHED}, the method will not be called again. + */ + TableFunctionProcessorState process(@Nullable ConnectorSplit split); +} diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueryFramework.java b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueryFramework.java index 3553f6c070a5c..44b1c60c82956 100644 --- a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueryFramework.java +++ b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueryFramework.java @@ -324,6 +324,11 @@ protected void assertQueryFails(@Language("SQL") String sql, @Language("RegExp") QueryAssertions.assertQueryFails(queryRunner, getSession(), sql, expectedMessageRegExp); } + protected void assertQueryFailsExact(@Language("SQL") String sql, @Language("RegExp") String expectedMessageRegExp) + { + QueryAssertions.assertQueryFails(queryRunner, getSession(), sql, expectedMessageRegExp, false, true); + } + protected void assertQueryFails(QueryRunner queryRunner, @Language("SQL") String sql, @Language("RegExp") String expectedMessageRegExp) { QueryAssertions.assertQueryFails(queryRunner, getSession(), sql, expectedMessageRegExp); @@ -331,7 +336,7 @@ protected void assertQueryFails(QueryRunner queryRunner, @Language("SQL") String protected void assertQueryFails(@Language("SQL") String sql, @Language("RegExp") String expectedMessageRegExp, boolean usePatternMatcher) { - QueryAssertions.assertQueryFails(queryRunner, getSession(), sql, expectedMessageRegExp, usePatternMatcher); + QueryAssertions.assertQueryFails(queryRunner, getSession(), sql, expectedMessageRegExp, usePatternMatcher, false); } protected void assertQueryFails(Session session, @Language("SQL") String sql, @Language("RegExp") String expectedMessageRegExp) @@ -356,7 +361,7 @@ protected void assertQueryError(@Language("SQL") String sql, @Language("RegExp") protected void assertQueryFails(Session session, @Language("SQL") String sql, @Language("RegExp") String expectedMessageRegExp, boolean usePatternMatcher) { - QueryAssertions.assertQueryFails(queryRunner, session, sql, expectedMessageRegExp, usePatternMatcher); + QueryAssertions.assertQueryFails(queryRunner, session, sql, expectedMessageRegExp, usePatternMatcher, false); } protected void assertQueryReturnsEmptyResult(@Language("SQL") String sql) diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/QueryAssertions.java b/presto-tests/src/main/java/com/facebook/presto/tests/QueryAssertions.java index 207a52ed95bce..ed78d28aaf6ed 100644 --- a/presto-tests/src/main/java/com/facebook/presto/tests/QueryAssertions.java +++ b/presto-tests/src/main/java/com/facebook/presto/tests/QueryAssertions.java @@ -381,18 +381,18 @@ protected static void assertQueryFails(QueryRunner queryRunner, Session session, fail(format("Expected query to fail: %s", sql)); } catch (RuntimeException ex) { - assertExceptionMessage(sql, ex, expectedMessageRegExp, false); + assertExceptionMessage(sql, ex, expectedMessageRegExp, false, false); } } - protected static void assertQueryFails(QueryRunner queryRunner, Session session, @Language("SQL") String sql, @Language("RegExp") String expectedMessageRegExp, boolean usePatternMatcher) + protected static void assertQueryFails(QueryRunner queryRunner, Session session, @Language("SQL") String sql, @Language("RegExp") String expectedMessageRegExp, boolean usePatternMatcher, boolean exact) { try { queryRunner.execute(session, sql); fail(format("Expected query to fail: %s", sql)); } catch (RuntimeException ex) { - assertExceptionMessage(sql, ex, expectedMessageRegExp, usePatternMatcher); + assertExceptionMessage(sql, ex, expectedMessageRegExp, usePatternMatcher, exact); } } @@ -408,7 +408,7 @@ protected static void assertQueryReturnsEmptyResult(QueryRunner queryRunner, Ses } } - public static void assertExceptionMessage(String sql, Exception exception, @Language("RegExp") String regex, boolean usePatternMatcher) + public static void assertExceptionMessage(String sql, Exception exception, @Language("RegExp") String regex, boolean usePatternMatcher, boolean exact) { if (usePatternMatcher) { Pattern p = Pattern.compile(regex, Pattern.MULTILINE); @@ -417,7 +417,7 @@ public static void assertExceptionMessage(String sql, Exception exception, @Lang } } else { - if (!nullToEmpty(exception.getMessage()).matches(regex)) { + if (!(exact ? nullToEmpty(exception.getMessage()).equals(regex) : nullToEmpty(exception.getMessage()).matches(regex))) { fail(format("Expected exception message '%s' to match '%s' for query: %s", exception.getMessage(), regex, sql), exception); } } diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/TestExcludeColumnsFunction.java b/presto-tests/src/main/java/com/facebook/presto/tests/TestExcludeColumnsFunction.java new file mode 100644 index 0000000000000..f55c66626cefe --- /dev/null +++ b/presto-tests/src/main/java/com/facebook/presto/tests/TestExcludeColumnsFunction.java @@ -0,0 +1,200 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.tests; + +import com.facebook.presto.testing.QueryRunner; +import com.facebook.presto.tpch.TpchPlugin; +import org.testng.annotations.Test; + +import static com.facebook.presto.testing.TestingSession.testSessionBuilder; +import static com.facebook.presto.tpch.TpchMetadata.TINY_SCHEMA_NAME; + +public class TestExcludeColumnsFunction + extends AbstractTestQueryFramework +{ + @Override + protected QueryRunner createQueryRunner() + throws Exception + { + DistributedQueryRunner queryRunner = DistributedQueryRunner.builder(testSessionBuilder().build()).build(); + queryRunner.installPlugin(new TpchPlugin()); + queryRunner.createCatalog("tpch", "tpch"); + return queryRunner; + } + + @Override + protected QueryRunner createExpectedQueryRunner() + throws Exception + { + DistributedQueryRunner result = DistributedQueryRunner.builder(testSessionBuilder() + .setCatalog("tpch") + .setSchema(TINY_SCHEMA_NAME) + .build()) + .build(); + result.installPlugin(new TpchPlugin()); + result.createCatalog("tpch", "tpch"); + return result; + } + + @Test + public void testExcludeColumnsFunction() + { + assertQuery("SELECT * FROM tpch.tiny.nation", + "SELECT nationkey, name, regionkey, comment FROM tpch.tiny.nation"); + + assertQuery("SELECT * " + + "FROM TABLE(exclude_columns( " + + " input => TABLE(tpch.tiny.nation)," + + " columns => DESCRIPTOR(comment)))", + "SELECT nationkey, name, regionkey FROM tpch.tiny.nation"); + + assertQuery("SELECT * " + + "FROM TABLE(exclude_columns( " + + " input => TABLE(tpch.tiny.nation), " + + " columns => DESCRIPTOR(regionkey, nationkey)))", + "SELECT name, comment FROM tpch.tiny.nation"); + } + + @Test + public void testInvalidArgument() + { + assertQueryFails("SELECT *\n" + + "FROM TABLE(exclude_columns(\n" + + " input => TABLE(tpch.tiny.nation),\n" + + " columns => CAST(null AS DESCRIPTOR)))\n", + "COLUMNS descriptor is null"); + + assertQueryFailsExact("SELECT *\n" + + "FROM TABLE(exclude_columns(\n" + + " input => TABLE(tpch.tiny.nation),\n" + + " columns => DESCRIPTOR()))\n", + "line 4:21: Invalid descriptor argument COLUMNS. Descriptors should be formatted as 'DESCRIPTOR(name [type], ...)'"); + + assertQueryFailsExact("SELECT *\n" + + "FROM TABLE(exclude_columns(\n" + + " input => TABLE(tpch.tiny.nation),\n" + + " columns => DESCRIPTOR(foo, comment, bar)))\n", + "Excluded columns: [foo, bar] not present in the table"); + + assertQueryFails("SELECT *\n" + + "FROM TABLE(exclude_columns(\n" + + " input => TABLE(tpch.tiny.nation),\n" + + " columns => DESCRIPTOR(nationkey bigint, comment)))\n", + "COLUMNS descriptor contains types"); + + assertQueryFails("SELECT *\n" + + "FROM TABLE(exclude_columns(\n" + + " input => TABLE(tpch.tiny.nation),\n" + + " columns => DESCRIPTOR(nationkey, name, regionkey, comment)))\n", + "All columns are excluded"); + } + + @Test + public void testColumnResolution() + { + // excluded column names are matched case-insensitive + assertQuery("SELECT *\n" + + "FROM TABLE(exclude_columns(\n" + + " input => TABLE(SELECT 1, 2, 3, 4, 5) t(a, B, \"c\", \"D\", e),\n" + + " columns => DESCRIPTOR(\"A\", \"b\", C, d)))\n", + "SELECT 5"); + } + + @Test + public void testReturnedColumnNames() + { + // the function preserves the incoming column names. (However, due to how the analyzer handles identifiers, these are not the canonical names according to the SQL identifier semantics.) + assertQuery("SELECT a, b, c, d\n" + + "FROM TABLE(exclude_columns(\n" + + " input => TABLE(SELECT 1, 2, 3, 4, 5) t(a, B, \"c\", \"D\", e),\n" + + " columns => DESCRIPTOR(e)))\n", + "SELECT 1, 2, 3, 4"); + } + + @Test + public void testHiddenColumn() + { + assertQuery("SELECT row_number FROM tpch.tiny.region", + "SELECT * FROM UNNEST(sequence(0, 4))"); + + // the hidden column is not provided to the function + assertQueryFails("SELECT row_number\n" + + "FROM TABLE(exclude_columns(\n" + + " input => TABLE(tpch.tiny.nation),\n" + + " columns => DESCRIPTOR(comment)))\n", + "line 1:8: Column 'row_number' cannot be resolved"); + + assertQueryFailsExact("SELECT *\n" + + "FROM TABLE(exclude_columns(\n" + + " input => TABLE(tpch.tiny.nation),\n" + + " columns => DESCRIPTOR(row_number)))\n", + "Excluded columns: [row_number] not present in the table"); + } + + @Test + public void testAnonymousColumn() + { + // cannot exclude an unnamed columns. the unnamed columns are passed on unnamed. + assertQuery("SELECT *\n" + + "FROM TABLE(exclude_columns(\n" + + " input => TABLE(SELECT 1 a, 2, 3 c, 4),\n" + + " columns => DESCRIPTOR(a, c)))\n", + "SELECT 2, 4"); + } + + @Test + public void testDuplicateExcludedColumn() + { + // duplicates in excluded column names are allowed + assertQuery("SELECT *\n" + + "FROM TABLE(exclude_columns(\n" + + " input => TABLE(tpch.tiny.nation),\n" + + " columns => DESCRIPTOR(comment, name, comment)))\n", + "SELECT nationkey, regionkey FROM tpch.tiny.nation"); + } + + @Test + public void testDuplicateInputColumn() + { + // all input columns with given name are excluded + assertQuery("SELECT *\n" + + "FROM TABLE(exclude_columns(\n" + + " input => TABLE(SELECT 1, 2, 3, 4, 5) t(a, b, c, a, b),\n" + + " columns => DESCRIPTOR(a, b)))\n", + "SELECT 3"); + } + + @Test + public void testFunctionResolution() + { + assertQuery("SELECT *\n" + + "FROM TABLE(system.builtin.exclude_columns(\n" + + " input => TABLE(tpch.tiny.nation),\n" + + " columns => DESCRIPTOR(comment)))\n", + "SELECT *\n" + + "FROM TABLE(exclude_columns(\n" + + " input => TABLE(tpch.tiny.nation),\n" + + " columns => DESCRIPTOR(comment)))\n"); + } + + @Test + public void testBigInput() + { + assertQuery("SELECT *\n" + + "FROM TABLE(exclude_columns(\n" + + " input => TABLE(tpch.tiny.orders),\n" + + " columns => DESCRIPTOR(orderstatus, orderdate, orderpriority, clerk, shippriority, comment)))\n", + "SELECT orderkey, custkey, totalprice FROM tpch.tiny.orders"); + } +} diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/TestSequenceFunction.java b/presto-tests/src/main/java/com/facebook/presto/tests/TestSequenceFunction.java new file mode 100644 index 0000000000000..dd6bc1b8946aa --- /dev/null +++ b/presto-tests/src/main/java/com/facebook/presto/tests/TestSequenceFunction.java @@ -0,0 +1,291 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.tests; + +import com.facebook.presto.testing.QueryRunner; +import com.facebook.presto.tpch.TpchPlugin; +import org.testng.annotations.Test; + +import static com.facebook.presto.operator.table.Sequence.SequenceFunctionSplit.DEFAULT_SPLIT_SIZE; +import static com.facebook.presto.testing.TestingSession.testSessionBuilder; +import static com.facebook.presto.tpch.TpchMetadata.TINY_SCHEMA_NAME; +import static java.lang.String.format; + +public class TestSequenceFunction + extends AbstractTestQueryFramework +{ + @Override + protected QueryRunner createQueryRunner() + throws Exception + { + DistributedQueryRunner queryRunner = DistributedQueryRunner.builder(testSessionBuilder().build()).build(); + queryRunner.installPlugin(new TpchPlugin()); + queryRunner.createCatalog("tpch", "tpch"); + return queryRunner; + } + + @Override + protected QueryRunner createExpectedQueryRunner() + throws Exception + { + DistributedQueryRunner result = DistributedQueryRunner.builder(testSessionBuilder() + .setCatalog("tpch") + .setSchema(TINY_SCHEMA_NAME) + .build()) + .build(); + result.installPlugin(new TpchPlugin()); + result.createCatalog("tpch", "tpch"); + return result; + } + + @Test + public void testSequence() + { + assertQuery("SELECT * FROM TABLE(sequence(0, 8000, 3))", + "SELECT * FROM UNNEST(sequence(0, 8000, 3))"); + + assertQuery("SELECT * FROM TABLE(sequence(1, 10, 3))", + "VALUES BIGINT '1', 4, 7, 10"); + + assertQuery("SELECT * FROM TABLE(sequence(1, 10, 6))", + "VALUES BIGINT '1', 7"); + + assertQuery("SELECT * FROM TABLE(sequence(-1, -10, -3))", + "VALUES BIGINT '-1', -4, -7, -10"); + + assertQuery("SELECT * FROM TABLE(sequence(-1, -10, -6))", + "VALUES BIGINT '-1', -7"); + + assertQuery("SELECT * FROM TABLE(sequence(-5, 5, 3))", + "VALUES BIGINT '-5', -2, 1, 4"); + + assertQuery("SELECT * FROM TABLE(sequence(5, -5, -3))", + "VALUES BIGINT '5', 2, -1, -4"); + + assertQuery("SELECT * FROM TABLE(sequence(0, 10, 3))", + "VALUES BIGINT '0', 3, 6, 9"); + + assertQuery("SELECT * FROM TABLE(sequence(0, -10, -3))", + "VALUES BIGINT '0', -3, -6, -9"); + } + + @Test + public void testDefaultArguments() + { + assertQuery("SELECT * FROM TABLE(sequence(stop => 10))", + "SELECT * FROM UNNEST(sequence(0, 10, 1))"); + } + + @Test + public void testInvalidArgument() + { + assertQueryFailsExact("SELECT * " + + "FROM TABLE(sequence( " + + " start => -5," + + " stop => 10," + + " step => -2))", + "Step must be positive for sequence [-5, 10]"); + + assertQueryFailsExact("SELECT * " + + "FROM TABLE(sequence(" + + " start => 10," + + " stop => -5," + + " step => 2))", + "Step must be negative for sequence [10, -5]"); + + assertQueryFailsExact("SELECT * " + + "FROM TABLE(sequence(" + + " start => null," + + " stop => -5," + + " step => 2))", + "Start is null"); + + assertQueryFailsExact("SELECT * " + + "FROM TABLE(sequence(" + + " start => 10," + + " stop => null," + + " step => 2))", + "Stop is null"); + + assertQueryFailsExact("SELECT * " + + "FROM TABLE(sequence(" + + " start => 10," + + " stop => -5," + + " step => null))", + "Step is null"); + } + + @Test + public void testSingletonSequence() + { + assertQuery("SELECT * " + + "FROM TABLE(sequence(" + + " start => 10," + + " stop => 10," + + " step => 2))", + "VALUES BIGINT '10'"); + + assertQuery("SELECT * " + + "FROM TABLE(sequence(" + + " start => 10," + + " stop => 10," + + " step => -2))", + "VALUES BIGINT '10'"); + + assertQuery("SELECT * " + + "FROM TABLE(sequence(" + + " start => 10," + + " stop => 10," + + " step => 0))", + "VALUES BIGINT '10'"); + } + + @Test + public void testBigStep() + { + assertQuery(format("SELECT * " + + "FROM TABLE(sequence(" + + " start => 10," + + " stop => -5," + + " step => %s))", + Long.MIN_VALUE / (DEFAULT_SPLIT_SIZE - 1)), "VALUES BIGINT '10'"); + + assertQuery(format("SELECT * " + + "FROM TABLE(sequence(" + + " start => 10," + + " stop => -5," + + " step => %s))", + Long.MIN_VALUE / (DEFAULT_SPLIT_SIZE - 1) - 1), + "VALUES BIGINT '10'"); + + assertQuery(format("SELECT DISTINCT x - lag(x, 1) OVER(ORDER BY x DESC) \n" + + "FROM TABLE(sequence(\n" + + " start => %s,\n" + + " stop => BIGINT '%s',\n" + + " step => %s)) t(x)", + Long.MAX_VALUE, Long.MIN_VALUE, Long.MIN_VALUE / (DEFAULT_SPLIT_SIZE - 1) - 1), + format("VALUES (null), (%s)", Long.MIN_VALUE / (DEFAULT_SPLIT_SIZE - 1) - 1)); + + assertQuery(format("SELECT * " + + "FROM TABLE(sequence(" + + " start => 10," + + " stop => -5," + + " step => BIGINT '%s'))", Long.MIN_VALUE), + "VALUES BIGINT '10'"); + + assertQuery(format("SELECT * " + + "FROM TABLE(sequence(" + + " start => -5," + + " stop => 10," + + " step => %s))", Long.MAX_VALUE / (DEFAULT_SPLIT_SIZE - 1)), + "VALUES BIGINT '-5'"); + + assertQuery(format("SELECT * " + + "FROM TABLE(sequence(" + + " start => -5," + + " stop => 10," + + " step => %s))", Long.MAX_VALUE / (DEFAULT_SPLIT_SIZE - 1) + 1), + "VALUES BIGINT '-5'"); + + assertQuery(format("SELECT DISTINCT x - lag(x, 1) OVER(ORDER BY x) " + + "FROM TABLE(sequence(" + + " start => BIGINT '%s'," + + " stop => %s," + + " step => %s)) t(x)", + Long.MIN_VALUE, Long.MAX_VALUE, Long.MAX_VALUE / (DEFAULT_SPLIT_SIZE - 1) + 1), + format("VALUES (null), (%s)", Long.MAX_VALUE / (DEFAULT_SPLIT_SIZE - 1) + 1)); + + assertQuery(format("SELECT * " + + "FROM TABLE(sequence(" + + " start => -5," + + " stop => 10," + + " step => %s))", Long.MAX_VALUE), + "VALUES BIGINT '-5'"); + } + + @Test + public void testMultipleSplits() + { + long sequenceLength = DEFAULT_SPLIT_SIZE * 10 + DEFAULT_SPLIT_SIZE / 2; + long start = 10; + long step = 5; + long stop = start + (sequenceLength - 1) * step; + assertQuery(format("SELECT count(x), count(DISTINCT x), min(x), max(x) " + + "FROM TABLE(sequence( " + + " start => %s," + + " stop => %s," + + " step => %s)) t(x)", start, stop, step), + format("SELECT BIGINT '%s', BIGINT '%s', BIGINT '%s', BIGINT '%s'", sequenceLength, sequenceLength, start, stop)); + + sequenceLength = DEFAULT_SPLIT_SIZE * 4 + DEFAULT_SPLIT_SIZE / 2; + stop = start + (sequenceLength - 1) * step; + assertQuery(format("SELECT min(x), max(x) " + + "FROM TABLE(sequence(" + + " start => %s," + + " stop => %s," + + " step => %s)) t(x)", start, stop, step), + format("SELECT BIGINT '%s', BIGINT '%s'", start, stop)); + + step = -5; + stop = start + (sequenceLength - 1) * step; + assertQuery(format("SELECT max(x), min(x) " + + "FROM TABLE(sequence(" + + " start => %s," + + " stop => %s," + + " step => %s)) t(x)", start, stop, step), + format("SELECT BIGINT '%s', BIGINT '%s'", start, stop)); + } + + @Test + public void testEdgeValues() + { + long start = Long.MIN_VALUE + 15; + long stop = Long.MIN_VALUE + 3; + long step = -10; + assertQuery(format("SELECT * " + + "FROM TABLE(sequence( " + + " start => %s," + + " stop => %s," + + " step => %s))", start, stop, step), + format("VALUES (%s), (%s)", start, start + step)); + + start = Long.MIN_VALUE + 1 - (DEFAULT_SPLIT_SIZE - 1) * step; + stop = Long.MIN_VALUE + 1; + assertQuery(format("SELECT max(x), min(x) " + + "FROM TABLE(sequence( " + + " start => %s," + + " stop => %s," + + " step => %s)) t(x)", start, stop, step), + format("SELECT %s, %s", start, Long.MIN_VALUE + 1)); + + start = Long.MAX_VALUE - 15; + stop = Long.MAX_VALUE - 3; + step = 10; + assertQuery(format("SELECT * " + + "FROM TABLE(sequence( " + + " start => %s," + + " stop => %s," + + " step => %s))", start, stop, step), + format("VALUES (%s), (%s)", start, start + step)); + + start = Long.MAX_VALUE - 1 - (DEFAULT_SPLIT_SIZE - 1) * step; + stop = Long.MAX_VALUE - 1; + assertQuery(format("SELECT min(x), max(x) " + + "FROM TABLE(sequence(" + + " start => %s," + + " stop => %s," + + " step => %s)) t(x)", start, stop, step), + format("SELECT %s, %s", start, Long.MAX_VALUE - 1)); + } +} diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/TestTableFunctionInvocation.java b/presto-tests/src/main/java/com/facebook/presto/tests/TestTableFunctionInvocation.java index 067c750f09497..8f4bec9fe182a 100644 --- a/presto-tests/src/main/java/com/facebook/presto/tests/TestTableFunctionInvocation.java +++ b/presto-tests/src/main/java/com/facebook/presto/tests/TestTableFunctionInvocation.java @@ -16,12 +16,17 @@ import com.facebook.presto.connector.tvf.TestTVFConnectorColumnHandle; import com.facebook.presto.connector.tvf.TestTVFConnectorFactory; import com.facebook.presto.connector.tvf.TestTVFConnectorPlugin; +import com.facebook.presto.connector.tvf.TestingTableFunctions; import com.facebook.presto.connector.tvf.TestingTableFunctions.SimpleTableFunction; import com.facebook.presto.connector.tvf.TestingTableFunctions.SimpleTableFunction.SimpleTableFunctionHandle; import com.facebook.presto.spi.ConnectorSession; import com.facebook.presto.spi.ConnectorTableHandle; +import com.facebook.presto.spi.FixedSplitSource; import com.facebook.presto.spi.connector.TableFunctionApplicationResult; +import com.facebook.presto.spi.function.SchemaFunctionName; import com.facebook.presto.testing.QueryRunner; +import com.facebook.presto.tpch.TpchPlugin; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; @@ -32,7 +37,10 @@ import java.util.stream.IntStream; import static com.facebook.presto.common.type.VarcharType.createUnboundedVarcharType; +import static com.facebook.presto.connector.tvf.TestTVFConnectorFactory.TestTVFConnector.TestTVFConnectorSplit.TEST_TVF_CONNECTOR_SPLIT; +import static com.facebook.presto.connector.tvf.TestingTableFunctions.ConstantFunction.getConstantFunctionSplitSource; import static com.facebook.presto.testing.TestingSession.testSessionBuilder; +import static com.facebook.presto.tpch.TpchMetadata.TINY_SCHEMA_NAME; import static com.google.common.collect.ImmutableMap.toImmutableMap; public class TestTableFunctionInvocation @@ -52,6 +60,20 @@ protected QueryRunner createQueryRunner() .build(); } + @Override + protected QueryRunner createExpectedQueryRunner() + throws Exception + { + DistributedQueryRunner result = DistributedQueryRunner.builder(testSessionBuilder() + .setCatalog("tpch") + .setSchema(TINY_SCHEMA_NAME) + .build()) + .build(); + result.installPlugin(new TpchPlugin()); + result.createCatalog("tpch", "tpch"); + return result; + } + @BeforeClass public void setUp() { @@ -63,16 +85,79 @@ public void setUp() .collect(toImmutableMap(column -> column, column -> new TestTVFConnectorColumnHandle(column, createUnboundedVarcharType()) {})); queryRunner.installPlugin(new TestTVFConnectorPlugin(TestTVFConnectorFactory.builder() - .withTableFunctions(ImmutableSet.of(new SimpleTableFunction())) + .withTableFunctions(ImmutableSet.of(new SimpleTableFunction(), + new TestingTableFunctions.IdentityFunction(), + new TestingTableFunctions.IdentityPassThroughFunction(), + new TestingTableFunctions.RepeatFunction(), + new TestingTableFunctions.EmptyOutputFunction(), + new TestingTableFunctions.EmptyOutputWithPassThroughFunction(), + new TestingTableFunctions.EmptySourceFunction(), + new TestingTableFunctions.TestInputsFunction(), + new TestingTableFunctions.PassThroughInputFunction(), + new TestingTableFunctions.TestInputFunction(), + new TestingTableFunctions.TestSingleInputRowSemanticsFunction(), + new TestingTableFunctions.ConstantFunction())) .withApplyTableFunction((session, handle) -> { if (handle instanceof SimpleTableFunctionHandle) { SimpleTableFunctionHandle functionHandle = (SimpleTableFunctionHandle) handle; return Optional.of(new TableFunctionApplicationResult<>(functionHandle.getTableHandle(), functionHandle.getTableHandle().getColumns().orElseThrow(() -> new IllegalStateException("Columns are missing")))); } - throw new IllegalStateException("Unsupported table function handle: " + handle.getClass().getSimpleName()); - }).withGetColumnHandles(getColumnHandles) + return Optional.empty(); + }) + .withGetColumnHandles(getColumnHandles) + .withTableFunctionProcessorProvider( + connectorTableFunctionHandle -> { + if (connectorTableFunctionHandle instanceof TestingTableFunctions.TestingTableFunctionHandle) { + switch (((TestingTableFunctions.TestingTableFunctionHandle) connectorTableFunctionHandle).getSchemaFunctionName().getFunctionName()) { + case "identity_function": + return new TestingTableFunctions.IdentityFunction.IdentityFunctionProcessorProvider(); + case "identity_pass_through_function": + return new TestingTableFunctions.IdentityPassThroughFunction.IdentityPassThroughFunctionProcessorProvider(); + case "empty_output": + return new TestingTableFunctions.EmptyOutputFunction.EmptyOutputProcessorProvider(); + case "empty_output_with_pass_through": + return new TestingTableFunctions.EmptyOutputWithPassThroughFunction.EmptyOutputWithPassThroughProcessorProvider(); + case "empty_source": + return new TestingTableFunctions.EmptySourceFunction.EmptySourceFunctionProcessorProvider(); + case "test_inputs_function": + return new TestingTableFunctions.TestInputsFunction.TestInputsFunctionProcessorProvider(); + case "pass_through": + return new TestingTableFunctions.PassThroughInputFunction.PassThroughInputProcessorProvider(); + case "test_input": + return new TestingTableFunctions.TestInputFunction.TestInputProcessorProvider(); + case "test_single_input_function": + return new TestingTableFunctions.TestSingleInputRowSemanticsFunction.TestSingleInputFunctionProcessorProvider(); + default: + throw new IllegalArgumentException("unexpected table function: " + ((TestingTableFunctions.TestingTableFunctionHandle) connectorTableFunctionHandle).getSchemaFunctionName()); + } + } + else if (connectorTableFunctionHandle instanceof TestingTableFunctions.RepeatFunction.RepeatFunctionHandle) { + return new TestingTableFunctions.RepeatFunction.RepeatFunctionProcessorProvider(); + } + else if (connectorTableFunctionHandle instanceof TestingTableFunctions.ConstantFunction.ConstantFunctionHandle) { + return new TestingTableFunctions.ConstantFunction.ConstantFunctionProcessorProvider(); + } + return null; + }) + .withTableFunctionResolver(TestingTableFunctions.RepeatFunction.RepeatFunctionHandle.class) + .withTableFunctionResolver(TestingTableFunctions.TestingTableFunctionHandle.class) + .withTableFunctionResolver(TestingTableFunctions.ConstantFunction.ConstantFunctionHandle.class) + .withTableFunctionSplitResolver(TestingTableFunctions.ConstantFunction.ConstantFunctionSplit.class) + .withTableFunctionSplitSource( + connectorTableFunctionHandle -> { + if (connectorTableFunctionHandle instanceof TestingTableFunctions.ConstantFunction.ConstantFunctionHandle) { + return getConstantFunctionSplitSource((TestingTableFunctions.ConstantFunction.ConstantFunctionHandle) connectorTableFunctionHandle); + } + else if (connectorTableFunctionHandle instanceof TestingTableFunctions.TestingTableFunctionHandle && ((TestingTableFunctions.TestingTableFunctionHandle) connectorTableFunctionHandle).getSchemaFunctionName().equals(new SchemaFunctionName("system", "empty_source"))) { + return new FixedSplitSource(ImmutableList.of(TEST_TVF_CONNECTOR_SPLIT)); + } + return null; + }) .build())); queryRunner.createCatalog(TESTING_CATALOG, "testTVF"); + + queryRunner.installPlugin(new TpchPlugin()); + queryRunner.createCatalog("tpch", "tpch"); } @Test @@ -91,4 +176,465 @@ public void testNoArgumentsPassed() assertQuery("SELECT col FROM TABLE(system.simple_table_function())", "SELECT true WHERE false"); } + + @Test + public void testIdentityFunction() + { + assertQuery("SELECT b, a FROM TABLE(system.identity_function(input => TABLE(VALUES (1, 2), (3, 4), (5, 6)) T(a, b)))", + "VALUES (2, 1), (4, 3), (6, 5)"); + + assertQuery("SELECT b, a FROM TABLE(system.identity_pass_through_function(input => TABLE(VALUES (1, 2), (3, 4), (5, 6)) T(a, b)))", + "VALUES (2, 1), (4, 3), (6, 5)"); + + // null partitioning value + assertQuery("SELECT i.b, a FROM TABLE(system.identity_function(input => TABLE(VALUES ('x', 1), ('y', 2), ('z', null)) T(a, b) PARTITION BY b)) i", + "VALUES (1, 'x'), (2, 'y'), (null, 'z')"); + + assertQuery("SELECT b, a FROM TABLE(system.identity_pass_through_function(input => TABLE(VALUES ('x', 1), ('y', 2), ('z', null)) T(a, b) PARTITION BY b))", + "VALUES (1, 'x'), (2, 'y'), (null, 'z')"); + + // the identity_function copies all input columns and outputs them as proper columns. + // the table tpch.tiny.orders has a hidden column row_number, which is not exposed to the function. + assertQuery("SELECT * FROM TABLE(system.identity_function(input => TABLE(tpch.tiny.region)))", + "SELECT * FROM tpch.tiny.region"); + + // the identity_pass_through_function passes all input columns on output using the pass-through mechanism (as opposed to producing proper columns). + // the table tpch.tiny.orders has a hidden column row_number, which is exposed to the pass-through mechanism. + // the passed-through column row_number preserves its hidden property. + assertQuery("SELECT row_number, * FROM TABLE(system.identity_pass_through_function(input => TABLE(tpch.tiny.orders)))", + "SELECT row_number, * FROM tpch.tiny.orders"); + } + + @Test + public void testRepeatFunction() + { + assertQuery("SELECT * FROM TABLE(system.repeat(TABLE(VALUES (1, 2), (3, 4), (5, 6))))", + "VALUES (1, 2), (1, 2), (3, 4), (3, 4), (5, 6), (5, 6)"); + + assertQuery("SELECT * FROM TABLE(system.repeat(TABLE(VALUES ('a', true), ('b', false)), 4))", + "VALUES ('a', true), ('b', false), ('a', true), ('b', false), ('a', true), ('b', false), ('a', true), ('b', false)"); + + assertQuery("SELECT * FROM TABLE(system.repeat(TABLE(VALUES ('a', true), ('b', false)) t(x, y) PARTITION BY x,4))", + "VALUES ('a', true), ('b', false), ('a', true), ('b', false), ('a', true), ('b', false), ('a', true), ('b', false)"); + + assertQuery("SELECT * FROM TABLE(system.repeat(TABLE(VALUES ('a', true), ('b', false)) t(x, y) ORDER BY y, 4))", + "VALUES ('a', true), ('b', false), ('a', true), ('b', false), ('a', true), ('b', false), ('a', true), ('b', false)"); + + assertQuery("SELECT * FROM TABLE(system.repeat(TABLE(VALUES ('a', true), ('b', false)) t(x, y) PARTITION BY x ORDER BY y, 4))", + "VALUES ('a', true), ('b', false), ('a', true), ('b', false), ('a', true), ('b', false), ('a', true), ('b', false)"); + + assertQuery("SELECT * FROM TABLE(system.repeat(TABLE(tpch.tiny.part), 3))", + "SELECT * FROM tpch.tiny.part UNION ALL TABLE tpch.tiny.part UNION ALL TABLE tpch.tiny.part"); + + assertQuery("SELECT * FROM TABLE(system.repeat(TABLE(tpch.tiny.part) PARTITION BY type, 3))", + "SELECT * FROM tpch.tiny.part UNION ALL TABLE tpch.tiny.part UNION ALL TABLE tpch.tiny.part"); + + assertQuery("SELECT * FROM TABLE(system.repeat(TABLE(tpch.tiny.part) ORDER BY size, 3))", + "SELECT * FROM tpch.tiny.part UNION ALL TABLE tpch.tiny.part UNION ALL TABLE tpch.tiny.part"); + + assertQuery("SELECT * FROM TABLE(system.repeat(TABLE(tpch.tiny.part) PARTITION BY type ORDER BY size, 3))", + "SELECT * FROM tpch.tiny.part UNION ALL TABLE tpch.tiny.part UNION ALL TABLE tpch.tiny.part"); + } + + @Test + public void testFunctionsReturningEmptyPages() + { + // the functions empty_output and empty_output_with_pass_through return an empty Page for each processed input Page. the argument has KEEP WHEN EMPTY property + + // non-empty input, no pass-trough columns + + assertQuery("SELECT * FROM TABLE(system.empty_output(TABLE(tpch.tiny.orders)))", + "SELECT true WHERE false"); + + // non-empty input, pass-through partitioning column + assertQuery("SELECT * FROM TABLE(system.empty_output(TABLE(tpch.tiny.orders) PARTITION BY orderstatus))", + "SELECT true, 'X' WHERE false"); + + // non-empty input, argument has pass-trough columns + assertQuery("SELECT * FROM TABLE(system.empty_output_with_pass_through(TABLE(tpch.tiny.orders)))", + "SELECT true, * FROM tpch.tiny.orders WHERE false"); + + // non-empty input, argument has pass-trough columns, partitioning column present + assertQuery("SELECT * FROM TABLE(system.empty_output_with_pass_through(TABLE(tpch.tiny.orders) PARTITION BY orderstatus))", + "SELECT true, * FROM tpch.tiny.orders WHERE false"); + + // empty input, no pass-trough columns + assertQuery("SELECT * FROM TABLE(system.empty_output(TABLE(SELECT * FROM tpch.tiny.orders WHERE false)))", + "SELECT true WHERE false"); + + // empty input, pass-through partitioning column + assertQuery("SELECT * FROM TABLE(system.empty_output(TABLE(SELECT * FROM tpch.tiny.orders WHERE false) PARTITION BY orderstatus))", + "SELECT true, 'X' WHERE false"); + + // empty input, argument has pass-trough columns + assertQuery("SELECT * FROM TABLE(system.empty_output_with_pass_through(TABLE(SELECT * FROM tpch.tiny.orders WHERE false)))", + "SELECT true, * FROM tpch.tiny.orders WHERE false"); + + // empty input, argument has pass-trough columns, partitioning column present + assertQuery("SELECT * FROM TABLE(system.empty_output_with_pass_through(TABLE(SELECT * FROM tpch.tiny.orders WHERE false) PARTITION BY orderstatus)) ", + "SELECT true, * FROM tpch.tiny.orders WHERE false"); + + // function empty_source returns an empty Page for each Split it processes + assertQuery("SELECT * FROM TABLE(system.empty_source())", + "SELECT true WHERE false"); + } + + @Test + public void testInputPartitioning() + { + // table function test_inputs_function has four table arguments. input_1 has row semantics. input_2, input_3 and input_4 have set semantics. + // the function outputs one row per each tuple of partition it processes. The row includes a true value, and partitioning values. + assertQuery("SELECT *" + + "FROM TABLE(system.test_inputs_function(" + + "input_1 => TABLE(VALUES 1, 2, 3)," + + "input_2 => TABLE(VALUES 4, 5, 4, 5, 4) t2(x2) PARTITION BY x2," + + "input_3 => TABLE(VALUES 6, 7, 6) t3(x3) PARTITION BY x3," + + "input_4 => TABLE(VALUES 8, 9)))", + "VALUES (true, 4, 6), (true, 4, 7), (true, 5, 6), (true, 5, 7)"); + + assertQuery("SELECT * FROM TABLE(system.test_inputs_function(" + + "input_1 => TABLE(VALUES 1, 2, 3)," + + "input_2 => TABLE(VALUES 4, 5, 4, 5, 4) t2(x2) PARTITION BY x2," + + "input_3 => TABLE(VALUES 6, 7, 6) t3(x3) PARTITION BY x3," + + "input_4 => TABLE(VALUES 8, 9) t4(x4) PARTITION BY x4))", + "VALUES (true, 4, 6, 8), (true, 4, 6, 9), (true, 4, 7, 8), (true, 4, 7, 9), (true, 5, 6, 8), (true, 5, 6, 9), (true, 5, 7, 8), (true, 5, 7, 9)"); + + assertQuery("SELECT * FROM TABLE(system.test_inputs_function(" + + "input_1 => TABLE(VALUES 1, 2, 3)," + + "input_2 => TABLE(VALUES 4, 5, 4, 5, 4) t2(x2) PARTITION BY x2," + + "input_3 => TABLE(VALUES 6, 7, 6) t3(x3) PARTITION BY x3," + + "input_4 => TABLE(VALUES 8, 8) t4(x4) PARTITION BY x4))", + "VALUES (true, 4, 6, 8), (true, 4, 7, 8), (true, 5, 6, 8), (true, 5, 7, 8)"); + + // null partitioning values + assertQuery("SELECT * FROM TABLE(system.test_inputs_function(" + + "input_1 => TABLE(VALUES 1, null)," + + "input_2 => TABLE(VALUES 2, null, 2, null) t2(x2) PARTITION BY x2," + + "input_3 => TABLE(VALUES 3, null, 3, null) t3(x3) PARTITION BY x3," + + "input_4 => TABLE(VALUES null, null) t4(x4) PARTITION BY x4))", + "VALUES (true, 2, 3, null), (true, 2, null, null), (true, null, 3, null), (true, null, null, null)"); + + assertQuery("SELECT * FROM TABLE(system.test_inputs_function(" + + "input_1 => TABLE(VALUES 1, 2, 3)," + + "input_2 => TABLE(VALUES 4, 5, 4, 5, 4)," + + "input_3 => TABLE(VALUES 6, 7, 6)," + + "input_4 => TABLE(VALUES 8, 9)))", + "VALUES true"); + + assertQuery("SELECT DISTINCT regionkey, nationkey FROM TABLE(system.test_inputs_function(" + + "input_1 => TABLE(tpch.tiny.nation)," + + "input_2 => TABLE(tpch.tiny.nation) PARTITION BY regionkey ORDER BY name," + + "input_3 => TABLE(tpch.tiny.customer) PARTITION BY nationkey," + + "input_4 => TABLE(tpch.tiny.customer)))", + "SELECT DISTINCT n.regionkey, c.nationkey FROM tpch.tiny.nation n, tpch.tiny.customer c"); + } + + @Test + public void testEmptyPartitions() + { + // input_1 has row semantics, so it is prune when empty. input_2, input_3 and input_4 have set semantics, and are keep when empty by default + assertQuery("SELECT * FROM TABLE(system.test_inputs_function(" + + "input_1 => TABLE(VALUES 1, 2, 3)," + + "input_2 => TABLE(SELECT 2 WHERE false)," + + "input_3 => TABLE(SELECT 3 WHERE false)," + + "input_4 => TABLE(SELECT 4 WHERE false)))", + "VALUES true"); + + assertQueryReturnsEmptyResult("SELECT * FROM TABLE(system.test_inputs_function(" + + "input_1 => TABLE(SELECT 1 WHERE false)," + + "input_2 => TABLE(VALUES 2)," + + "input_3 => TABLE(VALUES 3)," + + "input_4 => TABLE(VALUES 4)))"); + + assertQuery("SELECT * FROM TABLE(system.test_inputs_function(" + + "input_1 => TABLE(VALUES 1, 2, 3)," + + "input_2 => TABLE(SELECT 2 WHERE false) t2(x2) PARTITION BY x2," + + "input_3 => TABLE(SELECT 3 WHERE false) t3(x3) PARTITION BY x3," + + "input_4 => TABLE(SELECT 4 WHERE false) t4(x4) PARTITION BY x4))", + "VALUES (true, CAST(null AS integer), CAST(null AS integer), CAST(null AS integer))"); + + assertQuery("SELECT * FROM TABLE(system.test_inputs_function(" + + "input_1 => TABLE(VALUES 1, 2, 3)," + + "input_2 => TABLE(SELECT 2 WHERE false) t2(x2) PARTITION BY x2," + + "input_3 => TABLE(VALUES 3, 4, 4) t3(x3) PARTITION BY x3," + + "input_4 => TABLE(VALUES 4, 4, 4, 5, 5, 5, 5) t4(x4) PARTITION BY x4))", + "VALUES (true, CAST(null AS integer), 3, 4), (true, null, 4, 4), (true, null, 4, 5), (true, null, 3, 5)"); + + assertQuery("SELECT * FROM TABLE(system.test_inputs_function(" + + "input_1 => TABLE(VALUES 1, 2, 3)," + + "input_2 => TABLE(SELECT 2 WHERE false) t2(x2) PARTITION BY x2," + + "input_3 => TABLE(SELECT 3 WHERE false) t3(x3) PARTITION BY x3," + + "input_4 => TABLE(VALUES 4, 5) t4(x4) PARTITION BY x4))", + "VALUES (true, CAST(null AS integer), CAST(null AS integer), 4), (true, null, null, 5)"); + + assertQueryReturnsEmptyResult("SELECT * FROM TABLE(system.test_inputs_function(" + + "input_1 => TABLE(VALUES 1, 2, 3)," + + "input_2 => TABLE(SELECT 2 WHERE false) t2(x2) PARTITION BY x2 PRUNE WHEN EMPTY," + + "input_3 => TABLE(SELECT 3 WHERE false) t3(x3) PARTITION BY x3," + + "input_4 => TABLE(VALUES 4, 5) t4(x4) PARTITION BY x4))"); + } + + @Test + public void testCopartitioning() + { + // all tanbles are by default KEEP WHEN EMPTY. If there is no matching partition, it is null-completed + assertQuery("SELECT * FROM TABLE(system.test_inputs_function(" + + "input_1 => TABLE(VALUES 1, 2, 3)," + + "input_2 => TABLE(VALUES 1, 1, 2, 2) t2(x2) PARTITION BY x2," + + "input_3 => TABLE(VALUES 4, 5) t3(x3)," + + "input_4 => TABLE(VALUES 2, 2, 2, 3) t4(x4) PARTITION BY x4 " + + "COPARTITION (t2, t4)))", + "VALUES (true, 1, null), (true, 2, 2), (true, null, 3)"); + + // partition `3` from input_4 is pruned because there is no matching partition in input_2 + assertQuery("SELECT *" + + "FROM TABLE(system.test_inputs_function(" + + "input_1 => TABLE(VALUES 1, 2, 3)," + + "input_2 => TABLE(VALUES 1, 1, 2, 2) t2(x2) PARTITION BY x2 PRUNE WHEN EMPTY," + + "input_3 => TABLE(VALUES 4, 5) t3(x3)," + + "input_4 => TABLE(VALUES 2, 2, 2, 3) t4(x4) PARTITION BY x4 " + + "COPARTITION (t2, t4)))", + "VALUES (true, 1, null), (true, 2, 2)"); + + // partition `1` from input_2 is pruned because there is no matching partition in input_4 + assertQuery("SELECT *" + + "FROM TABLE(system.test_inputs_function(" + + "input_1 => TABLE(VALUES 1, 2, 3)," + + "input_2 => TABLE(VALUES 1, 1, 2, 2) t2(x2) PARTITION BY x2," + + "input_3 => TABLE(VALUES 4, 5) t3(x3)," + + "input_4 => TABLE(VALUES 2, 2, 2, 3) t4(x4) PARTITION BY x4 PRUNE WHEN EMPTY " + + "COPARTITION (t2, t4)))", + "VALUES (true, 2, 2), (true, null, 3)"); + + assertQuery("SELECT *" + + "FROM TABLE(system.test_inputs_function(" + + "input_1 => TABLE(VALUES 1, 2, 3)," + + "input_2 => TABLE(VALUES 1, 1, 2, 2) t2(x2) PARTITION BY x2 PRUNE WHEN EMPTY," + + "input_3 => TABLE(VALUES 4, 5) t3(x3)," + + "input_4 => TABLE(VALUES 2, 2, 2, 3) t4(x4) PARTITION BY x4 PRUNE WHEN EMPTY " + + "COPARTITION (t2, t4)))", + "VALUES (true, 2, 2)"); + + // null partitioning values + assertQuery("SELECT *" + + "FROM TABLE(system.test_inputs_function(" + + "input_1 => TABLE(VALUES 1, 2, 3)," + + "input_2 => TABLE(VALUES 1, 1, null, null, 2, 2) t2(x2) PARTITION BY x2," + + "input_3 => TABLE(VALUES 4, 5) t3(x3)," + + "input_4 => TABLE(VALUES null, 2, 2, 2, 3) t4(x4) PARTITION BY x4 " + + "COPARTITION (t2, t4)))", + "VALUES (true, 1, null), (true, 2, 2), (true, null, null), (true, null, 3)"); + + assertQuery("SELECT *" + + "FROM TABLE(system.test_inputs_function(" + + "input_1 => TABLE(VALUES 1, 2, 3)," + + "input_2 => TABLE(VALUES 1, 1, null, null, 2, 2) t2(x2) PARTITION BY x2 PRUNE WHEN EMPTY," + + "input_3 => TABLE(VALUES 4, 5) t3(x3)," + + "input_4 => TABLE(VALUES null, 2, 2, 2, 3) t4(x4) PARTITION BY x4 PRUNE WHEN EMPTY " + + "COPARTITION (t2, t4)))", + "VALUES (true, 2, 2), (true, null, null)"); + + assertQuery("SELECT *" + + "FROM TABLE(system.test_inputs_function(" + + "input_1 => TABLE(VALUES 1, 2, 3)," + + "input_2 => TABLE(VALUES 1, 1, null, null) t2(x2) PARTITION BY x2," + + "input_3 => TABLE(VALUES 2, 2, null) t3(x3) PARTITION BY x3," + + "input_4 => TABLE(VALUES 2, 3, 3) t4(x4) PARTITION BY x4 " + + "COPARTITION (t2, t4, t3)))", + "VALUES (true, 1, null, null), (true, null, null, null), (true, null, 2, 2), (true, null, null, 3)"); + + assertQuery("SELECT *" + + "FROM TABLE(system.test_inputs_function(" + + "input_1 => TABLE(VALUES 1, 2, 3)," + + "input_2 => TABLE(VALUES 1, 1, null, null) t2(x2) PARTITION BY x2," + + "input_3 => TABLE(VALUES 2, 2, null) t3(x3) PARTITION BY x3 PRUNE WHEN EMPTY," + + "input_4 => TABLE(VALUES 2, 3, 3) t4(x4) PARTITION BY x4 " + + "COPARTITION (t2, t4, t3)))", + "VALUES (true, CAST(null AS integer), null, null), (true, null, 2, 2)"); + + assertQuery("SELECT *" + + "FROM TABLE(system.test_inputs_function(" + + "input_1 => TABLE(VALUES 1, 2, 3)," + + "input_2 => TABLE(VALUES 1, 1, null, null) t2(x2) PARTITION BY x2 PRUNE WHEN EMPTY," + + "input_3 => TABLE(VALUES 2, 2, null) t3(x3) PARTITION BY x3," + + "input_4 => TABLE(VALUES 2, 3, 3) t4(x4) PARTITION BY x4 " + + "COPARTITION (t2, t4, t3)))", + "VALUES (true, 1, CAST(null AS integer), CAST(null AS integer)), (true, null, null, null)"); + + assertQueryReturnsEmptyResult( + "SELECT *" + + "FROM TABLE(system.test_inputs_function(" + + "input_1 => TABLE(VALUES 1, 2, 3)," + + "input_2 => TABLE(VALUES 1, 1, null, null) t2(x2) PARTITION BY x2 PRUNE WHEN EMPTY," + + "input_3 => TABLE(VALUES 2, 2, null) t3(x3) PARTITION BY x3," + + "input_4 => TABLE(VALUES 2, 3, 3) t4(x4) PARTITION BY x4 PRUNE WHEN EMPTY " + + "COPARTITION (t2, t4, t3)))"); + } + + @Test + public void testPassThroughWithEmptyPartitions() + { + assertQuery("SELECT * FROM TABLE(system.pass_through(" + + "TABLE(VALUES (1, 'a'), (2, 'b')) t1(a1, b1) PARTITION BY a1," + + "TABLE(VALUES (2, 'x'), (3, 'y')) t2(a2, b2) PARTITION BY a2 " + + "COPARTITION (t1, t2)))", + "VALUES (true, false, 1, 'a', null, null), (true, true, 2, 'b', 2, 'x'), (false, true, null, null, 3, 'y')"); + + assertQuery("SELECT * FROM TABLE(system.pass_through(" + + "TABLE(VALUES (1, 'a'), (2, 'b')) t1(a1, b1) PARTITION BY a1," + + "TABLE(SELECT 2, 'x' WHERE false) t2(a2, b2) PARTITION BY a2 " + + "COPARTITION (t1, t2)))", + "VALUES (true, false, 1, 'a', CAST(null AS integer), CAST(null AS VARCHAR(1))), (true, false, 2, 'b', null, null)"); + + assertQuery("SELECT * FROM TABLE(system.pass_through(" + + "TABLE(VALUES (1, 'a'), (2, 'b')) t1(a1, b1) PARTITION BY a1," + + "TABLE(SELECT 2, 'x' WHERE false) t2(a2, b2) PARTITION BY a2))", + "VALUES (true, false, 1, 'a', CAST(null AS integer), CAST(null AS VARCHAR(1))), (true, false, 2, 'b', null, null)"); + } + + @Test + public void testPassThroughWithEmptyInput() + { + assertQuery("SELECT * FROM TABLE(system.pass_through(TABLE(SELECT 1, 'x' WHERE false) t1(a1, b1) PARTITION BY a1, TABLE(SELECT 2, 'y' WHERE false) t2(a2, b2) PARTITION BY a2 COPARTITION (t1, t2)))", + "VALUES (false, false, CAST(null AS integer), CAST(null AS VARCHAR(1)), CAST(null AS integer), CAST(null AS VARCHAR(1)))"); + + assertQuery("SELECT * FROM TABLE(system.pass_through(TABLE(SELECT 1, 'x' WHERE false) t1(a1, b1) PARTITION BY a1, TABLE(SELECT 2, 'y' WHERE false) t2(a2, b2) PARTITION BY a2))", + "VALUES (false, false, CAST(null AS integer), CAST(null AS VARCHAR(1)), CAST(null AS integer), CAST(null AS VARCHAR(1)))"); + } + + @Test + public void testInput() + { + assertQuery("SELECT got_input FROM TABLE(system.test_input(TABLE(VALUES 1)))", "VALUES true"); + + assertQuery("SELECT got_input FROM TABLE(system.test_input(TABLE(VALUES 1, 2, 3) t(a) PARTITION BY a))", + "VALUES true, true, true"); + + assertQuery("SELECT got_input FROM TABLE(system.test_input(TABLE(SELECT 1 WHERE false)))", "VALUES false"); + + assertQuery("SELECT got_input FROM TABLE(system.test_input(TABLE(SELECT 1 WHERE false) t(a) PARTITION BY a))", + "VALUES false"); + + assertQuery("SELECT got_input FROM TABLE(system.test_input(TABLE(SELECT * FROM tpch.tiny.orders WHERE false)))", "VALUES false"); + + assertQuery("SELECT got_input FROM TABLE(system.test_input(TABLE(SELECT * FROM tpch.tiny.orders WHERE false) PARTITION BY orderstatus ORDER BY orderkey))", "VALUES false"); + } + + @Test + public void testSingleSourceWithRowSemantics() + { + assertQuery("SELECT * FROM TABLE(system.test_single_input_function(TABLE(VALUES (true), (false), (true))))", "VALUES true"); + } + + @Test + public void testConstantFunction() + { + assertQuery("SELECT * FROM TABLE(system.constant(5))", "VALUES 5"); + + assertQuery("SELECT * FROM TABLE(system.constant(2, 10))", "VALUES (2), (2), (2), (2), (2), (2), (2), (2), (2), (2)"); + + assertQuery("SELECT * FROM TABLE(system.constant(null, 3))", "VALUES (CAST(null AS integer)), (null), (null)"); + + // value as constant expression + assertQuery("SELECT * FROM TABLE(system.constant(5 * 4, 3))", "VALUES (20), (20), (20)"); + + assertQueryFails("SELECT * FROM TABLE(system.constant(2147483648, 3))", "line 1:37: Cannot cast type bigint to integer"); + + assertQuery("SELECT count(*), count(DISTINCT constant_column), min(constant_column) FROM TABLE(system.constant(2, 1000000))", "VALUES (BIGINT '1000000', BIGINT '1', 2)"); + } + + @Test + public void testPruneAllColumns() + { + // function identity_pass_through_function has no proper outputs. It outputs input columns using the pass-through mechanism. + // in this case, no pass-through columns are referenced, so they are all pruned. The function effectively produces no columns. + assertQuery("SELECT 'a' FROM TABLE(system.identity_pass_through_function(input => TABLE(VALUES 1, 2, 3)))", + "VALUES 'a', 'a', 'a'"); + + // all pass-through columns are pruned. Also, the input is empty, and it has KEEP WHEN EMPTY property, so the function is executed on empty partition. + assertQuery("SELECT 'a' FROM TABLE(system.identity_pass_through_function(input => TABLE(SELECT 1 WHERE false)))", + "SELECT 'a' WHERE false"); + + // all pass-through columns are pruned. Also, the input is empty, and it has PRUNE WHEN EMPTY property, so the function is pruned out. + assertQuery("SELECT 'a' FROM TABLE(system.identity_pass_through_function(input => TABLE(SELECT 1 WHERE false) PRUNE WHEN EMPTY))", + "SELECT 'a' WHERE false"); + } + + @Test + public void testPrunePassThroughColumns() + { + // function pass_through has 2 proper columns, and it outputs all columns from both inputs using the pass-through mechanism. + // all columns are referenced + assertQuery("SELECT p1, p2, x1, x2, y1, y2 " + + "FROM TABLE(system.pass_through( " + + " TABLE(VALUES (1, 'a'), (2, 'b'), (3, 'c')) t1(x1, x2)," + + " TABLE(VALUES (4, 'd'), (5, 'e')) t2(y1, y2))) t(p1, p2)", + "VALUES (true, true, 3, 'c', 5, 'e')"); + + // all pass-through columns are referenced. Proper columns are not referenced, but they are not pruned. + assertQuery("SELECT x1, x2, y1, y2 " + + "FROM TABLE(system.pass_through( " + + " TABLE(VALUES (1, 'a'), (2, 'b'), (3, 'c')) t1(x1, x2)," + + " TABLE(VALUES (4, 'd'), (5, 'e')) t2(y1, y2))) t(p1, p2)", + "VALUES (3, 'c', 5, 'e')"); + + // some pass-through columns are referenced. Unreferenced pass-through columns are pruned. + assertQuery("SELECT x2, y2 " + + "FROM TABLE(system.pass_through(" + + " TABLE(VALUES (1, 'a'), (2, 'b'), (3, 'c')) t1(x1, x2)," + + " TABLE(VALUES (4, 'd'), (5, 'e')) t2(y1, y2))) t(p1, p2)", + "VALUES ('c', 'e')"); + + assertQuery("SELECT y1, y2 " + + "FROM TABLE(system.pass_through( " + + " TABLE(VALUES (1, 'a'), (2, 'b'), (3, 'c')) t1(x1, x2)," + + " TABLE(VALUES (4, 'd'), (5, 'e')) t2(y1, y2))) t(p1, p2)", + "VALUES (5, 'e')"); + + // no pass-through columns are referenced. Unreferenced pass-through columns are pruned. + assertQuery("SELECT 'x' " + + "FROM TABLE(system.pass_through( " + + " TABLE(VALUES (1, 'a'), (2, 'b'), (3, 'c')) t1(x1, x2)," + + " TABLE(VALUES (4, 'd'), (5, 'e')) t2(y1, y2))) t(p1, p2)", + "VALUES ('x')"); + } + + @Test + public void testPrunePassThroughColumnsWithEmptyInput() + { + // function pass_through has 2 proper columns, and it outputs all columns from both inputs using the pass-through mechanism. + // all columns are referenced + assertQuery("SELECT p1, p2, x1, x2, y1, y2 " + + "FROM TABLE(system.pass_through( " + + " TABLE(SELECT 1, 'a' WHERE FALSE) t1(x1, x2)," + + " TABLE(SELECT 4, 'd' WHERE FALSE) t2(y1, y2))) t(p1, p2)", + "VALUES (false, false, CAST(null AS integer), CAST(null AS varchar(1)), CAST(null AS integer), CAST(null AS varchar(1)))"); + + // all pass-through columns are referenced. Proper columns are not referenced, but they are not pruned. + assertQuery("SELECT x1, x2, y1, y2 " + + "FROM TABLE(system.pass_through( " + + " TABLE(SELECT 1, 'a' WHERE FALSE) t1(x1, x2)," + + " TABLE(SELECT 4, 'd' WHERE FALSE) t2(y1, y2))) t(p1, p2) ", + "VALUES (CAST(null AS integer), CAST(null AS varchar(1)), CAST(null AS integer), CAST(null AS varchar(1)))"); + + // some pass-through columns are referenced. Unreferenced pass-through columns are pruned. + assertQuery("SELECT x2, y2 " + + "FROM TABLE(system.pass_through( " + + " TABLE(SELECT 1, 'a' WHERE FALSE) t1(x1, x2)," + + " TABLE(SELECT 4, 'd' WHERE FALSE) t2(y1, y2))) t(p1, p2)", + "VALUES (CAST(null AS varchar(1)), CAST(null AS varchar(1)))"); + + assertQuery("SELECT y1, y2 " + + "FROM TABLE(system.pass_through( " + + " TABLE(SELECT 1, 'a' WHERE FALSE) t1(x1, x2)," + + " TABLE(SELECT 4, 'd' WHERE FALSE) t2(y1, y2))) t(p1, p2)", + "VALUES (CAST(null AS integer), CAST(null AS varchar(1)))"); + + // no pass-through columns are referenced. Unreferenced pass-through columns are pruned. + assertQuery("SELECT 'x' " + + "FROM TABLE(system.pass_through(" + + " TABLE(SELECT 1, 'a' WHERE FALSE) t1(x1, x2)," + + " TABLE(SELECT 4, 'd' WHERE FALSE) t2(y1, y2))) t(p1, p2)", + "VALUES ('x')"); + } }