diff --git a/enginetest/queries/integration_plans.go b/enginetest/queries/integration_plans.go index 97cd20f2e9..2450fd7d70 100644 --- a/enginetest/queries/integration_plans.go +++ b/enginetest/queries/integration_plans.go @@ -7148,7 +7148,7 @@ WHERE " │ │ └─ 0.5 (decimal(2,1))\n" + " │ └─ Eq\n" + " │ ├─ nrfj3.YHYLK:6\n" + - " │ └─ 0 (bigint)\n" + + " │ └─ 0 (tinyint)\n" + " │ THEN 1 (tinyint) ELSE 0 (tinyint) END), nrfj3.T4IBQ:0!null, nrfj3.ECUWU:1!null, nrfj3.GSTQA:2!null, nrfj3.B5OUF:3\n" + " ├─ group: nrfj3.T4IBQ:0!null, nrfj3.ECUWU:1!null, nrfj3.GSTQA:2!null\n" + " └─ SubqueryAlias\n" + @@ -8023,7 +8023,7 @@ WHERE " │ │ └─ 0.5 (decimal(2,1))\n" + " │ └─ Eq\n" + " │ ├─ nrfj3.YHYLK:6\n" + - " │ └─ 0 (bigint)\n" + + " │ └─ 0 (tinyint)\n" + " │ THEN 1 (tinyint) ELSE 0 (tinyint) END), nrfj3.T4IBQ:0!null, nrfj3.ECUWU:1!null, nrfj3.GSTQA:2!null, nrfj3.B5OUF:3\n" + " ├─ group: nrfj3.T4IBQ:0!null, nrfj3.ECUWU:1!null, nrfj3.GSTQA:2!null\n" + " └─ SubqueryAlias\n" + diff --git a/enginetest/queries/queries.go b/enginetest/queries/queries.go index f2ef16915c..3505ec58ad 100644 --- a/enginetest/queries/queries.go +++ b/enginetest/queries/queries.go @@ -6092,7 +6092,7 @@ SELECT * FROM cte WHERE d = 2;`, Query: `SELECT if(123 = 123, NULL, NULL = 1)`, Expected: []sql.Row{{nil}}, ExpectedColumns: []*sql.Column{ - {Name: "if(123 = 123, NULL, NULL = 1)", Type: types.Int64}, // TODO: this should be getting coerced to bool + {Name: "if(123 = 123, NULL, NULL = 1)", Type: types.Boolean}, }, }, { diff --git a/enginetest/queries/script_queries.go b/enginetest/queries/script_queries.go index 3f06760c67..9ca10596ae 100644 --- a/enginetest/queries/script_queries.go +++ b/enginetest/queries/script_queries.go @@ -8712,6 +8712,26 @@ where }, }, }, + { + Name: "tinyint column does not restrict IF or IFNULL output", + // https://github.com/dolthub/dolt/issues/9321 + SetUpScript: []string{ + "create table t0 (c0 tinyint);", + "insert into t0 values (null);", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "select ifnull(t0.c0, 128) as ref0 from t0", + Expected: []sql.Row{ + {128}, + }, + }, + { + Query: "select if(t0.c0 = 1, t0.c0, 128) as ref0 from t0", + Expected: []sql.Row{{128}}, + }, + }, + }, } var SpatialScriptTests = []ScriptTest{ diff --git a/server/handler_test.go b/server/handler_test.go index 969c1b408c..03bf918754 100644 --- a/server/handler_test.go +++ b/server/handler_test.go @@ -212,7 +212,7 @@ func TestHandlerOutput(t *testing.T) { }) require.NoError(t, err) require.Equal(t, 1, len(result.Rows)) - require.Equal(t, sqltypes.Int64, result.Rows[0][0].Type()) + require.Equal(t, sqltypes.Int16, result.Rows[0][0].Type()) require.Equal(t, []byte("456"), result.Rows[0][0].ToBytes()) }) } @@ -471,7 +471,8 @@ func TestHandlerComPrepareExecute(t *testing.T) { }, }, schema: []*query.Field{ - {Name: "c1", OrgName: "c1", Table: "test", OrgTable: "test", Database: "test", Type: query.Type_INT32, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 11, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)}, + {Name: "c1", OrgName: "c1", Table: "test", OrgTable: "test", Database: "test", Type: query.Type_INT32, + Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 11, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)}, }, expected: []sql.Row{ {0}, {1}, {2}, {3}, {4}, @@ -550,7 +551,8 @@ func TestHandlerComPrepareExecuteWithPreparedDisabled(t *testing.T) { }, }, schema: []*query.Field{ - {Name: "c1", OrgName: "c1", Table: "test", OrgTable: "test", Database: "test", Type: query.Type_INT32, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 11, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)}, + {Name: "c1", OrgName: "c1", Table: "test", OrgTable: "test", Database: "test", Type: query.Type_INT32, + Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 11, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)}, }, expected: []sql.Row{ {0}, {1}, {2}, {3}, {4}, @@ -567,7 +569,8 @@ func TestHandlerComPrepareExecuteWithPreparedDisabled(t *testing.T) { BindVars: nil, }, schema: []*query.Field{ - {Name: "a", OrgName: "a", Table: "", OrgTable: "", Database: "", Type: query.Type_INT16, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 6, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)}, + {Name: "a", OrgName: "a", Table: "", OrgTable: "", Database: "", Type: query.Type_INT16, + Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 6, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)}, }, expected: []sql.Row{ {1000}, @@ -584,7 +587,8 @@ func TestHandlerComPrepareExecuteWithPreparedDisabled(t *testing.T) { BindVars: nil, }, schema: []*query.Field{ - {Name: "a", OrgName: "a", Table: "", OrgTable: "", Database: "", Type: query.Type_INT16, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 6, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)}, + {Name: "a", OrgName: "a", Table: "", OrgTable: "", Database: "", Type: query.Type_INT16, + Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 6, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)}, }, expected: []sql.Row{ {-129}, diff --git a/sql/expression/case.go b/sql/expression/case.go index 30b6cf3e06..7c7df34ce0 100644 --- a/sql/expression/case.go +++ b/sql/expression/case.go @@ -43,71 +43,14 @@ func NewCase(expr sql.Expression, branches []CaseBranch, elseExpr sql.Expression return &Case{expr, branches, elseExpr} } -// From the description of operator typing here: -// https://dev.mysql.com/doc/refman/8.0/en/flow-control-functions.html#operator_case -func combinedCaseBranchType(left, right sql.Type) sql.Type { - if left == types.Null { - return right - } - if right == types.Null { - return left - } - - // Our current implementation of StringType.Convert(enum), does not match MySQL's behavior. - // So, we make sure to return Enums in this particular case. - // More details: https://github.com/dolthub/dolt/issues/8598 - if types.IsEnum(left) && types.IsEnum(right) { - return right - } - if types.IsSet(left) && types.IsSet(right) { - return right - } - if types.IsTextOnly(left) && types.IsTextOnly(right) { - return types.LongText - } - if types.IsTextBlob(left) && types.IsTextBlob(right) { - return types.LongBlob - } - if types.IsTime(left) && types.IsTime(right) { - if left == right { - return left - } - return types.DatetimeMaxPrecision - } - if types.IsNumber(left) && types.IsNumber(right) { - if left == types.Float64 || right == types.Float64 { - return types.Float64 - } - if left == types.Float32 || right == types.Float32 { - return types.Float32 - } - if types.IsDecimal(left) || types.IsDecimal(right) { - return types.MustCreateDecimalType(65, 10) - } - if left == types.Uint64 && types.IsSigned(right) || - right == types.Uint64 && types.IsSigned(left) { - return types.MustCreateDecimalType(65, 10) - } - if !types.IsSigned(left) && !types.IsSigned(right) { - return types.Uint64 - } else { - return types.Int64 - } - } - if types.IsJSON(left) && types.IsJSON(right) { - return types.JSON - } - return types.LongText -} - // Type implements the sql.Expression interface. func (c *Case) Type() sql.Type { curr := types.Null for _, b := range c.Branches { - curr = combinedCaseBranchType(curr, b.Value.Type()) + curr = types.GeneralizeTypes(curr, b.Value.Type()) } if c.Else != nil { - curr = combinedCaseBranchType(curr, c.Else.Type()) + curr = types.GeneralizeTypes(curr, c.Else.Type()) } return curr } diff --git a/sql/expression/case_test.go b/sql/expression/case_test.go index 68033649e9..27b5afdacd 100644 --- a/sql/expression/case_test.go +++ b/sql/expression/case_test.go @@ -161,8 +161,8 @@ func TestCaseType(t *testing.T) { } } - decimalType := types.MustCreateDecimalType(65, 10) - + decimalType := types.MustCreateDecimalType(types.DecimalTypeMaxPrecision, types.DecimalTypeMaxScale) + uint64DecimalType := types.MustCreateDecimalType(types.DecimalTypeMaxPrecision, 0) testCases := []struct { name string c *Case @@ -175,13 +175,13 @@ func TestCaseType(t *testing.T) { }, { "unsigned promoted and unsigned", - caseExpr(NewLiteral(uint32(0), types.Uint32), NewLiteral(uint32(1), types.Uint32)), + caseExpr(NewLiteral(uint32(0), types.Uint32), NewLiteral(uint32(1), types.Uint64)), types.Uint64, }, { "signed promoted and signed", caseExpr(NewLiteral(int8(0), types.Int8), NewLiteral(int32(1), types.Int32)), - types.Int64, + types.Int32, }, { "int and float to float", @@ -216,7 +216,7 @@ func TestCaseType(t *testing.T) { { "uint64 and int8 to decimal", caseExpr(NewLiteral(uint64(10), types.Uint64), NewLiteral(int8(0), types.Int8)), - decimalType, + uint64DecimalType, }, { "int and text to text", diff --git a/sql/expression/function/if.go b/sql/expression/function/if.go index 55e24e5fdf..c019357f39 100644 --- a/sql/expression/function/if.go +++ b/sql/expression/function/if.go @@ -89,23 +89,15 @@ func (f *If) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return nil, err } } - eval, _, err = f.Type().Convert(ctx, eval) + if ret, _, err := f.Type().Convert(ctx, eval); err == nil { + return ret, nil + } return eval, err } // Type implements the Expression interface. func (f *If) Type() sql.Type { - // if either type is string type, this should be a string type, regardless need to promote - typ1 := f.ifTrue.Type() - typ2 := f.ifFalse.Type() - if types.IsText(typ1) || types.IsText(typ2) { - return types.Text - } - - if typ1 == types.Null { - return typ2.Promote() - } - return typ1.Promote() + return types.GeneralizeTypes(f.ifTrue.Type(), f.ifFalse.Type()) } // CollationCoercibility implements the interface sql.CollationCoercible. diff --git a/sql/expression/function/ifnull.go b/sql/expression/function/ifnull.go index 9f5e4f8709..9e80a16337 100644 --- a/sql/expression/function/ifnull.go +++ b/sql/expression/function/ifnull.go @@ -52,30 +52,32 @@ func (f *IfNull) Description() string { // Eval implements the Expression interface. func (f *IfNull) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { + t := f.Type() + left, err := f.LeftChild.Eval(ctx, row) if err != nil { return nil, err } if left != nil { - return left, nil + if ret, _, err := t.Convert(ctx, left); err == nil { + return ret, nil + } + return left, err } right, err := f.RightChild.Eval(ctx, row) if err != nil { return nil, err } - return right, nil + if ret, _, err := t.Convert(ctx, right); err == nil { + return ret, nil + } + return right, err } // Type implements the Expression interface. func (f *IfNull) Type() sql.Type { - if types.IsNull(f.LeftChild) { - if types.IsNull(f.RightChild) { - return types.Null - } - return f.RightChild.Type() - } - return f.LeftChild.Type() + return types.GeneralizeTypes(f.LeftChild.Type(), f.RightChild.Type()) } // CollationCoercibility implements the interface sql.CollationCoercible. diff --git a/sql/expression/function/ifnull_test.go b/sql/expression/function/ifnull_test.go index ed6acc3336..507b8e7bdd 100644 --- a/sql/expression/function/ifnull_test.go +++ b/sql/expression/function/ifnull_test.go @@ -26,25 +26,28 @@ import ( func TestIfNull(t *testing.T) { testCases := []struct { - expression interface{} - value interface{} - expected interface{} + expression interface{} + expressionType sql.Type + value interface{} + valueType sql.Type + expected interface{} + expectedType sql.Type }{ - {"foo", "bar", "foo"}, - {"foo", "foo", "foo"}, - {nil, "foo", "foo"}, - {"foo", nil, "foo"}, - {nil, nil, nil}, - {"", nil, ""}, + {"foo", types.LongText, "bar", types.LongText, "foo", types.LongText}, + {"foo", types.LongText, "foo", types.LongText, "foo", types.LongText}, + {nil, types.LongText, "foo", types.LongText, "foo", types.LongText}, + {"foo", types.LongText, nil, types.LongText, "foo", types.LongText}, + {nil, types.LongText, nil, types.LongText, nil, types.LongText}, + {"", types.LongText, nil, types.LongText, "", types.LongText}, + {nil, types.Int8, 128, types.Int64, int64(128), types.Int64}, } - f := NewIfNull( - expression.NewGetField(0, types.LongText, "expression", true), - expression.NewGetField(1, types.LongText, "value", true), - ) - require.Equal(t, types.LongText, f.Type()) - for _, tc := range testCases { + f := NewIfNull( + expression.NewGetField(0, tc.expressionType, "expression", true), + expression.NewGetField(1, tc.valueType, "value", true), + ) + require.Equal(t, tc.expectedType, f.Type()) v, err := f.Eval(sql.NewEmptyContext(), sql.NewRow(tc.expression, tc.value)) require.NoError(t, err) require.Equal(t, tc.expected, v) diff --git a/sql/types/conversion.go b/sql/types/conversion.go index fc027f1f65..3503111d31 100644 --- a/sql/types/conversion.go +++ b/sql/types/conversion.go @@ -554,3 +554,168 @@ func TypesEqual(a, b sql.Type) bool { return a.Equals(b) } } + +// generalizeNumberTypes assumes both inputs return true for IsNumber +func generalizeNumberTypes(a, b sql.Type) sql.Type { + if IsFloat(a) || IsFloat(b) { + // TODO: handle cases where MySQL returns Float32 + return Float64 + } + + if IsDecimal(a) || IsDecimal(b) { + // TODO: match precision and scale to that of the decimal type, check if defines column + return MustCreateDecimalType(DecimalTypeMaxPrecision, DecimalTypeMaxScale) + } + + aIsSigned := IsSigned(a) + bIsSigned := IsSigned(b) + + if a == Uint64 || b == Uint64 { + if aIsSigned || bIsSigned { + return MustCreateDecimalType(DecimalTypeMaxPrecision, 0) + } + return Uint64 + } + + if a == Int64 || b == Int64 { + return Int64 + } + + if a == Uint32 || b == Uint32 { + if aIsSigned || bIsSigned { + return Int64 + } + return Uint32 + } + + if a == Int32 || b == Int32 { + return Int32 + } + + if a == Uint24 || b == Uint24 { + if aIsSigned || bIsSigned { + return Int32 + } + return Uint24 + } + + if a == Int24 || b == Int24 { + return Int24 + } + + if a == Uint16 || b == Uint16 { + if aIsSigned || bIsSigned { + return Int24 + } + return Uint16 + } + + if a == Int16 || b == Int16 { + return Int16 + } + + if a == Uint8 || b == Uint8 { + if aIsSigned || bIsSigned { + return Int16 + } + return Uint8 + } + + if a == Int8 || b == Int8 { + return Int8 + } + + if IsBoolean(a) && IsBoolean(b) { + return Boolean + } + + return Int64 +} + +// GeneralizeTypes returns the more "general" of two types as defined by +// https://dev.mysql.com/doc/refman/8.4/en/flow-control-functions.html +// TODO: Create and handle "Illegal mix of collations" error +func GeneralizeTypes(a, b sql.Type) sql.Type { + if a == Null { + return b + } + if b == Null { + return a + } + + if svt, ok := a.(sql.SystemVariableType); ok { + a = svt.UnderlyingType() + } + if svt, ok := a.(sql.SystemVariableType); ok { + b = svt.UnderlyingType() + } + + if IsJSON(a) && IsJSON(b) { + return JSON + } + + if IsGeometry(a) && IsGeometry(b) { + return a + } + + if IsEnum(a) && IsEnum(b) { + return a + } + + if IsSet(a) && IsSet(b) { + return a + } + + aIsTimespan := IsTimespan(a) + bIsTimespan := IsTimespan(b) + if aIsTimespan && bIsTimespan { + return Time + } + if (IsTime(a) || aIsTimespan) && (IsTime(b) || bIsTimespan) { + if IsDateType(a) && IsDateType(b) { + return Date + } + if IsTimestampType(a) && IsTimestampType(b) { + // TODO: match precision to max precision of the two timestamps + return TimestampMaxPrecision + } + // TODO: match precision to max precision of the two time types + return DatetimeMaxPrecision + } + + if IsBlobType(a) || IsBlobType(b) { + // TODO: match blob length to max of the blob lengths + return LongBlob + } + + aIsBit := IsBit(a) + bIsBit := IsBit(b) + if aIsBit && bIsBit { + // TODO: match max bits to max of max bits between a and b + return a.Promote() + } + if aIsBit { + a = Int64 + } + if bIsBit { + b = Int64 + } + + aIsYear := IsYear(a) + bIsYear := IsYear(b) + if aIsYear && bIsYear { + return a + } + if aIsYear { + a = Int32 + } + if bIsYear { + b = Int32 + } + + if IsNumber(a) && IsNumber(b) { + return generalizeNumberTypes(a, b) + } + // TODO: decide if we want to make this VarChar to match MySQL, match VarChar length to max of two types + return LongText +} diff --git a/sql/types/conversion_test.go b/sql/types/conversion_test.go index 35cb5f03d2..e0928e6814 100644 --- a/sql/types/conversion_test.go +++ b/sql/types/conversion_test.go @@ -119,7 +119,7 @@ func TestColumnTypeToType_Time(t *testing.T) { } func TestColumnCharTypes(t *testing.T) { - test := []struct { + tests := []struct { typ string len int64 exp sql.Type @@ -146,7 +146,7 @@ func TestColumnCharTypes(t *testing.T) { }, } - for _, test := range test { + for _, test := range tests { t.Run(fmt.Sprintf("%v %v", test.typ, test.exp), func(t *testing.T) { ct := &sqlparser.ColumnType{ Type: test.typ, @@ -158,3 +158,63 @@ func TestColumnCharTypes(t *testing.T) { }) } } + +func TestGeneralizeTypes(t *testing.T) { + decimalType := MustCreateDecimalType(DecimalTypeMaxPrecision, DecimalTypeMaxScale) + uint64DecimalType := MustCreateDecimalType(DecimalTypeMaxPrecision, 0) + + tests := []struct { + typeA sql.Type + typeB sql.Type + expected sql.Type + }{ + {Float64, Float32, Float64}, + {Float64, Int32, Float64}, + {Int24, Float32, Float64}, + {decimalType, Float64, Float64}, + {decimalType, Int32, decimalType}, + {Int64, decimalType, decimalType}, + {Uint64, Int32, uint64DecimalType}, + {Int24, Uint64, uint64DecimalType}, + {Uint64, Uint8, Uint64}, + {Uint24, Uint64, Uint64}, + {Int64, Uint32, Int64}, + {Int24, Int64, Int64}, + {Int8, Int64, Int64}, + {Uint32, Int24, Int64}, + {Uint24, Uint32, Uint32}, + {Int32, Int8, Int32}, + {Uint24, Int32, Int32}, + {Uint24, Int24, Int32}, + {Uint8, Uint24, Uint24}, + {Int24, Uint8, Int24}, + {Int8, Int24, Int24}, + {Int8, Uint16, Int24}, + {Uint16, Uint8, Uint16}, + {Int16, Int16, Int16}, + {Int8, Int16, Int16}, + {Uint8, Int8, Int16}, + {Uint8, Uint8, Uint8}, + {Int8, Int8, Int8}, + {Boolean, Int64, Int64}, + {Boolean, Boolean, Boolean}, + {Text, Text, LongText}, + {Text, Float64, LongText}, + {Int64, Text, LongText}, + {Int8, Null, Int8}, + {Time, Time, Time}, + {Time, Date, DatetimeMaxPrecision}, + {Date, Date, Date}, + {Date, Timestamp, DatetimeMaxPrecision}, + {Timestamp, Timestamp, TimestampMaxPrecision}, + {Timestamp, Datetime, DatetimeMaxPrecision}, + {Null, Int64, Int64}, + {Null, Null, Null}, + } + for _, test := range tests { + t.Run(fmt.Sprintf("%v %v %v", test.typeA, test.typeB, test.expected), func(t *testing.T) { + res := GeneralizeTypes(test.typeA, test.typeB) + assert.Equal(t, test.expected, res) + }) + } +}