diff --git a/enginetest/queries/script_queries.go b/enginetest/queries/script_queries.go index b6657abc91..102c39b801 100644 --- a/enginetest/queries/script_queries.go +++ b/enginetest/queries/script_queries.go @@ -5897,6 +5897,257 @@ where }, }, }, + { + Name: "varchar primary key", + SetUpScript: []string{ + "create table vt (v varchar(3) primary key);", + "insert into vt values ('abc'), ('def'), ('ghi');", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "select * from vt where v = 'def';", + Expected: []sql.Row{ + {"def"}, + }, + }, + { + Query: "select * from vt where v < 'def';", + Expected: []sql.Row{ + {"abc"}, + }, + }, + { + Query: "select * from vt where v > 'def';", + Expected: []sql.Row{ + {"ghi"}, + }, + }, + { + Query: "select * from vt where v <= 'def';", + Expected: []sql.Row{ + {"abc"}, + {"def"}, + }, + }, + { + Query: "select * from vt where v >= 'def';", + Expected: []sql.Row{ + {"def"}, + {"ghi"}, + }, + }, + + { + Query: "select * from vt where v = 'defdef';", + Expected: []sql.Row{}, + }, + { + Query: "select * from vt where v < 'defdef';", + Expected: []sql.Row{ + {"abc"}, + {"def"}, + }, + }, + { + Query: "select * from vt where v > 'defdef';", + Expected: []sql.Row{ + {"ghi"}, + }, + }, + { + Query: "select * from vt where v <= 'defdef';", + Expected: []sql.Row{ + {"abc"}, + {"def"}, + }, + }, + { + Query: "select * from vt where v >= 'defdef';", + Expected: []sql.Row{ + {"ghi"}, + }, + }, + + // MySQL behavior around null bytes is strange + { + Skip: true, + Query: `select * from vt where v = 'def\0\0';`, + Expected: []sql.Row{ + {"def"}, + }, + }, + { + Skip: true, + Query: `select * from vt where v < 'def\0\0';`, + Expected: []sql.Row{ + {"abc"}, + }, + }, + { + Query: `select * from vt where v > 'def\0\0';`, + Expected: []sql.Row{ + {"ghi"}, + }, + }, + { + Query: `select * from vt where v <= 'def\0\0';`, + Expected: []sql.Row{ + {"abc"}, + {"def"}, + }, + }, + { + Skip: true, + Query: `select * from vt where v >= 'def\0\0';`, + Expected: []sql.Row{ + {"def"}, + {"ghi"}, + }, + }, + + { + Query: "select * from vt where v = cast('def' as char(6));", + Expected: []sql.Row{ + {"def"}, + }, + }, + { + Query: "select * from vt where v < cast('def' as char(6));", + Expected: []sql.Row{ + {"abc"}, + }, + }, + { + Query: "select * from vt where v > cast('def' as char(6));", + Expected: []sql.Row{ + {"ghi"}, + }, + }, + { + Query: "select * from vt where v <= cast('def' as char(6));", + Expected: []sql.Row{ + {"abc"}, + {"def"}, + }, + }, + { + Query: "select * from vt where v >= cast('def' as char(6));", + Expected: []sql.Row{ + {"def"}, + {"ghi"}, + }, + }, + }, + }, + { + Name: "varbinary primary key", + SetUpScript: []string{ + "create table vt (v varbinary(3) primary key);", + "insert into vt values ('abc'), ('def'), ('ghi');", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "select cast(v as char(3)) from vt where v = 'def';", + Expected: []sql.Row{ + {"def"}, + }, + }, + { + Query: "select cast(v as char(3)) from vt where v < 'def';", + Expected: []sql.Row{ + {"abc"}, + }, + }, + { + Query: "select cast(v as char(3)) from vt where v > 'def';", + Expected: []sql.Row{ + {"ghi"}, + }, + }, + { + Query: "select cast(v as char(3)) from vt where v <= 'def';", + Expected: []sql.Row{ + {"abc"}, + {"def"}, + }, + }, + { + Query: "select cast(v as char(3)) from vt where v >= 'def';", + Expected: []sql.Row{ + {"def"}, + {"ghi"}, + }, + }, + + { + Query: "select cast(v as char(3)) from vt where v = 'defdef';", + Expected: []sql.Row{}, + }, + { + Query: "select cast(v as char(3)) from vt where v < 'defdef';", + Expected: []sql.Row{ + {"abc"}, + {"def"}, + }, + }, + { + Query: "select cast(v as char(3)) from vt where v > 'defdef';", + Expected: []sql.Row{ + {"ghi"}, + }, + }, + { + Query: "select cast(v as char(3)) from vt where v <= 'defdef';", + Expected: []sql.Row{ + {"abc"}, + {"def"}, + }, + }, + { + Query: "select cast(v as char(3)) from vt where v >= 'defdef';", + Expected: []sql.Row{ + {"ghi"}, + }, + }, + + // MySQL behavior around null bytes is strange + { + Skip: true, + Query: `select cast(v as char(3)) from vt where v = 'def\0\0';`, + Expected: []sql.Row{ + {"def"}, + }, + }, + { + Skip: true, + Query: `select cast(v as char(3)) from vt where v < 'def\0\0';`, + Expected: []sql.Row{ + {"abc"}, + }, + }, + { + Query: `select cast(v as char(3)) from vt where v > 'def\0\0';`, + Expected: []sql.Row{ + {"ghi"}, + }, + }, + { + Query: `select cast(v as char(3)) from vt where v <= 'def\0\0';`, + Expected: []sql.Row{ + {"abc"}, + {"def"}, + }, + }, + { + Skip: true, + Query: `select cast(v as char(3)) from vt where v >= 'def\0\0';`, + Expected: []sql.Row{ + {"def"}, + {"ghi"}, + }, + }, + }, + }, } var SpatialScriptTests = []ScriptTest{ diff --git a/sql/index_builder.go b/sql/index_builder.go index 3254c08862..9e675fe614 100644 --- a/sql/index_builder.go +++ b/sql/index_builder.go @@ -44,8 +44,12 @@ func NewIndexBuilder(idx Index) *IndexBuilder { colExprTypes := make(map[string]Type) ranges := make(map[string][]RangeColumnExpr) for _, cet := range idx.ColumnExpressionTypes() { - colExprTypes[strings.ToLower(cet.Expression)] = cet.Type - ranges[strings.ToLower(cet.Expression)] = []RangeColumnExpr{AllRangeColumnExpr(cet.Type)} + typ := cet.Type + if _, ok := typ.(StringType); ok { + typ = typ.Promote() + } + colExprTypes[strings.ToLower(cet.Expression)] = typ + ranges[strings.ToLower(cet.Expression)] = []RangeColumnExpr{AllRangeColumnExpr(typ)} } return &IndexBuilder{ idx: idx, @@ -128,13 +132,14 @@ func (b *IndexBuilder) Equals(ctx *Context, colExpr string, keys ...interface{}) } } - res, _, err := typ.Convert(k) + var err error + k, _, err = typ.Convert(k) if err != nil { b.isInvalid = true b.err = err return b } - potentialRanges[i] = ClosedRangeColumnExpr(res, res, typ) + potentialRanges[i] = ClosedRangeColumnExpr(k, k, typ) } b.updateCol(ctx, colExpr, potentialRanges...) return b @@ -151,7 +156,6 @@ func (b *IndexBuilder) NotEquals(ctx *Context, colExpr string, key interface{}) b.err = ErrInvalidColExpr.New(colExpr, b.idx.ID()) return b } - // if converting from float to int results in rounding, then it's entire range (excluding nulls) f, c := floor(key), ceil(key) switch key.(type) { @@ -275,7 +279,6 @@ func (b *IndexBuilder) LessThan(ctx *Context, colExpr string, key interface{}) * if t, ok := typ.(NumberType); ok && !t.IsFloat() { key = ceil(key) } - key, _, err := typ.Convert(key) if err != nil { b.isInvalid = true @@ -371,7 +374,11 @@ func (b *IndexBuilder) Ranges(ctx *Context) RangeCollection { cets := b.idx.ColumnExpressionTypes() emptyRange := make(Range, len(cets)) for i, cet := range cets { - emptyRange[i] = EmptyRangeColumnExpr(cet.Type) + typ := cet.Type + if _, ok := typ.(StringType); ok { + typ = typ.Promote() + } + emptyRange[i] = EmptyRangeColumnExpr(typ) } return RangeCollection{emptyRange} } @@ -421,7 +428,7 @@ func (b *IndexBuilder) Ranges(ctx *Context) RangeCollection { cets := b.idx.ColumnExpressionTypes() emptyRange := make(Range, len(cets)) for i, cet := range cets { - emptyRange[i] = EmptyRangeColumnExpr(cet.Type) + emptyRange[i] = EmptyRangeColumnExpr(cet.Type.Promote()) } return RangeCollection{emptyRange} } diff --git a/sql/range.go b/sql/range.go index ab11dde4f6..a990b5352a 100644 --- a/sql/range.go +++ b/sql/range.go @@ -223,6 +223,9 @@ func (rang Range) TryMerge(otherRange Range) (Range, bool, error) { } } } + if indexToMerge == -1 { + return nil, false, fmt.Errorf("invalid index to merge") + } mergedLastExpr, ok, err := rang[indexToMerge].TryUnion(otherRange[indexToMerge]) if err != nil || !ok { return nil, false, err