diff --git a/go/test/endtoend/tabletmanager/commands_test.go b/go/test/endtoend/tabletmanager/commands_test.go index 1a2d2424cb4..64598e4a891 100644 --- a/go/test/endtoend/tabletmanager/commands_test.go +++ b/go/test/endtoend/tabletmanager/commands_test.go @@ -39,7 +39,7 @@ var ( getSchemaT1Results8030 = "CREATE TABLE `t1` (\n `id` bigint NOT NULL,\n `value` varchar(16) DEFAULT NULL,\n PRIMARY KEY (`id`)\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb3" getSchemaT1Results80 = "CREATE TABLE `t1` (\n `id` bigint NOT NULL,\n `value` varchar(16) DEFAULT NULL,\n PRIMARY KEY (`id`)\n) ENGINE=InnoDB DEFAULT CHARSET=utf8" getSchemaT1Results57 = "CREATE TABLE `t1` (\n `id` bigint(20) NOT NULL,\n `value` varchar(16) DEFAULT NULL,\n PRIMARY KEY (`id`)\n) ENGINE=InnoDB DEFAULT CHARSET=utf8" - getSchemaV1Results = fmt.Sprintf("CREATE ALGORITHM=UNDEFINED DEFINER=`%s`@`%s` SQL SECURITY DEFINER VIEW {{.DatabaseName}}.`v1` AS select {{.DatabaseName}}.`t1`.`id` AS `id`,{{.DatabaseName}}.`t1`.`value` AS `value` from {{.DatabaseName}}.`t1`", username, hostname) + getSchemaV1Results = fmt.Sprintf("create algorithm = UNDEFINED definer = %s@%s sql security DEFINER view {{.DatabaseName}}.v1 as select {{.DatabaseName}}.t1.id as id, {{.DatabaseName}}.t1.value as value from {{.DatabaseName}}.t1", username, hostname) ) // TabletCommands tests the basic tablet commands diff --git a/go/vt/mysqlctl/schema.go b/go/vt/mysqlctl/schema.go index 397668145ef..c2ecb5a50dc 100644 --- a/go/vt/mysqlctl/schema.go +++ b/go/vt/mysqlctl/schema.go @@ -34,6 +34,7 @@ import ( querypb "vitess.io/vitess/go/vt/proto/query" tabletmanagerdatapb "vitess.io/vitess/go/vt/proto/tabletmanagerdata" "vitess.io/vitess/go/vt/proto/vtrpc" + "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vterrors" ) @@ -99,7 +100,7 @@ func (mysqld *Mysqld) GetSchema(ctx context.Context, dbName string, request *tab if len(qr.Rows) == 0 { return nil, fmt.Errorf("empty create database statement for %v", dbName) } - sd.DatabaseSchema = strings.Replace(qr.Rows[0][1].ToString(), backtickDBName, "{{.DatabaseName}}", 1) + sd.DatabaseSchema = strings.Replace(qr.Rows[0][1].ToString(), backtickDBName, tmutils.DatabaseNamePlaceholder, 1) tds, err := mysqld.collectBasicTableData(ctx, dbName, request.Tables, request.ExcludeTables, request.IncludeViews) if err != nil { @@ -249,6 +250,29 @@ func (mysqld *Mysqld) collectSchema(ctx context.Context, dbName, tableName, tabl return fields, columns, schema, nil } +// normalizedStatement normalizes a CREATE TABLE or CREATE VIEW statement as follows: +// - For CREATE TABLE, it stripts away any AUTO_INCREMENT=... clause. +// - For CREATE VIEW, it replaces the schema name with given `dbName` +func normalizedStatement(ctx context.Context, statementQuery, dbName, tableType string) (string, error) { + // Normalize & remove auto_increment because it changes on every insert + // FIXME(alainjobart) find a way to share this with + // vt/tabletserver/table_info.go:162 + norm := statementQuery + norm = autoIncr.ReplaceAllLiteralString(norm, "") + if tableType == tmutils.TableView { + replaced, err := sqlparser.ReplaceTableQualifiers(norm, dbName, tmutils.DatabaseNamePlaceholder) + if err != nil { + // parsing unsuccessful + return norm, err + } + // Parsing successful + replaced = tmutils.UnqualifyDatabaseNamePlaceholder(replaced) + return replaced, nil + } + + return norm, nil +} + // normalizedSchema returns a table schema with database names replaced, and auto_increment annotations removed. func (mysqld *Mysqld) normalizedSchema(ctx context.Context, dbName, tableName, tableType string) (string, error) { backtickDBName := sqlescape.EscapeID(dbName) @@ -263,15 +287,7 @@ func (mysqld *Mysqld) normalizedSchema(ctx context.Context, dbName, tableName, t // Normalize & remove auto_increment because it changes on every insert // FIXME(alainjobart) find a way to share this with // vt/tabletserver/table_info.go:162 - norm := qr.Rows[0][1].ToString() - norm = autoIncr.ReplaceAllLiteralString(norm, "") - if tableType == tmutils.TableView { - // Views will have the dbname in there, replace it - // with {{.DatabaseName}} - norm = strings.Replace(norm, backtickDBName, "{{.DatabaseName}}", -1) - } - - return norm, nil + return normalizedStatement(ctx, qr.Rows[0][1].ToString(), dbName, tableType) } // ResolveTables returns a list of actual tables+views matching a list @@ -410,76 +426,9 @@ func (mysqld *Mysqld) getPrimaryKeyColumns(ctx context.Context, dbName string, t return colMap, err } -// PreflightSchemaChange checks the schema changes in "changes" by applying them -// to an intermediate database that has the same schema as the target database. +// PreflightSchemaChange is deprecated func (mysqld *Mysqld) PreflightSchemaChange(ctx context.Context, dbName string, changes []string) ([]*tabletmanagerdatapb.SchemaChangeResult, error) { - results := make([]*tabletmanagerdatapb.SchemaChangeResult, len(changes)) - - // Get current schema from the real database. - req := &tabletmanagerdatapb.GetSchemaRequest{IncludeViews: true, TableSchemaOnly: true} - originalSchema, err := mysqld.GetSchema(ctx, dbName, req) - if err != nil { - return nil, err - } - - // Populate temporary database with it. - initialCopySQL := "SET sql_log_bin = 0;\n" - initialCopySQL += "DROP DATABASE IF EXISTS _vt_preflight;\n" - initialCopySQL += "CREATE DATABASE _vt_preflight;\n" - initialCopySQL += "USE _vt_preflight;\n" - // We're not smart enough to create the tables in a foreign-key-compatible way, - // so we temporarily disable foreign key checks while adding the existing tables. - initialCopySQL += "SET foreign_key_checks = 0;\n" - for _, td := range originalSchema.TableDefinitions { - if td.Type == tmutils.TableBaseTable { - initialCopySQL += td.Schema + ";\n" - } - } - for _, td := range originalSchema.TableDefinitions { - if td.Type == tmutils.TableView { - // Views will have {{.DatabaseName}} in there, replace - // it with _vt_preflight - s := strings.Replace(td.Schema, "{{.DatabaseName}}", "`_vt_preflight`", -1) - initialCopySQL += s + ";\n" - } - } - if err = mysqld.executeSchemaCommands(ctx, initialCopySQL); err != nil { - return nil, err - } - - // For each change, record the schema before and after. - for i, change := range changes { - req := &tabletmanagerdatapb.GetSchemaRequest{IncludeViews: true} - beforeSchema, err := mysqld.GetSchema(ctx, "_vt_preflight", req) - if err != nil { - return nil, err - } - - // apply schema change to the temporary database - sql := "SET sql_log_bin = 0;\n" - sql += "USE _vt_preflight;\n" - sql += change - if err = mysqld.executeSchemaCommands(ctx, sql); err != nil { - return nil, err - } - - // get the result - afterSchema, err := mysqld.GetSchema(ctx, "_vt_preflight", req) - if err != nil { - return nil, err - } - - results[i] = &tabletmanagerdatapb.SchemaChangeResult{BeforeSchema: beforeSchema, AfterSchema: afterSchema} - } - - // and clean up the extra database - dropSQL := "SET sql_log_bin = 0;\n" - dropSQL += "DROP DATABASE _vt_preflight;\n" - if err = mysqld.executeSchemaCommands(ctx, dropSQL); err != nil { - return nil, err - } - - return results, nil + return nil, vterrors.Errorf(vtrpc.Code_UNIMPLEMENTED, "PreflightSchemaChange is deprecated") } // ApplySchemaChange will apply the schema change to the given database. diff --git a/go/vt/mysqlctl/schema_test.go b/go/vt/mysqlctl/schema_test.go index fb64f8ca8ee..cf4e42376a6 100644 --- a/go/vt/mysqlctl/schema_test.go +++ b/go/vt/mysqlctl/schema_test.go @@ -1,13 +1,16 @@ package mysqlctl import ( + "context" "fmt" "testing" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "vitess.io/vitess/go/mysql/fakesqldb" "vitess.io/vitess/go/sqltypes" + "vitess.io/vitess/go/vt/mysqlctl/tmutils" querypb "vitess.io/vitess/go/vt/proto/query" ) @@ -103,3 +106,58 @@ func TestColumnList(t *testing.T) { require.Equal(t, `[name:"col1" type:VARCHAR]`, fmt.Sprintf("%+v", fields)) } + +func TestNormalizedStatement(t *testing.T) { + tcases := []struct { + statement string + db string + typ string + expect string + }{ + { + statement: "create table mydb.t (id int auto_increment primary key) AUTO_INCREMENT=4", + db: "mydb", + typ: tmutils.TableBaseTable, + expect: "create table mydb.t (id int auto_increment primary key)", + }, + { + statement: "create table `mydb`.t (id int primary key)", + db: "mydb", + typ: tmutils.TableBaseTable, + expect: "create table `mydb`.t (id int primary key)", + }, + { + statement: "create view `mydb`.v as select * from t", + db: "mydb", + typ: tmutils.TableView, + expect: "create view {{.DatabaseName}}.v as select * from t", + }, + { + statement: "create view `mydb`.v as select * from `mydb`.`t`", + db: "mydb", + typ: tmutils.TableView, + expect: "create view {{.DatabaseName}}.v as select * from {{.DatabaseName}}.t", + }, + { + statement: "create view `mydb`.v as select * from `mydb`.mydb", + db: "mydb", + typ: tmutils.TableView, + expect: "create view {{.DatabaseName}}.v as select * from {{.DatabaseName}}.mydb", + }, + { + statement: "create view `mydb`.v as select * from `mydb`.`mydb`", + db: "mydb", + typ: tmutils.TableView, + expect: "create view {{.DatabaseName}}.v as select * from {{.DatabaseName}}.mydb", + }, + } + ctx := context.Background() + for _, tcase := range tcases { + testName := tcase.statement + t.Run(testName, func(t *testing.T) { + result, err := normalizedStatement(ctx, tcase.statement, tcase.db, tcase.typ) + assert.NoError(t, err) + assert.Equal(t, tcase.expect, result) + }) + } +} diff --git a/go/vt/mysqlctl/tmutils/schema.go b/go/vt/mysqlctl/tmutils/schema.go index aae529f89b0..c6f2eeb7158 100644 --- a/go/vt/mysqlctl/tmutils/schema.go +++ b/go/vt/mysqlctl/tmutils/schema.go @@ -27,6 +27,8 @@ import ( "vitess.io/vitess/go/vt/concurrency" "vitess.io/vitess/go/vt/schema" + "vitess.io/vitess/go/vt/sqlparser" + "vitess.io/vitess/go/vt/vterrors" tabletmanagerdatapb "vitess.io/vitess/go/vt/proto/tabletmanagerdata" ) @@ -38,8 +40,17 @@ const ( TableBaseTable = "BASE TABLE" // TableView indicates the table type is a view. TableView = "VIEW" + + DatabaseNamePlaceholder = "{{.DatabaseName}}" ) +func UnqualifyDatabaseNamePlaceholder(s string) string { + return strings.Replace(s, sqlescape.EscapeID(DatabaseNamePlaceholder), DatabaseNamePlaceholder, -1) +} +func QualifyDatabaseNamePlaceholder(s string) string { + return strings.Replace(s, DatabaseNamePlaceholder, sqlescape.EscapeID(DatabaseNamePlaceholder), -1) +} + // TableDefinitionGetColumn returns the index of a column inside a // TableDefinition. func TableDefinitionGetColumn(td *tabletmanagerdatapb.TableDefinition, name string) (index int, ok bool) { @@ -206,14 +217,15 @@ func SchemaDefinitionGetTable(sd *tabletmanagerdatapb.SchemaDefinition, table st // SchemaDefinitionToSQLStrings converts a SchemaDefinition to an array of SQL strings. The array contains all // the SQL statements needed for creating the database, tables, and views - in that order. // All SQL statements will have {{.DatabaseName}} in place of the actual db name. -func SchemaDefinitionToSQLStrings(sd *tabletmanagerdatapb.SchemaDefinition) []string { +func SchemaDefinitionToSQLStrings(sd *tabletmanagerdatapb.SchemaDefinition) ([]string, error) { sqlStrings := make([]string, 0, len(sd.TableDefinitions)+1) createViewSQL := make([]string, 0, len(sd.TableDefinitions)) // Backtick database name since keyspace names appear in the routing rules, and they might need to be escaped. // We unescape() them first in case we have an explicitly escaped string was specified. - createDatabaseSQL := strings.Replace(sd.DatabaseSchema, "`{{.DatabaseName}}`", "{{.DatabaseName}}", -1) - createDatabaseSQL = strings.Replace(createDatabaseSQL, "{{.DatabaseName}}", sqlescape.EscapeID("{{.DatabaseName}}"), -1) + createDatabaseSQL := sd.DatabaseSchema + createDatabaseSQL = UnqualifyDatabaseNamePlaceholder(createDatabaseSQL) + createDatabaseSQL = QualifyDatabaseNamePlaceholder(createDatabaseSQL) sqlStrings = append(sqlStrings, createDatabaseSQL) for _, td := range sd.TableDefinitions { @@ -223,17 +235,16 @@ func SchemaDefinitionToSQLStrings(sd *tabletmanagerdatapb.SchemaDefinition) []st if td.Type == TableView { createViewSQL = append(createViewSQL, td.Schema) } else { - lines := strings.Split(td.Schema, "\n") - for i, line := range lines { - if strings.HasPrefix(line, "CREATE TABLE `") { - lines[i] = strings.Replace(line, "CREATE TABLE `", "CREATE TABLE `{{.DatabaseName}}`.`", 1) - } + replaced, err := sqlparser.ReplaceTableQualifiers(td.Schema, "", DatabaseNamePlaceholder) + if err != nil { + // parsing unsuccessful + return nil, vterrors.Wrapf(err, "parsing schema: %v", td.Schema) } - sqlStrings = append(sqlStrings, strings.Join(lines, "\n")) + sqlStrings = append(sqlStrings, replaced) } } - return append(sqlStrings, createViewSQL...) + return append(sqlStrings, createViewSQL...), nil } // DiffSchema generates a report on what's different between two SchemaDefinitions diff --git a/go/vt/mysqlctl/tmutils/schema_test.go b/go/vt/mysqlctl/tmutils/schema_test.go index 0f3d9572107..d288ca8b744 100644 --- a/go/vt/mysqlctl/tmutils/schema_test.go +++ b/go/vt/mysqlctl/tmutils/schema_test.go @@ -30,32 +30,32 @@ import ( var basicTable1 = &tabletmanagerdatapb.TableDefinition{ Name: "table1", - Schema: "table schema 1", + Schema: "create table table1 (id int primary key)", Type: TableBaseTable, } var basicTable2 = &tabletmanagerdatapb.TableDefinition{ Name: "table2", - Schema: "table schema 2", + Schema: "create table table2 (id int primary key)", Type: TableBaseTable, } var table3 = &tabletmanagerdatapb.TableDefinition{ - Name: "table2", + Name: "table3", Schema: "CREATE TABLE `table3` (\n" + - "id bigint not null,\n" + + "id bigint not null\n" + ") Engine=InnoDB", Type: TableBaseTable, } var view1 = &tabletmanagerdatapb.TableDefinition{ Name: "view1", - Schema: "view schema 1", + Schema: "create view view1 as select id from t1", Type: TableView, } var view2 = &tabletmanagerdatapb.TableDefinition{ Name: "view2", - Schema: "view schema 2", + Schema: "create view view2 as select id from t2", Type: TableView, } @@ -73,7 +73,11 @@ func TestToSQLStrings(t *testing.T) { view1, }, }, - want: []string{"CREATE DATABASE `{{.DatabaseName}}`", basicTable1.Schema, view1.Schema}, + want: []string{ + "CREATE DATABASE `{{.DatabaseName}}`", + "create table `{{.DatabaseName}}`.table1 (\n\tid int primary key\n)", + "create view view1 as select id from t1", + }, }, { // SchemaDefinition doesn't need any tables or views @@ -96,7 +100,11 @@ func TestToSQLStrings(t *testing.T) { basicTable2, }, }, - want: []string{"CREATE DATABASE `{{.DatabaseName}}`", basicTable1.Schema, basicTable2.Schema}, + want: []string{ + "CREATE DATABASE `{{.DatabaseName}}`", + "create table `{{.DatabaseName}}`.table1 (\n\tid int primary key\n)", + "create table `{{.DatabaseName}}`.table2 (\n\tid int primary key\n)", + }, }, { // multiple tables and views should be ordered with all tables before views @@ -110,9 +118,10 @@ func TestToSQLStrings(t *testing.T) { }, }, want: []string{ - "CREATE DATABASE `{{.DatabaseName}}`", - basicTable1.Schema, basicTable2.Schema, - view1.Schema, view2.Schema, + "CREATE DATABASE `{{.DatabaseName}}`", "create table `{{.DatabaseName}}`.table1 (\n\tid int primary key\n)", + "create table `{{.DatabaseName}}`.table2 (\n\tid int primary key\n)", + "create view view1 as select id from t1", + "create view view2 as select id from t2", }, }, { @@ -126,16 +135,15 @@ func TestToSQLStrings(t *testing.T) { }, want: []string{ "CREATE DATABASE `{{.DatabaseName}}`", - basicTable1.Schema, - "CREATE TABLE `{{.DatabaseName}}`.`table3` (\n" + - "id bigint not null,\n" + - ") Engine=InnoDB", + "create table `{{.DatabaseName}}`.table1 (\n\tid int primary key\n)", + "create table `{{.DatabaseName}}`.table3 (\n\tid bigint not null\n) Engine InnoDB", }, }, } for _, tc := range testcases { - got := SchemaDefinitionToSQLStrings(tc.input) + got, err := SchemaDefinitionToSQLStrings(tc.input) + assert.NoError(t, err) assert.Equal(t, tc.want, got) } } diff --git a/go/vt/sqlparser/ast_funcs.go b/go/vt/sqlparser/ast_funcs.go index 20088bee795..3ca10c04ddd 100644 --- a/go/vt/sqlparser/ast_funcs.go +++ b/go/vt/sqlparser/ast_funcs.go @@ -86,7 +86,8 @@ type IndexOption struct { String string } -// TableOption is used for create table options like AUTO_INCREMENT, INSERT_METHOD, etc +// TableOption is used for create table options like AUTO_INCREMENT, INSERT_METHOD, +// TABLESPACE, etc type TableOption struct { Name string Value *Literal diff --git a/go/vt/sqlparser/utils.go b/go/vt/sqlparser/utils.go index 0f3c66f2ea3..b237831c8a4 100644 --- a/go/vt/sqlparser/utils.go +++ b/go/vt/sqlparser/utils.go @@ -135,15 +135,31 @@ func ReplaceTableQualifiers(query, olddb, newdb string) (string, error) { upd := Rewrite(in, func(cursor *Cursor) bool { switch node := cursor.Node().(type) { case TableName: - if !node.Qualifier.IsEmpty() && - node.Qualifier.String() == oldQualifier.String() { + if node.Qualifier.String() == oldQualifier.String() { node.Qualifier = newQualifier cursor.Replace(node) modified = true } + case *CreateDatabase: + if node.DBName.String() == oldQualifier.String() { + node.DBName = newQualifier + cursor.Replace(node) + modified = true + } + case *AlterDatabase: + if node.DBName.String() == oldQualifier.String() { + node.DBName = newQualifier + cursor.Replace(node) + modified = true + } + case *DropDatabase: + if node.DBName.String() == oldQualifier.String() { + node.DBName = newQualifier + cursor.Replace(node) + modified = true + } case *ShowBasic: // for things like 'show tables from _vt' - if !node.DbName.IsEmpty() && - node.DbName.String() == oldQualifier.String() { + if node.DbName.String() == oldQualifier.String() { node.DbName = newQualifier cursor.Replace(node) modified = true diff --git a/go/vt/sqlparser/utils_test.go b/go/vt/sqlparser/utils_test.go index 63c9b10ba43..b4f41241e94 100644 --- a/go/vt/sqlparser/utils_test.go +++ b/go/vt/sqlparser/utils_test.go @@ -187,85 +187,153 @@ func TestReplaceTableQualifiers(t *testing.T) { tests := []struct { name string in string + origdb string newdb string out string wantErr bool }{ { name: "invalid select", + origdb: origDB, in: "select frog bar person", out: "", wantErr: true, }, { - name: "simple select", - in: "select * from _vt.foo", - out: "select * from foo", + name: "simple select", + origdb: origDB, + in: "select * from _vt.foo", + out: "select * from foo", }, { - name: "simple select with new db", - in: "select * from _vt.foo", - newdb: "_vt_test", - out: "select * from _vt_test.foo", + name: "simple select with new db", + in: "select * from _vt.foo", + origdb: origDB, + newdb: "_vt_test", + out: "select * from _vt_test.foo", }, { - name: "simple select with new db same", - in: "select * from _vt.foo where id=1", // should be unchanged - newdb: "_vt", - out: "select * from _vt.foo where id=1", + name: "simple select with new db same", + in: "select * from _vt.foo where id=1", // should be unchanged + origdb: origDB, + newdb: "_vt", + out: "select * from _vt.foo where id=1", }, { - name: "simple select with new db needing escaping", - in: "select * from _vt.foo", - newdb: "1_vt-test", - out: "select * from `1_vt-test`.foo", + name: "simple select with new db needing escaping", + in: "select * from _vt.foo", + origdb: origDB, + newdb: "1_vt-test", + out: "select * from `1_vt-test`.foo", }, { - name: "complex select", - in: "select table_name, lastpk from _vt.copy_state where vrepl_id = 1 and id in (select max(id) from _vt.copy_state where vrepl_id = 1 group by vrepl_id, table_name)", - out: "select table_name, lastpk from copy_state where vrepl_id = 1 and id in (select max(id) from copy_state where vrepl_id = 1 group by vrepl_id, table_name)", + name: "complex select", + origdb: origDB, + in: "select table_name, lastpk from _vt.copy_state where vrepl_id = 1 and id in (select max(id) from _vt.copy_state where vrepl_id = 1 group by vrepl_id, table_name)", + out: "select table_name, lastpk from copy_state where vrepl_id = 1 and id in (select max(id) from copy_state where vrepl_id = 1 group by vrepl_id, table_name)", }, { - name: "complex mixed exists select", - in: "select workflow_name, db_name from _vt.vreplication where id = 1 and exists (select v1 from mydb.foo where fname = 'matt') and not exists (select v2 from _vt.newsidecartable where _vt.newsidecartable.id = _vt.vreplication.workflow_name)", - newdb: "_vt_import", - out: "select workflow_name, db_name from _vt_import.vreplication where id = 1 and exists (select v1 from mydb.foo where fname = 'matt') and not exists (select v2 from _vt_import.newsidecartable where _vt_import.newsidecartable.id = _vt_import.vreplication.workflow_name)", + name: "complex mixed exists select", + in: "select workflow_name, db_name from _vt.vreplication where id = 1 and exists (select v1 from mydb.foo where fname = 'matt') and not exists (select v2 from _vt.newsidecartable where _vt.newsidecartable.id = _vt.vreplication.workflow_name)", + origdb: origDB, + newdb: "_vt_import", + out: "select workflow_name, db_name from _vt_import.vreplication where id = 1 and exists (select v1 from mydb.foo where fname = 'matt') and not exists (select v2 from _vt_import.newsidecartable where _vt_import.newsidecartable.id = _vt_import.vreplication.workflow_name)", }, { - name: "derived table select", - in: "select myder.id from (select max(id) as id from _vt.copy_state where vrepl_id = 1 group by vrepl_id, table_name) as myder where id = 1", - newdb: "__vt-metadata", - out: "select myder.id from (select max(id) as id from `__vt-metadata`.copy_state where vrepl_id = 1 group by vrepl_id, table_name) as myder where id = 1", + name: "derived table select", + in: "select myder.id from (select max(id) as id from _vt.copy_state where vrepl_id = 1 group by vrepl_id, table_name) as myder where id = 1", + origdb: origDB, + newdb: "__vt-metadata", + out: "select myder.id from (select max(id) as id from `__vt-metadata`.copy_state where vrepl_id = 1 group by vrepl_id, table_name) as myder where id = 1", }, { - name: "complex select", - in: "select t1.col1, t2.col2 from _vt.t1 as t1 join _vt.t2 as t2 on t1.id = t2.id", - out: "select t1.col1, t2.col2 from t1 as t1 join t2 as t2 on t1.id = t2.id", + name: "complex select", + origdb: origDB, + in: "select t1.col1, t2.col2 from _vt.t1 as t1 join _vt.t2 as t2 on t1.id = t2.id", + out: "select t1.col1, t2.col2 from t1 as t1 join t2 as t2 on t1.id = t2.id", }, { - name: "simple insert", - in: "insert into _vt.foo(id) values (1)", - out: "insert into foo(id) values (1)", + name: "simple insert", + origdb: origDB, + in: "insert into _vt.foo(id) values (1)", + out: "insert into foo(id) values (1)", }, { - name: "simple update", - in: "update _vt.foo set id = 1", - out: "update foo set id = 1", + name: "simple update", + origdb: origDB, + in: "update _vt.foo set id = 1", + out: "update foo set id = 1", }, { - name: "simple delete", - in: "delete from _vt.foo where id = 1", - out: "delete from foo where id = 1", + name: "simple delete", + origdb: origDB, + in: "delete from _vt.foo where id = 1", + out: "delete from foo where id = 1", }, { - name: "simple set", - in: "set names 'binary'", - out: "set names 'binary'", + name: "simple set", + origdb: origDB, + in: "set names 'binary'", + out: "set names 'binary'", + }, + { + name: "simple create table", + origdb: origDB, + in: "CREATE TABLE t (id int primary key)", + out: "CREATE TABLE t (id int primary key)", + }, + { + name: "qualified create table", + origdb: origDB, + in: "CREATE TABLE `t` (id int primary key)", + out: "CREATE TABLE `t` (id int primary key)", + }, + { + name: "qualified create table, _vt", + origdb: origDB, + newdb: "mydb", + in: "create table `_vt`.`t` (id int primary key)", + out: "create table mydb.t (\n\tid int primary key\n)", + }, + { + name: "empty qualifier create table", + origdb: "", + newdb: "mydb", + in: "CREATE TABLE `t` (id int primary key)", + out: "create table mydb.t (\n\tid int primary key\n)", + }, + { + name: "create database", + origdb: "{{.DatabaseName}}", + newdb: "mydb", + in: "CREATE DATABASE `{{.DatabaseName}}`", + out: "create database mydb", + }, + { + name: "create database, comments", + origdb: "{{.DatabaseName}}", + newdb: "mydb", + in: "CREATE DATABASE /*!32312 IF NOT EXISTS*/ `{{.DatabaseName}}` /*!40100 DEFAULT CHARACTER SET utf8mb4 COLLATE utf8mb4_0900_ai_ci */ /*!80016 DEFAULT ENCRYPTION='N' */", + out: "create database if not exists mydb default character set utf8mb4 collate utf8mb4_0900_ai_ci default encryption 'N'", + }, + { + name: "create database, unqualified", + origdb: origDB, + newdb: "mydb", + in: "CREATE DATABASE _vt", + out: "create database mydb", + }, + { + name: "drop database, unqualified", + origdb: origDB, + newdb: "mydb", + in: "drop database _vt", + out: "drop database mydb", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := ReplaceTableQualifiers(tt.in, origDB, tt.newdb) + got, err := ReplaceTableQualifiers(tt.in, tt.origdb, tt.newdb) if tt.wantErr { require.Error(t, err) } else { diff --git a/go/vt/vtctl/workflow/server.go b/go/vt/vtctl/workflow/server.go index f42a2dda59c..793e278291b 100644 --- a/go/vt/vtctl/workflow/server.go +++ b/go/vt/vtctl/workflow/server.go @@ -2788,7 +2788,10 @@ func (s *Server) CopySchemaShard(ctx context.Context, sourceTabletAlias *topodat return fmt.Errorf("GetSchema(%v, %v, %v, %v) failed: %v", sourceTabletAlias, tables, excludeTables, includeViews, err) } - createSQLstmts := tmutils.SchemaDefinitionToSQLStrings(sourceSd) + createSQLstmts, err := tmutils.SchemaDefinitionToSQLStrings(sourceSd) + if err != nil { + return fmt.Errorf("SchemaDefinitionToSQLStrings(%v) failed: %v", sourceSd, err) + } destTabletInfo, err := s.ts.GetTablet(ctx, destShardInfo.PrimaryAlias) if err != nil { diff --git a/go/vt/wrangler/schema.go b/go/vt/wrangler/schema.go index 84bc078f240..adc45a73a8e 100644 --- a/go/vt/wrangler/schema.go +++ b/go/vt/wrangler/schema.go @@ -17,11 +17,9 @@ limitations under the License. package wrangler import ( - "bytes" "context" "fmt" "sync" - "text/template" "time" "vitess.io/vitess/go/vt/concurrency" @@ -29,9 +27,11 @@ import ( "vitess.io/vitess/go/vt/logutil" "vitess.io/vitess/go/vt/mysqlctl/tmutils" "vitess.io/vitess/go/vt/schema" + "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/topo" "vitess.io/vitess/go/vt/topo/topoproto" "vitess.io/vitess/go/vt/vtctl/schematools" + "vitess.io/vitess/go/vt/vterrors" "vitess.io/vitess/go/vt/vttablet/tabletmanager/vreplication" tabletmanagerdatapb "vitess.io/vitess/go/vt/proto/tabletmanagerdata" @@ -225,7 +225,10 @@ func (wr *Wrangler) CopySchemaShard(ctx context.Context, sourceTabletAlias *topo return fmt.Errorf("GetSchema(%v, %v, %v, %v) failed: %v", sourceTabletAlias, tables, excludeTables, includeViews, err) } - createSQLstmts := tmutils.SchemaDefinitionToSQLStrings(sourceSd) + createSQLstmts, err := tmutils.SchemaDefinitionToSQLStrings(sourceSd) + if err != nil { + return err + } destTabletInfo, err := wr.ts.GetTablet(ctx, destShardInfo.PrimaryAlias) if err != nil { @@ -234,10 +237,10 @@ func (wr *Wrangler) CopySchemaShard(ctx context.Context, sourceTabletAlias *topo for _, createSQL := range createSQLstmts { err = wr.applySQLShard(ctx, destTabletInfo, createSQL) if err != nil { - return fmt.Errorf("creating a table failed."+ + return vterrors.Wrapf(err, "creating a table failed."+ " Most likely some tables already exist on the destination and differ from the source."+ " Please remove all to be copied tables from the destination manually and run this command again."+ - " Full error: %v", err) + " CREATE statement: %v", createSQL) } } @@ -288,7 +291,7 @@ func (wr *Wrangler) CopySchemaShard(ctx context.Context, sourceTabletAlias *topo // it shouldn't be used for anything that will require a pivot. // The SQL statement string is expected to have {{.DatabaseName}} in place of the actual db name. func (wr *Wrangler) applySQLShard(ctx context.Context, tabletInfo *topo.TabletInfo, change string) error { - filledChange, err := fillStringTemplate(change, map[string]string{"DatabaseName": tabletInfo.DbName()}) + filledChange, err := sqlparser.ReplaceTableQualifiers(change, tmutils.DatabaseNamePlaceholder, tabletInfo.DbName()) if err != nil { return fmt.Errorf("fillStringTemplate failed: %v", err) } @@ -303,13 +306,3 @@ func (wr *Wrangler) applySQLShard(ctx context.Context, tabletInfo *topo.TabletIn }) return err } - -// fillStringTemplate returns the string template filled -func fillStringTemplate(tmpl string, vars any) (string, error) { - myTemplate := template.Must(template.New("").Parse(tmpl)) - data := new(bytes.Buffer) - if err := myTemplate.Execute(data, vars); err != nil { - return "", err - } - return data.String(), nil -} diff --git a/go/vt/wrangler/testlib/copy_schema_shard_test.go b/go/vt/wrangler/testlib/copy_schema_shard_test.go index 866ec2fe931..dfb1745d8a6 100644 --- a/go/vt/wrangler/testlib/copy_schema_shard_test.go +++ b/go/vt/wrangler/testlib/copy_schema_shard_test.go @@ -96,12 +96,12 @@ func copySchema(t *testing.T, useShardAsSource bool) { TableDefinitions: []*tabletmanagerdatapb.TableDefinition{ { Name: "table1", - Schema: "CREATE TABLE `table1` (\n `id` bigint(20) NOT NULL AUTO_INCREMENT,\n `msg` varchar(64) DEFAULT NULL,\n `keyspace_id` bigint(20) unsigned NOT NULL,\n PRIMARY KEY (`id`),\n KEY `by_msg` (`msg`)\n) ENGINE=InnoDB DEFAULT CHARSET=utf8", + Schema: "CREATE TABLE `vt_ks`.`table1` (\n `id` bigint(20) NOT NULL AUTO_INCREMENT,\n `msg` varchar(64) DEFAULT NULL,\n `keyspace_id` bigint(20) unsigned NOT NULL,\n PRIMARY KEY (`id`),\n KEY `by_msg` (`msg`)\n) ENGINE=InnoDB DEFAULT CHARSET=utf8", Type: tmutils.TableBaseTable, }, { Name: "view1", - Schema: "CREATE TABLE `view1` (\n `id` bigint(20) NOT NULL AUTO_INCREMENT,\n `msg` varchar(64) DEFAULT NULL,\n `keyspace_id` bigint(20) unsigned NOT NULL,\n PRIMARY KEY (`id`),\n KEY `by_msg` (`msg`)\n) ENGINE=InnoDB DEFAULT CHARSET=utf8", + Schema: "CREATE TABLE `vt_ks`.`view1` (\n `id` bigint(20) NOT NULL AUTO_INCREMENT,\n `msg` varchar(64) DEFAULT NULL,\n `keyspace_id` bigint(20) unsigned NOT NULL,\n PRIMARY KEY (`id`),\n KEY `by_msg` (`msg`)\n) ENGINE=InnoDB DEFAULT CHARSET=utf8", Type: tmutils.TableView, }, }, @@ -115,7 +115,7 @@ func copySchema(t *testing.T, useShardAsSource bool) { setSQLMode := fmt.Sprintf("SET @@session.sql_mode='%v'", vreplication.SQLMode) changeToDb := "USE `vt_ks`" - createDb := "CREATE DATABASE `vt_ks` /*!40100 DEFAULT CHARACTER SET utf8 */" + createDb := "create database vt_ks default character set utf8" createTable := "CREATE TABLE `vt_ks`.`table1` (\n" + " `id` bigint(20) NOT NULL AUTO_INCREMENT,\n" + " `msg` varchar(64) DEFAULT NULL,\n" + @@ -123,7 +123,7 @@ func copySchema(t *testing.T, useShardAsSource bool) { " PRIMARY KEY (`id`),\n" + " KEY `by_msg` (`msg`)\n" + ") ENGINE=InnoDB DEFAULT CHARSET=utf8" - createTableView := "CREATE TABLE `view1` (\n" + + createTableView := "CREATE TABLE `vt_ks`.`view1` (\n" + " `id` bigint(20) NOT NULL AUTO_INCREMENT,\n" + " `msg` varchar(64) DEFAULT NULL,\n" + " `keyspace_id` bigint(20) unsigned NOT NULL,\n" +