Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@
import static io.trino.metadata.GlobalFunctionCatalog.builtinFunctionName;
import static io.trino.metadata.GlobalFunctionCatalog.isBuiltinFunctionName;
import static io.trino.metadata.LanguageFunctionManager.isInlineFunction;
import static io.trino.operator.scalar.JsonStringToArrayCast.JSON_STRING_TO_ARRAY_NAME;
import static io.trino.operator.scalar.JsonStringToMapCast.JSON_STRING_TO_MAP_NAME;
import static io.trino.operator.scalar.JsonStringToRowCast.JSON_STRING_TO_ROW_NAME;
import static io.trino.spi.expression.StandardFunctions.ADD_FUNCTION_NAME;
import static io.trino.spi.expression.StandardFunctions.AND_FUNCTION_NAME;
import static io.trino.spi.expression.StandardFunctions.ARRAY_CONSTRUCTOR_FUNCTION_NAME;
Expand Down Expand Up @@ -292,9 +295,18 @@ protected Optional<Expression> translateCall(io.trino.spi.expression.Call call)
return translateInPredicate(call.getArguments().get(0), call.getArguments().get(1));
}

ResolvedFunction resolved = plannerContext.getMetadata().resolveBuiltinFunction(
call.getFunctionName().getName(),
fromTypes(call.getArguments().stream().map(ConnectorExpression::getType).collect(toImmutableList())));
ResolvedFunction resolved;
if (JSON_STRING_TO_MAP_NAME.equals(call.getFunctionName().getName()) ||
JSON_STRING_TO_ARRAY_NAME.equals(call.getFunctionName().getName()) ||
JSON_STRING_TO_ROW_NAME.equals(call.getFunctionName().getName())) {
// These are special functions that currently need to be resolved via getCoercion() -- TODO: fix this
resolved = plannerContext.getMetadata().getCoercion(builtinFunctionName(call.getFunctionName().getName()), call.getArguments().get(0).getType(), call.getType());
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't getCoercion API is a bit tricky - The internal logic which we do to resolving the function - but the API looks like we are applying coercion from one type to another

}
else {
resolved = plannerContext.getMetadata().resolveBuiltinFunction(
call.getFunctionName().getName(),
fromTypes(call.getArguments().stream().map(ConnectorExpression::getType).collect(toImmutableList())));
}

return translateCall(call.getFunctionName().getName(), resolved, call.getArguments());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
import io.trino.spi.expression.Variable;
import io.trino.spi.function.OperatorType;
import io.trino.spi.type.ArrayType;
import io.trino.spi.type.MapType;
import io.trino.spi.type.RowType;
import io.trino.spi.type.Type;
import io.trino.spi.type.VarcharType;
import io.trino.sql.ir.Between;
Expand Down Expand Up @@ -59,6 +61,9 @@
import static io.airlift.slice.Slices.utf8Slice;
import static io.trino.metadata.GlobalFunctionCatalog.builtinFunctionName;
import static io.trino.operator.scalar.JoniRegexpCasts.joniRegexp;
import static io.trino.operator.scalar.JsonStringToArrayCast.JSON_STRING_TO_ARRAY_NAME;
import static io.trino.operator.scalar.JsonStringToMapCast.JSON_STRING_TO_MAP_NAME;
import static io.trino.operator.scalar.JsonStringToRowCast.JSON_STRING_TO_ROW_NAME;
import static io.trino.spi.expression.StandardFunctions.ADD_FUNCTION_NAME;
import static io.trino.spi.expression.StandardFunctions.AND_FUNCTION_NAME;
import static io.trino.spi.expression.StandardFunctions.ARRAY_CONSTRUCTOR_FUNCTION_NAME;
Expand Down Expand Up @@ -95,6 +100,7 @@
import static io.trino.sql.planner.ConnectorExpressionTranslator.translate;
import static io.trino.sql.planner.TestingPlannerContext.PLANNER_CONTEXT;
import static io.trino.testing.TransactionBuilder.transaction;
import static io.trino.type.InternalTypeManager.TESTING_TYPE_MANAGER;
import static io.trino.type.JoniRegexpType.JONI_REGEXP;
import static io.trino.type.JsonPathType.JSON_PATH;
import static io.trino.type.LikeFunctions.likePattern;
Expand Down Expand Up @@ -510,6 +516,46 @@ public void testTranslateIn()
new io.trino.spi.expression.Constant(Slices.wrappedBuffer(value.getBytes(UTF_8)), VARCHAR_TYPE))))));
}

@Test
public void testTranslateCastPlusJsonParse()
{
TransactionManager transactionManager = new TestingTransactionManager();
Metadata metadata = MetadataManager.testMetadataManagerBuilder().withTransactionManager(transactionManager).build();
transaction(transactionManager, metadata, new AllowAllAccessControl())
.readOnly()
.execute(TEST_SESSION, transactionSession -> {
assertTranslationRoundTrips(
transactionSession,
new Call(
PLANNER_CONTEXT.getMetadata().getCoercion(builtinFunctionName(JSON_STRING_TO_ARRAY_NAME), VARCHAR, new ArrayType(VARCHAR_TYPE)),
ImmutableList.of(new Reference(VARCHAR_TYPE, "varchar_symbol_1"))),
new io.trino.spi.expression.Call(
new ArrayType(VARCHAR_TYPE),
new FunctionName(JSON_STRING_TO_ARRAY_NAME),
List.of(new Variable("varchar_symbol_1", VARCHAR_TYPE))));

assertTranslationRoundTrips(
transactionSession,
new Call(
PLANNER_CONTEXT.getMetadata().getCoercion(builtinFunctionName(JSON_STRING_TO_MAP_NAME), VARCHAR, new MapType(VARCHAR_TYPE, VARCHAR_TYPE, TESTING_TYPE_MANAGER.getTypeOperators())),
ImmutableList.of(new Reference(VARCHAR_TYPE, "varchar_symbol_1"))),
new io.trino.spi.expression.Call(
new MapType(VARCHAR_TYPE, VARCHAR_TYPE, TESTING_TYPE_MANAGER.getTypeOperators()),
new FunctionName(JSON_STRING_TO_MAP_NAME),
List.of(new Variable("varchar_symbol_1", VARCHAR_TYPE))));

assertTranslationRoundTrips(
transactionSession,
new Call(
PLANNER_CONTEXT.getMetadata().getCoercion(builtinFunctionName(JSON_STRING_TO_ROW_NAME), VARCHAR, RowType.anonymousRow(VARCHAR)),
ImmutableList.of(new Reference(VARCHAR_TYPE, "varchar_symbol_1"))),
new io.trino.spi.expression.Call(
RowType.anonymousRow(VARCHAR),
new FunctionName(JSON_STRING_TO_ROW_NAME),
List.of(new Variable("varchar_symbol_1", VARCHAR_TYPE))));
});
}

private void assertTranslationRoundTrips(Expression expression, ConnectorExpression connectorExpression)
{
assertTranslationRoundTrips(TEST_SESSION, expression, connectorExpression);
Expand Down