diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/IOBuffer.java b/src/main/java/com/microsoft/sqlserver/jdbc/IOBuffer.java index 027cc5eae0..cdb5c1f570 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/IOBuffer.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/IOBuffer.java @@ -52,6 +52,7 @@ import java.util.Set; import java.util.SimpleTimeZone; import java.util.TimeZone; +import java.util.concurrent.ScheduledFuture; import java.util.concurrent.SynchronousQueue; import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.TimeUnit; @@ -3016,6 +3017,10 @@ void setDataLoggable(boolean value) { dataIsLoggable = value; } + SharedTimer getSharedTimer() { + return con.getSharedTimer(); + } + private TDSCommand command = null; // TDS message type (Query, RPC, DTC, etc.) sent at the beginning @@ -6236,7 +6241,7 @@ final class TDSReaderMark { final class TDSReader { private final static Logger logger = Logger.getLogger("com.microsoft.sqlserver.jdbc.internals.TDS.Reader"); final private String traceID; - private TimeoutCommand timeoutCommand; + private ScheduledFuture timeout; final public String toString() { return traceID; @@ -6390,9 +6395,8 @@ synchronized final boolean readPacket() throws SQLServerException { // terminate the connection. if ((command.getCancelQueryTimeoutSeconds() > 0 && command.getQueryTimeoutSeconds() > 0)) { // if a timeout is configured with this object, add it to the timeout poller - int timeout = command.getCancelQueryTimeoutSeconds() + command.getQueryTimeoutSeconds(); - this.timeoutCommand = new TdsTimeoutCommand(timeout, this.command, this.con); - TimeoutPoller.getTimeoutPoller().addTimeoutCommand(this.timeoutCommand); + int seconds = command.getCancelQueryTimeoutSeconds() + command.getQueryTimeoutSeconds(); + this.timeout = con.getSharedTimer().schedule(new TDSTimeoutTask(command, con), seconds); } } // First, read the packet header. @@ -6413,8 +6417,9 @@ synchronized final boolean readPacket() throws SQLServerException { } // if execution was subject to timeout then stop timing - if (this.timeoutCommand != null) { - TimeoutPoller.getTimeoutPoller().remove(this.timeoutCommand); + if (this.timeout != null) { + this.timeout.cancel(false); + this.timeout = null; } // Header size is a 2 byte unsigned short integer in big-endian order. int packetLength = Util.readUnsignedShortBigEndian(newPacket.header, TDS.PACKET_HEADER_MESSAGE_LENGTH); @@ -7003,42 +7008,6 @@ final void trySetSensitivityClassification(SensitivityClassification sensitivity } -/** - * The tds default implementation of a timeout command - */ -class TdsTimeoutCommand extends TimeoutCommand { - public TdsTimeoutCommand(int timeout, TDSCommand command, SQLServerConnection sqlServerConnection) { - super(timeout, command, sqlServerConnection); - } - - public void interrupt() { - TDSCommand command = getCommand(); - SQLServerConnection sqlServerConnection = getSqlServerConnection(); - try { - // If TCP Connection to server is silently dropped, exceeding the query timeout - // on the same connection does - // not throw SQLTimeoutException - // The application stops responding instead until SocketTimeoutException is - // thrown. In this case, we must - // manually terminate the connection. - if (null == command && null != sqlServerConnection) { - sqlServerConnection.terminate(SQLServerException.DRIVER_ERROR_IO_FAILED, - SQLServerException.getErrString("R_connectionIsClosed")); - } else { - // If the timer wasn't canceled before it ran out of - // time then interrupt the registered command. - command.interrupt(SQLServerException.getErrString("R_queryTimedOut")); - } - } catch (SQLServerException e) { - // Unfortunately, there's nothing we can do if we - // fail to time out the request. There is no way - // to report back what happened. - assert null != command; - command.log(Level.FINE, "Command could not be timed out. Reason: " + e.getMessage()); - } - } -} - /** * TDSCommand encapsulates an interruptable TDS conversation. * @@ -7160,7 +7129,7 @@ protected void setProcessedResponse(boolean processedResponse) { private volatile boolean readingResponse; private int queryTimeoutSeconds; private int cancelQueryTimeoutSeconds; - private TdsTimeoutCommand timeoutCommand; + private ScheduledFuture timeout; protected int getQueryTimeoutSeconds() { return this.queryTimeoutSeconds; @@ -7576,8 +7545,8 @@ final TDSReader startResponse(boolean isAdaptive) throws SQLServerException { // If command execution is subject to timeout then start timing until // the server returns the first response packet. if (queryTimeoutSeconds > 0) { - this.timeoutCommand = new TdsTimeoutCommand(queryTimeoutSeconds, this, null); - TimeoutPoller.getTimeoutPoller().addTimeoutCommand(this.timeoutCommand); + SQLServerConnection conn = tdsReader != null ? tdsReader.getConnection() : null; + this.timeout = tdsWriter.getSharedTimer().schedule(new TDSTimeoutTask(this, conn), queryTimeoutSeconds); } if (logger.isLoggable(Level.FINEST)) @@ -7600,8 +7569,9 @@ final TDSReader startResponse(boolean isAdaptive) throws SQLServerException { } finally { // If command execution was subject to timeout then stop timing as soon // as the server returns the first response packet or errors out. - if (this.timeoutCommand != null) { - TimeoutPoller.getTimeoutPoller().remove(this.timeoutCommand); + if (this.timeout != null) { + this.timeout.cancel(false); + this.timeout = null; } } diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerBulkCopy.java b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerBulkCopy.java index 2df1e090a5..3e6243029b 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerBulkCopy.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerBulkCopy.java @@ -42,6 +42,7 @@ import java.util.SimpleTimeZone; import java.util.TimeZone; import java.util.UUID; +import java.util.concurrent.ScheduledFuture; import java.util.logging.Level; import javax.sql.RowSet; @@ -246,31 +247,7 @@ class BulkColumnMetaData { */ private int srcColumnCount; - /** - * Timeout for the bulk copy command - */ - private final class BulkTimeoutCommand extends TimeoutCommand { - public BulkTimeoutCommand(int timeout, TDSCommand command, SQLServerConnection sqlServerConnection) { - super(timeout, command, sqlServerConnection); - } - - @Override - public void interrupt() { - TDSCommand command = getCommand(); - // If the timer wasn't canceled before it ran out of - // time then interrupt the registered command. - try { - command.interrupt(SQLServerException.getErrString("R_queryTimedOut")); - } catch (SQLServerException e) { - // Unfortunately, there's nothing we can do if we - // fail to time out the request. There is no way - // to report back what happened. - command.log(Level.FINE, "Command could not be timed out. Reason: " + e.getMessage()); - } - } - } - - private BulkTimeoutCommand timeoutCommand; + private ScheduledFuture timeout; /** * The maximum temporal precision we can send when using varchar(precision) in bulkcommand, to send a @@ -646,16 +623,14 @@ private void sendBulkLoadBCP() throws SQLServerException { final class InsertBulk extends TDSCommand { InsertBulk() { super("InsertBulk", 0, 0); - int timeoutSeconds = copyOptions.getBulkCopyTimeout(); - timeoutCommand = timeoutSeconds > 0 ? new BulkTimeoutCommand(timeoutSeconds, this, null) : null; } final boolean doExecute() throws SQLServerException { - if (null != timeoutCommand) { - if (logger.isLoggable(Level.FINEST)) - logger.finest(this.toString() + ": Starting bulk timer..."); - - TimeoutPoller.getTimeoutPoller().addTimeoutCommand(timeoutCommand); + int timeoutSeconds = copyOptions.getBulkCopyTimeout(); + if (timeoutSeconds > 0) { + connection.checkClosed(); + timeout = connection.getSharedTimer().schedule(new TDSTimeoutTask(this, connection), + timeoutSeconds); } // doInsertBulk inserts the rows in one batch. It returns true if there are more rows in @@ -671,21 +646,27 @@ final boolean doExecute() throws SQLServerException { } // Check whether it is a timeout exception. - if (rootCause instanceof SQLException) { - checkForTimeoutException((SQLException) rootCause, timeoutCommand); + if (rootCause instanceof SQLException && timeout != null && timeout.isDone()) { + SQLException sqlEx = (SQLException) rootCause; + if (sqlEx.getSQLState() != null + && sqlEx.getSQLState().equals(SQLState.STATEMENT_CANCELED.getSQLStateCode())) { + // If SQLServerBulkCopy is managing the transaction, a rollback is needed. + if (copyOptions.isUseInternalTransaction()) { + connection.rollback(); + } + throw new SQLServerException(SQLServerException.getErrString("R_queryTimedOut"), + SQLState.STATEMENT_CANCELED, DriverError.NOT_SET, sqlEx); + } } // It is not a timeout exception. Re-throw. throw topLevelException; } - if (null != timeoutCommand) { - if (logger.isLoggable(Level.FINEST)) - logger.finest(this.toString() + ": Stopping bulk timer..."); - - TimeoutPoller.getTimeoutPoller().remove(timeoutCommand); + if (timeout != null) { + timeout.cancel(true); + timeout = null; } - return true; } } @@ -1145,22 +1126,6 @@ private void writeColumnMetaData(TDSWriter tdsWriter) throws SQLServerException } } - /** - * Helper method that throws a timeout exception if the cause of the exception was that the query was cancelled - */ - private void checkForTimeoutException(SQLException e, BulkTimeoutCommand timeoutCommand) throws SQLServerException { - if ((null != e.getSQLState()) && (e.getSQLState().equals(SQLState.STATEMENT_CANCELED.getSQLStateCode())) - && timeoutCommand.canTimeout()) { - // If SQLServerBulkCopy is managing the transaction, a rollback is needed. - if (copyOptions.isUseInternalTransaction()) { - connection.rollback(); - } - - throw new SQLServerException(SQLServerException.getErrString("R_queryTimedOut"), - SQLState.STATEMENT_CANCELED, DriverError.NOT_SET, e); - } - } - /** * Validates whether the source JDBC types are compatible with the destination table data types. We need to do this * only once for the whole bulk copy session. diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerConnection.java b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerConnection.java index 0807ae17dc..3116c83343 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerConnection.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerConnection.java @@ -141,6 +141,23 @@ public class SQLServerConnection implements ISQLServerConnection, java.io.Serial private Boolean isAzureDW = null; + private SharedTimer sharedTimer; + + /** + * Return an existing cached SharedTimer associated with this Connection or create a new one. + * + * The SharedTimer will be released when the Connection is closed. + */ + SharedTimer getSharedTimer() { + if (state == State.Closed) { + throw new IllegalStateException("Connection is closed"); + } + if (sharedTimer == null) { + this.sharedTimer = SharedTimer.getTimer(); + } + return this.sharedTimer; + } + static class CityHash128Key implements java.io.Serializable { /** @@ -3174,6 +3191,11 @@ public void close() throws SQLServerException { // with the connection. setState(State.Closed); + if (sharedTimer != null) { + sharedTimer.removeRef(); + sharedTimer = null; + } + // Close the TDS channel. When the channel is closed, the server automatically // rolls back any pending transactions and closes associated resources like // prepared handles. diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SharedTimer.java b/src/main/java/com/microsoft/sqlserver/jdbc/SharedTimer.java new file mode 100644 index 0000000000..57cc70dfa3 --- /dev/null +++ b/src/main/java/com/microsoft/sqlserver/jdbc/SharedTimer.java @@ -0,0 +1,94 @@ +/* + * Microsoft JDBC Driver for SQL Server Copyright(c) Microsoft Corporation All rights reserved. This program is made + * available under the terms of the MIT License. See the LICENSE file in the project root for more information. + */ +package com.microsoft.sqlserver.jdbc; + +import java.util.concurrent.ScheduledFuture; +import java.util.concurrent.ScheduledThreadPoolExecutor; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicLong; + + +class SharedTimer { + static final String CORE_THREAD_PREFIX = "mssql-jdbc-shared-timer-core-"; + private static final AtomicLong CORE_THREAD_COUNTER = new AtomicLong(); + private static SharedTimer instance; + + /** + * Unique ID of this SharedTimer + */ + private final long id = CORE_THREAD_COUNTER.getAndIncrement(); + /** + * Number of outstanding references to this SharedTimer + */ + private int refCount = 0; + private ScheduledThreadPoolExecutor executor; + + private SharedTimer() { + executor = new ScheduledThreadPoolExecutor(1, task -> new Thread(task, CORE_THREAD_PREFIX + id)); + executor.setRemoveOnCancelPolicy(true); + } + + public long getId() { + return id; + } + + /** + * @return Whether there is an instance of the SharedTimer currently allocated. + */ + static synchronized boolean isRunning() { + return instance != null; + } + + /** + * Remove a reference to this SharedTimer. + * + * If the reference count reaches zero then the underlying executor will be shutdown so that its thread stops. + */ + public synchronized void removeRef() { + if (refCount <= 0) { + throw new IllegalStateException("removeRef() called more than actual references"); + } + refCount -= 1; + if (refCount == 0) { + // Removed last reference so perform cleanup + executor.shutdownNow(); + executor = null; + instance = null; + } + } + + /** + * Retrieve a reference to existing SharedTimer or create a new one. + * + * The SharedTimer's reference count will be incremented to account for the new reference. + * + * When the caller is finished with the SharedTimer it must be released via {@link#removeRef} + */ + public static synchronized SharedTimer getTimer() { + if (instance == null) { + // No shared object exists so create a new one + instance = new SharedTimer(); + } + instance.refCount += 1; + return instance; + } + + /** + * Schedule a task to execute in the future using this SharedTimer's internal executor. + */ + public ScheduledFuture schedule(TDSTimeoutTask task, long delaySeconds) { + return schedule(task, delaySeconds, TimeUnit.SECONDS); + } + + /** + * Schedule a task to execute in the future using this SharedTimer's internal executor. + */ + public ScheduledFuture schedule(TDSTimeoutTask task, long delay, TimeUnit unit) { + if (executor == null) { + throw new IllegalStateException("Cannot schedule tasks after shutdown"); + } + return executor.schedule(task, delay, unit); + } +} diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/TDSTimeoutTask.java b/src/main/java/com/microsoft/sqlserver/jdbc/TDSTimeoutTask.java new file mode 100644 index 0000000000..062f18a32c --- /dev/null +++ b/src/main/java/com/microsoft/sqlserver/jdbc/TDSTimeoutTask.java @@ -0,0 +1,61 @@ +/* + * Microsoft JDBC Driver for SQL Server Copyright(c) Microsoft Corporation All rights reserved. This program is made + * available under the terms of the MIT License. See the LICENSE file in the project root for more information. + */ +package com.microsoft.sqlserver.jdbc; + +import java.util.UUID; +import java.util.concurrent.atomic.AtomicLong; +import java.util.logging.Level; + + +/** + * The TDS default implementation of a timeout command + */ +class TDSTimeoutTask implements Runnable { + private static final AtomicLong COUNTER = new AtomicLong(0); + + private final UUID connectionId; + private final TDSCommand command; + private final SQLServerConnection sqlServerConnection; + + public TDSTimeoutTask(TDSCommand command, SQLServerConnection sqlServerConnection) { + this.connectionId = sqlServerConnection == null ? null : sqlServerConnection.getClientConIdInternal(); + this.command = command; + this.sqlServerConnection = sqlServerConnection; + } + + @Override + public final void run() { + // Create a new thread to run the interrupt to ensure that blocking operations performed + // by the interrupt do not hang the primary timer thread. + String name = "mssql-timeout-task-" + COUNTER.incrementAndGet() + "-" + connectionId; + Thread thread = new Thread(this::interrupt, name); + thread.setDaemon(true); + thread.start(); + } + + protected void interrupt() { + try { + // If TCP Connection to server is silently dropped, exceeding the query timeout + // on the same connection does not throw SQLTimeoutException + // The application stops responding instead until SocketTimeoutException is + // thrown. In this case, we must manually terminate the connection. + if (null == command) { + if (null != sqlServerConnection) { + sqlServerConnection.terminate(SQLServerException.DRIVER_ERROR_IO_FAILED, + SQLServerException.getErrString("R_connectionIsClosed")); + } + } else { + // If the timer wasn't canceled before it ran out of + // time then interrupt the registered command. + command.interrupt(SQLServerException.getErrString("R_queryTimedOut")); + } + } catch (SQLServerException e) { + // Unfortunately, there's nothing we can do if we fail to time out the request. There + // is no way to report back what happened. + assert null != command; + command.log(Level.WARNING, "Command could not be timed out. Reason: " + e.getMessage()); + } + } +} diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/TimeoutCommand.java b/src/main/java/com/microsoft/sqlserver/jdbc/TimeoutCommand.java deleted file mode 100644 index 65b1f68b26..0000000000 --- a/src/main/java/com/microsoft/sqlserver/jdbc/TimeoutCommand.java +++ /dev/null @@ -1,41 +0,0 @@ -/* - * Microsoft JDBC Driver for SQL Server Copyright(c) Microsoft Corporation All rights reserved. This program is made - * available under the terms of the MIT License. See the LICENSE file in the project root for more information. - */ - -package com.microsoft.sqlserver.jdbc; - -/** - * Abstract implementation of a command that can be timed out using the {@link TimeoutPoller} - */ -abstract class TimeoutCommand { - private final long startTime; - private final int timeout; - private final T command; - private final SQLServerConnection sqlServerConnection; - - TimeoutCommand(int timeout, T command, SQLServerConnection sqlServerConnection) { - this.timeout = timeout; - this.command = command; - this.sqlServerConnection = sqlServerConnection; - this.startTime = System.currentTimeMillis(); - } - - public boolean canTimeout() { - long currentTime = System.currentTimeMillis(); - return ((currentTime - startTime) / 1000) >= timeout; - } - - public T getCommand() { - return command; - } - - public SQLServerConnection getSqlServerConnection() { - return sqlServerConnection; - } - - /** - * The implementation for interrupting this timeout command - */ - public abstract void interrupt(); -} diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/TimeoutPoller.java b/src/main/java/com/microsoft/sqlserver/jdbc/TimeoutPoller.java deleted file mode 100644 index 6c53d4d744..0000000000 --- a/src/main/java/com/microsoft/sqlserver/jdbc/TimeoutPoller.java +++ /dev/null @@ -1,82 +0,0 @@ -/* - * Microsoft JDBC Driver for SQL Server Copyright(c) Microsoft Corporation All rights reserved. This program is made - * available under the terms of the MIT License. See the LICENSE file in the project root for more information. - */ - -package com.microsoft.sqlserver.jdbc; - -import java.util.ArrayList; -import java.util.Iterator; -import java.util.List; -import java.util.logging.Level; -import java.util.logging.Logger; - - -/** - * Thread that runs in the background while the mssql driver is used that can timeout TDSCommands Checks all registered - * commands every second to see if they can be interrupted - */ -final class TimeoutPoller implements Runnable { - private List> timeoutCommands = new ArrayList<>(); - final static Logger logger = Logger.getLogger("com.microsoft.sqlserver.jdbc.TimeoutPoller"); - private static volatile TimeoutPoller timeoutPoller = null; - - static TimeoutPoller getTimeoutPoller() { - if (timeoutPoller == null) { - synchronized (TimeoutPoller.class) { - if (timeoutPoller == null) { - // initialize the timeout poller thread once - timeoutPoller = new TimeoutPoller(); - // start the timeout polling thread - Thread pollerThread = new Thread(timeoutPoller, "mssql-jdbc-TimeoutPoller"); - pollerThread.setDaemon(true); - pollerThread.start(); - } - } - } - return timeoutPoller; - } - - void addTimeoutCommand(TimeoutCommand timeoutCommand) { - synchronized (timeoutCommands) { - timeoutCommands.add(timeoutCommand); - } - } - - void remove(TimeoutCommand timeoutCommand) { - synchronized (timeoutCommands) { - timeoutCommands.remove(timeoutCommand); - } - } - - private TimeoutPoller() {} - - public void run() { - try { - // Poll every second checking for commands that have timed out and need - // interruption - while (true) { - synchronized (timeoutCommands) { - Iterator> timeoutCommandIterator = timeoutCommands.iterator(); - while (timeoutCommandIterator.hasNext()) { - TimeoutCommand timeoutCommand = timeoutCommandIterator.next(); - try { - if (timeoutCommand.canTimeout()) { - try { - timeoutCommand.interrupt(); - } finally { - timeoutCommandIterator.remove(); - } - } - } catch (Exception e) { - logger.log(Level.WARNING, "Could not timeout command", e); - } - } - } - Thread.sleep(1000); - } - } catch (Exception e) { - logger.log(Level.SEVERE, "Error processing timeout commands", e); - } - } -} diff --git a/src/test/java/com/microsoft/sqlserver/jdbc/TimeoutTest.java b/src/test/java/com/microsoft/sqlserver/jdbc/TimeoutTest.java new file mode 100644 index 0000000000..2bb6c7cab9 --- /dev/null +++ b/src/test/java/com/microsoft/sqlserver/jdbc/TimeoutTest.java @@ -0,0 +1,194 @@ +/* + * Microsoft JDBC Driver for SQL Server Copyright(c) Microsoft Corporation All rights reserved. This program is made + * available under the terms of the MIT License. See the LICENSE file in the project root for more information. + */ + +package com.microsoft.sqlserver.jdbc; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.SQLTimeoutException; +import java.util.Set; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.platform.runner.JUnitPlatform; +import org.junit.runner.RunWith; + +import com.microsoft.sqlserver.testframework.AbstractTest; + + +@RunWith(JUnitPlatform.class) +public class TimeoutTest extends AbstractTest { + private static final int TIMEOUT_SECONDS = 2; + private static final String WAIT_FOR_ONE_MINUTE_SQL = "WAITFOR DELAY '00:01:00'"; + + @BeforeAll + public static void beforeAll() throws SQLException, InterruptedException { + if (connection != null) { + connection.close(); + connection = null; + } + waitForSharedTimerThreadToStop(); + } + + @Before + public void before() throws InterruptedException { + waitForSharedTimerThreadToStop(); + } + + @After + public void after() throws InterruptedException { + waitForSharedTimerThreadToStop(); + } + + @Test + public void testBasicQueryTimeout() { + assertThrows(SQLTimeoutException.class, () -> { + runQuery(WAIT_FOR_ONE_MINUTE_SQL, TIMEOUT_SECONDS); + }); + } + + @Test + public void testQueryTimeoutValid() { + long start = System.currentTimeMillis(); + assertThrows(SQLTimeoutException.class, () -> { + runQuery(WAIT_FOR_ONE_MINUTE_SQL, TIMEOUT_SECONDS); + }); + long elapsedSeconds = (System.currentTimeMillis() - start) / 1000; + Assert.assertTrue("Query duration must be at least timeout amount, elapsed=" + elapsedSeconds, + elapsedSeconds >= TIMEOUT_SECONDS); + } + + @Test + public void testZeroTimeoutShouldNotStartTimerThread() throws SQLException { + try (Connection conn = getConnection()) { + // Connection is open but we have not used a timeout so it should be running + assertSharedTimerNotRunning(); + runQuery(conn, "SELECT 1", 0); + // Our statement does not have a timeout so the timer should not be started yet + assertSharedTimerNotRunning(); + } + } + + @Test + public void testNoTimeoutShouldNotStartTimerThread() throws SQLException { + try (Connection conn = getConnection()) { + // Connection is open but we have not used a timeout so it should not be running + assertSharedTimerNotRunning(); + runQuery(conn, "SELECT 1", 0); + // Ran a query but our statement does not have a timeout so the timer should not be running + assertSharedTimerNotRunning(); + } + } + + @Test + public void testPositiveTimeoutShouldStartTimerThread() throws SQLException { + try (Connection conn = getConnection()) { + // Connection is open but we have not used a timeout so it should not be running + assertSharedTimerNotRunning(); + runQuery(conn, "SELECT 1", TIMEOUT_SECONDS); + // Ran a query with a timeout so the thread should continue running + assertSharedTimerIsRunning(); + } + } + + @Test + public void testNestedTimeoutShouldKeepTimerThreadRunning() throws SQLException { + try (Connection conn = getConnection()) { + // Connection is open but we have not used a timeout so it should not be running + assertSharedTimerNotRunning(); + runQuery(conn, "SELECT 1", TIMEOUT_SECONDS); + // Ran a query with a timeout so the thread should continue running + assertSharedTimerIsRunning(); + + // Open a new connection + try (Connection otherConn = getConnection()) { + assertSharedTimerIsRunning(); + runQuery(otherConn, "SELECT 1", TIMEOUT_SECONDS); + assertSharedTimerIsRunning(); + } + + // Timer should still be running because our original connection is still open + assertSharedTimerIsRunning(); + } + } + + private static Connection getConnection() throws SQLException { + return DriverManager.getConnection(connectionString); + } + + private static void runQuery(String query, int timeout) throws SQLException { + try (Connection conn = getConnection()) { + runQuery(conn, query, timeout); + } + } + + private static void runQuery(Connection conn, String query, int timeout) throws SQLException { + try (PreparedStatement stmt = conn.prepareStatement(query)) { + if (timeout > 0) { + stmt.setQueryTimeout(timeout); + } + try (ResultSet rs = stmt.executeQuery()) {} + } + } + + @Test + public void testSameSharedTimerRetrieved() { + SharedTimer timer = SharedTimer.getTimer(); + try { + SharedTimer otherTimer = SharedTimer.getTimer(); + try { + assertEquals("The same SharedTimer should be returned", timer.getId(), otherTimer.getId()); + } finally { + otherTimer.removeRef(); + } + } finally { + timer.removeRef(); + } + } + + private static boolean isSharedTimerThreadRunning() { + Set threadSet = Thread.getAllStackTraces().keySet(); + for (Thread thread : threadSet) { + if (thread.getName().startsWith(SharedTimer.CORE_THREAD_PREFIX)) { + return true; + } + } + return false; + } + + private static void waitForSharedTimerThreadToStop() throws InterruptedException { + long started = System.currentTimeMillis(); + long MAX_WAIT_FOR_STOP_SECONDS = 10; + while (isSharedTimerThreadRunning()) { + long elapsed = System.currentTimeMillis() - started; + if (elapsed > MAX_WAIT_FOR_STOP_SECONDS * 1000) { + fail("SharedTimer thread did not stop within " + MAX_WAIT_FOR_STOP_SECONDS + " seconds"); + } + // Sleep a bit and try again + Thread.sleep(100); + } + assertSharedTimerNotRunning(); + } + + private static void assertSharedTimerNotRunning() { + assertFalse("SharedTimer should not be running", isSharedTimerThreadRunning()); + } + + private static void assertSharedTimerIsRunning() { + assertTrue("SharedTimer should be running", isSharedTimerThreadRunning()); + } +} diff --git a/src/test/java/com/microsoft/sqlserver/jdbc/timeouts/TimeoutTest.java b/src/test/java/com/microsoft/sqlserver/jdbc/timeouts/TimeoutTest.java deleted file mode 100644 index a61e097b17..0000000000 --- a/src/test/java/com/microsoft/sqlserver/jdbc/timeouts/TimeoutTest.java +++ /dev/null @@ -1,63 +0,0 @@ -/* - * Microsoft JDBC Driver for SQL Server Copyright(c) Microsoft Corporation All rights reserved. This program is made - * available under the terms of the MIT License. See the LICENSE file in the project root for more information. - */ - -package com.microsoft.sqlserver.jdbc.timeouts; - -import java.sql.Connection; -import java.sql.DriverManager; -import java.sql.PreparedStatement; -import java.sql.SQLException; -import java.sql.SQLTimeoutException; - -import org.junit.Assert; -import org.junit.jupiter.api.Test; -import org.junit.platform.runner.JUnitPlatform; -import org.junit.runner.RunWith; - -import com.microsoft.sqlserver.testframework.AbstractTest; - - -@RunWith(JUnitPlatform.class) -public class TimeoutTest extends AbstractTest { - @Test - public void testBasicQueryTimeout() { - boolean exceptionThrown = false; - try { - // wait 1 minute and timeout after 10 seconds - Assert.assertTrue("Select succeeded", runQuery("WAITFOR DELAY '00:01'", 10)); - } catch (SQLException e) { - exceptionThrown = true; - Assert.assertTrue("Timeout exception not thrown", e.getClass().equals(SQLTimeoutException.class)); - } - Assert.assertTrue("A SQLTimeoutException was expected", exceptionThrown); - } - - @Test - public void testQueryTimeoutValid() { - boolean exceptionThrown = false; - int timeoutInSeconds = 10; - long start = System.currentTimeMillis(); - try { - // wait 1 minute and timeout after 10 seconds - Assert.assertTrue("Select succeeded", runQuery("WAITFOR DELAY '00:01'", timeoutInSeconds)); - } catch (SQLException e) { - int secondsElapsed = (int) ((System.currentTimeMillis() - start) / 1000); - Assert.assertTrue("Query did not timeout expected, elapsedTime=" + secondsElapsed, - secondsElapsed >= timeoutInSeconds); - exceptionThrown = true; - Assert.assertTrue("Timeout exception not thrown", e.getClass().equals(SQLTimeoutException.class)); - } - Assert.assertTrue("A SQLTimeoutException was expected", exceptionThrown); - } - - private boolean runQuery(String query, int timeout) throws SQLException { - try (Connection con = DriverManager.getConnection(connectionString); - PreparedStatement preparedStatement = con.prepareStatement(query)) { - // set provided timeout - preparedStatement.setQueryTimeout(timeout); - return preparedStatement.execute(); - } - } -}