diff --git a/server/analyzer/type_sanitizer.go b/server/analyzer/type_sanitizer.go index 88f2034b88..ea93b3cbcb 100644 --- a/server/analyzer/type_sanitizer.go +++ b/server/analyzer/type_sanitizer.go @@ -17,6 +17,7 @@ package analyzer import ( "context" "strconv" + "strings" "time" "github.com/cockroachdb/errors" @@ -39,10 +40,23 @@ import ( // to GMS types, so by taking care of all conversions here, we can ensure that Doltgres only needs to worry about its // own types. func TypeSanitizer(ctx *sql.Context, a *analyzer.Analyzer, node sql.Node, scope *plan.Scope, selector analyzer.RuleSelector, qFlags *sql.QueryFlags) (sql.Node, transform.TreeIdentity, error) { - return pgtransform.NodeExprsWithOpaque(node, func(expr sql.Expression) (sql.Expression, transform.TreeIdentity, error) { + return pgtransform.NodeExprsWithNodeWithOpaque(node, func(n sql.Node, expr sql.Expression) (sql.Expression, transform.TreeIdentity, error) { // This can be updated if we find more expressions that return GMS types. // These should eventually be replaced with Doltgres-equivalents over time, rendering this function unnecessary. switch expr := expr.(type) { + case *expression.GetField: + switch n := n.(type) { + case *plan.Project, *plan.Filter, *plan.GroupBy: + child := n.Children()[0] + // Some dolt_ tables do not have doltgres types for their columns, so we convert them here + if rt, ok := child.(*plan.ResolvedTable); ok && strings.HasPrefix(rt.Name(), "dolt_") { + // This is a projection on a table, so we can safely convert the type + if _, ok := expr.Type().(*pgtypes.DoltgresType); !ok { + return pgexprs.NewGMSCast(expr), transform.NewTree, nil + } + } + } + return expr, transform.SameTree, nil case *expression.Literal: return typeSanitizerLiterals(ctx, expr) case *expression.Not, *expression.And, *expression.Or, *expression.Like: diff --git a/server/expression/gms_cast.go b/server/expression/gms_cast.go index 373be83108..56d1964853 100644 --- a/server/expression/gms_cast.go +++ b/server/expression/gms_cast.go @@ -15,10 +15,12 @@ package expression import ( - "strconv" + "encoding/json" + "math" "time" "github.com/cockroachdb/errors" + "github.com/dolthub/dolt/go/store/prolly/tree" "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/go-mysql-server/sql/types" "github.com/dolthub/vitess/go/vt/proto/query" @@ -124,8 +126,12 @@ func (c *GMSCast) Eval(ctx *sql.Context, row sql.Row) (any, error) { } return newVal, nil case query.Type_UINT64: + // Postgres doesn't have a Uint64 type, so we return an int64 with an error if the value is too high if val, ok := val.(uint64); ok { - return decimal.NewFromString(strconv.FormatUint(val, 10)) + if val > math.MaxInt64 { + return nil, errors.Errorf("uint64 value out of range: %v", val) + } + return int64(val), nil } return nil, errors.Errorf("GMSCast expected type `uint64`, got `%T`", val) case query.Type_FLOAT32: @@ -158,15 +164,26 @@ func (c *GMSCast) Eval(ctx *sql.Context, row sql.Row) (any, error) { if err != nil { return nil, err } - if _, ok := newVal.(string); !ok { + switch newVal := newVal.(type) { + case string: + return newVal, nil + case sql.StringWrapper: + return newVal.Unwrap(ctx) + default: return nil, errors.Errorf("GMSCast expected type `string`, got `%T`", val) } - return newVal, nil case query.Type_JSON: - if val, ok := val.(types.JSONDocument); ok { + switch val := val.(type) { + case types.JSONDocument: return val.JSONString() + case tree.IndexedJsonDocument: + return val.String(), nil + default: + // TODO: there are particular dolt tables (dolt_constraint_violations) that return json-marshallable structs + // that we need to handle here like this + bytes, err := json.Marshal(val) + return string(bytes), err } - return nil, errors.Errorf("GMSCast expected type `JSONDocument`, got `%T`", val) case query.Type_NULL_TYPE: return nil, nil case query.Type_GEOMETRY: diff --git a/testing/go/dolt_tables_test.go b/testing/go/dolt_tables_test.go index 2e31599afb..5c1930446d 100755 --- a/testing/go/dolt_tables_test.go +++ b/testing/go/dolt_tables_test.go @@ -106,7 +106,6 @@ func TestUserSpaceDoltTables(t *testing.T) { Expected: []sql.Row{{1}}, }, { - Skip: true, // this fails because the column type of this table is not a doltgres type, which IN requires Query: `SELECT "dolt_branches"."name" FROM "dolt_branches" WHERE "dolt_branches"."name" IN ('main') ORDER BY "dolt_branches"."name" DESC LIMIT 21;`, Expected: []sql.Row{{"main"}}, }, @@ -555,7 +554,7 @@ func TestUserSpaceDoltTables(t *testing.T) { }, { Query: `SELECT * FROM dolt_conflicts`, - Expected: []sql.Row{{"test", Numeric("1")}}, + Expected: []sql.Row{{"test", 1}}, }, { Query: `SELECT dolt.conflicts.table FROM dolt.conflicts`, @@ -763,7 +762,7 @@ func TestUserSpaceDoltTables(t *testing.T) { }, { Query: `SELECT * FROM dolt_constraint_violations`, - Expected: []sql.Row{{"test", Numeric("2")}}, + Expected: []sql.Row{{"test", 2}}, }, { Query: `SELECT dolt.constraint_violations.table FROM dolt.constraint_violations`, @@ -1744,8 +1743,8 @@ func TestUserSpaceDoltTables(t *testing.T) { }, }, }, - //TODO: turn on statistics - //{ + // TODO: turn on statistics + // { // Name: "dolt statistics", // SetUpScript: []string{ // "CREATE TABLE horses (id int primary key, name varchar(10));", @@ -1871,7 +1870,7 @@ func TestUserSpaceDoltTables(t *testing.T) { // Expected: []sql.Row{{"horses", "horses_name_idx"}, {"horses", "primary"}}, // }, // }, - //}, + // }, { Name: "dolt status", SetUpScript: []string{