Skip to content

Commit f331cf4

Browse files
committed
chore: add Execute function for SpannerLib
Adds an Execute function for SpannerLib that can be used to execute any type of SQL statement. The return type is always a Rows object. The Rows object is empty for DDL statements, it only contains ResultSetStats for DML statements without a THEN RETURN clause, and it contains actual row data for queries and DML statements with a THEN RETURN clause. The Execute function can also be used to execute client-side SQL statements, like BEGIN, COMMIT, SET, SHOW, etc.
1 parent e42a554 commit f331cf4

File tree

18 files changed

+1555
-8
lines changed

18 files changed

+1555
-8
lines changed

spannerlib/api/connection.go

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,14 @@ package api
1717
import (
1818
"context"
1919
"database/sql"
20+
"fmt"
21+
"strings"
2022
"sync"
2123
"sync/atomic"
24+
25+
"cloud.google.com/go/spanner"
26+
"cloud.google.com/go/spanner/apiv1/spannerpb"
27+
spannerdriver "github.com/googleapis/go-sql-spanner"
2228
)
2329

2430
// CloseConnection looks up the connection with the given poolId and connId and closes it.
@@ -36,6 +42,14 @@ func CloseConnection(ctx context.Context, poolId, connId int64) error {
3642
return conn.close(ctx)
3743
}
3844

45+
func Execute(ctx context.Context, poolId, connId int64, executeSqlRequest *spannerpb.ExecuteSqlRequest) (int64, error) {
46+
conn, err := findConnection(poolId, connId)
47+
if err != nil {
48+
return 0, err
49+
}
50+
return conn.Execute(ctx, executeSqlRequest)
51+
}
52+
3953
type Connection struct {
4054
// results contains the open query results for this connection.
4155
results *sync.Map
@@ -45,6 +59,11 @@ type Connection struct {
4559
backend *sql.Conn
4660
}
4761

62+
type queryExecutor interface {
63+
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
64+
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
65+
}
66+
4867
func (conn *Connection) close(ctx context.Context) error {
4968
conn.closeResults(ctx)
5069
err := conn.backend.Close()
@@ -60,3 +79,73 @@ func (conn *Connection) closeResults(ctx context.Context) {
6079
return true
6180
})
6281
}
82+
83+
func (conn *Connection) Execute(ctx context.Context, statement *spannerpb.ExecuteSqlRequest) (int64, error) {
84+
return execute(ctx, conn, conn.backend, statement)
85+
}
86+
87+
func execute(ctx context.Context, conn *Connection, executor queryExecutor, statement *spannerpb.ExecuteSqlRequest) (int64, error) {
88+
params := extractParams(statement)
89+
it, err := executor.QueryContext(ctx, statement.Sql, params...)
90+
if err != nil {
91+
return 0, err
92+
}
93+
// The first result set should contain the metadata.
94+
if !it.Next() {
95+
return 0, fmt.Errorf("query returned no metadata")
96+
}
97+
metadata := &spannerpb.ResultSetMetadata{}
98+
if err := it.Scan(&metadata); err != nil {
99+
return 0, err
100+
}
101+
// Move to the next result set, which contains the normal data.
102+
if !it.NextResultSet() {
103+
return 0, fmt.Errorf("no results found after metadata")
104+
}
105+
id := conn.resultsIdx.Add(1)
106+
res := &rows{
107+
backend: it,
108+
metadata: metadata,
109+
}
110+
if len(metadata.RowType.Fields) == 0 {
111+
// No rows returned. Read the stats now.
112+
res.readStats(ctx)
113+
}
114+
conn.results.Store(id, res)
115+
return id, nil
116+
}
117+
118+
func extractParams(statement *spannerpb.ExecuteSqlRequest) []any {
119+
paramsLen := 1
120+
if statement.Params != nil {
121+
paramsLen = 1 + len(statement.Params.Fields)
122+
}
123+
params := make([]any, paramsLen)
124+
params = append(params, spannerdriver.ExecOptions{
125+
DecodeOption: spannerdriver.DecodeOptionProto,
126+
// TODO: Implement support for passing in stale query options
127+
// TimestampBound: extractTimestampBound(statement),
128+
ReturnResultSetMetadata: true,
129+
ReturnResultSetStats: true,
130+
DirectExecuteQuery: true,
131+
})
132+
if statement.Params != nil {
133+
if statement.ParamTypes == nil {
134+
statement.ParamTypes = make(map[string]*spannerpb.Type)
135+
}
136+
for param, value := range statement.Params.Fields {
137+
genericValue := spanner.GenericColumnValue{
138+
Value: value,
139+
Type: statement.ParamTypes[param],
140+
}
141+
if strings.HasPrefix(param, "_") {
142+
// Prefix the parameter name with a 'p' to work around the fact that database/sql does not allow
143+
// named arguments to start with anything else than a letter.
144+
params = append(params, sql.Named("p"+param, spannerdriver.SpannerNamedArg{NameInQuery: param, Value: genericValue}))
145+
} else {
146+
params = append(params, sql.Named(param, genericValue))
147+
}
148+
}
149+
}
150+
return params
151+
}

