diff --git a/presto-base-jdbc/src/main/java/io/prestosql/plugin/jdbc/BaseJdbcClient.java b/presto-base-jdbc/src/main/java/io/prestosql/plugin/jdbc/BaseJdbcClient.java index 3920483b2449..2cbdddd05f7e 100644 --- a/presto-base-jdbc/src/main/java/io/prestosql/plugin/jdbc/BaseJdbcClient.java +++ b/presto-base-jdbc/src/main/java/io/prestosql/plugin/jdbc/BaseJdbcClient.java @@ -53,7 +53,20 @@ import static com.google.common.base.Strings.isNullOrEmpty; import static com.google.common.collect.Iterables.getOnlyElement; import static io.prestosql.plugin.jdbc.JdbcErrorCode.JDBC_ERROR; -import static io.prestosql.plugin.jdbc.StandardReadMappings.jdbcTypeToPrestoType; +import static io.prestosql.plugin.jdbc.StandardColumnMappings.bigintWriteFunction; +import static io.prestosql.plugin.jdbc.StandardColumnMappings.booleanWriteFunction; +import static io.prestosql.plugin.jdbc.StandardColumnMappings.charWriteFunction; +import static io.prestosql.plugin.jdbc.StandardColumnMappings.dateWriteFunction; +import static io.prestosql.plugin.jdbc.StandardColumnMappings.doubleWriteFunction; +import static io.prestosql.plugin.jdbc.StandardColumnMappings.integerWriteFunction; +import static io.prestosql.plugin.jdbc.StandardColumnMappings.jdbcTypeToPrestoType; +import static io.prestosql.plugin.jdbc.StandardColumnMappings.longDecimalWriteFunction; +import static io.prestosql.plugin.jdbc.StandardColumnMappings.realWriteFunction; +import static io.prestosql.plugin.jdbc.StandardColumnMappings.shortDecimalWriteFunction; +import static io.prestosql.plugin.jdbc.StandardColumnMappings.smallintWriteFunction; +import static io.prestosql.plugin.jdbc.StandardColumnMappings.tinyintWriteFunction; +import static io.prestosql.plugin.jdbc.StandardColumnMappings.varbinaryWriteFunction; +import static io.prestosql.plugin.jdbc.StandardColumnMappings.varcharWriteFunction; import static io.prestosql.spi.StandardErrorCode.NOT_FOUND; import static io.prestosql.spi.StandardErrorCode.NOT_SUPPORTED; import static io.prestosql.spi.type.BigintType.BIGINT; @@ -77,16 +90,16 @@ public class BaseJdbcClient { private static final Logger log = Logger.get(BaseJdbcClient.class); - private static final Map SQL_TYPES = ImmutableMap.builder() - .put(BOOLEAN, "boolean") - .put(BIGINT, "bigint") - .put(INTEGER, "integer") - .put(SMALLINT, "smallint") - .put(TINYINT, "tinyint") - .put(DOUBLE, "double precision") - .put(REAL, "real") - .put(VARBINARY, "varbinary") - .put(DATE, "date") + private static final Map WRITE_MAPPINGS = ImmutableMap.builder() + .put(BOOLEAN, WriteMapping.booleanMapping("boolean", booleanWriteFunction())) + .put(BIGINT, WriteMapping.longMapping("bigint", bigintWriteFunction())) + .put(INTEGER, WriteMapping.longMapping("integer", integerWriteFunction())) + .put(SMALLINT, WriteMapping.longMapping("smallint", smallintWriteFunction())) + .put(TINYINT, WriteMapping.longMapping("tinyint", tinyintWriteFunction())) + .put(DOUBLE, WriteMapping.doubleMapping("double precision", doubleWriteFunction())) + .put(REAL, WriteMapping.longMapping("real", realWriteFunction())) + .put(VARBINARY, WriteMapping.sliceMapping("varbinary", varbinaryWriteFunction())) + .put(DATE, WriteMapping.longMapping("date", dateWriteFunction())) .build(); protected final String connectorId; @@ -197,7 +210,7 @@ public List getColumns(ConnectorSession session, JdbcTableHand resultSet.getString("TYPE_NAME"), resultSet.getInt("COLUMN_SIZE"), resultSet.getInt("DECIMAL_DIGITS")); - Optional columnMapping = toPrestoType(session, typeHandle); + Optional columnMapping = toPrestoType(session, typeHandle); // skip unsupported column types if (columnMapping.isPresent()) { String columnName = resultSet.getString("COLUMN_NAME"); @@ -217,7 +230,7 @@ public List getColumns(ConnectorSession session, JdbcTableHand } @Override - public Optional toPrestoType(ConnectorSession session, JdbcTypeHandle typeHandle) + public Optional toPrestoType(ConnectorSession session, JdbcTypeHandle typeHandle) { return jdbcTypeToPrestoType(typeHandle); } @@ -252,11 +265,12 @@ public Connection getConnection(JdbcSplit split) } @Override - public PreparedStatement buildSql(Connection connection, JdbcSplit split, List columnHandles) + public PreparedStatement buildSql(ConnectorSession session, Connection connection, JdbcSplit split, List columnHandles) throws SQLException { return new QueryBuilder(identifierQuote).buildSql( this, + session, connection, split.getCatalogName(), split.getSchemaName(), @@ -308,7 +322,8 @@ private JdbcOutputTableHandle beginWriteTable(ConnectorTableMetadata tableMetada } columnNames.add(columnName); columnTypes.add(column.getType()); - columnList.add(format("%s %s", quoted(columnName), toSqlType(column.getType()))); + // TODO in INSERT case, we should reuse original column type and, ideally, constraints (then JdbcPageSink must get writer from toPrestoType()) + columnList.add(format("%s %s", quoted(columnName), toWriteMapping(column.getType()).getDataType())); } String sql = format( @@ -457,28 +472,43 @@ protected void execute(Connection connection, String query) } } - protected String toSqlType(Type type) + @Override + public WriteMapping toWriteMapping(Type type) { if (isVarcharType(type)) { VarcharType varcharType = (VarcharType) type; + String dataType; if (varcharType.isUnbounded()) { - return "varchar"; + dataType = "varchar"; + } + else { + dataType = "varchar(" + varcharType.getBoundedLength() + ")"; } - return "varchar(" + varcharType.getBoundedLength() + ")"; + return WriteMapping.sliceMapping(dataType, varcharWriteFunction()); } if (type instanceof CharType) { - if (((CharType) type).getLength() == CharType.MAX_LENGTH) { - return "char"; + CharType charType = (CharType) type; + String dataType; + if (charType.getLength() == CharType.MAX_LENGTH) { + dataType = "char"; } - return "char(" + ((CharType) type).getLength() + ")"; + else { + dataType = "char(" + charType.getLength() + ")"; + } + return WriteMapping.sliceMapping(dataType, charWriteFunction(charType)); } if (type instanceof DecimalType) { - return format("decimal(%s, %s)", ((DecimalType) type).getPrecision(), ((DecimalType) type).getScale()); + DecimalType decimalType = (DecimalType) type; + String dataType = format("decimal(%s, %s)", decimalType.getPrecision(), decimalType.getScale()); + if (decimalType.isShort()) { + return WriteMapping.longMapping(dataType, shortDecimalWriteFunction(decimalType)); + } + return WriteMapping.sliceMapping(dataType, longDecimalWriteFunction(decimalType)); } - String sqlType = SQL_TYPES.get(type); - if (sqlType != null) { - return sqlType; + WriteMapping writeMapping = WRITE_MAPPINGS.get(type); + if (writeMapping != null) { + return writeMapping; } throw new PrestoException(NOT_SUPPORTED, "Unsupported column type: " + type.getDisplayName()); } diff --git a/presto-base-jdbc/src/main/java/io/prestosql/plugin/jdbc/BooleanWriteFunction.java b/presto-base-jdbc/src/main/java/io/prestosql/plugin/jdbc/BooleanWriteFunction.java new file mode 100644 index 000000000000..a38ef5bd31f9 --- /dev/null +++ b/presto-base-jdbc/src/main/java/io/prestosql/plugin/jdbc/BooleanWriteFunction.java @@ -0,0 +1,30 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.plugin.jdbc; + +import java.sql.PreparedStatement; +import java.sql.SQLException; + +public interface BooleanWriteFunction + extends WriteFunction +{ + @Override + default Class getJavaType() + { + return boolean.class; + } + + void set(PreparedStatement statement, int index, boolean value) + throws SQLException; +} diff --git a/presto-base-jdbc/src/main/java/io/prestosql/plugin/jdbc/ColumnMapping.java b/presto-base-jdbc/src/main/java/io/prestosql/plugin/jdbc/ColumnMapping.java new file mode 100644 index 000000000000..a3057eaa2a1c --- /dev/null +++ b/presto-base-jdbc/src/main/java/io/prestosql/plugin/jdbc/ColumnMapping.java @@ -0,0 +1,119 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.plugin.jdbc; + +import io.prestosql.spi.predicate.Domain; +import io.prestosql.spi.type.Type; + +import java.util.function.UnaryOperator; + +import static com.google.common.base.MoreObjects.toStringHelper; +import static com.google.common.base.Preconditions.checkArgument; +import static java.util.Objects.requireNonNull; + +public final class ColumnMapping +{ + public static ColumnMapping booleanMapping(Type prestoType, BooleanReadFunction readFunction, BooleanWriteFunction writeFunction) + { + return booleanMapping(prestoType, readFunction, writeFunction, UnaryOperator.identity()); + } + + public static ColumnMapping booleanMapping(Type prestoType, BooleanReadFunction readFunction, BooleanWriteFunction writeFunction, UnaryOperator pushdownConverter) + { + return new ColumnMapping(prestoType, readFunction, writeFunction, pushdownConverter); + } + + public static ColumnMapping longMapping(Type prestoType, LongReadFunction readFunction, LongWriteFunction writeFunction) + { + return longMapping(prestoType, readFunction, writeFunction, UnaryOperator.identity()); + } + + public static ColumnMapping longMapping(Type prestoType, LongReadFunction readFunction, LongWriteFunction writeFunction, UnaryOperator pushdownConverter) + { + return new ColumnMapping(prestoType, readFunction, writeFunction, pushdownConverter); + } + + public static ColumnMapping doubleMapping(Type prestoType, DoubleReadFunction readFunction, DoubleWriteFunction writeFunction) + { + return doubleMapping(prestoType, readFunction, writeFunction, UnaryOperator.identity()); + } + + public static ColumnMapping doubleMapping(Type prestoType, DoubleReadFunction readFunction, DoubleWriteFunction writeFunction, UnaryOperator pushdownConverter) + { + return new ColumnMapping(prestoType, readFunction, writeFunction, pushdownConverter); + } + + public static ColumnMapping sliceMapping(Type prestoType, SliceReadFunction readFunction, SliceWriteFunction writeFunction) + { + return sliceMapping(prestoType, readFunction, writeFunction, UnaryOperator.identity()); + } + + public static ColumnMapping sliceMapping(Type prestoType, SliceReadFunction readFunction, SliceWriteFunction writeFunction, UnaryOperator pushdownConverter) + { + return new ColumnMapping(prestoType, readFunction, writeFunction, pushdownConverter); + } + + private final Type type; + private final ReadFunction readFunction; + private final WriteFunction writeFunction; + private final UnaryOperator pushdownConverter; + + private ColumnMapping(Type type, ReadFunction readFunction, WriteFunction writeFunction, UnaryOperator pushdownConverter) + { + this.type = requireNonNull(type, "type is null"); + this.readFunction = requireNonNull(readFunction, "readFunction is null"); + this.writeFunction = requireNonNull(writeFunction, "writeFunction is null"); + checkArgument( + type.getJavaType() == readFunction.getJavaType(), + "Presto type %s is not compatible with read function %s returning %s", + type, + readFunction, + readFunction.getJavaType()); + checkArgument( + type.getJavaType() == writeFunction.getJavaType(), + "Presto type %s is not compatible with write function %s accepting %s", + type, + writeFunction, + writeFunction.getJavaType()); + this.pushdownConverter = requireNonNull(pushdownConverter, "pushdownConverter is null"); + } + + public Type getType() + { + return type; + } + + public ReadFunction getReadFunction() + { + return readFunction; + } + + public WriteFunction getWriteFunction() + { + return writeFunction; + } + + public UnaryOperator getPushdownConverter() + { + return pushdownConverter; + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("type", type) + .toString(); + } +} diff --git a/presto-base-jdbc/src/main/java/io/prestosql/plugin/jdbc/DoubleWriteFunction.java b/presto-base-jdbc/src/main/java/io/prestosql/plugin/jdbc/DoubleWriteFunction.java new file mode 100644 index 000000000000..c49b543d820a --- /dev/null +++ b/presto-base-jdbc/src/main/java/io/prestosql/plugin/jdbc/DoubleWriteFunction.java @@ -0,0 +1,30 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.plugin.jdbc; + +import java.sql.PreparedStatement; +import java.sql.SQLException; + +public interface DoubleWriteFunction + extends WriteFunction +{ + @Override + default Class getJavaType() + { + return double.class; + } + + void set(PreparedStatement statement, int index, double value) + throws SQLException; +} diff --git a/presto-base-jdbc/src/main/java/io/prestosql/plugin/jdbc/JdbcClient.java b/presto-base-jdbc/src/main/java/io/prestosql/plugin/jdbc/JdbcClient.java index 199f982fa568..2bdf31437940 100644 --- a/presto-base-jdbc/src/main/java/io/prestosql/plugin/jdbc/JdbcClient.java +++ b/presto-base-jdbc/src/main/java/io/prestosql/plugin/jdbc/JdbcClient.java @@ -20,6 +20,7 @@ import io.prestosql.spi.connector.SchemaTableName; import io.prestosql.spi.predicate.TupleDomain; import io.prestosql.spi.statistics.TableStatistics; +import io.prestosql.spi.type.Type; import javax.annotation.Nullable; @@ -46,7 +47,9 @@ default boolean schemaExists(String schema) List getColumns(ConnectorSession session, JdbcTableHandle tableHandle); - Optional toPrestoType(ConnectorSession session, JdbcTypeHandle typeHandle); + Optional toPrestoType(ConnectorSession session, JdbcTypeHandle typeHandle); + + WriteMapping toWriteMapping(Type type); ConnectorSplitSource getSplits(JdbcTableLayoutHandle layoutHandle); @@ -59,7 +62,7 @@ default void abortReadConnection(Connection connection) // most drivers do not need this } - PreparedStatement buildSql(Connection connection, JdbcSplit split, List columnHandles) + PreparedStatement buildSql(ConnectorSession session, Connection connection, JdbcSplit split, List columnHandles) throws SQLException; JdbcOutputTableHandle beginCreateTable(ConnectorTableMetadata tableMetadata); diff --git a/presto-base-jdbc/src/main/java/io/prestosql/plugin/jdbc/JdbcPageSink.java b/presto-base-jdbc/src/main/java/io/prestosql/plugin/jdbc/JdbcPageSink.java index beb5a7b9fe05..16993ff54106 100644 --- a/presto-base-jdbc/src/main/java/io/prestosql/plugin/jdbc/JdbcPageSink.java +++ b/presto-base-jdbc/src/main/java/io/prestosql/plugin/jdbc/JdbcPageSink.java @@ -13,20 +13,16 @@ */ package io.prestosql.plugin.jdbc; +import com.google.common.base.VerifyException; import com.google.common.collect.ImmutableList; -import com.google.common.primitives.Shorts; -import com.google.common.primitives.SignedBytes; import io.airlift.slice.Slice; import io.prestosql.spi.Page; import io.prestosql.spi.PrestoException; import io.prestosql.spi.block.Block; import io.prestosql.spi.connector.ConnectorPageSink; -import io.prestosql.spi.type.DecimalType; import io.prestosql.spi.type.Type; -import org.joda.time.DateTimeZone; import java.sql.Connection; -import java.sql.Date; import java.sql.PreparedStatement; import java.sql.SQLException; import java.sql.SQLNonTransientException; @@ -34,26 +30,12 @@ import java.util.List; import java.util.concurrent.CompletableFuture; +import static com.google.common.base.Verify.verify; +import static com.google.common.collect.ImmutableList.toImmutableList; import static io.prestosql.plugin.jdbc.JdbcErrorCode.JDBC_ERROR; import static io.prestosql.plugin.jdbc.JdbcErrorCode.JDBC_NON_TRANSIENT_ERROR; -import static io.prestosql.spi.StandardErrorCode.NOT_SUPPORTED; -import static io.prestosql.spi.type.BigintType.BIGINT; -import static io.prestosql.spi.type.BooleanType.BOOLEAN; -import static io.prestosql.spi.type.Chars.isCharType; -import static io.prestosql.spi.type.DateType.DATE; -import static io.prestosql.spi.type.Decimals.readBigDecimal; -import static io.prestosql.spi.type.DoubleType.DOUBLE; -import static io.prestosql.spi.type.IntegerType.INTEGER; -import static io.prestosql.spi.type.RealType.REAL; -import static io.prestosql.spi.type.SmallintType.SMALLINT; -import static io.prestosql.spi.type.TinyintType.TINYINT; -import static io.prestosql.spi.type.VarbinaryType.VARBINARY; -import static io.prestosql.spi.type.Varchars.isVarcharType; -import static java.lang.Float.intBitsToFloat; -import static java.lang.Math.toIntExact; +import static java.lang.String.format; import static java.util.concurrent.CompletableFuture.completedFuture; -import static java.util.concurrent.TimeUnit.DAYS; -import static org.joda.time.chrono.ISOChronology.getInstanceUTC; public class JdbcPageSink implements ConnectorPageSink @@ -62,6 +44,7 @@ public class JdbcPageSink private final PreparedStatement statement; private final List columnTypes; + private final List columnWriters; private int batchSize; public JdbcPageSink(JdbcOutputTableHandle handle, JdbcClient jdbcClient) @@ -83,6 +66,19 @@ public JdbcPageSink(JdbcOutputTableHandle handle, JdbcClient jdbcClient) } columnTypes = handle.getColumnTypes(); + + columnWriters = columnTypes.stream() + .map(type -> { + WriteFunction writeFunction = jdbcClient.toWriteMapping(type).getWriteFunction(); + verify( + type.getJavaType() == writeFunction.getJavaType(), + "Presto type %s is not compatible with write function %s accepting %s", + type, + writeFunction, + writeFunction.getJavaType()); + return writeFunction; + }) + .collect(toImmutableList()); } @Override @@ -115,52 +111,30 @@ private void appendColumn(Page page, int position, int channel) throws SQLException { Block block = page.getBlock(channel); - int parameter = channel + 1; + int parameterIndex = channel + 1; if (block.isNull(position)) { - statement.setObject(parameter, null); + statement.setObject(parameterIndex, null); return; } Type type = columnTypes.get(channel); - if (BOOLEAN.equals(type)) { - statement.setBoolean(parameter, type.getBoolean(block, position)); - } - else if (BIGINT.equals(type)) { - statement.setLong(parameter, type.getLong(block, position)); - } - else if (INTEGER.equals(type)) { - statement.setInt(parameter, toIntExact(type.getLong(block, position))); - } - else if (SMALLINT.equals(type)) { - statement.setShort(parameter, Shorts.checkedCast(type.getLong(block, position))); - } - else if (TINYINT.equals(type)) { - statement.setByte(parameter, SignedBytes.checkedCast(type.getLong(block, position))); - } - else if (DOUBLE.equals(type)) { - statement.setDouble(parameter, type.getDouble(block, position)); - } - else if (REAL.equals(type)) { - statement.setFloat(parameter, intBitsToFloat(toIntExact(type.getLong(block, position)))); - } - else if (type instanceof DecimalType) { - statement.setBigDecimal(parameter, readBigDecimal((DecimalType) type, block, position)); + Class javaType = type.getJavaType(); + WriteFunction writeFunction = columnWriters.get(channel); + if (javaType == boolean.class) { + ((BooleanWriteFunction) writeFunction).set(statement, parameterIndex, type.getBoolean(block, position)); } - else if (isVarcharType(type) || isCharType(type)) { - statement.setString(parameter, type.getSlice(block, position).toStringUtf8()); + else if (javaType == long.class) { + ((LongWriteFunction) writeFunction).set(statement, parameterIndex, type.getLong(block, position)); } - else if (VARBINARY.equals(type)) { - statement.setBytes(parameter, type.getSlice(block, position).getBytes()); + else if (javaType == double.class) { + ((DoubleWriteFunction) writeFunction).set(statement, parameterIndex, type.getDouble(block, position)); } - else if (DATE.equals(type)) { - // convert to midnight in default time zone - long utcMillis = DAYS.toMillis(type.getLong(block, position)); - long localMillis = getInstanceUTC().getZone().getMillisKeepLocal(DateTimeZone.getDefault(), utcMillis); - statement.setDate(parameter, new Date(localMillis)); + else if (javaType == Slice.class) { + ((SliceWriteFunction) writeFunction).set(statement, parameterIndex, type.getSlice(block, position)); } else { - throw new PrestoException(NOT_SUPPORTED, "Unsupported column type: " + type.getDisplayName()); + throw new VerifyException(format("Unexpected type %s with java type %s", type, javaType.getName())); } } diff --git a/presto-base-jdbc/src/main/java/io/prestosql/plugin/jdbc/JdbcRecordCursor.java b/presto-base-jdbc/src/main/java/io/prestosql/plugin/jdbc/JdbcRecordCursor.java index 748b5b261988..4ec5f338f945 100644 --- a/presto-base-jdbc/src/main/java/io/prestosql/plugin/jdbc/JdbcRecordCursor.java +++ b/presto-base-jdbc/src/main/java/io/prestosql/plugin/jdbc/JdbcRecordCursor.java @@ -63,10 +63,10 @@ public JdbcRecordCursor(JdbcClient jdbcClient, ConnectorSession session, JdbcSpl sliceReadFunctions = new SliceReadFunction[columnHandles.size()]; for (int i = 0; i < this.columnHandles.length; i++) { - ReadMapping readMapping = jdbcClient.toPrestoType(session, columnHandles.get(i).getJdbcTypeHandle()) + ColumnMapping columnMapping = jdbcClient.toPrestoType(session, columnHandles.get(i).getJdbcTypeHandle()) .orElseThrow(() -> new VerifyException("Unsupported column type")); - Class javaType = readMapping.getType().getJavaType(); - ReadFunction readFunction = readMapping.getReadFunction(); + Class javaType = columnMapping.getType().getJavaType(); + ReadFunction readFunction = columnMapping.getReadFunction(); if (javaType == boolean.class) { booleanReadFunctions[i] = (BooleanReadFunction) readFunction; @@ -87,7 +87,7 @@ else if (javaType == Slice.class) { try { connection = jdbcClient.getConnection(split); - statement = jdbcClient.buildSql(connection, split, columnHandles); + statement = jdbcClient.buildSql(session, connection, split, columnHandles); log.debug("Executing: %s", statement.toString()); resultSet = statement.executeQuery(); } diff --git a/presto-base-jdbc/src/main/java/io/prestosql/plugin/jdbc/LongWriteFunction.java b/presto-base-jdbc/src/main/java/io/prestosql/plugin/jdbc/LongWriteFunction.java new file mode 100644 index 000000000000..c2d0a92d6426 --- /dev/null +++ b/presto-base-jdbc/src/main/java/io/prestosql/plugin/jdbc/LongWriteFunction.java @@ -0,0 +1,30 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.plugin.jdbc; + +import java.sql.PreparedStatement; +import java.sql.SQLException; + +public interface LongWriteFunction + extends WriteFunction +{ + @Override + default Class getJavaType() + { + return long.class; + } + + void set(PreparedStatement statement, int index, long value) + throws SQLException; +} diff --git a/presto-base-jdbc/src/main/java/io/prestosql/plugin/jdbc/QueryBuilder.java b/presto-base-jdbc/src/main/java/io/prestosql/plugin/jdbc/QueryBuilder.java index 884c94868ce5..64e9dc5cd8fe 100644 --- a/presto-base-jdbc/src/main/java/io/prestosql/plugin/jdbc/QueryBuilder.java +++ b/presto-base-jdbc/src/main/java/io/prestosql/plugin/jdbc/QueryBuilder.java @@ -14,33 +14,19 @@ package io.prestosql.plugin.jdbc; import com.google.common.base.Joiner; +import com.google.common.base.VerifyException; import com.google.common.collect.ImmutableList; import io.airlift.slice.Slice; import io.prestosql.spi.connector.ColumnHandle; +import io.prestosql.spi.connector.ConnectorSession; import io.prestosql.spi.predicate.Domain; import io.prestosql.spi.predicate.Range; import io.prestosql.spi.predicate.TupleDomain; -import io.prestosql.spi.type.BigintType; -import io.prestosql.spi.type.BooleanType; -import io.prestosql.spi.type.CharType; -import io.prestosql.spi.type.DateType; -import io.prestosql.spi.type.DoubleType; -import io.prestosql.spi.type.IntegerType; -import io.prestosql.spi.type.RealType; -import io.prestosql.spi.type.SmallintType; -import io.prestosql.spi.type.TimeType; -import io.prestosql.spi.type.TimestampType; -import io.prestosql.spi.type.TinyintType; import io.prestosql.spi.type.Type; -import io.prestosql.spi.type.VarcharType; -import org.joda.time.DateTimeZone; import java.sql.Connection; -import java.sql.Date; import java.sql.PreparedStatement; import java.sql.SQLException; -import java.sql.Time; -import java.sql.Timestamp; import java.util.ArrayList; import java.util.List; import java.util.Optional; @@ -49,12 +35,10 @@ import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Strings.isNullOrEmpty; import static com.google.common.collect.Iterables.getOnlyElement; -import static java.lang.Float.intBitsToFloat; +import static java.lang.String.format; import static java.util.Collections.nCopies; import static java.util.Objects.requireNonNull; -import static java.util.concurrent.TimeUnit.DAYS; import static java.util.stream.Collectors.joining; -import static org.joda.time.DateTimeZone.UTC; public class QueryBuilder { @@ -67,11 +51,13 @@ public class QueryBuilder private static class TypeAndValue { private final Type type; + private final JdbcTypeHandle typeHandle; private final Object value; - public TypeAndValue(Type type, Object value) + public TypeAndValue(Type type, JdbcTypeHandle typeHandle, Object value) { this.type = requireNonNull(type, "type is null"); + this.typeHandle = requireNonNull(typeHandle, "typeHandle is null"); this.value = requireNonNull(value, "value is null"); } @@ -80,6 +66,11 @@ public Type getType() return type; } + public JdbcTypeHandle getTypeHandle() + { + return typeHandle; + } + public Object getValue() { return value; @@ -93,6 +84,7 @@ public QueryBuilder(String quote) public PreparedStatement buildSql( JdbcClient client, + ConnectorSession session, Connection connection, String catalog, String schema, @@ -126,7 +118,7 @@ public PreparedStatement buildSql( List accumulator = new ArrayList<>(); - List clauses = toConjuncts(columns, tupleDomain, accumulator); + List clauses = toConjuncts(client, session, columns, tupleDomain, accumulator); if (additionalPredicate.isPresent()) { clauses = ImmutableList.builder() .addAll(clauses) @@ -142,84 +134,54 @@ public PreparedStatement buildSql( for (int i = 0; i < accumulator.size(); i++) { TypeAndValue typeAndValue = accumulator.get(i); - if (typeAndValue.getType().equals(BigintType.BIGINT)) { - statement.setLong(i + 1, (long) typeAndValue.getValue()); - } - else if (typeAndValue.getType().equals(IntegerType.INTEGER)) { - statement.setInt(i + 1, ((Number) typeAndValue.getValue()).intValue()); - } - else if (typeAndValue.getType().equals(SmallintType.SMALLINT)) { - statement.setShort(i + 1, ((Number) typeAndValue.getValue()).shortValue()); - } - else if (typeAndValue.getType().equals(TinyintType.TINYINT)) { - statement.setByte(i + 1, ((Number) typeAndValue.getValue()).byteValue()); + int parameterIndex = i + 1; + Type type = typeAndValue.getType(); + WriteFunction writeFunction = client.toPrestoType(session, typeAndValue.getTypeHandle()) + .orElseThrow(() -> new VerifyException(format("Unsupported type %s with handle %s", type, typeAndValue.getTypeHandle()))) + .getWriteFunction(); + Class javaType = type.getJavaType(); + Object value = typeAndValue.getValue(); + if (javaType == boolean.class) { + ((BooleanWriteFunction) writeFunction).set(statement, parameterIndex, (boolean) value); } - else if (typeAndValue.getType().equals(DoubleType.DOUBLE)) { - statement.setDouble(i + 1, (double) typeAndValue.getValue()); + else if (javaType == long.class) { + ((LongWriteFunction) writeFunction).set(statement, parameterIndex, (long) value); } - else if (typeAndValue.getType().equals(RealType.REAL)) { - statement.setFloat(i + 1, intBitsToFloat(((Number) typeAndValue.getValue()).intValue())); + else if (javaType == double.class) { + ((DoubleWriteFunction) writeFunction).set(statement, parameterIndex, (double) value); } - else if (typeAndValue.getType().equals(BooleanType.BOOLEAN)) { - statement.setBoolean(i + 1, (boolean) typeAndValue.getValue()); - } - else if (typeAndValue.getType().equals(DateType.DATE)) { - long millis = DAYS.toMillis((long) typeAndValue.getValue()); - statement.setDate(i + 1, new Date(UTC.getMillisKeepLocal(DateTimeZone.getDefault(), millis))); - } - else if (typeAndValue.getType().equals(TimeType.TIME)) { - statement.setTime(i + 1, new Time((long) typeAndValue.getValue())); - } - else if (typeAndValue.getType().equals(TimestampType.TIMESTAMP)) { - statement.setTimestamp(i + 1, new Timestamp((long) typeAndValue.getValue())); - } - else if (typeAndValue.getType() instanceof VarcharType) { - statement.setString(i + 1, ((Slice) typeAndValue.getValue()).toStringUtf8()); - } - else if (typeAndValue.getType() instanceof CharType) { - statement.setString(i + 1, ((Slice) typeAndValue.getValue()).toStringUtf8()); + else if (javaType == Slice.class) { + ((SliceWriteFunction) writeFunction).set(statement, parameterIndex, (Slice) value); } else { - throw new UnsupportedOperationException("Can't handle type: " + typeAndValue.getType()); + throw new VerifyException(format("Unexpected type %s with java type %s", type, javaType.getName())); } } return statement; } - private static boolean isAcceptedType(Type type) + private static Domain pushDownDomain(JdbcClient client, ConnectorSession session, JdbcColumnHandle column, Domain domain) { - Type validType = requireNonNull(type, "type is null"); - return validType.equals(BigintType.BIGINT) || - validType.equals(TinyintType.TINYINT) || - validType.equals(SmallintType.SMALLINT) || - validType.equals(IntegerType.INTEGER) || - validType.equals(DoubleType.DOUBLE) || - validType.equals(RealType.REAL) || - validType.equals(BooleanType.BOOLEAN) || - validType.equals(DateType.DATE) || - validType.equals(TimeType.TIME) || - validType.equals(TimestampType.TIMESTAMP) || - validType instanceof VarcharType || - validType instanceof CharType; + return client.toPrestoType(session, column.getJdbcTypeHandle()) + .orElseThrow(() -> new IllegalStateException(format("Unsupported type %s with handle %s", column.getColumnType(), column.getJdbcTypeHandle()))) + .getPushdownConverter().apply(domain); } - private List toConjuncts(List columns, TupleDomain tupleDomain, List accumulator) + private List toConjuncts(JdbcClient client, ConnectorSession session, List columns, TupleDomain tupleDomain, List accumulator) { ImmutableList.Builder builder = ImmutableList.builder(); for (JdbcColumnHandle column : columns) { - Type type = column.getColumnType(); - if (isAcceptedType(type)) { - Domain domain = tupleDomain.getDomains().get().get(column); - if (domain != null) { - builder.add(toPredicate(column.getColumnName(), domain, type, accumulator)); - } + Domain domain = tupleDomain.getDomains().get().get(column); + if (domain != null) { + domain = pushDownDomain(client, session, column, domain); + builder.add(toPredicate(column.getColumnName(), domain, column, accumulator)); } } return builder.build(); } - private String toPredicate(String columnName, Domain domain, Type type, List accumulator) + private String toPredicate(String columnName, Domain domain, JdbcColumnHandle column, List accumulator) { checkArgument(domain.getType().isOrderable(), "Domain type must be orderable"); @@ -243,10 +205,10 @@ private String toPredicate(String columnName, Domain domain, Type type, List", range.getLow().getValue(), type, accumulator)); + rangeConjuncts.add(toPredicate(columnName, ">", range.getLow().getValue(), column, accumulator)); break; case EXACTLY: - rangeConjuncts.add(toPredicate(columnName, ">=", range.getLow().getValue(), type, accumulator)); + rangeConjuncts.add(toPredicate(columnName, ">=", range.getLow().getValue(), column, accumulator)); break; case BELOW: throw new IllegalArgumentException("Low marker should never use BELOW bound"); @@ -259,10 +221,10 @@ private String toPredicate(String columnName, Domain domain, Type type, List 1) { for (Object value : singleValues) { - bindValue(value, type, accumulator); + bindValue(value, column, accumulator); } String values = Joiner.on(",").join(nCopies(singleValues.size(), "?")); disjuncts.add(quote(columnName) + " IN (" + values + ")"); @@ -295,9 +257,9 @@ else if (singleValues.size() > 1) { return "(" + Joiner.on(" OR ").join(disjuncts) + ")"; } - private String toPredicate(String columnName, String operator, Object value, Type type, List accumulator) + private String toPredicate(String columnName, String operator, Object value, JdbcColumnHandle column, List accumulator) { - bindValue(value, type, accumulator); + bindValue(value, column, accumulator); return quote(columnName) + " " + operator + " ?"; } @@ -307,9 +269,9 @@ private String quote(String name) return quote + name + quote; } - private static void bindValue(Object value, Type type, List accumulator) + private static void bindValue(Object value, JdbcColumnHandle column, List accumulator) { - checkArgument(isAcceptedType(type), "Can't handle type: %s", type); - accumulator.add(new TypeAndValue(type, value)); + Type type = column.getColumnType(); + accumulator.add(new TypeAndValue(type, column.getJdbcTypeHandle(), value)); } } diff --git a/presto-base-jdbc/src/main/java/io/prestosql/plugin/jdbc/ReadMapping.java b/presto-base-jdbc/src/main/java/io/prestosql/plugin/jdbc/ReadMapping.java deleted file mode 100644 index fae3f926981b..000000000000 --- a/presto-base-jdbc/src/main/java/io/prestosql/plugin/jdbc/ReadMapping.java +++ /dev/null @@ -1,76 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.prestosql.plugin.jdbc; - -import io.prestosql.spi.type.Type; - -import static com.google.common.base.MoreObjects.toStringHelper; -import static com.google.common.base.Preconditions.checkArgument; -import static java.util.Objects.requireNonNull; - -public final class ReadMapping -{ - public static ReadMapping booleanReadMapping(Type prestoType, BooleanReadFunction readFunction) - { - return new ReadMapping(prestoType, readFunction); - } - - public static ReadMapping longReadMapping(Type prestoType, LongReadFunction readFunction) - { - return new ReadMapping(prestoType, readFunction); - } - - public static ReadMapping doubleReadMapping(Type prestoType, DoubleReadFunction readFunction) - { - return new ReadMapping(prestoType, readFunction); - } - - public static ReadMapping sliceReadMapping(Type prestoType, SliceReadFunction readFunction) - { - return new ReadMapping(prestoType, readFunction); - } - - private final Type type; - private final ReadFunction readFunction; - - private ReadMapping(Type type, ReadFunction readFunction) - { - this.type = requireNonNull(type, "type is null"); - this.readFunction = requireNonNull(readFunction, "readFunction is null"); - checkArgument( - type.getJavaType() == readFunction.getJavaType(), - "Presto type %s is not compatible with read function %s returning %s", - type, - readFunction, - readFunction.getJavaType()); - } - - public Type getType() - { - return type; - } - - public ReadFunction getReadFunction() - { - return readFunction; - } - - @Override - public String toString() - { - return toStringHelper(this) - .add("type", type) - .toString(); - } -} diff --git a/presto-base-jdbc/src/main/java/io/prestosql/plugin/jdbc/SliceWriteFunction.java b/presto-base-jdbc/src/main/java/io/prestosql/plugin/jdbc/SliceWriteFunction.java new file mode 100644 index 000000000000..8a9fe51a5b8d --- /dev/null +++ b/presto-base-jdbc/src/main/java/io/prestosql/plugin/jdbc/SliceWriteFunction.java @@ -0,0 +1,32 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.plugin.jdbc; + +import io.airlift.slice.Slice; + +import java.sql.PreparedStatement; +import java.sql.SQLException; + +public interface SliceWriteFunction + extends WriteFunction +{ + @Override + default Class getJavaType() + { + return Slice.class; + } + + void set(PreparedStatement statement, int index, Slice value) + throws SQLException; +} diff --git a/presto-base-jdbc/src/main/java/io/prestosql/plugin/jdbc/StandardColumnMappings.java b/presto-base-jdbc/src/main/java/io/prestosql/plugin/jdbc/StandardColumnMappings.java new file mode 100644 index 000000000000..cf3d123bac92 --- /dev/null +++ b/presto-base-jdbc/src/main/java/io/prestosql/plugin/jdbc/StandardColumnMappings.java @@ -0,0 +1,373 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.plugin.jdbc; + +import com.google.common.base.CharMatcher; +import com.google.common.primitives.Shorts; +import com.google.common.primitives.SignedBytes; +import io.prestosql.spi.predicate.Domain; +import io.prestosql.spi.type.CharType; +import io.prestosql.spi.type.DecimalType; +import io.prestosql.spi.type.Decimals; +import io.prestosql.spi.type.VarcharType; +import org.joda.time.DateTimeZone; +import org.joda.time.chrono.ISOChronology; + +import java.math.BigDecimal; +import java.math.BigInteger; +import java.math.MathContext; +import java.sql.Date; +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.Time; +import java.sql.Timestamp; +import java.sql.Types; +import java.util.Optional; + +import static com.google.common.base.Preconditions.checkArgument; +import static io.airlift.slice.Slices.utf8Slice; +import static io.airlift.slice.Slices.wrappedBuffer; +import static io.prestosql.spi.type.BigintType.BIGINT; +import static io.prestosql.spi.type.BooleanType.BOOLEAN; +import static io.prestosql.spi.type.CharType.createCharType; +import static io.prestosql.spi.type.DateType.DATE; +import static io.prestosql.spi.type.DecimalType.createDecimalType; +import static io.prestosql.spi.type.Decimals.decodeUnscaledValue; +import static io.prestosql.spi.type.Decimals.encodeScaledValue; +import static io.prestosql.spi.type.Decimals.encodeShortScaledValue; +import static io.prestosql.spi.type.DoubleType.DOUBLE; +import static io.prestosql.spi.type.IntegerType.INTEGER; +import static io.prestosql.spi.type.RealType.REAL; +import static io.prestosql.spi.type.SmallintType.SMALLINT; +import static io.prestosql.spi.type.TimeType.TIME; +import static io.prestosql.spi.type.TimestampType.TIMESTAMP; +import static io.prestosql.spi.type.TinyintType.TINYINT; +import static io.prestosql.spi.type.VarbinaryType.VARBINARY; +import static io.prestosql.spi.type.VarcharType.createUnboundedVarcharType; +import static io.prestosql.spi.type.VarcharType.createVarcharType; +import static java.lang.Float.floatToRawIntBits; +import static java.lang.Float.intBitsToFloat; +import static java.lang.Math.max; +import static java.lang.Math.min; +import static java.lang.Math.toIntExact; +import static java.util.Objects.requireNonNull; +import static java.util.concurrent.TimeUnit.DAYS; +import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static org.joda.time.DateTimeZone.UTC; + +public final class StandardColumnMappings +{ + private StandardColumnMappings() {} + + private static final ISOChronology UTC_CHRONOLOGY = ISOChronology.getInstanceUTC(); + + public static ColumnMapping booleanColumnMapping() + { + return ColumnMapping.booleanMapping(BOOLEAN, ResultSet::getBoolean, booleanWriteFunction()); + } + + public static BooleanWriteFunction booleanWriteFunction() + { + return PreparedStatement::setBoolean; + } + + public static ColumnMapping tinyintColumnMapping() + { + return ColumnMapping.longMapping(TINYINT, ResultSet::getByte, tinyintWriteFunction()); + } + + public static LongWriteFunction tinyintWriteFunction() + { + return (statement, index, value) -> statement.setByte(index, SignedBytes.checkedCast(value)); + } + + public static ColumnMapping smallintColumnMapping() + { + return ColumnMapping.longMapping(SMALLINT, ResultSet::getShort, smallintWriteFunction()); + } + + public static LongWriteFunction smallintWriteFunction() + { + return (statement, index, value) -> statement.setShort(index, Shorts.checkedCast(value)); + } + + public static ColumnMapping integerColumnMapping() + { + return ColumnMapping.longMapping(INTEGER, ResultSet::getInt, integerWriteFunction()); + } + + public static LongWriteFunction integerWriteFunction() + { + return (statement, index, value) -> statement.setInt(index, toIntExact(value)); + } + + public static ColumnMapping bigintColumnMapping() + { + return ColumnMapping.longMapping(BIGINT, ResultSet::getLong, bigintWriteFunction()); + } + + public static LongWriteFunction bigintWriteFunction() + { + return PreparedStatement::setLong; + } + + public static ColumnMapping realColumnMapping() + { + return ColumnMapping.longMapping(REAL, (resultSet, columnIndex) -> floatToRawIntBits(resultSet.getFloat(columnIndex)), realWriteFunction()); + } + + public static LongWriteFunction realWriteFunction() + { + return (statement, index, value) -> statement.setFloat(index, intBitsToFloat(toIntExact(value))); + } + + public static ColumnMapping doubleColumnMapping() + { + return ColumnMapping.doubleMapping(DOUBLE, ResultSet::getDouble, doubleWriteFunction()); + } + + public static DoubleWriteFunction doubleWriteFunction() + { + return PreparedStatement::setDouble; + } + + public static ColumnMapping decimalColumnMapping(DecimalType decimalType) + { + // JDBC driver can return BigDecimal with lower scale than column's scale when there are trailing zeroes + int scale = decimalType.getScale(); + if (decimalType.isShort()) { + return ColumnMapping.longMapping( + decimalType, + (resultSet, columnIndex) -> encodeShortScaledValue(resultSet.getBigDecimal(columnIndex), scale), + shortDecimalWriteFunction(decimalType)); + } + return ColumnMapping.sliceMapping( + decimalType, + (resultSet, columnIndex) -> encodeScaledValue(resultSet.getBigDecimal(columnIndex), scale), + longDecimalWriteFunction(decimalType)); + } + + public static LongWriteFunction shortDecimalWriteFunction(DecimalType decimalType) + { + requireNonNull(decimalType, "decimalType is null"); + checkArgument(decimalType.isShort()); + return (statement, index, value) -> { + BigInteger unscaledValue = BigInteger.valueOf(value); + BigDecimal bigDecimal = new BigDecimal(unscaledValue, decimalType.getScale(), new MathContext(decimalType.getPrecision())); + statement.setBigDecimal(index, bigDecimal); + }; + } + + public static SliceWriteFunction longDecimalWriteFunction(DecimalType decimalType) + { + requireNonNull(decimalType, "decimalType is null"); + checkArgument(!decimalType.isShort()); + return (statement, index, value) -> { + BigInteger unscaledValue = decodeUnscaledValue(value); + BigDecimal bigDecimal = new BigDecimal(unscaledValue, decimalType.getScale(), new MathContext(decimalType.getPrecision())); + statement.setBigDecimal(index, bigDecimal); + }; + } + + public static ColumnMapping charColumnMapping(CharType charType) + { + requireNonNull(charType, "charType is null"); + return ColumnMapping.sliceMapping( + charType, + (resultSet, columnIndex) -> utf8Slice(CharMatcher.is(' ').trimTrailingFrom(resultSet.getString(columnIndex))), + charWriteFunction(charType)); + } + + public static SliceWriteFunction charWriteFunction(CharType charType) + { + requireNonNull(charType, "charType is null"); + return (statement, index, value) -> { + statement.setString(index, value.toStringUtf8()); + }; + } + + public static ColumnMapping varcharColumnMapping(VarcharType varcharType) + { + return ColumnMapping.sliceMapping(varcharType, (resultSet, columnIndex) -> utf8Slice(resultSet.getString(columnIndex)), varcharWriteFunction()); + } + + public static SliceWriteFunction varcharWriteFunction() + { + return (statement, index, value) -> statement.setString(index, value.toStringUtf8()); + } + + public static ColumnMapping varbinaryColumnMapping() + { + return ColumnMapping.sliceMapping( + VARBINARY, + (resultSet, columnIndex) -> wrappedBuffer(resultSet.getBytes(columnIndex)), + varbinaryWriteFunction(), + domain -> Domain.all(domain.getType())); + } + + public static SliceWriteFunction varbinaryWriteFunction() + { + return (statement, index, value) -> statement.setBytes(index, value.getBytes()); + } + + public static ColumnMapping dateColumnMapping() + { + return ColumnMapping.longMapping( + DATE, + (resultSet, columnIndex) -> { + /* + * JDBC returns a date using a timestamp at midnight in the JVM timezone, or earliest time after that if there was no midnight. + * This works correctly for all dates and zones except when the missing local times 'gap' is 24h. I.e. this fails when JVM time + * zone is Pacific/Apia and date to be returned is 2011-12-30. + * + * `return resultSet.getObject(columnIndex, LocalDate.class).toEpochDay()` avoids these problems but + * is currently known not to work with Redshift (old Postgres connector) and SQL Server. + */ + long localMillis = resultSet.getDate(columnIndex).getTime(); + // Convert it to a ~midnight in UTC. + long utcMillis = ISOChronology.getInstance().getZone().getMillisKeepLocal(UTC, localMillis); + // convert to days + return MILLISECONDS.toDays(utcMillis); + }, + dateWriteFunction()); + } + + public static LongWriteFunction dateWriteFunction() + { + return (statement, index, value) -> { + // convert to midnight in default time zone + long millis = DAYS.toMillis(value); + statement.setDate(index, new Date(UTC.getMillisKeepLocal(DateTimeZone.getDefault(), millis))); + }; + } + + public static ColumnMapping timeColumnMapping() + { + return ColumnMapping.longMapping( + TIME, + (resultSet, columnIndex) -> { + /* + * TODO `resultSet.getTime(columnIndex)` returns wrong value if JVM's zone had forward offset change during 1970-01-01 + * and the time value being retrieved was not present in local time (a 'gap'), e.g. time retrieved is 00:10:00 and JVM zone is America/Hermosillo + * The problem can be averted by using `resultSet.getObject(columnIndex, LocalTime.class)` -- but this is not universally supported by JDBC drivers. + */ + Time time = resultSet.getTime(columnIndex); + return UTC_CHRONOLOGY.millisOfDay().get(time.getTime()); + }, + timeWriteFunction()); + } + + public static LongWriteFunction timeWriteFunction() + { + return (statement, index, value) -> { + // Copied from `QueryBuilder.buildSql` + // TODO verify correctness, add tests and support non-legacy timestamp + statement.setTime(index, new Time(value)); + }; + } + + public static ColumnMapping timestampColumnMapping() + { + return ColumnMapping.longMapping( + TIMESTAMP, + (resultSet, columnIndex) -> { + /* + * TODO `resultSet.getTimestamp(columnIndex)` returns wrong value if JVM's zone had forward offset change and the local time + * corresponding to timestamp value being retrieved was not present (a 'gap'), this includes regular DST changes (e.g. Europe/Warsaw) + * and one-time policy changes (Asia/Kathmandu's shift by 15 minutes on January 1, 1986, 00:00:00). + * The problem can be averted by using `resultSet.getObject(columnIndex, LocalDateTime.class)` -- but this is not universally supported by JDBC drivers. + */ + Timestamp timestamp = resultSet.getTimestamp(columnIndex); + return timestamp.getTime(); + }, + timestampWriteFunction()); + } + + public static LongWriteFunction timestampWriteFunction() + { + return (statement, index, value) -> { + // Copied from `QueryBuilder.buildSql` + // TODO verify correctness, add tests and support non-legacy timestamp + statement.setTimestamp(index, new Timestamp(value)); + }; + } + + public static Optional jdbcTypeToPrestoType(JdbcTypeHandle type) + { + int columnSize = type.getColumnSize(); + switch (type.getJdbcType()) { + case Types.BIT: + case Types.BOOLEAN: + return Optional.of(booleanColumnMapping()); + + case Types.TINYINT: + return Optional.of(tinyintColumnMapping()); + + case Types.SMALLINT: + return Optional.of(smallintColumnMapping()); + + case Types.INTEGER: + return Optional.of(integerColumnMapping()); + + case Types.BIGINT: + return Optional.of(bigintColumnMapping()); + + case Types.REAL: + return Optional.of(realColumnMapping()); + + case Types.FLOAT: + case Types.DOUBLE: + return Optional.of(doubleColumnMapping()); + + case Types.NUMERIC: + case Types.DECIMAL: + int decimalDigits = type.getDecimalDigits(); + int precision = columnSize + max(-decimalDigits, 0); // Map decimal(p, -s) (negative scale) to decimal(p+s, 0). + if (precision > Decimals.MAX_PRECISION) { + return Optional.empty(); + } + return Optional.of(decimalColumnMapping(createDecimalType(precision, max(decimalDigits, 0)))); + + case Types.CHAR: + case Types.NCHAR: + // TODO this is wrong, we're going to construct malformed Slice representation if source > charLength + int charLength = min(columnSize, CharType.MAX_LENGTH); + return Optional.of(charColumnMapping(createCharType(charLength))); + + case Types.VARCHAR: + case Types.NVARCHAR: + case Types.LONGVARCHAR: + case Types.LONGNVARCHAR: + if (columnSize > VarcharType.MAX_LENGTH) { + return Optional.of(varcharColumnMapping(createUnboundedVarcharType())); + } + return Optional.of(varcharColumnMapping(createVarcharType(columnSize))); + + case Types.BINARY: + case Types.VARBINARY: + case Types.LONGVARBINARY: + return Optional.of(varbinaryColumnMapping()); + + case Types.DATE: + return Optional.of(dateColumnMapping()); + + case Types.TIME: + return Optional.of(timeColumnMapping()); + + case Types.TIMESTAMP: + return Optional.of(timestampColumnMapping()); + } + return Optional.empty(); + } +} diff --git a/presto-base-jdbc/src/main/java/io/prestosql/plugin/jdbc/StandardReadMappings.java b/presto-base-jdbc/src/main/java/io/prestosql/plugin/jdbc/StandardReadMappings.java deleted file mode 100644 index a3e5addfe096..000000000000 --- a/presto-base-jdbc/src/main/java/io/prestosql/plugin/jdbc/StandardReadMappings.java +++ /dev/null @@ -1,237 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.prestosql.plugin.jdbc; - -import com.google.common.base.CharMatcher; -import io.prestosql.spi.type.CharType; -import io.prestosql.spi.type.DecimalType; -import io.prestosql.spi.type.Decimals; -import io.prestosql.spi.type.VarcharType; -import org.joda.time.chrono.ISOChronology; - -import java.sql.ResultSet; -import java.sql.Time; -import java.sql.Timestamp; -import java.sql.Types; -import java.util.Optional; - -import static io.airlift.slice.Slices.utf8Slice; -import static io.airlift.slice.Slices.wrappedBuffer; -import static io.prestosql.plugin.jdbc.ReadMapping.longReadMapping; -import static io.prestosql.plugin.jdbc.ReadMapping.sliceReadMapping; -import static io.prestosql.spi.type.BigintType.BIGINT; -import static io.prestosql.spi.type.BooleanType.BOOLEAN; -import static io.prestosql.spi.type.CharType.createCharType; -import static io.prestosql.spi.type.DateType.DATE; -import static io.prestosql.spi.type.DecimalType.createDecimalType; -import static io.prestosql.spi.type.Decimals.encodeScaledValue; -import static io.prestosql.spi.type.Decimals.encodeShortScaledValue; -import static io.prestosql.spi.type.DoubleType.DOUBLE; -import static io.prestosql.spi.type.IntegerType.INTEGER; -import static io.prestosql.spi.type.RealType.REAL; -import static io.prestosql.spi.type.SmallintType.SMALLINT; -import static io.prestosql.spi.type.TimeType.TIME; -import static io.prestosql.spi.type.TimestampType.TIMESTAMP; -import static io.prestosql.spi.type.TinyintType.TINYINT; -import static io.prestosql.spi.type.VarbinaryType.VARBINARY; -import static io.prestosql.spi.type.VarcharType.createUnboundedVarcharType; -import static io.prestosql.spi.type.VarcharType.createVarcharType; -import static java.lang.Float.floatToRawIntBits; -import static java.lang.Math.max; -import static java.lang.Math.min; -import static java.util.Objects.requireNonNull; -import static java.util.concurrent.TimeUnit.MILLISECONDS; -import static org.joda.time.DateTimeZone.UTC; - -public final class StandardReadMappings -{ - private StandardReadMappings() {} - - private static final ISOChronology UTC_CHRONOLOGY = ISOChronology.getInstanceUTC(); - - public static ReadMapping booleanReadMapping() - { - return ReadMapping.booleanReadMapping(BOOLEAN, ResultSet::getBoolean); - } - - public static ReadMapping tinyintReadMapping() - { - return longReadMapping(TINYINT, ResultSet::getByte); - } - - public static ReadMapping smallintReadMapping() - { - return longReadMapping(SMALLINT, ResultSet::getShort); - } - - public static ReadMapping integerReadMapping() - { - return longReadMapping(INTEGER, ResultSet::getInt); - } - - public static ReadMapping bigintReadMapping() - { - return longReadMapping(BIGINT, ResultSet::getLong); - } - - public static ReadMapping realReadMapping() - { - return longReadMapping(REAL, (resultSet, columnIndex) -> floatToRawIntBits(resultSet.getFloat(columnIndex))); - } - - public static ReadMapping doubleReadMapping() - { - return ReadMapping.doubleReadMapping(DOUBLE, ResultSet::getDouble); - } - - public static ReadMapping decimalReadMapping(DecimalType decimalType) - { - // JDBC driver can return BigDecimal with lower scale than column's scale when there are trailing zeroes - int scale = decimalType.getScale(); - if (decimalType.isShort()) { - return longReadMapping(decimalType, (resultSet, columnIndex) -> encodeShortScaledValue(resultSet.getBigDecimal(columnIndex), scale)); - } - return sliceReadMapping(decimalType, (resultSet, columnIndex) -> encodeScaledValue(resultSet.getBigDecimal(columnIndex), scale)); - } - - public static ReadMapping charReadMapping(CharType charType) - { - requireNonNull(charType, "charType is null"); - return sliceReadMapping(charType, (resultSet, columnIndex) -> utf8Slice(CharMatcher.is(' ').trimTrailingFrom(resultSet.getString(columnIndex)))); - } - - public static ReadMapping varcharReadMapping(VarcharType varcharType) - { - return sliceReadMapping(varcharType, (resultSet, columnIndex) -> utf8Slice(resultSet.getString(columnIndex))); - } - - public static ReadMapping varbinaryReadMapping() - { - return sliceReadMapping(VARBINARY, (resultSet, columnIndex) -> wrappedBuffer(resultSet.getBytes(columnIndex))); - } - - public static ReadMapping dateReadMapping() - { - return longReadMapping(DATE, (resultSet, columnIndex) -> { - /* - * JDBC returns a date using a timestamp at midnight in the JVM timezone, or earliest time after that if there was no midnight. - * This works correctly for all dates and zones except when the missing local times 'gap' is 24h. I.e. this fails when JVM time - * zone is Pacific/Apia and date to be returned is 2011-12-30. - * - * `return resultSet.getObject(columnIndex, LocalDate.class).toEpochDay()` avoids these problems but - * is currently known not to work with Redshift (old Postgres connector) and SQL Server. - */ - long localMillis = resultSet.getDate(columnIndex).getTime(); - // Convert it to a ~midnight in UTC. - long utcMillis = ISOChronology.getInstance().getZone().getMillisKeepLocal(UTC, localMillis); - // convert to days - return MILLISECONDS.toDays(utcMillis); - }); - } - - public static ReadMapping timeReadMapping() - { - return longReadMapping(TIME, (resultSet, columnIndex) -> { - /* - * TODO `resultSet.getTime(columnIndex)` returns wrong value if JVM's zone had forward offset change during 1970-01-01 - * and the time value being retrieved was not present in local time (a 'gap'), e.g. time retrieved is 00:10:00 and JVM zone is America/Hermosillo - * The problem can be averted by using `resultSet.getObject(columnIndex, LocalTime.class)` -- but this is not universally supported by JDBC drivers. - */ - Time time = resultSet.getTime(columnIndex); - return UTC_CHRONOLOGY.millisOfDay().get(time.getTime()); - }); - } - - public static ReadMapping timestampReadMapping() - { - return longReadMapping(TIMESTAMP, (resultSet, columnIndex) -> { - /* - * TODO `resultSet.getTimestamp(columnIndex)` returns wrong value if JVM's zone had forward offset change and the local time - * corresponding to timestamp value being retrieved was not present (a 'gap'), this includes regular DST changes (e.g. Europe/Warsaw) - * and one-time policy changes (Asia/Kathmandu's shift by 15 minutes on January 1, 1986, 00:00:00). - * The problem can be averted by using `resultSet.getObject(columnIndex, LocalDateTime.class)` -- but this is not universally supported by JDBC drivers. - */ - Timestamp timestamp = resultSet.getTimestamp(columnIndex); - return timestamp.getTime(); - }); - } - - public static Optional jdbcTypeToPrestoType(JdbcTypeHandle type) - { - int columnSize = type.getColumnSize(); - switch (type.getJdbcType()) { - case Types.BIT: - case Types.BOOLEAN: - return Optional.of(booleanReadMapping()); - - case Types.TINYINT: - return Optional.of(tinyintReadMapping()); - - case Types.SMALLINT: - return Optional.of(smallintReadMapping()); - - case Types.INTEGER: - return Optional.of(integerReadMapping()); - - case Types.BIGINT: - return Optional.of(bigintReadMapping()); - - case Types.REAL: - return Optional.of(realReadMapping()); - - case Types.FLOAT: - case Types.DOUBLE: - return Optional.of(doubleReadMapping()); - - case Types.NUMERIC: - case Types.DECIMAL: - int decimalDigits = type.getDecimalDigits(); - int precision = columnSize + max(-decimalDigits, 0); // Map decimal(p, -s) (negative scale) to decimal(p+s, 0). - if (precision > Decimals.MAX_PRECISION) { - return Optional.empty(); - } - return Optional.of(decimalReadMapping(createDecimalType(precision, max(decimalDigits, 0)))); - - case Types.CHAR: - case Types.NCHAR: - // TODO this is wrong, we're going to construct malformed Slice representation if source > charLength - int charLength = min(columnSize, CharType.MAX_LENGTH); - return Optional.of(charReadMapping(createCharType(charLength))); - - case Types.VARCHAR: - case Types.NVARCHAR: - case Types.LONGVARCHAR: - case Types.LONGNVARCHAR: - if (columnSize > VarcharType.MAX_LENGTH) { - return Optional.of(varcharReadMapping(createUnboundedVarcharType())); - } - return Optional.of(varcharReadMapping(createVarcharType(columnSize))); - - case Types.BINARY: - case Types.VARBINARY: - case Types.LONGVARBINARY: - return Optional.of(varbinaryReadMapping()); - - case Types.DATE: - return Optional.of(dateReadMapping()); - - case Types.TIME: - return Optional.of(timeReadMapping()); - - case Types.TIMESTAMP: - return Optional.of(timestampReadMapping()); - } - return Optional.empty(); - } -} diff --git a/presto-base-jdbc/src/main/java/io/prestosql/plugin/jdbc/WriteFunction.java b/presto-base-jdbc/src/main/java/io/prestosql/plugin/jdbc/WriteFunction.java new file mode 100644 index 000000000000..ddca288fe6c0 --- /dev/null +++ b/presto-base-jdbc/src/main/java/io/prestosql/plugin/jdbc/WriteFunction.java @@ -0,0 +1,19 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.plugin.jdbc; + +public interface WriteFunction +{ + Class getJavaType(); +} diff --git a/presto-base-jdbc/src/main/java/io/prestosql/plugin/jdbc/WriteMapping.java b/presto-base-jdbc/src/main/java/io/prestosql/plugin/jdbc/WriteMapping.java new file mode 100644 index 000000000000..56939cd21d09 --- /dev/null +++ b/presto-base-jdbc/src/main/java/io/prestosql/plugin/jdbc/WriteMapping.java @@ -0,0 +1,70 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.plugin.jdbc; + +import static com.google.common.base.MoreObjects.toStringHelper; +import static java.util.Objects.requireNonNull; + +public final class WriteMapping +{ + public static WriteMapping booleanMapping(String dataType, BooleanWriteFunction writeFunction) + { + return new WriteMapping(dataType, writeFunction); + } + + public static WriteMapping longMapping(String dataType, LongWriteFunction writeFunction) + { + return new WriteMapping(dataType, writeFunction); + } + + public static WriteMapping doubleMapping(String dataType, DoubleWriteFunction writeFunction) + { + return new WriteMapping(dataType, writeFunction); + } + + public static WriteMapping sliceMapping(String dataType, SliceWriteFunction writeFunction) + { + return new WriteMapping(dataType, writeFunction); + } + + private final String dataType; + private final WriteFunction writeFunction; + + private WriteMapping(String dataType, WriteFunction writeFunction) + { + this.dataType = requireNonNull(dataType, "dataType is null"); + this.writeFunction = requireNonNull(writeFunction, "writeFunction is null"); + } + + /** + * Data type that should be used in the remote database when defining a column. + */ + public String getDataType() + { + return dataType; + } + + public WriteFunction getWriteFunction() + { + return writeFunction; + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("dataType", dataType) + .toString(); + } +} diff --git a/presto-base-jdbc/src/test/java/io/prestosql/plugin/jdbc/TestJdbcQueryBuilder.java b/presto-base-jdbc/src/test/java/io/prestosql/plugin/jdbc/TestJdbcQueryBuilder.java index bfc5358a5c19..fc7dc9668285 100644 --- a/presto-base-jdbc/src/test/java/io/prestosql/plugin/jdbc/TestJdbcQueryBuilder.java +++ b/presto-base-jdbc/src/test/java/io/prestosql/plugin/jdbc/TestJdbcQueryBuilder.java @@ -64,6 +64,7 @@ import static io.prestosql.spi.type.TimestampType.TIMESTAMP; import static io.prestosql.spi.type.TinyintType.TINYINT; import static io.prestosql.spi.type.VarcharType.VARCHAR; +import static io.prestosql.testing.TestingConnectorSession.SESSION; import static java.lang.Float.floatToRawIntBits; import static java.lang.String.format; import static java.time.temporal.ChronoUnit.DAYS; @@ -197,7 +198,7 @@ public void testNormalBuildSql() .build()); Connection connection = database.getConnection(); - try (PreparedStatement preparedStatement = new QueryBuilder("\"").buildSql(jdbcClient, connection, "", "", "test_table", columns, tupleDomain, Optional.empty()); + try (PreparedStatement preparedStatement = new QueryBuilder("\"").buildSql(jdbcClient, SESSION, connection, "", "", "test_table", columns, tupleDomain, Optional.empty()); ResultSet resultSet = preparedStatement.executeQuery()) { ImmutableSet.Builder builder = ImmutableSet.builder(); while (resultSet.next()) { @@ -220,7 +221,7 @@ public void testBuildSqlWithFloat() false))); Connection connection = database.getConnection(); - try (PreparedStatement preparedStatement = new QueryBuilder("\"").buildSql(jdbcClient, connection, "", "", "test_table", columns, tupleDomain, Optional.empty()); + try (PreparedStatement preparedStatement = new QueryBuilder("\"").buildSql(jdbcClient, SESSION, connection, "", "", "test_table", columns, tupleDomain, Optional.empty()); ResultSet resultSet = preparedStatement.executeQuery()) { ImmutableSet.Builder longBuilder = ImmutableSet.builder(); ImmutableSet.Builder floatBuilder = ImmutableSet.builder(); @@ -246,7 +247,7 @@ public void testBuildSqlWithVarchar() false))); Connection connection = database.getConnection(); - try (PreparedStatement preparedStatement = new QueryBuilder("\"").buildSql(jdbcClient, connection, "", "", "test_table", columns, tupleDomain, Optional.empty()); + try (PreparedStatement preparedStatement = new QueryBuilder("\"").buildSql(jdbcClient, SESSION, connection, "", "", "test_table", columns, tupleDomain, Optional.empty()); ResultSet resultSet = preparedStatement.executeQuery()) { ImmutableSet.Builder builder = ImmutableSet.builder(); while (resultSet.next()) { @@ -274,7 +275,7 @@ public void testBuildSqlWithChar() false))); Connection connection = database.getConnection(); - try (PreparedStatement preparedStatement = new QueryBuilder("\"").buildSql(jdbcClient, connection, "", "", "test_table", columns, tupleDomain, Optional.empty()); + try (PreparedStatement preparedStatement = new QueryBuilder("\"").buildSql(jdbcClient, SESSION, connection, "", "", "test_table", columns, tupleDomain, Optional.empty()); ResultSet resultSet = preparedStatement.executeQuery()) { ImmutableSet.Builder builder = ImmutableSet.builder(); while (resultSet.next()) { @@ -307,7 +308,7 @@ public void testBuildSqlWithDateTime() false))); Connection connection = database.getConnection(); - try (PreparedStatement preparedStatement = new QueryBuilder("\"").buildSql(jdbcClient, connection, "", "", "test_table", columns, tupleDomain, Optional.empty()); + try (PreparedStatement preparedStatement = new QueryBuilder("\"").buildSql(jdbcClient, SESSION, connection, "", "", "test_table", columns, tupleDomain, Optional.empty()); ResultSet resultSet = preparedStatement.executeQuery()) { ImmutableSet.Builder dateBuilder = ImmutableSet.builder(); ImmutableSet.Builder