From 40b7423aff55a366d6894765a13e121f7a936c54 Mon Sep 17 00:00:00 2001 From: Saif Alharthi Date: Sun, 5 Apr 2020 21:37:03 -0700 Subject: [PATCH 01/23] Added SingleRow Engine Definition Signed-off-by: Saif Alharthi --- go/vt/vtgate/engine/singlerow.go | 81 ++++++++++++++++++++++++++++++++ 1 file changed, 81 insertions(+) create mode 100644 go/vt/vtgate/engine/singlerow.go diff --git a/go/vt/vtgate/engine/singlerow.go b/go/vt/vtgate/engine/singlerow.go new file mode 100644 index 00000000000..16862a5d3d3 --- /dev/null +++ b/go/vt/vtgate/engine/singlerow.go @@ -0,0 +1,81 @@ +/* +Copyright 2020 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package engine + +import ( + "vitess.io/vitess/go/sqltypes" + "vitess.io/vitess/go/vt/proto/query" +) + +var _ Primitive = (*SingleRow)(nil) + +// SingleRow defines an empty result +type SingleRow struct { + noInputs +} + +// RouteType returns a description of the query routing type used by the primitive +func (s SingleRow) RouteType() string { + return "" +} + +// GetKeyspaceName specifies the Keyspace that this primitive routes to. +func (s SingleRow) GetKeyspaceName() string { + return "" +} + +// GetTableName specifies the table that this primitive routes to. +func (s SingleRow) GetTableName() string { + return "" +} + +// Execute performs a non-streaming exec. +func (s SingleRow) Execute(vcursor VCursor, bindVars map[string]*query.BindVariable, wantfields bool) (*sqltypes.Result, error) { + result := sqltypes.Result{ + Fields: nil, + RowsAffected: 0, + InsertID: 0, + Rows: [][]sqltypes.Value{ + {}, + }, + } + return &result, nil +} + +// StreamExecute performs a streaming exec. +func (s SingleRow) StreamExecute(vcursor VCursor, bindVars map[string]*query.BindVariable, wantields bool, callback func(*sqltypes.Result) error) error { + result := sqltypes.Result{ + Fields: nil, + RowsAffected: 0, + InsertID: 0, + Rows: [][]sqltypes.Value{ + {}, + }, + } + return callback(&result) +} + +// GetFields fetches the field info. +func (s SingleRow) GetFields(vcursor VCursor, bindVars map[string]*query.BindVariable) (*sqltypes.Result, error) { + return &sqltypes.Result{}, nil +} + +func (s SingleRow) description() PrimitiveDescription { + return PrimitiveDescription{ + OperatorType: "SingleRow", + } +} From 3464ea646561a4842423b57cc3ee497f0a617efb Mon Sep 17 00:00:00 2001 From: Saif Alharthi Date: Sun, 5 Apr 2020 23:13:54 -0700 Subject: [PATCH 02/23] Added Projection plan and Primitive Signed-off-by: Saif Alharthi --- go/vt/vtgate/engine/projection.go | 51 ++++++++++++ go/vt/vtgate/planbuilder/select.go | 32 ++++++++ .../planbuilder/testdata/from_cases.txt | 26 +++--- .../planbuilder/testdata/select_cases.txt | 51 +++++++++--- go/vt/vtgate/planbuilder/vtgate_execution.go | 79 +++++++++++++++++++ 5 files changed, 221 insertions(+), 18 deletions(-) create mode 100644 go/vt/vtgate/engine/projection.go create mode 100644 go/vt/vtgate/planbuilder/vtgate_execution.go diff --git a/go/vt/vtgate/engine/projection.go b/go/vt/vtgate/engine/projection.go new file mode 100644 index 00000000000..80cac6a7929 --- /dev/null +++ b/go/vt/vtgate/engine/projection.go @@ -0,0 +1,51 @@ +package engine + +import ( + "vitess.io/vitess/go/sqltypes" + "vitess.io/vitess/go/vt/proto/query" + "vitess.io/vitess/go/vt/sqlparser" +) + +var _ Primitive = (*Projection)(nil) + +type Projection struct { + Exprs []*sqlparser.AliasedExpr + Input Primitive +} + +func (p *Projection) RouteType() string { + return p.Input.RouteType() +} + +func (p *Projection) GetKeyspaceName() string { + return p.Input.GetKeyspaceName() +} + +func (p *Projection) GetTableName() string { + return p.Input.GetTableName() +} + +func (p *Projection) Execute(vcursor VCursor, bindVars map[string]*query.BindVariable, wantfields bool) (*sqltypes.Result, error) { + panic("implement me") +} + +func (p *Projection) StreamExecute(vcursor VCursor, bindVars map[string]*query.BindVariable, wantields bool, callback func(*sqltypes.Result) error) error { + panic("implement me") +} + +func (p *Projection) GetFields(vcursor VCursor, bindVars map[string]*query.BindVariable) (*sqltypes.Result, error) { + panic("implement me") +} + +func (p *Projection) Inputs() []Primitive { + return []Primitive{p.Input} +} + +func (p *Projection) description() PrimitiveDescription { + return PrimitiveDescription{ + OperatorType: "Projection", + Other: map[string]interface{}{ + "Expressions": p.Exprs, + }, + } +} diff --git a/go/vt/vtgate/planbuilder/select.go b/go/vt/vtgate/planbuilder/select.go index 98a38757fb6..42670cf0e6d 100644 --- a/go/vt/vtgate/planbuilder/select.go +++ b/go/vt/vtgate/planbuilder/select.go @@ -73,6 +73,16 @@ func buildSelectPlan(stmt sqlparser.Statement, vschema ContextVSchema) (engine.P // pushed into a route, then a primitive is created on top of any // of the above trees to make it discard unwanted rows. func (pb *primitiveBuilder) processSelect(sel *sqlparser.Select, outer *symtab) error { + if checkForDual(sel) && outer == nil { + exprs := make([]*sqlparser.AliasedExpr, len(sel.SelectExprs)) + for i, e := range sel.SelectExprs { + exprs[i] = e.(*sqlparser.AliasedExpr) + } + pb.bldr = &vtgateExecution{ + exprs, + } + return nil + } if err := pb.processTableExprs(sel.From); err != nil { return err } @@ -121,6 +131,28 @@ func (pb *primitiveBuilder) processSelect(sel *sqlparser.Select, outer *symtab) return nil } +func checkForDual(sel *sqlparser.Select) bool { + if len(sel.From) == 1 { + if from, ok := sel.From[0].(*sqlparser.AliasedTableExpr); ok { + if tableName, ok := from.Expr.(sqlparser.TableName); ok { + if tableName.Name.String() == "dual" && tableName.Qualifier.IsEmpty() { + for _, expr := range sel.SelectExprs { + e, ok := expr.(*sqlparser.AliasedExpr) + if !ok { + return false + } + if _, ok := e.Expr.(*sqlparser.SQLVal); !ok { + return false + } + } + return true + } + } + } + } + return false +} + // pushFilter identifies the target route for the specified bool expr, // pushes it down, and updates the route info if the new constraint improves // the primitive. This function can push to a WHERE or HAVING clause. diff --git a/go/vt/vtgate/planbuilder/testdata/from_cases.txt b/go/vt/vtgate/planbuilder/testdata/from_cases.txt index 8ec63f3ea51..bd26660a646 100644 --- a/go/vt/vtgate/planbuilder/testdata/from_cases.txt +++ b/go/vt/vtgate/planbuilder/testdata/from_cases.txt @@ -2049,15 +2049,23 @@ "QueryType": "SELECT", "Original": "select last_insert_id()", "Instructions": { - "OperatorType": "Route", - "Variant": "SelectReference", - "Keyspace": { - "Name": "main", - "Sharded": false - }, - "FieldQuery": "select :__lastInsertId as `last_insert_id()` from dual where 1 != 1", - "Query": "select :__lastInsertId as `last_insert_id()` from dual", - "Table": "dual" + "OperatorType": "Projection", + "Variant": "", + "Expressions": [ + { + "Expr": { + "Type": 5, + "Val": "Ol9fbGFzdEluc2VydElk" + }, + "As": "last_insert_id()" + } + ], + "Inputs": [ + { + "OperatorType": "SingleRow", + "Variant": "" + } + ] } } diff --git a/go/vt/vtgate/planbuilder/testdata/select_cases.txt b/go/vt/vtgate/planbuilder/testdata/select_cases.txt index 1573f501f99..d0889f442a3 100644 --- a/go/vt/vtgate/planbuilder/testdata/select_cases.txt +++ b/go/vt/vtgate/planbuilder/testdata/select_cases.txt @@ -411,15 +411,23 @@ "QueryType": "SELECT", "Original": "select database() from dual", "Instructions": { - "OperatorType": "Route", - "Variant": "SelectReference", - "Keyspace": { - "Name": "main", - "Sharded": false - }, - "FieldQuery": "select :__vtdbname as `database()` from dual where 1 != 1", - "Query": "select :__vtdbname as `database()` from dual", - "Table": "dual" + "OperatorType": "Projection", + "Variant": "", + "Expressions": [ + { + "Expr": { + "Type": 5, + "Val": "Ol9fdnRkYm5hbWU=" + }, + "As": "database()" + } + ], + "Inputs": [ + { + "OperatorType": "SingleRow", + "Variant": "" + } + ] } } @@ -1350,3 +1358,28 @@ "Table": "unsharded" } } + +# testing SingleRow Projection +"select 42" +{ + "Original": "select 42", + "Instructions": { + "OperatorType": "Projection", + "Variant": "", + "Expressions": [ + { + "Expr": { + "Type": 1, + "Val": "NDI=" + }, + "As": "" + } + ], + "Inputs": [ + { + "OperatorType": "SingleRow", + "Variant": "" + } + ] + } +} diff --git a/go/vt/vtgate/planbuilder/vtgate_execution.go b/go/vt/vtgate/planbuilder/vtgate_execution.go new file mode 100644 index 00000000000..9c98ef6382c --- /dev/null +++ b/go/vt/vtgate/planbuilder/vtgate_execution.go @@ -0,0 +1,79 @@ +package planbuilder + +import ( + "vitess.io/vitess/go/vt/sqlparser" + "vitess.io/vitess/go/vt/vtgate/engine" +) + +var _ builder = (*vtgateExecution)(nil) + +type vtgateExecution struct { + Exprs []*sqlparser.AliasedExpr +} + +func (v *vtgateExecution) Order() int { + panic("implement me") +} + +func (v *vtgateExecution) ResultColumns() []*resultColumn { + panic("implement me") +} + +func (v *vtgateExecution) Reorder(int) { + panic("implement me") +} + +func (v *vtgateExecution) First() builder { + panic("implement me") +} + +func (v *vtgateExecution) PushFilter(pb *primitiveBuilder, filter sqlparser.Expr, whereType string, origin builder) error { + panic("implement me") +} + +func (v *vtgateExecution) PushSelect(pb *primitiveBuilder, expr *sqlparser.AliasedExpr, origin builder) (rc *resultColumn, colNumber int, err error) { + panic("implement me") +} + +func (v *vtgateExecution) MakeDistinct() error { + panic("implement me") +} + +func (v *vtgateExecution) PushGroupBy(sqlparser.GroupBy) error { + panic("implement me") +} + +func (v *vtgateExecution) PushOrderBy(sqlparser.OrderBy) (builder, error) { + panic("implement me") +} + +func (v *vtgateExecution) SetUpperLimit(count *sqlparser.SQLVal) { + panic("implement me") +} + +func (v *vtgateExecution) PushMisc(sel *sqlparser.Select) { + panic("implement me") +} + +func (v *vtgateExecution) Wireup(bldr builder, jt *jointab) error { + return nil +} + +func (v *vtgateExecution) SupplyVar(from, to int, col *sqlparser.ColName, varname string) { + panic("implement me") +} + +func (v *vtgateExecution) SupplyCol(col *sqlparser.ColName) (rc *resultColumn, colNumber int) { + panic("implement me") +} + +func (v *vtgateExecution) SupplyWeightString(colNumber int) (weightcolNumber int, err error) { + panic("implement me") +} + +func (v *vtgateExecution) Primitive() engine.Primitive { + return &engine.Projection{ + Exprs: v.Exprs, + Input: &engine.SingleRow{}, + } +} From b56994aa8586b9f0f08cc942b11ddb49ae02cc4b Mon Sep 17 00:00:00 2001 From: Saif Alharthi Date: Tue, 7 Apr 2020 23:10:22 -0700 Subject: [PATCH 03/23] Added ability to evaluate numerical and expressions with bind variables Signed-off-by: Saif Alharthi --- go/sqltypes/arithmetic.go | 1 + go/sqltypes/expressions.go | 112 +++++++++++++++++++ go/sqltypes/expressions_test.go | 87 ++++++++++++++ go/vt/sqlparser/ast_funcs.go | 21 ---- go/vt/sqlparser/ast_test.go | 42 ------- go/vt/sqlparser/expression_converter.go | 14 +++ go/vt/vtgate/engine/projection.go | 55 +++++++-- go/vt/vtgate/engine/projection_test.go | 30 +++++ go/vt/vtgate/engine/singlerow.go | 14 +-- go/vt/vtgate/executor_select_test.go | 16 +-- go/vt/vtgate/planbuilder/select.go | 6 +- go/vt/vtgate/planbuilder/vtgate_execution.go | 3 +- 12 files changed, 313 insertions(+), 88 deletions(-) create mode 100644 go/sqltypes/expressions.go create mode 100644 go/sqltypes/expressions_test.go create mode 100644 go/vt/sqlparser/expression_converter.go create mode 100644 go/vt/vtgate/engine/projection_test.go diff --git a/go/sqltypes/arithmetic.go b/go/sqltypes/arithmetic.go index 6149a58c2c4..88e695ec05d 100644 --- a/go/sqltypes/arithmetic.go +++ b/go/sqltypes/arithmetic.go @@ -34,6 +34,7 @@ type numeric struct { ival int64 uval uint64 fval float64 + err error } var zeroBytes = []byte("0") diff --git a/go/sqltypes/expressions.go b/go/sqltypes/expressions.go new file mode 100644 index 00000000000..57df54149b3 --- /dev/null +++ b/go/sqltypes/expressions.go @@ -0,0 +1,112 @@ +/* +Copyright 2020 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package sqltypes + +import ( + "strconv" + + querypb "vitess.io/vitess/go/vt/proto/query" + vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" + "vitess.io/vitess/go/vt/vterrors" +) + +type ExpressionEnv struct { + BindVars map[string]*querypb.BindVariable + Row []Value +} + +type EvalResult = numeric + +func (e EvalResult) Value() Value { + return castFromNumeric(e, e.typ) +} + +type Expr interface { + Evaluate(env ExpressionEnv) (EvalResult, error) +} + +var _ Expr = (*SQLVal)(nil) + +type SQLVal struct { + Type ValType + Val []byte +} + +func (s *SQLVal) Evaluate(env ExpressionEnv) (EvalResult, error) { + switch s.Type { + case IntVal: + ival, err := strconv.ParseInt(string(s.Val), 10, 64) + if err != nil { + ival = 0 + } + return numeric{typ: Int64, ival: ival}, nil + case ValArg: + val := env.BindVars[string(s.Val[1:])] + ival, err := strconv.ParseInt(string(val.Value), 10, 64) + if err != nil { + ival = 0 + } + return numeric{typ: Int64, ival: ival}, nil + } + return EvalResult{}, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "not yet implemented") +} + +type ValType int + +const ( + StrVal = ValType(iota) + IntVal + FloatVal + HexNum + HexVal + ValArg + BitVal +) + +var _ Expr = (*add)(nil) + +type add struct { + l, r Expr +} + +type subtract struct { + l, r Expr +} + +func (a *add) Evaluate(env ExpressionEnv) (EvalResult, error) { + lVal, err := a.l.Evaluate(env) + if err != nil { + return EvalResult{}, err + } + rVal, err := a.r.Evaluate(env) + if err != nil { + return EvalResult{}, err + } + return addNumericWithError(lVal, rVal) +} + +func (s *subtract) Evaluate(env ExpressionEnv) (EvalResult, error) { + lVal, err := s.l.Evaluate(env) + if err != nil { + return EvalResult{}, err + } + rVal, err := s.r.Evaluate(env) + if err != nil { + return EvalResult{}, err + } + return subtractNumericWithError(lVal, rVal) +} diff --git a/go/sqltypes/expressions_test.go b/go/sqltypes/expressions_test.go new file mode 100644 index 00000000000..87842ef6f4d --- /dev/null +++ b/go/sqltypes/expressions_test.go @@ -0,0 +1,87 @@ +/* +Copyright 2020 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package sqltypes + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + querypb "vitess.io/vitess/go/vt/proto/query" +) + +func TestEvaluate(t *testing.T) { + type testCase struct { + name string + e Expr + bindvars map[string]*querypb.BindVariable + expected Value + } + + tests := []testCase{ + { + name: "42", + e: i("42"), + expected: NewInt64(42), + }, + { + name: "40+2", + e: &add{i("40"), i("2")}, + expected: NewInt64(42), + }, + { + name: "40-2", + e: &subtract{i("40"), i("2")}, + expected: NewInt64(38), + }, + { + name: "Bind Variable", + e: b(":exp"), + expected: NewInt64(66), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + env := ExpressionEnv{ + BindVars: map[string]*querypb.BindVariable{ + ":exp": Int64BindVariable(66), + }, + Row: nil, + } + r, err := test.e.Evaluate(env) + require.NoError(t, err) + result := castFromNumeric(r, r.typ) + assert.Equal(t, test.expected, result, "expected %s but got %s", test.expected.String(), result.String()) + }) + } + +} + +func i(in string) *SQLVal { + return &SQLVal{ + Type: IntVal, + Val: []byte(in), + } +} + +func b(in string) *SQLVal { + return &SQLVal{ + Type: ValArg, + Val: []byte(in), + } +} diff --git a/go/vt/sqlparser/ast_funcs.go b/go/vt/sqlparser/ast_funcs.go index dada62f3125..024ebf7a67e 100644 --- a/go/vt/sqlparser/ast_funcs.go +++ b/go/vt/sqlparser/ast_funcs.go @@ -24,10 +24,7 @@ import ( "vitess.io/vitess/go/vt/log" "vitess.io/vitess/go/sqltypes" - "vitess.io/vitess/go/vt/vterrors" - querypb "vitess.io/vitess/go/vt/proto/query" - vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" ) // Walk calls visit on every node. @@ -425,24 +422,6 @@ func (node *ComparisonExpr) IsImpossible() bool { return false } -// ExprFromValue converts the given Value into an Expr or returns an error. -func ExprFromValue(value sqltypes.Value) (Expr, error) { - // The type checks here follow the rules defined in sqltypes/types.go. - switch { - case value.Type() == sqltypes.Null: - return &NullVal{}, nil - case value.IsIntegral(): - return NewIntVal(value.ToBytes()), nil - case value.IsFloat() || value.Type() == sqltypes.Decimal: - return NewFloatVal(value.ToBytes()), nil - case value.IsQuoted(): - return NewStrVal(value.ToBytes()), nil - default: - // We cannot support sqltypes.Expression, or any other invalid type. - return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "cannot convert value %v to AST", value) - } -} - // NewStrVal builds a new StrVal. func NewStrVal(in []byte) *SQLVal { return &SQLVal{Type: StrVal, Val: in} diff --git a/go/vt/sqlparser/ast_test.go b/go/vt/sqlparser/ast_test.go index f229c0fa4ec..82d680567b8 100644 --- a/go/vt/sqlparser/ast_test.go +++ b/go/vt/sqlparser/ast_test.go @@ -25,7 +25,6 @@ import ( "unsafe" "github.com/stretchr/testify/require" - "vitess.io/vitess/go/sqltypes" ) func TestAppend(t *testing.T) { @@ -562,47 +561,6 @@ func TestReplaceExpr(t *testing.T) { } } -func TestExprFromValue(t *testing.T) { - tcases := []struct { - in sqltypes.Value - out SQLNode - err string - }{{ - in: sqltypes.NULL, - out: &NullVal{}, - }, { - in: sqltypes.NewInt64(1), - out: NewIntVal([]byte("1")), - }, { - in: sqltypes.NewFloat64(1.1), - out: NewFloatVal([]byte("1.1")), - }, { - in: sqltypes.MakeTrusted(sqltypes.Decimal, []byte("1.1")), - out: NewFloatVal([]byte("1.1")), - }, { - in: sqltypes.NewVarChar("aa"), - out: NewStrVal([]byte("aa")), - }, { - in: sqltypes.MakeTrusted(sqltypes.Expression, []byte("rand()")), - err: "cannot convert value EXPRESSION(rand()) to AST", - }} - for _, tcase := range tcases { - got, err := ExprFromValue(tcase.in) - if tcase.err != "" { - if err == nil || err.Error() != tcase.err { - t.Errorf("ExprFromValue(%v) err: %v, want %s", tcase.in, err, tcase.err) - } - continue - } - if err != nil { - t.Error(err) - } - if got, want := got, tcase.out; !reflect.DeepEqual(got, want) { - t.Errorf("ExprFromValue(%v): %v, want %s", tcase.in, got, want) - } - } -} - func TestColNameEqual(t *testing.T) { var c1, c2 *ColName if c1.Equal(c2) { diff --git a/go/vt/sqlparser/expression_converter.go b/go/vt/sqlparser/expression_converter.go new file mode 100644 index 00000000000..02288f393e9 --- /dev/null +++ b/go/vt/sqlparser/expression_converter.go @@ -0,0 +1,14 @@ +package sqlparser + +import "vitess.io/vitess/go/sqltypes" + +func Convert(e Expr) sqltypes.Expr { + switch node := e.(type) { + case *SQLVal: + return &sqltypes.SQLVal{ + Type: sqltypes.ValType(node.Type), // eeeek! not a clean way of doing it + Val: node.Val, + } + } + return nil +} diff --git a/go/vt/vtgate/engine/projection.go b/go/vt/vtgate/engine/projection.go index 80cac6a7929..f60fec777f9 100644 --- a/go/vt/vtgate/engine/projection.go +++ b/go/vt/vtgate/engine/projection.go @@ -2,14 +2,14 @@ package engine import ( "vitess.io/vitess/go/sqltypes" - "vitess.io/vitess/go/vt/proto/query" - "vitess.io/vitess/go/vt/sqlparser" + querypb "vitess.io/vitess/go/vt/proto/query" ) var _ Primitive = (*Projection)(nil) type Projection struct { - Exprs []*sqlparser.AliasedExpr + Cols []string + Exprs []sqltypes.Expr Input Primitive } @@ -25,16 +25,55 @@ func (p *Projection) GetTableName() string { return p.Input.GetTableName() } -func (p *Projection) Execute(vcursor VCursor, bindVars map[string]*query.BindVariable, wantfields bool) (*sqltypes.Result, error) { - panic("implement me") +func (p *Projection) Execute(vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool) (*sqltypes.Result, error) { + result, err := p.Input.Execute(vcursor, bindVars, wantfields) + if err != nil { + return nil, err + } + + env := sqltypes.ExpressionEnv{ + BindVars: bindVars, + } + + if wantfields { + p.addFields(result) + } + var rows [][]sqltypes.Value + for _, row := range result.Rows { + env.Row = row + for _, exp := range p.Exprs { + result, err := exp.Evaluate(env) + if err != nil { + return nil, err + } + row = append(row, result.Value()) + } + rows = append(rows, row) + } + result.Rows = rows + return result, nil } -func (p *Projection) StreamExecute(vcursor VCursor, bindVars map[string]*query.BindVariable, wantields bool, callback func(*sqltypes.Result) error) error { +func (p *Projection) StreamExecute(vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantields bool, callback func(*sqltypes.Result) error) error { panic("implement me") } -func (p *Projection) GetFields(vcursor VCursor, bindVars map[string]*query.BindVariable) (*sqltypes.Result, error) { - panic("implement me") +func (p *Projection) GetFields(vcursor VCursor, bindVars map[string]*querypb.BindVariable) (*sqltypes.Result, error) { + qr, err := p.Input.GetFields(vcursor, bindVars) + if err != nil { + return nil, err + } + p.addFields(qr) + return qr, nil +} + +func (p *Projection) addFields(qr *sqltypes.Result) { + for _, col := range p.Cols { + qr.Fields = append(qr.Fields, &querypb.Field{ + Name: col, + Type: querypb.Type_INT64, + }) + } } func (p *Projection) Inputs() []Primitive { diff --git a/go/vt/vtgate/engine/projection_test.go b/go/vt/vtgate/engine/projection_test.go new file mode 100644 index 00000000000..9f008eaa38c --- /dev/null +++ b/go/vt/vtgate/engine/projection_test.go @@ -0,0 +1,30 @@ +/* +Copyright 2020 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package engine + +import ( + "testing" +) + +func TestEvaluate(t *testing.T) { + //statement, _ := sqlparser.Parse("select 42") + //sel := statement.(*sqlparser.Select) + //exp := sel.SelectExprs[0].(*sqlparser.AliasedExpr) + // + //result := Evaluate(exp.Expr, map[string]*query.BindVariable{}, []sqltypes.Value{}) + //fmt.Println(result) +} diff --git a/go/vt/vtgate/engine/singlerow.go b/go/vt/vtgate/engine/singlerow.go index 16862a5d3d3..8d54f25ec6a 100644 --- a/go/vt/vtgate/engine/singlerow.go +++ b/go/vt/vtgate/engine/singlerow.go @@ -29,22 +29,22 @@ type SingleRow struct { } // RouteType returns a description of the query routing type used by the primitive -func (s SingleRow) RouteType() string { +func (s *SingleRow) RouteType() string { return "" } // GetKeyspaceName specifies the Keyspace that this primitive routes to. -func (s SingleRow) GetKeyspaceName() string { +func (s *SingleRow) GetKeyspaceName() string { return "" } // GetTableName specifies the table that this primitive routes to. -func (s SingleRow) GetTableName() string { +func (s *SingleRow) GetTableName() string { return "" } // Execute performs a non-streaming exec. -func (s SingleRow) Execute(vcursor VCursor, bindVars map[string]*query.BindVariable, wantfields bool) (*sqltypes.Result, error) { +func (s *SingleRow) Execute(vcursor VCursor, bindVars map[string]*query.BindVariable, wantfields bool) (*sqltypes.Result, error) { result := sqltypes.Result{ Fields: nil, RowsAffected: 0, @@ -57,7 +57,7 @@ func (s SingleRow) Execute(vcursor VCursor, bindVars map[string]*query.BindVaria } // StreamExecute performs a streaming exec. -func (s SingleRow) StreamExecute(vcursor VCursor, bindVars map[string]*query.BindVariable, wantields bool, callback func(*sqltypes.Result) error) error { +func (s *SingleRow) StreamExecute(vcursor VCursor, bindVars map[string]*query.BindVariable, wantields bool, callback func(*sqltypes.Result) error) error { result := sqltypes.Result{ Fields: nil, RowsAffected: 0, @@ -70,11 +70,11 @@ func (s SingleRow) StreamExecute(vcursor VCursor, bindVars map[string]*query.Bin } // GetFields fetches the field info. -func (s SingleRow) GetFields(vcursor VCursor, bindVars map[string]*query.BindVariable) (*sqltypes.Result, error) { +func (s *SingleRow) GetFields(vcursor VCursor, bindVars map[string]*query.BindVariable) (*sqltypes.Result, error) { return &sqltypes.Result{}, nil } -func (s SingleRow) description() PrimitiveDescription { +func (s *SingleRow) description() PrimitiveDescription { return PrimitiveDescription{ OperatorType: "SingleRow", } diff --git a/go/vt/vtgate/executor_select_test.go b/go/vt/vtgate/executor_select_test.go index 86b4ffdaf59..56fafe2670e 100644 --- a/go/vt/vtgate/executor_select_test.go +++ b/go/vt/vtgate/executor_select_test.go @@ -228,14 +228,16 @@ func TestSelectLastInsertId(t *testing.T) { defer QueryLogger.Unsubscribe(logChan) sql := "select last_insert_id()" - _, err := executorExec(executor, sql, map[string]*querypb.BindVariable{}) + masterSession.LastInsertId = 42 + result, err := executorExec(executor, sql, map[string]*querypb.BindVariable{}) + wantResult := &sqltypes.Result{ + Rows: [][]sqltypes.Value{{ + sqltypes.NewInt64(42), + }}, + } require.NoError(t, err) - wantQueries := []*querypb.BoundQuery{{ - Sql: "select :__lastInsertId as `last_insert_id()` from dual", - BindVariables: map[string]*querypb.BindVariable{"__lastInsertId": sqltypes.Uint64BindVariable(52)}, - }} - - assert.Equal(t, wantQueries, sbc1.Queries) + utils.MustMatch(t, result, wantResult, "Mismatch") + assert.Empty(t, sbc1.Queries) } func TestSelectUserDefindVariable(t *testing.T) { diff --git a/go/vt/vtgate/planbuilder/select.go b/go/vt/vtgate/planbuilder/select.go index 42670cf0e6d..96b0960efae 100644 --- a/go/vt/vtgate/planbuilder/select.go +++ b/go/vt/vtgate/planbuilder/select.go @@ -20,6 +20,8 @@ import ( "errors" "fmt" + "vitess.io/vitess/go/sqltypes" + "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vtgate/engine" ) @@ -74,9 +76,9 @@ func buildSelectPlan(stmt sqlparser.Statement, vschema ContextVSchema) (engine.P // of the above trees to make it discard unwanted rows. func (pb *primitiveBuilder) processSelect(sel *sqlparser.Select, outer *symtab) error { if checkForDual(sel) && outer == nil { - exprs := make([]*sqlparser.AliasedExpr, len(sel.SelectExprs)) + exprs := make([]sqltypes.Expr, len(sel.SelectExprs)) for i, e := range sel.SelectExprs { - exprs[i] = e.(*sqlparser.AliasedExpr) + exprs[i] = sqlparser.Convert(e.(*sqlparser.AliasedExpr).Expr) } pb.bldr = &vtgateExecution{ exprs, diff --git a/go/vt/vtgate/planbuilder/vtgate_execution.go b/go/vt/vtgate/planbuilder/vtgate_execution.go index 9c98ef6382c..b92c2a6c7d7 100644 --- a/go/vt/vtgate/planbuilder/vtgate_execution.go +++ b/go/vt/vtgate/planbuilder/vtgate_execution.go @@ -1,6 +1,7 @@ package planbuilder import ( + "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vtgate/engine" ) @@ -8,7 +9,7 @@ import ( var _ builder = (*vtgateExecution)(nil) type vtgateExecution struct { - Exprs []*sqlparser.AliasedExpr + Exprs []sqltypes.Expr } func (v *vtgateExecution) Order() int { From 2f05ccc47e8dbc7fb7ff1e78bfbc08cab2f0c590 Mon Sep 17 00:00:00 2001 From: Saif Alharthi Date: Wed, 8 Apr 2020 22:24:43 -0700 Subject: [PATCH 04/23] Support Expression evalution for BindVariables and literals Signed-off-by: Saif Alharthi --- go/sqltypes/expressions.go | 113 ++++++++++++------- go/sqltypes/expressions_test.go | 32 ++++-- go/vt/sqlparser/expression_converter.go | 8 +- go/vt/vtgate/engine/projection.go | 1 + go/vt/vtgate/planbuilder/select.go | 3 + go/vt/vtgate/planbuilder/vtgate_execution.go | 2 + 6 files changed, 102 insertions(+), 57 deletions(-) diff --git a/go/sqltypes/expressions.go b/go/sqltypes/expressions.go index 57df54149b3..f9032ecf04d 100644 --- a/go/sqltypes/expressions.go +++ b/go/sqltypes/expressions.go @@ -39,74 +39,103 @@ type Expr interface { Evaluate(env ExpressionEnv) (EvalResult, error) } -var _ Expr = (*SQLVal)(nil) +var _ Expr = (*LiteralInt)(nil) +var _ Expr = (*BindVariable)(nil) -type SQLVal struct { - Type ValType - Val []byte +type LiteralInt struct { + Val []byte } -func (s *SQLVal) Evaluate(env ExpressionEnv) (EvalResult, error) { - switch s.Type { - case IntVal: - ival, err := strconv.ParseInt(string(s.Val), 10, 64) - if err != nil { - ival = 0 - } - return numeric{typ: Int64, ival: ival}, nil - case ValArg: - val := env.BindVars[string(s.Val[1:])] - ival, err := strconv.ParseInt(string(val.Value), 10, 64) - if err != nil { - ival = 0 - } - return numeric{typ: Int64, ival: ival}, nil +func (l *LiteralInt) Evaluate(env ExpressionEnv) (EvalResult, error) { + ival, err := strconv.ParseInt(string(l.Val), 10, 64) + if err != nil { + ival = 0 } - return EvalResult{}, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "not yet implemented") + return numeric{typ: Int64, ival: ival}, nil } -type ValType int - -const ( - StrVal = ValType(iota) - IntVal - FloatVal - HexNum - HexVal - ValArg - BitVal -) +type BindVariable struct { + Key string +} -var _ Expr = (*add)(nil) +func (b *BindVariable) Evaluate(env ExpressionEnv) (EvalResult, error) { + val, ok := env.BindVars[b.Key] + if !ok { + return EvalResult{}, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "Bind variable not found") + } + ival, err := strconv.ParseInt(string(val.Value), 10, 64) + if err != nil { + ival = 0 + } + return numeric{typ: Int64, ival: ival}, nil -type add struct { - l, r Expr } -type subtract struct { - l, r Expr +var _ Expr = (*Addition)(nil) +var _ Expr = (*Subtraction)(nil) +var _ Expr = (*Multiplication)(nil) +var _ Expr = (*Division)(nil) + +type Addition struct { + Left, Right Expr } -func (a *add) Evaluate(env ExpressionEnv) (EvalResult, error) { - lVal, err := a.l.Evaluate(env) +func (a *Addition) Evaluate(env ExpressionEnv) (EvalResult, error) { + lVal, err := a.Left.Evaluate(env) if err != nil { return EvalResult{}, err } - rVal, err := a.r.Evaluate(env) + rVal, err := a.Right.Evaluate(env) if err != nil { return EvalResult{}, err } return addNumericWithError(lVal, rVal) } -func (s *subtract) Evaluate(env ExpressionEnv) (EvalResult, error) { - lVal, err := s.l.Evaluate(env) +type Subtraction struct { + Left, Right Expr +} + +func (s *Subtraction) Evaluate(env ExpressionEnv) (EvalResult, error) { + lVal, err := s.Left.Evaluate(env) if err != nil { return EvalResult{}, err } - rVal, err := s.r.Evaluate(env) + rVal, err := s.Right.Evaluate(env) if err != nil { return EvalResult{}, err } return subtractNumericWithError(lVal, rVal) } + +type Multiplication struct { + Left, Right Expr +} + +func (m *Multiplication) Evaluate(env ExpressionEnv) (EvalResult, error) { + lVal, err := m.Left.Evaluate(env) + if err != nil { + return EvalResult{}, err + } + rVal, err := m.Right.Evaluate(env) + if err != nil { + return EvalResult{}, err + } + return multiplyNumericWithError(lVal, rVal) +} + +type Division struct { + Left, Right Expr +} + +func (d *Division) Evaluate(env ExpressionEnv) (EvalResult, error) { + lVal, err := d.Left.Evaluate(env) + if err != nil { + return EvalResult{}, err + } + rVal, err := d.Right.Evaluate(env) + if err != nil { + return EvalResult{}, err + } + return divideNumericWithError(lVal, rVal) +} diff --git a/go/sqltypes/expressions_test.go b/go/sqltypes/expressions_test.go index 87842ef6f4d..f571da05299 100644 --- a/go/sqltypes/expressions_test.go +++ b/go/sqltypes/expressions_test.go @@ -40,17 +40,27 @@ func TestEvaluate(t *testing.T) { }, { name: "40+2", - e: &add{i("40"), i("2")}, + e: &Addition{i("40"), i("2")}, expected: NewInt64(42), }, { name: "40-2", - e: &subtract{i("40"), i("2")}, + e: &Subtraction{i("40"), i("2")}, expected: NewInt64(38), }, + { + name: "40*2", + e: &Multiplication{i("40"), i("2")}, + expected: NewInt64(80), + }, + { + name: "40/2", + e: &Division{i("40"), i("2")}, + expected: NewFloat64(20), + }, { name: "Bind Variable", - e: b(":exp"), + e: b("exp"), expected: NewInt64(66), }, } @@ -59,7 +69,7 @@ func TestEvaluate(t *testing.T) { t.Run(test.name, func(t *testing.T) { env := ExpressionEnv{ BindVars: map[string]*querypb.BindVariable{ - ":exp": Int64BindVariable(66), + "exp": Int64BindVariable(66), }, Row: nil, } @@ -72,16 +82,14 @@ func TestEvaluate(t *testing.T) { } -func i(in string) *SQLVal { - return &SQLVal{ - Type: IntVal, - Val: []byte(in), +func i(in string) Expr { + return &LiteralInt{ + Val: []byte(in), } } -func b(in string) *SQLVal { - return &SQLVal{ - Type: ValArg, - Val: []byte(in), +func b(in string) Expr { + return &BindVariable{ + Key: in, } } diff --git a/go/vt/sqlparser/expression_converter.go b/go/vt/sqlparser/expression_converter.go index 02288f393e9..c4e7e5d8e2f 100644 --- a/go/vt/sqlparser/expression_converter.go +++ b/go/vt/sqlparser/expression_converter.go @@ -5,9 +5,11 @@ import "vitess.io/vitess/go/sqltypes" func Convert(e Expr) sqltypes.Expr { switch node := e.(type) { case *SQLVal: - return &sqltypes.SQLVal{ - Type: sqltypes.ValType(node.Type), // eeeek! not a clean way of doing it - Val: node.Val, + switch node.Type { + case IntVal: + return &sqltypes.LiteralInt{Val: node.Val} + case ValArg: + return &sqltypes.BindVariable{Key: string(node.Val[1:])} } } return nil diff --git a/go/vt/vtgate/engine/projection.go b/go/vt/vtgate/engine/projection.go index f60fec777f9..29e08f177ff 100644 --- a/go/vt/vtgate/engine/projection.go +++ b/go/vt/vtgate/engine/projection.go @@ -85,6 +85,7 @@ func (p *Projection) description() PrimitiveDescription { OperatorType: "Projection", Other: map[string]interface{}{ "Expressions": p.Exprs, + "Columns": p.Cols, }, } } diff --git a/go/vt/vtgate/planbuilder/select.go b/go/vt/vtgate/planbuilder/select.go index 96b0960efae..189f46fef5b 100644 --- a/go/vt/vtgate/planbuilder/select.go +++ b/go/vt/vtgate/planbuilder/select.go @@ -77,11 +77,14 @@ func buildSelectPlan(stmt sqlparser.Statement, vschema ContextVSchema) (engine.P func (pb *primitiveBuilder) processSelect(sel *sqlparser.Select, outer *symtab) error { if checkForDual(sel) && outer == nil { exprs := make([]sqltypes.Expr, len(sel.SelectExprs)) + cols := make([]string, len(sel.SelectExprs)) for i, e := range sel.SelectExprs { exprs[i] = sqlparser.Convert(e.(*sqlparser.AliasedExpr).Expr) + cols[i] = e.(*sqlparser.AliasedExpr).As.String() } pb.bldr = &vtgateExecution{ exprs, + cols, } return nil } diff --git a/go/vt/vtgate/planbuilder/vtgate_execution.go b/go/vt/vtgate/planbuilder/vtgate_execution.go index b92c2a6c7d7..5785c2eaa89 100644 --- a/go/vt/vtgate/planbuilder/vtgate_execution.go +++ b/go/vt/vtgate/planbuilder/vtgate_execution.go @@ -10,6 +10,7 @@ var _ builder = (*vtgateExecution)(nil) type vtgateExecution struct { Exprs []sqltypes.Expr + Cols []string } func (v *vtgateExecution) Order() int { @@ -75,6 +76,7 @@ func (v *vtgateExecution) SupplyWeightString(colNumber int) (weightcolNumber int func (v *vtgateExecution) Primitive() engine.Primitive { return &engine.Projection{ Exprs: v.Exprs, + Cols: v.Cols, Input: &engine.SingleRow{}, } } From 8a546b8f54034742d933558f823025ff3274eae4 Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Thu, 9 Apr 2020 12:01:56 +0200 Subject: [PATCH 05/23] Refactoring of expressions Signed-off-by: Andres Taylor --- go/sqltypes/arithmetic.go | 3 +- go/sqltypes/expressions.go | 144 +++++++++++++++----------------- go/sqltypes/expressions_test.go | 33 ++++++-- 3 files changed, 92 insertions(+), 88 deletions(-) diff --git a/go/sqltypes/arithmetic.go b/go/sqltypes/arithmetic.go index 88e695ec05d..5a314cbe474 100644 --- a/go/sqltypes/arithmetic.go +++ b/go/sqltypes/arithmetic.go @@ -34,7 +34,6 @@ type numeric struct { ival int64 uval uint64 fval float64 - err error } var zeroBytes = []byte("0") @@ -110,7 +109,7 @@ func Multiply(v1, v2 Value) (Value, error) { return castFromNumeric(lresult, lresult.typ), nil } -// Float Division for MySQL. Replicates behavior of "/" operator +// Divide (Float) for MySQL. Replicates behavior of "/" operator func Divide(v1, v2 Value) (Value, error) { if v1.IsNull() || v2.IsNull() { return NULL, nil diff --git a/go/sqltypes/expressions.go b/go/sqltypes/expressions.go index f9032ecf04d..f49683f86d2 100644 --- a/go/sqltypes/expressions.go +++ b/go/sqltypes/expressions.go @@ -24,28 +24,70 @@ import ( "vitess.io/vitess/go/vt/vterrors" ) -type ExpressionEnv struct { - BindVars map[string]*querypb.BindVariable - Row []Value -} +type ( + + //ExpressionEnv contains the environment that the expression + //evaluates in, such as the current row and bindvars + ExpressionEnv struct { + BindVars map[string]*querypb.BindVariable + Row []Value + } + + // We use this type alias so we don't have to expose the private struct numeric + EvalResult = numeric + + // Expr is the interface that all evaluating expressions must implement + Expr interface { + Evaluate(env ExpressionEnv) (EvalResult, error) + } + + //BinaryExpr allows binary expressions to not have to evaluate child expressions - this is done by the BinaryOp + BinaryExpr interface { + Evaluate(left, right EvalResult) (EvalResult, error) + } + + // Expressions + LiteralInt struct{ Val []byte } + BindVariable struct{ Key string } + BinaryOp struct { + Expr BinaryExpr + Left, Right Expr + } -type EvalResult = numeric + // Binary ops + Addition struct{} + Subtraction struct{} + Multiplication struct{} + Division struct{} +) +//Value allows for retrieval of the value we need func (e EvalResult) Value() Value { return castFromNumeric(e, e.typ) } -type Expr interface { - Evaluate(env ExpressionEnv) (EvalResult, error) -} - var _ Expr = (*LiteralInt)(nil) var _ Expr = (*BindVariable)(nil) - -type LiteralInt struct { - Val []byte +var _ Expr = (*BinaryOp)(nil) +var _ BinaryExpr = (*Addition)(nil) +var _ BinaryExpr = (*Subtraction)(nil) +var _ BinaryExpr = (*Multiplication)(nil) +var _ BinaryExpr = (*Division)(nil) + +//Evaluate implements the Expr interface +func (b BinaryOp) Evaluate(env ExpressionEnv) (EvalResult, error) { + lVal, err := b.Left.Evaluate(env) + if err != nil { + return EvalResult{}, err + } + rVal, err := b.Right.Evaluate(env) + if err != nil { + return EvalResult{}, err + } + return b.Expr.Evaluate(lVal, rVal) } +//Evaluate implements the Expr interface func (l *LiteralInt) Evaluate(env ExpressionEnv) (EvalResult, error) { ival, err := strconv.ParseInt(string(l.Val), 10, 64) if err != nil { @@ -54,10 +96,7 @@ func (l *LiteralInt) Evaluate(env ExpressionEnv) (EvalResult, error) { return numeric{typ: Int64, ival: ival}, nil } -type BindVariable struct { - Key string -} - +//Evaluate implements the Expr interface func (b *BindVariable) Evaluate(env ExpressionEnv) (EvalResult, error) { val, ok := env.BindVars[b.Key] if !ok { @@ -71,71 +110,22 @@ func (b *BindVariable) Evaluate(env ExpressionEnv) (EvalResult, error) { } -var _ Expr = (*Addition)(nil) -var _ Expr = (*Subtraction)(nil) -var _ Expr = (*Multiplication)(nil) -var _ Expr = (*Division)(nil) - -type Addition struct { - Left, Right Expr +//Evaluate implements the BinaryOp interface +func (a *Addition) Evaluate(left, right EvalResult) (EvalResult, error) { + return addNumericWithError(left, right) } -func (a *Addition) Evaluate(env ExpressionEnv) (EvalResult, error) { - lVal, err := a.Left.Evaluate(env) - if err != nil { - return EvalResult{}, err - } - rVal, err := a.Right.Evaluate(env) - if err != nil { - return EvalResult{}, err - } - return addNumericWithError(lVal, rVal) -} - -type Subtraction struct { - Left, Right Expr -} - -func (s *Subtraction) Evaluate(env ExpressionEnv) (EvalResult, error) { - lVal, err := s.Left.Evaluate(env) - if err != nil { - return EvalResult{}, err - } - rVal, err := s.Right.Evaluate(env) - if err != nil { - return EvalResult{}, err - } - return subtractNumericWithError(lVal, rVal) +//Evaluate implements the BinaryOp interface +func (s *Subtraction) Evaluate(left, right EvalResult) (EvalResult, error) { + return subtractNumericWithError(left, right) } -type Multiplication struct { - Left, Right Expr +//Evaluate implements the BinaryOp interface +func (m *Multiplication) Evaluate(left, right EvalResult) (EvalResult, error) { + return multiplyNumericWithError(left, right) } -func (m *Multiplication) Evaluate(env ExpressionEnv) (EvalResult, error) { - lVal, err := m.Left.Evaluate(env) - if err != nil { - return EvalResult{}, err - } - rVal, err := m.Right.Evaluate(env) - if err != nil { - return EvalResult{}, err - } - return multiplyNumericWithError(lVal, rVal) -} - -type Division struct { - Left, Right Expr -} - -func (d *Division) Evaluate(env ExpressionEnv) (EvalResult, error) { - lVal, err := d.Left.Evaluate(env) - if err != nil { - return EvalResult{}, err - } - rVal, err := d.Right.Evaluate(env) - if err != nil { - return EvalResult{}, err - } - return divideNumericWithError(lVal, rVal) +//Evaluate implements the BinaryOp interface +func (d *Division) Evaluate(left, right EvalResult) (EvalResult, error) { + return divideNumericWithError(left, right) } diff --git a/go/sqltypes/expressions_test.go b/go/sqltypes/expressions_test.go index f571da05299..e8ec262a215 100644 --- a/go/sqltypes/expressions_test.go +++ b/go/sqltypes/expressions_test.go @@ -28,7 +28,6 @@ func TestEvaluate(t *testing.T) { type testCase struct { name string e Expr - bindvars map[string]*querypb.BindVariable expected Value } @@ -39,23 +38,39 @@ func TestEvaluate(t *testing.T) { expected: NewInt64(42), }, { - name: "40+2", - e: &Addition{i("40"), i("2")}, + name: "40+2", + e: BinaryOp{ + Expr: &Addition{}, + Left: i("40"), + Right: i("2"), + }, expected: NewInt64(42), }, { - name: "40-2", - e: &Subtraction{i("40"), i("2")}, + name: "40-2", + e: BinaryOp{ + Expr: &Subtraction{}, + Left: i("40"), + Right: i("2"), + }, expected: NewInt64(38), }, { - name: "40*2", - e: &Multiplication{i("40"), i("2")}, + name: "40*2", + e: BinaryOp{ + Expr: &Multiplication{}, + Left: i("40"), + Right: i("2"), + }, expected: NewInt64(80), }, { - name: "40/2", - e: &Division{i("40"), i("2")}, + name: "40/2", + e: BinaryOp{ + Expr: &Division{}, + Left: i("40"), + Right: i("2"), + }, expected: NewFloat64(20), }, { From b4f8ddc3863dd0c6451898b35fab001b37beb66a Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Thu, 9 Apr 2020 12:28:12 +0200 Subject: [PATCH 06/23] Cleanup of test Signed-off-by: Andres Taylor --- go/sqltypes/expressions.go | 2 +- go/sqltypes/expressions_test.go | 93 +------------------------ go/vt/sqlparser/expression_converter.go | 37 ++++++++++ go/vt/sqlparser/expressions_test.go | 83 ++++++++++++++++++++++ 4 files changed, 122 insertions(+), 93 deletions(-) create mode 100644 go/vt/sqlparser/expressions_test.go diff --git a/go/sqltypes/expressions.go b/go/sqltypes/expressions.go index f49683f86d2..672783ee260 100644 --- a/go/sqltypes/expressions.go +++ b/go/sqltypes/expressions.go @@ -75,7 +75,7 @@ var _ BinaryExpr = (*Multiplication)(nil) var _ BinaryExpr = (*Division)(nil) //Evaluate implements the Expr interface -func (b BinaryOp) Evaluate(env ExpressionEnv) (EvalResult, error) { +func (b *BinaryOp) Evaluate(env ExpressionEnv) (EvalResult, error) { lVal, err := b.Left.Evaluate(env) if err != nil { return EvalResult{}, err diff --git a/go/sqltypes/expressions_test.go b/go/sqltypes/expressions_test.go index e8ec262a215..9d393eb4b1b 100644 --- a/go/sqltypes/expressions_test.go +++ b/go/sqltypes/expressions_test.go @@ -16,95 +16,4 @@ limitations under the License. package sqltypes -import ( - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - querypb "vitess.io/vitess/go/vt/proto/query" -) - -func TestEvaluate(t *testing.T) { - type testCase struct { - name string - e Expr - expected Value - } - - tests := []testCase{ - { - name: "42", - e: i("42"), - expected: NewInt64(42), - }, - { - name: "40+2", - e: BinaryOp{ - Expr: &Addition{}, - Left: i("40"), - Right: i("2"), - }, - expected: NewInt64(42), - }, - { - name: "40-2", - e: BinaryOp{ - Expr: &Subtraction{}, - Left: i("40"), - Right: i("2"), - }, - expected: NewInt64(38), - }, - { - name: "40*2", - e: BinaryOp{ - Expr: &Multiplication{}, - Left: i("40"), - Right: i("2"), - }, - expected: NewInt64(80), - }, - { - name: "40/2", - e: BinaryOp{ - Expr: &Division{}, - Left: i("40"), - Right: i("2"), - }, - expected: NewFloat64(20), - }, - { - name: "Bind Variable", - e: b("exp"), - expected: NewInt64(66), - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - env := ExpressionEnv{ - BindVars: map[string]*querypb.BindVariable{ - "exp": Int64BindVariable(66), - }, - Row: nil, - } - r, err := test.e.Evaluate(env) - require.NoError(t, err) - result := castFromNumeric(r, r.typ) - assert.Equal(t, test.expected, result, "expected %s but got %s", test.expected.String(), result.String()) - }) - } - -} - -func i(in string) Expr { - return &LiteralInt{ - Val: []byte(in), - } -} - -func b(in string) Expr { - return &BindVariable{ - Key: in, - } -} +// these tests live in go/sqltypes/expressions_test.go diff --git a/go/vt/sqlparser/expression_converter.go b/go/vt/sqlparser/expression_converter.go index c4e7e5d8e2f..d04a62c69ac 100644 --- a/go/vt/sqlparser/expression_converter.go +++ b/go/vt/sqlparser/expression_converter.go @@ -1,7 +1,24 @@ +/* +Copyright 2020 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + package sqlparser import "vitess.io/vitess/go/sqltypes" +//Convert converts between AST expressions and executable expressions func Convert(e Expr) sqltypes.Expr { switch node := e.(type) { case *SQLVal: @@ -11,6 +28,26 @@ func Convert(e Expr) sqltypes.Expr { case ValArg: return &sqltypes.BindVariable{Key: string(node.Val[1:])} } + case *BinaryExpr: + var op sqltypes.BinaryExpr + switch node.Operator { + case PlusStr: + op = &sqltypes.Addition{} + case MinusStr: + op = &sqltypes.Subtraction{} + case MultStr: + op = &sqltypes.Multiplication{} + case DivStr: + op = &sqltypes.Division{} + default: + return nil + } + return &sqltypes.BinaryOp{ + Expr: op, + Left: Convert(node.Left), + Right: Convert(node.Right), + } + } return nil } diff --git a/go/vt/sqlparser/expressions_test.go b/go/vt/sqlparser/expressions_test.go new file mode 100644 index 00000000000..134f0a7c341 --- /dev/null +++ b/go/vt/sqlparser/expressions_test.go @@ -0,0 +1,83 @@ +/* +Copyright 2020 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package sqlparser + +import ( + "testing" + + "vitess.io/vitess/go/sqltypes" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + querypb "vitess.io/vitess/go/vt/proto/query" +) + +/* +These tests should in theory live in the sqltypes package but they live here so we can +exercise both expression conversion and evaluation in the same test file +*/ + +func TestEvaluate(t *testing.T) { + type testCase struct { + expression string + expected sqltypes.Value + } + + tests := []testCase{{ + expression: "42", + expected: sqltypes.NewInt64(42), + }, { + expression: "40+2", + expected: sqltypes.NewInt64(42), + }, { + expression: "40-2", + expected: sqltypes.NewInt64(38), + }, { + expression: "40*2", + expected: sqltypes.NewInt64(80), + }, { + expression: "40/2", + expected: sqltypes.NewFloat64(20), + }, { + expression: ":exp", + expected: sqltypes.NewInt64(66), + }} + + for _, test := range tests { + t.Run(test.expression, func(t *testing.T) { + // Given + stmt, err := Parse("select " + test.expression) + require.NoError(t, err) + astExpr := stmt.(*Select).SelectExprs[0].(*AliasedExpr).Expr + sqltypesExpr := Convert(astExpr) + require.NotNil(t, sqltypesExpr) + env := sqltypes.ExpressionEnv{ + BindVars: map[string]*querypb.BindVariable{ + "exp": sqltypes.Int64BindVariable(66), + }, + Row: nil, + } + + // When + r, err := sqltypesExpr.Evaluate(env) + + // Then + require.NoError(t, err) + assert.Equal(t, test.expected, r.Value(), "expected %s", test.expected.String()) + }) + } +} From 8e858ed720d3636642c4acd367a19943e8ba4058 Mon Sep 17 00:00:00 2001 From: Saif Alharthi Date: Thu, 9 Apr 2020 20:10:00 -0700 Subject: [PATCH 07/23] Fix up some tests for last_insert_ID, databaste(), and variables Signed-off-by: Saif Alharthi --- go/vt/vtgate/executor_select_test.go | 64 +++++++++++-------- go/vt/vtgate/plan_executor_select_test.go | 77 ++++++++++++++--------- 2 files changed, 85 insertions(+), 56 deletions(-) diff --git a/go/vt/vtgate/executor_select_test.go b/go/vt/vtgate/executor_select_test.go index 56fafe2670e..3d728f589dc 100644 --- a/go/vt/vtgate/executor_select_test.go +++ b/go/vt/vtgate/executor_select_test.go @@ -222,7 +222,7 @@ func TestStreamBuffering(t *testing.T) { func TestSelectLastInsertId(t *testing.T) { masterSession.LastInsertId = 52 - executor, sbc1, _, _ := createExecutorEnv() + executor, _, _, _ := createExecutorEnv() executor.normalize = true logChan := QueryLogger.Subscribe("Test") defer QueryLogger.Unsubscribe(logChan) @@ -231,35 +231,41 @@ func TestSelectLastInsertId(t *testing.T) { masterSession.LastInsertId = 42 result, err := executorExec(executor, sql, map[string]*querypb.BindVariable{}) wantResult := &sqltypes.Result{ + Fields: []*querypb.Field{ + {Name: "last_insert_id()", Type: sqltypes.Int64}, + }, Rows: [][]sqltypes.Value{{ sqltypes.NewInt64(42), }}, } require.NoError(t, err) utils.MustMatch(t, result, wantResult, "Mismatch") - assert.Empty(t, sbc1.Queries) } func TestSelectUserDefindVariable(t *testing.T) { - executor, sbc1, _, _ := createExecutorEnv() + executor, _, _, _ := createExecutorEnv() executor.normalize = true logChan := QueryLogger.Subscribe("Test") defer QueryLogger.Unsubscribe(logChan) sql := "select @foo" masterSession = &vtgatepb.Session{UserDefinedVariables: createMap([]string{"foo"}, []interface{}{"bar"})} - _, err := executorExec(executor, sql, map[string]*querypb.BindVariable{}) + result, err := executorExec(executor, sql, map[string]*querypb.BindVariable{}) require.NoError(t, err) - wantQueries := []*querypb.BoundQuery{{ - Sql: "select :__vtudvfoo as `@foo` from dual", - BindVariables: map[string]*querypb.BindVariable{"__vtudvfoo": sqltypes.StringBindVariable("bar")}, - }} - - assert.Equal(t, wantQueries, sbc1.Queries) + wantResult := &sqltypes.Result{ + Fields: []*querypb.Field{ + {Name: "@foo", Type: sqltypes.Int64}, + }, + Rows: [][]sqltypes.Value{{ + sqltypes.NewInt64(0), + }}, + } + require.NoError(t, err) + utils.MustMatch(t, result, wantResult, "Mismatch") } func TestFoundRows(t *testing.T) { - executor, sbc1, _, _ := createExecutorEnv() + executor, _, _, _ := createExecutorEnv() executor.normalize = true logChan := QueryLogger.Subscribe("Test") defer QueryLogger.Unsubscribe(logChan) @@ -269,14 +275,17 @@ func TestFoundRows(t *testing.T) { require.NoError(t, err) sql := "select found_rows()" - _, err = executorExec(executor, sql, map[string]*querypb.BindVariable{}) - require.NoError(t, err) - expected := &querypb.BoundQuery{ - Sql: "select :__vtfrows as `found_rows()` from dual", - BindVariables: map[string]*querypb.BindVariable{"__vtfrows": sqltypes.Uint64BindVariable(1)}, + result, err := executorExec(executor, sql, map[string]*querypb.BindVariable{}) + wantResult := &sqltypes.Result{ + Fields: []*querypb.Field{ + {Name: "found_rows()", Type: sqltypes.Int64}, + }, + Rows: [][]sqltypes.Value{{ + sqltypes.NewInt64(0), + }}, } - - assert.Equal(t, expected, sbc1.Queries[1]) + require.NoError(t, err) + utils.MustMatch(t, result, wantResult, "Mismatch") } func TestSelectLastInsertIdInUnion(t *testing.T) { @@ -363,26 +372,29 @@ func TestLastInsertIDInSubQueryExpression(t *testing.T) { } func TestSelectDatabase(t *testing.T) { - executor, sbc1, _, _ := createExecutorEnv() + executor, _, _, _ := createExecutorEnv() executor.normalize = true sql := "select database()" newSession := *masterSession session := NewSafeSession(&newSession) session.TargetString = "TestExecutor@master" - _, err := executor.Execute( + result, err := executor.Execute( context.Background(), "TestExecute", session, sql, map[string]*querypb.BindVariable{}) - + wantResult := &sqltypes.Result{ + Fields: []*querypb.Field{ + {Name: "database()", Type: sqltypes.Int64}, + }, + Rows: [][]sqltypes.Value{{ + sqltypes.NewInt64(0), + }}, + } require.NoError(t, err) - wantQueries := []*querypb.BoundQuery{{ - Sql: "select :__vtdbname as `database()` from dual", - BindVariables: map[string]*querypb.BindVariable{"__vtdbname": sqltypes.StringBindVariable("TestExecutor")}, - }} + utils.MustMatch(t, result, wantResult, "Mismatch") - assert.Equal(t, wantQueries, sbc1.Queries) } func TestSelectBindvars(t *testing.T) { diff --git a/go/vt/vtgate/plan_executor_select_test.go b/go/vt/vtgate/plan_executor_select_test.go index 5d234e5da07..402f8b4c5de 100644 --- a/go/vt/vtgate/plan_executor_select_test.go +++ b/go/vt/vtgate/plan_executor_select_test.go @@ -22,6 +22,8 @@ import ( "strings" "testing" + "vitess.io/vitess/go/test/utils" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -221,42 +223,50 @@ func TestPlanStreamBuffering(t *testing.T) { func TestPlanSelectLastInsertId(t *testing.T) { masterSession.LastInsertId = 52 - executor, sbc1, _, _ := createExecutorEnvUsing(planAllTheThings) + executor, _, _, _ := createExecutorEnvUsing(planAllTheThings) executor.normalize = true logChan := QueryLogger.Subscribe("Test") defer QueryLogger.Unsubscribe(logChan) sql := "select last_insert_id()" - _, err := executorExec(executor, sql, map[string]*querypb.BindVariable{}) + result, err := executorExec(executor, sql, map[string]*querypb.BindVariable{}) + wantResult := &sqltypes.Result{ + Fields: []*querypb.Field{ + {Name: "last_insert_id()", Type: sqltypes.Int64}, + }, + Rows: [][]sqltypes.Value{{ + sqltypes.NewInt64(52), + }}, + } require.NoError(t, err) - wantQueries := []*querypb.BoundQuery{{ - Sql: "select :__lastInsertId as `last_insert_id()` from dual", - BindVariables: map[string]*querypb.BindVariable{"__lastInsertId": sqltypes.Uint64BindVariable(52)}, - }} + utils.MustMatch(t, result, wantResult, "Mismatch") - assert.Equal(t, wantQueries, sbc1.Queries) } func TestPlanSelectUserDefindVariable(t *testing.T) { - executor, sbc1, _, _ := createExecutorEnvUsing(planAllTheThings) + executor, _, _, _ := createExecutorEnvUsing(planAllTheThings) executor.normalize = true logChan := QueryLogger.Subscribe("Test") defer QueryLogger.Unsubscribe(logChan) sql := "select @foo" masterSession = &vtgatepb.Session{UserDefinedVariables: createMap([]string{"foo"}, []interface{}{"bar"})} - _, err := executorExec(executor, sql, map[string]*querypb.BindVariable{}) + result, err := executorExec(executor, sql, map[string]*querypb.BindVariable{}) require.NoError(t, err) - wantQueries := []*querypb.BoundQuery{{ - Sql: "select :__vtudvfoo as `@foo` from dual", - BindVariables: map[string]*querypb.BindVariable{"__vtudvfoo": sqltypes.StringBindVariable("bar")}, - }} - - assert.Equal(t, wantQueries, sbc1.Queries) + wantResult := &sqltypes.Result{ + Fields: []*querypb.Field{ + {Name: "@foo", Type: sqltypes.Int64}, + }, + Rows: [][]sqltypes.Value{{ + sqltypes.NewInt64(0), + }}, + } + require.NoError(t, err) + utils.MustMatch(t, result, wantResult, "Mismatch") } func TestPlanFoundRows(t *testing.T) { - executor, sbc1, _, _ := createExecutorEnvUsing(planAllTheThings) + executor, _, _, _ := createExecutorEnvUsing(planAllTheThings) executor.normalize = true logChan := QueryLogger.Subscribe("Test") defer QueryLogger.Unsubscribe(logChan) @@ -266,14 +276,18 @@ func TestPlanFoundRows(t *testing.T) { require.NoError(t, err) sql := "select found_rows()" - _, err = executorExec(executor, sql, map[string]*querypb.BindVariable{}) - require.NoError(t, err) - expected := &querypb.BoundQuery{ - Sql: "select :__vtfrows as `found_rows()` from dual", - BindVariables: map[string]*querypb.BindVariable{"__vtfrows": sqltypes.Uint64BindVariable(1)}, + result, err := executorExec(executor, sql, map[string]*querypb.BindVariable{}) + wantResult := &sqltypes.Result{ + Fields: []*querypb.Field{ + {Name: "found_rows()", Type: sqltypes.Int64}, + }, + Rows: [][]sqltypes.Value{{ + sqltypes.NewInt64(0), + }}, } + require.NoError(t, err) + utils.MustMatch(t, result, wantResult, "Mismatch") - assert.Equal(t, expected, sbc1.Queries[1]) } func TestPlanSelectLastInsertIdInUnion(t *testing.T) { @@ -360,26 +374,29 @@ func TestPlanLastInsertIDInSubQueryExpression(t *testing.T) { } func TestPlanSelectDatabase(t *testing.T) { - executor, sbc1, _, _ := createExecutorEnvUsing(planAllTheThings) + executor, _, _, _ := createExecutorEnvUsing(planAllTheThings) executor.normalize = true sql := "select database()" newSession := *masterSession session := NewSafeSession(&newSession) session.TargetString = "TestExecutor@master" - _, err := executor.Execute( + result, err := executor.Execute( context.Background(), "TestExecute", session, sql, map[string]*querypb.BindVariable{}) - + wantResult := &sqltypes.Result{ + Fields: []*querypb.Field{ + {Name: "database()", Type: sqltypes.Int64}, + }, + Rows: [][]sqltypes.Value{{ + sqltypes.NewInt64(0), + }}, + } require.NoError(t, err) - wantQueries := []*querypb.BoundQuery{{ - Sql: "select :__vtdbname as `database()` from dual", - BindVariables: map[string]*querypb.BindVariable{"__vtdbname": sqltypes.StringBindVariable("TestExecutor")}, - }} + utils.MustMatch(t, result, wantResult, "Mismatch") - assert.Equal(t, wantQueries, sbc1.Queries) } func TestPlanSelectBindvars(t *testing.T) { From 3f1fe3b0c8bac8e56bfb5cf60cd8d325ff549425 Mon Sep 17 00:00:00 2001 From: Saif Alharthi Date: Thu, 9 Apr 2020 22:00:34 -0700 Subject: [PATCH 08/23] Fixed expression evaluation for Union and virtual tables Signed-off-by: Saif Alharthi --- go/vt/vtgate/planbuilder/expr.go | 4 ++-- go/vt/vtgate/planbuilder/from.go | 2 +- go/vt/vtgate/planbuilder/select.go | 6 +++--- go/vt/vtgate/planbuilder/union.go | 2 +- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/go/vt/vtgate/planbuilder/expr.go b/go/vt/vtgate/planbuilder/expr.go index e7a51664e6b..e0640be45bf 100644 --- a/go/vt/vtgate/planbuilder/expr.go +++ b/go/vt/vtgate/planbuilder/expr.go @@ -104,7 +104,7 @@ func (pb *primitiveBuilder) findOrigin(expr sqlparser.Expr) (pullouts []*pullout spb := newPrimitiveBuilder(pb.vschema, pb.jt) switch stmt := node.Select.(type) { case *sqlparser.Select: - if err := spb.processSelect(stmt, pb.st); err != nil { + if err := spb.processSelect(stmt, pb.st, false); err != nil { return false, err } case *sqlparser.Union: @@ -230,7 +230,7 @@ func (pb *primitiveBuilder) finalizeUnshardedDMLSubqueries(nodes ...sqlparser.SQ return true, nil } spb := newPrimitiveBuilder(pb.vschema, pb.jt) - if err := spb.processSelect(nodeType, pb.st); err != nil { + if err := spb.processSelect(nodeType, pb.st, false); err != nil { samePlan = false return false, err } diff --git a/go/vt/vtgate/planbuilder/from.go b/go/vt/vtgate/planbuilder/from.go index 772d224f13a..b7013d88947 100644 --- a/go/vt/vtgate/planbuilder/from.go +++ b/go/vt/vtgate/planbuilder/from.go @@ -97,7 +97,7 @@ func (pb *primitiveBuilder) processAliasedTable(tableExpr *sqlparser.AliasedTabl spb := newPrimitiveBuilder(pb.vschema, pb.jt) switch stmt := expr.Select.(type) { case *sqlparser.Select: - if err := spb.processSelect(stmt, nil); err != nil { + if err := spb.processSelect(stmt, nil, false); err != nil { return err } case *sqlparser.Union: diff --git a/go/vt/vtgate/planbuilder/select.go b/go/vt/vtgate/planbuilder/select.go index 189f46fef5b..d59f463c229 100644 --- a/go/vt/vtgate/planbuilder/select.go +++ b/go/vt/vtgate/planbuilder/select.go @@ -30,7 +30,7 @@ import ( func buildSelectPlan(stmt sqlparser.Statement, vschema ContextVSchema) (engine.Primitive, error) { sel := stmt.(*sqlparser.Select) pb := newPrimitiveBuilder(vschema, newJointab(sqlparser.GetBindvars(sel))) - if err := pb.processSelect(sel, nil); err != nil { + if err := pb.processSelect(sel, nil, true); err != nil { return nil, err } if err := pb.bldr.Wireup(pb.bldr, pb.jt); err != nil { @@ -74,8 +74,8 @@ func buildSelectPlan(stmt sqlparser.Statement, vschema ContextVSchema) (engine.P // The LIMIT clause is the last construct of a query. If it cannot be // pushed into a route, then a primitive is created on top of any // of the above trees to make it discard unwanted rows. -func (pb *primitiveBuilder) processSelect(sel *sqlparser.Select, outer *symtab) error { - if checkForDual(sel) && outer == nil { +func (pb *primitiveBuilder) processSelect(sel *sqlparser.Select, outer *symtab, canRunLocally bool) error { + if checkForDual(sel) && outer == nil && canRunLocally { exprs := make([]sqltypes.Expr, len(sel.SelectExprs)) cols := make([]string, len(sel.SelectExprs)) for i, e := range sel.SelectExprs { diff --git a/go/vt/vtgate/planbuilder/union.go b/go/vt/vtgate/planbuilder/union.go index 71fb9e4182d..a6f7de1b472 100644 --- a/go/vt/vtgate/planbuilder/union.go +++ b/go/vt/vtgate/planbuilder/union.go @@ -62,7 +62,7 @@ func (pb *primitiveBuilder) processPart(part sqlparser.SelectStatement, outer *s case *sqlparser.Union: return pb.processUnion(part, outer) case *sqlparser.Select: - return pb.processSelect(part, outer) + return pb.processSelect(part, outer, false) case *sqlparser.ParenSelect: return pb.processPart(part.Select, outer) } From 1b692d22c611ba55d41be727c9df0ce68c67d283 Mon Sep 17 00:00:00 2001 From: Saif Alharthi Date: Thu, 9 Apr 2020 22:31:57 -0700 Subject: [PATCH 09/23] Fix plan tests and modify descriptions Signed-off-by: Saif Alharthi --- go/sqltypes/expressions.go | 30 +++++++++++++++++++ go/vt/vtgate/engine/projection.go | 6 +++- .../planbuilder/testdata/from_cases.txt | 11 +++---- .../planbuilder/testdata/select_cases.txt | 23 ++++++-------- 4 files changed, 48 insertions(+), 22 deletions(-) diff --git a/go/sqltypes/expressions.go b/go/sqltypes/expressions.go index 672783ee260..ae2f0293b3b 100644 --- a/go/sqltypes/expressions.go +++ b/go/sqltypes/expressions.go @@ -39,11 +39,13 @@ type ( // Expr is the interface that all evaluating expressions must implement Expr interface { Evaluate(env ExpressionEnv) (EvalResult, error) + String() string } //BinaryExpr allows binary expressions to not have to evaluate child expressions - this is done by the BinaryOp BinaryExpr interface { Evaluate(left, right EvalResult) (EvalResult, error) + String() string } // Expressions @@ -61,6 +63,34 @@ type ( Division struct{} ) +func (d *Division) String() string { + return "/" +} + +func (m *Multiplication) String() string { + return "*" +} + +func (s *Subtraction) String() string { + return "-" +} + +func (a *Addition) String() string { + return "+" +} + +func (b *BinaryOp) String() string { + return b.Left.String() + " " + b.Expr.String() + " " + b.Right.String() +} + +func (b *BindVariable) String() string { + return ":" + b.Key +} + +func (l *LiteralInt) String() string { + return string(l.Val) +} + //Value allows for retrieval of the value we need func (e EvalResult) Value() Value { return castFromNumeric(e, e.typ) diff --git a/go/vt/vtgate/engine/projection.go b/go/vt/vtgate/engine/projection.go index 29e08f177ff..8b713be8d7f 100644 --- a/go/vt/vtgate/engine/projection.go +++ b/go/vt/vtgate/engine/projection.go @@ -81,10 +81,14 @@ func (p *Projection) Inputs() []Primitive { } func (p *Projection) description() PrimitiveDescription { + var exprs []string + for _, e := range p.Exprs { + exprs = append(exprs, e.String()) + } return PrimitiveDescription{ OperatorType: "Projection", Other: map[string]interface{}{ - "Expressions": p.Exprs, + "Expressions": exprs, "Columns": p.Cols, }, } diff --git a/go/vt/vtgate/planbuilder/testdata/from_cases.txt b/go/vt/vtgate/planbuilder/testdata/from_cases.txt index bd26660a646..be353439399 100644 --- a/go/vt/vtgate/planbuilder/testdata/from_cases.txt +++ b/go/vt/vtgate/planbuilder/testdata/from_cases.txt @@ -2051,14 +2051,11 @@ "Instructions": { "OperatorType": "Projection", "Variant": "", + "Columns": [ + "last_insert_id()" + ], "Expressions": [ - { - "Expr": { - "Type": 5, - "Val": "Ol9fbGFzdEluc2VydElk" - }, - "As": "last_insert_id()" - } + ":__lastInsertId" ], "Inputs": [ { diff --git a/go/vt/vtgate/planbuilder/testdata/select_cases.txt b/go/vt/vtgate/planbuilder/testdata/select_cases.txt index d0889f442a3..17afad5a7a5 100644 --- a/go/vt/vtgate/planbuilder/testdata/select_cases.txt +++ b/go/vt/vtgate/planbuilder/testdata/select_cases.txt @@ -413,14 +413,11 @@ "Instructions": { "OperatorType": "Projection", "Variant": "", + "Columns": [ + "database()" + ], "Expressions": [ - { - "Expr": { - "Type": 5, - "Val": "Ol9fdnRkYm5hbWU=" - }, - "As": "database()" - } + ":__vtdbname" ], "Inputs": [ { @@ -1362,18 +1359,16 @@ # testing SingleRow Projection "select 42" { + "QueryType": "SELECT", "Original": "select 42", "Instructions": { "OperatorType": "Projection", "Variant": "", + "Columns": [ + "" + ], "Expressions": [ - { - "Expr": { - "Type": 1, - "Val": "NDI=" - }, - "As": "" - } + "42" ], "Inputs": [ { From 230e34e0b5075c082361abb747c27e2479083873 Mon Sep 17 00:00:00 2001 From: Saif Alharthi Date: Sat, 11 Apr 2020 01:39:12 -0700 Subject: [PATCH 10/23] Supported evalution of column types Signed-off-by: Saif Alharthi --- go/sqltypes/arithmetic.go | 163 ++++++------ go/sqltypes/arithmetic_test.go | 254 +++++++++---------- go/sqltypes/expressions.go | 82 +++++- go/sqltypes/expressions_test.go | 2 +- go/vt/sqlparser/expressions_test.go | 14 +- go/vt/vtgate/endtoend/last_insert_id_test.go | 4 +- go/vt/vtgate/engine/projection.go | 11 +- go/vt/vtgate/planbuilder/select.go | 6 +- 8 files changed, 305 insertions(+), 231 deletions(-) diff --git a/go/sqltypes/arithmetic.go b/go/sqltypes/arithmetic.go index 5a314cbe474..4e94333ba53 100644 --- a/go/sqltypes/arithmetic.go +++ b/go/sqltypes/arithmetic.go @@ -29,13 +29,6 @@ import ( // numeric represents a numeric value extracted from // a Value, used for arithmetic operations. -type numeric struct { - typ querypb.Type - ival int64 - uval uint64 - fval float64 -} - var zeroBytes = []byte("0") // Add adds two values together @@ -360,68 +353,68 @@ func ToNative(v Value) (interface{}, error) { } // newNumeric parses a value and produces an Int64, Uint64 or Float64. -func newNumeric(v Value) (numeric, error) { +func newNumeric(v Value) (evalResult, error) { str := v.ToString() switch { case v.IsSigned(): ival, err := strconv.ParseInt(str, 10, 64) if err != nil { - return numeric{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "%v", err) + return evalResult{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "%v", err) } - return numeric{ival: ival, typ: Int64}, nil + return evalResult{ival: ival, typ: Int64}, nil case v.IsUnsigned(): uval, err := strconv.ParseUint(str, 10, 64) if err != nil { - return numeric{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "%v", err) + return evalResult{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "%v", err) } - return numeric{uval: uval, typ: Uint64}, nil + return evalResult{uval: uval, typ: Uint64}, nil case v.IsFloat(): fval, err := strconv.ParseFloat(str, 64) if err != nil { - return numeric{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "%v", err) + return evalResult{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "%v", err) } - return numeric{fval: fval, typ: Float64}, nil + return evalResult{fval: fval, typ: Float64}, nil } // For other types, do best effort. if ival, err := strconv.ParseInt(str, 10, 64); err == nil { - return numeric{ival: ival, typ: Int64}, nil + return evalResult{ival: ival, typ: Int64}, nil } if fval, err := strconv.ParseFloat(str, 64); err == nil { - return numeric{fval: fval, typ: Float64}, nil + return evalResult{fval: fval, typ: Float64}, nil } - return numeric{ival: 0, typ: Int64}, nil + return evalResult{ival: 0, typ: Int64}, nil } // newIntegralNumeric parses a value and produces an Int64 or Uint64. -func newIntegralNumeric(v Value) (numeric, error) { +func newIntegralNumeric(v Value) (evalResult, error) { str := v.ToString() switch { case v.IsSigned(): ival, err := strconv.ParseInt(str, 10, 64) if err != nil { - return numeric{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "%v", err) + return evalResult{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "%v", err) } - return numeric{ival: ival, typ: Int64}, nil + return evalResult{ival: ival, typ: Int64}, nil case v.IsUnsigned(): uval, err := strconv.ParseUint(str, 10, 64) if err != nil { - return numeric{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "%v", err) + return evalResult{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "%v", err) } - return numeric{uval: uval, typ: Uint64}, nil + return evalResult{uval: uval, typ: Uint64}, nil } // For other types, do best effort. if ival, err := strconv.ParseInt(str, 10, 64); err == nil { - return numeric{ival: ival, typ: Int64}, nil + return evalResult{ival: ival, typ: Int64}, nil } if uval, err := strconv.ParseUint(str, 10, 64); err == nil { - return numeric{uval: uval, typ: Uint64}, nil + return evalResult{uval: uval, typ: Uint64}, nil } - return numeric{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "could not parse value: '%s'", str) + return evalResult{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "could not parse value: '%s'", str) } -func addNumeric(v1, v2 numeric) numeric { +func addNumeric(v1, v2 evalResult) evalResult { v1, v2 = prioritize(v1, v2) switch v1.typ { case Int64: @@ -439,7 +432,7 @@ func addNumeric(v1, v2 numeric) numeric { panic("unreachable") } -func addNumericWithError(v1, v2 numeric) (numeric, error) { +func addNumericWithError(v1, v2 evalResult) (evalResult, error) { v1, v2 = prioritize(v1, v2) switch v1.typ { case Int64: @@ -457,7 +450,7 @@ func addNumericWithError(v1, v2 numeric) (numeric, error) { panic("unreachable") } -func subtractNumericWithError(v1, v2 numeric) (numeric, error) { +func subtractNumericWithError(v1, v2 evalResult) (evalResult, error) { switch v1.typ { case Int64: switch v2.typ { @@ -483,7 +476,7 @@ func subtractNumericWithError(v1, v2 numeric) (numeric, error) { panic("unreachable") } -func multiplyNumericWithError(v1, v2 numeric) (numeric, error) { +func multiplyNumericWithError(v1, v2 evalResult) (evalResult, error) { v1, v2 = prioritize(v1, v2) switch v1.typ { case Int64: @@ -501,7 +494,7 @@ func multiplyNumericWithError(v1, v2 numeric) (numeric, error) { panic("unreachable") } -func divideNumericWithError(v1, v2 numeric) (numeric, error) { +func divideNumericWithError(v1, v2 evalResult) (evalResult, error) { switch v1.typ { case Int64: return floatDivideAnyWithError(float64(v1.ival), v2) @@ -517,7 +510,7 @@ func divideNumericWithError(v1, v2 numeric) (numeric, error) { // prioritize reorders the input parameters // to be Float64, Uint64, Int64. -func prioritize(v1, v2 numeric) (altv1, altv2 numeric) { +func prioritize(v1, v2 evalResult) (altv1, altv2 evalResult) { switch v1.typ { case Int64: if v2.typ == Uint64 || v2.typ == Float64 { @@ -531,7 +524,7 @@ func prioritize(v1, v2 numeric) (altv1, altv2 numeric) { return v1, v2 } -func intPlusInt(v1, v2 int64) numeric { +func intPlusInt(v1, v2 int64) evalResult { result := v1 + v2 if v1 > 0 && v2 > 0 && result < 0 { goto overflow @@ -539,61 +532,61 @@ func intPlusInt(v1, v2 int64) numeric { if v1 < 0 && v2 < 0 && result > 0 { goto overflow } - return numeric{typ: Int64, ival: result} + return evalResult{typ: Int64, ival: result} overflow: - return numeric{typ: Float64, fval: float64(v1) + float64(v2)} + return evalResult{typ: Float64, fval: float64(v1) + float64(v2)} } -func intPlusIntWithError(v1, v2 int64) (numeric, error) { +func intPlusIntWithError(v1, v2 int64) (evalResult, error) { result := v1 + v2 if (result > v1) != (v2 > 0) { - return numeric{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT value is out of range in %v + %v", v1, v2) + return evalResult{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT value is out of range in %v + %v", v1, v2) } - return numeric{typ: Int64, ival: result}, nil + return evalResult{typ: Int64, ival: result}, nil } -func intMinusIntWithError(v1, v2 int64) (numeric, error) { +func intMinusIntWithError(v1, v2 int64) (evalResult, error) { result := v1 - v2 if (result < v1) != (v2 > 0) { - return numeric{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT value is out of range in %v - %v", v1, v2) + return evalResult{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT value is out of range in %v - %v", v1, v2) } - return numeric{typ: Int64, ival: result}, nil + return evalResult{typ: Int64, ival: result}, nil } -func intTimesIntWithError(v1, v2 int64) (numeric, error) { +func intTimesIntWithError(v1, v2 int64) (evalResult, error) { result := v1 * v2 if v1 != 0 && result/v1 != v2 { - return numeric{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT value is out of range in %v * %v", v1, v2) + return evalResult{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT value is out of range in %v * %v", v1, v2) } - return numeric{typ: Int64, ival: result}, nil + return evalResult{typ: Int64, ival: result}, nil } -func intMinusUintWithError(v1 int64, v2 uint64) (numeric, error) { +func intMinusUintWithError(v1 int64, v2 uint64) (evalResult, error) { if v1 < 0 || v1 < int64(v2) { - return numeric{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in %v - %v", v1, v2) + return evalResult{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in %v - %v", v1, v2) } return uintMinusUintWithError(uint64(v1), v2) } -func uintPlusInt(v1 uint64, v2 int64) numeric { +func uintPlusInt(v1 uint64, v2 int64) evalResult { return uintPlusUint(v1, uint64(v2)) } -func uintPlusIntWithError(v1 uint64, v2 int64) (numeric, error) { +func uintPlusIntWithError(v1 uint64, v2 int64) (evalResult, error) { if v2 < 0 && v1 < uint64(v2) { - return numeric{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in %v + %v", v1, v2) + return evalResult{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in %v + %v", v1, v2) } // convert to int -> uint is because for numeric operators (such as + or -) // where one of the operands is an unsigned integer, the result is unsigned by default. return uintPlusUintWithError(v1, uint64(v2)) } -func uintMinusIntWithError(v1 uint64, v2 int64) (numeric, error) { +func uintMinusIntWithError(v1 uint64, v2 int64) (evalResult, error) { if int64(v1) < v2 && v2 > 0 { - return numeric{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in %v - %v", v1, v2) + return evalResult{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in %v - %v", v1, v2) } // uint - (- int) = uint + int if v2 < 0 { @@ -602,77 +595,77 @@ func uintMinusIntWithError(v1 uint64, v2 int64) (numeric, error) { return uintMinusUintWithError(v1, uint64(v2)) } -func uintTimesIntWithError(v1 uint64, v2 int64) (numeric, error) { +func uintTimesIntWithError(v1 uint64, v2 int64) (evalResult, error) { if v2 < 0 || int64(v1) < 0 { - return numeric{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in %v * %v", v1, v2) + return evalResult{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in %v * %v", v1, v2) } return uintTimesUintWithError(v1, uint64(v2)) } -func uintPlusUint(v1, v2 uint64) numeric { +func uintPlusUint(v1, v2 uint64) evalResult { result := v1 + v2 if result < v2 { - return numeric{typ: Float64, fval: float64(v1) + float64(v2)} + return evalResult{typ: Float64, fval: float64(v1) + float64(v2)} } - return numeric{typ: Uint64, uval: result} + return evalResult{typ: Uint64, uval: result} } -func uintPlusUintWithError(v1, v2 uint64) (numeric, error) { +func uintPlusUintWithError(v1, v2 uint64) (evalResult, error) { result := v1 + v2 if result < v2 { - return numeric{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in %v + %v", v1, v2) + return evalResult{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in %v + %v", v1, v2) } - return numeric{typ: Uint64, uval: result}, nil + return evalResult{typ: Uint64, uval: result}, nil } -func uintMinusUintWithError(v1, v2 uint64) (numeric, error) { +func uintMinusUintWithError(v1, v2 uint64) (evalResult, error) { result := v1 - v2 if v2 > v1 { - return numeric{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in %v - %v", v1, v2) + return evalResult{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in %v - %v", v1, v2) } - return numeric{typ: Uint64, uval: result}, nil + return evalResult{typ: Uint64, uval: result}, nil } -func uintTimesUintWithError(v1, v2 uint64) (numeric, error) { +func uintTimesUintWithError(v1, v2 uint64) (evalResult, error) { result := v1 * v2 if result < v2 || result < v1 { - return numeric{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in %v * %v", v1, v2) + return evalResult{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in %v * %v", v1, v2) } - return numeric{typ: Uint64, uval: result}, nil + return evalResult{typ: Uint64, uval: result}, nil } -func floatPlusAny(v1 float64, v2 numeric) numeric { +func floatPlusAny(v1 float64, v2 evalResult) evalResult { switch v2.typ { case Int64: v2.fval = float64(v2.ival) case Uint64: v2.fval = float64(v2.uval) } - return numeric{typ: Float64, fval: v1 + v2.fval} + return evalResult{typ: Float64, fval: v1 + v2.fval} } -func floatMinusAny(v1 float64, v2 numeric) numeric { +func floatMinusAny(v1 float64, v2 evalResult) evalResult { switch v2.typ { case Int64: v2.fval = float64(v2.ival) case Uint64: v2.fval = float64(v2.uval) } - return numeric{typ: Float64, fval: v1 - v2.fval} + return evalResult{typ: Float64, fval: v1 - v2.fval} } -func floatTimesAny(v1 float64, v2 numeric) numeric { +func floatTimesAny(v1 float64, v2 evalResult) evalResult { switch v2.typ { case Int64: v2.fval = float64(v2.ival) case Uint64: v2.fval = float64(v2.uval) } - return numeric{typ: Float64, fval: v1 * v2.fval} + return evalResult{typ: Float64, fval: v1 * v2.fval} } -func floatDivideAnyWithError(v1 float64, v2 numeric) (numeric, error) { +func floatDivideAnyWithError(v1 float64, v2 evalResult) (evalResult, error) { switch v2.typ { case Int64: v2.fval = float64(v2.ival) @@ -681,26 +674,26 @@ func floatDivideAnyWithError(v1 float64, v2 numeric) (numeric, error) { } result := v1 / v2.fval divisorLessThanOne := v2.fval < 1 - resultMismatch := (v2.fval*result != v1) + resultMismatch := v2.fval*result != v1 if divisorLessThanOne && resultMismatch { - return numeric{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT is out of range in %v / %v", v1, v2.fval) + return evalResult{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT is out of range in %v / %v", v1, v2.fval) } - return numeric{typ: Float64, fval: v1 / v2.fval}, nil + return evalResult{typ: Float64, fval: v1 / v2.fval}, nil } -func anyMinusFloat(v1 numeric, v2 float64) numeric { +func anyMinusFloat(v1 evalResult, v2 float64) evalResult { switch v1.typ { case Int64: v1.fval = float64(v1.ival) case Uint64: v1.fval = float64(v1.uval) } - return numeric{typ: Float64, fval: v1.fval - v2} + return evalResult{typ: Float64, fval: v1.fval - v2} } -func castFromNumeric(v numeric, resultType querypb.Type) Value { +func castFromNumeric(v evalResult, resultType querypb.Type) Value { switch { case IsSigned(resultType): switch v.typ { @@ -735,11 +728,13 @@ func castFromNumeric(v numeric, resultType querypb.Type) Value { } return MakeTrusted(resultType, strconv.AppendFloat(nil, v.fval, format, -1, 64)) } + case resultType == VarChar: + return MakeTrusted(resultType, []byte(v.str)) } return NULL } -func compareNumeric(v1, v2 numeric) int { +func compareNumeric(v1, v2 evalResult) int { // Equalize the types. switch v1.typ { case Int64: @@ -748,9 +743,9 @@ func compareNumeric(v1, v2 numeric) int { if v1.ival < 0 { return -1 } - v1 = numeric{typ: Uint64, uval: uint64(v1.ival)} + v1 = evalResult{typ: Uint64, uval: uint64(v1.ival)} case Float64: - v1 = numeric{typ: Float64, fval: float64(v1.ival)} + v1 = evalResult{typ: Float64, fval: float64(v1.ival)} } case Uint64: switch v2.typ { @@ -758,16 +753,16 @@ func compareNumeric(v1, v2 numeric) int { if v2.ival < 0 { return 1 } - v2 = numeric{typ: Uint64, uval: uint64(v2.ival)} + v2 = evalResult{typ: Uint64, uval: uint64(v2.ival)} case Float64: - v1 = numeric{typ: Float64, fval: float64(v1.uval)} + v1 = evalResult{typ: Float64, fval: float64(v1.uval)} } case Float64: switch v2.typ { case Int64: - v2 = numeric{typ: Float64, fval: float64(v2.ival)} + v2 = evalResult{typ: Float64, fval: float64(v2.ival)} case Uint64: - v2 = numeric{typ: Float64, fval: float64(v2.uval)} + v2 = evalResult{typ: Float64, fval: float64(v2.uval)} } } diff --git a/go/sqltypes/arithmetic_test.go b/go/sqltypes/arithmetic_test.go index d874ea97efd..50d8cfde79d 100644 --- a/go/sqltypes/arithmetic_test.go +++ b/go/sqltypes/arithmetic_test.go @@ -932,25 +932,25 @@ func TestToNative(t *testing.T) { func TestNewNumeric(t *testing.T) { tcases := []struct { v Value - out numeric + out evalResult err error }{{ v: NewInt64(1), - out: numeric{typ: Int64, ival: 1}, + out: evalResult{typ: Int64, ival: 1}, }, { v: NewUint64(1), - out: numeric{typ: Uint64, uval: 1}, + out: evalResult{typ: Uint64, uval: 1}, }, { v: NewFloat64(1), - out: numeric{typ: Float64, fval: 1}, + out: evalResult{typ: Float64, fval: 1}, }, { // For non-number type, Int64 is the default. v: TestValue(VarChar, "1"), - out: numeric{typ: Int64, ival: 1}, + out: evalResult{typ: Int64, ival: 1}, }, { // If Int64 can't work, we use Float64. v: TestValue(VarChar, "1.2"), - out: numeric{typ: Float64, fval: 1.2}, + out: evalResult{typ: Float64, fval: 1.2}, }, { // Only valid Int64 allowed if type is Int64. v: TestValue(Int64, "1.2"), @@ -965,7 +965,7 @@ func TestNewNumeric(t *testing.T) { err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "strconv.ParseFloat: parsing \"abcd\": invalid syntax"), }, { v: TestValue(VarChar, "abcd"), - out: numeric{typ: Float64, fval: 0}, + out: evalResult{typ: Float64, fval: 0}, }} for _, tcase := range tcases { got, err := newNumeric(tcase.v) @@ -985,25 +985,25 @@ func TestNewNumeric(t *testing.T) { func TestNewIntegralNumeric(t *testing.T) { tcases := []struct { v Value - out numeric + out evalResult err error }{{ v: NewInt64(1), - out: numeric{typ: Int64, ival: 1}, + out: evalResult{typ: Int64, ival: 1}, }, { v: NewUint64(1), - out: numeric{typ: Uint64, uval: 1}, + out: evalResult{typ: Uint64, uval: 1}, }, { v: NewFloat64(1), - out: numeric{typ: Int64, ival: 1}, + out: evalResult{typ: Int64, ival: 1}, }, { // For non-number type, Int64 is the default. v: TestValue(VarChar, "1"), - out: numeric{typ: Int64, ival: 1}, + out: evalResult{typ: Int64, ival: 1}, }, { // If Int64 can't work, we use Uint64. v: TestValue(VarChar, "18446744073709551615"), - out: numeric{typ: Uint64, uval: 18446744073709551615}, + out: evalResult{typ: Uint64, uval: 18446744073709551615}, }, { // Only valid Int64 allowed if type is Int64. v: TestValue(Int64, "1.2"), @@ -1033,52 +1033,52 @@ func TestNewIntegralNumeric(t *testing.T) { func TestAddNumeric(t *testing.T) { tcases := []struct { - v1, v2 numeric - out numeric + v1, v2 evalResult + out evalResult err error }{{ - v1: numeric{typ: Int64, ival: 1}, - v2: numeric{typ: Int64, ival: 2}, - out: numeric{typ: Int64, ival: 3}, + v1: evalResult{typ: Int64, ival: 1}, + v2: evalResult{typ: Int64, ival: 2}, + out: evalResult{typ: Int64, ival: 3}, }, { - v1: numeric{typ: Int64, ival: 1}, - v2: numeric{typ: Uint64, uval: 2}, - out: numeric{typ: Uint64, uval: 3}, + v1: evalResult{typ: Int64, ival: 1}, + v2: evalResult{typ: Uint64, uval: 2}, + out: evalResult{typ: Uint64, uval: 3}, }, { - v1: numeric{typ: Int64, ival: 1}, - v2: numeric{typ: Float64, fval: 2}, - out: numeric{typ: Float64, fval: 3}, + v1: evalResult{typ: Int64, ival: 1}, + v2: evalResult{typ: Float64, fval: 2}, + out: evalResult{typ: Float64, fval: 3}, }, { - v1: numeric{typ: Uint64, uval: 1}, - v2: numeric{typ: Uint64, uval: 2}, - out: numeric{typ: Uint64, uval: 3}, + v1: evalResult{typ: Uint64, uval: 1}, + v2: evalResult{typ: Uint64, uval: 2}, + out: evalResult{typ: Uint64, uval: 3}, }, { - v1: numeric{typ: Uint64, uval: 1}, - v2: numeric{typ: Float64, fval: 2}, - out: numeric{typ: Float64, fval: 3}, + v1: evalResult{typ: Uint64, uval: 1}, + v2: evalResult{typ: Float64, fval: 2}, + out: evalResult{typ: Float64, fval: 3}, }, { - v1: numeric{typ: Float64, fval: 1}, - v2: numeric{typ: Float64, fval: 2}, - out: numeric{typ: Float64, fval: 3}, + v1: evalResult{typ: Float64, fval: 1}, + v2: evalResult{typ: Float64, fval: 2}, + out: evalResult{typ: Float64, fval: 3}, }, { // Int64 overflow. - v1: numeric{typ: Int64, ival: 9223372036854775807}, - v2: numeric{typ: Int64, ival: 2}, - out: numeric{typ: Float64, fval: 9223372036854775809}, + v1: evalResult{typ: Int64, ival: 9223372036854775807}, + v2: evalResult{typ: Int64, ival: 2}, + out: evalResult{typ: Float64, fval: 9223372036854775809}, }, { // Int64 underflow. - v1: numeric{typ: Int64, ival: -9223372036854775807}, - v2: numeric{typ: Int64, ival: -2}, - out: numeric{typ: Float64, fval: -9223372036854775809}, + v1: evalResult{typ: Int64, ival: -9223372036854775807}, + v2: evalResult{typ: Int64, ival: -2}, + out: evalResult{typ: Float64, fval: -9223372036854775809}, }, { - v1: numeric{typ: Int64, ival: -1}, - v2: numeric{typ: Uint64, uval: 2}, - out: numeric{typ: Float64, fval: 18446744073709551617}, + v1: evalResult{typ: Int64, ival: -1}, + v2: evalResult{typ: Uint64, uval: 2}, + out: evalResult{typ: Float64, fval: 18446744073709551617}, }, { // Uint64 overflow. - v1: numeric{typ: Uint64, uval: 18446744073709551615}, - v2: numeric{typ: Uint64, uval: 2}, - out: numeric{typ: Float64, fval: 18446744073709551617}, + v1: evalResult{typ: Uint64, uval: 18446744073709551615}, + v2: evalResult{typ: Uint64, uval: 2}, + out: evalResult{typ: Float64, fval: 18446744073709551617}, }} for _, tcase := range tcases { got := addNumeric(tcase.v1, tcase.v2) @@ -1090,13 +1090,13 @@ func TestAddNumeric(t *testing.T) { } func TestPrioritize(t *testing.T) { - ival := numeric{typ: Int64} - uval := numeric{typ: Uint64} - fval := numeric{typ: Float64} + ival := evalResult{typ: Int64} + uval := evalResult{typ: Uint64} + fval := evalResult{typ: Float64} tcases := []struct { - v1, v2 numeric - out1, out2 numeric + v1, v2 evalResult + out1, out2 evalResult }{{ v1: ival, v2: uval, @@ -1139,62 +1139,62 @@ func TestPrioritize(t *testing.T) { func TestCastFromNumeric(t *testing.T) { tcases := []struct { typ querypb.Type - v numeric + v evalResult out Value err error }{{ typ: Int64, - v: numeric{typ: Int64, ival: 1}, + v: evalResult{typ: Int64, ival: 1}, out: NewInt64(1), }, { typ: Int64, - v: numeric{typ: Uint64, uval: 1}, + v: evalResult{typ: Uint64, uval: 1}, out: NewInt64(1), }, { typ: Int64, - v: numeric{typ: Float64, fval: 1.2e-16}, + v: evalResult{typ: Float64, fval: 1.2e-16}, out: NewInt64(0), }, { typ: Uint64, - v: numeric{typ: Int64, ival: 1}, + v: evalResult{typ: Int64, ival: 1}, out: NewUint64(1), }, { typ: Uint64, - v: numeric{typ: Uint64, uval: 1}, + v: evalResult{typ: Uint64, uval: 1}, out: NewUint64(1), }, { typ: Uint64, - v: numeric{typ: Float64, fval: 1.2e-16}, + v: evalResult{typ: Float64, fval: 1.2e-16}, out: NewUint64(0), }, { typ: Float64, - v: numeric{typ: Int64, ival: 1}, + v: evalResult{typ: Int64, ival: 1}, out: TestValue(Float64, "1"), }, { typ: Float64, - v: numeric{typ: Uint64, uval: 1}, + v: evalResult{typ: Uint64, uval: 1}, out: TestValue(Float64, "1"), }, { typ: Float64, - v: numeric{typ: Float64, fval: 1.2e-16}, + v: evalResult{typ: Float64, fval: 1.2e-16}, out: TestValue(Float64, "1.2e-16"), }, { typ: Decimal, - v: numeric{typ: Int64, ival: 1}, + v: evalResult{typ: Int64, ival: 1}, out: TestValue(Decimal, "1"), }, { typ: Decimal, - v: numeric{typ: Uint64, uval: 1}, + v: evalResult{typ: Uint64, uval: 1}, out: TestValue(Decimal, "1"), }, { // For float, we should not use scientific notation. typ: Decimal, - v: numeric{typ: Float64, fval: 1.2e-16}, + v: evalResult{typ: Float64, fval: 1.2e-16}, out: TestValue(Decimal, "0.00000000000000012"), }, { typ: VarBinary, - v: numeric{typ: Int64, ival: 1}, - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "unexpected type conversion to non-numeric: VARBINARY"), + v: evalResult{typ: Int64, ival: 1}, + err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "unexpected type conversion to non-evalResult: VARBINARY"), }} for _, tcase := range tcases { got := castFromNumeric(tcase.v, tcase.typ) @@ -1207,125 +1207,125 @@ func TestCastFromNumeric(t *testing.T) { func TestCompareNumeric(t *testing.T) { tcases := []struct { - v1, v2 numeric + v1, v2 evalResult out int }{{ - v1: numeric{typ: Int64, ival: 1}, - v2: numeric{typ: Int64, ival: 1}, + v1: evalResult{typ: Int64, ival: 1}, + v2: evalResult{typ: Int64, ival: 1}, out: 0, }, { - v1: numeric{typ: Int64, ival: 1}, - v2: numeric{typ: Int64, ival: 2}, + v1: evalResult{typ: Int64, ival: 1}, + v2: evalResult{typ: Int64, ival: 2}, out: -1, }, { - v1: numeric{typ: Int64, ival: 2}, - v2: numeric{typ: Int64, ival: 1}, + v1: evalResult{typ: Int64, ival: 2}, + v2: evalResult{typ: Int64, ival: 1}, out: 1, }, { // Special case. - v1: numeric{typ: Int64, ival: -1}, - v2: numeric{typ: Uint64, uval: 1}, + v1: evalResult{typ: Int64, ival: -1}, + v2: evalResult{typ: Uint64, uval: 1}, out: -1, }, { - v1: numeric{typ: Int64, ival: 1}, - v2: numeric{typ: Uint64, uval: 1}, + v1: evalResult{typ: Int64, ival: 1}, + v2: evalResult{typ: Uint64, uval: 1}, out: 0, }, { - v1: numeric{typ: Int64, ival: 1}, - v2: numeric{typ: Uint64, uval: 2}, + v1: evalResult{typ: Int64, ival: 1}, + v2: evalResult{typ: Uint64, uval: 2}, out: -1, }, { - v1: numeric{typ: Int64, ival: 2}, - v2: numeric{typ: Uint64, uval: 1}, + v1: evalResult{typ: Int64, ival: 2}, + v2: evalResult{typ: Uint64, uval: 1}, out: 1, }, { - v1: numeric{typ: Int64, ival: 1}, - v2: numeric{typ: Float64, fval: 1}, + v1: evalResult{typ: Int64, ival: 1}, + v2: evalResult{typ: Float64, fval: 1}, out: 0, }, { - v1: numeric{typ: Int64, ival: 1}, - v2: numeric{typ: Float64, fval: 2}, + v1: evalResult{typ: Int64, ival: 1}, + v2: evalResult{typ: Float64, fval: 2}, out: -1, }, { - v1: numeric{typ: Int64, ival: 2}, - v2: numeric{typ: Float64, fval: 1}, + v1: evalResult{typ: Int64, ival: 2}, + v2: evalResult{typ: Float64, fval: 1}, out: 1, }, { // Special case. - v1: numeric{typ: Uint64, uval: 1}, - v2: numeric{typ: Int64, ival: -1}, + v1: evalResult{typ: Uint64, uval: 1}, + v2: evalResult{typ: Int64, ival: -1}, out: 1, }, { - v1: numeric{typ: Uint64, uval: 1}, - v2: numeric{typ: Int64, ival: 1}, + v1: evalResult{typ: Uint64, uval: 1}, + v2: evalResult{typ: Int64, ival: 1}, out: 0, }, { - v1: numeric{typ: Uint64, uval: 1}, - v2: numeric{typ: Int64, ival: 2}, + v1: evalResult{typ: Uint64, uval: 1}, + v2: evalResult{typ: Int64, ival: 2}, out: -1, }, { - v1: numeric{typ: Uint64, uval: 2}, - v2: numeric{typ: Int64, ival: 1}, + v1: evalResult{typ: Uint64, uval: 2}, + v2: evalResult{typ: Int64, ival: 1}, out: 1, }, { - v1: numeric{typ: Uint64, uval: 1}, - v2: numeric{typ: Uint64, uval: 1}, + v1: evalResult{typ: Uint64, uval: 1}, + v2: evalResult{typ: Uint64, uval: 1}, out: 0, }, { - v1: numeric{typ: Uint64, uval: 1}, - v2: numeric{typ: Uint64, uval: 2}, + v1: evalResult{typ: Uint64, uval: 1}, + v2: evalResult{typ: Uint64, uval: 2}, out: -1, }, { - v1: numeric{typ: Uint64, uval: 2}, - v2: numeric{typ: Uint64, uval: 1}, + v1: evalResult{typ: Uint64, uval: 2}, + v2: evalResult{typ: Uint64, uval: 1}, out: 1, }, { - v1: numeric{typ: Uint64, uval: 1}, - v2: numeric{typ: Float64, fval: 1}, + v1: evalResult{typ: Uint64, uval: 1}, + v2: evalResult{typ: Float64, fval: 1}, out: 0, }, { - v1: numeric{typ: Uint64, uval: 1}, - v2: numeric{typ: Float64, fval: 2}, + v1: evalResult{typ: Uint64, uval: 1}, + v2: evalResult{typ: Float64, fval: 2}, out: -1, }, { - v1: numeric{typ: Uint64, uval: 2}, - v2: numeric{typ: Float64, fval: 1}, + v1: evalResult{typ: Uint64, uval: 2}, + v2: evalResult{typ: Float64, fval: 1}, out: 1, }, { - v1: numeric{typ: Float64, fval: 1}, - v2: numeric{typ: Int64, ival: 1}, + v1: evalResult{typ: Float64, fval: 1}, + v2: evalResult{typ: Int64, ival: 1}, out: 0, }, { - v1: numeric{typ: Float64, fval: 1}, - v2: numeric{typ: Int64, ival: 2}, + v1: evalResult{typ: Float64, fval: 1}, + v2: evalResult{typ: Int64, ival: 2}, out: -1, }, { - v1: numeric{typ: Float64, fval: 2}, - v2: numeric{typ: Int64, ival: 1}, + v1: evalResult{typ: Float64, fval: 2}, + v2: evalResult{typ: Int64, ival: 1}, out: 1, }, { - v1: numeric{typ: Float64, fval: 1}, - v2: numeric{typ: Uint64, uval: 1}, + v1: evalResult{typ: Float64, fval: 1}, + v2: evalResult{typ: Uint64, uval: 1}, out: 0, }, { - v1: numeric{typ: Float64, fval: 1}, - v2: numeric{typ: Uint64, uval: 2}, + v1: evalResult{typ: Float64, fval: 1}, + v2: evalResult{typ: Uint64, uval: 2}, out: -1, }, { - v1: numeric{typ: Float64, fval: 2}, - v2: numeric{typ: Uint64, uval: 1}, + v1: evalResult{typ: Float64, fval: 2}, + v2: evalResult{typ: Uint64, uval: 1}, out: 1, }, { - v1: numeric{typ: Float64, fval: 1}, - v2: numeric{typ: Float64, fval: 1}, + v1: evalResult{typ: Float64, fval: 1}, + v2: evalResult{typ: Float64, fval: 1}, out: 0, }, { - v1: numeric{typ: Float64, fval: 1}, - v2: numeric{typ: Float64, fval: 2}, + v1: evalResult{typ: Float64, fval: 1}, + v2: evalResult{typ: Float64, fval: 2}, out: -1, }, { - v1: numeric{typ: Float64, fval: 2}, - v2: numeric{typ: Float64, fval: 1}, + v1: evalResult{typ: Float64, fval: 2}, + v2: evalResult{typ: Float64, fval: 1}, out: 1, }} for _, tcase := range tcases { @@ -1499,8 +1499,8 @@ func BenchmarkAddGoInterface(b *testing.B) { } func BenchmarkAddGoNonInterface(b *testing.B) { - v1 := numeric{typ: Int64, ival: 1} - v2 := numeric{typ: Int64, ival: 12} + v1 := evalResult{typ: Int64, ival: 1} + v2 := evalResult{typ: Int64, ival: 12} for i := 0; i < b.N; i++ { if v1.typ != Int64 { b.Error("type assertion failed") @@ -1508,7 +1508,7 @@ func BenchmarkAddGoNonInterface(b *testing.B) { if v2.typ != Int64 { b.Error("type assertion failed") } - v1 = numeric{typ: Int64, ival: v1.ival + v2.ival} + v1 = evalResult{typ: Int64, ival: v1.ival + v2.ival} } } diff --git a/go/sqltypes/expressions.go b/go/sqltypes/expressions.go index ae2f0293b3b..87eb6ebe3e8 100644 --- a/go/sqltypes/expressions.go +++ b/go/sqltypes/expressions.go @@ -25,7 +25,13 @@ import ( ) type ( - + evalResult struct { + typ querypb.Type + ival int64 + uval uint64 + fval float64 + str string + } //ExpressionEnv contains the environment that the expression //evaluates in, such as the current row and bindvars ExpressionEnv struct { @@ -34,12 +40,13 @@ type ( } // We use this type alias so we don't have to expose the private struct numeric - EvalResult = numeric + EvalResult = evalResult // Expr is the interface that all evaluating expressions must implement Expr interface { Evaluate(env ExpressionEnv) (EvalResult, error) String() string + Type(env ExpressionEnv) querypb.Type } //BinaryExpr allows binary expressions to not have to evaluate child expressions - this is done by the BinaryOp @@ -63,6 +70,44 @@ type ( Division struct{} ) +func prioritizeTypes(ltype, rtype querypb.Type) (querypb.Type, querypb.Type) { + switch ltype { + case Int64: + if rtype == Uint64 || rtype == Float64 { + return rtype, ltype + } + case Uint64: + if rtype == Float64 { + return rtype, ltype + } + } + return ltype, rtype +} + +func (b *BinaryOp) Type(env ExpressionEnv) querypb.Type { + ltype := b.Left.Type(env) + rtype := b.Right.Type(env) + ltype, rtype = prioritizeTypes(ltype, rtype) + switch ltype { + case Int64: + return Int64 + case Uint64: + return Uint64 + case Float64: + return Float64 + } + panic("unreachable") +} + +func (b *BindVariable) Type(env ExpressionEnv) querypb.Type { + e := env.BindVars + return e[b.Key].Type +} + +func (l *LiteralInt) Type(env ExpressionEnv) querypb.Type { + return Int64 +} + func (d *Division) String() string { return "/" } @@ -123,7 +168,7 @@ func (l *LiteralInt) Evaluate(env ExpressionEnv) (EvalResult, error) { if err != nil { ival = 0 } - return numeric{typ: Int64, ival: ival}, nil + return evalResult{typ: Int64, ival: ival}, nil } //Evaluate implements the Expr interface @@ -132,12 +177,33 @@ func (b *BindVariable) Evaluate(env ExpressionEnv) (EvalResult, error) { if !ok { return EvalResult{}, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "Bind variable not found") } - ival, err := strconv.ParseInt(string(val.Value), 10, 64) - if err != nil { - ival = 0 - } - return numeric{typ: Int64, ival: ival}, nil + return evaluateByType(val) +} +func evaluateByType(val *querypb.BindVariable) (EvalResult, error) { + switch val.Type { + case Int64: + ival, err := strconv.ParseInt(string(val.Value), 10, 64) + if err != nil { + ival = 0 + } + return evalResult{typ: Int64, ival: ival}, nil + case Uint64: + uval, err := strconv.ParseUint(string(val.Value), 10, 64) + if err != nil { + uval = 0 + } + return evalResult{typ: Uint64, uval: uval}, nil + case Float64: + fval, err := strconv.ParseFloat(string(val.Value), 64) + if err != nil { + fval = 0 + } + return evalResult{typ: Float64, fval: fval}, nil + case VarChar: + return evalResult{typ: VarChar, str: string(val.Value)}, nil + } + return evalResult{}, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "Type is not supported") } //Evaluate implements the BinaryOp interface diff --git a/go/sqltypes/expressions_test.go b/go/sqltypes/expressions_test.go index 9d393eb4b1b..486e8d5f166 100644 --- a/go/sqltypes/expressions_test.go +++ b/go/sqltypes/expressions_test.go @@ -16,4 +16,4 @@ limitations under the License. package sqltypes -// these tests live in go/sqltypes/expressions_test.go +// these tests live in go/sqlparser/expressions_test.go diff --git a/go/vt/sqlparser/expressions_test.go b/go/vt/sqlparser/expressions_test.go index 134f0a7c341..3e65b3ca76f 100644 --- a/go/vt/sqlparser/expressions_test.go +++ b/go/vt/sqlparser/expressions_test.go @@ -55,6 +55,15 @@ func TestEvaluate(t *testing.T) { }, { expression: ":exp", expected: sqltypes.NewInt64(66), + }, { + expression: ":uint64_bind_variable", + expected: sqltypes.NewUint64(22), + }, { + expression: ":string_bind_variable", + expected: sqltypes.NewVarChar("bar"), + }, { + expression: ":float_bind_variable", + expected: sqltypes.NewFloat64(2.2), }} for _, test := range tests { @@ -67,7 +76,10 @@ func TestEvaluate(t *testing.T) { require.NotNil(t, sqltypesExpr) env := sqltypes.ExpressionEnv{ BindVars: map[string]*querypb.BindVariable{ - "exp": sqltypes.Int64BindVariable(66), + "exp": sqltypes.Int64BindVariable(66), + "string_bind_variable": sqltypes.StringBindVariable("bar"), + "uint64_bind_variable": sqltypes.Uint64BindVariable(22), + "float_bind_variable": sqltypes.Float64BindVariable(2.2), }, Row: nil, } diff --git a/go/vt/vtgate/endtoend/last_insert_id_test.go b/go/vt/vtgate/endtoend/last_insert_id_test.go index 81b90d4fe76..b7ac4a70452 100644 --- a/go/vt/vtgate/endtoend/last_insert_id_test.go +++ b/go/vt/vtgate/endtoend/last_insert_id_test.go @@ -44,7 +44,7 @@ func TestLastInsertId(t *testing.T) { // even without a transaction, we should get the last inserted id back qr = exec(t, conn, "select last_insert_id()") got := fmt.Sprintf("%v", qr.Rows) - want := fmt.Sprintf("[[INT64(%d)]]", oldLastID+1) + want := fmt.Sprintf("[[UINT64(%d)]]", oldLastID+1) if diff := cmp.Diff(want, got); diff != "" { t.Error(diff) @@ -67,7 +67,7 @@ func TestLastInsertIdWithRollback(t *testing.T) { exec(t, conn, "insert into t1_last_insert_id(id1) values(42)") qr = exec(t, conn, "select last_insert_id()") got := fmt.Sprintf("%v", qr.Rows) - want := fmt.Sprintf("[[INT64(%d)]]", oldLastID+1) + want := fmt.Sprintf("[[UINT64(%d)]]", oldLastID+1) if diff := cmp.Diff(want, got); diff != "" { t.Error(diff) diff --git a/go/vt/vtgate/engine/projection.go b/go/vt/vtgate/engine/projection.go index 8b713be8d7f..e2af8acd230 100644 --- a/go/vt/vtgate/engine/projection.go +++ b/go/vt/vtgate/engine/projection.go @@ -36,7 +36,7 @@ func (p *Projection) Execute(vcursor VCursor, bindVars map[string]*querypb.BindV } if wantfields { - p.addFields(result) + p.addFields(result, bindVars) } var rows [][]sqltypes.Value for _, row := range result.Rows { @@ -63,15 +63,16 @@ func (p *Projection) GetFields(vcursor VCursor, bindVars map[string]*querypb.Bin if err != nil { return nil, err } - p.addFields(qr) + p.addFields(qr, bindVars) return qr, nil } -func (p *Projection) addFields(qr *sqltypes.Result) { - for _, col := range p.Cols { +func (p *Projection) addFields(qr *sqltypes.Result, bindVars map[string]*querypb.BindVariable) { + env := sqltypes.ExpressionEnv{BindVars: bindVars} + for i, col := range p.Cols { qr.Fields = append(qr.Fields, &querypb.Field{ Name: col, - Type: querypb.Type_INT64, + Type: p.Exprs[i].Type(env), }) } } diff --git a/go/vt/vtgate/planbuilder/select.go b/go/vt/vtgate/planbuilder/select.go index d59f463c229..21941c8e84a 100644 --- a/go/vt/vtgate/planbuilder/select.go +++ b/go/vt/vtgate/planbuilder/select.go @@ -21,7 +21,6 @@ import ( "fmt" "vitess.io/vitess/go/sqltypes" - "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vtgate/engine" ) @@ -79,8 +78,9 @@ func (pb *primitiveBuilder) processSelect(sel *sqlparser.Select, outer *symtab, exprs := make([]sqltypes.Expr, len(sel.SelectExprs)) cols := make([]string, len(sel.SelectExprs)) for i, e := range sel.SelectExprs { - exprs[i] = sqlparser.Convert(e.(*sqlparser.AliasedExpr).Expr) - cols[i] = e.(*sqlparser.AliasedExpr).As.String() + expr := e.(*sqlparser.AliasedExpr) + exprs[i] = sqlparser.Convert(expr.Expr) + cols[i] = expr.As.String() } pb.bldr = &vtgateExecution{ exprs, From 26017802e110c4a23407a36bfa4d38166471ca89 Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Sat, 11 Apr 2020 11:14:18 +0200 Subject: [PATCH 11/23] Handle type logic for binary operators Signed-off-by: Andres Taylor --- go/sqltypes/expressions.go | 45 ++++++++++++----- go/sqltypes/expressions_test.go | 87 ++++++++++++++++++++++++++++++++- 2 files changed, 119 insertions(+), 13 deletions(-) diff --git a/go/sqltypes/expressions.go b/go/sqltypes/expressions.go index 87eb6ebe3e8..5368e734bf8 100644 --- a/go/sqltypes/expressions.go +++ b/go/sqltypes/expressions.go @@ -45,13 +45,14 @@ type ( // Expr is the interface that all evaluating expressions must implement Expr interface { Evaluate(env ExpressionEnv) (EvalResult, error) - String() string Type(env ExpressionEnv) querypb.Type + String() string } //BinaryExpr allows binary expressions to not have to evaluate child expressions - this is done by the BinaryOp BinaryExpr interface { Evaluate(left, right EvalResult) (EvalResult, error) + Type(left, right querypb.Type) querypb.Type String() string } @@ -88,15 +89,7 @@ func (b *BinaryOp) Type(env ExpressionEnv) querypb.Type { ltype := b.Left.Type(env) rtype := b.Right.Type(env) ltype, rtype = prioritizeTypes(ltype, rtype) - switch ltype { - case Int64: - return Int64 - case Uint64: - return Uint64 - case Float64: - return Float64 - } - panic("unreachable") + return b.Expr.Type(ltype, rtype) } func (b *BindVariable) Type(env ExpressionEnv) querypb.Type { @@ -104,7 +97,7 @@ func (b *BindVariable) Type(env ExpressionEnv) querypb.Type { return e[b.Key].Type } -func (l *LiteralInt) Type(env ExpressionEnv) querypb.Type { +func (l *LiteralInt) Type(_ ExpressionEnv) querypb.Type { return Int64 } @@ -163,7 +156,7 @@ func (b *BinaryOp) Evaluate(env ExpressionEnv) (EvalResult, error) { } //Evaluate implements the Expr interface -func (l *LiteralInt) Evaluate(env ExpressionEnv) (EvalResult, error) { +func (l *LiteralInt) Evaluate(ExpressionEnv) (EvalResult, error) { ival, err := strconv.ParseInt(string(l.Val), 10, 64) if err != nil { ival = 0 @@ -225,3 +218,31 @@ func (m *Multiplication) Evaluate(left, right EvalResult) (EvalResult, error) { func (d *Division) Evaluate(left, right EvalResult) (EvalResult, error) { return divideNumericWithError(left, right) } + +func (a *Addition) Type(left, _ querypb.Type) querypb.Type { + return addAndSubtractType(left) +} + +func addAndSubtractType(left querypb.Type) querypb.Type { + switch left { + case Int64: + return Int64 + case Uint64: + return Uint64 + case Float64: + return Float64 + } + panic("oops") +} + +func (m *Multiplication) Type(left, _ querypb.Type) querypb.Type { + return addAndSubtractType(left) +} + +func (d *Division) Type(querypb.Type, querypb.Type) querypb.Type { + return Float64 +} + +func (s *Subtraction) Type(left, _ querypb.Type) querypb.Type { + return addAndSubtractType(left) +} diff --git a/go/sqltypes/expressions_test.go b/go/sqltypes/expressions_test.go index 486e8d5f166..651d16eff20 100644 --- a/go/sqltypes/expressions_test.go +++ b/go/sqltypes/expressions_test.go @@ -16,4 +16,89 @@ limitations under the License. package sqltypes -// these tests live in go/sqlparser/expressions_test.go +import ( + "fmt" + "reflect" + "testing" + + "github.com/magiconair/properties/assert" + querypb "vitess.io/vitess/go/vt/proto/query" +) + +// more tests in go/sqlparser/expressions_test.go + +func TestBinaryOpTypes(t *testing.T) { + type testcase struct { + l, r, e querypb.Type + } + type ops struct { + op BinaryExpr + testcases []testcase + } + + tests := []ops{ + { + op: &Addition{}, + testcases: []testcase{ + {Int64, Int64, Int64}, + {Uint64, Int64, Uint64}, + {Float64, Int64, Float64}, + {Int64, Uint64, Int64}, + {Uint64, Uint64, Uint64}, + {Float64, Uint64, Float64}, + {Int64, Float64, Int64}, + {Uint64, Float64, Uint64}, + {Float64, Float64, Float64}, + }, + }, { + op: &Subtraction{}, + testcases: []testcase{ + {Int64, Int64, Int64}, + {Uint64, Int64, Uint64}, + {Float64, Int64, Float64}, + {Int64, Uint64, Int64}, + {Uint64, Uint64, Uint64}, + {Float64, Uint64, Float64}, + {Int64, Float64, Int64}, + {Uint64, Float64, Uint64}, + {Float64, Float64, Float64}, + }, + }, { + op: &Multiplication{}, + testcases: []testcase{ + {Int64, Int64, Int64}, + {Uint64, Int64, Uint64}, + {Float64, Int64, Float64}, + {Int64, Uint64, Int64}, + {Uint64, Uint64, Uint64}, + {Float64, Uint64, Float64}, + {Int64, Float64, Int64}, + {Uint64, Float64, Uint64}, + {Float64, Float64, Float64}, + }, + }, { + op: &Division{}, + testcases: []testcase{ + {Int64, Int64, Float64}, + {Uint64, Int64, Float64}, + {Float64, Int64, Float64}, + {Int64, Uint64, Float64}, + {Uint64, Uint64, Float64}, + {Float64, Uint64, Float64}, + {Int64, Float64, Float64}, + {Uint64, Float64, Float64}, + {Float64, Float64, Float64}, + }, + }, + } + + for _, op := range tests { + for _, tc := range op.testcases { + name := fmt.Sprintf("%s %s %s", tc.l.String(), reflect.TypeOf(op.op).String(), tc.r.String()) + t.Run(name, func(t *testing.T) { + result := op.op.Type(tc.l, tc.r) + assert.Equal(t, tc.e, result) + }) + } + } +} From 46c228ef8cad6486c79e2475d4c4804609ed4e7c Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Sat, 11 Apr 2020 11:24:23 +0200 Subject: [PATCH 12/23] Update test assertions Signed-off-by: Andres Taylor --- go/vt/vtgate/executor_select_test.go | 16 ++++++++-------- go/vt/vtgate/plan_executor_select_test.go | 16 ++++++++-------- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/go/vt/vtgate/executor_select_test.go b/go/vt/vtgate/executor_select_test.go index 3d728f589dc..f5a10fd4b0a 100644 --- a/go/vt/vtgate/executor_select_test.go +++ b/go/vt/vtgate/executor_select_test.go @@ -232,10 +232,10 @@ func TestSelectLastInsertId(t *testing.T) { result, err := executorExec(executor, sql, map[string]*querypb.BindVariable{}) wantResult := &sqltypes.Result{ Fields: []*querypb.Field{ - {Name: "last_insert_id()", Type: sqltypes.Int64}, + {Name: "last_insert_id()", Type: sqltypes.Uint64}, }, Rows: [][]sqltypes.Value{{ - sqltypes.NewInt64(42), + sqltypes.NewUint64(42), }}, } require.NoError(t, err) @@ -254,10 +254,10 @@ func TestSelectUserDefindVariable(t *testing.T) { require.NoError(t, err) wantResult := &sqltypes.Result{ Fields: []*querypb.Field{ - {Name: "@foo", Type: sqltypes.Int64}, + {Name: "@foo", Type: sqltypes.VarChar}, }, Rows: [][]sqltypes.Value{{ - sqltypes.NewInt64(0), + sqltypes.NewVarChar("bar"), }}, } require.NoError(t, err) @@ -278,10 +278,10 @@ func TestFoundRows(t *testing.T) { result, err := executorExec(executor, sql, map[string]*querypb.BindVariable{}) wantResult := &sqltypes.Result{ Fields: []*querypb.Field{ - {Name: "found_rows()", Type: sqltypes.Int64}, + {Name: "found_rows()", Type: sqltypes.Uint64}, }, Rows: [][]sqltypes.Value{{ - sqltypes.NewInt64(0), + sqltypes.NewUint64(0), }}, } require.NoError(t, err) @@ -386,10 +386,10 @@ func TestSelectDatabase(t *testing.T) { map[string]*querypb.BindVariable{}) wantResult := &sqltypes.Result{ Fields: []*querypb.Field{ - {Name: "database()", Type: sqltypes.Int64}, + {Name: "database()", Type: sqltypes.VarChar}, }, Rows: [][]sqltypes.Value{{ - sqltypes.NewInt64(0), + sqltypes.NewVarChar("TestExecutor"), }}, } require.NoError(t, err) diff --git a/go/vt/vtgate/plan_executor_select_test.go b/go/vt/vtgate/plan_executor_select_test.go index 402f8b4c5de..25a1118cfb9 100644 --- a/go/vt/vtgate/plan_executor_select_test.go +++ b/go/vt/vtgate/plan_executor_select_test.go @@ -232,10 +232,10 @@ func TestPlanSelectLastInsertId(t *testing.T) { result, err := executorExec(executor, sql, map[string]*querypb.BindVariable{}) wantResult := &sqltypes.Result{ Fields: []*querypb.Field{ - {Name: "last_insert_id()", Type: sqltypes.Int64}, + {Name: "last_insert_id()", Type: sqltypes.Uint64}, }, Rows: [][]sqltypes.Value{{ - sqltypes.NewInt64(52), + sqltypes.NewUint64(52), }}, } require.NoError(t, err) @@ -255,10 +255,10 @@ func TestPlanSelectUserDefindVariable(t *testing.T) { require.NoError(t, err) wantResult := &sqltypes.Result{ Fields: []*querypb.Field{ - {Name: "@foo", Type: sqltypes.Int64}, + {Name: "@foo", Type: sqltypes.VarChar}, }, Rows: [][]sqltypes.Value{{ - sqltypes.NewInt64(0), + sqltypes.NewVarChar("bar"), }}, } require.NoError(t, err) @@ -279,10 +279,10 @@ func TestPlanFoundRows(t *testing.T) { result, err := executorExec(executor, sql, map[string]*querypb.BindVariable{}) wantResult := &sqltypes.Result{ Fields: []*querypb.Field{ - {Name: "found_rows()", Type: sqltypes.Int64}, + {Name: "found_rows()", Type: sqltypes.Uint64}, }, Rows: [][]sqltypes.Value{{ - sqltypes.NewInt64(0), + sqltypes.NewUint64(0), }}, } require.NoError(t, err) @@ -388,10 +388,10 @@ func TestPlanSelectDatabase(t *testing.T) { map[string]*querypb.BindVariable{}) wantResult := &sqltypes.Result{ Fields: []*querypb.Field{ - {Name: "database()", Type: sqltypes.Int64}, + {Name: "database()", Type: sqltypes.VarChar}, }, Rows: [][]sqltypes.Value{{ - sqltypes.NewInt64(0), + sqltypes.NewVarChar("TestExecutor"), }}, } require.NoError(t, err) From cfcb9046c5c52ce241a5fee8069619b7029b5f6e Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Sun, 12 Apr 2020 09:51:49 +0200 Subject: [PATCH 13/23] Code clean up Signed-off-by: Andres Taylor --- go/sqltypes/expressions.go | 199 +++++++++++++++++--------------- go/sqltypes/expressions_test.go | 2 +- 2 files changed, 108 insertions(+), 93 deletions(-) diff --git a/go/sqltypes/expressions.go b/go/sqltypes/expressions.go index 5368e734bf8..32a9afc9734 100644 --- a/go/sqltypes/expressions.go +++ b/go/sqltypes/expressions.go @@ -39,7 +39,7 @@ type ( Row []Value } - // We use this type alias so we don't have to expose the private struct numeric + // EvalResult is used so we don't have to expose all parts of the private struct EvalResult = evalResult // Expr is the interface that all evaluating expressions must implement @@ -52,7 +52,7 @@ type ( //BinaryExpr allows binary expressions to not have to evaluate child expressions - this is done by the BinaryOp BinaryExpr interface { Evaluate(left, right EvalResult) (EvalResult, error) - Type(left, right querypb.Type) querypb.Type + Type(left querypb.Type) querypb.Type String() string } @@ -71,106 +71,157 @@ type ( Division struct{} ) -func prioritizeTypes(ltype, rtype querypb.Type) (querypb.Type, querypb.Type) { - switch ltype { - case Int64: - if rtype == Uint64 || rtype == Float64 { - return rtype, ltype - } - case Uint64: - if rtype == Float64 { - return rtype, ltype - } +//Value allows for retrieval of the value we expose for public consumption +func (e EvalResult) Value() Value { + return castFromNumeric(e, e.typ) +} + +var _ Expr = (*LiteralInt)(nil) +var _ Expr = (*BindVariable)(nil) +var _ Expr = (*BinaryOp)(nil) + +var _ BinaryExpr = (*Addition)(nil) +var _ BinaryExpr = (*Subtraction)(nil) +var _ BinaryExpr = (*Multiplication)(nil) +var _ BinaryExpr = (*Division)(nil) + +//Evaluate implements the Expr interface +func (b *BinaryOp) Evaluate(env ExpressionEnv) (EvalResult, error) { + lVal, err := b.Left.Evaluate(env) + if err != nil { + return EvalResult{}, err + } + rVal, err := b.Right.Evaluate(env) + if err != nil { + return EvalResult{}, err + } + return b.Expr.Evaluate(lVal, rVal) +} + +//Evaluate implements the Expr interface +func (l *LiteralInt) Evaluate(ExpressionEnv) (EvalResult, error) { + ival, err := strconv.ParseInt(string(l.Val), 10, 64) + if err != nil { + ival = 0 + } + return evalResult{typ: Int64, ival: ival}, nil +} + +//Evaluate implements the Expr interface +func (b *BindVariable) Evaluate(env ExpressionEnv) (EvalResult, error) { + val, ok := env.BindVars[b.Key] + if !ok { + return EvalResult{}, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "Bind variable not found") } - return ltype, rtype + return evaluateByType(val) +} + +//Evaluate implements the BinaryOp interface +func (a *Addition) Evaluate(left, right EvalResult) (EvalResult, error) { + return addNumericWithError(left, right) +} + +//Evaluate implements the BinaryOp interface +func (s *Subtraction) Evaluate(left, right EvalResult) (EvalResult, error) { + return subtractNumericWithError(left, right) +} + +//Evaluate implements the BinaryOp interface +func (m *Multiplication) Evaluate(left, right EvalResult) (EvalResult, error) { + return multiplyNumericWithError(left, right) +} + +//Evaluate implements the BinaryOp interface +func (d *Division) Evaluate(left, right EvalResult) (EvalResult, error) { + return divideNumericWithError(left, right) +} + +//Type implements the BinaryExpr interface +func (a *Addition) Type(left querypb.Type) querypb.Type { + return addAndSubtractType(left) +} + +//Type implements the BinaryExpr interface +func (m *Multiplication) Type(left querypb.Type) querypb.Type { + return addAndSubtractType(left) +} + +//Type implements the BinaryExpr interface +func (d *Division) Type(querypb.Type) querypb.Type { + return Float64 +} + +//Type implements the BinaryExpr interface +func (s *Subtraction) Type(left querypb.Type) querypb.Type { + return addAndSubtractType(left) } +//Type implements the Expr interface func (b *BinaryOp) Type(env ExpressionEnv) querypb.Type { ltype := b.Left.Type(env) rtype := b.Right.Type(env) - ltype, rtype = prioritizeTypes(ltype, rtype) - return b.Expr.Type(ltype, rtype) + typ := mergeNumericalTypes(ltype, rtype) + return b.Expr.Type(typ) } +//Type implements the Expr interface func (b *BindVariable) Type(env ExpressionEnv) querypb.Type { e := env.BindVars return e[b.Key].Type } +//Type implements the Expr interface func (l *LiteralInt) Type(_ ExpressionEnv) querypb.Type { return Int64 } +//String implements the BinaryExpr interface func (d *Division) String() string { return "/" } +//String implements the BinaryExpr interface func (m *Multiplication) String() string { return "*" } +//String implements the BinaryExpr interface func (s *Subtraction) String() string { return "-" } +//String implements the BinaryExpr interface func (a *Addition) String() string { return "+" } +//String implements the Expr interface func (b *BinaryOp) String() string { return b.Left.String() + " " + b.Expr.String() + " " + b.Right.String() } +//String implements the Expr interface func (b *BindVariable) String() string { return ":" + b.Key } +//String implements the Expr interface func (l *LiteralInt) String() string { return string(l.Val) } -//Value allows for retrieval of the value we need -func (e EvalResult) Value() Value { - return castFromNumeric(e, e.typ) -} - -var _ Expr = (*LiteralInt)(nil) -var _ Expr = (*BindVariable)(nil) -var _ Expr = (*BinaryOp)(nil) -var _ BinaryExpr = (*Addition)(nil) -var _ BinaryExpr = (*Subtraction)(nil) -var _ BinaryExpr = (*Multiplication)(nil) -var _ BinaryExpr = (*Division)(nil) - -//Evaluate implements the Expr interface -func (b *BinaryOp) Evaluate(env ExpressionEnv) (EvalResult, error) { - lVal, err := b.Left.Evaluate(env) - if err != nil { - return EvalResult{}, err - } - rVal, err := b.Right.Evaluate(env) - if err != nil { - return EvalResult{}, err - } - return b.Expr.Evaluate(lVal, rVal) -} - -//Evaluate implements the Expr interface -func (l *LiteralInt) Evaluate(ExpressionEnv) (EvalResult, error) { - ival, err := strconv.ParseInt(string(l.Val), 10, 64) - if err != nil { - ival = 0 - } - return evalResult{typ: Int64, ival: ival}, nil -} - -//Evaluate implements the Expr interface -func (b *BindVariable) Evaluate(env ExpressionEnv) (EvalResult, error) { - val, ok := env.BindVars[b.Key] - if !ok { - return EvalResult{}, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "Bind variable not found") +func mergeNumericalTypes(ltype, rtype querypb.Type) querypb.Type { + switch ltype { + case Int64: + if rtype == Uint64 || rtype == Float64 { + return rtype + } + case Uint64: + if rtype == Float64 { + return rtype + } } - return evaluateByType(val) + return ltype } func evaluateByType(val *querypb.BindVariable) (EvalResult, error) { @@ -199,30 +250,6 @@ func evaluateByType(val *querypb.BindVariable) (EvalResult, error) { return evalResult{}, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "Type is not supported") } -//Evaluate implements the BinaryOp interface -func (a *Addition) Evaluate(left, right EvalResult) (EvalResult, error) { - return addNumericWithError(left, right) -} - -//Evaluate implements the BinaryOp interface -func (s *Subtraction) Evaluate(left, right EvalResult) (EvalResult, error) { - return subtractNumericWithError(left, right) -} - -//Evaluate implements the BinaryOp interface -func (m *Multiplication) Evaluate(left, right EvalResult) (EvalResult, error) { - return multiplyNumericWithError(left, right) -} - -//Evaluate implements the BinaryOp interface -func (d *Division) Evaluate(left, right EvalResult) (EvalResult, error) { - return divideNumericWithError(left, right) -} - -func (a *Addition) Type(left, _ querypb.Type) querypb.Type { - return addAndSubtractType(left) -} - func addAndSubtractType(left querypb.Type) querypb.Type { switch left { case Int64: @@ -234,15 +261,3 @@ func addAndSubtractType(left querypb.Type) querypb.Type { } panic("oops") } - -func (m *Multiplication) Type(left, _ querypb.Type) querypb.Type { - return addAndSubtractType(left) -} - -func (d *Division) Type(querypb.Type, querypb.Type) querypb.Type { - return Float64 -} - -func (s *Subtraction) Type(left, _ querypb.Type) querypb.Type { - return addAndSubtractType(left) -} diff --git a/go/sqltypes/expressions_test.go b/go/sqltypes/expressions_test.go index 651d16eff20..6163ca32886 100644 --- a/go/sqltypes/expressions_test.go +++ b/go/sqltypes/expressions_test.go @@ -96,7 +96,7 @@ func TestBinaryOpTypes(t *testing.T) { for _, tc := range op.testcases { name := fmt.Sprintf("%s %s %s", tc.l.String(), reflect.TypeOf(op.op).String(), tc.r.String()) t.Run(name, func(t *testing.T) { - result := op.op.Type(tc.l, tc.r) + result := op.op.Type(tc.l) assert.Equal(t, tc.e, result) }) } From 95f73096e791538a1644d790ac035f75090a1c9c Mon Sep 17 00:00:00 2001 From: Rohit Nayak Date: Tue, 14 Apr 2020 10:29:49 +0200 Subject: [PATCH 14/23] Add expression evaluation fallback mechanism Signed-off-by: Rohit Nayak --- go.mod | 1 + go/sqltypes/arithmetic.go | 2 +- go/sqltypes/expressions.go | 2 ++ go/vt/sqlparser/expression_converter.go | 32 ++++++++++++----- go/vt/sqlparser/expressions_test.go | 5 +-- go/vt/vtgate/engine/projection.go | 1 + go/vt/vtgate/engine/singlerow.go | 1 + go/vt/vtgate/executor_select_test.go | 8 ++--- go/vt/vtgate/plan_executor_select_test.go | 8 ++--- go/vt/vtgate/planbuilder/select.go | 42 ++++++++++++++++------- 10 files changed, 69 insertions(+), 33 deletions(-) diff --git a/go.mod b/go.mod index dd26e4c38d9..5f10ba6ce7c 100644 --- a/go.mod +++ b/go.mod @@ -98,6 +98,7 @@ require ( github.com/konsorten/go-windows-terminal-sequences v1.0.2 // indirect github.com/krishicks/yaml-patch v0.0.10 github.com/lyft/protoc-gen-validate v0.0.0-20180911180927-64fcb82c878e // indirect + github.com/magiconair/properties v1.8.1 github.com/mattn/go-isatty v0.0.11 // indirect github.com/minio/minio-go v0.0.0-20190131015406-c8a261de75c1 github.com/mitchellh/copystructure v0.0.0-20160804032330-cdac8253d00f // indirect diff --git a/go/sqltypes/arithmetic.go b/go/sqltypes/arithmetic.go index 4e94333ba53..90f4d8f8415 100644 --- a/go/sqltypes/arithmetic.go +++ b/go/sqltypes/arithmetic.go @@ -728,7 +728,7 @@ func castFromNumeric(v evalResult, resultType querypb.Type) Value { } return MakeTrusted(resultType, strconv.AppendFloat(nil, v.fval, format, -1, 64)) } - case resultType == VarChar: + case resultType == VarChar || resultType == VarBinary: return MakeTrusted(resultType, []byte(v.str)) } return NULL diff --git a/go/sqltypes/expressions.go b/go/sqltypes/expressions.go index 32a9afc9734..6793e3a1343 100644 --- a/go/sqltypes/expressions.go +++ b/go/sqltypes/expressions.go @@ -246,6 +246,8 @@ func evaluateByType(val *querypb.BindVariable) (EvalResult, error) { return evalResult{typ: Float64, fval: fval}, nil case VarChar: return evalResult{typ: VarChar, str: string(val.Value)}, nil + case VarBinary: + return evalResult{typ: VarBinary, str: string(val.Value)}, nil } return evalResult{}, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "Type is not supported") } diff --git a/go/vt/sqlparser/expression_converter.go b/go/vt/sqlparser/expression_converter.go index d04a62c69ac..dcf118bfe3c 100644 --- a/go/vt/sqlparser/expression_converter.go +++ b/go/vt/sqlparser/expression_converter.go @@ -16,17 +16,23 @@ limitations under the License. package sqlparser -import "vitess.io/vitess/go/sqltypes" +import ( + "fmt" + + "vitess.io/vitess/go/sqltypes" +) + +var ExprNotSupported = fmt.Errorf("Expr Not Supported") //Convert converts between AST expressions and executable expressions -func Convert(e Expr) sqltypes.Expr { +func Convert(e Expr) (sqltypes.Expr, error) { switch node := e.(type) { case *SQLVal: switch node.Type { case IntVal: - return &sqltypes.LiteralInt{Val: node.Val} + return &sqltypes.LiteralInt{Val: node.Val}, nil case ValArg: - return &sqltypes.BindVariable{Key: string(node.Val[1:])} + return &sqltypes.BindVariable{Key: string(node.Val[1:])}, nil } case *BinaryExpr: var op sqltypes.BinaryExpr @@ -40,14 +46,22 @@ func Convert(e Expr) sqltypes.Expr { case DivStr: op = &sqltypes.Division{} default: - return nil + return nil, ExprNotSupported + } + left, err := Convert(node.Left) + if err != nil { + return nil, err + } + right, err := Convert(node.Right) + if err != nil { + return nil, err } return &sqltypes.BinaryOp{ Expr: op, - Left: Convert(node.Left), - Right: Convert(node.Right), - } + Left: left, + Right: right, + }, nil } - return nil + return nil, ExprNotSupported } diff --git a/go/vt/sqlparser/expressions_test.go b/go/vt/sqlparser/expressions_test.go index 3e65b3ca76f..ebe30a8b750 100644 --- a/go/vt/sqlparser/expressions_test.go +++ b/go/vt/sqlparser/expressions_test.go @@ -60,7 +60,7 @@ func TestEvaluate(t *testing.T) { expected: sqltypes.NewUint64(22), }, { expression: ":string_bind_variable", - expected: sqltypes.NewVarChar("bar"), + expected: sqltypes.NewVarBinary("bar"), }, { expression: ":float_bind_variable", expected: sqltypes.NewFloat64(2.2), @@ -72,7 +72,8 @@ func TestEvaluate(t *testing.T) { stmt, err := Parse("select " + test.expression) require.NoError(t, err) astExpr := stmt.(*Select).SelectExprs[0].(*AliasedExpr).Expr - sqltypesExpr := Convert(astExpr) + sqltypesExpr, err := Convert(astExpr) + require.Nil(t, err) require.NotNil(t, sqltypesExpr) env := sqltypes.ExpressionEnv{ BindVars: map[string]*querypb.BindVariable{ diff --git a/go/vt/vtgate/engine/projection.go b/go/vt/vtgate/engine/projection.go index e2af8acd230..35db5a94420 100644 --- a/go/vt/vtgate/engine/projection.go +++ b/go/vt/vtgate/engine/projection.go @@ -11,6 +11,7 @@ type Projection struct { Cols []string Exprs []sqltypes.Expr Input Primitive + noTxNeeded } func (p *Projection) RouteType() string { diff --git a/go/vt/vtgate/engine/singlerow.go b/go/vt/vtgate/engine/singlerow.go index 8d54f25ec6a..d7c2e1fefe4 100644 --- a/go/vt/vtgate/engine/singlerow.go +++ b/go/vt/vtgate/engine/singlerow.go @@ -26,6 +26,7 @@ var _ Primitive = (*SingleRow)(nil) // SingleRow defines an empty result type SingleRow struct { noInputs + noTxNeeded } // RouteType returns a description of the query routing type used by the primitive diff --git a/go/vt/vtgate/executor_select_test.go b/go/vt/vtgate/executor_select_test.go index f5a10fd4b0a..3e5eb02f5ce 100644 --- a/go/vt/vtgate/executor_select_test.go +++ b/go/vt/vtgate/executor_select_test.go @@ -254,10 +254,10 @@ func TestSelectUserDefindVariable(t *testing.T) { require.NoError(t, err) wantResult := &sqltypes.Result{ Fields: []*querypb.Field{ - {Name: "@foo", Type: sqltypes.VarChar}, + {Name: "@foo", Type: sqltypes.VarBinary}, }, Rows: [][]sqltypes.Value{{ - sqltypes.NewVarChar("bar"), + sqltypes.NewVarBinary("bar"), }}, } require.NoError(t, err) @@ -386,10 +386,10 @@ func TestSelectDatabase(t *testing.T) { map[string]*querypb.BindVariable{}) wantResult := &sqltypes.Result{ Fields: []*querypb.Field{ - {Name: "database()", Type: sqltypes.VarChar}, + {Name: "database()", Type: sqltypes.VarBinary}, }, Rows: [][]sqltypes.Value{{ - sqltypes.NewVarChar("TestExecutor"), + sqltypes.NewVarBinary("TestExecutor"), }}, } require.NoError(t, err) diff --git a/go/vt/vtgate/plan_executor_select_test.go b/go/vt/vtgate/plan_executor_select_test.go index 25a1118cfb9..e3d7ac5bb3d 100644 --- a/go/vt/vtgate/plan_executor_select_test.go +++ b/go/vt/vtgate/plan_executor_select_test.go @@ -255,10 +255,10 @@ func TestPlanSelectUserDefindVariable(t *testing.T) { require.NoError(t, err) wantResult := &sqltypes.Result{ Fields: []*querypb.Field{ - {Name: "@foo", Type: sqltypes.VarChar}, + {Name: "@foo", Type: sqltypes.VarBinary}, }, Rows: [][]sqltypes.Value{{ - sqltypes.NewVarChar("bar"), + sqltypes.NewVarBinary("bar"), }}, } require.NoError(t, err) @@ -388,10 +388,10 @@ func TestPlanSelectDatabase(t *testing.T) { map[string]*querypb.BindVariable{}) wantResult := &sqltypes.Result{ Fields: []*querypb.Field{ - {Name: "database()", Type: sqltypes.VarChar}, + {Name: "database()", Type: sqltypes.VarBinary}, }, Rows: [][]sqltypes.Value{{ - sqltypes.NewVarChar("TestExecutor"), + sqltypes.NewVarBinary("TestExecutor"), }}, } require.NoError(t, err) diff --git a/go/vt/vtgate/planbuilder/select.go b/go/vt/vtgate/planbuilder/select.go index 21941c8e84a..65da3ce426e 100644 --- a/go/vt/vtgate/planbuilder/select.go +++ b/go/vt/vtgate/planbuilder/select.go @@ -21,6 +21,7 @@ import ( "fmt" "vitess.io/vitess/go/sqltypes" + "vitess.io/vitess/go/vt/log" "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vtgate/engine" ) @@ -74,19 +75,12 @@ func buildSelectPlan(stmt sqlparser.Statement, vschema ContextVSchema) (engine.P // pushed into a route, then a primitive is created on top of any // of the above trees to make it discard unwanted rows. func (pb *primitiveBuilder) processSelect(sel *sqlparser.Select, outer *symtab, canRunLocally bool) error { - if checkForDual(sel) && outer == nil && canRunLocally { - exprs := make([]sqltypes.Expr, len(sel.SelectExprs)) - cols := make([]string, len(sel.SelectExprs)) - for i, e := range sel.SelectExprs { - expr := e.(*sqlparser.AliasedExpr) - exprs[i] = sqlparser.Convert(expr.Expr) - cols[i] = expr.As.String() - } - pb.bldr = &vtgateExecution{ - exprs, - cols, - } - return nil + err, done := pb.runAtVtgate(sel, outer, canRunLocally) + if err == sqlparser.ExprNotSupported { + log.Warningf("Expression not supported at vtgate level") + } + if done { + return err } if err := pb.processTableExprs(sel.From); err != nil { return err @@ -136,6 +130,28 @@ func (pb *primitiveBuilder) processSelect(sel *sqlparser.Select, outer *symtab, return nil } +func (pb *primitiveBuilder) runAtVtgate(sel *sqlparser.Select, outer *symtab, canRunLocally bool) (error, bool) { + if checkForDual(sel) && outer == nil && canRunLocally { + var err error + exprs := make([]sqltypes.Expr, len(sel.SelectExprs)) + cols := make([]string, len(sel.SelectExprs)) + for i, e := range sel.SelectExprs { + expr := e.(*sqlparser.AliasedExpr) + exprs[i], err = sqlparser.Convert(expr.Expr) + if err != nil { + return err, false + } + cols[i] = expr.As.String() + } + pb.bldr = &vtgateExecution{ + exprs, + cols, + } + return nil, true + } + return nil, false +} + func checkForDual(sel *sqlparser.Select) bool { if len(sel.From) == 1 { if from, ok := sel.From[0].(*sqlparser.AliasedTableExpr); ok { From 085f7ff838bf0e0e62c7bb614fa78878195a60aa Mon Sep 17 00:00:00 2001 From: Rohit Nayak Date: Tue, 14 Apr 2020 10:32:54 +0200 Subject: [PATCH 15/23] Removed unused test Signed-off-by: Rohit Nayak --- go/vt/vtgate/engine/projection_test.go | 30 -------------------------- 1 file changed, 30 deletions(-) delete mode 100644 go/vt/vtgate/engine/projection_test.go diff --git a/go/vt/vtgate/engine/projection_test.go b/go/vt/vtgate/engine/projection_test.go deleted file mode 100644 index 9f008eaa38c..00000000000 --- a/go/vt/vtgate/engine/projection_test.go +++ /dev/null @@ -1,30 +0,0 @@ -/* -Copyright 2020 The Vitess Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package engine - -import ( - "testing" -) - -func TestEvaluate(t *testing.T) { - //statement, _ := sqlparser.Parse("select 42") - //sel := statement.(*sqlparser.Select) - //exp := sel.SelectExprs[0].(*sqlparser.AliasedExpr) - // - //result := Evaluate(exp.Expr, map[string]*query.BindVariable{}, []sqltypes.Value{}) - //fmt.Println(result) -} From 09e5e4e1df4e4b1ab1fd4458b22b6093c9f72a7c Mon Sep 17 00:00:00 2001 From: Rohit Nayak Date: Tue, 14 Apr 2020 11:08:22 +0200 Subject: [PATCH 16/23] Extract vtgate execution from primitive builder Signed-off-by: Rohit Nayak --- go/vt/vtgate/planbuilder/expr.go | 4 +- go/vt/vtgate/planbuilder/from.go | 2 +- go/vt/vtgate/planbuilder/select.go | 34 ++++---- go/vt/vtgate/planbuilder/union.go | 2 +- go/vt/vtgate/planbuilder/vtgate_execution.go | 82 -------------------- 5 files changed, 20 insertions(+), 104 deletions(-) delete mode 100644 go/vt/vtgate/planbuilder/vtgate_execution.go diff --git a/go/vt/vtgate/planbuilder/expr.go b/go/vt/vtgate/planbuilder/expr.go index e0640be45bf..e7a51664e6b 100644 --- a/go/vt/vtgate/planbuilder/expr.go +++ b/go/vt/vtgate/planbuilder/expr.go @@ -104,7 +104,7 @@ func (pb *primitiveBuilder) findOrigin(expr sqlparser.Expr) (pullouts []*pullout spb := newPrimitiveBuilder(pb.vschema, pb.jt) switch stmt := node.Select.(type) { case *sqlparser.Select: - if err := spb.processSelect(stmt, pb.st, false); err != nil { + if err := spb.processSelect(stmt, pb.st); err != nil { return false, err } case *sqlparser.Union: @@ -230,7 +230,7 @@ func (pb *primitiveBuilder) finalizeUnshardedDMLSubqueries(nodes ...sqlparser.SQ return true, nil } spb := newPrimitiveBuilder(pb.vschema, pb.jt) - if err := spb.processSelect(nodeType, pb.st, false); err != nil { + if err := spb.processSelect(nodeType, pb.st); err != nil { samePlan = false return false, err } diff --git a/go/vt/vtgate/planbuilder/from.go b/go/vt/vtgate/planbuilder/from.go index b7013d88947..772d224f13a 100644 --- a/go/vt/vtgate/planbuilder/from.go +++ b/go/vt/vtgate/planbuilder/from.go @@ -97,7 +97,7 @@ func (pb *primitiveBuilder) processAliasedTable(tableExpr *sqlparser.AliasedTabl spb := newPrimitiveBuilder(pb.vschema, pb.jt) switch stmt := expr.Select.(type) { case *sqlparser.Select: - if err := spb.processSelect(stmt, nil, false); err != nil { + if err := spb.processSelect(stmt, nil); err != nil { return err } case *sqlparser.Union: diff --git a/go/vt/vtgate/planbuilder/select.go b/go/vt/vtgate/planbuilder/select.go index 65da3ce426e..6cbdf7f29e1 100644 --- a/go/vt/vtgate/planbuilder/select.go +++ b/go/vt/vtgate/planbuilder/select.go @@ -21,7 +21,6 @@ import ( "fmt" "vitess.io/vitess/go/sqltypes" - "vitess.io/vitess/go/vt/log" "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vtgate/engine" ) @@ -29,8 +28,14 @@ import ( // buildSelectPlan is the new function to build a Select plan. func buildSelectPlan(stmt sqlparser.Statement, vschema ContextVSchema) (engine.Primitive, error) { sel := stmt.(*sqlparser.Select) + + p := tryAtVtgate(sel) + if p != nil { + return p, nil + } + pb := newPrimitiveBuilder(vschema, newJointab(sqlparser.GetBindvars(sel))) - if err := pb.processSelect(sel, nil, true); err != nil { + if err := pb.processSelect(sel, nil); err != nil { return nil, err } if err := pb.bldr.Wireup(pb.bldr, pb.jt); err != nil { @@ -74,14 +79,7 @@ func buildSelectPlan(stmt sqlparser.Statement, vschema ContextVSchema) (engine.P // The LIMIT clause is the last construct of a query. If it cannot be // pushed into a route, then a primitive is created on top of any // of the above trees to make it discard unwanted rows. -func (pb *primitiveBuilder) processSelect(sel *sqlparser.Select, outer *symtab, canRunLocally bool) error { - err, done := pb.runAtVtgate(sel, outer, canRunLocally) - if err == sqlparser.ExprNotSupported { - log.Warningf("Expression not supported at vtgate level") - } - if done { - return err - } +func (pb *primitiveBuilder) processSelect(sel *sqlparser.Select, outer *symtab) error { if err := pb.processTableExprs(sel.From); err != nil { return err } @@ -130,8 +128,8 @@ func (pb *primitiveBuilder) processSelect(sel *sqlparser.Select, outer *symtab, return nil } -func (pb *primitiveBuilder) runAtVtgate(sel *sqlparser.Select, outer *symtab, canRunLocally bool) (error, bool) { - if checkForDual(sel) && outer == nil && canRunLocally { +func tryAtVtgate(sel *sqlparser.Select) engine.Primitive { + if checkForDual(sel) { var err error exprs := make([]sqltypes.Expr, len(sel.SelectExprs)) cols := make([]string, len(sel.SelectExprs)) @@ -139,17 +137,17 @@ func (pb *primitiveBuilder) runAtVtgate(sel *sqlparser.Select, outer *symtab, ca expr := e.(*sqlparser.AliasedExpr) exprs[i], err = sqlparser.Convert(expr.Expr) if err != nil { - return err, false + return nil } cols[i] = expr.As.String() } - pb.bldr = &vtgateExecution{ - exprs, - cols, + return &engine.Projection{ + Exprs: exprs, + Cols: cols, + Input: &engine.SingleRow{}, } - return nil, true } - return nil, false + return nil } func checkForDual(sel *sqlparser.Select) bool { diff --git a/go/vt/vtgate/planbuilder/union.go b/go/vt/vtgate/planbuilder/union.go index a6f7de1b472..71fb9e4182d 100644 --- a/go/vt/vtgate/planbuilder/union.go +++ b/go/vt/vtgate/planbuilder/union.go @@ -62,7 +62,7 @@ func (pb *primitiveBuilder) processPart(part sqlparser.SelectStatement, outer *s case *sqlparser.Union: return pb.processUnion(part, outer) case *sqlparser.Select: - return pb.processSelect(part, outer, false) + return pb.processSelect(part, outer) case *sqlparser.ParenSelect: return pb.processPart(part.Select, outer) } diff --git a/go/vt/vtgate/planbuilder/vtgate_execution.go b/go/vt/vtgate/planbuilder/vtgate_execution.go deleted file mode 100644 index 5785c2eaa89..00000000000 --- a/go/vt/vtgate/planbuilder/vtgate_execution.go +++ /dev/null @@ -1,82 +0,0 @@ -package planbuilder - -import ( - "vitess.io/vitess/go/sqltypes" - "vitess.io/vitess/go/vt/sqlparser" - "vitess.io/vitess/go/vt/vtgate/engine" -) - -var _ builder = (*vtgateExecution)(nil) - -type vtgateExecution struct { - Exprs []sqltypes.Expr - Cols []string -} - -func (v *vtgateExecution) Order() int { - panic("implement me") -} - -func (v *vtgateExecution) ResultColumns() []*resultColumn { - panic("implement me") -} - -func (v *vtgateExecution) Reorder(int) { - panic("implement me") -} - -func (v *vtgateExecution) First() builder { - panic("implement me") -} - -func (v *vtgateExecution) PushFilter(pb *primitiveBuilder, filter sqlparser.Expr, whereType string, origin builder) error { - panic("implement me") -} - -func (v *vtgateExecution) PushSelect(pb *primitiveBuilder, expr *sqlparser.AliasedExpr, origin builder) (rc *resultColumn, colNumber int, err error) { - panic("implement me") -} - -func (v *vtgateExecution) MakeDistinct() error { - panic("implement me") -} - -func (v *vtgateExecution) PushGroupBy(sqlparser.GroupBy) error { - panic("implement me") -} - -func (v *vtgateExecution) PushOrderBy(sqlparser.OrderBy) (builder, error) { - panic("implement me") -} - -func (v *vtgateExecution) SetUpperLimit(count *sqlparser.SQLVal) { - panic("implement me") -} - -func (v *vtgateExecution) PushMisc(sel *sqlparser.Select) { - panic("implement me") -} - -func (v *vtgateExecution) Wireup(bldr builder, jt *jointab) error { - return nil -} - -func (v *vtgateExecution) SupplyVar(from, to int, col *sqlparser.ColName, varname string) { - panic("implement me") -} - -func (v *vtgateExecution) SupplyCol(col *sqlparser.ColName) (rc *resultColumn, colNumber int) { - panic("implement me") -} - -func (v *vtgateExecution) SupplyWeightString(colNumber int) (weightcolNumber int, err error) { - panic("implement me") -} - -func (v *vtgateExecution) Primitive() engine.Primitive { - return &engine.Projection{ - Exprs: v.Exprs, - Cols: v.Cols, - Input: &engine.SingleRow{}, - } -} From ffddcc52962b8bbe5a25be4aaf2ea410a7732cee Mon Sep 17 00:00:00 2001 From: Rohit Nayak Date: Tue, 14 Apr 2020 13:44:12 +0200 Subject: [PATCH 17/23] Support for literal float Signed-off-by: Rohit Nayak --- go/sqltypes/arithmetic_test.go | 4 ---- go/sqltypes/expressions.go | 22 +++++++++++++++++++++ go/vt/sqlparser/expression_converter.go | 2 ++ go/vt/sqlparser/expressions_test.go | 3 +++ go/vt/vtgate/endtoend/database_func_test.go | 2 +- 5 files changed, 28 insertions(+), 5 deletions(-) diff --git a/go/sqltypes/arithmetic_test.go b/go/sqltypes/arithmetic_test.go index 50d8cfde79d..bb4ad5d0d90 100644 --- a/go/sqltypes/arithmetic_test.go +++ b/go/sqltypes/arithmetic_test.go @@ -1191,10 +1191,6 @@ func TestCastFromNumeric(t *testing.T) { typ: Decimal, v: evalResult{typ: Float64, fval: 1.2e-16}, out: TestValue(Decimal, "0.00000000000000012"), - }, { - typ: VarBinary, - v: evalResult{typ: Int64, ival: 1}, - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "unexpected type conversion to non-evalResult: VARBINARY"), }} for _, tcase := range tcases { got := castFromNumeric(tcase.v, tcase.typ) diff --git a/go/sqltypes/expressions.go b/go/sqltypes/expressions.go index 6793e3a1343..00871c10718 100644 --- a/go/sqltypes/expressions.go +++ b/go/sqltypes/expressions.go @@ -58,6 +58,7 @@ type ( // Expressions LiteralInt struct{ Val []byte } + LiteralFloat struct{ Val EvalResult } BindVariable struct{ Key string } BinaryOp struct { Expr BinaryExpr @@ -76,7 +77,16 @@ func (e EvalResult) Value() Value { return castFromNumeric(e, e.typ) } +func NewLiteralFloat(val []byte) (Expr, error) { + fval, err := strconv.ParseFloat(string(val), 64) + if err != nil { + return nil, err + } + return &LiteralFloat{evalResult{typ: Float64, fval: fval}}, nil +} + var _ Expr = (*LiteralInt)(nil) +var _ Expr = (*LiteralFloat)(nil) var _ Expr = (*BindVariable)(nil) var _ Expr = (*BinaryOp)(nil) @@ -107,6 +117,10 @@ func (l *LiteralInt) Evaluate(ExpressionEnv) (EvalResult, error) { return evalResult{typ: Int64, ival: ival}, nil } +func (l *LiteralFloat) Evaluate(env ExpressionEnv) (EvalResult, error) { + return l.Val, nil +} + //Evaluate implements the Expr interface func (b *BindVariable) Evaluate(env ExpressionEnv) (EvalResult, error) { val, ok := env.BindVars[b.Key] @@ -175,6 +189,10 @@ func (l *LiteralInt) Type(_ ExpressionEnv) querypb.Type { return Int64 } +func (l *LiteralFloat) Type(env ExpressionEnv) querypb.Type { + return Float64 +} + //String implements the BinaryExpr interface func (d *Division) String() string { return "/" @@ -210,6 +228,10 @@ func (l *LiteralInt) String() string { return string(l.Val) } +func (l *LiteralFloat) String() string { + return l.Val.Value().String() +} + func mergeNumericalTypes(ltype, rtype querypb.Type) querypb.Type { switch ltype { case Int64: diff --git a/go/vt/sqlparser/expression_converter.go b/go/vt/sqlparser/expression_converter.go index dcf118bfe3c..d54cf1d1294 100644 --- a/go/vt/sqlparser/expression_converter.go +++ b/go/vt/sqlparser/expression_converter.go @@ -31,6 +31,8 @@ func Convert(e Expr) (sqltypes.Expr, error) { switch node.Type { case IntVal: return &sqltypes.LiteralInt{Val: node.Val}, nil + case FloatVal: + return sqltypes.NewLiteralFloat(node.Val) case ValArg: return &sqltypes.BindVariable{Key: string(node.Val[1:])}, nil } diff --git a/go/vt/sqlparser/expressions_test.go b/go/vt/sqlparser/expressions_test.go index ebe30a8b750..ca0940c8545 100644 --- a/go/vt/sqlparser/expressions_test.go +++ b/go/vt/sqlparser/expressions_test.go @@ -40,6 +40,9 @@ func TestEvaluate(t *testing.T) { tests := []testCase{{ expression: "42", expected: sqltypes.NewInt64(42), + }, { + expression: "42.42", + expected: sqltypes.NewFloat64(42.42), }, { expression: "40+2", expected: sqltypes.NewInt64(42), diff --git a/go/vt/vtgate/endtoend/database_func_test.go b/go/vt/vtgate/endtoend/database_func_test.go index e0cb0805f59..3a3b27c65f8 100644 --- a/go/vt/vtgate/endtoend/database_func_test.go +++ b/go/vt/vtgate/endtoend/database_func_test.go @@ -34,7 +34,7 @@ func TestDatabaseFunc(t *testing.T) { exec(t, conn, "use ks") qr := exec(t, conn, "select database()") - if got, want := fmt.Sprintf("%v", qr.Rows), `[[VARCHAR("ks")]]`; got != want { + if got, want := fmt.Sprintf("%v", qr.Rows), `[[VARBINARY("ks")]]`; got != want { t.Errorf("select:\n%v want\n%v", got, want) } } From 5efc63a35e9a60ddc1c9b26fa515f24bf0c3ea02 Mon Sep 17 00:00:00 2001 From: Rohit Nayak Date: Tue, 14 Apr 2020 14:28:09 +0200 Subject: [PATCH 18/23] Compute literal int during plan evaluation Signed-off-by: Rohit Nayak --- go/sqltypes/expressions.go | 18 +++++++++++------- go/vt/sqlparser/expression_converter.go | 2 +- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/go/sqltypes/expressions.go b/go/sqltypes/expressions.go index 00871c10718..147afe298d6 100644 --- a/go/sqltypes/expressions.go +++ b/go/sqltypes/expressions.go @@ -57,7 +57,7 @@ type ( } // Expressions - LiteralInt struct{ Val []byte } + LiteralInt struct{ Val EvalResult } LiteralFloat struct{ Val EvalResult } BindVariable struct{ Key string } BinaryOp struct { @@ -77,6 +77,14 @@ func (e EvalResult) Value() Value { return castFromNumeric(e, e.typ) } +func NewLiteralInt(val []byte) (Expr, error) { + ival, err := strconv.ParseInt(string(val), 10, 64) + if err != nil { + return nil, err + } + return &LiteralFloat{evalResult{typ: Int64, ival: ival}}, nil +} + func NewLiteralFloat(val []byte) (Expr, error) { fval, err := strconv.ParseFloat(string(val), 64) if err != nil { @@ -110,11 +118,7 @@ func (b *BinaryOp) Evaluate(env ExpressionEnv) (EvalResult, error) { //Evaluate implements the Expr interface func (l *LiteralInt) Evaluate(ExpressionEnv) (EvalResult, error) { - ival, err := strconv.ParseInt(string(l.Val), 10, 64) - if err != nil { - ival = 0 - } - return evalResult{typ: Int64, ival: ival}, nil + return l.Val, nil } func (l *LiteralFloat) Evaluate(env ExpressionEnv) (EvalResult, error) { @@ -225,7 +229,7 @@ func (b *BindVariable) String() string { //String implements the Expr interface func (l *LiteralInt) String() string { - return string(l.Val) + return l.Val.Value().String() } func (l *LiteralFloat) String() string { diff --git a/go/vt/sqlparser/expression_converter.go b/go/vt/sqlparser/expression_converter.go index d54cf1d1294..10c200fbb1a 100644 --- a/go/vt/sqlparser/expression_converter.go +++ b/go/vt/sqlparser/expression_converter.go @@ -30,7 +30,7 @@ func Convert(e Expr) (sqltypes.Expr, error) { case *SQLVal: switch node.Type { case IntVal: - return &sqltypes.LiteralInt{Val: node.Val}, nil + return sqltypes.NewLiteralInt(node.Val) case FloatVal: return sqltypes.NewLiteralFloat(node.Val) case ValArg: From f0711f1ba7a897339f9eb3c8b8bf30de7bdf3777 Mon Sep 17 00:00:00 2001 From: Saif Alharthi Date: Tue, 14 Apr 2020 09:54:45 -0700 Subject: [PATCH 19/23] Fix SingleRow's select cases test Signed-off-by: Saif Alharthi --- go/vt/vtgate/planbuilder/testdata/select_cases.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/go/vt/vtgate/planbuilder/testdata/select_cases.txt b/go/vt/vtgate/planbuilder/testdata/select_cases.txt index 17afad5a7a5..c4bc3f25327 100644 --- a/go/vt/vtgate/planbuilder/testdata/select_cases.txt +++ b/go/vt/vtgate/planbuilder/testdata/select_cases.txt @@ -1368,7 +1368,7 @@ "" ], "Expressions": [ - "42" + "INT64(42)" ], "Inputs": [ { From faae89abcbe06c931d33ea50fcdd1a33bd63238a Mon Sep 17 00:00:00 2001 From: Saif Alharthi Date: Tue, 14 Apr 2020 22:22:43 -0700 Subject: [PATCH 20/23] Add StreamExecute for projection and add unreachable comment to arithmetic Addition and Subtraction type check Signed-off-by: Saif Alharthi --- go/sqltypes/expressions.go | 2 +- go/vt/vtgate/engine/projection.go | 27 ++++++++++++++++++++++++++- 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/go/sqltypes/expressions.go b/go/sqltypes/expressions.go index 147afe298d6..4f3fb6c735e 100644 --- a/go/sqltypes/expressions.go +++ b/go/sqltypes/expressions.go @@ -287,5 +287,5 @@ func addAndSubtractType(left querypb.Type) querypb.Type { case Float64: return Float64 } - panic("oops") + panic("unreachable") } diff --git a/go/vt/vtgate/engine/projection.go b/go/vt/vtgate/engine/projection.go index 35db5a94420..9257a9dc02d 100644 --- a/go/vt/vtgate/engine/projection.go +++ b/go/vt/vtgate/engine/projection.go @@ -56,7 +56,32 @@ func (p *Projection) Execute(vcursor VCursor, bindVars map[string]*querypb.BindV } func (p *Projection) StreamExecute(vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantields bool, callback func(*sqltypes.Result) error) error { - panic("implement me") + result, err := p.Input.Execute(vcursor, bindVars, wantields) + if err != nil { + return err + } + + env := sqltypes.ExpressionEnv{ + BindVars: bindVars, + } + + if wantields { + p.addFields(result, bindVars) + } + var rows [][]sqltypes.Value + for _, row := range result.Rows { + env.Row = row + for _, exp := range p.Exprs { + result, err := exp.Evaluate(env) + if err != nil { + return err + } + row = append(row, result.Value()) + } + rows = append(rows, row) + } + result.Rows = rows + return callback(result) } func (p *Projection) GetFields(vcursor VCursor, bindVars map[string]*querypb.BindVariable) (*sqltypes.Result, error) { From d391f0a3549e877bf98caf01f9e668d43a2b55dc Mon Sep 17 00:00:00 2001 From: Saif Alharthi Date: Thu, 16 Apr 2020 21:32:40 -0700 Subject: [PATCH 21/23] Move evaluation expression to its own package Signed-off-by: Saif Alharthi --- go/mysql/endtoend/replication_test.go | 4 +- go/sqltypes/arithmetic_test.go | 1517 ---------------- go/sqltypes/expressions_test.go | 104 -- go/sqltypes/type.go | 4 +- go/sqltypes/type_test.go | 2 +- go/test/endtoend/messaging/message_test.go | 6 +- .../verticalsplit/vertical_split_test.go | 4 +- go/vt/binlog/binlogplayer/binlog_player.go | 6 +- go/vt/binlog/keyspace_id_resolver.go | 4 +- go/vt/mysqlctl/replication.go | 5 +- go/vt/mysqlctl/schema.go | 9 +- go/vt/schemamanager/schemaswap/schema_swap.go | 8 +- go/vt/sqlparser/expression_converter.go | 22 +- go/vt/sqlparser/expressions_test.go | 4 +- go/vt/vitessdriver/convert.go | 4 +- go/vt/vtgate/endtoend/last_insert_id_test.go | 8 +- go/vt/vtgate/engine/insert.go | 4 +- go/vt/vtgate/engine/limit.go | 6 +- go/vt/vtgate/engine/memory_sort.go | 6 +- go/vt/vtgate/engine/merge_sort.go | 4 +- go/vt/vtgate/engine/ordered_aggregate.go | 18 +- go/vt/vtgate/engine/projection.go | 9 +- go/vt/vtgate/engine/route.go | 4 +- go/vt/vtgate/engine/vindex_func.go | 4 +- .../vtgate/evalengine}/arithmetic.go | 349 ++-- go/vt/vtgate/evalengine/arithmetic_test.go | 1519 +++++++++++++++++ .../vtgate/evalengine}/expressions.go | 62 +- go/vt/vtgate/evalengine/expressions_test.go | 106 ++ go/vt/vtgate/planbuilder/select.go | 5 +- go/vt/vtgate/vindexes/consistent_lookup.go | 4 +- go/vt/vtgate/vindexes/hash.go | 6 +- go/vt/vtgate/vindexes/lookup_hash.go | 6 +- .../vindexes/lookup_unicodeloosemd5_hash.go | 6 +- go/vt/vtgate/vindexes/numeric.go | 6 +- go/vt/vtgate/vindexes/numeric_static_map.go | 6 +- go/vt/vtgate/vindexes/region_experimental.go | 6 +- go/vt/vtgate/vindexes/region_json.go | 4 +- go/vt/vtgate/vindexes/reverse_bits.go | 6 +- go/vt/vttablet/heartbeat/reader.go | 4 +- .../tabletmanager/vreplication/engine.go | 4 +- .../tabletmanager/vreplication/vreplicator.go | 4 +- .../tabletserver/messager/message_manager.go | 10 +- .../messager/message_manager_test.go | 4 +- go/vt/vttablet/tabletserver/query_executor.go | 8 +- go/vt/vttablet/tabletserver/rules/rules.go | 6 +- go/vt/vttablet/tabletserver/schema/engine.go | 7 +- go/vt/vttablet/tabletserver/twopc.go | 18 +- .../tabletserver/vstreamer/planbuilder.go | 4 +- go/vt/worker/chunk.go | 6 +- go/vt/worker/diff_utils.go | 6 +- go/vt/worker/key_resolver.go | 3 +- go/vt/worker/split_clone_flaky_test.go | 4 +- go/vt/wrangler/materializer.go | 4 +- go/vt/wrangler/stream_migrater.go | 4 +- go/vt/wrangler/traffic_switcher.go | 4 +- go/vt/wrangler/vdiff.go | 4 +- 56 files changed, 2021 insertions(+), 1940 deletions(-) delete mode 100644 go/sqltypes/arithmetic_test.go delete mode 100644 go/sqltypes/expressions_test.go rename go/{sqltypes => vt/vtgate/evalengine}/arithmetic.go (66%) create mode 100644 go/vt/vtgate/evalengine/arithmetic_test.go rename go/{sqltypes => vt/vtgate/evalengine}/expressions.go (86%) create mode 100644 go/vt/vtgate/evalengine/expressions_test.go diff --git a/go/mysql/endtoend/replication_test.go b/go/mysql/endtoend/replication_test.go index df5debcc3d3..11a7a3c3927 100644 --- a/go/mysql/endtoend/replication_test.go +++ b/go/mysql/endtoend/replication_test.go @@ -25,6 +25,8 @@ import ( "testing" "time" + "vitess.io/vitess/go/vt/vtgate/evalengine" + "golang.org/x/net/context" "vitess.io/vitess/go/mysql" @@ -68,7 +70,7 @@ func connectForReplication(t *testing.T, rbr bool) (*mysql.Conn, mysql.BinlogFor t.Fatalf("SHOW MASTER STATUS returned unexpected result: %v", result) } file := result.Rows[0][0].ToString() - position, err := sqltypes.ToUint64(result.Rows[0][1]) + position, err := evalengine.ToUint64(result.Rows[0][1]) if err != nil { t.Fatalf("SHOW MASTER STATUS returned invalid position: %v", result.Rows[0][1]) } diff --git a/go/sqltypes/arithmetic_test.go b/go/sqltypes/arithmetic_test.go deleted file mode 100644 index bb4ad5d0d90..00000000000 --- a/go/sqltypes/arithmetic_test.go +++ /dev/null @@ -1,1517 +0,0 @@ -/* -Copyright 2019 The Vitess Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package sqltypes - -import ( - "encoding/binary" - "fmt" - "math" - "reflect" - "strconv" - "testing" - - querypb "vitess.io/vitess/go/vt/proto/query" - vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" - "vitess.io/vitess/go/vt/vterrors" -) - -func TestDivide(t *testing.T) { - tcases := []struct { - v1, v2 Value - out Value - err error - }{{ - //All Nulls - v1: NULL, - v2: NULL, - out: NULL, - }, { - // First value null. - v1: NULL, - v2: NewInt32(1), - out: NULL, - }, { - // Second value null. - v1: NewInt32(1), - v2: NULL, - out: NULL, - }, { - // Second arg 0 - v1: NewInt32(5), - v2: NewInt32(0), - out: NULL, - }, { - // Both arguments zero - v1: NewInt32(0), - v2: NewInt32(0), - out: NULL, - }, { - // case with negative value - v1: NewInt64(-1), - v2: NewInt64(-2), - out: NewFloat64(0.5000), - }, { - // float64 division by zero - v1: NewFloat64(2), - v2: NewFloat64(0), - out: NULL, - }, { - // Lower bound for int64 - v1: NewInt64(math.MinInt64), - v2: NewInt64(1), - out: NewFloat64(math.MinInt64), - }, { - // upper bound for uint64 - v1: NewUint64(math.MaxUint64), - v2: NewUint64(1), - out: NewFloat64(math.MaxUint64), - }, { - // testing for error in types - v1: TestValue(Int64, "1.2"), - v2: NewInt64(2), - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "strconv.ParseInt: parsing \"1.2\": invalid syntax"), - }, { - // testing for error in types - v1: NewInt64(2), - v2: TestValue(Int64, "1.2"), - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "strconv.ParseInt: parsing \"1.2\": invalid syntax"), - }, { - // testing for uint/int - v1: NewUint64(4), - v2: NewInt64(5), - out: NewFloat64(0.8), - }, { - // testing for uint/uint - v1: NewUint64(1), - v2: NewUint64(2), - out: NewFloat64(0.5), - }, { - // testing for float64/int64 - v1: TestValue(Float64, "1.2"), - v2: NewInt64(-2), - out: NewFloat64(-0.6), - }, { - // testing for float64/uint64 - v1: TestValue(Float64, "1.2"), - v2: NewUint64(2), - out: NewFloat64(0.6), - }, { - // testing for overflow of float64 - v1: NewFloat64(math.MaxFloat64), - v2: NewFloat64(0.5), - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT is out of range in 1.7976931348623157e+308 / 0.5"), - }} - - for _, tcase := range tcases { - got, err := Divide(tcase.v1, tcase.v2) - - if !vterrors.Equals(err, tcase.err) { - t.Errorf("%v %v %v", printValue(tcase.v1), printValue(tcase.v2), vterrors.Print(err)) - t.Errorf("Divide(%v, %v) error: %v, want %v", printValue(tcase.v1), printValue(tcase.v2), vterrors.Print(err), vterrors.Print(tcase.err)) - } - - if tcase.err != nil { - continue - } - - if !reflect.DeepEqual(got, tcase.out) { - t.Errorf("Divide(%v, %v): %v, want %v", printValue(tcase.v1), printValue(tcase.v2), printValue(got), printValue(tcase.out)) - } - } - -} - -func TestMultiply(t *testing.T) { - tcases := []struct { - v1, v2 Value - out Value - err error - }{{ - //All Nulls - v1: NULL, - v2: NULL, - out: NULL, - }, { - // First value null. - v1: NewInt32(1), - v2: NULL, - out: NULL, - }, { - // Second value null. - v1: NULL, - v2: NewInt32(1), - out: NULL, - }, { - // case with negative value - v1: NewInt64(-1), - v2: NewInt64(-2), - out: NewInt64(2), - }, { - // testing for int64 overflow with min negative value - v1: NewInt64(math.MinInt64), - v2: NewInt64(1), - out: NewInt64(math.MinInt64), - }, { - // testing for error in types - v1: TestValue(Int64, "1.2"), - v2: NewInt64(2), - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "strconv.ParseInt: parsing \"1.2\": invalid syntax"), - }, { - // testing for error in types - v1: NewInt64(2), - v2: TestValue(Int64, "1.2"), - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "strconv.ParseInt: parsing \"1.2\": invalid syntax"), - }, { - // testing for uint*int - v1: NewUint64(4), - v2: NewInt64(5), - out: NewUint64(20), - }, { - // testing for uint*uint - v1: NewUint64(1), - v2: NewUint64(2), - out: NewUint64(2), - }, { - // testing for float64*int64 - v1: TestValue(Float64, "1.2"), - v2: NewInt64(-2), - out: NewFloat64(-2.4), - }, { - // testing for float64*uint64 - v1: TestValue(Float64, "1.2"), - v2: NewUint64(2), - out: NewFloat64(2.4), - }, { - // testing for overflow of int64 - v1: NewInt64(math.MaxInt64), - v2: NewInt64(2), - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT value is out of range in 9223372036854775807 * 2"), - }, { - // testing for underflow of uint64*max.uint64 - v1: NewInt64(2), - v2: NewUint64(math.MaxUint64), - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in 18446744073709551615 * 2"), - }, { - v1: NewUint64(math.MaxUint64), - v2: NewUint64(1), - out: NewUint64(math.MaxUint64), - }, { - //Checking whether maxInt value can be passed as uint value - v1: NewUint64(math.MaxInt64), - v2: NewInt64(3), - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in 9223372036854775807 * 3"), - }} - - for _, tcase := range tcases { - - got, err := Multiply(tcase.v1, tcase.v2) - - if !vterrors.Equals(err, tcase.err) { - t.Errorf("Multiply(%v, %v) error: %v, want %v", printValue(tcase.v1), printValue(tcase.v2), vterrors.Print(err), vterrors.Print(tcase.err)) - } - if tcase.err != nil { - continue - } - - if !reflect.DeepEqual(got, tcase.out) { - t.Errorf("Multiply(%v, %v): %v, want %v", printValue(tcase.v1), printValue(tcase.v2), printValue(got), printValue(tcase.out)) - } - } - -} - -func TestSubtract(t *testing.T) { - tcases := []struct { - v1, v2 Value - out Value - err error - }{{ - // All Nulls - v1: NULL, - v2: NULL, - out: NULL, - }, { - // First value null. - v1: NewInt32(1), - v2: NULL, - out: NULL, - }, { - // Second value null. - v1: NULL, - v2: NewInt32(1), - out: NULL, - }, { - // case with negative value - v1: NewInt64(-1), - v2: NewInt64(-2), - out: NewInt64(1), - }, { - // testing for int64 overflow with min negative value - v1: NewInt64(math.MinInt64), - v2: NewInt64(1), - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT value is out of range in -9223372036854775808 - 1"), - }, { - v1: NewUint64(4), - v2: NewInt64(5), - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in 4 - 5"), - }, { - // testing uint - int - v1: NewUint64(7), - v2: NewInt64(5), - out: NewUint64(2), - }, { - v1: NewUint64(math.MaxUint64), - v2: NewInt64(0), - out: NewUint64(math.MaxUint64), - }, { - // testing for int64 overflow - v1: NewInt64(math.MinInt64), - v2: NewUint64(0), - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in -9223372036854775808 - 0"), - }, { - v1: TestValue(VarChar, "c"), - v2: NewInt64(1), - out: NewInt64(-1), - }, { - v1: NewUint64(1), - v2: TestValue(VarChar, "c"), - out: NewUint64(1), - }, { - // testing for error for parsing float value to uint64 - v1: TestValue(Uint64, "1.2"), - v2: NewInt64(2), - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "strconv.ParseUint: parsing \"1.2\": invalid syntax"), - }, { - // testing for error for parsing float value to uint64 - v1: NewUint64(2), - v2: TestValue(Uint64, "1.2"), - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "strconv.ParseUint: parsing \"1.2\": invalid syntax"), - }, { - // uint64 - uint64 - v1: NewUint64(8), - v2: NewUint64(4), - out: NewUint64(4), - }, { - // testing for float subtraction: float - int - v1: NewFloat64(1.2), - v2: NewInt64(2), - out: NewFloat64(-0.8), - }, { - // testing for float subtraction: float - uint - v1: NewFloat64(1.2), - v2: NewUint64(2), - out: NewFloat64(-0.8), - }, { - v1: NewInt64(-1), - v2: NewUint64(2), - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in -1 - 2"), - }, { - v1: NewInt64(2), - v2: NewUint64(1), - out: NewUint64(1), - }, { - // testing int64 - float64 method - v1: NewInt64(-2), - v2: NewFloat64(1.0), - out: NewFloat64(-3.0), - }, { - // testing uint64 - float64 method - v1: NewUint64(1), - v2: NewFloat64(-2.0), - out: NewFloat64(3.0), - }, { - // testing uint - int to return uintplusint - v1: NewUint64(1), - v2: NewInt64(-2), - out: NewUint64(3), - }, { - // testing for float - float - v1: NewFloat64(1.2), - v2: NewFloat64(3.2), - out: NewFloat64(-2), - }, { - // testing uint - uint if v2 > v1 - v1: NewUint64(2), - v2: NewUint64(4), - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in 2 - 4"), - }, { - // testing uint - (- int) - v1: NewUint64(1), - v2: NewInt64(-2), - out: NewUint64(3), - }} - - for _, tcase := range tcases { - - got, err := Subtract(tcase.v1, tcase.v2) - - if !vterrors.Equals(err, tcase.err) { - t.Errorf("Subtract(%v, %v) error: %v, want %v", printValue(tcase.v1), printValue(tcase.v2), vterrors.Print(err), vterrors.Print(tcase.err)) - } - if tcase.err != nil { - continue - } - - if !reflect.DeepEqual(got, tcase.out) { - t.Errorf("Subtract(%v, %v): %v, want %v", printValue(tcase.v1), printValue(tcase.v2), printValue(got), printValue(tcase.out)) - } - } - -} - -func TestAdd(t *testing.T) { - tcases := []struct { - v1, v2 Value - out Value - err error - }{{ - // All Nulls - v1: NULL, - v2: NULL, - out: NULL, - }, { - // First value null. - v1: NewInt32(1), - v2: NULL, - out: NULL, - }, { - // Second value null. - v1: NULL, - v2: NewInt32(1), - out: NULL, - }, { - // case with negatives - v1: NewInt64(-1), - v2: NewInt64(-2), - out: NewInt64(-3), - }, { - // testing for overflow int64, result will be unsigned int - v1: NewInt64(math.MaxInt64), - v2: NewUint64(2), - out: NewUint64(9223372036854775809), - }, { - v1: NewInt64(-2), - v2: NewUint64(1), - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in 1 + -2"), - }, { - v1: NewInt64(math.MaxInt64), - v2: NewInt64(-2), - out: NewInt64(9223372036854775805), - }, { - // Normal case - v1: NewUint64(1), - v2: NewUint64(2), - out: NewUint64(3), - }, { - // testing for overflow uint64 - v1: NewUint64(math.MaxUint64), - v2: NewUint64(2), - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in 18446744073709551615 + 2"), - }, { - // int64 underflow - v1: NewInt64(math.MinInt64), - v2: NewInt64(-2), - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT value is out of range in -9223372036854775808 + -2"), - }, { - // checking int64 max value can be returned - v1: NewInt64(math.MaxInt64), - v2: NewUint64(0), - out: NewUint64(9223372036854775807), - }, { - // testing whether uint64 max value can be returned - v1: NewUint64(math.MaxUint64), - v2: NewInt64(0), - out: NewUint64(math.MaxUint64), - }, { - v1: NewUint64(math.MaxInt64), - v2: NewInt64(1), - out: NewUint64(9223372036854775808), - }, { - v1: NewUint64(1), - v2: TestValue(VarChar, "c"), - out: NewUint64(1), - }, { - v1: NewUint64(1), - v2: TestValue(VarChar, "1.2"), - out: NewFloat64(2.2), - }, { - v1: TestValue(Int64, "1.2"), - v2: NewInt64(2), - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "strconv.ParseInt: parsing \"1.2\": invalid syntax"), - }, { - v1: NewInt64(2), - v2: TestValue(Int64, "1.2"), - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "strconv.ParseInt: parsing \"1.2\": invalid syntax"), - }, { - // testing for uint64 overflow with max uint64 + int value - v1: NewUint64(math.MaxUint64), - v2: NewInt64(2), - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in 18446744073709551615 + 2"), - }} - - for _, tcase := range tcases { - - got, err := Add(tcase.v1, tcase.v2) - - if !vterrors.Equals(err, tcase.err) { - t.Errorf("Add(%v, %v) error: %v, want %v", printValue(tcase.v1), printValue(tcase.v2), vterrors.Print(err), vterrors.Print(tcase.err)) - } - if tcase.err != nil { - continue - } - - if !reflect.DeepEqual(got, tcase.out) { - t.Errorf("Add(%v, %v): %v, want %v", printValue(tcase.v1), printValue(tcase.v2), printValue(got), printValue(tcase.out)) - } - } - -} - -func TestNullsafeAdd(t *testing.T) { - tcases := []struct { - v1, v2 Value - out Value - err error - }{{ - // All nulls. - v1: NULL, - v2: NULL, - out: NewInt64(0), - }, { - // First value null. - v1: NewInt32(1), - v2: NULL, - out: NewInt64(1), - }, { - // Second value null. - v1: NULL, - v2: NewInt32(1), - out: NewInt64(1), - }, { - // Normal case. - v1: NewInt64(1), - v2: NewInt64(2), - out: NewInt64(3), - }, { - // Make sure underlying error is returned for LHS. - v1: TestValue(Int64, "1.2"), - v2: NewInt64(2), - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "strconv.ParseInt: parsing \"1.2\": invalid syntax"), - }, { - // Make sure underlying error is returned for RHS. - v1: NewInt64(2), - v2: TestValue(Int64, "1.2"), - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "strconv.ParseInt: parsing \"1.2\": invalid syntax"), - }, { - // Make sure underlying error is returned while adding. - v1: NewInt64(-1), - v2: NewUint64(2), - out: NewInt64(-9223372036854775808), - }, { - // Make sure underlying error is returned while converting. - v1: NewFloat64(1), - v2: NewFloat64(2), - out: NewInt64(3), - }} - for _, tcase := range tcases { - got := NullsafeAdd(tcase.v1, tcase.v2, Int64) - - if !reflect.DeepEqual(got, tcase.out) { - t.Errorf("NullsafeAdd(%v, %v): %v, want %v", printValue(tcase.v1), printValue(tcase.v2), printValue(got), printValue(tcase.out)) - } - } -} - -func TestNullsafeCompare(t *testing.T) { - tcases := []struct { - v1, v2 Value - out int - err error - }{{ - // All nulls. - v1: NULL, - v2: NULL, - out: 0, - }, { - // LHS null. - v1: NULL, - v2: NewInt64(1), - out: -1, - }, { - // RHS null. - v1: NewInt64(1), - v2: NULL, - out: 1, - }, { - // LHS Text - v1: TestValue(VarChar, "abcd"), - v2: TestValue(VarChar, "abcd"), - err: vterrors.New(vtrpcpb.Code_UNKNOWN, "types are not comparable: VARCHAR vs VARCHAR"), - }, { - // Make sure underlying error is returned for LHS. - v1: TestValue(Int64, "1.2"), - v2: NewInt64(2), - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "strconv.ParseInt: parsing \"1.2\": invalid syntax"), - }, { - // Make sure underlying error is returned for RHS. - v1: NewInt64(2), - v2: TestValue(Int64, "1.2"), - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "strconv.ParseInt: parsing \"1.2\": invalid syntax"), - }, { - // Numeric equal. - v1: NewInt64(1), - v2: NewUint64(1), - out: 0, - }, { - // Numeric unequal. - v1: NewInt64(1), - v2: NewUint64(2), - out: -1, - }, { - // Non-numeric equal - v1: TestValue(VarBinary, "abcd"), - v2: TestValue(Binary, "abcd"), - out: 0, - }, { - // Non-numeric unequal - v1: TestValue(VarBinary, "abcd"), - v2: TestValue(Binary, "bcde"), - out: -1, - }, { - // Date/Time types - v1: TestValue(Datetime, "1000-01-01 00:00:00"), - v2: TestValue(Binary, "1000-01-01 00:00:00"), - out: 0, - }, { - // Date/Time types - v1: TestValue(Datetime, "2000-01-01 00:00:00"), - v2: TestValue(Binary, "1000-01-01 00:00:00"), - out: 1, - }, { - // Date/Time types - v1: TestValue(Datetime, "1000-01-01 00:00:00"), - v2: TestValue(Binary, "2000-01-01 00:00:00"), - out: -1, - }} - for _, tcase := range tcases { - got, err := NullsafeCompare(tcase.v1, tcase.v2) - if !vterrors.Equals(err, tcase.err) { - t.Errorf("NullsafeCompare(%v, %v) error: %v, want %v", printValue(tcase.v1), printValue(tcase.v2), vterrors.Print(err), vterrors.Print(tcase.err)) - } - if tcase.err != nil { - continue - } - - if got != tcase.out { - t.Errorf("NullsafeCompare(%v, %v): %v, want %v", printValue(tcase.v1), printValue(tcase.v2), got, tcase.out) - } - } -} - -func TestCast(t *testing.T) { - tcases := []struct { - typ querypb.Type - v Value - out Value - err error - }{{ - typ: VarChar, - v: NULL, - out: NULL, - }, { - typ: VarChar, - v: TestValue(VarChar, "exact types"), - out: TestValue(VarChar, "exact types"), - }, { - typ: Int64, - v: TestValue(Int32, "32"), - out: TestValue(Int64, "32"), - }, { - typ: Int24, - v: TestValue(Uint64, "64"), - out: TestValue(Int24, "64"), - }, { - typ: Int24, - v: TestValue(VarChar, "bad int"), - err: vterrors.New(vtrpcpb.Code_UNKNOWN, `strconv.ParseInt: parsing "bad int": invalid syntax`), - }, { - typ: Uint64, - v: TestValue(Uint32, "32"), - out: TestValue(Uint64, "32"), - }, { - typ: Uint24, - v: TestValue(Int64, "64"), - out: TestValue(Uint24, "64"), - }, { - typ: Uint24, - v: TestValue(Int64, "-1"), - err: vterrors.New(vtrpcpb.Code_UNKNOWN, `strconv.ParseUint: parsing "-1": invalid syntax`), - }, { - typ: Float64, - v: TestValue(Int64, "64"), - out: TestValue(Float64, "64"), - }, { - typ: Float32, - v: TestValue(Float64, "64"), - out: TestValue(Float32, "64"), - }, { - typ: Float32, - v: TestValue(Decimal, "1.24"), - out: TestValue(Float32, "1.24"), - }, { - typ: Float64, - v: TestValue(VarChar, "1.25"), - out: TestValue(Float64, "1.25"), - }, { - typ: Float64, - v: TestValue(VarChar, "bad float"), - err: vterrors.New(vtrpcpb.Code_UNKNOWN, `strconv.ParseFloat: parsing "bad float": invalid syntax`), - }, { - typ: VarChar, - v: TestValue(Int64, "64"), - out: TestValue(VarChar, "64"), - }, { - typ: VarBinary, - v: TestValue(Float64, "64"), - out: TestValue(VarBinary, "64"), - }, { - typ: VarBinary, - v: TestValue(Decimal, "1.24"), - out: TestValue(VarBinary, "1.24"), - }, { - typ: VarBinary, - v: TestValue(VarChar, "1.25"), - out: TestValue(VarBinary, "1.25"), - }, { - typ: VarChar, - v: TestValue(VarBinary, "valid string"), - out: TestValue(VarChar, "valid string"), - }, { - typ: VarChar, - v: TestValue(Expression, "bad string"), - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "EXPRESSION(bad string) cannot be cast to VARCHAR"), - }} - for _, tcase := range tcases { - got, err := Cast(tcase.v, tcase.typ) - if !vterrors.Equals(err, tcase.err) { - t.Errorf("Cast(%v) error: %v, want %v", tcase.v, vterrors.Print(err), vterrors.Print(tcase.err)) - } - if tcase.err != nil { - continue - } - - if !reflect.DeepEqual(got, tcase.out) { - t.Errorf("Cast(%v): %v, want %v", tcase.v, got, tcase.out) - } - } -} - -func TestToUint64(t *testing.T) { - tcases := []struct { - v Value - out uint64 - err error - }{{ - v: TestValue(VarChar, "abcd"), - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "could not parse value: 'abcd'"), - }, { - v: NewInt64(-1), - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "negative number cannot be converted to unsigned: -1"), - }, { - v: NewInt64(1), - out: 1, - }, { - v: NewUint64(1), - out: 1, - }} - for _, tcase := range tcases { - got, err := ToUint64(tcase.v) - if !vterrors.Equals(err, tcase.err) { - t.Errorf("ToUint64(%v) error: %v, want %v", tcase.v, vterrors.Print(err), vterrors.Print(tcase.err)) - } - if tcase.err != nil { - continue - } - - if got != tcase.out { - t.Errorf("ToUint64(%v): %v, want %v", tcase.v, got, tcase.out) - } - } -} - -func TestToInt64(t *testing.T) { - tcases := []struct { - v Value - out int64 - err error - }{{ - v: TestValue(VarChar, "abcd"), - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "could not parse value: 'abcd'"), - }, { - v: NewUint64(18446744073709551615), - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "unsigned number overflows int64 value: 18446744073709551615"), - }, { - v: NewInt64(1), - out: 1, - }, { - v: NewUint64(1), - out: 1, - }} - for _, tcase := range tcases { - got, err := ToInt64(tcase.v) - if !vterrors.Equals(err, tcase.err) { - t.Errorf("ToInt64(%v) error: %v, want %v", tcase.v, vterrors.Print(err), vterrors.Print(tcase.err)) - } - if tcase.err != nil { - continue - } - - if got != tcase.out { - t.Errorf("ToInt64(%v): %v, want %v", tcase.v, got, tcase.out) - } - } -} - -func TestToFloat64(t *testing.T) { - tcases := []struct { - v Value - out float64 - err error - }{{ - v: TestValue(VarChar, "abcd"), - out: 0, - }, { - v: NewInt64(1), - out: 1, - }, { - v: NewUint64(1), - out: 1, - }, { - v: NewFloat64(1.2), - out: 1.2, - }, { - v: TestValue(Int64, "1.2"), - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "strconv.ParseInt: parsing \"1.2\": invalid syntax"), - }} - for _, tcase := range tcases { - got, err := ToFloat64(tcase.v) - if !vterrors.Equals(err, tcase.err) { - t.Errorf("ToFloat64(%v) error: %v, want %v", tcase.v, vterrors.Print(err), vterrors.Print(tcase.err)) - } - if tcase.err != nil { - continue - } - - if got != tcase.out { - t.Errorf("ToFloat64(%v): %v, want %v", tcase.v, got, tcase.out) - } - } -} - -func TestToNative(t *testing.T) { - testcases := []struct { - in Value - out interface{} - }{{ - in: NULL, - out: nil, - }, { - in: TestValue(Int8, "1"), - out: int64(1), - }, { - in: TestValue(Int16, "1"), - out: int64(1), - }, { - in: TestValue(Int24, "1"), - out: int64(1), - }, { - in: TestValue(Int32, "1"), - out: int64(1), - }, { - in: TestValue(Int64, "1"), - out: int64(1), - }, { - in: TestValue(Uint8, "1"), - out: uint64(1), - }, { - in: TestValue(Uint16, "1"), - out: uint64(1), - }, { - in: TestValue(Uint24, "1"), - out: uint64(1), - }, { - in: TestValue(Uint32, "1"), - out: uint64(1), - }, { - in: TestValue(Uint64, "1"), - out: uint64(1), - }, { - in: TestValue(Float32, "1"), - out: float64(1), - }, { - in: TestValue(Float64, "1"), - out: float64(1), - }, { - in: TestValue(Timestamp, "2012-02-24 23:19:43"), - out: []byte("2012-02-24 23:19:43"), - }, { - in: TestValue(Date, "2012-02-24"), - out: []byte("2012-02-24"), - }, { - in: TestValue(Time, "23:19:43"), - out: []byte("23:19:43"), - }, { - in: TestValue(Datetime, "2012-02-24 23:19:43"), - out: []byte("2012-02-24 23:19:43"), - }, { - in: TestValue(Year, "1"), - out: uint64(1), - }, { - in: TestValue(Decimal, "1"), - out: []byte("1"), - }, { - in: TestValue(Text, "a"), - out: []byte("a"), - }, { - in: TestValue(Blob, "a"), - out: []byte("a"), - }, { - in: TestValue(VarChar, "a"), - out: []byte("a"), - }, { - in: TestValue(VarBinary, "a"), - out: []byte("a"), - }, { - in: TestValue(Char, "a"), - out: []byte("a"), - }, { - in: TestValue(Binary, "a"), - out: []byte("a"), - }, { - in: TestValue(Bit, "1"), - out: []byte("1"), - }, { - in: TestValue(Enum, "a"), - out: []byte("a"), - }, { - in: TestValue(Set, "a"), - out: []byte("a"), - }} - for _, tcase := range testcases { - v, err := ToNative(tcase.in) - if err != nil { - t.Error(err) - } - if !reflect.DeepEqual(v, tcase.out) { - t.Errorf("%v.ToNative = %#v, want %#v", tcase.in, v, tcase.out) - } - } - - // Test Expression failure. - _, err := ToNative(TestValue(Expression, "aa")) - want := vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "EXPRESSION(aa) cannot be converted to a go type") - if !vterrors.Equals(err, want) { - t.Errorf("ToNative(EXPRESSION): %v, want %v", vterrors.Print(err), vterrors.Print(want)) - } -} - -func TestNewNumeric(t *testing.T) { - tcases := []struct { - v Value - out evalResult - err error - }{{ - v: NewInt64(1), - out: evalResult{typ: Int64, ival: 1}, - }, { - v: NewUint64(1), - out: evalResult{typ: Uint64, uval: 1}, - }, { - v: NewFloat64(1), - out: evalResult{typ: Float64, fval: 1}, - }, { - // For non-number type, Int64 is the default. - v: TestValue(VarChar, "1"), - out: evalResult{typ: Int64, ival: 1}, - }, { - // If Int64 can't work, we use Float64. - v: TestValue(VarChar, "1.2"), - out: evalResult{typ: Float64, fval: 1.2}, - }, { - // Only valid Int64 allowed if type is Int64. - v: TestValue(Int64, "1.2"), - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "strconv.ParseInt: parsing \"1.2\": invalid syntax"), - }, { - // Only valid Uint64 allowed if type is Uint64. - v: TestValue(Uint64, "1.2"), - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "strconv.ParseUint: parsing \"1.2\": invalid syntax"), - }, { - // Only valid Float64 allowed if type is Float64. - v: TestValue(Float64, "abcd"), - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "strconv.ParseFloat: parsing \"abcd\": invalid syntax"), - }, { - v: TestValue(VarChar, "abcd"), - out: evalResult{typ: Float64, fval: 0}, - }} - for _, tcase := range tcases { - got, err := newNumeric(tcase.v) - if !vterrors.Equals(err, tcase.err) { - t.Errorf("newNumeric(%s) error: %v, want %v", printValue(tcase.v), vterrors.Print(err), vterrors.Print(tcase.err)) - } - if tcase.err == nil { - continue - } - - if got != tcase.out { - t.Errorf("newNumeric(%s): %v, want %v", printValue(tcase.v), got, tcase.out) - } - } -} - -func TestNewIntegralNumeric(t *testing.T) { - tcases := []struct { - v Value - out evalResult - err error - }{{ - v: NewInt64(1), - out: evalResult{typ: Int64, ival: 1}, - }, { - v: NewUint64(1), - out: evalResult{typ: Uint64, uval: 1}, - }, { - v: NewFloat64(1), - out: evalResult{typ: Int64, ival: 1}, - }, { - // For non-number type, Int64 is the default. - v: TestValue(VarChar, "1"), - out: evalResult{typ: Int64, ival: 1}, - }, { - // If Int64 can't work, we use Uint64. - v: TestValue(VarChar, "18446744073709551615"), - out: evalResult{typ: Uint64, uval: 18446744073709551615}, - }, { - // Only valid Int64 allowed if type is Int64. - v: TestValue(Int64, "1.2"), - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "strconv.ParseInt: parsing \"1.2\": invalid syntax"), - }, { - // Only valid Uint64 allowed if type is Uint64. - v: TestValue(Uint64, "1.2"), - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "strconv.ParseUint: parsing \"1.2\": invalid syntax"), - }, { - v: TestValue(VarChar, "abcd"), - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "could not parse value: 'abcd'"), - }} - for _, tcase := range tcases { - got, err := newIntegralNumeric(tcase.v) - if err != nil && !vterrors.Equals(err, tcase.err) { - t.Errorf("newIntegralNumeric(%s) error: %v, want %v", printValue(tcase.v), vterrors.Print(err), vterrors.Print(tcase.err)) - } - if tcase.err == nil { - continue - } - - if got != tcase.out { - t.Errorf("newIntegralNumeric(%s): %v, want %v", printValue(tcase.v), got, tcase.out) - } - } -} - -func TestAddNumeric(t *testing.T) { - tcases := []struct { - v1, v2 evalResult - out evalResult - err error - }{{ - v1: evalResult{typ: Int64, ival: 1}, - v2: evalResult{typ: Int64, ival: 2}, - out: evalResult{typ: Int64, ival: 3}, - }, { - v1: evalResult{typ: Int64, ival: 1}, - v2: evalResult{typ: Uint64, uval: 2}, - out: evalResult{typ: Uint64, uval: 3}, - }, { - v1: evalResult{typ: Int64, ival: 1}, - v2: evalResult{typ: Float64, fval: 2}, - out: evalResult{typ: Float64, fval: 3}, - }, { - v1: evalResult{typ: Uint64, uval: 1}, - v2: evalResult{typ: Uint64, uval: 2}, - out: evalResult{typ: Uint64, uval: 3}, - }, { - v1: evalResult{typ: Uint64, uval: 1}, - v2: evalResult{typ: Float64, fval: 2}, - out: evalResult{typ: Float64, fval: 3}, - }, { - v1: evalResult{typ: Float64, fval: 1}, - v2: evalResult{typ: Float64, fval: 2}, - out: evalResult{typ: Float64, fval: 3}, - }, { - // Int64 overflow. - v1: evalResult{typ: Int64, ival: 9223372036854775807}, - v2: evalResult{typ: Int64, ival: 2}, - out: evalResult{typ: Float64, fval: 9223372036854775809}, - }, { - // Int64 underflow. - v1: evalResult{typ: Int64, ival: -9223372036854775807}, - v2: evalResult{typ: Int64, ival: -2}, - out: evalResult{typ: Float64, fval: -9223372036854775809}, - }, { - v1: evalResult{typ: Int64, ival: -1}, - v2: evalResult{typ: Uint64, uval: 2}, - out: evalResult{typ: Float64, fval: 18446744073709551617}, - }, { - // Uint64 overflow. - v1: evalResult{typ: Uint64, uval: 18446744073709551615}, - v2: evalResult{typ: Uint64, uval: 2}, - out: evalResult{typ: Float64, fval: 18446744073709551617}, - }} - for _, tcase := range tcases { - got := addNumeric(tcase.v1, tcase.v2) - - if got != tcase.out { - t.Errorf("addNumeric(%v, %v): %v, want %v", tcase.v1, tcase.v2, got, tcase.out) - } - } -} - -func TestPrioritize(t *testing.T) { - ival := evalResult{typ: Int64} - uval := evalResult{typ: Uint64} - fval := evalResult{typ: Float64} - - tcases := []struct { - v1, v2 evalResult - out1, out2 evalResult - }{{ - v1: ival, - v2: uval, - out1: uval, - out2: ival, - }, { - v1: ival, - v2: fval, - out1: fval, - out2: ival, - }, { - v1: uval, - v2: ival, - out1: uval, - out2: ival, - }, { - v1: uval, - v2: fval, - out1: fval, - out2: uval, - }, { - v1: fval, - v2: ival, - out1: fval, - out2: ival, - }, { - v1: fval, - v2: uval, - out1: fval, - out2: uval, - }} - for _, tcase := range tcases { - got1, got2 := prioritize(tcase.v1, tcase.v2) - if got1 != tcase.out1 || got2 != tcase.out2 { - t.Errorf("prioritize(%v, %v): (%v, %v) , want (%v, %v)", tcase.v1.typ, tcase.v2.typ, got1.typ, got2.typ, tcase.out1.typ, tcase.out2.typ) - } - } -} - -func TestCastFromNumeric(t *testing.T) { - tcases := []struct { - typ querypb.Type - v evalResult - out Value - err error - }{{ - typ: Int64, - v: evalResult{typ: Int64, ival: 1}, - out: NewInt64(1), - }, { - typ: Int64, - v: evalResult{typ: Uint64, uval: 1}, - out: NewInt64(1), - }, { - typ: Int64, - v: evalResult{typ: Float64, fval: 1.2e-16}, - out: NewInt64(0), - }, { - typ: Uint64, - v: evalResult{typ: Int64, ival: 1}, - out: NewUint64(1), - }, { - typ: Uint64, - v: evalResult{typ: Uint64, uval: 1}, - out: NewUint64(1), - }, { - typ: Uint64, - v: evalResult{typ: Float64, fval: 1.2e-16}, - out: NewUint64(0), - }, { - typ: Float64, - v: evalResult{typ: Int64, ival: 1}, - out: TestValue(Float64, "1"), - }, { - typ: Float64, - v: evalResult{typ: Uint64, uval: 1}, - out: TestValue(Float64, "1"), - }, { - typ: Float64, - v: evalResult{typ: Float64, fval: 1.2e-16}, - out: TestValue(Float64, "1.2e-16"), - }, { - typ: Decimal, - v: evalResult{typ: Int64, ival: 1}, - out: TestValue(Decimal, "1"), - }, { - typ: Decimal, - v: evalResult{typ: Uint64, uval: 1}, - out: TestValue(Decimal, "1"), - }, { - // For float, we should not use scientific notation. - typ: Decimal, - v: evalResult{typ: Float64, fval: 1.2e-16}, - out: TestValue(Decimal, "0.00000000000000012"), - }} - for _, tcase := range tcases { - got := castFromNumeric(tcase.v, tcase.typ) - - if !reflect.DeepEqual(got, tcase.out) { - t.Errorf("castFromNumeric(%v, %v): %v, want %v", tcase.v, tcase.typ, printValue(got), printValue(tcase.out)) - } - } -} - -func TestCompareNumeric(t *testing.T) { - tcases := []struct { - v1, v2 evalResult - out int - }{{ - v1: evalResult{typ: Int64, ival: 1}, - v2: evalResult{typ: Int64, ival: 1}, - out: 0, - }, { - v1: evalResult{typ: Int64, ival: 1}, - v2: evalResult{typ: Int64, ival: 2}, - out: -1, - }, { - v1: evalResult{typ: Int64, ival: 2}, - v2: evalResult{typ: Int64, ival: 1}, - out: 1, - }, { - // Special case. - v1: evalResult{typ: Int64, ival: -1}, - v2: evalResult{typ: Uint64, uval: 1}, - out: -1, - }, { - v1: evalResult{typ: Int64, ival: 1}, - v2: evalResult{typ: Uint64, uval: 1}, - out: 0, - }, { - v1: evalResult{typ: Int64, ival: 1}, - v2: evalResult{typ: Uint64, uval: 2}, - out: -1, - }, { - v1: evalResult{typ: Int64, ival: 2}, - v2: evalResult{typ: Uint64, uval: 1}, - out: 1, - }, { - v1: evalResult{typ: Int64, ival: 1}, - v2: evalResult{typ: Float64, fval: 1}, - out: 0, - }, { - v1: evalResult{typ: Int64, ival: 1}, - v2: evalResult{typ: Float64, fval: 2}, - out: -1, - }, { - v1: evalResult{typ: Int64, ival: 2}, - v2: evalResult{typ: Float64, fval: 1}, - out: 1, - }, { - // Special case. - v1: evalResult{typ: Uint64, uval: 1}, - v2: evalResult{typ: Int64, ival: -1}, - out: 1, - }, { - v1: evalResult{typ: Uint64, uval: 1}, - v2: evalResult{typ: Int64, ival: 1}, - out: 0, - }, { - v1: evalResult{typ: Uint64, uval: 1}, - v2: evalResult{typ: Int64, ival: 2}, - out: -1, - }, { - v1: evalResult{typ: Uint64, uval: 2}, - v2: evalResult{typ: Int64, ival: 1}, - out: 1, - }, { - v1: evalResult{typ: Uint64, uval: 1}, - v2: evalResult{typ: Uint64, uval: 1}, - out: 0, - }, { - v1: evalResult{typ: Uint64, uval: 1}, - v2: evalResult{typ: Uint64, uval: 2}, - out: -1, - }, { - v1: evalResult{typ: Uint64, uval: 2}, - v2: evalResult{typ: Uint64, uval: 1}, - out: 1, - }, { - v1: evalResult{typ: Uint64, uval: 1}, - v2: evalResult{typ: Float64, fval: 1}, - out: 0, - }, { - v1: evalResult{typ: Uint64, uval: 1}, - v2: evalResult{typ: Float64, fval: 2}, - out: -1, - }, { - v1: evalResult{typ: Uint64, uval: 2}, - v2: evalResult{typ: Float64, fval: 1}, - out: 1, - }, { - v1: evalResult{typ: Float64, fval: 1}, - v2: evalResult{typ: Int64, ival: 1}, - out: 0, - }, { - v1: evalResult{typ: Float64, fval: 1}, - v2: evalResult{typ: Int64, ival: 2}, - out: -1, - }, { - v1: evalResult{typ: Float64, fval: 2}, - v2: evalResult{typ: Int64, ival: 1}, - out: 1, - }, { - v1: evalResult{typ: Float64, fval: 1}, - v2: evalResult{typ: Uint64, uval: 1}, - out: 0, - }, { - v1: evalResult{typ: Float64, fval: 1}, - v2: evalResult{typ: Uint64, uval: 2}, - out: -1, - }, { - v1: evalResult{typ: Float64, fval: 2}, - v2: evalResult{typ: Uint64, uval: 1}, - out: 1, - }, { - v1: evalResult{typ: Float64, fval: 1}, - v2: evalResult{typ: Float64, fval: 1}, - out: 0, - }, { - v1: evalResult{typ: Float64, fval: 1}, - v2: evalResult{typ: Float64, fval: 2}, - out: -1, - }, { - v1: evalResult{typ: Float64, fval: 2}, - v2: evalResult{typ: Float64, fval: 1}, - out: 1, - }} - for _, tcase := range tcases { - got := compareNumeric(tcase.v1, tcase.v2) - if got != tcase.out { - t.Errorf("equalNumeric(%v, %v): %v, want %v", tcase.v1, tcase.v2, got, tcase.out) - } - } -} - -func TestMin(t *testing.T) { - tcases := []struct { - v1, v2 Value - min Value - err error - }{{ - v1: NULL, - v2: NULL, - min: NULL, - }, { - v1: NewInt64(1), - v2: NULL, - min: NewInt64(1), - }, { - v1: NULL, - v2: NewInt64(1), - min: NewInt64(1), - }, { - v1: NewInt64(1), - v2: NewInt64(2), - min: NewInt64(1), - }, { - v1: NewInt64(2), - v2: NewInt64(1), - min: NewInt64(1), - }, { - v1: NewInt64(1), - v2: NewInt64(1), - min: NewInt64(1), - }, { - v1: TestValue(VarChar, "aa"), - v2: TestValue(VarChar, "aa"), - err: vterrors.New(vtrpcpb.Code_UNKNOWN, "types are not comparable: VARCHAR vs VARCHAR"), - }} - for _, tcase := range tcases { - v, err := Min(tcase.v1, tcase.v2) - if !vterrors.Equals(err, tcase.err) { - t.Errorf("Min error: %v, want %v", vterrors.Print(err), vterrors.Print(tcase.err)) - } - if tcase.err != nil { - continue - } - - if !reflect.DeepEqual(v, tcase.min) { - t.Errorf("Min(%v, %v): %v, want %v", tcase.v1, tcase.v2, v, tcase.min) - } - } -} - -func TestMax(t *testing.T) { - tcases := []struct { - v1, v2 Value - max Value - err error - }{{ - v1: NULL, - v2: NULL, - max: NULL, - }, { - v1: NewInt64(1), - v2: NULL, - max: NewInt64(1), - }, { - v1: NULL, - v2: NewInt64(1), - max: NewInt64(1), - }, { - v1: NewInt64(1), - v2: NewInt64(2), - max: NewInt64(2), - }, { - v1: NewInt64(2), - v2: NewInt64(1), - max: NewInt64(2), - }, { - v1: NewInt64(1), - v2: NewInt64(1), - max: NewInt64(1), - }, { - v1: TestValue(VarChar, "aa"), - v2: TestValue(VarChar, "aa"), - err: vterrors.New(vtrpcpb.Code_UNKNOWN, "types are not comparable: VARCHAR vs VARCHAR"), - }} - for _, tcase := range tcases { - v, err := Max(tcase.v1, tcase.v2) - if !vterrors.Equals(err, tcase.err) { - t.Errorf("Max error: %v, want %v", vterrors.Print(err), vterrors.Print(tcase.err)) - } - if tcase.err != nil { - continue - } - - if !reflect.DeepEqual(v, tcase.max) { - t.Errorf("Max(%v, %v): %v, want %v", tcase.v1, tcase.v2, v, tcase.max) - } - } -} - -func printValue(v Value) string { - return fmt.Sprintf("%v:%q", v.typ, v.val) -} - -// These benchmarks show that using existing ASCII representations -// for numbers is about 6x slower than using native representations. -// However, 229ns is still a negligible time compared to the cost of -// other operations. The additional complexity of introducing native -// types is currently not worth it. So, we'll stay with the existing -// ASCII representation for now. Using interfaces is more expensive -// than native representation of values. This is probably because -// interfaces also allocate memory, and also perform type assertions. -// Actual benchmark is based on NoNative. So, the numbers are similar. -// Date: 6/4/17 -// Version: go1.8 -// BenchmarkAddActual-8 10000000 263 ns/op -// BenchmarkAddNoNative-8 10000000 228 ns/op -// BenchmarkAddNative-8 50000000 40.0 ns/op -// BenchmarkAddGoInterface-8 30000000 52.4 ns/op -// BenchmarkAddGoNonInterface-8 2000000000 1.00 ns/op -// BenchmarkAddGo-8 2000000000 1.00 ns/op -func BenchmarkAddActual(b *testing.B) { - v1 := MakeTrusted(Int64, []byte("1")) - v2 := MakeTrusted(Int64, []byte("12")) - for i := 0; i < b.N; i++ { - v1 = NullsafeAdd(v1, v2, Int64) - } -} - -func BenchmarkAddNoNative(b *testing.B) { - v1 := MakeTrusted(Int64, []byte("1")) - v2 := MakeTrusted(Int64, []byte("12")) - for i := 0; i < b.N; i++ { - iv1, _ := ToInt64(v1) - iv2, _ := ToInt64(v2) - v1 = MakeTrusted(Int64, strconv.AppendInt(nil, iv1+iv2, 10)) - } -} - -func BenchmarkAddNative(b *testing.B) { - v1 := makeNativeInt64(1) - v2 := makeNativeInt64(12) - for i := 0; i < b.N; i++ { - iv1 := int64(binary.BigEndian.Uint64(v1.Raw())) - iv2 := int64(binary.BigEndian.Uint64(v2.Raw())) - v1 = makeNativeInt64(iv1 + iv2) - } -} - -func makeNativeInt64(v int64) Value { - buf := make([]byte, 8) - binary.BigEndian.PutUint64(buf, uint64(v)) - return MakeTrusted(Int64, buf) -} - -func BenchmarkAddGoInterface(b *testing.B) { - var v1, v2 interface{} - v1 = int64(1) - v2 = int64(2) - for i := 0; i < b.N; i++ { - v1 = v1.(int64) + v2.(int64) - } -} - -func BenchmarkAddGoNonInterface(b *testing.B) { - v1 := evalResult{typ: Int64, ival: 1} - v2 := evalResult{typ: Int64, ival: 12} - for i := 0; i < b.N; i++ { - if v1.typ != Int64 { - b.Error("type assertion failed") - } - if v2.typ != Int64 { - b.Error("type assertion failed") - } - v1 = evalResult{typ: Int64, ival: v1.ival + v2.ival} - } -} - -func BenchmarkAddGo(b *testing.B) { - v1 := int64(1) - v2 := int64(2) - for i := 0; i < b.N; i++ { - v1 += v2 - } -} diff --git a/go/sqltypes/expressions_test.go b/go/sqltypes/expressions_test.go deleted file mode 100644 index 6163ca32886..00000000000 --- a/go/sqltypes/expressions_test.go +++ /dev/null @@ -1,104 +0,0 @@ -/* -Copyright 2020 The Vitess Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package sqltypes - -import ( - "fmt" - "reflect" - "testing" - - "github.com/magiconair/properties/assert" - querypb "vitess.io/vitess/go/vt/proto/query" -) - -// more tests in go/sqlparser/expressions_test.go - -func TestBinaryOpTypes(t *testing.T) { - type testcase struct { - l, r, e querypb.Type - } - type ops struct { - op BinaryExpr - testcases []testcase - } - - tests := []ops{ - { - op: &Addition{}, - testcases: []testcase{ - {Int64, Int64, Int64}, - {Uint64, Int64, Uint64}, - {Float64, Int64, Float64}, - {Int64, Uint64, Int64}, - {Uint64, Uint64, Uint64}, - {Float64, Uint64, Float64}, - {Int64, Float64, Int64}, - {Uint64, Float64, Uint64}, - {Float64, Float64, Float64}, - }, - }, { - op: &Subtraction{}, - testcases: []testcase{ - {Int64, Int64, Int64}, - {Uint64, Int64, Uint64}, - {Float64, Int64, Float64}, - {Int64, Uint64, Int64}, - {Uint64, Uint64, Uint64}, - {Float64, Uint64, Float64}, - {Int64, Float64, Int64}, - {Uint64, Float64, Uint64}, - {Float64, Float64, Float64}, - }, - }, { - op: &Multiplication{}, - testcases: []testcase{ - {Int64, Int64, Int64}, - {Uint64, Int64, Uint64}, - {Float64, Int64, Float64}, - {Int64, Uint64, Int64}, - {Uint64, Uint64, Uint64}, - {Float64, Uint64, Float64}, - {Int64, Float64, Int64}, - {Uint64, Float64, Uint64}, - {Float64, Float64, Float64}, - }, - }, { - op: &Division{}, - testcases: []testcase{ - {Int64, Int64, Float64}, - {Uint64, Int64, Float64}, - {Float64, Int64, Float64}, - {Int64, Uint64, Float64}, - {Uint64, Uint64, Float64}, - {Float64, Uint64, Float64}, - {Int64, Float64, Float64}, - {Uint64, Float64, Float64}, - {Float64, Float64, Float64}, - }, - }, - } - - for _, op := range tests { - for _, tc := range op.testcases { - name := fmt.Sprintf("%s %s %s", tc.l.String(), reflect.TypeOf(op.op).String(), tc.r.String()) - t.Run(name, func(t *testing.T) { - result := op.op.Type(tc.l) - assert.Equal(t, tc.e, result) - }) - } - } -} diff --git a/go/sqltypes/type.go b/go/sqltypes/type.go index 79fc5e7b294..adc763cc23e 100644 --- a/go/sqltypes/type.go +++ b/go/sqltypes/type.go @@ -81,8 +81,8 @@ func IsBinary(t querypb.Type) bool { return int(t)&flagIsBinary == flagIsBinary } -// isNumber returns true if the type is any type of number. -func isNumber(t querypb.Type) bool { +// IsNumber returns true if the type is any type of number. +func IsNumber(t querypb.Type) bool { return IsIntegral(t) || IsFloat(t) || t == Decimal } diff --git a/go/sqltypes/type_test.go b/go/sqltypes/type_test.go index 660d3c87bcd..efaeb726121 100644 --- a/go/sqltypes/type_test.go +++ b/go/sqltypes/type_test.go @@ -247,7 +247,7 @@ func TestIsFunctions(t *testing.T) { if !IsBinary(Binary) { t.Error("Char: !IsBinary, must be true") } - if !isNumber(Int64) { + if !IsNumber(Int64) { t.Error("Int64: !isNumber, must be true") } } diff --git a/go/test/endtoend/messaging/message_test.go b/go/test/endtoend/messaging/message_test.go index 89574b6ae0d..8de8c808186 100644 --- a/go/test/endtoend/messaging/message_test.go +++ b/go/test/endtoend/messaging/message_test.go @@ -26,6 +26,8 @@ import ( "testing" "time" + "vitess.io/vitess/go/vt/vtgate/evalengine" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "vitess.io/vitess/go/mysql" @@ -226,8 +228,8 @@ func getTimeEpoch(qr *sqltypes.Result) (int64, int64) { if len(qr.Rows) != 1 { return 0, 0 } - t, _ := sqltypes.ToInt64(qr.Rows[0][0]) - e, _ := sqltypes.ToInt64(qr.Rows[0][1]) + t, _ := evalengine.ToInt64(qr.Rows[0][0]) + e, _ := evalengine.ToInt64(qr.Rows[0][1]) return t, e } diff --git a/go/test/endtoend/sharding/verticalsplit/vertical_split_test.go b/go/test/endtoend/sharding/verticalsplit/vertical_split_test.go index e8624315b77..64e8e1f3705 100644 --- a/go/test/endtoend/sharding/verticalsplit/vertical_split_test.go +++ b/go/test/endtoend/sharding/verticalsplit/vertical_split_test.go @@ -27,6 +27,8 @@ import ( "testing" "time" + "vitess.io/vitess/go/vt/vtgate/evalengine" + "vitess.io/vitess/go/json2" "vitess.io/vitess/go/vt/vtgate/vtgateconn" @@ -546,7 +548,7 @@ func checkValues(t *testing.T, tablet *cluster.Vttablet, keyspace string, dbname assert.Equal(t, count, len(qr.Rows), fmt.Sprintf("got wrong number of rows: %d != %d", len(qr.Rows), count)) i := 0 for i < count { - result, _ := sqltypes.ToInt64(qr.Rows[i][0]) + result, _ := evalengine.ToInt64(qr.Rows[i][0]) assert.Equal(t, int64(first+i), result, fmt.Sprintf("got wrong number of rows: %d != %d", len(qr.Rows), first+i)) assert.Contains(t, qr.Rows[i][1].String(), fmt.Sprintf("value %d", first+i), fmt.Sprintf("invalid msg[%d]: 'value %d' != '%s'", i, first+i, qr.Rows[i][1].String())) i++ diff --git a/go/vt/binlog/binlogplayer/binlog_player.go b/go/vt/binlog/binlogplayer/binlog_player.go index d662129cd17..3e3b47d8c96 100644 --- a/go/vt/binlog/binlogplayer/binlog_player.go +++ b/go/vt/binlog/binlogplayer/binlog_player.go @@ -27,6 +27,8 @@ import ( "sync" "time" + "vitess.io/vitess/go/vt/vtgate/evalengine" + "github.com/golang/protobuf/proto" "golang.org/x/net/context" @@ -542,11 +544,11 @@ func ReadVRSettings(dbClient DBClient, uid uint32) (VRSettings, error) { } vrRow := qr.Rows[0] - maxTPS, err := sqltypes.ToInt64(vrRow[2]) + maxTPS, err := evalengine.ToInt64(vrRow[2]) if err != nil { return VRSettings{}, fmt.Errorf("failed to parse max_tps column: %v", err) } - maxReplicationLag, err := sqltypes.ToInt64(vrRow[3]) + maxReplicationLag, err := evalengine.ToInt64(vrRow[3]) if err != nil { return VRSettings{}, fmt.Errorf("failed to parse max_replication_lag column: %v", err) } diff --git a/go/vt/binlog/keyspace_id_resolver.go b/go/vt/binlog/keyspace_id_resolver.go index 1b45e7fed03..ab010eab5d9 100644 --- a/go/vt/binlog/keyspace_id_resolver.go +++ b/go/vt/binlog/keyspace_id_resolver.go @@ -21,6 +21,8 @@ import ( "fmt" "strings" + "vitess.io/vitess/go/vt/vtgate/evalengine" + "golang.org/x/net/context" "vitess.io/vitess/go/sqltypes" @@ -103,7 +105,7 @@ func (r *keyspaceIDResolverFactoryV2) keyspaceID(v sqltypes.Value) ([]byte, erro case topodatapb.KeyspaceIdType_BYTES: return v.ToBytes(), nil case topodatapb.KeyspaceIdType_UINT64: - i, err := sqltypes.ToUint64(v) + i, err := evalengine.ToUint64(v) if err != nil { return nil, fmt.Errorf("non numerical value: %v", err) } diff --git a/go/vt/mysqlctl/replication.go b/go/vt/mysqlctl/replication.go index 4bcec1d8408..5b3e5d3627d 100644 --- a/go/vt/mysqlctl/replication.go +++ b/go/vt/mysqlctl/replication.go @@ -27,11 +27,12 @@ import ( "strings" "time" + "vitess.io/vitess/go/vt/vtgate/evalengine" + "golang.org/x/net/context" "vitess.io/vitess/go/mysql" "vitess.io/vitess/go/netutil" - "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/vt/hook" "vitess.io/vitess/go/vt/log" ) @@ -145,7 +146,7 @@ func (mysqld *Mysqld) GetMysqlPort() (int32, error) { if len(qr.Rows) != 1 { return 0, errors.New("no port variable in mysql") } - utemp, err := sqltypes.ToUint64(qr.Rows[0][1]) + utemp, err := evalengine.ToUint64(qr.Rows[0][1]) if err != nil { return 0, err } diff --git a/go/vt/mysqlctl/schema.go b/go/vt/mysqlctl/schema.go index 18706cb36ed..cde232c9a1c 100644 --- a/go/vt/mysqlctl/schema.go +++ b/go/vt/mysqlctl/schema.go @@ -21,10 +21,11 @@ import ( "regexp" "strings" + "vitess.io/vitess/go/vt/vtgate/evalengine" + "golang.org/x/net/context" "vitess.io/vitess/go/sqlescape" - "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/vt/log" "vitess.io/vitess/go/vt/mysqlctl/tmutils" @@ -84,7 +85,7 @@ func (mysqld *Mysqld) GetSchema(dbName string, tables, excludeTables []string, i var dataLength uint64 if !row[2].IsNull() { // dataLength is NULL for views, then we use 0 - dataLength, err = sqltypes.ToUint64(row[2]) + dataLength, err = evalengine.ToUint64(row[2]) if err != nil { return nil, err } @@ -93,7 +94,7 @@ func (mysqld *Mysqld) GetSchema(dbName string, tables, excludeTables []string, i // get row count var rowCount uint64 if !row[3].IsNull() { - rowCount, err = sqltypes.ToUint64(row[3]) + rowCount, err = evalengine.ToUint64(row[3]) if err != nil { return nil, err } @@ -214,7 +215,7 @@ func (mysqld *Mysqld) GetPrimaryKeyColumns(dbName, table string) ([]string, erro } // check the Seq_in_index is always increasing - seqInIndex, err := sqltypes.ToInt64(row[seqInIndexIndex]) + seqInIndex, err := evalengine.ToInt64(row[seqInIndexIndex]) if err != nil { return nil, err } diff --git a/go/vt/schemamanager/schemaswap/schema_swap.go b/go/vt/schemamanager/schemaswap/schema_swap.go index ded32dc9be3..14e7161d919 100644 --- a/go/vt/schemamanager/schemaswap/schema_swap.go +++ b/go/vt/schemamanager/schemaswap/schema_swap.go @@ -27,6 +27,8 @@ import ( "sync" "time" + "vitess.io/vitess/go/vt/vtgate/evalengine" + "golang.org/x/net/context" "vitess.io/vitess/go/sqltypes" @@ -614,14 +616,14 @@ func (shardSwap *shardSchemaSwap) readShardMetadata(metadata *shardSwapMetadata, for _, row := range queryResult.Rows { switch row[0].ToString() { case lastStartedMetadataName: - swapID, err := sqltypes.ToUint64(row[1]) + swapID, err := evalengine.ToUint64(row[1]) if err != nil { log.Warningf("Could not parse value of last started schema swap id %v, ignoring the value: %v", row[1], err) } else { metadata.lastStartedSwap = swapID } case lastFinishedMetadataName: - swapID, err := sqltypes.ToUint64(row[1]) + swapID, err := evalengine.ToUint64(row[1]) if err != nil { log.Warningf("Could not parse value of last finished schema swap id %v, ignoring the value: %v", row[1], err) } else { @@ -909,7 +911,7 @@ func (shardSwap *shardSchemaSwap) isSwapApplied(tablet *topodatapb.Tablet) (bool // No such row means we need to apply the swap. return false, nil } - swapID, err := sqltypes.ToUint64(swapIDResult.Rows[0][0]) + swapID, err := evalengine.ToUint64(swapIDResult.Rows[0][0]) if err != nil { return false, err } diff --git a/go/vt/sqlparser/expression_converter.go b/go/vt/sqlparser/expression_converter.go index 10c200fbb1a..ce7763e470c 100644 --- a/go/vt/sqlparser/expression_converter.go +++ b/go/vt/sqlparser/expression_converter.go @@ -19,34 +19,34 @@ package sqlparser import ( "fmt" - "vitess.io/vitess/go/sqltypes" + "vitess.io/vitess/go/vt/vtgate/evalengine" ) var ExprNotSupported = fmt.Errorf("Expr Not Supported") //Convert converts between AST expressions and executable expressions -func Convert(e Expr) (sqltypes.Expr, error) { +func Convert(e Expr) (evalengine.Expr, error) { switch node := e.(type) { case *SQLVal: switch node.Type { case IntVal: - return sqltypes.NewLiteralInt(node.Val) + return evalengine.NewLiteralInt(node.Val) case FloatVal: - return sqltypes.NewLiteralFloat(node.Val) + return evalengine.NewLiteralFloat(node.Val) case ValArg: - return &sqltypes.BindVariable{Key: string(node.Val[1:])}, nil + return &evalengine.BindVariable{Key: string(node.Val[1:])}, nil } case *BinaryExpr: - var op sqltypes.BinaryExpr + var op evalengine.BinaryExpr switch node.Operator { case PlusStr: - op = &sqltypes.Addition{} + op = &evalengine.Addition{} case MinusStr: - op = &sqltypes.Subtraction{} + op = &evalengine.Subtraction{} case MultStr: - op = &sqltypes.Multiplication{} + op = &evalengine.Multiplication{} case DivStr: - op = &sqltypes.Division{} + op = &evalengine.Division{} default: return nil, ExprNotSupported } @@ -58,7 +58,7 @@ func Convert(e Expr) (sqltypes.Expr, error) { if err != nil { return nil, err } - return &sqltypes.BinaryOp{ + return &evalengine.BinaryOp{ Expr: op, Left: left, Right: right, diff --git a/go/vt/sqlparser/expressions_test.go b/go/vt/sqlparser/expressions_test.go index ca0940c8545..f0264ada165 100644 --- a/go/vt/sqlparser/expressions_test.go +++ b/go/vt/sqlparser/expressions_test.go @@ -19,6 +19,8 @@ package sqlparser import ( "testing" + "vitess.io/vitess/go/vt/vtgate/evalengine" + "vitess.io/vitess/go/sqltypes" "github.com/stretchr/testify/assert" @@ -78,7 +80,7 @@ func TestEvaluate(t *testing.T) { sqltypesExpr, err := Convert(astExpr) require.Nil(t, err) require.NotNil(t, sqltypesExpr) - env := sqltypes.ExpressionEnv{ + env := evalengine.ExpressionEnv{ BindVars: map[string]*querypb.BindVariable{ "exp": sqltypes.Int64BindVariable(66), "string_bind_variable": sqltypes.StringBindVariable("bar"), diff --git a/go/vt/vitessdriver/convert.go b/go/vt/vitessdriver/convert.go index 16d1b61134d..703ab076b0d 100644 --- a/go/vt/vitessdriver/convert.go +++ b/go/vt/vitessdriver/convert.go @@ -21,6 +21,8 @@ import ( "fmt" "time" + "vitess.io/vitess/go/vt/vtgate/evalengine" + "vitess.io/vitess/go/sqltypes" querypb "vitess.io/vitess/go/vt/proto/query" ) @@ -36,7 +38,7 @@ func (cv *converter) ToNative(v sqltypes.Value) (interface{}, error) { case sqltypes.Date: return DateToNative(v, cv.location) } - return sqltypes.ToNative(v) + return evalengine.ToNative(v) } func (cv *converter) BuildBindVariable(v interface{}) (*querypb.BindVariable, error) { diff --git a/go/vt/vtgate/endtoend/last_insert_id_test.go b/go/vt/vtgate/endtoend/last_insert_id_test.go index b7ac4a70452..60694090637 100644 --- a/go/vt/vtgate/endtoend/last_insert_id_test.go +++ b/go/vt/vtgate/endtoend/last_insert_id_test.go @@ -21,10 +21,10 @@ import ( "fmt" "testing" + "vitess.io/vitess/go/vt/vtgate/evalengine" + "github.com/google/go-cmp/cmp" "github.com/stretchr/testify/require" - "vitess.io/vitess/go/sqltypes" - "vitess.io/vitess/go/mysql" ) @@ -36,7 +36,7 @@ func TestLastInsertId(t *testing.T) { // figure out the last inserted id before we run change anything qr := exec(t, conn, "select max(id) from t1_last_insert_id") - oldLastID, err := sqltypes.ToUint64(qr.Rows[0][0]) + oldLastID, err := evalengine.ToUint64(qr.Rows[0][0]) require.NoError(t, err) exec(t, conn, "insert into t1_last_insert_id(id1) values(42)") @@ -59,7 +59,7 @@ func TestLastInsertIdWithRollback(t *testing.T) { // figure out the last inserted id before we run our tests qr := exec(t, conn, "select max(id) from t1_last_insert_id") - oldLastID, err := sqltypes.ToUint64(qr.Rows[0][0]) + oldLastID, err := evalengine.ToUint64(qr.Rows[0][0]) require.NoError(t, err) // add row inside explicit transaction diff --git a/go/vt/vtgate/engine/insert.go b/go/vt/vtgate/engine/insert.go index 6a9db408cae..7e9d7af07e5 100644 --- a/go/vt/vtgate/engine/insert.go +++ b/go/vt/vtgate/engine/insert.go @@ -23,6 +23,8 @@ import ( "strings" "time" + "vitess.io/vitess/go/vt/vtgate/evalengine" + topodatapb "vitess.io/vitess/go/vt/proto/topodata" "vitess.io/vitess/go/sqltypes" @@ -299,7 +301,7 @@ func (ins *Insert) processGenerate(vcursor VCursor, bindVars map[string]*querypb } // If no rows are returned, it's an internal error, and the code // must panic, which will be caught and reported. - insertID, err = sqltypes.ToInt64(qr.Rows[0][0]) + insertID, err = evalengine.ToInt64(qr.Rows[0][0]) if err != nil { return 0, err } diff --git a/go/vt/vtgate/engine/limit.go b/go/vt/vtgate/engine/limit.go index 68b7b4b327c..c43b6a4b28f 100644 --- a/go/vt/vtgate/engine/limit.go +++ b/go/vt/vtgate/engine/limit.go @@ -20,6 +20,8 @@ import ( "fmt" "io" + "vitess.io/vitess/go/vt/vtgate/evalengine" + "vitess.io/vitess/go/sqltypes" querypb "vitess.io/vitess/go/vt/proto/query" @@ -157,7 +159,7 @@ func (l *Limit) fetchCount(bindVars map[string]*querypb.BindVariable) (int, erro if err != nil { return 0, err } - num, err := sqltypes.ToUint64(resolved) + num, err := evalengine.ToUint64(resolved) if err != nil { return 0, err } @@ -176,7 +178,7 @@ func (l *Limit) fetchOffset(bindVars map[string]*querypb.BindVariable) (int, err if err != nil { return 0, err } - num, err := sqltypes.ToUint64(resolved) + num, err := evalengine.ToUint64(resolved) if err != nil { return 0, err } diff --git a/go/vt/vtgate/engine/memory_sort.go b/go/vt/vtgate/engine/memory_sort.go index 4db9b6a1448..c243a31d607 100644 --- a/go/vt/vtgate/engine/memory_sort.go +++ b/go/vt/vtgate/engine/memory_sort.go @@ -24,6 +24,8 @@ import ( "sort" "strings" + "vitess.io/vitess/go/vt/vtgate/evalengine" + "vitess.io/vitess/go/sqltypes" querypb "vitess.io/vitess/go/vt/proto/query" ) @@ -161,7 +163,7 @@ func (ms *MemorySort) fetchCount(bindVars map[string]*querypb.BindVariable) (int if resolved.IsNull() { return math.MaxInt64, nil } - num, err := sqltypes.ToUint64(resolved) + num, err := evalengine.ToUint64(resolved) if err != nil { return 0, err } @@ -230,7 +232,7 @@ func (sh *sortHeap) Less(i, j int) bool { if sh.err != nil { return true } - cmp, err := sqltypes.NullsafeCompare(sh.rows[i][order.Col], sh.rows[j][order.Col]) + cmp, err := evalengine.NullsafeCompare(sh.rows[i][order.Col], sh.rows[j][order.Col]) if err != nil { sh.err = err return true diff --git a/go/vt/vtgate/engine/merge_sort.go b/go/vt/vtgate/engine/merge_sort.go index 17ae57c1c69..aea1cc90f11 100644 --- a/go/vt/vtgate/engine/merge_sort.go +++ b/go/vt/vtgate/engine/merge_sort.go @@ -20,6 +20,8 @@ import ( "container/heap" "io" + "vitess.io/vitess/go/vt/vtgate/evalengine" + "golang.org/x/net/context" "vitess.io/vitess/go/sqltypes" @@ -250,7 +252,7 @@ func (sh *scatterHeap) Less(i, j int) bool { if sh.err != nil { return true } - cmp, err := sqltypes.NullsafeCompare(sh.rows[i].row[order.Col], sh.rows[j].row[order.Col]) + cmp, err := evalengine.NullsafeCompare(sh.rows[i].row[order.Col], sh.rows[j].row[order.Col]) if err != nil { sh.err = err return true diff --git a/go/vt/vtgate/engine/ordered_aggregate.go b/go/vt/vtgate/engine/ordered_aggregate.go index 74b1667e734..bee3cc52de5 100644 --- a/go/vt/vtgate/engine/ordered_aggregate.go +++ b/go/vt/vtgate/engine/ordered_aggregate.go @@ -20,6 +20,8 @@ import ( "fmt" "strconv" + "vitess.io/vitess/go/vt/vtgate/evalengine" + "vitess.io/vitess/go/vt/proto/vtrpc" "vitess.io/vitess/go/vt/vterrors" @@ -298,7 +300,7 @@ func (oa *OrderedAggregate) convertRow(row []sqltypes.Value) (newRow []sqltypes. case AggregateSumDistinct: curDistinct = row[aggr.Col] var err error - newRow[aggr.Col], err = sqltypes.Cast(row[aggr.Col], opcodeType[aggr.Opcode]) + newRow[aggr.Col], err = evalengine.Cast(row[aggr.Col], opcodeType[aggr.Opcode]) if err != nil { newRow[aggr.Col] = sumZero } @@ -328,7 +330,7 @@ func (oa *OrderedAggregate) NeedsTransaction() bool { func (oa *OrderedAggregate) keysEqual(row1, row2 []sqltypes.Value) (bool, error) { for _, key := range oa.Keys { - cmp, err := sqltypes.NullsafeCompare(row1[key], row2[key]) + cmp, err := evalengine.NullsafeCompare(row1[key], row2[key]) if err != nil { return false, err } @@ -346,7 +348,7 @@ func (oa *OrderedAggregate) merge(fields []*querypb.Field, row1, row2 []sqltypes if row2[aggr.Col].IsNull() { continue } - cmp, err := sqltypes.NullsafeCompare(curDistinct, row2[aggr.Col]) + cmp, err := evalengine.NullsafeCompare(curDistinct, row2[aggr.Col]) if err != nil { return nil, sqltypes.NULL, err } @@ -358,15 +360,15 @@ func (oa *OrderedAggregate) merge(fields []*querypb.Field, row1, row2 []sqltypes var err error switch aggr.Opcode { case AggregateCount, AggregateSum: - result[aggr.Col] = sqltypes.NullsafeAdd(row1[aggr.Col], row2[aggr.Col], fields[aggr.Col].Type) + result[aggr.Col] = evalengine.NullsafeAdd(row1[aggr.Col], row2[aggr.Col], fields[aggr.Col].Type) case AggregateMin: - result[aggr.Col], err = sqltypes.Min(row1[aggr.Col], row2[aggr.Col]) + result[aggr.Col], err = evalengine.Min(row1[aggr.Col], row2[aggr.Col]) case AggregateMax: - result[aggr.Col], err = sqltypes.Max(row1[aggr.Col], row2[aggr.Col]) + result[aggr.Col], err = evalengine.Max(row1[aggr.Col], row2[aggr.Col]) case AggregateCountDistinct: - result[aggr.Col] = sqltypes.NullsafeAdd(row1[aggr.Col], countOne, opcodeType[aggr.Opcode]) + result[aggr.Col] = evalengine.NullsafeAdd(row1[aggr.Col], countOne, opcodeType[aggr.Opcode]) case AggregateSumDistinct: - result[aggr.Col] = sqltypes.NullsafeAdd(row1[aggr.Col], row2[aggr.Col], opcodeType[aggr.Opcode]) + result[aggr.Col] = evalengine.NullsafeAdd(row1[aggr.Col], row2[aggr.Col], opcodeType[aggr.Opcode]) default: return nil, sqltypes.NULL, fmt.Errorf("BUG: Unexpected opcode: %v", aggr.Opcode) } diff --git a/go/vt/vtgate/engine/projection.go b/go/vt/vtgate/engine/projection.go index 9257a9dc02d..0e0f9487adc 100644 --- a/go/vt/vtgate/engine/projection.go +++ b/go/vt/vtgate/engine/projection.go @@ -3,13 +3,14 @@ package engine import ( "vitess.io/vitess/go/sqltypes" querypb "vitess.io/vitess/go/vt/proto/query" + "vitess.io/vitess/go/vt/vtgate/evalengine" ) var _ Primitive = (*Projection)(nil) type Projection struct { Cols []string - Exprs []sqltypes.Expr + Exprs []evalengine.Expr Input Primitive noTxNeeded } @@ -32,7 +33,7 @@ func (p *Projection) Execute(vcursor VCursor, bindVars map[string]*querypb.BindV return nil, err } - env := sqltypes.ExpressionEnv{ + env := evalengine.ExpressionEnv{ BindVars: bindVars, } @@ -61,7 +62,7 @@ func (p *Projection) StreamExecute(vcursor VCursor, bindVars map[string]*querypb return err } - env := sqltypes.ExpressionEnv{ + env := evalengine.ExpressionEnv{ BindVars: bindVars, } @@ -94,7 +95,7 @@ func (p *Projection) GetFields(vcursor VCursor, bindVars map[string]*querypb.Bin } func (p *Projection) addFields(qr *sqltypes.Result, bindVars map[string]*querypb.BindVariable) { - env := sqltypes.ExpressionEnv{BindVars: bindVars} + env := evalengine.ExpressionEnv{BindVars: bindVars} for i, col := range p.Cols { qr.Fields = append(qr.Fields, &querypb.Field{ Name: col, diff --git a/go/vt/vtgate/engine/route.go b/go/vt/vtgate/engine/route.go index 9fff1e5afeb..c772f8f76dc 100644 --- a/go/vt/vtgate/engine/route.go +++ b/go/vt/vtgate/engine/route.go @@ -23,6 +23,8 @@ import ( "strconv" "time" + "vitess.io/vitess/go/vt/vtgate/evalengine" + "vitess.io/vitess/go/mysql" "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/stats" @@ -439,7 +441,7 @@ func (route *Route) sort(in *sqltypes.Result) (*sqltypes.Result, error) { return true } var cmp int - cmp, err = sqltypes.NullsafeCompare(out.Rows[i][order.Col], out.Rows[j][order.Col]) + cmp, err = evalengine.NullsafeCompare(out.Rows[i][order.Col], out.Rows[j][order.Col]) if err != nil { return true } diff --git a/go/vt/vtgate/engine/vindex_func.go b/go/vt/vtgate/engine/vindex_func.go index 0e4e516c3b8..c30d0bcd80d 100644 --- a/go/vt/vtgate/engine/vindex_func.go +++ b/go/vt/vtgate/engine/vindex_func.go @@ -19,6 +19,8 @@ package engine import ( "encoding/json" + "vitess.io/vitess/go/vt/vtgate/evalengine" + "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/vt/key" "vitess.io/vitess/go/vt/vtgate/vindexes" @@ -109,7 +111,7 @@ func (vf *VindexFunc) mapVindex(vcursor VCursor, bindVars map[string]*querypb.Bi if err != nil { return nil, err } - vkey, err := sqltypes.Cast(k, sqltypes.VarBinary) + vkey, err := evalengine.Cast(k, sqltypes.VarBinary) if err != nil { return nil, err } diff --git a/go/sqltypes/arithmetic.go b/go/vt/vtgate/evalengine/arithmetic.go similarity index 66% rename from go/sqltypes/arithmetic.go rename to go/vt/vtgate/evalengine/arithmetic.go index 90f4d8f8415..97e9e5f321e 100644 --- a/go/sqltypes/arithmetic.go +++ b/go/vt/vtgate/evalengine/arithmetic.go @@ -14,11 +14,12 @@ See the License for the specific language governing permissions and limitations under the License. */ -package sqltypes +package evalengine import ( "bytes" "fmt" + "vitess.io/vitess/go/sqltypes" "strconv" @@ -33,101 +34,101 @@ var zeroBytes = []byte("0") // Add adds two values together // if v1 or v2 is null, then it returns null -func Add(v1, v2 Value) (Value, error) { +func Add(v1, v2 sqltypes.Value) (sqltypes.Value, error) { if v1.IsNull() || v2.IsNull() { - return NULL, nil + return sqltypes.NULL, nil } lv1, err := newNumeric(v1) if err != nil { - return NULL, err + return sqltypes.NULL, err } lv2, err := newNumeric(v2) if err != nil { - return NULL, err + return sqltypes.NULL, err } lresult, err := addNumericWithError(lv1, lv2) if err != nil { - return NULL, err + return sqltypes.NULL, err } return castFromNumeric(lresult, lresult.typ), nil } // Subtract takes two values and subtracts them -func Subtract(v1, v2 Value) (Value, error) { +func Subtract(v1, v2 sqltypes.Value) (sqltypes.Value, error) { if v1.IsNull() || v2.IsNull() { - return NULL, nil + return sqltypes.NULL, nil } lv1, err := newNumeric(v1) if err != nil { - return NULL, err + return sqltypes.NULL, err } lv2, err := newNumeric(v2) if err != nil { - return NULL, err + return sqltypes.NULL, err } lresult, err := subtractNumericWithError(lv1, lv2) if err != nil { - return NULL, err + return sqltypes.NULL, err } return castFromNumeric(lresult, lresult.typ), nil } // Multiply takes two values and multiplies it together -func Multiply(v1, v2 Value) (Value, error) { +func Multiply(v1, v2 sqltypes.Value) (sqltypes.Value, error) { if v1.IsNull() || v2.IsNull() { - return NULL, nil + return sqltypes.NULL, nil } lv1, err := newNumeric(v1) if err != nil { - return NULL, err + return sqltypes.NULL, err } lv2, err := newNumeric(v2) if err != nil { - return NULL, err + return sqltypes.NULL, err } lresult, err := multiplyNumericWithError(lv1, lv2) if err != nil { - return NULL, err + return sqltypes.NULL, err } return castFromNumeric(lresult, lresult.typ), nil } // Divide (Float) for MySQL. Replicates behavior of "/" operator -func Divide(v1, v2 Value) (Value, error) { +func Divide(v1, v2 sqltypes.Value) (sqltypes.Value, error) { if v1.IsNull() || v2.IsNull() { - return NULL, nil + return sqltypes.NULL, nil } lv2AsFloat, err := ToFloat64(v2) divisorIsZero := lv2AsFloat == 0 if divisorIsZero || err != nil { - return NULL, err + return sqltypes.NULL, err } lv1, err := newNumeric(v1) if err != nil { - return NULL, err + return sqltypes.NULL, err } lv2, err := newNumeric(v2) if err != nil { - return NULL, err + return sqltypes.NULL, err } lresult, err := divideNumericWithError(lv1, lv2) if err != nil { - return NULL, err + return sqltypes.NULL, err } return castFromNumeric(lresult, lresult.typ), nil @@ -144,21 +145,21 @@ func Divide(v1, v2 Value) (Value, error) { // addition, if one of the input types was Decimal, then // a Decimal is built. Otherwise, the final type of the // result is preserved. -func NullsafeAdd(v1, v2 Value, resultType querypb.Type) Value { +func NullsafeAdd(v1, v2 sqltypes.Value, resultType querypb.Type) sqltypes.Value { if v1.IsNull() { - v1 = MakeTrusted(resultType, zeroBytes) + v1 = sqltypes.MakeTrusted(resultType, zeroBytes) } if v2.IsNull() { - v2 = MakeTrusted(resultType, zeroBytes) + v2 = sqltypes.MakeTrusted(resultType, zeroBytes) } lv1, err := newNumeric(v1) if err != nil { - return NULL + return sqltypes.NULL } lv2, err := newNumeric(v2) if err != nil { - return NULL + return sqltypes.NULL } lresult := addNumeric(lv1, lv2) @@ -170,7 +171,7 @@ func NullsafeAdd(v1, v2 Value, resultType querypb.Type) Value { // numeric, then a numeric comparison is performed after // necessary conversions. If none are numeric, then it's // a simple binary comparison. Uncomparable values return an error. -func NullsafeCompare(v1, v2 Value) (int, error) { +func NullsafeCompare(v1, v2 sqltypes.Value) (int, error) { // Based on the categorization defined for the types, // we're going to allow comparison of the following: // Null, isNumber, IsBinary. This will exclude IsQuoted @@ -184,7 +185,7 @@ func NullsafeCompare(v1, v2 Value) (int, error) { if v2.IsNull() { return 1, nil } - if isNumber(v1.Type()) || isNumber(v2.Type()) { + if sqltypes.IsNumber(v1.Type()) || sqltypes.IsNumber(v2.Type()) { lv1, err := newNumeric(v1) if err != nil { return 0, err @@ -202,12 +203,12 @@ func NullsafeCompare(v1, v2 Value) (int, error) { } // isByteComparable returns true if the type is binary or date/time. -func isByteComparable(v Value) bool { +func isByteComparable(v sqltypes.Value) bool { if v.IsBinary() { return true } switch v.Type() { - case Timestamp, Date, Time, Datetime: + case sqltypes.Timestamp, sqltypes.Date, sqltypes.Time, sqltypes.Datetime: return true } return false @@ -216,18 +217,18 @@ func isByteComparable(v Value) bool { // Min returns the minimum of v1 and v2. If one of the // values is NULL, it returns the other value. If both // are NULL, it returns NULL. -func Min(v1, v2 Value) (Value, error) { +func Min(v1, v2 sqltypes.Value) (sqltypes.Value, error) { return minmax(v1, v2, true) } // Max returns the maximum of v1 and v2. If one of the // values is NULL, it returns the other value. If both // are NULL, it returns NULL. -func Max(v1, v2 Value) (Value, error) { +func Max(v1, v2 sqltypes.Value) (sqltypes.Value, error) { return minmax(v1, v2, false) } -func minmax(v1, v2 Value, min bool) (Value, error) { +func minmax(v1, v2 sqltypes.Value, min bool) (sqltypes.Value, error) { if v1.IsNull() { return v2, nil } @@ -237,7 +238,7 @@ func minmax(v1, v2 Value, min bool) (Value, error) { n, err := NullsafeCompare(v1, v2) if err != nil { - return NULL, err + return sqltypes.NULL, err } // XNOR construct. See tests. @@ -249,61 +250,61 @@ func minmax(v1, v2 Value, min bool) (Value, error) { } // Cast converts a Value to the target type. -func Cast(v Value, typ querypb.Type) (Value, error) { +func Cast(v sqltypes.Value, typ querypb.Type) (sqltypes.Value, error) { if v.Type() == typ || v.IsNull() { return v, nil } - if IsSigned(typ) && v.IsSigned() { - return MakeTrusted(typ, v.ToBytes()), nil + if sqltypes.IsSigned(typ) && v.IsSigned() { + return sqltypes.MakeTrusted(typ, v.ToBytes()), nil } - if IsUnsigned(typ) && v.IsUnsigned() { - return MakeTrusted(typ, v.ToBytes()), nil + if sqltypes.IsUnsigned(typ) && v.IsUnsigned() { + return sqltypes.MakeTrusted(typ, v.ToBytes()), nil } - if (IsFloat(typ) || typ == Decimal) && (v.IsIntegral() || v.IsFloat() || v.Type() == Decimal) { - return MakeTrusted(typ, v.ToBytes()), nil + if (sqltypes.IsFloat(typ) || typ == sqltypes.Decimal) && (v.IsIntegral() || v.IsFloat() || v.Type() == sqltypes.Decimal) { + return sqltypes.MakeTrusted(typ, v.ToBytes()), nil } - if IsQuoted(typ) && (v.IsIntegral() || v.IsFloat() || v.Type() == Decimal || v.IsQuoted()) { - return MakeTrusted(typ, v.ToBytes()), nil + if sqltypes.IsQuoted(typ) && (v.IsIntegral() || v.IsFloat() || v.Type() == sqltypes.Decimal || v.IsQuoted()) { + return sqltypes.MakeTrusted(typ, v.ToBytes()), nil } // Explicitly disallow Expression. - if v.Type() == Expression { - return NULL, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "%v cannot be cast to %v", v, typ) + if v.Type() == sqltypes.Expression { + return sqltypes.NULL, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "%v cannot be cast to %v", v, typ) } // If the above fast-paths were not possible, // go through full validation. - return NewValue(typ, v.ToBytes()) + return sqltypes.NewValue(typ, v.ToBytes()) } // ToUint64 converts Value to uint64. -func ToUint64(v Value) (uint64, error) { +func ToUint64(v sqltypes.Value) (uint64, error) { num, err := newIntegralNumeric(v) if err != nil { return 0, err } switch num.typ { - case Int64: + case sqltypes.Int64: if num.ival < 0 { return 0, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "negative number cannot be converted to unsigned: %d", num.ival) } return uint64(num.ival), nil - case Uint64: + case sqltypes.Uint64: return num.uval, nil } panic("unreachable") } // ToInt64 converts Value to int64. -func ToInt64(v Value) (int64, error) { +func ToInt64(v sqltypes.Value) (int64, error) { num, err := newIntegralNumeric(v) if err != nil { return 0, err } switch num.typ { - case Int64: + case sqltypes.Int64: return num.ival, nil - case Uint64: + case sqltypes.Uint64: ival := int64(num.uval) if ival < 0 { return 0, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "unsigned number overflows int64 value: %d", num.uval) @@ -314,17 +315,17 @@ func ToInt64(v Value) (int64, error) { } // ToFloat64 converts Value to float64. -func ToFloat64(v Value) (float64, error) { +func ToFloat64(v sqltypes.Value) (float64, error) { num, err := newNumeric(v) if err != nil { return 0, err } switch num.typ { - case Int64: + case sqltypes.Int64: return float64(num.ival), nil - case Uint64: + case sqltypes.Uint64: return float64(num.uval), nil - case Float64: + case sqltypes.Float64: return num.fval, nil } panic("unreachable") @@ -332,11 +333,11 @@ func ToFloat64(v Value) (float64, error) { // ToNative converts Value to a native go type. // Decimal is returned as []byte. -func ToNative(v Value) (interface{}, error) { +func ToNative(v sqltypes.Value) (interface{}, error) { var out interface{} var err error switch { - case v.Type() == Null: + case v.Type() == sqltypes.Null: // no-op case v.IsSigned(): return ToInt64(v) @@ -344,16 +345,16 @@ func ToNative(v Value) (interface{}, error) { return ToUint64(v) case v.IsFloat(): return ToFloat64(v) - case v.IsQuoted() || v.Type() == Bit || v.Type() == Decimal: - out = v.val - case v.Type() == Expression: + case v.IsQuoted() || v.Type() == sqltypes.Bit || v.Type() == sqltypes.Decimal: + out = v.ToBytes() + case v.Type() == sqltypes.Expression: err = vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "%v cannot be converted to a go type", v) } return out, err } // newNumeric parses a value and produces an Int64, Uint64 or Float64. -func newNumeric(v Value) (evalResult, error) { +func newNumeric(v sqltypes.Value) (evalResult, error) { str := v.ToString() switch { case v.IsSigned(): @@ -361,33 +362,33 @@ func newNumeric(v Value) (evalResult, error) { if err != nil { return evalResult{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "%v", err) } - return evalResult{ival: ival, typ: Int64}, nil + return evalResult{ival: ival, typ: sqltypes.Int64}, nil case v.IsUnsigned(): uval, err := strconv.ParseUint(str, 10, 64) if err != nil { return evalResult{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "%v", err) } - return evalResult{uval: uval, typ: Uint64}, nil + return evalResult{uval: uval, typ: sqltypes.Uint64}, nil case v.IsFloat(): fval, err := strconv.ParseFloat(str, 64) if err != nil { return evalResult{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "%v", err) } - return evalResult{fval: fval, typ: Float64}, nil + return evalResult{fval: fval, typ: sqltypes.Float64}, nil } // For other types, do best effort. if ival, err := strconv.ParseInt(str, 10, 64); err == nil { - return evalResult{ival: ival, typ: Int64}, nil + return evalResult{ival: ival, typ: sqltypes.Int64}, nil } if fval, err := strconv.ParseFloat(str, 64); err == nil { - return evalResult{fval: fval, typ: Float64}, nil + return evalResult{fval: fval, typ: sqltypes.Float64}, nil } - return evalResult{ival: 0, typ: Int64}, nil + return evalResult{ival: 0, typ: sqltypes.Int64}, nil } // newIntegralNumeric parses a value and produces an Int64 or Uint64. -func newIntegralNumeric(v Value) (evalResult, error) { +func newIntegralNumeric(v sqltypes.Value) (evalResult, error) { str := v.ToString() switch { case v.IsSigned(): @@ -395,21 +396,21 @@ func newIntegralNumeric(v Value) (evalResult, error) { if err != nil { return evalResult{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "%v", err) } - return evalResult{ival: ival, typ: Int64}, nil + return evalResult{ival: ival, typ: sqltypes.Int64}, nil case v.IsUnsigned(): uval, err := strconv.ParseUint(str, 10, 64) if err != nil { return evalResult{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "%v", err) } - return evalResult{uval: uval, typ: Uint64}, nil + return evalResult{uval: uval, typ: sqltypes.Uint64}, nil } // For other types, do best effort. if ival, err := strconv.ParseInt(str, 10, 64); err == nil { - return evalResult{ival: ival, typ: Int64}, nil + return evalResult{ival: ival, typ: sqltypes.Int64}, nil } if uval, err := strconv.ParseUint(str, 10, 64); err == nil { - return evalResult{uval: uval, typ: Uint64}, nil + return evalResult{uval: uval, typ: sqltypes.Uint64}, nil } return evalResult{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "could not parse value: '%s'", str) } @@ -417,16 +418,16 @@ func newIntegralNumeric(v Value) (evalResult, error) { func addNumeric(v1, v2 evalResult) evalResult { v1, v2 = prioritize(v1, v2) switch v1.typ { - case Int64: + case sqltypes.Int64: return intPlusInt(v1.ival, v2.ival) - case Uint64: + case sqltypes.Uint64: switch v2.typ { - case Int64: + case sqltypes.Int64: return uintPlusInt(v1.uval, v2.ival) - case Uint64: + case sqltypes.Uint64: return uintPlusUint(v1.uval, v2.uval) } - case Float64: + case sqltypes.Float64: return floatPlusAny(v1.fval, v2) } panic("unreachable") @@ -435,16 +436,16 @@ func addNumeric(v1, v2 evalResult) evalResult { func addNumericWithError(v1, v2 evalResult) (evalResult, error) { v1, v2 = prioritize(v1, v2) switch v1.typ { - case Int64: + case sqltypes.Int64: return intPlusIntWithError(v1.ival, v2.ival) - case Uint64: + case sqltypes.Uint64: switch v2.typ { - case Int64: + case sqltypes.Int64: return uintPlusIntWithError(v1.uval, v2.ival) - case Uint64: + case sqltypes.Uint64: return uintPlusUintWithError(v1.uval, v2.uval) } - case Float64: + case sqltypes.Float64: return floatPlusAny(v1.fval, v2), nil } panic("unreachable") @@ -452,25 +453,25 @@ func addNumericWithError(v1, v2 evalResult) (evalResult, error) { func subtractNumericWithError(v1, v2 evalResult) (evalResult, error) { switch v1.typ { - case Int64: + case sqltypes.Int64: switch v2.typ { - case Int64: + case sqltypes.Int64: return intMinusIntWithError(v1.ival, v2.ival) - case Uint64: + case sqltypes.Uint64: return intMinusUintWithError(v1.ival, v2.uval) - case Float64: + case sqltypes.Float64: return anyMinusFloat(v1, v2.fval), nil } - case Uint64: + case sqltypes.Uint64: switch v2.typ { - case Int64: + case sqltypes.Int64: return uintMinusIntWithError(v1.uval, v2.ival) - case Uint64: + case sqltypes.Uint64: return uintMinusUintWithError(v1.uval, v2.uval) - case Float64: + case sqltypes.Float64: return anyMinusFloat(v1, v2.fval), nil } - case Float64: + case sqltypes.Float64: return floatMinusAny(v1.fval, v2), nil } panic("unreachable") @@ -479,16 +480,16 @@ func subtractNumericWithError(v1, v2 evalResult) (evalResult, error) { func multiplyNumericWithError(v1, v2 evalResult) (evalResult, error) { v1, v2 = prioritize(v1, v2) switch v1.typ { - case Int64: + case sqltypes.Int64: return intTimesIntWithError(v1.ival, v2.ival) - case Uint64: + case sqltypes.Uint64: switch v2.typ { - case Int64: + case sqltypes.Int64: return uintTimesIntWithError(v1.uval, v2.ival) - case Uint64: + case sqltypes.Uint64: return uintTimesUintWithError(v1.uval, v2.uval) } - case Float64: + case sqltypes.Float64: return floatTimesAny(v1.fval, v2), nil } panic("unreachable") @@ -496,13 +497,13 @@ func multiplyNumericWithError(v1, v2 evalResult) (evalResult, error) { func divideNumericWithError(v1, v2 evalResult) (evalResult, error) { switch v1.typ { - case Int64: + case sqltypes.Int64: return floatDivideAnyWithError(float64(v1.ival), v2) - case Uint64: + case sqltypes.Uint64: return floatDivideAnyWithError(float64(v1.uval), v2) - case Float64: + case sqltypes.Float64: return floatDivideAnyWithError(v1.fval, v2) } panic("unreachable") @@ -512,12 +513,12 @@ func divideNumericWithError(v1, v2 evalResult) (evalResult, error) { // to be Float64, Uint64, Int64. func prioritize(v1, v2 evalResult) (altv1, altv2 evalResult) { switch v1.typ { - case Int64: - if v2.typ == Uint64 || v2.typ == Float64 { + case sqltypes.Int64: + if v2.typ == sqltypes.Uint64 || v2.typ == sqltypes.Float64 { return v2, v1 } - case Uint64: - if v2.typ == Float64 { + case sqltypes.Uint64: + if v2.typ == sqltypes.Float64 { return v2, v1 } } @@ -532,10 +533,10 @@ func intPlusInt(v1, v2 int64) evalResult { if v1 < 0 && v2 < 0 && result > 0 { goto overflow } - return evalResult{typ: Int64, ival: result} + return evalResult{typ: sqltypes.Int64, ival: result} overflow: - return evalResult{typ: Float64, fval: float64(v1) + float64(v2)} + return evalResult{typ: sqltypes.Float64, fval: float64(v1) + float64(v2)} } func intPlusIntWithError(v1, v2 int64) (evalResult, error) { @@ -543,7 +544,7 @@ func intPlusIntWithError(v1, v2 int64) (evalResult, error) { if (result > v1) != (v2 > 0) { return evalResult{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT value is out of range in %v + %v", v1, v2) } - return evalResult{typ: Int64, ival: result}, nil + return evalResult{typ: sqltypes.Int64, ival: result}, nil } func intMinusIntWithError(v1, v2 int64) (evalResult, error) { @@ -552,7 +553,7 @@ func intMinusIntWithError(v1, v2 int64) (evalResult, error) { if (result < v1) != (v2 > 0) { return evalResult{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT value is out of range in %v - %v", v1, v2) } - return evalResult{typ: Int64, ival: result}, nil + return evalResult{typ: sqltypes.Int64, ival: result}, nil } func intTimesIntWithError(v1, v2 int64) (evalResult, error) { @@ -560,7 +561,7 @@ func intTimesIntWithError(v1, v2 int64) (evalResult, error) { if v1 != 0 && result/v1 != v2 { return evalResult{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT value is out of range in %v * %v", v1, v2) } - return evalResult{typ: Int64, ival: result}, nil + return evalResult{typ: sqltypes.Int64, ival: result}, nil } @@ -605,9 +606,9 @@ func uintTimesIntWithError(v1 uint64, v2 int64) (evalResult, error) { func uintPlusUint(v1, v2 uint64) evalResult { result := v1 + v2 if result < v2 { - return evalResult{typ: Float64, fval: float64(v1) + float64(v2)} + return evalResult{typ: sqltypes.Float64, fval: float64(v1) + float64(v2)} } - return evalResult{typ: Uint64, uval: result} + return evalResult{typ: sqltypes.Uint64, uval: result} } func uintPlusUintWithError(v1, v2 uint64) (evalResult, error) { @@ -615,7 +616,7 @@ func uintPlusUintWithError(v1, v2 uint64) (evalResult, error) { if result < v2 { return evalResult{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in %v + %v", v1, v2) } - return evalResult{typ: Uint64, uval: result}, nil + return evalResult{typ: sqltypes.Uint64, uval: result}, nil } func uintMinusUintWithError(v1, v2 uint64) (evalResult, error) { @@ -624,7 +625,7 @@ func uintMinusUintWithError(v1, v2 uint64) (evalResult, error) { return evalResult{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in %v - %v", v1, v2) } - return evalResult{typ: Uint64, uval: result}, nil + return evalResult{typ: sqltypes.Uint64, uval: result}, nil } func uintTimesUintWithError(v1, v2 uint64) (evalResult, error) { @@ -632,44 +633,44 @@ func uintTimesUintWithError(v1, v2 uint64) (evalResult, error) { if result < v2 || result < v1 { return evalResult{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in %v * %v", v1, v2) } - return evalResult{typ: Uint64, uval: result}, nil + return evalResult{typ: sqltypes.Uint64, uval: result}, nil } func floatPlusAny(v1 float64, v2 evalResult) evalResult { switch v2.typ { - case Int64: + case sqltypes.Int64: v2.fval = float64(v2.ival) - case Uint64: + case sqltypes.Uint64: v2.fval = float64(v2.uval) } - return evalResult{typ: Float64, fval: v1 + v2.fval} + return evalResult{typ: sqltypes.Float64, fval: v1 + v2.fval} } func floatMinusAny(v1 float64, v2 evalResult) evalResult { switch v2.typ { - case Int64: + case sqltypes.Int64: v2.fval = float64(v2.ival) - case Uint64: + case sqltypes.Uint64: v2.fval = float64(v2.uval) } - return evalResult{typ: Float64, fval: v1 - v2.fval} + return evalResult{typ: sqltypes.Float64, fval: v1 - v2.fval} } func floatTimesAny(v1 float64, v2 evalResult) evalResult { switch v2.typ { - case Int64: + case sqltypes.Int64: v2.fval = float64(v2.ival) - case Uint64: + case sqltypes.Uint64: v2.fval = float64(v2.uval) } - return evalResult{typ: Float64, fval: v1 * v2.fval} + return evalResult{typ: sqltypes.Float64, fval: v1 * v2.fval} } func floatDivideAnyWithError(v1 float64, v2 evalResult) (evalResult, error) { switch v2.typ { - case Int64: + case sqltypes.Int64: v2.fval = float64(v2.ival) - case Uint64: + case sqltypes.Uint64: v2.fval = float64(v2.uval) } result := v1 / v2.fval @@ -680,109 +681,109 @@ func floatDivideAnyWithError(v1 float64, v2 evalResult) (evalResult, error) { return evalResult{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT is out of range in %v / %v", v1, v2.fval) } - return evalResult{typ: Float64, fval: v1 / v2.fval}, nil + return evalResult{typ: sqltypes.Float64, fval: v1 / v2.fval}, nil } func anyMinusFloat(v1 evalResult, v2 float64) evalResult { switch v1.typ { - case Int64: + case sqltypes.Int64: v1.fval = float64(v1.ival) - case Uint64: + case sqltypes.Uint64: v1.fval = float64(v1.uval) } - return evalResult{typ: Float64, fval: v1.fval - v2} + return evalResult{typ: sqltypes.Float64, fval: v1.fval - v2} } -func castFromNumeric(v evalResult, resultType querypb.Type) Value { +func castFromNumeric(v evalResult, resultType querypb.Type) sqltypes.Value { switch { - case IsSigned(resultType): + case sqltypes.IsSigned(resultType): switch v.typ { - case Int64: - return MakeTrusted(resultType, strconv.AppendInt(nil, v.ival, 10)) - case Uint64: - return MakeTrusted(resultType, strconv.AppendInt(nil, int64(v.uval), 10)) - case Float64: - return MakeTrusted(resultType, strconv.AppendInt(nil, int64(v.fval), 10)) + case sqltypes.Int64: + return sqltypes.MakeTrusted(resultType, strconv.AppendInt(nil, v.ival, 10)) + case sqltypes.Uint64: + return sqltypes.MakeTrusted(resultType, strconv.AppendInt(nil, int64(v.uval), 10)) + case sqltypes.Float64: + return sqltypes.MakeTrusted(resultType, strconv.AppendInt(nil, int64(v.fval), 10)) } - case IsUnsigned(resultType): + case sqltypes.IsUnsigned(resultType): switch v.typ { - case Uint64: - return MakeTrusted(resultType, strconv.AppendUint(nil, v.uval, 10)) - case Int64: - return MakeTrusted(resultType, strconv.AppendUint(nil, uint64(v.ival), 10)) - case Float64: - return MakeTrusted(resultType, strconv.AppendUint(nil, uint64(v.fval), 10)) + case sqltypes.Uint64: + return sqltypes.MakeTrusted(resultType, strconv.AppendUint(nil, v.uval, 10)) + case sqltypes.Int64: + return sqltypes.MakeTrusted(resultType, strconv.AppendUint(nil, uint64(v.ival), 10)) + case sqltypes.Float64: + return sqltypes.MakeTrusted(resultType, strconv.AppendUint(nil, uint64(v.fval), 10)) } - case IsFloat(resultType) || resultType == Decimal: + case sqltypes.IsFloat(resultType) || resultType == sqltypes.Decimal: switch v.typ { - case Int64: - return MakeTrusted(resultType, strconv.AppendInt(nil, v.ival, 10)) - case Uint64: - return MakeTrusted(resultType, strconv.AppendUint(nil, v.uval, 10)) - case Float64: + case sqltypes.Int64: + return sqltypes.MakeTrusted(resultType, strconv.AppendInt(nil, v.ival, 10)) + case sqltypes.Uint64: + return sqltypes.MakeTrusted(resultType, strconv.AppendUint(nil, v.uval, 10)) + case sqltypes.Float64: format := byte('g') - if resultType == Decimal { + if resultType == sqltypes.Decimal { format = 'f' } - return MakeTrusted(resultType, strconv.AppendFloat(nil, v.fval, format, -1, 64)) + return sqltypes.MakeTrusted(resultType, strconv.AppendFloat(nil, v.fval, format, -1, 64)) } - case resultType == VarChar || resultType == VarBinary: - return MakeTrusted(resultType, []byte(v.str)) + case resultType == sqltypes.VarChar || resultType == sqltypes.VarBinary: + return sqltypes.MakeTrusted(resultType, []byte(v.str)) } - return NULL + return sqltypes.NULL } func compareNumeric(v1, v2 evalResult) int { // Equalize the types. switch v1.typ { - case Int64: + case sqltypes.Int64: switch v2.typ { - case Uint64: + case sqltypes.Uint64: if v1.ival < 0 { return -1 } - v1 = evalResult{typ: Uint64, uval: uint64(v1.ival)} - case Float64: - v1 = evalResult{typ: Float64, fval: float64(v1.ival)} + v1 = evalResult{typ: sqltypes.Uint64, uval: uint64(v1.ival)} + case sqltypes.Float64: + v1 = evalResult{typ: sqltypes.Float64, fval: float64(v1.ival)} } - case Uint64: + case sqltypes.Uint64: switch v2.typ { - case Int64: + case sqltypes.Int64: if v2.ival < 0 { return 1 } - v2 = evalResult{typ: Uint64, uval: uint64(v2.ival)} - case Float64: - v1 = evalResult{typ: Float64, fval: float64(v1.uval)} + v2 = evalResult{typ: sqltypes.Uint64, uval: uint64(v2.ival)} + case sqltypes.Float64: + v1 = evalResult{typ: sqltypes.Float64, fval: float64(v1.uval)} } - case Float64: + case sqltypes.Float64: switch v2.typ { - case Int64: - v2 = evalResult{typ: Float64, fval: float64(v2.ival)} - case Uint64: - v2 = evalResult{typ: Float64, fval: float64(v2.uval)} + case sqltypes.Int64: + v2 = evalResult{typ: sqltypes.Float64, fval: float64(v2.ival)} + case sqltypes.Uint64: + v2 = evalResult{typ: sqltypes.Float64, fval: float64(v2.uval)} } } // Both values are of the same type. switch v1.typ { - case Int64: + case sqltypes.Int64: switch { case v1.ival == v2.ival: return 0 case v1.ival < v2.ival: return -1 } - case Uint64: + case sqltypes.Uint64: switch { case v1.uval == v2.uval: return 0 case v1.uval < v2.uval: return -1 } - case Float64: + case sqltypes.Float64: switch { case v1.fval == v2.fval: return 0 diff --git a/go/vt/vtgate/evalengine/arithmetic_test.go b/go/vt/vtgate/evalengine/arithmetic_test.go new file mode 100644 index 00000000000..781e5d5de82 --- /dev/null +++ b/go/vt/vtgate/evalengine/arithmetic_test.go @@ -0,0 +1,1519 @@ +/* +Copyright 2019 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package evalengine + +import ( + "encoding/binary" + "fmt" + "math" + "reflect" + "strconv" + "testing" + + "vitess.io/vitess/go/sqltypes" + + querypb "vitess.io/vitess/go/vt/proto/query" + vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" + "vitess.io/vitess/go/vt/vterrors" +) + +func TestDivide(t *testing.T) { + tcases := []struct { + v1, v2 sqltypes.Value + out sqltypes.Value + err error + }{{ + //All Nulls + v1: sqltypes.NULL, + v2: sqltypes.NULL, + out: sqltypes.NULL, + }, { + // First value null. + v1: sqltypes.NULL, + v2: sqltypes.NewInt32(1), + out: sqltypes.NULL, + }, { + // Second value null. + v1: sqltypes.NewInt32(1), + v2: sqltypes.NULL, + out: sqltypes.NULL, + }, { + // Second arg 0 + v1: sqltypes.NewInt32(5), + v2: sqltypes.NewInt32(0), + out: sqltypes.NULL, + }, { + // Both arguments zero + v1: sqltypes.NewInt32(0), + v2: sqltypes.NewInt32(0), + out: sqltypes.NULL, + }, { + // case with negative value + v1: sqltypes.NewInt64(-1), + v2: sqltypes.NewInt64(-2), + out: sqltypes.NewFloat64(0.5000), + }, { + // float64 division by zero + v1: sqltypes.NewFloat64(2), + v2: sqltypes.NewFloat64(0), + out: sqltypes.NULL, + }, { + // Lower bound for int64 + v1: sqltypes.NewInt64(math.MinInt64), + v2: sqltypes.NewInt64(1), + out: sqltypes.NewFloat64(math.MinInt64), + }, { + // upper bound for uint64 + v1: sqltypes.NewUint64(math.MaxUint64), + v2: sqltypes.NewUint64(1), + out: sqltypes.NewFloat64(math.MaxUint64), + }, { + // testing for error in types + v1: sqltypes.TestValue(querypb.Type_INT64, "1.2"), + v2: sqltypes.NewInt64(2), + err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "strconv.ParseInt: parsing \"1.2\": invalid syntax"), + }, { + // testing for error in types + v1: sqltypes.NewInt64(2), + v2: sqltypes.TestValue(querypb.Type_INT64, "1.2"), + err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "strconv.ParseInt: parsing \"1.2\": invalid syntax"), + }, { + // testing for uint/int + v1: sqltypes.NewUint64(4), + v2: sqltypes.NewInt64(5), + out: sqltypes.NewFloat64(0.8), + }, { + // testing for uint/uint + v1: sqltypes.NewUint64(1), + v2: sqltypes.NewUint64(2), + out: sqltypes.NewFloat64(0.5), + }, { + // testing for float64/int64 + v1: sqltypes.TestValue(querypb.Type_FLOAT64, "1.2"), + v2: sqltypes.NewInt64(-2), + out: sqltypes.NewFloat64(-0.6), + }, { + // testing for float64/uint64 + v1: sqltypes.TestValue(querypb.Type_FLOAT64, "1.2"), + v2: sqltypes.NewUint64(2), + out: sqltypes.NewFloat64(0.6), + }, { + // testing for overflow of float64 + v1: sqltypes.NewFloat64(math.MaxFloat64), + v2: sqltypes.NewFloat64(0.5), + err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT is out of range in 1.7976931348623157e+308 / 0.5"), + }} + + for _, tcase := range tcases { + got, err := Divide(tcase.v1, tcase.v2) + + if !vterrors.Equals(err, tcase.err) { + t.Errorf("%v %v %v", printValue(tcase.v1), printValue(tcase.v2), vterrors.Print(err)) + t.Errorf("Divide(%v, %v) error: %v, want %v", printValue(tcase.v1), printValue(tcase.v2), vterrors.Print(err), vterrors.Print(tcase.err)) + } + + if tcase.err != nil { + continue + } + + if !reflect.DeepEqual(got, tcase.out) { + t.Errorf("Divide(%v, %v): %v, want %v", printValue(tcase.v1), printValue(tcase.v2), printValue(got), printValue(tcase.out)) + } + } + +} + +func TestMultiply(t *testing.T) { + tcases := []struct { + v1, v2 sqltypes.Value + out sqltypes.Value + err error + }{{ + //All Nulls + v1: sqltypes.NULL, + v2: sqltypes.NULL, + out: sqltypes.NULL, + }, { + // First value null. + v1: sqltypes.NewInt32(1), + v2: sqltypes.NULL, + out: sqltypes.NULL, + }, { + // Second value null. + v1: sqltypes.NULL, + v2: sqltypes.NewInt32(1), + out: sqltypes.NULL, + }, { + // case with negative value + v1: sqltypes.NewInt64(-1), + v2: sqltypes.NewInt64(-2), + out: sqltypes.NewInt64(2), + }, { + // testing for int64 overflow with min negative value + v1: sqltypes.NewInt64(math.MinInt64), + v2: sqltypes.NewInt64(1), + out: sqltypes.NewInt64(math.MinInt64), + }, { + // testing for error in types + v1: sqltypes.TestValue(querypb.Type_INT64, "1.2"), + v2: sqltypes.NewInt64(2), + err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "strconv.ParseInt: parsing \"1.2\": invalid syntax"), + }, { + // testing for error in types + v1: sqltypes.NewInt64(2), + v2: sqltypes.TestValue(querypb.Type_INT64, "1.2"), + err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "strconv.ParseInt: parsing \"1.2\": invalid syntax"), + }, { + // testing for uint*int + v1: sqltypes.NewUint64(4), + v2: sqltypes.NewInt64(5), + out: sqltypes.NewUint64(20), + }, { + // testing for uint*uint + v1: sqltypes.NewUint64(1), + v2: sqltypes.NewUint64(2), + out: sqltypes.NewUint64(2), + }, { + // testing for float64*int64 + v1: sqltypes.TestValue(querypb.Type_FLOAT64, "1.2"), + v2: sqltypes.NewInt64(-2), + out: sqltypes.NewFloat64(-2.4), + }, { + // testing for float64*uint64 + v1: sqltypes.TestValue(querypb.Type_FLOAT64, "1.2"), + v2: sqltypes.NewUint64(2), + out: sqltypes.NewFloat64(2.4), + }, { + // testing for overflow of int64 + v1: sqltypes.NewInt64(math.MaxInt64), + v2: sqltypes.NewInt64(2), + err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT value is out of range in 9223372036854775807 * 2"), + }, { + // testing for underflow of uint64*max.uint64 + v1: sqltypes.NewInt64(2), + v2: sqltypes.NewUint64(math.MaxUint64), + err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in 18446744073709551615 * 2"), + }, { + v1: sqltypes.NewUint64(math.MaxUint64), + v2: sqltypes.NewUint64(1), + out: sqltypes.NewUint64(math.MaxUint64), + }, { + //Checking whether maxInt value can be passed as uint value + v1: sqltypes.NewUint64(math.MaxInt64), + v2: sqltypes.NewInt64(3), + err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in 9223372036854775807 * 3"), + }} + + for _, tcase := range tcases { + + got, err := Multiply(tcase.v1, tcase.v2) + + if !vterrors.Equals(err, tcase.err) { + t.Errorf("Multiply(%v, %v) error: %v, want %v", printValue(tcase.v1), printValue(tcase.v2), vterrors.Print(err), vterrors.Print(tcase.err)) + } + if tcase.err != nil { + continue + } + + if !reflect.DeepEqual(got, tcase.out) { + t.Errorf("Multiply(%v, %v): %v, want %v", printValue(tcase.v1), printValue(tcase.v2), printValue(got), printValue(tcase.out)) + } + } + +} + +func TestSubtract(t *testing.T) { + tcases := []struct { + v1, v2 sqltypes.Value + out sqltypes.Value + err error + }{{ + // All Nulls + v1: sqltypes.NULL, + v2: sqltypes.NULL, + out: sqltypes.NULL, + }, { + // First value null. + v1: sqltypes.NewInt32(1), + v2: sqltypes.NULL, + out: sqltypes.NULL, + }, { + // Second value null. + v1: sqltypes.NULL, + v2: sqltypes.NewInt32(1), + out: sqltypes.NULL, + }, { + // case with negative value + v1: sqltypes.NewInt64(-1), + v2: sqltypes.NewInt64(-2), + out: sqltypes.NewInt64(1), + }, { + // testing for int64 overflow with min negative value + v1: sqltypes.NewInt64(math.MinInt64), + v2: sqltypes.NewInt64(1), + err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT value is out of range in -9223372036854775808 - 1"), + }, { + v1: sqltypes.NewUint64(4), + v2: sqltypes.NewInt64(5), + err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in 4 - 5"), + }, { + // testing uint - int + v1: sqltypes.NewUint64(7), + v2: sqltypes.NewInt64(5), + out: sqltypes.NewUint64(2), + }, { + v1: sqltypes.NewUint64(math.MaxUint64), + v2: sqltypes.NewInt64(0), + out: sqltypes.NewUint64(math.MaxUint64), + }, { + // testing for int64 overflow + v1: sqltypes.NewInt64(math.MinInt64), + v2: sqltypes.NewUint64(0), + err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in -9223372036854775808 - 0"), + }, { + v1: sqltypes.TestValue(querypb.Type_VARCHAR, "c"), + v2: sqltypes.NewInt64(1), + out: sqltypes.NewInt64(-1), + }, { + v1: sqltypes.NewUint64(1), + v2: sqltypes.TestValue(querypb.Type_VARCHAR, "c"), + out: sqltypes.NewUint64(1), + }, { + // testing for error for parsing float value to uint64 + v1: sqltypes.TestValue(querypb.Type_UINT64, "1.2"), + v2: sqltypes.NewInt64(2), + err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "strconv.ParseUint: parsing \"1.2\": invalid syntax"), + }, { + // testing for error for parsing float value to uint64 + v1: sqltypes.NewUint64(2), + v2: sqltypes.TestValue(querypb.Type_UINT64, "1.2"), + err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "strconv.ParseUint: parsing \"1.2\": invalid syntax"), + }, { + // uint64 - uint64 + v1: sqltypes.NewUint64(8), + v2: sqltypes.NewUint64(4), + out: sqltypes.NewUint64(4), + }, { + // testing for float subtraction: float - int + v1: sqltypes.NewFloat64(1.2), + v2: sqltypes.NewInt64(2), + out: sqltypes.NewFloat64(-0.8), + }, { + // testing for float subtraction: float - uint + v1: sqltypes.NewFloat64(1.2), + v2: sqltypes.NewUint64(2), + out: sqltypes.NewFloat64(-0.8), + }, { + v1: sqltypes.NewInt64(-1), + v2: sqltypes.NewUint64(2), + err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in -1 - 2"), + }, { + v1: sqltypes.NewInt64(2), + v2: sqltypes.NewUint64(1), + out: sqltypes.NewUint64(1), + }, { + // testing int64 - float64 method + v1: sqltypes.NewInt64(-2), + v2: sqltypes.NewFloat64(1.0), + out: sqltypes.NewFloat64(-3.0), + }, { + // testing uint64 - float64 method + v1: sqltypes.NewUint64(1), + v2: sqltypes.NewFloat64(-2.0), + out: sqltypes.NewFloat64(3.0), + }, { + // testing uint - int to return uintplusint + v1: sqltypes.NewUint64(1), + v2: sqltypes.NewInt64(-2), + out: sqltypes.NewUint64(3), + }, { + // testing for float - float + v1: sqltypes.NewFloat64(1.2), + v2: sqltypes.NewFloat64(3.2), + out: sqltypes.NewFloat64(-2), + }, { + // testing uint - uint if v2 > v1 + v1: sqltypes.NewUint64(2), + v2: sqltypes.NewUint64(4), + err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in 2 - 4"), + }, { + // testing uint - (- int) + v1: sqltypes.NewUint64(1), + v2: sqltypes.NewInt64(-2), + out: sqltypes.NewUint64(3), + }} + + for _, tcase := range tcases { + + got, err := Subtract(tcase.v1, tcase.v2) + + if !vterrors.Equals(err, tcase.err) { + t.Errorf("Subtract(%v, %v) error: %v, want %v", printValue(tcase.v1), printValue(tcase.v2), vterrors.Print(err), vterrors.Print(tcase.err)) + } + if tcase.err != nil { + continue + } + + if !reflect.DeepEqual(got, tcase.out) { + t.Errorf("Subtract(%v, %v): %v, want %v", printValue(tcase.v1), printValue(tcase.v2), printValue(got), printValue(tcase.out)) + } + } + +} + +func TestAdd(t *testing.T) { + tcases := []struct { + v1, v2 sqltypes.Value + out sqltypes.Value + err error + }{{ + // All Nulls + v1: sqltypes.NULL, + v2: sqltypes.NULL, + out: sqltypes.NULL, + }, { + // First value null. + v1: sqltypes.NewInt32(1), + v2: sqltypes.NULL, + out: sqltypes.NULL, + }, { + // Second value null. + v1: sqltypes.NULL, + v2: sqltypes.NewInt32(1), + out: sqltypes.NULL, + }, { + // case with negatives + v1: sqltypes.NewInt64(-1), + v2: sqltypes.NewInt64(-2), + out: sqltypes.NewInt64(-3), + }, { + // testing for overflow int64, result will be unsigned int + v1: sqltypes.NewInt64(math.MaxInt64), + v2: sqltypes.NewUint64(2), + out: sqltypes.NewUint64(9223372036854775809), + }, { + v1: sqltypes.NewInt64(-2), + v2: sqltypes.NewUint64(1), + err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in 1 + -2"), + }, { + v1: sqltypes.NewInt64(math.MaxInt64), + v2: sqltypes.NewInt64(-2), + out: sqltypes.NewInt64(9223372036854775805), + }, { + // Normal case + v1: sqltypes.NewUint64(1), + v2: sqltypes.NewUint64(2), + out: sqltypes.NewUint64(3), + }, { + // testing for overflow uint64 + v1: sqltypes.NewUint64(math.MaxUint64), + v2: sqltypes.NewUint64(2), + err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in 18446744073709551615 + 2"), + }, { + // int64 underflow + v1: sqltypes.NewInt64(math.MinInt64), + v2: sqltypes.NewInt64(-2), + err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT value is out of range in -9223372036854775808 + -2"), + }, { + // checking int64 max value can be returned + v1: sqltypes.NewInt64(math.MaxInt64), + v2: sqltypes.NewUint64(0), + out: sqltypes.NewUint64(9223372036854775807), + }, { + // testing whether uint64 max value can be returned + v1: sqltypes.NewUint64(math.MaxUint64), + v2: sqltypes.NewInt64(0), + out: sqltypes.NewUint64(math.MaxUint64), + }, { + v1: sqltypes.NewUint64(math.MaxInt64), + v2: sqltypes.NewInt64(1), + out: sqltypes.NewUint64(9223372036854775808), + }, { + v1: sqltypes.NewUint64(1), + v2: sqltypes.TestValue(querypb.Type_VARCHAR, "c"), + out: sqltypes.NewUint64(1), + }, { + v1: sqltypes.NewUint64(1), + v2: sqltypes.TestValue(querypb.Type_VARCHAR, "1.2"), + out: sqltypes.NewFloat64(2.2), + }, { + v1: sqltypes.TestValue(querypb.Type_INT64, "1.2"), + v2: sqltypes.NewInt64(2), + err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "strconv.ParseInt: parsing \"1.2\": invalid syntax"), + }, { + v1: sqltypes.NewInt64(2), + v2: sqltypes.TestValue(querypb.Type_INT64, "1.2"), + err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "strconv.ParseInt: parsing \"1.2\": invalid syntax"), + }, { + // testing for uint64 overflow with max uint64 + int value + v1: sqltypes.NewUint64(math.MaxUint64), + v2: sqltypes.NewInt64(2), + err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in 18446744073709551615 + 2"), + }} + + for _, tcase := range tcases { + + got, err := Add(tcase.v1, tcase.v2) + + if !vterrors.Equals(err, tcase.err) { + t.Errorf("Add(%v, %v) error: %v, want %v", printValue(tcase.v1), printValue(tcase.v2), vterrors.Print(err), vterrors.Print(tcase.err)) + } + if tcase.err != nil { + continue + } + + if !reflect.DeepEqual(got, tcase.out) { + t.Errorf("Add(%v, %v): %v, want %v", printValue(tcase.v1), printValue(tcase.v2), printValue(got), printValue(tcase.out)) + } + } + +} + +func TestNullsafeAdd(t *testing.T) { + tcases := []struct { + v1, v2 sqltypes.Value + out sqltypes.Value + err error + }{{ + // All nulls. + v1: sqltypes.NULL, + v2: sqltypes.NULL, + out: sqltypes.NewInt64(0), + }, { + // First value null. + v1: sqltypes.NewInt32(1), + v2: sqltypes.NULL, + out: sqltypes.NewInt64(1), + }, { + // Second value null. + v1: sqltypes.NULL, + v2: sqltypes.NewInt32(1), + out: sqltypes.NewInt64(1), + }, { + // Normal case. + v1: sqltypes.NewInt64(1), + v2: sqltypes.NewInt64(2), + out: sqltypes.NewInt64(3), + }, { + // Make sure underlying error is returned for LHS. + v1: sqltypes.TestValue(querypb.Type_INT64, "1.2"), + v2: sqltypes.NewInt64(2), + err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "strconv.ParseInt: parsing \"1.2\": invalid syntax"), + }, { + // Make sure underlying error is returned for RHS. + v1: sqltypes.NewInt64(2), + v2: sqltypes.TestValue(querypb.Type_INT64, "1.2"), + err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "strconv.ParseInt: parsing \"1.2\": invalid syntax"), + }, { + // Make sure underlying error is returned while adding. + v1: sqltypes.NewInt64(-1), + v2: sqltypes.NewUint64(2), + out: sqltypes.NewInt64(-9223372036854775808), + }, { + // Make sure underlying error is returned while converting. + v1: sqltypes.NewFloat64(1), + v2: sqltypes.NewFloat64(2), + out: sqltypes.NewInt64(3), + }} + for _, tcase := range tcases { + got := NullsafeAdd(tcase.v1, tcase.v2, querypb.Type_INT64) + + if !reflect.DeepEqual(got, tcase.out) { + t.Errorf("NullsafeAdd(%v, %v): %v, want %v", printValue(tcase.v1), printValue(tcase.v2), printValue(got), printValue(tcase.out)) + } + } +} + +func TestNullsafeCompare(t *testing.T) { + tcases := []struct { + v1, v2 sqltypes.Value + out int + err error + }{{ + // All nulls. + v1: sqltypes.NULL, + v2: sqltypes.NULL, + out: 0, + }, { + // LHS null. + v1: sqltypes.NULL, + v2: sqltypes.NewInt64(1), + out: -1, + }, { + // RHS null. + v1: sqltypes.NewInt64(1), + v2: sqltypes.NULL, + out: 1, + }, { + // LHS Text + v1: sqltypes.TestValue(querypb.Type_VARCHAR, "abcd"), + v2: sqltypes.TestValue(querypb.Type_VARCHAR, "abcd"), + err: vterrors.New(vtrpcpb.Code_UNKNOWN, "types are not comparable: VARCHAR vs VARCHAR"), + }, { + // Make sure underlying error is returned for LHS. + v1: sqltypes.TestValue(querypb.Type_INT64, "1.2"), + v2: sqltypes.NewInt64(2), + err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "strconv.ParseInt: parsing \"1.2\": invalid syntax"), + }, { + // Make sure underlying error is returned for RHS. + v1: sqltypes.NewInt64(2), + v2: sqltypes.TestValue(querypb.Type_INT64, "1.2"), + err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "strconv.ParseInt: parsing \"1.2\": invalid syntax"), + }, { + // Numeric equal. + v1: sqltypes.NewInt64(1), + v2: sqltypes.NewUint64(1), + out: 0, + }, { + // Numeric unequal. + v1: sqltypes.NewInt64(1), + v2: sqltypes.NewUint64(2), + out: -1, + }, { + // Non-numeric equal + v1: sqltypes.TestValue(querypb.Type_VARBINARY, "abcd"), + v2: sqltypes.TestValue(querypb.Type_BINARY, "abcd"), + out: 0, + }, { + // Non-numeric unequal + v1: sqltypes.TestValue(querypb.Type_VARBINARY, "abcd"), + v2: sqltypes.TestValue(querypb.Type_BINARY, "bcde"), + out: -1, + }, { + // Date/Time types + v1: sqltypes.TestValue(querypb.Type_DATETIME, "1000-01-01 00:00:00"), + v2: sqltypes.TestValue(querypb.Type_BINARY, "1000-01-01 00:00:00"), + out: 0, + }, { + // Date/Time types + v1: sqltypes.TestValue(querypb.Type_DATETIME, "2000-01-01 00:00:00"), + v2: sqltypes.TestValue(querypb.Type_BINARY, "1000-01-01 00:00:00"), + out: 1, + }, { + // Date/Time types + v1: sqltypes.TestValue(querypb.Type_DATETIME, "1000-01-01 00:00:00"), + v2: sqltypes.TestValue(querypb.Type_BINARY, "2000-01-01 00:00:00"), + out: -1, + }} + for _, tcase := range tcases { + got, err := NullsafeCompare(tcase.v1, tcase.v2) + if !vterrors.Equals(err, tcase.err) { + t.Errorf("NullsafeCompare(%v, %v) error: %v, want %v", printValue(tcase.v1), printValue(tcase.v2), vterrors.Print(err), vterrors.Print(tcase.err)) + } + if tcase.err != nil { + continue + } + + if got != tcase.out { + t.Errorf("NullsafeCompare(%v, %v): %v, want %v", printValue(tcase.v1), printValue(tcase.v2), got, tcase.out) + } + } +} + +func TestCast(t *testing.T) { + tcases := []struct { + typ querypb.Type + v sqltypes.Value + out sqltypes.Value + err error + }{{ + typ: querypb.Type_VARCHAR, + v: sqltypes.NULL, + out: sqltypes.NULL, + }, { + typ: querypb.Type_VARCHAR, + v: sqltypes.TestValue(querypb.Type_VARCHAR, "exact types"), + out: sqltypes.TestValue(querypb.Type_VARCHAR, "exact types"), + }, { + typ: querypb.Type_INT64, + v: sqltypes.TestValue(querypb.Type_INT32, "32"), + out: sqltypes.TestValue(querypb.Type_INT64, "32"), + }, { + typ: querypb.Type_INT24, + v: sqltypes.TestValue(querypb.Type_UINT64, "64"), + out: sqltypes.TestValue(querypb.Type_INT24, "64"), + }, { + typ: querypb.Type_INT24, + v: sqltypes.TestValue(querypb.Type_VARCHAR, "bad int"), + err: vterrors.New(vtrpcpb.Code_UNKNOWN, `strconv.ParseInt: parsing "bad int": invalid syntax`), + }, { + typ: querypb.Type_UINT64, + v: sqltypes.TestValue(querypb.Type_UINT32, "32"), + out: sqltypes.TestValue(querypb.Type_UINT64, "32"), + }, { + typ: querypb.Type_UINT24, + v: sqltypes.TestValue(querypb.Type_INT64, "64"), + out: sqltypes.TestValue(querypb.Type_UINT24, "64"), + }, { + typ: querypb.Type_UINT24, + v: sqltypes.TestValue(querypb.Type_INT64, "-1"), + err: vterrors.New(vtrpcpb.Code_UNKNOWN, `strconv.ParseUint: parsing "-1": invalid syntax`), + }, { + typ: querypb.Type_FLOAT64, + v: sqltypes.TestValue(querypb.Type_INT64, "64"), + out: sqltypes.TestValue(querypb.Type_FLOAT64, "64"), + }, { + typ: querypb.Type_FLOAT32, + v: sqltypes.TestValue(querypb.Type_FLOAT64, "64"), + out: sqltypes.TestValue(querypb.Type_FLOAT32, "64"), + }, { + typ: querypb.Type_FLOAT32, + v: sqltypes.TestValue(querypb.Type_DECIMAL, "1.24"), + out: sqltypes.TestValue(querypb.Type_FLOAT32, "1.24"), + }, { + typ: querypb.Type_FLOAT64, + v: sqltypes.TestValue(querypb.Type_VARCHAR, "1.25"), + out: sqltypes.TestValue(querypb.Type_FLOAT64, "1.25"), + }, { + typ: querypb.Type_FLOAT64, + v: sqltypes.TestValue(querypb.Type_VARCHAR, "bad float"), + err: vterrors.New(vtrpcpb.Code_UNKNOWN, `strconv.ParseFloat: parsing "bad float": invalid syntax`), + }, { + typ: querypb.Type_VARCHAR, + v: sqltypes.TestValue(querypb.Type_INT64, "64"), + out: sqltypes.TestValue(querypb.Type_VARCHAR, "64"), + }, { + typ: querypb.Type_VARBINARY, + v: sqltypes.TestValue(querypb.Type_FLOAT64, "64"), + out: sqltypes.TestValue(querypb.Type_VARBINARY, "64"), + }, { + typ: querypb.Type_VARBINARY, + v: sqltypes.TestValue(querypb.Type_DECIMAL, "1.24"), + out: sqltypes.TestValue(querypb.Type_VARBINARY, "1.24"), + }, { + typ: querypb.Type_VARBINARY, + v: sqltypes.TestValue(querypb.Type_VARCHAR, "1.25"), + out: sqltypes.TestValue(querypb.Type_VARBINARY, "1.25"), + }, { + typ: querypb.Type_VARCHAR, + v: sqltypes.TestValue(querypb.Type_VARBINARY, "valid string"), + out: sqltypes.TestValue(querypb.Type_VARCHAR, "valid string"), + }, { + typ: querypb.Type_VARCHAR, + v: sqltypes.TestValue(sqltypes.Expression, "bad string"), + err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "EXPRESSION(bad string) cannot be cast to VARCHAR"), + }} + for _, tcase := range tcases { + got, err := Cast(tcase.v, tcase.typ) + if !vterrors.Equals(err, tcase.err) { + t.Errorf("Cast(%v) error: %v, want %v", tcase.v, vterrors.Print(err), vterrors.Print(tcase.err)) + } + if tcase.err != nil { + continue + } + + if !reflect.DeepEqual(got, tcase.out) { + t.Errorf("Cast(%v): %v, want %v", tcase.v, got, tcase.out) + } + } +} + +func TestToUint64(t *testing.T) { + tcases := []struct { + v sqltypes.Value + out uint64 + err error + }{{ + v: sqltypes.TestValue(querypb.Type_VARCHAR, "abcd"), + err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "could not parse value: 'abcd'"), + }, { + v: sqltypes.NewInt64(-1), + err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "negative number cannot be converted to unsigned: -1"), + }, { + v: sqltypes.NewInt64(1), + out: 1, + }, { + v: sqltypes.NewUint64(1), + out: 1, + }} + for _, tcase := range tcases { + got, err := ToUint64(tcase.v) + if !vterrors.Equals(err, tcase.err) { + t.Errorf("ToUint64(%v) error: %v, want %v", tcase.v, vterrors.Print(err), vterrors.Print(tcase.err)) + } + if tcase.err != nil { + continue + } + + if got != tcase.out { + t.Errorf("ToUint64(%v): %v, want %v", tcase.v, got, tcase.out) + } + } +} + +func TestToInt64(t *testing.T) { + tcases := []struct { + v sqltypes.Value + out int64 + err error + }{{ + v: sqltypes.TestValue(querypb.Type_VARCHAR, "abcd"), + err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "could not parse value: 'abcd'"), + }, { + v: sqltypes.NewUint64(18446744073709551615), + err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "unsigned number overflows int64 value: 18446744073709551615"), + }, { + v: sqltypes.NewInt64(1), + out: 1, + }, { + v: sqltypes.NewUint64(1), + out: 1, + }} + for _, tcase := range tcases { + got, err := ToInt64(tcase.v) + if !vterrors.Equals(err, tcase.err) { + t.Errorf("ToInt64(%v) error: %v, want %v", tcase.v, vterrors.Print(err), vterrors.Print(tcase.err)) + } + if tcase.err != nil { + continue + } + + if got != tcase.out { + t.Errorf("ToInt64(%v): %v, want %v", tcase.v, got, tcase.out) + } + } +} + +func TestToFloat64(t *testing.T) { + tcases := []struct { + v sqltypes.Value + out float64 + err error + }{{ + v: sqltypes.TestValue(querypb.Type_VARCHAR, "abcd"), + out: 0, + }, { + v: sqltypes.NewInt64(1), + out: 1, + }, { + v: sqltypes.NewUint64(1), + out: 1, + }, { + v: sqltypes.NewFloat64(1.2), + out: 1.2, + }, { + v: sqltypes.TestValue(querypb.Type_INT64, "1.2"), + err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "strconv.ParseInt: parsing \"1.2\": invalid syntax"), + }} + for _, tcase := range tcases { + got, err := ToFloat64(tcase.v) + if !vterrors.Equals(err, tcase.err) { + t.Errorf("ToFloat64(%v) error: %v, want %v", tcase.v, vterrors.Print(err), vterrors.Print(tcase.err)) + } + if tcase.err != nil { + continue + } + + if got != tcase.out { + t.Errorf("ToFloat64(%v): %v, want %v", tcase.v, got, tcase.out) + } + } +} + +func TestToNative(t *testing.T) { + testcases := []struct { + in sqltypes.Value + out interface{} + }{{ + in: sqltypes.NULL, + out: nil, + }, { + in: sqltypes.TestValue(querypb.Type_INT8, "1"), + out: int64(1), + }, { + in: sqltypes.TestValue(querypb.Type_INT16, "1"), + out: int64(1), + }, { + in: sqltypes.TestValue(querypb.Type_INT24, "1"), + out: int64(1), + }, { + in: sqltypes.TestValue(querypb.Type_INT32, "1"), + out: int64(1), + }, { + in: sqltypes.TestValue(querypb.Type_INT64, "1"), + out: int64(1), + }, { + in: sqltypes.TestValue(querypb.Type_UINT8, "1"), + out: uint64(1), + }, { + in: sqltypes.TestValue(querypb.Type_UINT16, "1"), + out: uint64(1), + }, { + in: sqltypes.TestValue(querypb.Type_UINT24, "1"), + out: uint64(1), + }, { + in: sqltypes.TestValue(querypb.Type_UINT32, "1"), + out: uint64(1), + }, { + in: sqltypes.TestValue(querypb.Type_UINT64, "1"), + out: uint64(1), + }, { + in: sqltypes.TestValue(querypb.Type_FLOAT32, "1"), + out: float64(1), + }, { + in: sqltypes.TestValue(querypb.Type_FLOAT64, "1"), + out: float64(1), + }, { + in: sqltypes.TestValue(querypb.Type_TIMESTAMP, "2012-02-24 23:19:43"), + out: []byte("2012-02-24 23:19:43"), + }, { + in: sqltypes.TestValue(querypb.Type_DATE, "2012-02-24"), + out: []byte("2012-02-24"), + }, { + in: sqltypes.TestValue(querypb.Type_TIME, "23:19:43"), + out: []byte("23:19:43"), + }, { + in: sqltypes.TestValue(querypb.Type_DATETIME, "2012-02-24 23:19:43"), + out: []byte("2012-02-24 23:19:43"), + }, { + in: sqltypes.TestValue(querypb.Type_YEAR, "1"), + out: uint64(1), + }, { + in: sqltypes.TestValue(querypb.Type_DECIMAL, "1"), + out: []byte("1"), + }, { + in: sqltypes.TestValue(querypb.Type_TEXT, "a"), + out: []byte("a"), + }, { + in: sqltypes.TestValue(querypb.Type_BLOB, "a"), + out: []byte("a"), + }, { + in: sqltypes.TestValue(querypb.Type_VARCHAR, "a"), + out: []byte("a"), + }, { + in: sqltypes.TestValue(querypb.Type_VARBINARY, "a"), + out: []byte("a"), + }, { + in: sqltypes.TestValue(querypb.Type_CHAR, "a"), + out: []byte("a"), + }, { + in: sqltypes.TestValue(querypb.Type_BINARY, "a"), + out: []byte("a"), + }, { + in: sqltypes.TestValue(querypb.Type_BIT, "1"), + out: []byte("1"), + }, { + in: sqltypes.TestValue(querypb.Type_ENUM, "a"), + out: []byte("a"), + }, { + in: sqltypes.TestValue(querypb.Type_SET, "a"), + out: []byte("a"), + }} + for _, tcase := range testcases { + v, err := ToNative(tcase.in) + if err != nil { + t.Error(err) + } + if !reflect.DeepEqual(v, tcase.out) { + t.Errorf("%v.ToNative = %#v, want %#v", tcase.in, v, tcase.out) + } + } + + // Test Expression failure. + _, err := ToNative(sqltypes.TestValue(querypb.Type_EXPRESSION, "aa")) + want := vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "EXPRESSION(aa) cannot be converted to a go type") + if !vterrors.Equals(err, want) { + t.Errorf("ToNative(EXPRESSION): %v, want %v", vterrors.Print(err), vterrors.Print(want)) + } +} + +func TestNewNumeric(t *testing.T) { + tcases := []struct { + v sqltypes.Value + out evalResult + err error + }{{ + v: sqltypes.NewInt64(1), + out: evalResult{typ: querypb.Type_INT64, ival: 1}, + }, { + v: sqltypes.NewUint64(1), + out: evalResult{typ: querypb.Type_UINT64, uval: 1}, + }, { + v: sqltypes.NewFloat64(1), + out: evalResult{typ: querypb.Type_FLOAT64, fval: 1}, + }, { + // For non-number type, Int64 is the default. + v: sqltypes.TestValue(querypb.Type_VARCHAR, "1"), + out: evalResult{typ: querypb.Type_INT64, ival: 1}, + }, { + // If Int64 can't work, we use Float64. + v: sqltypes.TestValue(querypb.Type_VARCHAR, "1.2"), + out: evalResult{typ: querypb.Type_FLOAT64, fval: 1.2}, + }, { + // Only valid Int64 allowed if type is Int64. + v: sqltypes.TestValue(querypb.Type_INT64, "1.2"), + err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "strconv.ParseInt: parsing \"1.2\": invalid syntax"), + }, { + // Only valid Uint64 allowed if type is Uint64. + v: sqltypes.TestValue(querypb.Type_UINT64, "1.2"), + err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "strconv.ParseUint: parsing \"1.2\": invalid syntax"), + }, { + // Only valid Float64 allowed if type is Float64. + v: sqltypes.TestValue(querypb.Type_FLOAT64, "abcd"), + err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "strconv.ParseFloat: parsing \"abcd\": invalid syntax"), + }, { + v: sqltypes.TestValue(querypb.Type_VARCHAR, "abcd"), + out: evalResult{typ: querypb.Type_FLOAT64, fval: 0}, + }} + for _, tcase := range tcases { + got, err := newNumeric(tcase.v) + if !vterrors.Equals(err, tcase.err) { + t.Errorf("newNumeric(%s) error: %v, want %v", printValue(tcase.v), vterrors.Print(err), vterrors.Print(tcase.err)) + } + if tcase.err == nil { + continue + } + + if got != tcase.out { + t.Errorf("newNumeric(%s): %v, want %v", printValue(tcase.v), got, tcase.out) + } + } +} + +func TestNewIntegralNumeric(t *testing.T) { + tcases := []struct { + v sqltypes.Value + out evalResult + err error + }{{ + v: sqltypes.NewInt64(1), + out: evalResult{typ: querypb.Type_INT64, ival: 1}, + }, { + v: sqltypes.NewUint64(1), + out: evalResult{typ: querypb.Type_UINT64, uval: 1}, + }, { + v: sqltypes.NewFloat64(1), + out: evalResult{typ: querypb.Type_INT64, ival: 1}, + }, { + // For non-number type, Int64 is the default. + v: sqltypes.TestValue(querypb.Type_VARCHAR, "1"), + out: evalResult{typ: querypb.Type_INT64, ival: 1}, + }, { + // If Int64 can't work, we use Uint64. + v: sqltypes.TestValue(querypb.Type_VARCHAR, "18446744073709551615"), + out: evalResult{typ: querypb.Type_UINT64, uval: 18446744073709551615}, + }, { + // Only valid Int64 allowed if type is Int64. + v: sqltypes.TestValue(querypb.Type_INT64, "1.2"), + err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "strconv.ParseInt: parsing \"1.2\": invalid syntax"), + }, { + // Only valid Uint64 allowed if type is Uint64. + v: sqltypes.TestValue(querypb.Type_UINT64, "1.2"), + err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "strconv.ParseUint: parsing \"1.2\": invalid syntax"), + }, { + v: sqltypes.TestValue(querypb.Type_VARCHAR, "abcd"), + err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "could not parse value: 'abcd'"), + }} + for _, tcase := range tcases { + got, err := newIntegralNumeric(tcase.v) + if err != nil && !vterrors.Equals(err, tcase.err) { + t.Errorf("newIntegralNumeric(%s) error: %v, want %v", printValue(tcase.v), vterrors.Print(err), vterrors.Print(tcase.err)) + } + if tcase.err == nil { + continue + } + + if got != tcase.out { + t.Errorf("newIntegralNumeric(%s): %v, want %v", printValue(tcase.v), got, tcase.out) + } + } +} + +func TestAddNumeric(t *testing.T) { + tcases := []struct { + v1, v2 evalResult + out evalResult + err error + }{{ + v1: evalResult{typ: querypb.Type_INT64, ival: 1}, + v2: evalResult{typ: querypb.Type_INT64, ival: 2}, + out: evalResult{typ: querypb.Type_INT64, ival: 3}, + }, { + v1: evalResult{typ: querypb.Type_INT64, ival: 1}, + v2: evalResult{typ: querypb.Type_UINT64, uval: 2}, + out: evalResult{typ: querypb.Type_UINT64, uval: 3}, + }, { + v1: evalResult{typ: querypb.Type_INT64, ival: 1}, + v2: evalResult{typ: querypb.Type_FLOAT64, fval: 2}, + out: evalResult{typ: querypb.Type_FLOAT64, fval: 3}, + }, { + v1: evalResult{typ: querypb.Type_UINT64, uval: 1}, + v2: evalResult{typ: querypb.Type_UINT64, uval: 2}, + out: evalResult{typ: querypb.Type_UINT64, uval: 3}, + }, { + v1: evalResult{typ: querypb.Type_UINT64, uval: 1}, + v2: evalResult{typ: querypb.Type_FLOAT64, fval: 2}, + out: evalResult{typ: querypb.Type_FLOAT64, fval: 3}, + }, { + v1: evalResult{typ: querypb.Type_FLOAT64, fval: 1}, + v2: evalResult{typ: querypb.Type_FLOAT64, fval: 2}, + out: evalResult{typ: querypb.Type_FLOAT64, fval: 3}, + }, { + // Int64 overflow. + v1: evalResult{typ: querypb.Type_INT64, ival: 9223372036854775807}, + v2: evalResult{typ: querypb.Type_INT64, ival: 2}, + out: evalResult{typ: querypb.Type_FLOAT64, fval: 9223372036854775809}, + }, { + // Int64 underflow. + v1: evalResult{typ: querypb.Type_INT64, ival: -9223372036854775807}, + v2: evalResult{typ: querypb.Type_INT64, ival: -2}, + out: evalResult{typ: querypb.Type_FLOAT64, fval: -9223372036854775809}, + }, { + v1: evalResult{typ: querypb.Type_INT64, ival: -1}, + v2: evalResult{typ: querypb.Type_UINT64, uval: 2}, + out: evalResult{typ: querypb.Type_FLOAT64, fval: 18446744073709551617}, + }, { + // Uint64 overflow. + v1: evalResult{typ: querypb.Type_UINT64, uval: 18446744073709551615}, + v2: evalResult{typ: querypb.Type_UINT64, uval: 2}, + out: evalResult{typ: querypb.Type_FLOAT64, fval: 18446744073709551617}, + }} + for _, tcase := range tcases { + got := addNumeric(tcase.v1, tcase.v2) + + if got != tcase.out { + t.Errorf("addNumeric(%v, %v): %v, want %v", tcase.v1, tcase.v2, got, tcase.out) + } + } +} + +func TestPrioritize(t *testing.T) { + ival := evalResult{typ: querypb.Type_INT64} + uval := evalResult{typ: querypb.Type_UINT64} + fval := evalResult{typ: querypb.Type_FLOAT64} + + tcases := []struct { + v1, v2 evalResult + out1, out2 evalResult + }{{ + v1: ival, + v2: uval, + out1: uval, + out2: ival, + }, { + v1: ival, + v2: fval, + out1: fval, + out2: ival, + }, { + v1: uval, + v2: ival, + out1: uval, + out2: ival, + }, { + v1: uval, + v2: fval, + out1: fval, + out2: uval, + }, { + v1: fval, + v2: ival, + out1: fval, + out2: ival, + }, { + v1: fval, + v2: uval, + out1: fval, + out2: uval, + }} + for _, tcase := range tcases { + got1, got2 := prioritize(tcase.v1, tcase.v2) + if got1 != tcase.out1 || got2 != tcase.out2 { + t.Errorf("prioritize(%v, %v): (%v, %v) , want (%v, %v)", tcase.v1.typ, tcase.v2.typ, got1.typ, got2.typ, tcase.out1.typ, tcase.out2.typ) + } + } +} + +func TestCastFromNumeric(t *testing.T) { + tcases := []struct { + typ querypb.Type + v evalResult + out sqltypes.Value + err error + }{{ + typ: querypb.Type_INT64, + v: evalResult{typ: querypb.Type_INT64, ival: 1}, + out: sqltypes.NewInt64(1), + }, { + typ: querypb.Type_INT64, + v: evalResult{typ: querypb.Type_UINT64, uval: 1}, + out: sqltypes.NewInt64(1), + }, { + typ: querypb.Type_INT64, + v: evalResult{typ: querypb.Type_FLOAT64, fval: 1.2e-16}, + out: sqltypes.NewInt64(0), + }, { + typ: querypb.Type_UINT64, + v: evalResult{typ: querypb.Type_INT64, ival: 1}, + out: sqltypes.NewUint64(1), + }, { + typ: querypb.Type_UINT64, + v: evalResult{typ: querypb.Type_UINT64, uval: 1}, + out: sqltypes.NewUint64(1), + }, { + typ: querypb.Type_UINT64, + v: evalResult{typ: querypb.Type_FLOAT64, fval: 1.2e-16}, + out: sqltypes.NewUint64(0), + }, { + typ: querypb.Type_FLOAT64, + v: evalResult{typ: querypb.Type_INT64, ival: 1}, + out: sqltypes.TestValue(querypb.Type_FLOAT64, "1"), + }, { + typ: querypb.Type_FLOAT64, + v: evalResult{typ: querypb.Type_UINT64, uval: 1}, + out: sqltypes.TestValue(querypb.Type_FLOAT64, "1"), + }, { + typ: querypb.Type_FLOAT64, + v: evalResult{typ: querypb.Type_FLOAT64, fval: 1.2e-16}, + out: sqltypes.TestValue(querypb.Type_FLOAT64, "1.2e-16"), + }, { + typ: querypb.Type_DECIMAL, + v: evalResult{typ: querypb.Type_INT64, ival: 1}, + out: sqltypes.TestValue(querypb.Type_DECIMAL, "1"), + }, { + typ: querypb.Type_DECIMAL, + v: evalResult{typ: querypb.Type_UINT64, uval: 1}, + out: sqltypes.TestValue(querypb.Type_DECIMAL, "1"), + }, { + // For float, we should not use scientific notation. + typ: querypb.Type_DECIMAL, + v: evalResult{typ: querypb.Type_FLOAT64, fval: 1.2e-16}, + out: sqltypes.TestValue(querypb.Type_DECIMAL, "0.00000000000000012"), + }} + for _, tcase := range tcases { + got := castFromNumeric(tcase.v, tcase.typ) + + if !reflect.DeepEqual(got, tcase.out) { + t.Errorf("castFromNumeric(%v, %v): %v, want %v", tcase.v, tcase.typ, printValue(got), printValue(tcase.out)) + } + } +} + +func TestCompareNumeric(t *testing.T) { + tcases := []struct { + v1, v2 evalResult + out int + }{{ + v1: evalResult{typ: querypb.Type_INT64, ival: 1}, + v2: evalResult{typ: querypb.Type_INT64, ival: 1}, + out: 0, + }, { + v1: evalResult{typ: querypb.Type_INT64, ival: 1}, + v2: evalResult{typ: querypb.Type_INT64, ival: 2}, + out: -1, + }, { + v1: evalResult{typ: querypb.Type_INT64, ival: 2}, + v2: evalResult{typ: querypb.Type_INT64, ival: 1}, + out: 1, + }, { + // Special case. + v1: evalResult{typ: querypb.Type_INT64, ival: -1}, + v2: evalResult{typ: querypb.Type_UINT64, uval: 1}, + out: -1, + }, { + v1: evalResult{typ: querypb.Type_INT64, ival: 1}, + v2: evalResult{typ: querypb.Type_UINT64, uval: 1}, + out: 0, + }, { + v1: evalResult{typ: querypb.Type_INT64, ival: 1}, + v2: evalResult{typ: querypb.Type_UINT64, uval: 2}, + out: -1, + }, { + v1: evalResult{typ: querypb.Type_INT64, ival: 2}, + v2: evalResult{typ: querypb.Type_UINT64, uval: 1}, + out: 1, + }, { + v1: evalResult{typ: querypb.Type_INT64, ival: 1}, + v2: evalResult{typ: querypb.Type_FLOAT64, fval: 1}, + out: 0, + }, { + v1: evalResult{typ: querypb.Type_INT64, ival: 1}, + v2: evalResult{typ: querypb.Type_FLOAT64, fval: 2}, + out: -1, + }, { + v1: evalResult{typ: querypb.Type_INT64, ival: 2}, + v2: evalResult{typ: querypb.Type_FLOAT64, fval: 1}, + out: 1, + }, { + // Special case. + v1: evalResult{typ: querypb.Type_UINT64, uval: 1}, + v2: evalResult{typ: querypb.Type_INT64, ival: -1}, + out: 1, + }, { + v1: evalResult{typ: querypb.Type_UINT64, uval: 1}, + v2: evalResult{typ: querypb.Type_INT64, ival: 1}, + out: 0, + }, { + v1: evalResult{typ: querypb.Type_UINT64, uval: 1}, + v2: evalResult{typ: querypb.Type_INT64, ival: 2}, + out: -1, + }, { + v1: evalResult{typ: querypb.Type_UINT64, uval: 2}, + v2: evalResult{typ: querypb.Type_INT64, ival: 1}, + out: 1, + }, { + v1: evalResult{typ: querypb.Type_UINT64, uval: 1}, + v2: evalResult{typ: querypb.Type_UINT64, uval: 1}, + out: 0, + }, { + v1: evalResult{typ: querypb.Type_UINT64, uval: 1}, + v2: evalResult{typ: querypb.Type_UINT64, uval: 2}, + out: -1, + }, { + v1: evalResult{typ: querypb.Type_UINT64, uval: 2}, + v2: evalResult{typ: querypb.Type_UINT64, uval: 1}, + out: 1, + }, { + v1: evalResult{typ: querypb.Type_UINT64, uval: 1}, + v2: evalResult{typ: querypb.Type_FLOAT64, fval: 1}, + out: 0, + }, { + v1: evalResult{typ: querypb.Type_UINT64, uval: 1}, + v2: evalResult{typ: querypb.Type_FLOAT64, fval: 2}, + out: -1, + }, { + v1: evalResult{typ: querypb.Type_UINT64, uval: 2}, + v2: evalResult{typ: querypb.Type_FLOAT64, fval: 1}, + out: 1, + }, { + v1: evalResult{typ: querypb.Type_FLOAT64, fval: 1}, + v2: evalResult{typ: querypb.Type_INT64, ival: 1}, + out: 0, + }, { + v1: evalResult{typ: querypb.Type_FLOAT64, fval: 1}, + v2: evalResult{typ: querypb.Type_INT64, ival: 2}, + out: -1, + }, { + v1: evalResult{typ: querypb.Type_FLOAT64, fval: 2}, + v2: evalResult{typ: querypb.Type_INT64, ival: 1}, + out: 1, + }, { + v1: evalResult{typ: querypb.Type_FLOAT64, fval: 1}, + v2: evalResult{typ: querypb.Type_UINT64, uval: 1}, + out: 0, + }, { + v1: evalResult{typ: querypb.Type_FLOAT64, fval: 1}, + v2: evalResult{typ: querypb.Type_UINT64, uval: 2}, + out: -1, + }, { + v1: evalResult{typ: querypb.Type_FLOAT64, fval: 2}, + v2: evalResult{typ: querypb.Type_UINT64, uval: 1}, + out: 1, + }, { + v1: evalResult{typ: querypb.Type_FLOAT64, fval: 1}, + v2: evalResult{typ: querypb.Type_FLOAT64, fval: 1}, + out: 0, + }, { + v1: evalResult{typ: querypb.Type_FLOAT64, fval: 1}, + v2: evalResult{typ: querypb.Type_FLOAT64, fval: 2}, + out: -1, + }, { + v1: evalResult{typ: querypb.Type_FLOAT64, fval: 2}, + v2: evalResult{typ: querypb.Type_FLOAT64, fval: 1}, + out: 1, + }} + for _, tcase := range tcases { + got := compareNumeric(tcase.v1, tcase.v2) + if got != tcase.out { + t.Errorf("equalNumeric(%v, %v): %v, want %v", tcase.v1, tcase.v2, got, tcase.out) + } + } +} + +func TestMin(t *testing.T) { + tcases := []struct { + v1, v2 sqltypes.Value + min sqltypes.Value + err error + }{{ + v1: sqltypes.NULL, + v2: sqltypes.NULL, + min: sqltypes.NULL, + }, { + v1: sqltypes.NewInt64(1), + v2: sqltypes.NULL, + min: sqltypes.NewInt64(1), + }, { + v1: sqltypes.NULL, + v2: sqltypes.NewInt64(1), + min: sqltypes.NewInt64(1), + }, { + v1: sqltypes.NewInt64(1), + v2: sqltypes.NewInt64(2), + min: sqltypes.NewInt64(1), + }, { + v1: sqltypes.NewInt64(2), + v2: sqltypes.NewInt64(1), + min: sqltypes.NewInt64(1), + }, { + v1: sqltypes.NewInt64(1), + v2: sqltypes.NewInt64(1), + min: sqltypes.NewInt64(1), + }, { + v1: sqltypes.TestValue(querypb.Type_VARCHAR, "aa"), + v2: sqltypes.TestValue(querypb.Type_VARCHAR, "aa"), + err: vterrors.New(vtrpcpb.Code_UNKNOWN, "types are not comparable: VARCHAR vs VARCHAR"), + }} + for _, tcase := range tcases { + v, err := Min(tcase.v1, tcase.v2) + if !vterrors.Equals(err, tcase.err) { + t.Errorf("Min error: %v, want %v", vterrors.Print(err), vterrors.Print(tcase.err)) + } + if tcase.err != nil { + continue + } + + if !reflect.DeepEqual(v, tcase.min) { + t.Errorf("Min(%v, %v): %v, want %v", tcase.v1, tcase.v2, v, tcase.min) + } + } +} + +func TestMax(t *testing.T) { + tcases := []struct { + v1, v2 sqltypes.Value + max sqltypes.Value + err error + }{{ + v1: sqltypes.NULL, + v2: sqltypes.NULL, + max: sqltypes.NULL, + }, { + v1: sqltypes.NewInt64(1), + v2: sqltypes.NULL, + max: sqltypes.NewInt64(1), + }, { + v1: sqltypes.NULL, + v2: sqltypes.NewInt64(1), + max: sqltypes.NewInt64(1), + }, { + v1: sqltypes.NewInt64(1), + v2: sqltypes.NewInt64(2), + max: sqltypes.NewInt64(2), + }, { + v1: sqltypes.NewInt64(2), + v2: sqltypes.NewInt64(1), + max: sqltypes.NewInt64(2), + }, { + v1: sqltypes.NewInt64(1), + v2: sqltypes.NewInt64(1), + max: sqltypes.NewInt64(1), + }, { + v1: sqltypes.TestValue(querypb.Type_VARCHAR, "aa"), + v2: sqltypes.TestValue(querypb.Type_VARCHAR, "aa"), + err: vterrors.New(vtrpcpb.Code_UNKNOWN, "types are not comparable: VARCHAR vs VARCHAR"), + }} + for _, tcase := range tcases { + v, err := Max(tcase.v1, tcase.v2) + if !vterrors.Equals(err, tcase.err) { + t.Errorf("Max error: %v, want %v", vterrors.Print(err), vterrors.Print(tcase.err)) + } + if tcase.err != nil { + continue + } + + if !reflect.DeepEqual(v, tcase.max) { + t.Errorf("Max(%v, %v): %v, want %v", tcase.v1, tcase.v2, v, tcase.max) + } + } +} + +func printValue(v sqltypes.Value) string { + return fmt.Sprintf("%v:%q", v.Type(), v.ToBytes()) +} + +// These benchmarks show that using existing ASCII representations +// for numbers is about 6x slower than using native representations. +// However, 229ns is still a negligible time compared to the cost of +// other operations. The additional complexity of introducing native +// types is currently not worth it. So, we'll stay with the existing +// ASCII representation for now. Using interfaces is more expensive +// than native representation of values. This is probably because +// interfaces also allocate memory, and also perform type assertions. +// Actual benchmark is based on NoNative. So, the numbers are similar. +// Date: 6/4/17 +// Version: go1.8 +// BenchmarkAddActual-8 10000000 263 ns/op +// BenchmarkAddNoNative-8 10000000 228 ns/op +// BenchmarkAddNative-8 50000000 40.0 ns/op +// BenchmarkAddGoInterface-8 30000000 52.4 ns/op +// BenchmarkAddGoNonInterface-8 2000000000 1.00 ns/op +// BenchmarkAddGo-8 2000000000 1.00 ns/op +func BenchmarkAddActual(b *testing.B) { + v1 := sqltypes.MakeTrusted(querypb.Type_INT64, []byte("1")) + v2 := sqltypes.MakeTrusted(querypb.Type_INT64, []byte("12")) + for i := 0; i < b.N; i++ { + v1 = NullsafeAdd(v1, v2, querypb.Type_INT64) + } +} + +func BenchmarkAddNoNative(b *testing.B) { + v1 := sqltypes.MakeTrusted(querypb.Type_INT64, []byte("1")) + v2 := sqltypes.MakeTrusted(querypb.Type_INT64, []byte("12")) + for i := 0; i < b.N; i++ { + iv1, _ := ToInt64(v1) + iv2, _ := ToInt64(v2) + v1 = sqltypes.MakeTrusted(querypb.Type_INT64, strconv.AppendInt(nil, iv1+iv2, 10)) + } +} + +func BenchmarkAddNative(b *testing.B) { + v1 := makeNativeInt64(1) + v2 := makeNativeInt64(12) + for i := 0; i < b.N; i++ { + iv1 := int64(binary.BigEndian.Uint64(v1.Raw())) + iv2 := int64(binary.BigEndian.Uint64(v2.Raw())) + v1 = makeNativeInt64(iv1 + iv2) + } +} + +func makeNativeInt64(v int64) sqltypes.Value { + buf := make([]byte, 8) + binary.BigEndian.PutUint64(buf, uint64(v)) + return sqltypes.MakeTrusted(querypb.Type_INT64, buf) +} + +func BenchmarkAddGoInterface(b *testing.B) { + var v1, v2 interface{} + v1 = int64(1) + v2 = int64(2) + for i := 0; i < b.N; i++ { + v1 = v1.(int64) + v2.(int64) + } +} + +func BenchmarkAddGoNonInterface(b *testing.B) { + v1 := evalResult{typ: querypb.Type_INT64, ival: 1} + v2 := evalResult{typ: querypb.Type_INT64, ival: 12} + for i := 0; i < b.N; i++ { + if v1.typ != querypb.Type_INT64 { + b.Error("type assertion failed") + } + if v2.typ != querypb.Type_INT64 { + b.Error("type assertion failed") + } + v1 = evalResult{typ: querypb.Type_INT64, ival: v1.ival + v2.ival} + } +} + +func BenchmarkAddGo(b *testing.B) { + v1 := int64(1) + v2 := int64(2) + for i := 0; i < b.N; i++ { + v1 += v2 + } +} diff --git a/go/sqltypes/expressions.go b/go/vt/vtgate/evalengine/expressions.go similarity index 86% rename from go/sqltypes/expressions.go rename to go/vt/vtgate/evalengine/expressions.go index 4f3fb6c735e..a2f00c2fdce 100644 --- a/go/sqltypes/expressions.go +++ b/go/vt/vtgate/evalengine/expressions.go @@ -14,10 +14,11 @@ See the License for the specific language governing permissions and limitations under the License. */ -package sqltypes +package evalengine import ( "strconv" + "vitess.io/vitess/go/sqltypes" querypb "vitess.io/vitess/go/vt/proto/query" vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" @@ -36,7 +37,7 @@ type ( //evaluates in, such as the current row and bindvars ExpressionEnv struct { BindVars map[string]*querypb.BindVariable - Row []Value + Row []sqltypes.Value } // EvalResult is used so we don't have to expose all parts of the private struct @@ -73,7 +74,7 @@ type ( ) //Value allows for retrieval of the value we expose for public consumption -func (e EvalResult) Value() Value { +func (e EvalResult) Value() sqltypes.Value { return castFromNumeric(e, e.typ) } @@ -82,7 +83,7 @@ func NewLiteralInt(val []byte) (Expr, error) { if err != nil { return nil, err } - return &LiteralFloat{evalResult{typ: Int64, ival: ival}}, nil + return &LiteralFloat{evalResult{typ: sqltypes.Int64, ival: ival}}, nil } func NewLiteralFloat(val []byte) (Expr, error) { @@ -90,7 +91,7 @@ func NewLiteralFloat(val []byte) (Expr, error) { if err != nil { return nil, err } - return &LiteralFloat{evalResult{typ: Float64, fval: fval}}, nil + return &LiteralFloat{evalResult{typ: sqltypes.Float64, fval: fval}}, nil } var _ Expr = (*LiteralInt)(nil) @@ -156,22 +157,22 @@ func (d *Division) Evaluate(left, right EvalResult) (EvalResult, error) { //Type implements the BinaryExpr interface func (a *Addition) Type(left querypb.Type) querypb.Type { - return addAndSubtractType(left) + return left } //Type implements the BinaryExpr interface func (m *Multiplication) Type(left querypb.Type) querypb.Type { - return addAndSubtractType(left) + return left } //Type implements the BinaryExpr interface func (d *Division) Type(querypb.Type) querypb.Type { - return Float64 + return sqltypes.Float64 } //Type implements the BinaryExpr interface func (s *Subtraction) Type(left querypb.Type) querypb.Type { - return addAndSubtractType(left) + return left } //Type implements the Expr interface @@ -190,11 +191,11 @@ func (b *BindVariable) Type(env ExpressionEnv) querypb.Type { //Type implements the Expr interface func (l *LiteralInt) Type(_ ExpressionEnv) querypb.Type { - return Int64 + return sqltypes.Int64 } func (l *LiteralFloat) Type(env ExpressionEnv) querypb.Type { - return Float64 + return sqltypes.Float64 } //String implements the BinaryExpr interface @@ -238,12 +239,12 @@ func (l *LiteralFloat) String() string { func mergeNumericalTypes(ltype, rtype querypb.Type) querypb.Type { switch ltype { - case Int64: - if rtype == Uint64 || rtype == Float64 { + case sqltypes.Int64: + if rtype == sqltypes.Uint64 || rtype == sqltypes.Float64 { return rtype } - case Uint64: - if rtype == Float64 { + case sqltypes.Uint64: + if rtype == sqltypes.Float64 { return rtype } } @@ -252,40 +253,29 @@ func mergeNumericalTypes(ltype, rtype querypb.Type) querypb.Type { func evaluateByType(val *querypb.BindVariable) (EvalResult, error) { switch val.Type { - case Int64: + case sqltypes.Int64: ival, err := strconv.ParseInt(string(val.Value), 10, 64) if err != nil { ival = 0 } - return evalResult{typ: Int64, ival: ival}, nil - case Uint64: + return evalResult{typ: sqltypes.Int64, ival: ival}, nil + case sqltypes.Uint64: uval, err := strconv.ParseUint(string(val.Value), 10, 64) if err != nil { uval = 0 } - return evalResult{typ: Uint64, uval: uval}, nil - case Float64: + return evalResult{typ: sqltypes.Uint64, uval: uval}, nil + case sqltypes.Float64: fval, err := strconv.ParseFloat(string(val.Value), 64) if err != nil { fval = 0 } - return evalResult{typ: Float64, fval: fval}, nil - case VarChar: - return evalResult{typ: VarChar, str: string(val.Value)}, nil - case VarBinary: - return evalResult{typ: VarBinary, str: string(val.Value)}, nil + return evalResult{typ: sqltypes.Float64, fval: fval}, nil + case sqltypes.VarChar: + return evalResult{typ: sqltypes.VarChar, str: string(val.Value)}, nil + case sqltypes.VarBinary: + return evalResult{typ: sqltypes.VarBinary, str: string(val.Value)}, nil } return evalResult{}, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "Type is not supported") } -func addAndSubtractType(left querypb.Type) querypb.Type { - switch left { - case Int64: - return Int64 - case Uint64: - return Uint64 - case Float64: - return Float64 - } - panic("unreachable") -} diff --git a/go/vt/vtgate/evalengine/expressions_test.go b/go/vt/vtgate/evalengine/expressions_test.go new file mode 100644 index 00000000000..6aed3ccc3e9 --- /dev/null +++ b/go/vt/vtgate/evalengine/expressions_test.go @@ -0,0 +1,106 @@ +/* +Copyright 2020 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package evalengine + +import ( + "fmt" + "reflect" + "testing" + + "vitess.io/vitess/go/sqltypes" + + "github.com/magiconair/properties/assert" + querypb "vitess.io/vitess/go/vt/proto/query" +) + +// more tests in go/sqlparser/expressions_test.go + +func TestBinaryOpTypes(t *testing.T) { + type testcase struct { + l, r, e querypb.Type + } + type ops struct { + op BinaryExpr + testcases []testcase + } + + tests := []ops{ + { + op: &Addition{}, + testcases: []testcase{ + {sqltypes.Int64, sqltypes.Int64, sqltypes.Int64}, + {sqltypes.Uint64, sqltypes.Int64, sqltypes.Uint64}, + {sqltypes.Float64, sqltypes.Int64, sqltypes.Float64}, + {sqltypes.Int64, sqltypes.Uint64, sqltypes.Int64}, + {sqltypes.Uint64, sqltypes.Uint64, sqltypes.Uint64}, + {sqltypes.Float64, sqltypes.Uint64, sqltypes.Float64}, + {sqltypes.Int64, sqltypes.Float64, sqltypes.Int64}, + {sqltypes.Uint64, sqltypes.Float64, sqltypes.Uint64}, + {sqltypes.Float64, sqltypes.Float64, sqltypes.Float64}, + }, + }, { + op: &Subtraction{}, + testcases: []testcase{ + {sqltypes.Int64, sqltypes.Int64, sqltypes.Int64}, + {sqltypes.Uint64, sqltypes.Int64, sqltypes.Uint64}, + {sqltypes.Float64, sqltypes.Int64, sqltypes.Float64}, + {sqltypes.Int64, sqltypes.Uint64, sqltypes.Int64}, + {sqltypes.Uint64, sqltypes.Uint64, sqltypes.Uint64}, + {sqltypes.Float64, sqltypes.Uint64, sqltypes.Float64}, + {sqltypes.Int64, sqltypes.Float64, sqltypes.Int64}, + {sqltypes.Uint64, sqltypes.Float64, sqltypes.Uint64}, + {sqltypes.Float64, sqltypes.Float64, sqltypes.Float64}, + }, + }, { + op: &Multiplication{}, + testcases: []testcase{ + {sqltypes.Int64, sqltypes.Int64, sqltypes.Int64}, + {sqltypes.Uint64, sqltypes.Int64, sqltypes.Uint64}, + {sqltypes.Float64, sqltypes.Int64, sqltypes.Float64}, + {sqltypes.Int64, sqltypes.Uint64, sqltypes.Int64}, + {sqltypes.Uint64, sqltypes.Uint64, sqltypes.Uint64}, + {sqltypes.Float64, sqltypes.Uint64, sqltypes.Float64}, + {sqltypes.Int64, sqltypes.Float64, sqltypes.Int64}, + {sqltypes.Uint64, sqltypes.Float64, sqltypes.Uint64}, + {sqltypes.Float64, sqltypes.Float64, sqltypes.Float64}, + }, + }, { + op: &Division{}, + testcases: []testcase{ + {sqltypes.Int64, sqltypes.Int64, sqltypes.Float64}, + {sqltypes.Uint64, sqltypes.Int64, sqltypes.Float64}, + {sqltypes.Float64, sqltypes.Int64, sqltypes.Float64}, + {sqltypes.Int64, sqltypes.Uint64, sqltypes.Float64}, + {sqltypes.Uint64, sqltypes.Uint64, sqltypes.Float64}, + {sqltypes.Float64, sqltypes.Uint64, sqltypes.Float64}, + {sqltypes.Int64, sqltypes.Float64, sqltypes.Float64}, + {sqltypes.Uint64, sqltypes.Float64, sqltypes.Float64}, + {sqltypes.Float64, sqltypes.Float64, sqltypes.Float64}, + }, + }, + } + + for _, op := range tests { + for _, tc := range op.testcases { + name := fmt.Sprintf("%s %s %s", tc.l.String(), reflect.TypeOf(op.op).String(), tc.r.String()) + t.Run(name, func(t *testing.T) { + result := op.op.Type(tc.l) + assert.Equal(t, tc.e, result) + }) + } + } +} diff --git a/go/vt/vtgate/planbuilder/select.go b/go/vt/vtgate/planbuilder/select.go index 6cbdf7f29e1..78ebcf3cc92 100644 --- a/go/vt/vtgate/planbuilder/select.go +++ b/go/vt/vtgate/planbuilder/select.go @@ -20,7 +20,8 @@ import ( "errors" "fmt" - "vitess.io/vitess/go/sqltypes" + "vitess.io/vitess/go/vt/vtgate/evalengine" + "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vtgate/engine" ) @@ -131,7 +132,7 @@ func (pb *primitiveBuilder) processSelect(sel *sqlparser.Select, outer *symtab) func tryAtVtgate(sel *sqlparser.Select) engine.Primitive { if checkForDual(sel) { var err error - exprs := make([]sqltypes.Expr, len(sel.SelectExprs)) + exprs := make([]evalengine.Expr, len(sel.SelectExprs)) cols := make([]string, len(sel.SelectExprs)) for i, e := range sel.SelectExprs { expr := e.(*sqlparser.AliasedExpr) diff --git a/go/vt/vtgate/vindexes/consistent_lookup.go b/go/vt/vtgate/vindexes/consistent_lookup.go index 84072e75beb..1b7b3c5294e 100644 --- a/go/vt/vtgate/vindexes/consistent_lookup.go +++ b/go/vt/vtgate/vindexes/consistent_lookup.go @@ -22,6 +22,8 @@ import ( "fmt" "strings" + "vitess.io/vitess/go/vt/vtgate/evalengine" + "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/vt/key" querypb "vitess.io/vitess/go/vt/proto/query" @@ -306,7 +308,7 @@ func (lu *clCommon) Delete(vcursor VCursor, rowsColValues [][]sqltypes.Value, ks func (lu *clCommon) Update(vcursor VCursor, oldValues []sqltypes.Value, ksid []byte, newValues []sqltypes.Value) error { equal := true for i := range oldValues { - result, err := sqltypes.NullsafeCompare(oldValues[i], newValues[i]) + result, err := evalengine.NullsafeCompare(oldValues[i], newValues[i]) // errors from NullsafeCompare can be ignored. if they are real problems, we'll see them in the Create/Update if err != nil || result != 0 { equal = false diff --git a/go/vt/vtgate/vindexes/hash.go b/go/vt/vtgate/vindexes/hash.go index 21b93dfe884..9d68b167b56 100644 --- a/go/vt/vtgate/vindexes/hash.go +++ b/go/vt/vtgate/vindexes/hash.go @@ -25,6 +25,8 @@ import ( "fmt" "strconv" + "vitess.io/vitess/go/vt/vtgate/evalengine" + "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/vt/key" "vitess.io/vitess/go/vt/vterrors" @@ -83,7 +85,7 @@ func (vind *Hash) Map(cursor VCursor, ids []sqltypes.Value) ([]key.Destination, ival, err = strconv.ParseInt(str, 10, 64) num = uint64(ival) } else { - num, err = sqltypes.ToUint64(id) + num, err = evalengine.ToUint64(id) } if err != nil { @@ -99,7 +101,7 @@ func (vind *Hash) Map(cursor VCursor, ids []sqltypes.Value) ([]key.Destination, func (vind *Hash) Verify(_ VCursor, ids []sqltypes.Value, ksids [][]byte) ([]bool, error) { out := make([]bool, len(ids)) for i := range ids { - num, err := sqltypes.ToUint64(ids[i]) + num, err := evalengine.ToUint64(ids[i]) if err != nil { return nil, vterrors.Wrap(err, "hash.Verify") } diff --git a/go/vt/vtgate/vindexes/lookup_hash.go b/go/vt/vtgate/vindexes/lookup_hash.go index a9d2ec0102d..5787308e276 100644 --- a/go/vt/vtgate/vindexes/lookup_hash.go +++ b/go/vt/vtgate/vindexes/lookup_hash.go @@ -20,6 +20,8 @@ import ( "encoding/json" "fmt" + "vitess.io/vitess/go/vt/vtgate/evalengine" + "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/vt/key" topodatapb "vitess.io/vitess/go/vt/proto/topodata" @@ -119,7 +121,7 @@ func (lh *LookupHash) Map(vcursor VCursor, ids []sqltypes.Value) ([]key.Destinat } ksids := make([][]byte, 0, len(result.Rows)) for _, row := range result.Rows { - num, err := sqltypes.ToUint64(row[0]) + num, err := evalengine.ToUint64(row[0]) if err != nil { // A failure to convert is equivalent to not being // able to map. @@ -273,7 +275,7 @@ func (lhu *LookupHashUnique) Map(vcursor VCursor, ids []sqltypes.Value) ([]key.D case 0: out = append(out, key.DestinationNone{}) case 1: - num, err := sqltypes.ToUint64(result.Rows[0][0]) + num, err := evalengine.ToUint64(result.Rows[0][0]) if err != nil { out = append(out, key.DestinationNone{}) continue diff --git a/go/vt/vtgate/vindexes/lookup_unicodeloosemd5_hash.go b/go/vt/vtgate/vindexes/lookup_unicodeloosemd5_hash.go index fccec4ec110..000b6d21608 100644 --- a/go/vt/vtgate/vindexes/lookup_unicodeloosemd5_hash.go +++ b/go/vt/vtgate/vindexes/lookup_unicodeloosemd5_hash.go @@ -21,6 +21,8 @@ import ( "encoding/json" "fmt" + "vitess.io/vitess/go/vt/vtgate/evalengine" + "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/vt/key" topodatapb "vitess.io/vitess/go/vt/proto/topodata" @@ -124,7 +126,7 @@ func (lh *LookupUnicodeLooseMD5Hash) Map(vcursor VCursor, ids []sqltypes.Value) } ksids := make([][]byte, 0, len(result.Rows)) for _, row := range result.Rows { - num, err := sqltypes.ToUint64(row[0]) + num, err := evalengine.ToUint64(row[0]) if err != nil { // A failure to convert is equivalent to not being // able to map. @@ -289,7 +291,7 @@ func (lhu *LookupUnicodeLooseMD5HashUnique) Map(vcursor VCursor, ids []sqltypes. case 0: out = append(out, key.DestinationNone{}) case 1: - num, err := sqltypes.ToUint64(result.Rows[0][0]) + num, err := evalengine.ToUint64(result.Rows[0][0]) if err != nil { out = append(out, key.DestinationNone{}) continue diff --git a/go/vt/vtgate/vindexes/numeric.go b/go/vt/vtgate/vindexes/numeric.go index 3b34b8a2298..e031031048d 100644 --- a/go/vt/vtgate/vindexes/numeric.go +++ b/go/vt/vtgate/vindexes/numeric.go @@ -21,6 +21,8 @@ import ( "encoding/binary" "fmt" + "vitess.io/vitess/go/vt/vtgate/evalengine" + "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/vt/key" "vitess.io/vitess/go/vt/vterrors" @@ -67,7 +69,7 @@ func (*Numeric) Verify(_ VCursor, ids []sqltypes.Value, ksids [][]byte) ([]bool, out := make([]bool, len(ids)) for i := range ids { var keybytes [8]byte - num, err := sqltypes.ToUint64(ids[i]) + num, err := evalengine.ToUint64(ids[i]) if err != nil { return nil, vterrors.Wrap(err, "Numeric.Verify") } @@ -81,7 +83,7 @@ func (*Numeric) Verify(_ VCursor, ids []sqltypes.Value, ksids [][]byte) ([]bool, func (*Numeric) Map(cursor VCursor, ids []sqltypes.Value) ([]key.Destination, error) { out := make([]key.Destination, 0, len(ids)) for _, id := range ids { - num, err := sqltypes.ToUint64(id) + num, err := evalengine.ToUint64(id) if err != nil { out = append(out, key.DestinationNone{}) continue diff --git a/go/vt/vtgate/vindexes/numeric_static_map.go b/go/vt/vtgate/vindexes/numeric_static_map.go index 6a1ac7ecbc1..e4501a2f923 100644 --- a/go/vt/vtgate/vindexes/numeric_static_map.go +++ b/go/vt/vtgate/vindexes/numeric_static_map.go @@ -24,6 +24,8 @@ import ( "io/ioutil" "strconv" + "vitess.io/vitess/go/vt/vtgate/evalengine" + "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/vt/key" "vitess.io/vitess/go/vt/vterrors" @@ -90,7 +92,7 @@ func (vind *NumericStaticMap) Verify(_ VCursor, ids []sqltypes.Value, ksids [][] out := make([]bool, len(ids)) for i := range ids { var keybytes [8]byte - num, err := sqltypes.ToUint64(ids[i]) + num, err := evalengine.ToUint64(ids[i]) if err != nil { return nil, vterrors.Wrap(err, "NumericStaticMap.Verify") } @@ -108,7 +110,7 @@ func (vind *NumericStaticMap) Verify(_ VCursor, ids []sqltypes.Value, ksids [][] func (vind *NumericStaticMap) Map(cursor VCursor, ids []sqltypes.Value) ([]key.Destination, error) { out := make([]key.Destination, 0, len(ids)) for _, id := range ids { - num, err := sqltypes.ToUint64(id) + num, err := evalengine.ToUint64(id) if err != nil { out = append(out, key.DestinationNone{}) continue diff --git a/go/vt/vtgate/vindexes/region_experimental.go b/go/vt/vtgate/vindexes/region_experimental.go index 50260b08d45..7a072dd70e6 100644 --- a/go/vt/vtgate/vindexes/region_experimental.go +++ b/go/vt/vtgate/vindexes/region_experimental.go @@ -21,6 +21,8 @@ import ( "encoding/binary" "fmt" + "vitess.io/vitess/go/vt/vtgate/evalengine" + "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/vt/key" ) @@ -94,7 +96,7 @@ func (ge *RegionExperimental) Map(vcursor VCursor, rowsColValues [][]sqltypes.Va continue } // Compute region prefix. - rn, err := sqltypes.ToUint64(row[0]) + rn, err := evalengine.ToUint64(row[0]) if err != nil { destinations = append(destinations, key.DestinationNone{}) continue @@ -103,7 +105,7 @@ func (ge *RegionExperimental) Map(vcursor VCursor, rowsColValues [][]sqltypes.Va binary.BigEndian.PutUint16(r, uint16(rn)) // Compute hash. - hn, err := sqltypes.ToUint64(row[1]) + hn, err := evalengine.ToUint64(row[1]) if err != nil { destinations = append(destinations, key.DestinationNone{}) continue diff --git a/go/vt/vtgate/vindexes/region_json.go b/go/vt/vtgate/vindexes/region_json.go index 8f934e941ad..9720819c565 100644 --- a/go/vt/vtgate/vindexes/region_json.go +++ b/go/vt/vtgate/vindexes/region_json.go @@ -24,6 +24,8 @@ import ( "io/ioutil" "strconv" + "vitess.io/vitess/go/vt/vtgate/evalengine" + "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/vt/key" "vitess.io/vitess/go/vt/log" @@ -108,7 +110,7 @@ func (rv *RegionJSON) Map(vcursor VCursor, rowsColValues [][]sqltypes.Value) ([] continue } // Compute hash. - hn, err := sqltypes.ToUint64(row[0]) + hn, err := evalengine.ToUint64(row[0]) if err != nil { destinations = append(destinations, key.DestinationNone{}) continue diff --git a/go/vt/vtgate/vindexes/reverse_bits.go b/go/vt/vtgate/vindexes/reverse_bits.go index d79d1f8a23f..9052a4c94ee 100644 --- a/go/vt/vtgate/vindexes/reverse_bits.go +++ b/go/vt/vtgate/vindexes/reverse_bits.go @@ -23,6 +23,8 @@ import ( "fmt" "math/bits" + "vitess.io/vitess/go/vt/vtgate/evalengine" + "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/vt/key" "vitess.io/vitess/go/vt/vterrors" @@ -68,7 +70,7 @@ func (vind *ReverseBits) NeedsVCursor() bool { func (vind *ReverseBits) Map(cursor VCursor, ids []sqltypes.Value) ([]key.Destination, error) { out := make([]key.Destination, len(ids)) for i, id := range ids { - num, err := sqltypes.ToUint64(id) + num, err := evalengine.ToUint64(id) if err != nil { out[i] = key.DestinationNone{} continue @@ -82,7 +84,7 @@ func (vind *ReverseBits) Map(cursor VCursor, ids []sqltypes.Value) ([]key.Destin func (vind *ReverseBits) Verify(_ VCursor, ids []sqltypes.Value, ksids [][]byte) ([]bool, error) { out := make([]bool, len(ids)) for i := range ids { - num, err := sqltypes.ToUint64(ids[i]) + num, err := evalengine.ToUint64(ids[i]) if err != nil { return nil, vterrors.Wrap(err, "reverseBits.Verify") } diff --git a/go/vt/vttablet/heartbeat/reader.go b/go/vt/vttablet/heartbeat/reader.go index dd1deeec79c..b17e9d0a235 100644 --- a/go/vt/vttablet/heartbeat/reader.go +++ b/go/vt/vttablet/heartbeat/reader.go @@ -21,6 +21,8 @@ import ( "sync" "time" + "vitess.io/vitess/go/vt/vtgate/evalengine" + "vitess.io/vitess/go/vt/vterrors" "golang.org/x/net/context" @@ -205,7 +207,7 @@ func parseHeartbeatResult(res *sqltypes.Result) (int64, error) { if len(res.Rows) != 1 { return 0, fmt.Errorf("failed to read heartbeat: writer query did not result in 1 row. Got %v", len(res.Rows)) } - ts, err := sqltypes.ToInt64(res.Rows[0][0]) + ts, err := evalengine.ToInt64(res.Rows[0][0]) if err != nil { return 0, err } diff --git a/go/vt/vttablet/tabletmanager/vreplication/engine.go b/go/vt/vttablet/tabletmanager/vreplication/engine.go index 9623563cc38..1c77300455c 100644 --- a/go/vt/vttablet/tabletmanager/vreplication/engine.go +++ b/go/vt/vttablet/tabletmanager/vreplication/engine.go @@ -24,6 +24,8 @@ import ( "sync" "time" + "vitess.io/vitess/go/vt/vtgate/evalengine" + "golang.org/x/net/context" "vitess.io/vitess/go/mysql" "vitess.io/vitess/go/sqltypes" @@ -392,7 +394,7 @@ func (vre *Engine) fetchIDs(dbClient binlogplayer.DBClient, selector string) (id return nil, nil, err } for _, row := range qr.Rows { - id, err := sqltypes.ToInt64(row[0]) + id, err := evalengine.ToInt64(row[0]) if err != nil { return nil, nil, err } diff --git a/go/vt/vttablet/tabletmanager/vreplication/vreplicator.go b/go/vt/vttablet/tabletmanager/vreplication/vreplicator.go index e22dc9638a1..c16cbde24dd 100644 --- a/go/vt/vttablet/tabletmanager/vreplication/vreplicator.go +++ b/go/vt/vttablet/tabletmanager/vreplication/vreplicator.go @@ -21,6 +21,8 @@ import ( "strings" "time" + "vitess.io/vitess/go/vt/vtgate/evalengine" + "golang.org/x/net/context" "vitess.io/vitess/go/mysql" @@ -208,7 +210,7 @@ func (vr *vreplicator) readSettings(ctx context.Context) (settings binlogplayer. if len(qr.Rows) == 0 || len(qr.Rows[0]) == 0 { return settings, numTablesToCopy, fmt.Errorf("unexpected result from %s: %v", query, qr) } - numTablesToCopy, err = sqltypes.ToInt64(qr.Rows[0][0]) + numTablesToCopy, err = evalengine.ToInt64(qr.Rows[0][0]) if err != nil { return settings, numTablesToCopy, err } diff --git a/go/vt/vttablet/tabletserver/messager/message_manager.go b/go/vt/vttablet/tabletserver/messager/message_manager.go index 2412eb601f4..7a0d5825771 100644 --- a/go/vt/vttablet/tabletserver/messager/message_manager.go +++ b/go/vt/vttablet/tabletserver/messager/message_manager.go @@ -22,6 +22,8 @@ import ( "sync" "time" + "vitess.io/vitess/go/vt/vtgate/evalengine" + "golang.org/x/net/context" "vitess.io/vitess/go/mysql" @@ -829,28 +831,28 @@ func (mm *messageManager) GeneratePurgeQuery(timeCutoff int64) (string, map[stri func BuildMessageRow(row []sqltypes.Value) (*MessageRow, error) { mr := &MessageRow{Row: row[4:]} if !row[0].IsNull() { - v, err := sqltypes.ToInt64(row[0]) + v, err := evalengine.ToInt64(row[0]) if err != nil { return nil, err } mr.Priority = v } if !row[1].IsNull() { - v, err := sqltypes.ToInt64(row[0]) + v, err := evalengine.ToInt64(row[0]) if err != nil { return nil, err } mr.TimeNext = v } if !row[2].IsNull() { - v, err := sqltypes.ToInt64(row[1]) + v, err := evalengine.ToInt64(row[1]) if err != nil { return nil, err } mr.Epoch = v } if !row[3].IsNull() { - v, err := sqltypes.ToInt64(row[2]) + v, err := evalengine.ToInt64(row[2]) if err != nil { return nil, err } diff --git a/go/vt/vttablet/tabletserver/messager/message_manager_test.go b/go/vt/vttablet/tabletserver/messager/message_manager_test.go index b9ee7a7fd73..0f6bbd44def 100644 --- a/go/vt/vttablet/tabletserver/messager/message_manager_test.go +++ b/go/vt/vttablet/tabletserver/messager/message_manager_test.go @@ -26,6 +26,8 @@ import ( "testing" "time" + "vitess.io/vitess/go/vt/vtgate/evalengine" + "vitess.io/vitess/go/test/utils" "github.com/stretchr/testify/assert" @@ -741,7 +743,7 @@ func TestMMGenerate(t *testing.T) { t.Errorf("GenerateAckQuery query: %s, want %s", query, wantQuery) } bvv, _ := sqltypes.BindVariableToValue(bv["time_acked"]) - gotAcked, _ := sqltypes.ToInt64(bvv) + gotAcked, _ := evalengine.ToInt64(bvv) wantAcked := time.Now().UnixNano() if wantAcked-gotAcked > 10e9 { t.Errorf("gotAcked: %d, should be with 10s of %d", gotAcked, wantAcked) diff --git a/go/vt/vttablet/tabletserver/query_executor.go b/go/vt/vttablet/tabletserver/query_executor.go index 91e51990bcd..ec6a3afe8c7 100644 --- a/go/vt/vttablet/tabletserver/query_executor.go +++ b/go/vt/vttablet/tabletserver/query_executor.go @@ -22,6 +22,8 @@ import ( "strings" "time" + "vitess.io/vitess/go/vt/vtgate/evalengine" + "golang.org/x/net/context" "vitess.io/vitess/go/mysql" @@ -400,7 +402,7 @@ func (qre *QueryExecutor) execNextval() (*sqltypes.Result, error) { if len(qr.Rows) != 1 { return nil, fmt.Errorf("unexpected rows from reading sequence %s (possible mis-route): %d", tableName, len(qr.Rows)) } - nextID, err := sqltypes.ToInt64(qr.Rows[0][0]) + nextID, err := evalengine.ToInt64(qr.Rows[0][0]) if err != nil { return nil, vterrors.Wrapf(err, "error loading sequence %s", tableName) } @@ -415,7 +417,7 @@ func (qre *QueryExecutor) execNextval() (*sqltypes.Result, error) { t.SequenceInfo.NextVal = nextID t.SequenceInfo.LastVal = nextID } - cache, err := sqltypes.ToInt64(qr.Rows[0][1]) + cache, err := evalengine.ToInt64(qr.Rows[0][1]) if err != nil { return nil, vterrors.Wrapf(err, "error loading sequence %s", tableName) } @@ -688,5 +690,5 @@ func resolveNumber(pv sqltypes.PlanValue, bindVars map[string]*querypb.BindVaria if err != nil { return 0, err } - return sqltypes.ToInt64(v) + return evalengine.ToInt64(v) } diff --git a/go/vt/vttablet/tabletserver/rules/rules.go b/go/vt/vttablet/tabletserver/rules/rules.go index b641b2e55ce..8dd25ef7450 100644 --- a/go/vt/vttablet/tabletserver/rules/rules.go +++ b/go/vt/vttablet/tabletserver/rules/rules.go @@ -24,6 +24,8 @@ import ( "regexp" "strconv" + "vitess.io/vitess/go/vt/vtgate/evalengine" + "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/vt/vterrors" "vitess.io/vitess/go/vt/vttablet/tabletserver/planbuilder" @@ -732,7 +734,7 @@ func getuint64(val *querypb.BindVariable) (uv uint64, status int) { if err != nil { return 0, QROutOfRange } - v, err := sqltypes.ToUint64(bv) + v, err := evalengine.ToUint64(bv) if err != nil { return 0, QROutOfRange } @@ -745,7 +747,7 @@ func getint64(val *querypb.BindVariable) (iv int64, status int) { if err != nil { return 0, QROutOfRange } - v, err := sqltypes.ToInt64(bv) + v, err := evalengine.ToInt64(bv) if err != nil { return 0, QROutOfRange } diff --git a/go/vt/vttablet/tabletserver/schema/engine.go b/go/vt/vttablet/tabletserver/schema/engine.go index 57b4400b794..d5fa83444f1 100644 --- a/go/vt/vttablet/tabletserver/schema/engine.go +++ b/go/vt/vttablet/tabletserver/schema/engine.go @@ -23,11 +23,12 @@ import ( "sync" "time" + "vitess.io/vitess/go/vt/vtgate/evalengine" + "golang.org/x/net/context" "vitess.io/vitess/go/acl" "vitess.io/vitess/go/mysql" - "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/timer" "vitess.io/vitess/go/vt/concurrency" "vitess.io/vitess/go/vt/dbconfigs" @@ -211,7 +212,7 @@ func (se *Engine) reload(ctx context.Context) error { for _, row := range tableData.Rows { tableName := row[0].ToString() curTables[tableName] = true - createTime, _ := sqltypes.ToInt64(row[2]) + createTime, _ := evalengine.ToInt64(row[2]) if _, ok := se.tables[tableName]; ok && createTime < se.lastChange { continue } @@ -266,7 +267,7 @@ func (se *Engine) mysqlTime(ctx context.Context, conn *connpool.DBConn) (int64, if len(tm.Rows) != 1 || len(tm.Rows[0]) != 1 || tm.Rows[0][0].IsNull() { return 0, vterrors.Errorf(vtrpcpb.Code_UNKNOWN, "unexpected result for MySQL time: %+v", tm.Rows) } - t, err := sqltypes.ToInt64(tm.Rows[0][0]) + t, err := evalengine.ToInt64(tm.Rows[0][0]) if err != nil { return 0, vterrors.Errorf(vtrpcpb.Code_UNKNOWN, "could not parse time %v: %v", tm, err) } diff --git a/go/vt/vttablet/tabletserver/twopc.go b/go/vt/vttablet/tabletserver/twopc.go index 8b8c87fb661..76c30ef9f63 100644 --- a/go/vt/vttablet/tabletserver/twopc.go +++ b/go/vt/vttablet/tabletserver/twopc.go @@ -20,6 +20,8 @@ import ( "fmt" "time" + "vitess.io/vitess/go/vt/vtgate/evalengine" + "golang.org/x/net/context" "vitess.io/vitess/go/sqlescape" @@ -288,12 +290,12 @@ func (tpc *TwoPC) ReadAllRedo(ctx context.Context) (prepared, failed []*Prepared // Initialize the new element. // A failure in time parsing will show up as a very old time, // which is harmless. - tm, _ := sqltypes.ToInt64(row[2]) + tm, _ := evalengine.ToInt64(row[2]) curTx = &PreparedTx{ Dtid: dtid, Time: time.Unix(0, tm), } - st, err := sqltypes.ToInt64(row[1]) + st, err := evalengine.ToInt64(row[1]) if err != nil { log.Errorf("Error parsing state for dtid %s: %v.", dtid, err) } @@ -330,7 +332,7 @@ func (tpc *TwoPC) CountUnresolvedRedo(ctx context.Context, unresolvedTime time.T if len(qr.Rows) < 1 { return 0, nil } - v, _ := sqltypes.ToInt64(qr.Rows[0][0]) + v, _ := evalengine.ToInt64(qr.Rows[0][0]) return v, nil } @@ -417,7 +419,7 @@ func (tpc *TwoPC) ReadTransaction(ctx context.Context, dtid string) (*querypb.Tr return result, nil } result.Dtid = qr.Rows[0][0].ToString() - st, err := sqltypes.ToInt64(qr.Rows[0][1]) + st, err := evalengine.ToInt64(qr.Rows[0][1]) if err != nil { return nil, vterrors.Wrapf(err, "Error parsing state for dtid %s", dtid) } @@ -427,7 +429,7 @@ func (tpc *TwoPC) ReadTransaction(ctx context.Context, dtid string) (*querypb.Tr } // A failure in time parsing will show up as a very old time, // which is harmless. - tm, _ := sqltypes.ToInt64(qr.Rows[0][2]) + tm, _ := evalengine.ToInt64(qr.Rows[0][2]) result.TimeCreated = tm qr, err = tpc.read(ctx, conn, tpc.readParticipants, bindVars) @@ -464,7 +466,7 @@ func (tpc *TwoPC) ReadAbandoned(ctx context.Context, abandonTime time.Time) (map } txs := make(map[string]time.Time, len(qr.Rows)) for _, row := range qr.Rows { - t, err := sqltypes.ToInt64(row[1]) + t, err := evalengine.ToInt64(row[1]) if err != nil { return nil, err } @@ -503,8 +505,8 @@ func (tpc *TwoPC) ReadAllTransactions(ctx context.Context) ([]*DistributedTx, er // Initialize the new element. // A failure in time parsing will show up as a very old time, // which is harmless. - tm, _ := sqltypes.ToInt64(row[2]) - st, err := sqltypes.ToInt64(row[1]) + tm, _ := evalengine.ToInt64(row[2]) + st, err := evalengine.ToInt64(row[1]) // Just log on error and continue. The state will show up as UNKNOWN // on the display. if err != nil { diff --git a/go/vt/vttablet/tabletserver/vstreamer/planbuilder.go b/go/vt/vttablet/tabletserver/vstreamer/planbuilder.go index c87c5cb462d..d10a73488f0 100644 --- a/go/vt/vttablet/tabletserver/vstreamer/planbuilder.go +++ b/go/vt/vttablet/tabletserver/vstreamer/planbuilder.go @@ -21,6 +21,8 @@ import ( "regexp" "strings" + "vitess.io/vitess/go/vt/vtgate/evalengine" + "vitess.io/vitess/go/mysql" "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/vt/key" @@ -115,7 +117,7 @@ func (plan *Plan) filter(values []sqltypes.Value) (bool, []sqltypes.Value, error for _, filter := range plan.Filters { switch filter.Opcode { case Equal: - result, err := sqltypes.NullsafeCompare(values[filter.ColNum], filter.Value) + result, err := evalengine.NullsafeCompare(values[filter.ColNum], filter.Value) if err != nil { return false, nil, err } diff --git a/go/vt/worker/chunk.go b/go/vt/worker/chunk.go index 89abd95383e..50c5cc5b0e8 100644 --- a/go/vt/worker/chunk.go +++ b/go/vt/worker/chunk.go @@ -19,6 +19,8 @@ package worker import ( "fmt" + "vitess.io/vitess/go/vt/vtgate/evalengine" + "vitess.io/vitess/go/vt/proto/vtrpc" "vitess.io/vitess/go/vt/vterrors" @@ -105,8 +107,8 @@ func generateChunks(ctx context.Context, wr *wrangler.Wrangler, tablet *topodata } result := sqltypes.Proto3ToResult(qr) - min, _ := sqltypes.ToNative(result.Rows[0][0]) - max, _ := sqltypes.ToNative(result.Rows[0][1]) + min, _ := evalengine.ToNative(result.Rows[0][0]) + max, _ := evalengine.ToNative(result.Rows[0][1]) if min == nil || max == nil { wr.Logger().Infof("table=%v: Not splitting the table into multiple chunks, min or max is NULL: %v", td.Name, qr.Rows[0]) diff --git a/go/vt/worker/diff_utils.go b/go/vt/worker/diff_utils.go index 132fed84575..b73785a037b 100644 --- a/go/vt/worker/diff_utils.go +++ b/go/vt/worker/diff_utils.go @@ -25,6 +25,8 @@ import ( "strings" "time" + "vitess.io/vitess/go/vt/vtgate/evalengine" + "vitess.io/vitess/go/vt/proto/vtrpc" "vitess.io/vitess/go/vt/vterrors" "vitess.io/vitess/go/vt/vttablet/tmclient" @@ -480,8 +482,8 @@ func RowsEqual(left, right []sqltypes.Value) int { // TODO: This can panic if types for left and right don't match. func CompareRows(fields []*querypb.Field, compareCount int, left, right []sqltypes.Value) (int, error) { for i := 0; i < compareCount; i++ { - lv, _ := sqltypes.ToNative(left[i]) - rv, _ := sqltypes.ToNative(right[i]) + lv, _ := evalengine.ToNative(left[i]) + rv, _ := evalengine.ToNative(right[i]) switch l := lv.(type) { case int64: r := rv.(int64) diff --git a/go/vt/worker/key_resolver.go b/go/vt/worker/key_resolver.go index a79fb1b66f6..b58fa09087f 100644 --- a/go/vt/worker/key_resolver.go +++ b/go/vt/worker/key_resolver.go @@ -19,6 +19,7 @@ package worker import ( "vitess.io/vitess/go/vt/proto/vtrpc" "vitess.io/vitess/go/vt/vterrors" + "vitess.io/vitess/go/vt/vtgate/evalengine" "vitess.io/vitess/go/sqltypes" @@ -78,7 +79,7 @@ func (r *v2Resolver) keyspaceID(row []sqltypes.Value) ([]byte, error) { case topodatapb.KeyspaceIdType_BYTES: return v.ToBytes(), nil case topodatapb.KeyspaceIdType_UINT64: - i, err := sqltypes.ToUint64(v) + i, err := evalengine.ToUint64(v) if err != nil { return nil, vterrors.Wrap(err, "Non numerical value") } diff --git a/go/vt/worker/split_clone_flaky_test.go b/go/vt/worker/split_clone_flaky_test.go index 716741da36d..0bf8336e4ee 100644 --- a/go/vt/worker/split_clone_flaky_test.go +++ b/go/vt/worker/split_clone_flaky_test.go @@ -26,6 +26,8 @@ import ( "testing" "time" + "vitess.io/vitess/go/vt/vtgate/evalengine" + "golang.org/x/net/context" "vitess.io/vitess/go/mysql" "vitess.io/vitess/go/mysql/fakesqldb" @@ -391,7 +393,7 @@ func (sq *testQueryService) StreamExecute(ctx context.Context, target *querypb.T // Send the values. rowsAffected := 0 for _, row := range sq.rows { - v, _ := sqltypes.ToNative(row[0]) + v, _ := evalengine.ToNative(row[0]) primaryKey := v.(int64) if primaryKey >= int64(min) && primaryKey < int64(max) { diff --git a/go/vt/wrangler/materializer.go b/go/vt/wrangler/materializer.go index 43213bd7e4d..4cab0c0a9c3 100644 --- a/go/vt/wrangler/materializer.go +++ b/go/vt/wrangler/materializer.go @@ -22,6 +22,8 @@ import ( "sync" "text/template" + "vitess.io/vitess/go/vt/vtgate/evalengine" + "github.com/gogo/protobuf/proto" "golang.org/x/net/context" @@ -476,7 +478,7 @@ func (wr *Wrangler) ExternalizeVindex(ctx context.Context, qualifiedVindexName s } qr := sqltypes.Proto3ToResult(p3qr) for _, row := range qr.Rows { - id, err := sqltypes.ToInt64(row[0]) + id, err := evalengine.ToInt64(row[0]) if err != nil { return err } diff --git a/go/vt/wrangler/stream_migrater.go b/go/vt/wrangler/stream_migrater.go index 963f6e711cb..52185107d5d 100644 --- a/go/vt/wrangler/stream_migrater.go +++ b/go/vt/wrangler/stream_migrater.go @@ -23,6 +23,8 @@ import ( "sync" "text/template" + "vitess.io/vitess/go/vt/vtgate/evalengine" + "github.com/golang/protobuf/proto" "golang.org/x/net/context" "vitess.io/vitess/go/mysql" @@ -225,7 +227,7 @@ func (sm *streamMigrater) readTabletStreams(ctx context.Context, ti *topo.Tablet tabletStreams := make([]*vrStream, 0, len(qr.Rows)) for _, row := range qr.Rows { - id, err := sqltypes.ToInt64(row[0]) + id, err := evalengine.ToInt64(row[0]) if err != nil { return nil, err } diff --git a/go/vt/wrangler/traffic_switcher.go b/go/vt/wrangler/traffic_switcher.go index 4cef3aa9c9b..d400a61d1f6 100644 --- a/go/vt/wrangler/traffic_switcher.go +++ b/go/vt/wrangler/traffic_switcher.go @@ -26,6 +26,8 @@ import ( "sync" "time" + "vitess.io/vitess/go/vt/vtgate/evalengine" + "vitess.io/vitess/go/vt/log" "github.com/golang/protobuf/proto" @@ -402,7 +404,7 @@ func (wr *Wrangler) buildTargets(ctx context.Context, targetKeyspace, workflow s } qr := sqltypes.Proto3ToResult(p3qr) for _, row := range qr.Rows { - id, err := sqltypes.ToInt64(row[0]) + id, err := evalengine.ToInt64(row[0]) if err != nil { return nil, false, err } diff --git a/go/vt/wrangler/vdiff.go b/go/vt/wrangler/vdiff.go index 1d113cd0cab..7a635533891 100644 --- a/go/vt/wrangler/vdiff.go +++ b/go/vt/wrangler/vdiff.go @@ -23,6 +23,8 @@ import ( "sync" "time" + "vitess.io/vitess/go/vt/vtgate/evalengine" + "github.com/golang/protobuf/proto" "golang.org/x/net/context" "vitess.io/vitess/go/mysql" @@ -830,7 +832,7 @@ func (td *tableDiffer) compare(sourceRow, targetRow []sqltypes.Value, cols []int if col == -1 { continue } - c, err := sqltypes.NullsafeCompare(sourceRow[col], targetRow[col]) + c, err := evalengine.NullsafeCompare(sourceRow[col], targetRow[col]) if err != nil { return 0, err } From 06c21e1ed4d1d7fd56eaec24f0d5bba90dea8116 Mon Sep 17 00:00:00 2001 From: Saif Alharthi Date: Thu, 16 Apr 2020 21:46:23 -0700 Subject: [PATCH 22/23] Fix formatter check Signed-off-by: Saif Alharthi --- go/vt/vtgate/evalengine/expressions.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/go/vt/vtgate/evalengine/expressions.go b/go/vt/vtgate/evalengine/expressions.go index a2f00c2fdce..36748e8b0f7 100644 --- a/go/vt/vtgate/evalengine/expressions.go +++ b/go/vt/vtgate/evalengine/expressions.go @@ -18,6 +18,7 @@ package evalengine import ( "strconv" + "vitess.io/vitess/go/sqltypes" querypb "vitess.io/vitess/go/vt/proto/query" @@ -278,4 +279,3 @@ func evaluateByType(val *querypb.BindVariable) (EvalResult, error) { } return evalResult{}, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "Type is not supported") } - From 23da2247e2dea68c1272abbf5bfa6e72baba7c2b Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Fri, 17 Apr 2020 10:21:38 +0200 Subject: [PATCH 23/23] Fix import Signed-off-by: Andres Taylor --- go.mod | 1 - go/vt/vtgate/evalengine/expressions_test.go | 3 ++- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/go.mod b/go.mod index 5f10ba6ce7c..dd26e4c38d9 100644 --- a/go.mod +++ b/go.mod @@ -98,7 +98,6 @@ require ( github.com/konsorten/go-windows-terminal-sequences v1.0.2 // indirect github.com/krishicks/yaml-patch v0.0.10 github.com/lyft/protoc-gen-validate v0.0.0-20180911180927-64fcb82c878e // indirect - github.com/magiconair/properties v1.8.1 github.com/mattn/go-isatty v0.0.11 // indirect github.com/minio/minio-go v0.0.0-20190131015406-c8a261de75c1 github.com/mitchellh/copystructure v0.0.0-20160804032330-cdac8253d00f // indirect diff --git a/go/vt/vtgate/evalengine/expressions_test.go b/go/vt/vtgate/evalengine/expressions_test.go index 6aed3ccc3e9..0cfda38dde1 100644 --- a/go/vt/vtgate/evalengine/expressions_test.go +++ b/go/vt/vtgate/evalengine/expressions_test.go @@ -21,9 +21,10 @@ import ( "reflect" "testing" + "github.com/stretchr/testify/assert" + "vitess.io/vitess/go/sqltypes" - "github.com/magiconair/properties/assert" querypb "vitess.io/vitess/go/vt/proto/query" )