Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 42 additions & 3 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -1071,6 +1071,34 @@ func (c *conn) getBatchReadOnlyTransactionOptions() BatchReadOnlyTransactionOpti
return BatchReadOnlyTransactionOptions{TimestampBound: c.ReadOnlyStaleness()}
}

// BeginReadOnlyTransaction is not part of the public API of the database/sql driver.
// It is exported for internal reasons, and may receive breaking changes without prior notice.
//
// BeginReadOnlyTransaction starts a new read-only transaction on this connection.
func (c *conn) BeginReadOnlyTransaction(ctx context.Context, options *ReadOnlyTransactionOptions) (driver.Tx, error) {
c.withTempReadOnlyTransactionOptions(options)
tx, err := c.BeginTx(ctx, driver.TxOptions{ReadOnly: true})
if err != nil {
c.withTempReadOnlyTransactionOptions(nil)
return nil, err
}
return tx, nil
}

// BeginReadWriteTransaction is not part of the public API of the database/sql driver.
// It is exported for internal reasons, and may receive breaking changes without prior notice.
//
// BeginReadWriteTransaction starts a new read/write transaction on this connection.
func (c *conn) BeginReadWriteTransaction(ctx context.Context, options *ReadWriteTransactionOptions) (driver.Tx, error) {
c.withTempTransactionOptions(options)
tx, err := c.BeginTx(ctx, driver.TxOptions{})
if err != nil {
c.withTempTransactionOptions(nil)
return nil, err
}
return tx, nil
}

func (c *conn) Begin() (driver.Tx, error) {
return c.BeginTx(context.Background(), driver.TxOptions{})
}
Expand Down Expand Up @@ -1254,18 +1282,29 @@ func (c *conn) inReadWriteTransaction() bool {
return false
}

