From 3449a8ca5e4dfb23a3f50e0baa306c57a8c897c0 Mon Sep 17 00:00:00 2001 From: Natasha Sehgal Date: Mon, 27 Jan 2025 12:05:38 -0800 Subject: [PATCH] [native] Fix varchar cast for json (#24396) Summary: https://www.internalfb.com/tasks/?t=211442303 There was an error in running query on Prestissimo not Presto - "Scalar function presto.default.substr not registered with arguments: (JSON, BIGINT, BIGINT)". This is not due to missing function, as the function signature does not exist in Presto. It occurs when attempting to cast JSON as varchar of capped length. Related Diff: https://www.internalfb.com/diff/D59531026 Note: Exception is still raised for try_cast() behavior. Alignment is out of scope for this PR Differential Revision: D68353517 --- .../main/types/PrestoToVeloxExpr.cpp | 11 +++++--- .../main/types/tests/RowExpressionTest.cpp | 25 +++++++++++++++++++ .../AbstractTestNativeGeneralQueries.java | 1 + 3 files changed, 34 insertions(+), 3 deletions(-) diff --git a/presto-native-execution/presto_cpp/main/types/PrestoToVeloxExpr.cpp b/presto-native-execution/presto_cpp/main/types/PrestoToVeloxExpr.cpp index 2c2b2a3c5ea00..ba552c8ac1e78 100644 --- a/presto-native-execution/presto_cpp/main/types/PrestoToVeloxExpr.cpp +++ b/presto-native-execution/presto_cpp/main/types/PrestoToVeloxExpr.cpp @@ -181,8 +181,13 @@ std::optional convertCastToVarcharWithMaxLength( VELOX_DCHECK(end == returnType.data() + returnType.size() - 1); VELOX_DCHECK_EQ(args.size(), 1); - const auto arg = args[0]; + auto arg = args[0]; + // If the argument is of JSON type, convert it to VARCHAR before applying + // substr. + if (velox::isJsonType(arg->type())) { + arg = std::make_shared(velox::VARCHAR(), arg, false); + } return std::make_shared( arg->type(), std::vector{ @@ -256,8 +261,8 @@ std::optional tryConvertCast( } // When the return type is varchar with max length, truncate if only the - // argument type is varchar (or varchar with max length). Non-varchar argument - // types are not truncated. + // argument type is varchar, or varchar with max length or json. Non-varchar + // argument types are not truncated. if (returnType.find(kVarchar) == 0 && args[0]->type()->kind() == TypeKind::VARCHAR && returnType.size() > strlen(kVarchar)) { diff --git a/presto-native-execution/presto_cpp/main/types/tests/RowExpressionTest.cpp b/presto-native-execution/presto_cpp/main/types/tests/RowExpressionTest.cpp index 176b4e1fff7cf..b0a078c593e55 100644 --- a/presto-native-execution/presto_cpp/main/types/tests/RowExpressionTest.cpp +++ b/presto-native-execution/presto_cpp/main/types/tests/RowExpressionTest.cpp @@ -18,6 +18,7 @@ #include "presto_cpp/presto_protocol/core/presto_protocol_core.h" #include "velox/core/Expressions.h" #include "velox/type/Type.h" +#include "velox/functions/prestosql/types/JsonType.h" using namespace facebook::presto; using namespace facebook::velox; @@ -30,6 +31,7 @@ class RowExpressionTest : public ::testing::Test { } void SetUp() override { + registerJsonType(); pool_ = memory::MemoryManager::getInstance()->addLeafPool(); converter_ = std::make_unique(pool_.get(), &typeParser_); @@ -626,6 +628,29 @@ TEST_F(RowExpressionTest, castToVarchar) { ASSERT_TRUE(returnExpr->nullOnFailure()); ASSERT_EQ(returnExpr->type()->toString(), "VARCHAR"); } + // CAST(json AS varchar(3)) + { + std::shared_ptr p = + json::parse(makeCastToVarchar(false, "json", "varchar(3)")); + auto expr = converter_->toVeloxExpr(p); + auto returnExpr = std::dynamic_pointer_cast(expr); + + ASSERT_NE(returnExpr, nullptr); + ASSERT_EQ(returnExpr->name(), "presto.default.substr"); + + auto returnArg1 = std::dynamic_pointer_cast( + returnExpr->inputs()[0]); + auto returnArg2 = std::dynamic_pointer_cast( + returnExpr->inputs()[1]); + auto returnArg3 = std::dynamic_pointer_cast( + returnExpr->inputs()[2]); + + ASSERT_EQ(returnArg1->type()->toString(), "VARCHAR"); + ASSERT_EQ(returnArg2->type()->toString(), "BIGINT"); + ASSERT_EQ(returnArg2->value().toJson(returnArg2->type()), "1"); + ASSERT_EQ(returnArg3->type()->toString(), "BIGINT"); + ASSERT_EQ(returnArg3->value().toJson(returnArg3->type()), "3"); + } } TEST_F(RowExpressionTest, special) { diff --git a/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/AbstractTestNativeGeneralQueries.java b/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/AbstractTestNativeGeneralQueries.java index 5717ef30926f3..ac82d3cf2733e 100644 --- a/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/AbstractTestNativeGeneralQueries.java +++ b/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/AbstractTestNativeGeneralQueries.java @@ -510,6 +510,7 @@ public void testCast() // Round-trip tests of casts for Json. assertQuery("SELECT cast(cast(name as JSON) as VARCHAR), cast(cast(size as JSON) as INTEGER), cast(cast(size + 0.01 as JSON) as DOUBLE), cast(cast(size > 5 as JSON) as BOOLEAN) FROM part"); assertQuery("SELECT cast(cast(array[suppkey, nationkey] as JSON) as ARRAY(INTEGER)), cast(cast(map(array[name, address, phone], array[1.1, 2.2, 3.3]) as JSON) as MAP(VARCHAR(40), DOUBLE)), cast(cast(map(array[name], array[phone]) as JSON) as MAP(VARCHAR(25), JSON)), cast(cast(array[array[suppkey], array[nationkey]] as JSON) as ARRAY(JSON)) from supplier"); + assertQuery("SELECT cast(json_extract(x, '$.a') AS varchar(255)) AS extracted_value FROM (VALUES ('{\"a\": \"Some long string\"}')) AS t(x)"); // Cast from date to timestamp assertQuery("SELECT CAST(date(shipdate) AS timestamp) FROM lineitem");