diff --git a/go/test/endtoend/vreplication/config.go b/go/test/endtoend/vreplication/config.go index c22a486e39c..d539252d5d0 100644 --- a/go/test/endtoend/vreplication/config.go +++ b/go/test/endtoend/vreplication/config.go @@ -4,7 +4,7 @@ var ( initialProductSchema = ` create table product(pid int, description varbinary(128), primary key(pid)); create table customer(cid int, name varbinary(128), primary key(cid)); -create table merchant(mname varchar(128), category varchar(128), primary key(mname)); +create table merchant(mname varchar(128), category varchar(128), primary key(mname)) DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_general_ci; create table orders(oid int, cid int, pid int, mname varchar(128), price int, primary key(oid)); create table customer_seq(id int, next_id bigint, cache bigint, primary key(id)) comment 'vitess_sequence'; create table order_seq(id int, next_id bigint, cache bigint, primary key(id)) comment 'vitess_sequence'; diff --git a/go/vt/wrangler/vdiff.go b/go/vt/wrangler/vdiff.go index 850799bc572..d3f8063f03c 100644 --- a/go/vt/wrangler/vdiff.go +++ b/go/vt/wrangler/vdiff.go @@ -23,6 +23,8 @@ import ( "sync" "time" + "vitess.io/vitess/go/vt/log" + "vitess.io/vitess/go/vt/vtgate/evalengine" "github.com/golang/protobuf/proto" @@ -265,6 +267,42 @@ func (df *vdiff) buildVDiffPlan(ctx context.Context, filter *binlogdatapb.Filter return nil } +// findPKs identifies PKs and removes them from the columns to do data comparison +func findPKs(table *tabletmanagerdatapb.TableDefinition, targetSelect *sqlparser.Select, td *tableDiffer) (sqlparser.OrderBy, error) { + var orderby sqlparser.OrderBy + for _, pk := range table.PrimaryKeyColumns { + found := false + for i, selExpr := range targetSelect.SelectExprs { + expr := selExpr.(*sqlparser.AliasedExpr).Expr + colname := "" + switch ct := expr.(type) { + case *sqlparser.ColName: + colname = ct.Name.String() + case *sqlparser.FuncExpr: //eg. weight_string() + //no-op + default: + log.Warningf("Not considering column %v for PK, type %v not handled", selExpr, ct) + } + if strings.EqualFold(pk, colname) { + td.comparePKs = append(td.comparePKs, td.compareCols[i]) + // We'll be comparing pks separately. So, remove them from compareCols. + td.compareCols[i] = -1 + found = true + break + } + } + if !found { + // Unreachable. + return nil, fmt.Errorf("column %v not found in table %v", pk, table.Name) + } + orderby = append(orderby, &sqlparser.Order{ + Expr: &sqlparser.ColName{Name: sqlparser.NewColIdent(pk)}, + Direction: sqlparser.AscScr, + }) + } + return orderby, nil +} + // buildTablePlan builds one tableDiffer. func (df *vdiff) buildTablePlan(table *tabletmanagerdatapb.TableDefinition, query string) (*tableDiffer, error) { statement, err := sqlparser.Parse(query) @@ -354,27 +392,9 @@ func (df *vdiff) buildTablePlan(table *tabletmanagerdatapb.TableDefinition, quer }, } - var orderby sqlparser.OrderBy - for _, pk := range table.PrimaryKeyColumns { - found := false - for i, selExpr := range targetSelect.SelectExprs { - colname := selExpr.(*sqlparser.AliasedExpr).Expr.(*sqlparser.ColName).Name.Lowered() - if pk == colname { - td.comparePKs = append(td.comparePKs, td.compareCols[i]) - // We'll be comparing pks seperately. So, remove them from compareCols. - td.compareCols[i] = -1 - found = true - break - } - } - if !found { - // Unreachable. - return nil, fmt.Errorf("column %v not found in table %v", pk, table.Name) - } - orderby = append(orderby, &sqlparser.Order{ - Expr: &sqlparser.ColName{Name: sqlparser.NewColIdent(pk)}, - Direction: sqlparser.AscScr, - }) + orderby, err := findPKs(table, targetSelect, td) + if err != nil { + return nil, err } // Remove in_keyrange. It's not understood by mysql. sourceSelect.Where = removeKeyrange(sel.Where) diff --git a/go/vt/wrangler/vdiff_test.go b/go/vt/wrangler/vdiff_test.go index 146a398d2b6..b3a0b65c490 100644 --- a/go/vt/wrangler/vdiff_test.go +++ b/go/vt/wrangler/vdiff_test.go @@ -20,6 +20,8 @@ import ( "testing" "time" + "vitess.io/vitess/go/vt/sqlparser" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/net/context" @@ -891,3 +893,72 @@ func TestVDiffReplicationWait(t *testing.T) { _, err := env.wr.VDiff(context.Background(), "target", env.workflow, env.cell, env.cell, "replica", 0*time.Second, "") require.EqualError(t, err, "startQueryStreams(sources): WaitForPosition for tablet cell-0000000101: context deadline exceeded") } + +func TestVDiffFindPKs(t *testing.T) { + + testcases := []struct { + name string + table *tabletmanagerdatapb.TableDefinition + targetSelect *sqlparser.Select + tdIn *tableDiffer + tdOut *tableDiffer + errorString string + }{ + { + name: "", + table: &tabletmanagerdatapb.TableDefinition{ + Name: "t1", + Columns: []string{"c1", "c2"}, + PrimaryKeyColumns: []string{"c1"}, + Fields: sqltypes.MakeTestFields("c1|c2", "int64|int64"), + }, + targetSelect: &sqlparser.Select{ + SelectExprs: sqlparser.SelectExprs{ + &sqlparser.AliasedExpr{Expr: &sqlparser.ColName{Name: sqlparser.NewColIdent("c1")}}, + &sqlparser.AliasedExpr{Expr: &sqlparser.ColName{Name: sqlparser.NewColIdent("c2")}}, + }, + }, + tdIn: &tableDiffer{ + compareCols: []int{0, 1}, + comparePKs: []int{}, + }, + tdOut: &tableDiffer{ + compareCols: []int{-1, 1}, + comparePKs: []int{0}, + }, + }, { + name: "", + table: &tabletmanagerdatapb.TableDefinition{ + Name: "t1", + Columns: []string{"c1", "c2", "c3", "c4"}, + PrimaryKeyColumns: []string{"c1", "c4"}, + Fields: sqltypes.MakeTestFields("c1|c2|c3|c4", "int64|int64|varchar|int64"), + }, + targetSelect: &sqlparser.Select{ + SelectExprs: sqlparser.SelectExprs{ + &sqlparser.AliasedExpr{Expr: &sqlparser.ColName{Name: sqlparser.NewColIdent("c1")}}, + &sqlparser.AliasedExpr{Expr: &sqlparser.ColName{Name: sqlparser.NewColIdent("c2")}}, + &sqlparser.AliasedExpr{Expr: &sqlparser.FuncExpr{Name: sqlparser.NewColIdent("c3")}}, + &sqlparser.AliasedExpr{Expr: &sqlparser.ColName{Name: sqlparser.NewColIdent("c4")}}, + }, + }, + tdIn: &tableDiffer{ + compareCols: []int{0, 1, 2, 3}, + comparePKs: []int{}, + }, + tdOut: &tableDiffer{ + compareCols: []int{-1, 1, 2, -1}, + comparePKs: []int{0, 3}, + }, + }, + } + + for _, tc := range testcases { + t.Run(tc.name, func(t *testing.T) { + _, err := findPKs(tc.table, tc.targetSelect, tc.tdIn) + require.NoError(t, err) + require.EqualValues(t, tc.tdOut, tc.tdIn) + }) + } + +}