From 1598184f74599c6f0a5e053f167f90f8747250c1 Mon Sep 17 00:00:00 2001 From: Adriano Caloiaro Date: Sat, 11 Nov 2023 10:44:10 -0700 Subject: [PATCH] fix: Add deadline to pool connection acquisition --- backends/postgres/postgres_backend.go | 89 +++++++++++++++++----- backends/postgres/postgres_backend_test.go | 84 ++++++++++++++++++++ neoq.go | 1 + 3 files changed, 157 insertions(+), 17 deletions(-) diff --git a/backends/postgres/postgres_backend.go b/backends/postgres/postgres_backend.go index 4ee71be..6b4edc6 100644 --- a/backends/postgres/postgres_backend.go +++ b/backends/postgres/postgres_backend.go @@ -58,12 +58,15 @@ const ( type contextKey struct{} var ( - txCtxVarKey contextKey - shutdownJobID = "-1" // job ID announced when triggering a shutdown - shutdownAnnouncementAllowance = 100 // ms - ErrCnxString = errors.New("invalid connecton string: see documentation for valid connection strings") - ErrDuplicateJob = errors.New("duplicate job") - ErrNoTransactionInContext = errors.New("context does not have a Tx set") + // DefaultConnectionTimeout defines the default amount of time that Neoq waits for connections to become available. + DefaultConnectionTimeout = 30 * time.Second + txCtxVarKey contextKey + shutdownJobID = "-1" // job ID announced when triggering a shutdown + shutdownAnnouncementAllowance = 100 // ms + ErrCnxString = errors.New("invalid connecton string: see documentation for valid connection strings") + ErrDuplicateJob = errors.New("duplicate job") + ErrNoTransactionInContext = errors.New("context does not have a Tx set") + ErrExceededConnectionPoolTimeout = errors.New("exceeded timeout acquiring a connection from the pool") ) // PgBackend is a Postgres-based Neoq backend @@ -72,8 +75,8 @@ type PgBackend struct { config *neoq.Config logger logging.Logger cron *cron.Cron - mu *sync.RWMutex // mutex to protect mutating state on a pgWorker pool *pgxpool.Pool + mu *sync.RWMutex // mutex to protect mutating state on a pgWorker futureJobs map[string]*jobs.Job // map of future job IDs to the corresponding job record handlers map[string]handler.Handler // a map of queue names to queue handlers cancelFuncs []context.CancelFunc // A collection of cancel functions to be called upon Shutdown() @@ -106,6 +109,7 @@ type PgBackend struct { func Backend(ctx context.Context, opts ...neoq.ConfigOption) (pb neoq.Neoq, err error) { cfg := neoq.NewConfig() cfg.IdleTransactionTimeout = neoq.DefaultIdleTxTimeout + cfg.PGConnectionTimeout = DefaultConnectionTimeout p := &PgBackend{ mu: &sync.RWMutex{}, @@ -187,6 +191,15 @@ func WithTransactionTimeout(txTimeout int) neoq.ConfigOption { } } +// WithConnectionTimeout sets the duration that Neoq waits for connections to become available to process and enqueue jobs +// +// Note: ConnectionTimeout does not affect how long neoq waits for connections to run schema migrations +func WithConnectionTimeout(timeout time.Duration) neoq.ConfigOption { + return func(c *neoq.Config) { + c.PGConnectionTimeout = timeout + } +} + // WithSynchronousCommit enables postgres parameter `synchronous_commit`. // // By default, neoq runs with synchronous_commit disabled. @@ -281,7 +294,7 @@ func (p *PgBackend) Enqueue(ctx context.Context, job *jobs.Job) (jobID string, e p.logger.Debug("enqueueing job payload", slog.String("queue", job.Queue), slog.Any("job_payload", job.Payload)) p.logger.Debug("acquiring new connection from connection pool", slog.String("queue", job.Queue)) - conn, err := p.pool.Acquire(ctx) + conn, err := p.acquire(ctx) if err != nil { err = fmt.Errorf("error acquiring connection: %w", err) return @@ -541,13 +554,18 @@ func (p *PgBackend) start(ctx context.Context, h handler.Handler) (err error) { return fmt.Errorf("%w: %s", handler.ErrNoHandlerForQueue, h.Queue) } - listenJobChan, ready := p.listen(ctx, h.Queue) // listen for 'new' jobs + listenJobChan, ready, errCh := p.listen(ctx, h.Queue) // listen for 'new' jobs defer close(ready) pendingJobsChan := p.pendingJobs(ctx, h.Queue) // process overdue jobs *at startup* // wait for the listener to connect and be ready to listen - <-ready + select { + case <-ready: + break + case err = <-errCh: + return + } // process all future jobs and retries go func() { p.scheduleFutureJobs(ctx, h.Queue) }() @@ -654,7 +672,7 @@ func (p *PgBackend) scheduleFutureJobs(ctx context.Context, queue string) { // // Announced jobs are executed by the first worker to respond to the announcement. func (p *PgBackend) announceJob(ctx context.Context, queue, jobID string) { - conn, err := p.pool.Acquire(ctx) + conn, err := p.acquire(ctx) if err != nil { return } @@ -684,7 +702,7 @@ func (p *PgBackend) announceJob(ctx context.Context, queue, jobID string) { func (p *PgBackend) pendingJobs(ctx context.Context, queue string) (jobsCh chan string) { jobsCh = make(chan string) - conn, err := p.pool.Acquire(ctx) + conn, err := p.acquire(ctx) if err != nil { p.logger.Error( "failed to acquire database connection to listen for pending queue items", @@ -716,7 +734,7 @@ func (p *PgBackend) pendingJobs(ctx context.Context, queue string) (jobsCh chan } }(ctx) - return + return jobsCh } // handleJob is the workhorse of Neoq @@ -726,7 +744,7 @@ func (p *PgBackend) pendingJobs(ctx context.Context, queue string) (jobsCh chan func (p *PgBackend) handleJob(ctx context.Context, jobID string, h handler.Handler) (err error) { var job *jobs.Job var tx pgx.Tx - conn, err := p.pool.Acquire(ctx) + conn, err := p.acquire(ctx) if err != nil { return } @@ -784,14 +802,16 @@ func (p *PgBackend) handleJob(ctx context.Context, jobID string, h handler.Handl // TODO: There is currently no handling of listener disconnects in PgBackend. // This will lead to jobs not getting processed until the worker is restarted. // Implement disconnect handling. -func (p *PgBackend) listen(ctx context.Context, queue string) (c chan string, ready chan bool) { +func (p *PgBackend) listen(ctx context.Context, queue string) (c chan string, ready chan bool, errCh chan error) { c = make(chan string, p.handlers[queue].Concurrency) ready = make(chan bool) + errCh = make(chan error) go func(ctx context.Context) { - conn, err := p.pool.Acquire(ctx) + conn, err := p.acquire(ctx) if err != nil { p.logger.Error("unable to acquire new listener connection", slog.String("queue", queue), slog.Any("error", err)) + errCh <- err return } defer p.release(ctx, conn, queue) @@ -801,6 +821,7 @@ func (p *PgBackend) listen(ctx context.Context, queue string) (c chan string, re if err != nil { err = fmt.Errorf("unable to configure listener connection: %w", err) p.logger.Error("unable to configure listener connection", slog.String("queue", queue), slog.Any("error", err)) + errCh <- err return } @@ -833,7 +854,7 @@ func (p *PgBackend) listen(ctx context.Context, queue string) (c chan string, re } }(ctx) - return c, ready + return c, ready, errCh } func (p *PgBackend) release(ctx context.Context, conn *pgxpool.Conn, queue string) { @@ -873,6 +894,40 @@ func (p *PgBackend) getPendingJobID(ctx context.Context, conn *pgxpool.Conn, que return } +// acquire acquires connections from the connection pool with a timeout +// +// the purpose of this function is to skirt pgxpool's default blocking behavior with connection acquisition preemtion +func (p *PgBackend) acquire(ctx context.Context) (conn *pgxpool.Conn, err error) { + ctx, cancelFunc := context.WithDeadline(ctx, time.Now().Add(p.config.PGConnectionTimeout)) + defer cancelFunc() + + p.logger.Debug("acquiring connection with timeout", slog.Any("timeout", p.config.PGConnectionTimeout)) + + connCh := make(chan *pgxpool.Conn) + errCh := make(chan error) + + go func() { + c, err := p.pool.Acquire(ctx) + if err != nil { + errCh <- err + } + + connCh <- c + }() + + select { + case conn = <-connCh: + return conn, nil + case err := <-errCh: + return nil, err + case <-ctx.Done(): + p.logger.Error("exceeded timeout acquiring a connection from the pool", slog.Any("timeout", p.config.PGConnectionTimeout)) + cancelFunc() + err = ErrExceededConnectionPoolTimeout + return + } +} + // withJobContext creates a new context with the Job set func withJobContext(ctx context.Context, j *jobs.Job) context.Context { return context.WithValue(ctx, internal.JobCtxVarKey, j) diff --git a/backends/postgres/postgres_backend_test.go b/backends/postgres/postgres_backend_test.go index 57bc0e3..e2dc4c0 100644 --- a/backends/postgres/postgres_backend_test.go +++ b/backends/postgres/postgres_backend_test.go @@ -4,7 +4,9 @@ import ( "context" "errors" "fmt" + "log" "os" + "regexp" "strings" "sync" "sync/atomic" @@ -738,3 +740,85 @@ func TestBasicJobMultipleQueueWithError(t *testing.T) { t.Error("should be dead") } } + +// Test_MoveJobsToDeadQueue tests that when a job's MaxRetries is reached, that the job is moved ot the dead queue successfully +// https://github.com/acaloiaro/neoq/issues/98 +func Test_ConnectionTimeout(t *testing.T) { + connString, _ := prepareAndCleanupDB(t) + + const queue = "testing" + done := make(chan bool) + defer close(done) + + ctx := context.Background() + nq, err := neoq.New(ctx, + neoq.WithBackend(postgres.Backend), + postgres.WithConnectionString(connString), + postgres.WithConnectionTimeout(0*time.Second)) + if err != nil { + t.Fatal(err) + } + + h := handler.New(queue, func(_ context.Context) (err error) { + done <- true + return + }) + + go func() { + err = nq.Start(ctx, h) + done <- true + }() + + timeoutTimer := time.After(5 * time.Second) + select { + case <-timeoutTimer: + err = jobs.ErrJobTimeout + case <-done: + } + + if !errors.Is(err, postgres.ErrExceededConnectionPoolTimeout) { + t.Error(err) + } + + // Create an instance with a non-zero timeout, but only give allow a pool size of 1 + // this will trquire a failure to acquire connections when the number of Start() calls exceeds 1 + nq, err = neoq.New(ctx, + neoq.WithBackend(postgres.Backend), + postgres.WithConnectionString(maxConnsDBUrl(1)), + postgres.WithConnectionTimeout(100*time.Millisecond)) + if err != nil { + t.Fatal(err) + return + } + + go func() { + err = nq.Start(ctx, h) + if err != nil { + return + } + + err = nq.Start(ctx, h) + done <- true + }() + + timeoutTimer = time.After(5 * time.Second) + select { + case <-timeoutTimer: + err = jobs.ErrJobTimeout + case <-done: + } + + log.Println("The error is", err) + if !errors.Is(err, postgres.ErrExceededConnectionPoolTimeout) { + t.Error(err) + } +} + +func maxConnsDBUrl(maxConns int) (dbURL string) { + dbURL = os.Getenv("TEST_DATABASE_URL") + r := regexp.MustCompile(`pool_max_conns=\d+`) + dbURL = string(r.ReplaceAll([]byte(dbURL), []byte(fmt.Sprintf("pool_max_conns=%d", maxConns)))) + + log.Println("URL", dbURL) + return +} diff --git a/neoq.go b/neoq.go index 3749506..f1022f1 100644 --- a/neoq.go +++ b/neoq.go @@ -38,6 +38,7 @@ type Config struct { ShutdownTimeout time.Duration // duration to wait for jobs to finish during shutdown SynchronousCommit bool // Postgres: Enable synchronous commits (increases durability, decreases performance) LogLevel logging.LogLevel // the log level of the default logger + PGConnectionTimeout time.Duration // the amount of time to wait for a connection to become available before timing out } // ConfigOption is a function that sets optional backend configuration