Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions enginetest/queries/integration_plans.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion enginetest/queries/queries.go
Original file line number Diff line number Diff line change
Expand Up @@ -6092,7 +6092,7 @@ SELECT * FROM cte WHERE d = 2;`,
Query: `SELECT if(123 = 123, NULL, NULL = 1)`,
Expected: []sql.Row{{nil}},
ExpectedColumns: []*sql.Column{
{Name: "if(123 = 123, NULL, NULL = 1)", Type: types.Int64}, // TODO: this should be getting coerced to bool
{Name: "if(123 = 123, NULL, NULL = 1)", Type: types.Boolean},
},
},
{
Expand Down
20 changes: 20 additions & 0 deletions enginetest/queries/script_queries.go
Original file line number Diff line number Diff line change
Expand Up @@ -8712,6 +8712,26 @@ where
},
},
},
{
Name: "tinyint column does not restrict IF or IFNULL output",
// https://github.com/dolthub/dolt/issues/9321
SetUpScript: []string{
"create table t0 (c0 tinyint);",
"insert into t0 values (null);",
},
Assertions: []ScriptTestAssertion{
{
Query: "select ifnull(t0.c0, 128) as ref0 from t0",
Expected: []sql.Row{
{128},
},
},
{
Query: "select if(t0.c0 = 1, t0.c0, 128) as ref0 from t0",
Expected: []sql.Row{{128}},
},
},
},
}

var SpatialScriptTests = []ScriptTest{
Expand Down
14 changes: 9 additions & 5 deletions server/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ func TestHandlerOutput(t *testing.T) {
})
require.NoError(t, err)
require.Equal(t, 1, len(result.Rows))
require.Equal(t, sqltypes.Int64, result.Rows[0][0].Type())
require.Equal(t, sqltypes.Int16, result.Rows[0][0].Type())
require.Equal(t, []byte("456"), result.Rows[0][0].ToBytes())
})
}
Expand Down Expand Up @@ -471,7 +471,8 @@ func TestHandlerComPrepareExecute(t *testing.T) {
},
},
schema: []*query.Field{
{Name: "c1", OrgName: "c1", Table: "test", OrgTable: "test", Database: "test", Type: query.Type_INT32, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 11, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)},
{Name: "c1", OrgName: "c1", Table: "test", OrgTable: "test", Database: "test", Type: query.Type_INT32,
Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 11, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)},
},
expected: []sql.Row{
{0}, {1}, {2}, {3}, {4},
Expand Down Expand Up @@ -550,7 +551,8 @@ func TestHandlerComPrepareExecuteWithPreparedDisabled(t *testing.T) {
},
},
schema: []*query.Field{
{Name: "c1", OrgName: "c1", Table: "test", OrgTable: "test", Database: "test", Type: query.Type_INT32, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 11, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)},
{Name: "c1", OrgName: "c1", Table: "test", OrgTable: "test", Database: "test", Type: query.Type_INT32,
Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 11, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)},
},
expected: []sql.Row{
{0}, {1}, {2}, {3}, {4},
Expand All @@ -567,7 +569,8 @@ func TestHandlerComPrepareExecuteWithPreparedDisabled(t *testing.T) {
BindVars: nil,
},
schema: []*query.Field{
{Name: "a", OrgName: "a", Table: "", OrgTable: "", Database: "", Type: query.Type_INT16, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 6, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)},
{Name: "a", OrgName: "a", Table: "", OrgTable: "", Database: "", Type: query.Type_INT16,
Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 6, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)},
},
expected: []sql.Row{
{1000},
Expand All @@ -584,7 +587,8 @@ func TestHandlerComPrepareExecuteWithPreparedDisabled(t *testing.T) {
BindVars: nil,
},
schema: []*query.Field{
{Name: "a", OrgName: "a", Table: "", OrgTable: "", Database: "", Type: query.Type_INT16, Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 6, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)},
{Name: "a", OrgName: "a", Table: "", OrgTable: "", Database: "", Type: query.Type_INT16,
Charset: uint32(sql.CharacterSet_utf8mb4), ColumnLength: 6, Flags: uint32(query.MySqlFlag_NOT_NULL_FLAG)},
},
expected: []sql.Row{
{-129},
Expand Down
61 changes: 2 additions & 59 deletions sql/expression/case.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,71 +43,14 @@ func NewCase(expr sql.Expression, branches []CaseBranch, elseExpr sql.Expression
return &Case{expr, branches, elseExpr}
}

// From the description of operator typing here:
// https://dev.mysql.com/doc/refman/8.0/en/flow-control-functions.html#operator_case
func combinedCaseBranchType(left, right sql.Type) sql.Type {
if left == types.Null {
return right
}
if right == types.Null {
return left
}

// Our current implementation of StringType.Convert(enum), does not match MySQL's behavior.
// So, we make sure to return Enums in this particular case.
// More details: https://github.com/dolthub/dolt/issues/8598
if types.IsEnum(left) && types.IsEnum(right) {
return right
}
if types.IsSet(left) && types.IsSet(right) {
return right
}
if types.IsTextOnly(left) && types.IsTextOnly(right) {
return types.LongText
}
if types.IsTextBlob(left) && types.IsTextBlob(right) {
return types.LongBlob
}
if types.IsTime(left) && types.IsTime(right) {
if left == right {
return left
}
return types.DatetimeMaxPrecision
}
if types.IsNumber(left) && types.IsNumber(right) {
if left == types.Float64 || right == types.Float64 {
return types.Float64
}
if left == types.Float32 || right == types.Float32 {
return types.Float32
}
if types.IsDecimal(left) || types.IsDecimal(right) {
return types.MustCreateDecimalType(65, 10)
}
if left == types.Uint64 && types.IsSigned(right) ||
right == types.Uint64 && types.IsSigned(left) {
return types.MustCreateDecimalType(65, 10)
}
if !types.IsSigned(left) && !types.IsSigned(right) {
return types.Uint64
} else {
return types.Int64
}
}
if types.IsJSON(left) && types.IsJSON(right) {
return types.JSON
}
return types.LongText
}

// Type implements the sql.Expression interface.
func (c *Case) Type() sql.Type {
curr := types.Null
for _, b := range c.Branches {
curr = combinedCaseBranchType(curr, b.Value.Type())
curr = types.GeneralizeTypes(curr, b.Value.Type())
}
if c.Else != nil {
curr = combinedCaseBranchType(curr, c.Else.Type())
curr = types.GeneralizeTypes(curr, c.Else.Type())
}
return curr
}
Expand Down
10 changes: 5 additions & 5 deletions sql/expression/case_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,8 @@ func TestCaseType(t *testing.T) {
}
}

decimalType := types.MustCreateDecimalType(65, 10)

decimalType := types.MustCreateDecimalType(types.DecimalTypeMaxPrecision, types.DecimalTypeMaxScale)
uint64DecimalType := types.MustCreateDecimalType(types.DecimalTypeMaxPrecision, 0)
testCases := []struct {
name string
c *Case
Expand All @@ -175,13 +175,13 @@ func TestCaseType(t *testing.T) {
},
{
"unsigned promoted and unsigned",
caseExpr(NewLiteral(uint32(0), types.Uint32), NewLiteral(uint32(1), types.Uint32)),
caseExpr(NewLiteral(uint32(0), types.Uint32), NewLiteral(uint32(1), types.Uint64)),
types.Uint64,
},
{
"signed promoted and signed",
caseExpr(NewLiteral(int8(0), types.Int8), NewLiteral(int32(1), types.Int32)),
types.Int64,
types.Int32,
},
{
"int and float to float",
Expand Down Expand Up @@ -216,7 +216,7 @@ func TestCaseType(t *testing.T) {
{
"uint64 and int8 to decimal",
caseExpr(NewLiteral(uint64(10), types.Uint64), NewLiteral(int8(0), types.Int8)),
decimalType,
uint64DecimalType,
},
{
"int and text to text",
Expand Down
16 changes: 4 additions & 12 deletions sql/expression/function/if.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,23 +89,15 @@ func (f *If) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
return nil, err
}
}
eval, _, err = f.Type().Convert(ctx, eval)
if ret, _, err := f.Type().Convert(ctx, eval); err == nil {
return ret, nil
}
return eval, err
}

// Type implements the Expression interface.
func (f *If) Type() sql.Type {
// if either type is string type, this should be a string type, regardless need to promote
typ1 := f.ifTrue.Type()
typ2 := f.ifFalse.Type()
if types.IsText(typ1) || types.IsText(typ2) {
return types.Text
}

if typ1 == types.Null {
return typ2.Promote()
}
return typ1.Promote()
return types.GeneralizeTypes(f.ifTrue.Type(), f.ifFalse.Type())
}

// CollationCoercibility implements the interface sql.CollationCoercible.
Expand Down
20 changes: 11 additions & 9 deletions sql/expression/function/ifnull.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,30 +52,32 @@ func (f *IfNull) Description() string {

// Eval implements the Expression interface.
func (f *IfNull) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
t := f.Type()

left, err := f.LeftChild.Eval(ctx, row)
if err != nil {
return nil, err
}
if left != nil {
return left, nil
if ret, _, err := t.Convert(ctx, left); err == nil {
return ret, nil
}
return left, err
}

right, err := f.RightChild.Eval(ctx, row)
if err != nil {
return nil, err
}
return right, nil
if ret, _, err := t.Convert(ctx, right); err == nil {
return ret, nil
}
return right, err
}

// Type implements the Expression interface.
func (f *IfNull) Type() sql.Type {
if types.IsNull(f.LeftChild) {
if types.IsNull(f.RightChild) {
return types.Null
}
return f.RightChild.Type()
}
return f.LeftChild.Type()
return types.GeneralizeTypes(f.LeftChild.Type(), f.RightChild.Type())
}

// CollationCoercibility implements the interface sql.CollationCoercible.
Expand Down
33 changes: 18 additions & 15 deletions sql/expression/function/ifnull_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,25 +26,28 @@ import (

func TestIfNull(t *testing.T) {
testCases := []struct {
expression interface{}
value interface{}
expected interface{}
expression interface{}
expressionType sql.Type
value interface{}
valueType sql.Type
expected interface{}
expectedType sql.Type
}{
{"foo", "bar", "foo"},
{"foo", "foo", "foo"},
{nil, "foo", "foo"},
{"foo", nil, "foo"},
{nil, nil, nil},
{"", nil, ""},
{"foo", types.LongText, "bar", types.LongText, "foo", types.LongText},
{"foo", types.LongText, "foo", types.LongText, "foo", types.LongText},
{nil, types.LongText, "foo", types.LongText, "foo", types.LongText},
{"foo", types.LongText, nil, types.LongText, "foo", types.LongText},
{nil, types.LongText, nil, types.LongText, nil, types.LongText},
{"", types.LongText, nil, types.LongText, "", types.LongText},
{nil, types.Int8, 128, types.Int64, int64(128), types.Int64},
}

f := NewIfNull(
expression.NewGetField(0, types.LongText, "expression", true),
expression.NewGetField(1, types.LongText, "value", true),
)
require.Equal(t, types.LongText, f.Type())

for _, tc := range testCases {
f := NewIfNull(
expression.NewGetField(0, tc.expressionType, "expression", true),
expression.NewGetField(1, tc.valueType, "value", true),
)
require.Equal(t, tc.expectedType, f.Type())
v, err := f.Eval(sql.NewEmptyContext(), sql.NewRow(tc.expression, tc.value))
require.NoError(t, err)
require.Equal(t, tc.expected, v)
Expand Down
Loading