diff --git a/optgen/cmd/support/agg_gen.go b/optgen/cmd/support/agg_gen.go index aba85bbce4..4ab2c37c00 100644 --- a/optgen/cmd/support/agg_gen.go +++ b/optgen/cmd/support/agg_gen.go @@ -120,7 +120,7 @@ func (g *AggGen) genAggStringer(define AggDef) { fmt.Fprintf(g.w, " pr.WriteChildren(children...)\n") fmt.Fprintf(g.w, " return pr.String()\n") fmt.Fprintf(g.w, " }\n") - fmt.Fprintf(g.w, " return fmt.Sprintf(\"%s(%%s)\", a.Child)\n", strings.ToUpper(sqlName)) + fmt.Fprintf(g.w, " return \"%s(\" + a.Child.String() + \")\"\n", strings.ToUpper(sqlName)) fmt.Fprintf(g.w, "}\n\n") fmt.Fprintf(g.w, "func (a *%s) DebugString() string {\n", define.Name) diff --git a/optgen/cmd/support/agg_gen_test.go b/optgen/cmd/support/agg_gen_test.go index fcb511d431..ed99bd7b42 100644 --- a/optgen/cmd/support/agg_gen_test.go +++ b/optgen/cmd/support/agg_gen_test.go @@ -64,7 +64,7 @@ func TestAggGen(t *testing.T) { pr.WriteChildren(children...) return pr.String() } - return fmt.Sprintf("TEST(%s)", a.Child) + return "TEST(" + a.Child.String() + ")" } func (a *Test) DebugString() string { diff --git a/sql/analyzer/apply_indexes_from_outer_scope.go b/sql/analyzer/apply_indexes_from_outer_scope.go index 371142d8da..31551a86ba 100644 --- a/sql/analyzer/apply_indexes_from_outer_scope.go +++ b/sql/analyzer/apply_indexes_from_outer_scope.go @@ -392,53 +392,21 @@ func extractJoinColumnExpr(e sql.Expression) (leftCol *joinColExpr, rightCol *jo } } -func containsColumns(e sql.Expression) bool { - var result bool - sql.Inspect(e, func(e sql.Expression) bool { - _, ok1 := e.(*expression.GetField) - _, ok2 := e.(*expression.UnresolvedColumn) - if ok1 || ok2 { - result = true - return false - } - return true - }) - return result -} - -func containsSubquery(e sql.Expression) bool { - var result bool - sql.Inspect(e, func(e sql.Expression) bool { - if _, ok := e.(*plan.Subquery); ok { - result = true - return false - } - return true - }) - return result -} - +// isEvaluable determines if sql.Expression has/contains columns, subqueries, bindvars, or procedure params. +// Those expressions are NOT evaluable. func isEvaluable(e sql.Expression) bool { - return !containsColumns(e) && !containsSubquery(e) && !containsBindvars(e) && !containsProcedureParam(e) -} - -func containsBindvars(e sql.Expression) bool { - var result bool + var hasUnevaluable bool sql.Inspect(e, func(e sql.Expression) bool { - if _, ok := e.(*expression.BindVar); ok { - result = true + switch e.(type) { + case *expression.GetField, *expression.UnresolvedColumn, + *plan.Subquery, + *expression.BindVar, + *expression.ProcedureParam: + hasUnevaluable = true return false + default: + return true } - return true }) - return result -} - -func containsProcedureParam(e sql.Expression) bool { - var result bool - sql.Inspect(e, func(e sql.Expression) bool { - _, result = e.(*expression.ProcedureParam) - return !result - }) - return result + return !hasUnevaluable } diff --git a/sql/analyzer/costed_index_scan.go b/sql/analyzer/costed_index_scan.go index f27e6002a7..1beddeed7d 100644 --- a/sql/analyzer/costed_index_scan.go +++ b/sql/analyzer/costed_index_scan.go @@ -756,6 +756,7 @@ func (c *indexCoster) flattenAnd(e *expression.And, and *iScanAnd) (sql.FastIntS invalid.Add(int(newOr.Id())) } else { and.orChildren = append(and.orChildren, newOr) + and.cnt++ if imp { imprecise.Add(int(newOr.id)) } @@ -1243,6 +1244,7 @@ func (a *iScanAnd) newLeaf(l *iScanLeaf) { a.leafChildren = make(map[string][]*iScanLeaf) } a.leafChildren[strings.ToLower(l.gf.Name())] = append(a.leafChildren[strings.ToLower(l.gf.Name())], l) + a.cnt++ } // leaves returns a list of this nodes leaf filters, sorted by id @@ -1260,14 +1262,6 @@ func (a *iScanAnd) leaves() []*iScanLeaf { } func (a *iScanAnd) childCnt() int { - if a.cnt > 0 { - return a.cnt - } - cnt := len(a.orChildren) - for _, leaves := range a.leafChildren { - cnt += len(leaves) - } - a.cnt = cnt return a.cnt } diff --git a/sql/expression/function/aggregation/common.go b/sql/expression/function/aggregation/common.go index 164d85cf2e..f9c88af7f0 100644 --- a/sql/expression/function/aggregation/common.go +++ b/sql/expression/function/aggregation/common.go @@ -15,8 +15,6 @@ package aggregation import ( - "fmt" - "gopkg.in/src-d/go-errors.v1" "github.com/dolthub/go-mysql-server/sql" @@ -59,7 +57,7 @@ func (a *unaryAggBase) Window() *sql.WindowDefinition { } func (a *unaryAggBase) String() string { - return fmt.Sprintf("%s(%s)", a.functionName, a.Child) + return a.functionName + "(" + a.Child.String() + ")" } func (a *unaryAggBase) Type() sql.Type { diff --git a/sql/expression/function/aggregation/unary_aggs.og.go b/sql/expression/function/aggregation/unary_aggs.og.go index a5094cc975..700ad0e47d 100644 --- a/sql/expression/function/aggregation/unary_aggs.og.go +++ b/sql/expression/function/aggregation/unary_aggs.og.go @@ -45,7 +45,7 @@ func (a *AnyValue) String() string { pr.WriteChildren(children...) return pr.String() } - return fmt.Sprintf("ANYVALUE(%s)", a.Child) + return "ANYVALUE(" + a.Child.String() + ")" } func (a *AnyValue) DebugString() string { @@ -124,7 +124,7 @@ func (a *Avg) String() string { pr.WriteChildren(children...) return pr.String() } - return fmt.Sprintf("AVG(%s)", a.Child) + return "AVG(" + a.Child.String() + ")" } func (a *Avg) DebugString() string { @@ -203,7 +203,7 @@ func (a *BitAnd) String() string { pr.WriteChildren(children...) return pr.String() } - return fmt.Sprintf("BITAND(%s)", a.Child) + return "BITAND(" + a.Child.String() + ")" } func (a *BitAnd) DebugString() string { @@ -282,7 +282,7 @@ func (a *BitOr) String() string { pr.WriteChildren(children...) return pr.String() } - return fmt.Sprintf("BITOR(%s)", a.Child) + return "BITOR(" + a.Child.String() + ")" } func (a *BitOr) DebugString() string { @@ -361,7 +361,7 @@ func (a *BitXor) String() string { pr.WriteChildren(children...) return pr.String() } - return fmt.Sprintf("BITXOR(%s)", a.Child) + return "BITXOR(" + a.Child.String() + ")" } func (a *BitXor) DebugString() string { @@ -440,7 +440,7 @@ func (a *Count) String() string { pr.WriteChildren(children...) return pr.String() } - return fmt.Sprintf("COUNT(%s)", a.Child) + return "COUNT(" + a.Child.String() + ")" } func (a *Count) DebugString() string { @@ -519,7 +519,7 @@ func (a *First) String() string { pr.WriteChildren(children...) return pr.String() } - return fmt.Sprintf("FIRST(%s)", a.Child) + return "FIRST(" + a.Child.String() + ")" } func (a *First) DebugString() string { @@ -598,7 +598,7 @@ func (a *JsonArray) String() string { pr.WriteChildren(children...) return pr.String() } - return fmt.Sprintf("JSON_ARRAYAGG(%s)", a.Child) + return "JSON_ARRAYAGG(" + a.Child.String() + ")" } func (a *JsonArray) DebugString() string { @@ -677,7 +677,7 @@ func (a *Last) String() string { pr.WriteChildren(children...) return pr.String() } - return fmt.Sprintf("LAST(%s)", a.Child) + return "LAST(" + a.Child.String() + ")" } func (a *Last) DebugString() string { @@ -756,7 +756,7 @@ func (a *Max) String() string { pr.WriteChildren(children...) return pr.String() } - return fmt.Sprintf("MAX(%s)", a.Child) + return "MAX(" + a.Child.String() + ")" } func (a *Max) DebugString() string { @@ -835,7 +835,7 @@ func (a *Min) String() string { pr.WriteChildren(children...) return pr.String() } - return fmt.Sprintf("MIN(%s)", a.Child) + return "MIN(" + a.Child.String() + ")" } func (a *Min) DebugString() string { @@ -914,7 +914,7 @@ func (a *Sum) String() string { pr.WriteChildren(children...) return pr.String() } - return fmt.Sprintf("SUM(%s)", a.Child) + return "SUM(" + a.Child.String() + ")" } func (a *Sum) DebugString() string { @@ -993,7 +993,7 @@ func (a *StdDevPop) String() string { pr.WriteChildren(children...) return pr.String() } - return fmt.Sprintf("STDDEVPOP(%s)", a.Child) + return "STDDEVPOP(" + a.Child.String() + ")" } func (a *StdDevPop) DebugString() string { @@ -1072,7 +1072,7 @@ func (a *StdDevSamp) String() string { pr.WriteChildren(children...) return pr.String() } - return fmt.Sprintf("STDDEVSAMP(%s)", a.Child) + return "STDDEVSAMP(" + a.Child.String() + ")" } func (a *StdDevSamp) DebugString() string { @@ -1151,7 +1151,7 @@ func (a *VarPop) String() string { pr.WriteChildren(children...) return pr.String() } - return fmt.Sprintf("VARPOP(%s)", a.Child) + return "VARPOP(" + a.Child.String() + ")" } func (a *VarPop) DebugString() string { @@ -1230,7 +1230,7 @@ func (a *VarSamp) String() string { pr.WriteChildren(children...) return pr.String() } - return fmt.Sprintf("VARSAMP(%s)", a.Child) + return "VARSAMP(" + a.Child.String() + ")" } func (a *VarSamp) DebugString() string { diff --git a/sql/index_builder.go b/sql/index_builder.go index 25acd2ae26..484a794627 100644 --- a/sql/index_builder.go +++ b/sql/index_builder.go @@ -201,15 +201,19 @@ func (b *MySQLIndexBuilder) In(ctx *Context, colExpr string, keyTypes []Type, ke for i, k := range keys { // if converting from float to int results in rounding, then it's empty range if t, ok := colTyp.(NumberType); ok && t.IsNumericType() && !t.IsFloat() { - f, c := floor(k), ceil(k) - switch k.(type) { - case float32, float64: - if f != c { + switch k := k.(type) { + case float32: + if float32(int64(k)) != k { + potentialRanges[i] = EmptyRangeColumnExpr(colTyp) + continue + } + case float64: + if float64(int64(k)) != k { potentialRanges[i] = EmptyRangeColumnExpr(colTyp) continue } case decimal.Decimal: - if !f.(decimal.Decimal).Equals(c.(decimal.Decimal)) { + if !k.Equal(decimal.NewFromInt(k.IntPart())) { potentialRanges[i] = EmptyRangeColumnExpr(colTyp) continue } @@ -246,15 +250,19 @@ func (b *MySQLIndexBuilder) NotEquals(ctx *Context, colExpr string, keyType Type return b } // if converting from float to int results in rounding, then it's entire range (excluding nulls) - f, c := floor(key), ceil(key) - switch key.(type) { - case float32, float64: - if f != c { + switch k := key.(type) { + case float32: + if float32(int64(k)) != k { + b.updateCol(ctx, colExpr, NotNullRangeColumnExpr(colTyp)) + return b + } + case float64: + if float64(int64(k)) != k { b.updateCol(ctx, colExpr, NotNullRangeColumnExpr(colTyp)) return b } case decimal.Decimal: - if !f.(decimal.Decimal).Equals(c.(decimal.Decimal)) { + if !k.Equal(decimal.NewFromInt(k.IntPart())) { b.updateCol(ctx, colExpr, NotNullRangeColumnExpr(colTyp)) return b } diff --git a/sql/planbuilder/orderby.go b/sql/planbuilder/orderby.go index 85c87d7f79..92c2e6925b 100644 --- a/sql/planbuilder/orderby.go +++ b/sql/planbuilder/orderby.go @@ -187,7 +187,7 @@ func (b *Builder) normalizeValArg(e *ast.SQLVal) (sql.Expression, bool) { func (b *Builder) normalizeIntVal(e *ast.SQLVal) (any, bool) { if e.Type == ast.IntVal { - lit := b.convertInt(string(e.Val), 10) + lit := b.convertInt(e.Val, 10) return lit.Value(), true } else if replace, ok := b.normalizeValArg(e); ok { if lit, ok := replace.(*expression.Literal); ok && types.IsNumber(lit.Type()) { diff --git a/sql/planbuilder/scalar.go b/sql/planbuilder/scalar.go index 1cd17b8e39..72c7e7d68e 100644 --- a/sql/planbuilder/scalar.go +++ b/sql/planbuilder/scalar.go @@ -926,32 +926,59 @@ func (b *Builder) intervalExprToExpression(inScope *scope, e *ast.IntervalExpr) // Convert an integer, represented by the specified string in the specified // base, to its smallest representation possible, out of: // int8, uint8, int16, uint16, int32, uint32, int64 and uint64 -func (b *Builder) convertInt(value string, base int) *expression.Literal { - if i8, err := strconv.ParseInt(value, base, 8); err == nil { - return expression.NewLiteral(int8(i8), types.Int8) - } - if ui8, err := strconv.ParseUint(value, base, 8); err == nil { - return expression.NewLiteral(uint8(ui8), types.Uint8) - } - if i16, err := strconv.ParseInt(value, base, 16); err == nil { - return expression.NewLiteral(int16(i16), types.Int16) - } - if ui16, err := strconv.ParseUint(value, base, 16); err == nil { - return expression.NewLiteral(uint16(ui16), types.Uint16) - } - if i32, err := strconv.ParseInt(value, base, 32); err == nil { - return expression.NewLiteral(int32(i32), types.Int32) - } - if ui32, err := strconv.ParseUint(value, base, 32); err == nil { - return expression.NewLiteral(uint32(ui32), types.Uint32) - } - if i64, err := strconv.ParseInt(value, base, 64); err == nil { - return expression.NewLiteral(int64(i64), types.Int64) +func (b *Builder) convertInt(value []byte, base int) *expression.Literal { + // For performance reasons, this smallest int representation possible for value. + // If zero-ing out (subtracting) the largest representation of the respective integer type results in values + // left over, then the value must not fit within that integer type. + valStr := encodings.BytesToString(value) + if i64, err := strconv.ParseInt(valStr, base, 64); err == nil { + if uint64(i64)&0x8000_0000_0000_0000 != 0 { + if uint64(^i64)&0xFFFF_FFFF_FFFF_FF80 == 0 { + return expression.NewLiteral(int8(i64), types.Int8) + } + if uint64(^i64)&0xFFFF_FFFF_FFFF_8000 == 0 { + return expression.NewLiteral(int16(i64), types.Int16) + } + if uint64(^i64)&0xFFFF_FFFF_8000_0000 == 0 { + return expression.NewLiteral(int32(i64), types.Int32) + } + return expression.NewLiteral(i64, types.Int64) + } + if uint64(i64)&0xFFFF_FFFF_FFFF_FF80 == 0 { + return expression.NewLiteral(int8(i64), types.Int8) + } + if uint64(i64)&0xFFFF_FFFF_FFFF_FF00 == 0 { + return expression.NewLiteral(uint8(i64), types.Uint8) + } + if uint64(i64)&0xFFFF_FFFF_FFFF_8000 == 0 { + return expression.NewLiteral(int16(i64), types.Int16) + } + if uint64(i64)&0xFFFF_FFFF_FFFF_0000 == 0 { + return expression.NewLiteral(uint16(i64), types.Uint16) + } + if uint64(i64)&0xFFFF_FFFF_8000_0000 == 0 { + return expression.NewLiteral(int32(i64), types.Int32) + } + if uint64(i64)&0xFFFF_FFFF_0000_0000 == 0 { + return expression.NewLiteral(uint32(i64), types.Uint32) + } + return expression.NewLiteral(i64, types.Int64) } - if ui64, err := strconv.ParseUint(value, base, 64); err == nil { - return expression.NewLiteral(uint64(ui64), types.Uint64) + + if ui64, err := strconv.ParseUint(valStr, base, 64); err == nil { + if ui64&0xFFFF_FFFF_FFFF_FF00 == 0 { + return expression.NewLiteral(uint8(ui64), types.Uint8) + } + if ui64&0xFFFF_FFFF_FFFF_0000 == 0 { + return expression.NewLiteral(uint16(ui64), types.Uint16) + } + if ui64&0xFFFF_0000_0000_0000 == 0 { + return expression.NewLiteral(uint32(ui64), types.Uint32) + } + return expression.NewLiteral(ui64, types.Uint64) } - if decimal, _, err := types.InternalDecimalType.Convert(b.ctx, value); err == nil { + + if decimal, _, err := types.InternalDecimalType.Convert(b.ctx, valStr); err == nil { return expression.NewLiteral(decimal, types.InternalDecimalType) } @@ -964,7 +991,7 @@ func (b *Builder) ConvertVal(v *ast.SQLVal) sql.Expression { case ast.StrVal: return expression.NewLiteral(string(v.Val), types.CreateLongText(b.ctx.GetCollation())) case ast.IntVal: - return b.convertInt(string(v.Val), 10) + return b.convertInt(v.Val, 10) case ast.FloatVal: // any float value is parsed as decimal except when the value has scientific notation ogVal := strings.ToLower(string(v.Val)) @@ -990,7 +1017,7 @@ func (b *Builder) ConvertVal(v *ast.SQLVal) sql.Expression { return expression.NewLiteral(dVal, dt) } else { // if the value is not float type - this should not happen - return b.convertInt(string(v.Val), 10) + return b.convertInt(v.Val, 10) } case ast.HexNum: // TODO: binary collation? diff --git a/sql/transform/expr.go b/sql/transform/expr.go index 73e0ca07e5..d7f6b96c94 100644 --- a/sql/transform/expr.go +++ b/sql/transform/expr.go @@ -96,19 +96,19 @@ func Exprs(e []sql.Expression, f ExprFunc) ([]sql.Expression, TreeIdentity, erro return newExprs, NewTree, nil } -var stopInspect = errors.New("stop") - // InspectExpr traverses the given expression tree from the bottom up, breaking if // stop = true. Returns a bool indicating whether traversal was interrupted. -func InspectExpr(node sql.Expression, f func(sql.Expression) bool) bool { - _, _, err := Expr(node, func(e sql.Expression) (sql.Expression, TreeIdentity, error) { - ok := f(e) - if ok { - return nil, SameTree, stopInspect +func InspectExpr(expr sql.Expression, f func(sql.Expression) bool) bool { + children := expr.Children() + for _, child := range children { + if InspectExpr(child, f) { + return true } - return e, SameTree, nil - }) - return errors.Is(err, stopInspect) + } + if f(expr) { + return true + } + return false } // InspectUp traverses the given node tree from the bottom up, breaking if