diff --git a/presto-druid/src/main/java/com/facebook/presto/druid/DruidQueryGenerator.java b/presto-druid/src/main/java/com/facebook/presto/druid/DruidQueryGenerator.java index fe77e91c889f9..c59626d2fa1c3 100644 --- a/presto-druid/src/main/java/com/facebook/presto/druid/DruidQueryGenerator.java +++ b/presto-druid/src/main/java/com/facebook/presto/druid/DruidQueryGenerator.java @@ -20,6 +20,7 @@ import com.facebook.presto.spi.function.FunctionMetadataManager; import com.facebook.presto.spi.function.StandardFunctionResolution; import com.facebook.presto.spi.plan.FilterNode; +import com.facebook.presto.spi.plan.LimitNode; import com.facebook.presto.spi.plan.PlanNode; import com.facebook.presto.spi.plan.PlanVisitor; import com.facebook.presto.spi.plan.ProjectNode; @@ -202,6 +203,15 @@ public DruidQueryGeneratorContext visitProject(ProjectNode node, DruidQueryGener return context.withProject(newSelections); } + @Override + public DruidQueryGeneratorContext visitLimit(LimitNode node, DruidQueryGeneratorContext context) + { + checkArgument(!node.isPartial(), "Druid query generator cannot handle partial limit"); + context = node.getSource().accept(this, context); + requireNonNull(context, "context is null"); + return context.withLimit(node.getCount()).withOutputColumns(node.getOutputVariables()); + } + @Override public DruidQueryGeneratorContext visitTableScan(TableScanNode node, DruidQueryGeneratorContext contextIn) { diff --git a/presto-druid/src/main/java/com/facebook/presto/druid/DruidQueryGeneratorContext.java b/presto-druid/src/main/java/com/facebook/presto/druid/DruidQueryGeneratorContext.java index a6f69b0f97258..0934e25b333ec 100644 --- a/presto-druid/src/main/java/com/facebook/presto/druid/DruidQueryGeneratorContext.java +++ b/presto-druid/src/main/java/com/facebook/presto/druid/DruidQueryGeneratorContext.java @@ -20,11 +20,12 @@ import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.OptionalLong; import java.util.stream.Collectors; import static com.facebook.presto.druid.DruidErrorCode.DRUID_QUERY_GENERATOR_FAILURE; import static com.google.common.base.MoreObjects.toStringHelper; -import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkState; import static java.util.Objects.requireNonNull; public class DruidQueryGeneratorContext @@ -32,6 +33,7 @@ public class DruidQueryGeneratorContext private final Map selections; private final Optional from; private final Optional filter; + private final OptionalLong limit; @Override public String toString() @@ -40,6 +42,7 @@ public String toString() .add("selections", selections) .add("from", from) .add("filter", filter) + .add("limit", limit) .toString(); } @@ -55,26 +58,30 @@ public String toString() this( selections, Optional.ofNullable(from), - Optional.empty()); + Optional.empty(), + OptionalLong.empty()); } private DruidQueryGeneratorContext( Map selections, Optional from, - Optional filter) + Optional filter, + OptionalLong limit) { this.selections = new LinkedHashMap<>(requireNonNull(selections, "selections can't be null")); this.from = requireNonNull(from, "source can't be null"); this.filter = requireNonNull(filter, "filter is null"); + this.limit = requireNonNull(limit, "limit is null"); } public DruidQueryGeneratorContext withFilter(String filter) { - checkArgument(!hasFilter(), "Druid doesn't support filters at multiple levels"); + checkState(!hasFilter(), "Druid doesn't support filters at multiple levels"); return new DruidQueryGeneratorContext( selections, from, - Optional.of(filter)); + Optional.of(filter), + limit); } public DruidQueryGeneratorContext withProject(Map newSelections) @@ -82,7 +89,26 @@ public DruidQueryGeneratorContext withProject(Map Long.MAX_VALUE) { + throw new PrestoException(DRUID_QUERY_GENERATOR_FAILURE, "Invalid limit: " + limit); + } + checkState(!hasLimit(), "Limit already exists. Druid doesn't support limit on top of another limit"); + return new DruidQueryGeneratorContext( + selections, + from, + filter, + OptionalLong.of(limit)); + } + + private boolean hasLimit() + { + return limit.isPresent(); } private boolean hasFilter() @@ -113,6 +139,11 @@ public DruidQueryGenerator.GeneratedDql toQuery() query += " WHERE " + filter.get(); pushdown = true; } + + if (limit.isPresent()) { + query += " LIMIT " + limit.getAsLong(); + pushdown = true; + } return new DruidQueryGenerator.GeneratedDql(tableName, query, pushdown); } @@ -133,7 +164,7 @@ public DruidQueryGeneratorContext withOutputColumns(List newSelections = new LinkedHashMap<>(); outputColumns.forEach(o -> newSelections.put(o, requireNonNull(selections.get(o), "Cannot find the selection " + o + " in the original context " + this))); - return new DruidQueryGeneratorContext(newSelections, from, filter); + return new DruidQueryGeneratorContext(newSelections, from, filter, limit); } public enum Origin diff --git a/presto-druid/src/test/java/com/facebook/presto/druid/TestDruidQueryGenerator.java b/presto-druid/src/test/java/com/facebook/presto/druid/TestDruidQueryGenerator.java index 2dbce2754db37..423178e4edd23 100644 --- a/presto-druid/src/test/java/com/facebook/presto/druid/TestDruidQueryGenerator.java +++ b/presto-druid/src/test/java/com/facebook/presto/druid/TestDruidQueryGenerator.java @@ -80,16 +80,16 @@ private PlanNode buildPlan(Function consumer) @Test public void testSimpleSelectStar() { - testDQL(planBuilder -> tableScan(planBuilder, druidTable, regionId, city, fare, secondsSinceEpoch), - "SELECT regionId, city, fare, secondsSinceEpoch FROM realtimeOnly"); - testDQL(planBuilder -> tableScan(planBuilder, druidTable, regionId, secondsSinceEpoch), - "SELECT regionId, secondsSinceEpoch FROM realtimeOnly"); + testDQL(planBuilder -> limit(planBuilder, 50L, tableScan(planBuilder, druidTable, regionId, city, fare, secondsSinceEpoch)), + "SELECT regionId, city, fare, secondsSinceEpoch FROM realtimeOnly LIMIT 50"); + testDQL(planBuilder -> limit(planBuilder, 10L, tableScan(planBuilder, druidTable, regionId, secondsSinceEpoch)), + "SELECT regionId, secondsSinceEpoch FROM realtimeOnly LIMIT 10"); } @Test - public void testSimpleSelectWithFilter() + public void testSimpleSelectWithFilterLimit() { - testDQL(planBuilder -> project(planBuilder, filter(planBuilder, tableScan(planBuilder, druidTable, regionId, city, fare, secondsSinceEpoch), getRowExpression("secondssinceepoch > 20", defaultSessionHolder)), ImmutableList.of("city", "secondssinceepoch")), - "SELECT city, secondsSinceEpoch FROM realtimeOnly WHERE (secondsSinceEpoch > 20)"); + testDQL(planBuilder -> limit(planBuilder, 30L, project(planBuilder, filter(planBuilder, tableScan(planBuilder, druidTable, regionId, city, fare, secondsSinceEpoch), getRowExpression("secondssinceepoch > 20", defaultSessionHolder)), ImmutableList.of("city", "secondssinceepoch"))), + "SELECT city, secondsSinceEpoch FROM realtimeOnly WHERE (secondsSinceEpoch > 20) LIMIT 30"); } }