Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,12 @@ private static boolean dependsOn(WindowNode parent, WindowNode child)
|| parent.getWindowFunctions().values().stream()
.map(function -> VariablesExtractor.extractUnique(function.getFunctionCall().getArguments()))
.flatMap(Collection::stream)
.anyMatch(child.getCreatedVariable()::contains);
.anyMatch(child.getCreatedVariable()::contains)
|| parent.getWindowFunctions().values().stream()
.map(function -> function.getFrame())
.map(frame -> ImmutableList.of(frame.getStartValue(), frame.getEndValue()))
.flatMap(Collection::stream)
.anyMatch(x -> x.isPresent() && child.getCreatedVariable().contains(x.get()));
}

public static class MergeAdjacentWindowsOverProjects
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,11 @@
*/
package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.presto.common.block.SortOrder;
import com.facebook.presto.spi.function.FunctionHandle;
import com.facebook.presto.spi.plan.Assignments;
import com.facebook.presto.spi.plan.Ordering;
import com.facebook.presto.spi.plan.OrderingScheme;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.planner.assertions.ExpectedValueProvider;
import com.facebook.presto.sql.planner.assertions.PlanMatchPattern;
Expand All @@ -41,10 +44,13 @@
import static com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder.expression;
import static com.facebook.presto.sql.planner.plan.AssignmentUtils.identityAssignments;
import static com.facebook.presto.sql.planner.plan.WindowNode.Frame.BoundType.CURRENT_ROW;
import static com.facebook.presto.sql.planner.plan.WindowNode.Frame.BoundType.PRECEDING;
import static com.facebook.presto.sql.planner.plan.WindowNode.Frame.BoundType.UNBOUNDED_PRECEDING;
import static com.facebook.presto.sql.planner.plan.WindowNode.Frame.WindowType.RANGE;
import static com.facebook.presto.sql.planner.plan.WindowNode.Frame.WindowType.ROWS;
import static com.facebook.presto.sql.relational.Expressions.call;
import static com.facebook.presto.sql.relational.Expressions.constant;
import static com.google.common.collect.ImmutableList.toImmutableList;

public class TestMergeAdjacentWindows
extends BaseRuleTest
Expand All @@ -58,9 +64,19 @@ public class TestMergeAdjacentWindows
Optional.empty(),
Optional.empty());

private static final WindowNode.Frame frameWithRowOffset = new WindowNode.Frame(
ROWS,
PRECEDING,
Optional.of(new VariableReferenceExpression(Optional.empty(), "startValue", BIGINT)),
CURRENT_ROW,
Optional.empty(),
Optional.of("startValue"),
Optional.empty());

private static final FunctionHandle SUM_FUNCTION_HANDLE = createTestMetadataManager().getFunctionAndTypeManager().lookupFunction("sum", fromTypes(DOUBLE));
private static final FunctionHandle AVG_FUNCTION_HANDLE = createTestMetadataManager().getFunctionAndTypeManager().lookupFunction("avg", fromTypes(DOUBLE));
private static final FunctionHandle LAG_FUNCTION_HANDLE = createTestMetadataManager().getFunctionAndTypeManager().lookupFunction("lag", fromTypes(DOUBLE));
private static final FunctionHandle RANK_FUNCTION_HANDLE = createTestMetadataManager().getFunctionAndTypeManager().lookupFunction("rank", ImmutableList.of());
private static final String columnAAlias = "ALIAS_A";
private static final ExpectedValueProvider<WindowNode.Specification> specificationA =
specification(ImmutableList.of(columnAAlias), ImmutableList.of(), ImmutableMap.of());
Expand Down Expand Up @@ -132,6 +148,21 @@ public void testDependentAdjacentWindowsIdenticalSpecifications()
.doesNotFire();
}

@Test
public void testDependentAdjacentWindowsIdenticalSpecificationsWithOffset()
{
tester().assertThat(new GatherAndMergeWindows.MergeAdjacentWindowsOverProjects(0))
.on(p ->
p.window(
newWindowNodeSpecification(p, "a", "sortkey"),
ImmutableMap.of(p.variable("avg_1"), newWindowNodeFunction("avg", AVG_FUNCTION_HANDLE, frameWithRowOffset, "a")),
p.window(
newWindowNodeSpecification(p, "a", "sortkey"),
ImmutableMap.of(p.variable("startValue"), newWindowNodeFunction("rank", RANK_FUNCTION_HANDLE)),
p.values(p.variable("a"), p.variable("sortkey")))))
.doesNotFire();
}

