diff --git a/go/vt/mysqlctl/schema.go b/go/vt/mysqlctl/schema.go index 7d57c5488cb..574f8178915 100644 --- a/go/vt/mysqlctl/schema.go +++ b/go/vt/mysqlctl/schema.go @@ -20,6 +20,7 @@ import ( "fmt" "regexp" "strings" + "sync" "vitess.io/vitess/go/vt/vtgate/evalengine" @@ -46,6 +47,15 @@ func (mysqld *Mysqld) executeSchemaCommands(sql string) error { return mysqld.executeMysqlScript(params, strings.NewReader(sql)) } +// tableList returns an IN clause "('t1', 't2'...) for a list of tables." +func tableListSql(tables []string) string { + if len(tables) == 0 { + return "()" + } + + return "('" + strings.Join(tables, "', '") + "')" +} + // GetSchema returns the schema for database for tables listed in // tables. If tables is empty, return the schema for all tables. func (mysqld *Mysqld) GetSchema(ctx context.Context, dbName string, tables, excludeTables []string, includeViews bool) (*tabletmanagerdatapb.SchemaDefinition, error) { @@ -75,11 +85,36 @@ func (mysqld *Mysqld) GetSchema(ctx context.Context, dbName string, tables, excl return sd, nil } - sd.TableDefinitions = make([]*tabletmanagerdatapb.TableDefinition, 0, len(qr.Rows)) + filter, err := tmutils.NewTableFilter(tables, excludeTables, includeViews) + if err != nil { + return nil, err + } + + type schemaResult struct { + idx int + err error + + td *tabletmanagerdatapb.TableDefinition + } + + resChan := make(chan *schemaResult, 100) + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + var wg sync.WaitGroup + + tableNames := make([]string, 0, len(qr.Rows)) + i := 0 for _, row := range qr.Rows { tableName := row[0].ToString() tableType := row[1].ToString() + if !filter.Includes(tableName, tableType) { + continue + } + + tableNames = append(tableNames, tableName) + // compute dataLength var dataLength uint64 if !row[2].IsNull() { @@ -99,51 +134,111 @@ func (mysqld *Mysqld) GetSchema(ctx context.Context, dbName string, tables, excl } } - qr, fetchErr := mysqld.FetchSuperQuery(ctx, fmt.Sprintf("SHOW CREATE TABLE %s.%s", backtickDBName, sqlescape.EscapeID(tableName))) - if fetchErr != nil { - return nil, fetchErr - } - if len(qr.Rows) == 0 { - return nil, fmt.Errorf("empty create table statement for %v", tableName) - } + wg.Add(1) + go func(idx int) { + defer wg.Done() - // Normalize & remove auto_increment because it changes on every insert - // FIXME(alainjobart) find a way to share this with - // vt/tabletserver/table_info.go:162 - norm := qr.Rows[0][1].ToString() - norm = autoIncr.ReplaceAllLiteralString(norm, "") - if tableType == tmutils.TableView { - // Views will have the dbname in there, replace it - // with {{.DatabaseName}} - norm = strings.Replace(norm, backtickDBName, "{{.DatabaseName}}", -1) - } + fields, columns, schema, err := mysqld.collectSchema(ctx, dbName, tableName, tableType) + if err != nil { + resChan <- &schemaResult{ + idx: idx, + err: err, + } + return + } + + resChan <- &schemaResult{ + idx: idx, + td: &tabletmanagerdatapb.TableDefinition{ + Name: tableName, + Type: tableType, + DataLength: dataLength, + RowCount: rowCount, + Fields: fields, + Columns: columns, + Schema: schema, + }, + } + }(i) - td := &tabletmanagerdatapb.TableDefinition{} - td.Name = tableName - td.Schema = norm + i++ + } - td.Fields, td.Columns, err = mysqld.GetColumns(ctx, dbName, tableName) - if err != nil { - return nil, err - } - td.PrimaryKeyColumns, err = mysqld.GetPrimaryKeyColumns(ctx, dbName, tableName) + go func() { + wg.Wait() + close(resChan) + }() + + colMap := map[string][]string{} + if len(tableNames) > 0 { + log.Infof("mysqld GetSchema: GetPrimaryKeyColumns") + var err error + colMap, err = mysqld.getPrimaryKeyColumns(ctx, dbName, tableNames...) if err != nil { return nil, err } - td.Type = tableType - td.DataLength = dataLength - td.RowCount = rowCount - sd.TableDefinitions = append(sd.TableDefinitions, td) + log.Infof("mysqld GetSchema: GetPrimaryKeyColumns done") } - sd, err = tmutils.FilterTables(sd, tables, excludeTables, includeViews) - if err != nil { - return nil, err + log.Infof("mysqld GetSchema: Collecting all table schemas") + tds := make([]*tabletmanagerdatapb.TableDefinition, 0, i) + for res := range resChan { + if res.err != nil { + cancel() + return nil, res.err + } + + td := res.td + td.PrimaryKeyColumns = colMap[td.Name] + + tds[res.idx] = res.td } + log.Infof("mysqld GetSchema: Collecting all table schemas done") + + sd.TableDefinitions = tds + tmutils.GenerateSchemaVersion(sd) return sd, nil } +func (mysqld *Mysqld) collectSchema(ctx context.Context, dbName, tableName, tableType string) ([]*querypb.Field, []string, string, error) { + fields, columns, err := mysqld.GetColumns(ctx, dbName, tableName) + if err != nil { + return nil, nil, "", err + } + + schema, err := mysqld.normalizedSchema(ctx, dbName, tableName, tableType) + if err != nil { + return nil, nil, "", err + } + + return fields, columns, schema, nil +} + +func (mysqld *Mysqld) normalizedSchema(ctx context.Context, dbName, tableName, tableType string) (string, error) { + backtickDBName := sqlescape.EscapeID(dbName) + qr, fetchErr := mysqld.FetchSuperQuery(ctx, fmt.Sprintf("SHOW CREATE TABLE %s.%s", dbName, sqlescape.EscapeID(tableName))) + if fetchErr != nil { + return "", fetchErr + } + if len(qr.Rows) == 0 { + return "", fmt.Errorf("empty create table statement for %v", tableName) + } + + // Normalize & remove auto_increment because it changes on every insert + // FIXME(alainjobart) find a way to share this with + // vt/tabletserver/table_info.go:162 + norm := qr.Rows[0][1].ToString() + norm = autoIncr.ReplaceAllLiteralString(norm, "") + if tableType == tmutils.TableView { + // Views will have the dbname in there, replace it + // with {{.DatabaseName}} + norm = strings.Replace(norm, backtickDBName, "{{.DatabaseName}}", -1) + } + + return norm, nil +} + // ResolveTables returns a list of actual tables+views matching a list // of regexps func ResolveTables(ctx context.Context, mysqld MysqlDaemon, dbName string, tables []string) ([]string, error) { @@ -166,11 +261,11 @@ func (mysqld *Mysqld) GetColumns(ctx context.Context, dbName, table string) ([]* } defer conn.Recycle() - sql := fmt.Sprintf("SELECT * FROM %s.%s WHERE 1=0", sqlescape.EscapeID(dbName), sqlescape.EscapeID(table)) - qr, err := mysqld.executeFetchContext(ctx, conn, sql, 0, true) + qr, err := conn.ExecuteFetch(fmt.Sprintf("SELECT * FROM %s.%s WHERE 1=0", sqlescape.EscapeID(dbName), sqlescape.EscapeID(table)), 0, true) if err != nil { return nil, nil, err } + columns := make([]string, len(qr.Fields)) for i, field := range qr.Fields { columns[i] = field.Name @@ -181,55 +276,40 @@ func (mysqld *Mysqld) GetColumns(ctx context.Context, dbName, table string) ([]* // GetPrimaryKeyColumns returns the primary key columns of table. func (mysqld *Mysqld) GetPrimaryKeyColumns(ctx context.Context, dbName, table string) ([]string, error) { + cs, err := mysqld.getPrimaryKeyColumns(ctx, dbName, table) + if err != nil { + return nil, err + } + + return cs[dbName], nil +} + +func (mysqld *Mysqld) getPrimaryKeyColumns(ctx context.Context, dbName string, tables ...string) (map[string][]string, error) { conn, err := getPoolReconnect(ctx, mysqld.dbaPool) if err != nil { return nil, err } defer conn.Recycle() - sql := fmt.Sprintf("SHOW INDEX FROM %s.%s", sqlescape.EscapeID(dbName), sqlescape.EscapeID(table)) - qr, err := mysqld.executeFetchContext(ctx, conn, sql, 100, true) + tableList := tableListSql(tables) + sql := fmt.Sprintf(` + SELECT table_name, ordinal_position, column_name + FROM information_schema.key_column_usage + WHERE table_schema = '%s' + AND table_name IN %s + AND constraint_name='PRIMARY' + ORDER BY table_name, ordinal_position`, dbName, tableList) + qr, err := conn.ExecuteFetch(sql, len(tables)*100, true) if err != nil { return nil, err } - keyNameIndex := -1 - seqInIndexIndex := -1 - columnNameIndex := -1 - for i, field := range qr.Fields { - switch field.Name { - case "Key_name": - keyNameIndex = i - case "Seq_in_index": - seqInIndexIndex = i - case "Column_name": - columnNameIndex = i - } - } - if keyNameIndex == -1 || seqInIndexIndex == -1 || columnNameIndex == -1 { - return nil, fmt.Errorf("unknown columns in 'show index' result: %v", qr.Fields) - } - columns := make([]string, 0, 5) - var expectedIndex int64 = 1 + colMap := map[string][]string{} for _, row := range qr.Rows { - // skip non-primary keys - if row[keyNameIndex].ToString() != "PRIMARY" { - continue - } - - // check the Seq_in_index is always increasing - seqInIndex, err := evalengine.ToInt64(row[seqInIndexIndex]) - if err != nil { - return nil, err - } - if seqInIndex != expectedIndex { - return nil, fmt.Errorf("unexpected index: %v != %v", seqInIndex, expectedIndex) - } - expectedIndex++ - - columns = append(columns, row[columnNameIndex].ToString()) + tableName := row[0].ToString() + colMap[tableName] = append(colMap[tableName], row[2].ToString()) } - return columns, err + return colMap, err } // PreflightSchemaChange checks the schema changes in "changes" by applying them diff --git a/go/vt/mysqlctl/tmutils/schema.go b/go/vt/mysqlctl/tmutils/schema.go index 58e547d16a2..86cf509e1ba 100644 --- a/go/vt/mysqlctl/tmutils/schema.go +++ b/go/vt/mysqlctl/tmutils/schema.go @@ -63,94 +63,130 @@ func (tds TableDefinitions) Swap(i, j int) { tds[i], tds[j] = tds[j], tds[i] } -// FilterTables returns a copy which includes only whitelisted tables +// TableFilter is a filter for table names and types. +type TableFilter struct { + includeViews bool + + filterTables bool + tableNames []string + tableREs []*regexp.Regexp + + filterExcludeTables bool + excludeTableNames []string + excludeTableREs []*regexp.Regexp +} + +// NewTableFilter creates a TableFilter for whitelisted tables // (tables), no blacklisted tables (excludeTables) and optionally // views (includeViews). -func FilterTables(sd *tabletmanagerdatapb.SchemaDefinition, tables, excludeTables []string, includeViews bool) (*tabletmanagerdatapb.SchemaDefinition, error) { - copy := *sd - copy.TableDefinitions = make([]*tabletmanagerdatapb.TableDefinition, 0, len(sd.TableDefinitions)) +func NewTableFilter(tables, excludeTables []string, includeViews bool) (*TableFilter, error) { + f := &TableFilter{ + includeViews: includeViews, + } // Build a list of regexp to match table names against. // We only use regexps if the name starts and ends with '/'. // Otherwise the entry in the arrays is nil, and we use the original // table name. - var tableRegexps []*regexp.Regexp if len(tables) > 0 { - tableRegexps = make([]*regexp.Regexp, len(tables)) - for i, table := range tables { - if strings.HasPrefix(table, "/") { + f.filterTables = true + for _, table := range tables { + if strings.HasPrefix(table, "/") && strings.HasSuffix(table, "/") { table = strings.Trim(table, "/") - var err error - tableRegexps[i], err = regexp.Compile(table) + re, err := regexp.Compile(table) if err != nil { return nil, fmt.Errorf("cannot compile regexp %v for table: %v", table, err) } + + f.tableREs = append(f.tableREs, re) + } else { + f.tableNames = append(f.tableNames, table) } } } - var excludeTableRegexps []*regexp.Regexp + if len(excludeTables) > 0 { - excludeTableRegexps = make([]*regexp.Regexp, len(excludeTables)) - for i, table := range excludeTables { - if strings.HasPrefix(table, "/") { + f.filterExcludeTables = true + for _, table := range excludeTables { + if strings.HasPrefix(table, "/") && strings.HasSuffix(table, "/") { table = strings.Trim(table, "/") - var err error - excludeTableRegexps[i], err = regexp.Compile(table) + re, err := regexp.Compile(table) if err != nil { return nil, fmt.Errorf("cannot compile regexp %v for excludeTable: %v", table, err) } + + f.excludeTableREs = append(f.tableREs, re) + } else { + f.excludeTableNames = append(f.excludeTableNames, table) } } } - for _, table := range sd.TableDefinitions { - // Check it's a table we want. - if len(tables) > 0 { - foundMatch := false - for i, tableRegexp := range tableRegexps { - if tableRegexp == nil { - // Not a regexp, just compare in a - // case insensitive way. - if strings.EqualFold(tables[i], table.Name) { - foundMatch = true - break - } - } else { - if tableRegexp.MatchString(table.Name) { - foundMatch = true - break - } - } - } - if !foundMatch { - continue + return f, nil +} + +// Includes returns whether a tableName/tableType should be included in this TableFilter. +func (f *TableFilter) Includes(tableName string, tableType string) bool { + if f.filterTables { + matches := false + for _, name := range f.tableNames { + if strings.EqualFold(name, tableName) { + matches = true + break } } - excluded := false - for i, tableRegexp := range excludeTableRegexps { - if tableRegexp == nil { - // Not a regexp, just compare in a - // case insensitive way. - if strings.EqualFold(excludeTables[i], table.Name) { - excluded = true - break - } - } else { - if tableRegexp.MatchString(table.Name) { - excluded = true + + if !matches { + for _, re := range f.tableREs { + if re.MatchString(tableName) { + matches = true break } } } - if excluded { - continue + + if !matches { + return false } + } - if !includeViews && table.Type == TableView { - continue + if f.filterExcludeTables { + for _, name := range f.excludeTableNames { + if strings.EqualFold(name, tableName) { + return false + } + } + + for _, re := range f.excludeTableREs { + if re.MatchString(tableName) { + return false + } } + } + + if !f.includeViews && tableType == TableView { + return false + } + + return true +} - copy.TableDefinitions = append(copy.TableDefinitions, table) +// FilterTables returns a copy which includes only whitelisted tables +// (tables), no blacklisted tables (excludeTables) and optionally +// views (includeViews). +func FilterTables(sd *tabletmanagerdatapb.SchemaDefinition, tables, excludeTables []string, includeViews bool) (*tabletmanagerdatapb.SchemaDefinition, error) { + copy := *sd + copy.TableDefinitions = make([]*tabletmanagerdatapb.TableDefinition, 0, len(sd.TableDefinitions)) + + f, err := NewTableFilter(tables, excludeTables, includeViews) + if err != nil { + return nil, err + } + + for _, table := range sd.TableDefinitions { + if f.Includes(table.Name, table.Type) { + copy.TableDefinitions = append(copy.TableDefinitions, table) + } } // Regenerate hash over tables because it may have changed. diff --git a/go/vt/mysqlctl/tmutils/schema_test.go b/go/vt/mysqlctl/tmutils/schema_test.go index 1775aae805d..0b851721614 100644 --- a/go/vt/mysqlctl/tmutils/schema_test.go +++ b/go/vt/mysqlctl/tmutils/schema_test.go @@ -256,6 +256,216 @@ func TestSchemaDiff(t *testing.T) { testDiff(t, sd1, sd2, "sd1", "sd2", []string{"schemas differ on table table2:\nsd1: schema2\n differs from:\nsd2: schema3"}) } +func TestTableFilter(t *testing.T) { + includedTable := "t1" + includedTable2 := "t2" + excludedTable := "e1" + view := "v1" + + includedTableRE := "/t.*/" + excludedTableRE := "/e.*/" + + tcs := []struct { + desc string + tables []string + excludeTables []string + includeViews bool + + tableName string + tableType string + + hasErr bool + included bool + }{ + { + desc: "everything allowed includes table", + includeViews: true, + + tableName: includedTable, + tableType: TableBaseTable, + + included: true, + }, + { + desc: "everything allowed includes view", + includeViews: true, + + tableName: view, + tableType: TableView, + + included: true, + }, + { + desc: "table list includes matching 1st table", + tables: []string{includedTable, includedTable2}, + includeViews: true, + + tableName: includedTable, + tableType: TableBaseTable, + + included: true, + }, + { + desc: "table list includes matching 2nd table", + tables: []string{includedTable, includedTable2}, + includeViews: true, + + tableName: includedTable2, + tableType: TableBaseTable, + + included: true, + }, + { + desc: "table list excludes non-matching table", + tables: []string{includedTable, includedTable2}, + includeViews: true, + + tableName: excludedTable, + tableType: TableBaseTable, + + included: false, + }, + { + desc: "table list include view includes matching view", + tables: []string{view}, + includeViews: true, + + tableName: view, + tableType: TableView, + + included: true, + }, + { + desc: "table list exclude view excludes matching view", + tables: []string{view}, + includeViews: false, + + tableName: view, + tableType: TableView, + + included: false, + }, + { + desc: "table regexp list includes matching table", + tables: []string{includedTableRE}, + includeViews: false, + + tableName: includedTable, + tableType: TableBaseTable, + + included: true, + }, + { + desc: "exclude table list excludes matching table", + excludeTables: []string{excludedTable}, + + tableName: excludedTable, + tableType: TableBaseTable, + + included: false, + }, + { + desc: "exclude table list includes non-matching table", + excludeTables: []string{excludedTable}, + + tableName: includedTable, + tableType: TableBaseTable, + + included: true, + }, + { + desc: "exclude table list includes non-matching view", + excludeTables: []string{excludedTable}, + includeViews: true, + + tableName: view, + tableType: TableView, + + included: true, + }, + { + desc: "exclude table list excludes matching view", + excludeTables: []string{excludedTable}, + includeViews: true, + + tableName: excludedTable, + tableType: TableView, + + included: false, + }, + { + desc: "exclude table list excludes matching view", + excludeTables: []string{excludedTable}, + includeViews: true, + + tableName: excludedTable, + tableType: TableView, + + included: false, + }, + { + desc: "exclude table regexp list excludes matching table", + excludeTables: []string{excludedTableRE}, + includeViews: false, + + tableName: excludedTable, + tableType: TableBaseTable, + + included: false, + }, + { + desc: "table list with excludes includes matching table", + tables: []string{includedTable}, + excludeTables: []string{excludedTable}, + + tableName: includedTable, + tableType: TableBaseTable, + + included: true, + }, + { + desc: "table list with excludes excludes matching excluded table", + tables: []string{includedTable}, + excludeTables: []string{excludedTable}, + + tableName: excludedTable, + tableType: TableBaseTable, + + included: false, + }, + { + desc: "bad table regexp", + tables: []string{"/*/"}, + + hasErr: true, + }, + { + desc: "bad exclude table regexp", + excludeTables: []string{"/*/"}, + + hasErr: true, + }, + } + + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + f, err := NewTableFilter(tc.tables, tc.excludeTables, tc.includeViews) + if tc.hasErr != (err != nil) { + t.Fatalf("hasErr not right: %v, tc: %+v", err, tc) + } + + if tc.hasErr { + return + } + + included := f.Includes(tc.tableName, tc.tableType) + if tc.included != included { + t.Fatalf("included is not right: %v\nfilter: %+v\ntc: %+v", included, f, tc) + } + }) + } +} + func TestFilterTables(t *testing.T) { var testcases = []struct { desc string