diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/Parameter.java b/src/main/java/com/microsoft/sqlserver/jdbc/Parameter.java index b614ad5fe..2a9c757d0 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/Parameter.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/Parameter.java @@ -415,6 +415,10 @@ Object getValue(JDBCType jdbcType, // statement level), cryptoMeta would be null. return getterDTV.getValue(jdbcType, outScale, getterArgs, cal, typeInfo, cryptoMeta, tdsReader); } + + Object getSetterValue() { + return setterDTV.getSetterValue(); + } int getInt(TDSReader tdsReader) throws SQLServerException { Integer value = (Integer) getValue(JDBCType.INTEGER, null, null, tdsReader); diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerBulkBatchInsertRecord.java b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerBulkBatchInsertRecord.java new file mode 100644 index 000000000..3bd33db7d --- /dev/null +++ b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerBulkBatchInsertRecord.java @@ -0,0 +1,423 @@ +/* + * Microsoft JDBC Driver for SQL Server + * + * Copyright(c) Microsoft Corporation All rights reserved. + * + * This program is made available under the terms of the MIT License. See the LICENSE file in the project root for more information. + */ + +package com.microsoft.sqlserver.jdbc; + +import java.math.BigDecimal; +import java.math.RoundingMode; +import java.sql.Types; +import java.text.DecimalFormat; +import java.text.MessageFormat; +import java.time.OffsetDateTime; +import java.time.OffsetTime; +import java.time.format.DateTimeFormatter; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map.Entry; + +import com.microsoft.sqlserver.jdbc.SQLServerBulkCommon.ColumnMetadata; + +import java.util.Set; + +/** + * A simple implementation of the ISQLServerBulkRecord interface that can be used to read in the basic Java data types from an ArrayList of Parameters + * that were provided by pstmt/cstmt. + */ +public class SQLServerBulkBatchInsertRecord extends SQLServerBulkCommon implements ISQLServerBulkRecord { + + private List batchParam; + private int batchParamIndex = -1; + private List columnList; + private List valueList; + + /* + * Class name for logging. + */ + private static final String loggerClassName = "com.microsoft.sqlserver.jdbc.SQLServerBulkBatchInsertRecord"; + + /* + * Logger + */ + private static final java.util.logging.Logger loggerExternal = java.util.logging.Logger.getLogger(loggerClassName); + + public SQLServerBulkBatchInsertRecord(ArrayList batchParam, + ArrayList columnList, + ArrayList valueList, + String encoding) throws SQLServerException { + loggerExternal.entering(loggerClassName, "SQLServerBulkBatchInsertRecord", new Object[] {batchParam, encoding}); + + if (null == batchParam) { + throwInvalidArgument("batchParam"); + } + + if (null == valueList) { + throwInvalidArgument("valueList"); + } + + this.batchParam = batchParam; + this.columnList = columnList; + this.valueList = valueList; + columnMetadata = new HashMap<>(); + + loggerExternal.exiting(loggerClassName, "SQLServerBulkBatchInsertRecord"); + } + + public DateTimeFormatter getColumnDateTimeFormatter(int column) { + return columnMetadata.get(column).dateTimeFormatter; + } + + @Override + public Set getColumnOrdinals() { + return columnMetadata.keySet(); + } + + @Override + public String getColumnName(int column) { + return columnMetadata.get(column).columnName; + } + + @Override + public int getColumnType(int column) { + return columnMetadata.get(column).columnType; + } + + @Override + public int getPrecision(int column) { + return columnMetadata.get(column).precision; + } + + @Override + public int getScale(int column) { + return columnMetadata.get(column).scale; + } + + @Override + public boolean isAutoIncrement(int column) { + return false; + } + + private Object convertValue(ColumnMetadata cm, + Object data) throws SQLServerException { + switch (cm.columnType) { + case Types.INTEGER: { + // Formatter to remove the decimal part as SQL Server floors the decimal in integer types + DecimalFormat decimalFormatter = new DecimalFormat("#"); + decimalFormatter.setRoundingMode(RoundingMode.DOWN); + String formatedfInput = decimalFormatter.format(Double.parseDouble(data.toString())); + return Integer.valueOf(formatedfInput); + } + + case Types.TINYINT: + case Types.SMALLINT: { + // Formatter to remove the decimal part as SQL Server floors the decimal in integer types + DecimalFormat decimalFormatter = new DecimalFormat("#"); + decimalFormatter.setRoundingMode(RoundingMode.DOWN); + String formatedfInput = decimalFormatter.format(Double.parseDouble(data.toString())); + return Short.valueOf(formatedfInput); + } + + case Types.BIGINT: { + BigDecimal bd = new BigDecimal(data.toString().trim()); + try { + return bd.setScale(0, RoundingMode.DOWN).longValueExact(); + } + catch (ArithmeticException ex) { + String value = "'" + data + "'"; + MessageFormat form = new MessageFormat(SQLServerException.getErrString("R_errorConvertingValue")); + throw new SQLServerException(form.format(new Object[] {value, JDBCType.of(cm.columnType)}), null, 0, ex); + } + } + + case Types.DECIMAL: + case Types.NUMERIC: { + BigDecimal bd = new BigDecimal(data.toString().trim()); + return bd.setScale(cm.scale, RoundingMode.HALF_UP); + } + + case Types.BIT: { + // "true" => 1, "false" => 0 + // Any non-zero value (integer/double) => 1, 0/0.0 => 0 + try { + return (0 == Double.parseDouble(data.toString())) ? Boolean.FALSE : Boolean.TRUE; + } + catch (NumberFormatException e) { + return Boolean.parseBoolean(data.toString()); + } + } + + case Types.REAL: { + return Float.parseFloat(data.toString()); + } + + case Types.DOUBLE: { + return Double.parseDouble(data.toString()); + } + + case Types.BINARY: + case Types.VARBINARY: + case Types.LONGVARBINARY: + case Types.BLOB: { + // Strip off 0x if present. + String binData = data.toString().trim(); + if (binData.startsWith("0x") || binData.startsWith("0X")) { + return binData.substring(2); + } + else { + return binData; + } + } + + case java.sql.Types.TIME_WITH_TIMEZONE: + { + OffsetTime offsetTimeValue; + + // The per-column DateTimeFormatter gets priority. + if (null != cm.dateTimeFormatter) + offsetTimeValue = OffsetTime.parse(data.toString(), cm.dateTimeFormatter); + else if (timeFormatter != null) + offsetTimeValue = OffsetTime.parse(data.toString(), timeFormatter); + else + offsetTimeValue = OffsetTime.parse(data.toString()); + + return offsetTimeValue; + } + + case java.sql.Types.TIMESTAMP_WITH_TIMEZONE: + { + OffsetDateTime offsetDateTimeValue; + + // The per-column DateTimeFormatter gets priority. + if (null != cm.dateTimeFormatter) + offsetDateTimeValue = OffsetDateTime.parse(data.toString(), cm.dateTimeFormatter); + else if (dateTimeFormatter != null) + offsetDateTimeValue = OffsetDateTime.parse(data.toString(), dateTimeFormatter); + else + offsetDateTimeValue = OffsetDateTime.parse(data.toString()); + + return offsetDateTimeValue; + } + + case Types.NULL: { + return null; + } + + case Types.DATE: + case Types.CHAR: + case Types.NCHAR: + case Types.VARCHAR: + case Types.NVARCHAR: + case Types.LONGVARCHAR: + case Types.LONGNVARCHAR: + case Types.CLOB: + default: { + // The string is copied as is. + return data; + } + } + } + + private String removeSingleQuote(String s) { + int len = s.length(); + return (s.charAt(0) == '\'' && s.charAt(len - 1) == '\'') ? s.substring(1, len - 1) : s; + } + + @Override + public Object[] getRowData() throws SQLServerException { + Object[] data = new Object[columnMetadata.size()]; + int valueIndex = 0; + String valueData; + Object rowData; + int columnListIndex = 0; + + // check if the size of the list of values = size of the list of columns (which is optional) + if (null != columnList && columnList.size() != valueList.size()) { + MessageFormat form = new MessageFormat(SQLServerException.getErrString("R_DataSchemaMismatch")); + Object[] msgArgs = {}; + throw new SQLServerException(form.format(msgArgs), SQLState.COL_NOT_FOUND, DriverError.NOT_SET, null); + } + + for (Entry pair : columnMetadata.entrySet()) { + int index = pair.getKey() - 1; + + // To explain what each variable represents: + // columnMetadata = map containing the ENTIRE list of columns in the table. + // columnList = the *optional* list of columns the user can provide. For example, the (c1, c3) part of this query: INSERT into t1 (c1, c3) values (?, ?) + // valueList = the *mandatory* list of columns the user needs provide. This is the (?, ?) part of the previous query. The size of this valueList will always equal the number of + // the entire columns in the table IF columnList has NOT been provided. If columnList HAS been provided, then this valueList may be smaller than the list of all columns (which is columnMetadata). + + // case when the user has not provided the optional list of column names. + if (null == columnList || columnList.size() == 0) { + valueData = valueList.get(index); + // if the user has provided a wildcard for this column, fetch the set value from the batchParam. + if (valueData.equalsIgnoreCase("?")) { + rowData = batchParam.get(batchParamIndex)[valueIndex++].getSetterValue(); + } + else if (valueData.equalsIgnoreCase("null")) { + rowData = null; + } + // if the user has provided a hardcoded value for this column, rowData is simply set to the hardcoded value. + else { + rowData = removeSingleQuote(valueData); + } + } + // case when the user has provided the optional list of column names. + else { + // columnListIndex is a separate counter we need to keep track of for each time we've processed a column + // that the user provided. + // for example, if the user provided an optional columnList of (c1, c3, c5, c7) in a table that has 8 columns (c1~c8), + // then the columnListIndex would increment only when we're dealing with the four columns inside columnMetadata. + // compare the list of the optional list of column names to the table's metadata, and match each other, so we assign the correct value to each column. + if (columnList.size() > columnListIndex && columnList.get(columnListIndex).equalsIgnoreCase(columnMetadata.get(index + 1).columnName)) { + valueData = valueList.get(columnListIndex); + if (valueData.equalsIgnoreCase("?")) { + rowData = batchParam.get(batchParamIndex)[valueIndex++].getSetterValue(); + } + else if (valueData.equalsIgnoreCase("null")) { + rowData = null; + } + else { + rowData = removeSingleQuote(valueData); + } + columnListIndex++; + } + else { + rowData = null; + } + } + + try { + if (null == rowData) { + data[index] = null; + continue; + } else if (0 == rowData.toString().length()) { + data[index] = ""; + continue; + } + data[index] = convertValue(pair.getValue(), rowData); + } + catch (IllegalArgumentException e) { + String value = "'" + rowData + "'"; + MessageFormat form = new MessageFormat(SQLServerException.getErrString("R_errorConvertingValue")); + throw new SQLServerException(form.format(new Object[] {value, JDBCType.of(pair.getValue().columnType)}), null, 0, e); + } + catch (ArrayIndexOutOfBoundsException e) { + throw new SQLServerException(SQLServerException.getErrString("R_DataSchemaMismatch"), e); + } + } + return data; + } + + @Override + void addColumnMetadataInternal(int positionInSource, + String name, + int jdbcType, + int precision, + int scale, + DateTimeFormatter dateTimeFormatter) throws SQLServerException { + loggerExternal.entering(loggerClassName, "addColumnMetadata", new Object[] {positionInSource, name, jdbcType, precision, scale}); + + String colName = ""; + + if (0 >= positionInSource) { + MessageFormat form = new MessageFormat(SQLServerException.getErrString("R_invalidColumnOrdinal")); + Object[] msgArgs = {positionInSource}; + throw new SQLServerException(form.format(msgArgs), SQLState.COL_NOT_FOUND, DriverError.NOT_SET, null); + } + + if (null != name) + colName = name.trim(); + else if ((null != columnNames) && (columnNames.length >= positionInSource)) + colName = columnNames[positionInSource - 1]; + + if ((null != columnNames) && (positionInSource > columnNames.length)) { + MessageFormat form = new MessageFormat(SQLServerException.getErrString("R_invalidColumn")); + Object[] msgArgs = {positionInSource}; + throw new SQLServerException(form.format(msgArgs), SQLState.COL_NOT_FOUND, DriverError.NOT_SET, null); + } + + checkDuplicateColumnName(positionInSource, name); + switch (jdbcType) { + /* + * SQL Server supports numerous string literal formats for temporal types, hence sending them as varchar with approximate + * precision(length) needed to send supported string literals. string literal formats supported by temporal types are available in MSDN + * page on data types. + */ + case java.sql.Types.DATE: + case java.sql.Types.TIME: + case java.sql.Types.TIMESTAMP: + case microsoft.sql.Types.DATETIMEOFFSET: + columnMetadata.put(positionInSource, new ColumnMetadata(colName, jdbcType, precision, scale, dateTimeFormatter)); + break; + + // Redirect SQLXML as LONGNVARCHAR + // SQLXML is not valid type in TDS + case java.sql.Types.SQLXML: + columnMetadata.put(positionInSource, new ColumnMetadata(colName, java.sql.Types.LONGNVARCHAR, precision, scale, dateTimeFormatter)); + break; + + // Redirecting Float as Double based on data type mapping + // https://msdn.microsoft.com/en-us/library/ms378878%28v=sql.110%29.aspx + case java.sql.Types.FLOAT: + columnMetadata.put(positionInSource, new ColumnMetadata(colName, java.sql.Types.DOUBLE, precision, scale, dateTimeFormatter)); + break; + + // redirecting BOOLEAN as BIT + case java.sql.Types.BOOLEAN: + columnMetadata.put(positionInSource, new ColumnMetadata(colName, java.sql.Types.BIT, precision, scale, dateTimeFormatter)); + break; + + default: + columnMetadata.put(positionInSource, new ColumnMetadata(colName, jdbcType, precision, scale, dateTimeFormatter)); + } + + loggerExternal.exiting(loggerClassName, "addColumnMetadata"); + } + + @Override + public void setTimestampWithTimezoneFormat(String dateTimeFormat) { + loggerExternal.entering(loggerClassName, "setTimestampWithTimezoneFormat", dateTimeFormat); + + super.setTimestampWithTimezoneFormat(dateTimeFormat); + + loggerExternal.exiting(loggerClassName, "setTimestampWithTimezoneFormat"); + } + + @Override + public void setTimestampWithTimezoneFormat(DateTimeFormatter dateTimeFormatter) { + loggerExternal.entering(loggerClassName, "setTimestampWithTimezoneFormat", new Object[] {dateTimeFormatter}); + + super.setTimestampWithTimezoneFormat(dateTimeFormatter); + + loggerExternal.exiting(loggerClassName, "setTimestampWithTimezoneFormat"); + } + + @Override + public void setTimeWithTimezoneFormat(String timeFormat) { + loggerExternal.entering(loggerClassName, "setTimeWithTimezoneFormat", timeFormat); + + super.setTimeWithTimezoneFormat(timeFormat); + + loggerExternal.exiting(loggerClassName, "setTimeWithTimezoneFormat"); + } + + @Override + public void setTimeWithTimezoneFormat(DateTimeFormatter dateTimeFormatter) { + loggerExternal.entering(loggerClassName, "setTimeWithTimezoneFormat", new Object[] {dateTimeFormatter}); + + super.setTimeWithTimezoneFormat(dateTimeFormatter); + + loggerExternal.exiting(loggerClassName, "setTimeWithTimezoneFormat"); + } + + @Override + public boolean next() throws SQLServerException { + batchParamIndex++; + return batchParamIndex < batchParam.size(); + } +} diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerBulkCSVFileRecord.java b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerBulkCSVFileRecord.java index cce49a416..019046ab6 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerBulkCSVFileRecord.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerBulkCSVFileRecord.java @@ -23,38 +23,17 @@ import java.time.OffsetTime; import java.time.format.DateTimeFormatter; import java.util.HashMap; -import java.util.Map; import java.util.Map.Entry; + +import com.microsoft.sqlserver.jdbc.SQLServerBulkCommon.ColumnMetadata; + import java.util.Set; /** * A simple implementation of the ISQLServerBulkRecord interface that can be used to read in the basic Java data types from a delimited file where * each line represents a row of data. */ -public class SQLServerBulkCSVFileRecord implements ISQLServerBulkRecord, java.lang.AutoCloseable { - /* - * Class to represent the column metadata - */ - private class ColumnMetadata { - String columnName; - int columnType; - int precision; - int scale; - DateTimeFormatter dateTimeFormatter = null; - - ColumnMetadata(String name, - int type, - int precision, - int scale, - DateTimeFormatter dateTimeFormatter) { - columnName = name; - columnType = type; - this.precision = precision; - this.scale = scale; - this.dateTimeFormatter = dateTimeFormatter; - } - } - +public class SQLServerBulkCSVFileRecord extends SQLServerBulkCommon implements ISQLServerBulkRecord, java.lang.AutoCloseable { /* * Resources associated with reading in the file */ @@ -62,12 +41,6 @@ private class ColumnMetadata { private InputStreamReader sr; private FileInputStream fis; - /* - * Metadata to represent the columns in the file. Each column should be mapped to its corresponding position within the file (from position 1 and - * onwards) - */ - private Map columnMetadata; - /* * Current line of data to parse. */ @@ -77,31 +50,16 @@ private class ColumnMetadata { * Delimiter to parse lines with. */ private final String delimiter; - - /* - * Contains all the column names if firstLineIsColumnNames is true - */ - private String[] columnNames = null; - - /* - * Contains the format that java.sql.Types.TIMESTAMP_WITH_TIMEZONE data should be read in as. - */ - private DateTimeFormatter dateTimeFormatter = null; - - /* - * Contains the format that java.sql.Types.TIME_WITH_TIMEZONE data should be read in as. - */ - private DateTimeFormatter timeFormatter = null; - - /* - * Class name for logging. - */ - private static final String loggerClassName = "com.microsoft.sqlserver.jdbc.SQLServerBulkCSVFileRecord"; - - /* - * Logger - */ - private static final java.util.logging.Logger loggerExternal = java.util.logging.Logger.getLogger(loggerClassName); + + /* + * Class name for logging. + */ + private static final String loggerClassName = "com.microsoft.sqlserver.jdbc.SQLServerBulkCSVFileRecord"; + + /* + * Logger + */ + private static final java.util.logging.Logger loggerExternal = java.util.logging.Logger.getLogger(loggerClassName); /** * Creates a simple reader to parse data from a delimited file with the given encoding. @@ -253,199 +211,6 @@ public SQLServerBulkCSVFileRecord(String fileToParse, this(fileToParse, null, ",", firstLineIsColumnNames); } - /** - * Adds metadata for the given column in the file. - * - * @param positionInFile - * Indicates which column the metadata is for. Columns start at 1. - * @param name - * Name for the column (optional if only using column ordinal in a mapping for SQLServerBulkCopy operation) - * @param jdbcType - * JDBC data type of the column - * @param precision - * Precision for the column (ignored for the appropriate data types) - * @param scale - * Scale for the column (ignored for the appropriate data types) - * @param dateTimeFormatter - * format to parse data that is sent - * @throws SQLServerException - * when an error occurs - */ - public void addColumnMetadata(int positionInFile, - String name, - int jdbcType, - int precision, - int scale, - DateTimeFormatter dateTimeFormatter) throws SQLServerException { - addColumnMetadataInternal(positionInFile, name, jdbcType, precision, scale, dateTimeFormatter); - } - - /** - * Adds metadata for the given column in the file. - * - * @param positionInFile - * Indicates which column the metadata is for. Columns start at 1. - * @param name - * Name for the column (optional if only using column ordinal in a mapping for SQLServerBulkCopy operation) - * @param jdbcType - * JDBC data type of the column - * @param precision - * Precision for the column (ignored for the appropriate data types) - * @param scale - * Scale for the column (ignored for the appropriate data types) - * @throws SQLServerException - * when an error occurs - */ - public void addColumnMetadata(int positionInFile, - String name, - int jdbcType, - int precision, - int scale) throws SQLServerException { - addColumnMetadataInternal(positionInFile, name, jdbcType, precision, scale, null); - } - - /** - * Adds metadata for the given column in the file. - * - * @param positionInFile - * Indicates which column the metadata is for. Columns start at 1. - * @param name - * Name for the column (optional if only using column ordinal in a mapping for SQLServerBulkCopy operation) - * @param jdbcType - * JDBC data type of the column - * @param precision - * Precision for the column (ignored for the appropriate data types) - * @param scale - * Scale for the column (ignored for the appropriate data types) - * @param dateTimeFormatter - * format to parse data that is sent - * @throws SQLServerException - * when an error occurs - */ - void addColumnMetadataInternal(int positionInFile, - String name, - int jdbcType, - int precision, - int scale, - DateTimeFormatter dateTimeFormatter) throws SQLServerException { - loggerExternal.entering(loggerClassName, "addColumnMetadata", new Object[] {positionInFile, name, jdbcType, precision, scale}); - - String colName = ""; - - if (0 >= positionInFile) { - MessageFormat form = new MessageFormat(SQLServerException.getErrString("R_invalidColumnOrdinal")); - Object[] msgArgs = {positionInFile}; - throw new SQLServerException(form.format(msgArgs), SQLState.COL_NOT_FOUND, DriverError.NOT_SET, null); - } - - if (null != name) - colName = name.trim(); - else if ((columnNames != null) && (columnNames.length >= positionInFile)) - colName = columnNames[positionInFile - 1]; - - if ((columnNames != null) && (positionInFile > columnNames.length)) { - MessageFormat form = new MessageFormat(SQLServerException.getErrString("R_invalidColumn")); - Object[] msgArgs = {positionInFile}; - throw new SQLServerException(form.format(msgArgs), SQLState.COL_NOT_FOUND, DriverError.NOT_SET, null); - } - - checkDuplicateColumnName(positionInFile, name); - switch (jdbcType) { - /* - * SQL Server supports numerous string literal formats for temporal types, hence sending them as varchar with approximate - * precision(length) needed to send supported string literals. string literal formats supported by temporal types are available in MSDN - * page on data types. - */ - case java.sql.Types.DATE: - case java.sql.Types.TIME: - case java.sql.Types.TIMESTAMP: - case microsoft.sql.Types.DATETIMEOFFSET: - // The precision is just a number long enough to hold all types of temporal data, doesn't need to be exact precision. - columnMetadata.put(positionInFile, new ColumnMetadata(colName, jdbcType, 50, scale, dateTimeFormatter)); - break; - - // Redirect SQLXML as LONGNVARCHAR - // SQLXML is not valid type in TDS - case java.sql.Types.SQLXML: - columnMetadata.put(positionInFile, new ColumnMetadata(colName, java.sql.Types.LONGNVARCHAR, precision, scale, dateTimeFormatter)); - break; - - // Redirecting Float as Double based on data type mapping - // https://msdn.microsoft.com/en-us/library/ms378878%28v=sql.110%29.aspx - case java.sql.Types.FLOAT: - columnMetadata.put(positionInFile, new ColumnMetadata(colName, java.sql.Types.DOUBLE, precision, scale, dateTimeFormatter)); - break; - - // redirecting BOOLEAN as BIT - case java.sql.Types.BOOLEAN: - columnMetadata.put(positionInFile, new ColumnMetadata(colName, java.sql.Types.BIT, precision, scale, dateTimeFormatter)); - break; - - default: - columnMetadata.put(positionInFile, new ColumnMetadata(colName, jdbcType, precision, scale, dateTimeFormatter)); - } - - loggerExternal.exiting(loggerClassName, "addColumnMetadata"); - } - - /** - * Set the format for reading in dates from the file. - * - * @param dateTimeFormat - * format to parse data sent as java.sql.Types.TIMESTAMP_WITH_TIMEZONE - */ - public void setTimestampWithTimezoneFormat(String dateTimeFormat) { - DriverJDBCVersion.checkSupportsJDBC42(); - loggerExternal.entering(loggerClassName, "setTimestampWithTimezoneFormat", dateTimeFormat); - - this.dateTimeFormatter = DateTimeFormatter.ofPattern(dateTimeFormat); - - loggerExternal.exiting(loggerClassName, "setTimestampWithTimezoneFormat"); - } - - /** - * Set the format for reading in dates from the file. - * - * @param dateTimeFormatter - * format to parse data sent as java.sql.Types.TIMESTAMP_WITH_TIMEZONE - */ - public void setTimestampWithTimezoneFormat(DateTimeFormatter dateTimeFormatter) { - loggerExternal.entering(loggerClassName, "setTimestampWithTimezoneFormat", new Object[] {dateTimeFormatter}); - - this.dateTimeFormatter = dateTimeFormatter; - - loggerExternal.exiting(loggerClassName, "setTimestampWithTimezoneFormat"); - } - - /** - * Set the format for reading in dates from the file. - * - * @param timeFormat - * format to parse data sent as java.sql.Types.TIME_WITH_TIMEZONE - */ - public void setTimeWithTimezoneFormat(String timeFormat) { - DriverJDBCVersion.checkSupportsJDBC42(); - loggerExternal.entering(loggerClassName, "setTimeWithTimezoneFormat", timeFormat); - - this.timeFormatter = DateTimeFormatter.ofPattern(timeFormat); - - loggerExternal.exiting(loggerClassName, "setTimeWithTimezoneFormat"); - } - - /** - * Set the format for reading in dates from the file. - * - * @param dateTimeFormatter - * format to parse data sent as java.sql.Types.TIME_WITH_TIMEZONE - */ - public void setTimeWithTimezoneFormat(DateTimeFormatter dateTimeFormatter) { - loggerExternal.entering(loggerClassName, "setTimeWithTimezoneFormat", new Object[] {dateTimeFormatter}); - - this.timeFormatter = dateTimeFormatter; - - loggerExternal.exiting(loggerClassName, "setTimeWithTimezoneFormat"); - } - /** * Releases any resources associated with the file reader. * @@ -538,7 +303,7 @@ public Object[] getRowData() throws SQLServerException { // Source header has more columns than current line read if (columnNames != null && (columnNames.length > data.length)) { - MessageFormat form = new MessageFormat(SQLServerException.getErrString("R_CSVDataSchemaMismatch")); + MessageFormat form = new MessageFormat(SQLServerException.getErrString("R_DataSchemaMismatch")); Object[] msgArgs = {}; throw new SQLServerException(form.format(msgArgs), SQLState.COL_NOT_FOUND, DriverError.NOT_SET, null); } @@ -557,6 +322,7 @@ public Object[] getRowData() throws SQLServerException { case Types.INTEGER: { // Formatter to remove the decimal part as SQL Server floors the decimal in integer types DecimalFormat decimalFormatter = new DecimalFormat("#"); + decimalFormatter.setRoundingMode(RoundingMode.DOWN); String formatedfInput = decimalFormatter.format(Double.parseDouble(data[pair.getKey() - 1])); dataRow[pair.getKey() - 1] = Integer.valueOf(formatedfInput); break; @@ -566,6 +332,7 @@ public Object[] getRowData() throws SQLServerException { case Types.SMALLINT: { // Formatter to remove the decimal part as SQL Server floors the decimal in integer types DecimalFormat decimalFormatter = new DecimalFormat("#"); + decimalFormatter.setRoundingMode(RoundingMode.DOWN); String formatedfInput = decimalFormatter.format(Double.parseDouble(data[pair.getKey() - 1])); dataRow[pair.getKey() - 1] = Short.valueOf(formatedfInput); break; @@ -632,9 +399,8 @@ public Object[] getRowData() throws SQLServerException { break; } - case 2013: // java.sql.Types.TIME_WITH_TIMEZONE + case java.sql.Types.TIME_WITH_TIMEZONE: { - DriverJDBCVersion.checkSupportsJDBC42(); OffsetTime offsetTimeValue; // The per-column DateTimeFormatter gets priority. @@ -649,9 +415,8 @@ else if (timeFormatter != null) break; } - case 2014: // java.sql.Types.TIMESTAMP_WITH_TIMEZONE + case java.sql.Types.TIMESTAMP_WITH_TIMEZONE: { - DriverJDBCVersion.checkSupportsJDBC42(); OffsetDateTime offsetDateTimeValue; // The per-column DateTimeFormatter gets priority. @@ -702,13 +467,115 @@ else if (dateTimeFormatter != null) MessageFormat form = new MessageFormat(SQLServerException.getErrString("R_errorConvertingValue")); throw new SQLServerException(form.format(new Object[]{value, JDBCType.of(cm.columnType)}), null, 0, e); } catch (ArrayIndexOutOfBoundsException e) { - throw new SQLServerException(SQLServerException.getErrString("R_CSVDataSchemaMismatch"), e); + throw new SQLServerException(SQLServerException.getErrString("R_DataSchemaMismatch"), e); } } return dataRow; } } + + @Override + void addColumnMetadataInternal(int positionInSource, + String name, + int jdbcType, + int precision, + int scale, + DateTimeFormatter dateTimeFormatter) throws SQLServerException { + loggerExternal.entering(loggerClassName, "addColumnMetadata", new Object[] {positionInSource, name, jdbcType, precision, scale}); + + String colName = ""; + + if (0 >= positionInSource) { + MessageFormat form = new MessageFormat(SQLServerException.getErrString("R_invalidColumnOrdinal")); + Object[] msgArgs = {positionInSource}; + throw new SQLServerException(form.format(msgArgs), SQLState.COL_NOT_FOUND, DriverError.NOT_SET, null); + } + + if (null != name) + colName = name.trim(); + else if ((null != columnNames) && (columnNames.length >= positionInSource)) + colName = columnNames[positionInSource - 1]; + + if ((null != columnNames) && (positionInSource > columnNames.length)) { + MessageFormat form = new MessageFormat(SQLServerException.getErrString("R_invalidColumn")); + Object[] msgArgs = {positionInSource}; + throw new SQLServerException(form.format(msgArgs), SQLState.COL_NOT_FOUND, DriverError.NOT_SET, null); + } + + checkDuplicateColumnName(positionInSource, name); + switch (jdbcType) { + /* + * SQL Server supports numerous string literal formats for temporal types, hence sending them as varchar with approximate + * precision(length) needed to send supported string literals. string literal formats supported by temporal types are available in MSDN + * page on data types. + */ + case java.sql.Types.DATE: + case java.sql.Types.TIME: + case java.sql.Types.TIMESTAMP: + case microsoft.sql.Types.DATETIMEOFFSET: + columnMetadata.put(positionInSource, new ColumnMetadata(colName, jdbcType, 50, scale, dateTimeFormatter)); + break; + + // Redirect SQLXML as LONGNVARCHAR + // SQLXML is not valid type in TDS + case java.sql.Types.SQLXML: + columnMetadata.put(positionInSource, new ColumnMetadata(colName, java.sql.Types.LONGNVARCHAR, precision, scale, dateTimeFormatter)); + break; + + // Redirecting Float as Double based on data type mapping + // https://msdn.microsoft.com/en-us/library/ms378878%28v=sql.110%29.aspx + case java.sql.Types.FLOAT: + columnMetadata.put(positionInSource, new ColumnMetadata(colName, java.sql.Types.DOUBLE, precision, scale, dateTimeFormatter)); + break; + + // redirecting BOOLEAN as BIT + case java.sql.Types.BOOLEAN: + columnMetadata.put(positionInSource, new ColumnMetadata(colName, java.sql.Types.BIT, precision, scale, dateTimeFormatter)); + break; + + default: + columnMetadata.put(positionInSource, new ColumnMetadata(colName, jdbcType, precision, scale, dateTimeFormatter)); + } + + loggerExternal.exiting(loggerClassName, "addColumnMetadata"); + } + + @Override + public void setTimestampWithTimezoneFormat(String dateTimeFormat) { + loggerExternal.entering(loggerClassName, "setTimestampWithTimezoneFormat", dateTimeFormat); + + super.setTimestampWithTimezoneFormat(dateTimeFormat); + + loggerExternal.exiting(loggerClassName, "setTimestampWithTimezoneFormat"); + } + + @Override + public void setTimestampWithTimezoneFormat(DateTimeFormatter dateTimeFormatter) { + loggerExternal.entering(loggerClassName, "setTimestampWithTimezoneFormat", new Object[] {dateTimeFormatter}); + + super.setTimestampWithTimezoneFormat(dateTimeFormatter); + + loggerExternal.exiting(loggerClassName, "setTimestampWithTimezoneFormat"); + } + + @Override + public void setTimeWithTimezoneFormat(String timeFormat) { + loggerExternal.entering(loggerClassName, "setTimeWithTimezoneFormat", timeFormat); + + super.setTimeWithTimezoneFormat(timeFormat); + + loggerExternal.exiting(loggerClassName, "setTimeWithTimezoneFormat"); + } + + @Override + public void setTimeWithTimezoneFormat(DateTimeFormatter dateTimeFormatter) { + loggerExternal.entering(loggerClassName, "setTimeWithTimezoneFormat", new Object[] {dateTimeFormatter}); + + super.setTimeWithTimezoneFormat(dateTimeFormatter); + + loggerExternal.exiting(loggerClassName, "setTimeWithTimezoneFormat"); + } @Override public boolean next() throws SQLServerException { @@ -720,33 +587,4 @@ public boolean next() throws SQLServerException { } return (null != currentLine); } - - /* - * Helper method to throw a SQLServerExeption with the invalidArgument message and given argument. - */ - private void throwInvalidArgument(String argument) throws SQLServerException { - MessageFormat form = new MessageFormat(SQLServerException.getErrString("R_invalidArgument")); - Object[] msgArgs = {argument}; - SQLServerException.makeFromDriverError(null, null, form.format(msgArgs), null, false); - } - - /* - * Method to throw a SQLServerExeption for duplicate column names - */ - private void checkDuplicateColumnName(int positionInFile, - String colName) throws SQLServerException { - - if (null != colName && colName.trim().length() != 0) { - for (Entry entry : columnMetadata.entrySet()) { - // duplicate check is not performed in case of same positionInFile value - if (null != entry && entry.getKey() != positionInFile) { - if (null != entry.getValue() && colName.trim().equalsIgnoreCase(entry.getValue().columnName)) { - throw new SQLServerException(SQLServerException.getErrString("R_BulkCSVDataDuplicateColumn"), null); - } - } - - } - } - - } } diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerBulkCommon.java b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerBulkCommon.java new file mode 100644 index 000000000..383b4c793 --- /dev/null +++ b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerBulkCommon.java @@ -0,0 +1,205 @@ +/* + * Microsoft JDBC Driver for SQL Server + * + * Copyright(c) Microsoft Corporation All rights reserved. + * + * This program is made available under the terms of the MIT License. See the LICENSE file in the project root for more information. + */ + +package com.microsoft.sqlserver.jdbc; + +import java.text.MessageFormat; +import java.time.format.DateTimeFormatter; +import java.util.Map; +import java.util.Map.Entry; + +abstract class SQLServerBulkCommon { + + /* + * Class to represent the column metadata + */ + protected class ColumnMetadata { + String columnName; + int columnType; + int precision; + int scale; + DateTimeFormatter dateTimeFormatter = null; + + ColumnMetadata(String name, + int type, + int precision, + int scale, + DateTimeFormatter dateTimeFormatter) { + columnName = name; + columnType = type; + this.precision = precision; + this.scale = scale; + this.dateTimeFormatter = dateTimeFormatter; + } + } + + /* + * Contains all the column names if firstLineIsColumnNames is true + */ + protected String[] columnNames = null; + + /* + * Metadata to represent the columns in the batch/file. Each column should be mapped to its corresponding position within the parameter (from + * position 1 and onwards) + */ + protected Map columnMetadata; + + /* + * Contains the format that java.sql.Types.TIMESTAMP_WITH_TIMEZONE data should be read in as. + */ + protected DateTimeFormatter dateTimeFormatter = null; + + /* + * Contains the format that java.sql.Types.TIME_WITH_TIMEZONE data should be read in as. + */ + protected DateTimeFormatter timeFormatter = null; + + /** + * Adds metadata for the given column in the batch/file. + * + * @param positionInSource + * Indicates which column the metadata is for. Columns start at 1. + * @param name + * Name for the column (optional if only using column ordinal in a mapping for SQLServerBulkCopy operation) + * @param jdbcType + * JDBC data type of the column + * @param precision + * Precision for the column (ignored for the appropriate data types) + * @param scale + * Scale for the column (ignored for the appropriate data types) + * @param dateTimeFormatter + * format to parse data that is sent + * @throws SQLServerException + * when an error occurs + */ + public void addColumnMetadata(int positionInSource, + String name, + int jdbcType, + int precision, + int scale, + DateTimeFormatter dateTimeFormatter) throws SQLServerException { + addColumnMetadataInternal(positionInSource, name, jdbcType, precision, scale, dateTimeFormatter); + } + + /** + * Adds metadata for the given column in the batch/file. + * + * @param positionInSource + * Indicates which column the metadata is for. Columns start at 1. + * @param name + * Name for the column (optional if only using column ordinal in a mapping for SQLServerBulkCopy operation) + * @param jdbcType + * JDBC data type of the column + * @param precision + * Precision for the column (ignored for the appropriate data types) + * @param scale + * Scale for the column (ignored for the appropriate data types) + * @throws SQLServerException + * when an error occurs + */ + public void addColumnMetadata(int positionInSource, + String name, + int jdbcType, + int precision, + int scale) throws SQLServerException { + addColumnMetadataInternal(positionInSource, name, jdbcType, precision, scale, null); + } + + /** + * Adds metadata for the given column in the batch/file. + * + * @param positionInSource + * Indicates which column the metadata is for. Columns start at 1. + * @param name + * Name for the column (optional if only using column ordinal in a mapping for SQLServerBulkCopy operation) + * @param jdbcType + * JDBC data type of the column + * @param precision + * Precision for the column (ignored for the appropriate data types) + * @param scale + * Scale for the column (ignored for the appropriate data types) + * @param dateTimeFormatter + * format to parse data that is sent + * @throws SQLServerException + * when an error occurs + */ + void addColumnMetadataInternal(int positionInSource, + String name, + int jdbcType, + int precision, + int scale, + DateTimeFormatter dateTimeFormatter) throws SQLServerException { + } + + /** + * Set the format for reading in dates from the batch/file. + * + * @param dateTimeFormat + * format to parse data sent as java.sql.Types.TIMESTAMP_WITH_TIMEZONE + */ + public void setTimestampWithTimezoneFormat(String dateTimeFormat) { + this.dateTimeFormatter = DateTimeFormatter.ofPattern(dateTimeFormat); + } + + /** + * Set the format for reading in dates from the batch/file. + * + * @param dateTimeFormatter + * format to parse data sent as java.sql.Types.TIMESTAMP_WITH_TIMEZONE + */ + public void setTimestampWithTimezoneFormat(DateTimeFormatter dateTimeFormatter) { + this.dateTimeFormatter = dateTimeFormatter; + } + + /** + * Set the format for reading in dates from the batch/file. + * + * @param timeFormat + * format to parse data sent as java.sql.Types.TIME_WITH_TIMEZONE + */ + public void setTimeWithTimezoneFormat(String timeFormat) { + this.timeFormatter = DateTimeFormatter.ofPattern(timeFormat); + } + + /** + * Set the format for reading in dates from the batch/file. + * + * @param dateTimeFormatter + * format to parse data sent as java.sql.Types.TIME_WITH_TIMEZONE + */ + public void setTimeWithTimezoneFormat(DateTimeFormatter dateTimeFormatter) { + this.timeFormatter = dateTimeFormatter; + } + + /* + * Helper method to throw a SQLServerExeption with the invalidArgument message and given argument. + */ + protected void throwInvalidArgument(String argument) throws SQLServerException { + MessageFormat form = new MessageFormat(SQLServerException.getErrString("R_invalidArgument")); + Object[] msgArgs = {argument}; + SQLServerException.makeFromDriverError(null, null, form.format(msgArgs), null, false); + } + + /* + * Method to throw a SQLServerExeption for duplicate column names + */ + protected void checkDuplicateColumnName(int positionInTable, + String colName) throws SQLServerException { + + if (null != colName && colName.trim().length() != 0) { + for (Entry entry : columnMetadata.entrySet()) { + // duplicate check is not performed in case of same positionInTable value + if (null != entry && entry.getKey() != positionInTable) { + if (null != entry.getValue() && colName.trim().equalsIgnoreCase(entry.getValue().columnName)) { + throw new SQLServerException(SQLServerException.getErrString("R_BulkDataDuplicateColumn"), null); + } + } + } + } + } +} diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerBulkCopy.java b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerBulkCopy.java index c5ad97c9a..647d77f64 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerBulkCopy.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerBulkCopy.java @@ -153,6 +153,11 @@ private class ColumnMapping { /* The CekTable for the destination table. */ private CekTable destCekTable = null; + /* Statement level encryption setting needed for querying against encrypted columns. */ + private SQLServerStatementColumnEncryptionSetting stmtColumnEncriptionSetting = SQLServerStatementColumnEncryptionSetting.UseConnectionSetting; + + private ResultSet destinationTableMetadata; + /* * Metadata for the destination table columns */ @@ -1490,7 +1495,12 @@ private String createInsertBulkCommand(TDSWriter tdsWriter) throws SQLServerExce if (null != destType && (destType.toLowerCase(Locale.ENGLISH).trim().startsWith("char") || destType.toLowerCase(Locale.ENGLISH).trim().startsWith("varchar"))) addCollate = " COLLATE " + columnCollation; } - bulkCmd.append("[" + colMapping.destinationColumnName + "] " + destType + addCollate + endColumn); + if (colMapping.destinationColumnName.contains("]")) { + String escapedColumnName = colMapping.destinationColumnName.replaceAll("]", "]]"); + bulkCmd.append("[" + escapedColumnName + "] " + destType + addCollate + endColumn); + } else { + bulkCmd.append("[" + colMapping.destinationColumnName + "] " + destType + addCollate + endColumn); + } } if (true == copyOptions.isCheckConstraints()) { @@ -1740,11 +1750,19 @@ private void getDestinationMetadata() throws SQLServerException { SQLServerResultSet rs = null; SQLServerResultSet rsMoreMetaData = null; - + SQLServerStatement stmt = null; + try { - // Get destination metadata - rs = ((SQLServerStatement) connection.createStatement()) - .executeQueryInternal("SET FMTONLY ON SELECT * FROM " + destinationTableName + " SET FMTONLY OFF "); + if (null != destinationTableMetadata) { + rs = (SQLServerResultSet) destinationTableMetadata; + } + else { + stmt = (SQLServerStatement) connection.createStatement(ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY, + connection.getHoldability(), stmtColumnEncriptionSetting); + + // Get destination metadata + rs = stmt.executeQueryInternal("sp_executesql N'SET FMTONLY ON SELECT * FROM " + destinationTableName + " '"); + } destColumnCount = rs.getMetaData().getColumnCount(); destColumnMetadata = new HashMap<>(); @@ -1783,6 +1801,8 @@ private void getDestinationMetadata() throws SQLServerException { finally { if (null != rs) rs.close(); + if (null != stmt) + stmt.close(); if (null != rsMoreMetaData) rsMoreMetaData.close(); } @@ -3059,8 +3079,6 @@ protected Object getTemporalObjectFromCSVWithFormatter(String valueStrUntrimmed, int srcJdbcType, int srcColOrdinal, DateTimeFormatter dateTimeFormatter) throws SQLServerException { - DriverJDBCVersion.checkSupportsJDBC42(); - SQLServerBulkCopy42Helper.getTemporalObjectFromCSVWithFormatter(valueStrUntrimmed, srcJdbcType, srcColOrdinal, dateTimeFormatter, connection, this); @@ -3573,4 +3591,12 @@ private boolean writeBatchData(TDSWriter tdsWriter, } } } + + protected void setStmtColumnEncriptionSetting(SQLServerStatementColumnEncryptionSetting stmtColumnEncriptionSetting) { + this.stmtColumnEncriptionSetting = stmtColumnEncriptionSetting; + } + + protected void setDestinationTableMetadata(SQLServerResultSet rs) { + destinationTableMetadata = rs; + } } diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerConnection.java b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerConnection.java index be9fa8f50..c632bb9be 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerConnection.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerConnection.java @@ -119,6 +119,8 @@ public class SQLServerConnection implements ISQLServerConnection { private SqlFedAuthToken fedAuthToken = null; private String originalHostNameInCertificate = null; + + private Boolean isAzureDW = null; static class Sha1HashKey { private byte[] bytes; @@ -403,7 +405,7 @@ private enum State { ServerPortPlaceHolder getRoutingInfo() { return routingInfo; } - + // Permission targets private static final String callAbortPerm = "callAbort"; @@ -492,6 +494,27 @@ final int getSocketTimeoutMilliseconds() { return socketTimeoutMilliseconds; } + /** + * boolean value for deciding if the driver should use bulk copy API for batch inserts + */ + private boolean useBulkCopyForBatchInsert; + + /** + * Retrieves the useBulkCopyForBatchInsert value. + * @return flag for using Bulk Copy API for batch insert operations. + */ + public boolean getUseBulkCopyForBatchInsert() { + return useBulkCopyForBatchInsert; + } + + /** + * Specifies the flag for using Bulk Copy API for batch insert operations. + * @param useBulkCopyForBatchInsert boolean value for useBulkCopyForBatchInsert. + */ + public void setUseBulkCopyForBatchInsert(boolean useBulkCopyForBatchInsert) { + this.useBulkCopyForBatchInsert = useBulkCopyForBatchInsert; + } + boolean userSetTNIR = true; private boolean sendTimeAsDatetime = SQLServerDriverBooleanProperty.SEND_TIME_AS_DATETIME.getDefaultValue(); @@ -1207,7 +1230,7 @@ Connection connectInternal(Properties propsIn, activeConnectionProperties = (Properties) propsIn.clone(); pooledConnectionParent = pooledConnection; - + String hostNameInCertificate = activeConnectionProperties. getProperty(SQLServerDriverStringProperty.HOSTNAME_IN_CERTIFICATE.toString()); @@ -1224,7 +1247,7 @@ Connection connectInternal(Properties propsIn, activeConnectionProperties.setProperty(SQLServerDriverStringProperty.HOSTNAME_IN_CERTIFICATE.toString(), originalHostNameInCertificate); } - + String sPropKey; String sPropValue; @@ -1772,7 +1795,13 @@ else if (0 == requestedPacketSize) if (null != sPropValue) { setEnablePrepareOnFirstPreparedStatementCall(booleanPropertyOn(sPropKey, sPropValue)); } - + + sPropKey = SQLServerDriverBooleanProperty.USE_BULK_COPY_FOR_BATCH_INSERT.toString(); + sPropValue = activeConnectionProperties.getProperty(sPropKey); + if (null != sPropValue) { + useBulkCopyForBatchInsert = booleanPropertyOn(sPropKey, sPropValue); + } + sPropKey = SQLServerDriverStringProperty.SSL_PROTOCOL.toString(); sPropValue = activeConnectionProperties.getProperty(sPropKey); if (null == sPropValue) { @@ -3731,6 +3760,7 @@ final void processEnvChange(TDSReader tdsReader) throws SQLServerException { isRoutedInCurrentAttempt = true; routingInfo = new ServerPortPlaceHolder(routingServerName, routingPortNumber, null, integratedSecurity); + break; // Error on unrecognized, unused ENVCHANGES @@ -6040,6 +6070,33 @@ public void onEviction(Sha1HashKey key, PreparedStatementHandle handle) { } } + boolean isAzureDW() throws SQLServerException, SQLException { + if (null == isAzureDW) { + try (Statement stmt = this.createStatement(); ResultSet rs = stmt.executeQuery("SELECT CAST(SERVERPROPERTY('EngineEdition') as INT)");) + { + // SERVERPROPERTY('EngineEdition') can be used to determine whether the db server is SQL Azure. + // It should return 6 for SQL Azure DW. This is more reliable than @@version or serverproperty('edition'). + // Reference: http://msdn.microsoft.com/en-us/library/ee336261.aspx + // + // SERVERPROPERTY('EngineEdition') means + // Database Engine edition of the instance of SQL Server installed on the server. + // 1 = Personal or Desktop Engine (Not available for SQL Server.) + // 2 = Standard (This is returned for Standard and Workgroup.) + // 3 = Enterprise (This is returned for Enterprise, Enterprise Evaluation, and Developer.) + // 4 = Express (This is returned for Express, Express with Advanced Services, and Windows Embedded SQL.) + // 5 = SQL Azure + // 6 = SQL Azure DW + // Base data type: int + final int ENGINE_EDITION_FOR_SQL_AZURE_DW = 6; + rs.next(); + isAzureDW = (rs.getInt(1) == ENGINE_EDITION_FOR_SQL_AZURE_DW) ? true : false; + } + return isAzureDW; + } else { + return isAzureDW; + } + } + /** * @param st * Statement to add to openStatements diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerDataSource.java b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerDataSource.java index 96d25f106..a01711a12 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerDataSource.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerDataSource.java @@ -821,6 +821,26 @@ public int getSocketTimeout() { int defaultTimeOut = SQLServerDriverIntProperty.SOCKET_TIMEOUT.getDefaultValue(); return getIntProperty(connectionProps, SQLServerDriverIntProperty.SOCKET_TIMEOUT.toString(), defaultTimeOut); } + + /** + * Setting the use Bulk Copy API for Batch Insert + * + * @param useBulkCopyForBatchInsert indicates whether Bulk Copy API should be used for Batch Insert operations. + */ + public void setUseBulkCopyForBatchInsert(boolean useBulkCopyForBatchInsert) { + setBooleanProperty(connectionProps, SQLServerDriverBooleanProperty.USE_BULK_COPY_FOR_BATCH_INSERT.toString(), + useBulkCopyForBatchInsert); + } + + /** + * Getting the use Bulk Copy API for Batch Insert + * + * @return whether the driver should use Bulk Copy API for Batch Insert operations. + */ + public boolean getUseBulkCopyForBatchInsert() { + return getBooleanProperty(connectionProps, SQLServerDriverBooleanProperty.USE_BULK_COPY_FOR_BATCH_INSERT.toString(), + SQLServerDriverBooleanProperty.USE_BULK_COPY_FOR_BATCH_INSERT.getDefaultValue()); + } /** * Sets the login configuration file for Kerberos authentication. This diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerDriver.java b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerDriver.java index 9f13dc3f4..5f67f93db 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerDriver.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerDriver.java @@ -353,7 +353,8 @@ enum SQLServerDriverBooleanProperty TRUST_SERVER_CERTIFICATE ("trustServerCertificate", false), XOPEN_STATES ("xopenStates", false), FIPS ("fips", false), - ENABLE_PREPARE_ON_FIRST_PREPARED_STATEMENT("enablePrepareOnFirstPreparedStatementCall", SQLServerConnection.DEFAULT_ENABLE_PREPARE_ON_FIRST_PREPARED_STATEMENT_CALL); + ENABLE_PREPARE_ON_FIRST_PREPARED_STATEMENT("enablePrepareOnFirstPreparedStatementCall", SQLServerConnection.DEFAULT_ENABLE_PREPARE_ON_FIRST_PREPARED_STATEMENT_CALL), + USE_BULK_COPY_FOR_BATCH_INSERT ("useBulkCopyForBatchInsert", false); private final String name; private final boolean defaultValue; @@ -429,7 +430,8 @@ public final class SQLServerDriver implements java.sql.Driver { new SQLServerDriverPropertyInfo(SQLServerDriverIntProperty.STATEMENT_POOLING_CACHE_SIZE.toString(), Integer.toString(SQLServerDriverIntProperty.STATEMENT_POOLING_CACHE_SIZE.getDefaultValue()), false, null), new SQLServerDriverPropertyInfo(SQLServerDriverStringProperty.JAAS_CONFIG_NAME.toString(), SQLServerDriverStringProperty.JAAS_CONFIG_NAME.getDefaultValue(), false, null), new SQLServerDriverPropertyInfo(SQLServerDriverStringProperty.SSL_PROTOCOL.toString(), SQLServerDriverStringProperty.SSL_PROTOCOL.getDefaultValue(), false, new String[] {SSLProtocol.TLS.toString(), SSLProtocol.TLS_V10.toString(), SSLProtocol.TLS_V11.toString(), SSLProtocol.TLS_V12.toString()}), - new SQLServerDriverPropertyInfo(SQLServerDriverIntProperty.CANCEL_QUERY_TIMEOUT.toString(), Integer.toString(SQLServerDriverIntProperty.CANCEL_QUERY_TIMEOUT.getDefaultValue()), false, null), + new SQLServerDriverPropertyInfo(SQLServerDriverIntProperty.CANCEL_QUERY_TIMEOUT.toString(), Integer.toString(SQLServerDriverIntProperty.CANCEL_QUERY_TIMEOUT.getDefaultValue()), false, null), + new SQLServerDriverPropertyInfo(SQLServerDriverBooleanProperty.USE_BULK_COPY_FOR_BATCH_INSERT.toString(), Boolean.toString(SQLServerDriverBooleanProperty.USE_BULK_COPY_FOR_BATCH_INSERT.getDefaultValue()), false, TRUE_FALSE), }; // Properties that can only be set by using Properties. diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerPreparedStatement.java b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerPreparedStatement.java index ac8e66fe3..ac6289f2f 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerPreparedStatement.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerPreparedStatement.java @@ -106,6 +106,35 @@ public class SQLServerPreparedStatement extends SQLServerStatement implements IS private void setPreparedStatementHandle(int handle) { this.prepStmtHandle = handle; } + + /** + * boolean value for deciding if the driver should use bulk copy API for batch inserts + */ + private boolean useBulkCopyForBatchInsert; + + /** Gets the prepared statement's useBulkCopyForBatchInsert value. + * + * @return + * Per the description. + * @throws SQLServerException when an error occurs + */ + @SuppressWarnings("unused") + private boolean getUseBulkCopyForBatchInsert() throws SQLServerException { + checkClosed(); + return useBulkCopyForBatchInsert; + } + + /** Sets the prepared statement's useBulkCopyForBatchInsert value. + * + * @param useBulkCopyForBatchInsert + * the boolean value + * @throws SQLServerException when an error occurs + */ + @SuppressWarnings("unused") + private void setUseBulkCopyForBatchInsert(boolean useBulkCopyForBatchInsert) throws SQLServerException { + checkClosed(); + this.useBulkCopyForBatchInsert = useBulkCopyForBatchInsert; + } /** The server handle for this prepared statement. If a value {@literal <} 1 is returned no handle has been created. * @@ -150,6 +179,8 @@ private boolean resetPrepStmtHandle(boolean discardCurrentCacheItem) { */ private boolean encryptionMetadataIsRetrieved = false; + private String localUserSQL; + // Internal function used in tracing String getClassNameInternal() { return "SQLServerPreparedStatement"; @@ -205,6 +236,7 @@ String getClassNameInternal() { bReturnValueSyntax = parsedSQL.bReturnValueSyntax; userSQL = parsedSQL.processedSQL; initParams(parsedSQL.parameterCount); + useBulkCopyForBatchInsert = conn.getUseBulkCopyForBatchInsert(); } /** @@ -2446,8 +2478,97 @@ public int[] executeBatch() throws SQLServerException, BatchUpdateException, SQL } checkClosed(); discardLastExecutionResults(); - + int updateCounts[]; + + localUserSQL = userSQL; + + try { + if (isInsert(localUserSQL) && connection.isAzureDW() && (this.useBulkCopyForBatchInsert)) { + // From the JDBC spec, section 9.1.4 - Making Batch Updates: + // The CallableStatement.executeBatch method (inherited from PreparedStatement) will + // throw a BatchUpdateException if the stored procedure returns anything other than an + // update count or takes OUT or INOUT parameters. + // + // Non-update count results (e.g. ResultSets) are treated as individual batch errors + // when they are encountered in the response. + // + // OUT and INOUT parameter checking is done here, before executing the batch. If any + // OUT or INOUT are present, the entire batch fails. + for (Parameter[] paramValues : batchParamValues) { + for (Parameter paramValue : paramValues) { + if (paramValue.isOutput()) { + throw new BatchUpdateException(SQLServerException.getErrString("R_outParamsNotPermittedinBatch"), null, 0, null); + } + } + } + + if (batchParamValues == null) { + updateCounts = new int[0]; + loggerExternal.exiting(getClassNameLogging(), "executeBatch", updateCounts); + return updateCounts; + } + + String tableName = parseUserSQLForTableNameDW(false, false, false, false); + ArrayList columnList = parseUserSQLForColumnListDW(); + ArrayList valueList = parseUserSQLForValueListDW(false); + + String destinationTableName = tableName; + SQLServerStatement stmt = (SQLServerStatement) connection.createStatement(ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY, + connection.getHoldability(), stmtColumnEncriptionSetting); + + // Get destination metadata + try (SQLServerResultSet rs = stmt + .executeQueryInternal("sp_executesql N'SET FMTONLY ON SELECT * FROM " + destinationTableName + " '");) { + + SQLServerBulkBatchInsertRecord batchRecord = new SQLServerBulkBatchInsertRecord(batchParamValues, columnList, valueList, null); + + for (int i = 1; i <= rs.getColumnCount(); i++) { + Column c = rs.getColumn(i); + CryptoMetadata cryptoMetadata = c.getCryptoMetadata(); + int jdbctype; + TypeInfo ti = c.getTypeInfo(); + if (null != cryptoMetadata) { + jdbctype = cryptoMetadata.getBaseTypeInfo().getSSType().getJDBCType().getIntValue(); + } + else { + jdbctype = ti.getSSType().getJDBCType().getIntValue(); + } + batchRecord.addColumnMetadata(i, c.getColumnName(), jdbctype, ti.getPrecision(), ti.getScale()); + } + + SQLServerBulkCopy bcOperation = new SQLServerBulkCopy(connection); + bcOperation.setDestinationTableName(tableName); + bcOperation.setStmtColumnEncriptionSetting(this.getStmtColumnEncriptionSetting()); + bcOperation.setDestinationTableMetadata(rs); + bcOperation.writeToServer((ISQLServerBulkRecord) batchRecord); + bcOperation.close(); + updateCounts = new int[batchParamValues.size()]; + for (int i = 0; i < batchParamValues.size(); ++i) { + updateCounts[i] = 1; + } + + batchParamValues = null; + loggerExternal.exiting(getClassNameLogging(), "executeBatch", updateCounts); + return updateCounts; + } + finally { + if (null != stmt) + stmt.close(); + } + } + } + catch (SQLException e) { + // throw a BatchUpdateException with the given error message, and return null for the updateCounts. + throw new BatchUpdateException(e.getMessage(), null, 0, null); + } + catch (IllegalArgumentException e) { + // If we fail with IllegalArgumentException, fall back to the original batch insert logic. + if (getStatementLogger().isLoggable(java.util.logging.Level.FINE)) { + getStatementLogger().fine("Parsing user's Batch Insert SQL Query failed: " + e.toString()); + getStatementLogger().fine("Falling back to the original implementation for Batch Insert."); + } + } if (batchParamValues == null) updateCounts = new int[0]; @@ -2505,6 +2626,95 @@ public long[] executeLargeBatch() throws SQLServerException, BatchUpdateExceptio discardLastExecutionResults(); long updateCounts[]; + + localUserSQL = userSQL; + + try { + if (isInsert(localUserSQL) && connection.isAzureDW() && (this.useBulkCopyForBatchInsert)) { + // From the JDBC spec, section 9.1.4 - Making Batch Updates: + // The CallableStatement.executeBatch method (inherited from PreparedStatement) will + // throw a BatchUpdateException if the stored procedure returns anything other than an + // update count or takes OUT or INOUT parameters. + // + // Non-update count results (e.g. ResultSets) are treated as individual batch errors + // when they are encountered in the response. + // + // OUT and INOUT parameter checking is done here, before executing the batch. If any + // OUT or INOUT are present, the entire batch fails. + for (Parameter[] paramValues : batchParamValues) { + for (Parameter paramValue : paramValues) { + if (paramValue.isOutput()) { + throw new BatchUpdateException(SQLServerException.getErrString("R_outParamsNotPermittedinBatch"), null, 0, null); + } + } + } + + if (batchParamValues == null) { + updateCounts = new long[0]; + loggerExternal.exiting(getClassNameLogging(), "executeLargeBatch", updateCounts); + return updateCounts; + } + + String tableName = parseUserSQLForTableNameDW(false, false, false, false); + ArrayList columnList = parseUserSQLForColumnListDW(); + ArrayList valueList = parseUserSQLForValueListDW(false); + + String destinationTableName = tableName; + SQLServerStatement stmt = (SQLServerStatement) connection.createStatement(ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY, + connection.getHoldability(), stmtColumnEncriptionSetting); + + // Get destination metadata + try (SQLServerResultSet rs = stmt + .executeQueryInternal("sp_executesql N'SET FMTONLY ON SELECT * FROM " + destinationTableName + " '");) { + + SQLServerBulkBatchInsertRecord batchRecord = new SQLServerBulkBatchInsertRecord(batchParamValues, columnList, valueList, null); + + for (int i = 1; i <= rs.getColumnCount(); i++) { + Column c = rs.getColumn(i); + CryptoMetadata cryptoMetadata = c.getCryptoMetadata(); + int jdbctype; + TypeInfo ti = c.getTypeInfo(); + if (null != cryptoMetadata) { + jdbctype = cryptoMetadata.getBaseTypeInfo().getSSType().getJDBCType().getIntValue(); + } + else { + jdbctype = ti.getSSType().getJDBCType().getIntValue(); + } + batchRecord.addColumnMetadata(i, c.getColumnName(), jdbctype, ti.getPrecision(), ti.getScale()); + } + + SQLServerBulkCopy bcOperation = new SQLServerBulkCopy(connection); + bcOperation.setDestinationTableName(tableName); + bcOperation.setStmtColumnEncriptionSetting(this.getStmtColumnEncriptionSetting()); + bcOperation.setDestinationTableMetadata(rs); + bcOperation.writeToServer((ISQLServerBulkRecord) batchRecord); + bcOperation.close(); + updateCounts = new long[batchParamValues.size()]; + for (int i = 0; i < batchParamValues.size(); ++i) { + updateCounts[i] = 1; + } + + batchParamValues = null; + loggerExternal.exiting(getClassNameLogging(), "executeLargeBatch", updateCounts); + return updateCounts; + } + finally { + if (null != stmt) + stmt.close(); + } + } + } + catch (SQLException e) { + // throw a BatchUpdateException with the given error message, and return null for the updateCounts. + throw new BatchUpdateException(e.getMessage(), null, 0, null); + } + catch (IllegalArgumentException e) { + // If we fail with IllegalArgumentException, fall back to the original batch insert logic. + if (getStatementLogger().isLoggable(java.util.logging.Level.FINE)) { + getStatementLogger().fine("Parsing user's Batch Insert SQL Query failed: " + e.toString()); + getStatementLogger().fine("Falling back to the original implementation for Batch Insert."); + } + } if (batchParamValues == null) updateCounts = new long[0]; @@ -2548,6 +2758,315 @@ public long[] executeLargeBatch() throws SQLServerException, BatchUpdateExceptio loggerExternal.exiting(getClassNameLogging(), "executeLargeBatch", updateCounts); return updateCounts; } + + + private String parseUserSQLForTableNameDW(boolean hasInsertBeenFound, boolean hasIntoBeenFound, boolean hasTableBeenFound, + boolean isExpectingTableName) { + // As far as finding the table name goes, There are two cases: + // Insert into and Insert + // And there could be in-line comments (with /* and */) in between. + // This method assumes the localUserSQL string starts with "insert". + localUserSQL = localUserSQL.trim(); + if (checkAndRemoveComments()) { + return parseUserSQLForTableNameDW(hasInsertBeenFound, hasIntoBeenFound, hasTableBeenFound, isExpectingTableName); + } + + StringBuilder sb = new StringBuilder(); + + // If table has been found and the next character is not a . at this point, we've finished parsing the table name. + // This if statement is needed to handle the case where the user has something like: + // [dbo] . /* random comment */ [tableName] + if (hasTableBeenFound && !isExpectingTableName) { + if (localUserSQL.substring(0, 1).equalsIgnoreCase(".")) { + sb.append("."); + localUserSQL = localUserSQL.substring(1); + return sb.toString() + parseUserSQLForTableNameDW(true, true, true, true); + } else { + return ""; + } + } + + if (localUserSQL.substring(0, 6).equalsIgnoreCase("insert") && !hasInsertBeenFound) { + localUserSQL = localUserSQL.substring(6); + return parseUserSQLForTableNameDW(true, hasIntoBeenFound, hasTableBeenFound, isExpectingTableName); + } + + if (localUserSQL.substring(0, 4).equalsIgnoreCase("into") && !hasIntoBeenFound) { + // is it really "into"? + // if the "into" is followed by a blank space or /*, then yes. + if (Character.isWhitespace(localUserSQL.charAt(4)) || + (localUserSQL.charAt(4) == '/' && localUserSQL.charAt(5) == '*')) { + localUserSQL = localUserSQL.substring(4); + return parseUserSQLForTableNameDW(hasInsertBeenFound, true, hasTableBeenFound, isExpectingTableName); + } + + // otherwise, we found the token that either contains the databasename.tablename or tablename. + // Recursively handle this, but into has been found. (or rather, it's absent in the query - the "into" keyword is optional) + return parseUserSQLForTableNameDW(hasInsertBeenFound, true, hasTableBeenFound, isExpectingTableName); + } + + // At this point, the next token has to be the table name. + // It could be encapsulated in [], "", or have a database name preceding the table name. + // If it's encapsulated in [] or "", we need be more careful with parsing as anything could go into []/"". + // For ] or ", they can be escaped by ]] or "", watch out for this too. + if (localUserSQL.substring(0, 1).equalsIgnoreCase("[")) { + int tempint = localUserSQL.indexOf("]", 1); + + // keep checking if it's escaped + while (localUserSQL.charAt(tempint + 1) == ']') { + tempint = localUserSQL.indexOf("]", tempint + 2); + } + + // we've found a ] that is actually trying to close the square bracket. + // return tablename + potentially more that's part of the table name + sb.append(localUserSQL.substring(0, tempint + 1)); + localUserSQL = localUserSQL.substring(tempint + 1); + return sb.toString() + parseUserSQLForTableNameDW(true, true, true, false); + } + + // do the same for "" + if (localUserSQL.substring(0, 1).equalsIgnoreCase("\"")) { + int tempint = localUserSQL.indexOf("\"", 1); + + // keep checking if it's escaped + while (localUserSQL.charAt(tempint + 1) == '\"') { + tempint = localUserSQL.indexOf("\"", tempint + 2); + } + + // we've found a " that is actually trying to close the quote. + // return tablename + potentially more that's part of the table name + sb.append(localUserSQL.substring(0, tempint + 1)); + localUserSQL = localUserSQL.substring(tempint + 1); + return sb.toString() + parseUserSQLForTableNameDW(true, true, true, false); + } + + // At this point, the next chunk of string is the table name, without starting with [ or ". + while (localUserSQL.length() > 0) { + // Keep going until the end of the table name is signalled - either a ., whitespace, or comment is encountered + if (localUserSQL.charAt(0) == '.' || Character.isWhitespace(localUserSQL.charAt(0)) || checkAndRemoveComments()) { + return sb.toString() + parseUserSQLForTableNameDW(true, true, true, false); + } else { + sb.append(localUserSQL.charAt(0)); + localUserSQL = localUserSQL.substring(1); + } + } + + // It shouldn't come here. If we did, something is wrong. + throw new IllegalArgumentException("localUserSQL"); + } + + private ArrayList parseUserSQLForColumnListDW() { + localUserSQL = localUserSQL.trim(); + + // ignore all comments + if (checkAndRemoveComments()) { + return parseUserSQLForColumnListDW(); + } + + //check if optional column list was provided + // Columns can have the form of c1, [c1] or "c1". It can escape ] or " by ]] or "". + if (localUserSQL.substring(0, 1).equalsIgnoreCase("(")) { + localUserSQL = localUserSQL.substring(1); + return parseUserSQLForColumnListDWHelper(new ArrayList()); + } + return null; + } + + private ArrayList parseUserSQLForColumnListDWHelper(ArrayList listOfColumns) { + localUserSQL = localUserSQL.trim(); + + // ignore all comments + if (checkAndRemoveComments()) { + return parseUserSQLForColumnListDWHelper(listOfColumns); + } + + if (localUserSQL.charAt(0) == ')') { + localUserSQL = localUserSQL.substring(1); + return listOfColumns; + } + + if (localUserSQL.charAt(0) == ',') { + localUserSQL = localUserSQL.substring(1); + return parseUserSQLForColumnListDWHelper(listOfColumns); + } + + if (localUserSQL.charAt(0) == '[') { + int tempint = localUserSQL.indexOf("]", 1); + + // keep checking if it's escaped + while (localUserSQL.charAt(tempint + 1) == ']') { + localUserSQL = localUserSQL.substring(0, tempint) + localUserSQL.substring(tempint + 1); + tempint = localUserSQL.indexOf("]", tempint + 1); + } + + // we've found a ] that is actually trying to close the square bracket. + String tempstr = localUserSQL.substring(1, tempint); + localUserSQL = localUserSQL.substring(tempint + 1); + listOfColumns.add(tempstr); + return parseUserSQLForColumnListDWHelper(listOfColumns); + } + + if (localUserSQL.charAt(0) == '\"') { + int tempint = localUserSQL.indexOf("\"", 1); + + // keep checking if it's escaped + while (localUserSQL.charAt(tempint + 1) == '\"') { + localUserSQL = localUserSQL.substring(0, tempint) + localUserSQL.substring(tempint + 1); + tempint = localUserSQL.indexOf("\"", tempint + 1); + } + + // we've found a " that is actually trying to close the quote. + String tempstr = localUserSQL.substring(1, tempint); + localUserSQL = localUserSQL.substring(tempint + 1); + listOfColumns.add(tempstr); + return parseUserSQLForColumnListDWHelper(listOfColumns); + } + + // At this point, the next chunk of string is the column name, without starting with [ or ". + StringBuilder sb = new StringBuilder(); + while (localUserSQL.length() > 0) { + if (localUserSQL.charAt(0) == ',') { + localUserSQL = localUserSQL.substring(1); + listOfColumns.add(sb.toString()); + return parseUserSQLForColumnListDWHelper(listOfColumns); + } else if (localUserSQL.charAt(0) == ')'){ + localUserSQL = localUserSQL.substring(1); + listOfColumns.add(sb.toString()); + return listOfColumns; + } else if (checkAndRemoveComments()) { + localUserSQL = localUserSQL.trim(); + } else { + sb.append(localUserSQL.charAt(0)); + localUserSQL = localUserSQL.substring(1); + localUserSQL = localUserSQL.trim(); + } + } + + // It shouldn't come here. If we did, something is wrong. + throw new IllegalArgumentException("localUserSQL"); + } + + + private ArrayList parseUserSQLForValueListDW(boolean hasValuesBeenFound) { + localUserSQL = localUserSQL.trim(); + + // ignore all comments + if (checkAndRemoveComments()) { + return parseUserSQLForValueListDW(hasValuesBeenFound); + } + + if (!hasValuesBeenFound) { + // look for keyword "VALUES" + if (localUserSQL.substring(0, 6).equalsIgnoreCase("VALUES")) { + localUserSQL = localUserSQL.substring(6); + + localUserSQL = localUserSQL.trim(); + + // ignore all comments + if (checkAndRemoveComments()) { + return parseUserSQLForValueListDW(true); + } + + if (localUserSQL.substring(0, 1).equalsIgnoreCase("(")) { + localUserSQL = localUserSQL.substring(1); + return parseUserSQLForValueListDWHelper(new ArrayList()); + } + } + } else { + // ignore all comments + if (checkAndRemoveComments()) { + return parseUserSQLForValueListDW(hasValuesBeenFound); + } + + if (localUserSQL.substring(0, 1).equalsIgnoreCase("(")) { + localUserSQL = localUserSQL.substring(1); + return parseUserSQLForValueListDWHelper(new ArrayList()); + } + } + + // shouldn't come here, as the list of values is mandatory. + throw new IllegalArgumentException("localUserSQL"); + } + + private ArrayList parseUserSQLForValueListDWHelper(ArrayList listOfValues) { + localUserSQL = localUserSQL.trim(); + + // ignore all comments + if (checkAndRemoveComments()) { + return parseUserSQLForValueListDWHelper(listOfValues); + } + + if (localUserSQL.charAt(0) == ')') { + localUserSQL = localUserSQL.substring(1); + return listOfValues; + } + + if (localUserSQL.charAt(0) == ',') { + localUserSQL = localUserSQL.substring(1); + return parseUserSQLForValueListDWHelper(listOfValues); + } + + if (localUserSQL.charAt(0) == '\'') { + int tempint = localUserSQL.indexOf("\'", 1); + + // keep checking if it's escaped + while (localUserSQL.charAt(tempint + 1) == '\'') { + localUserSQL = localUserSQL.substring(0, tempint) + localUserSQL.substring(tempint + 1); + tempint = localUserSQL.indexOf("\'", tempint + 1); + } + + // we've found a ' that is actually trying to close the quote. + // Include 's around the string as well, so we can distinguish '?' and ? later on. + String tempstr = localUserSQL.substring(0, tempint + 1); + localUserSQL = localUserSQL.substring(tempint + 1); + listOfValues.add(tempstr); + return parseUserSQLForValueListDWHelper(listOfValues); + } + + // At this point, the next chunk of string is the value, without starting with ' (most likely a ?). + StringBuilder sb = new StringBuilder(); + while (localUserSQL.length() > 0) { + if (localUserSQL.charAt(0) == ',' || localUserSQL.charAt(0) == ')') { + if (localUserSQL.charAt(0) == ',') { + localUserSQL = localUserSQL.substring(1); + listOfValues.add(sb.toString()); + return parseUserSQLForValueListDWHelper(listOfValues); + } else { + localUserSQL = localUserSQL.substring(1); + listOfValues.add(sb.toString()); + return listOfValues; + } + } else if (checkAndRemoveComments()) { + localUserSQL = localUserSQL.trim(); + } else { + sb.append(localUserSQL.charAt(0)); + localUserSQL = localUserSQL.substring(1); + localUserSQL = localUserSQL.trim(); + } + } + + // It shouldn't come here. If we did, something is wrong. + throw new IllegalArgumentException("localUserSQL"); + } + + private boolean checkAndRemoveComments() { + if (null == localUserSQL || localUserSQL.length() < 2) { + return false; + } + + if (localUserSQL.substring(0, 2).equalsIgnoreCase("/*")) { + int temp = localUserSQL.indexOf("*/") + 2; + localUserSQL = localUserSQL.substring(temp); + return true; + } + + if (localUserSQL.substring(0, 2).equalsIgnoreCase("--")) { + int temp = localUserSQL.indexOf("\n") + 2; + localUserSQL = localUserSQL.substring(temp); + return true; + } + return false; + } private final class PrepStmtBatchExecCmd extends TDSCommand { private final SQLServerPreparedStatement stmt; diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerResource.java b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerResource.java index c2ba11f8e..3c05548af 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerResource.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerResource.java @@ -227,8 +227,8 @@ protected Object[][] getContents() { {"R_invalidTransactionOption", "UseInternalTransaction option cannot be set to TRUE when used with a Connection object."}, {"R_invalidNegativeArg", "The {0} argument cannot be negative."}, {"R_BulkColumnMappingsIsEmpty", "Cannot perform bulk copy operation if the only mapping is an identity column and KeepIdentity is set to false."}, - {"R_CSVDataSchemaMismatch", "Source data does not match source schema."}, - {"R_BulkCSVDataDuplicateColumn", "Duplicate column names are not allowed."}, + {"R_DataSchemaMismatch", "Source data does not match source schema."}, + {"R_BulkDataDuplicateColumn", "Duplicate column names are not allowed."}, {"R_invalidColumnOrdinal", "Column {0} is invalid. Column number should be greater than zero."}, {"R_unsupportedEncoding", "The encoding {0} is not supported."}, {"R_UnexpectedDescribeParamFormat", "Internal error. The format of the resultset returned by sp_describe_parameter_encryption is invalid. One of the resultsets is missing."}, @@ -393,6 +393,7 @@ protected Object[][] getContents() { {"R_invalidSSLProtocol", "SSL Protocol {0} label is not valid. Only TLS, TLSv1, TLSv1.1, and TLSv1.2 are supported."}, {"R_cancelQueryTimeoutPropertyDescription", "The number of seconds to wait to cancel sending a query timeout."}, {"R_invalidCancelQueryTimeout", "The cancel timeout value {0} is not valid."}, + {"R_useBulkCopyForBatchInsertPropertyDescription", "Whether the driver will use bulk copy API for batch insert operations"}, {"R_UnknownDataClsTokenNumber", "Unknown token for Data Classification."}, // From Server {"R_InvalidDataClsVersionNumber", "Invalid version number {0} for Data Classification."}, // From Server {"R_unknownUTF8SupportValue", "Unknown value for UTF8 support."}, diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerStatement.java b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerStatement.java index 162120ae9..592bb8deb 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerStatement.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerStatement.java @@ -154,6 +154,9 @@ String getClassNameLogging() { */ protected SQLServerStatementColumnEncryptionSetting stmtColumnEncriptionSetting = SQLServerStatementColumnEncryptionSetting.UseConnectionSetting; + protected SQLServerStatementColumnEncryptionSetting getStmtColumnEncriptionSetting() { + return stmtColumnEncriptionSetting; + } /** * ExecuteProperties encapsulates a subset of statement property values as they were set at execution time. */ @@ -965,18 +968,40 @@ final void resetForReexecute() throws SQLServerException { * * @param sql * The statment SQL. - * @return True is the statement is a select. + * @return True if the statement is a select. */ /* L0 */ final boolean isSelect(String sql) throws SQLServerException { checkClosed(); // Used to check just the first letter which would cause // "Set" commands to return true... String temp = sql.trim(); - char c = temp.charAt(0); - if (c != 's' && c != 'S') + if (null == sql || sql.length() < 6) { return false; + } return temp.substring(0, 6).equalsIgnoreCase("select"); } + + /** + * Determine if the SQL is a INSERT. + * + * @param sql + * The statment SQL. + * @return True if the statement is an insert. + */ + /* L0 */ final boolean isInsert(String sql) throws SQLServerException { + checkClosed(); + // Used to check just the first letter which would cause + // "Set" commands to return true... + String temp = sql.trim(); + if (null == sql || sql.length() < 6) { + return false; + } + if (temp.substring(0, 2).equalsIgnoreCase("/*")) { + int index = temp.indexOf("*/") + 2; + return isInsert(temp.substring(index)); + } + return temp.substring(0, 6).equalsIgnoreCase("insert"); + } /** * Replace a JDBC parameter marker with the parameter's string value diff --git a/src/test/java/com/microsoft/sqlserver/jdbc/preparedStatement/BatchExecutionWithBulkCopyTest.java b/src/test/java/com/microsoft/sqlserver/jdbc/preparedStatement/BatchExecutionWithBulkCopyTest.java new file mode 100644 index 000000000..c0ec92b58 --- /dev/null +++ b/src/test/java/com/microsoft/sqlserver/jdbc/preparedStatement/BatchExecutionWithBulkCopyTest.java @@ -0,0 +1,577 @@ +package com.microsoft.sqlserver.jdbc.preparedStatement; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.lang.reflect.Field; +import java.lang.reflect.Method; +import java.sql.Connection; +import java.sql.Date; +import java.sql.DriverManager; +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Statement; +import java.sql.Timestamp; +import java.util.ArrayList; + +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.platform.runner.JUnitPlatform; +import org.junit.runner.RunWith; +import org.opentest4j.TestAbortedException; + +import com.microsoft.sqlserver.jdbc.SQLServerConnection; +import com.microsoft.sqlserver.jdbc.SQLServerPreparedStatement; +import com.microsoft.sqlserver.jdbc.SQLServerStatement; +import com.microsoft.sqlserver.testframework.AbstractTest; +import com.microsoft.sqlserver.testframework.Utils; + +@RunWith(JUnitPlatform.class) +public class BatchExecutionWithBulkCopyTest extends AbstractTest { + + static long UUID = System.currentTimeMillis();; + static String tableName = "BulkCopyParseTest" + UUID; + static String squareBracketTableName = "[peter]]]]test" + UUID + "]"; + static String doubleQuoteTableName = "\"peter\"\"\"\"test" + UUID + "\""; + + @Test + public void testIsInsert() throws Exception { + try (Connection connection = DriverManager.getConnection(connectionString + ";useBulkCopyForBatchInsert=true;"); + Statement stmt = (SQLServerStatement) connection.createStatement()) { + String valid1 = "INSERT INTO PeterTable values (1, 2)"; + String valid2 = " INSERT INTO PeterTable values (1, 2)"; + String valid3 = "/* asdf */ INSERT INTO PeterTable values (1, 2)"; + String invalid = "Select * from PEterTable"; + + Method method = stmt.getClass().getDeclaredMethod("isInsert", String.class); + method.setAccessible(true); + assertTrue((boolean) method.invoke(stmt, valid1)); + assertTrue((boolean) method.invoke(stmt, valid2)); + assertTrue((boolean) method.invoke(stmt, valid3)); + assertFalse((boolean) method.invoke(stmt, invalid)); + } + } + + @Test + public void testComments() throws Exception { + try (Connection connection = DriverManager.getConnection(connectionString + ";useBulkCopyForBatchInsert=true;"); + PreparedStatement pstmt = (SQLServerPreparedStatement) connection.prepareStatement("");) { + String valid = "/* rando comment *//* rando comment */ INSERT /* rando comment */ INTO /* rando comment *//*rando comment*/ PeterTable /*rando comment */" + + " /* rando comment */values/* rando comment */ (1, 2)"; + + Field f1 = pstmt.getClass().getSuperclass().getDeclaredField("localUserSQL"); + f1.setAccessible(true); + f1.set(pstmt, valid); + + Method method = pstmt.getClass().getSuperclass().getDeclaredMethod("parseUserSQLForTableNameDW", boolean.class, boolean.class, boolean.class, boolean.class); + method.setAccessible(true); + + assertEquals((String) method.invoke(pstmt, false, false, false, false), "PeterTable"); + } + } + + @Test + public void testBrackets() throws Exception { + try (Connection connection = DriverManager.getConnection(connectionString + ";useBulkCopyForBatchInsert=true;"); + PreparedStatement pstmt = (SQLServerPreparedStatement) connection.prepareStatement("");) { + String valid = "/* rando comment *//* rando comment */ INSERT /* rando comment */ INTO /* rando comment *//*rando comment*/ [Peter[]]Table] /*rando comment */" + + " /* rando comment */values/* rando comment */ (1, 2)"; + + Field f1 = pstmt.getClass().getSuperclass().getDeclaredField("localUserSQL"); + f1.setAccessible(true); + f1.set(pstmt, valid); + + Method method = pstmt.getClass().getSuperclass().getDeclaredMethod("parseUserSQLForTableNameDW", boolean.class, boolean.class, boolean.class, boolean.class); + method.setAccessible(true); + + assertEquals((String) method.invoke(pstmt, false, false, false, false), "[Peter[]]Table]"); + } + } + + @Test + public void testDoubleQuotes() throws Exception { + try (Connection connection = DriverManager.getConnection(connectionString + ";useBulkCopyForBatchInsert=true;"); + PreparedStatement pstmt = (SQLServerPreparedStatement) connection.prepareStatement("");) { + String valid = "/* rando comment *//* rando comment */ INSERT /* rando comment */ INTO /* rando comment *//*rando comment*/ \"Peter\"\"\"\"Table\" /*rando comment */" + + " /* rando comment */values/* rando comment */ (1, 2)"; + + Field f1 = pstmt.getClass().getSuperclass().getDeclaredField("localUserSQL"); + f1.setAccessible(true); + f1.set(pstmt, valid); + + Method method = pstmt.getClass().getSuperclass().getDeclaredMethod("parseUserSQLForTableNameDW", boolean.class, boolean.class, boolean.class, boolean.class); + method.setAccessible(true); + + assertEquals((String) method.invoke(pstmt, false, false, false, false), "\"Peter\"\"\"\"Table\""); + } + } + + @Test + public void testAll() throws Exception { + try (Connection connection = DriverManager.getConnection(connectionString + ";useBulkCopyForBatchInsert=true;"); + PreparedStatement pstmt = (SQLServerPreparedStatement) connection.prepareStatement("");) { + String valid = "/* rando comment *//* rando comment */ INSERT /* rando comment */ INTO /* rando comment *//*rando comment*/ \"Peter\"\"\"\"Table\" /*rando comment */" + + " /* rando comment */ (\"c1\"/* rando comment */, /* rando comment */[c2]/* rando comment */, /* rando comment */ /* rando comment */c3/* rando comment */, c4)" + + "values/* rando comment */ (/* rando comment */1/* rando comment */, /* rando comment */2/* rando comment */ , '?', ?)/* rando comment */"; + + Field f1 = pstmt.getClass().getSuperclass().getDeclaredField("localUserSQL"); + f1.setAccessible(true); + f1.set(pstmt, valid); + + Method method = pstmt.getClass().getSuperclass().getDeclaredMethod("parseUserSQLForTableNameDW", boolean.class, boolean.class, boolean.class, boolean.class); + method.setAccessible(true); + + assertEquals((String) method.invoke(pstmt, false, false, false, false), "\"Peter\"\"\"\"Table\""); + + method = pstmt.getClass().getSuperclass().getDeclaredMethod("parseUserSQLForColumnListDW"); + method.setAccessible(true); + + ArrayList columnList = (ArrayList) method.invoke(pstmt); + ArrayList columnListExpected = new ArrayList(); + columnListExpected.add("c1"); + columnListExpected.add("c2"); + columnListExpected.add("c3"); + columnListExpected.add("c4"); + + for (int i = 0; i < columnListExpected.size(); i++) { + assertEquals(columnList.get(i), columnListExpected.get(i)); + } + + method = pstmt.getClass().getSuperclass().getDeclaredMethod("parseUserSQLForValueListDW", boolean.class); + method.setAccessible(true); + + ArrayList valueList = (ArrayList) method.invoke(pstmt, false); + ArrayList valueListExpected = new ArrayList(); + valueListExpected.add("1"); + valueListExpected.add("2"); + valueListExpected.add("'?'"); + valueListExpected.add("?"); + + for (int i = 0; i < valueListExpected.size(); i++) { + assertEquals(valueList.get(i), valueListExpected.get(i)); + } + } + } + + @Test + public void testAllcolumns() throws Exception { + String valid = "INSERT INTO " + tableName + " values " + + "(" + + "?, " + + "?, " + + "?, " + + "?, " + + "?, " + + "?, " + + "?, " + + "?, " + + "?, " + + ")"; + + try (Connection connection = DriverManager.getConnection(connectionString + ";useBulkCopyForBatchInsert=true;"); + SQLServerPreparedStatement pstmt = (SQLServerPreparedStatement) connection.prepareStatement(valid); + Statement stmt = (SQLServerStatement) connection.createStatement();) { + Field f1 = SQLServerConnection.class.getDeclaredField("isAzureDW"); + f1.setAccessible(true); + f1.set(connection, true); + + Timestamp myTimestamp = new Timestamp(114550L); + + Date d = new Date(114550L); + + pstmt.setInt(1, 1234); + pstmt.setBoolean(2, false); + pstmt.setString(3, "a"); + pstmt.setDate(4, d); + pstmt.setDateTime(5, myTimestamp); + pstmt.setFloat(6, (float) 123.45); + pstmt.setString(7, "b"); + pstmt.setString(8, "varc"); + pstmt.setString(9, "''"); + pstmt.addBatch(); + + pstmt.executeBatch(); + + ResultSet rs = stmt.executeQuery("SELECT * FROM " + tableName); + + Object[] expected = new Object[9]; + + expected[0] = 1234; + expected[1] = false; + expected[2] = "a"; + expected[3] = d; + expected[4] = myTimestamp; + expected[5] = 123.45; + expected[6] = "b"; + expected[7] = "varc"; + expected[8] = "''"; + + rs.next(); + for (int i=0; i < expected.length; i++) { + assertEquals(rs.getObject(i + 1).toString(), expected[i].toString()); + } + } + } + + @Test + public void testMixColumns() throws Exception { + String valid = "INSERT INTO " + tableName + " (c1, c3, c5, c8) values " + + "(" + + "?, " + + "?, " + + "?, " + + "?, " + + ")"; + + try (Connection connection = DriverManager.getConnection(connectionString + ";useBulkCopyForBatchInsert=true;"); + SQLServerPreparedStatement pstmt = (SQLServerPreparedStatement) connection.prepareStatement(valid); + Statement stmt = (SQLServerStatement) connection.createStatement();) { + Field f1 = SQLServerConnection.class.getDeclaredField("isAzureDW"); + f1.setAccessible(true); + f1.set(connection, true); + + Timestamp myTimestamp = new Timestamp(114550L); + + Date d = new Date(114550L); + + pstmt.setInt(1, 1234); + pstmt.setString(2, "a"); + pstmt.setDateTime(3, myTimestamp); + pstmt.setString(4, "varc"); + pstmt.addBatch(); + + pstmt.executeBatch(); + + ResultSet rs = stmt.executeQuery("SELECT * FROM " + tableName); + + Object[] expected = new Object[9]; + + expected[0] = 1234; + expected[1] = false; + expected[2] = "a"; + expected[3] = d; + expected[4] = myTimestamp; + expected[5] = 123.45; + expected[6] = "b"; + expected[7] = "varc"; + expected[8] = "varcmax"; + + rs.next(); + for (int i=0; i < expected.length; i++) { + if (null != rs.getObject(i + 1)) { + assertEquals(rs.getObject(i + 1).toString(), expected[i].toString()); + } + } + } + } + + @Test + public void testNullOrEmptyColumns() throws Exception { + String valid = "INSERT INTO " + tableName + " (c1, c2, c3, c4, c5, c6, c7) values " + + "(" + + "?, " + + "?, " + + "?, " + + "?, " + + "?, " + + "?, " + + "?, " + + ")"; + + try (Connection connection = DriverManager.getConnection(connectionString + ";useBulkCopyForBatchInsert=true;"); + SQLServerPreparedStatement pstmt = (SQLServerPreparedStatement) connection.prepareStatement(valid); + Statement stmt = (SQLServerStatement) connection.createStatement();) { + Field f1 = SQLServerConnection.class.getDeclaredField("isAzureDW"); + f1.setAccessible(true); + f1.set(connection, true); + + pstmt.setInt(1, 1234); + pstmt.setBoolean(2, false); + pstmt.setString(3, null); + pstmt.setDate(4, null); + pstmt.setDateTime(5, null); + pstmt.setFloat(6, (float) 123.45); + pstmt.setString(7, ""); + pstmt.addBatch(); + + pstmt.executeBatch(); + + ResultSet rs = stmt.executeQuery("SELECT * FROM " + tableName); + + Object[] expected = new Object[9]; + + expected[0] = 1234; + expected[1] = false; + expected[2] = null; + expected[3] = null; + expected[4] = null; + expected[5] = 123.45; + expected[6] = " "; + + rs.next(); + for (int i=0; i < expected.length; i++) { + if (null != rs.getObject(i + 1)) { + assertEquals(rs.getObject(i + 1), expected[i]); + } + } + } + } + + @Test + public void testAllFilledColumns() throws Exception { + String valid = "INSERT INTO " + tableName + " values " + + "(" + + "1234, " + + "false, " + + "a, " + + "null, " + + "null, " + + "123.45, " + + "b, " + + "varc, " + + "sadf, " + + ")"; + + try (Connection connection = DriverManager.getConnection(connectionString + ";useBulkCopyForBatchInsert=true;"); + SQLServerPreparedStatement pstmt = (SQLServerPreparedStatement) connection.prepareStatement(valid); + Statement stmt = (SQLServerStatement) connection.createStatement();) { + Field f1 = SQLServerConnection.class.getDeclaredField("isAzureDW"); + f1.setAccessible(true); + f1.set(connection, true); + + Timestamp myTimestamp = new Timestamp(114550L); + + Date d = new Date(114550L); + + pstmt.addBatch(); + + pstmt.executeBatch(); + + ResultSet rs = stmt.executeQuery("SELECT * FROM " + tableName); + + Object[] expected = new Object[9]; + + expected[0] = 1234; + expected[1] = false; + expected[2] = "a"; + expected[3] = null; + expected[4] = null; + expected[5] = 123.45; + expected[6] = "b"; + expected[7] = "varc"; + expected[8] = "sadf"; + + rs.next(); + for (int i=0; i < expected.length; i++) { + assertEquals(rs.getObject(i + 1), expected[i]); + } + } + } + + @Test + public void testSquareBracketAgainstDB() throws Exception { + String valid = "insert into " + squareBracketTableName + " values (?)"; + + try (Connection connection = DriverManager.getConnection(connectionString + ";useBulkCopyForBatchInsert=true;"); + SQLServerPreparedStatement pstmt = (SQLServerPreparedStatement) connection.prepareStatement(valid); + Statement stmt = (SQLServerStatement) connection.createStatement();) { + Field f1 = SQLServerConnection.class.getDeclaredField("isAzureDW"); + f1.setAccessible(true); + f1.set(connection, true); + + Utils.dropTableIfExists(squareBracketTableName, stmt); + String createTable = "create table " + squareBracketTableName + " (c1 int)"; + stmt.execute(createTable); + + pstmt.setInt(1, 1); + pstmt.addBatch(); + + pstmt.executeBatch(); + + ResultSet rs = stmt.executeQuery("SELECT * FROM " + squareBracketTableName); + rs.next(); + + assertEquals(rs.getObject(1), 1); + } + } + + @Test + public void testDoubleQuoteAgainstDB() throws Exception { + String valid = "insert into " + doubleQuoteTableName + " values (?)"; + + try (Connection connection = DriverManager.getConnection(connectionString + ";useBulkCopyForBatchInsert=true;"); + SQLServerPreparedStatement pstmt = (SQLServerPreparedStatement) connection.prepareStatement(valid); + Statement stmt = (SQLServerStatement) connection.createStatement();) { + Field f1 = SQLServerConnection.class.getDeclaredField("isAzureDW"); + f1.setAccessible(true); + f1.set(connection, true); + + Utils.dropTableIfExists(doubleQuoteTableName, stmt); + String createTable = "create table " + doubleQuoteTableName + " (c1 int)"; + stmt.execute(createTable); + + pstmt.setInt(1, 1); + pstmt.addBatch(); + + pstmt.executeBatch(); + + ResultSet rs = stmt.executeQuery("SELECT * FROM " + doubleQuoteTableName); + rs.next(); + + assertEquals(rs.getObject(1), 1); + } + } + + @Test + public void testSchemaAgainstDB() throws Exception { + String schemaTableName = "\"dbo\" . /*some comment */ " + squareBracketTableName; + String valid = "insert into " + schemaTableName + " values (?)"; + + try (Connection connection = DriverManager.getConnection(connectionString + ";useBulkCopyForBatchInsert=true;"); + SQLServerPreparedStatement pstmt = (SQLServerPreparedStatement) connection.prepareStatement(valid); + Statement stmt = (SQLServerStatement) connection.createStatement();) { + Field f1 = SQLServerConnection.class.getDeclaredField("isAzureDW"); + f1.setAccessible(true); + f1.set(connection, true); + + Utils.dropTableIfExists("[dbo]." + squareBracketTableName, stmt); + + String createTable = "create table " + schemaTableName + " (c1 int)"; + stmt.execute(createTable); + + pstmt.setInt(1, 1); + pstmt.addBatch(); + + pstmt.executeBatch(); + + ResultSet rs = stmt.executeQuery("SELECT * FROM " + schemaTableName); + rs.next(); + + assertEquals(rs.getObject(1), 1); + } + } + + @Test + public void testColumnNameMixAgainstDB() throws Exception { + String valid = "insert into " + squareBracketTableName + " ([c]]]]1], [c]]]]2]) values (?, 1)"; + + try (Connection connection = DriverManager.getConnection(connectionString + ";useBulkCopyForBatchInsert=true;"); + SQLServerPreparedStatement pstmt = (SQLServerPreparedStatement) connection.prepareStatement(valid); + Statement stmt = (SQLServerStatement) connection.createStatement();) { + Field f1 = SQLServerConnection.class.getDeclaredField("isAzureDW"); + f1.setAccessible(true); + f1.set(connection, true); + + Utils.dropTableIfExists(squareBracketTableName, stmt); + String createTable = "create table " + squareBracketTableName + " ([c]]]]1] int, [c]]]]2] int)"; + stmt.execute(createTable); + + pstmt.setInt(1, 1); + pstmt.addBatch(); + + pstmt.executeBatch(); + + ResultSet rs = stmt.executeQuery("SELECT * FROM " + squareBracketTableName); + rs.next(); + + assertEquals(rs.getObject(1), 1); + } + } + + @Test + public void testAlColumnsLargeBatch() throws Exception { + String valid = "INSERT INTO " + tableName + " values " + + "(" + + "?, " + + "?, " + + "?, " + + "?, " + + "?, " + + "?, " + + "?, " + + "?, " + + "?, " + + ")"; + + try (Connection connection = DriverManager.getConnection(connectionString + ";useBulkCopyForBatchInsert=true;"); + SQLServerPreparedStatement pstmt = (SQLServerPreparedStatement) connection.prepareStatement(valid); + Statement stmt = (SQLServerStatement) connection.createStatement();) { + Field f1 = SQLServerConnection.class.getDeclaredField("isAzureDW"); + f1.setAccessible(true); + f1.set(connection, true); + + Timestamp myTimestamp = new Timestamp(114550L); + + Date d = new Date(114550L); + + pstmt.setInt(1, 1234); + pstmt.setBoolean(2, false); + pstmt.setString(3, "a"); + pstmt.setDate(4, d); + pstmt.setDateTime(5, myTimestamp); + pstmt.setFloat(6, (float) 123.45); + pstmt.setString(7, "b"); + pstmt.setString(8, "varc"); + pstmt.setString(9, "''"); + pstmt.addBatch(); + + pstmt.executeLargeBatch(); + + ResultSet rs = stmt.executeQuery("SELECT * FROM " + tableName); + + Object[] expected = new Object[9]; + + expected[0] = 1234; + expected[1] = false; + expected[2] = "a"; + expected[3] = d; + expected[4] = myTimestamp; + expected[5] = 123.45; + expected[6] = "b"; + expected[7] = "varc"; + expected[8] = "''"; + + rs.next(); + for (int i=0; i < expected.length; i++) { + assertEquals(rs.getObject(i + 1).toString(), expected[i].toString()); + } + } + } + + @BeforeEach + public void testSetup() throws TestAbortedException, Exception { + try (Connection connection = DriverManager.getConnection(connectionString + ";useBulkCopyForBatchInsert=true;")) { + try (Statement stmt = (SQLServerStatement) connection.createStatement()) { + Utils.dropTableIfExists(tableName, stmt); + String sql1 = "create table " + tableName + " " + + "(" + + "c1 int DEFAULT 1234, " + + "c2 bit, " + + "c3 char DEFAULT NULL, " + + "c4 date, " + + "c5 datetime2, " + + "c6 float, " + + "c7 nchar, " + + "c8 varchar(20), " + + "c9 varchar(max)" + + ")"; + + stmt.execute(sql1); + } + } + } + + @AfterAll + public static void terminateVariation() throws SQLException { + try (Connection connection = DriverManager.getConnection(connectionString)) { + try (Statement stmt = (SQLServerStatement) connection.createStatement()) { + Utils.dropTableIfExists(tableName, stmt); + Utils.dropTableIfExists(squareBracketTableName, stmt); + Utils.dropTableIfExists(doubleQuoteTableName, stmt); + } + } + } +} diff --git a/src/test/java/com/microsoft/sqlserver/jdbc/preparedStatement/RegressionTest.java b/src/test/java/com/microsoft/sqlserver/jdbc/preparedStatement/RegressionTest.java index 880ed7ef7..2596042f2 100644 --- a/src/test/java/com/microsoft/sqlserver/jdbc/preparedStatement/RegressionTest.java +++ b/src/test/java/com/microsoft/sqlserver/jdbc/preparedStatement/RegressionTest.java @@ -9,6 +9,7 @@ import static org.junit.jupiter.api.Assertions.fail; +import java.lang.reflect.Field; import java.sql.Connection; import java.sql.DriverManager; import java.sql.PreparedStatement; @@ -23,6 +24,8 @@ import org.junit.jupiter.api.Test; import org.junit.platform.runner.JUnitPlatform; import org.junit.runner.RunWith; + +import com.microsoft.sqlserver.jdbc.SQLServerConnection; import com.microsoft.sqlserver.testframework.AbstractTest; import com.microsoft.sqlserver.testframework.Utils; import com.microsoft.sqlserver.jdbc.TestResource; @@ -211,109 +214,123 @@ public void grantTest() throws SQLException { * @throws SQLException */ @Test - public void batchWithLargeStringTest() throws SQLException { - Statement stmt = con.createStatement(); - PreparedStatement pstmt = null; - ResultSet rs = null; - Utils.dropTableIfExists("TEST_TABLE", stmt); - - con.setAutoCommit(false); - - // create a table with two columns - boolean createPrimaryKey = false; - try { - stmt.execute("if object_id('TEST_TABLE', 'U') is not null\ndrop table TEST_TABLE;"); - if (createPrimaryKey) { - stmt.execute("create table TEST_TABLE ( ID int, DATA nvarchar(max), primary key (ID) );"); - } - else { - stmt.execute("create table TEST_TABLE ( ID int, DATA nvarchar(max) );"); + public void batchWithLargeStringTest() throws Exception { + batchWithLargeStringTestInternal("BatchInsert"); + } + + @Test + public void batchWithLargeStringTestUseBulkCopyAPI() throws Exception { + batchWithLargeStringTestInternal("BulkCopy"); + } + + private void batchWithLargeStringTestInternal(String mode) throws Exception { + try (Connection con = DriverManager.getConnection(connectionString);) { + if (mode.equalsIgnoreCase("bulkcopy")) { + modifyConnectionForBulkCopyAPI((SQLServerConnection) con); } - } - catch (Exception e) { - fail(e.toString()); - } - - con.commit(); - - // build a String with 4001 characters - StringBuilder stringBuilder = new StringBuilder(); - for (int i = 0; i < 4001; i++) { - stringBuilder.append('c'); - } - String largeString = stringBuilder.toString(); - String[] values = {"a", "b", largeString, "d", "e"}; - // insert five rows into the table; use a batch for each row - try { - pstmt = con.prepareStatement("insert into TEST_TABLE values (?,?)"); - // 0,a - pstmt.setInt(1, 0); - pstmt.setNString(2, values[0]); - pstmt.addBatch(); + Statement stmt = con.createStatement(); + PreparedStatement pstmt = null; + ResultSet rs = null; + Utils.dropTableIfExists("TEST_TABLE", stmt); - // 1,b - pstmt.setInt(1, 1); - pstmt.setNString(2, values[1]); - pstmt.addBatch(); + con.setAutoCommit(false); - // 2,ccc... - pstmt.setInt(1, 2); - pstmt.setNString(2, values[2]); - pstmt.addBatch(); + // create a table with two columns + boolean createPrimaryKey = false; + try { + stmt.execute("if object_id('TEST_TABLE', 'U') is not null\ndrop table TEST_TABLE;"); + if (createPrimaryKey) { + stmt.execute("create table TEST_TABLE ( ID int, DATA nvarchar(max), primary key (ID) );"); + } + else { + stmt.execute("create table TEST_TABLE ( ID int, DATA nvarchar(max) );"); + } + } + catch (Exception e) { + fail(e.toString()); + } - // 3,d - pstmt.setInt(1, 3); - pstmt.setNString(2, values[3]); - pstmt.addBatch(); + con.commit(); - // 4,e - pstmt.setInt(1, 4); - pstmt.setNString(2, values[4]); - pstmt.addBatch(); + // build a String with 4001 characters + StringBuilder stringBuilder = new StringBuilder(); + for (int i = 0; i < 4001; i++) { + stringBuilder.append('c'); + } + String largeString = stringBuilder.toString(); - pstmt.executeBatch(); - } - catch (Exception e) { - fail(e.toString()); - } - connection.commit(); + String[] values = {"a", "b", largeString, "d", "e"}; + // insert five rows into the table; use a batch for each row + try { + pstmt = con.prepareStatement("insert into TEST_TABLE values (?,?)"); + // 0,a + pstmt.setInt(1, 0); + pstmt.setNString(2, values[0]); + pstmt.addBatch(); + + // 1,b + pstmt.setInt(1, 1); + pstmt.setNString(2, values[1]); + pstmt.addBatch(); + + // 2,ccc... + pstmt.setInt(1, 2); + pstmt.setNString(2, values[2]); + pstmt.addBatch(); + + // 3,d + pstmt.setInt(1, 3); + pstmt.setNString(2, values[3]); + pstmt.addBatch(); + + // 4,e + pstmt.setInt(1, 4); + pstmt.setNString(2, values[4]); + pstmt.addBatch(); + + pstmt.executeBatch(); + } + catch (Exception e) { + fail(e.toString()); + } + con.commit(); - // check the data in the table - Map selectedValues = new LinkedHashMap<>(); - int id = 0; - try { - pstmt = con.prepareStatement("select * from TEST_TABLE;"); + // check the data in the table + Map selectedValues = new LinkedHashMap<>(); + int id = 0; try { - rs = pstmt.executeQuery(); - int i = 0; - while (rs.next()) { - id = rs.getInt(1); - String data = rs.getNString(2); - if (selectedValues.containsKey(id)) { - fail("Found duplicate id: " + id + " ,actual values is : " + values[i++] + " data is: " + data); + pstmt = con.prepareStatement("select * from TEST_TABLE;"); + try { + rs = pstmt.executeQuery(); + int i = 0; + while (rs.next()) { + id = rs.getInt(1); + String data = rs.getNString(2); + if (selectedValues.containsKey(id)) { + fail("Found duplicate id: " + id + " ,actual values is : " + values[i++] + " data is: " + data); + } + selectedValues.put(id, data); + } + } + finally { + if (null != rs) { + rs.close(); } - selectedValues.put(id, data); } } finally { - if (null != rs) { - rs.close(); + Utils.dropTableIfExists("TEST_TABLE", stmt); + if (null != pstmt) { + pstmt.close(); + } + if (null != stmt) { + stmt.close(); } } } - finally { - Utils.dropTableIfExists("TEST_TABLE", stmt); - if (null != pstmt) { - pstmt.close(); - } - if (null != stmt) { - stmt.close(); - } - } - } - + /** * Test with large string and tests with more batch queries * @@ -434,4 +451,11 @@ public static void cleanup() throws SQLException { } + private void modifyConnectionForBulkCopyAPI(SQLServerConnection con) throws Exception { + Field f1 = SQLServerConnection.class.getDeclaredField("isAzureDW"); + f1.setAccessible(true); + f1.set(con, true); + + con.setUseBulkCopyForBatchInsert(true); + } } diff --git a/src/test/java/com/microsoft/sqlserver/jdbc/unit/statement/BatchExecuteWithErrorsTest.java b/src/test/java/com/microsoft/sqlserver/jdbc/unit/statement/BatchExecuteWithErrorsTest.java index 1297446ca..e3659adfc 100644 --- a/src/test/java/com/microsoft/sqlserver/jdbc/unit/statement/BatchExecuteWithErrorsTest.java +++ b/src/test/java/com/microsoft/sqlserver/jdbc/unit/statement/BatchExecuteWithErrorsTest.java @@ -14,6 +14,7 @@ import static org.junit.jupiter.api.Assertions.fail; import static org.junit.jupiter.api.Assumptions.assumeTrue; +import java.lang.reflect.Field; import java.sql.BatchUpdateException; import java.sql.Connection; import java.sql.DriverManager; @@ -28,6 +29,7 @@ import org.junit.platform.runner.JUnitPlatform; import org.junit.runner.RunWith; +import com.microsoft.sqlserver.jdbc.SQLServerConnection; import com.microsoft.sqlserver.jdbc.TestResource; import com.microsoft.sqlserver.testframework.AbstractSQLGenerator; import com.microsoft.sqlserver.testframework.AbstractTest; @@ -53,12 +55,38 @@ public class BatchExecuteWithErrorsTest extends AbstractTest { /** * Batch test - * - * @throws SQLException + * @throws Exception */ @Test @DisplayName("Batch Test") - public void Repro47239() throws SQLException { + public void Repro47239() throws Exception { + Repro47239Internal("BatchInsert"); + } + + @Test + @DisplayName("Batch Test using bulk copy API") + public void Repro47239UseBulkCopyAPI() throws Exception { + Repro47239Internal("BulkCopy"); + } + + /** + * Tests large methods, supported in 42 + * + * @throws Exception + */ + @Test + @DisplayName("Regression test for using 'large' methods") + public void Repro47239large() throws Exception { + Repro47239largeInternal("BatchInsert"); + } + + @Test + @DisplayName("Regression test for using 'large' methods using bulk copy API") + public void Repro47239largeUseBulkCopyAPI() throws Exception { + Repro47239largeInternal("BulkCopy"); + } + + private void Repro47239Internal(String mode) throws Exception { String tableN = RandomUtil.getIdentifier("t_Repro47239"); final String tableName = AbstractSQLGenerator.escapeIdentifier(tableN); final String insertStmt = "INSERT INTO " + tableName + " VALUES (999, 'HELLO', '4/12/1994')"; @@ -101,181 +129,179 @@ public void Repro47239() throws SQLException { catch (ClassNotFoundException e1) { fail(e1.toString()); } - Connection conn = DriverManager.getConnection(connectionString); - Statement stmt = conn.createStatement(); - try { - stmt.executeUpdate("drop table " + tableName); - } - catch (Exception ignored) { - } - stmt.executeUpdate( - "create table " + tableName + " (c1_int int, c2_varchar varchar(20), c3_date datetime, c4_int int identity(1,1) primary key)"); - - // Regular Statement batch update - expectedUpdateCounts = new int[] {1, -2, 1, -2, 1, -2}; - Statement batchStmt = conn.createStatement(); - batchStmt.addBatch(insertStmt); - batchStmt.addBatch(warning); - batchStmt.addBatch(insertStmt); - batchStmt.addBatch(warning); - batchStmt.addBatch(insertStmt); - batchStmt.addBatch(warning); - try { - actualUpdateCounts = batchStmt.executeBatch(); - actualExceptionText = ""; - } - catch (BatchUpdateException bue) { - actualUpdateCounts = bue.getUpdateCounts(); - actualExceptionText = bue.getMessage(); - if (log.isLoggable(Level.FINE)) { - log.fine("BatchUpdateException occurred. Message:" + actualExceptionText); - } - } - finally { - batchStmt.close(); - } - if (log.isLoggable(Level.FINE)) { - log.fine("UpdateCounts:"); - } - for (int updateCount : actualUpdateCounts) { - log.fine("" + updateCount + ","); - } - log.fine(""); - assertTrue(Arrays.equals(actualUpdateCounts, expectedUpdateCounts), TestResource.getResource("R_testInterleaved")); - - expectedUpdateCounts = new int[] {-3, 1, 1, 1}; - stmt.addBatch(error); - stmt.addBatch(insertStmt); - stmt.addBatch(insertStmt); - stmt.addBatch(insertStmt); - try { - actualUpdateCounts = stmt.executeBatch(); - actualExceptionText = ""; - } - catch (BatchUpdateException bue) { - actualUpdateCounts = bue.getUpdateCounts(); - actualExceptionText = bue.getMessage(); - } - log.fine("UpdateCounts:"); - for (int updateCount : actualUpdateCounts) { - log.fine("" + updateCount + ","); - } - log.fine(""); - assertTrue(Arrays.equals(actualUpdateCounts, expectedUpdateCounts), TestResource.getResource("R_errorFollowInserts")); - // 50280 - expectedUpdateCounts = new int[] {1, -3}; - stmt.addBatch(insertStmt); - stmt.addBatch(error16); - try { - actualUpdateCounts = stmt.executeBatch(); - actualExceptionText = ""; - } - catch (BatchUpdateException bue) { - actualUpdateCounts = bue.getUpdateCounts(); - actualExceptionText = bue.getMessage(); - } - for (int updateCount : actualUpdateCounts) { - log.fine("" + updateCount + ","); - } - log.fine(""); - assertTrue(Arrays.equals(actualUpdateCounts, expectedUpdateCounts), TestResource.getResource("R_errorFollow50280")); - - // Test "soft" errors - conn.setAutoCommit(false); - stmt.addBatch(select); - stmt.addBatch(insertStmt); - stmt.addBatch(select); - stmt.addBatch(insertStmt); - try { - stmt.executeBatch(); - // Soft error test: executeBatch unexpectedly succeeded - assertEquals(true, false, TestResource.getResource("R_shouldThrowException")); - } - catch (BatchUpdateException bue) { - assertEquals("A result set was generated for update.", bue.getMessage(), TestResource.getResource("R_unexpectedExceptionContent")); - assertEquals(Arrays.equals(bue.getUpdateCounts(), new int[] {-3, 1, -3, 1}), true, - TestResource.getResource("R_incorrectUpdateCount")); - } - conn.rollback(); - - // Defect 128801: Rollback (with conversion error) should throw SQLException - stmt.addBatch(dateConversionError); - stmt.addBatch(insertStmt); - stmt.addBatch(insertStmt); - stmt.addBatch(insertStmt); - try { - stmt.executeBatch(); - } - catch (BatchUpdateException bue) { - assertThat(bue.getMessage(), containsString(TestResource.getResource("R_syntaxErrorDateConvert"))); - // CTestLog.CompareStartsWith(bue.getMessage(), "Syntax error converting date", "Transaction rollback with conversion error threw wrong - // BatchUpdateException"); - } - catch (SQLException e) { - assertThat(e.getMessage(), containsString(TestResource.getResource("R_dateConvertError"))); - // CTestLog.CompareStartsWith(e.getMessage(), "Conversion failed when converting date", "Transaction rollback with conversion error threw - // wrong SQLException"); - } - - conn.setAutoCommit(true); - - // On SQL Azure, raising FATAL error by RAISERROR() is not supported and there is no way to - // cut the current connection by a statement inside a SQL batch. - // Details: Although one can simulate a fatal error (that cuts the connections) by dropping the database, - // this simulation cannot be written entirely in TSQL (because it needs a new connection), - // and thus it cannot be put into a TSQL batch and it is useless here. - // So we have to skip the last scenario of this test case, i.e. "Test Severe (connection-closing) errors" - // It is worthwhile to still execute the first 5 test scenarios of this test case, in order to have best test coverage. - if (!DBConnection.isSqlAzure(conn)) { - // Test Severe (connection-closing) errors - stmt.addBatch(error); - stmt.addBatch(insertStmt); - stmt.addBatch(warning); - // TODO Removed until ResultSet refactoring task (45832) is complete. - // stmt.addBatch(select); // error: select not permitted in batch - stmt.addBatch(insertStmt); - stmt.addBatch(severe); - stmt.addBatch(insertStmt); - stmt.addBatch(insertStmt); - try { - stmt.executeBatch(); - // Test fatal errors batch execution succeeded (should have failed) - assertEquals(false, true, TestResource.getResource("R_shouldThrowException")); + try (Connection conn = DriverManager.getConnection(connectionString)) { + if (mode.equalsIgnoreCase("bulkcopy")) { + modifyConnectionForBulkCopyAPI((SQLServerConnection) conn); } - catch (BatchUpdateException bue) { - // Test fatal errors returned BatchUpdateException rather than SQLException - assertEquals(false, true, TestResource.getResource("R_unexpectedException") + bue.getMessage()); + try (Statement stmt = conn.createStatement()) { + + try { + Utils.dropTableIfExists(tableName, stmt); + } + catch (Exception ignored) { + } + stmt.executeUpdate( + "create table " + tableName + " (c1_int int, c2_varchar varchar(20), c3_date datetime, c4_int int identity(1,1) primary key)"); + + // Regular Statement batch update + expectedUpdateCounts = new int[] {1, -2, 1, -2, 1, -2}; + Statement batchStmt = conn.createStatement(); + batchStmt.addBatch(insertStmt); + batchStmt.addBatch(warning); + batchStmt.addBatch(insertStmt); + batchStmt.addBatch(warning); + batchStmt.addBatch(insertStmt); + batchStmt.addBatch(warning); + try { + actualUpdateCounts = batchStmt.executeBatch(); + actualExceptionText = ""; + } + catch (BatchUpdateException bue) { + actualUpdateCounts = bue.getUpdateCounts(); + actualExceptionText = bue.getMessage(); + if (log.isLoggable(Level.FINE)) { + log.fine("BatchUpdateException occurred. Message:" + actualExceptionText); + } + } + finally { + batchStmt.close(); + } + if (log.isLoggable(Level.FINE)) { + log.fine("UpdateCounts:"); + } + for (int updateCount : actualUpdateCounts) { + log.fine("" + updateCount + ","); + } + log.fine(""); + assertTrue(Arrays.equals(actualUpdateCounts, expectedUpdateCounts), TestResource.getResource("R_testInterleaved")); + + expectedUpdateCounts = new int[] {-3, 1, 1, 1}; + stmt.addBatch(error); + stmt.addBatch(insertStmt); + stmt.addBatch(insertStmt); + stmt.addBatch(insertStmt); + try { + actualUpdateCounts = stmt.executeBatch(); + actualExceptionText = ""; + } + catch (BatchUpdateException bue) { + actualUpdateCounts = bue.getUpdateCounts(); + actualExceptionText = bue.getMessage(); + } + log.fine("UpdateCounts:"); + for (int updateCount : actualUpdateCounts) { + log.fine("" + updateCount + ","); + } + log.fine(""); + assertTrue(Arrays.equals(actualUpdateCounts, expectedUpdateCounts), TestResource.getResource("R_errorFollowInserts")); + // 50280 + expectedUpdateCounts = new int[] {1, -3}; + stmt.addBatch(insertStmt); + stmt.addBatch(error16); + try { + actualUpdateCounts = stmt.executeBatch(); + actualExceptionText = ""; + } + catch (BatchUpdateException bue) { + actualUpdateCounts = bue.getUpdateCounts(); + actualExceptionText = bue.getMessage(); + } + for (int updateCount : actualUpdateCounts) { + log.fine("" + updateCount + ","); + } + log.fine(""); + assertTrue(Arrays.equals(actualUpdateCounts, expectedUpdateCounts), TestResource.getResource("R_errorFollow50280")); + + // Test "soft" errors + conn.setAutoCommit(false); + stmt.addBatch(select); + stmt.addBatch(insertStmt); + stmt.addBatch(select); + stmt.addBatch(insertStmt); + try { + stmt.executeBatch(); + // Soft error test: executeBatch unexpectedly succeeded + assertEquals(true, false, TestResource.getResource("R_shouldThrowException")); + } + catch (BatchUpdateException bue) { + assertEquals("A result set was generated for update.", bue.getMessage(), TestResource.getResource("R_unexpectedExceptionContent")); + assertEquals(Arrays.equals(bue.getUpdateCounts(), new int[] {-3, 1, -3, 1}), true, + TestResource.getResource("R_incorrectUpdateCount")); + } + conn.rollback(); + + // Defect 128801: Rollback (with conversion error) should throw SQLException + stmt.addBatch(dateConversionError); + stmt.addBatch(insertStmt); + stmt.addBatch(insertStmt); + stmt.addBatch(insertStmt); + try { + stmt.executeBatch(); + } + catch (BatchUpdateException bue) { + assertThat(bue.getMessage(), containsString(TestResource.getResource("R_syntaxErrorDateConvert"))); + // CTestLog.CompareStartsWith(bue.getMessage(), "Syntax error converting date", "Transaction rollback with conversion error threw wrong + // BatchUpdateException"); + } + catch (SQLException e) { + assertThat(e.getMessage(), containsString(TestResource.getResource("R_dateConvertError"))); + // CTestLog.CompareStartsWith(e.getMessage(), "Conversion failed when converting date", "Transaction rollback with conversion error threw + // wrong SQLException"); + } - } - catch (SQLException e) { - actualExceptionText = e.getMessage(); + conn.setAutoCommit(true); + + // On SQL Azure, raising FATAL error by RAISERROR() is not supported and there is no way to + // cut the current connection by a statement inside a SQL batch. + // Details: Although one can simulate a fatal error (that cuts the connections) by dropping the database, + // this simulation cannot be written entirely in TSQL (because it needs a new connection), + // and thus it cannot be put into a TSQL batch and it is useless here. + // So we have to skip the last scenario of this test case, i.e. "Test Severe (connection-closing) errors" + // It is worthwhile to still execute the first 5 test scenarios of this test case, in order to have best test coverage. + if (!DBConnection.isSqlAzure(conn)) { + // Test Severe (connection-closing) errors + stmt.addBatch(error); + stmt.addBatch(insertStmt); + stmt.addBatch(warning); + // TODO Removed until ResultSet refactoring task (45832) is complete. + // stmt.addBatch(select); // error: select not permitted in batch + stmt.addBatch(insertStmt); + stmt.addBatch(severe); + stmt.addBatch(insertStmt); + stmt.addBatch(insertStmt); + try { + stmt.executeBatch(); + // Test fatal errors batch execution succeeded (should have failed) + assertEquals(false, true, TestResource.getResource("R_shouldThrowException")); + } + catch (BatchUpdateException bue) { + // Test fatal errors returned BatchUpdateException rather than SQLException + assertEquals(false, true, TestResource.getResource("R_unexpectedException") + bue.getMessage()); + + } + catch (SQLException e) { + actualExceptionText = e.getMessage(); + + if (actualExceptionText.endsWith("reset")) { + assertTrue(actualExceptionText.equalsIgnoreCase("Connection reset"), TestResource.getResource("R_unexpectedExceptionContent") + ": " + actualExceptionText); + } + else { + assertTrue(actualExceptionText.equalsIgnoreCase("raiserror level 20"), TestResource.getResource("R_unexpectedExceptionContent") + ": " + actualExceptionText); + } + } + } - if (actualExceptionText.endsWith("reset")) { - assertTrue(actualExceptionText.equalsIgnoreCase("Connection reset"), TestResource.getResource("R_unexpectedExceptionContent") + ": " + actualExceptionText); + try { + stmt.executeUpdate("drop table " + tableName); } - else { - assertTrue(actualExceptionText.equalsIgnoreCase("raiserror level 20"), TestResource.getResource("R_unexpectedExceptionContent") + ": " + actualExceptionText); + catch (Exception ignored) { } } - } - try { - stmt.executeUpdate("drop table " + tableName); } - catch (Exception ignored) { - } - stmt.close(); - conn.close(); } - - /** - * Tests large methods, supported in 42 - * - * @throws Exception - */ - @Test - @DisplayName("Regression test for using 'large' methods") - public void Repro47239large() throws Exception { + + private void Repro47239largeInternal(String mode) throws Exception { assumeTrue("JDBC42".equals(Utils.getConfiguredProperty("JDBC_Version")), TestResource.getResource("R_incompatJDBC")); // the DBConnection for detecting whether the server is SQL Azure or SQL Server. @@ -309,168 +335,180 @@ public void Repro47239large() throws Exception { // SQL Server 2005 driver Class.forName("com.microsoft.sqlserver.jdbc.SQLServerDriver"); - Connection conn = DriverManager.getConnection(connectionString); - Statement stmt = conn.createStatement(); - - try { - stmt.executeLargeUpdate("drop table " + tableName); - } - catch (Exception ignored) { - } - try { - stmt.executeLargeUpdate( - "create table " + tableName + " (c1_int int, c2_varchar varchar(20), c3_date datetime, c4_int int identity(1,1) primary key)"); - } - catch (Exception ignored) { - } - // Regular Statement batch update - expectedUpdateCounts = new long[] {1, -2, 1, -2, 1, -2}; - Statement batchStmt = conn.createStatement(); - batchStmt.addBatch(insertStmt); - batchStmt.addBatch(warning); - batchStmt.addBatch(insertStmt); - batchStmt.addBatch(warning); - batchStmt.addBatch(insertStmt); - batchStmt.addBatch(warning); - try { - actualUpdateCounts = batchStmt.executeLargeBatch(); - actualExceptionText = ""; - } - catch (BatchUpdateException bue) { - actualUpdateCounts = bue.getLargeUpdateCounts(); - actualExceptionText = bue.getMessage(); - log.fine("BatchUpdateException occurred. Message:" + actualExceptionText); - } - finally { - batchStmt.close(); - } - log.fine("UpdateCounts:"); - for (long updateCount : actualUpdateCounts) { - log.fine("" + updateCount + ","); - } - log.fine(""); - assertTrue(Arrays.equals(actualUpdateCounts, expectedUpdateCounts), TestResource.getResource("R_testInterleaved")); - - expectedUpdateCounts = new long[] {-3, 1, 1, 1}; - stmt.addBatch(error); - stmt.addBatch(insertStmt); - stmt.addBatch(insertStmt); - stmt.addBatch(insertStmt); - try { - actualUpdateCounts = stmt.executeLargeBatch(); - actualExceptionText = ""; - } - catch (BatchUpdateException bue) { - actualUpdateCounts = bue.getLargeUpdateCounts(); - actualExceptionText = bue.getMessage(); - } - log.fine("UpdateCounts:"); - for (long updateCount : actualUpdateCounts) { - log.fine("" + updateCount + ","); - } - log.fine(""); - assertTrue(Arrays.equals(actualUpdateCounts, expectedUpdateCounts), TestResource.getResource("R_errorFollowInserts")); - - // 50280 - expectedUpdateCounts = new long[] {1, -3}; - stmt.addBatch(insertStmt); - stmt.addBatch(error16); - try { - actualUpdateCounts = stmt.executeLargeBatch(); - actualExceptionText = ""; - } - catch (BatchUpdateException bue) { - actualUpdateCounts = bue.getLargeUpdateCounts(); - actualExceptionText = bue.getMessage(); - } - for (long updateCount : actualUpdateCounts) { - log.fine("" + updateCount + ","); - } - log.fine(""); - assertTrue(Arrays.equals(actualUpdateCounts, expectedUpdateCounts), TestResource.getResource("R_errorFollow50280")); - - // Test "soft" errors - conn.setAutoCommit(false); - stmt.addBatch(select); - stmt.addBatch(insertStmt); - stmt.addBatch(select); - stmt.addBatch(insertStmt); - try { - stmt.executeLargeBatch(); - // Soft error test: executeLargeBatch unexpectedly succeeded - assertEquals(false, true, TestResource.getResource("R_shouldThrowException")); - } - catch (BatchUpdateException bue) { - // Soft error test: wrong error message in BatchUpdateException - assertEquals("A result set was generated for update.", bue.getMessage(), TestResource.getResource("R_unexpectedExceptionContent")); - // Soft error test: wrong update counts in BatchUpdateException - assertEquals(Arrays.equals(bue.getLargeUpdateCounts(), new long[] {-3, 1, -3, 1}), true, - TestResource.getResource("R_incorrectUpdateCount")); - } - conn.rollback(); - - // Defect 128801: Rollback (with conversion error) should throw SQLException - stmt.addBatch(dateConversionError); - stmt.addBatch(insertStmt); - stmt.addBatch(insertStmt); - stmt.addBatch(insertStmt); - try { - stmt.executeLargeBatch(); - } - catch (BatchUpdateException bue) { - assertThat(bue.getMessage(), containsString(TestResource.getResource("R_syntaxErrorDateConvert"))); - } - catch (SQLException e) { - assertThat(e.getMessage(), containsString(TestResource.getResource("R_dateConvertError"))); - } - - conn.setAutoCommit(true); - - // On SQL Azure, raising FATAL error by RAISERROR() is not supported and there is no way to - // cut the current connection by a statement inside a SQL batch. - // Details: Although one can simulate a fatal error (that cuts the connections) by dropping the database, - // this simulation cannot be written entirely in TSQL (because it needs a new connection), - // and thus it cannot be put into a TSQL batch and it is useless here. - // So we have to skip the last scenario of this test case, i.e. "Test Severe (connection-closing) errors" - // It is worthwhile to still execute the first 5 test scenarios of this test case, in order to have best test coverage. - if (!DBConnection.isSqlAzure(DriverManager.getConnection(connectionString))) { - // Test Severe (connection-closing) errors - stmt.addBatch(error); - stmt.addBatch(insertStmt); - stmt.addBatch(warning); - - stmt.addBatch(insertStmt); - stmt.addBatch(severe); - stmt.addBatch(insertStmt); - stmt.addBatch(insertStmt); - try { - stmt.executeLargeBatch(); - // Test fatal errors batch execution succeeded (should have failed) - assertEquals(false, true, TestResource.getResource("R_shouldThrowException")); - } - catch (BatchUpdateException bue) { - // Test fatal errors returned BatchUpdateException rather than SQLException - assertEquals(false, true, TestResource.getResource("R_unexpectedException") + bue.getMessage()); + + try (Connection conn = DriverManager.getConnection(connectionString)) { + if (mode.equalsIgnoreCase("bulkcopy")) { + modifyConnectionForBulkCopyAPI((SQLServerConnection) conn); } - catch (SQLException e) { - actualExceptionText = e.getMessage(); + try (Statement stmt = conn.createStatement()) { + + try { + Utils.dropTableIfExists(tableName, stmt); + } + catch (Exception ignored) { + } + try { + stmt.executeLargeUpdate( + "create table " + tableName + " (c1_int int, c2_varchar varchar(20), c3_date datetime, c4_int int identity(1,1) primary key)"); + } + catch (Exception ignored) { + } + // Regular Statement batch update + expectedUpdateCounts = new long[] {1, -2, 1, -2, 1, -2}; + Statement batchStmt = conn.createStatement(); + batchStmt.addBatch(insertStmt); + batchStmt.addBatch(warning); + batchStmt.addBatch(insertStmt); + batchStmt.addBatch(warning); + batchStmt.addBatch(insertStmt); + batchStmt.addBatch(warning); + try { + actualUpdateCounts = batchStmt.executeLargeBatch(); + actualExceptionText = ""; + } + catch (BatchUpdateException bue) { + actualUpdateCounts = bue.getLargeUpdateCounts(); + actualExceptionText = bue.getMessage(); + log.fine("BatchUpdateException occurred. Message:" + actualExceptionText); + } + finally { + batchStmt.close(); + } + log.fine("UpdateCounts:"); + for (long updateCount : actualUpdateCounts) { + log.fine("" + updateCount + ","); + } + log.fine(""); + assertTrue(Arrays.equals(actualUpdateCounts, expectedUpdateCounts), TestResource.getResource("R_testInterleaved")); + + expectedUpdateCounts = new long[] {-3, 1, 1, 1}; + stmt.addBatch(error); + stmt.addBatch(insertStmt); + stmt.addBatch(insertStmt); + stmt.addBatch(insertStmt); + try { + actualUpdateCounts = stmt.executeLargeBatch(); + actualExceptionText = ""; + } + catch (BatchUpdateException bue) { + actualUpdateCounts = bue.getLargeUpdateCounts(); + actualExceptionText = bue.getMessage(); + } + log.fine("UpdateCounts:"); + for (long updateCount : actualUpdateCounts) { + log.fine("" + updateCount + ","); + } + log.fine(""); + assertTrue(Arrays.equals(actualUpdateCounts, expectedUpdateCounts), TestResource.getResource("R_errorFollowInserts")); + + // 50280 + expectedUpdateCounts = new long[] {1, -3}; + stmt.addBatch(insertStmt); + stmt.addBatch(error16); + try { + actualUpdateCounts = stmt.executeLargeBatch(); + actualExceptionText = ""; + } + catch (BatchUpdateException bue) { + actualUpdateCounts = bue.getLargeUpdateCounts(); + actualExceptionText = bue.getMessage(); + } + for (long updateCount : actualUpdateCounts) { + log.fine("" + updateCount + ","); + } + log.fine(""); + assertTrue(Arrays.equals(actualUpdateCounts, expectedUpdateCounts), TestResource.getResource("R_errorFollow50280")); + + // Test "soft" errors + conn.setAutoCommit(false); + stmt.addBatch(select); + stmt.addBatch(insertStmt); + stmt.addBatch(select); + stmt.addBatch(insertStmt); + try { + stmt.executeLargeBatch(); + // Soft error test: executeLargeBatch unexpectedly succeeded + assertEquals(false, true, TestResource.getResource("R_shouldThrowException")); + } + catch (BatchUpdateException bue) { + // Soft error test: wrong error message in BatchUpdateException + assertEquals("A result set was generated for update.", bue.getMessage(), TestResource.getResource("R_unexpectedExceptionContent")); + // Soft error test: wrong update counts in BatchUpdateException + assertEquals(Arrays.equals(bue.getLargeUpdateCounts(), new long[] {-3, 1, -3, 1}), true, + TestResource.getResource("R_incorrectUpdateCount")); + } + conn.rollback(); + + // Defect 128801: Rollback (with conversion error) should throw SQLException + stmt.addBatch(dateConversionError); + stmt.addBatch(insertStmt); + stmt.addBatch(insertStmt); + stmt.addBatch(insertStmt); + try { + stmt.executeLargeBatch(); + } + catch (BatchUpdateException bue) { + assertThat(bue.getMessage(), containsString(TestResource.getResource("R_syntaxErrorDateConvert"))); + } + catch (SQLException e) { + assertThat(e.getMessage(), containsString(TestResource.getResource("R_dateConvertError"))); + } - if (actualExceptionText.endsWith("reset")) { - assertTrue(actualExceptionText.equalsIgnoreCase("Connection reset"), TestResource.getResource("R_unexpectedExceptionContent") + ": " + actualExceptionText); + conn.setAutoCommit(true); + + // On SQL Azure, raising FATAL error by RAISERROR() is not supported and there is no way to + // cut the current connection by a statement inside a SQL batch. + // Details: Although one can simulate a fatal error (that cuts the connections) by dropping the database, + // this simulation cannot be written entirely in TSQL (because it needs a new connection), + // and thus it cannot be put into a TSQL batch and it is useless here. + // So we have to skip the last scenario of this test case, i.e. "Test Severe (connection-closing) errors" + // It is worthwhile to still execute the first 5 test scenarios of this test case, in order to have best test coverage. + if (!DBConnection.isSqlAzure(DriverManager.getConnection(connectionString))) { + // Test Severe (connection-closing) errors + stmt.addBatch(error); + stmt.addBatch(insertStmt); + stmt.addBatch(warning); + + stmt.addBatch(insertStmt); + stmt.addBatch(severe); + stmt.addBatch(insertStmt); + stmt.addBatch(insertStmt); + try { + stmt.executeLargeBatch(); + // Test fatal errors batch execution succeeded (should have failed) + assertEquals(false, true, TestResource.getResource("R_shouldThrowException")); + } + catch (BatchUpdateException bue) { + // Test fatal errors returned BatchUpdateException rather than SQLException + assertEquals(false, true, TestResource.getResource("R_unexpectedException") + bue.getMessage()); + } + catch (SQLException e) { + actualExceptionText = e.getMessage(); + + if (actualExceptionText.endsWith("reset")) { + assertTrue(actualExceptionText.equalsIgnoreCase("Connection reset"), TestResource.getResource("R_unexpectedExceptionContent") + ": " + actualExceptionText); + } + else { + assertTrue(actualExceptionText.equalsIgnoreCase("raiserror level 20"), TestResource.getResource("R_unexpectedExceptionContent") + ": " + actualExceptionText); + + } + } } - else { - assertTrue(actualExceptionText.equalsIgnoreCase("raiserror level 20"), TestResource.getResource("R_unexpectedExceptionContent") + ": " + actualExceptionText); + try { + stmt.executeLargeUpdate("drop table " + tableName); + } + catch (Exception ignored) { } } } - - try { - stmt.executeLargeUpdate("drop table " + tableName); - } - catch (Exception ignored) { - } - stmt.close(); - conn.close(); + } + + private void modifyConnectionForBulkCopyAPI(SQLServerConnection con) throws Exception { + Field f1 = SQLServerConnection.class.getDeclaredField("isAzureDW"); + f1.setAccessible(true); + f1.set(con, true); + + con.setUseBulkCopyForBatchInsert(true); } } diff --git a/src/test/java/com/microsoft/sqlserver/jdbc/unit/statement/BatchExecutionTest.java b/src/test/java/com/microsoft/sqlserver/jdbc/unit/statement/BatchExecutionTest.java index c9de04a67..a4e7004be 100644 --- a/src/test/java/com/microsoft/sqlserver/jdbc/unit/statement/BatchExecutionTest.java +++ b/src/test/java/com/microsoft/sqlserver/jdbc/unit/statement/BatchExecutionTest.java @@ -11,7 +11,7 @@ import static org.junit.jupiter.api.Assertions.fail; import static org.junit.jupiter.api.Assumptions.assumeTrue; -import java.sql.BatchUpdateException; +import java.lang.reflect.Field; import java.sql.Connection; import java.sql.DriverManager; import java.sql.PreparedStatement; @@ -26,6 +26,7 @@ import org.junit.runner.RunWith; import org.opentest4j.TestAbortedException; +import com.microsoft.sqlserver.jdbc.SQLServerConnection; import com.microsoft.sqlserver.jdbc.SQLServerStatement; import com.microsoft.sqlserver.jdbc.TestResource; import com.microsoft.sqlserver.testframework.AbstractTest; @@ -54,6 +55,8 @@ public class BatchExecutionTest extends AbstractTest { public void testBatchExceptionAEOn() throws Exception { testAddBatch1(); testExecuteBatch1(); + testAddBatch1UseBulkCopyAPI(); + testExecuteBatch1UseBulkCopyAPI(); } /** @@ -61,55 +64,11 @@ public void testBatchExceptionAEOn() throws Exception { * array of Integer values of length 3 */ public void testAddBatch1() { - int i = 0; - int retValue[] = {0, 0, 0}; - try { - String sPrepStmt = "update ctstable2 set PRICE=PRICE*20 where TYPE_ID=?"; - pstmt = connection.prepareStatement(sPrepStmt); - pstmt.setInt(1, 2); - pstmt.addBatch(); - - pstmt.setInt(1, 3); - pstmt.addBatch(); - - pstmt.setInt(1, 4); - pstmt.addBatch(); - - int[] updateCount = pstmt.executeBatch(); - int updateCountlen = updateCount.length; - - assertTrue(updateCountlen == 3, TestResource.getResource("R_addBatchFailed") + ": " + TestResource.getResource("R_incorrectUpdateCount")); - - String sPrepStmt1 = "select count(*) from ctstable2 where TYPE_ID=?"; - - pstmt1 = connection.prepareStatement(sPrepStmt1); - - // 2 is the number that is set First for Type Id in Prepared Statement - for (int n = 2; n <= 4; n++) { - pstmt1.setInt(1, n); - rs = pstmt1.executeQuery(); - rs.next(); - retValue[i++] = rs.getInt(1); - } - - pstmt1.close(); - - for (int j = 0; j < updateCount.length; j++) { + testAddBatch1Internal("BatchInsert"); + } - if (updateCount[j] != retValue[j] && updateCount[j] != Statement.SUCCESS_NO_INFO) { - fail(TestResource.getResource("R_incorrectUpdateCount")); - } - } - } - catch (BatchUpdateException b) { - fail(TestResource.getResource("R_addBatchFailed") + ": " + b.getMessage()); - } - catch (SQLException sqle) { - fail(TestResource.getResource("R_addBatchFailed") + ": " + sqle.getMessage()); - } - catch (Exception e) { - fail(TestResource.getResource("R_addBatchFailed") + ": " + e.getMessage()); - } + public void testAddBatch1UseBulkCopyAPI() { + testAddBatch1Internal("BulkCopy"); } /** @@ -117,12 +76,24 @@ public void testAddBatch1() { * an array of Integer values of length 3. */ public void testExecuteBatch1() { + testExecuteBatch1Internal("BatchInsert"); + } + + public void testExecuteBatch1UseBulkCopyAPI() { + testExecuteBatch1Internal("BulkCopy"); + } + + private void testExecuteBatch1Internal(String mode) { int i = 0; int retValue[] = {0, 0, 0}; int updateCountlen = 0; - try { + try (Connection connection = DriverManager.getConnection(connectionString + ";columnEncryptionSetting=Enabled;");){ String sPrepStmt = "update ctstable2 set PRICE=PRICE*20 where TYPE_ID=?"; + if (mode.equalsIgnoreCase("bulkcopy")) { + modifyConnectionForBulkCopyAPI((SQLServerConnection) connection); + } + pstmt = connection.prepareStatement(sPrepStmt); pstmt.setInt(1, 1); pstmt.addBatch(); @@ -157,12 +128,6 @@ public void testExecuteBatch1() { } } } - catch (BatchUpdateException b) { - fail(TestResource.getResource("R_executeBatchFailed") + ": " + b.getMessage()); - } - catch (SQLException sqle) { - fail(TestResource.getResource("R_executeBatchFailed") + ": " + sqle.getMessage()); - } catch (Exception e) { fail(TestResource.getResource("R_executeBatchFailed") + ": " + e.getMessage()); } @@ -189,6 +154,65 @@ private static void createTable() throws SQLException { stmt.execute(sqlin1); } + + private void testAddBatch1Internal(String mode) { + int i = 0; + int retValue[] = {0, 0, 0}; + try (Connection connection = DriverManager.getConnection(connectionString + ";columnEncryptionSetting=Enabled;");){ + String sPrepStmt = "update ctstable2 set PRICE=PRICE*20 where TYPE_ID=?"; + + if (mode.equalsIgnoreCase("bulkcopy")) { + modifyConnectionForBulkCopyAPI((SQLServerConnection) connection); + } + + pstmt = connection.prepareStatement(sPrepStmt); + pstmt.setInt(1, 2); + pstmt.addBatch(); + + pstmt.setInt(1, 3); + pstmt.addBatch(); + + pstmt.setInt(1, 4); + pstmt.addBatch(); + + int[] updateCount = pstmt.executeBatch(); + int updateCountlen = updateCount.length; + + assertTrue(updateCountlen == 3, TestResource.getResource("R_addBatchFailed") + ": " + TestResource.getResource("R_incorrectUpdateCount")); + + String sPrepStmt1 = "select count(*) from ctstable2 where TYPE_ID=?"; + + pstmt1 = connection.prepareStatement(sPrepStmt1); + + // 2 is the number that is set First for Type Id in Prepared Statement + for (int n = 2; n <= 4; n++) { + pstmt1.setInt(1, n); + rs = pstmt1.executeQuery(); + rs.next(); + retValue[i++] = rs.getInt(1); + } + + pstmt1.close(); + + for (int j = 0; j < updateCount.length; j++) { + + if (updateCount[j] != retValue[j] && updateCount[j] != Statement.SUCCESS_NO_INFO) { + fail(TestResource.getResource("R_incorrectUpdateCount")); + } + } + } + catch (Exception e) { + fail(TestResource.getResource("R_addBatchFailed") + ": " + e.getMessage()); + } + } + + private void modifyConnectionForBulkCopyAPI(SQLServerConnection con) throws Exception { + Field f1 = SQLServerConnection.class.getDeclaredField("isAzureDW"); + f1.setAccessible(true); + f1.set(con, true); + + con.setUseBulkCopyForBatchInsert(true); + } @BeforeAll public static void testSetup() throws TestAbortedException, Exception { diff --git a/src/test/java/com/microsoft/sqlserver/jdbc/unit/statement/PreparedStatementTest.java b/src/test/java/com/microsoft/sqlserver/jdbc/unit/statement/PreparedStatementTest.java index 9e886d441..78b2777b0 100644 --- a/src/test/java/com/microsoft/sqlserver/jdbc/unit/statement/PreparedStatementTest.java +++ b/src/test/java/com/microsoft/sqlserver/jdbc/unit/statement/PreparedStatementTest.java @@ -14,6 +14,7 @@ import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.fail; +import java.lang.reflect.Field; import java.sql.BatchUpdateException; import java.sql.DriverManager; import java.sql.ResultSet; @@ -133,168 +134,25 @@ public void testBatchedUnprepare() throws SQLException { */ @Test @Tag("slow") - public void testStatementPooling() throws SQLException { - // Test % handle re-use - try (SQLServerConnection con = (SQLServerConnection)DriverManager.getConnection(connectionString)) { - String query = String.format("/*statementpoolingtest_re-use_%s*/SELECT TOP(1) * FROM sys.tables;", UUID.randomUUID().toString()); - - con.setStatementPoolingCacheSize(10); - - boolean[] prepOnFirstCalls = {false, true}; - - for(boolean prepOnFirstCall : prepOnFirstCalls) { - - con.setEnablePrepareOnFirstPreparedStatementCall(prepOnFirstCall); - - int[] queryCounts = {10, 20, 30, 40}; - for(int queryCount : queryCounts) { - String[] queries = new String[queryCount]; - for(int i = 0; i < queries.length; ++i) { - queries[i] = String.format("%s--%s--%s--%s", query, i, queryCount, prepOnFirstCall); - } - - int testsWithHandleReuse = 0; - final int testCount = 500; - for(int i = 0; i < testCount; ++i) { - Random random = new Random(); - int queryNumber = random.nextInt(queries.length); - try (SQLServerPreparedStatement pstmt = (SQLServerPreparedStatement) con.prepareStatement(queries[queryNumber])) { - pstmt.execute(); - - // Grab handle-reuse before it would be populated if initially created. - if(0 < pstmt.getPreparedStatementHandle()) - testsWithHandleReuse++; - - pstmt.getMoreResults(); // Make sure handle is updated. - } - } - System.out.println(String.format("Prep on first call: %s Query count:%s: %s of %s (%s)", prepOnFirstCall, queryCount, testsWithHandleReuse, testCount, (double)testsWithHandleReuse/(double)testCount)); - } - } - } - - try (SQLServerConnection con = (SQLServerConnection)DriverManager.getConnection(connectionString)) { - - // Test behvaior with statement pooling. - con.setStatementPoolingCacheSize(10); - this.executeSQL(con, - "IF NOT EXISTS (SELECT * FROM sys.messages WHERE message_id = 99586) EXEC sp_addmessage 99586, 16, 'Prepared handle GAH!';"); - // Test with missing handle failures (fake). - this.executeSQL(con, "CREATE TABLE #update1 (col INT);INSERT #update1 VALUES (1);"); - this.executeSQL(con, - "CREATE PROC #updateProc1 AS UPDATE #update1 SET col += 1; IF EXISTS (SELECT * FROM #update1 WHERE col % 5 = 0) RAISERROR(99586,16,1);"); - try (SQLServerPreparedStatement pstmt = (SQLServerPreparedStatement) con.prepareStatement("#updateProc1")) { - for (int i = 0; i < 100; ++i) { - try { - assertSame(1, pstmt.executeUpdate()); - } - catch (SQLException e) { - // Error "Prepared handle GAH" is expected to happen. But it should not terminate the execution with RAISERROR. - // Since the original "Could not find prepared statement with handle" error does not terminate the execution after it. - if (!e.getMessage().contains("Prepared handle GAH")) { - throw e; - } - } - } - } - - // test updated value, should be 1 + 100 = 101 - // although executeUpdate() throws exception, update operation should be executed successfully. - try (ResultSet rs = con.createStatement().executeQuery("select * from #update1")) { - rs.next(); - assertSame(101, rs.getInt(1)); - } - - // Test batching with missing handle failures (fake). - this.executeSQL(con, - "IF NOT EXISTS (SELECT * FROM sys.messages WHERE message_id = 99586) EXEC sp_addmessage 99586, 16, 'Prepared handle GAH!';"); - this.executeSQL(con, "CREATE TABLE #update2 (col INT);INSERT #update2 VALUES (1);"); - this.executeSQL(con, - "CREATE PROC #updateProc2 AS UPDATE #update2 SET col += 1; IF EXISTS (SELECT * FROM #update2 WHERE col % 5 = 0) RAISERROR(99586,16,1);"); - try (SQLServerPreparedStatement pstmt = (SQLServerPreparedStatement) con.prepareStatement("#updateProc2")) { - for (int i = 0; i < 100; ++i) { - pstmt.addBatch(); - } - - int[] updateCounts = null; - try { - updateCounts = pstmt.executeBatch(); - } - catch (BatchUpdateException e) { - // Error "Prepared handle GAH" is expected to happen. But it should not terminate the execution with RAISERROR. - // Since the original "Could not find prepared statement with handle" error does not terminate the execution after it. - if (!e.getMessage().contains("Prepared handle GAH")) { - throw e; - } - } - - // since executeBatch() throws exception, it does not return anthing. So updateCounts is still null. - assertSame(null, updateCounts); - - // test updated value, should be 1 + 100 = 101 - // although executeBatch() throws exception, update operation should be executed successfully. - try (ResultSet rs = con.createStatement().executeQuery("select * from #update2")) { - rs.next(); - assertSame(101, rs.getInt(1)); - } - } - } - - try (SQLServerConnection con = (SQLServerConnection)DriverManager.getConnection(connectionString)) { - // Test behvaior with statement pooling. - con.setDisableStatementPooling(false); - con.setStatementPoolingCacheSize(10); - - String lookupUniqueifier = UUID.randomUUID().toString(); - String query = String.format("/*statementpoolingtest_%s*/SELECT * FROM sys.tables;", lookupUniqueifier); - - // Execute statement first, should create cache entry WITHOUT handle (since sp_executesql was used). - try (SQLServerPreparedStatement pstmt = (SQLServerPreparedStatement)con.prepareStatement(query)) { - pstmt.execute(); // sp_executesql - pstmt.getMoreResults(); // Make sure handle is updated. - - assertSame(0, pstmt.getPreparedStatementHandle()); - } - - // Execute statement again, should now create handle. - int handle = 0; - try (SQLServerPreparedStatement pstmt = (SQLServerPreparedStatement)con.prepareStatement(query)) { - pstmt.execute(); // sp_prepexec - pstmt.getMoreResults(); // Make sure handle is updated. - - handle = pstmt.getPreparedStatementHandle(); - assertNotSame(0, handle); - } - - // Execute statement again and verify same handle was used. - try (SQLServerPreparedStatement pstmt = (SQLServerPreparedStatement)con.prepareStatement(query)) { - pstmt.execute(); // sp_execute - pstmt.getMoreResults(); // Make sure handle is updated. - - assertNotSame(0, pstmt.getPreparedStatementHandle()); - assertSame(handle, pstmt.getPreparedStatementHandle()); - } - - // Execute new statement with different SQL text and verify it does NOT get same handle (should now fall back to using sp_executesql). - SQLServerPreparedStatement outer = null; - try (SQLServerPreparedStatement pstmt = (SQLServerPreparedStatement)con.prepareStatement(query + ";")) { - outer = pstmt; - pstmt.execute(); // sp_executesql - pstmt.getMoreResults(); // Make sure handle is updated. - - assertSame(0, pstmt.getPreparedStatementHandle()); - assertNotSame(handle, pstmt.getPreparedStatementHandle()); - } - try { - System.out.println(outer.getPreparedStatementHandle()); - fail(TestResource.getResource("R_invalidGetPreparedStatementHandle")); - } - catch(Exception e) { - // Good! - } - } + public void testStatementPooling() throws Exception { + testStatementPoolingInternal("batchInsert"); } + /** + * Test handling of statement pooling for prepared statements. + * + * @throws SQLException + * @throws SecurityException + * @throws NoSuchFieldException + * @throws IllegalAccessException + * @throws IllegalArgumentException + */ + @Test + @Tag("slow") + public void testStatementPoolingUseBulkCopyAPI() throws Exception { + testStatementPoolingInternal("BulkCopy"); + } + /** * Test handling of eviction from statement pooling for prepared statements. * @@ -569,4 +427,182 @@ public void testStatementPoolingPreparedStatementExecAndUnprepareConfig() throws assertSame(0, con.getDiscardedServerPreparedStatementCount()); } } + + private void testStatementPoolingInternal(String mode) throws Exception { + // Test % handle re-use + try (SQLServerConnection con = (SQLServerConnection)DriverManager.getConnection(connectionString)) { + if (mode.equalsIgnoreCase("bulkcopy")) { + modifyConnectionForBulkCopyAPI(con); + } + String query = String.format("/*statementpoolingtest_re-use_%s*/SELECT TOP(1) * FROM sys.tables;", UUID.randomUUID().toString()); + + con.setStatementPoolingCacheSize(10); + + boolean[] prepOnFirstCalls = {false, true}; + + for(boolean prepOnFirstCall : prepOnFirstCalls) { + + con.setEnablePrepareOnFirstPreparedStatementCall(prepOnFirstCall); + + int[] queryCounts = {10, 20, 30, 40}; + for(int queryCount : queryCounts) { + String[] queries = new String[queryCount]; + for(int i = 0; i < queries.length; ++i) { + queries[i] = String.format("%s--%s--%s--%s", query, i, queryCount, prepOnFirstCall); + } + + int testsWithHandleReuse = 0; + final int testCount = 500; + for(int i = 0; i < testCount; ++i) { + Random random = new Random(); + int queryNumber = random.nextInt(queries.length); + try (SQLServerPreparedStatement pstmt = (SQLServerPreparedStatement) con.prepareStatement(queries[queryNumber])) { + pstmt.execute(); + + // Grab handle-reuse before it would be populated if initially created. + if(0 < pstmt.getPreparedStatementHandle()) + testsWithHandleReuse++; + + pstmt.getMoreResults(); // Make sure handle is updated. + } + } + System.out.println(String.format("Prep on first call: %s Query count:%s: %s of %s (%s)", prepOnFirstCall, queryCount, testsWithHandleReuse, testCount, (double)testsWithHandleReuse/(double)testCount)); + } + } + } + + try (SQLServerConnection con = (SQLServerConnection)DriverManager.getConnection(connectionString)) { + if (mode.equalsIgnoreCase("bulkcopy")) { + modifyConnectionForBulkCopyAPI(con); + } + // Test behvaior with statement pooling. + con.setStatementPoolingCacheSize(10); + this.executeSQL(con, + "IF NOT EXISTS (SELECT * FROM sys.messages WHERE message_id = 99586) EXEC sp_addmessage 99586, 16, 'Prepared handle GAH!';"); + // Test with missing handle failures (fake). + this.executeSQL(con, "CREATE TABLE #update1 (col INT);INSERT #update1 VALUES (1);"); + this.executeSQL(con, + "CREATE PROC #updateProc1 AS UPDATE #update1 SET col += 1; IF EXISTS (SELECT * FROM #update1 WHERE col % 5 = 0) RAISERROR(99586,16,1);"); + try (SQLServerPreparedStatement pstmt = (SQLServerPreparedStatement) con.prepareStatement("#updateProc1")) { + for (int i = 0; i < 100; ++i) { + try { + assertSame(1, pstmt.executeUpdate()); + } + catch (SQLException e) { + // Error "Prepared handle GAH" is expected to happen. But it should not terminate the execution with RAISERROR. + // Since the original "Could not find prepared statement with handle" error does not terminate the execution after it. + if (!e.getMessage().contains("Prepared handle GAH")) { + throw e; + } + } + } + } + + // test updated value, should be 1 + 100 = 101 + // although executeUpdate() throws exception, update operation should be executed successfully. + try (ResultSet rs = con.createStatement().executeQuery("select * from #update1")) { + rs.next(); + assertSame(101, rs.getInt(1)); + } + + // Test batching with missing handle failures (fake). + this.executeSQL(con, + "IF NOT EXISTS (SELECT * FROM sys.messages WHERE message_id = 99586) EXEC sp_addmessage 99586, 16, 'Prepared handle GAH!';"); + this.executeSQL(con, "CREATE TABLE #update2 (col INT);INSERT #update2 VALUES (1);"); + this.executeSQL(con, + "CREATE PROC #updateProc2 AS UPDATE #update2 SET col += 1; IF EXISTS (SELECT * FROM #update2 WHERE col % 5 = 0) RAISERROR(99586,16,1);"); + try (SQLServerPreparedStatement pstmt = (SQLServerPreparedStatement) con.prepareStatement("#updateProc2")) { + for (int i = 0; i < 100; ++i) { + pstmt.addBatch(); + } + + int[] updateCounts = null; + try { + updateCounts = pstmt.executeBatch(); + } + catch (BatchUpdateException e) { + // Error "Prepared handle GAH" is expected to happen. But it should not terminate the execution with RAISERROR. + // Since the original "Could not find prepared statement with handle" error does not terminate the execution after it. + if (!e.getMessage().contains("Prepared handle GAH")) { + throw e; + } + } + + // since executeBatch() throws exception, it does not return anthing. So updateCounts is still null. + assertSame(null, updateCounts); + + // test updated value, should be 1 + 100 = 101 + // although executeBatch() throws exception, update operation should be executed successfully. + try (ResultSet rs = con.createStatement().executeQuery("select * from #update2")) { + rs.next(); + assertSame(101, rs.getInt(1)); + } + } + } + + try (SQLServerConnection con = (SQLServerConnection)DriverManager.getConnection(connectionString)) { + if (mode.equalsIgnoreCase("bulkcopy")) { + modifyConnectionForBulkCopyAPI(con); + } + // Test behvaior with statement pooling. + con.setDisableStatementPooling(false); + con.setStatementPoolingCacheSize(10); + + String lookupUniqueifier = UUID.randomUUID().toString(); + String query = String.format("/*statementpoolingtest_%s*/SELECT * FROM sys.tables;", lookupUniqueifier); + + // Execute statement first, should create cache entry WITHOUT handle (since sp_executesql was used). + try (SQLServerPreparedStatement pstmt = (SQLServerPreparedStatement)con.prepareStatement(query)) { + pstmt.execute(); // sp_executesql + pstmt.getMoreResults(); // Make sure handle is updated. + + assertSame(0, pstmt.getPreparedStatementHandle()); + } + + // Execute statement again, should now create handle. + int handle = 0; + try (SQLServerPreparedStatement pstmt = (SQLServerPreparedStatement)con.prepareStatement(query)) { + pstmt.execute(); // sp_prepexec + pstmt.getMoreResults(); // Make sure handle is updated. + + handle = pstmt.getPreparedStatementHandle(); + assertNotSame(0, handle); + } + + // Execute statement again and verify same handle was used. + try (SQLServerPreparedStatement pstmt = (SQLServerPreparedStatement)con.prepareStatement(query)) { + pstmt.execute(); // sp_execute + pstmt.getMoreResults(); // Make sure handle is updated. + + assertNotSame(0, pstmt.getPreparedStatementHandle()); + assertSame(handle, pstmt.getPreparedStatementHandle()); + } + + // Execute new statement with different SQL text and verify it does NOT get same handle (should now fall back to using sp_executesql). + SQLServerPreparedStatement outer = null; + try (SQLServerPreparedStatement pstmt = (SQLServerPreparedStatement)con.prepareStatement(query + ";")) { + outer = pstmt; + pstmt.execute(); // sp_executesql + pstmt.getMoreResults(); // Make sure handle is updated. + + assertSame(0, pstmt.getPreparedStatementHandle()); + assertNotSame(handle, pstmt.getPreparedStatementHandle()); + } + try { + System.out.println(outer.getPreparedStatementHandle()); + fail(TestResource.getResource("R_invalidGetPreparedStatementHandle")); + } + catch(Exception e) { + // Good! + } + } + } + + private void modifyConnectionForBulkCopyAPI(SQLServerConnection con) throws Exception { + Field f1 = SQLServerConnection.class.getDeclaredField("isAzureDW"); + f1.setAccessible(true); + f1.set(con, true); + + con.setUseBulkCopyForBatchInsert(true); + } }