Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [v0.6.0] - 2025-02-05

### Fixed

* Fixed race conditions in lock acquisition in PG backend ([#108](https://github.com/microsoft/durabletask-go/pull/108)) - contributed by [@JonathanFejtek](https://github.com/JonathanFejtek)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JonathanFejtek the v0.6.0 tag has already been published. Can you add a new section above the v0.6.0 section called [vNext] and add your entry there?



### Added

- Add API to set custom status ([#81](https://github.com/microsoft/durabletask-go/pull/81)) - by [@famarting](https://github.com/famarting)
Expand Down
2 changes: 1 addition & 1 deletion backend/postgres/README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Postgres Backend
### Testing
By default, the postgres tests are skipped. To run the tests, set the environment variable `POSTGRES_ENABLED` to `true` before running the tests and have a postgres server running on `localhost:5432` with a database named `postgres` and a user `postgres` with password `postgres`.
By default, the postgres tests are skipped. PG tests use github.com/testcontainers/testcontainers-go/modules/postgres which require a Docker-API compatible container runtime.
199 changes: 116 additions & 83 deletions backend/postgres/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"errors"
"fmt"
"os"
"slices"
"strings"
"time"

Expand Down Expand Up @@ -44,7 +45,7 @@ func NewPostgresOptions(host string, port uint16, database string, user string,
if err != nil {
panic(fmt.Errorf("failed to parse the postgres connection string: %w", err))
}
conf.ConnConfig.Config.ConnectTimeout = 2 * time.Minute
conf.ConnConfig.ConnectTimeout = 2 * time.Minute
conf.MaxConnLifetime = 2 * time.Minute
conf.MaxConnIdleTime = 2 * time.Minute
conf.MaxConns = 1
Expand Down Expand Up @@ -136,7 +137,11 @@ func (be *postgresBackend) AbandonOrchestrationWorkItem(ctx context.Context, wi
if err != nil {
return err
}
defer tx.Rollback(ctx)
defer func() {
if err := tx.Rollback(ctx); err != nil {
be.logger.Error("AbandonOrchestrationWorkItem", "failed to rollback transaction", err)
}
}()

var visibleTime *time.Time = nil
if delay := wi.GetAbandonDelay(); delay > 0 {
Expand All @@ -155,10 +160,7 @@ func (be *postgresBackend) AbandonOrchestrationWorkItem(ctx context.Context, wi
return fmt.Errorf("failed to update NewEvents table: %w", err)
}

rowsAffected := dbResult.RowsAffected()
if err != nil {
return fmt.Errorf("failed get rows affected by UPDATE NewEvents statement: %w", err)
} else if rowsAffected == 0 {
if dbResult.RowsAffected() == 0 {
return backend.ErrWorkItemLockLost
}

Expand All @@ -168,15 +170,11 @@ func (be *postgresBackend) AbandonOrchestrationWorkItem(ctx context.Context, wi
string(wi.InstanceID),
wi.LockedBy,
)

if err != nil {
return fmt.Errorf("failed to update Instances table: %w", err)
}

rowsAffected = dbResult.RowsAffected()
if err != nil {
return fmt.Errorf("failed get rows affected by UPDATE Instances statement: %w", err)
} else if rowsAffected == 0 {
if dbResult.RowsAffected() == 0 {
return backend.ErrWorkItemLockLost
}

Expand All @@ -197,7 +195,11 @@ func (be *postgresBackend) CompleteOrchestrationWorkItem(ctx context.Context, wi
if err != nil {
return err
}
defer tx.Rollback(ctx)
defer func() {
if err := tx.Rollback(ctx); err != nil {
be.logger.Error("CompleteOrchestrationWorkItem", "failed to rollback transaction", err)
}
}()

now := time.Now().UTC()

Expand Down Expand Up @@ -260,10 +262,7 @@ func (be *postgresBackend) CompleteOrchestrationWorkItem(ctx context.Context, wi
return fmt.Errorf("failed to update Instances table: %w", err)
}

count := result.RowsAffected()
if err != nil {
return fmt.Errorf("failed to get the number of rows affected by the Instance table update: %w", err)
} else if count == 0 {
if result.RowsAffected() == 0 {
return fmt.Errorf("instance '%s' no longer exists or was locked by a different worker", string(wi.InstanceID))
}

Expand All @@ -287,7 +286,7 @@ func (be *postgresBackend) CompleteOrchestrationWorkItem(ctx context.Context, wi
}
query := builder.String()

args := make([]interface{}, 0, newHistoryCount*3)
args := make([]any, 0, newHistoryCount*3)
nextSequenceNumber := len(wi.State.OldEvents())
for _, e := range wi.State.NewEvents() {
eventPayload, err := backend.MarshalHistoryEvent(e)
Expand Down Expand Up @@ -318,7 +317,7 @@ func (be *postgresBackend) CompleteOrchestrationWorkItem(ctx context.Context, wi
}
insertSql := builder.String()

sqlInsertArgs := make([]interface{}, 0, newActivityCount*2)
sqlInsertArgs := make([]any, 0, newActivityCount*2)
for _, e := range wi.State.PendingTasks() {
eventPayload, err := backend.MarshalHistoryEvent(e)
if err != nil {
Expand Down Expand Up @@ -401,17 +400,10 @@ func (be *postgresBackend) CompleteOrchestrationWorkItem(ctx context.Context, wi
return fmt.Errorf("failed to delete from NewEvents table: %w", err)
}

rowsAffected := dbResult.RowsAffected()
if err != nil {
return fmt.Errorf("failed get rows affected by delete statement: %w", err)
} else if rowsAffected == 0 {
if dbResult.RowsAffected() == 0 {
return backend.ErrWorkItemLockLost
}

if err != nil {
return fmt.Errorf("failed to delete from the NewEvents table: %w", err)
}

if err = tx.Commit(ctx); err != nil {
return fmt.Errorf("failed to commit transaction: %w", err)
}
Expand All @@ -429,7 +421,11 @@ func (be *postgresBackend) CreateOrchestrationInstance(ctx context.Context, e *b
if err != nil {
return fmt.Errorf("failed to start transaction: %w", err)
}
defer tx.Rollback(ctx)
defer func() {
if err := tx.Rollback(ctx); err != nil {
be.logger.Error("CreateOrchestrationInstance", "failed to rollback transaction", err)
}
}()

var instanceID string
if instanceID, err = be.createOrchestrationInstanceInternal(ctx, e, tx, opts...); errors.Is(err, api.ErrIgnoreInstance) {
Expand Down Expand Up @@ -517,11 +513,7 @@ func insertOrIgnoreInstanceTableInternal(ctx context.Context, tx pgx.Tx, e *back
return -1, fmt.Errorf("failed to insert into Instances table: %w", err)
}

rows := res.RowsAffected()
if err != nil {
return -1, fmt.Errorf("failed to count the rows affected: %w", err)
}
return rows, nil
return res.RowsAffected(), nil
}

func (be *postgresBackend) handleInstanceExists(ctx context.Context, tx pgx.Tx, startEvent *protos.ExecutionStartedEvent, policy *protos.OrchestrationIdReusePolicy, e *backend.HistoryEvent) error {
Expand Down Expand Up @@ -572,12 +564,7 @@ func (be *postgresBackend) handleInstanceExists(ctx context.Context, tx pgx.Tx,
}

func isStatusMatch(statuses []protos.OrchestrationStatus, runtimeStatus protos.OrchestrationStatus) bool {
for _, status := range statuses {
if status == runtimeStatus {
return true
}
}
return false
return slices.Contains(statuses, runtimeStatus)
}

func (be *postgresBackend) cleanupOrchestrationStateInternal(ctx context.Context, tx pgx.Tx, id api.InstanceID, requireCompleted bool) error {
Expand All @@ -596,11 +583,7 @@ func (be *postgresBackend) cleanupOrchestrationStateInternal(ctx context.Context
return fmt.Errorf("failed to delete from the Instances table: %w", err)
}

rowsAffected := dbResult.RowsAffected()
if err != nil {
return fmt.Errorf("failed to get rows affected in Instances delete operation: %w", err)
}
if rowsAffected == 0 {
if dbResult.RowsAffected() == 0 {
return api.ErrNotCompleted
}
} else {
Expand Down Expand Up @@ -763,37 +746,57 @@ func (be *postgresBackend) GetOrchestrationWorkItem(ctx context.Context) (*backe
if err != nil {
return nil, err
}
defer tx.Rollback(ctx)
defer func() {
if err := tx.Rollback(ctx); err != nil {
be.logger.Error("GetOrchestrationWorkItem", "failed to rollback transaction", err)
}
}()

now := time.Now().UTC()
newLockExpiration := now.Add(be.options.OrchestrationLockTimeout)

// Place a lock on an orchestration instance that has new events that are ready to be executed.
row := tx.QueryRow(
// First, select and lock an instance with FOR UPDATE SKIP LOCKED to prevent race conditions
// This ensures only one worker can acquire the lock on a given instance
selectRow := tx.QueryRow(
ctx,
`UPDATE Instances SET LockedBy = $1, LockExpiration = $2
WHERE SequenceNumber = (
SELECT SequenceNumber FROM Instances I
WHERE (I.LockExpiration IS NULL OR I.LockExpiration < $3) AND EXISTS (
SELECT 1 FROM NewEvents E
WHERE E.InstanceID = I.InstanceID AND (E.VisibleTime IS NULL OR E.VisibleTime < $4)
)
LIMIT 1
) RETURNING InstanceID`,
be.workerName, // LockedBy for Instances table
newLockExpiration, // Updated LockExpiration for Instances table
now, // LockExpiration for Instances table
now, // VisibleTime for NewEvents table
`SELECT I.SequenceNumber, I.InstanceID FROM Instances I
WHERE (I.LockExpiration IS NULL OR I.LockExpiration < $1) AND EXISTS (
SELECT 1 FROM NewEvents E
WHERE E.InstanceID = I.InstanceID AND (E.VisibleTime IS NULL OR E.VisibleTime < $2)
)
ORDER BY I.SequenceNumber
LIMIT 1
FOR UPDATE SKIP LOCKED`,
now, // LockExpiration check
now, // VisibleTime check
)

var sequenceNumber int
var instanceID string
if err := row.Scan(&instanceID); err != nil {
if err := selectRow.Scan(&sequenceNumber, &instanceID); err != nil {
if errors.Is(err, pgx.ErrNoRows) {
// No new events to process
return nil, backend.ErrNoWorkItems
}
return nil, fmt.Errorf("failed to select orchestration work-item: %w", err)
}

return nil, fmt.Errorf("failed to scan the orchestration work-item: %w", err)
// Now update the locked instance with our worker information
updateResult, err := tx.Exec(
ctx,
`UPDATE Instances SET LockedBy = $1, LockExpiration = $2
WHERE SequenceNumber = $3`,
be.workerName,
newLockExpiration,
sequenceNumber,
)
if err != nil {
return nil, fmt.Errorf("failed to lock orchestration work-item: %w", err)
}

if updateResult.RowsAffected() == 0 {
// This should not happen since we have the row locked, but check anyway
return nil, backend.ErrNoWorkItems
}

// TODO: Get all the unprocessed events associated with the locked instance
Expand Down Expand Up @@ -855,33 +858,61 @@ func (be *postgresBackend) GetActivityWorkItem(ctx context.Context) (*backend.Ac
return nil, err
}

// Begin transaction to hold the FOR UPDATE lock
tx, err := be.db.BeginTx(ctx, pgx.TxOptions{})
if err != nil {
return nil, err
}
defer func() {
if err := tx.Rollback(ctx); err != nil {
be.logger.Error("GetActivityWorkItem", "failed to rollback transaction", err)
}
}()

now := time.Now().UTC()
newLockExpiration := now.Add(be.options.OrchestrationLockTimeout)
newLockExpiration := now.Add(be.options.ActivityLockTimeout)

row := be.db.QueryRow(
// First, select and lock a task with FOR UPDATE SKIP LOCKED to prevent race conditions
// The row lock is held by the transaction until commit
selectRow := tx.QueryRow(
ctx,
`UPDATE NewTasks SET LockedBy = $1, LockExpiration = $2, DequeueCount = DequeueCount + 1
WHERE SequenceNumber = (
SELECT SequenceNumber FROM NewTasks T
WHERE T.LockExpiration IS NULL OR T.LockExpiration < $3
LIMIT 1
) RETURNING SequenceNumber, InstanceID, EventPayload`,
be.workerName,
newLockExpiration,
`SELECT SequenceNumber, InstanceID, EventPayload FROM NewTasks T
WHERE T.LockExpiration IS NULL OR T.LockExpiration < $1
ORDER BY SequenceNumber
LIMIT 1
FOR UPDATE SKIP LOCKED`,
now,
)

var sequenceNumber int64
var instanceID string
var eventPayload []byte

if err := row.Scan(&sequenceNumber, &instanceID, &eventPayload); err != nil {
if err := selectRow.Scan(&sequenceNumber, &instanceID, &eventPayload); err != nil {
if errors.Is(err, pgx.ErrNoRows) {
// No new activity tasks to process
return nil, backend.ErrNoWorkItems
}
return nil, fmt.Errorf("failed to select the activity work-item: %w", err)
}

return nil, fmt.Errorf("failed to scan the activity work-item: %w", err)
// Now update the locked task with our worker information
// The row is still locked by our transaction
_, err = tx.Exec(
ctx,
`UPDATE NewTasks SET LockedBy = $1, LockExpiration = $2, DequeueCount = DequeueCount + 1
WHERE SequenceNumber = $3`,
be.workerName,
newLockExpiration,
sequenceNumber,
)
if err != nil {
return nil, fmt.Errorf("failed to lock the activity work-item: %w", err)
}

// Commit the transaction, releasing the lock
if err = tx.Commit(ctx); err != nil {
return nil, fmt.Errorf("failed to commit activity work-item transaction: %w", err)
}

e, err := backend.UnmarshalHistoryEvent(eventPayload)
Expand All @@ -907,7 +938,11 @@ func (be *postgresBackend) CompleteActivityWorkItem(ctx context.Context, wi *bac
if err != nil {
return err
}
defer tx.Rollback(ctx)
defer func() {
if err := tx.Rollback(ctx); err != nil {
be.logger.Error("CompleteActivityWorkItem", "failed to rollback transaction", err)
}
}()

bytes, err := backend.MarshalHistoryEvent(wi.Result)
if err != nil {
Expand All @@ -924,10 +959,7 @@ func (be *postgresBackend) CompleteActivityWorkItem(ctx context.Context, wi *bac
return fmt.Errorf("failed to delete from NewTasks table: %w", err)
}

rowsAffected := dbResult.RowsAffected()
if err != nil {
return fmt.Errorf("failed get rows affected by delete statement: %w", err)
} else if rowsAffected == 0 {
if dbResult.RowsAffected() == 0 {
return backend.ErrWorkItemLockLost
}

Expand All @@ -953,10 +985,7 @@ func (be *postgresBackend) AbandonActivityWorkItem(ctx context.Context, wi *back
return fmt.Errorf("failed to update the NewTasks table for abandon: %w", err)
}

rowsAffected := dbResult.RowsAffected()
if err != nil {
return fmt.Errorf("failed get rows affected by update statement for abandon: %w", err)
} else if rowsAffected == 0 {
if dbResult.RowsAffected() == 0 {
return backend.ErrWorkItemLockLost
}

Expand All @@ -972,7 +1001,11 @@ func (be *postgresBackend) PurgeOrchestrationState(ctx context.Context, id api.I
if err != nil {
return err
}
defer tx.Rollback(ctx)
defer func() {
if err := tx.Rollback(ctx); err != nil {
be.logger.Error("PurgeOrchestrationState", "failed to rollback transaction", err)
}
}()

if err := be.cleanupOrchestrationStateInternal(ctx, tx, id, true); err != nil {
return err
Expand Down
Loading