diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/ConnectorExpressionTranslator.java b/core/trino-main/src/main/java/io/trino/sql/planner/ConnectorExpressionTranslator.java index 096332d17159..73453ec31a5e 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/ConnectorExpressionTranslator.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/ConnectorExpressionTranslator.java @@ -69,6 +69,9 @@ import static io.trino.metadata.GlobalFunctionCatalog.builtinFunctionName; import static io.trino.metadata.GlobalFunctionCatalog.isBuiltinFunctionName; import static io.trino.metadata.LanguageFunctionManager.isInlineFunction; +import static io.trino.operator.scalar.JsonStringToArrayCast.JSON_STRING_TO_ARRAY_NAME; +import static io.trino.operator.scalar.JsonStringToMapCast.JSON_STRING_TO_MAP_NAME; +import static io.trino.operator.scalar.JsonStringToRowCast.JSON_STRING_TO_ROW_NAME; import static io.trino.spi.expression.StandardFunctions.ADD_FUNCTION_NAME; import static io.trino.spi.expression.StandardFunctions.AND_FUNCTION_NAME; import static io.trino.spi.expression.StandardFunctions.ARRAY_CONSTRUCTOR_FUNCTION_NAME; @@ -292,9 +295,18 @@ protected Optional translateCall(io.trino.spi.expression.Call call) return translateInPredicate(call.getArguments().get(0), call.getArguments().get(1)); } - ResolvedFunction resolved = plannerContext.getMetadata().resolveBuiltinFunction( - call.getFunctionName().getName(), - fromTypes(call.getArguments().stream().map(ConnectorExpression::getType).collect(toImmutableList()))); + ResolvedFunction resolved; + if (JSON_STRING_TO_MAP_NAME.equals(call.getFunctionName().getName()) || + JSON_STRING_TO_ARRAY_NAME.equals(call.getFunctionName().getName()) || + JSON_STRING_TO_ROW_NAME.equals(call.getFunctionName().getName())) { + // These are special functions that currently need to be resolved via getCoercion() -- TODO: fix this + resolved = plannerContext.getMetadata().getCoercion(builtinFunctionName(call.getFunctionName().getName()), call.getArguments().get(0).getType(), call.getType()); + } + else { + resolved = plannerContext.getMetadata().resolveBuiltinFunction( + call.getFunctionName().getName(), + fromTypes(call.getArguments().stream().map(ConnectorExpression::getType).collect(toImmutableList()))); + } return translateCall(call.getFunctionName().getName(), resolved, call.getArguments()); } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestConnectorExpressionTranslator.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestConnectorExpressionTranslator.java index bea838cb935f..054f504d9a8e 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestConnectorExpressionTranslator.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestConnectorExpressionTranslator.java @@ -30,6 +30,8 @@ import io.trino.spi.expression.Variable; import io.trino.spi.function.OperatorType; import io.trino.spi.type.ArrayType; +import io.trino.spi.type.MapType; +import io.trino.spi.type.RowType; import io.trino.spi.type.Type; import io.trino.spi.type.VarcharType; import io.trino.sql.ir.Between; @@ -59,6 +61,9 @@ import static io.airlift.slice.Slices.utf8Slice; import static io.trino.metadata.GlobalFunctionCatalog.builtinFunctionName; import static io.trino.operator.scalar.JoniRegexpCasts.joniRegexp; +import static io.trino.operator.scalar.JsonStringToArrayCast.JSON_STRING_TO_ARRAY_NAME; +import static io.trino.operator.scalar.JsonStringToMapCast.JSON_STRING_TO_MAP_NAME; +import static io.trino.operator.scalar.JsonStringToRowCast.JSON_STRING_TO_ROW_NAME; import static io.trino.spi.expression.StandardFunctions.ADD_FUNCTION_NAME; import static io.trino.spi.expression.StandardFunctions.AND_FUNCTION_NAME; import static io.trino.spi.expression.StandardFunctions.ARRAY_CONSTRUCTOR_FUNCTION_NAME; @@ -95,6 +100,7 @@ import static io.trino.sql.planner.ConnectorExpressionTranslator.translate; import static io.trino.sql.planner.TestingPlannerContext.PLANNER_CONTEXT; import static io.trino.testing.TransactionBuilder.transaction; +import static io.trino.type.InternalTypeManager.TESTING_TYPE_MANAGER; import static io.trino.type.JoniRegexpType.JONI_REGEXP; import static io.trino.type.JsonPathType.JSON_PATH; import static io.trino.type.LikeFunctions.likePattern; @@ -510,6 +516,46 @@ public void testTranslateIn() new io.trino.spi.expression.Constant(Slices.wrappedBuffer(value.getBytes(UTF_8)), VARCHAR_TYPE)))))); } + @Test + public void testTranslateCastPlusJsonParse() + { + TransactionManager transactionManager = new TestingTransactionManager(); + Metadata metadata = MetadataManager.testMetadataManagerBuilder().withTransactionManager(transactionManager).build(); + transaction(transactionManager, metadata, new AllowAllAccessControl()) + .readOnly() + .execute(TEST_SESSION, transactionSession -> { + assertTranslationRoundTrips( + transactionSession, + new Call( + PLANNER_CONTEXT.getMetadata().getCoercion(builtinFunctionName(JSON_STRING_TO_ARRAY_NAME), VARCHAR, new ArrayType(VARCHAR_TYPE)), + ImmutableList.of(new Reference(VARCHAR_TYPE, "varchar_symbol_1"))), + new io.trino.spi.expression.Call( + new ArrayType(VARCHAR_TYPE), + new FunctionName(JSON_STRING_TO_ARRAY_NAME), + List.of(new Variable("varchar_symbol_1", VARCHAR_TYPE)))); + + assertTranslationRoundTrips( + transactionSession, + new Call( + PLANNER_CONTEXT.getMetadata().getCoercion(builtinFunctionName(JSON_STRING_TO_MAP_NAME), VARCHAR, new MapType(VARCHAR_TYPE, VARCHAR_TYPE, TESTING_TYPE_MANAGER.getTypeOperators())), + ImmutableList.of(new Reference(VARCHAR_TYPE, "varchar_symbol_1"))), + new io.trino.spi.expression.Call( + new MapType(VARCHAR_TYPE, VARCHAR_TYPE, TESTING_TYPE_MANAGER.getTypeOperators()), + new FunctionName(JSON_STRING_TO_MAP_NAME), + List.of(new Variable("varchar_symbol_1", VARCHAR_TYPE)))); + + assertTranslationRoundTrips( + transactionSession, + new Call( + PLANNER_CONTEXT.getMetadata().getCoercion(builtinFunctionName(JSON_STRING_TO_ROW_NAME), VARCHAR, RowType.anonymousRow(VARCHAR)), + ImmutableList.of(new Reference(VARCHAR_TYPE, "varchar_symbol_1"))), + new io.trino.spi.expression.Call( + RowType.anonymousRow(VARCHAR), + new FunctionName(JSON_STRING_TO_ROW_NAME), + List.of(new Variable("varchar_symbol_1", VARCHAR_TYPE)))); + }); + } + private void assertTranslationRoundTrips(Expression expression, ConnectorExpression connectorExpression) { assertTranslationRoundTrips(TEST_SESSION, expression, connectorExpression);