From d853236208140389ce8c831f436ce7d2436b6034 Mon Sep 17 00:00:00 2001 From: angelamayxie Date: Thu, 15 Jan 2026 10:09:13 -0800 Subject: [PATCH 1/7] use key type for memory table index filter range --- sql/expression/filter-range.go | 41 +++++++++++++++++++--------------- 1 file changed, 23 insertions(+), 18 deletions(-) diff --git a/sql/expression/filter-range.go b/sql/expression/filter-range.go index 231e8043c8..2ba43bd8ae 100644 --- a/sql/expression/filter-range.go +++ b/sql/expression/filter-range.go @@ -40,62 +40,67 @@ func NewRangeFilterExpr(exprs []sql.Expression, ranges []sql.MySQLRange) (sql.Ex var rangeExpr sql.Expression for i, rce := range rang { var rangeColumnExpr sql.Expression + typ := exprs[i].Type().Promote() switch rce.Type() { // Both Empty and All may seem like strange inclusions, but if only one range is given we need some // expression to evaluate, otherwise our expression would be a nil expression which would panic. case sql.RangeType_Empty: - rangeColumnExpr = NewEquals(NewLiteral(1, types.Int8), NewLiteral(2, types.Int8)) + rangeColumnExpr = NewLiteral(false, types.Boolean) case sql.RangeType_All: - rangeColumnExpr = NewEquals(NewLiteral(1, types.Int8), NewLiteral(1, types.Int8)) + rangeColumnExpr = NewLiteral(true, types.Boolean) case sql.RangeType_EqualNull: rangeColumnExpr = DefaultExpressionFactory.NewIsNull(exprs[i]) case sql.RangeType_GreaterThan: if sql.MySQLRangeCutIsBinding(rce.LowerBound) { - rangeColumnExpr = NewGreaterThan(exprs[i], NewLiteral(sql.GetMySQLRangeCutKey(rce.LowerBound), rce.Typ.Promote())) + rangeColumnExpr = NewGreaterThan(exprs[i], NewLiteral(sql.GetMySQLRangeCutKey(rce.LowerBound), typ)) } else { rangeColumnExpr = DefaultExpressionFactory.NewIsNotNull(exprs[i]) } case sql.RangeType_GreaterOrEqual: - rangeColumnExpr = NewGreaterThanOrEqual(exprs[i], NewLiteral(sql.GetMySQLRangeCutKey(rce.LowerBound), rce.Typ.Promote())) + rangeColumnExpr = NewGreaterThanOrEqual(exprs[i], NewLiteral(sql.GetMySQLRangeCutKey(rce.LowerBound), typ)) case sql.RangeType_LessThanOrNull: rangeColumnExpr = JoinOr( - NewLessThan(exprs[i], NewLiteral(sql.GetMySQLRangeCutKey(rce.UpperBound), rce.Typ.Promote())), + NewLessThan(exprs[i], NewLiteral(sql.GetMySQLRangeCutKey(rce.UpperBound), typ)), DefaultExpressionFactory.NewIsNull(exprs[i]), ) case sql.RangeType_LessOrEqualOrNull: rangeColumnExpr = JoinOr( - NewLessThanOrEqual(exprs[i], NewLiteral(sql.GetMySQLRangeCutKey(rce.UpperBound), rce.Typ.Promote())), + NewLessThanOrEqual(exprs[i], NewLiteral(sql.GetMySQLRangeCutKey(rce.UpperBound), typ)), DefaultExpressionFactory.NewIsNull(exprs[i]), ) case sql.RangeType_ClosedClosed: - rangeColumnExpr = JoinAnd( - NewGreaterThanOrEqual(exprs[i], NewLiteral(sql.GetMySQLRangeCutKey(rce.LowerBound), rce.Typ.Promote())), - NewLessThanOrEqual(exprs[i], NewLiteral(sql.GetMySQLRangeCutKey(rce.UpperBound), rce.Typ.Promote())), - ) + if rce.LowerBound == rce.UpperBound { + rangeColumnExpr = NewEquals(exprs[i], NewLiteral(sql.GetMySQLRangeCutKey(rce.LowerBound), typ)) + } else { + rangeColumnExpr = JoinAnd( + NewGreaterThanOrEqual(exprs[i], NewLiteral(sql.GetMySQLRangeCutKey(rce.LowerBound), typ)), + NewLessThanOrEqual(exprs[i], NewLiteral(sql.GetMySQLRangeCutKey(rce.UpperBound), typ)), + ) + } case sql.RangeType_OpenOpen: if sql.MySQLRangeCutIsBinding(rce.LowerBound) { rangeColumnExpr = JoinAnd( - NewGreaterThan(exprs[i], NewLiteral(sql.GetMySQLRangeCutKey(rce.LowerBound), rce.Typ.Promote())), - NewLessThan(exprs[i], NewLiteral(sql.GetMySQLRangeCutKey(rce.UpperBound), rce.Typ.Promote())), + NewGreaterThan(exprs[i], NewLiteral(sql.GetMySQLRangeCutKey(rce.LowerBound), typ)), + NewLessThan(exprs[i], NewLiteral(sql.GetMySQLRangeCutKey(rce.UpperBound), typ)), ) } else { // Lower bound is (NULL, ...) - rangeColumnExpr = NewLessThan(exprs[i], NewLiteral(sql.GetMySQLRangeCutKey(rce.UpperBound), rce.Typ.Promote())) + rangeColumnExpr = NewLessThan(exprs[i], NewLiteral(sql.GetMySQLRangeCutKey(rce.UpperBound), typ)) } case sql.RangeType_OpenClosed: if sql.MySQLRangeCutIsBinding(rce.LowerBound) { rangeColumnExpr = JoinAnd( - NewGreaterThan(exprs[i], NewLiteral(sql.GetMySQLRangeCutKey(rce.LowerBound), rce.Typ.Promote())), - NewLessThanOrEqual(exprs[i], NewLiteral(sql.GetMySQLRangeCutKey(rce.UpperBound), rce.Typ.Promote())), + NewGreaterThan(exprs[i], NewLiteral(sql.GetMySQLRangeCutKey(rce.LowerBound), typ)), + NewLessThanOrEqual(exprs[i], NewLiteral(sql.GetMySQLRangeCutKey(rce.UpperBound), typ)), ) } else { // Lower bound is (NULL, ...] - rangeColumnExpr = NewLessThanOrEqual(exprs[i], NewLiteral(sql.GetMySQLRangeCutKey(rce.UpperBound), rce.Typ.Promote())) + rangeColumnExpr = NewLessThanOrEqual(exprs[i], NewLiteral(sql.GetMySQLRangeCutKey(rce.UpperBound), typ)) } case sql.RangeType_ClosedOpen: rangeColumnExpr = JoinAnd( - NewGreaterThanOrEqual(exprs[i], NewLiteral(sql.GetMySQLRangeCutKey(rce.LowerBound), rce.Typ.Promote())), - NewLessThan(exprs[i], NewLiteral(sql.GetMySQLRangeCutKey(rce.UpperBound), rce.Typ.Promote())), + NewGreaterThanOrEqual(exprs[i], NewLiteral(sql.GetMySQLRangeCutKey(rce.LowerBound), typ)), + NewLessThan(exprs[i], NewLiteral(sql.GetMySQLRangeCutKey(rce.UpperBound), typ)), ) } rangeExpr = JoinAnd(rangeExpr, rangeColumnExpr) From ca23de9ee14caec90d10099d3edfc61a2a11c428 Mon Sep 17 00:00:00 2001 From: angelamayxie Date: Thu, 15 Jan 2026 10:15:40 -0800 Subject: [PATCH 2/7] Use origin type during type conversion --- sql/expression/comparison.go | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/sql/expression/comparison.go b/sql/expression/comparison.go index 446851310f..6c932c9894 100644 --- a/sql/expression/comparison.go +++ b/sql/expression/comparison.go @@ -278,7 +278,7 @@ func (c *comparison) castLeftAndRight(ctx *sql.Context, left, right interface{}) } if types.IsTime(leftType) || types.IsTime(rightType) { - l, r, err := convertLeftAndRight(ctx, left, right, ConvertToDatetime) + l, r, err := c.convertLeftAndRight(ctx, left, right, ConvertToDatetime) if err != nil { return nil, nil, nil, err } @@ -291,7 +291,7 @@ func (c *comparison) castLeftAndRight(ctx *sql.Context, left, right interface{}) } if types.IsBinaryType(leftType) || types.IsBinaryType(rightType) { - l, r, err := convertLeftAndRight(ctx, left, right, ConvertToBinary) + l, r, err := c.convertLeftAndRight(ctx, left, right, ConvertToBinary) if err != nil { return nil, nil, nil, err } @@ -301,7 +301,7 @@ func (c *comparison) castLeftAndRight(ctx *sql.Context, left, right interface{}) if types.IsNumber(leftType) || types.IsNumber(rightType) { if types.IsDecimal(leftType) || types.IsDecimal(rightType) { //TODO: We need to set to the actual DECIMAL type - l, r, err := convertLeftAndRight(ctx, left, right, ConvertToDecimal) + l, r, err := c.convertLeftAndRight(ctx, left, right, ConvertToDecimal) if err != nil { return nil, nil, nil, err } @@ -314,7 +314,7 @@ func (c *comparison) castLeftAndRight(ctx *sql.Context, left, right interface{}) } if types.IsFloat(leftType) || types.IsFloat(rightType) { - l, r, err := convertLeftAndRight(ctx, left, right, ConvertToDouble) + l, r, err := c.convertLeftAndRight(ctx, left, right, ConvertToDouble) if err != nil { return nil, nil, nil, err } @@ -323,7 +323,7 @@ func (c *comparison) castLeftAndRight(ctx *sql.Context, left, right interface{}) } if types.IsSigned(leftType) && types.IsSigned(rightType) { - l, r, err := convertLeftAndRight(ctx, left, right, ConvertToSigned) + l, r, err := c.convertLeftAndRight(ctx, left, right, ConvertToSigned) if err != nil { return nil, nil, nil, err } @@ -332,7 +332,7 @@ func (c *comparison) castLeftAndRight(ctx *sql.Context, left, right interface{}) } if types.IsUnsigned(leftType) && types.IsUnsigned(rightType) { - l, r, err := convertLeftAndRight(ctx, left, right, ConvertToUnsigned) + l, r, err := c.convertLeftAndRight(ctx, left, right, ConvertToUnsigned) if err != nil { return nil, nil, nil, err } @@ -340,7 +340,7 @@ func (c *comparison) castLeftAndRight(ctx *sql.Context, left, right interface{}) return l, r, types.Uint64, nil } - l, r, err := convertLeftAndRight(ctx, left, right, ConvertToDouble) + l, r, err := c.convertLeftAndRight(ctx, left, right, ConvertToDouble) if err != nil { return nil, nil, nil, err } @@ -348,7 +348,7 @@ func (c *comparison) castLeftAndRight(ctx *sql.Context, left, right interface{}) return l, r, types.Float64, nil } - left, right, err := convertLeftAndRight(ctx, left, right, ConvertToChar) + left, right, err := c.convertLeftAndRight(ctx, left, right, ConvertToChar) if err != nil { return nil, nil, nil, err } @@ -356,17 +356,17 @@ func (c *comparison) castLeftAndRight(ctx *sql.Context, left, right interface{}) return left, right, types.LongText, nil } -func convertLeftAndRight(ctx *sql.Context, left, right interface{}, convertTo string) (interface{}, interface{}, error) { +func (c *comparison) convertLeftAndRight(ctx *sql.Context, left, right interface{}, convertTo string) (interface{}, interface{}, error) { typeLength := 0 if convertTo == ConvertToDatetime { typeLength = types.MaxDatetimePrecision } - l, err := convertValue(ctx, left, convertTo, nil, typeLength, 0) + l, err := convertValue(ctx, left, convertTo, c.Left().Type(), typeLength, 0) if err != nil { return nil, nil, err } - r, err := convertValue(ctx, right, convertTo, nil, typeLength, 0) + r, err := convertValue(ctx, right, convertTo, c.Right().Type(), typeLength, 0) if err != nil { return nil, nil, err } From 20261637d4605a305e73897bf755f0b234e17a1a Mon Sep 17 00:00:00 2001 From: angelamayxie Date: Thu, 15 Jan 2026 10:53:38 -0800 Subject: [PATCH 3/7] add script_queries test (pulled from #3377) --- enginetest/queries/script_queries.go | 40 ++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/enginetest/queries/script_queries.go b/enginetest/queries/script_queries.go index 999dbb685e..3272d3df78 100644 --- a/enginetest/queries/script_queries.go +++ b/enginetest/queries/script_queries.go @@ -11416,6 +11416,46 @@ where }, }, }, + { + // https://github.com/dolthub/dolt/issues/10311 + Name: "enums with foreign keys and joins", + Dialect: "mysql", + SetUpScript: []string{ + "create table animals(e enum('rat','ox','tiger','dog') primary key);", + "create table pets(e enum('cat','dog','fish','rat'), foreign key (e) references animals(e));", + "insert into animals values('rat');", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "insert into pets values ('rat');", + // Error expected here because 'rat' has different underlying int values depending on the enum type + ExpectedErr: sql.ErrForeignKeyChildViolation, + }, + { + Query: "insert into pets values ('cat');", + // Query OK expected here because the underlying int values are the same + Expected: []sql.Row{{types.NewOkResult(1)}}, + }, + { + Query: "select * from animals join pets on animals.e=pets.e;", + // Empty set expected here because comparison uses the string values when enum types are different + Expected: []sql.Row{}, + }, + { + Query: "insert into animals values ('dog');", + Expected: []sql.Row{{types.NewOkResult(1)}}, + }, + { + Query: "insert into pets values ('rat');", + // 'rat' is now okay because it has the same underlying int value as 'dog' in the animals table + Expected: []sql.Row{{types.NewOkResult(1)}}, + }, + { + Query: "select * from animals join pets on animals.e=pets.e;", + Expected: []sql.Row{{"rat", "rat"}}, + }, + }, + }, { Skip: true, Name: "enums with foreign keys and cascade", From 9643f1bc99d1f9e1bae822914f4aa21708f8d6c3 Mon Sep 17 00:00:00 2001 From: angelamayxie Date: Thu, 15 Jan 2026 11:40:34 -0800 Subject: [PATCH 4/7] pull more files from #3377 --- enginetest/join_op_tests.go | 17 +++++++++++++++++ sql/index_builder.go | 5 ++++- sql/type.go | 5 +++++ 3 files changed, 26 insertions(+), 1 deletion(-) diff --git a/enginetest/join_op_tests.go b/enginetest/join_op_tests.go index 8f7e72f9cb..28c2cd9974 100644 --- a/enginetest/join_op_tests.go +++ b/enginetest/join_op_tests.go @@ -2298,6 +2298,23 @@ WHERE }, }, }, + { + name: "join on different enum types", + setup: [][]string{ + { + "create table animals(e enum('rat','ox','tiger','dog') primary key);", + "create table pets(e enum('cat','dog','fish','rat'), foreign key (e) references animals(e));", + "insert into animals values('rat'), ('dog');", + "insert into pets values ('cat'), ('rat');", + }, + }, + tests: []JoinOpTests{ + { + Query: "select * from animals join pets on animals.e=pets.e;", + Expected: []sql.Row{{"rat", "rat"}}, + }, + }, + }, } var rangeJoinOpTests = []JoinOpTests{ diff --git a/sql/index_builder.go b/sql/index_builder.go index 484a794627..3bba31e739 100644 --- a/sql/index_builder.go +++ b/sql/index_builder.go @@ -331,8 +331,11 @@ func (b *MySQLIndexBuilder) GreaterThan(ctx *Context, colExpr string, keyType Ty // IsConvertibleKeyType checks if the key can be converted into the column type func IsConvertibleKeyType(colType Type, keyType Type) bool { + if IsEnumType(colType) { + return !(IsStringType(keyType) || (IsEnumType(keyType) && !colType.Equals(keyType))) + } if IsStringType(colType) { - return !(IsNumberType(keyType) || IsDecimalType(keyType)) + return !(IsNumberType(keyType) || IsDecimalType(keyType) || IsEnumType(keyType)) } // TODO: check other types return true diff --git a/sql/type.go b/sql/type.go index 1d277b1c9e..c1c549ddba 100644 --- a/sql/type.go +++ b/sql/type.go @@ -280,6 +280,11 @@ type EnumType interface { Values() []string } +func IsEnumType(t Type) bool { + _, ok := t.(EnumType) + return ok +} + // DecimalType represents the DECIMAL type. // https://dev.mysql.com/doc/refman/8.0/en/fixed-point-types.html // The type of the returned value is decimal.Decimal. From fc23d36bf491fb68a13d1b4598df430c36b4c854 Mon Sep 17 00:00:00 2001 From: angelamayxie Date: Thu, 15 Jan 2026 12:17:43 -0800 Subject: [PATCH 5/7] disable merge joins if types do not sort in the same order --- sql/analyzer/indexed_joins.go | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/sql/analyzer/indexed_joins.go b/sql/analyzer/indexed_joins.go index 8bf621c031..1315fd2d9b 100644 --- a/sql/analyzer/indexed_joins.go +++ b/sql/analyzer/indexed_joins.go @@ -1147,7 +1147,7 @@ func addMergeJoins(ctx *sql.Context, m *memo.Memo) error { } // check that comparer is not non-decreasing - if !isWeaklyMonotonic(l) || !isWeaklyMonotonic(r) { + if !canMergeTypes(l.Type(), r.Type()) || !isWeaklyMonotonic(l) || !isWeaklyMonotonic(r) { continue } @@ -1491,6 +1491,23 @@ func makeIndexScan(ctx *sql.Context, statsProv sql.StatsProvider, tab plan.Table }, true, nil } +// canMerge checks the types of two expressions to see if they can be merged into one another if sorted. +func canMergeTypes(t1, t2 sql.Type) bool { + switch { + case types.IsNumber(t1): + return !types.IsText(t2) + case types.IsText(t1): + return !(types.IsNumber(t2) || types.IsEnum(t2)) + case types.IsEnum(t1): + if types.IsEnum(t2) { + return types.TypesEqual(t1, t2) + } + return !types.IsText(t2) + default: + return true + } +} + // isWeaklyMonotonic is a weak test of whether an expression // will be strictly increasing as the value of column attribute // inputs increases. From db9506c45d992cac21e68f9a5e38ad33509d5637 Mon Sep 17 00:00:00 2001 From: angelamayxie Date: Thu, 15 Jan 2026 12:58:46 -0800 Subject: [PATCH 6/7] remove adding other type checks (might be too strict and we need additional testing) --- sql/analyzer/indexed_joins.go | 12 ++++-------- sql/index_builder.go | 14 ++++++++------ 2 files changed, 12 insertions(+), 14 deletions(-) diff --git a/sql/analyzer/indexed_joins.go b/sql/analyzer/indexed_joins.go index 1315fd2d9b..7dfb736bec 100644 --- a/sql/analyzer/indexed_joins.go +++ b/sql/analyzer/indexed_joins.go @@ -1491,21 +1491,17 @@ func makeIndexScan(ctx *sql.Context, statsProv sql.StatsProvider, tab plan.Table }, true, nil } -// canMerge checks the types of two expressions to see if they can be merged into one another if sorted. +// canMerge checks the types of two columns to see if they can be merged into one another if sorted. func canMergeTypes(t1, t2 sql.Type) bool { + // TODO: handle other types here. For example, Number and Text types likely can't be merged together. But we need to + // add more testing https://github.com/dolthub/dolt/issues/10316 switch { - case types.IsNumber(t1): - return !types.IsText(t2) - case types.IsText(t1): - return !(types.IsNumber(t2) || types.IsEnum(t2)) case types.IsEnum(t1): if types.IsEnum(t2) { return types.TypesEqual(t1, t2) } - return !types.IsText(t2) - default: - return true } + return true } // isWeaklyMonotonic is a weak test of whether an expression diff --git a/sql/index_builder.go b/sql/index_builder.go index 3bba31e739..45592ef3a6 100644 --- a/sql/index_builder.go +++ b/sql/index_builder.go @@ -331,13 +331,15 @@ func (b *MySQLIndexBuilder) GreaterThan(ctx *Context, colExpr string, keyType Ty // IsConvertibleKeyType checks if the key can be converted into the column type func IsConvertibleKeyType(colType Type, keyType Type) bool { - if IsEnumType(colType) { - return !(IsStringType(keyType) || (IsEnumType(keyType) && !colType.Equals(keyType))) - } - if IsStringType(colType) { - return !(IsNumberType(keyType) || IsDecimalType(keyType) || IsEnumType(keyType)) + // TODO: check other types https://github.com/dolthub/dolt/issues/10316 + switch { + case IsEnumType(colType): + if IsEnumType(keyType) { + return !colType.Equals(keyType) + } + case IsStringType(colType): + return !(IsNumberType(keyType) || IsDecimalType(keyType)) } - // TODO: check other types return true } From a8db9a637b3c92380a1a0636834f7d692246561d Mon Sep 17 00:00:00 2001 From: angelamayxie Date: Thu, 15 Jan 2026 13:09:04 -0800 Subject: [PATCH 7/7] fix enum type check and add bug link to new join op test --- enginetest/join_op_tests.go | 1 + sql/index_builder.go | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/enginetest/join_op_tests.go b/enginetest/join_op_tests.go index 28c2cd9974..4764eff810 100644 --- a/enginetest/join_op_tests.go +++ b/enginetest/join_op_tests.go @@ -2299,6 +2299,7 @@ WHERE }, }, { + // https://github.com/dolthub/dolt/issues/10311 name: "join on different enum types", setup: [][]string{ { diff --git a/sql/index_builder.go b/sql/index_builder.go index 45592ef3a6..4579db6c44 100644 --- a/sql/index_builder.go +++ b/sql/index_builder.go @@ -335,7 +335,7 @@ func IsConvertibleKeyType(colType Type, keyType Type) bool { switch { case IsEnumType(colType): if IsEnumType(keyType) { - return !colType.Equals(keyType) + return colType.Equals(keyType) } case IsStringType(colType): return !(IsNumberType(keyType) || IsDecimalType(keyType))