Skip to content

Commit 5c4016d

Browse files
authored
Merge pull request go-gorm#5455 from longbridgeapp/feat-support-transaction-calllback
2 parents c74bc57 + 2cb4088 commit 5c4016d

File tree

2 files changed

+32
-21
lines changed

2 files changed

+32
-21
lines changed

callbacks.go

+31-6
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package gorm
22

33
import (
44
"context"
5+
"database/sql"
56
"errors"
67
"fmt"
78
"reflect"
@@ -15,12 +16,13 @@ import (
1516
func initializeCallbacks(db *DB) *callbacks {
1617
return &callbacks{
1718
processors: map[string]*processor{
18-
"create": {db: db},
19-
"query": {db: db},
20-
"update": {db: db},
21-
"delete": {db: db},
22-
"row": {db: db},
23-
"raw": {db: db},
19+
"create": {db: db},
20+
"query": {db: db},
21+
"update": {db: db},
22+
"delete": {db: db},
23+
"row": {db: db},
24+
"raw": {db: db},
25+
"transaction": {db: db},
2426
},
2527
}
2628
}
@@ -72,6 +74,29 @@ func (cs *callbacks) Raw() *processor {
7274
return cs.processors["raw"]
7375
}
7476

77+
func (cs *callbacks) Transaction() *processor {
78+
return cs.processors["transaction"]
79+
}
80+
81+
func (p *processor) Begin(tx *DB, opt *sql.TxOptions) *DB {
82+
var err error
83+
84+
switch beginner := tx.Statement.ConnPool.(type) {
85+
case TxBeginner:
86+
tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
87+
case ConnPoolBeginner:
88+
tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
89+
default:
90+
err = ErrInvalidTransaction
91+
}
92+
93+
if err != nil {
94+
_ = tx.AddError(err)
95+
}
96+
97+
return tx
98+
}
99+
75100
func (p *processor) Execute(db *DB) *DB {
76101
// call scopes
77102
for len(db.Statement.scopes) > 0 {

finisher_api.go

+1-15
Original file line numberDiff line numberDiff line change
@@ -619,27 +619,13 @@ func (db *DB) Begin(opts ...*sql.TxOptions) *DB {
619619
// clone statement
620620
tx = db.getInstance().Session(&Session{Context: db.Statement.Context, NewDB: db.clone == 1})
621621
opt *sql.TxOptions
622-
err error
623622
)
624623

625624
if len(opts) > 0 {
626625
opt = opts[0]
627626
}
628627

629-
switch beginner := tx.Statement.ConnPool.(type) {
630-
case TxBeginner:
631-
tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
632-
case ConnPoolBeginner:
633-
tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
634-
default:
635-
err = ErrInvalidTransaction
636-
}
637-
638-
if err != nil {
639-
tx.AddError(err)
640-
}
641-
642-
return tx
628+
return tx.callbacks.Transaction().Begin(tx, opt)
643629
}
644630

645631
// Commit commit a transaction

0 commit comments

Comments
 (0)