diff --git a/presto-benchto-benchmarks/src/test/java/com/facebook/presto/sql/planner/AbstractCostBasedPlanTest.java b/presto-benchto-benchmarks/src/test/java/com/facebook/presto/sql/planner/AbstractCostBasedPlanTest.java index 46ed546dcd607..39b0a85956935 100644 --- a/presto-benchto-benchmarks/src/test/java/com/facebook/presto/sql/planner/AbstractCostBasedPlanTest.java +++ b/presto-benchto-benchmarks/src/test/java/com/facebook/presto/sql/planner/AbstractCostBasedPlanTest.java @@ -14,6 +14,7 @@ package com.facebook.presto.sql.planner; +import com.facebook.presto.Session; import com.facebook.presto.spi.ConnectorTableHandle; import com.facebook.presto.spi.plan.AggregationNode; import com.facebook.presto.spi.plan.JoinDistributionType; @@ -37,6 +38,7 @@ import java.nio.file.Paths; import java.util.stream.Stream; +import static com.facebook.presto.SystemSessionProperties.OPTIMIZER_USE_HISTOGRAMS; import static com.facebook.presto.spi.plan.JoinDistributionType.REPLICATED; import static com.facebook.presto.spi.plan.JoinType.INNER; import static com.facebook.presto.sql.Optimizer.PlanStage.OPTIMIZED_AND_VALIDATED; @@ -76,11 +78,42 @@ public void test(String queryResourcePath) assertEquals(generateQueryPlan(read(queryResourcePath)), read(getQueryPlanResourcePath(queryResourcePath))); } + @Test(dataProvider = "getQueriesDataProvider") + public void histogramsPlansMatch(String queryResourcePath) + { + String sql = read(queryResourcePath); + Session histogramSession = Session.builder(getQueryRunner().getDefaultSession()) + .setSystemProperty(OPTIMIZER_USE_HISTOGRAMS, "true") + .build(); + Session noHistogramSession = Session.builder(getQueryRunner().getDefaultSession()) + .setSystemProperty(OPTIMIZER_USE_HISTOGRAMS, "false") + .build(); + String regularPlan = generateQueryPlan(sql, noHistogramSession); + String histogramPlan = generateQueryPlan(sql, histogramSession); + if (!regularPlan.equals(histogramPlan)) { + assertEquals(histogramPlan, read(getHistogramPlanResourcePath(getQueryPlanResourcePath(queryResourcePath)))); + } + } + private String getQueryPlanResourcePath(String queryResourcePath) { return queryResourcePath.replaceAll("\\.sql$", ".plan.txt"); } + private String getHistogramPlanResourcePath(String regularPlanResourcePath) + { + Path root = Paths.get(regularPlanResourcePath); + return root.getParent().resolve("histogram/" + root.getFileName()).toString(); + } + + private Path getResourceWritePath(String queryResourcePath) + { + return Paths.get( + getSourcePath().toString(), + "src/test/resources", + getQueryPlanResourcePath(queryResourcePath)); + } + public void generate() throws Exception { @@ -90,12 +123,24 @@ public void generate() .parallel() .forEach(queryResourcePath -> { try { - Path queryPlanWritePath = Paths.get( - getSourcePath().toString(), - "src/test/resources", - getQueryPlanResourcePath(queryResourcePath)); + Path queryPlanWritePath = getResourceWritePath(queryResourcePath); createParentDirs(queryPlanWritePath.toFile()); - write(generateQueryPlan(read(queryResourcePath)).getBytes(UTF_8), queryPlanWritePath.toFile()); + Session histogramSession = Session.builder(getQueryRunner().getDefaultSession()) + .setSystemProperty(OPTIMIZER_USE_HISTOGRAMS, "true") + .build(); + Session noHistogramSession = Session.builder(getQueryRunner().getDefaultSession()) + .setSystemProperty(OPTIMIZER_USE_HISTOGRAMS, "false") + .build(); + String sql = read(queryResourcePath); + String regularPlan = generateQueryPlan(sql, noHistogramSession); + String histogramPlan = generateQueryPlan(sql, histogramSession); + write(regularPlan.getBytes(UTF_8), queryPlanWritePath.toFile()); + // write out the histogram plan if it differs + if (!regularPlan.equals(histogramPlan)) { + Path histogramPlanWritePath = getResourceWritePath(getHistogramPlanResourcePath(queryResourcePath)); + createParentDirs(histogramPlanWritePath.toFile()); + write(histogramPlan.getBytes(UTF_8), histogramPlanWritePath.toFile()); + } System.out.println("Generated expected plan for query: " + queryResourcePath); } catch (IOException e) { @@ -119,11 +164,16 @@ private static String read(String resource) } private String generateQueryPlan(String query) + { + return generateQueryPlan(query, getQueryRunner().getDefaultSession()); + } + + private String generateQueryPlan(String query, Session session) { String sql = query.replaceAll("\\s+;\\s+$", "") .replace("${database}.${schema}.", "") .replace("\"${database}\".\"${schema}\".\"${prefix}", "\""); - Plan plan = plan(sql, OPTIMIZED_AND_VALIDATED, false); + Plan plan = plan(session, sql, OPTIMIZED_AND_VALIDATED, false); JoinOrderPrinter joinOrderPrinter = new JoinOrderPrinter(); plan.getRoot().accept(joinOrderPrinter, 0); diff --git a/presto-benchto-benchmarks/src/test/resources/sql/presto/tpcds/histogram/q85.plan.txt b/presto-benchto-benchmarks/src/test/resources/sql/presto/tpcds/histogram/q85.plan.txt new file mode 100644 index 0000000000000..9da44b017a7a7 --- /dev/null +++ b/presto-benchto-benchmarks/src/test/resources/sql/presto/tpcds/histogram/q85.plan.txt @@ -0,0 +1,38 @@ +local exchange (GATHER, SINGLE, []) + remote exchange (GATHER, SINGLE, []) + final aggregation over (r_reason_desc) + local exchange (GATHER, SINGLE, []) + remote exchange (REPARTITION, HASH, [r_reason_desc]) + partial aggregation over (r_reason_desc) + join (INNER, REPLICATED): + join (INNER, REPLICATED): + join (INNER, PARTITIONED): + remote exchange (REPARTITION, HASH, [cd_demo_sk, cd_education_status, cd_marital_status]) + scan customer_demographics + local exchange (GATHER, SINGLE, []) + remote exchange (REPARTITION, HASH, [cd_education_status_3, cd_marital_status_2, wr_refunded_cdemo_sk]) + join (INNER, PARTITIONED): + remote exchange (REPARTITION, HASH, [wr_refunded_addr_sk]) + join (INNER, PARTITIONED): + remote exchange (REPARTITION, HASH, [ws_item_sk, ws_order_number]) + join (INNER, REPLICATED): + scan web_sales + local exchange (GATHER, SINGLE, []) + remote exchange (REPLICATE, BROADCAST, []) + scan date_dim + local exchange (GATHER, SINGLE, []) + remote exchange (REPARTITION, HASH, [wr_item_sk, wr_order_number]) + join (INNER, REPLICATED): + scan web_returns + local exchange (GATHER, SINGLE, []) + remote exchange (REPLICATE, BROADCAST, []) + scan customer_demographics + local exchange (GATHER, SINGLE, []) + remote exchange (REPARTITION, HASH, [ca_address_sk]) + scan customer_address + local exchange (GATHER, SINGLE, []) + remote exchange (REPLICATE, BROADCAST, []) + scan web_page + local exchange (GATHER, SINGLE, []) + remote exchange (REPLICATE, BROADCAST, []) + scan reason diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveIntegrationSmokeTest.java b/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveIntegrationSmokeTest.java index 12afedfe53cba..08d73cb4204cf 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveIntegrationSmokeTest.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveIntegrationSmokeTest.java @@ -4467,39 +4467,39 @@ public void testCollectColumnStatisticsOnCreateTable() assertQuery(format("SHOW STATS FOR (SELECT * FROM %s WHERE p_varchar = 'p1')", tableName), "SELECT * FROM VALUES " + - "('c_boolean', null, 2.0E0, 0.5E0, null, null, null), " + - "('c_bigint', null, 2.0E0, 0.5E0, null, '0', '1'), " + - "('c_double', null, 2.0E0, 0.5E0, null, '1.2', '2.2'), " + - "('c_timestamp', null, 2.0E0, 0.5E0, null, null, null), " + - "('c_varchar', 8.0E0, 2.0E0, 0.5E0, null, null, null), " + - "('c_varbinary', 8.0E0, null, 0.5E0, null, null, null), " + - "('c_array', 176.0E0, null, 0.5, null, null, null), " + - "('p_varchar', 8.0E0, 1.0E0, 0.0E0, null, null, null), " + - "(null, null, null, null, 4.0E0, null, null)"); + "('c_boolean', null, 2.0E0, 0.5E0, null, null, null, null), " + + "('c_bigint', null, 2.0E0, 0.5E0, null, '0', '1', null), " + + "('c_double', null, 2.0E0, 0.5E0, null, '1.2', '2.2', null), " + + "('c_timestamp', null, 2.0E0, 0.5E0, null, null, null, null), " + + "('c_varchar', 8.0E0, 2.0E0, 0.5E0, null, null, null, null), " + + "('c_varbinary', 8.0E0, null, 0.5E0, null, null, null, null), " + + "('c_array', 176.0E0, null, 0.5, null, null, null, null), " + + "('p_varchar', 8.0E0, 1.0E0, 0.0E0, null, null, null, null), " + + "(null, null, null, null, 4.0E0, null, null, null)"); assertQuery(format("SHOW STATS FOR (SELECT * FROM %s WHERE p_varchar = 'p2')", tableName), "SELECT * FROM VALUES " + - "('c_boolean', null, 2.0E0, 0.5E0, null, null, null), " + - "('c_bigint', null, 2.0E0, 0.5E0, null, '1', '2'), " + - "('c_double', null, 2.0E0, 0.5E0, null, '2.3', '3.3'), " + - "('c_timestamp', null, 2.0E0, 0.5E0, null, null, null), " + - "('c_varchar', 8.0E0, 2.0E0, 0.5E0, null, null, null), " + - "('c_varbinary', 8.0E0, null, 0.5E0, null, null, null), " + - "('c_array', 96.0E0, null, 0.5, null, null, null), " + - "('p_varchar', 8.0E0, 1.0E0, 0.0E0, null, null, null), " + - "(null, null, null, null, 4.0E0, null, null)"); + "('c_boolean', null, 2.0E0, 0.5E0, null, null, null, null), " + + "('c_bigint', null, 2.0E0, 0.5E0, null, '1', '2', null), " + + "('c_double', null, 2.0E0, 0.5E0, null, '2.3', '3.3', null), " + + "('c_timestamp', null, 2.0E0, 0.5E0, null, null, null, null), " + + "('c_varchar', 8.0E0, 2.0E0, 0.5E0, null, null, null, null), " + + "('c_varbinary', 8.0E0, null, 0.5E0, null, null, null, null), " + + "('c_array', 96.0E0, null, 0.5, null, null, null, null), " + + "('p_varchar', 8.0E0, 1.0E0, 0.0E0, null, null, null, null), " + + "(null, null, null, null, 4.0E0, null, null, null)"); // non existing partition assertQuery(format("SHOW STATS FOR (SELECT * FROM %s WHERE p_varchar = 'p3')", tableName), "SELECT * FROM VALUES " + - "('c_boolean', null, 0E0, 0E0, null, null, null), " + - "('c_bigint', null, 0E0, 0E0, null, null, null), " + - "('c_double', null, 0E0, 0E0, null, null, null), " + - "('c_timestamp', null, 0E0, 0E0, null, null, null), " + - "('c_varchar', 0E0, 0E0, 0E0, null, null, null), " + - "('c_varbinary', null, 0E0, 0E0, null, null, null), " + - "('c_array', null, 0E0, 0E0, null, null, null), " + - "('p_varchar', 0E0, 0E0, 0E0, null, null, null), " + - "(null, null, null, null, 0E0, null, null)"); + "('c_boolean', null, 0E0, 0E0, null, null, null, null), " + + "('c_bigint', null, 0E0, 0E0, null, null, null, null), " + + "('c_double', null, 0E0, 0E0, null, null, null, null), " + + "('c_timestamp', null, 0E0, 0E0, null, null, null, null), " + + "('c_varchar', 0E0, 0E0, 0E0, null, null, null, null), " + + "('c_varbinary', null, 0E0, 0E0, null, null, null, null), " + + "('c_array', null, 0E0, 0E0, null, null, null, null), " + + "('p_varchar', 0E0, 0E0, 0E0, null, null, null, null), " + + "(null, null, null, null, 0E0, null, null, null)"); assertUpdate(format("DROP TABLE %s", tableName)); } @@ -4540,39 +4540,39 @@ public void testCollectColumnStatisticsOnInsert() assertQuery(format("SHOW STATS FOR (SELECT * FROM %s WHERE p_varchar = 'p1')", tableName), "SELECT * FROM VALUES " + - "('c_boolean', null, 2.0E0, 0.5E0, null, null, null), " + - "('c_bigint', null, 2.0E0, 0.5E0, null, '0', '1'), " + - "('c_double', null, 2.0E0, 0.5E0, null, '1.2', '2.2'), " + - "('c_timestamp', null, 2.0E0, 0.5E0, null, null, null), " + - "('c_varchar', 8.0E0, 2.0E0, 0.5E0, null, null, null), " + - "('c_varbinary', 8.0E0, null, 0.5E0, null, null, null), " + - "('c_array', 176.0E0, null, 0.5E0, null, null, null), " + - "('p_varchar', 8.0E0, 1.0E0, 0.0E0, null, null, null), " + - "(null, null, null, null, 4.0E0, null, null)"); + "('c_boolean', null, 2.0E0, 0.5E0, null, null, null, null), " + + "('c_bigint', null, 2.0E0, 0.5E0, null, '0', '1', null), " + + "('c_double', null, 2.0E0, 0.5E0, null, '1.2', '2.2', null), " + + "('c_timestamp', null, 2.0E0, 0.5E0, null, null, null, null), " + + "('c_varchar', 8.0E0, 2.0E0, 0.5E0, null, null, null, null), " + + "('c_varbinary', 8.0E0, null, 0.5E0, null, null, null, null), " + + "('c_array', 176.0E0, null, 0.5E0, null, null, null, null), " + + "('p_varchar', 8.0E0, 1.0E0, 0.0E0, null, null, null, null), " + + "(null, null, null, null, 4.0E0, null, null, null)"); assertQuery(format("SHOW STATS FOR (SELECT * FROM %s WHERE p_varchar = 'p2')", tableName), "SELECT * FROM VALUES " + - "('c_boolean', null, 2.0E0, 0.5E0, null, null, null), " + - "('c_bigint', null, 2.0E0, 0.5E0, null, '1', '2'), " + - "('c_double', null, 2.0E0, 0.5E0, null, '2.3', '3.3'), " + - "('c_timestamp', null, 2.0E0, 0.5E0, null, null, null), " + - "('c_varchar', 8.0E0, 2.0E0, 0.5E0, null, null, null), " + - "('c_varbinary', 8.0E0, null, 0.5E0, null, null, null), " + - "('c_array', 96.0E0, null, 0.5E0, null, null, null), " + - "('p_varchar', 8.0E0, 1.0E0, 0.0E0, null, null, null), " + - "(null, null, null, null, 4.0E0, null, null)"); + "('c_boolean', null, 2.0E0, 0.5E0, null, null, null, null), " + + "('c_bigint', null, 2.0E0, 0.5E0, null, '1', '2', null), " + + "('c_double', null, 2.0E0, 0.5E0, null, '2.3', '3.3', null), " + + "('c_timestamp', null, 2.0E0, 0.5E0, null, null, null, null), " + + "('c_varchar', 8.0E0, 2.0E0, 0.5E0, null, null, null, null), " + + "('c_varbinary', 8.0E0, null, 0.5E0, null, null, null, null), " + + "('c_array', 96.0E0, null, 0.5E0, null, null, null, null), " + + "('p_varchar', 8.0E0, 1.0E0, 0.0E0, null, null, null, null), " + + "(null, null, null, null, 4.0E0, null, null, null)"); // non existing partition assertQuery(format("SHOW STATS FOR (SELECT * FROM %s WHERE p_varchar = 'p3')", tableName), "SELECT * FROM VALUES " + - "('c_boolean', null, 0E0, 0E0, null, null, null), " + - "('c_bigint', null, 0E0, 0E0, null, null, null), " + - "('c_double', null, 0E0, 0E0, null, null, null), " + - "('c_timestamp', null, 0E0, 0E0, null, null, null), " + - "('c_varchar', 0E0, 0E0, 0E0, null, null, null), " + - "('c_varbinary', null, 0E0, 0E0, null, null, null), " + - "('c_array', null, 0E0, 0E0, null, null, null), " + - "('p_varchar', 0E0, 0E0, 0E0, null, null, null), " + - "(null, null, null, null, 0E0, null, null)"); + "('c_boolean', null, 0E0, 0E0, null, null, null, null), " + + "('c_bigint', null, 0E0, 0E0, null, null, null, null), " + + "('c_double', null, 0E0, 0E0, null, null, null, null), " + + "('c_timestamp', null, 0E0, 0E0, null, null, null, null), " + + "('c_varchar', 0E0, 0E0, 0E0, null, null, null, null), " + + "('c_varbinary', null, 0E0, 0E0, null, null, null, null), " + + "('c_array', null, 0E0, 0E0, null, null, null, null), " + + "('p_varchar', 0E0, 0E0, 0E0, null, null, null, null), " + + "(null, null, null, null, 0E0, null, null, null)"); assertUpdate(format("DROP TABLE %s", tableName)); } @@ -4646,109 +4646,109 @@ public void testAnalyzePartitionedTable() // No column stats before running analyze assertQuery("SHOW STATS FOR " + tableName, "SELECT * FROM VALUES " + - "('c_boolean', null, null, null, null, null, null), " + - "('c_bigint', null, null, null, null, null, null), " + - "('c_double', null, null, null, null, null, null), " + - "('c_timestamp', null, null, null, null, null, null), " + - "('c_varchar', null, null, null, null, null, null), " + - "('c_varbinary', null, null, null, null, null, null), " + - "('c_array', null, null, null, null, null, null), " + - "('p_varchar', 24.0, 3.0, 0.25, null, null, null), " + - "('p_bigint', null, 2.0, 0.25, null, '7', '8'), " + - "(null, null, null, null, 16.0, null, null)"); + "('c_boolean', null, null, null, null, null, null, null), " + + "('c_bigint', null, null, null, null, null, null, null), " + + "('c_double', null, null, null, null, null, null, null), " + + "('c_timestamp', null, null, null, null, null, null, null), " + + "('c_varchar', null, null, null, null, null, null, null), " + + "('c_varbinary', null, null, null, null, null, null, null), " + + "('c_array', null, null, null, null, null, null, null), " + + "('p_varchar', 24.0, 3.0, 0.25, null, null, null, null), " + + "('p_bigint', null, 2.0, 0.25, null, '7', '8', null), " + + "(null, null, null, null, 16.0, null, null, null)"); // No column stats after running an empty analyze assertUpdate(format("ANALYZE %s WITH (partitions = ARRAY[])", tableName), 0); assertQuery("SHOW STATS FOR " + tableName, "SELECT * FROM VALUES " + - "('c_boolean', null, null, null, null, null, null), " + - "('c_bigint', null, null, null, null, null, null), " + - "('c_double', null, null, null, null, null, null), " + - "('c_timestamp', null, null, null, null, null, null), " + - "('c_varchar', null, null, null, null, null, null), " + - "('c_varbinary', null, null, null, null, null, null), " + - "('c_array', null, null, null, null, null, null), " + - "('p_varchar', 24.0, 3.0, 0.25, null, null, null), " + - "('p_bigint', null, 2.0, 0.25, null, '7', '8'), " + - "(null, null, null, null, 16.0, null, null)"); + "('c_boolean', null, null, null, null, null, null, null), " + + "('c_bigint', null, null, null, null, null, null, null), " + + "('c_double', null, null, null, null, null, null, null), " + + "('c_timestamp', null, null, null, null, null, null, null), " + + "('c_varchar', null, null, null, null, null, null, null), " + + "('c_varbinary', null, null, null, null, null, null, null), " + + "('c_array', null, null, null, null, null, null, null), " + + "('p_varchar', 24.0, 3.0, 0.25, null, null, null, null), " + + "('p_bigint', null, 2.0, 0.25, null, '7', '8', null), " + + "(null, null, null, null, 16.0, null, null, null)"); // Run analyze on 3 partitions including a null partition and a duplicate partition assertUpdate(format("ANALYZE %s WITH (partitions = ARRAY[ARRAY['p1', '7'], ARRAY['p2', '7'], ARRAY['p2', '7'], ARRAY[NULL, NULL]])", tableName), 12); assertQuery(format("SHOW STATS FOR (SELECT * FROM %s WHERE p_varchar = 'p1' AND p_bigint = 7)", tableName), "SELECT * FROM VALUES " + - "('c_boolean', null, 2.0, 0.5, null, null, null), " + - "('c_bigint', null, 2.0, 0.5, null, '0', '1'), " + - "('c_double', null, 2.0, 0.5, null, '1.2', '2.2'), " + - "('c_timestamp', null, 2.0, 0.5, null, null, null), " + - "('c_varchar', 8.0, 2.0, 0.5, null, null, null), " + - "('c_varbinary', 4.0, null, 0.5, null, null, null), " + - "('c_array', 176.0, null, 0.5, null, null, null), " + - "('p_varchar', 8.0, 1.0, 0.0, null, null, null), " + - "('p_bigint', null, 1.0, 0.0, null, '7', '7'), " + - "(null, null, null, null, 4.0, null, null)"); + "('c_boolean', null, 2.0, 0.5, null, null, null, null), " + + "('c_bigint', null, 2.0, 0.5, null, '0', '1', null), " + + "('c_double', null, 2.0, 0.5, null, '1.2', '2.2', null), " + + "('c_timestamp', null, 2.0, 0.5, null, null, null, null), " + + "('c_varchar', 8.0, 2.0, 0.5, null, null, null, null), " + + "('c_varbinary', 4.0, null, 0.5, null, null, null, null), " + + "('c_array', 176.0, null, 0.5, null, null, null, null), " + + "('p_varchar', 8.0, 1.0, 0.0, null, null, null, null), " + + "('p_bigint', null, 1.0, 0.0, null, '7', '7', null), " + + "(null, null, null, null, 4.0, null, null, null)"); assertQuery(format("SHOW STATS FOR (SELECT * FROM %s WHERE p_varchar = 'p2' AND p_bigint = 7)", tableName), "SELECT * FROM VALUES " + - "('c_boolean', null, 2.0, 0.5, null, null, null), " + - "('c_bigint', null, 2.0, 0.5, null, '1', '2'), " + - "('c_double', null, 2.0, 0.5, null, '2.3', '3.3'), " + - "('c_timestamp', null, 2.0, 0.5, null, null, null), " + - "('c_varchar', 8.0, 2.0, 0.5, null, null, null), " + - "('c_varbinary', 4.0, null, 0.5, null, null, null), " + - "('c_array', 96.0, null, 0.5, null, null, null), " + - "('p_varchar', 8.0, 1.0, 0.0, null, null, null), " + - "('p_bigint', null, 1.0, 0.0, null, '7', '7'), " + - "(null, null, null, null, 4.0, null, null)"); + "('c_boolean', null, 2.0, 0.5, null, null, null, null), " + + "('c_bigint', null, 2.0, 0.5, null, '1', '2', null), " + + "('c_double', null, 2.0, 0.5, null, '2.3', '3.3', null), " + + "('c_timestamp', null, 2.0, 0.5, null, null, null, null), " + + "('c_varchar', 8.0, 2.0, 0.5, null, null, null, null), " + + "('c_varbinary', 4.0, null, 0.5, null, null, null, null), " + + "('c_array', 96.0, null, 0.5, null, null, null, null), " + + "('p_varchar', 8.0, 1.0, 0.0, null, null, null, null), " + + "('p_bigint', null, 1.0, 0.0, null, '7', '7', null), " + + "(null, null, null, null, 4.0, null, null, null)"); assertQuery(format("SHOW STATS FOR (SELECT * FROM %s WHERE p_varchar IS NULL AND p_bigint IS NULL)", tableName), "SELECT * FROM VALUES " + - "('c_boolean', null, 1.0, 0.0, null, null, null), " + - "('c_bigint', null, 4.0, 0.0, null, '4', '7'), " + - "('c_double', null, 4.0, 0.0, null, '4.7', '7.7'), " + - "('c_timestamp', null, 4.0, 0.0, null, null, null), " + - "('c_varchar', 16.0, 4.0, 0.0, null, null, null), " + - "('c_varbinary', 8.0, null, 0.0, null, null, null), " + - "('c_array', 192.0, null, 0.0, null, null, null), " + - "('p_varchar', 0.0, 0.0, 1.0, null, null, null), " + - "('p_bigint', null, 0.0, 1.0, null, null, null), " + - "(null, null, null, null, 4.0, null, null)"); + "('c_boolean', null, 1.0, 0.0, null, null, null, null), " + + "('c_bigint', null, 4.0, 0.0, null, '4', '7', null), " + + "('c_double', null, 4.0, 0.0, null, '4.7', '7.7', null), " + + "('c_timestamp', null, 4.0, 0.0, null, null, null, null), " + + "('c_varchar', 16.0, 4.0, 0.0, null, null, null, null), " + + "('c_varbinary', 8.0, null, 0.0, null, null, null, null), " + + "('c_array', 192.0, null, 0.0, null, null, null, null), " + + "('p_varchar', 0.0, 0.0, 1.0, null, null, null, null), " + + "('p_bigint', null, 0.0, 1.0, null, null, null, null), " + + "(null, null, null, null, 4.0, null, null, null)"); // Partition [p3, 8], [e1, 9], [e2, 9] have no column stats assertQuery(format("SHOW STATS FOR (SELECT * FROM %s WHERE p_varchar = 'p3' AND p_bigint = 8)", tableName), "SELECT * FROM VALUES " + - "('c_boolean', null, null, null, null, null, null), " + - "('c_bigint', null, null, null, null, null, null), " + - "('c_double', null, null, null, null, null, null), " + - "('c_timestamp', null, null, null, null, null, null), " + - "('c_varchar', null, null, null, null, null, null), " + - "('c_varbinary', null, null, null, null, null, null), " + - "('c_array', null, null, null, null, null, null), " + - "('p_varchar', 8.0, 1.0, 0.0, null, null, null), " + - "('p_bigint', null, 1.0, 0.0, null, '8', '8'), " + - "(null, null, null, null, 4.0, null, null)"); + "('c_boolean', null, null, null, null, null, null, null), " + + "('c_bigint', null, null, null, null, null, null, null), " + + "('c_double', null, null, null, null, null, null, null), " + + "('c_timestamp', null, null, null, null, null, null, null), " + + "('c_varchar', null, null, null, null, null, null, null), " + + "('c_varbinary', null, null, null, null, null, null, null), " + + "('c_array', null, null, null, null, null, null, null), " + + "('p_varchar', 8.0, 1.0, 0.0, null, null, null, null), " + + "('p_bigint', null, 1.0, 0.0, null, '8', '8', null), " + + "(null, null, null, null, 4.0, null, null, null)"); assertQuery(format("SHOW STATS FOR (SELECT * FROM %s WHERE p_varchar = 'e1' AND p_bigint = 9)", tableName), "SELECT * FROM VALUES " + - "('c_boolean', null, null, null, null, null, null), " + - "('c_bigint', null, null, null, null, null, null), " + - "('c_double', null, null, null, null, null, null), " + - "('c_timestamp', null, null, null, null, null, null), " + - "('c_varchar', null, null, null, null, null, null), " + - "('c_varbinary', null, null, null, null, null, null), " + - "('c_array', null, null, null, null, null, null), " + - "('p_varchar', 0.0, 0.0, 0.0, null, null, null), " + - "('p_bigint', null, 0.0, 0.0, null, null, null), " + - "(null, null, null, null, 0.0, null, null)"); + "('c_boolean', null, null, null, null, null, null, null), " + + "('c_bigint', null, null, null, null, null, null, null), " + + "('c_double', null, null, null, null, null, null, null), " + + "('c_timestamp', null, null, null, null, null, null, null), " + + "('c_varchar', null, null, null, null, null, null, null), " + + "('c_varbinary', null, null, null, null, null, null, null), " + + "('c_array', null, null, null, null, null, null, null), " + + "('p_varchar', 0.0, 0.0, 0.0, null, null, null, null), " + + "('p_bigint', null, 0.0, 0.0, null, null, null, null), " + + "(null, null, null, null, 0.0, null, null, null)"); assertQuery(format("SHOW STATS FOR (SELECT * FROM %s WHERE p_varchar = 'e2' AND p_bigint = 9)", tableName), "SELECT * FROM VALUES " + - "('c_boolean', null, null, null, null, null, null), " + - "('c_bigint', null, null, null, null, null, null), " + - "('c_double', null, null, null, null, null, null), " + - "('c_timestamp', null, null, null, null, null, null), " + - "('c_varchar', null, null, null, null, null, null), " + - "('c_varbinary', null, null, null, null, null, null), " + - "('c_array', null, null, null, null, null, null), " + - "('p_varchar', 0.0, 0.0, 0.0, null, null, null), " + - "('p_bigint', null, 0.0, 0.0, null, null, null), " + - "(null, null, null, null, 0.0, null, null)"); + "('c_boolean', null, null, null, null, null, null, null), " + + "('c_bigint', null, null, null, null, null, null, null), " + + "('c_double', null, null, null, null, null, null, null), " + + "('c_timestamp', null, null, null, null, null, null, null), " + + "('c_varchar', null, null, null, null, null, null, null), " + + "('c_varbinary', null, null, null, null, null, null, null), " + + "('c_array', null, null, null, null, null, null, null), " + + "('p_varchar', 0.0, 0.0, 0.0, null, null, null, null), " + + "('p_bigint', null, 0.0, 0.0, null, null, null, null), " + + "(null, null, null, null, 0.0, null, null, null)"); // Run analyze on the whole table assertUpdate("ANALYZE " + tableName, 16); @@ -4756,76 +4756,76 @@ public void testAnalyzePartitionedTable() // All partitions except empty partitions have column stats assertQuery(format("SHOW STATS FOR (SELECT * FROM %s WHERE p_varchar = 'p1' AND p_bigint = 7)", tableName), "SELECT * FROM VALUES " + - "('c_boolean', null, 2.0, 0.5, null, null, null), " + - "('c_bigint', null, 2.0, 0.5, null, '0', '1'), " + - "('c_double', null, 2.0, 0.5, null, '1.2', '2.2'), " + - "('c_timestamp', null, 2.0, 0.5, null, null, null), " + - "('c_varchar', 8.0, 2.0, 0.5, null, null, null), " + - "('c_varbinary', 4.0, null, 0.5, null, null, null), " + - "('c_array', 176.0, null, 0.5, null, null, null), " + - "('p_varchar', 8.0, 1.0, 0.0, null, null, null), " + - "('p_bigint', null, 1.0, 0.0, null, '7', '7'), " + - "(null, null, null, null, 4.0, null, null)"); + "('c_boolean', null, 2.0, 0.5, null, null, null, null), " + + "('c_bigint', null, 2.0, 0.5, null, '0', '1', null), " + + "('c_double', null, 2.0, 0.5, null, '1.2', '2.2', null), " + + "('c_timestamp', null, 2.0, 0.5, null, null, null, null), " + + "('c_varchar', 8.0, 2.0, 0.5, null, null, null, null), " + + "('c_varbinary', 4.0, null, 0.5, null, null, null, null), " + + "('c_array', 176.0, null, 0.5, null, null, null, null), " + + "('p_varchar', 8.0, 1.0, 0.0, null, null, null, null), " + + "('p_bigint', null, 1.0, 0.0, null, '7', '7', null), " + + "(null, null, null, null, 4.0, null, null, null)"); assertQuery(format("SHOW STATS FOR (SELECT * FROM %s WHERE p_varchar = 'p2' AND p_bigint = 7)", tableName), "SELECT * FROM VALUES " + - "('c_boolean', null, 2.0, 0.5, null, null, null), " + - "('c_bigint', null, 2.0, 0.5, null, '1', '2'), " + - "('c_double', null, 2.0, 0.5, null, '2.3', '3.3'), " + - "('c_timestamp', null, 2.0, 0.5, null, null, null), " + - "('c_varchar', 8.0, 2.0, 0.5, null, null, null), " + - "('c_varbinary', 4.0, null, 0.5, null, null, null), " + - "('c_array', 96.0, null, 0.5, null, null, null), " + - "('p_varchar', 8.0, 1.0, 0.0, null, null, null), " + - "('p_bigint', null, 1.0, 0.0, null, '7', '7'), " + - "(null, null, null, null, 4.0, null, null)"); + "('c_boolean', null, 2.0, 0.5, null, null, null, null), " + + "('c_bigint', null, 2.0, 0.5, null, '1', '2', null), " + + "('c_double', null, 2.0, 0.5, null, '2.3', '3.3', null), " + + "('c_timestamp', null, 2.0, 0.5, null, null, null, null), " + + "('c_varchar', 8.0, 2.0, 0.5, null, null, null, null), " + + "('c_varbinary', 4.0, null, 0.5, null, null, null, null), " + + "('c_array', 96.0, null, 0.5, null, null, null, null), " + + "('p_varchar', 8.0, 1.0, 0.0, null, null, null, null), " + + "('p_bigint', null, 1.0, 0.0, null, '7', '7', null), " + + "(null, null, null, null, 4.0, null, null, null)"); assertQuery(format("SHOW STATS FOR (SELECT * FROM %s WHERE p_varchar IS NULL AND p_bigint IS NULL)", tableName), "SELECT * FROM VALUES " + - "('c_boolean', null, 1.0, 0.0, null, null, null), " + - "('c_bigint', null, 4.0, 0.0, null, '4', '7'), " + - "('c_double', null, 4.0, 0.0, null, '4.7', '7.7'), " + - "('c_timestamp', null, 4.0, 0.0, null, null, null), " + - "('c_varchar', 16.0, 4.0, 0.0, null, null, null), " + - "('c_varbinary', 8.0, null, 0.0, null, null, null), " + - "('c_array', 192.0, null, 0.0, null, null, null), " + - "('p_varchar', 0.0, 0.0, 1.0, null, null, null), " + - "('p_bigint', null, 0.0, 1.0, null, null, null), " + - "(null, null, null, null, 4.0, null, null)"); + "('c_boolean', null, 1.0, 0.0, null, null, null, null), " + + "('c_bigint', null, 4.0, 0.0, null, '4', '7', null), " + + "('c_double', null, 4.0, 0.0, null, '4.7', '7.7', null), " + + "('c_timestamp', null, 4.0, 0.0, null, null, null, null), " + + "('c_varchar', 16.0, 4.0, 0.0, null, null, null, null), " + + "('c_varbinary', 8.0, null, 0.0, null, null, null, null), " + + "('c_array', 192.0, null, 0.0, null, null, null, null), " + + "('p_varchar', 0.0, 0.0, 1.0, null, null, null, null), " + + "('p_bigint', null, 0.0, 1.0, null, null, null, null), " + + "(null, null, null, null, 4.0, null, null, null)"); assertQuery(format("SHOW STATS FOR (SELECT * FROM %s WHERE p_varchar = 'p3' AND p_bigint = 8)", tableName), "SELECT * FROM VALUES " + - "('c_boolean', null, 2.0, 0.5, null, null, null), " + - "('c_bigint', null, 2.0, 0.5, null, '2', '3'), " + - "('c_double', null, 2.0, 0.5, null, '3.4', '4.4'), " + - "('c_timestamp', null, 2.0, 0.5, null, null, null), " + - "('c_varchar', 8.0, 2.0, 0.5, null, null, null), " + - "('c_varbinary', 4.0, null, 0.5, null, null, null), " + - "('c_array', 96.0, null, 0.5, null, null, null), " + - "('p_varchar', 8.0, 1.0, 0.0, null, null, null), " + - "('p_bigint', null, 1.0, 0.0, null, '8', '8'), " + - "(null, null, null, null, 4.0, null, null)"); + "('c_boolean', null, 2.0, 0.5, null, null, null, null), " + + "('c_bigint', null, 2.0, 0.5, null, '2', '3', null), " + + "('c_double', null, 2.0, 0.5, null, '3.4', '4.4', null), " + + "('c_timestamp', null, 2.0, 0.5, null, null, null, null), " + + "('c_varchar', 8.0, 2.0, 0.5, null, null, null, null), " + + "('c_varbinary', 4.0, null, 0.5, null, null, null, null), " + + "('c_array', 96.0, null, 0.5, null, null, null, null), " + + "('p_varchar', 8.0, 1.0, 0.0, null, null, null, null), " + + "('p_bigint', null, 1.0, 0.0, null, '8', '8', null), " + + "(null, null, null, null, 4.0, null, null, null)"); assertQuery(format("SHOW STATS FOR (SELECT * FROM %s WHERE p_varchar = 'e1' AND p_bigint = 9)", tableName), "SELECT * FROM VALUES " + - "('c_boolean', null, 0.0, 0.0, null, null, null), " + - "('c_bigint', null, 0.0, 0.0, null, null, null), " + - "('c_double', null, 0.0, 0.0, null, null, null), " + - "('c_timestamp', null, 0.0, 0.0, null, null, null), " + - "('c_varchar', 0.0, 0.0, 0.0, null, null, null), " + - "('c_varbinary', 0.0, null, 0.0, null, null, null), " + - "('c_array', 0.0, null, 0.0, null, null, null), " + - "('p_varchar', 0.0, 0.0, 0.0, null, null, null), " + - "('p_bigint', null, 0.0, 0.0, null, null, null), " + - "(null, null, null, null, 0.0, null, null)"); + "('c_boolean', null, 0.0, 0.0, null, null, null, null), " + + "('c_bigint', null, 0.0, 0.0, null, null, null, null), " + + "('c_double', null, 0.0, 0.0, null, null, null, null), " + + "('c_timestamp', null, 0.0, 0.0, null, null, null, null), " + + "('c_varchar', 0.0, 0.0, 0.0, null, null, null, null), " + + "('c_varbinary', 0.0, null, 0.0, null, null, null, null), " + + "('c_array', 0.0, null, 0.0, null, null, null, null), " + + "('p_varchar', 0.0, 0.0, 0.0, null, null, null, null), " + + "('p_bigint', null, 0.0, 0.0, null, null, null, null), " + + "(null, null, null, null, 0.0, null, null, null)"); assertQuery(format("SHOW STATS FOR (SELECT * FROM %s WHERE p_varchar = 'e2' AND p_bigint = 9)", tableName), "SELECT * FROM VALUES " + - "('c_boolean', null, 0.0, 0.0, null, null, null), " + - "('c_bigint', null, 0.0, 0.0, null, null, null), " + - "('c_double', null, 0.0, 0.0, null, null, null), " + - "('c_timestamp', null, 0.0, 0.0, null, null, null), " + - "('c_varchar', 0.0, 0.0, 0.0, null, null, null), " + - "('c_varbinary', 0.0, null, 0.0, null, null, null), " + - "('c_array', 0.0, null, 0.0, null, null, null), " + - "('p_varchar', 0.0, 0.0, 0.0, null, null, null), " + - "('p_bigint', null, 0.0, 0.0, null, null, null), " + - "(null, null, null, null, 0.0, null, null)"); + "('c_boolean', null, 0.0, 0.0, null, null, null, null), " + + "('c_bigint', null, 0.0, 0.0, null, null, null, null), " + + "('c_double', null, 0.0, 0.0, null, null, null, null), " + + "('c_timestamp', null, 0.0, 0.0, null, null, null, null), " + + "('c_varchar', 0.0, 0.0, 0.0, null, null, null, null), " + + "('c_varbinary', 0.0, null, 0.0, null, null, null, null), " + + "('c_array', 0.0, null, 0.0, null, null, null, null), " + + "('p_varchar', 0.0, 0.0, 0.0, null, null, null, null), " + + "('p_bigint', null, 0.0, 0.0, null, null, null, null), " + + "(null, null, null, null, 0.0, null, null, null)"); // Drop the partitioned test table assertUpdate(format("DROP TABLE %s", tableName)); @@ -4840,32 +4840,32 @@ public void testAnalyzeUnpartitionedTable() // No column stats before running analyze assertQuery("SHOW STATS FOR " + tableName, "SELECT * FROM VALUES " + - "('c_boolean', null, null, null, null, null, null), " + - "('c_bigint', null, null, null, null, null, null), " + - "('c_double', null, null, null, null, null, null), " + - "('c_timestamp', null, null, null, null, null, null), " + - "('c_varchar', null, null, null, null, null, null), " + - "('c_varbinary', null, null, null, null, null, null), " + - "('c_array', null, null, null, null, null, null), " + - "('p_varchar', null, null, null, null, null, null), " + - "('p_bigint', null, null, null, null, null, null), " + - "(null, null, null, null, 16.0, null, null)"); + "('c_boolean', null, null, null, null, null, null, null), " + + "('c_bigint', null, null, null, null, null, null, null), " + + "('c_double', null, null, null, null, null, null, null), " + + "('c_timestamp', null, null, null, null, null, null, null), " + + "('c_varchar', null, null, null, null, null, null, null), " + + "('c_varbinary', null, null, null, null, null, null, null), " + + "('c_array', null, null, null, null, null, null, null), " + + "('p_varchar', null, null, null, null, null, null, null), " + + "('p_bigint', null, null, null, null, null, null, null), " + + "(null, null, null, null, 16.0, null, null, null)"); // Run analyze on the whole table assertUpdate("ANALYZE " + tableName, 16); assertQuery("SHOW STATS FOR " + tableName, "SELECT * FROM VALUES " + - "('c_boolean', null, 2.0, 0.375, null, null, null), " + - "('c_bigint', null, 8.0, 0.375, null, '0', '7'), " + - "('c_double', null, 10.0, 0.375, null, '1.2', '7.7'), " + - "('c_timestamp', null, 10.0, 0.375, null, null, null), " + - "('c_varchar', 40.0, 10.0, 0.375, null, null, null), " + - "('c_varbinary', 20.0, null, 0.375, null, null, null), " + - "('c_array', 560.0, null, 0.375, null, null, null), " + - "('p_varchar', 24.0, 3.0, 0.25, null, null, null), " + - "('p_bigint', null, 2.0, 0.25, null, '7', '8'), " + - "(null, null, null, null, 16.0, null, null)"); + "('c_boolean', null, 2.0, 0.375, null, null, null, null), " + + "('c_bigint', null, 8.0, 0.375, null, '0', '7', null), " + + "('c_double', null, 10.0, 0.375, null, '1.2', '7.7', null), " + + "('c_timestamp', null, 10.0, 0.375, null, null, null, null), " + + "('c_varchar', 40.0, 10.0, 0.375, null, null, null, null), " + + "('c_varbinary', 20.0, null, 0.375, null, null, null, null), " + + "('c_array', 560.0, null, 0.375, null, null, null, null), " + + "('p_varchar', 24.0, 3.0, 0.25, null, null, null, null), " + + "('p_bigint', null, 2.0, 0.25, null, '7', '8', null), " + + "(null, null, null, null, 16.0, null, null, null)"); // Drop the unpartitioned test table assertUpdate(format("DROP TABLE %s", tableName)); @@ -4950,11 +4950,11 @@ public void testInsertMultipleColumnsFromSameChannel() assertQuery(format("SHOW STATS FOR (SELECT * FROM %s WHERE p_varchar_1 = '2' AND p_varchar_2 = '2')", tableName), "SELECT * FROM VALUES " + - "('c_bigint_1', null, 1.0E0, 0.0E0, null, '1', '1'), " + - "('c_bigint_2', null, 1.0E0, 0.0E0, null, '1', '1'), " + - "('p_varchar_1', 1.0E0, 1.0E0, 0.0E0, null, null, null), " + - "('p_varchar_2', 1.0E0, 1.0E0, 0.0E0, null, null, null), " + - "(null, null, null, null, 1.0E0, null, null)"); + "('c_bigint_1', null, 1.0E0, 0.0E0, null, '1', '1', null), " + + "('c_bigint_2', null, 1.0E0, 0.0E0, null, '1', '1', null), " + + "('p_varchar_1', 1.0E0, 1.0E0, 0.0E0, null, null, null, null), " + + "('p_varchar_2', 1.0E0, 1.0E0, 0.0E0, null, null, null, null), " + + "(null, null, null, null, 1.0E0, null, null, null)"); assertUpdate(format("" + "INSERT INTO %s (c_bigint_1, c_bigint_2, p_varchar_1, p_varchar_2) " + @@ -4964,11 +4964,11 @@ public void testInsertMultipleColumnsFromSameChannel() assertQuery(format("SHOW STATS FOR (SELECT * FROM %s WHERE p_varchar_1 = 'O' AND p_varchar_2 = 'O')", tableName), "SELECT * FROM VALUES " + - "('c_bigint_1', null, 1.0E0, 0.0E0, null, '15008', '15008'), " + - "('c_bigint_2', null, 1.0E0, 0.0E0, null, '15008', '15008'), " + - "('p_varchar_1', 1.0E0, 1.0E0, 0.0E0, null, null, null), " + - "('p_varchar_2', 1.0E0, 1.0E0, 0.0E0, null, null, null), " + - "(null, null, null, null, 1.0E0, null, null)"); + "('c_bigint_1', null, 1.0E0, 0.0E0, null, '15008', '15008', null), " + + "('c_bigint_2', null, 1.0E0, 0.0E0, null, '15008', '15008', null), " + + "('p_varchar_1', 1.0E0, 1.0E0, 0.0E0, null, null, null, null), " + + "('p_varchar_2', 1.0E0, 1.0E0, 0.0E0, null, null, null, null), " + + "(null, null, null, null, 1.0E0, null, null, null)"); assertUpdate(format("DROP TABLE %s", tableName)); } diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/TestParquetDistributedQueries.java b/presto-hive/src/test/java/com/facebook/presto/hive/TestParquetDistributedQueries.java index 58ad85ad2b7d4..d8bb271cebcd0 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/TestParquetDistributedQueries.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/TestParquetDistributedQueries.java @@ -73,22 +73,22 @@ public void testQuickStats() // Since no stats were collected during write, all column stats will be null assertQuery("SHOW STATS FOR test_quick_stats", "SELECT * FROM (VALUES " + - " ('orderkey', null, null, null, null, null, null), " + - " ('linenumber', null, null, null, null, null, null), " + - " ('shipdate', null, null, null, null, null, null), " + - " ('arr', null, null, null, null, null, null), " + - " ('rrow', null, null, null, null, null, null), " + - " (null, null, null, null, 60175.0, null, null))"); + " ('orderkey', null, null, null, null, null, null, null), " + + " ('linenumber', null, null, null, null, null, null, null), " + + " ('shipdate', null, null, null, null, null, null, null), " + + " ('arr', null, null, null, null, null, null, null), " + + " ('rrow', null, null, null, null, null, null, null), " + + " (null, null, null, null, 60175.0, null, null, null))"); // With quick stats enabled, we should get nulls_fraction, low_value and high_value for the non-nested columns assertQuery(session, "SHOW STATS FOR test_quick_stats", "SELECT * FROM (VALUES " + - " ('orderkey', null, null, 0.0, null, '1', '60000'), " + - " ('linenumber', null, null, 0.0, null, '1', '7'), " + - " ('shipdate', null, null, 0.0, null, '1992-01-04', '1998-11-29'), " + - " ('arr', null, null, null, null, null, null), " + - " ('rrow', null, null, null, null, null, null), " + - " (null, null, null, null, 60175.0, null, null))"); + " ('orderkey', null, null, 0.0, null, '1', '60000', null), " + + " ('linenumber', null, null, 0.0, null, '1', '7', null), " + + " ('shipdate', null, null, 0.0, null, '1992-01-04', '1998-11-29', null), " + + " ('arr', null, null, null, null, null, null, null), " + + " ('rrow', null, null, null, null, null, null, null), " + + " (null, null, null, null, 60175.0, null, null, null))"); } finally { getQueryRunner().execute("DROP TABLE test_quick_stats"); @@ -131,29 +131,29 @@ public void testQuickStatsPartitionedTable() // Since no stats were collected during write, only the partitioned columns will have stats assertQuery("SHOW STATS FOR test_quick_stats_partitioned", "SELECT * FROM (VALUES " + - " ('suppkey', null, null, null, null, null, null), " + - " ('linenumber', null, null, null, null, null, null), " + - " ('orderkey', null, 10.0, 0.0, null, 100, 109), " + - " ('partkey', null, 10.0, 0.0, null, 1000, 1009), " + - " (null, null, null, null, 10.0, null, null))"); + " ('suppkey', null, null, null, null, null, null, null), " + + " ('linenumber', null, null, null, null, null, null, null), " + + " ('orderkey', null, 10.0, 0.0, null, 100, 109, null), " + + " ('partkey', null, 10.0, 0.0, null, 1000, 1009, null), " + + " (null, null, null, null, 10.0, null, null, null))"); // With quick stats enabled, we should get nulls_fraction, low_value and high_value for all columns assertQuery(session, "SHOW STATS FOR test_quick_stats_partitioned", "SELECT * FROM (VALUES " + - " ('suppkey', null, null, 0.0, null, 1, 10), " + - " ('linenumber', null, null, 0.0, null, 1, 10), " + - " ('orderkey', null, 10.0, 0.0, null, 100, 109), " + - " ('partkey', null, 10.0, 0.0, null, 1000, 1009), " + - " (null, null, null, null, 10.0, null, null))"); + " ('suppkey', null, null, 0.0, null, 1, 10, null), " + + " ('linenumber', null, null, 0.0, null, 1, 10, null), " + + " ('orderkey', null, 10.0, 0.0, null, 100, 109, null), " + + " ('partkey', null, 10.0, 0.0, null, 1000, 1009, null), " + + " (null, null, null, null, 10.0, null, null, null))"); // If a query targets a specific partition, stats are correctly limited to that partition assertQuery(session, "show stats for (select * from test_quick_stats_partitioned where partkey = 1009)", "SELECT * FROM (VALUES " + - " ('suppkey', null, null, 0.0, null, 10, 10), " + - " ('linenumber', null, null, 0.0, null, 10, 10), " + - " ('orderkey', null, 1.0, 0.0, null, 109, 109), " + - " ('partkey', null, 1.0, 0.0, null, 1009, 1009), " + - " (null, null, null, null, 1.0, null, null))"); + " ('suppkey', null, null, 0.0, null, 10, 10, null), " + + " ('linenumber', null, null, 0.0, null, 10, 10, null), " + + " ('orderkey', null, 1.0, 0.0, null, 109, 109, null), " + + " ('partkey', null, 1.0, 0.0, null, 1009, 1009, null), " + + " (null, null, null, null, 1.0, null, null, null))"); } finally { getQueryRunner().execute("DROP TABLE test_quick_stats_partitioned"); diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/TestShowStats.java b/presto-hive/src/test/java/com/facebook/presto/hive/TestShowStats.java index 07274f6ea61c8..2969c5a12f7f0 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/TestShowStats.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/TestShowStats.java @@ -44,97 +44,97 @@ public void testShowStats() { assertQuery("SHOW STATS FOR nation_partitioned", "SELECT * FROM (VALUES " + - " ('regionkey', null, 5.0, 0.0, null, 0, 4), " + - " ('nationkey', null, 5.0, 0.0, null, 0, 24), " + - " ('name', 177.0, 5.0, 0.0, null, null, null), " + - " ('comment', 1857.0, 5.0, 0.0, null, null, null), " + - " (null, null, null, null, 25.0, null, null))"); + " ('regionkey', null, 5.0, 0.0, null, 0, 4, null), " + + " ('nationkey', null, 5.0, 0.0, null, 0, 24, null), " + + " ('name', 177.0, 5.0, 0.0, null, null, null, null), " + + " ('comment', 1857.0, 5.0, 0.0, null, null, null, null), " + + " (null, null, null, null, 25.0, null, null, null))"); assertQuery("SHOW STATS FOR (SELECT * FROM nation_partitioned)", "SELECT * FROM (VALUES " + - " ('regionkey', null, 5.0, 0.0, null, 0, 4), " + - " ('nationkey', null, 5.0, 0.0, null, 0, 24), " + - " ('name', 177.0, 5.0, 0.0, null, null, null), " + - " ('comment', 1857.0, 5.0, 0.0, null, null, null), " + - " (null, null, null, null, 25.0, null, null))"); + " ('regionkey', null, 5.0, 0.0, null, 0, 4, null), " + + " ('nationkey', null, 5.0, 0.0, null, 0, 24, null), " + + " ('name', 177.0, 5.0, 0.0, null, null, null, null), " + + " ('comment', 1857.0, 5.0, 0.0, null, null, null, null), " + + " (null, null, null, null, 25.0, null, null, null))"); assertQuery("SHOW STATS FOR (SELECT regionkey, name FROM nation_partitioned)", "SELECT * FROM (VALUES " + - " ('regionkey', null, 5.0, 0.0, null, 0, 4), " + - " ('name', 177.0, 5.0, 0.0, null, null, null), " + - " (null, null, null, null, 25.0, null, null))"); + " ('regionkey', null, 5.0, 0.0, null, 0, 4, null), " + + " ('name', 177.0, 5.0, 0.0, null, null, null, null), " + + " (null, null, null, null, 25.0, null, null, null))"); assertQuery("SHOW STATS FOR (SELECT * FROM nation_partitioned WHERE regionkey IS NOT NULL)", "SELECT * FROM (VALUES " + - " ('regionkey', null, 5.0, 0.0, null, 0, 4), " + - " ('nationkey', null, 5.0, 0.0, null, 0, 24), " + - " ('name', 177.0, 5.0, 0.0, null, null, null), " + - " ('comment', 1857.0, 5.0, 0.0, null, null, null), " + - " (null, null, null, null, 25.0, null, null))"); + " ('regionkey', null, 5.0, 0.0, null, 0, 4, null), " + + " ('nationkey', null, 5.0, 0.0, null, 0, 24, null), " + + " ('name', 177.0, 5.0, 0.0, null, null, null, null), " + + " ('comment', 1857.0, 5.0, 0.0, null, null, null, null), " + + " (null, null, null, null, 25.0, null, null, null))"); assertQuery("SHOW STATS FOR (SELECT * FROM nation_partitioned WHERE regionkey IS NULL)", "SELECT * FROM (VALUES " + - " ('regionkey', null, 0.0, 0.0, null, null, null), " + - " ('nationkey', null, 0.0, 0.0, null, null, null), " + - " ('name', 0.0, 0.0, 0.0, null, null, null), " + - " ('comment', 0.0, 0.0, 0.0, null, null, null), " + - " (null, null, null, null, 0.0, null, null))"); + " ('regionkey', null, 0.0, 0.0, null, null, null, null), " + + " ('nationkey', null, 0.0, 0.0, null, null, null, null), " + + " ('name', 0.0, 0.0, 0.0, null, null, null, null), " + + " ('comment', 0.0, 0.0, 0.0, null, null, null, null), " + + " (null, null, null, null, 0.0, null, null, null))"); assertQuery("SHOW STATS FOR (SELECT * FROM nation_partitioned WHERE regionkey = 1)", "SELECT * FROM (VALUES " + - " ('regionkey', null, 1.0, 0.0, null, 1, 1), " + - " ('nationkey', null, 5.0, 0.0, null, 1, 24), " + - " ('name', 38.0, 5.0, 0.0, null, null, null), " + - " ('comment', 500.0, 5.0, 0.0, null, null, null), " + - " (null, null, null, null, 5.0, null, null))"); + " ('regionkey', null, 1.0, 0.0, null, 1, 1, null), " + + " ('nationkey', null, 5.0, 0.0, null, 1, 24, null), " + + " ('name', 38.0, 5.0, 0.0, null, null, null, null), " + + " ('comment', 500.0, 5.0, 0.0, null, null, null, null), " + + " (null, null, null, null, 5.0, null, null, null))"); assertQuery("SHOW STATS FOR (SELECT * FROM nation_partitioned WHERE regionkey IN (1, 3))", "SELECT * FROM (VALUES " + - " ('regionkey', null, 2.0, 0.0, null, 1, 3), " + - " ('nationkey', null, 5.0, 0.0, null, 1, 24), " + - " ('name', 78.0, 5.0, 0.0, null, null, null), " + - " ('comment', 847.0, 5.0, 0.0, null, null, null), " + - " (null, null, null, null, 10.0, null, null))"); + " ('regionkey', null, 2.0, 0.0, null, 1, 3, null), " + + " ('nationkey', null, 5.0, 0.0, null, 1, 24, null), " + + " ('name', 78.0, 5.0, 0.0, null, null, null, null), " + + " ('comment', 847.0, 5.0, 0.0, null, null, null, null), " + + " (null, null, null, null, 10.0, null, null, null))"); assertQuery("SHOW STATS FOR (SELECT * FROM nation_partitioned WHERE regionkey BETWEEN 1 AND 1 + 2)", "SELECT * FROM (VALUES " + - " ('regionkey', null, 3.0, 0.0, null, 1, 3), " + - " ('nationkey', null, 5.0, 0.0, null, 1, 24), " + - " ('name', 109.0, 5.0, 0.0, null, null, null), " + - " ('comment', 1199.0, 5.0, 0.0, null, null, null), " + - " (null, null, null, null, 15.0, null, null))"); + " ('regionkey', null, 3.0, 0.0, null, 1, 3, null), " + + " ('nationkey', null, 5.0, 0.0, null, 1, 24, null), " + + " ('name', 109.0, 5.0, 0.0, null, null, null, null), " + + " ('comment', 1199.0, 5.0, 0.0, null, null, null, null), " + + " (null, null, null, null, 15.0, null, null, null))"); assertQuery("SHOW STATS FOR (SELECT * FROM nation_partitioned WHERE regionkey > 3)", "SELECT * FROM (VALUES " + - " ('regionkey', null, 1.0, 0.0, null, 4, 4), " + - " ('nationkey', null, 5.0, 0.0, null, 4, 20), " + - " ('name', 31.0, 5.0, 0.0, null, null, null), " + - " ('comment', 348.0, 5.0, 0.0, null, null, null), " + - " (null, null, null, null, 5.0, null, null))"); + " ('regionkey', null, 1.0, 0.0, null, 4, 4, null), " + + " ('nationkey', null, 5.0, 0.0, null, 4, 20, null), " + + " ('name', 31.0, 5.0, 0.0, null, null, null, null), " + + " ('comment', 348.0, 5.0, 0.0, null, null, null, null), " + + " (null, null, null, null, 5.0, null, null, null))"); assertQuery("SHOW STATS FOR (SELECT * FROM nation_partitioned WHERE regionkey < 1)", "SELECT * FROM (VALUES " + - " ('regionkey', null, 1.0, 0.0, null, 0, 0), " + - " ('nationkey', null, 5.0, 0.0, null, 0, 16), " + - " ('name', 37.0, 5.0, 0.0, null, null, null), " + - " ('comment', 310.0, 5.0, 0.0, null, null, null), " + - " (null, null, null, null, 5.0, null, null))"); + " ('regionkey', null, 1.0, 0.0, null, 0, 0, null), " + + " ('nationkey', null, 5.0, 0.0, null, 0, 16, null), " + + " ('name', 37.0, 5.0, 0.0, null, null, null, null), " + + " ('comment', 310.0, 5.0, 0.0, null, null, null, null), " + + " (null, null, null, null, 5.0, null, null, null))"); assertQuery("SHOW STATS FOR (SELECT * FROM nation_partitioned WHERE regionkey > 0 and regionkey < 4)", "SELECT * FROM (VALUES " + - " ('regionkey', null, 3.0, 0.0, null, 1, 3), " + - " ('nationkey', null, 5.0, 0.0, null, 1, 24), " + - " ('name', 109.0, 5.0, 0.0, null, null, null), " + - " ('comment', 1199.0, 5.0, 0.0, null, null, null), " + - " (null, null, null, null, 15.0, null, null))"); + " ('regionkey', null, 3.0, 0.0, null, 1, 3, null), " + + " ('nationkey', null, 5.0, 0.0, null, 1, 24, null), " + + " ('name', 109.0, 5.0, 0.0, null, null, null, null), " + + " ('comment', 1199.0, 5.0, 0.0, null, null, null, null), " + + " (null, null, null, null, 15.0, null, null, null))"); assertQuery("SHOW STATS FOR (SELECT * FROM nation_partitioned WHERE regionkey > 10 or regionkey < 0)", "SELECT * FROM (VALUES " + - " ('regionkey', null, 0.0, 0.0, null, null, null), " + - " ('nationkey', null, 0.0, 0.0, null, null, null), " + - " ('name', 0.0, 0.0, 0.0, null, null, null), " + - " ('comment', 0.0, 0.0, 0.0, null, null, null), " + - " (null, null, null, null, 0.0, null, null))"); + " ('regionkey', null, 0.0, 0.0, null, null, null, null), " + + " ('nationkey', null, 0.0, 0.0, null, null, null, null), " + + " ('name', 0.0, 0.0, 0.0, null, null, null, null), " + + " ('comment', 0.0, 0.0, 0.0, null, null, null, null), " + + " (null, null, null, null, 0.0, null, null, null))"); } @Test diff --git a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/IcebergDistributedSmokeTestBase.java b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/IcebergDistributedSmokeTestBase.java index 3dac9dc34e352..6b0d16080afb4 100644 --- a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/IcebergDistributedSmokeTestBase.java +++ b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/IcebergDistributedSmokeTestBase.java @@ -1024,21 +1024,21 @@ public void testBasicTableStatistics() assertQuery(session, "SHOW STATS FOR " + tableName, "VALUES " + - " ('col', null, null, null, NULL, NULL, NULL), " + - " (NULL, NULL, NULL, NULL, 0e0, NULL, NULL)"); + " ('col', null, null, null, NULL, NULL, NULL, NULL), " + + " (NULL, NULL, NULL, NULL, 0e0, NULL, NULL, NULL)"); assertUpdate("INSERT INTO " + tableName + " VALUES -10", 1); assertUpdate("INSERT INTO " + tableName + " VALUES 100", 1); assertQuery(session, "SHOW STATS FOR " + tableName, "VALUES " + - " ('col', NULL, NULL, 0.0, NULL, '-10.0', '100.0'), " + - " (NULL, NULL, NULL, NULL, 2e0, NULL, NULL)"); + " ('col', NULL, NULL, 0.0, NULL, '-10.0', '100.0', NULL), " + + " (NULL, NULL, NULL, NULL, 2e0, NULL, NULL, NULL)"); assertUpdate("INSERT INTO " + tableName + " VALUES 200", 1); assertQuery(session, "SHOW STATS FOR " + tableName, "VALUES " + - " ('col', NULL, NULL, 0.0, NULL, '-10.0', '200.0'), " + - " (NULL, NULL, NULL, NULL, 3e0, NULL, NULL)"); + " ('col', NULL, NULL, 0.0, NULL, '-10.0', '200.0', NULL), " + + " (NULL, NULL, NULL, NULL, 3e0, NULL, NULL, NULL)"); dropTable(session, tableName); } @@ -1186,16 +1186,16 @@ public void testTableStatisticsTimestamp() assertQuery(session, "SHOW STATS FOR " + tableName, "VALUES " + - " ('col', null, null, null, NULL, NULL, NULL), " + - " (NULL, NULL, NULL, NULL, 0e0, NULL, NULL)"); + " ('col', null, null, null, NULL, NULL, NULL, NULL), " + + " (NULL, NULL, NULL, NULL, 0e0, NULL, NULL, NULL)"); assertUpdate(session, "INSERT INTO " + tableName + " VALUES TIMESTAMP '2021-01-02 09:04:05.321'", 1); assertUpdate(session, "INSERT INTO " + tableName + " VALUES TIMESTAMP '2022-12-22 10:07:08.456'", 1); assertQuery(session, "SHOW STATS FOR " + tableName, "VALUES " + - " ('col', NULL, NULL, 0.0, NULL, '2021-01-02 09:04:05.321', '2022-12-22 10:07:08.456'), " + - " (NULL, NULL, NULL, NULL, 2e0, NULL, NULL)"); + " ('col', NULL, NULL, 0.0, NULL, '2021-01-02 09:04:05.321', '2022-12-22 10:07:08.456', NULL), " + + " (NULL, NULL, NULL, NULL, 2e0, NULL, NULL, NULL)"); dropTable(session, tableName); } diff --git a/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java b/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java index 067997132c1ca..d9a88ba8ced8a 100644 --- a/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java +++ b/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java @@ -347,6 +347,7 @@ public final class SystemSessionProperties public static final String NATIVE_DEBUG_VALIDATE_OUTPUT_FROM_OPERATORS = "native_debug_validate_output_from_operators"; public static final String DEFAULT_VIEW_SECURITY_MODE = "default_view_security_mode"; public static final String JOIN_PREFILTER_BUILD_SIDE = "join_prefilter_build_side"; + public static final String OPTIMIZER_USE_HISTOGRAMS = "optimizer_use_histograms"; private final List> sessionProperties; @@ -1910,7 +1911,7 @@ public SystemSessionProperties( GENERATE_DOMAIN_FILTERS, "Infer predicates from column domains during predicate pushdown", featuresConfig.getGenerateDomainFilters(), - false), + false), booleanProperty( REWRITE_EXPRESSION_WITH_CONSTANT_EXPRESSION, "Rewrite left join with is null check to semi join", @@ -1938,6 +1939,10 @@ public SystemSessionProperties( JOIN_PREFILTER_BUILD_SIDE, "Prefiltering the build/inner side of a join with keys from the other side", false, + false), + booleanProperty(OPTIMIZER_USE_HISTOGRAMS, + "whether or not to use histograms in the CBO", + featuresConfig.isUseHistograms(), false)); } @@ -3230,4 +3235,9 @@ public static boolean isPrintEstimatedStatsFromCacheEnabled(Session session) { return session.getSystemProperty(PRINT_ESTIMATED_STATS_FROM_CACHE, Boolean.class); } + + public static boolean shouldOptimizerUseHistograms(Session session) + { + return session.getSystemProperty(OPTIMIZER_USE_HISTOGRAMS, Boolean.class); + } } diff --git a/presto-main/src/main/java/com/facebook/presto/cost/ComparisonStatsCalculator.java b/presto-main/src/main/java/com/facebook/presto/cost/ComparisonStatsCalculator.java index f695c5abb8dce..306e00cfb9079 100644 --- a/presto-main/src/main/java/com/facebook/presto/cost/ComparisonStatsCalculator.java +++ b/presto-main/src/main/java/com/facebook/presto/cost/ComparisonStatsCalculator.java @@ -13,12 +13,17 @@ */ package com.facebook.presto.cost; +import com.facebook.airlift.log.Logger; +import com.facebook.presto.Session; +import com.facebook.presto.SystemSessionProperties; import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.spi.statistics.Estimate; import com.facebook.presto.sql.tree.ComparisonExpression; import java.util.Optional; import java.util.OptionalDouble; +import static com.facebook.presto.cost.FilterStatsCalculator.UNKNOWN_FILTER_COEFFICIENT; import static com.facebook.presto.cost.VariableStatsEstimate.buildFrom; import static com.facebook.presto.util.MoreMath.firstNonNaN; import static com.facebook.presto.util.MoreMath.max; @@ -28,12 +33,20 @@ import static java.lang.Double.POSITIVE_INFINITY; import static java.lang.Double.isFinite; import static java.lang.Double.isNaN; +import static java.util.Objects.requireNonNull; public final class ComparisonStatsCalculator { - private ComparisonStatsCalculator() {} + private static final Logger log = Logger.get(ComparisonStatsCalculator.class); + private final boolean useHistograms; - public static PlanNodeStatsEstimate estimateExpressionToLiteralComparison( + public ComparisonStatsCalculator(Session session) + { + requireNonNull(session, "session is null"); + this.useHistograms = SystemSessionProperties.shouldOptimizerUseHistograms(session); + } + + public PlanNodeStatsEstimate estimateExpressionToLiteralComparison( PlanNodeStatsEstimate inputStatistics, VariableStatsEstimate expressionStatistics, Optional expressionVariable, @@ -46,11 +59,13 @@ public static PlanNodeStatsEstimate estimateExpressionToLiteralComparison( case NOT_EQUAL: return estimateExpressionNotEqualToLiteral(inputStatistics, expressionStatistics, expressionVariable, literalValue); case LESS_THAN: + return estimateExpressionLessThanLiteral(inputStatistics, expressionStatistics, expressionVariable, literalValue, false); case LESS_THAN_OR_EQUAL: - return estimateExpressionLessThanLiteral(inputStatistics, expressionStatistics, expressionVariable, literalValue); + return estimateExpressionLessThanLiteral(inputStatistics, expressionStatistics, expressionVariable, literalValue, true); case GREATER_THAN: + return estimateExpressionGreaterThanLiteral(inputStatistics, expressionStatistics, expressionVariable, literalValue, false); case GREATER_THAN_OR_EQUAL: - return estimateExpressionGreaterThanLiteral(inputStatistics, expressionStatistics, expressionVariable, literalValue); + return estimateExpressionGreaterThanLiteral(inputStatistics, expressionStatistics, expressionVariable, literalValue, true); case IS_DISTINCT_FROM: return PlanNodeStatsEstimate.unknown(); default: @@ -58,7 +73,7 @@ public static PlanNodeStatsEstimate estimateExpressionToLiteralComparison( } } - private static PlanNodeStatsEstimate estimateExpressionEqualToLiteral( + private PlanNodeStatsEstimate estimateExpressionEqualToLiteral( PlanNodeStatsEstimate inputStatistics, VariableStatsEstimate expressionStatistics, Optional expressionVariable, @@ -66,7 +81,7 @@ private static PlanNodeStatsEstimate estimateExpressionEqualToLiteral( { StatisticRange filterRange; if (literalValue.isPresent()) { - filterRange = new StatisticRange(literalValue.getAsDouble(), literalValue.getAsDouble(), 1); + filterRange = new StatisticRange(literalValue.getAsDouble(), false, literalValue.getAsDouble(), false, 1); } else { filterRange = new StatisticRange(NEGATIVE_INFINITY, POSITIVE_INFINITY, 1); @@ -74,23 +89,20 @@ private static PlanNodeStatsEstimate estimateExpressionEqualToLiteral( return estimateFilterRange(inputStatistics, expressionStatistics, expressionVariable, filterRange); } - private static PlanNodeStatsEstimate estimateExpressionNotEqualToLiteral( + private PlanNodeStatsEstimate estimateExpressionNotEqualToLiteral( PlanNodeStatsEstimate inputStatistics, VariableStatsEstimate expressionStatistics, Optional expressionVariable, OptionalDouble literalValue) { - StatisticRange expressionRange = StatisticRange.from(expressionStatistics); - StatisticRange filterRange; if (literalValue.isPresent()) { - filterRange = new StatisticRange(literalValue.getAsDouble(), literalValue.getAsDouble(), 1); + filterRange = new StatisticRange(literalValue.getAsDouble(), false, literalValue.getAsDouble(), false, 1); } else { - filterRange = new StatisticRange(NEGATIVE_INFINITY, POSITIVE_INFINITY, 1); + filterRange = new StatisticRange(NEGATIVE_INFINITY, true, POSITIVE_INFINITY, true, 1); } - StatisticRange intersectRange = expressionRange.intersect(filterRange); - double filterFactor = 1 - expressionRange.overlapPercentWith(intersectRange); + double filterFactor = 1 - calculateFilterFactor(expressionStatistics, filterRange); PlanNodeStatsEstimate.Builder estimate = PlanNodeStatsEstimate.buildFrom(inputStatistics); estimate.setOutputRowCount(filterFactor * (1 - expressionStatistics.getNullsFraction()) * inputStatistics.getOutputRowCount()); @@ -104,51 +116,83 @@ private static PlanNodeStatsEstimate estimateExpressionNotEqualToLiteral( return estimate.build(); } - private static PlanNodeStatsEstimate estimateExpressionLessThanLiteral( + private PlanNodeStatsEstimate estimateExpressionLessThanLiteral( PlanNodeStatsEstimate inputStatistics, VariableStatsEstimate expressionStatistics, Optional expressionVariable, - OptionalDouble literalValue) + OptionalDouble literalValue, + boolean equals) { - StatisticRange filterRange = new StatisticRange(NEGATIVE_INFINITY, literalValue.orElse(POSITIVE_INFINITY), NaN); + StatisticRange filterRange = new StatisticRange(NEGATIVE_INFINITY, true, literalValue.orElse(POSITIVE_INFINITY), !equals, NaN); return estimateFilterRange(inputStatistics, expressionStatistics, expressionVariable, filterRange); } - private static PlanNodeStatsEstimate estimateExpressionGreaterThanLiteral( + private PlanNodeStatsEstimate estimateExpressionGreaterThanLiteral( PlanNodeStatsEstimate inputStatistics, VariableStatsEstimate expressionStatistics, Optional expressionVariable, - OptionalDouble literalValue) + OptionalDouble literalValue, + boolean equals) { - StatisticRange filterRange = new StatisticRange(literalValue.orElse(NEGATIVE_INFINITY), POSITIVE_INFINITY, NaN); + StatisticRange filterRange = new StatisticRange(literalValue.orElse(NEGATIVE_INFINITY), !equals, POSITIVE_INFINITY, true, NaN); return estimateFilterRange(inputStatistics, expressionStatistics, expressionVariable, filterRange); } - private static PlanNodeStatsEstimate estimateFilterRange( + private PlanNodeStatsEstimate estimateFilterRange( PlanNodeStatsEstimate inputStatistics, VariableStatsEstimate expressionStatistics, Optional expressionVariable, StatisticRange filterRange) { + double filterFactor = calculateFilterFactor(expressionStatistics, filterRange); + StatisticRange expressionRange = StatisticRange.from(expressionStatistics); StatisticRange intersectRange = expressionRange.intersect(filterRange); - - double filterFactor = expressionRange.overlapPercentWith(intersectRange); - PlanNodeStatsEstimate estimate = inputStatistics.mapOutputRowCount(rowCount -> filterFactor * (1 - expressionStatistics.getNullsFraction()) * rowCount); if (expressionVariable.isPresent()) { - VariableStatsEstimate symbolNewEstimate = + VariableStatsEstimate.Builder symbolNewEstimate = VariableStatsEstimate.builder() .setAverageRowSize(expressionStatistics.getAverageRowSize()) .setStatisticsRange(intersectRange) - .setNullsFraction(0.0) - .build(); - estimate = estimate.mapVariableColumnStatistics(expressionVariable.get(), oldStats -> symbolNewEstimate); + .setNullsFraction(0.0); + if (useHistograms) { + symbolNewEstimate.setHistogram(expressionStatistics.getHistogram().map(expressionHistogram -> DisjointRangeDomainHistogram.addConjunction(expressionHistogram, intersectRange))); + } + + estimate = estimate.mapVariableColumnStatistics(expressionVariable.get(), oldStats -> symbolNewEstimate.build()); } return estimate; } - public static PlanNodeStatsEstimate estimateExpressionToExpressionComparison( + private double calculateFilterFactor(VariableStatsEstimate variableStatistics, StatisticRange filterRange) + { + StatisticRange variableRange = StatisticRange.from(variableStatistics); + StatisticRange intersectRange = variableRange.intersect(filterRange); + Estimate filterEstimate; + if (useHistograms) { + Estimate distinctEstimate = isNaN(variableStatistics.getDistinctValuesCount()) ? Estimate.unknown() : Estimate.of(variableRange.getDistinctValuesCount()); + filterEstimate = HistogramCalculator.calculateFilterFactor(intersectRange, variableStatistics.getHistogram().orElse(new UniformDistributionHistogram(variableStatistics.getLowValue(), variableStatistics.getHighValue())), distinctEstimate, true); + if (log.isDebugEnabled()) { + double expressionFilter = variableRange.overlapPercentWith(intersectRange); + if (!Double.isNaN(expressionFilter) && + !filterEstimate.fuzzyEquals(Estimate.of(expressionFilter), .0001)) { + log.debug(String.format("histogram-calculated filter factor differs from the uniformity assumption:" + + "expression range: %s%n" + + "intersect range: %s%n" + + "overlapPercent: %s%n" + + "histogram: %s%n" + + "histogramFilterIntersect: %s%n", variableRange, intersectRange, expressionFilter, variableStatistics.getHistogram(), filterEstimate)); + } + } + } + else { + filterEstimate = Estimate.estimateFromDouble(variableRange.overlapPercentWith(intersectRange)); + } + + return filterEstimate.orElse(() -> UNKNOWN_FILTER_COEFFICIENT); + } + + public PlanNodeStatsEstimate estimateExpressionToExpressionComparison( PlanNodeStatsEstimate inputStatistics, VariableStatsEstimate leftExpressionStatistics, Optional leftExpressionVariable, @@ -172,7 +216,7 @@ public static PlanNodeStatsEstimate estimateExpressionToExpressionComparison( } } - private static PlanNodeStatsEstimate estimateExpressionEqualToExpression( + private PlanNodeStatsEstimate estimateExpressionEqualToExpression( PlanNodeStatsEstimate inputStatistics, VariableStatsEstimate leftExpressionStatistics, Optional leftExpressionVariable, @@ -210,7 +254,7 @@ private static PlanNodeStatsEstimate estimateExpressionEqualToExpression( return estimate.build(); } - private static PlanNodeStatsEstimate estimateExpressionNotEqualToExpression( + private PlanNodeStatsEstimate estimateExpressionNotEqualToExpression( PlanNodeStatsEstimate inputStatistics, VariableStatsEstimate leftExpressionStatistics, Optional leftExpressionVariable, diff --git a/presto-main/src/main/java/com/facebook/presto/cost/ConnectorFilterStatsCalculatorService.java b/presto-main/src/main/java/com/facebook/presto/cost/ConnectorFilterStatsCalculatorService.java index 2cdf660ce4b77..9222c9bf22451 100644 --- a/presto-main/src/main/java/com/facebook/presto/cost/ConnectorFilterStatsCalculatorService.java +++ b/presto-main/src/main/java/com/facebook/presto/cost/ConnectorFilterStatsCalculatorService.java @@ -14,6 +14,8 @@ package com.facebook.presto.cost; +import com.facebook.presto.FullConnectorSession; +import com.facebook.presto.Session; import com.facebook.presto.common.type.Type; import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.ConnectorSession; @@ -53,7 +55,10 @@ public TableStatistics filterStats( Map columnTypes) { PlanNodeStatsEstimate tableStats = toPlanNodeStats(tableStatistics, columnNames, columnTypes); - PlanNodeStatsEstimate filteredStats = filterStatsCalculator.filterStats(tableStats, predicate, session); + // TODO: Consider re-designing the filter calculator APIs so that a proper Session instance + // can be more easily populated + Session filterSession = ((FullConnectorSession) session).getSession(); + PlanNodeStatsEstimate filteredStats = filterStatsCalculator.filterStats(tableStats, predicate, session, filterSession); if (filteredStats.isOutputRowCountUnknown()) { filteredStats = tableStats.mapOutputRowCount(sourceRowCount -> tableStats.getOutputRowCount() * UNKNOWN_FILTER_COEFFICIENT); diff --git a/presto-main/src/main/java/com/facebook/presto/cost/DisjointRangeDomainHistogram.java b/presto-main/src/main/java/com/facebook/presto/cost/DisjointRangeDomainHistogram.java new file mode 100644 index 0000000000000..1d4b44140de41 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/cost/DisjointRangeDomainHistogram.java @@ -0,0 +1,359 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.presto.cost; + +import com.facebook.presto.spi.statistics.ConnectorHistogram; +import com.facebook.presto.spi.statistics.Estimate; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.base.Suppliers; +import com.google.common.collect.BoundType; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Range; +import com.google.common.collect.RangeSet; +import com.google.common.collect.TreeRangeSet; + +import java.util.Collection; +import java.util.NoSuchElementException; +import java.util.Objects; +import java.util.Optional; +import java.util.Set; +import java.util.function.Supplier; + +import static com.facebook.presto.cost.HistogramCalculator.calculateFilterFactor; +import static com.facebook.presto.util.MoreMath.max; +import static com.facebook.presto.util.MoreMath.min; +import static com.google.common.base.MoreObjects.toStringHelper; +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static java.lang.Double.NEGATIVE_INFINITY; +import static java.lang.Double.POSITIVE_INFINITY; +import static java.lang.Double.isFinite; +import static java.util.Objects.hash; +import static java.util.Objects.requireNonNull; + +/** + * This class represents a set of disjoint ranges that span an input domain. + * Each range is used to represent filters over the domain of an original + * "source" histogram. + *
+ * For example, assume a source histogram represents a uniform distribution + * over the range [0, 100]. Next, assume a query with multiple filters such as + * x < 10 OR x > 85. This translates to two disjoint ranges over + * the histogram of [0, 10) and (85, 100], representing roughly 35% of the + * values in the original dataset. Using the example above, a cumulative + * probability for value 5 represents 5% of the original dataset, but 20% (1/5) + * of the range of constrained dataset. Similarly, all values in [10, 85] should + * compute their cumulative probability as 40% (2/5). + *
+ * The goal of this class is to implement the {@link ConnectorHistogram} API + * given a source histogram whose domain has been constrained by a set of filter + * ranges. + *
+ * This class is intended to be immutable. Changing the set of ranges should + * result in a new copy being created. + */ +public class DisjointRangeDomainHistogram + implements ConnectorHistogram +{ + private final ConnectorHistogram source; + // use RangeSet as the internal representation of the ranges, but the constructor arguments + // use StatisticRange to support serialization and deserialization. + private final Supplier> rangeSet; + private final Set> ranges; + + @JsonCreator + public DisjointRangeDomainHistogram(@JsonProperty("source") ConnectorHistogram source, @JsonProperty("ranges") Collection ranges) + { + this(source, ranges.stream().map(StatisticRange::toRange).collect(toImmutableSet())); + } + + public DisjointRangeDomainHistogram(ConnectorHistogram source, Set> ranges) + { + this.source = requireNonNull(source, "source is null"); + this.ranges = requireNonNull(ranges, "ranges is null"); + this.rangeSet = Suppliers.memoize(() -> { + RangeSet rangeSet = TreeRangeSet.create(); + rangeSet.addAll(ranges); + return rangeSet.subRangeSet(getSourceSpan(this.source)); + }); + } + + private static Range getSourceSpan(ConnectorHistogram source) + { + return Range.closed( + source.inverseCumulativeProbability(0.0).orElse(() -> NEGATIVE_INFINITY), + source.inverseCumulativeProbability(1.0).orElse(() -> POSITIVE_INFINITY)); + } + + @JsonProperty + public ConnectorHistogram getSource() + { + return source; + } + + @JsonProperty + public Set getRanges() + { + return rangeSet.get().asRanges().stream().map(StatisticRange::fromRange).collect(toImmutableSet()); + } + + public DisjointRangeDomainHistogram(ConnectorHistogram source) + { + this(source, ImmutableSet.>of()); + } + + @Override + public Estimate cumulativeProbability(double value, boolean inclusive) + { + // 1. compute the total probability for every existing range on the source + // 2. find the range, r, where `value` falls + // 3. compute the cumulative probability across all ranges that intersect [min, value] + // 4. divide the result from (3) by the result from (1) to get the true cumulative + // probability of the disjoint domains over the original histogram + if (Double.isNaN(value)) { + return Estimate.unknown(); + } + Optional> optionalSpan = getSpan(); + if (!optionalSpan.isPresent()) { + return Estimate.of(0.0); + } + Range span = optionalSpan.get(); + if (value <= span.lowerEndpoint()) { + return Estimate.of(0.0); + } + Range input = Range.range(span.lowerEndpoint(), span.lowerBoundType(), value, inclusive ? BoundType.CLOSED : BoundType.OPEN); + Estimate fullSetOverlap = calculateRangeSetOverlap(rangeSet.get()); + RangeSet spanned = rangeSet.get().subRangeSet(input); + Estimate spannedOverlap = calculateRangeSetOverlap(spanned); + + return spannedOverlap.flatMap(spannedProbability -> + fullSetOverlap.map(fullSetProbability -> { + if (fullSetProbability == 0.0) { + return 0.0; + } + return min(spannedProbability / fullSetProbability, 1.0); + })); + } + + private Estimate calculateRangeSetOverlap(RangeSet ranges) + { + // we require knowing bounds on all ranges + double cumulativeTotal = 0.0; + for (Range range : ranges.asRanges()) { + Estimate rangeProbability = getRangeProbability(range); + if (rangeProbability.isUnknown()) { + return Estimate.unknown(); + } + cumulativeTotal += rangeProbability.getValue(); + } + return Estimate.of(cumulativeTotal); + } + + /** + * Calculates the percent of the source distribution that {@code range} + * spans. + * + * @param range the range over the source domain + * @return estimate of the total probability the range covers in the source + */ + private Estimate getRangeProbability(Range range) + { + return calculateFilterFactor(StatisticRange.fromRange(range), source, Estimate.unknown(), false); + } + + @Override + public Estimate inverseCumulativeProbability(double percentile) + { + checkArgument(percentile >= 0.0 && percentile <= 1.0, "percentile must fall in [0.0, 1.0]"); + // 1. compute the probability for each range on the source in order until reaching a range + // where the cumulative total exceeds the percentile argument (totalCumulative) + // 2. compute the source probability of the left endpoint of the given range (percentileLow) + // 3. compute the "true" source percentile: + // rangedPercentile = percentile - percentileLow + // + // percentileLow + (rangedPercentile * rangePercentileLength) + Optional> optionalSpan = getSpan(); + if (!optionalSpan.isPresent()) { + return Estimate.unknown(); + } + Range span = optionalSpan.get(); + if (percentile == 0.0 && isFinite(span.lowerEndpoint())) { + return source.inverseCumulativeProbability(0.0).map(sourceMin -> max(span.lowerEndpoint(), sourceMin)); + } + + if (percentile == 1.0 && isFinite(span.upperEndpoint())) { + return source.inverseCumulativeProbability(1.0).map(sourceMax -> min(span.upperEndpoint(), sourceMax)); + } + + Estimate totalCumulativeEstimate = calculateRangeSetOverlap(rangeSet.get()); + if (totalCumulativeEstimate.isUnknown()) { + return Estimate.unknown(); + } + double totalCumulativeProbabilitySourceDomain = totalCumulativeEstimate.getValue(); + if (totalCumulativeProbabilitySourceDomain == 0.0) { + // calculations will fail with NaN + return Estimate.unknown(); + } + double cumulativeProbabilityNewDomain = 0.0; + double lastRangeEstimateSourceDomain = 0.0; + Range currentRange = null; + // find the range where the percentile falls + for (Range range : rangeSet.get().asRanges()) { + Estimate rangeEstimate = getRangeProbability(range); + if (rangeEstimate.isUnknown()) { + return Estimate.unknown(); + } + currentRange = range; + lastRangeEstimateSourceDomain = rangeEstimate.getValue(); + cumulativeProbabilityNewDomain += lastRangeEstimateSourceDomain / totalCumulativeProbabilitySourceDomain; + if (cumulativeProbabilityNewDomain >= percentile) { + break; + } + } + if (currentRange == null) { + // no ranges to iterate over. Did a constraint cut the entire domain of values? + return Estimate.unknown(); + } + Estimate rangeLeftSourceEstimate = source.cumulativeProbability(currentRange.lowerEndpoint(), currentRange.lowerBoundType() == BoundType.OPEN); + if (rangeLeftSourceEstimate.isUnknown()) { + return Estimate.unknown(); + } + double rangeLeftSource = rangeLeftSourceEstimate.getValue(); + double lastRangeProportionalProbability = lastRangeEstimateSourceDomain / totalCumulativeProbabilitySourceDomain; + double percentileLeftFromNewDomain = percentile - cumulativeProbabilityNewDomain + lastRangeProportionalProbability; + double percentilePoint = lastRangeEstimateSourceDomain * percentileLeftFromNewDomain / lastRangeProportionalProbability; + double finalPercentile = rangeLeftSource + percentilePoint; + + return source.inverseCumulativeProbability(min(max(finalPercentile, 0.0), 1.0)); + } + + /** + * Adds a new domain (logical disjunction) to the existing set. + * + * @param other the new range to add to the set. + * @return a new {@link DisjointRangeDomainHistogram} + */ + public DisjointRangeDomainHistogram addDisjunction(StatisticRange other) + { + Set> ranges = ImmutableSet.>builder() + .addAll(this.ranges) + .add(other.toRange()) + .build(); + return new DisjointRangeDomainHistogram(source, ranges); + } + + /** + * Adds a constraint (logical conjunction). This will constrain all ranges + * in the set to ones that are contained by the argument range. + * + * @param other the range that should enclose the set. + * @return a new {@link DisjointRangeDomainHistogram} where + */ + public DisjointRangeDomainHistogram addConjunction(StatisticRange other) + { + return new DisjointRangeDomainHistogram(source, rangeSet.get().subRangeSet(other.toRange()).asRanges()); + } + + /** + * Adds a new range to the available ranges that this histogram computes over + *
+ * e.g. if the source histogram represents values [0, 100], and an existing + * range in the set constrains it to [0, 25], and this method is called with + * a range of [50, 75], then it will attempt to push [50, 75] down onto the + * existing histogram to expand the set of intervals that are used to + * computed probabilities to [[0, 25], [50, 75]]. + *
+ * This method should be called for cases where we want to calculate plan + * statistics for queries that have multiple filters combined with OR. + * + * @param histogram the source histogram to add the range conjunction + * @param range the range representing the conjunction to add + * @return a new histogram with the conjunction applied. + */ + public static ConnectorHistogram addDisjunction(ConnectorHistogram histogram, StatisticRange range) + { + if (histogram instanceof DisjointRangeDomainHistogram) { + return ((DisjointRangeDomainHistogram) histogram).addDisjunction(range); + } + + return new DisjointRangeDomainHistogram(histogram, ImmutableSet.of(range.toRange())); + } + + /** + * Similar to {@link #addDisjunction(ConnectorHistogram, StatisticRange)} this method constrains + * the entire domain such that all ranges in the set intersect with the given range + * argument to this method. + *
+ * This should be used when an AND clause is present in the query and all tuples MUST satisfy + * the condition. + * + * @param histogram the source histogram + * @param range the range of values that the entire histogram's domain must fall within + * @return a histogram with the new range constraint + */ + public static ConnectorHistogram addConjunction(ConnectorHistogram histogram, StatisticRange range) + { + if (histogram instanceof DisjointRangeDomainHistogram) { + return ((DisjointRangeDomainHistogram) histogram).addConjunction(range); + } + + return new DisjointRangeDomainHistogram(histogram, ImmutableSet.of(range.toRange())); + } + + /** + * @return the span if it exists, empty otherwise + */ + private Optional> getSpan() + { + try { + return Optional.of(rangeSet.get().span()); + } + catch (NoSuchElementException e) { + return Optional.empty(); + } + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("source", this.source) + .add("rangeSet", this.rangeSet) + .toString(); + } + + @Override + public boolean equals(Object o) + { + if (o == this) { + return true; + } + if (!(o instanceof DisjointRangeDomainHistogram)) { + return false; + } + DisjointRangeDomainHistogram other = (DisjointRangeDomainHistogram) o; + return Objects.equals(source, other.source) && + // getRanges() flattens and creates the minimal range set which + // determines whether two histograms are truly equal + Objects.equals(getRanges(), other.getRanges()); + } + + @Override + public int hashCode() + { + return hash(source, getRanges()); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/cost/ExchangeStatsRule.java b/presto-main/src/main/java/com/facebook/presto/cost/ExchangeStatsRule.java index 13add075e9945..903884c8e5fe7 100644 --- a/presto-main/src/main/java/com/facebook/presto/cost/ExchangeStatsRule.java +++ b/presto-main/src/main/java/com/facebook/presto/cost/ExchangeStatsRule.java @@ -24,8 +24,8 @@ import java.util.List; import java.util.Optional; +import static com.facebook.presto.SystemSessionProperties.shouldOptimizerUseHistograms; import static com.facebook.presto.cost.PlanNodeStatsEstimate.buildFrom; -import static com.facebook.presto.cost.PlanNodeStatsEstimateMath.addStatsAndMaxDistinctValues; import static com.facebook.presto.sql.planner.plan.Patterns.exchange; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Verify.verify; @@ -63,7 +63,8 @@ protected Optional doCalculate(ExchangeNode node, StatsPr PlanNodeStatsEstimate sourceStatsWithMappedSymbols = mapToOutputVariables(sourceStats, node.getInputs().get(i), node.getOutputVariables()); if (estimate.isPresent()) { - estimate = Optional.of(addStatsAndMaxDistinctValues(estimate.get(), sourceStatsWithMappedSymbols)); + PlanNodeStatsEstimateMath calculator = new PlanNodeStatsEstimateMath(shouldOptimizerUseHistograms(session)); + estimate = Optional.of(calculator.addStatsAndMaxDistinctValues(estimate.get(), sourceStatsWithMappedSymbols)); } else { estimate = Optional.of(sourceStatsWithMappedSymbols); diff --git a/presto-main/src/main/java/com/facebook/presto/cost/FilterStatsCalculator.java b/presto-main/src/main/java/com/facebook/presto/cost/FilterStatsCalculator.java index 9e35fa478501a..347e5d71df8b5 100644 --- a/presto-main/src/main/java/com/facebook/presto/cost/FilterStatsCalculator.java +++ b/presto-main/src/main/java/com/facebook/presto/cost/FilterStatsCalculator.java @@ -14,6 +14,7 @@ package com.facebook.presto.cost; import com.facebook.presto.Session; +import com.facebook.presto.SystemSessionProperties; import com.facebook.presto.common.function.OperatorType; import com.facebook.presto.common.type.Type; import com.facebook.presto.metadata.FunctionAndTypeManager; @@ -65,12 +66,8 @@ import java.util.Optional; import java.util.OptionalDouble; +import static com.facebook.presto.SystemSessionProperties.shouldOptimizerUseHistograms; import static com.facebook.presto.common.type.BooleanType.BOOLEAN; -import static com.facebook.presto.cost.ComparisonStatsCalculator.estimateExpressionToExpressionComparison; -import static com.facebook.presto.cost.ComparisonStatsCalculator.estimateExpressionToLiteralComparison; -import static com.facebook.presto.cost.PlanNodeStatsEstimateMath.addStatsAndSumDistinctValues; -import static com.facebook.presto.cost.PlanNodeStatsEstimateMath.capStats; -import static com.facebook.presto.cost.PlanNodeStatsEstimateMath.subtractSubsetStats; import static com.facebook.presto.cost.StatsUtil.toStatsRepresentation; import static com.facebook.presto.expressions.DynamicFilters.isDynamicFilter; import static com.facebook.presto.expressions.LogicalRowExpressions.TRUE_CONSTANT; @@ -135,10 +132,16 @@ public PlanNodeStatsEstimate filterStats( public PlanNodeStatsEstimate filterStats( PlanNodeStatsEstimate statsEstimate, RowExpression predicate, - ConnectorSession session) + ConnectorSession session, + /* TODO: this session parameter is optional because this method can be called from the + ConnectorFilterStatsCalculatorService which only has access to a ConnectorSession + object. When the ConnectorSession API allows access to the underlying session, this + method should be updated to just accept a ConnectorSession. + */ + Session systemSession) { RowExpression simplifiedExpression = simplifyExpression(session, predicate); - return new FilterRowExpressionStatsCalculatingVisitor(statsEstimate, session, metadata.getFunctionAndTypeManager()).process(simplifiedExpression); + return new FilterRowExpressionStatsCalculatingVisitor(statsEstimate, session, metadata.getFunctionAndTypeManager(), systemSession).process(simplifiedExpression); } public PlanNodeStatsEstimate filterStats( @@ -146,7 +149,7 @@ public PlanNodeStatsEstimate filterStats( RowExpression predicate, Session session) { - return filterStats(statsEstimate, predicate, session.toConnectorSession()); + return filterStats(statsEstimate, predicate, session.toConnectorSession(), session); } private Expression simplifyExpression(Session session, Expression predicate, TypeProvider types) @@ -196,12 +199,16 @@ private class FilterExpressionStatsCalculatingVisitor private final PlanNodeStatsEstimate input; private final Session session; private final TypeProvider types; + private final PlanNodeStatsEstimateMath calculator; + private final ComparisonStatsCalculator comparisonCalculator; FilterExpressionStatsCalculatingVisitor(PlanNodeStatsEstimate input, Session session, TypeProvider types) { this.input = input; this.session = session; this.types = types; + this.calculator = new PlanNodeStatsEstimateMath(shouldOptimizerUseHistograms(session)); + this.comparisonCalculator = new ComparisonStatsCalculator(session); } @Override @@ -222,7 +229,7 @@ protected PlanNodeStatsEstimate visitNotExpression(NotExpression node, Void cont if (node.getValue() instanceof IsNullPredicate) { return process(new IsNotNullPredicate(((IsNullPredicate) node.getValue()).getValue())); } - return subtractSubsetStats(input, process(node.getValue())); + return calculator.subtractSubsetStats(input, process(node.getValue())); } @Override @@ -285,9 +292,9 @@ private PlanNodeStatsEstimate estimateLogicalOr(Expression left, Expression righ return PlanNodeStatsEstimate.unknown(); } - return capStats( - subtractSubsetStats( - addStatsAndSumDistinctValues(leftEstimate, rightEstimate), + return calculator.capStats( + calculator.subtractSubsetStats( + calculator.addStatsAndSumDistinctValues(leftEstimate, rightEstimate), andEstimate), input); } @@ -384,7 +391,7 @@ protected PlanNodeStatsEstimate visitInPredicate(InPredicate node, Void context) } PlanNodeStatsEstimate inEstimate = equalityEstimates.stream() - .reduce(PlanNodeStatsEstimateMath::addStatsAndSumDistinctValues) + .reduce(calculator::addStatsAndSumDistinctValues) .orElse(PlanNodeStatsEstimate.unknown()); if (inEstimate.isOutputRowCountUnknown()) { @@ -441,17 +448,17 @@ protected PlanNodeStatsEstimate visitComparisonExpression(ComparisonExpression n return visitBooleanLiteral(FALSE_LITERAL, null); } OptionalDouble literal = toStatsRepresentation(metadata, session, getType(left), literalValue); - return estimateExpressionToLiteralComparison(input, leftStats, leftVariable, literal, operator); + return comparisonCalculator.estimateExpressionToLiteralComparison(input, leftStats, leftVariable, literal, operator); } VariableStatsEstimate rightStats = getExpressionStats(right); if (rightStats.isSingleValue()) { OptionalDouble value = isNaN(rightStats.getLowValue()) ? OptionalDouble.empty() : OptionalDouble.of(rightStats.getLowValue()); - return estimateExpressionToLiteralComparison(input, leftStats, leftVariable, value, operator); + return comparisonCalculator.estimateExpressionToLiteralComparison(input, leftStats, leftVariable, value, operator); } Optional rightVariable = right instanceof SymbolReference ? Optional.of(toVariable(right)) : Optional.empty(); - return estimateExpressionToExpressionComparison(input, leftStats, leftVariable, rightStats, rightVariable, operator); + return comparisonCalculator.estimateExpressionToExpressionComparison(input, leftStats, leftVariable, rightStats, rightVariable, operator); } private Type getType(Expression expression) @@ -493,13 +500,19 @@ private class FilterRowExpressionStatsCalculatingVisitor { private final PlanNodeStatsEstimate input; private final ConnectorSession session; + private final Session systemSession; private final FunctionAndTypeManager functionAndTypeManager; + private final PlanNodeStatsEstimateMath calculator; + private final ComparisonStatsCalculator comparisonCalculator; - FilterRowExpressionStatsCalculatingVisitor(PlanNodeStatsEstimate input, ConnectorSession session, FunctionAndTypeManager functionAndTypeManager) + FilterRowExpressionStatsCalculatingVisitor(PlanNodeStatsEstimate input, ConnectorSession session, FunctionAndTypeManager functionAndTypeManager, Session systemSession) { this.input = requireNonNull(input, "input is null"); this.session = requireNonNull(session, "session is null"); this.functionAndTypeManager = requireNonNull(functionAndTypeManager, "functionManager is null"); + this.systemSession = requireNonNull(systemSession, "systemSession is null"); + this.calculator = new PlanNodeStatsEstimateMath(SystemSessionProperties.shouldOptimizerUseHistograms(systemSession)); + this.comparisonCalculator = new ComparisonStatsCalculator(systemSession); } @Override @@ -586,17 +599,17 @@ public PlanNodeStatsEstimate visitCall(CallExpression node, Void context) return visitConstant(constantNull(right.getSourceLocation(), BOOLEAN), null); } OptionalDouble literal = toStatsRepresentation(metadata.getFunctionAndTypeManager(), session, right.getType(), rightValue); - return estimateExpressionToLiteralComparison(input, leftStats, leftVariable, literal, getComparisonOperator(operatorType)); + return comparisonCalculator.estimateExpressionToLiteralComparison(input, leftStats, leftVariable, literal, getComparisonOperator(operatorType)); } VariableStatsEstimate rightStats = getRowExpressionStats(right); if (rightStats.isSingleValue()) { OptionalDouble value = isNaN(rightStats.getLowValue()) ? OptionalDouble.empty() : OptionalDouble.of(rightStats.getLowValue()); - return estimateExpressionToLiteralComparison(input, leftStats, leftVariable, value, getComparisonOperator(operatorType)); + return comparisonCalculator.estimateExpressionToLiteralComparison(input, leftStats, leftVariable, value, getComparisonOperator(operatorType)); } Optional rightVariable = right instanceof VariableReferenceExpression ? Optional.of((VariableReferenceExpression) right) : Optional.empty(); - return estimateExpressionToExpressionComparison(input, leftStats, leftVariable, rightStats, rightVariable, getComparisonOperator(operatorType)); + return comparisonCalculator.estimateExpressionToExpressionComparison(input, leftStats, leftVariable, rightStats, rightVariable, getComparisonOperator(operatorType)); } // NOT case @@ -615,7 +628,7 @@ public PlanNodeStatsEstimate visitCall(CallExpression node, Void context) } return PlanNodeStatsEstimate.unknown(); } - return subtractSubsetStats(input, process(argument)); + return calculator.subtractSubsetStats(input, process(argument)); } // BETWEEN case @@ -676,7 +689,7 @@ public PlanNodeStatsEstimate visitInputReference(InputReferenceExpression node, private FilterRowExpressionStatsCalculatingVisitor newEstimate(PlanNodeStatsEstimate input) { - return new FilterRowExpressionStatsCalculatingVisitor(input, session, functionAndTypeManager); + return new FilterRowExpressionStatsCalculatingVisitor(input, session, functionAndTypeManager, systemSession); } private PlanNodeStatsEstimate process(RowExpression rowExpression) @@ -731,9 +744,9 @@ private PlanNodeStatsEstimate estimateLogicalOr(RowExpression left, RowExpressio return PlanNodeStatsEstimate.unknown(); } - return capStats( - subtractSubsetStats( - addStatsAndSumDistinctValues(leftEstimate, rightEstimate), + return calculator.capStats( + calculator.subtractSubsetStats( + calculator.addStatsAndSumDistinctValues(leftEstimate, rightEstimate), andEstimate), input); } @@ -749,7 +762,7 @@ private PlanNodeStatsEstimate estimateIn(RowExpression value, List + * The filter factor is a fractional value in [0.0, 1.0] that represents the proportion of + * tuples in the source column that would be included in the result of a filter where the valid + * values in the filter are represented by the {@code range} parameter of this function. + * + * @param range the intersecting range with the histogram + * @param histogram the source histogram + * @param totalDistinctValues the total number of distinct values in the domain of the histogram + * @param useHeuristics whether to return heuristic values based on constants and/or distinct + * value counts. If false, {@link Estimate#unknown()} will be returned in any case a + * heuristic would have been used + * @return an estimate, x, where 0.0 <= x <= 1.0. + */ + public static Estimate calculateFilterFactor(StatisticRange range, ConnectorHistogram histogram, Estimate totalDistinctValues, boolean useHeuristics) + { + boolean openHigh = range.getOpenHigh(); + boolean openLow = range.getOpenLow(); + Estimate min = histogram.inverseCumulativeProbability(0.0); + Estimate max = histogram.inverseCumulativeProbability(1.0); + + // range is either above or below histogram + if ((!max.isUnknown() && (openHigh ? max.getValue() <= range.getLow() : max.getValue() < range.getLow())) + || (!min.isUnknown() && (openLow ? min.getValue() >= range.getHigh() : min.getValue() > range.getHigh()))) { + return Estimate.of(0.0); + } + + // one of the max/min bounds can't be determined + if ((max.isUnknown() && !min.isUnknown()) || (!max.isUnknown() && min.isUnknown())) { + // when the range length is 0, the filter factor should be 1/distinct value count + if (!useHeuristics) { + return Estimate.unknown(); + } + + if (range.length() == 0.0) { + return totalDistinctValues.map(distinct -> 1.0 / distinct); + } + + if (isFinite(range.length())) { + return Estimate.of(StatisticRange.INFINITE_TO_FINITE_RANGE_INTERSECT_OVERLAP_HEURISTIC_FACTOR); + } + return Estimate.of(StatisticRange.INFINITE_TO_INFINITE_RANGE_INTERSECT_OVERLAP_HEURISTIC_FACTOR); + } + + // we know the bounds are both known, so calculate the percentile for each bound + // The inclusivity arguments can be derived from the open-ness of the interval we're + // calculating the filter factor for + // e.g. given a variable with values in [0, 10] to calculate the filter of + // [1, 9) (openness: false, true) we need the percentile from + // [0.0 to 1.0) (openness: false, true) and from [0.0, 9.0) (openness: false, true) + // thus for the "lowPercentile" calculation we should pass "false" to be non-inclusive + // (same as openness) however, on the high-end we want the inclusivity to be the opposite + // of the openness since if it's open, we _don't_ want to include the bound. + Estimate lowPercentile = histogram.cumulativeProbability(range.getLow(), openLow); + Estimate highPercentile = histogram.cumulativeProbability(range.getHigh(), !openHigh); + + // both bounds are probably infinity, use the infinite-infinite heuristic + if (lowPercentile.isUnknown() || highPercentile.isUnknown()) { + if (!useHeuristics) { + return Estimate.unknown(); + } + // in the case the histogram has no values + if (totalDistinctValues.equals(Estimate.zero()) || range.getDistinctValuesCount() == 0.0) { + return Estimate.of(0.0); + } + + // in the case only one is unknown + if (((lowPercentile.isUnknown() && !highPercentile.isUnknown()) || + (!lowPercentile.isUnknown() && highPercentile.isUnknown())) && + isFinite(range.length())) { + return Estimate.of(StatisticRange.INFINITE_TO_FINITE_RANGE_INTERSECT_OVERLAP_HEURISTIC_FACTOR); + } + + if (range.length() == 0.0) { + return totalDistinctValues.map(distinct -> 1.0 / distinct); + } + + if (!isNaN(range.getDistinctValuesCount())) { + return totalDistinctValues.map(distinct -> min(1.0, range.getDistinctValuesCount() / distinct)); + } + + return Estimate.of(StatisticRange.INFINITE_TO_INFINITE_RANGE_INTERSECT_OVERLAP_HEURISTIC_FACTOR); + } + + // in the case the range is a single value, this can occur if the input + // filter range is a single value (low == high) OR in the case that the + // bounds of the filter or this histogram are infinite. + // in the case of infinite bounds, we should return an estimate that + // correlates to the overlapping distinct values. + if (lowPercentile.equals(highPercentile)) { + if (!useHeuristics) { + return Estimate.zero(); + } + + return totalDistinctValues.map(distinct -> 1.0 / distinct); + } + + // in the case that we return the entire range, the returned factor percent should be + // proportional to the number of distinct values in the range + if (lowPercentile.equals(Estimate.zero()) && highPercentile.equals(Estimate.of(1.0)) && min.isUnknown() && max.isUnknown()) { + if (!useHeuristics) { + return Estimate.unknown(); + } + + return totalDistinctValues.flatMap(totalDistinct -> { + if (DoubleMath.fuzzyEquals(totalDistinct, 0.0, 1E-6)) { + return Estimate.of(1.0); + } + return Estimate.of(min(1.0, range.getDistinctValuesCount() / totalDistinct)); + }) + // in the case totalDistinct is NaN or 0 + .or(() -> Estimate.of(StatisticRange.INFINITE_TO_INFINITE_RANGE_INTERSECT_OVERLAP_HEURISTIC_FACTOR)); + } + + return lowPercentile.flatMap(lowPercent -> highPercentile.map(highPercent -> highPercent - lowPercent)); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/cost/IntersectStatsRule.java b/presto-main/src/main/java/com/facebook/presto/cost/IntersectStatsRule.java index c500ec8a6701e..dc3bc31c26e7e 100644 --- a/presto-main/src/main/java/com/facebook/presto/cost/IntersectStatsRule.java +++ b/presto-main/src/main/java/com/facebook/presto/cost/IntersectStatsRule.java @@ -23,7 +23,7 @@ import java.util.Optional; -import static com.facebook.presto.cost.PlanNodeStatsEstimateMath.addStatsAndIntersect; +import static com.facebook.presto.SystemSessionProperties.shouldOptimizerUseHistograms; import static com.facebook.presto.sql.planner.plan.Patterns.intersect; import static com.google.common.base.Preconditions.checkArgument; @@ -57,7 +57,8 @@ protected Optional doCalculate( PlanNodeStatsEstimate sourceStatsWithMappedSymbols = mapToOutputSymbols(sourceStats, node, i); if (estimate.isPresent()) { - estimate = Optional.of(addStatsAndIntersect(estimate.get(), sourceStatsWithMappedSymbols)); + PlanNodeStatsEstimateMath calculator = new PlanNodeStatsEstimateMath(shouldOptimizerUseHistograms(session)); + estimate = Optional.of(calculator.addStatsAndIntersect(estimate.get(), sourceStatsWithMappedSymbols)); } else { estimate = Optional.of(sourceStatsWithMappedSymbols); diff --git a/presto-main/src/main/java/com/facebook/presto/cost/JoinStatsRule.java b/presto-main/src/main/java/com/facebook/presto/cost/JoinStatsRule.java index bf040ea81fa74..6547c0b463e1c 100644 --- a/presto-main/src/main/java/com/facebook/presto/cost/JoinStatsRule.java +++ b/presto-main/src/main/java/com/facebook/presto/cost/JoinStatsRule.java @@ -33,6 +33,8 @@ import java.util.Queue; import static com.facebook.presto.SystemSessionProperties.getDefaultJoinSelectivityCoefficient; +import static com.facebook.presto.SystemSessionProperties.shouldOptimizerUseHistograms; +import static com.facebook.presto.cost.DisjointRangeDomainHistogram.addConjunction; import static com.facebook.presto.cost.FilterStatsCalculator.UNKNOWN_FILTER_COEFFICIENT; import static com.facebook.presto.cost.VariableStatsEstimate.buildFrom; import static com.facebook.presto.expressions.LogicalRowExpressions.extractConjuncts; @@ -219,13 +221,14 @@ private PlanNodeStatsEstimate filterByEquiJoinClauses( { ComparisonExpression drivingPredicate = new ComparisonExpression(EQUAL, new SymbolReference(getNodeLocation(drivingClause.getLeft().getSourceLocation()), drivingClause.getLeft().getName()), new SymbolReference(getNodeLocation(drivingClause.getRight().getSourceLocation()), drivingClause.getRight().getName())); PlanNodeStatsEstimate filteredStats = filterStatsCalculator.filterStats(stats, drivingPredicate, session, types); + boolean useHistograms = shouldOptimizerUseHistograms(session); for (EquiJoinClause clause : remainingClauses) { - filteredStats = filterByAuxiliaryClause(filteredStats, clause); + filteredStats = filterByAuxiliaryClause(filteredStats, clause, useHistograms); } return filteredStats; } - private PlanNodeStatsEstimate filterByAuxiliaryClause(PlanNodeStatsEstimate stats, EquiJoinClause clause) + private PlanNodeStatsEstimate filterByAuxiliaryClause(PlanNodeStatsEstimate stats, EquiJoinClause clause, boolean useHistograms) { // we just clear null fraction and adjust ranges here // selectivity is mostly handled by driving clause. We just scale heuristically by UNKNOWN_FILTER_COEFFICIENT here. @@ -242,22 +245,26 @@ private PlanNodeStatsEstimate filterByAuxiliaryClause(PlanNodeStatsEstimate stat double rightNdvInRange = rightFilterValue * rightRange.getDistinctValuesCount(); double retainedNdv = MoreMath.min(leftNdvInRange, rightNdvInRange); - VariableStatsEstimate newLeftStats = buildFrom(leftStats) + VariableStatsEstimate.Builder newLeftStats = buildFrom(leftStats) .setNullsFraction(0) .setStatisticsRange(intersect) - .setDistinctValuesCount(retainedNdv) - .build(); + .setDistinctValuesCount(retainedNdv); + if (useHistograms) { + newLeftStats.setHistogram(leftStats.getHistogram().map(leftHistogram -> addConjunction(leftHistogram, intersect))); + } - VariableStatsEstimate newRightStats = buildFrom(rightStats) + VariableStatsEstimate.Builder newRightStats = buildFrom(rightStats) .setNullsFraction(0) .setStatisticsRange(intersect) - .setDistinctValuesCount(retainedNdv) - .build(); + .setDistinctValuesCount(retainedNdv); + if (useHistograms) { + newRightStats.setHistogram(rightStats.getHistogram().map(rightHistogram -> addConjunction(rightHistogram, intersect))); + } PlanNodeStatsEstimate.Builder result = PlanNodeStatsEstimate.buildFrom(stats) .setOutputRowCount(stats.getOutputRowCount() * UNKNOWN_FILTER_COEFFICIENT) - .addVariableStatistics(clause.getLeft(), newLeftStats) - .addVariableStatistics(clause.getRight(), newRightStats); + .addVariableStatistics(clause.getLeft(), newLeftStats.build()) + .addVariableStatistics(clause.getRight(), newRightStats.build()); return normalizer.normalize(result.build()); } diff --git a/presto-main/src/main/java/com/facebook/presto/cost/PlanNodeStatsEstimateMath.java b/presto-main/src/main/java/com/facebook/presto/cost/PlanNodeStatsEstimateMath.java index 5217a69b60898..1b2797e18a8af 100644 --- a/presto-main/src/main/java/com/facebook/presto/cost/PlanNodeStatsEstimateMath.java +++ b/presto-main/src/main/java/com/facebook/presto/cost/PlanNodeStatsEstimateMath.java @@ -13,6 +13,11 @@ */ package com.facebook.presto.cost; +import com.facebook.presto.spi.statistics.ConnectorHistogram; + +import java.util.Optional; + +import static com.facebook.presto.cost.DisjointRangeDomainHistogram.addConjunction; import static com.google.common.base.Preconditions.checkArgument; import static java.lang.Double.NaN; import static java.lang.Double.isNaN; @@ -22,14 +27,17 @@ public class PlanNodeStatsEstimateMath { - private PlanNodeStatsEstimateMath() - {} + private final boolean shouldUseHistograms; + public PlanNodeStatsEstimateMath(boolean shouldUseHistograms) + { + this.shouldUseHistograms = shouldUseHistograms; + } /** * Subtracts subset stats from supersets stats. * It is assumed that each NDV from subset has a matching NDV in superset. */ - public static PlanNodeStatsEstimate subtractSubsetStats(PlanNodeStatsEstimate superset, PlanNodeStatsEstimate subset) + public PlanNodeStatsEstimate subtractSubsetStats(PlanNodeStatsEstimate superset, PlanNodeStatsEstimate subset) { if (superset.isOutputRowCountUnknown() || subset.isOutputRowCountUnknown()) { return PlanNodeStatsEstimate.unknown(); @@ -100,7 +108,7 @@ else if (subsetDistinctValues == 0) { return result.build(); } - public static PlanNodeStatsEstimate capStats(PlanNodeStatsEstimate stats, PlanNodeStatsEstimate cap) + public PlanNodeStatsEstimate capStats(PlanNodeStatsEstimate stats, PlanNodeStatsEstimate cap) { if (stats.isOutputRowCountUnknown() || cap.isOutputRowCountUnknown()) { return PlanNodeStatsEstimate.unknown(); @@ -119,16 +127,20 @@ public static PlanNodeStatsEstimate capStats(PlanNodeStatsEstimate stats, PlanNo // for simplicity keep the average row size the same as in the input // in most cases the average row size doesn't change after applying filters newSymbolStats.setAverageRowSize(symbolStats.getAverageRowSize()); - newSymbolStats.setDistinctValuesCount(min(symbolStats.getDistinctValuesCount(), capSymbolStats.getDistinctValuesCount())); - newSymbolStats.setLowValue(max(symbolStats.getLowValue(), capSymbolStats.getLowValue())); - newSymbolStats.setHighValue(min(symbolStats.getHighValue(), capSymbolStats.getHighValue())); + double newLow = max(symbolStats.getLowValue(), capSymbolStats.getLowValue()); + double newHigh = min(symbolStats.getHighValue(), capSymbolStats.getHighValue()); + newSymbolStats.setLowValue(newLow); + newSymbolStats.setHighValue(newHigh); double numberOfNulls = stats.getOutputRowCount() * symbolStats.getNullsFraction(); double capNumberOfNulls = cap.getOutputRowCount() * capSymbolStats.getNullsFraction(); double cappedNumberOfNulls = min(numberOfNulls, capNumberOfNulls); double cappedNullsFraction = cappedRowCount == 0 ? 1 : cappedNumberOfNulls / cappedRowCount; newSymbolStats.setNullsFraction(cappedNullsFraction); + if (shouldUseHistograms) { + newSymbolStats.setHistogram(symbolStats.getHistogram().map(symbolHistogram -> addConjunction(symbolHistogram, new StatisticRange(newLow, newHigh, 0)))); + } result.addVariableStatistics(symbol, newSymbolStats.build()); }); @@ -144,28 +156,47 @@ private static PlanNodeStatsEstimate createZeroStats(PlanNodeStatsEstimate stats return result.build(); } + protected enum RangeAdditionStrategy + { + ADD_AND_SUM_DISTINCT(StatisticRange::addAndSumDistinctValues), + ADD_AND_MAX_DISTINCT(StatisticRange::addAndMaxDistinctValues), + ADD_AND_COLLAPSE_DISTINCT(StatisticRange::addAndCollapseDistinctValues), + INTERSECT(StatisticRange::intersect); + private final RangeAdditionFunction rangeAdditionFunction; + + RangeAdditionStrategy(RangeAdditionFunction rangeAdditionFunction) + { + this.rangeAdditionFunction = rangeAdditionFunction; + } + + public RangeAdditionFunction getRangeAdditionFunction() + { + return rangeAdditionFunction; + } + } + @FunctionalInterface - private interface RangeAdditionStrategy + protected interface RangeAdditionFunction { StatisticRange add(StatisticRange leftRange, StatisticRange rightRange); } - public static PlanNodeStatsEstimate addStatsAndSumDistinctValues(PlanNodeStatsEstimate left, PlanNodeStatsEstimate right) + public PlanNodeStatsEstimate addStatsAndSumDistinctValues(PlanNodeStatsEstimate left, PlanNodeStatsEstimate right) { - return addStats(left, right, StatisticRange::addAndSumDistinctValues); + return addStats(left, right, RangeAdditionStrategy.ADD_AND_SUM_DISTINCT); } - public static PlanNodeStatsEstimate addStatsAndMaxDistinctValues(PlanNodeStatsEstimate left, PlanNodeStatsEstimate right) + public PlanNodeStatsEstimate addStatsAndMaxDistinctValues(PlanNodeStatsEstimate left, PlanNodeStatsEstimate right) { - return addStats(left, right, StatisticRange::addAndMaxDistinctValues); + return addStats(left, right, RangeAdditionStrategy.ADD_AND_MAX_DISTINCT); } - public static PlanNodeStatsEstimate addStatsAndCollapseDistinctValues(PlanNodeStatsEstimate left, PlanNodeStatsEstimate right) + public PlanNodeStatsEstimate addStatsAndCollapseDistinctValues(PlanNodeStatsEstimate left, PlanNodeStatsEstimate right) { - return addStats(left, right, StatisticRange::addAndCollapseDistinctValues); + return addStats(left, right, RangeAdditionStrategy.ADD_AND_COLLAPSE_DISTINCT); } - public static PlanNodeStatsEstimate addStatsAndIntersect(PlanNodeStatsEstimate left, PlanNodeStatsEstimate right) + public PlanNodeStatsEstimate addStatsAndIntersect(PlanNodeStatsEstimate left, PlanNodeStatsEstimate right) { if (left.isOutputRowCountUnknown() || right.isOutputRowCountUnknown()) { return PlanNodeStatsEstimate.unknown(); @@ -185,15 +216,15 @@ public static PlanNodeStatsEstimate addStatsAndIntersect(PlanNodeStatsEstimate l right.getOutputRowCount() * rstats.overlapPercentWith(lstats)); }).reduce(Math::min).orElse(estimatedRowCount); - buildVariableStatistics(left, right, statsBuilder, rowCount, StatisticRange::intersect); + buildVariableStatistics(left, right, statsBuilder, rowCount, RangeAdditionStrategy.INTERSECT); return statsBuilder.setOutputRowCount(rowCount).build(); } - private static PlanNodeStatsEstimate addStats( + private PlanNodeStatsEstimate addStats( PlanNodeStatsEstimate left, PlanNodeStatsEstimate right, - RangeAdditionStrategy rangeAdder) + RangeAdditionStrategy strategy) { double rowCount = left.getOutputRowCount() + right.getOutputRowCount(); double totalSize = left.getTotalSize() + right.getTotalSize(); @@ -203,18 +234,18 @@ private static PlanNodeStatsEstimate addStats( } PlanNodeStatsEstimate.Builder statsBuilder = PlanNodeStatsEstimate.builder(); - buildVariableStatistics(left, right, statsBuilder, rowCount, rangeAdder); + buildVariableStatistics(left, right, statsBuilder, rowCount, strategy); return statsBuilder.setOutputRowCount(rowCount) .setTotalSize(totalSize).build(); } - private static void buildVariableStatistics( + private void buildVariableStatistics( PlanNodeStatsEstimate left, PlanNodeStatsEstimate right, PlanNodeStatsEstimate.Builder statsBuilder, double estimatedRowCount, - RangeAdditionStrategy rangeAdder) + RangeAdditionStrategy strategy) { concat(left.getVariablesWithKnownStatistics().stream(), right.getVariablesWithKnownStatistics().stream()) .distinct() @@ -230,13 +261,13 @@ else if (estimatedRowCount > 0) { right.getVariableStatistics(symbol), right.getOutputRowCount(), estimatedRowCount, - rangeAdder); + strategy); } statsBuilder.addVariableStatistics(symbol, symbolStats); }); } - private static VariableStatsEstimate addColumnStats( + private VariableStatsEstimate addColumnStats( VariableStatsEstimate leftStats, double leftRows, VariableStatsEstimate rightStats, @@ -249,7 +280,7 @@ private static VariableStatsEstimate addColumnStats( StatisticRange leftRange = StatisticRange.from(leftStats); StatisticRange rightRange = StatisticRange.from(rightStats); - StatisticRange sum = strategy.add(leftRange, rightRange); + StatisticRange sum = strategy.getRangeAdditionFunction().add(leftRange, rightRange); double nullsCountRight = rightStats.getNullsFraction() * rightRows; double nullsCountLeft = leftStats.getNullsFraction() * leftRows; double totalSizeLeft = (leftRows - nullsCountLeft) * leftStats.getAverageRowSize(); @@ -259,11 +290,17 @@ private static VariableStatsEstimate addColumnStats( // FIXME, weights to average. left and right should be equal in most cases anyway double newAverageRowSize = newNonNullsRowCount == 0 ? 0 : ((totalSizeLeft + totalSizeRight) / newNonNullsRowCount); - - return VariableStatsEstimate.builder() + VariableStatsEstimate.Builder statistics = VariableStatsEstimate.builder() .setStatisticsRange(sum) .setAverageRowSize(newAverageRowSize) - .setNullsFraction(newNullsFraction) - .build(); + .setNullsFraction(newNullsFraction); + if (shouldUseHistograms) { + Optional newHistogram = RangeAdditionStrategy.INTERSECT == strategy ? + leftStats.getHistogram().map(leftHistogram -> DisjointRangeDomainHistogram.addConjunction(leftHistogram, rightRange)) : + leftStats.getHistogram().map(leftHistogram -> DisjointRangeDomainHistogram.addDisjunction(leftHistogram, rightRange)); + statistics.setHistogram(newHistogram); + } + + return statistics.build(); } } diff --git a/presto-main/src/main/java/com/facebook/presto/cost/RemoteSourceStatsRule.java b/presto-main/src/main/java/com/facebook/presto/cost/RemoteSourceStatsRule.java index 72bd512a3ca0e..1699428bf0adf 100644 --- a/presto-main/src/main/java/com/facebook/presto/cost/RemoteSourceStatsRule.java +++ b/presto-main/src/main/java/com/facebook/presto/cost/RemoteSourceStatsRule.java @@ -23,6 +23,7 @@ import java.util.Optional; +import static com.facebook.presto.SystemSessionProperties.shouldOptimizerUseHistograms; import static com.facebook.presto.sql.planner.plan.Patterns.remoteSource; import static java.util.Objects.requireNonNull; @@ -49,8 +50,9 @@ public Pattern getPattern() protected Optional doCalculate(RemoteSourceNode node, StatsProvider sourceStats, Lookup lookup, Session session, TypeProvider types) { QueryId queryId = session.getQueryId(); + PlanNodeStatsEstimateMath calculator = new PlanNodeStatsEstimateMath(shouldOptimizerUseHistograms(session)); return node.getSourceFragmentIds().stream() .map(fragmentId -> fragmentStatsProvider.getStats(queryId, fragmentId)) - .reduce(PlanNodeStatsEstimateMath::addStatsAndCollapseDistinctValues); + .reduce(calculator::addStatsAndCollapseDistinctValues); } } diff --git a/presto-main/src/main/java/com/facebook/presto/cost/StatisticRange.java b/presto-main/src/main/java/com/facebook/presto/cost/StatisticRange.java index b613f655dc01d..4d80e13cb92cc 100644 --- a/presto-main/src/main/java/com/facebook/presto/cost/StatisticRange.java +++ b/presto-main/src/main/java/com/facebook/presto/cost/StatisticRange.java @@ -13,11 +13,19 @@ */ package com.facebook.presto.cost; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.collect.BoundType; +import com.google.common.collect.Range; + import java.util.Objects; +import static com.facebook.presto.util.MoreMath.nearlyEqual; import static com.google.common.base.MoreObjects.toStringHelper; import static com.google.common.base.Preconditions.checkArgument; +import static java.lang.Double.NEGATIVE_INFINITY; import static java.lang.Double.NaN; +import static java.lang.Double.POSITIVE_INFINITY; import static java.lang.Double.isFinite; import static java.lang.Double.isInfinite; import static java.lang.Double.isNaN; @@ -28,22 +36,36 @@ public class StatisticRange { - private static final double INFINITE_TO_FINITE_RANGE_INTERSECT_OVERLAP_HEURISTIC_FACTOR = 0.25; - private static final double INFINITE_TO_INFINITE_RANGE_INTERSECT_OVERLAP_HEURISTIC_FACTOR = 0.5; + protected static final double INFINITE_TO_FINITE_RANGE_INTERSECT_OVERLAP_HEURISTIC_FACTOR = 0.25; + protected static final double INFINITE_TO_INFINITE_RANGE_INTERSECT_OVERLAP_HEURISTIC_FACTOR = 0.5; // TODO unify field and method names with SymbolStatsEstimate /** * {@code NaN} represents empty range ({@code high} must be {@code NaN} too) */ private final double low; + + /** + * Whether the low side of the range is open. e.g. The value is *not* included in the range. + */ + private final boolean openLow; /** * {@code NaN} represents empty range ({@code low} must be {@code NaN} too) */ private final double high; + /** + * Whether the high side of the range is open. e.g. the value is *not* included in the range. + */ + private final boolean openHigh; private final double distinctValues; - public StatisticRange(double low, double high, double distinctValues) + @JsonCreator + public StatisticRange(@JsonProperty("low") double low, + @JsonProperty("openLow") boolean openLow, + @JsonProperty("high") double high, + @JsonProperty("openHigh") boolean openHigh, + @JsonProperty("distinctValuesCount") double distinctValues) { checkArgument( low <= high || (isNaN(low) && isNaN(high)), @@ -52,36 +74,58 @@ public StatisticRange(double low, double high, double distinctValues) high); this.low = low; this.high = high; + this.openLow = openLow; + this.openHigh = openHigh; checkArgument(distinctValues >= 0 || isNaN(distinctValues), "Distinct values count should be non-negative, got: %s", distinctValues); this.distinctValues = distinctValues; } + public StatisticRange(double low, double high, double distinctValues) + { + this(low, false, high, false, distinctValues); + } + public static StatisticRange empty() { - return new StatisticRange(NaN, NaN, 0); + return new StatisticRange(NaN, false, NaN, false, 0); } public static StatisticRange from(VariableStatsEstimate estimate) { - return new StatisticRange(estimate.getLowValue(), estimate.getHighValue(), estimate.getDistinctValuesCount()); + return new StatisticRange(estimate.getLowValue(), false, estimate.getHighValue(), false, estimate.getDistinctValuesCount()); } + @JsonProperty public double getLow() { return low; } + @JsonProperty public double getHigh() { return high; } + @JsonProperty public double getDistinctValuesCount() { return distinctValues; } + @JsonProperty + public boolean getOpenLow() + { + return openLow; + } + + @JsonProperty + public boolean getOpenHigh() + { + return openHigh; + } + public double length() { return high - low; @@ -142,9 +186,14 @@ private double overlappingDistinctValues(StatisticRange other) public StatisticRange intersect(StatisticRange other) { double newLow = max(low, other.low); + boolean newOpenLow = newLow == low ? openLow : other.openLow; + // epsilon is an arbitrary choice + newOpenLow = nearlyEqual(low, other.low, 1E-10) ? openLow || other.openLow : newOpenLow; double newHigh = min(high, other.high); + boolean newOpenHigh = newHigh == high ? openHigh : other.openHigh; + newOpenHigh = nearlyEqual(high, other.high, 1E-10) ? openHigh || other.openHigh : newOpenHigh; if (newLow <= newHigh) { - return new StatisticRange(newLow, newHigh, overlappingDistinctValues(other)); + return new StatisticRange(newLow, newOpenLow, newHigh, newOpenHigh, overlappingDistinctValues(other)); } return empty(); } @@ -152,13 +201,13 @@ public StatisticRange intersect(StatisticRange other) public StatisticRange addAndSumDistinctValues(StatisticRange other) { double newDistinctValues = distinctValues + other.distinctValues; - return new StatisticRange(minExcludeNaN(low, other.low), maxExcludeNaN(high, other.high), newDistinctValues); + return expandRangeWithNewDistinct(newDistinctValues, other); } public StatisticRange addAndMaxDistinctValues(StatisticRange other) { double newDistinctValues = max(distinctValues, other.distinctValues); - return new StatisticRange(minExcludeNaN(low, other.low), maxExcludeNaN(high, other.high), newDistinctValues); + return expandRangeWithNewDistinct(newDistinctValues, other); } public StatisticRange addAndCollapseDistinctValues(StatisticRange other) @@ -170,7 +219,41 @@ public StatisticRange addAndCollapseDistinctValues(StatisticRange other) double maxOverlappingValues = max(overlapDistinctValuesThis, overlapDistinctValuesOther); double newDistinctValues = maxOverlappingValues + (1 - overlapPercentOfThis) * distinctValues + (1 - overlapPercentOfOther) * other.distinctValues; - return new StatisticRange(minExcludeNaN(low, other.low), maxExcludeNaN(high, other.high), newDistinctValues); + return expandRangeWithNewDistinct(newDistinctValues, other); + } + + public Range toRange() + { + return Range.range(low, openLow ? BoundType.OPEN : BoundType.CLOSED, high, openHigh ? BoundType.OPEN : BoundType.CLOSED); + } + + public static StatisticRange fromRange(Range range) + { + return new StatisticRange( + range.hasLowerBound() ? range.lowerEndpoint() : NEGATIVE_INFINITY, + !range.hasLowerBound() || range.lowerBoundType() == BoundType.OPEN, + range.hasUpperBound() ? range.upperEndpoint() : POSITIVE_INFINITY, + !range.hasUpperBound() || range.upperBoundType() == BoundType.OPEN, + NaN); + } + + private StatisticRange expandRangeWithNewDistinct(double newDistinctValues, StatisticRange other) + { + double newLow = minExcludeNaN(low, other.low); + boolean newOpenLow = getNewEndpointOpennessLow(this, other, newLow); + double newHigh = maxExcludeNaN(high, other.high); + boolean newOpenHigh = getNewEndpointOpennessHigh(this, other, newHigh); + return new StatisticRange(newLow, newOpenLow, newHigh, newOpenHigh, newDistinctValues); + } + + private static boolean getNewEndpointOpennessLow(StatisticRange first, StatisticRange second, double newLow) + { + return newLow == first.low ? first.openLow : second.openLow; + } + + private static boolean getNewEndpointOpennessHigh(StatisticRange first, StatisticRange second, double newHigh) + { + return newHigh == first.high ? first.openHigh : second.openHigh; } private static double minExcludeNaN(double v1, double v2) @@ -206,21 +289,23 @@ public boolean equals(Object o) } StatisticRange that = (StatisticRange) o; return Double.compare(that.low, low) == 0 && + that.openLow == openLow && Double.compare(that.high, high) == 0 && + that.openHigh == openHigh && Double.compare(that.distinctValues, distinctValues) == 0; } @Override public int hashCode() { - return Objects.hash(low, high, distinctValues); + return Objects.hash(low, openLow, high, openHigh, distinctValues); } @Override public String toString() { return toStringHelper(this) - .add("range", format("[%s-%s]", low, high)) + .add("range", format("%s%s..%s%s", openLow ? "(" : "[", low, high, openHigh ? ")" : "]")) .add("ndv", distinctValues) .toString(); } diff --git a/presto-main/src/main/java/com/facebook/presto/cost/StatsUtil.java b/presto-main/src/main/java/com/facebook/presto/cost/StatsUtil.java index 3e33257e3f57e..a71de776f6f44 100644 --- a/presto-main/src/main/java/com/facebook/presto/cost/StatsUtil.java +++ b/presto-main/src/main/java/com/facebook/presto/cost/StatsUtil.java @@ -90,6 +90,7 @@ public static VariableStatsEstimate toVariableStatsEstimate(TableStatistics tabl result.setLowValue(range.getMin()); result.setHighValue(range.getMax()); }); + result.setHistogram(columnStatistics.getHistogram()); return result.build(); } } diff --git a/presto-main/src/main/java/com/facebook/presto/cost/UniformDistributionHistogram.java b/presto-main/src/main/java/com/facebook/presto/cost/UniformDistributionHistogram.java new file mode 100644 index 0000000000000..d06232d1fb608 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/cost/UniformDistributionHistogram.java @@ -0,0 +1,148 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.presto.cost; + +import com.facebook.presto.spi.statistics.ConnectorHistogram; +import com.facebook.presto.spi.statistics.Estimate; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import static com.google.common.base.MoreObjects.toStringHelper; +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Verify.verify; +import static java.lang.Double.isInfinite; +import static java.lang.Double.isNaN; +import static java.lang.Math.max; +import static java.lang.Math.min; +import static java.util.Objects.hash; + +/** + * This {@link ConnectorHistogram} implementation returns values assuming a + * uniform distribution between a given high and low value. + *
+ * In the case that statistics don't exist for a particular table, the Presto + * optimizer will fall back on this uniform distribution assumption. + */ +public class UniformDistributionHistogram + implements ConnectorHistogram +{ + private final double lowValue; + private final double highValue; + + @JsonCreator + public UniformDistributionHistogram( + @JsonProperty("lowValue") double lowValue, + @JsonProperty("highValue") double highValue) + { + verify(isNaN(lowValue) || isNaN(highValue) || (lowValue <= highValue), "lowValue must be <= highValue"); + this.lowValue = lowValue; + this.highValue = highValue; + } + + @JsonProperty + public double getLowValue() + { + return lowValue; + } + + @JsonProperty + public double getHighValue() + { + return highValue; + } + + @Override + public Estimate cumulativeProbability(double value, boolean inclusive) + { + if (isNaN(lowValue) || + isNaN(highValue) || + isNaN(value)) { + return Estimate.unknown(); + } + + if (value >= highValue) { + return Estimate.of(1.0); + } + + if (value <= lowValue) { + return Estimate.of(0.0); + } + + if (isInfinite(lowValue) || isInfinite(highValue)) { + return Estimate.unknown(); + } + + return Estimate.of(min(1.0, max(0.0, ((value - lowValue) / (highValue - lowValue))))); + } + + @Override + public Estimate inverseCumulativeProbability(double percentile) + { + checkArgument(percentile >= 0.0 && percentile <= 1.0, "percentile must be in [0.0, 1.0]: " + percentile); + if (isNaN(lowValue) || + isNaN(highValue)) { + return Estimate.unknown(); + } + + if (percentile == 0.0 && !isInfinite(lowValue)) { + return Estimate.of(lowValue); + } + + if (percentile == 1.0 && !isInfinite(highValue)) { + return Estimate.of(highValue); + } + + if (isInfinite(lowValue) || isInfinite(highValue)) { + return Estimate.unknown(); + } + + return Estimate.of(lowValue + (percentile * (highValue - lowValue))); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("lowValue", lowValue) + .add("highValue", highValue) + .toString(); + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (!(o instanceof UniformDistributionHistogram)) { + return false; + } + + UniformDistributionHistogram other = (UniformDistributionHistogram) o; + return equalsOrBothNaN(lowValue, other.lowValue) && + equalsOrBothNaN(highValue, other.highValue); + } + + @Override + public int hashCode() + { + return hash(lowValue, highValue); + } + + private static boolean equalsOrBothNaN(Double first, Double second) + { + return first.equals(second) || (Double.isNaN(first) && Double.isNaN(second)); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/cost/UnionStatsRule.java b/presto-main/src/main/java/com/facebook/presto/cost/UnionStatsRule.java index d5b9ebad66589..3ae9f62eb8b14 100644 --- a/presto-main/src/main/java/com/facebook/presto/cost/UnionStatsRule.java +++ b/presto-main/src/main/java/com/facebook/presto/cost/UnionStatsRule.java @@ -23,7 +23,7 @@ import java.util.Optional; -import static com.facebook.presto.cost.PlanNodeStatsEstimateMath.addStatsAndCollapseDistinctValues; +import static com.facebook.presto.SystemSessionProperties.shouldOptimizerUseHistograms; import static com.facebook.presto.sql.planner.plan.Patterns.union; import static com.google.common.base.Preconditions.checkArgument; @@ -56,7 +56,8 @@ protected final Optional doCalculate(UnionNode node, Stat PlanNodeStatsEstimate sourceStatsWithMappedSymbols = mapToOutputSymbols(sourceStats, node, i); if (estimate.isPresent()) { - estimate = Optional.of(addStatsAndCollapseDistinctValues(estimate.get(), sourceStatsWithMappedSymbols)); + PlanNodeStatsEstimateMath calculator = new PlanNodeStatsEstimateMath(shouldOptimizerUseHistograms(session)); + estimate = Optional.of(calculator.addStatsAndCollapseDistinctValues(estimate.get(), sourceStatsWithMappedSymbols)); } else { estimate = Optional.of(sourceStatsWithMappedSymbols); diff --git a/presto-main/src/main/java/com/facebook/presto/cost/VariableStatsEstimate.java b/presto-main/src/main/java/com/facebook/presto/cost/VariableStatsEstimate.java index 2e13cf5a0fde7..4c94b1d6b65b7 100644 --- a/presto-main/src/main/java/com/facebook/presto/cost/VariableStatsEstimate.java +++ b/presto-main/src/main/java/com/facebook/presto/cost/VariableStatsEstimate.java @@ -13,10 +13,13 @@ */ package com.facebook.presto.cost; +import com.facebook.presto.spi.statistics.ConnectorHistogram; import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonProperty; import java.util.Objects; +import java.util.Optional; import java.util.function.Function; import static com.google.common.base.MoreObjects.toStringHelper; @@ -27,6 +30,7 @@ import static java.lang.Double.isInfinite; import static java.lang.Double.isNaN; import static java.lang.String.format; +import static java.util.Objects.requireNonNull; public class VariableStatsEstimate { @@ -39,6 +43,7 @@ public class VariableStatsEstimate private final double nullsFraction; private final double averageRowSize; private final double distinctValuesCount; + private final Optional histogram; public static VariableStatsEstimate unknown() { @@ -50,13 +55,13 @@ public static VariableStatsEstimate zero() return ZERO; } - @JsonCreator public VariableStatsEstimate( - @JsonProperty("lowValue") double lowValue, - @JsonProperty("highValue") double highValue, - @JsonProperty("nullsFraction") double nullsFraction, - @JsonProperty("averageRowSize") double averageRowSize, - @JsonProperty("distinctValuesCount") double distinctValuesCount) + double lowValue, + double highValue, + double nullsFraction, + double averageRowSize, + double distinctValuesCount, + Optional histogram) { checkArgument( lowValue <= highValue || (isNaN(lowValue) && isNaN(highValue)), @@ -79,6 +84,18 @@ public VariableStatsEstimate( checkArgument(distinctValuesCount >= 0 || isNaN(distinctValuesCount), "Distinct values count should be non-negative, got: %s", distinctValuesCount); // TODO normalize distinctValuesCount for an empty range (or validate it is already normalized) this.distinctValuesCount = distinctValuesCount; + this.histogram = requireNonNull(histogram, "histogram is null"); + } + + @JsonCreator + public VariableStatsEstimate( + @JsonProperty("lowValue") double lowValue, + @JsonProperty("highValue") double highValue, + @JsonProperty("nullsFraction") double nullsFraction, + @JsonProperty("averageRowSize") double averageRowSize, + @JsonProperty("distinctValuesCount") double distinctValuesCount) + { + this(lowValue, highValue, nullsFraction, averageRowSize, distinctValuesCount, Optional.empty()); } @JsonProperty @@ -99,6 +116,15 @@ public double getNullsFraction() return nullsFraction; } + // We ignore the histogram during serialization because histograms can be + // quite large. Histograms are not used outside the coordinator, so there + // isn't a need to serialize them + @JsonIgnore + public Optional getHistogram() + { + return histogram; + } + public StatisticRange statisticRange() { return new StatisticRange(lowValue, highValue, distinctValuesCount); @@ -153,6 +179,8 @@ public boolean equals(Object o) return false; } VariableStatsEstimate that = (VariableStatsEstimate) o; + // histograms are explicitly left out because equals calculations would + // be expensive. return Double.compare(nullsFraction, that.nullsFraction) == 0 && Double.compare(averageRowSize, that.averageRowSize) == 0 && Double.compare(distinctValuesCount, that.distinctValuesCount) == 0 && @@ -174,6 +202,7 @@ public String toString() .add("nulls", nullsFraction) .add("ndv", distinctValuesCount) .add("rowSize", averageRowSize) + .add("histogram", histogram) .toString(); } @@ -189,7 +218,8 @@ public static Builder buildFrom(VariableStatsEstimate other) .setHighValue(other.getHighValue()) .setNullsFraction(other.getNullsFraction()) .setAverageRowSize(other.getAverageRowSize()) - .setDistinctValuesCount(other.getDistinctValuesCount()); + .setDistinctValuesCount(other.getDistinctValuesCount()) + .setHistogram(other.getHistogram()); } public static final class Builder @@ -199,6 +229,7 @@ public static final class Builder private double nullsFraction = NaN; private double averageRowSize = NaN; private double distinctValuesCount = NaN; + private Optional histogram = Optional.empty(); public Builder setStatisticsRange(StatisticRange range) { @@ -237,9 +268,16 @@ public Builder setDistinctValuesCount(double distinctValuesCount) return this; } + public Builder setHistogram(Optional histogram) + { + this.histogram = histogram; + return this; + } + public VariableStatsEstimate build() { - return new VariableStatsEstimate(lowValue, highValue, nullsFraction, averageRowSize, distinctValuesCount); + return new VariableStatsEstimate(lowValue, highValue, nullsFraction, averageRowSize, distinctValuesCount, + histogram); } } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java index c4cf91b70999c..60e0d4930cee3 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java @@ -309,6 +309,7 @@ public class FeaturesConfig private boolean generateDomainFilters; private boolean printEstimatedStatsFromCache; private CreateView.Security defaultViewSecurityMode = DEFINER; + private boolean useHistograms; public enum PartitioningPrecisionStrategy { @@ -3115,4 +3116,17 @@ public FeaturesConfig setPrintEstimatedStatsFromCache(boolean printEstimatedStat this.printEstimatedStatsFromCache = printEstimatedStatsFromCache; return this; } + + public boolean isUseHistograms() + { + return useHistograms; + } + + @Config("optimizer.use-histograms") + @ConfigDescription("Use histogram statistics in cost-based calculations in the optimizer") + public FeaturesConfig setUseHistograms(boolean useHistograms) + { + this.useHistograms = useHistograms; + return this; + } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/rewrite/ShowStatsRewrite.java b/presto-main/src/main/java/com/facebook/presto/sql/rewrite/ShowStatsRewrite.java index f1679c77eea61..7a7396824e091 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/rewrite/ShowStatsRewrite.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/rewrite/ShowStatsRewrite.java @@ -265,6 +265,7 @@ private static List buildColumnsNames() .add("row_count") .add("low_value") .add("high_value") + .add("histogram") .build(); } @@ -310,6 +311,7 @@ private Row createColumnStatsRow(String columnName, Type type, ColumnStatistics rowValues.add(NULL_DOUBLE); rowValues.add(toStringLiteral(type, columnStatistics.getRange().map(DoubleRange::getMin))); rowValues.add(toStringLiteral(type, columnStatistics.getRange().map(DoubleRange::getMax))); + rowValues.add(columnStatistics.getHistogram().map(Objects::toString).map(StringLiteral::new).orElse(NULL_VARCHAR)); return new Row(rowValues.build()); } @@ -323,6 +325,7 @@ private Expression createEmptyColumnStatsRow(String columnName) rowValues.add(NULL_DOUBLE); rowValues.add(NULL_VARCHAR); rowValues.add(NULL_VARCHAR); + rowValues.add(NULL_VARCHAR); return new Row(rowValues.build()); } @@ -336,6 +339,7 @@ private static Row createTableStatsRow(TableStatistics tableStatistics) rowValues.add(createEstimateRepresentation(tableStatistics.getRowCount())); rowValues.add(NULL_VARCHAR); rowValues.add(NULL_VARCHAR); + rowValues.add(NULL_VARCHAR); return new Row(rowValues.build()); } diff --git a/presto-main/src/test/java/com/facebook/presto/cost/TestComparisonStatsCalculator.java b/presto-main/src/test/java/com/facebook/presto/cost/AbstractTestComparisonStatsCalculator.java similarity index 99% rename from presto-main/src/test/java/com/facebook/presto/cost/TestComparisonStatsCalculator.java rename to presto-main/src/test/java/com/facebook/presto/cost/AbstractTestComparisonStatsCalculator.java index 310b35dcfc1e4..66c2ceef405fd 100644 --- a/presto-main/src/test/java/com/facebook/presto/cost/TestComparisonStatsCalculator.java +++ b/presto-main/src/test/java/com/facebook/presto/cost/AbstractTestComparisonStatsCalculator.java @@ -37,6 +37,7 @@ import java.util.Optional; import java.util.function.Consumer; +import static com.facebook.presto.SystemSessionProperties.OPTIMIZER_USE_HISTOGRAMS; import static com.facebook.presto.common.type.DoubleType.DOUBLE; import static com.facebook.presto.sql.tree.ComparisonExpression.Operator.EQUAL; import static com.facebook.presto.sql.tree.ComparisonExpression.Operator.GREATER_THAN; @@ -51,7 +52,8 @@ import static java.lang.String.format; import static java.util.stream.Collectors.joining; -public class TestComparisonStatsCalculator +@Test +public abstract class AbstractTestComparisonStatsCalculator { private FilterStatsCalculator filterStatsCalculator; private Session session; @@ -68,11 +70,17 @@ public class TestComparisonStatsCalculator private VariableStatsEstimate emptyRangeStats; private VariableStatsEstimate varcharStats; + public AbstractTestComparisonStatsCalculator(boolean withHistograms) + { + session = testSessionBuilder() + .setSystemProperty(OPTIMIZER_USE_HISTOGRAMS, Boolean.toString(withHistograms)) + .build(); + } + @BeforeClass public void setUp() throws Exception { - session = testSessionBuilder().build(); MetadataManager metadata = MetadataManager.createTestMetadataManager(); filterStatsCalculator = new FilterStatsCalculator(metadata, new ScalarStatsCalculator(metadata), new StatsNormalizer()); diff --git a/presto-main/src/test/java/com/facebook/presto/cost/TestFilterStatsCalculator.java b/presto-main/src/test/java/com/facebook/presto/cost/AbstractTestFilterStatsCalculator.java similarity index 97% rename from presto-main/src/test/java/com/facebook/presto/cost/TestFilterStatsCalculator.java rename to presto-main/src/test/java/com/facebook/presto/cost/AbstractTestFilterStatsCalculator.java index 50528418e55c8..c8f919326f92d 100644 --- a/presto-main/src/test/java/com/facebook/presto/cost/TestFilterStatsCalculator.java +++ b/presto-main/src/test/java/com/facebook/presto/cost/AbstractTestFilterStatsCalculator.java @@ -27,6 +27,7 @@ import java.util.Optional; +import static com.facebook.presto.SystemSessionProperties.OPTIMIZER_USE_HISTOGRAMS; import static com.facebook.presto.common.type.DoubleType.DOUBLE; import static com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder.expression; import static com.facebook.presto.testing.TestingSession.testSessionBuilder; @@ -36,7 +37,7 @@ import static java.lang.String.format; import static org.testng.Assert.assertEquals; -public class TestFilterStatsCalculator +public abstract class AbstractTestFilterStatsCalculator { private static final VarcharType MEDIUM_VARCHAR_TYPE = VarcharType.createVarcharType(100); @@ -54,6 +55,13 @@ public class TestFilterStatsCalculator private Session session; private TestingRowExpressionTranslator translator; + public AbstractTestFilterStatsCalculator(boolean withHistograms) + { + session = testSessionBuilder() + .setSystemProperty(OPTIMIZER_USE_HISTOGRAMS, Boolean.toString(withHistograms)) + .build(); + } + @BeforeClass public void setUp() throws Exception @@ -137,7 +145,6 @@ public void setUp() .add(new VariableReferenceExpression(Optional.empty(), "mediumVarchar", MEDIUM_VARCHAR_TYPE)) .build()); - session = testSessionBuilder().build(); MetadataManager metadata = MetadataManager.createTestMetadataManager(); statsCalculator = new FilterStatsCalculator(metadata, new ScalarStatsCalculator(metadata), new StatsNormalizer()); translator = new TestingRowExpressionTranslator(MetadataManager.createTestMetadataManager()); @@ -376,6 +383,33 @@ public void testIsNotNullFilter() .variableStats(new VariableReferenceExpression(Optional.empty(), "emptyRange", DOUBLE), VariableStatsAssertion::empty); } + @Test + public void testBetweenOperatorFilterLeftOpen() + { + // Left side open, cut on open side + assertExpression("leftOpen BETWEEN DOUBLE '-10' AND 10e0") + .outputRowsCount(180.0) + .variableStats(new VariableReferenceExpression(Optional.empty(), "leftOpen", DOUBLE), variableStats -> + variableStats.distinctValuesCount(10.0) + .lowValue(-10.0) + .highValue(10.0) + .nullsFraction(0.0)); + } + + @Test + public void testBetweenOperatorFilterRightOpen() + { + // Left side open, cut on open side + // Right side open, cut on open side + assertExpression("rightOpen BETWEEN DOUBLE '-10' AND 10e0") + .outputRowsCount(180.0) + .variableStats(new VariableReferenceExpression(Optional.empty(), "rightOpen", DOUBLE), variableStats -> + variableStats.distinctValuesCount(10.0) + .lowValue(-10.0) + .highValue(10.0) + .nullsFraction(0.0)); + } + @Test public void testBetweenOperatorFilter() { @@ -422,24 +456,6 @@ public void testBetweenOperatorFilter() .highValue(3.14) .nullsFraction(0.0)); - // Left side open, cut on open side - assertExpression("leftOpen BETWEEN DOUBLE '-10' AND 10e0") - .outputRowsCount(180.0) - .variableStats(new VariableReferenceExpression(Optional.empty(), "leftOpen", DOUBLE), variableStats -> - variableStats.distinctValuesCount(10.0) - .lowValue(-10.0) - .highValue(10.0) - .nullsFraction(0.0)); - - // Right side open, cut on open side - assertExpression("rightOpen BETWEEN DOUBLE '-10' AND 10e0") - .outputRowsCount(180.0) - .variableStats(new VariableReferenceExpression(Optional.empty(), "rightOpen", DOUBLE), variableStats -> - variableStats.distinctValuesCount(10.0) - .lowValue(-10.0) - .highValue(10.0) - .nullsFraction(0.0)); - // Filter all assertExpression("y BETWEEN 27.5e0 AND 107e0") .outputRowsCount(0.0) @@ -588,7 +604,7 @@ public void testInPredicateFilter() .nullsFraction(0.0)); } - private PlanNodeStatsAssertion assertExpression(String expression) + protected PlanNodeStatsAssertion assertExpression(String expression) { return assertExpression(expression(expression)); } diff --git a/presto-main/src/test/java/com/facebook/presto/cost/TestComparisonStatsCalculatorHistograms.java b/presto-main/src/test/java/com/facebook/presto/cost/TestComparisonStatsCalculatorHistograms.java new file mode 100644 index 0000000000000..e46090a11d333 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/cost/TestComparisonStatsCalculatorHistograms.java @@ -0,0 +1,23 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.cost; + +public class TestComparisonStatsCalculatorHistograms + extends AbstractTestComparisonStatsCalculator +{ + public TestComparisonStatsCalculatorHistograms() + { + super(true); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/cost/TestComparisonStatsCalculatorNoHistograms.java b/presto-main/src/test/java/com/facebook/presto/cost/TestComparisonStatsCalculatorNoHistograms.java new file mode 100644 index 0000000000000..1fb353fd9fa21 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/cost/TestComparisonStatsCalculatorNoHistograms.java @@ -0,0 +1,23 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.cost; + +public class TestComparisonStatsCalculatorNoHistograms + extends AbstractTestComparisonStatsCalculator +{ + public TestComparisonStatsCalculatorNoHistograms() + { + super(false); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/cost/TestConnectorFilterStatsCalculatorService.java b/presto-main/src/test/java/com/facebook/presto/cost/TestConnectorFilterStatsCalculatorService.java index 5bc06c514ddbb..584f55140e3a7 100644 --- a/presto-main/src/test/java/com/facebook/presto/cost/TestConnectorFilterStatsCalculatorService.java +++ b/presto-main/src/test/java/com/facebook/presto/cost/TestConnectorFilterStatsCalculatorService.java @@ -98,20 +98,20 @@ public void testTableStatisticsAfterFilter() TableStatistics filteredToZeroStatistics = TableStatistics.builder() .setRowCount(Estimate.zero()) .setTotalSize(Estimate.zero()) - .setColumnStatistics(xColumn, new ColumnStatistics(Estimate.of(1.0), Estimate.zero(), Estimate.zero(), Optional.empty())) + .setColumnStatistics(xColumn, new ColumnStatistics(Estimate.of(1.0), Estimate.zero(), Estimate.zero(), Optional.empty(), Optional.empty())) .build(); assertPredicate("false", originalTableStatistics, filteredToZeroStatistics); TableStatistics filteredStatistics = TableStatistics.builder() .setRowCount(Estimate.of(37.5)) .setTotalSize(Estimate.of(300)) - .setColumnStatistics(xColumn, new ColumnStatistics(Estimate.zero(), Estimate.of(20), Estimate.unknown(), Optional.of(new DoubleRange(-10, 0)))) + .setColumnStatistics(xColumn, new ColumnStatistics(Estimate.zero(), Estimate.of(20), Estimate.unknown(), Optional.of(new DoubleRange(-10, 0)), Optional.empty())) .build(); assertPredicate("x < 0", originalTableStatistics, filteredStatistics); TableStatistics filteredStatisticsWithoutTotalSize = TableStatistics.builder() .setRowCount(Estimate.of(37.5)) - .setColumnStatistics(xColumn, new ColumnStatistics(Estimate.zero(), Estimate.of(20), Estimate.unknown(), Optional.of(new DoubleRange(-10, 0)))) + .setColumnStatistics(xColumn, new ColumnStatistics(Estimate.zero(), Estimate.of(20), Estimate.unknown(), Optional.of(new DoubleRange(-10, 0)), Optional.empty())) .build(); assertPredicate("x < 0", originalTableStatisticsWithoutTotalSize, filteredStatisticsWithoutTotalSize); } diff --git a/presto-main/src/test/java/com/facebook/presto/cost/TestDisjointRangeDomainHistogram.java b/presto-main/src/test/java/com/facebook/presto/cost/TestDisjointRangeDomainHistogram.java new file mode 100644 index 0000000000000..1cbddcc781c58 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/cost/TestDisjointRangeDomainHistogram.java @@ -0,0 +1,288 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.presto.cost; + +import com.facebook.presto.spi.statistics.ConnectorHistogram; +import com.facebook.presto.spi.statistics.Estimate; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Range; +import org.apache.commons.math3.distribution.NormalDistribution; +import org.apache.commons.math3.distribution.RealDistribution; +import org.apache.commons.math3.distribution.UniformRealDistribution; +import org.testng.annotations.Test; + +import java.util.List; +import java.util.stream.Collectors; + +import static org.testng.Assert.assertEquals; + +public class TestDisjointRangeDomainHistogram + extends TestHistogram +{ + /** + * A uniform base with 2 ranges that are fully within the range of the uniform histogram. + */ + @Test + public void testBasicDisjointRanges() + { + ConnectorHistogram source = new UniformDistributionHistogram(0, 100); + ConnectorHistogram constrained = DisjointRangeDomainHistogram + .addDisjunction(source, StatisticRange.fromRange(Range.open(0d, 25d))); + constrained = DisjointRangeDomainHistogram + .addDisjunction(constrained, StatisticRange.fromRange(Range.open(75d, 100d))); + assertEquals(constrained.inverseCumulativeProbability(0.75).getValue(), 87.5); + assertEquals(constrained.inverseCumulativeProbability(0.0).getValue(), 0.0); + assertEquals(constrained.inverseCumulativeProbability(1.0).getValue(), 100); + assertEquals(constrained.inverseCumulativeProbability(0.5).getValue(), 25); + } + + /** + * A uniform base with a range that (1) doesn't have any overlap with the base distribution (2) + * has partial overlap (both ends of the base) and (3) complete overlap. + */ + @Test + public void testSingleDisjointRange() + { + ConnectorHistogram source = new UniformDistributionHistogram(0, 10); + + // no overlap, left bound + ConnectorHistogram constrained = DisjointRangeDomainHistogram + .addDisjunction(source, StatisticRange.fromRange(Range.open(-10d, -5d))); + for (int i = -11; i < 12; i++) { + assertEquals(constrained.cumulativeProbability(i, true).getValue(), 0.0, 1E-8); + assertEquals(constrained.cumulativeProbability(i, false).getValue(), 0.0, 1E-8); + } + assertEquals(constrained.inverseCumulativeProbability(0.0), Estimate.unknown()); + assertEquals(constrained.inverseCumulativeProbability(1.0), Estimate.unknown()); + + // partial overlap left bound + constrained = new DisjointRangeDomainHistogram(source, ImmutableSet.of(Range.open(-2d, 2d))); + assertEquals(constrained.cumulativeProbability(-3, false).getValue(), 0.0, 1E-8); + assertEquals(constrained.cumulativeProbability(-1, false).getValue(), 0.0, 1E-8); + assertEquals(constrained.cumulativeProbability(0, false).getValue(), 0.0, 1E-8); + assertEquals(constrained.cumulativeProbability(1, false).getValue(), 0.5, 1E-8); + assertEquals(constrained.cumulativeProbability(1.5, false).getValue(), 0.75, 1E-8); + assertEquals(constrained.cumulativeProbability(2, false).getValue(), 1.0, 1E-8); + assertEquals(constrained.cumulativeProbability(4, false).getValue(), 1.0, 1E-8); + assertEquals(constrained.inverseCumulativeProbability(0.0).getValue(), 0d, 1E-8); + assertEquals(constrained.inverseCumulativeProbability(0.5).getValue(), 1d, 1E-8); + assertEquals(constrained.inverseCumulativeProbability(0.75).getValue(), 1.5d, 1E-8); + assertEquals(constrained.inverseCumulativeProbability(1.0).getValue(), 2d, 1E-8); + + //full overlap + constrained = new DisjointRangeDomainHistogram(source, ImmutableSet.of(Range.open(3d, 4d))); + assertEquals(constrained.cumulativeProbability(-3, false).getValue(), 0.0, 1E-8); + assertEquals(constrained.cumulativeProbability(0, false).getValue(), 0.0, 1E-8); + assertEquals(constrained.cumulativeProbability(1, false).getValue(), 0.0, 1E-8); + assertEquals(constrained.cumulativeProbability(3, false).getValue(), 0.0, 1E-8); + assertEquals(constrained.cumulativeProbability(3.5, false).getValue(), 0.5, 1E-8); + assertEquals(constrained.cumulativeProbability(4, false).getValue(), 1.0, 1E-8); + assertEquals(constrained.cumulativeProbability(4.5, false).getValue(), 1.0, 1E-8); + assertEquals(constrained.inverseCumulativeProbability(0.0).getValue(), 3d, 1E-8); + assertEquals(constrained.inverseCumulativeProbability(0.5).getValue(), 3.5d, 1E-8); + assertEquals(constrained.inverseCumulativeProbability(0.75).getValue(), 3.75d, 1E-8); + assertEquals(constrained.inverseCumulativeProbability(1.0).getValue(), 4d, 1E-8); + + //right side overlap + constrained = new DisjointRangeDomainHistogram(source, ImmutableSet.of(Range.open(8d, 12d))); + assertEquals(constrained.cumulativeProbability(-3, false).getValue(), 0.0, 1E-8); + assertEquals(constrained.cumulativeProbability(0, false).getValue(), 0.0, 1E-8); + assertEquals(constrained.cumulativeProbability(5, false).getValue(), 0.0, 1E-8); + assertEquals(constrained.cumulativeProbability(8, false).getValue(), 0.0, 1E-8); + assertEquals(constrained.cumulativeProbability(9, false).getValue(), 0.5, 1E-8); + assertEquals(constrained.cumulativeProbability(9.5, false).getValue(), 0.75, 1E-8); + assertEquals(constrained.cumulativeProbability(10, false).getValue(), 1.0, 1E-8); + assertEquals(constrained.cumulativeProbability(11, false).getValue(), 1.0, 1E-8); + assertEquals(constrained.cumulativeProbability(12, false).getValue(), 1.0, 1E-8); + assertEquals(constrained.cumulativeProbability(13, false).getValue(), 1.0, 1E-8); + assertEquals(constrained.inverseCumulativeProbability(0.0).getValue(), 8d, 1E-8); + assertEquals(constrained.inverseCumulativeProbability(0.5).getValue(), 9d, 1E-8); + assertEquals(constrained.inverseCumulativeProbability(0.75).getValue(), 9.5d, 1E-8); + assertEquals(constrained.inverseCumulativeProbability(1.0).getValue(), 10d, 1E-8); + + // no overlap, right bound + constrained = DisjointRangeDomainHistogram + .addDisjunction(source, StatisticRange.fromRange(Range.open(15d, 20d))); + for (int i = 15; i < 20; i++) { + assertEquals(constrained.cumulativeProbability(i, true).getValue(), 0.0, 1E-8); + assertEquals(constrained.cumulativeProbability(i, false).getValue(), 0.0, 1E-8); + } + assertEquals(constrained.inverseCumulativeProbability(0.0), Estimate.unknown()); + assertEquals(constrained.inverseCumulativeProbability(1.0), Estimate.unknown()); + } + + /** + * Tests that calculations across N > 1 disjunctions applied to the source histogram are + * calculated properly. + */ + @Test + public void testMultipleDisjunction() + { + StandardNormalHistogram source = new StandardNormalHistogram(); + RealDistribution dist = source.getDistribution(); + ConnectorHistogram constrained = disjunction(source, Range.closed(-2d, -1d)); + constrained = disjunction(constrained, Range.closed(1d, 2d)); + double rangeLeftProb = dist.cumulativeProbability(-1) - dist.cumulativeProbability(-2); + double rangeRightProb = dist.cumulativeProbability(2) - dist.cumulativeProbability(1); + double sumRangeProb = rangeLeftProb + rangeRightProb; + assertEquals(constrained.cumulativeProbability(-2, true).getValue(), 0.0, 1E-8); + assertEquals(constrained.cumulativeProbability(-1.5, true).getValue(), (dist.cumulativeProbability(-1.5d) - dist.cumulativeProbability(-2)) / sumRangeProb, 1E-8); + assertEquals(constrained.cumulativeProbability(-1, true).getValue(), 0.5, 1E-8); + assertEquals(constrained.cumulativeProbability(1, true).getValue(), 0.5, 1E-8); + assertEquals(constrained.cumulativeProbability(1.5, true).getValue(), (rangeLeftProb / sumRangeProb) + ((dist.cumulativeProbability(1.5) - dist.cumulativeProbability(1.0)) / sumRangeProb)); + assertEquals(constrained.cumulativeProbability(2, true).getValue(), 1.0, 1E-8); + assertEquals(constrained.cumulativeProbability(3, true).getValue(), 1.0, 1E-8); + } + + /** + * Ensures assumptions made in tests for uniform distributions apply correctly for + * a non-uniform distribution. + */ + @Test + public void testNormalDistribution() + { + // standard normal + StandardNormalHistogram source = new StandardNormalHistogram(); + RealDistribution dist = source.getDistribution(); + ConnectorHistogram constrained = new DisjointRangeDomainHistogram(source, ImmutableSet.of(Range.open(-1d, 1d))); + assertEquals(constrained.cumulativeProbability(-1.0, true).getValue(), 0.0, 1E-8); + assertEquals(constrained.cumulativeProbability(0.0, true).getValue(), 0.5, 1E-8); + assertEquals(constrained.cumulativeProbability(1.0, true).getValue(), 1.0, 1E-8); + double probability = (dist.cumulativeProbability(-0.5) - dist.cumulativeProbability(-1.0)) / (dist.cumulativeProbability(1.0) - dist.cumulativeProbability(-1)); + assertEquals(constrained.cumulativeProbability(-0.5, true).getValue(), probability, 1E-8); + assertEquals(constrained.cumulativeProbability(0.5, true).getValue(), probability + (1.0 - (2 * probability)), 1E-8); + + assertEquals(constrained.inverseCumulativeProbability(0.0).getValue(), -1.0d, 1E-8); + probability = dist.inverseCumulativeProbability(dist.cumulativeProbability(-1) + 0.25 * (dist.cumulativeProbability(1) - dist.cumulativeProbability(-1))); + assertEquals(constrained.inverseCumulativeProbability(0.25).getValue(), -0.44177054668d, 1E-8); + assertEquals(constrained.inverseCumulativeProbability(0.5).getValue(), 0.0d, 1E-8); + assertEquals(constrained.inverseCumulativeProbability(0.75).getValue(), -1 * probability, 1E-8); + assertEquals(constrained.inverseCumulativeProbability(1.0).getValue(), 1.0d, 1E-8); + } + + /** + * Ensures disjunctions of ranges works properly + */ + @Test + public void testAddDisjunction() + { + ConnectorHistogram source = new UniformDistributionHistogram(0, 100); + DisjointRangeDomainHistogram constrained = disjunction(source, Range.open(-1d, 2d)); + assertEquals(constrained.getRanges().size(), 1); + assertEquals(ranges(constrained).get(0), Range.closedOpen(0d, 2d)); + constrained = disjunction(constrained, Range.open(1d, 10d)); + assertEquals(ranges(constrained).size(), 1); + assertEquals(ranges(constrained).get(0), Range.closedOpen(0d, 10d)); + constrained = disjunction(constrained, Range.closedOpen(50d, 100d)); + assertEquals(ranges(constrained).size(), 2); + assertEquals(ranges(constrained).get(0), Range.closedOpen(0d, 10d)); + assertEquals(ranges(constrained).get(1), Range.closedOpen(50d, 100d)); + } + + /** + * Ensures conjunctions of ranges works properly + */ + @Test + public void testAddConjunction() + { + ConnectorHistogram source = new UniformDistributionHistogram(0, 100); + DisjointRangeDomainHistogram constrained = disjunction(source, Range.open(10d, 90d)); + assertEquals(constrained.getRanges().size(), 1); + assertEquals(ranges(constrained).get(0), Range.open(10d, 90d)); + constrained = conjunction(constrained, Range.atMost(50d)); + assertEquals(ranges(constrained).size(), 1); + assertEquals(ranges(constrained).get(0), Range.openClosed(10d, 50d)); + constrained = conjunction(constrained, Range.atLeast(25d)); + assertEquals(ranges(constrained).size(), 1); + assertEquals(ranges(constrained).get(0), Range.closed(25d, 50d)); + } + + private static DisjointRangeDomainHistogram disjunction(ConnectorHistogram source, Range range) + { + return (DisjointRangeDomainHistogram) DisjointRangeDomainHistogram.addDisjunction(source, StatisticRange.fromRange(range)); + } + + private static DisjointRangeDomainHistogram conjunction(ConnectorHistogram source, Range range) + { + return (DisjointRangeDomainHistogram) DisjointRangeDomainHistogram.addConjunction(source, StatisticRange.fromRange(range)); + } + + private static List> ranges(DisjointRangeDomainHistogram hist) + { + return hist.getRanges().stream().map(StatisticRange::toRange).collect(Collectors.toList()); + } + + private static class StandardNormalHistogram + implements ConnectorHistogram + { + private final NormalDistribution distribution = new NormalDistribution(); + + public NormalDistribution getDistribution() + { + return distribution; + } + + @Override + public Estimate cumulativeProbability(double value, boolean inclusive) + { + return Estimate.of(distribution.cumulativeProbability(value)); + } + + @Override + public Estimate inverseCumulativeProbability(double percentile) + { + // assume lower/upper limit is 10, in order to not throw + // exception, even though technically the bounds are technically + // INF + if (percentile <= 0.0) { + return Estimate.of(-10); + } + if (percentile >= 1.0) { + return Estimate.of(10); + } + return Estimate.of(distribution.inverseCumulativeProbability(percentile)); + } + } + + @Override + ConnectorHistogram createHistogram() + { + RealDistribution distribution = getDistribution(); + return new DisjointRangeDomainHistogram( + new UniformDistributionHistogram( + distribution.getSupportLowerBound(), distribution.getSupportUpperBound())) + .addDisjunction(new StatisticRange(0.0, 100.0, 0.0)); + } + + @Override + double getDistinctValues() + { + return 100; + } + + @Override + RealDistribution getDistribution() + { + return new UniformRealDistribution(0.0, 100.0); + } + + /** + * Support depends on the underlying distribution. + */ + @Override + public void testInclusiveExclusive() + { + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/cost/TestFilterStatsCalculatorHistograms.java b/presto-main/src/test/java/com/facebook/presto/cost/TestFilterStatsCalculatorHistograms.java new file mode 100644 index 0000000000000..1f71fe5ebb78c --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/cost/TestFilterStatsCalculatorHistograms.java @@ -0,0 +1,23 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.cost; + +public class TestFilterStatsCalculatorHistograms + extends AbstractTestFilterStatsCalculator +{ + public TestFilterStatsCalculatorHistograms() + { + super(true); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/cost/TestFilterStatsCalculatorNoHistograms.java b/presto-main/src/test/java/com/facebook/presto/cost/TestFilterStatsCalculatorNoHistograms.java new file mode 100644 index 0000000000000..f101def2d8ce2 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/cost/TestFilterStatsCalculatorNoHistograms.java @@ -0,0 +1,23 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.cost; + +public class TestFilterStatsCalculatorNoHistograms + extends AbstractTestFilterStatsCalculator +{ + public TestFilterStatsCalculatorNoHistograms() + { + super(false); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/cost/TestHistogram.java b/presto-main/src/test/java/com/facebook/presto/cost/TestHistogram.java new file mode 100644 index 0000000000000..26c68b7e5730e --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/cost/TestHistogram.java @@ -0,0 +1,88 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.presto.cost; + +import com.facebook.presto.spi.statistics.ConnectorHistogram; +import org.apache.commons.math3.distribution.RealDistribution; +import org.testng.annotations.Test; + +import static java.lang.Double.NEGATIVE_INFINITY; +import static java.lang.Double.POSITIVE_INFINITY; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertThrows; +import static org.testng.Assert.assertTrue; + +public abstract class TestHistogram +{ + abstract ConnectorHistogram createHistogram(); + + abstract RealDistribution getDistribution(); + + abstract double getDistinctValues(); + + @Test + public void testInverseCumulativeProbability() + { + ConnectorHistogram hist = createHistogram(); + RealDistribution dist = getDistribution(); + assertThrows(IllegalArgumentException.class, () -> hist.inverseCumulativeProbability(Double.NaN)); + assertThrows(IllegalArgumentException.class, () -> hist.inverseCumulativeProbability(-1.0)); + assertThrows(IllegalArgumentException.class, () -> hist.inverseCumulativeProbability(2.0)); + assertEquals(hist.inverseCumulativeProbability(0.0).getValue(), dist.getSupportLowerBound(), .001); + assertEquals(hist.inverseCumulativeProbability(0.25).getValue(), dist.inverseCumulativeProbability(0.25), .001); + assertEquals(hist.inverseCumulativeProbability(0.5).getValue(), dist.getNumericalMean(), .001); + assertEquals(hist.inverseCumulativeProbability(1.0).getValue(), dist.getSupportUpperBound(), .001); + } + + @Test + public void testCumulativeProbability() + { + ConnectorHistogram hist = createHistogram(); + RealDistribution dist = getDistribution(); + + assertTrue(hist.cumulativeProbability(Double.NaN, true).isUnknown()); + assertEquals(hist.cumulativeProbability(NEGATIVE_INFINITY, true).getValue(), 0.0, .001); + assertEquals(hist.cumulativeProbability(NEGATIVE_INFINITY, false).getValue(), 0.0, .001); + assertEquals(hist.cumulativeProbability(POSITIVE_INFINITY, true).getValue(), 1.0, .001); + assertEquals(hist.cumulativeProbability(POSITIVE_INFINITY, false).getValue(), 1.0, .001); + + assertEquals(hist.cumulativeProbability(dist.getSupportLowerBound() - 1, true).getValue(), 0.0, .001); + assertEquals(hist.cumulativeProbability(dist.getSupportLowerBound(), true).getValue(), 0.0, .001); + assertEquals(hist.cumulativeProbability(dist.getSupportUpperBound() + 1, true).getValue(), 1.0, .001); + assertEquals(hist.cumulativeProbability(dist.getSupportUpperBound(), true).getValue(), 1.0, .001); + assertEquals(hist.cumulativeProbability(dist.getNumericalMean(), true).getValue(), 0.5, .001); + for (int i = 0; i < 10; i++) { + assertEquals(hist.cumulativeProbability(dist.inverseCumulativeProbability(0.1 * i), true).getValue(), dist.cumulativeProbability(dist.inverseCumulativeProbability(0.1 * i)), .001); + } + } + + @Test + public void testInclusiveExclusive() + { + double ndvs = getDistinctValues(); + ConnectorHistogram hist = createHistogram(); + // test maximums + assertEquals(hist.cumulativeProbability(hist.inverseCumulativeProbability(1.0).getValue(), false).getValue(), 1.0 - (1.0 / ndvs), .0001); + assertEquals(hist.cumulativeProbability(hist.inverseCumulativeProbability(1.0).getValue(), true).getValue(), 1.0, .0001); + + // test minimums + assertEquals(hist.cumulativeProbability(hist.inverseCumulativeProbability(0.0).getValue(), false).getValue(), 0.0, .0001); + assertEquals(hist.cumulativeProbability(hist.inverseCumulativeProbability(0.0).getValue(), true).getValue(), 0.0, .0001); + + // test non-max/min + double midPercent = hist.inverseCumulativeProbability(0.5).getValue(); + assertEquals(hist.cumulativeProbability(midPercent, true).getValue() - hist.cumulativeProbability(midPercent, false).getValue(), 1.0 / ndvs, .0001); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/cost/TestHistogramCalculator.java b/presto-main/src/test/java/com/facebook/presto/cost/TestHistogramCalculator.java new file mode 100644 index 0000000000000..ddccfdfe3c065 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/cost/TestHistogramCalculator.java @@ -0,0 +1,100 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.presto.cost; + +import com.facebook.presto.spi.statistics.ConnectorHistogram; +import com.facebook.presto.spi.statistics.Estimate; +import org.testng.annotations.Test; + +import static com.facebook.presto.cost.HistogramCalculator.calculateFilterFactor; +import static java.lang.Double.NEGATIVE_INFINITY; +import static java.lang.Double.NaN; +import static java.lang.Double.POSITIVE_INFINITY; +import static org.testng.Assert.assertEquals; + +public class TestHistogramCalculator +{ + @Test + public void testCalculateFilterFactor() + { + StatisticRange zeroToTen = range(0, 10, 10); + StatisticRange empty = StatisticRange.empty(); + + // Equal ranges + assertFilterFactor(Estimate.of(1.0), zeroToTen, uniformHist(0, 10), 5); + assertFilterFactor(Estimate.of(1.0), zeroToTen, uniformHist(0, 10), 20); + + // Some overlap + assertFilterFactor(Estimate.of(0.5), range(5, 3000, 5), uniformHist(zeroToTen), zeroToTen.getDistinctValuesCount()); + + // Single value overlap + assertFilterFactor(Estimate.of(1.0 / zeroToTen.getDistinctValuesCount()), range(3, 3, 1), uniformHist(zeroToTen), zeroToTen.getDistinctValuesCount()); + assertFilterFactor(Estimate.of(1.0 / zeroToTen.getDistinctValuesCount()), range(10, 100, 357), uniformHist(zeroToTen), zeroToTen.getDistinctValuesCount()); + + // No overlap + assertFilterFactor(Estimate.zero(), range(20, 30, 10), uniformHist(zeroToTen), zeroToTen.getDistinctValuesCount()); + + // Empty ranges + assertFilterFactor(Estimate.zero(), zeroToTen, uniformHist(empty), empty.getDistinctValuesCount()); + assertFilterFactor(Estimate.zero(), empty, uniformHist(zeroToTen), zeroToTen.getDistinctValuesCount()); + + // no test for (empty, empty) since any return value is correct + assertFilterFactor(Estimate.zero(), unboundedRange(10), uniformHist(empty), empty.getDistinctValuesCount()); + assertFilterFactor(Estimate.zero(), empty, uniformHist(unboundedRange(10)), 10); + + // Unbounded (infinite), NDV-based + assertFilterFactor(Estimate.of(0.5), unboundedRange(10), uniformHist(unboundedRange(20)), 20); + assertFilterFactor(Estimate.of(1.0), unboundedRange(20), uniformHist(unboundedRange(10)), 10); + + // NEW TESTS (TPC-H Q2) + // unbounded ranges + assertFilterFactor(Estimate.of(.5), unboundedRange(0.5), uniformHist(unboundedRange(NaN)), NaN); + // unbounded ranges with limited distinct values + assertFilterFactor(Estimate.of(0.2), unboundedRange(1.0), + domainConstrained(unboundedRange(5.0), uniformHist(unboundedRange(7.0))), 5.0); + } + + private static StatisticRange range(double low, double high, double distinctValues) + { + return new StatisticRange(low, high, distinctValues); + } + + private static StatisticRange unboundedRange(double distinctValues) + { + return new StatisticRange(NEGATIVE_INFINITY, POSITIVE_INFINITY, distinctValues); + } + + private static void assertFilterFactor(Estimate expected, StatisticRange range, ConnectorHistogram histogram, double totalDistinctValues) + { + assertEquals( + calculateFilterFactor(range, histogram, Estimate.estimateFromDouble(totalDistinctValues), true), + expected); + } + + private static ConnectorHistogram uniformHist(StatisticRange range) + { + return uniformHist(range.getLow(), range.getHigh()); + } + + private static ConnectorHistogram uniformHist(double low, double high) + { + return new UniformDistributionHistogram(low, high); + } + + private static ConnectorHistogram domainConstrained(StatisticRange range, ConnectorHistogram source) + { + return DisjointRangeDomainHistogram.addDisjunction(source, range); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/cost/TestPlanNodeStatsEstimateMath.java b/presto-main/src/test/java/com/facebook/presto/cost/TestPlanNodeStatsEstimateMath.java index ba67a40f3054d..dc2e60bb46e8d 100644 --- a/presto-main/src/test/java/com/facebook/presto/cost/TestPlanNodeStatsEstimateMath.java +++ b/presto-main/src/test/java/com/facebook/presto/cost/TestPlanNodeStatsEstimateMath.java @@ -14,15 +14,13 @@ package com.facebook.presto.cost; import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.spi.statistics.ConnectorHistogram; import org.testng.annotations.Test; import java.util.Optional; +import java.util.function.BiFunction; import static com.facebook.presto.common.type.BigintType.BIGINT; -import static com.facebook.presto.cost.PlanNodeStatsEstimateMath.addStatsAndMaxDistinctValues; -import static com.facebook.presto.cost.PlanNodeStatsEstimateMath.addStatsAndSumDistinctValues; -import static com.facebook.presto.cost.PlanNodeStatsEstimateMath.capStats; -import static com.facebook.presto.cost.PlanNodeStatsEstimateMath.subtractSubsetStats; import static com.facebook.presto.testing.assertions.Assert.assertEquals; import static java.lang.Double.NEGATIVE_INFINITY; import static java.lang.Double.NaN; @@ -32,6 +30,7 @@ public class TestPlanNodeStatsEstimateMath { private static final VariableReferenceExpression VARIABLE = new VariableReferenceExpression(Optional.empty(), "variable", BIGINT); private static final StatisticRange NON_EMPTY_RANGE = openRange(1); + private final PlanNodeStatsEstimateMath calculator = new PlanNodeStatsEstimateMath(true); @Test public void testAddRowCount() @@ -40,10 +39,10 @@ public void testAddRowCount() PlanNodeStatsEstimate first = statistics(10, NaN, NaN, NaN, StatisticRange.empty()); PlanNodeStatsEstimate second = statistics(20, NaN, NaN, NaN, StatisticRange.empty()); - assertEquals(addStatsAndSumDistinctValues(unknownStats, unknownStats), PlanNodeStatsEstimate.unknown()); - assertEquals(addStatsAndSumDistinctValues(first, unknownStats), PlanNodeStatsEstimate.unknown()); - assertEquals(addStatsAndSumDistinctValues(unknownStats, second), PlanNodeStatsEstimate.unknown()); - assertEquals(addStatsAndSumDistinctValues(first, second).getOutputRowCount(), 30.0); + assertEquals(calculator.addStatsAndSumDistinctValues(unknownStats, unknownStats), PlanNodeStatsEstimate.unknown()); + assertEquals(calculator.addStatsAndSumDistinctValues(first, unknownStats), PlanNodeStatsEstimate.unknown()); + assertEquals(calculator.addStatsAndSumDistinctValues(unknownStats, second), PlanNodeStatsEstimate.unknown()); + assertEquals(calculator.addStatsAndSumDistinctValues(first, second).getOutputRowCount(), 30.0); } @Test @@ -53,10 +52,10 @@ public void testAddTotalSize() PlanNodeStatsEstimate first = statistics(NaN, 10, NaN, NaN, StatisticRange.empty()); PlanNodeStatsEstimate second = statistics(NaN, 20, NaN, NaN, StatisticRange.empty()); - assertEquals(addStatsAndSumDistinctValues(unknownStats, unknownStats), PlanNodeStatsEstimate.unknown()); - assertEquals(addStatsAndSumDistinctValues(first, unknownStats), PlanNodeStatsEstimate.unknown()); - assertEquals(addStatsAndSumDistinctValues(unknownStats, second), PlanNodeStatsEstimate.unknown()); - assertEquals(addStatsAndSumDistinctValues(first, second).getTotalSize(), 30.0); + assertEquals(calculator.addStatsAndSumDistinctValues(unknownStats, unknownStats), PlanNodeStatsEstimate.unknown()); + assertEquals(calculator.addStatsAndSumDistinctValues(first, unknownStats), PlanNodeStatsEstimate.unknown()); + assertEquals(calculator.addStatsAndSumDistinctValues(unknownStats, second), PlanNodeStatsEstimate.unknown()); + assertEquals(calculator.addStatsAndSumDistinctValues(first, second).getTotalSize(), 30.0); } @Test @@ -78,9 +77,9 @@ public void testAddNullsFraction() assertAddNullsFraction(fractionalRowCountFirst, fractionalRowCountSecond, 0.2333333333333333); } - private static void assertAddNullsFraction(PlanNodeStatsEstimate first, PlanNodeStatsEstimate second, double expected) + private void assertAddNullsFraction(PlanNodeStatsEstimate first, PlanNodeStatsEstimate second, double expected) { - assertEquals(addStatsAndSumDistinctValues(first, second).getVariableStatistics(VARIABLE).getNullsFraction(), expected); + assertEquals(calculator.addStatsAndSumDistinctValues(first, second).getVariableStatistics(VARIABLE).getNullsFraction(), expected); } @Test @@ -104,9 +103,9 @@ public void testAddAverageRowSize() assertAddAverageRowSize(fractionalRowCountFirst, fractionalRowCountSecond, 0.3608695652173913); } - private static void assertAddAverageRowSize(PlanNodeStatsEstimate first, PlanNodeStatsEstimate second, double expected) + private void assertAddAverageRowSize(PlanNodeStatsEstimate first, PlanNodeStatsEstimate second, double expected) { - assertEquals(addStatsAndSumDistinctValues(first, second).getVariableStatistics(VARIABLE).getAverageRowSize(), expected); + assertEquals(calculator.addStatsAndSumDistinctValues(first, second).getVariableStatistics(VARIABLE).getAverageRowSize(), expected); } @Test @@ -125,9 +124,9 @@ public void testSumNumberOfDistinctValues() assertSumNumberOfDistinctValues(first, second, 5); } - private static void assertSumNumberOfDistinctValues(PlanNodeStatsEstimate first, PlanNodeStatsEstimate second, double expected) + private void assertSumNumberOfDistinctValues(PlanNodeStatsEstimate first, PlanNodeStatsEstimate second, double expected) { - assertEquals(addStatsAndSumDistinctValues(first, second).getVariableStatistics(VARIABLE).getDistinctValuesCount(), expected); + assertEquals(calculator.addStatsAndSumDistinctValues(first, second).getVariableStatistics(VARIABLE).getDistinctValuesCount(), expected); } @Test @@ -146,9 +145,9 @@ public void testMaxNumberOfDistinctValues() assertMaxNumberOfDistinctValues(first, second, 3); } - private static void assertMaxNumberOfDistinctValues(PlanNodeStatsEstimate first, PlanNodeStatsEstimate second, double expected) + private void assertMaxNumberOfDistinctValues(PlanNodeStatsEstimate first, PlanNodeStatsEstimate second, double expected) { - assertEquals(addStatsAndMaxDistinctValues(first, second).getVariableStatistics(VARIABLE).getDistinctValuesCount(), expected); + assertEquals(calculator.addStatsAndMaxDistinctValues(first, second).getVariableStatistics(VARIABLE).getDistinctValuesCount(), expected); } @Test @@ -167,9 +166,9 @@ public void testAddRange() assertAddRange(first, second, 12, 200); } - private static void assertAddRange(PlanNodeStatsEstimate first, PlanNodeStatsEstimate second, double expectedLow, double expectedHigh) + private void assertAddRange(PlanNodeStatsEstimate first, PlanNodeStatsEstimate second, double expectedLow, double expectedHigh) { - VariableStatsEstimate statistics = addStatsAndMaxDistinctValues(first, second).getVariableStatistics(VARIABLE); + VariableStatsEstimate statistics = calculator.addStatsAndMaxDistinctValues(first, second).getVariableStatistics(VARIABLE); assertEquals(statistics.getLowValue(), expectedLow); assertEquals(statistics.getHighValue(), expectedHigh); } @@ -181,10 +180,10 @@ public void testSubtractRowCount() PlanNodeStatsEstimate first = statistics(40, NaN, NaN, NaN, StatisticRange.empty()); PlanNodeStatsEstimate second = statistics(10, NaN, NaN, NaN, StatisticRange.empty()); - assertEquals(subtractSubsetStats(unknownStats, unknownStats), PlanNodeStatsEstimate.unknown()); - assertEquals(subtractSubsetStats(first, unknownStats), PlanNodeStatsEstimate.unknown()); - assertEquals(subtractSubsetStats(unknownStats, second), PlanNodeStatsEstimate.unknown()); - assertEquals(subtractSubsetStats(first, second).getOutputRowCount(), 30.0); + assertEquals(calculator.subtractSubsetStats(unknownStats, unknownStats), PlanNodeStatsEstimate.unknown()); + assertEquals(calculator.subtractSubsetStats(first, unknownStats), PlanNodeStatsEstimate.unknown()); + assertEquals(calculator.subtractSubsetStats(unknownStats, second), PlanNodeStatsEstimate.unknown()); + assertEquals(calculator.subtractSubsetStats(first, second).getOutputRowCount(), 30.0); } @Test @@ -205,9 +204,9 @@ public void testSubtractNullsFraction() assertSubtractNullsFraction(fractionalRowCountFirst, fractionalRowCountSecond, 0.019999999999999993); } - private static void assertSubtractNullsFraction(PlanNodeStatsEstimate first, PlanNodeStatsEstimate second, double expected) + private void assertSubtractNullsFraction(PlanNodeStatsEstimate first, PlanNodeStatsEstimate second, double expected) { - assertEquals(subtractSubsetStats(first, second).getVariableStatistics(VARIABLE).getNullsFraction(), expected); + assertEquals(calculator.subtractSubsetStats(first, second).getVariableStatistics(VARIABLE).getNullsFraction(), expected); } @Test @@ -229,9 +228,9 @@ public void testSubtractNumberOfDistinctValues() assertSubtractNumberOfDistinctValues(second, third, 5); } - private static void assertSubtractNumberOfDistinctValues(PlanNodeStatsEstimate first, PlanNodeStatsEstimate second, double expected) + private void assertSubtractNumberOfDistinctValues(PlanNodeStatsEstimate first, PlanNodeStatsEstimate second, double expected) { - assertEquals(subtractSubsetStats(first, second).getVariableStatistics(VARIABLE).getDistinctValuesCount(), expected); + assertEquals(calculator.subtractSubsetStats(first, second).getVariableStatistics(VARIABLE).getDistinctValuesCount(), expected); } @Test @@ -246,11 +245,11 @@ public void testSubtractRange() assertSubtractRange(0, 2, 0.5, 1, 0, 2); } - private static void assertSubtractRange(double supersetLow, double supersetHigh, double subsetLow, double subsetHigh, double expectedLow, double expectedHigh) + private void assertSubtractRange(double supersetLow, double supersetHigh, double subsetLow, double subsetHigh, double expectedLow, double expectedHigh) { PlanNodeStatsEstimate first = statistics(30, NaN, NaN, NaN, new StatisticRange(supersetLow, supersetHigh, 10)); PlanNodeStatsEstimate second = statistics(20, NaN, NaN, NaN, new StatisticRange(subsetLow, subsetHigh, 5)); - VariableStatsEstimate statistics = subtractSubsetStats(first, second).getVariableStatistics(VARIABLE); + VariableStatsEstimate statistics = calculator.subtractSubsetStats(first, second).getVariableStatistics(VARIABLE); assertEquals(statistics.getLowValue(), expectedLow); assertEquals(statistics.getHighValue(), expectedHigh); } @@ -262,11 +261,11 @@ public void testCapRowCount() PlanNodeStatsEstimate first = statistics(20, NaN, NaN, NaN, NON_EMPTY_RANGE); PlanNodeStatsEstimate second = statistics(10, NaN, NaN, NaN, NON_EMPTY_RANGE); - assertEquals(capStats(unknownRowCount, unknownRowCount).getOutputRowCount(), NaN); - assertEquals(capStats(first, unknownRowCount).getOutputRowCount(), NaN); - assertEquals(capStats(unknownRowCount, second).getOutputRowCount(), NaN); - assertEquals(capStats(first, second).getOutputRowCount(), 10.0); - assertEquals(capStats(second, first).getOutputRowCount(), 10.0); + assertEquals(calculator.capStats(unknownRowCount, unknownRowCount).getOutputRowCount(), NaN); + assertEquals(calculator.capStats(first, unknownRowCount).getOutputRowCount(), NaN); + assertEquals(calculator.capStats(unknownRowCount, second).getOutputRowCount(), NaN); + assertEquals(calculator.capStats(first, second).getOutputRowCount(), 10.0); + assertEquals(calculator.capStats(second, first).getOutputRowCount(), 10.0); } @Test @@ -286,9 +285,9 @@ public void testCapAverageRowSize() assertCapAverageRowSize(first, second, 10); } - private static void assertCapAverageRowSize(PlanNodeStatsEstimate stats, PlanNodeStatsEstimate cap, double expected) + private void assertCapAverageRowSize(PlanNodeStatsEstimate stats, PlanNodeStatsEstimate cap, double expected) { - assertEquals(capStats(stats, cap).getVariableStatistics(VARIABLE).getAverageRowSize(), expected); + assertEquals(calculator.capStats(stats, cap).getVariableStatistics(VARIABLE).getAverageRowSize(), expected); } @Test @@ -306,9 +305,9 @@ public void testCapNumberOfDistinctValues() assertCapNumberOfDistinctValues(first, second, 5); } - private static void assertCapNumberOfDistinctValues(PlanNodeStatsEstimate stats, PlanNodeStatsEstimate cap, double expected) + private void assertCapNumberOfDistinctValues(PlanNodeStatsEstimate stats, PlanNodeStatsEstimate cap, double expected) { - assertEquals(capStats(stats, cap).getVariableStatistics(VARIABLE).getDistinctValuesCount(), expected); + assertEquals(calculator.capStats(stats, cap).getVariableStatistics(VARIABLE).getDistinctValuesCount(), expected); } @Test @@ -327,9 +326,9 @@ public void testCapRange() assertCapRange(first, second, 13, 99); } - private static void assertCapRange(PlanNodeStatsEstimate stats, PlanNodeStatsEstimate cap, double expectedLow, double expectedHigh) + private void assertCapRange(PlanNodeStatsEstimate stats, PlanNodeStatsEstimate cap, double expectedLow, double expectedHigh) { - VariableStatsEstimate symbolStats = capStats(stats, cap).getVariableStatistics(VARIABLE); + VariableStatsEstimate symbolStats = calculator.capStats(stats, cap).getVariableStatistics(VARIABLE); assertEquals(symbolStats.getLowValue(), expectedLow); assertEquals(symbolStats.getHighValue(), expectedHigh); } @@ -351,9 +350,64 @@ public void testCapNullsFraction() assertCapNullsFraction(first, third, 1); } - private static void assertCapNullsFraction(PlanNodeStatsEstimate stats, PlanNodeStatsEstimate cap, double expected) + private void assertCapNullsFraction(PlanNodeStatsEstimate stats, PlanNodeStatsEstimate cap, double expected) { - assertEquals(capStats(stats, cap).getVariableStatistics(VARIABLE).getNullsFraction(), expected); + assertEquals(calculator.capStats(stats, cap).getVariableStatistics(VARIABLE).getNullsFraction(), expected); + } + + @Test + public void testAddHistograms() + { + StatisticRange zeroToTen = new StatisticRange(0, 10, 1); + StatisticRange zeroToFive = new StatisticRange(0, 5, 1); + StatisticRange fiveToTen = new StatisticRange(5, 10, 1); + StatisticRange threeToSeven = new StatisticRange(3, 7, 1); + + PlanNodeStatsEstimate unknownRowCount = statistics(NaN, NaN, NaN, NaN, zeroToTen); + PlanNodeStatsEstimate unknownNullsFraction = statistics(10, NaN, NaN, NaN, zeroToTen); + PlanNodeStatsEstimate first = statistics(50, NaN, 0.25, NaN, zeroToTen); + PlanNodeStatsEstimate second = statistics(25, NaN, 0.6, NaN, zeroToFive); + PlanNodeStatsEstimate third = statistics(25, NaN, 0.6, NaN, fiveToTen); + PlanNodeStatsEstimate fourth = statistics(20, NaN, 0.6, NaN, threeToSeven); + + // no histogram on unknown + assertEquals(calculator.addStatsAndCollapseDistinctValues(unknownRowCount, unknownRowCount).getVariableStatistics(VARIABLE).getHistogram(), Optional.empty()); + + // check when rows are available histograms are added properly. + ConnectorHistogram addedSameRange = DisjointRangeDomainHistogram.addDisjunction(unknownNullsFraction.getVariableStatistics(VARIABLE).getHistogram().get(), zeroToTen); + assertAddStatsHistogram(unknownNullsFraction, unknownNullsFraction, calculator::addStatsAndSumDistinctValues, addedSameRange); + assertAddStatsHistogram(unknownNullsFraction, unknownNullsFraction, calculator::addStatsAndCollapseDistinctValues, addedSameRange); + assertAddStatsHistogram(unknownNullsFraction, unknownNullsFraction, calculator::addStatsAndMaxDistinctValues, addedSameRange); + assertAddStatsHistogram(unknownNullsFraction, unknownNullsFraction, calculator::addStatsAndIntersect, addedSameRange); + + // check when only a sub-range is added, that the histogram still represents the full range + ConnectorHistogram fullRangeFirst = DisjointRangeDomainHistogram.addDisjunction(first.getVariableStatistics(VARIABLE).getHistogram().get(), zeroToTen); + ConnectorHistogram intersectedRangeSecond = DisjointRangeDomainHistogram.addConjunction(first.getVariableStatistics(VARIABLE).getHistogram().get(), zeroToFive); + assertAddStatsHistogram(first, second, calculator::addStatsAndSumDistinctValues, fullRangeFirst); + assertAddStatsHistogram(first, second, calculator::addStatsAndCollapseDistinctValues, fullRangeFirst); + assertAddStatsHistogram(first, second, calculator::addStatsAndMaxDistinctValues, fullRangeFirst); + assertAddStatsHistogram(first, second, calculator::addStatsAndIntersect, intersectedRangeSecond); + + // check when two ranges overlap, the new stats span both ranges + ConnectorHistogram fullRangeSecondThird = DisjointRangeDomainHistogram.addDisjunction(second.getVariableStatistics(VARIABLE).getHistogram().get(), fiveToTen); + ConnectorHistogram intersectedRangeSecondThird = DisjointRangeDomainHistogram.addConjunction(second.getVariableStatistics(VARIABLE).getHistogram().get(), fiveToTen); + assertAddStatsHistogram(second, third, calculator::addStatsAndSumDistinctValues, fullRangeSecondThird); + assertAddStatsHistogram(second, third, calculator::addStatsAndCollapseDistinctValues, fullRangeSecondThird); + assertAddStatsHistogram(second, third, calculator::addStatsAndMaxDistinctValues, fullRangeSecondThird); + assertAddStatsHistogram(second, third, calculator::addStatsAndIntersect, intersectedRangeSecondThird); + + // check when two ranges partially overlap, the addition/intersection is applied correctly + ConnectorHistogram fullRangeThirdFourth = DisjointRangeDomainHistogram.addDisjunction(third.getVariableStatistics(VARIABLE).getHistogram().get(), threeToSeven); + ConnectorHistogram intersectedRangeThirdFourth = DisjointRangeDomainHistogram.addConjunction(third.getVariableStatistics(VARIABLE).getHistogram().get(), threeToSeven); + assertAddStatsHistogram(third, fourth, calculator::addStatsAndSumDistinctValues, fullRangeThirdFourth); + assertAddStatsHistogram(third, fourth, calculator::addStatsAndCollapseDistinctValues, fullRangeThirdFourth); + assertAddStatsHistogram(third, fourth, calculator::addStatsAndMaxDistinctValues, fullRangeThirdFourth); + assertAddStatsHistogram(third, fourth, calculator::addStatsAndIntersect, intersectedRangeThirdFourth); + } + + private static void assertAddStatsHistogram(PlanNodeStatsEstimate first, PlanNodeStatsEstimate second, BiFunction function, ConnectorHistogram expected) + { + assertEquals(function.apply(first, second).getVariableStatistics(VARIABLE).getHistogram().get(), expected); } private static PlanNodeStatsEstimate statistics(double rowCount, double totalSize, double nullsFraction, double averageRowSize, StatisticRange range) @@ -365,6 +419,7 @@ private static PlanNodeStatsEstimate statistics(double rowCount, double totalSiz .setNullsFraction(nullsFraction) .setAverageRowSize(averageRowSize) .setStatisticsRange(range) + .setHistogram(Optional.of(DisjointRangeDomainHistogram.addConjunction(new UniformDistributionHistogram(range.getLow(), range.getHigh()), range))) .build()) .build(); } diff --git a/presto-main/src/test/java/com/facebook/presto/cost/TestStatisticRange.java b/presto-main/src/test/java/com/facebook/presto/cost/TestStatisticRange.java index d88780ecf8ccc..5975237dde8d0 100644 --- a/presto-main/src/test/java/com/facebook/presto/cost/TestStatisticRange.java +++ b/presto-main/src/test/java/com/facebook/presto/cost/TestStatisticRange.java @@ -20,6 +20,8 @@ import static java.lang.Double.NaN; import static java.lang.Double.POSITIVE_INFINITY; import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertFalse; +import static org.testng.Assert.assertTrue; public class TestStatisticRange { @@ -104,11 +106,66 @@ public void testAddAndCollapseDistinctValues() assertEquals(range(0, 3, 3).addAndCollapseDistinctValues(range(2, 6, 4)), range(0, 6, 6)); } + @Test + public void testIntersectOpenness() + { + StatisticRange first = range(0, true, 10, true, 10); + StatisticRange second = range(0, true, 5, true, 5); + StatisticRange intersect = first.intersect(second); + assertTrue(intersect.getOpenLow()); + assertTrue(intersect.getOpenHigh()); + intersect = second.intersect(first); + assertTrue(intersect.getOpenLow()); + assertTrue(intersect.getOpenHigh()); + + // second range bounds only on high + second = range(-1, true, 5, false, 5); + intersect = first.intersect(second); + assertTrue(intersect.getOpenLow()); + assertFalse(intersect.getOpenHigh()); + intersect = second.intersect(first); + assertTrue(intersect.getOpenLow()); + assertFalse(intersect.getOpenHigh()); + + // second range bounds on low and high + second = range(1, false, 5, false, 5); + intersect = first.intersect(second); + assertFalse(intersect.getOpenLow()); + assertFalse(intersect.getOpenHigh()); + intersect = second.intersect(first); + assertFalse(intersect.getOpenLow()); + assertFalse(intersect.getOpenHigh()); + + // second range bounds only on low + second = range(1, false, 5, true, 5); + intersect = first.intersect(second); + assertFalse(intersect.getOpenLow()); + assertTrue(intersect.getOpenHigh()); + intersect = second.intersect(first); + assertFalse(intersect.getOpenLow()); + assertTrue(intersect.getOpenHigh()); + + // same bounds but one is open and one is closed + first = range(0, false, 5, false, 5); + second = range(0, true, 5, true, 5); + intersect = first.intersect(second); + assertTrue(intersect.getOpenLow()); + assertTrue(intersect.getOpenHigh()); + intersect = second.intersect(first); + assertTrue(intersect.getOpenLow()); + assertTrue(intersect.getOpenHigh()); + } + private static StatisticRange range(double low, double high, double distinctValues) { return new StatisticRange(low, high, distinctValues); } + private static StatisticRange range(double low, boolean openLow, double high, boolean openHigh, double distinctValues) + { + return new StatisticRange(low, openLow, high, openHigh, distinctValues); + } + private static StatisticRange unboundedRange(double distinctValues) { return new StatisticRange(NEGATIVE_INFINITY, POSITIVE_INFINITY, distinctValues); diff --git a/presto-main/src/test/java/com/facebook/presto/cost/TestUniformHistogram.java b/presto-main/src/test/java/com/facebook/presto/cost/TestUniformHistogram.java new file mode 100644 index 0000000000000..395bc3f6e7518 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/cost/TestUniformHistogram.java @@ -0,0 +1,116 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.presto.cost; + +import com.facebook.presto.spi.statistics.ConnectorHistogram; +import com.facebook.presto.spi.statistics.Estimate; +import com.google.common.base.VerifyException; +import org.apache.commons.math3.distribution.RealDistribution; +import org.apache.commons.math3.distribution.UniformRealDistribution; +import org.testng.annotations.Test; + +import static java.lang.Double.POSITIVE_INFINITY; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertThrows; +import static org.testng.Assert.assertTrue; + +public class TestUniformHistogram + extends TestHistogram +{ + ConnectorHistogram createHistogram() + { + return new UniformDistributionHistogram(0, 1); + } + + RealDistribution getDistribution() + { + return new UniformRealDistribution(); + } + + @Override + double getDistinctValues() + { + return 100; + } + + @Test + public void testInvalidConstruction() + { + assertThrows(VerifyException.class, () -> new UniformDistributionHistogram(2.0, 1.0)); + } + + @Test + public void testNanRangeValues() + { + ConnectorHistogram hist = new UniformDistributionHistogram(Double.NaN, 2); + assertTrue(hist.inverseCumulativeProbability(0.5).isUnknown()); + + hist = new UniformDistributionHistogram(1.0, Double.NaN); + assertTrue(hist.inverseCumulativeProbability(0.5).isUnknown()); + + hist = new UniformDistributionHistogram(1.0, 2.0); + assertEquals(hist.inverseCumulativeProbability(0.5).getValue(), 1.5); + } + + @Test + public void testInfiniteRangeValues() + { + // test low value as infinite + ConnectorHistogram hist = new UniformDistributionHistogram(Double.NEGATIVE_INFINITY, 2); + + assertTrue(hist.inverseCumulativeProbability(0.5).isUnknown()); + assertEquals(hist.inverseCumulativeProbability(0.0), Estimate.unknown()); + assertEquals(hist.inverseCumulativeProbability(1.0).getValue(), 2.0); + + assertEquals(hist.cumulativeProbability(0.0, true), Estimate.unknown()); + assertEquals(hist.cumulativeProbability(1.0, true), Estimate.unknown()); + assertEquals(hist.cumulativeProbability(2.0, true).getValue(), 1.0); + assertEquals(hist.cumulativeProbability(2.5, true).getValue(), 1.0); + + // test high value as infinite + hist = new UniformDistributionHistogram(1.0, POSITIVE_INFINITY); + + assertTrue(hist.inverseCumulativeProbability(0.5).isUnknown()); + assertEquals(hist.inverseCumulativeProbability(0.0).getValue(), 1.0); + assertEquals(hist.inverseCumulativeProbability(1.0), Estimate.unknown()); + + assertEquals(hist.cumulativeProbability(0.0, true).getValue(), 0.0); + assertEquals(hist.cumulativeProbability(1.0, true).getValue(), 0.0); + assertEquals(hist.cumulativeProbability(1.5, true), Estimate.unknown()); + } + + @Test + public void testSingleValueRange() + { + UniformDistributionHistogram hist = new UniformDistributionHistogram(1.0, 1.0); + + assertEquals(hist.inverseCumulativeProbability(0.0).getValue(), 1.0); + assertEquals(hist.inverseCumulativeProbability(1.0).getValue(), 1.0); + assertEquals(hist.inverseCumulativeProbability(0.5).getValue(), 1.0); + + assertEquals(hist.cumulativeProbability(0.0, true).getValue(), 0.0); + assertEquals(hist.cumulativeProbability(0.5, true).getValue(), 0.0); + assertEquals(hist.cumulativeProbability(1.0, true).getValue(), 1.0); + assertEquals(hist.cumulativeProbability(1.5, true).getValue(), 1.0); + } + + /** + * {@link UniformDistributionHistogram} does not support the inclusive/exclusive arguments + */ + @Override + public void testInclusiveExclusive() + { + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/cost/TestVariableStatsEstimate.java b/presto-main/src/test/java/com/facebook/presto/cost/TestVariableStatsEstimate.java new file mode 100644 index 0000000000000..b26665eb30af4 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/cost/TestVariableStatsEstimate.java @@ -0,0 +1,48 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.presto.cost; + +import com.facebook.airlift.json.JsonCodec; +import com.google.common.collect.Range; +import org.testng.annotations.Test; + +import java.util.Optional; + +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertFalse; +import static org.testng.Assert.assertTrue; + +public class TestVariableStatsEstimate +{ + @Test + public void testSkipHistogramSerialization() + { + JsonCodec codec = JsonCodec.jsonCodec(VariableStatsEstimate.class); + VariableStatsEstimate estimate = VariableStatsEstimate.builder() + .setAverageRowSize(100) + .setDistinctValuesCount(100) + .setStatisticsRange(StatisticRange.fromRange(Range.open(1.0d, 2.0d))) + .setHistogram(Optional.of(new UniformDistributionHistogram(55, 65))) + .setNullsFraction(0.1) + .build(); + VariableStatsEstimate serialized = codec.fromBytes(codec.toBytes(estimate)); + assertEquals(serialized.getAverageRowSize(), estimate.getAverageRowSize()); + assertEquals(serialized.getDistinctValuesCount(), estimate.getDistinctValuesCount()); + assertEquals(serialized.getLowValue(), estimate.getLowValue()); + assertEquals(serialized.getHighValue(), estimate.getHighValue()); + assertTrue(estimate.getHistogram().isPresent()); + assertFalse(serialized.getHistogram().isPresent()); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java b/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java index 6a195ce96f25f..ff3c8c4034b6b 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java @@ -271,7 +271,8 @@ public void testDefaults() .setDefaultViewSecurityMode(DEFINER) .setCteHeuristicReplicationThreshold(4) .setLegacyJsonCast(true) - .setPrintEstimatedStatsFromCache(false)); + .setPrintEstimatedStatsFromCache(false) + .setUseHistograms(false)); } @Test @@ -487,6 +488,7 @@ public void testExplicitPropertyMappings() .put("default-view-security-mode", INVOKER.name()) .put("cte-heuristic-replication-threshold", "2") .put("optimizer.print-estimated-stats-from-cache", "true") + .put("optimizer.use-histograms", "true") .build(); FeaturesConfig expected = new FeaturesConfig() @@ -699,7 +701,8 @@ public void testExplicitPropertyMappings() .setDefaultViewSecurityMode(INVOKER) .setCteHeuristicReplicationThreshold(2) .setLegacyJsonCast(false) - .setPrintEstimatedStatsFromCache(true); + .setPrintEstimatedStatsFromCache(true) + .setUseHistograms(true); assertFullMapping(properties, expected); } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/BasePlanTest.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/BasePlanTest.java index 8303fad9c7824..5e90e80c18258 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/BasePlanTest.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/BasePlanTest.java @@ -309,9 +309,14 @@ protected Plan plan(String sql, Optimizer.PlanStage stage) } protected Plan plan(String sql, Optimizer.PlanStage stage, boolean forceSingleNode) + { + return plan(queryRunner.getDefaultSession(), sql, stage, forceSingleNode); + } + + protected Plan plan(Session session, String sql, Optimizer.PlanStage stage, boolean forceSingleNode) { try { - return queryRunner.inTransaction(transactionSession -> queryRunner.createPlan(transactionSession, sql, stage, forceSingleNode, WarningCollector.NOOP)); + return queryRunner.inTransaction(session, transactionSession -> queryRunner.createPlan(transactionSession, sql, stage, forceSingleNode, WarningCollector.NOOP)); } catch (RuntimeException e) { throw new AssertionError("Planning failed for SQL: " + sql, e); diff --git a/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/AbstractTestNativeGeneralQueries.java b/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/AbstractTestNativeGeneralQueries.java index 9d550991da7ba..bd4f8afe34ae3 100644 --- a/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/AbstractTestNativeGeneralQueries.java +++ b/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/AbstractTestNativeGeneralQueries.java @@ -258,10 +258,10 @@ public void testAnalyzeStats() // column_name | data_size | distinct_values_count | nulls_fraction | row_count | low_value | high_value assertQuery("SHOW STATS FOR region", "SELECT * FROM (VALUES" + - "('regionkey', NULL, 5.0, 0.0, NULL, '0', '4')," + - "('name', 54.0, 5.0, 0.0, NULL, NULL, NULL)," + - "('comment', 350.0, 5.0, 0.0, NULL, NULL, NULL)," + - "(NULL, NULL, NULL, NULL, 5.0, NULL, NULL))"); + "('regionkey', NULL, 5.0, 0.0, NULL, '0', '4', NULL)," + + "('name', 54.0, 5.0, 0.0, NULL, NULL, NULL, NULL)," + + "('comment', 350.0, 5.0, 0.0, NULL, NULL, NULL, NULL)," + + "(NULL, NULL, NULL, NULL, 5.0, NULL, NULL, NULL))"); // Create a partitioned table and run analyze on it. String tmpTableName = generateRandomTableName(); @@ -275,17 +275,17 @@ public void testAnalyzeStats() assertUpdate(String.format("ANALYZE %s", tmpTableName), 25); assertQuery(String.format("SHOW STATS for %s", tmpTableName), "SELECT * FROM (VALUES" + - "('name', 277.0, 1.0, 0.0, NULL, NULL, NULL)," + - "('regionkey', NULL, 5.0, 0.0, NULL, '0', '4')," + - "('nationkey', NULL, 25.0, 0.0, NULL, '0', '24')," + - "(NULL, NULL, NULL, NULL, 25.0, NULL, NULL))"); + "('name', 277.0, 1.0, 0.0, NULL, NULL, NULL, NULL)," + + "('regionkey', NULL, 5.0, 0.0, NULL, '0', '4', NULL)," + + "('nationkey', NULL, 25.0, 0.0, NULL, '0', '24', NULL)," + + "(NULL, NULL, NULL, NULL, 25.0, NULL, NULL, NULL))"); assertUpdate(String.format("ANALYZE %s WITH (partitions = ARRAY[ARRAY['0','0'],ARRAY['4', '11']])", tmpTableName), 2); assertQuery(String.format("SHOW STATS for (SELECT * FROM %s where regionkey=4 and nationkey=11)", tmpTableName), "SELECT * FROM (VALUES" + - "('name', 8.0, 1.0, 0.0, NULL, NULL, NULL)," + - "('regionkey', NULL, 1.0, 0.0, NULL, '4', '4')," + - "('nationkey', NULL, 1.0, 0.0, NULL, '11', '11')," + - "(NULL, NULL, NULL, NULL, 1.0, NULL, NULL))"); + "('name', 8.0, 1.0, 0.0, NULL, NULL, NULL, NULL)," + + "('regionkey', NULL, 1.0, 0.0, NULL, '4', '4', NULL)," + + "('nationkey', NULL, 1.0, 0.0, NULL, '11', '11', NULL)," + + "(NULL, NULL, NULL, NULL, 1.0, NULL, NULL, NULL))"); } finally { dropTableIfExists(tmpTableName); @@ -305,9 +305,9 @@ public void testAnalyzeStatsOnDecimals() assertUpdate(String.format("ANALYZE %s", tmpTableName), 7); assertQuery(String.format("SHOW STATS for %s", tmpTableName), "SELECT * FROM (VALUES" + - "('c0', NULL,4.0 , 0.2857142857142857, NULL, '-542392.89', '1000000.12')," + - "('c1', NULL,4.0 , 0.2857142857142857, NULL, '-6.72398239210929E12', '2.823982323232357E13')," + - "(NULL, NULL, NULL, NULL, 7.0, NULL, NULL))"); + "('c0', NULL,4.0 , 0.2857142857142857, NULL, '-542392.89', '1000000.12', NULL)," + + "('c1', NULL,4.0 , 0.2857142857142857, NULL, '-6.72398239210929E12', '2.823982323232357E13', NULL)," + + "(NULL, NULL, NULL, NULL, 7.0, NULL, NULL, NULL))"); } finally { dropTableIfExists(tmpTableName); diff --git a/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/AbstractTestWriter.java b/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/AbstractTestWriter.java index 23c642e35a84f..a95a7b2c2ea7f 100644 --- a/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/AbstractTestWriter.java +++ b/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/AbstractTestWriter.java @@ -278,36 +278,36 @@ public void testCollectColumnStatisticsOnCreateTable() assertQuery(format("SHOW STATS FOR (SELECT * FROM %s WHERE p_varchar = 'p1')", tmpTableName), "SELECT * FROM (VALUES " + - "('c_boolean', null, 2.0E0, 0.5E0, null, null, null), " + - "('c_bigint', null, 2.0E0, 0.5E0, null, '0', '1'), " + - "('c_double', null, 2.0E0, 0.5E0, null, '1.2', '2.2'), " + - "('c_timestamp', null, 2.0E0, 0.5E0, null, null, null), " + - "('c_varchar', 16.0E0, 2.0E0, 0.5E0, null, null, null), " + // 8.0 - "('c_array', 184.0E0, null, 0.5, null, null, null), " + // 176 - "('p_varchar', 8.0E0, 1.0E0, 0.0E0, null, null, null), " + - "(null, null, null, null, 4.0E0, null, null)) AS x (c_boolean, c_bigint, c_double, c_timestamp, c_varchar, c_array, p_varchar)"); + "('c_boolean', null, 2.0E0, 0.5E0, null, null, null, null), " + + "('c_bigint', null, 2.0E0, 0.5E0, null, '0', '1', null), " + + "('c_double', null, 2.0E0, 0.5E0, null, '1.2', '2.2', null), " + + "('c_timestamp', null, 2.0E0, 0.5E0, null, null, null, null), " + + "('c_varchar', 16.0E0, 2.0E0, 0.5E0, null, null, null, null), " + // 8.0 + "('c_array', 184.0E0, null, 0.5, null, null, null, null), " + // 176 + "('p_varchar', 8.0E0, 1.0E0, 0.0E0, null, null, null, null), " + + "(null, null, null, null, 4.0E0, null, null, null)) AS x (c_boolean, c_bigint, c_double, c_timestamp, c_varchar, c_array, p_varchar, h_varchar)"); assertQuery(format("SHOW STATS FOR (SELECT * FROM %s WHERE p_varchar = 'p2')", tmpTableName), "SELECT * FROM (VALUES " + - "('c_boolean', null, 2.0E0, 0.5E0, null, null, null), " + - "('c_bigint', null, 2.0E0, 0.5E0, null, '1', '2'), " + - "('c_double', null, 2.0E0, 0.5E0, null, '2.3', '3.3'), " + - "('c_timestamp', null, 2.0E0, 0.5E0, null, null, null), " + - "('c_varchar', 16.0E0, 2.0E0, 0.5E0, null, null, null), " + // 8 - "('c_array', 104.0E0, null, 0.5, null, null, null), " + // 96 - "('p_varchar', 8.0E0, 1.0E0, 0.0E0, null, null, null), " + - "(null, null, null, null, 4.0E0, null, null)) AS x (c_boolean, c_bigint, c_double, c_timestamp, c_varchar, c_array, p_varchar)"); + "('c_boolean', null, 2.0E0, 0.5E0, null, null, null, null), " + + "('c_bigint', null, 2.0E0, 0.5E0, null, '1', '2', null), " + + "('c_double', null, 2.0E0, 0.5E0, null, '2.3', '3.3', null), " + + "('c_timestamp', null, 2.0E0, 0.5E0, null, null, null, null), " + + "('c_varchar', 16.0E0, 2.0E0, 0.5E0, null, null, null, null), " + // 8 + "('c_array', 104.0E0, null, 0.5, null, null, null, null), " + // 96 + "('p_varchar', 8.0E0, 1.0E0, 0.0E0, null, null, null, null), " + + "(null, null, null, null, 4.0E0, null, null, null)) AS x (c_boolean, c_bigint, c_double, c_timestamp, c_varchar, c_array, p_varchar, h_varchar)"); // non existing partition assertQuery(format("SHOW STATS FOR (SELECT * FROM %s WHERE p_varchar = 'p3')", tmpTableName), "SELECT * FROM (VALUES " + - "('c_boolean', null, 0E0, 0E0, null, null, null), " + - "('c_bigint', null, 0E0, 0E0, null, null, null), " + - "('c_double', null, 0E0, 0E0, null, null, null), " + - "('c_timestamp', null, 0E0, 0E0, null, null, null), " + - "('c_varchar', 0E0, 0E0, 0E0, null, null, null), " + - "('c_array', null, 0E0, 0E0, null, null, null), " + - "('p_varchar', 0E0, 0E0, 0E0, null, null, null), " + - "(null, null, null, null, 0E0, null, null)) AS x (c_boolean, c_bigint, c_double, c_timestamp, c_varchar, c_array, p_varchar)"); + "('c_boolean', null, 0E0, 0E0, null, null, null, null), " + + "('c_bigint', null, 0E0, 0E0, null, null, null, null), " + + "('c_double', null, 0E0, 0E0, null, null, null, null), " + + "('c_timestamp', null, 0E0, 0E0, null, null, null, null), " + + "('c_varchar', 0E0, 0E0, 0E0, null, null, null, null), " + + "('c_array', null, 0E0, 0E0, null, null, null, null), " + + "('p_varchar', 0E0, 0E0, 0E0, null, null, null, null), " + + "(null, null, null, null, 0E0, null, null, null)) AS x (c_boolean, c_bigint, c_double, c_timestamp, c_varchar, c_array, p_varchar, h_varchar)"); dropTableIfExists(tmpTableName); } @@ -348,36 +348,36 @@ public void testCollectColumnStatisticsOnInsert() assertQuery(format("SHOW STATS FOR (SELECT * FROM %s WHERE p_varchar = 'p1')", tmpTableName), "SELECT * FROM (VALUES " + - "('c_boolean', null, 2.0E0, 0.5E0, null, null, null), " + - "('c_bigint', null, 2.0E0, 0.5E0, null, '0', '1'), " + - "('c_double', null, 2.0E0, 0.5E0, null, '1.2', '2.2'), " + - "('c_timestamp', null, 2.0E0, 0.5E0, null, null, null), " + - "('c_varchar', 16.0E0, 2.0E0, 0.5E0, null, null, null), " + // 8 - "('c_array', 184.0E0, null, 0.5E0, null, null, null), " + // 176 - "('p_varchar', 8.0E0, 1.0E0, 0.0E0, null, null, null), " + - "(null, null, null, null, 4.0E0, null, null)) AS x (c_boolean, c_bigint, c_double, c_timestamp, c_varchar, c_array, p_varchar)"); + "('c_boolean', null, 2.0E0, 0.5E0, null, null, null, null), " + + "('c_bigint', null, 2.0E0, 0.5E0, null, '0', '1', null), " + + "('c_double', null, 2.0E0, 0.5E0, null, '1.2', '2.2', null), " + + "('c_timestamp', null, 2.0E0, 0.5E0, null, null, null, null), " + + "('c_varchar', 16.0E0, 2.0E0, 0.5E0, null, null, null, null), " + // 8 + "('c_array', 184.0E0, null, 0.5E0, null, null, null, null), " + // 176 + "('p_varchar', 8.0E0, 1.0E0, 0.0E0, null, null, null, null), " + + "(null, null, null, null, 4.0E0, null, null, null)) AS x (c_boolean, c_bigint, c_double, c_timestamp, c_varchar, c_array, p_varchar, p_varchar)"); assertQuery(format("SHOW STATS FOR (SELECT * FROM %s WHERE p_varchar = 'p2')", tmpTableName), "SELECT * FROM (VALUES " + - "('c_boolean', null, 2.0E0, 0.5E0, null, null, null), " + - "('c_bigint', null, 2.0E0, 0.5E0, null, '1', '2'), " + - "('c_double', null, 2.0E0, 0.5E0, null, '2.3', '3.3'), " + - "('c_timestamp', null, 2.0E0, 0.5E0, null, null, null), " + - "('c_varchar', 16.0E0, 2.0E0, 0.5E0, null, null, null), " + // 8 - "('c_array', 104.0E0, null, 0.5, null, null, null), " + // 96 - "('p_varchar', 8.0E0, 1.0E0, 0.0E0, null, null, null), " + - "(null, null, null, null, 4.0E0, null, null)) AS x (c_boolean, c_bigint, c_double, c_timestamp, c_varchar, c_array, p_varchar)"); + "('c_boolean', null, 2.0E0, 0.5E0, null, null, null, null), " + + "('c_bigint', null, 2.0E0, 0.5E0, null, '1', '2', null), " + + "('c_double', null, 2.0E0, 0.5E0, null, '2.3', '3.3', null), " + + "('c_timestamp', null, 2.0E0, 0.5E0, null, null, null, null), " + + "('c_varchar', 16.0E0, 2.0E0, 0.5E0, null, null, null, null), " + // 8 + "('c_array', 104.0E0, null, 0.5, null, null, null, null), " + // 96 + "('p_varchar', 8.0E0, 1.0E0, 0.0E0, null, null, null, null), " + + "(null, null, null, null, 4.0E0, null, null, null)) AS x (c_boolean, c_bigint, c_double, c_timestamp, c_varchar, c_array, p_varchar, p_varchar)"); // non existing partition assertQuery(format("SHOW STATS FOR (SELECT * FROM %s WHERE p_varchar = 'p3')", tmpTableName), "SELECT * FROM (VALUES " + - "('c_boolean', null, 0E0, 0E0, null, null, null), " + - "('c_bigint', null, 0E0, 0E0, null, null, null), " + - "('c_double', null, 0E0, 0E0, null, null, null), " + - "('c_timestamp', null, 0E0, 0E0, null, null, null), " + - "('c_varchar', 0E0, 0E0, 0E0, null, null, null), " + - "('c_array', null, 0E0, 0E0, null, null, null), " + - "('p_varchar', 0E0, 0E0, 0E0, null, null, null), " + - "(null, null, null, null, 0E0, null, null)) AS x (c_boolean, c_bigint, c_double, c_timestamp, c_varchar, c_array, p_varchar)"); + "('c_boolean', null, 0E0, 0E0, null, null, null, null), " + + "('c_bigint', null, 0E0, 0E0, null, null, null, null), " + + "('c_double', null, 0E0, 0E0, null, null, null, null), " + + "('c_timestamp', null, 0E0, 0E0, null, null, null, null), " + + "('c_varchar', 0E0, 0E0, 0E0, null, null, null, null), " + + "('c_array', null, 0E0, 0E0, null, null, null, null), " + + "('p_varchar', 0E0, 0E0, 0E0, null, null, null, null), " + + "(null, null, null, null, 0E0, null, null, null)) AS x (c_boolean, c_bigint, c_double, c_timestamp, c_varchar, c_array, p_varchar, p_varchar)"); dropTableIfExists(tmpTableName); } diff --git a/presto-product-tests/src/main/java/com/facebook/presto/tests/hive/TestExternalHiveTable.java b/presto-product-tests/src/main/java/com/facebook/presto/tests/hive/TestExternalHiveTable.java index aac3355d38f0d..5a880eb631522 100644 --- a/presto-product-tests/src/main/java/com/facebook/presto/tests/hive/TestExternalHiveTable.java +++ b/presto-product-tests/src/main/java/com/facebook/presto/tests/hive/TestExternalHiveTable.java @@ -55,19 +55,19 @@ public void testShowStatisticsForExternalTable() onHive().executeQuery("ANALYZE TABLE " + EXTERNAL_TABLE_NAME + " PARTITION (p_regionkey) COMPUTE STATISTICS"); assertThat(query("SHOW STATS FOR " + EXTERNAL_TABLE_NAME)).containsOnly( - row("p_nationkey", null, null, null, null, null, null), - row("p_name", null, null, null, null, null, null), - row("p_comment", null, null, null, null, null, null), - row("p_regionkey", null, 1.0, 0.0, null, "1", "1"), - row(null, null, null, null, 5.0, null, null)); + row("p_nationkey", null, null, null, null, null, null, null), + row("p_name", null, null, null, null, null, null, null), + row("p_comment", null, null, null, null, null, null, null), + row("p_regionkey", null, 1.0, 0.0, null, "1", "1", null), + row(null, null, null, null, 5.0, null, null, null)); onHive().executeQuery("ANALYZE TABLE " + EXTERNAL_TABLE_NAME + " PARTITION (p_regionkey) COMPUTE STATISTICS FOR COLUMNS"); assertThat(query("SHOW STATS FOR " + EXTERNAL_TABLE_NAME)).containsOnly( - row("p_nationkey", null, 5.0, 0.0, null, "1", "24"), - row("p_name", 38.0, 5.0, 0.0, null, null, null), - row("p_comment", 499.0, 5.0, 0.0, null, null, null), - row("p_regionkey", null, 1.0, 0.0, null, "1", "1"), - row(null, null, null, null, 5.0, null, null)); + row("p_nationkey", null, 5.0, 0.0, null, "1", "24", null), + row("p_name", 38.0, 5.0, 0.0, null, null, null, null), + row("p_comment", 499.0, 5.0, 0.0, null, null, null, null), + row("p_regionkey", null, 1.0, 0.0, null, "1", "1", null), + row(null, null, null, null, 5.0, null, null, null)); } @Test diff --git a/presto-product-tests/src/main/java/com/facebook/presto/tests/hive/TestHiveTableStatistics.java b/presto-product-tests/src/main/java/com/facebook/presto/tests/hive/TestHiveTableStatistics.java index 7562179670e7d..59071737e25cc 100644 --- a/presto-product-tests/src/main/java/com/facebook/presto/tests/hive/TestHiveTableStatistics.java +++ b/presto-product-tests/src/main/java/com/facebook/presto/tests/hive/TestHiveTableStatistics.java @@ -102,58 +102,58 @@ public Requirement getRequirements(Configuration configuration) .build(); private static final List ALL_TYPES_TABLE_STATISTICS = ImmutableList.of( - row("c_tinyint", null, 2.0, 0.0, null, "121", "127"), - row("c_smallint", null, 2.0, 0.0, null, "32761", "32767"), - row("c_int", null, 2.0, 0.0, null, "2147483641", "2147483647"), - row("c_bigint", null, 2.0, 0.0, null, "9223372036854775807", "9223372036854775807"), - row("c_float", null, 2.0, 0.0, null, "123.341", "123.345"), - row("c_double", null, 2.0, 0.0, null, "234.561", "235.567"), - row("c_decimal", null, 2.0, 0.0, null, "345.0", "346.0"), - row("c_decimal_w_params", null, 2.0, 0.0, null, "345.671", "345.678"), - row("c_timestamp", null, 2.0, 0.0, null, null, null), - row("c_date", null, 2.0, 0.0, null, "2015-05-09", "2015-06-10"), - row("c_string", 22.0, 2.0, 0.0, null, null, null), - row("c_varchar", 20.0, 2.0, 0.0, null, null, null), - row("c_char", 12.0, 2.0, 0.0, null, null, null), - row("c_boolean", null, 2.0, 0.0, null, null, null), - row("c_binary", 23.0, null, 0.0, null, null, null), - row(null, null, null, null, 2.0, null, null)); + row("c_tinyint", null, 2.0, 0.0, null, "121", "127", null), + row("c_smallint", null, 2.0, 0.0, null, "32761", "32767", null), + row("c_int", null, 2.0, 0.0, null, "2147483641", "2147483647", null), + row("c_bigint", null, 2.0, 0.0, null, "9223372036854775807", "9223372036854775807", null), + row("c_float", null, 2.0, 0.0, null, "123.341", "123.345", null), + row("c_double", null, 2.0, 0.0, null, "234.561", "235.567", null), + row("c_decimal", null, 2.0, 0.0, null, "345.0", "346.0", null), + row("c_decimal_w_params", null, 2.0, 0.0, null, "345.671", "345.678", null), + row("c_timestamp", null, 2.0, 0.0, null, null, null, null), + row("c_date", null, 2.0, 0.0, null, "2015-05-09", "2015-06-10", null), + row("c_string", 22.0, 2.0, 0.0, null, null, null, null), + row("c_varchar", 20.0, 2.0, 0.0, null, null, null, null), + row("c_char", 12.0, 2.0, 0.0, null, null, null, null), + row("c_boolean", null, 2.0, 0.0, null, null, null, null), + row("c_binary", 23.0, null, 0.0, null, null, null, null), + row(null, null, null, null, 2.0, null, null, null)); private static final List ALL_TYPES_ALL_NULL_TABLE_STATISTICS = ImmutableList.of( - row("c_tinyint", null, 0.0, 1.0, null, null, null), - row("c_smallint", null, 0.0, 1.0, null, null, null), - row("c_int", null, 0.0, 1.0, null, null, null), - row("c_bigint", null, 0.0, 1.0, null, null, null), - row("c_float", null, 0.0, 1.0, null, null, null), - row("c_double", null, 0.0, 1.0, null, null, null), - row("c_decimal", null, 0.0, 1.0, null, null, null), - row("c_decimal_w_params", null, 0.0, 1.0, null, null, null), - row("c_timestamp", null, 0.0, 1.0, null, null, null), - row("c_date", null, 0.0, 1.0, null, null, null), - row("c_string", 0.0, 0.0, 1.0, null, null, null), - row("c_varchar", 0.0, 0.0, 1.0, null, null, null), - row("c_char", 0.0, 0.0, 1.0, null, null, null), - row("c_boolean", null, 0.0, 1.0, null, null, null), - row("c_binary", 0.0, null, 1.0, null, null, null), - row(null, null, null, null, 1.0, null, null)); + row("c_tinyint", null, 0.0, 1.0, null, null, null, null), + row("c_smallint", null, 0.0, 1.0, null, null, null, null), + row("c_int", null, 0.0, 1.0, null, null, null, null), + row("c_bigint", null, 0.0, 1.0, null, null, null, null), + row("c_float", null, 0.0, 1.0, null, null, null, null), + row("c_double", null, 0.0, 1.0, null, null, null, null), + row("c_decimal", null, 0.0, 1.0, null, null, null, null), + row("c_decimal_w_params", null, 0.0, 1.0, null, null, null, null), + row("c_timestamp", null, 0.0, 1.0, null, null, null, null), + row("c_date", null, 0.0, 1.0, null, null, null, null), + row("c_string", 0.0, 0.0, 1.0, null, null, null, null), + row("c_varchar", 0.0, 0.0, 1.0, null, null, null, null), + row("c_char", 0.0, 0.0, 1.0, null, null, null, null), + row("c_boolean", null, 0.0, 1.0, null, null, null, null), + row("c_binary", 0.0, null, 1.0, null, null, null, null), + row(null, null, null, null, 1.0, null, null, null)); private static final List ALL_TYPES_EMPTY_TABLE_STATISTICS = ImmutableList.of( - row("c_tinyint", null, 0.0, 0.0, null, null, null), - row("c_smallint", null, 0.0, 0.0, null, null, null), - row("c_int", null, 0.0, 0.0, null, null, null), - row("c_bigint", null, 0.0, 0.0, null, null, null), - row("c_float", null, 0.0, 0.0, null, null, null), - row("c_double", null, 0.0, 0.0, null, null, null), - row("c_decimal", null, 0.0, 0.0, null, null, null), - row("c_decimal_w_params", null, 0.0, 0.0, null, null, null), - row("c_timestamp", null, 0.0, 0.0, null, null, null), - row("c_date", null, 0.0, 0.0, null, null, null), - row("c_string", 0.0, 0.0, 0.0, null, null, null), - row("c_varchar", 0.0, 0.0, 0.0, null, null, null), - row("c_char", 0.0, 0.0, 0.0, null, null, null), - row("c_boolean", null, 0.0, 0.0, null, null, null), - row("c_binary", 0.0, null, 0.0, null, null, null), - row(null, null, null, null, 0.0, null, null)); + row("c_tinyint", null, 0.0, 0.0, null, null, null, null), + row("c_smallint", null, 0.0, 0.0, null, null, null, null), + row("c_int", null, 0.0, 0.0, null, null, null, null), + row("c_bigint", null, 0.0, 0.0, null, null, null, null), + row("c_float", null, 0.0, 0.0, null, null, null, null), + row("c_double", null, 0.0, 0.0, null, null, null, null), + row("c_decimal", null, 0.0, 0.0, null, null, null, null), + row("c_decimal_w_params", null, 0.0, 0.0, null, null, null, null), + row("c_timestamp", null, 0.0, 0.0, null, null, null, null), + row("c_date", null, 0.0, 0.0, null, null, null, null), + row("c_string", 0.0, 0.0, 0.0, null, null, null, null), + row("c_varchar", 0.0, 0.0, 0.0, null, null, null, null), + row("c_char", 0.0, 0.0, 0.0, null, null, null, null), + row("c_boolean", null, 0.0, 0.0, null, null, null, null), + row("c_binary", 0.0, null, 0.0, null, null, null, null), + row(null, null, null, null, 0.0, null, null, null)); private static final class AllTypesTable implements RequirementsProvider @@ -179,33 +179,33 @@ public void testStatisticsForUnpartitionedTable() // table not analyzed assertThat(query(showStatsWholeTable)).containsOnly( - row("n_nationkey", null, null, anyOf(null, 0.0), null, null, null), - row("n_name", null, null, anyOf(null, 0.0), null, null, null), - row("n_regionkey", null, null, anyOf(null, 0.0), null, null, null), - row("n_comment", null, null, anyOf(null, 0.0), null, null, null), - row(null, null, null, null, anyOf(null, 0.0), null, null)); // anyOf because of different behaviour on HDP (hive 1.2) and CDH (hive 1.1) + row("n_nationkey", null, null, anyOf(null, 0.0), null, null, null, null), + row("n_name", null, null, anyOf(null, 0.0), null, null, null, null), + row("n_regionkey", null, null, anyOf(null, 0.0), null, null, null, null), + row("n_comment", null, null, anyOf(null, 0.0), null, null, null, null), + row(null, null, null, null, anyOf(null, 0.0), null, null, null)); // anyOf because of different behaviour on HDP (hive 1.2) and CDH (hive 1.1) // basic analysis onHive().executeQuery("ANALYZE TABLE " + tableNameInDatabase + " COMPUTE STATISTICS"); assertThat(query(showStatsWholeTable)).containsOnly( - row("n_nationkey", null, null, null, null, null, null), - row("n_name", null, null, null, null, null, null), - row("n_regionkey", null, null, null, null, null, null), - row("n_comment", null, null, null, null, null, null), - row(null, null, null, null, 25.0, null, null)); + row("n_nationkey", null, null, null, null, null, null, null), + row("n_name", null, null, null, null, null, null, null), + row("n_regionkey", null, null, null, null, null, null, null), + row("n_comment", null, null, null, null, null, null, null), + row(null, null, null, null, 25.0, null, null, null)); // column analysis onHive().executeQuery("ANALYZE TABLE " + tableNameInDatabase + " COMPUTE STATISTICS FOR COLUMNS"); assertThat(query(showStatsWholeTable)).containsOnly( - row("n_nationkey", null, 19.0, 0.0, null, "0", "24"), - row("n_name", 177.0, 24.0, 0.0, null, null, null), - row("n_regionkey", null, 5.0, 0.0, null, "0", "4"), - row("n_comment", 1857.0, 25.0, 0.0, null, null, null), - row(null, null, null, null, 25.0, null, null)); + row("n_nationkey", null, 19.0, 0.0, null, "0", "24", null), + row("n_name", 177.0, 24.0, 0.0, null, null, null, null), + row("n_regionkey", null, 5.0, 0.0, null, "0", "4", null), + row("n_comment", 1857.0, 25.0, 0.0, null, null, null, null), + row(null, null, null, null, 25.0, null, null, null)); } @Test @@ -221,118 +221,118 @@ public void testStatisticsForTablePartitionedByBigint() // table not analyzed assertThat(query(showStatsWholeTable)).containsOnly( - row("p_nationkey", null, null, null, null, null, null), - row("p_name", null, null, null, null, null, null), - row("p_regionkey", null, null, null, null, null, null), - row("p_comment", null, null, null, null, null, null), - row(null, null, null, null, null, null, null)); + row("p_nationkey", null, null, null, null, null, null, null), + row("p_name", null, null, null, null, null, null, null), + row("p_regionkey", null, null, null, null, null, null, null), + row("p_comment", null, null, null, null, null, null, null), + row(null, null, null, null, null, null, null, null)); assertThat(query(showStatsPartitionOne)).containsOnly( - row("p_nationkey", null, null, null, null, null, null), - row("p_name", null, null, null, null, null, null), - row("p_regionkey", null, null, null, null, null, null), - row("p_comment", null, null, null, null, null, null), - row(null, null, null, null, null, null, null)); + row("p_nationkey", null, null, null, null, null, null, null), + row("p_name", null, null, null, null, null, null, null), + row("p_regionkey", null, null, null, null, null, null, null), + row("p_comment", null, null, null, null, null, null, null), + row(null, null, null, null, null, null, null, null)); // basic analysis for single partition onHive().executeQuery("ANALYZE TABLE " + tableNameInDatabase + " PARTITION (p_regionkey = \"1\") COMPUTE STATISTICS"); assertThat(query(showStatsWholeTable)).containsOnly( - row("p_nationkey", null, null, null, null, null, null), - row("p_name", null, null, null, null, null, null), - row("p_regionkey", null, 3.0, 0.0, null, "1", "3"), - row("p_comment", null, null, null, null, null, null), - row(null, null, null, null, 15.0, null, null)); + row("p_nationkey", null, null, null, null, null, null, null), + row("p_name", null, null, null, null, null, null, null), + row("p_regionkey", null, 3.0, 0.0, null, "1", "3", null), + row("p_comment", null, null, null, null, null, null, null), + row(null, null, null, null, 15.0, null, null, null)); assertThat(query(showStatsPartitionOne)).containsOnly( - row("p_nationkey", null, null, null, null, null, null), - row("p_name", null, null, null, null, null, null), - row("p_regionkey", null, 1.0, 0.0, null, "1", "1"), - row("p_comment", null, null, null, null, null, null), - row(null, null, null, null, 5.0, null, null)); + row("p_nationkey", null, null, null, null, null, null, null), + row("p_name", null, null, null, null, null, null, null), + row("p_regionkey", null, 1.0, 0.0, null, "1", "1", null), + row("p_comment", null, null, null, null, null, null, null), + row(null, null, null, null, 5.0, null, null, null)); assertThat(query(showStatsPartitionTwo)).containsOnly( - row("p_nationkey", null, null, null, null, null, null), - row("p_name", null, null, null, null, null, null), - row("p_regionkey", null, null, null, null, null, null), - row("p_comment", null, null, null, null, null, null), - row(null, null, null, null, null, null, null)); + row("p_nationkey", null, null, null, null, null, null, null), + row("p_name", null, null, null, null, null, null, null), + row("p_regionkey", null, null, null, null, null, null, null), + row("p_comment", null, null, null, null, null, null, null), + row(null, null, null, null, null, null, null, null)); // basic analysis for all partitions onHive().executeQuery("ANALYZE TABLE " + tableNameInDatabase + " PARTITION (p_regionkey) COMPUTE STATISTICS"); assertThat(query(showStatsWholeTable)).containsOnly( - row("p_nationkey", null, null, null, null, null, null), - row("p_name", null, null, null, null, null, null), - row("p_regionkey", null, 3.0, 0.0, null, "1", "3"), - row("p_comment", null, null, null, null, null, null), - row(null, null, null, null, 15.0, null, null)); + row("p_nationkey", null, null, null, null, null, null, null), + row("p_name", null, null, null, null, null, null, null), + row("p_regionkey", null, 3.0, 0.0, null, "1", "3", null), + row("p_comment", null, null, null, null, null, null, null), + row(null, null, null, null, 15.0, null, null, null)); assertThat(query(showStatsPartitionOne)).containsOnly( - row("p_nationkey", null, null, null, null, null, null), - row("p_name", null, null, null, null, null, null), - row("p_regionkey", null, 1.0, 0.0, null, "1", "1"), - row("p_comment", null, null, null, null, null, null), - row(null, null, null, null, 5.0, null, null)); + row("p_nationkey", null, null, null, null, null, null, null), + row("p_name", null, null, null, null, null, null, null), + row("p_regionkey", null, 1.0, 0.0, null, "1", "1", null), + row("p_comment", null, null, null, null, null, null, null), + row(null, null, null, null, 5.0, null, null, null)); assertThat(query(showStatsPartitionTwo)).containsOnly( - row("p_nationkey", null, null, null, null, null, null), - row("p_name", null, null, null, null, null, null), - row("p_regionkey", null, 1.0, 0.0, null, "2", "2"), - row("p_comment", null, null, null, null, null, null), - row(null, null, null, null, 5.0, null, null)); + row("p_nationkey", null, null, null, null, null, null, null), + row("p_name", null, null, null, null, null, null, null), + row("p_regionkey", null, 1.0, 0.0, null, "2", "2", null), + row("p_comment", null, null, null, null, null, null, null), + row(null, null, null, null, 5.0, null, null, null)); // column analysis for single partition onHive().executeQuery("ANALYZE TABLE " + tableNameInDatabase + " PARTITION (p_regionkey = \"1\") COMPUTE STATISTICS FOR COLUMNS"); assertThat(query(showStatsWholeTable)).containsOnly( - row("p_nationkey", null, 5.0, 0.0, null, "1", "24"), - row("p_name", 114.0, 5.0, 0.0, null, null, null), - row("p_regionkey", null, 3.0, 0.0, null, "1", "3"), - row("p_comment", 1497.0, 5.0, 0.0, null, null, null), - row(null, null, null, null, 15.0, null, null)); + row("p_nationkey", null, 5.0, 0.0, null, "1", "24", null), + row("p_name", 114.0, 5.0, 0.0, null, null, null, null), + row("p_regionkey", null, 3.0, 0.0, null, "1", "3", null), + row("p_comment", 1497.0, 5.0, 0.0, null, null, null, null), + row(null, null, null, null, 15.0, null, null, null)); assertThat(query(showStatsPartitionOne)).containsOnly( - row("p_nationkey", null, 5.0, 0.0, null, "1", "24"), - row("p_name", 38.0, 5.0, 0.0, null, null, null), - row("p_regionkey", null, 1.0, 0.0, null, "1", "1"), - row("p_comment", 499.0, 5.0, 0.0, null, null, null), - row(null, null, null, null, 5.0, null, null)); + row("p_nationkey", null, 5.0, 0.0, null, "1", "24", null), + row("p_name", 38.0, 5.0, 0.0, null, null, null, null), + row("p_regionkey", null, 1.0, 0.0, null, "1", "1", null), + row("p_comment", 499.0, 5.0, 0.0, null, null, null, null), + row(null, null, null, null, 5.0, null, null, null)); assertThat(query(showStatsPartitionTwo)).containsOnly( - row("p_nationkey", null, null, null, null, null, null), - row("p_name", null, null, null, null, null, null), - row("p_regionkey", null, 1.0, 0.0, null, "2", "2"), - row("p_comment", null, null, null, null, null, null), - row(null, null, null, null, 5.0, null, null)); + row("p_nationkey", null, null, null, null, null, null, null), + row("p_name", null, null, null, null, null, null, null), + row("p_regionkey", null, 1.0, 0.0, null, "2", "2", null), + row("p_comment", null, null, null, null, null, null, null), + row(null, null, null, null, 5.0, null, null, null)); // column analysis for all partitions onHive().executeQuery("ANALYZE TABLE " + tableNameInDatabase + " PARTITION (p_regionkey) COMPUTE STATISTICS FOR COLUMNS"); assertThat(query(showStatsWholeTable)).containsOnly( - row("p_nationkey", null, 5.0, 0.0, null, "1", "24"), - row("p_name", 109.0, 5.0, 0.0, null, null, null), - row("p_regionkey", null, 3.0, 0.0, null, "1", "3"), - row("p_comment", 1197.0, 5.0, 0.0, null, null, null), - row(null, null, null, null, 15.0, null, null)); + row("p_nationkey", null, 5.0, 0.0, null, "1", "24", null), + row("p_name", 109.0, 5.0, 0.0, null, null, null, null), + row("p_regionkey", null, 3.0, 0.0, null, "1", "3", null), + row("p_comment", 1197.0, 5.0, 0.0, null, null, null, null), + row(null, null, null, null, 15.0, null, null, null)); assertThat(query(showStatsPartitionOne)).containsOnly( - row("p_nationkey", null, 5.0, 0.0, null, "1", "24"), - row("p_name", 38.0, 5.0, 0.0, null, null, null), - row("p_regionkey", null, 1.0, 0.0, null, "1", "1"), - row("p_comment", 499.0, 5.0, 0.0, null, null, null), - row(null, null, null, null, 5.0, null, null)); + row("p_nationkey", null, 5.0, 0.0, null, "1", "24", null), + row("p_name", 38.0, 5.0, 0.0, null, null, null, null), + row("p_regionkey", null, 1.0, 0.0, null, "1", "1", null), + row("p_comment", 499.0, 5.0, 0.0, null, null, null, null), + row(null, null, null, null, 5.0, null, null, null)); assertThat(query(showStatsPartitionTwo)).containsOnly( - row("p_nationkey", null, 4.0, 0.0, null, "8", "21"), - row("p_name", 31.0, 5.0, 0.0, null, null, null), - row("p_regionkey", null, 1.0, 0.0, null, "2", "2"), - row("p_comment", 351.0, 5.0, 0.0, null, null, null), - row(null, null, null, null, 5.0, null, null)); + row("p_nationkey", null, 4.0, 0.0, null, "8", "21", null), + row("p_name", 31.0, 5.0, 0.0, null, null, null, null), + row("p_regionkey", null, 1.0, 0.0, null, "2", "2", null), + row("p_comment", 351.0, 5.0, 0.0, null, null, null, null), + row(null, null, null, null, 5.0, null, null, null)); } @Test @@ -348,118 +348,118 @@ public void testStatisticsForTablePartitionedByVarchar() // table not analyzed assertThat(query(showStatsWholeTable)).containsOnly( - row("p_nationkey", null, null, null, null, null, null), - row("p_name", null, null, null, null, null, null), - row("p_regionkey", null, null, null, null, null, null), - row("p_comment", null, null, null, null, null, null), - row(null, null, null, null, null, null, null)); + row("p_nationkey", null, null, null, null, null, null, null), + row("p_name", null, null, null, null, null, null, null), + row("p_regionkey", null, null, null, null, null, null, null), + row("p_comment", null, null, null, null, null, null, null), + row(null, null, null, null, null, null, null, null)); assertThat(query(showStatsPartitionOne)).containsOnly( - row("p_nationkey", null, null, null, null, null, null), - row("p_name", null, null, null, null, null, null), - row("p_regionkey", null, null, null, null, null, null), - row("p_comment", null, null, null, null, null, null), - row(null, null, null, null, null, null, null)); + row("p_nationkey", null, null, null, null, null, null, null), + row("p_name", null, null, null, null, null, null, null), + row("p_regionkey", null, null, null, null, null, null, null), + row("p_comment", null, null, null, null, null, null, null), + row(null, null, null, null, null, null, null, null)); // basic analysis for single partition onHive().executeQuery("ANALYZE TABLE " + tableNameInDatabase + " PARTITION (p_regionkey = \"AMERICA\") COMPUTE STATISTICS"); assertThat(query(showStatsWholeTable)).containsOnly( - row("p_nationkey", null, null, null, null, null, null), - row("p_name", null, null, null, null, null, null), - row("p_regionkey", 85.0, 3.0, 0.0, null, null, null), - row("p_comment", null, null, null, null, null, null), - row(null, null, null, null, 15.0, null, null)); + row("p_nationkey", null, null, null, null, null, null, null), + row("p_name", null, null, null, null, null, null, null), + row("p_regionkey", 85.0, 3.0, 0.0, null, null, null, null), + row("p_comment", null, null, null, null, null, null, null), + row(null, null, null, null, 15.0, null, null, null)); assertThat(query(showStatsPartitionOne)).containsOnly( - row("p_nationkey", null, null, null, null, null, null), - row("p_name", null, null, null, null, null, null), - row("p_regionkey", 35.0, 1.0, 0.0, null, null, null), - row("p_comment", null, null, null, null, null, null), - row(null, null, null, null, 5.0, null, null)); + row("p_nationkey", null, null, null, null, null, null, null), + row("p_name", null, null, null, null, null, null, null), + row("p_regionkey", 35.0, 1.0, 0.0, null, null, null, null), + row("p_comment", null, null, null, null, null, null, null), + row(null, null, null, null, 5.0, null, null, null)); assertThat(query(showStatsPartitionTwo)).containsOnly( - row("p_nationkey", null, null, null, null, null, null), - row("p_name", null, null, null, null, null, null), - row("p_regionkey", null, null, null, null, null, null), - row("p_comment", null, null, null, null, null, null), - row(null, null, null, null, null, null, null)); + row("p_nationkey", null, null, null, null, null, null, null), + row("p_name", null, null, null, null, null, null, null), + row("p_regionkey", null, null, null, null, null, null, null), + row("p_comment", null, null, null, null, null, null, null), + row(null, null, null, null, null, null, null, null)); // basic analysis for all partitions onHive().executeQuery("ANALYZE TABLE " + tableNameInDatabase + " PARTITION (p_regionkey) COMPUTE STATISTICS"); assertThat(query(showStatsWholeTable)).containsOnly( - row("p_nationkey", null, null, null, null, null, null), - row("p_name", null, null, null, null, null, null), - row("p_regionkey", 85.0, 3.0, 0.0, null, null, null), - row("p_comment", null, null, null, null, null, null), - row(null, null, null, null, 15.0, null, null)); + row("p_nationkey", null, null, null, null, null, null, null), + row("p_name", null, null, null, null, null, null, null), + row("p_regionkey", 85.0, 3.0, 0.0, null, null, null, null), + row("p_comment", null, null, null, null, null, null, null), + row(null, null, null, null, 15.0, null, null, null)); assertThat(query(showStatsPartitionOne)).containsOnly( - row("p_nationkey", null, null, null, null, null, null), - row("p_name", null, null, null, null, null, null), - row("p_regionkey", 35.0, 1.0, 0.0, null, null, null), - row("p_comment", null, null, null, null, null, null), - row(null, null, null, null, 5.0, null, null)); + row("p_nationkey", null, null, null, null, null, null, null), + row("p_name", null, null, null, null, null, null, null), + row("p_regionkey", 35.0, 1.0, 0.0, null, null, null, null), + row("p_comment", null, null, null, null, null, null, null), + row(null, null, null, null, 5.0, null, null, null)); assertThat(query(showStatsPartitionTwo)).containsOnly( - row("p_nationkey", null, null, null, null, null, null), - row("p_name", null, null, null, null, null, null), - row("p_regionkey", 20.0, 1.0, 0.0, null, null, null), - row("p_comment", null, null, null, null, null, null), - row(null, null, null, null, 5.0, null, null)); + row("p_nationkey", null, null, null, null, null, null, null), + row("p_name", null, null, null, null, null, null, null), + row("p_regionkey", 20.0, 1.0, 0.0, null, null, null, null), + row("p_comment", null, null, null, null, null, null, null), + row(null, null, null, null, 5.0, null, null, null)); // column analysis for single partition onHive().executeQuery("ANALYZE TABLE " + tableNameInDatabase + " PARTITION (p_regionkey = \"AMERICA\") COMPUTE STATISTICS FOR COLUMNS"); assertThat(query(showStatsWholeTable)).containsOnly( - row("p_nationkey", null, 5.0, 0.0, null, "1", "24"), - row("p_name", 114.0, 5.0, 0.0, null, null, null), - row("p_regionkey", 85.0, 3.0, 0.0, null, null, null), - row("p_comment", 1497.0, 5.0, 0.0, null, null, null), - row(null, null, null, null, 15.0, null, null)); + row("p_nationkey", null, 5.0, 0.0, null, "1", "24", null), + row("p_name", 114.0, 5.0, 0.0, null, null, null, null), + row("p_regionkey", 85.0, 3.0, 0.0, null, null, null, null), + row("p_comment", 1497.0, 5.0, 0.0, null, null, null, null), + row(null, null, null, null, 15.0, null, null, null)); assertThat(query(showStatsPartitionOne)).containsOnly( - row("p_nationkey", null, 5.0, 0.0, null, "1", "24"), - row("p_name", 38.0, 5.0, 0.0, null, null, null), - row("p_regionkey", 35.0, 1.0, 0.0, null, null, null), - row("p_comment", 499.0, 5.0, 0.0, null, null, null), - row(null, null, null, null, 5.0, null, null)); + row("p_nationkey", null, 5.0, 0.0, null, "1", "24", null), + row("p_name", 38.0, 5.0, 0.0, null, null, null, null), + row("p_regionkey", 35.0, 1.0, 0.0, null, null, null, null), + row("p_comment", 499.0, 5.0, 0.0, null, null, null, null), + row(null, null, null, null, 5.0, null, null, null)); assertThat(query(showStatsPartitionTwo)).containsOnly( - row("p_nationkey", null, null, null, null, null, null), - row("p_name", null, null, null, null, null, null), - row("p_regionkey", 20.0, 1.0, 0.0, null, null, null), - row("p_comment", null, null, null, null, null, null), - row(null, null, null, null, 5.0, null, null)); + row("p_nationkey", null, null, null, null, null, null, null), + row("p_name", null, null, null, null, null, null, null), + row("p_regionkey", 20.0, 1.0, 0.0, null, null, null, null), + row("p_comment", null, null, null, null, null, null, null), + row(null, null, null, null, 5.0, null, null, null)); // column analysis for all partitions onHive().executeQuery("ANALYZE TABLE " + tableNameInDatabase + " PARTITION (p_regionkey) COMPUTE STATISTICS FOR COLUMNS"); assertThat(query(showStatsWholeTable)).containsOnly( - row("p_nationkey", null, 5.0, 0.0, null, "1", "24"), - row("p_name", 109.0, 5.0, 0.0, null, null, null), - row("p_regionkey", 85.0, 3.0, 0.0, null, null, null), - row("p_comment", 1197.0, 5.0, 0.0, null, null, null), - row(null, null, null, null, 15.0, null, null)); + row("p_nationkey", null, 5.0, 0.0, null, "1", "24", null), + row("p_name", 109.0, 5.0, 0.0, null, null, null, null), + row("p_regionkey", 85.0, 3.0, 0.0, null, null, null, null), + row("p_comment", 1197.0, 5.0, 0.0, null, null, null, null), + row(null, null, null, null, 15.0, null, null, null)); assertThat(query(showStatsPartitionOne)).containsOnly( - row("p_nationkey", null, 5.0, 0.0, null, "1", "24"), - row("p_name", 38.0, 5.0, 0.0, null, null, null), - row("p_regionkey", 35.0, 1.0, 0.0, null, null, null), - row("p_comment", 499.0, 5.0, 0.0, null, null, null), - row(null, null, null, null, 5.0, null, null)); + row("p_nationkey", null, 5.0, 0.0, null, "1", "24", null), + row("p_name", 38.0, 5.0, 0.0, null, null, null, null), + row("p_regionkey", 35.0, 1.0, 0.0, null, null, null, null), + row("p_comment", 499.0, 5.0, 0.0, null, null, null, null), + row(null, null, null, null, 5.0, null, null, null)); assertThat(query(showStatsPartitionTwo)).containsOnly( - row("p_nationkey", null, 4.0, 0.0, null, "8", "21"), - row("p_name", 31.0, 5.0, 0.0, null, null, null), - row("p_regionkey", 20.0, 1.0, 0.0, null, null, null), - row("p_comment", 351.0, 5.0, 0.0, null, null, null), - row(null, null, null, null, 5.0, null, null)); + row("p_nationkey", null, 4.0, 0.0, null, "8", "21", null), + row("p_name", 31.0, 5.0, 0.0, null, null, null, null), + row("p_regionkey", 20.0, 1.0, 0.0, null, null, null, null), + row("p_comment", 351.0, 5.0, 0.0, null, null, null, null), + row(null, null, null, null, 5.0, null, null, null)); } // This covers also stats calculation for unpartitioned table @@ -472,43 +472,43 @@ public void testStatisticsForAllDataTypes() onHive().executeQuery("ANALYZE TABLE " + tableNameInDatabase + " COMPUTE STATISTICS"); assertThat(query("SHOW STATS FOR " + tableNameInDatabase)).containsOnly( - row("c_tinyint", null, null, null, null, null, null), - row("c_smallint", null, null, null, null, null, null), - row("c_int", null, null, null, null, null, null), - row("c_bigint", null, null, null, null, null, null), - row("c_float", null, null, null, null, null, null), - row("c_double", null, null, null, null, null, null), - row("c_decimal", null, null, null, null, null, null), - row("c_decimal_w_params", null, null, null, null, null, null), - row("c_timestamp", null, null, null, null, null, null), - row("c_date", null, null, null, null, null, null), - row("c_string", null, null, null, null, null, null), - row("c_varchar", null, null, null, null, null, null), - row("c_char", null, null, null, null, null, null), - row("c_boolean", null, null, null, null, null, null), - row("c_binary", null, null, null, null, null, null), - row(null, null, null, null, 2.0, null, null)); + row("c_tinyint", null, null, null, null, null, null, null), + row("c_smallint", null, null, null, null, null, null, null), + row("c_int", null, null, null, null, null, null, null), + row("c_bigint", null, null, null, null, null, null, null), + row("c_float", null, null, null, null, null, null, null), + row("c_double", null, null, null, null, null, null, null), + row("c_decimal", null, null, null, null, null, null, null), + row("c_decimal_w_params", null, null, null, null, null, null, null), + row("c_timestamp", null, null, null, null, null, null, null), + row("c_date", null, null, null, null, null, null, null), + row("c_string", null, null, null, null, null, null, null), + row("c_varchar", null, null, null, null, null, null, null), + row("c_char", null, null, null, null, null, null, null), + row("c_boolean", null, null, null, null, null, null, null), + row("c_binary", null, null, null, null, null, null, null), + row(null, null, null, null, 2.0, null, null, null)); onHive().executeQuery("ANALYZE TABLE " + tableNameInDatabase + " COMPUTE STATISTICS FOR COLUMNS"); // SHOW STATS FORMAT: column_name, data_size, distinct_values_count, nulls_fraction, row_count assertThat(query("SHOW STATS FOR " + tableNameInDatabase)).containsOnly( - row("c_tinyint", null, 2.0, 0.0, null, "121", "127"), - row("c_smallint", null, 2.0, 0.0, null, "32761", "32767"), - row("c_int", null, 2.0, 0.0, null, "2147483641", "2147483647"), - row("c_bigint", null, 2.0, 0.0, null, "9223372036854775807", "9223372036854775807"), - row("c_float", null, 2.0, 0.0, null, "123.341", "123.345"), - row("c_double", null, 2.0, 0.0, null, "234.561", "235.567"), - row("c_decimal", null, 2.0, 0.0, null, "345.0", "346.0"), - row("c_decimal_w_params", null, 2.0, 0.0, null, "345.671", "345.678"), - row("c_timestamp", null, 2.0, 0.0, null, null, null), // timestamp is shifted by hive.time-zone on read - row("c_date", null, 2.0, 0.0, null, "2015-05-09", "2015-06-10"), - row("c_string", 22.0, 2.0, 0.0, null, null, null), - row("c_varchar", 20.0, 2.0, 0.0, null, null, null), - row("c_char", 12.0, 2.0, 0.0, null, null, null), - row("c_boolean", null, 2.0, 0.0, null, null, null), - row("c_binary", 23.0, null, 0.0, null, null, null), - row(null, null, null, null, 2.0, null, null)); + row("c_tinyint", null, 2.0, 0.0, null, "121", "127", null), + row("c_smallint", null, 2.0, 0.0, null, "32761", "32767", null), + row("c_int", null, 2.0, 0.0, null, "2147483641", "2147483647", null), + row("c_bigint", null, 2.0, 0.0, null, "9223372036854775807", "9223372036854775807", null), + row("c_float", null, 2.0, 0.0, null, "123.341", "123.345", null), + row("c_double", null, 2.0, 0.0, null, "234.561", "235.567", null), + row("c_decimal", null, 2.0, 0.0, null, "345.0", "346.0", null), + row("c_decimal_w_params", null, 2.0, 0.0, null, "345.671", "345.678", null), + row("c_timestamp", null, 2.0, 0.0, null, null, null, null), // timestamp is shifted by hive.time-zone on read + row("c_date", null, 2.0, 0.0, null, "2015-05-09", "2015-06-10", null), + row("c_string", 22.0, 2.0, 0.0, null, null, null, null), + row("c_varchar", 20.0, 2.0, 0.0, null, null, null, null), + row("c_char", 12.0, 2.0, 0.0, null, null, null, null), + row("c_boolean", null, 2.0, 0.0, null, null, null, null), + row("c_binary", 23.0, null, 0.0, null, null, null, null), + row(null, null, null, null, 2.0, null, null, null)); } @Test(groups = {SKIP_ON_CDH}) // skip on cdh due to no support for date column and stats @@ -520,42 +520,42 @@ public void testStatisticsForAllDataTypesNoData() onHive().executeQuery("ANALYZE TABLE " + tableNameInDatabase + " COMPUTE STATISTICS"); assertThat(query("SHOW STATS FOR " + tableNameInDatabase)).containsOnly( - row("c_tinyint", null, null, null, null, null, null), - row("c_smallint", null, null, null, null, null, null), - row("c_int", null, null, null, null, null, null), - row("c_bigint", null, null, null, null, null, null), - row("c_float", null, null, null, null, null, null), - row("c_double", null, null, null, null, null, null), - row("c_decimal", null, null, null, null, null, null), - row("c_decimal_w_params", null, null, null, null, null, null), - row("c_timestamp", null, null, null, null, null, null), - row("c_date", null, null, null, null, null, null), - row("c_string", null, null, null, null, null, null), - row("c_varchar", null, null, null, null, null, null), - row("c_char", null, null, null, null, null, null), - row("c_boolean", null, null, null, null, null, null), - row("c_binary", null, null, null, null, null, null), - row(null, null, null, null, 0.0, null, null)); + row("c_tinyint", null, null, null, null, null, null, null), + row("c_smallint", null, null, null, null, null, null, null), + row("c_int", null, null, null, null, null, null, null), + row("c_bigint", null, null, null, null, null, null, null), + row("c_float", null, null, null, null, null, null, null), + row("c_double", null, null, null, null, null, null, null), + row("c_decimal", null, null, null, null, null, null, null), + row("c_decimal_w_params", null, null, null, null, null, null, null), + row("c_timestamp", null, null, null, null, null, null, null), + row("c_date", null, null, null, null, null, null, null), + row("c_string", null, null, null, null, null, null, null), + row("c_varchar", null, null, null, null, null, null, null), + row("c_char", null, null, null, null, null, null, null), + row("c_boolean", null, null, null, null, null, null, null), + row("c_binary", null, null, null, null, null, null, null), + row(null, null, null, null, 0.0, null, null, null)); onHive().executeQuery("ANALYZE TABLE " + tableNameInDatabase + " COMPUTE STATISTICS FOR COLUMNS"); assertThat(query("SHOW STATS FOR " + tableNameInDatabase)).containsOnly( - row("c_tinyint", null, 0.0, 0.0, null, null, null), - row("c_smallint", null, 0.0, 0.0, null, null, null), - row("c_int", null, 0.0, 0.0, null, null, null), - row("c_bigint", null, 0.0, 0.0, null, null, null), - row("c_float", null, 0.0, 0.0, null, null, null), - row("c_double", null, 0.0, 0.0, null, null, null), - row("c_decimal", null, 0.0, 0.0, null, null, null), - row("c_decimal_w_params", null, 0.0, 0.0, null, null, null), - row("c_timestamp", null, 0.0, 0.0, null, null, null), - row("c_date", null, 0.0, 0.0, null, null, null), - row("c_string", 0.0, 0.0, 0.0, null, null, null), - row("c_varchar", 0.0, 0.0, 0.0, null, null, null), - row("c_char", 0.0, 0.0, 0.0, null, null, null), - row("c_boolean", null, 0.0, 0.0, null, null, null), - row("c_binary", 0.0, null, 0.0, null, null, null), - row(null, null, null, null, 0.0, null, null)); + row("c_tinyint", null, 0.0, 0.0, null, null, null, null), + row("c_smallint", null, 0.0, 0.0, null, null, null, null), + row("c_int", null, 0.0, 0.0, null, null, null, null), + row("c_bigint", null, 0.0, 0.0, null, null, null, null), + row("c_float", null, 0.0, 0.0, null, null, null, null), + row("c_double", null, 0.0, 0.0, null, null, null, null), + row("c_decimal", null, 0.0, 0.0, null, null, null, null), + row("c_decimal_w_params", null, 0.0, 0.0, null, null, null, null), + row("c_timestamp", null, 0.0, 0.0, null, null, null, null), + row("c_date", null, 0.0, 0.0, null, null, null, null), + row("c_string", 0.0, 0.0, 0.0, null, null, null, null), + row("c_varchar", 0.0, 0.0, 0.0, null, null, null, null), + row("c_char", 0.0, 0.0, 0.0, null, null, null, null), + row("c_boolean", null, 0.0, 0.0, null, null, null, null), + row("c_binary", 0.0, null, 0.0, null, null, null, null), + row(null, null, null, null, 0.0, null, null, null)); } @Test(groups = {SKIP_ON_CDH}) // skip on cdh due to no support for date column and stats @@ -568,42 +568,42 @@ public void testStatisticsForAllDataTypesOnlyNulls() onHive().executeQuery("ANALYZE TABLE " + tableNameInDatabase + " COMPUTE STATISTICS"); assertThat(query("SHOW STATS FOR " + tableNameInDatabase)).containsOnly( - row("c_tinyint", null, null, null, null, null, null), - row("c_smallint", null, null, null, null, null, null), - row("c_int", null, null, null, null, null, null), - row("c_bigint", null, null, null, null, null, null), - row("c_float", null, null, null, null, null, null), - row("c_double", null, null, null, null, null, null), - row("c_decimal", null, null, null, null, null, null), - row("c_decimal_w_params", null, null, null, null, null, null), - row("c_timestamp", null, null, null, null, null, null), - row("c_date", null, null, null, null, null, null), - row("c_string", null, null, null, null, null, null), - row("c_varchar", null, null, null, null, null, null), - row("c_char", null, null, null, null, null, null), - row("c_boolean", null, null, null, null, null, null), - row("c_binary", null, null, null, null, null, null), - row(null, null, null, null, 1.0, null, null)); + row("c_tinyint", null, null, null, null, null, null, null), + row("c_smallint", null, null, null, null, null, null, null), + row("c_int", null, null, null, null, null, null, null), + row("c_bigint", null, null, null, null, null, null, null), + row("c_float", null, null, null, null, null, null, null), + row("c_double", null, null, null, null, null, null, null), + row("c_decimal", null, null, null, null, null, null, null), + row("c_decimal_w_params", null, null, null, null, null, null, null), + row("c_timestamp", null, null, null, null, null, null, null), + row("c_date", null, null, null, null, null, null, null), + row("c_string", null, null, null, null, null, null, null), + row("c_varchar", null, null, null, null, null, null, null), + row("c_char", null, null, null, null, null, null, null), + row("c_boolean", null, null, null, null, null, null, null), + row("c_binary", null, null, null, null, null, null, null), + row(null, null, null, null, 1.0, null, null, null)); onHive().executeQuery("ANALYZE TABLE " + tableNameInDatabase + " COMPUTE STATISTICS FOR COLUMNS"); assertThat(query("SHOW STATS FOR " + tableNameInDatabase)).containsOnly( - row("c_tinyint", null, 0.0, 1.0, null, null, null), - row("c_smallint", null, 0.0, 1.0, null, null, null), - row("c_int", null, 0.0, 1.0, null, null, null), - row("c_bigint", null, 0.0, 1.0, null, null, null), - row("c_float", null, 0.0, 1.0, null, null, null), - row("c_double", null, 0.0, 1.0, null, null, null), - row("c_decimal", null, 0.0, 1.0, null, null, null), - row("c_decimal_w_params", null, 0.0, 1.0, null, null, null), - row("c_timestamp", null, 0.0, 1.0, null, null, null), - row("c_date", null, 0.0, 1.0, null, null, null), - row("c_string", 0.0, 0.0, 1.0, null, null, null), - row("c_varchar", 0.0, 0.0, 1.0, null, null, null), - row("c_char", 0.0, 0.0, 1.0, null, null, null), - row("c_boolean", null, 0.0, 1.0, null, null, null), - row("c_binary", 0.0, null, 1.0, null, null, null), - row(null, null, null, null, 1.0, null, null)); + row("c_tinyint", null, 0.0, 1.0, null, null, null, null), + row("c_smallint", null, 0.0, 1.0, null, null, null, null), + row("c_int", null, 0.0, 1.0, null, null, null, null), + row("c_bigint", null, 0.0, 1.0, null, null, null, null), + row("c_float", null, 0.0, 1.0, null, null, null, null), + row("c_double", null, 0.0, 1.0, null, null, null, null), + row("c_decimal", null, 0.0, 1.0, null, null, null, null), + row("c_decimal_w_params", null, 0.0, 1.0, null, null, null, null), + row("c_timestamp", null, 0.0, 1.0, null, null, null, null), + row("c_date", null, 0.0, 1.0, null, null, null, null), + row("c_string", 0.0, 0.0, 1.0, null, null, null, null), + row("c_varchar", 0.0, 0.0, 1.0, null, null, null, null), + row("c_char", 0.0, 0.0, 1.0, null, null, null, null), + row("c_boolean", null, 0.0, 1.0, null, null, null, null), + row("c_binary", 0.0, null, 1.0, null, null, null, null), + row(null, null, null, null, 1.0, null, null, null)); } @Test @@ -616,22 +616,22 @@ public void testStatisticsForSkewedTable() onHive().executeQuery("INSERT INTO TABLE " + tableName + " VALUES ('c1', 1), ('c1', 2)"); assertThat(query("SHOW STATS FOR " + tableName)).containsOnly( - row("c_string", null, null, null, null, null, null), - row("c_int", null, null, null, null, null, null), - row(null, null, null, null, 2.0, null, null)); + row("c_string", null, null, null, null, null, null, null), + row("c_int", null, null, null, null, null, null, null), + row(null, null, null, null, 2.0, null, null, null)); onHive().executeQuery("ANALYZE TABLE " + tableName + " COMPUTE STATISTICS"); assertThat(query("SHOW STATS FOR " + tableName)).containsOnly( - row("c_string", null, null, null, null, null, null), - row("c_int", null, null, null, null, null, null), - row(null, null, null, null, 2.0, null, null)); + row("c_string", null, null, null, null, null, null, null), + row("c_int", null, null, null, null, null, null, null), + row(null, null, null, null, 2.0, null, null, null)); onHive().executeQuery("ANALYZE TABLE " + tableName + " COMPUTE STATISTICS FOR COLUMNS"); assertThat(query("SHOW STATS FOR " + tableName)).containsOnly( - row("c_string", 4.0, 1.0, 0.0, null, null, null), - row("c_int", null, 2.0, 0.0, null, "1", "2"), - row(null, null, null, null, 2.0, null, null)); + row("c_string", 4.0, 1.0, 0.0, null, null, null, null), + row("c_int", null, 2.0, 0.0, null, "1", "2", null), + row(null, null, null, null, 2.0, null, null, null)); } @Test @@ -644,15 +644,15 @@ public void testAnalyzesForSkewedTable() onHive().executeQuery("INSERT INTO TABLE " + tableName + " VALUES ('c1', 1), ('c1', 2)"); assertThat(query("SHOW STATS FOR " + tableName)).containsOnly( - row("c_string", null, null, null, null, null, null), - row("c_int", null, null, null, null, null, null), - row(null, null, null, null, 2.0, null, null)); + row("c_string", null, null, null, null, null, null, null), + row("c_int", null, null, null, null, null, null, null), + row(null, null, null, null, 2.0, null, null, null)); assertThat(query("ANALYZE " + tableName)).containsExactly(row(2)); assertThat(query("SHOW STATS FOR " + tableName)).containsOnly( - row("c_string", 4.0, 1.0, 0.0, null, null, null), - row("c_int", null, 2.0, 0.0, null, "1", "2"), - row(null, null, null, null, 2.0, null, null)); + row("c_string", 4.0, 1.0, 0.0, null, null, null, null), + row("c_int", null, 2.0, 0.0, null, "1", "2", null), + row(null, null, null, null, 2.0, null, null, null)); } @Test @@ -665,20 +665,20 @@ public void testAnalyzeForUnpartitionedTable() // table not analyzed assertThat(query(showStatsWholeTable)).containsOnly( - row("n_nationkey", null, null, anyOf(null, 0.0), null, null, null), - row("n_name", null, null, anyOf(null, 0.0), null, null, null), - row("n_regionkey", null, null, anyOf(null, 0.0), null, null, null), - row("n_comment", null, null, anyOf(null, 0.0), null, null, null), - row(null, null, null, null, anyOf(null, 0.0), null, null)); // anyOf because of different behaviour on HDP (hive 1.2) and CDH (hive 1.1) + row("n_nationkey", null, null, anyOf(null, 0.0), null, null, null, null), + row("n_name", null, null, anyOf(null, 0.0), null, null, null, null), + row("n_regionkey", null, null, anyOf(null, 0.0), null, null, null, null), + row("n_comment", null, null, anyOf(null, 0.0), null, null, null, null), + row(null, null, null, null, anyOf(null, 0.0), null, null, null)); // anyOf because of different behaviour on HDP (hive 1.2) and CDH (hive 1.1) assertThat(query("ANALYZE " + tableNameInDatabase)).containsExactly(row(25)); assertThat(query(showStatsWholeTable)).containsOnly( - row("n_nationkey", null, 25.0, 0.0, null, "0", "24"), - row("n_name", 177.0, 25.0, 0.0, null, null, null), - row("n_regionkey", null, 5.0, 0.0, null, "0", "4"), - row("n_comment", 1857.0, 25.0, 0.0, null, null, null), - row(null, null, null, null, 25.0, null, null)); + row("n_nationkey", null, 25.0, 0.0, null, "0", "24", null), + row("n_name", 177.0, 25.0, 0.0, null, null, null, null), + row("n_regionkey", null, 5.0, 0.0, null, "0", "4", null), + row("n_comment", 1857.0, 25.0, 0.0, null, null, null, null), + row(null, null, null, null, 25.0, null, null, null)); } @Test @@ -693,10 +693,10 @@ public void testAnalyzeForTableWithNonPrimitiveTypes() assertThat(query("ANALYZE " + tableName)).containsExactly(row(1)); assertThat(query(showStatsTable)).containsOnly( - row("c_row", null, null, null, null, null, null), - row("c_char", 1.0, 1.0, 0.0, null, null, null), - row("c_int", null, 1.0, 0.0, null, "3", "3"), - row(null, null, null, null, 1.0, null, null)); + row("c_row", null, null, null, null, null, null, null), + row("c_char", 1.0, 1.0, 0.0, null, null, null, null), + row("c_int", null, 1.0, 0.0, null, "3", "3", null), + row(null, null, null, null, 1.0, null, null, null)); } @Test @@ -715,11 +715,11 @@ public void testAnalyzeForPartitionedTableWithNonPrimitiveTypes() assertThat(query("ANALYZE " + tableName)).containsExactly(row(3)); assertThat(query(showStatsTable)).containsOnly( - row("c_row", null, null, null, null, null, null), - row("c_char", 3.0, 2.0, 0.0, null, null, null), - row("c_int", null, 1.0, 0.0, null, "3", "5"), - row("c_part", 3.0, 2.0, 0.0, null, null, null), - row(null, null, null, null, 3.0, null, null)); + row("c_row", null, null, null, null, null, null, null), + row("c_char", 3.0, 2.0, 0.0, null, null, null, null), + row("c_int", null, 1.0, 0.0, null, "3", "5", null), + row("c_part", 3.0, 2.0, 0.0, null, null, null, null), + row(null, null, null, null, 3.0, null, null, null)); } @Test @@ -735,68 +735,68 @@ public void testAnalyzeForTablePartitionedByBigint() // table not analyzed assertThat(query(showStatsWholeTable)).containsOnly( - row("p_nationkey", null, null, null, null, null, null), - row("p_name", null, null, null, null, null, null), - row("p_regionkey", null, null, null, null, null, null), - row("p_comment", null, null, null, null, null, null), - row(null, null, null, null, null, null, null)); + row("p_nationkey", null, null, null, null, null, null, null), + row("p_name", null, null, null, null, null, null, null), + row("p_regionkey", null, null, null, null, null, null, null), + row("p_comment", null, null, null, null, null, null, null), + row(null, null, null, null, null, null, null, null)); assertThat(query(showStatsPartitionOne)).containsOnly( - row("p_nationkey", null, null, null, null, null, null), - row("p_name", null, null, null, null, null, null), - row("p_regionkey", null, null, null, null, null, null), - row("p_comment", null, null, null, null, null, null), - row(null, null, null, null, null, null, null)); + row("p_nationkey", null, null, null, null, null, null, null), + row("p_name", null, null, null, null, null, null, null), + row("p_regionkey", null, null, null, null, null, null, null), + row("p_comment", null, null, null, null, null, null, null), + row(null, null, null, null, null, null, null, null)); // analyze for single partition assertThat(query("ANALYZE " + tableNameInDatabase + " WITH (partitions = ARRAY[ARRAY['1']])")).containsExactly(row(5)); assertThat(query(showStatsWholeTable)).containsOnly( - row("p_nationkey", null, 5.0, 0.0, null, "1", "24"), - row("p_name", 114.0, 5.0, 0.0, null, null, null), - row("p_regionkey", null, 3.0, 0.0, null, "1", "3"), - row("p_comment", 1497.0, 5.0, 0.0, null, null, null), - row(null, null, null, null, 15.0, null, null)); + row("p_nationkey", null, 5.0, 0.0, null, "1", "24", null), + row("p_name", 114.0, 5.0, 0.0, null, null, null, null), + row("p_regionkey", null, 3.0, 0.0, null, "1", "3", null), + row("p_comment", 1497.0, 5.0, 0.0, null, null, null, null), + row(null, null, null, null, 15.0, null, null, null)); assertThat(query(showStatsPartitionOne)).containsOnly( - row("p_nationkey", null, 5.0, 0.0, null, "1", "24"), - row("p_name", 38.0, 5.0, 0.0, null, null, null), - row("p_regionkey", null, 1.0, 0.0, null, "1", "1"), - row("p_comment", 499.0, 5.0, 0.0, null, null, null), - row(null, null, null, null, 5.0, null, null)); + row("p_nationkey", null, 5.0, 0.0, null, "1", "24", null), + row("p_name", 38.0, 5.0, 0.0, null, null, null, null), + row("p_regionkey", null, 1.0, 0.0, null, "1", "1", null), + row("p_comment", 499.0, 5.0, 0.0, null, null, null, null), + row(null, null, null, null, 5.0, null, null, null)); assertThat(query(showStatsPartitionTwo)).containsOnly( - row("p_nationkey", null, null, null, null, null, null), - row("p_name", null, null, null, null, null, null), - row("p_regionkey", null, null, null, null, null, null), - row("p_comment", null, null, null, null, null, null), - row(null, null, null, null, null, null, null)); + row("p_nationkey", null, null, null, null, null, null, null), + row("p_name", null, null, null, null, null, null, null), + row("p_regionkey", null, null, null, null, null, null, null), + row("p_comment", null, null, null, null, null, null, null), + row(null, null, null, null, null, null, null, null)); // analyze for all partitions assertThat(query("ANALYZE " + tableNameInDatabase)).containsExactly(row(15)); assertThat(query(showStatsWholeTable)).containsOnly( - row("p_nationkey", null, 5.0, 0.0, null, "1", "24"), - row("p_name", 109.0, 5.0, 0.0, null, null, null), - row("p_regionkey", null, 3.0, 0.0, null, "1", "3"), - row("p_comment", 1197.0, 5.0, 0.0, null, null, null), - row(null, null, null, null, 15.0, null, null)); + row("p_nationkey", null, 5.0, 0.0, null, "1", "24", null), + row("p_name", 109.0, 5.0, 0.0, null, null, null, null), + row("p_regionkey", null, 3.0, 0.0, null, "1", "3", null), + row("p_comment", 1197.0, 5.0, 0.0, null, null, null, null), + row(null, null, null, null, 15.0, null, null, null)); assertThat(query(showStatsPartitionOne)).containsOnly( - row("p_nationkey", null, 5.0, 0.0, null, "1", "24"), - row("p_name", 38.0, 5.0, 0.0, null, null, null), - row("p_regionkey", null, 1.0, 0.0, null, "1", "1"), - row("p_comment", 499.0, 5.0, 0.0, null, null, null), - row(null, null, null, null, 5.0, null, null)); + row("p_nationkey", null, 5.0, 0.0, null, "1", "24", null), + row("p_name", 38.0, 5.0, 0.0, null, null, null, null), + row("p_regionkey", null, 1.0, 0.0, null, "1", "1", null), + row("p_comment", 499.0, 5.0, 0.0, null, null, null, null), + row(null, null, null, null, 5.0, null, null, null)); assertThat(query(showStatsPartitionTwo)).containsOnly( - row("p_nationkey", null, 5.0, 0.0, null, "8", "21"), - row("p_name", 31.0, 5.0, 0.0, null, null, null), - row("p_regionkey", null, 1.0, 0.0, null, "2", "2"), - row("p_comment", 351.0, 5.0, 0.0, null, null, null), - row(null, null, null, null, 5.0, null, null)); + row("p_nationkey", null, 5.0, 0.0, null, "8", "21", null), + row("p_name", 31.0, 5.0, 0.0, null, null, null, null), + row("p_regionkey", null, 1.0, 0.0, null, "2", "2", null), + row("p_comment", 351.0, 5.0, 0.0, null, null, null, null), + row(null, null, null, null, 5.0, null, null, null)); } @Test @@ -812,68 +812,68 @@ public void testAnalyzeForTablePartitionedByVarchar() // table not analyzed assertThat(query(showStatsWholeTable)).containsOnly( - row("p_nationkey", null, null, null, null, null, null), - row("p_name", null, null, null, null, null, null), - row("p_regionkey", null, null, null, null, null, null), - row("p_comment", null, null, null, null, null, null), - row(null, null, null, null, null, null, null)); + row("p_nationkey", null, null, null, null, null, null, null), + row("p_name", null, null, null, null, null, null, null), + row("p_regionkey", null, null, null, null, null, null, null), + row("p_comment", null, null, null, null, null, null, null), + row(null, null, null, null, null, null, null, null)); assertThat(query(showStatsPartitionOne)).containsOnly( - row("p_nationkey", null, null, null, null, null, null), - row("p_name", null, null, null, null, null, null), - row("p_regionkey", null, null, null, null, null, null), - row("p_comment", null, null, null, null, null, null), - row(null, null, null, null, null, null, null)); + row("p_nationkey", null, null, null, null, null, null, null), + row("p_name", null, null, null, null, null, null, null), + row("p_regionkey", null, null, null, null, null, null, null), + row("p_comment", null, null, null, null, null, null, null), + row(null, null, null, null, null, null, null, null)); // analyze for single partition assertThat(query("ANALYZE " + tableNameInDatabase + " WITH (partitions = ARRAY[ARRAY['AMERICA']])")).containsExactly(row(5)); assertThat(query(showStatsWholeTable)).containsOnly( - row("p_nationkey", null, 5.0, 0.0, null, "1", "24"), - row("p_name", 114.0, 5.0, 0.0, null, null, null), - row("p_regionkey", 85.0, 3.0, 0.0, null, null, null), - row("p_comment", 1497.0, 5.0, 0.0, null, null, null), - row(null, null, null, null, 15.0, null, null)); + row("p_nationkey", null, 5.0, 0.0, null, "1", "24", null), + row("p_name", 114.0, 5.0, 0.0, null, null, null, null), + row("p_regionkey", 85.0, 3.0, 0.0, null, null, null, null), + row("p_comment", 1497.0, 5.0, 0.0, null, null, null, null), + row(null, null, null, null, 15.0, null, null, null)); assertThat(query(showStatsPartitionOne)).containsOnly( - row("p_nationkey", null, 5.0, 0.0, null, "1", "24"), - row("p_name", 38.0, 5.0, 0.0, null, null, null), - row("p_regionkey", 35.0, 1.0, 0.0, null, null, null), - row("p_comment", 499.0, 5.0, 0.0, null, null, null), - row(null, null, null, null, 5.0, null, null)); + row("p_nationkey", null, 5.0, 0.0, null, "1", "24", null), + row("p_name", 38.0, 5.0, 0.0, null, null, null, null), + row("p_regionkey", 35.0, 1.0, 0.0, null, null, null, null), + row("p_comment", 499.0, 5.0, 0.0, null, null, null, null), + row(null, null, null, null, 5.0, null, null, null)); assertThat(query(showStatsPartitionTwo)).containsOnly( - row("p_nationkey", null, null, null, null, null, null), - row("p_name", null, null, null, null, null, null), - row("p_regionkey", null, null, null, null, null, null), - row("p_comment", null, null, null, null, null, null), - row(null, null, null, null, null, null, null)); + row("p_nationkey", null, null, null, null, null, null, null), + row("p_name", null, null, null, null, null, null, null), + row("p_regionkey", null, null, null, null, null, null, null), + row("p_comment", null, null, null, null, null, null, null), + row(null, null, null, null, null, null, null, null)); // column analysis for all partitions assertThat(query("ANALYZE " + tableNameInDatabase)).containsExactly(row(15)); assertThat(query(showStatsWholeTable)).containsOnly( - row("p_nationkey", null, 5.0, 0.0, null, "1", "24"), - row("p_name", 109.0, 5.0, 0.0, null, null, null), - row("p_regionkey", 85.0, 3.0, 0.0, null, null, null), - row("p_comment", 1197.0, 5.0, 0.0, null, null, null), - row(null, null, null, null, 15.0, null, null)); + row("p_nationkey", null, 5.0, 0.0, null, "1", "24", null), + row("p_name", 109.0, 5.0, 0.0, null, null, null, null), + row("p_regionkey", 85.0, 3.0, 0.0, null, null, null, null), + row("p_comment", 1197.0, 5.0, 0.0, null, null, null, null), + row(null, null, null, null, 15.0, null, null, null)); assertThat(query(showStatsPartitionOne)).containsOnly( - row("p_nationkey", null, 5.0, 0.0, null, "1", "24"), - row("p_name", 38.0, 5.0, 0.0, null, null, null), - row("p_regionkey", 35.0, 1.0, 0.0, null, null, null), - row("p_comment", 499.0, 5.0, 0.0, null, null, null), - row(null, null, null, null, 5.0, null, null)); + row("p_nationkey", null, 5.0, 0.0, null, "1", "24", null), + row("p_name", 38.0, 5.0, 0.0, null, null, null, null), + row("p_regionkey", 35.0, 1.0, 0.0, null, null, null, null), + row("p_comment", 499.0, 5.0, 0.0, null, null, null, null), + row(null, null, null, null, 5.0, null, null, null)); assertThat(query(showStatsPartitionTwo)).containsOnly( - row("p_nationkey", null, 5.0, 0.0, null, "8", "21"), - row("p_name", 31.0, 5.0, 0.0, null, null, null), - row("p_regionkey", 20.0, 1.0, 0.0, null, null, null), - row("p_comment", 351.0, 5.0, 0.0, null, null, null), - row(null, null, null, null, 5.0, null, null)); + row("p_nationkey", null, 5.0, 0.0, null, "8", "21", null), + row("p_name", 31.0, 5.0, 0.0, null, null, null, null), + row("p_regionkey", 20.0, 1.0, 0.0, null, null, null, null), + row("p_comment", 351.0, 5.0, 0.0, null, null, null, null), + row(null, null, null, null, 5.0, null, null, null)); } // This covers also stats calculation for unpartitioned table @@ -884,43 +884,43 @@ public void testAnalyzeForAllDataTypes() String tableNameInDatabase = mutableTablesState().get(ALL_TYPES_TABLE_NAME).getNameInDatabase(); assertThat(query("SHOW STATS FOR " + tableNameInDatabase)).containsOnly( - row("c_tinyint", null, null, null, null, null, null), - row("c_smallint", null, null, null, null, null, null), - row("c_int", null, null, null, null, null, null), - row("c_bigint", null, null, null, null, null, null), - row("c_float", null, null, null, null, null, null), - row("c_double", null, null, null, null, null, null), - row("c_decimal", null, null, null, null, null, null), - row("c_decimal_w_params", null, null, null, null, null, null), - row("c_timestamp", null, null, null, null, null, null), - row("c_date", null, null, null, null, null, null), - row("c_string", null, null, null, null, null, null), - row("c_varchar", null, null, null, null, null, null), - row("c_char", null, null, null, null, null, null), - row("c_boolean", null, null, null, null, null, null), - row("c_binary", null, null, null, null, null, null), - row(null, null, null, null, 0.0, null, null)); + row("c_tinyint", null, null, null, null, null, null, null), + row("c_smallint", null, null, null, null, null, null, null), + row("c_int", null, null, null, null, null, null, null), + row("c_bigint", null, null, null, null, null, null, null), + row("c_float", null, null, null, null, null, null, null), + row("c_double", null, null, null, null, null, null, null), + row("c_decimal", null, null, null, null, null, null, null), + row("c_decimal_w_params", null, null, null, null, null, null, null), + row("c_timestamp", null, null, null, null, null, null, null), + row("c_date", null, null, null, null, null, null, null), + row("c_string", null, null, null, null, null, null, null), + row("c_varchar", null, null, null, null, null, null, null), + row("c_char", null, null, null, null, null, null, null), + row("c_boolean", null, null, null, null, null, null, null), + row("c_binary", null, null, null, null, null, null, null), + row(null, null, null, null, 0.0, null, null, null)); assertThat(query("ANALYZE " + tableNameInDatabase)).containsExactly(row(2)); // SHOW STATS FORMAT: column_name, data_size, distinct_values_count, nulls_fraction, row_count assertThat(query("SHOW STATS FOR " + tableNameInDatabase)).containsOnly( - row("c_tinyint", null, 2.0, 0.0, null, "121", "127"), - row("c_smallint", null, 2.0, 0.0, null, "32761", "32767"), - row("c_int", null, 2.0, 0.0, null, "2147483641", "2147483647"), - row("c_bigint", null, 2.0, 0.0, null, "9223372036854775807", "9223372036854775807"), - row("c_float", null, 2.0, 0.0, null, "123.341", "123.345"), - row("c_double", null, 2.0, 0.0, null, "234.561", "235.567"), - row("c_decimal", null, 2.0, 0.0, null, "345.0", "346.0"), - row("c_decimal_w_params", null, 2.0, 0.0, null, "345.671", "345.678"), - row("c_timestamp", null, 2.0, 0.0, null, null, null), - row("c_date", null, 2.0, 0.0, null, "2015-05-09", "2015-06-10"), - row("c_string", 22.0, 2.0, 0.0, null, null, null), - row("c_varchar", 20.0, 2.0, 0.0, null, null, null), - row("c_char", 12.0, 2.0, 0.0, null, null, null), - row("c_boolean", null, 2.0, 0.0, null, null, null), - row("c_binary", 23.0, null, 0.0, null, null, null), - row(null, null, null, null, 2.0, null, null)); + row("c_tinyint", null, 2.0, 0.0, null, "121", "127", null), + row("c_smallint", null, 2.0, 0.0, null, "32761", "32767", null), + row("c_int", null, 2.0, 0.0, null, "2147483641", "2147483647", null), + row("c_bigint", null, 2.0, 0.0, null, "9223372036854775807", "9223372036854775807", null), + row("c_float", null, 2.0, 0.0, null, "123.341", "123.345", null), + row("c_double", null, 2.0, 0.0, null, "234.561", "235.567", null), + row("c_decimal", null, 2.0, 0.0, null, "345.0", "346.0", null), + row("c_decimal_w_params", null, 2.0, 0.0, null, "345.671", "345.678", null), + row("c_timestamp", null, 2.0, 0.0, null, null, null, null), + row("c_date", null, 2.0, 0.0, null, "2015-05-09", "2015-06-10", null), + row("c_string", 22.0, 2.0, 0.0, null, null, null, null), + row("c_varchar", 20.0, 2.0, 0.0, null, null, null, null), + row("c_char", 12.0, 2.0, 0.0, null, null, null, null), + row("c_boolean", null, 2.0, 0.0, null, null, null, null), + row("c_binary", 23.0, null, 0.0, null, null, null, null), + row(null, null, null, null, 2.0, null, null, null)); } @Test(groups = {SKIP_ON_CDH}) // skip on cdh due to no support for date column and stats @@ -930,42 +930,42 @@ public void testAnalyzeForAllDataTypesNoData() String tableNameInDatabase = mutableTablesState().get(EMPTY_ALL_TYPES_TABLE_NAME).getNameInDatabase(); assertThat(query("SHOW STATS FOR " + tableNameInDatabase)).containsOnly( - row("c_tinyint", null, null, null, null, null, null), - row("c_smallint", null, null, null, null, null, null), - row("c_int", null, null, null, null, null, null), - row("c_bigint", null, null, null, null, null, null), - row("c_float", null, null, null, null, null, null), - row("c_double", null, null, null, null, null, null), - row("c_decimal", null, null, null, null, null, null), - row("c_decimal_w_params", null, null, null, null, null, null), - row("c_timestamp", null, null, null, null, null, null), - row("c_date", null, null, null, null, null, null), - row("c_string", null, null, null, null, null, null), - row("c_varchar", null, null, null, null, null, null), - row("c_char", null, null, null, null, null, null), - row("c_boolean", null, null, null, null, null, null), - row("c_binary", null, null, null, null, null, null), - row(null, null, null, null, 0.0, null, null)); + row("c_tinyint", null, null, null, null, null, null, null), + row("c_smallint", null, null, null, null, null, null, null), + row("c_int", null, null, null, null, null, null, null), + row("c_bigint", null, null, null, null, null, null, null), + row("c_float", null, null, null, null, null, null, null), + row("c_double", null, null, null, null, null, null, null), + row("c_decimal", null, null, null, null, null, null, null), + row("c_decimal_w_params", null, null, null, null, null, null, null), + row("c_timestamp", null, null, null, null, null, null, null), + row("c_date", null, null, null, null, null, null, null), + row("c_string", null, null, null, null, null, null, null), + row("c_varchar", null, null, null, null, null, null, null), + row("c_char", null, null, null, null, null, null, null), + row("c_boolean", null, null, null, null, null, null, null), + row("c_binary", null, null, null, null, null, null, null), + row(null, null, null, null, 0.0, null, null, null)); assertThat(query("ANALYZE " + tableNameInDatabase)).containsExactly(row(0)); assertThat(query("SHOW STATS FOR " + tableNameInDatabase)).containsOnly( - row("c_tinyint", null, 0.0, 0.0, null, null, null), - row("c_smallint", null, 0.0, 0.0, null, null, null), - row("c_int", null, 0.0, 0.0, null, null, null), - row("c_bigint", null, 0.0, 0.0, null, null, null), - row("c_float", null, 0.0, 0.0, null, null, null), - row("c_double", null, 0.0, 0.0, null, null, null), - row("c_decimal", null, 0.0, 0.0, null, null, null), - row("c_decimal_w_params", null, 0.0, 0.0, null, null, null), - row("c_timestamp", null, 0.0, 0.0, null, null, null), - row("c_date", null, 0.0, 0.0, null, null, null), - row("c_string", 0.0, 0.0, 0.0, null, null, null), - row("c_varchar", 0.0, 0.0, 0.0, null, null, null), - row("c_char", 0.0, 0.0, 0.0, null, null, null), - row("c_boolean", null, 0.0, 0.0, null, null, null), - row("c_binary", 0.0, null, 0.0, null, null, null), - row(null, null, null, null, 0.0, null, null)); + row("c_tinyint", null, 0.0, 0.0, null, null, null, null), + row("c_smallint", null, 0.0, 0.0, null, null, null, null), + row("c_int", null, 0.0, 0.0, null, null, null, null), + row("c_bigint", null, 0.0, 0.0, null, null, null, null), + row("c_float", null, 0.0, 0.0, null, null, null, null), + row("c_double", null, 0.0, 0.0, null, null, null, null), + row("c_decimal", null, 0.0, 0.0, null, null, null, null), + row("c_decimal_w_params", null, 0.0, 0.0, null, null, null, null), + row("c_timestamp", null, 0.0, 0.0, null, null, null, null), + row("c_date", null, 0.0, 0.0, null, null, null, null), + row("c_string", 0.0, 0.0, 0.0, null, null, null, null), + row("c_varchar", 0.0, 0.0, 0.0, null, null, null, null), + row("c_char", 0.0, 0.0, 0.0, null, null, null, null), + row("c_boolean", null, 0.0, 0.0, null, null, null, null), + row("c_binary", 0.0, null, 0.0, null, null, null, null), + row(null, null, null, null, 0.0, null, null, null)); } @Test(groups = {SKIP_ON_CDH}) // skip on cdh due to no support for date column and stats @@ -978,42 +978,42 @@ public void testAnalyzeForAllDataTypesOnlyNulls() onHive().executeQuery("INSERT INTO TABLE " + tableNameInDatabase + " VALUES(null, null, null, null, null, null, null, null, null, null, null, null, null, null, null)"); assertThat(query("SHOW STATS FOR " + tableNameInDatabase)).containsOnly( - row("c_tinyint", null, null, null, null, null, null), - row("c_smallint", null, null, null, null, null, null), - row("c_int", null, null, null, null, null, null), - row("c_bigint", null, null, null, null, null, null), - row("c_float", null, null, null, null, null, null), - row("c_double", null, null, null, null, null, null), - row("c_decimal", null, null, null, null, null, null), - row("c_decimal_w_params", null, null, null, null, null, null), - row("c_timestamp", null, null, null, null, null, null), - row("c_date", null, null, null, null, null, null), - row("c_string", null, null, null, null, null, null), - row("c_varchar", null, null, null, null, null, null), - row("c_char", null, null, null, null, null, null), - row("c_boolean", null, null, null, null, null, null), - row("c_binary", null, null, null, null, null, null), - row(null, null, null, null, 1.0, null, null)); + row("c_tinyint", null, null, null, null, null, null, null), + row("c_smallint", null, null, null, null, null, null, null), + row("c_int", null, null, null, null, null, null, null), + row("c_bigint", null, null, null, null, null, null, null), + row("c_float", null, null, null, null, null, null, null), + row("c_double", null, null, null, null, null, null, null), + row("c_decimal", null, null, null, null, null, null, null), + row("c_decimal_w_params", null, null, null, null, null, null, null), + row("c_timestamp", null, null, null, null, null, null, null), + row("c_date", null, null, null, null, null, null, null), + row("c_string", null, null, null, null, null, null, null), + row("c_varchar", null, null, null, null, null, null, null), + row("c_char", null, null, null, null, null, null, null), + row("c_boolean", null, null, null, null, null, null, null), + row("c_binary", null, null, null, null, null, null, null), + row(null, null, null, null, 1.0, null, null, null)); assertThat(query("ANALYZE " + tableNameInDatabase)).containsExactly(row(1)); assertThat(query("SHOW STATS FOR " + tableNameInDatabase)).containsOnly( - row("c_tinyint", null, 0.0, 1.0, null, null, null), - row("c_smallint", null, 0.0, 1.0, null, null, null), - row("c_int", null, 0.0, 1.0, null, null, null), - row("c_bigint", null, 0.0, 1.0, null, null, null), - row("c_float", null, 0.0, 1.0, null, null, null), - row("c_double", null, 0.0, 1.0, null, null, null), - row("c_decimal", null, 0.0, 1.0, null, null, null), - row("c_decimal_w_params", null, 0.0, 1.0, null, null, null), - row("c_timestamp", null, 0.0, 1.0, null, null, null), - row("c_date", null, 0.0, 1.0, null, null, null), - row("c_string", 0.0, 0.0, 1.0, null, null, null), - row("c_varchar", 0.0, 0.0, 1.0, null, null, null), - row("c_char", 0.0, 0.0, 1.0, null, null, null), - row("c_boolean", null, 0.0, 1.0, null, null, null), - row("c_binary", 0.0, null, 1.0, null, null, null), - row(null, null, null, null, 1.0, null, null)); + row("c_tinyint", null, 0.0, 1.0, null, null, null, null), + row("c_smallint", null, 0.0, 1.0, null, null, null, null), + row("c_int", null, 0.0, 1.0, null, null, null, null), + row("c_bigint", null, 0.0, 1.0, null, null, null, null), + row("c_float", null, 0.0, 1.0, null, null, null, null), + row("c_double", null, 0.0, 1.0, null, null, null, null), + row("c_decimal", null, 0.0, 1.0, null, null, null, null), + row("c_decimal_w_params", null, 0.0, 1.0, null, null, null, null), + row("c_timestamp", null, 0.0, 1.0, null, null, null, null), + row("c_date", null, 0.0, 1.0, null, null, null, null), + row("c_string", 0.0, 0.0, 1.0, null, null, null, null), + row("c_varchar", 0.0, 0.0, 1.0, null, null, null, null), + row("c_char", 0.0, 0.0, 1.0, null, null, null, null), + row("c_boolean", null, 0.0, 1.0, null, null, null, null), + row("c_binary", 0.0, null, 1.0, null, null, null, null), + row(null, null, null, null, 1.0, null, null, null)); } @Test @@ -1049,22 +1049,22 @@ public void testComputeTableStatisticsOnInsert() query(format("INSERT INTO %s SELECT * FROM %s", tableName, allTypesAllNullTable)); query(format("INSERT INTO %s SELECT * FROM %s", tableName, allTypesAllNullTable)); assertThat(query("SHOW STATS FOR " + tableName)).containsOnly( - row("c_tinyint", null, 2.0, 0.5, null, "121", "127"), - row("c_smallint", null, 2.0, 0.5, null, "32761", "32767"), - row("c_int", null, 2.0, 0.5, null, "2147483641", "2147483647"), - row("c_bigint", null, 2.0, 0.5, null, "9223372036854775807", "9223372036854775807"), - row("c_float", null, 2.0, 0.5, null, "123.341", "123.345"), - row("c_double", null, 2.0, 0.5, null, "234.561", "235.567"), - row("c_decimal", null, 2.0, 0.5, null, "345.0", "346.0"), - row("c_decimal_w_params", null, 2.0, 0.5, null, "345.671", "345.678"), - row("c_timestamp", null, 2.0, 0.5, null, null, null), - row("c_date", null, 2.0, 0.5, null, "2015-05-09", "2015-06-10"), - row("c_string", 22.0, 2.0, 0.5, null, null, null), - row("c_varchar", 20.0, 2.0, 0.5, null, null, null), - row("c_char", 12.0, 2.0, 0.5, null, null, null), - row("c_boolean", null, 2.0, 0.5, null, null, null), - row("c_binary", 23.0, null, 0.5, null, null, null), - row(null, null, null, null, 4.0, null, null)); + row("c_tinyint", null, 2.0, 0.5, null, "121", "127", null), + row("c_smallint", null, 2.0, 0.5, null, "32761", "32767", null), + row("c_int", null, 2.0, 0.5, null, "2147483641", "2147483647", null), + row("c_bigint", null, 2.0, 0.5, null, "9223372036854775807", "9223372036854775807", null), + row("c_float", null, 2.0, 0.5, null, "123.341", "123.345", null), + row("c_double", null, 2.0, 0.5, null, "234.561", "235.567", null), + row("c_decimal", null, 2.0, 0.5, null, "345.0", "346.0", null), + row("c_decimal_w_params", null, 2.0, 0.5, null, "345.671", "345.678", null), + row("c_timestamp", null, 2.0, 0.5, null, null, null, null), + row("c_date", null, 2.0, 0.5, null, "2015-05-09", "2015-06-10", null), + row("c_string", 22.0, 2.0, 0.5, null, null, null, null), + row("c_varchar", 20.0, 2.0, 0.5, null, null, null, null), + row("c_char", 12.0, 2.0, 0.5, null, null, null, null), + row("c_boolean", null, 2.0, 0.5, null, null, null, null), + row("c_binary", 23.0, null, 0.5, null, null, null, null), + row(null, null, null, null, 4.0, null, null, null)); query(format("INSERT INTO %s VALUES( " + "TINYINT '120', " + @@ -1084,22 +1084,22 @@ public void testComputeTableStatisticsOnInsert() "CAST('cGllcyBiaW5hcm54' as VARBINARY))", tableName)); assertThat(query("SHOW STATS FOR " + tableName)).containsOnly(ImmutableList.of( - row("c_tinyint", null, 2.0, 0.4, null, "120", "127"), - row("c_smallint", null, 2.0, 0.4, null, "32760", "32767"), - row("c_int", null, 2.0, 0.4, null, "2147483640", "2147483647"), - row("c_bigint", null, 2.0, 0.4, null, "9223372036854775807", "9223372036854775807"), - row("c_float", null, 2.0, 0.4, null, "123.34", "123.345"), - row("c_double", null, 2.0, 0.4, null, "234.56", "235.567"), - row("c_decimal", null, 2.0, 0.4, null, "343.0", "346.0"), - row("c_decimal_w_params", null, 2.0, 0.4, null, "345.67", "345.678"), - row("c_timestamp", null, 2.0, 0.4, null, null, null), - row("c_date", null, 2.0, 0.4, null, "2015-05-08", "2015-06-10"), - row("c_string", 32.0, 2.0, 0.4, null, null, null), - row("c_varchar", 29.0, 2.0, 0.4, null, null, null), - row("c_char", 17.0, 2.0, 0.4, null, null, null), - row("c_boolean", null, 2.0, 0.4, null, null, null), - row("c_binary", 39.0, null, 0.4, null, null, null), - row(null, null, null, null, 5.0, null, null))); + row("c_tinyint", null, 2.0, 0.4, null, "120", "127", null), + row("c_smallint", null, 2.0, 0.4, null, "32760", "32767", null), + row("c_int", null, 2.0, 0.4, null, "2147483640", "2147483647", null), + row("c_bigint", null, 2.0, 0.4, null, "9223372036854775807", "9223372036854775807", null), + row("c_float", null, 2.0, 0.4, null, "123.34", "123.345", null), + row("c_double", null, 2.0, 0.4, null, "234.56", "235.567", null), + row("c_decimal", null, 2.0, 0.4, null, "343.0", "346.0", null), + row("c_decimal_w_params", null, 2.0, 0.4, null, "345.67", "345.678", null), + row("c_timestamp", null, 2.0, 0.4, null, null, null, null), + row("c_date", null, 2.0, 0.4, null, "2015-05-08", "2015-06-10", null), + row("c_string", 32.0, 2.0, 0.4, null, null, null, null), + row("c_varchar", 29.0, 2.0, 0.4, null, null, null, null), + row("c_char", 17.0, 2.0, 0.4, null, null, null, null), + row("c_boolean", null, 2.0, 0.4, null, null, null, null), + row("c_binary", 39.0, null, 0.4, null, null, null, null), + row(null, null, null, null, 5.0, null, null, null))); } finally { query(format("DROP TABLE IF EXISTS %s", tableName)); @@ -1160,44 +1160,44 @@ public void testComputePartitionStatisticsOnCreateTable() ") AS t (c_tinyint, c_smallint, c_int, c_bigint, c_float, c_double, c_decimal, c_decimal_w_params, c_timestamp, c_date, c_string, c_varchar, c_char, c_boolean, c_binary, p_bigint, p_varchar)", tableName)); assertThat(query(format("SHOW STATS FOR (SELECT * FROM %s WHERE p_bigint = 1 AND p_varchar = 'partition1')", tableName))).containsOnly(ImmutableList.of( - row("c_tinyint", null, 1.0, 0.5, null, "120", "120"), - row("c_smallint", null, 1.0, 0.5, null, "32760", "32760"), - row("c_int", null, 1.0, 0.5, null, "2147483640", "2147483640"), - row("c_bigint", null, 1.0, 0.5, null, "9223372036854775807", "9223372036854775807"), - row("c_float", null, 1.0, 0.5, null, "123.34", "123.34"), - row("c_double", null, 1.0, 0.5, null, "234.56", "234.56"), - row("c_decimal", null, 1.0, 0.5, null, "343.0", "343.0"), - row("c_decimal_w_params", null, 1.0, 0.5, null, "345.67", "345.67"), - row("c_timestamp", null, 1.0, 0.5, null, null, null), - row("c_date", null, 1.0, 0.5, null, "2015-05-08", "2015-05-08"), - row("c_string", 10.0, 1.0, 0.5, null, null, null), - row("c_varchar", 10.0, 1.0, 0.5, null, null, null), - row("c_char", 9.0, 1.0, 0.5, null, null, null), - row("c_boolean", null, 1.0, 0.5, null, null, null), - row("c_binary", 9.0, null, 0.5, null, null, null), - row("p_bigint", null, 1.0, 0.0, null, "1", "1"), - row("p_varchar", 20.0, 1.0, 0.0, null, null, null), - row(null, null, null, null, 2.0, null, null))); + row("c_tinyint", null, 1.0, 0.5, null, "120", "120", null), + row("c_smallint", null, 1.0, 0.5, null, "32760", "32760", null), + row("c_int", null, 1.0, 0.5, null, "2147483640", "2147483640", null), + row("c_bigint", null, 1.0, 0.5, null, "9223372036854775807", "9223372036854775807", null), + row("c_float", null, 1.0, 0.5, null, "123.34", "123.34", null), + row("c_double", null, 1.0, 0.5, null, "234.56", "234.56", null), + row("c_decimal", null, 1.0, 0.5, null, "343.0", "343.0", null), + row("c_decimal_w_params", null, 1.0, 0.5, null, "345.67", "345.67", null), + row("c_timestamp", null, 1.0, 0.5, null, null, null, null), + row("c_date", null, 1.0, 0.5, null, "2015-05-08", "2015-05-08", null), + row("c_string", 10.0, 1.0, 0.5, null, null, null, null), + row("c_varchar", 10.0, 1.0, 0.5, null, null, null, null), + row("c_char", 9.0, 1.0, 0.5, null, null, null, null), + row("c_boolean", null, 1.0, 0.5, null, null, null, null), + row("c_binary", 9.0, null, 0.5, null, null, null, null), + row("p_bigint", null, 1.0, 0.0, null, "1", "1", null), + row("p_varchar", 20.0, 1.0, 0.0, null, null, null, null), + row(null, null, null, null, 2.0, null, null, null))); assertThat(query(format("SHOW STATS FOR (SELECT * FROM %s WHERE p_bigint = 2 AND p_varchar = 'partition2')", tableName))).containsOnly(ImmutableList.of( - row("c_tinyint", null, 1.0, 0.5, null, "99", "99"), - row("c_smallint", null, 1.0, 0.5, null, "333", "333"), - row("c_int", null, 1.0, 0.5, null, "444", "444"), - row("c_bigint", null, 1.0, 0.5, null, "555", "555"), - row("c_float", null, 1.0, 0.5, null, "666.34", "666.34"), - row("c_double", null, 1.0, 0.5, null, "777.56", "777.56"), - row("c_decimal", null, 1.0, 0.5, null, "888.0", "888.0"), - row("c_decimal_w_params", null, 1.0, 0.5, null, "999.67", "999.67"), - row("c_timestamp", null, 1.0, 0.5, null, null, null), - row("c_date", null, 1.0, 0.5, null, "2015-05-09", "2015-05-09"), - row("c_string", 10.0, 1.0, 0.5, null, null, null), - row("c_varchar", 10.0, 1.0, 0.5, null, null, null), - row("c_char", 9.0, 1.0, 0.5, null, null, null), - row("c_boolean", null, 1.0, 0.5, null, null, null), - row("c_binary", 9.0, null, 0.5, null, null, null), - row("p_bigint", null, 1.0, 0.0, null, "2", "2"), - row("p_varchar", 20.0, 1.0, 0.0, null, null, null), - row(null, null, null, null, 2.0, null, null))); + row("c_tinyint", null, 1.0, 0.5, null, "99", "99", null), + row("c_smallint", null, 1.0, 0.5, null, "333", "333", null), + row("c_int", null, 1.0, 0.5, null, "444", "444", null), + row("c_bigint", null, 1.0, 0.5, null, "555", "555", null), + row("c_float", null, 1.0, 0.5, null, "666.34", "666.34", null), + row("c_double", null, 1.0, 0.5, null, "777.56", "777.56", null), + row("c_decimal", null, 1.0, 0.5, null, "888.0", "888.0", null), + row("c_decimal_w_params", null, 1.0, 0.5, null, "999.67", "999.67", null), + row("c_timestamp", null, 1.0, 0.5, null, null, null, null), + row("c_date", null, 1.0, 0.5, null, "2015-05-09", "2015-05-09", null), + row("c_string", 10.0, 1.0, 0.5, null, null, null, null), + row("c_varchar", 10.0, 1.0, 0.5, null, null, null, null), + row("c_char", 9.0, 1.0, 0.5, null, null, null, null), + row("c_boolean", null, 1.0, 0.5, null, null, null, null), + row("c_binary", 9.0, null, 0.5, null, null, null, null), + row("p_bigint", null, 1.0, 0.0, null, "2", "2", null), + row("p_varchar", 20.0, 1.0, 0.0, null, null, null, null), + row(null, null, null, null, 2.0, null, null, null))); } finally { query(format("DROP TABLE IF EXISTS %s", tableName)); @@ -1246,90 +1246,90 @@ public void testComputePartitionStatisticsOnInsert() String showStatsPartitionTwo = format("SHOW STATS FOR (SELECT * FROM %s WHERE p_bigint = 2 AND p_varchar = 'partition2')", tableName); assertThat(query(showStatsPartitionOne)).containsOnly(ImmutableList.of( - row("c_tinyint", null, 1.0, 0.5, null, "120", "120"), - row("c_smallint", null, 1.0, 0.5, null, "32760", "32760"), - row("c_int", null, 1.0, 0.5, null, "2147483640", "2147483640"), - row("c_bigint", null, 1.0, 0.5, null, "9223372036854775807", "9223372036854775807"), - row("c_float", null, 1.0, 0.5, null, "123.34", "123.34"), - row("c_double", null, 1.0, 0.5, null, "234.56", "234.56"), - row("c_decimal", null, 1.0, 0.5, null, "343.0", "343.0"), - row("c_decimal_w_params", null, 1.0, 0.5, null, "345.67", "345.67"), - row("c_timestamp", null, 1.0, 0.5, null, null, null), - row("c_date", null, 1.0, 0.5, null, "2015-05-08", "2015-05-08"), - row("c_string", 10.0, 1.0, 0.5, null, null, null), - row("c_varchar", 10.0, 1.0, 0.5, null, null, null), - row("c_char", 9.0, 1.0, 0.5, null, null, null), - row("c_boolean", null, 1.0, 0.5, null, null, null), - row("c_binary", 9.0, null, 0.5, null, null, null), - row("p_bigint", null, 1.0, 0.0, null, "1", "1"), - row("p_varchar", 20.0, 1.0, 0.0, null, null, null), - row(null, null, null, null, 2.0, null, null))); + row("c_tinyint", null, 1.0, 0.5, null, "120", "120", null), + row("c_smallint", null, 1.0, 0.5, null, "32760", "32760", null), + row("c_int", null, 1.0, 0.5, null, "2147483640", "2147483640", null), + row("c_bigint", null, 1.0, 0.5, null, "9223372036854775807", "9223372036854775807", null), + row("c_float", null, 1.0, 0.5, null, "123.34", "123.34", null), + row("c_double", null, 1.0, 0.5, null, "234.56", "234.56", null), + row("c_decimal", null, 1.0, 0.5, null, "343.0", "343.0", null), + row("c_decimal_w_params", null, 1.0, 0.5, null, "345.67", "345.67", null), + row("c_timestamp", null, 1.0, 0.5, null, null, null, null), + row("c_date", null, 1.0, 0.5, null, "2015-05-08", "2015-05-08", null), + row("c_string", 10.0, 1.0, 0.5, null, null, null, null), + row("c_varchar", 10.0, 1.0, 0.5, null, null, null, null), + row("c_char", 9.0, 1.0, 0.5, null, null, null, null), + row("c_boolean", null, 1.0, 0.5, null, null, null, null), + row("c_binary", 9.0, null, 0.5, null, null, null, null), + row("p_bigint", null, 1.0, 0.0, null, "1", "1", null), + row("p_varchar", 20.0, 1.0, 0.0, null, null, null, null), + row(null, null, null, null, 2.0, null, null, null))); assertThat(query(showStatsPartitionTwo)).containsOnly(ImmutableList.of( - row("c_tinyint", null, 1.0, 0.5, null, "99", "99"), - row("c_smallint", null, 1.0, 0.5, null, "333", "333"), - row("c_int", null, 1.0, 0.5, null, "444", "444"), - row("c_bigint", null, 1.0, 0.5, null, "555", "555"), - row("c_float", null, 1.0, 0.5, null, "666.34", "666.34"), - row("c_double", null, 1.0, 0.5, null, "777.56", "777.56"), - row("c_decimal", null, 1.0, 0.5, null, "888.0", "888.0"), - row("c_decimal_w_params", null, 1.0, 0.5, null, "999.67", "999.67"), - row("c_timestamp", null, 1.0, 0.5, null, null, null), - row("c_date", null, 1.0, 0.5, null, "2015-05-09", "2015-05-09"), - row("c_string", 10.0, 1.0, 0.5, null, null, null), - row("c_varchar", 10.0, 1.0, 0.5, null, null, null), - row("c_char", 9.0, 1.0, 0.5, null, null, null), - row("c_boolean", null, 1.0, 0.5, null, null, null), - row("c_binary", 9.0, null, 0.5, null, null, null), - row("p_bigint", null, 1.0, 0.0, null, "2", "2"), - row("p_varchar", 20.0, 1.0, 0.0, null, null, null), - row(null, null, null, null, 2.0, null, null))); + row("c_tinyint", null, 1.0, 0.5, null, "99", "99", null), + row("c_smallint", null, 1.0, 0.5, null, "333", "333", null), + row("c_int", null, 1.0, 0.5, null, "444", "444", null), + row("c_bigint", null, 1.0, 0.5, null, "555", "555", null), + row("c_float", null, 1.0, 0.5, null, "666.34", "666.34", null), + row("c_double", null, 1.0, 0.5, null, "777.56", "777.56", null), + row("c_decimal", null, 1.0, 0.5, null, "888.0", "888.0", null), + row("c_decimal_w_params", null, 1.0, 0.5, null, "999.67", "999.67", null), + row("c_timestamp", null, 1.0, 0.5, null, null, null, null), + row("c_date", null, 1.0, 0.5, null, "2015-05-09", "2015-05-09", null), + row("c_string", 10.0, 1.0, 0.5, null, null, null, null), + row("c_varchar", 10.0, 1.0, 0.5, null, null, null, null), + row("c_char", 9.0, 1.0, 0.5, null, null, null, null), + row("c_boolean", null, 1.0, 0.5, null, null, null, null), + row("c_binary", 9.0, null, 0.5, null, null, null, null), + row("p_bigint", null, 1.0, 0.0, null, "2", "2", null), + row("p_varchar", 20.0, 1.0, 0.0, null, null, null, null), + row(null, null, null, null, 2.0, null, null, null))); query(format("INSERT INTO %s VALUES( TINYINT '119', SMALLINT '32759', INTEGER '2147483639', BIGINT '9223372036854775799', REAL '122.340', DOUBLE '233.560', CAST(342.0 AS DECIMAL(10, 0)), CAST(344.670 AS DECIMAL(10, 5)), TIMESTAMP '2015-05-10 12:15:29', DATE '2015-05-07', 'p1 varchar', CAST('p1 varchar10' AS VARCHAR(10)), CAST('p1 char10' AS CHAR(10)), true, CAST('p1 binary' as VARBINARY), BIGINT '1', 'partition1')", tableName)); query(format("INSERT INTO %s VALUES( null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, BIGINT '1', 'partition1')", tableName)); assertThat(query(showStatsPartitionOne)).containsOnly(ImmutableList.of( - row("c_tinyint", null, 1.0, 0.5, null, "119", "120"), - row("c_smallint", null, 1.0, 0.5, null, "32759", "32760"), - row("c_int", null, 1.0, 0.5, null, "2147483639", "2147483640"), - row("c_bigint", null, 1.0, 0.5, null, "9223372036854775807", "9223372036854775807"), - row("c_float", null, 1.0, 0.5, null, "122.34", "123.34"), - row("c_double", null, 1.0, 0.5, null, "233.56", "234.56"), - row("c_decimal", null, 1.0, 0.5, null, "342.0", "343.0"), - row("c_decimal_w_params", null, 1.0, 0.5, null, "344.67", "345.67"), - row("c_timestamp", null, 1.0, 0.5, null, null, null), - row("c_date", null, 1.0, 0.5, null, "2015-05-07", "2015-05-08"), - row("c_string", 20.0, 1.0, 0.5, null, null, null), - row("c_varchar", 20.0, 1.0, 0.5, null, null, null), - row("c_char", 18.0, 1.0, 0.5, null, null, null), - row("c_boolean", null, 2.0, 0.5, null, null, null), - row("c_binary", 18.0, null, 0.5, null, null, null), - row("p_bigint", null, 1.0, 0.0, null, "1", "1"), - row("p_varchar", 40.0, 1.0, 0.0, null, null, null), - row(null, null, null, null, 4.0, null, null))); + row("c_tinyint", null, 1.0, 0.5, null, "119", "120", null), + row("c_smallint", null, 1.0, 0.5, null, "32759", "32760", null), + row("c_int", null, 1.0, 0.5, null, "2147483639", "2147483640", null), + row("c_bigint", null, 1.0, 0.5, null, "9223372036854775807", "9223372036854775807", null), + row("c_float", null, 1.0, 0.5, null, "122.34", "123.34", null), + row("c_double", null, 1.0, 0.5, null, "233.56", "234.56", null), + row("c_decimal", null, 1.0, 0.5, null, "342.0", "343.0", null), + row("c_decimal_w_params", null, 1.0, 0.5, null, "344.67", "345.67", null), + row("c_timestamp", null, 1.0, 0.5, null, null, null, null), + row("c_date", null, 1.0, 0.5, null, "2015-05-07", "2015-05-08", null), + row("c_string", 20.0, 1.0, 0.5, null, null, null, null), + row("c_varchar", 20.0, 1.0, 0.5, null, null, null, null), + row("c_char", 18.0, 1.0, 0.5, null, null, null, null), + row("c_boolean", null, 2.0, 0.5, null, null, null, null), + row("c_binary", 18.0, null, 0.5, null, null, null, null), + row("p_bigint", null, 1.0, 0.0, null, "1", "1", null), + row("p_varchar", 40.0, 1.0, 0.0, null, null, null, null), + row(null, null, null, null, 4.0, null, null, null))); query(format("INSERT INTO %s VALUES( TINYINT '100', SMALLINT '334', INTEGER '445', BIGINT '556', REAL '667.340', DOUBLE '778.560', CAST(889.0 AS DECIMAL(10, 0)), CAST(1000.670 AS DECIMAL(10, 5)), TIMESTAMP '2015-05-10 12:45:31', DATE '2015-05-10', CAST('p2 varchar' AS VARCHAR), CAST('p2 varchar10' AS VARCHAR(10)), CAST('p2 char10' AS CHAR(10)), true, CAST('p2 binary' as VARBINARY), BIGINT '2', 'partition2')", tableName)); query(format("INSERT INTO %s VALUES( null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, BIGINT '2', 'partition2')", tableName)); assertThat(query(showStatsPartitionTwo)).containsOnly(ImmutableList.of( - row("c_tinyint", null, 1.0, 0.5, null, "99", "100"), - row("c_smallint", null, 1.0, 0.5, null, "333", "334"), - row("c_int", null, 1.0, 0.5, null, "444", "445"), - row("c_bigint", null, 1.0, 0.5, null, "555", "556"), - row("c_float", null, 1.0, 0.5, null, "666.34", "667.34"), - row("c_double", null, 1.0, 0.5, null, "777.56", "778.56"), - row("c_decimal", null, 1.0, 0.5, null, "888.0", "889.0"), - row("c_decimal_w_params", null, 1.0, 0.5, null, "999.67", "1000.67"), - row("c_timestamp", null, 1.0, 0.5, null, null, null), - row("c_date", null, 1.0, 0.5, null, "2015-05-09", "2015-05-10"), - row("c_string", 20.0, 1.0, 0.5, null, null, null), - row("c_varchar", 20.0, 1.0, 0.5, null, null, null), - row("c_char", 18.0, 1.0, 0.5, null, null, null), - row("c_boolean", null, 1.0, 0.5, null, null, null), - row("c_binary", 18.0, null, 0.5, null, null, null), - row("p_bigint", null, 1.0, 0.0, null, "2", "2"), - row("p_varchar", 40.0, 1.0, 0.0, null, null, null), - row(null, null, null, null, 4.0, null, null))); + row("c_tinyint", null, 1.0, 0.5, null, "99", "100", null), + row("c_smallint", null, 1.0, 0.5, null, "333", "334", null), + row("c_int", null, 1.0, 0.5, null, "444", "445", null), + row("c_bigint", null, 1.0, 0.5, null, "555", "556", null), + row("c_float", null, 1.0, 0.5, null, "666.34", "667.34", null), + row("c_double", null, 1.0, 0.5, null, "777.56", "778.56", null), + row("c_decimal", null, 1.0, 0.5, null, "888.0", "889.0", null), + row("c_decimal_w_params", null, 1.0, 0.5, null, "999.67", "1000.67", null), + row("c_timestamp", null, 1.0, 0.5, null, null, null, null), + row("c_date", null, 1.0, 0.5, null, "2015-05-09", "2015-05-10", null), + row("c_string", 20.0, 1.0, 0.5, null, null, null, null), + row("c_varchar", 20.0, 1.0, 0.5, null, null, null, null), + row("c_char", 18.0, 1.0, 0.5, null, null, null, null), + row("c_boolean", null, 1.0, 0.5, null, null, null, null), + row("c_binary", 18.0, null, 0.5, null, null, null, null), + row("p_bigint", null, 1.0, 0.0, null, "2", "2", null), + row("p_varchar", 40.0, 1.0, 0.0, null, null, null, null), + row(null, null, null, null, 4.0, null, null, null))); } finally { query(format("DROP TABLE IF EXISTS %s", tableName)); diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/statistics/ColumnStatisticType.java b/presto-spi/src/main/java/com/facebook/presto/spi/statistics/ColumnStatisticType.java index 793c5acec9606..20b9261e7bc25 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/statistics/ColumnStatisticType.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/statistics/ColumnStatisticType.java @@ -21,7 +21,8 @@ public enum ColumnStatisticType NUMBER_OF_DISTINCT_VALUES("approx_distinct"), NUMBER_OF_NON_NULL_VALUES("count"), NUMBER_OF_TRUE_VALUES("count_if"), - TOTAL_SIZE_IN_BYTES("sum_data_size_for_stats"); + TOTAL_SIZE_IN_BYTES("sum_data_size_for_stats"), + HISTOGRAM("tdigest_agg"); private final String functionName; ColumnStatisticType(String functionName) diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/statistics/ColumnStatistics.java b/presto-spi/src/main/java/com/facebook/presto/spi/statistics/ColumnStatistics.java index 8ae5cfc25ce41..ec4ee420e46bd 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/statistics/ColumnStatistics.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/statistics/ColumnStatistics.java @@ -23,13 +23,15 @@ public final class ColumnStatistics { - private static final ColumnStatistics EMPTY = new ColumnStatistics(Estimate.unknown(), Estimate.unknown(), Estimate.unknown(), Optional.empty()); + private static final ColumnStatistics EMPTY = new ColumnStatistics(Estimate.unknown(), Estimate.unknown(), Estimate.unknown(), Optional.empty(), Optional.empty()); private final Estimate nullsFraction; private final Estimate distinctValuesCount; private final Estimate dataSize; private final Optional range; + private final Optional histogram; + public static ColumnStatistics empty() { return EMPTY; @@ -39,7 +41,8 @@ public ColumnStatistics( Estimate nullsFraction, Estimate distinctValuesCount, Estimate dataSize, - Optional range) + Optional range, + Optional histogram) { this.nullsFraction = requireNonNull(nullsFraction, "nullsFraction is null"); if (!nullsFraction.isUnknown()) { @@ -56,6 +59,7 @@ public ColumnStatistics( throw new IllegalArgumentException(format("dataSize must be greater than or equal to 0: %s", dataSize.getValue())); } this.range = requireNonNull(range, "range is null"); + this.histogram = requireNonNull(histogram, "histogram is null"); } @JsonProperty @@ -82,6 +86,12 @@ public Optional getRange() return range; } + @JsonProperty + public Optional getHistogram() + { + return histogram; + } + @Override public boolean equals(Object o) { @@ -95,13 +105,14 @@ public boolean equals(Object o) return Objects.equals(nullsFraction, that.nullsFraction) && Objects.equals(distinctValuesCount, that.distinctValuesCount) && Objects.equals(dataSize, that.dataSize) && - Objects.equals(range, that.range); + Objects.equals(range, that.range) && + Objects.equals(histogram, that.histogram); } @Override public int hashCode() { - return Objects.hash(nullsFraction, distinctValuesCount, dataSize, range); + return Objects.hash(nullsFraction, distinctValuesCount, dataSize, range, histogram); } @Override @@ -112,6 +123,7 @@ public String toString() ", distinctValuesCount=" + distinctValuesCount + ", dataSize=" + dataSize + ", range=" + range + + ", histogram=" + histogram + '}'; } @@ -124,7 +136,8 @@ public static Builder builder() * If one of the estimates below is unspecified, the default "unknown" estimate value * (represented by floating point NaN) may cause the resulting symbol statistics * to be "unknown" as well. - * @see SymbolStatsEstimate + * + * @see VariableStatsEstimate */ public static final class Builder { @@ -133,6 +146,8 @@ public static final class Builder private Estimate dataSize = Estimate.unknown(); private Optional range = Optional.empty(); + private Optional histogram = Optional.empty(); + public Builder setNullsFraction(Estimate nullsFraction) { this.nullsFraction = requireNonNull(nullsFraction, "nullsFraction is null"); @@ -178,9 +193,40 @@ public Builder setRange(Optional range) return this; } + public Builder setHistogram(Optional histogram) + { + this.histogram = histogram; + return this; + } + + public Builder mergeWith(Builder other) + { + if (nullsFraction.isUnknown()) { + this.nullsFraction = other.nullsFraction; + } + + if (distinctValuesCount.isUnknown()) { + this.distinctValuesCount = other.distinctValuesCount; + } + + if (dataSize.isUnknown()) { + this.dataSize = other.dataSize; + } + + if (!range.isPresent()) { + this.range = other.range; + } + + if (!histogram.isPresent()) { + this.histogram = other.histogram; + } + + return this; + } + public ColumnStatistics build() { - return new ColumnStatistics(nullsFraction, distinctValuesCount, dataSize, range); + return new ColumnStatistics(nullsFraction, distinctValuesCount, dataSize, range, histogram); } } } diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/statistics/ConnectorHistogram.java b/presto-spi/src/main/java/com/facebook/presto/spi/statistics/ConnectorHistogram.java new file mode 100644 index 0000000000000..0febeb7f0d3fa --- /dev/null +++ b/presto-spi/src/main/java/com/facebook/presto/spi/statistics/ConnectorHistogram.java @@ -0,0 +1,64 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.spi.statistics; + +import com.fasterxml.jackson.annotation.JsonTypeInfo; + +/** + * This interface contains functions which the Presto optimizer can use to + * answer questions about a particular column's data distribution. These + * functions will be used to return answers to the query optimizer to build + * more realistic cost models for joins and filter predicates. + *
+ * Currently, this interface supports representing histograms of columns whose + * domains map to real values. + *
+ * Null values should not be represented in underlying histogram implementation. + * When calculating filter statistics using the {@link ColumnStatisticType#NUMBER_OF_NON_NULL_VALUES} + * are used to account for nulls in cost-based calculations. + * + * @see ColumnStatisticType#NUMBER_OF_NON_NULL_VALUES + */ +@JsonTypeInfo(use = JsonTypeInfo.Id.MINIMAL_CLASS, property = "@class") +public interface ConnectorHistogram +{ + /** + * Calculates an estimate for the percentile at which a particular value + * falls in a distribution. + *
+ * Put another way, this function returns the value of F(x) where F(x) + * represents the CDF of this particular distribution. Traditionally, the + * true CDF of a random variable X is represented by F(x) = P(x <= X). This + * function signature allows for a slight modification by using the + * {@code inclusive} parameter to return the value for F(x) = P(x < X) + * should the underlying implementation support it. + * + * @param value the value to calculate percentile + * @param inclusive whether this calculation should be inclusive or exclusive of the value (<= or <) + * @return an {@link Estimate} of the percentile + */ + Estimate cumulativeProbability(double value, boolean inclusive); + + /** + * Calculates the value which occurs at a particular percentile in the given + * distribution. + *
+ * Put another way, calculates the inverse CDF. Given F(x) is the CDF of + * a particular distribution, this function computes F^(-1)(x). + * + * @param percentile the percentile. Must be in the range [0.0, 1.0] + * @return the value in the distribution corresponding to the percentile + */ + Estimate inverseCumulativeProbability(double percentile); +} diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java index 6162d3cc10a09..c36e843529365 100644 --- a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java +++ b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java @@ -55,6 +55,7 @@ import static com.facebook.presto.SystemSessionProperties.FIELD_NAMES_IN_JSON_CAST_ENABLED; import static com.facebook.presto.SystemSessionProperties.GENERATE_DOMAIN_FILTERS; import static com.facebook.presto.SystemSessionProperties.HASH_PARTITION_COUNT; +import static com.facebook.presto.SystemSessionProperties.ITERATIVE_OPTIMIZER_TIMEOUT; import static com.facebook.presto.SystemSessionProperties.JOIN_PREFILTER_BUILD_SIDE; import static com.facebook.presto.SystemSessionProperties.KEY_BASED_SAMPLING_ENABLED; import static com.facebook.presto.SystemSessionProperties.KEY_BASED_SAMPLING_FUNCTION; @@ -63,6 +64,7 @@ import static com.facebook.presto.SystemSessionProperties.MERGE_AGGREGATIONS_WITH_AND_WITHOUT_FILTER; import static com.facebook.presto.SystemSessionProperties.MERGE_DUPLICATE_AGGREGATIONS; import static com.facebook.presto.SystemSessionProperties.OFFSET_CLAUSE_ENABLED; +import static com.facebook.presto.SystemSessionProperties.OPTIMIZER_USE_HISTOGRAMS; import static com.facebook.presto.SystemSessionProperties.OPTIMIZE_CASE_EXPRESSION_PREDICATE; import static com.facebook.presto.SystemSessionProperties.PREFILTER_FOR_GROUPBY_LIMIT; import static com.facebook.presto.SystemSessionProperties.PREFILTER_FOR_GROUPBY_LIMIT_TIMEOUT_MS; @@ -1027,7 +1029,7 @@ public void testDistinctLimitInternal(Session session) assertQuerySucceeds(session, "SELECT DISTINCT custkey FROM orders LIMIT 10000"); assertQuery(session, "" + - "SELECT DISTINCT x FROM (VALUES 1) t(x) JOIN (VALUES 10, 20) u(a) ON t.x < u.a LIMIT 100", + "SELECT DISTINCT x FROM (VALUES 1) t(x) JOIN (VALUES 10, 20) u(a) ON t.x < u.a LIMIT 100", "SELECT 1"); } @@ -1807,9 +1809,9 @@ public void testRowNumberSpecialFilters() public void testMinMaxN() { assertQuery("" + - "SELECT x FROM (" + - "SELECT min(orderkey, 3) t FROM orders" + - ") CROSS JOIN UNNEST(t) AS a(x)", + "SELECT x FROM (" + + "SELECT min(orderkey, 3) t FROM orders" + + ") CROSS JOIN UNNEST(t) AS a(x)", "VALUES 1, 2, 3"); assertQuery( @@ -2360,23 +2362,45 @@ public void testLargeIn() String longValues = range(0, 5000) .mapToObj(Integer::toString) .collect(joining(", ")); - assertQuery("SELECT orderkey FROM orders WHERE orderkey IN (" + longValues + ")"); - assertQuery("SELECT orderkey FROM orders WHERE orderkey NOT IN (" + longValues + ")"); + Session session = Session.builder(getSession()) + .setSystemProperty(ITERATIVE_OPTIMIZER_TIMEOUT, "15000ms") + .build(); + assertQuery(session, "SELECT orderkey FROM orders WHERE orderkey IN (" + longValues + ")"); + assertQuery(session, "SELECT orderkey FROM orders WHERE orderkey NOT IN (" + longValues + ")"); - assertQuery("SELECT orderkey FROM orders WHERE orderkey IN (mod(1000, orderkey), " + longValues + ")"); - assertQuery("SELECT orderkey FROM orders WHERE orderkey NOT IN (mod(1000, orderkey), " + longValues + ")"); + assertQuery(session, "SELECT orderkey FROM orders WHERE orderkey IN (mod(1000, orderkey), " + longValues + ")"); + assertQuery(session, "SELECT orderkey FROM orders WHERE orderkey NOT IN (mod(1000, orderkey), " + longValues + ")"); String varcharValues = range(0, 5000) .mapToObj(i -> "'" + i + "'") .collect(joining(", ")); - assertQuery("SELECT orderkey FROM orders WHERE cast(orderkey AS VARCHAR) IN (" + varcharValues + ")"); - assertQuery("SELECT orderkey FROM orders WHERE cast(orderkey AS VARCHAR) NOT IN (" + varcharValues + ")"); + assertQuery(session, "SELECT orderkey FROM orders WHERE cast(orderkey AS VARCHAR) IN (" + varcharValues + ")"); + assertQuery(session, "SELECT orderkey FROM orders WHERE cast(orderkey AS VARCHAR) NOT IN (" + varcharValues + ")"); String arrayValues = range(0, 5000) .mapToObj(i -> format("ARRAY[%s, %s, %s]", i, i + 1, i + 2)) .collect(joining(", ")); - assertQuery("SELECT ARRAY[0, 0, 0] in (ARRAY[0, 0, 0], " + arrayValues + ")", "values true"); - assertQuery("SELECT ARRAY[0, 0, 0] in (" + arrayValues + ")", "values false"); + assertQuery(session, "SELECT ARRAY[0, 0, 0] in (ARRAY[0, 0, 0], " + arrayValues + ")", "values true"); + assertQuery(session, "SELECT ARRAY[0, 0, 0] in (" + arrayValues + ")", "values false"); + } + + @Test + public void testLargeInWithHistograms() + { + String longValues = range(0, 10_000) + .mapToObj(Integer::toString) + .collect(joining(", ")); + String query = "select orderpriority, sum(totalprice) from lineitem join orders on lineitem.orderkey = orders.orderkey where orders.orderkey in (" + longValues + ") group by 1"; + Session session = Session.builder(getSession()) + .setSystemProperty(ITERATIVE_OPTIMIZER_TIMEOUT, "30000ms") + .setSystemProperty(OPTIMIZER_USE_HISTOGRAMS, "true") + .build(); + assertQuerySucceeds(session, query); + session = Session.builder(getSession()) + .setSystemProperty(ITERATIVE_OPTIMIZER_TIMEOUT, "20000ms") + .setSystemProperty(OPTIMIZER_USE_HISTOGRAMS, "false") + .build(); + assertQuerySucceeds(session, query); } @Test diff --git a/presto-tests/src/test/java/com/facebook/presto/tests/TestLocalQueries.java b/presto-tests/src/test/java/com/facebook/presto/tests/TestLocalQueries.java index a8a6d54296744..730d87e30ce4e 100644 --- a/presto-tests/src/test/java/com/facebook/presto/tests/TestLocalQueries.java +++ b/presto-tests/src/test/java/com/facebook/presto/tests/TestLocalQueries.java @@ -88,12 +88,12 @@ public void testShowColumnStats() MaterializedResult result = computeActual("SHOW STATS FOR nation"); MaterializedResult expectedStatistics = - resultBuilder(getSession(), VARCHAR, DOUBLE, DOUBLE, DOUBLE, DOUBLE, VARCHAR, VARCHAR) - .row("nationkey", null, 25.0, 0.0, null, "0", "24") - .row("name", 177.0, 25.0, 0.0, null, null, null) - .row("regionkey", null, 5.0, 0.0, null, "0", "4") - .row("comment", 1857.0, 25.0, 0.0, null, null, null) - .row(null, null, null, null, 25.0, null, null) + resultBuilder(getSession(), VARCHAR, DOUBLE, DOUBLE, DOUBLE, DOUBLE, VARCHAR, VARCHAR, VARCHAR) + .row("nationkey", null, 25.0, 0.0, null, "0", "24", null) + .row("name", 177.0, 25.0, 0.0, null, null, null, null) + .row("regionkey", null, 5.0, 0.0, null, "0", "4", null) + .row("comment", 1857.0, 25.0, 0.0, null, null, null, null) + .row(null, null, null, null, 25.0, null, null, null) .build(); assertEquals(result, expectedStatistics);