diff --git a/go/sqltypes/proto3.go b/go/sqltypes/proto3.go index 146692061a5..2a0af8a3a01 100644 --- a/go/sqltypes/proto3.go +++ b/go/sqltypes/proto3.go @@ -10,6 +10,30 @@ import "github.com/youtube/vitess/go/vt/vterrors" // This file contains the proto3 conversion functions for the structures // defined here. +// RowToProto3 converts []Value to proto3. +func RowToProto3(row []Value) *querypb.Row { + result := &querypb.Row{} + result.Lengths = make([]int64, 0, len(row)) + total := 0 + for _, c := range row { + if c.IsNull() { + result.Lengths = append(result.Lengths, -1) + continue + } + length := c.Len() + result.Lengths = append(result.Lengths, int64(length)) + total += length + } + result.Values = make([]byte, 0, total) + for _, c := range row { + if c.IsNull() { + continue + } + result.Values = append(result.Values, c.Raw()...) + } + return result +} + // RowsToProto3 converts [][]Value to proto3. func RowsToProto3(rows [][]Value) []*querypb.Row { if len(rows) == 0 { @@ -18,26 +42,7 @@ func RowsToProto3(rows [][]Value) []*querypb.Row { result := make([]*querypb.Row, len(rows)) for i, r := range rows { - row := &querypb.Row{} - result[i] = row - row.Lengths = make([]int64, 0, len(r)) - total := 0 - for _, c := range r { - if c.IsNull() { - row.Lengths = append(row.Lengths, -1) - continue - } - length := c.Len() - row.Lengths = append(row.Lengths, int64(length)) - total += length - } - row.Values = make([]byte, 0, total) - for _, c := range r { - if c.IsNull() { - continue - } - row.Values = append(row.Values, c.Raw()...) - } + result[i] = RowToProto3(r) } return result } diff --git a/go/vt/binlog/binlog_streamer.go b/go/vt/binlog/binlog_streamer.go index 8bc74014553..3a9295b07e4 100644 --- a/go/vt/binlog/binlog_streamer.go +++ b/go/vt/binlog/binlog_streamer.go @@ -15,6 +15,7 @@ import ( "github.com/youtube/vitess/go/mysqlconn/replication" "github.com/youtube/vitess/go/sqldb" + "github.com/youtube/vitess/go/sqltypes" "github.com/youtube/vitess/go/stats" "github.com/youtube/vitess/go/vt/mysqlctl" "github.com/youtube/vitess/go/vt/sqlparser" @@ -60,7 +61,7 @@ type FullBinlogStatement struct { Table string KeyspaceID []byte PKNames []*querypb.Field - PKRow *querypb.Row + PKValues []sqltypes.Value } // sendTransactionFunc is used to send binlog events. @@ -75,14 +76,57 @@ func getStatementCategory(sql string) binlogdatapb.BinlogTransaction_Statement_C return statementPrefixes[strings.ToLower(sql)] } +// tableCacheEntry contains everything we know about a table. +// It is created when we get a TableMap event. +type tableCacheEntry struct { + // tm is what we get from a TableMap event. + tm *replication.TableMap + + // ti is the table descriptor we get from the schema engine. + ti *schema.Table + + // The following fields are used if we want to extract the + // keyspace_id of a row. + + // resolver is only set if Streamer.resolverFactory is set. + resolver keyspaceIDResolver + + // keyspaceIDIndex is the index of the field that can be used + // to compute the keyspaceID. Set to -1 if no resolver is in used. + keyspaceIDIndex int + + // The following fields are used if we want to extract the + // primary key of a row. + + // pkNames contains an array of fields for the PK. + pkNames []*querypb.Field + + // pkIndexes contains the index of a given column in the + // PK. It is -1 f the column is not in any PK. It contains as + // many fields as there are columns in the table. + // For instance, in a table defined like this: + // field1 varchar() + // pkpart2 int + // pkpart1 int + // pkIndexes would contain: [ + // -1 // field1 is not in the pk + // 1 // pkpart2 is the second part of the PK + // 0 // pkpart1 is the first part of the PK + // This array is built this way so when we extract the columns + // in a row, we can just save them in the PK array easily. + pkIndexes []int +} + // Streamer streams binlog events from MySQL by connecting as a slave. // A Streamer should only be used once. To start another stream, call // NewStreamer() again. type Streamer struct { // The following fields at set at creation and immutable. - dbname string - mysqld mysqlctl.MysqlDaemon - se *schema.Engine + dbname string + mysqld mysqlctl.MysqlDaemon + se *schema.Engine + resolverFactory keyspaceIDResolverFactory + extractPK bool clientCharset *binlogdatapb.Charset startPos replication.Position @@ -193,7 +237,7 @@ func (bls *Streamer) parseEvents(ctx context.Context, events <-chan replication. // Remember the RBR state. // tableMaps is indexed by tableID. - tableMaps := make(map[uint64]*replication.TableMap) + tableMaps := make(map[uint64]*tableCacheEntry) // A begin can be triggered either by a BEGIN query, or by a GTID_EVENT. begin := func() { @@ -388,21 +432,63 @@ func (bls *Streamer) parseEvents(ctx context.Context, events <-chan replication. if err != nil { return pos, err } - tableMaps[tableID] = tm + // TODO(alainjobart) if table is already in map, + // just use it. + + tce := &tableCacheEntry{ + tm: tm, + keyspaceIDIndex: -1, + } + tableMaps[tableID] = tce + + // Check we're in the right database, and if so, fill + // in more data. + if tm.Database != "" && tm.Database != bls.dbname { + continue + } + + // Find and fill in the table schema. + tce.ti = bls.se.GetTable(sqlparser.NewTableIdent(tm.Name)) + if tce.ti == nil { + return pos, fmt.Errorf("unknown table %v in schema", tm.Name) + } + + // Fill in the resolver if needed. + if bls.resolverFactory != nil { + tce.keyspaceIDIndex, tce.resolver, err = bls.resolverFactory(tce.ti) + if err != nil { + return pos, fmt.Errorf("cannot find column to use to find keyspace_id for table %v", tm.Name) + } + } + + // Fill in PK indexes if necessary. + if bls.extractPK { + tce.pkNames = make([]*querypb.Field, len(tce.ti.PKColumns)) + tce.pkIndexes = make([]int, len(tce.ti.Columns)) + for i := range tce.pkIndexes { + // Put -1 as default in here. + tce.pkIndexes[i] = -1 + } + for i, c := range tce.ti.PKColumns { + // Patch in every PK column index. + tce.pkIndexes[c] = i + // Fill in pknames + tce.pkNames[i] = &querypb.Field{ + Name: tce.ti.Columns[c].Name.String(), + Type: tce.ti.Columns[c].Type, + } + } + } case ev.IsWriteRows(): tableID := ev.TableID(format) - tm, ok := tableMaps[tableID] + tce, ok := tableMaps[tableID] if !ok { return pos, fmt.Errorf("unknown tableID %v in WriteRows event", tableID) } - if tm.Database != "" && tm.Database != bls.dbname { + if tce.ti == nil { // Skip cross-db statements. continue } - ti := bls.se.GetTable(sqlparser.NewTableIdent(tm.Name)) - if ti == nil { - return pos, fmt.Errorf("unknown table %v in schema", tm.Name) - } setTimestamp := &binlogdatapb.BinlogTransaction_Statement{ Category: binlogdatapb.BinlogTransaction_Statement_BL_SET, Sql: []byte(fmt.Sprintf("SET TIMESTAMP=%d", ev.Timestamp())), @@ -411,12 +497,12 @@ func (bls *Streamer) parseEvents(ctx context.Context, events <-chan replication. Statement: setTimestamp, }) - rows, err := ev.Rows(format, tm) + rows, err := ev.Rows(format, tce.tm) if err != nil { return pos, err } - statements = appendInserts(statements, &rows, tm, ti) + statements = bls.appendInserts(statements, tce, &rows) if autocommit { if err = commit(ev.Timestamp()); err != nil { @@ -425,18 +511,14 @@ func (bls *Streamer) parseEvents(ctx context.Context, events <-chan replication. } case ev.IsUpdateRows(): tableID := ev.TableID(format) - tm, ok := tableMaps[tableID] + tce, ok := tableMaps[tableID] if !ok { return pos, fmt.Errorf("unknown tableID %v in UpdateRows event", tableID) } - if tm.Database != "" && tm.Database != bls.dbname { + if tce.ti == nil { // Skip cross-db statements. continue } - ti := bls.se.GetTable(sqlparser.NewTableIdent(tm.Name)) - if ti == nil { - return pos, fmt.Errorf("unknown table %v in schema", tm.Name) - } setTimestamp := &binlogdatapb.BinlogTransaction_Statement{ Category: binlogdatapb.BinlogTransaction_Statement_BL_SET, Sql: []byte(fmt.Sprintf("SET TIMESTAMP=%d", ev.Timestamp())), @@ -445,12 +527,12 @@ func (bls *Streamer) parseEvents(ctx context.Context, events <-chan replication. Statement: setTimestamp, }) - rows, err := ev.Rows(format, tm) + rows, err := ev.Rows(format, tce.tm) if err != nil { return pos, err } - statements = appendUpdates(statements, &rows, tm, ti) + statements = bls.appendUpdates(statements, tce, &rows) if autocommit { if err = commit(ev.Timestamp()); err != nil { @@ -459,18 +541,14 @@ func (bls *Streamer) parseEvents(ctx context.Context, events <-chan replication. } case ev.IsDeleteRows(): tableID := ev.TableID(format) - tm, ok := tableMaps[tableID] + tce, ok := tableMaps[tableID] if !ok { return pos, fmt.Errorf("unknown tableID %v in DeleteRows event", tableID) } - if tm.Database != "" && tm.Database != bls.dbname { + if tce.ti == nil { // Skip cross-db statements. continue } - ti := bls.se.GetTable(sqlparser.NewTableIdent(tm.Name)) - if ti == nil { - return pos, fmt.Errorf("unknown table %v in schema", tm.Name) - } setTimestamp := &binlogdatapb.BinlogTransaction_Statement{ Category: binlogdatapb.BinlogTransaction_Statement_BL_SET, Sql: []byte(fmt.Sprintf("SET TIMESTAMP=%d", ev.Timestamp())), @@ -479,12 +557,12 @@ func (bls *Streamer) parseEvents(ctx context.Context, events <-chan replication. Statement: setTimestamp, }) - rows, err := ev.Rows(format, tm) + rows, err := ev.Rows(format, tce.tm) if err != nil { return pos, err } - statements = appendDeletes(statements, &rows, tm, ti) + statements = bls.appendDeletes(statements, tce, &rows) if autocommit { if err = commit(ev.Timestamp()); err != nil { @@ -495,100 +573,142 @@ func (bls *Streamer) parseEvents(ctx context.Context, events <-chan replication. } } -func appendInserts(statements []FullBinlogStatement, rows *replication.Rows, tm *replication.TableMap, ti *schema.Table) []FullBinlogStatement { +func (bls *Streamer) appendInserts(statements []FullBinlogStatement, tce *tableCacheEntry, rows *replication.Rows) []FullBinlogStatement { for i := range rows.Rows { var sql bytes.Buffer sql.WriteString("INSERT INTO ") - sql.WriteString(tm.Name) + sql.WriteString(tce.tm.Name) sql.WriteString(" SET ") - if err := writeValuesAsSQL(&sql, rows, tm, ti, i); err != nil { + keyspaceIDCell, pkValues, err := writeValuesAsSQL(&sql, tce, rows, i, tce.pkNames != nil) + if err != nil { log.Warningf("writeValuesAsSQL(%v) failed: %v", i, err) continue } + // Fill in keyspace id if needed. + var ksid []byte + if tce.resolver != nil { + var err error + ksid, err = tce.resolver.keyspaceID(keyspaceIDCell) + if err != nil { + log.Warningf("resolver(%v) failed: %v", err) + } + } + statement := &binlogdatapb.BinlogTransaction_Statement{ Category: binlogdatapb.BinlogTransaction_Statement_BL_INSERT, Sql: sql.Bytes(), } statements = append(statements, FullBinlogStatement{ - Statement: statement, - Table: tm.Name, + Statement: statement, + Table: tce.tm.Name, + KeyspaceID: ksid, + PKNames: tce.pkNames, + PKValues: pkValues, }) - // TODO(alainjobart): fill in keyspaceID, pkNames, pkRows - // if necessary. } return statements } -func appendUpdates(statements []FullBinlogStatement, rows *replication.Rows, tm *replication.TableMap, ti *schema.Table) []FullBinlogStatement { +func (bls *Streamer) appendUpdates(statements []FullBinlogStatement, tce *tableCacheEntry, rows *replication.Rows) []FullBinlogStatement { for i := range rows.Rows { var sql bytes.Buffer sql.WriteString("UPDATE ") - sql.WriteString(tm.Name) + sql.WriteString(tce.tm.Name) sql.WriteString(" SET ") - if err := writeValuesAsSQL(&sql, rows, tm, ti, i); err != nil { + keyspaceIDCell, pkValues, err := writeValuesAsSQL(&sql, tce, rows, i, tce.pkNames != nil) + if err != nil { log.Warningf("writeValuesAsSQL(%v) failed: %v", i, err) continue } sql.WriteString(" WHERE ") - if err := writeIdentifiesAsSQL(&sql, rows, tm, ti, i); err != nil { + if _, _, err := writeIdentifiesAsSQL(&sql, tce, rows, i, false); err != nil { log.Warningf("writeIdentifiesAsSQL(%v) failed: %v", i, err) continue } + // Fill in keyspace id if needed. + var ksid []byte + if tce.resolver != nil { + var err error + ksid, err = tce.resolver.keyspaceID(keyspaceIDCell) + if err != nil { + log.Warningf("resolver(%v) failed: %v", err) + } + } + update := &binlogdatapb.BinlogTransaction_Statement{ Category: binlogdatapb.BinlogTransaction_Statement_BL_UPDATE, Sql: sql.Bytes(), } statements = append(statements, FullBinlogStatement{ - Statement: update, - Table: tm.Name, + Statement: update, + Table: tce.tm.Name, + KeyspaceID: ksid, + PKNames: tce.pkNames, + PKValues: pkValues, }) - // TODO(alainjobart): fill in keyspaceID, pkNames, pkRows - // if necessary. } return statements } -func appendDeletes(statements []FullBinlogStatement, rows *replication.Rows, tm *replication.TableMap, ti *schema.Table) []FullBinlogStatement { +func (bls *Streamer) appendDeletes(statements []FullBinlogStatement, tce *tableCacheEntry, rows *replication.Rows) []FullBinlogStatement { for i := range rows.Rows { var sql bytes.Buffer sql.WriteString("DELETE FROM ") - sql.WriteString(tm.Name) + sql.WriteString(tce.tm.Name) sql.WriteString(" WHERE ") - if err := writeIdentifiesAsSQL(&sql, rows, tm, ti, i); err != nil { + keyspaceIDCell, pkValues, err := writeIdentifiesAsSQL(&sql, tce, rows, i, tce.pkNames != nil) + if err != nil { log.Warningf("writeIdentifiesAsSQL(%v) failed: %v", i, err) continue } + // Fill in keyspace id if needed. + var ksid []byte + if tce.resolver != nil { + var err error + ksid, err = tce.resolver.keyspaceID(keyspaceIDCell) + if err != nil { + log.Warningf("resolver(%v) failed: %v", err) + } + } + statement := &binlogdatapb.BinlogTransaction_Statement{ Category: binlogdatapb.BinlogTransaction_Statement_BL_DELETE, Sql: sql.Bytes(), } statements = append(statements, FullBinlogStatement{ - Statement: statement, - Table: tm.Name, + Statement: statement, + Table: tce.tm.Name, + KeyspaceID: ksid, + PKNames: tce.pkNames, + PKValues: pkValues, }) - // TODO(alainjobart): fill in keyspaceID, pkNames, pkRows - // if necessary. } return statements } // writeValuesAsSQL is a helper method to print the values as SQL in the -// provided bytes.Buffer. -func writeValuesAsSQL(sql *bytes.Buffer, rs *replication.Rows, tm *replication.TableMap, ti *schema.Table, rowIndex int) error { +// provided bytes.Buffer. It also returns the value for the keyspaceIDColumn, +// and the array of values for the PK, if necessary. +func writeValuesAsSQL(sql *bytes.Buffer, tce *tableCacheEntry, rs *replication.Rows, rowIndex int, getPK bool) (sqltypes.Value, []sqltypes.Value, error) { valueIndex := 0 data := rs.Rows[rowIndex].Data pos := 0 + var keyspaceIDCell sqltypes.Value + var pkValues []sqltypes.Value + if getPK { + pkValues = make([]sqltypes.Value, len(tce.pkNames)) + } for c := 0; c < rs.DataColumns.Count(); c++ { if !rs.DataColumns.Bit(c) { continue @@ -598,7 +718,7 @@ func writeValuesAsSQL(sql *bytes.Buffer, rs *replication.Rows, tm *replication.T if valueIndex > 0 { sql.WriteString(", ") } - sql.WriteString(ti.Columns[c].Name.String()) + sql.WriteString(tce.ti.Columns[c].Name.String()) sql.WriteByte('=') if rs.Rows[rowIndex].NullColumns.Bit(valueIndex) { @@ -608,25 +728,39 @@ func writeValuesAsSQL(sql *bytes.Buffer, rs *replication.Rows, tm *replication.T continue } - // We have real data - value, l, err := replication.CellValue(data, pos, tm.Types[c], tm.Metadata[c], ti.Columns[c].Type) + // We have real data. + value, l, err := replication.CellValue(data, pos, tce.tm.Types[c], tce.tm.Metadata[c], tce.ti.Columns[c].Type) if err != nil { - return err + return keyspaceIDCell, nil, err } value.EncodeSQL(sql) + if c == tce.keyspaceIDIndex { + keyspaceIDCell = value + } + if getPK { + if tce.pkIndexes[c] != -1 { + pkValues[tce.pkIndexes[c]] = value + } + } pos += l valueIndex++ } - return nil + return keyspaceIDCell, pkValues, nil } // writeIdentifiesAsSQL is a helper method to print the identifies as SQL in the -// provided bytes.Buffer. -func writeIdentifiesAsSQL(sql *bytes.Buffer, rs *replication.Rows, tm *replication.TableMap, ti *schema.Table, rowIndex int) error { +// provided bytes.Buffer. It also returns the value for the keyspaceIDColumn, +// and the array of values for the PK, if necessary. +func writeIdentifiesAsSQL(sql *bytes.Buffer, tce *tableCacheEntry, rs *replication.Rows, rowIndex int, getPK bool) (sqltypes.Value, []sqltypes.Value, error) { valueIndex := 0 data := rs.Rows[rowIndex].Identify pos := 0 + var keyspaceIDCell sqltypes.Value + var pkValues []sqltypes.Value + if getPK { + pkValues = make([]sqltypes.Value, len(tce.pkNames)) + } for c := 0; c < rs.IdentifyColumns.Count(); c++ { if !rs.IdentifyColumns.Bit(c) { continue @@ -636,7 +770,7 @@ func writeIdentifiesAsSQL(sql *bytes.Buffer, rs *replication.Rows, tm *replicati if valueIndex > 0 { sql.WriteString(" AND ") } - sql.WriteString(ti.Columns[c].Name.String()) + sql.WriteString(tce.ti.Columns[c].Name.String()) sql.WriteByte('=') if rs.Rows[rowIndex].NullIdentifyColumns.Bit(valueIndex) { @@ -646,15 +780,23 @@ func writeIdentifiesAsSQL(sql *bytes.Buffer, rs *replication.Rows, tm *replicati continue } - // We have real data - value, l, err := replication.CellValue(data, pos, tm.Types[c], tm.Metadata[c], ti.Columns[c].Type) + // We have real data. + value, l, err := replication.CellValue(data, pos, tce.tm.Types[c], tce.tm.Metadata[c], tce.ti.Columns[c].Type) if err != nil { - return err + return keyspaceIDCell, nil, err } value.EncodeSQL(sql) + if c == tce.keyspaceIDIndex { + keyspaceIDCell = value + } + if getPK { + if tce.pkIndexes[c] != -1 { + pkValues[tce.pkIndexes[c]] = value + } + } pos += l valueIndex++ } - return nil + return keyspaceIDCell, pkValues, nil } diff --git a/go/vt/binlog/event_streamer.go b/go/vt/binlog/event_streamer.go index f1da61b447e..f9737aeef8a 100644 --- a/go/vt/binlog/event_streamer.go +++ b/go/vt/binlog/event_streamer.go @@ -45,6 +45,7 @@ func NewEventStreamer(dbname string, mysqld mysqlctl.MysqlDaemon, se *schema.Eng sendEvent: sendEvent, } evs.bls = NewStreamer(dbname, mysqld, se, nil, startPos, timestamp, evs.transactionToEvent) + evs.bls.extractPK = true return evs } @@ -74,7 +75,7 @@ func (evs *EventStreamer) transactionToEvent(eventToken *querypb.EventToken, sta binlogdatapb.BinlogTransaction_Statement_BL_UPDATE, binlogdatapb.BinlogTransaction_Statement_BL_DELETE: var dmlStatement *querypb.StreamEvent_Statement - dmlStatement, insertid, err = evs.buildDMLStatement(string(stmt.Statement.Sql), insertid) + dmlStatement, insertid, err = evs.buildDMLStatement(stmt, insertid) if err != nil { dmlStatement = &querypb.StreamEvent_Statement{ Category: querypb.StreamEvent_Statement_Error, @@ -103,12 +104,30 @@ func (evs *EventStreamer) transactionToEvent(eventToken *querypb.EventToken, sta } /* -buildDMLStatement parses the tuples of the full stream comment. +buildDMLStatement recovers the PK from a FullBinlogStatement. +For RBR, the values are already in there, just need to be translated. +For SBR, parses the tuples of the full stream comment. The _stream comment is extracted into a StreamEvent.Statement. */ // Example query: insert into _table_(foo) values ('foo') /* _stream _table_ (eid id name ) (null 1 'bmFtZQ==' ); */ // the "null" value is used for auto-increment columns. -func (evs *EventStreamer) buildDMLStatement(sql string, insertid int64) (*querypb.StreamEvent_Statement, int64, error) { +func (evs *EventStreamer) buildDMLStatement(stmt FullBinlogStatement, insertid int64) (*querypb.StreamEvent_Statement, int64, error) { + // For RBR events, we know all this already, just extract it. + if stmt.PKNames != nil { + // We get an array of []sqltypes.Value, need to convert to querypb.Row. + dmlStatement := &querypb.StreamEvent_Statement{ + Category: querypb.StreamEvent_Statement_DML, + TableName: stmt.Table, + PrimaryKeyFields: stmt.PKNames, + PrimaryKeyValues: []*querypb.Row{sqltypes.RowToProto3(stmt.PKValues)}, + } + // InsertID is only needed to fill in the ID on next queries, + // but if we use RBR, it's already in the values, so just return 0. + return dmlStatement, 0, nil + } + + sql := string(stmt.Statement.Sql) + // first extract the comment commentIndex := strings.LastIndex(sql, streamCommentStart) if commentIndex == -1 { @@ -116,7 +135,7 @@ func (evs *EventStreamer) buildDMLStatement(sql string, insertid int64) (*queryp } dmlComment := sql[commentIndex+streamCommentStartLen:] - // then strat building the response + // then start building the response dmlStatement := &querypb.StreamEvent_Statement{ Category: querypb.StreamEvent_Statement_DML, } diff --git a/go/vt/binlog/keyrange_filter.go b/go/vt/binlog/keyrange_filter.go index 614700d9e2a..b203e0861eb 100644 --- a/go/vt/binlog/keyrange_filter.go +++ b/go/vt/binlog/keyrange_filter.go @@ -37,6 +37,18 @@ func KeyRangeFilterFunc(keyrange *topodatapb.KeyRange, callback func(*binlogdata case binlogdatapb.BinlogTransaction_Statement_BL_INSERT, binlogdatapb.BinlogTransaction_Statement_BL_UPDATE, binlogdatapb.BinlogTransaction_Statement_BL_DELETE: + // Handle RBR case first. + if statement.KeyspaceID != nil { + if !key.KeyRangeContains(keyrange, statement.KeyspaceID) { + // Skip keyspace ids that don't belong to the destination shard. + continue + } + filtered = append(filtered, statement.Statement) + matched = true + continue + } + + // SBR case. keyspaceIDS, err := sqlannotation.ExtractKeyspaceIDS(string(statement.Statement.Sql)) if err != nil { if statement.Statement.Category == binlogdatapb.BinlogTransaction_Statement_BL_INSERT { diff --git a/go/vt/binlog/keyspace_id_resolver.go b/go/vt/binlog/keyspace_id_resolver.go new file mode 100644 index 00000000000..4f71f3c3c61 --- /dev/null +++ b/go/vt/binlog/keyspace_id_resolver.go @@ -0,0 +1,164 @@ +package binlog + +import ( + "flag" + "fmt" + + "golang.org/x/net/context" + + "github.com/youtube/vitess/go/sqltypes" + "github.com/youtube/vitess/go/vt/key" + "github.com/youtube/vitess/go/vt/topo" + "github.com/youtube/vitess/go/vt/vtgate/vindexes" + "github.com/youtube/vitess/go/vt/vttablet/tabletserver/schema" + + topodatapb "github.com/youtube/vitess/go/vt/proto/topodata" +) + +var useV3ReshardingMode = flag.Bool("binlog_use_v3_resharding_mode", false, "True iff the binlog streamer should use V3-style sharding, which doesn't require a preset sharding key column.") + +// keyspaceIDResolver is constructed for a tableMap entry in RBR. It +// is used for each row, and passed in the value used for figuring out +// the keyspace id. +type keyspaceIDResolver interface { + // keyspaceID takes a table row, and returns the keyspace id as bytes. + // It will return an error if no sharding key can be found. + // The bitmap describes which columns are present in the row. + keyspaceID(value sqltypes.Value) ([]byte, error) +} + +// keyspaceIDResolverFactory creates a keyspaceIDResolver for a table +// given its schema. It returns the index of the field to used to compute +// the keyspaceID, and a function that given a value for that +// field, returns the keyspace id. +type keyspaceIDResolverFactory func(*schema.Table) (int, keyspaceIDResolver, error) + +// newKeyspaceIDResolverFactory creates a new +// keyspaceIDResolverFactory for the provided keyspace and cell. +func newKeyspaceIDResolverFactory(ctx context.Context, ts topo.Server, keyspace string, cell string) (keyspaceIDResolverFactory, error) { + if *useV3ReshardingMode { + return newKeyspaceIDResolverFactoryV3(ctx, ts, keyspace, cell) + } + + return newKeyspaceIDResolverFactoryV2(ctx, ts, keyspace) +} + +// newKeyspaceIDResolverFactoryV2 finds the ShardingColumnName / Type +// from the keyspace, and uses it to find the column name. +func newKeyspaceIDResolverFactoryV2(ctx context.Context, ts topo.Server, keyspace string) (keyspaceIDResolverFactory, error) { + ki, err := ts.GetKeyspace(ctx, keyspace) + if err != nil { + return nil, err + } + if ki.ShardingColumnName == "" { + return nil, fmt.Errorf("ShardingColumnName needs to be set for a v2 sharding key for keyspace %v", keyspace) + } + switch ki.ShardingColumnType { + case topodatapb.KeyspaceIdType_UNSET: + return nil, fmt.Errorf("ShardingColumnType needs to be set for a v2 sharding key for keyspace %v", keyspace) + case topodatapb.KeyspaceIdType_BYTES, topodatapb.KeyspaceIdType_UINT64: + // Supported values, we're good. + default: + return nil, fmt.Errorf("unknown ShardingColumnType %v for v2 sharding key for keyspace %v", ki.ShardingColumnType, keyspace) + } + return func(table *schema.Table) (int, keyspaceIDResolver, error) { + for i, col := range table.Columns { + if col.Name.EqualString(ki.ShardingColumnName) { + // We found the column. + return i, &keyspaceIDResolverFactoryV2{ + shardingColumnType: ki.ShardingColumnType, + }, nil + } + } + // The column was not found. + return -1, nil, fmt.Errorf("cannot find column %v in table %v", ki.ShardingColumnName, table.Name) + }, nil +} + +// keyspaceIDResolverFactoryV2 uses the KeyspaceInfo of the Keyspace +// to find the sharding column name. +type keyspaceIDResolverFactoryV2 struct { + shardingColumnType topodatapb.KeyspaceIdType +} + +func (r *keyspaceIDResolverFactoryV2) keyspaceID(v sqltypes.Value) ([]byte, error) { + switch r.shardingColumnType { + case topodatapb.KeyspaceIdType_BYTES: + return v.Raw(), nil + case topodatapb.KeyspaceIdType_UINT64: + i, err := v.ParseUint64() + if err != nil { + return nil, fmt.Errorf("Non numerical value: %v", err) + } + return key.Uint64Key(i).Bytes(), nil + default: + panic("unreachable") + } +} + +// newKeyspaceIDResolverFactoryV3 finds the SrvVSchema in the cell, +// gets the keyspace part, and uses it to find the column name. +func newKeyspaceIDResolverFactoryV3(ctx context.Context, ts topo.Server, keyspace string, cell string) (keyspaceIDResolverFactory, error) { + srvVSchema, err := ts.GetSrvVSchema(ctx, cell) + if err != nil { + return nil, err + } + kschema, ok := srvVSchema.Keyspaces[keyspace] + if !ok { + return nil, fmt.Errorf("SrvVSchema has no entry for keyspace %v", keyspace) + } + keyspaceSchema, err := vindexes.BuildKeyspaceSchema(kschema, keyspace) + if err != nil { + return nil, fmt.Errorf("cannot build vschema for keyspace %v: %v", keyspace, err) + } + return func(table *schema.Table) (int, keyspaceIDResolver, error) { + // Find the v3 schema. + tableSchema, ok := keyspaceSchema.Tables[table.Name.String()] + if !ok { + return -1, nil, fmt.Errorf("no vschema definition for table %v", table.Name) + } + + // The primary vindex is most likely the sharding key, + // and has to be unique. + if len(tableSchema.ColumnVindexes) == 0 { + return -1, nil, fmt.Errorf("no vindex definition for table %v", table.Name) + } + colVindex := tableSchema.ColumnVindexes[0] + if colVindex.Vindex.Cost() > 1 { + return -1, nil, fmt.Errorf("primary vindex cost is too high for table %v", table.Name) + } + unique, ok := colVindex.Vindex.(vindexes.Unique) + if !ok { + return -1, nil, fmt.Errorf("primary vindex is not unique for table %v", table.Name) + } + + shardingColumnName := colVindex.Column.String() + for i, col := range table.Columns { + if col.Name.EqualString(shardingColumnName) { + // We found the column. + return i, &keyspaceIDResolverFactoryV3{ + vindex: unique, + }, nil + } + } + // The column was not found. + return -1, nil, fmt.Errorf("cannot find column %v in table %v", shardingColumnName, table.Name) + }, nil +} + +// keyspaceIDResolverFactoryV3 uses the Vindex to compute the value. +type keyspaceIDResolverFactoryV3 struct { + vindex vindexes.Unique +} + +func (r *keyspaceIDResolverFactoryV3) keyspaceID(v sqltypes.Value) ([]byte, error) { + ids := []interface{}{v} + ksids, err := r.vindex.Map(nil, ids) + if err != nil { + return nil, err + } + if len(ksids) != 1 { + return nil, fmt.Errorf("maping row to keyspace id returned an invalid array of keyspace ids: %v", ksids) + } + return ksids[0], nil +} diff --git a/go/vt/binlog/updatestreamctl.go b/go/vt/binlog/updatestreamctl.go index 11eddf9935c..6e191531003 100644 --- a/go/vt/binlog/updatestreamctl.go +++ b/go/vt/binlog/updatestreamctl.go @@ -16,6 +16,7 @@ import ( "github.com/youtube/vitess/go/sync2" "github.com/youtube/vitess/go/tb" "github.com/youtube/vitess/go/vt/mysqlctl" + "github.com/youtube/vitess/go/vt/topo" "github.com/youtube/vitess/go/vt/vttablet/tabletserver/schema" binlogdatapb "github.com/youtube/vitess/go/vt/proto/binlogdata" @@ -93,9 +94,12 @@ func (m *UpdateStreamControlMock) IsEnabled() bool { // and UpdateStreamControl type UpdateStreamImpl struct { // the following variables are set at construction time - mysqld mysqlctl.MysqlDaemon - dbname string - se *schema.Engine + ts topo.Server + keyspace string + cell string + mysqld mysqlctl.MysqlDaemon + dbname string + se *schema.Engine // actionLock protects the following variables actionLock sync.Mutex @@ -155,11 +159,14 @@ type RegisterUpdateStreamServiceFunc func(UpdateStream) var RegisterUpdateStreamServices []RegisterUpdateStreamServiceFunc // NewUpdateStream returns a new UpdateStreamImpl object -func NewUpdateStream(mysqld mysqlctl.MysqlDaemon, se *schema.Engine, dbname string) *UpdateStreamImpl { +func NewUpdateStream(ts topo.Server, keyspace string, cell string, mysqld mysqlctl.MysqlDaemon, se *schema.Engine, dbname string) *UpdateStreamImpl { return &UpdateStreamImpl{ - mysqld: mysqld, - se: se, - dbname: dbname, + ts: ts, + keyspace: keyspace, + cell: cell, + mysqld: mysqld, + se: se, + dbname: dbname, } } @@ -245,6 +252,10 @@ func (updateStream *UpdateStreamImpl) StreamKeyRange(ctx context.Context, positi return callback(trans) }) bls := NewStreamer(updateStream.dbname, updateStream.mysqld, updateStream.se, charset, pos, 0, f) + bls.resolverFactory, err = newKeyspaceIDResolverFactory(ctx, updateStream.ts, updateStream.keyspace, updateStream.cell) + if err != nil { + return fmt.Errorf("newKeyspaceIDResolverFactory failed: %v", err) + } streamCtx, cancel := context.WithCancel(ctx) i := updateStream.streams.Add(cancel) diff --git a/go/vt/vttablet/tabletmanager/action_agent.go b/go/vt/vttablet/tabletmanager/action_agent.go index 7472db16a2b..275dede479f 100644 --- a/go/vt/vttablet/tabletmanager/action_agent.go +++ b/go/vt/vttablet/tabletmanager/action_agent.go @@ -546,7 +546,7 @@ func (agent *ActionAgent) Start(ctx context.Context, mysqlPort, vtPort, gRPCPort // (it needs the dbname, so it has to be delayed up to here, // but it has to be before updateState below that may use it) if initUpdateStream { - us := binlog.NewUpdateStream(agent.MysqlDaemon, agent.QueryServiceControl.SchemaEngine(), agent.DBConfigs.App.DbName) + us := binlog.NewUpdateStream(agent.TopoServer, agent.initialTablet.Keyspace, agent.TabletAlias.Cell, agent.MysqlDaemon, agent.QueryServiceControl.SchemaEngine(), agent.DBConfigs.App.DbName) agent.UpdateStream = us servenv.OnRun(func() { us.RegisterService() diff --git a/go/vt/worker/key_resolver.go b/go/vt/worker/key_resolver.go index 21597c6ce51..efb1d07f09a 100644 --- a/go/vt/worker/key_resolver.go +++ b/go/vt/worker/key_resolver.go @@ -12,10 +12,11 @@ import ( "github.com/youtube/vitess/go/vt/key" "github.com/youtube/vitess/go/vt/mysqlctl/tmutils" - tabletmanagerdatapb "github.com/youtube/vitess/go/vt/proto/tabletmanagerdata" - topodatapb "github.com/youtube/vitess/go/vt/proto/topodata" "github.com/youtube/vitess/go/vt/topo" "github.com/youtube/vitess/go/vt/vtgate/vindexes" + + tabletmanagerdatapb "github.com/youtube/vitess/go/vt/proto/tabletmanagerdata" + topodatapb "github.com/youtube/vitess/go/vt/proto/topodata" ) // This file defines the interface and implementations of sharding key resolvers. diff --git a/test/base_sharding.py b/test/base_sharding.py index 2a6ce34b1d0..e3cb51a6873 100644 --- a/test/base_sharding.py +++ b/test/base_sharding.py @@ -17,6 +17,7 @@ keyspace_id_type = keyrange_constants.KIT_UINT64 +use_rbr = False pack_keyspace_id = struct.Struct('!Q').pack # fixed_parent_id is used as fixed value for the "parent_id" column in all rows. diff --git a/test/config.json b/test/config.json index 03cffaa5e59..0996dd59c96 100644 --- a/test/config.json +++ b/test/config.json @@ -280,6 +280,17 @@ "worker_test" ] }, + "resharding_rbr": { + "File": "resharding_rbr.py", + "Args": [], + "Command": [], + "Manual": false, + "Shard": 0, + "RetryMax": 0, + "Tags": [ + "worker_test" + ] + }, "schema": { "File": "schema.py", "Args": [], @@ -385,6 +396,15 @@ "RetryMax": 0, "Tags": [] }, + "update_stream_rbr": { + "File": "update_stream_rbr.py", + "Args": [], + "Command": [], + "Manual": false, + "Shard": 4, + "RetryMax": 0, + "Tags": [] + }, "vertical_split": { "File": "vertical_split.py", "Args": [], diff --git a/test/resharding.py b/test/resharding.py index 7fa8c923377..ce8b738dd17 100755 --- a/test/resharding.py +++ b/test/resharding.py @@ -70,7 +70,8 @@ def setUpModule(): try: environment.topo_server().setup() - setup_procs = [t.init_mysql() for t in all_tablets] + setup_procs = [t.init_mysql(use_rbr=base_sharding.use_rbr) + for t in all_tablets] utils.Vtctld().start() utils.wait_procs(setup_procs) except: @@ -728,8 +729,17 @@ def test_resharding(self): utils.pause('Good time to test vtworker for diffs') # get status for destination master tablets, make sure we have it all - self.check_running_binlog_player(shard_2_master, 4022, 2008) - self.check_running_binlog_player(shard_3_master, 4024, 2008) + if base_sharding.use_rbr: + # We submitted non-annotated DMLs, that are properly routed + # with RBR, but not with SBR. So the first shard counts + # are smaller. In the second shard, we submitted statements + # that affect more than one keyspace id. These will result + # in two queries with RBR. So the count there is higher. + self.check_running_binlog_player(shard_2_master, 4018, 2008) + self.check_running_binlog_player(shard_3_master, 4028, 2008) + else: + self.check_running_binlog_player(shard_2_master, 4022, 2008) + self.check_running_binlog_player(shard_3_master, 4024, 2008) # start a thread to insert data into shard_1 in the background # with current time, and monitor the delay diff --git a/test/resharding_rbr.py b/test/resharding_rbr.py new file mode 100755 index 00000000000..f70e3100f3c --- /dev/null +++ b/test/resharding_rbr.py @@ -0,0 +1,15 @@ +#!/usr/bin/env python +# +# Copyright 2017, Google Inc. All rights reserved. +# Use of this source code is governed by a BSD-style license that can +# be found in the LICENSE file. + +"""Re-runs resharding.py with RBR.""" + +import base_sharding +import resharding +import utils + +if __name__ == '__main__': + base_sharding.use_rbr = True + utils.main(resharding) diff --git a/test/tablet.py b/test/tablet.py index fa7df312d21..c57c5bb6bb5 100644 --- a/test/tablet.py +++ b/test/tablet.py @@ -180,7 +180,8 @@ def mysqlctld(self, cmd, extra_my_cnf=None, verbose=False, extra_args=None): args.extend(cmd) return utils.run_bg(args, extra_env=extra_env) - def init_mysql(self, extra_my_cnf=None, init_db=None, extra_args=None): + def init_mysql(self, extra_my_cnf=None, init_db=None, extra_args=None, + use_rbr=False): """Init the mysql tablet directory, starts mysqld. Either runs 'mysqlctl init', or starts a mysqlctld process. @@ -189,10 +190,17 @@ def init_mysql(self, extra_my_cnf=None, init_db=None, extra_args=None): extra_my_cnf: to pass to mysqlctl. init_db: if set, use this init_db script instead of the default. extra_args: passed to mysqlctld / mysqlctl. + use_rbr: configure the MySQL daemon to use RBR. Returns: The forked process. """ + if use_rbr: + if extra_my_cnf: + extra_my_cnf += ':' + environment.vttop + '/config/mycnf/rbr.cnf' + else: + extra_my_cnf = environment.vttop + '/config/mycnf/rbr.cnf' + if not init_db: init_db = environment.vttop + '/config/init_db.sql' diff --git a/test/update_stream.py b/test/update_stream.py index e78f010d7a6..2a01a4563c4 100755 --- a/test/update_stream.py +++ b/test/update_stream.py @@ -16,6 +16,9 @@ from protocols_flavor import protocols_flavor from vtgate_gateway_flavor.gateway import vtgate_gateway_flavor +# global flag to control which type of replication we use. +use_rbr = False + master_tablet = tablet.Tablet() replica_tablet = tablet.Tablet() @@ -59,8 +62,8 @@ def setUpModule(): environment.topo_server().setup() # start mysql instance external to the test - setup_procs = [master_tablet.init_mysql(), - replica_tablet.init_mysql()] + setup_procs = [master_tablet.init_mysql(use_rbr=use_rbr), + replica_tablet.init_mysql(use_rbr=use_rbr)] utils.wait_procs(setup_procs) # start a vtctld so the vtctl insert commands are just RPCs, not forks diff --git a/test/update_stream_rbr.py b/test/update_stream_rbr.py new file mode 100755 index 00000000000..6625b763dc3 --- /dev/null +++ b/test/update_stream_rbr.py @@ -0,0 +1,14 @@ +#!/usr/bin/env python +# +# Copyright 2017, Google Inc. All rights reserved. +# Use of this source code is governed by a BSD-style license that can +# be found in the LICENSE file. + +"""Re-runs update_stream.py with RBR.""" + +import update_stream +import utils + +if __name__ == '__main__': + update_stream.use_rbr = True + utils.main(update_stream) diff --git a/test/vertical_split.py b/test/vertical_split.py index 72277debaa4..64205f5c510 100755 --- a/test/vertical_split.py +++ b/test/vertical_split.py @@ -18,9 +18,6 @@ from vtdb import keyrange_constants from vtdb import vtgate_client -# Global variables, for tests flavors. -use_rbr = False - # source keyspace, with 4 tables source_master = tablet.Tablet() source_replica = tablet.Tablet() @@ -41,10 +38,8 @@ def setUpModule(): try: environment.topo_server().setup() - extra_my_cnf = None - if use_rbr: - extra_my_cnf = environment.vttop + '/config/mycnf/rbr.cnf' - setup_procs = [t.init_mysql(extra_my_cnf=extra_my_cnf) for t in all_tablets] + setup_procs = [t.init_mysql(use_rbr=base_sharding.use_rbr) + for t in all_tablets] utils.Vtctld().start() utils.wait_procs(setup_procs) except: diff --git a/test/vertical_split_rbr.py b/test/vertical_split_rbr.py index 9b36fe74e77..39c1e6c6ddc 100755 --- a/test/vertical_split_rbr.py +++ b/test/vertical_split_rbr.py @@ -6,9 +6,10 @@ """Re-runs resharding.py with RBR on.""" +import base_sharding import vertical_split import utils if __name__ == '__main__': - vertical_split.use_rbr = True + base_sharding.use_rbr = True utils.main(vertical_split)