Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,9 @@ type SpannerConn interface {
// return the same Spanner client.
UnderlyingClient() (client *spanner.Client, err error)

// DetectStatementType returns the type of SQL statement.
DetectStatementType(query string) parser.StatementType

// resetTransactionForRetry resets the current transaction after it has
// been aborted by Spanner. Calling this function on a transaction that
// has not been aborted is not supported and will cause an error to be
Expand Down Expand Up @@ -286,6 +289,11 @@ func (c *conn) UnderlyingClient() (*spanner.Client, error) {
return c.client, nil
}

func (c *conn) DetectStatementType(query string) parser.StatementType {
info := c.parser.DetectStatementType(query)
return info.StatementType
}

func (c *conn) CommitTimestamp() (time.Time, error) {
ts := propertyCommitTimestamp.GetValueOrDefault(c.state)
if ts == nil {
Expand Down
238 changes: 238 additions & 0 deletions spannerlib/api/batch_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,238 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package api

import (
"context"
"fmt"
"reflect"
"testing"

"cloud.google.com/go/longrunning/autogen/longrunningpb"
"cloud.google.com/go/spanner"
"cloud.google.com/go/spanner/admin/database/apiv1/databasepb"
"cloud.google.com/go/spanner/apiv1/spannerpb"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/googleapis/go-sql-spanner/testutil"
"google.golang.org/grpc/codes"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/anypb"
"google.golang.org/protobuf/types/known/emptypb"
)

func TestExecuteDmlBatch(t *testing.T) {
t.Parallel()

ctx := context.Background()
server, teardown := setupMockServer(t)
defer teardown()
dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address)

poolId, err := CreatePool(ctx, dsn)
if err != nil {
t.Fatalf("CreatePool returned unexpected error: %v", err)
}
connId, err := CreateConnection(ctx, poolId)
if err != nil {
t.Fatalf("CreateConnection returned unexpected error: %v", err)
}

// Execute a DML batch.
request := &spannerpb.ExecuteBatchDmlRequest{Statements: []*spannerpb.ExecuteBatchDmlRequest_Statement{
{Sql: testutil.UpdateBarSetFoo},
{Sql: testutil.UpdateBarSetFoo},
}}
resp, err := ExecuteBatch(ctx, poolId, connId, request)
if err != nil {
t.Fatalf("ExecuteBatch returned unexpected error: %v", err)
}
if g, w := len(resp.ResultSets), 2; g != w {
t.Fatalf("num results mismatch\n Got: %d\nWant: %d", g, w)
}
for i, result := range resp.ResultSets {
if g, w := result.Stats.GetRowCountExact(), int64(testutil.UpdateBarSetFooRowCount); g != w {
t.Fatalf("%d: update count mismatch\n Got: %d\nWant: %d", i, g, w)
}
}

requests := server.TestSpanner.DrainRequestsFromServer()
// There should be no ExecuteSql requests.
executeRequests := testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.ExecuteSqlRequest{}))
if g, w := len(executeRequests), 0; g != w {
t.Fatalf("Execute request count mismatch\n Got: %v\nWant: %v", g, w)
}
batchRequests := testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.ExecuteBatchDmlRequest{}))
if g, w := len(batchRequests), 1; g != w {
t.Fatalf("Execute batch request count mismatch\n Got: %v\nWant: %v", g, w)
}

if err := CloseConnection(ctx, poolId, connId); err != nil {
t.Fatalf("CloseConnection returned unexpected error: %v", err)
}
if err := ClosePool(ctx, poolId); err != nil {
t.Fatalf("ClosePool returned unexpected error: %v", err)
}
}

func TestExecuteDdlBatch(t *testing.T) {
t.Parallel()

ctx := context.Background()
server, teardown := setupMockServer(t)
defer teardown()
dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address)
// Set up a result for a DDL statement on the mock server.
var expectedResponse = &emptypb.Empty{}
anyMsg, _ := anypb.New(expectedResponse)
server.TestDatabaseAdmin.SetResps([]proto.Message{
&longrunningpb.Operation{
Done: true,
Result: &longrunningpb.Operation_Response{Response: anyMsg},
Name: "test-operation",
},
})

poolId, err := CreatePool(ctx, dsn)
if err != nil {
t.Fatalf("CreatePool returned unexpected error: %v", err)
}
connId, err := CreateConnection(ctx, poolId)
if err != nil {
t.Fatalf("CreateConnection returned unexpected error: %v", err)
}

