Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion server/analyzer/type_sanitizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package analyzer
import (
"context"
"strconv"
"strings"
"time"

"github.com/cockroachdb/errors"
Expand All @@ -39,10 +40,23 @@ import (
// to GMS types, so by taking care of all conversions here, we can ensure that Doltgres only needs to worry about its
// own types.
func TypeSanitizer(ctx *sql.Context, a *analyzer.Analyzer, node sql.Node, scope *plan.Scope, selector analyzer.RuleSelector, qFlags *sql.QueryFlags) (sql.Node, transform.TreeIdentity, error) {
return pgtransform.NodeExprsWithOpaque(node, func(expr sql.Expression) (sql.Expression, transform.TreeIdentity, error) {
return pgtransform.NodeExprsWithNodeWithOpaque(node, func(n sql.Node, expr sql.Expression) (sql.Expression, transform.TreeIdentity, error) {
// This can be updated if we find more expressions that return GMS types.
// These should eventually be replaced with Doltgres-equivalents over time, rendering this function unnecessary.
switch expr := expr.(type) {
case *expression.GetField:
switch n := n.(type) {
case *plan.Project, *plan.Filter, *plan.GroupBy:
child := n.Children()[0]
// Some dolt_ tables do not have doltgres types for their columns, so we convert them here
if rt, ok := child.(*plan.ResolvedTable); ok && strings.HasPrefix(rt.Name(), "dolt_") {
// This is a projection on a table, so we can safely convert the type
if _, ok := expr.Type().(*pgtypes.DoltgresType); !ok {
return pgexprs.NewGMSCast(expr), transform.NewTree, nil
}
}
}
return expr, transform.SameTree, nil
case *expression.Literal:
return typeSanitizerLiterals(ctx, expr)
case *expression.Not, *expression.And, *expression.Or, *expression.Like:
Expand Down
29 changes: 23 additions & 6 deletions server/expression/gms_cast.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@
package expression

import (
"strconv"
"encoding/json"
"math"
"time"

"github.com/cockroachdb/errors"
"github.com/dolthub/dolt/go/store/prolly/tree"
"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/go-mysql-server/sql/types"
"github.com/dolthub/vitess/go/vt/proto/query"
Expand Down Expand Up @@ -124,8 +126,12 @@ func (c *GMSCast) Eval(ctx *sql.Context, row sql.Row) (any, error) {
}
return newVal, nil
case query.Type_UINT64:
// Postgres doesn't have a Uint64 type, so we return an int64 with an error if the value is too high
if val, ok := val.(uint64); ok {
return decimal.NewFromString(strconv.FormatUint(val, 10))
if val > math.MaxInt64 {
return nil, errors.Errorf("uint64 value out of range: %v", val)
}
return int64(val), nil
}
return nil, errors.Errorf("GMSCast expected type `uint64`, got `%T`", val)
case query.Type_FLOAT32:
Expand Down Expand Up @@ -158,15 +164,26 @@ func (c *GMSCast) Eval(ctx *sql.Context, row sql.Row) (any, error) {
if err != nil {
return nil, err
}
if _, ok := newVal.(string); !ok {
switch newVal := newVal.(type) {
case string:
return newVal, nil
case sql.StringWrapper:
return newVal.Unwrap(ctx)
default:
return nil, errors.Errorf("GMSCast expected type `string`, got `%T`", val)
}
return newVal, nil
case query.Type_JSON:
if val, ok := val.(types.JSONDocument); ok {
switch val := val.(type) {
case types.JSONDocument:
return val.JSONString()
case tree.IndexedJsonDocument:
return val.String(), nil
default:
// TODO: there are particular dolt tables (dolt_constraint_violations) that return json-marshallable structs
// that we need to handle here like this
bytes, err := json.Marshal(val)
return string(bytes), err
}
return nil, errors.Errorf("GMSCast expected type `JSONDocument`, got `%T`", val)
case query.Type_NULL_TYPE:
return nil, nil
case query.Type_GEOMETRY:
Expand Down
11 changes: 5 additions & 6 deletions testing/go/dolt_tables_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,6 @@ func TestUserSpaceDoltTables(t *testing.T) {
Expected: []sql.Row{{1}},
},
{
Skip: true, // this fails because the column type of this table is not a doltgres type, which IN requires
Query: `SELECT "dolt_branches"."name" FROM "dolt_branches" WHERE "dolt_branches"."name" IN ('main') ORDER BY "dolt_branches"."name" DESC LIMIT 21;`,
Expected: []sql.Row{{"main"}},
},
Expand Down Expand Up @@ -555,7 +554,7 @@ func TestUserSpaceDoltTables(t *testing.T) {
},
{
Query: `SELECT * FROM dolt_conflicts`,
Expected: []sql.Row{{"test", Numeric("1")}},
Expected: []sql.Row{{"test", 1}},
},
{
Query: `SELECT dolt.conflicts.table FROM dolt.conflicts`,
Expand Down Expand Up @@ -763,7 +762,7 @@ func TestUserSpaceDoltTables(t *testing.T) {
},
{
Query: `SELECT * FROM dolt_constraint_violations`,
Expected: []sql.Row{{"test", Numeric("2")}},
Expected: []sql.Row{{"test", 2}},
},
{
Query: `SELECT dolt.constraint_violations.table FROM dolt.constraint_violations`,
Expand Down Expand Up @@ -1744,8 +1743,8 @@ func TestUserSpaceDoltTables(t *testing.T) {
},
},
},
//TODO: turn on statistics
//{
// TODO: turn on statistics
// {
// Name: "dolt statistics",
// SetUpScript: []string{
// "CREATE TABLE horses (id int primary key, name varchar(10));",
Expand Down Expand Up @@ -1871,7 +1870,7 @@ func TestUserSpaceDoltTables(t *testing.T) {
// Expected: []sql.Row{{"horses", "horses_name_idx"}, {"horses", "primary"}},
// },
// },
//},
// },
{
Name: "dolt status",
SetUpScript: []string{
Expand Down