Skip to content

Commit e46467d

Browse files
authored
Merge pull request #18 from knocknote/feature/add_condition_before_setting_global_callback
Add condition before setting global callback
2 parents 27424a2 + 7953a6f commit e46467d

File tree

2 files changed

+43
-8
lines changed

2 files changed

+43
-8
lines changed

database/sql/tx.go

+12-8
Original file line numberDiff line numberDiff line change
@@ -122,14 +122,18 @@ func (proxy *Tx) begin(conn *connection.DBConnection) {
122122
return
123123
}
124124
tx := conn.Begin(proxy.ctx, proxy.opts)
125-
proxy.BeforeCommitCallback(func(writeQueries []*QueryLog) error {
126-
return errors.WithStack(globalBeforeCommitCallback(proxy, writeQueries))
127-
})
128-
proxy.AfterCommitCallback(func() error {
129-
return errors.WithStack(globalAfterCommitSuccessCallback(proxy))
130-
}, func(isCritical bool, failureQueries []*QueryLog) error {
131-
return errors.WithStack(globalAfterCommitFailureCallback(proxy, isCritical, failureQueries))
132-
})
125+
if proxy.beforeCommitCallback == nil {
126+
proxy.BeforeCommitCallback(func(writeQueries []*QueryLog) error {
127+
return errors.WithStack(globalBeforeCommitCallback(proxy, writeQueries))
128+
})
129+
}
130+
if proxy.afterCommitSuccessCallback == nil && proxy.afterCommitFailureCallback == nil {
131+
proxy.AfterCommitCallback(func() error {
132+
return errors.WithStack(globalAfterCommitSuccessCallback(proxy))
133+
}, func(isCritical bool, failureQueries []*QueryLog) error {
134+
return errors.WithStack(globalAfterCommitFailureCallback(proxy, isCritical, failureQueries))
135+
})
136+
}
133137
proxy.tx = tx
134138
}
135139

transaction_test.go

+31
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,37 @@ func TestCommitCallbackForTx(t *testing.T) {
320320
}
321321
}
322322

323+
func TestCommitCallbackForTxSetCallbackBeforeQueryExec(t *testing.T) {
324+
db, err := sql.Open("", "")
325+
if err != nil {
326+
t.Fatalf("%+v\n", err)
327+
}
328+
tx, err := db.Begin()
329+
if err != nil {
330+
t.Fatalf("%+v\n", err)
331+
}
332+
isInvokedBeforeCommitCallback := false
333+
tx.BeforeCommitCallback(func(writeQueries []*sql.QueryLog) error {
334+
isInvokedBeforeCommitCallback = true
335+
return nil
336+
})
337+
isInvokedAfterCommitCallback := false
338+
tx.AfterCommitCallback(func() error {
339+
isInvokedAfterCommitCallback = true
340+
return nil
341+
}, func(isCriticalError bool, failureQueries []*sql.QueryLog) error {
342+
return nil
343+
})
344+
insertRecords(tx, t)
345+
checkErr(t, tx.Commit())
346+
if !isInvokedBeforeCommitCallback {
347+
t.Fatal("cannot invoke callback for before commit")
348+
}
349+
if !isInvokedAfterCommitCallback {
350+
t.Fatal("cannot invoke callback for after commit")
351+
}
352+
}
353+
323354
func testIsAlreadyCommittedQueryLog(t *testing.T, queryLog *sql.QueryLog) {
324355
initializeTables(t)
325356
db, err := sql.Open("", "")

0 commit comments

Comments
 (0)