Skip to content

Commit

Permalink
Merge pull request #619 from thatguystone/master
Browse files Browse the repository at this point in the history
Add support for sql.TxOptions
  • Loading branch information
maddyblue authored Jun 3, 2017
2 parents 91f10e4 + 23323c9 commit 8837942
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 7 deletions.
6 changes: 5 additions & 1 deletion conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -506,13 +506,17 @@ func (cn *conn) checkIsInTransaction(intxn bool) {
}

func (cn *conn) Begin() (_ driver.Tx, err error) {
return cn.begin("")
}

func (cn *conn) begin(mode string) (_ driver.Tx, err error) {
if cn.bad {
return nil, driver.ErrBadConn
}
defer cn.errRecover(&err)

cn.checkIsInTransaction(false)
_, commandTag, err := cn.simpleExec("BEGIN")
_, commandTag, err := cn.simpleExec("BEGIN" + mode)
if err != nil {
return nil, err
}
Expand Down
28 changes: 23 additions & 5 deletions conn_go18.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@ package pq

import (
"context"
"database/sql"
"database/sql/driver"
"errors"
"fmt"
"io"
"io/ioutil"
)
Expand Down Expand Up @@ -44,13 +45,30 @@ func (cn *conn) ExecContext(ctx context.Context, query string, args []driver.Nam

// Implement the "ConnBeginTx" interface
func (cn *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
if opts.Isolation != 0 {
return nil, errors.New("isolation levels not supported")
var mode string

switch sql.IsolationLevel(opts.Isolation) {
case sql.LevelDefault:
// Don't touch mode: use the server's default
case sql.LevelReadUncommitted:
mode = " ISOLATION LEVEL READ UNCOMMITTED"
case sql.LevelReadCommitted:
mode = " ISOLATION LEVEL READ COMMITTED"
case sql.LevelRepeatableRead:
mode = " ISOLATION LEVEL REPEATABLE READ"
case sql.LevelSerializable:
mode = " ISOLATION LEVEL SERIALIZABLE"
default:
return nil, fmt.Errorf("pq: isolation level not supported: %d", opts.Isolation)
}

if opts.ReadOnly {
return nil, errors.New("read-only transactions not supported")
mode += " READ ONLY"
} else {
mode += " READ WRITE"
}
tx, err := cn.Begin()

tx, err := cn.begin(mode)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,11 +160,11 @@ func TestPgpass(t *testing.T) {
rows, err := txn.Query("SELECT USER")
if err != nil {
txn.Rollback()
rows.Close()
if expected != "fail" {
t.Fatalf(reason, err)
}
} else {
rows.Close()
if expected != "ok" {
t.Fatalf(reason, err)
}
Expand Down
78 changes: 78 additions & 0 deletions go18_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"context"
"database/sql"
"runtime"
"strings"
"testing"
"time"
)
Expand Down Expand Up @@ -239,3 +240,80 @@ func TestContextCancelBegin(t *testing.T) {
}
}
}

func TestTxOptions(t *testing.T) {
db := openTestConn(t)
defer db.Close()
ctx := context.Background()

tests := []struct {
level sql.IsolationLevel
isolation string
}{
{
level: sql.LevelDefault,
isolation: "",
},
{
level: sql.LevelReadUncommitted,
isolation: "read uncommitted",
},
{
level: sql.LevelReadCommitted,
isolation: "read committed",
},
{
level: sql.LevelRepeatableRead,
isolation: "repeatable read",
},
{
level: sql.LevelSerializable,
isolation: "serializable",
},
}

for _, test := range tests {
for _, ro := range []bool{true, false} {
tx, err := db.BeginTx(ctx, &sql.TxOptions{
Isolation: test.level,
ReadOnly: ro,
})
if err != nil {
t.Fatal(err)
}

var isolation string
err = tx.QueryRow("select current_setting('transaction_isolation')").Scan(&isolation)
if err != nil {
t.Fatal(err)
}

if test.isolation != "" && isolation != test.isolation {
t.Errorf("wrong isolation level: %s != %s", isolation, test.isolation)
}

var isRO string
err = tx.QueryRow("select current_setting('transaction_read_only')").Scan(&isRO)
if err != nil {
t.Fatal(err)
}

if ro != (isRO == "on") {
t.Errorf("read/[write,only] not set: %t != %s for level %s",
ro, isRO, test.isolation)
}

tx.Rollback()
}
}

_, err := db.BeginTx(ctx, &sql.TxOptions{
Isolation: sql.LevelLinearizable,
})
if err == nil {
t.Fatal("expected LevelLinearizable to fail")
}
if !strings.Contains(err.Error(), "isolation level not supported") {
t.Errorf("Expected error to mention isolation level, got %q", err)
}
}

0 comments on commit 8837942

Please sign in to comment.