Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,7 @@ public PlanOptimizers(
.addAll(new UnwrapRowSubscript().rules())
.addAll(new PushCastIntoRow().rules())
.addAll(ImmutableSet.of(
new ImplementTableFunctionSource(metadata),
new UnwrapSingleColumnRowInApply(typeAnalyzer),
new RemoveEmptyUnionBranches(),
new EvaluateEmptyIntersect(),
Expand Down Expand Up @@ -603,7 +604,7 @@ public PlanOptimizers(
.add(new PushAggregationIntoTableScan(plannerContext, typeAnalyzer))
.add(new PushDistinctLimitIntoTableScan(plannerContext, typeAnalyzer))
.add(new PushTopNIntoTableScan(metadata))
.add(new RewriteTableFunctionToTableScan(plannerContext))
.add(new RewriteTableFunctionToTableScan(plannerContext)) // must run after ImplementTableFunctionSource
.build();
IterativeOptimizer pushIntoTableScanOptimizer = new IterativeOptimizer(
plannerContext,
Expand Down Expand Up @@ -633,11 +634,7 @@ public PlanOptimizers(
costCalculator,
// Temporary hack: separate optimizer step to avoid the sample node being replaced by filter before pushing
// it to table scan node
ImmutableSet.of(
new ImplementBernoulliSampleAsFilter(metadata),
// Must run after RewriteTableFunctionToTableScan because that rule applies to TableFunctionNode.
// While the node gets rewritten to TableFunctionProcessorNode, we can no longer pushdown the function to the connector.
new ImplementTableFunctionSource(metadata))),
ImmutableSet.of(new ImplementBernoulliSampleAsFilter(metadata))),
columnPruningOptimizer,
new IterativeOptimizer(
plannerContext,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import io.trino.sql.PlannerContext;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.plan.TableFunctionNode;
import io.trino.sql.planner.plan.TableFunctionProcessorNode;
import io.trino.sql.planner.plan.TableScanNode;

import java.util.List;
Expand All @@ -32,13 +32,13 @@
import static com.google.common.base.Preconditions.checkState;
import static io.trino.matching.Pattern.empty;
import static io.trino.sql.planner.plan.Patterns.sources;
import static io.trino.sql.planner.plan.Patterns.tableFunction;
import static io.trino.sql.planner.plan.Patterns.tableFunctionProcessor;
import static java.util.Objects.requireNonNull;

public class RewriteTableFunctionToTableScan
implements Rule<TableFunctionNode>
implements Rule<TableFunctionProcessorNode>
{
private static final Pattern<TableFunctionNode> PATTERN = tableFunction()
private static final Pattern<TableFunctionProcessorNode> PATTERN = tableFunctionProcessor()
.with(empty(sources()));

private final PlannerContext plannerContext;
Expand All @@ -49,31 +49,31 @@ public RewriteTableFunctionToTableScan(PlannerContext plannerContext)
}

@Override
public Pattern<TableFunctionNode> getPattern()
public Pattern<TableFunctionProcessorNode> getPattern()
{
return PATTERN;
}

@Override
public Result apply(TableFunctionNode tableFunctionNode, Captures captures, Context context)
public Result apply(TableFunctionProcessorNode node, Captures captures, Context context)
{
Optional<TableFunctionApplicationResult<TableHandle>> result = plannerContext.getMetadata().applyTableFunction(context.getSession(), tableFunctionNode.getHandle());
Optional<TableFunctionApplicationResult<TableHandle>> result = plannerContext.getMetadata().applyTableFunction(context.getSession(), node.getHandle());

if (result.isEmpty()) {
return Result.empty();
}

List<ColumnHandle> columnHandles = result.get().getColumnHandles();
checkState(tableFunctionNode.getOutputSymbols().size() == columnHandles.size(), "returned table does not match the node's output");
checkState(node.getOutputSymbols().size() == columnHandles.size(), "returned table does not match the node's output");
ImmutableMap.Builder<Symbol, ColumnHandle> assignments = ImmutableMap.builder();
for (int i = 0; i < columnHandles.size(); i++) {
assignments.put(tableFunctionNode.getOutputSymbols().get(i), columnHandles.get(i));
assignments.put(node.getOutputSymbols().get(i), columnHandles.get(i));
}

return Result.ofPlanNode(new TableScanNode(
tableFunctionNode.getId(),
node.getId(),
result.get().getTableHandle(),
tableFunctionNode.getOutputSymbols(),
node.getOutputSymbols(),
assignments.buildOrThrow(),
TupleDomain.all(),
Optional.empty(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,11 @@ public static Pattern<TableFunctionNode> tableFunction()
return typeOf(TableFunctionNode.class);
}

public static Pattern<TableFunctionProcessorNode> tableFunctionProcessor()
{
return typeOf(TableFunctionProcessorNode.class);
}

public static Pattern<RowNumberNode> rowNumber()
{
return typeOf(RowNumberNode.class);
Expand Down