diff --git a/core/trino-parser/src/main/java/io/trino/sql/tree/SampledRelation.java b/core/trino-parser/src/main/java/io/trino/sql/tree/SampledRelation.java index ed25dc777de0..e529480278bd 100644 --- a/core/trino-parser/src/main/java/io/trino/sql/tree/SampledRelation.java +++ b/core/trino-parser/src/main/java/io/trino/sql/tree/SampledRelation.java @@ -110,4 +110,15 @@ public int hashCode() { return Objects.hash(relation, type, samplePercentage); } + + @Override + public boolean shallowEquals(Node other) + { + if (!sameClass(this, other)) { + return false; + } + + SampledRelation otherRelation = (SampledRelation) other; + return type == otherRelation.type && Objects.equals(samplePercentage, otherRelation.samplePercentage); + } } diff --git a/testing/trino-tests/src/test/java/io/trino/tests/TestTablesample.java b/testing/trino-tests/src/test/java/io/trino/tests/TestTablesample.java index d9b89a4a0429..aa52f5b3a7ef 100644 --- a/testing/trino-tests/src/test/java/io/trino/tests/TestTablesample.java +++ b/testing/trino-tests/src/test/java/io/trino/tests/TestTablesample.java @@ -110,4 +110,32 @@ public void testInvalidRatioType() .hasErrorCode(TYPE_MISMATCH) .hasMessage("line 1:61: Sample percentage should be a numeric expression"); } + + @Test + public void testInSubquery() + { + // zero sample + assertThat(assertions.query("SELECT count(*) FROM tpch.tiny.orders WHERE orderkey IN (SELECT orderkey FROM tpch.tiny.orders TABLESAMPLE BERNOULLI (0))")) + .matches("VALUES BIGINT '0'"); + + // full sample + assertThat(assertions.query("SELECT count(*) FROM tpch.tiny.orders WHERE orderkey IN (SELECT orderkey FROM tpch.tiny.orders TABLESAMPLE BERNOULLI (100))")) + .matches("VALUES BIGINT '15000'"); + + // 1% + assertThat(assertions.query("SELECT count(*) FROM tpch.tiny.orders WHERE orderkey IN (SELECT orderkey FROM tpch.tiny.orders TABLESAMPLE BERNOULLI (1))")) + .satisfies(result -> assertThat((Long) result.getOnlyValue()).isBetween(50L, 450L)); + + // 0.1% + assertThat(assertions.query("SELECT count(*) FROM tpch.tiny.orders WHERE orderkey IN (SELECT orderkey FROM tpch.tiny.orders TABLESAMPLE BERNOULLI (1e-1))")) + .satisfies(result -> assertThat((Long) result.getOnlyValue()).isBetween(3L, 45L)); + + // 0.1% as decimal + assertThat(assertions.query("SELECT count(*) FROM tpch.tiny.orders WHERE orderkey IN (SELECT orderkey FROM tpch.tiny.orders TABLESAMPLE BERNOULLI (0.1))")) + .satisfies(result -> assertThat((Long) result.getOnlyValue()).isBetween(3L, 45L)); + + // fraction as long decimal + assertThat(assertions.query("SELECT count(*) FROM tpch.tiny.orders WHERE orderkey IN (SELECT orderkey FROM tpch.tiny.orders TABLESAMPLE BERNOULLI (0.000000000000000000001))")) + .satisfies(result -> assertThat((Long) result.getOnlyValue()).isBetween(0L, 5L)); + } }