diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionPool.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionPool.java index a5eee9d0db..a1a07ffd8c 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionPool.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionPool.java @@ -85,7 +85,9 @@ import io.opentelemetry.api.common.Attributes; import io.opentelemetry.api.common.AttributesBuilder; import io.opentelemetry.api.metrics.Meter; +import java.util.ArrayDeque; import java.util.ArrayList; +import java.util.Deque; import java.util.HashSet; import java.util.Iterator; import java.util.LinkedList; @@ -119,6 +121,7 @@ class SessionPool { private static final Logger logger = Logger.getLogger(SessionPool.class.getName()); private final TraceWrapper tracer; static final String WAIT_FOR_SESSION = "SessionPool.WaitForSession"; + static final String WAIT_FOR_MULTIPLEXED_SESSION = "SessionPool.WaitForMultiplexedSession"; /** * If the {@link SessionPoolOptions#getWaitForMinSessions()} duration is greater than zero, waits @@ -144,6 +147,26 @@ void maybeWaitOnMinSessions() { } } + void waitOnMultiplexedSession() { + if (options.getUseMultiplexedSession()) { + final long timeoutNanos = options.getWaitForMultiplexedSession().toNanos(); + if (timeoutNanos <= 0) { + return; + } + + try { + if (!multiplexedSessionsInitialized.await(timeoutNanos, TimeUnit.NANOSECONDS)) { + final long timeoutMillis = options.getWaitForMultiplexedSession().toMillis(); + throw SpannerExceptionFactory.newSpannerException( + ErrorCode.DEADLINE_EXCEEDED, + "Timed out after waiting " + timeoutMillis + "ms for multiplexed session creation"); + } + } catch (InterruptedException e) { + throw SpannerExceptionFactory.propagateInterrupt(e); + } + } + } + private abstract static class CachedResultSetSupplier implements Supplier { private ResultSet cached; @@ -1154,6 +1177,11 @@ private PooledSessionFuture createPooledSessionFuture( return new PooledSessionFuture(future, span); } + private MultiplexedSessionFuture createMultiplexedSessionFuture( + ListenableFuture future, ISpan span) { + return new MultiplexedSessionFuture(future, span); + } + interface SessionFuture extends Session { /** @@ -1419,7 +1447,7 @@ PooledSession get(final boolean eligibleForLongRunning) { res.markBusy(span); span.addAnnotation("Using Session", "sessionId", res.getName()); synchronized (lock) { - incrementNumSessionsInUse(); + incrementNumSessionsInUse(false); checkedOutSessions.add(this); } res.eligibleForLongRunning = eligibleForLongRunning; @@ -2154,6 +2182,87 @@ public ApiFuture asyncClose() { } } + private final class MultiplexedSessionWaiterFuture + extends ForwardingListenableFuture { + private final SettableFuture waiter = SettableFuture.create(); + + @Override + protected ListenableFuture delegate() { + return waiter; + } + + private void put(MultiplexedSession session) { + waiter.set(session); + } + + private void put(SpannerException e) { + waiter.setException(e); + } + + @Override + public MultiplexedSession get() { + while (true) { + ISpan span = tracer.spanBuilder(WAIT_FOR_MULTIPLEXED_SESSION); + try (IScope waitScope = tracer.withSpan(span)) { + MultiplexedSession s = pollUninterruptedlyWithTimeout(options.getAcquireSessionTimeout()); + if (s == null) { + // Set the status to DEADLINE_EXCEEDED and retry. + numMultiplexedSessionWaiterTimeouts.incrementAndGet(); + tracer.getCurrentSpan().setStatus(ErrorCode.DEADLINE_EXCEEDED); + } else { + return s; + } + } catch (Exception e) { + if (e instanceof SpannerException + && ErrorCode.RESOURCE_EXHAUSTED.equals(((SpannerException) e).getErrorCode())) { + numMultiplexedSessionWaiterTimeouts.incrementAndGet(); + tracer.getCurrentSpan().setStatus(ErrorCode.RESOURCE_EXHAUSTED); + } + span.setStatus(e); + throw e; + } finally { + span.end(); + } + } + } + + /** + * Method which allows to obtain a multiplexed session after blocking for a configurable + * duration {@link SessionPoolOptions#getAcquireSessionTimeout()}. Note that this duration + * becomes obsolete in case we have set {@link + * SessionPoolOptions#getWaitForMultiplexedSession()}. Because {@link + * SessionPoolOptions#getWaitForMultiplexedSession()} will ensure that the multiplexed session + * is initialized and available during application start up. + */ + private MultiplexedSession pollUninterruptedlyWithTimeout(Duration acquireSessionTimeout) { + boolean interrupted = false; + try { + while (true) { + try { + return waiter.get(acquireSessionTimeout.toMillis(), TimeUnit.MILLISECONDS); + } catch (InterruptedException e) { + interrupted = true; + } catch (TimeoutException e) { + if (acquireSessionTimeout != null) { + throw SpannerExceptionFactory.newSpannerException( + ErrorCode.RESOURCE_EXHAUSTED, + String.format( + "Timed out after waiting %s ms to acquire multiplexed session.", + acquireSessionTimeout.toMillis())); + } + return null; + } catch (ExecutionException e) { + throw SpannerExceptionFactory.newSpannerException(e.getCause()); + } + } + } finally { + if (interrupted) { + Thread.currentThread().interrupt(); + } + } + } + } + private final class WaiterFuture extends ForwardingListenableFuture { private static final long MAX_SESSION_WAIT_TIMEOUT = 240_000L; private final SettableFuture waiter = SettableFuture.create(); @@ -2333,6 +2442,7 @@ void maintainPool() { this.prevNumSessionsAcquired = SessionPool.this.numSessionsAcquired; } Instant currTime = clock.instant(); + maintainMultiplexedSession(currTime); removeIdleSessions(currTime); // Now go over all the remaining sessions and see if they need to be kept alive explicitly. keepAliveSessions(currTime); @@ -2501,6 +2611,51 @@ private void removeLongRunningSessions( } } } + + void maintainMultiplexedSession(Instant currentTime) { + try { + if (options.getUseMultiplexedSession()) { + Iterator iterator = multiplexedSessions.iterator(); + while (iterator.hasNext()) { + final MultiplexedSession session = iterator.next(); + final Duration durationFromCreationTime = + Duration.between(session.getDelegate().getCreateTime(), currentTime); + if (durationFromCreationTime.compareTo( + options.getMultiplexedSessionMaintenanceDuration()) + > 0 + && numMultiplexedSessionsBeingCreated == 0) { + logger.log( + Level.INFO, + String.format( + "Replacing Multiplexed Session => %s since it's created before maintenance " + + "window => %s days", + session.getName(), + options.getMultiplexedSessionMaintenanceDuration().toDays())); + createMultiplexedSessions(); + } + + // if there is are > 1 multiplexed sessions, this means there are some stale sessions + // since a multiplexed session is only created if a previous one is stale. + // remove all stale sessions. Note that there may be active transaction which are + // running + // on the stale sessions, hence we do not issue a DeleteSession RPC. + synchronized (lock) { + while (multiplexedSessions.size() > 1) { + final MultiplexedSession lastSession = multiplexedSessions.pollLast(); + logger.log( + Level.INFO, + String.format("Removed Multiplexed Session => %s", lastSession.getName())); + if (multiplexedSessionRemovedListener != null) { + multiplexedSessionRemovedListener.apply(lastSession); + } + } + } + } + } + } catch (final Throwable t) { + logger.log(Level.WARNING, "Failed to maintain multiplexed session", t); + } + } } enum Position { @@ -2574,9 +2729,16 @@ enum Position { @GuardedBy("lock") private final Queue waiters = new LinkedList<>(); + @GuardedBy("lock") + private final Queue multiplexedSessionWaiters = + new LinkedList<>(); + @GuardedBy("lock") private int numSessionsBeingCreated = 0; + @GuardedBy("lock") + private int numMultiplexedSessionsBeingCreated = 0; + @GuardedBy("lock") private int numSessionsInUse = 0; @@ -2605,6 +2767,12 @@ enum Position { private long numLeakedSessionsRemoved = 0; private AtomicLong numWaiterTimeouts = new AtomicLong(); + private AtomicLong numMultiplexedSessionWaiterTimeouts = new AtomicLong(); + + @GuardedBy("lock") + private final Deque multiplexedSessions; + + private final CountDownLatch multiplexedSessionsInitialized; @GuardedBy("lock") private final Set allSessions = new HashSet<>(); @@ -2615,9 +2783,13 @@ enum Position { private final SessionConsumer sessionConsumer = new SessionConsumerImpl(); + private final MultiplexedSessionConsumer multiplexedSessionConsumer = + new MultiplexedSessionConsumer(); + @VisibleForTesting Function idleSessionRemovedListener; @VisibleForTesting Function longRunningSessionRemovedListener; + @VisibleForTesting Function multiplexedSessionRemovedListener; private final CountDownLatch waitOnMinSessionsLatch; private final SessionReplacementHandler pooledSessionReplacementHandler = @@ -2749,6 +2921,8 @@ private SessionPool( this.initOpenTelemetryMetricsCollection(openTelemetry, attributes); this.waitOnMinSessionsLatch = options.getMinSessions() > 0 ? new CountDownLatch(1) : new CountDownLatch(0); + this.multiplexedSessions = new ArrayDeque<>(); + this.multiplexedSessionsInitialized = new CountDownLatch(1); } /** @@ -2855,12 +3029,19 @@ long getNumWaiterTimeouts() { return numWaiterTimeouts.get(); } + long getNumMultiplexedSessionWaiterTimeouts() { + return numMultiplexedSessionWaiterTimeouts.get(); + } + private void initPool() { synchronized (lock) { poolMaintainer.init(); if (options.getMinSessions() > 0) { createSessions(options.getMinSessions(), true); } + if (options.getUseMultiplexedSession()) { + createMultiplexedSessions(); + } } } @@ -2922,6 +3103,32 @@ boolean isValid() { } } + /** + * Returns a multiplexed session. We would always return the session which is at the front of the + * queue since this will be the most recently created session. + */ + SessionFuture getMultiplexedSessionWithFallback() throws SpannerException { + if (options.getUseMultiplexedSession()) { + ISpan span = tracer.getCurrentSpan(); + span.addAnnotation("Acquiring multiplexed session"); + MultiplexedSessionWaiterFuture waiter = null; + synchronized (lock) { + MultiplexedSession session = multiplexedSessions.peek(); + if (session != null) { + span.addAnnotation("Acquired multiplexed session", "sessionId", session.getName()); + incrementNumSessionsInUse(true); + } else { + span.addAnnotation("Multiplexed session un-available. Adding to waiter queue."); + waiter = new MultiplexedSessionWaiterFuture(); + multiplexedSessionWaiters.add(waiter); + } + return checkoutMultiplexedSession(span, session, waiter); + } + } else { + return getSession(); + } + } + /** * Returns a session to be used for requests to spanner. This method is always non-blocking and * returns a {@link PooledSessionFuture}. In case the pool is exhausted and {@link @@ -2969,6 +3176,26 @@ PooledSessionFuture getSession() throws SpannerException { } } + private MultiplexedSessionFuture checkoutMultiplexedSession( + final ISpan span, + final MultiplexedSession readySession, + MultiplexedSessionWaiterFuture waiter) { + ListenableFuture sessionFuture; + if (waiter != null) { + logger.log( + Level.FINE, + "No multiplexed session available. Blocking for one to become available/created"); + span.addAnnotation("Waiting for a multiplexed session to come available"); + sessionFuture = waiter; + } else { + SettableFuture fut = SettableFuture.create(); + fut.set(readySession); + sessionFuture = fut; + } + MultiplexedSessionFuture res = createMultiplexedSessionFuture(sessionFuture, span); + return res; + } + private PooledSessionFuture checkoutSession( final ISpan span, final PooledSession readySession, WaiterFuture waiter) { ListenableFuture sessionFuture; @@ -2988,12 +3215,16 @@ private PooledSessionFuture checkoutSession( return res; } - private void incrementNumSessionsInUse() { + private void incrementNumSessionsInUse(boolean isMultiplexed) { synchronized (lock) { - if (maxSessionsInUse < ++numSessionsInUse) { - maxSessionsInUse = numSessionsInUse; + if (!isMultiplexed) { + if (maxSessionsInUse < ++numSessionsInUse) { + maxSessionsInUse = numSessionsInUse; + } + numSessionsAcquired++; + } else { + numMultiplexedSessionsAcquired++; } - numSessionsAcquired++; } } @@ -3283,6 +3514,13 @@ int totalSessions() { } } + @VisibleForTesting + int totalMultiplexedSessions() { + synchronized (lock) { + return multiplexedSessions.size(); + } + } + private ApiFuture closeSessionAsync(final PooledSession sess) { ApiFuture res = sess.delegate.asyncClose(); res.addListener( @@ -3321,6 +3559,29 @@ private boolean canCreateSession() { } } + private void createMultiplexedSessions() { + logger.log(Level.FINE, String.format("Creating multiplexed sessions")); + synchronized (lock) { + try { + numMultiplexedSessionsBeingCreated++; + sessionClient.createMultiplexedSession(multiplexedSessionConsumer); + } catch (Throwable t) { + numMultiplexedSessionsBeingCreated--; + handleMultiplexedSessionsFailure(newSpannerException(t)); + } + } + } + + private void handleMultiplexedSessionsFailure(SpannerException e) { + // other errors for dialect detection or database not found error is not handled here. + // we are relying on handleCreateSessionsFailure method for its handling + synchronized (lock) { + while (!multiplexedSessionWaiters.isEmpty()) { + multiplexedSessionWaiters.poll().put(e); + } + } + } + private void createSessions(final int sessionCount, boolean distributeOverChannels) { logger.log(Level.FINE, String.format("Creating %d sessions", sessionCount)); synchronized (lock) { @@ -3343,6 +3604,27 @@ private void createSessions(final int sessionCount, boolean distributeOverChanne } } + class MultiplexedSessionConsumer implements SessionConsumer { + + @Override + public void onSessionReady(SessionImpl session) { + final MultiplexedSession multiplexedSession = new MultiplexedSession(session); + synchronized (lock) { + multiplexedSessions.addFirst(multiplexedSession); + multiplexedSessionsInitialized.countDown(); + numMultiplexedSessionsBeingCreated--; + while (!multiplexedSessionWaiters.isEmpty()) { + multiplexedSessionWaiters.poll().put(multiplexedSession); + } + } + } + + @Override + public void onSessionCreateFailure(Throwable t, int createFailureForSessionCount) { + handleMultiplexedSessionsFailure(newSpannerException(t)); + } + } + /** * {@link SessionConsumer} that receives the created sessions from a {@link SessionClient} and * releases these into the pool. The session pool only needs one instance of this, as all sessions diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerImpl.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerImpl.java index 69c8be9a70..7c3e77f348 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerImpl.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerImpl.java @@ -279,6 +279,7 @@ public DatabaseClient getDatabaseClient(DatabaseId db) { labelValues, attributesBuilder.build()); pool.maybeWaitOnMinSessions(); + pool.waitOnMultiplexedSession(); DatabaseClientImpl dbClient = createDatabaseClient(clientId, pool); dbClients.put(db, dbClient); return dbClient; diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/BaseSessionPoolTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/BaseSessionPoolTest.java index cbfa4bbb60..743d21b587 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/BaseSessionPoolTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/BaseSessionPoolTest.java @@ -29,6 +29,7 @@ import com.google.cloud.spanner.Options.TransactionOption; import com.google.cloud.spanner.spi.v1.SpannerRpc.Option; import com.google.protobuf.Empty; +import com.google.protobuf.Timestamp; import java.util.HashMap; import java.util.Map; import java.util.concurrent.ScheduledExecutorService; @@ -121,6 +122,44 @@ public CommitResponse writeWithOptions( return session; } + SessionImpl buildMockMultiplexedSession(ReadContext context, Timestamp creationTime) { + SpannerImpl spanner = mock(SpannerImpl.class); + Map options = new HashMap<>(); + final SessionImpl session = + new SessionImpl( + spanner, + "projects/dummy/instances/dummy/databases/dummy/sessions/session" + sessionIndex, + creationTime, + true, + options) { + @Override + public ReadContext singleUse(TimestampBound bound) { + // The below stubs are added so that we can mock keep-alive. + return context; + } + + @Override + public ApiFuture asyncClose() { + return ApiFutures.immediateFuture(Empty.getDefaultInstance()); + } + + @Override + public CommitResponse writeAtLeastOnceWithOptions( + Iterable mutations, TransactionOption... transactionOptions) + throws SpannerException { + return new CommitResponse(com.google.spanner.v1.CommitResponse.getDefaultInstance()); + } + + @Override + public CommitResponse writeWithOptions( + Iterable mutations, TransactionOption... options) throws SpannerException { + return new CommitResponse(com.google.spanner.v1.CommitResponse.getDefaultInstance()); + } + }; + sessionIndex++; + return session; + } + void runMaintenanceLoop(FakeClock clock, SessionPool pool, long numCycles) { for (int i = 0; i < numCycles; i++) { pool.poolMaintainer.maintainPool(); diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/MultiplexedSessionMaintainerTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/MultiplexedSessionMaintainerTest.java new file mode 100644 index 0000000000..c4fa65021c --- /dev/null +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/MultiplexedSessionMaintainerTest.java @@ -0,0 +1,231 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.spanner; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotEquals; +import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; +import static org.mockito.MockitoAnnotations.initMocks; + +import com.google.cloud.Timestamp; +import com.google.cloud.spanner.SessionClient.SessionConsumer; +import com.google.cloud.spanner.SessionPool.MultiplexedSession; +import com.google.cloud.spanner.SessionPool.MultiplexedSessionConsumer; +import com.google.cloud.spanner.SessionPool.Position; +import com.google.cloud.spanner.SessionPool.SessionFuture; +import io.opencensus.trace.Tracing; +import io.opentelemetry.api.OpenTelemetry; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.Mock; +import org.threeten.bp.Duration; +import org.threeten.bp.Instant; + +@RunWith(JUnit4.class) +public class MultiplexedSessionMaintainerTest extends BaseSessionPoolTest { + private ExecutorService executor = Executors.newSingleThreadExecutor(); + private @Mock SpannerImpl client; + private @Mock SessionClient sessionClient; + private @Mock SpannerOptions spannerOptions; + private DatabaseId db = DatabaseId.of("projects/p/instances/i/databases/unused"); + private SessionPoolOptions options; + private FakeClock clock = new FakeClock(); + private List multiplexedSessionsRemoved = new ArrayList<>(); + + @Before + public void setUp() { + initMocks(this); + when(client.getOptions()).thenReturn(spannerOptions); + when(client.getSessionClient(db)).thenReturn(sessionClient); + when(sessionClient.getSpanner()).thenReturn(client); + when(spannerOptions.getNumChannels()).thenReturn(4); + when(spannerOptions.getDatabaseRole()).thenReturn("role"); + options = + SessionPoolOptions.newBuilder() + .setMinSessions(1) + .setMaxIdleSessions(1) + .setMaxSessions(5) + .setIncStep(1) + .setKeepAliveIntervalMinutes(2) + .setUseMultiplexedSession(true) + .build(); + multiplexedSessionsRemoved.clear(); + } + + @Test + public void + testMaintainMultiplexedSession_whenNewSessionCreated_assertThatStaleSessionIsRemoved() { + doAnswer( + invocation -> { + MultiplexedSessionConsumer consumer = + invocation.getArgument(0, MultiplexedSessionConsumer.class); + ReadContext mockContext = mock(ReadContext.class); + Timestamp timestamp = + Timestamp.ofTimeSecondsAndNanos( + Instant.ofEpochMilli(clock.currentTimeMillis.get()).getEpochSecond(), 0); + consumer.onSessionReady( + setupMockSession( + buildMockMultiplexedSession(mockContext, timestamp.toProto()), mockContext)); + return null; + }) + .when(sessionClient) + .createMultiplexedSession(any(SessionConsumer.class)); + SessionPool pool = createPool(); + + // Run one maintenance loop. + SessionFuture session1 = pool.getMultiplexedSessionWithFallback(); + runMaintenanceLoop(clock, pool, 1); + assertEquals(1, pool.totalMultiplexedSessions()); + assertTrue(multiplexedSessionsRemoved.isEmpty()); + + // Advance clock by 4 days + clock.currentTimeMillis.addAndGet(Duration.ofDays(4).toMillis()); + // Run one maintenance loop. the first session would not be stale yet since it has now existed + // for less than 7 days. + runMaintenanceLoop(clock, pool, 1); + assertEquals(1, pool.totalMultiplexedSessions()); + assertTrue(multiplexedSessionsRemoved.isEmpty()); + + // Advance clock by 5 days + clock.currentTimeMillis.addAndGet(Duration.ofDays(5).toMillis()); + + // Run second maintenance loop. the first session would now be stale since it has now existed + // for + // more than 9 days. + runMaintenanceLoop(clock, pool, 1); + + SessionFuture session2 = pool.getMultiplexedSessionWithFallback(); + assertNotEquals(session1.getName(), session2.getName()); + assertEquals(1, pool.totalMultiplexedSessions()); + assertEquals(1, multiplexedSessionsRemoved.size()); + assertTrue(multiplexedSessionsRemoved.contains(session1.get())); + + // Advance clock by 8 days + clock.currentTimeMillis.addAndGet(Duration.ofDays(8).toMillis()); + + // Run third maintenance loop. the second session would now be stale since it has now existed + // for + // more than 7 days + runMaintenanceLoop(clock, pool, 1); + + SessionFuture session3 = pool.getMultiplexedSessionWithFallback(); + assertNotEquals(session2.getName(), session3.getName()); + assertEquals(1, pool.totalMultiplexedSessions()); + assertEquals(2, multiplexedSessionsRemoved.size()); + assertTrue(multiplexedSessionsRemoved.contains(session2.get())); + } + + @Test + public void + testMaintainMultiplexedSession_whenSessionCreationFailed_assertThatStaleSessionIsNotRemoved() { + doAnswer( + invocation -> { + MultiplexedSessionConsumer consumer = + invocation.getArgument(0, MultiplexedSessionConsumer.class); + ReadContext mockContext = mock(ReadContext.class); + Timestamp timestamp = + Timestamp.ofTimeSecondsAndNanos( + Instant.ofEpochMilli(clock.currentTimeMillis.get()).getEpochSecond(), 0); + consumer.onSessionReady( + setupMockSession( + buildMockMultiplexedSession(mockContext, timestamp.toProto()), mockContext)); + return null; + }) + .when(sessionClient) + .createMultiplexedSession(any(SessionConsumer.class)); + SessionPool pool = createPool(); + SessionFuture session1 = pool.getMultiplexedSessionWithFallback(); + + doThrow(RuntimeException.class) + .when(sessionClient) + .createMultiplexedSession(any(SessionConsumer.class)); + // Advance clock by 8 days + clock.currentTimeMillis.addAndGet(Duration.ofDays(8).toMillis()); + + // Run one maintenance loop. the first session would now be stale, but since new session + // creation failed, then the stale session won't be removed. + runMaintenanceLoop(clock, pool, 1); + assertEquals(1, pool.totalMultiplexedSessions()); + assertTrue(multiplexedSessionsRemoved.isEmpty()); + + doAnswer( + invocation -> { + MultiplexedSessionConsumer consumer = + invocation.getArgument(0, MultiplexedSessionConsumer.class); + ReadContext mockContext = mock(ReadContext.class); + Timestamp timestamp = + Timestamp.ofTimeSecondsAndNanos( + Instant.ofEpochMilli(clock.currentTimeMillis.get()).getEpochSecond(), 0); + consumer.onSessionReady( + setupMockSession( + buildMockMultiplexedSession(mockContext, timestamp.toProto()), mockContext)); + return null; + }) + .when(sessionClient) + .createMultiplexedSession(any(SessionConsumer.class)); + assertEquals(session1.getName(), pool.getMultiplexedSessionWithFallback().getName()); + + // Run second maintenance loop. the first session would now be stale since it has now existed + // for + // more than 7 days. + runMaintenanceLoop(clock, pool, 1); + SessionFuture session2 = pool.getMultiplexedSessionWithFallback(); + + assertNotEquals(session1.getName(), session2.getName()); + assertEquals(1, pool.totalMultiplexedSessions()); + assertEquals(1, multiplexedSessionsRemoved.size()); + assertTrue(multiplexedSessionsRemoved.contains(session1.get())); + } + + private SessionImpl setupMockSession(final SessionImpl session, final ReadContext mockContext) { + final ResultSet mockResult = mock(ResultSet.class); + when(mockContext.executeQuery(any(Statement.class))).thenAnswer(invocation -> mockResult); + when(mockResult.next()).thenReturn(true); + return session; + } + + private SessionPool createPool() { + // Allow sessions to be added to the head of the pool in all cases in this test, as it is + // otherwise impossible to know which session exactly is getting pinged at what point in time. + SessionPool pool = + SessionPool.createPool( + options, + new TestExecutorFactory(), + client.getSessionClient(db), + clock, + Position.FIRST, + new TraceWrapper(Tracing.getTracer(), OpenTelemetry.noop().getTracer("")), + OpenTelemetry.noop()); + pool.multiplexedSessionRemovedListener = + input -> { + multiplexedSessionsRemoved.add(input); + return null; + }; + return pool; + } +} diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/MultiplexedSessionPoolTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/MultiplexedSessionPoolTest.java new file mode 100644 index 0000000000..a5f96335aa --- /dev/null +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/MultiplexedSessionPoolTest.java @@ -0,0 +1,159 @@ +package com.google.cloud.spanner; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.mockito.MockitoAnnotations.initMocks; + +import com.google.cloud.spanner.SessionPool.MultiplexedSessionConsumer; +import com.google.cloud.spanner.SessionPool.MultiplexedSessionFuture; +import com.google.cloud.spanner.SpannerImpl.ClosedException; +import io.opencensus.trace.Tracing; +import io.opentelemetry.api.OpenTelemetry; +import java.io.PrintWriter; +import java.io.StringWriter; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mock; +import org.threeten.bp.Duration; + +/** + * Tests for {@link com.google.cloud.spanner.SessionPool.MultiplexedSession} component within the + * {@link SessionPool} class. + */ +public class MultiplexedSessionPoolTest extends BaseSessionPoolTest { + @Mock SpannerImpl client; + @Mock SessionClient sessionClient; + @Mock SpannerOptions spannerOptions; + private final DatabaseId db = DatabaseId.of("projects/p/instances/i/databases/unused"); + private final TraceWrapper tracer = + new TraceWrapper(Tracing.getTracer(), OpenTelemetry.noop().getTracer("")); + private final ExecutorService executor = Executors.newSingleThreadExecutor(); + + SessionPoolOptions options; + SessionPool pool; + + private SessionPool createPool() { + return SessionPool.createPool( + options, + new TestExecutorFactory(), + client.getSessionClient(db), + tracer, + OpenTelemetry.noop()); + } + + @Before + public void setUp() { + initMocks(this); + SpannerOptions.resetActiveTracingFramework(); + SpannerOptions.enableOpenTelemetryTraces(); + when(client.getOptions()).thenReturn(spannerOptions); + when(client.getSessionClient(db)).thenReturn(sessionClient); + when(sessionClient.getSpanner()).thenReturn(client); + when(spannerOptions.getNumChannels()).thenReturn(4); + when(spannerOptions.getDatabaseRole()).thenReturn("role"); + options = + SessionPoolOptions.newBuilder() + .setMinSessions(2) + .setMaxSessions(2) + .setUseMultiplexedSession(true) + .build(); + } + + @Test + public void testGetMultiplexedSession_whenClosedPool_assertSessionReturned() { + pool = createPool(); + assertTrue(pool.isValid()); + closePoolWithStacktrace(); + + // checking out a multiplexed session does not throw error even if pool is closed + MultiplexedSessionFuture multiplexedSessionFuture = + (MultiplexedSessionFuture) pool.getMultiplexedSessionWithFallback(); + assertNotNull(multiplexedSessionFuture); + + // checking out a regular session throws error. + IllegalStateException e = assertThrows(IllegalStateException.class, () -> pool.getSession()); + assertThat(e.getCause()).isInstanceOf(ClosedException.class); + StringWriter sw = new StringWriter(); + e.getCause().printStackTrace(new PrintWriter(sw)); + assertThat(sw.toString()).contains("closePoolWithStacktrace"); + } + + private void closePoolWithStacktrace() { + pool.closeAsync(new SpannerImpl.ClosedException()); + } + + @Test + public void sessionCreation() { + setupMockMultiplexedSessionCreation(); + pool = createPool(); + try (MultiplexedSessionFuture sessionFuture = + (MultiplexedSessionFuture) pool.getMultiplexedSessionWithFallback()) { + assertNotNull(sessionFuture); + } + } + + @Test + public void testSynchronousPoolInit_hasAtMostOneMultiplexedSession() { + setupMockMultiplexedSessionCreation(); + + pool = createPool(); + pool.waitOnMultiplexedSession(); + + assertEquals(1, pool.totalMultiplexedSessions()); + Session session1 = pool.getMultiplexedSessionWithFallback().get(); + Session session2 = pool.getMultiplexedSessionWithFallback().get(); + assertEquals(session1, session2); + + session2.close(); + session1.close(); + verify(sessionClient, times(1)).createMultiplexedSession(any(MultiplexedSessionConsumer.class)); + } + + @Test + public void testGetMultiplexedSession_whenSessionCreationFailed_assertErrorForWaiters() { + doAnswer( + invocation -> { + MultiplexedSessionConsumer consumer = + invocation.getArgument(0, MultiplexedSessionConsumer.class); + consumer.onSessionCreateFailure( + SpannerExceptionFactory.newSpannerException(ErrorCode.INTERNAL, ""), 1); + return null; + }) + .when(sessionClient) + .createMultiplexedSession(any(MultiplexedSessionConsumer.class)); + options = + options + .toBuilder() + .setMinSessions(2) + .setUseMultiplexedSession(true) + .setAcquireSessionTimeout( + Duration.ofMillis(50)) // block for a max of 100 ms for session to be available + .build(); + pool = createPool(); + SpannerException e = + assertThrows(SpannerException.class, () -> pool.getMultiplexedSessionWithFallback().get()); + assertEquals(ErrorCode.RESOURCE_EXHAUSTED, e.getErrorCode()); + } + + private void setupMockMultiplexedSessionCreation() { + doAnswer( + invocation -> { + MultiplexedSessionConsumer consumer = + invocation.getArgument(0, MultiplexedSessionConsumer.class); + consumer.onSessionReady(mockMultiplexedSession()); + return null; + }) + .when(sessionClient) + .createMultiplexedSession(any(MultiplexedSessionConsumer.class)); + } +} diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SessionPoolTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SessionPoolTest.java index e1c295c14d..00b51b1329 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SessionPoolTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SessionPoolTest.java @@ -137,6 +137,9 @@ public class SessionPoolTest extends BaseSessionPoolTest { private final ExecutorService executor = Executors.newSingleThreadExecutor(); @Parameter public int minSessions; + @Parameter(1) + public boolean useMultiplexed; + @Mock SpannerImpl client; @Mock SessionClient sessionClient; @Mock SpannerOptions spannerOptions; @@ -149,9 +152,14 @@ public class SessionPoolTest extends BaseSessionPoolTest { private final TraceWrapper tracer = new TraceWrapper(Tracing.getTracer(), OpenTelemetry.noop().getTracer("")); - @Parameters(name = "min sessions = {0}") + @Parameters(name = "min sessions = {0}, use multiplexed = {1}") public static Collection data() { - return Arrays.asList(new Object[][] {{0}, {1}}); + List params = new ArrayList<>(); + params.add(new Object[] {0, false}); + params.add(new Object[] {1, false}); + params.add(new Object[] {1, true}); + + return params; } private SessionPool createPool() { @@ -239,6 +247,7 @@ public void setUp() { .setMaxSessions(2) .setIncStep(1) .setBlockIfPoolExhausted() + .setUseMultiplexedSession(useMultiplexed) .build(); }