diff --git a/config/tablet/default.yaml b/config/tablet/default.yaml index 44b07ea8779..68de610693e 100644 --- a/config/tablet/default.yaml +++ b/config/tablet/default.yaml @@ -24,32 +24,32 @@ db: user: vt_app # db_app_user password: # db_app_password useSsl: true # db_app_use_ssl - preferSocket: true + preferTcp: false dba: user: vt_dba # db_dba_user password: # db_dba_password useSsl: true # db_dba_use_ssl - preferSocket: true + preferTcp: false filtered: user: vt_filtered # db_filtered_user password: # db_filtered_password useSsl: true # db_filtered_use_ssl - preferSocket: true + preferTcp: false repl: user: vt_repl # db_repl_user password: # db_repl_password useSsl: true # db_repl_use_ssl - preferSocket: true + preferTcp: false appdebug: user: vt_appdebug # db_appdebug_user password: # db_appdebug_password useSsl: true # db_appdebug_use_ssl - preferSocket: true + preferTcp: false allprivs: user: vt_allprivs # db_allprivs_user password: # db_allprivs_password useSsl: true # db_allprivs_use_ssl - preferSocket: true + preferTcp: false oltpReadPool: size: 16 # queryserver-config-pool-size diff --git a/go/cmd/vtcombo/main.go b/go/cmd/vtcombo/main.go index 0a25159c68d..232865bf7bf 100644 --- a/go/cmd/vtcombo/main.go +++ b/go/cmd/vtcombo/main.go @@ -97,10 +97,7 @@ func main() { servenv.Init() tabletenv.Init() - dbcfgs, err := dbconfigs.Init("") - if err != nil { - log.Warning(err) - } + dbcfgs := dbconfigs.GlobalDBConfigs.Init("") mysqld := mysqlctl.NewMysqld(dbcfgs) servenv.OnClose(mysqld.Close) diff --git a/go/cmd/vttablet/vttablet.go b/go/cmd/vttablet/vttablet.go index f9b26441c05..958d00aadd3 100644 --- a/go/cmd/vttablet/vttablet.go +++ b/go/cmd/vttablet/vttablet.go @@ -22,7 +22,6 @@ import ( "io/ioutil" "golang.org/x/net/context" - "sigs.k8s.io/yaml" "vitess.io/vitess/go/vt/dbconfigs" "vitess.io/vitess/go/vt/log" "vitess.io/vitess/go/vt/mysqlctl" @@ -34,6 +33,7 @@ import ( "vitess.io/vitess/go/vt/vttablet/tabletmanager" "vitess.io/vitess/go/vt/vttablet/tabletserver" "vitess.io/vitess/go/vt/vttablet/tabletserver/tabletenv" + "vitess.io/vitess/go/yaml2" ) var ( @@ -67,12 +67,12 @@ func main() { if err != nil { log.Exitf("error reading config file %s: %v", *tabletConfig, err) } - if err := yaml.Unmarshal(bytes, config); err != nil { + if err := yaml2.Unmarshal(bytes, config); err != nil { log.Exitf("error parsing config file %s: %v", bytes, err) } - gotBytes, _ := yaml.Marshal(config) - log.Infof("Loaded config file %s successfully:\n%s", *tabletConfig, gotBytes) } + gotBytes, _ := yaml2.Marshal(config) + log.Infof("Loaded config file %s successfully:\n%s", *tabletConfig, gotBytes) servenv.Init() @@ -90,7 +90,7 @@ func main() { // and use the socket from it. If connection parameters were specified, // we assume that the mysql is not local, and we skip loading mycnf. // This also means that backup and restore will not be allowed. - if !dbconfigs.HasConnectionParams() { + if !config.DB.HasGlobalSettings() { var err error if mycnf, err = mysqlctl.NewMycnfFromFlags(tabletAlias.Uid); err != nil { log.Exitf("mycnf read failed: %v", err) @@ -103,10 +103,7 @@ func main() { // If connection parameters were specified, socketFile will be empty. // Otherwise, the socketFile (read from mycnf) will be used to initialize // dbconfigs. - dbcfgs, err := dbconfigs.Init(socketFile) - if err != nil { - log.Warning(err) - } + config.DB = config.DB.Init(socketFile) if *tableACLConfig != "" { // To override default simpleacl, other ACL plugins must set themselves to be default ACL factory @@ -133,7 +130,7 @@ func main() { // Create mysqld and register the health reporter (needs to be done // before initializing the agent, so the initial health check // done by the agent has the right reporter) - mysqld := mysqlctl.NewMysqld(dbcfgs) + mysqld := mysqlctl.NewMysqld(config.DB) servenv.OnClose(mysqld.Close) // Depends on both query and updateStream. @@ -141,7 +138,7 @@ func main() { if servenv.GRPCPort != nil { gRPCPort = int32(*servenv.GRPCPort) } - agent, err = tabletmanager.NewActionAgent(context.Background(), ts, mysqld, qsc, tabletAlias, dbcfgs, mycnf, int32(*servenv.Port), gRPCPort) + agent, err = tabletmanager.NewActionAgent(context.Background(), ts, mysqld, qsc, tabletAlias, config.DB, mycnf, int32(*servenv.Port), gRPCPort) if err != nil { log.Exitf("NewActionAgent() failed: %v", err) } diff --git a/go/vt/dbconfigs/dbconfigs.go b/go/vt/dbconfigs/dbconfigs.go index 31ccc90aa7d..9941ec442a8 100644 --- a/go/vt/dbconfigs/dbconfigs.go +++ b/go/vt/dbconfigs/dbconfigs.go @@ -24,16 +24,32 @@ import ( "context" "encoding/json" "flag" - "fmt" "vitess.io/vitess/go/mysql" - "vitess.io/vitess/go/sync2" "vitess.io/vitess/go/vt/log" + "vitess.io/vitess/go/yaml2" +) + +// config flags +const ( + App = "app" + AppDebug = "appdebug" + // AllPrivs user should have more privileges than App (should include possibility to do + // schema changes and write to internal Vitess tables), but it shouldn't have SUPER + // privilege like Dba has. + AllPrivs = "allprivs" + Dba = "dba" + Filtered = "filtered" + Repl = "repl" + ExternalRepl = "erepl" ) var ( - dbConfigs = DBConfigs{userConfigs: make(map[string]*userConfig)} - baseConfig = mysql.ConnParams{} + // GlobalDBConfigs contains the initial values of dbconfgis from flags. + GlobalDBConfigs DBConfigs + + // All can be used to register all flags: RegisterFlags(All...) + All = []string{App, AppDebug, AllPrivs, Dba, Filtered, Repl, ExternalRepl} ) // DBConfigs stores all the data needed to build various connection @@ -42,105 +58,115 @@ var ( // It contains other connection parameters like socket, charset, etc. // It also stores the default db name, which it can combine with the // rest of the data to build db-sepcific connection parameters. -// It also supplies the SidecarDBName. This is currently hardcoded -// to "_vt", but will soon become customizable. -// The life-cycle of this package is as follows: +// +// The legacy way of initializing is as follows: // App must call RegisterFlags to request the types of connections -// it wants support for. This must be done before involing flags.Parse. +// it wants support for. This must be done before invoking flags.Parse. // After flag parsing, app invokes the Init function, which will return // a DBConfigs object. // The app must store the DBConfigs object internally, and use it to // build connection parameters as needed. -// The DBName is initially empty and may later be set or changed by the app. type DBConfigs struct { - userConfigs map[string]*userConfig - DBName sync2.AtomicString - SidecarDBName sync2.AtomicString + Socket string `json:"socket,omitempty"` + Host string `json:"host,omitempty"` + Port int `json:"port,omitempty"` + Charset string `json:"charset,omitempty"` + Flags uint64 `json:"flags,omitempty"` + Flavor string `json:"flavor,omitempty"` + SslCa string `json:"sslCa,omitempty"` + SslCaPath string `json:"sslCaPath,omitempty"` + SslCert string `json:"sslCert,omitempty"` + SslKey string `json:"sslKey,omitempty"` + ServerName string `json:"serverName,omitempty"` + ConnectTimeoutMilliseconds int `json:"connectTimeoutMilliseconds,omitempty"` + + App UserConfig `json:"app,omitempty"` + Dba UserConfig `json:"dba,omitempty"` + Filtered UserConfig `json:"filtered,omitempty"` + Repl UserConfig `json:"repl,omitempty"` + Appdebug UserConfig `json:"appdebug,omitempty"` + Allprivs UserConfig `json:"allprivs,omitempty"` + externalRepl UserConfig + + appParams mysql.ConnParams + dbaParams mysql.ConnParams + filteredParams mysql.ConnParams + replParams mysql.ConnParams + appdebugParams mysql.ConnParams + allprivsParams mysql.ConnParams + externalReplParams mysql.ConnParams + + dbname string } -type userConfig struct { - useSSL bool - param mysql.ConnParams +// UserConfig contains user-specific configs. +type UserConfig struct { + User string `json:"user,omitempty"` + Password string `json:"password,omitempty"` + UseSSL bool `json:"useSsl,omitempty"` + UseTCP bool `json:"useTcp,omitempty"` } -// config flags -const ( - App = "app" - AppDebug = "appdebug" - // AllPrivs user should have more privileges than App (should include possibility to do - // schema changes and write to internal Vitess tables), but it shouldn't have SUPER - // privilege like Dba has. - AllPrivs = "allprivs" - Dba = "dba" - Filtered = "filtered" - Repl = "repl" - ExternalRepl = "erepl" -) - -// All can be used to register all flags: RegisterFlags(All...) -var All = []string{App, AppDebug, AllPrivs, Dba, Filtered, Repl, ExternalRepl} - // RegisterFlags registers the flags for the given DBConfigFlag. // For instance, vttablet will register client, dba and repl. // Returns all registered flags. func RegisterFlags(userKeys ...string) { registerBaseFlags() for _, userKey := range userKeys { - uc := &userConfig{} - dbConfigs.userConfigs[userKey] = uc - registerPerUserFlags(uc, userKey) + uc, cp := GlobalDBConfigs.getParams(userKey, &GlobalDBConfigs) + registerPerUserFlags(userKey, uc, cp) } } func registerBaseFlags() { - flag.StringVar(&baseConfig.UnixSocket, "db_socket", "", "The unix socket to connect on. If this is specified, host and port will not be used.") - flag.StringVar(&baseConfig.Host, "db_host", "", "The host name for the tcp connection.") - flag.IntVar(&baseConfig.Port, "db_port", 0, "tcp port") - flag.StringVar(&baseConfig.Charset, "db_charset", "", "Character set. Only utf8 or latin1 based character sets are supported.") - flag.Uint64Var(&baseConfig.Flags, "db_flags", 0, "Flag values as defined by MySQL.") - flag.StringVar(&baseConfig.Flavor, "db_flavor", "", "Flavor overrid. Valid value is FilePos.") - flag.StringVar(&baseConfig.SslCa, "db_ssl_ca", "", "connection ssl ca") - flag.StringVar(&baseConfig.SslCaPath, "db_ssl_ca_path", "", "connection ssl ca path") - flag.StringVar(&baseConfig.SslCert, "db_ssl_cert", "", "connection ssl certificate") - flag.StringVar(&baseConfig.SslKey, "db_ssl_key", "", "connection ssl key") - flag.StringVar(&baseConfig.ServerName, "db_server_name", "", "server name of the DB we are connecting to.") - flag.Uint64Var(&baseConfig.ConnectTimeoutMs, "db_connect_timeout_ms", 0, "connection timeout to mysqld in milliseconds (0 for no timeout)") + flag.StringVar(&GlobalDBConfigs.Socket, "db_socket", "", "The unix socket to connect on. If this is specified, host and port will not be used.") + flag.StringVar(&GlobalDBConfigs.Host, "db_host", "", "The host name for the tcp connection.") + flag.IntVar(&GlobalDBConfigs.Port, "db_port", 0, "tcp port") + flag.StringVar(&GlobalDBConfigs.Charset, "db_charset", "", "Character set. Only utf8 or latin1 based character sets are supported.") + flag.Uint64Var(&GlobalDBConfigs.Flags, "db_flags", 0, "Flag values as defined by MySQL.") + flag.StringVar(&GlobalDBConfigs.Flavor, "db_flavor", "", "Flavor overrid. Valid value is FilePos.") + flag.StringVar(&GlobalDBConfigs.SslCa, "db_ssl_ca", "", "connection ssl ca") + flag.StringVar(&GlobalDBConfigs.SslCaPath, "db_ssl_ca_path", "", "connection ssl ca path") + flag.StringVar(&GlobalDBConfigs.SslCert, "db_ssl_cert", "", "connection ssl certificate") + flag.StringVar(&GlobalDBConfigs.SslKey, "db_ssl_key", "", "connection ssl key") + flag.StringVar(&GlobalDBConfigs.ServerName, "db_server_name", "", "server name of the DB we are connecting to.") + flag.IntVar(&GlobalDBConfigs.ConnectTimeoutMilliseconds, "db_connect_timeout_ms", 0, "connection timeout to mysqld in milliseconds (0 for no timeout)") } // The flags will change the global singleton // TODO(sougou): deprecate the legacy flags. -func registerPerUserFlags(dbc *userConfig, userKey string) { +func registerPerUserFlags(userKey string, uc *UserConfig, cp *mysql.ConnParams) { newUserFlag := "db_" + userKey + "_user" - flag.StringVar(&dbc.param.Uname, "db-config-"+userKey+"-uname", "vt_"+userKey, "deprecated: use "+newUserFlag) - flag.StringVar(&dbc.param.Uname, newUserFlag, "vt_"+userKey, "db "+userKey+" user userKey") + flag.StringVar(&uc.User, "db-config-"+userKey+"-uname", "vt_"+userKey, "deprecated: use "+newUserFlag) + flag.StringVar(&uc.User, newUserFlag, "vt_"+userKey, "db "+userKey+" user userKey") newPasswordFlag := "db_" + userKey + "_password" - flag.StringVar(&dbc.param.Pass, "db-config-"+userKey+"-pass", "", "db "+userKey+" deprecated: use "+newPasswordFlag) - flag.StringVar(&dbc.param.Pass, newPasswordFlag, "", "db "+userKey+" password") - - flag.BoolVar(&dbc.useSSL, "db_"+userKey+"_use_ssl", true, "Set this flag to false to make the "+userKey+" connection to not use ssl") - - flag.StringVar(&dbc.param.Host, "db-config-"+userKey+"-host", "", "deprecated: use db_host") - flag.IntVar(&dbc.param.Port, "db-config-"+userKey+"-port", 0, "deprecated: use db_port") - flag.StringVar(&dbc.param.UnixSocket, "db-config-"+userKey+"-unixsocket", "", "deprecated: use db_socket") - flag.StringVar(&dbc.param.Charset, "db-config-"+userKey+"-charset", "utf8", "deprecated: use db_charset") - flag.Uint64Var(&dbc.param.Flags, "db-config-"+userKey+"-flags", 0, "deprecated: use db_flags") - flag.StringVar(&dbc.param.SslCa, "db-config-"+userKey+"-ssl-ca", "", "deprecated: use db_ssl_ca") - flag.StringVar(&dbc.param.SslCaPath, "db-config-"+userKey+"-ssl-ca-path", "", "deprecated: use db_ssl_ca_path") - flag.StringVar(&dbc.param.SslCert, "db-config-"+userKey+"-ssl-cert", "", "deprecated: use db_ssl_cert") - flag.StringVar(&dbc.param.SslKey, "db-config-"+userKey+"-ssl-key", "", "deprecated: use db_ssl_key") - flag.StringVar(&dbc.param.ServerName, "db-config-"+userKey+"-server_name", "", "deprecated: use db_server_name") - flag.StringVar(&dbc.param.Flavor, "db-config-"+userKey+"-flavor", "", "deprecated: use db_flavor") - - flag.StringVar(&dbc.param.DeprecatedDBName, "db-config-"+userKey+"-dbname", "", "deprecated: dbname does not need to be explicitly configured") + flag.StringVar(&uc.Password, "db-config-"+userKey+"-pass", "", "db "+userKey+" deprecated: use "+newPasswordFlag) + flag.StringVar(&uc.Password, newPasswordFlag, "", "db "+userKey+" password") + + flag.BoolVar(&uc.UseSSL, "db_"+userKey+"_use_ssl", true, "Set this flag to false to make the "+userKey+" connection to not use ssl") + + flag.StringVar(&cp.Host, "db-config-"+userKey+"-host", "", "deprecated: use db_host") + flag.IntVar(&cp.Port, "db-config-"+userKey+"-port", 0, "deprecated: use db_port") + flag.StringVar(&cp.UnixSocket, "db-config-"+userKey+"-unixsocket", "", "deprecated: use db_socket") + flag.StringVar(&cp.Charset, "db-config-"+userKey+"-charset", "utf8", "deprecated: use db_charset") + flag.Uint64Var(&cp.Flags, "db-config-"+userKey+"-flags", 0, "deprecated: use db_flags") + flag.StringVar(&cp.SslCa, "db-config-"+userKey+"-ssl-ca", "", "deprecated: use db_ssl_ca") + flag.StringVar(&cp.SslCaPath, "db-config-"+userKey+"-ssl-ca-path", "", "deprecated: use db_ssl_ca_path") + flag.StringVar(&cp.SslCert, "db-config-"+userKey+"-ssl-cert", "", "deprecated: use db_ssl_cert") + flag.StringVar(&cp.SslKey, "db-config-"+userKey+"-ssl-key", "", "deprecated: use db_ssl_key") + flag.StringVar(&cp.ServerName, "db-config-"+userKey+"-server_name", "", "deprecated: use db_server_name") + flag.StringVar(&cp.Flavor, "db-config-"+userKey+"-flavor", "", "deprecated: use db_flavor") + + if userKey == ExternalRepl { + flag.StringVar(&cp.DeprecatedDBName, "db-config-"+userKey+"-dbname", "", "deprecated: dbname does not need to be explicitly configured") + } } // Connector contains Connection Parameters for mysql connection type Connector struct { connParams *mysql.ConnParams - dbName string - host string } // New initializes a ConnParams from mysql connection parameters @@ -174,59 +200,69 @@ func (c Connector) MysqlParams() (*mysql.ConnParams, error) { // DBName gets the dbname from mysql.ConnParams func (c Connector) DBName() string { - params, _ := c.MysqlParams() - return params.DbName + return c.connParams.DbName } // Host gets the host from mysql.ConnParams func (c Connector) Host() string { - params, _ := c.MysqlParams() - return params.Host + return c.connParams.Host +} + +// WithDBName returns a new DBConfigs with the dbname set. +func (dbcfgs *DBConfigs) WithDBName(dbname string) *DBConfigs { + dbcfgs = dbcfgs.Clone() + dbcfgs.dbname = dbname + return dbcfgs +} + +// DBName returns the db name. +func (dbcfgs *DBConfigs) DBName() string { + return dbcfgs.dbname } // AppWithDB returns connection parameters for app with dbname set. func (dbcfgs *DBConfigs) AppWithDB() Connector { - return dbcfgs.makeParams(App, true) + return dbcfgs.makeParams(&dbcfgs.appParams, true) } // AppDebugWithDB returns connection parameters for appdebug with dbname set. func (dbcfgs *DBConfigs) AppDebugWithDB() Connector { - return dbcfgs.makeParams(AppDebug, true) + return dbcfgs.makeParams(&dbcfgs.appdebugParams, true) } // AllPrivsWithDB returns connection parameters for appdebug with dbname set. func (dbcfgs *DBConfigs) AllPrivsWithDB() Connector { - return dbcfgs.makeParams(AllPrivs, true) + return dbcfgs.makeParams(&dbcfgs.allprivsParams, true) } -// Dba returns connection parameters for dba with no dbname set. -func (dbcfgs *DBConfigs) Dba() Connector { - return dbcfgs.makeParams(Dba, false) +// DbaConnector returns connection parameters for dba with no dbname set. +func (dbcfgs *DBConfigs) DbaConnector() Connector { + return dbcfgs.makeParams(&dbcfgs.dbaParams, false) } // DbaWithDB returns connection parameters for appdebug with dbname set. func (dbcfgs *DBConfigs) DbaWithDB() Connector { - return dbcfgs.makeParams(Dba, true) + return dbcfgs.makeParams(&dbcfgs.dbaParams, true) } // FilteredWithDB returns connection parameters for filtered with dbname set. func (dbcfgs *DBConfigs) FilteredWithDB() Connector { - return dbcfgs.makeParams(Filtered, true) + return dbcfgs.makeParams(&dbcfgs.filteredParams, true) } -// Repl returns connection parameters for repl with no dbname set. -func (dbcfgs *DBConfigs) Repl() Connector { - return dbcfgs.makeParams(Repl, false) +// ReplConnector returns connection parameters for repl with no dbname set. +func (dbcfgs *DBConfigs) ReplConnector() Connector { + return dbcfgs.makeParams(&dbcfgs.replParams, false) } // ExternalRepl returns connection parameters for repl with no dbname set. func (dbcfgs *DBConfigs) ExternalRepl() Connector { - return dbcfgs.makeParams(ExternalRepl, true) + return dbcfgs.makeParams(&dbcfgs.externalReplParams, true) } // ExternalReplWithDB returns connection parameters for repl with dbname set. func (dbcfgs *DBConfigs) ExternalReplWithDB() Connector { - params := dbcfgs.makeParams(ExternalRepl, true) + params := dbcfgs.makeParams(&dbcfgs.externalReplParams, true) // TODO @rafael: This is a hack to allows to configure external databases by providing // db-config-erepl-dbname. if params.connParams.DeprecatedDBName != "" { @@ -237,16 +273,10 @@ func (dbcfgs *DBConfigs) ExternalReplWithDB() Connector { } // AppWithDB returns connection parameters for app with dbname set. -func (dbcfgs *DBConfigs) makeParams(userKey string, withDB bool) Connector { - orig := dbcfgs.userConfigs[userKey] - if orig == nil { - return Connector{ - connParams: &mysql.ConnParams{}, - } - } - result := orig.param +func (dbcfgs *DBConfigs) makeParams(cp *mysql.ConnParams, withDB bool) Connector { + result := *cp if withDB { - result.DbName = dbcfgs.DBName.Get() + result.DbName = dbcfgs.dbname } return Connector{ connParams: &result, @@ -255,123 +285,140 @@ func (dbcfgs *DBConfigs) makeParams(userKey string, withDB bool) Connector { // IsZero returns true if DBConfigs was uninitialized. func (dbcfgs *DBConfigs) IsZero() bool { - return len(dbcfgs.userConfigs) == 0 + return *dbcfgs == DBConfigs{} +} + +// HasGlobalSettings returns true if DBConfigs contains values +// for gloabl configs. +func (dbcfgs *DBConfigs) HasGlobalSettings() bool { + return dbcfgs.Host != "" || dbcfgs.Socket != "" } func (dbcfgs *DBConfigs) String() string { - out := struct { - Conn mysql.ConnParams - Users map[string]string - }{ - Users: make(map[string]string), - } - if conn := dbcfgs.userConfigs[App]; conn != nil { - out.Conn = conn.param - } else if conn := dbcfgs.userConfigs[Dba]; conn != nil { - out.Conn = conn.param - } - out.Conn.Pass = "****" - for k, uc := range dbcfgs.userConfigs { - out.Users[k] = uc.param.Uname - } - data, err := json.MarshalIndent(out, "", " ") + out, err := yaml2.Marshal(dbcfgs.Redacted()) if err != nil { return err.Error() } - return string(data) + return string(out) } -// Copy returns a copy of the DBConfig. -func (dbcfgs *DBConfigs) Copy() *DBConfigs { - result := &DBConfigs{userConfigs: make(map[string]*userConfig)} - for k, u := range dbcfgs.userConfigs { - newu := *u - result.userConfigs[k] = &newu - } - result.DBName.Set(dbcfgs.DBName.Get()) - result.SidecarDBName.Set(dbcfgs.SidecarDBName.Get()) - return result +// MarshalJSON marshals after redacting passwords. +func (dbcfgs *DBConfigs) MarshalJSON() ([]byte, error) { + type nonCustom DBConfigs + return json.Marshal((*nonCustom)(dbcfgs.Redacted())) +} + +// Redacted redacts passwords from DBConfigs. +func (dbcfgs *DBConfigs) Redacted() *DBConfigs { + dbcfgs = dbcfgs.Clone() + dbcfgs.App.Password = "****" + dbcfgs.Dba.Password = "****" + dbcfgs.Filtered.Password = "****" + dbcfgs.Repl.Password = "****" + dbcfgs.Appdebug.Password = "****" + dbcfgs.Allprivs.Password = "****" + return dbcfgs } -// HasConnectionParams returns true if connection parameters were -// specified in the command-line. This will allow the caller to -// search for alternate ways to connect, like looking in the my.cnf -// file. -func HasConnectionParams() bool { - return baseConfig.Host != "" || baseConfig.UnixSocket != "" +// Clone returns a clone of the DBConfig. +func (dbcfgs *DBConfigs) Clone() *DBConfigs { + result := *dbcfgs + return &result } // Init will initialize all the necessary connection parameters. -// Precedence is as follows: if baseConfig command line options are -// set, they supersede all other settings. -// If baseConfig is not set, the next priority is with per-user connection +// Precedence is as follows: if UserConfig settings are set, +// they supersede all other settings. +// The next priority is with per-user connection // parameters. This is only for legacy support. // If no per-user parameters are supplied, then the defaultSocketFile // is used to initialize the per-user conn params. -func Init(defaultSocketFile string) (*DBConfigs, error) { - // The new base configs, if set, supersede legacy settings. - for user, uc := range dbConfigs.userConfigs { +func (dbcfgs *DBConfigs) Init(defaultSocketFile string) *DBConfigs { + dbcfgs = dbcfgs.Clone() + for _, userKey := range All { + uc, cp := dbcfgs.getParams(userKey, dbcfgs) // TODO @rafael: For ExternalRepl we need to respect the provided host / port // At the moment this is an snowflake user connection type that it used by // vreplication to connect to external mysql hosts that are not part of a vitess // cluster. In the future we need to refactor all dbconfig to support custom users // in a more flexible way. - if HasConnectionParams() && user != ExternalRepl { - uc.param.Host = baseConfig.Host - uc.param.Port = baseConfig.Port - uc.param.UnixSocket = baseConfig.UnixSocket - } else if uc.param.UnixSocket == "" && uc.param.Host == "" { - uc.param.UnixSocket = defaultSocketFile + if dbcfgs.HasGlobalSettings() && userKey != ExternalRepl { + cp.Host = dbcfgs.Host + cp.Port = dbcfgs.Port + if !uc.UseTCP { + cp.UnixSocket = dbcfgs.Socket + } + } else if cp.UnixSocket == "" && cp.Host == "" { + cp.UnixSocket = defaultSocketFile } - if baseConfig.Charset != "" { - uc.param.Charset = baseConfig.Charset + if dbcfgs.Charset != "" { + cp.Charset = dbcfgs.Charset } - if baseConfig.Flags != 0 { - uc.param.Flags = baseConfig.Flags + if dbcfgs.Flags != 0 { + cp.Flags = dbcfgs.Flags } - if user != ExternalRepl { - uc.param.Flavor = baseConfig.Flavor + if userKey != ExternalRepl { + cp.Flavor = dbcfgs.Flavor } - if uc.useSSL { - uc.param.SslCa = baseConfig.SslCa - uc.param.SslCaPath = baseConfig.SslCaPath - uc.param.SslCert = baseConfig.SslCert - uc.param.SslKey = baseConfig.SslKey - uc.param.ServerName = baseConfig.ServerName + cp.ConnectTimeoutMs = uint64(dbcfgs.ConnectTimeoutMilliseconds) + + cp.Uname = uc.User + cp.Pass = uc.Password + if uc.UseSSL { + cp.SslCa = dbcfgs.SslCa + cp.SslCaPath = dbcfgs.SslCaPath + cp.SslCert = dbcfgs.SslCert + cp.SslKey = dbcfgs.SslKey + cp.ServerName = dbcfgs.ServerName } - uc.param.ConnectTimeoutMs = baseConfig.ConnectTimeoutMs } - // See if the CredentialsServer is working. We do not use the - // result for anything, this is just a check. - for _, uc := range dbConfigs.userConfigs { - if _, err := withCredentials(&uc.param); err != nil { - return nil, fmt.Errorf("dbconfig cannot be initialized: %v", err) - } - // Check for only one. - break - } - dbConfigs.SidecarDBName.Set("_vt") + log.Infof("DBConfigs: %v\n", dbcfgs.String()) + return dbcfgs +} - log.Infof("DBConfigs: %v\n", dbConfigs.String()) - return &dbConfigs, nil +func (dbcfgs *DBConfigs) getParams(userKey string, dbc *DBConfigs) (*UserConfig, *mysql.ConnParams) { + var uc *UserConfig + var cp *mysql.ConnParams + switch userKey { + case App: + uc = &dbcfgs.App + cp = &dbcfgs.appParams + case AppDebug: + uc = &dbcfgs.Appdebug + cp = &dbcfgs.appdebugParams + case AllPrivs: + uc = &dbcfgs.Allprivs + cp = &dbcfgs.allprivsParams + case Dba: + uc = &dbcfgs.Dba + cp = &dbcfgs.dbaParams + case Filtered: + uc = &dbcfgs.Filtered + cp = &dbcfgs.filteredParams + case Repl: + uc = &dbcfgs.Repl + cp = &dbcfgs.replParams + case ExternalRepl: + uc = &dbcfgs.externalRepl + cp = &dbcfgs.externalReplParams + default: + log.Exitf("Invalid db user key requested: %s", userKey) + } + return uc, cp } // NewTestDBConfigs returns a DBConfigs meant for testing. -func NewTestDBConfigs(genParams, appDebugParams mysql.ConnParams, dbName string) *DBConfigs { - dbcfgs := &DBConfigs{ - userConfigs: map[string]*userConfig{ - App: {param: genParams}, - AppDebug: {param: appDebugParams}, - AllPrivs: {param: genParams}, - Dba: {param: genParams}, - Filtered: {param: genParams}, - Repl: {param: genParams}, - ExternalRepl: {param: genParams}, - }, +func NewTestDBConfigs(genParams, appDebugParams mysql.ConnParams, dbname string) *DBConfigs { + return &DBConfigs{ + appParams: genParams, + appdebugParams: appDebugParams, + allprivsParams: genParams, + dbaParams: genParams, + filteredParams: genParams, + replParams: genParams, + externalReplParams: genParams, + dbname: dbname, } - dbcfgs.DBName.Set(dbName) - dbcfgs.SidecarDBName.Set("_vt") - return dbcfgs } diff --git a/go/vt/dbconfigs/dbconfigs_test.go b/go/vt/dbconfigs/dbconfigs_test.go index 323da7eceee..23855b6c5ac 100644 --- a/go/vt/dbconfigs/dbconfigs_test.go +++ b/go/vt/dbconfigs/dbconfigs_test.go @@ -20,313 +20,210 @@ import ( "fmt" "io/ioutil" "os" - "reflect" "syscall" "testing" "time" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "vitess.io/vitess/go/mysql" + "vitess.io/vitess/go/yaml2" ) -func TestRegisterFlagsWithSomeFlags(t *testing.T) { - f := saveDBConfigs() - defer f() - - dbConfigs = DBConfigs{userConfigs: make(map[string]*userConfig)} - RegisterFlags(Dba, Repl) - for k := range dbConfigs.userConfigs { - if k != Dba && k != Repl { - t.Errorf("dbConfigs.params: %v, want dba or repl", k) - } - } -} - func TestInit(t *testing.T) { - f := saveDBConfigs() - defer f() + dbConfigs := DBConfigs{ + appParams: mysql.ConnParams{UnixSocket: "socket"}, + dbaParams: mysql.ConnParams{Host: "host"}, + } + dbc := dbConfigs.Init("default") + assert.Equal(t, mysql.ConnParams{UnixSocket: "socket"}, dbc.appParams) + assert.Equal(t, mysql.ConnParams{Host: "host"}, dbc.dbaParams) + assert.Equal(t, mysql.ConnParams{UnixSocket: "default"}, dbc.appdebugParams) dbConfigs = DBConfigs{ - userConfigs: map[string]*userConfig{ - App: {param: mysql.ConnParams{UnixSocket: "socket"}}, - AppDebug: {}, - Dba: {param: mysql.ConnParams{Host: "host"}}, + Host: "a", + Port: 1, + Socket: "b", + Charset: "c", + Flags: 2, + Flavor: "flavor", + SslCa: "d", + SslCaPath: "e", + SslCert: "f", + SslKey: "g", + ConnectTimeoutMilliseconds: 250, + App: UserConfig{ + User: "app", + Password: "apppass", }, - } - dbc, err := Init("default") - if err != nil { - t.Fatal(err) - } - if got, want := dbc.userConfigs[App].param.UnixSocket, "socket"; got != want { - t.Errorf("dbc.app.UnixSocket: %v, want %v", got, want) - } - if got, want := dbc.userConfigs[Dba].param.Host, "host"; got != want { - t.Errorf("dbc.app.Host: %v, want %v", got, want) - } - if got, want := dbc.userConfigs[AppDebug].param.UnixSocket, "default"; got != want { - t.Errorf("dbc.app.UnixSocket: %v, want %v", got, want) - } - - baseConfig = mysql.ConnParams{ - Host: "a", - Port: 1, - Uname: "b", - Pass: "c", - DbName: "d", - UnixSocket: "e", - Charset: "f", - Flags: 2, - Flavor: "flavor", - SslCa: "g", - SslCaPath: "h", - SslCert: "i", - SslKey: "j", - } - dbConfigs = DBConfigs{ - userConfigs: map[string]*userConfig{ - App: { - param: mysql.ConnParams{ - Uname: "app", - Pass: "apppass", - UnixSocket: "socket", - }, - }, - AppDebug: { - useSSL: true, - }, - Dba: { - useSSL: true, - param: mysql.ConnParams{ - Uname: "dba", - Pass: "dbapass", - Host: "host", - }, - }, + Appdebug: UserConfig{ + UseSSL: true, }, - } - dbc, err = Init("default") - if err != nil { - t.Fatal(err) - } - want := &DBConfigs{ - userConfigs: map[string]*userConfig{ - App: { - param: mysql.ConnParams{ - Host: "a", - Port: 1, - Uname: "app", - Pass: "apppass", - UnixSocket: "e", - Charset: "f", - Flags: 2, - Flavor: "flavor", - }, - }, - AppDebug: { - useSSL: true, - param: mysql.ConnParams{ - Host: "a", - Port: 1, - UnixSocket: "e", - Charset: "f", - Flags: 2, - Flavor: "flavor", - SslCa: "g", - SslCaPath: "h", - SslCert: "i", - SslKey: "j", - }, - }, - Dba: { - useSSL: true, - param: mysql.ConnParams{ - Host: "a", - Port: 1, - Uname: "dba", - Pass: "dbapass", - UnixSocket: "e", - Charset: "f", - Flags: 2, - Flavor: "flavor", - SslCa: "g", - SslCaPath: "h", - SslCert: "i", - SslKey: "j", - }, - }, + Dba: UserConfig{ + User: "dba", + Password: "dbapass", + UseSSL: true, + }, + appParams: mysql.ConnParams{ + UnixSocket: "socket", + }, + dbaParams: mysql.ConnParams{ + Host: "host", }, } - // Compare individually, otherwise the errors are not readable. - if !reflect.DeepEqual(dbc.userConfigs[App].param, want.userConfigs[App].param) { - t.Errorf("dbc: \n%#v, want \n%#v", dbc.userConfigs[App].param, want.userConfigs[App].param) + dbc = dbConfigs.Init("default") + + want := mysql.ConnParams{ + Host: "a", + Port: 1, + Uname: "app", + Pass: "apppass", + UnixSocket: "b", + Charset: "c", + Flags: 2, + Flavor: "flavor", + ConnectTimeoutMs: 250, } - if !reflect.DeepEqual(dbc.userConfigs[AppDebug].param, want.userConfigs[AppDebug].param) { - t.Errorf("dbc: \n%#v, want \n%#v", dbc.userConfigs[AppDebug].param, want.userConfigs[AppDebug].param) + assert.Equal(t, want, dbc.appParams) + + want = mysql.ConnParams{ + Host: "a", + Port: 1, + UnixSocket: "b", + Charset: "c", + Flags: 2, + Flavor: "flavor", + SslCa: "d", + SslCaPath: "e", + SslCert: "f", + SslKey: "g", + ConnectTimeoutMs: 250, } - if !reflect.DeepEqual(dbc.userConfigs[Dba].param, want.userConfigs[Dba].param) { - t.Errorf("dbc: \n%#v, want \n%#v", dbc.userConfigs[Dba].param, want.userConfigs[Dba].param) + assert.Equal(t, want, dbc.appdebugParams) + want = mysql.ConnParams{ + Host: "a", + Port: 1, + Uname: "dba", + Pass: "dbapass", + UnixSocket: "b", + Charset: "c", + Flags: 2, + Flavor: "flavor", + SslCa: "d", + SslCaPath: "e", + SslCert: "f", + SslKey: "g", + ConnectTimeoutMs: 250, } + assert.Equal(t, want, dbc.dbaParams) // Test that baseConfig does not override Charset and Flag if they're // not specified. - baseConfig = mysql.ConnParams{ - Host: "a", - Port: 1, - Uname: "b", - Pass: "c", - DbName: "d", - UnixSocket: "e", - SslCa: "g", - SslCaPath: "h", - SslCert: "i", - SslKey: "j", - } dbConfigs = DBConfigs{ - userConfigs: map[string]*userConfig{ - App: { - param: mysql.ConnParams{ - Uname: "app", - Pass: "apppass", - UnixSocket: "socket", - Charset: "f", - }, - }, - AppDebug: { - useSSL: true, - }, - Dba: { - useSSL: true, - param: mysql.ConnParams{ - Uname: "dba", - Pass: "dbapass", - Host: "host", - Flags: 2, - }, - }, + Host: "a", + Port: 1, + Socket: "b", + SslCa: "d", + SslCaPath: "e", + SslCert: "f", + SslKey: "g", + App: UserConfig{ + User: "app", + Password: "apppass", }, - } - dbc, err = Init("default") - if err != nil { - t.Fatal(err) - } - want = &DBConfigs{ - userConfigs: map[string]*userConfig{ - App: { - param: mysql.ConnParams{ - Host: "a", - Port: 1, - Uname: "app", - Pass: "apppass", - UnixSocket: "e", - Charset: "f", - }, - }, - AppDebug: { - useSSL: true, - param: mysql.ConnParams{ - Host: "a", - Port: 1, - UnixSocket: "e", - SslCa: "g", - SslCaPath: "h", - SslCert: "i", - SslKey: "j", - }, - }, - Dba: { - useSSL: true, - param: mysql.ConnParams{ - Host: "a", - Port: 1, - Uname: "dba", - Pass: "dbapass", - UnixSocket: "e", - Flags: 2, - SslCa: "g", - SslCaPath: "h", - SslCert: "i", - SslKey: "j", - }, - }, + Appdebug: UserConfig{ + UseSSL: true, + }, + Dba: UserConfig{ + User: "dba", + Password: "dbapass", + UseSSL: true, + }, + appParams: mysql.ConnParams{ + UnixSocket: "socket", + Charset: "f", + }, + dbaParams: mysql.ConnParams{ + Host: "host", + Flags: 2, }, } - // Compare individually, otherwise the errors are not readable. - if !reflect.DeepEqual(dbc.userConfigs[App].param, want.userConfigs[App].param) { - t.Errorf("dbc: \n%#v, want \n%#v", dbc.userConfigs[App].param, want.userConfigs[App].param) - } - if !reflect.DeepEqual(dbc.userConfigs[AppDebug].param, want.userConfigs[AppDebug].param) { - t.Errorf("dbc: \n%#v, want \n%#v", dbc.userConfigs[AppDebug].param, want.userConfigs[AppDebug].param) + dbc = dbConfigs.Init("default") + want = mysql.ConnParams{ + Host: "a", + Port: 1, + Uname: "app", + Pass: "apppass", + UnixSocket: "b", + Charset: "f", } - if !reflect.DeepEqual(dbc.userConfigs[Dba].param, want.userConfigs[Dba].param) { - t.Errorf("dbc: \n%#v, want \n%#v", dbc.userConfigs[Dba].param, want.userConfigs[Dba].param) + assert.Equal(t, want, dbc.appParams) + want = mysql.ConnParams{ + Host: "a", + Port: 1, + UnixSocket: "b", + SslCa: "d", + SslCaPath: "e", + SslCert: "f", + SslKey: "g", + } + assert.Equal(t, want, dbc.appdebugParams) + want = mysql.ConnParams{ + Host: "a", + Port: 1, + Uname: "dba", + Pass: "dbapass", + UnixSocket: "b", + Flags: 2, + SslCa: "d", + SslCaPath: "e", + SslCert: "f", + SslKey: "g", } + assert.Equal(t, want, dbc.dbaParams) } -func TestInitTimeout(t *testing.T) { - f := saveDBConfigs() - defer f() - - baseConfig = mysql.ConnParams{ - Host: "a", - Port: 1, - Uname: "b", - Pass: "c", - DbName: "d", - UnixSocket: "e", - Charset: "f", - Flags: 2, - Flavor: "flavor", - ConnectTimeoutMs: 250, - } - dbConfigs = DBConfigs{ - userConfigs: map[string]*userConfig{ - App: { - param: mysql.ConnParams{ - Uname: "app", - Pass: "apppass", - }, - }, +func TestUseTCP(t *testing.T) { + dbConfigs := DBConfigs{ + Host: "a", + Port: 1, + Socket: "b", + App: UserConfig{ + User: "app", + UseTCP: true, + }, + Dba: UserConfig{ + User: "dba", }, } + dbc := dbConfigs.Init("default") - dbc, err := Init("default") - if err != nil { - t.Fatal(err) - } - want := &DBConfigs{ - userConfigs: map[string]*userConfig{ - App: { - param: mysql.ConnParams{ - Host: "a", - Port: 1, - Uname: "app", - Pass: "apppass", - UnixSocket: "e", - Charset: "f", - Flags: 2, - Flavor: "flavor", - ConnectTimeoutMs: 250, - }, - }, - }, + want := mysql.ConnParams{ + Host: "a", + Port: 1, + Uname: "app", } + assert.Equal(t, want, dbc.appParams) - if !reflect.DeepEqual(dbc.userConfigs[App].param, want.userConfigs[App].param) { - t.Errorf("dbc: \n%#v, want \n%#v", dbc.userConfigs[App].param, want.userConfigs[App].param) + want = mysql.ConnParams{ + Host: "a", + Port: 1, + Uname: "dba", + UnixSocket: "b", } + assert.Equal(t, want, dbc.dbaParams) } func TestAccessors(t *testing.T) { dbc := &DBConfigs{ - userConfigs: map[string]*userConfig{ - App: {}, - AppDebug: {}, - AllPrivs: {}, - Dba: {}, - Filtered: {}, - Repl: {}, - }, - } - dbc.DBName.Set("db") + appParams: mysql.ConnParams{}, + appdebugParams: mysql.ConnParams{}, + allprivsParams: mysql.ConnParams{}, + dbaParams: mysql.ConnParams{}, + filteredParams: mysql.ConnParams{}, + replParams: mysql.ConnParams{}, + } + dbc = dbc.WithDBName("db") if got, want := dbc.AppWithDB().connParams.DbName, "db"; got != want { t.Errorf("dbc.AppWithDB().DbName: %v, want %v", got, want) } @@ -336,7 +233,7 @@ func TestAccessors(t *testing.T) { if got, want := dbc.AppDebugWithDB().connParams.DbName, "db"; got != want { t.Errorf("dbc.AppDebugWithDB().DbName: %v, want %v", got, want) } - if got, want := dbc.Dba().connParams.DbName, ""; got != want { + if got, want := dbc.DbaConnector().connParams.DbName, ""; got != want { t.Errorf("dbc.Dba().DbName: %v, want %v", got, want) } if got, want := dbc.DbaWithDB().connParams.DbName, "db"; got != want { @@ -345,28 +242,11 @@ func TestAccessors(t *testing.T) { if got, want := dbc.FilteredWithDB().connParams.DbName, "db"; got != want { t.Errorf("dbc.FilteredWithDB().DbName: %v, want %v", got, want) } - if got, want := dbc.Repl().connParams.DbName, ""; got != want { + if got, want := dbc.ReplConnector().connParams.DbName, ""; got != want { t.Errorf("dbc.Repl().DbName: %v, want %v", got, want) } } -func TestCopy(t *testing.T) { - want := &DBConfigs{ - userConfigs: map[string]*userConfig{ - App: {param: mysql.ConnParams{UnixSocket: "aa"}}, - AppDebug: {}, - Repl: {}, - }, - } - want.DBName.Set("db") - want.SidecarDBName.Set("_vt") - - got := want.Copy() - if !reflect.DeepEqual(got, want) { - t.Errorf("DBConfig: %v, want %v", got, want) - } -} - func TestCredentialsFileHUP(t *testing.T) { tmpFile, err := ioutil.TempFile("", "credentials.json") if err != nil { @@ -411,19 +291,63 @@ func hupTest(t *testing.T, tmpFile *os.File, oldStr, newStr string) { } } -func saveDBConfigs() (restore func()) { - savedDBConfigs := DBConfigs{ - userConfigs: dbConfigs.userConfigs, +func TestYaml(t *testing.T) { + db := DBConfigs{ + Socket: "a", + Port: 1, + Flags: 20, + App: UserConfig{ + User: "vt_app", + UseSSL: true, + }, + Dba: UserConfig{ + User: "vt_dba", + }, } - savedDBConfigs.DBName.Set(dbConfigs.DBName.Get()) - savedDBConfigs.SidecarDBName.Set(dbConfigs.SidecarDBName.Get()) - savedBaseConfig := baseConfig - return func() { - dbConfigs := DBConfigs{ - userConfigs: savedDBConfigs.userConfigs, - } - dbConfigs.DBName.Set(savedDBConfigs.DBName.Get()) - dbConfigs.SidecarDBName.Set(savedDBConfigs.SidecarDBName.Get()) - baseConfig = savedBaseConfig + gotBytes, err := yaml2.Marshal(&db) + require.NoError(t, err) + wantBytes := `allprivs: + password: '****' +app: + password: '****' + useSsl: true + user: vt_app +appdebug: + password: '****' +dba: + password: '****' + user: vt_dba +filtered: + password: '****' +flags: 20 +port: 1 +repl: + password: '****' +socket: a +` + assert.Equal(t, wantBytes, string(gotBytes)) + + inBytes := []byte(`socket: a +port: 1 +flags: 20 +app: + user: vt_app + useSsl: true + useTCP: false +dba: + user: vt_dba +`) + gotdb := DBConfigs{ + Port: 1, + Flags: 20, + App: UserConfig{ + UseTCP: true, + }, + Dba: UserConfig{ + User: "aaa", + }, } + err = yaml2.Unmarshal(inBytes, &gotdb) + require.NoError(t, err) + assert.Equal(t, &db, &gotdb) } diff --git a/go/vt/mysqlctl/cmd.go b/go/vt/mysqlctl/cmd.go index 7d4a9ea0c96..971f704bc63 100644 --- a/go/vt/mysqlctl/cmd.go +++ b/go/vt/mysqlctl/cmd.go @@ -46,11 +46,7 @@ func CreateMysqldAndMycnf(tabletUID uint32, mysqlSocket string, mysqlPort int32) mycnf.SocketFile = mysqlSocket } - dbcfgs, err := dbconfigs.Init(mycnf.SocketFile) - if err != nil { - return nil, nil, fmt.Errorf("couldn't Init dbconfigs: %v", err) - } - + dbcfgs := dbconfigs.GlobalDBConfigs.Init(mycnf.SocketFile) return NewMysqld(dbcfgs), mycnf, nil } @@ -64,10 +60,6 @@ func OpenMysqldAndMycnf(tabletUID uint32) (*Mysqld, *Mycnf, error) { return nil, nil, fmt.Errorf("couldn't read my.cnf file: %v", err) } - dbcfgs, err := dbconfigs.Init(mycnf.SocketFile) - if err != nil { - return nil, nil, fmt.Errorf("couldn't Init dbconfigs: %v", err) - } - + dbcfgs := dbconfigs.GlobalDBConfigs.Init(mycnf.SocketFile) return NewMysqld(dbcfgs), mycnf, nil } diff --git a/go/vt/mysqlctl/mycnf_test.go b/go/vt/mysqlctl/mycnf_test.go index 3412e56c57b..e2ee0fcc4d2 100644 --- a/go/vt/mysqlctl/mycnf_test.go +++ b/go/vt/mysqlctl/mycnf_test.go @@ -96,7 +96,7 @@ func NoTestMycnfHook(t *testing.T) { // this is not being passed, so it should be nil os.Setenv("MY_VAR", "myvalue") - dbcfgs, _ := dbconfigs.Init(cnf.SocketFile) + dbcfgs := dbconfigs.GlobalDBConfigs.Init(cnf.SocketFile) mysqld := NewMysqld(dbcfgs) servenv.OnClose(mysqld.Close) diff --git a/go/vt/mysqlctl/mysqld.go b/go/vt/mysqlctl/mysqld.go index 44a32d907ea..14ce1c91173 100644 --- a/go/vt/mysqlctl/mysqld.go +++ b/go/vt/mysqlctl/mysqld.go @@ -121,7 +121,7 @@ func NewMysqld(dbcfgs *dbconfigs.DBConfigs) *Mysqld { but also relies on none of the flavor detection features being used at runtime. Currently this assumption is guaranteed true. */ - if dbconfigs.HasConnectionParams() { + if dbconfigs.GlobalDBConfigs.HasGlobalSettings() { log.Info("mysqld is unmanaged or remote. Skipping flavor detection") return result } @@ -269,7 +269,7 @@ func (mysqld *Mysqld) RunMysqlUpgrade() error { // privileges' right in the middle, and then subsequent // commands fail if we don't use valid credentials. So let's // use dba credentials. - params, err := mysqld.dbcfgs.Dba().MysqlParams() + params, err := mysqld.dbcfgs.DbaConnector().MysqlParams() if err != nil { return err } @@ -436,7 +436,7 @@ func (mysqld *Mysqld) startNoWait(ctx context.Context, cnf *Mycnf, mysqldArgs .. // will use the dba credentials to try to connect. Use wait() with // different credentials if needed. func (mysqld *Mysqld) Wait(ctx context.Context, cnf *Mycnf) error { - params, err := mysqld.dbcfgs.Dba().MysqlParams() + params, err := mysqld.dbcfgs.DbaConnector().MysqlParams() if err != nil { return err } @@ -526,7 +526,7 @@ func (mysqld *Mysqld) Shutdown(ctx context.Context, cnf *Mycnf, waitForMysqld bo if err != nil { return err } - params, err := mysqld.dbcfgs.Dba().MysqlParams() + params, err := mysqld.dbcfgs.DbaConnector().MysqlParams() if err != nil { return err } @@ -1097,7 +1097,7 @@ func (mysqld *Mysqld) GetAppConnection(ctx context.Context) (*dbconnpool.PooledD // GetDbaConnection creates a new DBConnection. func (mysqld *Mysqld) GetDbaConnection() (*dbconnpool.DBConnection, error) { - return dbconnpool.NewDBConnection(context.TODO(), mysqld.dbcfgs.Dba()) + return dbconnpool.NewDBConnection(context.TODO(), mysqld.dbcfgs.DbaConnector()) } // GetAllPrivsConnection creates a new DBConnection. diff --git a/go/vt/mysqlctl/replication.go b/go/vt/mysqlctl/replication.go index 5b3e5d3627d..14fa5b39417 100644 --- a/go/vt/mysqlctl/replication.go +++ b/go/vt/mysqlctl/replication.go @@ -276,7 +276,7 @@ func (mysqld *Mysqld) SetSlavePosition(ctx context.Context, pos mysql.Position) // SetMaster makes the provided host / port the master. It optionally // stops replication before, and starts it after. func (mysqld *Mysqld) SetMaster(ctx context.Context, masterHost string, masterPort int, slaveStopBefore bool, slaveStartAfter bool) error { - params, err := mysqld.dbcfgs.Repl().MysqlParams() + params, err := mysqld.dbcfgs.ReplConnector().MysqlParams() if err != nil { return err } diff --git a/go/vt/mysqlctl/schema.go b/go/vt/mysqlctl/schema.go index cde232c9a1c..2fc5aab5bbb 100644 --- a/go/vt/mysqlctl/schema.go +++ b/go/vt/mysqlctl/schema.go @@ -38,7 +38,7 @@ var autoIncr = regexp.MustCompile(` AUTO_INCREMENT=\d+`) // executeSchemaCommands executes some SQL commands, using the mysql // command line tool. It uses the dba connection parameters, with credentials. func (mysqld *Mysqld) executeSchemaCommands(sql string) error { - params, err := mysqld.dbcfgs.Dba().MysqlParams() + params, err := mysqld.dbcfgs.DbaConnector().MysqlParams() if err != nil { return err } diff --git a/go/vt/vtcombo/tablet_map.go b/go/vt/vtcombo/tablet_map.go index 94cd5cadf83..337472e4e10 100644 --- a/go/vt/vtcombo/tablet_map.go +++ b/go/vt/vtcombo/tablet_map.go @@ -179,9 +179,8 @@ func InitTabletMap(ts *topo.Server, tpb *vttestpb.VTTestTopology, mysqld mysqlct if dbname == "" { dbname = fmt.Sprintf("vt_%v_%v", keyspace, shard) } - // Copy dbcfgs and override SidecarDBName because there will be one for each db. - copydbcfgs := dbcfgs.Copy() - copydbcfgs.SidecarDBName.Set("_" + dbname) + // Clone dbcfgs and override SidecarDBName because there will be one for each db. + copydbcfgs := dbcfgs.Clone() replicas := int(kpb.ReplicaCount) if replicas == 0 { diff --git a/go/vt/vttablet/heartbeat/reader.go b/go/vt/vttablet/heartbeat/reader.go index b85cbbdf40f..ec9522dd1b2 100644 --- a/go/vt/vttablet/heartbeat/reader.go +++ b/go/vt/vttablet/heartbeat/reader.go @@ -27,7 +27,6 @@ import ( "golang.org/x/net/context" - "vitess.io/vitess/go/sqlescape" "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/timer" "vitess.io/vitess/go/vt/log" @@ -55,7 +54,6 @@ type Reader struct { enabled bool interval time.Duration keyspaceShard string - dbName string now func() time.Time errorLog *logutil.ThrottledLogger @@ -91,13 +89,11 @@ func NewReader(env tabletenv.Env) *Reader { } } -// Init does last minute initialization of db settings, such as dbName -// and keyspaceShard +// Init does last minute initialization of db settings, such as keyspaceShard. func (r *Reader) Init(target querypb.Target) { if !r.enabled { return } - r.dbName = sqlescape.EscapeID(r.env.DBConfigs().SidecarDBName.Get()) r.keyspaceShard = fmt.Sprintf("%s:%s", target.Keyspace, target.Shard) } @@ -198,7 +194,7 @@ func (r *Reader) bindHeartbeatFetch() (string, error) { bindVars := map[string]*querypb.BindVariable{ "ks": sqltypes.StringBindVariable(r.keyspaceShard), } - parsed := sqlparser.BuildParsedQuery(sqlFetchMostRecentHeartbeat, r.dbName, ":ks") + parsed := sqlparser.BuildParsedQuery(sqlFetchMostRecentHeartbeat, "_vt", ":ks") bound, err := parsed.GenerateQuery(bindVars, nil) if err != nil { return "", err diff --git a/go/vt/vttablet/heartbeat/reader_test.go b/go/vt/vttablet/heartbeat/reader_test.go index d3d7da107ea..446c3c4cee0 100644 --- a/go/vt/vttablet/heartbeat/reader_test.go +++ b/go/vt/vttablet/heartbeat/reader_test.go @@ -22,7 +22,6 @@ import ( "time" "vitess.io/vitess/go/mysql/fakesqldb" - "vitess.io/vitess/go/sqlescape" "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/vt/dbconfigs" "vitess.io/vitess/go/vt/vttablet/tabletserver/tabletenv" @@ -38,7 +37,7 @@ func TestReaderReadHeartbeat(t *testing.T) { tr := newReader(db, mockNowFunc) defer tr.Close() - db.AddQuery(fmt.Sprintf("SELECT ts FROM %s.heartbeat WHERE keyspaceShard='%s'", tr.dbName, tr.keyspaceShard), &sqltypes.Result{ + db.AddQuery(fmt.Sprintf("SELECT ts FROM %s.heartbeat WHERE keyspaceShard='%s'", "_vt", tr.keyspaceShard), &sqltypes.Result{ Fields: []*querypb.Field{ {Name: "ts", Type: sqltypes.Int64}, }, @@ -107,7 +106,6 @@ func newReader(db *fakesqldb.DB, nowFunc func() time.Time) *Reader { dbc := dbconfigs.NewTestDBConfigs(cp, cp, "") tr := NewReader(tabletenv.NewTestEnv(config, nil, "ReaderTest")) - tr.dbName = sqlescape.EscapeID(dbc.SidecarDBName.Get()) tr.keyspaceShard = "test:0" tr.now = nowFunc tr.pool.Open(dbc.AppWithDB(), dbc.DbaWithDB(), dbc.AppDebugWithDB()) diff --git a/go/vt/vttablet/heartbeat/writer.go b/go/vt/vttablet/heartbeat/writer.go index 2583da0b751..54acc84dbaf 100644 --- a/go/vt/vttablet/heartbeat/writer.go +++ b/go/vt/vttablet/heartbeat/writer.go @@ -25,7 +25,6 @@ import ( "golang.org/x/net/context" - "vitess.io/vitess/go/sqlescape" "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/timer" "vitess.io/vitess/go/vt/dbconfigs" @@ -60,7 +59,6 @@ type Writer struct { interval time.Duration tabletAlias topodatapb.TabletAlias keyspaceShard string - dbName string now func() time.Time errorLog *logutil.ThrottledLogger @@ -101,7 +99,6 @@ func (w *Writer) Init(target querypb.Target) error { w.mu.Lock() defer w.mu.Unlock() log.Info("Initializing heartbeat table.") - w.dbName = sqlescape.EscapeID(w.env.DBConfigs().SidecarDBName.Get()) w.keyspaceShard = fmt.Sprintf("%s:%s", target.Keyspace, target.Shard) if target.TabletType == topodatapb.TabletType_MASTER { @@ -162,8 +159,8 @@ func (w *Writer) initializeTables(cp dbconfigs.Connector) error { } defer conn.Close() statements := []string{ - fmt.Sprintf(sqlCreateSidecarDB, w.dbName), - fmt.Sprintf(sqlCreateHeartbeatTable, w.dbName), + fmt.Sprintf(sqlCreateSidecarDB, "_vt"), + fmt.Sprintf(sqlCreateHeartbeatTable, "_vt"), } for _, s := range statements { if _, err := conn.ExecuteFetch(s, 0, false); err != nil { @@ -191,7 +188,7 @@ func (w *Writer) bindHeartbeatVars(query string) (string, error) { "ts": sqltypes.Int64BindVariable(w.now().UnixNano()), "uid": sqltypes.Int64BindVariable(int64(w.tabletAlias.Uid)), } - parsed := sqlparser.BuildParsedQuery(query, w.dbName, ":ts", ":uid", ":ks") + parsed := sqlparser.BuildParsedQuery(query, "_vt", ":ts", ":uid", ":ks") bound, err := parsed.GenerateQuery(bindVars, nil) if err != nil { return "", err diff --git a/go/vt/vttablet/heartbeat/writer_test.go b/go/vt/vttablet/heartbeat/writer_test.go index d19cce64c36..a25d5d5e7e5 100644 --- a/go/vt/vttablet/heartbeat/writer_test.go +++ b/go/vt/vttablet/heartbeat/writer_test.go @@ -22,7 +22,6 @@ import ( "time" "vitess.io/vitess/go/mysql/fakesqldb" - "vitess.io/vitess/go/sqlescape" "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/vt/dbconfigs" topodatapb "vitess.io/vitess/go/vt/proto/topodata" @@ -47,13 +46,13 @@ func TestCreateSchema(t *testing.T) { defer tw.Close() writes.Reset() - db.AddQuery(fmt.Sprintf(sqlCreateHeartbeatTable, tw.dbName), &sqltypes.Result{}) - db.AddQuery(fmt.Sprintf("INSERT INTO %s.heartbeat (ts, tabletUid, keyspaceShard) VALUES (%d, %d, '%s') ON DUPLICATE KEY UPDATE ts=VALUES(ts)", tw.dbName, now.UnixNano(), tw.tabletAlias.Uid, tw.keyspaceShard), &sqltypes.Result{}) + db.AddQuery(fmt.Sprintf(sqlCreateHeartbeatTable, "_vt"), &sqltypes.Result{}) + db.AddQuery(fmt.Sprintf("INSERT INTO %s.heartbeat (ts, tabletUid, keyspaceShard) VALUES (%d, %d, '%s') ON DUPLICATE KEY UPDATE ts=VALUES(ts)", "_vt", now.UnixNano(), tw.tabletAlias.Uid, tw.keyspaceShard), &sqltypes.Result{}) if err := tw.initializeTables(db.ConnParams()); err == nil { t.Fatal("initializeTables() should not have succeeded") } - db.AddQuery(fmt.Sprintf(sqlCreateSidecarDB, tw.dbName), &sqltypes.Result{}) + db.AddQuery(fmt.Sprintf(sqlCreateSidecarDB, "_vt"), &sqltypes.Result{}) if err := tw.initializeTables(db.ConnParams()); err != nil { t.Fatalf("Should not be in error: %v", err) } @@ -70,7 +69,7 @@ func TestWriteHeartbeat(t *testing.T) { defer db.Close() tw := newTestWriter(db, mockNowFunc) - db.AddQuery(fmt.Sprintf("UPDATE %s.heartbeat SET ts=%d, tabletUid=%d WHERE keyspaceShard='%s'", tw.dbName, now.UnixNano(), tw.tabletAlias.Uid, tw.keyspaceShard), &sqltypes.Result{}) + db.AddQuery(fmt.Sprintf("UPDATE %s.heartbeat SET ts=%d, tabletUid=%d WHERE keyspaceShard='%s'", "_vt", now.UnixNano(), tw.tabletAlias.Uid, tw.keyspaceShard), &sqltypes.Result{}) writes.Reset() writeErrors.Reset() @@ -112,7 +111,6 @@ func newTestWriter(db *fakesqldb.DB, nowFunc func() time.Time) *Writer { dbc := dbconfigs.NewTestDBConfigs(cp, cp, "") tw := NewWriter(tabletenv.NewTestEnv(config, nil, "WriterTest"), topodatapb.TabletAlias{Cell: "test", Uid: 1111}) - tw.dbName = sqlescape.EscapeID(dbc.SidecarDBName.Get()) tw.keyspaceShard = "test:0" tw.now = nowFunc tw.pool.Open(dbc.AppWithDB(), dbc.DbaWithDB(), dbc.AppDebugWithDB()) diff --git a/go/vt/vttablet/tabletmanager/action_agent.go b/go/vt/vttablet/tabletmanager/action_agent.go index 349e0d336c5..427eb9b0f58 100644 --- a/go/vt/vttablet/tabletmanager/action_agent.go +++ b/go/vt/vttablet/tabletmanager/action_agent.go @@ -272,7 +272,6 @@ func NewActionAgent( TabletAlias: tabletAlias, Cnf: mycnf, MysqlDaemon: mysqld, - DBConfigs: dbcfgs, History: history.New(historyLength), DemoteMasterType: demoteMasterTabletType, _healthy: fmt.Errorf("healthcheck not run yet"), @@ -317,7 +316,7 @@ func NewActionAgent( } // Start will get the tablet info, and update our state from it - if err := agent.Start(batchCtx, mysqlHost, int32(mysqlPort), port, gRPCPort, true); err != nil { + if err := agent.Start(batchCtx, dbcfgs, mysqlHost, int32(mysqlPort), port, gRPCPort, true); err != nil { return nil, err } @@ -419,7 +418,6 @@ func NewTestActionAgent(batchCtx context.Context, ts *topo.Server, tabletAlias * TabletAlias: tabletAlias, Cnf: nil, MysqlDaemon: mysqlDaemon, - DBConfigs: &dbconfigs.DBConfigs{}, VREngine: vreplication.NewEngine(ts, tabletAlias.Cell, mysqlDaemon, binlogplayer.NewFakeDBClient, ti.DbName()), History: history.New(historyLength), DemoteMasterType: demoteMasterTabletType, @@ -430,7 +428,7 @@ func NewTestActionAgent(batchCtx context.Context, ts *topo.Server, tabletAlias * } // Start will update the topology and setup services. - if err := agent.Start(batchCtx, "", 0, vtPort, grpcPort, false); err != nil { + if err := agent.Start(batchCtx, &dbconfigs.DBConfigs{}, "", 0, vtPort, grpcPort, false); err != nil { panic(vterrors.Wrapf(err, "agent.Start(%v) failed", tabletAlias)) } @@ -464,7 +462,6 @@ func NewComboActionAgent(batchCtx context.Context, ts *topo.Server, tabletAlias TabletAlias: tabletAlias, Cnf: nil, MysqlDaemon: mysqlDaemon, - DBConfigs: dbcfgs, VREngine: vreplication.NewEngine(nil, "", nil, nil, ""), gotMysqlPort: true, History: history.New(historyLength), @@ -483,7 +480,7 @@ func NewComboActionAgent(batchCtx context.Context, ts *topo.Server, tabletAlias } // Start the agent. - if err := agent.Start(batchCtx, "", 0, vtPort, grpcPort, false); err != nil { + if err := agent.Start(batchCtx, dbcfgs, "", 0, vtPort, grpcPort, false); err != nil { panic(vterrors.Wrapf(err, "agent.Start(%v) failed", tabletAlias)) } @@ -634,7 +631,7 @@ func (agent *ActionAgent) verifyTopology(ctx context.Context) { // Start validates and updates the topology records for the tablet, and performs // the initial state change callback to start tablet services. // If initUpdateStream is set, update stream service will also be registered. -func (agent *ActionAgent) Start(ctx context.Context, mysqlHost string, mysqlPort, vtPort, gRPCPort int32, initUpdateStream bool) error { +func (agent *ActionAgent) Start(ctx context.Context, dbcfgs *dbconfigs.DBConfigs, mysqlHost string, mysqlPort, vtPort, gRPCPort int32, initUpdateStream bool) error { // find our hostname as fully qualified, and IP hostname := *tabletHostname if hostname == "" { @@ -686,11 +683,8 @@ func (agent *ActionAgent) Start(ctx context.Context, mysqlHost string, mysqlPort // Verify the topology is correct. agent.verifyTopology(ctx) - // Get and fix the dbname if necessary, only for real instances. - if !agent.DBConfigs.IsZero() { - dbname := topoproto.TabletDbName(agent.initialTablet) - agent.DBConfigs.DBName.Set(dbname) - } + dbname := topoproto.TabletDbName(agent.initialTablet) + agent.DBConfigs = dbcfgs.WithDBName(dbname) // Create and register the RPC services from UpdateStream. // (it needs the dbname, so it has to be delayed up to here, diff --git a/go/vt/vttablet/tabletmanager/rpc_lock_tables.go b/go/vt/vttablet/tabletmanager/rpc_lock_tables.go index 30a0fabeb3e..b81d78cd7c1 100644 --- a/go/vt/vttablet/tabletmanager/rpc_lock_tables.go +++ b/go/vt/vttablet/tabletmanager/rpc_lock_tables.go @@ -102,7 +102,7 @@ func (agent *ActionAgent) lockTablesUsingLockTables(conn *dbconnpool.DBConnectio tableNames = append(tableNames, fmt.Sprintf("%s READ", sqlescape.EscapeID(name))) } lockStatement := fmt.Sprintf("LOCK TABLES %v", strings.Join(tableNames, ", ")) - _, err := conn.ExecuteFetch(fmt.Sprintf("USE %s", agent.DBConfigs.DBName.Get()), 0, false) + _, err := conn.ExecuteFetch(fmt.Sprintf("USE %s", agent.DBConfigs.DBName()), 0, false) if err != nil { return err } diff --git a/go/vt/vttablet/tabletserver/connpool/pool_test.go b/go/vt/vttablet/tabletserver/connpool/pool_test.go index 49e385c06ce..1f7bafa3977 100644 --- a/go/vt/vttablet/tabletserver/connpool/pool_test.go +++ b/go/vt/vttablet/tabletserver/connpool/pool_test.go @@ -86,7 +86,10 @@ func TestConnPoolMaxWaiters(t *testing.T) { go func() { defer wg.Done() c1, err := connPool.Get(context.Background()) - assert.NoError(t, err) + if err != nil { + t.Errorf("unexpected error: %v", err) + return + } c1.Recycle() }() // Wait for the first waiter to increment count. diff --git a/go/vt/vttablet/tabletserver/query_executor_test.go b/go/vt/vttablet/tabletserver/query_executor_test.go index f3388f1a335..c58947bf23f 100644 --- a/go/vt/vttablet/tabletserver/query_executor_test.go +++ b/go/vt/vttablet/tabletserver/query_executor_test.go @@ -1121,16 +1121,16 @@ func getTestTableFields() []*querypb.Field { func getQueryExecutorSupportedQueries(testTableHasMultipleUniqueKeys bool) map[string]*sqltypes.Result { return map[string]*sqltypes.Result{ // queries for twopc - sqlTurnoffBinlog: {}, - fmt.Sprintf(sqlCreateSidecarDB, "`_vt`"): {}, - fmt.Sprintf(sqlDropLegacy1, "`_vt`"): {}, - fmt.Sprintf(sqlDropLegacy2, "`_vt`"): {}, - fmt.Sprintf(sqlDropLegacy3, "`_vt`"): {}, - fmt.Sprintf(sqlDropLegacy4, "`_vt`"): {}, - fmt.Sprintf(sqlCreateTableRedoState, "`_vt`"): {}, - fmt.Sprintf(sqlCreateTableRedoStatement, "`_vt`"): {}, - fmt.Sprintf(sqlCreateTableDTState, "`_vt`"): {}, - fmt.Sprintf(sqlCreateTableDTParticipant, "`_vt`"): {}, + sqlTurnoffBinlog: {}, + fmt.Sprintf(sqlCreateSidecarDB, "_vt"): {}, + fmt.Sprintf(sqlDropLegacy1, "_vt"): {}, + fmt.Sprintf(sqlDropLegacy2, "_vt"): {}, + fmt.Sprintf(sqlDropLegacy3, "_vt"): {}, + fmt.Sprintf(sqlDropLegacy4, "_vt"): {}, + fmt.Sprintf(sqlCreateTableRedoState, "_vt"): {}, + fmt.Sprintf(sqlCreateTableRedoStatement, "_vt"): {}, + fmt.Sprintf(sqlCreateTableDTState, "_vt"): {}, + fmt.Sprintf(sqlCreateTableDTParticipant, "_vt"): {}, // queries for schema info "select unix_timestamp()": { Fields: []*querypb.Field{{ @@ -1251,6 +1251,6 @@ func getQueryExecutorSupportedQueries(testTableHasMultipleUniqueKeys bool) map[s "begin": {}, "commit": {}, "rollback": {}, - fmt.Sprintf(sqlReadAllRedo, "`_vt`", "`_vt`"): {}, + fmt.Sprintf(sqlReadAllRedo, "_vt", "_vt"): {}, } } diff --git a/go/vt/vttablet/tabletserver/tabletenv/config.go b/go/vt/vttablet/tabletserver/tabletenv/config.go index 23573ee624d..189b795715a 100644 --- a/go/vt/vttablet/tabletserver/tabletenv/config.go +++ b/go/vt/vttablet/tabletserver/tabletenv/config.go @@ -26,6 +26,7 @@ import ( "vitess.io/vitess/go/flagutil" "vitess.io/vitess/go/streamlog" + "vitess.io/vitess/go/vt/dbconfigs" "vitess.io/vitess/go/vt/log" "vitess.io/vitess/go/vt/throttler" ) @@ -39,7 +40,9 @@ const ( ) var ( - currentConfig TabletConfig + currentConfig = TabletConfig{ + DB: &dbconfigs.GlobalDBConfigs, + } queryLogHandler = flag.String("query-log-stream-handler", "/debug/querylog", "URL handler for streaming queries log") txLogHandler = flag.String("transaction-log-stream-handler", "/debug/txlog", "URL handler for streaming transactions log") @@ -185,6 +188,8 @@ func Init() { // TabletConfig contains all the configuration for query service type TabletConfig struct { + DB *dbconfigs.DBConfigs `json:"db,omitempty"` + OltpReadPool ConnPoolConfig `json:"oltpReadPool,omitempty"` OlapReadPool ConnPoolConfig `json:"olapReadPool,omitempty"` TxPool ConnPoolConfig `json:"txPool,omitempty"` @@ -268,6 +273,9 @@ func NewDefaultConfig() *TabletConfig { // Clone creates a clone of TabletConfig. func (c *TabletConfig) Clone() *TabletConfig { tc := *c + if tc.DB != nil { + tc.DB = c.DB.Clone() + } return &tc } diff --git a/go/vt/vttablet/tabletserver/tabletenv/config_test.go b/go/vt/vttablet/tabletserver/tabletenv/config_test.go index 8dc17dc08f0..cdea4848209 100644 --- a/go/vt/vttablet/tabletserver/tabletenv/config_test.go +++ b/go/vt/vttablet/tabletserver/tabletenv/config_test.go @@ -21,11 +21,21 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "sigs.k8s.io/yaml" + "vitess.io/vitess/go/vt/dbconfigs" + "vitess.io/vitess/go/yaml2" ) func TestConfigParse(t *testing.T) { cfg := TabletConfig{ + DB: &dbconfigs.DBConfigs{ + Socket: "a", + App: dbconfigs.UserConfig{ + User: "b", + }, + Dba: dbconfigs.UserConfig{ + User: "c", + }, + }, OltpReadPool: ConnPoolConfig{ Size: 16, TimeoutSeconds: 10, @@ -34,9 +44,25 @@ func TestConfigParse(t *testing.T) { MaxWaiters: 40, }, } - gotBytes, err := yaml.Marshal(&cfg) + gotBytes, err := yaml2.Marshal(&cfg) require.NoError(t, err) - wantBytes := `hotRowProtection: {} + wantBytes := `db: + allprivs: + password: '****' + app: + password: '****' + user: b + appdebug: + password: '****' + dba: + password: '****' + user: c + filtered: + password: '****' + repl: + password: '****' + socket: a +hotRowProtection: {} olapReadPool: {} oltp: {} oltpReadPool: @@ -49,21 +75,31 @@ txPool: {} ` assert.Equal(t, wantBytes, string(gotBytes)) - // Make sure TimeoutSeconds doesn't get overwritten. - inBytes := []byte(`oltpReadPool: + // Make sure things already set don't get overwritten, + // and thing specified do overwrite. + // OltpReadPool.TimeoutSeconds should not get overwritten. + // DB.App.User should not get overwritten. + // DB.Dba.User should get overwritten. + inBytes := []byte(`db: + socket: a + dba: + user: c +oltpReadPool: size: 16 idleTimeoutSeconds: 20 prefillParallelism: 30 maxWaiters: 40 `) gotCfg := cfg - err = yaml.Unmarshal(inBytes, &gotCfg) + gotCfg.DB = cfg.DB.Clone() + gotCfg.DB.Dba = dbconfigs.UserConfig{} + err = yaml2.Unmarshal(inBytes, &gotCfg) require.NoError(t, err) assert.Equal(t, cfg, gotCfg) } func TestDefaultConfig(t *testing.T) { - gotBytes, err := yaml.Marshal(NewDefaultConfig()) + gotBytes, err := yaml2.Marshal(NewDefaultConfig()) require.NoError(t, err) want := `cacheResultFields: true consolidator: enable diff --git a/go/vt/vttablet/tabletserver/tabletserver_test.go b/go/vt/vttablet/tabletserver/tabletserver_test.go index c589c7522a0..af4155095d3 100644 --- a/go/vt/vttablet/tabletserver/tabletserver_test.go +++ b/go/vt/vttablet/tabletserver/tabletserver_test.go @@ -654,8 +654,8 @@ func TestTabletServerCreateTransaction(t *testing.T) { ctx := context.Background() target := querypb.Target{TabletType: topodatapb.TabletType_MASTER} - db.AddQueryPattern(fmt.Sprintf("insert into `_vt`\\.dt_state\\(dtid, state, time_created\\) values \\('aa', %d,.*", int(querypb.TransactionState_PREPARE)), &sqltypes.Result{}) - db.AddQueryPattern("insert into `_vt`\\.dt_participant\\(dtid, id, keyspace, shard\\) values \\('aa', 1,.*", &sqltypes.Result{}) + db.AddQueryPattern(fmt.Sprintf("insert into _vt\\.dt_state\\(dtid, state, time_created\\) values \\('aa', %d,.*", int(querypb.TransactionState_PREPARE)), &sqltypes.Result{}) + db.AddQueryPattern("insert into _vt\\.dt_participant\\(dtid, id, keyspace, shard\\) values \\('aa', 1,.*", &sqltypes.Result{}) err := tsv.CreateTransaction(ctx, &target, "aa", []*querypb.Target{{ Keyspace: "t1", Shard: "0", @@ -670,7 +670,7 @@ func TestTabletServerStartCommit(t *testing.T) { ctx := context.Background() target := querypb.Target{TabletType: topodatapb.TabletType_MASTER} - commitTransition := fmt.Sprintf("update `_vt`.dt_state set state = %d where dtid = 'aa' and state = %d", int(querypb.TransactionState_COMMIT), int(querypb.TransactionState_PREPARE)) + commitTransition := fmt.Sprintf("update _vt.dt_state set state = %d where dtid = 'aa' and state = %d", int(querypb.TransactionState_COMMIT), int(querypb.TransactionState_PREPARE)) db.AddQuery(commitTransition, &sqltypes.Result{RowsAffected: 1}) txid := newTxForPrep(tsv) err := tsv.StartCommit(ctx, &target, txid, "aa") @@ -692,7 +692,7 @@ func TestTabletserverSetRollback(t *testing.T) { ctx := context.Background() target := querypb.Target{TabletType: topodatapb.TabletType_MASTER} - rollbackTransition := fmt.Sprintf("update `_vt`.dt_state set state = %d where dtid = 'aa' and state = %d", int(querypb.TransactionState_ROLLBACK), int(querypb.TransactionState_PREPARE)) + rollbackTransition := fmt.Sprintf("update _vt.dt_state set state = %d where dtid = 'aa' and state = %d", int(querypb.TransactionState_ROLLBACK), int(querypb.TransactionState_PREPARE)) db.AddQuery(rollbackTransition, &sqltypes.Result{RowsAffected: 1}) txid := newTxForPrep(tsv) err := tsv.SetRollback(ctx, &target, "aa", txid) @@ -714,7 +714,7 @@ func TestTabletServerReadTransaction(t *testing.T) { ctx := context.Background() target := querypb.Target{TabletType: topodatapb.TabletType_MASTER} - db.AddQuery("select dtid, state, time_created from `_vt`.dt_state where dtid = 'aa'", &sqltypes.Result{}) + db.AddQuery("select dtid, state, time_created from _vt.dt_state where dtid = 'aa'", &sqltypes.Result{}) got, err := tsv.ReadTransaction(ctx, &target, "aa") require.NoError(t, err) want := &querypb.TransactionMetadata{} @@ -734,8 +734,8 @@ func TestTabletServerReadTransaction(t *testing.T) { sqltypes.NewVarBinary("1"), }}, } - db.AddQuery("select dtid, state, time_created from `_vt`.dt_state where dtid = 'aa'", txResult) - db.AddQuery("select keyspace, shard from `_vt`.dt_participant where dtid = 'aa'", &sqltypes.Result{ + db.AddQuery("select dtid, state, time_created from _vt.dt_state where dtid = 'aa'", txResult) + db.AddQuery("select keyspace, shard from _vt.dt_participant where dtid = 'aa'", &sqltypes.Result{ Fields: []*querypb.Field{ {Type: sqltypes.VarBinary}, {Type: sqltypes.VarBinary}, @@ -780,7 +780,7 @@ func TestTabletServerReadTransaction(t *testing.T) { sqltypes.NewVarBinary("1"), }}, } - db.AddQuery("select dtid, state, time_created from `_vt`.dt_state where dtid = 'aa'", txResult) + db.AddQuery("select dtid, state, time_created from _vt.dt_state where dtid = 'aa'", txResult) want.State = querypb.TransactionState_COMMIT got, err = tsv.ReadTransaction(ctx, &target, "aa") require.NoError(t, err) @@ -800,7 +800,7 @@ func TestTabletServerReadTransaction(t *testing.T) { sqltypes.NewVarBinary("1"), }}, } - db.AddQuery("select dtid, state, time_created from `_vt`.dt_state where dtid = 'aa'", txResult) + db.AddQuery("select dtid, state, time_created from _vt.dt_state where dtid = 'aa'", txResult) want.State = querypb.TransactionState_ROLLBACK got, err = tsv.ReadTransaction(ctx, &target, "aa") require.NoError(t, err) @@ -816,8 +816,8 @@ func TestTabletServerConcludeTransaction(t *testing.T) { ctx := context.Background() target := querypb.Target{TabletType: topodatapb.TabletType_MASTER} - db.AddQuery("delete from `_vt`.dt_state where dtid = 'aa'", &sqltypes.Result{}) - db.AddQuery("delete from `_vt`.dt_participant where dtid = 'aa'", &sqltypes.Result{}) + db.AddQuery("delete from _vt.dt_state where dtid = 'aa'", &sqltypes.Result{}) + db.AddQuery("delete from _vt.dt_participant where dtid = 'aa'", &sqltypes.Result{}) err := tsv.ConcludeTransaction(ctx, &target, "aa") require.NoError(t, err) } @@ -2551,16 +2551,16 @@ func getSupportedQueries() map[string]*sqltypes.Result { RowsAffected: 1, }, // queries for twopc - sqlTurnoffBinlog: {}, - fmt.Sprintf(sqlCreateSidecarDB, "`_vt`"): {}, - fmt.Sprintf(sqlDropLegacy1, "`_vt`"): {}, - fmt.Sprintf(sqlDropLegacy2, "`_vt`"): {}, - fmt.Sprintf(sqlDropLegacy3, "`_vt`"): {}, - fmt.Sprintf(sqlDropLegacy4, "`_vt`"): {}, - fmt.Sprintf(sqlCreateTableRedoState, "`_vt`"): {}, - fmt.Sprintf(sqlCreateTableRedoStatement, "`_vt`"): {}, - fmt.Sprintf(sqlCreateTableDTState, "`_vt`"): {}, - fmt.Sprintf(sqlCreateTableDTParticipant, "`_vt`"): {}, + sqlTurnoffBinlog: {}, + fmt.Sprintf(sqlCreateSidecarDB, "_vt"): {}, + fmt.Sprintf(sqlDropLegacy1, "_vt"): {}, + fmt.Sprintf(sqlDropLegacy2, "_vt"): {}, + fmt.Sprintf(sqlDropLegacy3, "_vt"): {}, + fmt.Sprintf(sqlDropLegacy4, "_vt"): {}, + fmt.Sprintf(sqlCreateTableRedoState, "_vt"): {}, + fmt.Sprintf(sqlCreateTableRedoStatement, "_vt"): {}, + fmt.Sprintf(sqlCreateTableDTState, "_vt"): {}, + fmt.Sprintf(sqlCreateTableDTParticipant, "_vt"): {}, // queries for schema info "select unix_timestamp()": { Fields: []*querypb.Field{{ @@ -2647,7 +2647,7 @@ func getSupportedQueries() map[string]*sqltypes.Result { "begin": {}, "commit": {}, "rollback": {}, - fmt.Sprintf(sqlReadAllRedo, "`_vt`", "`_vt`"): {}, + fmt.Sprintf(sqlReadAllRedo, "_vt", "_vt"): {}, } } diff --git a/go/vt/vttablet/tabletserver/twopc.go b/go/vt/vttablet/tabletserver/twopc.go index a9932c8cfe9..4c365afb5d6 100644 --- a/go/vt/vttablet/tabletserver/twopc.go +++ b/go/vt/vttablet/tabletserver/twopc.go @@ -24,7 +24,6 @@ import ( "golang.org/x/net/context" - "vitess.io/vitess/go/sqlescape" "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/vt/dbconfigs" "vitess.io/vitess/go/vt/dbconnpool" @@ -127,8 +126,8 @@ func NewTwoPC(readPool *connpool.Pool) *TwoPC { // Init initializes TwoPC. If the metadata database or tables // are not present, they are created. -func (tpc *TwoPC) Init(sidecarDBName string, dbaparams dbconfigs.Connector) error { - dbname := sqlescape.EscapeID(sidecarDBName) +func (tpc *TwoPC) Init(dbaparams dbconfigs.Connector) error { + dbname := "_vt" conn, err := dbconnpool.NewDBConnection(context.TODO(), dbaparams) if err != nil { return err diff --git a/go/vt/vttablet/tabletserver/tx_engine.go b/go/vt/vttablet/tabletserver/tx_engine.go index dca7647e155..30e14a2cefb 100644 --- a/go/vt/vttablet/tabletserver/tx_engine.go +++ b/go/vt/vttablet/tabletserver/tx_engine.go @@ -340,7 +340,7 @@ func (te *TxEngine) transitionTo(nextState txEngineState) error { // up the metadata tables. func (te *TxEngine) Init() error { if te.twopcEnabled { - return te.twoPC.Init(te.env.DBConfigs().SidecarDBName.Get(), te.env.DBConfigs().DbaWithDB()) + return te.twoPC.Init(te.env.DBConfigs().DbaWithDB()) } return nil } diff --git a/go/vt/vttablet/tabletserver/tx_executor_test.go b/go/vt/vttablet/tabletserver/tx_executor_test.go index bc4bcb64b0a..2e5f775d396 100644 --- a/go/vt/vttablet/tabletserver/tx_executor_test.go +++ b/go/vt/vttablet/tabletserver/tx_executor_test.go @@ -156,11 +156,11 @@ func TestTxExecutorCommitRedoFail(t *testing.T) { defer tsv.StopService() txid := newTxForPrep(tsv) // Allow all additions to redo logs to succeed - db.AddQueryPattern("insert into `_vt`\\.redo_state.*", &sqltypes.Result{}) + db.AddQueryPattern("insert into _vt\\.redo_state.*", &sqltypes.Result{}) err := txe.Prepare(txid, "bb") require.NoError(t, err) defer txe.RollbackPrepared("bb", 0) - db.AddQuery("update `_vt`.redo_state set state = 'Failed' where dtid = 'bb'", &sqltypes.Result{}) + db.AddQuery("update _vt.redo_state set state = 'Failed' where dtid = 'bb'", &sqltypes.Result{}) err = txe.CommitPrepared("bb") want := "is not supported" if err == nil || !strings.Contains(err.Error(), want) { @@ -211,7 +211,7 @@ func TestTxExecutorRollbackRedoFail(t *testing.T) { defer tsv.StopService() txid := newTxForPrep(tsv) // Allow all additions to redo logs to succeed - db.AddQueryPattern("insert into `_vt`\\.redo_state.*", &sqltypes.Result{}) + db.AddQueryPattern("insert into _vt\\.redo_state.*", &sqltypes.Result{}) err := txe.Prepare(txid, "bb") require.NoError(t, err) err = txe.RollbackPrepared("bb", txid) @@ -226,8 +226,8 @@ func TestExecutorCreateTransaction(t *testing.T) { defer db.Close() defer tsv.StopService() - db.AddQueryPattern(fmt.Sprintf("insert into `_vt`\\.dt_state\\(dtid, state, time_created\\) values \\('aa', %d,.*", int(querypb.TransactionState_PREPARE)), &sqltypes.Result{}) - db.AddQueryPattern("insert into `_vt`\\.dt_participant\\(dtid, id, keyspace, shard\\) values \\('aa', 1,.*", &sqltypes.Result{}) + db.AddQueryPattern(fmt.Sprintf("insert into _vt\\.dt_state\\(dtid, state, time_created\\) values \\('aa', %d,.*", int(querypb.TransactionState_PREPARE)), &sqltypes.Result{}) + db.AddQueryPattern("insert into _vt\\.dt_participant\\(dtid, id, keyspace, shard\\) values \\('aa', 1,.*", &sqltypes.Result{}) err := txe.CreateTransaction("aa", []*querypb.Target{{ Keyspace: "t1", Shard: "0", @@ -240,7 +240,7 @@ func TestExecutorStartCommit(t *testing.T) { defer db.Close() defer tsv.StopService() - commitTransition := fmt.Sprintf("update `_vt`.dt_state set state = %d where dtid = 'aa' and state = %d", int(querypb.TransactionState_COMMIT), int(querypb.TransactionState_PREPARE)) + commitTransition := fmt.Sprintf("update _vt.dt_state set state = %d where dtid = 'aa' and state = %d", int(querypb.TransactionState_COMMIT), int(querypb.TransactionState_PREPARE)) db.AddQuery(commitTransition, &sqltypes.Result{RowsAffected: 1}) txid := newTxForPrep(tsv) err := txe.StartCommit(txid, "aa") @@ -260,7 +260,7 @@ func TestExecutorSetRollback(t *testing.T) { defer db.Close() defer tsv.StopService() - rollbackTransition := fmt.Sprintf("update `_vt`.dt_state set state = %d where dtid = 'aa' and state = %d", int(querypb.TransactionState_ROLLBACK), int(querypb.TransactionState_PREPARE)) + rollbackTransition := fmt.Sprintf("update _vt.dt_state set state = %d where dtid = 'aa' and state = %d", int(querypb.TransactionState_ROLLBACK), int(querypb.TransactionState_PREPARE)) db.AddQuery(rollbackTransition, &sqltypes.Result{RowsAffected: 1}) txid := newTxForPrep(tsv) err := txe.SetRollback("aa", txid) @@ -280,8 +280,8 @@ func TestExecutorConcludeTransaction(t *testing.T) { defer db.Close() defer tsv.StopService() - db.AddQuery("delete from `_vt`.dt_state where dtid = 'aa'", &sqltypes.Result{}) - db.AddQuery("delete from `_vt`.dt_participant where dtid = 'aa'", &sqltypes.Result{}) + db.AddQuery("delete from _vt.dt_state where dtid = 'aa'", &sqltypes.Result{}) + db.AddQuery("delete from _vt.dt_participant where dtid = 'aa'", &sqltypes.Result{}) err := txe.ConcludeTransaction("aa") require.NoError(t, err) } @@ -291,7 +291,7 @@ func TestExecutorReadTransaction(t *testing.T) { defer db.Close() defer tsv.StopService() - db.AddQuery("select dtid, state, time_created from `_vt`.dt_state where dtid = 'aa'", &sqltypes.Result{}) + db.AddQuery("select dtid, state, time_created from _vt.dt_state where dtid = 'aa'", &sqltypes.Result{}) got, err := txe.ReadTransaction("aa") require.NoError(t, err) want := &querypb.TransactionMetadata{} @@ -311,8 +311,8 @@ func TestExecutorReadTransaction(t *testing.T) { sqltypes.NewVarBinary("1"), }}, } - db.AddQuery("select dtid, state, time_created from `_vt`.dt_state where dtid = 'aa'", txResult) - db.AddQuery("select keyspace, shard from `_vt`.dt_participant where dtid = 'aa'", &sqltypes.Result{ + db.AddQuery("select dtid, state, time_created from _vt.dt_state where dtid = 'aa'", txResult) + db.AddQuery("select keyspace, shard from _vt.dt_participant where dtid = 'aa'", &sqltypes.Result{ Fields: []*querypb.Field{ {Type: sqltypes.VarChar}, {Type: sqltypes.VarChar}, @@ -357,7 +357,7 @@ func TestExecutorReadTransaction(t *testing.T) { sqltypes.NewVarBinary("1"), }}, } - db.AddQuery("select dtid, state, time_created from `_vt`.dt_state where dtid = 'aa'", txResult) + db.AddQuery("select dtid, state, time_created from _vt.dt_state where dtid = 'aa'", txResult) want.State = querypb.TransactionState_COMMIT got, err = txe.ReadTransaction("aa") require.NoError(t, err) @@ -377,7 +377,7 @@ func TestExecutorReadTransaction(t *testing.T) { sqltypes.NewVarBinary("1"), }}, } - db.AddQuery("select dtid, state, time_created from `_vt`.dt_state where dtid = 'aa'", txResult) + db.AddQuery("select dtid, state, time_created from _vt.dt_state where dtid = 'aa'", txResult) want.State = querypb.TransactionState_ROLLBACK got, err = txe.ReadTransaction("aa") require.NoError(t, err) @@ -451,7 +451,7 @@ func TestExecutorResolveTransaction(t *testing.T) { defer tsv.StopService() want := "aa" db.AddQueryPattern( - "select dtid, time_created from `_vt`\\.dt_state where time_created.*", + "select dtid, time_created from _vt\\.dt_state where time_created.*", &sqltypes.Result{ Fields: []*querypb.Field{ {Type: sqltypes.VarChar}, @@ -525,10 +525,10 @@ func newTestTxExecutor(t *testing.T) (txe *TxExecutor, tsv *TabletServer, db *fa ctx := context.Background() logStats := tabletenv.NewLogStats(ctx, "TestTxExecutor") tsv = newTestTabletServer(ctx, smallTxPool, db) - db.AddQueryPattern("insert into `_vt`\\.redo_state\\(dtid, state, time_created\\) values \\('aa', 1,.*", &sqltypes.Result{}) - db.AddQueryPattern("insert into `_vt`\\.redo_statement.*", &sqltypes.Result{}) - db.AddQuery("delete from `_vt`.redo_state where dtid = 'aa'", &sqltypes.Result{}) - db.AddQuery("delete from `_vt`.redo_statement where dtid = 'aa'", &sqltypes.Result{}) + db.AddQueryPattern("insert into _vt\\.redo_state\\(dtid, state, time_created\\) values \\('aa', 1,.*", &sqltypes.Result{}) + db.AddQueryPattern("insert into _vt\\.redo_statement.*", &sqltypes.Result{}) + db.AddQuery("delete from _vt.redo_state where dtid = 'aa'", &sqltypes.Result{}) + db.AddQuery("delete from _vt.redo_statement where dtid = 'aa'", &sqltypes.Result{}) db.AddQuery("update test_table set name = 2 where pk = 1 limit 10001", &sqltypes.Result{}) return &TxExecutor{ ctx: ctx, @@ -543,10 +543,10 @@ func newShortAgeExecutor(t *testing.T) (txe *TxExecutor, tsv *TabletServer, db * ctx := context.Background() logStats := tabletenv.NewLogStats(ctx, "TestTxExecutor") tsv = newTestTabletServer(ctx, smallTxPool|shortTwopcAge, db) - db.AddQueryPattern("insert into `_vt`\\.redo_state\\(dtid, state, time_created\\) values \\('aa', 1,.*", &sqltypes.Result{}) - db.AddQueryPattern("insert into `_vt`\\.redo_statement.*", &sqltypes.Result{}) - db.AddQuery("delete from `_vt`.redo_state where dtid = 'aa'", &sqltypes.Result{}) - db.AddQuery("delete from `_vt`.redo_statement where dtid = 'aa'", &sqltypes.Result{}) + db.AddQueryPattern("insert into _vt\\.redo_state\\(dtid, state, time_created\\) values \\('aa', 1,.*", &sqltypes.Result{}) + db.AddQueryPattern("insert into _vt\\.redo_statement.*", &sqltypes.Result{}) + db.AddQuery("delete from _vt.redo_state where dtid = 'aa'", &sqltypes.Result{}) + db.AddQuery("delete from _vt.redo_statement where dtid = 'aa'", &sqltypes.Result{}) db.AddQuery("update test_table set name = 2 where pk = 1 limit 10001", &sqltypes.Result{}) return &TxExecutor{ ctx: ctx, diff --git a/go/yaml2/yaml.go b/go/yaml2/yaml.go new file mode 100644 index 00000000000..b1a8342feff --- /dev/null +++ b/go/yaml2/yaml.go @@ -0,0 +1,28 @@ +/* +Copyright 2020 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Package yaml2 ensures that the right yaml package gets imported. +// The default package that goimports adds is not the one we want to use. +package yaml2 + +import "sigs.k8s.io/yaml" + +var ( + // Marshal marshals to YAML. + Marshal = yaml.Marshal + // Unmarshal unmarshals from YAML. + Unmarshal = yaml.Unmarshal +)