diff --git a/wrapper/src/main/java/software/amazon/jdbc/Driver.java b/wrapper/src/main/java/software/amazon/jdbc/Driver.java index 175dbc830..66559bfc2 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/Driver.java +++ b/wrapper/src/main/java/software/amazon/jdbc/Driver.java @@ -91,6 +91,8 @@ public Connection connect(final String url, final Properties info) throws SQLExc return null; } + LOGGER.finest("Opening connection to " + url); + final String driverUrl = url.replaceFirst(PROTOCOL_PREFIX, "jdbc:"); final java.sql.Driver driver = DriverManager.getDriver(driverUrl); diff --git a/wrapper/src/main/java/software/amazon/jdbc/HostListProvider.java b/wrapper/src/main/java/software/amazon/jdbc/HostListProvider.java index 9de22554e..919187a31 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/HostListProvider.java +++ b/wrapper/src/main/java/software/amazon/jdbc/HostListProvider.java @@ -39,4 +39,6 @@ public interface HostListProvider { * determine the host role */ HostRole getHostRole(Connection connection) throws SQLException; + + HostSpec identifyConnection(Connection connection) throws SQLException; } diff --git a/wrapper/src/main/java/software/amazon/jdbc/HostSpec.java b/wrapper/src/main/java/software/amazon/jdbc/HostSpec.java index dd9e8161d..1bfd90508 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/HostSpec.java +++ b/wrapper/src/main/java/software/amazon/jdbc/HostSpec.java @@ -38,6 +38,7 @@ public class HostSpec { protected Set aliases = ConcurrentHashMap.newKeySet(); protected Set allAliases = ConcurrentHashMap.newKeySet(); protected long weight; // Greater or equal 0. Lesser the weight, the healthier node. + protected String hostId; public HostSpec(final String host) { this.host = host; @@ -148,6 +149,12 @@ public void removeAlias(final String... alias) { }); } + public void resetAliases() { + this.aliases.clear(); + this.allAliases.clear(); + this.allAliases.add(this.asAlias()); + } + public String getUrl() { String url = isPortSpecified() ? host + ":" + port : host; if (!url.endsWith("/")) { @@ -156,6 +163,14 @@ public String getUrl() { return url; } + public String getHostId() { + return hostId; + } + + public void setHostId(String hostId) { + this.hostId = hostId; + } + public String asAlias() { return isPortSpecified() ? host + ":" + port : host; } diff --git a/wrapper/src/main/java/software/amazon/jdbc/PluginService.java b/wrapper/src/main/java/software/amazon/jdbc/PluginService.java index 85a85e919..313492568 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/PluginService.java +++ b/wrapper/src/main/java/software/amazon/jdbc/PluginService.java @@ -151,4 +151,8 @@ HostSpec getHostSpecByStrategy(HostRole role, String strategy) Dialect getDialect(); void updateDialect(final @NonNull Connection connection) throws SQLException; + + HostSpec identifyConnection(final Connection connection) throws SQLException; + + void fillAliases(final Connection connection, final HostSpec hostSpec) throws SQLException; } diff --git a/wrapper/src/main/java/software/amazon/jdbc/PluginServiceImpl.java b/wrapper/src/main/java/software/amazon/jdbc/PluginServiceImpl.java index ff5de8b7b..ef6ba5f9e 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/PluginServiceImpl.java +++ b/wrapper/src/main/java/software/amazon/jdbc/PluginServiceImpl.java @@ -17,13 +17,16 @@ package software.amazon.jdbc; import java.sql.Connection; +import java.sql.ResultSet; import java.sql.SQLException; +import java.sql.Statement; import java.util.ArrayList; import java.util.EnumSet; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Map.Entry; +import java.util.Objects; import java.util.Properties; import java.util.Set; import java.util.concurrent.TimeUnit; @@ -35,6 +38,7 @@ import software.amazon.jdbc.dialect.Dialect; import software.amazon.jdbc.dialect.DialectManager; import software.amazon.jdbc.dialect.DialectProvider; +import software.amazon.jdbc.dialect.TopologyAwareDatabaseCluster; import software.amazon.jdbc.exceptions.ExceptionManager; import software.amazon.jdbc.hostlistprovider.StaticHostListProvider; import software.amazon.jdbc.util.CacheMap; @@ -334,7 +338,7 @@ public HostListProvider getHostListProvider() { @Override public void refreshHostList() throws SQLException { final List updatedHostList = this.getHostListProvider().refresh(); - if (updatedHostList != null) { + if (!Objects.equals(updatedHostList, this.hosts)) { updateHostAvailability(updatedHostList); setNodeList(this.hosts, updatedHostList); } @@ -343,7 +347,7 @@ public void refreshHostList() throws SQLException { @Override public void refreshHostList(final Connection connection) throws SQLException { final List updatedHostList = this.getHostListProvider().refresh(connection); - if (updatedHostList != null) { + if (!Objects.equals(updatedHostList, this.hosts)) { updateHostAvailability(updatedHostList); setNodeList(this.hosts, updatedHostList); } @@ -489,4 +493,44 @@ public void updateDialect(final @NonNull Connection connection) throws SQLExcept connection); } + @Override + public HostSpec identifyConnection(Connection connection) throws SQLException { + if (!(this.getDialect() instanceof TopologyAwareDatabaseCluster)) { + return null; + } + + return this.hostListProvider.identifyConnection(connection); + } + + @Override + public void fillAliases(Connection connection, HostSpec hostSpec) throws SQLException { + if (hostSpec == null) { + return; + } + + if (!hostSpec.getAliases().isEmpty()) { + LOGGER.finest(() -> Messages.get("PluginServiceImpl.nonEmptyAliases", new Object[] {hostSpec.getAliases()})); + return; + } + + hostSpec.addAlias(hostSpec.asAlias()); + + // Add the host name and port, this host name is usually the internal IP address. + try (final Statement stmt = connection.createStatement()) { + try (final ResultSet rs = stmt.executeQuery(this.getDialect().getHostAliasQuery())) { + while (rs.next()) { + hostSpec.addAlias(rs.getString(1)); + } + } + } catch (final SQLException sqlException) { + // log and ignore + LOGGER.finest(() -> Messages.get("PluginServiceImpl.failedToRetrieveHostPort")); + } + + // Add the instance endpoint if the current connection is associated with a topology aware database cluster. + final HostSpec host = this.identifyConnection(connection); + if (host != null) { + hostSpec.addAlias(host.asAliases().toArray(new String[] {})); + } + } } diff --git a/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/AuroraHostListProvider.java b/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/AuroraHostListProvider.java index 4f593e9d5..22913d084 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/AuroraHostListProvider.java +++ b/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/AuroraHostListProvider.java @@ -27,6 +27,7 @@ import java.util.HashSet; import java.util.List; import java.util.Map.Entry; +import java.util.Objects; import java.util.Properties; import java.util.Set; import java.util.UUID; @@ -85,7 +86,6 @@ public class AuroraHostListProvider implements DynamicHostListProvider { : TimeUnit.MILLISECONDS.toNanos(30000); private final long suggestedClusterIdRefreshRateNano = TimeUnit.MINUTES.toNanos(10); private List hostList = new ArrayList<>(); - private List lastReturnedHostList; private List initialHostList = new ArrayList<>(); private HostSpec initialHostSpec; @@ -273,7 +273,7 @@ private ClusterSuggestedResult getSuggestedClusterId(final String url) { for (final HostSpec host : hosts) { if (host.getUrl().equals(url)) { LOGGER.finest(() -> Messages.get("AuroraHostListProvider.suggestedClusterId", - new Object[]{key, url})); + new Object[] {key, url})); return new ClusterSuggestedResult(key, isPrimaryCluster); } } @@ -398,8 +398,12 @@ private HostSpec createHost(final ResultSet resultSet) throws SQLException { // Calculate weight based on node lag in time and CPU utilization. final long weight = Math.round(nodeLag) * 100L + Math.round(cpuUtilization); - hostName = hostName == null ? "?" : hostName; - final String endpoint = getHostEndpoint(hostName); + return createHost(hostName, isWriter, weight); + } + + private HostSpec createHost(String host, final boolean isWriter, final long weight) { + host = host == null ? "?" : host; + final String endpoint = getHostEndpoint(host); final int port = this.clusterInstanceTemplate.isPortSpecified() ? this.clusterInstanceTemplate.getPort() : this.initialHostSpec.getPort(); @@ -410,7 +414,8 @@ private HostSpec createHost(final ResultSet resultSet) throws SQLException { isWriter ? HostRole.WRITER : HostRole.READER, HostAvailability.AVAILABLE, weight); - hostSpec.addAlias(hostName); + hostSpec.addAlias(host); + hostSpec.setHostId(host); return hostSpec; } @@ -466,12 +471,7 @@ public List refresh(final Connection connection) throws SQLException { final FetchTopologyResult results = getTopology(currentConnection, false); LOGGER.finest(() -> Utils.logTopology(results.hosts)); - if (results.isCachedData && this.lastReturnedHostList == results.hosts) { - return null; // no topology update - } - this.hostList = results.hosts; - this.lastReturnedHostList = this.hostList; return Collections.unmodifiableList(hostList); } @@ -490,7 +490,6 @@ public List forceRefresh(final Connection connection) throws SQLExcept final FetchTopologyResult results = getTopology(currentConnection, true); LOGGER.finest(() -> Utils.logTopology(results.hosts)); this.hostList = results.hosts; - this.lastReturnedHostList = this.hostList; return Collections.unmodifiableList(this.hostList); } @@ -577,6 +576,7 @@ public FetchTopologyResult(final boolean isCachedData, final List host } static class ClusterSuggestedResult { + public String clusterId; public boolean isPrimaryClusterId; @@ -588,18 +588,10 @@ public ClusterSuggestedResult(final String clusterId, final boolean isPrimaryClu @Override public HostRole getHostRole(Connection conn) throws SQLException { - if (this.topologyAwareDialect == null) { - Dialect dialect = this.hostListProviderService.getDialect(); - if (!(dialect instanceof TopologyAwareDatabaseCluster)) { - throw new SQLException( - Messages.get("AuroraHostListProvider.invalidDialectForGetHostRole", - new Object[]{dialect})); - } - this.topologyAwareDialect = (TopologyAwareDatabaseCluster) this.hostListProviderService.getDialect(); - } - try (final Statement stmt = conn.createStatement(); - final ResultSet rs = stmt.executeQuery(this.topologyAwareDialect.getIsReaderQuery())) { + final ResultSet rs = stmt.executeQuery( + getTopologyAwareDialect("AuroraHostListProvider.invalidDialectForGetHostRole") + .getIsReaderQuery())) { if (rs.next()) { boolean isReader = rs.getBoolean(1); return isReader ? HostRole.READER : HostRole.WRITER; @@ -610,4 +602,45 @@ public HostRole getHostRole(Connection conn) throws SQLException { throw new SQLException(Messages.get("AuroraHostListProvider.errorGettingHostRole")); } + + @Override + public HostSpec identifyConnection(Connection connection) throws SQLException { + try (final Statement stmt = connection.createStatement(); + final ResultSet resultSet = stmt.executeQuery( + getTopologyAwareDialect("AuroraHostListProvider.invalidDialectForIdentifyConnection") + .getNodeIdQuery())) { + if (resultSet.next()) { + final String instanceName = resultSet.getString(1); + + final List topology = this.refresh(); + + if (topology == null) { + return null; + } + + return topology + .stream() + .filter(host -> Objects.equals(instanceName, host.getHostId())) + .findAny() + .orElse(null); + } + } catch (final SQLException e) { + throw new SQLException(Messages.get("AuroraHostListProvider.errorIdentifyConnection"), e); + } + + throw new SQLException(Messages.get("AuroraHostListProvider.errorIdentifyConnection")); + } + + private TopologyAwareDatabaseCluster getTopologyAwareDialect(String exceptionMessageIdentifier) throws SQLException { + if (this.topologyAwareDialect == null) { + Dialect dialect = this.hostListProviderService.getDialect(); + if (!(dialect instanceof TopologyAwareDatabaseCluster)) { + throw new SQLException( + Messages.get(exceptionMessageIdentifier, + new Object[] {dialect})); + } + this.topologyAwareDialect = (TopologyAwareDatabaseCluster) this.hostListProviderService.getDialect(); + } + return this.topologyAwareDialect; + } } diff --git a/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/ConnectionStringHostListProvider.java b/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/ConnectionStringHostListProvider.java index acc8f99c1..aab082d7d 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/ConnectionStringHostListProvider.java +++ b/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/ConnectionStringHostListProvider.java @@ -17,10 +17,13 @@ package software.amazon.jdbc.hostlistprovider; import java.sql.Connection; +import java.sql.ResultSet; import java.sql.SQLException; +import java.sql.Statement; import java.util.ArrayList; import java.util.Collections; import java.util.List; +import java.util.Objects; import java.util.Properties; import org.checkerframework.checker.nullness.qual.NonNull; import software.amazon.jdbc.AwsWrapperProperty; @@ -109,4 +112,30 @@ public List forceRefresh(final Connection connection) throws SQLExcept public HostRole getHostRole(Connection connection) { throw new UnsupportedOperationException("ConnectionStringHostListProvider does not support getHostRole"); } + + @Override + public HostSpec identifyConnection(Connection connection) throws SQLException { + try (final Statement stmt = connection.createStatement(); + final ResultSet resultSet = stmt.executeQuery(this.hostListProviderService.getDialect().getHostAliasQuery())) { + if (resultSet.next()) { + final String instance = resultSet.getString(1); + + final List topology = this.refresh(connection); + + if (topology == null) { + return null; + } + + return topology + .stream() + .filter(host -> Objects.equals(instance, host.getHostId())) + .findAny() + .orElse(null); + } + } catch (final SQLException e) { + throw new SQLException(Messages.get("ConnectionStringHostListProvider.errorIdentifyConnection"), e); + } + + throw new SQLException(Messages.get("ConnectionStringHostListProvider.errorIdentifyConnection")); + } } diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/AuroraConnectionTrackerPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/AuroraConnectionTrackerPlugin.java index 921fd2bd3..647179750 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/AuroraConnectionTrackerPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/AuroraConnectionTrackerPlugin.java @@ -17,24 +17,23 @@ package software.amazon.jdbc.plugin; import java.sql.Connection; -import java.sql.ResultSet; import java.sql.SQLException; -import java.sql.Statement; import java.util.Collections; import java.util.EnumSet; import java.util.HashSet; +import java.util.List; import java.util.Map; import java.util.Properties; import java.util.Set; +import org.checkerframework.checker.nullness.qual.NonNull; +import software.amazon.jdbc.HostRole; import software.amazon.jdbc.HostSpec; import software.amazon.jdbc.JdbcCallable; import software.amazon.jdbc.NodeChangeOptions; import software.amazon.jdbc.PluginService; -import software.amazon.jdbc.dialect.TopologyAwareDatabaseCluster; -import software.amazon.jdbc.hostlistprovider.AuroraHostListProvider; import software.amazon.jdbc.plugin.failover.FailoverSQLException; +import software.amazon.jdbc.util.RdsUrlType; import software.amazon.jdbc.util.RdsUtils; -import software.amazon.jdbc.util.StringUtils; import software.amazon.jdbc.util.SubscribedMethodHelper; public class AuroraConnectionTrackerPlugin extends AbstractConnectionPlugin { @@ -52,10 +51,10 @@ public class AuroraConnectionTrackerPlugin extends AbstractConnectionPlugin { }); private final PluginService pluginService; - private final Properties props; private final RdsUtils rdsHelper; - private String clusterInstanceTemplate; private final OpenedConnectionTracker tracker; + private HostSpec currentWriter = null; + private boolean needUpdateCurrentWriter = false; AuroraConnectionTrackerPlugin(final PluginService pluginService, final Properties props) { this(pluginService, props, new RdsUtils(), new OpenedConnectionTracker()); @@ -67,7 +66,6 @@ public class AuroraConnectionTrackerPlugin extends AbstractConnectionPlugin { final RdsUtils rdsUtils, final OpenedConnectionTracker tracker) { this.pluginService = pluginService; - this.props = props; this.rdsHelper = rdsUtils; this.tracker = tracker; } @@ -88,19 +86,17 @@ public Connection connectInternal( throws SQLException { final Connection conn = connectFunc.call(); - final HostSpec currentHostSpec = (this.pluginService.getCurrentHostSpec() != null) - ? this.pluginService.getCurrentHostSpec() - : hostSpec; if (conn != null) { - if (!rdsHelper.isRdsInstance(currentHostSpec.getHost())) { - currentHostSpec.addAlias(getInstanceEndpoint(conn, currentHostSpec)); + final RdsUrlType type = this.rdsHelper.identifyRdsType(hostSpec.getHost()); + if (type.isRdsCluster()) { + hostSpec.resetAliases(); + this.pluginService.fillAliases(conn, hostSpec); } + tracker.populateOpenedConnectionQueue(hostSpec, conn); + tracker.logOpenedConnections(); } - tracker.populateOpenedConnectionQueue(currentHostSpec, conn); - tracker.logOpenedConnections(); - return conn; } @@ -110,31 +106,29 @@ public Connection forceConnect(String driverProtocol, HostSpec hostSpec, Propert return connectInternal(hostSpec, forceConnectFunc); } - private String getInstanceEndpointPattern(final String url) { - if (StringUtils.isNullOrEmpty(this.clusterInstanceTemplate)) { - this.clusterInstanceTemplate = AuroraHostListProvider.CLUSTER_INSTANCE_HOST_PATTERN.getString(this.props) == null - ? rdsHelper.getRdsInstanceHostPattern(url) - : AuroraHostListProvider.CLUSTER_INSTANCE_HOST_PATTERN.getString(this.props); - } - - return this.clusterInstanceTemplate; - } - @Override public T execute(final Class resultClass, final Class exceptionClass, final Object methodInvokeOn, final String methodName, final JdbcCallable jdbcMethodFunc, final Object[] jdbcMethodArgs) throws E { - final HostSpec originalHost = this.pluginService.getCurrentHostSpec(); + + final HostSpec currentHostSpec = this.pluginService.getCurrentHostSpec(); + if (this.currentWriter == null || this.needUpdateCurrentWriter) { + this.currentWriter = this.getWriter(this.pluginService.getHosts()); + this.needUpdateCurrentWriter = false; + } + try { final T result = jdbcMethodFunc.call(); if ((methodName.equals(METHOD_CLOSE) || methodName.equals(METHOD_ABORT))) { - tracker.invalidateCurrentConnection(originalHost, this.pluginService.getCurrentConnection()); + tracker.invalidateCurrentConnection(currentHostSpec, this.pluginService.getCurrentConnection()); } return result; + } catch (final Exception e) { if (e instanceof FailoverSQLException) { - tracker.invalidateAllConnections(originalHost); + tracker.invalidateAllConnections(this.currentWriter); tracker.logOpenedConnections(); + this.needUpdateCurrentWriter = true; } throw e; } @@ -143,36 +137,22 @@ public T execute(final Class resultClass, final Clas @Override public void notifyNodeListChanged(final Map> changes) { for (final String node : changes.keySet()) { - if (isRoleChanged(changes.get(node))) { + final EnumSet nodeChanges = changes.get(node); + if (nodeChanges.contains(NodeChangeOptions.PROMOTED_TO_READER)) { tracker.invalidateAllConnections(node); } + if (nodeChanges.contains(NodeChangeOptions.PROMOTED_TO_WRITER)) { + this.needUpdateCurrentWriter = true; + } } } - private boolean isRoleChanged(final EnumSet changes) { - return changes.contains(NodeChangeOptions.PROMOTED_TO_WRITER) - || changes.contains(NodeChangeOptions.PROMOTED_TO_READER); - } - - public String getInstanceEndpoint(final Connection conn, final HostSpec host) { - String instanceName = "?"; - - if (!(this.pluginService.getDialect() instanceof TopologyAwareDatabaseCluster)) { - return instanceName; - } - final TopologyAwareDatabaseCluster topologyAwareDialect = - (TopologyAwareDatabaseCluster) this.pluginService.getDialect(); - - try (final Statement stmt = conn.createStatement(); - final ResultSet resultSet = stmt.executeQuery(topologyAwareDialect.getNodeIdQuery())) { - if (resultSet.next()) { - instanceName = resultSet.getString(1); + private HostSpec getWriter(final @NonNull List hosts) { + for (final HostSpec hostSpec : hosts) { + if (hostSpec.getRole() == HostRole.WRITER) { + return hostSpec; } - String instanceEndpoint = getInstanceEndpointPattern(host.getHost()); - instanceEndpoint = host.isPortSpecified() ? instanceEndpoint + ":" + host.getPort() : instanceEndpoint; - return instanceEndpoint.replace("?", instanceName); - } catch (final SQLException e) { - return instanceName; } + return null; } } diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/OpenedConnectionTracker.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/OpenedConnectionTracker.java index 194f49650..c0f69597b 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/OpenedConnectionTracker.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/OpenedConnectionTracker.java @@ -60,7 +60,7 @@ public void populateOpenedConnectionQueue(final HostSpec hostSpec, final Connect final Set aliases = hostSpec.asAliases(); // Check if the connection was established using a cluster endpoint - final String host = hostSpec.getHost(); + final String host = hostSpec.asAlias(); if (rdsUtils.isRdsInstance(host)) { trackConnection(host, conn); return; @@ -82,6 +82,7 @@ public void populateOpenedConnectionQueue(final HostSpec hostSpec, final Connect * @param hostSpec The {@link HostSpec} object containing the url of the node. */ public void invalidateAllConnections(final HostSpec hostSpec) { + invalidateAllConnections(hostSpec.asAlias()); invalidateAllConnections(hostSpec.getAliases().toArray(new String[] {})); } @@ -97,7 +98,7 @@ public void invalidateAllConnections(final String... node) { public void invalidateCurrentConnection(final HostSpec hostSpec, final Connection connection) { final String host = rdsUtils.isRdsInstance(hostSpec.getHost()) - ? hostSpec.getHost() + ? hostSpec.asAlias() : hostSpec.getAliases().stream().filter(rdsUtils::isRdsInstance).findFirst().orElse(null); if (StringUtils.isNullOrEmpty(host)) { diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/efm/HostMonitoringConnectionPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/efm/HostMonitoringConnectionPlugin.java index a01c22e75..0c6303fb8 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/efm/HostMonitoringConnectionPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/efm/HostMonitoringConnectionPlugin.java @@ -39,6 +39,8 @@ import software.amazon.jdbc.cleanup.CanReleaseResources; import software.amazon.jdbc.plugin.AbstractConnectionPlugin; import software.amazon.jdbc.util.Messages; +import software.amazon.jdbc.util.RdsUrlType; +import software.amazon.jdbc.util.RdsUtils; import software.amazon.jdbc.util.SubscribedMethodHelper; /** @@ -81,8 +83,9 @@ public class HostMonitoringConnectionPlugin extends AbstractConnectionPlugin protected @NonNull Properties properties; private final @NonNull Supplier monitorServiceSupplier; private final @NonNull PluginService pluginService; - private final @NonNull Set nodeKeys = ConcurrentHashMap.newKeySet(); // Shared with monitor thread private MonitorService monitorService; + private final RdsUtils rdsHelper; + private HostSpec monitoringHostSpec; /** * Initialize the node monitoring plugin. @@ -93,13 +96,14 @@ public class HostMonitoringConnectionPlugin extends AbstractConnectionPlugin */ public HostMonitoringConnectionPlugin( final @NonNull PluginService pluginService, final @NonNull Properties properties) { - this(pluginService, properties, () -> new MonitorServiceImpl(pluginService)); + this(pluginService, properties, () -> new MonitorServiceImpl(pluginService), new RdsUtils()); } HostMonitoringConnectionPlugin( final @NonNull PluginService pluginService, final @NonNull Properties properties, - final @NonNull Supplier monitorServiceSupplier) { + final @NonNull Supplier monitorServiceSupplier, + final RdsUtils rdsHelper) { if (pluginService == null) { throw new IllegalArgumentException("pluginService"); } @@ -112,6 +116,7 @@ public HostMonitoringConnectionPlugin( this.pluginService = pluginService; this.properties = properties; this.monitorServiceSupplier = monitorServiceSupplier; + this.rdsHelper = rdsHelper; } @Override @@ -156,14 +161,11 @@ public T execute( "HostMonitoringConnectionPlugin.activatedMonitoring", new Object[] {methodName})); - this.nodeKeys.clear(); - this.nodeKeys.addAll(this.pluginService.getCurrentHostSpec().asAliases()); - monitorContext = this.monitorService.startMonitoring( this.pluginService.getCurrentConnection(), // abort this connection if needed - this.nodeKeys, - this.pluginService.getCurrentHostSpec(), + this.getMonitoringHostSpec().asAliases(), + this.getMonitoringHostSpec(), this.properties, failureDetectionTimeMillis, failureDetectionIntervalMillis, @@ -177,7 +179,9 @@ public T execute( this.monitorService.stopMonitoring(monitorContext); if (monitorContext.isNodeUnhealthy()) { - this.pluginService.setAvailability(this.nodeKeys, HostAvailability.NOT_AVAILABLE); + this.pluginService.setAvailability( + this.getMonitoringHostSpec().asAliases(), + HostAvailability.NOT_AVAILABLE); final boolean isConnectionClosed; try { @@ -241,44 +245,18 @@ public void releaseResources() { this.monitorService = null; } - /** - * Generate a set of node keys representing the node to monitor. - * - * @param driverProtocol Driver protocol for provided connection - * @param connection the connection to a specific node. - * @param hostSpec host details to add node keys to - */ - private void generateHostAliases( - final @NonNull String driverProtocol, - final @NonNull Connection connection, - final @NonNull HostSpec hostSpec) { - - hostSpec.addAlias(hostSpec.asAlias()); - - try (final Statement stmt = connection.createStatement()) { - try (final ResultSet rs = stmt.executeQuery(this.pluginService.getDialect().getHostAliasQuery())) { - while (rs.next()) { - hostSpec.addAlias(rs.getString(1)); - } - } - } catch (final SQLException sqlException) { - // log and ignore - LOGGER.finest(() -> Messages.get("HostMonitoringConnectionPlugin.failedToRetrieveHostPort")); - } - } - @Override public OldConnectionSuggestedAction notifyConnectionChanged(final EnumSet changes) { - if (changes.contains(NodeChangeOptions.WENT_DOWN) || changes.contains(NodeChangeOptions.NODE_DELETED)) { - if (!this.nodeKeys.isEmpty()) { - this.monitorService.stopMonitoringForAllConnections(this.nodeKeys); + if (!this.getMonitoringHostSpec().asAliases().isEmpty()) { + this.monitorService.stopMonitoringForAllConnections(this.getMonitoringHostSpec().asAliases()); } - this.nodeKeys.clear(); - this.nodeKeys.addAll(this.pluginService.getCurrentHostSpec().getAliases()); } + // Reset monitoring HostSpec since the associated connection has changed. + this.monitoringHostSpec = null; + return OldConnectionSuggestedAction.NO_OPINION; } @@ -298,7 +276,11 @@ private Connection connectInternal(String driverProtocol, HostSpec hostSpec, final Connection conn = connectFunc.call(); if (conn != null) { - generateHostAliases(driverProtocol, conn, hostSpec); + final RdsUrlType type = this.rdsHelper.identifyRdsType(hostSpec.getHost()); + if (type.isRdsCluster()) { + hostSpec.resetAliases(); + this.pluginService.fillAliases(conn, hostSpec); + } } return conn; @@ -314,4 +296,32 @@ public Connection forceConnect( throws SQLException { return connectInternal(driverProtocol, hostSpec, forceConnectFunc); } + + public HostSpec getMonitoringHostSpec() { + if (this.monitoringHostSpec == null) { + this.monitoringHostSpec = this.pluginService.getCurrentHostSpec(); + final RdsUrlType rdsUrlType = this.rdsHelper.identifyRdsType(monitoringHostSpec.getUrl()); + + try { + if (rdsUrlType.isRdsCluster()) { + LOGGER.finest("Monitoring HostSpec is associated with a cluster endpoint, " + + "plugin needs to identify the cluster connection."); + this.monitoringHostSpec = this.pluginService.identifyConnection(this.pluginService.getCurrentConnection()); + if (this.monitoringHostSpec == null) { + throw new RuntimeException(Messages.get( + "HostMonitoringConnectionPlugin.unableToIdentifyConnection", + new Object[] { + this.pluginService.getCurrentHostSpec().getHost(), + this.pluginService.getHostListProvider()})); + } + this.pluginService.fillAliases(this.pluginService.getCurrentConnection(), monitoringHostSpec); + } + } catch (SQLException e) { + // Log and throw. + LOGGER.finest(Messages.get("HostMonitoringConnectionPlugin.errorIdentifyingConnection", new Object[] {e})); + throw new RuntimeException(e); + } + } + return this.monitoringHostSpec; + } } diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/efm/MonitorImpl.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/efm/MonitorImpl.java index bd182c3cc..b96dd6384 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/efm/MonitorImpl.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/efm/MonitorImpl.java @@ -260,8 +260,10 @@ ConnectionStatus checkConnectionStatus(final long shortestFailureDetectionInterv monitoringConnProperties.remove(p); }); + LOGGER.finest(() -> "Opening a monitoring connection to " + this.hostSpec.getUrl()); startNano = this.getCurrentTimeNano(); this.monitoringConn = this.pluginService.forceConnect(this.hostSpec, monitoringConnProperties); + LOGGER.finest(() -> "Opened monitoring connection: " + this.monitoringConn); return new ConnectionStatus(true, this.getCurrentTimeNano() - startNano); } diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/efm/MonitorServiceImpl.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/efm/MonitorServiceImpl.java index b809e0087..7d1ee12e9 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/efm/MonitorServiceImpl.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/efm/MonitorServiceImpl.java @@ -17,6 +17,7 @@ package software.amazon.jdbc.plugin.efm; import java.sql.Connection; +import java.sql.SQLException; import java.util.Collections; import java.util.Properties; import java.util.Set; @@ -42,6 +43,7 @@ public class MonitorServiceImpl implements MonitorService { "60000", "Interval in milliseconds for a monitor to be considered inactive and to be disposed."); + private final PluginService pluginService; private MonitorThreadContainer threadContainer; final MonitorInitializer monitorInitializer; @@ -50,6 +52,7 @@ public class MonitorServiceImpl implements MonitorService { public MonitorServiceImpl(final @NonNull PluginService pluginService) { this( + pluginService, (hostSpec, properties, monitorService) -> new MonitorImpl( pluginService, @@ -67,9 +70,11 @@ public MonitorServiceImpl(final @NonNull PluginService pluginService) { } MonitorServiceImpl( + final PluginService pluginService, final MonitorInitializer monitorInitializer, final ExecutorServiceInitializer executorServiceInitializer) { + this.pluginService = pluginService; this.monitorInitializer = monitorInitializer; this.threadContainer = MonitorThreadContainer.getInstance(executorServiceInitializer); } @@ -85,11 +90,9 @@ public MonitorConnectionContext startMonitoring( final int failureDetectionCount) { if (nodeKeys.isEmpty()) { - LOGGER.warning( - () -> Messages.get( - "MonitorServiceImpl.emptyAliasSet", - new Object[] {hostSpec})); - hostSpec.addAlias(hostSpec.asAlias()); + throw new IllegalArgumentException(Messages.get( + "MonitorServiceImpl.emptyAliasSet", + new Object[] {hostSpec})); } final Monitor monitor; diff --git a/wrapper/src/main/resources/messages.properties b/wrapper/src/main/resources/messages.properties index 3ddc0f682..cd886deb8 100644 --- a/wrapper/src/main/resources/messages.properties +++ b/wrapper/src/main/resources/messages.properties @@ -28,6 +28,8 @@ AuroraHostListProvider.invalidQuery=Error obtaining host list. Provided database AuroraHostListProvider.invalidDialect=Expecting a dialect that supports a cluster topology. AuroraHostListProvider.invalidDialectForGetHostRole=An Aurora dialect is required to analyze a host's role. The detected dialect was ''{0}'' AuroraHostListProvider.errorGettingHostRole=An error occurred while obtaining the connected host's role. This could occur if the connection is broken or if you are not connected to an Aurora database. +AuroraHostListProvider.errorIdentifyConnection=An error occurred while obtaining the connection's host ID. +AuroraHostListProvider.invalidDialectForIdentifyConnection=An Aurora dialect is required to analyze the instance associated with this connection. The detected dialect was ''{0}'' # AWS Credentials Manager AwsCredentialsManager.nullProvider=The configured AwsCredentialsProvider was null. If you have configured the AwsCredentialsManager to use a custom AwsCredentialsProviderHandler, please ensure the handler does not return null. @@ -77,6 +79,7 @@ ClusterAwareWriterFailoverHandler.alreadyWriter=Current reader connection is act # Connection String Host List Provider ConnectionStringHostListProvider.parsedListEmpty=Can''t parse connection string: ''{0}''. +ConnectionStringHostListProvider.errorIdentifyConnection=An error occurred while obtaining the connection's host ID. # Connection Plugin Manager ConnectionPluginManager.configurationProfileNotFound=Configuration profile ''{0}'' not found. @@ -139,8 +142,8 @@ Failover.noOperationsAfterConnectionClosed=No operations allowed after connectio HostMonitoringConnectionPlugin.activatedMonitoring=Executing method ''{0}'', monitoring is activated. HostMonitoringConnectionPlugin.monitoringDeactivated=Monitoring deactivated for method ''{0}''. HostMonitoringConnectionPlugin.unavailableNode=Node ''{0}'' is unavailable. -HostMonitoringConnectionPlugin.failedToRetrieveHostPort=Could not retrieve Host:Port for connection. -HostMonitoringConnectionPlugin.unsupportedDriverProtocol=Driver protocol ''{0}'' is not supported. +HostMonitoringConnectionPlugin.errorIdentifyingConnection=Error occurred while identifying connection: ''{0}''. +HostMonitoringConnectionPlugin.unableToIdentifyConnection=Unable to identify the given connection: ''{0}'', please ensure the correct host list provider is specified. The host list provider in use is: ''{1}''. # IAM Auth Connection Plugin IamAuthConnectionPlugin.unsupportedHostname=Unsupported AWS hostname {0}. Amazon domain name in format *.AWS-Region.rds.amazonaws.com or *.rds.AWS-Region.amazonaws.com.cn is expected. @@ -164,8 +167,9 @@ MonitorThreadContainer.emptyNodeKeys=Provided node keys are empty. MonitorImpl.contextNullWarning=Parameter 'context' should not be null. # Monitor Service Impl -MonitorServiceImpl.nullMonitorParam=Parameter monitor' should not be null. -MonitorServiceImpl.emptyAliasSet=Empty alias set passed for {0}. Set should not be empty. +MonitorServiceImpl.nullMonitorParam=Parameter 'monitor' should not be null. +MonitorServiceImpl.emptyAliasSet=Empty alias set passed for ''{0}''. Set should not be empty. +MonitorServiceImpl.errorPopulatingAliases=Error occurred while populating aliases: ''{0}''. # Plugin Service Impl PluginServiceImpl.hostListEmpty=Current host list is empty. @@ -173,6 +177,8 @@ PluginServiceImpl.releaseResources=Releasing resources. PluginServiceImpl.hostListException=Exception while getting a host list. PluginServiceImpl.hostAliasNotFound=Can''t find any host by the following aliases: ''{0}''. PluginServiceImpl.hostsChangelistEmpty=There are no changes in the hosts' availability. +PluginServiceImpl.failedToRetrieveHostPort=Could not retrieve Host:Port for connection. +PluginServiceImpl.nonEmptyAliases=fillAliases called when HostSpec already contains the following aliases: ''{0}''. Please reset HostSpec aliases before calling this method if you need to re-fill them. # Property Utils PropertyUtils.setMethodDoesNotExistOnTarget=Set method for property ''{0}'' does not exist on target ''{1}''. diff --git a/wrapper/src/test/java/software/amazon/jdbc/PluginServiceImplTests.java b/wrapper/src/test/java/software/amazon/jdbc/PluginServiceImplTests.java index c6385c8bd..c2cedf963 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/PluginServiceImplTests.java +++ b/wrapper/src/test/java/software/amazon/jdbc/PluginServiceImplTests.java @@ -16,12 +16,14 @@ package software.amazon.jdbc; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertNotEquals; import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.doNothing; import static org.mockito.Mockito.never; import static org.mockito.Mockito.spy; @@ -30,7 +32,9 @@ import static org.mockito.Mockito.when; import java.sql.Connection; +import java.sql.ResultSet; import java.sql.SQLException; +import java.sql.Statement; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; @@ -40,15 +44,21 @@ import java.util.Map; import java.util.Properties; import java.util.Set; +import java.util.stream.Stream; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; import org.mockito.ArgumentCaptor; import org.mockito.Captor; import org.mockito.Mock; -import org.mockito.Mockito; import org.mockito.MockitoAnnotations; +import software.amazon.jdbc.dialect.AuroraPgDialect; +import software.amazon.jdbc.dialect.Dialect; import software.amazon.jdbc.dialect.DialectManager; +import software.amazon.jdbc.dialect.MysqlDialect; import software.amazon.jdbc.exceptions.ExceptionManager; public class PluginServiceImplTests { @@ -63,6 +73,8 @@ public class PluginServiceImplTests { @Mock Connection oldConnection; @Mock HostListProvider hostListProvider; @Mock DialectManager dialectManager; + @Mock Statement statement; + @Mock ResultSet resultSet; @Captor ArgumentCaptor> argumentChanges; @Captor ArgumentCaptor>> argumentChangesMap; @@ -72,6 +84,8 @@ public class PluginServiceImplTests { void setUp() throws SQLException { closeable = MockitoAnnotations.openMocks(this); when(oldConnection.isClosed()).thenReturn(false); + when(newConnection.createStatement()).thenReturn(statement); + when(statement.executeQuery(any())).thenReturn(resultSet); PluginServiceImpl.hostAvailabilityExpiringCache.clear(); } @@ -512,13 +526,18 @@ void testRefreshHostList_withCachedHostAvailability() throws SQLException { new HostSpec("hostB", HostSpec.NO_PORT, HostRole.READER, HostAvailability.AVAILABLE), new HostSpec("hostC", HostSpec.NO_PORT, HostRole.READER, HostAvailability.AVAILABLE) ); + final List newHostSpecs2 = Arrays.asList( + new HostSpec("hostA", HostSpec.NO_PORT, HostRole.READER, HostAvailability.AVAILABLE), + new HostSpec("hostB", HostSpec.NO_PORT, HostRole.READER, HostAvailability.NOT_AVAILABLE), + new HostSpec("hostC", HostSpec.NO_PORT, HostRole.READER, HostAvailability.AVAILABLE) + ); final List expectedHostSpecs = Arrays.asList( new HostSpec("hostA", HostSpec.NO_PORT, HostRole.READER, HostAvailability.NOT_AVAILABLE), new HostSpec("hostB", HostSpec.NO_PORT, HostRole.READER, HostAvailability.NOT_AVAILABLE), new HostSpec("hostC", HostSpec.NO_PORT, HostRole.READER, HostAvailability.AVAILABLE)); final List expectedHostSpecs2 = Arrays.asList( new HostSpec("hostA", HostSpec.NO_PORT, HostRole.READER, HostAvailability.NOT_AVAILABLE), - new HostSpec("hostB", HostSpec.NO_PORT, HostRole.READER, HostAvailability.AVAILABLE), + new HostSpec("hostB", HostSpec.NO_PORT, HostRole.READER, HostAvailability.NOT_AVAILABLE), new HostSpec("hostC", HostSpec.NO_PORT, HostRole.READER, HostAvailability.AVAILABLE)); PluginServiceImpl.hostAvailabilityExpiringCache.put("hostA/", HostAvailability.NOT_AVAILABLE, @@ -526,7 +545,7 @@ void testRefreshHostList_withCachedHostAvailability() throws SQLException { PluginServiceImpl.hostAvailabilityExpiringCache.put("hostB/", HostAvailability.NOT_AVAILABLE, PluginServiceImpl.DEFAULT_HOST_AVAILABILITY_CACHE_EXPIRE_NANO); when(hostListProvider.refresh()).thenReturn(newHostSpecs); - when(hostListProvider.refresh(newConnection)).thenReturn(newHostSpecs); + when(hostListProvider.refresh(newConnection)).thenReturn(newHostSpecs2); PluginServiceImpl target = spy( new PluginServiceImpl( @@ -580,4 +599,73 @@ void testForceRefreshHostList_withCachedHostAvailability() throws SQLException { target.forceRefreshHostList(newConnection); assertEquals(expectedHostSpecs2, newHostSpecs); } + + @Test + void testIdentifyConnectionWithNoAliases() throws SQLException { + PluginServiceImpl target = spy( + new PluginServiceImpl( + pluginManager, new ExceptionManager(), PROPERTIES, URL, DRIVER_PROTOCOL, dialectManager)); + when(target.getHostListProvider()).thenReturn(hostListProvider); + + when(target.getDialect()).thenReturn(new MysqlDialect()); + assertNull(target.identifyConnection(newConnection)); + } + + @Test + void testIdentifyConnectionWithAliases() throws SQLException { + final HostSpec expected = new HostSpec("test"); + PluginServiceImpl target = spy( + new PluginServiceImpl( + pluginManager, new ExceptionManager(), PROPERTIES, URL, DRIVER_PROTOCOL, dialectManager)); + target.hostListProvider = hostListProvider; + when(target.getHostListProvider()).thenReturn(hostListProvider); + when(hostListProvider.identifyConnection(eq(newConnection))).thenReturn(expected); + + when(target.getDialect()).thenReturn(new AuroraPgDialect()); + final HostSpec actual = target.identifyConnection(newConnection); + verify(target, never()).getCurrentHostSpec(); + verify(hostListProvider).identifyConnection(newConnection); + assertEquals(expected, actual); + } + + @Test + void testFillAliasesNonEmptyAliases() throws SQLException { + final HostSpec oneAlias = new HostSpec("foo"); + oneAlias.addAlias(oneAlias.asAlias()); + + PluginServiceImpl target = spy( + new PluginServiceImpl( + pluginManager, new ExceptionManager(), PROPERTIES, URL, DRIVER_PROTOCOL, dialectManager)); + + assertEquals(1, oneAlias.getAliases().size()); + target.fillAliases(newConnection, oneAlias); + // Fill aliases should return directly and no additional aliases should be added. + assertEquals(1, oneAlias.getAliases().size()); + } + + @ParameterizedTest + @MethodSource("fillAliasesDialects") + void testFillAliasesWithInstanceEndpoint(Dialect dialect, String[] expectedInstanceAliases) throws SQLException { + final HostSpec empty = new HostSpec("foo"); + PluginServiceImpl target = spy( + new PluginServiceImpl( + pluginManager, new ExceptionManager(), PROPERTIES, URL, DRIVER_PROTOCOL, dialectManager)); + target.hostListProvider = hostListProvider; + when(target.getDialect()).thenReturn(dialect); + when(resultSet.next()).thenReturn(true, false); // Result set contains 1 row. + when(resultSet.getString(eq(1))).thenReturn("ip"); + when(hostListProvider.identifyConnection(eq(newConnection))).thenReturn(new HostSpec("instance")); + + target.fillAliases(newConnection, empty); + + final String[] aliases = empty.getAliases().toArray(new String[] {}); + assertArrayEquals(expectedInstanceAliases, aliases); + } + + private static Stream fillAliasesDialects() { + return Stream.of( + Arguments.of(new AuroraPgDialect(), new String[]{"instance", "foo", "ip"}), + Arguments.of(new MysqlDialect(), new String[]{"foo", "ip"}) + ); + } } diff --git a/wrapper/src/test/java/software/amazon/jdbc/hostlistprovider/AuroraHostListProviderTest.java b/wrapper/src/test/java/software/amazon/jdbc/hostlistprovider/AuroraHostListProviderTest.java index 74174db6c..86a513780 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/hostlistprovider/AuroraHostListProviderTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/hostlistprovider/AuroraHostListProviderTest.java @@ -92,6 +92,7 @@ void setUp() throws SQLException { when(mockConnection.createStatement()).thenReturn(mockStatement); when(mockStatement.executeQuery(queryCaptor.capture())).thenReturn(mockResultSet); when(mockHostListProviderService.getDialect()).thenReturn(mockTopologyAwareDialect); + when(((TopologyAwareDatabaseCluster) mockTopologyAwareDialect).getNodeIdQuery()).thenReturn("nodeIdQuery"); } @AfterEach @@ -105,7 +106,7 @@ private AuroraHostListProvider getAuroraHostListProvider(String protocol, AuroraHostListProvider provider = new AuroraHostListProvider( protocol, mockHostListProviderService, new Properties(), originalUrl); provider.init(); - //provider.clusterId = "cluster-id"; + // provider.clusterId = "cluster-id"; return provider; } @@ -378,7 +379,7 @@ void testTopologyCache_AcceptSuggestion() throws SQLException { List topologyProvider1 = provider1.refresh(Mockito.mock(Connection.class)); assertEquals(topologyClusterA, topologyProvider1); - //AuroraHostListProvider.logCache(); + // AuroraHostListProvider.logCache(); AuroraHostListProvider provider2 = Mockito.spy( getAuroraHostListProvider("jdbc:something://", mockHostListProviderService, @@ -397,7 +398,7 @@ void testTopologyCache_AcceptSuggestion() throws SQLException { assertEquals("cluster-a.cluster-xyz.us-east-2.rds.amazonaws.com/", AuroraHostListProvider.suggestedPrimaryClusterIdCache.get(provider1.clusterId)); - //AuroraHostListProvider.logCache(); + // AuroraHostListProvider.logCache(); topologyProvider1 = provider1.forceRefresh(Mockito.mock(Connection.class)); assertEquals(topologyClusterA, topologyProvider1); @@ -405,6 +406,73 @@ void testTopologyCache_AcceptSuggestion() throws SQLException { assertTrue(provider1.isPrimaryClusterId); assertTrue(provider2.isPrimaryClusterId); - //AuroraHostListProvider.logCache(); + // AuroraHostListProvider.logCache(); + } + + @Test + void testIdentifyConnectionWithInvalidNodeIdQuery() throws SQLException { + auroraHostListProvider = Mockito.spy(getAuroraHostListProvider( + "jdbc:someprotocol://", + mockHostListProviderService, + "jdbc:someprotocol://url")); + + when(mockResultSet.next()).thenReturn(false); + assertThrows(SQLException.class, () -> auroraHostListProvider.identifyConnection(mockConnection)); + + when(mockConnection.createStatement()).thenThrow(new SQLException("exception")); + assertThrows(SQLException.class, () -> auroraHostListProvider.identifyConnection(mockConnection)); + } + + @Test + void testIdentifyConnectionNullTopology() throws SQLException { + auroraHostListProvider = Mockito.spy(getAuroraHostListProvider( + "jdbc:someprotocol://", + mockHostListProviderService, + "jdbc:someprotocol://url")); + auroraHostListProvider.clusterInstanceTemplate = new HostSpec("?.pattern"); + + when(mockResultSet.next()).thenReturn(true); + when(mockResultSet.getString(eq(1))).thenReturn("instance-1"); + when(auroraHostListProvider.refresh(eq(mockConnection))).thenReturn(null); + + assertNull(auroraHostListProvider.identifyConnection(mockConnection)); + } + + @Test + void testIdentifyConnectionHostNotInTopology() throws SQLException { + final List cachedTopology = Collections.singletonList( + new HostSpec("instance-a-1.xyz.us-east-2.rds.amazonaws.com", HostSpec.NO_PORT, HostRole.WRITER)); + + auroraHostListProvider = Mockito.spy(getAuroraHostListProvider( + "jdbc:someprotocol://", + mockHostListProviderService, + "jdbc:someprotocol://url")); + when(mockResultSet.next()).thenReturn(true); + when(mockResultSet.getString(eq(1))).thenReturn("instance-1"); + when(auroraHostListProvider.refresh(eq(mockConnection))).thenReturn(cachedTopology); + + assertNull(auroraHostListProvider.identifyConnection(mockConnection)); + } + + @Test + void testIdentifyConnectionHostInTopology() throws SQLException { + final HostSpec expectedHost = new HostSpec( + "instance-a-1.xyz.us-east-2.rds.amazonaws.com", + HostSpec.NO_PORT, + HostRole.WRITER); + expectedHost.setHostId("instance-a-1"); + final List cachedTopology = Collections.singletonList(expectedHost); + + auroraHostListProvider = Mockito.spy(getAuroraHostListProvider( + "jdbc:someprotocol://", + mockHostListProviderService, + "jdbc:someprotocol://url")); + when(mockResultSet.next()).thenReturn(true); + when(mockResultSet.getString(eq(1))).thenReturn("instance-a-1"); + when(auroraHostListProvider.refresh()).thenReturn(cachedTopology); + + final HostSpec actual = auroraHostListProvider.identifyConnection(mockConnection); + assertEquals("instance-a-1.xyz.us-east-2.rds.amazonaws.com", actual.getHost()); + assertEquals("instance-a-1", actual.getHostId()); } } diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/AuroraConnectionTrackerPluginTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/AuroraConnectionTrackerPluginTest.java index bc61d1296..c853d2eae 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/plugin/AuroraConnectionTrackerPluginTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/AuroraConnectionTrackerPluginTest.java @@ -19,7 +19,6 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.never; @@ -30,6 +29,7 @@ import java.sql.ResultSet; import java.sql.SQLException; import java.sql.Statement; +import java.util.Collections; import java.util.Properties; import java.util.Set; import java.util.stream.Stream; @@ -47,8 +47,8 @@ import software.amazon.jdbc.PluginService; import software.amazon.jdbc.dialect.Dialect; import software.amazon.jdbc.dialect.TopologyAwareDatabaseCluster; -import software.amazon.jdbc.hostlistprovider.AuroraHostListProvider; import software.amazon.jdbc.plugin.failover.FailoverSQLException; +import software.amazon.jdbc.util.RdsUrlType; import software.amazon.jdbc.util.RdsUtils; public class AuroraConnectionTrackerPluginTest { @@ -77,6 +77,7 @@ void setUp() throws SQLException { when(mockConnection.createStatement()).thenReturn(mockStatement); when(mockStatement.executeQuery(any(String.class))).thenReturn(mockResultSet); when(mockRdsUtils.getRdsInstanceHostPattern(any(String.class))).thenReturn("?"); + when(mockRdsUtils.identifyRdsType(any())).thenReturn(RdsUrlType.RDS_INSTANCE); when(mockPluginService.getCurrentConnection()).thenReturn(mockConnection); when(mockPluginService.getDialect()).thenReturn(mockTopologyAwareDialect); when(((TopologyAwareDatabaseCluster) mockTopologyAwareDialect).getNodeIdQuery()).thenReturn("any"); @@ -115,108 +116,11 @@ public void testTrackNewInstanceConnections( assertEquals(0, aliases.size()); } - @ParameterizedTest - @MethodSource("trackNewConnectionsParameters") - public void testTrackNewClusterConnections( - final String protocol, - final boolean isInitialConnection) throws SQLException { - final HostSpec hostSpec = new HostSpec("writerCluster"); - when(mockPluginService.getCurrentHostSpec()).thenReturn(hostSpec); - when(mockRdsUtils.isRdsInstance("writerCluster")).thenReturn(false); - when(mockResultSet.next()).thenReturn(true, false); // ResultSet should only have 1 row. - when(mockResultSet.getString(anyInt())).thenReturn("writerInstance"); - - final AuroraConnectionTrackerPlugin plugin = new AuroraConnectionTrackerPlugin( - mockPluginService, - EMPTY_PROPERTIES, - mockRdsUtils, - mockTracker); - - final Connection actualConnection = plugin.connect( - protocol, - hostSpec, - EMPTY_PROPERTIES, - isInitialConnection, - mockConnectionFunction); - - assertEquals(mockConnection, actualConnection); - verify(mockTracker).populateOpenedConnectionQueue(eq(hostSpec), eq(mockConnection)); - final Set aliases = hostSpec.getAliases(); - assertEquals(1, aliases.size()); - assertEquals("writerInstance", aliases.toArray()[0]); - } - - @ParameterizedTest - @MethodSource("trackNonRdsInstanceUrlParameters") - public void testTrackNewConnections_nonRdsInstanceUrl( - final String endpoint, - final boolean isInitialConnection, - final String expected) throws SQLException { - final Properties properties = new Properties(); - AuroraHostListProvider.CLUSTER_INSTANCE_HOST_PATTERN.set(properties, "?.pattern"); - final HostSpec hostSpec = new HostSpec(endpoint); - when(mockPluginService.getCurrentHostSpec()).thenReturn(hostSpec); - when(mockRdsUtils.isRdsInstance(endpoint)).thenReturn(false); - when(mockResultSet.next()).thenReturn(true); - when(mockResultSet.getString(anyInt())).thenReturn(endpoint); - - final AuroraConnectionTrackerPlugin plugin = new AuroraConnectionTrackerPlugin( - mockPluginService, - properties, - mockRdsUtils, - mockTracker); - - final Connection actualConnection = plugin.connect( - "protocol", - hostSpec, - properties, - isInitialConnection, - mockConnectionFunction); - - assertEquals(mockConnection, actualConnection); - verify(mockTracker).populateOpenedConnectionQueue(eq(hostSpec), eq(mockConnection)); - final Set aliases = hostSpec.getAliases(); - assertEquals(1, aliases.size()); - assertEquals(expected, aliases.toArray()[0]); - } - - @ParameterizedTest - @MethodSource("trackNonRdsInstanceUrlWithoutClusterHostInstancePatternParameters") - public void testTrackNewConnections_nonRdsInstanceUrl_withoutClusterInstanceHostPattern( - final String endpoint, - final boolean isInitialConnection, - final String expected) throws SQLException { - final HostSpec hostSpec = new HostSpec(endpoint); - when(mockPluginService.getCurrentHostSpec()).thenReturn(hostSpec); - when(mockRdsUtils.isRdsInstance(endpoint)).thenReturn(false); - when(mockResultSet.next()).thenReturn(true); - when(mockResultSet.getString(anyInt())).thenReturn("instance-1"); - - final AuroraConnectionTrackerPlugin plugin = new AuroraConnectionTrackerPlugin( - mockPluginService, - EMPTY_PROPERTIES, - mockRdsUtils, - mockTracker); - - final Connection actualConnection = plugin.connect( - "protocol", - hostSpec, - EMPTY_PROPERTIES, - isInitialConnection, - mockConnectionFunction); - - assertEquals(mockConnection, actualConnection); - verify(mockTracker).populateOpenedConnectionQueue(eq(hostSpec), eq(mockConnection)); - final Set aliases = hostSpec.getAliases(); - assertEquals(1, aliases.size()); - assertEquals(expected, aliases.toArray()[0]); - } - @Test public void testInvalidateOpenedConnections() throws SQLException { final FailoverSQLException expectedException = new FailoverSQLException("reason", "sqlstate"); final HostSpec originalHost = new HostSpec("host"); - when(mockPluginService.getCurrentHostSpec()).thenReturn(originalHost); + when(mockPluginService.getHosts()).thenReturn(Collections.singletonList(originalHost)); doThrow(expectedException).when(mockSqlFunction).call(); final AuroraConnectionTrackerPlugin plugin = new AuroraConnectionTrackerPlugin( @@ -271,20 +175,4 @@ private static Stream trackNewConnectionsParameters() { Arguments.of("otherProtocol", false) ); } - - private static Stream trackNonRdsInstanceUrlParameters() { - return Stream.of( - Arguments.of("custom.domain", true, "custom.domain.pattern"), - Arguments.of("instanceName", false, "instanceName.pattern"), - Arguments.of("8.8.8.8", true, "8.8.8.8.pattern") - ); - } - - private static Stream trackNonRdsInstanceUrlWithoutClusterHostInstancePatternParameters() { - return Stream.of( - Arguments.of("custom.domain", true, "instance-1"), - Arguments.of("instanceName", false, "instance-1"), - Arguments.of("8.8.8.8", true, "instance-1") - ); - } } diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/efm/ConcurrencyTests.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/efm/ConcurrencyTests.java index 198a29f75..9f3086984 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/plugin/efm/ConcurrencyTests.java +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/efm/ConcurrencyTests.java @@ -380,6 +380,16 @@ public Dialect getDialect() { } public void updateDialect(final @NonNull Connection connection) throws SQLException { } + + @Override + public HostSpec identifyConnection(Connection connection) throws SQLException { + return null; + } + + @Override + public void fillAliases(Connection connection, HostSpec hostSpec) throws SQLException { + + } } public static class TestConnection implements Connection { diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/efm/HostMonitoringConnectionPluginTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/efm/HostMonitoringConnectionPluginTest.java index 7bd49d433..d0f1ccfc0 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/plugin/efm/HostMonitoringConnectionPluginTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/efm/HostMonitoringConnectionPluginTest.java @@ -39,7 +39,6 @@ import java.util.Collections; import java.util.EnumSet; import java.util.HashSet; -import java.util.List; import java.util.Properties; import java.util.Set; import java.util.function.Supplier; @@ -61,9 +60,9 @@ import software.amazon.jdbc.OldConnectionSuggestedAction; import software.amazon.jdbc.PluginService; import software.amazon.jdbc.dialect.Dialect; -import software.amazon.jdbc.dialect.MysqlDialect; -import software.amazon.jdbc.dialect.PgDialect; import software.amazon.jdbc.util.Messages; +import software.amazon.jdbc.util.RdsUrlType; +import software.amazon.jdbc.util.RdsUtils; class HostMonitoringConnectionPluginTest { @@ -82,7 +81,9 @@ class HostMonitoringConnectionPluginTest { @Captor ArgumentCaptor stringArgumentCaptor; Properties properties = new Properties(); @Mock HostSpec hostSpec; + @Mock HostSpec hostSpec2; @Mock Supplier supplier; + @Mock RdsUtils rdsUtils; @Mock MonitorConnectionContext context; @Mock MonitorService monitorService; @Mock JdbcCallable sqlFunction; @@ -137,8 +138,12 @@ void initDefaultMockReturns() throws Exception { when(hostSpec.getHost()).thenReturn("host"); when(hostSpec.getHost()).thenReturn("port"); when(hostSpec.getAliases()).thenReturn(new HashSet<>(Collections.singletonList("host:port"))); + when(hostSpec2.getHost()).thenReturn("host"); + when(hostSpec2.getHost()).thenReturn("port"); + when(hostSpec2.getAliases()).thenReturn(new HashSet<>(Collections.singletonList("host:port"))); when(connection.createStatement()).thenReturn(statement); when(statement.executeQuery(any())).thenReturn(resultSet); + when(rdsUtils.identifyRdsType(any())).thenReturn(RdsUrlType.RDS_INSTANCE); properties.put("failureDetectionEnabled", Boolean.TRUE.toString()); properties.put("failureDetectionTime", String.valueOf(FAILURE_DETECTION_TIME)); @@ -147,7 +152,7 @@ void initDefaultMockReturns() throws Exception { } private void initializePlugin() { - plugin = new HostMonitoringConnectionPlugin(pluginService, properties, supplier); + plugin = new HostMonitoringConnectionPlugin(pluginService, properties, supplier, rdsUtils); } @ParameterizedTest @@ -269,40 +274,6 @@ void test_executeCleanUp_whenAbortConnection_throwsException() throws SQLExcepti verify(connection).close(); } - @ParameterizedTest - @MethodSource("getHostPortSQLParameters") - void test_connect_withNoAdditionalHostAlias(final String protocol, final String expectedSql) throws SQLException { - initializePlugin(); - - when(hostSpec.asAlias()).thenReturn("hostSpec alias"); - when(mockDialect.getHostAliasQuery()).thenReturn(expectedSql); - - plugin.connect(protocol, hostSpec, properties, true, () -> connection); - verify(hostSpec).addAlias("hostSpec alias"); - verify(statement).executeQuery(eq(expectedSql)); - } - - @ParameterizedTest - @MethodSource("getHostPortSQLParameters") - void test_connect_withHostAliases(final String protocol, final String expectedSql) throws SQLException { - initializePlugin(); - - when(hostSpec.asAlias()).thenReturn("hostSpec alias"); - when(mockDialect.getHostAliasQuery()).thenReturn(expectedSql); - - // ResultSet contains one row. - when(resultSet.next()).thenReturn(true, false); - when(resultSet.getString(eq(1))).thenReturn("second alias"); - - plugin.connect(protocol, hostSpec, properties, true, () -> connection); - verify(hostSpec, times(2)).addAlias(stringArgumentCaptor.capture()); - final List captures = stringArgumentCaptor.getAllValues(); - assertEquals(2, captures.size()); - assertEquals("hostSpec alias", captures.get(0)); - assertEquals("second alias", captures.get(1)); - verify(statement).executeQuery(eq(expectedSql)); - } - @Test void test_connect_exceptionRaisedDuringGenerateHostAliases() throws SQLException { initializePlugin(); @@ -326,16 +297,21 @@ void test_notifyConnectionChanged_nodeWentDown(final NodeChangeOptions option) t sqlFunction, EMPTY_ARGS); - final Set aliases = new HashSet<>(Arrays.asList("alias1", "alias2")); - when(hostSpec.getAliases()).thenReturn(aliases); - assertEquals(OldConnectionSuggestedAction.NO_OPINION, plugin.notifyConnectionChanged(EnumSet.of(option))); - - // NodeKeys should be empty at first - verify(monitorService, never()).stopMonitoringForAllConnections(any()); + final Set aliases1 = new HashSet<>(Arrays.asList("alias1", "alias2")); + final Set aliases2 = new HashSet<>(Arrays.asList("alias3", "alias4")); + when(hostSpec.asAliases()).thenReturn(aliases1); + when(hostSpec2.asAliases()).thenReturn(aliases2); + when(pluginService.getCurrentHostSpec()).thenReturn(hostSpec); assertEquals(OldConnectionSuggestedAction.NO_OPINION, plugin.notifyConnectionChanged(EnumSet.of(option))); // NodeKeys should contain {"alias1", "alias2"} - verify(monitorService).stopMonitoringForAllConnections(aliases); + verify(monitorService).stopMonitoringForAllConnections(aliases1); + + when(pluginService.getCurrentHostSpec()).thenReturn(hostSpec2); + assertEquals(OldConnectionSuggestedAction.NO_OPINION, plugin.notifyConnectionChanged(EnumSet.of(option))); + // NotifyConnectionChanged should reset the monitoringHostSpec. + // NodeKeys should contain {"alias3", "alias4"} + verify(monitorService).stopMonitoringForAllConnections(aliases2); } @Test @@ -358,18 +334,6 @@ void test_releaseResources() throws SQLException { verify(monitorService).releaseResources(); } - static Stream getHostPortSQLParameters() { - final String MYSQL_RETRIEVE_HOST_PORT_SQL = new MysqlDialect().getHostAliasQuery(); - final String PG_RETRIEVE_HOST_PORT_SQL = new PgDialect().getHostAliasQuery(); - - return Stream.of( - Arguments.of("jdbc:mysql:", MYSQL_RETRIEVE_HOST_PORT_SQL), - Arguments.of("jdbc:mysql:someUrl", MYSQL_RETRIEVE_HOST_PORT_SQL), - Arguments.of("jdbc:postgresql:", PG_RETRIEVE_HOST_PORT_SQL), - Arguments.of("jdbc:postgresql:someUrl", PG_RETRIEVE_HOST_PORT_SQL) - ); - } - static Stream nodeChangeOptions() { return Stream.of( Arguments.of(NodeChangeOptions.WENT_DOWN), diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/efm/MonitorServiceImplTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/efm/MonitorServiceImplTest.java index 9397b657e..3617d207e 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/plugin/efm/MonitorServiceImplTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/efm/MonitorServiceImplTest.java @@ -42,6 +42,7 @@ import org.mockito.Mock; import org.mockito.MockitoAnnotations; import software.amazon.jdbc.HostSpec; +import software.amazon.jdbc.PluginService; class MonitorServiceImplTest { @@ -59,6 +60,7 @@ class MonitorServiceImplTest { @Mock private Future task; @Mock private HostSpec hostSpec; @Mock private JdbcConnection connection; + @Mock private PluginService pluginService; private Properties properties; private AutoCloseable closeable; @@ -79,7 +81,7 @@ void init() { doReturn(task).when(executorService).submit(any(Monitor.class)); - monitorService = new MonitorServiceImpl(monitorInitializer, executorServiceInitializer); + monitorService = new MonitorServiceImpl(pluginService, monitorInitializer, executorServiceInitializer); } @AfterEach diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/efm/MultiThreadedDefaultMonitorServiceTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/efm/MultiThreadedDefaultMonitorServiceTest.java index fc4cb4aed..af2a8c24b 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/plugin/efm/MultiThreadedDefaultMonitorServiceTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/efm/MultiThreadedDefaultMonitorServiceTest.java @@ -54,6 +54,7 @@ import org.mockito.Mock; import org.mockito.MockitoAnnotations; import software.amazon.jdbc.HostSpec; +import software.amazon.jdbc.PluginService; /** * Multithreaded tests for {@link MultiThreadedDefaultMonitorServiceTest}. Repeats each testcase @@ -69,6 +70,7 @@ class MultiThreadedDefaultMonitorServiceTest { @Mock Monitor monitor; @Mock Properties properties; @Mock JdbcConnection connection; + @Mock PluginService pluginService; private final AtomicInteger counter = new AtomicInteger(0); private final AtomicInteger concurrentCounter = new AtomicInteger(0); @@ -399,7 +401,7 @@ private List generateContexts( private List generateServices(final int numServices) { final List services = new ArrayList<>(); for (int i = 0; i < numServices; i++) { - services.add(new MonitorServiceImpl(monitorInitializer, executorServiceInitializer)); + services.add(new MonitorServiceImpl(pluginService, monitorInitializer, executorServiceInitializer)); } return services; } diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPluginTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPluginTest.java index b937399e7..6467d7ecd 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPluginTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPluginTest.java @@ -523,5 +523,10 @@ public List forceRefresh(Connection connection) { public HostRole getHostRole(Connection conn) { return HostRole.WRITER; } + + @Override + public HostSpec identifyConnection(Connection connection) throws SQLException { + return new HostSpec("foo"); + } } }