diff --git a/enginetest/queries/script_queries.go b/enginetest/queries/script_queries.go index 01a8523959..80a857409f 100644 --- a/enginetest/queries/script_queries.go +++ b/enginetest/queries/script_queries.go @@ -3646,7 +3646,7 @@ CREATE TABLE tab3 ( { Query: "select COLUMN_NAME, DATA_TYPE from INFORMATION_SCHEMA.COLUMNS where TABLE_NAME='c';", Expected: []sql.Row{ - {"coalesce(NULL,1)", "tinyint"}, + {"coalesce(NULL,1)", "int"}, }, }, }, diff --git a/sql/expression/function/coalesce.go b/sql/expression/function/coalesce.go index 88adbc1fac..cb8b7acba9 100644 --- a/sql/expression/function/coalesce.go +++ b/sql/expression/function/coalesce.go @@ -18,9 +18,9 @@ import ( "fmt" "strings" - "github.com/dolthub/go-mysql-server/sql/types" - "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/expression" + "github.com/dolthub/go-mysql-server/sql/types" ) // Coalesce returns the first non-NULL value in the list, or NULL if there are no non-NULL values. @@ -53,17 +53,56 @@ func (c *Coalesce) Description() string { // Type implements the sql.Expression interface. // The return type of Type() is the aggregated type of the argument types. func (c *Coalesce) Type() sql.Type { + typ := types.Null for _, arg := range c.args { if arg == nil { continue } t := arg.Type() + // special case for signed and unsigned integers + if (types.IsSigned(typ) && types.IsUnsigned(t)) || (types.IsUnsigned(typ) && types.IsSigned(t)) { + typ = types.MustCreateDecimalType(20, 0) + continue + } + if t != nil && t != types.Null { - return t + convType := expression.GetConvertToType(typ, t) + switch convType { + case expression.ConvertToChar: + // Can't get any larger than this + return types.LongText + case expression.ConvertToDecimal: + if typ == types.Float64 || t == types.Float64 { + typ = types.Float64 + } else if types.IsDecimal(t) { + typ = t + } else if !types.IsDecimal(typ) { + typ = types.MustCreateDecimalType(10, 0) + } + case expression.ConvertToUnsigned: + if typ == types.Uint64 || t == types.Uint64 { + typ = types.Uint64 + } else { + typ = types.Uint32 + } + case expression.ConvertToSigned: + if typ == types.Int64 || t == types.Int64 { + typ = types.Int64 + } else { + typ = types.Int32 + } + case expression.ConvertToFloat: + if typ == types.Float64 || t == types.Float64 { + typ = types.Float64 + } else { + typ = types.Float32 + } + default: + } } } - return types.Null + return typ } // CollationCoercibility implements the interface sql.CollationCoercible. diff --git a/sql/expression/function/coalesce_test.go b/sql/expression/function/coalesce_test.go index d9efd8708d..56b7f66f2c 100644 --- a/sql/expression/function/coalesce_test.go +++ b/sql/expression/function/coalesce_test.go @@ -17,6 +17,7 @@ package function import ( "testing" + "github.com/shopspring/decimal" "github.com/stretchr/testify/require" "github.com/dolthub/go-mysql-server/sql" @@ -37,11 +38,102 @@ func TestCoalesce(t *testing.T) { typ sql.Type nullable bool }{ - {"coalesce(1, 2, 3)", []sql.Expression{expression.NewLiteral(1, types.Int32), expression.NewLiteral(2, types.Int32), expression.NewLiteral(3, types.Int32)}, 1, types.Int32, false}, - {"coalesce(NULL, NULL, 3)", []sql.Expression{nil, nil, expression.NewLiteral(3, types.Int32)}, 3, types.Int32, false}, - {"coalesce(NULL, NULL, '3')", []sql.Expression{nil, nil, expression.NewLiteral("3", types.LongText)}, "3", types.LongText, false}, - {"coalesce(NULL, '2', 3)", []sql.Expression{nil, expression.NewLiteral("2", types.LongText), expression.NewLiteral(3, types.Int32)}, "2", types.LongText, false}, - {"coalesce(NULL, NULL, NULL)", []sql.Expression{nil, nil, nil}, nil, types.Null, true}, + { + name: "coalesce(1, 2, 3)", + input: []sql.Expression{ + expression.NewLiteral(1, types.Int32), + expression.NewLiteral(2, types.Int32), + expression.NewLiteral(3, types.Int32), + }, + expected: 1, + typ: types.Int32, + nullable: false, + }, + { + name: "coalesce(NULL, NULL, 3)", + input: []sql.Expression{ + nil, + nil, + expression.NewLiteral(3, types.Int32), + }, + expected: 3, + typ: types.Int32, + nullable: false, + }, + { + name: "coalesce(NULL, NULL, '3')", + input: []sql.Expression{ + nil, + nil, + expression.NewLiteral("3", types.LongText), + }, + expected: "3", + typ: types.LongText, + nullable: false, + }, + { + name: "coalesce(NULL, '2', 3)", + input: []sql.Expression{ + nil, + expression.NewLiteral("2", types.LongText), + expression.NewLiteral(3, types.Int32), + }, + expected: "2", + typ: types.LongText, + nullable: false, + }, + { + name: "coalesce(NULL, NULL, NULL)", + input: []sql.Expression{ + nil, + nil, + nil, + }, + expected: nil, + typ: types.Null, + nullable: true, + }, + { + name: "coalesce(int(1), decimal(2.0), string('3'))", + input: []sql.Expression{ + expression.NewLiteral(1, types.Int32), + expression.NewLiteral(decimal.NewFromFloat(2.0), types.MustCreateDecimalType(10, 0)), + expression.NewLiteral("3", types.LongText), + }, + expected: 1, + typ: types.LongText, + nullable: false, + }, + { + name: "coalesce(signed(1), unsigned(2))", + input: []sql.Expression{ + expression.NewLiteral(1, types.Int32), + expression.NewLiteral(2, types.Uint32), + }, + expected: 1, + typ: types.MustCreateDecimalType(20, 0), + nullable: false, + }, + { + name: "coalesce(signed(1), unsigned(2))", + input: []sql.Expression{ + expression.NewLiteral(1, types.Int32), + expression.NewLiteral(2, types.Uint32), + }, + expected: 1, + typ: types.MustCreateDecimalType(20, 0), + nullable: false, + }, + { + name: "coalesce(decimal(1.0), float64(2.0))", + input: []sql.Expression{ + expression.NewLiteral(1, types.MustCreateDecimalType(10, 0)), + expression.NewLiteral(2, types.Float64), + }, + expected: 1, + typ: types.Float64, + nullable: false, + }, } for _, tt := range testCases {