// Execute a DDL batch. This also uses a DML batch request.
request := &spannerpb.ExecuteBatchDmlRequest{Statements: []*spannerpb.ExecuteBatchDmlRequest_Statement{
{Sql: "create table my_table (id int64 primary key, value string(100))"},
{Sql: "create index my_index on my_table (value)"},
}}
resp, err := ExecuteBatch(ctx, poolId, connId, request)
if err != nil {
t.Fatalf("ExecuteBatch returned unexpected error: %v", err)
}
// The response should contain an 'update count' per DDL statement.
if g, w := len(resp.ResultSets), 2; g != w {
t.Fatalf("num results mismatch\n Got: %d\nWant: %d", g, w)
}
// There is no update count for DDL statements.
for i, result := range resp.ResultSets {
emptyStats := &spannerpb.ResultSetStats{}
if g, w := result.Stats, emptyStats; !cmp.Equal(g, w, cmpopts.IgnoreUnexported(spannerpb.ResultSetStats{})) {
t.Fatalf("%d: ResultSetStats mismatch\n Got: %v\nWant: %v", i, g, w)
}
}

requests := server.TestSpanner.DrainRequestsFromServer()
// There should be no ExecuteSql requests.
executeRequests := testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.ExecuteSqlRequest{}))
if g, w := len(executeRequests), 0; g != w {
t.Fatalf("Execute request count mismatch\n Got: %v\nWant: %v", g, w)
}
// There should also be no ExecuteBatchDml requests.
batchDmlRequests := testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.ExecuteBatchDmlRequest{}))
if g, w := len(batchDmlRequests), 0; g != w {
t.Fatalf("ExecuteBatchDmlRequest count mismatch\n Got: %v\nWant: %v", g, w)
}

adminRequests := server.TestDatabaseAdmin.Reqs()
if g, w := len(adminRequests), 1; g != w {
t.Fatalf("admin request count mismatch\n Got: %v\nWant: %v", g, w)
}
ddlRequest := adminRequests[0].(*databasepb.UpdateDatabaseDdlRequest)
if g, w := len(ddlRequest.Statements), 2; g != w {
t.Fatalf("DDL statement count mismatch\n Got: %v\nWant: %v", g, w)
}

if err := CloseConnection(ctx, poolId, connId); err != nil {
t.Fatalf("CloseConnection returned unexpected error: %v", err)
}
if err := ClosePool(ctx, poolId); err != nil {
t.Fatalf("ClosePool returned unexpected error: %v", err)
}
}

func TestExecuteMixedBatch(t *testing.T) {
t.Parallel()

ctx := context.Background()
server, teardown := setupMockServer(t)
defer teardown()
dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address)

poolId, err := CreatePool(ctx, dsn)
if err != nil {
t.Fatalf("CreatePool returned unexpected error: %v", err)
}
connId, err := CreateConnection(ctx, poolId)
if err != nil {
t.Fatalf("CreateConnection returned unexpected error: %v", err)
}

// Try to execute a batch with mixed DML and DDL statements. This should fail.
request := &spannerpb.ExecuteBatchDmlRequest{Statements: []*spannerpb.ExecuteBatchDmlRequest_Statement{
{Sql: "create table my_table (id int64 primary key, value string(100))"},
{Sql: "update my_table set value = 100 where true"},
}}
_, err = ExecuteBatch(ctx, poolId, connId, request)
if g, w := spanner.ErrCode(err), codes.InvalidArgument; g != w {
t.Fatalf("error code mismatch\n Got: %v\nWant: %v", g, w)
}

if err := CloseConnection(ctx, poolId, connId); err != nil {
t.Fatalf("CloseConnection returned unexpected error: %v", err)
}
if err := ClosePool(ctx, poolId); err != nil {
t.Fatalf("ClosePool returned unexpected error: %v", err)
}
}

func TestExecuteDdlBatchInTransaction(t *testing.T) {
t.Parallel()

ctx := context.Background()
server, teardown := setupMockServer(t)
defer teardown()
dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address)

poolId, err := CreatePool(ctx, dsn)
if err != nil {
t.Fatalf("CreatePool returned unexpected error: %v", err)
}
connId, err := CreateConnection(ctx, poolId)
if err != nil {
t.Fatalf("CreateConnection returned unexpected error: %v", err)
}
if err := BeginTransaction(ctx, poolId, connId, &spannerpb.TransactionOptions{}); err != nil {
t.Fatalf("BeginTransaction returned unexpected error: %v", err)
}

// Try to execute a DDL batch in a transaction. This should fail.
request := &spannerpb.ExecuteBatchDmlRequest{Statements: []*spannerpb.ExecuteBatchDmlRequest_Statement{
{Sql: "create table my_table (id int64 primary key, value string(100))"},
{Sql: "create index my_index on my_table (value)"},
}}
_, err = ExecuteBatch(ctx, poolId, connId, request)
if g, w := spanner.ErrCode(err), codes.FailedPrecondition; g != w {
t.Fatalf("error code mismatch\n Got: %v\nWant: %v", g, w)
}

if err := CloseConnection(ctx, poolId, connId); err != nil {
t.Fatalf("CloseConnection returned unexpected error: %v", err)
}
if err := ClosePool(ctx, poolId); err != nil {
t.Fatalf("ClosePool returned unexpected error: %v", err)
}
}
Loading
Loading