Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions enginetest/join_op_tests.go
Original file line number Diff line number Diff line change
Expand Up @@ -2298,6 +2298,24 @@ WHERE
},
},
},
{
// https://github.com/dolthub/dolt/issues/10311
name: "join on different enum types",
setup: [][]string{
{
"create table animals(e enum('rat','ox','tiger','dog') primary key);",
"create table pets(e enum('cat','dog','fish','rat'), foreign key (e) references animals(e));",
"insert into animals values('rat'), ('dog');",
"insert into pets values ('cat'), ('rat');",
},
},
tests: []JoinOpTests{
{
Query: "select * from animals join pets on animals.e=pets.e;",
Expected: []sql.Row{{"rat", "rat"}},
},
},
},
}

var rangeJoinOpTests = []JoinOpTests{
Expand Down
40 changes: 40 additions & 0 deletions enginetest/queries/script_queries.go
Original file line number Diff line number Diff line change
Expand Up @@ -11416,6 +11416,46 @@ where
},
},
},
{
// https://github.com/dolthub/dolt/issues/10311
Name: "enums with foreign keys and joins",
Dialect: "mysql",
SetUpScript: []string{
"create table animals(e enum('rat','ox','tiger','dog') primary key);",
"create table pets(e enum('cat','dog','fish','rat'), foreign key (e) references animals(e));",
"insert into animals values('rat');",
},
Assertions: []ScriptTestAssertion{
{
Query: "insert into pets values ('rat');",
// Error expected here because 'rat' has different underlying int values depending on the enum type
ExpectedErr: sql.ErrForeignKeyChildViolation,
},
{
Query: "insert into pets values ('cat');",
// Query OK expected here because the underlying int values are the same
Expected: []sql.Row{{types.NewOkResult(1)}},
},
{
Query: "select * from animals join pets on animals.e=pets.e;",
// Empty set expected here because comparison uses the string values when enum types are different
Expected: []sql.Row{},
},
{
Query: "insert into animals values ('dog');",
Expected: []sql.Row{{types.NewOkResult(1)}},
},
{
Query: "insert into pets values ('rat');",
// 'rat' is now okay because it has the same underlying int value as 'dog' in the animals table
Expected: []sql.Row{{types.NewOkResult(1)}},
},
{
Query: "select * from animals join pets on animals.e=pets.e;",
Expected: []sql.Row{{"rat", "rat"}},
},
},
},
{
Skip: true,
Name: "enums with foreign keys and cascade",
Expand Down
15 changes: 14 additions & 1 deletion sql/analyzer/indexed_joins.go
Original file line number Diff line number Diff line change
Expand Up @@ -1147,7 +1147,7 @@ func addMergeJoins(ctx *sql.Context, m *memo.Memo) error {
}

// check that comparer is not non-decreasing
if !isWeaklyMonotonic(l) || !isWeaklyMonotonic(r) {
if !canMergeTypes(l.Type(), r.Type()) || !isWeaklyMonotonic(l) || !isWeaklyMonotonic(r) {
continue
}

Expand Down Expand Up @@ -1491,6 +1491,19 @@ func makeIndexScan(ctx *sql.Context, statsProv sql.StatsProvider, tab plan.Table
}, true, nil
}

// canMerge checks the types of two columns to see if they can be merged into one another if sorted.
func canMergeTypes(t1, t2 sql.Type) bool {
// TODO: handle other types here. For example, Number and Text types likely can't be merged together. But we need to
// add more testing https://github.com/dolthub/dolt/issues/10316
switch {
case types.IsEnum(t1):
if types.IsEnum(t2) {
return types.TypesEqual(t1, t2)
}
}
return true
}

// isWeaklyMonotonic is a weak test of whether an expression
// will be strictly increasing as the value of column attribute
// inputs increases.
Expand Down
22 changes: 11 additions & 11 deletions sql/expression/comparison.go
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ func (c *comparison) castLeftAndRight(ctx *sql.Context, left, right interface{})
}