spannerlib/api/pool.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ package api
1717
import (
1818
"context"
1919
"database/sql"
20+
"fmt"
2021
"sync"
2122
"sync/atomic"
2223

@@ -131,3 +132,16 @@ func findConnection(poolId, connId int64) (*Connection, error) {
131132
conn := c.(*Connection)
132133
return conn, nil
133134
}
135+
136+
func findRows(poolId, connId, rowsId int64) (*rows, error) {
137+
conn, err := findConnection(poolId, connId)
138+
if err != nil {
139+
return nil, err
140+
}
141+
r, ok := conn.results.Load(rowsId)
142+
if !ok {
143+
return nil, fmt.Errorf("rows %v not found", rowsId)
144+
}
145+
res := r.(*rows)
146+
return res, nil
147+
}

spannerlib/api/rows.go

Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
// Copyright 2025 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// https://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package api
16+
17+
import (
18+
"context"
19+
"database/sql"
20+
"errors"
21+
22+
"cloud.google.com/go/spanner"
23+
"cloud.google.com/go/spanner/apiv1/spannerpb"
24+
"google.golang.org/grpc/codes"
25+
"google.golang.org/grpc/status"
26+
"google.golang.org/protobuf/proto"
27+
"google.golang.org/protobuf/types/known/structpb"
28+
)
29+
30+
type EncodeRowOption int32
31+
32+
const (
33+
EncodeRowOptionProto EncodeRowOption = iota
34+
)
35+
36+
// Metadata returns the ResultSetMetadata of the given rows.
37+
// This function can be called for any type of statement (queries, DML, DDL).
38+
func Metadata(_ context.Context, poolId, connId, rowsId int64) (*spannerpb.ResultSetMetadata, error) {
39+
res, err := findRows(poolId, connId, rowsId)
40+
if err != nil {
41+
return nil, err
42+
}
43+
return res.Metadata()
44+
}
45+
46+
// ResultSetStats returns the result statistics of the given rows.
47+
// This function can only be called once all data in the rows have been fetched.
48+
// The stats are empty for queries and DDL statements.
49+
func ResultSetStats(ctx context.Context, poolId, connId, rowsId int64) (*spannerpb.ResultSetStats, error) {
50+
res, err := findRows(poolId, connId, rowsId)
51+
if err != nil {
52+
return nil, err
53+
}
54+
return res.ResultSetStats(ctx)
55+
}
56+
57+
// NextEncoded returns the next row data in encoded form.
58+
// Using NextEncoded instead of Next can be more efficient for large result sets,
59+
// as it allows the library to re-use the encoding buffer.
60+
// TODO: Add an encoder function as input argument, instead of hardcoding protobuf encoding here.
61+
func NextEncoded(ctx context.Context, poolId, connId, rowsId int64) ([]byte, error) {
62+
_, bytes, err := next(ctx, poolId, connId, rowsId, true)
63+
if err != nil {
64+
return nil, err
65+
}
66+
return bytes, nil
67+
}
68+
69+
// Next returns the next row as a protobuf ListValue.
70+
func Next(ctx context.Context, poolId, connId, rowsId int64) (*structpb.ListValue, error) {
71+
values, _, err := next(ctx, poolId, connId, rowsId, false)
72+
if err != nil {
73+
return nil, err
74+
}
75+
return values, nil
76+
}
77+
78+
// next returns the next row of data.
79+
// The row is returned as a protobuf ListValue if marshalResult==false.
80+
// The row is returned as a byte slice if marshalResult==true.
81+
// TODO: Add generics to the function and add input arguments for encoding instead of hardcoding it.
82+
func next(ctx context.Context, poolId, connId, rowsId int64, marshalResult bool) (*structpb.ListValue, []byte, error) {
83+
rows, err := findRows(poolId, connId, rowsId)
84+
if err != nil {
85+
return nil, nil, err
86+
}
87+
values, err := rows.Next(ctx)
88+
if err != nil {
89+
return nil, nil, err
90+
}
91+
if !marshalResult || values == nil {
92+
return values, nil, nil
93+
}
94+
95+
rows.marshalBuffer, err = proto.MarshalOptions{}.MarshalAppend(rows.marshalBuffer[:0], rows.values)
96+
if err != nil {
97+
return nil, nil, err
98+
}
99+
return values, rows.marshalBuffer, nil
100+
}
101+
102+
// CloseRows closes the given rows. Callers must always call this to clean up any resources
103+
// that are held by the underlying cursor.
104+
func CloseRows(ctx context.Context, poolId, connId, rowsId int64) error {
105+
conn, err := findConnection(poolId, connId)
106+
if err != nil {
107+
return err
108+
}
109+
r, ok := conn.results.LoadAndDelete(rowsId)
110+
if !ok {
111+
return nil
112+
}
113+
res := r.(*rows)
114+
return res.Close(ctx)
115+
}
116+
117+
type rows struct {
118+
backend *sql.Rows
119+
metadata *spannerpb.ResultSetMetadata
120+
stats *spannerpb.ResultSetStats
121+
done bool
122+
123+
buffer []any
124+
values *structpb.ListValue
125+
marshalBuffer []byte
126+
}
127+
128+
func (rows *rows) Close(ctx context.Context) error {
129+
err := rows.backend.Close()
130+
if err != nil {
131+
return err
132+
}
133+
return nil
134+
}
135+
136+
func (rows *rows) Metadata() (*spannerpb.ResultSetMetadata, error) {
137+
return rows.metadata, nil
138+
}
139+
140+
func (rows *rows) ResultSetStats(ctx context.Context) (*spannerpb.ResultSetStats, error) {
141+
if rows.stats == nil {
142+
rows.readStats(ctx)
143+
}
144+
return rows.stats, nil
145+
}
146+
147+
type genericValue struct {
148+
v *structpb.Value
149+
}
150+
151+
func (gv *genericValue) Scan(src any) error {
152+
if v, ok := src.(spanner.GenericColumnValue); ok {
153+
gv.v = v.Value
154+
return nil
155+
}
156+
return errors.New("cannot convert value to generic column value")
157+
}
158+
159+
func (rows *rows) Next(ctx context.Context) (*structpb.ListValue, error) {
160+
// No columns means no rows, so just return nil to indicate that there are no (more) rows.
161+
if len(rows.metadata.RowType.Fields) == 0 || rows.done {
162+
return nil, nil
163+
}
164+
if rows.stats != nil {
165+
return nil, spanner.ToSpannerError(status.Error(codes.FailedPrecondition, "cannot read more data after returning stats"))
166+
}
167+
ok := rows.backend.Next()
168+
if !ok {
169+
rows.done = true
170+
// No more rows. Read stats and return nil.
171+
rows.readStats(ctx)
172+
// nil indicates no more rows.
173+
return nil, nil
174+
}
175+
176+
if rows.buffer == nil {
177+
rows.buffer = make([]any, len(rows.metadata.RowType.Fields))
178+
for i := range rows.buffer {
179+
rows.buffer[i] = &genericValue{}
180+
}
181+
rows.values = &structpb.ListValue{
182+
Values: make([]*structpb.Value, len(rows.buffer)),
183+
}
184+
rows.marshalBuffer = make([]byte, 0)
185+
}
186+
if err := rows.backend.Scan(rows.buffer...); err != nil {
187+
return nil, err
188+
}
189+
for i := range rows.buffer {
190+
rows.values.Values[i] = rows.buffer[i].(*genericValue).v
191+
}
192+
return rows.values, nil
193+
}
194+
195+
func (rows *rows) readStats(ctx context.Context) {
196+
rows.stats = &spannerpb.ResultSetStats{}
197+
if !rows.backend.NextResultSet() {
198+
return
199+
}
200+
if rows.backend.Next() {
201+
_ = rows.backend.Scan(&rows.stats)
202+
}
203+
}

0 commit comments

Comments
 (0)