Skip to content

Commit

Permalink
Add ability to skip rollback due to an error (#69)
Browse files Browse the repository at this point in the history
  • Loading branch information
maranqz authored Jun 16, 2023
1 parent a590ac6 commit 98ce514
Show file tree
Hide file tree
Showing 6 changed files with 257 additions and 3 deletions.
10 changes: 10 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,16 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](http://keepachangelog.com/)
and this project adheres to [Semantic Versioning](http://semver.org/).

## [1.3.0] - 2023-06-16

### Added

- Ability to skip rollback due to an error

### Other

- Bumped library versions

## [1.2.2] - 2023-05-20

### Other
Expand Down
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ Compatibility beyond that is not guaranteed.

**For nested transactions with different transaction managers**, you need to use [ChainedMW](trm/manager/chain.go) ([docs](https://pkg.go.dev/github.com/github.com/avito-tech/go-transaction-manager)).

**To skip a transaction rollback due to an error, use [ErrSkip](blob/main/trm/manager.go#L20) or [Skippable](blob/main/trm/manager.go#L24)**

### Explanation of the approach ([English](https://www.youtube.com/watch?v=aRsea6FFAyA), [Russian](https://habr.com/ru/companies/avito/articles/727168/))

### Examples with an ideal *repository* and nested transactions.
Expand Down
47 changes: 46 additions & 1 deletion trm/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,12 @@ package trm

//go:generate mockgen -source=$GOFILE -destination=mock/$GOFILE -package=mock

import "context"
import (
"context"
"errors"

"go.uber.org/multierr"
)

// Manager manages a transaction from Begin to Commit or Rollback.
type Manager interface {
Expand All @@ -11,3 +16,43 @@ type Manager interface {
// DoWithSettings processes a transaction inside a closure with custom trm.Settings.
DoWithSettings(context.Context, Settings, func(ctx context.Context) error) error
}

// ErrSkip marks error to skip rollback for transaction because of inside error.
var ErrSkip = errors.New("skippable")

// Skippable marks error as ErrSkip.
func Skippable(err error) error {
if err == nil {
return nil
}

return multierr.Append(err, ErrSkip)
}

// UnSkippable removes ErrSkip from error.
func UnSkippable(err error) error {
if err == nil || !IsSkippable(err) {
return err
}

ee := multierr.Errors(err)
res := make([]error, 0, len(ee))

for _, e := range ee {
//nolint:errorlint,goerr113
if e != ErrSkip {
res = append(res, e)
}
}

return multierr.Combine(res...)
}

// IsSkippable checks that the error is ErrSkip.
func IsSkippable(err error) bool {
if err == nil {
return false
}

return errors.Is(err, ErrSkip)
}
13 changes: 11 additions & 2 deletions trm/manager/closer.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ func newTxCommit(tr trm.Transaction, l logger, c context.CancelFunc) Closer {
}).close
}

//nolint:funlen
func (c *trCloser) close(ctx context.Context, p interface{}, errInProcessTr *error) error {
defer c.cancel()

Expand All @@ -42,6 +43,7 @@ func (c *trCloser) close(ctx context.Context, p interface{}, errInProcessTr *err
}

hasError := *errInProcessTr != nil
isErrSkippable := hasError && trm.IsSkippable(*errInProcessTr)
// TODO not sure that context errors should be propagated.
isCtxCanceled := errors.Is(*errInProcessTr, context.Canceled)
isCtxDeadlineExceeded := errors.Is(*errInProcessTr, context.DeadlineExceeded)
Expand Down Expand Up @@ -73,7 +75,7 @@ func (c *trCloser) close(ctx context.Context, p interface{}, errInProcessTr *err
return trm.ErrAlreadyClosed
}

if hasError {
if hasError && !isErrSkippable {
if errRollback := c.tr.Rollback(ctx); errRollback != nil {
return multierr.Combine(*errInProcessTr, trm.ErrRollback, errRollback)
}
Expand All @@ -82,7 +84,14 @@ func (c *trCloser) close(ctx context.Context, p interface{}, errInProcessTr *err
}

if err := c.tr.Commit(ctx); err != nil {
return multierr.Combine(trm.ErrCommit, err)
var errUnSkipped error
if isErrSkippable {
errUnSkipped = trm.UnSkippable(*errInProcessTr)
}

return multierr.Combine(trm.ErrCommit, err, errUnSkipped)
} else if isErrSkippable {
return *errInProcessTr
}

return nil
Expand Down
53 changes: 53 additions & 0 deletions trm/manager/manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,59 @@ func Test_transactionManager_Do_Error(t *testing.T) {
assert.ErrorIs(t, err, trm.ErrRollback)
},
},
"skip_rollback_with_error": {
args: defaultArgs,
fields: func(t *testing.T, ctrl *gomock.Controller, a args) fields {
return fields{
factory: func(ctx context.Context, _ trm.Settings) (context.Context, trm.Transaction, error) {
tx := mock.NewMockTransaction(ctrl)

tx.EXPECT().
IsActive().
Return(true)
tx.EXPECT().
Commit(gomock.Any())

return ctx, tx, nil
},
settings: a.settings,
log: mock_log.NewMocklogger(ctrl),
}
},
ret: trm.Skippable(testErr),
wantErr: func(t assert.TestingT, err error, i ...interface{}) bool {
return assert.ErrorIs(t, err, testErr) &&
assert.True(t, trm.IsSkippable(err))
},
},
"skip_rollback_with_commit_error": {
args: defaultArgs,
fields: func(t *testing.T, ctrl *gomock.Controller, a args) fields {
return fields{
factory: func(ctx context.Context, _ trm.Settings) (context.Context, trm.Transaction, error) {
tx := mock.NewMockTransaction(ctrl)

tx.EXPECT().
IsActive().
Return(true)
tx.EXPECT().
Commit(gomock.Any()).
Return(testCommitErr)

return ctx, tx, nil
},
settings: a.settings,
log: mock_log.NewMocklogger(ctrl),
}
},
ret: trm.Skippable(testErr),
wantErr: func(t assert.TestingT, err error, i ...interface{}) bool {
return assert.ErrorIs(t, err, testErr) &&
assert.ErrorIs(t, err, testCommitErr) &&
assert.ErrorIs(t, err, trm.ErrCommit) &&
assert.False(t, trm.IsSkippable(err))
},
},
//nolint:dupl
"commit_error": {
args: defaultArgs,
Expand Down
135 changes: 135 additions & 0 deletions trm/manager_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
package trm

import (
"errors"
"testing"

"github.com/stretchr/testify/assert"
)

var errTest = errors.New("test")

func TestIsSkippable(t *testing.T) {
t.Parallel()

type args struct {
err error
}

tests := map[string]struct {
args args
want bool
}{
"skippable": {
args: args{
err: Skippable(Skippable(errTest)),
},
want: true,
},
"unSkippable": {
args: args{
err: errTest,
},
want: false,
},
"nil": {
args: args{
err: nil,
},
want: false,
},
}
for name, tt := range tests {
tt := tt
t.Run(name, func(t *testing.T) {
t.Parallel()

got := IsSkippable(tt.args.err)

assert.Equal(t, tt.want, got)
})
}
}

func TestSkippable(t *testing.T) {
t.Parallel()

type args struct {
err error
}

tests := map[string]struct {
args args
wantErr assert.ErrorAssertionFunc
}{
"skippable": {
args: args{
err: Skippable(Skippable(errTest)),
},
wantErr: func(t assert.TestingT, err error, i ...interface{}) bool {
return assert.ErrorIs(t, err, ErrSkip) &&
assert.ErrorIs(t, err, errTest)
},
},
"nil": {
args: args{
err: Skippable(Skippable(nil)),
},
wantErr: func(t assert.TestingT, err error, i ...interface{}) bool {
return assert.Nil(t, err)
},
},
}
for name, tt := range tests {
tt := tt
t.Run(name, func(t *testing.T) {
t.Parallel()

err := Skippable(tt.args.err)

tt.wantErr(t, err)
})
}
}

func TestUnSkippable(t *testing.T) {
t.Parallel()

type args struct {
err error
}

tests := map[string]struct {
args args
wantErr assert.ErrorAssertionFunc
}{
"unSkippable": {
args: args{
err: UnSkippable(UnSkippable(
Skippable(Skippable(errTest)))),
},
wantErr: func(t assert.TestingT, err error, i ...interface{}) bool {
return assert.NotErrorIs(t, err, ErrSkip) &&
assert.ErrorIs(t, err, errTest)
},
},
"nil": {
args: args{
err: UnSkippable(UnSkippable(nil)),
},
wantErr: func(t assert.TestingT, err error, i ...interface{}) bool {
return assert.Nil(t, err)
},
},
}
for name, tt := range tests {
tt := tt
t.Run(name, func(t *testing.T) {
t.Parallel()

err := UnSkippable(tt.args.err)

tt.wantErr(t, err)
})
}
}

0 comments on commit 98ce514

Please sign in to comment.