@Test
public void testDependentAdjacentWindowsDistinctSpecifications()
{
Expand Down Expand Up @@ -219,6 +250,13 @@ private static WindowNode.Specification newWindowNodeSpecification(PlanBuilder p
return new WindowNode.Specification(ImmutableList.of(planBuilder.variable(symbolName, BIGINT)), Optional.empty());
}

private static WindowNode.Specification newWindowNodeSpecification(PlanBuilder planBuilder, String symbolName, String sortkey)
{
return new WindowNode.Specification(ImmutableList.of(planBuilder.variable(symbolName, BIGINT)),
Optional.of(new OrderingScheme(
ImmutableList.of(new Ordering(planBuilder.variable(sortkey, BIGINT), SortOrder.ASC_NULLS_FIRST)))));
}

private WindowNode.Function newWindowNodeFunction(String name, FunctionHandle functionHandle, String... symbols)
{
return new WindowNode.Function(
Expand All @@ -230,4 +268,16 @@ private WindowNode.Function newWindowNodeFunction(String name, FunctionHandle fu
frame,
false);
}

private WindowNode.Function newWindowNodeFunction(String name, FunctionHandle functionHandle, WindowNode.Frame frame, String... symbols)
{
return new WindowNode.Function(
call(
name,
functionHandle,
BIGINT,
Arrays.stream(symbols).map(symbol -> new VariableReferenceExpression(Optional.empty(), symbol, BIGINT)).collect(toImmutableList())),
frame,
false);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@
*/
package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.presto.common.block.SortOrder;
import com.facebook.presto.spi.function.FunctionHandle;
import com.facebook.presto.spi.plan.Ordering;
import com.facebook.presto.spi.plan.OrderingScheme;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.planner.assertions.ExpectedValueProvider;
Expand All @@ -38,8 +41,10 @@
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.window;
import static com.facebook.presto.sql.planner.plan.WindowNode.Frame.BoundType.CURRENT_ROW;
import static com.facebook.presto.sql.planner.plan.WindowNode.Frame.BoundType.PRECEDING;
import static com.facebook.presto.sql.planner.plan.WindowNode.Frame.BoundType.UNBOUNDED_PRECEDING;
import static com.facebook.presto.sql.planner.plan.WindowNode.Frame.WindowType.RANGE;
import static com.facebook.presto.sql.planner.plan.WindowNode.Frame.WindowType.ROWS;
import static com.facebook.presto.sql.relational.Expressions.call;

public class TestSwapAdjacentWindowsBySpecifications
Expand Down Expand Up @@ -135,6 +140,49 @@ public void dependentWindowsAreNotReordered()
.doesNotFire();
}

@Test
public void dependentWindowsAreNotReorderedWithOffset()
{
FunctionHandle rankFunction = createTestMetadataManager().getFunctionAndTypeManager().lookupFunction("rank", ImmutableList.of());
WindowNode.Function windowFunction = new WindowNode.Function(
call(
"rank",
rankFunction,
BIGINT,
ImmutableList.of()),
frame,
false);
WindowNode.Frame frameWithRowOffset = new WindowNode.Frame(
ROWS,
PRECEDING,
Optional.of(new VariableReferenceExpression(Optional.empty(), "startValue", BIGINT)),
CURRENT_ROW,
Optional.empty(),
Optional.of("startValue"),
Optional.empty());
WindowNode.Function functionWithOffset = new WindowNode.Function(
call(
"avg",
functionHandle,
BIGINT,
ImmutableList.of(new VariableReferenceExpression(Optional.empty(), "a", BIGINT))),
frameWithRowOffset,
false);

tester().assertThat(new GatherAndMergeWindows.SwapAdjacentWindowsBySpecifications(0))
.on(p ->
p.window(new WindowNode.Specification(
ImmutableList.of(p.variable("a")),
Optional.of(new OrderingScheme(ImmutableList.of(new Ordering(p.variable("sortkey", BIGINT), SortOrder.ASC_NULLS_FIRST))))),
ImmutableMap.of(p.variable("avg_1"), functionWithOffset),
p.window(new WindowNode.Specification(
ImmutableList.of(p.variable("a"), p.variable("b")),
Optional.of(new OrderingScheme(ImmutableList.of(new Ordering(p.variable("sortkey", BIGINT), SortOrder.ASC_NULLS_FIRST))))),
ImmutableMap.of(p.variable("startValue"), windowFunction),
p.values(p.variable("a"), p.variable("b"), p.variable("sortkey")))))
.doesNotFire();
}

private WindowNode.Function newWindowNodeFunction(List<Symbol> symbols)
{
return new WindowNode.Function(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6343,4 +6343,24 @@ public void testDuplicateUnnestItem()
assertQuery("select * from (SELECT * from unnest(ARRAY[2, 3], ARRAY[2, 3]) WITH ORDINALITY AS r(r1, r2, ord)) cross join unnest(ARRAY[2, 3], ARRAY[2, 3])",
"VALUES (2, 2, 1, 2, 2), (2, 2, 1, 3, 3), (3, 3, 2, 2, 2), (3, 3, 2, 3, 3)");
}

@Test
public void testDependentWindowFunction()
{
// rank() from window function is used as input to parent window function
String sql = "SELECT a, b, c, rnk, SUM(c) OVER (PARTITION BY a ORDER BY b rows BETWEEN rnk PRECEDING AND rnk FOLLOWING)" +
"FROM (" +
" SELECT" +
" a, b, c, RANK() OVER (PARTITION BY a ORDER BY b) AS rnk" +
" FROM (" +
" VALUES (1, 1, 1), (1, 2, 1), (1, 3, 1), (2, 1, 1), (2, 2, 1), (2, 3, 1)" +
" ) AS t(a, b, c)" +
")";
assertQuery(sql, "VALUES (1, 1, 1, 1, 2), (1, 2, 1, 2, 3), (1, 3, 1, 3, 3), (2, 1, 1, 1, 2), (2, 2, 1, 2, 3), (2, 3, 1, 3, 3)");

sql = "select orderkey, orderpriority, totalprice, rnk, " +
"avg(totalprice) over (partition by orderpriority order by orderkey rows between rnk preceding and rnk following) " +
"from (select orderkey, orderpriority, totalprice, rank() over(partition by orderpriority order by orderkey) as rnk from orders)";
assertQuery(sql);
}
}