diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/IamAuthConnectionPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/IamAuthConnectionPlugin.java index 8ca84b0b1..33fdf864d 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/IamAuthConnectionPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/IamAuthConnectionPlugin.java @@ -107,20 +107,7 @@ private Connection connectInternal(String driverProtocol, HostSpec hostSpec, Pro host = IAM_HOST.getString(props); } - int port = hostSpec.getPort(); - if (!hostSpec.isPortSpecified()) { - if (StringUtils.isNullOrEmpty(IAM_DEFAULT_PORT.getString(props))) { - port = this.pluginService.getDialect().getDefaultPort(); - } else { - port = IAM_DEFAULT_PORT.getInteger(props); - if (port <= 0) { - throw new IllegalArgumentException( - Messages.get( - "IamAuthConnectionPlugin.invalidPort", - new Object[] {port})); - } - } - } + int port = getPort(props, hostSpec); final String iamRegion = IAM_REGION.getString(props); final Region region = StringUtils.isNullOrEmpty(iamRegion) @@ -245,6 +232,26 @@ public static void clearCache() { tokenCache.clear(); } + private int getPort(Properties props, HostSpec hostSpec) { + if (!StringUtils.isNullOrEmpty(IAM_DEFAULT_PORT.getString(props))) { + int defaultPort = IAM_DEFAULT_PORT.getInteger(props); + if (defaultPort > 0) { + return defaultPort; + } else { + LOGGER.finest( + () -> Messages.get( + "IamAuthConnectionPlugin.invalidPort", + new Object[] {defaultPort})); + } + } + + if (hostSpec.isPortSpecified()) { + return hostSpec.getPort(); + } else { + return this.pluginService.getDialect().getDefaultPort(); + } + } + private Region getRdsRegion(final String hostname) throws SQLException { // Get Region diff --git a/wrapper/src/main/resources/messages.properties b/wrapper/src/main/resources/messages.properties index 4a002248f..57bc397de 100644 --- a/wrapper/src/main/resources/messages.properties +++ b/wrapper/src/main/resources/messages.properties @@ -162,7 +162,7 @@ HostSelector.noHostsMatchingRole=No hosts were found matching the requested ''{0 IamAuthConnectionPlugin.unsupportedHostname=Unsupported AWS hostname {0}. Amazon domain name in format *.AWS-Region.rds.amazonaws.com or *.rds.AWS-Region.amazonaws.com.cn is expected. IamAuthConnectionPlugin.useCachedIamToken=Use cached IAM token = ''{0}'' IamAuthConnectionPlugin.generatedNewIamToken=Generated new IAM token = ''{0}'' -IamAuthConnectionPlugin.invalidPort=Port number: {0} is not valid. Port number should be greater than zero. +IamAuthConnectionPlugin.invalidPort=Port number: {0} is not valid. Port number should be greater than zero. Falling back to default port. IamAuthConnectionPlugin.unhandledException=Unhandled exception: ''{0}'' IamAuthConnectionPlugin.connectException=Error occurred while opening a connection: ''{0}'' diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/IamAuthConnectionPluginTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/IamAuthConnectionPluginTest.java index 879d6dcc3..3cb62ce4e 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/plugin/IamAuthConnectionPluginTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/IamAuthConnectionPluginTest.java @@ -113,15 +113,31 @@ public void testMySqlConnectValidTokenInCache() throws SQLException { } @Test - public void testPostgresConnectWithInvalidPort() { - props.setProperty("iamDefaultPort", "0"); - PluginService mockPluginService = Mockito.mock(PluginService.class); - final IamAuthConnectionPlugin targetPlugin = new IamAuthConnectionPlugin(mockPluginService); + public void testPostgresConnectWithInvalidPortFallbacksToHostPort() throws SQLException { + final String invalidIamDefaultPort = "0"; + props.setProperty(IamAuthConnectionPlugin.IAM_DEFAULT_PORT.name, invalidIamDefaultPort); - final IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, - () -> targetPlugin.connect(PG_DRIVER_PROTOCOL, PG_HOST_SPEC, props, true, mockLambda)); + final String cacheKeyWithNewPort = "us-east-2:pg.testdb.us-east-2.rds.amazonaws.com:" + + PG_HOST_SPEC_WITH_PORT.getPort() + ":postgresqlUser"; + IamAuthConnectionPlugin.tokenCache.put(cacheKeyWithNewPort, + new IamAuthConnectionPlugin.TokenInfo(TEST_TOKEN, Instant.now().plusMillis(300000))); - assertEquals("Port number: 0 is not valid. Port number should be greater than zero.", exception.getMessage()); + testTokenSetInProps(PG_DRIVER_PROTOCOL, PG_HOST_SPEC_WITH_PORT); + } + + @Test + public void testPostgresConnectWithInvalidPortAndNoHostPortFallbacksToHostPort() throws SQLException { + final String invalidIamDefaultPort = "0"; + props.setProperty(IamAuthConnectionPlugin.IAM_DEFAULT_PORT.name, invalidIamDefaultPort); + + when(mockDialect.getDefaultPort()).thenReturn(DEFAULT_PG_PORT); + + final String cacheKeyWithNewPort = "us-east-2:pg.testdb.us-east-2.rds.amazonaws.com:" + + DEFAULT_PG_PORT + ":postgresqlUser"; + IamAuthConnectionPlugin.tokenCache.put(cacheKeyWithNewPort, + new IamAuthConnectionPlugin.TokenInfo(TEST_TOKEN, Instant.now().plusMillis(300000))); + + testTokenSetInProps(PG_DRIVER_PROTOCOL, PG_HOST_SPEC); } @Test @@ -150,6 +166,18 @@ public void testConnectWithSpecifiedPort() throws SQLException { testTokenSetInProps(PG_DRIVER_PROTOCOL, PG_HOST_SPEC_WITH_PORT); } + @Test + public void testConnectWithSpecifiedIamDefaultPort() throws SQLException { + final String iamDefaultPort = "9999"; + props.setProperty(IamAuthConnectionPlugin.IAM_DEFAULT_PORT.name, iamDefaultPort); + final String cacheKeyWithNewPort = "us-east-2:pg.testdb.us-east-2.rds.amazonaws.com:" + + iamDefaultPort + ":postgresqlUser"; + IamAuthConnectionPlugin.tokenCache.put(cacheKeyWithNewPort, + new IamAuthConnectionPlugin.TokenInfo(TEST_TOKEN, Instant.now().plusMillis(300000))); + + testTokenSetInProps(PG_DRIVER_PROTOCOL, PG_HOST_SPEC_WITH_PORT); + } + @Test public void testConnectWithSpecifiedRegion() throws SQLException { final String cacheKeyWithNewRegion =