Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -107,20 +107,7 @@ private Connection connectInternal(String driverProtocol, HostSpec hostSpec, Pro
host = IAM_HOST.getString(props);
}

int port = hostSpec.getPort();
if (!hostSpec.isPortSpecified()) {
if (StringUtils.isNullOrEmpty(IAM_DEFAULT_PORT.getString(props))) {
port = this.pluginService.getDialect().getDefaultPort();
} else {
port = IAM_DEFAULT_PORT.getInteger(props);
if (port <= 0) {
throw new IllegalArgumentException(
Messages.get(
"IamAuthConnectionPlugin.invalidPort",
new Object[] {port}));
}
}
}
int port = getPort(props, hostSpec);

final String iamRegion = IAM_REGION.getString(props);
final Region region = StringUtils.isNullOrEmpty(iamRegion)
Expand Down Expand Up @@ -245,6 +232,26 @@ public static void clearCache() {
tokenCache.clear();
}

private int getPort(Properties props, HostSpec hostSpec) {
if (!StringUtils.isNullOrEmpty(IAM_DEFAULT_PORT.getString(props))) {
int defaultPort = IAM_DEFAULT_PORT.getInteger(props);
if (defaultPort > 0) {
return defaultPort;
} else {
LOGGER.finest(
() -> Messages.get(
"IamAuthConnectionPlugin.invalidPort",
new Object[] {defaultPort}));
}
}

if (hostSpec.isPortSpecified()) {
return hostSpec.getPort();
} else {
return this.pluginService.getDialect().getDefaultPort();
}
}

private Region getRdsRegion(final String hostname) throws SQLException {

// Get Region
Expand Down
2 changes: 1 addition & 1 deletion wrapper/src/main/resources/messages.properties
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ HostSelector.noHostsMatchingRole=No hosts were found matching the requested ''{0
IamAuthConnectionPlugin.unsupportedHostname=Unsupported AWS hostname {0}. Amazon domain name in format *.AWS-Region.rds.amazonaws.com or *.rds.AWS-Region.amazonaws.com.cn is expected.
IamAuthConnectionPlugin.useCachedIamToken=Use cached IAM token = ''{0}''
IamAuthConnectionPlugin.generatedNewIamToken=Generated new IAM token = ''{0}''
IamAuthConnectionPlugin.invalidPort=Port number: {0} is not valid. Port number should be greater than zero.
IamAuthConnectionPlugin.invalidPort=Port number: {0} is not valid. Port number should be greater than zero. Falling back to default port.
IamAuthConnectionPlugin.unhandledException=Unhandled exception: ''{0}''
IamAuthConnectionPlugin.connectException=Error occurred while opening a connection: ''{0}''

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,15 +113,31 @@ public void testMySqlConnectValidTokenInCache() throws SQLException {
}

@Test
public void testPostgresConnectWithInvalidPort() {
props.setProperty("iamDefaultPort", "0");
PluginService mockPluginService = Mockito.mock(PluginService.class);
final IamAuthConnectionPlugin targetPlugin = new IamAuthConnectionPlugin(mockPluginService);
public void testPostgresConnectWithInvalidPortFallbacksToHostPort() throws SQLException {
final String invalidIamDefaultPort = "0";
props.setProperty(IamAuthConnectionPlugin.IAM_DEFAULT_PORT.name, invalidIamDefaultPort);

final IllegalArgumentException exception = assertThrows(IllegalArgumentException.class,
() -> targetPlugin.connect(PG_DRIVER_PROTOCOL, PG_HOST_SPEC, props, true, mockLambda));
final String cacheKeyWithNewPort = "us-east-2:pg.testdb.us-east-2.rds.amazonaws.com:"
+ PG_HOST_SPEC_WITH_PORT.getPort() + ":postgresqlUser";
IamAuthConnectionPlugin.tokenCache.put(cacheKeyWithNewPort,
new IamAuthConnectionPlugin.TokenInfo(TEST_TOKEN, Instant.now().plusMillis(300000)));

assertEquals("Port number: 0 is not valid. Port number should be greater than zero.", exception.getMessage());
testTokenSetInProps(PG_DRIVER_PROTOCOL, PG_HOST_SPEC_WITH_PORT);
}

@Test
public void testPostgresConnectWithInvalidPortAndNoHostPortFallbacksToHostPort() throws SQLException {
final String invalidIamDefaultPort = "0";
props.setProperty(IamAuthConnectionPlugin.IAM_DEFAULT_PORT.name, invalidIamDefaultPort);

when(mockDialect.getDefaultPort()).thenReturn(DEFAULT_PG_PORT);

final String cacheKeyWithNewPort = "us-east-2:pg.testdb.us-east-2.rds.amazonaws.com:"
+ DEFAULT_PG_PORT + ":postgresqlUser";
IamAuthConnectionPlugin.tokenCache.put(cacheKeyWithNewPort,
new IamAuthConnectionPlugin.TokenInfo(TEST_TOKEN, Instant.now().plusMillis(300000)));

testTokenSetInProps(PG_DRIVER_PROTOCOL, PG_HOST_SPEC);
}

@Test
Expand Down Expand Up @@ -150,6 +166,18 @@ public void testConnectWithSpecifiedPort() throws SQLException {
testTokenSetInProps(PG_DRIVER_PROTOCOL, PG_HOST_SPEC_WITH_PORT);
}

@Test
public void testConnectWithSpecifiedIamDefaultPort() throws SQLException {
final String iamDefaultPort = "9999";
props.setProperty(IamAuthConnectionPlugin.IAM_DEFAULT_PORT.name, iamDefaultPort);
final String cacheKeyWithNewPort = "us-east-2:pg.testdb.us-east-2.rds.amazonaws.com:"
+ iamDefaultPort + ":postgresqlUser";
IamAuthConnectionPlugin.tokenCache.put(cacheKeyWithNewPort,
new IamAuthConnectionPlugin.TokenInfo(TEST_TOKEN, Instant.now().plusMillis(300000)));

testTokenSetInProps(PG_DRIVER_PROTOCOL, PG_HOST_SPEC_WITH_PORT);
}

@Test
public void testConnectWithSpecifiedRegion() throws SQLException {
final String cacheKeyWithNewRegion =
Expand Down