diff --git a/internal/tools/bigquery/bigqueryexecutesql/bigqueryexecutesql.go b/internal/tools/bigquery/bigqueryexecutesql/bigqueryexecutesql.go index 124c7874fe01..230ca7aea98b 100644 --- a/internal/tools/bigquery/bigqueryexecutesql/bigqueryexecutesql.go +++ b/internal/tools/bigquery/bigqueryexecutesql/bigqueryexecutesql.go @@ -27,6 +27,7 @@ import ( "github.com/googleapis/genai-toolbox/internal/tools" bqutil "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerycommon" "github.com/googleapis/genai-toolbox/internal/util" + "github.com/googleapis/genai-toolbox/internal/util/orderedmap" bigqueryrestapi "google.golang.org/api/bigquery/v2" "google.golang.org/api/iterator" ) @@ -324,19 +325,19 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken return nil, fmt.Errorf("unable to read query results: %w", err) } for { - var row map[string]bigqueryapi.Value - err = it.Next(&row) + var val map[string]bigqueryapi.Value + err = it.Next(&val) if err == iterator.Done { break } if err != nil { return nil, fmt.Errorf("unable to iterate through query results: %w", err) } - vMap := make(map[string]any) - for key, value := range row { - vMap[key] = value + row := orderedmap.Row{} + for key, value := range val { + row.Add(key, value) } - out = append(out, vMap) + out = append(out, row) } // If the query returned any rows, return them directly. if len(out) > 0 { diff --git a/internal/tools/mssql/mssqlexecutesql/mssqlexecutesql.go b/internal/tools/mssql/mssqlexecutesql/mssqlexecutesql.go index 38ba0fe2f7ba..77ab281f104e 100644 --- a/internal/tools/mssql/mssqlexecutesql/mssqlexecutesql.go +++ b/internal/tools/mssql/mssqlexecutesql/mssqlexecutesql.go @@ -25,6 +25,7 @@ import ( "github.com/googleapis/genai-toolbox/internal/sources/mssql" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util" + "github.com/googleapis/genai-toolbox/internal/util/orderedmap" ) const kind string = "mssql-execute-sql" @@ -152,11 +153,11 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken if scanErr != nil { return nil, fmt.Errorf("unable to parse row: %w", scanErr) } - vMap := make(map[string]any) + row := orderedmap.Row{} for i, name := range cols { - vMap[name] = rawValues[i] + row.Add(name, rawValues[i]) } - out = append(out, vMap) + out = append(out, row) } } diff --git a/internal/tools/mysql/mysqlexecutesql/mysqlexecutesql.go b/internal/tools/mysql/mysqlexecutesql/mysqlexecutesql.go index 890cb0c9ed0d..81fc6ba74366 100644 --- a/internal/tools/mysql/mysqlexecutesql/mysqlexecutesql.go +++ b/internal/tools/mysql/mysqlexecutesql/mysqlexecutesql.go @@ -27,6 +27,7 @@ import ( "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqlcommon" "github.com/googleapis/genai-toolbox/internal/util" + "github.com/googleapis/genai-toolbox/internal/util/orderedmap" ) const kind string = "mysql-execute-sql" @@ -159,20 +160,21 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken if err != nil { return nil, fmt.Errorf("unable to parse row: %w", err) } - vMap := make(map[string]any) + row := orderedmap.Row{} for i, name := range cols { val := rawValues[i] if val == nil { - vMap[name] = nil + row.Add(name, nil) continue } - vMap[name], err = mysqlcommon.ConvertToType(colTypes[i], val) + convertedValue, err := mysqlcommon.ConvertToType(colTypes[i], val) if err != nil { return nil, fmt.Errorf("errors encountered when converting values: %w", err) } + row.Add(name, convertedValue) } - out = append(out, vMap) + out = append(out, row) } if err := results.Err(); err != nil { diff --git a/internal/tools/postgres/postgresexecutesql/postgresexecutesql.go b/internal/tools/postgres/postgresexecutesql/postgresexecutesql.go index c43fac80db37..6f281e77e990 100644 --- a/internal/tools/postgres/postgresexecutesql/postgresexecutesql.go +++ b/internal/tools/postgres/postgresexecutesql/postgresexecutesql.go @@ -25,6 +25,7 @@ import ( "github.com/googleapis/genai-toolbox/internal/sources/postgres" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util" + "github.com/googleapis/genai-toolbox/internal/util/orderedmap" "github.com/jackc/pgx/v5/pgxpool" ) @@ -142,11 +143,11 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken if err != nil { return nil, fmt.Errorf("unable to parse row: %w", err) } - vMap := make(map[string]any) + row := orderedmap.Row{} for i, f := range fields { - vMap[f.Name] = v[i] + row.Add(f.Name, v[i]) } - out = append(out, vMap) + out = append(out, row) } if err := results.Err(); err != nil { diff --git a/internal/tools/spanner/spannerexecutesql/spannerexecutesql.go b/internal/tools/spanner/spannerexecutesql/spannerexecutesql.go index 495e5023598d..891bded786ef 100644 --- a/internal/tools/spanner/spannerexecutesql/spannerexecutesql.go +++ b/internal/tools/spanner/spannerexecutesql/spannerexecutesql.go @@ -24,6 +24,7 @@ import ( spannerdb "github.com/googleapis/genai-toolbox/internal/sources/spanner" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util" + "github.com/googleapis/genai-toolbox/internal/util/orderedmap" "google.golang.org/api/iterator" ) @@ -131,12 +132,12 @@ func processRows(iter *spanner.RowIterator) ([]any, error) { return nil, fmt.Errorf("unable to parse row: %w", err) } - vMap := make(map[string]any) + rowMap := orderedmap.Row{} cols := row.ColumnNames() for i, c := range cols { - vMap[c] = row.ColumnValue(i) + rowMap.Add(c, row.ColumnValue(i)) } - out = append(out, vMap) + out = append(out, rowMap) } return out, nil } diff --git a/internal/tools/sqlite/sqliteexecutesql/sqliteexecutesql.go b/internal/tools/sqlite/sqliteexecutesql/sqliteexecutesql.go index 87d53c5f486f..ff369338db66 100644 --- a/internal/tools/sqlite/sqliteexecutesql/sqliteexecutesql.go +++ b/internal/tools/sqlite/sqliteexecutesql/sqliteexecutesql.go @@ -25,6 +25,7 @@ import ( "github.com/googleapis/genai-toolbox/internal/sources/sqlite" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util" + "github.com/googleapis/genai-toolbox/internal/util/orderedmap" ) const kind string = "sqlite-execute-sql" @@ -155,11 +156,11 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken if err != nil { return nil, fmt.Errorf("unable to parse row: %w", err) } - vMap := make(map[string]any) + row := orderedmap.Row{} for i, name := range cols { val := rawValues[i] if val == nil { - vMap[name] = nil + row.Add(name, nil) continue } @@ -167,13 +168,13 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken if jsonString, ok := val.(string); ok { var unmarshaledData any if json.Unmarshal([]byte(jsonString), &unmarshaledData) == nil { - vMap[name] = unmarshaledData + row.Add(name, unmarshaledData) continue } } - vMap[name] = val + row.Add(name, val) } - out = append(out, vMap) + out = append(out, row) } if err := results.Err(); err != nil { diff --git a/internal/tools/sqlite/sqliteexecutesql/sqliteexecutesql_test.go b/internal/tools/sqlite/sqliteexecutesql/sqliteexecutesql_test.go index fb09460b0245..b9cbe170a64e 100644 --- a/internal/tools/sqlite/sqliteexecutesql/sqliteexecutesql_test.go +++ b/internal/tools/sqlite/sqliteexecutesql/sqliteexecutesql_test.go @@ -26,6 +26,7 @@ import ( "github.com/googleapis/genai-toolbox/internal/testutils" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/sqlite/sqliteexecutesql" + "github.com/googleapis/genai-toolbox/internal/util/orderedmap" _ "modernc.org/sqlite" ) @@ -159,8 +160,20 @@ func TestTool_Invoke(t *testing.T) { }, }, want: []any{ - map[string]any{"id": int64(1), "name": "Alice", "age": int64(30)}, - map[string]any{"id": int64(2), "name": "Bob", "age": int64(25)}, + orderedmap.Row{ + Columns: []orderedmap.Column{ + {Name: "id", Value: int64(1)}, + {Name: "name", Value: "Alice"}, + {Name: "age", Value: int64(30)}, + }, + }, + orderedmap.Row{ + Columns: []orderedmap.Column{ + {Name: "id", Value: int64(2)}, + {Name: "name", Value: "Bob"}, + {Name: "age", Value: int64(25)}, + }, + }, }, wantErr: false, }, @@ -233,7 +246,13 @@ func TestTool_Invoke(t *testing.T) { }, }, want: []any{ - map[string]any{"id": int64(1), "null_col": nil, "blob_col": []byte{1, 2, 3}}, + orderedmap.Row{ + Columns: []orderedmap.Column{ + {Name: "id", Value: int64(1)}, + {Name: "null_col", Value: nil}, + {Name: "blob_col", Value: []byte{1, 2, 3}}, + }, + }, }, wantErr: false, }, @@ -264,8 +283,18 @@ func TestTool_Invoke(t *testing.T) { }, }, want: []any{ - map[string]any{"name": "Alice", "item": "Laptop"}, - map[string]any{"name": "Bob", "item": "Keyboard"}, + orderedmap.Row{ + Columns: []orderedmap.Column{ + {Name: "name", Value: "Alice"}, + {Name: "item", Value: "Laptop"}, + }, + }, + orderedmap.Row{ + Columns: []orderedmap.Column{ + {Name: "name", Value: "Bob"}, + {Name: "item", Value: "Keyboard"}, + }, + }, }, wantErr: false, }, @@ -292,7 +321,7 @@ func TestTool_Invoke(t *testing.T) { } if !isEqual { - t.Errorf("Tool.Invoke() = %v, want %v", got, tt.want) + t.Errorf("Tool.Invoke() = %+v, want %v", got, tt.want) } }) } diff --git a/internal/util/orderedmap/orderedmap.go b/internal/util/orderedmap/orderedmap.go new file mode 100644 index 000000000000..81edf6ec9c1a --- /dev/null +++ b/internal/util/orderedmap/orderedmap.go @@ -0,0 +1,62 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package orderedmap + +import ( + "bytes" + "encoding/json" +) + +// Column represents a single column in a row. +type Column struct { + Name string + Value any +} + +// Row represents a row of data with columns in a specific order. +type Row struct { + Columns []Column +} + +// Add adds a new column to the row. +func (r *Row) Add(name string, value any) { + r.Columns = append(r.Columns, Column{Name: name, Value: value}) +} + +// MarshalJSON implements the json.Marshaler interface for the Row struct. +// It marshals the row into a JSON object, preserving the order of the columns. +func (r Row) MarshalJSON() ([]byte, error) { + var buf bytes.Buffer + buf.WriteString("{") + for i, col := range r.Columns { + if i > 0 { + buf.WriteString(",") + } + // Marshal the key + key, err := json.Marshal(col.Name) + if err != nil { + return nil, err + } + buf.Write(key) + buf.WriteString(":") + // Marshal the value + val, err := json.Marshal(col.Value) + if err != nil { + return nil, err + } + buf.Write(val) + } + buf.WriteString("}") + return buf.Bytes(), nil +} diff --git a/internal/util/orderedmap/orderedmap_test.go b/internal/util/orderedmap/orderedmap_test.go new file mode 100644 index 000000000000..0ad3754d41b6 --- /dev/null +++ b/internal/util/orderedmap/orderedmap_test.go @@ -0,0 +1,83 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package orderedmap + +import ( + "encoding/json" + "testing" +) + +func TestRowMarshalJSON(t *testing.T) { + tests := []struct { + name string + row Row + want string + wantErr bool + }{ + { + name: "Simple row", + row: Row{ + Columns: []Column{ + {Name: "A", Value: 1}, + {Name: "B", Value: "two"}, + {Name: "C", Value: true}, + }, + }, + want: `{"A":1,"B":"two","C":true}`, + wantErr: false, + }, + { + name: "Row with different order", + row: Row{ + Columns: []Column{ + {Name: "C", Value: true}, + {Name: "A", Value: 1}, + {Name: "B", Value: "two"}, + }, + }, + want: `{"C":true,"A":1,"B":"two"}`, + wantErr: false, + }, + { + name: "Empty row", + row: Row{}, + want: `{}`, + wantErr: false, + }, + { + name: "Row with nil value", + row: Row{ + Columns: []Column{ + {Name: "A", Value: 1}, + {Name: "B", Value: nil}, + }, + }, + want: `{"A":1,"B":null}`, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := json.Marshal(tt.row) + if (err != nil) != tt.wantErr { + t.Errorf("Row.MarshalJSON() error = %v, wantErr %v", err, tt.wantErr) + return + } + if string(got) != tt.want { + t.Errorf("Row.MarshalJSON() = %s, want %s", string(got), tt.want) + } + }) + } +}