From 5f23708ba1930a26785c09f1289dbdb32e155346 Mon Sep 17 00:00:00 2001 From: Pranjal Shankhdhar Date: Wed, 30 Jun 2021 17:50:32 +0100 Subject: [PATCH] Support distinct types --- .../presto/execution/CreateTypeTask.java | 22 +- .../presto/metadata/SignatureBinder.java | 14 +- .../facebook/presto/tests/H2QueryRunner.java | 289 +++++++----------- .../presto/tests/TestSqlFunctions.java | 16 +- .../presto/tests/TestUserDefinedTypes.java | 12 +- 5 files changed, 145 insertions(+), 208 deletions(-) diff --git a/presto-main/src/main/java/com/facebook/presto/execution/CreateTypeTask.java b/presto-main/src/main/java/com/facebook/presto/execution/CreateTypeTask.java index 6c37143497352..c37320b0bbf64 100644 --- a/presto-main/src/main/java/com/facebook/presto/execution/CreateTypeTask.java +++ b/presto-main/src/main/java/com/facebook/presto/execution/CreateTypeTask.java @@ -21,7 +21,6 @@ import com.facebook.presto.common.type.UserDefinedType; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.security.AccessControl; -import com.facebook.presto.spi.PrestoException; import com.facebook.presto.sql.parser.SqlParser; import com.facebook.presto.sql.tree.CreateType; import com.facebook.presto.sql.tree.Expression; @@ -36,7 +35,6 @@ import static com.facebook.presto.common.type.StandardTypes.ROW; import static com.facebook.presto.common.type.TypeSignature.parseTypeSignature; -import static com.facebook.presto.spi.StandardErrorCode.NOT_SUPPORTED; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.util.concurrent.Futures.immediateFuture; import static java.lang.String.format; @@ -68,16 +66,20 @@ public String explain(CreateType statement, List parameters) @Override public ListenableFuture execute(CreateType statement, TransactionManager transactionManager, Metadata metadata, AccessControl accessControl, QueryStateMachine stateMachine, List parameters) { + TypeSignature signature; + if (statement.getDistinctType().isPresent()) { - throw new PrestoException(NOT_SUPPORTED, "Creating distinct types is not yet supported"); + signature = new TypeSignature(statement.getDistinctType().get()); + } + else { + List typeParameters = + Streams.zip( + statement.getParameterNames().stream(), + statement.getParameterTypes().stream(), + (name, type) -> TypeSignatureParameter.of(new NamedTypeSignature(Optional.of(new RowFieldName(name, false)), parseTypeSignature(type)))) + .collect(toImmutableList()); + signature = new TypeSignature(ROW, typeParameters); } - - List typeParameters = Streams.zip( - statement.getParameterNames().stream(), - statement.getParameterTypes().stream(), - (name, type) -> TypeSignatureParameter.of(new NamedTypeSignature(Optional.of(new RowFieldName(name, false)), parseTypeSignature(type)))) - .collect(toImmutableList()); - TypeSignature signature = new TypeSignature(ROW, typeParameters); UserDefinedType userDefinedType = new UserDefinedType(QualifiedObjectName.valueOf(statement.getTypeName().toString()), signature); metadata.getFunctionAndTypeManager().addUserDefinedType(userDefinedType); diff --git a/presto-main/src/main/java/com/facebook/presto/metadata/SignatureBinder.java b/presto-main/src/main/java/com/facebook/presto/metadata/SignatureBinder.java index 47ffd1182ff78..31572056413a7 100644 --- a/presto-main/src/main/java/com/facebook/presto/metadata/SignatureBinder.java +++ b/presto-main/src/main/java/com/facebook/presto/metadata/SignatureBinder.java @@ -534,9 +534,19 @@ private boolean satisfiesCoercion(boolean allowCoercion, Type fromType, TypeSign if (allowCoercion) { return functionAndTypeManager.canCoerce(fromType, functionAndTypeManager.getType(toTypeSignature)); } - else { - return fromType.getTypeSignature().equals(toTypeSignature); + else if (fromType.getTypeSignature().equals(toTypeSignature)) { + return true; + } + // Convert user defined type to base type + else if (fromType instanceof TypeWithName) { + return fromType.getTypeSignature().getTypeSignatureBase().getStandardTypeBase().equals(toTypeSignature.getTypeSignatureBase()); } + // Convert base type to user defined type + if (toTypeSignature.getTypeSignatureBase().hasTypeName() && toTypeSignature.getTypeSignatureBase().hasStandardType()) { + return fromType.getTypeSignature().getBase().equals(toTypeSignature.getTypeSignatureBase().getStandardTypeBase()); + } + + return false; } private static List getLambdaArgumentTypeSignatures(TypeSignature lambdaTypeSignature) diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/H2QueryRunner.java b/presto-tests/src/main/java/com/facebook/presto/tests/H2QueryRunner.java index 332ac34b95af5..acc25ac9efbb1 100644 --- a/presto-tests/src/main/java/com/facebook/presto/tests/H2QueryRunner.java +++ b/presto-tests/src/main/java/com/facebook/presto/tests/H2QueryRunner.java @@ -20,6 +20,7 @@ import com.facebook.presto.common.type.RowType; import com.facebook.presto.common.type.TimestampType; import com.facebook.presto.common.type.Type; +import com.facebook.presto.common.type.TypeWithName; import com.facebook.presto.common.type.VarcharType; import com.facebook.presto.spi.ColumnMetadata; import com.facebook.presto.spi.ConnectorTableMetadata; @@ -210,6 +211,113 @@ private static RowMapper rowMapper(List types) { return new RowMapper() { + private Object getValue(Type type, ResultSet resultSet, int position) + throws SQLException + { + if (BOOLEAN.equals(type)) { + boolean booleanValue = resultSet.getBoolean(position); + return resultSet.wasNull() ? null : booleanValue; + } + else if (TINYINT.equals(type)) { + byte byteValue = resultSet.getByte(position); + return resultSet.wasNull() ? null : byteValue; + } + else if (SMALLINT.equals(type)) { + short shortValue = resultSet.getShort(position); + return resultSet.wasNull() ? null : shortValue; + } + else if (INTEGER.equals(type)) { + int intValue = resultSet.getInt(position); + return resultSet.wasNull() ? null : intValue; + } + else if (BIGINT.equals(type)) { + long longValue = resultSet.getLong(position); + return resultSet.wasNull() ? null : longValue; + } + else if (REAL.equals(type)) { + float floatValue = resultSet.getFloat(position); + return resultSet.wasNull() ? null : floatValue; + } + else if (DOUBLE.equals(type)) { + double doubleValue = resultSet.getDouble(position); + return resultSet.wasNull() ? null : doubleValue; + } + else if (isVarcharType(type)) { + String stringValue = resultSet.getString(position); + return resultSet.wasNull() ? null : stringValue; + } + else if (isCharType(type)) { + String stringValue = resultSet.getString(position); + return resultSet.wasNull() ? null : padEnd(stringValue, ((CharType) type).getLength(), ' '); + } + else if (VARBINARY.equals(type)) { + byte[] binary = resultSet.getBytes(position); + return resultSet.wasNull() ? null : binary; + } + else if (DATE.equals(type)) { + // resultSet.getDate(i) doesn't work if JVM's zone skipped day being retrieved (e.g. 2011-12-30 and Pacific/Apia zone) + LocalDate dateValue = resultSet.getObject(position, LocalDate.class); + return resultSet.wasNull() ? null : dateValue; + } + else if (TIME.equals(type)) { + // resultSet.getTime(i) doesn't work if JVM's zone had forward offset change during 1970-01-01 (e.g. America/Hermosillo zone) + LocalTime timeValue = resultSet.getObject(position, LocalTime.class); + return resultSet.wasNull() ? null : timeValue; + } + else if (TIME_WITH_TIME_ZONE.equals(type)) { + throw new UnsupportedOperationException("H2 does not support TIME WITH TIME ZONE"); + } + else if (TIMESTAMP.equals(type)) { + // resultSet.getTimestamp(i) doesn't work if JVM's zone had forward offset at the date/time being retrieved + LocalDateTime timestampValue; + try { + timestampValue = resultSet.getObject(position, LocalDateTime.class); + } + catch (SQLException first) { + // H2 cannot convert DATE to LocalDateTime in their JDBC driver (even though it can convert to java.sql.Timestamp), we need to do this manually + try { + timestampValue = Optional.ofNullable(resultSet.getObject(position, LocalDate.class)).map(LocalDate::atStartOfDay).orElse(null); + } + catch (RuntimeException e) { + first.addSuppressed(e); + throw first; + } + } + return resultSet.wasNull() ? null : timestampValue; + } + else if (TIMESTAMP_WITH_TIME_ZONE.equals(type)) { + // H2 supports TIMESTAMP WITH TIME ZONE via org.h2.api.TimestampWithTimeZone, but it represent only a fixed-offset TZ (not named) + // This means H2 is unsuitable for testing TIMESTAMP WITH TIME ZONE-bearing queries. Those need to be tested manually. + throw new UnsupportedOperationException(); + } + else if (UNKNOWN.equals(type)) { + Object objectValue = resultSet.getObject(position); + checkState(resultSet.wasNull(), "Expected a null value, but got %s", objectValue); + return null; + } + else if (type instanceof DecimalType) { + DecimalType decimalType = (DecimalType) type; + BigDecimal decimalValue = resultSet.getBigDecimal(position); + return resultSet.wasNull() ? null : decimalValue + .setScale(decimalType.getScale(), BigDecimal.ROUND_HALF_UP) + .round(new MathContext(decimalType.getPrecision())); + } + else if (type instanceof ArrayType) { + Array array = resultSet.getArray(position); + return resultSet.wasNull() ? null : newArrayList(mapArrayValues(((ArrayType) type), (Object[]) array.getArray())); + } + else if (type instanceof RowType) { + Array array = resultSet.getArray(position); + return resultSet.wasNull() ? null : newArrayList(mapRowValues((RowType) type, (Object[]) array.getArray())); + } + else if (type instanceof TypeWithName) { + return getValue(((TypeWithName) type).getType(), resultSet, position); + } + else { + throw new AssertionError("unhandled type: " + type); + } + } + @Override public MaterializedRow map(ResultSet resultSet, StatementContext context) throws SQLException @@ -218,186 +326,7 @@ public MaterializedRow map(ResultSet resultSet, StatementContext context) checkArgument(types.size() == count, "expected types count (%s) does not match actual column count (%s)", types.size(), count); List row = new ArrayList<>(count); for (int i = 1; i <= count; i++) { - Type type = types.get(i - 1); - if (BOOLEAN.equals(type)) { - boolean booleanValue = resultSet.getBoolean(i); - if (resultSet.wasNull()) { - row.add(null); - } - else { - row.add(booleanValue); - } - } - else if (TINYINT.equals(type)) { - byte byteValue = resultSet.getByte(i); - if (resultSet.wasNull()) { - row.add(null); - } - else { - row.add(byteValue); - } - } - else if (SMALLINT.equals(type)) { - short shortValue = resultSet.getShort(i); - if (resultSet.wasNull()) { - row.add(null); - } - else { - row.add(shortValue); - } - } - else if (INTEGER.equals(type)) { - int intValue = resultSet.getInt(i); - if (resultSet.wasNull()) { - row.add(null); - } - else { - row.add(intValue); - } - } - else if (BIGINT.equals(type)) { - long longValue = resultSet.getLong(i); - if (resultSet.wasNull()) { - row.add(null); - } - else { - row.add(longValue); - } - } - else if (REAL.equals(type)) { - float floatValue = resultSet.getFloat(i); - if (resultSet.wasNull()) { - row.add(null); - } - else { - row.add(floatValue); - } - } - else if (DOUBLE.equals(type)) { - double doubleValue = resultSet.getDouble(i); - if (resultSet.wasNull()) { - row.add(null); - } - else { - row.add(doubleValue); - } - } - else if (isVarcharType(type)) { - String stringValue = resultSet.getString(i); - if (resultSet.wasNull()) { - row.add(null); - } - else { - row.add(stringValue); - } - } - else if (isCharType(type)) { - String stringValue = resultSet.getString(i); - if (resultSet.wasNull()) { - row.add(null); - } - else { - row.add(padEnd(stringValue, ((CharType) type).getLength(), ' ')); - } - } - else if (VARBINARY.equals(type)) { - byte[] binary = resultSet.getBytes(i); - if (resultSet.wasNull()) { - row.add(null); - } - else { - row.add(binary); - } - } - else if (DATE.equals(type)) { - // resultSet.getDate(i) doesn't work if JVM's zone skipped day being retrieved (e.g. 2011-12-30 and Pacific/Apia zone) - LocalDate dateValue = resultSet.getObject(i, LocalDate.class); - if (resultSet.wasNull()) { - row.add(null); - } - else { - row.add(dateValue); - } - } - else if (TIME.equals(type)) { - // resultSet.getTime(i) doesn't work if JVM's zone had forward offset change during 1970-01-01 (e.g. America/Hermosillo zone) - LocalTime timeValue = resultSet.getObject(i, LocalTime.class); - if (resultSet.wasNull()) { - row.add(null); - } - else { - row.add(timeValue); - } - } - else if (TIME_WITH_TIME_ZONE.equals(type)) { - throw new UnsupportedOperationException("H2 does not support TIME WITH TIME ZONE"); - } - else if (TIMESTAMP.equals(type)) { - // resultSet.getTimestamp(i) doesn't work if JVM's zone had forward offset at the date/time being retrieved - LocalDateTime timestampValue; - try { - timestampValue = resultSet.getObject(i, LocalDateTime.class); - } - catch (SQLException first) { - // H2 cannot convert DATE to LocalDateTime in their JDBC driver (even though it can convert to java.sql.Timestamp), we need to do this manually - try { - timestampValue = Optional.ofNullable(resultSet.getObject(i, LocalDate.class)).map(LocalDate::atStartOfDay).orElse(null); - } - catch (RuntimeException e) { - first.addSuppressed(e); - throw first; - } - } - if (resultSet.wasNull()) { - row.add(null); - } - else { - row.add(timestampValue); - } - } - else if (TIMESTAMP_WITH_TIME_ZONE.equals(type)) { - // H2 supports TIMESTAMP WITH TIME ZONE via org.h2.api.TimestampWithTimeZone, but it represent only a fixed-offset TZ (not named) - // This means H2 is unsuitable for testing TIMESTAMP WITH TIME ZONE-bearing queries. Those need to be tested manually. - throw new UnsupportedOperationException(); - } - else if (UNKNOWN.equals(type)) { - Object objectValue = resultSet.getObject(i); - checkState(resultSet.wasNull(), "Expected a null value, but got %s", objectValue); - row.add(null); - } - else if (type instanceof DecimalType) { - DecimalType decimalType = (DecimalType) type; - BigDecimal decimalValue = resultSet.getBigDecimal(i); - if (resultSet.wasNull()) { - row.add(null); - } - else { - row.add(decimalValue - .setScale(decimalType.getScale(), BigDecimal.ROUND_HALF_UP) - .round(new MathContext(decimalType.getPrecision()))); - } - } - else if (type instanceof ArrayType) { - Array array = resultSet.getArray(i); - if (resultSet.wasNull()) { - row.add(null); - } - else { - row.add(newArrayList(mapArrayValues(((ArrayType) type), (Object[]) array.getArray()))); - } - } - else if (type instanceof RowType) { - Array array = resultSet.getArray(i); - if (resultSet.wasNull()) { - row.add(null); - } - else { - row.add(newArrayList(mapRowValues((RowType) type, (Object[]) array.getArray()))); - } - } - else { - throw new AssertionError("unhandled type: " + type); - } + row.add(getValue(types.get(i - 1), resultSet, i)); } return new MaterializedRow(MaterializedResult.DEFAULT_PRECISION, row); } diff --git a/presto-tests/src/test/java/com/facebook/presto/tests/TestSqlFunctions.java b/presto-tests/src/test/java/com/facebook/presto/tests/TestSqlFunctions.java index 63819e1d55e63..f8e40c5d53a25 100644 --- a/presto-tests/src/test/java/com/facebook/presto/tests/TestSqlFunctions.java +++ b/presto-tests/src/test/java/com/facebook/presto/tests/TestSqlFunctions.java @@ -15,8 +15,6 @@ import com.facebook.presto.Session; import com.facebook.presto.common.QualifiedObjectName; -import com.facebook.presto.common.type.NamedTypeSignature; -import com.facebook.presto.common.type.RowFieldName; import com.facebook.presto.common.type.TypeSignature; import com.facebook.presto.common.type.TypeSignatureParameter; import com.facebook.presto.common.type.UserDefinedType; @@ -35,13 +33,9 @@ import org.testng.annotations.Test; import java.util.List; -import java.util.Optional; import static com.facebook.presto.common.type.BigintEnumType.LongEnumMap; import static com.facebook.presto.common.type.StandardTypes.BIGINT_ENUM; -import static com.facebook.presto.common.type.StandardTypes.ROW; -import static com.facebook.presto.common.type.StandardTypes.TINYINT; -import static com.facebook.presto.common.type.StandardTypes.VARCHAR; import static com.facebook.presto.common.type.StandardTypes.VARCHAR_ENUM; import static com.facebook.presto.common.type.TypeSignature.parseTypeSignature; import static com.facebook.presto.common.type.VarcharEnumType.VarcharEnumMap; @@ -72,13 +66,6 @@ public class TestSqlFunctions "CHINA", "中国", "भारत", "India"))))); - private static final UserDefinedType PERSON = new UserDefinedType(QualifiedObjectName.valueOf("testing.type.person"), new TypeSignature( - ROW, - TypeSignatureParameter.of(new NamedTypeSignature(Optional.of(new RowFieldName("first_name", false)), new TypeSignature(VARCHAR))), - TypeSignatureParameter.of(new NamedTypeSignature(Optional.of(new RowFieldName("last_name", false)), new TypeSignature(VARCHAR))), - TypeSignatureParameter.of(new NamedTypeSignature(Optional.of(new RowFieldName("age", false)), new TypeSignature(TINYINT))), - TypeSignatureParameter.of(new NamedTypeSignature(Optional.of(new RowFieldName("country", false)), new TypeSignature("testing.enum.country"))))); - protected TestSqlFunctions() { TestingThriftUdfServer.start(ImmutableMap.of("thrift.server.port", "7779")); @@ -110,7 +97,8 @@ protected QueryRunner createQueryRunner() queryRunner.createTestFunctionNamespace("example", "example"); queryRunner.getMetadata().getFunctionAndTypeManager().addUserDefinedType(MOOD_ENUM); queryRunner.getMetadata().getFunctionAndTypeManager().addUserDefinedType(COUNTRY_ENUM); - queryRunner.getMetadata().getFunctionAndTypeManager().addUserDefinedType(PERSON); + + queryRunner.execute("CREATE TYPE testing.type.person AS (first_name varchar, last_name varchar, age tinyint, country testing.enum.country)"); return queryRunner; } diff --git a/presto-tests/src/test/java/com/facebook/presto/tests/TestUserDefinedTypes.java b/presto-tests/src/test/java/com/facebook/presto/tests/TestUserDefinedTypes.java index 924f7426f929b..ffe80aed4d8e4 100644 --- a/presto-tests/src/test/java/com/facebook/presto/tests/TestUserDefinedTypes.java +++ b/presto-tests/src/test/java/com/facebook/presto/tests/TestUserDefinedTypes.java @@ -35,13 +35,21 @@ protected QueryRunner createQueryRunner() } @Test - public void testCreateType() + public void testStructType() { - assertQueryFails("CREATE TYPE testing.type.num AS integer", "Creating distinct types is not yet supported"); assertQuerySucceeds("CREATE TYPE testing.type.pair AS (fst integer, snd integer)"); assertQuerySucceeds("CREATE TYPE testing.type.pair3 AS (fst testing.type.pair, snd integer)"); assertQuery("SELECT p.fst.fst FROM(SELECT CAST(ROW(CAST(ROW(1,2) AS testing.type.pair), 3) AS testing.type.pair3) AS p)", "SELECT 1"); assertQuerySucceeds("CREATE TYPE testing.type.pair3Alt AS (fst ROW(fst integer, snd integer), snd integer)"); assertQuery("SELECT p.fst.snd FROM(SELECT CAST(ROW(ROW(1,2), 3) AS testing.type.pair3Alt) AS p)", "SELECT 2"); } + + @Test + public void testDistinctType() + { + assertQuerySucceeds("CREATE TYPE testing.type.num AS integer"); + assertQuery("SELECT x FROM (SELECT CAST(4 as testing.type.num) AS x)", "SELECT 4"); + assertQuerySucceeds("CREATE TYPE testing.type.mypair AS (fst testing.type.num, snd integer)"); + assertQuery("SELECT p.fst FROM (SELECT CAST(ROW(CAST(4 AS testing.type.num),3) as testing.type.mypair) AS p)", "SELECT 4"); + } }