Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions go/mysql/collations/integration/coercion_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ import (
"strings"
"testing"

"github.com/stretchr/testify/require"

"vitess.io/vitess/go/mysql/collations"
"vitess.io/vitess/go/mysql/collations/remote"
"vitess.io/vitess/go/sqltypes"
Expand Down Expand Up @@ -80,12 +82,14 @@ func (tc *testConcat) Test(t *testing.T, remote *RemoteCoercionResult, local col
concat.Write(leftText)
concat.Write(rightText)

if !bytes.Equal(concat.Bytes(), remote.Expr.ToBytes()) {
rEBytes, err := remote.Expr.ToBytes()
require.NoError(t, err)
if !bytes.Equal(concat.Bytes(), rEBytes) {
t.Errorf("failed to concatenate text;\n\tCONCAT(%v COLLATE %s, %v COLLATE %s) = \n\tCONCAT(%v, %v) COLLATE %s = \n\t\t%v\n\n\texpected: %v",
tc.left.Text, tc.left.Collation.Name(),
tc.right.Text, tc.right.Collation.Name(),
leftText, rightText, localCollation.Name(),
concat.Bytes(), remote.Expr.ToBytes(),
concat.Bytes(), rEBytes,
)
}
}
Expand All @@ -102,6 +106,7 @@ func (tc *testComparison) Expression() string {
}

func (tc *testComparison) Test(t *testing.T, remote *RemoteCoercionResult, local collations.TypedCollation, coerce1, coerce2 collations.Coercion) {
localCollation := defaultenv.LookupByID(local.Collation)
leftText, err := coerce1(nil, tc.left.Text)
if err != nil {
t.Errorf("failed to transcode left: %v", err)
Expand All @@ -113,9 +118,9 @@ func (tc *testComparison) Test(t *testing.T, remote *RemoteCoercionResult, local
t.Errorf("failed to transcode right: %v", err)
return
}

remoteEquals := remote.Expr.ToBytes()[0] == '1'
localCollation := defaultenv.LookupByID(local.Collation)
rEBytes, err := remote.Expr.ToBytes()
require.NoError(t, err)
remoteEquals := rEBytes[0] == '1'
localEquals := localCollation.Collate(leftText, rightText, false) == 0
if remoteEquals != localEquals {
t.Errorf("failed to collate %#v = %#v with collation %s (expected %v, got %v)",
Expand Down
10 changes: 7 additions & 3 deletions go/mysql/collations/integration/collations_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (
"strings"
"testing"

"github.com/stretchr/testify/require"
"golang.org/x/text/encoding/unicode/utf32"

"vitess.io/vitess/go/mysql"
Expand Down Expand Up @@ -113,14 +114,17 @@ func (u *uca900CollationTest) Test(t *testing.T, result *sqltypes.Result) {
if row[1].Len() == 0 {
continue
}
utf8Input := parseUtf32cp(row[0].ToBytes())
rowBytes, err := row[0].ToBytes()
require.NoError(t, err)
utf8Input := parseUtf32cp(rowBytes)
if utf8Input == nil {
t.Errorf("[%s] failed to parse UTF32-encoded codepoint: %s (%s)", u.collation, row[0], row[2].ToString())
errors++
continue
}

expectedWeightString := parseWeightString(row[1].ToBytes())
rowBytes, err = row[1].ToBytes()
require.NoError(t, err)
expectedWeightString := parseWeightString(rowBytes)
if expectedWeightString == nil {
t.Errorf("[%s] failed to parse weight string: %s (%s)", u.collation, row[1], row[2].ToString())
errors++
Expand Down
5 changes: 4 additions & 1 deletion go/mysql/collations/remote/charset.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,10 @@ func (c *Charset) performConversion(dst []byte, dstCharset string, src []byte, s
if len(res.Rows) != 1 {
return nil, fmt.Errorf("unexpected result from MySQL: %d rows returned", len(res.Rows))
}
result := res.Rows[0][0].ToBytes()
result, err := res.Rows[0][0].ToBytes()
if err != nil {
return nil, err
}
if dst != nil {
return append(dst, result...), nil
}
Expand Down
5 changes: 3 additions & 2 deletions go/mysql/collations/remote/collation.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,10 +160,11 @@ func (c *Collation) WeightString(dst, src []byte, numCodepoints int) []byte {
c.sql.WriteString(")")

if result := c.performRemoteQuery(); result != nil {
resultBytes, _ := result[0].ToBytes()
if dst == nil {
dst = result[0].ToBytes()
dst = resultBytes
} else {
dst = append(dst, result[0].ToBytes()...)
dst = append(dst, resultBytes...)
}
}
return dst
Expand Down
9 changes: 8 additions & 1 deletion go/mysql/endtoend/replication_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ import (
"testing"
"time"

"github.com/stretchr/testify/require"

"vitess.io/vitess/go/vt/vtgate/evalengine"

"context"
Expand Down Expand Up @@ -1025,7 +1027,12 @@ func TestRowReplicationTypes(t *testing.T) {
sql.WriteString(", ")
sql.WriteString(tcase.name)
sql.WriteString(" = ")
if values[i+1].Type() == querypb.Type_TIMESTAMP && !bytes.HasPrefix(values[i+1].ToBytes(), mysql.ZeroTimestamp) {
valueBytes, err := values[i+1].ToBytes()
// Expression values are not supported with ToBytes
if values[i+1].Type() != querypb.Type_EXPRESSION {
require.NoError(t, err)
}
if values[i+1].Type() == querypb.Type_TIMESTAMP && !bytes.HasPrefix(valueBytes, mysql.ZeroTimestamp) {
// Values in the binary log are UTC. Let's convert them
// to whatever timezone the connection is using,
// so MySQL properly converts them back to UTC.
Expand Down
2 changes: 1 addition & 1 deletion go/sqltypes/plan_value.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ func (pv PlanValue) MarshalJSON() ([]byte, error) {
return json.Marshal(":" + pv.Key)
case !pv.Value.IsNull():
if pv.Value.IsIntegral() {
return pv.Value.ToBytes(), nil
return pv.Value.ToBytes()
}
return json.Marshal(pv.Value.ToString())
case pv.ListKey != "":
Expand Down
4 changes: 2 additions & 2 deletions go/sqltypes/result.go
Original file line number Diff line number Diff line change
Expand Up @@ -221,9 +221,9 @@ func saveRowsAnalysis(r Result, allRows map[string]int, totalRows *int, incremen
for _, row := range r.Rows {
newHash := hashCodeForRow(row)
if increment {
allRows[newHash] += 1
allRows[newHash]++
} else {
allRows[newHash] -= 1
allRows[newHash]--
}
}
if increment {
Expand Down
55 changes: 50 additions & 5 deletions go/sqltypes/value.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,21 @@ package sqltypes

import (
"encoding/base64"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"regexp"
"strconv"
"strings"

"vitess.io/vitess/go/bytes2"
"vitess.io/vitess/go/hack"

querypb "vitess.io/vitess/go/vt/proto/query"
"vitess.io/vitess/go/vt/proto/vtrpc"
vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc"
"vitess.io/vitess/go/vt/vterrors"
)

var (
Expand Down Expand Up @@ -236,13 +241,21 @@ func (v Value) RawStr() string {

// ToBytes returns the value as MySQL would return it as []byte.
// In contrast, Raw returns the internal representation of the Value, which may not
// match MySQL's representation for newer types.
// If the value is not convertible like in the case of Expression, it returns nil.
func (v Value) ToBytes() []byte {
// match MySQL's representation for hex encoded binary data or newer types.
// If the value is not convertible like in the case of Expression, it returns an error.
func (v Value) ToBytes() ([]byte, error) {
if v.typ == Expression {
return nil
return nil, vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "expression cannot be converted to bytes")
}
return v.val
if v.typ == HexVal {
dv, err := v.decodeHexVal()
return dv, err
}
if v.typ == HexNum {
dv, err := v.decodeHexNum()
return dv, err
}
return v.val, nil
}

// Len returns the length.
Expand Down Expand Up @@ -469,6 +482,38 @@ func (v *Value) UnmarshalJSON(b []byte) error {
return err
}

// decodeHexVal decodes the SQL hex value of the form x'A1' into a byte
// array matching what MySQL would return when querying the column where
// an INSERT was performed with x'A1' having been specified as a value
func (v *Value) decodeHexVal() ([]byte, error) {
match, err := regexp.Match("^x'.*'$", v.val)
if !match || err != nil {
return nil, vterrors.Errorf(vtrpc.Code_INVALID_ARGUMENT, "invalid hex value: %v", v.val)
}
hexBytes := v.val[2 : len(v.val)-1]
decodedHexBytes, err := hex.DecodeString(string(hexBytes))
if err != nil {
return nil, err
}
return decodedHexBytes, nil
}

// decodeHexNum decodes the SQL hex value of the form 0xA1 into a byte
// array matching what MySQL would return when querying the column where
// an INSERT was performed with 0xA1 having been specified as a value
func (v *Value) decodeHexNum() ([]byte, error) {
match, err := regexp.Match("^0x.*$", v.val)
if !match || err != nil {
return nil, vterrors.Errorf(vtrpc.Code_INVALID_ARGUMENT, "invalid hex number: %v", v.val)
}
hexBytes := v.val[2:]
decodedHexBytes, err := hex.DecodeString(string(hexBytes))
if err != nil {
return nil, err
}
return decodedHexBytes, nil
}

func encodeBytesSQL(val []byte, b BinWriter) {
buf := &bytes2.Buffer{}
encodeBytesSQLBytes2(val, buf)
Expand Down
10 changes: 8 additions & 2 deletions go/sqltypes/value_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ import (
"strings"
"testing"

"github.com/stretchr/testify/require"

querypb "vitess.io/vitess/go/vt/proto/query"
)

Expand Down Expand Up @@ -403,7 +405,9 @@ func TestToBytesAndString(t *testing.T) {
TestValue(Int64, "1"),
TestValue(Int64, "12"),
} {
if b := v.ToBytes(); !bytes.Equal(b, v.Raw()) {
vBytes, err := v.ToBytes()
require.NoError(t, err)
if b := vBytes; !bytes.Equal(b, v.Raw()) {
t.Errorf("%v.ToBytes: %s, want %s", v, b, v.Raw())
}
if s := v.ToString(); s != string(v.Raw()) {
Expand All @@ -412,7 +416,9 @@ func TestToBytesAndString(t *testing.T) {
}

tv := TestValue(Expression, "aa")
if b := tv.ToBytes(); b != nil {
tvBytes, err := tv.ToBytes()
require.EqualError(t, err, "expression cannot be converted to bytes")
if b := tvBytes; b != nil {
t.Errorf("%v.ToBytes: %s, want nil", tv, b)
}
if s := tv.ToString(); s != "" {
Expand Down
1 change: 1 addition & 0 deletions go/vt/binlog/binlog_connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ var (
// connecting for replication. Each such connection must identify itself to
// mysqld with a server ID that is unique both among other BinlogConnections and
// among actual replicas in the topology.
//revive:disable because I'm not trying to refactor the entire code base right now
type BinlogConnection struct {
*mysql.Conn
cp dbconfigs.Connector
Expand Down
12 changes: 10 additions & 2 deletions go/vt/binlog/binlog_streamer.go
Original file line number Diff line number Diff line change
Expand Up @@ -758,7 +758,11 @@ func writeValuesAsSQL(sql *sqlparser.TrackedBuffer, tce *tableCacheEntry, rs *my
if err != nil {
return keyspaceIDCell, nil, err
}
if value.Type() == querypb.Type_TIMESTAMP && !bytes.HasPrefix(value.ToBytes(), mysql.ZeroTimestamp) {
vBytes, err := value.ToBytes()
if err != nil {
return sqltypes.Value{}, nil, err
}
if value.Type() == querypb.Type_TIMESTAMP && !bytes.HasPrefix(vBytes, mysql.ZeroTimestamp) {
// Values in the binary log are UTC. Let's convert them
// to whatever timezone the connection is using,
// so MySQL properly converts them back to UTC.
Expand Down Expand Up @@ -819,7 +823,11 @@ func writeIdentifiersAsSQL(sql *sqlparser.TrackedBuffer, tce *tableCacheEntry, r
if err != nil {
return keyspaceIDCell, nil, err
}
if value.Type() == querypb.Type_TIMESTAMP && !bytes.HasPrefix(value.ToBytes(), mysql.ZeroTimestamp) {
vBytes, err := value.ToBytes()
if err != nil {
return keyspaceIDCell, nil, err
}
if value.Type() == querypb.Type_TIMESTAMP && !bytes.HasPrefix(vBytes, mysql.ZeroTimestamp) {
// Values in the binary log are UTC. Let's convert them
// to whatever timezone the connection is using,
// so MySQL properly converts them back to UTC.
Expand Down
6 changes: 5 additions & 1 deletion go/vt/binlog/keyspace_id_resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,11 @@ type keyspaceIDResolverFactoryV2 struct {
func (r *keyspaceIDResolverFactoryV2) keyspaceID(v sqltypes.Value) ([]byte, error) {
switch r.shardingColumnType {
case topodatapb.KeyspaceIdType_BYTES:
return v.ToBytes(), nil
vBytes, err := v.ToBytes()
if err != nil {
return nil, err
}
return vBytes, nil
case topodatapb.KeyspaceIdType_UINT64:
i, err := evalengine.ToUint64(v)
if err != nil {
Expand Down
3 changes: 2 additions & 1 deletion go/vt/mysqlctl/tmutils/permissions.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ func NewUserPermission(fields []*querypb.Field, values []sqltypes.Value) *tablet
case "user":
up.User = values[i].ToString()
case "password":
up.PasswordChecksum = crc64.Checksum(values[i].ToBytes(), hashTable)
vBytes, _ := values[i].ToBytes()
up.PasswordChecksum = crc64.Checksum(vBytes, hashTable)
case "password_last_changed":
// we skip this one, as the value may be
// different on primary and replicas.
Expand Down
9 changes: 7 additions & 2 deletions go/vt/sqlparser/ast_funcs.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"bytes"
"encoding/hex"
"encoding/json"
"regexp"
"strings"

"vitess.io/vitess/go/hack"
Expand Down Expand Up @@ -481,7 +482,7 @@ func (node *Literal) HexDecode() ([]byte, error) {
return hex.DecodeString(node.Val)
}

// EncodeHexValToMySQLQueryFormat encodes the hexval back into the query format
// encodeHexValToMySQLQueryFormat encodes the hexval back into the query format
// for passing on to MySQL as a bind var
func (node *Literal) encodeHexValToMySQLQueryFormat() ([]byte, error) {
nb := node.Bytes()
Expand All @@ -490,7 +491,11 @@ func (node *Literal) encodeHexValToMySQLQueryFormat() ([]byte, error) {
}

// Let's make this idempotent in case it's called more than once
if nb[0] == 'x' && nb[1] == '0' && nb[len(nb)-1] == '\'' {
match, err := regexp.Match("^x'.*'$", nb)
if err != nil {
return nb, err
}
if match {
return nb, nil
}

Expand Down
12 changes: 10 additions & 2 deletions go/vt/vtctl/workflow/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,11 @@ func (s *Server) CheckReshardingJournalExistsOnTablet(ctx context.Context, table

if len(p3qr.Rows) != 0 {
qr := sqltypes.Proto3ToResult(p3qr)
if err := prototext.Unmarshal(qr.Rows[0][0].ToBytes(), &journal); err != nil {
qrBytes, err := qr.Rows[0][0].ToBytes()
if err != nil {
return nil, false, err
}
if err := prototext.Unmarshal(qrBytes, &journal); err != nil {
return nil, false, err
}

Expand Down Expand Up @@ -334,7 +338,11 @@ func (s *Server) GetWorkflows(ctx context.Context, req *vtctldatapb.GetWorkflows
}

var bls binlogdatapb.BinlogSource
if err := prototext.Unmarshal(row[2].ToBytes(), &bls); err != nil {
rowBytes, err := row[2].ToBytes()
if err != nil {
return err
}
if err := prototext.Unmarshal(rowBytes, &bls); err != nil {
return err
}

Expand Down
6 changes: 5 additions & 1 deletion go/vt/vtctl/workflow/stream_migrator.go
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,11 @@ func (sm *StreamMigrator) readTabletStreams(ctx context.Context, ti *topo.Tablet
}

var bls binlogdatapb.BinlogSource
if err := prototext.Unmarshal(row[2].ToBytes(), &bls); err != nil {
rowBytes, err := row[2].ToBytes()
if err != nil {
return nil, err
}
if err := prototext.Unmarshal(rowBytes, &bls); err != nil {
return nil, err
}

Expand Down
Loading