diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/planPrinter/RowExpressionFormatter.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/planPrinter/RowExpressionFormatter.java index 7f72119911e95..31104dcff1f0b 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/planPrinter/RowExpressionFormatter.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/planPrinter/RowExpressionFormatter.java @@ -15,6 +15,7 @@ import com.facebook.presto.common.block.Block; import com.facebook.presto.common.type.Type; +import com.facebook.presto.common.type.VarbinaryType; import com.facebook.presto.metadata.FunctionAndTypeManager; import com.facebook.presto.spi.ConnectorSession; import com.facebook.presto.spi.function.FunctionMetadata; @@ -79,8 +80,25 @@ else if (standardFunctionResolution.isSubscriptFunction(node.getFunctionHandle() return formatRowExpression(session, node.getArguments().get(0)) + "[" + formatRowExpression(session, node.getArguments().get(1)) + "]"; } else if (standardFunctionResolution.isBetweenFunction(node.getFunctionHandle())) { - List formattedExpresions = formatRowExpressions(session, node.getArguments()); - return String.format("%s BETWEEN (%s) AND (%s)", formattedExpresions.get(0), formattedExpresions.get(1), formattedExpresions.get(2)); + List formattedExpressions = formatRowExpressions(session, node.getArguments()); + return String.format("%s BETWEEN (%s) AND (%s)", formattedExpressions.get(0), formattedExpressions.get(1), formattedExpressions.get(2)); + } + else if (standardFunctionResolution.isLikeFunction(node.getFunctionHandle())) { + RowExpression value = node.getArguments().get(0); + CallExpression patternCallExpression = (CallExpression) node.getArguments().get(1); + + // second LIKE argument is: + // CAST(pattern as LikePattern), if escape is not present + // LIKE_PATTERN(pattern, escape), if escape is present + if (standardFunctionResolution.isCastFunction(patternCallExpression.getFunctionHandle())) { + RowExpression pattern = patternCallExpression.getArguments().get(0); + return String.format("%s LIKE %s", formatRowExpression(session, value), formatRowExpression(session, pattern)); + } + else { + RowExpression pattern = patternCallExpression.getArguments().get(0); + RowExpression escape = patternCallExpression.getArguments().get(1); + return String.format("%s LIKE %s ESCAPE %s", formatRowExpression(session, value), formatRowExpression(session, pattern), formatRowExpression(session, escape)); + } } FunctionMetadata metadata = functionMetadataManager.getFunctionMetadata(node.getFunctionHandle()); return node.getDisplayName() + (metadata.getVersion().hasVersion() ? ":" + metadata.getVersion() : "") + "(" + String.join(", ", formatRowExpressions(session, node.getArguments())) + ")"; @@ -123,12 +141,19 @@ public String visitConstant(ConstantExpression node, ConnectorSession session) } Type type = node.getType(); - if (node.getType().getJavaType() == Block.class) { + if (type.getJavaType() == Block.class) { Block block = (Block) value; // TODO: format block return format("[Block: position count: %s; size: %s bytes]", block.getPositionCount(), block.getRetainedSizeInBytes()); } - return type.getDisplayName().toUpperCase() + " " + value.toString(); + + String valueString = "'" + value.toString().replace("'", "''") + "'"; + + if (VarbinaryType.isVarbinaryType(type)) { + return "X" + valueString; + } + + return type.getTypeSignature().getBase().toUpperCase() + valueString; } } } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestRowExpressionFormatter.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestRowExpressionFormatter.java index 0e26f7994f939..d6f52b2ab0d17 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestRowExpressionFormatter.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestRowExpressionFormatter.java @@ -26,6 +26,7 @@ import com.facebook.presto.spi.relation.SpecialFormExpression; import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.planPrinter.RowExpressionFormatter; +import com.facebook.presto.sql.relational.FunctionResolution; import com.google.common.collect.ImmutableList; import com.google.common.io.BaseEncoding; import io.airlift.slice.Slice; @@ -74,15 +75,19 @@ import static com.facebook.presto.type.ColorType.COLOR; import static com.facebook.presto.type.IntervalDayTimeType.INTERVAL_DAY_TIME; import static com.facebook.presto.type.IntervalYearMonthType.INTERVAL_YEAR_MONTH; +import static com.facebook.presto.type.LikePatternType.LIKE_PATTERN; +import static io.airlift.slice.Slices.utf8Slice; import static java.lang.Float.floatToIntBits; import static org.testng.Assert.assertEquals; public class TestRowExpressionFormatter { private static final FunctionAndTypeManager FUNCTION_AND_TYPE_MANAGER = createTestFunctionAndTypeManager(); + private static final FunctionResolution FUNCTION_RESOLUTION = new FunctionResolution(FUNCTION_AND_TYPE_MANAGER); private static final RowExpressionFormatter FORMATTER = new RowExpressionFormatter(FUNCTION_AND_TYPE_MANAGER); private static final VariableReferenceExpression C_BIGINT = new VariableReferenceExpression("c_bigint", BIGINT); private static final VariableReferenceExpression C_BIGINT_ARRAY = new VariableReferenceExpression("c_bigint_array", new ArrayType(BIGINT)); + private static final VariableReferenceExpression C_VARCHAR = new VariableReferenceExpression("c_varchar", VARCHAR); @Test public void testConstants() @@ -93,66 +98,66 @@ public void testConstants() // boolean constantExpression = constant(true, BOOLEAN); - assertEquals(format(constantExpression), "BOOLEAN true"); + assertEquals(format(constantExpression), "BOOLEAN'true'"); // double constantExpression = constant(1.1, DOUBLE); - assertEquals(format(constantExpression), "DOUBLE 1.1"); + assertEquals(format(constantExpression), "DOUBLE'1.1'"); constantExpression = constant(Double.NaN, DOUBLE); - assertEquals(format(constantExpression), "DOUBLE NaN"); + assertEquals(format(constantExpression), "DOUBLE'NaN'"); constantExpression = constant(Double.POSITIVE_INFINITY, DOUBLE); - assertEquals(format(constantExpression), "DOUBLE Infinity"); + assertEquals(format(constantExpression), "DOUBLE'Infinity'"); // real constantExpression = constant((long) floatToIntBits(1.1f), REAL); - assertEquals(format(constantExpression), "REAL 1.1"); + assertEquals(format(constantExpression), "REAL'1.1'"); constantExpression = constant((long) floatToIntBits(Float.NaN), REAL); - assertEquals(format(constantExpression), "REAL NaN"); + assertEquals(format(constantExpression), "REAL'NaN'"); constantExpression = constant((long) floatToIntBits(Float.POSITIVE_INFINITY), REAL); - assertEquals(format(constantExpression), "REAL Infinity"); + assertEquals(format(constantExpression), "REAL'Infinity'"); // string - constantExpression = constant(Slices.utf8Slice("abcde"), VARCHAR); - assertEquals(format(constantExpression), "VARCHAR abcde"); - constantExpression = constant(Slices.utf8Slice("fgh"), createCharType(3)); - assertEquals(format(constantExpression), "CHAR(3) fgh"); + constantExpression = constant(utf8Slice("abcde"), VARCHAR); + assertEquals(format(constantExpression), "VARCHAR'abcde'"); + constantExpression = constant(utf8Slice("fgh"), createCharType(3)); + assertEquals(format(constantExpression), "CHAR'fgh'"); // integer constantExpression = constant(1L, TINYINT); - assertEquals(format(constantExpression), "TINYINT 1"); + assertEquals(format(constantExpression), "TINYINT'1'"); constantExpression = constant(1L, SMALLINT); - assertEquals(format(constantExpression), "SMALLINT 1"); + assertEquals(format(constantExpression), "SMALLINT'1'"); constantExpression = constant(1L, INTEGER); - assertEquals(format(constantExpression), "INTEGER 1"); + assertEquals(format(constantExpression), "INTEGER'1'"); constantExpression = constant(1L, BIGINT); - assertEquals(format(constantExpression), "BIGINT 1"); + assertEquals(format(constantExpression), "BIGINT'1'"); // varbinary - Slice value = Slices.wrappedBuffer(BaseEncoding.base16().decode("123456")); + Slice value = Slices.wrappedBuffer(BaseEncoding.base16().decode("123456AB")); constantExpression = constant(value, VARBINARY); - assertEquals(format(constantExpression), "VARBINARY 12 34 56"); + assertEquals(format(constantExpression), "X'12 34 56 ab'"); // color constantExpression = constant(256L, COLOR); - assertEquals(format(constantExpression), "COLOR 256"); + assertEquals(format(constantExpression), "COLOR'256'"); // long and short decimals constantExpression = constant(decimal("1.2345678910"), DecimalType.createDecimalType(11, 10)); - assertEquals(format(constantExpression), "DECIMAL(11,10) 1.2345678910"); + assertEquals(format(constantExpression), "DECIMAL'1.2345678910'"); constantExpression = constant(decimal("1.281734081274028174012432412423134"), DecimalType.createDecimalType(34, 33)); - assertEquals(format(constantExpression), "DECIMAL(34,33) 1.281734081274028174012432412423134"); + assertEquals(format(constantExpression), "DECIMAL'1.281734081274028174012432412423134'"); // time constantExpression = constant(662727600000L, TIMESTAMP); - assertEquals(format(constantExpression), "TIMESTAMP 1991-01-01 00:00:00.000"); + assertEquals(format(constantExpression), "TIMESTAMP'1991-01-01 00:00:00.000'"); constantExpression = constant(7670L, DATE); - assertEquals(format(constantExpression), "DATE 1991-01-01"); + assertEquals(format(constantExpression), "DATE'1991-01-01'"); // interval constantExpression = constant(24L, INTERVAL_DAY_TIME); - assertEquals(format(constantExpression), "INTERVAL DAY TO SECOND 0 00:00:00.024"); + assertEquals(format(constantExpression), "INTERVAL DAY TO SECOND'0 00:00:00.024'"); constantExpression = constant(25L, INTERVAL_YEAR_MONTH); - assertEquals(format(constantExpression), "INTERVAL YEAR TO MONTH 2-1"); + assertEquals(format(constantExpression), "INTERVAL YEAR TO MONTH'2-1'"); // block constantExpression = constant(new LongArrayBlockBuilder(null, 4).writeLong(1L).writeLong(2).build(), new ArrayType(BIGINT)); @@ -166,31 +171,31 @@ public void testCalls() // arithmetic callExpression = createCallExpression(ADD); - assertEquals(format(callExpression), "(c_bigint) + (BIGINT 5)"); + assertEquals(format(callExpression), "(c_bigint) + (BIGINT'5')"); callExpression = createCallExpression(SUBTRACT); - assertEquals(format(callExpression), "(c_bigint) - (BIGINT 5)"); + assertEquals(format(callExpression), "(c_bigint) - (BIGINT'5')"); callExpression = createCallExpression(MULTIPLY); - assertEquals(format(callExpression), "(c_bigint) * (BIGINT 5)"); + assertEquals(format(callExpression), "(c_bigint) * (BIGINT'5')"); callExpression = createCallExpression(DIVIDE); - assertEquals(format(callExpression), "(c_bigint) / (BIGINT 5)"); + assertEquals(format(callExpression), "(c_bigint) / (BIGINT'5')"); callExpression = createCallExpression(MODULUS); - assertEquals(format(callExpression), "(c_bigint) % (BIGINT 5)"); + assertEquals(format(callExpression), "(c_bigint) % (BIGINT'5')"); // comparison callExpression = createCallExpression(GREATER_THAN); - assertEquals(format(callExpression), "(c_bigint) > (BIGINT 5)"); + assertEquals(format(callExpression), "(c_bigint) > (BIGINT'5')"); callExpression = createCallExpression(LESS_THAN); - assertEquals(format(callExpression), "(c_bigint) < (BIGINT 5)"); + assertEquals(format(callExpression), "(c_bigint) < (BIGINT'5')"); callExpression = createCallExpression(GREATER_THAN_OR_EQUAL); - assertEquals(format(callExpression), "(c_bigint) >= (BIGINT 5)"); + assertEquals(format(callExpression), "(c_bigint) >= (BIGINT'5')"); callExpression = createCallExpression(LESS_THAN_OR_EQUAL); - assertEquals(format(callExpression), "(c_bigint) <= (BIGINT 5)"); + assertEquals(format(callExpression), "(c_bigint) <= (BIGINT'5')"); callExpression = createCallExpression(EQUAL); - assertEquals(format(callExpression), "(c_bigint) = (BIGINT 5)"); + assertEquals(format(callExpression), "(c_bigint) = (BIGINT'5')"); callExpression = createCallExpression(NOT_EQUAL); - assertEquals(format(callExpression), "(c_bigint) <> (BIGINT 5)"); + assertEquals(format(callExpression), "(c_bigint) <> (BIGINT'5')"); callExpression = createCallExpression(IS_DISTINCT_FROM); - assertEquals(format(callExpression), "(c_bigint) IS DISTINCT FROM (BIGINT 5)"); + assertEquals(format(callExpression), "(c_bigint) IS DISTINCT FROM (BIGINT'5')"); // negation RowExpression expression = createCallExpression(ADD); @@ -199,7 +204,7 @@ public void testCalls() FUNCTION_AND_TYPE_MANAGER.resolveOperator(NEGATION, fromTypes(expression.getType())), expression.getType(), expression); - assertEquals(format(callExpression), "-((c_bigint) + (BIGINT 5))"); + assertEquals(format(callExpression), "-((c_bigint) + (BIGINT'5'))"); // subscript ArrayType arrayType = (ArrayType) C_BIGINT_ARRAY.getType(); @@ -209,15 +214,15 @@ public void testCalls() elementType, ImmutableList.of(C_BIGINT_ARRAY, constant(0L, INTEGER))); callExpression = subscriptExpression; - assertEquals(format(callExpression), "c_bigint_array[INTEGER 0]"); + assertEquals(format(callExpression), "c_bigint_array[INTEGER'0']"); // cast callExpression = call( - CAST.name(), - FUNCTION_AND_TYPE_MANAGER.lookupCast(CastType.CAST, TINYINT.getTypeSignature(), BIGINT.getTypeSignature()), - BIGINT, - constant(1L, TINYINT)); - assertEquals(format(callExpression), "CAST(TINYINT 1 AS bigint)"); + CAST.name(), + FUNCTION_AND_TYPE_MANAGER.lookupCast(CastType.CAST, TINYINT.getTypeSignature(), BIGINT.getTypeSignature()), + BIGINT, + constant(1L, TINYINT)); + assertEquals(format(callExpression), "CAST(TINYINT'1' AS bigint)"); // between callExpression = call( @@ -227,7 +232,7 @@ public void testCalls() subscriptExpression, constant(1L, BIGINT), constant(5L, BIGINT)); - assertEquals(format(callExpression), "c_bigint_array[INTEGER 0] BETWEEN (BIGINT 1) AND (BIGINT 5)"); + assertEquals(format(callExpression), "c_bigint_array[INTEGER'0'] BETWEEN (BIGINT'1') AND (BIGINT'5')"); // other callExpression = call( @@ -235,7 +240,34 @@ public void testCalls() FUNCTION_AND_TYPE_MANAGER.resolveOperator(HASH_CODE, fromTypes(BIGINT)), BIGINT, constant(1L, BIGINT)); - assertEquals(format(callExpression), "HASH_CODE(BIGINT 1)"); + assertEquals(format(callExpression), "HASH_CODE(BIGINT'1')"); + + // like + callExpression = call( + "LIKE", + FUNCTION_RESOLUTION.likeVarcharFunction(), + BOOLEAN, + C_VARCHAR, + call( + CAST.name(), + FUNCTION_AND_TYPE_MANAGER.lookupCast(CastType.CAST, VARCHAR.getTypeSignature(), LIKE_PATTERN.getTypeSignature()), + LIKE_PATTERN, + constant(utf8Slice("prefix%"), VARCHAR))); + assertEquals(format(callExpression), "c_varchar LIKE VARCHAR'prefix%'"); + + // like escape + callExpression = call( + "LIKE", + FUNCTION_RESOLUTION.likeVarcharFunction(), + BOOLEAN, + C_VARCHAR, + call( + "LIKE_PATTERN", + FUNCTION_RESOLUTION.likePatternFunction(), + LIKE_PATTERN, + constant(utf8Slice("%escaped$_"), VARCHAR), + constant(utf8Slice("$"), VARCHAR))); + assertEquals(format(callExpression), "c_varchar LIKE VARCHAR'%escaped$_' ESCAPE VARCHAR'$'"); } @Test @@ -245,13 +277,13 @@ public void testSpecialForm() // or and and specialFormExpression = new SpecialFormExpression(OR, BOOLEAN, createCallExpression(NOT_EQUAL), createCallExpression(IS_DISTINCT_FROM)); - assertEquals(format(specialFormExpression), "((c_bigint) <> (BIGINT 5)) OR ((c_bigint) IS DISTINCT FROM (BIGINT 5))"); + assertEquals(format(specialFormExpression), "((c_bigint) <> (BIGINT'5')) OR ((c_bigint) IS DISTINCT FROM (BIGINT'5'))"); specialFormExpression = new SpecialFormExpression(AND, BOOLEAN, createCallExpression(EQUAL), createCallExpression(GREATER_THAN)); - assertEquals(format(specialFormExpression), "((c_bigint) = (BIGINT 5)) AND ((c_bigint) > (BIGINT 5))"); + assertEquals(format(specialFormExpression), "((c_bigint) = (BIGINT'5')) AND ((c_bigint) > (BIGINT'5'))"); // other specialFormExpression = new SpecialFormExpression(IS_NULL, BOOLEAN, createCallExpression(ADD)); - assertEquals(format(specialFormExpression), "IS_NULL((c_bigint) + (BIGINT 5))"); + assertEquals(format(specialFormExpression), "IS_NULL((c_bigint) + (BIGINT'5'))"); } @Test @@ -266,7 +298,7 @@ public void testComplex() BIGINT, C_BIGINT, expression); - assertEquals(format(complexExpression), "(c_bigint) - ((c_bigint) + (BIGINT 5))"); + assertEquals(format(complexExpression), "(c_bigint) - ((c_bigint) + (BIGINT'5'))"); RowExpression expression1 = createCallExpression(ADD); RowExpression expression2 = call( @@ -277,7 +309,7 @@ public void testComplex() C_BIGINT); RowExpression expression3 = createCallExpression(GREATER_THAN); complexExpression = new SpecialFormExpression(OR, BOOLEAN, expression2, expression3); - assertEquals(format(complexExpression), "(((c_bigint) + (BIGINT 5)) * (c_bigint)) OR ((c_bigint) > (BIGINT 5))"); + assertEquals(format(complexExpression), "(((c_bigint) + (BIGINT'5')) * (c_bigint)) OR ((c_bigint) > (BIGINT'5'))"); ArrayType arrayType = (ArrayType) C_BIGINT_ARRAY.getType(); Type elementType = arrayType.getElementType(); @@ -296,7 +328,7 @@ public void testComplex() BIGINT, expression2, constant(5L, BIGINT)); - assertEquals(format(expression3), "(-(c_bigint_array[INTEGER 5])) + (BIGINT 5)"); + assertEquals(format(expression3), "(-(c_bigint_array[INTEGER'5'])) + (BIGINT'5')"); } protected static Object decimal(String decimalString) diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/function/StandardFunctionResolution.java b/presto-spi/src/main/java/com/facebook/presto/spi/function/StandardFunctionResolution.java index c87f27cde16ff..8cfb0a05d6788 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/function/StandardFunctionResolution.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/function/StandardFunctionResolution.java @@ -32,6 +32,8 @@ public interface StandardFunctionResolution FunctionHandle likeCharFunction(Type valueType); + boolean isLikeFunction(FunctionHandle functionHandle); + FunctionHandle arrayConstructor(List argumentTypes); FunctionHandle arithmeticFunction(OperatorType operator, Type leftType, Type rightType);