if types.IsTime(leftType) || types.IsTime(rightType) {
l, r, err := convertLeftAndRight(ctx, left, right, ConvertToDatetime)
l, r, err := c.convertLeftAndRight(ctx, left, right, ConvertToDatetime)
if err != nil {
return nil, nil, nil, err
}
Expand All @@ -291,7 +291,7 @@ func (c *comparison) castLeftAndRight(ctx *sql.Context, left, right interface{})
}

if types.IsBinaryType(leftType) || types.IsBinaryType(rightType) {
l, r, err := convertLeftAndRight(ctx, left, right, ConvertToBinary)
l, r, err := c.convertLeftAndRight(ctx, left, right, ConvertToBinary)
if err != nil {
return nil, nil, nil, err
}
Expand All @@ -301,7 +301,7 @@ func (c *comparison) castLeftAndRight(ctx *sql.Context, left, right interface{})
if types.IsNumber(leftType) || types.IsNumber(rightType) {
if types.IsDecimal(leftType) || types.IsDecimal(rightType) {
//TODO: We need to set to the actual DECIMAL type
l, r, err := convertLeftAndRight(ctx, left, right, ConvertToDecimal)
l, r, err := c.convertLeftAndRight(ctx, left, right, ConvertToDecimal)
if err != nil {
return nil, nil, nil, err
}
Expand All @@ -314,7 +314,7 @@ func (c *comparison) castLeftAndRight(ctx *sql.Context, left, right interface{})
}

if types.IsFloat(leftType) || types.IsFloat(rightType) {
l, r, err := convertLeftAndRight(ctx, left, right, ConvertToDouble)
l, r, err := c.convertLeftAndRight(ctx, left, right, ConvertToDouble)
if err != nil {
return nil, nil, nil, err
}
Expand All @@ -323,7 +323,7 @@ func (c *comparison) castLeftAndRight(ctx *sql.Context, left, right interface{})
}

if types.IsSigned(leftType) && types.IsSigned(rightType) {
l, r, err := convertLeftAndRight(ctx, left, right, ConvertToSigned)
l, r, err := c.convertLeftAndRight(ctx, left, right, ConvertToSigned)
if err != nil {
return nil, nil, nil, err
}
Expand All @@ -332,41 +332,41 @@ func (c *comparison) castLeftAndRight(ctx *sql.Context, left, right interface{})
}

if types.IsUnsigned(leftType) && types.IsUnsigned(rightType) {
l, r, err := convertLeftAndRight(ctx, left, right, ConvertToUnsigned)
l, r, err := c.convertLeftAndRight(ctx, left, right, ConvertToUnsigned)
if err != nil {
return nil, nil, nil, err
}

return l, r, types.Uint64, nil
}

l, r, err := convertLeftAndRight(ctx, left, right, ConvertToDouble)
l, r, err := c.convertLeftAndRight(ctx, left, right, ConvertToDouble)
if err != nil {
return nil, nil, nil, err
}

return l, r, types.Float64, nil
}

left, right, err := convertLeftAndRight(ctx, left, right, ConvertToChar)
left, right, err := c.convertLeftAndRight(ctx, left, right, ConvertToChar)
if err != nil {
return nil, nil, nil, err
}

return left, right, types.LongText, nil
}

