diff --git a/presto-parser/src/main/java/io/prestosql/sql/tree/SampledRelation.java b/presto-parser/src/main/java/io/prestosql/sql/tree/SampledRelation.java index 3d7e9ae78b7e..738846498dff 100644 --- a/presto-parser/src/main/java/io/prestosql/sql/tree/SampledRelation.java +++ b/presto-parser/src/main/java/io/prestosql/sql/tree/SampledRelation.java @@ -110,4 +110,15 @@ public int hashCode() { return Objects.hash(relation, type, samplePercentage); } + + @Override + public boolean shallowEquals(Node other) + { + if (!sameClass(this, other)) { + return false; + } + + SampledRelation otherRelation = (SampledRelation) other; + return type == otherRelation.type && Objects.equals(samplePercentage, otherRelation.samplePercentage); + } } diff --git a/presto-tests/src/test/java/io/prestosql/tests/TestTablesample.java b/presto-tests/src/test/java/io/prestosql/tests/TestTablesample.java index 76c812ab08c1..7fed9367503f 100644 --- a/presto-tests/src/test/java/io/prestosql/tests/TestTablesample.java +++ b/presto-tests/src/test/java/io/prestosql/tests/TestTablesample.java @@ -110,4 +110,32 @@ public void testInvalidRatioType() .hasErrorCode(TYPE_MISMATCH) .hasMessage("line 1:61: Sample percentage should be a numeric expression"); } + + @Test + public void testInSubquery() + { + // zero sample + assertThat(assertions.query("SELECT count(*) FROM tpch.tiny.orders WHERE orderkey IN (SELECT orderkey FROM tpch.tiny.orders TABLESAMPLE BERNOULLI (0))")) + .matches("VALUES BIGINT '0'"); + + // full sample + assertThat(assertions.query("SELECT count(*) FROM tpch.tiny.orders WHERE orderkey IN (SELECT orderkey FROM tpch.tiny.orders TABLESAMPLE BERNOULLI (100))")) + .matches("VALUES BIGINT '15000'"); + + // 1% + assertThat(assertions.query("SELECT count(*) FROM tpch.tiny.orders WHERE orderkey IN (SELECT orderkey FROM tpch.tiny.orders TABLESAMPLE BERNOULLI (1))")) + .satisfies(result -> assertThat((Long) result.getOnlyValue()).isBetween(50L, 450L)); + + // 0.1% + assertThat(assertions.query("SELECT count(*) FROM tpch.tiny.orders WHERE orderkey IN (SELECT orderkey FROM tpch.tiny.orders TABLESAMPLE BERNOULLI (1e-1))")) + .satisfies(result -> assertThat((Long) result.getOnlyValue()).isBetween(3L, 45L)); + + // 0.1% as decimal + assertThat(assertions.query("SELECT count(*) FROM tpch.tiny.orders WHERE orderkey IN (SELECT orderkey FROM tpch.tiny.orders TABLESAMPLE BERNOULLI (0.1))")) + .satisfies(result -> assertThat((Long) result.getOnlyValue()).isBetween(3L, 45L)); + + // fraction as long decimal + assertThat(assertions.query("SELECT count(*) FROM tpch.tiny.orders WHERE orderkey IN (SELECT orderkey FROM tpch.tiny.orders TABLESAMPLE BERNOULLI (0.000000000000000000001))")) + .satisfies(result -> assertThat((Long) result.getOnlyValue()).isBetween(0L, 5L)); + } }