Skip to content

Commit 83f0cc1

Browse files
Consolidate fetching of MySQL server info
Signed-off-by: Tim Vaillancourt <[email protected]>
1 parent 59fd18d commit 83f0cc1

13 files changed

+215
-130
lines changed

go/base/context.go

+3-11
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
Copyright 2022 GitHub Inc.
2+
Copyright 2023 GitHub Inc.
33
See https://github.com/github/gh-ost/blob/master/LICENSE
44
*/
55

@@ -163,18 +163,15 @@ type MigrationContext struct {
163163

164164
Hostname string
165165
AssumeMasterHostname string
166-
ApplierTimeZone string
167166
TableEngine string
168167
RowsEstimate int64
169168
RowsDeltaEstimate int64
170169
UsedRowsEstimateMethod RowsEstimateMethod
171170
HasSuperPrivilege bool
172-
OriginalBinlogFormat string
173-
OriginalBinlogRowImage string
174171
InspectorConnectionConfig *mysql.ConnectionConfig
175-
InspectorMySQLVersion string
172+
InspectorServerInfo *mysql.ServerInfo
176173
ApplierConnectionConfig *mysql.ConnectionConfig
177-
ApplierMySQLVersion string
174+
ApplierServerInfo *mysql.ServerInfo
178175
StartTime time.Time
179176
RowCopyStartTime time.Time
180177
RowCopyEndTime time.Time
@@ -359,11 +356,6 @@ func (this *MigrationContext) GetVoluntaryLockName() string {
359356
return fmt.Sprintf("%s.%s.lock", this.DatabaseName, this.OriginalTableName)
360357
}
361358

362-
// RequiresBinlogFormatChange is `true` when the original binlog format isn't `ROW`
363-
func (this *MigrationContext) RequiresBinlogFormatChange() bool {
364-
return this.OriginalBinlogFormat != "ROW"
365-
}
366-
367359
// GetApplierHostname is a safe access method to the applier hostname
368360
func (this *MigrationContext) GetApplierHostname() string {
369361
if this.ApplierConnectionConfig == nil {

go/base/utils.go

+13-28
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
Copyright 2022 GitHub Inc.
2+
Copyright 2023 GitHub Inc.
33
See https://github.com/github/gh-ost/blob/master/LICENSE
44
*/
55

@@ -12,8 +12,6 @@ import (
1212
"strings"
1313
"time"
1414

15-
gosql "database/sql"
16-
1715
"github.com/github/gh-ost/go/mysql"
1816
)
1917

@@ -61,35 +59,22 @@ func StringContainsAll(s string, substrings ...string) bool {
6159
return nonEmptyStringsFound
6260
}
6361

64-
func ValidateConnection(db *gosql.DB, connectionConfig *mysql.ConnectionConfig, migrationContext *MigrationContext, name string) (string, error) {
65-
versionQuery := `select @@global.version`
66-
var port, extraPort int
67-
var version string
68-
if err := db.QueryRow(versionQuery).Scan(&version); err != nil {
69-
return "", err
70-
}
71-
extraPortQuery := `select @@global.extra_port`
72-
if err := db.QueryRow(extraPortQuery).Scan(&extraPort); err != nil { //nolint:staticcheck
73-
// swallow this error. not all servers support extra_port
74-
}
62+
// ValidateConnection confirms the database server info matches the provided connection config.
63+
func ValidateConnection(serverInfo *mysql.ServerInfo, connectionConfig *mysql.ConnectionConfig, migrationContext *MigrationContext, name string) error {
7564
// AliyunRDS set users port to "NULL", replace it by gh-ost param
7665
// GCP set users port to "NULL", replace it by gh-ost param
77-
// Azure MySQL set users port to a different value by design, replace it by gh-ost para
66+
// Azure MySQL set users port to a different value by design, replace it by gh-ost param
7867
if migrationContext.AliyunRDS || migrationContext.GoogleCloudPlatform || migrationContext.AzureMySQL {
79-
port = connectionConfig.Key.Port
80-
} else {
81-
portQuery := `select @@global.port`
82-
if err := db.QueryRow(portQuery).Scan(&port); err != nil {
83-
return "", err
84-
}
68+
serverInfo.Port.Int64 = connectionConfig.Key.Port
69+
serverInfo.Port.Valid = connectionConfig.Key.Port > 0
8570
}
8671

87-
if connectionConfig.Key.Port == port || (extraPort > 0 && connectionConfig.Key.Port == extraPort) {
88-
migrationContext.Log.Infof("%s connection validated on %+v", name, connectionConfig.Key)
89-
return version, nil
90-
} else if extraPort == 0 {
91-
return "", fmt.Errorf("Unexpected database port reported: %+v", port)
92-
} else {
93-
return "", fmt.Errorf("Unexpected database port reported: %+v / extra_port: %+v", port, extraPort)
72+
if !serverInfo.Port.Valid && !serverInfo.ExtraPort.Valid {
73+
return fmt.Errorf("Unexpected database port reported: %+v", serverInfo.Port.Int64)
74+
} else if connectionConfig.Key.Port != serverInfo.Port.Int64 && connectionConfig.Key.Port != serverInfo.ExtraPort.Int64 {
75+
return fmt.Errorf("Unexpected database port reported: %+v / extra_port: %+v", serverInfo.Port.Int64, serverInfo.ExtraPort.Int64)
9476
}
77+
78+
migrationContext.Log.Infof("%s connection validated on %+v", name, connectionConfig.Key)
79+
return nil
9580
}

go/base/utils_test.go

+84-1
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
/*
2-
Copyright 2016 GitHub Inc.
2+
Copyright 2023 GitHub Inc.
33
See https://github.com/github/gh-ost/blob/master/LICENSE
44
*/
55

66
package base
77

88
import (
9+
gosql "database/sql"
910
"testing"
1011

12+
"github.com/github/gh-ost/go/mysql"
1113
"github.com/openark/golib/log"
1214
test "github.com/openark/golib/tests"
1315
)
@@ -16,6 +18,10 @@ func init() {
1618
log.SetLevel(log.ERROR)
1719
}
1820

21+
func newMysqlPort(port int64) gosql.NullInt64 {
22+
return gosql.NullInt64{Int64: port, Valid: port > 0}
23+
}
24+
1925
func TestStringContainsAll(t *testing.T) {
2026
s := `insert,delete,update`
2127

@@ -27,3 +33,80 @@ func TestStringContainsAll(t *testing.T) {
2733
test.S(t).ExpectTrue(StringContainsAll(s, "insert", ""))
2834
test.S(t).ExpectTrue(StringContainsAll(s, "insert", "update", "delete"))
2935
}
36+
37+
func TestValidateConnection(t *testing.T) {
38+
connectionConfig := &mysql.ConnectionConfig{
39+
Key: mysql.InstanceKey{
40+
Hostname: t.Name(),
41+
Port: mysql.DefaultInstancePort,
42+
},
43+
}
44+
45+
// check valid port matching connectionConfig validates
46+
{
47+
migrationContext := &MigrationContext{Log: NewDefaultLogger()}
48+
serverInfo := &mysql.ServerInfo{
49+
Port: newMysqlPort(mysql.DefaultInstancePort),
50+
ExtraPort: newMysqlPort(mysql.DefaultInstancePort + 1),
51+
}
52+
test.S(t).ExpectNil(ValidateConnection(serverInfo, connectionConfig, migrationContext, "test"))
53+
}
54+
// check NULL port validates when AliyunRDS=true
55+
{
56+
migrationContext := &MigrationContext{
57+
Log: NewDefaultLogger(),
58+
AliyunRDS: true,
59+
}
60+
serverInfo := &mysql.ServerInfo{}
61+
test.S(t).ExpectNil(ValidateConnection(serverInfo, connectionConfig, migrationContext, "test"))
62+
}
63+
// check NULL port validates when AzureMySQL=true
64+
{
65+
migrationContext := &MigrationContext{
66+
Log: NewDefaultLogger(),
67+
AzureMySQL: true,
68+
}
69+
serverInfo := &mysql.ServerInfo{}
70+
test.S(t).ExpectNil(ValidateConnection(serverInfo, connectionConfig, migrationContext, "test"))
71+
}
72+
// check NULL port validates when GoogleCloudPlatform=true
73+
{
74+
migrationContext := &MigrationContext{
75+
Log: NewDefaultLogger(),
76+
GoogleCloudPlatform: true,
77+
}
78+
serverInfo := &mysql.ServerInfo{}
79+
test.S(t).ExpectNil(ValidateConnection(serverInfo, connectionConfig, migrationContext, "test"))
80+
}
81+
// check extra_port validates when port=NULL
82+
{
83+
migrationContext := &MigrationContext{Log: NewDefaultLogger()}
84+
serverInfo := &mysql.ServerInfo{
85+
ExtraPort: newMysqlPort(mysql.DefaultInstancePort),
86+
}
87+
test.S(t).ExpectNil(ValidateConnection(serverInfo, connectionConfig, migrationContext, "test"))
88+
}
89+
// check extra_port validates when port does not match but extra_port does
90+
{
91+
migrationContext := &MigrationContext{Log: NewDefaultLogger()}
92+
serverInfo := &mysql.ServerInfo{
93+
Port: newMysqlPort(12345),
94+
ExtraPort: newMysqlPort(mysql.DefaultInstancePort),
95+
}
96+
test.S(t).ExpectNil(ValidateConnection(serverInfo, connectionConfig, migrationContext, "test"))
97+
}
98+
// check validation fails when valid port does not match connectionConfig
99+
{
100+
migrationContext := &MigrationContext{Log: NewDefaultLogger()}
101+
serverInfo := &mysql.ServerInfo{
102+
Port: newMysqlPort(9999),
103+
}
104+
test.S(t).ExpectNotNil(ValidateConnection(serverInfo, connectionConfig, migrationContext, "test"))
105+
}
106+
// check validation fails when port and extra_port are invalid
107+
{
108+
migrationContext := &MigrationContext{Log: NewDefaultLogger()}
109+
serverInfo := &mysql.ServerInfo{}
110+
test.S(t).ExpectNotNil(ValidateConnection(serverInfo, connectionConfig, migrationContext, "test"))
111+
}
112+
}

go/cmd/gh-ost/main.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
Copyright 2022 GitHub Inc.
2+
Copyright 2023 GitHub Inc.
33
See https://github.com/github/gh-ost/blob/master/LICENSE
44
*/
55

@@ -49,7 +49,7 @@ func main() {
4949
migrationContext := base.NewMigrationContext()
5050
flag.StringVar(&migrationContext.InspectorConnectionConfig.Key.Hostname, "host", "127.0.0.1", "MySQL hostname (preferably a replica, not the master)")
5151
flag.StringVar(&migrationContext.AssumeMasterHostname, "assume-master-host", "", "(optional) explicitly tell gh-ost the identity of the master. Format: some.host.com[:port] This is useful in master-master setups where you wish to pick an explicit master, or in a tungsten-replicator where gh-ost is unable to determine the master")
52-
flag.IntVar(&migrationContext.InspectorConnectionConfig.Key.Port, "port", 3306, "MySQL port (preferably a replica, not the master)")
52+
flag.Int64Var(&migrationContext.InspectorConnectionConfig.Key.Port, "port", 3306, "MySQL port (preferably a replica, not the master)")
5353
flag.Float64Var(&migrationContext.InspectorConnectionConfig.Timeout, "mysql-timeout", 0.0, "Connect, read and write timeout for MySQL")
5454
flag.StringVar(&migrationContext.CliUser, "user", "", "MySQL user")
5555
flag.StringVar(&migrationContext.CliPassword, "password", "", "MySQL password")

go/logic/applier.go

+14-25
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
Copyright 2022 GitHub Inc.
2+
Copyright 2023 GitHub Inc.
33
See https://github.com/github/gh-ost/blob/master/LICENSE
44
*/
55

@@ -71,25 +71,24 @@ func NewApplier(migrationContext *base.MigrationContext) *Applier {
7171
}
7272
}
7373

74+
func (this *Applier) ServerInfo() *mysql.ServerInfo {
75+
return this.migrationContext.ApplierServerInfo
76+
}
77+
7478
func (this *Applier) InitDBConnections() (err error) {
7579
applierUri := this.connectionConfig.GetDBUri(this.migrationContext.DatabaseName)
7680
if this.db, _, err = mysql.GetDB(this.migrationContext.Uuid, applierUri); err != nil {
7781
return err
7882
}
83+
if this.migrationContext.ApplierServerInfo, err = mysql.GetServerInfo(this.db); err != nil {
84+
return err
85+
}
7986
singletonApplierUri := fmt.Sprintf("%s&timeout=0", applierUri)
8087
if this.singletonDB, _, err = mysql.GetDB(this.migrationContext.Uuid, singletonApplierUri); err != nil {
8188
return err
8289
}
8390
this.singletonDB.SetMaxOpenConns(1)
84-
version, err := base.ValidateConnection(this.db, this.connectionConfig, this.migrationContext, this.name)
85-
if err != nil {
86-
return err
87-
}
88-
if _, err := base.ValidateConnection(this.singletonDB, this.connectionConfig, this.migrationContext, this.name); err != nil {
89-
return err
90-
}
91-
this.migrationContext.ApplierMySQLVersion = version
92-
if err := this.validateAndReadTimeZone(); err != nil {
91+
if err = base.ValidateConnection(this.ServerInfo(), this.connectionConfig, this.migrationContext, this.name); err != nil {
9392
return err
9493
}
9594
if !this.migrationContext.AliyunRDS && !this.migrationContext.GoogleCloudPlatform && !this.migrationContext.AzureMySQL {
@@ -102,18 +101,8 @@ func (this *Applier) InitDBConnections() (err error) {
102101
if err := this.readTableColumns(); err != nil {
103102
return err
104103
}
105-
this.migrationContext.Log.Infof("Applier initiated on %+v, version %+v", this.connectionConfig.ImpliedKey, this.migrationContext.ApplierMySQLVersion)
106-
return nil
107-
}
108-
109-
// validateAndReadTimeZone potentially reads server time-zone
110-
func (this *Applier) validateAndReadTimeZone() error {
111-
query := `select /* gh-ost */ @@global.time_zone`
112-
if err := this.db.QueryRow(query).Scan(&this.migrationContext.ApplierTimeZone); err != nil {
113-
return err
114-
}
115-
116-
this.migrationContext.Log.Infof("will use time_zone='%s' on applier", this.migrationContext.ApplierTimeZone)
104+
this.migrationContext.Log.Infof("Applier initiated on %+v, version %+v (%+v)", this.connectionConfig.ImpliedKey,
105+
this.ServerInfo().Version, this.ServerInfo().VersionComment)
117106
return nil
118107
}
119108

@@ -238,7 +227,7 @@ func (this *Applier) CreateGhostTable() error {
238227
}
239228
defer tx.Rollback()
240229

241-
sessionQuery := fmt.Sprintf(`SET SESSION time_zone = '%s'`, this.migrationContext.ApplierTimeZone)
230+
sessionQuery := fmt.Sprintf(`SET SESSION time_zone = '%s'`, this.ServerInfo().TimeZone)
242231
sessionQuery = fmt.Sprintf("%s, %s", sessionQuery, this.generateSqlModeQuery())
243232

244233
if _, err := tx.Exec(sessionQuery); err != nil {
@@ -279,7 +268,7 @@ func (this *Applier) AlterGhost() error {
279268
}
280269
defer tx.Rollback()
281270

282-
sessionQuery := fmt.Sprintf(`SET SESSION time_zone = '%s'`, this.migrationContext.ApplierTimeZone)
271+
sessionQuery := fmt.Sprintf(`SET SESSION time_zone = '%s'`, this.ServerInfo().TimeZone)
283272
sessionQuery = fmt.Sprintf("%s, %s", sessionQuery, this.generateSqlModeQuery())
284273

285274
if _, err := tx.Exec(sessionQuery); err != nil {
@@ -640,7 +629,7 @@ func (this *Applier) ApplyIterationInsertQuery() (chunkSize int64, rowsAffected
640629
}
641630
defer tx.Rollback()
642631

643-
sessionQuery := fmt.Sprintf(`SET SESSION time_zone = '%s'`, this.migrationContext.ApplierTimeZone)
632+
sessionQuery := fmt.Sprintf(`SET SESSION time_zone = '%s'`, this.ServerInfo().TimeZone)
644633
sessionQuery = fmt.Sprintf("%s, %s", sessionQuery, this.generateSqlModeQuery())
645634

646635
if _, err := tx.Exec(sessionQuery); err != nil {

0 commit comments

Comments
 (0)