diff --git a/core/trino-main/src/test/java/io/trino/operator/scalar/AbstractTestExtract.java b/core/trino-main/src/test/java/io/trino/operator/scalar/AbstractTestExtract.java index 27bb54a41e95..e25720198d85 100644 --- a/core/trino-main/src/test/java/io/trino/operator/scalar/AbstractTestExtract.java +++ b/core/trino-main/src/test/java/io/trino/operator/scalar/AbstractTestExtract.java @@ -159,7 +159,7 @@ protected void testUnsupportedExtract(String extractField) { types().forEach(type -> { String expression = format("EXTRACT(%s FROM CAST(NULL AS %s))", extractField, type); - assertThatThrownBy(() -> assertions.expression(expression), expression) + assertThatThrownBy(() -> assertions.expression(expression).evaluate(), expression) .as(expression) .isInstanceOf(TrinoException.class) .hasMessageMatching(format("line 1:\\d+:\\Q Cannot extract %s from %s", extractField, type)); diff --git a/core/trino-main/src/test/java/io/trino/operator/scalar/AbstractTestFunctions.java b/core/trino-main/src/test/java/io/trino/operator/scalar/AbstractTestFunctions.java index 09f1a5b5f292..5202dd415146 100644 --- a/core/trino-main/src/test/java/io/trino/operator/scalar/AbstractTestFunctions.java +++ b/core/trino-main/src/test/java/io/trino/operator/scalar/AbstractTestFunctions.java @@ -85,11 +85,19 @@ public final void destroyTestFunctions() functionAssertions = null; } + /** + * @deprecated Use {@link io.trino.sql.query.QueryAssertions#function(String, String...)} + */ + @Deprecated protected void assertFunction(@Language("SQL") String projection, Type expectedType, Object expected) { functionAssertions.assertFunction(projection, expectedType, expected); } + /** + * @deprecated Use {@link io.trino.sql.query.QueryAssertions#operator(OperatorType, String...)} + */ + @Deprecated protected void assertOperator(OperatorType operator, String value, Type expectedType, Object expected) { functionAssertions.assertFunction(format("\"%s\"(%s)", mangleOperatorName(operator), value), expectedType, expected); diff --git a/core/trino-main/src/test/java/io/trino/operator/scalar/FunctionAssertions.java b/core/trino-main/src/test/java/io/trino/operator/scalar/FunctionAssertions.java index 35ec196b159a..f2fdcbb55734 100644 --- a/core/trino-main/src/test/java/io/trino/operator/scalar/FunctionAssertions.java +++ b/core/trino-main/src/test/java/io/trino/operator/scalar/FunctionAssertions.java @@ -291,6 +291,10 @@ public void installPlugin(Plugin plugin) runner.installPlugin(plugin); } + /** + * @deprecated Use {@link io.trino.sql.query.QueryAssertions#function(String, String...)} + */ + @Deprecated public void assertFunction(String projection, Type expectedType, Object expected) { if (expected instanceof Slice) { @@ -301,17 +305,29 @@ public void assertFunction(String projection, Type expectedType, Object expected assertEquals(actual, expected); } + /** + * @deprecated Use {@link io.trino.sql.query.QueryAssertions#function(String, String...)} + */ + @Deprecated public void assertFunctionString(String projection, Type expectedType, String expected) { Object actual = selectSingleValue(projection, expectedType, runner.getExpressionCompiler()); assertEquals(actual.toString(), expected); } + /** + * @deprecated Use {@link io.trino.sql.query.QueryAssertions#expression(String)} + */ + @Deprecated public void tryEvaluate(String expression, Type expectedType) { tryEvaluate(expression, expectedType, session); } + /** + * @deprecated Use {@link io.trino.sql.query.QueryAssertions#expression(String)} + */ + @Deprecated public void tryEvaluate(String expression, Type expectedType, Session session) { selectUniqueValue(expression, expectedType, session, runner.getExpressionCompiler()); @@ -322,6 +338,10 @@ public void tryEvaluateWithAll(String expression, Type expectedType) tryEvaluateWithAll(expression, expectedType, session); } + /** + * @deprecated Use {@link io.trino.sql.query.QueryAssertions#expression(String)} + */ + @Deprecated public void tryEvaluateWithAll(String expression, Type expectedType, Session session) { executeProjectionWithAll(expression, expectedType, session, runner.getExpressionCompiler()); diff --git a/core/trino-main/src/test/java/io/trino/operator/scalar/TestArrayContainsSequence.java b/core/trino-main/src/test/java/io/trino/operator/scalar/TestArrayContainsSequence.java index 3e7b0404f83b..6937ddd8a363 100644 --- a/core/trino-main/src/test/java/io/trino/operator/scalar/TestArrayContainsSequence.java +++ b/core/trino-main/src/test/java/io/trino/operator/scalar/TestArrayContainsSequence.java @@ -13,30 +13,72 @@ */ package io.trino.operator.scalar; -import org.testng.annotations.Test; +import io.trino.sql.query.QueryAssertions; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; -import static io.trino.spi.type.BooleanType.BOOLEAN; +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS; +@TestInstance(PER_CLASS) public class TestArrayContainsSequence - extends AbstractTestFunctions { + private QueryAssertions assertions; + + @BeforeAll + public void init() + { + assertions = new QueryAssertions(); + } + + @AfterAll + public void teardown() + { + assertions.close(); + assertions = null; + } + @Test public void testBasic() { - assertFunction("contains_sequence(ARRAY [1,2,3,4,5,6], ARRAY[1,2])", BOOLEAN, true); - assertFunction("contains_sequence(ARRAY [1,2,3,4,5,6], ARRAY[3,4])", BOOLEAN, true); - assertFunction("contains_sequence(ARRAY [1,2,3,4,5,6], ARRAY[5,6])", BOOLEAN, true); - assertFunction("contains_sequence(ARRAY [1,2,3,4,5,6], ARRAY[1,2,4])", BOOLEAN, false); - assertFunction("contains_sequence(ARRAY [1,2,3,NULL,4,5,6], ARRAY[3,NULL,4])", BOOLEAN, true); - assertFunction("contains_sequence(ARRAY [1,2,3,4,5,6], ARRAY[1,2,3,4,5,6])", BOOLEAN, true); - assertFunction("contains_sequence(ARRAY [1,2,3,4,5,6], ARRAY[])", BOOLEAN, true); - assertFunction("contains_sequence(ARRAY ['1','2','3'], ARRAY['1','2'])", BOOLEAN, true); - assertFunction("contains_sequence(ARRAY [1.1,2.2,3.3], ARRAY[1.1,2.2])", BOOLEAN, true); - assertFunction("contains_sequence(ARRAY [ARRAY[1,2],ARRAY[3],ARRAY[4,5]], ARRAY[ARRAY[1,2],ARRAY[3]])", BOOLEAN, true); - assertFunction("contains_sequence(ARRAY [ARRAY[1,2],ARRAY[3],ARRAY[4,5]], ARRAY[ARRAY[1,2],ARRAY[4]])", BOOLEAN, false); + assertThat(assertions.function("contains_sequence", "ARRAY[1, 2, 3, 4, 5, 6]", "ARRAY[1, 2]")) + .isEqualTo(true); + + assertThat(assertions.function("contains_sequence", "ARRAY[1, 2, 3, 4, 5, 6]", "ARRAY[3, 4]")) + .isEqualTo(true); + + assertThat(assertions.function("contains_sequence", "ARRAY[1, 2, 3, 4, 5, 6]", "ARRAY[5, 6]")) + .isEqualTo(true); + + assertThat(assertions.function("contains_sequence", "ARRAY[1, 2, 3, 4, 5, 6]", "ARRAY[1, 2, 4]")) + .isEqualTo(false); + + assertThat(assertions.function("contains_sequence", "ARRAY[1, 2, 3, NULL, 4, 5, 6]", "ARRAY[3, NULL, 4]")) + .isEqualTo(true); + + assertThat(assertions.function("contains_sequence", "ARRAY[1, 2, 3, 4, 5, 6]", "ARRAY[1, 2, 3, 4, 5, 6]")) + .isEqualTo(true); + + assertThat(assertions.function("contains_sequence", "ARRAY[1, 2, 3, 4, 5, 6]", "ARRAY[]")) + .isEqualTo(true); + + assertThat(assertions.function("contains_sequence", "ARRAY['1', '2', '3']", "ARRAY['1', '2']")) + .isEqualTo(true); + + assertThat(assertions.function("contains_sequence", "ARRAY[1.1, 2.2, 3.3]", "ARRAY[1.1, 2.2]")) + .isEqualTo(true); + + assertThat(assertions.function("contains_sequence", "ARRAY[ARRAY[1,2], ARRAY[3], ARRAY[4,5]]", "ARRAY[ARRAY[1,2], ARRAY[3]]")) + .isEqualTo(true); + + assertThat(assertions.function("contains_sequence", "ARRAY[ARRAY[1,2], ARRAY[3], ARRAY[4,5]]", "ARRAY[ARRAY[1,2], ARRAY[4]]")) + .isEqualTo(false); for (int i = 1; i <= 6; i++) { - assertFunction("contains_sequence(ARRAY [1,2,3,4,5,6], ARRAY[" + i + "])", BOOLEAN, true); + assertThat(assertions.function("contains_sequence", "ARRAY[1, 2, 3, 4, 5, 6]", "ARRAY[%d]".formatted(i))) + .isEqualTo(true); } } } diff --git a/core/trino-main/src/test/java/io/trino/operator/scalar/TestArrayExceptFunction.java b/core/trino-main/src/test/java/io/trino/operator/scalar/TestArrayExceptFunction.java index 011500c6e539..d90fc4bff773 100644 --- a/core/trino-main/src/test/java/io/trino/operator/scalar/TestArrayExceptFunction.java +++ b/core/trino-main/src/test/java/io/trino/operator/scalar/TestArrayExceptFunction.java @@ -13,64 +13,112 @@ */ package io.trino.operator.scalar; -import com.google.common.collect.ImmutableList; import io.trino.spi.type.ArrayType; -import org.testng.annotations.Test; +import io.trino.sql.query.QueryAssertions; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; -import static io.trino.spi.type.BigintType.BIGINT; -import static io.trino.spi.type.BooleanType.BOOLEAN; -import static io.trino.spi.type.DoubleType.DOUBLE; -import static io.trino.spi.type.IntegerType.INTEGER; -import static io.trino.spi.type.VarcharType.VARCHAR; import static io.trino.type.UnknownType.UNKNOWN; -import static java.util.Arrays.asList; -import static java.util.Collections.singletonList; +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS; +@TestInstance(PER_CLASS) public class TestArrayExceptFunction - extends AbstractTestFunctions { + private QueryAssertions assertions; + + @BeforeAll + public void init() + { + assertions = new QueryAssertions(); + } + + @AfterAll + public void teardown() + { + assertions.close(); + assertions = null; + } + @Test public void testBasic() { - assertFunction("array_except(ARRAY[1, 5, 3], ARRAY[3])", new ArrayType(INTEGER), ImmutableList.of(1, 5)); - assertFunction("array_except(ARRAY[CAST(1 as BIGINT), 5, 3], ARRAY[5])", new ArrayType(BIGINT), ImmutableList.of(1L, 3L)); - assertFunction("array_except(ARRAY[VARCHAR 'x', 'y', 'z'], ARRAY['x'])", new ArrayType(VARCHAR), ImmutableList.of("y", "z")); - assertFunction("array_except(ARRAY[true, false, null], ARRAY[true])", new ArrayType(BOOLEAN), asList(false, null)); - assertFunction("array_except(ARRAY[1.1E0, 5.4E0, 3.9E0], ARRAY[5, 5.4E0])", new ArrayType(DOUBLE), ImmutableList.of(1.1, 3.9)); + assertThat(assertions.function("array_except", "ARRAY[1, 5, 3]", "ARRAY[3]")) + .matches("ARRAY[1, 5]"); + + assertThat(assertions.function("array_except", "ARRAY[BIGINT '1', 5, 3]", "ARRAY[5]")) + .matches("ARRAY[BIGINT '1', BIGINT '3']"); + + assertThat(assertions.function("array_except", "ARRAY[VARCHAR 'x', 'y', 'z']", "ARRAY['x']")) + .matches("ARRAY[VARCHAR 'y', VARCHAR 'z']"); + + assertThat(assertions.function("array_except", "ARRAY[true, false, null]", "ARRAY[true]")) + .matches("ARRAY[false, null]"); + + assertThat(assertions.function("array_except", "ARRAY[1.1E0, 5.4E0, 3.9E0]", "ARRAY[5, 5.4E0]")) + .matches("ARRAY[1.1E0, 3.9E0]"); } @Test public void testEmpty() { - assertFunction("array_except(ARRAY[], ARRAY[])", new ArrayType(UNKNOWN), ImmutableList.of()); - assertFunction("array_except(ARRAY[], ARRAY[1, 3])", new ArrayType(INTEGER), ImmutableList.of()); - assertFunction("array_except(ARRAY[VARCHAR 'abc'], ARRAY[])", new ArrayType(VARCHAR), ImmutableList.of("abc")); + assertThat(assertions.function("array_except", "ARRAY[]", "ARRAY[]")) + .matches("ARRAY[]"); + + assertThat(assertions.function("array_except", "ARRAY[]", "ARRAY[1, 3]")) + .matches("CAST(ARRAY[] AS array(integer))"); + + assertThat(assertions.function("array_except", "ARRAY[VARCHAR 'abc']", "ARRAY[]")) + .matches("ARRAY[VARCHAR 'abc']"); } @Test public void testNull() { - assertFunction("array_except(ARRAY[NULL], NULL)", new ArrayType(UNKNOWN), null); - assertFunction("array_except(NULL, NULL)", new ArrayType(UNKNOWN), null); - assertFunction("array_except(NULL, ARRAY[NULL])", new ArrayType(UNKNOWN), null); - assertFunction("array_except(ARRAY[NULL], ARRAY[NULL])", new ArrayType(UNKNOWN), ImmutableList.of()); - assertFunction("array_except(ARRAY[], ARRAY[NULL])", new ArrayType(UNKNOWN), ImmutableList.of()); - assertFunction("array_except(ARRAY[NULL], ARRAY[])", new ArrayType(UNKNOWN), singletonList(null)); + assertThat(assertions.function("array_except", "ARRAY[NULL]", "NULL")) + .isNull(new ArrayType(UNKNOWN)); + + assertThat(assertions.function("array_except", "NULL", "NULL")) + .isNull(new ArrayType(UNKNOWN)); + + assertThat(assertions.function("array_except", "NULL", "ARRAY[NULL]")) + .isNull(new ArrayType(UNKNOWN)); + + assertThat(assertions.function("array_except", "ARRAY[NULL]", "ARRAY[NULL]")) + .matches("ARRAY[]"); + + assertThat(assertions.function("array_except", "ARRAY[]", "ARRAY[NULL]")) + .matches("ARRAY[]"); + + assertThat(assertions.function("array_except", "ARRAY[NULL]", "ARRAY[]")) + .matches("ARRAY[NULL]"); } @Test public void testDuplicates() { - assertFunction("array_except(ARRAY[1, 5, 3, 5, 1], ARRAY[3])", new ArrayType(INTEGER), ImmutableList.of(1, 5)); - assertFunction("array_except(ARRAY[CAST(1 as BIGINT), 5, 5, 3, 3, 3, 1], ARRAY[3, 5])", new ArrayType(BIGINT), ImmutableList.of(1L)); - assertFunction("array_except(ARRAY[VARCHAR 'x', 'x', 'y', 'z'], ARRAY['x', 'y', 'x'])", new ArrayType(VARCHAR), ImmutableList.of("z")); - assertFunction("array_except(ARRAY[true, false, null, true, false, null], ARRAY[true, true, true])", new ArrayType(BOOLEAN), asList(false, null)); + assertThat(assertions.function("array_except", "ARRAY[1, 5, 3, 5, 1]", "ARRAY[3]")) + .matches("ARRAY[1, 5]"); + + assertThat(assertions.function("array_except", "ARRAY[BIGINT '1', 5, 5, 3, 3, 3, 1]", "ARRAY[3, 5]")) + .matches("ARRAY[BIGINT '1']"); + + assertThat(assertions.function("array_except", "ARRAY[VARCHAR 'x', 'x', 'y', 'z']", "ARRAY['x', 'y', 'x']")) + .matches("ARRAY[VARCHAR 'z']"); + + assertThat(assertions.function("array_except", "ARRAY[true, false, null, true, false, null]", "ARRAY[true, true, true]")) + .matches("ARRAY[false, null]"); } @Test public void testNonDistinctNonEqualValues() { - assertFunction("array_except(ARRAY[NaN()], ARRAY[NaN()])", new ArrayType(DOUBLE), ImmutableList.of()); - assertFunction("array_except(ARRAY[1, NaN(), 3], ARRAY[NaN(), 3])", new ArrayType(DOUBLE), ImmutableList.of(1.0)); + assertThat(assertions.function("array_except", "ARRAY[NaN()]", "ARRAY[NaN()]")) + .matches("CAST(ARRAY[] AS array(double))"); + + assertThat(assertions.function("array_except", "ARRAY[1, NaN(), 3]", "ARRAY[NaN(), 3]")) + .matches("ARRAY[1E0]"); } } diff --git a/core/trino-main/src/test/java/io/trino/operator/scalar/TestArrayFilterFunction.java b/core/trino-main/src/test/java/io/trino/operator/scalar/TestArrayFilterFunction.java index 3642546486e7..bd20cfb31c42 100644 --- a/core/trino-main/src/test/java/io/trino/operator/scalar/TestArrayFilterFunction.java +++ b/core/trino-main/src/test/java/io/trino/operator/scalar/TestArrayFilterFunction.java @@ -13,73 +13,142 @@ */ package io.trino.operator.scalar; -import com.google.common.collect.ImmutableList; -import io.trino.spi.type.ArrayType; -import org.testng.annotations.Test; - -import static io.trino.spi.type.BooleanType.BOOLEAN; -import static io.trino.spi.type.DoubleType.DOUBLE; -import static io.trino.spi.type.IntegerType.INTEGER; -import static io.trino.spi.type.TimestampType.createTimestampType; -import static io.trino.spi.type.VarcharType.createVarcharType; -import static io.trino.type.UnknownType.UNKNOWN; -import static java.util.Arrays.asList; -import static java.util.Collections.singletonList; +import io.trino.sql.query.QueryAssertions; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS; + +@TestInstance(PER_CLASS) public class TestArrayFilterFunction - extends AbstractTestFunctions { + private QueryAssertions assertions; + + @BeforeAll + public void init() + { + assertions = new QueryAssertions(); + } + + @AfterAll + public void teardown() + { + assertions.close(); + assertions = null; + } + @Test public void testBasic() { - assertFunction("filter(ARRAY [5, 6], x -> x = 5)", new ArrayType(INTEGER), ImmutableList.of(5)); - assertFunction("filter(ARRAY [5 + RANDOM(1), 6 + RANDOM(1)], x -> x = 5)", new ArrayType(INTEGER), ImmutableList.of(5)); - assertFunction("filter(ARRAY [true, false, true, false], x -> nullif(x, false))", new ArrayType(BOOLEAN), ImmutableList.of(true, true)); - assertFunction("filter(ARRAY [true, false, null, true, false, null], x -> not x)", new ArrayType(BOOLEAN), ImmutableList.of(false, false)); - assertFunction( - "filter(ARRAY [TIMESTAMP '2020-05-10 12:34:56.123456789', TIMESTAMP '1111-05-10 12:34:56.123456789'], t -> year(t) = 1111)", - new ArrayType(createTimestampType(9)), - ImmutableList.of(timestamp(9, "1111-05-10 12:34:56.123456789"))); + assertThat(assertions.expression("filter(a, x -> x = 5)") + .binding("a", "ARRAY[5, 6]")) + .matches("ARRAY[5]"); + + assertThat(assertions.expression("filter(a, x -> x = 5)") + .binding("a", "ARRAY[5 + random(1), 6 + random(1)]")) + .matches("ARRAY[5]"); + + assertThat(assertions.expression("filter(a, x -> nullif(x, false))") + .binding("a", "ARRAY[true, false, true, false]")) + .matches("ARRAY[true, true]"); + + assertThat(assertions.expression("filter(a, x -> not x)") + .binding("a", "ARRAY[true, false, null, true, false, null]")) + .matches("ARRAY[false, false]"); + + assertThat(assertions.expression("filter(a, t -> year(t) = 1111)") + .binding("a", "ARRAY[TIMESTAMP '2020-05-10 12:34:56.123456789', TIMESTAMP '1111-05-10 12:34:56.123456789']")) + .matches("ARRAY[TIMESTAMP '1111-05-10 12:34:56.123456789']"); } @Test public void testEmpty() { - assertFunction("filter(ARRAY [], x -> true)", new ArrayType(UNKNOWN), ImmutableList.of()); - assertFunction("filter(ARRAY [], x -> false)", new ArrayType(UNKNOWN), ImmutableList.of()); - assertFunction("filter(ARRAY [], x -> CAST (null AS BOOLEAN))", new ArrayType(UNKNOWN), ImmutableList.of()); - assertFunction("filter(CAST (ARRAY [] AS ARRAY(INTEGER)), x -> true)", new ArrayType(INTEGER), ImmutableList.of()); + assertThat(assertions.expression("filter(a, x -> true)") + .binding("a", "ARRAY[]")) + .matches("ARRAY[]"); + + assertThat(assertions.expression("filter(a, x -> false)") + .binding("a", "ARRAY[]")) + .matches("ARRAY[]"); + + assertThat(assertions.expression("filter(a, x -> CAST(null AS boolean))") + .binding("a", "ARRAY[]")) + .matches("ARRAY[]"); + + assertThat(assertions.expression("filter(a, x -> true)") + .binding("a", "CAST(ARRAY[] AS array(integer))")) + .matches("CAST(ARRAY[] AS array(integer))"); } @Test public void testNull() { - assertFunction("filter(ARRAY [NULL], x -> x IS NULL)", new ArrayType(UNKNOWN), singletonList(null)); - assertFunction("filter(ARRAY [NULL], x -> x IS NOT NULL)", new ArrayType(UNKNOWN), ImmutableList.of()); - assertFunction("filter(ARRAY [CAST (NULL AS INTEGER)], x -> x IS NULL)", new ArrayType(INTEGER), singletonList(null)); - assertFunction("filter(ARRAY [NULL, NULL, NULL], x -> x IS NULL)", new ArrayType(UNKNOWN), asList(null, null, null)); - assertFunction("filter(ARRAY [NULL, NULL, NULL], x -> x IS NOT NULL)", new ArrayType(UNKNOWN), ImmutableList.of()); - - assertFunction("filter(ARRAY [25, 26, NULL], x -> x % 2 = 1 OR x IS NULL)", new ArrayType(INTEGER), asList(25, null)); - assertFunction("filter(ARRAY [25.6E0, 37.3E0, NULL], x -> x < 30.0E0 OR x IS NULL)", new ArrayType(DOUBLE), asList(25.6, null)); - assertFunction("filter(ARRAY [true, false, NULL], x -> not x OR x IS NULL)", new ArrayType(BOOLEAN), asList(false, null)); - assertFunction("filter(ARRAY ['abc', 'def', NULL], x -> substr(x, 1, 1) = 'a' OR x IS NULL)", new ArrayType(createVarcharType(3)), asList("abc", null)); - assertFunction( - "filter(ARRAY [ARRAY ['abc', null, '123'], NULL], x -> x[2] IS NULL OR x IS NULL)", - new ArrayType(new ArrayType(createVarcharType(3))), - asList(asList("abc", null, "123"), null)); + assertThat(assertions.expression("filter(a, x -> x IS NULL)") + .binding("a", "ARRAY[NULL]")) + .matches("ARRAY[NULL]"); + + assertThat(assertions.expression("filter(a, x -> x IS NOT NULL)") + .binding("a", "ARRAY[NULL]")) + .matches("ARRAY[]"); + + assertThat(assertions.expression("filter(a, x -> x IS NULL)") + .binding("a", "ARRAY[CAST(NULL AS integer)]")) + .matches("CAST(ARRAY[NULL] AS array(integer))"); + + assertThat(assertions.expression("filter(a, x -> x IS NULL)") + .binding("a", "ARRAY[NULL, NULL, NULL]")) + .matches("ARRAY[NULL, NULL, NULL]"); + + assertThat(assertions.expression("filter(a, x -> x IS NOT NULL)") + .binding("a", "ARRAY[NULL, NULL, NULL]")) + .matches("ARRAY[]"); + + assertThat(assertions.expression("filter(a, x -> x % 2 = 1 OR x IS NULL)") + .binding("a", "ARRAY[25, 26, NULL]")) + .matches("ARRAY[25, NULL]"); + + assertThat(assertions.expression("filter(a, x -> x < 30.0E0 OR x IS NULL)") + .binding("a", "ARRAY[25.6E0, 37.3E0, NULL]")) + .matches("ARRAY[25.6E0, NULL]"); + + assertThat(assertions.expression("filter(a, x -> NOT x OR x IS NULL)") + .binding("a", "ARRAY[true, false, NULL]")) + .matches("ARRAY[false, NULL]"); + + assertThat(assertions.expression("filter(a, x -> substr(x, 1, 1) = 'a' OR x IS NULL)") + .binding("a", "ARRAY['abc', 'def', NULL]")) + .matches("ARRAY['abc', NULL]"); + + assertThat(assertions.expression("filter(a, x -> x[2] IS NULL OR x IS NULL)") + .binding("a", "ARRAY[ARRAY['abc', NULL, '123']]")) + .matches("ARRAY[ARRAY['abc', NULL, '123']]"); } @Test public void testTypeCombinations() { - assertFunction("filter(ARRAY [25, 26, 27], x -> x % 2 = 1)", new ArrayType(INTEGER), ImmutableList.of(25, 27)); - assertFunction("filter(ARRAY [25.6E0, 37.3E0, 28.6E0], x -> x < 30.0E0)", new ArrayType(DOUBLE), ImmutableList.of(25.6, 28.6)); - assertFunction("filter(ARRAY [true, false, true], x -> not x)", new ArrayType(BOOLEAN), ImmutableList.of(false)); - assertFunction("filter(ARRAY ['abc', 'def', 'ayz'], x -> substr(x, 1, 1) = 'a')", new ArrayType(createVarcharType(3)), ImmutableList.of("abc", "ayz")); - assertFunction( - "filter(ARRAY [ARRAY ['abc', null, '123'], ARRAY ['def', 'x', '456']], x -> x[2] IS NULL)", - new ArrayType(new ArrayType(createVarcharType(3))), - ImmutableList.of(asList("abc", null, "123"))); + assertThat(assertions.expression("filter(a, x -> x % 2 = 1)") + .binding("a", "ARRAY[25, 26, 27]")) + .matches("ARRAY[25, 27]"); + + assertThat(assertions.expression("filter(a, x -> x < 30.0E0)") + .binding("a", "ARRAY[25.6E0, 37.3E0, 28.6E0]")) + .matches("ARRAY[25.6E0, 28.6E0]"); + + assertThat(assertions.expression("filter(a, x -> NOT x)") + .binding("a", "ARRAY[true, false, true]")) + .matches("ARRAY[false]"); + + assertThat(assertions.expression("filter(a, x -> substr(x, 1, 1) = 'a' OR x IS NULL)") + .binding("a", "ARRAY['abc', 'def', 'ayz']")) + .matches("ARRAY['abc', 'ayz']"); + + assertThat(assertions.expression("filter(a, x -> x[2] IS NULL)") + .binding("a", "ARRAY[ARRAY['abc', NULL, '123'], ARRAY ['def', 'x', '456']]")) + .matches("ARRAY[ARRAY['abc', NULL, '123']]"); } } diff --git a/core/trino-main/src/test/java/io/trino/operator/scalar/TestArrayFunctions.java b/core/trino-main/src/test/java/io/trino/operator/scalar/TestArrayFunctions.java index b10300100cd4..14cb70d41bda 100644 --- a/core/trino-main/src/test/java/io/trino/operator/scalar/TestArrayFunctions.java +++ b/core/trino-main/src/test/java/io/trino/operator/scalar/TestArrayFunctions.java @@ -14,34 +14,68 @@ package io.trino.operator.scalar; import com.google.common.base.Joiner; +import io.trino.spi.TrinoException; import io.trino.spi.type.ArrayType; -import org.testng.annotations.Test; +import io.trino.sql.query.QueryAssertions; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; -import static io.trino.spi.StandardErrorCode.TOO_MANY_ARGUMENTS; import static io.trino.spi.type.DoubleType.DOUBLE; import static io.trino.spi.type.IntegerType.INTEGER; import static java.util.Collections.nCopies; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS; +@TestInstance(PER_CLASS) public class TestArrayFunctions - extends AbstractTestFunctions { + private QueryAssertions assertions; + + @BeforeAll + public void init() + { + assertions = new QueryAssertions(); + } + + @AfterAll + public void teardown() + { + assertions.close(); + assertions = null; + } + @Test public void testArrayConstructor() { - tryEvaluateWithAll("array[" + Joiner.on(", ").join(nCopies(254, "rand()")) + "]", new ArrayType(DOUBLE)); - assertInvalidFunction( - "array[" + Joiner.on(", ").join(nCopies(255, "rand()")) + "]", - TOO_MANY_ARGUMENTS, - "Too many arguments for array constructor"); + assertThat(assertions.expression("array[" + Joiner.on(", ").join(nCopies(254, "rand()")) + "]")) + .hasType(new ArrayType(DOUBLE)); + + assertThat(assertions.expression("array[a, b, c]") + .binding("a", "1") + .binding("b", "2") + .binding("c", "3")) + .matches("ARRAY[1, 2, 3]"); + + assertThatThrownBy(() -> assertions.expression("array[" + Joiner.on(", ").join(nCopies(255, "rand()")) + "]").evaluate()) + .isInstanceOf(TrinoException.class) + .hasMessage("Too many arguments for array constructor"); } @Test public void testArrayConcat() { - assertFunction("CONCAT(" + Joiner.on(", ").join(nCopies(127, "array[1]")) + ")", new ArrayType(INTEGER), nCopies(127, 1)); - assertInvalidFunction( - "CONCAT(" + Joiner.on(", ").join(nCopies(128, "array[1]")) + ")", - TOO_MANY_ARGUMENTS, - "line 1:1: Too many arguments for function call concat()"); + assertThat(assertions.expression("CONCAT(" + Joiner.on(", ").join(nCopies(127, "array[1]")) + ")")) + .hasType(new ArrayType(INTEGER)) + .matches("ARRAY[%s]".formatted(Joiner.on(",").join(nCopies(127, 1)))); + + assertThat(assertions.function("concat", "ARRAY[1]", "ARRAY[2]", "ARRAY[3]")) + .matches("ARRAY[1, 2, 3]"); + + assertThatThrownBy(() -> assertions.expression("CONCAT(" + Joiner.on(", ").join(nCopies(128, "array[1]")) + ")").evaluate()) + .isInstanceOf(TrinoException.class) + .hasMessage("line 1:8: Too many arguments for function call concat()"); } } diff --git a/core/trino-main/src/test/java/io/trino/operator/scalar/TestArrayMatchFunctions.java b/core/trino-main/src/test/java/io/trino/operator/scalar/TestArrayMatchFunctions.java index e01f4c37b86e..e21ee971a6b0 100644 --- a/core/trino-main/src/test/java/io/trino/operator/scalar/TestArrayMatchFunctions.java +++ b/core/trino-main/src/test/java/io/trino/operator/scalar/TestArrayMatchFunctions.java @@ -13,54 +13,163 @@ */ package io.trino.operator.scalar; -import io.trino.spi.type.BooleanType; -import org.testng.annotations.Test; +import io.trino.sql.query.QueryAssertions; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; +import static io.trino.spi.type.BooleanType.BOOLEAN; +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS; + +@TestInstance(PER_CLASS) public class TestArrayMatchFunctions - extends AbstractTestFunctions { + private QueryAssertions assertions; + + @BeforeAll + public void init() + { + assertions = new QueryAssertions(); + } + + @AfterAll + public void teardown() + { + assertions.close(); + assertions = null; + } + @Test public void testAllMatch() { - assertFunction("all_match(ARRAY [5, 7, 9], x -> x % 2 = 1)", BooleanType.BOOLEAN, true); - assertFunction("all_match(ARRAY [true, false, true], x -> x)", BooleanType.BOOLEAN, false); - assertFunction("all_match(ARRAY ['abc', 'ade', 'afg'], x -> substr(x, 1, 1) = 'a')", BooleanType.BOOLEAN, true); - assertFunction("all_match(ARRAY [], x -> true)", BooleanType.BOOLEAN, true); - assertFunction("all_match(ARRAY [true, true, NULL], x -> x)", BooleanType.BOOLEAN, null); - assertFunction("all_match(ARRAY [true, false, NULL], x -> x)", BooleanType.BOOLEAN, false); - assertFunction("all_match(ARRAY [NULL, NULL, NULL], x -> x > 1)", BooleanType.BOOLEAN, null); - assertFunction("all_match(ARRAY [NULL, NULL, NULL], x -> x IS NULL)", BooleanType.BOOLEAN, true); - assertFunction("all_match(ARRAY [MAP(ARRAY[1,2], ARRAY[3,4]), MAP(ARRAY[1,2,3], ARRAY[3,4,5])], x -> cardinality(x) > 1)", BooleanType.BOOLEAN, true); - assertFunction("all_match(ARRAY [TIMESTAMP '2020-05-10 12:34:56.123456789', TIMESTAMP '1111-05-10 12:34:56.123456789'], t -> month(t) = 5)", BooleanType.BOOLEAN, true); + assertThat(assertions.expression("all_match(a, x -> x % 2 = 1)") + .binding("a", "ARRAY[5, 7, 9]")) + .isEqualTo(true); + + assertThat(assertions.expression("all_match(a, x -> x)") + .binding("a", "ARRAY[true, false, true]")) + .isEqualTo(false); + + assertThat(assertions.expression("all_match(a, x -> substr(x, 1, 1) = 'a')") + .binding("a", "ARRAY['abc', 'ade', 'afg']")) + .isEqualTo(true); + + assertThat(assertions.expression("all_match(a, x -> true)") + .binding("a", "ARRAY[]")) + .isEqualTo(true); + + assertThat(assertions.expression("all_match(a, x -> x)") + .binding("a", "ARRAY[true, true, NULL]")) + .matches("CAST(NULL AS boolean)"); + + assertThat(assertions.expression("all_match(a, x -> x)") + .binding("a", "ARRAY[true, false, NULL]")) + .isEqualTo(false); + + assertThat(assertions.expression("all_match(a, x -> x > 1)") + .binding("a", "ARRAY[NULL, NULL, NULL]")) + .isNull(BOOLEAN); + + assertThat(assertions.expression("all_match(a, x -> x IS NULL)") + .binding("a", "ARRAY[NULL, NULL, NULL]")) + .isEqualTo(true); + + assertThat(assertions.expression("all_match(a, x -> cardinality(x) > 1)") + .binding("a", "ARRAY[MAP(ARRAY[1,2], ARRAY[3,4]), MAP(ARRAY[1,2,3], ARRAY[3,4,5])]")) + .isEqualTo(true); + + assertThat(assertions.expression("all_match(a, t -> month(t) = 5)") + .binding("a", "ARRAY[TIMESTAMP '2020-05-10 12:34:56.123456789', TIMESTAMP '1111-05-10 12:34:56.123456789']")) + .isEqualTo(true); } @Test public void testAnyMatch() { - assertFunction("any_match(ARRAY [5, 8, 10], x -> x % 2 = 1)", BooleanType.BOOLEAN, true); - assertFunction("any_match(ARRAY [false, false, false], x -> x)", BooleanType.BOOLEAN, false); - assertFunction("any_match(ARRAY ['abc', 'def', 'ghi'], x -> substr(x, 1, 1) = 'a')", BooleanType.BOOLEAN, true); - assertFunction("any_match(ARRAY [], x -> true)", BooleanType.BOOLEAN, false); - assertFunction("any_match(ARRAY [false, false, NULL], x -> x)", BooleanType.BOOLEAN, null); - assertFunction("any_match(ARRAY [true, false, NULL], x -> x)", BooleanType.BOOLEAN, true); - assertFunction("any_match(ARRAY [NULL, NULL, NULL], x -> x > 1)", BooleanType.BOOLEAN, null); - assertFunction("any_match(ARRAY [true, false, NULL], x -> x IS NULL)", BooleanType.BOOLEAN, true); - assertFunction("any_match(ARRAY [MAP(ARRAY[1,2], ARRAY[3,4]), MAP(ARRAY[1,2,3], ARRAY[3,4,5])], x -> cardinality(x) > 4)", BooleanType.BOOLEAN, false); - assertFunction("any_match(ARRAY [TIMESTAMP '2020-05-10 12:34:56.123456789', TIMESTAMP '1111-05-10 12:34:56.123456789'], t -> year(t) = 2020)", BooleanType.BOOLEAN, true); + assertThat(assertions.expression("any_match(a, x -> x % 2 = 1)") + .binding("a", "ARRAY[5, 8, 10]")) + .isEqualTo(true); + + assertThat(assertions.expression("any_match(a, x -> x)") + .binding("a", "ARRAY[false, false, false]")) + .isEqualTo(false); + + assertThat(assertions.expression("any_match(a, x -> substr(x, 1, 1) = 'a')") + .binding("a", "ARRAY['abc', 'def', 'ghi']")) + .isEqualTo(true); + + assertThat(assertions.expression("any_match(a, x -> true)") + .binding("a", "ARRAY[]")) + .isEqualTo(false); + + assertThat(assertions.expression("any_match(a, x -> x)") + .binding("a", "ARRAY[false, false, NULL]")) + .matches("CAST(NULL AS boolean)"); + + assertThat(assertions.expression("any_match(a, x -> x)") + .binding("a", "ARRAY[true, false, NULL]")) + .isEqualTo(true); + + assertThat(assertions.expression("any_match(a, x -> x > 1)") + .binding("a", "ARRAY[NULL, NULL, NULL]")) + .isNull(BOOLEAN); + + assertThat(assertions.expression("any_match(a, x -> x IS NULL)") + .binding("a", "ARRAY[true, false, NULL]")) + .isEqualTo(true); + + assertThat(assertions.expression("any_match(a, x -> cardinality(x) > 4)") + .binding("a", "ARRAY[MAP(ARRAY[1,2], ARRAY[3,4]), MAP(ARRAY[1,2,3], ARRAY[3,4,5])]")) + .isEqualTo(false); + + assertThat(assertions.expression("any_match(a, t -> year(t) = 2020)") + .binding("a", "ARRAY[TIMESTAMP '2020-05-10 12:34:56.123456789', TIMESTAMP '1111-05-10 12:34:56.123456789']")) + .isEqualTo(true); } @Test public void testNoneMatch() { - assertFunction("none_match(ARRAY [5, 8, 10], x -> x % 2 = 1)", BooleanType.BOOLEAN, false); - assertFunction("none_match(ARRAY [false, false, false], x -> x)", BooleanType.BOOLEAN, true); - assertFunction("none_match(ARRAY ['abc', 'def', 'ghi'], x -> substr(x, 1, 1) = 'a')", BooleanType.BOOLEAN, false); - assertFunction("none_match(ARRAY [], x -> true)", BooleanType.BOOLEAN, true); - assertFunction("none_match(ARRAY [false, false, NULL], x -> x)", BooleanType.BOOLEAN, null); - assertFunction("none_match(ARRAY [true, false, NULL], x -> x)", BooleanType.BOOLEAN, false); - assertFunction("none_match(ARRAY [NULL, NULL, NULL], x -> x > 1)", BooleanType.BOOLEAN, null); - assertFunction("none_match(ARRAY [true, false, NULL], x -> x IS NULL)", BooleanType.BOOLEAN, false); - assertFunction("none_match(ARRAY [MAP(ARRAY[1,2], ARRAY[3,4]), MAP(ARRAY[1,2,3], ARRAY[3,4,5])], x -> cardinality(x) > 4)", BooleanType.BOOLEAN, true); - assertFunction("none_match(ARRAY [TIMESTAMP '2020-05-10 12:34:56.123456789', TIMESTAMP '1111-05-10 12:34:56.123456789'], t -> month(t) = 10)", BooleanType.BOOLEAN, true); + assertThat(assertions.expression("none_match(a, x -> x % 2 = 1)") + .binding("a", "ARRAY[5, 8, 10]")) + .isEqualTo(false); + + assertThat(assertions.expression("none_match(a, x -> x)") + .binding("a", "ARRAY[false, false, false]")) + .isEqualTo(true); + + assertThat(assertions.expression("none_match(a, x -> substr(x, 1, 1) = 'a')") + .binding("a", "ARRAY['abc', 'def', 'ghi']")) + .isEqualTo(false); + + assertThat(assertions.expression("none_match(a, x -> true)") + .binding("a", "ARRAY[]")) + .isEqualTo(true); + + assertThat(assertions.expression("none_match(a, x -> x)") + .binding("a", "ARRAY[false, false, NULL]")) + .matches("CAST(NULL AS boolean)"); + + assertThat(assertions.expression("none_match(a, x -> x)") + .binding("a", "ARRAY[true, false, NULL]")) + .isEqualTo(false); + + assertThat(assertions.expression("none_match(a, x -> x > 1)") + .binding("a", "ARRAY[NULL, NULL, NULL]")) + .isNull(BOOLEAN); + + assertThat(assertions.expression("none_match(a, x -> x IS NULL)") + .binding("a", "ARRAY[true, false, NULL]")) + .isEqualTo(false); + + assertThat(assertions.expression("none_match(a, x -> cardinality(x) > 4)") + .binding("a", "ARRAY[MAP(ARRAY[1,2], ARRAY[3,4]), MAP(ARRAY[1,2,3], ARRAY[3,4,5])]")) + .isEqualTo(true); + + assertThat(assertions.expression("none_match(a, t -> month(t) = 10)") + .binding("a", "ARRAY[TIMESTAMP '2020-05-10 12:34:56.123456789', TIMESTAMP '1111-05-10 12:34:56.123456789']")) + .isEqualTo(true); } } diff --git a/core/trino-main/src/test/java/io/trino/operator/scalar/TestColorFunctions.java b/core/trino-main/src/test/java/io/trino/operator/scalar/TestColorFunctions.java index b64ed5621650..a125133cd2cd 100644 --- a/core/trino-main/src/test/java/io/trino/operator/scalar/TestColorFunctions.java +++ b/core/trino-main/src/test/java/io/trino/operator/scalar/TestColorFunctions.java @@ -13,6 +13,7 @@ */ package io.trino.operator.scalar; +import io.trino.sql.query.QueryAssertions; import org.testng.annotations.Test; import static io.airlift.slice.Slices.utf8Slice; @@ -25,12 +26,11 @@ import static io.trino.operator.scalar.ColorFunctions.render; import static io.trino.operator.scalar.ColorFunctions.rgb; import static io.trino.spi.function.OperatorType.INDETERMINATE; -import static io.trino.spi.type.BooleanType.BOOLEAN; import static java.lang.String.format; +import static org.assertj.core.api.Assertions.assertThat; import static org.testng.Assert.assertEquals; public class TestColorFunctions - extends AbstractTestFunctions { @Test public void testParseRgb() @@ -156,7 +156,12 @@ public void testInterpolate() @Test public void testIndeterminate() { - assertOperator(INDETERMINATE, "color(null)", BOOLEAN, true); - assertOperator(INDETERMINATE, "color('black')", BOOLEAN, false); + try (QueryAssertions assertions = new QueryAssertions()) { + assertThat(assertions.operator(INDETERMINATE, "color(null)")) + .isEqualTo(true); + + assertThat(assertions.operator(INDETERMINATE, "color('black')")) + .isEqualTo(false); + } } } diff --git a/core/trino-main/src/test/java/io/trino/operator/scalar/TestConditions.java b/core/trino-main/src/test/java/io/trino/operator/scalar/TestConditions.java index 92c8cdf8527d..718a17863bba 100644 --- a/core/trino-main/src/test/java/io/trino/operator/scalar/TestConditions.java +++ b/core/trino-main/src/test/java/io/trino/operator/scalar/TestConditions.java @@ -13,145 +13,519 @@ */ package io.trino.operator.scalar; -import org.testng.annotations.Test; +import io.trino.spi.TrinoException; +import io.trino.sql.query.QueryAssertions; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; -import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.BooleanType.BOOLEAN; import static io.trino.spi.type.DecimalType.createDecimalType; -import static io.trino.spi.type.DoubleType.DOUBLE; import static io.trino.spi.type.IntegerType.INTEGER; -import static io.trino.spi.type.SqlDecimal.decimal; -import static io.trino.spi.type.VarcharType.VARCHAR; import static io.trino.spi.type.VarcharType.createVarcharType; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS; +@TestInstance(PER_CLASS) public class TestConditions - extends AbstractTestFunctions { - @Test - public void testLike() - { - assertFunction("'_monkey_' like 'X_monkeyX_' escape 'X'", BOOLEAN, true); - - assertFunction("'monkey' like 'monkey'", BOOLEAN, true); - assertFunction("'monkey' like 'mon%'", BOOLEAN, true); - assertFunction("'monkey' like 'mon_ey'", BOOLEAN, true); - assertFunction("'monkey' like 'm____y'", BOOLEAN, true); - - assertFunction("'monkey' like 'dain'", BOOLEAN, false); - assertFunction("'monkey' like 'key'", BOOLEAN, false); - - assertFunction("'_monkey_' like '\\_monkey\\_'", BOOLEAN, false); - assertFunction("'_monkey_' like 'X_monkeyX_' escape 'X'", BOOLEAN, true); - - assertFunction("null like 'monkey'", BOOLEAN, null); - assertFunction("'monkey' like null", BOOLEAN, null); - assertFunction("'monkey' like 'monkey' escape null", BOOLEAN, null); + private QueryAssertions assertions; - assertFunction("'_monkey_' not like 'X_monkeyX_' escape 'X'", BOOLEAN, false); - - assertFunction("'monkey' not like 'monkey'", BOOLEAN, false); - assertFunction("'monkey' not like 'mon%'", BOOLEAN, false); - assertFunction("'monkey' not like 'mon_ey'", BOOLEAN, false); - assertFunction("'monkey' not like 'm____y'", BOOLEAN, false); - - assertFunction("'monkey' not like 'dain'", BOOLEAN, true); - assertFunction("'monkey' not like 'key'", BOOLEAN, true); - - assertFunction("'_monkey_' not like '\\_monkey\\_'", BOOLEAN, true); - assertFunction("'_monkey_' not like 'X_monkeyX_' escape 'X'", BOOLEAN, false); + @BeforeAll + public void init() + { + assertions = new QueryAssertions(); + } - assertFunction("null not like 'monkey'", BOOLEAN, null); - assertFunction("'monkey' not like null", BOOLEAN, null); - assertFunction("'monkey' not like 'monkey' escape null", BOOLEAN, null); + @AfterAll + public void teardown() + { + assertions.close(); + assertions = null; + } - assertInvalidFunction("'monkey' like 'monkey' escape 'foo'", "Escape string must be a single character"); + @Test + public void testLike() + { + // like + assertThat(assertions.expression("a like 'X_monkeyX_' escape 'X'") + .binding("a", "'_monkey_'")) + .isEqualTo(true); + + assertThat(assertions.expression("a like 'monkey'") + .binding("a", "'monkey'")) + .isEqualTo(true); + + assertThat(assertions.expression("a like 'mon%'") + .binding("a", "'monkey'")) + .isEqualTo(true); + + assertThat(assertions.expression("a like '%key'") + .binding("a", "'monkey'")) + .isEqualTo(true); + + assertThat(assertions.expression("a like 'm____y'") + .binding("a", "'monkey'")) + .isEqualTo(true); + + assertThat(assertions.expression("a like 'lion'") + .binding("a", "'monkey'")) + .isEqualTo(false); + + assertThat(assertions.expression("a like 'monkey'") + .binding("a", "null")) + .isNull(BOOLEAN); + + assertThat(assertions.expression("a like null") + .binding("a", "'monkey'")) + .isNull(BOOLEAN); + + assertThat(assertions.expression("a like 'monkey' escape null") + .binding("a", "'monkey'")) + .isNull(BOOLEAN); + + // not like + assertThat(assertions.expression("a not like 'X_monkeyX_' escape 'X'") + .binding("a", "'_monkey_'")) + .isEqualTo(false); + + assertThat(assertions.expression("a not like 'monkey'") + .binding("a", "'monkey'")) + .isEqualTo(false); + + assertThat(assertions.expression("a not like 'mon%'") + .binding("a", "'monkey'")) + .isEqualTo(false); + + assertThat(assertions.expression("a not like '%key'") + .binding("a", "'monkey'")) + .isEqualTo(false); + + assertThat(assertions.expression("a not like 'm____y'") + .binding("a", "'monkey'")) + .isEqualTo(false); + + assertThat(assertions.expression("a not like 'lion'") + .binding("a", "'monkey'")) + .isEqualTo(true); + + assertThat(assertions.expression("a not like 'monkey'") + .binding("a", "null")) + .isNull(BOOLEAN); + + assertThat(assertions.expression("a not like null") + .binding("a", "'monkey'")) + .isNull(BOOLEAN); + + assertThat(assertions.expression("a not like 'monkey' escape null") + .binding("a", "'monkey'")) + .isNull(BOOLEAN); + + assertThatThrownBy(() -> assertions.expression("a like 'monkey' escape 'foo'") + .binding("a", "'monkey'") + .evaluate()) + .isInstanceOf(TrinoException.class) + .hasMessage("Escape string must be a single character"); } @Test public void testDistinctFrom() { - assertFunction("NULL IS DISTINCT FROM NULL", BOOLEAN, false); - assertFunction("NULL IS DISTINCT FROM 1", BOOLEAN, true); - assertFunction("1 IS DISTINCT FROM NULL", BOOLEAN, true); - assertFunction("1 IS DISTINCT FROM 1", BOOLEAN, false); - assertFunction("1 IS DISTINCT FROM 2", BOOLEAN, true); - - assertFunction("NULL IS NOT DISTINCT FROM NULL", BOOLEAN, true); - assertFunction("NULL IS NOT DISTINCT FROM 1", BOOLEAN, false); - assertFunction("1 IS NOT DISTINCT FROM NULL", BOOLEAN, false); - assertFunction("1 IS NOT DISTINCT FROM 1", BOOLEAN, true); - assertFunction("1 IS NOT DISTINCT FROM 2", BOOLEAN, false); + // distinct from + assertThat(assertions.expression("a IS DISTINCT FROM b") + .binding("a", "null") + .binding("b", "null")) + .isEqualTo(false); + + assertThat(assertions.expression("a IS DISTINCT FROM b") + .binding("a", "null") + .binding("b", "1")) + .isEqualTo(true); + + assertThat(assertions.expression("a IS DISTINCT FROM b") + .binding("a", "1") + .binding("b", "null")) + .isEqualTo(true); + + assertThat(assertions.expression("a IS DISTINCT FROM b") + .binding("a", "1") + .binding("b", "1")) + .isEqualTo(false); + + assertThat(assertions.expression("a IS DISTINCT FROM b") + .binding("a", "1") + .binding("b", "2")) + .isEqualTo(true); + + // not distinct from + assertThat(assertions.expression("a IS NOT DISTINCT FROM b") + .binding("a", "null") + .binding("b", "null")) + .isEqualTo(true); + + assertThat(assertions.expression("a IS NOT DISTINCT FROM b") + .binding("a", "null") + .binding("b", "1")) + .isEqualTo(false); + + assertThat(assertions.expression("a IS NOT DISTINCT FROM b") + .binding("a", "1") + .binding("b", "null")) + .isEqualTo(false); + + assertThat(assertions.expression("a IS NOT DISTINCT FROM b") + .binding("a", "1") + .binding("b", "1")) + .isEqualTo(true); + + assertThat(assertions.expression("a IS NOT DISTINCT FROM b") + .binding("a", "1") + .binding("b", "2")) + .isEqualTo(false); } @Test public void testBetween() { - assertFunction("3 between 2 and 4", BOOLEAN, true); - assertFunction("3 between 3 and 3", BOOLEAN, true); - assertFunction("3 between 2 and 3", BOOLEAN, true); - assertFunction("3 between 3 and 4", BOOLEAN, true); - assertFunction("3 between 4 and 2", BOOLEAN, false); - assertFunction("2 between 3 and 4", BOOLEAN, false); - assertFunction("5 between 3 and 4", BOOLEAN, false); - assertFunction("null between 2 and 4", BOOLEAN, null); - assertFunction("3 between null and 4", BOOLEAN, null); - assertFunction("3 between 2 and null", BOOLEAN, null); - - assertFunction("3 between 3 and 4000000000", BOOLEAN, true); - assertFunction("5 between 3 and 4000000000", BOOLEAN, true); - assertFunction("3 between BIGINT '3' and 4", BOOLEAN, true); - assertFunction("BIGINT '3' between 3 and 4", BOOLEAN, true); - - assertFunction("'c' between 'b' and 'd'", BOOLEAN, true); - assertFunction("'c' between 'c' and 'c'", BOOLEAN, true); - assertFunction("'c' between 'b' and 'c'", BOOLEAN, true); - assertFunction("'c' between 'c' and 'd'", BOOLEAN, true); - assertFunction("'c' between 'd' and 'b'", BOOLEAN, false); - assertFunction("'b' between 'c' and 'd'", BOOLEAN, false); - assertFunction("'e' between 'c' and 'd'", BOOLEAN, false); - assertFunction("null between 'b' and 'd'", BOOLEAN, null); - assertFunction("'c' between null and 'd'", BOOLEAN, null); - assertFunction("'c' between 'b' and null", BOOLEAN, null); - - assertFunction("3 not between 2 and 4", BOOLEAN, false); - assertFunction("3 not between 3 and 3", BOOLEAN, false); - assertFunction("3 not between 2 and 3", BOOLEAN, false); - assertFunction("3 not between 3 and 4", BOOLEAN, false); - assertFunction("3 not between 4 and 2", BOOLEAN, true); - assertFunction("2 not between 3 and 4", BOOLEAN, true); - assertFunction("5 not between 3 and 4", BOOLEAN, true); - assertFunction("null not between 2 and 4", BOOLEAN, null); - assertFunction("3 not between null and 4", BOOLEAN, null); - assertFunction("3 not between 2 and null", BOOLEAN, null); - - assertFunction("'c' not between 'b' and 'd'", BOOLEAN, false); - assertFunction("'c' not between 'c' and 'c'", BOOLEAN, false); - assertFunction("'c' not between 'b' and 'c'", BOOLEAN, false); - assertFunction("'c' not between 'c' and 'd'", BOOLEAN, false); - assertFunction("'c' not between 'd' and 'b'", BOOLEAN, true); - assertFunction("'b' not between 'c' and 'd'", BOOLEAN, true); - assertFunction("'e' not between 'c' and 'd'", BOOLEAN, true); - assertFunction("null not between 'b' and 'd'", BOOLEAN, null); - assertFunction("'c' not between null and 'd'", BOOLEAN, null); - assertFunction("'c' not between 'b' and null", BOOLEAN, null); + // between + assertThat(assertions.expression("value between low and high") + .binding("value", "3") + .binding("low", "2") + .binding("high", "4")) + .isEqualTo(true); + + assertThat(assertions.expression("value between low and high") + .binding("value", "3") + .binding("low", "3") + .binding("high", "3")) + .isEqualTo(true); + + assertThat(assertions.expression("value between low and high") + .binding("value", "3") + .binding("low", "2") + .binding("high", "3")) + .isEqualTo(true); + + assertThat(assertions.expression("value between low and high") + .binding("value", "3") + .binding("low", "3") + .binding("high", "4")) + .isEqualTo(true); + + assertThat(assertions.expression("value between low and high") + .binding("value", "3") + .binding("low", "4") + .binding("high", "2")) + .isEqualTo(false); + + assertThat(assertions.expression("value between low and high") + .binding("value", "2") + .binding("low", "3") + .binding("high", "4")) + .isEqualTo(false); + + assertThat(assertions.expression("value between low and high") + .binding("value", "5") + .binding("low", "3") + .binding("high", "4")) + .isEqualTo(false); + + assertThat(assertions.expression("value between low and high") + .binding("value", "null") + .binding("low", "3") + .binding("high", "4")) + .isNull(BOOLEAN); + + assertThat(assertions.expression("value between low and high") + .binding("value", "3") + .binding("low", "null") + .binding("high", "4")) + .isNull(BOOLEAN); + + assertThat(assertions.expression("value between low and high") + .binding("value", "3") + .binding("low", "2") + .binding("high", "null")) + .isNull(BOOLEAN); + + assertThat(assertions.expression("value between low and high") + .binding("value", "3") + .binding("low", "3") + .binding("high", "4000000000")) + .isEqualTo(true); + + assertThat(assertions.expression("value between low and high") + .binding("value", "5") + .binding("low", "3") + .binding("high", "4000000000")) + .isEqualTo(true); + + assertThat(assertions.expression("value between low and high") + .binding("value", "5") + .binding("low", "BIGINT '3'") + .binding("high", "4")) + .isEqualTo(false); + + assertThat(assertions.expression("value between low and high") + .binding("value", "BIGINT '3'") + .binding("low", "3") + .binding("high", "4")) + .isEqualTo(true); + + assertThat(assertions.expression("value between low and high") + .binding("value", "'c'") + .binding("low", "'b'") + .binding("high", "'b'")) + .isEqualTo(false); + + assertThat(assertions.expression("value between low and high") + .binding("value", "'c'") + .binding("low", "'c'") + .binding("high", "'c'")) + .isEqualTo(true); + + assertThat(assertions.expression("value between low and high") + .binding("value", "'c'") + .binding("low", "'b'") + .binding("high", "'c'")) + .isEqualTo(true); + + assertThat(assertions.expression("value between low and high") + .binding("value", "'c'") + .binding("low", "'c'") + .binding("high", "'d'")) + .isEqualTo(true); + + assertThat(assertions.expression("value between low and high") + .binding("value", "'c'") + .binding("low", "'d'") + .binding("high", "'b'")) + .isEqualTo(false); + + assertThat(assertions.expression("value between low and high") + .binding("value", "'b'") + .binding("low", "'c'") + .binding("high", "'d'")) + .isEqualTo(false); + + assertThat(assertions.expression("value between low and high") + .binding("value", "'e'") + .binding("low", "'c'") + .binding("high", "'d'")) + .isEqualTo(false); + + assertThat(assertions.expression("value between low and high") + .binding("value", "null") + .binding("low", "'b'") + .binding("high", "'d'")) + .matches("CAST(null AS boolean)"); + + assertThat(assertions.expression("value between low and high") + .binding("value", "'c'") + .binding("low", "null") + .binding("high", "'d'")) + .isNull(BOOLEAN); + + assertThat(assertions.expression("value between low and high") + .binding("value", "'c'") + .binding("low", "'b'") + .binding("high", "null")) + .isNull(BOOLEAN); + + // not between + assertThat(assertions.expression("value not between low and high") + .binding("value", "3") + .binding("low", "2") + .binding("high", "4")) + .isEqualTo(false); + + assertThat(assertions.expression("value not between low and high") + .binding("value", "3") + .binding("low", "3") + .binding("high", "3")) + .isEqualTo(false); + + assertThat(assertions.expression("value not between low and high") + .binding("value", "3") + .binding("low", "2") + .binding("high", "3")) + .isEqualTo(false); + + assertThat(assertions.expression("value not between low and high") + .binding("value", "3") + .binding("low", "3") + .binding("high", "4")) + .isEqualTo(false); + + assertThat(assertions.expression("value not between low and high") + .binding("value", "3") + .binding("low", "4") + .binding("high", "2")) + .isEqualTo(true); + + assertThat(assertions.expression("value not between low and high") + .binding("value", "2") + .binding("low", "3") + .binding("high", "4")) + .isEqualTo(true); + + assertThat(assertions.expression("value not between low and high") + .binding("value", "5") + .binding("low", "3") + .binding("high", "4")) + .isEqualTo(true); + + assertThat(assertions.expression("value not between low and high") + .binding("value", "null") + .binding("low", "3") + .binding("high", "4")) + .isNull(BOOLEAN); + + assertThat(assertions.expression("value not between low and high") + .binding("value", "3") + .binding("low", "null") + .binding("high", "4")) + .isNull(BOOLEAN); + + assertThat(assertions.expression("value not between low and high") + .binding("value", "3") + .binding("low", "2") + .binding("high", "null")) + .isNull(BOOLEAN); + + assertThat(assertions.expression("value not between low and high") + .binding("value", "3") + .binding("low", "3") + .binding("high", "4000000000")) + .isEqualTo(false); + + assertThat(assertions.expression("value not between low and high") + .binding("value", "5") + .binding("low", "3") + .binding("high", "4000000000")) + .isEqualTo(false); + + assertThat(assertions.expression("value not between low and high") + .binding("value", "5") + .binding("low", "BIGINT '3'") + .binding("high", "4")) + .isEqualTo(true); + + assertThat(assertions.expression("value not between low and high") + .binding("value", "BIGINT '3'") + .binding("low", "3") + .binding("high", "4")) + .isEqualTo(false); + + assertThat(assertions.expression("value not between low and high") + .binding("value", "'c'") + .binding("low", "'b'") + .binding("high", "'b'")) + .isEqualTo(true); + + assertThat(assertions.expression("value not between low and high") + .binding("value", "'c'") + .binding("low", "'c'") + .binding("high", "'c'")) + .isEqualTo(false); + + assertThat(assertions.expression("value not between low and high") + .binding("value", "'c'") + .binding("low", "'b'") + .binding("high", "'c'")) + .isEqualTo(false); + + assertThat(assertions.expression("value not between low and high") + .binding("value", "'c'") + .binding("low", "'c'") + .binding("high", "'d'")) + .isEqualTo(false); + + assertThat(assertions.expression("value not between low and high") + .binding("value", "'c'") + .binding("low", "'d'") + .binding("high", "'b'")) + .isEqualTo(true); + + assertThat(assertions.expression("value not between low and high") + .binding("value", "'b'") + .binding("low", "'c'") + .binding("high", "'d'")) + .isEqualTo(true); + + assertThat(assertions.expression("value not between low and high") + .binding("value", "'e'") + .binding("low", "'c'") + .binding("high", "'d'")) + .isEqualTo(true); + + assertThat(assertions.expression("value not between low and high") + .binding("value", "null") + .binding("low", "'b'") + .binding("high", "'d'")) + .isNull(BOOLEAN); + + assertThat(assertions.expression("value not between low and high") + .binding("value", "'c'") + .binding("low", "null") + .binding("high", "'d'")) + .isNull(BOOLEAN); + + assertThat(assertions.expression("value not between low and high") + .binding("value", "'c'") + .binding("low", "'b'") + .binding("high", "null")) + .isNull(BOOLEAN); } @Test public void testIn() { - assertFunction("3 in (2, 4, 3, 5)", BOOLEAN, true); - assertFunction("3 not in (2, 4, 3, 5)", BOOLEAN, false); - assertFunction("3 in (2, 4, 9, 5)", BOOLEAN, false); - assertFunction("3 in (2, null, 3, 5)", BOOLEAN, true); + assertThat(assertions.expression("value in (2, 4, 3, 5)") + .binding("value", "3")) + .isEqualTo(true); + + assertThat(assertions.expression("value not in (2, 4, 3, 5)") + .binding("value", "3")) + .isEqualTo(false); + + assertThat(assertions.expression("value in (2, 4, 9, 5)") + .binding("value", "3")) + .isEqualTo(false); + + assertThat(assertions.expression("value in (2, null, 3, 5)") + .binding("value", "3")) + .isEqualTo(true); + + assertThat(assertions.expression("value in ('bar', 'baz', 'foo', 'blah')") + .binding("value", "'foo'")) + .isEqualTo(true); + + assertThat(assertions.expression("value in ('bar', 'baz', 'buz', 'blah')") + .binding("value", "'foo'")) + .isEqualTo(false); + + assertThat(assertions.expression("value in ('bar', null, 'foo', 'blah')") + .binding("value", "'foo'")) + .isEqualTo(true); + + assertThat(assertions.expression("value in (2, null, 3, 5)") + .binding("value", "null")) + .isNull(BOOLEAN); + + assertThat(assertions.expression("value in (2, null)") + .binding("value", "3")) + .isNull(BOOLEAN); - assertFunction("'foo' in ('bar', 'baz', 'foo', 'blah')", BOOLEAN, true); - assertFunction("'foo' in ('bar', 'baz', 'buz', 'blah')", BOOLEAN, false); - assertFunction("'foo' in ('bar', null, 'foo', 'blah')", BOOLEAN, true); + assertThat(assertions.expression("value not in (2, null, 3, 5)") + .binding("value", "null")) + .isNull(BOOLEAN); - assertFunction("(null in (2, null, 3, 5)) is null", BOOLEAN, true); - assertFunction("(3 in (2, null)) is null", BOOLEAN, true); - assertFunction("(null not in (2, null, 3, 5)) is null", BOOLEAN, true); - assertFunction("(3 not in (2, null)) is null", BOOLEAN, true); + assertThat(assertions.expression("value not in (2, null)") + .binding("value", "3")) + .isNull(BOOLEAN); // Because of the failing in-list item 5 / 0, the in-predicate cannot be simplified. // It is instead processed with the use of generated code which applies the short-circuit @@ -162,296 +536,376 @@ public void testIn() @Test public void testSearchCase() { - assertFunction("case " + - "when true then 33 " + - "end", - INTEGER, - 33); - - assertFunction("case " + - "when true then BIGINT '33' " + - "end", - BIGINT, - 33L); - - assertFunction("case " + - "when false then 1 " + - "else 33 " + - "end", - INTEGER, - 33); - - assertFunction("case " + - "when false then 10000000000 " + - "else 33 " + - "end", - BIGINT, - 33L); - - assertFunction("case " + - "when false then 1 " + - "when false then 1 " + - "when true then 33 " + - "else 1 " + - "end", - INTEGER, - 33); - - assertFunction("case " + - "when false then BIGINT '1' " + - "when false then 1 " + - "when true then 33 " + - "else 1 " + - "end", - BIGINT, - 33L); - - assertFunction("case " + - "when false then 10000000000 " + - "when false then 1 " + - "when true then 33 " + - "else 1 " + - "end", - BIGINT, - 33L); - - assertFunction("case " + - "when false then 1 " + - "end", - INTEGER, - null); - - assertFunction("case " + - "when true then null " + - "else 'foo' " + - "end", - createVarcharType(3), - null); - - assertFunction("case " + - "when null then 1 " + - "when true then 33 " + - "end", - INTEGER, - 33); - - assertFunction("case " + - "when null then 10000000000 " + - "when true then 33 " + - "end", - BIGINT, - 33L); - - assertFunction("case " + - "when false then 1.0E0 " + - "when true then 33 " + - "end", - DOUBLE, - 33.0); - - assertDecimalFunction("case " + - "when false then DECIMAL '2.2' " + - "when true then DECIMAL '2.2' " + - "end", - decimal("2.2", createDecimalType(2, 1))); - - assertDecimalFunction("case " + - "when false then DECIMAL '1234567890.0987654321' " + - "when true then DECIMAL '3.3' " + - "end", - decimal("0000000003.3000000000", createDecimalType(20, 10))); - - assertDecimalFunction("case " + - "when false then 1 " + - "when true then DECIMAL '2.2' " + - "end", - decimal("0000000002.2", createDecimalType(11, 1))); - - assertDecimalFunction("case " + - "when false then 2.2 " + - "when true then 2.2 " + - "end", - decimal("2.2", createDecimalType(2, 1))); - - assertDecimalFunction("case " + - "when false then 1234567890.0987654321 " + - "when true then 3.3 " + - "end", - decimal("0000000003.3000000000", createDecimalType(20, 10))); - - assertDecimalFunction("case " + - "when false then 1 " + - "when true then 2.2 " + - "end", - decimal("0000000002.2", createDecimalType(11, 1))); - - assertFunction("case " + - "when false then DECIMAL '1.1' " + - "when true then 33.0E0 " + - "end", - DOUBLE, - 33.0); - - assertFunction("case " + - "when false then 1.1 " + - "when true then 33.0E0 " + - "end", - DOUBLE, - 33.0); + assertThat(assertions.expression(""" + case + when value then 33 + end + """) + .binding("value", "true")) + .matches("33"); + + assertThat(assertions.expression(""" + case + when value then BIGINT '33' + end + """) + .binding("value", "true")) + .matches("BIGINT '33'"); + + assertThat(assertions.expression(""" + case + when value then 1 + else 33 + end + """) + .binding("value", "false")) + .matches("33"); + + assertThat(assertions.expression(""" + case + when value then 10000000000 + else 33 + end + """) + .binding("value", "false")) + .matches("BIGINT '33'"); + + assertThat(assertions.expression(""" + case + when condition1 then 1 + when condition2 then 1 + when condition3 then 33 + else 1 + end + """) + .binding("condition1", "false") + .binding("condition2", "false") + .binding("condition3", "true")) + .matches("33"); + + assertThat(assertions.expression(""" + case + when condition1 then BIGINT '1' + when condition2 then 1 + when condition3 then 33 + else 1 + end + """) + .binding("condition1", "false") + .binding("condition2", "false") + .binding("condition3", "true")) + .matches("BIGINT '33'"); + + assertThat(assertions.expression(""" + case + when condition1 then 10000000000 + when condition2 then 1 + when condition3 then 33 + else 1 + end + """) + .binding("condition1", "false") + .binding("condition2", "false") + .binding("condition3", "true")) + .matches("BIGINT '33'"); + + assertThat(assertions.expression(""" + case + when value then 1 + end + """) + .binding("value", "false")) + .matches("CAST(null AS integer)"); + + assertThat(assertions.expression(""" + case + when value then null + else 'foo' + end + """) + .binding("value", "true")) + .isNull(createVarcharType(3)); + + assertThat(assertions.expression(""" + case + when condition1 then 1 + when condition2 then 33 + end + """) + .binding("condition1", "null") + .binding("condition2", "true")) + .matches("33"); + + assertThat(assertions.expression(""" + case + when condition1 then 10000000000 + when condition2 then 33 + end + """) + .binding("condition1", "null") + .binding("condition2", "true")) + .matches("BIGINT '33'"); + + assertThat(assertions.expression(""" + case + when condition1 then 1.0E0 + when condition2 then 33 + end + """) + .binding("condition1", "false") + .binding("condition2", "true")) + .matches("33E0"); + + assertThat(assertions.expression(""" + case + when condition1 then 2.2 + when condition2 then 2.2 + end + """) + .binding("condition1", "false") + .binding("condition2", "true")) + .hasType(createDecimalType(2, 1)) + .matches("2.2"); + + assertThat(assertions.expression(""" + case + when condition1 then 1234567890.0987654321 + when condition2 then 3.3 + end + """) + .binding("condition1", "false") + .binding("condition2", "true")) + .matches("CAST(3.3 AS decimal(20, 10))"); + + assertThat(assertions.expression(""" + case + when condition1 then 1 + when condition2 then 2.2 + end + """) + .binding("condition1", "false") + .binding("condition2", "true")) + .matches("CAST(2.2 AS decimal(11, 1))"); + + assertThat(assertions.expression(""" + case + when condition1 then 1.1 + when condition2 then 33E0 + end + """) + .binding("condition1", "false") + .binding("condition2", "true")) + .matches("33E0"); } @Test public void testSimpleCase() { - assertFunction("case true " + - "when true then cast(null as varchar) " + - "else 'foo' " + - "end", - VARCHAR, - null); - - assertFunction("case true " + - "when true then 33 " + - "end", - INTEGER, - 33); - - assertFunction("case true " + - "when true then BIGINT '33' " + - "end", - BIGINT, - 33L); - - assertFunction("case true " + - "when false then 1 " + - "else 33 " + - "end", - INTEGER, - 33); - - assertFunction("case true " + - "when false then 10000000000 " + - "else 33 " + - "end", - BIGINT, - 33L); - - assertFunction("case true " + - "when false then 1 " + - "when false then 1 " + - "when true then 33 " + - "else 1 " + - "end", - INTEGER, - 33); - - assertFunction("case true " + - "when false then 1 " + - "end", - INTEGER, - null); - - assertFunction("case true " + - "when true then null " + - "else 'foo' " + - "end", - createVarcharType(3), - null); - - assertFunction("case true " + - "when null then 10000000000 " + - "when true then 33 " + - "end", - BIGINT, - 33L); - - assertFunction("case true " + - "when null then 1 " + - "when true then 33 " + - "end", - INTEGER, - 33); - - assertFunction("case null " + - "when true then 1 " + - "else 33 " + - "end", - INTEGER, - 33); - - assertFunction("case true " + - "when false then 1.0E0 " + - "when true then 33 " + - "end", - DOUBLE, - 33.0); - - assertDecimalFunction("case true " + - "when false then DECIMAL '2.2' " + - "when true then DECIMAL '2.2' " + - "end", - decimal("2.2", createDecimalType(2, 1))); - - assertDecimalFunction("case true " + - "when false then DECIMAL '1234567890.0987654321' " + - "when true then DECIMAL '3.3' " + - "end", - decimal("0000000003.3000000000", createDecimalType(20, 10))); - - assertDecimalFunction("case true " + - "when false then 1 " + - "when true then DECIMAL '2.2' " + - "end", - decimal("0000000002.2", createDecimalType(11, 1))); - - assertFunction("case true " + - "when false then DECIMAL '1.1' " + - "when true then 33.0E0 " + - "end", - DOUBLE, - 33.0); - - assertDecimalFunction("case true " + - "when false then 2.2 " + - "when true then 2.2 " + - "end", - decimal("2.2", createDecimalType(2, 1))); - - assertDecimalFunction("case true " + - "when false then 1234567890.0987654321 " + - "when true then 3.3 " + - "end", - decimal("0000000003.3000000000", createDecimalType(20, 10))); - - assertDecimalFunction("case true " + - "when false then 1 " + - "when true then 2.2 " + - "end", - decimal("0000000002.2", createDecimalType(11, 1))); - - assertFunction("case true " + - "when false then 1.1 " + - "when true then 33.0E0 " + - "end", - DOUBLE, - 33.0); + assertThat(assertions.expression(""" + case value + when condition then CAST(null AS varchar) + else 'foo' + end + """) + .binding("value", "true") + .binding("condition", "true")) + .matches("CAST(null AS varchar)"); + + assertThat(assertions.expression(""" + case value + when condition then 33 + end + """) + .binding("value", "true") + .binding("condition", "true")) + .matches("33"); + + assertThat(assertions.expression(""" + case value + when condition then BIGINT '33' + end + """) + .binding("value", "true") + .binding("condition", "true")) + .matches("BIGINT '33'"); + + assertThat(assertions.expression(""" + case value + when condition then 1 + else 33 + end + """) + .binding("value", "true") + .binding("condition", "false")) + .matches("33"); + + assertThat(assertions.expression(""" + case value + when condition then 10000000000 + else 33 + end + """) + .binding("value", "true") + .binding("condition", "false")) + .matches("BIGINT '33'"); + + assertThat(assertions.expression(""" + case value + when condition1 then 1 + when condition2 then 1 + when condition3 then 33 + else 1 + end + """) + .binding("value", "true") + .binding("condition1", "false") + .binding("condition2", "false") + .binding("condition3", "true")) + .matches("33"); + + assertThat(assertions.expression(""" + case value + when condition then 1 + end + """) + .binding("value", "true") + .binding("condition", "false")) + .isNull(INTEGER); + + assertThat(assertions.expression(""" + case value + when condition then null + else 'foo' + end + """) + .binding("value", "true") + .binding("condition", "true")) + .isNull(createVarcharType(3)); + + assertThat(assertions.expression(""" + case value + when condition1 then 10000000000 + when condition2 then 33 + end + """) + .binding("value", "true") + .binding("condition1", "null") + .binding("condition2", "true")) + .matches("BIGINT '33'"); + + assertThat(assertions.expression(""" + case value + when condition1 then 1 + when condition2 then 33 + end + """) + .binding("value", "true") + .binding("condition1", "null") + .binding("condition2", "true")) + .matches("33"); + + assertThat(assertions.expression(""" + case value + when condition then 1 + else 33 + end + """) + .binding("value", "null") + .binding("condition", "true")) + .matches("33"); + + assertThat(assertions.expression(""" + case value + when condition1 then 1E0 + when condition2 then 33 + end + """) + .binding("value", "true") + .binding("condition1", "false") + .binding("condition2", "true")) + .matches("33E0"); + + assertThat(assertions.expression(""" + case value + when condition1 then 2.2 + when condition2 then 2.2 + end + """) + .binding("value", "true") + .binding("condition1", "false") + .binding("condition2", "true")) + .matches("2.2"); + + assertThat(assertions.expression(""" + case value + when condition1 then 1234567890.0987654321 + when condition2 then 3.3 + end + """) + .binding("value", "true") + .binding("condition1", "false") + .binding("condition2", "true")) + .matches("CAST(3.3 AS decimal(20, 10))"); + + assertThat(assertions.expression(""" + case value + when condition1 then 1 + when condition2 then 2.2 + end + """) + .binding("value", "true") + .binding("condition1", "false") + .binding("condition2", "true")) + .matches("CAST(2.2 AS decimal(11, 1))"); + + assertThat(assertions.expression(""" + case value + when condition1 then 1.1 + when condition2 then 33E0 + end + """) + .binding("value", "true") + .binding("condition1", "false") + .binding("condition2", "true")) + .matches("33E0"); + + assertThat(assertions.expression(""" + case value + when condition1 then result1 + when condition2 then result2 + end + """) + .binding("value", "true") + .binding("condition1", "false") + .binding("result1", "1.1") + .binding("condition2", "true") + .binding("result2", "33.0E0")) + .matches("33.0E0"); } @Test public void testSimpleCaseWithCoercions() { - assertFunction("case 8 " + - "when double '76.1' then 1 " + - "when real '8.1' then 2 " + - "end", - INTEGER, - null); - - assertFunction("case 8 " + - "when 9 then 1 " + - "when cast(null as decimal) then 2 " + - "end", - INTEGER, - null); + assertThat(assertions.expression(""" + case value + when condition1 then 1 + when condition2 then 2 + end + """) + .binding("value", "8") + .binding("condition1", "double '76.1'") + .binding("condition2", "real '8.1'")) + .isNull(INTEGER); + + assertThat(assertions.expression(""" + case value + when condition1 then 1 + when condition2 then 2 + end + """) + .binding("value", "8") + .binding("condition1", "9") + .binding("condition2", "cast(NULL as decimal)")) + .isNull(INTEGER); } } diff --git a/core/trino-main/src/test/java/io/trino/operator/scalar/TestDataSizeFunctions.java b/core/trino-main/src/test/java/io/trino/operator/scalar/TestDataSizeFunctions.java index 288ec7856f03..7a27180c0427 100644 --- a/core/trino-main/src/test/java/io/trino/operator/scalar/TestDataSizeFunctions.java +++ b/core/trino-main/src/test/java/io/trino/operator/scalar/TestDataSizeFunctions.java @@ -13,46 +13,114 @@ */ package io.trino.operator.scalar; -import io.trino.spi.type.Type; -import org.testng.annotations.Test; +import io.trino.spi.type.DecimalType; +import io.trino.sql.query.QueryAssertions; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; import static io.trino.spi.StandardErrorCode.NUMERIC_VALUE_OUT_OF_RANGE; import static io.trino.spi.type.DecimalType.createDecimalType; import static io.trino.spi.type.SqlDecimal.decimal; +import static io.trino.testing.assertions.TrinoExceptionAssert.assertTrinoExceptionThrownBy; +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS; +@TestInstance(PER_CLASS) public class TestDataSizeFunctions - extends AbstractTestFunctions { - private static final Type DECIMAL = createDecimalType(38, 0); + private static final DecimalType DECIMAL = createDecimalType(38, 0); + + private QueryAssertions assertions; + + @BeforeAll + public void init() + { + assertions = new QueryAssertions(); + } + + @AfterAll + public void teardown() + { + assertions.close(); + assertions = null; + } @Test public void testParseDataSize() { - assertFunction("parse_data_size('0B')", DECIMAL, decimal("0", createDecimalType(38))); - assertFunction("parse_data_size('1B')", DECIMAL, decimal("1", createDecimalType(38))); - assertFunction("parse_data_size('1.2B')", DECIMAL, decimal("1", createDecimalType(38))); - assertFunction("parse_data_size('1.9B')", DECIMAL, decimal("1", createDecimalType(38))); - assertFunction("parse_data_size('2.2kB')", DECIMAL, decimal("2252", createDecimalType(38))); - assertFunction("parse_data_size('2.23kB')", DECIMAL, decimal("2283", createDecimalType(38))); - assertFunction("parse_data_size('2.23kB')", DECIMAL, decimal("2283", createDecimalType(38))); - assertFunction("parse_data_size('2.234kB')", DECIMAL, decimal("2287", createDecimalType(38))); - assertFunction("parse_data_size('3MB')", DECIMAL, decimal("3145728", createDecimalType(38))); - assertFunction("parse_data_size('4GB')", DECIMAL, decimal("4294967296", createDecimalType(38))); - assertFunction("parse_data_size('4TB')", DECIMAL, decimal("4398046511104", createDecimalType(38))); - assertFunction("parse_data_size('5PB')", DECIMAL, decimal("5629499534213120", createDecimalType(38))); - assertFunction("parse_data_size('6EB')", DECIMAL, decimal("6917529027641081856", createDecimalType(38))); - assertFunction("parse_data_size('7ZB')", DECIMAL, decimal("8264141345021879123968", createDecimalType(38))); - assertFunction("parse_data_size('8YB')", DECIMAL, decimal("9671406556917033397649408", createDecimalType(38))); - assertFunction("parse_data_size('6917529027641081856EB')", DECIMAL, decimal("7975367974709495237422842361682067456", createDecimalType(38))); - assertFunction("parse_data_size('69175290276410818560EB')", DECIMAL, decimal("79753679747094952374228423616820674560", createDecimalType(38))); - - assertInvalidFunction("parse_data_size('')", "Invalid data size: ''"); - assertInvalidFunction("parse_data_size('0')", "Invalid data size: '0'"); - assertInvalidFunction("parse_data_size('10KB')", "Invalid data size: '10KB'"); - assertInvalidFunction("parse_data_size('KB')", "Invalid data size: 'KB'"); - assertInvalidFunction("parse_data_size('-1B')", "Invalid data size: '-1B'"); - assertInvalidFunction("parse_data_size('12345K')", "Invalid data size: '12345K'"); - assertInvalidFunction("parse_data_size('A12345B')", "Invalid data size: 'A12345B'"); - assertInvalidFunction("parse_data_size('99999999999999YB')", NUMERIC_VALUE_OUT_OF_RANGE, "Value out of range: '99999999999999YB' ('120892581961461708544797985370825293824B')"); + assertThat(assertions.function("parse_data_size", "'0B'")) + .isEqualTo(decimal("0", DECIMAL)); + + assertThat(assertions.function("parse_data_size", "'1B'")) + .isEqualTo(decimal("1", DECIMAL)); + + assertThat(assertions.function("parse_data_size", "'1.2B'")) + .isEqualTo(decimal("1", DECIMAL)); + + assertThat(assertions.function("parse_data_size", "'1.9B'")) + .isEqualTo(decimal("1", DECIMAL)); + + assertThat(assertions.function("parse_data_size", "'2.2kB'")) + .isEqualTo(decimal("2252", DECIMAL)); + + assertThat(assertions.function("parse_data_size", "'2.23kB'")) + .isEqualTo(decimal("2283", DECIMAL)); + + assertThat(assertions.function("parse_data_size", "'2.234kB'")) + .isEqualTo(decimal("2287", DECIMAL)); + + assertThat(assertions.function("parse_data_size", "'3MB'")) + .isEqualTo(decimal("3145728", DECIMAL)); + + assertThat(assertions.function("parse_data_size", "'4GB'")) + .isEqualTo(decimal("4294967296", DECIMAL)); + + assertThat(assertions.function("parse_data_size", "'4TB'")) + .isEqualTo(decimal("4398046511104", DECIMAL)); + + assertThat(assertions.function("parse_data_size", "'5PB'")) + .isEqualTo(decimal("5629499534213120", DECIMAL)); + + assertThat(assertions.function("parse_data_size", "'6EB'")) + .isEqualTo(decimal("6917529027641081856", DECIMAL)); + + assertThat(assertions.function("parse_data_size", "'7ZB'")) + .isEqualTo(decimal("8264141345021879123968", DECIMAL)); + + assertThat(assertions.function("parse_data_size", "'8YB'")) + .isEqualTo(decimal("9671406556917033397649408", DECIMAL)); + + assertThat(assertions.function("parse_data_size", "'6917529027641081856EB'")) + .isEqualTo(decimal("7975367974709495237422842361682067456", DECIMAL)); + + assertThat(assertions.function("parse_data_size", "'69175290276410818560EB'")) + .isEqualTo(decimal("79753679747094952374228423616820674560", DECIMAL)); + + assertTrinoExceptionThrownBy(() -> assertions.function("parse_data_size", "''").evaluate()) + .hasMessage("Invalid data size: ''"); + + assertTrinoExceptionThrownBy(() -> assertions.function("parse_data_size", "'0'").evaluate()) + .hasMessage("Invalid data size: '0'"); + + assertTrinoExceptionThrownBy(() -> assertions.function("parse_data_size", "'10KB'").evaluate()) + .hasMessage("Invalid data size: '10KB'"); + + assertTrinoExceptionThrownBy(() -> assertions.function("parse_data_size", "'KB'").evaluate()) + .hasMessage("Invalid data size: 'KB'"); + + assertTrinoExceptionThrownBy(() -> assertions.function("parse_data_size", "'-1B'").evaluate()) + .hasMessage("Invalid data size: '-1B'"); + + assertTrinoExceptionThrownBy(() -> assertions.function("parse_data_size", "'12345K'").evaluate()) + .hasMessage("Invalid data size: '12345K'"); + + assertTrinoExceptionThrownBy(() -> assertions.function("parse_data_size", "'A12345B'").evaluate()) + .hasMessage("Invalid data size: 'A12345B'"); + + assertTrinoExceptionThrownBy(() -> assertions.function("parse_data_size", "'99999999999999YB'").evaluate()) + .hasErrorCode(NUMERIC_VALUE_OUT_OF_RANGE) + .hasMessage("Value out of range: '99999999999999YB' ('120892581961461708544797985370825293824B')"); } } diff --git a/core/trino-main/src/test/java/io/trino/operator/scalar/interval/TestIntervalDayTime.java b/core/trino-main/src/test/java/io/trino/operator/scalar/interval/TestIntervalDayTime.java index 9138043059c9..c1e82b14262e 100644 --- a/core/trino-main/src/test/java/io/trino/operator/scalar/interval/TestIntervalDayTime.java +++ b/core/trino-main/src/test/java/io/trino/operator/scalar/interval/TestIntervalDayTime.java @@ -132,19 +132,19 @@ public void testLiterals() assertThat(assertions.expression("INTERVAL '32' SECOND")) .isEqualTo(interval(0, 0, 0, 32, 0)); - assertThatThrownBy(() -> assertions.expression("INTERVAL '12X' DAY")) + assertThatThrownBy(() -> assertions.expression("INTERVAL '12X' DAY").evaluate()) .hasMessage("line 1:8: '12X' is not a valid interval literal"); - assertThatThrownBy(() -> assertions.expression("INTERVAL '12 10' DAY")) + assertThatThrownBy(() -> assertions.expression("INTERVAL '12 10' DAY").evaluate()) .hasMessage("line 1:8: '12 10' is not a valid interval literal"); - assertThatThrownBy(() -> assertions.expression("INTERVAL '12 X' DAY TO HOUR")) + assertThatThrownBy(() -> assertions.expression("INTERVAL '12 X' DAY TO HOUR").evaluate()) .hasMessage("line 1:8: '12 X' is not a valid interval literal"); - assertThatThrownBy(() -> assertions.expression("INTERVAL '12 -10' DAY TO HOUR")) + assertThatThrownBy(() -> assertions.expression("INTERVAL '12 -10' DAY TO HOUR").evaluate()) .hasMessage("line 1:8: '12 -10' is not a valid interval literal"); - assertThatThrownBy(() -> assertions.expression("INTERVAL '--12 -10' DAY TO HOUR")) + assertThatThrownBy(() -> assertions.expression("INTERVAL '--12 -10' DAY TO HOUR").evaluate()) .hasMessage("line 1:8: '--12 -10' is not a valid interval literal"); } diff --git a/core/trino-main/src/test/java/io/trino/operator/scalar/interval/TestIntervalYearMonth.java b/core/trino-main/src/test/java/io/trino/operator/scalar/interval/TestIntervalYearMonth.java index 210d0f4374d7..e8b89968681b 100644 --- a/core/trino-main/src/test/java/io/trino/operator/scalar/interval/TestIntervalYearMonth.java +++ b/core/trino-main/src/test/java/io/trino/operator/scalar/interval/TestIntervalYearMonth.java @@ -60,19 +60,19 @@ public void testLiterals() assertThat(assertions.expression("INTERVAL '32767-32767' YEAR TO MONTH")) .isEqualTo(interval(32767, 32767)); - assertThatThrownBy(() -> assertions.expression("INTERVAL '124X' YEAR")) + assertThatThrownBy(() -> assertions.expression("INTERVAL '124X' YEAR").evaluate()) .hasMessage("line 1:8: '124X' is not a valid interval literal"); - assertThatThrownBy(() -> assertions.expression("INTERVAL '124-30' YEAR")) + assertThatThrownBy(() -> assertions.expression("INTERVAL '124-30' YEAR").evaluate()) .hasMessage("line 1:8: '124-30' is not a valid interval literal"); - assertThatThrownBy(() -> assertions.expression("INTERVAL '124-X' YEAR TO MONTH")) + assertThatThrownBy(() -> assertions.expression("INTERVAL '124-X' YEAR TO MONTH").evaluate()) .hasMessage("line 1:8: '124-X' is not a valid interval literal"); - assertThatThrownBy(() -> assertions.expression("INTERVAL '124--30' YEAR TO MONTH")) + assertThatThrownBy(() -> assertions.expression("INTERVAL '124--30' YEAR TO MONTH").evaluate()) .hasMessage("line 1:8: '124--30' is not a valid interval literal"); - assertThatThrownBy(() -> assertions.expression("INTERVAL '--124--30' YEAR TO MONTH")) + assertThatThrownBy(() -> assertions.expression("INTERVAL '--124--30' YEAR TO MONTH").evaluate()) .hasMessage("line 1:8: '--124--30' is not a valid interval literal"); } diff --git a/core/trino-main/src/test/java/io/trino/operator/scalar/time/TestExtract.java b/core/trino-main/src/test/java/io/trino/operator/scalar/time/TestExtract.java index ea20765cabf9..3b7dc94bd03c 100644 --- a/core/trino-main/src/test/java/io/trino/operator/scalar/time/TestExtract.java +++ b/core/trino-main/src/test/java/io/trino/operator/scalar/time/TestExtract.java @@ -135,7 +135,7 @@ public void testSecond() @Test public void testMillisecond() { - assertThatThrownBy(() -> assertions.expression("EXTRACT(MILLISECOND FROM TIME '12:34:56')")) + assertThatThrownBy(() -> assertions.expression("EXTRACT(MILLISECOND FROM TIME '12:34:56')").evaluate()) .isInstanceOf(ParsingException.class) .hasMessage("line 1:8: Invalid EXTRACT field: MILLISECOND"); diff --git a/core/trino-main/src/test/java/io/trino/operator/scalar/time/TestTime.java b/core/trino-main/src/test/java/io/trino/operator/scalar/time/TestTime.java index 1b2015b5137a..3f1e17bb411a 100644 --- a/core/trino-main/src/test/java/io/trino/operator/scalar/time/TestTime.java +++ b/core/trino-main/src/test/java/io/trino/operator/scalar/time/TestTime.java @@ -104,19 +104,19 @@ public void testLiterals() .hasType(createTimeType(12)) .isEqualTo(time(12, 12, 34, 56, 123_456_789_123L)); - assertThatThrownBy(() -> assertions.expression("TIME '12:34:56.1234567891234'")) + assertThatThrownBy(() -> assertions.expression("TIME '12:34:56.1234567891234'").evaluate()) .hasMessage("line 1:8: TIME precision must be in range [0, 12]: 13"); - assertThatThrownBy(() -> assertions.expression("TIME '25:00:00'")) + assertThatThrownBy(() -> assertions.expression("TIME '25:00:00'").evaluate()) .hasMessage("line 1:8: '25:00:00' is not a valid time literal"); - assertThatThrownBy(() -> assertions.expression("TIME '12:65:00'")) + assertThatThrownBy(() -> assertions.expression("TIME '12:65:00'").evaluate()) .hasMessage("line 1:8: '12:65:00' is not a valid time literal"); - assertThatThrownBy(() -> assertions.expression("TIME '12:00:65'")) + assertThatThrownBy(() -> assertions.expression("TIME '12:00:65'").evaluate()) .hasMessage("line 1:8: '12:00:65' is not a valid time literal"); - assertThatThrownBy(() -> assertions.expression("TIME 'xxx'")) + assertThatThrownBy(() -> assertions.expression("TIME 'xxx'").evaluate()) .hasMessage("line 1:8: 'xxx' is not a valid time literal"); } @@ -1461,31 +1461,31 @@ public void testCastFromVarchar() assertThat(assertions.expression("CAST('23:59:59.999999999999' AS TIME(11))")).matches("TIME '00:00:00.00000000000'"); // > 12 digits of precision - assertThatThrownBy(() -> assertions.expression("CAST('12:34:56.1111111111111' AS TIME(0))")) + assertThatThrownBy(() -> assertions.expression("CAST('12:34:56.1111111111111' AS TIME(0))").evaluate()) .hasMessage("Value cannot be cast to time: 12:34:56.1111111111111"); - assertThatThrownBy(() -> assertions.expression("CAST('12:34:56.1111111111111' AS TIME(1))")) + assertThatThrownBy(() -> assertions.expression("CAST('12:34:56.1111111111111' AS TIME(1))").evaluate()) .hasMessage("Value cannot be cast to time: 12:34:56.1111111111111"); - assertThatThrownBy(() -> assertions.expression("CAST('12:34:56.1111111111111' AS TIME(2))")) + assertThatThrownBy(() -> assertions.expression("CAST('12:34:56.1111111111111' AS TIME(2))").evaluate()) .hasMessage("Value cannot be cast to time: 12:34:56.1111111111111"); - assertThatThrownBy(() -> assertions.expression("CAST('12:34:56.1111111111111' AS TIME(3))")) + assertThatThrownBy(() -> assertions.expression("CAST('12:34:56.1111111111111' AS TIME(3))").evaluate()) .hasMessage("Value cannot be cast to time: 12:34:56.1111111111111"); - assertThatThrownBy(() -> assertions.expression("CAST('12:34:56.1111111111111' AS TIME(4))")) + assertThatThrownBy(() -> assertions.expression("CAST('12:34:56.1111111111111' AS TIME(4))").evaluate()) .hasMessage("Value cannot be cast to time: 12:34:56.1111111111111"); - assertThatThrownBy(() -> assertions.expression("CAST('12:34:56.1111111111111' AS TIME(5))")) + assertThatThrownBy(() -> assertions.expression("CAST('12:34:56.1111111111111' AS TIME(5))").evaluate()) .hasMessage("Value cannot be cast to time: 12:34:56.1111111111111"); - assertThatThrownBy(() -> assertions.expression("CAST('12:34:56.1111111111111' AS TIME(6))")) + assertThatThrownBy(() -> assertions.expression("CAST('12:34:56.1111111111111' AS TIME(6))").evaluate()) .hasMessage("Value cannot be cast to time: 12:34:56.1111111111111"); - assertThatThrownBy(() -> assertions.expression("CAST('12:34:56.1111111111111' AS TIME(7))")) + assertThatThrownBy(() -> assertions.expression("CAST('12:34:56.1111111111111' AS TIME(7))").evaluate()) .hasMessage("Value cannot be cast to time: 12:34:56.1111111111111"); - assertThatThrownBy(() -> assertions.expression("CAST('12:34:56.1111111111111' AS TIME(8))")) + assertThatThrownBy(() -> assertions.expression("CAST('12:34:56.1111111111111' AS TIME(8))").evaluate()) .hasMessage("Value cannot be cast to time: 12:34:56.1111111111111"); - assertThatThrownBy(() -> assertions.expression("CAST('12:34:56.1111111111111' AS TIME(9))")) + assertThatThrownBy(() -> assertions.expression("CAST('12:34:56.1111111111111' AS TIME(9))").evaluate()) .hasMessage("Value cannot be cast to time: 12:34:56.1111111111111"); - assertThatThrownBy(() -> assertions.expression("CAST('12:34:56.1111111111111' AS TIME(10))")) + assertThatThrownBy(() -> assertions.expression("CAST('12:34:56.1111111111111' AS TIME(10))").evaluate()) .hasMessage("Value cannot be cast to time: 12:34:56.1111111111111"); - assertThatThrownBy(() -> assertions.expression("CAST('12:34:56.1111111111111' AS TIME(11))")) + assertThatThrownBy(() -> assertions.expression("CAST('12:34:56.1111111111111' AS TIME(11))").evaluate()) .hasMessage("Value cannot be cast to time: 12:34:56.1111111111111"); - assertThatThrownBy(() -> assertions.expression("CAST('12:34:56.1111111111111' AS TIME(12))")) + assertThatThrownBy(() -> assertions.expression("CAST('12:34:56.1111111111111' AS TIME(12))").evaluate()) .hasMessage("Value cannot be cast to time: 12:34:56.1111111111111"); } diff --git a/core/trino-main/src/test/java/io/trino/operator/scalar/timestamp/TestExtract.java b/core/trino-main/src/test/java/io/trino/operator/scalar/timestamp/TestExtract.java index c34b821731c0..29c1b22e52e2 100644 --- a/core/trino-main/src/test/java/io/trino/operator/scalar/timestamp/TestExtract.java +++ b/core/trino-main/src/test/java/io/trino/operator/scalar/timestamp/TestExtract.java @@ -324,7 +324,7 @@ public void testSecond() @Test public void testMillisecond() { - assertThatThrownBy(() -> assertions.expression("EXTRACT(MILLISECOND FROM TIMESTAMP '2020-05-10 12:34:56')")) + assertThatThrownBy(() -> assertions.expression("EXTRACT(MILLISECOND FROM TIMESTAMP '2020-05-10 12:34:56')").evaluate()) .isInstanceOf(ParsingException.class) .hasMessage("line 1:8: Invalid EXTRACT field: MILLISECOND"); @@ -478,7 +478,7 @@ public void testQuarter() @Test public void testWeekOfYear() { - assertThatThrownBy(() -> assertions.expression("EXTRACT(WEEK_OF_YEAR FROM TIMESTAMP '2020-05-10 12:34:56')")) + assertThatThrownBy(() -> assertions.expression("EXTRACT(WEEK_OF_YEAR FROM TIMESTAMP '2020-05-10 12:34:56')").evaluate()) .isInstanceOf(ParsingException.class) .hasMessage("line 1:8: Invalid EXTRACT field: WEEK_OF_YEAR"); diff --git a/core/trino-main/src/test/java/io/trino/operator/scalar/timestamp/TestTimestamp.java b/core/trino-main/src/test/java/io/trino/operator/scalar/timestamp/TestTimestamp.java index 774b956f0ed2..bf40774ee3f7 100644 --- a/core/trino-main/src/test/java/io/trino/operator/scalar/timestamp/TestTimestamp.java +++ b/core/trino-main/src/test/java/io/trino/operator/scalar/timestamp/TestTimestamp.java @@ -134,13 +134,13 @@ public void testLiterals() .hasType(createTimestampType(12)) .isEqualTo(timestamp(12, 2020, 5, 1, 12, 34, 56, 123_456_789_012L)); - assertThatThrownBy(() -> assertions.expression("TIMESTAMP '2020-05-01 12:34:56.1234567890123'")) + assertThatThrownBy(() -> assertions.expression("TIMESTAMP '2020-05-01 12:34:56.1234567890123'").evaluate()) .hasMessage("line 1:8: TIMESTAMP precision must be in range [0, 12]: 13"); - assertThatThrownBy(() -> assertions.expression("TIMESTAMP '2020-13-01'")) + assertThatThrownBy(() -> assertions.expression("TIMESTAMP '2020-13-01'").evaluate()) .hasMessage("line 1:8: '2020-13-01' is not a valid timestamp literal"); - assertThatThrownBy(() -> assertions.expression("TIMESTAMP 'xxx'")) + assertThatThrownBy(() -> assertions.expression("TIMESTAMP 'xxx'").evaluate()) .hasMessage("line 1:8: 'xxx' is not a valid timestamp literal"); // negative epoch @@ -1464,9 +1464,9 @@ public void testCastToTimestampWithTimeZone() assertThat(assertions.expression("CAST(TIMESTAMP '-12001-05-01 12:34:56' AS TIMESTAMP(0) WITH TIME ZONE)")).matches("TIMESTAMP '-12001-05-01 12:34:56 Pacific/Apia'"); // Overflow - assertThatThrownBy(() -> assertions.expression("CAST(TIMESTAMP '123001-05-01 12:34:56' AS TIMESTAMP WITH TIME ZONE)")) + assertThatThrownBy(() -> assertions.expression("CAST(TIMESTAMP '123001-05-01 12:34:56' AS TIMESTAMP WITH TIME ZONE)").evaluate()) .hasMessage("Out of range for timestamp with time zone: 3819379822496000"); - assertThatThrownBy(() -> assertions.expression("CAST(TIMESTAMP '-123001-05-01 12:34:56' AS TIMESTAMP WITH TIME ZONE)")) + assertThatThrownBy(() -> assertions.expression("CAST(TIMESTAMP '-123001-05-01 12:34:56' AS TIMESTAMP WITH TIME ZONE)").evaluate()) .hasMessage("Out of range for timestamp with time zone: -3943693439888000"); } @@ -2674,30 +2674,30 @@ public void testAtTimeZone() @Test public void testCastInvalidTimestamp() { - assertThatThrownBy(() -> assertions.expression("CAST('ABC' AS TIMESTAMP)")) + assertThatThrownBy(() -> assertions.expression("CAST('ABC' AS TIMESTAMP)").evaluate()) .hasMessage("Value cannot be cast to timestamp: ABC"); - assertThatThrownBy(() -> assertions.expression("CAST('2022-01-00 00:00:00' AS TIMESTAMP)")) + assertThatThrownBy(() -> assertions.expression("CAST('2022-01-00 00:00:00' AS TIMESTAMP)").evaluate()) .hasMessage("Value cannot be cast to timestamp: 2022-01-00 00:00:00"); - assertThatThrownBy(() -> assertions.expression("CAST('2022-00-01 00:00:00' AS TIMESTAMP)")) + assertThatThrownBy(() -> assertions.expression("CAST('2022-00-01 00:00:00' AS TIMESTAMP)").evaluate()) .hasMessage("Value cannot be cast to timestamp: 2022-00-01 00:00:00"); - assertThatThrownBy(() -> assertions.expression("CAST('2022-01-01 25:00:00' AS TIMESTAMP)")) + assertThatThrownBy(() -> assertions.expression("CAST('2022-01-01 25:00:00' AS TIMESTAMP)").evaluate()) .hasMessage("Value cannot be cast to timestamp: 2022-01-01 25:00:00"); - assertThatThrownBy(() -> assertions.expression("CAST('2022-01-01 00:61:00' AS TIMESTAMP)")) + assertThatThrownBy(() -> assertions.expression("CAST('2022-01-01 00:61:00' AS TIMESTAMP)").evaluate()) .hasMessage("Value cannot be cast to timestamp: 2022-01-01 00:61:00"); - assertThatThrownBy(() -> assertions.expression("CAST('2022-01-01 00:00:61' AS TIMESTAMP)")) + assertThatThrownBy(() -> assertions.expression("CAST('2022-01-01 00:00:61' AS TIMESTAMP)").evaluate()) .hasMessage("Value cannot be cast to timestamp: 2022-01-01 00:00:61"); - assertThatThrownBy(() -> assertions.expression("CAST('ABC' AS TIMESTAMP(12))")) + assertThatThrownBy(() -> assertions.expression("CAST('ABC' AS TIMESTAMP(12))").evaluate()) .hasMessage("Value cannot be cast to timestamp: ABC"); - assertThatThrownBy(() -> assertions.expression("CAST('2022-01-00 00:00:00' AS TIMESTAMP(12))")) + assertThatThrownBy(() -> assertions.expression("CAST('2022-01-00 00:00:00' AS TIMESTAMP(12))").evaluate()) .hasMessage("Value cannot be cast to timestamp: 2022-01-00 00:00:00"); - assertThatThrownBy(() -> assertions.expression("CAST('2022-00-01 00:00:00' AS TIMESTAMP(12))")) + assertThatThrownBy(() -> assertions.expression("CAST('2022-00-01 00:00:00' AS TIMESTAMP(12))").evaluate()) .hasMessage("Value cannot be cast to timestamp: 2022-00-01 00:00:00"); - assertThatThrownBy(() -> assertions.expression("CAST('2022-01-01 25:00:00' AS TIMESTAMP(12))")) + assertThatThrownBy(() -> assertions.expression("CAST('2022-01-01 25:00:00' AS TIMESTAMP(12))").evaluate()) .hasMessage("Value cannot be cast to timestamp: 2022-01-01 25:00:00"); - assertThatThrownBy(() -> assertions.expression("CAST('2022-01-01 00:61:00' AS TIMESTAMP(12))")) + assertThatThrownBy(() -> assertions.expression("CAST('2022-01-01 00:61:00' AS TIMESTAMP(12))").evaluate()) .hasMessage("Value cannot be cast to timestamp: 2022-01-01 00:61:00"); - assertThatThrownBy(() -> assertions.expression("CAST('2022-01-01 00:00:61' AS TIMESTAMP(12))")) + assertThatThrownBy(() -> assertions.expression("CAST('2022-01-01 00:00:61' AS TIMESTAMP(12))").evaluate()) .hasMessage("Value cannot be cast to timestamp: 2022-01-01 00:00:61"); } diff --git a/core/trino-main/src/test/java/io/trino/operator/scalar/timestamptz/TestExtract.java b/core/trino-main/src/test/java/io/trino/operator/scalar/timestamptz/TestExtract.java index a6c921eb7a6c..a2cddd826330 100644 --- a/core/trino-main/src/test/java/io/trino/operator/scalar/timestamptz/TestExtract.java +++ b/core/trino-main/src/test/java/io/trino/operator/scalar/timestamptz/TestExtract.java @@ -295,7 +295,7 @@ public void testSecond() @Test public void testMillisecond() { - assertThatThrownBy(() -> assertions.expression("EXTRACT(MILLISECOND FROM TIMESTAMP '2020-05-10 12:34:56 Asia/Kathmandu')")) + assertThatThrownBy(() -> assertions.expression("EXTRACT(MILLISECOND FROM TIMESTAMP '2020-05-10 12:34:56 Asia/Kathmandu')").evaluate()) .isInstanceOf(ParsingException.class) .hasMessage("line 1:8: Invalid EXTRACT field: MILLISECOND"); @@ -600,7 +600,7 @@ public void testQuarter() @Test public void testWeekOfYear() { - assertThatThrownBy(() -> assertions.expression("EXTRACT(WEEK_OF_YEAR FROM TIMESTAMP '2020-05-10 12:34:56 Asia/Kathmandu')")) + assertThatThrownBy(() -> assertions.expression("EXTRACT(WEEK_OF_YEAR FROM TIMESTAMP '2020-05-10 12:34:56 Asia/Kathmandu')").evaluate()) .isInstanceOf(ParsingException.class) .hasMessage("line 1:8: Invalid EXTRACT field: WEEK_OF_YEAR"); diff --git a/core/trino-main/src/test/java/io/trino/operator/scalar/timestamptz/TestTimestampWithTimeZone.java b/core/trino-main/src/test/java/io/trino/operator/scalar/timestamptz/TestTimestampWithTimeZone.java index 7900c185655c..ced42bbdea31 100644 --- a/core/trino-main/src/test/java/io/trino/operator/scalar/timestamptz/TestTimestampWithTimeZone.java +++ b/core/trino-main/src/test/java/io/trino/operator/scalar/timestamptz/TestTimestampWithTimeZone.java @@ -106,10 +106,10 @@ public void testLiterals() .hasType(createTimestampWithTimeZoneType(12)) .isEqualTo(timestampWithTimeZone(12, 2020, 5, 1, 12, 34, 56, 123_456_789_012L, getTimeZoneKey("Asia/Kathmandu"))); - assertThatThrownBy(() -> assertions.expression("TIMESTAMP '2020-05-01 12:34:56.1234567890123 Asia/Kathmandu'")) + assertThatThrownBy(() -> assertions.expression("TIMESTAMP '2020-05-01 12:34:56.1234567890123 Asia/Kathmandu'").evaluate()) .hasMessage("line 1:8: TIMESTAMP WITH TIME ZONE precision must be in range [0, 12]: 13"); - assertThatThrownBy(() -> assertions.expression("TIMESTAMP '2020-13-01 Asia/Kathmandu'")) + assertThatThrownBy(() -> assertions.expression("TIMESTAMP '2020-13-01 Asia/Kathmandu'").evaluate()) .hasMessage("line 1:8: '2020-13-01 Asia/Kathmandu' is not a valid timestamp literal"); // negative epoch @@ -2488,34 +2488,34 @@ public void testJoin() @Test public void testCastInvalidTimestamp() { - assertThatThrownBy(() -> assertions.expression("CAST('ABC' AS TIMESTAMP WITH TIME ZONE)")) + assertThatThrownBy(() -> assertions.expression("CAST('ABC' AS TIMESTAMP WITH TIME ZONE)").evaluate()) .hasMessage("Value cannot be cast to timestamp: ABC"); - assertThatThrownBy(() -> assertions.expression("CAST('2022-01-00 00:00:00 UTC' AS TIMESTAMP WITH TIME ZONE)")) + assertThatThrownBy(() -> assertions.expression("CAST('2022-01-00 00:00:00 UTC' AS TIMESTAMP WITH TIME ZONE)").evaluate()) .hasMessage("Value cannot be cast to timestamp: 2022-01-00 00:00:00 UTC"); - assertThatThrownBy(() -> assertions.expression("CAST('2022-00-01 00:00:00 UTC' AS TIMESTAMP WITH TIME ZONE)")) + assertThatThrownBy(() -> assertions.expression("CAST('2022-00-01 00:00:00 UTC' AS TIMESTAMP WITH TIME ZONE)").evaluate()) .hasMessage("Value cannot be cast to timestamp: 2022-00-01 00:00:00 UTC"); - assertThatThrownBy(() -> assertions.expression("CAST('2022-01-01 25:00:00 UTC' AS TIMESTAMP WITH TIME ZONE)")) + assertThatThrownBy(() -> assertions.expression("CAST('2022-01-01 25:00:00 UTC' AS TIMESTAMP WITH TIME ZONE)").evaluate()) .hasMessage("Value cannot be cast to timestamp: 2022-01-01 25:00:00 UTC"); - assertThatThrownBy(() -> assertions.expression("CAST('2022-01-01 00:61:00 UTC' AS TIMESTAMP WITH TIME ZONE)")) + assertThatThrownBy(() -> assertions.expression("CAST('2022-01-01 00:61:00 UTC' AS TIMESTAMP WITH TIME ZONE)").evaluate()) .hasMessage("Value cannot be cast to timestamp: 2022-01-01 00:61:00 UTC"); - assertThatThrownBy(() -> assertions.expression("CAST('2022-01-01 00:00:61 UTC' AS TIMESTAMP WITH TIME ZONE)")) + assertThatThrownBy(() -> assertions.expression("CAST('2022-01-01 00:00:61 UTC' AS TIMESTAMP WITH TIME ZONE)").evaluate()) .hasMessage("Value cannot be cast to timestamp: 2022-01-01 00:00:61 UTC"); - assertThatThrownBy(() -> assertions.expression("CAST('2022-01-01 00:00:00 ABC' AS TIMESTAMP WITH TIME ZONE)")) + assertThatThrownBy(() -> assertions.expression("CAST('2022-01-01 00:00:00 ABC' AS TIMESTAMP WITH TIME ZONE)").evaluate()) .hasMessage("Value cannot be cast to timestamp: 2022-01-01 00:00:00 ABC"); - assertThatThrownBy(() -> assertions.expression("CAST('ABC' AS TIMESTAMP(12))")) + assertThatThrownBy(() -> assertions.expression("CAST('ABC' AS TIMESTAMP(12))").evaluate()) .hasMessage("Value cannot be cast to timestamp: ABC"); - assertThatThrownBy(() -> assertions.expression("CAST('2022-01-00 00:00:00 UTC' AS TIMESTAMP(12) WITH TIME ZONE)")) + assertThatThrownBy(() -> assertions.expression("CAST('2022-01-00 00:00:00 UTC' AS TIMESTAMP(12) WITH TIME ZONE)").evaluate()) .hasMessage("Value cannot be cast to timestamp: 2022-01-00 00:00:00 UTC"); - assertThatThrownBy(() -> assertions.expression("CAST('2022-00-01 00:00:00 UTC' AS TIMESTAMP(12) WITH TIME ZONE)")) + assertThatThrownBy(() -> assertions.expression("CAST('2022-00-01 00:00:00 UTC' AS TIMESTAMP(12) WITH TIME ZONE)").evaluate()) .hasMessage("Value cannot be cast to timestamp: 2022-00-01 00:00:00 UTC"); - assertThatThrownBy(() -> assertions.expression("CAST('2022-01-01 25:00:00 UTC' AS TIMESTAMP(12) WITH TIME ZONE)")) + assertThatThrownBy(() -> assertions.expression("CAST('2022-01-01 25:00:00 UTC' AS TIMESTAMP(12) WITH TIME ZONE)").evaluate()) .hasMessage("Value cannot be cast to timestamp: 2022-01-01 25:00:00 UTC"); - assertThatThrownBy(() -> assertions.expression("CAST('2022-01-01 00:61:00 UTC' AS TIMESTAMP(12) WITH TIME ZONE)")) + assertThatThrownBy(() -> assertions.expression("CAST('2022-01-01 00:61:00 UTC' AS TIMESTAMP(12) WITH TIME ZONE)").evaluate()) .hasMessage("Value cannot be cast to timestamp: 2022-01-01 00:61:00 UTC"); - assertThatThrownBy(() -> assertions.expression("CAST('2022-01-01 00:00:61 UTC' AS TIMESTAMP(12) WITH TIME ZONE)")) + assertThatThrownBy(() -> assertions.expression("CAST('2022-01-01 00:00:61 UTC' AS TIMESTAMP(12) WITH TIME ZONE)").evaluate()) .hasMessage("Value cannot be cast to timestamp: 2022-01-01 00:00:61 UTC"); - assertThatThrownBy(() -> assertions.expression("CAST('2022-01-01 00:00:00 ABC' AS TIMESTAMP(12) WITH TIME ZONE)")) + assertThatThrownBy(() -> assertions.expression("CAST('2022-01-01 00:00:00 ABC' AS TIMESTAMP(12) WITH TIME ZONE)").evaluate()) .hasMessage("Value cannot be cast to timestamp: 2022-01-01 00:00:00 ABC"); } diff --git a/core/trino-main/src/test/java/io/trino/operator/scalar/timetz/TestExtract.java b/core/trino-main/src/test/java/io/trino/operator/scalar/timetz/TestExtract.java index ced4bebe8661..0426c69fc1d9 100644 --- a/core/trino-main/src/test/java/io/trino/operator/scalar/timetz/TestExtract.java +++ b/core/trino-main/src/test/java/io/trino/operator/scalar/timetz/TestExtract.java @@ -135,7 +135,7 @@ public void testSecond() @Test public void testMillisecond() { - assertThatThrownBy(() -> assertions.expression("EXTRACT(MILLISECOND FROM TIME '12:34:56+08:35')")) + assertThatThrownBy(() -> assertions.expression("EXTRACT(MILLISECOND FROM TIME '12:34:56+08:35')").evaluate()) .isInstanceOf(ParsingException.class) .hasMessage("line 1:8: Invalid EXTRACT field: MILLISECOND"); diff --git a/core/trino-main/src/test/java/io/trino/operator/scalar/timetz/TestTimeWithTimeZone.java b/core/trino-main/src/test/java/io/trino/operator/scalar/timetz/TestTimeWithTimeZone.java index ca98cb410b29..6d47ed319ed6 100644 --- a/core/trino-main/src/test/java/io/trino/operator/scalar/timetz/TestTimeWithTimeZone.java +++ b/core/trino-main/src/test/java/io/trino/operator/scalar/timetz/TestTimeWithTimeZone.java @@ -257,34 +257,34 @@ public void testLiterals() .hasType(createTimeWithTimeZoneType(12)) .isEqualTo(timeWithTimeZone(12, 12, 34, 56, 123_456_789_123L, -14 * 60)); - assertThatThrownBy(() -> assertions.expression("TIME '12:34:56.1234567891234+08:35'")) + assertThatThrownBy(() -> assertions.expression("TIME '12:34:56.1234567891234+08:35'").evaluate()) .hasMessage("line 1:8: TIME WITH TIME ZONE precision must be in range [0, 12]: 13"); - assertThatThrownBy(() -> assertions.expression("TIME '25:00:00+08:35'")) + assertThatThrownBy(() -> assertions.expression("TIME '25:00:00+08:35'").evaluate()) .hasMessage("line 1:8: '25:00:00+08:35' is not a valid time literal"); - assertThatThrownBy(() -> assertions.expression("TIME '12:65:00+08:35'")) + assertThatThrownBy(() -> assertions.expression("TIME '12:65:00+08:35'").evaluate()) .hasMessage("line 1:8: '12:65:00+08:35' is not a valid time literal"); - assertThatThrownBy(() -> assertions.expression("TIME '12:00:65+08:35'")) + assertThatThrownBy(() -> assertions.expression("TIME '12:00:65+08:35'").evaluate()) .hasMessage("line 1:8: '12:00:65+08:35' is not a valid time literal"); - assertThatThrownBy(() -> assertions.expression("TIME '12:00:00+15:00'")) + assertThatThrownBy(() -> assertions.expression("TIME '12:00:00+15:00'").evaluate()) .hasMessage("line 1:8: '12:00:00+15:00' is not a valid time literal"); - assertThatThrownBy(() -> assertions.expression("TIME '12:00:00-15:00'")) + assertThatThrownBy(() -> assertions.expression("TIME '12:00:00-15:00'").evaluate()) .hasMessage("line 1:8: '12:00:00-15:00' is not a valid time literal"); - assertThatThrownBy(() -> assertions.expression("TIME '12:00:00+14:01'")) + assertThatThrownBy(() -> assertions.expression("TIME '12:00:00+14:01'").evaluate()) .hasMessage("line 1:8: '12:00:00+14:01' is not a valid time literal"); - assertThatThrownBy(() -> assertions.expression("TIME '12:00:00-14:01'")) + assertThatThrownBy(() -> assertions.expression("TIME '12:00:00-14:01'").evaluate()) .hasMessage("line 1:8: '12:00:00-14:01' is not a valid time literal"); - assertThatThrownBy(() -> assertions.expression("TIME '12:00:00+13:60'")) + assertThatThrownBy(() -> assertions.expression("TIME '12:00:00+13:60'").evaluate()) .hasMessage("line 1:8: '12:00:00+13:60' is not a valid time literal"); - assertThatThrownBy(() -> assertions.expression("TIME '12:00:00-13:60'")) + assertThatThrownBy(() -> assertions.expression("TIME '12:00:00-13:60'").evaluate()) .hasMessage("line 1:8: '12:00:00-13:60' is not a valid time literal"); } diff --git a/core/trino-main/src/test/java/io/trino/sql/query/QueryAssertions.java b/core/trino-main/src/test/java/io/trino/sql/query/QueryAssertions.java index df402d455f47..8b518bcc4051 100644 --- a/core/trino-main/src/test/java/io/trino/sql/query/QueryAssertions.java +++ b/core/trino-main/src/test/java/io/trino/sql/query/QueryAssertions.java @@ -13,10 +13,12 @@ */ package io.trino.sql.query; +import com.google.common.base.Joiner; import com.google.common.collect.ImmutableList; import io.trino.Session; import io.trino.execution.warnings.WarningCollector; import io.trino.metadata.FunctionBundle; +import io.trino.spi.function.OperatorType; import io.trino.spi.type.SqlTime; import io.trino.spi.type.SqlTimeWithTimeZone; import io.trino.spi.type.SqlTimestamp; @@ -42,7 +44,9 @@ import java.io.Closeable; import java.util.Arrays; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.Objects; import java.util.function.BiFunction; import java.util.function.Consumer; @@ -53,8 +57,8 @@ import static com.google.common.collect.ImmutableList.toImmutableList; import static io.airlift.testing.Assertions.assertEqualsIgnoreOrder; import static io.trino.cost.StatsCalculator.noopStatsCalculator; +import static io.trino.metadata.OperatorNameUtil.mangleOperatorName; import static io.trino.sql.planner.assertions.PlanAssert.assertPlan; -import static io.trino.sql.query.QueryAssertions.ExpressionAssert.newExpressionAssert; import static io.trino.sql.query.QueryAssertions.QueryAssert.newQueryAssert; import static io.trino.testing.TestingHandles.TEST_CATALOG_NAME; import static io.trino.testing.TestingSession.testSessionBuilder; @@ -114,14 +118,38 @@ public AssertProvider query(Session session, @Language("SQL") Strin return newQueryAssert(query, runner, session); } - public AssertProvider expression(@Language("SQL") String expression) + public ExpressionAssertProvider expression(@Language("SQL") String expression) { return expression(expression, runner.getDefaultSession()); } - public AssertProvider expression(@Language("SQL") String expression, Session session) + public ExpressionAssertProvider operator(OperatorType operator, @Language("SQL") String... arguments) { - return newExpressionAssert(expression, runner, session); + return function(mangleOperatorName(operator), arguments); + } + + public ExpressionAssertProvider function(String name, @Language("SQL") String... arguments) + { + ImmutableList.Builder builder = ImmutableList.builder(); + for (int i = 0; i < arguments.length; i++) { + builder.add("a" + i); + } + + List names = builder.build(); + ExpressionAssertProvider assertion = expression("\"%s\"(%s)".formatted( + name, + String.join(",", names))); + + for (int i = 0; i < arguments.length; i++) { + assertion.binding(names.get(i), arguments[i]); + } + + return assertion; + } + + public ExpressionAssertProvider expression(@Language("SQL") String expression, Session session) + { + return new ExpressionAssertProvider(runner, session, expression); } public void assertQueryAndPlan( @@ -527,6 +555,103 @@ public QueryAssert hasCorrectResultsRegardlessOfPushdown() } } + public static class ExpressionAssertProvider + implements AssertProvider + { + private final QueryRunner runner; + private final String expression; + private final Session session; + + private final Map bindings = new HashMap<>(); + + public ExpressionAssertProvider(QueryRunner runner, Session session, String expression) + { + this.runner = runner; + this.session = session; + this.expression = expression; + } + + public ExpressionAssertProvider binding(String variable, @Language("SQL") String value) + { + String previous = bindings.put(variable, value); + if (previous != null) { + fail("%s already bound to: %s".formatted(variable, value)); + } + return this; + } + + public Result evaluate() + { + if (bindings.isEmpty()) { + return run("VALUES %s".formatted(expression)); + } + else { + List> entries = ImmutableList.copyOf(bindings.entrySet()); + + List columns = entries.stream() + .map(Map.Entry::getKey) + .collect(toList()); + + List values = entries.stream() + .map(Map.Entry::getValue) + .collect(toList()); + + // Evaluate the expression using two modes: + // 1. Avoid constant folding -> exercises the compiler and evaluation engine + // 2. Force constant folding -> exercises the interpreter + + Result full = run(""" + SELECT %s + FROM ( + VALUES (%s) + ) t(%s) + WHERE rand() >= 0 + """ + .formatted( + expression, + Joiner.on(",").join(values), + Joiner.on(",").join(columns))); + + Result withConstantFolding = run(""" + SELECT %s + FROM ( + VALUES (%s) + ) t(%s) + """ + .formatted( + expression, + Joiner.on(",").join(values), + Joiner.on(",").join(columns))); + + if (!full.type().equals(withConstantFolding.type())) { + fail("Mismatched types between interpreter and evaluation engine: %s vs %s".formatted(full.type(), withConstantFolding.type())); + } + + if (!Objects.equals(full.value(), withConstantFolding.value())) { + fail("Mismatched results between interpreter and evaluation engine: %s vs %s".formatted(full.value(), withConstantFolding.value())); + } + + return new Result(full.type(), full.value); + } + } + + private Result run(String query) + { + MaterializedResult result = runner.execute(session, query); + return new Result(result.getTypes().get(0), result.getOnlyColumnAsSet().iterator().next()); + } + + @Override + public ExpressionAssert assertThat() + { + Result result = evaluate(); + return new ExpressionAssert(runner, session, result.value(), result.type()) + .withRepresentation(ExpressionAssert.TYPE_RENDERER); + } + + record Result(Type type, Object value) {} + } + public static class ExpressionAssert extends AbstractAssert { @@ -575,15 +700,6 @@ public String toStringOf(Object object) private final Session session; private final Type actualType; - static AssertProvider newExpressionAssert(String expression, QueryRunner runner, Session session) - { - MaterializedResult result = runner.execute(session, "VALUES " + expression); - Type type = result.getTypes().get(0); - Object value = result.getOnlyColumnAsSet().iterator().next(); - return () -> new ExpressionAssert(runner, session, value, type) - .withRepresentation(TYPE_RENDERER); - } - public ExpressionAssert(QueryRunner runner, Session session, Object actual, Type actualType) { super(actual, Object.class); @@ -613,6 +729,20 @@ public ExpressionAssert matches(@Language("SQL") String expression) }); } + /** + * Syntactic sugar for: + * + *
{@code
+         *     assertThat(...)
+         *         .hasType(type)
+         *         .isNull()
+         * }
+ */ + public void isNull(Type type) + { + hasType(type).isNull(); + } + public ExpressionAssert hasType(Type type) { objects.assertEqual(info, actualType, type);