2020import java .sql .Connection ;
2121import java .sql .SQLException ;
2222import java .util .Arrays ;
23+ import java .util .HashSet ;
2324import java .util .Map ;
2425import java .util .Objects ;
25- import java .util .Optional ;
2626import java .util .Queue ;
2727import java .util .Set ;
2828import java .util .concurrent .ConcurrentHashMap ;
2929import java .util .concurrent .ConcurrentLinkedQueue ;
30+ import java .util .concurrent .Executor ;
3031import java .util .concurrent .ExecutorService ;
3132import java .util .concurrent .Executors ;
3233import java .util .logging .Logger ;
3536import software .amazon .jdbc .util .Messages ;
3637import software .amazon .jdbc .util .RdsUtils ;
3738import software .amazon .jdbc .util .StringUtils ;
39+ import software .amazon .jdbc .util .SynchronousExecutor ;
3840import software .amazon .jdbc .util .telemetry .TelemetryContext ;
3941import software .amazon .jdbc .util .telemetry .TelemetryFactory ;
4042import software .amazon .jdbc .util .telemetry .TelemetryTraceLevel ;
@@ -50,16 +52,17 @@ public class OpenedConnectionTracker {
5052 invalidateThread .setDaemon (true );
5153 return invalidateThread ;
5254 });
53- private static final ExecutorService abortConnectionExecutorService =
54- Executors .newCachedThreadPool (
55- r -> {
56- final Thread abortThread = new Thread (r );
57- abortThread .setDaemon (true );
58- return abortThread ;
59- });
55+ private static final Executor abortConnectionExecutor = new SynchronousExecutor ();
6056
6157 private static final Logger LOGGER = Logger .getLogger (OpenedConnectionTracker .class .getName ());
6258 private static final RdsUtils rdsUtils = new RdsUtils ();
59+
60+ private static final Set <String > safeToCheckClosedClasses = new HashSet <>(Arrays .asList (
61+ "HikariProxyConnection" ,
62+ "org.postgresql.jdbc.PgConnection" ,
63+ "com.mysql.cj.jdbc.ConnectionImpl" ,
64+ "org.mariadb.jdbc.Connection" ));
65+
6366 private final PluginService pluginService ;
6467
6568 public OpenedConnectionTracker (final PluginService pluginService ) {
@@ -72,6 +75,7 @@ public void populateOpenedConnectionQueue(final HostSpec hostSpec, final Connect
7275 // Check if the connection was established using an instance endpoint
7376 if (rdsUtils .isRdsInstance (hostSpec .getHost ())) {
7477 trackConnection (hostSpec .getHostAndPort (), conn );
78+ logOpenedConnections ();
7579 return ;
7680 }
7781
@@ -80,14 +84,17 @@ public void populateOpenedConnectionQueue(final HostSpec hostSpec, final Connect
8084 .max (String ::compareToIgnoreCase )
8185 .orElse (null );
8286
83- if (instanceEndpoint == null ) {
84- LOGGER .finest (
85- Messages .get ("OpenedConnectionTracker.unableToPopulateOpenedConnectionQueue" ,
86- new Object [] {hostSpec .getHost ()}));
87+ if (instanceEndpoint != null ) {
88+ trackConnection (instanceEndpoint , conn );
89+ logOpenedConnections ();
8790 return ;
8891 }
8992
90- trackConnection (instanceEndpoint , conn );
93+ // It seems there's no RDS instance host found. It might be a custom domain name. Let's track by all aliases
94+ for (String alias : aliases ) {
95+ trackConnection (alias , conn );
96+ }
97+ logOpenedConnections ();
9198 }
9299
93100 /**
@@ -100,28 +107,27 @@ public void invalidateAllConnections(final HostSpec hostSpec) {
100107 invalidateAllConnections (hostSpec .getAliases ().toArray (new String [] {}));
101108 }
102109
103- public void invalidateAllConnections (final String ... node ) {
110+ public void invalidateAllConnections (final String ... keys ) {
104111 TelemetryFactory telemetryFactory = this .pluginService .getTelemetryFactory ();
105112 TelemetryContext telemetryContext = telemetryFactory .openTelemetryContext (
106113 TELEMETRY_INVALIDATE_CONNECTIONS , TelemetryTraceLevel .NESTED );
107114
108115 try {
109- final Optional <String > instanceEndpoint = Arrays .stream (node )
110- .filter (x -> rdsUtils .isRdsInstance (rdsUtils .removePort (x )))
111- .findFirst ();
112- if (!instanceEndpoint .isPresent ()) {
113- return ;
116+ for (String key : keys ) {
117+ try {
118+ final Queue <WeakReference <Connection >> connectionQueue = openedConnections .get (key );
119+ logConnectionQueue (key , connectionQueue );
120+ invalidateConnections (connectionQueue );
121+ } catch (Exception ex ) {
122+ // ignore and continue
123+ }
114124 }
115- final Queue <WeakReference <Connection >> connectionQueue = openedConnections .get (instanceEndpoint .get ());
116- logConnectionQueue (instanceEndpoint .get (), connectionQueue );
117- invalidateConnections (openedConnections .get (instanceEndpoint .get ()));
118-
119125 } finally {
120126 telemetryContext .closeContext ();
121127 }
122128 }
123129
124- public void invalidateCurrentConnection (final HostSpec hostSpec , final Connection connection ) {
130+ public void removeConnectionTracking (final HostSpec hostSpec , final Connection connection ) {
125131 final String host = rdsUtils .isRdsInstance (hostSpec .getHost ())
126132 ? hostSpec .asAlias ()
127133 : hostSpec .getAliases ().stream ()
@@ -134,8 +140,11 @@ public void invalidateCurrentConnection(final HostSpec hostSpec, final Connectio
134140 }
135141
136142 final Queue <WeakReference <Connection >> connectionQueue = openedConnections .get (host );
137- logConnectionQueue (host , connectionQueue );
138- connectionQueue .removeIf (connectionWeakReference -> Objects .equals (connectionWeakReference .get (), connection ));
143+ if (connectionQueue != null ) {
144+ logConnectionQueue (host , connectionQueue );
145+ connectionQueue .removeIf (connectionWeakReference -> connectionWeakReference != null
146+ && Objects .equals (connectionWeakReference .get (), connection ));
147+ }
139148 }
140149
141150 private void trackConnection (final String instanceEndpoint , final Connection connection ) {
@@ -144,10 +153,12 @@ private void trackConnection(final String instanceEndpoint, final Connection con
144153 instanceEndpoint ,
145154 (k ) -> new ConcurrentLinkedQueue <>());
146155 connectionQueue .add (new WeakReference <>(connection ));
147- logOpenedConnections ();
148156 }
149157
150158 private void invalidateConnections (final Queue <WeakReference <Connection >> connectionQueue ) {
159+ if (connectionQueue == null || connectionQueue .isEmpty ()) {
160+ return ;
161+ }
151162 invalidateConnectionsExecutorService .submit (() -> {
152163 WeakReference <Connection > connReference ;
153164 while ((connReference = connectionQueue .poll ()) != null ) {
@@ -157,7 +168,7 @@ private void invalidateConnections(final Queue<WeakReference<Connection>> connec
157168 }
158169
159170 try {
160- conn .abort (abortConnectionExecutorService );
171+ conn .abort (abortConnectionExecutor );
161172 } catch (final SQLException e ) {
162173 // swallow this exception, current connection should be useless anyway.
163174 }
@@ -204,7 +215,10 @@ public void pruneNullConnections() {
204215 if (conn == null ) {
205216 return true ;
206217 }
207- if (conn .getClass ().getSimpleName ().equals ("HikariProxyConnection" )) {
218+ // The following classes do not check connection validity by calling a DB server
219+ // so it's safe to check whether connection is already closed.
220+ if (safeToCheckClosedClasses .contains (conn .getClass ().getSimpleName ())
221+ || safeToCheckClosedClasses .contains (conn .getClass ().getName ())) {
208222 try {
209223 return conn .isClosed ();
210224 } catch (SQLException ex ) {
0 commit comments