func convertLeftAndRight(ctx *sql.Context, left, right interface{}, convertTo string) (interface{}, interface{}, error) {
func (c *comparison) convertLeftAndRight(ctx *sql.Context, left, right interface{}, convertTo string) (interface{}, interface{}, error) {
typeLength := 0
if convertTo == ConvertToDatetime {
typeLength = types.MaxDatetimePrecision
}
l, err := convertValue(ctx, left, convertTo, nil, typeLength, 0)
l, err := convertValue(ctx, left, convertTo, c.Left().Type(), typeLength, 0)
if err != nil {
return nil, nil, err
}

r, err := convertValue(ctx, right, convertTo, nil, typeLength, 0)
r, err := convertValue(ctx, right, convertTo, c.Right().Type(), typeLength, 0)
if err != nil {
return nil, nil, err
}
Expand Down
41 changes: 23 additions & 18 deletions sql/expression/filter-range.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,62 +40,67 @@ func NewRangeFilterExpr(exprs []sql.Expression, ranges []sql.MySQLRange) (sql.Ex
var rangeExpr sql.Expression
for i, rce := range rang {
var rangeColumnExpr sql.Expression
typ := exprs[i].Type().Promote()
switch rce.Type() {
// Both Empty and All may seem like strange inclusions, but if only one range is given we need some
// expression to evaluate, otherwise our expression would be a nil expression which would panic.
case sql.RangeType_Empty:
rangeColumnExpr = NewEquals(NewLiteral(1, types.Int8), NewLiteral(2, types.Int8))
rangeColumnExpr = NewLiteral(false, types.Boolean)
case sql.RangeType_All:
rangeColumnExpr = NewEquals(NewLiteral(1, types.Int8), NewLiteral(1, types.Int8))
rangeColumnExpr = NewLiteral(true, types.Boolean)
case sql.RangeType_EqualNull:
rangeColumnExpr = DefaultExpressionFactory.NewIsNull(exprs[i])
case sql.RangeType_GreaterThan:
if sql.MySQLRangeCutIsBinding(rce.LowerBound) {
rangeColumnExpr = NewGreaterThan(exprs[i], NewLiteral(sql.GetMySQLRangeCutKey(rce.LowerBound), rce.Typ.Promote()))
rangeColumnExpr = NewGreaterThan(exprs[i], NewLiteral(sql.GetMySQLRangeCutKey(rce.LowerBound), typ))
} else {
rangeColumnExpr = DefaultExpressionFactory.NewIsNotNull(exprs[i])
}
case sql.RangeType_GreaterOrEqual:
rangeColumnExpr = NewGreaterThanOrEqual(exprs[i], NewLiteral(sql.GetMySQLRangeCutKey(rce.LowerBound), rce.Typ.Promote()))
rangeColumnExpr = NewGreaterThanOrEqual(exprs[i], NewLiteral(sql.GetMySQLRangeCutKey(rce.LowerBound), typ))
case sql.RangeType_LessThanOrNull:
rangeColumnExpr = JoinOr(
NewLessThan(exprs[i], NewLiteral(sql.GetMySQLRangeCutKey(rce.UpperBound), rce.Typ.Promote())),
NewLessThan(exprs[i], NewLiteral(sql.GetMySQLRangeCutKey(rce.UpperBound), typ)),
DefaultExpressionFactory.NewIsNull(exprs[i]),
)
case sql.RangeType_LessOrEqualOrNull:
rangeColumnExpr = JoinOr(
NewLessThanOrEqual(exprs[i], NewLiteral(sql.GetMySQLRangeCutKey(rce.UpperBound), rce.Typ.Promote())),
NewLessThanOrEqual(exprs[i], NewLiteral(sql.GetMySQLRangeCutKey(rce.UpperBound), typ)),
DefaultExpressionFactory.NewIsNull(exprs[i]),
)
case sql.RangeType_ClosedClosed:
rangeColumnExpr = JoinAnd(
NewGreaterThanOrEqual(exprs[i], NewLiteral(sql.GetMySQLRangeCutKey(rce.LowerBound), rce.Typ.Promote())),
NewLessThanOrEqual(exprs[i], NewLiteral(sql.GetMySQLRangeCutKey(rce.UpperBound), rce.Typ.Promote())),
)
if rce.LowerBound == rce.UpperBound {
rangeColumnExpr = NewEquals(exprs[i], NewLiteral(sql.GetMySQLRangeCutKey(rce.LowerBound), typ))
} else {
rangeColumnExpr = JoinAnd(
NewGreaterThanOrEqual(exprs[i], NewLiteral(sql.GetMySQLRangeCutKey(rce.LowerBound), typ)),
NewLessThanOrEqual(exprs[i], NewLiteral(sql.GetMySQLRangeCutKey(rce.UpperBound), typ)),
)
}
case sql.RangeType_OpenOpen:
if sql.MySQLRangeCutIsBinding(rce.LowerBound) {
rangeColumnExpr = JoinAnd(
NewGreaterThan(exprs[i], NewLiteral(sql.GetMySQLRangeCutKey(rce.LowerBound), rce.Typ.Promote())),
NewLessThan(exprs[i], NewLiteral(sql.GetMySQLRangeCutKey(rce.UpperBound), rce.Typ.Promote())),
NewGreaterThan(exprs[i], NewLiteral(sql.GetMySQLRangeCutKey(rce.LowerBound), typ)),
NewLessThan(exprs[i], NewLiteral(sql.GetMySQLRangeCutKey(rce.UpperBound), typ)),
)
} else {
// Lower bound is (NULL, ...)
rangeColumnExpr = NewLessThan(exprs[i], NewLiteral(sql.GetMySQLRangeCutKey(rce.UpperBound), rce.Typ.Promote()))
rangeColumnExpr = NewLessThan(exprs[i], NewLiteral(sql.GetMySQLRangeCutKey(rce.UpperBound), typ))
}
case sql.RangeType_OpenClosed:
if sql.MySQLRangeCutIsBinding(rce.LowerBound) {
rangeColumnExpr = JoinAnd(
NewGreaterThan(exprs[i], NewLiteral(sql.GetMySQLRangeCutKey(rce.LowerBound), rce.Typ.Promote())),
NewLessThanOrEqual(exprs[i], NewLiteral(sql.GetMySQLRangeCutKey(rce.UpperBound), rce.Typ.Promote())),
NewGreaterThan(exprs[i], NewLiteral(sql.GetMySQLRangeCutKey(rce.LowerBound), typ)),
NewLessThanOrEqual(exprs[i], NewLiteral(sql.GetMySQLRangeCutKey(rce.UpperBound), typ)),
)
} else {
// Lower bound is (NULL, ...]
rangeColumnExpr = NewLessThanOrEqual(exprs[i], NewLiteral(sql.GetMySQLRangeCutKey(rce.UpperBound), rce.Typ.Promote()))
rangeColumnExpr = NewLessThanOrEqual(exprs[i], NewLiteral(sql.GetMySQLRangeCutKey(rce.UpperBound), typ))
}
case sql.RangeType_ClosedOpen:
rangeColumnExpr = JoinAnd(
NewGreaterThanOrEqual(exprs[i], NewLiteral(sql.GetMySQLRangeCutKey(rce.LowerBound), rce.Typ.Promote())),
NewLessThan(exprs[i], NewLiteral(sql.GetMySQLRangeCutKey(rce.UpperBound), rce.Typ.Promote())),
NewGreaterThanOrEqual(exprs[i], NewLiteral(sql.GetMySQLRangeCutKey(rce.LowerBound), typ)),
NewLessThan(exprs[i], NewLiteral(sql.GetMySQLRangeCutKey(rce.UpperBound), typ)),
)
}
rangeExpr = JoinAnd(rangeExpr, rangeColumnExpr)
Expand Down
9 changes: 7 additions & 2 deletions sql/index_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -331,10 +331,15 @@ func (b *MySQLIndexBuilder) GreaterThan(ctx *Context, colExpr string, keyType Ty

// IsConvertibleKeyType checks if the key can be converted into the column type
func IsConvertibleKeyType(colType Type, keyType Type) bool {
if IsStringType(colType) {
// TODO: check other types https://github.com/dolthub/dolt/issues/10316
switch {
case IsEnumType(colType):
if IsEnumType(keyType) {
return colType.Equals(keyType)
}
case IsStringType(colType):
return !(IsNumberType(keyType) || IsDecimalType(keyType))
}
// TODO: check other types
return true
}

Expand Down
5 changes: 5 additions & 0 deletions sql/type.go
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,11 @@ type EnumType interface {
Values() []string
}

func IsEnumType(t Type) bool {
_, ok := t.(EnumType)
return ok
}

// DecimalType represents the DECIMAL type.
// https://dev.mysql.com/doc/refman/8.0/en/fixed-point-types.html
// The type of the returned value is decimal.Decimal.
Expand Down