From e793bae17d48b9f76615bad8e12eedc2601bea3f Mon Sep 17 00:00:00 2001 From: sergiyvamz Date: Wed, 31 Jul 2024 17:06:51 -0700 Subject: [PATCH] fix issue with statement object cast (#1045) --- .../amazon/jdbc/util/WrapperUtils.java | 18 +++++++++-- .../amazon/jdbc/util/WrapperUtilsTest.java | 31 +++++++++++++++++-- 2 files changed, 45 insertions(+), 4 deletions(-) diff --git a/wrapper/src/main/java/software/amazon/jdbc/util/WrapperUtils.java b/wrapper/src/main/java/software/amazon/jdbc/util/WrapperUtils.java index 0fd82b004..46eb1e29c 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/util/WrapperUtils.java +++ b/wrapper/src/main/java/software/amazon/jdbc/util/WrapperUtils.java @@ -284,13 +284,27 @@ public static T executeWithPlugins( return toProxy; } - Class wrapperClass = availableWrappers.get(resultClass); + Class effectiveResultClass = resultClass; + + if (resultClass == Statement.class) { + // Statement class is a special case since it has subclasses like PreparedStatement and CallableStatement. + // We need to choose the best result class based on actual toProxy object. + + // Order of the following if-statements is important! + if (toProxy instanceof CallableStatement) { + effectiveResultClass = CallableStatement.class; + } else if (toProxy instanceof PreparedStatement) { + effectiveResultClass = PreparedStatement.class; + } + } + + Class wrapperClass = availableWrappers.get(effectiveResultClass); if (wrapperClass != null) { return createInstance( wrapperClass, resultClass, - new Class[] {resultClass, ConnectionPluginManager.class}, + new Class[] {effectiveResultClass, ConnectionPluginManager.class}, toProxy, pluginManager); } diff --git a/wrapper/src/test/java/software/amazon/jdbc/util/WrapperUtilsTest.java b/wrapper/src/test/java/software/amazon/jdbc/util/WrapperUtilsTest.java index 9a4f7ba9d..38565ed7f 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/util/WrapperUtilsTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/util/WrapperUtilsTest.java @@ -16,6 +16,7 @@ package software.amazon.jdbc.util; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.fail; import static org.mockito.ArgumentMatchers.any; @@ -26,12 +27,12 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; +import java.sql.CallableStatement; import java.sql.Connection; +import java.sql.PreparedStatement; import java.sql.ResultSet; import java.sql.SQLException; import java.sql.Statement; -import java.util.ArrayList; -import java.util.List; import java.util.concurrent.CompletableFuture; import java.util.concurrent.locks.ReentrantLock; import org.junit.jupiter.api.AfterEach; @@ -43,6 +44,9 @@ import software.amazon.jdbc.JdbcCallable; import software.amazon.jdbc.util.telemetry.TelemetryContext; import software.amazon.jdbc.util.telemetry.TelemetryFactory; +import software.amazon.jdbc.wrapper.CallableStatementWrapper; +import software.amazon.jdbc.wrapper.PreparedStatementWrapper; +import software.amazon.jdbc.wrapper.StatementWrapper; public class WrapperUtilsTest { @@ -155,4 +159,27 @@ void getConnectionFromSqlObjectChecksStatementNotClosed() throws Exception { final Connection rsConn = WrapperUtils.getConnectionFromSqlObject(mockClosedStatement); assertNull(rsConn); } + + @Test + void testStatementWrapper() throws InstantiationException { + ConnectionPluginManager mockPluginManager = mock(ConnectionPluginManager.class); + + assertInstanceOf(StatementWrapper.class, + WrapperUtils.wrapWithProxyIfNeeded( + Statement.class, + mock(Statement.class), + mockPluginManager)); + + assertInstanceOf(PreparedStatementWrapper.class, + WrapperUtils.wrapWithProxyIfNeeded( + Statement.class, + mock(PreparedStatement.class), + mockPluginManager)); + + assertInstanceOf(CallableStatementWrapper.class, + WrapperUtils.wrapWithProxyIfNeeded( + Statement.class, + mock(CallableStatement.class), + mockPluginManager)); + } }