Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions go/sqltypes/bind_variables.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,16 @@ func BuildBindVariables(in map[string]interface{}) (map[string]*querypb.BindVari
return out, nil
}

// HexNumBindVariable converts bytes representing a hex number to a bind var.
func HexNumBindVariable(v []byte) *querypb.BindVariable {
return ValueBindVariable(NewHexNum(v))
}

// HexValBindVariable converts bytes representing a hex encoded string to a bind var.
func HexValBindVariable(v []byte) *querypb.BindVariable {
return ValueBindVariable(NewHexVal(v))
}

// Int8BindVariable converts an int8 to a bind var.
func Int8BindVariable(v int8) *querypb.BindVariable {
return ValueBindVariable(NewInt8(v))
Expand Down
57 changes: 57 additions & 0 deletions go/sqltypes/result.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ limitations under the License.
package sqltypes

import (
"crypto/sha256"
"fmt"
"reflect"

"github.com/golang/protobuf/proto"
Expand Down Expand Up @@ -175,6 +177,61 @@ func ResultsEqual(r1, r2 []Result) bool {
return true
}

// ResultsEqualUnordered compares two unordered arrays of Result.
func ResultsEqualUnordered(r1, r2 []Result) bool {
if len(r1) != len(r2) {
return false
}

// allRows is a hash map that contains a row hashed as a key and
// the number of occurrence as the value. we use this map to ensure
// equality between the two result sets. when analyzing r1, we
// increment each key's value by one for each row's occurrence, and
// then we decrement it by one each time we see the same key in r2.
// if one of the key's value is not equal to zero, then r1 and r2 do
// not match.
allRows := map[string]int{}
countRows := 0
for _, r := range r1 {
saveRowsAnalysis(r, allRows, &countRows, true)
}
for _, r := range r2 {
saveRowsAnalysis(r, allRows, &countRows, false)
}
if countRows != 0 {
return false
}
for _, i := range allRows {
if i != 0 {
return false
}
}
return true
}

func saveRowsAnalysis(r Result, allRows map[string]int, totalRows *int, increment bool) {
for _, row := range r.Rows {
newHash := hashCodeForRow(row)
if increment {
allRows[newHash]++
} else {
allRows[newHash]--
}
}
if increment {
*totalRows += int(r.RowsAffected)
} else {
*totalRows -= int(r.RowsAffected)
}
}

func hashCodeForRow(val []Value) string {
h := sha256.New()
h.Write([]byte(fmt.Sprintf("%v", val)))

return fmt.Sprintf("%x", h.Sum(nil))
}

// MakeRowTrusted converts a *querypb.Row to []Value based on the types
// in fields. It does not sanity check the values against the type.
// Every place this function is called, a comment is needed that explains
Expand Down
2 changes: 2 additions & 0 deletions go/sqltypes/type.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,8 @@ const (
Geometry = querypb.Type_GEOMETRY
TypeJSON = querypb.Type_JSON
Expression = querypb.Type_EXPRESSION
HexNum = querypb.Type_HEXNUM
HexVal = querypb.Type_HEXVAL
)

// bit-shift the mysql flags by two byte so we
Expand Down
10 changes: 9 additions & 1 deletion go/sqltypes/type_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,12 @@ func TestTypeValues(t *testing.T) {
}, {
defined: Expression,
expected: 31,
}, {
defined: HexNum,
expected: 32 | flagIsText,
}, {
defined: HexVal,
expected: 33 | flagIsText,
}}
for _, tcase := range testcases {
if int(tcase.defined) != tcase.expected {
Expand Down Expand Up @@ -162,6 +168,8 @@ func TestCategory(t *testing.T) {
Geometry,
TypeJSON,
Expression,
HexNum,
HexVal,
}
for _, typ := range alltypes {
matched := false
Expand Down Expand Up @@ -192,7 +200,7 @@ func TestCategory(t *testing.T) {
}
matched = true
}
if typ == Null || typ == Decimal || typ == Expression || typ == Bit {
if typ == Null || typ == Decimal || typ == Expression || typ == Bit || typ == HexNum || typ == HexVal {
if matched {
t.Errorf("%v matched more than one category", typ)
}
Expand Down
63 changes: 62 additions & 1 deletion go/sqltypes/value.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,21 @@ package sqltypes

import (
"encoding/base64"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"regexp"
"strconv"
"strings"

"vitess.io/vitess/go/bytes2"
"vitess.io/vitess/go/hack"

"vitess.io/vitess/go/vt/log"
querypb "vitess.io/vitess/go/vt/proto/query"
"vitess.io/vitess/go/vt/proto/vtrpc"
"vitess.io/vitess/go/vt/vterrors"
)

var (
Expand Down Expand Up @@ -79,7 +84,7 @@ func NewValue(typ querypb.Type, val []byte) (v Value, err error) {
return NULL, err
}
return MakeTrusted(typ, val), nil
case IsQuoted(typ) || typ == Bit || typ == Null:
case IsQuoted(typ) || typ == Bit || typ == HexNum || typ == HexVal || typ == Null:
return MakeTrusted(typ, val), nil
}
// All other types are unsafe or invalid.
Expand All @@ -102,6 +107,16 @@ func MakeTrusted(typ querypb.Type, val []byte) Value {
return Value{typ: typ, val: val}
}

// NewHexNum builds an Hex Value.
func NewHexNum(v []byte) Value {
return MakeTrusted(HexNum, v)
}

// NewHexVal builds a HexVal Value.
func NewHexVal(v []byte) Value {
return MakeTrusted(HexVal, v)
}

// NewInt64 builds an Int64 Value.
func NewInt64(v int64) Value {
return MakeTrusted(Int64, strconv.AppendInt(nil, v, 10))
Expand Down Expand Up @@ -200,6 +215,20 @@ func (v Value) ToBytes() []byte {
if v.typ == Expression {
return nil
}
if v.typ == HexVal {
dv, err := v.decodeHexVal()
if err != nil {
log.Errorf("Unexpected error seen when returning MySQL representation of SQL Hex value: %v", err)
}
return dv
}
if v.typ == HexNum {
dv, err := v.decodeHexNum()
if err != nil {
log.Errorf("Unexpected error seen when returning MySQL representation of SQL Hex number: %v", err)
}
return dv
}
return v.val
}

Expand Down Expand Up @@ -385,6 +414,38 @@ func (v *Value) UnmarshalJSON(b []byte) error {
return err
}

// decodeHexVal decodes the SQL hex value of the form x'A1' into a byte
// array matching what MySQL would return when querying the column where
// an INSERT was performed with x'A1' having been specified as a value
func (v *Value) decodeHexVal() ([]byte, error) {
match, err := regexp.Match("^x'.*'$", v.val)
if !match || err != nil {
return nil, vterrors.Errorf(vtrpc.Code_INVALID_ARGUMENT, "invalid hex value: %v", v.val)
}
hexBytes := v.val[2 : len(v.val)-1]
decodedHexBytes, err := hex.DecodeString(string(hexBytes))
if err != nil {
return nil, err
}
return decodedHexBytes, nil
}

// decodeHexNum decodes the SQL hex value of the form 0xA1 into a byte
// array matching what MySQL would return when querying the column where
// an INSERT was performed with 0xA1 having been specified as a value
func (v *Value) decodeHexNum() ([]byte, error) {
match, err := regexp.Match("^0x.*$", v.val)
if !match || err != nil {
return nil, vterrors.Errorf(vtrpc.Code_INVALID_ARGUMENT, "invalid hex number: %v", v.val)
}
hexBytes := v.val[2:]
decodedHexBytes, err := hex.DecodeString(string(hexBytes))
if err != nil {
return nil, err
}
return decodedHexBytes, nil
}

func encodeBytesSQL(val []byte, b BinWriter) {
buf := &bytes2.Buffer{}
buf.WriteByte('\'')
Expand Down
92 changes: 92 additions & 0 deletions go/test/endtoend/vtgate/queries/normalize/main_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
/*
Copyright 2021 The Vitess Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package normalize

import (
"flag"
"os"
"testing"

"vitess.io/vitess/go/mysql"
"vitess.io/vitess/go/test/endtoend/cluster"
)

var (
clusterInstance *cluster.LocalProcessCluster
vtParams mysql.ConnParams
KeyspaceName = "ks_normalize"
Cell = "test_normalize"
SchemaSQL = `
create table t1(
id bigint unsigned not null,
charcol char(10),
vcharcol varchar(50),
bincol binary(50),
varbincol varbinary(50),
floatcol float,
deccol decimal(5,2),
bitcol bit,
datecol date,
enumcol enum('small', 'medium', 'large'),
setcol set('a', 'b', 'c'),
jsoncol json,
geocol geometry,
primary key(id)
) Engine=InnoDB;
`
)

func TestMain(m *testing.M) {
defer cluster.PanicHandler(nil)
flag.Parse()

exitCode := func() int {
clusterInstance = cluster.NewCluster(Cell, "localhost")
defer clusterInstance.Teardown()

// Start topo server
err := clusterInstance.StartTopo()
if err != nil {
return 1
}

// Start keyspace
keyspace := &cluster.Keyspace{
Name: KeyspaceName,
SchemaSQL: SchemaSQL,
}
clusterInstance.VtGateExtraArgs = []string{}
clusterInstance.VtTabletExtraArgs = []string{}
err = clusterInstance.StartKeyspace(*keyspace, []string{"-"}, 1, false)
if err != nil {
return 1
}

// Start vtgate
err = clusterInstance.StartVtgate()
if err != nil {
return 1
}

vtParams = mysql.ConnParams{
Host: clusterInstance.Hostname,
Port: clusterInstance.VtgateMySQLPort,
}
return m.Run()
}()
os.Exit(exitCode)
}
Loading