func (c *conn) commit(ctx context.Context) (*spanner.CommitResponse, error) {
// Commit is not part of the public API of the database/sql driver.
// It is exported for internal reasons, and may receive breaking changes without prior notice.
//
// Commit commits the current transaction on this connection.
func (c *conn) Commit(ctx context.Context) (*spanner.CommitResponse, error) {
if !c.inTransaction() {
return nil, status.Errorf(codes.FailedPrecondition, "this connection does not have a transaction")
}
// TODO: Pass in context to the tx.Commit() function.
if err := c.tx.Commit(); err != nil {
return nil, err
}
return c.CommitResponse()

// This will return either the commit response or nil, depending on whether the transaction was a
// read/write transaction or a read-only transaction.
return propertyCommitResponse.GetValueOrDefault(c.state), nil
}

func (c *conn) rollback(ctx context.Context) error {
// Rollback is not part of the public API of the database/sql driver.
// It is exported for internal reasons, and may receive breaking changes without prior notice.
//
// Rollback rollbacks the current transaction on this connection.
func (c *conn) Rollback(ctx context.Context) error {
if !c.inTransaction() {
return status.Errorf(codes.FailedPrecondition, "this connection does not have a transaction")
}
Expand Down
20 changes: 20 additions & 0 deletions conn_with_mockserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,26 @@ func TestTwoTransactionsOnOneConn(t *testing.T) {
}
}

func TestTwoQueriesOnOneConn(t *testing.T) {
t.Parallel()

db, _, teardown := setupTestDBConnection(t)
defer teardown()
ctx := context.Background()

c, _ := db.Conn(ctx)
defer silentClose(c)

for range 2 {
r, err := c.QueryContext(context.Background(), testutil.SelectFooFromBar)
if err != nil {
t.Fatal(err)
}
_ = r.Next()
defer silentClose(r)
}
}

func TestExplicitBeginTx(t *testing.T) {
t.Parallel()

Expand Down
153 changes: 152 additions & 1 deletion spannerlib/api/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package api
import (
"context"
"database/sql"
"database/sql/driver"
"fmt"
"strings"
"sync"
Expand All @@ -25,6 +26,9 @@ import (
"cloud.google.com/go/spanner"
"cloud.google.com/go/spanner/apiv1/spannerpb"
spannerdriver "github.com/googleapis/go-sql-spanner"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/timestamppb"
)

// CloseConnection looks up the connection with the given poolId and connId and closes it.
Expand All @@ -42,6 +46,35 @@ func CloseConnection(ctx context.Context, poolId, connId int64) error {
return conn.close(ctx)
}

// BeginTransaction starts a new transaction on the given connection.
// A connection can have at most one transaction at any time. This function therefore returns an error if the
// connection has an active transaction.
func BeginTransaction(ctx context.Context, poolId, connId int64, txOpts *spannerpb.TransactionOptions) error {
conn, err := findConnection(poolId, connId)
if err != nil {
return err
}
return conn.BeginTransaction(ctx, txOpts)
}

// Commit commits the current transaction on the given connection.
func Commit(ctx context.Context, poolId, connId int64) (*spannerpb.CommitResponse, error) {
conn, err := findConnection(poolId, connId)
if err != nil {
return nil, err
}
return conn.commit(ctx)
}

// Rollback rollbacks the current transaction on the given connection.
func Rollback(ctx context.Context, poolId, connId int64) error {
conn, err := findConnection(poolId, connId)
if err != nil {
return err
}
return conn.rollback(ctx)
}

func Execute(ctx context.Context, poolId, connId int64, executeSqlRequest *spannerpb.ExecuteSqlRequest) (int64, error) {
conn, err := findConnection(poolId, connId)
if err != nil {
Expand All @@ -59,23 +92,141 @@ type Connection struct {
backend *sql.Conn
}

// spannerConn is an internal interface that contains the internal functions that are used by this API.
// It is implemented by the spannerdriver.conn struct.
type spannerConn interface {
BeginReadOnlyTransaction(ctx context.Context, options *spannerdriver.ReadOnlyTransactionOptions) (driver.Tx, error)
BeginReadWriteTransaction(ctx context.Context, options *spannerdriver.ReadWriteTransactionOptions) (driver.Tx, error)
Commit(ctx context.Context) (*spanner.CommitResponse, error)
Rollback(ctx context.Context) error
}

type queryExecutor interface {
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
}

func (conn *Connection) close(ctx context.Context) error {
conn.closeResults(ctx)
// Rollback any open transactions on the connection.
_ = conn.rollback(ctx)

err := conn.backend.Close()
if err != nil {
return err
}
return nil
}

func (conn *Connection) BeginTransaction(ctx context.Context, txOpts *spannerpb.TransactionOptions) error {
var err error
if txOpts.GetReadOnly() != nil {
return conn.beginReadOnlyTransaction(ctx, convertToReadOnlyOpts(txOpts))
} else if txOpts.GetPartitionedDml() != nil {
err = spanner.ToSpannerError(status.Error(codes.InvalidArgument, "transaction type not supported"))
} else {
return conn.beginReadWriteTransaction(ctx, convertToReadWriteTransactionOptions(txOpts))
}
if err != nil {
return err
}
return nil
}

func (conn *Connection) beginReadOnlyTransaction(ctx context.Context, opts *spannerdriver.ReadOnlyTransactionOptions) error {
return conn.backend.Raw(func(driverConn any) (err error) {
sc, _ := driverConn.(spannerConn)
_, err = sc.BeginReadOnlyTransaction(ctx, opts)
return err
})
}

func (conn *Connection) beginReadWriteTransaction(ctx context.Context, opts *spannerdriver.ReadWriteTransactionOptions) error {
return conn.backend.Raw(func(driverConn any) (err error) {
sc, _ := driverConn.(spannerConn)
_, err = sc.BeginReadWriteTransaction(ctx, opts)
return err
})
}

func (conn *Connection) commit(ctx context.Context) (*spannerpb.CommitResponse, error) {
var response *spanner.CommitResponse
if err := conn.backend.Raw(func(driverConn any) (err error) {
spannerConn, _ := driverConn.(spannerConn)
response, err = spannerConn.Commit(ctx)
if err != nil {
return err
}
return nil
}); err != nil {
return nil, err
}

// The commit response is nil for read-only transactions.
if response == nil {
return nil, nil
}
// TODO: Include commit stats
return &spannerpb.CommitResponse{CommitTimestamp: timestamppb.New(response.CommitTs)}, nil
}

func (conn *Connection) rollback(ctx context.Context) error {
return conn.backend.Raw(func(driverConn any) (err error) {
spannerConn, _ := driverConn.(spannerConn)
return spannerConn.Rollback(ctx)
})
}

func convertToReadOnlyOpts(txOpts *spannerpb.TransactionOptions) *spannerdriver.ReadOnlyTransactionOptions {
return &spannerdriver.ReadOnlyTransactionOptions{
TimestampBound: convertTimestampBound(txOpts),
}
}

func convertTimestampBound(txOpts *spannerpb.TransactionOptions) spanner.TimestampBound {
ro := txOpts.GetReadOnly()
if ro.GetStrong() {
return spanner.StrongRead()
} else if ro.GetReadTimestamp() != nil {
return spanner.ReadTimestamp(ro.GetReadTimestamp().AsTime())
} else if ro.GetMinReadTimestamp() != nil {
return spanner.ReadTimestamp(ro.GetMinReadTimestamp().AsTime())
} else if ro.GetExactStaleness() != nil {
return spanner.ExactStaleness(ro.GetExactStaleness().AsDuration())
} else if ro.GetMaxStaleness() != nil {
return spanner.MaxStaleness(ro.GetMaxStaleness().AsDuration())
}
return spanner.TimestampBound{}
}

func convertToReadWriteTransactionOptions(txOpts *spannerpb.TransactionOptions) *spannerdriver.ReadWriteTransactionOptions {
readLockMode := spannerpb.TransactionOptions_ReadWrite_READ_LOCK_MODE_UNSPECIFIED
if txOpts.GetReadWrite() != nil {
readLockMode = txOpts.GetReadWrite().GetReadLockMode()
}
return &spannerdriver.ReadWriteTransactionOptions{
TransactionOptions: spanner.TransactionOptions{
IsolationLevel: txOpts.GetIsolationLevel(),
ReadLockMode: readLockMode,
},
}
}

func convertIsolationLevel(level spannerpb.TransactionOptions_IsolationLevel) sql.IsolationLevel {
switch level {
case spannerpb.TransactionOptions_SERIALIZABLE:
return sql.LevelSerializable
case spannerpb.TransactionOptions_REPEATABLE_READ:
return sql.LevelRepeatableRead
}
return sql.LevelDefault
}

func (conn *Connection) closeResults(ctx context.Context) {
conn.results.Range(func(key, value interface{}) bool {
// TODO: Implement
if r, ok := value.(*rows); ok {
_ = r.Close(ctx)
}
return true
})
}
Expand Down
Loading
Loading