Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -841,44 +841,47 @@ protected RowExpression visitNullIfExpression(NullIfExpression node, Context con
{
RowExpression first = process(node.getFirst(), context);
RowExpression second = process(node.getSecond(), context);
Type returnType = getType(node);

if (isNative && !second.getType().equals(first.getType())) {
Optional<Type> commonType = functionAndTypeResolver.getCommonSuperType(first.getType(), second.getType());
if (!commonType.isPresent()) {
throw new SemanticException(TYPE_MISMATCH, node, "Types are not comparable with NULLIF: %s vs %s", first.getType(), second.getType());
}

Type returnType = getType(node);
if (isNative) {
// If the first type is unknown, as per presto's NULL_IF semantics we should not infer the type using second argument.
// Always return a null with unknown type.
if (first.getType().equals(UnknownType.UNKNOWN)) {
return constantNull(UnknownType.UNKNOWN);
}
RowExpression originalFirst = first;
// cast(first as <common type>)
if (!first.getType().equals(commonType.get())) {
first = call(
getSourceLocation(node),
CAST.name(),
functionAndTypeResolver.lookupCast(CAST.name(), first.getType(), commonType.get()),
commonType.get(), first);
}
// cast(second as <common type>)
if (!second.getType().equals(commonType.get())) {
second = call(
getSourceLocation(node),
CAST.name(),
functionAndTypeResolver.lookupCast(CAST.name(), second.getType(), commonType.get()),
commonType.get(), second);
RowExpression firstArgWithoutCast = first;

if (!second.getType().equals(first.getType())) {
Optional<Type> commonType = functionAndTypeResolver.getCommonSuperType(first.getType(), second.getType());
if (!commonType.isPresent()) {
throw new SemanticException(TYPE_MISMATCH, node, "Types are not comparable with NULLIF: %s vs %s", first.getType(), second.getType());
}

// cast(first as <common type>)
if (!first.getType().equals(commonType.get())) {
first = call(
getSourceLocation(node),
CAST.name(),
functionAndTypeResolver.lookupCast(CAST.name(), first.getType(), commonType.get()),
commonType.get(), first);
}
// cast(second as <common type>)
if (!second.getType().equals(commonType.get())) {
second = call(
getSourceLocation(node),
CAST.name(),
functionAndTypeResolver.lookupCast(CAST.name(), second.getType(), commonType.get()),
commonType.get(), second);
}
}
FunctionHandle equalsFunctionHandle = functionAndTypeResolver.resolveOperator(EQUAL, fromTypes(first.getType(), second.getType()));
// equal(cast(first as <common type>), cast(second as <common type>))
RowExpression equal = call(EQUAL.name(), equalsFunctionHandle, BOOLEAN, first, second);

// if (equal(cast(first as <common type>), cast(second as <common type>)), cast(null as firstType), first)
return specialForm(IF, returnType, equal, constantNull(originalFirst.getType()), originalFirst);
return specialForm(IF, returnType, equal, constantNull(returnType), firstArgWithoutCast);
}
return specialForm(getSourceLocation(node), NULL_IF, getType(node), first, second);
return specialForm(getSourceLocation(node), NULL_IF, returnType, first, second);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -611,22 +611,6 @@ TypedExprPtr convertDereferenceExpr(

return std::make_shared<FieldAccessTypedExpr>(returnType, input, childName);
}

TypedExprPtr convertNullIfExpr(
const velox::TypePtr& returnType,
const std::vector<TypedExprPtr>& args) {
VELOX_CHECK_EQ(args.size(), 2);

// Convert nullif(a, b) to if(a = b, null, a).

std::vector<TypedExprPtr> newArgs = {
std::make_shared<CallTypedExpr>(
velox::BOOLEAN(), args, "presto.default.eq"),
std::make_shared<ConstantTypedExpr>(
returnType, velox::variant::null(returnType->kind())),
args[0]};
return std::make_shared<CallTypedExpr>(returnType, newArgs, "if");
}
} // namespace

TypedExprPtr VeloxExprConverter::toVeloxExpr(
Expand Down Expand Up @@ -657,7 +641,7 @@ TypedExprPtr VeloxExprConverter::toVeloxExpr(
}

if (pexpr->form == protocol::Form::NULL_IF) {
return convertNullIfExpr(returnType, args);
VELOX_UNREACHABLE("NULL_IF not supported in specialForm")
}

auto form = std::string(json(pexpr->form));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ public void testTopN()
assertQuery("SELECT linenumber, NULL FROM lineitem ORDER BY 1 LIMIT 23");
}

@Test (enabled = false)
@Test
public void testNullIf()
{
assertQuery("SELECT NULLIF(totalprice, 0) FROM (SELECT SUM(extendedprice) AS totalprice FROM lineitem WHERE shipdate >= '1995-09-01')");
Expand Down Expand Up @@ -963,7 +963,7 @@ public void testInsertIntoSpecialPartitionName()

// For special character in partition name, without correct handling, it would throw errors like 'Invalid partition spec: nationkey=A/B'
// In this test, verify those partition names can be successfully created
String[] specialCharacters = new String[]{"\"", "#", "%", "''", "*", "/", ":", "=", "?", "\\", "\\x7F", "{", "[", "]", "^"}; // escape single quote for sql
String[] specialCharacters = new String[] {"\"", "#", "%", "''", "*", "/", ":", "=", "?", "\\", "\\x7F", "{", "[", "]", "^"}; // escape single quote for sql
for (String specialCharacter : specialCharacters) {
getQueryRunner().execute(writeSession, String.format("INSERT INTO %s VALUES ('name', 'A%sB')", tmpTableName, specialCharacter));
assertQuery(String.format("SELECT nationkey FROM %s", tmpTableName), String.format("VALUES('A%sB')", specialCharacter));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ public static QueryRunner createNativeQueryRunner(
ImmutableMap.<String, String>builder()
.put("http-server.http.port", "8080")
.put("experimental.internal-communication.thrift-transport-enabled", String.valueOf(useThrift))
.put("native-execution-enabled", "true")
.putAll(getNativeWorkerSystemProperties())
.build(),
ImmutableMap.of(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,8 @@ public void testMapOnlyQueries()
{
assertQuery("SELECT * FROM orders");
assertQuery("SELECT orderkey, custkey FROM orders WHERE orderkey <= 200");
assertQuery("SELECT orderkey, custkey FROM orders ORDER BY orderkey LIMIT 4");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you remove this query? Is this accidental?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not removed, just moved around. It is the last line in the unit test
This was extracted out into a different unit test: 1d331e0

}

@Test (enabled = false)
public void testNullIf()
{
assertQuery("SELECT nullif(orderkey, custkey) FROM orders");
assertQuery("SELECT orderkey, custkey FROM orders ORDER BY orderkey LIMIT 4");
}

@Test
Expand Down