From d68fc12a22e730e1d0c73a5b3c77e31d0b63ff9e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Knut=20Olav=20L=C3=B8ite?= Date: Mon, 15 Sep 2025 18:10:08 +0200 Subject: [PATCH 1/2] chore: add ExecuteBatch to SpannerLib Adds an ExecuteBatch function to SpannerLib that supports executing DML or DDL statements as a single batch. The function accepts an ExecuteBatchDml request for both types of batches. The type of batch that is actually being executed is determined based on the statements in the batch. Mixing DML and DDL in the same batch is not supported. Queries are also not supported in batches. --- conn.go | 8 + spannerlib/api/batch_test.go | 238 ++++++++++++++++++ spannerlib/api/connection.go | 128 ++++++++++ spannerlib/lib/connection.go | 16 ++ spannerlib/lib/connection_test.go | 42 ++++ spannerlib/shared/shared_lib.go | 12 + spannerlib/shared/shared_lib_test.go | 57 +++++ .../google/cloud/spannerlib/Connection.java | 21 ++ .../spannerlib/internal/SpannerLibrary.java | 6 + .../google/cloud/spannerlib/BatchTest.java | 152 +++++++++++ 10 files changed, 680 insertions(+) create mode 100644 spannerlib/api/batch_test.go create mode 100644 spannerlib/wrappers/spannerlib-java/src/test/java/com/google/cloud/spannerlib/BatchTest.java diff --git a/conn.go b/conn.go index 9ee4efb1..cbe5ff33 100644 --- a/conn.go +++ b/conn.go @@ -222,6 +222,9 @@ type SpannerConn interface { // return the same Spanner client. UnderlyingClient() (client *spanner.Client, err error) + // DetectStatementType returns the type of SQL statement. + DetectStatementType(query string) parser.StatementType + // resetTransactionForRetry resets the current transaction after it has // been aborted by Spanner. Calling this function on a transaction that // has not been aborted is not supported and will cause an error to be @@ -286,6 +289,11 @@ func (c *conn) UnderlyingClient() (*spanner.Client, error) { return c.client, nil } +func (c *conn) DetectStatementType(query string) parser.StatementType { + info := c.parser.DetectStatementType(query) + return info.StatementType +} + func (c *conn) CommitTimestamp() (time.Time, error) { ts := propertyCommitTimestamp.GetValueOrDefault(c.state) if ts == nil { diff --git a/spannerlib/api/batch_test.go b/spannerlib/api/batch_test.go new file mode 100644 index 00000000..1ae9af14 --- /dev/null +++ b/spannerlib/api/batch_test.go @@ -0,0 +1,238 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package api + +import ( + "context" + "fmt" + "reflect" + "testing" + + "cloud.google.com/go/longrunning/autogen/longrunningpb" + "cloud.google.com/go/spanner" + "cloud.google.com/go/spanner/admin/database/apiv1/databasepb" + "cloud.google.com/go/spanner/apiv1/spannerpb" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "github.com/googleapis/go-sql-spanner/testutil" + "google.golang.org/grpc/codes" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/known/anypb" + "google.golang.org/protobuf/types/known/emptypb" +) + +func TestExecuteDmlBatch(t *testing.T) { + t.Parallel() + + ctx := context.Background() + server, teardown := setupMockServer(t) + defer teardown() + dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address) + + poolId, err := CreatePool(ctx, dsn) + if err != nil { + t.Fatalf("CreatePool returned unexpected error: %v", err) + } + connId, err := CreateConnection(ctx, poolId) + if err != nil { + t.Fatalf("CreateConnection returned unexpected error: %v", err) + } + + // Execute a DML batch. + request := &spannerpb.ExecuteBatchDmlRequest{Statements: []*spannerpb.ExecuteBatchDmlRequest_Statement{ + {Sql: testutil.UpdateBarSetFoo}, + {Sql: testutil.UpdateBarSetFoo}, + }} + resp, err := ExecuteBatch(ctx, poolId, connId, request) + if err != nil { + t.Fatalf("ExecuteBatch returned unexpected error: %v", err) + } + if g, w := len(resp.ResultSets), 2; g != w { + t.Fatalf("num results mismatch\n Got: %d\nWant: %d", g, w) + } + for i, result := range resp.ResultSets { + if g, w := result.Stats.GetRowCountExact(), int64(testutil.UpdateBarSetFooRowCount); g != w { + t.Fatalf("%d: update count mismatch\n Got: %d\nWant: %d", i, g, w) + } + } + + requests := server.TestSpanner.DrainRequestsFromServer() + // There should be no ExecuteSql requests. + executeRequests := testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.ExecuteSqlRequest{})) + if g, w := len(executeRequests), 0; g != w { + t.Fatalf("Execute request count mismatch\n Got: %v\nWant: %v", g, w) + } + batchRequests := testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.ExecuteBatchDmlRequest{})) + if g, w := len(batchRequests), 1; g != w { + t.Fatalf("Execute batch request count mismatch\n Got: %v\nWant: %v", g, w) + } + + if err := CloseConnection(ctx, poolId, connId); err != nil { + t.Fatalf("CloseConnection returned unexpected error: %v", err) + } + if err := ClosePool(ctx, poolId); err != nil { + t.Fatalf("ClosePool returned unexpected error: %v", err) + } +} + +func TestExecuteDdlBatch(t *testing.T) { + t.Parallel() + + ctx := context.Background() + server, teardown := setupMockServer(t) + defer teardown() + dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address) + // Set up a result for a DDL statement on the mock server. + var expectedResponse = &emptypb.Empty{} + anyMsg, _ := anypb.New(expectedResponse) + server.TestDatabaseAdmin.SetResps([]proto.Message{ + &longrunningpb.Operation{ + Done: true, + Result: &longrunningpb.Operation_Response{Response: anyMsg}, + Name: "test-operation", + }, + }) + + poolId, err := CreatePool(ctx, dsn) + if err != nil { + t.Fatalf("CreatePool returned unexpected error: %v", err) + } + connId, err := CreateConnection(ctx, poolId) + if err != nil { + t.Fatalf("CreateConnection returned unexpected error: %v", err) + } + + // Execute a DDL batch. This also uses a DML batch request. + request := &spannerpb.ExecuteBatchDmlRequest{Statements: []*spannerpb.ExecuteBatchDmlRequest_Statement{ + {Sql: "create table my_table (id int64 primary key, value string(100))"}, + {Sql: "create index my_index on my_table (value)"}, + }} + resp, err := ExecuteBatch(ctx, poolId, connId, request) + if err != nil { + t.Fatalf("ExecuteBatch returned unexpected error: %v", err) + } + // The response should contain an 'update count' per DDL statement. + if g, w := len(resp.ResultSets), 2; g != w { + t.Fatalf("num results mismatch\n Got: %d\nWant: %d", g, w) + } + // There is no update count for DDL statements. + for i, result := range resp.ResultSets { + emptyStats := &spannerpb.ResultSetStats{} + if g, w := result.Stats, emptyStats; !cmp.Equal(g, w, cmpopts.IgnoreUnexported(spannerpb.ResultSetStats{})) { + t.Fatalf("%d: ResultSetStats mismatch\n Got: %v\nWant: %v", i, g, w) + } + } + + requests := server.TestSpanner.DrainRequestsFromServer() + // There should be no ExecuteSql requests. + executeRequests := testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.ExecuteSqlRequest{})) + if g, w := len(executeRequests), 0; g != w { + t.Fatalf("Execute request count mismatch\n Got: %v\nWant: %v", g, w) + } + // There should also be no ExecuteBatchDml requests. + batchDmlRequests := testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.ExecuteBatchDmlRequest{})) + if g, w := len(batchDmlRequests), 0; g != w { + t.Fatalf("ExecuteBatchDmlRequest count mismatch\n Got: %v\nWant: %v", g, w) + } + + adminRequests := server.TestDatabaseAdmin.Reqs() + if g, w := len(adminRequests), 1; g != w { + t.Fatalf("admin request count mismatch\n Got: %v\nWant: %v", g, w) + } + ddlRequest := adminRequests[0].(*databasepb.UpdateDatabaseDdlRequest) + if g, w := len(ddlRequest.Statements), 2; g != w { + t.Fatalf("DDL statement count mismatch\n Got: %v\nWant: %v", g, w) + } + + if err := CloseConnection(ctx, poolId, connId); err != nil { + t.Fatalf("CloseConnection returned unexpected error: %v", err) + } + if err := ClosePool(ctx, poolId); err != nil { + t.Fatalf("ClosePool returned unexpected error: %v", err) + } +} + +func TestExecuteMixedBatch(t *testing.T) { + t.Parallel() + + ctx := context.Background() + server, teardown := setupMockServer(t) + defer teardown() + dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address) + + poolId, err := CreatePool(ctx, dsn) + if err != nil { + t.Fatalf("CreatePool returned unexpected error: %v", err) + } + connId, err := CreateConnection(ctx, poolId) + if err != nil { + t.Fatalf("CreateConnection returned unexpected error: %v", err) + } + + // Try to execute a batch with mixed DML and DDL statements. This should fail. + request := &spannerpb.ExecuteBatchDmlRequest{Statements: []*spannerpb.ExecuteBatchDmlRequest_Statement{ + {Sql: "create table my_table (id int64 primary key, value string(100))"}, + {Sql: "update my_table set value = 100 where true"}, + }} + _, err = ExecuteBatch(ctx, poolId, connId, request) + if g, w := spanner.ErrCode(err), codes.InvalidArgument; g != w { + t.Fatalf("error code mismatch\n Got: %v\nWant: %v", g, w) + } + + if err := CloseConnection(ctx, poolId, connId); err != nil { + t.Fatalf("CloseConnection returned unexpected error: %v", err) + } + if err := ClosePool(ctx, poolId); err != nil { + t.Fatalf("ClosePool returned unexpected error: %v", err) + } +} + +func TestExecuteDdlBatchInTransaction(t *testing.T) { + t.Parallel() + + ctx := context.Background() + server, teardown := setupMockServer(t) + defer teardown() + dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address) + + poolId, err := CreatePool(ctx, dsn) + if err != nil { + t.Fatalf("CreatePool returned unexpected error: %v", err) + } + connId, err := CreateConnection(ctx, poolId) + if err != nil { + t.Fatalf("CreateConnection returned unexpected error: %v", err) + } + if err := BeginTransaction(ctx, poolId, connId, &spannerpb.TransactionOptions{}); err != nil { + t.Fatalf("BeginTransaction returned unexpected error: %v", err) + } + + // Try to execute a DDL batch in a transaction. This should fail. + request := &spannerpb.ExecuteBatchDmlRequest{Statements: []*spannerpb.ExecuteBatchDmlRequest_Statement{ + {Sql: "create table my_table (id int64 primary key, value string(100))"}, + {Sql: "create index my_index on my_table (value)"}, + }} + _, err = ExecuteBatch(ctx, poolId, connId, request) + if g, w := spanner.ErrCode(err), codes.FailedPrecondition; g != w { + t.Fatalf("error code mismatch\n Got: %v\nWant: %v", g, w) + } + + if err := CloseConnection(ctx, poolId, connId); err != nil { + t.Fatalf("CloseConnection returned unexpected error: %v", err) + } + if err := ClosePool(ctx, poolId); err != nil { + t.Fatalf("ClosePool returned unexpected error: %v", err) + } +} diff --git a/spannerlib/api/connection.go b/spannerlib/api/connection.go index f5f7643a..1d852120 100644 --- a/spannerlib/api/connection.go +++ b/spannerlib/api/connection.go @@ -26,6 +26,7 @@ import ( "cloud.google.com/go/spanner" "cloud.google.com/go/spanner/apiv1/spannerpb" spannerdriver "github.com/googleapis/go-sql-spanner" + "github.com/googleapis/go-sql-spanner/parser" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" "google.golang.org/protobuf/types/known/timestamppb" @@ -83,6 +84,14 @@ func Execute(ctx context.Context, poolId, connId int64, executeSqlRequest *spann return conn.Execute(ctx, executeSqlRequest) } +func ExecuteBatch(ctx context.Context, poolId, connId int64, statements *spannerpb.ExecuteBatchDmlRequest) (*spannerpb.ExecuteBatchDmlResponse, error) { + conn, err := findConnection(poolId, connId) + if err != nil { + return nil, err + } + return conn.ExecuteBatch(ctx, statements.Statements) +} + type Connection struct { // results contains the open query results for this connection. results *sync.Map @@ -235,6 +244,10 @@ func (conn *Connection) Execute(ctx context.Context, statement *spannerpb.Execut return execute(ctx, conn, conn.backend, statement) } +func (conn *Connection) ExecuteBatch(ctx context.Context, statements []*spannerpb.ExecuteBatchDmlRequest_Statement) (*spannerpb.ExecuteBatchDmlResponse, error) { + return executeBatch(ctx, conn, conn.backend, statements) +} + func execute(ctx context.Context, conn *Connection, executor queryExecutor, statement *spannerpb.ExecuteSqlRequest) (int64, error) { params := extractParams(statement) it, err := executor.QueryContext(ctx, statement.Sql, params...) @@ -266,6 +279,90 @@ func execute(ctx context.Context, conn *Connection, executor queryExecutor, stat return id, nil } +func executeBatch(ctx context.Context, conn *Connection, executor queryExecutor, statements []*spannerpb.ExecuteBatchDmlRequest_Statement) (*spannerpb.ExecuteBatchDmlResponse, error) { + // Determine the type of batch that should be executed based on the type of statements. + batchType, err := determineBatchType(conn, statements) + if err != nil { + return nil, err + } + switch batchType { + case parser.BatchTypeDml: + return executeBatchDml(ctx, conn, executor, statements) + case parser.BatchTypeDdl: + return executeBatchDdl(ctx, conn, executor, statements) + default: + return nil, status.Errorf(codes.InvalidArgument, "unsupported batch type: %v", batchType) + } +} + +func executeBatchDdl(ctx context.Context, conn *Connection, executor queryExecutor, statements []*spannerpb.ExecuteBatchDmlRequest_Statement) (*spannerpb.ExecuteBatchDmlResponse, error) { + if err := conn.backend.Raw(func(driverConn any) error { + spannerConn, _ := driverConn.(spannerdriver.SpannerConn) + return spannerConn.StartBatchDDL() + }); err != nil { + return nil, err + } + for _, statement := range statements { + _, err := executor.ExecContext(ctx, statement.Sql) + if err != nil { + return nil, err + } + } + // TODO: Add support for getting the actual Batch DDL response. + if err := conn.backend.Raw(func(driverConn any) (err error) { + spannerConn, _ := driverConn.(spannerdriver.SpannerConn) + return spannerConn.RunBatch(ctx) + }); err != nil { + return nil, err + } + + response := spannerpb.ExecuteBatchDmlResponse{} + response.ResultSets = make([]*spannerpb.ResultSet, len(statements)) + for i := range statements { + response.ResultSets[i] = &spannerpb.ResultSet{Stats: &spannerpb.ResultSetStats{}} + } + return &response, nil +} + +func executeBatchDml(ctx context.Context, conn *Connection, executor queryExecutor, statements []*spannerpb.ExecuteBatchDmlRequest_Statement) (*spannerpb.ExecuteBatchDmlResponse, error) { + if err := conn.backend.Raw(func(driverConn any) error { + spannerConn, _ := driverConn.(spannerdriver.SpannerConn) + return spannerConn.StartBatchDML() + }); err != nil { + return nil, err + } + for _, statement := range statements { + request := &spannerpb.ExecuteSqlRequest{ + Sql: statement.Sql, + Params: statement.Params, + ParamTypes: statement.ParamTypes, + } + params := extractParams(request) + _, err := executor.ExecContext(ctx, statement.Sql, params...) + if err != nil { + return nil, err + } + } + var spannerResult spannerdriver.SpannerResult + if err := conn.backend.Raw(func(driverConn any) (err error) { + spannerConn, _ := driverConn.(spannerdriver.SpannerConn) + spannerResult, err = spannerConn.RunDmlBatch(ctx) + return err + }); err != nil { + return nil, err + } + affected, err := spannerResult.BatchRowsAffected() + if err != nil { + return nil, err + } + response := spannerpb.ExecuteBatchDmlResponse{} + response.ResultSets = make([]*spannerpb.ResultSet, len(affected)) + for i, aff := range affected { + response.ResultSets[i] = &spannerpb.ResultSet{Stats: &spannerpb.ResultSetStats{RowCount: &spannerpb.ResultSetStats_RowCountExact{RowCountExact: aff}}} + } + return &response, nil +} + func extractParams(statement *spannerpb.ExecuteSqlRequest) []any { paramsLen := 1 if statement.Params != nil { @@ -300,3 +397,34 @@ func extractParams(statement *spannerpb.ExecuteSqlRequest) []any { } return params } + +func determineBatchType(conn *Connection, statements []*spannerpb.ExecuteBatchDmlRequest_Statement) (parser.BatchType, error) { + if len(statements) == 0 { + return parser.BatchTypeDdl, status.Errorf(codes.InvalidArgument, "cannot determine type of an empty batch") + } + var batchType parser.BatchType + if err := conn.backend.Raw(func(driverConn any) error { + spannerConn, _ := driverConn.(spannerdriver.SpannerConn) + firstStatementType := spannerConn.DetectStatementType(statements[0].Sql) + if firstStatementType == parser.StatementTypeDml { + batchType = parser.BatchTypeDml + } else if firstStatementType == parser.StatementTypeDdl { + batchType = parser.BatchTypeDdl + } else { + return status.Errorf(codes.InvalidArgument, "unsupported statement type for batching: %v", firstStatementType) + } + for i, statement := range statements { + if i > 0 { + tp := spannerConn.DetectStatementType(statement.Sql) + if tp != firstStatementType { + return status.Errorf(codes.InvalidArgument, "Batches may not contain different types of statements. The first statement is of type %v. The statement on position %d is of type %v.", firstStatementType, i, tp) + } + } + } + return nil + }); err != nil { + return parser.BatchTypeDdl, err + } + + return batchType, nil +} diff --git a/spannerlib/lib/connection.go b/spannerlib/lib/connection.go index d41b6a6e..72f32efb 100644 --- a/spannerlib/lib/connection.go +++ b/spannerlib/lib/connection.go @@ -89,3 +89,19 @@ func Execute(ctx context.Context, poolId, connId int64, executeSqlRequestBytes [ } return idMessage(id) } + +func ExecuteBatch(ctx context.Context, poolId, connId int64, statementsBytes []byte) *Message { + statements := spannerpb.ExecuteBatchDmlRequest{} + if err := proto.Unmarshal(statementsBytes, &statements); err != nil { + return errMessage(err) + } + response, err := api.ExecuteBatch(ctx, poolId, connId, &statements) + if err != nil { + return errMessage(err) + } + res, err := proto.Marshal(response) + if err != nil { + return errMessage(err) + } + return &Message{Res: res} +} diff --git a/spannerlib/lib/connection_test.go b/spannerlib/lib/connection_test.go index 02ba453c..02945ae9 100644 --- a/spannerlib/lib/connection_test.go +++ b/spannerlib/lib/connection_test.go @@ -117,6 +117,48 @@ func TestExecute(t *testing.T) { } } +func TestExecuteBatch(t *testing.T) { + t.Parallel() + + ctx := context.Background() + server, teardown := setupMockServer(t) + defer teardown() + dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address) + + poolMsg := CreatePool(ctx, dsn) + if g, w := poolMsg.Code, int32(0); g != w { + t.Fatalf("CreatePool result mismatch\n Got: %v\nWant: %v", g, w) + } + connMsg := CreateConnection(ctx, poolMsg.ObjectId) + if g, w := connMsg.Code, int32(0); g != w { + t.Fatalf("CreateConnection result mismatch\n Got: %v\nWant: %v", g, w) + } + request := &spannerpb.ExecuteBatchDmlRequest{Statements: []*spannerpb.ExecuteBatchDmlRequest_Statement{ + {Sql: testutil.UpdateBarSetFoo}, + {Sql: testutil.UpdateBarSetFoo}, + }} + requestBytes, err := proto.Marshal(request) + if err != nil { + t.Fatalf("Failed to marshal request: %v", err) + } + rowsMsg := ExecuteBatch(ctx, poolMsg.ObjectId, connMsg.ObjectId, requestBytes) + if g, w := rowsMsg.Code, int32(0); g != w { + t.Fatalf("ExecuteBatch result mismatch\n Got: %v\nWant: %v", g, w) + } + if rowsMsg.Length() == 0 { + t.Fatal("ExecuteBatch returned no data") + } + + closeMsg := CloseConnection(ctx, poolMsg.ObjectId, connMsg.ObjectId) + if g, w := closeMsg.Code, int32(0); g != w { + t.Fatalf("CloseConnection result mismatch\n Got: %v\nWant: %v", g, w) + } + closeMsg = ClosePool(ctx, poolMsg.ObjectId) + if g, w := closeMsg.Code, int32(0); g != w { + t.Fatalf("ClosePool result mismatch\n Got: %v\nWant: %v", g, w) + } +} + func TestBeginAndCommit(t *testing.T) { t.Parallel() diff --git a/spannerlib/shared/shared_lib.go b/spannerlib/shared/shared_lib.go index e9c9be6a..d68c1973 100644 --- a/spannerlib/shared/shared_lib.go +++ b/spannerlib/shared/shared_lib.go @@ -123,6 +123,18 @@ func Execute(poolId, connectionId int64, statement []byte) (int64, int32, int64, return pin(msg) } +// ExecuteBatch executes a batch of statements on the given connection. The statements must all be either DML or DDL +// statements. Mixing DML and DDL in a batch is not supported. Executing queries in a batch is also not supported. +// The batch will use the current transaction on the given connection, or execute as a single auto-commit statement +// if the connection does not have a transaction. +// +//export ExecuteBatch +func ExecuteBatch(poolId, connectionId int64, statements []byte) (int64, int32, int64, int32, unsafe.Pointer) { + ctx := context.Background() + msg := lib.ExecuteBatch(ctx, poolId, connectionId, statements) + return pin(msg) +} + // Metadata returns the metadata of a Rows object. // //export Metadata diff --git a/spannerlib/shared/shared_lib_test.go b/spannerlib/shared/shared_lib_test.go index 82d010bd..21650f6b 100644 --- a/spannerlib/shared/shared_lib_test.go +++ b/spannerlib/shared/shared_lib_test.go @@ -245,6 +245,63 @@ func TestExecute(t *testing.T) { } } +func TestExecuteBatch(t *testing.T) { + t.Parallel() + + server, teardown := setupMockServer(t) + defer teardown() + dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address) + + _, code, poolId, _, _ := CreatePool(dsn) + if g, w := code, int32(0); g != w { + t.Fatalf("CreatePool result mismatch\n Got: %v\nWant: %v", g, w) + } + _, code, connId, _, _ := CreateConnection(poolId) + if g, w := code, int32(0); g != w { + t.Fatalf("CreateConnection result mismatch\n Got: %v\nWant: %v", g, w) + } + + request := &spannerpb.ExecuteBatchDmlRequest{ + Statements: []*spannerpb.ExecuteBatchDmlRequest_Statement{ + {Sql: testutil.UpdateBarSetFoo}, + {Sql: testutil.UpdateBarSetFoo}, + }, + } + requestBytes, err := proto.Marshal(request) + if err != nil { + t.Fatal(err) + } + // ExecuteBatch returns a ExecuteBatchDml response. + mem, code, batchId, length, data := ExecuteBatch(poolId, connId, requestBytes) + verifyDataMessage(t, "ExecuteBatch", mem, code, batchId, length, data) + response := &spannerpb.ExecuteBatchDmlResponse{} + responseBytes := reflect.SliceAt(reflect.TypeOf(byte(0)), data, int(length)).Bytes() + if err := proto.Unmarshal(responseBytes, response); err != nil { + t.Fatal(err) + } + if g, w := len(response.ResultSets), 2; g != w { + t.Fatalf("num results mismatch\n Got: %v\nWant: %v", g, w) + } + for i, result := range response.ResultSets { + if g, w := result.Stats.GetRowCountExact(), int64(testutil.UpdateBarSetFooRowCount); g != w { + t.Fatalf("%d: update count mismatch\n Got: %v\nWant: %v", i, g, w) + } + } + // Release the memory held by the response. + if g, w := Release(mem), int32(0); g != w { + t.Fatalf("Release() result mismatch\n Got: %v\nWant: %v", g, w) + } + + _, code, _, _, _ = CloseConnection(poolId, connId) + if g, w := code, int32(0); g != w { + t.Fatalf("CloseConnection result mismatch\n Got: %v\nWant: %v", g, w) + } + _, code, _, _, _ = ClosePool(poolId) + if g, w := code, int32(0); g != w { + t.Fatalf("ClosePool result mismatch\n Got: %v\nWant: %v", g, w) + } +} + func TestBeginAndCommitTransaction(t *testing.T) { t.Parallel() diff --git a/spannerlib/wrappers/spannerlib-java/src/main/java/com/google/cloud/spannerlib/Connection.java b/spannerlib/wrappers/spannerlib-java/src/main/java/com/google/cloud/spannerlib/Connection.java index 90ac04f9..dc5aad96 100644 --- a/spannerlib/wrappers/spannerlib-java/src/main/java/com/google/cloud/spannerlib/Connection.java +++ b/spannerlib/wrappers/spannerlib-java/src/main/java/com/google/cloud/spannerlib/Connection.java @@ -22,6 +22,8 @@ import com.google.cloud.spannerlib.internal.WrappedGoBytes; import com.google.protobuf.InvalidProtocolBufferException; import com.google.spanner.v1.CommitResponse; +import com.google.spanner.v1.ExecuteBatchDmlRequest; +import com.google.spanner.v1.ExecuteBatchDmlResponse; import com.google.spanner.v1.ExecuteSqlRequest; import com.google.spanner.v1.TransactionOptions; import java.nio.ByteBuffer; @@ -84,6 +86,25 @@ public Rows execute(ExecuteSqlRequest request) { } } + /** + * Executes the given batch of DML or DDL statements on this connection. The statements must all + * be of the same type. + */ + public ExecuteBatchDmlResponse executeBatch(ExecuteBatchDmlRequest request) { + try (WrappedGoBytes serializedRequest = WrappedGoBytes.serialize(request); + MessageHandler message = + getLibrary() + .execute( + library -> + library.ExecuteBatch( + pool.getId(), getId(), serializedRequest.getGoBytes()))) { + ByteBuffer buffer = message.getValue().getByteBuffer(0, message.getLength()); + return ExecuteBatchDmlResponse.parseFrom(buffer); + } catch (InvalidProtocolBufferException decodeException) { + throw new RuntimeException(decodeException); + } + } + /** Closes this connection. Any active transaction on the connection is rolled back. */ @Override public void close() { diff --git a/spannerlib/wrappers/spannerlib-java/src/main/java/com/google/cloud/spannerlib/internal/SpannerLibrary.java b/spannerlib/wrappers/spannerlib-java/src/main/java/com/google/cloud/spannerlib/internal/SpannerLibrary.java index 1ad3f14e..54143891 100644 --- a/spannerlib/wrappers/spannerlib-java/src/main/java/com/google/cloud/spannerlib/internal/SpannerLibrary.java +++ b/spannerlib/wrappers/spannerlib-java/src/main/java/com/google/cloud/spannerlib/internal/SpannerLibrary.java @@ -80,6 +80,12 @@ default MessageHandler execute(Function function) /** Executes a SQL statement on the given Connection. */ Message Execute(long poolId, long connectionId, GoBytes executeSqlRequest); + /** + * Executes a batch of DML or DDL statements on the given Connection. Returns an {@link + * com.google.spanner.v1.ExecuteBatchDmlResponse} for both DML and DDL batches. + */ + Message ExecuteBatch(long poolId, long connectionId, GoBytes executeBatchDmlRequest); + /** Returns the {@link com.google.spanner.v1.ResultSetMetadata} of the given Rows object. */ Message Metadata(long poolId, long connectionId, long rowsId); diff --git a/spannerlib/wrappers/spannerlib-java/src/test/java/com/google/cloud/spannerlib/BatchTest.java b/spannerlib/wrappers/spannerlib-java/src/test/java/com/google/cloud/spannerlib/BatchTest.java new file mode 100644 index 00000000..9740d541 --- /dev/null +++ b/spannerlib/wrappers/spannerlib-java/src/test/java/com/google/cloud/spannerlib/BatchTest.java @@ -0,0 +1,152 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.spannerlib; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; + +import com.google.cloud.spanner.MockSpannerServiceImpl.StatementResult; +import com.google.cloud.spanner.connection.AbstractMockServerTest; +import com.google.common.collect.ImmutableMap; +import com.google.longrunning.Operation; +import com.google.protobuf.AbstractMessage; +import com.google.protobuf.Any; +import com.google.protobuf.Empty; +import com.google.protobuf.Struct; +import com.google.protobuf.Value; +import com.google.spanner.admin.database.v1.UpdateDatabaseDdlMetadata; +import com.google.spanner.admin.database.v1.UpdateDatabaseDdlRequest; +import com.google.spanner.v1.CommitRequest; +import com.google.spanner.v1.ExecuteBatchDmlRequest; +import com.google.spanner.v1.ExecuteBatchDmlRequest.Statement; +import com.google.spanner.v1.ExecuteBatchDmlResponse; +import com.google.spanner.v1.ExecuteSqlRequest; +import com.google.spanner.v1.Type; +import com.google.spanner.v1.TypeCode; +import java.util.List; +import org.junit.Test; + +public class BatchTest extends AbstractMockServerTest { + + @Test + public void testBatchDml() { + String insert = "insert into test (id, value) values (@id, @value)"; + mockSpanner.putStatementResult( + StatementResult.update( + com.google.cloud.spanner.Statement.newBuilder(insert) + .bind("id") + .to(1L) + .bind("value") + .to("One") + .build(), + 1L)); + mockSpanner.putStatementResult( + StatementResult.update( + com.google.cloud.spanner.Statement.newBuilder(insert) + .bind("id") + .to(2L) + .bind("value") + .to("Two") + .build(), + 1L)); + + String dsn = + String.format( + "localhost:%d/projects/p/instances/i/databases/d?usePlainText=true", getPort()); + try (Pool pool = Pool.createPool(dsn); + Connection connection = pool.createConnection()) { + ExecuteBatchDmlResponse response = + connection.executeBatch( + ExecuteBatchDmlRequest.newBuilder() + .addStatements( + Statement.newBuilder() + .setSql(insert) + .setParams( + Struct.newBuilder() + .putFields("id", Value.newBuilder().setStringValue("1").build()) + .putFields( + "value", Value.newBuilder().setStringValue("One").build()) + .build()) + .putAllParamTypes( + ImmutableMap.of( + "id", Type.newBuilder().setCode(TypeCode.INT64).build(), + "value", Type.newBuilder().setCode(TypeCode.STRING).build())) + .build()) + .addStatements( + Statement.newBuilder() + .setSql(insert) + .setParams( + Struct.newBuilder() + .putFields("id", Value.newBuilder().setStringValue("2").build()) + .putFields( + "value", Value.newBuilder().setStringValue("Two").build()) + .build()) + .putAllParamTypes( + ImmutableMap.of( + "id", Type.newBuilder().setCode(TypeCode.INT64).build(), + "value", Type.newBuilder().setCode(TypeCode.STRING).build())) + .build()) + .build()); + + assertEquals(2, response.getResultSetsCount()); + assertEquals(1L, response.getResultSets(0).getStats().getRowCountExact()); + assertEquals(1L, response.getResultSets(1).getStats().getRowCountExact()); + } + + assertEquals(1, mockSpanner.countRequestsOfType(ExecuteBatchDmlRequest.class)); + assertEquals(1, mockSpanner.countRequestsOfType(CommitRequest.class)); + assertEquals(0, mockSpanner.countRequestsOfType(ExecuteSqlRequest.class)); + } + + @Test + public void testBatchDdl() { + // Set up a DDL response on the mock server. + mockDatabaseAdmin.addResponse( + Operation.newBuilder() + .setDone(true) + .setResponse(Any.pack(Empty.getDefaultInstance())) + .setMetadata(Any.pack(UpdateDatabaseDdlMetadata.getDefaultInstance())) + .build()); + + String dsn = + String.format( + "localhost:%d/projects/p/instances/i/databases/d?usePlainText=true", getPort()); + try (Pool pool = Pool.createPool(dsn); + Connection connection = pool.createConnection()) { + ExecuteBatchDmlResponse response = + connection.executeBatch( + ExecuteBatchDmlRequest.newBuilder() + .addStatements( + Statement.newBuilder() + .setSql("create table my_table (id int64 primary key, value string(max))") + .build()) + .addStatements( + Statement.newBuilder() + .setSql("create index my_index on my_table (value)") + .build()) + .build()); + + assertEquals(2, response.getResultSetsCount()); + assertFalse(response.getResultSets(0).getStats().hasRowCountExact()); + } + + List requests = mockDatabaseAdmin.getRequests(); + assertEquals(1, requests.size()); + UpdateDatabaseDdlRequest request = (UpdateDatabaseDdlRequest) requests.get(0); + assertEquals(2, request.getStatementsCount()); + } +} From 0cb54fe04e1bcb1f11ec6d68a765d6ccc25a15e4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Knut=20Olav=20L=C3=B8ite?= Date: Sat, 20 Sep 2025 12:10:28 +0200 Subject: [PATCH 2/2] chore: add WriteMutations function for SpannerLib (#532) Adds a WriteMutations function for SpannerLib. This function can be used to write mutations to Spanner in two ways: 1. In a transaction: The mutations are buffered in the current read/write transaction. The returned message is empty. 2. Outside a transaction: The mutations are written to Spanner directly in a new read/write transaction. The returned message contains the CommitResponse. --- conn.go | 21 +++ spannerlib/api/connection.go | 45 +++++ spannerlib/api/connection_test.go | 157 ++++++++++++++++++ spannerlib/lib/connection.go | 24 +++ spannerlib/lib/connection_test.go | 60 +++++++ spannerlib/shared/shared_lib.go | 16 ++ spannerlib/shared/shared_lib_test.go | 69 ++++++++ .../google/cloud/spannerlib/Connection.java | 26 +++ .../spannerlib/internal/SpannerLibrary.java | 9 + .../cloud/spannerlib/ConnectionTest.java | 149 +++++++++++++++++ 10 files changed, 576 insertions(+) diff --git a/conn.go b/conn.go index cbe5ff33..698c8274 100644 --- a/conn.go +++ b/conn.go @@ -683,6 +683,27 @@ func sum(affected []int64) int64 { return sum } +// WriteMutations is not part of the public API of the database/sql driver. +// It is exported for internal reasons, and may receive breaking changes without prior notice. +// +// WriteMutations writes mutations using this connection. The mutations are either buffered in the current transaction, +// or written directly to Spanner using a new read/write transaction if the connection does not have a transaction. +// +// The function returns an error if the connection currently has a read-only transaction. +// +// The returned CommitResponse is nil if the connection currently has a transaction, as the mutations will only be +// applied to Spanner when the transaction commits. +func (c *conn) WriteMutations(ctx context.Context, ms []*spanner.Mutation) (*spanner.CommitResponse, error) { + if c.inTransaction() { + return nil, c.BufferWrite(ms) + } + ts, err := c.Apply(ctx, ms) + if err != nil { + return nil, err + } + return &spanner.CommitResponse{CommitTs: ts}, nil +} + func (c *conn) Apply(ctx context.Context, ms []*spanner.Mutation, opts ...spanner.ApplyOption) (commitTimestamp time.Time, err error) { if c.inTransaction() { return time.Time{}, spanner.ToSpannerError( diff --git a/spannerlib/api/connection.go b/spannerlib/api/connection.go index 1d852120..907212ec 100644 --- a/spannerlib/api/connection.go +++ b/spannerlib/api/connection.go @@ -47,6 +47,22 @@ func CloseConnection(ctx context.Context, poolId, connId int64) error { return conn.close(ctx) } +// WriteMutations writes an array of mutations to Spanner. The mutations are buffered in +// the current read/write transaction if the connection currently has a read/write transaction. +// The mutations are applied to the database in a new read/write transaction that is automatically +// committed if the connection currently does not have a transaction. +// +// The function returns an error if the connection is currently in a read-only transaction. +// +// The mutationsBytes must be an encoded BatchWriteRequest_MutationGroup protobuf object. +func WriteMutations(ctx context.Context, poolId, connId int64, mutations *spannerpb.BatchWriteRequest_MutationGroup) (*spannerpb.CommitResponse, error) { + conn, err := findConnection(poolId, connId) + if err != nil { + return nil, err + } + return conn.writeMutations(ctx, mutations) +} + // BeginTransaction starts a new transaction on the given connection. // A connection can have at most one transaction at any time. This function therefore returns an error if the // connection has an active transaction. @@ -104,6 +120,7 @@ type Connection struct { // spannerConn is an internal interface that contains the internal functions that are used by this API. // It is implemented by the spannerdriver.conn struct. type spannerConn interface { + WriteMutations(ctx context.Context, ms []*spanner.Mutation) (*spanner.CommitResponse, error) BeginReadOnlyTransaction(ctx context.Context, options *spannerdriver.ReadOnlyTransactionOptions) (driver.Tx, error) BeginReadWriteTransaction(ctx context.Context, options *spannerdriver.ReadWriteTransactionOptions) (driver.Tx, error) Commit(ctx context.Context) (*spanner.CommitResponse, error) @@ -127,6 +144,34 @@ func (conn *Connection) close(ctx context.Context) error { return nil } +func (conn *Connection) writeMutations(ctx context.Context, mutation *spannerpb.BatchWriteRequest_MutationGroup) (*spannerpb.CommitResponse, error) { + mutations := make([]*spanner.Mutation, 0, len(mutation.Mutations)) + for _, m := range mutation.Mutations { + spannerMutation, err := spanner.WrapMutation(m) + if err != nil { + return nil, err + } + mutations = append(mutations, spannerMutation) + } + var commitResponse *spanner.CommitResponse + if err := conn.backend.Raw(func(driverConn any) (err error) { + sc, _ := driverConn.(spannerConn) + commitResponse, err = sc.WriteMutations(ctx, mutations) + return err + }); err != nil { + return nil, err + } + + // The commit response is nil if the connection is currently in a transaction. + if commitResponse == nil { + return nil, nil + } + response := spannerpb.CommitResponse{ + CommitTimestamp: timestamppb.New(commitResponse.CommitTs), + } + return &response, nil +} + func (conn *Connection) BeginTransaction(ctx context.Context, txOpts *spannerpb.TransactionOptions) error { var err error if txOpts.GetReadOnly() != nil { diff --git a/spannerlib/api/connection_test.go b/spannerlib/api/connection_test.go index 64d19ab3..b4625bb3 100644 --- a/spannerlib/api/connection_test.go +++ b/spannerlib/api/connection_test.go @@ -24,6 +24,7 @@ import ( "cloud.google.com/go/spanner/apiv1/spannerpb" "github.com/googleapis/go-sql-spanner/testutil" "google.golang.org/grpc/codes" + "google.golang.org/protobuf/types/known/structpb" ) func TestCreateAndCloseConnection(t *testing.T) { @@ -143,3 +144,159 @@ func TestCloseConnectionTwice(t *testing.T) { t.Fatalf("ClosePool returned unexpected error: %v", err) } } + +func TestWriteMutations(t *testing.T) { + t.Parallel() + + ctx := context.Background() + server, teardown := setupMockServer(t) + defer teardown() + dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address) + + poolId, err := CreatePool(ctx, dsn) + if err != nil { + t.Fatalf("CreatePool returned unexpected error: %v", err) + } + connId, err := CreateConnection(ctx, poolId) + if err != nil { + t.Fatalf("CreateConnection returned unexpected error: %v", err) + } + + mutations := &spannerpb.BatchWriteRequest_MutationGroup{Mutations: []*spannerpb.Mutation{ + {Operation: &spannerpb.Mutation_Insert{Insert: &spannerpb.Mutation_Write{ + Table: "my_table", + Columns: []string{"id", "value"}, + Values: []*structpb.ListValue{ + {Values: []*structpb.Value{structpb.NewStringValue("1"), structpb.NewStringValue("One")}}, + {Values: []*structpb.Value{structpb.NewStringValue("2"), structpb.NewStringValue("Two")}}, + {Values: []*structpb.Value{structpb.NewStringValue("3"), structpb.NewStringValue("Three")}}, + }, + }}}, + {Operation: &spannerpb.Mutation_Update{Update: &spannerpb.Mutation_Write{ + Table: "my_table", + Columns: []string{"id", "value"}, + Values: []*structpb.ListValue{ + {Values: []*structpb.Value{structpb.NewStringValue("0"), structpb.NewStringValue("Zero")}}, + }, + }}}, + }} + resp, err := WriteMutations(ctx, poolId, connId, mutations) + if err != nil { + t.Fatalf("WriteMutations returned unexpected error: %v", err) + } + if resp.CommitTimestamp == nil { + t.Fatalf("CommitTimestamp is nil") + } + requests := server.TestSpanner.DrainRequestsFromServer() + beginRequests := testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.BeginTransactionRequest{})) + if g, w := len(beginRequests), 1; g != w { + t.Fatalf("num BeginTransaction requests mismatch\n Got: %d\nWant: %d", g, w) + } + commitRequests := testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.CommitRequest{})) + if g, w := len(commitRequests), 1; g != w { + t.Fatalf("num CommitRequests mismatch\n Got: %d\nWant: %d", g, w) + } + commitRequest := commitRequests[0].(*spannerpb.CommitRequest) + if g, w := len(commitRequest.Mutations), 2; g != w { + t.Fatalf("num mutations mismatch\n Got: %d\nWant: %d", g, w) + } + + // Write the same mutations in a transaction. + if err := BeginTransaction(ctx, poolId, connId, &spannerpb.TransactionOptions{}); err != nil { + t.Fatalf("BeginTransaction returned unexpected error: %v", err) + } + resp, err = WriteMutations(ctx, poolId, connId, mutations) + if err != nil { + t.Fatalf("WriteMutations returned unexpected error: %v", err) + } + if resp != nil { + t.Fatalf("WriteMutations returned unexpected response: %v", resp) + } + resp, err = Commit(ctx, poolId, connId) + if err != nil { + t.Fatalf("Commit returned unexpected error: %v", err) + } + if resp == nil { + t.Fatalf("Commit returned nil response") + } + if resp.CommitTimestamp == nil { + t.Fatalf("CommitTimestamp is nil") + } + requests = server.TestSpanner.DrainRequestsFromServer() + beginRequests = testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.BeginTransactionRequest{})) + if g, w := len(beginRequests), 1; g != w { + t.Fatalf("num BeginTransaction requests mismatch\n Got: %d\nWant: %d", g, w) + } + commitRequests = testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.CommitRequest{})) + if g, w := len(commitRequests), 1; g != w { + t.Fatalf("num CommitRequests mismatch\n Got: %d\nWant: %d", g, w) + } + commitRequest = commitRequests[0].(*spannerpb.CommitRequest) + if g, w := len(commitRequest.Mutations), 2; g != w { + t.Fatalf("num mutations mismatch\n Got: %d\nWant: %d", g, w) + } + + if err := ClosePool(ctx, poolId); err != nil { + t.Fatalf("ClosePool returned unexpected error: %v", err) + } +} + +func TestWriteMutationsInReadOnlyTx(t *testing.T) { + t.Parallel() + + ctx := context.Background() + server, teardown := setupMockServer(t) + defer teardown() + dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address) + + poolId, err := CreatePool(ctx, dsn) + if err != nil { + t.Fatalf("CreatePool returned unexpected error: %v", err) + } + connId, err := CreateConnection(ctx, poolId) + if err != nil { + t.Fatalf("CreateConnection returned unexpected error: %v", err) + } + + // Start a read-only transaction and try to write mutations to that transaction. That should return an error. + if err := BeginTransaction(ctx, poolId, connId, &spannerpb.TransactionOptions{ + Mode: &spannerpb.TransactionOptions_ReadOnly_{ReadOnly: &spannerpb.TransactionOptions_ReadOnly{}}, + }); err != nil { + t.Fatalf("BeginTransaction returned unexpected error: %v", err) + } + + mutations := &spannerpb.BatchWriteRequest_MutationGroup{Mutations: []*spannerpb.Mutation{ + {Operation: &spannerpb.Mutation_Insert{Insert: &spannerpb.Mutation_Write{ + Table: "my_table", + Columns: []string{"id", "value"}, + Values: []*structpb.ListValue{ + {Values: []*structpb.Value{structpb.NewStringValue("1"), structpb.NewStringValue("One")}}, + }, + }}}, + }} + _, err = WriteMutations(ctx, poolId, connId, mutations) + if g, w := spanner.ErrCode(err), codes.FailedPrecondition; g != w { + t.Fatalf("WriteMutations error code mismatch\n Got: %d\nWant: %d", g, w) + } + + // Committing the read-only transaction should not lead to any commits on Spanner. + _, err = Commit(ctx, poolId, connId) + if err != nil { + t.Fatalf("Commit returned unexpected error: %v", err) + } + requests := server.TestSpanner.DrainRequestsFromServer() + // There should also not be any BeginTransaction requests on Spanner, as the transaction was never really started + // by a query or other statement. + beginRequests := testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.BeginTransactionRequest{})) + if g, w := len(beginRequests), 0; g != w { + t.Fatalf("num BeginTransaction requests mismatch\n Got: %d\nWant: %d", g, w) + } + commitRequests := testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.CommitRequest{})) + if g, w := len(commitRequests), 0; g != w { + t.Fatalf("num CommitRequests mismatch\n Got: %d\nWant: %d", g, w) + } + + if err := ClosePool(ctx, poolId); err != nil { + t.Fatalf("ClosePool returned unexpected error: %v", err) + } +} diff --git a/spannerlib/lib/connection.go b/spannerlib/lib/connection.go index 72f32efb..75c6bc3a 100644 --- a/spannerlib/lib/connection.go +++ b/spannerlib/lib/connection.go @@ -34,6 +34,30 @@ func CloseConnection(ctx context.Context, poolId, connId int64) *Message { return &Message{} } +// WriteMutations writes an array of mutations to Spanner. The mutations are buffered in +// the current read/write transaction if the connection currently has a read/write transaction. +// The mutations are applied to the database in a new read/write transaction that is automatically +// committed if the connection currently does not have a transaction. +// +// The function returns an error if the connection is currently in a read-only transaction. +// +// The mutationsBytes must be an encoded BatchWriteRequest_MutationGroup protobuf object. +func WriteMutations(ctx context.Context, poolId, connId int64, mutationBytes []byte) *Message { + mutations := spannerpb.BatchWriteRequest_MutationGroup{} + if err := proto.Unmarshal(mutationBytes, &mutations); err != nil { + return errMessage(err) + } + response, err := api.WriteMutations(ctx, poolId, connId, &mutations) + if err != nil { + return errMessage(err) + } + res, err := proto.Marshal(response) + if err != nil { + return errMessage(err) + } + return &Message{Res: res} +} + // BeginTransaction starts a new transaction on the given connection. A connection can have at most one active // transaction at any time. This function therefore returns an error if the connection has an active transaction. func BeginTransaction(ctx context.Context, poolId, connId int64, txOptsBytes []byte) *Message { diff --git a/spannerlib/lib/connection_test.go b/spannerlib/lib/connection_test.go index 02945ae9..193e3fd9 100644 --- a/spannerlib/lib/connection_test.go +++ b/spannerlib/lib/connection_test.go @@ -23,6 +23,7 @@ import ( "github.com/googleapis/go-sql-spanner/testutil" "google.golang.org/grpc/codes" "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/known/structpb" ) func TestCreateAndCloseConnection(t *testing.T) { @@ -262,3 +263,62 @@ func TestBeginAndRollback(t *testing.T) { t.Fatalf("ClosePool result mismatch\n Got: %v\nWant: %v", g, w) } } + +func TestWriteMutations(t *testing.T) { + t.Parallel() + + ctx := context.Background() + server, teardown := setupMockServer(t) + defer teardown() + dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address) + + poolMsg := CreatePool(ctx, dsn) + if g, w := poolMsg.Code, int32(0); g != w { + t.Fatalf("CreatePool result mismatch\n Got: %v\nWant: %v", g, w) + } + connMsg := CreateConnection(ctx, poolMsg.ObjectId) + if g, w := connMsg.Code, int32(0); g != w { + t.Fatalf("CreateConnection result mismatch\n Got: %v\nWant: %v", g, w) + } + mutations := &spannerpb.BatchWriteRequest_MutationGroup{Mutations: []*spannerpb.Mutation{ + {Operation: &spannerpb.Mutation_Insert{Insert: &spannerpb.Mutation_Write{ + Table: "my_table", + Columns: []string{"id", "value"}, + Values: []*structpb.ListValue{ + {Values: []*structpb.Value{structpb.NewStringValue("1"), structpb.NewStringValue("One")}}, + {Values: []*structpb.Value{structpb.NewStringValue("2"), structpb.NewStringValue("Two")}}, + {Values: []*structpb.Value{structpb.NewStringValue("3"), structpb.NewStringValue("Three")}}, + }, + }}}, + }} + mutationBytes, err := proto.Marshal(mutations) + if err != nil { + t.Fatal(err) + } + mutationsMsg := WriteMutations(ctx, poolMsg.ObjectId, connMsg.ObjectId, mutationBytes) + if g, w := mutationsMsg.Code, int32(0); g != w { + t.Fatalf("WriteMutations result mismatch\n Got: %v\nWant: %v", g, w) + } + if mutationsMsg.Length() == 0 { + t.Fatal("WriteMutations returned no data") + } + + // Write mutations in a transaction. + mutationsMsg = BeginTransaction(ctx, poolMsg.ObjectId, connMsg.ObjectId, mutationBytes) + // The response should now be an empty message, as the mutations were only buffered in the transaction. + if g, w := mutationsMsg.Code, int32(0); g != w { + t.Fatalf("WriteMutations result mismatch\n Got: %v\nWant: %v", g, w) + } + if g, w := mutationsMsg.Length(), int32(0); g != w { + t.Fatalf("WriteMutations data length mismatch\n Got: %v\nWant: %v", g, w) + } + + closeMsg := CloseConnection(ctx, poolMsg.ObjectId, connMsg.ObjectId) + if g, w := closeMsg.Code, int32(0); g != w { + t.Fatalf("CloseConnection result mismatch\n Got: %v\nWant: %v", g, w) + } + closeMsg = ClosePool(ctx, poolMsg.ObjectId) + if g, w := closeMsg.Code, int32(0); g != w { + t.Fatalf("ClosePool result mismatch\n Got: %v\nWant: %v", g, w) + } +} diff --git a/spannerlib/shared/shared_lib.go b/spannerlib/shared/shared_lib.go index d68c1973..8645f219 100644 --- a/spannerlib/shared/shared_lib.go +++ b/spannerlib/shared/shared_lib.go @@ -111,6 +111,22 @@ func CloseConnection(poolId, connId int64) (int64, int32, int64, int32, unsafe.P return pin(msg) } +// WriteMutations writes an array of mutations to Spanner. The mutations are buffered in +// the current read/write transaction if the connection currently has a read/write transaction. +// The mutations are applied to the database in a new read/write transaction that is automatically +// committed if the connection currently does not have a transaction. +// +// The function returns an error if the connection is currently in a read-only transaction. +// +// The mutationsBytes must be an encoded BatchWriteRequest_MutationGroup protobuf object. +// +//export WriteMutations +func WriteMutations(poolId, connectionId int64, mutationsBytes []byte) (int64, int32, int64, int32, unsafe.Pointer) { + ctx := context.Background() + msg := lib.WriteMutations(ctx, poolId, connectionId, mutationsBytes) + return pin(msg) +} + // Execute executes a SQL statement on the given connection. // The return type is an identifier for a Rows object. This identifier can be used to // call the functions Metadata and Next to get respectively the metadata of the result diff --git a/spannerlib/shared/shared_lib_test.go b/spannerlib/shared/shared_lib_test.go index 21650f6b..b6d63467 100644 --- a/spannerlib/shared/shared_lib_test.go +++ b/spannerlib/shared/shared_lib_test.go @@ -412,6 +412,75 @@ func TestBeginAndRollbackTransaction(t *testing.T) { } } +func TestWriteMutations(t *testing.T) { + t.Parallel() + + server, teardown := setupMockServer(t) + defer teardown() + dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address) + + _, code, poolId, _, _ := CreatePool(dsn) + if g, w := code, int32(0); g != w { + t.Fatalf("CreatePool result mismatch\n Got: %v\nWant: %v", g, w) + } + _, code, connId, _, _ := CreateConnection(poolId) + if g, w := code, int32(0); g != w { + t.Fatalf("CreateConnection result mismatch\n Got: %v\nWant: %v", g, w) + } + + mutations := &spannerpb.BatchWriteRequest_MutationGroup{Mutations: []*spannerpb.Mutation{ + {Operation: &spannerpb.Mutation_Insert{Insert: &spannerpb.Mutation_Write{ + Table: "my_table", + Columns: []string{"id", "value"}, + Values: []*structpb.ListValue{ + {Values: []*structpb.Value{structpb.NewStringValue("1"), structpb.NewStringValue("One")}}, + {Values: []*structpb.Value{structpb.NewStringValue("2"), structpb.NewStringValue("Two")}}, + {Values: []*structpb.Value{structpb.NewStringValue("3"), structpb.NewStringValue("Three")}}, + }, + }}}, + }} + mutationBytes, err := proto.Marshal(mutations) + if err != nil { + t.Fatal(err) + } + // WriteMutations returns a CommitResponse or nil, depending on whether the connection has an active transaction. + mem, code, id, length, data := WriteMutations(poolId, connId, mutationBytes) + verifyDataMessage(t, "WriteMutations", mem, code, id, length, data) + + response := &spannerpb.CommitResponse{} + responseBytes := reflect.SliceAt(reflect.TypeOf(byte(0)), data, int(length)).Bytes() + if err := proto.Unmarshal(responseBytes, response); err != nil { + t.Fatal(err) + } + if response.CommitTimestamp == nil { + t.Fatal("CommitTimestamp is nil") + } + // Release the memory held by the response. + if g, w := Release(mem), int32(0); g != w { + t.Fatalf("Release() result mismatch\n Got: %v\nWant: %v", g, w) + } + + // Start a transaction on the connection and write the mutations to that transaction. + txOpts := &spannerpb.TransactionOptions{} + txOptsBytes, err := proto.Marshal(txOpts) + _, code, _, _, _ = BeginTransaction(poolId, connId, txOptsBytes) + if g, w := code, int32(0); g != w { + t.Fatalf("BeginTransaction result mismatch\n Got: %v\nWant: %v", g, w) + } + mem, code, id, length, data = WriteMutations(poolId, connId, mutationBytes) + // The response should now be an empty message, as the mutations were buffered in the current transaction. + verifyEmptyMessage(t, "WriteMutations in tx", mem, code, id, length, data) + + _, code, _, _, _ = CloseConnection(poolId, connId) + if g, w := code, int32(0); g != w { + t.Fatalf("CloseConnection result mismatch\n Got: %v\nWant: %v", g, w) + } + _, code, _, _, _ = ClosePool(poolId) + if g, w := code, int32(0); g != w { + t.Fatalf("ClosePool result mismatch\n Got: %v\nWant: %v", g, w) + } +} + func verifyEmptyMessage(t *testing.T, name string, mem int64, code int32, id int64, length int32, res unsafe.Pointer) { if g, w := mem, int64(0); g != w { t.Fatalf("%s: mem ID mismatch\n Got: %v\nWant: %v", name, g, w) diff --git a/spannerlib/wrappers/spannerlib-java/src/main/java/com/google/cloud/spannerlib/Connection.java b/spannerlib/wrappers/spannerlib-java/src/main/java/com/google/cloud/spannerlib/Connection.java index dc5aad96..18a9ca50 100644 --- a/spannerlib/wrappers/spannerlib-java/src/main/java/com/google/cloud/spannerlib/Connection.java +++ b/spannerlib/wrappers/spannerlib-java/src/main/java/com/google/cloud/spannerlib/Connection.java @@ -21,6 +21,7 @@ import com.google.cloud.spannerlib.internal.MessageHandler; import com.google.cloud.spannerlib.internal.WrappedGoBytes; import com.google.protobuf.InvalidProtocolBufferException; +import com.google.spanner.v1.BatchWriteRequest.MutationGroup; import com.google.spanner.v1.CommitResponse; import com.google.spanner.v1.ExecuteBatchDmlRequest; import com.google.spanner.v1.ExecuteBatchDmlResponse; @@ -41,6 +42,31 @@ public Pool getPool() { return this.pool; } + /** + * Writes a group of mutations to Spanner. The mutations are buffered in the current read/write + * transaction if the connection has an active read/write transaction. Otherwise, the mutations + * are written directly to Spanner using a new read/write transaction. Returns a {@link + * CommitResponse} if the mutations were written directly to Spanner, and otherwise null if the + * mutations were buffered in the current transaction. + */ + public CommitResponse WriteMutations(MutationGroup mutations) { + try (WrappedGoBytes serializedRequest = WrappedGoBytes.serialize(mutations); + MessageHandler message = + getLibrary() + .execute( + library -> + library.WriteMutations( + pool.getId(), getId(), serializedRequest.getGoBytes()))) { + if (message.getLength() == 0) { + return null; + } + ByteBuffer buffer = message.getValue().getByteBuffer(0, message.getLength()); + return CommitResponse.parseFrom(buffer); + } catch (InvalidProtocolBufferException decodeException) { + throw new RuntimeException(decodeException); + } + } + /** Starts a transaction on this connection. */ public void beginTransaction(TransactionOptions options) { try (WrappedGoBytes serializedOptions = WrappedGoBytes.serialize(options)) { diff --git a/spannerlib/wrappers/spannerlib-java/src/main/java/com/google/cloud/spannerlib/internal/SpannerLibrary.java b/spannerlib/wrappers/spannerlib-java/src/main/java/com/google/cloud/spannerlib/internal/SpannerLibrary.java index 54143891..7da5a013 100644 --- a/spannerlib/wrappers/spannerlib-java/src/main/java/com/google/cloud/spannerlib/internal/SpannerLibrary.java +++ b/spannerlib/wrappers/spannerlib-java/src/main/java/com/google/cloud/spannerlib/internal/SpannerLibrary.java @@ -65,6 +65,15 @@ default MessageHandler execute(Function function) /** Closes the given Connection. */ Message CloseConnection(long poolId, long connectionId); + /** + * Writes a group of mutations on Spanner. The mutations are buffered in the current read/write + * transaction if the connection has an active read/write transaction. Otherwise, the mutations + * are written directly to Spanner in a new read/write transaction. Returns a {@link + * com.google.spanner.v1.CommitResponse} if the mutations were written directly to Spanner, and an + * empty message if the mutations were only buffered in the current transaction. + */ + Message WriteMutations(long poolId, long connectionId, GoBytes mutations); + /** Starts a new transaction on the given Connection. */ Message BeginTransaction(long poolId, long connectionId, GoBytes transactionOptions); diff --git a/spannerlib/wrappers/spannerlib-java/src/test/java/com/google/cloud/spannerlib/ConnectionTest.java b/spannerlib/wrappers/spannerlib-java/src/test/java/com/google/cloud/spannerlib/ConnectionTest.java index 5415b867..8efc22b5 100644 --- a/spannerlib/wrappers/spannerlib-java/src/test/java/com/google/cloud/spannerlib/ConnectionTest.java +++ b/spannerlib/wrappers/spannerlib-java/src/test/java/com/google/cloud/spannerlib/ConnectionTest.java @@ -18,10 +18,25 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import com.google.cloud.spanner.connection.AbstractMockServerTest; +import com.google.common.collect.ImmutableList; +import com.google.protobuf.ListValue; +import com.google.protobuf.Value; +import com.google.rpc.Code; +import com.google.spanner.v1.BatchWriteRequest.MutationGroup; +import com.google.spanner.v1.BeginTransactionRequest; +import com.google.spanner.v1.CommitRequest; +import com.google.spanner.v1.CommitResponse; import com.google.spanner.v1.CreateSessionRequest; +import com.google.spanner.v1.Mutation; +import com.google.spanner.v1.Mutation.Write; +import com.google.spanner.v1.TransactionOptions; +import com.google.spanner.v1.TransactionOptions.ReadOnly; import org.junit.Test; public class ConnectionTest extends AbstractMockServerTest { @@ -53,4 +68,138 @@ public void testCreateTwoConnections() { } } } + + @Test + public void testWriteMutations() { + String dsn = + String.format( + "localhost:%d/projects/p/instances/i/databases/d?usePlainText=true", getPort()); + try (Pool pool = Pool.createPool(dsn); + Connection connection = pool.createConnection()) { + CommitResponse response = + connection.WriteMutations( + MutationGroup.newBuilder() + .addMutations( + Mutation.newBuilder() + .setInsert( + Write.newBuilder() + .addAllColumns(ImmutableList.of("id", "value")) + .addValues( + ListValue.newBuilder() + .addValues(Value.newBuilder().setStringValue("1").build()) + .addValues( + Value.newBuilder().setStringValue("One").build()) + .build()) + .addValues( + ListValue.newBuilder() + .addValues(Value.newBuilder().setStringValue("2").build()) + .addValues( + Value.newBuilder().setStringValue("Two").build()) + .build()) + .build()) + .build()) + .addMutations( + Mutation.newBuilder() + .setInsertOrUpdate( + Write.newBuilder() + .addAllColumns(ImmutableList.of("id", "value")) + .addValues( + ListValue.newBuilder() + .addValues(Value.newBuilder().setStringValue("0").build()) + .addValues( + Value.newBuilder().setStringValue("Zero").build()) + .build()) + .build()) + .build()) + .build()); + assertNotNull(response); + assertNotNull(response.getCommitTimestamp()); + + assertEquals(1, mockSpanner.countRequestsOfType(BeginTransactionRequest.class)); + assertEquals(1, mockSpanner.countRequestsOfType(CommitRequest.class)); + CommitRequest request = mockSpanner.getRequestsOfType(CommitRequest.class).get(0); + assertEquals(2, request.getMutationsCount()); + assertEquals(2, request.getMutations(0).getInsert().getValuesCount()); + assertEquals(1, request.getMutations(1).getInsertOrUpdate().getValuesCount()); + } + } + + @Test + public void testWriteMutationsInTransaction() { + String dsn = + String.format( + "localhost:%d/projects/p/instances/i/databases/d?usePlainText=true", getPort()); + try (Pool pool = Pool.createPool(dsn); + Connection connection = pool.createConnection()) { + connection.beginTransaction(TransactionOptions.getDefaultInstance()); + CommitResponse response = + connection.WriteMutations( + MutationGroup.newBuilder() + .addMutations( + Mutation.newBuilder() + .setInsert( + Write.newBuilder() + .addAllColumns(ImmutableList.of("id", "value")) + .addValues( + ListValue.newBuilder() + .addValues(Value.newBuilder().setStringValue("1").build()) + .addValues( + Value.newBuilder().setStringValue("One").build()) + .build()) + .build()) + .build()) + .build()); + // The mutations are only buffered in the current transaction, so there should be no response. + assertNull(response); + + // Committing the transaction should return a CommitResponse. + response = connection.commit(); + assertNotNull(response); + assertNotNull(response.getCommitTimestamp()); + + assertEquals(1, mockSpanner.countRequestsOfType(BeginTransactionRequest.class)); + assertEquals(1, mockSpanner.countRequestsOfType(CommitRequest.class)); + CommitRequest request = mockSpanner.getRequestsOfType(CommitRequest.class).get(0); + assertEquals(1, request.getMutationsCount()); + assertEquals(1, request.getMutations(0).getInsert().getValuesCount()); + } + } + + @Test + public void testWriteMutationsInReadOnlyTransaction() { + String dsn = + String.format( + "localhost:%d/projects/p/instances/i/databases/d?usePlainText=true", getPort()); + try (Pool pool = Pool.createPool(dsn); + Connection connection = pool.createConnection()) { + connection.beginTransaction( + TransactionOptions.newBuilder().setReadOnly(ReadOnly.newBuilder().build()).build()); + SpannerLibException exception = + assertThrows( + SpannerLibException.class, + () -> + connection.WriteMutations( + MutationGroup.newBuilder() + .addMutations( + Mutation.newBuilder() + .setInsert( + Write.newBuilder() + .addAllColumns(ImmutableList.of("id", "value")) + .addValues( + ListValue.newBuilder() + .addValues( + Value.newBuilder() + .setStringValue("1") + .build()) + .addValues( + Value.newBuilder() + .setStringValue("One") + .build()) + .build()) + .build()) + .build()) + .build())); + assertEquals(Code.FAILED_PRECONDITION.getNumber(), exception.getStatus().getCode()); + } + } }