Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
224 changes: 152 additions & 72 deletions go/vt/mysqlctl/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"fmt"
"regexp"
"strings"
"sync"

"vitess.io/vitess/go/vt/vtgate/evalengine"

Expand All @@ -46,6 +47,15 @@ func (mysqld *Mysqld) executeSchemaCommands(sql string) error {
return mysqld.executeMysqlScript(params, strings.NewReader(sql))
}

// tableList returns an IN clause "('t1', 't2'...) for a list of tables."
func tableListSql(tables []string) string {
if len(tables) == 0 {
return "()"
}

return "('" + strings.Join(tables, "', '") + "')"
}

// GetSchema returns the schema for database for tables listed in
// tables. If tables is empty, return the schema for all tables.
func (mysqld *Mysqld) GetSchema(ctx context.Context, dbName string, tables, excludeTables []string, includeViews bool) (*tabletmanagerdatapb.SchemaDefinition, error) {
Expand Down Expand Up @@ -75,11 +85,36 @@ func (mysqld *Mysqld) GetSchema(ctx context.Context, dbName string, tables, excl
return sd, nil
}

sd.TableDefinitions = make([]*tabletmanagerdatapb.TableDefinition, 0, len(qr.Rows))
filter, err := tmutils.NewTableFilter(tables, excludeTables, includeViews)
if err != nil {
return nil, err
}

type schemaResult struct {
idx int
err error

td *tabletmanagerdatapb.TableDefinition
}

resChan := make(chan *schemaResult, 100)
ctx, cancel := context.WithCancel(ctx)
defer cancel()

var wg sync.WaitGroup

tableNames := make([]string, 0, len(qr.Rows))
i := 0
for _, row := range qr.Rows {
tableName := row[0].ToString()
tableType := row[1].ToString()

if !filter.Includes(tableName, tableType) {
continue
}

tableNames = append(tableNames, tableName)

// compute dataLength
var dataLength uint64
if !row[2].IsNull() {
Expand All @@ -99,51 +134,111 @@ func (mysqld *Mysqld) GetSchema(ctx context.Context, dbName string, tables, excl
}
}

qr, fetchErr := mysqld.FetchSuperQuery(ctx, fmt.Sprintf("SHOW CREATE TABLE %s.%s", backtickDBName, sqlescape.EscapeID(tableName)))
if fetchErr != nil {
return nil, fetchErr
}
if len(qr.Rows) == 0 {
return nil, fmt.Errorf("empty create table statement for %v", tableName)
}
wg.Add(1)
go func(idx int) {
defer wg.Done()

// Normalize & remove auto_increment because it changes on every insert
// FIXME(alainjobart) find a way to share this with
// vt/tabletserver/table_info.go:162
norm := qr.Rows[0][1].ToString()
norm = autoIncr.ReplaceAllLiteralString(norm, "")
if tableType == tmutils.TableView {
// Views will have the dbname in there, replace it
// with {{.DatabaseName}}
norm = strings.Replace(norm, backtickDBName, "{{.DatabaseName}}", -1)
}
fields, columns, schema, err := mysqld.collectSchema(ctx, dbName, tableName, tableType)
if err != nil {
resChan <- &schemaResult{
idx: idx,
err: err,
}
return
}

resChan <- &schemaResult{
idx: idx,
td: &tabletmanagerdatapb.TableDefinition{
Name: tableName,
Type: tableType,
DataLength: dataLength,
RowCount: rowCount,
Fields: fields,
Columns: columns,
Schema: schema,
},
}
}(i)

td := &tabletmanagerdatapb.TableDefinition{}
td.Name = tableName
td.Schema = norm
i++
}

td.Fields, td.Columns, err = mysqld.GetColumns(ctx, dbName, tableName)
if err != nil {
return nil, err
}
td.PrimaryKeyColumns, err = mysqld.GetPrimaryKeyColumns(ctx, dbName, tableName)
go func() {
wg.Wait()
close(resChan)
}()

colMap := map[string][]string{}
if len(tableNames) > 0 {
log.Infof("mysqld GetSchema: GetPrimaryKeyColumns")
var err error
colMap, err = mysqld.getPrimaryKeyColumns(ctx, dbName, tableNames...)
if err != nil {
return nil, err
}
td.Type = tableType
td.DataLength = dataLength
td.RowCount = rowCount
sd.TableDefinitions = append(sd.TableDefinitions, td)
log.Infof("mysqld GetSchema: GetPrimaryKeyColumns done")
}

sd, err = tmutils.FilterTables(sd, tables, excludeTables, includeViews)
if err != nil {
return nil, err
log.Infof("mysqld GetSchema: Collecting all table schemas")
tds := make([]*tabletmanagerdatapb.TableDefinition, 0, i)
for res := range resChan {
if res.err != nil {
cancel()
return nil, res.err
}

td := res.td
td.PrimaryKeyColumns = colMap[td.Name]

tds[res.idx] = res.td
}
log.Infof("mysqld GetSchema: Collecting all table schemas done")

sd.TableDefinitions = tds

tmutils.GenerateSchemaVersion(sd)
return sd, nil
}

func (mysqld *Mysqld) collectSchema(ctx context.Context, dbName, tableName, tableType string) ([]*querypb.Field, []string, string, error) {
fields, columns, err := mysqld.GetColumns(ctx, dbName, tableName)
if err != nil {
return nil, nil, "", err
}

schema, err := mysqld.normalizedSchema(ctx, dbName, tableName, tableType)
if err != nil {
return nil, nil, "", err
}

return fields, columns, schema, nil
}

func (mysqld *Mysqld) normalizedSchema(ctx context.Context, dbName, tableName, tableType string) (string, error) {
backtickDBName := sqlescape.EscapeID(dbName)
qr, fetchErr := mysqld.FetchSuperQuery(ctx, fmt.Sprintf("SHOW CREATE TABLE %s.%s", dbName, sqlescape.EscapeID(tableName)))
if fetchErr != nil {
return "", fetchErr
}
if len(qr.Rows) == 0 {
return "", fmt.Errorf("empty create table statement for %v", tableName)
}

// Normalize & remove auto_increment because it changes on every insert
// FIXME(alainjobart) find a way to share this with
// vt/tabletserver/table_info.go:162
norm := qr.Rows[0][1].ToString()
norm = autoIncr.ReplaceAllLiteralString(norm, "")
if tableType == tmutils.TableView {
// Views will have the dbname in there, replace it
// with {{.DatabaseName}}
norm = strings.Replace(norm, backtickDBName, "{{.DatabaseName}}", -1)
}

return norm, nil
}

// ResolveTables returns a list of actual tables+views matching a list
// of regexps
func ResolveTables(ctx context.Context, mysqld MysqlDaemon, dbName string, tables []string) ([]string, error) {
Expand All @@ -166,11 +261,11 @@ func (mysqld *Mysqld) GetColumns(ctx context.Context, dbName, table string) ([]*
}
defer conn.Recycle()

sql := fmt.Sprintf("SELECT * FROM %s.%s WHERE 1=0", sqlescape.EscapeID(dbName), sqlescape.EscapeID(table))
qr, err := mysqld.executeFetchContext(ctx, conn, sql, 0, true)
qr, err := conn.ExecuteFetch(fmt.Sprintf("SELECT * FROM %s.%s WHERE 1=0", sqlescape.EscapeID(dbName), sqlescape.EscapeID(table)), 0, true)
if err != nil {
return nil, nil, err
}

columns := make([]string, len(qr.Fields))
for i, field := range qr.Fields {
columns[i] = field.Name
Expand All @@ -181,55 +276,40 @@ func (mysqld *Mysqld) GetColumns(ctx context.Context, dbName, table string) ([]*

// GetPrimaryKeyColumns returns the primary key columns of table.
func (mysqld *Mysqld) GetPrimaryKeyColumns(ctx context.Context, dbName, table string) ([]string, error) {
cs, err := mysqld.getPrimaryKeyColumns(ctx, dbName, table)
if err != nil {
return nil, err
}

return cs[dbName], nil
}

func (mysqld *Mysqld) getPrimaryKeyColumns(ctx context.Context, dbName string, tables ...string) (map[string][]string, error) {
conn, err := getPoolReconnect(ctx, mysqld.dbaPool)
if err != nil {
return nil, err
}
defer conn.Recycle()

sql := fmt.Sprintf("SHOW INDEX FROM %s.%s", sqlescape.EscapeID(dbName), sqlescape.EscapeID(table))
qr, err := mysqld.executeFetchContext(ctx, conn, sql, 100, true)
tableList := tableListSql(tables)
sql := fmt.Sprintf(`
SELECT table_name, ordinal_position, column_name
FROM information_schema.key_column_usage
WHERE table_schema = '%s'
AND table_name IN %s
AND constraint_name='PRIMARY'
ORDER BY table_name, ordinal_position`, dbName, tableList)
qr, err := conn.ExecuteFetch(sql, len(tables)*100, true)
if err != nil {
return nil, err
}
keyNameIndex := -1
seqInIndexIndex := -1
columnNameIndex := -1
for i, field := range qr.Fields {
switch field.Name {
case "Key_name":
keyNameIndex = i
case "Seq_in_index":
seqInIndexIndex = i
case "Column_name":
columnNameIndex = i
}
}
if keyNameIndex == -1 || seqInIndexIndex == -1 || columnNameIndex == -1 {
return nil, fmt.Errorf("unknown columns in 'show index' result: %v", qr.Fields)
}

columns := make([]string, 0, 5)
var expectedIndex int64 = 1
colMap := map[string][]string{}
for _, row := range qr.Rows {
// skip non-primary keys
if row[keyNameIndex].ToString() != "PRIMARY" {
continue
}

// check the Seq_in_index is always increasing
seqInIndex, err := evalengine.ToInt64(row[seqInIndexIndex])
if err != nil {
return nil, err
}
if seqInIndex != expectedIndex {
return nil, fmt.Errorf("unexpected index: %v != %v", seqInIndex, expectedIndex)
}
expectedIndex++

columns = append(columns, row[columnNameIndex].ToString())
tableName := row[0].ToString()
colMap[tableName] = append(colMap[tableName], row[2].ToString())
}
return columns, err
return colMap, err
}

// PreflightSchemaChange checks the schema changes in "changes" by applying them
Expand Down
Loading