diff --git a/go/mysql/endtoend/replication_test.go b/go/mysql/endtoend/replication_test.go index df5debcc3d3..11a7a3c3927 100644 --- a/go/mysql/endtoend/replication_test.go +++ b/go/mysql/endtoend/replication_test.go @@ -25,6 +25,8 @@ import ( "testing" "time" + "vitess.io/vitess/go/vt/vtgate/evalengine" + "golang.org/x/net/context" "vitess.io/vitess/go/mysql" @@ -68,7 +70,7 @@ func connectForReplication(t *testing.T, rbr bool) (*mysql.Conn, mysql.BinlogFor t.Fatalf("SHOW MASTER STATUS returned unexpected result: %v", result) } file := result.Rows[0][0].ToString() - position, err := sqltypes.ToUint64(result.Rows[0][1]) + position, err := evalengine.ToUint64(result.Rows[0][1]) if err != nil { t.Fatalf("SHOW MASTER STATUS returned invalid position: %v", result.Rows[0][1]) } diff --git a/go/sqltypes/arithmetic_test.go b/go/sqltypes/arithmetic_test.go deleted file mode 100644 index d874ea97efd..00000000000 --- a/go/sqltypes/arithmetic_test.go +++ /dev/null @@ -1,1521 +0,0 @@ -/* -Copyright 2019 The Vitess Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package sqltypes - -import ( - "encoding/binary" - "fmt" - "math" - "reflect" - "strconv" - "testing" - - querypb "vitess.io/vitess/go/vt/proto/query" - vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" - "vitess.io/vitess/go/vt/vterrors" -) - -func TestDivide(t *testing.T) { - tcases := []struct { - v1, v2 Value - out Value - err error - }{{ - //All Nulls - v1: NULL, - v2: NULL, - out: NULL, - }, { - // First value null. - v1: NULL, - v2: NewInt32(1), - out: NULL, - }, { - // Second value null. - v1: NewInt32(1), - v2: NULL, - out: NULL, - }, { - // Second arg 0 - v1: NewInt32(5), - v2: NewInt32(0), - out: NULL, - }, { - // Both arguments zero - v1: NewInt32(0), - v2: NewInt32(0), - out: NULL, - }, { - // case with negative value - v1: NewInt64(-1), - v2: NewInt64(-2), - out: NewFloat64(0.5000), - }, { - // float64 division by zero - v1: NewFloat64(2), - v2: NewFloat64(0), - out: NULL, - }, { - // Lower bound for int64 - v1: NewInt64(math.MinInt64), - v2: NewInt64(1), - out: NewFloat64(math.MinInt64), - }, { - // upper bound for uint64 - v1: NewUint64(math.MaxUint64), - v2: NewUint64(1), - out: NewFloat64(math.MaxUint64), - }, { - // testing for error in types - v1: TestValue(Int64, "1.2"), - v2: NewInt64(2), - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "strconv.ParseInt: parsing \"1.2\": invalid syntax"), - }, { - // testing for error in types - v1: NewInt64(2), - v2: TestValue(Int64, "1.2"), - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "strconv.ParseInt: parsing \"1.2\": invalid syntax"), - }, { - // testing for uint/int - v1: NewUint64(4), - v2: NewInt64(5), - out: NewFloat64(0.8), - }, { - // testing for uint/uint - v1: NewUint64(1), - v2: NewUint64(2), - out: NewFloat64(0.5), - }, { - // testing for float64/int64 - v1: TestValue(Float64, "1.2"), - v2: NewInt64(-2), - out: NewFloat64(-0.6), - }, { - // testing for float64/uint64 - v1: TestValue(Float64, "1.2"), - v2: NewUint64(2), - out: NewFloat64(0.6), - }, { - // testing for overflow of float64 - v1: NewFloat64(math.MaxFloat64), - v2: NewFloat64(0.5), - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT is out of range in 1.7976931348623157e+308 / 0.5"), - }} - - for _, tcase := range tcases { - got, err := Divide(tcase.v1, tcase.v2) - - if !vterrors.Equals(err, tcase.err) { - t.Errorf("%v %v %v", printValue(tcase.v1), printValue(tcase.v2), vterrors.Print(err)) - t.Errorf("Divide(%v, %v) error: %v, want %v", printValue(tcase.v1), printValue(tcase.v2), vterrors.Print(err), vterrors.Print(tcase.err)) - } - - if tcase.err != nil { - continue - } - - if !reflect.DeepEqual(got, tcase.out) { - t.Errorf("Divide(%v, %v): %v, want %v", printValue(tcase.v1), printValue(tcase.v2), printValue(got), printValue(tcase.out)) - } - } - -} - -func TestMultiply(t *testing.T) { - tcases := []struct { - v1, v2 Value - out Value - err error - }{{ - //All Nulls - v1: NULL, - v2: NULL, - out: NULL, - }, { - // First value null. - v1: NewInt32(1), - v2: NULL, - out: NULL, - }, { - // Second value null. - v1: NULL, - v2: NewInt32(1), - out: NULL, - }, { - // case with negative value - v1: NewInt64(-1), - v2: NewInt64(-2), - out: NewInt64(2), - }, { - // testing for int64 overflow with min negative value - v1: NewInt64(math.MinInt64), - v2: NewInt64(1), - out: NewInt64(math.MinInt64), - }, { - // testing for error in types - v1: TestValue(Int64, "1.2"), - v2: NewInt64(2), - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "strconv.ParseInt: parsing \"1.2\": invalid syntax"), - }, { - // testing for error in types - v1: NewInt64(2), - v2: TestValue(Int64, "1.2"), - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "strconv.ParseInt: parsing \"1.2\": invalid syntax"), - }, { - // testing for uint*int - v1: NewUint64(4), - v2: NewInt64(5), - out: NewUint64(20), - }, { - // testing for uint*uint - v1: NewUint64(1), - v2: NewUint64(2), - out: NewUint64(2), - }, { - // testing for float64*int64 - v1: TestValue(Float64, "1.2"), - v2: NewInt64(-2), - out: NewFloat64(-2.4), - }, { - // testing for float64*uint64 - v1: TestValue(Float64, "1.2"), - v2: NewUint64(2), - out: NewFloat64(2.4), - }, { - // testing for overflow of int64 - v1: NewInt64(math.MaxInt64), - v2: NewInt64(2), - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT value is out of range in 9223372036854775807 * 2"), - }, { - // testing for underflow of uint64*max.uint64 - v1: NewInt64(2), - v2: NewUint64(math.MaxUint64), - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in 18446744073709551615 * 2"), - }, { - v1: NewUint64(math.MaxUint64), - v2: NewUint64(1), - out: NewUint64(math.MaxUint64), - }, { - //Checking whether maxInt value can be passed as uint value - v1: NewUint64(math.MaxInt64), - v2: NewInt64(3), - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in 9223372036854775807 * 3"), - }} - - for _, tcase := range tcases { - - got, err := Multiply(tcase.v1, tcase.v2) - - if !vterrors.Equals(err, tcase.err) { - t.Errorf("Multiply(%v, %v) error: %v, want %v", printValue(tcase.v1), printValue(tcase.v2), vterrors.Print(err), vterrors.Print(tcase.err)) - } - if tcase.err != nil { - continue - } - - if !reflect.DeepEqual(got, tcase.out) { - t.Errorf("Multiply(%v, %v): %v, want %v", printValue(tcase.v1), printValue(tcase.v2), printValue(got), printValue(tcase.out)) - } - } - -} - -func TestSubtract(t *testing.T) { - tcases := []struct { - v1, v2 Value - out Value - err error - }{{ - // All Nulls - v1: NULL, - v2: NULL, - out: NULL, - }, { - // First value null. - v1: NewInt32(1), - v2: NULL, - out: NULL, - }, { - // Second value null. - v1: NULL, - v2: NewInt32(1), - out: NULL, - }, { - // case with negative value - v1: NewInt64(-1), - v2: NewInt64(-2), - out: NewInt64(1), - }, { - // testing for int64 overflow with min negative value - v1: NewInt64(math.MinInt64), - v2: NewInt64(1), - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT value is out of range in -9223372036854775808 - 1"), - }, { - v1: NewUint64(4), - v2: NewInt64(5), - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in 4 - 5"), - }, { - // testing uint - int - v1: NewUint64(7), - v2: NewInt64(5), - out: NewUint64(2), - }, { - v1: NewUint64(math.MaxUint64), - v2: NewInt64(0), - out: NewUint64(math.MaxUint64), - }, { - // testing for int64 overflow - v1: NewInt64(math.MinInt64), - v2: NewUint64(0), - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in -9223372036854775808 - 0"), - }, { - v1: TestValue(VarChar, "c"), - v2: NewInt64(1), - out: NewInt64(-1), - }, { - v1: NewUint64(1), - v2: TestValue(VarChar, "c"), - out: NewUint64(1), - }, { - // testing for error for parsing float value to uint64 - v1: TestValue(Uint64, "1.2"), - v2: NewInt64(2), - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "strconv.ParseUint: parsing \"1.2\": invalid syntax"), - }, { - // testing for error for parsing float value to uint64 - v1: NewUint64(2), - v2: TestValue(Uint64, "1.2"), - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "strconv.ParseUint: parsing \"1.2\": invalid syntax"), - }, { - // uint64 - uint64 - v1: NewUint64(8), - v2: NewUint64(4), - out: NewUint64(4), - }, { - // testing for float subtraction: float - int - v1: NewFloat64(1.2), - v2: NewInt64(2), - out: NewFloat64(-0.8), - }, { - // testing for float subtraction: float - uint - v1: NewFloat64(1.2), - v2: NewUint64(2), - out: NewFloat64(-0.8), - }, { - v1: NewInt64(-1), - v2: NewUint64(2), - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in -1 - 2"), - }, { - v1: NewInt64(2), - v2: NewUint64(1), - out: NewUint64(1), - }, { - // testing int64 - float64 method - v1: NewInt64(-2), - v2: NewFloat64(1.0), - out: NewFloat64(-3.0), - }, { - // testing uint64 - float64 method - v1: NewUint64(1), - v2: NewFloat64(-2.0), - out: NewFloat64(3.0), - }, { - // testing uint - int to return uintplusint - v1: NewUint64(1), - v2: NewInt64(-2), - out: NewUint64(3), - }, { - // testing for float - float - v1: NewFloat64(1.2), - v2: NewFloat64(3.2), - out: NewFloat64(-2), - }, { - // testing uint - uint if v2 > v1 - v1: NewUint64(2), - v2: NewUint64(4), - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in 2 - 4"), - }, { - // testing uint - (- int) - v1: NewUint64(1), - v2: NewInt64(-2), - out: NewUint64(3), - }} - - for _, tcase := range tcases { - - got, err := Subtract(tcase.v1, tcase.v2) - - if !vterrors.Equals(err, tcase.err) { - t.Errorf("Subtract(%v, %v) error: %v, want %v", printValue(tcase.v1), printValue(tcase.v2), vterrors.Print(err), vterrors.Print(tcase.err)) - } - if tcase.err != nil { - continue - } - - if !reflect.DeepEqual(got, tcase.out) { - t.Errorf("Subtract(%v, %v): %v, want %v", printValue(tcase.v1), printValue(tcase.v2), printValue(got), printValue(tcase.out)) - } - } - -} - -func TestAdd(t *testing.T) { - tcases := []struct { - v1, v2 Value - out Value - err error - }{{ - // All Nulls - v1: NULL, - v2: NULL, - out: NULL, - }, { - // First value null. - v1: NewInt32(1), - v2: NULL, - out: NULL, - }, { - // Second value null. - v1: NULL, - v2: NewInt32(1), - out: NULL, - }, { - // case with negatives - v1: NewInt64(-1), - v2: NewInt64(-2), - out: NewInt64(-3), - }, { - // testing for overflow int64, result will be unsigned int - v1: NewInt64(math.MaxInt64), - v2: NewUint64(2), - out: NewUint64(9223372036854775809), - }, { - v1: NewInt64(-2), - v2: NewUint64(1), - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in 1 + -2"), - }, { - v1: NewInt64(math.MaxInt64), - v2: NewInt64(-2), - out: NewInt64(9223372036854775805), - }, { - // Normal case - v1: NewUint64(1), - v2: NewUint64(2), - out: NewUint64(3), - }, { - // testing for overflow uint64 - v1: NewUint64(math.MaxUint64), - v2: NewUint64(2), - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in 18446744073709551615 + 2"), - }, { - // int64 underflow - v1: NewInt64(math.MinInt64), - v2: NewInt64(-2), - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT value is out of range in -9223372036854775808 + -2"), - }, { - // checking int64 max value can be returned - v1: NewInt64(math.MaxInt64), - v2: NewUint64(0), - out: NewUint64(9223372036854775807), - }, { - // testing whether uint64 max value can be returned - v1: NewUint64(math.MaxUint64), - v2: NewInt64(0), - out: NewUint64(math.MaxUint64), - }, { - v1: NewUint64(math.MaxInt64), - v2: NewInt64(1), - out: NewUint64(9223372036854775808), - }, { - v1: NewUint64(1), - v2: TestValue(VarChar, "c"), - out: NewUint64(1), - }, { - v1: NewUint64(1), - v2: TestValue(VarChar, "1.2"), - out: NewFloat64(2.2), - }, { - v1: TestValue(Int64, "1.2"), - v2: NewInt64(2), - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "strconv.ParseInt: parsing \"1.2\": invalid syntax"), - }, { - v1: NewInt64(2), - v2: TestValue(Int64, "1.2"), - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "strconv.ParseInt: parsing \"1.2\": invalid syntax"), - }, { - // testing for uint64 overflow with max uint64 + int value - v1: NewUint64(math.MaxUint64), - v2: NewInt64(2), - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in 18446744073709551615 + 2"), - }} - - for _, tcase := range tcases { - - got, err := Add(tcase.v1, tcase.v2) - - if !vterrors.Equals(err, tcase.err) { - t.Errorf("Add(%v, %v) error: %v, want %v", printValue(tcase.v1), printValue(tcase.v2), vterrors.Print(err), vterrors.Print(tcase.err)) - } - if tcase.err != nil { - continue - } - - if !reflect.DeepEqual(got, tcase.out) { - t.Errorf("Add(%v, %v): %v, want %v", printValue(tcase.v1), printValue(tcase.v2), printValue(got), printValue(tcase.out)) - } - } - -} - -func TestNullsafeAdd(t *testing.T) { - tcases := []struct { - v1, v2 Value - out Value - err error - }{{ - // All nulls. - v1: NULL, - v2: NULL, - out: NewInt64(0), - }, { - // First value null. - v1: NewInt32(1), - v2: NULL, - out: NewInt64(1), - }, { - // Second value null. - v1: NULL, - v2: NewInt32(1), - out: NewInt64(1), - }, { - // Normal case. - v1: NewInt64(1), - v2: NewInt64(2), - out: NewInt64(3), - }, { - // Make sure underlying error is returned for LHS. - v1: TestValue(Int64, "1.2"), - v2: NewInt64(2), - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "strconv.ParseInt: parsing \"1.2\": invalid syntax"), - }, { - // Make sure underlying error is returned for RHS. - v1: NewInt64(2), - v2: TestValue(Int64, "1.2"), - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "strconv.ParseInt: parsing \"1.2\": invalid syntax"), - }, { - // Make sure underlying error is returned while adding. - v1: NewInt64(-1), - v2: NewUint64(2), - out: NewInt64(-9223372036854775808), - }, { - // Make sure underlying error is returned while converting. - v1: NewFloat64(1), - v2: NewFloat64(2), - out: NewInt64(3), - }} - for _, tcase := range tcases { - got := NullsafeAdd(tcase.v1, tcase.v2, Int64) - - if !reflect.DeepEqual(got, tcase.out) { - t.Errorf("NullsafeAdd(%v, %v): %v, want %v", printValue(tcase.v1), printValue(tcase.v2), printValue(got), printValue(tcase.out)) - } - } -} - -func TestNullsafeCompare(t *testing.T) { - tcases := []struct { - v1, v2 Value - out int - err error - }{{ - // All nulls. - v1: NULL, - v2: NULL, - out: 0, - }, { - // LHS null. - v1: NULL, - v2: NewInt64(1), - out: -1, - }, { - // RHS null. - v1: NewInt64(1), - v2: NULL, - out: 1, - }, { - // LHS Text - v1: TestValue(VarChar, "abcd"), - v2: TestValue(VarChar, "abcd"), - err: vterrors.New(vtrpcpb.Code_UNKNOWN, "types are not comparable: VARCHAR vs VARCHAR"), - }, { - // Make sure underlying error is returned for LHS. - v1: TestValue(Int64, "1.2"), - v2: NewInt64(2), - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "strconv.ParseInt: parsing \"1.2\": invalid syntax"), - }, { - // Make sure underlying error is returned for RHS. - v1: NewInt64(2), - v2: TestValue(Int64, "1.2"), - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "strconv.ParseInt: parsing \"1.2\": invalid syntax"), - }, { - // Numeric equal. - v1: NewInt64(1), - v2: NewUint64(1), - out: 0, - }, { - // Numeric unequal. - v1: NewInt64(1), - v2: NewUint64(2), - out: -1, - }, { - // Non-numeric equal - v1: TestValue(VarBinary, "abcd"), - v2: TestValue(Binary, "abcd"), - out: 0, - }, { - // Non-numeric unequal - v1: TestValue(VarBinary, "abcd"), - v2: TestValue(Binary, "bcde"), - out: -1, - }, { - // Date/Time types - v1: TestValue(Datetime, "1000-01-01 00:00:00"), - v2: TestValue(Binary, "1000-01-01 00:00:00"), - out: 0, - }, { - // Date/Time types - v1: TestValue(Datetime, "2000-01-01 00:00:00"), - v2: TestValue(Binary, "1000-01-01 00:00:00"), - out: 1, - }, { - // Date/Time types - v1: TestValue(Datetime, "1000-01-01 00:00:00"), - v2: TestValue(Binary, "2000-01-01 00:00:00"), - out: -1, - }} - for _, tcase := range tcases { - got, err := NullsafeCompare(tcase.v1, tcase.v2) - if !vterrors.Equals(err, tcase.err) { - t.Errorf("NullsafeCompare(%v, %v) error: %v, want %v", printValue(tcase.v1), printValue(tcase.v2), vterrors.Print(err), vterrors.Print(tcase.err)) - } - if tcase.err != nil { - continue - } - - if got != tcase.out { - t.Errorf("NullsafeCompare(%v, %v): %v, want %v", printValue(tcase.v1), printValue(tcase.v2), got, tcase.out) - } - } -} - -func TestCast(t *testing.T) { - tcases := []struct { - typ querypb.Type - v Value - out Value - err error - }{{ - typ: VarChar, - v: NULL, - out: NULL, - }, { - typ: VarChar, - v: TestValue(VarChar, "exact types"), - out: TestValue(VarChar, "exact types"), - }, { - typ: Int64, - v: TestValue(Int32, "32"), - out: TestValue(Int64, "32"), - }, { - typ: Int24, - v: TestValue(Uint64, "64"), - out: TestValue(Int24, "64"), - }, { - typ: Int24, - v: TestValue(VarChar, "bad int"), - err: vterrors.New(vtrpcpb.Code_UNKNOWN, `strconv.ParseInt: parsing "bad int": invalid syntax`), - }, { - typ: Uint64, - v: TestValue(Uint32, "32"), - out: TestValue(Uint64, "32"), - }, { - typ: Uint24, - v: TestValue(Int64, "64"), - out: TestValue(Uint24, "64"), - }, { - typ: Uint24, - v: TestValue(Int64, "-1"), - err: vterrors.New(vtrpcpb.Code_UNKNOWN, `strconv.ParseUint: parsing "-1": invalid syntax`), - }, { - typ: Float64, - v: TestValue(Int64, "64"), - out: TestValue(Float64, "64"), - }, { - typ: Float32, - v: TestValue(Float64, "64"), - out: TestValue(Float32, "64"), - }, { - typ: Float32, - v: TestValue(Decimal, "1.24"), - out: TestValue(Float32, "1.24"), - }, { - typ: Float64, - v: TestValue(VarChar, "1.25"), - out: TestValue(Float64, "1.25"), - }, { - typ: Float64, - v: TestValue(VarChar, "bad float"), - err: vterrors.New(vtrpcpb.Code_UNKNOWN, `strconv.ParseFloat: parsing "bad float": invalid syntax`), - }, { - typ: VarChar, - v: TestValue(Int64, "64"), - out: TestValue(VarChar, "64"), - }, { - typ: VarBinary, - v: TestValue(Float64, "64"), - out: TestValue(VarBinary, "64"), - }, { - typ: VarBinary, - v: TestValue(Decimal, "1.24"), - out: TestValue(VarBinary, "1.24"), - }, { - typ: VarBinary, - v: TestValue(VarChar, "1.25"), - out: TestValue(VarBinary, "1.25"), - }, { - typ: VarChar, - v: TestValue(VarBinary, "valid string"), - out: TestValue(VarChar, "valid string"), - }, { - typ: VarChar, - v: TestValue(Expression, "bad string"), - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "EXPRESSION(bad string) cannot be cast to VARCHAR"), - }} - for _, tcase := range tcases { - got, err := Cast(tcase.v, tcase.typ) - if !vterrors.Equals(err, tcase.err) { - t.Errorf("Cast(%v) error: %v, want %v", tcase.v, vterrors.Print(err), vterrors.Print(tcase.err)) - } - if tcase.err != nil { - continue - } - - if !reflect.DeepEqual(got, tcase.out) { - t.Errorf("Cast(%v): %v, want %v", tcase.v, got, tcase.out) - } - } -} - -func TestToUint64(t *testing.T) { - tcases := []struct { - v Value - out uint64 - err error - }{{ - v: TestValue(VarChar, "abcd"), - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "could not parse value: 'abcd'"), - }, { - v: NewInt64(-1), - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "negative number cannot be converted to unsigned: -1"), - }, { - v: NewInt64(1), - out: 1, - }, { - v: NewUint64(1), - out: 1, - }} - for _, tcase := range tcases { - got, err := ToUint64(tcase.v) - if !vterrors.Equals(err, tcase.err) { - t.Errorf("ToUint64(%v) error: %v, want %v", tcase.v, vterrors.Print(err), vterrors.Print(tcase.err)) - } - if tcase.err != nil { - continue - } - - if got != tcase.out { - t.Errorf("ToUint64(%v): %v, want %v", tcase.v, got, tcase.out) - } - } -} - -func TestToInt64(t *testing.T) { - tcases := []struct { - v Value - out int64 - err error - }{{ - v: TestValue(VarChar, "abcd"), - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "could not parse value: 'abcd'"), - }, { - v: NewUint64(18446744073709551615), - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "unsigned number overflows int64 value: 18446744073709551615"), - }, { - v: NewInt64(1), - out: 1, - }, { - v: NewUint64(1), - out: 1, - }} - for _, tcase := range tcases { - got, err := ToInt64(tcase.v) - if !vterrors.Equals(err, tcase.err) { - t.Errorf("ToInt64(%v) error: %v, want %v", tcase.v, vterrors.Print(err), vterrors.Print(tcase.err)) - } - if tcase.err != nil { - continue - } - - if got != tcase.out { - t.Errorf("ToInt64(%v): %v, want %v", tcase.v, got, tcase.out) - } - } -} - -func TestToFloat64(t *testing.T) { - tcases := []struct { - v Value - out float64 - err error - }{{ - v: TestValue(VarChar, "abcd"), - out: 0, - }, { - v: NewInt64(1), - out: 1, - }, { - v: NewUint64(1), - out: 1, - }, { - v: NewFloat64(1.2), - out: 1.2, - }, { - v: TestValue(Int64, "1.2"), - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "strconv.ParseInt: parsing \"1.2\": invalid syntax"), - }} - for _, tcase := range tcases { - got, err := ToFloat64(tcase.v) - if !vterrors.Equals(err, tcase.err) { - t.Errorf("ToFloat64(%v) error: %v, want %v", tcase.v, vterrors.Print(err), vterrors.Print(tcase.err)) - } - if tcase.err != nil { - continue - } - - if got != tcase.out { - t.Errorf("ToFloat64(%v): %v, want %v", tcase.v, got, tcase.out) - } - } -} - -func TestToNative(t *testing.T) { - testcases := []struct { - in Value - out interface{} - }{{ - in: NULL, - out: nil, - }, { - in: TestValue(Int8, "1"), - out: int64(1), - }, { - in: TestValue(Int16, "1"), - out: int64(1), - }, { - in: TestValue(Int24, "1"), - out: int64(1), - }, { - in: TestValue(Int32, "1"), - out: int64(1), - }, { - in: TestValue(Int64, "1"), - out: int64(1), - }, { - in: TestValue(Uint8, "1"), - out: uint64(1), - }, { - in: TestValue(Uint16, "1"), - out: uint64(1), - }, { - in: TestValue(Uint24, "1"), - out: uint64(1), - }, { - in: TestValue(Uint32, "1"), - out: uint64(1), - }, { - in: TestValue(Uint64, "1"), - out: uint64(1), - }, { - in: TestValue(Float32, "1"), - out: float64(1), - }, { - in: TestValue(Float64, "1"), - out: float64(1), - }, { - in: TestValue(Timestamp, "2012-02-24 23:19:43"), - out: []byte("2012-02-24 23:19:43"), - }, { - in: TestValue(Date, "2012-02-24"), - out: []byte("2012-02-24"), - }, { - in: TestValue(Time, "23:19:43"), - out: []byte("23:19:43"), - }, { - in: TestValue(Datetime, "2012-02-24 23:19:43"), - out: []byte("2012-02-24 23:19:43"), - }, { - in: TestValue(Year, "1"), - out: uint64(1), - }, { - in: TestValue(Decimal, "1"), - out: []byte("1"), - }, { - in: TestValue(Text, "a"), - out: []byte("a"), - }, { - in: TestValue(Blob, "a"), - out: []byte("a"), - }, { - in: TestValue(VarChar, "a"), - out: []byte("a"), - }, { - in: TestValue(VarBinary, "a"), - out: []byte("a"), - }, { - in: TestValue(Char, "a"), - out: []byte("a"), - }, { - in: TestValue(Binary, "a"), - out: []byte("a"), - }, { - in: TestValue(Bit, "1"), - out: []byte("1"), - }, { - in: TestValue(Enum, "a"), - out: []byte("a"), - }, { - in: TestValue(Set, "a"), - out: []byte("a"), - }} - for _, tcase := range testcases { - v, err := ToNative(tcase.in) - if err != nil { - t.Error(err) - } - if !reflect.DeepEqual(v, tcase.out) { - t.Errorf("%v.ToNative = %#v, want %#v", tcase.in, v, tcase.out) - } - } - - // Test Expression failure. - _, err := ToNative(TestValue(Expression, "aa")) - want := vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "EXPRESSION(aa) cannot be converted to a go type") - if !vterrors.Equals(err, want) { - t.Errorf("ToNative(EXPRESSION): %v, want %v", vterrors.Print(err), vterrors.Print(want)) - } -} - -func TestNewNumeric(t *testing.T) { - tcases := []struct { - v Value - out numeric - err error - }{{ - v: NewInt64(1), - out: numeric{typ: Int64, ival: 1}, - }, { - v: NewUint64(1), - out: numeric{typ: Uint64, uval: 1}, - }, { - v: NewFloat64(1), - out: numeric{typ: Float64, fval: 1}, - }, { - // For non-number type, Int64 is the default. - v: TestValue(VarChar, "1"), - out: numeric{typ: Int64, ival: 1}, - }, { - // If Int64 can't work, we use Float64. - v: TestValue(VarChar, "1.2"), - out: numeric{typ: Float64, fval: 1.2}, - }, { - // Only valid Int64 allowed if type is Int64. - v: TestValue(Int64, "1.2"), - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "strconv.ParseInt: parsing \"1.2\": invalid syntax"), - }, { - // Only valid Uint64 allowed if type is Uint64. - v: TestValue(Uint64, "1.2"), - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "strconv.ParseUint: parsing \"1.2\": invalid syntax"), - }, { - // Only valid Float64 allowed if type is Float64. - v: TestValue(Float64, "abcd"), - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "strconv.ParseFloat: parsing \"abcd\": invalid syntax"), - }, { - v: TestValue(VarChar, "abcd"), - out: numeric{typ: Float64, fval: 0}, - }} - for _, tcase := range tcases { - got, err := newNumeric(tcase.v) - if !vterrors.Equals(err, tcase.err) { - t.Errorf("newNumeric(%s) error: %v, want %v", printValue(tcase.v), vterrors.Print(err), vterrors.Print(tcase.err)) - } - if tcase.err == nil { - continue - } - - if got != tcase.out { - t.Errorf("newNumeric(%s): %v, want %v", printValue(tcase.v), got, tcase.out) - } - } -} - -func TestNewIntegralNumeric(t *testing.T) { - tcases := []struct { - v Value - out numeric - err error - }{{ - v: NewInt64(1), - out: numeric{typ: Int64, ival: 1}, - }, { - v: NewUint64(1), - out: numeric{typ: Uint64, uval: 1}, - }, { - v: NewFloat64(1), - out: numeric{typ: Int64, ival: 1}, - }, { - // For non-number type, Int64 is the default. - v: TestValue(VarChar, "1"), - out: numeric{typ: Int64, ival: 1}, - }, { - // If Int64 can't work, we use Uint64. - v: TestValue(VarChar, "18446744073709551615"), - out: numeric{typ: Uint64, uval: 18446744073709551615}, - }, { - // Only valid Int64 allowed if type is Int64. - v: TestValue(Int64, "1.2"), - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "strconv.ParseInt: parsing \"1.2\": invalid syntax"), - }, { - // Only valid Uint64 allowed if type is Uint64. - v: TestValue(Uint64, "1.2"), - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "strconv.ParseUint: parsing \"1.2\": invalid syntax"), - }, { - v: TestValue(VarChar, "abcd"), - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "could not parse value: 'abcd'"), - }} - for _, tcase := range tcases { - got, err := newIntegralNumeric(tcase.v) - if err != nil && !vterrors.Equals(err, tcase.err) { - t.Errorf("newIntegralNumeric(%s) error: %v, want %v", printValue(tcase.v), vterrors.Print(err), vterrors.Print(tcase.err)) - } - if tcase.err == nil { - continue - } - - if got != tcase.out { - t.Errorf("newIntegralNumeric(%s): %v, want %v", printValue(tcase.v), got, tcase.out) - } - } -} - -func TestAddNumeric(t *testing.T) { - tcases := []struct { - v1, v2 numeric - out numeric - err error - }{{ - v1: numeric{typ: Int64, ival: 1}, - v2: numeric{typ: Int64, ival: 2}, - out: numeric{typ: Int64, ival: 3}, - }, { - v1: numeric{typ: Int64, ival: 1}, - v2: numeric{typ: Uint64, uval: 2}, - out: numeric{typ: Uint64, uval: 3}, - }, { - v1: numeric{typ: Int64, ival: 1}, - v2: numeric{typ: Float64, fval: 2}, - out: numeric{typ: Float64, fval: 3}, - }, { - v1: numeric{typ: Uint64, uval: 1}, - v2: numeric{typ: Uint64, uval: 2}, - out: numeric{typ: Uint64, uval: 3}, - }, { - v1: numeric{typ: Uint64, uval: 1}, - v2: numeric{typ: Float64, fval: 2}, - out: numeric{typ: Float64, fval: 3}, - }, { - v1: numeric{typ: Float64, fval: 1}, - v2: numeric{typ: Float64, fval: 2}, - out: numeric{typ: Float64, fval: 3}, - }, { - // Int64 overflow. - v1: numeric{typ: Int64, ival: 9223372036854775807}, - v2: numeric{typ: Int64, ival: 2}, - out: numeric{typ: Float64, fval: 9223372036854775809}, - }, { - // Int64 underflow. - v1: numeric{typ: Int64, ival: -9223372036854775807}, - v2: numeric{typ: Int64, ival: -2}, - out: numeric{typ: Float64, fval: -9223372036854775809}, - }, { - v1: numeric{typ: Int64, ival: -1}, - v2: numeric{typ: Uint64, uval: 2}, - out: numeric{typ: Float64, fval: 18446744073709551617}, - }, { - // Uint64 overflow. - v1: numeric{typ: Uint64, uval: 18446744073709551615}, - v2: numeric{typ: Uint64, uval: 2}, - out: numeric{typ: Float64, fval: 18446744073709551617}, - }} - for _, tcase := range tcases { - got := addNumeric(tcase.v1, tcase.v2) - - if got != tcase.out { - t.Errorf("addNumeric(%v, %v): %v, want %v", tcase.v1, tcase.v2, got, tcase.out) - } - } -} - -func TestPrioritize(t *testing.T) { - ival := numeric{typ: Int64} - uval := numeric{typ: Uint64} - fval := numeric{typ: Float64} - - tcases := []struct { - v1, v2 numeric - out1, out2 numeric - }{{ - v1: ival, - v2: uval, - out1: uval, - out2: ival, - }, { - v1: ival, - v2: fval, - out1: fval, - out2: ival, - }, { - v1: uval, - v2: ival, - out1: uval, - out2: ival, - }, { - v1: uval, - v2: fval, - out1: fval, - out2: uval, - }, { - v1: fval, - v2: ival, - out1: fval, - out2: ival, - }, { - v1: fval, - v2: uval, - out1: fval, - out2: uval, - }} - for _, tcase := range tcases { - got1, got2 := prioritize(tcase.v1, tcase.v2) - if got1 != tcase.out1 || got2 != tcase.out2 { - t.Errorf("prioritize(%v, %v): (%v, %v) , want (%v, %v)", tcase.v1.typ, tcase.v2.typ, got1.typ, got2.typ, tcase.out1.typ, tcase.out2.typ) - } - } -} - -func TestCastFromNumeric(t *testing.T) { - tcases := []struct { - typ querypb.Type - v numeric - out Value - err error - }{{ - typ: Int64, - v: numeric{typ: Int64, ival: 1}, - out: NewInt64(1), - }, { - typ: Int64, - v: numeric{typ: Uint64, uval: 1}, - out: NewInt64(1), - }, { - typ: Int64, - v: numeric{typ: Float64, fval: 1.2e-16}, - out: NewInt64(0), - }, { - typ: Uint64, - v: numeric{typ: Int64, ival: 1}, - out: NewUint64(1), - }, { - typ: Uint64, - v: numeric{typ: Uint64, uval: 1}, - out: NewUint64(1), - }, { - typ: Uint64, - v: numeric{typ: Float64, fval: 1.2e-16}, - out: NewUint64(0), - }, { - typ: Float64, - v: numeric{typ: Int64, ival: 1}, - out: TestValue(Float64, "1"), - }, { - typ: Float64, - v: numeric{typ: Uint64, uval: 1}, - out: TestValue(Float64, "1"), - }, { - typ: Float64, - v: numeric{typ: Float64, fval: 1.2e-16}, - out: TestValue(Float64, "1.2e-16"), - }, { - typ: Decimal, - v: numeric{typ: Int64, ival: 1}, - out: TestValue(Decimal, "1"), - }, { - typ: Decimal, - v: numeric{typ: Uint64, uval: 1}, - out: TestValue(Decimal, "1"), - }, { - // For float, we should not use scientific notation. - typ: Decimal, - v: numeric{typ: Float64, fval: 1.2e-16}, - out: TestValue(Decimal, "0.00000000000000012"), - }, { - typ: VarBinary, - v: numeric{typ: Int64, ival: 1}, - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "unexpected type conversion to non-numeric: VARBINARY"), - }} - for _, tcase := range tcases { - got := castFromNumeric(tcase.v, tcase.typ) - - if !reflect.DeepEqual(got, tcase.out) { - t.Errorf("castFromNumeric(%v, %v): %v, want %v", tcase.v, tcase.typ, printValue(got), printValue(tcase.out)) - } - } -} - -func TestCompareNumeric(t *testing.T) { - tcases := []struct { - v1, v2 numeric - out int - }{{ - v1: numeric{typ: Int64, ival: 1}, - v2: numeric{typ: Int64, ival: 1}, - out: 0, - }, { - v1: numeric{typ: Int64, ival: 1}, - v2: numeric{typ: Int64, ival: 2}, - out: -1, - }, { - v1: numeric{typ: Int64, ival: 2}, - v2: numeric{typ: Int64, ival: 1}, - out: 1, - }, { - // Special case. - v1: numeric{typ: Int64, ival: -1}, - v2: numeric{typ: Uint64, uval: 1}, - out: -1, - }, { - v1: numeric{typ: Int64, ival: 1}, - v2: numeric{typ: Uint64, uval: 1}, - out: 0, - }, { - v1: numeric{typ: Int64, ival: 1}, - v2: numeric{typ: Uint64, uval: 2}, - out: -1, - }, { - v1: numeric{typ: Int64, ival: 2}, - v2: numeric{typ: Uint64, uval: 1}, - out: 1, - }, { - v1: numeric{typ: Int64, ival: 1}, - v2: numeric{typ: Float64, fval: 1}, - out: 0, - }, { - v1: numeric{typ: Int64, ival: 1}, - v2: numeric{typ: Float64, fval: 2}, - out: -1, - }, { - v1: numeric{typ: Int64, ival: 2}, - v2: numeric{typ: Float64, fval: 1}, - out: 1, - }, { - // Special case. - v1: numeric{typ: Uint64, uval: 1}, - v2: numeric{typ: Int64, ival: -1}, - out: 1, - }, { - v1: numeric{typ: Uint64, uval: 1}, - v2: numeric{typ: Int64, ival: 1}, - out: 0, - }, { - v1: numeric{typ: Uint64, uval: 1}, - v2: numeric{typ: Int64, ival: 2}, - out: -1, - }, { - v1: numeric{typ: Uint64, uval: 2}, - v2: numeric{typ: Int64, ival: 1}, - out: 1, - }, { - v1: numeric{typ: Uint64, uval: 1}, - v2: numeric{typ: Uint64, uval: 1}, - out: 0, - }, { - v1: numeric{typ: Uint64, uval: 1}, - v2: numeric{typ: Uint64, uval: 2}, - out: -1, - }, { - v1: numeric{typ: Uint64, uval: 2}, - v2: numeric{typ: Uint64, uval: 1}, - out: 1, - }, { - v1: numeric{typ: Uint64, uval: 1}, - v2: numeric{typ: Float64, fval: 1}, - out: 0, - }, { - v1: numeric{typ: Uint64, uval: 1}, - v2: numeric{typ: Float64, fval: 2}, - out: -1, - }, { - v1: numeric{typ: Uint64, uval: 2}, - v2: numeric{typ: Float64, fval: 1}, - out: 1, - }, { - v1: numeric{typ: Float64, fval: 1}, - v2: numeric{typ: Int64, ival: 1}, - out: 0, - }, { - v1: numeric{typ: Float64, fval: 1}, - v2: numeric{typ: Int64, ival: 2}, - out: -1, - }, { - v1: numeric{typ: Float64, fval: 2}, - v2: numeric{typ: Int64, ival: 1}, - out: 1, - }, { - v1: numeric{typ: Float64, fval: 1}, - v2: numeric{typ: Uint64, uval: 1}, - out: 0, - }, { - v1: numeric{typ: Float64, fval: 1}, - v2: numeric{typ: Uint64, uval: 2}, - out: -1, - }, { - v1: numeric{typ: Float64, fval: 2}, - v2: numeric{typ: Uint64, uval: 1}, - out: 1, - }, { - v1: numeric{typ: Float64, fval: 1}, - v2: numeric{typ: Float64, fval: 1}, - out: 0, - }, { - v1: numeric{typ: Float64, fval: 1}, - v2: numeric{typ: Float64, fval: 2}, - out: -1, - }, { - v1: numeric{typ: Float64, fval: 2}, - v2: numeric{typ: Float64, fval: 1}, - out: 1, - }} - for _, tcase := range tcases { - got := compareNumeric(tcase.v1, tcase.v2) - if got != tcase.out { - t.Errorf("equalNumeric(%v, %v): %v, want %v", tcase.v1, tcase.v2, got, tcase.out) - } - } -} - -func TestMin(t *testing.T) { - tcases := []struct { - v1, v2 Value - min Value - err error - }{{ - v1: NULL, - v2: NULL, - min: NULL, - }, { - v1: NewInt64(1), - v2: NULL, - min: NewInt64(1), - }, { - v1: NULL, - v2: NewInt64(1), - min: NewInt64(1), - }, { - v1: NewInt64(1), - v2: NewInt64(2), - min: NewInt64(1), - }, { - v1: NewInt64(2), - v2: NewInt64(1), - min: NewInt64(1), - }, { - v1: NewInt64(1), - v2: NewInt64(1), - min: NewInt64(1), - }, { - v1: TestValue(VarChar, "aa"), - v2: TestValue(VarChar, "aa"), - err: vterrors.New(vtrpcpb.Code_UNKNOWN, "types are not comparable: VARCHAR vs VARCHAR"), - }} - for _, tcase := range tcases { - v, err := Min(tcase.v1, tcase.v2) - if !vterrors.Equals(err, tcase.err) { - t.Errorf("Min error: %v, want %v", vterrors.Print(err), vterrors.Print(tcase.err)) - } - if tcase.err != nil { - continue - } - - if !reflect.DeepEqual(v, tcase.min) { - t.Errorf("Min(%v, %v): %v, want %v", tcase.v1, tcase.v2, v, tcase.min) - } - } -} - -func TestMax(t *testing.T) { - tcases := []struct { - v1, v2 Value - max Value - err error - }{{ - v1: NULL, - v2: NULL, - max: NULL, - }, { - v1: NewInt64(1), - v2: NULL, - max: NewInt64(1), - }, { - v1: NULL, - v2: NewInt64(1), - max: NewInt64(1), - }, { - v1: NewInt64(1), - v2: NewInt64(2), - max: NewInt64(2), - }, { - v1: NewInt64(2), - v2: NewInt64(1), - max: NewInt64(2), - }, { - v1: NewInt64(1), - v2: NewInt64(1), - max: NewInt64(1), - }, { - v1: TestValue(VarChar, "aa"), - v2: TestValue(VarChar, "aa"), - err: vterrors.New(vtrpcpb.Code_UNKNOWN, "types are not comparable: VARCHAR vs VARCHAR"), - }} - for _, tcase := range tcases { - v, err := Max(tcase.v1, tcase.v2) - if !vterrors.Equals(err, tcase.err) { - t.Errorf("Max error: %v, want %v", vterrors.Print(err), vterrors.Print(tcase.err)) - } - if tcase.err != nil { - continue - } - - if !reflect.DeepEqual(v, tcase.max) { - t.Errorf("Max(%v, %v): %v, want %v", tcase.v1, tcase.v2, v, tcase.max) - } - } -} - -func printValue(v Value) string { - return fmt.Sprintf("%v:%q", v.typ, v.val) -} - -// These benchmarks show that using existing ASCII representations -// for numbers is about 6x slower than using native representations. -// However, 229ns is still a negligible time compared to the cost of -// other operations. The additional complexity of introducing native -// types is currently not worth it. So, we'll stay with the existing -// ASCII representation for now. Using interfaces is more expensive -// than native representation of values. This is probably because -// interfaces also allocate memory, and also perform type assertions. -// Actual benchmark is based on NoNative. So, the numbers are similar. -// Date: 6/4/17 -// Version: go1.8 -// BenchmarkAddActual-8 10000000 263 ns/op -// BenchmarkAddNoNative-8 10000000 228 ns/op -// BenchmarkAddNative-8 50000000 40.0 ns/op -// BenchmarkAddGoInterface-8 30000000 52.4 ns/op -// BenchmarkAddGoNonInterface-8 2000000000 1.00 ns/op -// BenchmarkAddGo-8 2000000000 1.00 ns/op -func BenchmarkAddActual(b *testing.B) { - v1 := MakeTrusted(Int64, []byte("1")) - v2 := MakeTrusted(Int64, []byte("12")) - for i := 0; i < b.N; i++ { - v1 = NullsafeAdd(v1, v2, Int64) - } -} - -func BenchmarkAddNoNative(b *testing.B) { - v1 := MakeTrusted(Int64, []byte("1")) - v2 := MakeTrusted(Int64, []byte("12")) - for i := 0; i < b.N; i++ { - iv1, _ := ToInt64(v1) - iv2, _ := ToInt64(v2) - v1 = MakeTrusted(Int64, strconv.AppendInt(nil, iv1+iv2, 10)) - } -} - -func BenchmarkAddNative(b *testing.B) { - v1 := makeNativeInt64(1) - v2 := makeNativeInt64(12) - for i := 0; i < b.N; i++ { - iv1 := int64(binary.BigEndian.Uint64(v1.Raw())) - iv2 := int64(binary.BigEndian.Uint64(v2.Raw())) - v1 = makeNativeInt64(iv1 + iv2) - } -} - -func makeNativeInt64(v int64) Value { - buf := make([]byte, 8) - binary.BigEndian.PutUint64(buf, uint64(v)) - return MakeTrusted(Int64, buf) -} - -func BenchmarkAddGoInterface(b *testing.B) { - var v1, v2 interface{} - v1 = int64(1) - v2 = int64(2) - for i := 0; i < b.N; i++ { - v1 = v1.(int64) + v2.(int64) - } -} - -func BenchmarkAddGoNonInterface(b *testing.B) { - v1 := numeric{typ: Int64, ival: 1} - v2 := numeric{typ: Int64, ival: 12} - for i := 0; i < b.N; i++ { - if v1.typ != Int64 { - b.Error("type assertion failed") - } - if v2.typ != Int64 { - b.Error("type assertion failed") - } - v1 = numeric{typ: Int64, ival: v1.ival + v2.ival} - } -} - -func BenchmarkAddGo(b *testing.B) { - v1 := int64(1) - v2 := int64(2) - for i := 0; i < b.N; i++ { - v1 += v2 - } -} diff --git a/go/sqltypes/type.go b/go/sqltypes/type.go index 79fc5e7b294..adc763cc23e 100644 --- a/go/sqltypes/type.go +++ b/go/sqltypes/type.go @@ -81,8 +81,8 @@ func IsBinary(t querypb.Type) bool { return int(t)&flagIsBinary == flagIsBinary } -// isNumber returns true if the type is any type of number. -func isNumber(t querypb.Type) bool { +// IsNumber returns true if the type is any type of number. +func IsNumber(t querypb.Type) bool { return IsIntegral(t) || IsFloat(t) || t == Decimal } diff --git a/go/sqltypes/type_test.go b/go/sqltypes/type_test.go index 660d3c87bcd..efaeb726121 100644 --- a/go/sqltypes/type_test.go +++ b/go/sqltypes/type_test.go @@ -247,7 +247,7 @@ func TestIsFunctions(t *testing.T) { if !IsBinary(Binary) { t.Error("Char: !IsBinary, must be true") } - if !isNumber(Int64) { + if !IsNumber(Int64) { t.Error("Int64: !isNumber, must be true") } } diff --git a/go/test/endtoend/messaging/message_test.go b/go/test/endtoend/messaging/message_test.go index 89574b6ae0d..8de8c808186 100644 --- a/go/test/endtoend/messaging/message_test.go +++ b/go/test/endtoend/messaging/message_test.go @@ -26,6 +26,8 @@ import ( "testing" "time" + "vitess.io/vitess/go/vt/vtgate/evalengine" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "vitess.io/vitess/go/mysql" @@ -226,8 +228,8 @@ func getTimeEpoch(qr *sqltypes.Result) (int64, int64) { if len(qr.Rows) != 1 { return 0, 0 } - t, _ := sqltypes.ToInt64(qr.Rows[0][0]) - e, _ := sqltypes.ToInt64(qr.Rows[0][1]) + t, _ := evalengine.ToInt64(qr.Rows[0][0]) + e, _ := evalengine.ToInt64(qr.Rows[0][1]) return t, e } diff --git a/go/test/endtoend/sharding/verticalsplit/vertical_split_test.go b/go/test/endtoend/sharding/verticalsplit/vertical_split_test.go index e8624315b77..64e8e1f3705 100644 --- a/go/test/endtoend/sharding/verticalsplit/vertical_split_test.go +++ b/go/test/endtoend/sharding/verticalsplit/vertical_split_test.go @@ -27,6 +27,8 @@ import ( "testing" "time" + "vitess.io/vitess/go/vt/vtgate/evalengine" + "vitess.io/vitess/go/json2" "vitess.io/vitess/go/vt/vtgate/vtgateconn" @@ -546,7 +548,7 @@ func checkValues(t *testing.T, tablet *cluster.Vttablet, keyspace string, dbname assert.Equal(t, count, len(qr.Rows), fmt.Sprintf("got wrong number of rows: %d != %d", len(qr.Rows), count)) i := 0 for i < count { - result, _ := sqltypes.ToInt64(qr.Rows[i][0]) + result, _ := evalengine.ToInt64(qr.Rows[i][0]) assert.Equal(t, int64(first+i), result, fmt.Sprintf("got wrong number of rows: %d != %d", len(qr.Rows), first+i)) assert.Contains(t, qr.Rows[i][1].String(), fmt.Sprintf("value %d", first+i), fmt.Sprintf("invalid msg[%d]: 'value %d' != '%s'", i, first+i, qr.Rows[i][1].String())) i++ diff --git a/go/vt/binlog/binlogplayer/binlog_player.go b/go/vt/binlog/binlogplayer/binlog_player.go index d662129cd17..3e3b47d8c96 100644 --- a/go/vt/binlog/binlogplayer/binlog_player.go +++ b/go/vt/binlog/binlogplayer/binlog_player.go @@ -27,6 +27,8 @@ import ( "sync" "time" + "vitess.io/vitess/go/vt/vtgate/evalengine" + "github.com/golang/protobuf/proto" "golang.org/x/net/context" @@ -542,11 +544,11 @@ func ReadVRSettings(dbClient DBClient, uid uint32) (VRSettings, error) { } vrRow := qr.Rows[0] - maxTPS, err := sqltypes.ToInt64(vrRow[2]) + maxTPS, err := evalengine.ToInt64(vrRow[2]) if err != nil { return VRSettings{}, fmt.Errorf("failed to parse max_tps column: %v", err) } - maxReplicationLag, err := sqltypes.ToInt64(vrRow[3]) + maxReplicationLag, err := evalengine.ToInt64(vrRow[3]) if err != nil { return VRSettings{}, fmt.Errorf("failed to parse max_replication_lag column: %v", err) } diff --git a/go/vt/binlog/keyspace_id_resolver.go b/go/vt/binlog/keyspace_id_resolver.go index 1b45e7fed03..ab010eab5d9 100644 --- a/go/vt/binlog/keyspace_id_resolver.go +++ b/go/vt/binlog/keyspace_id_resolver.go @@ -21,6 +21,8 @@ import ( "fmt" "strings" + "vitess.io/vitess/go/vt/vtgate/evalengine" + "golang.org/x/net/context" "vitess.io/vitess/go/sqltypes" @@ -103,7 +105,7 @@ func (r *keyspaceIDResolverFactoryV2) keyspaceID(v sqltypes.Value) ([]byte, erro case topodatapb.KeyspaceIdType_BYTES: return v.ToBytes(), nil case topodatapb.KeyspaceIdType_UINT64: - i, err := sqltypes.ToUint64(v) + i, err := evalengine.ToUint64(v) if err != nil { return nil, fmt.Errorf("non numerical value: %v", err) } diff --git a/go/vt/mysqlctl/replication.go b/go/vt/mysqlctl/replication.go index 4bcec1d8408..5b3e5d3627d 100644 --- a/go/vt/mysqlctl/replication.go +++ b/go/vt/mysqlctl/replication.go @@ -27,11 +27,12 @@ import ( "strings" "time" + "vitess.io/vitess/go/vt/vtgate/evalengine" + "golang.org/x/net/context" "vitess.io/vitess/go/mysql" "vitess.io/vitess/go/netutil" - "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/vt/hook" "vitess.io/vitess/go/vt/log" ) @@ -145,7 +146,7 @@ func (mysqld *Mysqld) GetMysqlPort() (int32, error) { if len(qr.Rows) != 1 { return 0, errors.New("no port variable in mysql") } - utemp, err := sqltypes.ToUint64(qr.Rows[0][1]) + utemp, err := evalengine.ToUint64(qr.Rows[0][1]) if err != nil { return 0, err } diff --git a/go/vt/mysqlctl/schema.go b/go/vt/mysqlctl/schema.go index 18706cb36ed..cde232c9a1c 100644 --- a/go/vt/mysqlctl/schema.go +++ b/go/vt/mysqlctl/schema.go @@ -21,10 +21,11 @@ import ( "regexp" "strings" + "vitess.io/vitess/go/vt/vtgate/evalengine" + "golang.org/x/net/context" "vitess.io/vitess/go/sqlescape" - "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/vt/log" "vitess.io/vitess/go/vt/mysqlctl/tmutils" @@ -84,7 +85,7 @@ func (mysqld *Mysqld) GetSchema(dbName string, tables, excludeTables []string, i var dataLength uint64 if !row[2].IsNull() { // dataLength is NULL for views, then we use 0 - dataLength, err = sqltypes.ToUint64(row[2]) + dataLength, err = evalengine.ToUint64(row[2]) if err != nil { return nil, err } @@ -93,7 +94,7 @@ func (mysqld *Mysqld) GetSchema(dbName string, tables, excludeTables []string, i // get row count var rowCount uint64 if !row[3].IsNull() { - rowCount, err = sqltypes.ToUint64(row[3]) + rowCount, err = evalengine.ToUint64(row[3]) if err != nil { return nil, err } @@ -214,7 +215,7 @@ func (mysqld *Mysqld) GetPrimaryKeyColumns(dbName, table string) ([]string, erro } // check the Seq_in_index is always increasing - seqInIndex, err := sqltypes.ToInt64(row[seqInIndexIndex]) + seqInIndex, err := evalengine.ToInt64(row[seqInIndexIndex]) if err != nil { return nil, err } diff --git a/go/vt/schemamanager/schemaswap/schema_swap.go b/go/vt/schemamanager/schemaswap/schema_swap.go index ded32dc9be3..14e7161d919 100644 --- a/go/vt/schemamanager/schemaswap/schema_swap.go +++ b/go/vt/schemamanager/schemaswap/schema_swap.go @@ -27,6 +27,8 @@ import ( "sync" "time" + "vitess.io/vitess/go/vt/vtgate/evalengine" + "golang.org/x/net/context" "vitess.io/vitess/go/sqltypes" @@ -614,14 +616,14 @@ func (shardSwap *shardSchemaSwap) readShardMetadata(metadata *shardSwapMetadata, for _, row := range queryResult.Rows { switch row[0].ToString() { case lastStartedMetadataName: - swapID, err := sqltypes.ToUint64(row[1]) + swapID, err := evalengine.ToUint64(row[1]) if err != nil { log.Warningf("Could not parse value of last started schema swap id %v, ignoring the value: %v", row[1], err) } else { metadata.lastStartedSwap = swapID } case lastFinishedMetadataName: - swapID, err := sqltypes.ToUint64(row[1]) + swapID, err := evalengine.ToUint64(row[1]) if err != nil { log.Warningf("Could not parse value of last finished schema swap id %v, ignoring the value: %v", row[1], err) } else { @@ -909,7 +911,7 @@ func (shardSwap *shardSchemaSwap) isSwapApplied(tablet *topodatapb.Tablet) (bool // No such row means we need to apply the swap. return false, nil } - swapID, err := sqltypes.ToUint64(swapIDResult.Rows[0][0]) + swapID, err := evalengine.ToUint64(swapIDResult.Rows[0][0]) if err != nil { return false, err } diff --git a/go/vt/sqlparser/ast_funcs.go b/go/vt/sqlparser/ast_funcs.go index b9ed7137435..15c834f73af 100644 --- a/go/vt/sqlparser/ast_funcs.go +++ b/go/vt/sqlparser/ast_funcs.go @@ -24,10 +24,7 @@ import ( "vitess.io/vitess/go/vt/log" "vitess.io/vitess/go/sqltypes" - "vitess.io/vitess/go/vt/vterrors" - querypb "vitess.io/vitess/go/vt/proto/query" - vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" ) // Walk calls visit on every node. @@ -425,24 +422,6 @@ func (node *ComparisonExpr) IsImpossible() bool { return false } -// ExprFromValue converts the given Value into an Expr or returns an error. -func ExprFromValue(value sqltypes.Value) (Expr, error) { - // The type checks here follow the rules defined in sqltypes/types.go. - switch { - case value.Type() == sqltypes.Null: - return &NullVal{}, nil - case value.IsIntegral(): - return NewIntVal(value.ToBytes()), nil - case value.IsFloat() || value.Type() == sqltypes.Decimal: - return NewFloatVal(value.ToBytes()), nil - case value.IsQuoted(): - return NewStrVal(value.ToBytes()), nil - default: - // We cannot support sqltypes.Expression, or any other invalid type. - return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "cannot convert value %v to AST", value) - } -} - // NewStrVal builds a new StrVal. func NewStrVal(in []byte) *SQLVal { return &SQLVal{Type: StrVal, Val: in} diff --git a/go/vt/sqlparser/ast_test.go b/go/vt/sqlparser/ast_test.go index f229c0fa4ec..82d680567b8 100644 --- a/go/vt/sqlparser/ast_test.go +++ b/go/vt/sqlparser/ast_test.go @@ -25,7 +25,6 @@ import ( "unsafe" "github.com/stretchr/testify/require" - "vitess.io/vitess/go/sqltypes" ) func TestAppend(t *testing.T) { @@ -562,47 +561,6 @@ func TestReplaceExpr(t *testing.T) { } } -func TestExprFromValue(t *testing.T) { - tcases := []struct { - in sqltypes.Value - out SQLNode - err string - }{{ - in: sqltypes.NULL, - out: &NullVal{}, - }, { - in: sqltypes.NewInt64(1), - out: NewIntVal([]byte("1")), - }, { - in: sqltypes.NewFloat64(1.1), - out: NewFloatVal([]byte("1.1")), - }, { - in: sqltypes.MakeTrusted(sqltypes.Decimal, []byte("1.1")), - out: NewFloatVal([]byte("1.1")), - }, { - in: sqltypes.NewVarChar("aa"), - out: NewStrVal([]byte("aa")), - }, { - in: sqltypes.MakeTrusted(sqltypes.Expression, []byte("rand()")), - err: "cannot convert value EXPRESSION(rand()) to AST", - }} - for _, tcase := range tcases { - got, err := ExprFromValue(tcase.in) - if tcase.err != "" { - if err == nil || err.Error() != tcase.err { - t.Errorf("ExprFromValue(%v) err: %v, want %s", tcase.in, err, tcase.err) - } - continue - } - if err != nil { - t.Error(err) - } - if got, want := got, tcase.out; !reflect.DeepEqual(got, want) { - t.Errorf("ExprFromValue(%v): %v, want %s", tcase.in, got, want) - } - } -} - func TestColNameEqual(t *testing.T) { var c1, c2 *ColName if c1.Equal(c2) { diff --git a/go/vt/sqlparser/expression_converter.go b/go/vt/sqlparser/expression_converter.go new file mode 100644 index 00000000000..ce7763e470c --- /dev/null +++ b/go/vt/sqlparser/expression_converter.go @@ -0,0 +1,69 @@ +/* +Copyright 2020 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package sqlparser + +import ( + "fmt" + + "vitess.io/vitess/go/vt/vtgate/evalengine" +) + +var ExprNotSupported = fmt.Errorf("Expr Not Supported") + +//Convert converts between AST expressions and executable expressions +func Convert(e Expr) (evalengine.Expr, error) { + switch node := e.(type) { + case *SQLVal: + switch node.Type { + case IntVal: + return evalengine.NewLiteralInt(node.Val) + case FloatVal: + return evalengine.NewLiteralFloat(node.Val) + case ValArg: + return &evalengine.BindVariable{Key: string(node.Val[1:])}, nil + } + case *BinaryExpr: + var op evalengine.BinaryExpr + switch node.Operator { + case PlusStr: + op = &evalengine.Addition{} + case MinusStr: + op = &evalengine.Subtraction{} + case MultStr: + op = &evalengine.Multiplication{} + case DivStr: + op = &evalengine.Division{} + default: + return nil, ExprNotSupported + } + left, err := Convert(node.Left) + if err != nil { + return nil, err + } + right, err := Convert(node.Right) + if err != nil { + return nil, err + } + return &evalengine.BinaryOp{ + Expr: op, + Left: left, + Right: right, + }, nil + + } + return nil, ExprNotSupported +} diff --git a/go/vt/sqlparser/expressions_test.go b/go/vt/sqlparser/expressions_test.go new file mode 100644 index 00000000000..f0264ada165 --- /dev/null +++ b/go/vt/sqlparser/expressions_test.go @@ -0,0 +1,101 @@ +/* +Copyright 2020 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package sqlparser + +import ( + "testing" + + "vitess.io/vitess/go/vt/vtgate/evalengine" + + "vitess.io/vitess/go/sqltypes" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + querypb "vitess.io/vitess/go/vt/proto/query" +) + +/* +These tests should in theory live in the sqltypes package but they live here so we can +exercise both expression conversion and evaluation in the same test file +*/ + +func TestEvaluate(t *testing.T) { + type testCase struct { + expression string + expected sqltypes.Value + } + + tests := []testCase{{ + expression: "42", + expected: sqltypes.NewInt64(42), + }, { + expression: "42.42", + expected: sqltypes.NewFloat64(42.42), + }, { + expression: "40+2", + expected: sqltypes.NewInt64(42), + }, { + expression: "40-2", + expected: sqltypes.NewInt64(38), + }, { + expression: "40*2", + expected: sqltypes.NewInt64(80), + }, { + expression: "40/2", + expected: sqltypes.NewFloat64(20), + }, { + expression: ":exp", + expected: sqltypes.NewInt64(66), + }, { + expression: ":uint64_bind_variable", + expected: sqltypes.NewUint64(22), + }, { + expression: ":string_bind_variable", + expected: sqltypes.NewVarBinary("bar"), + }, { + expression: ":float_bind_variable", + expected: sqltypes.NewFloat64(2.2), + }} + + for _, test := range tests { + t.Run(test.expression, func(t *testing.T) { + // Given + stmt, err := Parse("select " + test.expression) + require.NoError(t, err) + astExpr := stmt.(*Select).SelectExprs[0].(*AliasedExpr).Expr + sqltypesExpr, err := Convert(astExpr) + require.Nil(t, err) + require.NotNil(t, sqltypesExpr) + env := evalengine.ExpressionEnv{ + BindVars: map[string]*querypb.BindVariable{ + "exp": sqltypes.Int64BindVariable(66), + "string_bind_variable": sqltypes.StringBindVariable("bar"), + "uint64_bind_variable": sqltypes.Uint64BindVariable(22), + "float_bind_variable": sqltypes.Float64BindVariable(2.2), + }, + Row: nil, + } + + // When + r, err := sqltypesExpr.Evaluate(env) + + // Then + require.NoError(t, err) + assert.Equal(t, test.expected, r.Value(), "expected %s", test.expected.String()) + }) + } +} diff --git a/go/vt/vitessdriver/convert.go b/go/vt/vitessdriver/convert.go index 16d1b61134d..703ab076b0d 100644 --- a/go/vt/vitessdriver/convert.go +++ b/go/vt/vitessdriver/convert.go @@ -21,6 +21,8 @@ import ( "fmt" "time" + "vitess.io/vitess/go/vt/vtgate/evalengine" + "vitess.io/vitess/go/sqltypes" querypb "vitess.io/vitess/go/vt/proto/query" ) @@ -36,7 +38,7 @@ func (cv *converter) ToNative(v sqltypes.Value) (interface{}, error) { case sqltypes.Date: return DateToNative(v, cv.location) } - return sqltypes.ToNative(v) + return evalengine.ToNative(v) } func (cv *converter) BuildBindVariable(v interface{}) (*querypb.BindVariable, error) { diff --git a/go/vt/vtgate/endtoend/database_func_test.go b/go/vt/vtgate/endtoend/database_func_test.go index e0cb0805f59..3a3b27c65f8 100644 --- a/go/vt/vtgate/endtoend/database_func_test.go +++ b/go/vt/vtgate/endtoend/database_func_test.go @@ -34,7 +34,7 @@ func TestDatabaseFunc(t *testing.T) { exec(t, conn, "use ks") qr := exec(t, conn, "select database()") - if got, want := fmt.Sprintf("%v", qr.Rows), `[[VARCHAR("ks")]]`; got != want { + if got, want := fmt.Sprintf("%v", qr.Rows), `[[VARBINARY("ks")]]`; got != want { t.Errorf("select:\n%v want\n%v", got, want) } } diff --git a/go/vt/vtgate/endtoend/last_insert_id_test.go b/go/vt/vtgate/endtoend/last_insert_id_test.go index 81b90d4fe76..60694090637 100644 --- a/go/vt/vtgate/endtoend/last_insert_id_test.go +++ b/go/vt/vtgate/endtoend/last_insert_id_test.go @@ -21,10 +21,10 @@ import ( "fmt" "testing" + "vitess.io/vitess/go/vt/vtgate/evalengine" + "github.com/google/go-cmp/cmp" "github.com/stretchr/testify/require" - "vitess.io/vitess/go/sqltypes" - "vitess.io/vitess/go/mysql" ) @@ -36,7 +36,7 @@ func TestLastInsertId(t *testing.T) { // figure out the last inserted id before we run change anything qr := exec(t, conn, "select max(id) from t1_last_insert_id") - oldLastID, err := sqltypes.ToUint64(qr.Rows[0][0]) + oldLastID, err := evalengine.ToUint64(qr.Rows[0][0]) require.NoError(t, err) exec(t, conn, "insert into t1_last_insert_id(id1) values(42)") @@ -44,7 +44,7 @@ func TestLastInsertId(t *testing.T) { // even without a transaction, we should get the last inserted id back qr = exec(t, conn, "select last_insert_id()") got := fmt.Sprintf("%v", qr.Rows) - want := fmt.Sprintf("[[INT64(%d)]]", oldLastID+1) + want := fmt.Sprintf("[[UINT64(%d)]]", oldLastID+1) if diff := cmp.Diff(want, got); diff != "" { t.Error(diff) @@ -59,7 +59,7 @@ func TestLastInsertIdWithRollback(t *testing.T) { // figure out the last inserted id before we run our tests qr := exec(t, conn, "select max(id) from t1_last_insert_id") - oldLastID, err := sqltypes.ToUint64(qr.Rows[0][0]) + oldLastID, err := evalengine.ToUint64(qr.Rows[0][0]) require.NoError(t, err) // add row inside explicit transaction @@ -67,7 +67,7 @@ func TestLastInsertIdWithRollback(t *testing.T) { exec(t, conn, "insert into t1_last_insert_id(id1) values(42)") qr = exec(t, conn, "select last_insert_id()") got := fmt.Sprintf("%v", qr.Rows) - want := fmt.Sprintf("[[INT64(%d)]]", oldLastID+1) + want := fmt.Sprintf("[[UINT64(%d)]]", oldLastID+1) if diff := cmp.Diff(want, got); diff != "" { t.Error(diff) diff --git a/go/vt/vtgate/engine/insert.go b/go/vt/vtgate/engine/insert.go index 6a9db408cae..7e9d7af07e5 100644 --- a/go/vt/vtgate/engine/insert.go +++ b/go/vt/vtgate/engine/insert.go @@ -23,6 +23,8 @@ import ( "strings" "time" + "vitess.io/vitess/go/vt/vtgate/evalengine" + topodatapb "vitess.io/vitess/go/vt/proto/topodata" "vitess.io/vitess/go/sqltypes" @@ -299,7 +301,7 @@ func (ins *Insert) processGenerate(vcursor VCursor, bindVars map[string]*querypb } // If no rows are returned, it's an internal error, and the code // must panic, which will be caught and reported. - insertID, err = sqltypes.ToInt64(qr.Rows[0][0]) + insertID, err = evalengine.ToInt64(qr.Rows[0][0]) if err != nil { return 0, err } diff --git a/go/vt/vtgate/engine/limit.go b/go/vt/vtgate/engine/limit.go index 68b7b4b327c..c43b6a4b28f 100644 --- a/go/vt/vtgate/engine/limit.go +++ b/go/vt/vtgate/engine/limit.go @@ -20,6 +20,8 @@ import ( "fmt" "io" + "vitess.io/vitess/go/vt/vtgate/evalengine" + "vitess.io/vitess/go/sqltypes" querypb "vitess.io/vitess/go/vt/proto/query" @@ -157,7 +159,7 @@ func (l *Limit) fetchCount(bindVars map[string]*querypb.BindVariable) (int, erro if err != nil { return 0, err } - num, err := sqltypes.ToUint64(resolved) + num, err := evalengine.ToUint64(resolved) if err != nil { return 0, err } @@ -176,7 +178,7 @@ func (l *Limit) fetchOffset(bindVars map[string]*querypb.BindVariable) (int, err if err != nil { return 0, err } - num, err := sqltypes.ToUint64(resolved) + num, err := evalengine.ToUint64(resolved) if err != nil { return 0, err } diff --git a/go/vt/vtgate/engine/memory_sort.go b/go/vt/vtgate/engine/memory_sort.go index 4db9b6a1448..c243a31d607 100644 --- a/go/vt/vtgate/engine/memory_sort.go +++ b/go/vt/vtgate/engine/memory_sort.go @@ -24,6 +24,8 @@ import ( "sort" "strings" + "vitess.io/vitess/go/vt/vtgate/evalengine" + "vitess.io/vitess/go/sqltypes" querypb "vitess.io/vitess/go/vt/proto/query" ) @@ -161,7 +163,7 @@ func (ms *MemorySort) fetchCount(bindVars map[string]*querypb.BindVariable) (int if resolved.IsNull() { return math.MaxInt64, nil } - num, err := sqltypes.ToUint64(resolved) + num, err := evalengine.ToUint64(resolved) if err != nil { return 0, err } @@ -230,7 +232,7 @@ func (sh *sortHeap) Less(i, j int) bool { if sh.err != nil { return true } - cmp, err := sqltypes.NullsafeCompare(sh.rows[i][order.Col], sh.rows[j][order.Col]) + cmp, err := evalengine.NullsafeCompare(sh.rows[i][order.Col], sh.rows[j][order.Col]) if err != nil { sh.err = err return true diff --git a/go/vt/vtgate/engine/merge_sort.go b/go/vt/vtgate/engine/merge_sort.go index 17ae57c1c69..aea1cc90f11 100644 --- a/go/vt/vtgate/engine/merge_sort.go +++ b/go/vt/vtgate/engine/merge_sort.go @@ -20,6 +20,8 @@ import ( "container/heap" "io" + "vitess.io/vitess/go/vt/vtgate/evalengine" + "golang.org/x/net/context" "vitess.io/vitess/go/sqltypes" @@ -250,7 +252,7 @@ func (sh *scatterHeap) Less(i, j int) bool { if sh.err != nil { return true } - cmp, err := sqltypes.NullsafeCompare(sh.rows[i].row[order.Col], sh.rows[j].row[order.Col]) + cmp, err := evalengine.NullsafeCompare(sh.rows[i].row[order.Col], sh.rows[j].row[order.Col]) if err != nil { sh.err = err return true diff --git a/go/vt/vtgate/engine/ordered_aggregate.go b/go/vt/vtgate/engine/ordered_aggregate.go index 74b1667e734..bee3cc52de5 100644 --- a/go/vt/vtgate/engine/ordered_aggregate.go +++ b/go/vt/vtgate/engine/ordered_aggregate.go @@ -20,6 +20,8 @@ import ( "fmt" "strconv" + "vitess.io/vitess/go/vt/vtgate/evalengine" + "vitess.io/vitess/go/vt/proto/vtrpc" "vitess.io/vitess/go/vt/vterrors" @@ -298,7 +300,7 @@ func (oa *OrderedAggregate) convertRow(row []sqltypes.Value) (newRow []sqltypes. case AggregateSumDistinct: curDistinct = row[aggr.Col] var err error - newRow[aggr.Col], err = sqltypes.Cast(row[aggr.Col], opcodeType[aggr.Opcode]) + newRow[aggr.Col], err = evalengine.Cast(row[aggr.Col], opcodeType[aggr.Opcode]) if err != nil { newRow[aggr.Col] = sumZero } @@ -328,7 +330,7 @@ func (oa *OrderedAggregate) NeedsTransaction() bool { func (oa *OrderedAggregate) keysEqual(row1, row2 []sqltypes.Value) (bool, error) { for _, key := range oa.Keys { - cmp, err := sqltypes.NullsafeCompare(row1[key], row2[key]) + cmp, err := evalengine.NullsafeCompare(row1[key], row2[key]) if err != nil { return false, err } @@ -346,7 +348,7 @@ func (oa *OrderedAggregate) merge(fields []*querypb.Field, row1, row2 []sqltypes if row2[aggr.Col].IsNull() { continue } - cmp, err := sqltypes.NullsafeCompare(curDistinct, row2[aggr.Col]) + cmp, err := evalengine.NullsafeCompare(curDistinct, row2[aggr.Col]) if err != nil { return nil, sqltypes.NULL, err } @@ -358,15 +360,15 @@ func (oa *OrderedAggregate) merge(fields []*querypb.Field, row1, row2 []sqltypes var err error switch aggr.Opcode { case AggregateCount, AggregateSum: - result[aggr.Col] = sqltypes.NullsafeAdd(row1[aggr.Col], row2[aggr.Col], fields[aggr.Col].Type) + result[aggr.Col] = evalengine.NullsafeAdd(row1[aggr.Col], row2[aggr.Col], fields[aggr.Col].Type) case AggregateMin: - result[aggr.Col], err = sqltypes.Min(row1[aggr.Col], row2[aggr.Col]) + result[aggr.Col], err = evalengine.Min(row1[aggr.Col], row2[aggr.Col]) case AggregateMax: - result[aggr.Col], err = sqltypes.Max(row1[aggr.Col], row2[aggr.Col]) + result[aggr.Col], err = evalengine.Max(row1[aggr.Col], row2[aggr.Col]) case AggregateCountDistinct: - result[aggr.Col] = sqltypes.NullsafeAdd(row1[aggr.Col], countOne, opcodeType[aggr.Opcode]) + result[aggr.Col] = evalengine.NullsafeAdd(row1[aggr.Col], countOne, opcodeType[aggr.Opcode]) case AggregateSumDistinct: - result[aggr.Col] = sqltypes.NullsafeAdd(row1[aggr.Col], row2[aggr.Col], opcodeType[aggr.Opcode]) + result[aggr.Col] = evalengine.NullsafeAdd(row1[aggr.Col], row2[aggr.Col], opcodeType[aggr.Opcode]) default: return nil, sqltypes.NULL, fmt.Errorf("BUG: Unexpected opcode: %v", aggr.Opcode) } diff --git a/go/vt/vtgate/engine/projection.go b/go/vt/vtgate/engine/projection.go new file mode 100644 index 00000000000..0e0f9487adc --- /dev/null +++ b/go/vt/vtgate/engine/projection.go @@ -0,0 +1,123 @@ +package engine + +import ( + "vitess.io/vitess/go/sqltypes" + querypb "vitess.io/vitess/go/vt/proto/query" + "vitess.io/vitess/go/vt/vtgate/evalengine" +) + +var _ Primitive = (*Projection)(nil) + +type Projection struct { + Cols []string + Exprs []evalengine.Expr + Input Primitive + noTxNeeded +} + +func (p *Projection) RouteType() string { + return p.Input.RouteType() +} + +func (p *Projection) GetKeyspaceName() string { + return p.Input.GetKeyspaceName() +} + +func (p *Projection) GetTableName() string { + return p.Input.GetTableName() +} + +func (p *Projection) Execute(vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool) (*sqltypes.Result, error) { + result, err := p.Input.Execute(vcursor, bindVars, wantfields) + if err != nil { + return nil, err + } + + env := evalengine.ExpressionEnv{ + BindVars: bindVars, + } + + if wantfields { + p.addFields(result, bindVars) + } + var rows [][]sqltypes.Value + for _, row := range result.Rows { + env.Row = row + for _, exp := range p.Exprs { + result, err := exp.Evaluate(env) + if err != nil { + return nil, err + } + row = append(row, result.Value()) + } + rows = append(rows, row) + } + result.Rows = rows + return result, nil +} + +func (p *Projection) StreamExecute(vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantields bool, callback func(*sqltypes.Result) error) error { + result, err := p.Input.Execute(vcursor, bindVars, wantields) + if err != nil { + return err + } + + env := evalengine.ExpressionEnv{ + BindVars: bindVars, + } + + if wantields { + p.addFields(result, bindVars) + } + var rows [][]sqltypes.Value + for _, row := range result.Rows { + env.Row = row + for _, exp := range p.Exprs { + result, err := exp.Evaluate(env) + if err != nil { + return err + } + row = append(row, result.Value()) + } + rows = append(rows, row) + } + result.Rows = rows + return callback(result) +} + +func (p *Projection) GetFields(vcursor VCursor, bindVars map[string]*querypb.BindVariable) (*sqltypes.Result, error) { + qr, err := p.Input.GetFields(vcursor, bindVars) + if err != nil { + return nil, err + } + p.addFields(qr, bindVars) + return qr, nil +} + +func (p *Projection) addFields(qr *sqltypes.Result, bindVars map[string]*querypb.BindVariable) { + env := evalengine.ExpressionEnv{BindVars: bindVars} + for i, col := range p.Cols { + qr.Fields = append(qr.Fields, &querypb.Field{ + Name: col, + Type: p.Exprs[i].Type(env), + }) + } +} + +func (p *Projection) Inputs() []Primitive { + return []Primitive{p.Input} +} + +func (p *Projection) description() PrimitiveDescription { + var exprs []string + for _, e := range p.Exprs { + exprs = append(exprs, e.String()) + } + return PrimitiveDescription{ + OperatorType: "Projection", + Other: map[string]interface{}{ + "Expressions": exprs, + "Columns": p.Cols, + }, + } +} diff --git a/go/vt/vtgate/engine/route.go b/go/vt/vtgate/engine/route.go index 9fff1e5afeb..c772f8f76dc 100644 --- a/go/vt/vtgate/engine/route.go +++ b/go/vt/vtgate/engine/route.go @@ -23,6 +23,8 @@ import ( "strconv" "time" + "vitess.io/vitess/go/vt/vtgate/evalengine" + "vitess.io/vitess/go/mysql" "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/stats" @@ -439,7 +441,7 @@ func (route *Route) sort(in *sqltypes.Result) (*sqltypes.Result, error) { return true } var cmp int - cmp, err = sqltypes.NullsafeCompare(out.Rows[i][order.Col], out.Rows[j][order.Col]) + cmp, err = evalengine.NullsafeCompare(out.Rows[i][order.Col], out.Rows[j][order.Col]) if err != nil { return true } diff --git a/go/vt/vtgate/engine/singlerow.go b/go/vt/vtgate/engine/singlerow.go new file mode 100644 index 00000000000..d7c2e1fefe4 --- /dev/null +++ b/go/vt/vtgate/engine/singlerow.go @@ -0,0 +1,82 @@ +/* +Copyright 2020 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package engine + +import ( + "vitess.io/vitess/go/sqltypes" + "vitess.io/vitess/go/vt/proto/query" +) + +var _ Primitive = (*SingleRow)(nil) + +// SingleRow defines an empty result +type SingleRow struct { + noInputs + noTxNeeded +} + +// RouteType returns a description of the query routing type used by the primitive +func (s *SingleRow) RouteType() string { + return "" +} + +// GetKeyspaceName specifies the Keyspace that this primitive routes to. +func (s *SingleRow) GetKeyspaceName() string { + return "" +} + +// GetTableName specifies the table that this primitive routes to. +func (s *SingleRow) GetTableName() string { + return "" +} + +// Execute performs a non-streaming exec. +func (s *SingleRow) Execute(vcursor VCursor, bindVars map[string]*query.BindVariable, wantfields bool) (*sqltypes.Result, error) { + result := sqltypes.Result{ + Fields: nil, + RowsAffected: 0, + InsertID: 0, + Rows: [][]sqltypes.Value{ + {}, + }, + } + return &result, nil +} + +// StreamExecute performs a streaming exec. +func (s *SingleRow) StreamExecute(vcursor VCursor, bindVars map[string]*query.BindVariable, wantields bool, callback func(*sqltypes.Result) error) error { + result := sqltypes.Result{ + Fields: nil, + RowsAffected: 0, + InsertID: 0, + Rows: [][]sqltypes.Value{ + {}, + }, + } + return callback(&result) +} + +// GetFields fetches the field info. +func (s *SingleRow) GetFields(vcursor VCursor, bindVars map[string]*query.BindVariable) (*sqltypes.Result, error) { + return &sqltypes.Result{}, nil +} + +func (s *SingleRow) description() PrimitiveDescription { + return PrimitiveDescription{ + OperatorType: "SingleRow", + } +} diff --git a/go/vt/vtgate/engine/vindex_func.go b/go/vt/vtgate/engine/vindex_func.go index 0e4e516c3b8..c30d0bcd80d 100644 --- a/go/vt/vtgate/engine/vindex_func.go +++ b/go/vt/vtgate/engine/vindex_func.go @@ -19,6 +19,8 @@ package engine import ( "encoding/json" + "vitess.io/vitess/go/vt/vtgate/evalengine" + "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/vt/key" "vitess.io/vitess/go/vt/vtgate/vindexes" @@ -109,7 +111,7 @@ func (vf *VindexFunc) mapVindex(vcursor VCursor, bindVars map[string]*querypb.Bi if err != nil { return nil, err } - vkey, err := sqltypes.Cast(k, sqltypes.VarBinary) + vkey, err := evalengine.Cast(k, sqltypes.VarBinary) if err != nil { return nil, err } diff --git a/go/sqltypes/arithmetic.go b/go/vt/vtgate/evalengine/arithmetic.go similarity index 50% rename from go/sqltypes/arithmetic.go rename to go/vt/vtgate/evalengine/arithmetic.go index 6149a58c2c4..97e9e5f321e 100644 --- a/go/sqltypes/arithmetic.go +++ b/go/vt/vtgate/evalengine/arithmetic.go @@ -14,11 +14,12 @@ See the License for the specific language governing permissions and limitations under the License. */ -package sqltypes +package evalengine import ( "bytes" "fmt" + "vitess.io/vitess/go/sqltypes" "strconv" @@ -29,112 +30,105 @@ import ( // numeric represents a numeric value extracted from // a Value, used for arithmetic operations. -type numeric struct { - typ querypb.Type - ival int64 - uval uint64 - fval float64 -} - var zeroBytes = []byte("0") // Add adds two values together // if v1 or v2 is null, then it returns null -func Add(v1, v2 Value) (Value, error) { +func Add(v1, v2 sqltypes.Value) (sqltypes.Value, error) { if v1.IsNull() || v2.IsNull() { - return NULL, nil + return sqltypes.NULL, nil } lv1, err := newNumeric(v1) if err != nil { - return NULL, err + return sqltypes.NULL, err } lv2, err := newNumeric(v2) if err != nil { - return NULL, err + return sqltypes.NULL, err } lresult, err := addNumericWithError(lv1, lv2) if err != nil { - return NULL, err + return sqltypes.NULL, err } return castFromNumeric(lresult, lresult.typ), nil } // Subtract takes two values and subtracts them -func Subtract(v1, v2 Value) (Value, error) { +func Subtract(v1, v2 sqltypes.Value) (sqltypes.Value, error) { if v1.IsNull() || v2.IsNull() { - return NULL, nil + return sqltypes.NULL, nil } lv1, err := newNumeric(v1) if err != nil { - return NULL, err + return sqltypes.NULL, err } lv2, err := newNumeric(v2) if err != nil { - return NULL, err + return sqltypes.NULL, err } lresult, err := subtractNumericWithError(lv1, lv2) if err != nil { - return NULL, err + return sqltypes.NULL, err } return castFromNumeric(lresult, lresult.typ), nil } // Multiply takes two values and multiplies it together -func Multiply(v1, v2 Value) (Value, error) { +func Multiply(v1, v2 sqltypes.Value) (sqltypes.Value, error) { if v1.IsNull() || v2.IsNull() { - return NULL, nil + return sqltypes.NULL, nil } lv1, err := newNumeric(v1) if err != nil { - return NULL, err + return sqltypes.NULL, err } lv2, err := newNumeric(v2) if err != nil { - return NULL, err + return sqltypes.NULL, err } lresult, err := multiplyNumericWithError(lv1, lv2) if err != nil { - return NULL, err + return sqltypes.NULL, err } return castFromNumeric(lresult, lresult.typ), nil } -// Float Division for MySQL. Replicates behavior of "/" operator -func Divide(v1, v2 Value) (Value, error) { +// Divide (Float) for MySQL. Replicates behavior of "/" operator +func Divide(v1, v2 sqltypes.Value) (sqltypes.Value, error) { if v1.IsNull() || v2.IsNull() { - return NULL, nil + return sqltypes.NULL, nil } lv2AsFloat, err := ToFloat64(v2) divisorIsZero := lv2AsFloat == 0 if divisorIsZero || err != nil { - return NULL, err + return sqltypes.NULL, err } lv1, err := newNumeric(v1) if err != nil { - return NULL, err + return sqltypes.NULL, err } lv2, err := newNumeric(v2) if err != nil { - return NULL, err + return sqltypes.NULL, err } lresult, err := divideNumericWithError(lv1, lv2) if err != nil { - return NULL, err + return sqltypes.NULL, err } return castFromNumeric(lresult, lresult.typ), nil @@ -151,21 +145,21 @@ func Divide(v1, v2 Value) (Value, error) { // addition, if one of the input types was Decimal, then // a Decimal is built. Otherwise, the final type of the // result is preserved. -func NullsafeAdd(v1, v2 Value, resultType querypb.Type) Value { +func NullsafeAdd(v1, v2 sqltypes.Value, resultType querypb.Type) sqltypes.Value { if v1.IsNull() { - v1 = MakeTrusted(resultType, zeroBytes) + v1 = sqltypes.MakeTrusted(resultType, zeroBytes) } if v2.IsNull() { - v2 = MakeTrusted(resultType, zeroBytes) + v2 = sqltypes.MakeTrusted(resultType, zeroBytes) } lv1, err := newNumeric(v1) if err != nil { - return NULL + return sqltypes.NULL } lv2, err := newNumeric(v2) if err != nil { - return NULL + return sqltypes.NULL } lresult := addNumeric(lv1, lv2) @@ -177,7 +171,7 @@ func NullsafeAdd(v1, v2 Value, resultType querypb.Type) Value { // numeric, then a numeric comparison is performed after // necessary conversions. If none are numeric, then it's // a simple binary comparison. Uncomparable values return an error. -func NullsafeCompare(v1, v2 Value) (int, error) { +func NullsafeCompare(v1, v2 sqltypes.Value) (int, error) { // Based on the categorization defined for the types, // we're going to allow comparison of the following: // Null, isNumber, IsBinary. This will exclude IsQuoted @@ -191,7 +185,7 @@ func NullsafeCompare(v1, v2 Value) (int, error) { if v2.IsNull() { return 1, nil } - if isNumber(v1.Type()) || isNumber(v2.Type()) { + if sqltypes.IsNumber(v1.Type()) || sqltypes.IsNumber(v2.Type()) { lv1, err := newNumeric(v1) if err != nil { return 0, err @@ -209,12 +203,12 @@ func NullsafeCompare(v1, v2 Value) (int, error) { } // isByteComparable returns true if the type is binary or date/time. -func isByteComparable(v Value) bool { +func isByteComparable(v sqltypes.Value) bool { if v.IsBinary() { return true } switch v.Type() { - case Timestamp, Date, Time, Datetime: + case sqltypes.Timestamp, sqltypes.Date, sqltypes.Time, sqltypes.Datetime: return true } return false @@ -223,18 +217,18 @@ func isByteComparable(v Value) bool { // Min returns the minimum of v1 and v2. If one of the // values is NULL, it returns the other value. If both // are NULL, it returns NULL. -func Min(v1, v2 Value) (Value, error) { +func Min(v1, v2 sqltypes.Value) (sqltypes.Value, error) { return minmax(v1, v2, true) } // Max returns the maximum of v1 and v2. If one of the // values is NULL, it returns the other value. If both // are NULL, it returns NULL. -func Max(v1, v2 Value) (Value, error) { +func Max(v1, v2 sqltypes.Value) (sqltypes.Value, error) { return minmax(v1, v2, false) } -func minmax(v1, v2 Value, min bool) (Value, error) { +func minmax(v1, v2 sqltypes.Value, min bool) (sqltypes.Value, error) { if v1.IsNull() { return v2, nil } @@ -244,7 +238,7 @@ func minmax(v1, v2 Value, min bool) (Value, error) { n, err := NullsafeCompare(v1, v2) if err != nil { - return NULL, err + return sqltypes.NULL, err } // XNOR construct. See tests. @@ -256,61 +250,61 @@ func minmax(v1, v2 Value, min bool) (Value, error) { } // Cast converts a Value to the target type. -func Cast(v Value, typ querypb.Type) (Value, error) { +func Cast(v sqltypes.Value, typ querypb.Type) (sqltypes.Value, error) { if v.Type() == typ || v.IsNull() { return v, nil } - if IsSigned(typ) && v.IsSigned() { - return MakeTrusted(typ, v.ToBytes()), nil + if sqltypes.IsSigned(typ) && v.IsSigned() { + return sqltypes.MakeTrusted(typ, v.ToBytes()), nil } - if IsUnsigned(typ) && v.IsUnsigned() { - return MakeTrusted(typ, v.ToBytes()), nil + if sqltypes.IsUnsigned(typ) && v.IsUnsigned() { + return sqltypes.MakeTrusted(typ, v.ToBytes()), nil } - if (IsFloat(typ) || typ == Decimal) && (v.IsIntegral() || v.IsFloat() || v.Type() == Decimal) { - return MakeTrusted(typ, v.ToBytes()), nil + if (sqltypes.IsFloat(typ) || typ == sqltypes.Decimal) && (v.IsIntegral() || v.IsFloat() || v.Type() == sqltypes.Decimal) { + return sqltypes.MakeTrusted(typ, v.ToBytes()), nil } - if IsQuoted(typ) && (v.IsIntegral() || v.IsFloat() || v.Type() == Decimal || v.IsQuoted()) { - return MakeTrusted(typ, v.ToBytes()), nil + if sqltypes.IsQuoted(typ) && (v.IsIntegral() || v.IsFloat() || v.Type() == sqltypes.Decimal || v.IsQuoted()) { + return sqltypes.MakeTrusted(typ, v.ToBytes()), nil } // Explicitly disallow Expression. - if v.Type() == Expression { - return NULL, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "%v cannot be cast to %v", v, typ) + if v.Type() == sqltypes.Expression { + return sqltypes.NULL, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "%v cannot be cast to %v", v, typ) } // If the above fast-paths were not possible, // go through full validation. - return NewValue(typ, v.ToBytes()) + return sqltypes.NewValue(typ, v.ToBytes()) } // ToUint64 converts Value to uint64. -func ToUint64(v Value) (uint64, error) { +func ToUint64(v sqltypes.Value) (uint64, error) { num, err := newIntegralNumeric(v) if err != nil { return 0, err } switch num.typ { - case Int64: + case sqltypes.Int64: if num.ival < 0 { return 0, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "negative number cannot be converted to unsigned: %d", num.ival) } return uint64(num.ival), nil - case Uint64: + case sqltypes.Uint64: return num.uval, nil } panic("unreachable") } // ToInt64 converts Value to int64. -func ToInt64(v Value) (int64, error) { +func ToInt64(v sqltypes.Value) (int64, error) { num, err := newIntegralNumeric(v) if err != nil { return 0, err } switch num.typ { - case Int64: + case sqltypes.Int64: return num.ival, nil - case Uint64: + case sqltypes.Uint64: ival := int64(num.uval) if ival < 0 { return 0, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "unsigned number overflows int64 value: %d", num.uval) @@ -321,17 +315,17 @@ func ToInt64(v Value) (int64, error) { } // ToFloat64 converts Value to float64. -func ToFloat64(v Value) (float64, error) { +func ToFloat64(v sqltypes.Value) (float64, error) { num, err := newNumeric(v) if err != nil { return 0, err } switch num.typ { - case Int64: + case sqltypes.Int64: return float64(num.ival), nil - case Uint64: + case sqltypes.Uint64: return float64(num.uval), nil - case Float64: + case sqltypes.Float64: return num.fval, nil } panic("unreachable") @@ -339,11 +333,11 @@ func ToFloat64(v Value) (float64, error) { // ToNative converts Value to a native go type. // Decimal is returned as []byte. -func ToNative(v Value) (interface{}, error) { +func ToNative(v sqltypes.Value) (interface{}, error) { var out interface{} var err error switch { - case v.Type() == Null: + case v.Type() == sqltypes.Null: // no-op case v.IsSigned(): return ToInt64(v) @@ -351,165 +345,165 @@ func ToNative(v Value) (interface{}, error) { return ToUint64(v) case v.IsFloat(): return ToFloat64(v) - case v.IsQuoted() || v.Type() == Bit || v.Type() == Decimal: - out = v.val - case v.Type() == Expression: + case v.IsQuoted() || v.Type() == sqltypes.Bit || v.Type() == sqltypes.Decimal: + out = v.ToBytes() + case v.Type() == sqltypes.Expression: err = vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "%v cannot be converted to a go type", v) } return out, err } // newNumeric parses a value and produces an Int64, Uint64 or Float64. -func newNumeric(v Value) (numeric, error) { +func newNumeric(v sqltypes.Value) (evalResult, error) { str := v.ToString() switch { case v.IsSigned(): ival, err := strconv.ParseInt(str, 10, 64) if err != nil { - return numeric{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "%v", err) + return evalResult{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "%v", err) } - return numeric{ival: ival, typ: Int64}, nil + return evalResult{ival: ival, typ: sqltypes.Int64}, nil case v.IsUnsigned(): uval, err := strconv.ParseUint(str, 10, 64) if err != nil { - return numeric{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "%v", err) + return evalResult{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "%v", err) } - return numeric{uval: uval, typ: Uint64}, nil + return evalResult{uval: uval, typ: sqltypes.Uint64}, nil case v.IsFloat(): fval, err := strconv.ParseFloat(str, 64) if err != nil { - return numeric{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "%v", err) + return evalResult{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "%v", err) } - return numeric{fval: fval, typ: Float64}, nil + return evalResult{fval: fval, typ: sqltypes.Float64}, nil } // For other types, do best effort. if ival, err := strconv.ParseInt(str, 10, 64); err == nil { - return numeric{ival: ival, typ: Int64}, nil + return evalResult{ival: ival, typ: sqltypes.Int64}, nil } if fval, err := strconv.ParseFloat(str, 64); err == nil { - return numeric{fval: fval, typ: Float64}, nil + return evalResult{fval: fval, typ: sqltypes.Float64}, nil } - return numeric{ival: 0, typ: Int64}, nil + return evalResult{ival: 0, typ: sqltypes.Int64}, nil } // newIntegralNumeric parses a value and produces an Int64 or Uint64. -func newIntegralNumeric(v Value) (numeric, error) { +func newIntegralNumeric(v sqltypes.Value) (evalResult, error) { str := v.ToString() switch { case v.IsSigned(): ival, err := strconv.ParseInt(str, 10, 64) if err != nil { - return numeric{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "%v", err) + return evalResult{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "%v", err) } - return numeric{ival: ival, typ: Int64}, nil + return evalResult{ival: ival, typ: sqltypes.Int64}, nil case v.IsUnsigned(): uval, err := strconv.ParseUint(str, 10, 64) if err != nil { - return numeric{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "%v", err) + return evalResult{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "%v", err) } - return numeric{uval: uval, typ: Uint64}, nil + return evalResult{uval: uval, typ: sqltypes.Uint64}, nil } // For other types, do best effort. if ival, err := strconv.ParseInt(str, 10, 64); err == nil { - return numeric{ival: ival, typ: Int64}, nil + return evalResult{ival: ival, typ: sqltypes.Int64}, nil } if uval, err := strconv.ParseUint(str, 10, 64); err == nil { - return numeric{uval: uval, typ: Uint64}, nil + return evalResult{uval: uval, typ: sqltypes.Uint64}, nil } - return numeric{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "could not parse value: '%s'", str) + return evalResult{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "could not parse value: '%s'", str) } -func addNumeric(v1, v2 numeric) numeric { +func addNumeric(v1, v2 evalResult) evalResult { v1, v2 = prioritize(v1, v2) switch v1.typ { - case Int64: + case sqltypes.Int64: return intPlusInt(v1.ival, v2.ival) - case Uint64: + case sqltypes.Uint64: switch v2.typ { - case Int64: + case sqltypes.Int64: return uintPlusInt(v1.uval, v2.ival) - case Uint64: + case sqltypes.Uint64: return uintPlusUint(v1.uval, v2.uval) } - case Float64: + case sqltypes.Float64: return floatPlusAny(v1.fval, v2) } panic("unreachable") } -func addNumericWithError(v1, v2 numeric) (numeric, error) { +func addNumericWithError(v1, v2 evalResult) (evalResult, error) { v1, v2 = prioritize(v1, v2) switch v1.typ { - case Int64: + case sqltypes.Int64: return intPlusIntWithError(v1.ival, v2.ival) - case Uint64: + case sqltypes.Uint64: switch v2.typ { - case Int64: + case sqltypes.Int64: return uintPlusIntWithError(v1.uval, v2.ival) - case Uint64: + case sqltypes.Uint64: return uintPlusUintWithError(v1.uval, v2.uval) } - case Float64: + case sqltypes.Float64: return floatPlusAny(v1.fval, v2), nil } panic("unreachable") } -func subtractNumericWithError(v1, v2 numeric) (numeric, error) { +func subtractNumericWithError(v1, v2 evalResult) (evalResult, error) { switch v1.typ { - case Int64: + case sqltypes.Int64: switch v2.typ { - case Int64: + case sqltypes.Int64: return intMinusIntWithError(v1.ival, v2.ival) - case Uint64: + case sqltypes.Uint64: return intMinusUintWithError(v1.ival, v2.uval) - case Float64: + case sqltypes.Float64: return anyMinusFloat(v1, v2.fval), nil } - case Uint64: + case sqltypes.Uint64: switch v2.typ { - case Int64: + case sqltypes.Int64: return uintMinusIntWithError(v1.uval, v2.ival) - case Uint64: + case sqltypes.Uint64: return uintMinusUintWithError(v1.uval, v2.uval) - case Float64: + case sqltypes.Float64: return anyMinusFloat(v1, v2.fval), nil } - case Float64: + case sqltypes.Float64: return floatMinusAny(v1.fval, v2), nil } panic("unreachable") } -func multiplyNumericWithError(v1, v2 numeric) (numeric, error) { +func multiplyNumericWithError(v1, v2 evalResult) (evalResult, error) { v1, v2 = prioritize(v1, v2) switch v1.typ { - case Int64: + case sqltypes.Int64: return intTimesIntWithError(v1.ival, v2.ival) - case Uint64: + case sqltypes.Uint64: switch v2.typ { - case Int64: + case sqltypes.Int64: return uintTimesIntWithError(v1.uval, v2.ival) - case Uint64: + case sqltypes.Uint64: return uintTimesUintWithError(v1.uval, v2.uval) } - case Float64: + case sqltypes.Float64: return floatTimesAny(v1.fval, v2), nil } panic("unreachable") } -func divideNumericWithError(v1, v2 numeric) (numeric, error) { +func divideNumericWithError(v1, v2 evalResult) (evalResult, error) { switch v1.typ { - case Int64: + case sqltypes.Int64: return floatDivideAnyWithError(float64(v1.ival), v2) - case Uint64: + case sqltypes.Uint64: return floatDivideAnyWithError(float64(v1.uval), v2) - case Float64: + case sqltypes.Float64: return floatDivideAnyWithError(v1.fval, v2) } panic("unreachable") @@ -517,21 +511,21 @@ func divideNumericWithError(v1, v2 numeric) (numeric, error) { // prioritize reorders the input parameters // to be Float64, Uint64, Int64. -func prioritize(v1, v2 numeric) (altv1, altv2 numeric) { +func prioritize(v1, v2 evalResult) (altv1, altv2 evalResult) { switch v1.typ { - case Int64: - if v2.typ == Uint64 || v2.typ == Float64 { + case sqltypes.Int64: + if v2.typ == sqltypes.Uint64 || v2.typ == sqltypes.Float64 { return v2, v1 } - case Uint64: - if v2.typ == Float64 { + case sqltypes.Uint64: + if v2.typ == sqltypes.Float64 { return v2, v1 } } return v1, v2 } -func intPlusInt(v1, v2 int64) numeric { +func intPlusInt(v1, v2 int64) evalResult { result := v1 + v2 if v1 > 0 && v2 > 0 && result < 0 { goto overflow @@ -539,61 +533,61 @@ func intPlusInt(v1, v2 int64) numeric { if v1 < 0 && v2 < 0 && result > 0 { goto overflow } - return numeric{typ: Int64, ival: result} + return evalResult{typ: sqltypes.Int64, ival: result} overflow: - return numeric{typ: Float64, fval: float64(v1) + float64(v2)} + return evalResult{typ: sqltypes.Float64, fval: float64(v1) + float64(v2)} } -func intPlusIntWithError(v1, v2 int64) (numeric, error) { +func intPlusIntWithError(v1, v2 int64) (evalResult, error) { result := v1 + v2 if (result > v1) != (v2 > 0) { - return numeric{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT value is out of range in %v + %v", v1, v2) + return evalResult{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT value is out of range in %v + %v", v1, v2) } - return numeric{typ: Int64, ival: result}, nil + return evalResult{typ: sqltypes.Int64, ival: result}, nil } -func intMinusIntWithError(v1, v2 int64) (numeric, error) { +func intMinusIntWithError(v1, v2 int64) (evalResult, error) { result := v1 - v2 if (result < v1) != (v2 > 0) { - return numeric{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT value is out of range in %v - %v", v1, v2) + return evalResult{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT value is out of range in %v - %v", v1, v2) } - return numeric{typ: Int64, ival: result}, nil + return evalResult{typ: sqltypes.Int64, ival: result}, nil } -func intTimesIntWithError(v1, v2 int64) (numeric, error) { +func intTimesIntWithError(v1, v2 int64) (evalResult, error) { result := v1 * v2 if v1 != 0 && result/v1 != v2 { - return numeric{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT value is out of range in %v * %v", v1, v2) + return evalResult{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT value is out of range in %v * %v", v1, v2) } - return numeric{typ: Int64, ival: result}, nil + return evalResult{typ: sqltypes.Int64, ival: result}, nil } -func intMinusUintWithError(v1 int64, v2 uint64) (numeric, error) { +func intMinusUintWithError(v1 int64, v2 uint64) (evalResult, error) { if v1 < 0 || v1 < int64(v2) { - return numeric{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in %v - %v", v1, v2) + return evalResult{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in %v - %v", v1, v2) } return uintMinusUintWithError(uint64(v1), v2) } -func uintPlusInt(v1 uint64, v2 int64) numeric { +func uintPlusInt(v1 uint64, v2 int64) evalResult { return uintPlusUint(v1, uint64(v2)) } -func uintPlusIntWithError(v1 uint64, v2 int64) (numeric, error) { +func uintPlusIntWithError(v1 uint64, v2 int64) (evalResult, error) { if v2 < 0 && v1 < uint64(v2) { - return numeric{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in %v + %v", v1, v2) + return evalResult{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in %v + %v", v1, v2) } // convert to int -> uint is because for numeric operators (such as + or -) // where one of the operands is an unsigned integer, the result is unsigned by default. return uintPlusUintWithError(v1, uint64(v2)) } -func uintMinusIntWithError(v1 uint64, v2 int64) (numeric, error) { +func uintMinusIntWithError(v1 uint64, v2 int64) (evalResult, error) { if int64(v1) < v2 && v2 > 0 { - return numeric{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in %v - %v", v1, v2) + return evalResult{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in %v - %v", v1, v2) } // uint - (- int) = uint + int if v2 < 0 { @@ -602,192 +596,194 @@ func uintMinusIntWithError(v1 uint64, v2 int64) (numeric, error) { return uintMinusUintWithError(v1, uint64(v2)) } -func uintTimesIntWithError(v1 uint64, v2 int64) (numeric, error) { +func uintTimesIntWithError(v1 uint64, v2 int64) (evalResult, error) { if v2 < 0 || int64(v1) < 0 { - return numeric{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in %v * %v", v1, v2) + return evalResult{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in %v * %v", v1, v2) } return uintTimesUintWithError(v1, uint64(v2)) } -func uintPlusUint(v1, v2 uint64) numeric { +func uintPlusUint(v1, v2 uint64) evalResult { result := v1 + v2 if result < v2 { - return numeric{typ: Float64, fval: float64(v1) + float64(v2)} + return evalResult{typ: sqltypes.Float64, fval: float64(v1) + float64(v2)} } - return numeric{typ: Uint64, uval: result} + return evalResult{typ: sqltypes.Uint64, uval: result} } -func uintPlusUintWithError(v1, v2 uint64) (numeric, error) { +func uintPlusUintWithError(v1, v2 uint64) (evalResult, error) { result := v1 + v2 if result < v2 { - return numeric{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in %v + %v", v1, v2) + return evalResult{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in %v + %v", v1, v2) } - return numeric{typ: Uint64, uval: result}, nil + return evalResult{typ: sqltypes.Uint64, uval: result}, nil } -func uintMinusUintWithError(v1, v2 uint64) (numeric, error) { +func uintMinusUintWithError(v1, v2 uint64) (evalResult, error) { result := v1 - v2 if v2 > v1 { - return numeric{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in %v - %v", v1, v2) + return evalResult{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in %v - %v", v1, v2) } - return numeric{typ: Uint64, uval: result}, nil + return evalResult{typ: sqltypes.Uint64, uval: result}, nil } -func uintTimesUintWithError(v1, v2 uint64) (numeric, error) { +func uintTimesUintWithError(v1, v2 uint64) (evalResult, error) { result := v1 * v2 if result < v2 || result < v1 { - return numeric{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in %v * %v", v1, v2) + return evalResult{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in %v * %v", v1, v2) } - return numeric{typ: Uint64, uval: result}, nil + return evalResult{typ: sqltypes.Uint64, uval: result}, nil } -func floatPlusAny(v1 float64, v2 numeric) numeric { +func floatPlusAny(v1 float64, v2 evalResult) evalResult { switch v2.typ { - case Int64: + case sqltypes.Int64: v2.fval = float64(v2.ival) - case Uint64: + case sqltypes.Uint64: v2.fval = float64(v2.uval) } - return numeric{typ: Float64, fval: v1 + v2.fval} + return evalResult{typ: sqltypes.Float64, fval: v1 + v2.fval} } -func floatMinusAny(v1 float64, v2 numeric) numeric { +func floatMinusAny(v1 float64, v2 evalResult) evalResult { switch v2.typ { - case Int64: + case sqltypes.Int64: v2.fval = float64(v2.ival) - case Uint64: + case sqltypes.Uint64: v2.fval = float64(v2.uval) } - return numeric{typ: Float64, fval: v1 - v2.fval} + return evalResult{typ: sqltypes.Float64, fval: v1 - v2.fval} } -func floatTimesAny(v1 float64, v2 numeric) numeric { +func floatTimesAny(v1 float64, v2 evalResult) evalResult { switch v2.typ { - case Int64: + case sqltypes.Int64: v2.fval = float64(v2.ival) - case Uint64: + case sqltypes.Uint64: v2.fval = float64(v2.uval) } - return numeric{typ: Float64, fval: v1 * v2.fval} + return evalResult{typ: sqltypes.Float64, fval: v1 * v2.fval} } -func floatDivideAnyWithError(v1 float64, v2 numeric) (numeric, error) { +func floatDivideAnyWithError(v1 float64, v2 evalResult) (evalResult, error) { switch v2.typ { - case Int64: + case sqltypes.Int64: v2.fval = float64(v2.ival) - case Uint64: + case sqltypes.Uint64: v2.fval = float64(v2.uval) } result := v1 / v2.fval divisorLessThanOne := v2.fval < 1 - resultMismatch := (v2.fval*result != v1) + resultMismatch := v2.fval*result != v1 if divisorLessThanOne && resultMismatch { - return numeric{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT is out of range in %v / %v", v1, v2.fval) + return evalResult{}, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT is out of range in %v / %v", v1, v2.fval) } - return numeric{typ: Float64, fval: v1 / v2.fval}, nil + return evalResult{typ: sqltypes.Float64, fval: v1 / v2.fval}, nil } -func anyMinusFloat(v1 numeric, v2 float64) numeric { +func anyMinusFloat(v1 evalResult, v2 float64) evalResult { switch v1.typ { - case Int64: + case sqltypes.Int64: v1.fval = float64(v1.ival) - case Uint64: + case sqltypes.Uint64: v1.fval = float64(v1.uval) } - return numeric{typ: Float64, fval: v1.fval - v2} + return evalResult{typ: sqltypes.Float64, fval: v1.fval - v2} } -func castFromNumeric(v numeric, resultType querypb.Type) Value { +func castFromNumeric(v evalResult, resultType querypb.Type) sqltypes.Value { switch { - case IsSigned(resultType): + case sqltypes.IsSigned(resultType): switch v.typ { - case Int64: - return MakeTrusted(resultType, strconv.AppendInt(nil, v.ival, 10)) - case Uint64: - return MakeTrusted(resultType, strconv.AppendInt(nil, int64(v.uval), 10)) - case Float64: - return MakeTrusted(resultType, strconv.AppendInt(nil, int64(v.fval), 10)) + case sqltypes.Int64: + return sqltypes.MakeTrusted(resultType, strconv.AppendInt(nil, v.ival, 10)) + case sqltypes.Uint64: + return sqltypes.MakeTrusted(resultType, strconv.AppendInt(nil, int64(v.uval), 10)) + case sqltypes.Float64: + return sqltypes.MakeTrusted(resultType, strconv.AppendInt(nil, int64(v.fval), 10)) } - case IsUnsigned(resultType): + case sqltypes.IsUnsigned(resultType): switch v.typ { - case Uint64: - return MakeTrusted(resultType, strconv.AppendUint(nil, v.uval, 10)) - case Int64: - return MakeTrusted(resultType, strconv.AppendUint(nil, uint64(v.ival), 10)) - case Float64: - return MakeTrusted(resultType, strconv.AppendUint(nil, uint64(v.fval), 10)) + case sqltypes.Uint64: + return sqltypes.MakeTrusted(resultType, strconv.AppendUint(nil, v.uval, 10)) + case sqltypes.Int64: + return sqltypes.MakeTrusted(resultType, strconv.AppendUint(nil, uint64(v.ival), 10)) + case sqltypes.Float64: + return sqltypes.MakeTrusted(resultType, strconv.AppendUint(nil, uint64(v.fval), 10)) } - case IsFloat(resultType) || resultType == Decimal: + case sqltypes.IsFloat(resultType) || resultType == sqltypes.Decimal: switch v.typ { - case Int64: - return MakeTrusted(resultType, strconv.AppendInt(nil, v.ival, 10)) - case Uint64: - return MakeTrusted(resultType, strconv.AppendUint(nil, v.uval, 10)) - case Float64: + case sqltypes.Int64: + return sqltypes.MakeTrusted(resultType, strconv.AppendInt(nil, v.ival, 10)) + case sqltypes.Uint64: + return sqltypes.MakeTrusted(resultType, strconv.AppendUint(nil, v.uval, 10)) + case sqltypes.Float64: format := byte('g') - if resultType == Decimal { + if resultType == sqltypes.Decimal { format = 'f' } - return MakeTrusted(resultType, strconv.AppendFloat(nil, v.fval, format, -1, 64)) + return sqltypes.MakeTrusted(resultType, strconv.AppendFloat(nil, v.fval, format, -1, 64)) } + case resultType == sqltypes.VarChar || resultType == sqltypes.VarBinary: + return sqltypes.MakeTrusted(resultType, []byte(v.str)) } - return NULL + return sqltypes.NULL } -func compareNumeric(v1, v2 numeric) int { +func compareNumeric(v1, v2 evalResult) int { // Equalize the types. switch v1.typ { - case Int64: + case sqltypes.Int64: switch v2.typ { - case Uint64: + case sqltypes.Uint64: if v1.ival < 0 { return -1 } - v1 = numeric{typ: Uint64, uval: uint64(v1.ival)} - case Float64: - v1 = numeric{typ: Float64, fval: float64(v1.ival)} + v1 = evalResult{typ: sqltypes.Uint64, uval: uint64(v1.ival)} + case sqltypes.Float64: + v1 = evalResult{typ: sqltypes.Float64, fval: float64(v1.ival)} } - case Uint64: + case sqltypes.Uint64: switch v2.typ { - case Int64: + case sqltypes.Int64: if v2.ival < 0 { return 1 } - v2 = numeric{typ: Uint64, uval: uint64(v2.ival)} - case Float64: - v1 = numeric{typ: Float64, fval: float64(v1.uval)} + v2 = evalResult{typ: sqltypes.Uint64, uval: uint64(v2.ival)} + case sqltypes.Float64: + v1 = evalResult{typ: sqltypes.Float64, fval: float64(v1.uval)} } - case Float64: + case sqltypes.Float64: switch v2.typ { - case Int64: - v2 = numeric{typ: Float64, fval: float64(v2.ival)} - case Uint64: - v2 = numeric{typ: Float64, fval: float64(v2.uval)} + case sqltypes.Int64: + v2 = evalResult{typ: sqltypes.Float64, fval: float64(v2.ival)} + case sqltypes.Uint64: + v2 = evalResult{typ: sqltypes.Float64, fval: float64(v2.uval)} } } // Both values are of the same type. switch v1.typ { - case Int64: + case sqltypes.Int64: switch { case v1.ival == v2.ival: return 0 case v1.ival < v2.ival: return -1 } - case Uint64: + case sqltypes.Uint64: switch { case v1.uval == v2.uval: return 0 case v1.uval < v2.uval: return -1 } - case Float64: + case sqltypes.Float64: switch { case v1.fval == v2.fval: return 0 diff --git a/go/vt/vtgate/evalengine/arithmetic_test.go b/go/vt/vtgate/evalengine/arithmetic_test.go new file mode 100644 index 00000000000..781e5d5de82 --- /dev/null +++ b/go/vt/vtgate/evalengine/arithmetic_test.go @@ -0,0 +1,1519 @@ +/* +Copyright 2019 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package evalengine + +import ( + "encoding/binary" + "fmt" + "math" + "reflect" + "strconv" + "testing" + + "vitess.io/vitess/go/sqltypes" + + querypb "vitess.io/vitess/go/vt/proto/query" + vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" + "vitess.io/vitess/go/vt/vterrors" +) + +func TestDivide(t *testing.T) { + tcases := []struct { + v1, v2 sqltypes.Value + out sqltypes.Value + err error + }{{ + //All Nulls + v1: sqltypes.NULL, + v2: sqltypes.NULL, + out: sqltypes.NULL, + }, { + // First value null. + v1: sqltypes.NULL, + v2: sqltypes.NewInt32(1), + out: sqltypes.NULL, + }, { + // Second value null. + v1: sqltypes.NewInt32(1), + v2: sqltypes.NULL, + out: sqltypes.NULL, + }, { + // Second arg 0 + v1: sqltypes.NewInt32(5), + v2: sqltypes.NewInt32(0), + out: sqltypes.NULL, + }, { + // Both arguments zero + v1: sqltypes.NewInt32(0), + v2: sqltypes.NewInt32(0), + out: sqltypes.NULL, + }, { + // case with negative value + v1: sqltypes.NewInt64(-1), + v2: sqltypes.NewInt64(-2), + out: sqltypes.NewFloat64(0.5000), + }, { + // float64 division by zero + v1: sqltypes.NewFloat64(2), + v2: sqltypes.NewFloat64(0), + out: sqltypes.NULL, + }, { + // Lower bound for int64 + v1: sqltypes.NewInt64(math.MinInt64), + v2: sqltypes.NewInt64(1), + out: sqltypes.NewFloat64(math.MinInt64), + }, { + // upper bound for uint64 + v1: sqltypes.NewUint64(math.MaxUint64), + v2: sqltypes.NewUint64(1), + out: sqltypes.NewFloat64(math.MaxUint64), + }, { + // testing for error in types + v1: sqltypes.TestValue(querypb.Type_INT64, "1.2"), + v2: sqltypes.NewInt64(2), + err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "strconv.ParseInt: parsing \"1.2\": invalid syntax"), + }, { + // testing for error in types + v1: sqltypes.NewInt64(2), + v2: sqltypes.TestValue(querypb.Type_INT64, "1.2"), + err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "strconv.ParseInt: parsing \"1.2\": invalid syntax"), + }, { + // testing for uint/int + v1: sqltypes.NewUint64(4), + v2: sqltypes.NewInt64(5), + out: sqltypes.NewFloat64(0.8), + }, { + // testing for uint/uint + v1: sqltypes.NewUint64(1), + v2: sqltypes.NewUint64(2), + out: sqltypes.NewFloat64(0.5), + }, { + // testing for float64/int64 + v1: sqltypes.TestValue(querypb.Type_FLOAT64, "1.2"), + v2: sqltypes.NewInt64(-2), + out: sqltypes.NewFloat64(-0.6), + }, { + // testing for float64/uint64 + v1: sqltypes.TestValue(querypb.Type_FLOAT64, "1.2"), + v2: sqltypes.NewUint64(2), + out: sqltypes.NewFloat64(0.6), + }, { + // testing for overflow of float64 + v1: sqltypes.NewFloat64(math.MaxFloat64), + v2: sqltypes.NewFloat64(0.5), + err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT is out of range in 1.7976931348623157e+308 / 0.5"), + }} + + for _, tcase := range tcases { + got, err := Divide(tcase.v1, tcase.v2) + + if !vterrors.Equals(err, tcase.err) { + t.Errorf("%v %v %v", printValue(tcase.v1), printValue(tcase.v2), vterrors.Print(err)) + t.Errorf("Divide(%v, %v) error: %v, want %v", printValue(tcase.v1), printValue(tcase.v2), vterrors.Print(err), vterrors.Print(tcase.err)) + } + + if tcase.err != nil { + continue + } + + if !reflect.DeepEqual(got, tcase.out) { + t.Errorf("Divide(%v, %v): %v, want %v", printValue(tcase.v1), printValue(tcase.v2), printValue(got), printValue(tcase.out)) + } + } + +} + +func TestMultiply(t *testing.T) { + tcases := []struct { + v1, v2 sqltypes.Value + out sqltypes.Value + err error + }{{ + //All Nulls + v1: sqltypes.NULL, + v2: sqltypes.NULL, + out: sqltypes.NULL, + }, { + // First value null. + v1: sqltypes.NewInt32(1), + v2: sqltypes.NULL, + out: sqltypes.NULL, + }, { + // Second value null. + v1: sqltypes.NULL, + v2: sqltypes.NewInt32(1), + out: sqltypes.NULL, + }, { + // case with negative value + v1: sqltypes.NewInt64(-1), + v2: sqltypes.NewInt64(-2), + out: sqltypes.NewInt64(2), + }, { + // testing for int64 overflow with min negative value + v1: sqltypes.NewInt64(math.MinInt64), + v2: sqltypes.NewInt64(1), + out: sqltypes.NewInt64(math.MinInt64), + }, { + // testing for error in types + v1: sqltypes.TestValue(querypb.Type_INT64, "1.2"), + v2: sqltypes.NewInt64(2), + err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "strconv.ParseInt: parsing \"1.2\": invalid syntax"), + }, { + // testing for error in types + v1: sqltypes.NewInt64(2), + v2: sqltypes.TestValue(querypb.Type_INT64, "1.2"), + err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "strconv.ParseInt: parsing \"1.2\": invalid syntax"), + }, { + // testing for uint*int + v1: sqltypes.NewUint64(4), + v2: sqltypes.NewInt64(5), + out: sqltypes.NewUint64(20), + }, { + // testing for uint*uint + v1: sqltypes.NewUint64(1), + v2: sqltypes.NewUint64(2), + out: sqltypes.NewUint64(2), + }, { + // testing for float64*int64 + v1: sqltypes.TestValue(querypb.Type_FLOAT64, "1.2"), + v2: sqltypes.NewInt64(-2), + out: sqltypes.NewFloat64(-2.4), + }, { + // testing for float64*uint64 + v1: sqltypes.TestValue(querypb.Type_FLOAT64, "1.2"), + v2: sqltypes.NewUint64(2), + out: sqltypes.NewFloat64(2.4), + }, { + // testing for overflow of int64 + v1: sqltypes.NewInt64(math.MaxInt64), + v2: sqltypes.NewInt64(2), + err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT value is out of range in 9223372036854775807 * 2"), + }, { + // testing for underflow of uint64*max.uint64 + v1: sqltypes.NewInt64(2), + v2: sqltypes.NewUint64(math.MaxUint64), + err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in 18446744073709551615 * 2"), + }, { + v1: sqltypes.NewUint64(math.MaxUint64), + v2: sqltypes.NewUint64(1), + out: sqltypes.NewUint64(math.MaxUint64), + }, { + //Checking whether maxInt value can be passed as uint value + v1: sqltypes.NewUint64(math.MaxInt64), + v2: sqltypes.NewInt64(3), + err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in 9223372036854775807 * 3"), + }} + + for _, tcase := range tcases { + + got, err := Multiply(tcase.v1, tcase.v2) + + if !vterrors.Equals(err, tcase.err) { + t.Errorf("Multiply(%v, %v) error: %v, want %v", printValue(tcase.v1), printValue(tcase.v2), vterrors.Print(err), vterrors.Print(tcase.err)) + } + if tcase.err != nil { + continue + } + + if !reflect.DeepEqual(got, tcase.out) { + t.Errorf("Multiply(%v, %v): %v, want %v", printValue(tcase.v1), printValue(tcase.v2), printValue(got), printValue(tcase.out)) + } + } + +} + +func TestSubtract(t *testing.T) { + tcases := []struct { + v1, v2 sqltypes.Value + out sqltypes.Value + err error + }{{ + // All Nulls + v1: sqltypes.NULL, + v2: sqltypes.NULL, + out: sqltypes.NULL, + }, { + // First value null. + v1: sqltypes.NewInt32(1), + v2: sqltypes.NULL, + out: sqltypes.NULL, + }, { + // Second value null. + v1: sqltypes.NULL, + v2: sqltypes.NewInt32(1), + out: sqltypes.NULL, + }, { + // case with negative value + v1: sqltypes.NewInt64(-1), + v2: sqltypes.NewInt64(-2), + out: sqltypes.NewInt64(1), + }, { + // testing for int64 overflow with min negative value + v1: sqltypes.NewInt64(math.MinInt64), + v2: sqltypes.NewInt64(1), + err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT value is out of range in -9223372036854775808 - 1"), + }, { + v1: sqltypes.NewUint64(4), + v2: sqltypes.NewInt64(5), + err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in 4 - 5"), + }, { + // testing uint - int + v1: sqltypes.NewUint64(7), + v2: sqltypes.NewInt64(5), + out: sqltypes.NewUint64(2), + }, { + v1: sqltypes.NewUint64(math.MaxUint64), + v2: sqltypes.NewInt64(0), + out: sqltypes.NewUint64(math.MaxUint64), + }, { + // testing for int64 overflow + v1: sqltypes.NewInt64(math.MinInt64), + v2: sqltypes.NewUint64(0), + err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in -9223372036854775808 - 0"), + }, { + v1: sqltypes.TestValue(querypb.Type_VARCHAR, "c"), + v2: sqltypes.NewInt64(1), + out: sqltypes.NewInt64(-1), + }, { + v1: sqltypes.NewUint64(1), + v2: sqltypes.TestValue(querypb.Type_VARCHAR, "c"), + out: sqltypes.NewUint64(1), + }, { + // testing for error for parsing float value to uint64 + v1: sqltypes.TestValue(querypb.Type_UINT64, "1.2"), + v2: sqltypes.NewInt64(2), + err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "strconv.ParseUint: parsing \"1.2\": invalid syntax"), + }, { + // testing for error for parsing float value to uint64 + v1: sqltypes.NewUint64(2), + v2: sqltypes.TestValue(querypb.Type_UINT64, "1.2"), + err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "strconv.ParseUint: parsing \"1.2\": invalid syntax"), + }, { + // uint64 - uint64 + v1: sqltypes.NewUint64(8), + v2: sqltypes.NewUint64(4), + out: sqltypes.NewUint64(4), + }, { + // testing for float subtraction: float - int + v1: sqltypes.NewFloat64(1.2), + v2: sqltypes.NewInt64(2), + out: sqltypes.NewFloat64(-0.8), + }, { + // testing for float subtraction: float - uint + v1: sqltypes.NewFloat64(1.2), + v2: sqltypes.NewUint64(2), + out: sqltypes.NewFloat64(-0.8), + }, { + v1: sqltypes.NewInt64(-1), + v2: sqltypes.NewUint64(2), + err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in -1 - 2"), + }, { + v1: sqltypes.NewInt64(2), + v2: sqltypes.NewUint64(1), + out: sqltypes.NewUint64(1), + }, { + // testing int64 - float64 method + v1: sqltypes.NewInt64(-2), + v2: sqltypes.NewFloat64(1.0), + out: sqltypes.NewFloat64(-3.0), + }, { + // testing uint64 - float64 method + v1: sqltypes.NewUint64(1), + v2: sqltypes.NewFloat64(-2.0), + out: sqltypes.NewFloat64(3.0), + }, { + // testing uint - int to return uintplusint + v1: sqltypes.NewUint64(1), + v2: sqltypes.NewInt64(-2), + out: sqltypes.NewUint64(3), + }, { + // testing for float - float + v1: sqltypes.NewFloat64(1.2), + v2: sqltypes.NewFloat64(3.2), + out: sqltypes.NewFloat64(-2), + }, { + // testing uint - uint if v2 > v1 + v1: sqltypes.NewUint64(2), + v2: sqltypes.NewUint64(4), + err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in 2 - 4"), + }, { + // testing uint - (- int) + v1: sqltypes.NewUint64(1), + v2: sqltypes.NewInt64(-2), + out: sqltypes.NewUint64(3), + }} + + for _, tcase := range tcases { + + got, err := Subtract(tcase.v1, tcase.v2) + + if !vterrors.Equals(err, tcase.err) { + t.Errorf("Subtract(%v, %v) error: %v, want %v", printValue(tcase.v1), printValue(tcase.v2), vterrors.Print(err), vterrors.Print(tcase.err)) + } + if tcase.err != nil { + continue + } + + if !reflect.DeepEqual(got, tcase.out) { + t.Errorf("Subtract(%v, %v): %v, want %v", printValue(tcase.v1), printValue(tcase.v2), printValue(got), printValue(tcase.out)) + } + } + +} + +func TestAdd(t *testing.T) { + tcases := []struct { + v1, v2 sqltypes.Value + out sqltypes.Value + err error + }{{ + // All Nulls + v1: sqltypes.NULL, + v2: sqltypes.NULL, + out: sqltypes.NULL, + }, { + // First value null. + v1: sqltypes.NewInt32(1), + v2: sqltypes.NULL, + out: sqltypes.NULL, + }, { + // Second value null. + v1: sqltypes.NULL, + v2: sqltypes.NewInt32(1), + out: sqltypes.NULL, + }, { + // case with negatives + v1: sqltypes.NewInt64(-1), + v2: sqltypes.NewInt64(-2), + out: sqltypes.NewInt64(-3), + }, { + // testing for overflow int64, result will be unsigned int + v1: sqltypes.NewInt64(math.MaxInt64), + v2: sqltypes.NewUint64(2), + out: sqltypes.NewUint64(9223372036854775809), + }, { + v1: sqltypes.NewInt64(-2), + v2: sqltypes.NewUint64(1), + err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in 1 + -2"), + }, { + v1: sqltypes.NewInt64(math.MaxInt64), + v2: sqltypes.NewInt64(-2), + out: sqltypes.NewInt64(9223372036854775805), + }, { + // Normal case + v1: sqltypes.NewUint64(1), + v2: sqltypes.NewUint64(2), + out: sqltypes.NewUint64(3), + }, { + // testing for overflow uint64 + v1: sqltypes.NewUint64(math.MaxUint64), + v2: sqltypes.NewUint64(2), + err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in 18446744073709551615 + 2"), + }, { + // int64 underflow + v1: sqltypes.NewInt64(math.MinInt64), + v2: sqltypes.NewInt64(-2), + err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT value is out of range in -9223372036854775808 + -2"), + }, { + // checking int64 max value can be returned + v1: sqltypes.NewInt64(math.MaxInt64), + v2: sqltypes.NewUint64(0), + out: sqltypes.NewUint64(9223372036854775807), + }, { + // testing whether uint64 max value can be returned + v1: sqltypes.NewUint64(math.MaxUint64), + v2: sqltypes.NewInt64(0), + out: sqltypes.NewUint64(math.MaxUint64), + }, { + v1: sqltypes.NewUint64(math.MaxInt64), + v2: sqltypes.NewInt64(1), + out: sqltypes.NewUint64(9223372036854775808), + }, { + v1: sqltypes.NewUint64(1), + v2: sqltypes.TestValue(querypb.Type_VARCHAR, "c"), + out: sqltypes.NewUint64(1), + }, { + v1: sqltypes.NewUint64(1), + v2: sqltypes.TestValue(querypb.Type_VARCHAR, "1.2"), + out: sqltypes.NewFloat64(2.2), + }, { + v1: sqltypes.TestValue(querypb.Type_INT64, "1.2"), + v2: sqltypes.NewInt64(2), + err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "strconv.ParseInt: parsing \"1.2\": invalid syntax"), + }, { + v1: sqltypes.NewInt64(2), + v2: sqltypes.TestValue(querypb.Type_INT64, "1.2"), + err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "strconv.ParseInt: parsing \"1.2\": invalid syntax"), + }, { + // testing for uint64 overflow with max uint64 + int value + v1: sqltypes.NewUint64(math.MaxUint64), + v2: sqltypes.NewInt64(2), + err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in 18446744073709551615 + 2"), + }} + + for _, tcase := range tcases { + + got, err := Add(tcase.v1, tcase.v2) + + if !vterrors.Equals(err, tcase.err) { + t.Errorf("Add(%v, %v) error: %v, want %v", printValue(tcase.v1), printValue(tcase.v2), vterrors.Print(err), vterrors.Print(tcase.err)) + } + if tcase.err != nil { + continue + } + + if !reflect.DeepEqual(got, tcase.out) { + t.Errorf("Add(%v, %v): %v, want %v", printValue(tcase.v1), printValue(tcase.v2), printValue(got), printValue(tcase.out)) + } + } + +} + +func TestNullsafeAdd(t *testing.T) { + tcases := []struct { + v1, v2 sqltypes.Value + out sqltypes.Value + err error + }{{ + // All nulls. + v1: sqltypes.NULL, + v2: sqltypes.NULL, + out: sqltypes.NewInt64(0), + }, { + // First value null. + v1: sqltypes.NewInt32(1), + v2: sqltypes.NULL, + out: sqltypes.NewInt64(1), + }, { + // Second value null. + v1: sqltypes.NULL, + v2: sqltypes.NewInt32(1), + out: sqltypes.NewInt64(1), + }, { + // Normal case. + v1: sqltypes.NewInt64(1), + v2: sqltypes.NewInt64(2), + out: sqltypes.NewInt64(3), + }, { + // Make sure underlying error is returned for LHS. + v1: sqltypes.TestValue(querypb.Type_INT64, "1.2"), + v2: sqltypes.NewInt64(2), + err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "strconv.ParseInt: parsing \"1.2\": invalid syntax"), + }, { + // Make sure underlying error is returned for RHS. + v1: sqltypes.NewInt64(2), + v2: sqltypes.TestValue(querypb.Type_INT64, "1.2"), + err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "strconv.ParseInt: parsing \"1.2\": invalid syntax"), + }, { + // Make sure underlying error is returned while adding. + v1: sqltypes.NewInt64(-1), + v2: sqltypes.NewUint64(2), + out: sqltypes.NewInt64(-9223372036854775808), + }, { + // Make sure underlying error is returned while converting. + v1: sqltypes.NewFloat64(1), + v2: sqltypes.NewFloat64(2), + out: sqltypes.NewInt64(3), + }} + for _, tcase := range tcases { + got := NullsafeAdd(tcase.v1, tcase.v2, querypb.Type_INT64) + + if !reflect.DeepEqual(got, tcase.out) { + t.Errorf("NullsafeAdd(%v, %v): %v, want %v", printValue(tcase.v1), printValue(tcase.v2), printValue(got), printValue(tcase.out)) + } + } +} + +func TestNullsafeCompare(t *testing.T) { + tcases := []struct { + v1, v2 sqltypes.Value + out int + err error + }{{ + // All nulls. + v1: sqltypes.NULL, + v2: sqltypes.NULL, + out: 0, + }, { + // LHS null. + v1: sqltypes.NULL, + v2: sqltypes.NewInt64(1), + out: -1, + }, { + // RHS null. + v1: sqltypes.NewInt64(1), + v2: sqltypes.NULL, + out: 1, + }, { + // LHS Text + v1: sqltypes.TestValue(querypb.Type_VARCHAR, "abcd"), + v2: sqltypes.TestValue(querypb.Type_VARCHAR, "abcd"), + err: vterrors.New(vtrpcpb.Code_UNKNOWN, "types are not comparable: VARCHAR vs VARCHAR"), + }, { + // Make sure underlying error is returned for LHS. + v1: sqltypes.TestValue(querypb.Type_INT64, "1.2"), + v2: sqltypes.NewInt64(2), + err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "strconv.ParseInt: parsing \"1.2\": invalid syntax"), + }, { + // Make sure underlying error is returned for RHS. + v1: sqltypes.NewInt64(2), + v2: sqltypes.TestValue(querypb.Type_INT64, "1.2"), + err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "strconv.ParseInt: parsing \"1.2\": invalid syntax"), + }, { + // Numeric equal. + v1: sqltypes.NewInt64(1), + v2: sqltypes.NewUint64(1), + out: 0, + }, { + // Numeric unequal. + v1: sqltypes.NewInt64(1), + v2: sqltypes.NewUint64(2), + out: -1, + }, { + // Non-numeric equal + v1: sqltypes.TestValue(querypb.Type_VARBINARY, "abcd"), + v2: sqltypes.TestValue(querypb.Type_BINARY, "abcd"), + out: 0, + }, { + // Non-numeric unequal + v1: sqltypes.TestValue(querypb.Type_VARBINARY, "abcd"), + v2: sqltypes.TestValue(querypb.Type_BINARY, "bcde"), + out: -1, + }, { + // Date/Time types + v1: sqltypes.TestValue(querypb.Type_DATETIME, "1000-01-01 00:00:00"), + v2: sqltypes.TestValue(querypb.Type_BINARY, "1000-01-01 00:00:00"), + out: 0, + }, { + // Date/Time types + v1: sqltypes.TestValue(querypb.Type_DATETIME, "2000-01-01 00:00:00"), + v2: sqltypes.TestValue(querypb.Type_BINARY, "1000-01-01 00:00:00"), + out: 1, + }, { + // Date/Time types + v1: sqltypes.TestValue(querypb.Type_DATETIME, "1000-01-01 00:00:00"), + v2: sqltypes.TestValue(querypb.Type_BINARY, "2000-01-01 00:00:00"), + out: -1, + }} + for _, tcase := range tcases { + got, err := NullsafeCompare(tcase.v1, tcase.v2) + if !vterrors.Equals(err, tcase.err) { + t.Errorf("NullsafeCompare(%v, %v) error: %v, want %v", printValue(tcase.v1), printValue(tcase.v2), vterrors.Print(err), vterrors.Print(tcase.err)) + } + if tcase.err != nil { + continue + } + + if got != tcase.out { + t.Errorf("NullsafeCompare(%v, %v): %v, want %v", printValue(tcase.v1), printValue(tcase.v2), got, tcase.out) + } + } +} + +func TestCast(t *testing.T) { + tcases := []struct { + typ querypb.Type + v sqltypes.Value + out sqltypes.Value + err error + }{{ + typ: querypb.Type_VARCHAR, + v: sqltypes.NULL, + out: sqltypes.NULL, + }, { + typ: querypb.Type_VARCHAR, + v: sqltypes.TestValue(querypb.Type_VARCHAR, "exact types"), + out: sqltypes.TestValue(querypb.Type_VARCHAR, "exact types"), + }, { + typ: querypb.Type_INT64, + v: sqltypes.TestValue(querypb.Type_INT32, "32"), + out: sqltypes.TestValue(querypb.Type_INT64, "32"), + }, { + typ: querypb.Type_INT24, + v: sqltypes.TestValue(querypb.Type_UINT64, "64"), + out: sqltypes.TestValue(querypb.Type_INT24, "64"), + }, { + typ: querypb.Type_INT24, + v: sqltypes.TestValue(querypb.Type_VARCHAR, "bad int"), + err: vterrors.New(vtrpcpb.Code_UNKNOWN, `strconv.ParseInt: parsing "bad int": invalid syntax`), + }, { + typ: querypb.Type_UINT64, + v: sqltypes.TestValue(querypb.Type_UINT32, "32"), + out: sqltypes.TestValue(querypb.Type_UINT64, "32"), + }, { + typ: querypb.Type_UINT24, + v: sqltypes.TestValue(querypb.Type_INT64, "64"), + out: sqltypes.TestValue(querypb.Type_UINT24, "64"), + }, { + typ: querypb.Type_UINT24, + v: sqltypes.TestValue(querypb.Type_INT64, "-1"), + err: vterrors.New(vtrpcpb.Code_UNKNOWN, `strconv.ParseUint: parsing "-1": invalid syntax`), + }, { + typ: querypb.Type_FLOAT64, + v: sqltypes.TestValue(querypb.Type_INT64, "64"), + out: sqltypes.TestValue(querypb.Type_FLOAT64, "64"), + }, { + typ: querypb.Type_FLOAT32, + v: sqltypes.TestValue(querypb.Type_FLOAT64, "64"), + out: sqltypes.TestValue(querypb.Type_FLOAT32, "64"), + }, { + typ: querypb.Type_FLOAT32, + v: sqltypes.TestValue(querypb.Type_DECIMAL, "1.24"), + out: sqltypes.TestValue(querypb.Type_FLOAT32, "1.24"), + }, { + typ: querypb.Type_FLOAT64, + v: sqltypes.TestValue(querypb.Type_VARCHAR, "1.25"), + out: sqltypes.TestValue(querypb.Type_FLOAT64, "1.25"), + }, { + typ: querypb.Type_FLOAT64, + v: sqltypes.TestValue(querypb.Type_VARCHAR, "bad float"), + err: vterrors.New(vtrpcpb.Code_UNKNOWN, `strconv.ParseFloat: parsing "bad float": invalid syntax`), + }, { + typ: querypb.Type_VARCHAR, + v: sqltypes.TestValue(querypb.Type_INT64, "64"), + out: sqltypes.TestValue(querypb.Type_VARCHAR, "64"), + }, { + typ: querypb.Type_VARBINARY, + v: sqltypes.TestValue(querypb.Type_FLOAT64, "64"), + out: sqltypes.TestValue(querypb.Type_VARBINARY, "64"), + }, { + typ: querypb.Type_VARBINARY, + v: sqltypes.TestValue(querypb.Type_DECIMAL, "1.24"), + out: sqltypes.TestValue(querypb.Type_VARBINARY, "1.24"), + }, { + typ: querypb.Type_VARBINARY, + v: sqltypes.TestValue(querypb.Type_VARCHAR, "1.25"), + out: sqltypes.TestValue(querypb.Type_VARBINARY, "1.25"), + }, { + typ: querypb.Type_VARCHAR, + v: sqltypes.TestValue(querypb.Type_VARBINARY, "valid string"), + out: sqltypes.TestValue(querypb.Type_VARCHAR, "valid string"), + }, { + typ: querypb.Type_VARCHAR, + v: sqltypes.TestValue(sqltypes.Expression, "bad string"), + err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "EXPRESSION(bad string) cannot be cast to VARCHAR"), + }} + for _, tcase := range tcases { + got, err := Cast(tcase.v, tcase.typ) + if !vterrors.Equals(err, tcase.err) { + t.Errorf("Cast(%v) error: %v, want %v", tcase.v, vterrors.Print(err), vterrors.Print(tcase.err)) + } + if tcase.err != nil { + continue + } + + if !reflect.DeepEqual(got, tcase.out) { + t.Errorf("Cast(%v): %v, want %v", tcase.v, got, tcase.out) + } + } +} + +func TestToUint64(t *testing.T) { + tcases := []struct { + v sqltypes.Value + out uint64 + err error + }{{ + v: sqltypes.TestValue(querypb.Type_VARCHAR, "abcd"), + err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "could not parse value: 'abcd'"), + }, { + v: sqltypes.NewInt64(-1), + err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "negative number cannot be converted to unsigned: -1"), + }, { + v: sqltypes.NewInt64(1), + out: 1, + }, { + v: sqltypes.NewUint64(1), + out: 1, + }} + for _, tcase := range tcases { + got, err := ToUint64(tcase.v) + if !vterrors.Equals(err, tcase.err) { + t.Errorf("ToUint64(%v) error: %v, want %v", tcase.v, vterrors.Print(err), vterrors.Print(tcase.err)) + } + if tcase.err != nil { + continue + } + + if got != tcase.out { + t.Errorf("ToUint64(%v): %v, want %v", tcase.v, got, tcase.out) + } + } +} + +func TestToInt64(t *testing.T) { + tcases := []struct { + v sqltypes.Value + out int64 + err error + }{{ + v: sqltypes.TestValue(querypb.Type_VARCHAR, "abcd"), + err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "could not parse value: 'abcd'"), + }, { + v: sqltypes.NewUint64(18446744073709551615), + err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "unsigned number overflows int64 value: 18446744073709551615"), + }, { + v: sqltypes.NewInt64(1), + out: 1, + }, { + v: sqltypes.NewUint64(1), + out: 1, + }} + for _, tcase := range tcases { + got, err := ToInt64(tcase.v) + if !vterrors.Equals(err, tcase.err) { + t.Errorf("ToInt64(%v) error: %v, want %v", tcase.v, vterrors.Print(err), vterrors.Print(tcase.err)) + } + if tcase.err != nil { + continue + } + + if got != tcase.out { + t.Errorf("ToInt64(%v): %v, want %v", tcase.v, got, tcase.out) + } + } +} + +func TestToFloat64(t *testing.T) { + tcases := []struct { + v sqltypes.Value + out float64 + err error + }{{ + v: sqltypes.TestValue(querypb.Type_VARCHAR, "abcd"), + out: 0, + }, { + v: sqltypes.NewInt64(1), + out: 1, + }, { + v: sqltypes.NewUint64(1), + out: 1, + }, { + v: sqltypes.NewFloat64(1.2), + out: 1.2, + }, { + v: sqltypes.TestValue(querypb.Type_INT64, "1.2"), + err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "strconv.ParseInt: parsing \"1.2\": invalid syntax"), + }} + for _, tcase := range tcases { + got, err := ToFloat64(tcase.v) + if !vterrors.Equals(err, tcase.err) { + t.Errorf("ToFloat64(%v) error: %v, want %v", tcase.v, vterrors.Print(err), vterrors.Print(tcase.err)) + } + if tcase.err != nil { + continue + } + + if got != tcase.out { + t.Errorf("ToFloat64(%v): %v, want %v", tcase.v, got, tcase.out) + } + } +} + +func TestToNative(t *testing.T) { + testcases := []struct { + in sqltypes.Value + out interface{} + }{{ + in: sqltypes.NULL, + out: nil, + }, { + in: sqltypes.TestValue(querypb.Type_INT8, "1"), + out: int64(1), + }, { + in: sqltypes.TestValue(querypb.Type_INT16, "1"), + out: int64(1), + }, { + in: sqltypes.TestValue(querypb.Type_INT24, "1"), + out: int64(1), + }, { + in: sqltypes.TestValue(querypb.Type_INT32, "1"), + out: int64(1), + }, { + in: sqltypes.TestValue(querypb.Type_INT64, "1"), + out: int64(1), + }, { + in: sqltypes.TestValue(querypb.Type_UINT8, "1"), + out: uint64(1), + }, { + in: sqltypes.TestValue(querypb.Type_UINT16, "1"), + out: uint64(1), + }, { + in: sqltypes.TestValue(querypb.Type_UINT24, "1"), + out: uint64(1), + }, { + in: sqltypes.TestValue(querypb.Type_UINT32, "1"), + out: uint64(1), + }, { + in: sqltypes.TestValue(querypb.Type_UINT64, "1"), + out: uint64(1), + }, { + in: sqltypes.TestValue(querypb.Type_FLOAT32, "1"), + out: float64(1), + }, { + in: sqltypes.TestValue(querypb.Type_FLOAT64, "1"), + out: float64(1), + }, { + in: sqltypes.TestValue(querypb.Type_TIMESTAMP, "2012-02-24 23:19:43"), + out: []byte("2012-02-24 23:19:43"), + }, { + in: sqltypes.TestValue(querypb.Type_DATE, "2012-02-24"), + out: []byte("2012-02-24"), + }, { + in: sqltypes.TestValue(querypb.Type_TIME, "23:19:43"), + out: []byte("23:19:43"), + }, { + in: sqltypes.TestValue(querypb.Type_DATETIME, "2012-02-24 23:19:43"), + out: []byte("2012-02-24 23:19:43"), + }, { + in: sqltypes.TestValue(querypb.Type_YEAR, "1"), + out: uint64(1), + }, { + in: sqltypes.TestValue(querypb.Type_DECIMAL, "1"), + out: []byte("1"), + }, { + in: sqltypes.TestValue(querypb.Type_TEXT, "a"), + out: []byte("a"), + }, { + in: sqltypes.TestValue(querypb.Type_BLOB, "a"), + out: []byte("a"), + }, { + in: sqltypes.TestValue(querypb.Type_VARCHAR, "a"), + out: []byte("a"), + }, { + in: sqltypes.TestValue(querypb.Type_VARBINARY, "a"), + out: []byte("a"), + }, { + in: sqltypes.TestValue(querypb.Type_CHAR, "a"), + out: []byte("a"), + }, { + in: sqltypes.TestValue(querypb.Type_BINARY, "a"), + out: []byte("a"), + }, { + in: sqltypes.TestValue(querypb.Type_BIT, "1"), + out: []byte("1"), + }, { + in: sqltypes.TestValue(querypb.Type_ENUM, "a"), + out: []byte("a"), + }, { + in: sqltypes.TestValue(querypb.Type_SET, "a"), + out: []byte("a"), + }} + for _, tcase := range testcases { + v, err := ToNative(tcase.in) + if err != nil { + t.Error(err) + } + if !reflect.DeepEqual(v, tcase.out) { + t.Errorf("%v.ToNative = %#v, want %#v", tcase.in, v, tcase.out) + } + } + + // Test Expression failure. + _, err := ToNative(sqltypes.TestValue(querypb.Type_EXPRESSION, "aa")) + want := vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "EXPRESSION(aa) cannot be converted to a go type") + if !vterrors.Equals(err, want) { + t.Errorf("ToNative(EXPRESSION): %v, want %v", vterrors.Print(err), vterrors.Print(want)) + } +} + +func TestNewNumeric(t *testing.T) { + tcases := []struct { + v sqltypes.Value + out evalResult + err error + }{{ + v: sqltypes.NewInt64(1), + out: evalResult{typ: querypb.Type_INT64, ival: 1}, + }, { + v: sqltypes.NewUint64(1), + out: evalResult{typ: querypb.Type_UINT64, uval: 1}, + }, { + v: sqltypes.NewFloat64(1), + out: evalResult{typ: querypb.Type_FLOAT64, fval: 1}, + }, { + // For non-number type, Int64 is the default. + v: sqltypes.TestValue(querypb.Type_VARCHAR, "1"), + out: evalResult{typ: querypb.Type_INT64, ival: 1}, + }, { + // If Int64 can't work, we use Float64. + v: sqltypes.TestValue(querypb.Type_VARCHAR, "1.2"), + out: evalResult{typ: querypb.Type_FLOAT64, fval: 1.2}, + }, { + // Only valid Int64 allowed if type is Int64. + v: sqltypes.TestValue(querypb.Type_INT64, "1.2"), + err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "strconv.ParseInt: parsing \"1.2\": invalid syntax"), + }, { + // Only valid Uint64 allowed if type is Uint64. + v: sqltypes.TestValue(querypb.Type_UINT64, "1.2"), + err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "strconv.ParseUint: parsing \"1.2\": invalid syntax"), + }, { + // Only valid Float64 allowed if type is Float64. + v: sqltypes.TestValue(querypb.Type_FLOAT64, "abcd"), + err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "strconv.ParseFloat: parsing \"abcd\": invalid syntax"), + }, { + v: sqltypes.TestValue(querypb.Type_VARCHAR, "abcd"), + out: evalResult{typ: querypb.Type_FLOAT64, fval: 0}, + }} + for _, tcase := range tcases { + got, err := newNumeric(tcase.v) + if !vterrors.Equals(err, tcase.err) { + t.Errorf("newNumeric(%s) error: %v, want %v", printValue(tcase.v), vterrors.Print(err), vterrors.Print(tcase.err)) + } + if tcase.err == nil { + continue + } + + if got != tcase.out { + t.Errorf("newNumeric(%s): %v, want %v", printValue(tcase.v), got, tcase.out) + } + } +} + +func TestNewIntegralNumeric(t *testing.T) { + tcases := []struct { + v sqltypes.Value + out evalResult + err error + }{{ + v: sqltypes.NewInt64(1), + out: evalResult{typ: querypb.Type_INT64, ival: 1}, + }, { + v: sqltypes.NewUint64(1), + out: evalResult{typ: querypb.Type_UINT64, uval: 1}, + }, { + v: sqltypes.NewFloat64(1), + out: evalResult{typ: querypb.Type_INT64, ival: 1}, + }, { + // For non-number type, Int64 is the default. + v: sqltypes.TestValue(querypb.Type_VARCHAR, "1"), + out: evalResult{typ: querypb.Type_INT64, ival: 1}, + }, { + // If Int64 can't work, we use Uint64. + v: sqltypes.TestValue(querypb.Type_VARCHAR, "18446744073709551615"), + out: evalResult{typ: querypb.Type_UINT64, uval: 18446744073709551615}, + }, { + // Only valid Int64 allowed if type is Int64. + v: sqltypes.TestValue(querypb.Type_INT64, "1.2"), + err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "strconv.ParseInt: parsing \"1.2\": invalid syntax"), + }, { + // Only valid Uint64 allowed if type is Uint64. + v: sqltypes.TestValue(querypb.Type_UINT64, "1.2"), + err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "strconv.ParseUint: parsing \"1.2\": invalid syntax"), + }, { + v: sqltypes.TestValue(querypb.Type_VARCHAR, "abcd"), + err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "could not parse value: 'abcd'"), + }} + for _, tcase := range tcases { + got, err := newIntegralNumeric(tcase.v) + if err != nil && !vterrors.Equals(err, tcase.err) { + t.Errorf("newIntegralNumeric(%s) error: %v, want %v", printValue(tcase.v), vterrors.Print(err), vterrors.Print(tcase.err)) + } + if tcase.err == nil { + continue + } + + if got != tcase.out { + t.Errorf("newIntegralNumeric(%s): %v, want %v", printValue(tcase.v), got, tcase.out) + } + } +} + +func TestAddNumeric(t *testing.T) { + tcases := []struct { + v1, v2 evalResult + out evalResult + err error + }{{ + v1: evalResult{typ: querypb.Type_INT64, ival: 1}, + v2: evalResult{typ: querypb.Type_INT64, ival: 2}, + out: evalResult{typ: querypb.Type_INT64, ival: 3}, + }, { + v1: evalResult{typ: querypb.Type_INT64, ival: 1}, + v2: evalResult{typ: querypb.Type_UINT64, uval: 2}, + out: evalResult{typ: querypb.Type_UINT64, uval: 3}, + }, { + v1: evalResult{typ: querypb.Type_INT64, ival: 1}, + v2: evalResult{typ: querypb.Type_FLOAT64, fval: 2}, + out: evalResult{typ: querypb.Type_FLOAT64, fval: 3}, + }, { + v1: evalResult{typ: querypb.Type_UINT64, uval: 1}, + v2: evalResult{typ: querypb.Type_UINT64, uval: 2}, + out: evalResult{typ: querypb.Type_UINT64, uval: 3}, + }, { + v1: evalResult{typ: querypb.Type_UINT64, uval: 1}, + v2: evalResult{typ: querypb.Type_FLOAT64, fval: 2}, + out: evalResult{typ: querypb.Type_FLOAT64, fval: 3}, + }, { + v1: evalResult{typ: querypb.Type_FLOAT64, fval: 1}, + v2: evalResult{typ: querypb.Type_FLOAT64, fval: 2}, + out: evalResult{typ: querypb.Type_FLOAT64, fval: 3}, + }, { + // Int64 overflow. + v1: evalResult{typ: querypb.Type_INT64, ival: 9223372036854775807}, + v2: evalResult{typ: querypb.Type_INT64, ival: 2}, + out: evalResult{typ: querypb.Type_FLOAT64, fval: 9223372036854775809}, + }, { + // Int64 underflow. + v1: evalResult{typ: querypb.Type_INT64, ival: -9223372036854775807}, + v2: evalResult{typ: querypb.Type_INT64, ival: -2}, + out: evalResult{typ: querypb.Type_FLOAT64, fval: -9223372036854775809}, + }, { + v1: evalResult{typ: querypb.Type_INT64, ival: -1}, + v2: evalResult{typ: querypb.Type_UINT64, uval: 2}, + out: evalResult{typ: querypb.Type_FLOAT64, fval: 18446744073709551617}, + }, { + // Uint64 overflow. + v1: evalResult{typ: querypb.Type_UINT64, uval: 18446744073709551615}, + v2: evalResult{typ: querypb.Type_UINT64, uval: 2}, + out: evalResult{typ: querypb.Type_FLOAT64, fval: 18446744073709551617}, + }} + for _, tcase := range tcases { + got := addNumeric(tcase.v1, tcase.v2) + + if got != tcase.out { + t.Errorf("addNumeric(%v, %v): %v, want %v", tcase.v1, tcase.v2, got, tcase.out) + } + } +} + +func TestPrioritize(t *testing.T) { + ival := evalResult{typ: querypb.Type_INT64} + uval := evalResult{typ: querypb.Type_UINT64} + fval := evalResult{typ: querypb.Type_FLOAT64} + + tcases := []struct { + v1, v2 evalResult + out1, out2 evalResult + }{{ + v1: ival, + v2: uval, + out1: uval, + out2: ival, + }, { + v1: ival, + v2: fval, + out1: fval, + out2: ival, + }, { + v1: uval, + v2: ival, + out1: uval, + out2: ival, + }, { + v1: uval, + v2: fval, + out1: fval, + out2: uval, + }, { + v1: fval, + v2: ival, + out1: fval, + out2: ival, + }, { + v1: fval, + v2: uval, + out1: fval, + out2: uval, + }} + for _, tcase := range tcases { + got1, got2 := prioritize(tcase.v1, tcase.v2) + if got1 != tcase.out1 || got2 != tcase.out2 { + t.Errorf("prioritize(%v, %v): (%v, %v) , want (%v, %v)", tcase.v1.typ, tcase.v2.typ, got1.typ, got2.typ, tcase.out1.typ, tcase.out2.typ) + } + } +} + +func TestCastFromNumeric(t *testing.T) { + tcases := []struct { + typ querypb.Type + v evalResult + out sqltypes.Value + err error + }{{ + typ: querypb.Type_INT64, + v: evalResult{typ: querypb.Type_INT64, ival: 1}, + out: sqltypes.NewInt64(1), + }, { + typ: querypb.Type_INT64, + v: evalResult{typ: querypb.Type_UINT64, uval: 1}, + out: sqltypes.NewInt64(1), + }, { + typ: querypb.Type_INT64, + v: evalResult{typ: querypb.Type_FLOAT64, fval: 1.2e-16}, + out: sqltypes.NewInt64(0), + }, { + typ: querypb.Type_UINT64, + v: evalResult{typ: querypb.Type_INT64, ival: 1}, + out: sqltypes.NewUint64(1), + }, { + typ: querypb.Type_UINT64, + v: evalResult{typ: querypb.Type_UINT64, uval: 1}, + out: sqltypes.NewUint64(1), + }, { + typ: querypb.Type_UINT64, + v: evalResult{typ: querypb.Type_FLOAT64, fval: 1.2e-16}, + out: sqltypes.NewUint64(0), + }, { + typ: querypb.Type_FLOAT64, + v: evalResult{typ: querypb.Type_INT64, ival: 1}, + out: sqltypes.TestValue(querypb.Type_FLOAT64, "1"), + }, { + typ: querypb.Type_FLOAT64, + v: evalResult{typ: querypb.Type_UINT64, uval: 1}, + out: sqltypes.TestValue(querypb.Type_FLOAT64, "1"), + }, { + typ: querypb.Type_FLOAT64, + v: evalResult{typ: querypb.Type_FLOAT64, fval: 1.2e-16}, + out: sqltypes.TestValue(querypb.Type_FLOAT64, "1.2e-16"), + }, { + typ: querypb.Type_DECIMAL, + v: evalResult{typ: querypb.Type_INT64, ival: 1}, + out: sqltypes.TestValue(querypb.Type_DECIMAL, "1"), + }, { + typ: querypb.Type_DECIMAL, + v: evalResult{typ: querypb.Type_UINT64, uval: 1}, + out: sqltypes.TestValue(querypb.Type_DECIMAL, "1"), + }, { + // For float, we should not use scientific notation. + typ: querypb.Type_DECIMAL, + v: evalResult{typ: querypb.Type_FLOAT64, fval: 1.2e-16}, + out: sqltypes.TestValue(querypb.Type_DECIMAL, "0.00000000000000012"), + }} + for _, tcase := range tcases { + got := castFromNumeric(tcase.v, tcase.typ) + + if !reflect.DeepEqual(got, tcase.out) { + t.Errorf("castFromNumeric(%v, %v): %v, want %v", tcase.v, tcase.typ, printValue(got), printValue(tcase.out)) + } + } +} + +func TestCompareNumeric(t *testing.T) { + tcases := []struct { + v1, v2 evalResult + out int + }{{ + v1: evalResult{typ: querypb.Type_INT64, ival: 1}, + v2: evalResult{typ: querypb.Type_INT64, ival: 1}, + out: 0, + }, { + v1: evalResult{typ: querypb.Type_INT64, ival: 1}, + v2: evalResult{typ: querypb.Type_INT64, ival: 2}, + out: -1, + }, { + v1: evalResult{typ: querypb.Type_INT64, ival: 2}, + v2: evalResult{typ: querypb.Type_INT64, ival: 1}, + out: 1, + }, { + // Special case. + v1: evalResult{typ: querypb.Type_INT64, ival: -1}, + v2: evalResult{typ: querypb.Type_UINT64, uval: 1}, + out: -1, + }, { + v1: evalResult{typ: querypb.Type_INT64, ival: 1}, + v2: evalResult{typ: querypb.Type_UINT64, uval: 1}, + out: 0, + }, { + v1: evalResult{typ: querypb.Type_INT64, ival: 1}, + v2: evalResult{typ: querypb.Type_UINT64, uval: 2}, + out: -1, + }, { + v1: evalResult{typ: querypb.Type_INT64, ival: 2}, + v2: evalResult{typ: querypb.Type_UINT64, uval: 1}, + out: 1, + }, { + v1: evalResult{typ: querypb.Type_INT64, ival: 1}, + v2: evalResult{typ: querypb.Type_FLOAT64, fval: 1}, + out: 0, + }, { + v1: evalResult{typ: querypb.Type_INT64, ival: 1}, + v2: evalResult{typ: querypb.Type_FLOAT64, fval: 2}, + out: -1, + }, { + v1: evalResult{typ: querypb.Type_INT64, ival: 2}, + v2: evalResult{typ: querypb.Type_FLOAT64, fval: 1}, + out: 1, + }, { + // Special case. + v1: evalResult{typ: querypb.Type_UINT64, uval: 1}, + v2: evalResult{typ: querypb.Type_INT64, ival: -1}, + out: 1, + }, { + v1: evalResult{typ: querypb.Type_UINT64, uval: 1}, + v2: evalResult{typ: querypb.Type_INT64, ival: 1}, + out: 0, + }, { + v1: evalResult{typ: querypb.Type_UINT64, uval: 1}, + v2: evalResult{typ: querypb.Type_INT64, ival: 2}, + out: -1, + }, { + v1: evalResult{typ: querypb.Type_UINT64, uval: 2}, + v2: evalResult{typ: querypb.Type_INT64, ival: 1}, + out: 1, + }, { + v1: evalResult{typ: querypb.Type_UINT64, uval: 1}, + v2: evalResult{typ: querypb.Type_UINT64, uval: 1}, + out: 0, + }, { + v1: evalResult{typ: querypb.Type_UINT64, uval: 1}, + v2: evalResult{typ: querypb.Type_UINT64, uval: 2}, + out: -1, + }, { + v1: evalResult{typ: querypb.Type_UINT64, uval: 2}, + v2: evalResult{typ: querypb.Type_UINT64, uval: 1}, + out: 1, + }, { + v1: evalResult{typ: querypb.Type_UINT64, uval: 1}, + v2: evalResult{typ: querypb.Type_FLOAT64, fval: 1}, + out: 0, + }, { + v1: evalResult{typ: querypb.Type_UINT64, uval: 1}, + v2: evalResult{typ: querypb.Type_FLOAT64, fval: 2}, + out: -1, + }, { + v1: evalResult{typ: querypb.Type_UINT64, uval: 2}, + v2: evalResult{typ: querypb.Type_FLOAT64, fval: 1}, + out: 1, + }, { + v1: evalResult{typ: querypb.Type_FLOAT64, fval: 1}, + v2: evalResult{typ: querypb.Type_INT64, ival: 1}, + out: 0, + }, { + v1: evalResult{typ: querypb.Type_FLOAT64, fval: 1}, + v2: evalResult{typ: querypb.Type_INT64, ival: 2}, + out: -1, + }, { + v1: evalResult{typ: querypb.Type_FLOAT64, fval: 2}, + v2: evalResult{typ: querypb.Type_INT64, ival: 1}, + out: 1, + }, { + v1: evalResult{typ: querypb.Type_FLOAT64, fval: 1}, + v2: evalResult{typ: querypb.Type_UINT64, uval: 1}, + out: 0, + }, { + v1: evalResult{typ: querypb.Type_FLOAT64, fval: 1}, + v2: evalResult{typ: querypb.Type_UINT64, uval: 2}, + out: -1, + }, { + v1: evalResult{typ: querypb.Type_FLOAT64, fval: 2}, + v2: evalResult{typ: querypb.Type_UINT64, uval: 1}, + out: 1, + }, { + v1: evalResult{typ: querypb.Type_FLOAT64, fval: 1}, + v2: evalResult{typ: querypb.Type_FLOAT64, fval: 1}, + out: 0, + }, { + v1: evalResult{typ: querypb.Type_FLOAT64, fval: 1}, + v2: evalResult{typ: querypb.Type_FLOAT64, fval: 2}, + out: -1, + }, { + v1: evalResult{typ: querypb.Type_FLOAT64, fval: 2}, + v2: evalResult{typ: querypb.Type_FLOAT64, fval: 1}, + out: 1, + }} + for _, tcase := range tcases { + got := compareNumeric(tcase.v1, tcase.v2) + if got != tcase.out { + t.Errorf("equalNumeric(%v, %v): %v, want %v", tcase.v1, tcase.v2, got, tcase.out) + } + } +} + +func TestMin(t *testing.T) { + tcases := []struct { + v1, v2 sqltypes.Value + min sqltypes.Value + err error + }{{ + v1: sqltypes.NULL, + v2: sqltypes.NULL, + min: sqltypes.NULL, + }, { + v1: sqltypes.NewInt64(1), + v2: sqltypes.NULL, + min: sqltypes.NewInt64(1), + }, { + v1: sqltypes.NULL, + v2: sqltypes.NewInt64(1), + min: sqltypes.NewInt64(1), + }, { + v1: sqltypes.NewInt64(1), + v2: sqltypes.NewInt64(2), + min: sqltypes.NewInt64(1), + }, { + v1: sqltypes.NewInt64(2), + v2: sqltypes.NewInt64(1), + min: sqltypes.NewInt64(1), + }, { + v1: sqltypes.NewInt64(1), + v2: sqltypes.NewInt64(1), + min: sqltypes.NewInt64(1), + }, { + v1: sqltypes.TestValue(querypb.Type_VARCHAR, "aa"), + v2: sqltypes.TestValue(querypb.Type_VARCHAR, "aa"), + err: vterrors.New(vtrpcpb.Code_UNKNOWN, "types are not comparable: VARCHAR vs VARCHAR"), + }} + for _, tcase := range tcases { + v, err := Min(tcase.v1, tcase.v2) + if !vterrors.Equals(err, tcase.err) { + t.Errorf("Min error: %v, want %v", vterrors.Print(err), vterrors.Print(tcase.err)) + } + if tcase.err != nil { + continue + } + + if !reflect.DeepEqual(v, tcase.min) { + t.Errorf("Min(%v, %v): %v, want %v", tcase.v1, tcase.v2, v, tcase.min) + } + } +} + +func TestMax(t *testing.T) { + tcases := []struct { + v1, v2 sqltypes.Value + max sqltypes.Value + err error + }{{ + v1: sqltypes.NULL, + v2: sqltypes.NULL, + max: sqltypes.NULL, + }, { + v1: sqltypes.NewInt64(1), + v2: sqltypes.NULL, + max: sqltypes.NewInt64(1), + }, { + v1: sqltypes.NULL, + v2: sqltypes.NewInt64(1), + max: sqltypes.NewInt64(1), + }, { + v1: sqltypes.NewInt64(1), + v2: sqltypes.NewInt64(2), + max: sqltypes.NewInt64(2), + }, { + v1: sqltypes.NewInt64(2), + v2: sqltypes.NewInt64(1), + max: sqltypes.NewInt64(2), + }, { + v1: sqltypes.NewInt64(1), + v2: sqltypes.NewInt64(1), + max: sqltypes.NewInt64(1), + }, { + v1: sqltypes.TestValue(querypb.Type_VARCHAR, "aa"), + v2: sqltypes.TestValue(querypb.Type_VARCHAR, "aa"), + err: vterrors.New(vtrpcpb.Code_UNKNOWN, "types are not comparable: VARCHAR vs VARCHAR"), + }} + for _, tcase := range tcases { + v, err := Max(tcase.v1, tcase.v2) + if !vterrors.Equals(err, tcase.err) { + t.Errorf("Max error: %v, want %v", vterrors.Print(err), vterrors.Print(tcase.err)) + } + if tcase.err != nil { + continue + } + + if !reflect.DeepEqual(v, tcase.max) { + t.Errorf("Max(%v, %v): %v, want %v", tcase.v1, tcase.v2, v, tcase.max) + } + } +} + +func printValue(v sqltypes.Value) string { + return fmt.Sprintf("%v:%q", v.Type(), v.ToBytes()) +} + +// These benchmarks show that using existing ASCII representations +// for numbers is about 6x slower than using native representations. +// However, 229ns is still a negligible time compared to the cost of +// other operations. The additional complexity of introducing native +// types is currently not worth it. So, we'll stay with the existing +// ASCII representation for now. Using interfaces is more expensive +// than native representation of values. This is probably because +// interfaces also allocate memory, and also perform type assertions. +// Actual benchmark is based on NoNative. So, the numbers are similar. +// Date: 6/4/17 +// Version: go1.8 +// BenchmarkAddActual-8 10000000 263 ns/op +// BenchmarkAddNoNative-8 10000000 228 ns/op +// BenchmarkAddNative-8 50000000 40.0 ns/op +// BenchmarkAddGoInterface-8 30000000 52.4 ns/op +// BenchmarkAddGoNonInterface-8 2000000000 1.00 ns/op +// BenchmarkAddGo-8 2000000000 1.00 ns/op +func BenchmarkAddActual(b *testing.B) { + v1 := sqltypes.MakeTrusted(querypb.Type_INT64, []byte("1")) + v2 := sqltypes.MakeTrusted(querypb.Type_INT64, []byte("12")) + for i := 0; i < b.N; i++ { + v1 = NullsafeAdd(v1, v2, querypb.Type_INT64) + } +} + +func BenchmarkAddNoNative(b *testing.B) { + v1 := sqltypes.MakeTrusted(querypb.Type_INT64, []byte("1")) + v2 := sqltypes.MakeTrusted(querypb.Type_INT64, []byte("12")) + for i := 0; i < b.N; i++ { + iv1, _ := ToInt64(v1) + iv2, _ := ToInt64(v2) + v1 = sqltypes.MakeTrusted(querypb.Type_INT64, strconv.AppendInt(nil, iv1+iv2, 10)) + } +} + +func BenchmarkAddNative(b *testing.B) { + v1 := makeNativeInt64(1) + v2 := makeNativeInt64(12) + for i := 0; i < b.N; i++ { + iv1 := int64(binary.BigEndian.Uint64(v1.Raw())) + iv2 := int64(binary.BigEndian.Uint64(v2.Raw())) + v1 = makeNativeInt64(iv1 + iv2) + } +} + +func makeNativeInt64(v int64) sqltypes.Value { + buf := make([]byte, 8) + binary.BigEndian.PutUint64(buf, uint64(v)) + return sqltypes.MakeTrusted(querypb.Type_INT64, buf) +} + +func BenchmarkAddGoInterface(b *testing.B) { + var v1, v2 interface{} + v1 = int64(1) + v2 = int64(2) + for i := 0; i < b.N; i++ { + v1 = v1.(int64) + v2.(int64) + } +} + +func BenchmarkAddGoNonInterface(b *testing.B) { + v1 := evalResult{typ: querypb.Type_INT64, ival: 1} + v2 := evalResult{typ: querypb.Type_INT64, ival: 12} + for i := 0; i < b.N; i++ { + if v1.typ != querypb.Type_INT64 { + b.Error("type assertion failed") + } + if v2.typ != querypb.Type_INT64 { + b.Error("type assertion failed") + } + v1 = evalResult{typ: querypb.Type_INT64, ival: v1.ival + v2.ival} + } +} + +func BenchmarkAddGo(b *testing.B) { + v1 := int64(1) + v2 := int64(2) + for i := 0; i < b.N; i++ { + v1 += v2 + } +} diff --git a/go/vt/vtgate/evalengine/expressions.go b/go/vt/vtgate/evalengine/expressions.go new file mode 100644 index 00000000000..36748e8b0f7 --- /dev/null +++ b/go/vt/vtgate/evalengine/expressions.go @@ -0,0 +1,281 @@ +/* +Copyright 2020 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package evalengine + +import ( + "strconv" + + "vitess.io/vitess/go/sqltypes" + + querypb "vitess.io/vitess/go/vt/proto/query" + vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" + "vitess.io/vitess/go/vt/vterrors" +) + +type ( + evalResult struct { + typ querypb.Type + ival int64 + uval uint64 + fval float64 + str string + } + //ExpressionEnv contains the environment that the expression + //evaluates in, such as the current row and bindvars + ExpressionEnv struct { + BindVars map[string]*querypb.BindVariable + Row []sqltypes.Value + } + + // EvalResult is used so we don't have to expose all parts of the private struct + EvalResult = evalResult + + // Expr is the interface that all evaluating expressions must implement + Expr interface { + Evaluate(env ExpressionEnv) (EvalResult, error) + Type(env ExpressionEnv) querypb.Type + String() string + } + + //BinaryExpr allows binary expressions to not have to evaluate child expressions - this is done by the BinaryOp + BinaryExpr interface { + Evaluate(left, right EvalResult) (EvalResult, error) + Type(left querypb.Type) querypb.Type + String() string + } + + // Expressions + LiteralInt struct{ Val EvalResult } + LiteralFloat struct{ Val EvalResult } + BindVariable struct{ Key string } + BinaryOp struct { + Expr BinaryExpr + Left, Right Expr + } + + // Binary ops + Addition struct{} + Subtraction struct{} + Multiplication struct{} + Division struct{} +) + +//Value allows for retrieval of the value we expose for public consumption +func (e EvalResult) Value() sqltypes.Value { + return castFromNumeric(e, e.typ) +} + +func NewLiteralInt(val []byte) (Expr, error) { + ival, err := strconv.ParseInt(string(val), 10, 64) + if err != nil { + return nil, err + } + return &LiteralFloat{evalResult{typ: sqltypes.Int64, ival: ival}}, nil +} + +func NewLiteralFloat(val []byte) (Expr, error) { + fval, err := strconv.ParseFloat(string(val), 64) + if err != nil { + return nil, err + } + return &LiteralFloat{evalResult{typ: sqltypes.Float64, fval: fval}}, nil +} + +var _ Expr = (*LiteralInt)(nil) +var _ Expr = (*LiteralFloat)(nil) +var _ Expr = (*BindVariable)(nil) +var _ Expr = (*BinaryOp)(nil) + +var _ BinaryExpr = (*Addition)(nil) +var _ BinaryExpr = (*Subtraction)(nil) +var _ BinaryExpr = (*Multiplication)(nil) +var _ BinaryExpr = (*Division)(nil) + +//Evaluate implements the Expr interface +func (b *BinaryOp) Evaluate(env ExpressionEnv) (EvalResult, error) { + lVal, err := b.Left.Evaluate(env) + if err != nil { + return EvalResult{}, err + } + rVal, err := b.Right.Evaluate(env) + if err != nil { + return EvalResult{}, err + } + return b.Expr.Evaluate(lVal, rVal) +} + +//Evaluate implements the Expr interface +func (l *LiteralInt) Evaluate(ExpressionEnv) (EvalResult, error) { + return l.Val, nil +} + +func (l *LiteralFloat) Evaluate(env ExpressionEnv) (EvalResult, error) { + return l.Val, nil +} + +//Evaluate implements the Expr interface +func (b *BindVariable) Evaluate(env ExpressionEnv) (EvalResult, error) { + val, ok := env.BindVars[b.Key] + if !ok { + return EvalResult{}, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "Bind variable not found") + } + return evaluateByType(val) +} + +//Evaluate implements the BinaryOp interface +func (a *Addition) Evaluate(left, right EvalResult) (EvalResult, error) { + return addNumericWithError(left, right) +} + +//Evaluate implements the BinaryOp interface +func (s *Subtraction) Evaluate(left, right EvalResult) (EvalResult, error) { + return subtractNumericWithError(left, right) +} + +//Evaluate implements the BinaryOp interface +func (m *Multiplication) Evaluate(left, right EvalResult) (EvalResult, error) { + return multiplyNumericWithError(left, right) +} + +//Evaluate implements the BinaryOp interface +func (d *Division) Evaluate(left, right EvalResult) (EvalResult, error) { + return divideNumericWithError(left, right) +} + +//Type implements the BinaryExpr interface +func (a *Addition) Type(left querypb.Type) querypb.Type { + return left +} + +//Type implements the BinaryExpr interface +func (m *Multiplication) Type(left querypb.Type) querypb.Type { + return left +} + +//Type implements the BinaryExpr interface +func (d *Division) Type(querypb.Type) querypb.Type { + return sqltypes.Float64 +} + +//Type implements the BinaryExpr interface +func (s *Subtraction) Type(left querypb.Type) querypb.Type { + return left +} + +//Type implements the Expr interface +func (b *BinaryOp) Type(env ExpressionEnv) querypb.Type { + ltype := b.Left.Type(env) + rtype := b.Right.Type(env) + typ := mergeNumericalTypes(ltype, rtype) + return b.Expr.Type(typ) +} + +//Type implements the Expr interface +func (b *BindVariable) Type(env ExpressionEnv) querypb.Type { + e := env.BindVars + return e[b.Key].Type +} + +//Type implements the Expr interface +func (l *LiteralInt) Type(_ ExpressionEnv) querypb.Type { + return sqltypes.Int64 +} + +func (l *LiteralFloat) Type(env ExpressionEnv) querypb.Type { + return sqltypes.Float64 +} + +//String implements the BinaryExpr interface +func (d *Division) String() string { + return "/" +} + +//String implements the BinaryExpr interface +func (m *Multiplication) String() string { + return "*" +} + +//String implements the BinaryExpr interface +func (s *Subtraction) String() string { + return "-" +} + +//String implements the BinaryExpr interface +func (a *Addition) String() string { + return "+" +} + +//String implements the Expr interface +func (b *BinaryOp) String() string { + return b.Left.String() + " " + b.Expr.String() + " " + b.Right.String() +} + +//String implements the Expr interface +func (b *BindVariable) String() string { + return ":" + b.Key +} + +//String implements the Expr interface +func (l *LiteralInt) String() string { + return l.Val.Value().String() +} + +func (l *LiteralFloat) String() string { + return l.Val.Value().String() +} + +func mergeNumericalTypes(ltype, rtype querypb.Type) querypb.Type { + switch ltype { + case sqltypes.Int64: + if rtype == sqltypes.Uint64 || rtype == sqltypes.Float64 { + return rtype + } + case sqltypes.Uint64: + if rtype == sqltypes.Float64 { + return rtype + } + } + return ltype +} + +func evaluateByType(val *querypb.BindVariable) (EvalResult, error) { + switch val.Type { + case sqltypes.Int64: + ival, err := strconv.ParseInt(string(val.Value), 10, 64) + if err != nil { + ival = 0 + } + return evalResult{typ: sqltypes.Int64, ival: ival}, nil + case sqltypes.Uint64: + uval, err := strconv.ParseUint(string(val.Value), 10, 64) + if err != nil { + uval = 0 + } + return evalResult{typ: sqltypes.Uint64, uval: uval}, nil + case sqltypes.Float64: + fval, err := strconv.ParseFloat(string(val.Value), 64) + if err != nil { + fval = 0 + } + return evalResult{typ: sqltypes.Float64, fval: fval}, nil + case sqltypes.VarChar: + return evalResult{typ: sqltypes.VarChar, str: string(val.Value)}, nil + case sqltypes.VarBinary: + return evalResult{typ: sqltypes.VarBinary, str: string(val.Value)}, nil + } + return evalResult{}, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "Type is not supported") +} diff --git a/go/vt/vtgate/evalengine/expressions_test.go b/go/vt/vtgate/evalengine/expressions_test.go new file mode 100644 index 00000000000..0cfda38dde1 --- /dev/null +++ b/go/vt/vtgate/evalengine/expressions_test.go @@ -0,0 +1,107 @@ +/* +Copyright 2020 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package evalengine + +import ( + "fmt" + "reflect" + "testing" + + "github.com/stretchr/testify/assert" + + "vitess.io/vitess/go/sqltypes" + + querypb "vitess.io/vitess/go/vt/proto/query" +) + +// more tests in go/sqlparser/expressions_test.go + +func TestBinaryOpTypes(t *testing.T) { + type testcase struct { + l, r, e querypb.Type + } + type ops struct { + op BinaryExpr + testcases []testcase + } + + tests := []ops{ + { + op: &Addition{}, + testcases: []testcase{ + {sqltypes.Int64, sqltypes.Int64, sqltypes.Int64}, + {sqltypes.Uint64, sqltypes.Int64, sqltypes.Uint64}, + {sqltypes.Float64, sqltypes.Int64, sqltypes.Float64}, + {sqltypes.Int64, sqltypes.Uint64, sqltypes.Int64}, + {sqltypes.Uint64, sqltypes.Uint64, sqltypes.Uint64}, + {sqltypes.Float64, sqltypes.Uint64, sqltypes.Float64}, + {sqltypes.Int64, sqltypes.Float64, sqltypes.Int64}, + {sqltypes.Uint64, sqltypes.Float64, sqltypes.Uint64}, + {sqltypes.Float64, sqltypes.Float64, sqltypes.Float64}, + }, + }, { + op: &Subtraction{}, + testcases: []testcase{ + {sqltypes.Int64, sqltypes.Int64, sqltypes.Int64}, + {sqltypes.Uint64, sqltypes.Int64, sqltypes.Uint64}, + {sqltypes.Float64, sqltypes.Int64, sqltypes.Float64}, + {sqltypes.Int64, sqltypes.Uint64, sqltypes.Int64}, + {sqltypes.Uint64, sqltypes.Uint64, sqltypes.Uint64}, + {sqltypes.Float64, sqltypes.Uint64, sqltypes.Float64}, + {sqltypes.Int64, sqltypes.Float64, sqltypes.Int64}, + {sqltypes.Uint64, sqltypes.Float64, sqltypes.Uint64}, + {sqltypes.Float64, sqltypes.Float64, sqltypes.Float64}, + }, + }, { + op: &Multiplication{}, + testcases: []testcase{ + {sqltypes.Int64, sqltypes.Int64, sqltypes.Int64}, + {sqltypes.Uint64, sqltypes.Int64, sqltypes.Uint64}, + {sqltypes.Float64, sqltypes.Int64, sqltypes.Float64}, + {sqltypes.Int64, sqltypes.Uint64, sqltypes.Int64}, + {sqltypes.Uint64, sqltypes.Uint64, sqltypes.Uint64}, + {sqltypes.Float64, sqltypes.Uint64, sqltypes.Float64}, + {sqltypes.Int64, sqltypes.Float64, sqltypes.Int64}, + {sqltypes.Uint64, sqltypes.Float64, sqltypes.Uint64}, + {sqltypes.Float64, sqltypes.Float64, sqltypes.Float64}, + }, + }, { + op: &Division{}, + testcases: []testcase{ + {sqltypes.Int64, sqltypes.Int64, sqltypes.Float64}, + {sqltypes.Uint64, sqltypes.Int64, sqltypes.Float64}, + {sqltypes.Float64, sqltypes.Int64, sqltypes.Float64}, + {sqltypes.Int64, sqltypes.Uint64, sqltypes.Float64}, + {sqltypes.Uint64, sqltypes.Uint64, sqltypes.Float64}, + {sqltypes.Float64, sqltypes.Uint64, sqltypes.Float64}, + {sqltypes.Int64, sqltypes.Float64, sqltypes.Float64}, + {sqltypes.Uint64, sqltypes.Float64, sqltypes.Float64}, + {sqltypes.Float64, sqltypes.Float64, sqltypes.Float64}, + }, + }, + } + + for _, op := range tests { + for _, tc := range op.testcases { + name := fmt.Sprintf("%s %s %s", tc.l.String(), reflect.TypeOf(op.op).String(), tc.r.String()) + t.Run(name, func(t *testing.T) { + result := op.op.Type(tc.l) + assert.Equal(t, tc.e, result) + }) + } + } +} diff --git a/go/vt/vtgate/executor_select_test.go b/go/vt/vtgate/executor_select_test.go index cac773aaaad..c8b90bd2ce3 100644 --- a/go/vt/vtgate/executor_select_test.go +++ b/go/vt/vtgate/executor_select_test.go @@ -222,42 +222,50 @@ func TestStreamBuffering(t *testing.T) { func TestSelectLastInsertId(t *testing.T) { masterSession.LastInsertId = 52 - executor, sbc1, _, _ := createExecutorEnv() + executor, _, _, _ := createExecutorEnv() executor.normalize = true logChan := QueryLogger.Subscribe("Test") defer QueryLogger.Unsubscribe(logChan) sql := "select last_insert_id()" - _, err := executorExec(executor, sql, map[string]*querypb.BindVariable{}) + masterSession.LastInsertId = 42 + result, err := executorExec(executor, sql, map[string]*querypb.BindVariable{}) + wantResult := &sqltypes.Result{ + Fields: []*querypb.Field{ + {Name: "last_insert_id()", Type: sqltypes.Uint64}, + }, + Rows: [][]sqltypes.Value{{ + sqltypes.NewUint64(42), + }}, + } require.NoError(t, err) - wantQueries := []*querypb.BoundQuery{{ - Sql: "select :__lastInsertId as `last_insert_id()` from dual", - BindVariables: map[string]*querypb.BindVariable{"__lastInsertId": sqltypes.Uint64BindVariable(52)}, - }} - - assert.Equal(t, wantQueries, sbc1.Queries) + utils.MustMatch(t, result, wantResult, "Mismatch") } -func TestSelectUserDefindVariable(t *testing.T) { - executor, sbc1, _, _ := createExecutorEnv() +func TestSelectUserDefinedVariable(t *testing.T) { + executor, _, _, _ := createExecutorEnv() executor.normalize = true logChan := QueryLogger.Subscribe("Test") defer QueryLogger.Unsubscribe(logChan) sql := "select @foo" masterSession = &vtgatepb.Session{UserDefinedVariables: createMap([]string{"foo"}, []interface{}{"bar"})} - _, err := executorExec(executor, sql, map[string]*querypb.BindVariable{}) + result, err := executorExec(executor, sql, map[string]*querypb.BindVariable{}) require.NoError(t, err) - wantQueries := []*querypb.BoundQuery{{ - Sql: "select :__vtudvfoo as `@foo` from dual", - BindVariables: map[string]*querypb.BindVariable{"__vtudvfoo": sqltypes.BytesBindVariable([]byte("bar"))}, - }} - - assert.Equal(t, wantQueries, sbc1.Queries) + wantResult := &sqltypes.Result{ + Fields: []*querypb.Field{ + {Name: "@foo", Type: sqltypes.VarBinary}, + }, + Rows: [][]sqltypes.Value{{ + sqltypes.NewVarBinary("bar"), + }}, + } + require.NoError(t, err) + utils.MustMatch(t, result, wantResult, "Mismatch") } func TestFoundRows(t *testing.T) { - executor, sbc1, _, _ := createExecutorEnv() + executor, _, _, _ := createExecutorEnv() executor.normalize = true logChan := QueryLogger.Subscribe("Test") defer QueryLogger.Unsubscribe(logChan) @@ -267,14 +275,17 @@ func TestFoundRows(t *testing.T) { require.NoError(t, err) sql := "select found_rows()" - _, err = executorExec(executor, sql, map[string]*querypb.BindVariable{}) - require.NoError(t, err) - expected := &querypb.BoundQuery{ - Sql: "select :__vtfrows as `found_rows()` from dual", - BindVariables: map[string]*querypb.BindVariable{"__vtfrows": sqltypes.Uint64BindVariable(1)}, + result, err := executorExec(executor, sql, map[string]*querypb.BindVariable{}) + wantResult := &sqltypes.Result{ + Fields: []*querypb.Field{ + {Name: "found_rows()", Type: sqltypes.Uint64}, + }, + Rows: [][]sqltypes.Value{{ + sqltypes.NewUint64(0), + }}, } - - assert.Equal(t, expected, sbc1.Queries[1]) + require.NoError(t, err) + utils.MustMatch(t, result, wantResult, "Mismatch") } func TestSelectLastInsertIdInUnion(t *testing.T) { @@ -361,26 +372,29 @@ func TestLastInsertIDInSubQueryExpression(t *testing.T) { } func TestSelectDatabase(t *testing.T) { - executor, sbc1, _, _ := createExecutorEnv() + executor, _, _, _ := createExecutorEnv() executor.normalize = true sql := "select database()" newSession := *masterSession session := NewSafeSession(&newSession) session.TargetString = "TestExecutor@master" - _, err := executor.Execute( + result, err := executor.Execute( context.Background(), "TestExecute", session, sql, map[string]*querypb.BindVariable{}) - + wantResult := &sqltypes.Result{ + Fields: []*querypb.Field{ + {Name: "database()", Type: sqltypes.VarBinary}, + }, + Rows: [][]sqltypes.Value{{ + sqltypes.NewVarBinary("TestExecutor"), + }}, + } require.NoError(t, err) - wantQueries := []*querypb.BoundQuery{{ - Sql: "select :__vtdbname as `database()` from dual", - BindVariables: map[string]*querypb.BindVariable{"__vtdbname": sqltypes.StringBindVariable("TestExecutor")}, - }} + utils.MustMatch(t, result, wantResult, "Mismatch") - assert.Equal(t, wantQueries, sbc1.Queries) } func TestSelectBindvars(t *testing.T) { diff --git a/go/vt/vtgate/plan_executor_select_test.go b/go/vt/vtgate/plan_executor_select_test.go index 72a2426c390..01d2a644909 100644 --- a/go/vt/vtgate/plan_executor_select_test.go +++ b/go/vt/vtgate/plan_executor_select_test.go @@ -22,6 +22,8 @@ import ( "strings" "testing" + "vitess.io/vitess/go/test/utils" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -221,42 +223,50 @@ func TestPlanStreamBuffering(t *testing.T) { func TestPlanSelectLastInsertId(t *testing.T) { masterSession.LastInsertId = 52 - executor, sbc1, _, _ := createExecutorEnvUsing(planAllTheThings) + executor, _, _, _ := createExecutorEnvUsing(planAllTheThings) executor.normalize = true logChan := QueryLogger.Subscribe("Test") defer QueryLogger.Unsubscribe(logChan) sql := "select last_insert_id()" - _, err := executorExec(executor, sql, map[string]*querypb.BindVariable{}) + result, err := executorExec(executor, sql, map[string]*querypb.BindVariable{}) + wantResult := &sqltypes.Result{ + Fields: []*querypb.Field{ + {Name: "last_insert_id()", Type: sqltypes.Uint64}, + }, + Rows: [][]sqltypes.Value{{ + sqltypes.NewUint64(52), + }}, + } require.NoError(t, err) - wantQueries := []*querypb.BoundQuery{{ - Sql: "select :__lastInsertId as `last_insert_id()` from dual", - BindVariables: map[string]*querypb.BindVariable{"__lastInsertId": sqltypes.Uint64BindVariable(52)}, - }} + utils.MustMatch(t, result, wantResult, "Mismatch") - assert.Equal(t, wantQueries, sbc1.Queries) } func TestPlanSelectUserDefindVariable(t *testing.T) { - executor, sbc1, _, _ := createExecutorEnvUsing(planAllTheThings) + executor, _, _, _ := createExecutorEnvUsing(planAllTheThings) executor.normalize = true logChan := QueryLogger.Subscribe("Test") defer QueryLogger.Unsubscribe(logChan) sql := "select @foo" masterSession = &vtgatepb.Session{UserDefinedVariables: createMap([]string{"foo"}, []interface{}{"bar"})} - _, err := executorExec(executor, sql, map[string]*querypb.BindVariable{}) + result, err := executorExec(executor, sql, map[string]*querypb.BindVariable{}) require.NoError(t, err) - wantQueries := []*querypb.BoundQuery{{ - Sql: "select :__vtudvfoo as `@foo` from dual", - BindVariables: map[string]*querypb.BindVariable{"__vtudvfoo": sqltypes.BytesBindVariable([]byte("bar"))}, - }} - - assert.Equal(t, wantQueries, sbc1.Queries) + wantResult := &sqltypes.Result{ + Fields: []*querypb.Field{ + {Name: "@foo", Type: sqltypes.VarBinary}, + }, + Rows: [][]sqltypes.Value{{ + sqltypes.NewVarBinary("bar"), + }}, + } + require.NoError(t, err) + utils.MustMatch(t, result, wantResult, "Mismatch") } func TestPlanFoundRows(t *testing.T) { - executor, sbc1, _, _ := createExecutorEnvUsing(planAllTheThings) + executor, _, _, _ := createExecutorEnvUsing(planAllTheThings) executor.normalize = true logChan := QueryLogger.Subscribe("Test") defer QueryLogger.Unsubscribe(logChan) @@ -266,14 +276,18 @@ func TestPlanFoundRows(t *testing.T) { require.NoError(t, err) sql := "select found_rows()" - _, err = executorExec(executor, sql, map[string]*querypb.BindVariable{}) - require.NoError(t, err) - expected := &querypb.BoundQuery{ - Sql: "select :__vtfrows as `found_rows()` from dual", - BindVariables: map[string]*querypb.BindVariable{"__vtfrows": sqltypes.Uint64BindVariable(1)}, + result, err := executorExec(executor, sql, map[string]*querypb.BindVariable{}) + wantResult := &sqltypes.Result{ + Fields: []*querypb.Field{ + {Name: "found_rows()", Type: sqltypes.Uint64}, + }, + Rows: [][]sqltypes.Value{{ + sqltypes.NewUint64(0), + }}, } + require.NoError(t, err) + utils.MustMatch(t, result, wantResult, "Mismatch") - assert.Equal(t, expected, sbc1.Queries[1]) } func TestPlanSelectLastInsertIdInUnion(t *testing.T) { @@ -360,26 +374,29 @@ func TestPlanLastInsertIDInSubQueryExpression(t *testing.T) { } func TestPlanSelectDatabase(t *testing.T) { - executor, sbc1, _, _ := createExecutorEnvUsing(planAllTheThings) + executor, _, _, _ := createExecutorEnvUsing(planAllTheThings) executor.normalize = true sql := "select database()" newSession := *masterSession session := NewSafeSession(&newSession) session.TargetString = "TestExecutor@master" - _, err := executor.Execute( + result, err := executor.Execute( context.Background(), "TestExecute", session, sql, map[string]*querypb.BindVariable{}) - + wantResult := &sqltypes.Result{ + Fields: []*querypb.Field{ + {Name: "database()", Type: sqltypes.VarBinary}, + }, + Rows: [][]sqltypes.Value{{ + sqltypes.NewVarBinary("TestExecutor"), + }}, + } require.NoError(t, err) - wantQueries := []*querypb.BoundQuery{{ - Sql: "select :__vtdbname as `database()` from dual", - BindVariables: map[string]*querypb.BindVariable{"__vtdbname": sqltypes.StringBindVariable("TestExecutor")}, - }} + utils.MustMatch(t, result, wantResult, "Mismatch") - assert.Equal(t, wantQueries, sbc1.Queries) } func TestPlanSelectBindvars(t *testing.T) { diff --git a/go/vt/vtgate/planbuilder/select.go b/go/vt/vtgate/planbuilder/select.go index 98a38757fb6..78ebcf3cc92 100644 --- a/go/vt/vtgate/planbuilder/select.go +++ b/go/vt/vtgate/planbuilder/select.go @@ -20,6 +20,8 @@ import ( "errors" "fmt" + "vitess.io/vitess/go/vt/vtgate/evalengine" + "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vtgate/engine" ) @@ -27,6 +29,12 @@ import ( // buildSelectPlan is the new function to build a Select plan. func buildSelectPlan(stmt sqlparser.Statement, vschema ContextVSchema) (engine.Primitive, error) { sel := stmt.(*sqlparser.Select) + + p := tryAtVtgate(sel) + if p != nil { + return p, nil + } + pb := newPrimitiveBuilder(vschema, newJointab(sqlparser.GetBindvars(sel))) if err := pb.processSelect(sel, nil); err != nil { return nil, err @@ -121,6 +129,50 @@ func (pb *primitiveBuilder) processSelect(sel *sqlparser.Select, outer *symtab) return nil } +func tryAtVtgate(sel *sqlparser.Select) engine.Primitive { + if checkForDual(sel) { + var err error + exprs := make([]evalengine.Expr, len(sel.SelectExprs)) + cols := make([]string, len(sel.SelectExprs)) + for i, e := range sel.SelectExprs { + expr := e.(*sqlparser.AliasedExpr) + exprs[i], err = sqlparser.Convert(expr.Expr) + if err != nil { + return nil + } + cols[i] = expr.As.String() + } + return &engine.Projection{ + Exprs: exprs, + Cols: cols, + Input: &engine.SingleRow{}, + } + } + return nil +} + +func checkForDual(sel *sqlparser.Select) bool { + if len(sel.From) == 1 { + if from, ok := sel.From[0].(*sqlparser.AliasedTableExpr); ok { + if tableName, ok := from.Expr.(sqlparser.TableName); ok { + if tableName.Name.String() == "dual" && tableName.Qualifier.IsEmpty() { + for _, expr := range sel.SelectExprs { + e, ok := expr.(*sqlparser.AliasedExpr) + if !ok { + return false + } + if _, ok := e.Expr.(*sqlparser.SQLVal); !ok { + return false + } + } + return true + } + } + } + } + return false +} + // pushFilter identifies the target route for the specified bool expr, // pushes it down, and updates the route info if the new constraint improves // the primitive. This function can push to a WHERE or HAVING clause. diff --git a/go/vt/vtgate/planbuilder/testdata/from_cases.txt b/go/vt/vtgate/planbuilder/testdata/from_cases.txt index 8ec63f3ea51..be353439399 100644 --- a/go/vt/vtgate/planbuilder/testdata/from_cases.txt +++ b/go/vt/vtgate/planbuilder/testdata/from_cases.txt @@ -2049,15 +2049,20 @@ "QueryType": "SELECT", "Original": "select last_insert_id()", "Instructions": { - "OperatorType": "Route", - "Variant": "SelectReference", - "Keyspace": { - "Name": "main", - "Sharded": false - }, - "FieldQuery": "select :__lastInsertId as `last_insert_id()` from dual where 1 != 1", - "Query": "select :__lastInsertId as `last_insert_id()` from dual", - "Table": "dual" + "OperatorType": "Projection", + "Variant": "", + "Columns": [ + "last_insert_id()" + ], + "Expressions": [ + ":__lastInsertId" + ], + "Inputs": [ + { + "OperatorType": "SingleRow", + "Variant": "" + } + ] } } diff --git a/go/vt/vtgate/planbuilder/testdata/select_cases.txt b/go/vt/vtgate/planbuilder/testdata/select_cases.txt index 1573f501f99..c4bc3f25327 100644 --- a/go/vt/vtgate/planbuilder/testdata/select_cases.txt +++ b/go/vt/vtgate/planbuilder/testdata/select_cases.txt @@ -411,15 +411,20 @@ "QueryType": "SELECT", "Original": "select database() from dual", "Instructions": { - "OperatorType": "Route", - "Variant": "SelectReference", - "Keyspace": { - "Name": "main", - "Sharded": false - }, - "FieldQuery": "select :__vtdbname as `database()` from dual where 1 != 1", - "Query": "select :__vtdbname as `database()` from dual", - "Table": "dual" + "OperatorType": "Projection", + "Variant": "", + "Columns": [ + "database()" + ], + "Expressions": [ + ":__vtdbname" + ], + "Inputs": [ + { + "OperatorType": "SingleRow", + "Variant": "" + } + ] } } @@ -1350,3 +1355,26 @@ "Table": "unsharded" } } + +# testing SingleRow Projection +"select 42" +{ + "QueryType": "SELECT", + "Original": "select 42", + "Instructions": { + "OperatorType": "Projection", + "Variant": "", + "Columns": [ + "" + ], + "Expressions": [ + "INT64(42)" + ], + "Inputs": [ + { + "OperatorType": "SingleRow", + "Variant": "" + } + ] + } +} diff --git a/go/vt/vtgate/vindexes/consistent_lookup.go b/go/vt/vtgate/vindexes/consistent_lookup.go index 84072e75beb..1b7b3c5294e 100644 --- a/go/vt/vtgate/vindexes/consistent_lookup.go +++ b/go/vt/vtgate/vindexes/consistent_lookup.go @@ -22,6 +22,8 @@ import ( "fmt" "strings" + "vitess.io/vitess/go/vt/vtgate/evalengine" + "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/vt/key" querypb "vitess.io/vitess/go/vt/proto/query" @@ -306,7 +308,7 @@ func (lu *clCommon) Delete(vcursor VCursor, rowsColValues [][]sqltypes.Value, ks func (lu *clCommon) Update(vcursor VCursor, oldValues []sqltypes.Value, ksid []byte, newValues []sqltypes.Value) error { equal := true for i := range oldValues { - result, err := sqltypes.NullsafeCompare(oldValues[i], newValues[i]) + result, err := evalengine.NullsafeCompare(oldValues[i], newValues[i]) // errors from NullsafeCompare can be ignored. if they are real problems, we'll see them in the Create/Update if err != nil || result != 0 { equal = false diff --git a/go/vt/vtgate/vindexes/hash.go b/go/vt/vtgate/vindexes/hash.go index 21b93dfe884..9d68b167b56 100644 --- a/go/vt/vtgate/vindexes/hash.go +++ b/go/vt/vtgate/vindexes/hash.go @@ -25,6 +25,8 @@ import ( "fmt" "strconv" + "vitess.io/vitess/go/vt/vtgate/evalengine" + "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/vt/key" "vitess.io/vitess/go/vt/vterrors" @@ -83,7 +85,7 @@ func (vind *Hash) Map(cursor VCursor, ids []sqltypes.Value) ([]key.Destination, ival, err = strconv.ParseInt(str, 10, 64) num = uint64(ival) } else { - num, err = sqltypes.ToUint64(id) + num, err = evalengine.ToUint64(id) } if err != nil { @@ -99,7 +101,7 @@ func (vind *Hash) Map(cursor VCursor, ids []sqltypes.Value) ([]key.Destination, func (vind *Hash) Verify(_ VCursor, ids []sqltypes.Value, ksids [][]byte) ([]bool, error) { out := make([]bool, len(ids)) for i := range ids { - num, err := sqltypes.ToUint64(ids[i]) + num, err := evalengine.ToUint64(ids[i]) if err != nil { return nil, vterrors.Wrap(err, "hash.Verify") } diff --git a/go/vt/vtgate/vindexes/lookup_hash.go b/go/vt/vtgate/vindexes/lookup_hash.go index a9d2ec0102d..5787308e276 100644 --- a/go/vt/vtgate/vindexes/lookup_hash.go +++ b/go/vt/vtgate/vindexes/lookup_hash.go @@ -20,6 +20,8 @@ import ( "encoding/json" "fmt" + "vitess.io/vitess/go/vt/vtgate/evalengine" + "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/vt/key" topodatapb "vitess.io/vitess/go/vt/proto/topodata" @@ -119,7 +121,7 @@ func (lh *LookupHash) Map(vcursor VCursor, ids []sqltypes.Value) ([]key.Destinat } ksids := make([][]byte, 0, len(result.Rows)) for _, row := range result.Rows { - num, err := sqltypes.ToUint64(row[0]) + num, err := evalengine.ToUint64(row[0]) if err != nil { // A failure to convert is equivalent to not being // able to map. @@ -273,7 +275,7 @@ func (lhu *LookupHashUnique) Map(vcursor VCursor, ids []sqltypes.Value) ([]key.D case 0: out = append(out, key.DestinationNone{}) case 1: - num, err := sqltypes.ToUint64(result.Rows[0][0]) + num, err := evalengine.ToUint64(result.Rows[0][0]) if err != nil { out = append(out, key.DestinationNone{}) continue diff --git a/go/vt/vtgate/vindexes/lookup_unicodeloosemd5_hash.go b/go/vt/vtgate/vindexes/lookup_unicodeloosemd5_hash.go index fccec4ec110..000b6d21608 100644 --- a/go/vt/vtgate/vindexes/lookup_unicodeloosemd5_hash.go +++ b/go/vt/vtgate/vindexes/lookup_unicodeloosemd5_hash.go @@ -21,6 +21,8 @@ import ( "encoding/json" "fmt" + "vitess.io/vitess/go/vt/vtgate/evalengine" + "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/vt/key" topodatapb "vitess.io/vitess/go/vt/proto/topodata" @@ -124,7 +126,7 @@ func (lh *LookupUnicodeLooseMD5Hash) Map(vcursor VCursor, ids []sqltypes.Value) } ksids := make([][]byte, 0, len(result.Rows)) for _, row := range result.Rows { - num, err := sqltypes.ToUint64(row[0]) + num, err := evalengine.ToUint64(row[0]) if err != nil { // A failure to convert is equivalent to not being // able to map. @@ -289,7 +291,7 @@ func (lhu *LookupUnicodeLooseMD5HashUnique) Map(vcursor VCursor, ids []sqltypes. case 0: out = append(out, key.DestinationNone{}) case 1: - num, err := sqltypes.ToUint64(result.Rows[0][0]) + num, err := evalengine.ToUint64(result.Rows[0][0]) if err != nil { out = append(out, key.DestinationNone{}) continue diff --git a/go/vt/vtgate/vindexes/numeric.go b/go/vt/vtgate/vindexes/numeric.go index 3b34b8a2298..e031031048d 100644 --- a/go/vt/vtgate/vindexes/numeric.go +++ b/go/vt/vtgate/vindexes/numeric.go @@ -21,6 +21,8 @@ import ( "encoding/binary" "fmt" + "vitess.io/vitess/go/vt/vtgate/evalengine" + "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/vt/key" "vitess.io/vitess/go/vt/vterrors" @@ -67,7 +69,7 @@ func (*Numeric) Verify(_ VCursor, ids []sqltypes.Value, ksids [][]byte) ([]bool, out := make([]bool, len(ids)) for i := range ids { var keybytes [8]byte - num, err := sqltypes.ToUint64(ids[i]) + num, err := evalengine.ToUint64(ids[i]) if err != nil { return nil, vterrors.Wrap(err, "Numeric.Verify") } @@ -81,7 +83,7 @@ func (*Numeric) Verify(_ VCursor, ids []sqltypes.Value, ksids [][]byte) ([]bool, func (*Numeric) Map(cursor VCursor, ids []sqltypes.Value) ([]key.Destination, error) { out := make([]key.Destination, 0, len(ids)) for _, id := range ids { - num, err := sqltypes.ToUint64(id) + num, err := evalengine.ToUint64(id) if err != nil { out = append(out, key.DestinationNone{}) continue diff --git a/go/vt/vtgate/vindexes/numeric_static_map.go b/go/vt/vtgate/vindexes/numeric_static_map.go index 6a1ac7ecbc1..e4501a2f923 100644 --- a/go/vt/vtgate/vindexes/numeric_static_map.go +++ b/go/vt/vtgate/vindexes/numeric_static_map.go @@ -24,6 +24,8 @@ import ( "io/ioutil" "strconv" + "vitess.io/vitess/go/vt/vtgate/evalengine" + "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/vt/key" "vitess.io/vitess/go/vt/vterrors" @@ -90,7 +92,7 @@ func (vind *NumericStaticMap) Verify(_ VCursor, ids []sqltypes.Value, ksids [][] out := make([]bool, len(ids)) for i := range ids { var keybytes [8]byte - num, err := sqltypes.ToUint64(ids[i]) + num, err := evalengine.ToUint64(ids[i]) if err != nil { return nil, vterrors.Wrap(err, "NumericStaticMap.Verify") } @@ -108,7 +110,7 @@ func (vind *NumericStaticMap) Verify(_ VCursor, ids []sqltypes.Value, ksids [][] func (vind *NumericStaticMap) Map(cursor VCursor, ids []sqltypes.Value) ([]key.Destination, error) { out := make([]key.Destination, 0, len(ids)) for _, id := range ids { - num, err := sqltypes.ToUint64(id) + num, err := evalengine.ToUint64(id) if err != nil { out = append(out, key.DestinationNone{}) continue diff --git a/go/vt/vtgate/vindexes/region_experimental.go b/go/vt/vtgate/vindexes/region_experimental.go index 50260b08d45..7a072dd70e6 100644 --- a/go/vt/vtgate/vindexes/region_experimental.go +++ b/go/vt/vtgate/vindexes/region_experimental.go @@ -21,6 +21,8 @@ import ( "encoding/binary" "fmt" + "vitess.io/vitess/go/vt/vtgate/evalengine" + "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/vt/key" ) @@ -94,7 +96,7 @@ func (ge *RegionExperimental) Map(vcursor VCursor, rowsColValues [][]sqltypes.Va continue } // Compute region prefix. - rn, err := sqltypes.ToUint64(row[0]) + rn, err := evalengine.ToUint64(row[0]) if err != nil { destinations = append(destinations, key.DestinationNone{}) continue @@ -103,7 +105,7 @@ func (ge *RegionExperimental) Map(vcursor VCursor, rowsColValues [][]sqltypes.Va binary.BigEndian.PutUint16(r, uint16(rn)) // Compute hash. - hn, err := sqltypes.ToUint64(row[1]) + hn, err := evalengine.ToUint64(row[1]) if err != nil { destinations = append(destinations, key.DestinationNone{}) continue diff --git a/go/vt/vtgate/vindexes/region_json.go b/go/vt/vtgate/vindexes/region_json.go index 8f934e941ad..9720819c565 100644 --- a/go/vt/vtgate/vindexes/region_json.go +++ b/go/vt/vtgate/vindexes/region_json.go @@ -24,6 +24,8 @@ import ( "io/ioutil" "strconv" + "vitess.io/vitess/go/vt/vtgate/evalengine" + "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/vt/key" "vitess.io/vitess/go/vt/log" @@ -108,7 +110,7 @@ func (rv *RegionJSON) Map(vcursor VCursor, rowsColValues [][]sqltypes.Value) ([] continue } // Compute hash. - hn, err := sqltypes.ToUint64(row[0]) + hn, err := evalengine.ToUint64(row[0]) if err != nil { destinations = append(destinations, key.DestinationNone{}) continue diff --git a/go/vt/vtgate/vindexes/reverse_bits.go b/go/vt/vtgate/vindexes/reverse_bits.go index d79d1f8a23f..9052a4c94ee 100644 --- a/go/vt/vtgate/vindexes/reverse_bits.go +++ b/go/vt/vtgate/vindexes/reverse_bits.go @@ -23,6 +23,8 @@ import ( "fmt" "math/bits" + "vitess.io/vitess/go/vt/vtgate/evalengine" + "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/vt/key" "vitess.io/vitess/go/vt/vterrors" @@ -68,7 +70,7 @@ func (vind *ReverseBits) NeedsVCursor() bool { func (vind *ReverseBits) Map(cursor VCursor, ids []sqltypes.Value) ([]key.Destination, error) { out := make([]key.Destination, len(ids)) for i, id := range ids { - num, err := sqltypes.ToUint64(id) + num, err := evalengine.ToUint64(id) if err != nil { out[i] = key.DestinationNone{} continue @@ -82,7 +84,7 @@ func (vind *ReverseBits) Map(cursor VCursor, ids []sqltypes.Value) ([]key.Destin func (vind *ReverseBits) Verify(_ VCursor, ids []sqltypes.Value, ksids [][]byte) ([]bool, error) { out := make([]bool, len(ids)) for i := range ids { - num, err := sqltypes.ToUint64(ids[i]) + num, err := evalengine.ToUint64(ids[i]) if err != nil { return nil, vterrors.Wrap(err, "reverseBits.Verify") } diff --git a/go/vt/vttablet/heartbeat/reader.go b/go/vt/vttablet/heartbeat/reader.go index dd1deeec79c..b17e9d0a235 100644 --- a/go/vt/vttablet/heartbeat/reader.go +++ b/go/vt/vttablet/heartbeat/reader.go @@ -21,6 +21,8 @@ import ( "sync" "time" + "vitess.io/vitess/go/vt/vtgate/evalengine" + "vitess.io/vitess/go/vt/vterrors" "golang.org/x/net/context" @@ -205,7 +207,7 @@ func parseHeartbeatResult(res *sqltypes.Result) (int64, error) { if len(res.Rows) != 1 { return 0, fmt.Errorf("failed to read heartbeat: writer query did not result in 1 row. Got %v", len(res.Rows)) } - ts, err := sqltypes.ToInt64(res.Rows[0][0]) + ts, err := evalengine.ToInt64(res.Rows[0][0]) if err != nil { return 0, err } diff --git a/go/vt/vttablet/tabletmanager/vreplication/engine.go b/go/vt/vttablet/tabletmanager/vreplication/engine.go index 9623563cc38..1c77300455c 100644 --- a/go/vt/vttablet/tabletmanager/vreplication/engine.go +++ b/go/vt/vttablet/tabletmanager/vreplication/engine.go @@ -24,6 +24,8 @@ import ( "sync" "time" + "vitess.io/vitess/go/vt/vtgate/evalengine" + "golang.org/x/net/context" "vitess.io/vitess/go/mysql" "vitess.io/vitess/go/sqltypes" @@ -392,7 +394,7 @@ func (vre *Engine) fetchIDs(dbClient binlogplayer.DBClient, selector string) (id return nil, nil, err } for _, row := range qr.Rows { - id, err := sqltypes.ToInt64(row[0]) + id, err := evalengine.ToInt64(row[0]) if err != nil { return nil, nil, err } diff --git a/go/vt/vttablet/tabletmanager/vreplication/vreplicator.go b/go/vt/vttablet/tabletmanager/vreplication/vreplicator.go index e22dc9638a1..c16cbde24dd 100644 --- a/go/vt/vttablet/tabletmanager/vreplication/vreplicator.go +++ b/go/vt/vttablet/tabletmanager/vreplication/vreplicator.go @@ -21,6 +21,8 @@ import ( "strings" "time" + "vitess.io/vitess/go/vt/vtgate/evalengine" + "golang.org/x/net/context" "vitess.io/vitess/go/mysql" @@ -208,7 +210,7 @@ func (vr *vreplicator) readSettings(ctx context.Context) (settings binlogplayer. if len(qr.Rows) == 0 || len(qr.Rows[0]) == 0 { return settings, numTablesToCopy, fmt.Errorf("unexpected result from %s: %v", query, qr) } - numTablesToCopy, err = sqltypes.ToInt64(qr.Rows[0][0]) + numTablesToCopy, err = evalengine.ToInt64(qr.Rows[0][0]) if err != nil { return settings, numTablesToCopy, err } diff --git a/go/vt/vttablet/tabletserver/messager/message_manager.go b/go/vt/vttablet/tabletserver/messager/message_manager.go index 2412eb601f4..7a0d5825771 100644 --- a/go/vt/vttablet/tabletserver/messager/message_manager.go +++ b/go/vt/vttablet/tabletserver/messager/message_manager.go @@ -22,6 +22,8 @@ import ( "sync" "time" + "vitess.io/vitess/go/vt/vtgate/evalengine" + "golang.org/x/net/context" "vitess.io/vitess/go/mysql" @@ -829,28 +831,28 @@ func (mm *messageManager) GeneratePurgeQuery(timeCutoff int64) (string, map[stri func BuildMessageRow(row []sqltypes.Value) (*MessageRow, error) { mr := &MessageRow{Row: row[4:]} if !row[0].IsNull() { - v, err := sqltypes.ToInt64(row[0]) + v, err := evalengine.ToInt64(row[0]) if err != nil { return nil, err } mr.Priority = v } if !row[1].IsNull() { - v, err := sqltypes.ToInt64(row[0]) + v, err := evalengine.ToInt64(row[0]) if err != nil { return nil, err } mr.TimeNext = v } if !row[2].IsNull() { - v, err := sqltypes.ToInt64(row[1]) + v, err := evalengine.ToInt64(row[1]) if err != nil { return nil, err } mr.Epoch = v } if !row[3].IsNull() { - v, err := sqltypes.ToInt64(row[2]) + v, err := evalengine.ToInt64(row[2]) if err != nil { return nil, err } diff --git a/go/vt/vttablet/tabletserver/messager/message_manager_test.go b/go/vt/vttablet/tabletserver/messager/message_manager_test.go index b9ee7a7fd73..0f6bbd44def 100644 --- a/go/vt/vttablet/tabletserver/messager/message_manager_test.go +++ b/go/vt/vttablet/tabletserver/messager/message_manager_test.go @@ -26,6 +26,8 @@ import ( "testing" "time" + "vitess.io/vitess/go/vt/vtgate/evalengine" + "vitess.io/vitess/go/test/utils" "github.com/stretchr/testify/assert" @@ -741,7 +743,7 @@ func TestMMGenerate(t *testing.T) { t.Errorf("GenerateAckQuery query: %s, want %s", query, wantQuery) } bvv, _ := sqltypes.BindVariableToValue(bv["time_acked"]) - gotAcked, _ := sqltypes.ToInt64(bvv) + gotAcked, _ := evalengine.ToInt64(bvv) wantAcked := time.Now().UnixNano() if wantAcked-gotAcked > 10e9 { t.Errorf("gotAcked: %d, should be with 10s of %d", gotAcked, wantAcked) diff --git a/go/vt/vttablet/tabletserver/query_executor.go b/go/vt/vttablet/tabletserver/query_executor.go index 91e51990bcd..ec6a3afe8c7 100644 --- a/go/vt/vttablet/tabletserver/query_executor.go +++ b/go/vt/vttablet/tabletserver/query_executor.go @@ -22,6 +22,8 @@ import ( "strings" "time" + "vitess.io/vitess/go/vt/vtgate/evalengine" + "golang.org/x/net/context" "vitess.io/vitess/go/mysql" @@ -400,7 +402,7 @@ func (qre *QueryExecutor) execNextval() (*sqltypes.Result, error) { if len(qr.Rows) != 1 { return nil, fmt.Errorf("unexpected rows from reading sequence %s (possible mis-route): %d", tableName, len(qr.Rows)) } - nextID, err := sqltypes.ToInt64(qr.Rows[0][0]) + nextID, err := evalengine.ToInt64(qr.Rows[0][0]) if err != nil { return nil, vterrors.Wrapf(err, "error loading sequence %s", tableName) } @@ -415,7 +417,7 @@ func (qre *QueryExecutor) execNextval() (*sqltypes.Result, error) { t.SequenceInfo.NextVal = nextID t.SequenceInfo.LastVal = nextID } - cache, err := sqltypes.ToInt64(qr.Rows[0][1]) + cache, err := evalengine.ToInt64(qr.Rows[0][1]) if err != nil { return nil, vterrors.Wrapf(err, "error loading sequence %s", tableName) } @@ -688,5 +690,5 @@ func resolveNumber(pv sqltypes.PlanValue, bindVars map[string]*querypb.BindVaria if err != nil { return 0, err } - return sqltypes.ToInt64(v) + return evalengine.ToInt64(v) } diff --git a/go/vt/vttablet/tabletserver/rules/rules.go b/go/vt/vttablet/tabletserver/rules/rules.go index b641b2e55ce..8dd25ef7450 100644 --- a/go/vt/vttablet/tabletserver/rules/rules.go +++ b/go/vt/vttablet/tabletserver/rules/rules.go @@ -24,6 +24,8 @@ import ( "regexp" "strconv" + "vitess.io/vitess/go/vt/vtgate/evalengine" + "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/vt/vterrors" "vitess.io/vitess/go/vt/vttablet/tabletserver/planbuilder" @@ -732,7 +734,7 @@ func getuint64(val *querypb.BindVariable) (uv uint64, status int) { if err != nil { return 0, QROutOfRange } - v, err := sqltypes.ToUint64(bv) + v, err := evalengine.ToUint64(bv) if err != nil { return 0, QROutOfRange } @@ -745,7 +747,7 @@ func getint64(val *querypb.BindVariable) (iv int64, status int) { if err != nil { return 0, QROutOfRange } - v, err := sqltypes.ToInt64(bv) + v, err := evalengine.ToInt64(bv) if err != nil { return 0, QROutOfRange } diff --git a/go/vt/vttablet/tabletserver/schema/engine.go b/go/vt/vttablet/tabletserver/schema/engine.go index 57b4400b794..d5fa83444f1 100644 --- a/go/vt/vttablet/tabletserver/schema/engine.go +++ b/go/vt/vttablet/tabletserver/schema/engine.go @@ -23,11 +23,12 @@ import ( "sync" "time" + "vitess.io/vitess/go/vt/vtgate/evalengine" + "golang.org/x/net/context" "vitess.io/vitess/go/acl" "vitess.io/vitess/go/mysql" - "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/timer" "vitess.io/vitess/go/vt/concurrency" "vitess.io/vitess/go/vt/dbconfigs" @@ -211,7 +212,7 @@ func (se *Engine) reload(ctx context.Context) error { for _, row := range tableData.Rows { tableName := row[0].ToString() curTables[tableName] = true - createTime, _ := sqltypes.ToInt64(row[2]) + createTime, _ := evalengine.ToInt64(row[2]) if _, ok := se.tables[tableName]; ok && createTime < se.lastChange { continue } @@ -266,7 +267,7 @@ func (se *Engine) mysqlTime(ctx context.Context, conn *connpool.DBConn) (int64, if len(tm.Rows) != 1 || len(tm.Rows[0]) != 1 || tm.Rows[0][0].IsNull() { return 0, vterrors.Errorf(vtrpcpb.Code_UNKNOWN, "unexpected result for MySQL time: %+v", tm.Rows) } - t, err := sqltypes.ToInt64(tm.Rows[0][0]) + t, err := evalengine.ToInt64(tm.Rows[0][0]) if err != nil { return 0, vterrors.Errorf(vtrpcpb.Code_UNKNOWN, "could not parse time %v: %v", tm, err) } diff --git a/go/vt/vttablet/tabletserver/twopc.go b/go/vt/vttablet/tabletserver/twopc.go index 8b8c87fb661..76c30ef9f63 100644 --- a/go/vt/vttablet/tabletserver/twopc.go +++ b/go/vt/vttablet/tabletserver/twopc.go @@ -20,6 +20,8 @@ import ( "fmt" "time" + "vitess.io/vitess/go/vt/vtgate/evalengine" + "golang.org/x/net/context" "vitess.io/vitess/go/sqlescape" @@ -288,12 +290,12 @@ func (tpc *TwoPC) ReadAllRedo(ctx context.Context) (prepared, failed []*Prepared // Initialize the new element. // A failure in time parsing will show up as a very old time, // which is harmless. - tm, _ := sqltypes.ToInt64(row[2]) + tm, _ := evalengine.ToInt64(row[2]) curTx = &PreparedTx{ Dtid: dtid, Time: time.Unix(0, tm), } - st, err := sqltypes.ToInt64(row[1]) + st, err := evalengine.ToInt64(row[1]) if err != nil { log.Errorf("Error parsing state for dtid %s: %v.", dtid, err) } @@ -330,7 +332,7 @@ func (tpc *TwoPC) CountUnresolvedRedo(ctx context.Context, unresolvedTime time.T if len(qr.Rows) < 1 { return 0, nil } - v, _ := sqltypes.ToInt64(qr.Rows[0][0]) + v, _ := evalengine.ToInt64(qr.Rows[0][0]) return v, nil } @@ -417,7 +419,7 @@ func (tpc *TwoPC) ReadTransaction(ctx context.Context, dtid string) (*querypb.Tr return result, nil } result.Dtid = qr.Rows[0][0].ToString() - st, err := sqltypes.ToInt64(qr.Rows[0][1]) + st, err := evalengine.ToInt64(qr.Rows[0][1]) if err != nil { return nil, vterrors.Wrapf(err, "Error parsing state for dtid %s", dtid) } @@ -427,7 +429,7 @@ func (tpc *TwoPC) ReadTransaction(ctx context.Context, dtid string) (*querypb.Tr } // A failure in time parsing will show up as a very old time, // which is harmless. - tm, _ := sqltypes.ToInt64(qr.Rows[0][2]) + tm, _ := evalengine.ToInt64(qr.Rows[0][2]) result.TimeCreated = tm qr, err = tpc.read(ctx, conn, tpc.readParticipants, bindVars) @@ -464,7 +466,7 @@ func (tpc *TwoPC) ReadAbandoned(ctx context.Context, abandonTime time.Time) (map } txs := make(map[string]time.Time, len(qr.Rows)) for _, row := range qr.Rows { - t, err := sqltypes.ToInt64(row[1]) + t, err := evalengine.ToInt64(row[1]) if err != nil { return nil, err } @@ -503,8 +505,8 @@ func (tpc *TwoPC) ReadAllTransactions(ctx context.Context) ([]*DistributedTx, er // Initialize the new element. // A failure in time parsing will show up as a very old time, // which is harmless. - tm, _ := sqltypes.ToInt64(row[2]) - st, err := sqltypes.ToInt64(row[1]) + tm, _ := evalengine.ToInt64(row[2]) + st, err := evalengine.ToInt64(row[1]) // Just log on error and continue. The state will show up as UNKNOWN // on the display. if err != nil { diff --git a/go/vt/vttablet/tabletserver/vstreamer/planbuilder.go b/go/vt/vttablet/tabletserver/vstreamer/planbuilder.go index c87c5cb462d..d10a73488f0 100644 --- a/go/vt/vttablet/tabletserver/vstreamer/planbuilder.go +++ b/go/vt/vttablet/tabletserver/vstreamer/planbuilder.go @@ -21,6 +21,8 @@ import ( "regexp" "strings" + "vitess.io/vitess/go/vt/vtgate/evalengine" + "vitess.io/vitess/go/mysql" "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/vt/key" @@ -115,7 +117,7 @@ func (plan *Plan) filter(values []sqltypes.Value) (bool, []sqltypes.Value, error for _, filter := range plan.Filters { switch filter.Opcode { case Equal: - result, err := sqltypes.NullsafeCompare(values[filter.ColNum], filter.Value) + result, err := evalengine.NullsafeCompare(values[filter.ColNum], filter.Value) if err != nil { return false, nil, err } diff --git a/go/vt/worker/chunk.go b/go/vt/worker/chunk.go index 89abd95383e..50c5cc5b0e8 100644 --- a/go/vt/worker/chunk.go +++ b/go/vt/worker/chunk.go @@ -19,6 +19,8 @@ package worker import ( "fmt" + "vitess.io/vitess/go/vt/vtgate/evalengine" + "vitess.io/vitess/go/vt/proto/vtrpc" "vitess.io/vitess/go/vt/vterrors" @@ -105,8 +107,8 @@ func generateChunks(ctx context.Context, wr *wrangler.Wrangler, tablet *topodata } result := sqltypes.Proto3ToResult(qr) - min, _ := sqltypes.ToNative(result.Rows[0][0]) - max, _ := sqltypes.ToNative(result.Rows[0][1]) + min, _ := evalengine.ToNative(result.Rows[0][0]) + max, _ := evalengine.ToNative(result.Rows[0][1]) if min == nil || max == nil { wr.Logger().Infof("table=%v: Not splitting the table into multiple chunks, min or max is NULL: %v", td.Name, qr.Rows[0]) diff --git a/go/vt/worker/diff_utils.go b/go/vt/worker/diff_utils.go index 132fed84575..b73785a037b 100644 --- a/go/vt/worker/diff_utils.go +++ b/go/vt/worker/diff_utils.go @@ -25,6 +25,8 @@ import ( "strings" "time" + "vitess.io/vitess/go/vt/vtgate/evalengine" + "vitess.io/vitess/go/vt/proto/vtrpc" "vitess.io/vitess/go/vt/vterrors" "vitess.io/vitess/go/vt/vttablet/tmclient" @@ -480,8 +482,8 @@ func RowsEqual(left, right []sqltypes.Value) int { // TODO: This can panic if types for left and right don't match. func CompareRows(fields []*querypb.Field, compareCount int, left, right []sqltypes.Value) (int, error) { for i := 0; i < compareCount; i++ { - lv, _ := sqltypes.ToNative(left[i]) - rv, _ := sqltypes.ToNative(right[i]) + lv, _ := evalengine.ToNative(left[i]) + rv, _ := evalengine.ToNative(right[i]) switch l := lv.(type) { case int64: r := rv.(int64) diff --git a/go/vt/worker/key_resolver.go b/go/vt/worker/key_resolver.go index a79fb1b66f6..b58fa09087f 100644 --- a/go/vt/worker/key_resolver.go +++ b/go/vt/worker/key_resolver.go @@ -19,6 +19,7 @@ package worker import ( "vitess.io/vitess/go/vt/proto/vtrpc" "vitess.io/vitess/go/vt/vterrors" + "vitess.io/vitess/go/vt/vtgate/evalengine" "vitess.io/vitess/go/sqltypes" @@ -78,7 +79,7 @@ func (r *v2Resolver) keyspaceID(row []sqltypes.Value) ([]byte, error) { case topodatapb.KeyspaceIdType_BYTES: return v.ToBytes(), nil case topodatapb.KeyspaceIdType_UINT64: - i, err := sqltypes.ToUint64(v) + i, err := evalengine.ToUint64(v) if err != nil { return nil, vterrors.Wrap(err, "Non numerical value") } diff --git a/go/vt/worker/split_clone_flaky_test.go b/go/vt/worker/split_clone_flaky_test.go index 716741da36d..0bf8336e4ee 100644 --- a/go/vt/worker/split_clone_flaky_test.go +++ b/go/vt/worker/split_clone_flaky_test.go @@ -26,6 +26,8 @@ import ( "testing" "time" + "vitess.io/vitess/go/vt/vtgate/evalengine" + "golang.org/x/net/context" "vitess.io/vitess/go/mysql" "vitess.io/vitess/go/mysql/fakesqldb" @@ -391,7 +393,7 @@ func (sq *testQueryService) StreamExecute(ctx context.Context, target *querypb.T // Send the values. rowsAffected := 0 for _, row := range sq.rows { - v, _ := sqltypes.ToNative(row[0]) + v, _ := evalengine.ToNative(row[0]) primaryKey := v.(int64) if primaryKey >= int64(min) && primaryKey < int64(max) { diff --git a/go/vt/wrangler/materializer.go b/go/vt/wrangler/materializer.go index 43213bd7e4d..4cab0c0a9c3 100644 --- a/go/vt/wrangler/materializer.go +++ b/go/vt/wrangler/materializer.go @@ -22,6 +22,8 @@ import ( "sync" "text/template" + "vitess.io/vitess/go/vt/vtgate/evalengine" + "github.com/gogo/protobuf/proto" "golang.org/x/net/context" @@ -476,7 +478,7 @@ func (wr *Wrangler) ExternalizeVindex(ctx context.Context, qualifiedVindexName s } qr := sqltypes.Proto3ToResult(p3qr) for _, row := range qr.Rows { - id, err := sqltypes.ToInt64(row[0]) + id, err := evalengine.ToInt64(row[0]) if err != nil { return err } diff --git a/go/vt/wrangler/stream_migrater.go b/go/vt/wrangler/stream_migrater.go index 963f6e711cb..52185107d5d 100644 --- a/go/vt/wrangler/stream_migrater.go +++ b/go/vt/wrangler/stream_migrater.go @@ -23,6 +23,8 @@ import ( "sync" "text/template" + "vitess.io/vitess/go/vt/vtgate/evalengine" + "github.com/golang/protobuf/proto" "golang.org/x/net/context" "vitess.io/vitess/go/mysql" @@ -225,7 +227,7 @@ func (sm *streamMigrater) readTabletStreams(ctx context.Context, ti *topo.Tablet tabletStreams := make([]*vrStream, 0, len(qr.Rows)) for _, row := range qr.Rows { - id, err := sqltypes.ToInt64(row[0]) + id, err := evalengine.ToInt64(row[0]) if err != nil { return nil, err } diff --git a/go/vt/wrangler/traffic_switcher.go b/go/vt/wrangler/traffic_switcher.go index 4cef3aa9c9b..d400a61d1f6 100644 --- a/go/vt/wrangler/traffic_switcher.go +++ b/go/vt/wrangler/traffic_switcher.go @@ -26,6 +26,8 @@ import ( "sync" "time" + "vitess.io/vitess/go/vt/vtgate/evalengine" + "vitess.io/vitess/go/vt/log" "github.com/golang/protobuf/proto" @@ -402,7 +404,7 @@ func (wr *Wrangler) buildTargets(ctx context.Context, targetKeyspace, workflow s } qr := sqltypes.Proto3ToResult(p3qr) for _, row := range qr.Rows { - id, err := sqltypes.ToInt64(row[0]) + id, err := evalengine.ToInt64(row[0]) if err != nil { return nil, false, err } diff --git a/go/vt/wrangler/vdiff.go b/go/vt/wrangler/vdiff.go index 1d113cd0cab..7a635533891 100644 --- a/go/vt/wrangler/vdiff.go +++ b/go/vt/wrangler/vdiff.go @@ -23,6 +23,8 @@ import ( "sync" "time" + "vitess.io/vitess/go/vt/vtgate/evalengine" + "github.com/golang/protobuf/proto" "golang.org/x/net/context" "vitess.io/vitess/go/mysql" @@ -830,7 +832,7 @@ func (td *tableDiffer) compare(sourceRow, targetRow []sqltypes.Value, cols []int if col == -1 { continue } - c, err := sqltypes.NullsafeCompare(sourceRow[col], targetRow[col]) + c, err := evalengine.NullsafeCompare(sourceRow[col], targetRow[col]) if err != nil { return 0, err }