diff --git a/go/vt/vttablet/tabletmanager/vreplication/framework_test.go b/go/vt/vttablet/tabletmanager/vreplication/framework_test.go index e35481e7f75..12c3dab7937 100644 --- a/go/vt/vttablet/tabletmanager/vreplication/framework_test.go +++ b/go/vt/vttablet/tabletmanager/vreplication/framework_test.go @@ -50,12 +50,13 @@ import ( ) var ( - playerEngine *Engine - streamerEngine *vstreamer.Engine - env *testenv.Env - globalFBC = &fakeBinlogClient{} - vrepldb = "vrepl" - globalDBQueries = make(chan string, 1000) + playerEngine *Engine + streamerEngine *vstreamer.Engine + env *testenv.Env + globalFBC = &fakeBinlogClient{} + vrepldb = "vrepl" + globalDBQueries = make(chan string, 1000) + testForeignKeyQueries = false ) type LogExpectation struct { @@ -395,6 +396,8 @@ func (dbc *realDBClient) ExecuteFetch(query string, maxrows int) (*sqltypes.Resu qr, err := dbc.conn.ExecuteFetch(query, 10000, true) if !strings.HasPrefix(query, "select") && !strings.HasPrefix(query, "set") && !dbc.nolog { globalDBQueries <- query + } else if testForeignKeyQueries && strings.Contains(query, "foreign_key_checks") { //allow select/set for foreign_key_checks + globalDBQueries <- query } return qr, err } diff --git a/go/vt/vttablet/tabletmanager/vreplication/vcopier_test.go b/go/vt/vttablet/tabletmanager/vreplication/vcopier_test.go index a06557bbfda..15b7a9b4a31 100644 --- a/go/vt/vttablet/tabletmanager/vreplication/vcopier_test.go +++ b/go/vt/vttablet/tabletmanager/vreplication/vcopier_test.go @@ -22,6 +22,8 @@ import ( "testing" "time" + "github.com/stretchr/testify/require" + "golang.org/x/net/context" "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/vt/binlog/binlogplayer" @@ -29,6 +31,106 @@ import ( "vitess.io/vitess/go/vt/vttablet/tabletserver/vstreamer" ) +func TestPlayerCopyTablesWithFK(t *testing.T) { + testForeignKeyQueries = true + defer func() { + testForeignKeyQueries = false + }() + + defer deleteTablet(addTablet(100)) + + execStatements(t, []string{ + "create table src2(id int, id2 int, primary key(id))", + "create table src1(id int, id2 int, primary key(id), foreign key (id2) references src2(id) on delete cascade)", + "insert into src2 values(1, 21), (2, 22)", + "insert into src1 values(1, 1), (2, 2)", + fmt.Sprintf("create table %s.dst2(id int, id2 int, primary key(id))", vrepldb), + fmt.Sprintf("create table %s.dst1(id int, id2 int, primary key(id), foreign key (id2) references dst2(id) on delete cascade)", vrepldb), + }) + defer execStatements(t, []string{ + "drop table src1", + fmt.Sprintf("drop table %s.dst1", vrepldb), + "drop table src2", + fmt.Sprintf("drop table %s.dst2", vrepldb), + }) + env.SchemaEngine.Reload(context.Background()) + + filter := &binlogdatapb.Filter{ + Rules: []*binlogdatapb.Rule{{ + Match: "dst1", + Filter: "select * from src1", + }, { + Match: "dst2", + Filter: "select * from src2", + }}, + } + + bls := &binlogdatapb.BinlogSource{ + Keyspace: env.KeyspaceName, + Shard: env.ShardName, + Filter: filter, + OnDdl: binlogdatapb.OnDDLAction_IGNORE, + } + query := binlogplayer.CreateVReplicationState("test", bls, "", binlogplayer.VReplicationInit, playerEngine.dbName) + qr, err := playerEngine.Exec(query) + require.NoError(t, err) + + expectDBClientQueries(t, []string{ + "/insert into _vt.vreplication", + "select @@foreign_key_checks;", + // Create the list of tables to copy and transition to Copying state. + "begin", + "/insert into _vt.copy_state", + "/update _vt.vreplication set state='Copying'", + "commit", + "set foreign_key_checks=0;", + // The first fast-forward has no starting point. So, it just saves the current position. + "/update _vt.vreplication set pos=", + "begin", + "insert into dst1(id,id2) values (1,1), (2,2)", + `/update _vt.copy_state set lastpk='fields: rows: ' where vrepl_id=.*`, + "commit", + // copy of dst1 is done: delete from copy_state. + "/delete from _vt.copy_state.*dst1", + // The next FF executes and updates the position before copying. + "set foreign_key_checks=0;", + "begin", + "/update _vt.vreplication set pos=", + "commit", + // copy dst2 + "begin", + "insert into dst2(id,id2) values (1,21), (2,22)", + `/update _vt.copy_state set lastpk='fields: rows: ' where vrepl_id=.*`, + "commit", + // copy of dst1 is done: delete from copy_state. + "/delete from _vt.copy_state.*dst2", + // All tables copied. Final catch up followed by Running state. + "set foreign_key_checks=1;", + "/update _vt.vreplication set state='Running'", + }) + + expectData(t, "dst1", [][]string{ + {"1", "1"}, + {"2", "2"}, + }) + expectData(t, "dst2", [][]string{ + {"1", "21"}, + {"2", "22"}, + }) + + query = fmt.Sprintf("delete from _vt.vreplication where id = %d", qr.InsertID) + if _, err := playerEngine.Exec(query); err != nil { + t.Fatal(err) + } + expectDBClientQueries(t, []string{ + "set foreign_key_checks=1;", + "begin", + "/delete from _vt.vreplication", + "/delete from _vt.copy_state", + "commit", + }) +} + func TestPlayerCopyTables(t *testing.T) { defer deleteTablet(addTablet(100)) diff --git a/go/vt/vttablet/tabletmanager/vreplication/vreplicator.go b/go/vt/vttablet/tabletmanager/vreplication/vreplicator.go index d7fcb54324e..60242eb7473 100644 --- a/go/vt/vttablet/tabletmanager/vreplication/vreplicator.go +++ b/go/vt/vttablet/tabletmanager/vreplication/vreplicator.go @@ -59,6 +59,8 @@ type vreplicator struct { // mysqld is used to fetch the local schema. mysqld mysqlctl.MysqlDaemon tableKeys map[string][]string + + originalFKCheckSetting int64 } // newVReplicator creates a new vreplicator. The valid fields from the source are: @@ -130,7 +132,11 @@ func (vr *vreplicator) replicate(ctx context.Context) error { return err } vr.tableKeys = tableKeys - + if err := vr.getSettingFKCheck(); err != nil { + return err + } + //defensive guard, should be a no-op since it should happen after copy is done + defer vr.resetFKCheckAfterCopy() for { select { case <-ctx.Done(): @@ -152,6 +158,10 @@ func (vr *vreplicator) replicate(ctx context.Context) error { switch { case numTablesToCopy != 0: + if err := vr.clearFKCheck(); err != nil { + log.Warningf("Unable to clear FK check %v", err) + return err + } if err := newVCopier(vr).copyNext(ctx, settings); err != nil { return err } @@ -160,6 +170,10 @@ func (vr *vreplicator) replicate(ctx context.Context) error { return err } default: + if err := vr.resetFKCheckAfterCopy(); err != nil { + log.Warningf("Unable to reset FK check %v", err) + return err + } if vr.source.StopAfterCopy { return vr.setState(binlogplayer.BlpStopped, "Stopped after copy.") } @@ -254,3 +268,28 @@ func encodeString(in string) string { sqltypes.NewVarChar(in).EncodeSQL(&buf) return buf.String() } + +func (vr *vreplicator) getSettingFKCheck() error { + qr, err := vr.dbClient.Execute("select @@foreign_key_checks;") + if err != nil { + return err + } + if qr.RowsAffected != 1 || len(qr.Fields) != 1 { + return fmt.Errorf("unable to select @@foreign_key_checks") + } + vr.originalFKCheckSetting, err = evalengine.ToInt64(qr.Rows[0][0]) + if err != nil { + return err + } + return nil +} + +func (vr *vreplicator) resetFKCheckAfterCopy() error { + _, err := vr.dbClient.Execute(fmt.Sprintf("set foreign_key_checks=%d;", vr.originalFKCheckSetting)) + return err +} + +func (vr *vreplicator) clearFKCheck() error { + _, err := vr.dbClient.Execute("set foreign_key_checks=0;") + return err +}