Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ The wrapper driver currently uses [Hikari](https://github.com/brettwooldridge/Hi
- username
- password

You can optionally pass in a `HikariPoolMapping` function as a second parameter to the `HikariPooledConnectionProvider`. Internally, the connection pools used by the plugin are maintained as a map from instance URLs to connection pools. If you would like to define a different key system, you should pass in a `HikariPoolMapping` function defining this logic. This is helpful, for example, when you would like to create multiple Connection objects to the same instance with different users. In this scenario, you should pass in a `HikariPoolMapping` that incorporates the instance URL and the username from the `Properties` object into the map key.
You can optionally pass in a `HikariPoolMapping` function as a second parameter to the `HikariPooledConnectionProvider`. This allows you to decide when new connection pools should be created by defining what is included in the pool map key. A new pool will be created each time a new connection is requested with a unique key. By default, a new pool will be created for each unique instance-user combination. If you would like to define a different key system, you should pass in a `HikariPoolMapping` function defining this logic. Note that the user will always be automatically included in the key for security reasons. Please see [ReadWriteSplittingPostgresExample.java](../../../examples/AWSDriverExample/src/main/java/software/amazon/ReadWriteSplittingPostgresExample.java) for an example of how to configure the pool map key.

2. Call `ConnectionProviderManager.setConnectionProvider`, passing in the `HikariPooledConnectionProvider` you created in step 1.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ public static void main(String[] args) throws SQLException {
* Optional: configure read-write splitting to use internal connection pools (the getPoolKey
* parameter is optional, see UsingTheReadWriteSplittingPlugin.md for more info).
*/
// props.setProperty("somePropertyValue", "1"); // used in getPoolKey
// final HikariPooledConnectionProvider connProvider =
// new HikariPooledConnectionProvider(
// ReadWriteSplittingPostgresExample::getHikariConfig,
Expand Down Expand Up @@ -163,9 +164,10 @@ private static HikariConfig getHikariConfig(HostSpec hostSpec, Properties props)
// This method is an optional parameter to `ConnectionProviderManager.setConnectionProvider`.
// It can be omitted if you do not require it.
private static String getPoolKey(HostSpec hostSpec, Properties props) {
// Include the user in the connection pool key so that a new connection pool will be opened for
// each instance-user combination.
final String user = props.getProperty(PropertyDefinition.USER.name);
return hostSpec.getUrl() + user;
// Include the URL and somePropertyValue in the connection pool key so that a new connection
// pool will be opened for each different instance-user-somePropertyValue combination.
// (Note that the user will automatically be added to the key).
final String somePropertyValue = props.getProperty("somePropertyValue");
return hostSpec.getUrl() + somePropertyValue;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -120,10 +120,11 @@ public Connection connect(
@NonNull Properties props)
throws SQLException {
final HikariDataSource ds = databasePools.computeIfAbsent(
poolMapping.getKey(hostSpec, props),
getPoolKey(hostSpec, props),
url -> createHikariDataSource(protocol, hostSpec, props)
);

ds.setPassword(props.getProperty(PropertyDefinition.PASSWORD.name));
Connection conn = ds.getConnection();
int count = 0;
while (conn != null && count++ < retries && !conn.isValid(3)) {
Expand All @@ -140,6 +141,15 @@ public Connection connect(
return null;
}

// The pool key should always be retrieved using this method, because the username
// must always be included to avoid sharing privileged connections with other users.
private String getPoolKey(HostSpec hostSpec, Properties props) {
final StringBuilder sb = new StringBuilder();
sb.append(poolMapping.getKey(hostSpec, props))
.append(props.getProperty(PropertyDefinition.USER.name));
return sb.toString();
}

@Override
public void releaseResources() {
databasePools.forEach((String url, HikariDataSource ds) -> ds.close());
Expand Down Expand Up @@ -168,7 +178,11 @@ protected void configurePool(
}

final StringJoiner propsJoiner = new StringJoiner("&");
connectionProps.forEach((k, v) -> propsJoiner.add(k + "=" + v));
connectionProps.forEach((k, v) -> {
if (!PropertyDefinition.PASSWORD.name.equals(k) && !PropertyDefinition.USER.name.equals(k)) {
propsJoiner.add(k + "=" + v);
}
});
urlBuilder.append("?").append(propsJoiner);

config.setJdbcUrl(urlBuilder.toString());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,12 @@
import static org.junit.jupiter.api.Assertions.fail;

import com.zaxxer.hikari.HikariConfig;
import com.zaxxer.hikari.pool.HikariPool;
import integration.refactored.DatabaseEngine;
import integration.refactored.DatabaseEngineDeployment;
import integration.refactored.DriverHelper;
import integration.refactored.TestEnvironmentFeatures;
import integration.refactored.TestEnvironmentInfo;
import integration.refactored.TestInstanceInfo;
import integration.refactored.container.ConnectionStringHelper;
import integration.refactored.container.ProxyHelper;
Expand Down Expand Up @@ -597,7 +599,7 @@ public void test_pooledConnection_reuseCachedConnection() throws SQLException {
protected static HikariConfig getHikariConfig(HostSpec hostSpec, Properties props) {
final HikariConfig config = new HikariConfig();
config.setMaximumPoolSize(1);
config.setInitializationFailTimeout(75000);
config.setInitializationFailTimeout(20000);
return config;
}

Expand Down Expand Up @@ -758,4 +760,66 @@ public void test_pooledConnection_failoverInTransaction()
ConnectionProviderManager.resetProvider();
}
}

@TestTemplate
public void test_pooledConnection_differentUsers() throws SQLException {
Properties privilegedUserProps = getProps();

Properties privilegedUserWithWrongPasswordProps = getProps();
privilegedUserWithWrongPasswordProps.setProperty(PropertyDefinition.PASSWORD.name, "bogus_password");

Properties limitedUserProps = getProps();
String limitedUserName = "limited_user";
String limitedUserPassword = "limited_user";
String limitedUserNewDb = "limited_user_db";
limitedUserProps.setProperty(PropertyDefinition.USER.name, limitedUserName);
limitedUserProps.setProperty(PropertyDefinition.PASSWORD.name, limitedUserPassword);

Properties wrongUserRightPasswordProps = getProps();
wrongUserRightPasswordProps.setProperty(PropertyDefinition.USER.name, "bogus_user");

final HikariPooledConnectionProvider provider =
new HikariPooledConnectionProvider(ReadWriteSplittingTests::getHikariConfig);
ConnectionProviderManager.setConnectionProvider(provider);

try {
try (Connection conn = DriverManager.getConnection(ConnectionStringHelper.getWrapperUrl(),
privilegedUserProps); Statement stmt = conn.createStatement()) {
stmt.execute("DROP USER IF EXISTS " + limitedUserName);
auroraUtil.createUser(conn, limitedUserName, limitedUserPassword);
TestEnvironmentInfo info = TestEnvironment.getCurrent().getInfo();
DatabaseEngine engine = info.getRequest().getDatabaseEngine();
if (DatabaseEngine.MYSQL.equals(engine)) {
String db = info.getDatabaseInfo().getDefaultDbName();
// MySQL needs this extra command to allow the limited user access to the database
stmt.execute("GRANT ALL PRIVILEGES ON " + db + ".* to " + limitedUserName);
}
}

try (final Connection conn = DriverManager.getConnection(
ConnectionStringHelper.getWrapperUrl(), limitedUserProps);
Statement stmt = conn.createStatement()) {
assertThrows(SQLException.class,
() -> stmt.execute("CREATE DATABASE " + limitedUserNewDb));
}

assertThrows(
HikariPool.PoolInitializationException.class, () -> {
try (final Connection conn = DriverManager.getConnection(
ConnectionStringHelper.getWrapperUrl(), wrongUserRightPasswordProps)) {
// Do nothing (close connection automatically)
}
});
} finally {
ConnectionProviderManager.releaseResources();
ConnectionProviderManager.resetProvider();

try (Connection conn = DriverManager.getConnection(ConnectionStringHelper.getWrapperUrl(),
privilegedUserProps);
Statement stmt = conn.createStatement()) {
stmt.execute("DROP DATABASE IF EXISTS " + limitedUserNewDb);
stmt.execute("DROP USER IF EXISTS " + limitedUserName);
}
}
}
}
20 changes: 19 additions & 1 deletion wrapper/src/test/java/integration/util/AuroraTestUtility.java
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@
import software.amazon.awssdk.services.rds.model.Filter;
import software.amazon.awssdk.services.rds.model.Tag;
import software.amazon.awssdk.services.rds.waiters.RdsWaiter;
import software.amazon.jdbc.plugin.failover.FailoverSuccessSQLException;
import software.amazon.jdbc.util.StringUtils;

/**
Expand Down Expand Up @@ -762,6 +761,25 @@ public String executeInstanceIdQuery(DatabaseEngine databaseEngine, Statement st
return null;
}

public void createUser(Connection conn, String username, String password) throws SQLException {
DatabaseEngine engine = TestEnvironment.getCurrent().getInfo().getRequest().getDatabaseEngine();
String dropUserSql = getCreateUserSql(engine, username, password);
try (Statement stmt = conn.createStatement()) {
stmt.execute(dropUserSql);
}
}

protected String getCreateUserSql(DatabaseEngine engine, String username, String password) {
switch (engine) {
case MYSQL:
return "CREATE USER " + username + " identified by '" + password + "'";
case PG:
return "CREATE USER " + username + " with password '" + password + "'";
default:
throw new UnsupportedOperationException(engine.toString());
}
}

public void addAuroraAwsIamUser(
DatabaseEngine databaseEngine,
String connectionUrl,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,15 +62,18 @@ void tearDown() throws Exception {
@Test
void testConnectWithDefaultMapping() throws SQLException {
when(mockHostSpec.getUrl()).thenReturn("url");
final Set<String> expected = new HashSet<>(Collections.singletonList("url"));
final Set<String> expected = new HashSet<>(Collections.singletonList("urlusername"));

final HikariPooledConnectionProvider provider =
spy(new HikariPooledConnectionProvider((hostSpec, properties) -> mockConfig));

doReturn(mockDataSource).when(provider).createHikariDataSource(any(), any(), any());

Properties props = new Properties();
props.setProperty(PropertyDefinition.USER.name, "username");
props.setProperty(PropertyDefinition.PASSWORD.name, "password");
try (Connection conn = provider.connect(
"protocol", mockDialect, mockHostSpec, emptyProperties)) {
"protocol", mockDialect, mockHostSpec, props)) {
assertEquals(mockConnection, conn);
assertEquals(1, provider.getHostCount());
final Set<String> hosts = provider.getHosts();
Expand All @@ -83,16 +86,19 @@ void testConnectWithDefaultMapping() throws SQLException {
@Test
void testConnectWithCustomMapping() throws SQLException {
when(mockHostSpec.getUrl()).thenReturn("url");
final Set<String> expected = new HashSet<>(Collections.singletonList("url+someUniqueKey"));
final Set<String> expected = new HashSet<>(Collections.singletonList("url+someUniqueKeyusername"));

final HikariPooledConnectionProvider provider = spy(new HikariPooledConnectionProvider(
(hostSpec, properties) -> mockConfig,
(hostSpec, properties) -> hostSpec.getUrl() + "+someUniqueKey"));

doReturn(mockDataSource).when(provider).createHikariDataSource(any(), any(), any());

Properties props = new Properties();
props.setProperty(PropertyDefinition.USER.name, "username");
props.setProperty(PropertyDefinition.PASSWORD.name, "password");
try (Connection conn = provider.connect(
"protocol", mockDialect, mockHostSpec, emptyProperties)) {
"protocol", mockDialect, mockHostSpec, props)) {
assertEquals(mockConnection, conn);
assertEquals(1, provider.getHostCount());
final Set<String> hosts = provider.getHosts();
Expand Down