diff --git a/pkg/ccl/sqlproxyccl/forwarder.go b/pkg/ccl/sqlproxyccl/forwarder.go index 247498d19835..4aaff1ae7381 100644 --- a/pkg/ccl/sqlproxyccl/forwarder.go +++ b/pkg/ccl/sqlproxyccl/forwarder.go @@ -202,6 +202,8 @@ func (f *forwarder) Context() context.Context { // // Close implements the balancer.ConnectionHandle interface. func (f *forwarder) Close() { + // Cancelling the forwarder's context and connections will automatically + // cause the processors to exit, and close themselves. f.ctxCancel() // Whenever Close is called while both of the processors are suspended, the @@ -389,7 +391,10 @@ func makeLogicalClockFn() func() uint64 { // cancellation of dials. var aLongTimeAgo = timeutil.Unix(1, 0) -var errProcessorResumed = errors.New("processor has already been resumed") +var ( + errProcessorResumed = errors.New("processor has already been resumed") + errProcessorClosed = errors.New("processor has been closed") +) // processor must always be constructed through newProcessor. type processor struct { @@ -402,6 +407,7 @@ type processor struct { mu struct { syncutil.Mutex cond *sync.Cond + closed bool resumed bool inPeek bool suspendReq bool // Indicates that a suspend has been requested. @@ -424,13 +430,15 @@ func newProcessor(logicalClockFn func() uint64, src, dst *interceptor.PGConn) *p // resume starts the processor and blocks during the processing. When the // processing has been terminated, this returns nil if the processor can be -// resumed again in the future. If an error (except errProcessorResumed) was -// returned, the processor should not be resumed again, and the forwarder should -// be closed. -func (p *processor) resume(ctx context.Context) error { +// resumed again in the future. If an error was returned, the processor should +// not be resumed again, and the forwarder must be closed. +func (p *processor) resume(ctx context.Context) (retErr error) { enterResume := func() error { p.mu.Lock() defer p.mu.Unlock() + if p.mu.closed { + return errProcessorClosed + } if p.mu.resumed { return errProcessorResumed } @@ -441,6 +449,10 @@ func (p *processor) resume(ctx context.Context) error { exitResume := func() { p.mu.Lock() defer p.mu.Unlock() + // If there's an error, close the processor. + if retErr != nil { + p.mu.closed = true + } p.mu.resumed = false p.mu.cond.Broadcast() } @@ -495,6 +507,9 @@ func (p *processor) resume(ctx context.Context) error { } if err := enterResume(); err != nil { + if errors.Is(err, errProcessorResumed) { + return nil + } return err } defer exitResume() @@ -524,6 +539,9 @@ func (p *processor) waitResumed(ctx context.Context) error { if ctx.Err() != nil { return ctx.Err() } + if p.mu.closed { + return errProcessorClosed + } p.mu.cond.Wait() } return nil @@ -536,6 +554,11 @@ func (p *processor) suspend(ctx context.Context) error { p.mu.Lock() defer p.mu.Unlock() + // If the processor has been closed, it cannot be suspended at all. + if p.mu.closed { + return errProcessorClosed + } + defer func() { if p.mu.suspendReq { p.mu.suspendReq = false diff --git a/pkg/ccl/sqlproxyccl/forwarder_test.go b/pkg/ccl/sqlproxyccl/forwarder_test.go index a9e32f37e8f1..848bd46c9191 100644 --- a/pkg/ccl/sqlproxyccl/forwarder_test.go +++ b/pkg/ccl/sqlproxyccl/forwarder_test.go @@ -521,12 +521,15 @@ func TestSuspendResumeProcessor(t *testing.T) { interceptor.NewPGConn(serverProxy), ) require.EqualError(t, p.resume(ctx), context.Canceled.Error()) + p.mu.Lock() + require.True(t, p.mu.closed) + p.mu.Unlock() // Set resumed to true to simulate suspend loop. p.mu.Lock() p.mu.resumed = true p.mu.Unlock() - require.EqualError(t, p.suspend(ctx), context.Canceled.Error()) + require.EqualError(t, p.suspend(ctx), errProcessorClosed.Error()) }) t.Run("wait_for_resumed", func(t *testing.T) { @@ -586,15 +589,15 @@ func TestSuspendResumeProcessor(t *testing.T) { interceptor.NewPGConn(serverProxy), ) - // Ensure that everything will return a resumed error except 1. + // Ensure that two resume calls will return right away. errCh := make(chan error, 2) go func() { errCh <- p.resume(ctx) }() go func() { errCh <- p.resume(ctx) }() go func() { errCh <- p.resume(ctx) }() err := <-errCh - require.EqualError(t, err, errProcessorResumed.Error()) + require.NoError(t, err) err = <-errCh - require.EqualError(t, err, errProcessorResumed.Error()) + require.NoError(t, err) // Suspend the last goroutine. err = p.waitResumed(ctx) @@ -604,7 +607,7 @@ func TestSuspendResumeProcessor(t *testing.T) { // Validate suspension. err = <-errCh - require.Nil(t, err) + require.NoError(t, err) p.mu.Lock() require.False(t, p.mu.resumed) require.False(t, p.mu.inPeek) @@ -694,10 +697,7 @@ func TestSuspendResumeProcessor(t *testing.T) { // Wait until all resume calls except 1 have returned. for i := 0; i < concurrency-1; i++ { err := <-errResumeCh - // If error is not nil, it has to be an already resumed error. - if err != nil { - require.EqualError(t, err, errProcessorResumed.Error()) - } + require.NoError(t, err) } // Wait until the last one returns. We can guarantee that this is for diff --git a/pkg/ccl/sqlproxyccl/proxy_handler.go b/pkg/ccl/sqlproxyccl/proxy_handler.go index 4f13591a54db..31dba30c3c4d 100644 --- a/pkg/ccl/sqlproxyccl/proxy_handler.go +++ b/pkg/ccl/sqlproxyccl/proxy_handler.go @@ -407,7 +407,7 @@ func (handler *proxyHandler) handle(ctx context.Context, incomingConn *proxyConn if err := f.run(fe.Conn, crdbConn); err != nil { // Don't send to the client here for the same reason below. handler.metrics.updateForError(err) - return err + return errors.Wrap(err, "running forwarder") } // Block until an error is received, or when the stopper starts quiescing, diff --git a/pkg/ccl/sqlproxyccl/proxy_handler_test.go b/pkg/ccl/sqlproxyccl/proxy_handler_test.go index b83b8c11d9f4..6e9ac3ed6c32 100644 --- a/pkg/ccl/sqlproxyccl/proxy_handler_test.go +++ b/pkg/ccl/sqlproxyccl/proxy_handler_test.go @@ -1325,6 +1325,76 @@ func TestConnectionMigration(t *testing.T) { }, 10*time.Second, 100*time.Millisecond) } +// TestCurConnCountMetric ensures that the CurConnCount metric is accurate. +// Previously, there was a regression where the CurConnCount metric wasn't +// decremented whenever the connections were closed due to a goroutine leak. +func TestCurConnCountMetric(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + ctx := context.Background() + + // Start KV server. + params, _ := tests.CreateTestServerParams() + s, _, _ := serverutils.StartServer(t, params) + defer s.Stopper().Stop(ctx) + + // Start a single SQL pod. + tenantID := serverutils.TestTenantID() + tenants := startTestTenantPods(ctx, t, s, tenantID, 1) + defer func() { + for _, tenant := range tenants { + tenant.Stopper().Stop(ctx) + } + }() + + // Register the SQL pod in the directory server. + tds := tenantdirsvr.NewTestStaticDirectoryServer(s.Stopper(), nil /* timeSource */) + tds.CreateTenant(tenantID, "tenant-cluster") + tds.AddPod(tenantID, &tenant.Pod{ + TenantID: tenantID.ToUint64(), + Addr: tenants[0].SQLAddr(), + State: tenant.RUNNING, + StateTimestamp: timeutil.Now(), + }) + require.NoError(t, tds.Start(ctx)) + + opts := &ProxyOptions{SkipVerify: true, DisableConnectionRebalancing: true} + opts.testingKnobs.directoryServer = tds + proxy, addr := newSecureProxyServer(ctx, t, s.Stopper(), opts) + connectionString := fmt.Sprintf("postgres://testuser:hunter2@%s/?sslmode=require&options=--cluster=tenant-cluster-%s", addr, tenantID) + + // Open 500 connections to the SQL pod. + const numConns = 500 + var wg sync.WaitGroup + wg.Add(numConns) + for i := 0; i < numConns; i++ { + go func() { + defer wg.Done() + + // Opens a new connection, runs SELECT 1, and closes it right away. + // Ignore all connection errors. + conn, err := pgx.Connect(ctx, connectionString) + if err != nil { + return + } + _ = conn.Ping(ctx) + _ = conn.Close(ctx) + }() + } + wg.Wait() + + // Ensure that the CurConnCount metric gets decremented to 0 whenever all + // the connections are closed. + testutils.SucceedsSoon(t, func() error { + val := proxy.metrics.CurConnCount.Value() + if val == 0 { + return nil + } + return errors.Newf("expected CurConnCount=0, but got %d", val) + }) +} + func TestClusterNameAndTenantFromParams(t *testing.T) { defer leaktest.AfterTest(t)() defer log.Scope(t).Close(t)