diff --git a/enginetest/join_op_tests.go b/enginetest/join_op_tests.go index 8f7e72f9cb..4764eff810 100644 --- a/enginetest/join_op_tests.go +++ b/enginetest/join_op_tests.go @@ -2298,6 +2298,24 @@ WHERE }, }, }, + { + // https://github.com/dolthub/dolt/issues/10311 + name: "join on different enum types", + setup: [][]string{ + { + "create table animals(e enum('rat','ox','tiger','dog') primary key);", + "create table pets(e enum('cat','dog','fish','rat'), foreign key (e) references animals(e));", + "insert into animals values('rat'), ('dog');", + "insert into pets values ('cat'), ('rat');", + }, + }, + tests: []JoinOpTests{ + { + Query: "select * from animals join pets on animals.e=pets.e;", + Expected: []sql.Row{{"rat", "rat"}}, + }, + }, + }, } var rangeJoinOpTests = []JoinOpTests{ diff --git a/enginetest/queries/script_queries.go b/enginetest/queries/script_queries.go index 999dbb685e..3272d3df78 100644 --- a/enginetest/queries/script_queries.go +++ b/enginetest/queries/script_queries.go @@ -11416,6 +11416,46 @@ where }, }, }, + { + // https://github.com/dolthub/dolt/issues/10311 + Name: "enums with foreign keys and joins", + Dialect: "mysql", + SetUpScript: []string{ + "create table animals(e enum('rat','ox','tiger','dog') primary key);", + "create table pets(e enum('cat','dog','fish','rat'), foreign key (e) references animals(e));", + "insert into animals values('rat');", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "insert into pets values ('rat');", + // Error expected here because 'rat' has different underlying int values depending on the enum type + ExpectedErr: sql.ErrForeignKeyChildViolation, + }, + { + Query: "insert into pets values ('cat');", + // Query OK expected here because the underlying int values are the same + Expected: []sql.Row{{types.NewOkResult(1)}}, + }, + { + Query: "select * from animals join pets on animals.e=pets.e;", + // Empty set expected here because comparison uses the string values when enum types are different + Expected: []sql.Row{}, + }, + { + Query: "insert into animals values ('dog');", + Expected: []sql.Row{{types.NewOkResult(1)}}, + }, + { + Query: "insert into pets values ('rat');", + // 'rat' is now okay because it has the same underlying int value as 'dog' in the animals table + Expected: []sql.Row{{types.NewOkResult(1)}}, + }, + { + Query: "select * from animals join pets on animals.e=pets.e;", + Expected: []sql.Row{{"rat", "rat"}}, + }, + }, + }, { Skip: true, Name: "enums with foreign keys and cascade", diff --git a/sql/analyzer/indexed_joins.go b/sql/analyzer/indexed_joins.go index 8bf621c031..7dfb736bec 100644 --- a/sql/analyzer/indexed_joins.go +++ b/sql/analyzer/indexed_joins.go @@ -1147,7 +1147,7 @@ func addMergeJoins(ctx *sql.Context, m *memo.Memo) error { } // check that comparer is not non-decreasing - if !isWeaklyMonotonic(l) || !isWeaklyMonotonic(r) { + if !canMergeTypes(l.Type(), r.Type()) || !isWeaklyMonotonic(l) || !isWeaklyMonotonic(r) { continue } @@ -1491,6 +1491,19 @@ func makeIndexScan(ctx *sql.Context, statsProv sql.StatsProvider, tab plan.Table }, true, nil } +// canMerge checks the types of two columns to see if they can be merged into one another if sorted. +func canMergeTypes(t1, t2 sql.Type) bool { + // TODO: handle other types here. For example, Number and Text types likely can't be merged together. But we need to + // add more testing https://github.com/dolthub/dolt/issues/10316 + switch { + case types.IsEnum(t1): + if types.IsEnum(t2) { + return types.TypesEqual(t1, t2) + } + } + return true +} + // isWeaklyMonotonic is a weak test of whether an expression // will be strictly increasing as the value of column attribute // inputs increases. diff --git a/sql/expression/comparison.go b/sql/expression/comparison.go index 446851310f..6c932c9894 100644 --- a/sql/expression/comparison.go +++ b/sql/expression/comparison.go @@ -278,7 +278,7 @@ func (c *comparison) castLeftAndRight(ctx *sql.Context, left, right interface{}) } if types.IsTime(leftType) || types.IsTime(rightType) { - l, r, err := convertLeftAndRight(ctx, left, right, ConvertToDatetime) + l, r, err := c.convertLeftAndRight(ctx, left, right, ConvertToDatetime) if err != nil { return nil, nil, nil, err } @@ -291,7 +291,7 @@ func (c *comparison) castLeftAndRight(ctx *sql.Context, left, right interface{}) } if types.IsBinaryType(leftType) || types.IsBinaryType(rightType) { - l, r, err := convertLeftAndRight(ctx, left, right, ConvertToBinary) + l, r, err := c.convertLeftAndRight(ctx, left, right, ConvertToBinary) if err != nil { return nil, nil, nil, err } @@ -301,7 +301,7 @@ func (c *comparison) castLeftAndRight(ctx *sql.Context, left, right interface{}) if types.IsNumber(leftType) || types.IsNumber(rightType) { if types.IsDecimal(leftType) || types.IsDecimal(rightType) { //TODO: We need to set to the actual DECIMAL type - l, r, err := convertLeftAndRight(ctx, left, right, ConvertToDecimal) + l, r, err := c.convertLeftAndRight(ctx, left, right, ConvertToDecimal) if err != nil { return nil, nil, nil, err } @@ -314,7 +314,7 @@ func (c *comparison) castLeftAndRight(ctx *sql.Context, left, right interface{}) } if types.IsFloat(leftType) || types.IsFloat(rightType) { - l, r, err := convertLeftAndRight(ctx, left, right, ConvertToDouble) + l, r, err := c.convertLeftAndRight(ctx, left, right, ConvertToDouble) if err != nil { return nil, nil, nil, err } @@ -323,7 +323,7 @@ func (c *comparison) castLeftAndRight(ctx *sql.Context, left, right interface{}) } if types.IsSigned(leftType) && types.IsSigned(rightType) { - l, r, err := convertLeftAndRight(ctx, left, right, ConvertToSigned) + l, r, err := c.convertLeftAndRight(ctx, left, right, ConvertToSigned) if err != nil { return nil, nil, nil, err } @@ -332,7 +332,7 @@ func (c *comparison) castLeftAndRight(ctx *sql.Context, left, right interface{}) } if types.IsUnsigned(leftType) && types.IsUnsigned(rightType) { - l, r, err := convertLeftAndRight(ctx, left, right, ConvertToUnsigned) + l, r, err := c.convertLeftAndRight(ctx, left, right, ConvertToUnsigned) if err != nil { return nil, nil, nil, err } @@ -340,7 +340,7 @@ func (c *comparison) castLeftAndRight(ctx *sql.Context, left, right interface{}) return l, r, types.Uint64, nil } - l, r, err := convertLeftAndRight(ctx, left, right, ConvertToDouble) + l, r, err := c.convertLeftAndRight(ctx, left, right, ConvertToDouble) if err != nil { return nil, nil, nil, err } @@ -348,7 +348,7 @@ func (c *comparison) castLeftAndRight(ctx *sql.Context, left, right interface{}) return l, r, types.Float64, nil } - left, right, err := convertLeftAndRight(ctx, left, right, ConvertToChar) + left, right, err := c.convertLeftAndRight(ctx, left, right, ConvertToChar) if err != nil { return nil, nil, nil, err } @@ -356,17 +356,17 @@ func (c *comparison) castLeftAndRight(ctx *sql.Context, left, right interface{}) return left, right, types.LongText, nil } -func convertLeftAndRight(ctx *sql.Context, left, right interface{}, convertTo string) (interface{}, interface{}, error) { +func (c *comparison) convertLeftAndRight(ctx *sql.Context, left, right interface{}, convertTo string) (interface{}, interface{}, error) { typeLength := 0 if convertTo == ConvertToDatetime { typeLength = types.MaxDatetimePrecision } - l, err := convertValue(ctx, left, convertTo, nil, typeLength, 0) + l, err := convertValue(ctx, left, convertTo, c.Left().Type(), typeLength, 0) if err != nil { return nil, nil, err } - r, err := convertValue(ctx, right, convertTo, nil, typeLength, 0) + r, err := convertValue(ctx, right, convertTo, c.Right().Type(), typeLength, 0) if err != nil { return nil, nil, err } diff --git a/sql/expression/filter-range.go b/sql/expression/filter-range.go index 231e8043c8..2ba43bd8ae 100644 --- a/sql/expression/filter-range.go +++ b/sql/expression/filter-range.go @@ -40,62 +40,67 @@ func NewRangeFilterExpr(exprs []sql.Expression, ranges []sql.MySQLRange) (sql.Ex var rangeExpr sql.Expression for i, rce := range rang { var rangeColumnExpr sql.Expression + typ := exprs[i].Type().Promote() switch rce.Type() { // Both Empty and All may seem like strange inclusions, but if only one range is given we need some // expression to evaluate, otherwise our expression would be a nil expression which would panic. case sql.RangeType_Empty: - rangeColumnExpr = NewEquals(NewLiteral(1, types.Int8), NewLiteral(2, types.Int8)) + rangeColumnExpr = NewLiteral(false, types.Boolean) case sql.RangeType_All: - rangeColumnExpr = NewEquals(NewLiteral(1, types.Int8), NewLiteral(1, types.Int8)) + rangeColumnExpr = NewLiteral(true, types.Boolean) case sql.RangeType_EqualNull: rangeColumnExpr = DefaultExpressionFactory.NewIsNull(exprs[i]) case sql.RangeType_GreaterThan: if sql.MySQLRangeCutIsBinding(rce.LowerBound) { - rangeColumnExpr = NewGreaterThan(exprs[i], NewLiteral(sql.GetMySQLRangeCutKey(rce.LowerBound), rce.Typ.Promote())) + rangeColumnExpr = NewGreaterThan(exprs[i], NewLiteral(sql.GetMySQLRangeCutKey(rce.LowerBound), typ)) } else { rangeColumnExpr = DefaultExpressionFactory.NewIsNotNull(exprs[i]) } case sql.RangeType_GreaterOrEqual: - rangeColumnExpr = NewGreaterThanOrEqual(exprs[i], NewLiteral(sql.GetMySQLRangeCutKey(rce.LowerBound), rce.Typ.Promote())) + rangeColumnExpr = NewGreaterThanOrEqual(exprs[i], NewLiteral(sql.GetMySQLRangeCutKey(rce.LowerBound), typ)) case sql.RangeType_LessThanOrNull: rangeColumnExpr = JoinOr( - NewLessThan(exprs[i], NewLiteral(sql.GetMySQLRangeCutKey(rce.UpperBound), rce.Typ.Promote())), + NewLessThan(exprs[i], NewLiteral(sql.GetMySQLRangeCutKey(rce.UpperBound), typ)), DefaultExpressionFactory.NewIsNull(exprs[i]), ) case sql.RangeType_LessOrEqualOrNull: rangeColumnExpr = JoinOr( - NewLessThanOrEqual(exprs[i], NewLiteral(sql.GetMySQLRangeCutKey(rce.UpperBound), rce.Typ.Promote())), + NewLessThanOrEqual(exprs[i], NewLiteral(sql.GetMySQLRangeCutKey(rce.UpperBound), typ)), DefaultExpressionFactory.NewIsNull(exprs[i]), ) case sql.RangeType_ClosedClosed: - rangeColumnExpr = JoinAnd( - NewGreaterThanOrEqual(exprs[i], NewLiteral(sql.GetMySQLRangeCutKey(rce.LowerBound), rce.Typ.Promote())), - NewLessThanOrEqual(exprs[i], NewLiteral(sql.GetMySQLRangeCutKey(rce.UpperBound), rce.Typ.Promote())), - ) + if rce.LowerBound == rce.UpperBound { + rangeColumnExpr = NewEquals(exprs[i], NewLiteral(sql.GetMySQLRangeCutKey(rce.LowerBound), typ)) + } else { + rangeColumnExpr = JoinAnd( + NewGreaterThanOrEqual(exprs[i], NewLiteral(sql.GetMySQLRangeCutKey(rce.LowerBound), typ)), + NewLessThanOrEqual(exprs[i], NewLiteral(sql.GetMySQLRangeCutKey(rce.UpperBound), typ)), + ) + } case sql.RangeType_OpenOpen: if sql.MySQLRangeCutIsBinding(rce.LowerBound) { rangeColumnExpr = JoinAnd( - NewGreaterThan(exprs[i], NewLiteral(sql.GetMySQLRangeCutKey(rce.LowerBound), rce.Typ.Promote())), - NewLessThan(exprs[i], NewLiteral(sql.GetMySQLRangeCutKey(rce.UpperBound), rce.Typ.Promote())), + NewGreaterThan(exprs[i], NewLiteral(sql.GetMySQLRangeCutKey(rce.LowerBound), typ)), + NewLessThan(exprs[i], NewLiteral(sql.GetMySQLRangeCutKey(rce.UpperBound), typ)), ) } else { // Lower bound is (NULL, ...) - rangeColumnExpr = NewLessThan(exprs[i], NewLiteral(sql.GetMySQLRangeCutKey(rce.UpperBound), rce.Typ.Promote())) + rangeColumnExpr = NewLessThan(exprs[i], NewLiteral(sql.GetMySQLRangeCutKey(rce.UpperBound), typ)) } case sql.RangeType_OpenClosed: if sql.MySQLRangeCutIsBinding(rce.LowerBound) { rangeColumnExpr = JoinAnd( - NewGreaterThan(exprs[i], NewLiteral(sql.GetMySQLRangeCutKey(rce.LowerBound), rce.Typ.Promote())), - NewLessThanOrEqual(exprs[i], NewLiteral(sql.GetMySQLRangeCutKey(rce.UpperBound), rce.Typ.Promote())), + NewGreaterThan(exprs[i], NewLiteral(sql.GetMySQLRangeCutKey(rce.LowerBound), typ)), + NewLessThanOrEqual(exprs[i], NewLiteral(sql.GetMySQLRangeCutKey(rce.UpperBound), typ)), ) } else { // Lower bound is (NULL, ...] - rangeColumnExpr = NewLessThanOrEqual(exprs[i], NewLiteral(sql.GetMySQLRangeCutKey(rce.UpperBound), rce.Typ.Promote())) + rangeColumnExpr = NewLessThanOrEqual(exprs[i], NewLiteral(sql.GetMySQLRangeCutKey(rce.UpperBound), typ)) } case sql.RangeType_ClosedOpen: rangeColumnExpr = JoinAnd( - NewGreaterThanOrEqual(exprs[i], NewLiteral(sql.GetMySQLRangeCutKey(rce.LowerBound), rce.Typ.Promote())), - NewLessThan(exprs[i], NewLiteral(sql.GetMySQLRangeCutKey(rce.UpperBound), rce.Typ.Promote())), + NewGreaterThanOrEqual(exprs[i], NewLiteral(sql.GetMySQLRangeCutKey(rce.LowerBound), typ)), + NewLessThan(exprs[i], NewLiteral(sql.GetMySQLRangeCutKey(rce.UpperBound), typ)), ) } rangeExpr = JoinAnd(rangeExpr, rangeColumnExpr) diff --git a/sql/index_builder.go b/sql/index_builder.go index 484a794627..4579db6c44 100644 --- a/sql/index_builder.go +++ b/sql/index_builder.go @@ -331,10 +331,15 @@ func (b *MySQLIndexBuilder) GreaterThan(ctx *Context, colExpr string, keyType Ty // IsConvertibleKeyType checks if the key can be converted into the column type func IsConvertibleKeyType(colType Type, keyType Type) bool { - if IsStringType(colType) { + // TODO: check other types https://github.com/dolthub/dolt/issues/10316 + switch { + case IsEnumType(colType): + if IsEnumType(keyType) { + return colType.Equals(keyType) + } + case IsStringType(colType): return !(IsNumberType(keyType) || IsDecimalType(keyType)) } - // TODO: check other types return true } diff --git a/sql/type.go b/sql/type.go index 1d277b1c9e..c1c549ddba 100644 --- a/sql/type.go +++ b/sql/type.go @@ -280,6 +280,11 @@ type EnumType interface { Values() []string } +func IsEnumType(t Type) bool { + _, ok := t.(EnumType) + return ok +} + // DecimalType represents the DECIMAL type. // https://dev.mysql.com/doc/refman/8.0/en/fixed-point-types.html // The type of the returned value is decimal.Decimal.